TP5 : Neural Tangent Kernel#

import torch

import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F

from torchvision import datasets, transforms
import torchvision

import numpy as np
import matplotlib.pyplot as plt

from IPython.display import clear_output

clear_output()

Today we study attempt to compare the behavior of neural nets and their NTK on some simple example.

The data#

To do so we are going to use the FashionMNIST dataset to build a small binary classification task, train some neural nets on those tasks, and build the corresponding NTK.

n = 200  # number of training points that we keep
c1, c2 = 1, 3  # subclasses that are kept, make sure that c1 < c2

batch_size = 32  # training batch size

classes = (
    "T-shirt/Top",
    "Trouser",
    "Pullover",
    "Dress",
    "Coat",
    "Sandal",
    "Shirt",
    "Sneaker",
    "Bag",
    "Ankle Boot",
)

# Define a sequence of operations that will be performed to all training images before use
# the 'ToTensor()' function sets the image in tensor object and puts the values of every pixel between 0 and 1
# the 'Normalize' performs the dataset to a given mean and variance
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
)

# Creates an object to load the images. We use FashionMNIST as a substitute for MNIST because binary classification on
# subclasses of MNIST is too easy.

trainset = datasets.FashionMNIST(
    "../data", train=True, download=True, transform=transform
)

r = torch.arange(len(trainset))

# Build a training set that only contains images with classes c1 and c2, with n / 2 images from each label.

idxc1 = torch.as_tensor(trainset.targets) == c1
x1 = np.where(np.cumsum(idxc1) == (n / 2))[0][0]
idxc1 = idxc1 & (r <= x1)

idxc2 = torch.as_tensor(trainset.targets) == c2
x2 = np.where(np.cumsum(idxc2) == (n / 2))[0][0]
idxc2 = (torch.as_tensor(trainset.targets) == c2) & (r <= x2)

idx = idxc1 + idxc2
dset_train = torch.utils.data.dataset.Subset(trainset, np.where(idx == 1)[0])

print(f"Number of training points : {len(dset_train)}")
trainloader = torch.utils.data.DataLoader(dset_train, batch_size=batch_size)

# Build the corresponding test set

testset = datasets.FashionMNIST(
    root="./data", train=False, download=True, transform=transform
)
clear_output()
idx = torch.as_tensor(testset.targets) == c1
idx += torch.as_tensor(testset.targets) == c2
dset_test = torch.utils.data.dataset.Subset(testset, np.where(idx == 1)[0])
testloader = torch.utils.data.DataLoader(dset_test)
def imshow(img):
    img = img / 2 + 0.5  # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()


# get batch_size random training images
dataiter = iter(trainloader)
images, labels = next(dataiter)

# show images
imshow(torchvision.utils.make_grid(images))
# print labels
print(" ".join(f"{classes[labels[j]]:5s}" for j in range(batch_size)))
../_images/5f590d535cbda0c4f5b09ea4d10512600577ea8944596543f8cef021cb5d05b9.png
Dress Trouser Dress Trouser Dress Dress Trouser Dress Dress Dress Dress Dress Dress Trouser Dress Trouser Dress Trouser Trouser Trouser Dress Trouser Dress Dress Trouser Trouser Trouser Trouser Trouser Dress Trouser Trouser

A neural network#

Fully connected 1-hidden layer neural network. Flattens the image and treats it as a vector.

Use the methods

nn.Flatten nn.Linear

nn.init.xavier_normal_ nn.init.zeros_

to build the network

class MLPshallow(nn.Module):
    def __init__(self, hidden_dim=10):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.flatten = nn.Flatten()
        self.l1 = nn.Linear(28 * 28, self.hidden_dim)
        self.lout = nn.Linear(self.hidden_dim, 1)

        nn.init.xavier_normal_(self.l1.weight)
        nn.init.xavier_normal_(self.lout.weight)

        nn.init.zeros_(self.l1.bias)
        nn.init.zeros_(self.lout.bias)

    def forward(self, x):
        x = self.flatten(x)
        x = F.relu(self.l1(x))
        x = self.lout(x)
        return x


net = MLPshallow(hidden_dim=2000)
class MLPshallow(nn.Module):
    def __init__(self, hidden_dim=10):
        super().__init__()
        self.hidden_dim = hidden_dim
        # YOUR CODE HERE

    def forward(self, x):
        # YOUR CODE HERE
        return x


net = MLPshallow(hidden_dim=1000)

The training function#

We train the network with the square loss to predict the class. The labels are normalized to be between \(0\) and \(1\).

