Model Training
This commit is contained in:
@ -13,6 +13,9 @@ import pytorch_lightning as pl
|
||||
|
||||
# Utility - Modules
|
||||
###################
|
||||
from ml_lib.utils.model_io import ModelParameters
|
||||
|
||||
|
||||
class F_x(object):
|
||||
def __init__(self):
|
||||
pass
|
||||
@ -111,12 +114,15 @@ class LightningBaseModule(pl.LightningModule, ABC):
|
||||
|
||||
def __init__(self, hparams):
|
||||
super(LightningBaseModule, self).__init__()
|
||||
self.hparams = deepcopy(hparams)
|
||||
|
||||
# Data loading
|
||||
# =============================================================================
|
||||
# Map Object
|
||||
# self.map_storage = MapStorage(self.hparams.data_param.map_root)
|
||||
# Set Parameters
|
||||
################################
|
||||
self.hparams = hparams
|
||||
self.params = ModelParameters(hparams)
|
||||
|
||||
# Dataset Loading
|
||||
################################
|
||||
# TODO: Find a way to push Class Name, library path and parameters (sometimes thiose are objects) in here
|
||||
|
||||
def size(self):
|
||||
return self.shape
|
||||
@ -158,25 +164,28 @@ class LightningBaseModule(pl.LightningModule, ABC):
|
||||
weight_initializer = WeightInit(in_place_init_function=in_place_init_func_)
|
||||
self.apply(weight_initializer)
|
||||
|
||||
|
||||
class BaseModuleMixin_Dataloaders(ABC):
|
||||
|
||||
# Dataloaders
|
||||
# ================================================================================
|
||||
# Train Dataloader
|
||||
def train_dataloader(self):
|
||||
return DataLoader(dataset=self.dataset.train_dataset, shuffle=True,
|
||||
batch_size=self.hparams.train_param.batch_size,
|
||||
num_workers=self.hparams.data_param.worker)
|
||||
batch_size=self.params.batch_size,
|
||||
num_workers=self.params.worker)
|
||||
|
||||
# Test Dataloader
|
||||
def test_dataloader(self):
|
||||
return DataLoader(dataset=self.dataset.test_dataset, shuffle=True,
|
||||
batch_size=self.hparams.train_param.batch_size,
|
||||
num_workers=self.hparams.data_param.worker)
|
||||
batch_size=self.params.batch_size,
|
||||
num_workers=self.params.worker)
|
||||
|
||||
# Validation Dataloader
|
||||
def val_dataloader(self):
|
||||
return DataLoader(dataset=self.dataset.val_dataset, shuffle=True,
|
||||
batch_size=self.hparams.train_param.batch_size,
|
||||
num_workers=self.hparams.data_param.worker)
|
||||
batch_size=self.params.batch_size,
|
||||
num_workers=self.params.worker)
|
||||
|
||||
|
||||
class FilterLayer(nn.Module):
|
||||
|
Reference in New Issue
Block a user