I am not able to understand, if prediction is calculated in forward method, then why there is need “out = self(images)” and what it will do. I am bit confuse about this code.
class MnistModel(nn.Module): def __init__(self): super().__init__() self.linear = nn.Linear(input_size, num_classes) def forward(self, xb): xb = xb.reshape(-1, 784) out = self.linear(xb) return out def training_step(self, batch): images, labels = batch out = self(images) # Generate predictions loss = F.cross_entropy(out, labels) # Calculate loss return loss def validation_step(self, batch): images, labels = batch out = self(images) # Generate predictions loss = F.cross_entropy(out, labels) # Calculate loss acc = accuracy(out, labels) # Calculate accuracy return {'val_loss': loss, 'val_acc': acc} def validation_epoch_end(self, outputs): batch_losses = [x['val_loss'] for x in outputs] epoch_loss = torch.stack(batch_losses).mean() # Combine losses batch_accs = [x['val_acc'] for x in outputs] epoch_acc = torch.stack(batch_accs).mean() # Combine accuracies return {'val_loss': epoch_loss.item(), 'val_acc': epoch_acc.item()} def epoch_end(self, epoch, result): print("Epoch [{}], val_loss: {:.4f}, val_acc: {:.4f}".format(epoch, result['val_loss'], result['val_acc']))
model = MnistModel()
Advertisement
Answer
In Python, self
refers to the instance that you have created from a class (similar to this
in Java and C++). An instance is callable, which means it may be called like a function itself, if method __call__
have been overridden.
Example:
class A: def __init__(self): pass def __call__(self, x, y): return x + y a = A() print(a(3,4)) # Prints 7
In your case, __call__
method is implemented in super class nn.Module
.
As it is a neural network module it needs an input placeholder. “out” is the placeholder for the data that is going to be forward the output of the module to the next layer or module of your model.
In the case of nn.Module
class instances (and those that inherit from the class) the forward method is what is used as the __call__
method. At least where it is defined with respect to the nn.Module
class.