Skip to content
Advertisement

can anyone explain what “out = self(images)” do in below code

I am not able to understand, if prediction is calculated in forward method, then why there is need “out = self(images)” and what it will do. I am bit confuse about this code.

class MnistModel(nn.Module):
def __init__(self):
    super().__init__()
    self.linear = nn.Linear(input_size, num_classes)
    
def forward(self, xb):
    xb = xb.reshape(-1, 784)
    out = self.linear(xb)
    return out

def training_step(self, batch):
    images, labels = batch 
    out = self(images)                  # Generate predictions
    loss = F.cross_entropy(out, labels) # Calculate loss
    return loss

def validation_step(self, batch):
    images, labels = batch 
    out = self(images)                    # Generate predictions
    loss = F.cross_entropy(out, labels)   # Calculate loss
    acc = accuracy(out, labels)           # Calculate accuracy
    return {'val_loss': loss, 'val_acc': acc}
    
def validation_epoch_end(self, outputs):
    batch_losses = [x['val_loss'] for x in outputs]
    epoch_loss = torch.stack(batch_losses).mean()   # Combine losses
    batch_accs = [x['val_acc'] for x in outputs]
    epoch_acc = torch.stack(batch_accs).mean()      # Combine accuracies
    return {'val_loss': epoch_loss.item(), 'val_acc': epoch_acc.item()}

def epoch_end(self, epoch, result):
    print("Epoch [{}], val_loss: {:.4f}, val_acc: {:.4f}".format(epoch, result['val_loss'], result['val_acc']))

model = MnistModel()

Advertisement

Answer

In Python, self refers to the instance that you have created from a class (similar to this in Java and C++). An instance is callable, which means it may be called like a function itself, if method __call__ have been overridden.

Example:

class A:
    def __init__(self):
        pass
    def __call__(self, x, y):
        return x + y

a = A()
print(a(3,4)) # Prints 7

In your case, __call__ method is implemented in super class nn.Module. As it is a neural network module it needs an input placeholder. “out” is the placeholder for the data that is going to be forward the output of the module to the next layer or module of your model.

In the case of nn.Module class instances (and those that inherit from the class) the forward method is what is used as the __call__ method. At least where it is defined with respect to the nn.Module class.

User contributions licensed under: CC BY-SA
3 People found this is helpful
Advertisement