self-replicating-neural-net.../meta_task_exp_small.py
2022-03-12 11:39:28 +01:00

245 lines
12 KiB
Python

from collections import defaultdict
from pathlib import Path
import numpy as np
import torch
import torchmetrics
from torch import nn
from torch.utils.data import DataLoader
from tqdm import tqdm
from experiments.meta_task_small_utility import AddTaskDataset, train_task
from experiments.robustness_tester import test_robustness
from network import MetaNet
from functionalities_test import test_for_fixpoints, FixTypes as ft
from experiments.meta_task_utility import new_storage_df, flat_for_store, plot_training_result, \
plot_training_particle_types, run_particle_dropout_and_plot, plot_network_connectivity_by_fixtype, \
checkpoint_and_validate, plot_training_results_over_n_seeds, sanity_weight_swap, FINAL_CHECKPOINT_NAME
from plot_3d_trajectories import plot_single_3d_trajectories_by_layer, plot_grouped_3d_trajectories_by_layer
WORKER = 0
BATCHSIZE = 50
EPOCH = 60
VALIDATION_FRQ = 3
VAL_METRIC_CLASS = torchmetrics.MeanAbsoluteError
# noinspection PyProtectedMember
VAL_METRIC_NAME = VAL_METRIC_CLASS()._get_name()
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
plot_loader = DataLoader(AddTaskDataset(), batch_size=BATCHSIZE, shuffle=True,
drop_last=True, num_workers=WORKER)
if __name__ == '__main__':
training = False
plotting = True
robustness = True
attack = False
attack_ratio = 0.01
melt = False
melt_ratio = 0.01
n_st = 200
activation = None # nn.ReLU()
for weight_hidden_size in [3]:
weight_hidden_size = weight_hidden_size
residual_skip = True
n_seeds = 10
depth = 3
width = 3
out = 1
data_path = Path('data')
data_path.mkdir(exist_ok=True, parents=True)
# noinspection PyUnresolvedReferences
ac_str = f'_{activation.__class__.__name__}' if activation is not None else ''
res_str = f'{"" if residual_skip else "_no_res"}'
att_str = f'_att_{attack_ratio}' if attack else ''
mlt_str = f'_mlt_{melt_ratio}' if melt else ''
w_str = f'_w{width}wh{weight_hidden_size}d{depth}'
# dr_str = f'{f"_dr_{dropout}" if dropout != 0 else ""}'
config_str = f'{res_str}{att_str}{ac_str}{mlt_str}{w_str}'
exp_path = Path('output') / f'add_st_{EPOCH}{config_str}'
# if not training:
# # noinspection PyRedeclaration
# exp_path = Path('output') / f'add_st_{n_st}_{weight_hidden_size}'
for seed in range(n_seeds):
seed_path = exp_path / str(seed)
df_store_path = seed_path / 'train_store.csv'
weight_store_path = seed_path / 'weight_store.csv'
srnn_parameters = dict()
valid_data = AddTaskDataset()
vali_load = DataLoader(valid_data, batch_size=BATCHSIZE, shuffle=True,
drop_last=True, num_workers=WORKER)
if training:
# Check if files do exist on project location, warn and break.
for path in [df_store_path, weight_store_path]:
assert not path.exists(), f'Path "{path}" already exists. Check your configuration!'
train_data = AddTaskDataset()
train_load = DataLoader(train_data, batch_size=BATCHSIZE, shuffle=True,
drop_last=True, num_workers=WORKER)
interface = np.prod(train_data[0][0].shape)
metanet = MetaNet(interface, depth=depth, width=width, out=out,
residual_skip=residual_skip, weight_hidden_size=weight_hidden_size,
activation=activation
).to(DEVICE)
loss_fn = nn.MSELoss()
optimizer = torch.optim.SGD(metanet.parameters(), lr=0.004, momentum=0.9)
train_store = new_storage_df('train', None)
weight_store = new_storage_df('weights', metanet.particle_parameter_count)
for epoch in tqdm(range(EPOCH), desc=f'Train - Epochs'):
is_validation_epoch = epoch % VALIDATION_FRQ == 0
metanet = metanet.train()
# Init metrics, even we do not need:
metric = VAL_METRIC_CLASS()
n_st_per_batch = max(1, (n_st // len(train_load)))
for batch, (batch_x, batch_y) in tqdm(enumerate(train_load),
total=len(train_load), desc='MetaNet Train - Batch'
):
# Self Train
self_train_loss = metanet.combined_self_train(n_st_per_batch,
reduction='mean', per_particle=False)
# noinspection PyUnboundLocalVariable
st_step_log = dict(Metric='Self Train Loss', Score=self_train_loss.item())
st_step_log.update(dict(Epoch=epoch, Batch=batch))
train_store.loc[train_store.shape[0]] = st_step_log
# Attack
if attack:
after_attack_loss = metanet.make_particles_attack(attack_ratio)
st_step_log = dict(Metric='After Attack Loss', Score=after_attack_loss.item())
st_step_log.update(dict(Epoch=epoch, Batch=batch))
train_store.loc[train_store.shape[0]] = st_step_log
# Melt
if melt:
after_melt_loss = metanet.make_particles_melt(melt_ratio)
st_step_log = dict(Metric='After Melt Loss', Score=after_melt_loss.item())
st_step_log.update(dict(Epoch=epoch, Batch=batch))
train_store.loc[train_store.shape[0]] = st_step_log
# Task Train
tsk_step_log, y_pred = train_task(metanet, optimizer, loss_fn, batch_x, batch_y)
tsk_step_log.update(dict(Epoch=epoch, Batch=batch))
train_store.loc[train_store.shape[0]] = tsk_step_log
metric(y_pred.cpu(), batch_y.cpu())
if is_validation_epoch:
metanet = metanet.eval()
if metric.total.item():
validation_log = dict(Epoch=int(epoch), Batch=BATCHSIZE,
Metric=f'Train {VAL_METRIC_NAME}', Score=metric.compute().item())
train_store.loc[train_store.shape[0]] = validation_log
mae = checkpoint_and_validate(metanet, vali_load, seed_path, epoch,
validation_metric=VAL_METRIC_CLASS).item()
validation_log = dict(Epoch=int(epoch), Batch=BATCHSIZE,
Metric=f'Test {VAL_METRIC_NAME}', Score=mae)
train_store.loc[train_store.shape[0]] = validation_log
if is_validation_epoch:
counter_dict = defaultdict(lambda: 0)
# This returns ID-functions
_ = test_for_fixpoints(counter_dict, list(metanet.particles))
counter_dict = dict(counter_dict)
for key, value in counter_dict.items():
val_step_log = dict(Epoch=int(epoch), Batch=BATCHSIZE, Metric=key, Score=value)
train_store.loc[train_store.shape[0]] = val_step_log
tqdm.write(f'Fixpoint Tester Results: {counter_dict}')
# FLUSH to disk
if is_validation_epoch:
for particle in metanet.particles:
weight_log = (epoch, particle.name, *flat_for_store(particle.parameters()))
weight_store.loc[weight_store.shape[0]] = weight_log
train_store.to_csv(df_store_path, mode='a',
header=not df_store_path.exists(), index=False)
weight_store.to_csv(weight_store_path, mode='a',
header=not weight_store_path.exists(), index=False)
train_store = new_storage_df('train', None)
weight_store = new_storage_df('weights', metanet.particle_parameter_count)
###########################################################
# EPOCHS endet
metanet = metanet.eval()
counter_dict = defaultdict(lambda: 0)
# This returns ID-functions
_ = test_for_fixpoints(counter_dict, list(metanet.particles))
for key, value in dict(counter_dict).items():
step_log = dict(Epoch=int(EPOCH), Batch=BATCHSIZE, Metric=key, Score=value)
train_store.loc[train_store.shape[0]] = step_log
accuracy = checkpoint_and_validate(metanet, vali_load, seed_path, EPOCH, final_model=True,
validation_metric=VAL_METRIC_CLASS)
validation_log = dict(Epoch=EPOCH, Batch=BATCHSIZE,
Metric=f'Test {VAL_METRIC_NAME}', Score=accuracy.item())
for particle in metanet.particles:
weight_log = (EPOCH, particle.name, *(flat_for_store(particle.parameters())))
weight_store.loc[weight_store.shape[0]] = weight_log
train_store.loc[train_store.shape[0]] = validation_log
train_store.to_csv(df_store_path, mode='a', header=not df_store_path.exists(), index=False)
weight_store.to_csv(weight_store_path, mode='a', header=not weight_store_path.exists(), index=False)
if plotting:
plot_training_result(df_store_path, metric_name=VAL_METRIC_NAME)
plot_training_particle_types(df_store_path)
try:
model_path = next(seed_path.glob(f'*{FINAL_CHECKPOINT_NAME}'))
except StopIteration:
print('####################################################')
print('ERROR: Model pattern did not trigger.')
print(f'INFO: Search path was: {seed_path}:')
print(f'INFO: Found Models are: {list(seed_path.rglob(".tp"))}')
print('####################################################')
exit(1)
try:
# noinspection PyUnboundLocalVariable
run_particle_dropout_and_plot(model_path, valid_loader=plot_loader, metric_class=VAL_METRIC_CLASS)
except ValueError as e:
print('ERROR:', e)
try:
plot_network_connectivity_by_fixtype(model_path)
except ValueError as e:
print('ERROR:', e)
try:
tqdm.write('Trajectory plotting ...')
plot_single_3d_trajectories_by_layer(model_path, weight_store_path, status_type=ft.identity_func)
plot_single_3d_trajectories_by_layer(model_path, weight_store_path, status_type=ft.other_func)
plot_grouped_3d_trajectories_by_layer(model_path, weight_store_path, status_type=ft.identity_func)
plot_grouped_3d_trajectories_by_layer(model_path, weight_store_path, status_type=ft.other_func)
tqdm.write('Trajectory plotting Done')
except ValueError as e:
print('ERROR:', e)
if robustness:
try:
test_robustness(model_path, seeds=10)
pass
except ValueError as e:
print('ERROR:', e)
if 2 <= n_seeds == sum(list(x.is_dir() for x in exp_path.iterdir())):
if plotting:
plot_training_results_over_n_seeds(exp_path, metric_name=VAL_METRIC_NAME)
sanity_weight_swap(exp_path, plot_loader, VAL_METRIC_CLASS)