fingerprinted now should work correctly
This commit is contained in:
parent
e423d6fe31
commit
206aca10b3
@ -1,11 +1,13 @@
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
import warnings
|
||||
|
||||
from torch import nn
|
||||
|
||||
from modules.utils import AutoPad, Interpolate, ShapeMixin, F_x, Flatten
|
||||
import sys
|
||||
sys.path.append(str(Path(__file__).parent))
|
||||
from .util import AutoPad, Interpolate, ShapeMixin, F_x, Flatten
|
||||
|
||||
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
|
@ -4,7 +4,7 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from modules.utils import ShapeMixin
|
||||
from .util import ShapeMixin
|
||||
|
||||
|
||||
class Generator(nn.Module):
|
||||
|
@ -10,7 +10,7 @@ import pytorch_lightning as pl
|
||||
|
||||
# Utility - Modules
|
||||
###################
|
||||
from utils.model_io import ModelParameters
|
||||
from ..utils.model_io import ModelParameters
|
||||
|
||||
|
||||
class ShapeMixin:
|
@ -39,10 +39,15 @@ class Config(ConfigParser, ABC):
|
||||
h = hashlib.md5()
|
||||
params = deepcopy(self.as_dict)
|
||||
del params['model']['type']
|
||||
del params['model']['secondary_type']
|
||||
del params['data']['worker']
|
||||
del params['main']
|
||||
h.update(str(params).encode())
|
||||
del params['project']
|
||||
# Flatten the dict of dicts
|
||||
for section in list(params.keys()):
|
||||
params.update({f'{section}_{key}': val for key, val in params[section].items()})
|
||||
del params[section]
|
||||
_, vals = zip(*sorted(params.items(), key=lambda tup: tup[0]))
|
||||
h.update(str(vals).encode())
|
||||
fingerprint = h.hexdigest()
|
||||
return fingerprint
|
||||
|
||||
|
@ -5,7 +5,7 @@ from pytorch_lightning.loggers.base import LightningLoggerBase
|
||||
from pytorch_lightning.loggers.neptune import NeptuneLogger
|
||||
from pytorch_lightning.loggers.test_tube import TestTubeLogger
|
||||
|
||||
from utils.config import Config
|
||||
from .config import Config
|
||||
|
||||
|
||||
class Logger(LightningLoggerBase, ABC):
|
||||
|
@ -23,3 +23,5 @@ def run_n_in_parallel(f, n, processes=0, **kwargs):
|
||||
p.join()
|
||||
|
||||
return results
|
||||
|
||||
raise NotImplementedError()
|
||||
|
Loading…
x
Reference in New Issue
Block a user