From 5f1f5833d8b05ff84f8e40495f8c056eab339466 Mon Sep 17 00:00:00 2001
From: Steffen Illium <steffen.illium@ifi.lmu.de>
Date: Fri, 21 Jan 2022 17:28:45 +0100
Subject: [PATCH] Journal TEx Text

---
 experiments/meta_task_exp.py | 50 ++++++++++++++++++++
 journal_robustness.py        | 19 ++++----
 network.py                   | 91 ++++++++++++++++++++++++++++++------
 3 files changed, 138 insertions(+), 22 deletions(-)
 create mode 100644 experiments/meta_task_exp.py

diff --git a/experiments/meta_task_exp.py b/experiments/meta_task_exp.py
new file mode 100644
index 0000000..9973dca
--- /dev/null
+++ b/experiments/meta_task_exp.py
@@ -0,0 +1,50 @@
+import numpy as np
+import torch
+from matplotlib import pyplot as plt
+import seaborn as sns
+from torch import nn
+from torch.utils.data import Dataset, DataLoader
+from tqdm import tqdm
+
+from network import MetaNet
+
+
+class TaskDataset(Dataset):
+    def __init__(self, length=int(5e5)):
+        super().__init__()
+        self.length = length
+        self.prng = np.random.default_rng()
+
+    def __len__(self):
+        return self.length
+
+    def __getitem__(self, _):
+        ab = self.prng.normal(size=(2,)).astype(np.float32)
+        return ab, ab.sum(axis=-1, keepdims=True)
+
+
+if __name__ == '__main__':
+    metanet = MetaNet(2, 3, 4, 1)
+    loss_fn = nn.MSELoss()
+    optimizer = torch.optim.AdamW(metanet.parameters(), lr=0.004)
+
+    d = DataLoader(TaskDataset(), batch_size=50, shuffle=True, drop_last=True)
+    # metanet.train(True)
+    losses = []
+    for batch_x, batch_y in tqdm(d, total=len(d)):
+        # Zero your gradients for every batch!
+        optimizer.zero_grad()
+
+        y = metanet(batch_x)
+        loss = loss_fn(y, batch_y)
+        loss.backward()
+
+        # Adjust learning weights
+        optimizer.step()
+
+        losses.append(loss.item())
+
+    sns.lineplot(y=np.asarray(losses), x=np.arange(len(losses)))
+    plt.show()
+
+
diff --git a/journal_robustness.py b/journal_robustness.py
index acce534..27bf0b6 100644
--- a/journal_robustness.py
+++ b/journal_robustness.py
@@ -137,7 +137,7 @@ class RobustnessComparisonExperiment:
                     for noise_level in range(noise_levels):
                         steps = 0
                         clone = Net(fixpoint.input_size, fixpoint.hidden_size, fixpoint.out_size,
-                                    f"{fixpoint.name}_clone_noise10e-{noise_level}")
+                                    f"{fixpoint.name}_clone_noise_1e-{noise_level}")
                         clone.load_state_dict(copy.deepcopy(fixpoint.state_dict()))
                         clone = clone.apply_noise(pow(10, -noise_level))
 
@@ -159,7 +159,8 @@ class RobustnessComparisonExperiment:
                                 # When this raises a Type Error, we found a second order fixpoint!
                             steps += 1
 
-                            df.loc[df.shape[0]] = [setting, f'$10^{{-{noise_level}}}$', steps, absolute_loss,
+                            df.loc[df.shape[0]] = [setting, f'$\mathregular{{10^{{-{noise_level}}}}}$',
+                                                   steps, absolute_loss,
                                                    time_to_vergence[setting][noise_level],
                                                    time_as_fixpoint[setting][noise_level]]
                     pbar.update(1)
@@ -171,12 +172,12 @@ class RobustnessComparisonExperiment:
                                                  var_name="Measurement",
                                                  value_name="Steps").sort_values('Noise Level')
         # Plotting
-        plt.rcParams.update({
-            "text.usetex": True,
-            "font.family": "sans-serif",
-            "font.size": 12,
-            "font.weight": 'bold',
-            "font.sans-serif": ["Helvetica"]})
+        # plt.rcParams.update({
+        #    "text.usetex": True,
+        #    "font.family": "sans-serif",
+        #    "font.size": 12,
+        #    "font.weight": 'bold',
+        #    "font.sans-serif": ["Helvetica"]})
         sns.set(style='whitegrid', font_scale=2)
         bf = sns.boxplot(data=df_melted, y='Steps', x='Noise Level', hue='Measurement', palette=PALETTE)
         synthetic = 'synthetic' if self.is_synthetic else 'natural'
