TP3 : Lazy training#

Observing the lazy training regime with different models

With inspiration from

https://rajatvd.github.io/NTK/

import torch

import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F

import numpy as np
import matplotlib.pyplot as plt
### Exectutes the script tp3_utils.py which contains code from last time
## Go take a quick look at the functions implemented there

from tp3_utils import * 

# %load_ext autoreload
# %autoreload 2

The data#

We continue with last week’s toy problem of 1 dimensional regression

xs = np.linspace(-1, 1, 100)
plt.plot(xs, simple_f(xs), label='Simple f')
plt.plot(xs, middle_f(xs), label='Middle f')
plt.plot(xs, complex_f(xs), label='Complicated f')
plt.title('Target functions')
plt.legend()
plt.show()

dataset = Data(n=15,
                xmin=-1,
                xmax=1,
                noise_level=0,
                type='middle')

plot_data(dataset)
plt.legend()
plt.show()
_images/cd5688302688263222ec0dc892a827d342620cf3c0b1bf94a58259ea0ee9bbed.png _images/9db2d5ad0ec076e877c4141f1e38c7f22fa7dda4ec350d1f16ba1c85eb8a1180.png

The training function#

Write the training function

def train(net, 
          dataset, 
          N_steps=1, 
          batch_size=30, 
          lr=0.01, 
          save_weights_every=100):
    
    optimizer = optim.SGD(net.parameters(), lr=lr)
    criterion = nn.MSELoss()

    losses = []
    weights = []

    for i in range(N_steps): 
        inputs, labels= dataset.next_batch(batch_size)
        optimizer.zero_grad()

        outputs = net(inputs)
        loss = criterion(outputs, labels)
        
        loss.backward()
        optimizer.step()

        losses.append(loss.detach().numpy())

        if i % save_weights_every == 0 and (save_weights_every > 0):
            with torch.no_grad():
                weights.append(net.save_weights())

    return losses
def train(net, 
          dataset, 
          N_steps=1, 
          batch_size=30, 
          lr=0.01, 
          save_weights_every=100):
    
    # YOUR CODE HERE

    losses = []
    weights = []

    for i in range(N_steps): 
        # YOUR CODE HERE

        losses.append(loss.detach().numpy())

        if i % save_weights_every == 0 and (save_weights_every > 0):
            with torch.no_grad():
                weights.append(net.save_weights())

    return losses

Test your training#

model = MLPdeep(hidden_dim=10)
losses = train(model, 
                dataset, 
                N_steps=10000, 
                batch_size=15, 
                lr=.1, 
                save_weights_every=1,
                )

plot_net(model)
plot_data(dataset)
plt.show()

plt.plot(losses)
plt.show()
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
Cell In[6], line 2
      1 model = MLPdeep(hidden_dim=10)
----> 2 losses = train(model, 
      3                 dataset, 
      4                 N_steps=10000, 
      5                 batch_size=15, 
      6                 lr=.1, 
      7                 save_weights_every=1,
      8                 )
     10 plot_net(model)
     11 plot_data(dataset)

Cell In[5], line 16, in train(net, dataset, N_steps, batch_size, lr, save_weights_every)
     11 weights = []
     13 for i in range(N_steps): 
     14     # YOUR CODE HERE
---> 16     losses.append(loss.detach().numpy())
     18     if i % save_weights_every == 0 and (save_weights_every > 0):
     19         with torch.no_grad():

NameError: name 'loss' is not defined

A linear model#

We define a family of linear models with sinusoidal features. The feature map is

\[ \phi(x) = ( \cos(2 \pi x) , \sin(2 \pi x), \cos(4 \pi x) , \sin(4 \pi x), \dots , \cos(2k \pi x) , \sin(2k \pi x)) \]

where \(k\) is feature_dim / 2.

class LinSin(nn.Module): 
    """
        A linear model with sinusoidal features. Assumes feature_dim is even.
    """
    def __init__(self, feature_dim): 
        super().__init__()
        self.feature_dim = feature_dim
        self.lout = nn.Linear(self.feature_dim, 1)

        self.weight_history = []

    def phi(self, x):
        """
            Computes $sin(2\pi k x), cos(2\pi k x) $ for k in $[feature_dim / 2]$
        """
        r = 2 * np.pi * torch.arange(int(self.feature_dim / 2)) * x
        return torch.concatenate([torch.sin(r), torch.cos(r)], axis=1)

    def save_weights(self): 
        with torch.no_grad():
            l = list(self.parameters())
            self.weight_history.append(np.copy(l[-1].numpy()))

    def forward(self, x): 
        return self.lout(self.phi(x))
class LinSin(nn.Module): 
    """
        A linear model with sinusoidal features
    """
    def __init__(self, feature_dim): 
        super().__init__()
        self.feature_dim = feature_dim
        self.lout = nn.Linear(self.feature_dim, 1)

        self.weight_history = []

    def phi(self, x):
        """
            Computes $sin(2\pi k x), cos(2\pi k x) $ for k in $[feature_dim / 2]$
        """
        # YOUR CODE HERE
        pass

    def save_weights(self): 
        with torch.no_grad():
            l = list(self.parameters())
            self.weight_history.append(np.copy(l[-1].numpy()))

    def forward(self, x): 
        # YOUR CODE HERE
        pass

Test the linear model

model = LinSin(feature_dim=8)
losses = train(model, 
                dataset, 
                N_steps=10000, 
                batch_size=15, 
                lr=.01, 
                save_weights_every=1,
                )

plot_net(model)
plot_data(dataset)
plt.show()

fig, axs = plt.subplots(1, 2, figsize=(12, 4))
axs[0].loglog(losses)
axs[0].set_title('Train loss')

all_weights1 = np.array(model.weight_history)
weight_evol = np.linalg.norm(all_weights1 - all_weights1[0], axis=1) / np.linalg.norm(all_weights1[0])
axs[1].plot(weight_evol, alpha=.5) 
axs[1].set_title('Relative weight movement of last layer')
_images/1b75fa7bd4eeeb51285d56cc3ecdd5f49717fd76fadbf2d1445d245d22efc47a.png
Text(0.5, 1.0, 'Relative weight movement of last layer')
_images/b82b46d55a8dcc1a5a6e01e5990e83b71082a1ffaf1f310de7043e88418cb656.png

Question: what do the plots mean?

Let us now repeat the code above with a large feature dimension. What you do observe?

model = LinSin(feature_dim=200)
losses = train(model, 
                dataset, 
                N_steps=100, 
                batch_size=15, 
                lr=.01, 
                save_weights_every=1,
                )

plot_net(model)
plot_data(dataset)
plt.title('Trained model')
plt.show()

fig, axs = plt.subplots(1, 2, figsize=(12, 4))
axs[0].loglog(losses)
axs[0].set_title('Train loss')

all_weights1 = np.array(model.weight_history)
weight_evol = np.linalg.norm(all_weights1 - all_weights1[0], axis=1) / np.linalg.norm(all_weights1[0])
axs[1].plot(weight_evol, alpha=.5) 
axs[1].set_title('Relative weight movement of last layer')

plt.show()
_images/548aa265a779ff227636c9036e716bf449f7f627f296fc90036a4de5595cc551.png _images/5d02ccef808c71455c60f6c8e4bf54a4d6a378b50ff728b757afc283096f7ba9.png

Robustness of fast convergence for linear models#

Run 10 instances of the linear model with different initializations.

Are there significant differences between the models?

model_list = []
N_models = 1
for _ in range(N_models): 
    model = LinSin(feature_dim=100)
    model_list.append(model)
    plot_net(model)

plot_data(dataset)
plt.title('At initialization')

all_losses = [[] for _ in range(N_models)]
all_weights = []
_images/4e4de4066341971d24900c91df9927c9425b7ca0e15fcf2e1911a0f52f73a3e1.png
N_steps = 100000

