I would like to do simple division and average using jit function where nopython = True.
import numpy as np from numba import jit,prange,typed A = np.array([[2,2,2],[1,0,0],[1,2,1]], dtype=np.float32) B = np.array([[2,0,2],[0,1,0],[1,2,1]],dtype=np.float32) C = np.array([[2,0,1],[0,1,0],[1,1,2]],dtype=np.float32)
my jit function goes
@jit(nopython=True)
def test(a,b,c):
    mask = a+b >0
    div = np.divide(c, a+b, where=mask)
    result = div.mean(axis=1)
    return result
test_res = test(A,B,C)
however this throws me an error, what would be the workaround for this? I am trying to do this without the loop, any lights would be appreiciate.
Advertisement
Answer
numba doesn’t support some arguments for some of numpy modules (e.g. np.mean() or where in np.divid) (including “axis” argument which is not included). You can do this by some alternative codes like:
@nb.njit("float64[::1](float32[:, ::1], float32[:, ::1], float32[:, ::1])")  # parallel --> , parallel=True
def test(a, b, c):
    result = np.zeros(c.shape[0])
    c = np.copy(c)
    for i in range(c.shape[0]):     # parallel --> for i in nb.prange(c.shape[0]):
        for j in range(c.shape[1]):
            if a[i, j] + b[i, j] > 0:
                c[i, j] = c[i, j] / (a[i, j] + b[i, j])
            result[i] += c[i, j]
    return result / c.shape[1]
JAX library can be used to accelerate as:
import jax
import jax.numpy as jnp
@jax.jit
def test_jax(a, b, c):
    mask = a + b > 0
    div = jnp.where(mask, jnp.divide(c, a + b), c)
    return jnp.mean(div, axis=1)