@@ -191,7 +192,7 @@ class RobustnessComparisonExperiment:
         plt.savefig(str(filepath))
 
         if print_it:
-            col_headers = [str(f"10e-{d}") for d in range(noise_levels)]
+            col_headers = [str(f"1e-{d}") for d in range(noise_levels)]
 
             print(f"\nAppplications steps until divergence / zero: ")
             # print(tabulate(time_to_vergence, showindex=row_headers, headers=col_headers, tablefmt='orgtbl'))
diff --git a/network.py b/network.py
index 9ec060a..2b2c2af 100644
--- a/network.py
+++ b/network.py
@@ -245,17 +245,82 @@ class SecondaryNet(Net):
         return df, is_diverged
 
 
+class MetaWeight(Net):
+    pass
+
+
+class MetaCell(nn.Module):
+    def __init__(self, name, interface, residual_skip=True):
+        super().__init__()
+        self.residual_skip = residual_skip
+        self.name = name
+        self.interface = interface
+        self.weight_interface = 4
+        self.net_hidden_size = 4
+        self.net_ouput_size = 1
+        self.meta_weight_list = nn.ModuleList(
+            [MetaWeight(self.weight_interface, self.net_hidden_size,
+                        self.net_ouput_size, name=f'{self.name}_{weight_idx}'
+                        ) for weight_idx in range(self.interface)])
+
+    def forward(self, x):
+        xs = [torch.hstack((x[:, idx].unsqueeze(-1), torch.zeros((x.shape[0], self.weight_interface - 1))))
+              for idx in range(len(self.meta_weight_list))]
+        tensor = torch.hstack([meta_weight(xs[idx]) for idx, meta_weight in enumerate(self.meta_weight_list)])
+        if self.residual_skip:
+            tensor += x
+
+        result = torch.sum(tensor, dim=-1, keepdim=True)
+        return result
+
+
+class MetaLayer(nn.Module):
+    def __init__(self, name, interface=4, out=1, width=4):
+        super().__init__()
+        self.name = name
+        self.interface = interface
+        self.width = width
+
+        meta_cell_list = nn.ModuleList([MetaCell(name=f'{self.name}_{cell_idx}',
+                                                 interface=interface
+                                                 ) for cell_idx in range(self.width)])
+        self.meta_cell_list = meta_cell_list
+
+    def forward(self, x):
+        result = torch.hstack([metacell(x) for metacell in self.meta_cell_list])
+        return result
+
+
+class MetaNet(nn.Module):
+
+    def __init__(self, interface=4, depth=3, width=4, out=1):
+        super().__init__()
+        self.out = out
+        self.interface = interface
+        self.width = width
+        self.depth = depth
+
+        meta_layer_list = nn.ModuleList([MetaLayer(name=f'Weight_{0}',
+                                                   interface=self.interface,
+                                                   width=self.width)])
+        meta_layer_list.extend([MetaLayer(name=f'Weight_{layer_idx + 1}',
+                                          interface=self.width, width=self.width
+                                          ) for layer_idx in range(self.depth - 2)])
+        meta_layer_list.append(MetaLayer(name=f'Weight_{len(meta_layer_list)}',
+                                         interface=self.width, width=self.out))
+        self._meta_layer_list = meta_layer_list
+        self._net = nn.Sequential(*self._meta_layer_list)
+
+    def forward(self, x):
+        result = self._net.forward(x)
+        return result
+
+
 if __name__ == '__main__':
-    is_div = True
-    while is_div:
-        net = SecondaryNet(4, 2, 1, "SecondaryNet")
-        data_df, is_div = net.self_train(20000, 25, 1e-4)
-    from matplotlib import pyplot as plt
-    import seaborn as sns
-    # data_df = data_df[::-1]  # Reverse
-    fig = sns.lineplot(data=data_df[[x for x in data_df.columns if x != 'step']])
-    # fig.set(yscale='log')
-    print(data_df.iloc[-1])
-    print(data_df.iloc[0])
-    plt.show()
-    print("done")
+    metanet = MetaNet(2, 3, 4, 1)
+    metanet(torch.ones((5, 2)))
+    print('Test')
+    print('Test')
+    print('Test')
+    print('Test')
+    print('Test')