Model Training

This commit is contained in:
Si11ium 2020-05-03 18:00:51 +02:00
parent 8a97f59906
commit e4f6506a4b
9 changed files with 167 additions and 105 deletions

56
_paramters.py Normal file
View File

@ -0,0 +1,56 @@
from argparse import ArgumentParser, Namespace
from distutils.util import strtobool
from pathlib import Path
import os
# Parameter Configuration
# =============================================================================
# Argument Parser
_ROOT = Path(__file__).parent
main_arg_parser = ArgumentParser(description="parser for fast-neural-style")
# Main Parameters
main_arg_parser.add_argument("--main_debug", type=strtobool, default=False, help="")
main_arg_parser.add_argument("--main_eval", type=strtobool, default=False, help="")
main_arg_parser.add_argument("--main_seed", type=int, default=69, help="")
# Data Parameters
main_arg_parser.add_argument("--data_worker", type=int, default=10, help="")
main_arg_parser.add_argument("--data_root", type=str, default='data', help="")
main_arg_parser.add_argument("--data_class_name", type=str, default='BinaryMasksDataset', help="")
main_arg_parser.add_argument("--data_normalized", type=strtobool, default=True, help="")
main_arg_parser.add_argument("--data_use_preprocessed", type=strtobool, default=True, help="")
# Transformation Parameters
main_arg_parser.add_argument("--transformations_to_tensor", type=strtobool, default=False, help="")
# Training Parameters
main_arg_parser.add_argument("--train_outpath", type=str, default="output", help="")
main_arg_parser.add_argument("--train_version", type=strtobool, required=False, help="")
main_arg_parser.add_argument("--train_epochs", type=int, default=500, help="")
main_arg_parser.add_argument("--train_batch_size", type=int, default=200, help="")
main_arg_parser.add_argument("--train_lr", type=float, default=1e-3, help="")
main_arg_parser.add_argument("--train_num_sanity_val_steps", type=int, default=0, help="")
# Model Parameters
main_arg_parser.add_argument("--model_type", type=str, default="BinaryClassifier", help="")
main_arg_parser.add_argument("--model_weight_init", type=str, default="xavier_normal_", help="")
main_arg_parser.add_argument("--model_activation", type=str, default="leaky_relu", help="")
main_arg_parser.add_argument("--model_filters", type=str, default="[16, 32, 64]", help="")
main_arg_parser.add_argument("--model_classes", type=int, default=2, help="")
main_arg_parser.add_argument("--model_lat_dim", type=int, default=16, help="")
main_arg_parser.add_argument("--model_bias", type=strtobool, default=True, help="")
main_arg_parser.add_argument("--model_norm", type=strtobool, default=False, help="")
main_arg_parser.add_argument("--model_dropout", type=float, default=0.00, help="")
# Project Parameters
main_arg_parser.add_argument("--project_name", type=str, default=_ROOT.name, help="")
main_arg_parser.add_argument("--project_owner", type=str, default='si11ium', help="")
main_arg_parser.add_argument("--project_neptune_key", type=str, default=os.getenv('NEPTUNE_KEY'), help="")
if __name__ == '__main__':
# Parse it
args: Namespace = main_arg_parser.parse_args()

View File

@ -0,0 +1 @@
from datasets.binar_masks import BinaryMasksDataset

View File

@ -4,6 +4,7 @@ from pathlib import Path
import librosa as librosa import librosa as librosa
from torch.utils.data import Dataset from torch.utils.data import Dataset
import torch
import variables as V import variables as V
from ml_lib.modules.utils import F_x from ml_lib.modules.utils import F_x
@ -11,18 +12,16 @@ from ml_lib.modules.utils import F_x
class BinaryMasksDataset(Dataset): class BinaryMasksDataset(Dataset):
_to_label = defaultdict(lambda: -1) _to_label = defaultdict(lambda: -1)
_to_label['clear'] = V.CLEAR _to_label.update(dict(clear=V.CLEAR, mask=V.MASK))
_to_label['mask'] = V.MASK
settings = ['test', 'devel', 'train']
@property @property
def sample_shape(self): def sample_shape(self):
return self[0][0].shape return self[0][0].shape
def __init__(self, data_root, setting, transforms=None): def __init__(self, data_root, setting, transforms=None):
assert isinstance(setting, str), f'Setting has to be a string, but was: {self.settings}.' assert isinstance(setting, str), f'Setting has to be a string, but was: {type(setting)}.'
assert setting in self.settings, f'Setting must match one of: {self.settings}.' assert setting in V.DATA_OPTIONS, f'Setting must match one of: {V.DATA_OPTIONS}.'
assert callable(transforms) or None, f'Transforms has to be callable, but was: {transforms}' assert callable(transforms) or None, f'Transforms has to be callable, but was: {type(transforms)}'
super(BinaryMasksDataset, self).__init__() super(BinaryMasksDataset, self).__init__()
self.data_root = Path(data_root) self.data_root = Path(data_root)
@ -41,7 +40,7 @@ class BinaryMasksDataset(Dataset):
for row in f: for row in f:
if self.setting not in row: if self.setting not in row:
continue continue
filename, label = row.split(',') filename, label = row.strip().split(',')
labeldict[filename] = self._to_label[label.lower()] labeldict[filename] = self._to_label[label.lower()]
return labeldict return labeldict
@ -60,5 +59,5 @@ class BinaryMasksDataset(Dataset):
pickle.dump(transformed_sample, f, protocol=pickle.HIGHEST_PROTOCOL) pickle.dump(transformed_sample, f, protocol=pickle.HIGHEST_PROTOCOL)
with (self._mel_folder / filename).open(mode='rb') as f: with (self._mel_folder / filename).open(mode='rb') as f:
sample = pickle.load(f, fix_imports=True) sample = pickle.load(f, fix_imports=True)
label = self._labels[key] label = torch.as_tensor(self._labels[key], dtype=torch.float)
return sample, label return sample, label

