small fixes new parameters

This commit is contained in:
Steffen Illium
2022-02-25 15:32:56 +01:00
parent 5b2b5b5beb
commit 9d8496a725
5 changed files with 292 additions and 236 deletions

View File

@@ -45,8 +45,8 @@ from functionalities_test import test_for_fixpoints
WORKER = 10 if not debug else 2 WORKER = 10 if not debug else 2
debug = False debug = False
BATCHSIZE = 500 if not debug else 50 BATCHSIZE = 500 if not debug else 50
EPOCH = 50 EPOCH = 100
VALIDATION_FRQ = 3 if not debug else 1 VALIDATION_FRQ = 4 if not debug else 1
SELF_TRAIN_FRQ = 1 if not debug else 1 SELF_TRAIN_FRQ = 1 if not debug else 1
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
@@ -56,6 +56,9 @@ if debug:
class ToFloat: class ToFloat:
def __init__(self):
pass
def __call__(self, x): def __call__(self, x):
return x.to(torch.float32) return x.to(torch.float32)
@@ -194,7 +197,7 @@ def plot_training_result(path_to_dataframe):
def plot_network_connectivity_by_fixtype(path_to_trained_model): def plot_network_connectivity_by_fixtype(path_to_trained_model):
m = torch.load(path_to_trained_model, map_location=torch.device('cpu')) m = torch.load(path_to_trained_model, map_location=torch.device('cpu'))
# noinspection PyProtectedMember # noinspection PyProtectedMember
particles = [y for x in m._meta_layer_list for y in x.particles] particles = list(m.particles)
df = pd.DataFrame(columns=['type', 'layer', 'neuron', 'name']) df = pd.DataFrame(columns=['type', 'layer', 'neuron', 'name'])
for prtcl in particles: for prtcl in particles:
@@ -210,10 +213,16 @@ def plot_network_connectivity_by_fixtype(path_to_trained_model):
for n, fixtype in enumerate([ft.other_func, ft.identity_func]): for n, fixtype in enumerate([ft.other_func, ft.identity_func]):
plt.clf() plt.clf()
ax = sns.lineplot(y='neuron', x='layer', hue='name', data=df[df['type'] == fixtype], ax = sns.lineplot(y='neuron', x='layer', hue='name', data=df[df['type'] == fixtype],
legend=False, estimator=None, legend=False, estimator=None, lw=1)
palette=[sns.color_palette()[n]] * (df[df['type'] == fixtype].shape[0]//2), lw=1) _ = sns.lineplot(y=[0, 1], x=[-1, df['layer'].max()], legend=False, estimator=None, lw=0)
ax.set_title(fixtype) ax.set_title(fixtype)
lines = ax.get_lines()
for line in lines:
line.set_color(sns.color_palette()[n])
if debug:
plt.show() plt.show()
else:
plt.savefig(Path(path_to_trained_model.parent / f'net_connectivity_{fixtype}.png'), dpi=300)
print('plottet') print('plottet')
@@ -234,7 +243,7 @@ def run_particle_dropout_test(run_path):
tqdm.write(f'Zero_ident diff = {acc_diff}') tqdm.write(f'Zero_ident diff = {acc_diff}')
diff_df.loc[diff_df.shape[0]] = (fixpoint_type, acc_post, acc_diff) diff_df.loc[diff_df.shape[0]] = (fixpoint_type, acc_post, acc_diff)
diff_df.to_csv(diff_store_path, mode='a', header=not df_store_path.exists(), index=False) diff_df.to_csv(diff_store_path, mode='a', header=not diff_store_path.exists(), index=False)
return diff_store_path return diff_store_path
@@ -246,18 +255,18 @@ def plot_dropout_stacked_barplot(path_to_diff_df):
plt.clf() plt.clf()
fig, ax = plt.subplots(ncols=2) fig, ax = plt.subplots(ncols=2)
colors = sns.color_palette()[:diff_df.shape[0]] colors = sns.color_palette()[:diff_df.shape[0]]
barplot = sns.barplot(data=diff_df, y='Accuracy', x='Particle Type', palette=colors, ax=ax[0]) barplot = sns.barplot(data=diff_df, y='Accuracy', x='Particle Type', ax=ax[0], palette=colors)
# noinspection PyUnboundLocalVariable # noinspection PyUnboundLocalVariable
for idx, patch in enumerate(barplot.patches): #for idx, patch in enumerate(barplot.patches):
if idx != 0: # if idx != 0:
# we recenter the bar # # we recenter the bar
patch.set_x(patch.get_x() + idx * 0.035) # patch.set_x(patch.get_x() + idx * 0.035)
ax[0].set_title('Accuracy after particle dropout') ax[0].set_title('Accuracy after particle dropout')
ax[0].set_xlabel('Accuracy') ax[0].set_xlabel('Particle Type')
ax[1].pie(particle_dict.values(), labels=particle_dict.keys(), colors=colors, ) ax[1].pie(particle_dict.values(), labels=particle_dict.keys(), colors=colors, )
ax[1].set_title('Particle Count for ') ax[1].set_title('Particle Count')
plt.tight_layout() plt.tight_layout()
if debug: if debug:
@@ -278,37 +287,42 @@ def flat_for_store(parameters):
if __name__ == '__main__': if __name__ == '__main__':
self_train = True self_train = True
training = True training = False
train_to_id_first = False train_to_id_first = True
train_to_task_first = False train_to_task_first = False
train_to_task_first_sequential = True sequential_task_train = True
force_st_for_n_from_last_epochs = 5 force_st_for_n_from_last_epochs = 5
n_st_per_batch = 3
activation = None # nn.ReLU()
use_sparse_network = False use_sparse_network = True
tsk_threshold = 0.855 for weight_hidden_size in [3, 4, 5, 6]:
self_train_alpha = 1
batch_train_beta = 1 tsk_threshold = 0.85
weight_hidden_size = 3 weight_hidden_size = weight_hidden_size
residual_skip = True residual_skip = True
n_seeds = 5 n_seeds = 3
data_path = Path('data') data_path = Path('data')
data_path.mkdir(exist_ok=True, parents=True) data_path.mkdir(exist_ok=True, parents=True)
assert not (train_to_task_first and train_to_id_first) assert not (train_to_task_first and train_to_id_first)
st_str = f'{"" if self_train else "no_"}st' st_str = f'{"" if self_train else "no_"}st{f"_n_{n_st_per_batch}" if n_st_per_batch else ""}'
a_str = f'_alpha_{self_train_alpha}' if self_train_alpha != 1 else '' ac_str = f'_{activation.__class__.__name__}' if activation is not None else ''
res_str = f'{"" if residual_skip else "_no_res"}' res_str = f'{"" if residual_skip else "_no_res"}'
# dr_str = f'{f"_dr_{dropout}" if dropout != 0 else ""}' # dr_str = f'{f"_dr_{dropout}" if dropout != 0 else ""}'
id_str = f'{f"_StToId" if train_to_id_first else ""}' id_str = f'{f"_StToId" if train_to_id_first else ""}'
tsk_str = f'{f"_Tsk_{tsk_threshold}" if train_to_task_first and tsk_threshold != 1 else ""}' tsk_str = f'{f"_Tsk_{tsk_threshold}" if train_to_task_first and tsk_threshold != 1 else ""}'
sprs_str = '_sprs' if use_sparse_network else '' sprs_str = '_sprs' if use_sparse_network else ''
f_str = f'_f_{force_st_for_n_from_last_epochs}' if \ f_str = f'_f_{force_st_for_n_from_last_epochs}' if \
force_st_for_n_from_last_epochs and train_to_task_first_sequential and train_to_task_first \ force_st_for_n_from_last_epochs and sequential_task_train and train_to_task_first else ""
else "" config_str = f'{res_str}{id_str}{tsk_str}{f_str}{sprs_str}'
config_str = f'{a_str}{res_str}{id_str}{tsk_str}{f_str}{sprs_str}' exp_path = Path('output') / f'mn_{st_str}_{EPOCH}_{weight_hidden_size}{config_str}{ac_str}'
exp_path = Path('output') / f'mn_{st_str}_{EPOCH}_{weight_hidden_size}{config_str}'
if not training:
# noinspection PyRedeclaration
exp_path = Path('output') / 'mn_st_n_2_100_4'
for seed in range(n_seeds): for seed in range(n_seeds):
seed_path = exp_path / str(seed) seed_path = exp_path / str(seed)
@@ -317,10 +331,12 @@ if __name__ == '__main__':
df_store_path = seed_path / 'train_store.csv' df_store_path = seed_path / 'train_store.csv'
weight_store_path = seed_path / 'weight_store.csv' weight_store_path = seed_path / 'weight_store.csv'
srnn_parameters = dict() srnn_parameters = dict()
if training:
# Check if files do exist on project location, warn and break.
for path in [model_path, df_store_path, weight_store_path]: for path in [model_path, df_store_path, weight_store_path]:
assert not path.exists(), f'Path "{path}" already exists. Check your configuration!' assert not path.exists(), f'Path "{path}" already exists. Check your configuration!'
if training:
utility_transforms = Compose([ToTensor(), ToFloat(), Resize((15, 15)), Flatten(start_dim=0)]) utility_transforms = Compose([ToTensor(), ToFloat(), Resize((15, 15)), Flatten(start_dim=0)])
try: try:
dataset = MNIST(str(data_path), transform=utility_transforms) dataset = MNIST(str(data_path), transform=utility_transforms)
@@ -330,9 +346,9 @@ if __name__ == '__main__':
interface = np.prod(dataset[0][0].shape) interface = np.prod(dataset[0][0].shape)
dense_metanet = MetaNet(interface, depth=5, width=6, out=10, residual_skip=residual_skip, dense_metanet = MetaNet(interface, depth=5, width=6, out=10, residual_skip=residual_skip,
weight_hidden_size=weight_hidden_size,).to(DEVICE) weight_hidden_size=weight_hidden_size, activation=activation).to(DEVICE)
sparse_metanet = SparseNetwork(interface, depth=5, width=6, out=10, residual_skip=residual_skip, sparse_metanet = SparseNetwork(interface, depth=5, width=6, out=10, residual_skip=residual_skip,
weight_hidden_size=weight_hidden_size weight_hidden_size=weight_hidden_size, activation=activation
).to(DEVICE) if use_sparse_network else dense_metanet ).to(DEVICE) if use_sparse_network else dense_metanet
meta_weight_count = sum(p.numel() for p in next(dense_metanet.particles).parameters()) meta_weight_count = sum(p.numel() for p in next(dense_metanet.particles).parameters())
@@ -358,7 +374,7 @@ if __name__ == '__main__':
x.is_fixpoint == ft.identity_func for x in dense_metanet.particles x.is_fixpoint == ft.identity_func for x in dense_metanet.particles
) )
force_st = (force_st_for_n_from_last_epochs >= (EPOCH - epoch) force_st = (force_st_for_n_from_last_epochs >= (EPOCH - epoch)
) and train_to_task_first_sequential and force_st_for_n_from_last_epochs ) and sequential_task_train and force_st_for_n_from_last_epochs
for batch, (batch_x, batch_y) in tqdm(enumerate(d), total=len(d), desc='MetaNet Train - Batch'): for batch, (batch_x, batch_y) in tqdm(enumerate(d), total=len(d), desc='MetaNet Train - Batch'):
# Self Train # Self Train
@@ -366,12 +382,9 @@ if __name__ == '__main__':
# Transfer weights # Transfer weights
if use_sparse_network: if use_sparse_network:
sparse_metanet = sparse_metanet.replace_weights_by_particles(dense_metanet.particles) sparse_metanet = sparse_metanet.replace_weights_by_particles(dense_metanet.particles)
# Zero your gradients for every batch! for _ in range(n_st_per_batch):
sparse_optimizer.zero_grad() self_train_loss = sparse_metanet.combined_self_train(sparse_optimizer, reduction='mean')
self_train_loss = sparse_metanet.combined_self_train() * self_train_alpha # noinspection PyUnboundLocalVariable
self_train_loss.backward()
# Adjust learning weights
sparse_optimizer.step()
step_log = dict(Epoch=epoch, Batch=batch, step_log = dict(Epoch=epoch, Batch=batch,
Metric='Self Train Loss', Score=self_train_loss.item()) Metric='Self Train Loss', Score=self_train_loss.item())
train_store.loc[train_store.shape[0]] = step_log train_store.loc[train_store.shape[0]] = step_log
@@ -386,7 +399,7 @@ if __name__ == '__main__':
batch_x, batch_y = batch_x.to(DEVICE), batch_y.to(DEVICE) batch_x, batch_y = batch_x.to(DEVICE), batch_y.to(DEVICE)
y_pred = dense_metanet(batch_x) y_pred = dense_metanet(batch_x)
# loss = loss_fn(y, batch_y.unsqueeze(-1).to(torch.float32)) # loss = loss_fn(y, batch_y.unsqueeze(-1).to(torch.float32))
loss = loss_fn(y_pred, batch_y.to(torch.long)) * batch_train_beta loss = loss_fn(y_pred, batch_y.to(torch.long))
loss.backward() loss.backward()
# Adjust learning weights # Adjust learning weights
@@ -412,15 +425,17 @@ if __name__ == '__main__':
validation_log = dict(Epoch=int(epoch), Batch=BATCHSIZE, validation_log = dict(Epoch=int(epoch), Batch=BATCHSIZE,
Metric='Test Accuracy', Score=accuracy) Metric='Test Accuracy', Score=accuracy)
train_store.loc[train_store.shape[0]] = validation_log train_store.loc[train_store.shape[0]] = validation_log
if init_tsk or (train_to_task_first and train_to_task_first_sequential): if init_tsk or (train_to_task_first and sequential_task_train):
init_tsk = accuracy <= tsk_threshold init_tsk = accuracy <= tsk_threshold
if init_st or is_validation_epoch: if init_st or is_validation_epoch:
counter_dict = defaultdict(lambda: 0) counter_dict = defaultdict(lambda: 0)
# This returns ID-functions # This returns ID-functions
_ = test_for_fixpoints(counter_dict, list(dense_metanet.particles)) _ = test_for_fixpoints(counter_dict, list(dense_metanet.particles))
for key, value in dict(counter_dict).items(): counter_dict = dict(counter_dict)
for key, value in counter_dict.items():
step_log = dict(Epoch=int(epoch), Batch=BATCHSIZE, Metric=key, Score=value) step_log = dict(Epoch=int(epoch), Batch=BATCHSIZE, Metric=key, Score=value)
train_store.loc[train_store.shape[0]] = step_log train_store.loc[train_store.shape[0]] = step_log
tqdm.write(f'Fixpoint Tester Results: {counter_dict}')
if init_st or is_validation_epoch: if init_st or is_validation_epoch:
for particle in dense_metanet.particles: for particle in dense_metanet.particles:
weight_log = (epoch, particle.name, *flat_for_store(particle.parameters())) weight_log = (epoch, particle.name, *flat_for_store(particle.parameters()))

