From aea34de964616fbbffb42495696e0d9d029c4128 Mon Sep 17 00:00:00 2001
From: Si11ium <steffen.illium@ifi.lmu.de>
Date: Fri, 19 Jun 2020 15:37:43 +0200
Subject: [PATCH] dataset fixing

---
 point_toolset/point_io.py | 6 +++---
 utils/model_io.py         | 4 +++-
 2 files changed, 6 insertions(+), 4 deletions(-)

diff --git a/point_toolset/point_io.py b/point_toolset/point_io.py
index 5fe37c0..1253e59 100644
--- a/point_toolset/point_io.py
+++ b/point_toolset/point_io.py
@@ -11,9 +11,9 @@ class BatchToData(object):
         # Convert to torch_geometric.data.Data type
 
         batch_pos = batch_dict['pos']
-        batch_norm = batch_dict['norm']
-        batch_y = batch_dict['y']
-        batch_y_c = batch_dict['y_c']
+        batch_norm = batch_dict.get('norm', None)
+        batch_y = batch_dict.get('y', None)
+        batch_y_c = batch_dict.get('y_c', None)
 
         batch_size, num_points, _ = batch_pos.shape  # (batch_size, num_points, 3)
 
diff --git a/utils/model_io.py b/utils/model_io.py
index fc82c4a..8724c11 100644
--- a/utils/model_io.py
+++ b/utils/model_io.py
@@ -62,6 +62,8 @@ class ModelParameters(Namespace, Mapping):
     )
 
     def __init__(self, parameter_mapping):
+        if isinstance(parameter_mapping, Namespace):
+            parameter_mapping = parameter_mapping.__dict__
         super(ModelParameters, self).__init__(**parameter_mapping)
 
 
@@ -80,7 +82,7 @@ class SavedLightningModels(object):
             model = torch.load(models_root_path / 'model_class.obj')
         assert model is not None
 
-        return cls(weights=checkpoint_path, model=model)
+        return cls(weights=str(checkpoint_path), model=model)
 
     def __init__(self, **kwargs):
         self.weights: str = kwargs.get('weights', '')