eval written

This commit is contained in:
Si11ium 2020-03-05 16:58:23 +01:00
parent 8d06c179c9
commit 1f25bf599b
12 changed files with 127 additions and 74 deletions

11
.idea/deployment.xml generated
View File

@ -1,8 +1,15 @@
<?xml version="1.0" encoding="UTF-8"?> <?xml version="1.0" encoding="UTF-8"?>
<project version="4"> <project version="4">
<component name="PublishConfigData" autoUpload="On explicit save action" serverName="steffen@aimachine:22"> <component name="PublishConfigData" autoUpload="On explicit save action" serverName="traj_gen-AiMachine" showAutoUploadSettingsWarning="false">
<serverData> <serverData>
<paths name="steffen@aimachine:22"> <paths name="ErLoWa-AiMachine">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="traj_gen-AiMachine">
<serverdata> <serverdata>
<mappings> <mappings>
<mapping deploy="/" local="$PROJECT_DIR$" web="/" /> <mapping deploy="/" local="$PROJECT_DIR$" web="/" />

View File

@ -2,7 +2,7 @@
<module type="PYTHON_MODULE" version="4"> <module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager"> <component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" /> <content url="file://$MODULE_DIR$" />
<orderEntry type="jdk" jdkName="Remote Python 3.7.6 (sftp://steffen@aimachine:22/home/steffen/envs/traj_gen/bin/python)" jdkType="Python SDK" /> <orderEntry type="jdk" jdkName="traj_gen@AiMachine" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" /> <orderEntry type="sourceFolder" forTests="false" />
</component> </component>
</module> </module>

2
.idea/misc.xml generated
View File

@ -3,5 +3,5 @@
<component name="JavaScriptSettings"> <component name="JavaScriptSettings">
<option name="languageLevel" value="ES6" /> <option name="languageLevel" value="ES6" />
</component> </component>
<component name="ProjectRootManager" version="2" project-jdk-name="hom_traj_gen@aimachine" project-jdk-type="Python SDK" /> <component name="ProjectRootManager" version="2" project-jdk-name="traj_gen@AiMachine" project-jdk-type="Python SDK" />
</project> </project>

View File

@ -1,29 +1,24 @@
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc from sklearn.metrics import roc_curve, auc
class ROCEvaluation(object): class ROCEvaluation(object):
BINARY_PROBLEM = 2
linewidth = 2 linewidth = 2
def __init__(self, save_fig=True): def __init__(self, prepare_figure=False):
self.prepare_figure = prepare_figure
self.epoch = 0 self.epoch = 0
pass
def __call__(self, prediction, label, prepare_fig=True): def __call__(self, prediction, label, plotting=False):
# Compute ROC curve and ROC area # Compute ROC curve and ROC area
fpr, tpr, _ = roc_curve(prediction, label) fpr, tpr, _ = roc_curve(prediction, label)
roc_auc = auc(fpr, tpr) roc_auc = auc(fpr, tpr)
if plotting:
if prepare_fig: fig = plt.gcf()
fig = self._prepare_fig() fig.plot(fpr, tpr, color='darkorange', lw=self.linewidth, label=f'ROC curve (area = {roc_auc})')
fig.plot(fpr, tpr, color='darkorange', return roc_auc, fpr, tpr
lw=2, label=f'ROC curve (area = {roc_auc})')
self._prepare_fig()
return roc_auc
def _prepare_fig(self): def _prepare_fig(self):
fig = plt.gcf() fig = plt.gcf()
@ -32,6 +27,6 @@ class ROCEvaluation(object):
fig.ylim([0.0, 1.05]) fig.ylim([0.0, 1.05])
fig.xlabel('False Positive Rate') fig.xlabel('False Positive Rate')
fig.ylabel('True Positive Rate') fig.ylabel('True Positive Rate')
fig.legend(loc="lower right") fig.legend(loc="lower right")
return fig return fig

View File

@ -1,4 +1,4 @@
from dataset.dataset import TrajPairData from datasets.paired_dataset import TrajPairData
from lib.modules.blocks import ConvModule from lib.modules.blocks import ConvModule
from lib.modules.utils import LightningBaseModule from lib.modules.utils import LightningBaseModule

View File

