CCS intergration training running
notebooks
This commit is contained in:
parent
c12f3866c8
commit
82835295a1
0
datasets/__init__.py
Normal file
0
datasets/__init__.py
Normal file
@ -1,172 +1,19 @@
|
|||||||
import multiprocessing as mp
|
from datasets.compare_base import CompareBase
|
||||||
from collections import defaultdict
|
from ml_lib.utils.tools import add_argparse_args
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from torch.utils.data import DataLoader, ConcatDataset, WeightedRandomSampler
|
|
||||||
from torchvision.transforms import Compose, RandomApply
|
|
||||||
|
|
||||||
from ml_lib.audio_toolset.audio_io import NormalizeLocal
|
|
||||||
from ml_lib.audio_toolset.audio_to_mel_dataset import LibrosaAudioToMelDataset
|
|
||||||
from ml_lib.audio_toolset.mel_augmentation import NoiseInjection, LoudnessManipulator, ShiftTime, MaskAug
|
|
||||||
from ml_lib.utils._basedatamodule import _BaseDataModule, DATA_OPTION_test, DATA_OPTION_train, DATA_OPTION_devel
|
|
||||||
from ml_lib.utils.equal_sampler import EqualSampler
|
|
||||||
from ml_lib.utils.transforms import ToTensor
|
|
||||||
|
|
||||||
data_options = [DATA_OPTION_test, DATA_OPTION_train, DATA_OPTION_devel]
|
|
||||||
|
|
||||||
|
|
||||||
class CCSLibrosaDatamodule(_BaseDataModule):
|
class CCSLibrosaDatamodule(CompareBase):
|
||||||
|
|
||||||
@property
|
class_names = ['negative', 'positive']
|
||||||
def class_names(self):
|
sub_dataset_name = 'ComParE2021_CCS'
|
||||||
return {key: val for val, key in enumerate(['negative', 'positive'])}
|
|
||||||
|
|
||||||
@property
|
def __init__(self, *args, **kwargs):
|
||||||
def n_classes(self):
|
super(CCSLibrosaDatamodule, self).__init__(*args, **kwargs)
|
||||||
return len(self.class_names)
|
|
||||||
|
|
||||||
@property
|
@classmethod
|
||||||
def shape(self):
|
def add_argparse_args(cls, parent_parser):
|
||||||
return self.datasets[DATA_OPTION_train].datasets[0][0][1].shape
|
return add_argparse_args(CompareBase, parent_parser)
|
||||||
|
|
||||||
@property
|
@classmethod
|
||||||
def mel_folder(self):
|
def from_argparse_args(cls, args, **kwargs):
|
||||||
return self.root / 'mel_folder'
|
return CompareBase.from_argparse_args(args, class_names=cls.class_names, sub_dataset_name=cls.sub_dataset_name)
|
||||||
|
|
||||||
@property
|
|
||||||
def wav_folder(self):
|
|
||||||
return self.root / 'wav'
|
|
||||||
|
|
||||||
def __init__(self, data_root, batch_size, num_worker, sr, n_mels, n_fft, hop_length, sampler=None,
|
|
||||||
random_apply_chance=0.5, target_mel_length_in_seconds=1,
|
|
||||||
loudness_ratio=0.3, shift_ratio=0.3, noise_ratio=0.3, mask_ratio=0.3):
|
|
||||||
super(CCSLibrosaDatamodule, self).__init__()
|
|
||||||
self.sampler = sampler
|
|
||||||
self.samplers = None
|
|
||||||
|
|
||||||
self.num_worker = num_worker or 1
|
|
||||||
self.batch_size = batch_size
|
|
||||||
self.root = Path(data_root) / 'ComParE2021_CCS'
|
|
||||||
self.mel_length_in_seconds = target_mel_length_in_seconds
|
|
||||||
|
|
||||||
# Mel Transforms - will be pushed with all other paramters by self.__dict__ to subdataset-class
|
|
||||||
self.mel_kwargs = dict(sr=sr, n_mels=n_mels, n_fft=n_fft, hop_length=hop_length)
|
|
||||||
|
|
||||||
# Utility
|
|
||||||
self.utility_transforms = Compose([NormalizeLocal(), ToTensor()])
|
|
||||||
|
|
||||||
# Data Augmentations
|
|
||||||
self.random_apply_chance = random_apply_chance
|
|
||||||
self.mel_augmentations = Compose([
|
|
||||||
RandomApply([NoiseInjection(noise_ratio)], p=random_apply_chance),
|
|
||||||
RandomApply([LoudnessManipulator(loudness_ratio)], p=random_apply_chance),
|
|
||||||
RandomApply([ShiftTime(shift_ratio)], p=random_apply_chance),
|
|
||||||
RandomApply([MaskAug(mask_ratio)], p=random_apply_chance),
|
|
||||||
self.utility_transforms])
|
|
||||||
|
|
||||||
def train_dataloader(self):
|
|
||||||
return DataLoader(dataset=self.datasets[DATA_OPTION_train], num_workers=self.num_worker, pin_memory=True,
|
|
||||||
sampler=self.samplers[DATA_OPTION_train], batch_size=self.batch_size)
|
|
||||||
|
|
||||||
# Validation Dataloader
|
|
||||||
def val_dataloader(self):
|
|
||||||
return DataLoader(dataset=self.datasets[DATA_OPTION_devel], shuffle=False,
|
|
||||||
batch_size=self.batch_size, pin_memory=False,
|
|
||||||
num_workers=self.num_worker)
|
|
||||||
|
|
||||||
# Test Dataloader
|
|
||||||
def test_dataloader(self):
|
|
||||||
return DataLoader(dataset=self.datasets[DATA_OPTION_test], shuffle=False,
|
|
||||||
batch_size=self.batch_size, pin_memory=False,
|
|
||||||
num_workers=self.num_worker)
|
|
||||||
|
|
||||||
def _build_subdataset(self, row, build=False):
|
|
||||||
slice_file_name, class_name = row.strip().split(',')
|
|
||||||
class_id = self.class_names.get(class_name, -1)
|
|
||||||
audio_file_path = self.wav_folder / slice_file_name
|
|
||||||
|
|
||||||
# DATA OPTION DIFFERENTIATION !!!!!!!!!!! - Begin
|
|
||||||
kwargs = self.__dict__
|
|
||||||
if any([x in slice_file_name for x in [DATA_OPTION_devel, DATA_OPTION_test]]):
|
|
||||||
kwargs.update(mel_augmentations=self.utility_transforms)
|
|
||||||
# DATA OPTION DIFFERENTIATION !!!!!!!!!!! - End
|
|
||||||
|
|
||||||
target_frames = self.mel_length_in_seconds * self.mel_kwargs['sr']
|
|
||||||
sample_segment_length = target_frames // self.mel_kwargs['hop_length'] + 1
|
|
||||||
kwargs.update(sample_segment_len=sample_segment_length, sample_hop_len=sample_segment_length//2)
|
|
||||||
mel_dataset = LibrosaAudioToMelDataset(audio_file_path, class_id, **kwargs)
|
|
||||||
if build:
|
|
||||||
assert mel_dataset.build_mel()
|
|
||||||
return mel_dataset, class_id, slice_file_name
|
|
||||||
|
|
||||||
def prepare_data(self, *args, **kwargs):
|
|
||||||
datasets = dict()
|
|
||||||
for data_option in data_options:
|
|
||||||
with open(Path(self.root) / 'lab' / f'{data_option}.csv', mode='r') as f:
|
|
||||||
# Exclude the header
|
|
||||||
_ = next(f)
|
|
||||||
all_rows = list(f)
|
|
||||||
chunksize = len(all_rows) // max(self.num_worker, 1)
|
|
||||||
dataset = list()
|
|
||||||
with mp.Pool(processes=self.num_worker) as pool:
|
|
||||||
|
|
||||||
from itertools import repeat
|
|
||||||
results = pool.starmap_async(self._build_subdataset, zip(all_rows, repeat(True, len(all_rows))),
|
|
||||||
chunksize=chunksize)
|
|
||||||
for sub_dataset in results.get():
|
|
||||||
dataset.append(sub_dataset[0])
|
|
||||||
datasets[data_option] = ConcatDataset(dataset)
|
|
||||||
print(f'{data_option}-dataset prepared.')
|
|
||||||
self.datasets = datasets
|
|
||||||
return datasets
|
|
||||||
|
|
||||||
def setup(self, stag=None):
|
|
||||||
datasets = dict()
|
|
||||||
samplers = dict()
|
|
||||||
weights = dict()
|
|
||||||
|
|
||||||
for data_option in data_options:
|
|
||||||
with open(Path(self.root) / 'lab' / f'{data_option}.csv', mode='r') as f:
|
|
||||||
# Exclude the header
|
|
||||||
_ = next(f)
|
|
||||||
all_rows = list(f)
|
|
||||||
dataset = list()
|
|
||||||
for row in all_rows:
|
|
||||||
mel_dataset, class_id, _ = self._build_subdataset(row)
|
|
||||||
dataset.append(mel_dataset)
|
|
||||||
print(f'{data_option}-dataset prepared!')
|
|
||||||
datasets[data_option] = ConcatDataset(dataset)
|
|
||||||
|
|
||||||
# Build Weighted Sampler for train and val
|
|
||||||
if data_option in [DATA_OPTION_train]:
|
|
||||||
if self.sampler == EqualSampler.__name__:
|
|
||||||
class_idxs = [[idx for idx, (_, __, label) in enumerate(datasets[data_option]) if label == class_idx]
|
|
||||||
for class_idx in range(len(self.class_names))
|
|
||||||
]
|
|
||||||
samplers[data_option] = EqualSampler(class_idxs)
|
|
||||||
elif self.sampler == WeightedRandomSampler.__name__:
|
|
||||||
class_counts = defaultdict(lambda: 0)
|
|
||||||
for _, __, label in datasets[data_option]:
|
|
||||||
class_counts[label] += 1
|
|
||||||
len_largest_class = max(class_counts.values())
|
|
||||||
|
|
||||||
weights[data_option] = [1 / class_counts[x] for x in range(len(class_counts))]
|
|
||||||
##############################################################################
|
|
||||||
weights[data_option] = [weights[data_option][datasets[data_option][i][-1]]
|
|
||||||
for i in range(len(datasets[data_option]))]
|
|
||||||
samplers[data_option] = WeightedRandomSampler(weights[data_option],
|
|
||||||
len_largest_class * len(self.class_names))
|
|
||||||
else:
|
|
||||||
samplers[data_option] = None
|
|
||||||
self.datasets = datasets
|
|
||||||
self.samplers = samplers
|
|
||||||
print(f'Dataset {self.__class__.__name__} setup done.')
|
|
||||||
return datasets
|
|
||||||
|
|
||||||
def purge(self):
|
|
||||||
import shutil
|
|
||||||
|
|
||||||
shutil.rmtree(self.mel_folder, ignore_errors=True)
|
|
||||||
print('Mel Folder has been recursively deleted')
|
|
||||||
print(f'Folder still exists: {self.mel_folder.exists()}')
|
|
||||||
return not self.mel_folder.exists()
|
|
||||||
|
181
datasets/compare_base.py
Normal file
181
datasets/compare_base.py
Normal file
@ -0,0 +1,181 @@
|
|||||||
|
import multiprocessing as mp
|
||||||
|
from collections import defaultdict
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from torch.utils.data import DataLoader, ConcatDataset, WeightedRandomSampler
|
||||||
|
from torchvision.transforms import Compose, RandomApply
|
||||||
|
|
||||||
|
from ml_lib.audio_toolset.audio_io import NormalizeLocal
|
||||||
|
from ml_lib.audio_toolset.audio_to_mel_dataset import LibrosaAudioToMelDataset
|
||||||
|
from ml_lib.audio_toolset.mel_augmentation import NoiseInjection, LoudnessManipulator, ShiftTime, MaskAug
|
||||||
|
from ml_lib.utils._basedatamodule import _BaseDataModule, DATA_OPTION_test, DATA_OPTION_train, DATA_OPTION_devel
|
||||||
|
from ml_lib.utils.equal_sampler import EqualSampler
|
||||||
|
from ml_lib.utils.tools import add_argparse_args
|
||||||
|
from ml_lib.utils.transforms import ToTensor
|
||||||
|
|
||||||
|
data_options = [DATA_OPTION_test, DATA_OPTION_train, DATA_OPTION_devel]
|
||||||
|
|
||||||
|
|
||||||
|
class CompareBase(_BaseDataModule):
|
||||||
|
|
||||||
|
@property
|
||||||
|
def class_names(self):
|
||||||
|
return {key: val for val, key in enumerate(self._class_names)}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def n_classes(self):
|
||||||
|
return len(self.class_names)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def shape(self):
|
||||||
|
return 1, int(self.mel_kwargs['n_mels']), int(self.sample_segment_length)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def mel_folder(self):
|
||||||
|
return self.root / 'mel_folder'
|
||||||
|
|
||||||
|
@property
|
||||||
|
def wav_folder(self):
|
||||||
|
return self.root / 'wav'
|
||||||
|
|
||||||
|
def __init__(self, sub_dataset_name, class_names, data_root, batch_size, num_worker, sr, n_mels, n_fft, hop_length, sampler=None,
|
||||||
|
random_apply_chance=0.5, target_mel_length_in_seconds=1,
|
||||||
|
loudness_ratio=0.3, shift_ratio=0.3, noise_ratio=0.3, mask_ratio=0.3):
|
||||||
|
super(CompareBase, self).__init__()
|
||||||
|
self.sampler = sampler
|
||||||
|
self.samplers = None
|
||||||
|
|
||||||
|
self.num_worker = num_worker or 1
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.root = Path(data_root) / sub_dataset_name
|
||||||
|
self._class_names = class_names
|
||||||
|
self.mel_length_in_seconds = target_mel_length_in_seconds
|
||||||
|
|
||||||
|
# Mel Transforms - will be pushed with all other paramters by self.__dict__ to subdataset-class
|
||||||
|
self.mel_kwargs = dict(sr=sr, n_mels=n_mels, n_fft=n_fft, hop_length=hop_length)
|
||||||
|
|
||||||
|
target_frames = self.mel_length_in_seconds * self.mel_kwargs['sr']
|
||||||
|
self.sample_segment_length = target_frames // self.mel_kwargs['hop_length'] + 1
|
||||||
|
|
||||||
|
# Utility
|
||||||
|
self.utility_transforms = Compose([NormalizeLocal(), ToTensor()])
|
||||||
|
|
||||||
|
# Data Augmentations
|
||||||
|
self.random_apply_chance = random_apply_chance
|
||||||
|
self.mel_augmentations = Compose([
|
||||||
|
RandomApply([NoiseInjection(noise_ratio)], p=random_apply_chance),
|
||||||
|
RandomApply([LoudnessManipulator(loudness_ratio)], p=random_apply_chance),
|
||||||
|
RandomApply([ShiftTime(shift_ratio)], p=random_apply_chance),
|
||||||
|
RandomApply([MaskAug(mask_ratio)], p=random_apply_chance),
|
||||||
|
self.utility_transforms])
|
||||||
|
|
||||||
|
def train_dataloader(self):
|
||||||
|
return DataLoader(dataset=self.datasets[DATA_OPTION_train], num_workers=self.num_worker, pin_memory=True,
|
||||||
|
sampler=self.samplers[DATA_OPTION_train], batch_size=self.batch_size)
|
||||||
|
|
||||||
|
# Validation Dataloader
|
||||||
|
def val_dataloader(self):
|
||||||
|
return DataLoader(dataset=self.datasets[DATA_OPTION_devel], shuffle=False,
|
||||||
|
batch_size=self.batch_size, pin_memory=False,
|
||||||
|
num_workers=self.num_worker)
|
||||||
|
|
||||||
|
# Test Dataloader
|
||||||
|
def test_dataloader(self):
|
||||||
|
return DataLoader(dataset=self.datasets[DATA_OPTION_test], shuffle=False,
|
||||||
|
batch_size=self.batch_size, pin_memory=False,
|
||||||
|
num_workers=self.num_worker)
|
||||||
|
|
||||||
|
def _build_subdataset(self, row, build=False):
|
||||||
|
slice_file_name, class_name = row.strip().split(',')
|
||||||
|
class_id = self.class_names.get(class_name, -1)
|
||||||
|
audio_file_path = self.wav_folder / slice_file_name
|
||||||
|
|
||||||
|
# DATA OPTION DIFFERENTIATION !!!!!!!!!!! - Begin
|
||||||
|
kwargs = self.__dict__
|
||||||
|
if any([x in slice_file_name for x in [DATA_OPTION_devel, DATA_OPTION_test]]):
|
||||||
|
kwargs.update(mel_augmentations=self.utility_transforms)
|
||||||
|
# DATA OPTION DIFFERENTIATION !!!!!!!!!!! - End
|
||||||
|
|
||||||
|
|
||||||
|
kwargs.update(sample_segment_len=self.sample_segment_length, sample_hop_len=self.sample_segment_length//2)
|
||||||
|
mel_dataset = LibrosaAudioToMelDataset(audio_file_path, class_id, **kwargs)
|
||||||
|
if build:
|
||||||
|
assert mel_dataset.build_mel()
|
||||||
|
return mel_dataset, class_id, slice_file_name
|
||||||
|
|
||||||
|
def manual_setup(self, stag=None):
|
||||||
|
datasets = dict()
|
||||||
|
for data_option in data_options:
|
||||||
|
with open(Path(self.root) / 'lab' / f'{data_option}.csv', mode='r') as f:
|
||||||
|
# Exclude the header
|
||||||
|
_ = next(f)
|
||||||
|
all_rows = list(f)
|
||||||
|
chunksize = len(all_rows) // max(self.num_worker, 1)
|
||||||
|
dataset = list()
|
||||||
|
with mp.Pool(processes=self.num_worker) as pool:
|
||||||
|
|
||||||
|
from itertools import repeat
|
||||||
|
results = pool.starmap_async(self._build_subdataset, zip(all_rows, repeat(True, len(all_rows))),
|
||||||
|
chunksize=chunksize)
|
||||||
|
for sub_dataset in results.get():
|
||||||
|
dataset.append(sub_dataset[0])
|
||||||
|
datasets[data_option] = ConcatDataset(dataset)
|
||||||
|
print(f'{data_option}-dataset prepared.')
|
||||||
|
self.datasets = datasets
|
||||||
|
return datasets
|
||||||
|
|
||||||
|
def prepare_data(self, *args, rebuild=False, **kwargs):
|
||||||
|
datasets = dict()
|
||||||
|
samplers = dict()
|
||||||
|
weights = dict()
|
||||||
|
|
||||||
|
for data_option in data_options:
|
||||||
|
with open(Path(self.root) / 'lab' / f'{data_option}.csv', mode='r') as f:
|
||||||
|
# Exclude the header
|
||||||
|
_ = next(f)
|
||||||
|
all_rows = list(f)
|
||||||
|
chunksize = len(all_rows) // max(self.num_worker, 1)
|
||||||
|
dataset = list()
|
||||||
|
with mp.Pool(processes=self.num_worker) as pool:
|
||||||
|
|
||||||
|
from itertools import repeat
|
||||||
|
results = pool.starmap_async(self._build_subdataset, zip(all_rows, repeat(rebuild, len(all_rows))),
|
||||||
|
chunksize=chunksize)
|
||||||
|
for sub_dataset in results.get():
|
||||||
|
dataset.append(sub_dataset[0])
|
||||||
|
datasets[data_option] = ConcatDataset(dataset)
|
||||||
|
print(f'{data_option}-dataset set up!')
|
||||||
|
|
||||||
|
# Build Weighted Sampler for train and val
|
||||||
|
if data_option in [DATA_OPTION_train]:
|
||||||
|
if self.sampler == EqualSampler.__name__:
|
||||||
|
class_idxs = [[idx for idx, (_, __, label) in enumerate(datasets[data_option]) if label == class_idx]
|
||||||
|
for class_idx in range(len(self.class_names))
|
||||||
|
]
|
||||||
|
samplers[data_option] = EqualSampler(class_idxs)
|
||||||
|
elif self.sampler == WeightedRandomSampler.__name__:
|
||||||
|
class_counts = defaultdict(lambda: 0)
|
||||||
|
for _, __, label in datasets[data_option]:
|
||||||
|
class_counts[label] += 1
|
||||||
|
len_largest_class = max(class_counts.values())
|
||||||
|
|
||||||
|
weights[data_option] = [1 / class_counts[x] for x in range(len(class_counts))]
|
||||||
|
##############################################################################
|
||||||
|
weights[data_option] = [weights[data_option][datasets[data_option][i][-1]]
|
||||||
|
for i in range(len(datasets[data_option]))]
|
||||||
|
samplers[data_option] = WeightedRandomSampler(weights[data_option],
|
||||||
|
len_largest_class * len(self.class_names))
|
||||||
|
else:
|
||||||
|
samplers[data_option] = None
|
||||||
|
self.datasets = datasets
|
||||||
|
self.samplers = samplers
|
||||||
|
print(f'Dataset {self.__class__.__name__} setup done.')
|
||||||
|
return datasets
|
||||||
|
|
||||||
|
def purge(self):
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
shutil.rmtree(self.mel_folder, ignore_errors=True)
|
||||||
|
print('Mel Folder has been recursively deleted')
|
||||||
|
print(f'Folder still exists: {self.mel_folder.exists()}')
|
||||||
|
return not self.mel_folder.exists()
|
@ -1,170 +1,24 @@
|
|||||||
import multiprocessing as mp
|
from argparse import ArgumentParser, Namespace
|
||||||
from collections import defaultdict
|
from ctypes import Union
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from torch.utils.data import DataLoader, ConcatDataset, WeightedRandomSampler
|
from datasets.compare_base import CompareBase
|
||||||
from torchvision.transforms import Compose, RandomApply
|
from ml_lib.utils.tools import add_argparse_args
|
||||||
|
|
||||||
from ml_lib.audio_toolset.audio_io import NormalizeLocal
|
|
||||||
from ml_lib.audio_toolset.audio_to_mel_dataset import LibrosaAudioToMelDataset
|
|
||||||
from ml_lib.audio_toolset.mel_augmentation import NoiseInjection, LoudnessManipulator, ShiftTime, MaskAug
|
|
||||||
from ml_lib.utils._basedatamodule import _BaseDataModule, DATA_OPTION_test, DATA_OPTION_train, DATA_OPTION_devel
|
|
||||||
from ml_lib.utils.equal_sampler import EqualSampler
|
|
||||||
from ml_lib.utils.transforms import ToTensor
|
|
||||||
|
|
||||||
data_options = [DATA_OPTION_test, DATA_OPTION_train, DATA_OPTION_devel]
|
|
||||||
|
|
||||||
|
|
||||||
class PrimatesLibrosaDatamodule(_BaseDataModule):
|
class PrimatesLibrosaDatamodule(CompareBase):
|
||||||
|
|
||||||
@property
|
class_names = ['background', 'chimpanze', 'geunon', 'mandrille', 'redcap']
|
||||||
def class_names(self):
|
sub_dataset_name = 'primates'
|
||||||
return {key: val for val, key in enumerate(['background', 'chimpanze', 'geunon', 'mandrille', 'redcap'])}
|
|
||||||
|
|
||||||
@property
|
def __init__(self, *args, **kwargs):
|
||||||
def n_classes(self):
|
super(PrimatesLibrosaDatamodule, self).__init__(*args, **kwargs)
|
||||||
return len(self.class_names)
|
|
||||||
|
|
||||||
@property
|
@classmethod
|
||||||
def shape(self):
|
def add_argparse_args(cls, parent_parser):
|
||||||
|
return add_argparse_args(CompareBase, parent_parser)
|
||||||
|
|
||||||
return self.datasets[DATA_OPTION_train].datasets[0][0][1].shape
|
|
||||||
|
|
||||||
@property
|
@classmethod
|
||||||
def mel_folder(self):
|
def from_argparse_args(cls, args, **kwargs):
|
||||||
return self.root / 'mel_folder'
|
return CompareBase.from_argparse_args(args, class_names=cls.class_names, sub_dataset_name=cls.sub_dataset_name)
|
||||||
|
|
||||||
@property
|
|
||||||
def wav_folder(self):
|
|
||||||
return self.root / 'wav'
|
|
||||||
|
|
||||||
def __init__(self, data_root, batch_size, num_worker, sr, n_mels, n_fft, hop_length, sampler=None,
|
|
||||||
target_mel_length_in_seconds=0.7, random_apply_chance=0.5,
|
|
||||||
loudness_ratio=0.3, shift_ratio=0.3, noise_ratio=0.3, mask_ratio=0.3):
|
|
||||||
super(PrimatesLibrosaDatamodule, self).__init__()
|
|
||||||
self.sampler = sampler
|
|
||||||
self.samplers = None
|
|
||||||
|
|
||||||
self.num_worker = num_worker or 1
|
|
||||||
self.batch_size = batch_size
|
|
||||||
self.root = Path(data_root) / 'primates'
|
|
||||||
self.target_mel_length_in_seconds = target_mel_length_in_seconds
|
|
||||||
|
|
||||||
# Mel Transforms - will be pushed with all other paramters by self.__dict__ to subdataset-class
|
|
||||||
self.mel_kwargs = dict(sr=sr, n_mels=n_mels, n_fft=n_fft, hop_length=hop_length)
|
|
||||||
|
|
||||||
# Utility
|
|
||||||
self.utility_transforms = Compose([NormalizeLocal(), ToTensor()])
|
|
||||||
|
|
||||||
# Data Augmentations
|
|
||||||
self.random_apply_chance = random_apply_chance
|
|
||||||
self.mel_augmentations = Compose([
|
|
||||||
RandomApply([NoiseInjection(noise_ratio)], p=random_apply_chance),
|
|
||||||
RandomApply([LoudnessManipulator(loudness_ratio)], p=random_apply_chance),
|
|
||||||
RandomApply([ShiftTime(shift_ratio)], p=random_apply_chance),
|
|
||||||
RandomApply([MaskAug(mask_ratio)], p=random_apply_chance),
|
|
||||||
self.utility_transforms])
|
|
||||||
|
|
||||||
def train_dataloader(self):
|
|
||||||
return DataLoader(dataset=self.datasets[DATA_OPTION_train], num_workers=self.num_worker, pin_memory=True,
|
|
||||||
sampler=self.samplers[DATA_OPTION_train], batch_size=self.batch_size)
|
|
||||||
|
|
||||||
# Validation Dataloader
|
|
||||||
def val_dataloader(self):
|
|
||||||
return DataLoader(dataset=self.datasets[DATA_OPTION_devel], shuffle=False,
|
|
||||||
batch_size=self.batch_size, pin_memory=False,
|
|
||||||
num_workers=self.num_worker)
|
|
||||||
|
|
||||||
# Test Dataloader
|
|
||||||
def test_dataloader(self):
|
|
||||||
return DataLoader(dataset=self.datasets[DATA_OPTION_test], shuffle=False,
|
|
||||||
batch_size=self.batch_size, pin_memory=False,
|
|
||||||
num_workers=self.num_worker)
|
|
||||||
|
|
||||||
def _build_subdataset(self, row, build=False):
|
|
||||||
slice_file_name, class_name = row.strip().split(',')
|
|
||||||
class_id = self.class_names.get(class_name, -1)
|
|
||||||
audio_file_path = self.wav_folder / slice_file_name
|
|
||||||
|
|
||||||
# DATA OPTION DIFFERENTIATION !!!!!!!!!!! - Begin
|
|
||||||
kwargs = self.__dict__
|
|
||||||
if any([x in slice_file_name for x in [DATA_OPTION_devel, DATA_OPTION_test]]):
|
|
||||||
kwargs.update(mel_augmentations=self.utility_transforms)
|
|
||||||
# DATA OPTION DIFFERENTIATION !!!!!!!!!!! - End
|
|
||||||
|
|
||||||
target_frames = self.target_mel_length_in_seconds * self.mel_kwargs['sr']
|
|
||||||
sample_segment_length = target_frames // self.mel_kwargs['hop_length'] + 1
|
|
||||||
kwargs.update(sample_segment_len=sample_segment_length, sample_hop_len=sample_segment_length//2)
|
|
||||||
mel_dataset = LibrosaAudioToMelDataset(audio_file_path, class_id, **kwargs)
|
|
||||||
if build:
|
|
||||||
assert mel_dataset.build_mel()
|
|
||||||
return mel_dataset, class_id, slice_file_name
|
|
||||||
|
|
||||||
def prepare_data(self, *args, **kwargs):
|
|
||||||
datasets = dict()
|
|
||||||
for data_option in data_options:
|
|
||||||
with open(Path(self.root) / 'lab' / f'{data_option}.csv', mode='r') as f:
|
|
||||||
# Exclude the header
|
|
||||||
_ = next(f)
|
|
||||||
all_rows = list(f)
|
|
||||||
chunksize = len(all_rows) // max(self.num_worker, 1)
|
|
||||||
dataset = list()
|
|
||||||
with mp.Pool(processes=self.num_worker) as pool:
|
|
||||||
|
|
||||||
from itertools import repeat
|
|
||||||
results = pool.starmap_async(self._build_subdataset, zip(all_rows, repeat(True, len(all_rows))),
|
|
||||||
chunksize=chunksize)
|
|
||||||
for sub_dataset in results.get():
|
|
||||||
dataset.append(sub_dataset[0])
|
|
||||||
datasets[data_option] = ConcatDataset(dataset)
|
|
||||||
self.datasets = datasets
|
|
||||||
return datasets
|
|
||||||
|
|
||||||
def setup(self, stag=None):
|
|
||||||
datasets = dict()
|
|
||||||
samplers = dict()
|
|
||||||
weights = dict()
|
|
||||||
|
|
||||||
for data_option in data_options:
|
|
||||||
with open(Path(self.root) / 'lab' / f'{data_option}.csv', mode='r') as f:
|
|
||||||
# Exclude the header
|
|
||||||
_ = next(f)
|
|
||||||
all_rows = list(f)
|
|
||||||
dataset = list()
|
|
||||||
for row in all_rows:
|
|
||||||
mel_dataset, class_id, _ = self._build_subdataset(row)
|
|
||||||
dataset.append(mel_dataset)
|
|
||||||
datasets[data_option] = ConcatDataset(dataset)
|
|
||||||
|
|
||||||
# Build Weighted Sampler for train and val
|
|
||||||
if data_option in [DATA_OPTION_train]:
|
|
||||||
if self.sampler == EqualSampler.__name__:
|
|
||||||
class_idxs = [[idx for idx, (_, __, label) in enumerate(datasets[data_option]) if label == class_idx]
|
|
||||||
for class_idx in range(len(self.class_names))
|
|
||||||
]
|
|
||||||
samplers[data_option] = EqualSampler(class_idxs)
|
|
||||||
elif self.sampler == WeightedRandomSampler.__name__:
|
|
||||||
class_counts = defaultdict(lambda: 0)
|
|
||||||
for _, __, label in datasets[data_option]:
|
|
||||||
class_counts[label] += 1
|
|
||||||
len_largest_class = max(class_counts.values())
|
|
||||||
|
|
||||||
weights[data_option] = [1 / class_counts[x] for x in range(len(class_counts))]
|
|
||||||
##############################################################################
|
|
||||||
weights[data_option] = [weights[data_option][datasets[data_option][i][-1]]
|
|
||||||
for i in range(len(datasets[data_option]))]
|
|
||||||
samplers[data_option] = WeightedRandomSampler(weights[data_option],
|
|
||||||
len_largest_class * len(self.class_names))
|
|
||||||
else:
|
|
||||||
samplers[data_option] = None
|
|
||||||
self.datasets = datasets
|
|
||||||
self.samplers = samplers
|
|
||||||
return datasets
|
|
||||||
|
|
||||||
def purge(self):
|
|
||||||
import shutil
|
|
||||||
|
|
||||||
shutil.rmtree(self.mel_folder, ignore_errors=True)
|
|
||||||
print('Mel Folder has been recursively deleted')
|
|
||||||
print(f'Folder still exists: {self.mel_folder.exists()}')
|
|
||||||
return not self.mel_folder.exists()
|
|
||||||
|
23
main.py
23
main.py
@ -2,6 +2,7 @@ from argparse import Namespace
|
|||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
|
import yaml
|
||||||
from pytorch_lightning import Trainer, Callback
|
from pytorch_lightning import Trainer, Callback
|
||||||
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
|
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
|
||||||
|
|
||||||
@ -16,7 +17,7 @@ warnings.filterwarnings('ignore', category=FutureWarning)
|
|||||||
warnings.filterwarnings('ignore', category=UserWarning)
|
warnings.filterwarnings('ignore', category=UserWarning)
|
||||||
|
|
||||||
|
|
||||||
def run_lightning_loop(h_params, data_class, model_class, seed=69, additional_callbacks=None):
|
def run_lightning_loop(h_params :Namespace, data_class, model_class, seed=69, additional_callbacks=None):
|
||||||
|
|
||||||
fix_all_random_seeds(seed)
|
fix_all_random_seeds(seed)
|
||||||
|
|
||||||
@ -54,16 +55,23 @@ def run_lightning_loop(h_params, data_class, model_class, seed=69, additional_ca
|
|||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Let Datamodule pull what it wants
|
# Let Datamodule pull what it wants
|
||||||
datamodule = data_class.from_argparse_args(h_params)
|
datamodule = data_class.from_argparse_args(h_params)
|
||||||
datamodule.setup()
|
|
||||||
|
# Final h_params Setup:
|
||||||
|
h_params = vars(h_params)
|
||||||
|
h_params.update(in_shape=datamodule.shape, n_classes=datamodule.n_classes)
|
||||||
|
h_params = Namespace(**h_params)
|
||||||
|
|
||||||
# Let Trainer pull what it wants and add callbacks
|
# Let Trainer pull what it wants and add callbacks
|
||||||
trainer = Trainer.from_argparse_args(h_params, logger=logger, callbacks=callbacks)
|
trainer = Trainer.from_argparse_args(h_params, logger=logger, callbacks=callbacks)
|
||||||
|
|
||||||
# Let Model pull what it wants
|
# Let Model pull what it wants
|
||||||
model = model_class.from_argparse_args(h_params, in_shape=datamodule.shape, n_classes=datamodule.n_classes)
|
model = model_class.from_argparse_args(h_params)
|
||||||
model.init_weights()
|
model.init_weights()
|
||||||
|
|
||||||
# trainer.test(model=model, datamodule=datamodule)
|
# Store Model in Object File:
|
||||||
|
model.save_to_disk(logger.save_dir)
|
||||||
|
# Store h_params to yaml_file File & Neptune (if available):
|
||||||
|
logger.log_hyperparams(h_params)
|
||||||
|
|
||||||
trainer.fit(model, datamodule)
|
trainer.fit(model, datamodule)
|
||||||
trainer.save_checkpoint(logger.save_dir / 'last_weights.ckpt')
|
trainer.save_checkpoint(logger.save_dir / 'last_weights.ckpt')
|
||||||
@ -73,10 +81,9 @@ def run_lightning_loop(h_params, data_class, model_class, seed=69, additional_ca
|
|||||||
except:
|
except:
|
||||||
print('Test did not Suceed!')
|
print('Test did not Suceed!')
|
||||||
pass
|
pass
|
||||||
try:
|
|
||||||
logger.log_metrics(score_callback.best_scores, step=trainer.global_step+1)
|
logger.log_metrics(score_callback.best_scores, step=trainer.global_step+1)
|
||||||
except:
|
|
||||||
print('debug max_score_logging')
|
|
||||||
return score_callback.best_scores['PL_recall_score']
|
return score_callback.best_scores['PL_recall_score']
|
||||||
|
|
||||||
|
|
||||||
|
@ -7,7 +7,7 @@ from torch import nn
|
|||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
|
|
||||||
from ml_lib.metrics.multi_class_classification import MultiClassScores
|
from ml_lib.metrics.multi_class_classification import MultiClassScores
|
||||||
from ml_lib.modules.blocks import TransformerModule
|
from ml_lib.modules.blocks import (TransformerModule, F_x)
|
||||||
from ml_lib.modules.util import (LightningBaseModule, AutoPadToShape)
|
from ml_lib.modules.util import (LightningBaseModule, AutoPadToShape)
|
||||||
from util.module_mixins import CombinedModelMixins
|
from util.module_mixins import CombinedModelMixins
|
||||||
|
|
||||||
@ -21,7 +21,8 @@ class VisualTransformer(CombinedModelMixins,
|
|||||||
def __init__(self, in_shape, n_classes, weight_init, activation,
|
def __init__(self, in_shape, n_classes, weight_init, activation,
|
||||||
embedding_size, heads, attn_depth, patch_size, use_residual, variable_length,
|
embedding_size, heads, attn_depth, patch_size, use_residual, variable_length,
|
||||||
use_bias, use_norm, dropout, lat_dim, loss, scheduler, mlp_dim, head_dim,
|
use_bias, use_norm, dropout, lat_dim, loss, scheduler, mlp_dim, head_dim,
|
||||||
lr, weight_decay, sto_weight_avg, lr_scheduler_parameter, opt_reset_interval):
|
lr, weight_decay, sto_weight_avg, lr_scheduler_parameter, opt_reset_interval,
|
||||||
|
return_logits=False):
|
||||||
|
|
||||||
# TODO: Move this to parent class, or make it much easier to access... But How...
|
# TODO: Move this to parent class, or make it much easier to access... But How...
|
||||||
a = dict(locals())
|
a = dict(locals())
|
||||||
@ -69,14 +70,20 @@ class VisualTransformer(CombinedModelMixins,
|
|||||||
self.to_cls_token = nn.Identity()
|
self.to_cls_token = nn.Identity()
|
||||||
|
|
||||||
logits = self.params.n_classes if self.params.n_classes > 2 else 1
|
logits = self.params.n_classes if self.params.n_classes > 2 else 1
|
||||||
|
|
||||||
|
if return_logits:
|
||||||
|
outbound_activation = nn.Identity()
|
||||||
|
else:
|
||||||
|
outbound_activation = nn.Softmax() if logits > 1 else nn.Sigmoid()
|
||||||
|
|
||||||
|
|
||||||
self.mlp_head = nn.Sequential(
|
self.mlp_head = nn.Sequential(
|
||||||
nn.LayerNorm(self.embed_dim),
|
nn.LayerNorm(self.embed_dim),
|
||||||
nn.Linear(self.embed_dim, self.params.lat_dim),
|
nn.Linear(self.embed_dim, self.params.lat_dim),
|
||||||
nn.GELU(),
|
self.params.activation(),
|
||||||
nn.Dropout(self.params.dropout),
|
nn.Dropout(self.params.dropout),
|
||||||
nn.Linear(self.params.lat_dim, logits),
|
nn.Linear(self.params.lat_dim, logits),
|
||||||
nn.Softmax() if logits > 1 else nn.Sigmoid()
|
outbound_activation
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x, mask=None, return_attn_weights=False):
|
def forward(self, x, mask=None, return_attn_weights=False):
|
||||||
|
383
notebooks/Dataset Analysis.ipynb
Normal file
383
notebooks/Dataset Analysis.ipynb
Normal file
File diff suppressed because one or more lines are too long
397
notebooks/Train Eval.ipynb
Normal file
397
notebooks/Train Eval.ipynb
Normal file
@ -0,0 +1,397 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 6,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from collections import defaultdict\n",
|
||||||
|
"from pathlib import Path\n",
|
||||||
|
"from natsort import natsorted\n",
|
||||||
|
"from pytorch_lightning.core.saving import ModelIO\n",
|
||||||
|
"from ml_lib.utils.model_io import SavedLightningModels\n",
|
||||||
|
"from ml_lib.utils.tools import locate_and_import_class\n",
|
||||||
|
"\n",
|
||||||
|
"import yaml\n",
|
||||||
|
"\n",
|
||||||
|
"import numpy as np\n",
|
||||||
|
"import torch\n",
|
||||||
|
"import pytorch_lightning as pl\n",
|
||||||
|
"import librosa\n",
|
||||||
|
"import pandas as pd\n",
|
||||||
|
"import variables as v\n",
|
||||||
|
"import seaborn as sns\n",
|
||||||
|
"from tqdm import tqdm\n",
|
||||||
|
"from matplotlib import pyplot as plt"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%% Imports go here\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 12,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Settings and Variables\n",
|
||||||
|
"\n",
|
||||||
|
"# This Experiment (= Model and Parameter Configuration\n",
|
||||||
|
"_ROOT = Path('..')\n",
|
||||||
|
"out_path = Path('..') / Path('output')\n",
|
||||||
|
"model_name = 'VisualTransformer'\n"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 42,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def print_stats(data_option, mean_duration, std_duration, min_duration, max_duration):\n",
|
||||||
|
" print(f'For {data_option}; statistics are:')\n",
|
||||||
|
" print(f'Scores - mean: {mean_duration:.3f}s\\tstd: {std_duration:.3f}s'\n",
|
||||||
|
" f'min: {min_duration:.3f}s\\t max: {max_duration:.3f}s')\n",
|
||||||
|
"\n",
|
||||||
|
"def print_metrics(exp_path):\n",
|
||||||
|
" print(f'--------------{exp_path.name}------------------')\n",
|
||||||
|
" best_scores = []\n",
|
||||||
|
" had_errors = []\n",
|
||||||
|
" for run_folder in [x for x in exp_path.iterdir() if x.is_dir()]:\n",
|
||||||
|
" # model_class = locate_and_import_class(model_name, 'models')\n",
|
||||||
|
" # sorted_checkpoints = natsorted(run_folder.glob('*.ckpt'))\n",
|
||||||
|
" # model = ModelIO.load_from_checkpoint(str(sorted_checkpoints[0]), strict=True)\n",
|
||||||
|
" try:\n",
|
||||||
|
" metrics = pd.read_csv(run_folder / 'metrics.csv')\n",
|
||||||
|
"\n",
|
||||||
|
" # Possible keys are:\n",
|
||||||
|
" # -- CE - Losses:\n",
|
||||||
|
" # val_max_vote_loss, val_mean_vote_loss, mean_val_loss\n",
|
||||||
|
" # -- Fallback:\n",
|
||||||
|
" # mean_loss,epoch,step,macro_f1_score, macro_roc_auc_ovr, uar_score, micro_f1_score\n",
|
||||||
|
" # Pytorch Metrics:\n",
|
||||||
|
" # PL_f1_score,PL_accuracy_score_score, PL_fbeta_score,PL_recall_score,PL_precision_score,\n",
|
||||||
|
" score = metrics.PL_recall_score[-1]\n",
|
||||||
|
" print(f'{exp_path.name} - {run_folder.name}: {score}')\n",
|
||||||
|
" best_scores.append(score)\n",
|
||||||
|
" had_errors.append(False)\n",
|
||||||
|
" except (AttributeError, FileNotFoundError):\n",
|
||||||
|
" had_errors.append(True)\n",
|
||||||
|
" pass\n",
|
||||||
|
" if any(had_errors):\n",
|
||||||
|
" return\n",
|
||||||
|
" else:\n",
|
||||||
|
" print('\\n')\n",
|
||||||
|
" stats = np.mean(best_scores), np.std(best_scores), np.min(best_scores), np.max(best_scores)\n",
|
||||||
|
" print_stats(exp_path.name, *stats)\n",
|
||||||
|
" print('--------------------------------------------')\n"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%% Util Functions\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 32,
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"--------------VT_259ee495ee2d2dc0e56bb23d12476f17------------------\n",
|
||||||
|
"VT_259ee495ee2d2dc0e56bb23d12476f17 - version_1: 0.8403531908988953\n",
|
||||||
|
"VT_259ee495ee2d2dc0e56bb23d12476f17 - version_3: 0.8312729001045227\n",
|
||||||
|
"VT_259ee495ee2d2dc0e56bb23d12476f17 - version_0: 0.8342075347900391\n",
|
||||||
|
"VT_259ee495ee2d2dc0e56bb23d12476f17 - version_5: 0.8459098935127258\n",
|
||||||
|
"VT_259ee495ee2d2dc0e56bb23d12476f17 - version_2: 0.8468937277793884\n",
|
||||||
|
"VT_259ee495ee2d2dc0e56bb23d12476f17 - version_4: 0.8404075503349304\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"For VT_259ee495ee2d2dc0e56bb23d12476f17; statistics are:\n",
|
||||||
|
"Scores - mean: 0.840s\tstd: 0.006smin: 0.831s\t max: 0.847s\n",
|
||||||
|
"--------------------------------------------\n",
|
||||||
|
"--------------VT_012aff7c1c667073aedafcbebfa35ec7------------------\n",
|
||||||
|
"VT_012aff7c1c667073aedafcbebfa35ec7 - version_6: 0.8637051582336426\n",
|
||||||
|
"VT_012aff7c1c667073aedafcbebfa35ec7 - version_1: 0.864475429058075\n",
|
||||||
|
"VT_012aff7c1c667073aedafcbebfa35ec7 - version_3: 0.854859471321106\n",
|
||||||
|
"VT_012aff7c1c667073aedafcbebfa35ec7 - version_0: 0.8631429672241211\n",
|
||||||
|
"VT_012aff7c1c667073aedafcbebfa35ec7 - version_8: 0.8484407663345337\n",
|
||||||
|
"VT_012aff7c1c667073aedafcbebfa35ec7 - version_5: 0.8564963340759277\n",
|
||||||
|
"VT_012aff7c1c667073aedafcbebfa35ec7 - version_7: 0.8519455194473267\n",
|
||||||
|
"VT_012aff7c1c667073aedafcbebfa35ec7 - version_2: 0.8683117032051086\n",
|
||||||
|
"VT_012aff7c1c667073aedafcbebfa35ec7 - version_9: 0.8730489611625671\n",
|
||||||
|
"VT_012aff7c1c667073aedafcbebfa35ec7 - version_4: 0.8658838272094727\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"For VT_012aff7c1c667073aedafcbebfa35ec7; statistics are:\n",
|
||||||
|
"Scores - mean: 0.861s\tstd: 0.007smin: 0.848s\t max: 0.873s\n",
|
||||||
|
"--------------------------------------------\n",
|
||||||
|
"--------------VT_fdf2a86085b508c1325b181c830a4cf7------------------\n",
|
||||||
|
"VT_fdf2a86085b508c1325b181c830a4cf7 - version_6: 0.854997456073761\n",
|
||||||
|
"VT_fdf2a86085b508c1325b181c830a4cf7 - version_1: 0.8609604835510254\n",
|
||||||
|
"VT_fdf2a86085b508c1325b181c830a4cf7 - version_3: 0.8558254837989807\n",
|
||||||
|
"VT_fdf2a86085b508c1325b181c830a4cf7 - version_0: 0.8728921413421631\n",
|
||||||
|
"VT_fdf2a86085b508c1325b181c830a4cf7 - version_8: 0.8631933927536011\n",
|
||||||
|
"VT_fdf2a86085b508c1325b181c830a4cf7 - version_5: 0.8612215518951416\n",
|
||||||
|
"VT_fdf2a86085b508c1325b181c830a4cf7 - version_7: 0.8661960959434509\n",
|
||||||
|
"VT_fdf2a86085b508c1325b181c830a4cf7 - version_2: 0.8636621832847595\n",
|
||||||
|
"VT_fdf2a86085b508c1325b181c830a4cf7 - version_9: 0.8614727258682251\n",
|
||||||
|
"VT_fdf2a86085b508c1325b181c830a4cf7 - version_4: 0.8657329082489014\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"For VT_fdf2a86085b508c1325b181c830a4cf7; statistics are:\n",
|
||||||
|
"Scores - mean: 0.863s\tstd: 0.005smin: 0.855s\t max: 0.873s\n",
|
||||||
|
"--------------------------------------------\n",
|
||||||
|
"--------------VT_cc64c06847a7ca26f5ea4d465f9cc5bc------------------\n",
|
||||||
|
"VT_cc64c06847a7ca26f5ea4d465f9cc5bc - version_6: 0.8572231531143188\n",
|
||||||
|
"VT_cc64c06847a7ca26f5ea4d465f9cc5bc - version_1: 0.8442623615264893\n",
|
||||||
|
"VT_cc64c06847a7ca26f5ea4d465f9cc5bc - version_3: 0.8498414754867554\n",
|
||||||
|
"VT_cc64c06847a7ca26f5ea4d465f9cc5bc - version_0: 0.8569087982177734\n",
|
||||||
|
"VT_cc64c06847a7ca26f5ea4d465f9cc5bc - version_8: 0.8455194234848022\n",
|
||||||
|
"VT_cc64c06847a7ca26f5ea4d465f9cc5bc - version_5: 0.8435630798339844\n",
|
||||||
|
"VT_cc64c06847a7ca26f5ea4d465f9cc5bc - version_7: 0.845982551574707\n",
|
||||||
|
"VT_cc64c06847a7ca26f5ea4d465f9cc5bc - version_2: 0.8571171164512634\n",
|
||||||
|
"VT_cc64c06847a7ca26f5ea4d465f9cc5bc - version_9: 0.8448543548583984\n",
|
||||||
|
"VT_cc64c06847a7ca26f5ea4d465f9cc5bc - version_4: 0.845399022102356\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"For VT_cc64c06847a7ca26f5ea4d465f9cc5bc; statistics are:\n",
|
||||||
|
"Scores - mean: 0.849s\tstd: 0.005smin: 0.844s\t max: 0.857s\n",
|
||||||
|
"--------------------------------------------\n",
|
||||||
|
"--------------VT_2c7afd50e127f5a2339db0ddfd6bfd7c------------------\n",
|
||||||
|
"VT_2c7afd50e127f5a2339db0ddfd6bfd7c - version_6: 0.8630585670471191\n",
|
||||||
|
"VT_2c7afd50e127f5a2339db0ddfd6bfd7c - version_1: 0.8686699271202087\n",
|
||||||
|
"VT_2c7afd50e127f5a2339db0ddfd6bfd7c - version_3: 0.8729345798492432\n",
|
||||||
|
"VT_2c7afd50e127f5a2339db0ddfd6bfd7c - version_0: 0.8636038899421692\n",
|
||||||
|
"VT_2c7afd50e127f5a2339db0ddfd6bfd7c - version_8: 0.8558077812194824\n",
|
||||||
|
"VT_2c7afd50e127f5a2339db0ddfd6bfd7c - version_5: 0.8710847496986389\n",
|
||||||
|
"VT_2c7afd50e127f5a2339db0ddfd6bfd7c - version_7: 0.8619015216827393\n",
|
||||||
|
"VT_2c7afd50e127f5a2339db0ddfd6bfd7c - version_2: 0.8499867916107178\n",
|
||||||
|
"VT_2c7afd50e127f5a2339db0ddfd6bfd7c - version_9: 0.8507344722747803\n",
|
||||||
|
"VT_2c7afd50e127f5a2339db0ddfd6bfd7c - version_4: 0.8555077314376831\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"For VT_2c7afd50e127f5a2339db0ddfd6bfd7c; statistics are:\n",
|
||||||
|
"Scores - mean: 0.861s\tstd: 0.008smin: 0.850s\t max: 0.873s\n",
|
||||||
|
"--------------------------------------------\n",
|
||||||
|
"--------------VT_63b9fee765cdda91756af1f35cd320a3------------------\n",
|
||||||
|
"VT_63b9fee765cdda91756af1f35cd320a3 - version_6: 0.8663593530654907\n",
|
||||||
|
"VT_63b9fee765cdda91756af1f35cd320a3 - version_1: 0.8519773483276367\n",
|
||||||
|
"VT_63b9fee765cdda91756af1f35cd320a3 - version_3: 0.8519774675369263\n",
|
||||||
|
"VT_63b9fee765cdda91756af1f35cd320a3 - version_0: 0.8603388071060181\n",
|
||||||
|
"VT_63b9fee765cdda91756af1f35cd320a3 - version_8: 0.8614517450332642\n",
|
||||||
|
"VT_63b9fee765cdda91756af1f35cd320a3 - version_5: 0.8558711409568787\n",
|
||||||
|
"VT_63b9fee765cdda91756af1f35cd320a3 - version_7: 0.8537712097167969\n",
|
||||||
|
"VT_63b9fee765cdda91756af1f35cd320a3 - version_2: 0.8558205962181091\n",
|
||||||
|
"VT_63b9fee765cdda91756af1f35cd320a3 - version_9: 0.8647329211235046\n",
|
||||||
|
"VT_63b9fee765cdda91756af1f35cd320a3 - version_4: 0.8546129465103149\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"For VT_63b9fee765cdda91756af1f35cd320a3; statistics are:\n",
|
||||||
|
"Scores - mean: 0.858s\tstd: 0.005smin: 0.852s\t max: 0.866s\n",
|
||||||
|
"--------------------------------------------\n",
|
||||||
|
"--------------VT_aca900a5b9566af61c91aea6525190e6------------------\n",
|
||||||
|
"VT_aca900a5b9566af61c91aea6525190e6 - version_6: 0.8575441241264343\n",
|
||||||
|
"VT_aca900a5b9566af61c91aea6525190e6 - version_1: 0.8453981280326843\n",
|
||||||
|
"VT_aca900a5b9566af61c91aea6525190e6 - version_3: 0.8621359467506409\n",
|
||||||
|
"VT_aca900a5b9566af61c91aea6525190e6 - version_0: 0.8547767400741577\n",
|
||||||
|
"VT_aca900a5b9566af61c91aea6525190e6 - version_8: 0.8613359928131104\n",
|
||||||
|
"VT_aca900a5b9566af61c91aea6525190e6 - version_5: 0.8667657375335693\n",
|
||||||
|
"VT_aca900a5b9566af61c91aea6525190e6 - version_7: 0.8474754095077515\n",
|
||||||
|
"VT_aca900a5b9566af61c91aea6525190e6 - version_2: 0.8628634214401245\n",
|
||||||
|
"VT_aca900a5b9566af61c91aea6525190e6 - version_9: 0.8585749268531799\n",
|
||||||
|
"VT_aca900a5b9566af61c91aea6525190e6 - version_4: 0.8380126357078552\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"For VT_aca900a5b9566af61c91aea6525190e6; statistics are:\n",
|
||||||
|
"Scores - mean: 0.855s\tstd: 0.009smin: 0.838s\t max: 0.867s\n",
|
||||||
|
"--------------------------------------------\n",
|
||||||
|
"--------------VT_fb6b96a190455106d29f0630f002ac6f------------------\n",
|
||||||
|
"VT_fb6b96a190455106d29f0630f002ac6f - version_6: 0.8635155558586121\n",
|
||||||
|
"VT_fb6b96a190455106d29f0630f002ac6f - version_1: 0.8261691927909851\n",
|
||||||
|
"VT_fb6b96a190455106d29f0630f002ac6f - version_3: 0.8444902896881104\n",
|
||||||
|
"VT_fb6b96a190455106d29f0630f002ac6f - version_0: 0.865719735622406\n",
|
||||||
|
"VT_fb6b96a190455106d29f0630f002ac6f - version_8: 0.8533784747123718\n",
|
||||||
|
"VT_fb6b96a190455106d29f0630f002ac6f - version_5: 0.8555656671524048\n",
|
||||||
|
"VT_fb6b96a190455106d29f0630f002ac6f - version_7: 0.837948739528656\n",
|
||||||
|
"VT_fb6b96a190455106d29f0630f002ac6f - version_2: 0.8545827865600586\n",
|
||||||
|
"VT_fb6b96a190455106d29f0630f002ac6f - version_9: 0.8541560769081116\n",
|
||||||
|
"VT_fb6b96a190455106d29f0630f002ac6f - version_4: 0.85297691822052\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"For VT_fb6b96a190455106d29f0630f002ac6f; statistics are:\n",
|
||||||
|
"Scores - mean: 0.851s\tstd: 0.011smin: 0.826s\t max: 0.866s\n",
|
||||||
|
"--------------------------------------------\n",
|
||||||
|
"--------------VT_378971720b930050ad7662bb96699e20------------------\n",
|
||||||
|
"VT_378971720b930050ad7662bb96699e20 - version_6: 0.8388294577598572\n",
|
||||||
|
"VT_378971720b930050ad7662bb96699e20 - version_1: 0.8333806395530701\n",
|
||||||
|
"VT_378971720b930050ad7662bb96699e20 - version_3: 0.847841203212738\n",
|
||||||
|
"VT_378971720b930050ad7662bb96699e20 - version_0: 0.8287097811698914\n",
|
||||||
|
"VT_378971720b930050ad7662bb96699e20 - version_8: 0.8436978459358215\n",
|
||||||
|
"VT_378971720b930050ad7662bb96699e20 - version_5: 0.8392724990844727\n",
|
||||||
|
"VT_378971720b930050ad7662bb96699e20 - version_7: 0.8410612344741821\n",
|
||||||
|
"VT_378971720b930050ad7662bb96699e20 - version_2: 0.8407015204429626\n",
|
||||||
|
"VT_378971720b930050ad7662bb96699e20 - version_9: 0.8334627151489258\n",
|
||||||
|
"VT_378971720b930050ad7662bb96699e20 - version_4: 0.8400266766548157\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"For VT_378971720b930050ad7662bb96699e20; statistics are:\n",
|
||||||
|
"Scores - mean: 0.839s\tstd: 0.005smin: 0.829s\t max: 0.848s\n",
|
||||||
|
"--------------------------------------------\n",
|
||||||
|
"--------------VT_d55f1492ff29a3cd1026013948ce7fa7------------------\n",
|
||||||
|
"VT_d55f1492ff29a3cd1026013948ce7fa7 - version_6: 0.8385945558547974\n",
|
||||||
|
"VT_d55f1492ff29a3cd1026013948ce7fa7 - version_1: 0.8324360251426697\n",
|
||||||
|
"VT_d55f1492ff29a3cd1026013948ce7fa7 - version_3: 0.8386826515197754\n",
|
||||||
|
"VT_d55f1492ff29a3cd1026013948ce7fa7 - version_0: 0.8366813063621521\n",
|
||||||
|
"VT_d55f1492ff29a3cd1026013948ce7fa7 - version_8: 0.8460721969604492\n",
|
||||||
|
"VT_d55f1492ff29a3cd1026013948ce7fa7 - version_5: 0.8374781608581543\n",
|
||||||
|
"VT_d55f1492ff29a3cd1026013948ce7fa7 - version_7: 0.8320286273956299\n",
|
||||||
|
"VT_d55f1492ff29a3cd1026013948ce7fa7 - version_2: 0.8370164632797241\n",
|
||||||
|
"VT_d55f1492ff29a3cd1026013948ce7fa7 - version_9: 0.8495808839797974\n",
|
||||||
|
"VT_d55f1492ff29a3cd1026013948ce7fa7 - version_4: 0.8332125544548035\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"For VT_d55f1492ff29a3cd1026013948ce7fa7; statistics are:\n",
|
||||||
|
"Scores - mean: 0.838s\tstd: 0.005smin: 0.832s\t max: 0.850s\n",
|
||||||
|
"--------------------------------------------\n",
|
||||||
|
"--------------VT_15cbb349b2b50dbb97beec16af2bedab------------------\n",
|
||||||
|
"VT_15cbb349b2b50dbb97beec16af2bedab - version_6: 0.8407894372940063\n",
|
||||||
|
"VT_15cbb349b2b50dbb97beec16af2bedab - version_1: 0.836580216884613\n",
|
||||||
|
"VT_15cbb349b2b50dbb97beec16af2bedab - version_3: 0.8312996029853821\n",
|
||||||
|
"VT_15cbb349b2b50dbb97beec16af2bedab - version_0: 0.8336991667747498\n",
|
||||||
|
"VT_15cbb349b2b50dbb97beec16af2bedab - version_8: 0.8231534957885742\n",
|
||||||
|
"VT_15cbb349b2b50dbb97beec16af2bedab - version_5: 0.8243923187255859\n",
|
||||||
|
"VT_15cbb349b2b50dbb97beec16af2bedab - version_7: 0.8342592120170593\n",
|
||||||
|
"VT_15cbb349b2b50dbb97beec16af2bedab - version_2: 0.8349334001541138\n",
|
||||||
|
"VT_15cbb349b2b50dbb97beec16af2bedab - version_9: 0.8382810950279236\n",
|
||||||
|
"VT_15cbb349b2b50dbb97beec16af2bedab - version_4: 0.8381868600845337\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"For VT_15cbb349b2b50dbb97beec16af2bedab; statistics are:\n",
|
||||||
|
"Scores - mean: 0.834s\tstd: 0.006smin: 0.823s\t max: 0.841s\n",
|
||||||
|
"--------------------------------------------\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"for model_configuration in [x for x in (out_path / model_name).iterdir() if x.is_dir()]:\n",
|
||||||
|
" # Print metrics\n",
|
||||||
|
" print_metrics(model_configuration)"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%% Mass - Load Model and read Metrics\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 15,
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"--------------VT_fdf2a86085b508c1325b181c830a4cf7------------------\n",
|
||||||
|
"--------------VT_fdf2a86085b508c1325b181c830a4cf7------------------\n",
|
||||||
|
"VT_fdf2a86085b508c1325b181c830a4cf7 - version_6: 0.854997456073761\n",
|
||||||
|
"VT_fdf2a86085b508c1325b181c830a4cf7 - version_1: 0.8609604835510254\n",
|
||||||
|
"VT_fdf2a86085b508c1325b181c830a4cf7 - version_3: 0.8558254837989807\n",
|
||||||
|
"VT_fdf2a86085b508c1325b181c830a4cf7 - version_0: 0.8728921413421631\n",
|
||||||
|
"VT_fdf2a86085b508c1325b181c830a4cf7 - version_8: 0.8631933927536011\n",
|
||||||
|
"VT_fdf2a86085b508c1325b181c830a4cf7 - version_5: 0.8612215518951416\n",
|
||||||
|
"VT_fdf2a86085b508c1325b181c830a4cf7 - version_7: 0.8661960959434509\n",
|
||||||
|
"VT_fdf2a86085b508c1325b181c830a4cf7 - version_2: 0.8636621832847595\n",
|
||||||
|
"VT_fdf2a86085b508c1325b181c830a4cf7 - version_9: 0.8614727258682251\n",
|
||||||
|
"VT_fdf2a86085b508c1325b181c830a4cf7 - version_4: 0.8657329082489014\n",
|
||||||
|
"--------------------------------------------\n",
|
||||||
|
"--------------------------------------------\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"# fingerprint = '012aff7c1c667073aedafcbebfa35ec7'\n",
|
||||||
|
"fingerprint = 'fdf2a86085b508c1325b181c830a4cf7'\n",
|
||||||
|
"exp_name = f'{\"\".join([x for x in model_name if x.isupper()])}_{fingerprint}'\n",
|
||||||
|
"\n",
|
||||||
|
"# Print metrics\n",
|
||||||
|
"print_metrics(out_path/model_name/exp_name)\n",
|
||||||
|
"\n"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%% Single - Load Model and read Metrics\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 39,
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
" filenames prediction prediction_named\n",
|
||||||
|
"0 test_00001 1 chimpanze\n",
|
||||||
|
"1 test_00002 0 background\n",
|
||||||
|
"2 test_00003 0 background\n",
|
||||||
|
"3 test_00004 1 chimpanze\n",
|
||||||
|
"4 test_00005 4 redcap\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"predictions_file = out_path/model_name/'VT_15cbb349b2b50dbb97beec16af2bedab'/'version_9'/'predictions.csv'\n",
|
||||||
|
"df_predictions = pd.read_csv(predictions_file)\n",
|
||||||
|
"print(df_predictions.head())\n",
|
||||||
|
"df_predictions = df_predictions[['filenames', 'prediction_named']]\n",
|
||||||
|
"df_predictions.columns = ['filename', 'prediction']\n",
|
||||||
|
"df_predictions['filename'] = df_predictions['filename'] + '.wav'\n",
|
||||||
|
"predictions_file_new = predictions_file.parent / 'prediction_final.csv'\n",
|
||||||
|
"df_predictions.to_csv(index=False, path_or_buf=predictions_file_new)\n",
|
||||||
|
"\n",
|
||||||
|
"\n"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%% Combine Predictions#\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 2
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython2",
|
||||||
|
"version": "2.7.6"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 0
|
||||||
|
}
|
@ -1,104 +0,0 @@
|
|||||||
{
|
|
||||||
"cells": [
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 47,
|
|
||||||
"metadata": {
|
|
||||||
"collapsed": true,
|
|
||||||
"pycharm": {
|
|
||||||
"name": "#%% IMPORTS\n"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"from pathlib import Path\n",
|
|
||||||
"from natsort import natsorted\n",
|
|
||||||
"from pytorch_lightning.core.saving import *\n",
|
|
||||||
"from ml_lib.utils.model_io import SavedLightningModels\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 48,
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"from ml_lib.utils.tools import locate_and_import_class\n",
|
|
||||||
"from models.transformer_model import VisualTransformer\n",
|
|
||||||
"_ROOT = Path('..')\n",
|
|
||||||
"out_path = 'output'\n",
|
|
||||||
"model_class = VisualTransformer\n",
|
|
||||||
"model_name = model_class.name()\n",
|
|
||||||
"\n",
|
|
||||||
"exp_name = 'VT_01123c93daaffa92d2ed341bda32426d'\n",
|
|
||||||
"version = 'version_2'"
|
|
||||||
],
|
|
||||||
"metadata": {
|
|
||||||
"collapsed": false,
|
|
||||||
"pycharm": {
|
|
||||||
"name": "#%%M Path resolving and variables\n"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 50,
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"ename": "ValueError",
|
|
||||||
"evalue": "When you set `reduce` as 'macro', you have to provide the number of classes.",
|
|
||||||
"output_type": "error",
|
|
||||||
"traceback": [
|
|
||||||
"\u001B[1;31m---------------------------------------------------------------------------\u001B[0m",
|
|
||||||
"\u001B[1;31mValueError\u001B[0m Traceback (most recent call last)",
|
|
||||||
"\u001B[1;32m<ipython-input-50-0216292a172f>\u001B[0m in \u001B[0;36m<module>\u001B[1;34m\u001B[0m\n\u001B[0;32m 6\u001B[0m \u001B[0madditional_kwargs\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mdict\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mvariable_length\u001B[0m \u001B[1;33m=\u001B[0m \u001B[1;32mFalse\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mc_classes\u001B[0m\u001B[1;33m=\u001B[0m\u001B[1;36m5\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 7\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m----> 8\u001B[1;33m \u001B[0mmodel\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mmodel_class\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mload_from_checkpoint\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mcheckpoint\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mhparams_file\u001B[0m\u001B[1;33m=\u001B[0m\u001B[0mstr\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mhparams_yaml\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m,\u001B[0m \u001B[1;33m**\u001B[0m\u001B[0madditional_kwargs\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m 9\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n",
|
|
||||||
"\u001B[1;32mc:\\users\\steff\\envs\\compare_21\\lib\\site-packages\\pytorch_lightning\\core\\saving.py\u001B[0m in \u001B[0;36mload_from_checkpoint\u001B[1;34m(cls, checkpoint_path, map_location, hparams_file, strict, **kwargs)\u001B[0m\n\u001B[0;32m 154\u001B[0m \u001B[0mcheckpoint\u001B[0m\u001B[1;33m[\u001B[0m\u001B[0mcls\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mCHECKPOINT_HYPER_PARAMS_KEY\u001B[0m\u001B[1;33m]\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mupdate\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mkwargs\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 155\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m--> 156\u001B[1;33m \u001B[0mmodel\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mcls\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0m_load_model_state\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mcheckpoint\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mstrict\u001B[0m\u001B[1;33m=\u001B[0m\u001B[0mstrict\u001B[0m\u001B[1;33m,\u001B[0m \u001B[1;33m**\u001B[0m\u001B[0mkwargs\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m 157\u001B[0m \u001B[1;32mreturn\u001B[0m \u001B[0mmodel\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 158\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n",
|
|
||||||
"\u001B[1;32mc:\\users\\steff\\envs\\compare_21\\lib\\site-packages\\pytorch_lightning\\core\\saving.py\u001B[0m in \u001B[0;36m_load_model_state\u001B[1;34m(cls, checkpoint, strict, **cls_kwargs_new)\u001B[0m\n\u001B[0;32m 196\u001B[0m \u001B[0m_cls_kwargs\u001B[0m \u001B[1;33m=\u001B[0m \u001B[1;33m{\u001B[0m\u001B[0mk\u001B[0m\u001B[1;33m:\u001B[0m \u001B[0mv\u001B[0m \u001B[1;32mfor\u001B[0m \u001B[0mk\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mv\u001B[0m \u001B[1;32min\u001B[0m \u001B[0m_cls_kwargs\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mitems\u001B[0m\u001B[1;33m(\u001B[0m\u001B[1;33m)\u001B[0m \u001B[1;32mif\u001B[0m \u001B[0mk\u001B[0m \u001B[1;32min\u001B[0m \u001B[0mcls_init_args_name\u001B[0m\u001B[1;33m}\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 197\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m--> 198\u001B[1;33m \u001B[0mmodel\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mcls\u001B[0m\u001B[1;33m(\u001B[0m\u001B[1;33m**\u001B[0m\u001B[0m_cls_kwargs\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m 199\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 200\u001B[0m \u001B[1;31m# give model a chance to load something\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n",
|
|
||||||
"\u001B[1;32m~\\projects\\compare_21\\models\\transformer_model.py\u001B[0m in \u001B[0;36m__init__\u001B[1;34m(self, in_shape, n_classes, weight_init, activation, embedding_size, heads, attn_depth, patch_size, use_residual, variable_length, use_bias, use_norm, dropout, lat_dim, loss, scheduler, mlp_dim, head_dim, lr, weight_decay, sto_weight_avg, lr_scheduler_parameter, opt_reset_interval)\u001B[0m\n\u001B[0;32m 27\u001B[0m \u001B[0ma\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mdict\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mlocals\u001B[0m\u001B[1;33m(\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 28\u001B[0m \u001B[0mparams\u001B[0m \u001B[1;33m=\u001B[0m \u001B[1;33m{\u001B[0m\u001B[0marg\u001B[0m\u001B[1;33m:\u001B[0m \u001B[0ma\u001B[0m\u001B[1;33m[\u001B[0m\u001B[0marg\u001B[0m\u001B[1;33m]\u001B[0m \u001B[1;32mfor\u001B[0m \u001B[0marg\u001B[0m \u001B[1;32min\u001B[0m \u001B[0minspect\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0msignature\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mself\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0m__init__\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mparameters\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mkeys\u001B[0m\u001B[1;33m(\u001B[0m\u001B[1;33m)\u001B[0m \u001B[1;32mif\u001B[0m \u001B[0marg\u001B[0m \u001B[1;33m!=\u001B[0m \u001B[1;34m'self'\u001B[0m\u001B[1;33m}\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m---> 29\u001B[1;33m \u001B[0msuper\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mVisualTransformer\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mself\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0m__init__\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mparams\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m 30\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 31\u001B[0m \u001B[0mself\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0min_shape\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0min_shape\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n",
|
|
||||||
"\u001B[1;32m~\\projects\\compare_21\\ml_lib\\modules\\util.py\u001B[0m in \u001B[0;36m__init__\u001B[1;34m(self, model_parameters, weight_init)\u001B[0m\n\u001B[0;32m 112\u001B[0m \u001B[0mself\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mparams\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mModelParameters\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mmodel_parameters\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 113\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m--> 114\u001B[1;33m \u001B[0mself\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mmetrics\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mPLMetrics\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mself\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mparams\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mn_classes\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mtag\u001B[0m\u001B[1;33m=\u001B[0m\u001B[1;34m'PL'\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m 115\u001B[0m \u001B[1;32mpass\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 116\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n",
|
|
||||||
"\u001B[1;32m~\\projects\\compare_21\\ml_lib\\modules\\util.py\u001B[0m in \u001B[0;36m__init__\u001B[1;34m(self, n_classes, tag)\u001B[0m\n\u001B[0;32m 30\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 31\u001B[0m \u001B[0mself\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0maccuracy_score\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mpl\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mmetrics\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mAccuracy\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mcompute_on_step\u001B[0m\u001B[1;33m=\u001B[0m\u001B[1;32mFalse\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m---> 32\u001B[1;33m \u001B[0mself\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mprecision\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mpl\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mmetrics\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mPrecision\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mnum_classes\u001B[0m\u001B[1;33m=\u001B[0m\u001B[0mself\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mn_classes\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0maverage\u001B[0m\u001B[1;33m=\u001B[0m\u001B[1;34m'macro'\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mcompute_on_step\u001B[0m\u001B[1;33m=\u001B[0m\u001B[1;32mFalse\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m 33\u001B[0m \u001B[0mself\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mrecall\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mpl\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mmetrics\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mRecall\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mnum_classes\u001B[0m\u001B[1;33m=\u001B[0m\u001B[0mself\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mn_classes\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0maverage\u001B[0m\u001B[1;33m=\u001B[0m\u001B[1;34m'macro'\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mcompute_on_step\u001B[0m\u001B[1;33m=\u001B[0m\u001B[1;32mFalse\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 34\u001B[0m \u001B[0mself\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mconfusion_matrix\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mpl\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mmetrics\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mConfusionMatrix\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mself\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mn_classes\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mnormalize\u001B[0m\u001B[1;33m=\u001B[0m\u001B[1;34m'true'\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mcompute_on_step\u001B[0m\u001B[1;33m=\u001B[0m\u001B[1;32mFalse\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n",
|
|
||||||
"\u001B[1;32mc:\\users\\steff\\envs\\compare_21\\lib\\site-packages\\pytorch_lightning\\metrics\\classification\\precision_recall.py\u001B[0m in \u001B[0;36m__init__\u001B[1;34m(self, num_classes, threshold, average, multilabel, mdmc_average, ignore_index, top_k, is_multiclass, compute_on_step, dist_sync_on_step, process_group, dist_sync_fn)\u001B[0m\n\u001B[0;32m 139\u001B[0m \u001B[1;32mraise\u001B[0m \u001B[0mValueError\u001B[0m\u001B[1;33m(\u001B[0m\u001B[1;34mf\"The `average` has to be one of {allowed_average}, got {average}.\"\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 140\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m--> 141\u001B[1;33m super().__init__(\n\u001B[0m\u001B[0;32m 142\u001B[0m \u001B[0mreduce\u001B[0m\u001B[1;33m=\u001B[0m\u001B[1;34m\"macro\"\u001B[0m \u001B[1;32mif\u001B[0m \u001B[0maverage\u001B[0m \u001B[1;32min\u001B[0m \u001B[1;33m[\u001B[0m\u001B[1;34m\"weighted\"\u001B[0m\u001B[1;33m,\u001B[0m \u001B[1;34m\"none\"\u001B[0m\u001B[1;33m,\u001B[0m \u001B[1;32mNone\u001B[0m\u001B[1;33m]\u001B[0m \u001B[1;32melse\u001B[0m \u001B[0maverage\u001B[0m\u001B[1;33m,\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 143\u001B[0m \u001B[0mmdmc_reduce\u001B[0m\u001B[1;33m=\u001B[0m\u001B[0mmdmc_average\u001B[0m\u001B[1;33m,\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n",
|
|
||||||
"\u001B[1;32mc:\\users\\steff\\envs\\compare_21\\lib\\site-packages\\pytorch_lightning\\metrics\\classification\\stat_scores.py\u001B[0m in \u001B[0;36m__init__\u001B[1;34m(self, threshold, top_k, reduce, num_classes, ignore_index, mdmc_reduce, is_multiclass, compute_on_step, dist_sync_on_step, process_group, dist_sync_fn)\u001B[0m\n\u001B[0;32m 157\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 158\u001B[0m \u001B[1;32mif\u001B[0m \u001B[0mreduce\u001B[0m \u001B[1;33m==\u001B[0m \u001B[1;34m\"macro\"\u001B[0m \u001B[1;32mand\u001B[0m \u001B[1;33m(\u001B[0m\u001B[1;32mnot\u001B[0m \u001B[0mnum_classes\u001B[0m \u001B[1;32mor\u001B[0m \u001B[0mnum_classes\u001B[0m \u001B[1;33m<\u001B[0m \u001B[1;36m1\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m:\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m--> 159\u001B[1;33m \u001B[1;32mraise\u001B[0m \u001B[0mValueError\u001B[0m\u001B[1;33m(\u001B[0m\u001B[1;34m\"When you set `reduce` as 'macro', you have to provide the number of classes.\"\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m 160\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 161\u001B[0m \u001B[1;32mif\u001B[0m \u001B[0mnum_classes\u001B[0m \u001B[1;32mand\u001B[0m \u001B[0mignore_index\u001B[0m \u001B[1;32mis\u001B[0m \u001B[1;32mnot\u001B[0m \u001B[1;32mNone\u001B[0m \u001B[1;32mand\u001B[0m \u001B[1;33m(\u001B[0m\u001B[1;32mnot\u001B[0m \u001B[1;36m0\u001B[0m \u001B[1;33m<=\u001B[0m \u001B[0mignore_index\u001B[0m \u001B[1;33m<\u001B[0m \u001B[0mnum_classes\u001B[0m \u001B[1;32mor\u001B[0m \u001B[0mnum_classes\u001B[0m \u001B[1;33m==\u001B[0m \u001B[1;36m1\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m:\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n",
|
|
||||||
"\u001B[1;31mValueError\u001B[0m: When you set `reduce` as 'macro', you have to provide the number of classes."
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"exp_path = _ROOT / out_path / model_name / exp_name / version\n",
|
|
||||||
"checkpoint = natsorted(exp_path.glob('*.ckpt'))[-1]\n",
|
|
||||||
"hparams_yaml = next(exp_path.glob('*.yaml'))\n",
|
|
||||||
"\n",
|
|
||||||
"hparams_file = load_hparams_from_yaml(hparams_yaml)\n",
|
|
||||||
"additional_kwargs = dict(variable_length = False, c_classes=5)\n",
|
|
||||||
"\n",
|
|
||||||
"model = model_class.load_from_checkpoint(checkpoint, hparams_file=str(hparams_yaml), **additional_kwargs)\n"
|
|
||||||
],
|
|
||||||
"metadata": {
|
|
||||||
"collapsed": false,
|
|
||||||
"pycharm": {
|
|
||||||
"name": "#%%\n"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"metadata": {
|
|
||||||
"kernelspec": {
|
|
||||||
"display_name": "Python 3",
|
|
||||||
"language": "python",
|
|
||||||
"name": "python3"
|
|
||||||
},
|
|
||||||
"language_info": {
|
|
||||||
"codemirror_mode": {
|
|
||||||
"name": "ipython",
|
|
||||||
"version": 2
|
|
||||||
},
|
|
||||||
"file_extension": ".py",
|
|
||||||
"mimetype": "text/x-python",
|
|
||||||
"name": "python",
|
|
||||||
"nbconvert_exporter": "python",
|
|
||||||
"pygments_lexer": "ipython2",
|
|
||||||
"version": "2.7.6"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"nbformat": 4,
|
|
||||||
"nbformat_minor": 0
|
|
||||||
}
|
|
247
reload model.ipynb
Normal file
247
reload model.ipynb
Normal file
File diff suppressed because one or more lines are too long
@ -18,7 +18,7 @@ class TrainMixin:
|
|||||||
batch_files, batch_x, batch_y = batch_xy
|
batch_files, batch_x, batch_y = batch_xy
|
||||||
y = self(batch_x).main_out
|
y = self(batch_x).main_out
|
||||||
if self.params.n_classes <= 2:
|
if self.params.n_classes <= 2:
|
||||||
loss = self.bce_loss(y, batch_y.long())
|
loss = self.bce_loss(y.squeeze().float(), batch_y.float())
|
||||||
else:
|
else:
|
||||||
if self.params.loss == 'focal_loss_rob':
|
if self.params.loss == 'focal_loss_rob':
|
||||||
labels_one_hot = torch.nn.functional.one_hot(batch_y, num_classes=self.params.n_classes)
|
labels_one_hot = torch.nn.functional.one_hot(batch_y, num_classes=self.params.n_classes)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user