Variational Generator
This commit is contained in:
@ -17,7 +17,7 @@ class ConvModule(nn.Module):
|
||||
output = self(x)
|
||||
return output.shape[1:]
|
||||
|
||||
def __init__(self, in_shape, activation: nn.Module = nn.ELU, pooling_size=None, use_bias=True, use_norm=True,
|
||||
def __init__(self, in_shape, activation: nn.Module = nn.ELU, pooling_size=None, use_bias=True, use_norm=False,
|
||||
dropout: Union[int, float] = 0, conv_class=nn.Conv2d,
|
||||
conv_filters=64, conv_kernel=5, conv_stride=1, conv_padding=0):
|
||||
super(ConvModule, self).__init__()
|
||||
|
@ -154,7 +154,7 @@ class LightningBaseModule(pl.LightningModule, ABC):
|
||||
|
||||
# Validation Dataloader
|
||||
def val_dataloader(self):
|
||||
return DataLoader(dataset=self.dataset.val_dataset, shuffle=False,
|
||||
return DataLoader(dataset=self.dataset.val_dataset, shuffle=True,
|
||||
batch_size=self.hparams.train_param.batch_size,
|
||||
num_workers=self.hparams.data_param.worker)
|
||||
|
||||
|
Reference in New Issue
Block a user