for i, model in enumerate(model_list):
    losses = train(model, 
                   dataset, 
                   N_steps=N_steps, 
                   batch_size=30, 
                   lr=.01, 
                   save_weights_every=1,
                   )
    all_losses[i] += losses
    plot_net(model)

plot_data(dataset)
plt.title('After some training')

plt.show()
_images/433262b9d1c1070b23de856edc3a798c5095e11d1fb062f6d5935c548df4e720.png
fig, axs = plt.subplots(1, 2, figsize=(12, 4))

axs[0].set_title('Train loss')
for losses in all_losses:
    points = np.arange(0, len(losses), 1)  
    losses = np.array(losses)
    axs[0].loglog(points, losses[points], alpha=.5)

axs[1].set_title('Relative movement of weights')
for model in model_list:    
    all_weights1 = np.array(model.weight_history)
    weight_evol = np.linalg.norm(all_weights1 - all_weights1[0], axis=1) / np.linalg.norm(all_weights1[0])

    axs[1].plot(weight_evol, alpha=.5) 

plt.show()
_images/cbf36ae45f5293ceeddfe80e9f5d06f14b3f4fb3fba12aaafdd52c5a02c2e0d5.png

Lazy training of non-linear models#

We now train some neural nets in different regimes and try to see when lazy training occurs.

Question How are nets initialized in pytorch?

model_list = []
N_models = 10
for _ in range(N_models): 
    model = MLPshallow(hidden_dim=10)
    model_list.append(model)
    plot_net(model)

plot_data(dataset)
plt.title('At initialization')

all_losses = [[] for _ in range(N_models)]
all_weights = []
_images/9e209298d93501c0ff8dfa32525650a605c28dcb5b2982226b5ee1f6f80ccdb9.png
N_steps = 5000

for i, model in enumerate(model_list):
    losses = train(model, 
                   dataset, 
                   N_steps=N_steps, 
                   batch_size=30, 
                   lr=.001, 
                   save_weights_every=1,
                   )
    all_losses[i] += losses
    plot_net(model)

plot_data(dataset)
plt.title('After some training')

plt.show()
_images/150d1525442ef79549261a6bf774352c14d2213be224ace5834a4c20dd5970e7.png
fig, axs = plt.subplots(1, 2, figsize=(12, 4))

axs[0].set_title('Train loss')
for losses in all_losses:
    points = np.arange(0, len(losses), 1)  
    losses = np.array(losses)
    axs[0].loglog(points, losses[points], alpha=.5)

axs[1].set_title('Relative weight movement of last layer')
for model in model_list:    
    all_weights1 = np.array(model.weight_history)
    weight_evol = np.linalg.norm(all_weights1 - all_weights1[0], axis=1) / np.linalg.norm(all_weights1[0])

    axs[1].plot(weight_evol, alpha=.5) 

plt.show()

Thin vs. Wide networks#

Let us now vary the width of our network. What happens?

model_list = []

width_list = [5, 10, 100, 1000]

for width in width_list: 
    model = MLPshallow(hidden_dim=width)
    model_list.append(model)
    plot_net(model, label=f'Width = {model.hidden_dim}')

plot_data(dataset)
plt.title('At initialization')
plt.legend()

all_losses = [[] for _ in width_list]
all_weights = []
_images/712cba2103f40922dfe3ebfc8539266275b5e0aac5f72bd4955ad77ccc9732e9.png
N_steps = 5000

for i, model in enumerate(model_list):
    losses = train(model, 
                   dataset, 
                   N_steps=N_steps, 
                   batch_size=30, 
                   lr=.001, 
                   save_weights_every=1,
                   )
    all_losses[i] += losses
    plot_net(model, label=f'Width = {model.hidden_dim}')

plot_data(dataset)
plt.title('After some training')

plt.show()
_images/b67d38c733d2d71420e32d2b22d889ea819b79b85e7da421763ee984a8dd9533.png
fig, axs = plt.subplots(1, 2, figsize=(12, 4))

