Multiclassification task using keras [closed]

Tags: , , , ,

Classification (not detection!) of several objects in one image is the problem. How can I do this using keras.

For example if I have 6 classes (dogs,cats,birds,…) and two different objects (a cat and a bird) in this image. The label would be of the form: [0,1,1,0,0,0] Which metric, loss function and optimizer is recommended? I would like to use CNN.


The keyword is “multilabel classification“. In the output layer you have multiple neurons, each neuron representing one of your classes.

Now you should use a binary classification for each neuron independently. So if you have 3 classes, the output of your network could be [0.1, 0.8, 0.99] which means the following: The first class is true for your image with the probability 10 %, the second is true with 80 % and the last class is true with 99 %. So the network decided for two classes to be true at the same time for a single input image!

It’s pretty easy to implement this into Keras/Tensorflow. You could use some binary_crossentropy as your loss function and Sigmoid Function as the activation in your last layer. Therefore you get values in the interval (0, 1) for every output neuron. As the metric you could use accuracy, which tells you how many images are classified in the right way (as relative frequency).

See the following example:

from tensorflow.keras.layers import *
from tensorflow.keras.activations import *
from tensorflow.keras.models import *
from tensorflow.keras.optimizers import *
import numpy as np

# put your data right here:
num_classes = 3 # here I'm assuming 3 classes, e.g. dog, cat and bird
x_train = np.zeros((100, 128, 128, 3)) # I'm just using zeros to simplify the example
y_train = np.zeros((100, num_classes))

model = Sequential()
# put your conv layer / conv blocks here:
model.add(Conv2D(32, kernel_size=3, activation='relu', input_shape=(128, 128, 3)))
model.add(Dense(units=num_classes, activation='sigmoid'))
training_history =, y=y_train, epochs=5)

I’m using Tensorflow 2.2.0. I hope this will help you :)

Source: stackoverflow