New Dataset for per spatial cluster training
This commit is contained in:
@ -222,6 +222,9 @@ class DatasetMixin:
|
||||
|
||||
def build_dataset(self, dataset_class, **kwargs):
|
||||
assert isinstance(self, LightningBaseModule)
|
||||
assert dataset_class.name == self.params.dataset_type, f'Check the dataset! ' + \
|
||||
f'Expected was {self.params.dataset_type}, ' + \
|
||||
f'given:{dataset_class.name}'
|
||||
|
||||
# Dataset
|
||||
# =============================================================================
|
||||
@ -258,7 +261,7 @@ class BaseDataloadersMixin(ABC):
|
||||
# In case you want to implement bootstraping
|
||||
# sampler = RandomSampler(self.dataset.train_dataset, True, len(self.dataset.train_dataset))
|
||||
sampler = None
|
||||
return DataLoader(dataset=self.dataset.train_dataset, shuffle=True if not sampler else None, sampler=sampler,
|
||||
return DataLoader(dataset=self.dataset.train_dataset, shuffle=False if not sampler else None, sampler=sampler,
|
||||
batch_size=self.params.batch_size,
|
||||
num_workers=self.params.worker)
|
||||
|
||||
|
Reference in New Issue
Block a user