SubSpectral and Lightning 0.9 Update

This commit is contained in:
Si11ium 2020-09-25 15:35:15 +02:00
parent 6bc9447ce1
commit 5848b528f0
13 changed files with 197 additions and 630 deletions

View File

@ -7,10 +7,9 @@ import torch
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from modules.utils import LightningBaseModule
from utils.config import Config
from utils.logging import Logger
from utils.model_io import SavedLightningModels
from ml_lib.modules.util import LightningBaseModule
from ml_lib.utils.config import Config
from ml_lib.utils.logging import Logger
warnings.filterwarnings('ignore', category=FutureWarning)
warnings.filterwarnings('ignore', category=UserWarning)

View File

@ -1,6 +1,6 @@
import warnings
from _templates.new_project.utils.project_config import Config
from ml_lib._templates.new_project.utils.project_config import Config
warnings.filterwarnings('ignore', category=FutureWarning)
warnings.filterwarnings('ignore', category=UserWarning)
@ -8,7 +8,7 @@ warnings.filterwarnings('ignore', category=UserWarning)
# Imports
# =============================================================================
from _templates.new_project.main import run_lightning_loop, args
from ml_lib._templates.new_project.main import run_lightning_loop, args
if __name__ == '__main__':

View File

@ -11,13 +11,13 @@ from torch.utils.data import DataLoader
from torchcontrib.optim import SWA
from torchvision.transforms import Compose
from _templates.new_project.datasets.template_dataset import TemplateDataset
from ml_lib._templates.new_project.datasets.template_dataset import TemplateDataset
from audio_toolset.audio_io import NormalizeLocal
from modules.utils import LightningBaseModule
from utils.transforms import ToTensor
from ml_lib.audio_toolset.audio_io import NormalizeLocal
from ml_lib.modules.util import LightningBaseModule
from ml_lib.utils.transforms import ToTensor
from _templates.new_project.utils.project_config import GlobalVar as GlobalVars
from ml_lib._templates.new_project.utils.project_config import GlobalVar as GlobalVars
class BaseOptimizerMixin:

View File

@ -1,6 +1,6 @@
from argparse import Namespace
from utils.config import Config
from ml_lib.utils.config import Config
class GlobalVar(Namespace):

View File

View File