View File

@@ -6,11 +6,14 @@ from tqdm import tqdm
from network import FixTypes, Net from network import FixTypes, Net
epsilon_error_margin = pow(10, -5)
def is_divergent(network: Net) -> bool: def is_divergent(network: Net) -> bool:
return network.input_weight_matrix().isinf().any().item() or network.input_weight_matrix().isnan().any().item() return network.input_weight_matrix().isinf().any().item() or network.input_weight_matrix().isnan().any().item()
def is_identity_function(network: Net, epsilon=pow(10, -5)) -> bool: def is_identity_function(network: Net, epsilon=epsilon_error_margin) -> bool:
input_data = network.input_weight_matrix() input_data = network.input_weight_matrix()
target_data = network.create_target_weights(input_data) target_data = network.create_target_weights(input_data)
@@ -20,14 +23,14 @@ def is_identity_function(network: Net, epsilon=pow(10, -5)) -> bool:
rtol=0, atol=epsilon) rtol=0, atol=epsilon)
def is_zero_fixpoint(network: Net, epsilon=pow(10, -5)) -> bool: def is_zero_fixpoint(network: Net, epsilon=epsilon_error_margin) -> bool:
target_data = network.create_target_weights(network.input_weight_matrix().detach()) target_data = network.create_target_weights(network.input_weight_matrix().detach())
result = torch.allclose(target_data, torch.zeros_like(target_data), rtol=0, atol=epsilon) result = torch.allclose(target_data, torch.zeros_like(target_data), rtol=0, atol=epsilon)
# result = bool(len(np.nonzero(network.create_target_weights(network.input_weight_matrix())))) # result = bool(len(np.nonzero(network.create_target_weights(network.input_weight_matrix()))))
return result return result
def is_secondary_fixpoint(network: Net, epsilon: float = pow(10, -5)) -> bool: def is_secondary_fixpoint(network: Net, epsilon: float = epsilon_error_margin) -> bool:
""" Secondary fixpoint check is done like this: compare first INPUT with second OUTPUT. """ Secondary fixpoint check is done like this: compare first INPUT with second OUTPUT.
If they are within the boundaries, then is secondary fixpoint. """ If they are within the boundaries, then is secondary fixpoint. """

