Why do I get this many iterations when adding to and removing from a set while iterating over it?
Trying to understand the Python for-loop, I thought this would give the result {1}
for one iteration, or just get stuck in an infinite loop, depending on if it does the iteration like in C or other languages. But actually it did neither.
>>> s = {0}
>>> for i in s:
... s.add(i + 1)
... s.remove(i)
...
>>> print(s)
{16}
Why does it do 16 iterations? Where does the result {16}
come from?
This was using Python 3.8.2. On pypy it makes the expected result {1}
.
Solution 1:
Python makes no promises about when (if ever) this loop will end. Modifying a set during iteration can lead to skipped elements, repeated elements, and other weirdness. Never rely on such behavior.
Everything I am about to say is implementation details, subject to change without notice. If you write a program that relies on any of it, your program may break on any combination of Python implementation and version other than CPython 3.8.2.
The short explanation for why the loop ends at 16 is that 16 is the first element that happens to be placed at a lower hash table index than the previous element. The full explanation is below.
The internal hash table of a Python set always has a power of 2 size. For a table of size 2^n, if no collisions occur, elements are stored in the position in the hash table corresponding to the n least-significant bits of their hash. You can see this implemented in set_add_entry
:
mask = so->mask;
i = (size_t)hash & mask;
entry = &so->table[i];
if (entry->key == NULL)
goto found_unused;
Most small Python ints hash to themselves; particularly, all ints in your test hash to themselves. You can see this implemented in long_hash
. Since your set never contains two elements with equal low bits in their hashes, no collision occurs.
A Python set iterator keeps track of its position in a set with a simple integer index into the set's internal hash table. When the next element is requested, the iterator searches for a populated entry in the hash table starting at that index, then sets its stored index to immediately after the found entry and returns the entry's element. You can see this in setiter_iternext
:
while (i <= mask && (entry[i].key == NULL || entry[i].key == dummy))
i++;
si->si_pos = i+1;
if (i > mask)
goto fail;
si->len--;
key = entry[i].key;
Py_INCREF(key);
return key;
Your set initially starts with a hash table of size 8, and a pointer to a 0
int object at index 0 in the hash table. The iterator is also positioned at index 0. As you iterate, elements are added to the hash table, each at the next index because that's where their hash says to put them, and that's always the next index the iterator looks at. Removed elements have a dummy marker stored at their old position, for collision resolution purposes. You can see that implemented in set_discard_entry
:
entry = set_lookkey(so, key, hash);
if (entry == NULL)
return -1;
if (entry->key == NULL)
return DISCARD_NOTFOUND;
old_key = entry->key;
entry->key = dummy;
entry->hash = -1;
so->used--;
Py_DECREF(old_key);
return DISCARD_FOUND;
When 4
is added to the set, the number of elements and dummies in the set becomes high enough that set_add_entry
triggers a hash table rebuild, calling set_table_resize
:
if ((size_t)so->fill*5 < mask*3)
return 0;
return set_table_resize(so, so->used>50000 ? so->used*2 : so->used*4);
so->used
is the number of populated, non-dummy entries in the hash table, which is 2, so set_table_resize
receives 8 as its second argument. Based on this, set_table_resize
decides the new hash table size should be 16:
/* Find the smallest table size > minused. */
/* XXX speed-up with intrinsics */
size_t newsize = PySet_MINSIZE;
while (newsize <= (size_t)minused) {
newsize <<= 1; // The largest possible value is PY_SSIZE_T_MAX + 1.
}
It rebuilds the hash table with size 16. All elements still end up at their old indexes in the new hash table, since they didn't have any high bits set in their hashes.
As the loop continues, elements keep getting placed at the next index the iterator will look. Another hash table rebuild is triggered, but the new size is still 16.
The pattern breaks when the loop adds 16 as an element. There is no index 16 to place the new element at. The 4 lowest bits of 16 are 0000, putting 16 at index 0. The iterator's stored index is 16 at this point, and when the loop asks for the next element from the iterator, the iterator sees that it has gone past the end of the hash table.
The iterator terminates the loop at this point, leaving only 16
in the set.
Solution 2:
I believe this has got something to do with the actual implementation of sets in python. Sets use hash tables for storing their items and so iterating over a set means iterating over the rows of its hash table.
As you iterate and add items to your set, new hashes are being created and appended to the hash table until you reach number 16. At this point, the next number is actually added to the beginning of the hash table and not to the end. And since you already iterated over the first row of the table, the iteration loop ends.
My answer is based on this one of a similar question, it actually shows this exact same example. I really recommend reading it for more detail.
Solution 3:
From the python 3 documentation:
Code that modifies a collection while iterating over that same collection can be tricky to get right. Instead, it is usually more straight-forward to loop over a copy of the collection or to create a new collection:
Iterate over a copy
s = {0}
s2 = s.copy()
for i in s2:
s.add(i + 1)
s.remove(i)
which should iterate only 1 time
>>> print(s)
{1}
>>> print(s2)
{0}
Edit:
A Possible reason for this iteration is because a set is unordered, causing some kind of stack trace sort of thing. If you do it with a list and not a set, then it will just end, with s = [1]
because lists are ordered so the for loop will start with index 0 and then move on to the next index, finding that there isn't one, and exiting the loop.
Solution 4:
Python set an unordered collection which do not record element position or order of insertion. There is no index attached to any element in a python set. So they do not support any indexing or slicing operation.
So don't expect your for loop will work in a defined order.
Why does it do 16 iterations?
user2357112 supports Monica
already explains the main cause. Here, is another way of thinking.
s = {0}
for i in s:
s.add(i + 1)
print(s)
s.remove(i)
print(s)
When you run this code it gives you output this :
{0, 1}
{1, 2}
{2, 3}
{3, 4}
{4, 5}
{5, 6}
{6, 7}
{7, 8}
{8, 9}
{9, 10}
{10, 11}
{11, 12}
{12, 13}
{13, 14}
{14, 15}
{16, 15}
{16}
When we access all the elements together like loop or printing the set, there must be a predefined order for it to traverse the whole set.
So, in last iteration you will see order is changed like from {i,i+1}
to {i+1,i}
.
After the last iteration it happened that i+1
is already traversed so loop exit.
Interesting Fact: Use any value less than 16 except 6 and 7 will always gives you result 16.