PyTorch Intro

PyTorch Official Tutorial: https://pytorch.org/tutorials/

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision

import matplotlib.pyplot as plt
import numpy as np

import pdb
In [2]:
# transform to do random affine and cast image to PyTorch tensor
trans_ = torchvision.transforms.Compose(
    [
     # torchvision.transforms.RandomAffine(10),
     torchvision.transforms.ToTensor()]
)

# Setup the dataset
ds = torchvision.datasets.ImageFolder("train_img/",
                                     transform=trans_)

# Setup the dataloader
loader = torch.utils.data.DataLoader(ds, 
                                     batch_size=16, 
                                     shuffle=True)
In [3]:
# [16, 3, 30, 30] = [batch size, channels, width, height]
for x, y in loader:
    print(x.shape)
    print(y.shape)
    print(y)
    break

# vis
for i in range(16):
    plt.imshow(np.transpose(x[i,:], (1,2,0))) # 30 x 30 x 3
    plt.show()
torch.Size([16, 3, 30, 30])
torch.Size([16])
tensor([4, 0, 3, 3, 2, 2, 2, 1, 2, 0, 2, 2, 1, 3, 1, 3])
In [4]:
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        
        # define the layers
        # kernel size = 3 means (3,3) kernel
        # rgb -> 3 -> in channel
        # number of feature maps = 16
        # number of filters = 3 x 16
        self.l1 = nn.Conv2d(kernel_size=3, in_channels=3, out_channels=16)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2) 
        # MaxPool2d, AvgPool2d. 
        # The first 2 = 2x2 kernel size, 
        # The second 2 means the stride=2
        
        self.l2 = nn.Conv2d(kernel_size=3, in_channels=16, out_channels=32)
        
        # FC layer
        self.fc1 = nn.Linear(32 * 6 * 6, 5)
        
    def forward(self, x):
        # define the data flow through the deep learning layers
        x = self.pool(F.relu(self.l1(x))) # 16x16 x 14 x 14
        x = self.pool(F.relu(self.l2(x))) # 16x32x6x6
        # print(x.shape)
        x = x.reshape(-1, 32*6*6) # [16 x 1152]# CRUCIAL: 
        # print(x.shape)
        x = self.fc1(x)
        return x
In [5]:
m = CNN()
pred = m(x)
print(pred.shape)
torch.Size([16, 5])
In [6]:
print(pred)
tensor([[ 0.0639,  0.0436,  0.0704, -0.0715,  0.1685],
        [ 0.0127,  0.0206, -0.0038, -0.0404,  0.0772],
        [ 0.0295, -0.0017,  0.0853, -0.1172,  0.1308],
        [ 0.0850,  0.0475,  0.0775, -0.0509,  0.1571],
        [ 0.0563,  0.0451,  0.0610, -0.1154,  0.1714],
        [ 0.0335,  0.0121,  0.0726, -0.0556,  0.1618],
        [ 0.0424,  0.0136,  0.1042, -0.0438,  0.1198],
        [ 0.0828,  0.0228,  0.0699, -0.0575,  0.1268],
        [ 0.0491,  0.0670, -0.0061, -0.0210,  0.0731],
        [ 0.0241,  0.0363,  0.0657,  0.0047,  0.0502],
        [ 0.0034,  0.0490,  0.0348, -0.0445,  0.0543],
        [ 0.0610,  0.0214,  0.0644,  0.0061,  0.1064],
        [ 0.0444,  0.0325,  0.0192, -0.1069,  0.1403],
        [ 0.0528,  0.0218,  0.0242, -0.0196,  0.0937],
        [ 0.0321,  0.0210,  0.0212, -0.0482,  0.1059],
        [ 0.0493,  0.0357,  0.0338, -0.0419,  0.1400]],
       grad_fn=<AddmmBackward>)

Training

In [7]:
criterion = nn.CrossEntropyLoss()
num_epoches = 50
import tqdm

import torch.optim as optim
In [8]:
for epoch_id in range(num_epoches):
    optimizer = optim.SGD(m.parameters(), lr=0.01 * 0.95 ** epoch_id)
    for x, y in tqdm.tqdm(loader):
        optimizer.zero_grad() # clear (reset) the gradient for the optimizer
        pred = m(x)
        loss = criterion(pred, y)
        loss.backward() # calculating the gradient
        optimizer.step() # backpropagation: optimize the model
