I have a a loop in which I’m calculating several pseudoinverses of rather large, non-sparse matrices (eg. 20000x800
).
As my code spends most time on the pinv
, I was trying to find a way to speed up the computation. I’m already using multiprocessing (joblib/loky
) to run with several processes, but that of course increases also overhead. Using jit
did not help much.
Is there a faster way / better implementation to compute pseudoinverse using any function? Precision isn’t key.
My current benchmark
import time import numba import numpy as np from numpy.linalg import pinv as np_pinv from scipy.linalg import pinv as scipy_pinv from scipy.linalg import pinv2 as scipy_pinv2 @numba.njit def np_jit_pinv(A): return np_pinv(A) matrix = np.random.rand(20000, 800) for pinv in [np_pinv, scipy_pinv, scipy_pinv2, np_jit_pinv]: start = time.time() pinv(matrix) print(f'{pinv.__module__ +"."+pinv.__name__} took {time.time()-start:.3f}')
numpy.linalg.pinv took 2.774 scipy.linalg.basic.pinv took 1.906 scipy.linalg.basic.pinv2 took 1.682 __main__.np_jit_pinv took 2.446
EDIT: JAX seems to be 30% faster! impressive! Thanks for letting me know @yuri-brigance . For Windows it works well under WSL.
numpy.linalg.pinv took 2.774 scipy.linalg.basic.pinv took 1.906 scipy.linalg.basic.pinv2 took 1.682 __main__.np_jit_pinv took 2.446 jax._src.numpy.linalg.pinv took 0.995
Advertisement
Answer
Try with JAX:
import jax.numpy as jnp jnp.linalg.pinv(A)
Seems to be slightly faster than regular numpy.linalg.pinv
. On my machine your benchmark looks like this:
jax._src.numpy.linalg.pinv took 3.127 numpy.linalg.pinv took 4.284