axs[0].set_title('Train loss')
for i, losses in enumerate(all_losses):
    points = np.arange(0, len(losses), 1)  
    losses = np.array(losses)
    axs[0].loglog(points, losses[points], alpha=.5, label=f'{width_list[i]}')

axs[1].set_title('Relative weight movement of last layer')
for model in model_list:    
    all_weights1 = np.array(model.weight_history)
    weight_evol = np.linalg.norm(all_weights1 - all_weights1[0], axis=1) / np.linalg.norm(all_weights1[0])

    axs[1].plot(weight_evol, alpha=.5, label=f'{model.hidden_dim}') 

plt.legend()
plt.show()
_images/50c3698fe5b613e312fbe9719567f6711f40a84071243b6a92daa0c2ff9cdd14.png

Bad initialization can you get you out of the Lazy regime#

We simulate bad initialization by taking two steps of gradient descent from initialization.

model_list = []
N_models = 10
for _ in range(N_models): 
    model = MLPshallow(hidden_dim=100)
    model_list.append(model)
    train(model, 
        dataset, 
        N_steps=2, 
        batch_size=30, 
        lr=1, 
        save_weights_every=-1,
        )
    plot_net(model)


plot_data(dataset)
plt.title('At bad initialization')

all_losses = [[] for _ in range(N_models)]
all_weights = []
N_steps = 1000
xmin = np.min(dataset.inputs.detach().numpy())
xmax = np.max(dataset.inputs.detach().numpy())
plt.xlim(xmin, xmax)
print(xmin, xmax)

for i, model in enumerate(model_list):
    losses = train(model, 
                   dataset, 
                   N_steps=N_steps, 
                   batch_size=30, 
                   lr=.001, 
                   save_weights_every=1,
                   )
    all_losses[i] += losses
    plot_net(model, xmin=xmin, xmax=xmax)

plot_data(dataset)
plt.title('After some training')

plt.show()
fig, axs = plt.subplots(1, 2, figsize=(12, 4))

axs[0].set_title('Train loss')
for i, losses in enumerate(all_losses):
    points = np.arange(0, len(losses), 1)  
    losses = np.array(losses)
    axs[0].loglog(points, losses[points], alpha=.5)

axs[1].set_title('Relative weight movement of last layer')
for model in model_list:    
    all_weights1 = np.array(model.weight_history)
    weight_evol = np.linalg.norm(all_weights1 - all_weights1[0], axis=1) / np.linalg.norm(all_weights1[0])

    axs[1].plot(weight_evol, alpha=.5) 

plt.show()

Scaling#

Now we attempt to enter the lazy regime by scaling the model outputs.

n_points = 10

dataset = Data(n=n_points,
                xmin=-1,
                xmax=1,
                noise_level=0,
                type='middle')
class ScaledModel(MLPdeep):
    """"
        From a base network, creates a scaled copy with 0 initialization. Keeps a frozen copy of a base model. 

        alpha: scaling factor
    """
    def __init__(self, base_model, alpha=1):
        super().__init__(hidden_dim=base_model.hidden_dim)
        
        self.alpha = alpha
        self.load_state_dict(base_model.state_dict())

        self.base_model = [base_model] # trick to hide from the other parameters dictionary. Probably not the proper way to do this...

    def forward(self, x):
        """
            Scales the output and subtracts the initial function 
        """
        with torch.no_grad():
            base = self.base_model[0].forward(x).detach()
        return self.alpha * (super().forward(x) - base)

base_model = MLPdeep(hidden_dim=8)

alphas = [1, 100, 1000]#, 2000, 10000]
model_list = []

for alpha in alphas: 
    new_model = ScaledModel(base_model, alpha=alpha)
    new_model.load_state_dict(base_model.state_dict())
    model_list.append(new_model)

    plot_net(new_model, label=f'{alpha}')

all_losses = [[] for _ in range(len(model_list))]

