diff --git a/point_toolset/point_io.py b/point_toolset/point_io.py index 5fe37c0..1253e59 100644 --- a/point_toolset/point_io.py +++ b/point_toolset/point_io.py @@ -11,9 +11,9 @@ class BatchToData(object): # Convert to torch_geometric.data.Data type batch_pos = batch_dict['pos'] - batch_norm = batch_dict['norm'] - batch_y = batch_dict['y'] - batch_y_c = batch_dict['y_c'] + batch_norm = batch_dict.get('norm', None) + batch_y = batch_dict.get('y', None) + batch_y_c = batch_dict.get('y_c', None) batch_size, num_points, _ = batch_pos.shape # (batch_size, num_points, 3) diff --git a/utils/model_io.py b/utils/model_io.py index fc82c4a..8724c11 100644 --- a/utils/model_io.py +++ b/utils/model_io.py @@ -62,6 +62,8 @@ class ModelParameters(Namespace, Mapping): ) def __init__(self, parameter_mapping): + if isinstance(parameter_mapping, Namespace): + parameter_mapping = parameter_mapping.__dict__ super(ModelParameters, self).__init__(**parameter_mapping) @@ -80,7 +82,7 @@ class SavedLightningModels(object): model = torch.load(models_root_path / 'model_class.obj') assert model is not None - return cls(weights=checkpoint_path, model=model) + return cls(weights=str(checkpoint_path), model=model) def __init__(self, **kwargs): self.weights: str = kwargs.get('weights', '')