diff --git a/experiments/meta_task_exp.py b/experiments/meta_task_exp.py
index c7fb8ad..3889e55 100644
--- a/experiments/meta_task_exp.py
+++ b/experiments/meta_task_exp.py
@@ -45,8 +45,8 @@ from functionalities_test import test_for_fixpoints
 WORKER = 10 if not debug else 2
 debug = False
 BATCHSIZE = 500 if not debug else 50
-EPOCH = 50
-VALIDATION_FRQ = 3 if not debug else 1
+EPOCH = 100
+VALIDATION_FRQ = 4 if not debug else 1
 SELF_TRAIN_FRQ = 1 if not debug else 1
 DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
@@ -56,6 +56,9 @@ if debug:
 
 class ToFloat:
 
+    def __init__(self):
+        pass
+
     def __call__(self, x):
         return x.to(torch.float32)
 
@@ -194,7 +197,7 @@ def plot_training_result(path_to_dataframe):
 def plot_network_connectivity_by_fixtype(path_to_trained_model):
     m = torch.load(path_to_trained_model, map_location=torch.device('cpu'))
     # noinspection PyProtectedMember
-    particles = [y for x in m._meta_layer_list for y in x.particles]
+    particles = list(m.particles)
     df = pd.DataFrame(columns=['type', 'layer', 'neuron', 'name'])
 
     for prtcl in particles:
