From 5f1f5833d8b05ff84f8e40495f8c056eab339466 Mon Sep 17 00:00:00 2001 From: Steffen Illium <steffen.illium@ifi.lmu.de> Date: Fri, 21 Jan 2022 17:28:45 +0100 Subject: [PATCH] Journal TEx Text --- experiments/meta_task_exp.py | 50 ++++++++++++++++++++ journal_robustness.py | 19 ++++---- network.py | 91 ++++++++++++++++++++++++++++++------ 3 files changed, 138 insertions(+), 22 deletions(-) create mode 100644 experiments/meta_task_exp.py diff --git a/experiments/meta_task_exp.py b/experiments/meta_task_exp.py new file mode 100644 index 0000000..9973dca --- /dev/null +++ b/experiments/meta_task_exp.py @@ -0,0 +1,50 @@ +import numpy as np +import torch +from matplotlib import pyplot as plt +import seaborn as sns +from torch import nn +from torch.utils.data import Dataset, DataLoader +from tqdm import tqdm + +from network import MetaNet + + +class TaskDataset(Dataset): + def __init__(self, length=int(5e5)): + super().__init__() + self.length = length + self.prng = np.random.default_rng() + + def __len__(self): + return self.length + + def __getitem__(self, _): + ab = self.prng.normal(size=(2,)).astype(np.float32) + return ab, ab.sum(axis=-1, keepdims=True) + + +if __name__ == '__main__': + metanet = MetaNet(2, 3, 4, 1) + loss_fn = nn.MSELoss() + optimizer = torch.optim.AdamW(metanet.parameters(), lr=0.004) + + d = DataLoader(TaskDataset(), batch_size=50, shuffle=True, drop_last=True) + # metanet.train(True) + losses = [] + for batch_x, batch_y in tqdm(d, total=len(d)): + # Zero your gradients for every batch! + optimizer.zero_grad() + + y = metanet(batch_x) + loss = loss_fn(y, batch_y) + loss.backward() + + # Adjust learning weights + optimizer.step() + + losses.append(loss.item()) + + sns.lineplot(y=np.asarray(losses), x=np.arange(len(losses))) + plt.show() + + diff --git a/journal_robustness.py b/journal_robustness.py index acce534..27bf0b6 100644 --- a/journal_robustness.py +++ b/journal_robustness.py @@ -137,7 +137,7 @@ class RobustnessComparisonExperiment: for noise_level in range(noise_levels): steps = 0 clone = Net(fixpoint.input_size, fixpoint.hidden_size, fixpoint.out_size, - f"{fixpoint.name}_clone_noise10e-{noise_level}") + f"{fixpoint.name}_clone_noise_1e-{noise_level}") clone.load_state_dict(copy.deepcopy(fixpoint.state_dict())) clone = clone.apply_noise(pow(10, -noise_level)) @@ -159,7 +159,8 @@ class RobustnessComparisonExperiment: # When this raises a Type Error, we found a second order fixpoint! steps += 1 - df.loc[df.shape[0]] = [setting, f'$10^{{-{noise_level}}}$', steps, absolute_loss, + df.loc[df.shape[0]] = [setting, f'$\mathregular{{10^{{-{noise_level}}}}}$', + steps, absolute_loss, time_to_vergence[setting][noise_level], time_as_fixpoint[setting][noise_level]] pbar.update(1) @@ -171,12 +172,12 @@ class RobustnessComparisonExperiment: var_name="Measurement", value_name="Steps").sort_values('Noise Level') # Plotting - plt.rcParams.update({ - "text.usetex": True, - "font.family": "sans-serif", - "font.size": 12, - "font.weight": 'bold', - "font.sans-serif": ["Helvetica"]}) + # plt.rcParams.update({ + # "text.usetex": True, + # "font.family": "sans-serif", + # "font.size": 12, + # "font.weight": 'bold', + # "font.sans-serif": ["Helvetica"]}) sns.set(style='whitegrid', font_scale=2) bf = sns.boxplot(data=df_melted, y='Steps', x='Noise Level', hue='Measurement', palette=PALETTE) synthetic = 'synthetic' if self.is_synthetic else 'natural' @@ -191,7 +192,7 @@ class RobustnessComparisonExperiment: plt.savefig(str(filepath)) if print_it: - col_headers = [str(f"10e-{d}") for d in range(noise_levels)] + col_headers = [str(f"1e-{d}") for d in range(noise_levels)] print(f"\nAppplications steps until divergence / zero: ") # print(tabulate(time_to_vergence, showindex=row_headers, headers=col_headers, tablefmt='orgtbl')) diff --git a/network.py b/network.py index 9ec060a..2b2c2af 100644 --- a/network.py +++ b/network.py @@ -245,17 +245,82 @@ class SecondaryNet(Net): return df, is_diverged +class MetaWeight(Net): + pass + + +class MetaCell(nn.Module): + def __init__(self, name, interface, residual_skip=True): + super().__init__() + self.residual_skip = residual_skip + self.name = name + self.interface = interface + self.weight_interface = 4 + self.net_hidden_size = 4 + self.net_ouput_size = 1 + self.meta_weight_list = nn.ModuleList( + [MetaWeight(self.weight_interface, self.net_hidden_size, + self.net_ouput_size, name=f'{self.name}_{weight_idx}' + ) for weight_idx in range(self.interface)]) + + def forward(self, x): + xs = [torch.hstack((x[:, idx].unsqueeze(-1), torch.zeros((x.shape[0], self.weight_interface - 1)))) + for idx in range(len(self.meta_weight_list))] + tensor = torch.hstack([meta_weight(xs[idx]) for idx, meta_weight in enumerate(self.meta_weight_list)]) + if self.residual_skip: + tensor += x + + result = torch.sum(tensor, dim=-1, keepdim=True) + return result + + +class MetaLayer(nn.Module): + def __init__(self, name, interface=4, out=1, width=4): + super().__init__() + self.name = name + self.interface = interface + self.width = width + + meta_cell_list = nn.ModuleList([MetaCell(name=f'{self.name}_{cell_idx}', + interface=interface + ) for cell_idx in range(self.width)]) + self.meta_cell_list = meta_cell_list + + def forward(self, x): + result = torch.hstack([metacell(x) for metacell in self.meta_cell_list]) + return result + + +class MetaNet(nn.Module): + + def __init__(self, interface=4, depth=3, width=4, out=1): + super().__init__() + self.out = out + self.interface = interface + self.width = width + self.depth = depth + + meta_layer_list = nn.ModuleList([MetaLayer(name=f'Weight_{0}', + interface=self.interface, + width=self.width)]) + meta_layer_list.extend([MetaLayer(name=f'Weight_{layer_idx + 1}', + interface=self.width, width=self.width + ) for layer_idx in range(self.depth - 2)]) + meta_layer_list.append(MetaLayer(name=f'Weight_{len(meta_layer_list)}', + interface=self.width, width=self.out)) + self._meta_layer_list = meta_layer_list + self._net = nn.Sequential(*self._meta_layer_list) + + def forward(self, x): + result = self._net.forward(x) + return result + + if __name__ == '__main__': - is_div = True - while is_div: - net = SecondaryNet(4, 2, 1, "SecondaryNet") - data_df, is_div = net.self_train(20000, 25, 1e-4) - from matplotlib import pyplot as plt - import seaborn as sns - # data_df = data_df[::-1] # Reverse - fig = sns.lineplot(data=data_df[[x for x in data_df.columns if x != 'step']]) - # fig.set(yscale='log') - print(data_df.iloc[-1]) - print(data_df.iloc[0]) - plt.show() - print("done") + metanet = MetaNet(2, 3, 4, 1) + metanet(torch.ones((5, 2))) + print('Test') + print('Test') + print('Test') + print('Test') + print('Test')