100%|██████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 47.51it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 46.98it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 47.05it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 47.75it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 48.46it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 48.06it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 49.54it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 48.14it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 49.88it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 49.28it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 50.05it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 50.22it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 50.13it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 51.46it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 51.10it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 50.13it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 45.43it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 51.92it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 50.66it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 49.96it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 51.46it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 48.87it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 47.13it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 48.30it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 49.12it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 47.28it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 48.14it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 49.12it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 48.14it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 49.03it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 50.83it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 47.75it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 49.79it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 45.72it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 49.79it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 51.10it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 50.48it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 50.57it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 48.54it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 50.57it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 50.05it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 49.37it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 49.96it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 51.56it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 50.22it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 49.20it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 49.37it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 47.67it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 49.79it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 49.62it/s]

Testing

In [9]:
# Setup the dataset
test_ds = torchvision.datasets.ImageFolder("test_img/",
                                     transform=trans_)

# Setup the dataloader
testloader = torch.utils.data.DataLoader(test_ds, 
                                     batch_size=16, 
                                     shuffle=True)
In [10]:
all_gt = []
all_pred = []

for x, y in tqdm.tqdm(loader):
    optimizer.zero_grad() # clear (reset) the gradient for the optimizer
    all_gt += list(y.numpy().reshape(-1))
    pred = torch.argmax(m(x), dim=1)
    all_pred += list(pred.numpy().reshape(-1))
    
    
