Journal TEx Text
This commit is contained in:
parent
21dd572969
commit
5f1f5833d8
50
experiments/meta_task_exp.py
Normal file
50
experiments/meta_task_exp.py
Normal file
@ -0,0 +1,50 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
from matplotlib import pyplot as plt
|
||||
import seaborn as sns
|
||||
from torch import nn
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from tqdm import tqdm
|
||||
|
||||
from network import MetaNet
|
||||
|
||||
|
||||
class TaskDataset(Dataset):
|
||||
def __init__(self, length=int(5e5)):
|
||||
super().__init__()
|
||||
self.length = length
|
||||
self.prng = np.random.default_rng()
|
||||
|
||||
def __len__(self):
|
||||
return self.length
|
||||
|
||||
def __getitem__(self, _):
|
||||
ab = self.prng.normal(size=(2,)).astype(np.float32)
|
||||
return ab, ab.sum(axis=-1, keepdims=True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
metanet = MetaNet(2, 3, 4, 1)
|
||||
loss_fn = nn.MSELoss()
|
||||
optimizer = torch.optim.AdamW(metanet.parameters(), lr=0.004)
|
||||
|
||||
d = DataLoader(TaskDataset(), batch_size=50, shuffle=True, drop_last=True)
|
||||
# metanet.train(True)
|
||||
losses = []
|
||||
for batch_x, batch_y in tqdm(d, total=len(d)):
|
||||
# Zero your gradients for every batch!
|
||||
optimizer.zero_grad()
|
||||
|
||||
y = metanet(batch_x)
|
||||
loss = loss_fn(y, batch_y)
|
||||
loss.backward()
|
||||
|
||||
# Adjust learning weights
|
||||
optimizer.step()
|
||||
|
||||
losses.append(loss.item())
|
||||
|
||||
sns.lineplot(y=np.asarray(losses), x=np.arange(len(losses)))
|
||||
plt.show()
|
||||
|
||||
|
@ -137,7 +137,7 @@ class RobustnessComparisonExperiment:
|
||||
for noise_level in range(noise_levels):
|
||||
steps = 0
|
||||
clone = Net(fixpoint.input_size, fixpoint.hidden_size, fixpoint.out_size,
|
||||
f"{fixpoint.name}_clone_noise10e-{noise_level}")
|
||||
f"{fixpoint.name}_clone_noise_1e-{noise_level}")
|
||||
clone.load_state_dict(copy.deepcopy(fixpoint.state_dict()))
|
||||
clone = clone.apply_noise(pow(10, -noise_level))
|
||||
|
||||
@ -159,7 +159,8 @@ class RobustnessComparisonExperiment:
|
||||
# When this raises a Type Error, we found a second order fixpoint!
|
||||
steps += 1
|
||||
|
||||
df.loc[df.shape[0]] = [setting, f'$10^{{-{noise_level}}}$', steps, absolute_loss,
|
||||
df.loc[df.shape[0]] = [setting, f'$\mathregular{{10^{{-{noise_level}}}}}$',
|
||||
steps, absolute_loss,
|
||||
time_to_vergence[setting][noise_level],
|
||||
time_as_fixpoint[setting][noise_level]]
|
||||
pbar.update(1)
|
||||
@ -171,12 +172,12 @@ class RobustnessComparisonExperiment:
|
||||
var_name="Measurement",
|
||||
value_name="Steps").sort_values('Noise Level')
|
||||
# Plotting
|
||||
plt.rcParams.update({
|
||||
"text.usetex": True,
|
||||
"font.family": "sans-serif",
|
||||
"font.size": 12,
|
||||
"font.weight": 'bold',
|
||||
"font.sans-serif": ["Helvetica"]})
|
||||
# plt.rcParams.update({
|
||||
# "text.usetex": True,
|
||||
# "font.family": "sans-serif",
|
||||
# "font.size": 12,
|
||||
# "font.weight": 'bold',
|
||||
# "font.sans-serif": ["Helvetica"]})
|
||||
sns.set(style='whitegrid', font_scale=2)
|
||||
bf = sns.boxplot(data=df_melted, y='Steps', x='Noise Level', hue='Measurement', palette=PALETTE)
|
||||
synthetic = 'synthetic' if self.is_synthetic else 'natural'
|
||||
@ -191,7 +192,7 @@ class RobustnessComparisonExperiment:
|
||||
plt.savefig(str(filepath))
|
||||
|
||||
if print_it:
|
||||
col_headers = [str(f"10e-{d}") for d in range(noise_levels)]
|
||||
col_headers = [str(f"1e-{d}") for d in range(noise_levels)]
|
||||
|
||||
print(f"\nAppplications steps until divergence / zero: ")
|
||||
# print(tabulate(time_to_vergence, showindex=row_headers, headers=col_headers, tablefmt='orgtbl'))
|
||||
|
91
network.py
91
network.py
@ -245,17 +245,82 @@ class SecondaryNet(Net):
|
||||
return df, is_diverged
|
||||
|
||||
|
||||
class MetaWeight(Net):
|
||||
pass
|
||||
|
||||
|
||||
class MetaCell(nn.Module):
|
||||
def __init__(self, name, interface, residual_skip=True):
|
||||
super().__init__()
|
||||
self.residual_skip = residual_skip
|
||||
self.name = name
|
||||
self.interface = interface
|
||||
self.weight_interface = 4
|
||||
self.net_hidden_size = 4
|
||||
self.net_ouput_size = 1
|
||||
self.meta_weight_list = nn.ModuleList(
|
||||
[MetaWeight(self.weight_interface, self.net_hidden_size,
|
||||
self.net_ouput_size, name=f'{self.name}_{weight_idx}'
|
||||
) for weight_idx in range(self.interface)])
|
||||
|
||||
def forward(self, x):
|
||||
xs = [torch.hstack((x[:, idx].unsqueeze(-1), torch.zeros((x.shape[0], self.weight_interface - 1))))
|
||||
for idx in range(len(self.meta_weight_list))]
|
||||
tensor = torch.hstack([meta_weight(xs[idx]) for idx, meta_weight in enumerate(self.meta_weight_list)])
|
||||
if self.residual_skip:
|
||||
tensor += x
|
||||
|
||||
result = torch.sum(tensor, dim=-1, keepdim=True)
|
||||
return result
|
||||
|
||||
|
||||
class MetaLayer(nn.Module):
|
||||
def __init__(self, name, interface=4, out=1, width=4):
|
||||
super().__init__()
|
||||
self.name = name
|
||||
self.interface = interface
|
||||
self.width = width
|
||||
|
||||
meta_cell_list = nn.ModuleList([MetaCell(name=f'{self.name}_{cell_idx}',
|
||||
interface=interface
|
||||
) for cell_idx in range(self.width)])
|
||||
self.meta_cell_list = meta_cell_list
|
||||
|
||||
def forward(self, x):
|
||||
result = torch.hstack([metacell(x) for metacell in self.meta_cell_list])
|
||||
return result
|
||||
|
||||
|
||||
class MetaNet(nn.Module):
|
||||
|
||||
def __init__(self, interface=4, depth=3, width=4, out=1):
|
||||
super().__init__()
|
||||
self.out = out
|
||||
self.interface = interface
|
||||
self.width = width
|
||||
self.depth = depth
|
||||
|
||||
meta_layer_list = nn.ModuleList([MetaLayer(name=f'Weight_{0}',
|
||||
interface=self.interface,
|
||||
width=self.width)])
|
||||
meta_layer_list.extend([MetaLayer(name=f'Weight_{layer_idx + 1}',
|
||||
interface=self.width, width=self.width
|
||||
) for layer_idx in range(self.depth - 2)])
|
||||
meta_layer_list.append(MetaLayer(name=f'Weight_{len(meta_layer_list)}',
|
||||
interface=self.width, width=self.out))
|
||||
self._meta_layer_list = meta_layer_list
|
||||
self._net = nn.Sequential(*self._meta_layer_list)
|
||||
|
||||
def forward(self, x):
|
||||
result = self._net.forward(x)
|
||||
return result
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
is_div = True
|
||||
while is_div:
|
||||
net = SecondaryNet(4, 2, 1, "SecondaryNet")
|
||||
data_df, is_div = net.self_train(20000, 25, 1e-4)
|
||||
from matplotlib import pyplot as plt
|
||||
import seaborn as sns
|
||||
# data_df = data_df[::-1] # Reverse
|
||||
fig = sns.lineplot(data=data_df[[x for x in data_df.columns if x != 'step']])
|
||||
# fig.set(yscale='log')
|
||||
print(data_df.iloc[-1])
|
||||
print(data_df.iloc[0])
|
||||
plt.show()
|
||||
print("done")
|
||||
metanet = MetaNet(2, 3, 4, 1)
|
||||
metanet(torch.ones((5, 2)))
|
||||
print('Test')
|
||||
print('Test')
|
||||
print('Test')
|
||||
print('Test')
|
||||
print('Test')
|
||||
|
Loading…
x
Reference in New Issue
Block a user