smaller task train
This commit is contained in:
@@ -17,7 +17,7 @@ from torch.nn import Flatten
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from torchvision.datasets import MNIST
|
||||
from torchvision.transforms import ToTensor, Compose, Resize
|
||||
from tqdm import tqdm, trange
|
||||
from tqdm import tqdm
|
||||
|
||||
# noinspection DuplicatedCode
|
||||
if platform.node() == 'CarbonX':
|
||||
@@ -231,7 +231,7 @@ def plot_network_connectivity_by_fixtype(path_to_trained_model):
|
||||
plt.show()
|
||||
else:
|
||||
plt.savefig(Path(path_to_trained_model.parent / f'net_connectivity_{fixtype}.png'), dpi=300)
|
||||
tqdm.write(f'Connectivity plottet: {fixtype} - n = {df[df["type"] == fixtype].shape[0]}')
|
||||
tqdm.write(f'Connectivity plottet: {fixtype} - n = {df[df["type"] == fixtype].shape[0] // 2}')
|
||||
else:
|
||||
tqdm.write(f'No Connectivity {fixtype}')
|
||||
|
||||
@@ -260,22 +260,17 @@ def run_particle_dropout_test(model_path):
|
||||
return diff_store_path
|
||||
|
||||
|
||||
def plot_dropout_stacked_barplot(model_path):
|
||||
diff_store_path = model_path.parent / 'diff_store.csv'
|
||||
def plot_dropout_stacked_barplot(mdl_path):
|
||||
diff_store_path = mdl_path.parent / 'diff_store.csv'
|
||||
diff_df = pd.read_csv(diff_store_path)
|
||||
particle_dict = defaultdict(lambda: 0)
|
||||
latest_model = torch.load(model_path, map_location=DEVICE).eval()
|
||||
latest_model = torch.load(mdl_path, map_location=DEVICE).eval()
|
||||
_ = test_for_fixpoints(particle_dict, list(latest_model.particles))
|
||||
tqdm.write(str(dict(particle_dict)))
|
||||
plt.clf()
|
||||
fig, ax = plt.subplots(ncols=2)
|
||||
colors = sns.color_palette()[1:diff_df.shape[0]+1]
|
||||
barplot = sns.barplot(data=diff_df, y='Accuracy', x='Particle Type', ax=ax[0], palette=colors)
|
||||
# noinspection PyUnboundLocalVariable
|
||||
#for idx, patch in enumerate(barplot.patches):
|
||||
# if idx != 0:
|
||||
# # we recenter the bar
|
||||
# patch.set_x(patch.get_x() + idx * 0.035)
|
||||
_ = sns.barplot(data=diff_df, y='Accuracy', x='Particle Type', ax=ax[0], palette=colors)
|
||||
|
||||
ax[0].set_title('Accuracy after particle dropout')
|
||||
ax[0].set_xlabel('Particle Type')
|
||||
@@ -299,12 +294,12 @@ def flat_for_store(parameters):
|
||||
return (x.item() for y in parameters for x in y.detach().flatten())
|
||||
|
||||
|
||||
def train_self_replication(model, optimizer, st_steps) -> dict:
|
||||
for _ in range(st_steps):
|
||||
def train_self_replication(model, optimizer, st_stps) -> dict:
|
||||
for _ in range(st_stps):
|
||||
self_train_loss = model.combined_self_train(optimizer)
|
||||
# noinspection PyUnboundLocalVariable
|
||||
step_log = dict(Metric='Self Train Loss', Score=self_train_loss.item())
|
||||
return step_log
|
||||
stp_log = dict(Metric='Self Train Loss', Score=self_train_loss.item())
|
||||
return stp_log
|
||||
|
||||
|
||||
def train_task(model, optimizer, loss_func, btch_x, btch_y) -> (dict, torch.Tensor):
|
||||
@@ -346,6 +341,7 @@ if __name__ == '__main__':
|
||||
|
||||
assert not (train_to_task_first and train_to_id_first)
|
||||
|
||||
# noinspection PyUnresolvedReferences
|
||||
ac_str = f'_{activation.__class__.__name__}' if activation is not None else ''
|
||||
res_str = f'{"" if residual_skip else "_no_res"}'
|
||||
# dr_str = f'{f"_dr_{dropout}" if dropout != 0 else ""}'
|
||||
@@ -364,14 +360,14 @@ if __name__ == '__main__':
|
||||
for seed in range(n_seeds):
|
||||
seed_path = exp_path / str(seed)
|
||||
|
||||
model_path = seed_path / '0000_trained_model.zip'
|
||||
model_save_path = seed_path / '0000_trained_model.zip'
|
||||
df_store_path = seed_path / 'train_store.csv'
|
||||
weight_store_path = seed_path / 'weight_store.csv'
|
||||
srnn_parameters = dict()
|
||||
|
||||
if training:
|
||||
# Check if files do exist on project location, warn and break.
|
||||
for path in [model_path, df_store_path, weight_store_path]:
|
||||
for path in [model_save_path, df_store_path, weight_store_path]:
|
||||
assert not path.exists(), f'Path "{path}" already exists. Check your configuration!'
|
||||
|
||||
utility_transforms = Compose([ToTensor(), ToFloat(), Resize((15, 15)), Flatten(start_dim=0)])
|
||||
@@ -488,8 +484,10 @@ if __name__ == '__main__':
|
||||
for particle in dense_metanet.particles:
|
||||
weight_log = (epoch, particle.name, *flat_for_store(particle.parameters()))
|
||||
weight_store.loc[weight_store.shape[0]] = weight_log
|
||||
train_store.to_csv(df_store_path, mode='a', header=not df_store_path.exists(), index=False)
|
||||
weight_store.to_csv(weight_store_path, mode='a', header=not weight_store_path.exists(), index=False)
|
||||
train_store.to_csv(df_store_path, mode='a',
|
||||
header=not df_store_path.exists(), index=False)
|
||||
weight_store.to_csv(weight_store_path, mode='a',
|
||||
header=not weight_store_path.exists(), index=False)
|
||||
train_store = new_storage_df('train', None)
|
||||
weight_store = new_storage_df('weights', dense_metanet.particle_parameter_count)
|
||||
|
||||
@@ -518,7 +516,7 @@ if __name__ == '__main__':
|
||||
plot_training_particle_types(df_store_path)
|
||||
|
||||
try:
|
||||
model_path = next(seed_path.glob(f'*e{EPOCH}.tp'))
|
||||
_ = next(seed_path.glob(f'*e{EPOCH}.tp'))
|
||||
except StopIteration:
|
||||
print('Model pattern did not trigger.')
|
||||
print(f'Search path was: {seed_path}:')
|
||||
@@ -530,7 +528,7 @@ if __name__ == '__main__':
|
||||
except ValueError as e:
|
||||
print(e)
|
||||
try:
|
||||
plot_network_connectivity_by_fixtype(model_path)
|
||||
plot_network_connectivity_by_fixtype(model_save_path)
|
||||
except ValueError as e:
|
||||
print(e)
|
||||
|
||||
|
@@ -99,16 +99,16 @@ if __name__ == '__main__':
|
||||
train_to_task_first = False
|
||||
seq_task_train = True
|
||||
force_st_for_epochs_n = 5
|
||||
n_st_per_batch = 10
|
||||
n_st_per_batch = 2
|
||||
activation = None # nn.ReLU()
|
||||
|
||||
use_sparse_network = False
|
||||
|
||||
for weight_hidden_size in [3, 4]:
|
||||
for weight_hidden_size in [3, 4, 5, 6]:
|
||||
|
||||
tsk_threshold = 0.85
|
||||
weight_hidden_size = weight_hidden_size
|
||||
residual_skip = False
|
||||
residual_skip = True
|
||||
n_seeds = 3
|
||||
depth = 3
|
||||
width = 3
|
||||
@@ -167,9 +167,9 @@ if __name__ == '__main__':
|
||||
sparse_metanet = sparse_metanet.replace_weights_by_particles(dense_metanet.particles)
|
||||
|
||||
loss_fn = nn.MSELoss()
|
||||
dense_optimizer = torch.optim.SGD(dense_metanet.parameters(), lr=0.004, momentum=0.9)
|
||||
dense_optimizer = torch.optim.SGD(dense_metanet.parameters(), lr=0.00004, momentum=0.9)
|
||||
sparse_optimizer = torch.optim.SGD(
|
||||
sparse_metanet.parameters(), lr=0.001, momentum=0.9
|
||||
sparse_metanet.parameters(), lr=0.00001, momentum=0.9
|
||||
) if use_sparse_network else dense_optimizer
|
||||
|
||||
dense_weights_updated = False
|
||||
@@ -212,6 +212,7 @@ if __name__ == '__main__':
|
||||
sparse_weights_updated = True
|
||||
|
||||
# Task Train
|
||||
init_st = True
|
||||
if not init_st:
|
||||
# Transfer weights
|
||||
if sparse_weights_updated:
|
||||
@@ -231,7 +232,7 @@ if __name__ == '__main__':
|
||||
sparse_weights_updated = False
|
||||
|
||||
dense_metanet = dense_metanet.eval()
|
||||
if do_tsk_train:
|
||||
if not init_st:
|
||||
validation_log = dict(Epoch=int(epoch), Batch=BATCHSIZE,
|
||||
Metric='Train Accuracy', Score=metric.compute().item())
|
||||
train_store.loc[train_store.shape[0]] = validation_log
|
||||
|
@@ -10,7 +10,7 @@ from torch import nn
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
import numpy as np
|
||||
import seaborn as sns
|
||||
from tqdm import trange
|
||||
from tqdm import trange, tqdm
|
||||
from tqdm.contrib import tenumerate
|
||||
|
||||
|
||||
@@ -53,7 +53,7 @@ class MultiplyByXTaskDataset(Dataset):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
net = Net(5, 1, 1)
|
||||
net = Net(5, 4, 1)
|
||||
multiplication_target = 0.03
|
||||
|
||||
loss_fn = nn.MSELoss()
|
||||
@@ -68,6 +68,9 @@ if __name__ == '__main__':
|
||||
mean_self_tain_loss = []
|
||||
for batch, (batch_x, batch_y) in tenumerate(dataloader):
|
||||
self_train_loss, _ = net.self_train(2, save_history=False)
|
||||
is_fixpoint = functionalities_test.is_zero_fixpoint(net)
|
||||
|
||||
optimizer.zero_grad()
|
||||
batch_x_emb = torch.zeros(batch_x.shape[0], 5)
|
||||
batch_x_emb[:, -1] = batch_x.squeeze()
|
||||
y = net(batch_x_emb)
|
||||
@@ -75,6 +78,9 @@ if __name__ == '__main__':
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
if is_fixpoint:
|
||||
tqdm.write(f'is fixpoint after st : {is_fixpoint}')
|
||||
tqdm.write(f'is fixpoint after tsk: {functionalities_test.is_zero_fixpoint(net)}')
|
||||
|
||||
mean_batch_loss.append(loss.detach())
|
||||
mean_self_tain_loss.append(self_train_loss.detach())
|
||||
|
Reference in New Issue
Block a user