diff --git a/experiments/meta_task_exp.py b/experiments/meta_task_exp.py index 05826a4..c0f1553 100644 --- a/experiments/meta_task_exp.py +++ b/experiments/meta_task_exp.py @@ -295,8 +295,7 @@ def flat_for_store(parameters): def train_self_replication(model, optimizer, st_stps) -> dict: - for _ in range(st_stps): - self_train_loss = model.combined_self_train(optimizer) + self_train_loss = model.combined_self_train(optimizer, st_stps) # noinspection PyUnboundLocalVariable stp_log = dict(Metric='Self Train Loss', Score=self_train_loss.item()) return stp_log diff --git a/experiments/meta_task_sanity_exp.py b/experiments/meta_task_sanity_exp.py index dc152cf..b54bc0f 100644 --- a/experiments/meta_task_sanity_exp.py +++ b/experiments/meta_task_sanity_exp.py @@ -68,7 +68,8 @@ if __name__ == '__main__': mean_self_tain_loss = [] for batch, (batch_x, batch_y) in tenumerate(dataloader): - # self_train_loss, _ = net.self_train(2, save_history=False, learning_rate=0.004) + self_train_loss, _ = net.self_train(2, save_history=False, learning_rate=0.004) + for _ in range(2): optimizer.zero_grad() input_data = net.input_weight_matrix() diff --git a/network.py b/network.py index aa6af5d..92018c4 100644 --- a/network.py +++ b/network.py @@ -463,19 +463,20 @@ class MetaNet(nn.Module): def particles(self): return (cell for metalayer in self.all_layers for cell in metalayer.particles) - def combined_self_train(self, optimizer, reduction='mean'): - optimizer.zero_grad() + def combined_self_train(self, optimizer, n_st_steps, reduction='mean'): + losses = [] - n = 10 for particle in self.particles: - # Intergrate optimizer and backward function - input_data = particle.input_weight_matrix() - target_data = particle.create_target_weights(input_data) - output = particle(input_data) - losses.append(F.mse_loss(output, target_data, reduction=reduction)) + for _ in range(n_st_steps): + optimizer.zero_grad() + # Intergrate optimizer and backward function + input_data = particle.input_weight_matrix() + target_data = particle.create_target_weights(input_data) + output = particle(input_data) + losses.append(F.mse_loss(output, target_data, reduction=reduction)) + losses.backward() + optimizer.step() losses = torch.hstack(losses).sum(dim=-1, keepdim=True) - losses.backward() - optimizer.step() return losses.detach() @property