# TP5: Training Tricks

In [None]:
import math
import numpy as np
import matplotlib.pyplot as plt

from collections import defaultdict

import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
from torchvision.utils import make_grid

from torch.utils.tensorboard import SummaryWriter

In [None]:
# Load CIFAR10 dataset and define a transformation done on all images:
# - resize to 32x32 (does nothing here)
# - convert to tensor
# - normalize to mean 0.5, std 0.5 for each channel
transform = transforms.Compose(
    [transforms.Resize(32), transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
)
trainset = datasets.CIFAR10(
    root="./data", train=True, download=True, transform=transform
)
testset = datasets.CIFAR10(
    root="./data", train=False, download=True, transform=transform
)

print(f"Train points {len(trainset)}")
print(f"Test points {len(testset)}")

In [None]:
# show a few images from the training loader
train_dl = DataLoader(trainset, batch_size=8, shuffle=True)
imgs, labels = next(iter(train_dl))
imgs = imgs[:8] * 0.5 + 0.5
grid = make_grid(imgs, nrow=4)

plt.figure(figsize=(8, 4))
plt.imshow(grid.permute(1, 2, 0).cpu().numpy())
plt.axis("off")
plt.title(", ".join([trainset.classes[int(ell)] for ell in labels[:8]]))
plt.show()

print(f"Number of classes: {len(trainset.classes)}")
# print(f"Classes: {trainset.classes}")
print(f"Image shape: {imgs[0].shape}")  # C, H, W

# Our Base Model

In [None]:
# Our baseline model
class CNN_base(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        channels = [3, 32, 64, 128, 128, 256]
        layers = []
        for i in range(5):
            layers += [
                nn.Conv2d(channels[i], channels[i + 1], 3, padding=1),
                nn.ReLU(),
            ]
            if i in {1, 3, 4}:  # downsample after 2nd, 4th, and 5th conv
                layers.append(nn.MaxPool2d(2, 2))
        self.features = nn.Sequential(*layers)
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(256 * 4 * 4, num_classes),
        )

    def forward(self, x):
        return self.classifier(self.features(x))

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

n_train = 200
train_subset = Subset(trainset, list(range(n_train)))
train_dl = DataLoader(train_subset, batch_size=32, shuffle=True)
test_dl = DataLoader(testset, batch_size=4)

writer = SummaryWriter("runs/")

# log a small batch of images and the model graph (if possible)
imgs_sample, labels_sample = next(iter(train_dl))
imgs_sample = imgs_sample.to(device)
# writer.add_graph(model, imgs_sample)  # may fail for some models
grid = make_grid(imgs_sample[:16], nrow=4, normalize=True, scale_each=True)
writer.add_image("train/sample_images", grid)

# Our Training Set 
We focus on a tiny subset of CIFAR to test optimization on a simplet setup.
Normally, we should end up overfitting this small training set quickly. 

In [None]:
n_train = 200
train_subset = Subset(trainset, list(range(n_train)))
train_dl = DataLoader(train_subset, batch_size=32, shuffle=True)

# Some utility functions to log info during training


In [None]:
def log_grad_norms(model, writer, step, model_name=""):
    for name, p in model.named_parameters():
        if p.grad is not None:
            writer.add_scalar(f"{model_name}/grads/{name}", p.grad.norm().item(), step)


def get_layer_grad_norms(model):
    layer_norms = defaultdict(float)
    for name, p in model.named_parameters():
        if p.grad is None:
            continue
        layer_norms[name] += p.grad.norm().item() ** 2
    return {k: v**0.5 for k, v in layer_norms.items()}

# Training Loop

In [None]:
def test_model(model, test_dl, criterion, writer=None, global_step=None, name=""):
    correct = 0
    total = 0
    losses = []
    with torch.no_grad():
        for imgs, labels in test_dl:
            imgs, labels = imgs.to(device), labels.to(device)
            logits = model(imgs)

            losses.append(criterion(logits, labels).mean().numpy())
            preds = logits.argmax(1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    if writer is not None and global_step is not None:
        writer.add_scalar(f"{name}/test/accuracy", 100 * correct / total, global_step)
        writer.add_scalar(f"{name}/test/loss", np.mean(losses), global_step)
    print(f"Test accuracy: {100 * correct / total:.2f}%, test loss: {np.mean(losses)}")

In [None]:
def train(
    model,
    train_dl,
    opt,
    criterion,
    writer,
    n_epochs,
    name="",
    global_step=0,
    prints=False,
):
    for epoch in range(n_epochs):
        model.train()
        for imgs, labels in train_dl:
            global_step += 1
            imgs, labels = imgs.to(device), labels.to(device)
            opt.zero_grad()
            logits = model(imgs)
            loss = criterion(logits, labels)
            writer.add_scalar("train/cnn_loss", loss.item(), global_step)
            if global_step % 10 == 0:
                log_grad_norms(model, writer, global_step, model_name=name)
            loss.backward()
            opt.step()
        print(f"Epoch {epoch + 1}: train loss = {loss.item():.4f}")
        if prints:
            print(get_layer_grad_norms(model))
        model.eval()
        test_model(
            model, test_dl, criterion, writer=writer, global_step=global_step, name=name
        )

# Sigmoid activation

In [None]:
# Try sigmoid
class CNN5_sigmoid(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        channels = [3, 32, 64, 128, 128, 256]
        layers = []
        for i in range(5):
            layers += [
                nn.Conv2d(channels[i], channels[i + 1], 3, padding=1),
                nn.Sigmoid(),
            ]
            if i in {1, 3, 4}:  # downsample after 2nd, 4th, and 5th conv
                layers.append(nn.MaxPool2d(2, 2))
        self.features = nn.Sequential(*layers)
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(256 * 4 * 4, num_classes),
        )

    def forward(self, x):
        return self.classifier(self.features(x))

In [None]:
n_train = 200
train_subset = Subset(trainset, range(n_train))
train_dl = DataLoader(train_subset, batch_size=32, shuffle=True)

cnn_sig = CNN5_sigmoid().to(device)
total_params = sum(p.numel() for p in cnn_sig.parameters())
print(f"Training SimpleCNN model with {total_params} parameters")

opt = torch.optim.Adam(cnn_sig.parameters(), lr=3e-4)
criterion = nn.CrossEntropyLoss()
global_step = 0

name = "sig_act"
n_epochs = 50

train(cnn_sig, train_dl, opt, criterion, writer, n_epochs, name=name, global_step=0)

# Compare to initial

In [None]:
name = "baseline"
cnn = CNN_base().to(device)
total_params = sum(p.numel() for p in cnn.parameters())
print(f"Training SimpleCNN model with {total_params} parameters")

opt = torch.optim.Adam(cnn.parameters(), lr=3e-4)
criterion = nn.CrossEntropyLoss()
global_step = 0

n_epochs = 50

train(cnn, train_dl, opt, criterion, writer, n_epochs, name=name, global_step=0)

# Visualize the kernel weights and images

In [None]:
# visualize first-layer kernels and their feature maps on a test image
def plot_kernels(model, test_dl, i_test):
    # get one batch from test loader and pick first image
    imgs_test, labels_test = next(iter(test_dl))
    img = imgs_test[i_test].to(device)  # shape (1,3,32,32)
    # show the three channels side-by-side
    img_cpu = img.squeeze(0).cpu()  # (3,H,W)
    # show full RGB image (denormalize from Normalize((0.5,), (0.5,)))
    img_rgb = img_cpu.permute(1, 2, 0).numpy()
    img_rgb = np.clip(img_rgb * 0.5 + 0.5, 0, 1)
    plt.figure(figsize=(3, 3))
    plt.imshow(img_rgb)
    plt.title(f"Label: {testset.classes[labels_test[i_test]]}")
    plt.axis("off")
    plt.figure(figsize=(9, 3))

    cmaps = ["Reds", "Greens", "Blues"]
    for i in range(3):
        ch = img_cpu[i].numpy()
        ch = (ch - ch.min()) / (ch.max() - ch.min() + 1e-8)  # normalize for display
        ax = plt.subplot(1, 3, i + 1)
        ax.imshow(ch, cmap=cmaps[i])
        ax.set_title(f"channel {i}")
        ax.axis("off")
    plt.suptitle("Test image channels")
    plt.tight_layout()
    plt.show()

    conv1 = cnn.features[0]  # first Conv2d layer
    kernels = conv1.weight.detach().cpu()  # (out_ch, in_ch, kH, kW)
    out_ch = kernels.shape[0]

    # number of kernels / feature maps to show
    n_show = 32
    n_cols = 8
    n_rows = math.ceil(n_show / n_cols)

    # plot kernels as RGB images (normalize per-kernel)
    plt.figure(figsize=(n_cols * 2, n_rows * 2))
    for i in range(n_show):
        k = kernels[i]  # (3, kH, kW)
        k_min, k_max = k.min(), k.max()
        k_img = (k - k_min) / (k_max - k_min + 1e-8)  # normalize to 0-1
        k_img = k_img.permute(1, 2, 0).numpy()  # H,W,C
        ax = plt.subplot(n_rows, n_cols, i + 1)
        ax.imshow(k_img)
        ax.set_title(f"kernel {i}")
        ax.axis("off")
    plt.suptitle("First-layer kernels (normalized RGB)")
    plt.tight_layout()
    plt.show()

    # compute feature maps produced by conv1 for the chosen test image
    with torch.no_grad():
        acts = conv1(img).squeeze(0).cpu()  # (out_ch, H, W)

    # plot first n_show feature maps (grayscale)
    plt.figure(figsize=(n_cols * 2, n_rows * 2))
    for i in range(n_show):
        act = acts[i].numpy()
        # normalize each activation map for visualization
        act = (act - act.min()) / (act.max() - act.min() + 1e-8)
        ax = plt.subplot(n_rows, n_cols, i + 1)
        ax.imshow(act, cmap="gray")
        ax.set_title(f"Features {i}")
        ax.axis("off")
    plt.suptitle("Feature maps after first conv (normalized)")
    plt.tight_layout()
    plt.show()

In [None]:
model_weights = []
with torch.no_grad():
    for p in cnn.parameters():
        model_weights.append(p.view(-1).abs())

    model_weights = torch.cat(model_weights).numpy()


plt.hist(model_weights, bins=100)
plt.show()

# Initialization Matters

### Zero Initialization
Gradients are all the same for all weights in a layer: no learning can occur. 

In [None]:
name = "zero_init"
cnn = CNN_base().to(device)

# initialize all weights and biases to zero
for name, param in cnn.named_parameters():
    param.data.zero_()
print("Initialized all cnn parameters to zero.")

total_params = sum(p.numel() for p in cnn.parameters())
print(f"Training SimpleCNN model with {total_params} parameters")

opt = torch.optim.Adam(cnn.parameters(), lr=3e-4)
criterion = nn.CrossEntropyLoss()
global_step = 0

n_epochs = 50

train(cnn, train_dl, opt, criterion, writer, n_epochs, name=name, global_step=0)

### Default Initialization

Check pytorch default initialization for 
- conv2d layers
- linear layers

Non-adaptive initialization schemes:

In [None]:
name = "normal_init"
cnn = CNN_base().to(device)


def init_normal(m):
    if isinstance(m, (nn.Linear, nn.Conv2d)):
        nn.init.normal_(m.weight, mean=0.0, std=0.1)
        if m.bias is not None:
            nn.init.zeros_(m.bias)


cnn.apply(init_normal)


opt = torch.optim.Adam(cnn.parameters(), lr=3e-4)
criterion = nn.CrossEntropyLoss()
global_step = 0

n_epochs = 50

train(cnn, train_dl, opt, criterion, writer, n_epochs, name=name, global_step=0)

# Dropout Regularization

In [None]:
# Our baseline model
class CNN_dropout(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        channels = [3, 32, 64, 128, 128, 256]
        layers = []
        for i in range(5):
            layers += [
                nn.Conv2d(channels[i], channels[i + 1], 3, padding=1),
                nn.ReLU(),
            ]
            if i in {1, 3, 4}:  # downsample after 2nd, 4th, and 5th conv
                layers.append(nn.MaxPool2d(2, 2))
        self.features = nn.Sequential(*layers)
        self.classifier = nn.Sequential(
            nn.Dropout(0.5),
            nn.Flatten(),
            nn.Linear(256 * 4 * 4, num_classes),
        )

    def forward(self, x):
        return self.classifier(self.features(x))

In [None]:
name = "dropout"
cnn_skip = CNN_dropout().to(device)

opt = torch.optim.Adam(cnn_skip.parameters(), lr=3e-4)
criterion = nn.CrossEntropyLoss()
global_step = 0

n_epochs = 50

train(cnn_skip, train_dl, opt, criterion, writer, n_epochs, name=name, global_step=0)

# Skip connections

In [None]:
# Skip connections
class ResidualBlock(nn.Module):
    def __init__(self, in_ch, out_ch, downsample=False):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
        )
        self.skip = nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity()
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool2d(2, 2) if downsample else nn.Identity()

    def forward(self, x):
        out = self.conv(x)
        out = out + self.skip(x)
        out = self.relu(out)
        out = self.pool(out)
        return out


class CNN_skip(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        channels = [3, 32, 64, 128, 128, 256]
        downsamples = {1, 3, 4}  # same downsampling pattern as before
        blocks = []
        for i in range(5):
            in_ch, out_ch = channels[i], channels[i + 1]
            down = i in downsamples
            blocks.append(ResidualBlock(in_ch, out_ch, downsample=down))
        self.features = nn.Sequential(*blocks)
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(256 * 4 * 4, num_classes),
        )

    def forward(self, x):
        return self.classifier(self.features(x))

In [None]:
name = "skip"
cnn_skip = CNN_dropout().to(device)

opt = torch.optim.Adam(cnn_skip.parameters(), lr=3e-4)
criterion = nn.CrossEntropyLoss()
global_step = 0

n_epochs = 50

train(cnn_skip, train_dl, opt, criterion, writer, n_epochs, name=name, global_step=0)

# Explicit regularization

In [None]:
def l1norm(model):
    return sum([p.abs().sum() for p in model.parameters()])


def l2norm(model):
    return math.sqrt(sum([p.norm() ** 2 for p in model.parameters()]))


def train_reg(
    model,
    train_dl,
    opt,
    criterion,
    writer,
    n_epochs,
    name="",
    reg="l1",
    global_step=0,
    prints=False,
):
    for epoch in range(n_epochs):
        for imgs, labels in train_dl:
            global_step += 1
            imgs, labels = imgs.to(device), labels.to(device)
            opt.zero_grad()
            logits = model(imgs)
            unreg_loss = criterion(logits, labels)
            if reg == "l1":
                penalty = 1e-4 * l1norm(model)
            elif reg == "l2":
                penalty = 1e-3 * l2norm(model)
            else:
                penalty = 0
            loss = unreg_loss + penalty
            writer.add_scalar("train/cnn_loss", unreg_loss.item(), global_step)
            if global_step % 100 == 0:
                log_grad_norms(model, writer, global_step, model_name=name)
            loss.backward()
            opt.step()
        print(f"Epoch {epoch + 1}: unregularized train loss = {unreg_loss.item():.4f}")
        print(f"Epoch {epoch + 1}: full train loss = {loss.item():.4f}")
        if prints:
            print(get_layer_grad_norms(model))
        test_model(model, test_dl, criterion, writer=None, global_step=None, name=name)

In [None]:
name = "ell1_reg"
cnn = CNN_base().to(device)

opt = torch.optim.Adam(cnn.parameters(), lr=3e-4)
criterion = nn.CrossEntropyLoss()
global_step = 0

n_epochs = 50

train_reg(cnn, train_dl, opt, criterion, writer, n_epochs, name=name, global_step=0)

In [None]:
@torch.no_grad()
def plot_weights_histogram(model):
    model_weights = []
    for p in model.parameters():
        model_weights.append(p.view(-1).abs())

    model_weights = torch.cat(model_weights).numpy()

    plt.hist(model_weights, bins=100)
    plt.show()


cnn_init = CNN_base().to(device)


plot_weights_histogram(cnn_init)
plot_weights_histogram(cnn)

# Data augmentation

In [None]:
from torch.utils.data import Dataset, Subset, DataLoader
import torchvision.transforms as T

train_transforms = transforms.Compose(
    [
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10),
        transforms.ColorJitter(brightness=0.2, contrast=0.2),
        transforms.ToTensor(),
    ]
)


# We want to augment only our subset of data for fair comparison
class AugmentedSubset(Dataset):
    def __init__(self, subset, transform=None):
        self.subset = subset
        self.transform = transform

    def __getitem__(self, idx):
        x, y = self.subset[idx]
        if self.transform:
            x = self.transform(x)
        return x, y

    def __len__(self):
        return len(self.subset)


n_train = 200
train_subset = Subset(trainset, list(range(n_train)))

augmented_subset = AugmentedSubset(train_subset)
aug_train_dl = DataLoader(augmented_subset, batch_size=32, shuffle=True)


In [None]:
name = "ell1_reg"
cnn = CNN_base().to(device)

opt = torch.optim.Adam(cnn.parameters(), lr=3e-4)
criterion = nn.CrossEntropyLoss()
global_step = 0

n_epochs = 50

train(cnn, aug_train_dl, opt, criterion, writer, n_epochs, name=name, global_step=0)

# Do you own experiments!
Try to achieve the best possible test accuracy on CIFAR10 with these 200 points. 
You can use any of the tricks seen in this TP or others you may know of.

Add more data points to the training set and see how the different tricks scale with more data.

# Class figure:  Learning rate scheduling illustration

In [None]:
import torch
import matplotlib.pyplot as plt

T = 200
base_lr = 1e-3


def make_optimizer():
    model = torch.nn.Linear(10, 1)
    return torch.optim.Adam(model.parameters(), lr=base_lr)


optimizers = {name: make_optimizer() for name in range(5)}

schedulers = {
    "Constant": torch.optim.lr_scheduler.LambdaLR(make_optimizer(), lambda _: 1.0),
    "StepLR": torch.optim.lr_scheduler.StepLR(
        make_optimizer(), step_size=50, gamma=0.5
    ),
    "ExponentialLR": torch.optim.lr_scheduler.ExponentialLR(
        make_optimizer(), gamma=0.98
    ),
    "CosineAnnealingLR": torch.optim.lr_scheduler.CosineAnnealingLR(
        make_optimizer(), T_max=T
    ),
    "OneCycleLR": torch.optim.lr_scheduler.OneCycleLR(
        make_optimizer(), max_lr=1e-3, total_steps=T
    ),
}

# Record learning rates
lrs = {name: [] for name in schedulers}

for step in range(T):
    for name, sched in schedulers.items():
        opt = sched.optimizer
        lrs[name].append(opt.param_groups[0]["lr"])
        sched.step()

plt.figure(figsize=(7, 4))
for name, values in lrs.items():
    plt.plot(values, label=name)
plt.xlabel("Step")
plt.ylabel("Learning rate")
plt.title("Typical Learning Rate Schedulers in PyTorch")
plt.legend()
plt.tight_layout()
plt.show()