Debugging Validation and testing
This commit is contained in:
@@ -90,7 +90,7 @@ class LightningBaseModule(pl.LightningModule, ABC):
|
||||
# Data loading
|
||||
# =============================================================================
|
||||
# Map Object
|
||||
self.map_storage = MapStorage(self.hparams.data_param.map_root)
|
||||
self.map_storage = MapStorage(self.hparams.data_param.map_root, load_all=True)
|
||||
|
||||
def size(self):
|
||||
return self.shape
|
||||
@@ -143,19 +143,19 @@ class LightningBaseModule(pl.LightningModule, ABC):
|
||||
# Train Dataloader
|
||||
def train_dataloader(self):
|
||||
return DataLoader(dataset=self.dataset.train_dataset, shuffle=True,
|
||||
batch_size=self.hparams.data_param.batchsize,
|
||||
batch_size=self.hparams.train_param.batch_size,
|
||||
num_workers=self.hparams.data_param.worker)
|
||||
|
||||
# Test Dataloader
|
||||
def test_dataloader(self):
|
||||
return DataLoader(dataset=self.dataset.test_dataset, shuffle=True,
|
||||
batch_size=self.hparams.data_param.batchsize,
|
||||
batch_size=self.hparams.train_param.batch_size,
|
||||
num_workers=self.hparams.data_param.worker)
|
||||
|
||||
# Validation Dataloader
|
||||
def val_dataloader(self):
|
||||
return DataLoader(dataset=self.dataset.val_dataset, shuffle=False,
|
||||
batch_size=self.hparams.data_param.batchsize,
|
||||
batch_size=self.hparams.train_param.batch_size,
|
||||
num_workers=self.hparams.data_param.worker)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user