fingerprinted now should work correctly

This commit is contained in:
Si11ium
2020-05-19 08:33:04 +02:00
parent e423d6fe31
commit 206aca10b3
7 changed files with 16 additions and 7 deletions

View File

View File

@@ -1,11 +1,13 @@
from pathlib import Path
from typing import Union from typing import Union
import torch import torch
import warnings import warnings
from torch import nn from torch import nn
import sys
from modules.utils import AutoPad, Interpolate, ShapeMixin, F_x, Flatten sys.path.append(str(Path(__file__).parent))
from .util import AutoPad, Interpolate, ShapeMixin, F_x, Flatten
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

View File

@@ -4,7 +4,7 @@
import torch import torch
from torch import nn from torch import nn
from modules.utils import ShapeMixin from .util import ShapeMixin
class Generator(nn.Module): class Generator(nn.Module):

View File

@@ -10,7 +10,7 @@ import pytorch_lightning as pl
# Utility - Modules # Utility - Modules
################### ###################
from utils.model_io import ModelParameters from ..utils.model_io import ModelParameters
class ShapeMixin: class ShapeMixin:

View File

@@ -39,10 +39,15 @@ class Config(ConfigParser, ABC):
h = hashlib.md5() h = hashlib.md5()
params = deepcopy(self.as_dict) params = deepcopy(self.as_dict)
del params['model']['type'] del params['model']['type']
del params['model']['secondary_type']
del params['data']['worker'] del params['data']['worker']
del params['main'] del params['main']
h.update(str(params).encode()) del params['project']
# Flatten the dict of dicts
for section in list(params.keys()):
params.update({f'{section}_{key}': val for key, val in params[section].items()})
del params[section]
_, vals = zip(*sorted(params.items(), key=lambda tup: tup[0]))
h.update(str(vals).encode())
fingerprint = h.hexdigest() fingerprint = h.hexdigest()
return fingerprint return fingerprint

View File

@@ -5,7 +5,7 @@ from pytorch_lightning.loggers.base import LightningLoggerBase
from pytorch_lightning.loggers.neptune import NeptuneLogger from pytorch_lightning.loggers.neptune import NeptuneLogger
from pytorch_lightning.loggers.test_tube import TestTubeLogger from pytorch_lightning.loggers.test_tube import TestTubeLogger
from utils.config import Config from .config import Config
class Logger(LightningLoggerBase, ABC): class Logger(LightningLoggerBase, ABC):

View File

@@ -23,3 +23,5 @@ def run_n_in_parallel(f, n, processes=0, **kwargs):
p.join() p.join()
return results return results
raise NotImplementedError()