dataset fixing

This commit is contained in:
Si11ium 2020-06-19 15:37:43 +02:00
parent 3f8122484b
commit aea34de964
2 changed files with 6 additions and 4 deletions

View File

@ -11,9 +11,9 @@ class BatchToData(object):
# Convert to torch_geometric.data.Data type # Convert to torch_geometric.data.Data type
batch_pos = batch_dict['pos'] batch_pos = batch_dict['pos']
batch_norm = batch_dict['norm'] batch_norm = batch_dict.get('norm', None)
batch_y = batch_dict['y'] batch_y = batch_dict.get('y', None)
batch_y_c = batch_dict['y_c'] batch_y_c = batch_dict.get('y_c', None)
batch_size, num_points, _ = batch_pos.shape # (batch_size, num_points, 3) batch_size, num_points, _ = batch_pos.shape # (batch_size, num_points, 3)

View File

@ -62,6 +62,8 @@ class ModelParameters(Namespace, Mapping):
) )
def __init__(self, parameter_mapping): def __init__(self, parameter_mapping):
if isinstance(parameter_mapping, Namespace):
parameter_mapping = parameter_mapping.__dict__
super(ModelParameters, self).__init__(**parameter_mapping) super(ModelParameters, self).__init__(**parameter_mapping)
@ -80,7 +82,7 @@ class SavedLightningModels(object):
model = torch.load(models_root_path / 'model_class.obj') model = torch.load(models_root_path / 'model_class.obj')
assert model is not None assert model is not None
return cls(weights=checkpoint_path, model=model) return cls(weights=str(checkpoint_path), model=model)
def __init__(self, **kwargs): def __init__(self, **kwargs):
self.weights: str = kwargs.get('weights', '') self.weights: str = kwargs.get('weights', '')