Journal TEx Text

This commit is contained in:
Steffen Illium
2022-01-21 17:28:45 +01:00
parent 21dd572969
commit 5f1f5833d8
3 changed files with 138 additions and 22 deletions

View File

@ -245,17 +245,82 @@ class SecondaryNet(Net):
return df, is_diverged
class MetaWeight(Net):
pass
class MetaCell(nn.Module):
def __init__(self, name, interface, residual_skip=True):
super().__init__()
self.residual_skip = residual_skip
self.name = name
self.interface = interface
self.weight_interface = 4
self.net_hidden_size = 4
self.net_ouput_size = 1
self.meta_weight_list = nn.ModuleList(
[MetaWeight(self.weight_interface, self.net_hidden_size,
self.net_ouput_size, name=f'{self.name}_{weight_idx}'
) for weight_idx in range(self.interface)])
def forward(self, x):
xs = [torch.hstack((x[:, idx].unsqueeze(-1), torch.zeros((x.shape[0], self.weight_interface - 1))))
for idx in range(len(self.meta_weight_list))]
tensor = torch.hstack([meta_weight(xs[idx]) for idx, meta_weight in enumerate(self.meta_weight_list)])
if self.residual_skip:
tensor += x
result = torch.sum(tensor, dim=-1, keepdim=True)
return result
class MetaLayer(nn.Module):
def __init__(self, name, interface=4, out=1, width=4):
super().__init__()
self.name = name
self.interface = interface
self.width = width
meta_cell_list = nn.ModuleList([MetaCell(name=f'{self.name}_{cell_idx}',
interface=interface
) for cell_idx in range(self.width)])
self.meta_cell_list = meta_cell_list
def forward(self, x):
result = torch.hstack([metacell(x) for metacell in self.meta_cell_list])
return result
class MetaNet(nn.Module):
def __init__(self, interface=4, depth=3, width=4, out=1):
super().__init__()
self.out = out
self.interface = interface
self.width = width
self.depth = depth
meta_layer_list = nn.ModuleList([MetaLayer(name=f'Weight_{0}',
interface=self.interface,
width=self.width)])
meta_layer_list.extend([MetaLayer(name=f'Weight_{layer_idx + 1}',
interface=self.width, width=self.width
) for layer_idx in range(self.depth - 2)])
meta_layer_list.append(MetaLayer(name=f'Weight_{len(meta_layer_list)}',
interface=self.width, width=self.out))
self._meta_layer_list = meta_layer_list
self._net = nn.Sequential(*self._meta_layer_list)
def forward(self, x):
result = self._net.forward(x)
return result
if __name__ == '__main__':
is_div = True
while is_div:
net = SecondaryNet(4, 2, 1, "SecondaryNet")
data_df, is_div = net.self_train(20000, 25, 1e-4)
from matplotlib import pyplot as plt
import seaborn as sns
# data_df = data_df[::-1] # Reverse
fig = sns.lineplot(data=data_df[[x for x in data_df.columns if x != 'step']])
# fig.set(yscale='log')
print(data_df.iloc[-1])
print(data_df.iloc[0])
plt.show()
print("done")
metanet = MetaNet(2, 3, 4, 1)
metanet(torch.ones((5, 2)))
print('Test')
print('Test')
print('Test')
print('Test')
print('Test')