I created a model class which is a subclass of keras.Model. While training the model, I want to change the weights of the loss functions after some epochs. In order to do that I created boolean variables to my model indicating that the model should start training with additional loss function. I add a pseudo code that mainly shows what I am trying to achieve.
JavaScript
x
24
24
1
class MyModel(keras.Model):
2
self.start_loss_2 = False
3
4
5
def train_step(self):
6
# Check if training with loss_2 started
7
weight_loss_2 = 0.0
8
if self.start_loss_2:
9
weight_loss_2 = 0.5
10
11
# Pass the data through model
12
# Calculate two loss values
13
total_loss = loss_1 + weight_loss_2 * loss_2
14
# Calculate gradients with tf.Tape
15
# Update variables
16
17
18
# This is called via Callback after each epoch
19
def epoch_finised(epoch_num):
20
if epoch_num > START_LOSS_2:
21
self.start_loss_2 = True
22
23
24
My questions is:
- Is it valid to use if-else statement whose value changes after some time? If it is not, how can achieve this?
Advertisement
Answer
Yes. You can create a tf.Variable
and then assign a new value to it based on some training criteria.
Example:
JavaScript
1
67
67
1
import numpy as np
2
import tensorflow as tf
3
4
5
# simple toy network
6
x_in = tf.keras.Input((10))
7
x = tf.keras.layers.Dense(25)(x_in)
8
x_out = tf.keras.layers.Dense(1)(x)
9
10
# model
11
m = tf.keras.Model(x_in, x_out)
12
13
# fake data
14
X = tf.random.normal((100, 10))
15
y0 = tf.random.normal((100, ))
16
y1 = tf.random.normal((100, ))
17
18
# optimizer
19
m_opt = tf.keras.optimizers.Adam(1e-2)
20
21
# prep data
22
ds = tf.data.Dataset.from_tensor_slices((X, y0, y1))
23
ds = ds.repeat().batch(5)
24
train_iter = iter(ds)
25
26
# toy loss function that uses a weight
27
def loss_fn(y_true0, y_true1, y_pred, weight):
28
mse = tf.keras.losses.MSE
29
mse_0 = tf.math.reduce_mean(mse(y_true0, y_pred))
30
mse_1 = tf.math.reduce_mean(mse(y_true1, y_pred))
31
return mse_0 + weight * mse_1
32
33
NUM_EPOCHS = 4
34
NUM_BATCHES_PER_EPOCH = 10
35
START_NEW_LOSS_AT_GLOBAL_STEP = 20
36
37
# the weight variable set to 0 initially and then
38
# will be changed after a certain number of steps
39
# (or some other training criteria)
40
w = tf.Variable(0.0, trainable=False)
41
42
for epoch in range(NUM_EPOCHS):
43
losses = []
44
for batch in range(NUM_BATCHES_PER_EPOCH):
45
X_train, y0_train, y1_train = next(train_iter)
46
with tf.GradientTape() as tape:
47
y_hat = m(X_train)
48
loss = loss_fn(y0_train, y1_train, y_hat, w)
49
losses.append(loss)
50
51
m_vars = m.trainable_variables
52
m_grads = tape.gradient(loss, m_vars)
53
m_opt.apply_gradients(zip(m_grads, m_vars))
54
55
print(f"epoch: {epoch}tloss: {np.mean(losses):.4f}")
56
losses = []
57
58
# if the criteria is met assign a huge number to see if the
59
# loss spikes up
60
if (epoch + 1) * (batch + 1) >= START_NEW_LOSS_AT_GLOBAL_STEP:
61
w.assign(10000.0)
62
63
# epoch: 0 loss: 1.8226
64
# epoch: 1 loss: 1.1143
65
# epoch: 2 loss: 8788.2227 <= looks like assign worked
66
# epoch: 3 loss: 10999.5449
67