View File

@@ -420,7 +420,7 @@ class MetaNet(nn.Module):
) for layer_idx in range(self.depth - 2)] ) for layer_idx in range(self.depth - 2)]
) )
self._meta_layer_last = MetaLayer(name=f'L{len(self._meta_layer_list)}', self._meta_layer_last = MetaLayer(name=f'L{len(self._meta_layer_list) + 1}',
interface=self.width, width=self.out, interface=self.width, width=self.out,
weight_interface=weight_interface, weight_interface=weight_interface,
weight_hidden_size=weight_hidden_size, weight_hidden_size=weight_hidden_size,
@@ -428,8 +428,6 @@ class MetaNet(nn.Module):
) )
self.dropout_layer = nn.Dropout(p=self.dropout) self.dropout_layer = nn.Dropout(p=self.dropout)
self._all_layers_with_particles = [self._meta_layer_first, *self._meta_layer_list, self._meta_layer_last]
def replace_with_zero(self, ident_key): def replace_with_zero(self, ident_key):
replaced_particles = 0 replaced_particles = 0
for particle in self.particles: for particle in self.particles:
@@ -442,48 +440,51 @@ class MetaNet(nn.Module):
return self return self
def forward(self, x): def forward(self, x):
if self.dropout != 0:
x = self.dropout_layer(x)
tensor = self._meta_layer_first(x) tensor = self._meta_layer_first(x)
residual = None
for idx, meta_layer in enumerate(self._meta_layer_list, start=1): for idx, meta_layer in enumerate(self._meta_layer_list, start=1):
if self.dropout != 0:
tensor = self.dropout_layer(tensor)
if idx % 2 == 1 and self.residual_skip: if idx % 2 == 1 and self.residual_skip:
x = tensor.clone() residual = tensor.clone()
tensor = meta_layer(tensor) tensor = meta_layer(tensor)
if idx % 2 == 0 and self.residual_skip: if idx % 2 == 0 and self.residual_skip:
tensor = tensor + x tensor = tensor + residual
if self.dropout != 0: tensor = self._meta_layer_last(tensor)
x = self.dropout_layer(x)
tensor = self._meta_layer_last(x)
return tensor return tensor
@property @property
def particles(self): def particles(self):
return (cell for metalayer in self._all_layers_with_particles for cell in metalayer.particles) return (cell for metalayer in self.all_layers for cell in metalayer.particles)
def combined_self_train(self): def combined_self_train(self, optimizer, reduction='mean'):
optimizer.zero_grad()
losses = [] losses = []
for particle in self.particles: for particle in self.particles:
# Intergrate optimizer and backward function # Intergrate optimizer and backward function
input_data = particle.input_weight_matrix() input_data = particle.input_weight_matrix()
target_data = particle.create_target_weights(input_data) target_data = particle.create_target_weights(input_data)
output = particle(input_data) output = particle(input_data)
losses.append(F.mse_loss(output, target_data)) losses.append(F.mse_loss(output, target_data, reduction=reduction))
return torch.hstack(losses).sum(dim=-1, keepdim=True) losses = torch.hstack(losses).sum(dim=-1, keepdim=True)
losses.backward()
optimizer.step()
return losses.detach()
@property @property
def hyperparams(self): def hyperparams(self):
return {key: val for key, val in self.__dict__.items() if not key.startswith('_')} return {key: val for key, val in self.__dict__.items() if not key.startswith('_')}
def replace_particles(self, particle_weights_list): def replace_particles(self, particle_weights_list):
for layer in self._all_layers_with_particles: for layer in self.all_layers:
for cell in layer.meta_cell_list: for cell in layer.meta_cell_list:
# Individual replacement on cell lvl # Individual replacement on cell lvl
for weight in cell.meta_weight_list: for weight in cell.meta_weight_list:
weight.apply_weights(next(particle_weights_list).detach()) weight.apply_weights(next(particle_weights_list).detach())
return self return self
@property
def all_layers(self):
return (x for x in (self._meta_layer_first, *self._meta_layer_list, self._meta_layer_last))
class MetaNetCompareBaseline(nn.Module): class MetaNetCompareBaseline(nn.Module):
@@ -495,19 +496,24 @@ class MetaNetCompareBaseline(nn.Module):
self.interface = interface self.interface = interface
self.width = width self.width = width
self.depth = depth self.depth = depth
self._first_layer = nn.Linear(self.interface, self.width, bias=False) self._first_layer = nn.Linear(self.interface, self.width, bias=False)
self._meta_layer_list = nn.ModuleList([nn.Linear(self.width, self.width, bias=False) for _ in range(self.depth - 2)]) self._meta_layer_list = nn.ModuleList([nn.Linear(self.width, self.width, bias=False
) for _ in range(self.depth - 2)])
self._last_layer = nn.Linear(self.width, self.out, bias=False) self._last_layer = nn.Linear(self.width, self.out, bias=False)
def forward(self, x): def forward(self, x):
tensor = self._first_layer(x) tensor = self._first_layer(x)
if self.activation:
tensor = self.activation(tensor)
residual = None
for idx, meta_layer in enumerate(self._meta_layer_list, start=1): for idx, meta_layer in enumerate(self._meta_layer_list, start=1):
if idx % 2 == 1 and self.residual_skip:
x = tensor.clone()
tensor = meta_layer(tensor) tensor = meta_layer(tensor)
if idx % 2 == 1 and self.residual_skip:
residual = tensor.clone()
if idx % 2 == 0 and self.residual_skip: if idx % 2 == 0 and self.residual_skip:
tensor = tensor + x tensor = tensor + residual
if self.activation:
tensor = self.activation(tensor)
tensor = self._last_layer(tensor) tensor = self._last_layer(tensor)
return tensor return tensor