100%|██████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 84.53it/s]
In [11]:
print(all_gt)
print(all_pred)
[1, 0, 3, 2, 0, 4, 2, 2, 4, 4, 1, 1, 4, 3, 1, 1, 1, 1, 1, 1, 0, 3, 1, 3, 3, 0, 1, 0, 1, 1, 4, 3, 2, 1, 2, 0, 0, 4, 3, 4, 4, 2, 3, 0, 1, 4, 0, 3, 3, 4, 3, 3, 0, 1, 2, 1, 1, 4, 2, 2, 1, 0, 3, 4, 1, 4, 3, 4, 4, 2, 2, 2, 4, 0, 1, 1, 1, 1, 0, 1, 3, 1, 0, 0, 1, 2, 1, 4, 2, 1, 4, 4, 0, 2, 1, 4, 1, 3, 4, 1, 0, 3, 3, 1, 3, 3, 3, 2, 1, 1, 4, 2, 0, 4, 2, 3, 3, 4, 2, 0, 3, 4, 0, 4, 4, 2, 1, 1, 0, 3, 2, 3, 1, 0, 1, 4, 3, 2, 1, 4, 4, 0, 4, 2, 0, 2, 0, 3, 0, 2, 3, 3, 3, 4, 4, 0, 1, 0, 0, 3, 2, 0, 1, 3, 2, 0, 3, 0, 1, 3, 4, 2, 1, 3, 4, 1, 0, 1, 2, 2, 1, 2, 1, 0, 3, 3, 3, 3, 0, 2, 2, 0, 0, 0, 2, 4, 1, 1, 1, 2, 0, 0, 4, 3, 0, 4, 2, 1, 2, 2, 4, 3, 1, 3, 4, 1, 1, 2, 2, 4, 4, 0, 1, 4, 4, 3, 4, 3, 0, 2, 0, 3, 4, 4, 0, 1, 0, 4, 4, 1, 0, 3, 3, 0, 0, 3, 4, 3, 0, 1, 0, 3, 0, 3, 3, 1, 0, 3, 2, 4, 1, 3, 3, 3, 3, 3, 0, 4, 4, 4, 4, 1, 2, 4, 4, 2, 2, 0, 4, 2, 3, 2, 0, 0, 0, 2, 1, 4, 4, 3, 1, 4, 0, 3, 4, 3, 3, 0, 0, 1, 0, 2, 3, 0, 1, 3, 2, 2, 2, 1, 3, 0, 3, 0, 0, 3, 2, 0, 2, 1, 3, 0, 1, 4, 3, 2, 2, 4, 0, 2, 1, 1, 3, 1, 2, 4, 0, 2, 2, 3, 2, 2, 2, 3, 2, 0, 2, 1, 2, 2, 3, 2, 4, 4, 4, 2, 1, 4, 4, 2, 0, 0, 4, 1, 2, 0, 4, 2, 0, 2, 1, 3, 4, 3, 4, 4, 2, 4, 0, 4, 4, 0, 4, 3, 2, 1, 1, 0, 2, 0, 1, 3, 0, 3, 4, 4, 4, 2, 2, 0, 0, 3, 2, 2, 2, 2, 2, 0, 0, 3, 2, 1, 1, 0, 3, 1, 2, 0, 0, 1, 1, 3, 4, 1, 0, 0, 3, 4, 3, 4, 4, 1, 3, 2, 1, 0, 3, 1, 4, 2, 1, 3, 4, 4, 2, 2, 2, 1, 4, 3]
[2, 4, 0, 2, 0, 4, 1, 2, 4, 4, 1, 0, 2, 3, 1, 1, 1, 1, 2, 1, 3, 3, 1, 1, 2, 0, 1, 2, 1, 4, 4, 0, 1, 0, 2, 0, 0, 4, 3, 4, 4, 1, 3, 1, 2, 4, 0, 1, 3, 3, 3, 3, 1, 1, 2, 1, 1, 4, 2, 2, 1, 0, 3, 4, 1, 0, 0, 4, 3, 2, 2, 0, 0, 0, 1, 1, 2, 1, 0, 2, 3, 1, 4, 0, 2, 2, 4, 3, 2, 1, 4, 4, 2, 2, 0, 2, 0, 3, 4, 1, 0, 1, 0, 2, 3, 0, 0, 4, 1, 1, 4, 0, 1, 4, 2, 4, 0, 4, 2, 0, 3, 4, 0, 4, 4, 2, 2, 1, 3, 1, 2, 4, 1, 0, 1, 4, 3, 2, 1, 4, 2, 0, 3, 0, 2, 2, 0, 3, 1, 2, 4, 3, 2, 0, 4, 0, 1, 4, 1, 3, 2, 4, 1, 3, 2, 1, 0, 3, 2, 3, 4, 2, 1, 0, 4, 1, 0, 1, 1, 0, 1, 1, 1, 0, 1, 3, 3, 3, 1, 0, 0, 3, 1, 0, 1, 3, 1, 1, 0, 1, 0, 4, 4, 3, 0, 4, 2, 1, 1, 2, 4, 3, 1, 3, 4, 0, 4, 2, 2, 4, 4, 0, 1, 0, 4, 2, 4, 1, 0, 2, 1, 3, 4, 4, 1, 1, 0, 4, 4, 1, 4, 1, 3, 0, 2, 1, 4, 3, 0, 2, 0, 1, 1, 3, 3, 1, 0, 3, 1, 1, 1, 0, 3, 3, 1, 1, 0, 2, 4, 4, 4, 4, 1, 4, 4, 2, 2, 3, 4, 1, 3, 2, 2, 4, 0, 2, 1, 4, 2, 3, 1, 4, 1, 3, 4, 3, 1, 1, 0, 1, 0, 1, 3, 0, 1, 3, 2, 2, 2, 1, 1, 1, 3, 3, 0, 3, 2, 0, 2, 1, 3, 0, 1, 4, 0, 2, 1, 1, 0, 1, 4, 0, 3, 1, 2, 4, 3, 4, 1, 3, 2, 2, 2, 0, 2, 0, 1, 1, 2, 2, 4, 2, 4, 4, 0, 2, 1, 3, 1, 0, 0, 0, 3, 0, 2, 0, 4, 2, 0, 2, 1, 0, 0, 0, 4, 4, 0, 4, 0, 4, 4, 0, 4, 3, 2, 1, 2, 0, 2, 0, 1, 3, 3, 3, 4, 4, 4, 4, 2, 2, 0, 4, 4, 2, 2, 2, 4, 1, 0, 3, 4, 1, 1, 0, 3, 1, 3, 0, 2, 1, 2, 4, 4, 4, 1, 2, 3, 4, 0, 4, 4, 2, 0, 2, 1, 0, 3, 4, 4, 2, 1, 3, 4, 1, 2, 2, 2, 4, 4, 3]
In [12]:
acc = np.sum(np.array(all_gt) == np.array(all_pred)) / len(all_gt)
print("Accuracy is:", acc)
Accuracy is: 0.6466666666666666

Dilation and Depth-wise Conv

In [15]:
standard_conv = nn.Conv2d(kernel_size=3, in_channels=16, out_channels=16, dilation=1, groups=1)
dilated_conv = nn.Conv2d(kernel_size=3, in_channels=16, out_channels=16, dilation=2, groups=1)
depth_conv = nn.Conv2d(kernel_size=3, in_channels=16, out_channels=16, dilation=1, groups=16)
In [19]:
print(sum([p.numel() for p in standard_conv.parameters()]))
print(sum([p.numel() for p in dilated_conv.parameters()]))
print(sum([p.numel() for p in depth_conv.parameters()]))
2320
2320
160
In [23]:
print(standard_conv.weight.shape)
print(dilated_conv.weight.shape)
print(depth_conv.weight.shape)
torch.Size([16, 16, 3, 3])
torch.Size([16, 16, 3, 3])
torch.Size([16, 1, 3, 3])

Exercise after class

  • Use deeper network
  • Try to use dilated convolution
  • Try to use depth-wise convolution