New Dataset for per spatial cluster training

This commit is contained in:
Si11ium
2020-06-09 14:08:35 +02:00
parent 821b2d1961
commit 23f3aa878d
10 changed files with 104 additions and 12 deletions

View File

@ -222,6 +222,9 @@ class DatasetMixin:
def build_dataset(self, dataset_class, **kwargs):
assert isinstance(self, LightningBaseModule)
assert dataset_class.name == self.params.dataset_type, f'Check the dataset! ' + \
f'Expected was {self.params.dataset_type}, ' + \
f'given:{dataset_class.name}'
# Dataset
# =============================================================================
@ -258,7 +261,7 @@ class BaseDataloadersMixin(ABC):
# In case you want to implement bootstraping
# sampler = RandomSampler(self.dataset.train_dataset, True, len(self.dataset.train_dataset))
sampler = None
return DataLoader(dataset=self.dataset.train_dataset, shuffle=True if not sampler else None, sampler=sampler,
return DataLoader(dataset=self.dataset.train_dataset, shuffle=False if not sampler else None, sampler=sampler,
batch_size=self.params.batch_size,
num_workers=self.params.worker)