Sorting an iterator in python

I want to iterate over a big itertools product, but I want to do it in a different order from the one that product offers. The problem is that sorting an iterator using sorted takes time. For example:

from itertools import product
import time

RNG = 15
RPT = 6

start = time.time()
a = sorted(product(range(RNG), repeat=RPT), key=sum)
print("Sorted: " + str(time.time() - start))
print(type(a))

start = time.time()
a = product(range(RNG), repeat=RPT)
print("Unsorted: " + str(time.time() - start))
print(type(a))

Creating the sorted iterator takes about twice as long. I'm guessing this is because sorted actually involves going through the whole iterator and returning a list. Whereas the second unsorted iterator is doing some sort of lazy evaluation magic.

I guess there's really two questions here.

  1. General question: is there a lazy evaluation way to change the order items appear in an iterator?
  2. Specific question: is there a way to loop through all m-length lists of ints less than n, hitting lists with smaller sums first?

Solution 1:

If your objective is to reduce memory consumption, you could write your own generator to return the permutations in order of their sum (see below). But, if memory is not a concern, sorting the output of itertools.product() will be faster than the Python code that produces the same result.

Writing a recursive function that produces the combinations of values in order of their sum can be achieved by merging multiple iterators (one per starting value) based on the smallest sum:

def sumCombo(A,N):
    if N==1:
        yield from ((n,) for n in A) # single item combos
        return
    pA = []                          # list of iterator/states
    for i,n in enumerate(A):         # for each starting value 
        ip = sumCombo(A[i:],N-1)     # iterator recursion to N-1
        p  = next(ip)                # current N-1 combination
        pA.append((n+sum(p),p,n,ip)) # sum, state & iterator
    while pA:
        # index and states of smallest sum
        i,(s,p,n,ip) = min(enumerate(pA),key=lambda ip:ip[1][0])
        ps = s
        while s == ps:        # output equal sum combinations
           yield (n,*p)       # yield starting number with recursed
           p = next(ip,None)  # advance iterator
           if p is None:
               del pA[i]      # remove exhausted iterators
               break
           s = n+sum(p)       # compute new sum
           pA[i] = (s,p,n,ip) # and update states

This will only produce combinations of values as opposed to the product which produces distinct permutations of these combinations. (38,760 combinations vs 11,390,625 products).

In order to obtain all the products, you would need to run these combinations through a function that generates distinct permutations:

def permuteDistinct(A):
    if len(A) == 1:
        yield tuple(A) # single value
        return
    seen = set()               # track starting value
    for i,n in enumerate(A):   # for each starting value
        if n in seen: continue # not yet used
        seen.add(n)
        for p in permuteDistinct(A[:i]+A[i+1:]): 
            yield (n,*p)       # starting value & rest

def sumProd(A,N):     
    for p in sumCombo(A,N):           # combinations in order of sum
        yield from permuteDistinct(p) # permuted

So sumProd(range(RNG),RPT) will produce the 11,390,625 permutations in order of their sum, without storing them in a list BUT it will take 5 times longer to do so (compared to sorting the product).

a = sorted(product(range(RNG), repeat=RPT), key=sum) # 4.6 sec
b = list(sumProd(range(RNG),RPT))                    # 23  sec

list(map(sum,a)) == list(map(sum,b)) # True  (same order of sums)
a == b                               # False (order differs for equal sums)

a[5:15]            b[5:15]             sum
(0, 1, 0, 0, 0, 0) (0, 1, 0, 0, 0, 0)  1
(1, 0, 0, 0, 0, 0) (1, 0, 0, 0, 0, 0)  1
(0, 0, 0, 0, 0, 2) (0, 0, 0, 0, 0, 2)  2
(0, 0, 0, 0, 1, 1) (0, 0, 0, 0, 2, 0)  2
(0, 0, 0, 0, 2, 0) (0, 0, 0, 2, 0, 0)  2
(0, 0, 0, 1, 0, 1) (0, 0, 2, 0, 0, 0)  2
(0, 0, 0, 1, 1, 0) (0, 2, 0, 0, 0, 0)  2
(0, 0, 0, 2, 0, 0) (2, 0, 0, 0, 0, 0)  2
(0, 0, 1, 0, 0, 1) (0, 0, 0, 0, 1, 1)  2
(0, 0, 1, 0, 1, 0) (0, 0, 0, 1, 0, 1)  2

If your process is searching for specific sums, it may be interesting to filter on combinations first and only expand distinct permutations for the combinations (sums) that meet your criteria. This could potentially cut down the number of iterations considerably (sumCombo(range(RNG),RPT) # 0.22 sec is faster than sorting the products).