Would like to understand how to define user defined activation functions for neural networks with two learnable parameters, using tensorflow in python.
Any reference will be helpul or a case study?
Thank you
Advertisement
Answer
If you create a tf.Variable
within your model, Tensorflow will track its state and will adjust it as any other parameter. Such a tf.Variable
can be a parameter from your activation function.
Let’s start with some toy dataset.
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Dense
import matplotlib.pyplot as plt
from tensorflow.keras import Model
from sklearn.datasets import load_iris
iris = load_iris(return_X_y=True)
X = iris[0].astype(np.float32)
y = iris[1].astype(np.float32)
ds = tf.data.Dataset.from_tensor_slices((X, y)).shuffle(25).batch(8)
Now, let’s create a tf.keras.Model
and make a parametric ReLU function with the slope being learnable, and also the minimum value (usually 0 for classical ReLU). Let’s start with a PReLU slope/min value of 0.1 for now.
slope_values = list()
min_values = list()
class MyModel(Model):
def __init__(self):
super(MyModel, self).__init__()
self.prelu_slope = tf.Variable(0.1)
self.min_value = tf.Variable(0.1)
self.d0 = Dense(16, activation=self.prelu)
self.d1 = Dense(32, activation=self.prelu)
self.d2 = Dense(3, activation='softmax')
def prelu(self, x):
return tf.maximum(self.min_value, x * self.prelu_slope)
def call(self, x, **kwargs):
slope_values.append(self.prelu_slope.numpy())
min_values.append(self.min_value.numpy())
x = self.d0(x)
x = self.d1(x)
x = self.d2(x)
return x
model = MyModel()
Now, let’s train the model (in eager mode so we can keep the slope values).
model.compile(loss='sparse_categorical_crossentropy',
optimizer='adam', run_eagerly=True)
history = model.fit(ds, epochs=500, verbose=0)
Let’s look at the slope. Tensorflow is adjusting it to be the best slope for this task. As you will see it approaches non-parametric ReLU with a slope of 1.
plt.plot(slope_values, label='Slope Value')
plt.plot(min_values, label='Minimum Value')
plt.legend()
plt.title('Parametric ReLU Parameters Across Time')
plt.show()