eval written

This commit is contained in:
Si11ium 2020-03-05 16:58:23 +01:00
parent 8d06c179c9
commit 1f25bf599b
12 changed files with 127 additions and 74 deletions

11
.idea/deployment.xml generated
View File

@ -1,8 +1,15 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="PublishConfigData" autoUpload="On explicit save action" serverName="steffen@aimachine:22">
<component name="PublishConfigData" autoUpload="On explicit save action" serverName="traj_gen-AiMachine" showAutoUploadSettingsWarning="false">
<serverData>
<paths name="steffen@aimachine:22">
<paths name="ErLoWa-AiMachine">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="traj_gen-AiMachine">
<serverdata>
<mappings>
<mapping deploy="/" local="$PROJECT_DIR$" web="/" />

View File

@ -2,7 +2,7 @@
<module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" />
<orderEntry type="jdk" jdkName="Remote Python 3.7.6 (sftp://steffen@aimachine:22/home/steffen/envs/traj_gen/bin/python)" jdkType="Python SDK" />
<orderEntry type="jdk" jdkName="traj_gen@AiMachine" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
</module>

2
.idea/misc.xml generated
View File

@ -3,5 +3,5 @@
<component name="JavaScriptSettings">
<option name="languageLevel" value="ES6" />
</component>
<component name="ProjectRootManager" version="2" project-jdk-name="hom_traj_gen@aimachine" project-jdk-type="Python SDK" />
<component name="ProjectRootManager" version="2" project-jdk-name="traj_gen@AiMachine" project-jdk-type="Python SDK" />
</project>

View File

@ -1,29 +1,24 @@
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc
class ROCEvaluation(object):
BINARY_PROBLEM = 2
linewidth = 2
def __init__(self, save_fig=True):
def __init__(self, prepare_figure=False):
self.prepare_figure = prepare_figure
self.epoch = 0
pass
def __call__(self, prediction, label, prepare_fig=True):
def __call__(self, prediction, label, plotting=False):
# Compute ROC curve and ROC area
fpr, tpr, _ = roc_curve(prediction, label)
roc_auc = auc(fpr, tpr)
if prepare_fig:
fig = self._prepare_fig()
fig.plot(fpr, tpr, color='darkorange',
lw=2, label=f'ROC curve (area = {roc_auc})')
self._prepare_fig()
return roc_auc
if plotting:
fig = plt.gcf()
fig.plot(fpr, tpr, color='darkorange', lw=self.linewidth, label=f'ROC curve (area = {roc_auc})')
return roc_auc, fpr, tpr
def _prepare_fig(self):
fig = plt.gcf()
@ -32,6 +27,6 @@ class ROCEvaluation(object):
fig.ylim([0.0, 1.05])
fig.xlabel('False Positive Rate')
fig.ylabel('True Positive Rate')
fig.legend(loc="lower right")
return fig

View File

@ -1,4 +1,4 @@
from dataset.dataset import TrajPairData
from datasets.paired_dataset import TrajPairData
from lib.modules.blocks import ConvModule
from lib.modules.utils import LightningBaseModule

View File

@ -5,8 +5,10 @@ import torch
from torch import nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.utils.data import DataLoader
from datasets.trajectory_dataset import TrajData
from lib.evaluation.classification import ROCEvaluation
from lib.modules.utils import LightningBaseModule, Flatten
from lib.modules.blocks import ConvModule, ResidualModule
@ -24,6 +26,22 @@ class ConvHomDetector(LightningBaseModule):
loss = F.binary_cross_entropy(pred_y, batch_y.float())
return {'loss': loss, 'log': dict(loss=loss)}
def test_step(self, batch_xy, **kwargs):
batch_x, batch_y = batch_xy
pred_y = self(batch_x)
return dict(prediction=pred_y, label=batch_y)
def test_end(self, outputs):
evaluation = ROCEvaluation()
predictions = torch.stack([x['prediction'] for x in outputs])
labels = torch.stack([x['label'] for x in outputs])
scores = evaluation(predictions.numpy(), labels.numpy())
self.logger.log_metrics()
pass
def __init__(self, *params):
super(ConvHomDetector, self).__init__(*params)
@ -70,6 +88,26 @@ class ConvHomDetector(LightningBaseModule):
self.classifier = nn.Linear(self.hparams.model_param.classes * 10, 1) # self.hparams.model_param.classes)
self.out_activation = nn.Sigmoid() # nn.Softmax
# Dataloaders
# ================================================================================
# Train Dataloader
def train_dataloader(self):
return DataLoader(dataset=self.dataset.train_dataset, shuffle=True,
batch_size=self.hparams.data_param.batchsize,
num_workers=self.hparams.data_param.worker)
# Test Dataloader
def test_dataloader(self):
return DataLoader(dataset=self.dataset.test_dataset, shuffle=True,
batch_size=self.hparams.data_param.batchsize,
num_workers=self.hparams.data_param.worker)
# Validation Dataloader
def val_dataloader(self):
return DataLoader(dataset=self.dataset.val_dataset, shuffle=True,
batch_size=self.hparams.data_param.batchsize,
num_workers=self.hparams.data_param.worker)
def forward(self, x):
tensor = self.map_conv_0(x)
tensor = self.map_res_1(tensor)

