small fixes new parameters

This commit is contained in:
Steffen Illium
2022-02-25 15:32:56 +01:00
parent 5b2b5b5beb
commit 9d8496a725
5 changed files with 292 additions and 236 deletions

View File

@ -161,7 +161,7 @@ def embed_vector(x, repeat_dim):
class SparseNetwork(nn.Module):
def __init__(self, input_dim, depth, width, out, residual_skip=True,
def __init__(self, input_dim, depth, width, out, residual_skip=True, activation=None,
weight_interface=5, weight_hidden_size=2, weight_output_size=1
):
super(SparseNetwork, self).__init__()
@ -170,6 +170,7 @@ class SparseNetwork(nn.Module):
self.depth_dim = depth
self.hidden_dim = width
self.out_dim = out
self.activation = activation
self.first_layer = SparseLayer(self.input_dim * self.hidden_dim,
interface=weight_interface, width=weight_hidden_size, out=weight_output_size)
self.last_layer = SparseLayer(self.hidden_dim * self.out_dim,
@ -182,13 +183,17 @@ class SparseNetwork(nn.Module):
def __call__(self, x):
tensor = self.sparse_layer_forward(x, self.first_layer)
if self.activation:
tensor = self.activation(tensor)
for nl_idx, network_layer in enumerate(self.hidden_layers):
if nl_idx % 2 == 0 and self.residual_skip:
residual = tensor
# Sparse Layer pass
tensor = self.sparse_layer_forward(tensor, network_layer)
if nl_idx % 2 != 0 and self.residual_skip:
if self.activation:
tensor = self.activation(tensor)
if nl_idx % 2 == 0 and self.residual_skip:
residual = tensor.clone()
if nl_idx % 2 == 1 and self.residual_skip:
# noinspection PyUnboundLocalVariable
tensor += residual
tensor = self.sparse_layer_forward(tensor, self.last_layer, view_dim=self.out_dim)
@ -234,14 +239,19 @@ class SparseNetwork(nn.Module):
def sparselayers(self):
return (x for x in (self.first_layer, *self.hidden_layers, self.last_layer))
def combined_self_train(self):
def combined_self_train(self, optimizer, reduction='mean'):
losses = []
for layer in self.sparselayers:
optimizer.zero_grad()
x, target_data = layer.get_self_train_inputs_and_targets()
output = layer(x)
losses.append(F.mse_loss(output, target_data) / layer.nr_nets)
return torch.hstack(losses).sum(dim=-1, keepdim=True)
loss = F.mse_loss(output, target_data, reduction=reduction)
losses.append(loss.detach())
loss.backward()
optimizer.step()
return sum(losses)
def replace_weights_by_particles(self, particles):
particles = list(particles)
@ -274,12 +284,7 @@ def test_sparse_net_sef_train():
if True:
optimizer = torch.optim.SGD(net.parameters(), lr=0.004, momentum=0.9)
for _ in trange(epochs):
optimizer.zero_grad()
loss = net.combined_self_train()
print(loss)
exit()
loss.backward()
optimizer.step()
_ = net.combined_self_train(optimizer)
else:
optimizer_dict = {