Mathieu Larose

How Not to Flatten a List of Lists in Python

January 2013

You will find lots of solutions on the Web to flatten a list of lists in Python. Most of them take linear time, which is efficient because each element is copied once. But some solutions take more than linear time. In this article, I will explain why one of the simplest solution sum(lst, []) is actually one of the worst solution.

The inefficiency comes from how the + operator (concatenation) is defined on a list: it creates a new list and copies each element into it. The algorithm thus creates many unnecessary intermediate lists and makes many unnecessary copies.

First, let's note that:

sum(lst, [])

is equivalent to:

lst[0] + lst[1] + lst[2] + ...

and:

list(functools.reduce(operator.add, lst, []))

As you can see, this will make lots of copies. How many? Let's look at an example:

sum([[1,2,3], [4], [5,6]], [])

which is equivalent to:

[] + [1,2,3] + [4] + [5,6]

First, it makes 3 copies:

[1,2,3] = [] + [1,2,3]

Then, 4 copies:

[1,2,3,4] = [1,2,3] + [4]

And finally, 6 copies:

[1,2,3,4,5,6] = [1,2,3,4] + [5,6]

It takes 13 copies to flatten these three lists. Comparing to the obvious solution, which is to simply allocate a new list and copy everything into it (6 copies), this is very slow.

Let's do the math to calculate how many copies it takes for the general case.

Let mm be the number of lists and nin_i the number of elements in the iith list.

At the first iteration, the algorithm copies n1n_1 elements. At the second iteration it copies n1+n2n_1 + n_2 elements. At the third iteration it copies n1+n2+n3n_1 + n_2 + n_3 elements.

So, the number of copies at iteration kk is:

i=1kni\sum_{i=1}^k n_i

Since there are mm lists, the total number of copies is:

k=1mi=1kni\sum_{k=1}^m \sum_{i=1}^k n_i

The worst-case scenario is mm lists of one element (ni=1,for i=1,...,kn_i = 1, \text{for} \ i = 1, ..., k). So an upper bound for the worst-case scenario is:

k=1mi=1k1=k=1mk=m(m+1)2O(m2)\sum_{k=1}^m \sum_{i=1}^k 1 = \sum_{k=1}^m k = \frac{m (m+1)}{2} \in O(m^2)

We can conclude that sum(lst, []) has quadratic complexity. Comparing to:

list(itertools.chain.from_iterable(lst))

which has linear complexity. So, sum(lst, []) is very inefficient.

And if you are still not convinced, here is a benchmark:

$ python3 -mtimeit -s'lst=[[1]] \* 10000' 'sum(lst, [])'
10 loops, best of 3: 237 msec per loop

$ python3 -mtimeit -s'lst=[[1]] \* 10000' 'list(itertools.chain.from_iterable(lst))'
1000 loops, best of 3: 488 usec per loop

Like this article? Get notified of new ones: