LinearModule

This commit is contained in:
Si11ium
2020-05-09 21:56:58 +02:00
parent 5e6b0e598f
commit 3fbc98dfa3
7 changed files with 145 additions and 138 deletions

View File

@ -14,6 +14,7 @@ from torchvision.transforms import Compose, RandomApply
from ml_lib.audio_toolset.audio_augmentation import NoiseInjection, LoudnessManipulator, ShiftTime
from ml_lib.audio_toolset.audio_io import AudioToMel, MelToImage, NormalizeLocal
from ml_lib.modules.utils import LightningBaseModule
from ml_lib.utils.transforms import ToTensor
import variables as V
@ -22,6 +23,7 @@ import variables as V
class BaseOptimizerMixin:
def configure_optimizers(self):
assert isinstance(self, LightningBaseModule)
opt = Adam(params=self.parameters(), lr=self.params.lr)
if self.params.sto_weight_avg:
opt = SWA(opt, swa_start=10, swa_freq=5, swa_lr=0.05)
@ -33,7 +35,7 @@ class BaseOptimizerMixin:
opt.swap_swa_sgd()
def on_epoch_end(self):
if False: # FIXME: Pass a new parameter to model args.
if self.params.opt_reset_interval:
if self.current_epoch % self.params.opt_reset_interval == 0:
for opt in self.trainer.optimizers:
opt.state = defaultdict(dict)
@ -42,6 +44,7 @@ class BaseOptimizerMixin:
class BaseTrainMixin:
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
assert isinstance(self, LightningBaseModule)
batch_x, batch_y = batch_xy
y = self(batch_x).main_out
loss = self.criterion(y, batch_y)
@ -60,7 +63,7 @@ class BaseValMixin:
absolute_loss = L1Loss()
def validation_step(self, batch_xy, batch_idx, *args, **kwargs):
def validation_step(self, batch_xy, batch_idx, dataloader_idx, *args, **kwargs):
batch_x, batch_y = batch_xy
y = self(batch_x).main_out
val_bce_loss = self.criterion(y, batch_y)
@ -69,52 +72,63 @@ class BaseValMixin:
batch_idx=batch_idx, y=y, batch_y=batch_y
)
def validation_epoch_end(self, outputs):
keys = list(outputs[0].keys())
def validation_epoch_end(self, outputs, *args, **kwargs):
summary_dict = dict(log=dict())
for output_idx, output in enumerate(outputs):
keys = list(output[0].keys())
ident = '' if output_idx == 0 else '_train'
summary_dict['log'].update({f'mean{ident}_{key}': torch.mean(torch.stack([output[key]
for output in output]))
for key in keys if 'loss' in key}
)
summary_dict = dict(log={f'mean_{key}': torch.mean(torch.stack([output[key]
for output in outputs]))
for key in keys if 'loss' in key})
# UnweightedAverageRecall
y_true = torch.cat([output['batch_y'] for output in output]) .cpu().numpy()
y_pred = torch.cat([output['y'] for output in output]).squeeze().cpu().numpy()
# UnweightedAverageRecall
y_true = torch.cat([output['batch_y'] for output in outputs]) .cpu().numpy()
y_pred = torch.cat([output['y'] for output in outputs]).squeeze().cpu().numpy()
y_pred = (y_pred >= 0.5).astype(np.float32)
y_pred = (y_pred >= 0.5).astype(np.float32)
uar_score = sklearn.metrics.recall_score(y_true, y_pred, labels=[0, 1], average='macro',
sample_weight=None, zero_division='warn')
uar_score = sklearn.metrics.recall_score(y_true, y_pred, labels=[0, 1], average='macro',
sample_weight=None, zero_division='warn')
summary_dict['log'].update(uar_score=uar_score)
summary_dict['log'].update({f'uar{ident}_score': uar_score})
return summary_dict
class BinaryMaskDatasetFunction:
def build_dataset(self):
assert isinstance(self, LightningBaseModule)
# Dataset
# =============================================================================
# Mel Transforms
mel_transforms = Compose([
# Audio to Mel Transformations
AudioToMel(n_mels=self.params.n_mels), MelToImage()])
AudioToMel(sr=self.params.sr, n_mels=self.params.n_mels, n_fft=self.params.n_fft,
hop_length=self.params.hop_length), MelToImage()])
# Data Augmentations
aug_transforms = Compose([
RandomApply([
NoiseInjection(self.params.noise_ratio),
LoudnessManipulator(self.params.loudness_ratio),
ShiftTime(self.params.shift_ratio)], p=0.5),
NoiseInjection(self.params.noise_ratio),
LoudnessManipulator(self.params.loudness_ratio),
ShiftTime(self.params.shift_ratio)], p=0.5),
# Utility
NormalizeLocal(), ToTensor()
])
val_transforms = Compose([NormalizeLocal(), ToTensor()])
# sampler = RandomSampler(train_dataset, True, len(train_dataset)) if params['bootstrap'] else None
# Datasets
from datasets.binar_masks import BinaryMasksDataset
dataset = Namespace(
**dict(
train_dataset=BinaryMasksDataset(self.params.root, setting=V.DATA_OPTIONS.train, mixup=self.params.mixup,
train_dataset=BinaryMasksDataset(self.params.root, setting=V.DATA_OPTIONS.train,
mixup=self.params.mixup,
mel_transforms=mel_transforms, transforms=aug_transforms),
val_train_dataset=BinaryMasksDataset(self.params.root, setting=V.DATA_OPTIONS.train,
mel_transforms=mel_transforms, transforms=val_transforms),
val_dataset=BinaryMasksDataset(self.params.root, setting=V.DATA_OPTIONS.devel,
mel_transforms=mel_transforms, transforms=val_transforms),
test_dataset=BinaryMasksDataset(self.params.root, setting=V.DATA_OPTIONS.test,
@ -142,6 +156,9 @@ class BaseDataloadersMixin(ABC):
# Validation Dataloader
def val_dataloader(self):
return DataLoader(dataset=self.dataset.val_dataset, shuffle=True,
batch_size=self.params.batch_size,
num_workers=self.params.worker)
val_dataloader = DataLoader(dataset=self.dataset.val_dataset, shuffle=True,
batch_size=self.params.batch_size, num_workers=self.params.worker)
train_dataloader = DataLoader(self.dataset.val_train_dataset, num_workers=self.params.worker,
batch_size=self.params.batch_size, shuffle=False)
return [val_dataloader, train_dataloader]