Journal TEx Text
This commit is contained in:
50
experiments/meta_task_exp.py
Normal file
50
experiments/meta_task_exp.py
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from matplotlib import pyplot as plt
|
||||||
|
import seaborn as sns
|
||||||
|
from torch import nn
|
||||||
|
from torch.utils.data import Dataset, DataLoader
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from network import MetaNet
|
||||||
|
|
||||||
|
|
||||||
|
class TaskDataset(Dataset):
|
||||||
|
def __init__(self, length=int(5e5)):
|
||||||
|
super().__init__()
|
||||||
|
self.length = length
|
||||||
|
self.prng = np.random.default_rng()
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.length
|
||||||
|
|
||||||
|
def __getitem__(self, _):
|
||||||
|
ab = self.prng.normal(size=(2,)).astype(np.float32)
|
||||||
|
return ab, ab.sum(axis=-1, keepdims=True)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
metanet = MetaNet(2, 3, 4, 1)
|
||||||
|
loss_fn = nn.MSELoss()
|
||||||
|
optimizer = torch.optim.AdamW(metanet.parameters(), lr=0.004)
|
||||||
|
|
||||||
|
d = DataLoader(TaskDataset(), batch_size=50, shuffle=True, drop_last=True)
|
||||||
|
# metanet.train(True)
|
||||||
|
losses = []
|
||||||
|
for batch_x, batch_y in tqdm(d, total=len(d)):
|
||||||
|
# Zero your gradients for every batch!
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
y = metanet(batch_x)
|
||||||
|
loss = loss_fn(y, batch_y)
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
# Adjust learning weights
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
losses.append(loss.item())
|
||||||
|
|
||||||
|
sns.lineplot(y=np.asarray(losses), x=np.arange(len(losses)))
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
@ -137,7 +137,7 @@ class RobustnessComparisonExperiment:
|
|||||||
for noise_level in range(noise_levels):
|
for noise_level in range(noise_levels):
|
||||||
steps = 0
|
steps = 0
|
||||||
clone = Net(fixpoint.input_size, fixpoint.hidden_size, fixpoint.out_size,
|
clone = Net(fixpoint.input_size, fixpoint.hidden_size, fixpoint.out_size,
|
||||||
f"{fixpoint.name}_clone_noise10e-{noise_level}")
|
f"{fixpoint.name}_clone_noise_1e-{noise_level}")
|
||||||
clone.load_state_dict(copy.deepcopy(fixpoint.state_dict()))
|
clone.load_state_dict(copy.deepcopy(fixpoint.state_dict()))
|
||||||
clone = clone.apply_noise(pow(10, -noise_level))
|
clone = clone.apply_noise(pow(10, -noise_level))
|
||||||
|
|
||||||
@ -159,7 +159,8 @@ class RobustnessComparisonExperiment:
|
|||||||
# When this raises a Type Error, we found a second order fixpoint!
|
# When this raises a Type Error, we found a second order fixpoint!
|
||||||
steps += 1
|
steps += 1
|
||||||
|
|
||||||
df.loc[df.shape[0]] = [setting, f'$10^{{-{noise_level}}}$', steps, absolute_loss,
|
df.loc[df.shape[0]] = [setting, f'$\mathregular{{10^{{-{noise_level}}}}}$',
|
||||||
|
steps, absolute_loss,
|
||||||
time_to_vergence[setting][noise_level],
|
time_to_vergence[setting][noise_level],
|
||||||
time_as_fixpoint[setting][noise_level]]
|
time_as_fixpoint[setting][noise_level]]
|
||||||
pbar.update(1)
|
pbar.update(1)
|
||||||
@ -171,12 +172,12 @@ class RobustnessComparisonExperiment:
|
|||||||
var_name="Measurement",
|
var_name="Measurement",
|
||||||
value_name="Steps").sort_values('Noise Level')
|
value_name="Steps").sort_values('Noise Level')
|
||||||
# Plotting
|
# Plotting
|
||||||
plt.rcParams.update({
|
# plt.rcParams.update({
|
||||||
"text.usetex": True,
|
# "text.usetex": True,
|
||||||
"font.family": "sans-serif",
|
# "font.family": "sans-serif",
|
||||||
"font.size": 12,
|
# "font.size": 12,
|
||||||
"font.weight": 'bold',
|
# "font.weight": 'bold',
|
||||||
"font.sans-serif": ["Helvetica"]})
|
# "font.sans-serif": ["Helvetica"]})
|
||||||
sns.set(style='whitegrid', font_scale=2)
|
sns.set(style='whitegrid', font_scale=2)
|
||||||
bf = sns.boxplot(data=df_melted, y='Steps', x='Noise Level', hue='Measurement', palette=PALETTE)
|
bf = sns.boxplot(data=df_melted, y='Steps', x='Noise Level', hue='Measurement', palette=PALETTE)
|
||||||
synthetic = 'synthetic' if self.is_synthetic else 'natural'
|
synthetic = 'synthetic' if self.is_synthetic else 'natural'
|
||||||
@ -191,7 +192,7 @@ class RobustnessComparisonExperiment:
|
|||||||
plt.savefig(str(filepath))
|
plt.savefig(str(filepath))
|
||||||
|
|
||||||
if print_it:
|
if print_it:
|
||||||
col_headers = [str(f"10e-{d}") for d in range(noise_levels)]
|
col_headers = [str(f"1e-{d}") for d in range(noise_levels)]
|
||||||
|
|
||||||
print(f"\nAppplications steps until divergence / zero: ")
|
print(f"\nAppplications steps until divergence / zero: ")
|
||||||
# print(tabulate(time_to_vergence, showindex=row_headers, headers=col_headers, tablefmt='orgtbl'))
|
# print(tabulate(time_to_vergence, showindex=row_headers, headers=col_headers, tablefmt='orgtbl'))
|
||||||
|
91
network.py
91
network.py
@ -245,17 +245,82 @@ class SecondaryNet(Net):
|
|||||||
return df, is_diverged
|
return df, is_diverged
|
||||||
|
|
||||||
|
|
||||||
|
class MetaWeight(Net):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class MetaCell(nn.Module):
|
||||||
|
def __init__(self, name, interface, residual_skip=True):
|
||||||
|
super().__init__()
|
||||||
|
self.residual_skip = residual_skip
|
||||||
|
self.name = name
|
||||||
|
self.interface = interface
|
||||||
|
self.weight_interface = 4
|
||||||
|
self.net_hidden_size = 4
|
||||||
|
self.net_ouput_size = 1
|
||||||
|
self.meta_weight_list = nn.ModuleList(
|
||||||
|
[MetaWeight(self.weight_interface, self.net_hidden_size,
|
||||||
|
self.net_ouput_size, name=f'{self.name}_{weight_idx}'
|
||||||
|
) for weight_idx in range(self.interface)])
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
xs = [torch.hstack((x[:, idx].unsqueeze(-1), torch.zeros((x.shape[0], self.weight_interface - 1))))
|
||||||
|
for idx in range(len(self.meta_weight_list))]
|
||||||
|
tensor = torch.hstack([meta_weight(xs[idx]) for idx, meta_weight in enumerate(self.meta_weight_list)])
|
||||||
|
if self.residual_skip:
|
||||||
|
tensor += x
|
||||||
|
|
||||||
|
result = torch.sum(tensor, dim=-1, keepdim=True)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class MetaLayer(nn.Module):
|
||||||
|
def __init__(self, name, interface=4, out=1, width=4):
|
||||||
|
super().__init__()
|
||||||
|
self.name = name
|
||||||
|
self.interface = interface
|
||||||
|
self.width = width
|
||||||
|
|
||||||
|
meta_cell_list = nn.ModuleList([MetaCell(name=f'{self.name}_{cell_idx}',
|
||||||
|
interface=interface
|
||||||
|
) for cell_idx in range(self.width)])
|
||||||
|
self.meta_cell_list = meta_cell_list
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
result = torch.hstack([metacell(x) for metacell in self.meta_cell_list])
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class MetaNet(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, interface=4, depth=3, width=4, out=1):
|
||||||
|
super().__init__()
|
||||||
|
self.out = out
|
||||||
|
self.interface = interface
|
||||||
|
self.width = width
|
||||||
|
self.depth = depth
|
||||||
|
|
||||||
|
meta_layer_list = nn.ModuleList([MetaLayer(name=f'Weight_{0}',
|
||||||
|
interface=self.interface,
|
||||||
|
width=self.width)])
|
||||||
|
meta_layer_list.extend([MetaLayer(name=f'Weight_{layer_idx + 1}',
|
||||||
|
interface=self.width, width=self.width
|
||||||
|
) for layer_idx in range(self.depth - 2)])
|
||||||
|
meta_layer_list.append(MetaLayer(name=f'Weight_{len(meta_layer_list)}',
|
||||||
|
interface=self.width, width=self.out))
|
||||||
|
self._meta_layer_list = meta_layer_list
|
||||||
|
self._net = nn.Sequential(*self._meta_layer_list)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
result = self._net.forward(x)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
is_div = True
|
metanet = MetaNet(2, 3, 4, 1)
|
||||||
while is_div:
|
metanet(torch.ones((5, 2)))
|
||||||
net = SecondaryNet(4, 2, 1, "SecondaryNet")
|
print('Test')
|
||||||
data_df, is_div = net.self_train(20000, 25, 1e-4)
|
print('Test')
|
||||||
from matplotlib import pyplot as plt
|
print('Test')
|
||||||
import seaborn as sns
|
print('Test')
|
||||||
# data_df = data_df[::-1] # Reverse
|
print('Test')
|
||||||
fig = sns.lineplot(data=data_df[[x for x in data_df.columns if x != 'step']])
|
|
||||||
# fig.set(yscale='log')
|
|
||||||
print(data_df.iloc[-1])
|
|
||||||
print(data_df.iloc[0])
|
|
||||||
plt.show()
|
|
||||||
print("done")
|
|
||||||
|
Reference in New Issue
Block a user