@@ -210,10 +213,16 @@ def plot_network_connectivity_by_fixtype(path_to_trained_model):
     for n, fixtype in enumerate([ft.other_func, ft.identity_func]):
         plt.clf()
         ax = sns.lineplot(y='neuron', x='layer', hue='name', data=df[df['type'] == fixtype],
-                          legend=False, estimator=None,
-                          palette=[sns.color_palette()[n]] * (df[df['type'] == fixtype].shape[0]//2), lw=1)
+                          legend=False, estimator=None, lw=1)
+        _ = sns.lineplot(y=[0, 1], x=[-1, df['layer'].max()], legend=False, estimator=None, lw=0)
         ax.set_title(fixtype)
-        plt.show()
+        lines = ax.get_lines()
+        for line in lines:
+            line.set_color(sns.color_palette()[n])
+        if debug:
+            plt.show()
+        else:
+            plt.savefig(Path(path_to_trained_model.parent / f'net_connectivity_{fixtype}.png'), dpi=300)
         print('plottet')
 
 
@@ -234,7 +243,7 @@ def run_particle_dropout_test(run_path):
             tqdm.write(f'Zero_ident diff = {acc_diff}')
             diff_df.loc[diff_df.shape[0]] = (fixpoint_type, acc_post, acc_diff)
 
-    diff_df.to_csv(diff_store_path, mode='a', header=not df_store_path.exists(), index=False)
+    diff_df.to_csv(diff_store_path, mode='a', header=not diff_store_path.exists(), index=False)
     return diff_store_path
 
 
@@ -246,18 +255,18 @@ def plot_dropout_stacked_barplot(path_to_diff_df):
     plt.clf()
     fig, ax = plt.subplots(ncols=2)
     colors = sns.color_palette()[:diff_df.shape[0]]
-    barplot = sns.barplot(data=diff_df, y='Accuracy', x='Particle Type', palette=colors, ax=ax[0])
+    barplot = sns.barplot(data=diff_df, y='Accuracy', x='Particle Type', ax=ax[0], palette=colors)
     # noinspection PyUnboundLocalVariable
-    for idx, patch in enumerate(barplot.patches):
-        if idx != 0:
-            # we recenter the bar
-            patch.set_x(patch.get_x() + idx * 0.035)
+    #for idx, patch in enumerate(barplot.patches):
+    #    if idx != 0:
+    #        # we recenter the bar
+    #        patch.set_x(patch.get_x() + idx * 0.035)
 
     ax[0].set_title('Accuracy after particle dropout')
-    ax[0].set_xlabel('Accuracy')
+    ax[0].set_xlabel('Particle Type')
 
     ax[1].pie(particle_dict.values(), labels=particle_dict.keys(), colors=colors, )
-    ax[1].set_title('Particle Count for ')
+    ax[1].set_title('Particle Count')
 
     plt.tight_layout()
     if debug:
@@ -278,196 +287,202 @@ def flat_for_store(parameters):
 if __name__ == '__main__':
 
     self_train = True
-    training = True
-    train_to_id_first = False
+    training = False
+    train_to_id_first = True
     train_to_task_first = False
-    train_to_task_first_sequential = True
+    sequential_task_train = True
     force_st_for_n_from_last_epochs = 5
+    n_st_per_batch = 3
+    activation = None  # nn.ReLU()
 
-    use_sparse_network = False
+    use_sparse_network = True
 
-    tsk_threshold = 0.855
-    self_train_alpha = 1
-    batch_train_beta = 1
-    weight_hidden_size = 3
-    residual_skip = True
-    n_seeds = 5
+    for weight_hidden_size in [3, 4, 5, 6]:
 
-    data_path = Path('data')
-    data_path.mkdir(exist_ok=True, parents=True)
-    assert not (train_to_task_first and train_to_id_first)
+        tsk_threshold = 0.85
+        weight_hidden_size = weight_hidden_size
+        residual_skip = True
+        n_seeds = 3
 
-    st_str = f'{"" if self_train else "no_"}st'
-    a_str = f'_alpha_{self_train_alpha}' if self_train_alpha != 1 else ''
-    res_str = f'{"" if residual_skip else "_no_res"}'
-    # dr_str = f'{f"_dr_{dropout}" if dropout != 0 else ""}'
-    id_str = f'{f"_StToId" if train_to_id_first else ""}'
-    tsk_str = f'{f"_Tsk_{tsk_threshold}" if train_to_task_first and tsk_threshold != 1 else ""}'
-    sprs_str = '_sprs' if use_sparse_network else ''
-    f_str = f'_f_{force_st_for_n_from_last_epochs}' if \
-        force_st_for_n_from_last_epochs and train_to_task_first_sequential and train_to_task_first \
-        else ""
-    config_str = f'{a_str}{res_str}{id_str}{tsk_str}{f_str}{sprs_str}'
-    exp_path = Path('output') / f'mn_{st_str}_{EPOCH}_{weight_hidden_size}{config_str}'
+        data_path = Path('data')
+        data_path.mkdir(exist_ok=True, parents=True)
+        assert not (train_to_task_first and train_to_id_first)
 
-    for seed in range(n_seeds):
-        seed_path = exp_path / str(seed)
+        st_str = f'{"" if self_train else "no_"}st{f"_n_{n_st_per_batch}" if n_st_per_batch else ""}'
+        ac_str = f'_{activation.__class__.__name__}' if activation is not None else ''
+        res_str = f'{"" if residual_skip else "_no_res"}'
+        # dr_str = f'{f"_dr_{dropout}" if dropout != 0 else ""}'
+        id_str = f'{f"_StToId" if train_to_id_first else ""}'
+        tsk_str = f'{f"_Tsk_{tsk_threshold}" if train_to_task_first and tsk_threshold != 1 else ""}'
+        sprs_str = '_sprs' if use_sparse_network else ''
+        f_str = f'_f_{force_st_for_n_from_last_epochs}' if \
+            force_st_for_n_from_last_epochs and sequential_task_train and train_to_task_first else ""
+        config_str = f'{res_str}{id_str}{tsk_str}{f_str}{sprs_str}'
+        exp_path = Path('output') / f'mn_{st_str}_{EPOCH}_{weight_hidden_size}{config_str}{ac_str}'
 
-        model_path = seed_path / '0000_trained_model.zip'
-        df_store_path = seed_path / 'train_store.csv'
-        weight_store_path = seed_path / 'weight_store.csv'
-        srnn_parameters = dict()
-        for path in [model_path, df_store_path, weight_store_path]:
-            assert not path.exists(), f'Path "{path}" already exists. Check your configuration!'
+        if not training:
+            # noinspection PyRedeclaration
+            exp_path = Path('output') / 'mn_st_n_2_100_4'
 
-        if training:
-            utility_transforms = Compose([ToTensor(), ToFloat(), Resize((15, 15)), Flatten(start_dim=0)])
-            try:
-                dataset = MNIST(str(data_path), transform=utility_transforms)
-            except RuntimeError:
-                dataset = MNIST(str(data_path), transform=utility_transforms, download=True)
-            d = DataLoader(dataset, batch_size=BATCHSIZE, shuffle=True, drop_last=True, num_workers=WORKER)
+        for seed in range(n_seeds):
+            seed_path = exp_path / str(seed)
 
-            interface = np.prod(dataset[0][0].shape)
-            dense_metanet = MetaNet(interface, depth=5, width=6, out=10, residual_skip=residual_skip,
-                                    weight_hidden_size=weight_hidden_size,).to(DEVICE)
-            sparse_metanet = SparseNetwork(interface, depth=5, width=6, out=10, residual_skip=residual_skip,
-                                           weight_hidden_size=weight_hidden_size
-                                           ).to(DEVICE) if use_sparse_network else dense_metanet
-            meta_weight_count = sum(p.numel() for p in next(dense_metanet.particles).parameters())
+            model_path = seed_path / '0000_trained_model.zip'
+            df_store_path = seed_path / 'train_store.csv'
+            weight_store_path = seed_path / 'weight_store.csv'
+            srnn_parameters = dict()
 
-            loss_fn = nn.CrossEntropyLoss()
-            dense_optimizer = torch.optim.SGD(dense_metanet.parameters(), lr=0.008, momentum=0.9)
-            sparse_optimizer = torch.optim.SGD(
-                sparse_metanet.parameters(), lr=0.008, momentum=0.9
-                                               ) if use_sparse_network else dense_optimizer
+            if training:
+                # Check if files do exist on project location, warn and break.
+                for path in [model_path, df_store_path, weight_store_path]:
+                    assert not path.exists(), f'Path "{path}" already exists. Check your configuration!'
 
-            train_store = new_storage_df('train', None)
-            weight_store = new_storage_df('weights', meta_weight_count)
-            init_tsk = train_to_task_first
-            for epoch in tqdm(range(EPOCH), desc='MetaNet Train - Epochs'):
-                is_validation_epoch = epoch % VALIDATION_FRQ == 0 if not debug else True
-                is_self_train_epoch = epoch % SELF_TRAIN_FRQ == 0 if not debug else True
-                sparse_metanet = sparse_metanet.train()
-                dense_metanet = dense_metanet.train()
-                if is_validation_epoch:
-                    metric = torchmetrics.Accuracy()
-                else:
-                    metric = None
-                init_st = train_to_id_first and not all(
-                    x.is_fixpoint == ft.identity_func for x in dense_metanet.particles
-                )
-                force_st = (force_st_for_n_from_last_epochs >= (EPOCH - epoch)
-                            ) and train_to_task_first_sequential and force_st_for_n_from_last_epochs
-                for batch, (batch_x, batch_y) in tqdm(enumerate(d), total=len(d), desc='MetaNet Train - Batch'):
+                utility_transforms = Compose([ToTensor(), ToFloat(), Resize((15, 15)), Flatten(start_dim=0)])
+                try:
+                    dataset = MNIST(str(data_path), transform=utility_transforms)
+                except RuntimeError:
+                    dataset = MNIST(str(data_path), transform=utility_transforms, download=True)
+                d = DataLoader(dataset, batch_size=BATCHSIZE, shuffle=True, drop_last=True, num_workers=WORKER)
 
-                    # Self Train
-                    if self_train and ((not init_tsk and (is_self_train_epoch or init_st)) or force_st):
-                        # Transfer weights
-                        if use_sparse_network:
-                            sparse_metanet = sparse_metanet.replace_weights_by_particles(dense_metanet.particles)
-                        # Zero your gradients for every batch!
-                        sparse_optimizer.zero_grad()
-                        self_train_loss = sparse_metanet.combined_self_train() * self_train_alpha
-                        self_train_loss.backward()
-                        # Adjust learning weights
-                        sparse_optimizer.step()
-                        step_log = dict(Epoch=epoch, Batch=batch,
-                                        Metric='Self Train Loss', Score=self_train_loss.item())
-                        train_store.loc[train_store.shape[0]] = step_log
-                        # Transfer weights
-                        if use_sparse_network:
-                            dense_metanet = dense_metanet.replace_particles(sparse_metanet.particle_weights)
+                interface = np.prod(dataset[0][0].shape)
+                dense_metanet = MetaNet(interface, depth=5, width=6, out=10, residual_skip=residual_skip,
+                                        weight_hidden_size=weight_hidden_size, activation=activation).to(DEVICE)
+                sparse_metanet = SparseNetwork(interface, depth=5, width=6, out=10, residual_skip=residual_skip,
+                                               weight_hidden_size=weight_hidden_size, activation=activation
+                                               ).to(DEVICE) if use_sparse_network else dense_metanet
+                meta_weight_count = sum(p.numel() for p in next(dense_metanet.particles).parameters())
 
-                    # Task Train
-                    if not init_st:
-                        # Zero your gradients for every batch!
-                        dense_optimizer.zero_grad()
-                        batch_x, batch_y = batch_x.to(DEVICE), batch_y.to(DEVICE)
-                        y_pred = dense_metanet(batch_x)
-                        # loss = loss_fn(y, batch_y.unsqueeze(-1).to(torch.float32))
-                        loss = loss_fn(y_pred, batch_y.to(torch.long)) * batch_train_beta
-                        loss.backward()
+                loss_fn = nn.CrossEntropyLoss()
+                dense_optimizer = torch.optim.SGD(dense_metanet.parameters(), lr=0.008, momentum=0.9)
+                sparse_optimizer = torch.optim.SGD(
+                    sparse_metanet.parameters(), lr=0.008, momentum=0.9
+                                                   ) if use_sparse_network else dense_optimizer
 
-                        # Adjust learning weights
-                        dense_optimizer.step()
+                train_store = new_storage_df('train', None)
+                weight_store = new_storage_df('weights', meta_weight_count)
+                init_tsk = train_to_task_first
+                for epoch in tqdm(range(EPOCH), desc='MetaNet Train - Epochs'):
+                    is_validation_epoch = epoch % VALIDATION_FRQ == 0 if not debug else True
+                    is_self_train_epoch = epoch % SELF_TRAIN_FRQ == 0 if not debug else True
+                    sparse_metanet = sparse_metanet.train()
+                    dense_metanet = dense_metanet.train()
+                    if is_validation_epoch:
+                        metric = torchmetrics.Accuracy()
+                    else:
+                        metric = None
+                    init_st = train_to_id_first and not all(
+                        x.is_fixpoint == ft.identity_func for x in dense_metanet.particles
+                    )
+                    force_st = (force_st_for_n_from_last_epochs >= (EPOCH - epoch)
+                                ) and sequential_task_train and force_st_for_n_from_last_epochs
+                    for batch, (batch_x, batch_y) in tqdm(enumerate(d), total=len(d), desc='MetaNet Train - Batch'):
 
-                        step_log = dict(Epoch=epoch, Batch=batch,
-                                        Metric='Task Loss', Score=loss.item())
-                        train_store.loc[train_store.shape[0]] = step_log
-                        if is_validation_epoch:
-                            metric(y_pred.cpu(), batch_y.cpu())
+                        # Self Train
+                        if self_train and ((not init_tsk and (is_self_train_epoch or init_st)) or force_st):
+                            # Transfer weights
+                            if use_sparse_network:
+                                sparse_metanet = sparse_metanet.replace_weights_by_particles(dense_metanet.particles)
+                            for _ in range(n_st_per_batch):
+                                self_train_loss = sparse_metanet.combined_self_train(sparse_optimizer, reduction='mean')
+                            # noinspection PyUnboundLocalVariable
+                            step_log = dict(Epoch=epoch, Batch=batch,
+                                            Metric='Self Train Loss', Score=self_train_loss.item())
+                            train_store.loc[train_store.shape[0]] = step_log
+                            # Transfer weights
+                            if use_sparse_network:
+                                dense_metanet = dense_metanet.replace_particles(sparse_metanet.particle_weights)
 
-                    if batch >= 3 and debug:
-                        break
+                        # Task Train
+                        if not init_st:
+                            # Zero your gradients for every batch!
+                            dense_optimizer.zero_grad()
+                            batch_x, batch_y = batch_x.to(DEVICE), batch_y.to(DEVICE)
+                            y_pred = dense_metanet(batch_x)
+                            # loss = loss_fn(y, batch_y.unsqueeze(-1).to(torch.float32))
+                            loss = loss_fn(y_pred, batch_y.to(torch.long))
+                            loss.backward()
 
-                if is_validation_epoch:
-                    dense_metanet = dense_metanet.eval()
-                    if not init_st:
+                            # Adjust learning weights
+                            dense_optimizer.step()
+
+                            step_log = dict(Epoch=epoch, Batch=batch,
+                                            Metric='Task Loss', Score=loss.item())
+                            train_store.loc[train_store.shape[0]] = step_log
+                            if is_validation_epoch:
+                                metric(y_pred.cpu(), batch_y.cpu())
+
+                        if batch >= 3 and debug:
+                            break
+
+                    if is_validation_epoch:
+                        dense_metanet = dense_metanet.eval()
+                        if not init_st:
+                            validation_log = dict(Epoch=int(epoch), Batch=BATCHSIZE,
+                                                  Metric='Train Accuracy', Score=metric.compute().item())
+                            train_store.loc[train_store.shape[0]] = validation_log
+
+                        accuracy = checkpoint_and_validate(dense_metanet, seed_path, epoch).item()
                         validation_log = dict(Epoch=int(epoch), Batch=BATCHSIZE,
-                                              Metric='Train Accuracy', Score=metric.compute().item())
+                                              Metric='Test Accuracy', Score=accuracy)
                         train_store.loc[train_store.shape[0]] = validation_log