i.e. the y-value for images with label i is float(i/9)

Write the training function.

def train(net, trainloader, N_passes=1, lr=0.01):
    optimizer = optim.SGD(net.parameters(), lr=lr)
    criterion = nn.MSELoss()

    losses = []
    i = 0

    for _ in range(N_passes):
        for inputs, labels in trainloader:
            i += 1
            optimizer.zero_grad()

            # print(inputs.shape)
            # print(torch.linalg.vector_norm(inputs, dim=(1, 2, 3)))

            outputs = net(inputs)
            target = labels.float().unsqueeze(1) / 9
            loss = criterion(outputs, target)

            loss.backward()
            optimizer.step()

            losses.append(loss.detach().numpy())

    print(f"Number of gradient steps {i}")

    return losses
def train(net, trainloader, N_passes=1, lr=0.01):
    optimizer = optim.SGD(net.parameters(), lr=lr)
    criterion = nn.MSELoss()

    losses = []
    i = 0

    for _ in range(N_passes):
        for inputs, labels in trainloader:
            i += 1

            # YOUR CODE HERE

            losses.append(loss.detach().numpy())

    print(f"Number of gradient steps {i}")

    return losses
print("Batch size : ", batch_size)

losses = train(net, trainloader, N_passes=200, lr=0.01)

plt.ylim(0, 1.1 * np.max(losses))
plt.plot(losses)
plt.grid(alpha=0.3)
Batch size :  32
Number of gradient steps 1400
../_images/cf75a3596a29d1141bd8906de4afcc97c4f7d31276be9ca86f69f59a1da0d9f1.png

Checking the training values#

Complete the following function to get the prediction of the network on the data.

Train the network until you interpolate the data.

correct = 0
total = 0
with torch.no_grad():
    for images, labels in trainloader:
        outputs = net(images)
        predicted = c1 + (c2 - c1) * (
            9 * outputs > (c1 + c2) / 2
        )  # c1 if prediction is less than average,

        total += labels.size(0)
        correct += (predicted.squeeze(1) == labels).sum()

print(
    f"Accuracy of the network on the {total} train images: {100 * correct // total} %"
)
Accuracy of the network on the 200 train images: 100 %
correct = 0
total = 0
with torch.no_grad():
    for images, labels in trainloader:

        outputs = net(images)
        predicted = #YOUR CODE HERE

        total += labels.size(0)
        correct += (predicted.squeeze(1) == labels).sum()

print(f'Accuracy of the network on the {total} train images: {100 * correct // total} %')
testloader = torch.utils.data.DataLoader(dset_test, batch_size=64)
dataiter = iter(testloader)
images, labels = next(dataiter)

# print images
imshow(torchvision.utils.make_grid(images[32:40]))
print("GrndTruth: ", " ".join(f"{classes[labels[j]]:5s}" for j in range(8)))

with torch.no_grad():
    outputs = net(images)
    predicted = c1 + (c2 - c1) * (9 * outputs > (c1 + c2) / 2)

    print("Predicted: ", " ".join(f"{classes[predicted[j]]:5s}" for j in range(8)))
../_images/2f12c0ad6230e964a39f5890dfc469e31ee49ccb041399797fd6eaaa49d8f0c7.png
GrndTruth:  Trouser Trouser Trouser Dress Trouser Trouser Dress Dress
Predicted:  Trouser Trouser Trouser Dress Trouser Trouser Dress Dress
testloader = torch.utils.data.DataLoader(dset_test, batch_size=8)
dataiter = iter(testloader)
images, labels = next(dataiter)

# print images
imshow(torchvision.utils.make_grid(images[:8]))
print('GrndTruth: ', ' '.join(f'{classes[labels[j]]:5s}' for j in range(8)))

with torch.no_grad():
    outputs = net(images)
    predicted = #YOUR CODE HERE

    print('Predicted: ', ' '.join(f'{classes[predicted[j]]:5s}' for j in range(8)))
  Cell In[17], line 11
    predicted = #YOUR CODE HERE
                ^
SyntaxError: invalid syntax

Test loss#

correct = 0
total = 0
with torch.no_grad():
    for images, labels in testloader:
        outputs = net(images)
        predicted = c1 + (c2 - c1) * (
            9 * outputs > (c1 + c2) / 2
        )  # c1 if prediction is less than average,

        total += labels.size(0)
        correct += (predicted.squeeze(1) == labels).sum()

