I would like to do simple division and average using jit function where nopython = True.
JavaScript
x
8
1
import numpy as np
2
from numba import jit,prange,typed
3
4
A = np.array([[2,2,2],[1,0,0],[1,2,1]], dtype=np.float32)
5
B = np.array([[2,0,2],[0,1,0],[1,2,1]],dtype=np.float32)
6
C = np.array([[2,0,1],[0,1,0],[1,1,2]],dtype=np.float32)
7
8
my jit function goes
JavaScript
1
10
10
1
@jit(nopython=True)
2
def test(a,b,c):
3
mask = a+b >0
4
div = np.divide(c, a+b, where=mask)
5
result = div.mean(axis=1)
6
7
return result
8
9
test_res = test(A,B,C)
10
however this throws me an error, what would be the workaround for this? I am trying to do this without the loop, any lights would be appreiciate.
Advertisement
Answer
numba doesn’t support some arguments for some of numpy modules (e.g. np.mean()
or where
in np.divid
) (including “axis” argument which is not included). You can do this by some alternative codes like:
JavaScript
1
12
12
1
@nb.njit("float64[::1](float32[:, ::1], float32[:, ::1], float32[:, ::1])") # parallel --> , parallel=True
2
def test(a, b, c):
3
result = np.zeros(c.shape[0])
4
c = np.copy(c)
5
for i in range(c.shape[0]): # parallel --> for i in nb.prange(c.shape[0]):
6
for j in range(c.shape[1]):
7
if a[i, j] + b[i, j] > 0:
8
c[i, j] = c[i, j] / (a[i, j] + b[i, j])
9
result[i] += c[i, j]
10
11
return result / c.shape[1]
12
JAX library can be used to accelerate as:
JavaScript
1
9
1
import jax
2
import jax.numpy as jnp
3
4
@jax.jit
5
def test_jax(a, b, c):
6
mask = a + b > 0
7
div = jnp.where(mask, jnp.divide(c, a + b), c)
8
return jnp.mean(div, axis=1)
9