eval written
This commit is contained in:
parent
8d06c179c9
commit
1f25bf599b
11
.idea/deployment.xml
generated
11
.idea/deployment.xml
generated
@ -1,8 +1,15 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="PublishConfigData" autoUpload="On explicit save action" serverName="steffen@aimachine:22">
|
||||
<component name="PublishConfigData" autoUpload="On explicit save action" serverName="traj_gen-AiMachine" showAutoUploadSettingsWarning="false">
|
||||
<serverData>
|
||||
<paths name="steffen@aimachine:22">
|
||||
<paths name="ErLoWa-AiMachine">
|
||||
<serverdata>
|
||||
<mappings>
|
||||
<mapping local="$PROJECT_DIR$" web="/" />
|
||||
</mappings>
|
||||
</serverdata>
|
||||
</paths>
|
||||
<paths name="traj_gen-AiMachine">
|
||||
<serverdata>
|
||||
<mappings>
|
||||
<mapping deploy="/" local="$PROJECT_DIR$" web="/" />
|
||||
|
2
.idea/hom_traj_gen.iml
generated
2
.idea/hom_traj_gen.iml
generated
@ -2,7 +2,7 @@
|
||||
<module type="PYTHON_MODULE" version="4">
|
||||
<component name="NewModuleRootManager">
|
||||
<content url="file://$MODULE_DIR$" />
|
||||
<orderEntry type="jdk" jdkName="Remote Python 3.7.6 (sftp://steffen@aimachine:22/home/steffen/envs/traj_gen/bin/python)" jdkType="Python SDK" />
|
||||
<orderEntry type="jdk" jdkName="traj_gen@AiMachine" jdkType="Python SDK" />
|
||||
<orderEntry type="sourceFolder" forTests="false" />
|
||||
</component>
|
||||
</module>
|
2
.idea/misc.xml
generated
2
.idea/misc.xml
generated
@ -3,5 +3,5 @@
|
||||
<component name="JavaScriptSettings">
|
||||
<option name="languageLevel" value="ES6" />
|
||||
</component>
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="hom_traj_gen@aimachine" project-jdk-type="Python SDK" />
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="traj_gen@AiMachine" project-jdk-type="Python SDK" />
|
||||
</project>
|
@ -1,29 +1,24 @@
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from sklearn.metrics import roc_curve, auc
|
||||
|
||||
|
||||
class ROCEvaluation(object):
|
||||
|
||||
BINARY_PROBLEM = 2
|
||||
linewidth = 2
|
||||
|
||||
def __init__(self, save_fig=True):
|
||||
def __init__(self, prepare_figure=False):
|
||||
self.prepare_figure = prepare_figure
|
||||
self.epoch = 0
|
||||
pass
|
||||
|
||||
def __call__(self, prediction, label, prepare_fig=True):
|
||||
def __call__(self, prediction, label, plotting=False):
|
||||
|
||||
# Compute ROC curve and ROC area
|
||||
fpr, tpr, _ = roc_curve(prediction, label)
|
||||
roc_auc = auc(fpr, tpr)
|
||||
|
||||
if prepare_fig:
|
||||
fig = self._prepare_fig()
|
||||
fig.plot(fpr, tpr, color='darkorange',
|
||||
lw=2, label=f'ROC curve (area = {roc_auc})')
|
||||
self._prepare_fig()
|
||||
return roc_auc
|
||||
if plotting:
|
||||
fig = plt.gcf()
|
||||
fig.plot(fpr, tpr, color='darkorange', lw=self.linewidth, label=f'ROC curve (area = {roc_auc})')
|
||||
return roc_auc, fpr, tpr
|
||||
|
||||
def _prepare_fig(self):
|
||||
fig = plt.gcf()
|
||||
@ -32,6 +27,6 @@ class ROCEvaluation(object):
|
||||
fig.ylim([0.0, 1.05])
|
||||
fig.xlabel('False Positive Rate')
|
||||
fig.ylabel('True Positive Rate')
|
||||
|
||||
fig.legend(loc="lower right")
|
||||
|
||||
return fig
|
||||
|
@ -1,4 +1,4 @@
|
||||
from dataset.dataset import TrajPairData
|
||||
from datasets.paired_dataset import TrajPairData
|
||||
from lib.modules.blocks import ConvModule
|
||||
from lib.modules.utils import LightningBaseModule
|
||||
|
||||
|
@ -5,8 +5,10 @@ import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from torch.optim import Adam
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from datasets.trajectory_dataset import TrajData
|
||||
from lib.evaluation.classification import ROCEvaluation
|
||||
from lib.modules.utils import LightningBaseModule, Flatten
|
||||
from lib.modules.blocks import ConvModule, ResidualModule
|
||||
|
||||
@ -24,6 +26,22 @@ class ConvHomDetector(LightningBaseModule):
|
||||
loss = F.binary_cross_entropy(pred_y, batch_y.float())
|
||||
return {'loss': loss, 'log': dict(loss=loss)}
|
||||
|
||||
def test_step(self, batch_xy, **kwargs):
|
||||
batch_x, batch_y = batch_xy
|
||||
pred_y = self(batch_x)
|
||||
return dict(prediction=pred_y, label=batch_y)
|
||||
|
||||
def test_end(self, outputs):
|
||||
evaluation = ROCEvaluation()
|
||||
predictions = torch.stack([x['prediction'] for x in outputs])
|
||||
labels = torch.stack([x['label'] for x in outputs])
|
||||
|
||||
scores = evaluation(predictions.numpy(), labels.numpy())
|
||||
self.logger.log_metrics()
|
||||
|
||||
|
||||
pass
|
||||
|
||||
def __init__(self, *params):
|
||||
super(ConvHomDetector, self).__init__(*params)
|
||||
|
||||
@ -70,6 +88,26 @@ class ConvHomDetector(LightningBaseModule):
|
||||
self.classifier = nn.Linear(self.hparams.model_param.classes * 10, 1) # self.hparams.model_param.classes)
|
||||
self.out_activation = nn.Sigmoid() # nn.Softmax
|
||||
|
||||
# Dataloaders
|
||||
# ================================================================================
|
||||
# Train Dataloader
|
||||
def train_dataloader(self):
|
||||
return DataLoader(dataset=self.dataset.train_dataset, shuffle=True,
|
||||
batch_size=self.hparams.data_param.batchsize,
|
||||
num_workers=self.hparams.data_param.worker)
|
||||
|
||||
# Test Dataloader
|
||||
def test_dataloader(self):
|
||||
return DataLoader(dataset=self.dataset.test_dataset, shuffle=True,
|
||||
batch_size=self.hparams.data_param.batchsize,
|
||||
num_workers=self.hparams.data_param.worker)
|
||||
|
||||
# Validation Dataloader
|
||||
def val_dataloader(self):
|
||||
return DataLoader(dataset=self.dataset.val_dataset, shuffle=True,
|
||||
batch_size=self.hparams.data_param.batchsize,
|
||||
num_workers=self.hparams.data_param.worker)
|
||||
|
||||
def forward(self, x):
|
||||
tensor = self.map_conv_0(x)
|
||||
tensor = self.map_res_1(tensor)
|
||||
|
@ -105,24 +105,6 @@ class LightningBaseModule(pl.LightningModule, ABC):
|
||||
torch.save(self.__class__, f)
|
||||
return True
|
||||
|
||||
@pl.data_loader
|
||||
def train_dataloader(self):
|
||||
return DataLoader(dataset=self.dataset.train_dataset, shuffle=True,
|
||||
batch_size=self.hparams.data_param.batchsize,
|
||||
num_workers=self.hparams.data_param.worker)
|
||||
|
||||
@pl.data_loader
|
||||
def test_dataloader(self):
|
||||
return DataLoader(dataset=self.dataset.test_dataset, shuffle=True,
|
||||
batch_size=self.hparams.data_param.batchsize,
|
||||
num_workers=self.hparams.data_param.worker)
|
||||
|
||||
@pl.data_loader
|
||||
def val_dataloader(self):
|
||||
return DataLoader(dataset=self.dataset.val_dataset, shuffle=True,
|
||||
batch_size=self.hparams.data_param.batchsize,
|
||||
num_workers=self.hparams.data_param.worker)
|
||||
|
||||
@property
|
||||
def data_len(self):
|
||||
return len(self.dataset.train_dataset)
|
||||
|
@ -79,20 +79,37 @@ class Config(ConfigParser):
|
||||
super(Config, self).__init__(**kwargs)
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _sort_combined_section_key_mapping(dict_obj):
|
||||
sorted_dict = defaultdict(dict)
|
||||
for key in dict_obj:
|
||||
section, *attr_name = key.split('_')
|
||||
attr_name = '_'.join(attr_name)
|
||||
value = str(dict_obj[key])
|
||||
|
||||
sorted_dict[section][attr_name] = value
|
||||
# noinspection PyTypeChecker
|
||||
return dict(sorted_dict)
|
||||
|
||||
@classmethod
|
||||
def read_namespace(cls, namespace: Namespace):
|
||||
|
||||
space_dict = defaultdict(dict)
|
||||
for key in namespace.__dict__:
|
||||
section, *attr_name = key.split('_')
|
||||
attr_name = '_'.join(attr_name)
|
||||
value = str(namespace.__getattribute__(key))
|
||||
|
||||
space_dict[section][attr_name] = value
|
||||
sorted_dict = cls._sort_combined_section_key_mapping(namespace.__dict__)
|
||||
new_config = cls()
|
||||
new_config.read_dict(space_dict)
|
||||
new_config.read_dict(sorted_dict)
|
||||
return new_config
|
||||
|
||||
def update(self, mapping):
|
||||
sorted_dict = self._sort_combined_section_key_mapping(mapping)
|
||||
for section in sorted_dict:
|
||||
if self.has_section(section):
|
||||
pass
|
||||
else:
|
||||
self.add_section(section)
|
||||
for option, value in sorted_dict[section].items():
|
||||
self.set(section, option, value)
|
||||
return self
|
||||
|
||||
def get(self, *args, **kwargs):
|
||||
item = super(Config, self).get(*args, **kwargs)
|
||||
try:
|
||||
@ -108,5 +125,4 @@ class Config(ConfigParser):
|
||||
|
||||
with path.open('w') as configfile:
|
||||
super().write(configfile)
|
||||
|
||||
return True
|
||||
|
@ -78,6 +78,14 @@ class Logger(LightningLoggerBase):
|
||||
def log_config_as_ini(self):
|
||||
self.config.write(self.log_dir)
|
||||
|
||||
def log_metric(self, metric_name, metric_value, **kwargs):
|
||||
self.testtubelogger.log_metrics(dict(metric_name=metric_value))
|
||||
self.neptunelogger.log_metric(metric_name, metric_value, **kwargs)
|
||||
|
||||
def log_image(self, name, image, **kwargs):
|
||||
self.neptunelogger.log_image(name, image, **kwargs)
|
||||
image.savefig(self.log_dir / name)
|
||||
|
||||
def save(self):
|
||||
self.testtubelogger.save()
|
||||
self.neptunelogger.save()
|
||||
|
38
main.py
38
main.py
@ -3,18 +3,21 @@
|
||||
import os
|
||||
from distutils.util import strtobool
|
||||
from pathlib import Path
|
||||
from argparse import ArgumentParser
|
||||
from argparse import ArgumentParser, Namespace
|
||||
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from lib.modules.utils import LightningBaseModule
|
||||
from lib.utils.config import Config
|
||||
from lib.utils.logging import Logger
|
||||
|
||||
from lib.evaluation.classification import ROCEvaluation
|
||||
|
||||
warnings.filterwarnings('ignore', category=FutureWarning)
|
||||
warnings.filterwarnings('ignore', category=UserWarning)
|
||||
|
||||
@ -31,7 +34,7 @@ main_arg_parser.add_argument("--main_eval", type=strtobool, default=False, help=
|
||||
main_arg_parser.add_argument("--main_seed", type=int, default=69, help="")
|
||||
|
||||
# Data Parameters
|
||||
main_arg_parser.add_argument("--data_worker", type=int, default=0, help="")
|
||||
main_arg_parser.add_argument("--data_worker", type=int, default=10, help="")
|
||||
main_arg_parser.add_argument("--data_batchsize", type=int, default=100, help="")
|
||||
main_arg_parser.add_argument("--data_root", type=str, default='/data/', help="")
|
||||
main_arg_parser.add_argument("--data_map_root", type=str, default='/res/maps', help="")
|
||||
@ -61,16 +64,15 @@ main_arg_parser.add_argument("--project_owner", type=str, default='si11ium', hel
|
||||
main_arg_parser.add_argument("--project_neptune_key", type=str, default=os.getenv('NEPTUNE_KEY'), help="")
|
||||
|
||||
# Parse it
|
||||
args = main_arg_parser.parse_args()
|
||||
config = Config.read_namespace(args)
|
||||
args: Namespace = main_arg_parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
def run_lightning_loop(config_obj):
|
||||
|
||||
# Logging
|
||||
# =============================================================================
|
||||
# ================================================================================
|
||||
# Logger
|
||||
with Logger(config) as logger:
|
||||
with Logger(config_obj) as logger:
|
||||
# Callbacks
|
||||
# =============================================================================
|
||||
# Checkpoint Saving
|
||||
@ -90,12 +92,12 @@ if __name__ == "__main__":
|
||||
# Model
|
||||
# =============================================================================
|
||||
# Init
|
||||
model: LightningBaseModule = config.model_class(config.model_paramters)
|
||||
model: LightningBaseModule = config_obj.model_class(config_obj.model_paramters)
|
||||
model.init_weights()
|
||||
|
||||
# Trainer
|
||||
# =============================================================================
|
||||
trainer = Trainer(max_epochs=config.train.epochs,
|
||||
trainer = Trainer(max_epochs=config_obj.train.epochs,
|
||||
show_progress_bar=True,
|
||||
weights_save_path=logger.log_dir,
|
||||
gpus=[0] if torch.cuda.is_available() else None,
|
||||
@ -103,15 +105,23 @@ if __name__ == "__main__":
|
||||
log_save_interval=(model.data_len * 0.04), # TODO: Better Value / Setting
|
||||
checkpoint_callback=checkpoint_callback,
|
||||
logger=logger,
|
||||
fast_dev_run=config.main.debug,
|
||||
fast_dev_run=config_obj.main.debug,
|
||||
early_stop_callback=None
|
||||
)
|
||||
|
||||
# Train it
|
||||
trainer.fit(model)
|
||||
# Train It
|
||||
trainer.fit(model,)
|
||||
|
||||
# Save the last state & all parameters
|
||||
trainer.save_checkpoint(logger.log_dir / 'weights.ckpt')
|
||||
model.save_to_disk(logger.log_dir)
|
||||
pass
|
||||
# TODO: Eval here!
|
||||
|
||||
# Evaluate It
|
||||
trainer.test()
|
||||
return model
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
config = Config.read_namespace(args)
|
||||
trained_model = run_lightning_loop(config)
|
||||
|
27
multi_run.py
27
multi_run.py
@ -7,29 +7,26 @@ warnings.filterwarnings('ignore', category=UserWarning)
|
||||
|
||||
# Imports
|
||||
# =============================================================================
|
||||
from pathlib import Path
|
||||
import os
|
||||
|
||||
from main import run_training, args
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
# Model Settings
|
||||
warnings.filterwarnings('ignore', category=FutureWarning)
|
||||
config = Config().read_namespace(args)
|
||||
# use_bias, activation, model, use_norm, max_epochs, filters
|
||||
cnn_classifier = [True, 'leaky_relu', 'classifier_cnn', False, 2, [16, 32, 64]]
|
||||
# use_bias, activation, model, use_norm, max_epochs, sr, lat_dim, filters
|
||||
|
||||
# Data Settings
|
||||
data_shortcodes = ['mid', 'mid_5']
|
||||
for use_bias, activation, model, use_norm, max_epochs, filters in [cnn_classifier]:
|
||||
for seed in range(5):
|
||||
arg_dict = dict(main_seed=seed, train_max_epochs=max_epochs,
|
||||
model_use_bias=use_bias, model_use_norm=use_norm,
|
||||
model_activation=activation, model_type=model,
|
||||
model_filters=filters,
|
||||
data_batch_size=512)
|
||||
|
||||
# Iteration over
|
||||
for data_shortcode in data_shortcodes:
|
||||
for use_bias, activation, model, use_norm, max_epochs, filters in [cnn_classifier]:
|
||||
for seed in range(5):
|
||||
arg_dict = dict(main_seed=seed, train_max_epochs=max_epochs,
|
||||
model_use_bias=use_bias, model_use_norm=use_norm,
|
||||
model_activation=activation, model_type=model,
|
||||
model_filters=filters,
|
||||
data_batch_size=512)
|
||||
config = config.update(arg_dict)
|
||||
|
||||
os.system(f'/home/steffen/envs/traj_gen/bin/python main.py {arg_dict}')
|
||||
run_training(config)
|
||||
|
Loading…
x
Reference in New Issue
Block a user