CCS intergration dataloader
This commit is contained in:
parent
6ace861016
commit
d4059779c4
169
datasets/ccs_librosa_datamodule.py
Normal file
169
datasets/ccs_librosa_datamodule.py
Normal file
@ -0,0 +1,169 @@
|
|||||||
|
import multiprocessing as mp
|
||||||
|
from collections import defaultdict
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from torch.utils.data import DataLoader, ConcatDataset, WeightedRandomSampler
|
||||||
|
from torchvision.transforms import Compose, RandomApply
|
||||||
|
|
||||||
|
from ml_lib.audio_toolset.audio_io import NormalizeLocal
|
||||||
|
from ml_lib.audio_toolset.audio_to_mel_dataset import LibrosaAudioToMelDataset
|
||||||
|
from ml_lib.audio_toolset.mel_augmentation import NoiseInjection, LoudnessManipulator, ShiftTime, MaskAug
|
||||||
|
from ml_lib.utils._basedatamodule import _BaseDataModule, DATA_OPTION_test, DATA_OPTION_train, DATA_OPTION_devel
|
||||||
|
from ml_lib.utils.equal_sampler import EqualSampler
|
||||||
|
from ml_lib.utils.transforms import ToTensor
|
||||||
|
|
||||||
|
data_options = [DATA_OPTION_test, DATA_OPTION_train, DATA_OPTION_devel]
|
||||||
|
|
||||||
|
|
||||||
|
class CCSLibrosaDatamodule(_BaseDataModule):
|
||||||
|
|
||||||
|
@property
|
||||||
|
def class_names(self):
|
||||||
|
return {key: val for val, key in enumerate(['negative', 'positive'])}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def n_classes(self):
|
||||||
|
return len(self.class_names)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def shape(self):
|
||||||
|
return self.datasets[DATA_OPTION_train].datasets[0][0][1].shape
|
||||||
|
|
||||||
|
@property
|
||||||
|
def mel_folder(self):
|
||||||
|
return self.root / 'mel_folder'
|
||||||
|
|
||||||
|
@property
|
||||||
|
def wav_folder(self):
|
||||||
|
return self.root / 'wav'
|
||||||
|
|
||||||
|
def __init__(self, data_root, batch_size, num_worker, sr, n_mels, n_fft, hop_length, sampler=None,
|
||||||
|
random_apply_chance=0.5, target_mel_length_in_seconds=1,
|
||||||
|
loudness_ratio=0.3, shift_ratio=0.3, noise_ratio=0.3, mask_ratio=0.3):
|
||||||
|
super(CCSLibrosaDatamodule, self).__init__()
|
||||||
|
self.sampler = sampler
|
||||||
|
self.samplers = None
|
||||||
|
|
||||||
|
self.num_worker = num_worker or 1
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.root = Path(data_root) / 'ComParE2021_CCS'
|
||||||
|
self.mel_length_in_seconds = target_mel_length_in_seconds
|
||||||
|
|
||||||
|
# Mel Transforms - will be pushed with all other paramters by self.__dict__ to subdataset-class
|
||||||
|
self.mel_kwargs = dict(sr=sr, n_mels=n_mels, n_fft=n_fft, hop_length=hop_length)
|
||||||
|
|
||||||
|
# Utility
|
||||||
|
self.utility_transforms = Compose([NormalizeLocal(), ToTensor()])
|
||||||
|
|
||||||
|
# Data Augmentations
|
||||||
|
self.random_apply_chance = random_apply_chance
|
||||||
|
self.mel_augmentations = Compose([
|
||||||
|
RandomApply([NoiseInjection(noise_ratio)], p=random_apply_chance),
|
||||||
|
RandomApply([LoudnessManipulator(loudness_ratio)], p=random_apply_chance),
|
||||||
|
RandomApply([ShiftTime(shift_ratio)], p=random_apply_chance),
|
||||||
|
RandomApply([MaskAug(mask_ratio)], p=random_apply_chance),
|
||||||
|
self.utility_transforms])
|
||||||
|
|
||||||
|
def train_dataloader(self):
|
||||||
|
return DataLoader(dataset=self.datasets[DATA_OPTION_train], num_workers=self.num_worker, pin_memory=True,
|
||||||
|
sampler=self.samplers[DATA_OPTION_train], batch_size=self.batch_size)
|
||||||
|
|
||||||
|
# Validation Dataloader
|
||||||
|
def val_dataloader(self):
|
||||||
|
return DataLoader(dataset=self.datasets[DATA_OPTION_devel], shuffle=False,
|
||||||
|
batch_size=self.batch_size, pin_memory=False,
|
||||||
|
num_workers=self.num_worker)
|
||||||
|
|
||||||
|
# Test Dataloader
|
||||||
|
def test_dataloader(self):
|
||||||
|
return DataLoader(dataset=self.datasets[DATA_OPTION_test], shuffle=False,
|
||||||
|
batch_size=self.batch_size, pin_memory=False,
|
||||||
|
num_workers=self.num_worker)
|
||||||
|
|
||||||
|
def _build_subdataset(self, row, build=False):
|
||||||
|
slice_file_name, class_name = row.strip().split(',')
|
||||||
|
class_id = self.class_names.get(class_name, -1)
|
||||||
|
audio_file_path = self.wav_folder / slice_file_name
|
||||||
|
|
||||||
|
# DATA OPTION DIFFERENTIATION !!!!!!!!!!! - Begin
|
||||||
|
kwargs = self.__dict__
|
||||||
|
if any([x in slice_file_name for x in [DATA_OPTION_devel, DATA_OPTION_test]]):
|
||||||
|
kwargs.update(mel_augmentations=self.utility_transforms)
|
||||||
|
# DATA OPTION DIFFERENTIATION !!!!!!!!!!! - End
|
||||||
|
|
||||||
|
target_frames = self.mel_length_in_seconds * self.mel_kwargs['sr']
|
||||||
|
sample_segment_length = target_frames // self.mel_kwargs['hop_length'] + 1
|
||||||
|
kwargs.update(sample_segment_len=sample_segment_length, sample_hop_len=sample_segment_length//2)
|
||||||
|
mel_dataset = LibrosaAudioToMelDataset(audio_file_path, class_id, **kwargs)
|
||||||
|
if build:
|
||||||
|
assert mel_dataset.build_mel()
|
||||||
|
return mel_dataset, class_id, slice_file_name
|
||||||
|
|
||||||
|
def prepare_data(self, *args, **kwargs):
|
||||||
|
datasets = dict()
|
||||||
|
for data_option in data_options:
|
||||||
|
with open(Path(self.root) / 'lab' / f'{data_option}.csv', mode='r') as f:
|
||||||
|
# Exclude the header
|
||||||
|
_ = next(f)
|
||||||
|
all_rows = list(f)
|
||||||
|
chunksize = len(all_rows) // max(self.num_worker, 1)
|
||||||
|
dataset = list()
|
||||||
|
with mp.Pool(processes=self.num_worker) as pool:
|
||||||
|
|
||||||
|
from itertools import repeat
|
||||||
|
results = pool.starmap_async(self._build_subdataset, zip(all_rows, repeat(True, len(all_rows))),
|
||||||
|
chunksize=chunksize)
|
||||||
|
for sub_dataset in results.get():
|
||||||
|
dataset.append(sub_dataset[0])
|
||||||
|
datasets[data_option] = ConcatDataset(dataset)
|
||||||
|
self.datasets = datasets
|
||||||
|
return datasets
|
||||||
|
|
||||||
|
def setup(self, stag=None):
|
||||||
|
datasets = dict()
|
||||||
|
samplers = dict()
|
||||||
|
weights = dict()
|
||||||
|
|
||||||
|
for data_option in data_options:
|
||||||
|
with open(Path(self.root) / 'lab' / f'{data_option}.csv', mode='r') as f:
|
||||||
|
# Exclude the header
|
||||||
|
_ = next(f)
|
||||||
|
all_rows = list(f)
|
||||||
|
dataset = list()
|
||||||
|
for row in all_rows:
|
||||||
|
mel_dataset, class_id, _ = self._build_subdataset(row)
|
||||||
|
dataset.append(mel_dataset)
|
||||||
|
datasets[data_option] = ConcatDataset(dataset)
|
||||||
|
|
||||||
|
# Build Weighted Sampler for train and val
|
||||||
|
if data_option in [DATA_OPTION_train]:
|
||||||
|
if self.sampler == EqualSampler.__name__:
|
||||||
|
class_idxs = [[idx for idx, (_, __, label) in enumerate(datasets[data_option]) if label == class_idx]
|
||||||
|
for class_idx in range(len(self.class_names))
|
||||||
|
]
|
||||||
|
samplers[data_option] = EqualSampler(class_idxs)
|
||||||
|
elif self.sampler == WeightedRandomSampler.__name__:
|
||||||
|
class_counts = defaultdict(lambda: 0)
|
||||||
|
for _, __, label in datasets[data_option]:
|
||||||
|
class_counts[label] += 1
|
||||||
|
len_largest_class = max(class_counts.values())
|
||||||
|
|
||||||
|
weights[data_option] = [1 / class_counts[x] for x in range(len(class_counts))]
|
||||||
|
##############################################################################
|
||||||
|
weights[data_option] = [weights[data_option][datasets[data_option][i][-1]]
|
||||||
|
for i in range(len(datasets[data_option]))]
|
||||||
|
samplers[data_option] = WeightedRandomSampler(weights[data_option],
|
||||||
|
len_largest_class * len(self.class_names))
|
||||||
|
else:
|
||||||
|
samplers[data_option] = None
|
||||||
|
self.datasets = datasets
|
||||||
|
self.samplers = samplers
|
||||||
|
return datasets
|
||||||
|
|
||||||
|
def purge(self):
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
shutil.rmtree(self.mel_folder, ignore_errors=True)
|
||||||
|
print('Mel Folder has been recursively deleted')
|
||||||
|
print(f'Folder still exists: {self.mel_folder.exists()}')
|
||||||
|
return not self.mel_folder.exists()
|
@ -17,7 +17,13 @@ data_options = [DATA_OPTION_test, DATA_OPTION_train, DATA_OPTION_devel]
|
|||||||
|
|
||||||
class PrimatesLibrosaDatamodule(_BaseDataModule):
|
class PrimatesLibrosaDatamodule(_BaseDataModule):
|
||||||
|
|
||||||
class_names = {key: val for val, key in enumerate(['background', 'chimpanze', 'geunon', 'mandrille', 'redcap'])}
|
@property
|
||||||
|
def class_names(self):
|
||||||
|
return {key: val for val, key in enumerate(['background', 'chimpanze', 'geunon', 'mandrille', 'redcap'])}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def n_classes(self):
|
||||||
|
return len(self.class_names)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def shape(self):
|
def shape(self):
|
||||||
@ -33,19 +39,16 @@ class PrimatesLibrosaDatamodule(_BaseDataModule):
|
|||||||
return self.root / 'wav'
|
return self.root / 'wav'
|
||||||
|
|
||||||
def __init__(self, data_root, batch_size, num_worker, sr, n_mels, n_fft, hop_length, sampler=None,
|
def __init__(self, data_root, batch_size, num_worker, sr, n_mels, n_fft, hop_length, sampler=None,
|
||||||
sample_segment_len=40, sample_hop_len=15, random_apply_chance=0.5,
|
target_mel_length_in_seconds=0.7, random_apply_chance=0.5,
|
||||||
loudness_ratio=0.3, shift_ratio=0.3, noise_ratio=0.3, mask_ratio=0.3):
|
loudness_ratio=0.3, shift_ratio=0.3, noise_ratio=0.3, mask_ratio=0.3):
|
||||||
super(PrimatesLibrosaDatamodule, self).__init__()
|
super(PrimatesLibrosaDatamodule, self).__init__()
|
||||||
self.sampler = sampler
|
self.sampler = sampler
|
||||||
self.samplers = None
|
self.samplers = None
|
||||||
|
|
||||||
self.sample_hop_len = sample_hop_len
|
|
||||||
self.sample_segment_len = sample_segment_len
|
|
||||||
|
|
||||||
self.num_worker = num_worker or 1
|
self.num_worker = num_worker or 1
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.root = Path(data_root) / 'primates'
|
self.root = Path(data_root) / 'primates'
|
||||||
self.mel_length_in_seconds = 0.7
|
self.target_mel_length_in_seconds = target_mel_length_in_seconds
|
||||||
|
|
||||||
# Mel Transforms - will be pushed with all other paramters by self.__dict__ to subdataset-class
|
# Mel Transforms - will be pushed with all other paramters by self.__dict__ to subdataset-class
|
||||||
self.mel_kwargs = dict(sr=sr, n_mels=n_mels, n_fft=n_fft, hop_length=hop_length)
|
self.mel_kwargs = dict(sr=sr, n_mels=n_mels, n_fft=n_fft, hop_length=hop_length)
|
||||||
@ -89,7 +92,7 @@ class PrimatesLibrosaDatamodule(_BaseDataModule):
|
|||||||
kwargs.update(mel_augmentations=self.utility_transforms)
|
kwargs.update(mel_augmentations=self.utility_transforms)
|
||||||
# DATA OPTION DIFFERENTIATION !!!!!!!!!!! - End
|
# DATA OPTION DIFFERENTIATION !!!!!!!!!!! - End
|
||||||
|
|
||||||
target_frames = self.mel_length_in_seconds * self.mel_kwargs['sr']
|
target_frames = self.target_mel_length_in_seconds * self.mel_kwargs['sr']
|
||||||
sample_segment_length = target_frames // self.mel_kwargs['hop_length'] + 1
|
sample_segment_length = target_frames // self.mel_kwargs['hop_length'] + 1
|
||||||
kwargs.update(sample_segment_len=sample_segment_length, sample_hop_len=sample_segment_length//2)
|
kwargs.update(sample_segment_len=sample_segment_length, sample_hop_len=sample_segment_length//2)
|
||||||
mel_dataset = LibrosaAudioToMelDataset(audio_file_path, class_id, **kwargs)
|
mel_dataset = LibrosaAudioToMelDataset(audio_file_path, class_id, **kwargs)
|
||||||
|
2
main.py
2
main.py
@ -60,7 +60,7 @@ def run_lightning_loop(h_params, data_class, model_class, seed=69, additional_ca
|
|||||||
trainer = Trainer.from_argparse_args(h_params, logger=logger, callbacks=callbacks)
|
trainer = Trainer.from_argparse_args(h_params, logger=logger, callbacks=callbacks)
|
||||||
|
|
||||||
# Let Model pull what it wants
|
# Let Model pull what it wants
|
||||||
model = model_class.from_argparse_args(h_params, in_shape=datamodule.shape, n_classes=v.N_CLASS_multi)
|
model = model_class.from_argparse_args(h_params, in_shape=datamodule.shape, n_classes=datamodule.n_classes)
|
||||||
model.init_weights()
|
model.init_weights()
|
||||||
|
|
||||||
# trainer.test(model=model, datamodule=datamodule)
|
# trainer.test(model=model, datamodule=datamodule)
|
||||||
|
@ -10,7 +10,7 @@ from einops import rearrange, repeat
|
|||||||
|
|
||||||
from ml_lib.metrics.multi_class_classification import MultiClassScores
|
from ml_lib.metrics.multi_class_classification import MultiClassScores
|
||||||
from ml_lib.modules.blocks import TransformerModule
|
from ml_lib.modules.blocks import TransformerModule
|
||||||
from ml_lib.modules.util import (LightningBaseModule, AutoPadToShape, F_x)
|
from ml_lib.modules.util import (LightningBaseModule, AutoPadToShape)
|
||||||
from util.module_mixins import CombinedModelMixins
|
from util.module_mixins import CombinedModelMixins
|
||||||
|
|
||||||
MIN_NUM_PATCHES = 16
|
MIN_NUM_PATCHES = 16
|
||||||
@ -25,7 +25,7 @@ class VisualTransformer(CombinedModelMixins,
|
|||||||
use_bias, use_norm, dropout, lat_dim, loss, scheduler, mlp_dim, head_dim,
|
use_bias, use_norm, dropout, lat_dim, loss, scheduler, mlp_dim, head_dim,
|
||||||
lr, weight_decay, sto_weight_avg, lr_scheduler_parameter, opt_reset_interval):
|
lr, weight_decay, sto_weight_avg, lr_scheduler_parameter, opt_reset_interval):
|
||||||
|
|
||||||
# TODO: Move this to parent class, or make it much easieer to access... But How...
|
# TODO: Move this to parent class, or make it much easier to access... But How...
|
||||||
a = dict(locals())
|
a = dict(locals())
|
||||||
params = {arg: a[arg] for arg in inspect.signature(self.__init__).parameters.keys() if arg != 'self'}
|
params = {arg: a[arg] for arg in inspect.signature(self.__init__).parameters.keys() if arg != 'self'}
|
||||||
super(VisualTransformer, self).__init__(params)
|
super(VisualTransformer, self).__init__(params)
|
||||||
@ -75,7 +75,7 @@ class VisualTransformer(CombinedModelMixins,
|
|||||||
nn.Linear(self.embed_dim, self.params.lat_dim),
|
nn.Linear(self.embed_dim, self.params.lat_dim),
|
||||||
nn.GELU(),
|
nn.GELU(),
|
||||||
nn.Dropout(self.params.dropout),
|
nn.Dropout(self.params.dropout),
|
||||||
nn.Linear(self.params.lat_dim, n_classes),
|
nn.Linear(self.params.lat_dim, self.params.n_classes),
|
||||||
nn.Softmax()
|
nn.Softmax()
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -88,7 +88,7 @@ class VisualTransformer(CombinedModelMixins,
|
|||||||
tensor = self.autopad(x)
|
tensor = self.autopad(x)
|
||||||
p = self.params.patch_size
|
p = self.params.patch_size
|
||||||
|
|
||||||
tensor = rearrange(tensor, 'b c (h p1) (w p2) -> b (w h) (p1 p2 c)', p1=p, p2=p)
|
tensor = rearrange(tensor, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p)
|
||||||
b, n, _ = tensor.shape
|
b, n, _ = tensor.shape
|
||||||
|
|
||||||
# mask
|
# mask
|
||||||
@ -96,7 +96,7 @@ class VisualTransformer(CombinedModelMixins,
|
|||||||
mask = (lengths == torch.zeros_like(lengths))
|
mask = (lengths == torch.zeros_like(lengths))
|
||||||
# CLS-token awareness
|
# CLS-token awareness
|
||||||
# mask = torch.cat((torch.zeros(b, 1), mask), dim=-1)
|
# mask = torch.cat((torch.zeros(b, 1), mask), dim=-1)
|
||||||
# mask = repeat(mask, 'b n -> b n', h=self.params.heads)
|
# mask = repeat(mask, 'b n -> b h n', h=self.params.heads)
|
||||||
|
|
||||||
tensor = self.patch_to_embedding(tensor)
|
tensor = self.patch_to_embedding(tensor)
|
||||||
|
|
||||||
|
33
multi_run.py
33
multi_run.py
@ -10,26 +10,27 @@ import itertools
|
|||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
||||||
# Set new values
|
# Set new values
|
||||||
hparams_dict = dict(seed=range(10),
|
hparams_dict = dict(seed=[69],
|
||||||
model_name=['VisualTransformer'],
|
model_name=['VisualTransformer'],
|
||||||
batch_size=[50],
|
data_name=['CCSLibrosaDatamodule'],
|
||||||
max_epochs=[250],
|
batch_size=[5],
|
||||||
random_apply_chance=[0.3], # trial.suggest_float('random_apply_chance', 0.1, 0.5, step=0.1),
|
max_epochs=[200],
|
||||||
loudness_ratio=[0], # trial.suggest_float('loudness_ratio', 0.0, 0.5, step=0.1),
|
random_apply_chance=[0.5], # trial.suggest_float('random_apply_chance', 0.1, 0.5, step=0.1),
|
||||||
|
loudness_ratio=[0.3], # trial.suggest_float('loudness_ratio', 0.0, 0.5, step=0.1),
|
||||||
shift_ratio=[0.3], # trial.suggest_float('shift_ratio', 0.0, 0.5, step=0.1),
|
shift_ratio=[0.3], # trial.suggest_float('shift_ratio', 0.0, 0.5, step=0.1),
|
||||||
noise_ratio=[0.3], # trial.suggest_float('noise_ratio', 0.0, 0.5, step=0.1),
|
noise_ratio=[0.3], # trial.suggest_float('noise_ratio', 0.0, 0.5, step=0.1),
|
||||||
mask_ratio=[0.3], # trial.suggest_float('mask_ratio', 0.0, 0.5, step=0.1),
|
mask_ratio=[0.3], # trial.suggest_float('mask_ratio', 0.0, 0.5, step=0.1),
|
||||||
lr=[5e-3], # trial.suggest_uniform('lr', 1e-3, 3e-3),
|
lr=[1e-2], # trial.suggest_uniform('lr', 1e-3, 3e-3),
|
||||||
dropout=[0.2], # trial.suggest_float('dropout', 0.0, 0.3, step=0.05),
|
dropout=[0.2], # trial.suggest_float('dropout', 0.0, 0.3, step=0.05),
|
||||||
lat_dim=[32], # 2 ** trial.suggest_int('lat_dim', 1, 5, step=1),
|
lat_dim=[48], # 2 ** trial.suggest_int('lat_dim', 1, 5, step=1),
|
||||||
mlp_dim=[16], # 2 ** trial.suggest_int('mlp_dim', 1, 5, step=1),
|
mlp_dim=[30], # 2 ** trial.suggest_int('mlp_dim', 1, 5, step=1),
|
||||||
head_dim=[6], # 2 ** trial.suggest_int('head_dim', 1, 5, step=1),
|
head_dim=[12], # 2 ** trial.suggest_int('head_dim', 1, 5, step=1),
|
||||||
patch_size=[12], # trial.suggest_int('patch_size', 6, 12, step=3),
|
patch_size=[12], # trial.suggest_int('patch_size', 6, 12, step=3),
|
||||||
attn_depth=[10], # trial.suggest_int('attn_depth', 2, 14, step=4),
|
attn_depth=[12], # trial.suggest_int('attn_depth', 2, 14, step=4),
|
||||||
heads=[6], # trial.suggest_int('heads', 2, 16, step=2),
|
heads=[12], # trial.suggest_int('heads', 2, 16, step=2),
|
||||||
scheduler=['CosineAnnealingWarmRestarts'], # trial.suggest_categorical('scheduler', [None, 'LambdaLR']),
|
scheduler=['LambdaLR'], # trial.suggest_categorical('scheduler', [None, 'LambdaLR']),
|
||||||
lr_scheduler_parameter=[25], # [0.98],
|
lr_scheduler_parameter=[0.95], # [0.98],
|
||||||
embedding_size=[30], # trial.suggest_int('embedding_size', 12, 64, step=12),
|
embedding_size=[64], # trial.suggest_int('embedding_size', 12, 64, step=12),
|
||||||
loss=['ce_loss'],
|
loss=['ce_loss'],
|
||||||
sampler=['WeightedRandomSampler'],
|
sampler=['WeightedRandomSampler'],
|
||||||
# rial.suggest_categorical('sampler', [None, 'WeightedRandomSampler']),
|
# rial.suggest_categorical('sampler', [None, 'WeightedRandomSampler']),
|
||||||
@ -40,7 +41,7 @@ if __name__ == '__main__':
|
|||||||
permutations_dicts = [dict(zip(keys, v)) for v in itertools.product(*values)]
|
permutations_dicts = [dict(zip(keys, v)) for v in itertools.product(*values)]
|
||||||
for permutations_dict in tqdm(permutations_dicts, total=len(permutations_dicts)):
|
for permutations_dict in tqdm(permutations_dicts, total=len(permutations_dicts)):
|
||||||
# Parse comandline args, read config and get model
|
# Parse comandline args, read config and get model
|
||||||
cmd_args, found_data_class, found_model_class = parse_comandline_args_add_defaults(
|
cmd_args, *data_model_seed = parse_comandline_args_add_defaults(
|
||||||
'_parameters.ini', overrides=permutations_dict)
|
'_parameters.ini', overrides=permutations_dict)
|
||||||
|
|
||||||
hparams = dict(**cmd_args)
|
hparams = dict(**cmd_args)
|
||||||
@ -50,6 +51,6 @@ if __name__ == '__main__':
|
|||||||
# RUN
|
# RUN
|
||||||
# ---------------------------------------
|
# ---------------------------------------
|
||||||
print(f'Running Loop, parameters are: {permutations_dict}')
|
print(f'Running Loop, parameters are: {permutations_dict}')
|
||||||
run_lightning_loop(hparams, found_data_class, found_model_class)
|
run_lightning_loop(hparams, *data_model_seed)
|
||||||
print(f'Done, parameters were: {permutations_dict}')
|
print(f'Done, parameters were: {permutations_dict}')
|
||||||
pass
|
pass
|
||||||
|
@ -22,7 +22,7 @@ def rebuild_dataset(h_params, data_class):
|
|||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
# Parse comandline args, read config and get model
|
# Parse comandline args, read config and get model
|
||||||
cmd_args, found_data_class, _ = parse_comandline_args_add_defaults('_parameters.ini')
|
cmd_args, found_data_class, _, _ = parse_comandline_args_add_defaults('_parameters.ini')
|
||||||
|
|
||||||
# To NameSpace
|
# To NameSpace
|
||||||
hparams = Namespace(**cmd_args)
|
hparams = Namespace(**cmd_args)
|
||||||
|
@ -19,7 +19,7 @@ class TrainMixin:
|
|||||||
y = self(batch_x).main_out
|
y = self(batch_x).main_out
|
||||||
|
|
||||||
if self.params.loss == 'focal_loss_rob':
|
if self.params.loss == 'focal_loss_rob':
|
||||||
labels_one_hot = torch.nn.functional.one_hot(batch_y, num_classes=5)
|
labels_one_hot = torch.nn.functional.one_hot(batch_y, num_classes=self.params.n_classes)
|
||||||
loss = self.__getattribute__(self.params.loss)(y, labels_one_hot)
|
loss = self.__getattribute__(self.params.loss)(y, labels_one_hot)
|
||||||
else:
|
else:
|
||||||
loss = self.__getattribute__(self.params.loss)(y, batch_y.long())
|
loss = self.__getattribute__(self.params.loss)(y, batch_y.long())
|
||||||
@ -58,7 +58,7 @@ class ValMixin:
|
|||||||
y_max = torch.stack(
|
y_max = torch.stack(
|
||||||
[torch.argmax(x.mean(dim=0)) if x.shape[0] > 1 else torch.argmax(x) for x in sorted_y.values()]
|
[torch.argmax(x.mean(dim=0)) if x.shape[0] > 1 else torch.argmax(x) for x in sorted_y.values()]
|
||||||
).squeeze()
|
).squeeze()
|
||||||
y_one_hot = torch.nn.functional.one_hot(y_max, num_classes=5).float()
|
y_one_hot = torch.nn.functional.one_hot(y_max, num_classes=self.params.n_classes).float()
|
||||||
self.metrics.update(y_one_hot, torch.stack(tuple(sorted_batch_y.values())).long())
|
self.metrics.update(y_one_hot, torch.stack(tuple(sorted_batch_y.values())).long())
|
||||||
|
|
||||||
val_loss = self.ce_loss(y, batch_y.long())
|
val_loss = self.ce_loss(y, batch_y.long())
|
||||||
@ -96,7 +96,7 @@ class ValMixin:
|
|||||||
y_max = torch.stack(
|
y_max = torch.stack(
|
||||||
[torch.argmax(x.mean(dim=0)) if x.shape[0] > 1 else torch.argmax(x) for x in sorted_y.values()]
|
[torch.argmax(x.mean(dim=0)) if x.shape[0] > 1 else torch.argmax(x) for x in sorted_y.values()]
|
||||||
).squeeze()
|
).squeeze()
|
||||||
y_one_hot = torch.nn.functional.one_hot(y_max, num_classes=5).float()
|
y_one_hot = torch.nn.functional.one_hot(y_max, num_classes=self.params.n_classes).float()
|
||||||
max_vote_loss = self.ce_loss(y_one_hot, sorted_batch_y)
|
max_vote_loss = self.ce_loss(y_one_hot, sorted_batch_y)
|
||||||
summary_dict.update(val_max_vote_loss=max_vote_loss)
|
summary_dict.update(val_max_vote_loss=max_vote_loss)
|
||||||
|
|
||||||
@ -145,7 +145,11 @@ class TestMixin:
|
|||||||
y_max = torch.stack(
|
y_max = torch.stack(
|
||||||
[torch.argmax(x.mean(dim=0)) if x.shape[0] > 1 else torch.argmax(x) for x in sorted_y.values()]
|
[torch.argmax(x.mean(dim=0)) if x.shape[0] > 1 else torch.argmax(x) for x in sorted_y.values()]
|
||||||
).squeeze().cpu()
|
).squeeze().cpu()
|
||||||
class_names = {val: key for val, key in enumerate(['background', 'chimpanze', 'geunon', 'mandrille', 'redcap'])}
|
if self.params.n_classes == 5:
|
||||||
|
class_names = {val: key for val, key in
|
||||||
|
enumerate(['background', 'chimpanze', 'geunon', 'mandrille', 'redcap'])}
|
||||||
|
elif self.params.n_classes == 2:
|
||||||
|
class_names = {val: key for val, key in ['negative', 'positive']}
|
||||||
|
|
||||||
df = pd.DataFrame(data=dict(filename=[Path(x).name for x in sorted_y.keys()],
|
df = pd.DataFrame(data=dict(filename=[Path(x).name for x in sorted_y.keys()],
|
||||||
prediction=y_max.cpu().numpy()))
|
prediction=y_max.cpu().numpy()))
|
||||||
@ -154,7 +158,7 @@ class TestMixin:
|
|||||||
try:
|
try:
|
||||||
result_file.unlink()
|
result_file.unlink()
|
||||||
except:
|
except:
|
||||||
print('File allready existed')
|
print('File already existed')
|
||||||
pass
|
pass
|
||||||
with result_file.open(mode='wb') as csv_file:
|
with result_file.open(mode='wb') as csv_file:
|
||||||
df.to_csv(index=False, path_or_buf=csv_file)
|
df.to_csv(index=False, path_or_buf=csv_file)
|
||||||
|
@ -4,5 +4,6 @@ from pathlib import Path
|
|||||||
sr = 16000
|
sr = 16000
|
||||||
|
|
||||||
PRIMATES_Root = Path(__file__).parent / 'data' / 'primates'
|
PRIMATES_Root = Path(__file__).parent / 'data' / 'primates'
|
||||||
|
CCS_Root = Path(__file__).parent / 'data' / 'ComParE2021_CCS'
|
||||||
|
|
||||||
N_CLASS_multi = 5
|
N_CLASS_multi = 5
|
||||||
|
Loading…
x
Reference in New Issue
Block a user