Skip to content
Advertisement

Python: Memory-efficient random sampling of list of permutations

I am seeking to sample n random permutations of a list in Python.

This is my code:

obj = [    5     8     9 ... 45718 45719 45720]
#type(obj) = numpy.ndarray

pairs = random.sample(list(permutations(obj,2)),k= 150) 

Although the code does what I want it to, it causes memory issues. I sometimes receive the error Memory error when running on CPU, and when running on GPU, my virtual machine crashes.

How can I make the code work in a more memory-efficient manner?

Advertisement

Answer

Building on Pablo Ruiz’s excellent answer, I suggest wrapping his sampling solution into a generator function that yields unique permutations by keeping track of what it has already yielded:

import numpy as np

def unique_permutations(sequence, r, n):
    """Yield n unique permutations of r elements from sequence"""
    seen = set()
    while len(seen) < n:
        # This line of code adapted from Pablo Ruiz's answer:
        candidate_permutation = tuple(np.random.choice(sequence, r, replace=False))

        if candidate_permutation not in seen:
            seen.add(candidate_permutation)
            yield candidate_permutation

obj = list(range(10))
for permutation in unique_permutations(obj, 2, 15):
    # do something with the permutation

# Or, to save the result as a list:
pairs = list(unique_permutations(obj, 2, 15))

My assumption is that you are sampling a small subset of the very large number of possible permutations, in which case collisions will be rare enough that keeping a seen set will not be expensive.

Warnings: this function is an infinite loop if you ask for more permutations than are possible given the inputs. It will also get increasingly slow an n gets close to the number of possible permutations, since collisions will get increasingly frequent.

If I were to put this function in my code base, I would put a shield at the top that calculated the number of possible permutations and raised a ValueError exception if n exceeded that number, and maybe output a warning if n exceeded one tenth that number, or something like that.

User contributions licensed under: CC BY-SA
5 People found this is helpful
Advertisement