Custom cache with iterator does not work as intended
Your bug is due to these lines:
case item if item >= len(self.cache):
while item - len(self.cache) >= 0:
self.cache_next()
Basically, CachedTuple((1,2,3))[50]
will loop indefinitely, as 50
is larger than the length of the cache, and self.cache_next()
won't generate any new values.
A simple change adding a self.finished
check will work:
case item if item >= len(self.cache):
while item - len(self.cache) >= 0 and not self.finished:
self.cache_next()
I do however believe you have numerous other issues with the code, and I think you can improve it tremendously:
- Drop the match statement. It does nothing.
- Implement iteration using
__iter__
instead of relying on the old iteration mechanism of__getitem__
. - Inherit from the
collections.abc.Sequence
and adhere to theSequence
protocol. - Drop the dataclass. This is not a dataclass. You seem to enjoy the delightful new language features, but unfortunately none of them are relevant and it's causing your code to be longer, less clear, and not working as intended.
Remember, simple readable code is infinitely more important than using new language features.
I took the liberty and spent a few hours creating an example code complying to collections.abc.Sequence
. Enjoy!
from collections.abc import Sequence
import itertools
from typing import Iterable, Iterator, Optional, TypeVar, overload
_T_co =TypeVar("_T_co", covariant=True)
class CachedIterable(Sequence[_T_co]):
def __init__(self, iterable: Iterable[_T_co], *, max_length: int = None) -> None:
self._cache: list[_T_co] = []
if max_length is not None:
if max_length <= 0:
raise ValueError('max_length must be > 0')
iterable = itertools.islice(iterable, max_length)
else:
try:
# Attempt to optimize and get a length.
max_length = len(iterable) # type: ignore
except TypeError:
max_length = None
self._max_length = max_length
self._iterator: Optional[Iterator] = iter(iterable)
def __repr__(self) -> str:
return (f'<{self.__class__.__name__} {self._cache!r}'
f'{"+" if self._iterator else ""}>')
def _exhaust_iterator(self) -> None:
"""Fully exhaust the iterator."""
assert self._iterator
try:
self._cache.extend(self._iterator)
finally:
self._iterator = None
def _advance_iterator(self, n: int) -> None:
"""Attempt to advance the iterator by n steps.
May advance by less than n steps if the iterator is exhausted.
"""
assert self._iterator
pre_advance_length = len(self._cache)
try:
self._cache.extend(itertools.islice(self._iterator, n))
except Exception:
# Iterator threw an exception.
self._iterator = None
raise
# If iterator exhausted, clear it.
if pre_advance_length + n > len(self._cache):
self._iterator = None
def _grow_cache(self, size: int) -> None:
"""Atttempt grow the cache to be at least size.
May grow to less than size if the iterator is exhausted.
"""
if size <= len(self._cache):
return
if self._max_length and size >= self._max_length:
self._exhaust_iterator()
return
self._advance_iterator(size - len(self._cache))
@overload
def __getitem__(self, i: int) -> _T_co: ...
@overload
def __getitem__(self, s: slice) -> Sequence[_T_co]: ...
def __getitem__(self, index):
if not isinstance(index, (slice, int)):
raise TypeError(f'index must be int or slice, not {index!r}')
if not self._iterator:
return self._cache[index]
if isinstance(index, slice):
# Stop might be less than start if step is negative.
max_index = max(index.stop or 0, index.start or 0)
# If we're counting from the end, exaust the iterator.
if (index.stop is not None and index.stop < 0 or
index.start is not None and index.start < 0):
self._exhaust_iterator()
else:
self._grow_cache(max_index + 1)
return self._cache[index]
# Asking for a number beyond the limit.
if self._max_length and index > self._max_length:
raise IndexError(f'index {index} out of range')
# If we're counting from the end, exaust the iterator.
if index < 0:
self._exhaust_iterator()
else:
self._grow_cache(index + 1)
return self._cache[index]
def __iter__(self) -> Iterator[_T_co]:
if not self._iterator:
yield from self._cache
return
yield from self._cache
while True:
try:
item = next(self._iterator)
# Iterator threw an exception.
except StopIteration:
self._iterator = None
return
except BaseException:
self._iterator = None
raise
self._cache.append(item)
# Prevent capturing GeneratorExit and other gen.throw() exceptions.
yield item
def __len__(self) -> int:
# TODO: Can optimize for known lengths.
if not self._iterator:
return len(self._cache)
self._exhaust_iterator()
return len(self._cache)