Model Training

This commit is contained in:
Si11ium
2020-05-03 18:00:49 +02:00
parent 3e75d73a6b
commit 6d8fbd7184
4 changed files with 80 additions and 52 deletions

View File

@@ -13,6 +13,9 @@ import pytorch_lightning as pl
# Utility - Modules # Utility - Modules
################### ###################
from ml_lib.utils.model_io import ModelParameters
class F_x(object): class F_x(object):
def __init__(self): def __init__(self):
pass pass
@@ -111,12 +114,15 @@ class LightningBaseModule(pl.LightningModule, ABC):
def __init__(self, hparams): def __init__(self, hparams):
super(LightningBaseModule, self).__init__() super(LightningBaseModule, self).__init__()
self.hparams = deepcopy(hparams)
# Data loading # Set Parameters
# ============================================================================= ################################
# Map Object self.hparams = hparams
# self.map_storage = MapStorage(self.hparams.data_param.map_root) self.params = ModelParameters(hparams)
# Dataset Loading
################################
# TODO: Find a way to push Class Name, library path and parameters (sometimes thiose are objects) in here
def size(self): def size(self):
return self.shape return self.shape
@@ -158,25 +164,28 @@ class LightningBaseModule(pl.LightningModule, ABC):
weight_initializer = WeightInit(in_place_init_function=in_place_init_func_) weight_initializer = WeightInit(in_place_init_function=in_place_init_func_)
self.apply(weight_initializer) self.apply(weight_initializer)
class BaseModuleMixin_Dataloaders(ABC):
# Dataloaders # Dataloaders
# ================================================================================ # ================================================================================
# Train Dataloader # Train Dataloader
def train_dataloader(self): def train_dataloader(self):
return DataLoader(dataset=self.dataset.train_dataset, shuffle=True, return DataLoader(dataset=self.dataset.train_dataset, shuffle=True,
batch_size=self.hparams.train_param.batch_size, batch_size=self.params.batch_size,
num_workers=self.hparams.data_param.worker) num_workers=self.params.worker)
# Test Dataloader # Test Dataloader
def test_dataloader(self): def test_dataloader(self):
return DataLoader(dataset=self.dataset.test_dataset, shuffle=True, return DataLoader(dataset=self.dataset.test_dataset, shuffle=True,
batch_size=self.hparams.train_param.batch_size, batch_size=self.params.batch_size,
num_workers=self.hparams.data_param.worker) num_workers=self.params.worker)
# Validation Dataloader # Validation Dataloader
def val_dataloader(self): def val_dataloader(self):
return DataLoader(dataset=self.dataset.val_dataset, shuffle=True, return DataLoader(dataset=self.dataset.val_dataset, shuffle=True,
batch_size=self.hparams.train_param.batch_size, batch_size=self.params.batch_size,
num_workers=self.hparams.data_param.worker) num_workers=self.params.worker)
class FilterLayer(nn.Module): class FilterLayer(nn.Module):

View File

@@ -1,7 +1,9 @@
import ast import ast
from copy import deepcopy
from abc import ABC from abc import ABC
from argparse import Namespace from argparse import Namespace, ArgumentParser
from collections import defaultdict from collections import defaultdict
from configparser import ConfigParser from configparser import ConfigParser
from pathlib import Path from pathlib import Path
@@ -20,13 +22,13 @@ def is_jsonable(x):
class Config(ConfigParser, ABC): class Config(ConfigParser, ABC):
# TODO: Do this programmatically; This did not work: @property
# Initialize Default Sections def _model_weight_init(self):
# for section in self.default_sections: mod = __import__('torch.nn.init', fromlist=[self.model.weight_init])
# self.__setattr__(section, property(lambda x :x._get_namespace_for_section(section)) return getattr(mod, self.model.weight_init)
@property @property
def model_map(self): def _model_map(self):
""" """
This is function is supposed to return a dict, which holds a mapping from string model names to model classes This is function is supposed to return a dict, which holds a mapping from string model names to model classes
@@ -42,10 +44,15 @@ class Config(ConfigParser, ABC):
@property @property
def model_class(self): def model_class(self):
try: try:
return self.model_map[self.get('model', 'type')] return self._model_map[self.model.type]
except KeyError as e: except KeyError:
raise KeyError(rf'The model alias you provided ("{self.get("model", "type")}") does not exist! \n' raise KeyError(rf'The model alias you provided ("{self.get("model", "type")}") does not exist! \n'
f'Try one of these:\n{list(self.model_map.keys())}') f'Try one of these:\n{list(self._model_map.keys())}')
# TODO: Do this programmatically; This did not work:
# Initialize Default Sections as Property
# for section in self.default_sections:
# self.__setattr__(section, property(lambda x :x._get_namespace_for_section(section))
@property @property
def main(self): def main(self):
@@ -71,7 +78,12 @@ class Config(ConfigParser, ABC):
@property @property
def model_paramters(self): def model_paramters(self):
return ModelParameters(self.model, self.train, self.data) params = deepcopy(self.model.__dict__)
assert all(key not in list(params.keys()) for key in self.train.__dict__)
params.update(self.train.__dict__)
assert all(key not in list(params.keys()) for key in self.data.__dict__)
params.update(self.data.__dict__)
return params
@property @property
def tags(self, ): def tags(self, ):
@@ -113,12 +125,22 @@ class Config(ConfigParser, ABC):
new_config.read_dict(sorted_dict) new_config.read_dict(sorted_dict)
return new_config return new_config
@classmethod
def read_argparser(cls, argparser: ArgumentParser):
# Parse it
args = argparser.parse_args()
sorted_dict = cls._sort_combined_section_key_mapping(args.__dict__)
new_config = cls()
new_config.read_dict(sorted_dict)
return new_config
def build_model(self): def build_model(self):
return self.model_class(self.model_paramters) return self.model_class(self.model_paramters)
def build_and_init_model(self, weight_init_function): def build_and_init_model(self):
model = self.build_model() model = self.build_model()
model.init_weights(weight_init_function) model.init_weights(self._model_weight_init)
return model return model
def update(self, mapping): def update(self, mapping):

