smaller task train

This commit is contained in:
Steffen Illium
2022-03-03 21:16:33 +01:00
parent e167cc78c5
commit 16c08d04d4
5 changed files with 47 additions and 37 deletions

View File

@@ -17,7 +17,7 @@ from torch.nn import Flatten
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor, Compose, Resize
from tqdm import tqdm, trange
from tqdm import tqdm
# noinspection DuplicatedCode
if platform.node() == 'CarbonX':
@@ -231,7 +231,7 @@ def plot_network_connectivity_by_fixtype(path_to_trained_model):
plt.show()
else:
plt.savefig(Path(path_to_trained_model.parent / f'net_connectivity_{fixtype}.png'), dpi=300)
tqdm.write(f'Connectivity plottet: {fixtype} - n = {df[df["type"] == fixtype].shape[0]}')
tqdm.write(f'Connectivity plottet: {fixtype} - n = {df[df["type"] == fixtype].shape[0] // 2}')
else:
tqdm.write(f'No Connectivity {fixtype}')
@@ -260,22 +260,17 @@ def run_particle_dropout_test(model_path):
return diff_store_path
def plot_dropout_stacked_barplot(model_path):
diff_store_path = model_path.parent / 'diff_store.csv'
def plot_dropout_stacked_barplot(mdl_path):
diff_store_path = mdl_path.parent / 'diff_store.csv'
diff_df = pd.read_csv(diff_store_path)
particle_dict = defaultdict(lambda: 0)
latest_model = torch.load(model_path, map_location=DEVICE).eval()
latest_model = torch.load(mdl_path, map_location=DEVICE).eval()
_ = test_for_fixpoints(particle_dict, list(latest_model.particles))
tqdm.write(str(dict(particle_dict)))
plt.clf()
fig, ax = plt.subplots(ncols=2)
colors = sns.color_palette()[1:diff_df.shape[0]+1]
barplot = sns.barplot(data=diff_df, y='Accuracy', x='Particle Type', ax=ax[0], palette=colors)
# noinspection PyUnboundLocalVariable
#for idx, patch in enumerate(barplot.patches):
# if idx != 0:
# # we recenter the bar
# patch.set_x(patch.get_x() + idx * 0.035)
_ = sns.barplot(data=diff_df, y='Accuracy', x='Particle Type', ax=ax[0], palette=colors)
ax[0].set_title('Accuracy after particle dropout')
ax[0].set_xlabel('Particle Type')
@@ -299,12 +294,12 @@ def flat_for_store(parameters):
return (x.item() for y in parameters for x in y.detach().flatten())
def train_self_replication(model, optimizer, st_steps) -> dict:
for _ in range(st_steps):
def train_self_replication(model, optimizer, st_stps) -> dict:
for _ in range(st_stps):
self_train_loss = model.combined_self_train(optimizer)
# noinspection PyUnboundLocalVariable
step_log = dict(Metric='Self Train Loss', Score=self_train_loss.item())
return step_log
stp_log = dict(Metric='Self Train Loss', Score=self_train_loss.item())
return stp_log
def train_task(model, optimizer, loss_func, btch_x, btch_y) -> (dict, torch.Tensor):
@@ -346,6 +341,7 @@ if __name__ == '__main__':
assert not (train_to_task_first and train_to_id_first)
# noinspection PyUnresolvedReferences
ac_str = f'_{activation.__class__.__name__}' if activation is not None else ''
res_str = f'{"" if residual_skip else "_no_res"}'
# dr_str = f'{f"_dr_{dropout}" if dropout != 0 else ""}'
@@ -364,14 +360,14 @@ if __name__ == '__main__':
for seed in range(n_seeds):
seed_path = exp_path / str(seed)
model_path = seed_path / '0000_trained_model.zip'
model_save_path = seed_path / '0000_trained_model.zip'
df_store_path = seed_path / 'train_store.csv'
weight_store_path = seed_path / 'weight_store.csv'
srnn_parameters = dict()
if training:
# Check if files do exist on project location, warn and break.
for path in [model_path, df_store_path, weight_store_path]:
for path in [model_save_path, df_store_path, weight_store_path]:
assert not path.exists(), f'Path "{path}" already exists. Check your configuration!'
utility_transforms = Compose([ToTensor(), ToFloat(), Resize((15, 15)), Flatten(start_dim=0)])
@@ -488,8 +484,10 @@ if __name__ == '__main__':
for particle in dense_metanet.particles:
weight_log = (epoch, particle.name, *flat_for_store(particle.parameters()))
weight_store.loc[weight_store.shape[0]] = weight_log
train_store.to_csv(df_store_path, mode='a', header=not df_store_path.exists(), index=False)
weight_store.to_csv(weight_store_path, mode='a', header=not weight_store_path.exists(), index=False)
train_store.to_csv(df_store_path, mode='a',
header=not df_store_path.exists(), index=False)
weight_store.to_csv(weight_store_path, mode='a',
header=not weight_store_path.exists(), index=False)
train_store = new_storage_df('train', None)
weight_store = new_storage_df('weights', dense_metanet.particle_parameter_count)
@@ -518,7 +516,7 @@ if __name__ == '__main__':
plot_training_particle_types(df_store_path)
try:
model_path = next(seed_path.glob(f'*e{EPOCH}.tp'))
_ = next(seed_path.glob(f'*e{EPOCH}.tp'))
except StopIteration:
print('Model pattern did not trigger.')
print(f'Search path was: {seed_path}:')
@@ -530,7 +528,7 @@ if __name__ == '__main__':
except ValueError as e:
print(e)
try:
plot_network_connectivity_by_fixtype(model_path)
plot_network_connectivity_by_fixtype(model_save_path)
except ValueError as e:
print(e)