@ -5,8 +5,10 @@ import torch
from torch import nn from torch import nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.optim import Adam from torch.optim import Adam
from torch.utils.data import DataLoader
from datasets.trajectory_dataset import TrajData from datasets.trajectory_dataset import TrajData
from lib.evaluation.classification import ROCEvaluation
from lib.modules.utils import LightningBaseModule, Flatten from lib.modules.utils import LightningBaseModule, Flatten
from lib.modules.blocks import ConvModule, ResidualModule from lib.modules.blocks import ConvModule, ResidualModule
@ -24,6 +26,22 @@ class ConvHomDetector(LightningBaseModule):
loss = F.binary_cross_entropy(pred_y, batch_y.float()) loss = F.binary_cross_entropy(pred_y, batch_y.float())
return {'loss': loss, 'log': dict(loss=loss)} return {'loss': loss, 'log': dict(loss=loss)}
def test_step(self, batch_xy, **kwargs):
batch_x, batch_y = batch_xy
pred_y = self(batch_x)
return dict(prediction=pred_y, label=batch_y)
def test_end(self, outputs):
evaluation = ROCEvaluation()
predictions = torch.stack([x['prediction'] for x in outputs])
labels = torch.stack([x['label'] for x in outputs])
scores = evaluation(predictions.numpy(), labels.numpy())
self.logger.log_metrics()
pass
def __init__(self, *params): def __init__(self, *params):
super(ConvHomDetector, self).__init__(*params) super(ConvHomDetector, self).__init__(*params)
@ -70,6 +88,26 @@ class ConvHomDetector(LightningBaseModule):
self.classifier = nn.Linear(self.hparams.model_param.classes * 10, 1) # self.hparams.model_param.classes) self.classifier = nn.Linear(self.hparams.model_param.classes * 10, 1) # self.hparams.model_param.classes)
self.out_activation = nn.Sigmoid() # nn.Softmax self.out_activation = nn.Sigmoid() # nn.Softmax
# Dataloaders
# ================================================================================
# Train Dataloader
def train_dataloader(self):
return DataLoader(dataset=self.dataset.train_dataset, shuffle=True,
batch_size=self.hparams.data_param.batchsize,
num_workers=self.hparams.data_param.worker)
# Test Dataloader
def test_dataloader(self):
return DataLoader(dataset=self.dataset.test_dataset, shuffle=True,
batch_size=self.hparams.data_param.batchsize,
num_workers=self.hparams.data_param.worker)
# Validation Dataloader
def val_dataloader(self):
return DataLoader(dataset=self.dataset.val_dataset, shuffle=True,
batch_size=self.hparams.data_param.batchsize,
num_workers=self.hparams.data_param.worker)
def forward(self, x): def forward(self, x):
tensor = self.map_conv_0(x) tensor = self.map_conv_0(x)
tensor = self.map_res_1(tensor) tensor = self.map_res_1(tensor)

View File

@ -105,24 +105,6 @@ class LightningBaseModule(pl.LightningModule, ABC):
torch.save(self.__class__, f) torch.save(self.__class__, f)
return True return True
@pl.data_loader
def train_dataloader(self):
return DataLoader(dataset=self.dataset.train_dataset, shuffle=True,
batch_size=self.hparams.data_param.batchsize,
num_workers=self.hparams.data_param.worker)
@pl.data_loader
def test_dataloader(self):
return DataLoader(dataset=self.dataset.test_dataset, shuffle=True,
batch_size=self.hparams.data_param.batchsize,
num_workers=self.hparams.data_param.worker)
@pl.data_loader
def val_dataloader(self):
return DataLoader(dataset=self.dataset.val_dataset, shuffle=True,
batch_size=self.hparams.data_param.batchsize,
num_workers=self.hparams.data_param.worker)
@property @property
def data_len(self): def data_len(self):
return len(self.dataset.train_dataset) return len(self.dataset.train_dataset)

View File

@ -79,20 +79,37 @@ class Config(ConfigParser):
super(Config, self).__init__(**kwargs) super(Config, self).__init__(**kwargs)
pass pass
@staticmethod
def _sort_combined_section_key_mapping(dict_obj):
sorted_dict = defaultdict(dict)
for key in dict_obj:
section, *attr_name = key.split('_')
attr_name = '_'.join(attr_name)
value = str(dict_obj[key])
sorted_dict[section][attr_name] = value
# noinspection PyTypeChecker
return dict(sorted_dict)
@classmethod @classmethod
def read_namespace(cls, namespace: Namespace): def read_namespace(cls, namespace: Namespace):
space_dict = defaultdict(dict) sorted_dict = cls._sort_combined_section_key_mapping(namespace.__dict__)
for key in namespace.__dict__:
section, *attr_name = key.split('_')
attr_name = '_'.join(attr_name)
value = str(namespace.__getattribute__(key))
space_dict[section][attr_name] = value
new_config = cls() new_config = cls()
new_config.read_dict(space_dict) new_config.read_dict(sorted_dict)
return new_config return new_config
def update(self, mapping):
sorted_dict = self._sort_combined_section_key_mapping(mapping)
for section in sorted_dict:
if self.has_section(section):
pass
else:
self.add_section(section)
for option, value in sorted_dict[section].items():
self.set(section, option, value)
return self
def get(self, *args, **kwargs): def get(self, *args, **kwargs):
item = super(Config, self).get(*args, **kwargs) item = super(Config, self).get(*args, **kwargs)
try: try:
@ -108,5 +125,4 @@ class Config(ConfigParser):
with path.open('w') as configfile: with path.open('w') as configfile:
super().write(configfile) super().write(configfile)
return True return True

View File

@ -78,6 +78,14 @@ class Logger(LightningLoggerBase):
def log_config_as_ini(self): def log_config_as_ini(self):
self.config.write(self.log_dir) self.config.write(self.log_dir)
def log_metric(self, metric_name, metric_value, **kwargs):
self.testtubelogger.log_metrics(dict(metric_name=metric_value))
self.neptunelogger.log_metric(metric_name, metric_value, **kwargs)
def log_image(self, name, image, **kwargs):
self.neptunelogger.log_image(name, image, **kwargs)
image.savefig(self.log_dir / name)
def save(self): def save(self):
self.testtubelogger.save() self.testtubelogger.save()
self.neptunelogger.save() self.neptunelogger.save()

38
main.py
View File

