small fixes new parameters

This commit is contained in:
Steffen Illium
2022-02-25 15:32:56 +01:00
parent 5b2b5b5beb
commit 9d8496a725
5 changed files with 292 additions and 236 deletions

View File

@@ -45,8 +45,8 @@ from functionalities_test import test_for_fixpoints
WORKER = 10 if not debug else 2 WORKER = 10 if not debug else 2
debug = False debug = False
BATCHSIZE = 500 if not debug else 50 BATCHSIZE = 500 if not debug else 50
EPOCH = 50 EPOCH = 100
VALIDATION_FRQ = 3 if not debug else 1 VALIDATION_FRQ = 4 if not debug else 1
SELF_TRAIN_FRQ = 1 if not debug else 1 SELF_TRAIN_FRQ = 1 if not debug else 1
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
@@ -56,6 +56,9 @@ if debug:
class ToFloat: class ToFloat:
def __init__(self):
pass
def __call__(self, x): def __call__(self, x):
return x.to(torch.float32) return x.to(torch.float32)
@@ -194,7 +197,7 @@ def plot_training_result(path_to_dataframe):
def plot_network_connectivity_by_fixtype(path_to_trained_model): def plot_network_connectivity_by_fixtype(path_to_trained_model):
m = torch.load(path_to_trained_model, map_location=torch.device('cpu')) m = torch.load(path_to_trained_model, map_location=torch.device('cpu'))
# noinspection PyProtectedMember # noinspection PyProtectedMember
particles = [y for x in m._meta_layer_list for y in x.particles] particles = list(m.particles)
df = pd.DataFrame(columns=['type', 'layer', 'neuron', 'name']) df = pd.DataFrame(columns=['type', 'layer', 'neuron', 'name'])
for prtcl in particles: for prtcl in particles:
@@ -210,10 +213,16 @@ def plot_network_connectivity_by_fixtype(path_to_trained_model):
for n, fixtype in enumerate([ft.other_func, ft.identity_func]): for n, fixtype in enumerate([ft.other_func, ft.identity_func]):
plt.clf() plt.clf()
ax = sns.lineplot(y='neuron', x='layer', hue='name', data=df[df['type'] == fixtype], ax = sns.lineplot(y='neuron', x='layer', hue='name', data=df[df['type'] == fixtype],
legend=False, estimator=None, legend=False, estimator=None, lw=1)
palette=[sns.color_palette()[n]] * (df[df['type'] == fixtype].shape[0]//2), lw=1) _ = sns.lineplot(y=[0, 1], x=[-1, df['layer'].max()], legend=False, estimator=None, lw=0)
ax.set_title(fixtype) ax.set_title(fixtype)
plt.show() lines = ax.get_lines()
for line in lines:
line.set_color(sns.color_palette()[n])
if debug:
plt.show()
else:
plt.savefig(Path(path_to_trained_model.parent / f'net_connectivity_{fixtype}.png'), dpi=300)
print('plottet') print('plottet')
@@ -234,7 +243,7 @@ def run_particle_dropout_test(run_path):
tqdm.write(f'Zero_ident diff = {acc_diff}') tqdm.write(f'Zero_ident diff = {acc_diff}')
diff_df.loc[diff_df.shape[0]] = (fixpoint_type, acc_post, acc_diff) diff_df.loc[diff_df.shape[0]] = (fixpoint_type, acc_post, acc_diff)
diff_df.to_csv(diff_store_path, mode='a', header=not df_store_path.exists(), index=False) diff_df.to_csv(diff_store_path, mode='a', header=not diff_store_path.exists(), index=False)
return diff_store_path return diff_store_path
@@ -246,18 +255,18 @@ def plot_dropout_stacked_barplot(path_to_diff_df):
plt.clf() plt.clf()
fig, ax = plt.subplots(ncols=2) fig, ax = plt.subplots(ncols=2)
colors = sns.color_palette()[:diff_df.shape[0]] colors = sns.color_palette()[:diff_df.shape[0]]
barplot = sns.barplot(data=diff_df, y='Accuracy', x='Particle Type', palette=colors, ax=ax[0]) barplot = sns.barplot(data=diff_df, y='Accuracy', x='Particle Type', ax=ax[0], palette=colors)
# noinspection PyUnboundLocalVariable # noinspection PyUnboundLocalVariable
for idx, patch in enumerate(barplot.patches): #for idx, patch in enumerate(barplot.patches):
if idx != 0: # if idx != 0:
# we recenter the bar # # we recenter the bar
patch.set_x(patch.get_x() + idx * 0.035) # patch.set_x(patch.get_x() + idx * 0.035)
ax[0].set_title('Accuracy after particle dropout') ax[0].set_title('Accuracy after particle dropout')
ax[0].set_xlabel('Accuracy') ax[0].set_xlabel('Particle Type')
ax[1].pie(particle_dict.values(), labels=particle_dict.keys(), colors=colors, ) ax[1].pie(particle_dict.values(), labels=particle_dict.keys(), colors=colors, )
ax[1].set_title('Particle Count for ') ax[1].set_title('Particle Count')
plt.tight_layout() plt.tight_layout()
if debug: if debug:
@@ -278,196 +287,202 @@ def flat_for_store(parameters):
if __name__ == '__main__': if __name__ == '__main__':
self_train = True self_train = True
training = True training = False
train_to_id_first = False train_to_id_first = True
train_to_task_first = False train_to_task_first = False
train_to_task_first_sequential = True sequential_task_train = True
force_st_for_n_from_last_epochs = 5 force_st_for_n_from_last_epochs = 5
n_st_per_batch = 3
activation = None # nn.ReLU()
use_sparse_network = False use_sparse_network = True
tsk_threshold = 0.855 for weight_hidden_size in [3, 4, 5, 6]:
self_train_alpha = 1
batch_train_beta = 1
weight_hidden_size = 3
residual_skip = True
n_seeds = 5
data_path = Path('data') tsk_threshold = 0.85
data_path.mkdir(exist_ok=True, parents=True) weight_hidden_size = weight_hidden_size
assert not (train_to_task_first and train_to_id_first) residual_skip = True
n_seeds = 3
st_str = f'{"" if self_train else "no_"}st' data_path = Path('data')
a_str = f'_alpha_{self_train_alpha}' if self_train_alpha != 1 else '' data_path.mkdir(exist_ok=True, parents=True)
res_str = f'{"" if residual_skip else "_no_res"}' assert not (train_to_task_first and train_to_id_first)
# dr_str = f'{f"_dr_{dropout}" if dropout != 0 else ""}'
id_str = f'{f"_StToId" if train_to_id_first else ""}'
tsk_str = f'{f"_Tsk_{tsk_threshold}" if train_to_task_first and tsk_threshold != 1 else ""}'
sprs_str = '_sprs' if use_sparse_network else ''
f_str = f'_f_{force_st_for_n_from_last_epochs}' if \
force_st_for_n_from_last_epochs and train_to_task_first_sequential and train_to_task_first \
else ""
config_str = f'{a_str}{res_str}{id_str}{tsk_str}{f_str}{sprs_str}'
exp_path = Path('output') / f'mn_{st_str}_{EPOCH}_{weight_hidden_size}{config_str}'
for seed in range(n_seeds): st_str = f'{"" if self_train else "no_"}st{f"_n_{n_st_per_batch}" if n_st_per_batch else ""}'
seed_path = exp_path / str(seed) ac_str = f'_{activation.__class__.__name__}' if activation is not None else ''
res_str = f'{"" if residual_skip else "_no_res"}'
# dr_str = f'{f"_dr_{dropout}" if dropout != 0 else ""}'
id_str = f'{f"_StToId" if train_to_id_first else ""}'
tsk_str = f'{f"_Tsk_{tsk_threshold}" if train_to_task_first and tsk_threshold != 1 else ""}'
sprs_str = '_sprs' if use_sparse_network else ''
f_str = f'_f_{force_st_for_n_from_last_epochs}' if \
force_st_for_n_from_last_epochs and sequential_task_train and train_to_task_first else ""
config_str = f'{res_str}{id_str}{tsk_str}{f_str}{sprs_str}'
exp_path = Path('output') / f'mn_{st_str}_{EPOCH}_{weight_hidden_size}{config_str}{ac_str}'
model_path = seed_path / '0000_trained_model.zip' if not training:
df_store_path = seed_path / 'train_store.csv' # noinspection PyRedeclaration
weight_store_path = seed_path / 'weight_store.csv' exp_path = Path('output') / 'mn_st_n_2_100_4'
srnn_parameters = dict()
for path in [model_path, df_store_path, weight_store_path]:
assert not path.exists(), f'Path "{path}" already exists. Check your configuration!'
if training: for seed in range(n_seeds):
utility_transforms = Compose([ToTensor(), ToFloat(), Resize((15, 15)), Flatten(start_dim=0)]) seed_path = exp_path / str(seed)
try:
dataset = MNIST(str(data_path), transform=utility_transforms)
except RuntimeError:
dataset = MNIST(str(data_path), transform=utility_transforms, download=True)
d = DataLoader(dataset, batch_size=BATCHSIZE, shuffle=True, drop_last=True, num_workers=WORKER)
interface = np.prod(dataset[0][0].shape) model_path = seed_path / '0000_trained_model.zip'
dense_metanet = MetaNet(interface, depth=5, width=6, out=10, residual_skip=residual_skip, df_store_path = seed_path / 'train_store.csv'
weight_hidden_size=weight_hidden_size,).to(DEVICE) weight_store_path = seed_path / 'weight_store.csv'
sparse_metanet = SparseNetwork(interface, depth=5, width=6, out=10, residual_skip=residual_skip, srnn_parameters = dict()
weight_hidden_size=weight_hidden_size
).to(DEVICE) if use_sparse_network else dense_metanet
meta_weight_count = sum(p.numel() for p in next(dense_metanet.particles).parameters())
loss_fn = nn.CrossEntropyLoss() if training:
dense_optimizer = torch.optim.SGD(dense_metanet.parameters(), lr=0.008, momentum=0.9) # Check if files do exist on project location, warn and break.
sparse_optimizer = torch.optim.SGD( for path in [model_path, df_store_path, weight_store_path]:
sparse_metanet.parameters(), lr=0.008, momentum=0.9 assert not path.exists(), f'Path "{path}" already exists. Check your configuration!'
) if use_sparse_network else dense_optimizer
train_store = new_storage_df('train', None) utility_transforms = Compose([ToTensor(), ToFloat(), Resize((15, 15)), Flatten(start_dim=0)])
weight_store = new_storage_df('weights', meta_weight_count) try:
init_tsk = train_to_task_first dataset = MNIST(str(data_path), transform=utility_transforms)
for epoch in tqdm(range(EPOCH), desc='MetaNet Train - Epochs'): except RuntimeError:
is_validation_epoch = epoch % VALIDATION_FRQ == 0 if not debug else True dataset = MNIST(str(data_path), transform=utility_transforms, download=True)
is_self_train_epoch = epoch % SELF_TRAIN_FRQ == 0 if not debug else True d = DataLoader(dataset, batch_size=BATCHSIZE, shuffle=True, drop_last=True, num_workers=WORKER)
sparse_metanet = sparse_metanet.train()
dense_metanet = dense_metanet.train()
if is_validation_epoch:
metric = torchmetrics.Accuracy()
else:
metric = None
init_st = train_to_id_first and not all(
x.is_fixpoint == ft.identity_func for x in dense_metanet.particles
)
force_st = (force_st_for_n_from_last_epochs >= (EPOCH - epoch)
) and train_to_task_first_sequential and force_st_for_n_from_last_epochs
for batch, (batch_x, batch_y) in tqdm(enumerate(d), total=len(d), desc='MetaNet Train - Batch'):
# Self Train interface = np.prod(dataset[0][0].shape)
if self_train and ((not init_tsk and (is_self_train_epoch or init_st)) or force_st): dense_metanet = MetaNet(interface, depth=5, width=6, out=10, residual_skip=residual_skip,
# Transfer weights weight_hidden_size=weight_hidden_size, activation=activation).to(DEVICE)
if use_sparse_network: sparse_metanet = SparseNetwork(interface, depth=5, width=6, out=10, residual_skip=residual_skip,
sparse_metanet = sparse_metanet.replace_weights_by_particles(dense_metanet.particles) weight_hidden_size=weight_hidden_size, activation=activation
# Zero your gradients for every batch! ).to(DEVICE) if use_sparse_network else dense_metanet
sparse_optimizer.zero_grad() meta_weight_count = sum(p.numel() for p in next(dense_metanet.particles).parameters())
self_train_loss = sparse_metanet.combined_self_train() * self_train_alpha
self_train_loss.backward()
# Adjust learning weights
sparse_optimizer.step()
step_log = dict(Epoch=epoch, Batch=batch,
Metric='Self Train Loss', Score=self_train_loss.item())
train_store.loc[train_store.shape[0]] = step_log
# Transfer weights
if use_sparse_network:
dense_metanet = dense_metanet.replace_particles(sparse_metanet.particle_weights)
# Task Train loss_fn = nn.CrossEntropyLoss()
if not init_st: dense_optimizer = torch.optim.SGD(dense_metanet.parameters(), lr=0.008, momentum=0.9)
# Zero your gradients for every batch! sparse_optimizer = torch.optim.SGD(
dense_optimizer.zero_grad() sparse_metanet.parameters(), lr=0.008, momentum=0.9
batch_x, batch_y = batch_x.to(DEVICE), batch_y.to(DEVICE) ) if use_sparse_network else dense_optimizer
y_pred = dense_metanet(batch_x)
# loss = loss_fn(y, batch_y.unsqueeze(-1).to(torch.float32))
loss = loss_fn(y_pred, batch_y.to(torch.long)) * batch_train_beta
loss.backward()
# Adjust learning weights train_store = new_storage_df('train', None)
dense_optimizer.step() weight_store = new_storage_df('weights', meta_weight_count)
init_tsk = train_to_task_first
for epoch in tqdm(range(EPOCH), desc='MetaNet Train - Epochs'):
is_validation_epoch = epoch % VALIDATION_FRQ == 0 if not debug else True
is_self_train_epoch = epoch % SELF_TRAIN_FRQ == 0 if not debug else True
sparse_metanet = sparse_metanet.train()
dense_metanet = dense_metanet.train()
if is_validation_epoch:
metric = torchmetrics.Accuracy()
else:
metric = None
init_st = train_to_id_first and not all(
x.is_fixpoint == ft.identity_func for x in dense_metanet.particles
)
force_st = (force_st_for_n_from_last_epochs >= (EPOCH - epoch)
) and sequential_task_train and force_st_for_n_from_last_epochs
for batch, (batch_x, batch_y) in tqdm(enumerate(d), total=len(d), desc='MetaNet Train - Batch'):
step_log = dict(Epoch=epoch, Batch=batch, # Self Train
Metric='Task Loss', Score=loss.item()) if self_train and ((not init_tsk and (is_self_train_epoch or init_st)) or force_st):
train_store.loc[train_store.shape[0]] = step_log # Transfer weights
if is_validation_epoch: if use_sparse_network:
metric(y_pred.cpu(), batch_y.cpu()) sparse_metanet = sparse_metanet.replace_weights_by_particles(dense_metanet.particles)
for _ in range(n_st_per_batch):
self_train_loss = sparse_metanet.combined_self_train(sparse_optimizer, reduction='mean')
# noinspection PyUnboundLocalVariable
step_log = dict(Epoch=epoch, Batch=batch,
Metric='Self Train Loss', Score=self_train_loss.item())
train_store.loc[train_store.shape[0]] = step_log
# Transfer weights
if use_sparse_network:
dense_metanet = dense_metanet.replace_particles(sparse_metanet.particle_weights)
if batch >= 3 and debug: # Task Train
break if not init_st:
# Zero your gradients for every batch!
dense_optimizer.zero_grad()
batch_x, batch_y = batch_x.to(DEVICE), batch_y.to(DEVICE)
y_pred = dense_metanet(batch_x)
# loss = loss_fn(y, batch_y.unsqueeze(-1).to(torch.float32))
loss = loss_fn(y_pred, batch_y.to(torch.long))
loss.backward()
if is_validation_epoch: # Adjust learning weights
dense_metanet = dense_metanet.eval() dense_optimizer.step()
if not init_st:
step_log = dict(Epoch=epoch, Batch=batch,
Metric='Task Loss', Score=loss.item())
train_store.loc[train_store.shape[0]] = step_log
if is_validation_epoch:
metric(y_pred.cpu(), batch_y.cpu())
if batch >= 3 and debug:
break
if is_validation_epoch:
dense_metanet = dense_metanet.eval()
if not init_st:
validation_log = dict(Epoch=int(epoch), Batch=BATCHSIZE,
Metric='Train Accuracy', Score=metric.compute().item())
train_store.loc[train_store.shape[0]] = validation_log
accuracy = checkpoint_and_validate(dense_metanet, seed_path, epoch).item()
validation_log = dict(Epoch=int(epoch), Batch=BATCHSIZE, validation_log = dict(Epoch=int(epoch), Batch=BATCHSIZE,
Metric='Train Accuracy', Score=metric.compute().item()) Metric='Test Accuracy', Score=accuracy)
train_store.loc[train_store.shape[0]] = validation_log train_store.loc[train_store.shape[0]] = validation_log
if init_tsk or (train_to_task_first and sequential_task_train):
init_tsk = accuracy <= tsk_threshold
if init_st or is_validation_epoch:
counter_dict = defaultdict(lambda: 0)
# This returns ID-functions
_ = test_for_fixpoints(counter_dict, list(dense_metanet.particles))
counter_dict = dict(counter_dict)
for key, value in counter_dict.items():
step_log = dict(Epoch=int(epoch), Batch=BATCHSIZE, Metric=key, Score=value)
train_store.loc[train_store.shape[0]] = step_log
tqdm.write(f'Fixpoint Tester Results: {counter_dict}')
if init_st or is_validation_epoch:
for particle in dense_metanet.particles:
weight_log = (epoch, particle.name, *flat_for_store(particle.parameters()))
weight_store.loc[weight_store.shape[0]] = weight_log
train_store.to_csv(df_store_path, mode='a', header=not df_store_path.exists(), index=False)
weight_store.to_csv(weight_store_path, mode='a', header=not weight_store_path.exists(), index=False)
train_store = new_storage_df('train', None)
weight_store = new_storage_df('weights', meta_weight_count)
accuracy = checkpoint_and_validate(dense_metanet, seed_path, epoch).item() dense_metanet.eval()
validation_log = dict(Epoch=int(epoch), Batch=BATCHSIZE,
Metric='Test Accuracy', Score=accuracy)
train_store.loc[train_store.shape[0]] = validation_log
if init_tsk or (train_to_task_first and train_to_task_first_sequential):
init_tsk = accuracy <= tsk_threshold
if init_st or is_validation_epoch:
counter_dict = defaultdict(lambda: 0)
# This returns ID-functions
_ = test_for_fixpoints(counter_dict, list(dense_metanet.particles))
for key, value in dict(counter_dict).items():
step_log = dict(Epoch=int(epoch), Batch=BATCHSIZE, Metric=key, Score=value)
train_store.loc[train_store.shape[0]] = step_log
if init_st or is_validation_epoch:
for particle in dense_metanet.particles:
weight_log = (epoch, particle.name, *flat_for_store(particle.parameters()))
weight_store.loc[weight_store.shape[0]] = weight_log
train_store.to_csv(df_store_path, mode='a', header=not df_store_path.exists(), index=False)
weight_store.to_csv(weight_store_path, mode='a', header=not weight_store_path.exists(), index=False)
train_store = new_storage_df('train', None)
weight_store = new_storage_df('weights', meta_weight_count)
dense_metanet.eval() counter_dict = defaultdict(lambda: 0)
# This returns ID-functions
_ = test_for_fixpoints(counter_dict, list(dense_metanet.particles))
for key, value in dict(counter_dict).items():
step_log = dict(Epoch=int(EPOCH), Batch=BATCHSIZE, Metric=key, Score=value)
train_store.loc[train_store.shape[0]] = step_log
accuracy = checkpoint_and_validate(dense_metanet, seed_path, EPOCH, final_model=True)
validation_log = dict(Epoch=EPOCH, Batch=BATCHSIZE,
Metric='Test Accuracy', Score=accuracy.item())
for particle in dense_metanet.particles:
weight_log = (EPOCH, particle.name, *(flat_for_store(particle.parameters())))
weight_store.loc[weight_store.shape[0]] = weight_log
counter_dict = defaultdict(lambda: 0) train_store.loc[train_store.shape[0]] = validation_log
# This returns ID-functions train_store.to_csv(df_store_path, mode='a', header=not df_store_path.exists(), index=False)
_ = test_for_fixpoints(counter_dict, list(dense_metanet.particles)) weight_store.to_csv(weight_store_path, mode='a', header=not weight_store_path.exists(), index=False)
for key, value in dict(counter_dict).items():
step_log = dict(Epoch=int(EPOCH), Batch=BATCHSIZE, Metric=key, Score=value)
train_store.loc[train_store.shape[0]] = step_log
accuracy = checkpoint_and_validate(dense_metanet, seed_path, EPOCH, final_model=True)
validation_log = dict(Epoch=EPOCH, Batch=BATCHSIZE,
Metric='Test Accuracy', Score=accuracy.item())
for particle in dense_metanet.particles:
weight_log = (EPOCH, particle.name, *(flat_for_store(particle.parameters())))
weight_store.loc[weight_store.shape[0]] = weight_log
train_store.loc[train_store.shape[0]] = validation_log plot_training_result(df_store_path)
train_store.to_csv(df_store_path, mode='a', header=not df_store_path.exists(), index=False) plot_training_particle_types(df_store_path)
weight_store.to_csv(weight_store_path, mode='a', header=not weight_store_path.exists(), index=False)
plot_training_result(df_store_path) try:
plot_training_particle_types(df_store_path) model_path = next(seed_path.glob(f'*e{EPOCH}.tp'))
except StopIteration:
try: print('Model pattern did not trigger.')
model_path = next(seed_path.glob(f'*e{EPOCH}.tp')) print(f'Search path was: {seed_path}:')
except StopIteration: print(f'Found Models are: {list(seed_path.rglob(".tp"))}')
print('Model pattern did not trigger.') exit(1)
print(f'Search path was: {seed_path}:') latest_model = torch.load(model_path, map_location=DEVICE).eval()
print(f'Found Models are: {list(seed_path.rglob(".tp"))}') try:
exit(1) run_particle_dropout_and_plot(seed_path)
latest_model = torch.load(model_path, map_location=DEVICE).eval() except ValueError as e:
try: print(e)
run_particle_dropout_and_plot(seed_path) try:
except ValueError as e: plot_network_connectivity_by_fixtype(model_path)
print(e) except ValueError as e:
try: print(e)
plot_network_connectivity_by_fixtype(model_path)
except ValueError as e:
print(e)
if n_seeds >= 2: if n_seeds >= 2:
pass pass

View File

@@ -6,11 +6,14 @@ from tqdm import tqdm
from network import FixTypes, Net from network import FixTypes, Net
epsilon_error_margin = pow(10, -5)
def is_divergent(network: Net) -> bool: def is_divergent(network: Net) -> bool:
return network.input_weight_matrix().isinf().any().item() or network.input_weight_matrix().isnan().any().item() return network.input_weight_matrix().isinf().any().item() or network.input_weight_matrix().isnan().any().item()
def is_identity_function(network: Net, epsilon=pow(10, -5)) -> bool: def is_identity_function(network: Net, epsilon=epsilon_error_margin) -> bool:
input_data = network.input_weight_matrix() input_data = network.input_weight_matrix()
target_data = network.create_target_weights(input_data) target_data = network.create_target_weights(input_data)
@@ -20,14 +23,14 @@ def is_identity_function(network: Net, epsilon=pow(10, -5)) -> bool:
rtol=0, atol=epsilon) rtol=0, atol=epsilon)
def is_zero_fixpoint(network: Net, epsilon=pow(10, -5)) -> bool: def is_zero_fixpoint(network: Net, epsilon=epsilon_error_margin) -> bool:
target_data = network.create_target_weights(network.input_weight_matrix().detach()) target_data = network.create_target_weights(network.input_weight_matrix().detach())
result = torch.allclose(target_data, torch.zeros_like(target_data), rtol=0, atol=epsilon) result = torch.allclose(target_data, torch.zeros_like(target_data), rtol=0, atol=epsilon)
# result = bool(len(np.nonzero(network.create_target_weights(network.input_weight_matrix())))) # result = bool(len(np.nonzero(network.create_target_weights(network.input_weight_matrix()))))
return result return result
def is_secondary_fixpoint(network: Net, epsilon: float = pow(10, -5)) -> bool: def is_secondary_fixpoint(network: Net, epsilon: float = epsilon_error_margin) -> bool:
""" Secondary fixpoint check is done like this: compare first INPUT with second OUTPUT. """ Secondary fixpoint check is done like this: compare first INPUT with second OUTPUT.
If they are within the boundaries, then is secondary fixpoint. """ If they are within the boundaries, then is secondary fixpoint. """

View File

@@ -420,7 +420,7 @@ class MetaNet(nn.Module):
) for layer_idx in range(self.depth - 2)] ) for layer_idx in range(self.depth - 2)]
) )
self._meta_layer_last = MetaLayer(name=f'L{len(self._meta_layer_list)}', self._meta_layer_last = MetaLayer(name=f'L{len(self._meta_layer_list) + 1}',
interface=self.width, width=self.out, interface=self.width, width=self.out,
weight_interface=weight_interface, weight_interface=weight_interface,
weight_hidden_size=weight_hidden_size, weight_hidden_size=weight_hidden_size,
@@ -428,8 +428,6 @@ class MetaNet(nn.Module):
) )
self.dropout_layer = nn.Dropout(p=self.dropout) self.dropout_layer = nn.Dropout(p=self.dropout)
self._all_layers_with_particles = [self._meta_layer_first, *self._meta_layer_list, self._meta_layer_last]
def replace_with_zero(self, ident_key): def replace_with_zero(self, ident_key):
replaced_particles = 0 replaced_particles = 0
for particle in self.particles: for particle in self.particles:
@@ -442,48 +440,51 @@ class MetaNet(nn.Module):
return self return self
def forward(self, x): def forward(self, x):
if self.dropout != 0:
x = self.dropout_layer(x)
tensor = self._meta_layer_first(x) tensor = self._meta_layer_first(x)
residual = None
for idx, meta_layer in enumerate(self._meta_layer_list, start=1): for idx, meta_layer in enumerate(self._meta_layer_list, start=1):
if self.dropout != 0:
tensor = self.dropout_layer(tensor)
if idx % 2 == 1 and self.residual_skip: if idx % 2 == 1 and self.residual_skip:
x = tensor.clone() residual = tensor.clone()
tensor = meta_layer(tensor) tensor = meta_layer(tensor)
if idx % 2 == 0 and self.residual_skip: if idx % 2 == 0 and self.residual_skip:
tensor = tensor + x tensor = tensor + residual
if self.dropout != 0: tensor = self._meta_layer_last(tensor)
x = self.dropout_layer(x)
tensor = self._meta_layer_last(x)
return tensor return tensor
@property @property
def particles(self): def particles(self):
return (cell for metalayer in self._all_layers_with_particles for cell in metalayer.particles) return (cell for metalayer in self.all_layers for cell in metalayer.particles)
def combined_self_train(self): def combined_self_train(self, optimizer, reduction='mean'):
optimizer.zero_grad()
losses = [] losses = []
for particle in self.particles: for particle in self.particles:
# Intergrate optimizer and backward function # Intergrate optimizer and backward function
input_data = particle.input_weight_matrix() input_data = particle.input_weight_matrix()
target_data = particle.create_target_weights(input_data) target_data = particle.create_target_weights(input_data)
output = particle(input_data) output = particle(input_data)
losses.append(F.mse_loss(output, target_data)) losses.append(F.mse_loss(output, target_data, reduction=reduction))
return torch.hstack(losses).sum(dim=-1, keepdim=True) losses = torch.hstack(losses).sum(dim=-1, keepdim=True)
losses.backward()
optimizer.step()
return losses.detach()
@property @property
def hyperparams(self): def hyperparams(self):
return {key: val for key, val in self.__dict__.items() if not key.startswith('_')} return {key: val for key, val in self.__dict__.items() if not key.startswith('_')}
def replace_particles(self, particle_weights_list): def replace_particles(self, particle_weights_list):
for layer in self._all_layers_with_particles: for layer in self.all_layers:
for cell in layer.meta_cell_list: for cell in layer.meta_cell_list:
# Individual replacement on cell lvl # Individual replacement on cell lvl
for weight in cell.meta_weight_list: for weight in cell.meta_weight_list:
weight.apply_weights(next(particle_weights_list).detach()) weight.apply_weights(next(particle_weights_list).detach())
return self return self
@property
def all_layers(self):
return (x for x in (self._meta_layer_first, *self._meta_layer_list, self._meta_layer_last))
class MetaNetCompareBaseline(nn.Module): class MetaNetCompareBaseline(nn.Module):
@@ -495,19 +496,24 @@ class MetaNetCompareBaseline(nn.Module):
self.interface = interface self.interface = interface
self.width = width self.width = width
self.depth = depth self.depth = depth
self._first_layer = nn.Linear(self.interface, self.width, bias=False) self._first_layer = nn.Linear(self.interface, self.width, bias=False)
self._meta_layer_list = nn.ModuleList([nn.Linear(self.width, self.width, bias=False) for _ in range(self.depth - 2)]) self._meta_layer_list = nn.ModuleList([nn.Linear(self.width, self.width, bias=False
) for _ in range(self.depth - 2)])
self._last_layer = nn.Linear(self.width, self.out, bias=False) self._last_layer = nn.Linear(self.width, self.out, bias=False)
def forward(self, x): def forward(self, x):
tensor = self._first_layer(x) tensor = self._first_layer(x)
if self.activation:
tensor = self.activation(tensor)
residual = None
for idx, meta_layer in enumerate(self._meta_layer_list, start=1): for idx, meta_layer in enumerate(self._meta_layer_list, start=1):
if idx % 2 == 1 and self.residual_skip:
x = tensor.clone()
tensor = meta_layer(tensor) tensor = meta_layer(tensor)
if idx % 2 == 1 and self.residual_skip:
residual = tensor.clone()
if idx % 2 == 0 and self.residual_skip: if idx % 2 == 0 and self.residual_skip:
tensor = tensor + x tensor = tensor + residual
if self.activation:
tensor = self.activation(tensor)
tensor = self._last_layer(tensor) tensor = self._last_layer(tensor)
return tensor return tensor

