Skip to content
Advertisement

How to make a custom gradient for a multiple output function?

I would like to know how to write a custom gradient for a function which have multiple outputs( or an array). For a simple example, I wrote the following code for y=tan( x @ w + b) with x shape is (2,3) and y shape is (2,2). To compare results, I calculated the operation by usual way and by the custom gradient.

Here is the code.

#-----------------------------------------------------------------
# Gradient test example 
#-----------------------------------------------------------------
w = tf.Variable([[1.,2.],[2.,3.],[3.,4.]], name='w') # shape (input_size ,unit_size) =  (3,2) 
b = tf.Variable([1.0,2.0], name='b') #shape(2,)
x = tf.constant([[1., 2., 3.],[1.,1.,1.]]) #shape = (batch_size,input_size)  (2,3) 

@tf.custom_gradient
def custom_op(x):
    @tf.function
    def _inner_function():
        y = x @ w + b
        return tf.math.tan(y) # y shape = (batch_size,unit_size)
    y = _inner_function()             
    def grads(upstream,variables):
        # here upstream is a shape of (batch_size,unit_size)  
        assert variables[0] is w
        yp = 1.0/tf.square(tf.math.cos(y)) # tan'(y) = 1/cos(y)**2 
        dydx = (upstream*yp) @ tf.transpose(w)  # (batch_size,input_size)
        dydw = tf.transpose(x) @ (upstream*yp)  # (input_size,unit_size)
        dydb = tf.reduce_sum(upstream*yp,axis=0) # (unit_size,)
        return dydx, [dydw, dydb]
    return y, grads
#-----------------------------------------------
# feed forward 
#-----------------------------------------------
with tf.GradientTape(persistent=True) as tape:
    tape.watch(x)
    y = tf.math.tan(x @ w + b )# shape (1,2)   
    y2 = custom_op(x)
    loss = tf.reduce_mean(y**2)    
    loss2 = tf.reduce_mean(y2**2)    
#----------------------------------------------
# compute gradient 
#----------------------------------------------
my_vars = {'w': w,'b': b}
dldx = tape.gradient(loss,x)
dldy = tape.gradient(loss,y)
dldwb = tape.gradient(loss, my_vars)
dldx_2 = tape.gradient(loss2,x)
dldy_2 = tape.gradient(loss2,y2)
dldwb_2 = tape.gradient(loss2, my_vars)
print('w :',w)
print('b :',b)
print('x :',x)    
print('y :',y,y2)
print('loss :',loss,loss2)
print('dldx:',dldx, dldx_2)
print('dldy:',dldy, dldy_2)
print('dldwb:',dldwb, dldwb_2)
dydx = tape.gradient(y,x)  
dydw = tape.gradient(y,w)  
dydb = tape.gradient(y,b)
dydx_2 = tape.gradient(y2,x)
dydw_2 = tape.gradient(y2,w)
dydb_2 = tape.gradient(y2,b)

print('dydx:',dydx, dydx_2)
print('dydw:',dydw, dydw_2)
print('dydb:',dydb, dydb_2)

The result of the code gives different gradients for y and y2. Obviously, I did something wrong but, could not figure out how to fix it. (When there was no tan function, y = x @ w +b, the code seems to work. But, it does not work with tan function.)

Advertisement

Answer

There is some confusion about y and the _inner_function – the inner function is x @ w + b while tan is the outer function. And since y was set to tan(...), then later yp was calculated like cos(tan(...)) which didn’t make sense.

This will give the correct results:

    @tf.function
    def _inner_function():
        z = x @ w + b
        return z
    z = _inner_function()
    y = tf.math.tan(z)
           
    def grads(upstream,variables):
        # here upstream is a shape of (batch_size,unit_size)  
        assert variables[0] is w
        yp = 1.0/tf.square(tf.math.cos(z)) # tan'(z) = 1/cos(z)**2 
        dydx = (upstream*yp) @ tf.transpose(w)  # (batch_size,input_size)
        dydw = tf.transpose(x) @ (upstream*yp)  # (input_size,unit_size)
        dydb = tf.reduce_sum(upstream*yp,axis=0) # (unit_size,)
        return dydx, [dydw, dydb]
    return y, grads

