diff --git a/network.py b/network.py index 280f3b3..13cbc46 100644 --- a/network.py +++ b/network.py @@ -342,10 +342,10 @@ class MetaCell(nn.Module): class MetaLayer(nn.Module): - def __init__(self, name, interface=4, width=4, residual_skip=True, + def __init__(self, name, interface=4, width=4, # residual_skip=False, weight_interface=5, weight_hidden_size=2, weight_output_size=1): super().__init__() - self.residual_skip = residual_skip + self.residual_skip = False self.name = name self.interface = interface self.width = width @@ -389,20 +389,20 @@ class MetaNet(nn.Module): self._meta_layer_first = MetaLayer(name=f'L{0}', interface=self.interface, - width=self.width, residual_skip=residual_skip, + width=self.width, weight_interface=weight_interface, weight_hidden_size=weight_hidden_size, weight_output_size=weight_output_size) self._meta_layer_list = nn.ModuleList([MetaLayer(name=f'L{layer_idx + 1}', - interface=self.width, width=self.width, residual_skip=residual_skip, + interface=self.width, width=self.width, weight_interface=weight_interface, weight_hidden_size=weight_hidden_size, weight_output_size=weight_output_size, ) for layer_idx in range(self.depth - 2)] ) self._meta_layer_last = MetaLayer(name=f'L{len(self._meta_layer_list)}', - interface=self.width, width=self.out, residual_skip=residual_skip, + interface=self.width, width=self.out, weight_interface=weight_interface, weight_hidden_size=weight_hidden_size, weight_output_size=weight_output_size,