I would like to do simple division and average using jit function where nopython = True.
import numpy as np from numba import jit,prange,typed A = np.array([[2,2,2],[1,0,0],[1,2,1]], dtype=np.float32) B = np.array([[2,0,2],[0,1,0],[1,2,1]],dtype=np.float32) C = np.array([[2,0,1],[0,1,0],[1,1,2]],dtype=np.float32)
my jit function goes
@jit(nopython=True) def test(a,b,c): mask = a+b >0 div = np.divide(c, a+b, where=mask) result = div.mean(axis=1) return result test_res = test(A,B,C)
however this throws me an error, what would be the workaround for this? I am trying to do this without the loop, any lights would be appreiciate.
Advertisement
Answer
numba doesn’t support some arguments for some of numpy modules (e.g. np.mean()
or where
in np.divid
) (including “axis” argument which is not included). You can do this by some alternative codes like:
@nb.njit("float64[::1](float32[:, ::1], float32[:, ::1], float32[:, ::1])") # parallel --> , parallel=True def test(a, b, c): result = np.zeros(c.shape[0]) c = np.copy(c) for i in range(c.shape[0]): # parallel --> for i in nb.prange(c.shape[0]): for j in range(c.shape[1]): if a[i, j] + b[i, j] > 0: c[i, j] = c[i, j] / (a[i, j] + b[i, j]) result[i] += c[i, j] return result / c.shape[1]
JAX library can be used to accelerate as:
import jax import jax.numpy as jnp @jax.jit def test_jax(a, b, c): mask = a + b > 0 div = jnp.where(mask, jnp.divide(c, a + b), c) return jnp.mean(div, axis=1)