Merge remote-tracking branch 'origin/master'

This commit is contained in:
Si11ium
2020-03-05 09:58:50 +01:00
8 changed files with 125 additions and 63 deletions

11
.idea/deployment.xml generated
View File

@ -1,15 +1,8 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="PublishConfigData" autoUpload="On explicit save action" serverName="traj_gen-AiMachine">
<component name="PublishConfigData" autoUpload="On explicit save action" serverName="steffen@aimachine:22">
<serverData>
<paths name="ErLoWa-AiMachine">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="traj_gen-AiMachine">
<paths name="steffen@aimachine:22">
<serverdata>
<mappings>
<mapping deploy="/" local="$PROJECT_DIR$" web="/" />

View File

@ -2,7 +2,7 @@
<module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" />
<orderEntry type="jdk" jdkName="traj_gen@AiMachine" jdkType="Python SDK" />
<orderEntry type="jdk" jdkName="Remote Python 3.7.6 (sftp://steffen@aimachine:22/home/steffen/envs/traj_gen/bin/python)" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
</module>

2
.idea/misc.xml generated
View File

@ -3,5 +3,5 @@
<component name="JavaScriptSettings">
<option name="languageLevel" value="ES6" />
</component>
<component name="ProjectRootManager" version="2" project-jdk-name="traj_gen@AiMachine" project-jdk-type="Python SDK" />
<component name="ProjectRootManager" version="2" project-jdk-name="hom_traj_gen@aimachine" project-jdk-type="Python SDK" />
</project>

View File

@ -3,7 +3,6 @@ from pathlib import Path
from typing import Union, List
import torch
from random import choice
from torch.utils.data import ConcatDataset, Dataset
from lib.objects.map import Map
@ -17,12 +16,14 @@ class TrajDataset(Dataset):
return self.map.as_array.shape
def __init__(self, *args, maps_root: Union[Path, str] = '', mapname='tate_sw',
length=100000, all_in_map=True, embedding_size=None, **kwargs):
length=100000, all_in_map=True, embedding_size=None, preserve_equal_samples=False, **kwargs):
super(TrajDataset, self).__init__()
self.preserve_equal_samples = preserve_equal_samples
self.all_in_map = all_in_map
self.mapname = mapname if mapname.endswith('.bmp') else f'{mapname}.bmp'
self.maps_root = maps_root
self._len = length
self.last_label = -1
self.map = Map(self.mapname).from_image(self.maps_root / self.mapname, embedding_size=embedding_size)
@ -31,8 +32,16 @@ class TrajDataset(Dataset):
def __getitem__(self, item):
trajectory = self.map.get_random_trajectory()
while True:
# TODO: Sanity Check this while true loop...
alternative = self.map.generate_alternative(trajectory)
label = choice([0, 1])
label = self.map.are_homotopic(trajectory, alternative)
if self.preserve_equal_samples and label == self.last_label:
continue
else:
break
self.last_label = label
if self.all_in_map:
blank_trajectory_space = torch.zeros(self.map.shape)
blank_alternative_space = torch.zeros(self.map.shape)
@ -41,7 +50,6 @@ class TrajDataset(Dataset):
blank_alternative_space[index] = 1
map_array = torch.as_tensor(self.map.as_array).float()
label = self.map.are_homotopic(trajectory, alternative)
return torch.cat((map_array, blank_trajectory_space, blank_alternative_space)), int(label)
else:
return trajectory.vertices, alternative.vertices, label, self.mapname
@ -78,7 +86,8 @@ class TrajData(object):
# find max image size among available maps:
max_map_size = (1, ) + tuple(reversed(tuple(map(max, *[Image.open(map_file).size for map_file in map_files]))))
return ConcatDataset([TrajDataset(maps_root=self.maps_root, mapname=map_file.name, length=equal_split,
all_in_map=self.all_in_map, embedding_size=max_map_size)
all_in_map=self.all_in_map, embedding_size=max_map_size,
preserve_equal_samples=True)
for map_file in map_files])
@property

View File

