Model Training

This commit is contained in:
Si11ium 2020-05-03 18:00:49 +02:00
parent 3e75d73a6b
commit 6d8fbd7184
4 changed files with 80 additions and 52 deletions

View File

@ -13,6 +13,9 @@ import pytorch_lightning as pl
# Utility - Modules
###################
from ml_lib.utils.model_io import ModelParameters
class F_x(object):
def __init__(self):
pass
@ -111,12 +114,15 @@ class LightningBaseModule(pl.LightningModule, ABC):
def __init__(self, hparams):
super(LightningBaseModule, self).__init__()
self.hparams = deepcopy(hparams)
# Data loading
# =============================================================================
# Map Object
# self.map_storage = MapStorage(self.hparams.data_param.map_root)
# Set Parameters
################################
self.hparams = hparams
self.params = ModelParameters(hparams)
# Dataset Loading
################################
# TODO: Find a way to push Class Name, library path and parameters (sometimes thiose are objects) in here
def size(self):
return self.shape
@ -158,25 +164,28 @@ class LightningBaseModule(pl.LightningModule, ABC):
weight_initializer = WeightInit(in_place_init_function=in_place_init_func_)
self.apply(weight_initializer)
class BaseModuleMixin_Dataloaders(ABC):
# Dataloaders
# ================================================================================
# Train Dataloader
def train_dataloader(self):
return DataLoader(dataset=self.dataset.train_dataset, shuffle=True,
batch_size=self.hparams.train_param.batch_size,
num_workers=self.hparams.data_param.worker)
batch_size=self.params.batch_size,
num_workers=self.params.worker)
# Test Dataloader
def test_dataloader(self):
return DataLoader(dataset=self.dataset.test_dataset, shuffle=True,
batch_size=self.hparams.train_param.batch_size,
num_workers=self.hparams.data_param.worker)
batch_size=self.params.batch_size,
num_workers=self.params.worker)
# Validation Dataloader
def val_dataloader(self):
return DataLoader(dataset=self.dataset.val_dataset, shuffle=True,
batch_size=self.hparams.train_param.batch_size,
num_workers=self.hparams.data_param.worker)
batch_size=self.params.batch_size,
num_workers=self.params.worker)
class FilterLayer(nn.Module):

View File