print(f"Accuracy of the network on the {total} test images: {100 * correct // total} %")
Accuracy of the network on the 2000 test images: 92 %
correct = 0
total = 0
with torch.no_grad():
    for images, labels in testloader:
        outputs = net(images)
        predicted = #YOUR CODE HERE

        total += labels.size(0)
        correct += (predicted.squeeze(1) == labels).sum()

print(f'Accuracy of the network on the {total} test images: {100 * correct // total} %')

The NTK#

Let us write the corresponding NTK in the infinite width limit. Use the formula seen in class for the NTK kernel to complete the code below.

class NTK:
    def __init__(self, dset_train):
        self.n = len(dset_train)
        self.train_set = dset_train
        self.train()

    def k(self, x, xprime):
        """
        NTK Kernel for ReLU with one hidden layer. The delta factor
        is to avoid nan values for arccos(1+eps) from rounding.
        """
        with torch.no_grad():
            v = torch.linalg.norm(x) * torch.linalg.norm(xprime)
            u = 0.99999 * torch.dot(x, xprime) / v
            return v * (
                u * (torch.pi - torch.arccos(u) + torch.sqrt(1 - u**2)) / (2 * np.pi)
                + u * (torch.pi - torch.arccos(u)) / (2 * np.pi)
            )

    def train(self):
        ntrainloader = torch.utils.data.DataLoader(self.train_set, batch_size=self.n)
        dataiter = iter(ntrainloader)
        images, labels = next(dataiter)
        print(images.flatten(start_dim=1, end_dim=3).shape)

        xis = images.flatten(start_dim=1, end_dim=3)
        print(xis.shape)

        self.xis = xis

        # plt.imshow(xis)
        # plt.title('Flattened images')
        # plt.show()

        H = torch.empty((n, n))

        for i in range(n):
            for j in range(n):
                H[i, j] = self.k(xis[i], xis[j])

        # plt.title(f'Gram matrix for {n} x {n} inputs')
        # plt.imshow(H)
        # plt.colorbar()
        # plt.show()

        eigenvalues = np.linalg.eigvalsh(H)
        print(f"Smallest eigenvalue : {eigenvalues[0]}")
        # plt.title('Sorted eigenvalues of the Gram matrix')
        # plt.plot(eigenvalues)
        # plt.grid(alpha=0.3)
        # plt.ylim(0, 1.01 * np.max(eigenvalues))
        # plt.show()

        self.H = H
        self.Hinv = torch.linalg.inv(H)
        self.V = self.Hinv @ labels.float() / 9

        print(self.V.shape)

        # plt.imshow(V)
        # plt.colorbar()
        # plt.show()

    def apply(self, x):
        s = 0
        for i, xi in enumerate(self.xis):
            s += self.k(x, xi) * self.V[i]  # V = ((Phi^\top Phi)^{-1} Y)
        return s

    def gen_bound(self):
        ntrainloader = torch.utils.data.DataLoader(self.train_set, batch_size=n)
        dataiter = iter(ntrainloader)
        _, labels = next(dataiter)

        return torch.dot(labels.float(), self.Hinv @ labels.float())
class NTK:
    def __init__(self, dset_train):
        self.n = len(dset_train)
        self.train_set = dset_train
        self.train()

    def k(self, x, xprime):
        """
        NTK Kernel for ReLU with one hidden layer. T
        """
        delta = 0.999999
        with torch.no_grad():
            v = torch.linalg.norm(x) * torch.linalg.norm(xprime)
            u = delta * torch.dot(x, xprime) / v
            return  # YOUR CODE HERE

    def train(self):
        ntrainloader = torch.utils.data.DataLoader(self.train_set, batch_size=self.n)
        dataiter = iter(ntrainloader)
        images, labels = next(dataiter)
        print(images.flatten(start_dim=1, end_dim=3).shape)

        xis = images.flatten(start_dim=1, end_dim=3)
        print(xis.shape)

        self.xis = xis

        # plt.imshow(xis)
        # plt.title('Flattened images')
        # plt.show()

        H = torch.empty((n, n))

        for i in range(n):
            for j in range(n):
                H[i, j] = self.k(xis[i], xis[j])

        # plt.title(f'Gram matrix for {n} x {n} inputs')
        # plt.imshow(H)
        # plt.colorbar()
        # plt.show()

        eigenvalues = np.linalg.eigvalsh(H)
        print(f"Smallest eigenvalue : {eigenvalues[0]}")
        # plt.title('Sorted eigenvalues of the Gram matrix')
        # plt.plot(eigenvalues)
        # plt.grid(alpha=0.3)
        # plt.ylim(0, 1.01 * np.max(eigenvalues))
        # plt.show()

        self.H = H
        self.Hinv = torch.linalg.inv(H)
        self.V = self.Hinv @ labels.float() / 9

        print(self.V.shape)

        # plt.imshow(V)
        # plt.colorbar()
        # plt.show()

    def apply(self, x):
        s = 0
        for i, xi in enumerate(self.xis):
            s += self.k(x, xi) * self.V[i]
        return s