@ -18,9 +18,7 @@ class Logger(LightningLoggerBase):
@property
def log_dir(self):
if self.debug:
return Path(self.outpath)
return Path(self.experiment.log_dir).parent
return Path(self.testtubelogger.experiment.get_logdir()).parent
@property
def name(self):
@ -32,14 +30,14 @@ class Logger(LightningLoggerBase):
@property
def version(self):
return f"version_{self.config.get('main', 'seed')}"
return self.config.get('main', 'seed')
@property
def outpath(self):
# ToDo: Add further path modification such as dataset config etc.
return Path(self.config.train.outpath)
def __init__(self, config: Config, debug=False):
def __init__(self, config: Config):
"""
params (dict|None): Optional. Parameters of the experiment. After experiment creation params are read-only.
Parameters are displayed in the experiments Parameters section and each key-value pair can be
@ -53,8 +51,8 @@ class Logger(LightningLoggerBase):
"""
super(Logger, self).__init__()
self.debug = debug
self.config = config
self.debug = self.config.main.debug
self._testtube_kwargs = dict(save_dir=self.outpath, version=self.version, name=self.name)
self._neptune_kwargs = dict(offline_mode=self.debug,
api_key=self.config.project.neptune_key,
@ -68,10 +66,30 @@ class Logger(LightningLoggerBase):
self.testtubelogger.log_hyperparams(params)
pass
def log_metrics(self, metrics, step_num):
self.neptunelogger.log_metrics(metrics, step_num)
self.testtubelogger.log_metrics(metrics, step_num)
def log_metrics(self, metrics, step=None):
self.neptunelogger.log_metrics(metrics, step=step)
self.testtubelogger.log_metrics(metrics, step=step)
pass
def close(self):
self.testtubelogger.close()
self.neptunelogger.close()
def log_config_as_ini(self):
self.config.write(self.log_dir)
def save(self):
self.testtubelogger.save()
self.neptunelogger.save()
def finalize(self, status):
self.testtubelogger.finalize()
self.neptunelogger.finalize()
self.log_config_as_ini()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.finalize('success')
pass

47
main.py
View File

@ -9,7 +9,7 @@ import warnings
import torch
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from lib.modules.utils import LightningBaseModule
from lib.utils.config import Config
@ -43,7 +43,7 @@ main_arg_parser.add_argument("--transformations_to_tensor", type=strtobool, defa
main_arg_parser.add_argument("--train_outpath", type=str, default="output", help="")
main_arg_parser.add_argument("--train_version", type=strtobool, required=False, help="")
main_arg_parser.add_argument("--train_epochs", type=int, default=10, help="")
main_arg_parser.add_argument("--train_batch_size", type=int, default=512, help="")
main_arg_parser.add_argument("--train_batch_size", type=int, default=256, help="")
main_arg_parser.add_argument("--train_lr", type=float, default=0.002, help="")
# Model
@ -64,21 +64,29 @@ main_arg_parser.add_argument("--project_neptune_key", type=str, default=os.geten
args = main_arg_parser.parse_args()
config = Config.read_namespace(args)
# Logger
# =============================================================================
logger = Logger(config, debug=True)
# Checkpoint Callback
# =============================================================================
checkpoint_callback = ModelCheckpoint(
filepath=str(logger.log_dir / 'ckpt_weights'),
verbose=True,
period=1
)
if __name__ == "__main__":
# Logging
# =============================================================================
# Logger
with Logger(config) as logger:
# Callbacks
# =============================================================================
# Checkpoint Saving
checkpoint_callback = ModelCheckpoint(
filepath=str(logger.log_dir / 'ckpt_weights'),
verbose=True, save_top_k=5,
)
# =============================================================================
# Early Stopping
# TODO: For This to work, one must set a validation step and End Eval and Score
early_stopping_callback = EarlyStopping(
monitor='val_loss',
min_delta=0.0,
patience=0,
)
# Model
# =============================================================================
# Init
@ -87,15 +95,15 @@ if __name__ == "__main__":
# Trainer
# =============================================================================
trainer = Trainer(max_nb_epochs=config.train.epochs,
trainer = Trainer(max_epochs=config.train.epochs,
show_progress_bar=True,
weights_save_path=logger.log_dir,
gpus=[0] if torch.cuda.is_available() else None,
row_log_interval=model.data_len // 40, # TODO: Better Value / Setting
log_save_interval=model.data_len // 10, # TODO: Better Value / Setting
row_log_interval=(model.data_len * 0.01), # TODO: Better Value / Setting
log_save_interval=(model.data_len * 0.04), # TODO: Better Value / Setting
checkpoint_callback=checkpoint_callback,
logger=logger,
fast_dev_run=config.get('main', 'debug'),
fast_dev_run=config.main.debug,
early_stop_callback=None
)
@ -103,8 +111,7 @@ if __name__ == "__main__":
trainer.fit(model)
# Save the last state & all parameters
config.exp_path.mkdir(parents=True, exist_ok=True) # Todo: do i need this?
trainer.save_checkpoint(logger.log_dir / 'weights.ckpt')
model.save_to_disk(logger.log_dir)
pass
# TODO: Eval here!

0
main_post.py Normal file
View File

35
multi_run.py Normal file
View File

@ -0,0 +1,35 @@
import warnings
from lib.utils.config import Config
warnings.filterwarnings('ignore', category=FutureWarning)
warnings.filterwarnings('ignore', category=UserWarning)
# Imports
# =============================================================================
from pathlib import Path
import os
if __name__ == '__main__':
# Model Settings
warnings.filterwarnings('ignore', category=FutureWarning)
# use_bias, activation, model, use_norm, max_epochs, filters
cnn_classifier = [True, 'leaky_relu', 'classifier_cnn', False, 2, [16, 32, 64]]
# use_bias, activation, model, use_norm, max_epochs, sr, lat_dim, filters
# Data Settings
data_shortcodes = ['mid', 'mid_5']
# Iteration over
for data_shortcode in data_shortcodes:
for use_bias, activation, model, use_norm, max_epochs, filters in [cnn_classifier]:
for seed in range(5):
arg_dict = dict(main_seed=seed, train_max_epochs=max_epochs,
model_use_bias=use_bias, model_use_norm=use_norm,
model_activation=activation, model_type=model,
model_filters=filters,
data_batch_size=512)
os.system(f'/home/steffen/envs/traj_gen/bin/python main.py {arg_dict}')