How Not to Flatten a List of Lists in Python

By Mathieu Larose

January 2013

You will find lots of solutions on the Web to flatten a list of lists in Python. Most of them take linear time, which is efficient because each element is copied once. But some solutions take more than linear time. In this article, I will explain why one of the simplest solution sum(lst, []) is actually one of the worst solution.

The inefficiency comes from how the + operator (concatenation) is defined on a list: it creates a new list and copies each element into it. The algorithm thus creates many unnecessary intermediate lists and makes many unnecessary copies.

First, let’s note that:

sum(lst, [])

is equivalent to:

list(functools.reduce(operator.add, lst, []))

and:

lst[0] + lst[1] + lst[2] + ...

As you can see, this will make lots of copies. How many? Let’s look at an example:

sum([[1,2,3], [4], [5,6]], [])
[] + [1,2,3] + [4] + [5,6]

3 copies:

[1,2,3] = [] + [1,2,3] 

4 copies:

[1,2,3,4] = [1,2,3] + [4]

6 copies:

[1,2,3,4,5,6] = [1,2,3,4] + [5,6]

It takes 13 copies to flatten these three lists. Comparing to the obvious solution, which is to simply allocate a new list and copy everything into it (6 copies), this is very slow.

Let’s do the math to calculate how many copies it takes for the general case.

Let \(m\) be the number of lists and \(n_i\) the number of elements in the \(i\)th list.

At the first iteration, the algorithm copies \(n_1\) elements. At the second iteration it copies \(n_1 + n_2\) elements. At the third iteration it copies \(n_1 + n_2 + n_3\) elements.

So, the number of copies at iteration \(k\) is: \[\sum_{i=1}^k n_i\]

Since there are \(m\) lists, the total number of copies is: \[\sum_{k=1}^m \sum_{i=1}^k n_i\]

The worst-case scenario is \(m\) lists of one element (\(n_i = 1, \text{for} \ i = 1, ..., k\)). So an upper bound for the worst-case scenario is:

\[\sum_{k=1}^m \sum_{i=1}^k 1 = \sum_{k=1}^m k = \frac{m (m+1)}{2} \in O(m^2)\]

We can conclude that sum(lst, []) has quadratic complexity. Comparing to:

list(itertools.chain.from_iterable(lst))

which has linear complexity, sum(lst, []) is very inefficient.

And if you are still not convinced, see for yourself:

$ python3 -mtimeit -s'lst=[[1]] * 10000' 'sum(lst, [])'
10 loops, best of 3: 237 msec per loop

$ python3 -mtimeit -s'lst=[[1]] * 10000' 'list(itertools.chain.from_iterable(lst))'
1000 loops, best of 3: 488 usec per loop
comments powered by Disqus