Skip to content
Advertisement

Rotate image for data augmentation using tf keras only in specific angles

In tf keras, it is possible to have a data augmentation layer that performs rotation on each given image during training, in the following way as the docs say:

tf.keras.layers.RandomRotation(
    factor, fill_mode='reflect', interpolation='bilinear',
    seed=None, fill_value=0.0, **kwargs
)

The factor argument indicates the value of maximum rotation if a float is given and indicates lower and upper limits if a tuple is given.

For my specific application only specific rotations are allowed, say 0°, 90°, 180° and 270°.

Is there any way I can achieve this using the RandomRotation class or a good alternative to this or should I just augment the whole dataset before training?

Advertisement

Answer

You can do this by creating a custom PreprocessingLayer.

import tensorflow as tf
    
class Rotate90Randomly(tf.keras.layers.experimental.preprocessing.PreprocessingLayer):
    def __init__(self):
        super(Rotate90Randomly, self).__init__()

    def call(self, x, training=False):
        def random_rotate():
            rotation_factor = tf.random.uniform([], minval=0,
                                                maxval=4, dtype=tf.int32)
            return tf.image.rot90(x, k=rotation_factor)

        training = tf.constant(training, dtype=tf.bool)
        
        rotated = tf.cond(training, random_rotate, lambda: x)
        rotated.set_shape(rotated.shape)
        return rotated

One thing to consider, if the inputs’ height and width are not the same, in other words they are not square you need to define input_shape as (None, None, channels) while creating the model.

Examples:

model = tf.keras.Sequential([
                             tf.keras.Input((180,180,3)),
                             Rotate90Randomly()])

import matplotlib.pyplot as plt

plt.figure(figsize=(10, 10))
for images, labels in train_ds.take(1):
  for i in range(9):
    ax = plt.subplot(3, 3, i + 1)
    images = model(images, training = True)
    plt.imshow(images[i].numpy().astype("uint8"))
    plt.title(class_names[labels[i]])
    plt.axis("off")

enter image description here

With training = False, they remain the same so this layer is not active during inference.

enter image description here

Advertisement