+                        if init_tsk or (train_to_task_first and sequential_task_train):
+                            init_tsk = accuracy <= tsk_threshold
+                    if init_st or is_validation_epoch:
+                        counter_dict = defaultdict(lambda: 0)
+                        # This returns ID-functions
+                        _ = test_for_fixpoints(counter_dict, list(dense_metanet.particles))
+                        counter_dict = dict(counter_dict)
+                        for key, value in counter_dict.items():
+                            step_log = dict(Epoch=int(epoch), Batch=BATCHSIZE, Metric=key, Score=value)
+                            train_store.loc[train_store.shape[0]] = step_log
+                        tqdm.write(f'Fixpoint Tester Results: {counter_dict}')
+                    if init_st or is_validation_epoch:
+                        for particle in dense_metanet.particles:
+                            weight_log = (epoch, particle.name, *flat_for_store(particle.parameters()))
+                            weight_store.loc[weight_store.shape[0]] = weight_log
+                        train_store.to_csv(df_store_path, mode='a', header=not df_store_path.exists(), index=False)
+                        weight_store.to_csv(weight_store_path, mode='a', header=not weight_store_path.exists(), index=False)
+                        train_store = new_storage_df('train', None)
+                        weight_store = new_storage_df('weights', meta_weight_count)
 
