TP6 : PAC-Bayes#
Hedi Hadiji, Theoretical principles of Deep Learning, 23/24, CS-UPSACLAY
The goal of this notebook is to illustrate the optimization of PAC-Bayes bounds on some MNIST task.
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
import torchvision
import numpy as np
import matplotlib.pyplot as plt
import copy
The data#
n = 2000 # number of training points that we keep
c1, c2 = 3, 6 # subclasses that are kept, make sure that c1 < c2
batch_size = 32 # training batch size
classes = ("T-shirt/Top",
"Trouser",
"Pullover",
"Dress",
"Coat",
"Sandal",
"Shirt",
"Sneaker",
"Bag",
"Ankle Boot"
)
# Define a sequence of operations that will be performed to all training
# images before use the 'ToTensor()' function sets the image in tensor
# object and puts the values of every pixel between 0 and 1 the
# 'Normalize' performs the dataset to a given mean and variance
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
# Creates an object to load the images. We use FashionMNIST as a
# substitute for MNIST because binary classification on subclasses of
# MNIST is too easy.
trainset = datasets.FashionMNIST('../data', train=True, download=True, transform=transform)
r = torch.arange(len(trainset))
# Build a training set that only contains images with classes c1 and c2,
# with n / 2 images from each label.
if c1 != 0:
trainset.targets[trainset.targets == 0] = -1
trainset.targets[trainset.targets == c1] = 0
if c2 != 1:
trainset.targets[trainset.targets == 1] = -1
trainset.targets[trainset.targets == c2] = 1
idxc1 = (torch.as_tensor(trainset.targets) == 0)
x1 = np.where(np.cumsum(idxc1) == (n / 2))[0][0]
idxc1 = idxc1 & (r <= x1)
idxc2 = (torch.as_tensor(trainset.targets) == 1)
x2 = np.where(np.cumsum(idxc2) == (n / 2))[0][0]
idxc2 = idxc2 & (r <= x2)
idx = idxc1 + idxc2
dset_train = torch.utils.data.dataset.Subset(trainset, np.where(idx==1)[0])
print(f'Number of training points : {len(dset_train)}')
trainloader = torch.utils.data.DataLoader(dset_train, batch_size=batch_size)
# Build the corresponding test set
testset = datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)
if c1 != 0:
testset.targets[testset.targets == 0] = -1
testset.targets[testset.targets == c1] = 0
if c2 != 1:
testset.targets[testset.targets == 1] = -1
testset.targets[testset.targets == c2] = 1
idx = torch.as_tensor(testset.targets) == 0
idx += torch.as_tensor(testset.targets) == 1
dset_test = torch.utils.data.dataset.Subset(testset, np.where(idx==1)[0])
testloader = torch.utils.data.DataLoader(dset_test)
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Using downloaded and verified file: ../data/FashionMNIST/raw/train-images-idx3-ubyte.gz
Extracting ../data/FashionMNIST/raw/train-images-idx3-ubyte.gz to ../data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Using downloaded and verified file: ../data/FashionMNIST/raw/train-labels-idx1-ubyte.gz
Extracting ../data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to ../data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ../data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz
0.7%
1.5%
2.2%
3.0%
3.7%
4.4%
5.2%
5.9%
6.7%
7.4%
8.2%
8.9%
9.6%
10.4%
11.1%
11.9%
12.6%
13.3%
14.1%
14.8%
15.6%
16.3%
17.0%
17.8%
18.5%
19.3%
20.0%
20.7%
21.5%
22.2%
23.0%
23.7%
24.5%
25.2%
25.9%
26.7%
27.4%
28.2%
28.9%
29.6%
30.4%
31.1%
31.9%
32.6%
33.3%
34.1%
34.8%
35.6%
36.3%
37.1%
37.8%
38.5%
39.3%
40.0%
40.8%
41.5%
42.2%
43.0%
43.7%
44.5%
45.2%
45.9%
46.7%
47.4%
48.2%
48.9%
49.6%
50.4%
51.1%
51.9%
52.6%
53.4%
54.1%
54.8%
55.6%
56.3%
57.1%
57.8%
58.5%
59.3%
60.0%
60.8%
61.5%
62.2%
63.0%
63.7%
64.5%
65.2%
65.9%
66.7%
67.4%
68.2%
68.9%
69.7%
70.4%
71.1%
71.9%
72.6%
73.4%
74.1%
74.8%
75.6%
76.3%
77.1%
77.8%
78.5%
79.3%
80.0%
80.8%
81.5%
82.3%
83.0%
83.7%
84.5%
85.2%
86.0%
86.7%
87.4%
88.2%
88.9%
89.7%
90.4%
91.1%
91.9%
92.6%
93.4%
94.1%
94.8%
95.6%
96.3%
97.1%
97.8%
98.6%
99.3%
100.0%
Extracting ../data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to ../data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ../data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz
100.0%
Extracting ../data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to ../data/FashionMNIST/raw
Number of training points : 2000
def imshow(img):
img = img / 2 + 0.5 # unnormalize
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.show()
# get batch_size random training images
dataiter = iter(trainloader)
images, labels = next(dataiter)
# show images
imshow(torchvision.utils.make_grid(images))
A neural network#
Fully connected 1-hidden layer neural network. Flattens the image and treats it as a vector.
class MLPDeep(nn.Module):
def __init__(self,):
super().__init__()
self.flatten = nn.Flatten()
self.l1 = nn.Linear(28 * 28, 100)
self.l2 = nn.Linear(100, 100)
self.l3 = nn.Linear(100, 100)
self.lout = nn.Linear(100, 2)
# Store the initial value for the bound optimization
with torch.no_grad():
self.w0 = copy.deepcopy(self.state_dict())
self.total_parameters = sum(param.numel() for param in self.parameters())
def forward(self, x):
x = self.flatten(x)
x = F.relu(self.l1(x))
x = F.relu(self.l2(x))
x = F.relu(self.l3(x))
x = self.lout(x)
return x
net = MLPDeep()
print('Number of parameters :', net.total_parameters)
Number of parameters : 98902
The training function#
We train the network with the square loss to predict the class. The labels are normalized to be between \(0\) and \(1\).
def train(net,
trainloader,
N_passes=1,
lr=0.01):
optimizer = optim.SGD(net.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()
losses = []
i = 0
for _ in range(N_passes):
for inputs, labels in trainloader:
i += 1
optimizer.zero_grad()
# print(inputs.shape)
# print(torch.linalg.vector_norm(inputs, dim=(1, 2, 3)))
outputs = net(inputs)
target = labels
loss = criterion(outputs, target)
loss.backward()
optimizer.step()
losses.append(loss.detach().numpy())
print(f'Number of gradient steps {i}')
return losses
print('Batch size : ', batch_size)
losses = train(net,
trainloader,
N_passes=50,
lr=0.01)
torch.save(net.state_dict(), 'models/trained_net.pt')
plt.ylim(0, 1.1 * np.max(losses))
plt.plot(losses)
plt.grid(alpha=.3)
Batch size : 32
Number of gradient steps 3150
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[6], line 8
1 print('Batch size : ', batch_size)
3 losses = train(net,
4 trainloader,
5 N_passes=50,
6 lr=0.01)
----> 8 torch.save(net.state_dict(), 'models/trained_net.pt')
10 plt.ylim(0, 1.1 * np.max(losses))
11 plt.plot(losses)
File ~/anaconda3/envs/py39/lib/python3.9/site-packages/torch/serialization.py:628, in save(obj, f, pickle_module, pickle_protocol, _use_new_zipfile_serialization, _disable_byteorder_record)
625 _check_save_filelike(f)
627 if _use_new_zipfile_serialization:
--> 628 with _open_zipfile_writer(f) as opened_zipfile:
629 _save(obj, opened_zipfile, pickle_module, pickle_protocol, _disable_byteorder_record)
630 return
File ~/anaconda3/envs/py39/lib/python3.9/site-packages/torch/serialization.py:502, in _open_zipfile_writer(name_or_buffer)
500 else:
501 container = _open_zipfile_writer_buffer
--> 502 return container(name_or_buffer)
File ~/anaconda3/envs/py39/lib/python3.9/site-packages/torch/serialization.py:473, in _open_zipfile_writer_file.__init__(self, name)
471 super().__init__(torch._C.PyTorchFileWriter(self.file_stream))
472 else:
--> 473 super().__init__(torch._C.PyTorchFileWriter(self.name))
RuntimeError: Parent directory models does not exist.
Checking the training values#
correct = 0
total = 0
with torch.no_grad():
for images, labels in trainloader:
outputs = net(images)
predicted = torch.argmax(outputs, 1)
total += labels.size(0)
correct += (predicted == labels).sum()
print(f'Accuracy of the network on the {total} train images: {100 * correct // total} %')
Accuracy of the network on the 2000 train images: 96 %
Test loss#
correct = 0
total = 0
with torch.no_grad():
for images, labels in testloader:
outputs = net(images)
predicted = torch.argmax(outputs, 1)
total += labels.size(0)
correct += (predicted == labels).sum()
print(f'Accuracy of the network on the {total} test images: {100 * correct // total} %')
Accuracy of the network on the 2000 test images: 91 %
PAC Bayes bound optimization#
def kl(q, p):
return q * np.log(q / p) + (1 - q) * np.log((1 - q) / (1 - p))
def klprime(q, p):
return -q / p + (1 - q) / (1 - p)
def klinvert(q, c, k=10):
r = q + np.sqrt(c / 2)
for _ in range(k):
if r >= 1:
return 1.
else:
r = r - (kl(q, r) - c) / klprime(q,r)
return r
class PACBayesBound(nn.Module):
'''
This object keeps track of everything that is in the pac bayes upper bound.
- the prior parameter
- the posterior parameters
'''
def __init__(self, net, nsamples, delta=0.02, deltaprime=0.01):
super().__init__()
### Initialize the posterior means
self.w_post = net
### Initialize the posterior variances
self.sigma_post = copy.deepcopy(net)
norm = sum([torch.sum(param ** 2) for param in self.w_post.parameters()])
# print(norm)
for param in self.sigma_post.parameters():
param.data = torch.sqrt(norm) * torch.ones_like(param.data)
### Initalize the prior
self.lambda_max = torch.tensor([10])
self.alpha = torch.tensor([.5])
self.sigmaPrior = torch.nn.Parameter(torch.tensor([.09]), requires_grad=True)
### Other parameters
self.d = net.total_parameters
self.delta = torch.tensor([delta])
self.deltaprime = torch.tensor([deltaprime])
self.nsamples = nsamples
self.final_bound_value = None
def KL_g(self):
l = self.sigmaPrior
norm, dw, logterm = 0, 0, 0
for param_s, param_w, param_w0 in zip(self.sigma_post.parameters(), self.w_post.parameters(), self.w_post.w0.values()):
# print('w0', param_w0.shape)
# print('w', param_w.shape)
# print('s', param_s.shape)
norm += torch.sum(param_s ** 2) / (2 * l ** 2)
dw += torch.sum((param_w - param_w0) ** 2) / (2 * l ** 2)
logterm += torch.sum(torch.log((l / param_s) ** 2) - 1)/ 2
return norm + dw + logterm
def meta_prior_cost(self):
first = np.pi ** 2 / 6
second = (torch.log((self.sigmaPrior / self.lambda_max) ** 2) / (2 * torch.log(self.alpha)) ) ** 2
# print(second)
return torch.log(first * second)
def full_KL_bound(self):
KL_g = self.KL_g()
conf = torch.log(2 * torch.sqrt(self.nsamples / self.delta))
metaprior = self.meta_prior_cost()
return (KL_g + conf + metaprior) / self.nsamples
def estimate_sample_error(self, dset_train, m):
"""
Samples m networks according to the current posterior, evaluates them on the training data, and returns the average
"""
trainloader = torch.utils.data.DataLoader(dset_train, batch_size=self.nsamples)
corrects = []
with torch.no_grad():
for _ in range(m):
sample_net = copy.deepcopy(bound.w_post)
for param, param_s, param_w in zip(sample_net.parameters(), self.sigma_post.parameters(), self.w_post.parameters()):
param = param_w + param_s * torch.normal(torch.zeros_like(param), 1)
correct = 0
total = 0
for images, labels in trainloader:
outputs = sample_net(images)
predicted = torch.argmax(outputs, 1)
total += labels.size(0)
correct += (predicted == labels).sum()
corrects.append(correct / total)
self.error_estimate = 1 - np.mean(corrects)
print(f'Estimated average posterior error on the training set {self.error_estimate}')
def final_bound(self, m):
'''
TODO: put the real prior from the grid
'''
half_bound = klinvert(self.error_estimate, torch.log(1 / self.deltaprime) / m)
with torch.no_grad():
full_bound = klinvert(half_bound, self.full_KL_bound())
self.final_bound_value = full_bound
success_prob = (1 - self.delta - self.deltaprime).detach()
print(f'With probability at least {float(success_prob)},'
+ f'the average loss under the posterior is less than {self.final_bound_value}')
return self.final_bound_value
net = MLPDeep()
net.load_state_dict(torch.load("models/trained_net.pt"))
bound = PACBayesBound(net, n)
def optimize_bound(bound,
dset_train,
N_passes=2,
lr=0.01):
trainloader = torch.utils.data.DataLoader(dset_train, batch_size=bound.nsamples)
all_parameters = (list(bound.w_post.parameters())
+ list(bound.sigma_post.parameters())
+ [bound.sigmaPrior]
)
optimizer = optim.SGD(all_parameters, lr=lr)
criterion = nn.CrossEntropyLoss()
losses = []
i = 0
for _ in range(N_passes):
for inputs, labels in trainloader:
i += 1
optimizer.zero_grad()
perturbed_net = copy.deepcopy(bound.w_post)
for param, param_s, param_w in zip(perturbed_net.parameters(), bound.sigma_post.parameters(), bound.w_post.parameters()):
param = param_w + param_s * torch.normal(torch.zeros_like(param), 1)
outputs = perturbed_net(inputs)
target = labels
loss = criterion(outputs, target) + torch.sqrt(bound.full_KL_bound() / 2)
loss.backward()
optimizer.step()
losses.append(loss.detach().numpy())
print(f'Number of gradient steps {i}')
print(f'First optimized bound value {losses[0]}')
print(f'Last optimized bound value {losses[-1]}')
return losses
losses = optimize_bound(bound,
dset_train,
N_passes=100,
lr=.1)
plt.ylim(0, 1.1 * np.max(losses))
plt.plot(losses)
plt.grid(alpha=.3)
Number of gradient steps 100
First optimized bound value [0.12836254]
Last optimized bound value [0.12831238]
m = 200
bound.estimate_sample_error(dset_train, m)
Estimated average posterior error on the training set 0.03800004720687866
final_bound = bound.final_bound(1000000) # cheat a bit by faking a large m to see if there is a chance that the bound is good
print(float(bound.sigmaPrior.detach()))
With probability at least 0.9900000095367432,the average loss under the posterior is less than tensor([0.0485])
10.426172256469727
Interesting questions:
evaluate the posterior on the test set
find ways to examine and interpret the prior and posterior
investigate the flat minimum hypothesis
repeat the experiment with random labels
Pinsker’s inequality, etc.#
def kl(q, p):
return q * np.log(q / p) + (1 - q) * np.log((1 - q) / (1 - p))
def dtv(q, p):
return 2 * (q - p)**2
q = .1
ps= np.linspace(0, .5, 1000)
klvals = [kl(q, p) for p in ps]
tvvals = [dtv(q, p) for p in ps]
plt.plot(ps, klvals, label=rf'KL(q, $\cdot$), $q = {q}$')
plt.plot(ps, tvvals, label=rf'$2(q - \cdot)^2$, $q = {q}$')
plt.legend()
plt.show()
def klprime(q, p):
return -q / p + (1 - q) / (1 - p)
def klinvert(q, c, k=10):
r = q + np.sqrt(c / 2)
for _ in range(k):
if r >= 1:
return 1
else:
r = r - (kl(q, r) - c) / klprime(q,r)
return r
### Testing the inversion
q = .499
cs = np.linspace(.001, .5)
kls = []
us = []
# for c in cs:
# p = klinvert(q, c)
# kls.append(kl(q , p))
# plt.title(f"$kl(q, kl^{{-1}}(q, c)) = c$")
# plt.plot(cs, kls - cs)
# plt.show()
# qs = np.linspace(0, 1, 1000)
# for c in [.01, .1, .2]:
# plt.plot(qs, [klinvert(q, c) for q in qs], label=fr'$kl^{{-1}}(\cdot, {c})$ ')
# plt.legend()
# plt.show()
cs = np.linspace(0, 9, 1000)
for q in [.01, .1, .5, .9]:
plt.plot(cs, [klinvert(q, c) for c in cs], label=fr'$kl^{{-1}}({q}, \cdot)$ ')
plt.legend()
plt.show()