@ -1,7 +1,9 @@
import ast
from copy import deepcopy
from abc import ABC
from argparse import Namespace
from argparse import Namespace, ArgumentParser
from collections import defaultdict
from configparser import ConfigParser
from pathlib import Path
@ -20,13 +22,13 @@ def is_jsonable(x):
class Config(ConfigParser, ABC):
# TODO: Do this programmatically; This did not work:
# Initialize Default Sections
# for section in self.default_sections:
# self.__setattr__(section, property(lambda x :x._get_namespace_for_section(section))
@property
def _model_weight_init(self):
mod = __import__('torch.nn.init', fromlist=[self.model.weight_init])
return getattr(mod, self.model.weight_init)
@property
def model_map(self):
def _model_map(self):
"""
This is function is supposed to return a dict, which holds a mapping from string model names to model classes
@ -42,10 +44,15 @@ class Config(ConfigParser, ABC):
@property
def model_class(self):
try:
return self.model_map[self.get('model', 'type')]
except KeyError as e:
return self._model_map[self.model.type]
except KeyError:
raise KeyError(rf'The model alias you provided ("{self.get("model", "type")}") does not exist! \n'
f'Try one of these:\n{list(self.model_map.keys())}')
f'Try one of these:\n{list(self._model_map.keys())}')
# TODO: Do this programmatically; This did not work:
# Initialize Default Sections as Property
# for section in self.default_sections:
# self.__setattr__(section, property(lambda x :x._get_namespace_for_section(section))
@property
def main(self):
@ -71,7 +78,12 @@ class Config(ConfigParser, ABC):
@property
def model_paramters(self):
return ModelParameters(self.model, self.train, self.data)
params = deepcopy(self.model.__dict__)
assert all(key not in list(params.keys()) for key in self.train.__dict__)
params.update(self.train.__dict__)
assert all(key not in list(params.keys()) for key in self.data.__dict__)
params.update(self.data.__dict__)
return params
@property
def tags(self, ):
@ -113,12 +125,22 @@ class Config(ConfigParser, ABC):
new_config.read_dict(sorted_dict)
return new_config
@classmethod
def read_argparser(cls, argparser: ArgumentParser):
# Parse it
args = argparser.parse_args()
sorted_dict = cls._sort_combined_section_key_mapping(args.__dict__)
new_config = cls()
new_config.read_dict(sorted_dict)
return new_config
def build_model(self):
return self.model_class(self.model_paramters)
def build_and_init_model(self, weight_init_function):
def build_and_init_model(self):
model = self.build_model()
model.init_weights(weight_init_function)
model.init_weights(self._model_weight_init)
return model
def update(self, mapping):

View File

@ -1,3 +1,4 @@
from abc import ABC
from pathlib import Path
from pytorch_lightning.loggers.base import LightningLoggerBase
@ -5,10 +6,9 @@ from pytorch_lightning.loggers.neptune import NeptuneLogger
from pytorch_lightning.loggers.test_tube import TestTubeLogger
from ml_lib.utils.config import Config
import numpy as np
class Logger(LightningLoggerBase):
class Logger(LightningLoggerBase, ABC):
media_dir = 'media'
@ -29,7 +29,7 @@ class Logger(LightningLoggerBase):
@property
def project_name(self):
return f"{self.config.project.owner}/{self.config.project.name}"
return f"{self.config.project.owner}/{self.config.project.name.replace('_', '-')}"
@property
def version(self):
@ -37,8 +37,7 @@ class Logger(LightningLoggerBase):
@property
def outpath(self):
# FIXME: Move this out of here, this is not the right place to do this!!!
return Path(self.config.train.outpath) / self.config.model.type
raise NotImplementedError
def __init__(self, config: Config):
"""

View File

@ -1,5 +1,6 @@
from argparse import Namespace
from collections import Mapping
from copy import deepcopy
from pathlib import Path
import torch
@ -8,12 +9,12 @@ from torch import nn
# Hyperparamter Object
class ModelParameters(Mapping, Namespace):
class ModelParameters(Namespace, Mapping):
@property
def module_paramters(self):
paramter_mapping = dict()
paramter_mapping.update(self.model_param.__dict__)
def module_kwargs(self):
paramter_mapping = deepcopy(self.__dict__)
paramter_mapping.update(
dict(
@ -21,26 +22,37 @@ class ModelParameters(Mapping, Namespace):
)
)
del paramter_mapping['in_shape']
return paramter_mapping
@property
def test_activation(self):
try:
return self._activations[self.model.activation]
except KeyError:
return nn.ReLU
def __getitem__(self, k):
# k: _KT -> _VT_co
return self.__dict__[k]
def __len__(self):
# -> int
return len(self.__dict__.keys())
return len(self.__dict__)
def __iter__(self):
# -> Iterator[_T_co]
return iter(list(self.__dict__.keys()))
def __delitem__(self, key):
self.__dict__.__delitem__(key)
self.__delattr__(key)
return True
def __getattribute__(self, name):
if name == 'activation':
return self._activations[self['activation']]
else:
return super(ModelParameters, self).__getattribute__(name)
_activations = dict(
leaky_relu=nn.LeakyReLU,
relu=nn.ReLU,
@ -48,22 +60,8 @@ class ModelParameters(Mapping, Namespace):
tanh=nn.Tanh
)
def __init__(self, model_param, train_param, data_param):
self.model_param = model_param
self.train_param = train_param
self.data_param = data_param
kwargs = vars(model_param)
kwargs.update(vars(train_param))
kwargs.update(vars(data_param))
super(ModelParameters, self).__init__(**kwargs)
def __getattribute__(self, item):
if item == 'activation':
try:
return self._activations[item]
except KeyError:
return nn.ReLU
return super(ModelParameters, self).__getattribute__(item)
def __init__(self, parameter_mapping):
super(ModelParameters, self).__init__(**parameter_mapping)
class SavedLightningModels(object):