Skip to content
Advertisement

JAX: Passing a dictionary rather than arg nums to identify variables for autodifferentiation

I want to use JAX as a vehicle for gradient descent; however, I have a moderately large number of parameters and would prefer to pass them as a dictionary f(func, dict) rather than f(func, x1, ...xn).

So instead of

# https://www.kaggle.com/code/grez911/tutorial-efficient-gradient-descent-with-jax/notebook
def J(X, w, b, y):
    """Cost function for a linear regression. A forward pass of our model.

    Args:
        X: a features matrix.
        w: weights (a column vector).
        b: a bias.
        y: a target vector.

    Returns:
        scalar: a cost of this solution.    
    """
    y_hat = X.dot(w) + b # Predict values.
    return ((y_hat - y)**2).mean() # Return cost.

for i in range(100):
    w -= learning_rate * grad(J, argnums=1)(X, w, b, y)
    b -= learning_rate * grad(J, argnums=2)(X, w, b, y)

Something more like

for i in range(100):
    w -= learning_rate * grad(J, arg_key='w')(arg_dict)
    b -= learning_rate * grad(J, arg_key='b')(arg_dict)

Is this possible?

EDIT:

This is my current work around solution:

# A features matrix.
X = np.array([
                 [4., 7.],
                 [1., 8.],
                 [-5., -6.],
                 [3., -1.],
                 [0., 9.]
             ])

# A target column vector.
y = np.array([
                 [37.],
                 [24.],
                 [-34.], 
                 [16.],
                 [21.]
             ])

learning_rate = 0.01

w = np.zeros((2, 1))
b = 0.

import jax.numpy as np
from jax import grad

def J(X, w, b, y):
    """Cost function for a linear regression. A forward pass of our model.

    Args:
        X: a features matrix.
        w: weights (a column vector).
        b: a bias.
        y: a target vector.

    Returns:
        scalar: a cost of this solution.    
    """
    y_hat = X.dot(w) + b # Predict values.
    return ((y_hat - y)**2).mean() # Return cost.

# Define your function arguments as a dictionary
arg_dict = {
    'X': X,
    'w': w,
    'b': b,
    'y': y
}
idx_dict = {idx:name for idx,name in enumerate(arg_dict.keys())}
arg_arr = [arg_dict[idx_dict[idx]] for idx in range(len(arg_dict))]

for i in range(100):
  for idx, name in idx_dict.items():
    var = arg_dict[idx_dict[idx]]
    var -= learning_rate * grad(J, argnums=idx) (*arg_arr)

The gist is that now I don’t need to write grad(…) for every single variable that needs autodifferentiation.

Advertisement

Answer

Specifying autodiff argnums by name is not currently supported in JAX, although the idea is under discussion in this issue: https://github.com/google/jax/issues/10614

Until that is implemented, there are ways you can convert argnames to argnums automatically using inspect.signature for your function (some examples are in the linked issue), but overall it’s probably simpler to do that mapping manually for your specific function.

User contributions licensed under: CC BY-SA
9 People found this is helpful
Advertisement