I have a a loop in which I’m calculating several pseudoinverses of rather large, non-sparse matrices (eg. 20000x800
).
As my code spends most time on the pinv
, I was trying to find a way to speed up the computation. I’m already using multiprocessing (joblib/loky
) to run with several processes, but that of course increases also overhead. Using jit
did not help much.
Is there a faster way / better implementation to compute pseudoinverse using any function? Precision isn’t key.
My current benchmark
JavaScript
x
17
17
1
import time
2
import numba
3
import numpy as np
4
from numpy.linalg import pinv as np_pinv
5
from scipy.linalg import pinv as scipy_pinv
6
from scipy.linalg import pinv2 as scipy_pinv2
7
8
@numba.njit
9
def np_jit_pinv(A):
10
return np_pinv(A)
11
12
matrix = np.random.rand(20000, 800)
13
for pinv in [np_pinv, scipy_pinv, scipy_pinv2, np_jit_pinv]:
14
start = time.time()
15
pinv(matrix)
16
print(f'{pinv.__module__ +"."+pinv.__name__} took {time.time()-start:.3f}')
17
JavaScript
1
5
1
numpy.linalg.pinv took 2.774
2
scipy.linalg.basic.pinv took 1.906
3
scipy.linalg.basic.pinv2 took 1.682
4
__main__.np_jit_pinv took 2.446
5
EDIT: JAX seems to be 30% faster! impressive! Thanks for letting me know @yuri-brigance . For Windows it works well under WSL.
JavaScript
1
6
1
numpy.linalg.pinv took 2.774
2
scipy.linalg.basic.pinv took 1.906
3
scipy.linalg.basic.pinv2 took 1.682
4
__main__.np_jit_pinv took 2.446
5
jax._src.numpy.linalg.pinv took 0.995
6
Advertisement
Answer
Try with JAX:
JavaScript
1
4
1
import jax.numpy as jnp
2
3
jnp.linalg.pinv(A)
4
Seems to be slightly faster than regular numpy.linalg.pinv
. On my machine your benchmark looks like this:
JavaScript
1
3
1
jax._src.numpy.linalg.pinv took 3.127
2
numpy.linalg.pinv took 4.284
3