Final Train Runs
This commit is contained in:
@ -4,7 +4,6 @@ from pathlib import Path
|
||||
|
||||
from torch.utils.data import DataLoader, ConcatDataset, WeightedRandomSampler
|
||||
from torchvision.transforms import Compose, RandomApply
|
||||
from tqdm import tqdm
|
||||
|
||||
from ml_lib.audio_toolset.audio_io import NormalizeLocal
|
||||
from ml_lib.audio_toolset.audio_to_mel_dataset import LibrosaAudioToMelDataset
|
||||
@ -69,13 +68,14 @@ class PrimatesLibrosaDatamodule(_BaseDataModule):
|
||||
|
||||
# Validation Dataloader
|
||||
def val_dataloader(self):
|
||||
return DataLoader(dataset=self.datasets[DATA_OPTION_devel], num_workers=self.num_worker, pin_memory=True,
|
||||
sampler=self.samplers[DATA_OPTION_devel], batch_size=self.batch_size)
|
||||
return DataLoader(dataset=self.datasets[DATA_OPTION_devel], shuffle=False,
|
||||
batch_size=self.batch_size, pin_memory=False,
|
||||
num_workers=self.num_worker)
|
||||
|
||||
# Test Dataloader
|
||||
def test_dataloader(self):
|
||||
return DataLoader(dataset=self.datasets[DATA_OPTION_test], shuffle=False,
|
||||
batch_size=self.batch_size, pin_memory=True,
|
||||
batch_size=self.batch_size, pin_memory=False,
|
||||
num_workers=self.num_worker)
|
||||
|
||||
def _build_subdataset(self, row, build=False):
|
||||
@ -134,7 +134,7 @@ class PrimatesLibrosaDatamodule(_BaseDataModule):
|
||||
datasets[data_option] = ConcatDataset(dataset)
|
||||
|
||||
# Build Weighted Sampler for train and val
|
||||
if data_option in [DATA_OPTION_train, DATA_OPTION_devel]:
|
||||
if data_option in [DATA_OPTION_train]:
|
||||
if self.sampler == EqualSampler.__name__:
|
||||
class_idxs = [[idx for idx, (_, __, label) in enumerate(datasets[data_option]) if label == class_idx]
|
||||
for class_idx in range(len(self.class_names))
|
||||
@ -147,6 +147,7 @@ class PrimatesLibrosaDatamodule(_BaseDataModule):
|
||||
len_largest_class = max(class_counts.values())
|
||||
|
||||
weights[data_option] = [1 / class_counts[x] for x in range(len(class_counts))]
|
||||
##############################################################################
|
||||
weights[data_option] = [weights[data_option][datasets[data_option][i][-1]]
|
||||
for i in range(len(datasets[data_option]))]
|
||||
samplers[data_option] = WeightedRandomSampler(weights[data_option],
|
||||
|
Reference in New Issue
Block a user