Skip to content
Advertisement

Fastest way for computing pseudoinverse (pinv) in Python

I have a a loop in which I’m calculating several pseudoinverses of rather large, non-sparse matrices (eg. 20000x800).

As my code spends most time on the pinv, I was trying to find a way to speed up the computation. I’m already using multiprocessing (joblib/loky) to run with several processes, but that of course increases also overhead. Using jit did not help much.

Is there a faster way / better implementation to compute pseudoinverse using any function? Precision isn’t key.

My current benchmark

import time
import numba
import numpy as np
from numpy.linalg import pinv as np_pinv
from scipy.linalg import pinv as scipy_pinv
from scipy.linalg import pinv2 as scipy_pinv2

@numba.njit
def np_jit_pinv(A):
  return np_pinv(A)

matrix = np.random.rand(20000, 800)
for pinv in [np_pinv, scipy_pinv, scipy_pinv2, np_jit_pinv]:
    start = time.time()
    pinv(matrix)
    print(f'{pinv.__module__ +"."+pinv.__name__} took {time.time()-start:.3f}')
numpy.linalg.pinv took 2.774
scipy.linalg.basic.pinv took 1.906
scipy.linalg.basic.pinv2 took 1.682
__main__.np_jit_pinv took 2.446

EDIT: JAX seems to be 30% faster! impressive! Thanks for letting me know @yuri-brigance . For Windows it works well under WSL.

numpy.linalg.pinv took 2.774
scipy.linalg.basic.pinv took 1.906
scipy.linalg.basic.pinv2 took 1.682
__main__.np_jit_pinv took 2.446
jax._src.numpy.linalg.pinv took 0.995

Advertisement

Answer

Try with JAX:

import jax.numpy as jnp

jnp.linalg.pinv(A)

Seems to be slightly faster than regular numpy.linalg.pinv. On my machine your benchmark looks like this:

jax._src.numpy.linalg.pinv took 3.127
numpy.linalg.pinv took 4.284
User contributions licensed under: CC BY-SA
5 People found this is helpful
Advertisement