I am new to AI/ML stuff. I’m learning TensorFlow. In some tutorial, I noticed that the input_shape
argument of a Conv2D
layer was specified only for the first. Code looked kinda like this:
model = tf.keras.models.Sequential([ tf.keras.layers.Conv2D(16, (3,3), activation='relu', input_shape=(300,300,3)), tf.keras.layers.MaxPooling2D(2, 2), tf.keras.layers.Conv2D(32, (3,3), activation='relu'), tf.keras.layers.MaxPooling2D(2,2), tf.keras.layers.Conv2D(64, (3,3), activation='relu'), tf.keras.layers.MaxPooling2D(2,2), tf.keras.layers.Flatten(), tf.keras.layers.Dense(512, activation='relu'), tf.keras.layers.Dense(1, activation='sigmoid') ])
In many examples, not only in the above, the instructor didn’t include that argument in there. Is there any reason for that?
Advertisement
Answer
The next layers derive the required shape from the output of the previous layer. That is, the MaxPooling2D
layer derives its input shape based on the output of the Conv2D
layer and so on. Note that in your sequential model, you don’t even need to define an input_shape in the first layer. It is able to derive the input_shape
if you feed it real data, which gives you a bit more flexibility since you don’t have to hard-code the input shape:
import tensorflow as tf tf.random.set_seed(1) model = tf.keras.models.Sequential([ tf.keras.layers.Conv2D(16, (3,3), activation='relu',), tf.keras.layers.MaxPooling2D(2, 2), tf.keras.layers.Conv2D(32, (3,3), activation='relu'), tf.keras.layers.MaxPooling2D(2,2), tf.keras.layers.Conv2D(64, (3,3), activation='relu'), tf.keras.layers.MaxPooling2D(2,2), tf.keras.layers.Flatten(), tf.keras.layers.Dense(512, activation='relu'), tf.keras.layers.Dense(1, activation='sigmoid') ]) print(model(tf.random.normal((1, 300, 300, 3))))
tf.Tensor([[0.6059081]], shape=(1, 1), dtype=float32)
If data with an incorrect shape, for example (300, 3) instead of (300, 300, 3), is passed to your model, an error occurs because a Conv2D
layer requires a 3D input excluding the batch dimension.
If your model does not have an input_shape
, you will, however, not be able to call model.summary()
to view your network. First you would have to build your model with an input shape:
model.build(input_shape=(1, 300, 300, 3)) model.summary()