View File

@@ -99,16 +99,16 @@ if __name__ == '__main__':
train_to_task_first = False
seq_task_train = True
force_st_for_epochs_n = 5
n_st_per_batch = 10
n_st_per_batch = 2
activation = None # nn.ReLU()
use_sparse_network = False
for weight_hidden_size in [3, 4]:
for weight_hidden_size in [3, 4, 5, 6]:
tsk_threshold = 0.85
weight_hidden_size = weight_hidden_size
residual_skip = False
residual_skip = True
n_seeds = 3
depth = 3
width = 3
@@ -167,9 +167,9 @@ if __name__ == '__main__':
sparse_metanet = sparse_metanet.replace_weights_by_particles(dense_metanet.particles)
loss_fn = nn.MSELoss()
dense_optimizer = torch.optim.SGD(dense_metanet.parameters(), lr=0.004, momentum=0.9)
dense_optimizer = torch.optim.SGD(dense_metanet.parameters(), lr=0.00004, momentum=0.9)
sparse_optimizer = torch.optim.SGD(
sparse_metanet.parameters(), lr=0.001, momentum=0.9
sparse_metanet.parameters(), lr=0.00001, momentum=0.9
) if use_sparse_network else dense_optimizer
dense_weights_updated = False
@@ -212,6 +212,7 @@ if __name__ == '__main__':
sparse_weights_updated = True
# Task Train
init_st = True
if not init_st:
# Transfer weights
if sparse_weights_updated:
@@ -231,7 +232,7 @@ if __name__ == '__main__':
sparse_weights_updated = False
dense_metanet = dense_metanet.eval()
if do_tsk_train:
if not init_st:
validation_log = dict(Epoch=int(epoch), Batch=BATCHSIZE,
Metric='Train Accuracy', Score=metric.compute().item())
train_store.loc[train_store.shape[0]] = validation_log

View File

@@ -10,7 +10,7 @@ from torch import nn
from torch.utils.data import Dataset, DataLoader
import numpy as np
import seaborn as sns
from tqdm import trange
from tqdm import trange, tqdm
from tqdm.contrib import tenumerate
@@ -53,7 +53,7 @@ class MultiplyByXTaskDataset(Dataset):
if __name__ == '__main__':
net = Net(5, 1, 1)
net = Net(5, 4, 1)
multiplication_target = 0.03
loss_fn = nn.MSELoss()
@@ -68,6 +68,9 @@ if __name__ == '__main__':
mean_self_tain_loss = []
for batch, (batch_x, batch_y) in tenumerate(dataloader):
self_train_loss, _ = net.self_train(2, save_history=False)
is_fixpoint = functionalities_test.is_zero_fixpoint(net)
optimizer.zero_grad()
batch_x_emb = torch.zeros(batch_x.shape[0], 5)
batch_x_emb[:, -1] = batch_x.squeeze()
y = net(batch_x_emb)
@@ -75,6 +78,9 @@ if __name__ == '__main__':
loss.backward()
optimizer.step()
if is_fixpoint:
tqdm.write(f'is fixpoint after st : {is_fixpoint}')
tqdm.write(f'is fixpoint after tsk: {functionalities_test.is_zero_fixpoint(net)}')
mean_batch_loss.append(loss.detach())
mean_self_tain_loss.append(self_train_loss.detach())