fingerprinted now should work correctly
This commit is contained in:
@@ -1,11 +1,13 @@
|
|||||||
|
from pathlib import Path
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
import sys
|
||||||
from modules.utils import AutoPad, Interpolate, ShapeMixin, F_x, Flatten
|
sys.path.append(str(Path(__file__).parent))
|
||||||
|
from .util import AutoPad, Interpolate, ShapeMixin, F_x, Flatten
|
||||||
|
|
||||||
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
|
||||||
|
@@ -4,7 +4,7 @@
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from modules.utils import ShapeMixin
|
from .util import ShapeMixin
|
||||||
|
|
||||||
|
|
||||||
class Generator(nn.Module):
|
class Generator(nn.Module):
|
||||||
|
@@ -10,7 +10,7 @@ import pytorch_lightning as pl
|
|||||||
|
|
||||||
# Utility - Modules
|
# Utility - Modules
|
||||||
###################
|
###################
|
||||||
from utils.model_io import ModelParameters
|
from ..utils.model_io import ModelParameters
|
||||||
|
|
||||||
|
|
||||||
class ShapeMixin:
|
class ShapeMixin:
|
@@ -39,10 +39,15 @@ class Config(ConfigParser, ABC):
|
|||||||
h = hashlib.md5()
|
h = hashlib.md5()
|
||||||
params = deepcopy(self.as_dict)
|
params = deepcopy(self.as_dict)
|
||||||
del params['model']['type']
|
del params['model']['type']
|
||||||
del params['model']['secondary_type']
|
|
||||||
del params['data']['worker']
|
del params['data']['worker']
|
||||||
del params['main']
|
del params['main']
|
||||||
h.update(str(params).encode())
|
del params['project']
|
||||||
|
# Flatten the dict of dicts
|
||||||
|
for section in list(params.keys()):
|
||||||
|
params.update({f'{section}_{key}': val for key, val in params[section].items()})
|
||||||
|
del params[section]
|
||||||
|
_, vals = zip(*sorted(params.items(), key=lambda tup: tup[0]))
|
||||||
|
h.update(str(vals).encode())
|
||||||
fingerprint = h.hexdigest()
|
fingerprint = h.hexdigest()
|
||||||
return fingerprint
|
return fingerprint
|
||||||
|
|
||||||
|
@@ -5,7 +5,7 @@ from pytorch_lightning.loggers.base import LightningLoggerBase
|
|||||||
from pytorch_lightning.loggers.neptune import NeptuneLogger
|
from pytorch_lightning.loggers.neptune import NeptuneLogger
|
||||||
from pytorch_lightning.loggers.test_tube import TestTubeLogger
|
from pytorch_lightning.loggers.test_tube import TestTubeLogger
|
||||||
|
|
||||||
from utils.config import Config
|
from .config import Config
|
||||||
|
|
||||||
|
|
||||||
class Logger(LightningLoggerBase, ABC):
|
class Logger(LightningLoggerBase, ABC):
|
||||||
|
@@ -23,3 +23,5 @@ def run_n_in_parallel(f, n, processes=0, **kwargs):
|
|||||||
p.join()
|
p.join()
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
raise NotImplementedError()
|
||||||
|
Reference in New Issue
Block a user