dataset fixing
This commit is contained in:
parent
3f8122484b
commit
aea34de964
@ -11,9 +11,9 @@ class BatchToData(object):
|
|||||||
# Convert to torch_geometric.data.Data type
|
# Convert to torch_geometric.data.Data type
|
||||||
|
|
||||||
batch_pos = batch_dict['pos']
|
batch_pos = batch_dict['pos']
|
||||||
batch_norm = batch_dict['norm']
|
batch_norm = batch_dict.get('norm', None)
|
||||||
batch_y = batch_dict['y']
|
batch_y = batch_dict.get('y', None)
|
||||||
batch_y_c = batch_dict['y_c']
|
batch_y_c = batch_dict.get('y_c', None)
|
||||||
|
|
||||||
batch_size, num_points, _ = batch_pos.shape # (batch_size, num_points, 3)
|
batch_size, num_points, _ = batch_pos.shape # (batch_size, num_points, 3)
|
||||||
|
|
||||||
|
@ -62,6 +62,8 @@ class ModelParameters(Namespace, Mapping):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def __init__(self, parameter_mapping):
|
def __init__(self, parameter_mapping):
|
||||||
|
if isinstance(parameter_mapping, Namespace):
|
||||||
|
parameter_mapping = parameter_mapping.__dict__
|
||||||
super(ModelParameters, self).__init__(**parameter_mapping)
|
super(ModelParameters, self).__init__(**parameter_mapping)
|
||||||
|
|
||||||
|
|
||||||
@ -80,7 +82,7 @@ class SavedLightningModels(object):
|
|||||||
model = torch.load(models_root_path / 'model_class.obj')
|
model = torch.load(models_root_path / 'model_class.obj')
|
||||||
assert model is not None
|
assert model is not None
|
||||||
|
|
||||||
return cls(weights=checkpoint_path, model=model)
|
return cls(weights=str(checkpoint_path), model=model)
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
self.weights: str = kwargs.get('weights', '')
|
self.weights: str = kwargs.get('weights', '')
|
||||||
|
Loading…
x
Reference in New Issue
Block a user