README.md Update
This commit is contained in:
@@ -198,12 +198,13 @@ if __name__ == '__main__':
|
||||
plotting = True
|
||||
particle_analysis = True
|
||||
as_sparse_network_test = True
|
||||
self_train_alpha = 100
|
||||
self_train_alpha = 1
|
||||
batch_train_beta = 1
|
||||
|
||||
data_path = Path('data')
|
||||
data_path.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
run_path = Path('output') / 'mn_st_200_8_alpha_100'
|
||||
run_path = Path('output') / 'mn_st_400_2_no_res'
|
||||
model_path = run_path / '0000_trained_model.zip'
|
||||
df_store_path = run_path / 'train_store.csv'
|
||||
weight_store_path = run_path / 'weight_store.csv'
|
||||
@@ -217,7 +218,7 @@ if __name__ == '__main__':
|
||||
d = DataLoader(dataset, batch_size=BATCHSIZE, shuffle=True, drop_last=True, num_workers=WORKER)
|
||||
|
||||
interface = np.prod(dataset[0][0].shape)
|
||||
metanet = MetaNet(interface, depth=5, width=6, out=10, residual_skip=True).to(DEVICE)
|
||||
metanet = MetaNet(interface, depth=5, width=6, out=10, residual_skip=False).to(DEVICE)
|
||||
meta_weight_count = sum(p.numel() for p in next(metanet.particles).parameters())
|
||||
|
||||
loss_fn = nn.CrossEntropyLoss()
|
||||
@@ -249,7 +250,7 @@ if __name__ == '__main__':
|
||||
batch_x, batch_y = batch_x.to(DEVICE), batch_y.to(DEVICE)
|
||||
y = metanet(batch_x)
|
||||
# loss = loss_fn(y, batch_y.unsqueeze(-1).to(torch.float32))
|
||||
loss = loss_fn(y, batch_y.to(torch.long))
|
||||
loss = loss_fn(y, batch_y.to(torch.long)) * batch_train_beta
|
||||
loss.backward()
|
||||
|
||||
# Adjust learning weights
|
||||
@@ -312,7 +313,7 @@ if __name__ == '__main__':
|
||||
plot_training_result(df_store_path)
|
||||
if particle_analysis:
|
||||
plot_training_particle_types(df_store_path)
|
||||
exit()
|
||||
|
||||
if particle_analysis:
|
||||
model_path = next(run_path.glob(f'*e{EPOCH}.tp'))
|
||||
latest_model = torch.load(model_path, map_location=DEVICE).eval()
|
||||
|
103
experiments/meta_task_sanity_exp.py
Normal file
103
experiments/meta_task_sanity_exp.py
Normal file
@@ -0,0 +1,103 @@
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
import platform
|
||||
|
||||
import pandas as pd
|
||||
import torch.optim
|
||||
from matplotlib import pyplot as plt
|
||||
from torch import nn
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
import numpy as np
|
||||
import seaborn as sns
|
||||
from tqdm import trange
|
||||
from tqdm.contrib import tenumerate
|
||||
|
||||
|
||||
if platform.node() == 'CarbonX':
|
||||
debug = True
|
||||
print("@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@")
|
||||
print("@ Warning, Debugging Config@!!!!!! @")
|
||||
print("@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@")
|
||||
else:
|
||||
debug = False
|
||||
try:
|
||||
# noinspection PyUnboundLocalVariable
|
||||
if __package__ is None:
|
||||
DIR = Path(__file__).resolve().parent
|
||||
sys.path.insert(0, str(DIR.parent))
|
||||
__package__ = DIR.name
|
||||
else:
|
||||
DIR = None
|
||||
except NameError:
|
||||
DIR = None
|
||||
pass
|
||||
|
||||
import functionalities_test
|
||||
from network import Net
|
||||
|
||||
|
||||
class MultiplyByXTaskDataset(Dataset):
|
||||
def __init__(self, x=0.23, length=int(5e5)):
|
||||
super().__init__()
|
||||
self.length = length
|
||||
self.x = x
|
||||
self.prng = np.random.default_rng()
|
||||
|
||||
def __len__(self):
|
||||
return self.length
|
||||
|
||||
def __getitem__(self, _):
|
||||
ab = self.prng.normal(size=(1,)).astype(np.float32)
|
||||
return ab, ab * self.x
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
net = Net(5, 1, 1)
|
||||
multiplication_target = 0.03
|
||||
|
||||
loss_fn = nn.MSELoss()
|
||||
optimizer = torch.optim.SGD(net.parameters(), lr=0.008, momentum=0.9)
|
||||
|
||||
train_frame = pd.DataFrame(columns=['Epoch', 'Batch', 'Metric', 'Score'])
|
||||
|
||||
dataset = MultiplyByXTaskDataset(x=multiplication_target, length=1000000)
|
||||
dataloader = DataLoader(dataset=dataset, batch_size=8000)
|
||||
for epoch in trange(30):
|
||||
mean_batch_loss = []
|
||||
mean_self_tain_loss = []
|
||||
for batch, (batch_x, batch_y) in tenumerate(dataloader):
|
||||
self_train_loss, _ = net.self_train(10, save_history=False)
|
||||
batch_x_emb = torch.zeros(batch_x.shape[0], 5)
|
||||
batch_x_emb[:, -1] = batch_x.squeeze()
|
||||
y = net(batch_x_emb)
|
||||
loss = loss_fn(y, batch_y)
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
mean_batch_loss.append(loss.detach())
|
||||
mean_self_tain_loss.append(self_train_loss.detach())
|
||||
|
||||
train_frame.loc[train_frame.shape[0]] = dict(Epoch=epoch, Batch=batch,
|
||||
Metric='Self Train Loss', Score=np.average(mean_self_tain_loss))
|
||||
train_frame.loc[train_frame.shape[0]] = dict(Epoch=epoch, Batch=batch,
|
||||
Metric='Batch Loss', Score=np.average(mean_batch_loss))
|
||||
|
||||
counter = defaultdict(lambda: 0)
|
||||
functionalities_test.test_for_fixpoints(counter, nets=[net])
|
||||
print(dict(counter))
|
||||
sanity = net(torch.Tensor([0,0,0,0,1])).detach()
|
||||
print(sanity)
|
||||
print(abs(sanity - multiplication_target))
|
||||
sns.lineplot(data=train_frame, x='Epoch', y='Score', hue='Metric')
|
||||
outpath = Path('output') / 'sanity' / 'test.png'
|
||||
outpath.parent.mkdir(exist_ok=True, parents=True)
|
||||
plt.savefig(outpath)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user