@ -1,488 +0,0 @@
##########################
# constants
import argparse
import contextlib
import json
from datetime import datetime
from pathlib import Path
from typing import Dict, Optional, Union, Any
import numpy as np
import pandas as pd
import os
# ToDo: Check this
import shutil
from imageio import imwrite
from natsort import natsorted
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning import _logger as log
from test_tube.log import DDPExperiment
_ROOT = Path(os.path.abspath(__file__))
# -----------------------------
# Experiment object
# -----------------------------
class Experiment(object):
def __init__(self, save_dir=None, name='default', debug=False, version=None, autosave=False, description=None):
"""
A new Experiment object defaults to 'default' unless a specific name is provided
If a known name is already provided, then the file version is changed
:param name:
:param debug:
"""
# change where the save dir is if requested
if save_dir is not None:
global _ROOT
_ROOT = save_dir
self.save_dir = save_dir
self.no_save_dir = save_dir is None
self.metrics = []
self.tags = {}
self.name = name
self.debug = debug
self.version = version
self.autosave = autosave
self.description = description
self.exp_hash = '{}_v{}'.format(self.name, version)
self.created_at = str(datetime.utcnow())
self.process = os.getpid()
# when debugging don't do anything else
if debug:
return
# update version hash if we need to increase version on our own
# we will increase the previous version, so do it now so the hash
# is accurate
if version is None:
old_version = self.__get_last_experiment_version()
self.exp_hash = '{}_v{}'.format(self.name, old_version + 1)
self.version = old_version + 1
# create a new log file
self.__init_cache_file_if_needed()
# when we have a version, load it
if self.version is not None:
# when no version and no file, create it
if not os.path.exists(self.__get_log_name()):
self.__create_exp_file(self.version)
else:
# otherwise load it
self.__load()
else:
# if no version given, increase the version to a new exp
# create the file if not exists
old_version = self.__get_last_experiment_version()
self.version = old_version
self.__create_exp_file(self.version + 1)
def get_meta_copy(self):
"""
Gets a meta-version only copy of this module
:return:
"""
return DDPExperiment(self)
def on_exit(self):
pass
def __clean_dir(self):
files = os.listdir(self.save_dir)
for f in files:
if str(self.process) in f:
os.remove(os.path.join(self.save_dir, f))
def argparse(self, argparser):
parsed = vars(argparser)
to_add = {}
# don't store methods
for k, v in parsed.items():
if not callable(v):
to_add[k] = v
self.tag(to_add)
def add_meta_from_hyperopt(self, hypo):
"""
Transfers meta data about all the params from the
hyperoptimizer to the log
:param hypo:
:return:
"""
meta = hypo.get_current_trial_meta()
for tag in meta:
self.tag(tag)
# --------------------------------
# FILE IO UTILS
# --------------------------------
def __init_cache_file_if_needed(self):
"""
Inits a file that we log historical experiments
:return:
"""
try:
exp_cache_file = self.get_data_path(self.name, self.version)
if not os.path.isdir(exp_cache_file):
os.makedirs(exp_cache_file, exist_ok=True)
except FileExistsError:
# file already exists (likely written by another exp. In this case disable the experiment
self.debug = True
def __create_exp_file(self, version):
"""
Recreates the old file with this exp and version
:param version:
:return:
"""
try:
exp_cache_file = self.get_data_path(self.name, self.version)
# if no exp, then make it
path = exp_cache_file / 'meta.experiment'
path.touch(exist_ok=True)
self.version = version
# make the directory for the experiment media assets name
self.get_media_path(self.name, self.version).mkdir(parents=True, exist_ok=True)
except FileExistsError:
# file already exists (likely written by another exp. In this case disable the experiment
self.debug = True
def __get_last_experiment_version(self):
exp_cache_file = self.get_data_path(self.name, self.version).parent
last_version = -1
version = natsorted([x.name for x in exp_cache_file.iterdir() if 'version_' in x.name])[-1]
last_version = max(last_version, int(version.split('_')[1]))
return last_version
def __get_log_name(self):
return self.get_data_path(self.name, self.version) / 'meta.experiment'
def tag(self, tag_dict):
"""
Adds a tag to the experiment.
Tags are metadata for the exp.
>> e.tag({"model": "Convnet A"})
:param tag_dict:
:type tag_dict: dict
:return:
"""
if self.debug:
return
# parse tags
for k, v in tag_dict.items():
self.tags[k] = v
# save if needed
if self.autosave:
self.save()
def log(self, metrics_dict):
"""
Adds a json dict of metrics.
>> e.log({"loss": 23, "coeff_a": 0.2})
:param metrics_dict:
:return:
"""
if self.debug:
return
new_metrics_dict = metrics_dict.copy()
for k, v in metrics_dict.items():
tmp_metrics_dict = new_metrics_dict.pop(k)
new_metrics_dict.update(tmp_metrics_dict)
metrics_dict = new_metrics_dict
# timestamp
if 'created_at' not in metrics_dict:
metrics_dict['created_at'] = str(datetime.utcnow())
self.__convert_numpy_types(metrics_dict)
self.metrics.append(metrics_dict)
if self.autosave:
self.save()
@staticmethod
def __convert_numpy_types(metrics_dict):
for k, v in metrics_dict.items():
if v.__class__.__name__ == 'float32':
metrics_dict[k] = float(v)
if v.__class__.__name__ == 'float64':
metrics_dict[k] = float(v)
def save(self):
"""
Saves current experiment progress
:return:
"""
if self.debug:
return
# save images and replace the image array with the
# file name
self.__save_images(self.metrics)
metrics_file_path = self.get_data_path(self.name, self.version) / 'metrics.csv'
meta_tags_path = self.get_data_path(self.name, self.version) / 'meta_tags.csv'
obj = {
'name': self.name,
'version': self.version,
'tags_path': meta_tags_path,
'metrics_path': metrics_file_path,
'autosave': self.autosave,
'description': self.description,
'created_at': self.created_at,
'exp_hash': self.exp_hash
}
# save the experiment meta file
with atomic_write(self.__get_log_name()) as tmp_path:
with open(tmp_path, 'w') as file:
json.dump(obj, file, ensure_ascii=False)
# save the metatags file
df = pd.DataFrame({'key': list(self.tags.keys()), 'value': list(self.tags.values())})
with atomic_write(meta_tags_path) as tmp_path:
df.to_csv(tmp_path, index=False)
# save the metrics data
df = pd.DataFrame(self.metrics)
with atomic_write(metrics_file_path) as tmp_path:
df.to_csv(tmp_path, index=False)
def __save_images(self, metrics):
"""
Save tags that have a png_ prefix (as images)
and replace the meta tag with the file name
:param metrics:
:return:
"""
# iterate all metrics and find keys with a specific prefix
for i, metric in enumerate(metrics):
for k, v in metric.items():
# if the prefix is a png, save the image and replace the value with the path
img_extension = None
img_extension = 'png' if 'png_' in k else img_extension
img_extension = 'jpg' if 'jpg' in k else img_extension
img_extension = 'jpeg' if 'jpeg' in k else img_extension
if img_extension is not None:
# determine the file name
img_name = '_'.join(k.split('_')[1:])
save_path = self.get_media_path(self.name, self.version)
save_path = '{}/{}_{}.{}'.format(save_path, img_name, i, img_extension)
# save image to disk
if type(metric[k]) is not str:
imwrite(save_path, metric[k])
# replace the image in the metric with the file path
metric[k] = save_path
def __load(self):
# load .experiment file
with open(self.__get_log_name(), 'r') as file:
data = json.load(file)
self.name = data['name']
self.version = data['version']
self.autosave = data['autosave']
self.created_at = data['created_at']
self.description = data['description']
self.exp_hash = data['exp_hash']
# load .tags file
meta_tags_path = self.get_data_path(self.name, self.version) / 'meta_tags.csv'
df = pd.read_csv(meta_tags_path)
self.tags_list = df.to_dict(orient='records')
self.tags = {}
for d in self.tags_list:
k, v = d['key'], d['value']
self.tags[k] = v
# load metrics
metrics_file_path = self.get_data_path(self.name, self.version) / 'metrics.csv'
try:
df = pd.read_csv(metrics_file_path)
self.metrics = df.to_dict(orient='records')
# remove nans and infs
for metric in self.metrics:
to_delete = []
for k, v in metric.items():
if np.isnan(v) or np.isinf(v):
to_delete.append(k)
for k in to_delete:
del metric[k]
except Exception:
# metrics was empty...
self.metrics = []
def get_data_path(self, exp_name, exp_version):
"""
Returns the path to the local package cache
:param exp_name:
:param exp_version:
:return:
Path
"""
if self.no_save_dir:
return _ROOT / 'local_experiment_data' / exp_name, f'version_{exp_version}'
else:
return _ROOT / exp_name / f'version_{exp_version}'
def get_media_path(self, exp_name, exp_version):
"""
Returns the path to the local package cache
:param exp_version:
:param exp_name:
:return:
"""
return self.get_data_path(exp_name, exp_version) / 'media'
# ----------------------------
# OVERWRITES
# ----------------------------
def __str__(self):
return 'Exp: {}, v: {}'.format(self.name, self.version)
def __hash__(self):
return 'Exp: {}, v: {}'.format(self.name, self.version)
@contextlib.contextmanager
def atomic_write(dst_path):
"""A context manager to simplify atomic writing.
Usage:
>>> with atomic_write(dst_path) as tmp_path:
>>> # write to tmp_path
>>> # Here tmp_path renamed to dst_path, if no exception happened.
"""
tmp_path = dst_path / '.tmp'
try:
yield tmp_path
except:
if tmp_path.exists():
tmp_path.unlink()
raise
else:
# If everything is fine, move tmp file to the destination.
shutil.move(tmp_path, str(dst_path))
##########################
class LocalLogger(LightningLoggerBase):
@property
def name(self) -> str:
return self._name
@property
def experiment(self) -> Experiment:
r"""
Actual TestTube object. To use TestTube features in your
:class:`~pytorch_lightning.core.lightning.LightningModule` do the following.
Example::
self.logger.experiment.some_test_tube_function()
"""
if self._experiment is not None:
return self._experiment
self._experiment = Experiment(
save_dir=self.save_dir,
name=self._name,
debug=self.debug,
version=self.version,
description=self.description
)
return self._experiment
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None):
pass
def log_hyperparams(self, params: argparse.Namespace):
pass
@property
def version(self) -> Union[int, str]:
if self._version is None:
self._version = self._get_next_version()
return self._version
def _get_next_version(self):
root_dir = self.save_dir / self.name
if not root_dir.is_dir():
log.warning(f'Missing logger folder: {root_dir}')
return 0
existing_versions = []
for d in os.listdir(root_dir):
if os.path.isdir(os.path.join(root_dir, d)) and d.startswith("version_"):
existing_versions.append(int(d.split("_")[1]))
if len(existing_versions) == 0:
return 0
return max(existing_versions) + 1
def __init__(self, save_dir: str, name: str = "default", description: Optional[str] = None,
debug: bool = False, version: Optional[int] = None, **kwargs):
super(LocalLogger, self).__init__(**kwargs)
self.save_dir = Path(save_dir)
self._name = name
self.description = description
self.debug = debug
self._version = version
self._experiment = None
# Test tube experiments are not pickleable, so we need to override a few
# methods to get DDP working. See
# https://docs.python.org/3/library/pickle.html#handling-stateful-objects
# for more info.
def __getstate__(self) -> Dict[Any, Any]:
state = self.__dict__.copy()
state["_experiment"] = self.experiment.get_meta_copy()
return state
def __setstate__(self, state: Dict[Any, Any]):
self._experiment = state["_experiment"].get_non_ddp_exp()
del state["_experiment"]
self.__dict__.update(state)