View File

@@ -10,8 +10,11 @@ from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import MNIST, CIFAR10 from torchvision.datasets import MNIST, CIFAR10
from torchvision.transforms import ToTensor, Compose, Resize, Normalize, Grayscale from torchvision.transforms import ToTensor, Compose, Resize, Normalize, Grayscale
import torchmetrics import torchmetrics
from functionalities_test import epsilon_error_margin as e
from network import MetaNet, MetaNetCompareBaseline from network import MetaNet, MetaNetCompareBaseline
def extract_weights_from_model(model:MetaNet)->dict: def extract_weights_from_model(model:MetaNet)->dict:
inpt = torch.zeros(5) inpt = torch.zeros(5)
inpt[-1] = 1 inpt[-1] = 1
@@ -25,27 +28,51 @@ def extract_weights_from_model(model:MetaNet)->dict:
return dict(weights) return dict(weights)
def test_weights_as_model(model, new_weights:dict, data): def test_weights_as_model(meta_net, new_weights:dict, data):
TransferNet = MetaNetCompareBaseline(model.interface, depth=model.depth, width=model.width, out=model.out, transfer_net = MetaNetCompareBaseline(meta_net.interface, depth=meta_net.depth, width=meta_net.width, out=meta_net.out,
residual_skip=True) residual_skip=True)
with torch.no_grad(): with torch.no_grad():
for weights, parameters in zip(new_weights.values(), TransferNet.parameters()): new_weight_values = list(new_weights.values())
old_parameters = list(transfer_net.parameters())
assert len(new_weight_values) == len(old_parameters)
for weights, parameters in zip(new_weights.values(), transfer_net.parameters()):
parameters[:] = torch.Tensor(weights).view(parameters.shape)[:] parameters[:] = torch.Tensor(weights).view(parameters.shape)[:]
TransferNet.eval() transfer_net.eval()
# Test if the margin of error is similar
im_t = defaultdict(list)
rand = torch.randn((1, 15 * 15))
for net in [meta_net, transfer_net]:
tensor = rand.clone()
for layer in net.all_layers:
tensor = layer(tensor)
im_t[net.__class__.__name__].append(tensor.detach())
im_t = dict(im_t)
all_close = {f'layer_{idx}': torch.allclose(y1.detach(), y2.detach(), rtol=0, atol=e
) for idx, (y1, y2) in enumerate(zip(*im_t.values()))
}
print(f'Cummulative differences per layer is smaller then {e}:\n {all_close}')
# all_errors = {f'layer_{idx}': torch.absolute(y1.detach(), y2.detach(), rtol=0, atol=e
# ) for idx, (y1, y2) in enumerate(zip(*im_t.values()))
# }
for net in [meta_net, transfer_net]:
net.eval()
metric = torchmetrics.Accuracy() metric = torchmetrics.Accuracy()
with tqdm(desc='Test Batch: ') as pbar: with tqdm(desc='Test Batch: ') as pbar:
for batch, (batch_x, batch_y) in tqdm(enumerate(data), total=len(data), desc='MetaNet Sanity Check'): for batch, (batch_x, batch_y) in tqdm(enumerate(data), total=len(data), desc='MetaNet Sanity Check'):
y = TransferNet(batch_x) y = net(batch_x)
acc = metric(y.cpu(), batch_y.cpu()) acc = metric(y.cpu(), batch_y.cpu())
pbar.set_postfix_str(f'Acc: {acc}') pbar.set_postfix_str(f'Acc: {acc}')
pbar.update() pbar.update()
# metric on all batches using custom accumulation # metric on all batches using custom accumulation
acc = metric.compute() acc = metric.compute()
tqdm.write(f"Avg. accuracy on all data: {acc}") tqdm.write(f"Avg. accuracy on {net.__class__.__name__}: {acc}")
return acc
if __name__ == '__main__': if __name__ == '__main__':