View File

@@ -10,8 +10,11 @@ from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import MNIST, CIFAR10 from torchvision.datasets import MNIST, CIFAR10
from torchvision.transforms import ToTensor, Compose, Resize, Normalize, Grayscale from torchvision.transforms import ToTensor, Compose, Resize, Normalize, Grayscale
import torchmetrics import torchmetrics
from functionalities_test import epsilon_error_margin as e
from network import MetaNet, MetaNetCompareBaseline from network import MetaNet, MetaNetCompareBaseline
def extract_weights_from_model(model:MetaNet)->dict: def extract_weights_from_model(model:MetaNet)->dict:
inpt = torch.zeros(5) inpt = torch.zeros(5)
inpt[-1] = 1 inpt[-1] = 1
@@ -25,27 +28,51 @@ def extract_weights_from_model(model:MetaNet)->dict:
return dict(weights) return dict(weights)
def test_weights_as_model(model, new_weights:dict, data): def test_weights_as_model(meta_net, new_weights:dict, data):
TransferNet = MetaNetCompareBaseline(model.interface, depth=model.depth, width=model.width, out=model.out, transfer_net = MetaNetCompareBaseline(meta_net.interface, depth=meta_net.depth, width=meta_net.width, out=meta_net.out,
residual_skip=True) residual_skip=True)
with torch.no_grad(): with torch.no_grad():
for weights, parameters in zip(new_weights.values(), TransferNet.parameters()): new_weight_values = list(new_weights.values())
old_parameters = list(transfer_net.parameters())
assert len(new_weight_values) == len(old_parameters)
for weights, parameters in zip(new_weights.values(), transfer_net.parameters()):
parameters[:] = torch.Tensor(weights).view(parameters.shape)[:] parameters[:] = torch.Tensor(weights).view(parameters.shape)[:]
TransferNet.eval() transfer_net.eval()
metric = torchmetrics.Accuracy()
with tqdm(desc='Test Batch: ') as pbar: # Test if the margin of error is similar
for batch, (batch_x, batch_y) in tqdm(enumerate(data), total=len(data), desc='MetaNet Sanity Check'):
y = TransferNet(batch_x) im_t = defaultdict(list)
acc = metric(y.cpu(), batch_y.cpu()) rand = torch.randn((1, 15 * 15))
pbar.set_postfix_str(f'Acc: {acc}') for net in [meta_net, transfer_net]:
pbar.update() tensor = rand.clone()
for layer in net.all_layers:
# metric on all batches using custom accumulation tensor = layer(tensor)
acc = metric.compute() im_t[net.__class__.__name__].append(tensor.detach())
tqdm.write(f"Avg. accuracy on all data: {acc}")
return acc im_t = dict(im_t)
all_close = {f'layer_{idx}': torch.allclose(y1.detach(), y2.detach(), rtol=0, atol=e
) for idx, (y1, y2) in enumerate(zip(*im_t.values()))
}
print(f'Cummulative differences per layer is smaller then {e}:\n {all_close}')
# all_errors = {f'layer_{idx}': torch.absolute(y1.detach(), y2.detach(), rtol=0, atol=e
# ) for idx, (y1, y2) in enumerate(zip(*im_t.values()))
# }
for net in [meta_net, transfer_net]:
net.eval()
metric = torchmetrics.Accuracy()
with tqdm(desc='Test Batch: ') as pbar:
for batch, (batch_x, batch_y) in tqdm(enumerate(data), total=len(data), desc='MetaNet Sanity Check'):
y = net(batch_x)
acc = metric(y.cpu(), batch_y.cpu())
pbar.set_postfix_str(f'Acc: {acc}')
pbar.update()
# metric on all batches using custom accumulation
acc = metric.compute()
tqdm.write(f"Avg. accuracy on {net.__class__.__name__}: {acc}")
if __name__ == '__main__': if __name__ == '__main__':
@@ -58,7 +85,7 @@ if __name__ == '__main__':
data_path.mkdir(exist_ok=True, parents=True) data_path.mkdir(exist_ok=True, parents=True)
mnist_test = MNIST(str(data_path), transform=MNIST_TRANSFORM, download=True, train=False) mnist_test = MNIST(str(data_path), transform=MNIST_TRANSFORM, download=True, train=False)
d_test = DataLoader(mnist_test, batch_size=BATCHSIZE, shuffle=False, drop_last=True, num_workers=WORKER) d_test = DataLoader(mnist_test, batch_size=BATCHSIZE, shuffle=False, drop_last=True, num_workers=WORKER)
model = torch.load(Path('experiments/output/trained_model_ckpt_e50.tp'), map_location=DEVICE).eval() model = torch.load(Path('experiments/output/trained_model_ckpt_e50.tp'), map_location=DEVICE).eval()
weights = extract_weights_from_model(model) weights = extract_weights_from_model(model)
test_weights_as_model(model, weights, d_test) test_weights_as_model(model, weights, d_test)