View File

@ -130,8 +130,9 @@ class DeConvModule(ShapeMixin, nn.Module):
def __init__(self, in_shape, conv_filters, conv_kernel, conv_stride=1, conv_padding=0,
dropout: Union[int, float] = 0, autopad=0,
activation: Union[None, nn.Module] = nn.ReLU, interpolation_scale=0,
bias=True, norm=False):
bias=True, norm=False, **kwargs):
super(DeConvModule, self).__init__()
warnings.warn(f'The following arguments have been ignored: \n {list(kwargs.keys())}')
in_channels, height, width = in_shape[0], in_shape[1], in_shape[2]
self.padding = conv_padding
self.conv_kernel = conv_kernel

View File

@ -1,8 +1,10 @@
import torch
from torch import nn
from torch.nn import ReLU
from torch_geometric.nn import PointConv, fps, radius, global_max_pool, knn_interpolate
try:
from torch_geometric.nn import PointConv, fps, radius, global_max_pool, knn_interpolate
except ImportError:
print('Install torch-geometric to use this package.')
class SAModule(torch.nn.Module):

View File

@ -1,10 +1,77 @@
#
# Full Model Parts
###################
import torch
from torch import nn
from argparse import Namespace
from typing import Union, List, Tuple
from .util import ShapeMixin
import torch
from abc import ABC
from torch import nn
from torch.utils.data import DataLoader
from .util import ShapeMixin, LightningBaseModule
class AEBaseModule(LightningBaseModule, ABC):
def generate_random_image(self, dataloader: Union[None, str, DataLoader] = None,
lat_min: Union[Tuple, List, None] = None,
lat_max: Union[Tuple, List, None] = None):
assert bool(dataloader) ^ bool(lat_min and lat_max), 'Decide wether to give min, max or a dataloader, not both.'
min_max = self._find_min_max(dataloader) if dataloader else [None, None]
# assert not any([x is None for x in min_max])
lat_min = torch.as_tensor(lat_min or min_max[0])
lat_max = lat_max or min_max[1]
random_z = torch.rand((1, self.lat_dim))
random_z = random_z * (abs(lat_min) + lat_max) - abs(lat_min)
return self.decoder(random_z).squeeze()
def encode(self, x):
if len(x.shape) == 3:
x = x.unsqueeze(0)
return self.encoder(x).squeeze()
def _find_min_max(self, dataloader):
encodings = list()
for batch in dataloader:
encodings.append(self.encode(batch))
encodings = torch.cat(encodings, dim=0)
min_lat = encodings.min(dim=1)
max_lat = encodings.max(dim=1)
return min_lat, max_lat
def decode_lat_evenly(self, n: int,
dataloader: Union[None, str, DataLoader] = None,
lat_min: Union[Tuple, List, None] = None,
lat_max: Union[Tuple, List, None] = None):
assert bool(dataloader) ^ bool(lat_min and lat_max), 'Decide wether to give min, max or a dataloader, not both.'
min_max = self._find_min_max(dataloader) if dataloader else [None, None]
lat_min = lat_min or min_max[0]
lat_max = lat_max or min_max[1]
random_latent_samples = torch.stack([torch.linspace(lat_min[i].item(), lat_max[i].item(), n)
for i in range(self.params.lat_dim)], dim=-1).cpu().detach()
return self.decode(random_latent_samples).cpu().detach()
def decode(self, z):
if len(z.shape) == 1:
z = z.unsqueeze(0)
return self.decoder(z).squeeze()
def encode_and_restore(self, x):
x = x.to(self.device)
if len(x.shape) == 3:
x = x.unsqueeze(0)
z = self.encode(x)
x_hat = self.decode(z)
return Namespace(main_out=x_hat.squeeze(), latent_out=z)
class Generator(nn.Module):
@ -16,9 +83,12 @@ class Generator(nn.Module):
# noinspection PyUnresolvedReferences
def __init__(self, out_channels, re_shape, lat_dim, use_norm=False, use_bias=True, dropout: Union[int, float] = 0,
filters: List[int] = None, activation=nn.ReLU):
filters: List[int] = None, kernels: List[int] = None, activation=nn.ReLU, **kwargs):
super(Generator, self).__init__()
assert filters, '"Filters" has to be a list of int len 3'
assert filters, '"Filters" has to be a list of int.'
assert filters, '"Filters" has to be a list of int.'
assert len(filters) == len(kernels), '"Filters" and "Kernels" has to be of same length.'
self.filters = filters
self.activation = activation
self.inner_activation = activation()
@ -29,52 +99,35 @@ class Generator(nn.Module):
# re_shape = (self.feature_mixed_dim // reduce(mul, re_shape[1:]), ) + tuple(re_shape[1:])
self.flat = Flatten(to=re_shape)
self.de_conv_list = nn.ModuleList()
self.deconv1 = DeConvModule(re_shape, conv_filters=self.filters[0],
conv_kernel=5,
conv_padding=2,
conv_stride=1,
normalize=use_norm,
activation=self.activation,
interpolation_scale=2,
dropout=self.dropout
)
last_shape = re_shape
for conv_filter, conv_kernel in zip(filters, kernels):
self.de_conv_list.append(DeConvModule(last_shape, conv_filters=self.filters[0],
conv_kernel=conv_kernel,
conv_padding=conv_kernel-2,
conv_stride=conv_filter,
normalize=use_norm,
activation=self.activation,
interpolation_scale=2,
dropout=self.dropout
)
)
last_shape = self.de_conv_list[-1].shape
self.deconv2 = DeConvModule(self.deconv1.shape, conv_filters=self.filters[1],
conv_kernel=3,
conv_padding=1,
conv_stride=1,
normalize=use_norm,
activation=self.activation,
interpolation_scale=2,
dropout=self.dropout
)
self.deconv3 = DeConvModule(self.deconv2.shape, conv_filters=self.filters[2],
conv_kernel=3,
conv_padding=1,
conv_stride=1,
normalize=use_norm,
activation=self.activation,
interpolation_scale=2,
dropout=self.dropout
)
self.deconv4 = DeConvModule(self.deconv3.shape, conv_filters=out_channels,
conv_kernel=3,
conv_padding=1,
# normalize=norm,
activation=self.out_activation
)
self.de_conv_out = DeConvModule(self.de_conv_list[-1].shape, conv_filters=out_channels, conv_kernel=3,
conv_padding=1, activation=self.out_activation
)
def forward(self, z):
tensor = self.l1(z)
tensor = self.inner_activation(tensor)
tensor = self.flat(tensor)
tensor = self.deconv1(tensor)
tensor = self.deconv2(tensor)
tensor = self.deconv3(tensor)
tensor = self.deconv4(tensor)
for de_conv in self.de_conv_list:
tensor = de_conv(tensor)
tensor = self.de_conv_out(tensor)
return tensor
def size(self):
@ -119,12 +172,14 @@ class BaseEncoder(ShapeMixin, nn.Module):
# noinspection PyUnresolvedReferences
def __init__(self, in_shape, lat_dim=256, use_bias=True, use_norm=False, dropout: Union[int, float] = 0,
latent_activation: Union[nn.Module, None] = None, activation: nn.Module = nn.ELU,
filters: List[int] = None):
filters: List[int] = None, kernels: List[int] = None, **kwargs):
super(BaseEncoder, self).__init__()
assert filters, '"Filters" has to be a list of int len 3'
assert filters, '"Filters" has to be a list of int'
assert kernels, '"Kernels" has to be a list of int'
assert len(kernels) == len(filters), 'Length of "Filters" and "Kernels" has to be same.'
# Optional Padding for odd image-sizes
# Obsolet, already Done by autopadding module on incoming tensors
# Obsolet, cdan be done by autopadding module on incoming tensors
# in_shape = [x+1 if x % 2 != 0 and idx else x for idx, x in enumerate(in_shape)]
# Parameters
@ -133,43 +188,29 @@ class BaseEncoder(ShapeMixin, nn.Module):
self.use_bias = use_bias
self.latent_activation = latent_activation() if latent_activation else None
self.conv_list = nn.ModuleList()
# Modules
self.conv1 = ConvModule(self.in_shape, conv_filters=filters[0],
conv_kernel=3,
conv_padding=1,
conv_stride=1,
pooling_size=2,
use_norm=use_norm,
dropout=dropout,
activation=activation
)
self.conv2 = ConvModule(self.conv1.shape, conv_filters=filters[1],
conv_kernel=3,
conv_padding=1,
conv_stride=1,
pooling_size=2,
use_norm=use_norm,
dropout=dropout,
activation=activation
)
self.conv3 = ConvModule(self.conv2.shape, conv_filters=filters[2],
conv_kernel=5,
conv_padding=2,
conv_stride=1,
pooling_size=2,
use_norm=use_norm,
dropout=dropout,
activation=activation
)
last_shape = self.in_shape
for conv_filter, conv_kernel in zip(filters, kernels):
self.conv_list.append(ConvModule(last_shape, conv_filters=conv_filter,
conv_kernel=conv_kernel,
conv_padding=conv_kernel-2,
conv_stride=1,
pooling_size=2,
use_norm=use_norm,
dropout=dropout,
activation=activation
)
)
last_shape = self.conv_list[-1].shape
self.flat = Flatten()
def forward(self, x):
tensor = self.conv1(x)
tensor = self.conv2(tensor)
tensor = self.conv3(tensor)
tensor = x
for conv in self.conv_list:
tensor = conv(tensor)
tensor = self.flat(tensor)
return tensor

