Data is in a CSV file which includes image path and target (x and y). where x and y belong to [-1 to 1] after scale (in keras becasue there are so many image . I can not load all in X_train like normal). Thank you so much for help! data in csv file
Advertisement
Answer
I will write here as the comment section is getting bigger and bigger.
You can try to train the model using gradientTape. For details, please check here. With. that you will have more control over batches.
JavaScript
x
52
52
1
EPOCHS_start = 1
2
EPOCHS_end = 32
3
4
nb_train_steps = training_generator.n // BS
5
6
# whatever loss you want to use
7
loss_fn = BinaryCrossentropy(from_logits=False)
8
train_acc_metric1 = BinaryAccuracy()
9
10
for epoch in range(EPOCHS_start, EPOCHS_end):
11
print("Start of epoch %d" % (epoch,))
12
start_time = time.time()
13
loss_total = tf.Variable(0.0)
14
15
# Iterate over the batches of the dataset.
16
for step, (x_batch_train, y_batch_train) in enumerate(training_generator):
17
with tf.GradientTape() as tape:
18
x_batch_train_scaled = rescale_img(x_batch_train)
19
# you have to create your model before
20
logits = model(x_batch_train_scaled, training=True)
21
y_batch_train =
22
np.asarray(y_batch_train).astype('float32').reshape((-1, 1))
23
24
loss_value = loss_fn(y_batch_train, logits)
25
print("loss value : %.4f " % loss_value)
26
27
grads = tape.gradient(loss_value, model.trainable_weights)
28
optimizer.apply_gradients(zip(grads, model.trainable_weights))
29
30
loss_total = loss_total + loss_value
31
32
# Update training metric.
33
train_acc_metric1.update_state(y_batch_train, logits)
34
35
if step >= nb_train_steps:
36
# we need to break the loop by hand because
37
# the generator loops indefinitely
38
break
39
40
# Display metrics at the end of each epoch.
41
train_acc = train_acc_metric1.result()
42
43
44
print("Training acc over epoch: ",
45
"%.4f - %.4f - LOSS: %.4f ** Time taken: %.2fs" %
46
(float(train_acc),
47
float(loss_total.numpy() / nb_train_steps),
48
(time.time() - start_time)))
49
50
train_acc_metric1.reset_states()
51
52