MetaNetworks Debugged II

This commit is contained in:
Steffen Illium
2022-02-01 18:17:11 +01:00
parent 246d825bb4
commit 1b7581e656
4 changed files with 105 additions and 61 deletions

View File

@ -9,6 +9,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim, Tensor
from tqdm import tqdm
def prng():
@ -391,6 +392,17 @@ class MetaNet(nn.Module):
interface=self.width, width=self.out)
)
def replace_with_zero(self, ident_key):
replaced_particles = 0
for particle in self.particles:
if particle.is_fixpoint == ident_key:
particle.load_state_dict(
{key: torch.zeros_like(state) for key, state in particle.state_dict().items()}
)
replaced_particles += 1
tqdm.write(f'Particle Parameters replaced: {str(replaced_particles)}')
return self
def forward(self, x):
tensor = x
for meta_layer in self._meta_layer_list:
@ -401,15 +413,22 @@ class MetaNet(nn.Module):
def particles(self):
return (cell for metalayer in self._meta_layer_list for cell in metalayer.particles)
def combined_self_train(self):
def combined_self_train(self, external_optimizer):
losses = []
for particle in self.particles:
# Zero your gradients for every batch!
external_optimizer.zero_grad()
# Intergrate optimizer and backward function
input_data = particle.input_weight_matrix()
target_data = particle.create_target_weights(input_data)
output = particle(input_data)
losses.append(F.mse_loss(output, target_data))
return torch.hstack(losses).sum(dim=-1, keepdim=True)
loss = F.mse_loss(output, target_data)
losses.append(loss.detach)
loss.backward()
# Adjust learning weights
external_optimizer.step()
# return torch.hstack(losses).sum(dim=-1, keepdim=True)
return sum(losses)
if __name__ == '__main__':