View File

@ -1,7 +1,10 @@
from functools import reduce
from abc import ABC
from pathlib import Path
import torch
from operator import mul
from torch import nn
from torch import functional as F
@ -102,6 +105,14 @@ class ShapeMixin:
else:
return -1
@property
def flat_shape(self):
shape = self.shape
try:
return reduce(mul, shape)
except TypeError:
return shape
class F_x(ShapeMixin, nn.Module):
def __init__(self, in_shape):
@ -175,7 +186,7 @@ class WeightInit:
m.bias.data.fill_(0.01)
class Filter(nn.Module):
class Filter(nn.Module, ShapeMixin):
def __init__(self, in_shape, pos, dim=-1):
super(Filter, self).__init__()
@ -210,11 +221,15 @@ class AutoPadToShape(object):
def __call__(self, x):
if not torch.is_tensor(x):
x = torch.as_tensor(x)
if x.shape[1:] == self.shape:
if x.shape[1:] == self.shape or x.shape == self.shape:
return x
embedding = torch.zeros((x.shape[0], *self.shape))
embedding[:, :x.shape[1], :x.shape[2], :x.shape[3]] = x
return embedding
for i in range(-1, -len(self.shape), -1):
idx = [0] * len(x.shape)
idx[i] = self.shape[i] - x.shape[i]
idx = tuple(idx)
x = torch.nn.functional.pad(x, idx)
return x
def __repr__(self):
return f'AutoPadTransform({self.shape})'
@ -233,9 +248,9 @@ class Splitter(nn.Module):
def __init__(self, in_shape, n, dim=-1):
super(Splitter, self).__init__()
self.n = n
self.dim = dim
self.in_shape = in_shape
self.n = n
self.dim = dim if dim > 0 else len(self.in_shape) - abs(dim)
self.new_dim_size = (self.in_shape[self.dim] // self.n) + (1 if self.in_shape[self.dim] % self.n != 0 else 0)
self._out_shape = tuple([x if self.dim != i else self.new_dim_size for i, x in enumerate(self.in_shape)])
@ -243,22 +258,23 @@ class Splitter(nn.Module):
self.autopad = AutoPadToShape(self._out_shape)
def forward(self, x: torch.Tensor):
x = x.transpose(0, self.dim)
dim = self.dim + 1 if len(self.in_shape) == (x.ndim -1) else self.dim
x = x.transpose(0, dim)
n_blocks = list()
for block_idx in range(self.n):
start = block_idx * self.new_dim_size
end = (block_idx + 1) * self.new_dim_size
block = self.autopad(x[:, :, start:end, :])
n_blocks.append(block.transpose(0, self.dim))
block = x[start:end].transpose(0, dim)
block = self.autopad(block)
n_blocks.append(block)
return n_blocks
class Merger(nn.Module):
class Merger(nn.Module, ShapeMixin):
@property
def shape(self):
y = self.forward([torch.randn(self.in_shape)])
y = self.forward([torch.randn(self.in_shape) for _ in range(self.n)])
return y.shape
def __init__(self, in_shape, n, dim=-1):

View File

@ -3,7 +3,8 @@ from pathlib import Path
from pytorch_lightning.loggers.base import LightningLoggerBase
from pytorch_lightning.loggers.neptune import NeptuneLogger
from pytorch_lightning.loggers.test_tube import TestTubeLogger
# noinspection PyUnresolvedReferences
from pytorch_lightning.loggers.csv_logs import CSVLogger
from .config import Config
@ -15,13 +16,13 @@ class Logger(LightningLoggerBase, ABC):
@property
def experiment(self):
if self.debug:
return self.testtubelogger.experiment
return self.csvlogger.experiment
else:
return self.neptunelogger.experiment
@property
def log_dir(self):
return Path(self.testtubelogger.experiment.get_logdir()).parent
return Path(self.csvlogger.experiment.log_dir)
@property
def name(self):
@ -64,55 +65,56 @@ class Logger(LightningLoggerBase, ABC):
self.config.set('project', 'owner', 'testuser')
self.config.set('project', 'name', 'test')
self.config.set('project', 'neptune_key', 'XXX')
self._testtube_kwargs = dict(save_dir=self.outpath, version=self.version, name=self.name)
self._csvlogger_kwargs = dict(save_dir=self.outpath, version=self.version, name=self.name)
self._neptune_kwargs = dict(offline_mode=self.debug,
api_key=self.config.project.neptune_key,
experiment_name=self.name,
project_name=self.project_name,
params=self.config.model_paramters)
self.neptunelogger = NeptuneLogger(**self._neptune_kwargs)
self.testtubelogger = TestTubeLogger(**self._testtube_kwargs)
self.csvlogger = CSVLogger(**self._csvlogger_kwargs)
self.log_config_as_ini()
def log_hyperparams(self, params):
self.neptunelogger.log_hyperparams(params)
self.testtubelogger.log_hyperparams(params)
self.csvlogger.log_hyperparams(params)
pass
def log_metrics(self, metrics, step=None):
self.neptunelogger.log_metrics(metrics, step=step)
self.testtubelogger.log_metrics(metrics, step=step)
self.csvlogger.log_metrics(metrics, step=step)
pass
def close(self):
self.testtubelogger.close()
self.csvlogger.close()
self.neptunelogger.close()
def log_config_as_ini(self):
self.config.write(self.log_dir / 'config.ini')
def log_text(self, name, text, step_nb=0, **kwargs):
def log_text(self, name, text, step_nb=0, **_):
# TODO Implement Offline variant.
self.neptunelogger.log_text(name, text, step_nb)
def log_metric(self, metric_name, metric_value, **kwargs):
self.testtubelogger.log_metrics(dict(metric_name=metric_value))
self.csvlogger.log_metrics(dict(metric_name=metric_value))
self.neptunelogger.log_metric(metric_name, metric_value, **kwargs)
def log_image(self, name, image, ext='png', **kwargs):
self.neptunelogger.log_image(name, image, **kwargs)
step = kwargs.get('step', None)
name = f'{step}_{name}' if step is not None else name
name = f'{name}.{ext[1:] if ext.startswith(".") else ext}'
image_name = f'{step}_{name}' if step is not None else name
image_path = self.log_dir / self.media_dir / f'{image_name}.{ext[1:] if ext.startswith(".") else ext}'
(self.log_dir / self.media_dir).mkdir(parents=True, exist_ok=True)
image.savefig(self.log_dir / self.media_dir / name)
image.savefig(image_path, bbox_inches='tight', pad_inches=0)
self.neptunelogger.log_image(name, str(image_path), **kwargs)
def save(self):
self.testtubelogger.save()
self.csvlogger.save()
self.neptunelogger.save()
def finalize(self, status):
self.testtubelogger.finalize(status)
self.csvlogger.finalize(status)
self.neptunelogger.finalize(status)
def __enter__(self):

View File

@ -20,7 +20,7 @@ class ModelParameters(Namespace, Mapping):
paramter_mapping.update(
dict(
activation=self._activations[self['activation']]
activation=self.__getattribute__('activation')
)
)
@ -44,7 +44,7 @@ class ModelParameters(Namespace, Mapping):
def __getattribute__(self, name):
if name == 'activation':
return self._activations[self['activation']]
return self._activations[self['activation'].lower()]
else:
try:
return super(ModelParameters, self).__getattribute__(name)
@ -56,6 +56,7 @@ class ModelParameters(Namespace, Mapping):
_activations = dict(
leaky_relu=nn.LeakyReLU,
elu=nn.ELU,
relu=nn.ReLU,
sigmoid=nn.Sigmoid,
tanh=nn.Tanh

View File

@ -1,5 +1,5 @@
try:
import matplotlib.pyplot as plt
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
except ImportError: # pragma: no-cover
raise ImportError('You want to use `matplotlib` plugins which are not installed yet,' # pragma: no-cover
' install it with `pip install matplotlib`.')
@ -8,30 +8,23 @@ from pathlib import Path
class Plotter(object):
def __init__(self, root_path=''):
if not root_path:
self.root_path = Path(root_path)
def save_current_figure(self, filename: str, extention='.png', naked=False):
fig, _ = plt.gcf(), plt.gca()
def save_figure(self, figure, title, extention='.png', naked=False):
canvas = FigureCanvas(figure)
# Prepare save location and check img file extention
path = self.root_path / Path(filename if filename.endswith(extention) else f'{filename}{extention}')
path = self.root_path / f'{title}{extention}'
path.parent.mkdir(exist_ok=True, parents=True)
if naked:
plt.axis('off')
fig.savefig(path, bbox_inches='tight', transparent=True, pad_inches=0)
fig.clf()
figure.axis('off)')
figure.savefig(path, bbox_inches='tight', transparent=True, pad_inches=0)
canvas.print_figure(path)
else:
fig.savefig(path)
fig.clf()
def show_current_figure(self):
fig, _ = plt.gcf(), plt.gca()
fig.show()
fig.clf()
canvas.print_figure(path)
if __name__ == '__main__':
output_root = Path('..') / 'output'
p = Plotter(output_root)
p.save_current_figure('test.png')
raise PermissionError('Get out of here.')