View File

@ -105,24 +105,6 @@ class LightningBaseModule(pl.LightningModule, ABC):
torch.save(self.__class__, f)
return True
@pl.data_loader
def train_dataloader(self):
return DataLoader(dataset=self.dataset.train_dataset, shuffle=True,
batch_size=self.hparams.data_param.batchsize,
num_workers=self.hparams.data_param.worker)
@pl.data_loader
def test_dataloader(self):
return DataLoader(dataset=self.dataset.test_dataset, shuffle=True,
batch_size=self.hparams.data_param.batchsize,
num_workers=self.hparams.data_param.worker)
@pl.data_loader
def val_dataloader(self):
return DataLoader(dataset=self.dataset.val_dataset, shuffle=True,
batch_size=self.hparams.data_param.batchsize,
num_workers=self.hparams.data_param.worker)
@property
def data_len(self):
return len(self.dataset.train_dataset)

View File

@ -79,20 +79,37 @@ class Config(ConfigParser):
super(Config, self).__init__(**kwargs)
pass
@staticmethod
def _sort_combined_section_key_mapping(dict_obj):
sorted_dict = defaultdict(dict)
for key in dict_obj:
section, *attr_name = key.split('_')
attr_name = '_'.join(attr_name)
value = str(dict_obj[key])
sorted_dict[section][attr_name] = value
# noinspection PyTypeChecker
return dict(sorted_dict)
@classmethod
def read_namespace(cls, namespace: Namespace):
space_dict = defaultdict(dict)
for key in namespace.__dict__:
section, *attr_name = key.split('_')
attr_name = '_'.join(attr_name)
value = str(namespace.__getattribute__(key))
space_dict[section][attr_name] = value
sorted_dict = cls._sort_combined_section_key_mapping(namespace.__dict__)
new_config = cls()
new_config.read_dict(space_dict)
new_config.read_dict(sorted_dict)
return new_config
def update(self, mapping):
sorted_dict = self._sort_combined_section_key_mapping(mapping)
for section in sorted_dict:
if self.has_section(section):
pass
else:
self.add_section(section)
for option, value in sorted_dict[section].items():
self.set(section, option, value)
return self
def get(self, *args, **kwargs):
item = super(Config, self).get(*args, **kwargs)
try:
@ -108,5 +125,4 @@ class Config(ConfigParser):
with path.open('w') as configfile:
super().write(configfile)
return True

View File

@ -78,6 +78,14 @@ class Logger(LightningLoggerBase):
def log_config_as_ini(self):
self.config.write(self.log_dir)
def log_metric(self, metric_name, metric_value, **kwargs):
self.testtubelogger.log_metrics(dict(metric_name=metric_value))
self.neptunelogger.log_metric(metric_name, metric_value, **kwargs)
def log_image(self, name, image, **kwargs):
self.neptunelogger.log_image(name, image, **kwargs)
image.savefig(self.log_dir / name)
def save(self):
self.testtubelogger.save()
self.neptunelogger.save()

38
main.py
View File

