Model Training

This commit is contained in:
Si11ium
2020-05-03 18:00:49 +02:00
parent 3e75d73a6b
commit 6d8fbd7184
4 changed files with 80 additions and 52 deletions

View File

@ -13,6 +13,9 @@ import pytorch_lightning as pl
# Utility - Modules
###################
from ml_lib.utils.model_io import ModelParameters
class F_x(object):
def __init__(self):
pass
@ -111,12 +114,15 @@ class LightningBaseModule(pl.LightningModule, ABC):
def __init__(self, hparams):
super(LightningBaseModule, self).__init__()
self.hparams = deepcopy(hparams)
# Data loading
# =============================================================================
# Map Object
# self.map_storage = MapStorage(self.hparams.data_param.map_root)
# Set Parameters
################################
self.hparams = hparams
self.params = ModelParameters(hparams)
# Dataset Loading
################################
# TODO: Find a way to push Class Name, library path and parameters (sometimes thiose are objects) in here
def size(self):
return self.shape
@ -158,25 +164,28 @@ class LightningBaseModule(pl.LightningModule, ABC):
weight_initializer = WeightInit(in_place_init_function=in_place_init_func_)
self.apply(weight_initializer)
class BaseModuleMixin_Dataloaders(ABC):
# Dataloaders
# ================================================================================
# Train Dataloader
def train_dataloader(self):
return DataLoader(dataset=self.dataset.train_dataset, shuffle=True,
batch_size=self.hparams.train_param.batch_size,
num_workers=self.hparams.data_param.worker)
batch_size=self.params.batch_size,
num_workers=self.params.worker)
# Test Dataloader
def test_dataloader(self):
return DataLoader(dataset=self.dataset.test_dataset, shuffle=True,
batch_size=self.hparams.train_param.batch_size,
num_workers=self.hparams.data_param.worker)
batch_size=self.params.batch_size,
num_workers=self.params.worker)
# Validation Dataloader
def val_dataloader(self):
return DataLoader(dataset=self.dataset.val_dataset, shuffle=True,
batch_size=self.hparams.train_param.batch_size,
num_workers=self.hparams.data_param.worker)
batch_size=self.params.batch_size,
num_workers=self.params.worker)
class FilterLayer(nn.Module):