Hparams passing with user warnings
This commit is contained in:
@ -1,4 +1,5 @@
|
||||
from typing import Union
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -19,9 +20,9 @@ class ConvModule(nn.Module):
|
||||
|
||||
def __init__(self, in_shape, conv_filters, conv_kernel, activation: nn.Module = nn.ELU, pooling_size=None,
|
||||
use_bias=True, use_norm=False, dropout: Union[int, float] = 0,
|
||||
conv_class=nn.Conv2d, conv_stride=1, conv_padding=0):
|
||||
conv_class=nn.Conv2d, conv_stride=1, conv_padding=0, **kwargs):
|
||||
super(ConvModule, self).__init__()
|
||||
|
||||
warnings.warn(f'The following arguments have been ignored: \n {list(kwargs.keys())}')
|
||||
# Module Parameters
|
||||
self.in_shape = in_shape
|
||||
in_channels, height, width = in_shape[0], in_shape[1], in_shape[2]
|
||||
|
@ -11,8 +11,16 @@ import pytorch_lightning as pl
|
||||
|
||||
# Utility - Modules
|
||||
###################
|
||||
class F_x(object):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def __call__(self, x):
|
||||
return x
|
||||
|
||||
|
||||
# Utility - Modules
|
||||
###################
|
||||
class Flatten(nn.Module):
|
||||
|
||||
@property
|
||||
|
Reference in New Issue
Block a user