Journal TEx Text

This commit is contained in:
Steffen Illium
2022-01-21 17:28:45 +01:00
parent 21dd572969
commit 5f1f5833d8
3 changed files with 138 additions and 22 deletions

@ -0,0 +1,50 @@
import numpy as np
import torch
from matplotlib import pyplot as plt
import seaborn as sns
from torch import nn
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from network import MetaNet
class TaskDataset(Dataset):
def __init__(self, length=int(5e5)):
super().__init__()
self.length = length
self.prng = np.random.default_rng()
def __len__(self):
return self.length
def __getitem__(self, _):
ab = self.prng.normal(size=(2,)).astype(np.float32)
return ab, ab.sum(axis=-1, keepdims=True)
if __name__ == '__main__':
metanet = MetaNet(2, 3, 4, 1)
loss_fn = nn.MSELoss()
optimizer = torch.optim.AdamW(metanet.parameters(), lr=0.004)
d = DataLoader(TaskDataset(), batch_size=50, shuffle=True, drop_last=True)
# metanet.train(True)
losses = []
for batch_x, batch_y in tqdm(d, total=len(d)):
# Zero your gradients for every batch!
optimizer.zero_grad()
y = metanet(batch_x)
loss = loss_fn(y, batch_y)
loss.backward()
# Adjust learning weights
optimizer.step()
losses.append(loss.item())
sns.lineplot(y=np.asarray(losses), x=np.arange(len(losses)))
plt.show()

@ -137,7 +137,7 @@ class RobustnessComparisonExperiment:
for noise_level in range(noise_levels): for noise_level in range(noise_levels):
steps = 0 steps = 0
clone = Net(fixpoint.input_size, fixpoint.hidden_size, fixpoint.out_size, clone = Net(fixpoint.input_size, fixpoint.hidden_size, fixpoint.out_size,
f"{fixpoint.name}_clone_noise10e-{noise_level}") f"{fixpoint.name}_clone_noise_1e-{noise_level}")
clone.load_state_dict(copy.deepcopy(fixpoint.state_dict())) clone.load_state_dict(copy.deepcopy(fixpoint.state_dict()))
clone = clone.apply_noise(pow(10, -noise_level)) clone = clone.apply_noise(pow(10, -noise_level))
@ -159,7 +159,8 @@ class RobustnessComparisonExperiment:
# When this raises a Type Error, we found a second order fixpoint! # When this raises a Type Error, we found a second order fixpoint!
steps += 1 steps += 1
df.loc[df.shape[0]] = [setting, f'$10^{{-{noise_level}}}$', steps, absolute_loss, df.loc[df.shape[0]] = [setting, f'$\mathregular{{10^{{-{noise_level}}}}}$',
steps, absolute_loss,
time_to_vergence[setting][noise_level], time_to_vergence[setting][noise_level],
time_as_fixpoint[setting][noise_level]] time_as_fixpoint[setting][noise_level]]
pbar.update(1) pbar.update(1)
@ -171,12 +172,12 @@ class RobustnessComparisonExperiment:
var_name="Measurement", var_name="Measurement",
value_name="Steps").sort_values('Noise Level') value_name="Steps").sort_values('Noise Level')
# Plotting # Plotting
plt.rcParams.update({ # plt.rcParams.update({
"text.usetex": True, # "text.usetex": True,
"font.family": "sans-serif", # "font.family": "sans-serif",
"font.size": 12, # "font.size": 12,
"font.weight": 'bold', # "font.weight": 'bold',
"font.sans-serif": ["Helvetica"]}) # "font.sans-serif": ["Helvetica"]})
sns.set(style='whitegrid', font_scale=2) sns.set(style='whitegrid', font_scale=2)
bf = sns.boxplot(data=df_melted, y='Steps', x='Noise Level', hue='Measurement', palette=PALETTE) bf = sns.boxplot(data=df_melted, y='Steps', x='Noise Level', hue='Measurement', palette=PALETTE)
synthetic = 'synthetic' if self.is_synthetic else 'natural' synthetic = 'synthetic' if self.is_synthetic else 'natural'
@ -191,7 +192,7 @@ class RobustnessComparisonExperiment:
plt.savefig(str(filepath)) plt.savefig(str(filepath))
if print_it: if print_it:
col_headers = [str(f"10e-{d}") for d in range(noise_levels)] col_headers = [str(f"1e-{d}") for d in range(noise_levels)]
print(f"\nAppplications steps until divergence / zero: ") print(f"\nAppplications steps until divergence / zero: ")
# print(tabulate(time_to_vergence, showindex=row_headers, headers=col_headers, tablefmt='orgtbl')) # print(tabulate(time_to_vergence, showindex=row_headers, headers=col_headers, tablefmt='orgtbl'))

