This question is similar to the question here, but I cannot link with what I should alter.
I have a function
def elbo(variational_parameters, eps, a, b): ... return theta, _ elbo = jit(elbo, static_argnames=["a", "b"])
where variational_parameters
is a vector (one-dimensional array) of length P, eps
is a two-dimensional array of dimensions K by N, and a
, b
are fixed values.
The elbo
has been successfully vmap
ped over the rows of eps
, and has been jit
ted by setting by passing a
and b
to static_argnames
, to return theta
, which is a two-dimensional array of dimensions K by P.
I want to take the Jacobian of the output theta
with respect to variational_parameters
through the elbo
function. The first value returned by
jacobian(elbo, argnums=0, has_aus=True)(variational_parameters, eps, a, b)
gives me a three-dimensional array of dimensions K by P by N. This is what I want. As soon as I try to jit this function
jit(jacobian(elbo, argnums=0, has_aus=True))(variational_parameters, eps, a, b)
I get the error
ValueError: Non-hashable static arguments are not supported, which can lead to unexpected cache-misses. Static argument (index 2) of type <class 'jax.interpreters.partial_eval.DynamicJaxprTracer'> for function elbo is non-hashable.
Any help would be greatly appreciated; thanks!
Advertisement
Answer
Any parameters you pass to a JIT-compiled function will no longer be static, unless you explicitly mark them as such. So this line:
jit(jacobian(elbo, argnums=0, has_aus=True))(variational_parameters, eps, a, b)
Makes variational_parameters
, eps
, a
, and b
non-static. Then within the transformed function these non-static parameters are passed to this function:
elbo = jit(elbo, static_argnames=["a", "b"])
which means that you are attempting to pass non-static values as static arguments, which causes an error.
To fix this, you should mark the static parameters as static any time they enter a jit-compiled function. In your case it might look something like this:
jit(jacobian(elbo, argnums=0, has_aus=True), static_argnums=(2, 3))(variational_parameters, eps, a, b)