83
main.py
View File

@ -1,81 +1,28 @@
# Imports # Imports
# ============================================================================= # =============================================================================
import os
from distutils.util import strtobool
from pathlib import Path
from argparse import ArgumentParser, Namespace
import warnings import warnings
import torch import torch
from pytorch_lightning import Trainer from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from torch.utils.data import DataLoader
from torchvision.transforms import Compose, ToTensor
from ml_lib.audio_toolset.audio_io import Melspectogram, NormalizeLocal, AutoPadToShape
from ml_lib.modules.utils import LightningBaseModule from ml_lib.modules.utils import LightningBaseModule
from ml_lib.utils.logging import Logger
# Project Specific Config and Logger SubClasses
from util.config import MConfig from util.config import MConfig
from util.logging import MLogger
warnings.filterwarnings('ignore', category=FutureWarning) warnings.filterwarnings('ignore', category=FutureWarning)
warnings.filterwarnings('ignore', category=UserWarning) warnings.filterwarnings('ignore', category=UserWarning)
_ROOT = Path(__file__).parent
# Parameter Configuration
# =============================================================================
# Argument Parser
main_arg_parser = ArgumentParser(description="parser for fast-neural-style")
# Main Parameters
main_arg_parser.add_argument("--main_debug", type=strtobool, default=False, help="")
main_arg_parser.add_argument("--main_eval", type=strtobool, default=True, help="")
main_arg_parser.add_argument("--main_seed", type=int, default=69, help="")
# Data Parameters
main_arg_parser.add_argument("--data_worker", type=int, default=10, help="")
main_arg_parser.add_argument("--data_dataset_length", type=int, default=10000, help="")
main_arg_parser.add_argument("--data_root", type=str, default='data', help="")
main_arg_parser.add_argument("--data_normalized", type=strtobool, default=True, help="")
main_arg_parser.add_argument("--data_use_preprocessed", type=strtobool, default=True, help="")
# Transformation Parameters
main_arg_parser.add_argument("--transformations_to_tensor", type=strtobool, default=False, help="")
# Training Parameters
main_arg_parser.add_argument("--train_outpath", type=str, default="output", help="")
main_arg_parser.add_argument("--train_version", type=strtobool, required=False, help="")
main_arg_parser.add_argument("--train_epochs", type=int, default=500, help="")
main_arg_parser.add_argument("--train_batch_size", type=int, default=200, help="")
main_arg_parser.add_argument("--train_lr", type=float, default=1e-3, help="")
main_arg_parser.add_argument("--train_num_sanity_val_steps", type=int, default=0, help="")
# Model Parameters
main_arg_parser.add_argument("--model_type", type=str, default="BinaryClassifier", help="")
main_arg_parser.add_argument("--model_activation", type=str, default="leaky_relu", help="")
main_arg_parser.add_argument("--model_filters", type=str, default="[16, 32, 64]", help="")
main_arg_parser.add_argument("--model_classes", type=int, default=2, help="")
main_arg_parser.add_argument("--model_lat_dim", type=int, default=16, help="")
main_arg_parser.add_argument("--model_bias", type=strtobool, default=True, help="")
main_arg_parser.add_argument("--model_norm", type=strtobool, default=False, help="")
main_arg_parser.add_argument("--model_dropout", type=float, default=0.00, help="")
# Project Parameters
main_arg_parser.add_argument("--project_name", type=str, default=_ROOT.parent.name, help="")
main_arg_parser.add_argument("--project_owner", type=str, default='si11ium', help="")
main_arg_parser.add_argument("--project_neptune_key", type=str, default=os.getenv('NEPTUNE_KEY'), help="")
# Parse it
args: Namespace = main_arg_parser.parse_args()
def run_lightning_loop(config_obj): def run_lightning_loop(config_obj):
# Logging # Logging
# ================================================================================ # ================================================================================
# Logger # Logger
with Logger(config_obj) as logger: with MLogger(config_obj) as logger:
# Callbacks # Callbacks
# ============================================================================= # =============================================================================
# Checkpoint Saving # Checkpoint Saving
@ -93,24 +40,10 @@ def run_lightning_loop(config_obj):
patience=0, patience=0,
) )
# Dataset and Dataloaders
# =============================================================================
# Transforms
transforms = Compose([Melspectogram(), ToTensor(), NormalizeLocal()])
# Datasets
from datasets.binar_masks import BinaryMasksDataset
train_dataset = BinaryMasksDataset(config_obj.data.root, setting='train', transforms=transforms)
val_dataset = BinaryMasksDataset(config_obj.data.root, setting='devel', transforms=transforms)
# Dataloaders
train_dataloader = DataLoader(train_dataset)
val_dataloader = DataLoader(val_dataset)
# Model # Model
# ============================================================================= # =============================================================================
# Build and Init its Weights # Build and Init its Weights
config_obj.set('model', 'in_shape', str(tuple(train_dataset.sample_shape))) model: LightningBaseModule = config_obj.build_and_init_model()
model: LightningBaseModule = config_obj.build_and_init_model(weight_init_function=torch.nn.init.xavier_normal_
)
# Trainer # Trainer
# ============================================================================= # =============================================================================
@ -129,7 +62,7 @@ def run_lightning_loop(config_obj):
) )
# Train It # Train It
trainer.fit(model, train_dataloader=train_dataloader, val_dataloaders=val_dataloader) trainer.fit(model)
# Save the last state & all parameters # Save the last state & all parameters
trainer.save_checkpoint(logger.log_dir / 'weights.ckpt') trainer.save_checkpoint(logger.log_dir / 'weights.ckpt')
@ -144,5 +77,7 @@ def run_lightning_loop(config_obj):
if __name__ == "__main__": if __name__ == "__main__":
config = MConfig.read_namespace(args) from _paramters import main_arg_parser
config = MConfig.read_argparser(main_arg_parser)
trained_model = run_lightning_loop(config) trained_model = run_lightning_loop(config)

