Skip to content
Snippets Groups Projects

net_sample3.py

  • Clone with SSH
  • Clone with HTTPS
  • Embed
  • Share
    The snippet can be accessed without any authentication.
    Authored by Matteo

    A modification of: https://git.garr.it/snippets/19

    Modified checkpoint_filename var as checkpoint_filename = '/workspace/mnist-{:03d}.pkl'.format(epoch) otherwise script would terminate after first epoch.

    Increased number of epoch from 10 to 50 for longer script run.

    Edited
    net_sample3.py 5.44 KiB
    # pip install pytorch-ignite
    # pip install matplotlib
    
    import sys
    import numpy as np
    
    import torch
    import torch.nn.functional as F
    import torch.nn as nn
    import torch.optim as optim
    import torchvision.transforms as transforms
    
    from torchvision.datasets import MNIST
    from torch.utils.data import DataLoader
    
    import matplotlib.pyplot as plt
    
    import ignite
    
    
    device = 'cuda'
    
    class SimpleCNN(nn.Module):
    
        def __init__(self, num_channels=1, num_classes=10):
            super(SimpleCNN, self).__init__()
            self.conv1 = nn.Conv2d(num_channels, 32, 3, stride=1, padding=1)
            self.conv2 = nn.Conv2d(32, 32, 3, stride=1, padding=1)
            self.pool1 = nn.MaxPool2d(2)
            self.drop1 = nn.Dropout(0.25)
            self.fc1 = nn.Linear(14*14*32, 128)
            self.drop2 = nn.Dropout(0.5)
            self.fc2 = nn.Linear(128, num_classes)
    
        def forward(self, X):
            X = F.relu(self.conv1(X))
            X = F.relu(self.conv2(X)) 
            X = self.pool1(X)
            X = self.drop1(X)
            X = X.reshape(-1, 14*14*32)
            X = F.relu(self.fc1(X))
            X = self.drop2(X)
            X = self.fc2(X)
            return X  # logits
    
    def save_checkpoint(optimizer, model, epoch, filename):
        checkpoint_dict = {
            'optimizer': optimizer.state_dict(),
            'model': model.state_dict(),
            'epoch': epoch
        }
        torch.save(checkpoint_dict, filename)
    
    
    def load_checkpoint(optimizer, model, filename):
        checkpoint_dict = torch.load(filename)
        epoch = checkpoint_dict['epoch']
        model.load_state_dict(checkpoint_dict['model'])
        if optimizer is not None:
            optimizer.load_state_dict(checkpoint_dict['optimizer'])
        return epoch
    
    # !mkdir -p checkpoints
    
    class MovingAverage():
        def __init__(self, nitems=10):
            self.__mylist = list()
            self.__nitems = nitems
        def update(self, value):
            self.__mylist.append(value)
        def value(self):
            self.__mylist = self.__mylist[-self.__nitems:]
            return sum(self.__mylist) / len(self.__mylist)
    
    
    def train(optimizer, model, num_epochs=50, first_epoch=1):
        
        criterion = nn.CrossEntropyLoss()
    
        train_losses = []
        valid_losses = []
    
        for epoch in range(first_epoch, first_epoch + num_epochs):
            print('Epoch', epoch)
    
            # train phase
            model.train()
    
            # create a progress bar
            # progress = ProgressMonitor(length=len(train_set))
    
            train_loss = MovingAverage()
    
            for batch, targets in train_loader:
                # Move the training data to the GPU
                batch = batch.to(device)
                targets = targets.to(device)
    
                # clear previous gradient computation
                optimizer.zero_grad()
    
                # forward propagation
                predictions = model(batch)
    
                # calculate the loss
                loss = criterion(predictions, targets)
    
                # backpropagate to compute gradients
                loss.backward()
    
                # update model weights
                optimizer.step()
    
                # update average loss
                train_loss.update(loss)
    
                # update progress bar
                #progress.update(batch.shape[0], train_loss)
    
            print('Training loss:', train_loss)
            train_losses.append(train_loss.value)
    
    
            # validation phase
            model.eval()
    
            #valid_loss = RunningAverage()
            valid_loss = MovingAverage()
    
            # keep track of predictions
            y_pred = []
    
            # We don't need gradients for validation, so wrap in 
            # no_grad to save memory
            with torch.no_grad():
    
                for batch, targets in valid_loader:
    
                    # Move the training batch to the GPU
                    batch = batch.to(device)
                    targets = targets.to(device)
    
                    # forward propagation
                    predictions = model(batch)
    
                    # calculate the loss
                    loss = criterion(predictions, targets)
    
                    # update running loss value
                    valid_loss.update(loss)
    
                    # save predictions
                    y_pred.extend(predictions.argmax(dim=1).cpu().numpy())
    
            print('Validation loss:', valid_loss)
            valid_losses.append(valid_loss.value)
    
            # Calculate validation accuracy
            y_pred = torch.tensor(y_pred, dtype=torch.int64)
            accuracy = torch.mean((y_pred == valid_set.targets).float())
            print('Validation accuracy: {:.4f}%'.format(float(accuracy) * 100))
    
            # Save a checkpoint
            checkpoint_filename = '/workspace/mnist-{:03d}.pkl'.format(epoch)
            save_checkpoint(optimizer, model, epoch, checkpoint_filename)
        
        return train_losses, valid_losses, y_pred
    
    
    # transform for the training data
    train_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.1307], [0.3081])
    ])
    
    # use the same transform for the validation data
    valid_transform = train_transform
    
    # load datasets, downloading if needed
    train_set = MNIST('./data/mnist', train=True, download=True, 
                      transform=train_transform)
    valid_set = MNIST('./data/mnist', train=False, download=True, 
                      transform=valid_transform)
    
    print(train_set.data.shape)
    print(valid_set.data.shape)
    
    
    train_loader = DataLoader(train_set, batch_size=256, num_workers=0, shuffle=True)
    valid_loader = DataLoader(valid_set, batch_size=512, num_workers=0, shuffle=False)
    
    model = SimpleCNN()
    model.to(device)
    
    optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, nesterov=True)
    train_losses, valid_losses, y_pred = train(optimizer, model, num_epochs=50)
    0% Loading or .
    You are about to add 0 people to the discussion. Proceed with caution.
    Finish editing this message first!
    Please register or to comment