Merge remote-tracking branch 'origin/master'
This commit is contained in:
11
.idea/deployment.xml
generated
11
.idea/deployment.xml
generated
@ -1,15 +1,8 @@
|
|||||||
<?xml version="1.0" encoding="UTF-8"?>
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
<project version="4">
|
<project version="4">
|
||||||
<component name="PublishConfigData" autoUpload="On explicit save action" serverName="traj_gen-AiMachine">
|
<component name="PublishConfigData" autoUpload="On explicit save action" serverName="steffen@aimachine:22">
|
||||||
<serverData>
|
<serverData>
|
||||||
<paths name="ErLoWa-AiMachine">
|
<paths name="steffen@aimachine:22">
|
||||||
<serverdata>
|
|
||||||
<mappings>
|
|
||||||
<mapping local="$PROJECT_DIR$" web="/" />
|
|
||||||
</mappings>
|
|
||||||
</serverdata>
|
|
||||||
</paths>
|
|
||||||
<paths name="traj_gen-AiMachine">
|
|
||||||
<serverdata>
|
<serverdata>
|
||||||
<mappings>
|
<mappings>
|
||||||
<mapping deploy="/" local="$PROJECT_DIR$" web="/" />
|
<mapping deploy="/" local="$PROJECT_DIR$" web="/" />
|
||||||
|
2
.idea/hom_traj_gen.iml
generated
2
.idea/hom_traj_gen.iml
generated
@ -2,7 +2,7 @@
|
|||||||
<module type="PYTHON_MODULE" version="4">
|
<module type="PYTHON_MODULE" version="4">
|
||||||
<component name="NewModuleRootManager">
|
<component name="NewModuleRootManager">
|
||||||
<content url="file://$MODULE_DIR$" />
|
<content url="file://$MODULE_DIR$" />
|
||||||
<orderEntry type="jdk" jdkName="traj_gen@AiMachine" jdkType="Python SDK" />
|
<orderEntry type="jdk" jdkName="Remote Python 3.7.6 (sftp://steffen@aimachine:22/home/steffen/envs/traj_gen/bin/python)" jdkType="Python SDK" />
|
||||||
<orderEntry type="sourceFolder" forTests="false" />
|
<orderEntry type="sourceFolder" forTests="false" />
|
||||||
</component>
|
</component>
|
||||||
</module>
|
</module>
|
2
.idea/misc.xml
generated
2
.idea/misc.xml
generated
@ -3,5 +3,5 @@
|
|||||||
<component name="JavaScriptSettings">
|
<component name="JavaScriptSettings">
|
||||||
<option name="languageLevel" value="ES6" />
|
<option name="languageLevel" value="ES6" />
|
||||||
</component>
|
</component>
|
||||||
<component name="ProjectRootManager" version="2" project-jdk-name="traj_gen@AiMachine" project-jdk-type="Python SDK" />
|
<component name="ProjectRootManager" version="2" project-jdk-name="hom_traj_gen@aimachine" project-jdk-type="Python SDK" />
|
||||||
</project>
|
</project>
|
@ -3,7 +3,6 @@ from pathlib import Path
|
|||||||
from typing import Union, List
|
from typing import Union, List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from random import choice
|
|
||||||
from torch.utils.data import ConcatDataset, Dataset
|
from torch.utils.data import ConcatDataset, Dataset
|
||||||
|
|
||||||
from lib.objects.map import Map
|
from lib.objects.map import Map
|
||||||
@ -17,12 +16,14 @@ class TrajDataset(Dataset):
|
|||||||
return self.map.as_array.shape
|
return self.map.as_array.shape
|
||||||
|
|
||||||
def __init__(self, *args, maps_root: Union[Path, str] = '', mapname='tate_sw',
|
def __init__(self, *args, maps_root: Union[Path, str] = '', mapname='tate_sw',
|
||||||
length=100000, all_in_map=True, embedding_size=None, **kwargs):
|
length=100000, all_in_map=True, embedding_size=None, preserve_equal_samples=False, **kwargs):
|
||||||
super(TrajDataset, self).__init__()
|
super(TrajDataset, self).__init__()
|
||||||
|
self.preserve_equal_samples = preserve_equal_samples
|
||||||
self.all_in_map = all_in_map
|
self.all_in_map = all_in_map
|
||||||
self.mapname = mapname if mapname.endswith('.bmp') else f'{mapname}.bmp'
|
self.mapname = mapname if mapname.endswith('.bmp') else f'{mapname}.bmp'
|
||||||
self.maps_root = maps_root
|
self.maps_root = maps_root
|
||||||
self._len = length
|
self._len = length
|
||||||
|
self.last_label = -1
|
||||||
|
|
||||||
self.map = Map(self.mapname).from_image(self.maps_root / self.mapname, embedding_size=embedding_size)
|
self.map = Map(self.mapname).from_image(self.maps_root / self.mapname, embedding_size=embedding_size)
|
||||||
|
|
||||||
@ -31,8 +32,16 @@ class TrajDataset(Dataset):
|
|||||||
|
|
||||||
def __getitem__(self, item):
|
def __getitem__(self, item):
|
||||||
trajectory = self.map.get_random_trajectory()
|
trajectory = self.map.get_random_trajectory()
|
||||||
alternative = self.map.generate_alternative(trajectory)
|
while True:
|
||||||
label = choice([0, 1])
|
# TODO: Sanity Check this while true loop...
|
||||||
|
alternative = self.map.generate_alternative(trajectory)
|
||||||
|
label = self.map.are_homotopic(trajectory, alternative)
|
||||||
|
if self.preserve_equal_samples and label == self.last_label:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
self.last_label = label
|
||||||
if self.all_in_map:
|
if self.all_in_map:
|
||||||
blank_trajectory_space = torch.zeros(self.map.shape)
|
blank_trajectory_space = torch.zeros(self.map.shape)
|
||||||
blank_alternative_space = torch.zeros(self.map.shape)
|
blank_alternative_space = torch.zeros(self.map.shape)
|
||||||
@ -41,7 +50,6 @@ class TrajDataset(Dataset):
|
|||||||
blank_alternative_space[index] = 1
|
blank_alternative_space[index] = 1
|
||||||
|
|
||||||
map_array = torch.as_tensor(self.map.as_array).float()
|
map_array = torch.as_tensor(self.map.as_array).float()
|
||||||
label = self.map.are_homotopic(trajectory, alternative)
|
|
||||||
return torch.cat((map_array, blank_trajectory_space, blank_alternative_space)), int(label)
|
return torch.cat((map_array, blank_trajectory_space, blank_alternative_space)), int(label)
|
||||||
else:
|
else:
|
||||||
return trajectory.vertices, alternative.vertices, label, self.mapname
|
return trajectory.vertices, alternative.vertices, label, self.mapname
|
||||||
@ -78,7 +86,8 @@ class TrajData(object):
|
|||||||
# find max image size among available maps:
|
# find max image size among available maps:
|
||||||
max_map_size = (1, ) + tuple(reversed(tuple(map(max, *[Image.open(map_file).size for map_file in map_files]))))
|
max_map_size = (1, ) + tuple(reversed(tuple(map(max, *[Image.open(map_file).size for map_file in map_files]))))
|
||||||
return ConcatDataset([TrajDataset(maps_root=self.maps_root, mapname=map_file.name, length=equal_split,
|
return ConcatDataset([TrajDataset(maps_root=self.maps_root, mapname=map_file.name, length=equal_split,
|
||||||
all_in_map=self.all_in_map, embedding_size=max_map_size)
|
all_in_map=self.all_in_map, embedding_size=max_map_size,
|
||||||
|
preserve_equal_samples=True)
|
||||||
for map_file in map_files])
|
for map_file in map_files])
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -18,9 +18,7 @@ class Logger(LightningLoggerBase):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def log_dir(self):
|
def log_dir(self):
|
||||||
if self.debug:
|
return Path(self.testtubelogger.experiment.get_logdir()).parent
|
||||||
return Path(self.outpath)
|
|
||||||
return Path(self.experiment.log_dir).parent
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self):
|
def name(self):
|
||||||
@ -32,14 +30,14 @@ class Logger(LightningLoggerBase):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def version(self):
|
def version(self):
|
||||||
return f"version_{self.config.get('main', 'seed')}"
|
return self.config.get('main', 'seed')
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def outpath(self):
|
def outpath(self):
|
||||||
# ToDo: Add further path modification such as dataset config etc.
|
# ToDo: Add further path modification such as dataset config etc.
|
||||||
return Path(self.config.train.outpath)
|
return Path(self.config.train.outpath)
|
||||||
|
|
||||||
def __init__(self, config: Config, debug=False):
|
def __init__(self, config: Config):
|
||||||
"""
|
"""
|
||||||
params (dict|None): Optional. Parameters of the experiment. After experiment creation params are read-only.
|
params (dict|None): Optional. Parameters of the experiment. After experiment creation params are read-only.
|
||||||
Parameters are displayed in the experiment’s Parameters section and each key-value pair can be
|
Parameters are displayed in the experiment’s Parameters section and each key-value pair can be
|
||||||
@ -53,8 +51,8 @@ class Logger(LightningLoggerBase):
|
|||||||
"""
|
"""
|
||||||
super(Logger, self).__init__()
|
super(Logger, self).__init__()
|
||||||
|
|
||||||
self.debug = debug
|
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.debug = self.config.main.debug
|
||||||
self._testtube_kwargs = dict(save_dir=self.outpath, version=self.version, name=self.name)
|
self._testtube_kwargs = dict(save_dir=self.outpath, version=self.version, name=self.name)
|
||||||
self._neptune_kwargs = dict(offline_mode=self.debug,
|
self._neptune_kwargs = dict(offline_mode=self.debug,
|
||||||
api_key=self.config.project.neptune_key,
|
api_key=self.config.project.neptune_key,
|
||||||
@ -68,10 +66,30 @@ class Logger(LightningLoggerBase):
|
|||||||
self.testtubelogger.log_hyperparams(params)
|
self.testtubelogger.log_hyperparams(params)
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def log_metrics(self, metrics, step_num):
|
def log_metrics(self, metrics, step=None):
|
||||||
self.neptunelogger.log_metrics(metrics, step_num)
|
self.neptunelogger.log_metrics(metrics, step=step)
|
||||||
self.testtubelogger.log_metrics(metrics, step_num)
|
self.testtubelogger.log_metrics(metrics, step=step)
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
self.testtubelogger.close()
|
||||||
|
self.neptunelogger.close()
|
||||||
|
|
||||||
def log_config_as_ini(self):
|
def log_config_as_ini(self):
|
||||||
self.config.write(self.log_dir)
|
self.config.write(self.log_dir)
|
||||||
|
|
||||||
|
def save(self):
|
||||||
|
self.testtubelogger.save()
|
||||||
|
self.neptunelogger.save()
|
||||||
|
|
||||||
|
def finalize(self, status):
|
||||||
|
self.testtubelogger.finalize()
|
||||||
|
self.neptunelogger.finalize()
|
||||||
|
self.log_config_as_ini()
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
self.finalize('success')
|
||||||
|
pass
|
||||||
|
81
main.py
81
main.py
@ -9,7 +9,7 @@ import warnings
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pytorch_lightning import Trainer
|
from pytorch_lightning import Trainer
|
||||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
|
||||||
|
|
||||||
from lib.modules.utils import LightningBaseModule
|
from lib.modules.utils import LightningBaseModule
|
||||||
from lib.utils.config import Config
|
from lib.utils.config import Config
|
||||||
@ -43,7 +43,7 @@ main_arg_parser.add_argument("--transformations_to_tensor", type=strtobool, defa
|
|||||||
main_arg_parser.add_argument("--train_outpath", type=str, default="output", help="")
|
main_arg_parser.add_argument("--train_outpath", type=str, default="output", help="")
|
||||||
main_arg_parser.add_argument("--train_version", type=strtobool, required=False, help="")
|
main_arg_parser.add_argument("--train_version", type=strtobool, required=False, help="")
|
||||||
main_arg_parser.add_argument("--train_epochs", type=int, default=10, help="")
|
main_arg_parser.add_argument("--train_epochs", type=int, default=10, help="")
|
||||||
main_arg_parser.add_argument("--train_batch_size", type=int, default=512, help="")
|
main_arg_parser.add_argument("--train_batch_size", type=int, default=256, help="")
|
||||||
main_arg_parser.add_argument("--train_lr", type=float, default=0.002, help="")
|
main_arg_parser.add_argument("--train_lr", type=float, default=0.002, help="")
|
||||||
|
|
||||||
# Model
|
# Model
|
||||||
@ -64,47 +64,54 @@ main_arg_parser.add_argument("--project_neptune_key", type=str, default=os.geten
|
|||||||
args = main_arg_parser.parse_args()
|
args = main_arg_parser.parse_args()
|
||||||
config = Config.read_namespace(args)
|
config = Config.read_namespace(args)
|
||||||
|
|
||||||
# Logger
|
|
||||||
# =============================================================================
|
|
||||||
logger = Logger(config, debug=True)
|
|
||||||
|
|
||||||
# Checkpoint Callback
|
|
||||||
# =============================================================================
|
|
||||||
checkpoint_callback = ModelCheckpoint(
|
|
||||||
filepath=str(logger.log_dir / 'ckpt_weights'),
|
|
||||||
verbose=True,
|
|
||||||
period=1
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
# Model
|
# Logging
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Init
|
# Logger
|
||||||
model: LightningBaseModule = config.model_class(config.model_paramters)
|
with Logger(config) as logger:
|
||||||
model.init_weights()
|
# Callbacks
|
||||||
|
# =============================================================================
|
||||||
|
# Checkpoint Saving
|
||||||
|
checkpoint_callback = ModelCheckpoint(
|
||||||
|
filepath=str(logger.log_dir / 'ckpt_weights'),
|
||||||
|
verbose=True, save_top_k=5,
|
||||||
|
)
|
||||||
|
# =============================================================================
|
||||||
|
# Early Stopping
|
||||||
|
# TODO: For This to work, one must set a validation step and End Eval and Score
|
||||||
|
early_stopping_callback = EarlyStopping(
|
||||||
|
monitor='val_loss',
|
||||||
|
min_delta=0.0,
|
||||||
|
patience=0,
|
||||||
|
)
|
||||||
|
|
||||||
# Trainer
|
# Model
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
trainer = Trainer(max_nb_epochs=config.train.epochs,
|
# Init
|
||||||
show_progress_bar=True,
|
model: LightningBaseModule = config.model_class(config.model_paramters)
|
||||||
weights_save_path=logger.log_dir,
|
model.init_weights()
|
||||||
gpus=[0] if torch.cuda.is_available() else None,
|
|
||||||
row_log_interval=model.data_len // 40, # TODO: Better Value / Setting
|
|
||||||
log_save_interval=model.data_len // 10, # TODO: Better Value / Setting
|
|
||||||
checkpoint_callback=checkpoint_callback,
|
|
||||||
logger=logger,
|
|
||||||
fast_dev_run=config.get('main', 'debug'),
|
|
||||||
early_stop_callback=None
|
|
||||||
)
|
|
||||||
|
|
||||||
# Train it
|
# Trainer
|
||||||
trainer.fit(model)
|
# =============================================================================
|
||||||
|
trainer = Trainer(max_epochs=config.train.epochs,
|
||||||
|
show_progress_bar=True,
|
||||||
|
weights_save_path=logger.log_dir,
|
||||||
|
gpus=[0] if torch.cuda.is_available() else None,
|
||||||
|
row_log_interval=(model.data_len * 0.01), # TODO: Better Value / Setting
|
||||||
|
log_save_interval=(model.data_len * 0.04), # TODO: Better Value / Setting
|
||||||
|
checkpoint_callback=checkpoint_callback,
|
||||||
|
logger=logger,
|
||||||
|
fast_dev_run=config.main.debug,
|
||||||
|
early_stop_callback=None
|
||||||
|
)
|
||||||
|
|
||||||
# Save the last state & all parameters
|
# Train it
|
||||||
config.exp_path.mkdir(parents=True, exist_ok=True) # Todo: do i need this?
|
trainer.fit(model)
|
||||||
trainer.save_checkpoint(logger.log_dir / 'weights.ckpt')
|
|
||||||
model.save_to_disk(logger.log_dir)
|
|
||||||
|
|
||||||
|
# Save the last state & all parameters
|
||||||
|
trainer.save_checkpoint(logger.log_dir / 'weights.ckpt')
|
||||||
|
model.save_to_disk(logger.log_dir)
|
||||||
|
pass
|
||||||
# TODO: Eval here!
|
# TODO: Eval here!
|
||||||
|
0
main_post.py
Normal file
0
main_post.py
Normal file
35
multi_run.py
Normal file
35
multi_run.py
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
import warnings
|
||||||
|
|
||||||
|
from lib.utils.config import Config
|
||||||
|
|
||||||
|
warnings.filterwarnings('ignore', category=FutureWarning)
|
||||||
|
warnings.filterwarnings('ignore', category=UserWarning)
|
||||||
|
|
||||||
|
# Imports
|
||||||
|
# =============================================================================
|
||||||
|
from pathlib import Path
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
|
||||||
|
# Model Settings
|
||||||
|
warnings.filterwarnings('ignore', category=FutureWarning)
|
||||||
|
# use_bias, activation, model, use_norm, max_epochs, filters
|
||||||
|
cnn_classifier = [True, 'leaky_relu', 'classifier_cnn', False, 2, [16, 32, 64]]
|
||||||
|
# use_bias, activation, model, use_norm, max_epochs, sr, lat_dim, filters
|
||||||
|
|
||||||
|
# Data Settings
|
||||||
|
data_shortcodes = ['mid', 'mid_5']
|
||||||
|
|
||||||
|
# Iteration over
|
||||||
|
for data_shortcode in data_shortcodes:
|
||||||
|
for use_bias, activation, model, use_norm, max_epochs, filters in [cnn_classifier]:
|
||||||
|
for seed in range(5):
|
||||||
|
arg_dict = dict(main_seed=seed, train_max_epochs=max_epochs,
|
||||||
|
model_use_bias=use_bias, model_use_norm=use_norm,
|
||||||
|
model_activation=activation, model_type=model,
|
||||||
|
model_filters=filters,
|
||||||
|
data_batch_size=512)
|
||||||
|
|
||||||
|
os.system(f'/home/steffen/envs/traj_gen/bin/python main.py {arg_dict}')
|
Reference in New Issue
Block a user