Skip to content
Advertisement

Tensorflow/keras: “logits and labels must have the same first dimension” How to squeeze logits or expand labels?

I’m trying to make a simple CNN classifier model. For my training images (BATCH_SIZEx227x227x1) and labels (BATCH_SIZEx7) datasets, I’m using numpy ndarrays that are fed to the model in batches via ImageDataGenerator. The loss function I’m using is tf.nn.sparse_categorical_crossentropy. The problem arises when the model tries to train; the model (batch size here is 1 for my simplified experimentations) outputs a shape of [1, 7] and labels is shape [7].

I’m almost positive I know the cause of this, but I am unsure how to fix it. My hypothesis is that sparse_categorical_crossentropy is squeezing the dimensions of my labels (e.g. when BATCH_SIZE is 2, the input, ground-truth label shape is squeezed from [2, 7] to [14]), making it impossible for me to fix the label shape, and all my attempts to fix logits shape have been fruitless.

I originally tried fixing labels shape with np.expand_dims. But the loss function always flattens the labels, no matter how I expand the dimensions.

Following that, I tried adding a tf.keras.layers.Flatten() at the end of my model to get rid of the extraneous first dimension, but it had no effect; I still got the same exact error. Following that, tried using tf.keras.layers.Reshape((-1,)) to squeeze all the dimensions. However, that resulted in a different error:

in sparse_categorical_crossentropy logits = array_ops.reshape(output, [-1, int(output_shape[-1])]) TypeError: int returned non-int (type NoneType)

Question: How can I squash the shape of the logits to be the same shape as the labels returned by the sparse_categorical_crossentropy?

JavaScript

— full error trace —

JavaScript

Advertisement

Answer

No, you got the cause all wrong. You are giving one-hot encoded labels, but sparse_categorical_crossentropy expects integer labels, as it does the one-hot encoding itself (hence, sparse).

An easy solution would be to change loss to categorical_crossentropy, not the sparse version. Also note that y_true with shape (7,) is incorrect, it should be (1, 7).

User contributions licensed under: CC BY-SA
3 People found this is helpful
Advertisement