ntk = NTK(dset_train)
torch.Size([200, 784])
torch.Size([200, 784])
Smallest eigenvalue : 13.815045356750488
torch.Size([200])
ntk_trainloader = torch.utils.data.DataLoader(dset_train, batch_size=1)

correct_ntk = 0
total = 0
with torch.no_grad():
    for images, labels in ntk_trainloader:
        x = torch.flatten(images.squeeze(0))

        output_ntk = ntk.apply(x)
        predicted_ntk = c1 + (c2 - c1) * (
            9 * output_ntk > (c1 + c2) / 2
        )  # c1 if prediction is less than average,
        correct_ntk += (predicted_ntk.unsqueeze(0) == labels).numpy()[0]

        total += 1

print(f"Accuracy of NTK on the {total} train images: {100 * correct_ntk // total} %")
Accuracy of NTK on the 200 train images: 100 %

Let us now compare the outputs of the NTK and of the network.

What do you observe? Why?

testloader = torch.utils.data.DataLoader(dset_test)

total = 0
correct_ntk = 0
correct_net = 0
agree = 0

with torch.no_grad():
    for images, labels in testloader:
        x = torch.flatten(images.squeeze(0))

        output_ntk = ntk.apply(x)
        predicted_ntk = c1 + (c2 - c1) * (
            9 * output_ntk > (c1 + c2) / 2
        )  # c1 if prediction is less than average,
        correct_ntk += (predicted_ntk.unsqueeze(0) == labels).numpy()[0]

        outputs = net(images)
        predicted = c1 + (c2 - c1) * (
            9 * outputs > (c1 + c2) / 2
        )  # c1 if prediction is less than average,
        correct_net += (predicted.squeeze(1) == labels).numpy()[0]

        # print(labels.shape)
        # print(predicted_ntk.unsqueeze(0).shape, predicted_ntk.unsqueeze(0))
        # print(predicted.squeeze(1).shape, predicted.squeeze(1))

        agree += (predicted_ntk.unsqueeze(0) == predicted.squeeze(1)).numpy()[0]
        total += 1

print(f"Accuracy of NTK on the {total} test images: {100 * correct_ntk // total} %")
print(f"Accuracy of net on the {total} test images: {100 * correct_net // total} %")

print(f"Both methods agree on {100 * agree // total} % of the images")
Accuracy of NTK on the 2000 test images: 96 %
Accuracy of net on the 2000 test images: 92 %
Both methods agree on 94 % of the images

Generalization error#

trainset = datasets.FashionMNIST(
    "../data", train=True, download=True, transform=transform
)
r = torch.arange(len(trainset))

bounds = []

n_values = range(100, 2000, 300)

for n in n_values:
    # Build a training set that only contains images with classes c1 and c2, with n / 2 images from each label.

    idxc1 = torch.as_tensor(trainset.targets) == c1
    x1 = np.where(np.cumsum(idxc1) == (n / 2))[0][0]
    idxc1 = idxc1 & (r <= x1)

    idxc2 = torch.as_tensor(trainset.targets) == c2
    x2 = np.where(np.cumsum(idxc2) == (n / 2))[0][0]
    idxc2 = (torch.as_tensor(trainset.targets) == c2) & (r <= x2)

    idx = idxc1 + idxc2
    dset_train = torch.utils.data.dataset.Subset(trainset, np.where(idx == 1)[0])

    ntk = NTK(dset_train)
    bounds.append(ntk.gen_bound())
torch.Size([100, 784])
torch.Size([100, 784])
Smallest eigenvalue : 23.512008666992188
torch.Size([100])
torch.Size([400, 784])
torch.Size([400, 784])
Smallest eigenvalue : 15.195989608764648
torch.Size([400])
torch.Size([700, 784])
torch.Size([700, 784])
Smallest eigenvalue : 13.155244827270508
torch.Size([700])
torch.Size([1000, 784])
torch.Size([1000, 784])
Smallest eigenvalue : 5.420177936553955
torch.Size([1000])
torch.Size([1300, 784])
torch.Size([1300, 784])
Smallest eigenvalue : 5.3553948402404785
torch.Size([1300])
torch.Size([1600, 784])
torch.Size([1600, 784])
Smallest eigenvalue : 5.308578968048096
torch.Size([1600])
torch.Size([1900, 784])
torch.Size([1900, 784])
Smallest eigenvalue : 5.274092197418213
torch.Size([1900])
plt.title("Generalization bound as a function of n")
plt.plot(n_values, np.sqrt(bounds / np.array(n_values)))
[<matplotlib.lines.Line2D at 0x147f2fd30>]
../_images/82134a02e95b78b2009fa5f1d535b4dc39e66f51773db4f74200ca3906dc022b.png

