New Images

This commit is contained in:
Steffen Illium
2022-02-17 13:13:41 +01:00
parent 4f99251e68
commit 5aea4f9f55
10 changed files with 109 additions and 50 deletions

View File

@@ -6,7 +6,7 @@ Data Exchange: [Google Drive Folder](***REMOVED***)
### Fixpoint Tests: ### Fixpoint Tests:
- [ ] Dropout Test - [X] Dropout Test
- (Macht das Partikel beim Goal mit oder ist es nur SRN) - (Macht das Partikel beim Goal mit oder ist es nur SRN)
- Zero_ident diff = -00.04999637603759766 % - Zero_ident diff = -00.04999637603759766 %
@@ -29,6 +29,8 @@ Data Exchange: [Google Drive Folder](***REMOVED***)
- gits das schon? - gits das schon?
- Hypernetwork? - Hypernetwork?
- arxiv: 1905.02898 - arxiv: 1905.02898
- Sparse Networks
- Pruning
--- ---
@@ -42,6 +44,16 @@ Data Exchange: [Google Drive Folder](***REMOVED***)
| ![](./figures/sanity/sanity_3hidden_xtimesn.png) | ![](./figures/sanity/sanity_4hidden_xtimesn.png) | | ![](./figures/sanity/sanity_3hidden_xtimesn.png) | ![](./figures/sanity/sanity_4hidden_xtimesn.png) |
| SRNN x*n 6 Neurons Other_Func | SRNN x*n 10 Neurons Other_Func | | SRNN x*n 6 Neurons Other_Func | SRNN x*n 10 Neurons Other_Func |
| ![](./figures/sanity/sanity_6hidden_xtimesn.png) | ![](./figures/sanity/sanity_10hidden_xtimesn.png) | | ![](./figures/sanity/sanity_6hidden_xtimesn.png) | ![](./figures/sanity/sanity_10hidden_xtimesn.png) |
- [ ] Connectivity
- Das Netz dünnt sich wirklich aus.
|||
|---------------------------------------------------|----------------------------------------------------|
| 200 Epochs - 4 Neurons - \alpha 100 RES | |
| ![](./figures/connectivity/training_lineplot.png) | ![](./figures/connectivity/training_particle_type_lp.png) |
| OTHER FUNTIONS | IDENTITY FUNCTIONS |
| ![](./figures/connectivity/other.png) | ![](./figures/connectivity/identity.png) |
- [ ] Training mit kleineren GNs - [ ] Training mit kleineren GNs
@@ -59,6 +71,7 @@ Data Exchange: [Google Drive Folder](***REMOVED***)
- [ ] Test mit Baseline Dense Network - [ ] Test mit Baseline Dense Network
- [ ] mit vergleichbaren Neuron Count - [ ] mit vergleichbaren Neuron Count
- [ ] mit gesamt Weight Count - [ ] mit gesamt Weight Count
- [ ] Task/Goal statt SRNN-Task - [ ] Task/Goal statt SRNN-Task
--- ---

40
as_line_plot.py Normal file
View File

