How Not to Flatten a List of Lists in Python
January 2013
You will find lots of solutions on the Web to flatten a list of lists
in Python. Most of them take linear time, which is efficient because
each element is copied once. But some solutions take more than linear
time. In this article, I will explain why one of the simplest solution
sum(lst, [])
is actually one of the worst solution.
The inefficiency comes from how the +
operator (concatenation) is
defined on a list: it creates a new list and copies each element into
it. The algorithm thus creates many unnecessary intermediate lists and
makes many unnecessary copies.
First, let's note that:
sum(lst, [])
is equivalent to:
lst[0] + lst[1] + lst[2] + ...
and:
list(functools.reduce(operator.add, lst, []))
As you can see, this will make lots of copies. How many? Let's look at an example:
sum([[1,2,3], [4], [5,6]], [])
which is equivalent to:
[] + [1,2,3] + [4] + [5,6]
First, it makes 3 copies:
[1,2,3] = [] + [1,2,3]
Then, 4 copies:
[1,2,3,4] = [1,2,3] + [4]
And finally, 6 copies:
[1,2,3,4,5,6] = [1,2,3,4] + [5,6]
It takes 13 copies to flatten these three lists. Comparing to the obvious solution, which is to simply allocate a new list and copy everything into it (6 copies), this is very slow.
Let's do the math to calculate how many copies it takes for the general case.
Let be the number of lists and the number of elements in the th list.
At the first iteration, the algorithm copies elements. At the second iteration it copies elements. At the third iteration it copies elements.
So, the number of copies at iteration is:
Since there are lists, the total number of copies is:
The worst-case scenario is lists of one element (). So an upper bound for the worst-case scenario is:
We can conclude that sum(lst, [])
has quadratic
complexity. Comparing to:
list(itertools.chain.from_iterable(lst))
which has linear complexity. So, sum(lst, [])
is very inefficient.
And if you are still not convinced, here is a benchmark:
$ python3 -mtimeit -s'lst=[[1]] \* 10000' 'sum(lst, [])'
10 loops, best of 3: 237 msec per loop
$ python3 -mtimeit -s'lst=[[1]] \* 10000' 'list(itertools.chain.from_iterable(lst))'
1000 loops, best of 3: 488 usec per loop
Like this article? Get notified of new ones: