Merge remote-tracking branch 'origin/master'
This commit is contained in:
11
.idea/deployment.xml
generated
11
.idea/deployment.xml
generated
@ -1,15 +1,8 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="PublishConfigData" autoUpload="On explicit save action" serverName="traj_gen-AiMachine">
|
||||
<component name="PublishConfigData" autoUpload="On explicit save action" serverName="steffen@aimachine:22">
|
||||
<serverData>
|
||||
<paths name="ErLoWa-AiMachine">
|
||||
<serverdata>
|
||||
<mappings>
|
||||
<mapping local="$PROJECT_DIR$" web="/" />
|
||||
</mappings>
|
||||
</serverdata>
|
||||
</paths>
|
||||
<paths name="traj_gen-AiMachine">
|
||||
<paths name="steffen@aimachine:22">
|
||||
<serverdata>
|
||||
<mappings>
|
||||
<mapping deploy="/" local="$PROJECT_DIR$" web="/" />
|
||||
|
2
.idea/hom_traj_gen.iml
generated
2
.idea/hom_traj_gen.iml
generated
@ -2,7 +2,7 @@
|
||||
<module type="PYTHON_MODULE" version="4">
|
||||
<component name="NewModuleRootManager">
|
||||
<content url="file://$MODULE_DIR$" />
|
||||
<orderEntry type="jdk" jdkName="traj_gen@AiMachine" jdkType="Python SDK" />
|
||||
<orderEntry type="jdk" jdkName="Remote Python 3.7.6 (sftp://steffen@aimachine:22/home/steffen/envs/traj_gen/bin/python)" jdkType="Python SDK" />
|
||||
<orderEntry type="sourceFolder" forTests="false" />
|
||||
</component>
|
||||
</module>
|
2
.idea/misc.xml
generated
2
.idea/misc.xml
generated
@ -3,5 +3,5 @@
|
||||
<component name="JavaScriptSettings">
|
||||
<option name="languageLevel" value="ES6" />
|
||||
</component>
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="traj_gen@AiMachine" project-jdk-type="Python SDK" />
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="hom_traj_gen@aimachine" project-jdk-type="Python SDK" />
|
||||
</project>
|
@ -3,7 +3,6 @@ from pathlib import Path
|
||||
from typing import Union, List
|
||||
|
||||
import torch
|
||||
from random import choice
|
||||
from torch.utils.data import ConcatDataset, Dataset
|
||||
|
||||
from lib.objects.map import Map
|
||||
@ -17,12 +16,14 @@ class TrajDataset(Dataset):
|
||||
return self.map.as_array.shape
|
||||
|
||||
def __init__(self, *args, maps_root: Union[Path, str] = '', mapname='tate_sw',
|
||||
length=100000, all_in_map=True, embedding_size=None, **kwargs):
|
||||
length=100000, all_in_map=True, embedding_size=None, preserve_equal_samples=False, **kwargs):
|
||||
super(TrajDataset, self).__init__()
|
||||
self.preserve_equal_samples = preserve_equal_samples
|
||||
self.all_in_map = all_in_map
|
||||
self.mapname = mapname if mapname.endswith('.bmp') else f'{mapname}.bmp'
|
||||
self.maps_root = maps_root
|
||||
self._len = length
|
||||
self.last_label = -1
|
||||
|
||||
self.map = Map(self.mapname).from_image(self.maps_root / self.mapname, embedding_size=embedding_size)
|
||||
|
||||
@ -31,8 +32,16 @@ class TrajDataset(Dataset):
|
||||
|
||||
def __getitem__(self, item):
|
||||
trajectory = self.map.get_random_trajectory()
|
||||
while True:
|
||||
# TODO: Sanity Check this while true loop...
|
||||
alternative = self.map.generate_alternative(trajectory)
|
||||
label = choice([0, 1])
|
||||
label = self.map.are_homotopic(trajectory, alternative)
|
||||
if self.preserve_equal_samples and label == self.last_label:
|
||||
continue
|
||||
else:
|
||||
break
|
||||
|
||||
self.last_label = label
|
||||
if self.all_in_map:
|
||||
blank_trajectory_space = torch.zeros(self.map.shape)
|
||||
blank_alternative_space = torch.zeros(self.map.shape)
|
||||
@ -41,7 +50,6 @@ class TrajDataset(Dataset):
|
||||
blank_alternative_space[index] = 1
|
||||
|
||||
map_array = torch.as_tensor(self.map.as_array).float()
|
||||
label = self.map.are_homotopic(trajectory, alternative)
|
||||
return torch.cat((map_array, blank_trajectory_space, blank_alternative_space)), int(label)
|
||||
else:
|
||||
return trajectory.vertices, alternative.vertices, label, self.mapname
|
||||
@ -78,7 +86,8 @@ class TrajData(object):
|
||||
# find max image size among available maps:
|
||||
max_map_size = (1, ) + tuple(reversed(tuple(map(max, *[Image.open(map_file).size for map_file in map_files]))))
|
||||
return ConcatDataset([TrajDataset(maps_root=self.maps_root, mapname=map_file.name, length=equal_split,
|
||||
all_in_map=self.all_in_map, embedding_size=max_map_size)
|
||||
all_in_map=self.all_in_map, embedding_size=max_map_size,
|
||||
preserve_equal_samples=True)
|
||||
for map_file in map_files])
|
||||
|
||||
@property
|
||||
|
@ -18,9 +18,7 @@ class Logger(LightningLoggerBase):
|
||||
|
||||
@property
|
||||
def log_dir(self):
|
||||
if self.debug:
|
||||
return Path(self.outpath)
|
||||
return Path(self.experiment.log_dir).parent
|
||||
return Path(self.testtubelogger.experiment.get_logdir()).parent
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
@ -32,14 +30,14 @@ class Logger(LightningLoggerBase):
|
||||
|
||||
@property
|
||||
def version(self):
|
||||
return f"version_{self.config.get('main', 'seed')}"
|
||||
return self.config.get('main', 'seed')
|
||||
|
||||
@property
|
||||
def outpath(self):
|
||||
# ToDo: Add further path modification such as dataset config etc.
|
||||
return Path(self.config.train.outpath)
|
||||
|
||||
def __init__(self, config: Config, debug=False):
|
||||
def __init__(self, config: Config):
|
||||
"""
|
||||
params (dict|None): Optional. Parameters of the experiment. After experiment creation params are read-only.
|
||||
Parameters are displayed in the experiment’s Parameters section and each key-value pair can be
|
||||
@ -53,8 +51,8 @@ class Logger(LightningLoggerBase):
|
||||
"""
|
||||
super(Logger, self).__init__()
|
||||
|
||||
self.debug = debug
|
||||
self.config = config
|
||||
self.debug = self.config.main.debug
|
||||
self._testtube_kwargs = dict(save_dir=self.outpath, version=self.version, name=self.name)
|
||||
self._neptune_kwargs = dict(offline_mode=self.debug,
|
||||
api_key=self.config.project.neptune_key,
|
||||
@ -68,10 +66,30 @@ class Logger(LightningLoggerBase):
|
||||
self.testtubelogger.log_hyperparams(params)
|
||||
pass
|
||||
|
||||
def log_metrics(self, metrics, step_num):
|
||||
self.neptunelogger.log_metrics(metrics, step_num)
|
||||
self.testtubelogger.log_metrics(metrics, step_num)
|
||||
def log_metrics(self, metrics, step=None):
|
||||
self.neptunelogger.log_metrics(metrics, step=step)
|
||||
self.testtubelogger.log_metrics(metrics, step=step)
|
||||
pass
|
||||
|
||||
def close(self):
|
||||
self.testtubelogger.close()
|
||||
self.neptunelogger.close()
|
||||
|
||||
def log_config_as_ini(self):
|
||||
self.config.write(self.log_dir)
|
||||
|
||||
def save(self):
|
||||
self.testtubelogger.save()
|
||||
self.neptunelogger.save()
|
||||
|
||||
def finalize(self, status):
|
||||
self.testtubelogger.finalize()
|
||||
self.neptunelogger.finalize()
|
||||
self.log_config_as_ini()
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.finalize('success')
|
||||
pass
|
||||
|
47
main.py
47
main.py
@ -9,7 +9,7 @@ import warnings
|
||||
|
||||
import torch
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
|
||||
|
||||
from lib.modules.utils import LightningBaseModule
|
||||
from lib.utils.config import Config
|
||||
@ -43,7 +43,7 @@ main_arg_parser.add_argument("--transformations_to_tensor", type=strtobool, defa
|
||||
main_arg_parser.add_argument("--train_outpath", type=str, default="output", help="")
|
||||
main_arg_parser.add_argument("--train_version", type=strtobool, required=False, help="")
|
||||
main_arg_parser.add_argument("--train_epochs", type=int, default=10, help="")
|
||||
main_arg_parser.add_argument("--train_batch_size", type=int, default=512, help="")
|
||||
main_arg_parser.add_argument("--train_batch_size", type=int, default=256, help="")
|
||||
main_arg_parser.add_argument("--train_lr", type=float, default=0.002, help="")
|
||||
|
||||
# Model
|
||||
@ -64,21 +64,29 @@ main_arg_parser.add_argument("--project_neptune_key", type=str, default=os.geten
|
||||
args = main_arg_parser.parse_args()
|
||||
config = Config.read_namespace(args)
|
||||
|
||||
# Logger
|
||||
# =============================================================================
|
||||
logger = Logger(config, debug=True)
|
||||
|
||||
# Checkpoint Callback
|
||||
# =============================================================================
|
||||
checkpoint_callback = ModelCheckpoint(
|
||||
filepath=str(logger.log_dir / 'ckpt_weights'),
|
||||
verbose=True,
|
||||
period=1
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
# Logging
|
||||
# =============================================================================
|
||||
# Logger
|
||||
with Logger(config) as logger:
|
||||
# Callbacks
|
||||
# =============================================================================
|
||||
# Checkpoint Saving
|
||||
checkpoint_callback = ModelCheckpoint(
|
||||
filepath=str(logger.log_dir / 'ckpt_weights'),
|
||||
verbose=True, save_top_k=5,
|
||||
)
|
||||
# =============================================================================
|
||||
# Early Stopping
|
||||
# TODO: For This to work, one must set a validation step and End Eval and Score
|
||||
early_stopping_callback = EarlyStopping(
|
||||
monitor='val_loss',
|
||||
min_delta=0.0,
|
||||
patience=0,
|
||||
)
|
||||
|
||||
# Model
|
||||
# =============================================================================
|
||||
# Init
|
||||
@ -87,15 +95,15 @@ if __name__ == "__main__":
|
||||
|
||||
# Trainer
|
||||
# =============================================================================
|
||||
trainer = Trainer(max_nb_epochs=config.train.epochs,
|
||||
trainer = Trainer(max_epochs=config.train.epochs,
|
||||
show_progress_bar=True,
|
||||
weights_save_path=logger.log_dir,
|
||||
gpus=[0] if torch.cuda.is_available() else None,
|
||||
row_log_interval=model.data_len // 40, # TODO: Better Value / Setting
|
||||
log_save_interval=model.data_len // 10, # TODO: Better Value / Setting
|
||||
row_log_interval=(model.data_len * 0.01), # TODO: Better Value / Setting
|
||||
log_save_interval=(model.data_len * 0.04), # TODO: Better Value / Setting
|
||||
checkpoint_callback=checkpoint_callback,
|
||||
logger=logger,
|
||||
fast_dev_run=config.get('main', 'debug'),
|
||||
fast_dev_run=config.main.debug,
|
||||
early_stop_callback=None
|
||||
)
|
||||
|
||||
@ -103,8 +111,7 @@ if __name__ == "__main__":
|
||||
trainer.fit(model)
|
||||
|
||||
# Save the last state & all parameters
|
||||
config.exp_path.mkdir(parents=True, exist_ok=True) # Todo: do i need this?
|
||||
trainer.save_checkpoint(logger.log_dir / 'weights.ckpt')
|
||||
model.save_to_disk(logger.log_dir)
|
||||
|
||||
pass
|
||||
# TODO: Eval here!
|
||||
|
0
main_post.py
Normal file
0
main_post.py
Normal file
35
multi_run.py
Normal file
35
multi_run.py
Normal file
@ -0,0 +1,35 @@
|
||||
import warnings
|
||||
|
||||
from lib.utils.config import Config
|
||||
|
||||
warnings.filterwarnings('ignore', category=FutureWarning)
|
||||
warnings.filterwarnings('ignore', category=UserWarning)
|
||||
|
||||
# Imports
|
||||
# =============================================================================
|
||||
from pathlib import Path
|
||||
import os
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
# Model Settings
|
||||
warnings.filterwarnings('ignore', category=FutureWarning)
|
||||
# use_bias, activation, model, use_norm, max_epochs, filters
|
||||
cnn_classifier = [True, 'leaky_relu', 'classifier_cnn', False, 2, [16, 32, 64]]
|
||||
# use_bias, activation, model, use_norm, max_epochs, sr, lat_dim, filters
|
||||
|
||||
# Data Settings
|
||||
data_shortcodes = ['mid', 'mid_5']
|
||||
|
||||
# Iteration over
|
||||
for data_shortcode in data_shortcodes:
|
||||
for use_bias, activation, model, use_norm, max_epochs, filters in [cnn_classifier]:
|
||||
for seed in range(5):
|
||||
arg_dict = dict(main_seed=seed, train_max_epochs=max_epochs,
|
||||
model_use_bias=use_bias, model_use_norm=use_norm,
|
||||
model_activation=activation, model_type=model,
|
||||
model_filters=filters,
|
||||
data_batch_size=512)
|
||||
|
||||
os.system(f'/home/steffen/envs/traj_gen/bin/python main.py {arg_dict}')
|
Reference in New Issue
Block a user