initial commit - just template files

This commit is contained in:
Si11ium 2020-05-19 09:20:54 +02:00
parent 499691fbc9
commit 9ccbec9d7c
9 changed files with 379 additions and 0 deletions

5
.gitignore vendored Normal file
View File

@ -0,0 +1,5 @@
# my own stuff
/data
/.idea
/ml_lib

0
__init__.py Normal file
View File

57
_parameters.py Normal file
View File

@ -0,0 +1,57 @@
# Imports
# =============================================================================
import os
from distutils.util import strtobool
from argparse import ArgumentParser, Namespace
# Parameter Configuration
# =============================================================================
# Argument Parser
main_arg_parser = ArgumentParser(description="parser for fast-neural-style")
# Main Parameters
main_arg_parser.add_argument("--main_debug", type=strtobool, default=False, help="")
main_arg_parser.add_argument("--main_eval", type=strtobool, default=True, help="")
main_arg_parser.add_argument("--main_seed", type=int, default=69, help="")
# Project
main_arg_parser.add_argument("--project_name", type=str, default='traj-gen', help="")
main_arg_parser.add_argument("--project_owner", type=str, default='si11ium', help="")
main_arg_parser.add_argument("--project_neptune_key", type=str, default=os.getenv('NEPTUNE_KEY'), help="")
# Data Parameters
main_arg_parser.add_argument("--data_worker", type=int, default=10, help="")
main_arg_parser.add_argument("--data_dataset_length", type=int, default=10000, help="")
main_arg_parser.add_argument("--data_root", type=str, default='data', help="")
main_arg_parser.add_argument("--data_additional_resource_root", type=str, default='res/resource/root', help="")
main_arg_parser.add_argument("--data_use_preprocessed", type=strtobool, default=True, help="")
# Transformations
main_arg_parser.add_argument("--transformations_to_tensor", type=strtobool, default=False, help="")
main_arg_parser.add_argument("--transformations_normalize", type=strtobool, default=False, help="")
# Transformations
main_arg_parser.add_argument("--train_outpath", type=str, default="output", help="")
main_arg_parser.add_argument("--train_version", type=strtobool, required=False, help="")
main_arg_parser.add_argument("--train_epochs", type=int, default=500, help="")
main_arg_parser.add_argument("--train_batch_size", type=int, default=200, help="")
main_arg_parser.add_argument("--train_lr", type=float, default=1e-3, help="")
main_arg_parser.add_argument("--train_num_sanity_val_steps", type=int, default=0, help="")
# Model
main_arg_parser.add_argument("--model_type", type=str, default="CNNRouteGenerator", help="")
main_arg_parser.add_argument("--model_activation", type=str, default="leaky_relu", help="")
main_arg_parser.add_argument("--model_use_bias", type=strtobool, default=True, help="")
main_arg_parser.add_argument("--model_use_norm", type=strtobool, default=False, help="")
main_arg_parser.add_argument("--model_dropout", type=float, default=0.00, help="")
# Model 2: Layer Specific Stuff
main_arg_parser.add_argument("--model_filters", type=str, default="[16, 32, 64]", help="")
main_arg_parser.add_argument("--model_features", type=int, default=16, help="")
# Parse it
args: Namespace = main_arg_parser.parse_args()
if __name__ == '__main__':
pass

View File

@ -0,0 +1,6 @@
from torch.utils.data import Dataset
class TemplateDataset(Dataset):
def __init__(self, *args, **kwargs):
super(TemplateDataset, self).__init__()

81
main.py Normal file
View File

