MetaNetworks Debugged II
This commit is contained in:
25
network.py
25
network.py
@ -9,6 +9,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import optim, Tensor
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def prng():
|
||||
@ -391,6 +392,17 @@ class MetaNet(nn.Module):
|
||||
interface=self.width, width=self.out)
|
||||
)
|
||||
|
||||
def replace_with_zero(self, ident_key):
|
||||
replaced_particles = 0
|
||||
for particle in self.particles:
|
||||
if particle.is_fixpoint == ident_key:
|
||||
particle.load_state_dict(
|
||||
{key: torch.zeros_like(state) for key, state in particle.state_dict().items()}
|
||||
)
|
||||
replaced_particles += 1
|
||||
tqdm.write(f'Particle Parameters replaced: {str(replaced_particles)}')
|
||||
return self
|
||||
|
||||
def forward(self, x):
|
||||
tensor = x
|
||||
for meta_layer in self._meta_layer_list:
|
||||
@ -401,15 +413,22 @@ class MetaNet(nn.Module):
|
||||
def particles(self):
|
||||
return (cell for metalayer in self._meta_layer_list for cell in metalayer.particles)
|
||||
|
||||
def combined_self_train(self):
|
||||
def combined_self_train(self, external_optimizer):
|
||||
losses = []
|
||||
for particle in self.particles:
|
||||
# Zero your gradients for every batch!
|
||||
external_optimizer.zero_grad()
|
||||
# Intergrate optimizer and backward function
|
||||
input_data = particle.input_weight_matrix()
|
||||
target_data = particle.create_target_weights(input_data)
|
||||
output = particle(input_data)
|
||||
losses.append(F.mse_loss(output, target_data))
|
||||
return torch.hstack(losses).sum(dim=-1, keepdim=True)
|
||||
loss = F.mse_loss(output, target_data)
|
||||
losses.append(loss.detach)
|
||||
loss.backward()
|
||||
# Adjust learning weights
|
||||
external_optimizer.step()
|
||||
# return torch.hstack(losses).sum(dim=-1, keepdim=True)
|
||||
return sum(losses)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
Reference in New Issue
Block a user