@ -245,17 +245,82 @@ class SecondaryNet(Net):
return df, is_diverged return df, is_diverged
class MetaWeight(Net):
pass
class MetaCell(nn.Module):
def __init__(self, name, interface, residual_skip=True):
super().__init__()
self.residual_skip = residual_skip
self.name = name
self.interface = interface
self.weight_interface = 4
self.net_hidden_size = 4
self.net_ouput_size = 1
self.meta_weight_list = nn.ModuleList(
[MetaWeight(self.weight_interface, self.net_hidden_size,
self.net_ouput_size, name=f'{self.name}_{weight_idx}'
) for weight_idx in range(self.interface)])
def forward(self, x):
xs = [torch.hstack((x[:, idx].unsqueeze(-1), torch.zeros((x.shape[0], self.weight_interface - 1))))
for idx in range(len(self.meta_weight_list))]
tensor = torch.hstack([meta_weight(xs[idx]) for idx, meta_weight in enumerate(self.meta_weight_list)])
if self.residual_skip:
tensor += x
result = torch.sum(tensor, dim=-1, keepdim=True)
return result
class MetaLayer(nn.Module):
def __init__(self, name, interface=4, out=1, width=4):
super().__init__()
self.name = name
self.interface = interface
self.width = width
meta_cell_list = nn.ModuleList([MetaCell(name=f'{self.name}_{cell_idx}',
interface=interface
) for cell_idx in range(self.width)])
self.meta_cell_list = meta_cell_list
def forward(self, x):
result = torch.hstack([metacell(x) for metacell in self.meta_cell_list])
return result
class MetaNet(nn.Module):
def __init__(self, interface=4, depth=3, width=4, out=1):
super().__init__()
self.out = out
self.interface = interface
self.width = width
self.depth = depth
meta_layer_list = nn.ModuleList([MetaLayer(name=f'Weight_{0}',
interface=self.interface,
width=self.width)])
meta_layer_list.extend([MetaLayer(name=f'Weight_{layer_idx + 1}',
interface=self.width, width=self.width
) for layer_idx in range(self.depth - 2)])
meta_layer_list.append(MetaLayer(name=f'Weight_{len(meta_layer_list)}',
interface=self.width, width=self.out))
self._meta_layer_list = meta_layer_list
self._net = nn.Sequential(*self._meta_layer_list)
def forward(self, x):
result = self._net.forward(x)
return result
if __name__ == '__main__': if __name__ == '__main__':
is_div = True metanet = MetaNet(2, 3, 4, 1)
while is_div: metanet(torch.ones((5, 2)))
net = SecondaryNet(4, 2, 1, "SecondaryNet") print('Test')
data_df, is_div = net.self_train(20000, 25, 1e-4) print('Test')
from matplotlib import pyplot as plt print('Test')
import seaborn as sns print('Test')
# data_df = data_df[::-1] # Reverse print('Test')
fig = sns.lineplot(data=data_df[[x for x in data_df.columns if x != 'step']])
# fig.set(yscale='log')
print(data_df.iloc[-1])
print(data_df.iloc[0])
plt.show()
print("done")