Variational Generator

This commit is contained in:
Si11ium
2020-03-10 16:59:51 +01:00
parent 21e7e31805
commit 1b5a7dc69e
10 changed files with 177 additions and 95 deletions

View File

@ -17,7 +17,7 @@ class ConvModule(nn.Module):
output = self(x)
return output.shape[1:]
def __init__(self, in_shape, activation: nn.Module = nn.ELU, pooling_size=None, use_bias=True, use_norm=True,
def __init__(self, in_shape, activation: nn.Module = nn.ELU, pooling_size=None, use_bias=True, use_norm=False,
dropout: Union[int, float] = 0, conv_class=nn.Conv2d,
conv_filters=64, conv_kernel=5, conv_stride=1, conv_padding=0):
super(ConvModule, self).__init__()

View File

@ -154,7 +154,7 @@ class LightningBaseModule(pl.LightningModule, ABC):
# Validation Dataloader
def val_dataloader(self):
return DataLoader(dataset=self.dataset.val_dataset, shuffle=False,
return DataLoader(dataset=self.dataset.val_dataset, shuffle=True,
batch_size=self.hparams.train_param.batch_size,
num_workers=self.hparams.data_param.worker)