@ -3,18 +3,21 @@
import os import os
from distutils.util import strtobool from distutils.util import strtobool
from pathlib import Path from pathlib import Path
from argparse import ArgumentParser from argparse import ArgumentParser, Namespace
import warnings import warnings
import torch import torch
from pytorch_lightning import Trainer from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from torch.utils.data import DataLoader
from lib.modules.utils import LightningBaseModule from lib.modules.utils import LightningBaseModule
from lib.utils.config import Config from lib.utils.config import Config
from lib.utils.logging import Logger from lib.utils.logging import Logger
from lib.evaluation.classification import ROCEvaluation
warnings.filterwarnings('ignore', category=FutureWarning) warnings.filterwarnings('ignore', category=FutureWarning)
warnings.filterwarnings('ignore', category=UserWarning) warnings.filterwarnings('ignore', category=UserWarning)
@ -31,7 +34,7 @@ main_arg_parser.add_argument("--main_eval", type=strtobool, default=False, help=
main_arg_parser.add_argument("--main_seed", type=int, default=69, help="") main_arg_parser.add_argument("--main_seed", type=int, default=69, help="")
# Data Parameters # Data Parameters
main_arg_parser.add_argument("--data_worker", type=int, default=0, help="") main_arg_parser.add_argument("--data_worker", type=int, default=10, help="")
main_arg_parser.add_argument("--data_batchsize", type=int, default=100, help="") main_arg_parser.add_argument("--data_batchsize", type=int, default=100, help="")
main_arg_parser.add_argument("--data_root", type=str, default='/data/', help="") main_arg_parser.add_argument("--data_root", type=str, default='/data/', help="")
main_arg_parser.add_argument("--data_map_root", type=str, default='/res/maps', help="") main_arg_parser.add_argument("--data_map_root", type=str, default='/res/maps', help="")
@ -61,16 +64,15 @@ main_arg_parser.add_argument("--project_owner", type=str, default='si11ium', hel
main_arg_parser.add_argument("--project_neptune_key", type=str, default=os.getenv('NEPTUNE_KEY'), help="") main_arg_parser.add_argument("--project_neptune_key", type=str, default=os.getenv('NEPTUNE_KEY'), help="")
# Parse it # Parse it
args = main_arg_parser.parse_args() args: Namespace = main_arg_parser.parse_args()
config = Config.read_namespace(args)
if __name__ == "__main__": def run_lightning_loop(config_obj):
# Logging # Logging
# ============================================================================= # ================================================================================
# Logger # Logger
with Logger(config) as logger: with Logger(config_obj) as logger:
# Callbacks # Callbacks
# ============================================================================= # =============================================================================
# Checkpoint Saving # Checkpoint Saving
@ -90,12 +92,12 @@ if __name__ == "__main__":
# Model # Model
# ============================================================================= # =============================================================================
# Init # Init
model: LightningBaseModule = config.model_class(config.model_paramters) model: LightningBaseModule = config_obj.model_class(config_obj.model_paramters)
model.init_weights() model.init_weights()
# Trainer # Trainer
# ============================================================================= # =============================================================================
trainer = Trainer(max_epochs=config.train.epochs, trainer = Trainer(max_epochs=config_obj.train.epochs,
show_progress_bar=True, show_progress_bar=True,
weights_save_path=logger.log_dir, weights_save_path=logger.log_dir,
gpus=[0] if torch.cuda.is_available() else None, gpus=[0] if torch.cuda.is_available() else None,
@ -103,15 +105,23 @@ if __name__ == "__main__":
log_save_interval=(model.data_len * 0.04), # TODO: Better Value / Setting log_save_interval=(model.data_len * 0.04), # TODO: Better Value / Setting
checkpoint_callback=checkpoint_callback, checkpoint_callback=checkpoint_callback,
logger=logger, logger=logger,
fast_dev_run=config.main.debug, fast_dev_run=config_obj.main.debug,
early_stop_callback=None early_stop_callback=None
) )
# Train it # Train It
trainer.fit(model) trainer.fit(model,)
# Save the last state & all parameters # Save the last state & all parameters
trainer.save_checkpoint(logger.log_dir / 'weights.ckpt') trainer.save_checkpoint(logger.log_dir / 'weights.ckpt')
model.save_to_disk(logger.log_dir) model.save_to_disk(logger.log_dir)
pass
# TODO: Eval here! # Evaluate It
trainer.test()
return model
if __name__ == "__main__":
config = Config.read_namespace(args)
trained_model = run_lightning_loop(config)

View File

View File

@ -7,29 +7,26 @@ warnings.filterwarnings('ignore', category=UserWarning)
# Imports # Imports
# ============================================================================= # =============================================================================
from pathlib import Path
import os from main import run_training, args
if __name__ == '__main__': if __name__ == '__main__':
# Model Settings # Model Settings
warnings.filterwarnings('ignore', category=FutureWarning) config = Config().read_namespace(args)
# use_bias, activation, model, use_norm, max_epochs, filters # use_bias, activation, model, use_norm, max_epochs, filters
cnn_classifier = [True, 'leaky_relu', 'classifier_cnn', False, 2, [16, 32, 64]] cnn_classifier = [True, 'leaky_relu', 'classifier_cnn', False, 2, [16, 32, 64]]
# use_bias, activation, model, use_norm, max_epochs, sr, lat_dim, filters # use_bias, activation, model, use_norm, max_epochs, sr, lat_dim, filters
# Data Settings for use_bias, activation, model, use_norm, max_epochs, filters in [cnn_classifier]:
data_shortcodes = ['mid', 'mid_5'] for seed in range(5):
arg_dict = dict(main_seed=seed, train_max_epochs=max_epochs,
model_use_bias=use_bias, model_use_norm=use_norm,
model_activation=activation, model_type=model,
model_filters=filters,
data_batch_size=512)
# Iteration over config = config.update(arg_dict)
for data_shortcode in data_shortcodes:
for use_bias, activation, model, use_norm, max_epochs, filters in [cnn_classifier]:
for seed in range(5):
arg_dict = dict(main_seed=seed, train_max_epochs=max_epochs,
model_use_bias=use_bias, model_use_norm=use_norm,
model_activation=activation, model_type=model,
model_filters=filters,
data_batch_size=512)
os.system(f'/home/steffen/envs/traj_gen/bin/python main.py {arg_dict}') run_training(config)