Audio Dataset
This commit is contained in:
parent
95561acc35
commit
95dcf22f3d
@ -2,7 +2,8 @@ from argparse import ArgumentParser, Namespace
|
||||
from distutils.util import strtobool
|
||||
from pathlib import Path
|
||||
|
||||
import os
|
||||
NEPTUNE_API_KEY = 'eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vdWkubmVwdHVuZS5haSIsImFwaV91cmwiOiJodHRwczovL3VpLm5lcHR1bmUu' \
|
||||
'YWkiLCJhcGlfa2V5IjoiZmI0OGMzNzUtOTg1NS00Yzg2LThjMzYtMWFiYjUwMDUyMjVlIn0='
|
||||
|
||||
# Parameter Configuration
|
||||
# =============================================================================
|
||||
@ -18,10 +19,10 @@ main_arg_parser.add_argument("--main_eval", type=strtobool, default=True, help="
|
||||
main_arg_parser.add_argument("--main_seed", type=int, default=69, help="")
|
||||
|
||||
# Data Parameters
|
||||
main_arg_parser.add_argument("--data_worker", type=int, default=11, help="")
|
||||
main_arg_parser.add_argument("--data_class_name", type=str, default='Urban8K', help="")
|
||||
main_arg_parser.add_argument("--data_worker", type=int, default=6, help="")
|
||||
main_arg_parser.add_argument("--data_root", type=str, default='data', help="")
|
||||
main_arg_parser.add_argument("--data_class_name", type=str, default='BinaryMasksDataset', help="")
|
||||
main_arg_parser.add_argument("--data_use_preprocessed", type=strtobool, default=False, help="")
|
||||
main_arg_parser.add_argument("--data_use_preprocessed", type=strtobool, default=True, help="")
|
||||
main_arg_parser.add_argument("--data_n_mels", type=int, default=64, help="")
|
||||
main_arg_parser.add_argument("--data_sr", type=int, default=16000, help="")
|
||||
main_arg_parser.add_argument("--data_hop_length", type=int, default=256, help="")
|
||||
@ -29,17 +30,17 @@ main_arg_parser.add_argument("--data_n_fft", type=int, default=512, help="")
|
||||
main_arg_parser.add_argument("--data_stretch", type=strtobool, default=True, help="")
|
||||
|
||||
# Transformation Parameters
|
||||
main_arg_parser.add_argument("--data_loudness_ratio", type=float, default=0.0, help="") # 0.4
|
||||
main_arg_parser.add_argument("--data_shift_ratio", type=float, default=0.0, help="") # 0.4
|
||||
main_arg_parser.add_argument("--data_noise_ratio", type=float, default=0, help="") # 0.4
|
||||
main_arg_parser.add_argument("--data_mask_ratio", type=float, default=0.3, help="") # 0.2
|
||||
main_arg_parser.add_argument("--data_speed_amount", type=float, default=0, help="") # 0.4
|
||||
main_arg_parser.add_argument("--data_speed_min", type=float, default=0, help="") # 0.7
|
||||
main_arg_parser.add_argument("--data_speed_max", type=float, default=0, help="") # 1.7
|
||||
main_arg_parser.add_argument("--data_loudness_ratio", type=float, default=0.0, help="") # 0.4
|
||||
main_arg_parser.add_argument("--data_shift_ratio", type=float, default=0.0, help="") # 0.4
|
||||
main_arg_parser.add_argument("--data_noise_ratio", type=float, default=0, help="") # 0.4
|
||||
main_arg_parser.add_argument("--data_mask_ratio", type=float, default=0.3, help="") # 0.2
|
||||
main_arg_parser.add_argument("--data_speed_amount", type=float, default=0, help="") # 0.4
|
||||
main_arg_parser.add_argument("--data_speed_min", type=float, default=0, help="") # 0.7
|
||||
main_arg_parser.add_argument("--data_speed_max", type=float, default=0, help="") # 1.7
|
||||
|
||||
# Model Parameters
|
||||
# General
|
||||
main_arg_parser.add_argument("--model_type", type=str, default="SequentialVisualTransformer", help="")
|
||||
main_arg_parser.add_argument("--model_type", type=str, default="HorizontalVisualTransformer", help="")
|
||||
main_arg_parser.add_argument("--model_weight_init", type=str, default="xavier_normal_", help="")
|
||||
main_arg_parser.add_argument("--model_activation", type=str, default="gelu", help="")
|
||||
main_arg_parser.add_argument("--model_bias", type=strtobool, default=True, help="")
|
||||
@ -63,8 +64,8 @@ main_arg_parser.add_argument("--train_version", type=strtobool, required=False,
|
||||
main_arg_parser.add_argument("--train_sto_weight_avg", type=strtobool, default=False, help="")
|
||||
main_arg_parser.add_argument("--train_weight_decay", type=float, default=0, help="")
|
||||
main_arg_parser.add_argument("--train_opt_reset_interval", type=int, default=0, help="")
|
||||
main_arg_parser.add_argument("--train_epochs", type=int, default=100, help="")
|
||||
main_arg_parser.add_argument("--train_batch_size", type=int, default=250, help="")
|
||||
main_arg_parser.add_argument("--train_epochs", type=int, default=200, help="")
|
||||
main_arg_parser.add_argument("--train_batch_size", type=int, default=200, help="")
|
||||
main_arg_parser.add_argument("--train_lr", type=float, default=1e-3, help="")
|
||||
main_arg_parser.add_argument("--train_lr_warmup_steps", type=int, default=10, help="")
|
||||
main_arg_parser.add_argument("--train_num_sanity_val_steps", type=int, default=0, help="")
|
||||
@ -72,7 +73,7 @@ main_arg_parser.add_argument("--train_num_sanity_val_steps", type=int, default=0
|
||||
# Project Parameters
|
||||
main_arg_parser.add_argument("--project_name", type=str, default=_ROOT.name, help="")
|
||||
main_arg_parser.add_argument("--project_owner", type=str, default='si11ium', help="")
|
||||
main_arg_parser.add_argument("--project_neptune_key", type=str, default=os.getenv('NEPTUNE_API_TOKEN'), help="")
|
||||
main_arg_parser.add_argument("--project_neptune_key", type=str, default=NEPTUNE_API_KEY, help="")
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Parse it
|
||||
|
@ -18,15 +18,18 @@ class BinaryMasksDataset(Dataset):
|
||||
def sample_shape(self):
|
||||
return self[0][0].shape
|
||||
|
||||
@property
|
||||
def _fingerprint(self):
|
||||
return dict(**self._mel_kwargs, normalize=self.normalize)
|
||||
|
||||
def __init__(self, data_root, setting, mel_transforms, transforms=None, stretch_dataset=False,
|
||||
use_preprocessed=True):
|
||||
self.use_preprocessed = use_preprocessed
|
||||
self.stretch = stretch_dataset
|
||||
assert isinstance(setting, str), f'Setting has to be a string, but was: {type(setting)}.'
|
||||
assert setting in V.DATA_OPTIONS, f'Setting must match one of: {V.DATA_OPTIONS}.'
|
||||
super(BinaryMasksDataset, self).__init__()
|
||||
|
||||
self.data_root = Path(data_root)
|
||||
self.data_root = Path(data_root) / 'ComParE2020_Mask'
|
||||
self.setting = setting
|
||||
self._wav_folder = self.data_root / 'wav'
|
||||
self._mel_folder = self.data_root / 'mel'
|
||||
@ -37,16 +40,36 @@ class BinaryMasksDataset(Dataset):
|
||||
self._wav_files = list(sorted(self._labels.keys()))
|
||||
self._transforms = transforms or F_x(in_shape=None)
|
||||
|
||||
param_storage = self._mel_folder / 'data_params.pik'
|
||||
self._mel_folder.mkdir(parents=True, exist_ok=True)
|
||||
try:
|
||||
pik_data = param_storage.read_bytes()
|
||||
fingerprint = pickle.loads(pik_data)
|
||||
if fingerprint == self._fingerprint:
|
||||
self.use_preprocessed = use_preprocessed
|
||||
else:
|
||||
print('Diverging parameters were found; Refreshing...')
|
||||
param_storage.unlink()
|
||||
pik_data = pickle.dumps(self._fingerprint)
|
||||
param_storage.write_bytes(pik_data)
|
||||
self.use_preprocessed = True
|
||||
|
||||
except FileNotFoundError:
|
||||
pik_data = pickle.dumps(self._fingerprint)
|
||||
param_storage.write_bytes(pik_data)
|
||||
self.use_preprocessed = True
|
||||
|
||||
def _build_labels(self):
|
||||
labeldict = dict()
|
||||
with open(Path(self.data_root) / 'lab' / 'labels.csv', mode='r') as f:
|
||||
labelfile = 'labels' if self.setting != V.DATA_OPTIONS.test else V.DATA_OPTIONS.test
|
||||
with open(Path(self.data_root) / 'lab' / f'{labelfile}.csv', mode='r') as f:
|
||||
# Exclude the header
|
||||
_ = next(f)
|
||||
for row in f:
|
||||
if self.setting not in row:
|
||||
continue
|
||||
filename, label = row.strip().split(',')
|
||||
labeldict[filename] = self._to_label[label.lower()] if not self.setting == 'test' else filename
|
||||
labeldict[filename] = self._to_label[label.lower()] # if not self.setting == 'test' else filename
|
||||
if self.stretch and self.setting == V.DATA_OPTIONS.train:
|
||||
additional_dict = ({f'X{key}': val for key, val in labeldict.items()})
|
||||
additional_dict.update({f'XX{key}': val for key, val in labeldict.items()})
|
||||
|
@ -1,95 +1,140 @@
|
||||
import pickle
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
import multiprocessing as mp
|
||||
|
||||
import librosa as librosa
|
||||
from torch.utils.data import Dataset
|
||||
from torch.utils.data import Dataset, ConcatDataset
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
import variables as V
|
||||
from ml_lib.audio_toolset.mel_dataset import TorchMelDataset
|
||||
from ml_lib.modules.util import F_x
|
||||
|
||||
|
||||
class BinaryMasksDataset(Dataset):
|
||||
_to_label = defaultdict(lambda: -1)
|
||||
_to_label.update(dict(clear=V.CLEAR, mask=V.MASK))
|
||||
class Urban8K(Dataset):
|
||||
|
||||
@property
|
||||
def sample_shape(self):
|
||||
return self[0][0].shape
|
||||
|
||||
def __init__(self, data_root, setting, mel_transforms, transforms=None, stretch_dataset=False,
|
||||
use_preprocessed=True):
|
||||
self.use_preprocessed = use_preprocessed
|
||||
self.stretch = stretch_dataset
|
||||
@property
|
||||
def _fingerprint(self):
|
||||
return str(self._mel_transform)
|
||||
|
||||
def __init__(self, data_root, setting, mel_transforms, fold=1, transforms=None,
|
||||
use_preprocessed=True, audio_segment_len=62, audio_hop_len=30, num_worker=mp.cpu_count(),
|
||||
**_):
|
||||
assert isinstance(setting, str), f'Setting has to be a string, but was: {type(setting)}.'
|
||||
assert setting in V.DATA_OPTIONS, f'Setting must match one of: {V.DATA_OPTIONS}.'
|
||||
super(BinaryMasksDataset, self).__init__()
|
||||
assert fold in range(1, 11)
|
||||
super(Urban8K, self).__init__()
|
||||
|
||||
self.data_root = Path(data_root)
|
||||
self.data_root = Path(data_root) / 'UrbanSound8K'
|
||||
self.setting = setting
|
||||
self._wav_folder = self.data_root / 'wav'
|
||||
self._mel_folder = self.data_root / 'mel'
|
||||
self.num_worker = num_worker
|
||||
self.fold = fold if self.setting == V.DATA_OPTIONS.train else 10
|
||||
self.use_preprocessed = use_preprocessed
|
||||
self._wav_folder = self.data_root / 'audio' / f'fold{self.fold}'
|
||||
self._mel_folder = self.data_root / 'mel' / f'fold{self.fold}'
|
||||
self.container_ext = '.pik'
|
||||
self._mel_transform = mel_transforms
|
||||
|
||||
self._labels = self._build_labels()
|
||||
self._wav_files = list(sorted(self._labels.keys()))
|
||||
self._transforms = transforms or F_x(in_shape=None)
|
||||
transforms = transforms or F_x(in_shape=None)
|
||||
|
||||
param_storage = self._mel_folder / 'data_params.pik'
|
||||
self._mel_folder.mkdir(parents=True, exist_ok=True)
|
||||
try:
|
||||
pik_data = param_storage.read_bytes()
|
||||
fingerprint = pickle.loads(pik_data)
|
||||
if fingerprint == self._fingerprint:
|
||||
self.use_preprocessed = use_preprocessed
|
||||
else:
|
||||
print('Diverging parameters were found; Refreshing...')
|
||||
param_storage.unlink()
|
||||
pik_data = pickle.dumps(self._fingerprint)
|
||||
param_storage.write_bytes(pik_data)
|
||||
self.use_preprocessed = False
|
||||
|
||||
except FileNotFoundError:
|
||||
pik_data = pickle.dumps(self._fingerprint)
|
||||
param_storage.write_bytes(pik_data)
|
||||
self.use_preprocessed = False
|
||||
|
||||
|
||||
while True:
|
||||
if not self.use_preprocessed:
|
||||
self._pre_process()
|
||||
try:
|
||||
self._dataset = ConcatDataset(
|
||||
[TorchMelDataset(identifier=key, mel_path=self._mel_folder, transform=transforms,
|
||||
segment_len=audio_segment_len, hop_len=audio_hop_len,
|
||||
label=self._labels[key]['label']
|
||||
) for key in self._labels.keys()]
|
||||
)
|
||||
break
|
||||
except IOError:
|
||||
self.use_preprocessed = False
|
||||
pass
|
||||
|
||||
def _build_labels(self):
|
||||
labeldict = dict()
|
||||
with open(Path(self.data_root) / 'lab' / 'labels.csv', mode='r') as f:
|
||||
with open(Path(self.data_root) / 'metadata' / 'UrbanSound8K.csv', mode='r') as f:
|
||||
# Exclude the header
|
||||
_ = next(f)
|
||||
for row in f:
|
||||
if self.setting not in row:
|
||||
continue
|
||||
filename, label = row.strip().split(',')
|
||||
labeldict[filename] = self._to_label[label.lower()] if not self.setting == 'test' else filename
|
||||
if self.stretch and self.setting == V.DATA_OPTIONS.train:
|
||||
additional_dict = ({f'X{key}': val for key, val in labeldict.items()})
|
||||
additional_dict.update({f'XX{key}': val for key, val in labeldict.items()})
|
||||
additional_dict.update({f'XXX{key}': val for key, val in labeldict.items()})
|
||||
labeldict.update(additional_dict)
|
||||
slice_file_name, fs_id, start, end, salience, fold, class_id, class_name = row.strip().split(',')
|
||||
if int(fold) == self.fold:
|
||||
key = slice_file_name.replace('.wav', '')
|
||||
labeldict[key] = dict(label=int(class_id), fold=int(fold))
|
||||
|
||||
# Delete File if one exists.
|
||||
if not self.use_preprocessed:
|
||||
for key in labeldict.keys():
|
||||
try:
|
||||
(self._mel_folder / (key.replace('.wav', '') + self.container_ext)).unlink()
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
for mel_file in self._mel_folder.rglob(f'{key}_*'):
|
||||
try:
|
||||
mel_file.unlink(missing_ok=True)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
return labeldict
|
||||
|
||||
def __len__(self):
|
||||
return len(self._labels)
|
||||
return len(self._dataset)
|
||||
|
||||
def _compute_or_retrieve(self, filename):
|
||||
def _pre_process(self):
|
||||
print('Preprocessing Mel Files....')
|
||||
with mp.Pool(processes=self.num_worker) as pool:
|
||||
with tqdm(total=len(self._labels)) as pbar:
|
||||
for i, _ in enumerate(pool.imap_unordered(self._build_mel, self._labels.keys())):
|
||||
pbar.update()
|
||||
|
||||
if not (self._mel_folder / (filename + self.container_ext)).exists():
|
||||
raw_sample, sr = librosa.core.load(self._wav_folder / (filename.replace('X', '') + '.wav'))
|
||||
def _build_mel(self, filename):
|
||||
|
||||
wav_file = self._wav_folder / (filename.replace('X', '') + '.wav')
|
||||
mel_file = list(self._mel_folder.glob(f'{filename}_*'))
|
||||
|
||||
if not mel_file:
|
||||
raw_sample, sr = librosa.core.load(wav_file)
|
||||
mel_sample = self._mel_transform(raw_sample)
|
||||
m, n = mel_sample.shape
|
||||
mel_file = self._mel_folder / f'{filename}_{m}_{n}'
|
||||
self._mel_folder.mkdir(exist_ok=True, parents=True)
|
||||
with (self._mel_folder / (filename + self.container_ext)).open(mode='wb') as f:
|
||||
with mel_file.open(mode='wb') as f:
|
||||
pickle.dump(mel_sample, f, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
else:
|
||||
# print(f"Already existed.. Skipping {filename}")
|
||||
mel_file = mel_file[0]
|
||||
|
||||
with (self._mel_folder / (filename + self.container_ext)).open(mode='rb') as f:
|
||||
with mel_file.open(mode='rb') as f:
|
||||
mel_sample = pickle.load(f, fix_imports=True)
|
||||
return mel_sample
|
||||
return mel_sample, mel_file
|
||||
|
||||
def __getitem__(self, item):
|
||||
transformed_samples, label = self._dataset[item]
|
||||
|
||||
key: str = list(self._labels.keys())[item]
|
||||
filename = key.replace('.wav', '')
|
||||
mel_sample = self._compute_or_retrieve(filename)
|
||||
label = self._labels[key]
|
||||
|
||||
transformed_samples = self._transforms(mel_sample)
|
||||
|
||||
if self.setting != V.DATA_OPTIONS.test:
|
||||
# In test, filenames instead of labels are returned. This is a little hacky though.
|
||||
label = torch.as_tensor(label, dtype=torch.float)
|
||||
label = torch.as_tensor(label, dtype=torch.float)
|
||||
|
||||
return transformed_samples, label
|
||||
|
140
datasets/urban_8k_torchaudio.py
Normal file
140
datasets/urban_8k_torchaudio.py
Normal file
@ -0,0 +1,140 @@
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
import multiprocessing as mp
|
||||
|
||||
import librosa as librosa
|
||||
from torch.utils.data import Dataset, ConcatDataset
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
import variables as V
|
||||
from ml_lib.audio_toolset.mel_dataset import TorchMelDataset
|
||||
from ml_lib.modules.util import F_x
|
||||
|
||||
|
||||
class Urban8K_TO(Dataset):
|
||||
|
||||
@property
|
||||
def sample_shape(self):
|
||||
return self[0][0].shape
|
||||
|
||||
@property
|
||||
def _fingerprint(self):
|
||||
return str(self._mel_transform)
|
||||
|
||||
def __init__(self, data_root, setting, mel_transforms, fold=1, transforms=None,
|
||||
use_preprocessed=True, audio_segment_len=1, audio_hop_len=1, num_worker=mp.cpu_count(),
|
||||
**_):
|
||||
assert isinstance(setting, str), f'Setting has to be a string, but was: {type(setting)}.'
|
||||
assert setting in V.DATA_OPTIONS, f'Setting must match one of: {V.DATA_OPTIONS}.'
|
||||
assert fold in range(1, 11)
|
||||
super(Urban8K_TO, self).__init__()
|
||||
|
||||
self.data_root = Path(data_root) / 'UrbanSound8K'
|
||||
self.setting = setting
|
||||
self.num_worker = num_worker
|
||||
self.fold = fold if self.setting == V.DATA_OPTIONS.train else 10
|
||||
self.use_preprocessed = use_preprocessed
|
||||
self._wav_folder = self.data_root / 'audio' / f'fold{self.fold}'
|
||||
self._mel_folder = self.data_root / 'mel' / f'fold{self.fold}'
|
||||
self.container_ext = '.pik'
|
||||
self._mel_transform = mel_transforms
|
||||
|
||||
self._labels = self._build_labels()
|
||||
self._wav_files = list(sorted(self._labels.keys()))
|
||||
transforms = transforms or F_x(in_shape=None)
|
||||
|
||||
param_storage = self._mel_folder / 'data_params.pik'
|
||||
self._mel_folder.mkdir(parents=True, exist_ok=True)
|
||||
try:
|
||||
pik_data = param_storage.read_bytes()
|
||||
fingerprint = pickle.loads(pik_data)
|
||||
if fingerprint == self._fingerprint:
|
||||
self.use_preprocessed = use_preprocessed
|
||||
else:
|
||||
print('Diverging parameters were found; Refreshing...')
|
||||
param_storage.unlink()
|
||||
pik_data = pickle.dumps(self._fingerprint)
|
||||
param_storage.write_bytes(pik_data)
|
||||
self.use_preprocessed = False
|
||||
|
||||
except FileNotFoundError:
|
||||
pik_data = pickle.dumps(self._fingerprint)
|
||||
param_storage.write_bytes(pik_data)
|
||||
self.use_preprocessed = False
|
||||
|
||||
|
||||
while True:
|
||||
if not self.use_preprocessed:
|
||||
self._pre_process()
|
||||
try:
|
||||
self._dataset = ConcatDataset(
|
||||
[TorchMelDataset(identifier=key, mel_path=self._mel_folder, transform=transforms,
|
||||
segment_len=audio_segment_len, hop_len=audio_hop_len,
|
||||
label=self._labels[key]['label']
|
||||
) for key in self._labels.keys()]
|
||||
)
|
||||
break
|
||||
except IOError:
|
||||
self.use_preprocessed = False
|
||||
pass
|
||||
|
||||
def _build_labels(self):
|
||||
labeldict = dict()
|
||||
with open(Path(self.data_root) / 'metadata' / 'UrbanSound8K.csv', mode='r') as f:
|
||||
# Exclude the header
|
||||
_ = next(f)
|
||||
for row in f:
|
||||
slice_file_name, fs_id, start, end, salience, fold, class_id, class_name = row.strip().split(',')
|
||||
if int(fold) == self.fold:
|
||||
key = slice_file_name.replace('.wav', '')
|
||||
labeldict[key] = dict(label=int(class_id), fold=int(fold))
|
||||
|
||||
# Delete File if one exists.
|
||||
if not self.use_preprocessed:
|
||||
for key in labeldict.keys():
|
||||
for mel_file in self._mel_folder.rglob(f'{key}_*'):
|
||||
try:
|
||||
mel_file.unlink(missing_ok=True)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
return labeldict
|
||||
|
||||
def __len__(self):
|
||||
return len(self._dataset)
|
||||
|
||||
def _pre_process(self):
|
||||
print('Preprocessing Mel Files....')
|
||||
with mp.Pool(processes=self.num_worker) as pool:
|
||||
with tqdm(total=len(self._labels)) as pbar:
|
||||
for i, _ in enumerate(pool.imap_unordered(self._build_mel, self._labels.keys())):
|
||||
pbar.update()
|
||||
|
||||
def _build_mel(self, filename):
|
||||
|
||||
wav_file = self._wav_folder / (filename.replace('X', '') + '.wav')
|
||||
mel_file = list(self._mel_folder.glob(f'{filename}_*'))
|
||||
|
||||
if not mel_file:
|
||||
raw_sample, sr = librosa.core.load(wav_file)
|
||||
mel_sample = self._mel_transform(raw_sample)
|
||||
m, n = mel_sample.shape
|
||||
mel_file = self._mel_folder / f'{filename}_{m}_{n}'
|
||||
self._mel_folder.mkdir(exist_ok=True, parents=True)
|
||||
with mel_file.open(mode='wb') as f:
|
||||
pickle.dump(mel_sample, f, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
else:
|
||||
# print(f"Already existed.. Skipping {filename}")
|
||||
mel_file = mel_file[0]
|
||||
|
||||
with mel_file.open(mode='rb') as f:
|
||||
mel_sample = pickle.load(f, fix_imports=True)
|
||||
return mel_sample, mel_file
|
||||
|
||||
def __getitem__(self, item):
|
||||
transformed_samples, label = self._dataset[item]
|
||||
|
||||
label = torch.as_tensor(label, dtype=torch.float)
|
||||
|
||||
return transformed_samples, label
|
40
main.py
40
main.py
@ -82,47 +82,9 @@ def run_lightning_loop(config_obj):
|
||||
# Save the last state & all parameters
|
||||
trainer.save_checkpoint(str(logger.log_dir / 'weights.ckpt'))
|
||||
model.save_to_disk(logger.log_dir)
|
||||
# trainer.run_evaluation(test_mode=True)
|
||||
|
||||
# Evaluate It
|
||||
if config_obj.main.eval:
|
||||
with torch.no_grad():
|
||||
model.eval()
|
||||
if torch.cuda.is_available():
|
||||
model.cuda()
|
||||
outputs = []
|
||||
from tqdm import tqdm
|
||||
for idx, batch in enumerate(tqdm(model.val_dataloader()[0])):
|
||||
batch_x, label = batch
|
||||
batch_x = batch_x.to(device='cuda' if model.on_gpu else 'cpu')
|
||||
label = label.to(device='cuda' if model.on_gpu else 'cpu')
|
||||
outputs.append(
|
||||
model.validation_step((batch_x, label), idx, 1)
|
||||
)
|
||||
model.validation_epoch_end([outputs])
|
||||
|
||||
# trainer.test()
|
||||
outpath = Path(config_obj.train.outpath)
|
||||
model_type = config_obj.model.type
|
||||
parameters = logger.name
|
||||
version = f'version_{logger.version}'
|
||||
inference_out = f'{parameters}_test_out.csv'
|
||||
|
||||
from main_inference import prepare_dataloader
|
||||
import variables as V
|
||||
test_dataloader = prepare_dataloader(config_obj)
|
||||
|
||||
with (outpath / model_type / parameters / version / inference_out).open(mode='w') as outfile:
|
||||
outfile.write(f'file_name,prediction\n')
|
||||
|
||||
from tqdm import tqdm
|
||||
for batch in tqdm(test_dataloader, total=len(test_dataloader)):
|
||||
batch_x, file_names = batch
|
||||
batch_x = batch_x.to(device='cuda' if model.on_gpu else 'cpu')
|
||||
y = model(batch_x).main_out
|
||||
predictions = (y >= 0.5).int()
|
||||
for prediction, file_name in zip(predictions, file_names):
|
||||
prediction_text = 'clear' if prediction == V.CLEAR else 'mask'
|
||||
outfile.write(f'{file_name},{prediction_text}\n')
|
||||
return model
|
||||
|
||||
|
||||
|
@ -5,11 +5,11 @@ from torch.nn import ModuleList
|
||||
|
||||
from ml_lib.modules.blocks import ConvModule, LinearModule
|
||||
from ml_lib.modules.util import (LightningBaseModule, Splitter, Merger)
|
||||
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetMixin,
|
||||
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, DatasetMixin,
|
||||
BaseDataloadersMixin)
|
||||
|
||||
|
||||
class BandwiseConvClassifier(BinaryMaskDatasetMixin,
|
||||
class BandwiseConvClassifier(DatasetMixin,
|
||||
BaseDataloadersMixin,
|
||||
BaseTrainMixin,
|
||||
BaseValMixin,
|
||||
|
@ -6,11 +6,11 @@ from torch.nn import ModuleList
|
||||
|
||||
from ml_lib.modules.blocks import ConvModule, LinearModule
|
||||
from ml_lib.modules.util import (LightningBaseModule, Splitter)
|
||||
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetMixin,
|
||||
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, DatasetMixin,
|
||||
BaseDataloadersMixin)
|
||||
|
||||
|
||||
class BandwiseConvMultiheadClassifier(BinaryMaskDatasetMixin,
|
||||
class BandwiseConvMultiheadClassifier(DatasetMixin,
|
||||
BaseDataloadersMixin,
|
||||
BaseTrainMixin,
|
||||
BaseValMixin,
|
||||
|
@ -5,11 +5,11 @@ from torch.nn import ModuleList
|
||||
|
||||
from ml_lib.modules.blocks import ConvModule, LinearModule
|
||||
from ml_lib.modules.util import LightningBaseModule
|
||||
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetMixin,
|
||||
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, DatasetMixin,
|
||||
BaseDataloadersMixin)
|
||||
|
||||
|
||||
class ConvClassifier(BinaryMaskDatasetMixin,
|
||||
class ConvClassifier(DatasetMixin,
|
||||
BaseDataloadersMixin,
|
||||
BaseTrainMixin,
|
||||
BaseValMixin,
|
||||
|
@ -8,11 +8,11 @@ from torch.nn import ModuleList
|
||||
from ml_lib.modules.util import LightningBaseModule
|
||||
from ml_lib.utils.config import Config
|
||||
from ml_lib.utils.model_io import SavedLightningModels
|
||||
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetMixin,
|
||||
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, DatasetMixin,
|
||||
BaseDataloadersMixin)
|
||||
|
||||
|
||||
class Ensemble(BinaryMaskDatasetMixin,
|
||||
class Ensemble(DatasetMixin,
|
||||
BaseDataloadersMixin,
|
||||
BaseTrainMixin,
|
||||
BaseValMixin,
|
||||
|
@ -5,11 +5,11 @@ from torch.nn import ModuleList
|
||||
|
||||
from ml_lib.modules.blocks import ConvModule, LinearModule, ResidualModule
|
||||
from ml_lib.modules.util import LightningBaseModule
|
||||
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetMixin,
|
||||
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, DatasetMixin,
|
||||
BaseDataloadersMixin)
|
||||
|
||||
|
||||
class ResidualConvClassifier(BinaryMaskDatasetMixin,
|
||||
class ResidualConvClassifier(DatasetMixin,
|
||||
BaseDataloadersMixin,
|
||||
BaseTrainMixin,
|
||||
BaseValMixin,
|
||||
|
@ -9,15 +9,16 @@ from einops import rearrange, repeat
|
||||
|
||||
from ml_lib.modules.blocks import TransformerModule
|
||||
from ml_lib.modules.util import (LightningBaseModule, AutoPadToShape, F_x)
|
||||
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetMixin,
|
||||
BaseDataloadersMixin)
|
||||
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, DatasetMixin,
|
||||
BaseDataloadersMixin, BaseTestMixin)
|
||||
|
||||
MIN_NUM_PATCHES = 16
|
||||
|
||||
class VisualTransformer(BinaryMaskDatasetMixin,
|
||||
class VisualTransformer(DatasetMixin,
|
||||
BaseDataloadersMixin,
|
||||
BaseTrainMixin,
|
||||
BaseValMixin,
|
||||
BaseTestMixin,
|
||||
BaseOptimizerMixin,
|
||||
LightningBaseModule
|
||||
):
|
||||
|
111
models/transformer_model_horizontal.py
Normal file
111
models/transformer_model_horizontal.py
Normal file
@ -0,0 +1,111 @@
|
||||
from argparse import Namespace
|
||||
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ml_lib.modules.blocks import TransformerModule
|
||||
from ml_lib.modules.util import (LightningBaseModule, AutoPadToShape, F_x, SlidingWindow)
|
||||
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, DatasetMixin,
|
||||
BaseDataloadersMixin, BaseTestMixin)
|
||||
|
||||
MIN_NUM_PATCHES = 16
|
||||
|
||||
class HorizontalVisualTransformer(DatasetMixin,
|
||||
BaseDataloadersMixin,
|
||||
BaseTrainMixin,
|
||||
BaseValMixin,
|
||||
BaseTestMixin,
|
||||
BaseOptimizerMixin,
|
||||
LightningBaseModule
|
||||
):
|
||||
|
||||
def __init__(self, hparams):
|
||||
super(HorizontalVisualTransformer, self).__init__(hparams)
|
||||
|
||||
# Dataset
|
||||
# =============================================================================
|
||||
self.dataset = self.build_dataset()
|
||||
|
||||
self.in_shape = self.dataset.train_dataset.sample_shape
|
||||
assert len(self.in_shape) == 3, 'There need to be three Dimensions'
|
||||
channels, height, width = self.in_shape
|
||||
|
||||
# Model Paramters
|
||||
# =============================================================================
|
||||
# Additional parameters
|
||||
self.embed_dim = self.params.embedding_size
|
||||
self.patch_size = self.params.patch_size
|
||||
self.height = height
|
||||
self.width = width
|
||||
self.channels = channels
|
||||
|
||||
self.new_height = ((self.height - self.patch_size)//1) + 1
|
||||
|
||||
num_patches = self.new_height - (self.patch_size // 2)
|
||||
patch_dim = channels * self.patch_size * self.width
|
||||
assert num_patches >= MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for ' + \
|
||||
f'attention. Try decreasing your patch size'
|
||||
|
||||
# Correct the Embedding Dim
|
||||
if not self.embed_dim % self.params.heads == 0:
|
||||
self.embed_dim = (self.embed_dim // self.params.heads) * self.params.heads
|
||||
message = ('Embedding Dimension was fixed to be devideable by the number' +
|
||||
f' of attention heads, is now: {self.embed_dim}')
|
||||
for func in print, warnings.warn:
|
||||
func(message)
|
||||
|
||||
# Utility Modules
|
||||
self.autopad = AutoPadToShape((self.new_height, self.width))
|
||||
self.dropout = nn.Dropout(self.params.dropout)
|
||||
self.slider = SlidingWindow((channels, *self.autopad.target_shape), (self.patch_size, self.width),
|
||||
keepdim=False)
|
||||
|
||||
# Modules with Parameters
|
||||
self.transformer = TransformerModule(in_shape=self.embed_dim, hidden_size=self.params.lat_dim,
|
||||
n_heads=self.params.heads, num_layers=self.params.attn_depth,
|
||||
dropout=self.params.dropout, use_norm=self.params.use_norm,
|
||||
activation=self.params.activation_as_string
|
||||
)
|
||||
|
||||
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, self.embed_dim))
|
||||
self.patch_to_embedding = nn.Linear(patch_dim, self.embed_dim) if self.params.embedding_size \
|
||||
else F_x(self.embed_dim)
|
||||
self.cls_token = nn.Parameter(torch.randn(1, 1, self.embed_dim))
|
||||
self.to_cls_token = nn.Identity()
|
||||
|
||||
self.mlp_head = nn.Sequential(
|
||||
nn.LayerNorm(self.embed_dim),
|
||||
nn.Linear(self.embed_dim, self.params.lat_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(self.params.dropout),
|
||||
nn.Linear(self.params.lat_dim, 1),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
"""
|
||||
:param x: the sequence to the encoder (required).
|
||||
:param mask: the mask for the src sequence (optional).
|
||||
:return:
|
||||
"""
|
||||
tensor = self.autopad(x)
|
||||
tensor = self.slider(tensor)
|
||||
|
||||
tensor = self.patch_to_embedding(tensor)
|
||||
b, n, _ = tensor.shape
|
||||
|
||||
# cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
|
||||
cls_tokens = self.cls_token.repeat((b, 1, 1))
|
||||
|
||||
tensor = torch.cat((cls_tokens, tensor), dim=1)
|
||||
tensor += self.pos_embedding[:, :(n + 1)]
|
||||
tensor = self.dropout(tensor)
|
||||
|
||||
tensor = self.transformer(tensor, mask)
|
||||
|
||||
tensor = self.to_cls_token(tensor[:, 0])
|
||||
tensor = self.mlp_head(tensor)
|
||||
return Namespace(main_out=tensor)
|
@ -7,21 +7,22 @@ from torch import nn
|
||||
|
||||
from ml_lib.modules.blocks import TransformerModule
|
||||
from ml_lib.modules.util import (LightningBaseModule, AutoPadToShape, F_x, SlidingWindow)
|
||||
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetMixin,
|
||||
BaseDataloadersMixin)
|
||||
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, DatasetMixin,
|
||||
BaseDataloadersMixin, BaseTestMixin)
|
||||
|
||||
MIN_NUM_PATCHES = 16
|
||||
|
||||
class SequentialVisualTransformer(BinaryMaskDatasetMixin,
|
||||
BaseDataloadersMixin,
|
||||
BaseTrainMixin,
|
||||
BaseValMixin,
|
||||
BaseOptimizerMixin,
|
||||
LightningBaseModule
|
||||
):
|
||||
class VerticalVisualTransformer(DatasetMixin,
|
||||
BaseDataloadersMixin,
|
||||
BaseTrainMixin,
|
||||
BaseValMixin,
|
||||
BaseTestMixin,
|
||||
BaseOptimizerMixin,
|
||||
LightningBaseModule
|
||||
):
|
||||
|
||||
def __init__(self, hparams):
|
||||
super(SequentialVisualTransformer, self).__init__(hparams)
|
||||
super(VerticalVisualTransformer, self).__init__(hparams)
|
||||
|
||||
# Dataset
|
||||
# =============================================================================
|
15
multi_run.py
15
multi_run.py
@ -4,6 +4,7 @@ from _paramters import main_arg_parser
|
||||
from main import run_lightning_loop
|
||||
|
||||
import warnings
|
||||
import shutil
|
||||
|
||||
from ml_lib.utils.config import Config
|
||||
|
||||
@ -22,7 +23,7 @@ if __name__ == '__main__':
|
||||
arg_dict.update(main_seed=seed)
|
||||
if False:
|
||||
for patch_size in [3, 5 , 9]:
|
||||
for model in ['SequentialVisualTransformer']:
|
||||
for model in ['VerticalVisualTransformer']:
|
||||
arg_dict.update(model_type=model, model_patch_size=patch_size)
|
||||
raw_conf = dict(data_speed_amount=0.0, data_speed_min=0.0, data_speed_max=0.0,
|
||||
data_mask_ratio=0.0, data_noise_ratio=0.0, data_shift_ratio=0.0, data_loudness_ratio=0.0,
|
||||
@ -52,12 +53,12 @@ if __name__ == '__main__':
|
||||
|
||||
arg_dict.update(dicts)
|
||||
if True:
|
||||
for patch_size in [3, 7]:
|
||||
for lat_dim in [4, 32]:
|
||||
for heads in [2, 4]:
|
||||
for embedding_size in [32, 64]:
|
||||
for attn_depth in [1, 3]:
|
||||
for model in ['SequentialVisualTransformer', 'VisualTransformer']:
|
||||
for patch_size in [7]:
|
||||
for lat_dim in [32]:
|
||||
for heads in [8]:
|
||||
for embedding_size in [7**2]:
|
||||
for attn_depth in [1, 3, 5, 7]:
|
||||
for model in ['HorizontalVisualTransformer']:
|
||||
arg_dict.update(
|
||||
model_type=model,
|
||||
model_patch_size=patch_size,
|
||||
|
@ -103,9 +103,9 @@ class BaseValMixin:
|
||||
keys = list(output[0].keys())
|
||||
ident = '' if output_idx == 0 else '_train'
|
||||
summary_dict.update({f'mean{ident}_{key}': torch.mean(torch.stack([output[key]
|
||||
for output in output]))
|
||||
for key in keys if 'loss' in key}
|
||||
)
|
||||
for output in output]))
|
||||
for key in keys if 'loss' in key}
|
||||
)
|
||||
|
||||
# UnweightedAverageRecall
|
||||
y_true = torch.cat([output['batch_y'] for output in output]) .cpu().numpy()
|
||||
@ -121,7 +121,45 @@ class BaseValMixin:
|
||||
self.log(key, summary_dict[key])
|
||||
|
||||
|
||||
class BinaryMaskDatasetMixin:
|
||||
class BaseTestMixin:
|
||||
|
||||
absolute_loss = nn.L1Loss()
|
||||
nll_loss = nn.NLLLoss()
|
||||
bce_loss = nn.BCELoss()
|
||||
|
||||
def test_step(self, batch_xy, batch_idx, dataloader_idx, *args, **kwargs):
|
||||
assert isinstance(self, LightningBaseModule)
|
||||
batch_x, batch_y = batch_xy
|
||||
y = self(batch_x).main_out
|
||||
test_bce_loss = self.bce_loss(y.squeeze(), batch_y)
|
||||
return dict(test_bce_loss=test_bce_loss,
|
||||
batch_idx=batch_idx, y=y, batch_y=batch_y)
|
||||
|
||||
def test_epoch_end(self, outputs, *_, **__):
|
||||
assert isinstance(self, LightningBaseModule)
|
||||
summary_dict = dict()
|
||||
|
||||
keys = list(outputs[0].keys())
|
||||
summary_dict.update({f'mean_{key}': torch.mean(torch.stack([output[key]
|
||||
for output in outputs]))
|
||||
for key in keys if 'loss' in key}
|
||||
)
|
||||
|
||||
# UnweightedAverageRecall
|
||||
y_true = torch.cat([output['batch_y'] for output in outputs]) .cpu().numpy()
|
||||
y_pred = torch.cat([output['y'] for output in outputs]).squeeze().cpu().numpy()
|
||||
|
||||
y_pred = (y_pred >= 0.5).astype(np.float32)
|
||||
|
||||
uar_score = sklearn.metrics.recall_score(y_true, y_pred, labels=[0, 1], average='macro',
|
||||
sample_weight=None, zero_division='warn')
|
||||
uar_score = torch.as_tensor(uar_score)
|
||||
summary_dict.update({f'uar_score': uar_score})
|
||||
for key in summary_dict.keys():
|
||||
self.log(key, summary_dict[key])
|
||||
|
||||
|
||||
class DatasetMixin:
|
||||
|
||||
def build_dataset(self):
|
||||
assert isinstance(self, LightningBaseModule)
|
||||
@ -159,21 +197,20 @@ class BinaryMaskDatasetMixin:
|
||||
util_transforms])
|
||||
|
||||
# Datasets
|
||||
from datasets.binar_masks import BinaryMasksDataset
|
||||
dataset = Namespace(
|
||||
**dict(
|
||||
# TRAIN DATASET
|
||||
train_dataset=BinaryMasksDataset(self.params.root, setting=V.DATA_OPTIONS.train,
|
||||
train_dataset=self.dataset_class(self.params.root, setting=V.DATA_OPTIONS.train,
|
||||
use_preprocessed=self.params.use_preprocessed,
|
||||
stretch_dataset=self.params.stretch,
|
||||
mel_transforms=mel_transforms_train, transforms=aug_transforms),
|
||||
# VALIDATION DATASET
|
||||
val_train_dataset=BinaryMasksDataset(self.params.root, setting=V.DATA_OPTIONS.train,
|
||||
val_train_dataset=self.dataset_class(self.params.root, setting=V.DATA_OPTIONS.train,
|
||||
mel_transforms=mel_transforms, transforms=util_transforms),
|
||||
val_dataset=BinaryMasksDataset(self.params.root, setting=V.DATA_OPTIONS.devel,
|
||||
val_dataset=self.dataset_class(self.params.root, setting=V.DATA_OPTIONS.devel,
|
||||
mel_transforms=mel_transforms, transforms=util_transforms),
|
||||
# TEST DATASET
|
||||
test_dataset=BinaryMasksDataset(self.params.root, setting=V.DATA_OPTIONS.test,
|
||||
test_dataset=self.dataset_class(self.params.root, setting=V.DATA_OPTIONS.test,
|
||||
mel_transforms=mel_transforms, transforms=util_transforms),
|
||||
)
|
||||
)
|
||||
@ -190,22 +227,23 @@ class BaseDataloadersMixin(ABC):
|
||||
# sampler = RandomSampler(self.dataset.train_dataset, True, len(self.dataset.train_dataset))
|
||||
sampler = None
|
||||
return DataLoader(dataset=self.dataset.train_dataset, shuffle=True if not sampler else None, sampler=sampler,
|
||||
batch_size=self.params.batch_size,
|
||||
batch_size=self.params.batch_size, pin_memory=True,
|
||||
num_workers=self.params.worker)
|
||||
|
||||
# Test Dataloader
|
||||
def test_dataloader(self):
|
||||
assert isinstance(self, LightningBaseModule)
|
||||
return DataLoader(dataset=self.dataset.test_dataset, shuffle=False,
|
||||
batch_size=self.params.batch_size,
|
||||
batch_size=self.params.batch_size, pin_memory=True,
|
||||
num_workers=self.params.worker)
|
||||
|
||||
# Validation Dataloader
|
||||
def val_dataloader(self):
|
||||
assert isinstance(self, LightningBaseModule)
|
||||
val_dataloader = DataLoader(dataset=self.dataset.val_dataset, shuffle=False,
|
||||
val_dataloader = DataLoader(dataset=self.dataset.val_dataset, shuffle=False, pin_memory=True,
|
||||
batch_size=self.params.batch_size, num_workers=self.params.worker)
|
||||
|
||||
train_dataloader = DataLoader(self.dataset.val_train_dataset, num_workers=self.params.worker,
|
||||
pin_memory=True,
|
||||
batch_size=self.params.batch_size, shuffle=False)
|
||||
return [val_dataloader, train_dataloader]
|
||||
|
Loading…
x
Reference in New Issue
Block a user