Residual Model
This commit is contained in:
@ -109,11 +109,16 @@ class DeConvModule(ShapeMixin, nn.Module):
|
|||||||
|
|
||||||
class ResidualModule(ShapeMixin, nn.Module):
|
class ResidualModule(ShapeMixin, nn.Module):
|
||||||
|
|
||||||
def __init__(self, in_shape, module_class, n, **module_parameters):
|
def __init__(self, in_shape, module_class, n, norm=False, **module_parameters):
|
||||||
assert n >= 1
|
assert n >= 1
|
||||||
super(ResidualModule, self).__init__()
|
super(ResidualModule, self).__init__()
|
||||||
self.in_shape = in_shape
|
self.in_shape = in_shape
|
||||||
module_parameters.update(in_shape=in_shape)
|
module_parameters.update(in_shape=in_shape)
|
||||||
|
if norm:
|
||||||
|
self.norm = nn.BatchNorm1d if len(self.in_shape) <= 2 else nn.BatchNorm2d
|
||||||
|
self.norm = self.norm(self.in_shape if isinstance(self.in_shape, int) else self.in_shape[0])
|
||||||
|
else:
|
||||||
|
self.norm = F_x(self.in_shape)
|
||||||
self.activation = module_parameters.get('activation', None)
|
self.activation = module_parameters.get('activation', None)
|
||||||
if self.activation is not None:
|
if self.activation is not None:
|
||||||
self.activation = self.activation()
|
self.activation = self.activation()
|
||||||
|
@ -105,6 +105,7 @@ class Config(ConfigParser, ABC):
|
|||||||
params.update(self.train.__dict__)
|
params.update(self.train.__dict__)
|
||||||
assert all(key not in list(params.keys()) for key in self.data.__dict__)
|
assert all(key not in list(params.keys()) for key in self.data.__dict__)
|
||||||
params.update(self.data.__dict__)
|
params.update(self.data.__dict__)
|
||||||
|
params.update(version=self.version)
|
||||||
params.update(exp_path=str(self.exp_path), exp_fingerprint=str(self.fingerprint))
|
params.update(exp_path=str(self.exp_path), exp_fingerprint=str(self.fingerprint))
|
||||||
return params
|
return params
|
||||||
|
|
||||||
|
@ -64,7 +64,7 @@ class Logger(LightningLoggerBase, ABC):
|
|||||||
api_key=self.config.project.neptune_key,
|
api_key=self.config.project.neptune_key,
|
||||||
experiment_name=self.name,
|
experiment_name=self.name,
|
||||||
project_name=self.project_name,
|
project_name=self.project_name,
|
||||||
upload_source_files=list())
|
params=self.config.model_paramters)
|
||||||
self.neptunelogger = NeptuneLogger(**self._neptune_kwargs)
|
self.neptunelogger = NeptuneLogger(**self._neptune_kwargs)
|
||||||
self.testtubelogger = TestTubeLogger(**self._testtube_kwargs)
|
self.testtubelogger = TestTubeLogger(**self._testtube_kwargs)
|
||||||
self.log_config_as_ini()
|
self.log_config_as_ini()
|
||||||
|
Reference in New Issue
Block a user