net_sample3.py
The snippet can be accessed without any authentication.
Authored by
Matteo
A modification of: https://git.garr.it/snippets/19
Modified checkpoint_filename var as checkpoint_filename = '/workspace/mnist-{:03d}.pkl'.format(epoch) otherwise script would terminate after first epoch.
Increased number of epoch from 10 to 50 for longer script run.
# pip install pytorch-ignite
# pip install matplotlib
import sys
import numpy as np
import torch
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import ignite
device = 'cuda'
class SimpleCNN(nn.Module):
def __init__(self, num_channels=1, num_classes=10):
super(SimpleCNN, self).__init__()
self.conv1 = nn.Conv2d(num_channels, 32, 3, stride=1, padding=1)
self.conv2 = nn.Conv2d(32, 32, 3, stride=1, padding=1)
self.pool1 = nn.MaxPool2d(2)
self.drop1 = nn.Dropout(0.25)
self.fc1 = nn.Linear(14*14*32, 128)
self.drop2 = nn.Dropout(0.5)
self.fc2 = nn.Linear(128, num_classes)
def forward(self, X):
X = F.relu(self.conv1(X))
X = F.relu(self.conv2(X))
X = self.pool1(X)
X = self.drop1(X)
X = X.reshape(-1, 14*14*32)
X = F.relu(self.fc1(X))
X = self.drop2(X)
X = self.fc2(X)
return X # logits
def save_checkpoint(optimizer, model, epoch, filename):
checkpoint_dict = {
'optimizer': optimizer.state_dict(),
'model': model.state_dict(),
'epoch': epoch
}
torch.save(checkpoint_dict, filename)
def load_checkpoint(optimizer, model, filename):
checkpoint_dict = torch.load(filename)
epoch = checkpoint_dict['epoch']
model.load_state_dict(checkpoint_dict['model'])
if optimizer is not None:
optimizer.load_state_dict(checkpoint_dict['optimizer'])
return epoch
# !mkdir -p checkpoints
class MovingAverage():
def __init__(self, nitems=10):
self.__mylist = list()
self.__nitems = nitems
def update(self, value):
self.__mylist.append(value)
def value(self):
self.__mylist = self.__mylist[-self.__nitems:]
return sum(self.__mylist) / len(self.__mylist)
def train(optimizer, model, num_epochs=50, first_epoch=1):
criterion = nn.CrossEntropyLoss()
train_losses = []
valid_losses = []
for epoch in range(first_epoch, first_epoch + num_epochs):
print('Epoch', epoch)
# train phase
model.train()
# create a progress bar
# progress = ProgressMonitor(length=len(train_set))
train_loss = MovingAverage()
for batch, targets in train_loader:
# Move the training data to the GPU
batch = batch.to(device)
targets = targets.to(device)
# clear previous gradient computation
optimizer.zero_grad()
# forward propagation
predictions = model(batch)
# calculate the loss
loss = criterion(predictions, targets)
# backpropagate to compute gradients
loss.backward()
# update model weights
optimizer.step()
# update average loss
train_loss.update(loss)
# update progress bar
#progress.update(batch.shape[0], train_loss)
print('Training loss:', train_loss)
train_losses.append(train_loss.value)
# validation phase
model.eval()
#valid_loss = RunningAverage()
valid_loss = MovingAverage()
# keep track of predictions
y_pred = []
# We don't need gradients for validation, so wrap in
# no_grad to save memory
with torch.no_grad():
for batch, targets in valid_loader:
# Move the training batch to the GPU
batch = batch.to(device)
targets = targets.to(device)
# forward propagation
predictions = model(batch)
# calculate the loss
loss = criterion(predictions, targets)
# update running loss value
valid_loss.update(loss)
# save predictions
y_pred.extend(predictions.argmax(dim=1).cpu().numpy())
print('Validation loss:', valid_loss)
valid_losses.append(valid_loss.value)
# Calculate validation accuracy
y_pred = torch.tensor(y_pred, dtype=torch.int64)
accuracy = torch.mean((y_pred == valid_set.targets).float())
print('Validation accuracy: {:.4f}%'.format(float(accuracy) * 100))
# Save a checkpoint
checkpoint_filename = '/workspace/mnist-{:03d}.pkl'.format(epoch)
save_checkpoint(optimizer, model, epoch, checkpoint_filename)
return train_losses, valid_losses, y_pred
# transform for the training data
train_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.1307], [0.3081])
])
# use the same transform for the validation data
valid_transform = train_transform
# load datasets, downloading if needed
train_set = MNIST('./data/mnist', train=True, download=True,
transform=train_transform)
valid_set = MNIST('./data/mnist', train=False, download=True,
transform=valid_transform)
print(train_set.data.shape)
print(valid_set.data.shape)
train_loader = DataLoader(train_set, batch_size=256, num_workers=0, shuffle=True)
valid_loader = DataLoader(valid_set, batch_size=512, num_workers=0, shuffle=False)
model = SimpleCNN()
model.to(device)
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, nesterov=True)
train_losses, valid_losses, y_pred = train(optimizer, model, num_epochs=50)
Please register or sign in to comment