@@ -0,0 +1,40 @@
import numpy as np
import torch
import pandas as pd
import re
from pathlib import Path
import seaborn as sns
from matplotlib import pyplot as plt
from network import FixTypes
if __name__ == '__main__':
p = Path(r'experiments\output\mn_st_200_4_alpha_100\trained_model_ckpt_e200.tp')
m = torch.load(p, map_location=torch.device('cpu'))
particles = [y for x in m._meta_layer_list for y in x.particles]
df = pd.DataFrame(columns=['type', 'layer', 'neuron', 'name', 'color'])
colors = []
for particle in particles:
l, c, w = [float(x) for x in re.sub("[^0-9|_]", "", particle.name).split('_')]
color = sns.color_palette()[0 if particle.is_fixpoint == FixTypes.identity_func else 1]
# color = 'orange' if particle.is_fixpoint == FixTypes.identity_func else 'blue'
colors.append(color)
df.loc[df.shape[0]] = (particle.is_fixpoint, l-1, w, particle.name, color)
df.loc[df.shape[0]] = (particle.is_fixpoint, l, c, particle.name, color)
for layer in list(df['layer'].unique()):
divisor = df.loc[(df['layer'] == layer), 'neuron'].max()
df.loc[(df['layer'] == layer), 'neuron'] /= divisor
print('gathered')
for n, (fixtype, color) in enumerate(zip([FixTypes.other_func, FixTypes.identity_func], ['blue', 'orange'])):
plt.clf()
ax = sns.lineplot(y='neuron', x='layer', hue='name', data=df[df['type'] == fixtype],
legend=False, estimator=None,
palette=[sns.color_palette()[n]] * (df[df['type'] == fixtype].shape[0]//2), lw=1)
# ax.set(yscale='log', ylabel='Neuron')
ax.set_title(fixtype)
plt.show()
print('plottet')

View File

@@ -17,6 +17,7 @@ from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor, Compose, Resize from torchvision.transforms import ToTensor, Compose, Resize
from tqdm import tqdm from tqdm import tqdm
if platform.node() == 'CarbonX': if platform.node() == 'CarbonX':
debug = True debug = True
print("@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@") print("@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@")
@@ -36,8 +37,8 @@ else:
DIR = None DIR = None
pass pass
from network import MetaNet from network import MetaNet, FixTypes
from functionalities_test import test_for_fixpoints, FixTypes from functionalities_test import test_for_fixpoints
WORKER = 10 if not debug else 2 WORKER = 10 if not debug else 2
debug = False debug = False
@@ -195,13 +196,14 @@ def flat_for_store(parameters):
if __name__ == '__main__': if __name__ == '__main__':
self_train = True self_train = True
training = True training = False
plotting = True plotting = True
particle_analysis = True particle_analysis = True
as_sparse_network_test = True as_sparse_network_test = True
self_train_alpha = 1 train_to_id_first = False
self_train_alpha = 100
batch_train_beta = 1 batch_train_beta = 1
weight_hidden_size = 5 weight_hidden_size = 4
residual_skip = True residual_skip = True
dropout = 0 dropout = 0
@@ -209,9 +211,11 @@ if __name__ == '__main__':
data_path.mkdir(exist_ok=True, parents=True) data_path.mkdir(exist_ok=True, parents=True)
st_str = f'{"" if self_train else "no_"}st' st_str = f'{"" if self_train else "no_"}st'
a_str = f'_alpha_{self_train_alpha}' if self_train_alpha != 1 else ''
res_str = f'{"" if residual_skip else "_no"}_res' res_str = f'{"" if residual_skip else "_no"}_res'
dr_str = f'{f"_dr_{dropout}" if dropout != 0 else ""}' dr_str = f'{f"_dr_{dropout}" if dropout != 0 else ""}'
run_path = Path('output') / f'mn_{st_str}_{EPOCH}_{weight_hidden_size}{res_str}{dr_str}' id_str = f'{f"_StToId" if train_to_id_first else ""}'
run_path = Path('output') / f'mn_{st_str}_{EPOCH}_{weight_hidden_size}{a_str}{res_str}{dr_str}{id_str}'
model_path = run_path / '0000_trained_model.zip' model_path = run_path / '0000_trained_model.zip'
df_store_path = run_path / 'train_store.csv' df_store_path = run_path / 'train_store.csv'
@@ -245,8 +249,9 @@ if __name__ == '__main__':
metric = torchmetrics.Accuracy() metric = torchmetrics.Accuracy()
else: else:
metric = None metric = None
init_st = train_to_id_first and all(x.is_fixpoint == FixTypes.identity_func for x in metanet.particles)
for batch, (batch_x, batch_y) in tqdm(enumerate(d), total=len(d), desc='MetaNet Train - Batch'): for batch, (batch_x, batch_y) in tqdm(enumerate(d), total=len(d), desc='MetaNet Train - Batch'):
if self_train and is_self_train_epoch: if (self_train and is_self_train_epoch) or init_st:
# Zero your gradients for every batch! # Zero your gradients for every batch!
optimizer.zero_grad() optimizer.zero_grad()
self_train_loss = metanet.combined_self_train() * self_train_alpha self_train_loss = metanet.combined_self_train() * self_train_alpha
@@ -255,44 +260,46 @@ if __name__ == '__main__':
optimizer.step() optimizer.step()
step_log = dict(Epoch=epoch, Batch=batch, Metric='Self Train Loss', Score=self_train_loss.item()) step_log = dict(Epoch=epoch, Batch=batch, Metric='Self Train Loss', Score=self_train_loss.item())
train_store.loc[train_store.shape[0]] = step_log train_store.loc[train_store.shape[0]] = step_log
if train_to_id_first <= epoch:
# Zero your gradients for every batch!
optimizer.zero_grad()
batch_x, batch_y = batch_x.to(DEVICE), batch_y.to(DEVICE)
y = metanet(batch_x)
# loss = loss_fn(y, batch_y.unsqueeze(-1).to(torch.float32))
loss = loss_fn(y, batch_y.to(torch.long)) * batch_train_beta
loss.backward()
# Zero your gradients for every batch! # Adjust learning weights
optimizer.zero_grad() optimizer.step()
batch_x, batch_y = batch_x.to(DEVICE), batch_y.to(DEVICE)
y = metanet(batch_x)
# loss = loss_fn(y, batch_y.unsqueeze(-1).to(torch.float32))
loss = loss_fn(y, batch_y.to(torch.long)) * batch_train_beta
loss.backward()
# Adjust learning weights step_log = dict(Epoch=epoch, Batch=batch,
optimizer.step() Metric='Task Loss', Score=loss.item())
train_store.loc[train_store.shape[0]] = step_log
step_log = dict(Epoch=epoch, Batch=batch, if is_validation_epoch:
Metric='Task Loss', Score=loss.item()) metric(y.cpu(), batch_y.cpu())
train_store.loc[train_store.shape[0]] = step_log
if is_validation_epoch:
metric(y.cpu(), batch_y.cpu())
if batch >= 3 and debug: if batch >= 3 and debug:
break break
if is_validation_epoch: if is_validation_epoch:
metanet = metanet.eval() metanet = metanet.eval()
validation_log = dict(Epoch=int(epoch), Batch=BATCHSIZE, if train_to_id_first <= epoch:
Metric='Train Accuracy', Score=metric.compute().item()) validation_log = dict(Epoch=int(epoch), Batch=BATCHSIZE,
train_store.loc[train_store.shape[0]] = validation_log Metric='Train Accuracy', Score=metric.compute().item())
train_store.loc[train_store.shape[0]] = validation_log
accuracy = checkpoint_and_validate(metanet, run_path, epoch) accuracy = checkpoint_and_validate(metanet, run_path, epoch)
validation_log = dict(Epoch=int(epoch), Batch=BATCHSIZE, validation_log = dict(Epoch=int(epoch), Batch=BATCHSIZE,
Metric='Test Accuracy', Score=accuracy.item()) Metric='Test Accuracy', Score=accuracy.item())
train_store.loc[train_store.shape[0]] = validation_log train_store.loc[train_store.shape[0]] = validation_log
if particle_analysis: if particle_analysis and (init_st or is_validation_epoch):
counter_dict = defaultdict(lambda: 0) counter_dict = defaultdict(lambda: 0)
# This returns ID-functions # This returns ID-functions
_ = test_for_fixpoints(counter_dict, list(metanet.particles)) _ = test_for_fixpoints(counter_dict, list(metanet.particles))
for key, value in dict(counter_dict).items(): for key, value in dict(counter_dict).items():
step_log = dict(Epoch=int(epoch), Batch=BATCHSIZE, Metric=key, Score=value) step_log = dict(Epoch=int(epoch), Batch=BATCHSIZE, Metric=key, Score=value)
train_store.loc[train_store.shape[0]] = step_log train_store.loc[train_store.shape[0]] = step_log
if init_st or is_validation_epoch:
for particle in metanet.particles: for particle in metanet.particles:
weight_log = (epoch, particle.name, *flat_for_store(particle.parameters())) weight_log = (epoch, particle.name, *flat_for_store(particle.parameters()))
weight_store.loc[weight_store.shape[0]] = weight_log weight_store.loc[weight_store.shape[0]] = weight_log
@@ -355,7 +362,7 @@ if __name__ == '__main__':
fig, ax = plt.subplots(ncols=2) fig, ax = plt.subplots(ncols=2)
labels = ['Full Network', 'Sparse, No Identity', 'Sparse, No Other'] labels = ['Full Network', 'Sparse, No Identity', 'Sparse, No Other']
colors = sns.color_palette()[:diff_df.shape[0]] if diff_df.shape[0] >= 2 else sns.color_palette()[0] colors = sns.color_palette()[:diff_df.shape[0]] if diff_df.shape[0] >= 2 else sns.color_palette()[0]
barplot = sns.barplot(data=diff_df, y='Accuracy', x='Particle Type', color=colors, ax=ax[0]) barplot = sns.barplot(data=diff_df, y='Accuracy', x='Particle Type', palette=colors, ax=ax[0])
# noinspection PyUnboundLocalVariable # noinspection PyUnboundLocalVariable
for idx, patch in enumerate(barplot.patches): for idx, patch in enumerate(barplot.patches):
if idx != 0: if idx != 0:
@@ -366,7 +373,7 @@ if __name__ == '__main__':
ax[0].set_xlabel('Accuracy') ax[0].set_xlabel('Accuracy')
# ax[0].legend() # ax[0].legend()
ax[1].pie(counter_dict.values(), labels=counter_dict.keys(), colors=sns.color_palette()[:3], ) ax[1].pie(counter_dict.values(), labels=counter_dict.keys(), colors=colors, )
ax[1].set_title('Particle Count for ') ax[1].set_title('Particle Count for ')
# ax[1].set_xlabel('') # ax[1].set_xlabel('')

Binary file not shown.

After

Width:  |  Height:  |  Size: 98 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 97 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 198 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 91 KiB

View File

@@ -3,20 +3,7 @@ from typing import Dict, List
import torch import torch
from tqdm import tqdm from tqdm import tqdm
from network import Net from network import FixTypes, Net
class FixTypes:
divergent = 'divergent'
fix_zero = 'fix_zero'
identity_func = 'identity_func'
fix_sec = 'fix_sec'
other_func = 'other_func'
@classmethod
def all_types(cls):
return [val for key, val in cls.__dict__.items() if isinstance(val, str) and not key.startswith('_')]
def is_divergent(network: Net) -> bool: def is_divergent(network: Net) -> bool:

0
helpers.py Normal file
View File

View File

@@ -15,6 +15,18 @@ from tqdm import tqdm
def prng(): def prng():
return random.random() return random.random()
class FixTypes:
divergent = 'divergent'
fix_zero = 'fix_zero'
identity_func = 'identity_func'
fix_sec = 'fix_sec'
other_func = 'other_func'
@classmethod
def all_types(cls):
return [val for key, val in cls.__dict__.items() if isinstance(val, str) and not key.startswith('_')]
class Net(nn.Module): class Net(nn.Module):
@@ -79,7 +91,7 @@ class Net(nn.Module):
self.trained = False self.trained = False
self.number_trained = 0 self.number_trained = 0
self.is_fixpoint = "" self.is_fixpoint = FixTypes.other_func
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[nn.Linear(i_size, h_size, False), [nn.Linear(i_size, h_size, False),
nn.Linear(h_size, h_size, False), nn.Linear(h_size, h_size, False),