View File

@@ -161,7 +161,7 @@ def embed_vector(x, repeat_dim):
class SparseNetwork(nn.Module): class SparseNetwork(nn.Module):
def __init__(self, input_dim, depth, width, out, residual_skip=True, def __init__(self, input_dim, depth, width, out, residual_skip=True, activation=None,
weight_interface=5, weight_hidden_size=2, weight_output_size=1 weight_interface=5, weight_hidden_size=2, weight_output_size=1
): ):
super(SparseNetwork, self).__init__() super(SparseNetwork, self).__init__()
@@ -170,6 +170,7 @@ class SparseNetwork(nn.Module):
self.depth_dim = depth self.depth_dim = depth
self.hidden_dim = width self.hidden_dim = width
self.out_dim = out self.out_dim = out
self.activation = activation
self.first_layer = SparseLayer(self.input_dim * self.hidden_dim, self.first_layer = SparseLayer(self.input_dim * self.hidden_dim,
interface=weight_interface, width=weight_hidden_size, out=weight_output_size) interface=weight_interface, width=weight_hidden_size, out=weight_output_size)
self.last_layer = SparseLayer(self.hidden_dim * self.out_dim, self.last_layer = SparseLayer(self.hidden_dim * self.out_dim,
@@ -182,13 +183,17 @@ class SparseNetwork(nn.Module):
def __call__(self, x): def __call__(self, x):
tensor = self.sparse_layer_forward(x, self.first_layer) tensor = self.sparse_layer_forward(x, self.first_layer)
if self.activation:
tensor = self.activation(tensor)
for nl_idx, network_layer in enumerate(self.hidden_layers): for nl_idx, network_layer in enumerate(self.hidden_layers):
if nl_idx % 2 == 0 and self.residual_skip:
residual = tensor
# Sparse Layer pass # Sparse Layer pass
tensor = self.sparse_layer_forward(tensor, network_layer) tensor = self.sparse_layer_forward(tensor, network_layer)
if nl_idx % 2 != 0 and self.residual_skip: if self.activation:
tensor = self.activation(tensor)
if nl_idx % 2 == 0 and self.residual_skip:
residual = tensor.clone()
if nl_idx % 2 == 1 and self.residual_skip:
# noinspection PyUnboundLocalVariable # noinspection PyUnboundLocalVariable
tensor += residual tensor += residual
tensor = self.sparse_layer_forward(tensor, self.last_layer, view_dim=self.out_dim) tensor = self.sparse_layer_forward(tensor, self.last_layer, view_dim=self.out_dim)
@@ -234,14 +239,19 @@ class SparseNetwork(nn.Module):
def sparselayers(self): def sparselayers(self):
return (x for x in (self.first_layer, *self.hidden_layers, self.last_layer)) return (x for x in (self.first_layer, *self.hidden_layers, self.last_layer))
def combined_self_train(self): def combined_self_train(self, optimizer, reduction='mean'):
losses = [] losses = []
for layer in self.sparselayers: for layer in self.sparselayers:
optimizer.zero_grad()
x, target_data = layer.get_self_train_inputs_and_targets() x, target_data = layer.get_self_train_inputs_and_targets()
output = layer(x) output = layer(x)
losses.append(F.mse_loss(output, target_data) / layer.nr_nets) loss = F.mse_loss(output, target_data, reduction=reduction)
return torch.hstack(losses).sum(dim=-1, keepdim=True) losses.append(loss.detach())
loss.backward()
optimizer.step()
return sum(losses)
def replace_weights_by_particles(self, particles): def replace_weights_by_particles(self, particles):
particles = list(particles) particles = list(particles)
@@ -274,12 +284,7 @@ def test_sparse_net_sef_train():
if True: if True:
optimizer = torch.optim.SGD(net.parameters(), lr=0.004, momentum=0.9) optimizer = torch.optim.SGD(net.parameters(), lr=0.004, momentum=0.9)
for _ in trange(epochs): for _ in trange(epochs):
optimizer.zero_grad() _ = net.combined_self_train(optimizer)
loss = net.combined_self_train()
print(loss)
exit()
loss.backward()
optimizer.step()
else: else:
optimizer_dict = { optimizer_dict = {