I’m looking for a training map with something like this:
Grayscale Image -> Coloured Image
But the dataset can’t be loaded all to the ram as X and Y because of obvious reasons.
I looked up the ImageDataGenerator() library, but it didn’t give me a clear answer as to make it work here.
Summary:
Input Shape = (2048, 2048, 1)
Output Shape = (2048, 2048, 2)
Training Dataset = 17,000 images
Validation Dataset = 1,000 images
Here’s the structure of the model I’m trying to train:
Model: "functional_1" __________________________________________________________________________________________________ Layer (type) Output Shape Param # Connected to ================================================================================================== input_1 (InputLayer) [(None, 2048, 2048, 0 __________________________________________________________________________________________________ conv2d (Conv2D) (None, 2048, 2048, 1 160 input_1[0][0] __________________________________________________________________________________________________ leaky_re_lu (LeakyReLU) (None, 2048, 2048, 1 0 conv2d[0][0] __________________________________________________________________________________________________ conv2d_1 (Conv2D) (None, 2048, 2048, 3 4640 leaky_re_lu[0][0] __________________________________________________________________________________________________ leaky_re_lu_1 (LeakyReLU) (None, 2048, 2048, 3 0 conv2d_1[0][0] __________________________________________________________________________________________________ batch_normalization (BatchNorma (None, 2048, 2048, 3 128 leaky_re_lu_1[0][0] __________________________________________________________________________________________________ max_pooling2d (MaxPooling2D) (None, 1024, 1024, 3 0 batch_normalization[0][0] __________________________________________________________________________________________________ conv2d_2 (Conv2D) (None, 1024, 1024, 6 18496 max_pooling2d[0][0] __________________________________________________________________________________________________ leaky_re_lu_2 (LeakyReLU) (None, 1024, 1024, 6 0 conv2d_2[0][0] __________________________________________________________________________________________________ batch_normalization_1 (BatchNor (None, 1024, 1024, 6 256 leaky_re_lu_2[0][0] __________________________________________________________________________________________________ max_pooling2d_1 (MaxPooling2D) (None, 512, 512, 64) 0 batch_normalization_1[0][0] __________________________________________________________________________________________________ conv2d_3 (Conv2D) (None, 512, 512, 128 73856 max_pooling2d_1[0][0] __________________________________________________________________________________________________ leaky_re_lu_3 (LeakyReLU) (None, 512, 512, 128 0 conv2d_3[0][0] __________________________________________________________________________________________________ batch_normalization_2 (BatchNor (None, 512, 512, 128 512 leaky_re_lu_3[0][0] __________________________________________________________________________________________________ conv2d_4 (Conv2D) (None, 512, 512, 256 295168 batch_normalization_2[0][0] __________________________________________________________________________________________________ leaky_re_lu_4 (LeakyReLU) (None, 512, 512, 256 0 conv2d_4[0][0] __________________________________________________________________________________________________ batch_normalization_3 (BatchNor (None, 512, 512, 256 1024 leaky_re_lu_4[0][0] __________________________________________________________________________________________________ up_sampling2d (UpSampling2D) (None, 1024, 1024, 2 0 batch_normalization_3[0][0] __________________________________________________________________________________________________ conv2d_5 (Conv2D) (None, 1024, 1024, 1 295040 up_sampling2d[0][0] __________________________________________________________________________________________________ leaky_re_lu_5 (LeakyReLU) (None, 1024, 1024, 1 0 conv2d_5[0][0] __________________________________________________________________________________________________ batch_normalization_4 (BatchNor (None, 1024, 1024, 1 512 leaky_re_lu_5[0][0] __________________________________________________________________________________________________ up_sampling2d_1 (UpSampling2D) (None, 2048, 2048, 1 0 batch_normalization_4[0][0] __________________________________________________________________________________________________ conv2d_6 (Conv2D) (None, 2048, 2048, 6 73792 up_sampling2d_1[0][0] __________________________________________________________________________________________________ leaky_re_lu_6 (LeakyReLU) (None, 2048, 2048, 6 0 conv2d_6[0][0] __________________________________________________________________________________________________ concatenate (Concatenate) (None, 2048, 2048, 6 0 leaky_re_lu_6[0][0] input_1[0][0] __________________________________________________________________________________________________ conv2d_7 (Conv2D) (None, 2048, 2048, 6 37504 concatenate[0][0] __________________________________________________________________________________________________ leaky_re_lu_7 (LeakyReLU) (None, 2048, 2048, 6 0 conv2d_7[0][0] __________________________________________________________________________________________________ batch_normalization_5 (BatchNor (None, 2048, 2048, 6 256 leaky_re_lu_7[0][0] __________________________________________________________________________________________________ conv2d_8 (Conv2D) (None, 2048, 2048, 3 18464 batch_normalization_5[0][0] __________________________________________________________________________________________________ leaky_re_lu_8 (LeakyReLU) (None, 2048, 2048, 3 0 conv2d_8[0][0] __________________________________________________________________________________________________ conv2d_9 (Conv2D) (None, 2048, 2048, 2 578 leaky_re_lu_8[0][0] ================================================================================================== Total params: 820,386 Trainable params: 819,042 Non-trainable params: 1,344 __________________________________________________________________________________________________
Advertisement
Answer
That would be easiest with a custom training loop.
def reconstruct(colored_inputs): with tf.GradientTape() as tape: grayscale_inputs = tf.image.rgb_to_grayscale(colored_inputs) out = autoencoder(grayscale_inputs) loss = loss_object(colored_inputs, out) gradients = tape.gradient(loss, autoencoder.trainable_variables) optimizer.apply_gradients(zip(gradients, autoencoder.trainable_variables)) reconstruction_loss(loss)
Here, my data iterator is cyling through all the color pictures, but its converted to grayscale before being passed to the model. Then, the RGB output of the model is compared to the original RGB image. You will have to use the argument class_mode=None
in flow_from_directory
. I used tf.image.rgb_to_grayscale
to make the conversion between grayscale and RGB.
Full example:
import tensorflow as tf physical_devices = tf.config.list_physical_devices('GPU') tf.config.experimental.set_memory_growth(physical_devices[0], True) import os os.chdir(r'catsanddogs') generator = tf.keras.preprocessing.image.ImageDataGenerator() iterator = generator.flow_from_directory( target_size=(32, 32), directory='.', batch_size=4, class_mode=None) encoder = tf.keras.Sequential([ tf.keras.layers.Flatten(input_shape=(32, 32, 1)), tf.keras.layers.Dense(32), tf.keras.layers.Dense(16) ]) decoder = tf.keras.Sequential([ tf.keras.layers.Dense(32, input_shape=[16]), tf.keras.layers.Dense(32 * 32 * 3), tf.keras.layers.Reshape([32, 32, 3]) ]) autoencoder = tf.keras.Sequential([encoder, decoder]) loss_object = tf.losses.BinaryCrossentropy() reconstruction_loss = tf.metrics.Mean(name='reconstruction_loss') optimizer = tf.optimizers.Adam() def reconstruct(colored_inputs): with tf.GradientTape() as tape: grayscale_inputs = tf.image.rgb_to_grayscale(colored_inputs) out = autoencoder(grayscale_inputs) loss = loss_object(colored_inputs, out) gradients = tape.gradient(loss, autoencoder.trainable_variables) optimizer.apply_gradients(zip(gradients, autoencoder.trainable_variables)) reconstruction_loss(loss) if __name__ == '__main__': template = 'Epoch {:2} Reconstruction Loss {:.4f}' for epoch in range(50): reconstruction_loss.reset_states() for input_batches in iterator: reconstruct(input_batches) print(template.format(epoch + 1, reconstruction_loss.result()))