How not to miss the next element after itertools.takewhile()
Say we wish to process an iterator and want to handle it by chunks.
The logic per chunk depends on previously-calculated chunks, so groupby()
does not help.
Our friend in this case is itertools.takewhile():
while True:
chunk = itertools.takewhile(getNewChunkLogic(), myIterator)
process(chunk)
The problem is that takewhile()
needs to go past the last element that meets the new chunk logic, thus 'eating' the first element for the next chunk.
There are various solutions to that, including wrapping or à la C's ungetc()
, etc..
My question is: is there an elegant solution?
Solution 1:
takewhile()
indeed needs to look at the next element to determine when to toggle behaviour.
You could use a wrapper that tracks the last seen element, and that can be 'reset' to back up one element:
_sentinel = object()
class OneStepBuffered(object):
def __init__(self, it):
self._it = iter(it)
self._last = _sentinel
self._next = _sentinel
def __iter__(self):
return self
def __next__(self):
if self._next is not _sentinel:
next_val, self._next = self._next, _sentinel
return next_val
try:
self._last = next(self._it)
return self._last
except StopIteration:
self._last = self._next = _sentinel
raise
next = __next__ # Python 2 compatibility
def step_back(self):
if self._last is _sentinel:
raise ValueError("Can't back up a step")
self._next, self._last = self._last, _sentinel
Wrap your iterator in this one before using it with takewhile()
:
myIterator = OneStepBuffered(myIterator)
while True:
chunk = itertools.takewhile(getNewChunkLogic(), myIterator)
process(chunk)
myIterator.step_back()
Demo:
>>> from itertools import takewhile
>>> test_list = range(10)
>>> iterator = OneStepBuffered(test_list)
>>> list(takewhile(lambda i: i < 5, iterator))
[0, 1, 2, 3, 4]
>>> iterator.step_back()
>>> list(iterator)
[5, 6, 7, 8, 9]