View File

@@ -1,3 +1,4 @@
from abc import ABC
from pathlib import Path from pathlib import Path
from pytorch_lightning.loggers.base import LightningLoggerBase from pytorch_lightning.loggers.base import LightningLoggerBase
@@ -5,10 +6,9 @@ from pytorch_lightning.loggers.neptune import NeptuneLogger
from pytorch_lightning.loggers.test_tube import TestTubeLogger from pytorch_lightning.loggers.test_tube import TestTubeLogger
from ml_lib.utils.config import Config from ml_lib.utils.config import Config
import numpy as np
class Logger(LightningLoggerBase): class Logger(LightningLoggerBase, ABC):
media_dir = 'media' media_dir = 'media'
@@ -29,7 +29,7 @@ class Logger(LightningLoggerBase):
@property @property
def project_name(self): def project_name(self):
return f"{self.config.project.owner}/{self.config.project.name}" return f"{self.config.project.owner}/{self.config.project.name.replace('_', '-')}"
@property @property
def version(self): def version(self):
@@ -37,8 +37,7 @@ class Logger(LightningLoggerBase):
@property @property
def outpath(self): def outpath(self):
# FIXME: Move this out of here, this is not the right place to do this!!! raise NotImplementedError
return Path(self.config.train.outpath) / self.config.model.type
def __init__(self, config: Config): def __init__(self, config: Config):
""" """

View File

@@ -1,5 +1,6 @@
from argparse import Namespace from argparse import Namespace
from collections import Mapping from collections import Mapping
from copy import deepcopy
from pathlib import Path from pathlib import Path
import torch import torch
@@ -8,12 +9,12 @@ from torch import nn
# Hyperparamter Object # Hyperparamter Object
class ModelParameters(Mapping, Namespace): class ModelParameters(Namespace, Mapping):
@property @property
def module_paramters(self): def module_kwargs(self):
paramter_mapping = dict()
paramter_mapping.update(self.model_param.__dict__) paramter_mapping = deepcopy(self.__dict__)
paramter_mapping.update( paramter_mapping.update(
dict( dict(
@@ -21,26 +22,37 @@ class ModelParameters(Mapping, Namespace):
) )
) )
del paramter_mapping['in_shape']
return paramter_mapping return paramter_mapping
@property
def test_activation(self):
try:
return self._activations[self.model.activation]
except KeyError:
return nn.ReLU
def __getitem__(self, k): def __getitem__(self, k):
# k: _KT -> _VT_co # k: _KT -> _VT_co
return self.__dict__[k] return self.__dict__[k]
def __len__(self): def __len__(self):
# -> int # -> int
return len(self.__dict__.keys()) return len(self.__dict__)
def __iter__(self): def __iter__(self):
# -> Iterator[_T_co] # -> Iterator[_T_co]
return iter(list(self.__dict__.keys())) return iter(list(self.__dict__.keys()))
def __delitem__(self, key): def __delitem__(self, key):
self.__dict__.__delitem__(key) self.__delattr__(key)
return True return True
def __getattribute__(self, name):
if name == 'activation':
return self._activations[self['activation']]
else:
return super(ModelParameters, self).__getattribute__(name)
_activations = dict( _activations = dict(
leaky_relu=nn.LeakyReLU, leaky_relu=nn.LeakyReLU,
relu=nn.ReLU, relu=nn.ReLU,
@@ -48,22 +60,8 @@ class ModelParameters(Mapping, Namespace):
tanh=nn.Tanh tanh=nn.Tanh
) )
def __init__(self, model_param, train_param, data_param): def __init__(self, parameter_mapping):
self.model_param = model_param super(ModelParameters, self).__init__(**parameter_mapping)
self.train_param = train_param
self.data_param = data_param
kwargs = vars(model_param)
kwargs.update(vars(train_param))
kwargs.update(vars(data_param))
super(ModelParameters, self).__init__(**kwargs)
def __getattribute__(self, item):
if item == 'activation':
try:
return self._activations[item]
except KeyError:
return nn.ReLU
return super(ModelParameters, self).__getattribute__(item)
class SavedLightningModels(object): class SavedLightningModels(object):