eval written
This commit is contained in:
parent
8d06c179c9
commit
1f25bf599b
11
.idea/deployment.xml
generated
11
.idea/deployment.xml
generated
@ -1,8 +1,15 @@
|
|||||||
<?xml version="1.0" encoding="UTF-8"?>
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
<project version="4">
|
<project version="4">
|
||||||
<component name="PublishConfigData" autoUpload="On explicit save action" serverName="steffen@aimachine:22">
|
<component name="PublishConfigData" autoUpload="On explicit save action" serverName="traj_gen-AiMachine" showAutoUploadSettingsWarning="false">
|
||||||
<serverData>
|
<serverData>
|
||||||
<paths name="steffen@aimachine:22">
|
<paths name="ErLoWa-AiMachine">
|
||||||
|
<serverdata>
|
||||||
|
<mappings>
|
||||||
|
<mapping local="$PROJECT_DIR$" web="/" />
|
||||||
|
</mappings>
|
||||||
|
</serverdata>
|
||||||
|
</paths>
|
||||||
|
<paths name="traj_gen-AiMachine">
|
||||||
<serverdata>
|
<serverdata>
|
||||||
<mappings>
|
<mappings>
|
||||||
<mapping deploy="/" local="$PROJECT_DIR$" web="/" />
|
<mapping deploy="/" local="$PROJECT_DIR$" web="/" />
|
||||||
|
2
.idea/hom_traj_gen.iml
generated
2
.idea/hom_traj_gen.iml
generated
@ -2,7 +2,7 @@
|
|||||||
<module type="PYTHON_MODULE" version="4">
|
<module type="PYTHON_MODULE" version="4">
|
||||||
<component name="NewModuleRootManager">
|
<component name="NewModuleRootManager">
|
||||||
<content url="file://$MODULE_DIR$" />
|
<content url="file://$MODULE_DIR$" />
|
||||||
<orderEntry type="jdk" jdkName="Remote Python 3.7.6 (sftp://steffen@aimachine:22/home/steffen/envs/traj_gen/bin/python)" jdkType="Python SDK" />
|
<orderEntry type="jdk" jdkName="traj_gen@AiMachine" jdkType="Python SDK" />
|
||||||
<orderEntry type="sourceFolder" forTests="false" />
|
<orderEntry type="sourceFolder" forTests="false" />
|
||||||
</component>
|
</component>
|
||||||
</module>
|
</module>
|
2
.idea/misc.xml
generated
2
.idea/misc.xml
generated
@ -3,5 +3,5 @@
|
|||||||
<component name="JavaScriptSettings">
|
<component name="JavaScriptSettings">
|
||||||
<option name="languageLevel" value="ES6" />
|
<option name="languageLevel" value="ES6" />
|
||||||
</component>
|
</component>
|
||||||
<component name="ProjectRootManager" version="2" project-jdk-name="hom_traj_gen@aimachine" project-jdk-type="Python SDK" />
|
<component name="ProjectRootManager" version="2" project-jdk-name="traj_gen@AiMachine" project-jdk-type="Python SDK" />
|
||||||
</project>
|
</project>
|
@ -1,29 +1,24 @@
|
|||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
from sklearn.metrics import roc_curve, auc
|
from sklearn.metrics import roc_curve, auc
|
||||||
|
|
||||||
|
|
||||||
class ROCEvaluation(object):
|
class ROCEvaluation(object):
|
||||||
|
|
||||||
BINARY_PROBLEM = 2
|
|
||||||
linewidth = 2
|
linewidth = 2
|
||||||
|
|
||||||
def __init__(self, save_fig=True):
|
def __init__(self, prepare_figure=False):
|
||||||
|
self.prepare_figure = prepare_figure
|
||||||
self.epoch = 0
|
self.epoch = 0
|
||||||
pass
|
|
||||||
|
|
||||||
def __call__(self, prediction, label, prepare_fig=True):
|
def __call__(self, prediction, label, plotting=False):
|
||||||
|
|
||||||
# Compute ROC curve and ROC area
|
# Compute ROC curve and ROC area
|
||||||
fpr, tpr, _ = roc_curve(prediction, label)
|
fpr, tpr, _ = roc_curve(prediction, label)
|
||||||
roc_auc = auc(fpr, tpr)
|
roc_auc = auc(fpr, tpr)
|
||||||
|
if plotting:
|
||||||
if prepare_fig:
|
fig = plt.gcf()
|
||||||
fig = self._prepare_fig()
|
fig.plot(fpr, tpr, color='darkorange', lw=self.linewidth, label=f'ROC curve (area = {roc_auc})')
|
||||||
fig.plot(fpr, tpr, color='darkorange',
|
return roc_auc, fpr, tpr
|
||||||
lw=2, label=f'ROC curve (area = {roc_auc})')
|
|
||||||
self._prepare_fig()
|
|
||||||
return roc_auc
|
|
||||||
|
|
||||||
def _prepare_fig(self):
|
def _prepare_fig(self):
|
||||||
fig = plt.gcf()
|
fig = plt.gcf()
|
||||||
@ -32,6 +27,6 @@ class ROCEvaluation(object):
|
|||||||
fig.ylim([0.0, 1.05])
|
fig.ylim([0.0, 1.05])
|
||||||
fig.xlabel('False Positive Rate')
|
fig.xlabel('False Positive Rate')
|
||||||
fig.ylabel('True Positive Rate')
|
fig.ylabel('True Positive Rate')
|
||||||
|
|
||||||
fig.legend(loc="lower right")
|
fig.legend(loc="lower right")
|
||||||
|
|
||||||
return fig
|
return fig
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from dataset.dataset import TrajPairData
|
from datasets.paired_dataset import TrajPairData
|
||||||
from lib.modules.blocks import ConvModule
|
from lib.modules.blocks import ConvModule
|
||||||
from lib.modules.utils import LightningBaseModule
|
from lib.modules.utils import LightningBaseModule
|
||||||
|
|
||||||
|
@ -5,8 +5,10 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch.optim import Adam
|
from torch.optim import Adam
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from datasets.trajectory_dataset import TrajData
|
from datasets.trajectory_dataset import TrajData
|
||||||
|
from lib.evaluation.classification import ROCEvaluation
|
||||||
from lib.modules.utils import LightningBaseModule, Flatten
|
from lib.modules.utils import LightningBaseModule, Flatten
|
||||||
from lib.modules.blocks import ConvModule, ResidualModule
|
from lib.modules.blocks import ConvModule, ResidualModule
|
||||||
|
|
||||||
@ -24,6 +26,22 @@ class ConvHomDetector(LightningBaseModule):
|
|||||||
loss = F.binary_cross_entropy(pred_y, batch_y.float())
|
loss = F.binary_cross_entropy(pred_y, batch_y.float())
|
||||||
return {'loss': loss, 'log': dict(loss=loss)}
|
return {'loss': loss, 'log': dict(loss=loss)}
|
||||||
|
|
||||||
|
def test_step(self, batch_xy, **kwargs):
|
||||||
|
batch_x, batch_y = batch_xy
|
||||||
|
pred_y = self(batch_x)
|
||||||
|
return dict(prediction=pred_y, label=batch_y)
|
||||||
|
|
||||||
|
def test_end(self, outputs):
|
||||||
|
evaluation = ROCEvaluation()
|
||||||
|
predictions = torch.stack([x['prediction'] for x in outputs])
|
||||||
|
labels = torch.stack([x['label'] for x in outputs])
|
||||||
|
|
||||||
|
scores = evaluation(predictions.numpy(), labels.numpy())
|
||||||
|
self.logger.log_metrics()
|
||||||
|
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
def __init__(self, *params):
|
def __init__(self, *params):
|
||||||
super(ConvHomDetector, self).__init__(*params)
|
super(ConvHomDetector, self).__init__(*params)
|
||||||
|
|
||||||
@ -70,6 +88,26 @@ class ConvHomDetector(LightningBaseModule):
|
|||||||
self.classifier = nn.Linear(self.hparams.model_param.classes * 10, 1) # self.hparams.model_param.classes)
|
self.classifier = nn.Linear(self.hparams.model_param.classes * 10, 1) # self.hparams.model_param.classes)
|
||||||
self.out_activation = nn.Sigmoid() # nn.Softmax
|
self.out_activation = nn.Sigmoid() # nn.Softmax
|
||||||
|
|
||||||
|
# Dataloaders
|
||||||
|
# ================================================================================
|
||||||
|
# Train Dataloader
|
||||||
|
def train_dataloader(self):
|
||||||
|
return DataLoader(dataset=self.dataset.train_dataset, shuffle=True,
|
||||||
|
batch_size=self.hparams.data_param.batchsize,
|
||||||
|
num_workers=self.hparams.data_param.worker)
|
||||||
|
|
||||||
|
# Test Dataloader
|
||||||
|
def test_dataloader(self):
|
||||||
|
return DataLoader(dataset=self.dataset.test_dataset, shuffle=True,
|
||||||
|
batch_size=self.hparams.data_param.batchsize,
|
||||||
|
num_workers=self.hparams.data_param.worker)
|
||||||
|
|
||||||
|
# Validation Dataloader
|
||||||
|
def val_dataloader(self):
|
||||||
|
return DataLoader(dataset=self.dataset.val_dataset, shuffle=True,
|
||||||
|
batch_size=self.hparams.data_param.batchsize,
|
||||||
|
num_workers=self.hparams.data_param.worker)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
tensor = self.map_conv_0(x)
|
tensor = self.map_conv_0(x)
|
||||||
tensor = self.map_res_1(tensor)
|
tensor = self.map_res_1(tensor)
|
||||||
|
@ -105,24 +105,6 @@ class LightningBaseModule(pl.LightningModule, ABC):
|
|||||||
torch.save(self.__class__, f)
|
torch.save(self.__class__, f)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@pl.data_loader
|
|
||||||
def train_dataloader(self):
|
|
||||||
return DataLoader(dataset=self.dataset.train_dataset, shuffle=True,
|
|
||||||
batch_size=self.hparams.data_param.batchsize,
|
|
||||||
num_workers=self.hparams.data_param.worker)
|
|
||||||
|
|
||||||
@pl.data_loader
|
|
||||||
def test_dataloader(self):
|
|
||||||
return DataLoader(dataset=self.dataset.test_dataset, shuffle=True,
|
|
||||||
batch_size=self.hparams.data_param.batchsize,
|
|
||||||
num_workers=self.hparams.data_param.worker)
|
|
||||||
|
|
||||||
@pl.data_loader
|
|
||||||
def val_dataloader(self):
|
|
||||||
return DataLoader(dataset=self.dataset.val_dataset, shuffle=True,
|
|
||||||
batch_size=self.hparams.data_param.batchsize,
|
|
||||||
num_workers=self.hparams.data_param.worker)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def data_len(self):
|
def data_len(self):
|
||||||
return len(self.dataset.train_dataset)
|
return len(self.dataset.train_dataset)
|
||||||
|
@ -79,20 +79,37 @@ class Config(ConfigParser):
|
|||||||
super(Config, self).__init__(**kwargs)
|
super(Config, self).__init__(**kwargs)
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _sort_combined_section_key_mapping(dict_obj):
|
||||||
|
sorted_dict = defaultdict(dict)
|
||||||
|
for key in dict_obj:
|
||||||
|
section, *attr_name = key.split('_')
|
||||||
|
attr_name = '_'.join(attr_name)
|
||||||
|
value = str(dict_obj[key])
|
||||||
|
|
||||||
|
sorted_dict[section][attr_name] = value
|
||||||
|
# noinspection PyTypeChecker
|
||||||
|
return dict(sorted_dict)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def read_namespace(cls, namespace: Namespace):
|
def read_namespace(cls, namespace: Namespace):
|
||||||
|
|
||||||
space_dict = defaultdict(dict)
|
sorted_dict = cls._sort_combined_section_key_mapping(namespace.__dict__)
|
||||||
for key in namespace.__dict__:
|
|
||||||
section, *attr_name = key.split('_')
|
|
||||||
attr_name = '_'.join(attr_name)
|
|
||||||
value = str(namespace.__getattribute__(key))
|
|
||||||
|
|
||||||
space_dict[section][attr_name] = value
|
|
||||||
new_config = cls()
|
new_config = cls()
|
||||||
new_config.read_dict(space_dict)
|
new_config.read_dict(sorted_dict)
|
||||||
return new_config
|
return new_config
|
||||||
|
|
||||||
|
def update(self, mapping):
|
||||||
|
sorted_dict = self._sort_combined_section_key_mapping(mapping)
|
||||||
|
for section in sorted_dict:
|
||||||
|
if self.has_section(section):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
self.add_section(section)
|
||||||
|
for option, value in sorted_dict[section].items():
|
||||||
|
self.set(section, option, value)
|
||||||
|
return self
|
||||||
|
|
||||||
def get(self, *args, **kwargs):
|
def get(self, *args, **kwargs):
|
||||||
item = super(Config, self).get(*args, **kwargs)
|
item = super(Config, self).get(*args, **kwargs)
|
||||||
try:
|
try:
|
||||||
@ -108,5 +125,4 @@ class Config(ConfigParser):
|
|||||||
|
|
||||||
with path.open('w') as configfile:
|
with path.open('w') as configfile:
|
||||||
super().write(configfile)
|
super().write(configfile)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
@ -78,6 +78,14 @@ class Logger(LightningLoggerBase):
|
|||||||
def log_config_as_ini(self):
|
def log_config_as_ini(self):
|
||||||
self.config.write(self.log_dir)
|
self.config.write(self.log_dir)
|
||||||
|
|
||||||
|
def log_metric(self, metric_name, metric_value, **kwargs):
|
||||||
|
self.testtubelogger.log_metrics(dict(metric_name=metric_value))
|
||||||
|
self.neptunelogger.log_metric(metric_name, metric_value, **kwargs)
|
||||||
|
|
||||||
|
def log_image(self, name, image, **kwargs):
|
||||||
|
self.neptunelogger.log_image(name, image, **kwargs)
|
||||||
|
image.savefig(self.log_dir / name)
|
||||||
|
|
||||||
def save(self):
|
def save(self):
|
||||||
self.testtubelogger.save()
|
self.testtubelogger.save()
|
||||||
self.neptunelogger.save()
|
self.neptunelogger.save()
|
||||||
|
38
main.py
38
main.py
@ -3,18 +3,21 @@
|
|||||||
import os
|
import os
|
||||||
from distutils.util import strtobool
|
from distutils.util import strtobool
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser, Namespace
|
||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pytorch_lightning import Trainer
|
from pytorch_lightning import Trainer
|
||||||
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
|
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from lib.modules.utils import LightningBaseModule
|
from lib.modules.utils import LightningBaseModule
|
||||||
from lib.utils.config import Config
|
from lib.utils.config import Config
|
||||||
from lib.utils.logging import Logger
|
from lib.utils.logging import Logger
|
||||||
|
|
||||||
|
from lib.evaluation.classification import ROCEvaluation
|
||||||
|
|
||||||
warnings.filterwarnings('ignore', category=FutureWarning)
|
warnings.filterwarnings('ignore', category=FutureWarning)
|
||||||
warnings.filterwarnings('ignore', category=UserWarning)
|
warnings.filterwarnings('ignore', category=UserWarning)
|
||||||
|
|
||||||
@ -31,7 +34,7 @@ main_arg_parser.add_argument("--main_eval", type=strtobool, default=False, help=
|
|||||||
main_arg_parser.add_argument("--main_seed", type=int, default=69, help="")
|
main_arg_parser.add_argument("--main_seed", type=int, default=69, help="")
|
||||||
|
|
||||||
# Data Parameters
|
# Data Parameters
|
||||||
main_arg_parser.add_argument("--data_worker", type=int, default=0, help="")
|
main_arg_parser.add_argument("--data_worker", type=int, default=10, help="")
|
||||||
main_arg_parser.add_argument("--data_batchsize", type=int, default=100, help="")
|
main_arg_parser.add_argument("--data_batchsize", type=int, default=100, help="")
|
||||||
main_arg_parser.add_argument("--data_root", type=str, default='/data/', help="")
|
main_arg_parser.add_argument("--data_root", type=str, default='/data/', help="")
|
||||||
main_arg_parser.add_argument("--data_map_root", type=str, default='/res/maps', help="")
|
main_arg_parser.add_argument("--data_map_root", type=str, default='/res/maps', help="")
|
||||||
@ -61,16 +64,15 @@ main_arg_parser.add_argument("--project_owner", type=str, default='si11ium', hel
|
|||||||
main_arg_parser.add_argument("--project_neptune_key", type=str, default=os.getenv('NEPTUNE_KEY'), help="")
|
main_arg_parser.add_argument("--project_neptune_key", type=str, default=os.getenv('NEPTUNE_KEY'), help="")
|
||||||
|
|
||||||
# Parse it
|
# Parse it
|
||||||
args = main_arg_parser.parse_args()
|
args: Namespace = main_arg_parser.parse_args()
|
||||||
config = Config.read_namespace(args)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
def run_lightning_loop(config_obj):
|
||||||
|
|
||||||
# Logging
|
# Logging
|
||||||
# =============================================================================
|
# ================================================================================
|
||||||
# Logger
|
# Logger
|
||||||
with Logger(config) as logger:
|
with Logger(config_obj) as logger:
|
||||||
# Callbacks
|
# Callbacks
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Checkpoint Saving
|
# Checkpoint Saving
|
||||||
@ -90,12 +92,12 @@ if __name__ == "__main__":
|
|||||||
# Model
|
# Model
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Init
|
# Init
|
||||||
model: LightningBaseModule = config.model_class(config.model_paramters)
|
model: LightningBaseModule = config_obj.model_class(config_obj.model_paramters)
|
||||||
model.init_weights()
|
model.init_weights()
|
||||||
|
|
||||||
# Trainer
|
# Trainer
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
trainer = Trainer(max_epochs=config.train.epochs,
|
trainer = Trainer(max_epochs=config_obj.train.epochs,
|
||||||
show_progress_bar=True,
|
show_progress_bar=True,
|
||||||
weights_save_path=logger.log_dir,
|
weights_save_path=logger.log_dir,
|
||||||
gpus=[0] if torch.cuda.is_available() else None,
|
gpus=[0] if torch.cuda.is_available() else None,
|
||||||
@ -103,15 +105,23 @@ if __name__ == "__main__":
|
|||||||
log_save_interval=(model.data_len * 0.04), # TODO: Better Value / Setting
|
log_save_interval=(model.data_len * 0.04), # TODO: Better Value / Setting
|
||||||
checkpoint_callback=checkpoint_callback,
|
checkpoint_callback=checkpoint_callback,
|
||||||
logger=logger,
|
logger=logger,
|
||||||
fast_dev_run=config.main.debug,
|
fast_dev_run=config_obj.main.debug,
|
||||||
early_stop_callback=None
|
early_stop_callback=None
|
||||||
)
|
)
|
||||||
|
|
||||||
# Train it
|
# Train It
|
||||||
trainer.fit(model)
|
trainer.fit(model,)
|
||||||
|
|
||||||
# Save the last state & all parameters
|
# Save the last state & all parameters
|
||||||
trainer.save_checkpoint(logger.log_dir / 'weights.ckpt')
|
trainer.save_checkpoint(logger.log_dir / 'weights.ckpt')
|
||||||
model.save_to_disk(logger.log_dir)
|
model.save_to_disk(logger.log_dir)
|
||||||
pass
|
|
||||||
# TODO: Eval here!
|
# Evaluate It
|
||||||
|
trainer.test()
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
config = Config.read_namespace(args)
|
||||||
|
trained_model = run_lightning_loop(config)
|
||||||
|
27
multi_run.py
27
multi_run.py
@ -7,29 +7,26 @@ warnings.filterwarnings('ignore', category=UserWarning)
|
|||||||
|
|
||||||
# Imports
|
# Imports
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
from pathlib import Path
|
|
||||||
import os
|
from main import run_training, args
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
||||||
# Model Settings
|
# Model Settings
|
||||||
warnings.filterwarnings('ignore', category=FutureWarning)
|
config = Config().read_namespace(args)
|
||||||
# use_bias, activation, model, use_norm, max_epochs, filters
|
# use_bias, activation, model, use_norm, max_epochs, filters
|
||||||
cnn_classifier = [True, 'leaky_relu', 'classifier_cnn', False, 2, [16, 32, 64]]
|
cnn_classifier = [True, 'leaky_relu', 'classifier_cnn', False, 2, [16, 32, 64]]
|
||||||
# use_bias, activation, model, use_norm, max_epochs, sr, lat_dim, filters
|
# use_bias, activation, model, use_norm, max_epochs, sr, lat_dim, filters
|
||||||
|
|
||||||
# Data Settings
|
for use_bias, activation, model, use_norm, max_epochs, filters in [cnn_classifier]:
|
||||||
data_shortcodes = ['mid', 'mid_5']
|
for seed in range(5):
|
||||||
|
arg_dict = dict(main_seed=seed, train_max_epochs=max_epochs,
|
||||||
|
model_use_bias=use_bias, model_use_norm=use_norm,
|
||||||
|
model_activation=activation, model_type=model,
|
||||||
|
model_filters=filters,
|
||||||
|
data_batch_size=512)
|
||||||
|
|
||||||
# Iteration over
|
config = config.update(arg_dict)
|
||||||
for data_shortcode in data_shortcodes:
|
|
||||||
for use_bias, activation, model, use_norm, max_epochs, filters in [cnn_classifier]:
|
|
||||||
for seed in range(5):
|
|
||||||
arg_dict = dict(main_seed=seed, train_max_epochs=max_epochs,
|
|
||||||
model_use_bias=use_bias, model_use_norm=use_norm,
|
|
||||||
model_activation=activation, model_type=model,
|
|
||||||
model_filters=filters,
|
|
||||||
data_batch_size=512)
|
|
||||||
|
|
||||||
os.system(f'/home/steffen/envs/traj_gen/bin/python main.py {arg_dict}')
|
run_training(config)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user