@ -3,18 +3,21 @@
import os
from distutils.util import strtobool
from pathlib import Path
from argparse import ArgumentParser
from argparse import ArgumentParser, Namespace
import warnings
import torch
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from torch.utils.data import DataLoader
from lib.modules.utils import LightningBaseModule
from lib.utils.config import Config
from lib.utils.logging import Logger
from lib.evaluation.classification import ROCEvaluation
warnings.filterwarnings('ignore', category=FutureWarning)
warnings.filterwarnings('ignore', category=UserWarning)
@ -31,7 +34,7 @@ main_arg_parser.add_argument("--main_eval", type=strtobool, default=False, help=
main_arg_parser.add_argument("--main_seed", type=int, default=69, help="")
# Data Parameters
main_arg_parser.add_argument("--data_worker", type=int, default=0, help="")
main_arg_parser.add_argument("--data_worker", type=int, default=10, help="")
main_arg_parser.add_argument("--data_batchsize", type=int, default=100, help="")
main_arg_parser.add_argument("--data_root", type=str, default='/data/', help="")
main_arg_parser.add_argument("--data_map_root", type=str, default='/res/maps', help="")
@ -61,16 +64,15 @@ main_arg_parser.add_argument("--project_owner", type=str, default='si11ium', hel
main_arg_parser.add_argument("--project_neptune_key", type=str, default=os.getenv('NEPTUNE_KEY'), help="")
# Parse it
args = main_arg_parser.parse_args()
config = Config.read_namespace(args)
args: Namespace = main_arg_parser.parse_args()
if __name__ == "__main__":
def run_lightning_loop(config_obj):
# Logging
# =============================================================================
# ================================================================================
# Logger
with Logger(config) as logger:
with Logger(config_obj) as logger:
# Callbacks
# =============================================================================
# Checkpoint Saving
@ -90,12 +92,12 @@ if __name__ == "__main__":
# Model
# =============================================================================
# Init
model: LightningBaseModule = config.model_class(config.model_paramters)
model: LightningBaseModule = config_obj.model_class(config_obj.model_paramters)
model.init_weights()
# Trainer
# =============================================================================
trainer = Trainer(max_epochs=config.train.epochs,
trainer = Trainer(max_epochs=config_obj.train.epochs,
show_progress_bar=True,
weights_save_path=logger.log_dir,
gpus=[0] if torch.cuda.is_available() else None,
@ -103,15 +105,23 @@ if __name__ == "__main__":
log_save_interval=(model.data_len * 0.04), # TODO: Better Value / Setting
checkpoint_callback=checkpoint_callback,
logger=logger,
fast_dev_run=config.main.debug,
fast_dev_run=config_obj.main.debug,
early_stop_callback=None
)
# Train it
trainer.fit(model)
# Train It
trainer.fit(model,)
# Save the last state & all parameters
trainer.save_checkpoint(logger.log_dir / 'weights.ckpt')
model.save_to_disk(logger.log_dir)
pass
# TODO: Eval here!
# Evaluate It
trainer.test()
return model
if __name__ == "__main__":
config = Config.read_namespace(args)
trained_model = run_lightning_loop(config)

View File

View File

@ -7,23 +7,18 @@ warnings.filterwarnings('ignore', category=UserWarning)
# Imports
# =============================================================================
from pathlib import Path
import os
from main import run_training, args
if __name__ == '__main__':
# Model Settings
warnings.filterwarnings('ignore', category=FutureWarning)
config = Config().read_namespace(args)
# use_bias, activation, model, use_norm, max_epochs, filters
cnn_classifier = [True, 'leaky_relu', 'classifier_cnn', False, 2, [16, 32, 64]]
# use_bias, activation, model, use_norm, max_epochs, sr, lat_dim, filters
# Data Settings
data_shortcodes = ['mid', 'mid_5']
# Iteration over
for data_shortcode in data_shortcodes:
for use_bias, activation, model, use_norm, max_epochs, filters in [cnn_classifier]:
for seed in range(5):
arg_dict = dict(main_seed=seed, train_max_epochs=max_epochs,
@ -32,4 +27,6 @@ if __name__ == '__main__':
model_filters=filters,
data_batch_size=512)
os.system(f'/home/steffen/envs/traj_gen/bin/python main.py {arg_dict}')
config = config.update(arg_dict)
run_training(config)