Final Train Runs

This commit is contained in:
Steffen Illium
2021-03-18 07:45:07 +01:00
parent ad254dae92
commit fecf4923c2
14 changed files with 672 additions and 362 deletions

View File

@ -4,7 +4,6 @@ from pathlib import Path
from torch.utils.data import DataLoader, ConcatDataset, WeightedRandomSampler
from torchvision.transforms import Compose, RandomApply
from tqdm import tqdm
from ml_lib.audio_toolset.audio_io import NormalizeLocal
from ml_lib.audio_toolset.audio_to_mel_dataset import LibrosaAudioToMelDataset
@ -69,13 +68,14 @@ class PrimatesLibrosaDatamodule(_BaseDataModule):
# Validation Dataloader
def val_dataloader(self):
return DataLoader(dataset=self.datasets[DATA_OPTION_devel], num_workers=self.num_worker, pin_memory=True,
sampler=self.samplers[DATA_OPTION_devel], batch_size=self.batch_size)
return DataLoader(dataset=self.datasets[DATA_OPTION_devel], shuffle=False,
batch_size=self.batch_size, pin_memory=False,
num_workers=self.num_worker)
# Test Dataloader
def test_dataloader(self):
return DataLoader(dataset=self.datasets[DATA_OPTION_test], shuffle=False,
batch_size=self.batch_size, pin_memory=True,
batch_size=self.batch_size, pin_memory=False,
num_workers=self.num_worker)
def _build_subdataset(self, row, build=False):
@ -134,7 +134,7 @@ class PrimatesLibrosaDatamodule(_BaseDataModule):
datasets[data_option] = ConcatDataset(dataset)
# Build Weighted Sampler for train and val
if data_option in [DATA_OPTION_train, DATA_OPTION_devel]:
if data_option in [DATA_OPTION_train]:
if self.sampler == EqualSampler.__name__:
class_idxs = [[idx for idx, (_, __, label) in enumerate(datasets[data_option]) if label == class_idx]
for class_idx in range(len(self.class_names))
@ -147,6 +147,7 @@ class PrimatesLibrosaDatamodule(_BaseDataModule):
len_largest_class = max(class_counts.values())
weights[data_option] = [1 / class_counts[x] for x in range(len(class_counts))]
##############################################################################
weights[data_option] = [weights[data_option][datasets[data_option][i][-1]]
for i in range(len(datasets[data_option]))]
samplers[data_option] = WeightedRandomSampler(weights[data_option],