-                    accuracy = checkpoint_and_validate(dense_metanet, seed_path, epoch).item()
-                    validation_log = dict(Epoch=int(epoch), Batch=BATCHSIZE,
-                                          Metric='Test Accuracy', Score=accuracy)
-                    train_store.loc[train_store.shape[0]] = validation_log
-                    if init_tsk or (train_to_task_first and train_to_task_first_sequential):
-                        init_tsk = accuracy <= tsk_threshold
-                if init_st or is_validation_epoch:
-                    counter_dict = defaultdict(lambda: 0)
-                    # This returns ID-functions
-                    _ = test_for_fixpoints(counter_dict, list(dense_metanet.particles))
-                    for key, value in dict(counter_dict).items():
-                        step_log = dict(Epoch=int(epoch), Batch=BATCHSIZE, Metric=key, Score=value)
-                        train_store.loc[train_store.shape[0]] = step_log
-                if init_st or is_validation_epoch:
-                    for particle in dense_metanet.particles:
-                        weight_log = (epoch, particle.name, *flat_for_store(particle.parameters()))
-                        weight_store.loc[weight_store.shape[0]] = weight_log
-                    train_store.to_csv(df_store_path, mode='a', header=not df_store_path.exists(), index=False)
-                    weight_store.to_csv(weight_store_path, mode='a', header=not weight_store_path.exists(), index=False)
-                    train_store = new_storage_df('train', None)
-                    weight_store = new_storage_df('weights', meta_weight_count)
+                dense_metanet.eval()
 
