I would like to know how to write a custom gradient for a function which have multiple outputs( or an array). For a simple example, I wrote the following code for y=tan( x @ w + b) with x shape is (2,3) and y shape is (2,2). To compare results, I calculated the operation by usual way and by the custom gradient.
Here is the code.
#----------------------------------------------------------------- # Gradient test example #----------------------------------------------------------------- w = tf.Variable([[1.,2.],[2.,3.],[3.,4.]], name='w') # shape (input_size ,unit_size) = (3,2) b = tf.Variable([1.0,2.0], name='b') #shape(2,) x = tf.constant([[1., 2., 3.],[1.,1.,1.]]) #shape = (batch_size,input_size) (2,3) @tf.custom_gradient def custom_op(x): @tf.function def _inner_function(): y = x @ w + b return tf.math.tan(y) # y shape = (batch_size,unit_size) y = _inner_function() def grads(upstream,variables): # here upstream is a shape of (batch_size,unit_size) assert variables[0] is w yp = 1.0/tf.square(tf.math.cos(y)) # tan'(y) = 1/cos(y)**2 dydx = (upstream*yp) @ tf.transpose(w) # (batch_size,input_size) dydw = tf.transpose(x) @ (upstream*yp) # (input_size,unit_size) dydb = tf.reduce_sum(upstream*yp,axis=0) # (unit_size,) return dydx, [dydw, dydb] return y, grads #----------------------------------------------- # feed forward #----------------------------------------------- with tf.GradientTape(persistent=True) as tape: tape.watch(x) y = tf.math.tan(x @ w + b )# shape (1,2) y2 = custom_op(x) loss = tf.reduce_mean(y**2) loss2 = tf.reduce_mean(y2**2) #---------------------------------------------- # compute gradient #---------------------------------------------- my_vars = {'w': w,'b': b} dldx = tape.gradient(loss,x) dldy = tape.gradient(loss,y) dldwb = tape.gradient(loss, my_vars) dldx_2 = tape.gradient(loss2,x) dldy_2 = tape.gradient(loss2,y2) dldwb_2 = tape.gradient(loss2, my_vars) print('w :',w) print('b :',b) print('x :',x) print('y :',y,y2) print('loss :',loss,loss2) print('dldx:',dldx, dldx_2) print('dldy:',dldy, dldy_2) print('dldwb:',dldwb, dldwb_2) dydx = tape.gradient(y,x) dydw = tape.gradient(y,w) dydb = tape.gradient(y,b) dydx_2 = tape.gradient(y2,x) dydw_2 = tape.gradient(y2,w) dydb_2 = tape.gradient(y2,b) print('dydx:',dydx, dydx_2) print('dydw:',dydw, dydw_2) print('dydb:',dydb, dydb_2)
The result of the code gives different gradients for y and y2. Obviously, I did something wrong but, could not figure out how to fix it. (When there was no tan function, y = x @ w +b, the code seems to work. But, it does not work with tan function.)
Advertisement
Answer
There is some confusion about y
and the _inner_function
– the inner function is x @ w + b
while tan
is the outer function. And since y
was set to tan(...)
, then later yp
was calculated like cos(tan(...))
which didn’t make sense.
This will give the correct results:
@tf.function def _inner_function(): z = x @ w + b return z z = _inner_function() y = tf.math.tan(z) def grads(upstream,variables): # here upstream is a shape of (batch_size,unit_size) assert variables[0] is w yp = 1.0/tf.square(tf.math.cos(z)) # tan'(z) = 1/cos(z)**2 dydx = (upstream*yp) @ tf.transpose(w) # (batch_size,input_size) dydw = tf.transpose(x) @ (upstream*yp) # (input_size,unit_size) dydb = tf.reduce_sum(upstream*yp,axis=0) # (unit_size,) return dydx, [dydw, dydb] return y, grads
Output:
w : <tf.Variable 'w:0' shape=(3, 2) dtype=float32, numpy= array([[1., 2.], [2., 3.], [3., 4.]], dtype=float32)> b : <tf.Variable 'b:0' shape=(2,) dtype=float32, numpy=array([1., 2.], dtype=float32)> x : tf.Tensor( [[1. 2. 3.] [1. 1. 1.]], shape=(2, 3), dtype=float32) y : tf.Tensor( [[-8.5599345e-01 8.8516558e-03] [ 8.7144798e-01 -2.2595085e+02]], shape=(2, 2), dtype=float32) tf.Tensor( [[-8.5599345e-01 8.8516558e-03] [ 8.7144798e-01 -2.2595085e+02]], shape=(2, 2), dtype=float32) loss : tf.Tensor(12763.819, shape=(), dtype=float32) tf.Tensor(12763.819, shape=(), dtype=float32) dldx: tf.Tensor( [[-7.3274821e-01 -1.4699227e+00 -2.2070971e+00] [-1.1535872e+07 -1.7303808e+07 -2.3071744e+07]], shape=(2, 3), dtype=float32) tf.Tensor( [[-7.3274809e-01 -1.4699224e+00 -2.2070966e+00] [-1.1535871e+07 -1.7303806e+07 -2.3071742e+07]], shape=(2, 3), dtype=float32) dldy: tf.Tensor( [[-4.27996725e-01 4.42582788e-03] [ 4.35723990e-01 -1.12975426e+02]], shape=(2, 2), dtype=float32) tf.Tensor( [[-4.27996725e-01 4.42582788e-03] [ 4.35723990e-01 -1.12975426e+02]], shape=(2, 2), dtype=float32) dldwb: {'w': <tf.Tensor: shape=(3, 2), dtype=float32, numpy= array([[ 2.5021553e-02, -5.7679365e+06], [-7.1657902e-01, -5.7679365e+06], [-1.4581797e+00, -5.7679365e+06]], dtype=float32)>, 'b': <tf.Tensor: shape=(2,), dtype=float32, numpy=array([ 2.5021553e-02, -5.7679365e+06], dtype=float32)>} {'w': <tf.Tensor: shape=(3, 2), dtype=float32, numpy= array([[ 2.5021732e-02, -5.7679360e+06], [-7.1657872e-01, -5.7679360e+06], [-1.4581791e+00, -5.7679360e+06]], dtype=float32)>, 'b': <tf.Tensor: shape=(2,), dtype=float32, numpy=array([ 2.5021732e-02, -5.7679360e+06], dtype=float32)>} dydx: tf.Tensor( [[3.73288178e+00 6.46568489e+00 9.19848824e+00] [1.02111336e+05 1.53167891e+05 2.04224438e+05]], shape=(2, 3), dtype=float32) tf.Tensor( [[3.7328813e+00 6.4656844e+00 9.1984873e+00] [1.0211133e+05 1.5316788e+05 2.0422442e+05]], shape=(2, 3), dtype=float32) dydw: tf.Tensor( [[3.4921465e+00 5.1055789e+04] [5.2248712e+00 5.1056789e+04] [6.9575958e+00 5.1057789e+04]], shape=(3, 2), dtype=float32) tf.Tensor( [[3.4921463e+00 5.1055785e+04] [5.2248707e+00 5.1056785e+04] [6.9575958e+00 5.1057785e+04]], shape=(3, 2), dtype=float32) dydb: tf.Tensor([3.4921465e+00 5.1055789e+04], shape=(2,), dtype=float32) tf.Tensor([3.4921463e+00 5.1055785e+04], shape=(2,), dtype=float32)