Skip to content
Advertisement

how to use mask using numba @jit

I would like to do simple division and average using jit function where nopython = True.

import numpy as np
from numba import jit,prange,typed

A = np.array([[2,2,2],[1,0,0],[1,2,1]], dtype=np.float32)
B = np.array([[2,0,2],[0,1,0],[1,2,1]],dtype=np.float32)
C = np.array([[2,0,1],[0,1,0],[1,1,2]],dtype=np.float32)

my jit function goes

@jit(nopython=True)
def test(a,b,c):
    mask = a+b >0
    div = np.divide(c, a+b, where=mask)
    result = div.mean(axis=1)

    return result

test_res = test(A,B,C)

however this throws me an error, what would be the workaround for this? I am trying to do this without the loop, any lights would be appreiciate.

Advertisement

Answer

numba doesn’t support some arguments for some of numpy modules (e.g. np.mean() or where in np.divid) (including “axis” argument which is not included). You can do this by some alternative codes like:

@nb.njit("float64[::1](float32[:, ::1], float32[:, ::1], float32[:, ::1])")  # parallel --> , parallel=True
def test(a, b, c):
    result = np.zeros(c.shape[0])
    c = np.copy(c)
    for i in range(c.shape[0]):     # parallel --> for i in nb.prange(c.shape[0]):
        for j in range(c.shape[1]):
            if a[i, j] + b[i, j] > 0:
                c[i, j] = c[i, j] / (a[i, j] + b[i, j])
            result[i] += c[i, j]

    return result / c.shape[1]

JAX library can be used to accelerate as:

import jax
import jax.numpy as jnp

@jax.jit
def test_jax(a, b, c):
    mask = a + b > 0
    div = jnp.where(mask, jnp.divide(c, a + b), c)
    return jnp.mean(div, axis=1)
User contributions licensed under: CC BY-SA
10 People found this is helpful
Advertisement