Dataset rdy

This commit is contained in:
Steffen Illium
2021-02-16 10:18:04 +01:00
parent 151b22a2c3
commit 7edd3834a1
11 changed files with 350 additions and 15 deletions

74
main.py
View File

@ -1,12 +1,68 @@
from pathlib import Path
import configparser
from argparse import ArgumentParser, Namespace
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from ml_lib.utils.logging import Logger
from ml_lib.utils.tools import locate_and_import_class, auto_cast
import variables as v
from datasets.primates_librosa_datamodule import PrimatesLibrosaDatamodule
data_root = Path() / 'data'
if __name__ == '__main__':
dataset = PrimatesLibrosaDatamodule(data_root, batch_size=25, num_worker=6,
sr=v.sr, n_mels=64, n_fft=512, hop_length=256)
dataset.prepare_data()
print('done')
# Argument Parser and default Values
# =============================================================================
# Load Defaults from _parameters.ini file
config = configparser.ConfigParser()
config.read('_parameters.ini')
project = config['project']
data_class = locate_and_import_class(project['data_name'], 'datasets')
model_class = locate_and_import_class(project['model_name'], 'models')
tmp_params = dict()
for key in ['project', 'train', 'data', 'model_cnn']:
defaults = config[key]
tmp_params.update({key: auto_cast(val) for key, val in defaults.items()})
# Parse Command Line
parser = ArgumentParser()
for module in [Logger, Trainer, data_class, model_class]:
parser = module.add_argparse_args(parser)
cmd_args, _ = parser.parse_known_args()
tmp_params.update({key: val for key, val in vars(cmd_args).items() if val is not None})
hparams = Namespace(**tmp_params)
with Logger.from_argparse_args(hparams) as logger:
# Callbacks
# =============================================================================
# Checkpoint Saving
ckpt_callback = ModelCheckpoint(
monitor='mean_loss',
filepath=str(logger.log_dir / 'ckpt_weights'),
verbose=False,
save_top_k=3,
)
# Learning Rate Logger
lr_logger = LearningRateMonitor(logging_interval='epoch')
#
# START
# =============================================================================
# Let Datamodule pull what it wants
datamodule = data_class.from_argparse_args(hparams)
datamodule.setup()
model_in_shape = datamodule.shape
# Let Trainer pull what it wants and add callbacks
trainer = Trainer.from_argparse_args(hparams, callbacks=[ckpt_callback, lr_logger])
# Let Model pull what it wants
model = model_class.from_argparse_args(hparams, in_shape=datamodule.shape, n_classes=v.N_CLASS_multi)
model.init_weights()
logger.log_hyperparams(dict(model.params))
trainer.fit(model, datamodule)
trainer.save_checkpoint(trainer.logger.save_dir)