plt.legend()
plt.show()
_images/a7800bfd338336733de040169f8a5dc5c7aa729b97b2cc2c027efd9426f0ca0e.png
train(model_list[2], 
    dataset, 
    N_steps=10000,
    batch_size=n_points, 
    lr=.1 / alphas[2]**2, 
    save_weights_every=1,
    )

plot_data(dataset)
plt.legend()
plt.title('After some training')


for i, model in enumerate(model_list):
    plot_net(model, xmin=xmin, xmax=xmax, label=f'alpha = {alphas[i]}')
N_steps = 10000
xmin = -1 #np.min(dataset.inputs.detach().numpy())
xmax = 1 #np.max(dataset.inputs.detach().numpy())
plt.xlim(xmin, xmax)
print(xmin, xmax)

for i, model in enumerate(model_list):
    losses = train(model, 
                   dataset, 
                   N_steps=N_steps,
                   batch_size=n_points, 
                   lr=.01 / alphas[i]**2 , 
                   save_weights_every=1,
                   )
    all_losses[i] += losses 
    plot_net(model, xmin=xmin, xmax=xmax, label=f'alpha = {alphas[i]}')

plot_data(dataset)
plt.legend()
plt.title('After some training')

plt.show()
-1 1
_images/f0c9ba183a93d7db8485c1cdd9ed229fde3d9de4fad51888c679bb605dad7952.png
fig, axs = plt.subplots(1, 2, figsize=(12, 4))

axs[0].set_title('Train loss')
for i, losses in enumerate(all_losses):
    points = np.arange(0, len(losses), 1)  
    losses = np.array(losses) # / alphas[i] ** 2
    axs[0].loglog(points, losses[points], alpha=.5, label=f'{alphas[i]}')

axs[1].set_title('Relative weight movement of last layer')
for i, model in enumerate(model_list):    
    all_weights1 = np.array(model.weight_history)
    weight_evol = np.linalg.norm(all_weights1 - all_weights1[0], axis=1) / np.linalg.norm(all_weights1[0])

    axs[1].plot(weight_evol, alpha=.5, label=rf'$\alpha$ = {alphas[i]}')

plt.legend()
plt.show()
_images/b79b443b65bfc611175b238f9adce54efb50eef6da452826c9cb0e4c3f3200d7.png

5 - The effect of scaling on a 1D model#

Question Explain the piece of code below.

see blog post

import numpy as np
import matplotlib.pyplot as plt

w_0 = 1.5

def f(x, w): 
    return w * x /2  +  np.log(1 + w**2) * x / 5  + .1 * np.sin(np.exp(2*w))#np.exp(0.1*w))

def h(x, w, w_0=.4): 
    return f(x, w) * (w - w_0)

def linh(x, w, w_0=.4):
    return  f(x, w_0) * (w - w_0)

xys = [(1.5, 3)]

def l(w, alpha): 
    return np.sum([(alpha * h(x, w, w_0=w_0) - y) ** 2 / 2 for x, y in xys]) / alpha ** 2

def linl(w, alpha):
    return np.sum([(alpha * linh(x, w, w_0=w_0) - y) ** 2 / 2 for x, y in xys]) / alpha ** 2

alphas = [1, 5, 10, 50,  100, 1000, 2000, 10000]

ws = np.linspace(-1, 1, 100)
plt.title('h(2; w)')
plt.plot(ws, [h(2, w, w_0=w_0) for w in ws])
plt.show()

for alpha in alphas:
    delta = 2 / alpha ** (3/4)
    ws = np.linspace(w_0 - delta, w_0 + delta, 5000)
    plt.title(f'Loss landscape of scaled model: alpha = {alpha}')
    plt.plot(ws, [l(w, alpha) for w in ws], label='Loss of scaled model')
    plt.plot(ws, [linl(w, alpha) for w in ws], label='Loss of scaled linearized model')
    plt.grid(alpha=.2)
    plt.legend()
    plt.show()
import numpy as np
import matplotlib.pyplot as plt

w_0 = 1.5

def f(x, w): 
    return w * x /2  +  np.log(1 + w**2) * x / 5  + .1 * np.sin(np.exp(2*w))

