CIFAR-10 with GoogLeNet

2021-05-16 11:32

1. train part

import PIL
import time
import torch
import numpy as np
from torch import nn
from tqdm import tqdm
from torch import optim
from torchvision import transforms
from torch.nn import functional as F
from matplotlib import pyplot as plt
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
!wget https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz

--2021-05-16 11:43:20-- https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz

Loaded CA certificate '/etc/ssl/certs/ca-certificates.crt'

Resolving www.cs.toronto.edu (www.cs.toronto.edu)... 128.100.3.30

Connecting to www.cs.toronto.edu (www.cs.toronto.edu)|128.100.3.30|:443... connected.

HTTP request sent, awaiting response... 200 OK

Length: 170498071 (163M) [application/x-gzip]

Saving to: ‘cifar-10-python.tar.gz’

cifar-10-python.tar 100%[===================>] 162.60M 3.42MB/s in 50s

2021-05-16 11:44:11 (3.27 MB/s) - ‘cifar-10-python.tar.gz’ saved [170498071/170498071]

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean = (0.5, 0.5, 0.5),
                         std = (0.5, 0.5, 0.5))
])
train_set = CIFAR10(
    root = './cifar-10',
    train = True,
    # download = True,
    transform = transform
)
train_loader = DataLoader(
    train_set,
    batch_size = 100,
    shuffle = True
)
class Inception_A(nn.Module):
    def __init__(self, in_channels):
        super(Inception_A, self).__init__()
        self.branch3x3_1 = nn.Conv2d(in_channels, 16, 1)
        self.branch3x3_2 = nn.Conv2d(16, 24, 3, padding = 1)
        self.branch3x3_3 = nn.Conv2d(24, 24, 3, padding = 1)

        self.branch5x5_1 = nn.Conv2d(in_channels, 16, 1)
        self.branch5x5_2 = nn.Conv2d(16, 24, 5, padding = 2)

        self.branch1x1 = nn.Conv2d(in_channels, 16, 1)

        self.branch_pool = nn.Conv2d(in_channels, 24, 1)

    def forward(self, x):
        branch3x3 = self.branch3x3_1(x)
        branch3x3 = self.branch3x3_2(branch3x3)
        branch3x3 = self.branch3x3_3(branch3x3)

        branch5x5 = self.branch5x5_1(x)
        branch5x5 = self.branch5x5_2(branch5x5)

        branch1x1 = self.branch1x1(x)

        branch_pool = F.avg_pool2d(x, 3, 1, 1)
        branch_pool = self.branch_pool(branch_pool)

        output = [branch1x1, branch5x5, branch3x3, branch_pool]

        return torch.cat(output, dim = 1)
class GoogLeNet(nn.Module):
    def __init__(self):
        super(GoogLeNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 10, 5)
        self.incep1 = Inception_A(in_channels=10)

        self.conv2 = nn.Conv2d(88, 20, 5)
        self.incep2 = Inception_A(in_channels=20)

        self.mp = nn.MaxPool2d(2)
        # 88*5*5 = 2200
        self.fc = nn.Linear(2200, 10)
        self.cls = nn.Softmax(dim = 1)

    def forward(self, x):
        x = F.relu(self.mp(self.conv1(x)))
        x = self.incep1(x)
        x = F.relu(self.mp(self.conv2(x)))
        x = self.incep2(x)
        x = x.view(x.size()[0], -1)
        x = self.fc(x)
        x = self.cls(x)
        return x
net = GoogLeNet()
torch.cuda.empty_cache()
device = ('cuda' if torch.cuda.is_available() else 'cpu')
if not (device == 'cpu'):
    net.to(device)
criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr = 0.005, momentum= 0.2)
egimage, eglabel = iter(train_loader).next()

print(egimage.size())
print(eglabel.size())
# print result
torch.Size([100, 3, 32, 32])
torch.Size([100])
egindex = 4

plt.figure()
plt.imshow(egimage[egindex][0])
plt.colorbar()
plt.grid()
plt.show()

print(classes[eglabel[egindex]])

cat

cat

start_time = time.time()
epochs = 50
epoch_loss = []

for epoch in range(epochs):
    running_loss = 0
    for i, (inputs, labels) in enumerate(train_loader):
        if not (device == 'cpu'):
            inputs = inputs.to(device)
            labels = labels.to(device)
        y_hats = net(inputs)

        loss = criterion(y_hats, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        avr_loss = running_loss / (i+1)

    epoch_loss.append(avr_loss)
    if epoch%5 == 0:
        print('epoch %d, loss: %.5f'%(epoch, avr_loss))

end_time = time.time()
print('Training finished, time used %.3f s'%(end_time - start_time))
epoch 0, loss: 1.81920
epoch 5, loss: 1.81351
epoch 10, loss: 1.80854
epoch 15, loss: 1.80460
epoch 20, loss: 1.80021
epoch 25, loss: 1.79726
epoch 30, loss: 1.79350
epoch 35, loss: 1.78997
epoch 40, loss: 1.78629
epoch 45, loss: 1.78374
Training finished, time used 1777.526 s
plt.plot(epoch_loss)

[]

png

torch.save(net.state_dict(), './GoogLeNet_weights_0005_02.pkl')
ac = 0
total = 0 
with torch.no_grad():
    for i, (inputs, labels) in enumerate(test_loader):
        if not (device =='cpu'):
            inputs = inputs.to(device)
            labels = labels.to(device)
        y_hats = net(inputs)
        y_pred = y_hats.argmax(dim = 1)
        ac += (y_pred == labels).sum().item()
        total += labels.size()[0]

print('accuracy: ', ac/total)

2. test part

import torch
from torchvision import transforms
from GoogLeNet_model import GoogLeNet
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean = (0.5, 0.5, 0.5),
                        std = (0.5, 0.5, 0.5))
])
test_set  = CIFAR10(
    root = './cifar-10/',
    train = False,
    transform = transform
)
test_loader = DataLoader(
    test_set,
    batch_size = 100,
    shuffle = True,
)
net = GoogLeNet()
device = ('cuda' if torch.cuda.is_available() else 'cpu')
print('Using: ', device)
if not (device == 'cpu'):
    net.to(device)

Using: cuda

net.load_state_dict(torch.load('./GoogLeNet_weights_0005_02.pkl'))

ac = 0
total = 0 
with torch.no_grad():
    for i, (inputs, labels) in enumerate(test_loader):
        if not (device =='cpu'):
            inputs = inputs.to(device)
            labels = labels.to(device)
        y_hats = net(inputs)
        y_pred = y_hats.argmax(dim = 1)
        ac += (y_pred == labels).sum().item()
        total += labels.size()[0]

print('accuracy: ', ac/total)
accuracy:  0.6253