BandwiseBinaryClassifier is no longer work in progress

This commit is contained in:
Si11ium 2020-05-05 10:58:36 +02:00
parent 451f78f820
commit c2860b0aed
2 changed files with 29 additions and 27 deletions

View File

@ -30,7 +30,7 @@ class BinaryMasksDataset(Dataset):
self._labels = self._build_labels() self._labels = self._build_labels()
self._wav_folder = self.data_root / 'wav' self._wav_folder = self.data_root / 'wav'
self._wav_files = list(sorted(self._labels.keys())) self._wav_files = list(sorted(self._labels.keys()))
self._mel_folder = self.data_root / 'transformed' self._transformed_folder = self.data_root / 'transformed'
def _build_labels(self): def _build_labels(self):
with open(Path(self.data_root) / 'lab' / 'labels.csv', mode='r') as f: with open(Path(self.data_root) / 'lab' / 'labels.csv', mode='r') as f:
@ -51,13 +51,13 @@ class BinaryMasksDataset(Dataset):
key = self._wav_files[item] key = self._wav_files[item]
filename = key[:-4] + '.pik' filename = key[:-4] + '.pik'
if not (self._mel_folder / filename).exists(): if not (self._transformed_folder / filename).exists():
raw_sample, sr = librosa.core.load(self._wav_folder / self._wav_files[item]) raw_sample, sr = librosa.core.load(self._wav_folder / self._wav_files[item])
transformed_sample = self._transforms(raw_sample) transformed_sample = self._transforms(raw_sample)
self._mel_folder.mkdir(exist_ok=True, parents=True) self._transformed_folder.mkdir(exist_ok=True, parents=True)
with (self._mel_folder / filename).open(mode='wb') as f: with (self._transformed_folder / filename).open(mode='wb') as f:
pickle.dump(transformed_sample, f, protocol=pickle.HIGHEST_PROTOCOL) pickle.dump(transformed_sample, f, protocol=pickle.HIGHEST_PROTOCOL)
with (self._mel_folder / filename).open(mode='rb') as f: with (self._transformed_folder / filename).open(mode='rb') as f:
sample = pickle.load(f, fix_imports=True) sample = pickle.load(f, fix_imports=True)
label = torch.as_tensor(self._labels[key], dtype=torch.float) label = torch.as_tensor(self._labels[key], dtype=torch.float)
return sample, label return sample, label

View File

