How to limit the size of a comprehension?
Solution 1:
You can use a generator expression to do the filtering, then use islice()
to limit the number of iterations:
from itertools import islice
filtered = (i for i in a if i == 1)
b = list(islice(filtered, 3))
This ensures you don't do more work than you have to to produce those 3 elements.
Note that there is no point anymore in using a list comprehension here; a list comprehension can't be broken out of, you are locked into iterating to the end.
Solution 2:
@Martijn Pieters is completly right that itertools.islice
is the best way to solve this. However if you don't mind an additional (external) library you can use iteration_utilities
which wraps a lot of these itertools
and their applications (and some additional ones). It could make this a bit easier, at least if you like functional programming:
>>> from iteration_utilities import Iterable
>>> Iterable([1, 2, 1, 2, 1, 2]).filter((1).__eq__)[:2].as_list()
[1, 1]
>>> (Iterable([1, 2, 1, 2, 1, 2])
... .filter((1).__eq__) # like "if item == 1"
... [:2] # like "islice(iterable, 2)"
... .as_list()) # like "list(iterable)"
[1, 1]
The iteration_utilities.Iterable
class uses generators internally so it will only process as many items as neccessary until you call any of the as_*
(or get_*
) -methods.
Disclaimer: I'm the author of the iteration_utilities
library.
Solution 3:
You could use itertools.count
to generate a counter and itertools.takewhile
to stop the iterating over a generator when the counter reaches the desired integer (3
in this case):
from itertools import count, takewhile
c = count()
b = list(takewhile(lambda x: next(c) < 3, (i for i in a if i == 1)))
Or a similar idea building a construct to raise StopIteration
to terminate the generator. That is the closest you'll get to your original idea of breaking the list comprehension, but I would not recommend it as best practice:
c = count()
b = list(i if next(c) < 3 else next(iter([])) for i in a if i == 1)
Examples:
>>> a = [1,2,1,4,1,1,1,1]
>>> c = count()
>>> list(takewhile(lambda x: next(c) < 3, (i for i in a if i == 1)))
[1, 1, 1]
>>> c = count()
>>> list(i if next(c) < 3 else next(iter([])) for i in a if i == 1)
[1, 1, 1]