TP3 : Lazy training#
Observing the lazy training regime with different models
With inspiration from
https://rajatvd.github.io/NTK/
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
### Exectutes the script tp3_utils.py which contains code from last time
## Go take a quick look at the functions implemented there
from tp3_utils import *
# %load_ext autoreload
# %autoreload 2
The data#
We continue with last week’s toy problem of 1 dimensional regression
xs = np.linspace(-1, 1, 100)
plt.plot(xs, simple_f(xs), label='Simple f')
plt.plot(xs, middle_f(xs), label='Middle f')
plt.plot(xs, complex_f(xs), label='Complicated f')
plt.title('Target functions')
plt.legend()
plt.show()
dataset = Data(n=15,
xmin=-1,
xmax=1,
noise_level=0,
type='middle')
plot_data(dataset)
plt.legend()
plt.show()
data:image/s3,"s3://crabby-images/e95e6/e95e68754e2bdd3748a21a1c1fbaf7e74eec0ff9" alt="_images/cd5688302688263222ec0dc892a827d342620cf3c0b1bf94a58259ea0ee9bbed.png"
data:image/s3,"s3://crabby-images/da809/da809aa2490cf914df59425a366ed760972c7f66" alt="_images/9db2d5ad0ec076e877c4141f1e38c7f22fa7dda4ec350d1f16ba1c85eb8a1180.png"
The training function#
Write the training function
def train(net,
dataset,
N_steps=1,
batch_size=30,
lr=0.01,
save_weights_every=100):
optimizer = optim.SGD(net.parameters(), lr=lr)
criterion = nn.MSELoss()
losses = []
weights = []
for i in range(N_steps):
inputs, labels= dataset.next_batch(batch_size)
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
losses.append(loss.detach().numpy())
if i % save_weights_every == 0 and (save_weights_every > 0):
with torch.no_grad():
weights.append(net.save_weights())
return losses
def train(net,
dataset,
N_steps=1,
batch_size=30,
lr=0.01,
save_weights_every=100):
# YOUR CODE HERE
losses = []
weights = []
for i in range(N_steps):
# YOUR CODE HERE
losses.append(loss.detach().numpy())
if i % save_weights_every == 0 and (save_weights_every > 0):
with torch.no_grad():
weights.append(net.save_weights())
return losses
Test your training#
model = MLPdeep(hidden_dim=10)
losses = train(model,
dataset,
N_steps=10000,
batch_size=15,
lr=.1,
save_weights_every=1,
)
plot_net(model)
plot_data(dataset)
plt.show()
plt.plot(losses)
plt.show()
---------------------------------------------------------------------------
NameError Traceback (most recent call last)
Cell In[6], line 2
1 model = MLPdeep(hidden_dim=10)
----> 2 losses = train(model,
3 dataset,
4 N_steps=10000,
5 batch_size=15,
6 lr=.1,
7 save_weights_every=1,
8 )
10 plot_net(model)
11 plot_data(dataset)
Cell In[5], line 16, in train(net, dataset, N_steps, batch_size, lr, save_weights_every)
11 weights = []
13 for i in range(N_steps):
14 # YOUR CODE HERE
---> 16 losses.append(loss.detach().numpy())
18 if i % save_weights_every == 0 and (save_weights_every > 0):
19 with torch.no_grad():
NameError: name 'loss' is not defined
A linear model#
We define a family of linear models with sinusoidal features. The feature map is
where \(k\) is feature_dim / 2.
class LinSin(nn.Module):
"""
A linear model with sinusoidal features. Assumes feature_dim is even.
"""
def __init__(self, feature_dim):
super().__init__()
self.feature_dim = feature_dim
self.lout = nn.Linear(self.feature_dim, 1)
self.weight_history = []
def phi(self, x):
"""
Computes $sin(2\pi k x), cos(2\pi k x) $ for k in $[feature_dim / 2]$
"""
r = 2 * np.pi * torch.arange(int(self.feature_dim / 2)) * x
return torch.concatenate([torch.sin(r), torch.cos(r)], axis=1)
def save_weights(self):
with torch.no_grad():
l = list(self.parameters())
self.weight_history.append(np.copy(l[-1].numpy()))
def forward(self, x):
return self.lout(self.phi(x))
class LinSin(nn.Module):
"""
A linear model with sinusoidal features
"""
def __init__(self, feature_dim):
super().__init__()
self.feature_dim = feature_dim
self.lout = nn.Linear(self.feature_dim, 1)
self.weight_history = []
def phi(self, x):
"""
Computes $sin(2\pi k x), cos(2\pi k x) $ for k in $[feature_dim / 2]$
"""
# YOUR CODE HERE
pass
def save_weights(self):
with torch.no_grad():
l = list(self.parameters())
self.weight_history.append(np.copy(l[-1].numpy()))
def forward(self, x):
# YOUR CODE HERE
pass
Test the linear model
model = LinSin(feature_dim=8)
losses = train(model,
dataset,
N_steps=10000,
batch_size=15,
lr=.01,
save_weights_every=1,
)
plot_net(model)
plot_data(dataset)
plt.show()
fig, axs = plt.subplots(1, 2, figsize=(12, 4))
axs[0].loglog(losses)
axs[0].set_title('Train loss')
all_weights1 = np.array(model.weight_history)
weight_evol = np.linalg.norm(all_weights1 - all_weights1[0], axis=1) / np.linalg.norm(all_weights1[0])
axs[1].plot(weight_evol, alpha=.5)
axs[1].set_title('Relative weight movement of last layer')
data:image/s3,"s3://crabby-images/1df03/1df03717c7d9c684b442f595e6b18661fd744c13" alt="_images/1b75fa7bd4eeeb51285d56cc3ecdd5f49717fd76fadbf2d1445d245d22efc47a.png"
Text(0.5, 1.0, 'Relative weight movement of last layer')
data:image/s3,"s3://crabby-images/6efad/6efadc25de368427ecbed43ea54c561e8389c828" alt="_images/b82b46d55a8dcc1a5a6e01e5990e83b71082a1ffaf1f310de7043e88418cb656.png"
Question: what do the plots mean?
Let us now repeat the code above with a large feature dimension. What you do observe?
model = LinSin(feature_dim=200)
losses = train(model,
dataset,
N_steps=100,
batch_size=15,
lr=.01,
save_weights_every=1,
)
plot_net(model)
plot_data(dataset)
plt.title('Trained model')
plt.show()
fig, axs = plt.subplots(1, 2, figsize=(12, 4))
axs[0].loglog(losses)
axs[0].set_title('Train loss')
all_weights1 = np.array(model.weight_history)
weight_evol = np.linalg.norm(all_weights1 - all_weights1[0], axis=1) / np.linalg.norm(all_weights1[0])
axs[1].plot(weight_evol, alpha=.5)
axs[1].set_title('Relative weight movement of last layer')
plt.show()
data:image/s3,"s3://crabby-images/9c8d8/9c8d8909be642b98b4173eb3a6fcd1c13039443b" alt="_images/548aa265a779ff227636c9036e716bf449f7f627f296fc90036a4de5595cc551.png"
data:image/s3,"s3://crabby-images/c60f7/c60f7ea05efd9a43b9d0b6addcdfaee15a988171" alt="_images/5d02ccef808c71455c60f6c8e4bf54a4d6a378b50ff728b757afc283096f7ba9.png"
Robustness of fast convergence for linear models#
Run 10 instances of the linear model with different initializations.
Are there significant differences between the models?
model_list = []
N_models = 1
for _ in range(N_models):
model = LinSin(feature_dim=100)
model_list.append(model)
plot_net(model)
plot_data(dataset)
plt.title('At initialization')
all_losses = [[] for _ in range(N_models)]
all_weights = []
data:image/s3,"s3://crabby-images/db48a/db48a17c3619c2147d2fe552ec65d8ec836cf9db" alt="_images/4e4de4066341971d24900c91df9927c9425b7ca0e15fcf2e1911a0f52f73a3e1.png"
N_steps = 100000
for i, model in enumerate(model_list):
losses = train(model,
dataset,
N_steps=N_steps,
batch_size=30,
lr=.01,
save_weights_every=1,
)
all_losses[i] += losses
plot_net(model)
plot_data(dataset)
plt.title('After some training')
plt.show()
data:image/s3,"s3://crabby-images/4348b/4348b825aa90cf910bc50ab23ceeec696e460e76" alt="_images/433262b9d1c1070b23de856edc3a798c5095e11d1fb062f6d5935c548df4e720.png"
fig, axs = plt.subplots(1, 2, figsize=(12, 4))
axs[0].set_title('Train loss')
for losses in all_losses:
points = np.arange(0, len(losses), 1)
losses = np.array(losses)
axs[0].loglog(points, losses[points], alpha=.5)
axs[1].set_title('Relative movement of weights')
for model in model_list:
all_weights1 = np.array(model.weight_history)
weight_evol = np.linalg.norm(all_weights1 - all_weights1[0], axis=1) / np.linalg.norm(all_weights1[0])
axs[1].plot(weight_evol, alpha=.5)
plt.show()
data:image/s3,"s3://crabby-images/b4892/b489205c05adeca97715b1f14ca29afe810bcb67" alt="_images/cbf36ae45f5293ceeddfe80e9f5d06f14b3f4fb3fba12aaafdd52c5a02c2e0d5.png"
Lazy training of non-linear models#
We now train some neural nets in different regimes and try to see when lazy training occurs.
Question How are nets initialized in pytorch?
model_list = []
N_models = 10
for _ in range(N_models):
model = MLPshallow(hidden_dim=10)
model_list.append(model)
plot_net(model)
plot_data(dataset)
plt.title('At initialization')
all_losses = [[] for _ in range(N_models)]
all_weights = []
data:image/s3,"s3://crabby-images/3ffae/3ffae46d3c5ecf3eaa8bdc7217101a106582ce8b" alt="_images/9e209298d93501c0ff8dfa32525650a605c28dcb5b2982226b5ee1f6f80ccdb9.png"
N_steps = 5000
for i, model in enumerate(model_list):
losses = train(model,
dataset,
N_steps=N_steps,
batch_size=30,
lr=.001,
save_weights_every=1,
)
all_losses[i] += losses
plot_net(model)
plot_data(dataset)
plt.title('After some training')
plt.show()
data:image/s3,"s3://crabby-images/4f094/4f094bd04e9f288a44fa73af3bc6be98a123a4f4" alt="_images/150d1525442ef79549261a6bf774352c14d2213be224ace5834a4c20dd5970e7.png"
fig, axs = plt.subplots(1, 2, figsize=(12, 4))
axs[0].set_title('Train loss')
for losses in all_losses:
points = np.arange(0, len(losses), 1)
losses = np.array(losses)
axs[0].loglog(points, losses[points], alpha=.5)
axs[1].set_title('Relative weight movement of last layer')
for model in model_list:
all_weights1 = np.array(model.weight_history)
weight_evol = np.linalg.norm(all_weights1 - all_weights1[0], axis=1) / np.linalg.norm(all_weights1[0])
axs[1].plot(weight_evol, alpha=.5)
plt.show()
Thin vs. Wide networks#
Let us now vary the width of our network. What happens?
model_list = []
width_list = [5, 10, 100, 1000]
for width in width_list:
model = MLPshallow(hidden_dim=width)
model_list.append(model)
plot_net(model, label=f'Width = {model.hidden_dim}')
plot_data(dataset)
plt.title('At initialization')
plt.legend()
all_losses = [[] for _ in width_list]
all_weights = []
data:image/s3,"s3://crabby-images/14b6c/14b6c1dace50b1f622c2e2fd920dc3423f63abfa" alt="_images/712cba2103f40922dfe3ebfc8539266275b5e0aac5f72bd4955ad77ccc9732e9.png"
N_steps = 5000
for i, model in enumerate(model_list):
losses = train(model,
dataset,
N_steps=N_steps,
batch_size=30,
lr=.001,
save_weights_every=1,
)
all_losses[i] += losses
plot_net(model, label=f'Width = {model.hidden_dim}')
plot_data(dataset)
plt.title('After some training')
plt.show()
data:image/s3,"s3://crabby-images/c07fe/c07febc37ae1550837214f94d65832aaeb987303" alt="_images/b67d38c733d2d71420e32d2b22d889ea819b79b85e7da421763ee984a8dd9533.png"
fig, axs = plt.subplots(1, 2, figsize=(12, 4))
axs[0].set_title('Train loss')
for i, losses in enumerate(all_losses):
points = np.arange(0, len(losses), 1)
losses = np.array(losses)
axs[0].loglog(points, losses[points], alpha=.5, label=f'{width_list[i]}')
axs[1].set_title('Relative weight movement of last layer')
for model in model_list:
all_weights1 = np.array(model.weight_history)
weight_evol = np.linalg.norm(all_weights1 - all_weights1[0], axis=1) / np.linalg.norm(all_weights1[0])
axs[1].plot(weight_evol, alpha=.5, label=f'{model.hidden_dim}')
plt.legend()
plt.show()
data:image/s3,"s3://crabby-images/e09f6/e09f6460c9f6f9fc47b27c241757e5e0a74fb7c5" alt="_images/50c3698fe5b613e312fbe9719567f6711f40a84071243b6a92daa0c2ff9cdd14.png"
Bad initialization can you get you out of the Lazy regime#
We simulate bad initialization by taking two steps of gradient descent from initialization.
model_list = []
N_models = 10
for _ in range(N_models):
model = MLPshallow(hidden_dim=100)
model_list.append(model)
train(model,
dataset,
N_steps=2,
batch_size=30,
lr=1,
save_weights_every=-1,
)
plot_net(model)
plot_data(dataset)
plt.title('At bad initialization')
all_losses = [[] for _ in range(N_models)]
all_weights = []
N_steps = 1000
xmin = np.min(dataset.inputs.detach().numpy())
xmax = np.max(dataset.inputs.detach().numpy())
plt.xlim(xmin, xmax)
print(xmin, xmax)
for i, model in enumerate(model_list):
losses = train(model,
dataset,
N_steps=N_steps,
batch_size=30,
lr=.001,
save_weights_every=1,
)
all_losses[i] += losses
plot_net(model, xmin=xmin, xmax=xmax)
plot_data(dataset)
plt.title('After some training')
plt.show()
fig, axs = plt.subplots(1, 2, figsize=(12, 4))
axs[0].set_title('Train loss')
for i, losses in enumerate(all_losses):
points = np.arange(0, len(losses), 1)
losses = np.array(losses)
axs[0].loglog(points, losses[points], alpha=.5)
axs[1].set_title('Relative weight movement of last layer')
for model in model_list:
all_weights1 = np.array(model.weight_history)
weight_evol = np.linalg.norm(all_weights1 - all_weights1[0], axis=1) / np.linalg.norm(all_weights1[0])
axs[1].plot(weight_evol, alpha=.5)
plt.show()
Scaling#
Now we attempt to enter the lazy regime by scaling the model outputs.
n_points = 10
dataset = Data(n=n_points,
xmin=-1,
xmax=1,
noise_level=0,
type='middle')
class ScaledModel(MLPdeep):
""""
From a base network, creates a scaled copy with 0 initialization. Keeps a frozen copy of a base model.
alpha: scaling factor
"""
def __init__(self, base_model, alpha=1):
super().__init__(hidden_dim=base_model.hidden_dim)
self.alpha = alpha
self.load_state_dict(base_model.state_dict())
self.base_model = [base_model] # trick to hide from the other parameters dictionary. Probably not the proper way to do this...
def forward(self, x):
"""
Scales the output and subtracts the initial function
"""
with torch.no_grad():
base = self.base_model[0].forward(x).detach()
return self.alpha * (super().forward(x) - base)
base_model = MLPdeep(hidden_dim=8)
alphas = [1, 100, 1000]#, 2000, 10000]
model_list = []
for alpha in alphas:
new_model = ScaledModel(base_model, alpha=alpha)
new_model.load_state_dict(base_model.state_dict())
model_list.append(new_model)
plot_net(new_model, label=f'{alpha}')
all_losses = [[] for _ in range(len(model_list))]
plt.legend()
plt.show()
data:image/s3,"s3://crabby-images/6b831/6b831cbc50eab437124d3af538fdaefc96f3c2f7" alt="_images/a7800bfd338336733de040169f8a5dc5c7aa729b97b2cc2c027efd9426f0ca0e.png"
train(model_list[2],
dataset,
N_steps=10000,
batch_size=n_points,
lr=.1 / alphas[2]**2,
save_weights_every=1,
)
plot_data(dataset)
plt.legend()
plt.title('After some training')
for i, model in enumerate(model_list):
plot_net(model, xmin=xmin, xmax=xmax, label=f'alpha = {alphas[i]}')
N_steps = 10000
xmin = -1 #np.min(dataset.inputs.detach().numpy())
xmax = 1 #np.max(dataset.inputs.detach().numpy())
plt.xlim(xmin, xmax)
print(xmin, xmax)
for i, model in enumerate(model_list):
losses = train(model,
dataset,
N_steps=N_steps,
batch_size=n_points,
lr=.01 / alphas[i]**2 ,
save_weights_every=1,
)
all_losses[i] += losses
plot_net(model, xmin=xmin, xmax=xmax, label=f'alpha = {alphas[i]}')
plot_data(dataset)
plt.legend()
plt.title('After some training')
plt.show()
-1 1
data:image/s3,"s3://crabby-images/bc77c/bc77c0be9fd62a09349807f0255c9002d97e86ce" alt="_images/f0c9ba183a93d7db8485c1cdd9ed229fde3d9de4fad51888c679bb605dad7952.png"
fig, axs = plt.subplots(1, 2, figsize=(12, 4))
axs[0].set_title('Train loss')
for i, losses in enumerate(all_losses):
points = np.arange(0, len(losses), 1)
losses = np.array(losses) # / alphas[i] ** 2
axs[0].loglog(points, losses[points], alpha=.5, label=f'{alphas[i]}')
axs[1].set_title('Relative weight movement of last layer')
for i, model in enumerate(model_list):
all_weights1 = np.array(model.weight_history)
weight_evol = np.linalg.norm(all_weights1 - all_weights1[0], axis=1) / np.linalg.norm(all_weights1[0])
axs[1].plot(weight_evol, alpha=.5, label=rf'$\alpha$ = {alphas[i]}')
plt.legend()
plt.show()
data:image/s3,"s3://crabby-images/8dd80/8dd801f0b1a0b05858db9c8c36de73cf30cf5169" alt="_images/b79b443b65bfc611175b238f9adce54efb50eef6da452826c9cb0e4c3f3200d7.png"
5 - The effect of scaling on a 1D model#
Question Explain the piece of code below.
see blog post
import numpy as np
import matplotlib.pyplot as plt
w_0 = 1.5
def f(x, w):
return w * x /2 + np.log(1 + w**2) * x / 5 + .1 * np.sin(np.exp(2*w))#np.exp(0.1*w))
def h(x, w, w_0=.4):
return f(x, w) * (w - w_0)
def linh(x, w, w_0=.4):
return f(x, w_0) * (w - w_0)
xys = [(1.5, 3)]
def l(w, alpha):
return np.sum([(alpha * h(x, w, w_0=w_0) - y) ** 2 / 2 for x, y in xys]) / alpha ** 2
def linl(w, alpha):
return np.sum([(alpha * linh(x, w, w_0=w_0) - y) ** 2 / 2 for x, y in xys]) / alpha ** 2
alphas = [1, 5, 10, 50, 100, 1000, 2000, 10000]
ws = np.linspace(-1, 1, 100)
plt.title('h(2; w)')
plt.plot(ws, [h(2, w, w_0=w_0) for w in ws])
plt.show()
for alpha in alphas:
delta = 2 / alpha ** (3/4)
ws = np.linspace(w_0 - delta, w_0 + delta, 5000)
plt.title(f'Loss landscape of scaled model: alpha = {alpha}')
plt.plot(ws, [l(w, alpha) for w in ws], label='Loss of scaled model')
plt.plot(ws, [linl(w, alpha) for w in ws], label='Loss of scaled linearized model')
plt.grid(alpha=.2)
plt.legend()
plt.show()
import numpy as np
import matplotlib.pyplot as plt
w_0 = 1.5
def f(x, w):
return w * x /2 + np.log(1 + w**2) * x / 5 + .1 * np.sin(np.exp(2*w))
def h(x, w, w_0=.4):
return f(x, w) * (w - w_0)
def linh(x, w, w_0=.4):
return f(x, w_0) * (w - w_0)
xys = [(1.5, 3)]
def l(w, alpha):
return np.sum([(alpha * h(x, w, w_0=w_0) - y) ** 2 / 2 for x, y in xys]) / alpha ** 2
def linl(w, alpha):
return np.sum([(alpha * linh(x, w, w_0=w_0) - y) ** 2 / 2 for x, y in xys]) / alpha ** 2
alphas = [1, 5, 10, 50, 100, 1000, 2000, 10000]
for alpha in alphas:
delta = 2 / alpha ** (3/4)
ws = np.linspace(w_0 - delta, w_0 + delta, 5000)
plt.title(f'??: alpha = {alpha}')
plt.plot(ws, [l(w, alpha) for w in ws], label='??')
plt.plot(ws, [linl(w, alpha) for w in ws], label='??')
plt.grid(alpha=.2)
plt.legend()
plt.show()
Pieces#
mlp = model_list[0]
for name, param in mlp.named_parameters():
print(name)
print(param.shape)
print(np.linalg.norm((param.detach().numpy())))
a = np.arange(10,30)
shifted_a = np.zeros(20)
print(shifted_a)
shifted_a[1:] = a[:-1]
print(shifted_a)
print(a)
evol = (a - shifted_a)[1:] / a[1:]
print(evol)
plt.title('Difference in succesive weights')
for model in model_list:
all_weights1 = np.array(model.weight_history)
shifted_weights = np.zeros(all_weights1.shape)
shifted_weights[1:, :] = all_weights1[:-1, :]
weight_evol = (np.linalg.norm((all_weights1 - shifted_weights)[1:], axis=1)
/ np.linalg.norm(all_weights1[1:], axis=1))
plt.plot(weight_evol, alpha=.5)
plt.show()
You can use the following code snippet to save your network.
# mlp = MLPdeep() # MLPshallow()
# PATH = f'./{dataset.type}_{mlp.net_type}.pth'
# torch.save(mlp.state_dict(), PATH)
# net = MLPdeep() # MLPshallow()
# net.load_state_dict(torch.load(PATH))
p = 100
n = 5000
X = np.random.normal(size=(p, n)) / np.sqrt(n)
Y = np.random.normal(size=(n, 1))
XXt = X @ X.transpose()
XY = X @ Y
fig, axs = plt.subplots(1, 2)
axs[0].set_title(r'$X X^\top $')
axs[0].imshow(XXt)
eigvals, P = np.linalg.eigh(XXt)
print(max(eigvals))
axs[1].set_title('Diagonalized')
axs[1].imshow(P.transpose() @ XXt @ P)
plt.show()
print(np.linalg.matrix_rank(X))
1.2872850938240181
data:image/s3,"s3://crabby-images/c9e28/c9e28a9856fb8e522acc78d0e85ec5d263524a81" alt="_images/3ca2f5f8f947a58bd739b331154119b3ef9d31e879acce1ebfca1015d26c466d.png"
100
T = 1000
eta = 0.001
ws = np.zeros((T, p, 1))
w = np.random.normal(size=(p, 1)) / np.sqrt(p)
losses = []
for t in range(T):
w = w - eta * XXt @ w + eta * XY
ws[t] = w
losses.append(np.linalg.norm(X.transpose() @ w - Y)**2 / n )
ws = ws.squeeze(axis=2)
fig, axs = plt.subplots(1, 2, figsize=(16, 4))
axs[0].set_title('Loss')
axs[0].plot(losses)
axs[1].set_title('Relative distance of weights to init')
weight_evol = np.linalg.norm(ws - ws[0, :], axis=1) / np.linalg.norm(ws[0, :])
axs[1].plot(weight_evol, alpha=.5)
plt.show()
data:image/s3,"s3://crabby-images/7f3b2/7f3b2680190de8adc5533399ecd623f9a1629cb2" alt="_images/cb7c3889add3b93c0022963a23bb724d3d975d95ed03c58e15ded9bac551060a.png"
In the right basis, most coordinates of w dont move#
w only moves in a low dimensional subspace: the image of X
fig, ax = plt.subplots(figsize=(40, 3))
plt.title(r'$|w_{t, i} - w_{0, i}|$')
ws_in_base = ws @ P
a = ax.imshow(np.abs((ws_in_base - ws_in_base[0]).transpose()))
fig.colorbar(a)
plt.tight_layout()
plt.show()
data:image/s3,"s3://crabby-images/ec47b/ec47b5b4f0c678906275875216c16ed6bc446688" alt="_images/a52d2b561602578e56fdd3b636c3916d3a2ebb7c50826c2a77bd600ba8643165.png"
Bias complexity tradeoff
xs = np.linspace(-1, 1, 100)
plt.plot(xs, simple_f(xs), label='Simple f')
plt.plot(xs, middle_f(xs), label='Middle f')
plt.plot(xs, complex_f(xs), label='Complicated f')
plt.title('Target functions')
plt.legend()
plt.show()
npoints = 50
dataset = Data(n=npoints,
xmin=-1,
xmax=1,
noise_level=0.05,
type='middle')
plot_data(dataset)
plt.legend()
plt.show()
data:image/s3,"s3://crabby-images/57c02/57c02768311dadc1d62297254da7e9ea900073e4" alt="_images/9b16cb8e377e6422101d9333040b8725f9d101ad01b0436cf6a17f9f12833802.png"
data:image/s3,"s3://crabby-images/8e492/8e4922c3571c74272264e37eb4c544c2e2e336b2" alt="_images/df8d7680d6681f7b0aa5821e5d26efba2e0343061978eea0d85415d409c8ab99.png"
class LinSin(nn.Module):
"""
A linear model with sinusoidal features. Assumes feature_dim is even.
"""
def __init__(self, feature_dim):
super().__init__()
self.feature_dim = feature_dim
self.lout = nn.Linear(self.feature_dim, 1)
self.weight_history = []
def phi(self, x):
"""
Computes $sin(2\pi k x), cos(2\pi k x) $ for k in $[feature_dim / 2]$
"""
r = .2 * np.pi * torch.arange(int(self.feature_dim / 2)) * x
return torch.concatenate([torch.sin(r), torch.cos(r)], axis=1)
def save_weights(self):
with torch.no_grad():
l = list(self.parameters())
self.weight_history.append(np.copy(l[-1].numpy()))
def forward(self, x):
return self.lout(self.phi(x))
model = LinSin(feature_dim=30)
losses = train(model,
dataset,
N_steps=2000,
batch_size=npoints,
lr=.01,
save_weights_every=-1,
)
plot_net(model)
plot_data(dataset)
plt.title('Trained linear model = dim 30')
plt.show()
fig, axs = plt.subplots(1, 2, figsize=(12, 4))
axs[0].loglog(losses)
axs[0].set_title('Train loss')
true_risk = dataset.risk(model)
print(true_risk)
axs[0].axhline(true_risk, color='red', alpha=.5, linestyle='--')
plt.show()
data:image/s3,"s3://crabby-images/c0434/c043486e08fd9c639f2e82980d9553a678142d85" alt="_images/b7845b203e1f292132210a7d724630c85c590ce094d297b13bf425b187dccf11.png"
tensor(0.0026)
data:image/s3,"s3://crabby-images/1606b/1606b75186a61263a247a2ffbafe376d3138f63f" alt="_images/c23de366ffd22f84d869561ef337d153c868ba7e5c8068686f16dd8786781c27.png"
# train_losses = []
# var_train = []
# risks = []
dimensions = np.array(list(range(150, 200, 2)))# + list(range(100, 300, 2)))
for dim in dimensions:
losses = []
local_risks = []
for _ in range(10):
model = LinSin(feature_dim=dim)
losses.append(train(model,
dataset,
N_steps=1000,
batch_size=100,
lr=.1,
save_weights_every=-1,
)[-1])
local_risks.append(dataset.risk(model).detach())
train_losses.append(np.mean(losses))
var_train.append(np.var(losses))
risks.append(np.mean(local_risks))
plt.grid(alpha=.2)
plt.plot(np.arange(2, 200, 2),train_losses, label='Train')
plt.plot(np.arange(2, 200, 2), risks, label='Test')
plt.legend()
<matplotlib.legend.Legend at 0x29090d610>
data:image/s3,"s3://crabby-images/61a4b/61a4b87b0d0156851d564118f17022c7d70f29cd" alt="_images/1d6ae6d3afc098d973c294dc89e8cba232914ced9cde4bc4baacfb66890b025a.png"