@ -13,11 +13,11 @@ from models.module_mixins import BaseOptimizerMixin, BaseTrainMixin, BaseValMixi
class BandwiseBinaryClassifier(BaseModuleMixin_Dataloaders, class BandwiseBinaryClassifier(BaseModuleMixin_Dataloaders,
BaseTrainMixin, BaseTrainMixin,
BaseValMixin, BaseValMixin,
BaseOptimizerMixin, BaseOptimizerMixin,
LightningBaseModule LightningBaseModule
): ):
def __init__(self, hparams): def __init__(self, hparams):
super(BandwiseBinaryClassifier, self).__init__(hparams) super(BandwiseBinaryClassifier, self).__init__(hparams)
@ -25,7 +25,7 @@ class BandwiseBinaryClassifier(BaseModuleMixin_Dataloaders,
# Dataset and Dataloaders # Dataset and Dataloaders
# ============================================================================= # =============================================================================
# Transforms # Transforms
transforms = Compose([AudioToMel(), MelToImage(), ToTensor(), NormalizeLocal()]) transforms = Compose([AudioToMel(n_mels=32), MelToImage(), ToTensor(), NormalizeLocal()])
# Datasets # Datasets
from datasets.binar_masks import BinaryMasksDataset from datasets.binar_masks import BinaryMasksDataset
self.dataset = Namespace( self.dataset = Namespace(
@ -44,33 +44,35 @@ class BandwiseBinaryClassifier(BaseModuleMixin_Dataloaders,
self.criterion = nn.BCELoss() self.criterion = nn.BCELoss()
self.n_band_sections = 5 self.n_band_sections = 5
# Utility Modules # Modules
self.split = HorizontalSplitter(self.in_shape, self.n_band_sections) self.split = HorizontalSplitter(self.in_shape, self.n_band_sections)
self.conv_dict = ModuleDict()
# Modules with Parameters self.conv_dict.update({f"conv_1_{band_section}":
modules = {f"conv_1_{band_section}": ConvModule(self.split.shape, self.conv_filters[0], 3, conv_stride=1, **self.params.module_kwargs)
ConvModule(self.in_shape, self.conv_filters[0], 3, conv_stride=2, **self.params.module_kwargs) for band_section in range(self.n_band_sections)}
for band_section in range(self.n_band_sections)} )
self.conv_dict.update({f"conv_2_{band_section}":
modules.update({f"conv_2_{band_section}": ConvModule(self.conv_dict['conv_1_1'].shape, self.conv_filters[1], 3, conv_stride=1,
ConvModule(self.conv_1.shape, self.conv_filters[1], 3, conv_stride=2,
**self.params.module_kwargs) for band_section in range(self.n_band_sections)} **self.params.module_kwargs) for band_section in range(self.n_band_sections)}
) )
modules.update({f"conv_3_{band_section}": self.conv_dict.update({f"conv_3_{band_section}":
ConvModule(self.conv_2.shape, self.conv_filters[2], 3, conv_stride=2, ConvModule(self.conv_dict['conv_2_1'].shape, self.conv_filters[2], 3, conv_stride=1,
**self.params.module_kwargs) **self.params.module_kwargs)
for band_section in range(self.n_band_sections)} for band_section in range(self.n_band_sections)}
) )
self.merge = HorizontalMerger(self.conv_dict['conv_3_1'].shape, self.n_band_sections)
self.flat = Flatten(self.merge.shape)
self.full_1 = nn.Linear(self.flat.shape, self.params.lat_dim, self.params.bias) self.full_1 = nn.Linear(self.flat.shape, self.params.lat_dim, self.params.bias)
self.full_2 = nn.Linear(self.full_1.out_features, self.full_1.out_features // 2, self.params.bias) self.full_2 = nn.Linear(self.full_1.out_features, self.full_1.out_features // 2, self.params.bias)
self.full_out = nn.Linear(self.full_2.out_features, 1, self.params.bias) self.full_out = nn.Linear(self.full_2.out_features, 1, self.params.bias)
# Utility Modules # Utility Modules
self.merge = HorizontalMerger(self.split.shape, self.n_band_sections)
self.conv_dict = ModuleDict(modules=modules)
self.flat = Flatten(self.conv_3.shape)
self.dropout = nn.Dropout2d(self.params.dropout) if self.params.dropout else lambda x: x self.dropout = nn.Dropout2d(self.params.dropout) if self.params.dropout else lambda x: x
self.activation = self.params.activation() self.activation = self.params.activation()
self.sigmoid = nn.Sigmoid() self.sigmoid = nn.Sigmoid()
@ -78,11 +80,11 @@ class BandwiseBinaryClassifier(BaseModuleMixin_Dataloaders,
def forward(self, batch, **kwargs): def forward(self, batch, **kwargs):
tensors = self.split(batch) tensors = self.split(batch)
for idx, tensor in enumerate(tensors): for idx, tensor in enumerate(tensors):
tensor[idx] = self.conv_dict[f"conv_1_{idx}"](tensor) tensors[idx] = self.conv_dict[f"conv_1_{idx}"](tensor)
for idx, tensor in enumerate(tensors): for idx, tensor in enumerate(tensors):
tensor[idx] = self.conv_dict[f"conv_2_{idx}"](tensor) tensors[idx] = self.conv_dict[f"conv_2_{idx}"](tensor)
for idx, tensor in enumerate(tensors): for idx, tensor in enumerate(tensors):
tensor[idx] = self.conv_dict[f"conv_3_{idx}"](tensor) tensors[idx] = self.conv_dict[f"conv_3_{idx}"](tensor)
tensor = self.merge(tensors) tensor = self.merge(tensors)
tensor = self.flat(tensor) tensor = self.flat(tensor)