44
main_inference.py Normal file
View File

@ -0,0 +1,44 @@
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import Compose, ToTensor
from ml_lib.audio_toolset.audio_io import Melspectogram, NormalizeLocal
# Dataset and Dataloaders
# =============================================================================
# Transforms
from ml_lib.utils.model_io import SavedLightningModels
from util.config import MConfig
from util.logging import MLogger
transforms = Compose([Melspectogram(), ToTensor(), NormalizeLocal()])
# Datasets
from datasets.binar_masks import BinaryMasksDataset
def prepare_dataset(config_obj):
dataset: Dataset = BinaryMasksDataset(config_obj.data.root, setting='test', transforms=transforms)
return DataLoader(dataset=dataset,
batch_size=None,
worker=config_obj.data.worker,
shuffle=False)
def restore_logger_and_model(config_obj):
logger = MLogger(config_obj)
model = SavedLightningModels().load_checkpoint(models_root_path=logger.log_dir)
model = model.restore()
return model
if __name__ == '__main__':
from _paramters import main_arg_parser
config = MConfig().read_argparser(main_arg_parser)
test_dataset = prepare_dataset(config)
loaded_model = restore_logger_and_model(config)
print("run model here and find a format to store the output")

View File

@ -1,24 +1,23 @@
from argparse import Namespace
import torch import torch
from torch import nn from torch import nn
from torch.optim import Adam from torch.optim import Adam
from torchvision.transforms import Compose, ToTensor
from ml_lib.audio_toolset.audio_io import Melspectogram, NormalizeLocal
from ml_lib.modules.blocks import ConvModule from ml_lib.modules.blocks import ConvModule
from ml_lib.modules.utils import LightningBaseModule, Flatten from ml_lib.modules.utils import LightningBaseModule, Flatten, BaseModuleMixin_Dataloaders
class BinaryClassifier(LightningBaseModule): class BinaryClassifier(BaseModuleMixin_Dataloaders, LightningBaseModule):
def test_step(self, *args, **kwargs):
pass
def test_epoch_end(self, outputs):
pass
@classmethod @classmethod
def name(cls): def name(cls):
return cls.__name__ return cls.__name__
def configure_optimizers(self): def configure_optimizers(self):
return Adam(params=self.parameters(), lr=self.hparams.lr) return Adam(params=self.parameters(), lr=self.params.lr)
def training_step(self, batch_xy, batch_nb, *args, **kwargs): def training_step(self, batch_xy, batch_nb, *args, **kwargs):
batch_x, batch_y = batch_xy batch_x, batch_y = batch_xy
@ -26,11 +25,11 @@ class BinaryClassifier(LightningBaseModule):
loss = self.criterion(y, batch_y) loss = self.criterion(y, batch_y)
return dict(loss=loss) return dict(loss=loss)
def validation_step(self, batch_xy, **kwargs): def validation_step(self, batch_xy, batch_idx, *args, **kwargs):
batch_x, batch_y = batch_xy batch_x, batch_y = batch_xy
y = self(batch_y) y = self(batch_x)
val_loss = self.criterion(y, batch_y) val_loss = self.criterion(y, batch_y)
return dict(val_loss=val_loss) return dict(val_loss=val_loss, batch_idx=batch_idx)
def validation_epoch_end(self, outputs): def validation_epoch_end(self, outputs):
overall_val_loss = torch.mean(torch.stack([output['val_loss'] for output in outputs])) overall_val_loss = torch.mean(torch.stack([output['val_loss'] for output in outputs]))
@ -41,22 +40,36 @@ class BinaryClassifier(LightningBaseModule):
def __init__(self, hparams): def __init__(self, hparams):
super(BinaryClassifier, self).__init__(hparams) super(BinaryClassifier, self).__init__(hparams)
self.criterion = nn.BCELoss() # Dataset and Dataloaders
# =============================================================================
# Transforms
transforms = Compose([Melspectogram(), ToTensor(), NormalizeLocal()])
# Datasets
from datasets.binar_masks import BinaryMasksDataset
self.dataset = Namespace(
**dict(
train_dataset=BinaryMasksDataset(self.params.root, setting='train', transforms=transforms),
val_dataset=BinaryMasksDataset(self.params.root, setting='devel', transforms=transforms),
test_dataset=BinaryMasksDataset(self.params.root, setting='test', transforms=transforms),
)
)
# Model Paramters
# =============================================================================
# Additional parameters # Additional parameters
self.in_shape = self.hparams.in_shape self.in_shape = self.dataset.train_dataset.sample_shape
self.criterion = nn.BCELoss()
# Model Modules # Modules
self.conv_1 = ConvModule(self.in_shape, 32, 3, conv_stride=2, **self.hparams.module_paramters) self.conv_1 = ConvModule(self.in_shape, 32, 3, conv_stride=2, **self.params.module_kwargs)
self.conv_2 = ConvModule(self.conv_1.shape, 64, 5, conv_stride=2, **self.hparams.module_paramters) self.conv_2 = ConvModule(self.conv_1.shape, 64, 5, conv_stride=2, **self.params.module_kwargs)
self.conv_3 = ConvModule(self.conv_2.shape, 128, 7, conv_stride=2, **self.hparams.module_paramters) self.conv_3 = ConvModule(self.conv_2.shape, 128, 7, conv_stride=2, **self.params.module_kwargs)
self.flat = Flatten(self.conv_3.shape) self.flat = Flatten(self.conv_3.shape)
self.full_1 = nn.Linear(self.flat.shape, 32, self.hparams.bias) self.full_1 = nn.Linear(self.flat.shape, 32, self.params.bias)
self.full_2 = nn.Linear(self.full_1.out_features, self.full_1.out_features // 2, self.hparams.bias) self.full_2 = nn.Linear(self.full_1.out_features, self.full_1.out_features // 2, self.params.bias)
self.activation = self.hparams.activation() self.activation = self.params.activation()
self.full_out = nn.Linear(self.full_2.out_features, 1, self.hparams.bias) self.full_out = nn.Linear(self.full_2.out_features, 1, self.params.bias)
self.sigmoid = nn.Sigmoid() self.sigmoid = nn.Sigmoid()
def forward(self, batch, **kwargs): def forward(self, batch, **kwargs):

View File

@ -5,5 +5,5 @@ from models.binary_classifier import BinaryClassifier
class MConfig(Config): class MConfig(Config):
@property @property
def model_map(self): def _model_map(self):
return dict(BinaryClassifier=BinaryClassifier) return dict(BinaryClassifier=BinaryClassifier)

11
util/logging.py Normal file
View File

@ -0,0 +1,11 @@
from pathlib import Path
from ml_lib.utils.logging import Logger
class MLogger(Logger):
@property
def outpath(self):
# FIXME: Specify a special path
return Path(self.config.train.outpath)

View File

@ -1,3 +1,6 @@
# Labels # Labels
CLEAR = 0 CLEAR = 0
MASK = 1 MASK = 1
# Dataset Options
DATA_OPTIONS = ['test', 'devel', 'train']