I want to use JAX as a vehicle for gradient descent; however, I have a moderately large number of parameters and would prefer to pass them as a dictionary f(func, dict) rather than f(func, x1, …xn). So instead of Something more like Is this possible? EDIT: This is my current work around solution: The gi…
Tag: jax
Why does GPU memory increase when recreating and reassigning a JAX numpy array to the same variable name?
When I recreate and reassign a JAX np array to the same variable name, for some reason the GPU memory nearly doubles the first recreation and then stays stable for subsequent recreations/reassignments. Why does this happen and is this generally expected behavior for JAX arrays? Fully runnable minimal example:…
Error when trying to jit the computation of the Jacobian in JAX: “ValueError: Non-hashable static arguments are not supported”
This question is similar to the question here, but I cannot link with what I should alter. I have a function where variational_parameters is a vector (one-dimensional array) of length P, eps is a two-dimensional array of dimensions K by N, and a, b are fixed values. The elbo has been successfully vmapped over…
JAX: Getting rid of zero-gradient
Is there a way how to modify this function (MyFunc) so that it gives the same result, but its derivative is not zero gradient? EDIT: Similar function which doesn’t give zero gradient – but it doesn’t return 30/20/10 Answer The gradient of your function is zero because this is the correct res…
(Conv1D) Tensorflow and Jax Resulting Different Outputs for The Same Input
I am trying to use conv1d functions to make a transposed convlotion repectively at jax and tensorflow. I read the documentation of both of jax and tensorflow for the con1d_transposed operation but they are resulting with different outputs for the same input. I can not find out what the problem is. And I don&#…
Not able to install jaxlib
I am trying to install jaxlib on my windows 10 by the following command which I found on the documentation.. pip install jaxlib It shows the following error Answer Jaxlib is not supported on windows you can see it here.. https://github.com/google/jax/issues/438