Data Regression

2021-04-11 12:58

Regress to a sine function.

#!/usr/bin/env python
# coding: utf-8
import torch
import numpy as np
from torch import nn
from matplotlib import pyplot as plt
x = torch.unsqueeze(torch.linspace(-np.pi, np.pi, 100), dim=1)
y = torch.sin(x) + 0.5*torch.rand(x.size())
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.predict = nn.Sequential(
            nn.Linear(1, 10),
            nn.ReLU(),
            nn.Linear(10, 1)
        )
    def forward(self, x):
        pred = self.predict(x)
        return pred
net = Net()
optimizer = torch.optim.SGD(net.parameters(), lr=0.05)
loss_func = nn.MSELoss()
plt.ion()
with plt.ion():
    fig = plt.figure()
    for epoch in range(1000):
        y_pred = net(x)
        loss = loss_func(y_pred, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if epoch % 1000 == 0:
            plt.scatter(x.detach().numpy(), y_pred.detach().numpy())
            plt.plot(x.detach().numpy(), y.detach().numpy())
plt.show()
for epoch in range(10000):
    plt.ion()
    y_pred = net(x)
    loss = loss_func(y_pred, y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if epoch%5000==0:
        plt.title('%s  traing result' % int(epoch/1000))
        plt.plot(x.detach().numpy(), y_pred.detach().numpy())
        plt.scatter(x.detach().numpy(), y.detach().numpy())
        plt.legend()
        plt.ioff()
    plt.show()