How to make a flat list out of a list of lists?
Is there a shortcut to make a simple list out of a list of lists in Python?
I can do it in a for
loop, but is there some cool "one-liner"?
I tried it with functools.reduce()
:
from functools import reduce
l = [[1, 2, 3], [4, 5, 6], [7], [8, 9]]
reduce(lambda x, y: x.extend(y), l)
But I get this error:
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "<stdin>", line 1, in <lambda>
AttributeError: 'NoneType' object has no attribute 'extend'
Solution 1:
Given a list of lists t
,
flat_list = [item for sublist in t for item in sublist]
which means:
flat_list = []
for sublist in t:
for item in sublist:
flat_list.append(item)
is faster than the shortcuts posted so far. (t
is the list to flatten.)
Here is the corresponding function:
def flatten(t):
return [item for sublist in t for item in sublist]
As evidence, you can use the timeit
module in the standard library:
$ python -mtimeit -s't=[[1,2,3],[4,5,6], [7], [8,9]]*99' '[item for sublist in t for item in sublist]'
10000 loops, best of 3: 143 usec per loop
$ python -mtimeit -s't=[[1,2,3],[4,5,6], [7], [8,9]]*99' 'sum(t, [])'
1000 loops, best of 3: 969 usec per loop
$ python -mtimeit -s't=[[1,2,3],[4,5,6], [7], [8,9]]*99' 'reduce(lambda x,y: x+y,t)'
1000 loops, best of 3: 1.1 msec per loop
Explanation: the shortcuts based on +
(including the implied use in sum
) are, of necessity, O(T**2)
when there are T sublists -- as the intermediate result list keeps getting longer, at each step a new intermediate result list object gets allocated, and all the items in the previous intermediate result must be copied over (as well as a few new ones added at the end). So, for simplicity and without actual loss of generality, say you have T sublists of k items each: the first k items are copied back and forth T-1 times, the second k items T-2 times, and so on; total number of copies is k times the sum of x for x from 1 to T excluded, i.e., k * (T**2)/2
.
The list comprehension just generates one list, once, and copies each item over (from its original place of residence to the result list) also exactly once.
Solution 2:
You can use itertools.chain()
:
import itertools
list2d = [[1,2,3], [4,5,6], [7], [8,9]]
merged = list(itertools.chain(*list2d))
Or you can use itertools.chain.from_iterable()
which doesn't require unpacking the list with the *
operator:
merged = list(itertools.chain.from_iterable(list2d))
Solution 3:
Note from the author: This is inefficient. But fun, because monoids are awesome. It's not appropriate for production Python code.
>>> l = [[1, 2, 3], [4, 5, 6], [7], [8, 9]]
>>> sum(l, [])
[1, 2, 3, 4, 5, 6, 7, 8, 9]
This just sums the elements of iterable passed in the first argument, treating second argument as the initial value of the sum (if not given, 0
is used instead and this case will give you an error).
Because you are summing nested lists, you actually get [1,3]+[2,4]
as a result of sum([[1,3],[2,4]],[])
, which is equal to [1,3,2,4]
.
Note that only works on lists of lists. For lists of lists of lists, you'll need another solution.
Solution 4:
I tested most suggested solutions with perfplot (a pet project of mine, essentially a wrapper around timeit
), and found
import functools
import operator
functools.reduce(operator.iconcat, a, [])
to be the fastest solution, both when many small lists and few long lists are concatenated. (operator.iadd
is equally fast.)
A simpler and also acceptable variant is
out = []
for sublist in a:
out.extend(sublist)
If the number of sublists is large, this performs a little worse than the above suggestion.
Code to reproduce the plot:
import functools
import itertools
import operator
import numpy as np
import perfplot
def forfor(a):
return [item for sublist in a for item in sublist]
def sum_brackets(a):
return sum(a, [])
def functools_reduce(a):
return functools.reduce(operator.concat, a)
def functools_reduce_iconcat(a):
return functools.reduce(operator.iconcat, a, [])
def itertools_chain(a):
return list(itertools.chain.from_iterable(a))
def numpy_flat(a):
return list(np.array(a).flat)
def numpy_concatenate(a):
return list(np.concatenate(a))
def extend(a):
out = []
for sublist in a:
out.extend(sublist)
return out
b = perfplot.bench(
setup=lambda n: [list(range(10))] * n,
# setup=lambda n: [list(range(n))] * 10,
kernels=[
forfor,
sum_brackets,
functools_reduce,
functools_reduce_iconcat,
itertools_chain,
numpy_flat,
numpy_concatenate,
extend,
],
n_range=[2 ** k for k in range(16)],
xlabel="num lists (of length 10)",
# xlabel="len lists (10 lists total)"
)
b.save("out.png")
b.show()
Solution 5:
>>> from functools import reduce
>>> l = [[1,2,3], [4,5,6], [7], [8,9]]
>>> reduce(lambda x, y: x+y, l)
[1, 2, 3, 4, 5, 6, 7, 8, 9]
The extend()
method in your example modifies x
instead of returning a useful value (which functools.reduce()
expects).
A faster way to do the reduce
version would be
>>> import operator
>>> l = [[1,2,3], [4,5,6], [7], [8,9]]
>>> reduce(operator.concat, l)
[1, 2, 3, 4, 5, 6, 7, 8, 9]