dataset fixing
This commit is contained in:
@ -62,6 +62,8 @@ class ModelParameters(Namespace, Mapping):
|
||||
)
|
||||
|
||||
def __init__(self, parameter_mapping):
|
||||
if isinstance(parameter_mapping, Namespace):
|
||||
parameter_mapping = parameter_mapping.__dict__
|
||||
super(ModelParameters, self).__init__(**parameter_mapping)
|
||||
|
||||
|
||||
@ -80,7 +82,7 @@ class SavedLightningModels(object):
|
||||
model = torch.load(models_root_path / 'model_class.obj')
|
||||
assert model is not None
|
||||
|
||||
return cls(weights=checkpoint_path, model=model)
|
||||
return cls(weights=str(checkpoint_path), model=model)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.weights: str = kwargs.get('weights', '')
|
||||
|
Reference in New Issue
Block a user