diff --git a/modules/blocks.py b/modules/blocks.py index 5a28e94..04cb264 100644 --- a/modules/blocks.py +++ b/modules/blocks.py @@ -109,11 +109,16 @@ class DeConvModule(ShapeMixin, nn.Module): class ResidualModule(ShapeMixin, nn.Module): - def __init__(self, in_shape, module_class, n, **module_parameters): + def __init__(self, in_shape, module_class, n, norm=False, **module_parameters): assert n >= 1 super(ResidualModule, self).__init__() self.in_shape = in_shape module_parameters.update(in_shape=in_shape) + if norm: + self.norm = nn.BatchNorm1d if len(self.in_shape) <= 2 else nn.BatchNorm2d + self.norm = self.norm(self.in_shape if isinstance(self.in_shape, int) else self.in_shape[0]) + else: + self.norm = F_x(self.in_shape) self.activation = module_parameters.get('activation', None) if self.activation is not None: self.activation = self.activation() diff --git a/utils/config.py b/utils/config.py index 5892604..9797a2c 100644 --- a/utils/config.py +++ b/utils/config.py @@ -105,6 +105,7 @@ class Config(ConfigParser, ABC): params.update(self.train.__dict__) assert all(key not in list(params.keys()) for key in self.data.__dict__) params.update(self.data.__dict__) + params.update(version=self.version) params.update(exp_path=str(self.exp_path), exp_fingerprint=str(self.fingerprint)) return params diff --git a/utils/logging.py b/utils/logging.py index 11c2a5b..645ee4f 100644 --- a/utils/logging.py +++ b/utils/logging.py @@ -64,7 +64,7 @@ class Logger(LightningLoggerBase, ABC): api_key=self.config.project.neptune_key, experiment_name=self.name, project_name=self.project_name, - upload_source_files=list()) + params=self.config.model_paramters) self.neptunelogger = NeptuneLogger(**self._neptune_kwargs) self.testtubelogger = TestTubeLogger(**self._testtube_kwargs) self.log_config_as_ini()