This commit is contained in:
Steffen Illium
2022-02-15 10:57:40 +01:00
parent 8546cc7ddf
commit 62e640e1f0
4 changed files with 64 additions and 21 deletions

@ -291,7 +291,7 @@ class SecondaryNet(Net):
class MetaCell(nn.Module):
def __init__(self, name, interface):
def __init__(self, name, interface, weight_interface=5, weight_hidden_size=2, weight_output_size=1):
super().__init__()
self.name = name
self.interface = interface
@ -342,7 +342,8 @@ class MetaCell(nn.Module):
class MetaLayer(nn.Module):
def __init__(self, name, interface=4, width=4, residual_skip=True):
def __init__(self, name, interface=4, width=4, residual_skip=True,
weight_interface=5, weight_hidden_size=2, weight_output_size=1):
super().__init__()
self.residual_skip = residual_skip
self.name = name
@ -351,7 +352,9 @@ class MetaLayer(nn.Module):
self.meta_cell_list = nn.ModuleList()
self.meta_cell_list.extend([MetaCell(name=f'{self.name}_C{cell_idx}',
interface=interface
interface=interface,
weight_interface=weight_interface, weight_hidden_size=weight_hidden_size,
weight_output_size=weight_output_size,
) for cell_idx in range(self.width)]
)
@ -371,26 +374,42 @@ class MetaLayer(nn.Module):
class MetaNet(nn.Module):
def __init__(self, interface=4, depth=3, width=4, out=1, activation=None, residual_skip=True):
def __init__(self, interface=4, depth=3, width=4, out=1, activation=None, residual_skip=True, dropout=0,
weight_interface=5, weight_hidden_size=2, weight_output_size=1,):
super().__init__()
self.dropout = dropout
self.activation = activation
self.out = out
self.interface = interface
self.width = width
self.depth = depth
self.weight_interface = weight_interface
self.weight_hidden_size = weight_hidden_size
self.weight_output_size = weight_output_size
self._meta_layer_list = nn.ModuleList()
self._meta_layer_list.append(MetaLayer(name=f'L{0}',
interface=self.interface,
width=self.width, residual_skip=residual_skip)
width=self.width, residual_skip=residual_skip,
weight_interface=weight_interface,
weight_hidden_size=weight_hidden_size,
weight_output_size=weight_output_size)
)
self._meta_layer_list.extend([MetaLayer(name=f'L{layer_idx + 1}',
interface=self.width, width=self.width, residual_skip=residual_skip
interface=self.width, width=self.width, residual_skip=residual_skip,
weight_interface=weight_interface,
weight_hidden_size=weight_hidden_size,
weight_output_size=weight_output_size,
) for layer_idx in range(self.depth - 2)]
)
self._meta_layer_list.append(MetaLayer(name=f'L{len(self._meta_layer_list)}',
interface=self.width, width=self.out, residual_skip=residual_skip)
interface=self.width, width=self.out, residual_skip=residual_skip,
weight_interface=weight_interface,
weight_hidden_size=weight_hidden_size,
weight_output_size=weight_output_size,
)
)
self.dropout_layer = nn.Dropout(p=self.dropout)
def replace_with_zero(self, ident_key):
replaced_particles = 0
@ -406,6 +425,8 @@ class MetaNet(nn.Module):
def forward(self, x):
tensor = x
for meta_layer in self._meta_layer_list:
if self.dropout:
tensor = self.dropout_layer(tensor)
tensor = meta_layer(tensor)
return tensor
@ -423,6 +444,10 @@ class MetaNet(nn.Module):
losses.append(F.mse_loss(output, target_data))
return torch.hstack(losses).sum(dim=-1, keepdim=True)
@property
def hyperparams(self):
return {key: val for key, val in self.__dict__.items() if not key.startswith('_')}
class MetaNetCompareBaseline(nn.Module):
@ -437,7 +462,7 @@ class MetaNetCompareBaseline(nn.Module):
self._meta_layer_list = nn.ModuleList()
self._meta_layer_list.append(nn.Linear(self.interface, self.width, bias=False))
self._meta_layer_list.extend([ nn.Linear(self.width, self.width, bias=False) for _ in range(self.depth - 2)])
self._meta_layer_list.extend([nn.Linear(self.width, self.width, bias=False) for _ in range(self.depth - 2)])
self._meta_layer_list.append(nn.Linear(self.width, self.out, bias=False))
def forward(self, x):