By Mathieu Larose
January 13, 2013
You will find lots of solutions on the Web to flatten a list of lists in Python. Most of them take linear time, which is efficient because each element is copied once. But some solutions take more than linear time. In this article, I will explain why one of the simplest solution
sum(lst, ) is actually one of the worst solution.
The inefficiency comes from how the
+ operator (concatenation) is defined on a list: it creates a new list and copies each element into it. The algorithm thus creates many unnecessary intermediate lists and makes many unnecessary copies.
First, let's note that:
is equivalent to:
list(functools.reduce(operator.add, lst, ))
lst + lst + lst + ...
As you can see, this will make lots of copies. How many? Let's look at an example:
sum([[1,2,3], , [5,6]], )  + [1,2,3] +  + [5,6]
[1,2,3] =  + [1,2,3]
[1,2,3,4] = [1,2,3] + 
[1,2,3,4,5,6] = [1,2,3,4] + [5,6]
It takes 13 copies to flatten these three lists. Comparing to the obvious solution, which is to simply allocate a new list and copy everything into it (6 copies), this is very slow.
Let's do the math to calculate how many copies it takes for the general case.
Let be the number of lists and the number of elements in the th list.
At the first iterations, the algorithm copies elements. At the second iteration it copies elements. At the third iteration it copies elements.
So, the number of copies at iteration is:
Since there are lists, the total number of copies is:
The worst-case scenario is lists of one element (). So an upper bound for the worst-case scenario is:
We can conclude that
sum(lst, ) has quadratic complexity. Comparing to:
which has linear complexity,
sum(lst, ) is very inefficient.
And if you are still not convinced, see for yourself:
comments powered by Disqus
$ python3 -mtimeit -s'lst=[] * 10000' 'sum(lst, )' 10 loops, best of 3: 237 msec per loop $ python3 -mtimeit -s'lst=[] * 10000' 'list(itertools.chain.from_iterable(lst))' 1000 loops, best of 3: 488 usec per loop