The NTK Gram matrix is typically invertible#

if number of data points is smaller than input dim

n = 30
d = 100

xis = [torch.randn(d) for _ in range(n)]

# print(xis)
# print(np.dot(xis[n-1], xis[n-1]))

H = torch.empty((n, n))


def k(x, xprime):
    with torch.no_grad():
        v = torch.linalg.norm(x) * torch.linalg.norm(xprime)
        u = 0.99999 * torch.dot(x, xprime) / v
        return v * (
            u * (torch.pi - torch.arccos(u) + torch.sqrt(1 - u**2)) / (2 * np.pi)
            + u * (torch.pi - torch.arccos(u)) / (2 * np.pi)
        )


for i in range(n):
    for j in range(n):
        H[i, j] = k(xis[i], xis[j])


plt.title(f"Gram matrix for {n} x {n} inputs of dim {d}")
plt.imshow(H)
plt.colorbar()

plt.show()

plt.title("Sorted eigenvalues of the Gram matrix")
eigvals = np.linalg.eigvalsh(H)
plt.plot(eigvals)
plt.grid(alpha=0.3)
plt.ylim(0, 1.01 * np.max(eigvals))
plt.show()
../_images/48a23e3b4a53d42bd1390858bd195f3e4ccf0d3eb45b2400bf5f083cc3ecaef1.png ../_images/707994b308e8f7d6a1177effe12122a92e3816a85bed463c630882b8caee177f.png
# mlp = MLPdeep() # MLPshallow()
# PATH = f'./{dataset.type}_{mlp.net_type}.pth'
# torch.save(mlp.state_dict(), PATH)

# net = MLPdeep() # MLPshallow()
# net.load_state_dict(torch.load(PATH))

Overparameterized linear regression#

p = 1000
n = 100

X = np.random.normal(size=(p, n)) / np.sqrt(n)
Y = np.random.normal(size=(n, 1))

XXt = X @ X.transpose()
XY = X @ Y

fig, axs = plt.subplots(1, 2)

axs[0].set_title(r"$X X^\top $")
axs[0].imshow(XXt)

eigvals, P = np.linalg.eigh(XXt)
print(rf"Largest eigenvalue of of XX^\top : {max(eigvals)}")

axs[1].set_title("Diagonalized")
axs[1].imshow(P.transpose() @ XXt @ P)
plt.show()

print(np.linalg.matrix_rank(X))
Largest eigenvalue of of XX^\top 16.963392764546683
../_images/35d210d9e43030e3c1119b92f190f38d34810b9b43e4e5283278786fccba3dd5.png
100
T = 1000
eta = 0.001

ws = np.zeros((T, p, 1))
w = np.random.normal(size=(p, 1)) / np.sqrt(p)

losses = []
for t in range(T):
    w = w - eta * XXt @ w + eta * XY
    ws[t] = w

    losses.append(np.linalg.norm(X.transpose() @ w - Y) ** 2 / n)

ws = ws.squeeze(axis=2)
fig, axs = plt.subplots(1, 2, figsize=(16, 4))

axs[0].set_title("Loss")
axs[0].plot(losses)


axs[1].set_title("Relative distance of weights to init")
weight_evol = np.linalg.norm(ws - ws[0, :], axis=1) / np.linalg.norm(ws[0, :])
axs[1].plot(weight_evol, alpha=0.5)


plt.show()
../_images/c0e0ef62728da835f1c30d55580bb9866c947e9023bf49123c6e2e0027cfa0cb.png

In the right basis, most coordinates of w dont move#

w only moves in a low dimensional subspace: the image of X

fig, ax = plt.subplots(figsize=(40, 3))

plt.title(r"$|w_{t, i} - w_{0, i}|$")
ws_in_base = ws @ P

a = ax.imshow(np.abs((ws_in_base - ws_in_base[0]).transpose()))
fig.colorbar(a)
plt.tight_layout()
plt.show()
../_images/9e7d2dcf2272b73ae690d6d6e23b25695626de2ca1466a9195473dbca59b030a.png