I want to use JAX as a vehicle for gradient descent; however, I have a moderately large number of parameters and would prefer to pass them as a dictionary f(func, dict)
rather than f(func, x1, ...xn)
.
So instead of
# https://www.kaggle.com/code/grez911/tutorial-efficient-gradient-descent-with-jax/notebook def J(X, w, b, y): """Cost function for a linear regression. A forward pass of our model. Args: X: a features matrix. w: weights (a column vector). b: a bias. y: a target vector. Returns: scalar: a cost of this solution. """ y_hat = X.dot(w) + b # Predict values. return ((y_hat - y)**2).mean() # Return cost. for i in range(100): w -= learning_rate * grad(J, argnums=1)(X, w, b, y) b -= learning_rate * grad(J, argnums=2)(X, w, b, y)
Something more like
for i in range(100): w -= learning_rate * grad(J, arg_key='w')(arg_dict) b -= learning_rate * grad(J, arg_key='b')(arg_dict)
Is this possible?
EDIT:
This is my current work around solution:
# A features matrix. X = np.array([ [4., 7.], [1., 8.], [-5., -6.], [3., -1.], [0., 9.] ]) # A target column vector. y = np.array([ [37.], [24.], [-34.], [16.], [21.] ]) learning_rate = 0.01 w = np.zeros((2, 1)) b = 0. import jax.numpy as np from jax import grad def J(X, w, b, y): """Cost function for a linear regression. A forward pass of our model. Args: X: a features matrix. w: weights (a column vector). b: a bias. y: a target vector. Returns: scalar: a cost of this solution. """ y_hat = X.dot(w) + b # Predict values. return ((y_hat - y)**2).mean() # Return cost. # Define your function arguments as a dictionary arg_dict = { 'X': X, 'w': w, 'b': b, 'y': y } idx_dict = {idx:name for idx,name in enumerate(arg_dict.keys())} arg_arr = [arg_dict[idx_dict[idx]] for idx in range(len(arg_dict))] for i in range(100): for idx, name in idx_dict.items(): var = arg_dict[idx_dict[idx]] var -= learning_rate * grad(J, argnums=idx) (*arg_arr)
The gist is that now I don’t need to write grad(…) for every single variable that needs autodifferentiation.
Advertisement
Answer
Specifying autodiff argnums by name is not currently supported in JAX, although the idea is under discussion in this issue: https://github.com/google/jax/issues/10614
Until that is implemented, there are ways you can convert argnames to argnums automatically using inspect.signature
for your function (some examples are in the linked issue), but overall it’s probably simpler to do that mapping manually for your specific function.