@ -0,0 +1,81 @@
# Imports
# =============================================================================
import warnings
import torch
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from modules.utils import LightningBaseModule
from utils.config import Config
from utils.logging import Logger
from utils.model_io import SavedLightningModels
warnings.filterwarnings('ignore', category=FutureWarning)
warnings.filterwarnings('ignore', category=UserWarning)
def run_lightning_loop(config_obj):
# Logging
# ================================================================================
# Logger
with Logger(config_obj) as logger:
# Callbacks
# =============================================================================
# Checkpoint Saving
checkpoint_callback = ModelCheckpoint(
filepath=str(logger.log_dir / 'ckpt_weights'),
verbose=True, save_top_k=0,
)
# =============================================================================
# Early Stopping
# TODO: For This to work, one must set a validation step and End Eval and Score
early_stopping_callback = EarlyStopping(
monitor='val_loss',
min_delta=0.0,
patience=0,
)
# Model
# =============================================================================
# Init
model: LightningBaseModule = config_obj.model_class(config_obj.model_paramters)
model.init_weights(torch.nn.init.xavier_normal_)
# Trainer
# =============================================================================
trainer = Trainer(max_epochs=config_obj.train.epochs,
show_progress_bar=True,
weights_save_path=logger.log_dir,
gpus=[0] if torch.cuda.is_available() else None,
check_val_every_n_epoch=10,
# num_sanity_val_steps=config_obj.train.num_sanity_val_steps,
# row_log_interval=(model.n_train_batches * 0.1), # TODO: Better Value / Setting
# log_save_interval=(model.n_train_batches * 0.2), # TODO: Better Value / Setting
checkpoint_callback=checkpoint_callback,
logger=logger,
fast_dev_run=config_obj.main.debug,
early_stop_callback=None
)
# Train It
trainer.fit(model)
# Save the last state & all parameters
trainer.save_checkpoint(config_obj.exp_path.log_dir / 'weights.ckpt')
model.save_to_disk(config_obj.exp_path)
# Evaluate It
if config_obj.main.eval:
trainer.test()
return model
if __name__ == "__main__":
from _templates.new_project._parameters import args
config = Config.read_namespace(args)
trained_model = run_lightning_loop(config)

28
multi_run.py Normal file
View File

@ -0,0 +1,28 @@
import warnings
from _templates.new_project.utils.project_config import Config
warnings.filterwarnings('ignore', category=FutureWarning)
warnings.filterwarnings('ignore', category=UserWarning)
# Imports
# =============================================================================
from _templates.new_project.main import run_lightning_loop, args
if __name__ == '__main__':
# Model Settings
config = Config().read_namespace(args)
# bias, activation, model, norm, max_epochs
cnn_classifier = dict(train_epochs=10, model_use_bias=True, model_use_norm=True, data_batchsize=512)
# bias, activation, model, norm, max_epochs
for arg_dict in [cnn_classifier]:
for seed in range(5):
arg_dict.update(main_seed=seed)
config = config.update(arg_dict)
run_lightning_loop(config)

0
utils/__init__.py Normal file
View File

172
utils/module_mixins.py Normal file
View File

