Audio Dataset
This commit is contained in:
parent
95561acc35
commit
95dcf22f3d
@ -2,7 +2,8 @@ from argparse import ArgumentParser, Namespace
|
|||||||
from distutils.util import strtobool
|
from distutils.util import strtobool
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import os
|
NEPTUNE_API_KEY = 'eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vdWkubmVwdHVuZS5haSIsImFwaV91cmwiOiJodHRwczovL3VpLm5lcHR1bmUu' \
|
||||||
|
'YWkiLCJhcGlfa2V5IjoiZmI0OGMzNzUtOTg1NS00Yzg2LThjMzYtMWFiYjUwMDUyMjVlIn0='
|
||||||
|
|
||||||
# Parameter Configuration
|
# Parameter Configuration
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
@ -18,10 +19,10 @@ main_arg_parser.add_argument("--main_eval", type=strtobool, default=True, help="
|
|||||||
main_arg_parser.add_argument("--main_seed", type=int, default=69, help="")
|
main_arg_parser.add_argument("--main_seed", type=int, default=69, help="")
|
||||||
|
|
||||||
# Data Parameters
|
# Data Parameters
|
||||||
main_arg_parser.add_argument("--data_worker", type=int, default=11, help="")
|
main_arg_parser.add_argument("--data_class_name", type=str, default='Urban8K', help="")
|
||||||
|
main_arg_parser.add_argument("--data_worker", type=int, default=6, help="")
|
||||||
main_arg_parser.add_argument("--data_root", type=str, default='data', help="")
|
main_arg_parser.add_argument("--data_root", type=str, default='data', help="")
|
||||||
main_arg_parser.add_argument("--data_class_name", type=str, default='BinaryMasksDataset', help="")
|
main_arg_parser.add_argument("--data_use_preprocessed", type=strtobool, default=True, help="")
|
||||||
main_arg_parser.add_argument("--data_use_preprocessed", type=strtobool, default=False, help="")
|
|
||||||
main_arg_parser.add_argument("--data_n_mels", type=int, default=64, help="")
|
main_arg_parser.add_argument("--data_n_mels", type=int, default=64, help="")
|
||||||
main_arg_parser.add_argument("--data_sr", type=int, default=16000, help="")
|
main_arg_parser.add_argument("--data_sr", type=int, default=16000, help="")
|
||||||
main_arg_parser.add_argument("--data_hop_length", type=int, default=256, help="")
|
main_arg_parser.add_argument("--data_hop_length", type=int, default=256, help="")
|
||||||
@ -39,7 +40,7 @@ main_arg_parser.add_argument("--data_speed_max", type=float, default=0, help="")
|
|||||||
|
|
||||||
# Model Parameters
|
# Model Parameters
|
||||||
# General
|
# General
|
||||||
main_arg_parser.add_argument("--model_type", type=str, default="SequentialVisualTransformer", help="")
|
main_arg_parser.add_argument("--model_type", type=str, default="HorizontalVisualTransformer", help="")
|
||||||
main_arg_parser.add_argument("--model_weight_init", type=str, default="xavier_normal_", help="")
|
main_arg_parser.add_argument("--model_weight_init", type=str, default="xavier_normal_", help="")
|
||||||
main_arg_parser.add_argument("--model_activation", type=str, default="gelu", help="")
|
main_arg_parser.add_argument("--model_activation", type=str, default="gelu", help="")
|
||||||
main_arg_parser.add_argument("--model_bias", type=strtobool, default=True, help="")
|
main_arg_parser.add_argument("--model_bias", type=strtobool, default=True, help="")
|
||||||
@ -63,8 +64,8 @@ main_arg_parser.add_argument("--train_version", type=strtobool, required=False,
|
|||||||
main_arg_parser.add_argument("--train_sto_weight_avg", type=strtobool, default=False, help="")
|
main_arg_parser.add_argument("--train_sto_weight_avg", type=strtobool, default=False, help="")
|
||||||
main_arg_parser.add_argument("--train_weight_decay", type=float, default=0, help="")
|
main_arg_parser.add_argument("--train_weight_decay", type=float, default=0, help="")
|
||||||
main_arg_parser.add_argument("--train_opt_reset_interval", type=int, default=0, help="")
|
main_arg_parser.add_argument("--train_opt_reset_interval", type=int, default=0, help="")
|
||||||
main_arg_parser.add_argument("--train_epochs", type=int, default=100, help="")
|
main_arg_parser.add_argument("--train_epochs", type=int, default=200, help="")
|
||||||
main_arg_parser.add_argument("--train_batch_size", type=int, default=250, help="")
|
main_arg_parser.add_argument("--train_batch_size", type=int, default=200, help="")
|
||||||
main_arg_parser.add_argument("--train_lr", type=float, default=1e-3, help="")
|
main_arg_parser.add_argument("--train_lr", type=float, default=1e-3, help="")
|
||||||
main_arg_parser.add_argument("--train_lr_warmup_steps", type=int, default=10, help="")
|
main_arg_parser.add_argument("--train_lr_warmup_steps", type=int, default=10, help="")
|
||||||
main_arg_parser.add_argument("--train_num_sanity_val_steps", type=int, default=0, help="")
|
main_arg_parser.add_argument("--train_num_sanity_val_steps", type=int, default=0, help="")
|
||||||
@ -72,7 +73,7 @@ main_arg_parser.add_argument("--train_num_sanity_val_steps", type=int, default=0
|
|||||||
# Project Parameters
|
# Project Parameters
|
||||||
main_arg_parser.add_argument("--project_name", type=str, default=_ROOT.name, help="")
|
main_arg_parser.add_argument("--project_name", type=str, default=_ROOT.name, help="")
|
||||||
main_arg_parser.add_argument("--project_owner", type=str, default='si11ium', help="")
|
main_arg_parser.add_argument("--project_owner", type=str, default='si11ium', help="")
|
||||||
main_arg_parser.add_argument("--project_neptune_key", type=str, default=os.getenv('NEPTUNE_API_TOKEN'), help="")
|
main_arg_parser.add_argument("--project_neptune_key", type=str, default=NEPTUNE_API_KEY, help="")
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
# Parse it
|
# Parse it
|
||||||
|
@ -18,15 +18,18 @@ class BinaryMasksDataset(Dataset):
|
|||||||
def sample_shape(self):
|
def sample_shape(self):
|
||||||
return self[0][0].shape
|
return self[0][0].shape
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _fingerprint(self):
|
||||||
|
return dict(**self._mel_kwargs, normalize=self.normalize)
|
||||||
|
|
||||||
def __init__(self, data_root, setting, mel_transforms, transforms=None, stretch_dataset=False,
|
def __init__(self, data_root, setting, mel_transforms, transforms=None, stretch_dataset=False,
|
||||||
use_preprocessed=True):
|
use_preprocessed=True):
|
||||||
self.use_preprocessed = use_preprocessed
|
|
||||||
self.stretch = stretch_dataset
|
self.stretch = stretch_dataset
|
||||||
assert isinstance(setting, str), f'Setting has to be a string, but was: {type(setting)}.'
|
assert isinstance(setting, str), f'Setting has to be a string, but was: {type(setting)}.'
|
||||||
assert setting in V.DATA_OPTIONS, f'Setting must match one of: {V.DATA_OPTIONS}.'
|
assert setting in V.DATA_OPTIONS, f'Setting must match one of: {V.DATA_OPTIONS}.'
|
||||||
super(BinaryMasksDataset, self).__init__()
|
super(BinaryMasksDataset, self).__init__()
|
||||||
|
|
||||||
self.data_root = Path(data_root)
|
self.data_root = Path(data_root) / 'ComParE2020_Mask'
|
||||||
self.setting = setting
|
self.setting = setting
|
||||||
self._wav_folder = self.data_root / 'wav'
|
self._wav_folder = self.data_root / 'wav'
|
||||||
self._mel_folder = self.data_root / 'mel'
|
self._mel_folder = self.data_root / 'mel'
|
||||||
@ -37,16 +40,36 @@ class BinaryMasksDataset(Dataset):
|
|||||||
self._wav_files = list(sorted(self._labels.keys()))
|
self._wav_files = list(sorted(self._labels.keys()))
|
||||||
self._transforms = transforms or F_x(in_shape=None)
|
self._transforms = transforms or F_x(in_shape=None)
|
||||||
|
|
||||||
|
param_storage = self._mel_folder / 'data_params.pik'
|
||||||
|
self._mel_folder.mkdir(parents=True, exist_ok=True)
|
||||||
|
try:
|
||||||
|
pik_data = param_storage.read_bytes()
|
||||||
|
fingerprint = pickle.loads(pik_data)
|
||||||
|
if fingerprint == self._fingerprint:
|
||||||
|
self.use_preprocessed = use_preprocessed
|
||||||
|
else:
|
||||||
|
print('Diverging parameters were found; Refreshing...')
|
||||||
|
param_storage.unlink()
|
||||||
|
pik_data = pickle.dumps(self._fingerprint)
|
||||||
|
param_storage.write_bytes(pik_data)
|
||||||
|
self.use_preprocessed = True
|
||||||
|
|
||||||
|
except FileNotFoundError:
|
||||||
|
pik_data = pickle.dumps(self._fingerprint)
|
||||||
|
param_storage.write_bytes(pik_data)
|
||||||
|
self.use_preprocessed = True
|
||||||
|
|
||||||
def _build_labels(self):
|
def _build_labels(self):
|
||||||
labeldict = dict()
|
labeldict = dict()
|
||||||
with open(Path(self.data_root) / 'lab' / 'labels.csv', mode='r') as f:
|
labelfile = 'labels' if self.setting != V.DATA_OPTIONS.test else V.DATA_OPTIONS.test
|
||||||
|
with open(Path(self.data_root) / 'lab' / f'{labelfile}.csv', mode='r') as f:
|
||||||
# Exclude the header
|
# Exclude the header
|
||||||
_ = next(f)
|
_ = next(f)
|
||||||
for row in f:
|
for row in f:
|
||||||
if self.setting not in row:
|
if self.setting not in row:
|
||||||
continue
|
continue
|
||||||
filename, label = row.strip().split(',')
|
filename, label = row.strip().split(',')
|
||||||
labeldict[filename] = self._to_label[label.lower()] if not self.setting == 'test' else filename
|
labeldict[filename] = self._to_label[label.lower()] # if not self.setting == 'test' else filename
|
||||||
if self.stretch and self.setting == V.DATA_OPTIONS.train:
|
if self.stretch and self.setting == V.DATA_OPTIONS.train:
|
||||||
additional_dict = ({f'X{key}': val for key, val in labeldict.items()})
|
additional_dict = ({f'X{key}': val for key, val in labeldict.items()})
|
||||||
additional_dict.update({f'XX{key}': val for key, val in labeldict.items()})
|
additional_dict.update({f'XX{key}': val for key, val in labeldict.items()})
|
||||||
|
@ -1,95 +1,140 @@
|
|||||||
import pickle
|
import pickle
|
||||||
from collections import defaultdict
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
import multiprocessing as mp
|
||||||
|
|
||||||
import librosa as librosa
|
import librosa as librosa
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset, ConcatDataset
|
||||||
import torch
|
import torch
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
import variables as V
|
import variables as V
|
||||||
|
from ml_lib.audio_toolset.mel_dataset import TorchMelDataset
|
||||||
from ml_lib.modules.util import F_x
|
from ml_lib.modules.util import F_x
|
||||||
|
|
||||||
|
|
||||||
class BinaryMasksDataset(Dataset):
|
class Urban8K(Dataset):
|
||||||
_to_label = defaultdict(lambda: -1)
|
|
||||||
_to_label.update(dict(clear=V.CLEAR, mask=V.MASK))
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def sample_shape(self):
|
def sample_shape(self):
|
||||||
return self[0][0].shape
|
return self[0][0].shape
|
||||||
|
|
||||||
def __init__(self, data_root, setting, mel_transforms, transforms=None, stretch_dataset=False,
|
@property
|
||||||
use_preprocessed=True):
|
def _fingerprint(self):
|
||||||
self.use_preprocessed = use_preprocessed
|
return str(self._mel_transform)
|
||||||
self.stretch = stretch_dataset
|
|
||||||
|
def __init__(self, data_root, setting, mel_transforms, fold=1, transforms=None,
|
||||||
|
use_preprocessed=True, audio_segment_len=62, audio_hop_len=30, num_worker=mp.cpu_count(),
|
||||||
|
**_):
|
||||||
assert isinstance(setting, str), f'Setting has to be a string, but was: {type(setting)}.'
|
assert isinstance(setting, str), f'Setting has to be a string, but was: {type(setting)}.'
|
||||||
assert setting in V.DATA_OPTIONS, f'Setting must match one of: {V.DATA_OPTIONS}.'
|
assert setting in V.DATA_OPTIONS, f'Setting must match one of: {V.DATA_OPTIONS}.'
|
||||||
super(BinaryMasksDataset, self).__init__()
|
assert fold in range(1, 11)
|
||||||
|
super(Urban8K, self).__init__()
|
||||||
|
|
||||||
self.data_root = Path(data_root)
|
self.data_root = Path(data_root) / 'UrbanSound8K'
|
||||||
self.setting = setting
|
self.setting = setting
|
||||||
self._wav_folder = self.data_root / 'wav'
|
self.num_worker = num_worker
|
||||||
self._mel_folder = self.data_root / 'mel'
|
self.fold = fold if self.setting == V.DATA_OPTIONS.train else 10
|
||||||
|
self.use_preprocessed = use_preprocessed
|
||||||
|
self._wav_folder = self.data_root / 'audio' / f'fold{self.fold}'
|
||||||
|
self._mel_folder = self.data_root / 'mel' / f'fold{self.fold}'
|
||||||
self.container_ext = '.pik'
|
self.container_ext = '.pik'
|
||||||
self._mel_transform = mel_transforms
|
self._mel_transform = mel_transforms
|
||||||
|
|
||||||
self._labels = self._build_labels()
|
self._labels = self._build_labels()
|
||||||
self._wav_files = list(sorted(self._labels.keys()))
|
self._wav_files = list(sorted(self._labels.keys()))
|
||||||
self._transforms = transforms or F_x(in_shape=None)
|
transforms = transforms or F_x(in_shape=None)
|
||||||
|
|
||||||
|
param_storage = self._mel_folder / 'data_params.pik'
|
||||||
|
self._mel_folder.mkdir(parents=True, exist_ok=True)
|
||||||
|
try:
|
||||||
|
pik_data = param_storage.read_bytes()
|
||||||
|
fingerprint = pickle.loads(pik_data)
|
||||||
|
if fingerprint == self._fingerprint:
|
||||||
|
self.use_preprocessed = use_preprocessed
|
||||||
|
else:
|
||||||
|
print('Diverging parameters were found; Refreshing...')
|
||||||
|
param_storage.unlink()
|
||||||
|
pik_data = pickle.dumps(self._fingerprint)
|
||||||
|
param_storage.write_bytes(pik_data)
|
||||||
|
self.use_preprocessed = False
|
||||||
|
|
||||||
|
except FileNotFoundError:
|
||||||
|
pik_data = pickle.dumps(self._fingerprint)
|
||||||
|
param_storage.write_bytes(pik_data)
|
||||||
|
self.use_preprocessed = False
|
||||||
|
|
||||||
|
|
||||||
|
while True:
|
||||||
|
if not self.use_preprocessed:
|
||||||
|
self._pre_process()
|
||||||
|
try:
|
||||||
|
self._dataset = ConcatDataset(
|
||||||
|
[TorchMelDataset(identifier=key, mel_path=self._mel_folder, transform=transforms,
|
||||||
|
segment_len=audio_segment_len, hop_len=audio_hop_len,
|
||||||
|
label=self._labels[key]['label']
|
||||||
|
) for key in self._labels.keys()]
|
||||||
|
)
|
||||||
|
break
|
||||||
|
except IOError:
|
||||||
|
self.use_preprocessed = False
|
||||||
|
pass
|
||||||
|
|
||||||
def _build_labels(self):
|
def _build_labels(self):
|
||||||
labeldict = dict()
|
labeldict = dict()
|
||||||
with open(Path(self.data_root) / 'lab' / 'labels.csv', mode='r') as f:
|
with open(Path(self.data_root) / 'metadata' / 'UrbanSound8K.csv', mode='r') as f:
|
||||||
# Exclude the header
|
# Exclude the header
|
||||||
_ = next(f)
|
_ = next(f)
|
||||||
for row in f:
|
for row in f:
|
||||||
if self.setting not in row:
|
slice_file_name, fs_id, start, end, salience, fold, class_id, class_name = row.strip().split(',')
|
||||||
continue
|
if int(fold) == self.fold:
|
||||||
filename, label = row.strip().split(',')
|
key = slice_file_name.replace('.wav', '')
|
||||||
labeldict[filename] = self._to_label[label.lower()] if not self.setting == 'test' else filename
|
labeldict[key] = dict(label=int(class_id), fold=int(fold))
|
||||||
if self.stretch and self.setting == V.DATA_OPTIONS.train:
|
|
||||||
additional_dict = ({f'X{key}': val for key, val in labeldict.items()})
|
|
||||||
additional_dict.update({f'XX{key}': val for key, val in labeldict.items()})
|
|
||||||
additional_dict.update({f'XXX{key}': val for key, val in labeldict.items()})
|
|
||||||
labeldict.update(additional_dict)
|
|
||||||
|
|
||||||
# Delete File if one exists.
|
# Delete File if one exists.
|
||||||
if not self.use_preprocessed:
|
if not self.use_preprocessed:
|
||||||
for key in labeldict.keys():
|
for key in labeldict.keys():
|
||||||
|
for mel_file in self._mel_folder.rglob(f'{key}_*'):
|
||||||
try:
|
try:
|
||||||
(self._mel_folder / (key.replace('.wav', '') + self.container_ext)).unlink()
|
mel_file.unlink(missing_ok=True)
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
return labeldict
|
return labeldict
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self._labels)
|
return len(self._dataset)
|
||||||
|
|
||||||
def _compute_or_retrieve(self, filename):
|
def _pre_process(self):
|
||||||
|
print('Preprocessing Mel Files....')
|
||||||
|
with mp.Pool(processes=self.num_worker) as pool:
|
||||||
|
with tqdm(total=len(self._labels)) as pbar:
|
||||||
|
for i, _ in enumerate(pool.imap_unordered(self._build_mel, self._labels.keys())):
|
||||||
|
pbar.update()
|
||||||
|
|
||||||
if not (self._mel_folder / (filename + self.container_ext)).exists():
|
def _build_mel(self, filename):
|
||||||
raw_sample, sr = librosa.core.load(self._wav_folder / (filename.replace('X', '') + '.wav'))
|
|
||||||
|
wav_file = self._wav_folder / (filename.replace('X', '') + '.wav')
|
||||||
|
mel_file = list(self._mel_folder.glob(f'{filename}_*'))
|
||||||
|
|
||||||
|
if not mel_file:
|
||||||
|
raw_sample, sr = librosa.core.load(wav_file)
|
||||||
mel_sample = self._mel_transform(raw_sample)
|
mel_sample = self._mel_transform(raw_sample)
|
||||||
|
m, n = mel_sample.shape
|
||||||
|
mel_file = self._mel_folder / f'{filename}_{m}_{n}'
|
||||||
self._mel_folder.mkdir(exist_ok=True, parents=True)
|
self._mel_folder.mkdir(exist_ok=True, parents=True)
|
||||||
with (self._mel_folder / (filename + self.container_ext)).open(mode='wb') as f:
|
with mel_file.open(mode='wb') as f:
|
||||||
pickle.dump(mel_sample, f, protocol=pickle.HIGHEST_PROTOCOL)
|
pickle.dump(mel_sample, f, protocol=pickle.HIGHEST_PROTOCOL)
|
||||||
|
else:
|
||||||
|
# print(f"Already existed.. Skipping {filename}")
|
||||||
|
mel_file = mel_file[0]
|
||||||
|
|
||||||
with (self._mel_folder / (filename + self.container_ext)).open(mode='rb') as f:
|
with mel_file.open(mode='rb') as f:
|
||||||
mel_sample = pickle.load(f, fix_imports=True)
|
mel_sample = pickle.load(f, fix_imports=True)
|
||||||
return mel_sample
|
return mel_sample, mel_file
|
||||||
|
|
||||||
def __getitem__(self, item):
|
def __getitem__(self, item):
|
||||||
|
transformed_samples, label = self._dataset[item]
|
||||||
|
|
||||||
key: str = list(self._labels.keys())[item]
|
|
||||||
filename = key.replace('.wav', '')
|
|
||||||
mel_sample = self._compute_or_retrieve(filename)
|
|
||||||
label = self._labels[key]
|
|
||||||
|
|
||||||
transformed_samples = self._transforms(mel_sample)
|
|
||||||
|
|
||||||
if self.setting != V.DATA_OPTIONS.test:
|
|
||||||
# In test, filenames instead of labels are returned. This is a little hacky though.
|
|
||||||
label = torch.as_tensor(label, dtype=torch.float)
|
label = torch.as_tensor(label, dtype=torch.float)
|
||||||
|
|
||||||
return transformed_samples, label
|
return transformed_samples, label
|
||||||
|
140
datasets/urban_8k_torchaudio.py
Normal file
140
datasets/urban_8k_torchaudio.py
Normal file
@ -0,0 +1,140 @@
|
|||||||
|
import pickle
|
||||||
|
from pathlib import Path
|
||||||
|
import multiprocessing as mp
|
||||||
|
|
||||||
|
import librosa as librosa
|
||||||
|
from torch.utils.data import Dataset, ConcatDataset
|
||||||
|
import torch
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
import variables as V
|
||||||
|
from ml_lib.audio_toolset.mel_dataset import TorchMelDataset
|
||||||
|
from ml_lib.modules.util import F_x
|
||||||
|
|
||||||
|
|
||||||
|
class Urban8K_TO(Dataset):
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sample_shape(self):
|
||||||
|
return self[0][0].shape
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _fingerprint(self):
|
||||||
|
return str(self._mel_transform)
|
||||||
|
|
||||||
|
def __init__(self, data_root, setting, mel_transforms, fold=1, transforms=None,
|
||||||
|
use_preprocessed=True, audio_segment_len=1, audio_hop_len=1, num_worker=mp.cpu_count(),
|
||||||
|
**_):
|
||||||
|
assert isinstance(setting, str), f'Setting has to be a string, but was: {type(setting)}.'
|
||||||
|
assert setting in V.DATA_OPTIONS, f'Setting must match one of: {V.DATA_OPTIONS}.'
|
||||||
|
assert fold in range(1, 11)
|
||||||
|
super(Urban8K_TO, self).__init__()
|
||||||
|
|
||||||
|
self.data_root = Path(data_root) / 'UrbanSound8K'
|
||||||
|
self.setting = setting
|
||||||
|
self.num_worker = num_worker
|
||||||
|
self.fold = fold if self.setting == V.DATA_OPTIONS.train else 10
|
||||||
|
self.use_preprocessed = use_preprocessed
|
||||||
|
self._wav_folder = self.data_root / 'audio' / f'fold{self.fold}'
|
||||||
|
self._mel_folder = self.data_root / 'mel' / f'fold{self.fold}'
|
||||||
|
self.container_ext = '.pik'
|
||||||
|
self._mel_transform = mel_transforms
|
||||||
|
|
||||||
|
self._labels = self._build_labels()
|
||||||
|
self._wav_files = list(sorted(self._labels.keys()))
|
||||||
|
transforms = transforms or F_x(in_shape=None)
|
||||||
|
|
||||||
|
param_storage = self._mel_folder / 'data_params.pik'
|
||||||
|
self._mel_folder.mkdir(parents=True, exist_ok=True)
|
||||||
|
try:
|
||||||
|
pik_data = param_storage.read_bytes()
|
||||||
|
fingerprint = pickle.loads(pik_data)
|
||||||
|
if fingerprint == self._fingerprint:
|
||||||
|
self.use_preprocessed = use_preprocessed
|
||||||
|
else:
|
||||||
|
print('Diverging parameters were found; Refreshing...')
|
||||||
|
param_storage.unlink()
|
||||||
|
pik_data = pickle.dumps(self._fingerprint)
|
||||||
|
param_storage.write_bytes(pik_data)
|
||||||
|
self.use_preprocessed = False
|
||||||
|
|
||||||
|
except FileNotFoundError:
|
||||||
|
pik_data = pickle.dumps(self._fingerprint)
|
||||||
|
param_storage.write_bytes(pik_data)
|
||||||
|
self.use_preprocessed = False
|
||||||
|
|
||||||
|
|
||||||
|
while True:
|
||||||
|
if not self.use_preprocessed:
|
||||||
|
self._pre_process()
|
||||||
|
try:
|
||||||
|
self._dataset = ConcatDataset(
|
||||||
|
[TorchMelDataset(identifier=key, mel_path=self._mel_folder, transform=transforms,
|
||||||
|
segment_len=audio_segment_len, hop_len=audio_hop_len,
|
||||||
|
label=self._labels[key]['label']
|
||||||
|
) for key in self._labels.keys()]
|
||||||
|
)
|
||||||
|
break
|
||||||
|
except IOError:
|
||||||
|
self.use_preprocessed = False
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _build_labels(self):
|
||||||
|
labeldict = dict()
|
||||||
|
with open(Path(self.data_root) / 'metadata' / 'UrbanSound8K.csv', mode='r') as f:
|
||||||
|
# Exclude the header
|
||||||
|
_ = next(f)
|
||||||
|
for row in f:
|
||||||
|
slice_file_name, fs_id, start, end, salience, fold, class_id, class_name = row.strip().split(',')
|
||||||
|
if int(fold) == self.fold:
|
||||||
|
key = slice_file_name.replace('.wav', '')
|
||||||
|
labeldict[key] = dict(label=int(class_id), fold=int(fold))
|
||||||
|
|
||||||
|
# Delete File if one exists.
|
||||||
|
if not self.use_preprocessed:
|
||||||
|
for key in labeldict.keys():
|
||||||
|
for mel_file in self._mel_folder.rglob(f'{key}_*'):
|
||||||
|
try:
|
||||||
|
mel_file.unlink(missing_ok=True)
|
||||||
|
except FileNotFoundError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return labeldict
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self._dataset)
|
||||||
|
|
||||||
|
def _pre_process(self):
|
||||||
|
print('Preprocessing Mel Files....')
|
||||||
|
with mp.Pool(processes=self.num_worker) as pool:
|
||||||
|
with tqdm(total=len(self._labels)) as pbar:
|
||||||
|
for i, _ in enumerate(pool.imap_unordered(self._build_mel, self._labels.keys())):
|
||||||
|
pbar.update()
|
||||||
|
|
||||||
|
def _build_mel(self, filename):
|
||||||
|
|
||||||
|
wav_file = self._wav_folder / (filename.replace('X', '') + '.wav')
|
||||||
|
mel_file = list(self._mel_folder.glob(f'{filename}_*'))
|
||||||
|
|
||||||
|
if not mel_file:
|
||||||
|
raw_sample, sr = librosa.core.load(wav_file)
|
||||||
|
mel_sample = self._mel_transform(raw_sample)
|
||||||
|
m, n = mel_sample.shape
|
||||||
|
mel_file = self._mel_folder / f'{filename}_{m}_{n}'
|
||||||
|
self._mel_folder.mkdir(exist_ok=True, parents=True)
|
||||||
|
with mel_file.open(mode='wb') as f:
|
||||||
|
pickle.dump(mel_sample, f, protocol=pickle.HIGHEST_PROTOCOL)
|
||||||
|
else:
|
||||||
|
# print(f"Already existed.. Skipping {filename}")
|
||||||
|
mel_file = mel_file[0]
|
||||||
|
|
||||||
|
with mel_file.open(mode='rb') as f:
|
||||||
|
mel_sample = pickle.load(f, fix_imports=True)
|
||||||
|
return mel_sample, mel_file
|
||||||
|
|
||||||
|
def __getitem__(self, item):
|
||||||
|
transformed_samples, label = self._dataset[item]
|
||||||
|
|
||||||
|
label = torch.as_tensor(label, dtype=torch.float)
|
||||||
|
|
||||||
|
return transformed_samples, label
|
40
main.py
40
main.py
@ -82,47 +82,9 @@ def run_lightning_loop(config_obj):
|
|||||||
# Save the last state & all parameters
|
# Save the last state & all parameters
|
||||||
trainer.save_checkpoint(str(logger.log_dir / 'weights.ckpt'))
|
trainer.save_checkpoint(str(logger.log_dir / 'weights.ckpt'))
|
||||||
model.save_to_disk(logger.log_dir)
|
model.save_to_disk(logger.log_dir)
|
||||||
|
# trainer.run_evaluation(test_mode=True)
|
||||||
|
|
||||||
# Evaluate It
|
|
||||||
if config_obj.main.eval:
|
|
||||||
with torch.no_grad():
|
|
||||||
model.eval()
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
model.cuda()
|
|
||||||
outputs = []
|
|
||||||
from tqdm import tqdm
|
|
||||||
for idx, batch in enumerate(tqdm(model.val_dataloader()[0])):
|
|
||||||
batch_x, label = batch
|
|
||||||
batch_x = batch_x.to(device='cuda' if model.on_gpu else 'cpu')
|
|
||||||
label = label.to(device='cuda' if model.on_gpu else 'cpu')
|
|
||||||
outputs.append(
|
|
||||||
model.validation_step((batch_x, label), idx, 1)
|
|
||||||
)
|
|
||||||
model.validation_epoch_end([outputs])
|
|
||||||
|
|
||||||
# trainer.test()
|
|
||||||
outpath = Path(config_obj.train.outpath)
|
|
||||||
model_type = config_obj.model.type
|
|
||||||
parameters = logger.name
|
|
||||||
version = f'version_{logger.version}'
|
|
||||||
inference_out = f'{parameters}_test_out.csv'
|
|
||||||
|
|
||||||
from main_inference import prepare_dataloader
|
|
||||||
import variables as V
|
|
||||||
test_dataloader = prepare_dataloader(config_obj)
|
|
||||||
|
|
||||||
with (outpath / model_type / parameters / version / inference_out).open(mode='w') as outfile:
|
|
||||||
outfile.write(f'file_name,prediction\n')
|
|
||||||
|
|
||||||
from tqdm import tqdm
|
|
||||||
for batch in tqdm(test_dataloader, total=len(test_dataloader)):
|
|
||||||
batch_x, file_names = batch
|
|
||||||
batch_x = batch_x.to(device='cuda' if model.on_gpu else 'cpu')
|
|
||||||
y = model(batch_x).main_out
|
|
||||||
predictions = (y >= 0.5).int()
|
|
||||||
for prediction, file_name in zip(predictions, file_names):
|
|
||||||
prediction_text = 'clear' if prediction == V.CLEAR else 'mask'
|
|
||||||
outfile.write(f'{file_name},{prediction_text}\n')
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@ -5,11 +5,11 @@ from torch.nn import ModuleList
|
|||||||
|
|
||||||
from ml_lib.modules.blocks import ConvModule, LinearModule
|
from ml_lib.modules.blocks import ConvModule, LinearModule
|
||||||
from ml_lib.modules.util import (LightningBaseModule, Splitter, Merger)
|
from ml_lib.modules.util import (LightningBaseModule, Splitter, Merger)
|
||||||
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetMixin,
|
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, DatasetMixin,
|
||||||
BaseDataloadersMixin)
|
BaseDataloadersMixin)
|
||||||
|
|
||||||
|
|
||||||
class BandwiseConvClassifier(BinaryMaskDatasetMixin,
|
class BandwiseConvClassifier(DatasetMixin,
|
||||||
BaseDataloadersMixin,
|
BaseDataloadersMixin,
|
||||||
BaseTrainMixin,
|
BaseTrainMixin,
|
||||||
BaseValMixin,
|
BaseValMixin,
|
||||||
|
@ -6,11 +6,11 @@ from torch.nn import ModuleList
|
|||||||
|
|
||||||
from ml_lib.modules.blocks import ConvModule, LinearModule
|
from ml_lib.modules.blocks import ConvModule, LinearModule
|
||||||
from ml_lib.modules.util import (LightningBaseModule, Splitter)
|
from ml_lib.modules.util import (LightningBaseModule, Splitter)
|
||||||
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetMixin,
|
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, DatasetMixin,
|
||||||
BaseDataloadersMixin)
|
BaseDataloadersMixin)
|
||||||
|
|
||||||
|
|
||||||
class BandwiseConvMultiheadClassifier(BinaryMaskDatasetMixin,
|
class BandwiseConvMultiheadClassifier(DatasetMixin,
|
||||||
BaseDataloadersMixin,
|
BaseDataloadersMixin,
|
||||||
BaseTrainMixin,
|
BaseTrainMixin,
|
||||||
BaseValMixin,
|
BaseValMixin,
|
||||||
|
@ -5,11 +5,11 @@ from torch.nn import ModuleList
|
|||||||
|
|
||||||
from ml_lib.modules.blocks import ConvModule, LinearModule
|
from ml_lib.modules.blocks import ConvModule, LinearModule
|
||||||
from ml_lib.modules.util import LightningBaseModule
|
from ml_lib.modules.util import LightningBaseModule
|
||||||
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetMixin,
|
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, DatasetMixin,
|
||||||
BaseDataloadersMixin)
|
BaseDataloadersMixin)
|
||||||
|
|
||||||
|
|
||||||
class ConvClassifier(BinaryMaskDatasetMixin,
|
class ConvClassifier(DatasetMixin,
|
||||||
BaseDataloadersMixin,
|
BaseDataloadersMixin,
|
||||||
BaseTrainMixin,
|
BaseTrainMixin,
|
||||||
BaseValMixin,
|
BaseValMixin,
|
||||||
|
@ -8,11 +8,11 @@ from torch.nn import ModuleList
|
|||||||
from ml_lib.modules.util import LightningBaseModule
|
from ml_lib.modules.util import LightningBaseModule
|
||||||
from ml_lib.utils.config import Config
|
from ml_lib.utils.config import Config
|
||||||
from ml_lib.utils.model_io import SavedLightningModels
|
from ml_lib.utils.model_io import SavedLightningModels
|
||||||
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetMixin,
|
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, DatasetMixin,
|
||||||
BaseDataloadersMixin)
|
BaseDataloadersMixin)
|
||||||
|
|
||||||
|
|
||||||
class Ensemble(BinaryMaskDatasetMixin,
|
class Ensemble(DatasetMixin,
|
||||||
BaseDataloadersMixin,
|
BaseDataloadersMixin,
|
||||||
BaseTrainMixin,
|
BaseTrainMixin,
|
||||||
BaseValMixin,
|
BaseValMixin,
|
||||||
|
@ -5,11 +5,11 @@ from torch.nn import ModuleList
|
|||||||
|
|
||||||
from ml_lib.modules.blocks import ConvModule, LinearModule, ResidualModule
|
from ml_lib.modules.blocks import ConvModule, LinearModule, ResidualModule
|
||||||
from ml_lib.modules.util import LightningBaseModule
|
from ml_lib.modules.util import LightningBaseModule
|
||||||
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetMixin,
|
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, DatasetMixin,
|
||||||
BaseDataloadersMixin)
|
BaseDataloadersMixin)
|
||||||
|
|
||||||
|
|
||||||
class ResidualConvClassifier(BinaryMaskDatasetMixin,
|
class ResidualConvClassifier(DatasetMixin,
|
||||||
BaseDataloadersMixin,
|
BaseDataloadersMixin,
|
||||||
BaseTrainMixin,
|
BaseTrainMixin,
|
||||||
BaseValMixin,
|
BaseValMixin,
|
||||||
|
@ -9,15 +9,16 @@ from einops import rearrange, repeat
|
|||||||
|
|
||||||
from ml_lib.modules.blocks import TransformerModule
|
from ml_lib.modules.blocks import TransformerModule
|
||||||
from ml_lib.modules.util import (LightningBaseModule, AutoPadToShape, F_x)
|
from ml_lib.modules.util import (LightningBaseModule, AutoPadToShape, F_x)
|
||||||
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetMixin,
|
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, DatasetMixin,
|
||||||
BaseDataloadersMixin)
|
BaseDataloadersMixin, BaseTestMixin)
|
||||||
|
|
||||||
MIN_NUM_PATCHES = 16
|
MIN_NUM_PATCHES = 16
|
||||||
|
|
||||||
class VisualTransformer(BinaryMaskDatasetMixin,
|
class VisualTransformer(DatasetMixin,
|
||||||
BaseDataloadersMixin,
|
BaseDataloadersMixin,
|
||||||
BaseTrainMixin,
|
BaseTrainMixin,
|
||||||
BaseValMixin,
|
BaseValMixin,
|
||||||
|
BaseTestMixin,
|
||||||
BaseOptimizerMixin,
|
BaseOptimizerMixin,
|
||||||
LightningBaseModule
|
LightningBaseModule
|
||||||
):
|
):
|
||||||
|
111
models/transformer_model_horizontal.py
Normal file
111
models/transformer_model_horizontal.py
Normal file
@ -0,0 +1,111 @@
|
|||||||
|
from argparse import Namespace
|
||||||
|
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from ml_lib.modules.blocks import TransformerModule
|
||||||
|
from ml_lib.modules.util import (LightningBaseModule, AutoPadToShape, F_x, SlidingWindow)
|
||||||
|
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, DatasetMixin,
|
||||||
|
BaseDataloadersMixin, BaseTestMixin)
|
||||||
|
|
||||||
|
MIN_NUM_PATCHES = 16
|
||||||
|
|
||||||
|
class HorizontalVisualTransformer(DatasetMixin,
|
||||||
|
BaseDataloadersMixin,
|
||||||
|
BaseTrainMixin,
|
||||||
|
BaseValMixin,
|
||||||
|
BaseTestMixin,
|
||||||
|
BaseOptimizerMixin,
|
||||||
|
LightningBaseModule
|
||||||
|
):
|
||||||
|
|
||||||
|
def __init__(self, hparams):
|
||||||
|
super(HorizontalVisualTransformer, self).__init__(hparams)
|
||||||
|
|
||||||
|
# Dataset
|
||||||
|
# =============================================================================
|
||||||
|
self.dataset = self.build_dataset()
|
||||||
|
|
||||||
|
self.in_shape = self.dataset.train_dataset.sample_shape
|
||||||
|
assert len(self.in_shape) == 3, 'There need to be three Dimensions'
|
||||||
|
channels, height, width = self.in_shape
|
||||||
|
|
||||||
|
# Model Paramters
|
||||||
|
# =============================================================================
|
||||||
|
# Additional parameters
|
||||||
|
self.embed_dim = self.params.embedding_size
|
||||||
|
self.patch_size = self.params.patch_size
|
||||||
|
self.height = height
|
||||||
|
self.width = width
|
||||||
|
self.channels = channels
|
||||||
|
|
||||||
|
self.new_height = ((self.height - self.patch_size)//1) + 1
|
||||||
|
|
||||||
|
num_patches = self.new_height - (self.patch_size // 2)
|
||||||
|
patch_dim = channels * self.patch_size * self.width
|
||||||
|
assert num_patches >= MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for ' + \
|
||||||
|
f'attention. Try decreasing your patch size'
|
||||||
|
|
||||||
|
# Correct the Embedding Dim
|
||||||
|
if not self.embed_dim % self.params.heads == 0:
|
||||||
|
self.embed_dim = (self.embed_dim // self.params.heads) * self.params.heads
|
||||||
|
message = ('Embedding Dimension was fixed to be devideable by the number' +
|
||||||
|
f' of attention heads, is now: {self.embed_dim}')
|
||||||
|
for func in print, warnings.warn:
|
||||||
|
func(message)
|
||||||
|
|
||||||
|
# Utility Modules
|
||||||
|
self.autopad = AutoPadToShape((self.new_height, self.width))
|
||||||
|
self.dropout = nn.Dropout(self.params.dropout)
|
||||||
|
self.slider = SlidingWindow((channels, *self.autopad.target_shape), (self.patch_size, self.width),
|
||||||
|
keepdim=False)
|
||||||
|
|
||||||
|
# Modules with Parameters
|
||||||
|
self.transformer = TransformerModule(in_shape=self.embed_dim, hidden_size=self.params.lat_dim,
|
||||||
|
n_heads=self.params.heads, num_layers=self.params.attn_depth,
|
||||||
|
dropout=self.params.dropout, use_norm=self.params.use_norm,
|
||||||
|
activation=self.params.activation_as_string
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, self.embed_dim))
|
||||||
|
self.patch_to_embedding = nn.Linear(patch_dim, self.embed_dim) if self.params.embedding_size \
|
||||||
|
else F_x(self.embed_dim)
|
||||||
|
self.cls_token = nn.Parameter(torch.randn(1, 1, self.embed_dim))
|
||||||
|
self.to_cls_token = nn.Identity()
|
||||||
|
|
||||||
|
self.mlp_head = nn.Sequential(
|
||||||
|
nn.LayerNorm(self.embed_dim),
|
||||||
|
nn.Linear(self.embed_dim, self.params.lat_dim),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Dropout(self.params.dropout),
|
||||||
|
nn.Linear(self.params.lat_dim, 1),
|
||||||
|
nn.Sigmoid()
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x, mask=None):
|
||||||
|
"""
|
||||||
|
:param x: the sequence to the encoder (required).
|
||||||
|
:param mask: the mask for the src sequence (optional).
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
tensor = self.autopad(x)
|
||||||
|
tensor = self.slider(tensor)
|
||||||
|
|
||||||
|
tensor = self.patch_to_embedding(tensor)
|
||||||
|
b, n, _ = tensor.shape
|
||||||
|
|
||||||
|
# cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
|
||||||
|
cls_tokens = self.cls_token.repeat((b, 1, 1))
|
||||||
|
|
||||||
|
tensor = torch.cat((cls_tokens, tensor), dim=1)
|
||||||
|
tensor += self.pos_embedding[:, :(n + 1)]
|
||||||
|
tensor = self.dropout(tensor)
|
||||||
|
|
||||||
|
tensor = self.transformer(tensor, mask)
|
||||||
|
|
||||||
|
tensor = self.to_cls_token(tensor[:, 0])
|
||||||
|
tensor = self.mlp_head(tensor)
|
||||||
|
return Namespace(main_out=tensor)
|
@ -7,21 +7,22 @@ from torch import nn
|
|||||||
|
|
||||||
from ml_lib.modules.blocks import TransformerModule
|
from ml_lib.modules.blocks import TransformerModule
|
||||||
from ml_lib.modules.util import (LightningBaseModule, AutoPadToShape, F_x, SlidingWindow)
|
from ml_lib.modules.util import (LightningBaseModule, AutoPadToShape, F_x, SlidingWindow)
|
||||||
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetMixin,
|
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, DatasetMixin,
|
||||||
BaseDataloadersMixin)
|
BaseDataloadersMixin, BaseTestMixin)
|
||||||
|
|
||||||
MIN_NUM_PATCHES = 16
|
MIN_NUM_PATCHES = 16
|
||||||
|
|
||||||
class SequentialVisualTransformer(BinaryMaskDatasetMixin,
|
class VerticalVisualTransformer(DatasetMixin,
|
||||||
BaseDataloadersMixin,
|
BaseDataloadersMixin,
|
||||||
BaseTrainMixin,
|
BaseTrainMixin,
|
||||||
BaseValMixin,
|
BaseValMixin,
|
||||||
|
BaseTestMixin,
|
||||||
BaseOptimizerMixin,
|
BaseOptimizerMixin,
|
||||||
LightningBaseModule
|
LightningBaseModule
|
||||||
):
|
):
|
||||||
|
|
||||||
def __init__(self, hparams):
|
def __init__(self, hparams):
|
||||||
super(SequentialVisualTransformer, self).__init__(hparams)
|
super(VerticalVisualTransformer, self).__init__(hparams)
|
||||||
|
|
||||||
# Dataset
|
# Dataset
|
||||||
# =============================================================================
|
# =============================================================================
|
15
multi_run.py
15
multi_run.py
@ -4,6 +4,7 @@ from _paramters import main_arg_parser
|
|||||||
from main import run_lightning_loop
|
from main import run_lightning_loop
|
||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
|
import shutil
|
||||||
|
|
||||||
from ml_lib.utils.config import Config
|
from ml_lib.utils.config import Config
|
||||||
|
|
||||||
@ -22,7 +23,7 @@ if __name__ == '__main__':
|
|||||||
arg_dict.update(main_seed=seed)
|
arg_dict.update(main_seed=seed)
|
||||||
if False:
|
if False:
|
||||||
for patch_size in [3, 5 , 9]:
|
for patch_size in [3, 5 , 9]:
|
||||||
for model in ['SequentialVisualTransformer']:
|
for model in ['VerticalVisualTransformer']:
|
||||||
arg_dict.update(model_type=model, model_patch_size=patch_size)
|
arg_dict.update(model_type=model, model_patch_size=patch_size)
|
||||||
raw_conf = dict(data_speed_amount=0.0, data_speed_min=0.0, data_speed_max=0.0,
|
raw_conf = dict(data_speed_amount=0.0, data_speed_min=0.0, data_speed_max=0.0,
|
||||||
data_mask_ratio=0.0, data_noise_ratio=0.0, data_shift_ratio=0.0, data_loudness_ratio=0.0,
|
data_mask_ratio=0.0, data_noise_ratio=0.0, data_shift_ratio=0.0, data_loudness_ratio=0.0,
|
||||||
@ -52,12 +53,12 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
arg_dict.update(dicts)
|
arg_dict.update(dicts)
|
||||||
if True:
|
if True:
|
||||||
for patch_size in [3, 7]:
|
for patch_size in [7]:
|
||||||
for lat_dim in [4, 32]:
|
for lat_dim in [32]:
|
||||||
for heads in [2, 4]:
|
for heads in [8]:
|
||||||
for embedding_size in [32, 64]:
|
for embedding_size in [7**2]:
|
||||||
for attn_depth in [1, 3]:
|
for attn_depth in [1, 3, 5, 7]:
|
||||||
for model in ['SequentialVisualTransformer', 'VisualTransformer']:
|
for model in ['HorizontalVisualTransformer']:
|
||||||
arg_dict.update(
|
arg_dict.update(
|
||||||
model_type=model,
|
model_type=model,
|
||||||
model_patch_size=patch_size,
|
model_patch_size=patch_size,
|
||||||
|
@ -121,7 +121,45 @@ class BaseValMixin:
|
|||||||
self.log(key, summary_dict[key])
|
self.log(key, summary_dict[key])
|
||||||
|
|
||||||
|
|
||||||
class BinaryMaskDatasetMixin:
|
class BaseTestMixin:
|
||||||
|
|
||||||
|
absolute_loss = nn.L1Loss()
|
||||||
|
nll_loss = nn.NLLLoss()
|
||||||
|
bce_loss = nn.BCELoss()
|
||||||
|
|
||||||
|
def test_step(self, batch_xy, batch_idx, dataloader_idx, *args, **kwargs):
|
||||||
|
assert isinstance(self, LightningBaseModule)
|
||||||
|
batch_x, batch_y = batch_xy
|
||||||
|
y = self(batch_x).main_out
|
||||||
|
test_bce_loss = self.bce_loss(y.squeeze(), batch_y)
|
||||||
|
return dict(test_bce_loss=test_bce_loss,
|
||||||
|
batch_idx=batch_idx, y=y, batch_y=batch_y)
|
||||||
|
|
||||||
|
def test_epoch_end(self, outputs, *_, **__):
|
||||||
|
assert isinstance(self, LightningBaseModule)
|
||||||
|
summary_dict = dict()
|
||||||
|
|
||||||
|
keys = list(outputs[0].keys())
|
||||||
|
summary_dict.update({f'mean_{key}': torch.mean(torch.stack([output[key]
|
||||||
|
for output in outputs]))
|
||||||
|
for key in keys if 'loss' in key}
|
||||||
|
)
|
||||||
|
|
||||||
|
# UnweightedAverageRecall
|
||||||
|
y_true = torch.cat([output['batch_y'] for output in outputs]) .cpu().numpy()
|
||||||
|
y_pred = torch.cat([output['y'] for output in outputs]).squeeze().cpu().numpy()
|
||||||
|
|
||||||
|
y_pred = (y_pred >= 0.5).astype(np.float32)
|
||||||
|
|
||||||
|
uar_score = sklearn.metrics.recall_score(y_true, y_pred, labels=[0, 1], average='macro',
|
||||||
|
sample_weight=None, zero_division='warn')
|
||||||
|
uar_score = torch.as_tensor(uar_score)
|
||||||
|
summary_dict.update({f'uar_score': uar_score})
|
||||||
|
for key in summary_dict.keys():
|
||||||
|
self.log(key, summary_dict[key])
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetMixin:
|
||||||
|
|
||||||
def build_dataset(self):
|
def build_dataset(self):
|
||||||
assert isinstance(self, LightningBaseModule)
|
assert isinstance(self, LightningBaseModule)
|
||||||
@ -159,21 +197,20 @@ class BinaryMaskDatasetMixin:
|
|||||||
util_transforms])
|
util_transforms])
|
||||||
|
|
||||||
# Datasets
|
# Datasets
|
||||||
from datasets.binar_masks import BinaryMasksDataset
|
|
||||||
dataset = Namespace(
|
dataset = Namespace(
|
||||||
**dict(
|
**dict(
|
||||||
# TRAIN DATASET
|
# TRAIN DATASET
|
||||||
train_dataset=BinaryMasksDataset(self.params.root, setting=V.DATA_OPTIONS.train,
|
train_dataset=self.dataset_class(self.params.root, setting=V.DATA_OPTIONS.train,
|
||||||
use_preprocessed=self.params.use_preprocessed,
|
use_preprocessed=self.params.use_preprocessed,
|
||||||
stretch_dataset=self.params.stretch,
|
stretch_dataset=self.params.stretch,
|
||||||
mel_transforms=mel_transforms_train, transforms=aug_transforms),
|
mel_transforms=mel_transforms_train, transforms=aug_transforms),
|
||||||
# VALIDATION DATASET
|
# VALIDATION DATASET
|
||||||
val_train_dataset=BinaryMasksDataset(self.params.root, setting=V.DATA_OPTIONS.train,
|
val_train_dataset=self.dataset_class(self.params.root, setting=V.DATA_OPTIONS.train,
|
||||||
mel_transforms=mel_transforms, transforms=util_transforms),
|
mel_transforms=mel_transforms, transforms=util_transforms),
|
||||||
val_dataset=BinaryMasksDataset(self.params.root, setting=V.DATA_OPTIONS.devel,
|
val_dataset=self.dataset_class(self.params.root, setting=V.DATA_OPTIONS.devel,
|
||||||
mel_transforms=mel_transforms, transforms=util_transforms),
|
mel_transforms=mel_transforms, transforms=util_transforms),
|
||||||
# TEST DATASET
|
# TEST DATASET
|
||||||
test_dataset=BinaryMasksDataset(self.params.root, setting=V.DATA_OPTIONS.test,
|
test_dataset=self.dataset_class(self.params.root, setting=V.DATA_OPTIONS.test,
|
||||||
mel_transforms=mel_transforms, transforms=util_transforms),
|
mel_transforms=mel_transforms, transforms=util_transforms),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -190,22 +227,23 @@ class BaseDataloadersMixin(ABC):
|
|||||||
# sampler = RandomSampler(self.dataset.train_dataset, True, len(self.dataset.train_dataset))
|
# sampler = RandomSampler(self.dataset.train_dataset, True, len(self.dataset.train_dataset))
|
||||||
sampler = None
|
sampler = None
|
||||||
return DataLoader(dataset=self.dataset.train_dataset, shuffle=True if not sampler else None, sampler=sampler,
|
return DataLoader(dataset=self.dataset.train_dataset, shuffle=True if not sampler else None, sampler=sampler,
|
||||||
batch_size=self.params.batch_size,
|
batch_size=self.params.batch_size, pin_memory=True,
|
||||||
num_workers=self.params.worker)
|
num_workers=self.params.worker)
|
||||||
|
|
||||||
# Test Dataloader
|
# Test Dataloader
|
||||||
def test_dataloader(self):
|
def test_dataloader(self):
|
||||||
assert isinstance(self, LightningBaseModule)
|
assert isinstance(self, LightningBaseModule)
|
||||||
return DataLoader(dataset=self.dataset.test_dataset, shuffle=False,
|
return DataLoader(dataset=self.dataset.test_dataset, shuffle=False,
|
||||||
batch_size=self.params.batch_size,
|
batch_size=self.params.batch_size, pin_memory=True,
|
||||||
num_workers=self.params.worker)
|
num_workers=self.params.worker)
|
||||||
|
|
||||||
# Validation Dataloader
|
# Validation Dataloader
|
||||||
def val_dataloader(self):
|
def val_dataloader(self):
|
||||||
assert isinstance(self, LightningBaseModule)
|
assert isinstance(self, LightningBaseModule)
|
||||||
val_dataloader = DataLoader(dataset=self.dataset.val_dataset, shuffle=False,
|
val_dataloader = DataLoader(dataset=self.dataset.val_dataset, shuffle=False, pin_memory=True,
|
||||||
batch_size=self.params.batch_size, num_workers=self.params.worker)
|
batch_size=self.params.batch_size, num_workers=self.params.worker)
|
||||||
|
|
||||||
train_dataloader = DataLoader(self.dataset.val_train_dataset, num_workers=self.params.worker,
|
train_dataloader = DataLoader(self.dataset.val_train_dataset, num_workers=self.params.worker,
|
||||||
|
pin_memory=True,
|
||||||
batch_size=self.params.batch_size, shuffle=False)
|
batch_size=self.params.batch_size, shuffle=False)
|
||||||
return [val_dataloader, train_dataloader]
|
return [val_dataloader, train_dataloader]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user