def h(x, w, w_0=.4): 
    return f(x, w) * (w - w_0)

def linh(x, w, w_0=.4):
    return  f(x, w_0) * (w - w_0)

xys = [(1.5, 3)]

def l(w, alpha): 
    return np.sum([(alpha * h(x, w, w_0=w_0) - y) ** 2 / 2 for x, y in xys]) / alpha ** 2

def linl(w, alpha):
    return np.sum([(alpha * linh(x, w, w_0=w_0) - y) ** 2 / 2 for x, y in xys]) / alpha ** 2

alphas = [1, 5, 10, 50,  100, 1000, 2000, 10000]

for alpha in alphas:
    delta = 2 / alpha ** (3/4)
    ws = np.linspace(w_0 - delta, w_0 + delta, 5000)
    plt.title(f'??: alpha = {alpha}')
    plt.plot(ws, [l(w, alpha) for w in ws], label='??')
    plt.plot(ws, [linl(w, alpha) for w in ws], label='??')
    plt.grid(alpha=.2)
    plt.legend()
    plt.show()

Pieces#

mlp = model_list[0]

for name, param in mlp.named_parameters(): 
    print(name)
    print(param.shape)
    print(np.linalg.norm((param.detach().numpy())))
a = np.arange(10,30)
shifted_a = np.zeros(20)
print(shifted_a)
shifted_a[1:] = a[:-1]
print(shifted_a)
print(a)

evol = (a - shifted_a)[1:] / a[1:]
print(evol)
plt.title('Difference in succesive weights')
for model in model_list:    
    all_weights1 = np.array(model.weight_history)
    shifted_weights = np.zeros(all_weights1.shape)
    shifted_weights[1:, :] = all_weights1[:-1, :]
    weight_evol = (np.linalg.norm((all_weights1 - shifted_weights)[1:], axis=1) 
                   / np.linalg.norm(all_weights1[1:], axis=1))

    plt.plot(weight_evol, alpha=.5) 


plt.show()

You can use the following code snippet to save your network.

# mlp = MLPdeep() # MLPshallow()
# PATH = f'./{dataset.type}_{mlp.net_type}.pth'
# torch.save(mlp.state_dict(), PATH)

# net = MLPdeep() # MLPshallow()
# net.load_state_dict(torch.load(PATH))
p = 100
n = 5000

X = np.random.normal(size=(p, n))  / np.sqrt(n) 
Y = np.random.normal(size=(n, 1))

XXt = X @ X.transpose()
XY = X @ Y

fig, axs = plt.subplots(1, 2)

axs[0].set_title(r'$X X^\top $')
axs[0].imshow(XXt)

eigvals, P = np.linalg.eigh(XXt)
print(max(eigvals))

axs[1].set_title('Diagonalized')
axs[1].imshow(P.transpose() @ XXt @ P)
plt.show()

print(np.linalg.matrix_rank(X))
1.2872850938240181
_images/3ca2f5f8f947a58bd739b331154119b3ef9d31e879acce1ebfca1015d26c466d.png
100
T = 1000
eta = 0.001

ws = np.zeros((T, p, 1))
w = np.random.normal(size=(p, 1))  / np.sqrt(p)

losses = []
for t in range(T):
    w = w - eta * XXt @ w + eta * XY
    ws[t] = w

    losses.append(np.linalg.norm(X.transpose() @ w  - Y)**2 / n )

ws = ws.squeeze(axis=2)
fig, axs = plt.subplots(1, 2, figsize=(16, 4))

axs[0].set_title('Loss')
axs[0].plot(losses) 


axs[1].set_title('Relative distance of weights to init')
weight_evol = np.linalg.norm(ws - ws[0, :], axis=1) / np.linalg.norm(ws[0, :])
axs[1].plot(weight_evol, alpha=.5) 
 

plt.show()
_images/cb7c3889add3b93c0022963a23bb724d3d975d95ed03c58e15ded9bac551060a.png

In the right basis, most coordinates of w dont move#

