Data Classfication
2021-04-18 13:16
import torch
import torch.nn as nn
from matplotlib import pyplot as plt
data = torch.ones(100, 2)
x0 = torch.normal(data*2, 1)
x1 = torch.normal(data*-2, 1)
x = torch.cat([x0, x1], 0)#.type(torch.FloatTensor)
y0 = torch.zeros(100)
y1 = torch.ones(100)
y = torch.cat([y0, y1], 0).type(torch.LongTensor)
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.classify = nn.Sequential(
nn.Linear(2, 15),
nn.ReLU(),
nn.Linear(15, 2),
nn.Softmax(dim=1)
)
def forward(self, x):
classification = self.classify(x)
return classification
net = Net()
optimizer = torch.optim.SGD(net.parameters(), lr=0.3)
loss_func = nn.CrossEntropyLoss()
for epoch in range(100):
y_hat = net(x)
loss = loss_func(y_hat, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
classification = torch.max(y_hat, 1)[1]
class_y = classification.data.numpy()
target_y = y.data.numpy()
plt.scatter(x.data.numpy()[:,0],
x.data.numpy()[:,1],
c = class_y,
s = 100,
cmap = 'RdYlGn',
)
accuracy = sum(class_y == target_y)/200
# plt.title('accuracy = %s'%accuracy)
plt.text(1.5,
-4,
f'accuracy = {accuracy}',
fontdict = {'size':'20',
'color':'blue'}
)
plt.show()