-            dense_metanet.eval()
+                counter_dict = defaultdict(lambda: 0)
+                # This returns ID-functions
+                _ = test_for_fixpoints(counter_dict, list(dense_metanet.particles))
+                for key, value in dict(counter_dict).items():
+                    step_log = dict(Epoch=int(EPOCH), Batch=BATCHSIZE, Metric=key, Score=value)
+                    train_store.loc[train_store.shape[0]] = step_log
+                accuracy = checkpoint_and_validate(dense_metanet, seed_path, EPOCH, final_model=True)
+                validation_log = dict(Epoch=EPOCH, Batch=BATCHSIZE,
+                                      Metric='Test Accuracy', Score=accuracy.item())
+                for particle in dense_metanet.particles:
+                    weight_log = (EPOCH, particle.name, *(flat_for_store(particle.parameters())))
+                    weight_store.loc[weight_store.shape[0]] = weight_log
 
-            counter_dict = defaultdict(lambda: 0)
-            # This returns ID-functions
-            _ = test_for_fixpoints(counter_dict, list(dense_metanet.particles))
-            for key, value in dict(counter_dict).items():
-                step_log = dict(Epoch=int(EPOCH), Batch=BATCHSIZE, Metric=key, Score=value)
-                train_store.loc[train_store.shape[0]] = step_log
-            accuracy = checkpoint_and_validate(dense_metanet, seed_path, EPOCH, final_model=True)
-            validation_log = dict(Epoch=EPOCH, Batch=BATCHSIZE,
-                                  Metric='Test Accuracy', Score=accuracy.item())
-            for particle in dense_metanet.particles:
-                weight_log = (EPOCH, particle.name, *(flat_for_store(particle.parameters())))
-                weight_store.loc[weight_store.shape[0]] = weight_log
+                train_store.loc[train_store.shape[0]] = validation_log
+                train_store.to_csv(df_store_path, mode='a', header=not df_store_path.exists(), index=False)
+                weight_store.to_csv(weight_store_path, mode='a', header=not weight_store_path.exists(), index=False)
 
-            train_store.loc[train_store.shape[0]] = validation_log
-            train_store.to_csv(df_store_path, mode='a', header=not df_store_path.exists(), index=False)
-            weight_store.to_csv(weight_store_path, mode='a', header=not weight_store_path.exists(), index=False)
+            plot_training_result(df_store_path)
+            plot_training_particle_types(df_store_path)
 
-        plot_training_result(df_store_path)
-        plot_training_particle_types(df_store_path)
-
-        try:
-            model_path = next(seed_path.glob(f'*e{EPOCH}.tp'))
-        except StopIteration:
-            print('Model pattern did not trigger.')
-            print(f'Search path was: {seed_path}:')
-            print(f'Found Models are: {list(seed_path.rglob(".tp"))}')
-            exit(1)
-        latest_model = torch.load(model_path, map_location=DEVICE).eval()
-        try:
-            run_particle_dropout_and_plot(seed_path)
-        except ValueError as e:
-            print(e)
-        try:
-            plot_network_connectivity_by_fixtype(model_path)
-        except ValueError as e:
-            print(e)
+            try:
+                model_path = next(seed_path.glob(f'*e{EPOCH}.tp'))
+            except StopIteration:
+                print('Model pattern did not trigger.')
+                print(f'Search path was: {seed_path}:')
+                print(f'Found Models are: {list(seed_path.rglob(".tp"))}')
+                exit(1)
+            latest_model = torch.load(model_path, map_location=DEVICE).eval()
+            try:
+                run_particle_dropout_and_plot(seed_path)
+            except ValueError as e:
+                print(e)
+            try:
+                plot_network_connectivity_by_fixtype(model_path)
+            except ValueError as e:
+                print(e)
 
     if n_seeds >= 2:
         pass
