I am seeking to sample n random permutations of a list in Python.
This is my code:
obj = [ 5 8 9 ... 45718 45719 45720] #type(obj) = numpy.ndarray pairs = random.sample(list(permutations(obj,2)),k= 150)
Although the code does what I want it to, it causes memory issues. I sometimes receive the error Memory error
when running on CPU, and when running on GPU, my virtual machine crashes.
How can I make the code work in a more memory-efficient manner?
Advertisement
Answer
Building on Pablo Ruiz’s excellent answer, I suggest wrapping his sampling solution into a generator function that yields unique permutations by keeping track of what it has already yielded:
import numpy as np def unique_permutations(sequence, r, n): """Yield n unique permutations of r elements from sequence""" seen = set() while len(seen) < n: # This line of code adapted from Pablo Ruiz's answer: candidate_permutation = tuple(np.random.choice(sequence, r, replace=False)) if candidate_permutation not in seen: seen.add(candidate_permutation) yield candidate_permutation obj = list(range(10)) for permutation in unique_permutations(obj, 2, 15): # do something with the permutation # Or, to save the result as a list: pairs = list(unique_permutations(obj, 2, 15))
My assumption is that you are sampling a small subset of the very large number of possible permutations, in which case collisions will be rare enough that keeping a seen
set will not be expensive.
Warnings: this function is an infinite loop if you ask for more permutations than are possible given the inputs. It will also get increasingly slow an n
gets close to the number of possible permutations, since collisions will get increasingly frequent.
If I were to put this function in my code base, I would put a shield at the top that calculated the number of possible permutations and raised a ValueError
exception if n
exceeded that number, and maybe output a warning if n
exceeded one tenth that number, or something like that.