Flattening a shallow list in Python [duplicate]

Is there a simple way to flatten a list of iterables with a list comprehension, or failing that, what would you all consider to be the best way to flatten a shallow list like this, balancing performance and readability?

I tried to flatten such a list with a nested list comprehension, like this:

[image for image in menuitem for menuitem in list_of_menuitems]

But I get in trouble of the NameError variety there, because the name 'menuitem' is not defined. After googling and looking around on Stack Overflow, I got the desired results with a reduce statement:

reduce(list.__add__, map(lambda x: list(x), list_of_menuitems))

But this method is fairly unreadable because I need that list(x) call there because x is a Django QuerySet object.

Conclusion:

Thanks to everyone who contributed to this question. Here is a summary of what I learned. I'm also making this a community wiki in case others want to add to or correct these observations.

My original reduce statement is redundant and is better written this way:

>>> reduce(list.__add__, (list(mi) for mi in list_of_menuitems))

This is the correct syntax for a nested list comprehension (Brilliant summary dF!):

>>> [image for mi in list_of_menuitems for image in mi]

But neither of these methods are as efficient as using itertools.chain:

>>> from itertools import chain
>>> list(chain(*list_of_menuitems))

And as @cdleary notes, it's probably better style to avoid * operator magic by using chain.from_iterable like so:

>>> chain = itertools.chain.from_iterable([[1,2],[3],[5,89],[],[6]])
>>> print(list(chain))
>>> [1, 2, 3, 5, 89, 6]

Solution 1:

If you're just looking to iterate over a flattened version of the data structure and don't need an indexable sequence, consider itertools.chain and company.

>>> list_of_menuitems = [['image00', 'image01'], ['image10'], []]
>>> import itertools
>>> chain = itertools.chain(*list_of_menuitems)
>>> print(list(chain))
['image00', 'image01', 'image10']

It will work on anything that's iterable, which should include Django's iterable QuerySets, which it appears that you're using in the question.

Edit: This is probably as good as a reduce anyway, because reduce will have the same overhead copying the items into the list that's being extended. chain will only incur this (same) overhead if you run list(chain) at the end.

Meta-Edit: Actually, it's less overhead than the question's proposed solution, because you throw away the temporary lists you create when you extend the original with the temporary.

Edit: As J.F. Sebastian says itertools.chain.from_iterable avoids the unpacking and you should use that to avoid * magic, but the timeit app shows negligible performance difference.

Solution 2:

You almost have it! The way to do nested list comprehensions is to put the for statements in the same order as they would go in regular nested for statements.

Thus, this

for inner_list in outer_list:
    for item in inner_list:
        ...

corresponds to

[... for inner_list in outer_list for item in inner_list]

So you want

[image for menuitem in list_of_menuitems for image in menuitem]

Solution 3:

@S.Lott: You inspired me to write a timeit app.

I figured it would also vary based on the number of partitions (number of iterators within the container list) -- your comment didn't mention how many partitions there were of the thirty items. This plot is flattening a thousand items in every run, with varying number of partitions. The items are evenly distributed among the partitions.

Flattening Comparison

Code (Python 2.6):

#!/usr/bin/env python2.6

"""Usage: %prog item_count"""

from __future__ import print_function

import collections
import itertools
import operator
from timeit import Timer
import sys

import matplotlib.pyplot as pyplot

def itertools_flatten(iter_lst):
    return list(itertools.chain(*iter_lst))

def itertools_iterable_flatten(iter_iter):
    return list(itertools.chain.from_iterable(iter_iter))

def reduce_flatten(iter_lst):
    return reduce(operator.add, map(list, iter_lst))

def reduce_lambda_flatten(iter_lst):
    return reduce(operator.add, map(lambda x: list(x), [i for i in iter_lst]))

def comprehension_flatten(iter_lst):
    return list(item for iter_ in iter_lst for item in iter_)

METHODS = ['itertools', 'itertools_iterable', 'reduce', 'reduce_lambda',
           'comprehension']

def _time_test_assert(iter_lst):
    """Make sure all methods produce an equivalent value.
    :raise AssertionError: On any non-equivalent value."""
    callables = (globals()[method + '_flatten'] for method in METHODS)
    results = [callable(iter_lst) for callable in callables]
    if not all(result == results[0] for result in results[1:]):
        raise AssertionError

def time_test(partition_count, item_count_per_partition, test_count=10000):
    """Run flatten methods on a list of :param:`partition_count` iterables.
    Normalize results over :param:`test_count` runs.
    :return: Mapping from method to (normalized) microseconds per pass.
    """
    iter_lst = [[dict()] * item_count_per_partition] * partition_count
    print('Partition count:    ', partition_count)
    print('Items per partition:', item_count_per_partition)
    _time_test_assert(iter_lst)
    test_str = 'flatten(%r)' % iter_lst
    result_by_method = {}
    for method in METHODS:
        setup_str = 'from test import %s_flatten as flatten' % method
        t = Timer(test_str, setup_str)
        per_pass = test_count * t.timeit(number=test_count) / test_count
        print('%20s: %.2f usec/pass' % (method, per_pass))
        result_by_method[method] = per_pass
    return result_by_method

if __name__ == '__main__':
    if len(sys.argv) != 2:
        raise ValueError('Need a number of items to flatten')
    item_count = int(sys.argv[1])
    partition_counts = []
    pass_times_by_method = collections.defaultdict(list)
    for partition_count in xrange(1, item_count):
        if item_count % partition_count != 0:
            continue
        items_per_partition = item_count / partition_count
        result_by_method = time_test(partition_count, items_per_partition)
        partition_counts.append(partition_count)
        for method, result in result_by_method.iteritems():
            pass_times_by_method[method].append(result)
    for method, pass_times in pass_times_by_method.iteritems():
        pyplot.plot(partition_counts, pass_times, label=method)
    pyplot.legend()
    pyplot.title('Flattening Comparison for %d Items' % item_count)
    pyplot.xlabel('Number of Partitions')
    pyplot.ylabel('Microseconds')
    pyplot.show()

Edit: Decided to make it community wiki.

Note: METHODS should probably be accumulated with a decorator, but I figure it'd be easier for people to read this way.

Solution 4:

sum(list_of_lists, []) would flatten it.

l = [['image00', 'image01'], ['image10'], []]
print sum(l,[]) # prints ['image00', 'image01', 'image10']