Journal TEx Text
This commit is contained in:
91
network.py
91
network.py
@ -245,17 +245,82 @@ class SecondaryNet(Net):
|
||||
return df, is_diverged
|
||||
|
||||
|
||||
class MetaWeight(Net):
|
||||
pass
|
||||
|
||||
|
||||
class MetaCell(nn.Module):
|
||||
def __init__(self, name, interface, residual_skip=True):
|
||||
super().__init__()
|
||||
self.residual_skip = residual_skip
|
||||
self.name = name
|
||||
self.interface = interface
|
||||
self.weight_interface = 4
|
||||
self.net_hidden_size = 4
|
||||
self.net_ouput_size = 1
|
||||
self.meta_weight_list = nn.ModuleList(
|
||||
[MetaWeight(self.weight_interface, self.net_hidden_size,
|
||||
self.net_ouput_size, name=f'{self.name}_{weight_idx}'
|
||||
) for weight_idx in range(self.interface)])
|
||||
|
||||
def forward(self, x):
|
||||
xs = [torch.hstack((x[:, idx].unsqueeze(-1), torch.zeros((x.shape[0], self.weight_interface - 1))))
|
||||
for idx in range(len(self.meta_weight_list))]
|
||||
tensor = torch.hstack([meta_weight(xs[idx]) for idx, meta_weight in enumerate(self.meta_weight_list)])
|
||||
if self.residual_skip:
|
||||
tensor += x
|
||||
|
||||
result = torch.sum(tensor, dim=-1, keepdim=True)
|
||||
return result
|
||||
|
||||
|
||||
class MetaLayer(nn.Module):
|
||||
def __init__(self, name, interface=4, out=1, width=4):
|
||||
super().__init__()
|
||||
self.name = name
|
||||
self.interface = interface
|
||||
self.width = width
|
||||
|
||||
meta_cell_list = nn.ModuleList([MetaCell(name=f'{self.name}_{cell_idx}',
|
||||
interface=interface
|
||||
) for cell_idx in range(self.width)])
|
||||
self.meta_cell_list = meta_cell_list
|
||||
|
||||
def forward(self, x):
|
||||
result = torch.hstack([metacell(x) for metacell in self.meta_cell_list])
|
||||
return result
|
||||
|
||||
|
||||
class MetaNet(nn.Module):
|
||||
|
||||
def __init__(self, interface=4, depth=3, width=4, out=1):
|
||||
super().__init__()
|
||||
self.out = out
|
||||
self.interface = interface
|
||||
self.width = width
|
||||
self.depth = depth
|
||||
|
||||
meta_layer_list = nn.ModuleList([MetaLayer(name=f'Weight_{0}',
|
||||
interface=self.interface,
|
||||
width=self.width)])
|
||||
meta_layer_list.extend([MetaLayer(name=f'Weight_{layer_idx + 1}',
|
||||
interface=self.width, width=self.width
|
||||
) for layer_idx in range(self.depth - 2)])
|
||||
meta_layer_list.append(MetaLayer(name=f'Weight_{len(meta_layer_list)}',
|
||||
interface=self.width, width=self.out))
|
||||
self._meta_layer_list = meta_layer_list
|
||||
self._net = nn.Sequential(*self._meta_layer_list)
|
||||
|
||||
def forward(self, x):
|
||||
result = self._net.forward(x)
|
||||
return result
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
is_div = True
|
||||
while is_div:
|
||||
net = SecondaryNet(4, 2, 1, "SecondaryNet")
|
||||
data_df, is_div = net.self_train(20000, 25, 1e-4)
|
||||
from matplotlib import pyplot as plt
|
||||
import seaborn as sns
|
||||
# data_df = data_df[::-1] # Reverse
|
||||
fig = sns.lineplot(data=data_df[[x for x in data_df.columns if x != 'step']])
|
||||
# fig.set(yscale='log')
|
||||
print(data_df.iloc[-1])
|
||||
print(data_df.iloc[0])
|
||||
plt.show()
|
||||
print("done")
|
||||
metanet = MetaNet(2, 3, 4, 1)
|
||||
metanet(torch.ones((5, 2)))
|
||||
print('Test')
|
||||
print('Test')
|
||||
print('Test')
|
||||
print('Test')
|
||||
print('Test')
|
||||
|
Reference in New Issue
Block a user