Skip to content
Advertisement

Why is the code not able to approximate the square function?

WHy does the following code not work as a square approximator? I am getting weird dimensions. When I tried plotting loss, the graph somehow does not show anything. I am a beginner with pytorch, so I would be grateful for any help.

import torch
from torch import nn
import matplotlib.pyplot as plt
import numpy as np

data = [[i] for i in range(-10000, 10000)]
y = [[i[0] * i[0]] for i in data]
data=torch.FloatTensor(data)
y=torch.FloatTensor(y)


class MyModel(nn.Module):
    def __init__(self, numfeatures, outfeatures):
        super().__init__()
        self.modele = nn.Sequential(
                                    nn.Linear( numfeatures,  2*numfeatures),
                                    nn.ReLU(),
                                    nn.Linear(2 * numfeatures, 4 * numfeatures),
                                    nn.ReLU(),
                                    nn.Linear(4* numfeatures, 2 * numfeatures),
                                    nn.ReLU(),
                                    nn.Linear(2*numfeatures, numfeatures),
                                    )
    def forward(self, x):
        return self.modele(x)


model = MyModel(1, 1)
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

n_epochs = 10000
epoch_loss= []

for i in range(n_epochs):
    y_pred = model(data)
    loss = criterion(y_pred, y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    epoch_loss.append(loss.item())

    

plt.plot(epoch_loss)

Advertisement

Answer

Your data is ranging from -10000 to 10000! You need to standardize your data, otherwise you won’t be able to make your model learn:

data = (data - data.min()) / (data.max() - data.min())
y = (y - y.min()) / (y.max() - y.min())

Additionally, you could normalize your input with:

mean, std = data.mean(), data.std()
data = (data - mean) / std

After 100 epochs:

enter image description here

User contributions licensed under: CC BY-SA
7 People found this is helpful
Advertisement