small fixes new parameters
This commit is contained in:
@ -161,7 +161,7 @@ def embed_vector(x, repeat_dim):
|
||||
|
||||
|
||||
class SparseNetwork(nn.Module):
|
||||
def __init__(self, input_dim, depth, width, out, residual_skip=True,
|
||||
def __init__(self, input_dim, depth, width, out, residual_skip=True, activation=None,
|
||||
weight_interface=5, weight_hidden_size=2, weight_output_size=1
|
||||
):
|
||||
super(SparseNetwork, self).__init__()
|
||||
@ -170,6 +170,7 @@ class SparseNetwork(nn.Module):
|
||||
self.depth_dim = depth
|
||||
self.hidden_dim = width
|
||||
self.out_dim = out
|
||||
self.activation = activation
|
||||
self.first_layer = SparseLayer(self.input_dim * self.hidden_dim,
|
||||
interface=weight_interface, width=weight_hidden_size, out=weight_output_size)
|
||||
self.last_layer = SparseLayer(self.hidden_dim * self.out_dim,
|
||||
@ -182,13 +183,17 @@ class SparseNetwork(nn.Module):
|
||||
def __call__(self, x):
|
||||
|
||||
tensor = self.sparse_layer_forward(x, self.first_layer)
|
||||
if self.activation:
|
||||
tensor = self.activation(tensor)
|
||||
for nl_idx, network_layer in enumerate(self.hidden_layers):
|
||||
if nl_idx % 2 == 0 and self.residual_skip:
|
||||
residual = tensor
|
||||
# Sparse Layer pass
|
||||
tensor = self.sparse_layer_forward(tensor, network_layer)
|
||||
|
||||
if nl_idx % 2 != 0 and self.residual_skip:
|
||||
if self.activation:
|
||||
tensor = self.activation(tensor)
|
||||
if nl_idx % 2 == 0 and self.residual_skip:
|
||||
residual = tensor.clone()
|
||||
if nl_idx % 2 == 1 and self.residual_skip:
|
||||
# noinspection PyUnboundLocalVariable
|
||||
tensor += residual
|
||||
tensor = self.sparse_layer_forward(tensor, self.last_layer, view_dim=self.out_dim)
|
||||
@ -234,14 +239,19 @@ class SparseNetwork(nn.Module):
|
||||
def sparselayers(self):
|
||||
return (x for x in (self.first_layer, *self.hidden_layers, self.last_layer))
|
||||
|
||||
def combined_self_train(self):
|
||||
def combined_self_train(self, optimizer, reduction='mean'):
|
||||
losses = []
|
||||
for layer in self.sparselayers:
|
||||
optimizer.zero_grad()
|
||||
x, target_data = layer.get_self_train_inputs_and_targets()
|
||||
output = layer(x)
|
||||
|
||||
losses.append(F.mse_loss(output, target_data) / layer.nr_nets)
|
||||
return torch.hstack(losses).sum(dim=-1, keepdim=True)
|
||||
loss = F.mse_loss(output, target_data, reduction=reduction)
|
||||
losses.append(loss.detach())
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
return sum(losses)
|
||||
|
||||
def replace_weights_by_particles(self, particles):
|
||||
particles = list(particles)
|
||||
@ -274,12 +284,7 @@ def test_sparse_net_sef_train():
|
||||
if True:
|
||||
optimizer = torch.optim.SGD(net.parameters(), lr=0.004, momentum=0.9)
|
||||
for _ in trange(epochs):
|
||||
optimizer.zero_grad()
|
||||
loss = net.combined_self_train()
|
||||
print(loss)
|
||||
exit()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
_ = net.combined_self_train(optimizer)
|
||||
|
||||
else:
|
||||
optimizer_dict = {
|
||||
|
Reference in New Issue
Block a user