MetaNetworks Debugged

This commit is contained in:
Steffen Illium 2022-01-31 10:35:11 +01:00
parent 49c0d8a621
commit 246d825bb4
8 changed files with 169 additions and 109 deletions

View File

@ -1,5 +1,4 @@
""" ----------------------------------------- Methods for summarizing the experiments ------------------------------------------ """
import os
""" -------------------------------- Methods for summarizing the experiments --------------------------------- """
from pathlib import Path
from visualization import line_chart_fixpoints, bar_chart_fixpoints
@ -52,7 +51,9 @@ def summary_fixpoint_percentage(runs, epochs, fixpoints_percentages, ST_steps, S
line_chart_fixpoints(fixpoints_percentages, epochs, ST_steps, SA_steps, directory_name, population_size)
""" --------------------------------------------------- Miscellaneous ---------------------------------------------------------- """
""" -------------------------------------------- Miscellaneous --------------------------------------------------- """
def check_folder(experiment_folder: str):
exp_path = Path('experiments') / experiment_folder
exp_path.mkdir(parents=True, exist_ok=True)

View File

@ -1,5 +1,5 @@
import pickle
import time
from collections import defaultdict
from pathlib import Path
import sys
import platform
@ -7,7 +7,14 @@ import platform
import pandas as pd
import torchmetrics
if platform.node() != 'CarbonX':
from functionalities_test import test_for_fixpoints
if platform.node() == 'CarbonX':
debug = True
print("@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@")
print("@ Warning, Debugging Config@!!!!!! @")
print("@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@")
else:
debug = False
try:
# noinspection PyUnboundLocalVariable
@ -20,8 +27,7 @@ if platform.node() != 'CarbonX':
except NameError:
DIR = None
pass
else:
debug = True
import numpy as np
import torch
@ -31,7 +37,7 @@ from torch import nn
from torch.nn import Flatten
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor, Compose
from torchvision.transforms import ToTensor, Compose, Resize
from tqdm import tqdm
from network import MetaNet
@ -69,7 +75,7 @@ class AddTaskDataset(Dataset):
def set_checkpoint(model, out_path, epoch_n, final_model=False):
epoch_n = str(epoch_n)
if final_model:
if not final_model:
ckpt_path = Path(out_path) / 'ckpt' / f'{epoch_n.zfill(4)}_model_ckpt.tp'
else:
ckpt_path = Path(out_path) / f'trained_model_ckpt.tp'
@ -84,32 +90,32 @@ def validate(checkpoint_path, ratio=0.1):
import torchmetrics
# initialize metric
metric = torchmetrics.Accuracy()
validmetric = torchmetrics.Accuracy()
try:
dataset = MNIST(str(data_path), transform=utility_transforms, train=False)
datas = MNIST(str(data_path), transform=utility_transforms, train=False)
except RuntimeError:
dataset = MNIST(str(data_path), transform=utility_transforms, train=False, download=True)
d = DataLoader(dataset, batch_size=BATCHSIZE, shuffle=True, drop_last=True, num_workers=WORKER)
datas = MNIST(str(data_path), transform=utility_transforms, train=False, download=True)
valid_d = DataLoader(datas, batch_size=BATCHSIZE, shuffle=True, drop_last=True, num_workers=WORKER)
model = torch.load(checkpoint_path, map_location=DEVICE).eval()
n_samples = int(len(d) * ratio)
with tqdm(total=n_samples, desc='Validation Run: ') as pbar:
for idx, (batch_x, batch_y) in enumerate(d):
batch_x, batch_y = batch_x.to(DEVICE), batch_y.to(DEVICE)
y = model(batch_x)
for idx, (valid_batch_x, valid_batch_y) in enumerate(valid_d):
valid_batch_x, valid_batch_y = valid_batch_x.to(DEVICE), valid_batch_y.to(DEVICE)
y_valid = model(valid_batch_x)
# metric on current batch
acc = metric(y.cpu(), batch_y.cpu())
acc = validmetric(y_valid.cpu(), valid_batch_y.cpu())
pbar.set_postfix_str(f'Acc: {acc}')
pbar.update()
if idx == n_samples:
break
# metric on all batches using custom accumulation
acc = metric.compute()
print(f"Accuracy on all data: {acc}")
acc = validmetric.compute()
tqdm.write(f"Avg. accuracy on all data: {acc}")
return acc
@ -125,41 +131,39 @@ def plot_training_result(path_to_dataframe):
df = pd.read_csv(path_to_dataframe, index_col=0)
fig, ax1 = plt.subplots() # initializes figure and plots
ax2 = ax1.twinx() # applies twinx to ax2, which is the second y axis.
ax2 = ax1.twinx() # applies twinx to ax2, which is the second y-axis.
# plots the first set of data, and sets it to ax1.
data = df[df['Metric'] == 'BatchLoss']
# plots the second set, and sets to ax2.
sns.lineplot(data=data.groupby('Epoch').mean(), x='Epoch', y='Score', legend=True, ax=ax2)
data = df[df['Metric'] == 'Test Accuracy']
sns.lineplot(data=data, x='Epoch', y='Score', marker='o', color='red')
data = df[df['Metric'] == 'Train Accuracy']
sns.lineplot(data=data, x='Epoch', y='Score', marker='o', color='green')
sns.lineplot(data=data.groupby('Epoch').mean(), x='Epoch', y='Score', legend=True, ax=ax1, color='blue')
data = df[(df['Metric'] == 'Test Accuracy') | (df['Metric'] == 'Train Accuracy')]
sns.lineplot(data=data, x='Epoch', y='Score', marker='o', hue='Metric', legend=True)
ax2.set(yscale='log')
ax1.set(yscale='log')
ax1.set_title('Training Lineplot')
plt.tight_layout()
if debug:
plt.show()
else:
plt.savefig(Path(path_to_dataframe.parent / 'training_lineplot.png'))
plt.savefig(Path(path_to_dataframe.parent / 'training_lineplot.png'), dpi=300)
if __name__ == '__main__':
self_train = True
soup_interaction = True
training = True
plotting = True
self_train = False
training = False
plotting = False
particle_analysis = True
data_path = Path('data')
data_path.mkdir(exist_ok=True, parents=True)
run_path = Path('output') / 'intergrated_self_train'
run_path = Path('output') / 'mnist_test_half_size'
model_path = run_path / '0000_trained_model.zip'
if training:
utility_transforms = Compose([ToTensor(), ToFloat(), Flatten(start_dim=0)])
utility_transforms = Compose([ToTensor(), ToFloat(), Resize((15, 15)), Flatten(start_dim=0)])
try:
dataset = MNIST(str(data_path), transform=utility_transforms)
@ -179,6 +183,8 @@ if __name__ == '__main__':
is_self_train_epoch = epoch % SELF_TRAIN_FRQ == 0 if not debug else True
if is_validation_epoch:
metric = torchmetrics.Accuracy()
else:
metric = None
for batch, (batch_x, batch_y) in tqdm(enumerate(d), total=len(d), desc='MetaNet Train - Batch'):
if self_train and is_self_train_epoch:
# Zero your gradients for every batch!
@ -221,10 +227,19 @@ if __name__ == '__main__':
accuracy = checkpoint_and_validate(metanet, run_path, EPOCH, final_model=True)
validation_log = dict(Epoch=EPOCH, Batch=BATCHSIZE,
Metric='Test Accuracy', Score=accuracy.item())
train_store.loc[train_store.shape[0]] = validation_log
torch.save(metanet, model_path, pickle_protocol=pickle.HIGHEST_PROTOCOL)
train_store.loc[train_store.shape[0]] = validation_log
train_store.to_csv(run_path / 'train_store.csv')
if plotting:
plot_training_result(run_path / 'train_store.csv')
if particle_analysis:
model_path = next(run_path.glob('*.tp'))
latest_model = torch.load(model_path, map_location=DEVICE).eval()
analysis_dict = defaultdict(dict)
counter_dict = defaultdict(lambda: 0)
for particle in latest_model.particles:
analysis_dict[particle.name]['is_diverged'] = particle.are_weights_diverged()
test_for_fixpoints(counter_dict, latest_model.particles)

View File

@ -1,7 +1,6 @@
import copy
from typing import Dict, List
import numpy as np
from torch import Tensor
from network import Net
@ -9,7 +8,7 @@ def is_divergent(network: Net) -> bool:
for i in network.input_weight_matrix():
weight_value = i[0].item()
if np.isnan(weight_value).all() or np.isinf(weight_value).all():
if np.isnan(weight_value).any() or np.isinf(weight_value).any():
return True
return False
@ -25,7 +24,7 @@ def is_identity_function(network: Net, epsilon=pow(10, -5)) -> bool:
def is_zero_fixpoint(network: Net, epsilon=pow(10, -5)) -> bool:
target_data = network.create_target_weights(network.input_weight_matrix().detach().numpy())
target_data = network.create_target_weights(network.input_weight_matrix().detach())
result = np.allclose(target_data, np.zeros_like(target_data), rtol=0, atol=epsilon)
# result = bool(len(np.nonzero(network.create_target_weights(network.input_weight_matrix()))))
return result
@ -95,4 +94,4 @@ def test_status(net: Net) -> Net:
else:
net.is_fixpoint = "other_func"
return net
return net

View File

@ -6,7 +6,6 @@ import pickle
import pandas as pd
import numpy as np
import torch
from sklearn import preprocessing
from functionalities_test import is_identity_function, test_status
from journal_basins import SpawnExperiment, mean_invariate_manhattan_distance
@ -22,8 +21,11 @@ class SpawnLinspaceExperiment(SpawnExperiment):
number_clones = number_clones or self.nr_clones
df = pd.DataFrame(
columns=['clone', 'parent', 'parent2', 'MAE_pre', 'MAE_post', 'MSE_pre', 'MSE_post', 'MIM_pre', 'MIM_post', 'noise',
'status_pst'])
columns=['clone', 'parent', 'parent2',
'MAE_pre', 'MAE_post',
'MSE_pre', 'MSE_post',
'MIM_pre', 'MIM_post',
'noise', 'status_pst'])
# For every initial net {i} after populating (that is fixpoint after first epoch);
# parent = self.parents[0]
@ -38,15 +40,15 @@ class SpawnLinspaceExperiment(SpawnExperiment):
# to see full trajectory (but the clones will be very hard to see).
# Make one target to compare distances to clones later when they have trained.
net1.start_time = self.ST_steps - 150
net1_input_data = net1.input_weight_matrix()
net1_target_data = net1.create_target_weights(net1_input_data)
net1_input_data = net1.input_weight_matrix().detach()
net1_target_data = net1.create_target_weights(net1_input_data).detach()
net2.start_time = self.ST_steps - 150
net2_input_data = net2.input_weight_matrix()
net2_target_data = net2.create_target_weights(net2_input_data)
net2_input_data = net2.input_weight_matrix().detach()
net2_target_data = net2.create_target_weights(net2_input_data).detach()
if is_identity_function(net1) and is_identity_function(net2):
# if True:
# if True:
# Clone the fixpoint x times and add (+-)self.noise to weight-sets randomly;
# To plot clones starting after first epoch (z=ST_steps), set that as start_time!
# To make sure PCA will plot the same trajectory up until this point, we clone the
@ -64,7 +66,7 @@ class SpawnLinspaceExperiment(SpawnExperiment):
clone.number_trained = copy.deepcopy(net1.number_trained)
# Pre Training distances (after noise application of course)
clone_pre_weights = clone.create_target_weights(clone.input_weight_matrix())
clone_pre_weights = clone.create_target_weights(clone.input_weight_matrix()).detach()
MAE_pre = MAE(net1_target_data, clone_pre_weights)
MSE_pre = MSE(net1_target_data, clone_pre_weights)
MIM_pre = mean_invariate_manhattan_distance(net1_target_data, clone_pre_weights)
@ -78,10 +80,15 @@ class SpawnLinspaceExperiment(SpawnExperiment):
raise ValueError
except ValueError:
print("Ran into nan in 'in beetween weights' array.")
df.loc[len(df)] = [j, net1.name, net2.name,
MAE_pre, 0,
MSE_pre, 0,
MIM_pre, 0,
self.noise, clone.is_fixpoint]
continue
# Post Training distances for comparison
clone_post_weights = clone.create_target_weights(clone.input_weight_matrix())
clone_post_weights = clone.create_target_weights(clone.input_weight_matrix()).detach()
MAE_post = MAE(net1_target_data, clone_post_weights)
MSE_post = MSE(net1_target_data, clone_post_weights)
MIM_post = mean_invariate_manhattan_distance(net1_target_data, clone_post_weights)
@ -95,16 +102,23 @@ class SpawnLinspaceExperiment(SpawnExperiment):
f"\nMIM({net1.name},{j}): {MIM_post}\n")
self.nets.append(clone)
df.loc[len(df)] = [j, net1.name, net2.name, MAE_pre, MAE_post, MSE_pre, MSE_post, MIM_pre, MIM_post,
self.noise, clone.is_fixpoint]
df.loc[len(df)] = [j, net1.name, net2.name,
MAE_pre, MAE_post,
MSE_pre, MSE_post,
MIM_pre, MIM_post,
self.noise, clone.is_fixpoint]
for net1, net2 in pairwise_net_list:
value = 'MAE'
c_selector = [f'{value}_pre', f'{value}_post']
values = df.loc[(df['parent'] == net1.name) & (df['parent2'] == net2.name)][c_selector]
this_min, this_max = values.values.min(), values.values.max()
df.loc[(df['parent'] == net1.name) &
(df['parent2'] == net2.name), c_selector] = (values - this_min) / (this_max - this_min)
try:
value = 'MAE'
c_selector = [f'{value}_pre', f'{value}_post']
values = df.loc[(df['parent'] == net1.name) & (df['parent2'] == net2.name)][c_selector]
this_min, this_max = values.values.min(), values.values.max()
df.loc[(df['parent'] == net1.name) &
(df['parent2'] == net2.name), c_selector] = (values - this_min) / (this_max - this_min)
except ValueError:
pass
for parent in self.parents:
for _ in range(self.epochs - 1):
for _ in range(self.ST_steps):
@ -148,7 +162,8 @@ if __name__ == '__main__':
df = exp.df
directory = Path('output') / 'spawn_basin' / f'{ST_name_hash}' / 'linage'
pickle.dump(exp, open(f"{directory}/experiment_pickle_{ST_name_hash}.p", "wb"))
with (directory / f"experiment_pickle_{ST_name_hash}.p").open('wb') as f:
pickle.dump(exp, f)
print(f"\nSaved experiment to {directory}.")
# Boxplot with counts of nr_fixpoints, nr_other, nr_etc. on y-axis
@ -183,6 +198,6 @@ if __name__ == '__main__':
# else:
# label.set_visible(False)
filepath = exp.directory / 'mim_dist_plot.png'
filepath = exp.directory / 'mim_dist_plot.pdf'
plt.tight_layout()
plt.savefig(filepath)
plt.savefig(filepath, dpi=600, format='pdf', bbox_inches='tight')