diff --git a/functionalities_test.py b/functionalities_test.py
index 7165e30..4dd1724 100644
--- a/functionalities_test.py
+++ b/functionalities_test.py
@@ -6,11 +6,14 @@ from tqdm import tqdm
 from network import FixTypes, Net
 
 
+epsilon_error_margin = pow(10, -5)
+
+
 def is_divergent(network: Net) -> bool:
     return network.input_weight_matrix().isinf().any().item() or network.input_weight_matrix().isnan().any().item()
 
 
-def is_identity_function(network: Net, epsilon=pow(10, -5)) -> bool:
+def is_identity_function(network: Net, epsilon=epsilon_error_margin) -> bool:
 
     input_data = network.input_weight_matrix()
     target_data = network.create_target_weights(input_data)
@@ -20,14 +23,14 @@ def is_identity_function(network: Net, epsilon=pow(10, -5)) -> bool:
                           rtol=0, atol=epsilon)
 
 
-def is_zero_fixpoint(network: Net, epsilon=pow(10, -5)) -> bool:
+def is_zero_fixpoint(network: Net, epsilon=epsilon_error_margin) -> bool:
     target_data = network.create_target_weights(network.input_weight_matrix().detach())
     result = torch.allclose(target_data, torch.zeros_like(target_data), rtol=0, atol=epsilon)
     # result = bool(len(np.nonzero(network.create_target_weights(network.input_weight_matrix()))))
     return result
 
 
-def is_secondary_fixpoint(network: Net, epsilon: float = pow(10, -5)) -> bool:
+def is_secondary_fixpoint(network: Net, epsilon: float = epsilon_error_margin) -> bool:
     """ Secondary fixpoint check is done like this: compare first INPUT with second OUTPUT.
     If they are within the boundaries, then is secondary fixpoint. """
 
diff --git a/network.py b/network.py
index eed4bda..67c69ca 100644
--- a/network.py
+++ b/network.py
@@ -420,7 +420,7 @@ class MetaNet(nn.Module):
 
                                                          ) for layer_idx in range(self.depth - 2)]
                                               )