View File

@@ -161,7 +161,7 @@ def embed_vector(x, repeat_dim):
class SparseNetwork(nn.Module): class SparseNetwork(nn.Module):
def __init__(self, input_dim, depth, width, out, residual_skip=True, def __init__(self, input_dim, depth, width, out, residual_skip=True, activation=None,
weight_interface=5, weight_hidden_size=2, weight_output_size=1 weight_interface=5, weight_hidden_size=2, weight_output_size=1
): ):
super(SparseNetwork, self).__init__() super(SparseNetwork, self).__init__()
@@ -170,6 +170,7 @@ class SparseNetwork(nn.Module):
self.depth_dim = depth self.depth_dim = depth
self.hidden_dim = width self.hidden_dim = width
self.out_dim = out self.out_dim = out
self.activation = activation
self.first_layer = SparseLayer(self.input_dim * self.hidden_dim, self.first_layer = SparseLayer(self.input_dim * self.hidden_dim,
interface=weight_interface, width=weight_hidden_size, out=weight_output_size) interface=weight_interface, width=weight_hidden_size, out=weight_output_size)
self.last_layer = SparseLayer(self.hidden_dim * self.out_dim, self.last_layer = SparseLayer(self.hidden_dim * self.out_dim,
@@ -182,13 +183,17 @@ class SparseNetwork(nn.Module):
def __call__(self, x): def __call__(self, x):
tensor = self.sparse_layer_forward(x, self.first_layer) tensor = self.sparse_layer_forward(x, self.first_layer)
if self.activation:
tensor = self.activation(tensor)
for nl_idx, network_layer in enumerate(self.hidden_layers): for nl_idx, network_layer in enumerate(self.hidden_layers):
if nl_idx % 2 == 0 and self.residual_skip:
residual = tensor
# Sparse Layer pass # Sparse Layer pass
tensor = self.sparse_layer_forward(tensor, network_layer) tensor = self.sparse_layer_forward(tensor, network_layer)
if nl_idx % 2 != 0 and self.residual_skip: if self.activation:
tensor = self.activation(tensor)
if nl_idx % 2 == 0 and self.residual_skip:
residual = tensor.clone()
if nl_idx % 2 == 1 and self.residual_skip:
# noinspection PyUnboundLocalVariable # noinspection PyUnboundLocalVariable
tensor += residual tensor += residual
tensor = self.sparse_layer_forward(tensor, self.last_layer, view_dim=self.out_dim) tensor = self.sparse_layer_forward(tensor, self.last_layer, view_dim=self.out_dim)
@@ -234,14 +239,19 @@ class SparseNetwork(nn.Module):
def sparselayers(self): def sparselayers(self):
return (x for x in (self.first_layer, *self.hidden_layers, self.last_layer)) return (x for x in (self.first_layer, *self.hidden_layers, self.last_layer))
def combined_self_train(self): def combined_self_train(self, optimizer, reduction='mean'):
losses = [] losses = []
for layer in self.sparselayers: for layer in self.sparselayers:
optimizer.zero_grad()
x, target_data = layer.get_self_train_inputs_and_targets() x, target_data = layer.get_self_train_inputs_and_targets()
output = layer(x) output = layer(x)
losses.append(F.mse_loss(output, target_data) / layer.nr_nets) loss = F.mse_loss(output, target_data, reduction=reduction)
return torch.hstack(losses).sum(dim=-1, keepdim=True) losses.append(loss.detach())
loss.backward()
optimizer.step()
return sum(losses)
def replace_weights_by_particles(self, particles): def replace_weights_by_particles(self, particles):
particles = list(particles) particles = list(particles)
@@ -274,12 +284,7 @@ def test_sparse_net_sef_train():
if True: if True:
optimizer = torch.optim.SGD(net.parameters(), lr=0.004, momentum=0.9) optimizer = torch.optim.SGD(net.parameters(), lr=0.004, momentum=0.9)
for _ in trange(epochs): for _ in trange(epochs):
optimizer.zero_grad() _ = net.combined_self_train(optimizer)
loss = net.combined_self_train()
print(loss)
exit()
loss.backward()
optimizer.step()
else: else:
optimizer_dict = { optimizer_dict = {