Speeding up the loops or different ideas for counting primitive triples

def pythag_triples(n):
    i = 0
    start = time.time()
    for x in range(1, int(sqrt(n) + sqrt(n)) + 1, 2):
        for m in range(x+2,int(sqrt(n) + sqrt(n)) + 1, 2):
            if gcd(x, m) == 1:
                # q = x*m
                # l = (m**2 - x**2)/2
                c = (m**2 + x**2)/2
                # trips.append((q,l,c))
                if c < n:
                    i += 1
    end = time.time()
    return i, end-start
print(pythag_triples(3141592653589793))

I'm trying to calculate primitive pythagorean triples using the idea that all triples are generated from using m, n that are both odd and coprime. I already know that the function works up to 1000000 but when doing it to the larger number its taken longer than 24 hours. Any ideas on how to speed this up/ not brute force it. I am trying to count the triples.


Solution 1:

Instead of the double loop over x and m and repeatedly checking if they are co-prime, we iterate only over m (the larger of the two), and apply either Euler's totient function or a custom version of it to directly count the number of x values that are relatively prime to m. This gives us a much faster method (the speed remains to be quantified more precisely): for example 43ms for n = 100_000_000 instead of 30s with the OP's code (700x speedup).

The need for a custom version arises when the maximum value xmax that x is allowed to take is smaller than m (to satisfy the inequality (m**2 + x**2)/2 <= n). In that case, not all co-primes of m should be counted but only those up to that bound.

def distinct_factors(n):
    # a variant of the well-known factorization, but that
    # yields only distinct factors, rather than all of them
    # (including possible repeats)
    last = None
    i = 2
    while i * i <= n:
        if n % i:
            i += 1
        else:
            n //= i
            if i != last:
                yield i
                last = i
    if n > 1 and n != last:
        yield n