-        self._meta_layer_last = MetaLayer(name=f'L{len(self._meta_layer_list)}',
+        self._meta_layer_last = MetaLayer(name=f'L{len(self._meta_layer_list) + 1}',
                                           interface=self.width, width=self.out,
                                           weight_interface=weight_interface,
                                           weight_hidden_size=weight_hidden_size,
@@ -428,8 +428,6 @@ class MetaNet(nn.Module):
                                           )
         self.dropout_layer = nn.Dropout(p=self.dropout)
 
-        self._all_layers_with_particles = [self._meta_layer_first, *self._meta_layer_list, self._meta_layer_last]
-
     def replace_with_zero(self, ident_key):
         replaced_particles = 0
         for particle in self.particles:
@@ -442,48 +440,51 @@ class MetaNet(nn.Module):
         return self
 
     def forward(self, x):
-        if self.dropout != 0:
-            x = self.dropout_layer(x)
         tensor = self._meta_layer_first(x)
+        residual = None
         for idx, meta_layer in enumerate(self._meta_layer_list, start=1):
-            if self.dropout != 0:
-                tensor = self.dropout_layer(tensor)
             if idx % 2 == 1 and self.residual_skip:
-                x = tensor.clone()
+                residual = tensor.clone()
             tensor = meta_layer(tensor)
             if idx % 2 == 0 and self.residual_skip:
-                tensor = tensor + x
-        if self.dropout != 0:
-            x = self.dropout_layer(x)
-        tensor = self._meta_layer_last(x)
+                tensor = tensor + residual
+        tensor = self._meta_layer_last(tensor)
         return tensor
 
     @property
     def particles(self):
-        return (cell for metalayer in self._all_layers_with_particles for cell in metalayer.particles)
+        return (cell for metalayer in self.all_layers for cell in metalayer.particles)
 
-    def combined_self_train(self):
+    def combined_self_train(self, optimizer, reduction='mean'):
+        optimizer.zero_grad()
         losses = []
         for particle in self.particles:
             # Intergrate optimizer and backward function
             input_data = particle.input_weight_matrix()
             target_data = particle.create_target_weights(input_data)
             output = particle(input_data)
-            losses.append(F.mse_loss(output, target_data))
-        return torch.hstack(losses).sum(dim=-1, keepdim=True)
+            losses.append(F.mse_loss(output, target_data, reduction=reduction))
+        losses = torch.hstack(losses).sum(dim=-1, keepdim=True)
+        losses.backward()
+        optimizer.step()
+        return losses.detach()
 
     @property
     def hyperparams(self):
         return {key: val for key, val in self.__dict__.items() if not key.startswith('_')}
 
     def replace_particles(self, particle_weights_list):
-        for layer in self._all_layers_with_particles:
+        for layer in self.all_layers:
             for cell in layer.meta_cell_list:
                 # Individual replacement on cell lvl
                 for weight in cell.meta_weight_list:
                     weight.apply_weights(next(particle_weights_list).detach())
         return self
 
+    @property
+    def all_layers(self):
+        return (x for x in (self._meta_layer_first, *self._meta_layer_list, self._meta_layer_last))
+
 
 class MetaNetCompareBaseline(nn.Module):
 
@@ -495,19 +496,24 @@ class MetaNetCompareBaseline(nn.Module):
         self.interface = interface
         self.width = width
         self.depth = depth
-        
         self._first_layer = nn.Linear(self.interface, self.width, bias=False)
-        self._meta_layer_list = nn.ModuleList([nn.Linear(self.width, self.width, bias=False) for _ in range(self.depth - 2)])
+        self._meta_layer_list = nn.ModuleList([nn.Linear(self.width, self.width, bias=False
+                                                         ) for _ in range(self.depth - 2)])
         self._last_layer = nn.Linear(self.width, self.out, bias=False)
 
     def forward(self, x):
         tensor = self._first_layer(x)
+        if self.activation:
+            tensor = self.activation(tensor)
+        residual = None
         for idx, meta_layer in enumerate(self._meta_layer_list, start=1):
-            if idx % 2 == 1 and self.residual_skip:
-                x = tensor.clone()
             tensor = meta_layer(tensor)
+            if idx % 2 == 1 and self.residual_skip:
+                residual = tensor.clone()
             if idx % 2 == 0 and self.residual_skip:
-                tensor = tensor + x
+                tensor = tensor + residual
+            if self.activation:
+                tensor = self.activation(tensor)
         tensor = self._last_layer(tensor)
         return tensor
     
diff --git a/sanity_check_weights.py b/sanity_check_weights.py
index e6449f2..407d41e 100644
--- a/sanity_check_weights.py
+++ b/sanity_check_weights.py
@@ -10,8 +10,11 @@ from torch.utils.data import Dataset, DataLoader
 from torchvision.datasets import MNIST, CIFAR10
 from torchvision.transforms import ToTensor, Compose, Resize, Normalize, Grayscale
 import torchmetrics
+
+from functionalities_test import epsilon_error_margin as e
 from network import MetaNet, MetaNetCompareBaseline
 
+
 def extract_weights_from_model(model:MetaNet)->dict:
     inpt = torch.zeros(5)
     inpt[-1] = 1
@@ -25,27 +28,51 @@ def extract_weights_from_model(model:MetaNet)->dict:
     return dict(weights)
 
 
-def test_weights_as_model(model, new_weights:dict, data):
-    TransferNet = MetaNetCompareBaseline(model.interface, depth=model.depth, width=model.width, out=model.out,
-                                         residual_skip=True)
-
+def test_weights_as_model(meta_net, new_weights:dict, data):
+    transfer_net = MetaNetCompareBaseline(meta_net.interface, depth=meta_net.depth, width=meta_net.width, out=meta_net.out,
+                                          residual_skip=True)
     with torch.no_grad():
-        for weights, parameters in zip(new_weights.values(), TransferNet.parameters()):
+        new_weight_values = list(new_weights.values())
+        old_parameters = list(transfer_net.parameters())
+        assert len(new_weight_values) == len(old_parameters)
+        for weights, parameters in zip(new_weights.values(), transfer_net.parameters()):
             parameters[:] = torch.Tensor(weights).view(parameters.shape)[:]
 
-    TransferNet.eval()
-    metric = torchmetrics.Accuracy()
-    with tqdm(desc='Test Batch: ') as pbar:
-        for batch, (batch_x, batch_y) in tqdm(enumerate(data), total=len(data), desc='MetaNet Sanity Check'):
-            y = TransferNet(batch_x)
-            acc = metric(y.cpu(), batch_y.cpu())
-            pbar.set_postfix_str(f'Acc: {acc}')
-            pbar.update()
-                
-        # metric on all batches using custom accumulation
-        acc = metric.compute()
-        tqdm.write(f"Avg. accuracy on all data: {acc}")
-        return acc
+    transfer_net.eval()
+
+    # Test if the margin of error is similar
+
+    im_t = defaultdict(list)
+    rand = torch.randn((1, 15 * 15))
+    for net in [meta_net, transfer_net]:
+        tensor = rand.clone()
+        for layer in net.all_layers:
+            tensor = layer(tensor)
+            im_t[net.__class__.__name__].append(tensor.detach())
+
+    im_t = dict(im_t)
+
+    all_close = {f'layer_{idx}': torch.allclose(y1.detach(), y2.detach(), rtol=0, atol=e
+                                                ) for idx, (y1, y2) in enumerate(zip(*im_t.values()))
+                 }
+    print(f'Cummulative differences per layer is smaller then {e}:\n {all_close}')
+    # all_errors = {f'layer_{idx}': torch.absolute(y1.detach(), y2.detach(), rtol=0, atol=e
+    #                                              ) for idx, (y1, y2) in enumerate(zip(*im_t.values()))
+    #               }
+
+    for net in [meta_net, transfer_net]:
+        net.eval()
+        metric = torchmetrics.Accuracy()
+        with tqdm(desc='Test Batch: ') as pbar:
+            for batch, (batch_x, batch_y) in tqdm(enumerate(data), total=len(data), desc='MetaNet Sanity Check'):
+                y = net(batch_x)
+                acc = metric(y.cpu(), batch_y.cpu())
+                pbar.set_postfix_str(f'Acc: {acc}')
+                pbar.update()
+
+            # metric on all batches using custom accumulation
+            acc = metric.compute()
+            tqdm.write(f"Avg. accuracy on {net.__class__.__name__}: {acc}")
 
 
 if __name__ == '__main__':
@@ -58,7 +85,7 @@ if __name__ == '__main__':
     data_path.mkdir(exist_ok=True, parents=True)
     mnist_test = MNIST(str(data_path), transform=MNIST_TRANSFORM, download=True, train=False)
     d_test = DataLoader(mnist_test, batch_size=BATCHSIZE, shuffle=False, drop_last=True, num_workers=WORKER)
-    
+
     model = torch.load(Path('experiments/output/trained_model_ckpt_e50.tp'), map_location=DEVICE).eval()
     weights = extract_weights_from_model(model)
     test_weights_as_model(model, weights, d_test)
diff --git a/sparse_net.py b/sparse_net.py
index c9fc8e4..a9c6789 100644
--- a/sparse_net.py
+++ b/sparse_net.py
@@ -161,7 +161,7 @@ def embed_vector(x, repeat_dim):
 
 
 class SparseNetwork(nn.Module):
-    def __init__(self, input_dim, depth, width, out, residual_skip=True,
+    def __init__(self, input_dim, depth, width, out, residual_skip=True, activation=None,
                  weight_interface=5, weight_hidden_size=2, weight_output_size=1
                  ):
         super(SparseNetwork, self).__init__()
@@ -170,6 +170,7 @@ class SparseNetwork(nn.Module):
         self.depth_dim = depth
         self.hidden_dim = width
         self.out_dim = out
+        self.activation = activation
         self.first_layer = SparseLayer(self.input_dim  * self.hidden_dim,
                                        interface=weight_interface, width=weight_hidden_size, out=weight_output_size)
         self.last_layer = SparseLayer(self.hidden_dim * self.out_dim,
@@ -182,13 +183,17 @@ class SparseNetwork(nn.Module):
     def __call__(self, x):
 
         tensor = self.sparse_layer_forward(x, self.first_layer)
+        if self.activation:
+            tensor = self.activation(tensor)
         for nl_idx, network_layer in enumerate(self.hidden_layers):
-            if nl_idx % 2 == 0 and self.residual_skip:
-                residual = tensor
             # Sparse Layer pass
             tensor = self.sparse_layer_forward(tensor, network_layer)
 
-            if nl_idx % 2 != 0 and self.residual_skip:
+            if self.activation:
+                tensor = self.activation(tensor)
+            if nl_idx % 2 == 0 and self.residual_skip:
+                residual = tensor.clone()
+            if nl_idx % 2 == 1 and self.residual_skip:
                 # noinspection PyUnboundLocalVariable
                 tensor += residual
         tensor = self.sparse_layer_forward(tensor, self.last_layer, view_dim=self.out_dim)
@@ -234,14 +239,19 @@ class SparseNetwork(nn.Module):
     def sparselayers(self):
         return (x for x in (self.first_layer, *self.hidden_layers, self.last_layer))
 
-    def combined_self_train(self):
+    def combined_self_train(self, optimizer, reduction='mean'):
         losses = []
         for layer in self.sparselayers:
+            optimizer.zero_grad()
             x, target_data = layer.get_self_train_inputs_and_targets()
             output = layer(x)
 
-            losses.append(F.mse_loss(output, target_data) / layer.nr_nets)
-        return torch.hstack(losses).sum(dim=-1, keepdim=True)
+            loss = F.mse_loss(output, target_data, reduction=reduction)
+            losses.append(loss.detach())
+            loss.backward()
+            optimizer.step()
+
+        return sum(losses)
 
     def replace_weights_by_particles(self, particles):
         particles = list(particles)
@@ -274,12 +284,7 @@ def test_sparse_net_sef_train():
     if True:
         optimizer = torch.optim.SGD(net.parameters(), lr=0.004, momentum=0.9)
         for _ in trange(epochs):
-            optimizer.zero_grad()
-            loss = net.combined_self_train()
-            print(loss)
-            exit()
-            loss.backward()
-            optimizer.step()
+            _ = net.combined_self_train(optimizer)
 
     else:
         optimizer_dict = {