CCS intergration training running
notebooks
This commit is contained in:
		
							
								
								
									
										0
									
								
								datasets/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								datasets/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -1,172 +1,19 @@ | |||||||
| import multiprocessing as mp | from datasets.compare_base import CompareBase | ||||||
| from collections import defaultdict | from ml_lib.utils.tools import add_argparse_args | ||||||
| from pathlib import Path |  | ||||||
|  |  | ||||||
| from torch.utils.data import DataLoader, ConcatDataset, WeightedRandomSampler |  | ||||||
| from torchvision.transforms import Compose, RandomApply |  | ||||||
|  |  | ||||||
| from ml_lib.audio_toolset.audio_io import NormalizeLocal |  | ||||||
| from ml_lib.audio_toolset.audio_to_mel_dataset import LibrosaAudioToMelDataset |  | ||||||
| from ml_lib.audio_toolset.mel_augmentation import NoiseInjection, LoudnessManipulator, ShiftTime, MaskAug |  | ||||||
| from ml_lib.utils._basedatamodule import _BaseDataModule, DATA_OPTION_test, DATA_OPTION_train, DATA_OPTION_devel |  | ||||||
| from ml_lib.utils.equal_sampler import EqualSampler |  | ||||||
| from ml_lib.utils.transforms import ToTensor |  | ||||||
|  |  | ||||||
| data_options = [DATA_OPTION_test, DATA_OPTION_train, DATA_OPTION_devel] |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class CCSLibrosaDatamodule(_BaseDataModule): | class CCSLibrosaDatamodule(CompareBase): | ||||||
|  |  | ||||||
|     @property |     class_names = ['negative', 'positive'] | ||||||
|     def class_names(self): |     sub_dataset_name = 'ComParE2021_CCS' | ||||||
|         return {key: val for val, key in enumerate(['negative', 'positive'])} |  | ||||||
|  |  | ||||||
|     @property |     def __init__(self, *args, **kwargs): | ||||||
|     def n_classes(self): |         super(CCSLibrosaDatamodule, self).__init__(*args, **kwargs) | ||||||
|         return len(self.class_names) |  | ||||||
|  |  | ||||||
|     @property |     @classmethod | ||||||
|     def shape(self): |     def add_argparse_args(cls, parent_parser): | ||||||
|         return self.datasets[DATA_OPTION_train].datasets[0][0][1].shape |         return add_argparse_args(CompareBase, parent_parser) | ||||||
|  |  | ||||||
|     @property |     @classmethod | ||||||
|     def mel_folder(self): |     def from_argparse_args(cls, args, **kwargs): | ||||||
|         return self.root / 'mel_folder' |         return CompareBase.from_argparse_args(args, class_names=cls.class_names, sub_dataset_name=cls.sub_dataset_name) | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def wav_folder(self): |  | ||||||
|         return self.root / 'wav' |  | ||||||
|  |  | ||||||
|     def __init__(self, data_root, batch_size, num_worker, sr, n_mels, n_fft, hop_length, sampler=None, |  | ||||||
|                  random_apply_chance=0.5, target_mel_length_in_seconds=1, |  | ||||||
|                  loudness_ratio=0.3, shift_ratio=0.3, noise_ratio=0.3, mask_ratio=0.3): |  | ||||||
|         super(CCSLibrosaDatamodule, self).__init__() |  | ||||||
|         self.sampler = sampler |  | ||||||
|         self.samplers = None |  | ||||||
|  |  | ||||||
|         self.num_worker = num_worker or 1 |  | ||||||
|         self.batch_size = batch_size |  | ||||||
|         self.root = Path(data_root) / 'ComParE2021_CCS' |  | ||||||
|         self.mel_length_in_seconds = target_mel_length_in_seconds |  | ||||||
|  |  | ||||||
|         # Mel Transforms - will be pushed with all other paramters by self.__dict__ to subdataset-class |  | ||||||
|         self.mel_kwargs = dict(sr=sr, n_mels=n_mels, n_fft=n_fft, hop_length=hop_length) |  | ||||||
|  |  | ||||||
|         # Utility |  | ||||||
|         self.utility_transforms = Compose([NormalizeLocal(), ToTensor()]) |  | ||||||
|  |  | ||||||
|         # Data Augmentations |  | ||||||
|         self.random_apply_chance = random_apply_chance |  | ||||||
|         self.mel_augmentations = Compose([ |  | ||||||
|             RandomApply([NoiseInjection(noise_ratio)], p=random_apply_chance), |  | ||||||
|             RandomApply([LoudnessManipulator(loudness_ratio)], p=random_apply_chance), |  | ||||||
|             RandomApply([ShiftTime(shift_ratio)], p=random_apply_chance), |  | ||||||
|             RandomApply([MaskAug(mask_ratio)], p=random_apply_chance), |  | ||||||
|             self.utility_transforms]) |  | ||||||
|  |  | ||||||
|     def train_dataloader(self): |  | ||||||
|         return DataLoader(dataset=self.datasets[DATA_OPTION_train], num_workers=self.num_worker, pin_memory=True, |  | ||||||
|                           sampler=self.samplers[DATA_OPTION_train], batch_size=self.batch_size) |  | ||||||
|  |  | ||||||
|     # Validation Dataloader |  | ||||||
|     def val_dataloader(self): |  | ||||||
|         return DataLoader(dataset=self.datasets[DATA_OPTION_devel], shuffle=False, |  | ||||||
|                           batch_size=self.batch_size, pin_memory=False, |  | ||||||
|                           num_workers=self.num_worker) |  | ||||||
|  |  | ||||||
|     # Test Dataloader |  | ||||||
|     def test_dataloader(self): |  | ||||||
|         return DataLoader(dataset=self.datasets[DATA_OPTION_test], shuffle=False, |  | ||||||
|                           batch_size=self.batch_size, pin_memory=False, |  | ||||||
|                           num_workers=self.num_worker) |  | ||||||
|  |  | ||||||
|     def _build_subdataset(self, row, build=False): |  | ||||||
|         slice_file_name, class_name = row.strip().split(',') |  | ||||||
|         class_id = self.class_names.get(class_name, -1) |  | ||||||
|         audio_file_path = self.wav_folder / slice_file_name |  | ||||||
|  |  | ||||||
|         # DATA OPTION DIFFERENTIATION !!!!!!!!!!! - Begin |  | ||||||
|         kwargs = self.__dict__ |  | ||||||
|         if any([x in slice_file_name for x in [DATA_OPTION_devel, DATA_OPTION_test]]): |  | ||||||
|             kwargs.update(mel_augmentations=self.utility_transforms) |  | ||||||
|         # DATA OPTION DIFFERENTIATION !!!!!!!!!!! - End |  | ||||||
|  |  | ||||||
|         target_frames = self.mel_length_in_seconds * self.mel_kwargs['sr'] |  | ||||||
|         sample_segment_length = target_frames // self.mel_kwargs['hop_length'] + 1 |  | ||||||
|         kwargs.update(sample_segment_len=sample_segment_length, sample_hop_len=sample_segment_length//2) |  | ||||||
|         mel_dataset = LibrosaAudioToMelDataset(audio_file_path, class_id, **kwargs) |  | ||||||
|         if build: |  | ||||||
|             assert mel_dataset.build_mel() |  | ||||||
|         return mel_dataset, class_id, slice_file_name |  | ||||||
|  |  | ||||||
|     def prepare_data(self, *args, **kwargs): |  | ||||||
|         datasets = dict() |  | ||||||
|         for data_option in data_options: |  | ||||||
|             with open(Path(self.root) / 'lab' / f'{data_option}.csv', mode='r') as f: |  | ||||||
|                 # Exclude the header |  | ||||||
|                 _ = next(f) |  | ||||||
|                 all_rows = list(f) |  | ||||||
|             chunksize = len(all_rows) // max(self.num_worker, 1) |  | ||||||
|             dataset = list() |  | ||||||
|             with mp.Pool(processes=self.num_worker) as pool: |  | ||||||
|  |  | ||||||
|                 from itertools import repeat |  | ||||||
|                 results = pool.starmap_async(self._build_subdataset, zip(all_rows, repeat(True, len(all_rows))), |  | ||||||
|                                              chunksize=chunksize) |  | ||||||
|                 for sub_dataset in results.get(): |  | ||||||
|                     dataset.append(sub_dataset[0]) |  | ||||||
|             datasets[data_option] = ConcatDataset(dataset) |  | ||||||
|             print(f'{data_option}-dataset prepared.') |  | ||||||
|         self.datasets = datasets |  | ||||||
|         return datasets |  | ||||||
|  |  | ||||||
|     def setup(self, stag=None): |  | ||||||
|         datasets = dict() |  | ||||||
|         samplers = dict() |  | ||||||
|         weights = dict() |  | ||||||
|  |  | ||||||
|         for data_option in data_options: |  | ||||||
|             with open(Path(self.root) / 'lab' / f'{data_option}.csv', mode='r') as f: |  | ||||||
|                 # Exclude the header |  | ||||||
|                 _ = next(f) |  | ||||||
|                 all_rows = list(f) |  | ||||||
|             dataset = list() |  | ||||||
|             for row in all_rows: |  | ||||||
|                 mel_dataset, class_id, _ = self._build_subdataset(row) |  | ||||||
|                 dataset.append(mel_dataset) |  | ||||||
|             print(f'{data_option}-dataset prepared!') |  | ||||||
|             datasets[data_option] = ConcatDataset(dataset) |  | ||||||
|  |  | ||||||
|             # Build Weighted Sampler for train and val |  | ||||||
|             if data_option in [DATA_OPTION_train]: |  | ||||||
|                 if self.sampler == EqualSampler.__name__: |  | ||||||
|                     class_idxs = [[idx for idx, (_, __, label) in enumerate(datasets[data_option]) if label == class_idx] |  | ||||||
|                                   for class_idx in range(len(self.class_names)) |  | ||||||
|                                   ] |  | ||||||
|                     samplers[data_option] = EqualSampler(class_idxs) |  | ||||||
|                 elif self.sampler == WeightedRandomSampler.__name__: |  | ||||||
|                     class_counts = defaultdict(lambda: 0) |  | ||||||
|                     for _, __, label in datasets[data_option]: |  | ||||||
|                         class_counts[label] += 1 |  | ||||||
|                     len_largest_class = max(class_counts.values()) |  | ||||||
|  |  | ||||||
|                     weights[data_option] = [1 / class_counts[x] for x in range(len(class_counts))] |  | ||||||
|                     ############################################################################## |  | ||||||
|                     weights[data_option] = [weights[data_option][datasets[data_option][i][-1]] |  | ||||||
|                                             for i in range(len(datasets[data_option]))] |  | ||||||
|                     samplers[data_option] = WeightedRandomSampler(weights[data_option], |  | ||||||
|                                                                   len_largest_class * len(self.class_names)) |  | ||||||
|                 else: |  | ||||||
|                     samplers[data_option] = None |  | ||||||
|         self.datasets = datasets |  | ||||||
|         self.samplers = samplers |  | ||||||
|         print(f'Dataset {self.__class__.__name__} setup done.') |  | ||||||
|         return datasets |  | ||||||
|  |  | ||||||
|     def purge(self): |  | ||||||
|         import shutil |  | ||||||
|  |  | ||||||
|         shutil.rmtree(self.mel_folder, ignore_errors=True) |  | ||||||
|         print('Mel Folder has been recursively deleted') |  | ||||||
|         print(f'Folder still exists: {self.mel_folder.exists()}') |  | ||||||
|         return not self.mel_folder.exists() |  | ||||||
|   | |||||||
							
								
								
									
										181
									
								
								datasets/compare_base.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										181
									
								
								datasets/compare_base.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,181 @@ | |||||||
|  | import multiprocessing as mp | ||||||
|  | from collections import defaultdict | ||||||
|  | from pathlib import Path | ||||||
|  |  | ||||||
|  | from torch.utils.data import DataLoader, ConcatDataset, WeightedRandomSampler | ||||||
|  | from torchvision.transforms import Compose, RandomApply | ||||||
|  |  | ||||||
|  | from ml_lib.audio_toolset.audio_io import NormalizeLocal | ||||||
|  | from ml_lib.audio_toolset.audio_to_mel_dataset import LibrosaAudioToMelDataset | ||||||
|  | from ml_lib.audio_toolset.mel_augmentation import NoiseInjection, LoudnessManipulator, ShiftTime, MaskAug | ||||||
|  | from ml_lib.utils._basedatamodule import _BaseDataModule, DATA_OPTION_test, DATA_OPTION_train, DATA_OPTION_devel | ||||||
|  | from ml_lib.utils.equal_sampler import EqualSampler | ||||||
|  | from ml_lib.utils.tools import add_argparse_args | ||||||
|  | from ml_lib.utils.transforms import ToTensor | ||||||
|  |  | ||||||
|  | data_options = [DATA_OPTION_test, DATA_OPTION_train, DATA_OPTION_devel] | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class CompareBase(_BaseDataModule): | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def class_names(self): | ||||||
|  |         return {key: val for val, key in enumerate(self._class_names)} | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def n_classes(self): | ||||||
|  |         return len(self.class_names) | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def shape(self): | ||||||
|  |         return 1, int(self.mel_kwargs['n_mels']), int(self.sample_segment_length) | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def mel_folder(self): | ||||||
|  |         return self.root / 'mel_folder' | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def wav_folder(self): | ||||||
|  |         return self.root / 'wav' | ||||||
|  |  | ||||||
|  |     def __init__(self, sub_dataset_name, class_names, data_root, batch_size, num_worker, sr, n_mels, n_fft, hop_length, sampler=None, | ||||||
|  |                  random_apply_chance=0.5, target_mel_length_in_seconds=1, | ||||||
|  |                  loudness_ratio=0.3, shift_ratio=0.3, noise_ratio=0.3, mask_ratio=0.3): | ||||||
|  |         super(CompareBase, self).__init__() | ||||||
|  |         self.sampler = sampler | ||||||
|  |         self.samplers = None | ||||||
|  |  | ||||||
|  |         self.num_worker = num_worker or 1 | ||||||
|  |         self.batch_size = batch_size | ||||||
|  |         self.root = Path(data_root) / sub_dataset_name | ||||||
|  |         self._class_names = class_names | ||||||
|  |         self.mel_length_in_seconds = target_mel_length_in_seconds | ||||||
|  |  | ||||||
|  |         # Mel Transforms - will be pushed with all other paramters by self.__dict__ to subdataset-class | ||||||
|  |         self.mel_kwargs = dict(sr=sr, n_mels=n_mels, n_fft=n_fft, hop_length=hop_length) | ||||||
|  |  | ||||||
|  |         target_frames = self.mel_length_in_seconds * self.mel_kwargs['sr'] | ||||||
|  |         self.sample_segment_length = target_frames // self.mel_kwargs['hop_length'] + 1 | ||||||
|  |  | ||||||
|  |         # Utility | ||||||
|  |         self.utility_transforms = Compose([NormalizeLocal(), ToTensor()]) | ||||||
|  |  | ||||||
|  |         # Data Augmentations | ||||||
|  |         self.random_apply_chance = random_apply_chance | ||||||
|  |         self.mel_augmentations = Compose([ | ||||||
|  |             RandomApply([NoiseInjection(noise_ratio)], p=random_apply_chance), | ||||||
|  |             RandomApply([LoudnessManipulator(loudness_ratio)], p=random_apply_chance), | ||||||
|  |             RandomApply([ShiftTime(shift_ratio)], p=random_apply_chance), | ||||||
|  |             RandomApply([MaskAug(mask_ratio)], p=random_apply_chance), | ||||||
|  |             self.utility_transforms]) | ||||||
|  |  | ||||||
|  |     def train_dataloader(self): | ||||||
|  |         return DataLoader(dataset=self.datasets[DATA_OPTION_train], num_workers=self.num_worker, pin_memory=True, | ||||||
|  |                           sampler=self.samplers[DATA_OPTION_train], batch_size=self.batch_size) | ||||||
|  |  | ||||||
|  |     # Validation Dataloader | ||||||
|  |     def val_dataloader(self): | ||||||
|  |         return DataLoader(dataset=self.datasets[DATA_OPTION_devel], shuffle=False, | ||||||
|  |                           batch_size=self.batch_size, pin_memory=False, | ||||||
|  |                           num_workers=self.num_worker) | ||||||
|  |  | ||||||
|  |     # Test Dataloader | ||||||
|  |     def test_dataloader(self): | ||||||
|  |         return DataLoader(dataset=self.datasets[DATA_OPTION_test], shuffle=False, | ||||||
|  |                           batch_size=self.batch_size, pin_memory=False, | ||||||
|  |                           num_workers=self.num_worker) | ||||||
|  |  | ||||||
|  |     def _build_subdataset(self, row, build=False): | ||||||
|  |         slice_file_name, class_name = row.strip().split(',') | ||||||
|  |         class_id = self.class_names.get(class_name, -1) | ||||||
|  |         audio_file_path = self.wav_folder / slice_file_name | ||||||
|  |  | ||||||
|  |         # DATA OPTION DIFFERENTIATION !!!!!!!!!!! - Begin | ||||||
|  |         kwargs = self.__dict__ | ||||||
|  |         if any([x in slice_file_name for x in [DATA_OPTION_devel, DATA_OPTION_test]]): | ||||||
|  |             kwargs.update(mel_augmentations=self.utility_transforms) | ||||||
|  |         # DATA OPTION DIFFERENTIATION !!!!!!!!!!! - End | ||||||
|  |  | ||||||
|  |  | ||||||
|  |         kwargs.update(sample_segment_len=self.sample_segment_length, sample_hop_len=self.sample_segment_length//2) | ||||||
|  |         mel_dataset = LibrosaAudioToMelDataset(audio_file_path, class_id, **kwargs) | ||||||
|  |         if build: | ||||||
|  |             assert mel_dataset.build_mel() | ||||||
|  |         return mel_dataset, class_id, slice_file_name | ||||||
|  |  | ||||||
|  |     def manual_setup(self, stag=None): | ||||||
|  |         datasets = dict() | ||||||
|  |         for data_option in data_options: | ||||||
|  |             with open(Path(self.root) / 'lab' / f'{data_option}.csv', mode='r') as f: | ||||||
|  |                 # Exclude the header | ||||||
|  |                 _ = next(f) | ||||||
|  |                 all_rows = list(f) | ||||||
|  |             chunksize = len(all_rows) // max(self.num_worker, 1) | ||||||
|  |             dataset = list() | ||||||
|  |             with mp.Pool(processes=self.num_worker) as pool: | ||||||
|  |  | ||||||
|  |                 from itertools import repeat | ||||||
|  |                 results = pool.starmap_async(self._build_subdataset, zip(all_rows, repeat(True, len(all_rows))), | ||||||
|  |                                              chunksize=chunksize) | ||||||
|  |                 for sub_dataset in results.get(): | ||||||
|  |                     dataset.append(sub_dataset[0]) | ||||||
|  |             datasets[data_option] = ConcatDataset(dataset) | ||||||
|  |             print(f'{data_option}-dataset prepared.') | ||||||
|  |         self.datasets = datasets | ||||||
|  |         return datasets | ||||||
|  |  | ||||||
|  |     def prepare_data(self, *args, rebuild=False, **kwargs): | ||||||
|  |         datasets = dict() | ||||||
|  |         samplers = dict() | ||||||
|  |         weights = dict() | ||||||
|  |  | ||||||
|  |         for data_option in data_options: | ||||||
|  |             with open(Path(self.root) / 'lab' / f'{data_option}.csv', mode='r') as f: | ||||||
|  |                 # Exclude the header | ||||||
|  |                 _ = next(f) | ||||||
|  |                 all_rows = list(f) | ||||||
|  |             chunksize = len(all_rows) // max(self.num_worker, 1) | ||||||
|  |             dataset = list() | ||||||
|  |             with mp.Pool(processes=self.num_worker) as pool: | ||||||
|  |  | ||||||
|  |                 from itertools import repeat | ||||||
|  |                 results = pool.starmap_async(self._build_subdataset, zip(all_rows, repeat(rebuild, len(all_rows))), | ||||||
|  |                                              chunksize=chunksize) | ||||||
|  |                 for sub_dataset in results.get(): | ||||||
|  |                     dataset.append(sub_dataset[0]) | ||||||
|  |             datasets[data_option] = ConcatDataset(dataset) | ||||||
|  |             print(f'{data_option}-dataset set up!') | ||||||
|  |  | ||||||
|  |             # Build Weighted Sampler for train and val | ||||||
|  |             if data_option in [DATA_OPTION_train]: | ||||||
|  |                 if self.sampler == EqualSampler.__name__: | ||||||
|  |                     class_idxs = [[idx for idx, (_, __, label) in enumerate(datasets[data_option]) if label == class_idx] | ||||||
|  |                                   for class_idx in range(len(self.class_names)) | ||||||
|  |                                   ] | ||||||
|  |                     samplers[data_option] = EqualSampler(class_idxs) | ||||||
|  |                 elif self.sampler == WeightedRandomSampler.__name__: | ||||||
|  |                     class_counts = defaultdict(lambda: 0) | ||||||
|  |                     for _, __, label in datasets[data_option]: | ||||||
|  |                         class_counts[label] += 1 | ||||||
|  |                     len_largest_class = max(class_counts.values()) | ||||||
|  |  | ||||||
|  |                     weights[data_option] = [1 / class_counts[x] for x in range(len(class_counts))] | ||||||
|  |                     ############################################################################## | ||||||
|  |                     weights[data_option] = [weights[data_option][datasets[data_option][i][-1]] | ||||||
|  |                                             for i in range(len(datasets[data_option]))] | ||||||
|  |                     samplers[data_option] = WeightedRandomSampler(weights[data_option], | ||||||
|  |                                                                   len_largest_class * len(self.class_names)) | ||||||
|  |                 else: | ||||||
|  |                     samplers[data_option] = None | ||||||
|  |         self.datasets = datasets | ||||||
|  |         self.samplers = samplers | ||||||
|  |         print(f'Dataset {self.__class__.__name__} setup done.') | ||||||
|  |         return datasets | ||||||
|  |  | ||||||
|  |     def purge(self): | ||||||
|  |         import shutil | ||||||
|  |  | ||||||
|  |         shutil.rmtree(self.mel_folder, ignore_errors=True) | ||||||
|  |         print('Mel Folder has been recursively deleted') | ||||||
|  |         print(f'Folder still exists: {self.mel_folder.exists()}') | ||||||
|  |         return not self.mel_folder.exists() | ||||||
| @@ -1,170 +1,24 @@ | |||||||
| import multiprocessing as mp | from argparse import ArgumentParser, Namespace | ||||||
| from collections import defaultdict | from ctypes import Union | ||||||
| from pathlib import Path |  | ||||||
|  |  | ||||||
| from torch.utils.data import DataLoader, ConcatDataset, WeightedRandomSampler | from datasets.compare_base import CompareBase | ||||||
| from torchvision.transforms import Compose, RandomApply | from ml_lib.utils.tools import add_argparse_args | ||||||
|  |  | ||||||
| from ml_lib.audio_toolset.audio_io import NormalizeLocal |  | ||||||
| from ml_lib.audio_toolset.audio_to_mel_dataset import LibrosaAudioToMelDataset |  | ||||||
| from ml_lib.audio_toolset.mel_augmentation import NoiseInjection, LoudnessManipulator, ShiftTime, MaskAug |  | ||||||
| from ml_lib.utils._basedatamodule import _BaseDataModule, DATA_OPTION_test, DATA_OPTION_train, DATA_OPTION_devel |  | ||||||
| from ml_lib.utils.equal_sampler import EqualSampler |  | ||||||
| from ml_lib.utils.transforms import ToTensor |  | ||||||
|  |  | ||||||
| data_options = [DATA_OPTION_test, DATA_OPTION_train, DATA_OPTION_devel] |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class PrimatesLibrosaDatamodule(_BaseDataModule): | class PrimatesLibrosaDatamodule(CompareBase): | ||||||
|  |  | ||||||
|     @property |     class_names = ['background', 'chimpanze', 'geunon', 'mandrille', 'redcap'] | ||||||
|     def class_names(self): |     sub_dataset_name = 'primates' | ||||||
|         return {key: val for val, key in enumerate(['background', 'chimpanze', 'geunon', 'mandrille', 'redcap'])} |  | ||||||
|  |  | ||||||
|     @property |     def __init__(self, *args, **kwargs): | ||||||
|     def n_classes(self): |         super(PrimatesLibrosaDatamodule, self).__init__(*args, **kwargs) | ||||||
|         return len(self.class_names) |  | ||||||
|  |  | ||||||
|     @property |     @classmethod | ||||||
|     def shape(self): |     def add_argparse_args(cls, parent_parser): | ||||||
|  |         return add_argparse_args(CompareBase, parent_parser) | ||||||
|  |  | ||||||
|         return self.datasets[DATA_OPTION_train].datasets[0][0][1].shape |  | ||||||
|  |  | ||||||
|     @property |     @classmethod | ||||||
|     def mel_folder(self): |     def from_argparse_args(cls, args, **kwargs): | ||||||
|         return self.root / 'mel_folder' |         return CompareBase.from_argparse_args(args, class_names=cls.class_names, sub_dataset_name=cls.sub_dataset_name) | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def wav_folder(self): |  | ||||||
|         return self.root / 'wav' |  | ||||||
|  |  | ||||||
|     def __init__(self, data_root, batch_size, num_worker, sr, n_mels, n_fft, hop_length, sampler=None, |  | ||||||
|                  target_mel_length_in_seconds=0.7, random_apply_chance=0.5, |  | ||||||
|                  loudness_ratio=0.3, shift_ratio=0.3, noise_ratio=0.3, mask_ratio=0.3): |  | ||||||
|         super(PrimatesLibrosaDatamodule, self).__init__() |  | ||||||
|         self.sampler = sampler |  | ||||||
|         self.samplers = None |  | ||||||
|  |  | ||||||
|         self.num_worker = num_worker or 1 |  | ||||||
|         self.batch_size = batch_size |  | ||||||
|         self.root = Path(data_root) / 'primates' |  | ||||||
|         self.target_mel_length_in_seconds = target_mel_length_in_seconds |  | ||||||
|  |  | ||||||
|         # Mel Transforms - will be pushed with all other paramters by self.__dict__ to subdataset-class |  | ||||||
|         self.mel_kwargs = dict(sr=sr, n_mels=n_mels, n_fft=n_fft, hop_length=hop_length) |  | ||||||
|  |  | ||||||
|         # Utility |  | ||||||
|         self.utility_transforms = Compose([NormalizeLocal(), ToTensor()]) |  | ||||||
|  |  | ||||||
|         # Data Augmentations |  | ||||||
|         self.random_apply_chance = random_apply_chance |  | ||||||
|         self.mel_augmentations = Compose([ |  | ||||||
|             RandomApply([NoiseInjection(noise_ratio)], p=random_apply_chance), |  | ||||||
|             RandomApply([LoudnessManipulator(loudness_ratio)], p=random_apply_chance), |  | ||||||
|             RandomApply([ShiftTime(shift_ratio)], p=random_apply_chance), |  | ||||||
|             RandomApply([MaskAug(mask_ratio)], p=random_apply_chance), |  | ||||||
|             self.utility_transforms]) |  | ||||||
|  |  | ||||||
|     def train_dataloader(self): |  | ||||||
|         return DataLoader(dataset=self.datasets[DATA_OPTION_train], num_workers=self.num_worker, pin_memory=True, |  | ||||||
|                           sampler=self.samplers[DATA_OPTION_train], batch_size=self.batch_size) |  | ||||||
|  |  | ||||||
|     # Validation Dataloader |  | ||||||
|     def val_dataloader(self): |  | ||||||
|         return DataLoader(dataset=self.datasets[DATA_OPTION_devel], shuffle=False, |  | ||||||
|                           batch_size=self.batch_size, pin_memory=False, |  | ||||||
|                           num_workers=self.num_worker) |  | ||||||
|  |  | ||||||
|     # Test Dataloader |  | ||||||
|     def test_dataloader(self): |  | ||||||
|         return DataLoader(dataset=self.datasets[DATA_OPTION_test], shuffle=False, |  | ||||||
|                           batch_size=self.batch_size, pin_memory=False, |  | ||||||
|                           num_workers=self.num_worker) |  | ||||||
|  |  | ||||||
|     def _build_subdataset(self, row, build=False): |  | ||||||
|         slice_file_name, class_name = row.strip().split(',') |  | ||||||
|         class_id = self.class_names.get(class_name, -1) |  | ||||||
|         audio_file_path = self.wav_folder / slice_file_name |  | ||||||
|  |  | ||||||
|         # DATA OPTION DIFFERENTIATION !!!!!!!!!!! - Begin |  | ||||||
|         kwargs = self.__dict__ |  | ||||||
|         if any([x in slice_file_name for x in [DATA_OPTION_devel, DATA_OPTION_test]]): |  | ||||||
|             kwargs.update(mel_augmentations=self.utility_transforms) |  | ||||||
|         # DATA OPTION DIFFERENTIATION !!!!!!!!!!! - End |  | ||||||
|  |  | ||||||
|         target_frames = self.target_mel_length_in_seconds * self.mel_kwargs['sr'] |  | ||||||
|         sample_segment_length = target_frames // self.mel_kwargs['hop_length'] + 1 |  | ||||||
|         kwargs.update(sample_segment_len=sample_segment_length, sample_hop_len=sample_segment_length//2) |  | ||||||
|         mel_dataset = LibrosaAudioToMelDataset(audio_file_path, class_id, **kwargs) |  | ||||||
|         if build: |  | ||||||
|             assert mel_dataset.build_mel() |  | ||||||
|         return mel_dataset, class_id, slice_file_name |  | ||||||
|  |  | ||||||
|     def prepare_data(self, *args, **kwargs): |  | ||||||
|         datasets = dict() |  | ||||||
|         for data_option in data_options: |  | ||||||
|             with open(Path(self.root) / 'lab' / f'{data_option}.csv', mode='r') as f: |  | ||||||
|                 # Exclude the header |  | ||||||
|                 _ = next(f) |  | ||||||
|                 all_rows = list(f) |  | ||||||
|             chunksize = len(all_rows) // max(self.num_worker, 1) |  | ||||||
|             dataset = list() |  | ||||||
|             with mp.Pool(processes=self.num_worker) as pool: |  | ||||||
|  |  | ||||||
|                 from itertools import repeat |  | ||||||
|                 results = pool.starmap_async(self._build_subdataset, zip(all_rows, repeat(True, len(all_rows))), |  | ||||||
|                                              chunksize=chunksize) |  | ||||||
|                 for sub_dataset in results.get(): |  | ||||||
|                     dataset.append(sub_dataset[0]) |  | ||||||
|             datasets[data_option] = ConcatDataset(dataset) |  | ||||||
|         self.datasets = datasets |  | ||||||
|         return datasets |  | ||||||
|  |  | ||||||
|     def setup(self, stag=None): |  | ||||||
|         datasets = dict() |  | ||||||
|         samplers = dict() |  | ||||||
|         weights = dict() |  | ||||||
|  |  | ||||||
|         for data_option in data_options: |  | ||||||
|             with open(Path(self.root) / 'lab' / f'{data_option}.csv', mode='r') as f: |  | ||||||
|                 # Exclude the header |  | ||||||
|                 _ = next(f) |  | ||||||
|                 all_rows = list(f) |  | ||||||
|             dataset = list() |  | ||||||
|             for row in all_rows: |  | ||||||
|                 mel_dataset, class_id, _ = self._build_subdataset(row) |  | ||||||
|                 dataset.append(mel_dataset) |  | ||||||
|             datasets[data_option] = ConcatDataset(dataset) |  | ||||||
|  |  | ||||||
|             # Build Weighted Sampler for train and val |  | ||||||
|             if data_option in [DATA_OPTION_train]: |  | ||||||
|                 if self.sampler == EqualSampler.__name__: |  | ||||||
|                     class_idxs = [[idx for idx, (_, __, label) in enumerate(datasets[data_option]) if label == class_idx] |  | ||||||
|                                   for class_idx in range(len(self.class_names)) |  | ||||||
|                                   ] |  | ||||||
|                     samplers[data_option] = EqualSampler(class_idxs) |  | ||||||
|                 elif self.sampler == WeightedRandomSampler.__name__: |  | ||||||
|                     class_counts = defaultdict(lambda: 0) |  | ||||||
|                     for _, __, label in datasets[data_option]: |  | ||||||
|                         class_counts[label] += 1 |  | ||||||
|                     len_largest_class = max(class_counts.values()) |  | ||||||
|  |  | ||||||
|                     weights[data_option] = [1 / class_counts[x] for x in range(len(class_counts))] |  | ||||||
|                     ############################################################################## |  | ||||||
|                     weights[data_option] = [weights[data_option][datasets[data_option][i][-1]] |  | ||||||
|                                             for i in range(len(datasets[data_option]))] |  | ||||||
|                     samplers[data_option] = WeightedRandomSampler(weights[data_option], |  | ||||||
|                                                                   len_largest_class * len(self.class_names)) |  | ||||||
|                 else: |  | ||||||
|                     samplers[data_option] = None |  | ||||||
|         self.datasets = datasets |  | ||||||
|         self.samplers = samplers |  | ||||||
|         return datasets |  | ||||||
|  |  | ||||||
|     def purge(self): |  | ||||||
|         import shutil |  | ||||||
|  |  | ||||||
|         shutil.rmtree(self.mel_folder, ignore_errors=True) |  | ||||||
|         print('Mel Folder has been recursively deleted') |  | ||||||
|         print(f'Folder still exists: {self.mel_folder.exists()}') |  | ||||||
|         return not self.mel_folder.exists() |  | ||||||
|   | |||||||
							
								
								
									
										23
									
								
								main.py
									
									
									
									
									
								
							
							
						
						
									
										23
									
								
								main.py
									
									
									
									
									
								
							| @@ -2,6 +2,7 @@ from argparse import Namespace | |||||||
|  |  | ||||||
| import warnings | import warnings | ||||||
|  |  | ||||||
|  | import yaml | ||||||
| from pytorch_lightning import Trainer, Callback | from pytorch_lightning import Trainer, Callback | ||||||
| from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint | from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint | ||||||
|  |  | ||||||
| @@ -16,7 +17,7 @@ warnings.filterwarnings('ignore', category=FutureWarning) | |||||||
| warnings.filterwarnings('ignore', category=UserWarning) | warnings.filterwarnings('ignore', category=UserWarning) | ||||||
|  |  | ||||||
|  |  | ||||||
| def run_lightning_loop(h_params, data_class, model_class, seed=69, additional_callbacks=None): | def run_lightning_loop(h_params :Namespace, data_class, model_class, seed=69, additional_callbacks=None): | ||||||
|  |  | ||||||
|     fix_all_random_seeds(seed) |     fix_all_random_seeds(seed) | ||||||
|  |  | ||||||
| @@ -54,16 +55,23 @@ def run_lightning_loop(h_params, data_class, model_class, seed=69, additional_ca | |||||||
|         # ============================================================================= |         # ============================================================================= | ||||||
|         # Let Datamodule pull what it wants |         # Let Datamodule pull what it wants | ||||||
|         datamodule = data_class.from_argparse_args(h_params) |         datamodule = data_class.from_argparse_args(h_params) | ||||||
|         datamodule.setup() |  | ||||||
|  |         # Final h_params Setup: | ||||||
|  |         h_params = vars(h_params) | ||||||
|  |         h_params.update(in_shape=datamodule.shape, n_classes=datamodule.n_classes) | ||||||
|  |         h_params = Namespace(**h_params) | ||||||
|  |  | ||||||
|         # Let Trainer pull what it wants and add callbacks |         # Let Trainer pull what it wants and add callbacks | ||||||
|         trainer = Trainer.from_argparse_args(h_params, logger=logger, callbacks=callbacks) |         trainer = Trainer.from_argparse_args(h_params, logger=logger, callbacks=callbacks) | ||||||
|  |  | ||||||
|         # Let Model pull what it wants |         # Let Model pull what it wants | ||||||
|         model = model_class.from_argparse_args(h_params, in_shape=datamodule.shape, n_classes=datamodule.n_classes) |         model = model_class.from_argparse_args(h_params) | ||||||
|         model.init_weights() |         model.init_weights() | ||||||
|  |  | ||||||
|         # trainer.test(model=model, datamodule=datamodule) |         # Store Model in Object File: | ||||||
|  |         model.save_to_disk(logger.save_dir) | ||||||
|  |         # Store h_params to yaml_file File & Neptune (if available): | ||||||
|  |         logger.log_hyperparams(h_params) | ||||||
|  |  | ||||||
|         trainer.fit(model, datamodule) |         trainer.fit(model, datamodule) | ||||||
|         trainer.save_checkpoint(logger.save_dir / 'last_weights.ckpt') |         trainer.save_checkpoint(logger.save_dir / 'last_weights.ckpt') | ||||||
| @@ -73,10 +81,9 @@ def run_lightning_loop(h_params, data_class, model_class, seed=69, additional_ca | |||||||
|         except: |         except: | ||||||
|             print('Test did not Suceed!') |             print('Test did not Suceed!') | ||||||
|             pass |             pass | ||||||
|         try: |  | ||||||
|             logger.log_metrics(score_callback.best_scores, step=trainer.global_step+1) |         logger.log_metrics(score_callback.best_scores, step=trainer.global_step+1) | ||||||
|         except: |  | ||||||
|             print('debug max_score_logging') |  | ||||||
|         return score_callback.best_scores['PL_recall_score'] |         return score_callback.best_scores['PL_recall_score'] | ||||||
|  |  | ||||||
|  |  | ||||||
|   | |||||||
| @@ -7,7 +7,7 @@ from torch import nn | |||||||
| from einops import rearrange, repeat | from einops import rearrange, repeat | ||||||
|  |  | ||||||
| from ml_lib.metrics.multi_class_classification import MultiClassScores | from ml_lib.metrics.multi_class_classification import MultiClassScores | ||||||
| from ml_lib.modules.blocks import TransformerModule | from ml_lib.modules.blocks import (TransformerModule, F_x) | ||||||
| from ml_lib.modules.util import (LightningBaseModule, AutoPadToShape) | from ml_lib.modules.util import (LightningBaseModule, AutoPadToShape) | ||||||
| from util.module_mixins import CombinedModelMixins | from util.module_mixins import CombinedModelMixins | ||||||
|  |  | ||||||
| @@ -21,7 +21,8 @@ class VisualTransformer(CombinedModelMixins, | |||||||
|     def __init__(self, in_shape, n_classes, weight_init, activation, |     def __init__(self, in_shape, n_classes, weight_init, activation, | ||||||
|                  embedding_size, heads, attn_depth, patch_size, use_residual, variable_length, |                  embedding_size, heads, attn_depth, patch_size, use_residual, variable_length, | ||||||
|                  use_bias, use_norm, dropout, lat_dim, loss, scheduler, mlp_dim, head_dim, |                  use_bias, use_norm, dropout, lat_dim, loss, scheduler, mlp_dim, head_dim, | ||||||
|                  lr, weight_decay, sto_weight_avg, lr_scheduler_parameter, opt_reset_interval): |                  lr, weight_decay, sto_weight_avg, lr_scheduler_parameter, opt_reset_interval, | ||||||
|  |                  return_logits=False): | ||||||
|  |  | ||||||
|         # TODO: Move this to parent class, or make it much easier to access... But How... |         # TODO: Move this to parent class, or make it much easier to access... But How... | ||||||
|         a = dict(locals()) |         a = dict(locals()) | ||||||
| @@ -69,14 +70,20 @@ class VisualTransformer(CombinedModelMixins, | |||||||
|         self.to_cls_token = nn.Identity() |         self.to_cls_token = nn.Identity() | ||||||
|  |  | ||||||
|         logits = self.params.n_classes if self.params.n_classes > 2 else 1 |         logits = self.params.n_classes if self.params.n_classes > 2 else 1 | ||||||
|          |  | ||||||
|  |         if return_logits: | ||||||
|  |             outbound_activation = nn.Identity() | ||||||
|  |         else: | ||||||
|  |             outbound_activation = nn.Softmax() if logits > 1 else nn.Sigmoid() | ||||||
|  |  | ||||||
|  |  | ||||||
|         self.mlp_head = nn.Sequential( |         self.mlp_head = nn.Sequential( | ||||||
|             nn.LayerNorm(self.embed_dim), |             nn.LayerNorm(self.embed_dim), | ||||||
|             nn.Linear(self.embed_dim, self.params.lat_dim), |             nn.Linear(self.embed_dim, self.params.lat_dim), | ||||||
|             nn.GELU(), |             self.params.activation(), | ||||||
|             nn.Dropout(self.params.dropout), |             nn.Dropout(self.params.dropout), | ||||||
|             nn.Linear(self.params.lat_dim, logits), |             nn.Linear(self.params.lat_dim, logits), | ||||||
|             nn.Softmax() if logits > 1 else nn.Sigmoid() |             outbound_activation | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|     def forward(self, x, mask=None, return_attn_weights=False): |     def forward(self, x, mask=None, return_attn_weights=False): | ||||||
|   | |||||||
							
								
								
									
										383
									
								
								notebooks/Dataset Analysis.ipynb
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										383
									
								
								notebooks/Dataset Analysis.ipynb
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							
							
								
								
									
										397
									
								
								notebooks/Train Eval.ipynb
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										397
									
								
								notebooks/Train Eval.ipynb
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,397 @@ | |||||||
|  | { | ||||||
|  |  "cells": [ | ||||||
|  |   { | ||||||
|  |    "cell_type": "code", | ||||||
|  |    "execution_count": 6, | ||||||
|  |    "outputs": [], | ||||||
|  |    "source": [ | ||||||
|  |     "from collections import defaultdict\n", | ||||||
|  |     "from pathlib import Path\n", | ||||||
|  |     "from natsort import natsorted\n", | ||||||
|  |     "from pytorch_lightning.core.saving import ModelIO\n", | ||||||
|  |     "from ml_lib.utils.model_io import SavedLightningModels\n", | ||||||
|  |     "from ml_lib.utils.tools import locate_and_import_class\n", | ||||||
|  |     "\n", | ||||||
|  |     "import yaml\n", | ||||||
|  |     "\n", | ||||||
|  |     "import numpy as np\n", | ||||||
|  |     "import torch\n", | ||||||
|  |     "import pytorch_lightning as pl\n", | ||||||
|  |     "import librosa\n", | ||||||
|  |     "import pandas as pd\n", | ||||||
|  |     "import variables as v\n", | ||||||
|  |     "import seaborn as sns\n", | ||||||
|  |     "from tqdm import tqdm\n", | ||||||
|  |     "from matplotlib import pyplot as plt" | ||||||
|  |    ], | ||||||
|  |    "metadata": { | ||||||
|  |     "collapsed": false, | ||||||
|  |     "pycharm": { | ||||||
|  |      "name": "#%% Imports go here\n" | ||||||
|  |     } | ||||||
|  |    } | ||||||
|  |   }, | ||||||
|  |   { | ||||||
|  |    "cell_type": "code", | ||||||
|  |    "execution_count": 12, | ||||||
|  |    "outputs": [], | ||||||
|  |    "source": [ | ||||||
|  |     "# Settings and Variables\n", | ||||||
|  |     "\n", | ||||||
|  |     "# This Experiment (= Model and Parameter Configuration\n", | ||||||
|  |     "_ROOT = Path('..')\n", | ||||||
|  |     "out_path = Path('..') / Path('output')\n", | ||||||
|  |     "model_name = 'VisualTransformer'\n" | ||||||
|  |    ], | ||||||
|  |    "metadata": { | ||||||
|  |     "collapsed": false, | ||||||
|  |     "pycharm": { | ||||||
|  |      "name": "#%%\n" | ||||||
|  |     } | ||||||
|  |    } | ||||||
|  |   }, | ||||||
|  |   { | ||||||
|  |    "cell_type": "code", | ||||||
|  |    "execution_count": 42, | ||||||
|  |    "outputs": [], | ||||||
|  |    "source": [ | ||||||
|  |     "def print_stats(data_option, mean_duration, std_duration, min_duration, max_duration):\n", | ||||||
|  |     "    print(f'For {data_option}; statistics are:')\n", | ||||||
|  |     "    print(f'Scores - mean: {mean_duration:.3f}s\\tstd: {std_duration:.3f}s'\n", | ||||||
|  |     "          f'min: {min_duration:.3f}s\\t max: {max_duration:.3f}s')\n", | ||||||
|  |     "\n", | ||||||
|  |     "def print_metrics(exp_path):\n", | ||||||
|  |     "    print(f'--------------{exp_path.name}------------------')\n", | ||||||
|  |     "    best_scores = []\n", | ||||||
|  |     "    had_errors = []\n", | ||||||
|  |     "    for run_folder in [x for x in exp_path.iterdir() if x.is_dir()]:\n", | ||||||
|  |     "        # model_class = locate_and_import_class(model_name, 'models')\n", | ||||||
|  |     "        # sorted_checkpoints = natsorted(run_folder.glob('*.ckpt'))\n", | ||||||
|  |     "        # model = ModelIO.load_from_checkpoint(str(sorted_checkpoints[0]), strict=True)\n", | ||||||
|  |     "        try:\n", | ||||||
|  |     "            metrics = pd.read_csv(run_folder / 'metrics.csv')\n", | ||||||
|  |     "\n", | ||||||
|  |     "            # Possible keys are:\n", | ||||||
|  |     "            # -- CE - Losses:\n", | ||||||
|  |     "            # val_max_vote_loss, val_mean_vote_loss, mean_val_loss\n", | ||||||
|  |     "            # -- Fallback:\n", | ||||||
|  |     "            # mean_loss,epoch,step,macro_f1_score, macro_roc_auc_ovr, uar_score, micro_f1_score\n", | ||||||
|  |     "            # Pytorch Metrics:\n", | ||||||
|  |     "            # PL_f1_score,PL_accuracy_score_score, PL_fbeta_score,PL_recall_score,PL_precision_score,\n", | ||||||
|  |     "            score = metrics.PL_recall_score[-1]\n", | ||||||
|  |     "            print(f'{exp_path.name} - {run_folder.name}: {score}')\n", | ||||||
|  |     "            best_scores.append(score)\n", | ||||||
|  |     "            had_errors.append(False)\n", | ||||||
|  |     "        except (AttributeError, FileNotFoundError):\n", | ||||||
|  |     "            had_errors.append(True)\n", | ||||||
|  |     "            pass\n", | ||||||
|  |     "    if any(had_errors):\n", | ||||||
|  |     "        return\n", | ||||||
|  |     "    else:\n", | ||||||
|  |     "        print('\\n')\n", | ||||||
|  |     "        stats = np.mean(best_scores), np.std(best_scores), np.min(best_scores), np.max(best_scores)\n", | ||||||
|  |     "        print_stats(exp_path.name, *stats)\n", | ||||||
|  |     "        print('--------------------------------------------')\n" | ||||||
|  |    ], | ||||||
|  |    "metadata": { | ||||||
|  |     "collapsed": false, | ||||||
|  |     "pycharm": { | ||||||
|  |      "name": "#%% Util Functions\n" | ||||||
|  |     } | ||||||
|  |    } | ||||||
|  |   }, | ||||||
|  |   { | ||||||
|  |    "cell_type": "code", | ||||||
|  |    "execution_count": 32, | ||||||
|  |    "outputs": [ | ||||||
|  |     { | ||||||
|  |      "name": "stdout", | ||||||
|  |      "output_type": "stream", | ||||||
|  |      "text": [ | ||||||
|  |       "--------------VT_259ee495ee2d2dc0e56bb23d12476f17------------------\n", | ||||||
|  |       "VT_259ee495ee2d2dc0e56bb23d12476f17 - version_1: 0.8403531908988953\n", | ||||||
|  |       "VT_259ee495ee2d2dc0e56bb23d12476f17 - version_3: 0.8312729001045227\n", | ||||||
|  |       "VT_259ee495ee2d2dc0e56bb23d12476f17 - version_0: 0.8342075347900391\n", | ||||||
|  |       "VT_259ee495ee2d2dc0e56bb23d12476f17 - version_5: 0.8459098935127258\n", | ||||||
|  |       "VT_259ee495ee2d2dc0e56bb23d12476f17 - version_2: 0.8468937277793884\n", | ||||||
|  |       "VT_259ee495ee2d2dc0e56bb23d12476f17 - version_4: 0.8404075503349304\n", | ||||||
|  |       "\n", | ||||||
|  |       "\n", | ||||||
|  |       "For VT_259ee495ee2d2dc0e56bb23d12476f17; statistics are:\n", | ||||||
|  |       "Scores - mean: 0.840s\tstd: 0.006smin: 0.831s\t max: 0.847s\n", | ||||||
|  |       "--------------------------------------------\n", | ||||||
|  |       "--------------VT_012aff7c1c667073aedafcbebfa35ec7------------------\n", | ||||||
|  |       "VT_012aff7c1c667073aedafcbebfa35ec7 - version_6: 0.8637051582336426\n", | ||||||
|  |       "VT_012aff7c1c667073aedafcbebfa35ec7 - version_1: 0.864475429058075\n", | ||||||
|  |       "VT_012aff7c1c667073aedafcbebfa35ec7 - version_3: 0.854859471321106\n", | ||||||
|  |       "VT_012aff7c1c667073aedafcbebfa35ec7 - version_0: 0.8631429672241211\n", | ||||||
|  |       "VT_012aff7c1c667073aedafcbebfa35ec7 - version_8: 0.8484407663345337\n", | ||||||
|  |       "VT_012aff7c1c667073aedafcbebfa35ec7 - version_5: 0.8564963340759277\n", | ||||||
|  |       "VT_012aff7c1c667073aedafcbebfa35ec7 - version_7: 0.8519455194473267\n", | ||||||
|  |       "VT_012aff7c1c667073aedafcbebfa35ec7 - version_2: 0.8683117032051086\n", | ||||||
|  |       "VT_012aff7c1c667073aedafcbebfa35ec7 - version_9: 0.8730489611625671\n", | ||||||
|  |       "VT_012aff7c1c667073aedafcbebfa35ec7 - version_4: 0.8658838272094727\n", | ||||||
|  |       "\n", | ||||||
|  |       "\n", | ||||||
|  |       "For VT_012aff7c1c667073aedafcbebfa35ec7; statistics are:\n", | ||||||
|  |       "Scores - mean: 0.861s\tstd: 0.007smin: 0.848s\t max: 0.873s\n", | ||||||
|  |       "--------------------------------------------\n", | ||||||
|  |       "--------------VT_fdf2a86085b508c1325b181c830a4cf7------------------\n", | ||||||
|  |       "VT_fdf2a86085b508c1325b181c830a4cf7 - version_6: 0.854997456073761\n", | ||||||
|  |       "VT_fdf2a86085b508c1325b181c830a4cf7 - version_1: 0.8609604835510254\n", | ||||||
|  |       "VT_fdf2a86085b508c1325b181c830a4cf7 - version_3: 0.8558254837989807\n", | ||||||
|  |       "VT_fdf2a86085b508c1325b181c830a4cf7 - version_0: 0.8728921413421631\n", | ||||||
|  |       "VT_fdf2a86085b508c1325b181c830a4cf7 - version_8: 0.8631933927536011\n", | ||||||
|  |       "VT_fdf2a86085b508c1325b181c830a4cf7 - version_5: 0.8612215518951416\n", | ||||||
|  |       "VT_fdf2a86085b508c1325b181c830a4cf7 - version_7: 0.8661960959434509\n", | ||||||
|  |       "VT_fdf2a86085b508c1325b181c830a4cf7 - version_2: 0.8636621832847595\n", | ||||||
|  |       "VT_fdf2a86085b508c1325b181c830a4cf7 - version_9: 0.8614727258682251\n", | ||||||
|  |       "VT_fdf2a86085b508c1325b181c830a4cf7 - version_4: 0.8657329082489014\n", | ||||||
|  |       "\n", | ||||||
|  |       "\n", | ||||||
|  |       "For VT_fdf2a86085b508c1325b181c830a4cf7; statistics are:\n", | ||||||
|  |       "Scores - mean: 0.863s\tstd: 0.005smin: 0.855s\t max: 0.873s\n", | ||||||
|  |       "--------------------------------------------\n", | ||||||
|  |       "--------------VT_cc64c06847a7ca26f5ea4d465f9cc5bc------------------\n", | ||||||
|  |       "VT_cc64c06847a7ca26f5ea4d465f9cc5bc - version_6: 0.8572231531143188\n", | ||||||
|  |       "VT_cc64c06847a7ca26f5ea4d465f9cc5bc - version_1: 0.8442623615264893\n", | ||||||
|  |       "VT_cc64c06847a7ca26f5ea4d465f9cc5bc - version_3: 0.8498414754867554\n", | ||||||
|  |       "VT_cc64c06847a7ca26f5ea4d465f9cc5bc - version_0: 0.8569087982177734\n", | ||||||
|  |       "VT_cc64c06847a7ca26f5ea4d465f9cc5bc - version_8: 0.8455194234848022\n", | ||||||
|  |       "VT_cc64c06847a7ca26f5ea4d465f9cc5bc - version_5: 0.8435630798339844\n", | ||||||
|  |       "VT_cc64c06847a7ca26f5ea4d465f9cc5bc - version_7: 0.845982551574707\n", | ||||||
|  |       "VT_cc64c06847a7ca26f5ea4d465f9cc5bc - version_2: 0.8571171164512634\n", | ||||||
|  |       "VT_cc64c06847a7ca26f5ea4d465f9cc5bc - version_9: 0.8448543548583984\n", | ||||||
|  |       "VT_cc64c06847a7ca26f5ea4d465f9cc5bc - version_4: 0.845399022102356\n", | ||||||
|  |       "\n", | ||||||
|  |       "\n", | ||||||
|  |       "For VT_cc64c06847a7ca26f5ea4d465f9cc5bc; statistics are:\n", | ||||||
|  |       "Scores - mean: 0.849s\tstd: 0.005smin: 0.844s\t max: 0.857s\n", | ||||||
|  |       "--------------------------------------------\n", | ||||||
|  |       "--------------VT_2c7afd50e127f5a2339db0ddfd6bfd7c------------------\n", | ||||||
|  |       "VT_2c7afd50e127f5a2339db0ddfd6bfd7c - version_6: 0.8630585670471191\n", | ||||||
|  |       "VT_2c7afd50e127f5a2339db0ddfd6bfd7c - version_1: 0.8686699271202087\n", | ||||||
|  |       "VT_2c7afd50e127f5a2339db0ddfd6bfd7c - version_3: 0.8729345798492432\n", | ||||||
|  |       "VT_2c7afd50e127f5a2339db0ddfd6bfd7c - version_0: 0.8636038899421692\n", | ||||||
|  |       "VT_2c7afd50e127f5a2339db0ddfd6bfd7c - version_8: 0.8558077812194824\n", | ||||||
|  |       "VT_2c7afd50e127f5a2339db0ddfd6bfd7c - version_5: 0.8710847496986389\n", | ||||||
|  |       "VT_2c7afd50e127f5a2339db0ddfd6bfd7c - version_7: 0.8619015216827393\n", | ||||||
|  |       "VT_2c7afd50e127f5a2339db0ddfd6bfd7c - version_2: 0.8499867916107178\n", | ||||||
|  |       "VT_2c7afd50e127f5a2339db0ddfd6bfd7c - version_9: 0.8507344722747803\n", | ||||||
|  |       "VT_2c7afd50e127f5a2339db0ddfd6bfd7c - version_4: 0.8555077314376831\n", | ||||||
|  |       "\n", | ||||||
|  |       "\n", | ||||||
|  |       "For VT_2c7afd50e127f5a2339db0ddfd6bfd7c; statistics are:\n", | ||||||
|  |       "Scores - mean: 0.861s\tstd: 0.008smin: 0.850s\t max: 0.873s\n", | ||||||
|  |       "--------------------------------------------\n", | ||||||
|  |       "--------------VT_63b9fee765cdda91756af1f35cd320a3------------------\n", | ||||||
|  |       "VT_63b9fee765cdda91756af1f35cd320a3 - version_6: 0.8663593530654907\n", | ||||||
|  |       "VT_63b9fee765cdda91756af1f35cd320a3 - version_1: 0.8519773483276367\n", | ||||||
|  |       "VT_63b9fee765cdda91756af1f35cd320a3 - version_3: 0.8519774675369263\n", | ||||||
|  |       "VT_63b9fee765cdda91756af1f35cd320a3 - version_0: 0.8603388071060181\n", | ||||||
|  |       "VT_63b9fee765cdda91756af1f35cd320a3 - version_8: 0.8614517450332642\n", | ||||||
|  |       "VT_63b9fee765cdda91756af1f35cd320a3 - version_5: 0.8558711409568787\n", | ||||||
|  |       "VT_63b9fee765cdda91756af1f35cd320a3 - version_7: 0.8537712097167969\n", | ||||||
|  |       "VT_63b9fee765cdda91756af1f35cd320a3 - version_2: 0.8558205962181091\n", | ||||||
|  |       "VT_63b9fee765cdda91756af1f35cd320a3 - version_9: 0.8647329211235046\n", | ||||||
|  |       "VT_63b9fee765cdda91756af1f35cd320a3 - version_4: 0.8546129465103149\n", | ||||||
|  |       "\n", | ||||||
|  |       "\n", | ||||||
|  |       "For VT_63b9fee765cdda91756af1f35cd320a3; statistics are:\n", | ||||||
|  |       "Scores - mean: 0.858s\tstd: 0.005smin: 0.852s\t max: 0.866s\n", | ||||||
|  |       "--------------------------------------------\n", | ||||||
|  |       "--------------VT_aca900a5b9566af61c91aea6525190e6------------------\n", | ||||||
|  |       "VT_aca900a5b9566af61c91aea6525190e6 - version_6: 0.8575441241264343\n", | ||||||
|  |       "VT_aca900a5b9566af61c91aea6525190e6 - version_1: 0.8453981280326843\n", | ||||||
|  |       "VT_aca900a5b9566af61c91aea6525190e6 - version_3: 0.8621359467506409\n", | ||||||
|  |       "VT_aca900a5b9566af61c91aea6525190e6 - version_0: 0.8547767400741577\n", | ||||||
|  |       "VT_aca900a5b9566af61c91aea6525190e6 - version_8: 0.8613359928131104\n", | ||||||
|  |       "VT_aca900a5b9566af61c91aea6525190e6 - version_5: 0.8667657375335693\n", | ||||||
|  |       "VT_aca900a5b9566af61c91aea6525190e6 - version_7: 0.8474754095077515\n", | ||||||
|  |       "VT_aca900a5b9566af61c91aea6525190e6 - version_2: 0.8628634214401245\n", | ||||||
|  |       "VT_aca900a5b9566af61c91aea6525190e6 - version_9: 0.8585749268531799\n", | ||||||
|  |       "VT_aca900a5b9566af61c91aea6525190e6 - version_4: 0.8380126357078552\n", | ||||||
|  |       "\n", | ||||||
|  |       "\n", | ||||||
|  |       "For VT_aca900a5b9566af61c91aea6525190e6; statistics are:\n", | ||||||
|  |       "Scores - mean: 0.855s\tstd: 0.009smin: 0.838s\t max: 0.867s\n", | ||||||
|  |       "--------------------------------------------\n", | ||||||
|  |       "--------------VT_fb6b96a190455106d29f0630f002ac6f------------------\n", | ||||||
|  |       "VT_fb6b96a190455106d29f0630f002ac6f - version_6: 0.8635155558586121\n", | ||||||
|  |       "VT_fb6b96a190455106d29f0630f002ac6f - version_1: 0.8261691927909851\n", | ||||||
|  |       "VT_fb6b96a190455106d29f0630f002ac6f - version_3: 0.8444902896881104\n", | ||||||
|  |       "VT_fb6b96a190455106d29f0630f002ac6f - version_0: 0.865719735622406\n", | ||||||
|  |       "VT_fb6b96a190455106d29f0630f002ac6f - version_8: 0.8533784747123718\n", | ||||||
|  |       "VT_fb6b96a190455106d29f0630f002ac6f - version_5: 0.8555656671524048\n", | ||||||
|  |       "VT_fb6b96a190455106d29f0630f002ac6f - version_7: 0.837948739528656\n", | ||||||
|  |       "VT_fb6b96a190455106d29f0630f002ac6f - version_2: 0.8545827865600586\n", | ||||||
|  |       "VT_fb6b96a190455106d29f0630f002ac6f - version_9: 0.8541560769081116\n", | ||||||
|  |       "VT_fb6b96a190455106d29f0630f002ac6f - version_4: 0.85297691822052\n", | ||||||
|  |       "\n", | ||||||
|  |       "\n", | ||||||
|  |       "For VT_fb6b96a190455106d29f0630f002ac6f; statistics are:\n", | ||||||
|  |       "Scores - mean: 0.851s\tstd: 0.011smin: 0.826s\t max: 0.866s\n", | ||||||
|  |       "--------------------------------------------\n", | ||||||
|  |       "--------------VT_378971720b930050ad7662bb96699e20------------------\n", | ||||||
|  |       "VT_378971720b930050ad7662bb96699e20 - version_6: 0.8388294577598572\n", | ||||||
|  |       "VT_378971720b930050ad7662bb96699e20 - version_1: 0.8333806395530701\n", | ||||||
|  |       "VT_378971720b930050ad7662bb96699e20 - version_3: 0.847841203212738\n", | ||||||
|  |       "VT_378971720b930050ad7662bb96699e20 - version_0: 0.8287097811698914\n", | ||||||
|  |       "VT_378971720b930050ad7662bb96699e20 - version_8: 0.8436978459358215\n", | ||||||
|  |       "VT_378971720b930050ad7662bb96699e20 - version_5: 0.8392724990844727\n", | ||||||
|  |       "VT_378971720b930050ad7662bb96699e20 - version_7: 0.8410612344741821\n", | ||||||
|  |       "VT_378971720b930050ad7662bb96699e20 - version_2: 0.8407015204429626\n", | ||||||
|  |       "VT_378971720b930050ad7662bb96699e20 - version_9: 0.8334627151489258\n", | ||||||
|  |       "VT_378971720b930050ad7662bb96699e20 - version_4: 0.8400266766548157\n", | ||||||
|  |       "\n", | ||||||
|  |       "\n", | ||||||
|  |       "For VT_378971720b930050ad7662bb96699e20; statistics are:\n", | ||||||
|  |       "Scores - mean: 0.839s\tstd: 0.005smin: 0.829s\t max: 0.848s\n", | ||||||
|  |       "--------------------------------------------\n", | ||||||
|  |       "--------------VT_d55f1492ff29a3cd1026013948ce7fa7------------------\n", | ||||||
|  |       "VT_d55f1492ff29a3cd1026013948ce7fa7 - version_6: 0.8385945558547974\n", | ||||||
|  |       "VT_d55f1492ff29a3cd1026013948ce7fa7 - version_1: 0.8324360251426697\n", | ||||||
|  |       "VT_d55f1492ff29a3cd1026013948ce7fa7 - version_3: 0.8386826515197754\n", | ||||||
|  |       "VT_d55f1492ff29a3cd1026013948ce7fa7 - version_0: 0.8366813063621521\n", | ||||||
|  |       "VT_d55f1492ff29a3cd1026013948ce7fa7 - version_8: 0.8460721969604492\n", | ||||||
|  |       "VT_d55f1492ff29a3cd1026013948ce7fa7 - version_5: 0.8374781608581543\n", | ||||||
|  |       "VT_d55f1492ff29a3cd1026013948ce7fa7 - version_7: 0.8320286273956299\n", | ||||||
|  |       "VT_d55f1492ff29a3cd1026013948ce7fa7 - version_2: 0.8370164632797241\n", | ||||||
|  |       "VT_d55f1492ff29a3cd1026013948ce7fa7 - version_9: 0.8495808839797974\n", | ||||||
|  |       "VT_d55f1492ff29a3cd1026013948ce7fa7 - version_4: 0.8332125544548035\n", | ||||||
|  |       "\n", | ||||||
|  |       "\n", | ||||||
|  |       "For VT_d55f1492ff29a3cd1026013948ce7fa7; statistics are:\n", | ||||||
|  |       "Scores - mean: 0.838s\tstd: 0.005smin: 0.832s\t max: 0.850s\n", | ||||||
|  |       "--------------------------------------------\n", | ||||||
|  |       "--------------VT_15cbb349b2b50dbb97beec16af2bedab------------------\n", | ||||||
|  |       "VT_15cbb349b2b50dbb97beec16af2bedab - version_6: 0.8407894372940063\n", | ||||||
|  |       "VT_15cbb349b2b50dbb97beec16af2bedab - version_1: 0.836580216884613\n", | ||||||
|  |       "VT_15cbb349b2b50dbb97beec16af2bedab - version_3: 0.8312996029853821\n", | ||||||
|  |       "VT_15cbb349b2b50dbb97beec16af2bedab - version_0: 0.8336991667747498\n", | ||||||
|  |       "VT_15cbb349b2b50dbb97beec16af2bedab - version_8: 0.8231534957885742\n", | ||||||
|  |       "VT_15cbb349b2b50dbb97beec16af2bedab - version_5: 0.8243923187255859\n", | ||||||
|  |       "VT_15cbb349b2b50dbb97beec16af2bedab - version_7: 0.8342592120170593\n", | ||||||
|  |       "VT_15cbb349b2b50dbb97beec16af2bedab - version_2: 0.8349334001541138\n", | ||||||
|  |       "VT_15cbb349b2b50dbb97beec16af2bedab - version_9: 0.8382810950279236\n", | ||||||
|  |       "VT_15cbb349b2b50dbb97beec16af2bedab - version_4: 0.8381868600845337\n", | ||||||
|  |       "\n", | ||||||
|  |       "\n", | ||||||
|  |       "For VT_15cbb349b2b50dbb97beec16af2bedab; statistics are:\n", | ||||||
|  |       "Scores - mean: 0.834s\tstd: 0.006smin: 0.823s\t max: 0.841s\n", | ||||||
|  |       "--------------------------------------------\n" | ||||||
|  |      ] | ||||||
|  |     } | ||||||
|  |    ], | ||||||
|  |    "source": [ | ||||||
|  |     "for model_configuration in [x for x in (out_path / model_name).iterdir() if x.is_dir()]:\n", | ||||||
|  |     "    # Print metrics\n", | ||||||
|  |     "    print_metrics(model_configuration)" | ||||||
|  |    ], | ||||||
|  |    "metadata": { | ||||||
|  |     "collapsed": false, | ||||||
|  |     "pycharm": { | ||||||
|  |      "name": "#%% Mass - Load Model and read Metrics\n" | ||||||
|  |     } | ||||||
|  |    } | ||||||
|  |   }, | ||||||
|  |   { | ||||||
|  |    "cell_type": "code", | ||||||
|  |    "execution_count": 15, | ||||||
|  |    "outputs": [ | ||||||
|  |     { | ||||||
|  |      "name": "stdout", | ||||||
|  |      "output_type": "stream", | ||||||
|  |      "text": [ | ||||||
|  |       "--------------VT_fdf2a86085b508c1325b181c830a4cf7------------------\n", | ||||||
|  |       "--------------VT_fdf2a86085b508c1325b181c830a4cf7------------------\n", | ||||||
|  |       "VT_fdf2a86085b508c1325b181c830a4cf7 - version_6: 0.854997456073761\n", | ||||||
|  |       "VT_fdf2a86085b508c1325b181c830a4cf7 - version_1: 0.8609604835510254\n", | ||||||
|  |       "VT_fdf2a86085b508c1325b181c830a4cf7 - version_3: 0.8558254837989807\n", | ||||||
|  |       "VT_fdf2a86085b508c1325b181c830a4cf7 - version_0: 0.8728921413421631\n", | ||||||
|  |       "VT_fdf2a86085b508c1325b181c830a4cf7 - version_8: 0.8631933927536011\n", | ||||||
|  |       "VT_fdf2a86085b508c1325b181c830a4cf7 - version_5: 0.8612215518951416\n", | ||||||
|  |       "VT_fdf2a86085b508c1325b181c830a4cf7 - version_7: 0.8661960959434509\n", | ||||||
|  |       "VT_fdf2a86085b508c1325b181c830a4cf7 - version_2: 0.8636621832847595\n", | ||||||
|  |       "VT_fdf2a86085b508c1325b181c830a4cf7 - version_9: 0.8614727258682251\n", | ||||||
|  |       "VT_fdf2a86085b508c1325b181c830a4cf7 - version_4: 0.8657329082489014\n", | ||||||
|  |       "--------------------------------------------\n", | ||||||
|  |       "--------------------------------------------\n" | ||||||
|  |      ] | ||||||
|  |     } | ||||||
|  |    ], | ||||||
|  |    "source": [ | ||||||
|  |     "# fingerprint = '012aff7c1c667073aedafcbebfa35ec7'\n", | ||||||
|  |     "fingerprint = 'fdf2a86085b508c1325b181c830a4cf7'\n", | ||||||
|  |     "exp_name = f'{\"\".join([x for x in model_name if x.isupper()])}_{fingerprint}'\n", | ||||||
|  |     "\n", | ||||||
|  |     "# Print metrics\n", | ||||||
|  |     "print_metrics(out_path/model_name/exp_name)\n", | ||||||
|  |     "\n" | ||||||
|  |    ], | ||||||
|  |    "metadata": { | ||||||
|  |     "collapsed": false, | ||||||
|  |     "pycharm": { | ||||||
|  |      "name": "#%% Single - Load Model and read Metrics\n" | ||||||
|  |     } | ||||||
|  |    } | ||||||
|  |   }, | ||||||
|  |   { | ||||||
|  |    "cell_type": "code", | ||||||
|  |    "execution_count": 39, | ||||||
|  |    "outputs": [ | ||||||
|  |     { | ||||||
|  |      "name": "stdout", | ||||||
|  |      "output_type": "stream", | ||||||
|  |      "text": [ | ||||||
|  |       "    filenames  prediction prediction_named\n", | ||||||
|  |       "0  test_00001           1        chimpanze\n", | ||||||
|  |       "1  test_00002           0       background\n", | ||||||
|  |       "2  test_00003           0       background\n", | ||||||
|  |       "3  test_00004           1        chimpanze\n", | ||||||
|  |       "4  test_00005           4           redcap\n" | ||||||
|  |      ] | ||||||
|  |     } | ||||||
|  |    ], | ||||||
|  |    "source": [ | ||||||
|  |     "predictions_file = out_path/model_name/'VT_15cbb349b2b50dbb97beec16af2bedab'/'version_9'/'predictions.csv'\n", | ||||||
|  |     "df_predictions = pd.read_csv(predictions_file)\n", | ||||||
|  |     "print(df_predictions.head())\n", | ||||||
|  |     "df_predictions = df_predictions[['filenames', 'prediction_named']]\n", | ||||||
|  |     "df_predictions.columns = ['filename', 'prediction']\n", | ||||||
|  |     "df_predictions['filename'] = df_predictions['filename'] + '.wav'\n", | ||||||
|  |     "predictions_file_new = predictions_file.parent / 'prediction_final.csv'\n", | ||||||
|  |     "df_predictions.to_csv(index=False, path_or_buf=predictions_file_new)\n", | ||||||
|  |     "\n", | ||||||
|  |     "\n" | ||||||
|  |    ], | ||||||
|  |    "metadata": { | ||||||
|  |     "collapsed": false, | ||||||
|  |     "pycharm": { | ||||||
|  |      "name": "#%% Combine Predictions#\n" | ||||||
|  |     } | ||||||
|  |    } | ||||||
|  |   } | ||||||
|  |  ], | ||||||
|  |  "metadata": { | ||||||
|  |   "kernelspec": { | ||||||
|  |    "display_name": "Python 3", | ||||||
|  |    "language": "python", | ||||||
|  |    "name": "python3" | ||||||
|  |   }, | ||||||
|  |   "language_info": { | ||||||
|  |    "codemirror_mode": { | ||||||
|  |     "name": "ipython", | ||||||
|  |     "version": 2 | ||||||
|  |    }, | ||||||
|  |    "file_extension": ".py", | ||||||
|  |    "mimetype": "text/x-python", | ||||||
|  |    "name": "python", | ||||||
|  |    "nbconvert_exporter": "python", | ||||||
|  |    "pygments_lexer": "ipython2", | ||||||
|  |    "version": "2.7.6" | ||||||
|  |   } | ||||||
|  |  }, | ||||||
|  |  "nbformat": 4, | ||||||
|  |  "nbformat_minor": 0 | ||||||
|  | } | ||||||
| @@ -1,104 +0,0 @@ | |||||||
| { |  | ||||||
|  "cells": [ |  | ||||||
|   { |  | ||||||
|    "cell_type": "code", |  | ||||||
|    "execution_count": 47, |  | ||||||
|    "metadata": { |  | ||||||
|     "collapsed": true, |  | ||||||
|     "pycharm": { |  | ||||||
|      "name": "#%% IMPORTS\n" |  | ||||||
|     } |  | ||||||
|    }, |  | ||||||
|    "outputs": [], |  | ||||||
|    "source": [ |  | ||||||
|     "from pathlib import Path\n", |  | ||||||
|     "from natsort import natsorted\n", |  | ||||||
|     "from pytorch_lightning.core.saving import *\n", |  | ||||||
|     "from ml_lib.utils.model_io import SavedLightningModels\n" |  | ||||||
|    ] |  | ||||||
|   }, |  | ||||||
|   { |  | ||||||
|    "cell_type": "code", |  | ||||||
|    "execution_count": 48, |  | ||||||
|    "outputs": [], |  | ||||||
|    "source": [ |  | ||||||
|     "from ml_lib.utils.tools import locate_and_import_class\n", |  | ||||||
|     "from models.transformer_model import VisualTransformer\n", |  | ||||||
|     "_ROOT = Path('..')\n", |  | ||||||
|     "out_path = 'output'\n", |  | ||||||
|     "model_class = VisualTransformer\n", |  | ||||||
|     "model_name = model_class.name()\n", |  | ||||||
|     "\n", |  | ||||||
|     "exp_name = 'VT_01123c93daaffa92d2ed341bda32426d'\n", |  | ||||||
|     "version = 'version_2'" |  | ||||||
|    ], |  | ||||||
|    "metadata": { |  | ||||||
|     "collapsed": false, |  | ||||||
|     "pycharm": { |  | ||||||
|      "name": "#%%M Path resolving and variables\n" |  | ||||||
|     } |  | ||||||
|    } |  | ||||||
|   }, |  | ||||||
|   { |  | ||||||
|    "cell_type": "code", |  | ||||||
|    "execution_count": 50, |  | ||||||
|    "outputs": [ |  | ||||||
|     { |  | ||||||
|      "ename": "ValueError", |  | ||||||
|      "evalue": "When you set `reduce` as 'macro', you have to provide the number of classes.", |  | ||||||
|      "output_type": "error", |  | ||||||
|      "traceback": [ |  | ||||||
|       "\u001B[1;31m---------------------------------------------------------------------------\u001B[0m", |  | ||||||
|       "\u001B[1;31mValueError\u001B[0m                                Traceback (most recent call last)", |  | ||||||
|       "\u001B[1;32m<ipython-input-50-0216292a172f>\u001B[0m in \u001B[0;36m<module>\u001B[1;34m\u001B[0m\n\u001B[0;32m      6\u001B[0m \u001B[0madditional_kwargs\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mdict\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mvariable_length\u001B[0m \u001B[1;33m=\u001B[0m \u001B[1;32mFalse\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mc_classes\u001B[0m\u001B[1;33m=\u001B[0m\u001B[1;36m5\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m      7\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m----> 8\u001B[1;33m \u001B[0mmodel\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mmodel_class\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mload_from_checkpoint\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mcheckpoint\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mhparams_file\u001B[0m\u001B[1;33m=\u001B[0m\u001B[0mstr\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mhparams_yaml\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m,\u001B[0m \u001B[1;33m**\u001B[0m\u001B[0madditional_kwargs\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m      9\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n", |  | ||||||
|       "\u001B[1;32mc:\\users\\steff\\envs\\compare_21\\lib\\site-packages\\pytorch_lightning\\core\\saving.py\u001B[0m in \u001B[0;36mload_from_checkpoint\u001B[1;34m(cls, checkpoint_path, map_location, hparams_file, strict, **kwargs)\u001B[0m\n\u001B[0;32m    154\u001B[0m         \u001B[0mcheckpoint\u001B[0m\u001B[1;33m[\u001B[0m\u001B[0mcls\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mCHECKPOINT_HYPER_PARAMS_KEY\u001B[0m\u001B[1;33m]\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mupdate\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mkwargs\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m    155\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m--> 156\u001B[1;33m         \u001B[0mmodel\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mcls\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0m_load_model_state\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mcheckpoint\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mstrict\u001B[0m\u001B[1;33m=\u001B[0m\u001B[0mstrict\u001B[0m\u001B[1;33m,\u001B[0m \u001B[1;33m**\u001B[0m\u001B[0mkwargs\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m    157\u001B[0m         \u001B[1;32mreturn\u001B[0m \u001B[0mmodel\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m    158\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n", |  | ||||||
|       "\u001B[1;32mc:\\users\\steff\\envs\\compare_21\\lib\\site-packages\\pytorch_lightning\\core\\saving.py\u001B[0m in \u001B[0;36m_load_model_state\u001B[1;34m(cls, checkpoint, strict, **cls_kwargs_new)\u001B[0m\n\u001B[0;32m    196\u001B[0m             \u001B[0m_cls_kwargs\u001B[0m \u001B[1;33m=\u001B[0m \u001B[1;33m{\u001B[0m\u001B[0mk\u001B[0m\u001B[1;33m:\u001B[0m \u001B[0mv\u001B[0m \u001B[1;32mfor\u001B[0m \u001B[0mk\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mv\u001B[0m \u001B[1;32min\u001B[0m \u001B[0m_cls_kwargs\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mitems\u001B[0m\u001B[1;33m(\u001B[0m\u001B[1;33m)\u001B[0m \u001B[1;32mif\u001B[0m \u001B[0mk\u001B[0m \u001B[1;32min\u001B[0m \u001B[0mcls_init_args_name\u001B[0m\u001B[1;33m}\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m    197\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m--> 198\u001B[1;33m         \u001B[0mmodel\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mcls\u001B[0m\u001B[1;33m(\u001B[0m\u001B[1;33m**\u001B[0m\u001B[0m_cls_kwargs\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m    199\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m    200\u001B[0m         \u001B[1;31m# give model a chance to load something\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n", |  | ||||||
|       "\u001B[1;32m~\\projects\\compare_21\\models\\transformer_model.py\u001B[0m in \u001B[0;36m__init__\u001B[1;34m(self, in_shape, n_classes, weight_init, activation, embedding_size, heads, attn_depth, patch_size, use_residual, variable_length, use_bias, use_norm, dropout, lat_dim, loss, scheduler, mlp_dim, head_dim, lr, weight_decay, sto_weight_avg, lr_scheduler_parameter, opt_reset_interval)\u001B[0m\n\u001B[0;32m     27\u001B[0m         \u001B[0ma\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mdict\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mlocals\u001B[0m\u001B[1;33m(\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m     28\u001B[0m         \u001B[0mparams\u001B[0m \u001B[1;33m=\u001B[0m \u001B[1;33m{\u001B[0m\u001B[0marg\u001B[0m\u001B[1;33m:\u001B[0m \u001B[0ma\u001B[0m\u001B[1;33m[\u001B[0m\u001B[0marg\u001B[0m\u001B[1;33m]\u001B[0m \u001B[1;32mfor\u001B[0m \u001B[0marg\u001B[0m \u001B[1;32min\u001B[0m \u001B[0minspect\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0msignature\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mself\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0m__init__\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mparameters\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mkeys\u001B[0m\u001B[1;33m(\u001B[0m\u001B[1;33m)\u001B[0m \u001B[1;32mif\u001B[0m \u001B[0marg\u001B[0m \u001B[1;33m!=\u001B[0m \u001B[1;34m'self'\u001B[0m\u001B[1;33m}\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m---> 29\u001B[1;33m         \u001B[0msuper\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mVisualTransformer\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mself\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0m__init__\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mparams\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m     30\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m     31\u001B[0m         \u001B[0mself\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0min_shape\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0min_shape\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n", |  | ||||||
|       "\u001B[1;32m~\\projects\\compare_21\\ml_lib\\modules\\util.py\u001B[0m in \u001B[0;36m__init__\u001B[1;34m(self, model_parameters, weight_init)\u001B[0m\n\u001B[0;32m    112\u001B[0m             \u001B[0mself\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mparams\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mModelParameters\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mmodel_parameters\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m    113\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m--> 114\u001B[1;33m             \u001B[0mself\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mmetrics\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mPLMetrics\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mself\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mparams\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mn_classes\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mtag\u001B[0m\u001B[1;33m=\u001B[0m\u001B[1;34m'PL'\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m    115\u001B[0m             \u001B[1;32mpass\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m    116\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n", |  | ||||||
|       "\u001B[1;32m~\\projects\\compare_21\\ml_lib\\modules\\util.py\u001B[0m in \u001B[0;36m__init__\u001B[1;34m(self, n_classes, tag)\u001B[0m\n\u001B[0;32m     30\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m     31\u001B[0m             \u001B[0mself\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0maccuracy_score\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mpl\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mmetrics\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mAccuracy\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mcompute_on_step\u001B[0m\u001B[1;33m=\u001B[0m\u001B[1;32mFalse\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m---> 32\u001B[1;33m             \u001B[0mself\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mprecision\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mpl\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mmetrics\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mPrecision\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mnum_classes\u001B[0m\u001B[1;33m=\u001B[0m\u001B[0mself\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mn_classes\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0maverage\u001B[0m\u001B[1;33m=\u001B[0m\u001B[1;34m'macro'\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mcompute_on_step\u001B[0m\u001B[1;33m=\u001B[0m\u001B[1;32mFalse\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m     33\u001B[0m             \u001B[0mself\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mrecall\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mpl\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mmetrics\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mRecall\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mnum_classes\u001B[0m\u001B[1;33m=\u001B[0m\u001B[0mself\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mn_classes\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0maverage\u001B[0m\u001B[1;33m=\u001B[0m\u001B[1;34m'macro'\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mcompute_on_step\u001B[0m\u001B[1;33m=\u001B[0m\u001B[1;32mFalse\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m     34\u001B[0m             \u001B[0mself\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mconfusion_matrix\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mpl\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mmetrics\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mConfusionMatrix\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mself\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mn_classes\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mnormalize\u001B[0m\u001B[1;33m=\u001B[0m\u001B[1;34m'true'\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mcompute_on_step\u001B[0m\u001B[1;33m=\u001B[0m\u001B[1;32mFalse\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n", |  | ||||||
|       "\u001B[1;32mc:\\users\\steff\\envs\\compare_21\\lib\\site-packages\\pytorch_lightning\\metrics\\classification\\precision_recall.py\u001B[0m in \u001B[0;36m__init__\u001B[1;34m(self, num_classes, threshold, average, multilabel, mdmc_average, ignore_index, top_k, is_multiclass, compute_on_step, dist_sync_on_step, process_group, dist_sync_fn)\u001B[0m\n\u001B[0;32m    139\u001B[0m             \u001B[1;32mraise\u001B[0m \u001B[0mValueError\u001B[0m\u001B[1;33m(\u001B[0m\u001B[1;34mf\"The `average` has to be one of {allowed_average}, got {average}.\"\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m    140\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m--> 141\u001B[1;33m         super().__init__(\n\u001B[0m\u001B[0;32m    142\u001B[0m             \u001B[0mreduce\u001B[0m\u001B[1;33m=\u001B[0m\u001B[1;34m\"macro\"\u001B[0m \u001B[1;32mif\u001B[0m \u001B[0maverage\u001B[0m \u001B[1;32min\u001B[0m \u001B[1;33m[\u001B[0m\u001B[1;34m\"weighted\"\u001B[0m\u001B[1;33m,\u001B[0m \u001B[1;34m\"none\"\u001B[0m\u001B[1;33m,\u001B[0m \u001B[1;32mNone\u001B[0m\u001B[1;33m]\u001B[0m \u001B[1;32melse\u001B[0m \u001B[0maverage\u001B[0m\u001B[1;33m,\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m    143\u001B[0m             \u001B[0mmdmc_reduce\u001B[0m\u001B[1;33m=\u001B[0m\u001B[0mmdmc_average\u001B[0m\u001B[1;33m,\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n", |  | ||||||
|       "\u001B[1;32mc:\\users\\steff\\envs\\compare_21\\lib\\site-packages\\pytorch_lightning\\metrics\\classification\\stat_scores.py\u001B[0m in \u001B[0;36m__init__\u001B[1;34m(self, threshold, top_k, reduce, num_classes, ignore_index, mdmc_reduce, is_multiclass, compute_on_step, dist_sync_on_step, process_group, dist_sync_fn)\u001B[0m\n\u001B[0;32m    157\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m    158\u001B[0m         \u001B[1;32mif\u001B[0m \u001B[0mreduce\u001B[0m \u001B[1;33m==\u001B[0m \u001B[1;34m\"macro\"\u001B[0m \u001B[1;32mand\u001B[0m \u001B[1;33m(\u001B[0m\u001B[1;32mnot\u001B[0m \u001B[0mnum_classes\u001B[0m \u001B[1;32mor\u001B[0m \u001B[0mnum_classes\u001B[0m \u001B[1;33m<\u001B[0m \u001B[1;36m1\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m:\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m--> 159\u001B[1;33m             \u001B[1;32mraise\u001B[0m \u001B[0mValueError\u001B[0m\u001B[1;33m(\u001B[0m\u001B[1;34m\"When you set `reduce` as 'macro', you have to provide the number of classes.\"\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m    160\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m    161\u001B[0m         \u001B[1;32mif\u001B[0m \u001B[0mnum_classes\u001B[0m \u001B[1;32mand\u001B[0m \u001B[0mignore_index\u001B[0m \u001B[1;32mis\u001B[0m \u001B[1;32mnot\u001B[0m \u001B[1;32mNone\u001B[0m \u001B[1;32mand\u001B[0m \u001B[1;33m(\u001B[0m\u001B[1;32mnot\u001B[0m \u001B[1;36m0\u001B[0m \u001B[1;33m<=\u001B[0m \u001B[0mignore_index\u001B[0m \u001B[1;33m<\u001B[0m \u001B[0mnum_classes\u001B[0m \u001B[1;32mor\u001B[0m \u001B[0mnum_classes\u001B[0m \u001B[1;33m==\u001B[0m \u001B[1;36m1\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m:\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n", |  | ||||||
|       "\u001B[1;31mValueError\u001B[0m: When you set `reduce` as 'macro', you have to provide the number of classes." |  | ||||||
|      ] |  | ||||||
|     } |  | ||||||
|    ], |  | ||||||
|    "source": [ |  | ||||||
|     "exp_path = _ROOT / out_path / model_name / exp_name / version\n", |  | ||||||
|     "checkpoint = natsorted(exp_path.glob('*.ckpt'))[-1]\n", |  | ||||||
|     "hparams_yaml = next(exp_path.glob('*.yaml'))\n", |  | ||||||
|     "\n", |  | ||||||
|     "hparams_file = load_hparams_from_yaml(hparams_yaml)\n", |  | ||||||
|     "additional_kwargs = dict(variable_length = False, c_classes=5)\n", |  | ||||||
|     "\n", |  | ||||||
|     "model = model_class.load_from_checkpoint(checkpoint, hparams_file=str(hparams_yaml), **additional_kwargs)\n" |  | ||||||
|    ], |  | ||||||
|    "metadata": { |  | ||||||
|     "collapsed": false, |  | ||||||
|     "pycharm": { |  | ||||||
|      "name": "#%%\n" |  | ||||||
|     } |  | ||||||
|    } |  | ||||||
|   } |  | ||||||
|  ], |  | ||||||
|  "metadata": { |  | ||||||
|   "kernelspec": { |  | ||||||
|    "display_name": "Python 3", |  | ||||||
|    "language": "python", |  | ||||||
|    "name": "python3" |  | ||||||
|   }, |  | ||||||
|   "language_info": { |  | ||||||
|    "codemirror_mode": { |  | ||||||
|     "name": "ipython", |  | ||||||
|     "version": 2 |  | ||||||
|    }, |  | ||||||
|    "file_extension": ".py", |  | ||||||
|    "mimetype": "text/x-python", |  | ||||||
|    "name": "python", |  | ||||||
|    "nbconvert_exporter": "python", |  | ||||||
|    "pygments_lexer": "ipython2", |  | ||||||
|    "version": "2.7.6" |  | ||||||
|   } |  | ||||||
|  }, |  | ||||||
|  "nbformat": 4, |  | ||||||
|  "nbformat_minor": 0 |  | ||||||
| } |  | ||||||
							
								
								
									
										247
									
								
								reload model.ipynb
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										247
									
								
								reload model.ipynb
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							| @@ -18,7 +18,7 @@ class TrainMixin: | |||||||
|         batch_files, batch_x, batch_y = batch_xy |         batch_files, batch_x, batch_y = batch_xy | ||||||
|         y = self(batch_x).main_out |         y = self(batch_x).main_out | ||||||
|         if self.params.n_classes <= 2: |         if self.params.n_classes <= 2: | ||||||
|             loss = self.bce_loss(y, batch_y.long()) |             loss = self.bce_loss(y.squeeze().float(), batch_y.float()) | ||||||
|         else: |         else: | ||||||
|             if self.params.loss == 'focal_loss_rob': |             if self.params.loss == 'focal_loss_rob': | ||||||
|                 labels_one_hot = torch.nn.functional.one_hot(batch_y, num_classes=self.params.n_classes) |                 labels_one_hot = torch.nn.functional.one_hot(batch_y, num_classes=self.params.n_classes) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Steffen
					Steffen