Would like to understand how to define user defined activation functions for neural networks with two learnable parameters, using tensorflow in python.
Any reference will be helpul or a case study?
Thank you
Advertisement
Answer
If you create a tf.Variable
within your model, Tensorflow will track its state and will adjust it as any other parameter. Such a tf.Variable
can be a parameter from your activation function.
Let’s start with some toy dataset.
import numpy as np import tensorflow as tf from tensorflow.keras.layers import Dense import matplotlib.pyplot as plt from tensorflow.keras import Model from sklearn.datasets import load_iris iris = load_iris(return_X_y=True) X = iris[0].astype(np.float32) y = iris[1].astype(np.float32) ds = tf.data.Dataset.from_tensor_slices((X, y)).shuffle(25).batch(8)
Now, let’s create a tf.keras.Model
and make a parametric ReLU function with the slope being learnable, and also the minimum value (usually 0 for classical ReLU). Let’s start with a PReLU slope/min value of 0.1 for now.
slope_values = list() min_values = list() class MyModel(Model): def __init__(self): super(MyModel, self).__init__() self.prelu_slope = tf.Variable(0.1) self.min_value = tf.Variable(0.1) self.d0 = Dense(16, activation=self.prelu) self.d1 = Dense(32, activation=self.prelu) self.d2 = Dense(3, activation='softmax') def prelu(self, x): return tf.maximum(self.min_value, x * self.prelu_slope) def call(self, x, **kwargs): slope_values.append(self.prelu_slope.numpy()) min_values.append(self.min_value.numpy()) x = self.d0(x) x = self.d1(x) x = self.d2(x) return x model = MyModel()
Now, let’s train the model (in eager mode so we can keep the slope values).
model.compile(loss='sparse_categorical_crossentropy', optimizer='adam', run_eagerly=True) history = model.fit(ds, epochs=500, verbose=0)
Let’s look at the slope. Tensorflow is adjusting it to be the best slope for this task. As you will see it approaches non-parametric ReLU with a slope of 1.
plt.plot(slope_values, label='Slope Value') plt.plot(min_values, label='Minimum Value') plt.legend() plt.title('Parametric ReLU Parameters Across Time') plt.show()