Fixed the Model classes, Visualization

This commit is contained in:
Si11ium
2019-08-23 13:10:47 +02:00
parent 0e879bfdb1
commit 7b0b96eaa3
16 changed files with 141 additions and 469 deletions

View File

@ -1,11 +1,34 @@
import os
import torch
import pytorch_lightning as pl
from torch.nn import Module, Linear, ReLU, Tanh, Sigmoid, Dropout, GRU, AvgPool2d
from pytorch_lightning import data_loader
from torch.nn import Module, Linear, ReLU, Tanh, Sigmoid, Dropout, GRU
from abc import ABC, abstractmethod
#######################
# Abstract NN Class
# Abstract NN Class & Lightning Module
from torch.utils.data import DataLoader
from dataset import DataContainer
class LightningModuleOverrides:
@property
def name(self):
return self.__class__.__name__
def forward(self, x):
return self.network.forward(x)
@data_loader
def tng_dataloader(self):
num_workers = os.cpu_count() // 2
return DataLoader(DataContainer('data', self.size, self.step),
shuffle=True, batch_size=100, num_workers=num_workers)
class AbstractNeuralNetwork(Module):