Skip to content
Advertisement

ImageDataGenerator() for CNN with input and output as an Image

I’m looking for a training map with something like this:

Grayscale Image -> Coloured Image

But the dataset can’t be loaded all to the ram as X and Y because of obvious reasons.

I looked up the ImageDataGenerator() library, but it didn’t give me a clear answer as to make it work here.

Summary:

Input Shape = (2048, 2048, 1)

Output Shape = (2048, 2048, 2)

Training Dataset = 17,000 images

Validation Dataset = 1,000 images

Here’s the structure of the model I’m trying to train:

Model: "functional_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            [(None, 2048, 2048,  0                                            
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 2048, 2048, 1 160         input_1[0][0]                    
__________________________________________________________________________________________________
leaky_re_lu (LeakyReLU)         (None, 2048, 2048, 1 0           conv2d[0][0]                     
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 2048, 2048, 3 4640        leaky_re_lu[0][0]                
__________________________________________________________________________________________________
leaky_re_lu_1 (LeakyReLU)       (None, 2048, 2048, 3 0           conv2d_1[0][0]                   
__________________________________________________________________________________________________
batch_normalization (BatchNorma (None, 2048, 2048, 3 128         leaky_re_lu_1[0][0]              
__________________________________________________________________________________________________
max_pooling2d (MaxPooling2D)    (None, 1024, 1024, 3 0           batch_normalization[0][0]        
__________________________________________________________________________________________________
conv2d_2 (Conv2D)               (None, 1024, 1024, 6 18496       max_pooling2d[0][0]              
__________________________________________________________________________________________________
leaky_re_lu_2 (LeakyReLU)       (None, 1024, 1024, 6 0           conv2d_2[0][0]                   
__________________________________________________________________________________________________
batch_normalization_1 (BatchNor (None, 1024, 1024, 6 256         leaky_re_lu_2[0][0]              
__________________________________________________________________________________________________
max_pooling2d_1 (MaxPooling2D)  (None, 512, 512, 64) 0           batch_normalization_1[0][0]      
__________________________________________________________________________________________________
conv2d_3 (Conv2D)               (None, 512, 512, 128 73856       max_pooling2d_1[0][0]            
__________________________________________________________________________________________________
leaky_re_lu_3 (LeakyReLU)       (None, 512, 512, 128 0           conv2d_3[0][0]                   
__________________________________________________________________________________________________
batch_normalization_2 (BatchNor (None, 512, 512, 128 512         leaky_re_lu_3[0][0]              
__________________________________________________________________________________________________
conv2d_4 (Conv2D)               (None, 512, 512, 256 295168      batch_normalization_2[0][0]      
__________________________________________________________________________________________________
leaky_re_lu_4 (LeakyReLU)       (None, 512, 512, 256 0           conv2d_4[0][0]                   
__________________________________________________________________________________________________
batch_normalization_3 (BatchNor (None, 512, 512, 256 1024        leaky_re_lu_4[0][0]              
__________________________________________________________________________________________________
up_sampling2d (UpSampling2D)    (None, 1024, 1024, 2 0           batch_normalization_3[0][0]      
__________________________________________________________________________________________________
conv2d_5 (Conv2D)               (None, 1024, 1024, 1 295040      up_sampling2d[0][0]              
__________________________________________________________________________________________________
leaky_re_lu_5 (LeakyReLU)       (None, 1024, 1024, 1 0           conv2d_5[0][0]                   
__________________________________________________________________________________________________
batch_normalization_4 (BatchNor (None, 1024, 1024, 1 512         leaky_re_lu_5[0][0]              
__________________________________________________________________________________________________
up_sampling2d_1 (UpSampling2D)  (None, 2048, 2048, 1 0           batch_normalization_4[0][0]      
__________________________________________________________________________________________________
conv2d_6 (Conv2D)               (None, 2048, 2048, 6 73792       up_sampling2d_1[0][0]            
__________________________________________________________________________________________________
leaky_re_lu_6 (LeakyReLU)       (None, 2048, 2048, 6 0           conv2d_6[0][0]                   
__________________________________________________________________________________________________
concatenate (Concatenate)       (None, 2048, 2048, 6 0           leaky_re_lu_6[0][0]              
                                                                 input_1[0][0]                    
__________________________________________________________________________________________________
conv2d_7 (Conv2D)               (None, 2048, 2048, 6 37504       concatenate[0][0]                
__________________________________________________________________________________________________
leaky_re_lu_7 (LeakyReLU)       (None, 2048, 2048, 6 0           conv2d_7[0][0]                   
__________________________________________________________________________________________________
batch_normalization_5 (BatchNor (None, 2048, 2048, 6 256         leaky_re_lu_7[0][0]              
__________________________________________________________________________________________________
conv2d_8 (Conv2D)               (None, 2048, 2048, 3 18464       batch_normalization_5[0][0]      
__________________________________________________________________________________________________
leaky_re_lu_8 (LeakyReLU)       (None, 2048, 2048, 3 0           conv2d_8[0][0]                   
__________________________________________________________________________________________________
conv2d_9 (Conv2D)               (None, 2048, 2048, 2 578         leaky_re_lu_8[0][0]              
==================================================================================================
Total params: 820,386
Trainable params: 819,042
Non-trainable params: 1,344
__________________________________________________________________________________________________

Advertisement

Answer

That would be easiest with a custom training loop.

def reconstruct(colored_inputs):
    with tf.GradientTape() as tape:
        grayscale_inputs = tf.image.rgb_to_grayscale(colored_inputs)

        out = autoencoder(grayscale_inputs)
        loss = loss_object(colored_inputs, out)

    gradients = tape.gradient(loss, autoencoder.trainable_variables)
    optimizer.apply_gradients(zip(gradients, autoencoder.trainable_variables))

    reconstruction_loss(loss)

Here, my data iterator is cyling through all the color pictures, but its converted to grayscale before being passed to the model. Then, the RGB output of the model is compared to the original RGB image. You will have to use the argument class_mode=None in flow_from_directory. I used tf.image.rgb_to_grayscale to make the conversion between grayscale and RGB.

Full example:

import tensorflow as tf
physical_devices = tf.config.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(physical_devices[0], True)
import os

os.chdir(r'catsanddogs')

generator = tf.keras.preprocessing.image.ImageDataGenerator()
iterator = generator.flow_from_directory(
    target_size=(32, 32),
    directory='.',
    batch_size=4,
    class_mode=None)

encoder = tf.keras.Sequential([
    tf.keras.layers.Flatten(input_shape=(32, 32, 1)),
    tf.keras.layers.Dense(32),
    tf.keras.layers.Dense(16)
])

decoder = tf.keras.Sequential([
    tf.keras.layers.Dense(32, input_shape=[16]),
    tf.keras.layers.Dense(32 * 32 * 3),
    tf.keras.layers.Reshape([32, 32, 3])
])


autoencoder = tf.keras.Sequential([encoder, decoder])

loss_object = tf.losses.BinaryCrossentropy()

reconstruction_loss = tf.metrics.Mean(name='reconstruction_loss')

optimizer = tf.optimizers.Adam()


def reconstruct(colored_inputs):
    with tf.GradientTape() as tape:
        grayscale_inputs = tf.image.rgb_to_grayscale(colored_inputs)

        out = autoencoder(grayscale_inputs)
        loss = loss_object(colored_inputs, out)

    gradients = tape.gradient(loss, autoencoder.trainable_variables)
    optimizer.apply_gradients(zip(gradients, autoencoder.trainable_variables))

    reconstruction_loss(loss)


if __name__ == '__main__':
    template = 'Epoch {:2} Reconstruction Loss {:.4f}'
    for epoch in range(50):
        reconstruction_loss.reset_states()
        for input_batches in iterator:
            reconstruct(input_batches)
        print(template.format(epoch + 1, reconstruction_loss.result()))
Advertisement