train running dataset fixed
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@ -69,3 +69,4 @@ fabric.properties
|
||||
|
||||
# Android studio 3.1+ serialized cache file
|
||||
.idea/caches/build_file_checksums.ser
|
||||
/.idea/inspectionProfiles/
|
||||
|
11
.idea/deployment.xml
generated
11
.idea/deployment.xml
generated
@ -1,15 +1,8 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="PublishConfigData" autoUpload="On explicit save action" serverName="traj_gen-AiMachine" showAutoUploadSettingsWarning="false">
|
||||
<component name="PublishConfigData" autoUpload="On explicit save action" serverName="steffen@aimachine:22" showAutoUploadSettingsWarning="false">
|
||||
<serverData>
|
||||
<paths name="ErLoWa-AiMachine">
|
||||
<serverdata>
|
||||
<mappings>
|
||||
<mapping local="$PROJECT_DIR$" web="/" />
|
||||
</mappings>
|
||||
</serverdata>
|
||||
</paths>
|
||||
<paths name="traj_gen-AiMachine">
|
||||
<paths name="steffen@aimachine:22">
|
||||
<serverdata>
|
||||
<mappings>
|
||||
<mapping deploy="/" local="$PROJECT_DIR$" web="/" />
|
||||
|
2
.idea/hom_traj_gen.iml
generated
2
.idea/hom_traj_gen.iml
generated
@ -2,7 +2,7 @@
|
||||
<module type="PYTHON_MODULE" version="4">
|
||||
<component name="NewModuleRootManager">
|
||||
<content url="file://$MODULE_DIR$" />
|
||||
<orderEntry type="jdk" jdkName="traj_gen@AiMachine" jdkType="Python SDK" />
|
||||
<orderEntry type="jdk" jdkName="hom_traj_gen@aimachine" jdkType="Python SDK" />
|
||||
<orderEntry type="sourceFolder" forTests="false" />
|
||||
</component>
|
||||
</module>
|
1
.idea/inspectionProfiles/profiles_settings.xml
generated
1
.idea/inspectionProfiles/profiles_settings.xml
generated
@ -1,5 +1,6 @@
|
||||
<component name="InspectionProjectProfileManager">
|
||||
<settings>
|
||||
<option name="PROJECT_PROFILE" value="Default" />
|
||||
<option name="USE_PROJECT_PROFILE" value="false" />
|
||||
<version value="1.0" />
|
||||
</settings>
|
||||
|
2
.idea/misc.xml
generated
2
.idea/misc.xml
generated
@ -3,5 +3,5 @@
|
||||
<component name="JavaScriptSettings">
|
||||
<option name="languageLevel" value="ES6" />
|
||||
</component>
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="traj_gen@AiMachine" project-jdk-type="Python SDK" />
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="hom_traj_gen@aimachine" project-jdk-type="Python SDK" />
|
||||
</project>
|
@ -87,7 +87,7 @@ class TrajData(object):
|
||||
max_map_size = (1, ) + tuple(reversed(tuple(map(max, *[Image.open(map_file).size for map_file in map_files]))))
|
||||
return ConcatDataset([TrajDataset(maps_root=self.maps_root, mapname=map_file.name, length=equal_split,
|
||||
all_in_map=self.all_in_map, embedding_size=max_map_size,
|
||||
preserve_equal_samples=True)
|
||||
preserve_equal_samples=False)
|
||||
for map_file in map_files])
|
||||
|
||||
@property
|
||||
|
@ -11,6 +11,7 @@ from datasets.trajectory_dataset import TrajData
|
||||
from lib.evaluation.classification import ROCEvaluation
|
||||
from lib.modules.utils import LightningBaseModule, Flatten
|
||||
from lib.modules.blocks import ConvModule, ResidualModule
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
class ConvHomDetector(LightningBaseModule):
|
||||
@ -36,10 +37,9 @@ class ConvHomDetector(LightningBaseModule):
|
||||
predictions = torch.stack([x['prediction'] for x in outputs])
|
||||
labels = torch.stack([x['label'] for x in outputs])
|
||||
|
||||
scores = evaluation(predictions.numpy(), labels.numpy())
|
||||
self.logger.log_metrics()
|
||||
|
||||
|
||||
scores = evaluation(predictions.numpy(), labels.numpy(), )
|
||||
self.logger.log_metrics({key:value for key, value in zip(['roc_auc', 'tpr', 'fpr'], scores)})
|
||||
self.logger.log_image(f'{self.name}', plt.gcf())
|
||||
pass
|
||||
|
||||
def __init__(self, *params):
|
||||
@ -88,6 +88,19 @@ class ConvHomDetector(LightningBaseModule):
|
||||
self.classifier = nn.Linear(self.hparams.model_param.classes * 10, 1) # self.hparams.model_param.classes)
|
||||
self.out_activation = nn.Sigmoid() # nn.Softmax
|
||||
|
||||
def forward(self, x):
|
||||
tensor = self.map_conv_0(x)
|
||||
tensor = self.map_res_1(tensor)
|
||||
tensor = self.map_conv_1(tensor)
|
||||
tensor = self.map_res_2(tensor)
|
||||
tensor = self.map_conv_2(tensor)
|
||||
tensor = self.map_conv_3(tensor)
|
||||
tensor = self.flatten(tensor)
|
||||
tensor = self.linear(tensor)
|
||||
tensor = self.classifier(tensor)
|
||||
tensor = self.out_activation(tensor)
|
||||
return tensor
|
||||
|
||||
# Dataloaders
|
||||
# ================================================================================
|
||||
# Train Dataloader
|
||||
@ -107,16 +120,3 @@ class ConvHomDetector(LightningBaseModule):
|
||||
return DataLoader(dataset=self.dataset.val_dataset, shuffle=True,
|
||||
batch_size=self.hparams.data_param.batchsize,
|
||||
num_workers=self.hparams.data_param.worker)
|
||||
|
||||
def forward(self, x):
|
||||
tensor = self.map_conv_0(x)
|
||||
tensor = self.map_res_1(tensor)
|
||||
tensor = self.map_conv_1(tensor)
|
||||
tensor = self.map_res_2(tensor)
|
||||
tensor = self.map_conv_2(tensor)
|
||||
tensor = self.map_conv_3(tensor)
|
||||
tensor = self.flatten(tensor)
|
||||
tensor = self.linear(tensor)
|
||||
tensor = self.classifier(tensor)
|
||||
tensor = self.out_activation(tensor)
|
||||
return tensor
|
||||
|
@ -45,7 +45,7 @@ class Map(object):
|
||||
|
||||
@property
|
||||
def as_2d_array(self):
|
||||
return self.map_array[1:]
|
||||
return self.map_array.squeeze()
|
||||
|
||||
def __init__(self, name='', array_like_map_representation=None):
|
||||
if array_like_map_representation is not None:
|
||||
@ -145,9 +145,9 @@ class Map(object):
|
||||
|
||||
img = Image.new('L', (self.height, self.width), 0)
|
||||
draw = ImageDraw.Draw(img)
|
||||
draw.polygon(polyline, outline=255, fill=255)
|
||||
draw.polygon(polyline, outline=1, fill=1)
|
||||
|
||||
a = (np.asarray(img) * np.where(self.as_2d_array == self.white, 0, 1)).sum()
|
||||
a = (np.where(np.asarray(img) == self.white, 1, 0) * np.where(self.as_2d_array == self.white, 1, 0)).sum()
|
||||
|
||||
if a:
|
||||
return False # Non-Homotoph
|
||||
@ -159,7 +159,7 @@ class Map(object):
|
||||
# The standard colormaps also all have reversed versions.
|
||||
# They have the same names with _r tacked on to the end.
|
||||
# https: // matplotlib.org / api / pyplot_summary.html?highlight = colormaps
|
||||
img = ax.imshow(self.as_array, cmap='Greys_r')
|
||||
img = ax.imshow(self.as_2d_array, cmap='Greys_r')
|
||||
return dict(img=img, fig=fig, ax=ax)
|
||||
|
||||
|
||||
|
@ -14,7 +14,7 @@ class Trajectory(object):
|
||||
|
||||
@property
|
||||
def xy_vertices(self):
|
||||
return [(x,y) for _, x,y in self._vertices]
|
||||
return [(x, y) for _, y, x in self._vertices]
|
||||
|
||||
@property
|
||||
def endpoints(self):
|
||||
@ -30,11 +30,11 @@ class Trajectory(object):
|
||||
|
||||
@property
|
||||
def xs(self):
|
||||
return [x[1] for x in self._vertices]
|
||||
return [x[2] for x in self._vertices]
|
||||
|
||||
@property
|
||||
def ys(self):
|
||||
return [x[0] for x in self._vertices]
|
||||
return [x[1] for x in self._vertices]
|
||||
|
||||
@property
|
||||
def as_paired_list(self):
|
||||
@ -59,7 +59,7 @@ class Trajectory(object):
|
||||
kwargs.update(color='red' if label == V.HOMOTOPIC else 'green',
|
||||
label='Homotopic' if label == V.HOMOTOPIC else 'Alternative')
|
||||
if highlights:
|
||||
kwargs.update(marker='bo')
|
||||
kwargs.update(marker='o')
|
||||
fig, ax = plt.gcf(), plt.gca()
|
||||
img = plt.plot(self.xs, self.ys, **kwargs)
|
||||
return dict(img=img, fig=fig, ax=ax)
|
||||
|
@ -76,7 +76,7 @@ class Logger(LightningLoggerBase):
|
||||
self.neptunelogger.close()
|
||||
|
||||
def log_config_as_ini(self):
|
||||
self.config.write(self.log_dir)
|
||||
self.config.write(self.log_dir / 'config.ini')
|
||||
|
||||
def log_metric(self, metric_name, metric_value, **kwargs):
|
||||
self.testtubelogger.log_metrics(dict(metric_name=metric_value))
|
||||
@ -91,8 +91,8 @@ class Logger(LightningLoggerBase):
|
||||
self.neptunelogger.save()
|
||||
|
||||
def finalize(self, status):
|
||||
self.testtubelogger.finalize()
|
||||
self.neptunelogger.finalize()
|
||||
self.testtubelogger.finalize(status)
|
||||
self.neptunelogger.finalize(status)
|
||||
self.log_config_as_ini()
|
||||
|
||||
def __enter__(self):
|
||||
|
15
multi_run.py
15
multi_run.py
@ -8,7 +8,7 @@ warnings.filterwarnings('ignore', category=UserWarning)
|
||||
# Imports
|
||||
# =============================================================================
|
||||
|
||||
from main import run_training, args
|
||||
from main import run_lightning_loop, args
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
@ -16,17 +16,14 @@ if __name__ == '__main__':
|
||||
# Model Settings
|
||||
config = Config().read_namespace(args)
|
||||
# use_bias, activation, model, use_norm, max_epochs, filters
|
||||
cnn_classifier = [True, 'leaky_relu', 'classifier_cnn', False, 2, [16, 32, 64]]
|
||||
cnn_classifier = dict(train_epochs=100, model_use_bias=True, model_use_norm=True, model_activation='leaky_relu',
|
||||
model_type='classifier_cnn', model_filters=[16, 32, 64], data_batchsize=512)
|
||||
# use_bias, activation, model, use_norm, max_epochs, sr, lat_dim, filters
|
||||
|
||||
for use_bias, activation, model, use_norm, max_epochs, filters in [cnn_classifier]:
|
||||
for arg_dict in [cnn_classifier]:
|
||||
for seed in range(5):
|
||||
arg_dict = dict(main_seed=seed, train_max_epochs=max_epochs,
|
||||
model_use_bias=use_bias, model_use_norm=use_norm,
|
||||
model_activation=activation, model_type=model,
|
||||
model_filters=filters,
|
||||
data_batch_size=512)
|
||||
arg_dict.update(main_seed=seed)
|
||||
|
||||
config = config.update(arg_dict)
|
||||
|
||||
run_training(config)
|
||||
run_lightning_loop(config)
|
||||
|
Reference in New Issue
Block a user