View File

@ -26,11 +26,14 @@ def l1(tup):
def mean_invariate_manhattan_distance(x, y):
# One of these one-liners that might be smart or really dumb. Goal is to find pairwise
# distances of ascending values, ie. sum (abs(min1_X-min1_Y), abs(min2_X-min2Y) ...) / mean.
# One of these one-liners that might be smart or really dumb. Goal is to find pairwise
# distances of ascending values, ie. sum (abs(min1_X-min1_Y), abs(min2_X-min2Y) ...) / mean.
# Idea was to find weight sets that have same values but just in different positions, that would
# make this distance 0.
return np.mean(list(map(l1, zip(sorted(x.numpy()), sorted(y.numpy())))))
try:
return np.mean(list(map(l1, zip(sorted(x.detach().numpy()), sorted(y.detach().numpy())))))
except AttributeError:
return np.mean(list(map(l1, zip(sorted(x.numpy()), sorted(y.numpy())))))
def distance_matrix(nets, distance="MIM", print_it=True):

View File

@ -8,6 +8,7 @@ from pathlib import Path
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from matplotlib import pyplot as plt
from sklearn.metrics import mean_absolute_error as MAE
from sklearn.metrics import mean_squared_error as MSE
@ -16,6 +17,7 @@ from tqdm import tqdm
from functionalities_test import is_identity_function, test_status, is_zero_fixpoint, is_divergent, \
is_secondary_fixpoint
from journal_basins import mean_invariate_manhattan_distance
from network import Net
from visualization import plot_loss, plot_3d_soup
@ -25,14 +27,6 @@ def l1(tup):
return abs(a - b)
def mean_invariate_manhattan_distance(x, y):
# One of these one-liners that might be smart or really dumb. Goal is to find pairwise
# distances of ascending values, ie. sum (abs(min1_X-min1_Y), abs(min2_X-min2Y) ...) / mean.
# Idea was to find weight sets that have same values but just in different positions, that would
# make this distance 0.
return np.mean(list(map(l1, zip(sorted(x.numpy()), sorted(y.numpy())))))
def distance_matrix(nets, distance="MIM", print_it=True):
matrix = [[0 for _ in range(len(nets))] for _ in range(len(nets))]
for net in range(len(nets)):

