Skip to content

Convert a recursive python code to a non-recursive version

The code provided here works unless we start to increase the distinct and n-symbols and length, for example, on my computer n_symbols=512, length=512, distinct=300 ends up with this error RecursionError: maximum recursion depth exceeded in comparison and then overflow errors if I increase the lru_cache value.
What I want is to have a non-recursive version of this code.

from functools import lru_cache
def get_permutations_count(n_symbols, length, distinct, used=0):
     - n_symbols: number of symbols in the alphabet
     - length: the number of symbols in each sequence
     - distinct: the number of distinct symbols in each sequence
    if distinct < 0:
        return 0
    if length == 0:
        return 1 if distinct == 0 else 0
          get_permutations_count(n_symbols, length-1, distinct-0, used+0) * used + 
          get_permutations_count(n_symbols, length-1, distinct-1, used+1) * (n_symbols - used)


get_permutations_count(n_symbols=300, length=300, distinct=270)

runs in ~0.5 second giving the answer




Here’s mine:

def get_permutations_count_improved(n_symbols, length, distinct):
    if distinct > length or distinct > n_symbols:
        return 0
    ways = [1]
    for _ in range(length):
        ways = [used * (distinct - d) + new
               for d, used, new in zip(range(distinct+1), [*ways, 0], [0, *ways])]
    return ways[distinct] * comb(n_symbols, distinct) * factorial(distinct)

Speed comparison for some argument sets:

n_symbols length distinct   yours    mine
   300      300    270      0.62 s   0.012 s (~51 times faster)
   512      512    300        -      0.035 s
  1024     1024    600        -      0.22 s
  3000     3000   2700        -      6.0 s

In my last line you see I split the overall result into three factors:

  • comb(n_symbols, distinct) for choosing which distinct out of the n_symbols symbols actually get used. That essentially gets rid of the n_symbols parameter, or think of it as compensating setting n_symbols = distinct.
  • factorial(distinct) for the order in which the symbols get used first. This gets rid of the * (n_symbols - used) in your recurrence.
  • ways[distinct] is the number of ways to build a sequence of length length with exactly distinct distinct symbols, where the order in which they get used first is fixed.

It might be easier to think of the ways table as two-dimensional: ways[length][distinct]. But for more memory-efficiency, I compute it row by row and only keep the latest row.

Benchmark and some correctness checks (Try it online!):

from timeit import timeit
from functools import lru_cache
from math import comb, factorial

def get_permutations_count(n_symbols, length, distinct, used=0):
     - n_symbols: number of symbols in the alphabet
     - length: the number of symbols in each sequence
     - distinct: the number of distinct symbols in each sequence
    if distinct < 0:
        return 0
    if length == 0:
        return 1 if distinct == 0 else 0
          get_permutations_count(n_symbols, length-1, distinct-0, used+0) * used + 
          get_permutations_count(n_symbols, length-1, distinct-1, used+1) * (n_symbols - used)

def get_permutations_count_improved(n_symbols, length, distinct):
    if distinct > length or distinct > n_symbols:
        return 0
    ways = [1]
    for _ in range(length):
        ways = [used * (distinct - d) + new
               for d, used, new in zip(range(distinct+1), [*ways, 0], [0, *ways])]
    return ways[distinct] * comb(n_symbols, distinct) * factorial(distinct)

funcs = get_permutations_count, get_permutations_count_improved

# Check correctness
stop = 20
for a in range(stop):
    for b in range(stop):
        for c in range(stop):
            expect = get_permutations_count(a, b, c)
            result = get_permutations_count_improved(a, b, c)
            assert result == expect, (a, b, c, expect, result)

# Benchmark
n_symbols, length, distinct = 300, 300, 270
#n_symbols, length, distinct = 512, 512, 300
#n_symbols, length, distinct = 1024, 1024, 600
#n_symbols, length, distinct = 3000, 3000, 2700
for func in funcs[0:] * 3:
    t = timeit(lambda: func(n_symbols, length, distinct), number=1)
    print('%.3f seconds ' % t, func.__name__)
User contributions licensed under: CC BY-SA
1 People found this is helpful