w only moves in a low dimensional subspace: the image of X

fig, ax = plt.subplots(figsize=(40, 3))

plt.title(r'$|w_{t, i} - w_{0, i}|$')
ws_in_base = ws @ P

a = ax.imshow(np.abs((ws_in_base - ws_in_base[0]).transpose()))
fig.colorbar(a)
plt.tight_layout()
plt.show()
_images/a52d2b561602578e56fdd3b636c3916d3a2ebb7c50826c2a77bd600ba8643165.png

Bias complexity tradeoff

xs = np.linspace(-1, 1, 100)
plt.plot(xs, simple_f(xs), label='Simple f')
plt.plot(xs, middle_f(xs), label='Middle f')
plt.plot(xs, complex_f(xs), label='Complicated f')
plt.title('Target functions')
plt.legend()
plt.show()

npoints = 50

dataset = Data(n=npoints,
                xmin=-1,
                xmax=1,
                noise_level=0.05,
                type='middle')

plot_data(dataset)
plt.legend()
plt.show()
_images/9b16cb8e377e6422101d9333040b8725f9d101ad01b0436cf6a17f9f12833802.png _images/df8d7680d6681f7b0aa5821e5d26efba2e0343061978eea0d85415d409c8ab99.png
class LinSin(nn.Module): 
    """
        A linear model with sinusoidal features. Assumes feature_dim is even.
    """
    def __init__(self, feature_dim): 
        super().__init__()
        self.feature_dim = feature_dim
        self.lout = nn.Linear(self.feature_dim, 1)

        self.weight_history = []

    def phi(self, x):
        """
            Computes $sin(2\pi k x), cos(2\pi k x) $ for k in $[feature_dim / 2]$
        """
        r = .2 * np.pi * torch.arange(int(self.feature_dim / 2)) * x
        return torch.concatenate([torch.sin(r), torch.cos(r)], axis=1)

    def save_weights(self): 
        with torch.no_grad():
            l = list(self.parameters())
            self.weight_history.append(np.copy(l[-1].numpy()))

    def forward(self, x): 
        return self.lout(self.phi(x))
model = LinSin(feature_dim=30)
losses = train(model, 
                dataset, 
                N_steps=2000, 
                batch_size=npoints, 
                lr=.01, 
                save_weights_every=-1,
                )

plot_net(model)
plot_data(dataset)
plt.title('Trained linear model = dim 30')
plt.show()

fig, axs = plt.subplots(1, 2, figsize=(12, 4))
axs[0].loglog(losses)
axs[0].set_title('Train loss')
true_risk = dataset.risk(model)
print(true_risk)
axs[0].axhline(true_risk, color='red', alpha=.5, linestyle='--')


plt.show()
_images/b7845b203e1f292132210a7d724630c85c590ce094d297b13bf425b187dccf11.png
tensor(0.0026)
_images/c23de366ffd22f84d869561ef337d153c868ba7e5c8068686f16dd8786781c27.png
# train_losses = []
# var_train = []
# risks = []
dimensions = np.array(list(range(150, 200, 2)))# + list(range(100, 300, 2)))
for dim in dimensions:
    losses = []
    local_risks = []
    for _ in range(10):
        model = LinSin(feature_dim=dim)
        losses.append(train(model, 
                        dataset, 
                        N_steps=1000, 
                        batch_size=100, 
                        lr=.1, 
                        save_weights_every=-1,
                        )[-1])
        local_risks.append(dataset.risk(model).detach())
    train_losses.append(np.mean(losses))
    var_train.append(np.var(losses))
    risks.append(np.mean(local_risks))
    
plt.grid(alpha=.2)

plt.plot(np.arange(2, 200, 2),train_losses, label='Train')
plt.plot(np.arange(2, 200, 2), risks, label='Test')
plt.legend()
<matplotlib.legend.Legend at 0x29090d610>
_images/1d6ae6d3afc098d973c294dc89e8cba232914ced9cde4bc4baacfb66890b025a.png