Output:

    w : <tf.Variable 'w:0' shape=(3, 2) dtype=float32, numpy=
array([[1., 2.],
       [2., 3.],
       [3., 4.]], dtype=float32)>
b : <tf.Variable 'b:0' shape=(2,) dtype=float32, numpy=array([1., 2.], dtype=float32)>
x : tf.Tensor(
[[1. 2. 3.]
 [1. 1. 1.]], shape=(2, 3), dtype=float32)
y : tf.Tensor(
[[-8.5599345e-01  8.8516558e-03]
 [ 8.7144798e-01 -2.2595085e+02]], shape=(2, 2), dtype=float32) tf.Tensor(
[[-8.5599345e-01  8.8516558e-03]
 [ 8.7144798e-01 -2.2595085e+02]], shape=(2, 2), dtype=float32)
loss : tf.Tensor(12763.819, shape=(), dtype=float32) tf.Tensor(12763.819, shape=(), dtype=float32)
dldx: tf.Tensor(
[[-7.3274821e-01 -1.4699227e+00 -2.2070971e+00]
 [-1.1535872e+07 -1.7303808e+07 -2.3071744e+07]], shape=(2, 3), dtype=float32) tf.Tensor(
[[-7.3274809e-01 -1.4699224e+00 -2.2070966e+00]
 [-1.1535871e+07 -1.7303806e+07 -2.3071742e+07]], shape=(2, 3), dtype=float32)
dldy: tf.Tensor(
[[-4.27996725e-01  4.42582788e-03]
 [ 4.35723990e-01 -1.12975426e+02]], shape=(2, 2), dtype=float32) tf.Tensor(
[[-4.27996725e-01  4.42582788e-03]
 [ 4.35723990e-01 -1.12975426e+02]], shape=(2, 2), dtype=float32)
dldwb: {'w': <tf.Tensor: shape=(3, 2), dtype=float32, numpy=
array([[ 2.5021553e-02, -5.7679365e+06],
       [-7.1657902e-01, -5.7679365e+06],
       [-1.4581797e+00, -5.7679365e+06]], dtype=float32)>, 'b': <tf.Tensor: shape=(2,), dtype=float32, numpy=array([ 2.5021553e-02, -5.7679365e+06], dtype=float32)>} {'w': <tf.Tensor: shape=(3, 2), dtype=float32, numpy=
array([[ 2.5021732e-02, -5.7679360e+06],
       [-7.1657872e-01, -5.7679360e+06],
       [-1.4581791e+00, -5.7679360e+06]], dtype=float32)>, 'b': <tf.Tensor: shape=(2,), dtype=float32, numpy=array([ 2.5021732e-02, -5.7679360e+06], dtype=float32)>}
dydx: tf.Tensor(
[[3.73288178e+00 6.46568489e+00 9.19848824e+00]
 [1.02111336e+05 1.53167891e+05 2.04224438e+05]], shape=(2, 3), dtype=float32) tf.Tensor(
[[3.7328813e+00 6.4656844e+00 9.1984873e+00]
 [1.0211133e+05 1.5316788e+05 2.0422442e+05]], shape=(2, 3), dtype=float32)
dydw: tf.Tensor(
[[3.4921465e+00 5.1055789e+04]
 [5.2248712e+00 5.1056789e+04]
 [6.9575958e+00 5.1057789e+04]], shape=(3, 2), dtype=float32) tf.Tensor(
[[3.4921463e+00 5.1055785e+04]
 [5.2248707e+00 5.1056785e+04]
 [6.9575958e+00 5.1057785e+04]], shape=(3, 2), dtype=float32)
dydb: tf.Tensor([3.4921465e+00 5.1055789e+04], shape=(2,), dtype=float32) tf.Tensor([3.4921463e+00 5.1055785e+04], shape=(2,), dtype=float32)
User contributions licensed under: CC BY-SA
3 People found this is helpful
Advertisement