def products_of(p_list, upto):
    for i, p in enumerate(p_list):
        if p > upto:
            break
        yield -p
        for q in products_of(p_list[i+1:], upto=upto // p):
            yield -p * q

def phi(n, upto=None):
    # Euler's totient or "phi" function
    if upto is not None and upto < n:
        # custom version: all co-primes of n up to the `upto` bound
        cnt = upto
        p_list = list(distinct_factors(n))
        for q in products_of(p_list, upto):
            cnt += upto // q if q > 0 else -(upto // -q)
        return cnt
    # standard formulation: all co-primes of n up to n-1
    cnt = n
    for p in distinct_factors(n):
        cnt *= (1 - 1/p)
    return int(cnt)

phi(n) is Euler's totient or ϕ(n) function.

phi(n, upto=x) is a custom variant that counts only the co-primes up to a given value x. To understand it, let's work with an example:

>>> n = 3*3*3*5  # 135
>>> list(factors(n))
[3, 3, 3, 5]

>>> list(distinct_factors(n))
[3, 5]

# there are 72 integers between 1 and 135 that are co-primes of 135
>>> phi(n)
72

# ...but only 53 of them are no greater than 100:
# 100 - (100//3 + 100//5 - 100//(3*5)) 
>>> phi(n, upto=100)
53

When evaluating the number of co-primes of n under a value x, we should count all numbers 1 .. x minus the ones that are multiple of any of the distinct factors of n. However, when simply removing x // p_i for all p_i, we double-count numbers that are multiples of two factors, so we need to "add those back". When doing so, however, we double count (add too many times) the numbers that are multiples of three factors, so we need to account for those as well, etc. In the example n = 135, we remove x // 3 and x // 5, but then that double-counts those integers that are factors of both 3 and 5 (factors of 15), so we need to add those back. For a longer set of factors, we need to:

  • take x as initial count;
  • subtract the number of multiples of each factor p;
  • "un-subtract" (add) the number of multiples of any product of 2 factors;
  • "un-un-subtract" (subtract) the number of multiples of any product of 3 factors;
  • etc.

The initial answer was doing this by iterating over all combinations of distinct factors, but this is substantially optimized in this answer by the products_of(p_list, upto) generator, which gives the products of all subsets of the given p_list distinct factors whose product is no greater than upto. The sign indicates how to account for each product: positively or negatively depending on whether the subset size is even or odd, respectively.

With phi(n) and phi(n, upto) in hand, we can now write the following:

def pyth_m_counts(n):
    # yield tuples (m, count(x) where 0 < x < m and odd(x)
    # and odd(m) and coprime(x, m) and m**2 + x**2 <= 2*n).
    mmax = isqrt(2*n - 1)
    for m in range(3, mmax + 1, 2):
        # requirement: (m**2 + x**2) // 2 <= n
        # and both m and x are odd
        # (so (m**2 + x**2) // 2 == (m**2 + x**2) / 2)
        xmax = isqrt(2*n - m**2)
        cnt_m = phi(2*m, upto=xmax) if xmax < m else phi(2*m) // 2
        if cnt_m > 0:
            yield m, cnt_m

Why the expression phi(2*m) // 2? Since x (and m) must both be odd, according to the OP, we need to remove all the even values. We can do that without modification of phi(), by passing 2*m (which then has 2 as a factor, and will thus "kill" all even values of x) and then dividing by 2 to obtain the actual number of off co-primes to m. A similar (but a little bit more subtle) consideration is done with phi(2*m, upto=xmax) -we'll leave it as exercise for the reader...

Sample run:

>>> n = 300
>>> list(pyth_m_counts(n))
[(3, 1),
 (5, 2),
 (7, 3),
 (9, 3),
 (11, 5),
 (13, 6),
 (15, 4),
 (17, 8),
 (19, 8),
 (21, 3),
 (23, 4)]

That means that, in the OP's function, pythag_triples(300) would have returned 1 tuple with m==3, 2 tuples with m==5, etc.In fact, let's modify that function to verify this:

def mod_pythag_triples(n):
    for x in range(1, int(sqrt(n) + sqrt(n)) + 1, 2):
        for m in range(x+2, int(sqrt(n) + sqrt(n)) + 1, 2):
            if gcd(x, m) == 1:
                c = (m**2 + x**2) // 2
                if c < n:
                    yield x, m

Then:

>>> n = 300
>>> list(pyth_m_counts(n)) == list(Counter(m for x, m in mod_pythag_triples(n)).items())
True

Same for any positive value of n.

Now on the actual count function: we just need to sum up the counts for each m:

def pyth_triples_count(n):
    cnt = 0
    mmax = isqrt(2*n - 1)
    for m in range(3, mmax + 1, 2):
        # requirement: (m**2 + x**2) // 2 <= n
        # and both m and x are odd (so (m**2 + x**2) // 2 == (m**2 + x**2) / 2)
        xmax = isqrt(2*n - m**2)
        cnt += phi(2*m, upto=xmax) if xmax < m else phi(2*m) // 2
    return cnt

Sample runs:

>>> pyth_triples_count(1_000_000)
159139

>>> pyth_triples_count(100_000_000)
15915492

>>> pyth_triples_count(1_000_000_000)
159154994

>>> big_n = 3_141_592_653_589_793
>>> pyth_triples_count(big_n)
500000000002845

Speed:

%timeit pyth_triples_count(100_000_000)
42.8 ms ± 56.1 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

%timeit pyth_triples_count(1_000_000_000)
188 ms ± 571 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

%%time
pyth_triples_count(big_n)
CPU times: user 1h 42min 33s, sys: 480 ms, total: 1h 42min 33s
Wall time: 1h 42min 33s

Note: on the same machine, the code in the OP's question takes 30s for n=100_000_000; this version is 700x faster for that n.

See also my other answer for a faster solution.

Solution 2:

This new answer brings the total time for big_n down to 4min 6s.

An profiling of my initial answer revealed these facts:

  • Total time: 1h 42min 33s
  • Time spent factorizing numbers: almost 100% of the time

In contrast, generating all primes from 3 to sqrt(2*N - 1) takes only 38.5s (using Atkin's sieve).

I therefore decided to try a version where we generate all numbers m as known products of prime numbers. That is, the generator yields the number itself as well as the distinct prime factors involved. No factorization needed.

The result is still 500_000_000_002_841, off by 4 as @Koder noticed. I do not know yet where that problem comes from. Edit: after correction of the xmax bound (isqrt(2*N - m**2) instead of isqrt(2*N - m**2 - 1), since we do want to include triangles with hypothenuse equal to N), we now get the correct result.

The code for the primes generator is included at the end. Basically, I used Atkin's sieve, adapted (without spending much time on it) to Python. I am quite sure it could be sped up (e.g. using numpy and perhaps even numba).

To generate integers from primes (which we know we can do thanks to the Fundamental theorem of arithmetic), we just need to iterate through all the possible products prod(p_i**k_i) where p_i is the i^th prime number and k_i is any non-negative integer.

The easiest formulation is a recursive one:

def gen_ints_from_primes(p_list, upto):
    if p_list and upto >= p_list[0]:
        p, *p_list = p_list
        pk = 1
        p_tup = tuple()
        while pk <= upto:
            for q, p_distinct in gen_ints_from_primes(p_list, upto=upto // pk):
                yield pk * q, p_tup + p_distinct
            pk *= p
            p_tup = (p, )
    else:
        yield 1, tuple()

Unfortunately, we quickly run into memory constraints (and recursion limit). So here is a non-recursive version which uses no extra memory aside from the list of primes themselves. Essentially, the current value of q (the integer in process of being generated) and an index in the list are all the information we need to generate the next integer. Of course, the values come unsorted, but that doesn't matter, as long as they are all covered.

def rem_p(q, p, p_distinct):
    q0 = q
    while q % p == 0:
        q //= p
    if q != q0:
        if p_distinct[-1] != p:
            raise ValueError(f'rem({q}, {p}, ...{p_distinct[-4:]}): p expected at end of p_distinct if q % p == 0')
        p_distinct = p_distinct[:-1]
    return q, p_distinct

def add_p(q, p, p_distinct):
    if len(p_distinct) == 0 or p_distinct[-1] != p:
        p_distinct += (p, )
    q *= p
    return q, p_distinct

def gen_prod_primes(p, upto=None):
    if upto is None:
        upto = p[-1]
    if upto >= p[-1]:
        p = p + [upto + 1]  # sentinel
    
    q = 1
    i = 0
    p_distinct = tuple()
    
    while True:
        while q * p[i] <= upto:
            i += 1
        while q * p[i] > upto:
            yield q, p_distinct
            if i <= 0:
                return
            q, p_distinct = rem_p(q, p[i], p_distinct)
            i -= 1
        q, p_distinct = add_p(q, p[i], p_distinct)

Example-

>>> p_list = list(primes(20))
>>> p_list
[2, 3, 5, 7, 11, 13, 17, 19]

>>> sorted(gen_prod_primes(p_list, 20))
[(1, ()),
 (2, (2,)),
 (3, (3,)),
 (4, (2,)),
 (5, (5,)),
 (6, (2, 3)),
 (7, (7,)),
 (8, (2,)),
 (9, (3,)),
 (10, (2, 5)),
 (11, (11,)),
 (12, (2, 3)),
 (13, (13,)),
 (14, (2, 7)),
 (15, (3, 5)),
 (16, (2,)),
 (17, (17,)),
 (18, (2, 3)),
 (19, (19,)),
 (20, (2, 5))]

As you can see, we don't need to factorize any number, as they conveniently come along with the distinct primes involved.

To get only odd numbers, simply remove 2 from the list of primes:

>>> sorted(gen_prod_primes(p_list[1:]), 20)
[(1, ()),
 (3, (3,)),
 (5, (5,)),
 (7, (7,)),
 (9, (3,)),
 (11, (11,)),
 (13, (13,)),
 (15, (3, 5)),
 (17, (17,)),
 (19, (19,))]

In order to exploit this number-and-factors presentation, we need to amend a bit the function given in the original answer:

def phi(n, upto=None, p_list=None):
    # Euler's totient or "phi" function
    if upto is None or upto > n:
        upto = n
    if p_list is None:
        p_list = list(distinct_factors(n))
    if upto < n:
        # custom version: all co-primes of n up to the `upto` bound
        cnt = upto
        for q in products_of(p_list, upto):
            cnt += upto // q if q > 0 else -(upto // -q)
        return cnt
    # standard formulation: all co-primes of n up to n-1
    cnt = n
    for p in p_list:
        cnt = cnt * (p - 1) // p
    return cnt

With all this, we can now rewrite our counting functions:

def pt_count_m(N):
    # yield tuples (m, count(x) where 0 < x < m and odd(x)
    # and odd(m) and coprime(x, m) and m**2 + x**2 <= 2*N))
    # in this version, m is generated from primes, and the values
    # are iterated through unordered.
    mmax = isqrt(2*N - 1)
    p_list = list(primes(mmax))[1:]  # skip 2
    for m, p_distinct in gen_prod_primes(p_list, upto=mmax):
        if m < 3:
            continue
        # requirement: (m**2 + x**2) // 2 <= N
        # note, both m and x are odd (so (m**2 + x**2) // 2 == (m**2 + x**2) / 2)
        xmax = isqrt(2*N - m*m)
        cnt_m = phi(m+1, upto=xmax, p_list=(2,) + tuple(p_distinct))
        if cnt_m > 0:
            yield m, cnt_m

def pt_count(N, progress=False):
    mmax = isqrt(2*N - 1)
    it = pt_count_m(N)
    if progress:
        it = tqdm(it, total=(mmax - 3 + 1) // 2)
    return sum(cnt_m for m, cnt_m in it)

And now:

%timeit pt_count(100_000_000)
31.1 ms ± 38.9 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

%timeit pt_count(1_000_000_000)
104 ms ± 299 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

# the speedup is still very moderate at that stage

# however:
%%time
big_n = 3_141_592_653_589_793
N = big_n
res = pt_count(N)

CPU times: user 4min 5s, sys: 662 ms, total: 4min 6s
Wall time: 4min 6s

>>> res
500000000002845

Addendum: Atkin's sieve

As promised, here is my version of Atkin's sieve. It can definitely be sped up.

def primes(limit):
    # Generates prime numbers between 2 and n
    # Atkin's sieve -- see http://en.wikipedia.org/wiki/Prime_number
    sqrtLimit = isqrt(limit) + 1

    # initialize the sieve
    is_prime = [False, False, True, True, False] + [False for _ in range(5, limit + 1)]

    # put in candidate primes:
    # integers which have an odd number of
    # representations by certain quadratic forms
    for x in range(1, sqrtLimit):
        x2 = x * x
        for y in range(1, sqrtLimit):
            y2 = y*y
            n = 4 * x2 + y2
            if n <= limit and (n % 12 == 1 or n % 12 == 5): is_prime[n] ^= True
            n = 3 * x2 + y2
            if n <= limit and (n % 12 == 7): is_prime[n] ^= True
            n = 3*x2-y2
            if n <= limit and x > y and n % 12 == 11: is_prime[n] ^= True

    # eliminate composites by sieving
    for n in range(5, sqrtLimit):
        if is_prime[n]:
            sqN = n**2
            # n is prime, omit multiples of its square; this is sufficient because
            # composites which managed to get on the list cannot be square-free
            for i in range(1, int(limit/sqN) + 1):
                k = i * sqN # k ∈ {n², 2n², 3n², ..., limit}
                is_prime[k] = False
    for i, truth in enumerate(is_prime):
        if truth: yield i