The code provided here works unless we start to increase the distinct and n-symbols and length, for example, on my computer n_symbols=512, length=512, distinct=300 ends up with this error RecursionError: maximum recursion depth exceeded in comparison and then overflow errors if I increase the lru_cache value.
What I want is to have a non-recursive version of this code.
from functools import lru_cache @lru_cache def get_permutations_count(n_symbols, length, distinct, used=0): ''' - n_symbols: number of symbols in the alphabet - length: the number of symbols in each sequence - distinct: the number of distinct symbols in each sequence ''' if distinct < 0: return 0 if length == 0: return 1 if distinct == 0 else 0 else: return get_permutations_count(n_symbols, length-1, distinct-0, used+0) * used + get_permutations_count(n_symbols, length-1, distinct-1, used+1) * (n_symbols - used)
Then
get_permutations_count(n_symbols=300, length=300, distinct=270)
runs in ~0.5 second giving the answer
2729511887951350984580070745513114266766906881300774347439917775 7093985721949669285469996223829969654724957176705978029888262889 8157939885553971500652353177628564896814078569667364402373549268 5524290993833663948683375995196081654415976659499171897405039547 1546236260377859451955180752885715923847446106509971875543496023 2494854876774756172488117802642800540206851318332940739395445903 6305051887120804168979339693187702655904071331731936748927759927 3688881301614948043182289382736687065840703041231428800720854767 0713406956719647313048146023960093662879015837313428567467555885 3564982943420444850950866922223974844727296000000000000000000000 000000000000000000000000000000000000000000000000
Advertisement
Answer
Here’s mine:
def get_permutations_count_improved(n_symbols, length, distinct): if distinct > length or distinct > n_symbols: return 0 ways = [1] for _ in range(length): ways = [used * (distinct - d) + new for d, used, new in zip(range(distinct+1), [*ways, 0], [0, *ways])] return ways[distinct] * comb(n_symbols, distinct) * factorial(distinct)
Speed comparison for some argument sets:
n_symbols length distinct yours mine 300 300 270 0.62 s 0.012 s (~51 times faster) 512 512 300 - 0.035 s 1024 1024 600 - 0.22 s 3000 3000 2700 - 6.0 s
In my last line you see I split the overall result into three factors:
comb(n_symbols, distinct)
for choosing whichdistinct
out of then_symbols
symbols actually get used. That essentially gets rid of then_symbols
parameter, or think of it as compensating settingn_symbols = distinct
.factorial(distinct)
for the order in which the symbols get used first. This gets rid of the* (n_symbols - used)
in your recurrence.ways[distinct]
is the number of ways to build a sequence of lengthlength
with exactlydistinct
distinct symbols, where the order in which they get used first is fixed.
It might be easier to think of the ways
table as two-dimensional: ways[length][distinct]
. But for more memory-efficiency, I compute it row by row and only keep the latest row.
Benchmark and some correctness checks (Try it online!):
from timeit import timeit from functools import lru_cache from math import comb, factorial @lru_cache def get_permutations_count(n_symbols, length, distinct, used=0): ''' - n_symbols: number of symbols in the alphabet - length: the number of symbols in each sequence - distinct: the number of distinct symbols in each sequence ''' if distinct < 0: return 0 if length == 0: return 1 if distinct == 0 else 0 else: return get_permutations_count(n_symbols, length-1, distinct-0, used+0) * used + get_permutations_count(n_symbols, length-1, distinct-1, used+1) * (n_symbols - used) def get_permutations_count_improved(n_symbols, length, distinct): if distinct > length or distinct > n_symbols: return 0 ways = [1] for _ in range(length): ways = [used * (distinct - d) + new for d, used, new in zip(range(distinct+1), [*ways, 0], [0, *ways])] return ways[distinct] * comb(n_symbols, distinct) * factorial(distinct) funcs = get_permutations_count, get_permutations_count_improved # Check correctness stop = 20 for a in range(stop): for b in range(stop): for c in range(stop): expect = get_permutations_count(a, b, c) result = get_permutations_count_improved(a, b, c) assert result == expect, (a, b, c, expect, result) # Benchmark n_symbols, length, distinct = 300, 300, 270 #n_symbols, length, distinct = 512, 512, 300 #n_symbols, length, distinct = 1024, 1024, 600 #n_symbols, length, distinct = 3000, 3000, 2700 for func in funcs[0:] * 3: funcs[0].cache_clear() t = timeit(lambda: func(n_symbols, length, distinct), number=1) print('%.3f seconds ' % t, func.__name__)