Debugging Validation and testing
This commit is contained in:
parent
4ae333fe5d
commit
6b9696c98e
2
.idea/.gitignore
generated
vendored
2
.idea/.gitignore
generated
vendored
@ -1,2 +0,0 @@
|
|||||||
# Default ignored files
|
|
||||||
/workspace.xml
|
|
22
.idea/deployment.xml
generated
22
.idea/deployment.xml
generated
@ -1,22 +0,0 @@
|
|||||||
<?xml version="1.0" encoding="UTF-8"?>
|
|
||||||
<project version="4">
|
|
||||||
<component name="PublishConfigData" autoUpload="On explicit save action" serverName="steffen@aimachine:22" showAutoUploadSettingsWarning="false">
|
|
||||||
<serverData>
|
|
||||||
<paths name="erlowa@aimachine">
|
|
||||||
<serverdata>
|
|
||||||
<mappings>
|
|
||||||
<mapping deploy="/" local="$PROJECT_DIR$" web="/" />
|
|
||||||
</mappings>
|
|
||||||
</serverdata>
|
|
||||||
</paths>
|
|
||||||
<paths name="steffen@aimachine:22">
|
|
||||||
<serverdata>
|
|
||||||
<mappings>
|
|
||||||
<mapping deploy="/" local="$PROJECT_DIR$" web="/" />
|
|
||||||
</mappings>
|
|
||||||
</serverdata>
|
|
||||||
</paths>
|
|
||||||
</serverData>
|
|
||||||
<option name="myAutoUpload" value="ON_EXPLICIT_SAVE" />
|
|
||||||
</component>
|
|
||||||
</project>
|
|
23
.idea/dictionaries/steffen.xml
generated
23
.idea/dictionaries/steffen.xml
generated
@ -1,23 +0,0 @@
|
|||||||
<component name="ProjectDictionaryState">
|
|
||||||
<dictionary name="steffen">
|
|
||||||
<words>
|
|
||||||
<w>autopad</w>
|
|
||||||
<w>conv</w>
|
|
||||||
<w>convolutional</w>
|
|
||||||
<w>dataloader</w>
|
|
||||||
<w>dataloaders</w>
|
|
||||||
<w>datasets</w>
|
|
||||||
<w>homotopic</w>
|
|
||||||
<w>hparams</w>
|
|
||||||
<w>hyperparamter</w>
|
|
||||||
<w>kingma</w>
|
|
||||||
<w>logvar</w>
|
|
||||||
<w>mapname</w>
|
|
||||||
<w>mapnames</w>
|
|
||||||
<w>numlayers</w>
|
|
||||||
<w>reparameterize</w>
|
|
||||||
<w>softmax</w>
|
|
||||||
<w>traj</w>
|
|
||||||
</words>
|
|
||||||
</dictionary>
|
|
||||||
</component>
|
|
8
.idea/hom_traj_gen.iml
generated
8
.idea/hom_traj_gen.iml
generated
@ -1,8 +0,0 @@
|
|||||||
<?xml version="1.0" encoding="UTF-8"?>
|
|
||||||
<module type="PYTHON_MODULE" version="4">
|
|
||||||
<component name="NewModuleRootManager">
|
|
||||||
<content url="file://$MODULE_DIR$" />
|
|
||||||
<orderEntry type="jdk" jdkName="Remote Python 3.7.6 (sftp://steffen@aimachine:22/home/steffen/envs/traj_gen/bin/python)" jdkType="Python SDK" />
|
|
||||||
<orderEntry type="sourceFolder" forTests="false" />
|
|
||||||
</component>
|
|
||||||
</module>
|
|
7
.idea/inspectionProfiles/profiles_settings.xml
generated
7
.idea/inspectionProfiles/profiles_settings.xml
generated
@ -1,7 +0,0 @@
|
|||||||
<component name="InspectionProjectProfileManager">
|
|
||||||
<settings>
|
|
||||||
<option name="PROJECT_PROFILE" value="Default" />
|
|
||||||
<option name="USE_PROJECT_PROFILE" value="false" />
|
|
||||||
<version value="1.0" />
|
|
||||||
</settings>
|
|
||||||
</component>
|
|
10
.idea/misc.xml
generated
10
.idea/misc.xml
generated
@ -1,10 +0,0 @@
|
|||||||
<?xml version="1.0" encoding="UTF-8"?>
|
|
||||||
<project version="4">
|
|
||||||
<component name="JavaScriptSettings">
|
|
||||||
<option name="languageLevel" value="ES6" />
|
|
||||||
</component>
|
|
||||||
<component name="ProjectRootManager" version="2" project-jdk-name="traj_gen@ai-machine" project-jdk-type="Python SDK" />
|
|
||||||
<component name="PyPackaging">
|
|
||||||
<option name="earlyReleasesAsUpgrades" value="true" />
|
|
||||||
</component>
|
|
||||||
</project>
|
|
8
.idea/modules.xml
generated
8
.idea/modules.xml
generated
@ -1,8 +0,0 @@
|
|||||||
<?xml version="1.0" encoding="UTF-8"?>
|
|
||||||
<project version="4">
|
|
||||||
<component name="ProjectModuleManager">
|
|
||||||
<modules>
|
|
||||||
<module fileurl="file://$PROJECT_DIR$/.idea/hom_traj_gen.iml" filepath="$PROJECT_DIR$/.idea/hom_traj_gen.iml" />
|
|
||||||
</modules>
|
|
||||||
</component>
|
|
||||||
</project>
|
|
6
.idea/vcs.xml
generated
6
.idea/vcs.xml
generated
@ -1,6 +0,0 @@
|
|||||||
<?xml version="1.0" encoding="UTF-8"?>
|
|
||||||
<project version="4">
|
|
||||||
<component name="VcsDirectoryMappings">
|
|
||||||
<mapping directory="$PROJECT_DIR$" vcs="Git" />
|
|
||||||
</component>
|
|
||||||
</project>
|
|
15
.idea/webResources.xml
generated
15
.idea/webResources.xml
generated
@ -1,15 +0,0 @@
|
|||||||
<?xml version="1.0" encoding="UTF-8"?>
|
|
||||||
<project version="4">
|
|
||||||
<component name="WebResourcesPaths">
|
|
||||||
<contentEntries>
|
|
||||||
<entry url="file://$PROJECT_DIR$">
|
|
||||||
<entryData>
|
|
||||||
<resourceRoots>
|
|
||||||
<path value="file://$PROJECT_DIR$/res" />
|
|
||||||
<path value="file://$PROJECT_DIR$/data" />
|
|
||||||
</resourceRoots>
|
|
||||||
</entryData>
|
|
||||||
</entry>
|
|
||||||
</contentEntries>
|
|
||||||
</component>
|
|
||||||
</project>
|
|
@ -1,3 +1,5 @@
|
|||||||
|
from statistics import mean
|
||||||
|
|
||||||
from random import choice
|
from random import choice
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -65,13 +67,12 @@ class CNNRouteGeneratorModel(LightningBaseModule):
|
|||||||
|
|
||||||
def validation_epoch_end(self, outputs):
|
def validation_epoch_end(self, outputs):
|
||||||
evaluation = ROCEvaluation(plot_roc=True)
|
evaluation = ROCEvaluation(plot_roc=True)
|
||||||
predictions = torch.cat([x['prediction'] for x in outputs])
|
pred_label = torch.cat([x['pred_label'] for x in outputs])
|
||||||
labels = torch.cat([x['label'] for x in outputs]).unsqueeze(1)
|
labels = torch.cat([x['label'] for x in outputs]).unsqueeze(1)
|
||||||
losses = torch.cat([x['discriminated_bce_loss'] for x in outputs]).unsqueeze(1)
|
mean_losses = torch.stack([x['discriminated_bce_loss'] for x in outputs]).mean()
|
||||||
mean_losses = losses.mean()
|
|
||||||
|
|
||||||
# Sci-py call ROC eval call is eval(true_label, prediction)
|
# Sci-py call ROC eval call is eval(true_label, prediction)
|
||||||
roc_auc, tpr, fpr = evaluation(labels.cpu().numpy(), predictions.cpu().numpy(), )
|
roc_auc, tpr, fpr = evaluation(labels.cpu().numpy(), pred_label.cpu().numpy(), )
|
||||||
# self.logger.log_metrics(score_dict)
|
# self.logger.log_metrics(score_dict)
|
||||||
self.logger.log_image(f'{self.name}_ROC-Curve_E{self.current_epoch}', plt.gcf())
|
self.logger.log_image(f'{self.name}_ROC-Curve_E{self.current_epoch}', plt.gcf())
|
||||||
plt.clf()
|
plt.clf()
|
||||||
@ -103,7 +104,7 @@ class CNNRouteGeneratorModel(LightningBaseModule):
|
|||||||
|
|
||||||
# Dataset
|
# Dataset
|
||||||
self.dataset = TrajData(self.hparams.data_param.map_root, mode='just_route',
|
self.dataset = TrajData(self.hparams.data_param.map_root, mode='just_route',
|
||||||
length=self.hparams.train_param.batch_size * 1000)
|
length=self.hparams.data_param.dataset_length)
|
||||||
|
|
||||||
# Additional Attributes
|
# Additional Attributes
|
||||||
self.in_shape = self.dataset.map_shapes_max
|
self.in_shape = self.dataset.map_shapes_max
|
||||||
@ -159,6 +160,10 @@ class CNNRouteGeneratorModel(LightningBaseModule):
|
|||||||
|
|
||||||
self.traj_lin = nn.Linear(reduce(mul, self.traj_conv_3.shape), self.feature_dim)
|
self.traj_lin = nn.Linear(reduce(mul, self.traj_conv_3.shape), self.feature_dim)
|
||||||
|
|
||||||
|
#
|
||||||
|
# Mixed Encoder
|
||||||
|
self.mixed_lin = nn.Linear(self.lat_dim, self.lat_dim)
|
||||||
|
|
||||||
#
|
#
|
||||||
# Variational Bottleneck
|
# Variational Bottleneck
|
||||||
self.mu = nn.Linear(self.lat_dim, self.hparams.model_param.lat_dim)
|
self.mu = nn.Linear(self.lat_dim, self.hparams.model_param.lat_dim)
|
||||||
@ -242,7 +247,7 @@ class CNNRouteGeneratorModel(LightningBaseModule):
|
|||||||
return z, mu, logvar
|
return z, mu, logvar
|
||||||
|
|
||||||
def generate_random(self, n=6):
|
def generate_random(self, n=6):
|
||||||
maps = [self.map_storage[choice(self.map_storage.keys())] for _ in range(n)]
|
maps = [self.map_storage[choice(self.map_storage.keys)] for _ in range(n)]
|
||||||
trajectories = torch.stack([x.get_random_trajectory() for x in maps] * 2)
|
trajectories = torch.stack([x.get_random_trajectory() for x in maps] * 2)
|
||||||
maps = torch.stack([x.as_2d_array for x in maps] * 2)
|
maps = torch.stack([x.as_2d_array for x in maps] * 2)
|
||||||
labels = torch.as_tensor([0] * n + [1] * n)
|
labels = torch.as_tensor([0] * n + [1] * n)
|
||||||
|
@ -57,7 +57,8 @@ class ConvHomDetector(LightningBaseModule):
|
|||||||
# Model Parameters
|
# Model Parameters
|
||||||
self.in_shape = self.dataset.map_shapes_max
|
self.in_shape = self.dataset.map_shapes_max
|
||||||
assert len(self.in_shape) == 3, f'Image or map shape has to have 3 dims, but had: {len(self.in_shape)}'
|
assert len(self.in_shape) == 3, f'Image or map shape has to have 3 dims, but had: {len(self.in_shape)}'
|
||||||
self.criterion = nn.BCEWithLogitsLoss()
|
self.criterion = nn.BCELoss()
|
||||||
|
self.sigmoid = nn.Sigmoid()
|
||||||
|
|
||||||
# NN Nodes
|
# NN Nodes
|
||||||
# ============================
|
# ============================
|
||||||
@ -100,4 +101,5 @@ class ConvHomDetector(LightningBaseModule):
|
|||||||
tensor = self.flatten(tensor)
|
tensor = self.flatten(tensor)
|
||||||
tensor = self.linear(tensor)
|
tensor = self.linear(tensor)
|
||||||
tensor = self.classifier(tensor)
|
tensor = self.classifier(tensor)
|
||||||
|
tensor = self.sigmoid(tensor)
|
||||||
return tensor
|
return tensor
|
||||||
|
@ -90,7 +90,7 @@ class LightningBaseModule(pl.LightningModule, ABC):
|
|||||||
# Data loading
|
# Data loading
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Map Object
|
# Map Object
|
||||||
self.map_storage = MapStorage(self.hparams.data_param.map_root)
|
self.map_storage = MapStorage(self.hparams.data_param.map_root, load_all=True)
|
||||||
|
|
||||||
def size(self):
|
def size(self):
|
||||||
return self.shape
|
return self.shape
|
||||||
@ -143,19 +143,19 @@ class LightningBaseModule(pl.LightningModule, ABC):
|
|||||||
# Train Dataloader
|
# Train Dataloader
|
||||||
def train_dataloader(self):
|
def train_dataloader(self):
|
||||||
return DataLoader(dataset=self.dataset.train_dataset, shuffle=True,
|
return DataLoader(dataset=self.dataset.train_dataset, shuffle=True,
|
||||||
batch_size=self.hparams.data_param.batchsize,
|
batch_size=self.hparams.train_param.batch_size,
|
||||||
num_workers=self.hparams.data_param.worker)
|
num_workers=self.hparams.data_param.worker)
|
||||||
|
|
||||||
# Test Dataloader
|
# Test Dataloader
|
||||||
def test_dataloader(self):
|
def test_dataloader(self):
|
||||||
return DataLoader(dataset=self.dataset.test_dataset, shuffle=True,
|
return DataLoader(dataset=self.dataset.test_dataset, shuffle=True,
|
||||||
batch_size=self.hparams.data_param.batchsize,
|
batch_size=self.hparams.train_param.batch_size,
|
||||||
num_workers=self.hparams.data_param.worker)
|
num_workers=self.hparams.data_param.worker)
|
||||||
|
|
||||||
# Validation Dataloader
|
# Validation Dataloader
|
||||||
def val_dataloader(self):
|
def val_dataloader(self):
|
||||||
return DataLoader(dataset=self.dataset.val_dataset, shuffle=False,
|
return DataLoader(dataset=self.dataset.val_dataset, shuffle=False,
|
||||||
batch_size=self.hparams.data_param.batchsize,
|
batch_size=self.hparams.train_param.batch_size,
|
||||||
num_workers=self.hparams.data_param.worker)
|
num_workers=self.hparams.data_param.worker)
|
||||||
|
|
||||||
|
|
||||||
|
@ -167,6 +167,10 @@ class Map(object):
|
|||||||
|
|
||||||
class MapStorage(object):
|
class MapStorage(object):
|
||||||
|
|
||||||
|
@property
|
||||||
|
def keys(self):
|
||||||
|
return list(self.data.keys())
|
||||||
|
|
||||||
def __init__(self, map_root, load_all=False):
|
def __init__(self, map_root, load_all=False):
|
||||||
self.data = dict()
|
self.data = dict()
|
||||||
self.map_root = Path(map_root)
|
self.map_root = Path(map_root)
|
||||||
@ -175,11 +179,11 @@ class MapStorage(object):
|
|||||||
_ = self[map_file.name]
|
_ = self[map_file.name]
|
||||||
|
|
||||||
def __getitem__(self, item):
|
def __getitem__(self, item):
|
||||||
if item in hasattr(self, item):
|
if item in self.data.keys():
|
||||||
return self.__getattribute__(item)
|
return self.data.get(item)
|
||||||
else:
|
else:
|
||||||
with shelve.open(self.map_root / f'{item}.pik', flag='r') as d:
|
current_map = Map().from_image(self.map_root / item)
|
||||||
self.__setattr__(item, d['map']['map'])
|
self.data.__setitem__(item, np.asarray(current_map))
|
||||||
return self[item]
|
return self[item]
|
||||||
|
|
||||||
|
|
||||||
|
2
main.py
2
main.py
@ -33,6 +33,7 @@ main_arg_parser.add_argument("--main_seed", type=int, default=69, help="")
|
|||||||
|
|
||||||
# Data Parameters
|
# Data Parameters
|
||||||
main_arg_parser.add_argument("--data_worker", type=int, default=10, help="")
|
main_arg_parser.add_argument("--data_worker", type=int, default=10, help="")
|
||||||
|
main_arg_parser.add_argument("--data_dataset_length", type=int, default=10000, help="")
|
||||||
main_arg_parser.add_argument("--data_root", type=str, default='data', help="")
|
main_arg_parser.add_argument("--data_root", type=str, default='data', help="")
|
||||||
main_arg_parser.add_argument("--data_map_root", type=str, default='res/shapes', help="")
|
main_arg_parser.add_argument("--data_map_root", type=str, default='res/shapes', help="")
|
||||||
|
|
||||||
@ -105,6 +106,7 @@ def run_lightning_loop(config_obj):
|
|||||||
show_progress_bar=True,
|
show_progress_bar=True,
|
||||||
weights_save_path=logger.log_dir,
|
weights_save_path=logger.log_dir,
|
||||||
gpus=[0] if torch.cuda.is_available() else None,
|
gpus=[0] if torch.cuda.is_available() else None,
|
||||||
|
check_val_every_n_epoch=1,
|
||||||
# row_log_interval=(model.n_train_batches * 0.1), # TODO: Better Value / Setting
|
# row_log_interval=(model.n_train_batches * 0.1), # TODO: Better Value / Setting
|
||||||
# log_save_interval=(model.n_train_batches * 0.2), # TODO: Better Value / Setting
|
# log_save_interval=(model.n_train_batches * 0.2), # TODO: Better Value / Setting
|
||||||
checkpoint_callback=checkpoint_callback,
|
checkpoint_callback=checkpoint_callback,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user