fingerprinted now should work correctly

This commit is contained in:
Si11ium 2020-05-19 08:33:04 +02:00
parent e423d6fe31
commit 206aca10b3
7 changed files with 16 additions and 7 deletions

View File

View File

@ -1,11 +1,13 @@
from pathlib import Path
from typing import Union
import torch
import warnings
from torch import nn
from modules.utils import AutoPad, Interpolate, ShapeMixin, F_x, Flatten
import sys
sys.path.append(str(Path(__file__).parent))
from .util import AutoPad, Interpolate, ShapeMixin, F_x, Flatten
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

View File

@ -4,7 +4,7 @@
import torch
from torch import nn
from modules.utils import ShapeMixin
from .util import ShapeMixin
class Generator(nn.Module):

View File

@ -10,7 +10,7 @@ import pytorch_lightning as pl
# Utility - Modules
###################
from utils.model_io import ModelParameters
from ..utils.model_io import ModelParameters
class ShapeMixin:

View File

@ -39,10 +39,15 @@ class Config(ConfigParser, ABC):
h = hashlib.md5()
params = deepcopy(self.as_dict)
del params['model']['type']
del params['model']['secondary_type']
del params['data']['worker']
del params['main']
h.update(str(params).encode())
del params['project']
# Flatten the dict of dicts
for section in list(params.keys()):
params.update({f'{section}_{key}': val for key, val in params[section].items()})
del params[section]
_, vals = zip(*sorted(params.items(), key=lambda tup: tup[0]))
h.update(str(vals).encode())
fingerprint = h.hexdigest()
return fingerprint

View File

@ -5,7 +5,7 @@ from pytorch_lightning.loggers.base import LightningLoggerBase
from pytorch_lightning.loggers.neptune import NeptuneLogger
from pytorch_lightning.loggers.test_tube import TestTubeLogger
from utils.config import Config
from .config import Config
class Logger(LightningLoggerBase, ABC):

View File

@ -23,3 +23,5 @@ def run_n_in_parallel(f, n, processes=0, **kwargs):
p.join()
return results
raise NotImplementedError()