@ -0,0 +1,172 @@
from collections import defaultdict
from abc import ABC
from argparse import Namespace
import torch
from torch import nn
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchcontrib.optim import SWA
from torchvision.transforms import Compose
from _templates.new_project.datasets.template_dataset import TemplateDataset
from audio_toolset.audio_io import NormalizeLocal
from modules.utils import LightningBaseModule
from utils.transforms import ToTensor
from _templates.new_project.utils.project_config import GlobalVar as GlobalVars
class BaseOptimizerMixin:
def configure_optimizers(self):
assert isinstance(self, LightningBaseModule)
opt = Adam(params=self.parameters(), lr=self.params.lr, weight_decay=self.params.weight_decay)
if self.params.sto_weight_avg:
# TODO: Make this glabaly available.
opt = SWA(opt, swa_start=10, swa_freq=5, swa_lr=0.05)
return opt
def on_train_end(self):
assert isinstance(self, LightningBaseModule)
for opt in self.trainer.optimizers:
if isinstance(opt, SWA):
opt.swap_swa_sgd()
def on_epoch_end(self):
assert isinstance(self, LightningBaseModule)
if self.params.opt_reset_interval:
if self.current_epoch % self.params.opt_reset_interval == 0:
for opt in self.trainer.optimizers:
opt.state = defaultdict(dict)
class BaseTrainMixin:
absolute_loss = nn.L1Loss()
nll_loss = nn.NLLLoss()
bce_loss = nn.BCELoss()
def training_step(self, batch_xy, batch_nb, *_, **__):
assert isinstance(self, LightningBaseModule)
batch_x, batch_y = batch_xy
y = self(batch_x).main_out
bce_loss = self.bce_loss(y, batch_y)
return dict(loss=bce_loss, log=dict(batch_nb=batch_nb))
def training_epoch_end(self, outputs):
assert isinstance(self, LightningBaseModule)
keys = list(outputs[0].keys())
summary_dict = dict(log={f'mean_{key}': torch.mean(torch.stack([output[key]
for output in outputs]))
for key in keys if 'loss' in key})
return summary_dict
class BaseValMixin:
absolute_loss = nn.L1Loss()
nll_loss = nn.NLLLoss()
bce_loss = nn.BCELoss()
def validation_step(self, batch_xy, batch_idx, _, *__, **___):
assert isinstance(self, LightningBaseModule)
batch_x, batch_y = batch_xy
y = self(batch_x).main_out
val_bce_loss = self.bce_loss(y, batch_y)
return dict(val_bce_loss=val_bce_loss,
batch_idx=batch_idx, y=y, batch_y=batch_y)
def validation_epoch_end(self, outputs, *_, **__):
assert isinstance(self, LightningBaseModule)
summary_dict = dict(log=dict())
# In case of Multiple given dataloader this will outputs will be: list[list[dict[]]]
# for output_idx, output in enumerate(outputs):
# else:list[dict[]]
keys = list(outputs.keys())
# Add Every Value das has a "loss" in it, by calc. mean over all occurences.
summary_dict['log'].update({f'mean_{key}': torch.mean(torch.stack([output[key]
for output in outputs]))
for key in keys if 'loss' in key}
)
"""
# Additional Score like the unweighted Average Recall:
# UnweightedAverageRecall
y_true = torch.cat([output['batch_y'] for output in outputs]) .cpu().numpy()
y_pred = torch.cat([output['y'] for output in outputs]).squeeze().cpu().numpy()
y_pred = (y_pred >= 0.5).astype(np.float32)
uar_score = sklearn.metrics.recall_score(y_true, y_pred, labels=[0, 1], average='macro',
sample_weight=None, zero_division='warn')
summary_dict['log'].update({f'uar_score': uar_score})
"""
return summary_dict
class BinaryMaskDatasetMixin:
def build_dataset(self):
assert isinstance(self, LightningBaseModule)
# Dataset
# =============================================================================
# Data Augmentations or Utility Transformations
transforms = Compose([NormalizeLocal(), ToTensor()])
# Dataset
dataset = Namespace(
**dict(
# TRAIN DATASET
train_dataset=TemplateDataset(self.params.root, setting=GlobalVars.DATA_OPTIONS.train,
transforms=transforms
),
# VALIDATION DATASET
val_dataset=TemplateDataset(self.params.root, setting=GlobalVars.vali,
),
# TEST DATASET
test_dataset=TemplateDataset(self.params.root, setting=GlobalVars.test,
),
)
)
return dataset
class BaseDataloadersMixin(ABC):
# Dataloaders
# ================================================================================
# Train Dataloader
def train_dataloader(self):
assert isinstance(self, LightningBaseModule)
# In case you want to implement bootstraping
# sampler = RandomSampler(self.dataset.train_dataset, True, len(self.dataset.train_dataset))
sampler = None
return DataLoader(dataset=self.dataset.train_dataset, shuffle=True if not sampler else None, sampler=sampler,
batch_size=self.params.batch_size,
num_workers=self.params.worker)
# Test Dataloader
def test_dataloader(self):
assert isinstance(self, LightningBaseModule)
return DataLoader(dataset=self.dataset.test_dataset, shuffle=False,
batch_size=self.params.batch_size,
num_workers=self.params.worker)
# Validation Dataloader
def val_dataloader(self):
assert isinstance(self, LightningBaseModule)
val_dataloader = DataLoader(dataset=self.dataset.val_dataset, shuffle=False,
batch_size=self.params.batch_size, num_workers=self.params.worker)
# Alternative return [val_dataloader, alternative dataloader], there will be a dataloader_idx in validation_step
return val_dataloader

30
utils/project_config.py Normal file
View File

@ -0,0 +1,30 @@
from argparse import Namespace
from utils.config import Config
class GlobalVar(Namespace):
# Labels for classes
LEFT = 1
RIGHT = 0
WRONG = -1
# Colors for img files
WHITE = 255
BLACK = 0
# Variables for plotting
PADDING = 0.25
DPI = 50
# DATAOPTIONS
train='train',
vali='vali',
test='test'
class ThisConfig(Config):
@property
def _model_map(self):
return dict()