How not to miss the next element after itertools.takewhile()

Say we wish to process an iterator and want to handle it by chunks.
The logic per chunk depends on previously-calculated chunks, so groupby() does not help.

Our friend in this case is itertools.takewhile():

while True:
    chunk = itertools.takewhile(getNewChunkLogic(), myIterator)
    process(chunk)

The problem is that takewhile() needs to go past the last element that meets the new chunk logic, thus 'eating' the first element for the next chunk.

There are various solutions to that, including wrapping or à la C's ungetc(), etc..
My question is: is there an elegant solution?


Solution 1:

takewhile() indeed needs to look at the next element to determine when to toggle behaviour.

You could use a wrapper that tracks the last seen element, and that can be 'reset' to back up one element:

_sentinel = object()

class OneStepBuffered(object):
    def __init__(self, it):
        self._it = iter(it)
        self._last = _sentinel
        self._next = _sentinel
    def __iter__(self):
        return self
    def __next__(self):
        if self._next is not _sentinel:
            next_val, self._next = self._next, _sentinel
            return next_val
        try:
            self._last = next(self._it)
            return self._last
        except StopIteration:
            self._last = self._next = _sentinel
            raise
    next = __next__  # Python 2 compatibility
    def step_back(self):
        if self._last is _sentinel:
            raise ValueError("Can't back up a step")
        self._next, self._last = self._last, _sentinel

Wrap your iterator in this one before using it with takewhile():

myIterator = OneStepBuffered(myIterator)
while True:
    chunk = itertools.takewhile(getNewChunkLogic(), myIterator)
    process(chunk)
    myIterator.step_back()

Demo:

>>> from itertools import takewhile
>>> test_list = range(10)
>>> iterator = OneStepBuffered(test_list)
>>> list(takewhile(lambda i: i < 5, iterator))
[0, 1, 2, 3, 4]
>>> iterator.step_back()
>>> list(iterator)
[5, 6, 7, 8, 9]