View File

@ -34,17 +34,21 @@ class Net(nn.Module):
def are_weights_diverged(network_weights):
""" Testing if the weights are eiter converging to infinity or -infinity. """
for layer_id, layer in enumerate(network_weights):
for cell_id, cell in enumerate(layer):
for weight_id, weight in enumerate(cell):
if torch.isnan(weight):
return True
if torch.isinf(weight):
return True
return False
# Slow and shitty:
# for layer_id, layer in enumerate(network_weights):
# for cell_id, cell in enumerate(layer):
# for weight_id, weight in enumerate(cell):
# if torch.isnan(weight):
# return True
# if torch.isinf(weight):
# return True
# return False
# Fast and modern:
return any(x.isnan.any() or x.isinf().any() for x in network_weights.parameters)
def apply_weights(self, new_weights: Tensor):
""" Changing the weights of a network to new given values. """
# TODO: Change this to 'parameters' version
i = 0
for layer_id, layer_name in enumerate(self.state_dict()):
for line_id, line_values in enumerate(self.state_dict()[layer_name]):
@ -101,15 +105,17 @@ class Net(nn.Module):
# Cell Enumeration
torch.arange(layer.out_features, device=d).repeat_interleave(layer.in_features).view(-1, 1),
# Weight Enumeration within the Cells
torch.arange(layer.in_features, device=d).view(-1, 1).repeat(layer.out_features, 1)
torch.arange(layer.in_features, device=d).view(-1, 1).repeat(layer.out_features, 1),
*(torch.full((x.numel(), 1), 0, device=d) for _ in range(self.input_size-4))
), dim=1)
)
# Finalize
weight_matrix = torch.cat(weight_matrix).float()
# Normalize all along the 1 dimensions
norm2 = weight_matrix[:, 1:].pow(2).sum(keepdim=True, dim=0).sqrt()
weight_matrix[:, 1:] = weight_matrix[:, 1:] / norm2
# Normalize 1,2,3 column of dim 1
last_pos_idx = self.input_size - 4
norm2 = weight_matrix[:, 1:-last_pos_idx].pow(2).sum(keepdim=True, dim=0).sqrt()
weight_matrix[:, 1:-last_pos_idx] = (weight_matrix[:, 1:-last_pos_idx] / norm2) + 1e-8
# computations
# create a mask where pos is 0 if it is to be replaced
@ -117,7 +123,7 @@ class Net(nn.Module):
mask[:, 0] = 0
self._weight_pos_enc_and_mask = weight_matrix, mask
return self._weight_pos_enc_and_mask
return tuple(x.clone() for x in self._weight_pos_enc_and_mask)
def forward(self, x):
for layer in self.layers:
@ -125,6 +131,7 @@ class Net(nn.Module):
return x
def normalize(self, value, norm):
raise NotImplementedError
# FIXME, This is bullshit, the code does not do what the docstring explains
# Obsolete now
""" Normalizing the values >= 1 and adding pow(10, -8) to the values equal to 0 """
@ -138,7 +145,7 @@ class Net(nn.Module):
""" Calculating the input tensor formed from the weights of the net """
weight_matrix = torch.cat([x.view(-1, 1) for x in self.parameters()])
pos_enc, mask = self._weight_pos_enc
weight_matrix = pos_enc * mask + weight_matrix.expand(-1, 4) * (1 - mask)
weight_matrix = pos_enc * mask + weight_matrix.expand(-1, pos_enc.shape[-1]) * (1 - mask)
return weight_matrix
def self_train(self,
@ -283,33 +290,50 @@ class SecondaryNet(Net):
class MetaCell(nn.Module):
def __init__(self, name, interface, residual_skip=True):
def __init__(self, name, interface):
super().__init__()
self.residual_skip = residual_skip
self.name = name
self.interface = interface
self.weight_interface = 4
self.weight_interface = 5
self.net_hidden_size = 4
self.net_ouput_size = 1
self.meta_weight_list = nn.ModuleList()
self.meta_weight_list.extend(
[Net(self.weight_interface, self.net_hidden_size,
self.net_ouput_size, name=f'{self.name}_{weight_idx}'
self.net_ouput_size, name=f'{self.name}_W{weight_idx}'
) for weight_idx in range(self.interface)]
)
self.__bed_mask = None
@property
def _bed_mask(self):
if self.__bed_mask is None:
d = next(self.parameters()).device
embedding = torch.zeros(1, self.weight_interface, device=d)
# computations
# create a mask where pos is 0 if it is to be replaced
mask = torch.ones_like(embedding)
mask[:, -1] = 0
self.__bed_mask = embedding, mask
return tuple(x.clone() for x in self.__bed_mask)
def forward(self, x):
xs = [torch.hstack(
(torch.zeros((x.shape[0], self.weight_interface - 1), device=x.device), x[:, idx].unsqueeze(-1))
)
for idx in range(len(self.meta_weight_list))]
tensor = torch.hstack([meta_weight(xs[idx]) for idx, meta_weight in enumerate(self.meta_weight_list)])
embedding, mask = self._bed_mask
expanded_mask = mask.expand(*x.shape, embedding.shape[-1])
embedding = embedding.repeat(*x.shape, 1)
if self.residual_skip:
tensor += x
# Row-wise
# xs = x.unsqueeze(-1).expand(-1, -1, embedding.shape[-1]).swapdims(0, 1)
# Column-wise
xs = x.unsqueeze(-1).expand(-1, -1, embedding.shape[-1])
xs = embedding * expanded_mask + xs * (1 - expanded_mask)
# ToDo Speed this up!
tensor = torch.hstack([meta_weight(xs[:, idx, :]) for idx, meta_weight in enumerate(self.meta_weight_list)])
result = torch.sum(tensor, dim=-1, keepdim=True)
return result
tensor = torch.sum(tensor, dim=-1, keepdim=True)
return tensor
@property
def particles(self):
@ -317,21 +341,27 @@ class MetaCell(nn.Module):
class MetaLayer(nn.Module):
def __init__(self, name, interface=4, width=4):
def __init__(self, name, interface=4, width=4, residual_skip=True):
super().__init__()
self.residual_skip = residual_skip
self.name = name
self.interface = interface
self.width = width
self.meta_cell_list = nn.ModuleList()
self.meta_cell_list.extend([MetaCell(name=f'{self.name}_{cell_idx}',
self.meta_cell_list.extend([MetaCell(name=f'{self.name}_C{cell_idx}',
interface=interface
) for cell_idx in range(self.width)]
)
def forward(self, x):
result = torch.hstack([metacell(x) for metacell in self.meta_cell_list])
return result
cell_results = []
for metacell in self.meta_cell_list:
cell_results.append(metacell(x))
tensor = torch.hstack(cell_results)
if self.residual_skip and x.shape == tensor.shape:
tensor += x
return tensor
@property
def particles(self):
@ -349,15 +379,15 @@ class MetaNet(nn.Module):
self.depth = depth
self._meta_layer_list = nn.ModuleList()
self._meta_layer_list.append(MetaLayer(name=f'Weight_{0}',
self._meta_layer_list.append(MetaLayer(name=f'L{0}',
interface=self.interface,
width=self.width)
)
self._meta_layer_list.extend([MetaLayer(name=f'Weight_{layer_idx + 1}',
self._meta_layer_list.extend([MetaLayer(name=f'L{layer_idx + 1}',
interface=self.width, width=self.width
) for layer_idx in range(self.depth - 2)]
)
self._meta_layer_list.append(MetaLayer(name=f'Weight_{len(self._meta_layer_list)}',
self._meta_layer_list.append(MetaLayer(name=f'L{len(self._meta_layer_list)}',
interface=self.width, width=self.out)
)
@ -383,9 +413,9 @@ class MetaNet(nn.Module):
if __name__ == '__main__':
metanet = MetaNet(interface=2, depth=3, width=2, out=1)
metanet = MetaNet(interface=3, depth=5, width=3, out=1)
next(metanet.particles).input_weight_matrix()
metanet(torch.ones((5, 2)))
metanet(torch.hstack([torch.full((2, 1), x) for x in range(metanet.interface)]))
a = metanet.particles
print('Test')
print('Test')

View File

@ -2,10 +2,13 @@ torch~=1.8.1+cpu
tqdm~=4.60.0
numpy~=1.20.3
matplotlib~=3.4.2
sklearn
sklearn~=0.0
scipy
tabulate~=0.8.9
scikit-learn~=0.24.2
pandas~=1.2.4
seaborn~=0.11.1
seaborn~=0.11.1
future~=0.18.2
torchmetrics~=0.7.0
torchvision~=0.9.1+cpu