diff --git a/README.md b/README.md index 214a23c..39d4967 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,7 @@ Data Exchange: [Google Drive Folder](***REMOVED***) - Übersetung in ein Gewichtsskalar - Einbettung in ein Reguläres Netz -- [ ] Übersetung in ein Explainable AI Framework +- [ ] Übersetzung in ein Explainable AI Framework - Rückschlüsse auf Mikro Netze - [ ] Visualiserung diff --git a/experiments/meta_task_exp.py b/experiments/meta_task_exp.py index 270b2ed..000c159 100644 --- a/experiments/meta_task_exp.py +++ b/experiments/meta_task_exp.py @@ -40,6 +40,7 @@ from network import MetaNet from functionalities_test import test_for_fixpoints, FixTypes WORKER = 10 if not debug else 2 +debug = False BATCHSIZE = 500 if not debug else 50 EPOCH = 200 VALIDATION_FRQ = 5 if not debug else 1 @@ -200,14 +201,22 @@ if __name__ == '__main__': as_sparse_network_test = True self_train_alpha = 1 batch_train_beta = 1 + weight_hidden_size = 5 + residual_skip = True + dropout = 0.1 data_path = Path('data') data_path.mkdir(exist_ok=True, parents=True) - run_path = Path('output') / 'mn_st_400_2_no_res' + st_str = f'{"" if self_train else "no_"}st' + res_str = f'{"" if residual_skip else "_no"}_res' + dr_str = f'{f"_dr_{dropout}" if dropout != 0 else ""}' + run_path = Path('output') / f'mn_{st_str}_{EPOCH}_{weight_hidden_size}{res_str}{dr_str}' + model_path = run_path / '0000_trained_model.zip' df_store_path = run_path / 'train_store.csv' weight_store_path = run_path / 'weight_store.csv' + srnn_parameters = dict() if training: utility_transforms = Compose([ToTensor(), ToFloat(), Resize((15, 15)), Flatten(start_dim=0)]) @@ -218,7 +227,9 @@ if __name__ == '__main__': d = DataLoader(dataset, batch_size=BATCHSIZE, shuffle=True, drop_last=True, num_workers=WORKER) interface = np.prod(dataset[0][0].shape) - metanet = MetaNet(interface, depth=5, width=6, out=10, residual_skip=False).to(DEVICE) + metanet = MetaNet(interface, depth=5, width=6, out=10, residual_skip=residual_skip, dropout=dropout, + weight_hidden_size=weight_hidden_size, + ).to(DEVICE) meta_weight_count = sum(p.numel() for p in next(metanet.particles).parameters()) loss_fn = nn.CrossEntropyLoss() @@ -315,7 +326,13 @@ if __name__ == '__main__': plot_training_particle_types(df_store_path) if particle_analysis: - model_path = next(run_path.glob(f'*e{EPOCH}.tp')) + try: + model_path = next(run_path.glob(f'*e{EPOCH}.tp')) + except StopIteration: + print('Model pattern did not trigger.') + print(f'Search path was: {run_path}:') + print(f'Found Models are: {list(run_path.rglob(".tp"))}') + exit(1) latest_model = torch.load(model_path, map_location=DEVICE).eval() counter_dict = defaultdict(lambda: 0) _ = test_for_fixpoints(counter_dict, list(latest_model.particles)) @@ -323,21 +340,22 @@ if __name__ == '__main__': if as_sparse_network_test: acc_pre = validate(model_path, ratio=1).item() - diff_table = pd.DataFrame(columns=['Particle Type', 'Accuracy', 'Diff']) + diff_df = pd.DataFrame(columns=['Particle Type', 'Accuracy', 'Diff']) for fixpoint_type in FixTypes.all_types(): new_model = torch.load(model_path, map_location=DEVICE).eval().replace_with_zero(fixpoint_type) - new_ckpt = set_checkpoint(new_model, model_path.parent, fixpoint_type, final_model=True) - acc_post = validate(new_ckpt, ratio=1).item() - acc_diff = abs(acc_post-acc_pre) - tqdm.write(f'Zero_ident diff = {acc_diff}') - diff_table.iloc[diff_table.shape[0]] = (fixpoint_type, acc_post, acc_diff) + if [x for x in new_model.particles if x.is_fixpoint == fixpoint_type]: + new_ckpt = set_checkpoint(new_model, model_path.parent, fixpoint_type, final_model=True) + acc_post = validate(new_ckpt, ratio=1).item() + acc_diff = abs(acc_post-acc_pre) + tqdm.write(f'Zero_ident diff = {acc_diff}') + diff_df.loc[diff_df.shape[0]] = (fixpoint_type, acc_post, acc_diff) if plotting: plt.clf() fig, ax = plt.subplots(ncols=2) labels = ['Full Network', 'Sparse, No Identity', 'Sparse, No Other'] - barplot = sns.barplot(data=diff_table, y='Accurady', x=['Particle Type'], - color=sns.color_palette()[:diff_table.shape[0]], ax=ax[0]) + colors = sns.color_palette()[:diff_df.shape[0]] if diff_df.shape[0] >= 2 else sns.color_palette()[0] + barplot = sns.barplot(data=diff_df, y='Accuracy', x='Particle Type', color=colors, ax=ax[0]) # noinspection PyUnboundLocalVariable for idx, patch in enumerate(barplot.patches): if idx != 0: diff --git a/network.py b/network.py index cf008aa..6893eca 100644 --- a/network.py +++ b/network.py @@ -291,7 +291,7 @@ class SecondaryNet(Net): class MetaCell(nn.Module): - def __init__(self, name, interface): + def __init__(self, name, interface, weight_interface=5, weight_hidden_size=2, weight_output_size=1): super().__init__() self.name = name self.interface = interface @@ -342,7 +342,8 @@ class MetaCell(nn.Module): class MetaLayer(nn.Module): - def __init__(self, name, interface=4, width=4, residual_skip=True): + def __init__(self, name, interface=4, width=4, residual_skip=True, + weight_interface=5, weight_hidden_size=2, weight_output_size=1): super().__init__() self.residual_skip = residual_skip self.name = name @@ -351,7 +352,9 @@ class MetaLayer(nn.Module): self.meta_cell_list = nn.ModuleList() self.meta_cell_list.extend([MetaCell(name=f'{self.name}_C{cell_idx}', - interface=interface + interface=interface, + weight_interface=weight_interface, weight_hidden_size=weight_hidden_size, + weight_output_size=weight_output_size, ) for cell_idx in range(self.width)] ) @@ -371,26 +374,42 @@ class MetaLayer(nn.Module): class MetaNet(nn.Module): - def __init__(self, interface=4, depth=3, width=4, out=1, activation=None, residual_skip=True): + def __init__(self, interface=4, depth=3, width=4, out=1, activation=None, residual_skip=True, dropout=0, + weight_interface=5, weight_hidden_size=2, weight_output_size=1,): super().__init__() + self.dropout = dropout self.activation = activation self.out = out self.interface = interface self.width = width self.depth = depth + self.weight_interface = weight_interface + self.weight_hidden_size = weight_hidden_size + self.weight_output_size = weight_output_size self._meta_layer_list = nn.ModuleList() self._meta_layer_list.append(MetaLayer(name=f'L{0}', interface=self.interface, - width=self.width, residual_skip=residual_skip) + width=self.width, residual_skip=residual_skip, + weight_interface=weight_interface, + weight_hidden_size=weight_hidden_size, + weight_output_size=weight_output_size) ) self._meta_layer_list.extend([MetaLayer(name=f'L{layer_idx + 1}', - interface=self.width, width=self.width, residual_skip=residual_skip + interface=self.width, width=self.width, residual_skip=residual_skip, + weight_interface=weight_interface, + weight_hidden_size=weight_hidden_size, + weight_output_size=weight_output_size, ) for layer_idx in range(self.depth - 2)] ) self._meta_layer_list.append(MetaLayer(name=f'L{len(self._meta_layer_list)}', - interface=self.width, width=self.out, residual_skip=residual_skip) + interface=self.width, width=self.out, residual_skip=residual_skip, + weight_interface=weight_interface, + weight_hidden_size=weight_hidden_size, + weight_output_size=weight_output_size, + ) ) + self.dropout_layer = nn.Dropout(p=self.dropout) def replace_with_zero(self, ident_key): replaced_particles = 0 @@ -406,6 +425,8 @@ class MetaNet(nn.Module): def forward(self, x): tensor = x for meta_layer in self._meta_layer_list: + if self.dropout: + tensor = self.dropout_layer(tensor) tensor = meta_layer(tensor) return tensor @@ -423,6 +444,10 @@ class MetaNet(nn.Module): losses.append(F.mse_loss(output, target_data)) return torch.hstack(losses).sum(dim=-1, keepdim=True) + @property + def hyperparams(self): + return {key: val for key, val in self.__dict__.items() if not key.startswith('_')} + class MetaNetCompareBaseline(nn.Module): @@ -437,7 +462,7 @@ class MetaNetCompareBaseline(nn.Module): self._meta_layer_list = nn.ModuleList() self._meta_layer_list.append(nn.Linear(self.interface, self.width, bias=False)) - self._meta_layer_list.extend([ nn.Linear(self.width, self.width, bias=False) for _ in range(self.depth - 2)]) + self._meta_layer_list.extend([nn.Linear(self.width, self.width, bias=False) for _ in range(self.depth - 2)]) self._meta_layer_list.append(nn.Linear(self.width, self.out, bias=False)) def forward(self, x): diff --git a/sparse_tensor_combined.ipynb b/sparse_tensor_combined.ipynb index 141cd51..7b430dc 100644 --- a/sparse_tensor_combined.ipynb +++ b/sparse_tensor_combined.ipynb @@ -495,4 +495,4 @@ }, "nbformat": 4, "nbformat_minor": 2 -} +} \ No newline at end of file