smaller task train
This commit is contained in:
@@ -17,7 +17,7 @@ from torch.nn import Flatten
|
|||||||
from torch.utils.data import Dataset, DataLoader
|
from torch.utils.data import Dataset, DataLoader
|
||||||
from torchvision.datasets import MNIST
|
from torchvision.datasets import MNIST
|
||||||
from torchvision.transforms import ToTensor, Compose, Resize
|
from torchvision.transforms import ToTensor, Compose, Resize
|
||||||
from tqdm import tqdm, trange
|
from tqdm import tqdm
|
||||||
|
|
||||||
# noinspection DuplicatedCode
|
# noinspection DuplicatedCode
|
||||||
if platform.node() == 'CarbonX':
|
if platform.node() == 'CarbonX':
|
||||||
@@ -231,7 +231,7 @@ def plot_network_connectivity_by_fixtype(path_to_trained_model):
|
|||||||
plt.show()
|
plt.show()
|
||||||
else:
|
else:
|
||||||
plt.savefig(Path(path_to_trained_model.parent / f'net_connectivity_{fixtype}.png'), dpi=300)
|
plt.savefig(Path(path_to_trained_model.parent / f'net_connectivity_{fixtype}.png'), dpi=300)
|
||||||
tqdm.write(f'Connectivity plottet: {fixtype} - n = {df[df["type"] == fixtype].shape[0]}')
|
tqdm.write(f'Connectivity plottet: {fixtype} - n = {df[df["type"] == fixtype].shape[0] // 2}')
|
||||||
else:
|
else:
|
||||||
tqdm.write(f'No Connectivity {fixtype}')
|
tqdm.write(f'No Connectivity {fixtype}')
|
||||||
|
|
||||||
@@ -260,22 +260,17 @@ def run_particle_dropout_test(model_path):
|
|||||||
return diff_store_path
|
return diff_store_path
|
||||||
|
|
||||||
|
|
||||||
def plot_dropout_stacked_barplot(model_path):
|
def plot_dropout_stacked_barplot(mdl_path):
|
||||||
diff_store_path = model_path.parent / 'diff_store.csv'
|
diff_store_path = mdl_path.parent / 'diff_store.csv'
|
||||||
diff_df = pd.read_csv(diff_store_path)
|
diff_df = pd.read_csv(diff_store_path)
|
||||||
particle_dict = defaultdict(lambda: 0)
|
particle_dict = defaultdict(lambda: 0)
|
||||||
latest_model = torch.load(model_path, map_location=DEVICE).eval()
|
latest_model = torch.load(mdl_path, map_location=DEVICE).eval()
|
||||||
_ = test_for_fixpoints(particle_dict, list(latest_model.particles))
|
_ = test_for_fixpoints(particle_dict, list(latest_model.particles))
|
||||||
tqdm.write(str(dict(particle_dict)))
|
tqdm.write(str(dict(particle_dict)))
|
||||||
plt.clf()
|
plt.clf()
|
||||||
fig, ax = plt.subplots(ncols=2)
|
fig, ax = plt.subplots(ncols=2)
|
||||||
colors = sns.color_palette()[1:diff_df.shape[0]+1]
|
colors = sns.color_palette()[1:diff_df.shape[0]+1]
|
||||||
barplot = sns.barplot(data=diff_df, y='Accuracy', x='Particle Type', ax=ax[0], palette=colors)
|
_ = sns.barplot(data=diff_df, y='Accuracy', x='Particle Type', ax=ax[0], palette=colors)
|
||||||
# noinspection PyUnboundLocalVariable
|
|
||||||
#for idx, patch in enumerate(barplot.patches):
|
|
||||||
# if idx != 0:
|
|
||||||
# # we recenter the bar
|
|
||||||
# patch.set_x(patch.get_x() + idx * 0.035)
|
|
||||||
|
|
||||||
ax[0].set_title('Accuracy after particle dropout')
|
ax[0].set_title('Accuracy after particle dropout')
|
||||||
ax[0].set_xlabel('Particle Type')
|
ax[0].set_xlabel('Particle Type')
|
||||||
@@ -299,12 +294,12 @@ def flat_for_store(parameters):
|
|||||||
return (x.item() for y in parameters for x in y.detach().flatten())
|
return (x.item() for y in parameters for x in y.detach().flatten())
|
||||||
|
|
||||||
|
|
||||||
def train_self_replication(model, optimizer, st_steps) -> dict:
|
def train_self_replication(model, optimizer, st_stps) -> dict:
|
||||||
for _ in range(st_steps):
|
for _ in range(st_stps):
|
||||||
self_train_loss = model.combined_self_train(optimizer)
|
self_train_loss = model.combined_self_train(optimizer)
|
||||||
# noinspection PyUnboundLocalVariable
|
# noinspection PyUnboundLocalVariable
|
||||||
step_log = dict(Metric='Self Train Loss', Score=self_train_loss.item())
|
stp_log = dict(Metric='Self Train Loss', Score=self_train_loss.item())
|
||||||
return step_log
|
return stp_log
|
||||||
|
|
||||||
|
|
||||||
def train_task(model, optimizer, loss_func, btch_x, btch_y) -> (dict, torch.Tensor):
|
def train_task(model, optimizer, loss_func, btch_x, btch_y) -> (dict, torch.Tensor):
|
||||||
@@ -346,6 +341,7 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
assert not (train_to_task_first and train_to_id_first)
|
assert not (train_to_task_first and train_to_id_first)
|
||||||
|
|
||||||
|
# noinspection PyUnresolvedReferences
|
||||||
ac_str = f'_{activation.__class__.__name__}' if activation is not None else ''
|
ac_str = f'_{activation.__class__.__name__}' if activation is not None else ''
|
||||||
res_str = f'{"" if residual_skip else "_no_res"}'
|
res_str = f'{"" if residual_skip else "_no_res"}'
|
||||||
# dr_str = f'{f"_dr_{dropout}" if dropout != 0 else ""}'
|
# dr_str = f'{f"_dr_{dropout}" if dropout != 0 else ""}'
|
||||||
@@ -364,14 +360,14 @@ if __name__ == '__main__':
|
|||||||
for seed in range(n_seeds):
|
for seed in range(n_seeds):
|
||||||
seed_path = exp_path / str(seed)
|
seed_path = exp_path / str(seed)
|
||||||
|
|
||||||
model_path = seed_path / '0000_trained_model.zip'
|
model_save_path = seed_path / '0000_trained_model.zip'
|
||||||
df_store_path = seed_path / 'train_store.csv'
|
df_store_path = seed_path / 'train_store.csv'
|
||||||
weight_store_path = seed_path / 'weight_store.csv'
|
weight_store_path = seed_path / 'weight_store.csv'
|
||||||
srnn_parameters = dict()
|
srnn_parameters = dict()
|
||||||
|
|
||||||
if training:
|
if training:
|
||||||
# Check if files do exist on project location, warn and break.
|
# Check if files do exist on project location, warn and break.
|
||||||
for path in [model_path, df_store_path, weight_store_path]:
|
for path in [model_save_path, df_store_path, weight_store_path]:
|
||||||
assert not path.exists(), f'Path "{path}" already exists. Check your configuration!'
|
assert not path.exists(), f'Path "{path}" already exists. Check your configuration!'
|
||||||
|
|
||||||
utility_transforms = Compose([ToTensor(), ToFloat(), Resize((15, 15)), Flatten(start_dim=0)])
|
utility_transforms = Compose([ToTensor(), ToFloat(), Resize((15, 15)), Flatten(start_dim=0)])
|
||||||
@@ -488,8 +484,10 @@ if __name__ == '__main__':
|
|||||||
for particle in dense_metanet.particles:
|
for particle in dense_metanet.particles:
|
||||||
weight_log = (epoch, particle.name, *flat_for_store(particle.parameters()))
|
weight_log = (epoch, particle.name, *flat_for_store(particle.parameters()))
|
||||||
weight_store.loc[weight_store.shape[0]] = weight_log
|
weight_store.loc[weight_store.shape[0]] = weight_log
|
||||||
train_store.to_csv(df_store_path, mode='a', header=not df_store_path.exists(), index=False)
|
train_store.to_csv(df_store_path, mode='a',
|
||||||
weight_store.to_csv(weight_store_path, mode='a', header=not weight_store_path.exists(), index=False)
|
header=not df_store_path.exists(), index=False)
|
||||||
|
weight_store.to_csv(weight_store_path, mode='a',
|
||||||
|
header=not weight_store_path.exists(), index=False)
|
||||||
train_store = new_storage_df('train', None)
|
train_store = new_storage_df('train', None)
|
||||||
weight_store = new_storage_df('weights', dense_metanet.particle_parameter_count)
|
weight_store = new_storage_df('weights', dense_metanet.particle_parameter_count)
|
||||||
|
|
||||||
@@ -518,7 +516,7 @@ if __name__ == '__main__':
|
|||||||
plot_training_particle_types(df_store_path)
|
plot_training_particle_types(df_store_path)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
model_path = next(seed_path.glob(f'*e{EPOCH}.tp'))
|
_ = next(seed_path.glob(f'*e{EPOCH}.tp'))
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
print('Model pattern did not trigger.')
|
print('Model pattern did not trigger.')
|
||||||
print(f'Search path was: {seed_path}:')
|
print(f'Search path was: {seed_path}:')
|
||||||
@@ -530,7 +528,7 @@ if __name__ == '__main__':
|
|||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
print(e)
|
print(e)
|
||||||
try:
|
try:
|
||||||
plot_network_connectivity_by_fixtype(model_path)
|
plot_network_connectivity_by_fixtype(model_save_path)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
print(e)
|
print(e)
|
||||||
|
|
||||||
|
|||||||
@@ -99,16 +99,16 @@ if __name__ == '__main__':
|
|||||||
train_to_task_first = False
|
train_to_task_first = False
|
||||||
seq_task_train = True
|
seq_task_train = True
|
||||||
force_st_for_epochs_n = 5
|
force_st_for_epochs_n = 5
|
||||||
n_st_per_batch = 10
|
n_st_per_batch = 2
|
||||||
activation = None # nn.ReLU()
|
activation = None # nn.ReLU()
|
||||||
|
|
||||||
use_sparse_network = False
|
use_sparse_network = False
|
||||||
|
|
||||||
for weight_hidden_size in [3, 4]:
|
for weight_hidden_size in [3, 4, 5, 6]:
|
||||||
|
|
||||||
tsk_threshold = 0.85
|
tsk_threshold = 0.85
|
||||||
weight_hidden_size = weight_hidden_size
|
weight_hidden_size = weight_hidden_size
|
||||||
residual_skip = False
|
residual_skip = True
|
||||||
n_seeds = 3
|
n_seeds = 3
|
||||||
depth = 3
|
depth = 3
|
||||||
width = 3
|
width = 3
|
||||||
@@ -167,9 +167,9 @@ if __name__ == '__main__':
|
|||||||
sparse_metanet = sparse_metanet.replace_weights_by_particles(dense_metanet.particles)
|
sparse_metanet = sparse_metanet.replace_weights_by_particles(dense_metanet.particles)
|
||||||
|
|
||||||
loss_fn = nn.MSELoss()
|
loss_fn = nn.MSELoss()
|
||||||
dense_optimizer = torch.optim.SGD(dense_metanet.parameters(), lr=0.004, momentum=0.9)
|
dense_optimizer = torch.optim.SGD(dense_metanet.parameters(), lr=0.00004, momentum=0.9)
|
||||||
sparse_optimizer = torch.optim.SGD(
|
sparse_optimizer = torch.optim.SGD(
|
||||||
sparse_metanet.parameters(), lr=0.001, momentum=0.9
|
sparse_metanet.parameters(), lr=0.00001, momentum=0.9
|
||||||
) if use_sparse_network else dense_optimizer
|
) if use_sparse_network else dense_optimizer
|
||||||
|
|
||||||
dense_weights_updated = False
|
dense_weights_updated = False
|
||||||
@@ -212,6 +212,7 @@ if __name__ == '__main__':
|
|||||||
sparse_weights_updated = True
|
sparse_weights_updated = True
|
||||||
|
|
||||||
# Task Train
|
# Task Train
|
||||||
|
init_st = True
|
||||||
if not init_st:
|
if not init_st:
|
||||||
# Transfer weights
|
# Transfer weights
|
||||||
if sparse_weights_updated:
|
if sparse_weights_updated:
|
||||||
@@ -231,7 +232,7 @@ if __name__ == '__main__':
|
|||||||
sparse_weights_updated = False
|
sparse_weights_updated = False
|
||||||
|
|
||||||
dense_metanet = dense_metanet.eval()
|
dense_metanet = dense_metanet.eval()
|
||||||
if do_tsk_train:
|
if not init_st:
|
||||||
validation_log = dict(Epoch=int(epoch), Batch=BATCHSIZE,
|
validation_log = dict(Epoch=int(epoch), Batch=BATCHSIZE,
|
||||||
Metric='Train Accuracy', Score=metric.compute().item())
|
Metric='Train Accuracy', Score=metric.compute().item())
|
||||||
train_store.loc[train_store.shape[0]] = validation_log
|
train_store.loc[train_store.shape[0]] = validation_log
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ from torch import nn
|
|||||||
from torch.utils.data import Dataset, DataLoader
|
from torch.utils.data import Dataset, DataLoader
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import seaborn as sns
|
import seaborn as sns
|
||||||
from tqdm import trange
|
from tqdm import trange, tqdm
|
||||||
from tqdm.contrib import tenumerate
|
from tqdm.contrib import tenumerate
|
||||||
|
|
||||||
|
|
||||||
@@ -53,7 +53,7 @@ class MultiplyByXTaskDataset(Dataset):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
net = Net(5, 1, 1)
|
net = Net(5, 4, 1)
|
||||||
multiplication_target = 0.03
|
multiplication_target = 0.03
|
||||||
|
|
||||||
loss_fn = nn.MSELoss()
|
loss_fn = nn.MSELoss()
|
||||||
@@ -68,6 +68,9 @@ if __name__ == '__main__':
|
|||||||
mean_self_tain_loss = []
|
mean_self_tain_loss = []
|
||||||
for batch, (batch_x, batch_y) in tenumerate(dataloader):
|
for batch, (batch_x, batch_y) in tenumerate(dataloader):
|
||||||
self_train_loss, _ = net.self_train(2, save_history=False)
|
self_train_loss, _ = net.self_train(2, save_history=False)
|
||||||
|
is_fixpoint = functionalities_test.is_zero_fixpoint(net)
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
batch_x_emb = torch.zeros(batch_x.shape[0], 5)
|
batch_x_emb = torch.zeros(batch_x.shape[0], 5)
|
||||||
batch_x_emb[:, -1] = batch_x.squeeze()
|
batch_x_emb[:, -1] = batch_x.squeeze()
|
||||||
y = net(batch_x_emb)
|
y = net(batch_x_emb)
|
||||||
@@ -75,6 +78,9 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
if is_fixpoint:
|
||||||
|
tqdm.write(f'is fixpoint after st : {is_fixpoint}')
|
||||||
|
tqdm.write(f'is fixpoint after tsk: {functionalities_test.is_zero_fixpoint(net)}')
|
||||||
|
|
||||||
mean_batch_loss.append(loss.detach())
|
mean_batch_loss.append(loss.detach())
|
||||||
mean_self_tain_loss.append(self_train_loss.detach())
|
mean_self_tain_loss.append(self_train_loss.detach())
|
||||||
|
|||||||
@@ -221,7 +221,7 @@ if __name__ == "__main__":
|
|||||||
ST_steps = 1000
|
ST_steps = 1000
|
||||||
ST_epochs = 5
|
ST_epochs = 5
|
||||||
ST_log_step_size = 10
|
ST_log_step_size = 10
|
||||||
ST_population_size = 1000
|
ST_population_size = 10
|
||||||
ST_net_hidden_size = 2
|
ST_net_hidden_size = 2
|
||||||
ST_net_learning_rate = 0.004
|
ST_net_learning_rate = 0.004
|
||||||
ST_name_hash = random.getrandbits(32)
|
ST_name_hash = random.getrandbits(32)
|
||||||
|
|||||||
13
network.py
13
network.py
@@ -45,7 +45,8 @@ class Net(nn.Module):
|
|||||||
# target_weight_matrix[i] = input_weight_matrix[i][0]
|
# target_weight_matrix[i] = input_weight_matrix[i][0]
|
||||||
|
|
||||||
# Fast and simple
|
# Fast and simple
|
||||||
return input_weight_matrix[:, 0].unsqueeze(-1)
|
target_weights = input_weight_matrix[:, 0].detach().unsqueeze(-1)
|
||||||
|
return target_weights
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -132,6 +133,7 @@ class Net(nn.Module):
|
|||||||
# Normalize 1,2,3 column of dim 1
|
# Normalize 1,2,3 column of dim 1
|
||||||
last_pos_idx = self.input_size - 4
|
last_pos_idx = self.input_size - 4
|
||||||
max_per_col, _ = weight_matrix[:, 1:-last_pos_idx].max(keepdim=True, dim=0)
|
max_per_col, _ = weight_matrix[:, 1:-last_pos_idx].max(keepdim=True, dim=0)
|
||||||
|
max_per_col += 1e-8
|
||||||
weight_matrix[:, 1:-last_pos_idx] = (weight_matrix[:, 1:-last_pos_idx] / max_per_col) + 1e-8
|
weight_matrix[:, 1:-last_pos_idx] = (weight_matrix[:, 1:-last_pos_idx] / max_per_col) + 1e-8
|
||||||
|
|
||||||
# computations
|
# computations
|
||||||
@@ -139,7 +141,7 @@ class Net(nn.Module):
|
|||||||
mask = torch.ones_like(weight_matrix)
|
mask = torch.ones_like(weight_matrix)
|
||||||
mask[:, 0] = 0
|
mask[:, 0] = 0
|
||||||
|
|
||||||
self._weight_pos_enc_and_mask = weight_matrix, mask
|
self._weight_pos_enc_and_mask = weight_matrix.detach(), mask.detach()
|
||||||
return self._weight_pos_enc_and_mask
|
return self._weight_pos_enc_and_mask
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
@@ -160,10 +162,11 @@ class Net(nn.Module):
|
|||||||
|
|
||||||
def input_weight_matrix(self) -> Tensor:
|
def input_weight_matrix(self) -> Tensor:
|
||||||
""" Calculating the input tensor formed from the weights of the net """
|
""" Calculating the input tensor formed from the weights of the net """
|
||||||
|
with torch.no_grad():
|
||||||
weight_matrix = torch.cat([x.view(-1, 1) for x in self.parameters()])
|
weight_matrix = torch.cat([x.view(-1, 1) for x in self.parameters()])
|
||||||
pos_enc, mask = self._weight_pos_enc
|
pos_enc, mask = self._weight_pos_enc
|
||||||
weight_matrix = pos_enc * mask + weight_matrix.expand(-1, pos_enc.shape[-1]) * (1 - mask)
|
weight_matrix = pos_enc * mask + weight_matrix.expand(-1, pos_enc.shape[-1]) * (1 - mask)
|
||||||
return weight_matrix
|
return weight_matrix.detach()
|
||||||
|
|
||||||
def target_weight_matrix(self) -> Tensor:
|
def target_weight_matrix(self) -> Tensor:
|
||||||
weight_matrix = torch.cat([x.view(-1, 1) for x in self.parameters()])
|
weight_matrix = torch.cat([x.view(-1, 1) for x in self.parameters()])
|
||||||
@@ -204,10 +207,11 @@ class Net(nn.Module):
|
|||||||
self.s_train_weights_history.append(weights.T.detach().numpy())
|
self.s_train_weights_history.append(weights.T.detach().numpy())
|
||||||
self.loss_history.append(loss.item())
|
self.loss_history.append(loss.item())
|
||||||
|
|
||||||
weights = self.create_target_weights(self.input_weight_matrix())
|
|
||||||
# Saving weights only at the end of a soup/mixed exp. epoch.
|
# Saving weights only at the end of a soup/mixed exp. epoch.
|
||||||
if save_history:
|
if save_history:
|
||||||
if "soup" in self.name or "mixed" in self.name:
|
if "soup" in self.name or "mixed" in self.name:
|
||||||
|
weights = self.create_target_weights(self.input_weight_matrix())
|
||||||
self.s_train_weights_history.append(weights.T.detach().numpy())
|
self.s_train_weights_history.append(weights.T.detach().numpy())
|
||||||
self.loss_history.append(loss.item())
|
self.loss_history.append(loss.item())
|
||||||
|
|
||||||
@@ -462,6 +466,7 @@ class MetaNet(nn.Module):
|
|||||||
def combined_self_train(self, optimizer, reduction='mean'):
|
def combined_self_train(self, optimizer, reduction='mean'):
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
losses = []
|
losses = []
|
||||||
|
n = 10
|
||||||
for particle in self.particles:
|
for particle in self.particles:
|
||||||
# Intergrate optimizer and backward function
|
# Intergrate optimizer and backward function
|
||||||
input_data = particle.input_weight_matrix()
|
input_data = particle.input_weight_matrix()
|
||||||
|
|||||||
Reference in New Issue
Block a user