README.md Update

This commit is contained in:
Steffen Illium
2022-02-10 16:53:49 +01:00
parent 594bbaa3dd
commit 14768ffc0a
8 changed files with 134 additions and 18 deletions

View File

@@ -198,12 +198,13 @@ if __name__ == '__main__':
plotting = True
particle_analysis = True
as_sparse_network_test = True
self_train_alpha = 100
self_train_alpha = 1
batch_train_beta = 1
data_path = Path('data')
data_path.mkdir(exist_ok=True, parents=True)
run_path = Path('output') / 'mn_st_200_8_alpha_100'
run_path = Path('output') / 'mn_st_400_2_no_res'
model_path = run_path / '0000_trained_model.zip'
df_store_path = run_path / 'train_store.csv'
weight_store_path = run_path / 'weight_store.csv'
@@ -217,7 +218,7 @@ if __name__ == '__main__':
d = DataLoader(dataset, batch_size=BATCHSIZE, shuffle=True, drop_last=True, num_workers=WORKER)
interface = np.prod(dataset[0][0].shape)
metanet = MetaNet(interface, depth=5, width=6, out=10, residual_skip=True).to(DEVICE)
metanet = MetaNet(interface, depth=5, width=6, out=10, residual_skip=False).to(DEVICE)
meta_weight_count = sum(p.numel() for p in next(metanet.particles).parameters())
loss_fn = nn.CrossEntropyLoss()
@@ -249,7 +250,7 @@ if __name__ == '__main__':
batch_x, batch_y = batch_x.to(DEVICE), batch_y.to(DEVICE)
y = metanet(batch_x)
# loss = loss_fn(y, batch_y.unsqueeze(-1).to(torch.float32))
loss = loss_fn(y, batch_y.to(torch.long))
loss = loss_fn(y, batch_y.to(torch.long)) * batch_train_beta
loss.backward()
# Adjust learning weights
@@ -312,7 +313,7 @@ if __name__ == '__main__':
plot_training_result(df_store_path)
if particle_analysis:
plot_training_particle_types(df_store_path)
exit()
if particle_analysis:
model_path = next(run_path.glob(f'*e{EPOCH}.tp'))
latest_model = torch.load(model_path, map_location=DEVICE).eval()

View File

@@ -0,0 +1,103 @@
import sys
from collections import defaultdict
from pathlib import Path
import platform
import pandas as pd
import torch.optim
from matplotlib import pyplot as plt
from torch import nn
from torch.utils.data import Dataset, DataLoader
import numpy as np
import seaborn as sns
from tqdm import trange
from tqdm.contrib import tenumerate
if platform.node() == 'CarbonX':
debug = True
print("@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@")
print("@ Warning, Debugging Config@!!!!!! @")
print("@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@")
else:
debug = False
try:
# noinspection PyUnboundLocalVariable
if __package__ is None:
DIR = Path(__file__).resolve().parent
sys.path.insert(0, str(DIR.parent))
__package__ = DIR.name
else:
DIR = None
except NameError:
DIR = None
pass
import functionalities_test
from network import Net
class MultiplyByXTaskDataset(Dataset):
def __init__(self, x=0.23, length=int(5e5)):
super().__init__()
self.length = length
self.x = x
self.prng = np.random.default_rng()
def __len__(self):
return self.length
def __getitem__(self, _):
ab = self.prng.normal(size=(1,)).astype(np.float32)
return ab, ab * self.x
if __name__ == '__main__':
net = Net(5, 1, 1)
multiplication_target = 0.03
loss_fn = nn.MSELoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.008, momentum=0.9)
train_frame = pd.DataFrame(columns=['Epoch', 'Batch', 'Metric', 'Score'])
dataset = MultiplyByXTaskDataset(x=multiplication_target, length=1000000)
dataloader = DataLoader(dataset=dataset, batch_size=8000)
for epoch in trange(30):
mean_batch_loss = []
mean_self_tain_loss = []
for batch, (batch_x, batch_y) in tenumerate(dataloader):
self_train_loss, _ = net.self_train(10, save_history=False)
batch_x_emb = torch.zeros(batch_x.shape[0], 5)
batch_x_emb[:, -1] = batch_x.squeeze()
y = net(batch_x_emb)
loss = loss_fn(y, batch_y)
loss.backward()
optimizer.step()
mean_batch_loss.append(loss.detach())
mean_self_tain_loss.append(self_train_loss.detach())
train_frame.loc[train_frame.shape[0]] = dict(Epoch=epoch, Batch=batch,
Metric='Self Train Loss', Score=np.average(mean_self_tain_loss))
train_frame.loc[train_frame.shape[0]] = dict(Epoch=epoch, Batch=batch,
Metric='Batch Loss', Score=np.average(mean_batch_loss))
counter = defaultdict(lambda: 0)
functionalities_test.test_for_fixpoints(counter, nets=[net])
print(dict(counter))
sanity = net(torch.Tensor([0,0,0,0,1])).detach()
print(sanity)
print(abs(sanity - multiplication_target))
sns.lineplot(data=train_frame, x='Epoch', y='Score', hue='Metric')
outpath = Path('output') / 'sanity' / 'test.png'
outpath.parent.mkdir(exist_ok=True, parents=True)
plt.savefig(outpath)