Fixed the Model classes, Visualization
This commit is contained in:
@ -1,11 +1,34 @@
|
||||
import os
|
||||
|
||||
import torch
|
||||
import pytorch_lightning as pl
|
||||
from torch.nn import Module, Linear, ReLU, Tanh, Sigmoid, Dropout, GRU, AvgPool2d
|
||||
from pytorch_lightning import data_loader
|
||||
from torch.nn import Module, Linear, ReLU, Tanh, Sigmoid, Dropout, GRU
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
#######################
|
||||
# Abstract NN Class
|
||||
# Abstract NN Class & Lightning Module
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from dataset import DataContainer
|
||||
|
||||
|
||||
class LightningModuleOverrides:
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self.__class__.__name__
|
||||
|
||||
def forward(self, x):
|
||||
return self.network.forward(x)
|
||||
|
||||
@data_loader
|
||||
def tng_dataloader(self):
|
||||
num_workers = os.cpu_count() // 2
|
||||
return DataLoader(DataContainer('data', self.size, self.step),
|
||||
shuffle=True, batch_size=100, num_workers=num_workers)
|
||||
|
||||
|
||||
class AbstractNeuralNetwork(Module):
|
||||
|
||||
|
Reference in New Issue
Block a user