Model Training
This commit is contained in:
@@ -13,6 +13,9 @@ import pytorch_lightning as pl
|
|||||||
|
|
||||||
# Utility - Modules
|
# Utility - Modules
|
||||||
###################
|
###################
|
||||||
|
from ml_lib.utils.model_io import ModelParameters
|
||||||
|
|
||||||
|
|
||||||
class F_x(object):
|
class F_x(object):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
pass
|
pass
|
||||||
@@ -111,12 +114,15 @@ class LightningBaseModule(pl.LightningModule, ABC):
|
|||||||
|
|
||||||
def __init__(self, hparams):
|
def __init__(self, hparams):
|
||||||
super(LightningBaseModule, self).__init__()
|
super(LightningBaseModule, self).__init__()
|
||||||
self.hparams = deepcopy(hparams)
|
|
||||||
|
|
||||||
# Data loading
|
# Set Parameters
|
||||||
# =============================================================================
|
################################
|
||||||
# Map Object
|
self.hparams = hparams
|
||||||
# self.map_storage = MapStorage(self.hparams.data_param.map_root)
|
self.params = ModelParameters(hparams)
|
||||||
|
|
||||||
|
# Dataset Loading
|
||||||
|
################################
|
||||||
|
# TODO: Find a way to push Class Name, library path and parameters (sometimes thiose are objects) in here
|
||||||
|
|
||||||
def size(self):
|
def size(self):
|
||||||
return self.shape
|
return self.shape
|
||||||
@@ -158,25 +164,28 @@ class LightningBaseModule(pl.LightningModule, ABC):
|
|||||||
weight_initializer = WeightInit(in_place_init_function=in_place_init_func_)
|
weight_initializer = WeightInit(in_place_init_function=in_place_init_func_)
|
||||||
self.apply(weight_initializer)
|
self.apply(weight_initializer)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseModuleMixin_Dataloaders(ABC):
|
||||||
|
|
||||||
# Dataloaders
|
# Dataloaders
|
||||||
# ================================================================================
|
# ================================================================================
|
||||||
# Train Dataloader
|
# Train Dataloader
|
||||||
def train_dataloader(self):
|
def train_dataloader(self):
|
||||||
return DataLoader(dataset=self.dataset.train_dataset, shuffle=True,
|
return DataLoader(dataset=self.dataset.train_dataset, shuffle=True,
|
||||||
batch_size=self.hparams.train_param.batch_size,
|
batch_size=self.params.batch_size,
|
||||||
num_workers=self.hparams.data_param.worker)
|
num_workers=self.params.worker)
|
||||||
|
|
||||||
# Test Dataloader
|
# Test Dataloader
|
||||||
def test_dataloader(self):
|
def test_dataloader(self):
|
||||||
return DataLoader(dataset=self.dataset.test_dataset, shuffle=True,
|
return DataLoader(dataset=self.dataset.test_dataset, shuffle=True,
|
||||||
batch_size=self.hparams.train_param.batch_size,
|
batch_size=self.params.batch_size,
|
||||||
num_workers=self.hparams.data_param.worker)
|
num_workers=self.params.worker)
|
||||||
|
|
||||||
# Validation Dataloader
|
# Validation Dataloader
|
||||||
def val_dataloader(self):
|
def val_dataloader(self):
|
||||||
return DataLoader(dataset=self.dataset.val_dataset, shuffle=True,
|
return DataLoader(dataset=self.dataset.val_dataset, shuffle=True,
|
||||||
batch_size=self.hparams.train_param.batch_size,
|
batch_size=self.params.batch_size,
|
||||||
num_workers=self.hparams.data_param.worker)
|
num_workers=self.params.worker)
|
||||||
|
|
||||||
|
|
||||||
class FilterLayer(nn.Module):
|
class FilterLayer(nn.Module):
|
||||||
|
@@ -1,7 +1,9 @@
|
|||||||
import ast
|
import ast
|
||||||
|
from copy import deepcopy
|
||||||
|
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
|
|
||||||
from argparse import Namespace
|
from argparse import Namespace, ArgumentParser
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from configparser import ConfigParser
|
from configparser import ConfigParser
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -20,13 +22,13 @@ def is_jsonable(x):
|
|||||||
|
|
||||||
class Config(ConfigParser, ABC):
|
class Config(ConfigParser, ABC):
|
||||||
|
|
||||||
# TODO: Do this programmatically; This did not work:
|
@property
|
||||||
# Initialize Default Sections
|
def _model_weight_init(self):
|
||||||
# for section in self.default_sections:
|
mod = __import__('torch.nn.init', fromlist=[self.model.weight_init])
|
||||||
# self.__setattr__(section, property(lambda x :x._get_namespace_for_section(section))
|
return getattr(mod, self.model.weight_init)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def model_map(self):
|
def _model_map(self):
|
||||||
"""
|
"""
|
||||||
This is function is supposed to return a dict, which holds a mapping from string model names to model classes
|
This is function is supposed to return a dict, which holds a mapping from string model names to model classes
|
||||||
|
|
||||||
@@ -42,10 +44,15 @@ class Config(ConfigParser, ABC):
|
|||||||
@property
|
@property
|
||||||
def model_class(self):
|
def model_class(self):
|
||||||
try:
|
try:
|
||||||
return self.model_map[self.get('model', 'type')]
|
return self._model_map[self.model.type]
|
||||||
except KeyError as e:
|
except KeyError:
|
||||||
raise KeyError(rf'The model alias you provided ("{self.get("model", "type")}") does not exist! \n'
|
raise KeyError(rf'The model alias you provided ("{self.get("model", "type")}") does not exist! \n'
|
||||||
f'Try one of these:\n{list(self.model_map.keys())}')
|
f'Try one of these:\n{list(self._model_map.keys())}')
|
||||||
|
|
||||||
|
# TODO: Do this programmatically; This did not work:
|
||||||
|
# Initialize Default Sections as Property
|
||||||
|
# for section in self.default_sections:
|
||||||
|
# self.__setattr__(section, property(lambda x :x._get_namespace_for_section(section))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def main(self):
|
def main(self):
|
||||||
@@ -71,7 +78,12 @@ class Config(ConfigParser, ABC):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def model_paramters(self):
|
def model_paramters(self):
|
||||||
return ModelParameters(self.model, self.train, self.data)
|
params = deepcopy(self.model.__dict__)
|
||||||
|
assert all(key not in list(params.keys()) for key in self.train.__dict__)
|
||||||
|
params.update(self.train.__dict__)
|
||||||
|
assert all(key not in list(params.keys()) for key in self.data.__dict__)
|
||||||
|
params.update(self.data.__dict__)
|
||||||
|
return params
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def tags(self, ):
|
def tags(self, ):
|
||||||
@@ -113,12 +125,22 @@ class Config(ConfigParser, ABC):
|
|||||||
new_config.read_dict(sorted_dict)
|
new_config.read_dict(sorted_dict)
|
||||||
return new_config
|
return new_config
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def read_argparser(cls, argparser: ArgumentParser):
|
||||||
|
# Parse it
|
||||||
|
args = argparser.parse_args()
|
||||||
|
sorted_dict = cls._sort_combined_section_key_mapping(args.__dict__)
|
||||||
|
new_config = cls()
|
||||||
|
new_config.read_dict(sorted_dict)
|
||||||
|
return new_config
|
||||||
|
|
||||||
|
|
||||||
def build_model(self):
|
def build_model(self):
|
||||||
return self.model_class(self.model_paramters)
|
return self.model_class(self.model_paramters)
|
||||||
|
|
||||||
def build_and_init_model(self, weight_init_function):
|
def build_and_init_model(self):
|
||||||
model = self.build_model()
|
model = self.build_model()
|
||||||
model.init_weights(weight_init_function)
|
model.init_weights(self._model_weight_init)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def update(self, mapping):
|
def update(self, mapping):
|
||||||
|
@@ -1,3 +1,4 @@
|
|||||||
|
from abc import ABC
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from pytorch_lightning.loggers.base import LightningLoggerBase
|
from pytorch_lightning.loggers.base import LightningLoggerBase
|
||||||
@@ -5,10 +6,9 @@ from pytorch_lightning.loggers.neptune import NeptuneLogger
|
|||||||
from pytorch_lightning.loggers.test_tube import TestTubeLogger
|
from pytorch_lightning.loggers.test_tube import TestTubeLogger
|
||||||
|
|
||||||
from ml_lib.utils.config import Config
|
from ml_lib.utils.config import Config
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
|
|
||||||
class Logger(LightningLoggerBase):
|
class Logger(LightningLoggerBase, ABC):
|
||||||
|
|
||||||
media_dir = 'media'
|
media_dir = 'media'
|
||||||
|
|
||||||
@@ -29,7 +29,7 @@ class Logger(LightningLoggerBase):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def project_name(self):
|
def project_name(self):
|
||||||
return f"{self.config.project.owner}/{self.config.project.name}"
|
return f"{self.config.project.owner}/{self.config.project.name.replace('_', '-')}"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def version(self):
|
def version(self):
|
||||||
@@ -37,8 +37,7 @@ class Logger(LightningLoggerBase):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def outpath(self):
|
def outpath(self):
|
||||||
# FIXME: Move this out of here, this is not the right place to do this!!!
|
raise NotImplementedError
|
||||||
return Path(self.config.train.outpath) / self.config.model.type
|
|
||||||
|
|
||||||
def __init__(self, config: Config):
|
def __init__(self, config: Config):
|
||||||
"""
|
"""
|
||||||
|
@@ -1,5 +1,6 @@
|
|||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
from collections import Mapping
|
from collections import Mapping
|
||||||
|
from copy import deepcopy
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -8,12 +9,12 @@ from torch import nn
|
|||||||
|
|
||||||
|
|
||||||
# Hyperparamter Object
|
# Hyperparamter Object
|
||||||
class ModelParameters(Mapping, Namespace):
|
class ModelParameters(Namespace, Mapping):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def module_paramters(self):
|
def module_kwargs(self):
|
||||||
paramter_mapping = dict()
|
|
||||||
paramter_mapping.update(self.model_param.__dict__)
|
paramter_mapping = deepcopy(self.__dict__)
|
||||||
|
|
||||||
paramter_mapping.update(
|
paramter_mapping.update(
|
||||||
dict(
|
dict(
|
||||||
@@ -21,26 +22,37 @@ class ModelParameters(Mapping, Namespace):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
del paramter_mapping['in_shape']
|
|
||||||
|
|
||||||
return paramter_mapping
|
return paramter_mapping
|
||||||
|
|
||||||
|
@property
|
||||||
|
def test_activation(self):
|
||||||
|
try:
|
||||||
|
return self._activations[self.model.activation]
|
||||||
|
except KeyError:
|
||||||
|
return nn.ReLU
|
||||||
|
|
||||||
def __getitem__(self, k):
|
def __getitem__(self, k):
|
||||||
# k: _KT -> _VT_co
|
# k: _KT -> _VT_co
|
||||||
return self.__dict__[k]
|
return self.__dict__[k]
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
# -> int
|
# -> int
|
||||||
return len(self.__dict__.keys())
|
return len(self.__dict__)
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
# -> Iterator[_T_co]
|
# -> Iterator[_T_co]
|
||||||
return iter(list(self.__dict__.keys()))
|
return iter(list(self.__dict__.keys()))
|
||||||
|
|
||||||
def __delitem__(self, key):
|
def __delitem__(self, key):
|
||||||
self.__dict__.__delitem__(key)
|
self.__delattr__(key)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def __getattribute__(self, name):
|
||||||
|
if name == 'activation':
|
||||||
|
return self._activations[self['activation']]
|
||||||
|
else:
|
||||||
|
return super(ModelParameters, self).__getattribute__(name)
|
||||||
|
|
||||||
_activations = dict(
|
_activations = dict(
|
||||||
leaky_relu=nn.LeakyReLU,
|
leaky_relu=nn.LeakyReLU,
|
||||||
relu=nn.ReLU,
|
relu=nn.ReLU,
|
||||||
@@ -48,22 +60,8 @@ class ModelParameters(Mapping, Namespace):
|
|||||||
tanh=nn.Tanh
|
tanh=nn.Tanh
|
||||||
)
|
)
|
||||||
|
|
||||||
def __init__(self, model_param, train_param, data_param):
|
def __init__(self, parameter_mapping):
|
||||||
self.model_param = model_param
|
super(ModelParameters, self).__init__(**parameter_mapping)
|
||||||
self.train_param = train_param
|
|
||||||
self.data_param = data_param
|
|
||||||
kwargs = vars(model_param)
|
|
||||||
kwargs.update(vars(train_param))
|
|
||||||
kwargs.update(vars(data_param))
|
|
||||||
super(ModelParameters, self).__init__(**kwargs)
|
|
||||||
|
|
||||||
def __getattribute__(self, item):
|
|
||||||
if item == 'activation':
|
|
||||||
try:
|
|
||||||
return self._activations[item]
|
|
||||||
except KeyError:
|
|
||||||
return nn.ReLU
|
|
||||||
return super(ModelParameters, self).__getattribute__(item)
|
|
||||||
|
|
||||||
|
|
||||||
class SavedLightningModels(object):
|
class SavedLightningModels(object):
|
||||||
|
Reference in New Issue
Block a user