From 5848b528f01eaa999574f0af7c9557eb69b679b7 Mon Sep 17 00:00:00 2001 From: Si11ium Date: Fri, 25 Sep 2020 15:35:15 +0200 Subject: [PATCH] SubSpectral and Lightning 0.9 Update --- _templates/new_project/main.py | 7 +- _templates/new_project/multi_run.py | 4 +- _templates/new_project/utils/module_mixins.py | 10 +- .../new_project/utils/project_config.py | 2 +- logging/__init__.py | 0 logging/local_logging.py | 488 ------------------ modules/blocks.py | 3 +- modules/geometric_blocks.py | 6 +- modules/model_parts.py | 199 ++++--- modules/util.py | 42 +- utils/logging.py | 34 +- utils/model_io.py | 5 +- visualization/tools.py | 27 +- 13 files changed, 197 insertions(+), 630 deletions(-) delete mode 100644 logging/__init__.py delete mode 100644 logging/local_logging.py diff --git a/_templates/new_project/main.py b/_templates/new_project/main.py index 1c7b357..c60cc87 100644 --- a/_templates/new_project/main.py +++ b/_templates/new_project/main.py @@ -7,10 +7,9 @@ import torch from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping -from modules.utils import LightningBaseModule -from utils.config import Config -from utils.logging import Logger -from utils.model_io import SavedLightningModels +from ml_lib.modules.util import LightningBaseModule +from ml_lib.utils.config import Config +from ml_lib.utils.logging import Logger warnings.filterwarnings('ignore', category=FutureWarning) warnings.filterwarnings('ignore', category=UserWarning) diff --git a/_templates/new_project/multi_run.py b/_templates/new_project/multi_run.py index f0ba3ce..73b02ae 100644 --- a/_templates/new_project/multi_run.py +++ b/_templates/new_project/multi_run.py @@ -1,6 +1,6 @@ import warnings -from _templates.new_project.utils.project_config import Config +from ml_lib._templates.new_project.utils.project_config import Config warnings.filterwarnings('ignore', category=FutureWarning) warnings.filterwarnings('ignore', category=UserWarning) @@ -8,7 +8,7 @@ warnings.filterwarnings('ignore', category=UserWarning) # Imports # ============================================================================= -from _templates.new_project.main import run_lightning_loop, args +from ml_lib._templates.new_project.main import run_lightning_loop, args if __name__ == '__main__': diff --git a/_templates/new_project/utils/module_mixins.py b/_templates/new_project/utils/module_mixins.py index cc98998..609977b 100644 --- a/_templates/new_project/utils/module_mixins.py +++ b/_templates/new_project/utils/module_mixins.py @@ -11,13 +11,13 @@ from torch.utils.data import DataLoader from torchcontrib.optim import SWA from torchvision.transforms import Compose -from _templates.new_project.datasets.template_dataset import TemplateDataset +from ml_lib._templates.new_project.datasets.template_dataset import TemplateDataset -from audio_toolset.audio_io import NormalizeLocal -from modules.utils import LightningBaseModule -from utils.transforms import ToTensor +from ml_lib.audio_toolset.audio_io import NormalizeLocal +from ml_lib.modules.util import LightningBaseModule +from ml_lib.utils.transforms import ToTensor -from _templates.new_project.utils.project_config import GlobalVar as GlobalVars +from ml_lib._templates.new_project.utils.project_config import GlobalVar as GlobalVars class BaseOptimizerMixin: diff --git a/_templates/new_project/utils/project_config.py b/_templates/new_project/utils/project_config.py index 78774db..8b651b3 100644 --- a/_templates/new_project/utils/project_config.py +++ b/_templates/new_project/utils/project_config.py @@ -1,6 +1,6 @@ from argparse import Namespace -from utils.config import Config +from ml_lib.utils.config import Config class GlobalVar(Namespace): diff --git a/logging/__init__.py b/logging/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/logging/local_logging.py b/logging/local_logging.py deleted file mode 100644 index 42933d1..0000000 --- a/logging/local_logging.py +++ /dev/null @@ -1,488 +0,0 @@ -########################## -# constants -import argparse -import contextlib -import json -from datetime import datetime -from pathlib import Path -from typing import Dict, Optional, Union, Any -import numpy as np - -import pandas as pd - -import os - - -# ToDo: Check this -import shutil -from imageio import imwrite -from natsort import natsorted -from pytorch_lightning.loggers import LightningLoggerBase -from pytorch_lightning import _logger as log -from test_tube.log import DDPExperiment - -_ROOT = Path(os.path.abspath(__file__)) - - -# ----------------------------- -# Experiment object -# ----------------------------- -class Experiment(object): - - def __init__(self, save_dir=None, name='default', debug=False, version=None, autosave=False, description=None): - """ - A new Experiment object defaults to 'default' unless a specific name is provided - If a known name is already provided, then the file version is changed - :param name: - :param debug: - """ - - # change where the save dir is if requested - - if save_dir is not None: - global _ROOT - _ROOT = save_dir - - self.save_dir = save_dir - self.no_save_dir = save_dir is None - self.metrics = [] - self.tags = {} - self.name = name - self.debug = debug - self.version = version - self.autosave = autosave - self.description = description - self.exp_hash = '{}_v{}'.format(self.name, version) - self.created_at = str(datetime.utcnow()) - self.process = os.getpid() - - # when debugging don't do anything else - if debug: - return - - # update version hash if we need to increase version on our own - # we will increase the previous version, so do it now so the hash - # is accurate - if version is None: - old_version = self.__get_last_experiment_version() - self.exp_hash = '{}_v{}'.format(self.name, old_version + 1) - self.version = old_version + 1 - - # create a new log file - self.__init_cache_file_if_needed() - - # when we have a version, load it - if self.version is not None: - - # when no version and no file, create it - if not os.path.exists(self.__get_log_name()): - self.__create_exp_file(self.version) - else: - # otherwise load it - self.__load() - - else: - # if no version given, increase the version to a new exp - # create the file if not exists - old_version = self.__get_last_experiment_version() - self.version = old_version - self.__create_exp_file(self.version + 1) - - def get_meta_copy(self): - """ - Gets a meta-version only copy of this module - :return: - """ - return DDPExperiment(self) - - def on_exit(self): - pass - - def __clean_dir(self): - files = os.listdir(self.save_dir) - - for f in files: - if str(self.process) in f: - os.remove(os.path.join(self.save_dir, f)) - - def argparse(self, argparser): - parsed = vars(argparser) - to_add = {} - - # don't store methods - for k, v in parsed.items(): - if not callable(v): - to_add[k] = v - - self.tag(to_add) - - def add_meta_from_hyperopt(self, hypo): - """ - Transfers meta data about all the params from the - hyperoptimizer to the log - :param hypo: - :return: - """ - meta = hypo.get_current_trial_meta() - for tag in meta: - self.tag(tag) - - # -------------------------------- - # FILE IO UTILS - # -------------------------------- - def __init_cache_file_if_needed(self): - """ - Inits a file that we log historical experiments - :return: - """ - try: - exp_cache_file = self.get_data_path(self.name, self.version) - if not os.path.isdir(exp_cache_file): - os.makedirs(exp_cache_file, exist_ok=True) - except FileExistsError: - # file already exists (likely written by another exp. In this case disable the experiment - self.debug = True - - def __create_exp_file(self, version): - """ - Recreates the old file with this exp and version - :param version: - :return: - """ - - try: - exp_cache_file = self.get_data_path(self.name, self.version) - # if no exp, then make it - path = exp_cache_file / 'meta.experiment' - path.touch(exist_ok=True) - - self.version = version - - # make the directory for the experiment media assets name - self.get_media_path(self.name, self.version).mkdir(parents=True, exist_ok=True) - - except FileExistsError: - # file already exists (likely written by another exp. In this case disable the experiment - self.debug = True - - def __get_last_experiment_version(self): - - exp_cache_file = self.get_data_path(self.name, self.version).parent - last_version = -1 - - version = natsorted([x.name for x in exp_cache_file.iterdir() if 'version_' in x.name])[-1] - last_version = max(last_version, int(version.split('_')[1])) - - return last_version - - def __get_log_name(self): - return self.get_data_path(self.name, self.version) / 'meta.experiment' - - def tag(self, tag_dict): - """ - Adds a tag to the experiment. - Tags are metadata for the exp. - - >> e.tag({"model": "Convnet A"}) - - :param tag_dict: - :type tag_dict: dict - - :return: - """ - if self.debug: - return - - # parse tags - for k, v in tag_dict.items(): - self.tags[k] = v - - # save if needed - if self.autosave: - self.save() - - def log(self, metrics_dict): - """ - Adds a json dict of metrics. - - >> e.log({"loss": 23, "coeff_a": 0.2}) - - :param metrics_dict: - - :return: - """ - if self.debug: - return - - new_metrics_dict = metrics_dict.copy() - for k, v in metrics_dict.items(): - tmp_metrics_dict = new_metrics_dict.pop(k) - new_metrics_dict.update(tmp_metrics_dict) - - metrics_dict = new_metrics_dict - - # timestamp - if 'created_at' not in metrics_dict: - metrics_dict['created_at'] = str(datetime.utcnow()) - - self.__convert_numpy_types(metrics_dict) - - self.metrics.append(metrics_dict) - - if self.autosave: - self.save() - - @staticmethod - def __convert_numpy_types(metrics_dict): - for k, v in metrics_dict.items(): - if v.__class__.__name__ == 'float32': - metrics_dict[k] = float(v) - - if v.__class__.__name__ == 'float64': - metrics_dict[k] = float(v) - - def save(self): - """ - Saves current experiment progress - :return: - """ - if self.debug: - return - - # save images and replace the image array with the - # file name - self.__save_images(self.metrics) - metrics_file_path = self.get_data_path(self.name, self.version) / 'metrics.csv' - meta_tags_path = self.get_data_path(self.name, self.version) / 'meta_tags.csv' - - obj = { - 'name': self.name, - 'version': self.version, - 'tags_path': meta_tags_path, - 'metrics_path': metrics_file_path, - 'autosave': self.autosave, - 'description': self.description, - 'created_at': self.created_at, - 'exp_hash': self.exp_hash - } - - # save the experiment meta file - with atomic_write(self.__get_log_name()) as tmp_path: - with open(tmp_path, 'w') as file: - json.dump(obj, file, ensure_ascii=False) - - # save the metatags file - df = pd.DataFrame({'key': list(self.tags.keys()), 'value': list(self.tags.values())}) - with atomic_write(meta_tags_path) as tmp_path: - df.to_csv(tmp_path, index=False) - - # save the metrics data - df = pd.DataFrame(self.metrics) - with atomic_write(metrics_file_path) as tmp_path: - df.to_csv(tmp_path, index=False) - - def __save_images(self, metrics): - """ - Save tags that have a png_ prefix (as images) - and replace the meta tag with the file name - :param metrics: - :return: - """ - # iterate all metrics and find keys with a specific prefix - for i, metric in enumerate(metrics): - for k, v in metric.items(): - # if the prefix is a png, save the image and replace the value with the path - img_extension = None - img_extension = 'png' if 'png_' in k else img_extension - img_extension = 'jpg' if 'jpg' in k else img_extension - img_extension = 'jpeg' if 'jpeg' in k else img_extension - - if img_extension is not None: - # determine the file name - img_name = '_'.join(k.split('_')[1:]) - save_path = self.get_media_path(self.name, self.version) - save_path = '{}/{}_{}.{}'.format(save_path, img_name, i, img_extension) - - # save image to disk - if type(metric[k]) is not str: - imwrite(save_path, metric[k]) - - # replace the image in the metric with the file path - metric[k] = save_path - - def __load(self): - # load .experiment file - with open(self.__get_log_name(), 'r') as file: - data = json.load(file) - self.name = data['name'] - self.version = data['version'] - self.autosave = data['autosave'] - self.created_at = data['created_at'] - self.description = data['description'] - self.exp_hash = data['exp_hash'] - - # load .tags file - meta_tags_path = self.get_data_path(self.name, self.version) / 'meta_tags.csv' - df = pd.read_csv(meta_tags_path) - self.tags_list = df.to_dict(orient='records') - self.tags = {} - for d in self.tags_list: - k, v = d['key'], d['value'] - self.tags[k] = v - - # load metrics - metrics_file_path = self.get_data_path(self.name, self.version) / 'metrics.csv' - try: - df = pd.read_csv(metrics_file_path) - self.metrics = df.to_dict(orient='records') - - # remove nans and infs - for metric in self.metrics: - to_delete = [] - for k, v in metric.items(): - if np.isnan(v) or np.isinf(v): - to_delete.append(k) - for k in to_delete: - del metric[k] - - except Exception: - # metrics was empty... - self.metrics = [] - - def get_data_path(self, exp_name, exp_version): - """ - Returns the path to the local package cache - :param exp_name: - :param exp_version: - :return: - Path - """ - if self.no_save_dir: - return _ROOT / 'local_experiment_data' / exp_name, f'version_{exp_version}' - else: - return _ROOT / exp_name / f'version_{exp_version}' - - def get_media_path(self, exp_name, exp_version): - """ - Returns the path to the local package cache - :param exp_version: - :param exp_name: - :return: - """ - - return self.get_data_path(exp_name, exp_version) / 'media' - - # ---------------------------- - # OVERWRITES - # ---------------------------- - def __str__(self): - return 'Exp: {}, v: {}'.format(self.name, self.version) - - def __hash__(self): - return 'Exp: {}, v: {}'.format(self.name, self.version) - - -@contextlib.contextmanager -def atomic_write(dst_path): - """A context manager to simplify atomic writing. - - Usage: - >>> with atomic_write(dst_path) as tmp_path: - >>> # write to tmp_path - >>> # Here tmp_path renamed to dst_path, if no exception happened. - """ - tmp_path = dst_path / '.tmp' - try: - yield tmp_path - except: - if tmp_path.exists(): - tmp_path.unlink() - raise - else: - # If everything is fine, move tmp file to the destination. - shutil.move(tmp_path, str(dst_path)) - - -########################## -class LocalLogger(LightningLoggerBase): - - @property - def name(self) -> str: - return self._name - - @property - def experiment(self) -> Experiment: - r""" - - Actual TestTube object. To use TestTube features in your - :class:`~pytorch_lightning.core.lightning.LightningModule` do the following. - - Example:: - - self.logger.experiment.some_test_tube_function() - - """ - if self._experiment is not None: - return self._experiment - - self._experiment = Experiment( - save_dir=self.save_dir, - name=self._name, - debug=self.debug, - version=self.version, - description=self.description - ) - return self._experiment - - def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None): - pass - - def log_hyperparams(self, params: argparse.Namespace): - pass - - @property - def version(self) -> Union[int, str]: - if self._version is None: - self._version = self._get_next_version() - return self._version - - def _get_next_version(self): - root_dir = self.save_dir / self.name - - if not root_dir.is_dir(): - log.warning(f'Missing logger folder: {root_dir}') - return 0 - - existing_versions = [] - for d in os.listdir(root_dir): - if os.path.isdir(os.path.join(root_dir, d)) and d.startswith("version_"): - existing_versions.append(int(d.split("_")[1])) - - if len(existing_versions) == 0: - return 0 - - return max(existing_versions) + 1 - - def __init__(self, save_dir: str, name: str = "default", description: Optional[str] = None, - debug: bool = False, version: Optional[int] = None, **kwargs): - super(LocalLogger, self).__init__(**kwargs) - self.save_dir = Path(save_dir) - self._name = name - self.description = description - self.debug = debug - self._version = version - self._experiment = None - - # Test tube experiments are not pickleable, so we need to override a few - # methods to get DDP working. See - # https://docs.python.org/3/library/pickle.html#handling-stateful-objects - # for more info. - def __getstate__(self) -> Dict[Any, Any]: - state = self.__dict__.copy() - state["_experiment"] = self.experiment.get_meta_copy() - return state - - def __setstate__(self, state: Dict[Any, Any]): - self._experiment = state["_experiment"].get_non_ddp_exp() - del state["_experiment"] - self.__dict__.update(state) diff --git a/modules/blocks.py b/modules/blocks.py index fe19b8b..cb85886 100644 --- a/modules/blocks.py +++ b/modules/blocks.py @@ -130,8 +130,9 @@ class DeConvModule(ShapeMixin, nn.Module): def __init__(self, in_shape, conv_filters, conv_kernel, conv_stride=1, conv_padding=0, dropout: Union[int, float] = 0, autopad=0, activation: Union[None, nn.Module] = nn.ReLU, interpolation_scale=0, - bias=True, norm=False): + bias=True, norm=False, **kwargs): super(DeConvModule, self).__init__() + warnings.warn(f'The following arguments have been ignored: \n {list(kwargs.keys())}') in_channels, height, width = in_shape[0], in_shape[1], in_shape[2] self.padding = conv_padding self.conv_kernel = conv_kernel diff --git a/modules/geometric_blocks.py b/modules/geometric_blocks.py index 6a05068..f058628 100644 --- a/modules/geometric_blocks.py +++ b/modules/geometric_blocks.py @@ -1,8 +1,10 @@ import torch from torch import nn from torch.nn import ReLU - -from torch_geometric.nn import PointConv, fps, radius, global_max_pool, knn_interpolate +try: + from torch_geometric.nn import PointConv, fps, radius, global_max_pool, knn_interpolate +except ImportError: + print('Install torch-geometric to use this package.') class SAModule(torch.nn.Module): diff --git a/modules/model_parts.py b/modules/model_parts.py index 0eb92ec..9afd028 100644 --- a/modules/model_parts.py +++ b/modules/model_parts.py @@ -1,10 +1,77 @@ # # Full Model Parts ################### -import torch -from torch import nn +from argparse import Namespace +from typing import Union, List, Tuple -from .util import ShapeMixin +import torch +from abc import ABC +from torch import nn +from torch.utils.data import DataLoader + +from .util import ShapeMixin, LightningBaseModule + + +class AEBaseModule(LightningBaseModule, ABC): + + def generate_random_image(self, dataloader: Union[None, str, DataLoader] = None, + lat_min: Union[Tuple, List, None] = None, + lat_max: Union[Tuple, List, None] = None): + + assert bool(dataloader) ^ bool(lat_min and lat_max), 'Decide wether to give min, max or a dataloader, not both.' + + min_max = self._find_min_max(dataloader) if dataloader else [None, None] + # assert not any([x is None for x in min_max]) + lat_min = torch.as_tensor(lat_min or min_max[0]) + lat_max = lat_max or min_max[1] + + random_z = torch.rand((1, self.lat_dim)) + random_z = random_z * (abs(lat_min) + lat_max) - abs(lat_min) + + return self.decoder(random_z).squeeze() + + def encode(self, x): + if len(x.shape) == 3: + x = x.unsqueeze(0) + return self.encoder(x).squeeze() + + def _find_min_max(self, dataloader): + encodings = list() + for batch in dataloader: + encodings.append(self.encode(batch)) + encodings = torch.cat(encodings, dim=0) + min_lat = encodings.min(dim=1) + max_lat = encodings.max(dim=1) + return min_lat, max_lat + + def decode_lat_evenly(self, n: int, + dataloader: Union[None, str, DataLoader] = None, + lat_min: Union[Tuple, List, None] = None, + lat_max: Union[Tuple, List, None] = None): + assert bool(dataloader) ^ bool(lat_min and lat_max), 'Decide wether to give min, max or a dataloader, not both.' + + min_max = self._find_min_max(dataloader) if dataloader else [None, None] + + lat_min = lat_min or min_max[0] + lat_max = lat_max or min_max[1] + + random_latent_samples = torch.stack([torch.linspace(lat_min[i].item(), lat_max[i].item(), n) + for i in range(self.params.lat_dim)], dim=-1).cpu().detach() + return self.decode(random_latent_samples).cpu().detach() + + def decode(self, z): + if len(z.shape) == 1: + z = z.unsqueeze(0) + return self.decoder(z).squeeze() + + def encode_and_restore(self, x): + x = x.to(self.device) + if len(x.shape) == 3: + x = x.unsqueeze(0) + z = self.encode(x) + x_hat = self.decode(z) + + return Namespace(main_out=x_hat.squeeze(), latent_out=z) class Generator(nn.Module): @@ -16,9 +83,12 @@ class Generator(nn.Module): # noinspection PyUnresolvedReferences def __init__(self, out_channels, re_shape, lat_dim, use_norm=False, use_bias=True, dropout: Union[int, float] = 0, - filters: List[int] = None, activation=nn.ReLU): + filters: List[int] = None, kernels: List[int] = None, activation=nn.ReLU, **kwargs): super(Generator, self).__init__() - assert filters, '"Filters" has to be a list of int len 3' + assert filters, '"Filters" has to be a list of int.' + assert filters, '"Filters" has to be a list of int.' + assert len(filters) == len(kernels), '"Filters" and "Kernels" has to be of same length.' + self.filters = filters self.activation = activation self.inner_activation = activation() @@ -29,52 +99,35 @@ class Generator(nn.Module): # re_shape = (self.feature_mixed_dim // reduce(mul, re_shape[1:]), ) + tuple(re_shape[1:]) self.flat = Flatten(to=re_shape) + self.de_conv_list = nn.ModuleList() - self.deconv1 = DeConvModule(re_shape, conv_filters=self.filters[0], - conv_kernel=5, - conv_padding=2, - conv_stride=1, - normalize=use_norm, - activation=self.activation, - interpolation_scale=2, - dropout=self.dropout - ) + last_shape = re_shape + for conv_filter, conv_kernel in zip(filters, kernels): + self.de_conv_list.append(DeConvModule(last_shape, conv_filters=self.filters[0], + conv_kernel=conv_kernel, + conv_padding=conv_kernel-2, + conv_stride=conv_filter, + normalize=use_norm, + activation=self.activation, + interpolation_scale=2, + dropout=self.dropout + ) + ) + last_shape = self.de_conv_list[-1].shape - self.deconv2 = DeConvModule(self.deconv1.shape, conv_filters=self.filters[1], - conv_kernel=3, - conv_padding=1, - conv_stride=1, - normalize=use_norm, - activation=self.activation, - interpolation_scale=2, - dropout=self.dropout - ) - - self.deconv3 = DeConvModule(self.deconv2.shape, conv_filters=self.filters[2], - conv_kernel=3, - conv_padding=1, - conv_stride=1, - normalize=use_norm, - activation=self.activation, - interpolation_scale=2, - dropout=self.dropout - ) - - self.deconv4 = DeConvModule(self.deconv3.shape, conv_filters=out_channels, - conv_kernel=3, - conv_padding=1, - # normalize=norm, - activation=self.out_activation - ) + self.de_conv_out = DeConvModule(self.de_conv_list[-1].shape, conv_filters=out_channels, conv_kernel=3, + conv_padding=1, activation=self.out_activation + ) def forward(self, z): tensor = self.l1(z) tensor = self.inner_activation(tensor) tensor = self.flat(tensor) - tensor = self.deconv1(tensor) - tensor = self.deconv2(tensor) - tensor = self.deconv3(tensor) - tensor = self.deconv4(tensor) + + for de_conv in self.de_conv_list: + tensor = de_conv(tensor) + + tensor = self.de_conv_out(tensor) return tensor def size(self): @@ -119,12 +172,14 @@ class BaseEncoder(ShapeMixin, nn.Module): # noinspection PyUnresolvedReferences def __init__(self, in_shape, lat_dim=256, use_bias=True, use_norm=False, dropout: Union[int, float] = 0, latent_activation: Union[nn.Module, None] = None, activation: nn.Module = nn.ELU, - filters: List[int] = None): + filters: List[int] = None, kernels: List[int] = None, **kwargs): super(BaseEncoder, self).__init__() - assert filters, '"Filters" has to be a list of int len 3' + assert filters, '"Filters" has to be a list of int' + assert kernels, '"Kernels" has to be a list of int' + assert len(kernels) == len(filters), 'Length of "Filters" and "Kernels" has to be same.' # Optional Padding for odd image-sizes - # Obsolet, already Done by autopadding module on incoming tensors + # Obsolet, cdan be done by autopadding module on incoming tensors # in_shape = [x+1 if x % 2 != 0 and idx else x for idx, x in enumerate(in_shape)] # Parameters @@ -133,43 +188,29 @@ class BaseEncoder(ShapeMixin, nn.Module): self.use_bias = use_bias self.latent_activation = latent_activation() if latent_activation else None + self.conv_list = nn.ModuleList() + # Modules - self.conv1 = ConvModule(self.in_shape, conv_filters=filters[0], - conv_kernel=3, - conv_padding=1, - conv_stride=1, - pooling_size=2, - use_norm=use_norm, - dropout=dropout, - activation=activation - ) - - self.conv2 = ConvModule(self.conv1.shape, conv_filters=filters[1], - conv_kernel=3, - conv_padding=1, - conv_stride=1, - pooling_size=2, - use_norm=use_norm, - dropout=dropout, - activation=activation - ) - - self.conv3 = ConvModule(self.conv2.shape, conv_filters=filters[2], - conv_kernel=5, - conv_padding=2, - conv_stride=1, - pooling_size=2, - use_norm=use_norm, - dropout=dropout, - activation=activation - ) + last_shape = self.in_shape + for conv_filter, conv_kernel in zip(filters, kernels): + self.conv_list.append(ConvModule(last_shape, conv_filters=conv_filter, + conv_kernel=conv_kernel, + conv_padding=conv_kernel-2, + conv_stride=1, + pooling_size=2, + use_norm=use_norm, + dropout=dropout, + activation=activation + ) + ) + last_shape = self.conv_list[-1].shape self.flat = Flatten() def forward(self, x): - tensor = self.conv1(x) - tensor = self.conv2(tensor) - tensor = self.conv3(tensor) + tensor = x + for conv in self.conv_list: + tensor = conv(tensor) tensor = self.flat(tensor) return tensor diff --git a/modules/util.py b/modules/util.py index 1b29853..87548f7 100644 --- a/modules/util.py +++ b/modules/util.py @@ -1,7 +1,10 @@ +from functools import reduce + from abc import ABC from pathlib import Path import torch +from operator import mul from torch import nn from torch import functional as F @@ -102,6 +105,14 @@ class ShapeMixin: else: return -1 + @property + def flat_shape(self): + shape = self.shape + try: + return reduce(mul, shape) + except TypeError: + return shape + class F_x(ShapeMixin, nn.Module): def __init__(self, in_shape): @@ -175,7 +186,7 @@ class WeightInit: m.bias.data.fill_(0.01) -class Filter(nn.Module): +class Filter(nn.Module, ShapeMixin): def __init__(self, in_shape, pos, dim=-1): super(Filter, self).__init__() @@ -210,11 +221,15 @@ class AutoPadToShape(object): def __call__(self, x): if not torch.is_tensor(x): x = torch.as_tensor(x) - if x.shape[1:] == self.shape: + if x.shape[1:] == self.shape or x.shape == self.shape: return x - embedding = torch.zeros((x.shape[0], *self.shape)) - embedding[:, :x.shape[1], :x.shape[2], :x.shape[3]] = x - return embedding + + for i in range(-1, -len(self.shape), -1): + idx = [0] * len(x.shape) + idx[i] = self.shape[i] - x.shape[i] + idx = tuple(idx) + x = torch.nn.functional.pad(x, idx) + return x def __repr__(self): return f'AutoPadTransform({self.shape})' @@ -233,9 +248,9 @@ class Splitter(nn.Module): def __init__(self, in_shape, n, dim=-1): super(Splitter, self).__init__() - self.n = n - self.dim = dim self.in_shape = in_shape + self.n = n + self.dim = dim if dim > 0 else len(self.in_shape) - abs(dim) self.new_dim_size = (self.in_shape[self.dim] // self.n) + (1 if self.in_shape[self.dim] % self.n != 0 else 0) self._out_shape = tuple([x if self.dim != i else self.new_dim_size for i, x in enumerate(self.in_shape)]) @@ -243,22 +258,23 @@ class Splitter(nn.Module): self.autopad = AutoPadToShape(self._out_shape) def forward(self, x: torch.Tensor): - x = x.transpose(0, self.dim) + dim = self.dim + 1 if len(self.in_shape) == (x.ndim -1) else self.dim + x = x.transpose(0, dim) n_blocks = list() for block_idx in range(self.n): start = block_idx * self.new_dim_size end = (block_idx + 1) * self.new_dim_size - block = self.autopad(x[:, :, start:end, :]) - - n_blocks.append(block.transpose(0, self.dim)) + block = x[start:end].transpose(0, dim) + block = self.autopad(block) + n_blocks.append(block) return n_blocks -class Merger(nn.Module): +class Merger(nn.Module, ShapeMixin): @property def shape(self): - y = self.forward([torch.randn(self.in_shape)]) + y = self.forward([torch.randn(self.in_shape) for _ in range(self.n)]) return y.shape def __init__(self, in_shape, n, dim=-1): diff --git a/utils/logging.py b/utils/logging.py index 3c27474..ec4119d 100644 --- a/utils/logging.py +++ b/utils/logging.py @@ -3,7 +3,8 @@ from pathlib import Path from pytorch_lightning.loggers.base import LightningLoggerBase from pytorch_lightning.loggers.neptune import NeptuneLogger -from pytorch_lightning.loggers.test_tube import TestTubeLogger +# noinspection PyUnresolvedReferences +from pytorch_lightning.loggers.csv_logs import CSVLogger from .config import Config @@ -15,13 +16,13 @@ class Logger(LightningLoggerBase, ABC): @property def experiment(self): if self.debug: - return self.testtubelogger.experiment + return self.csvlogger.experiment else: return self.neptunelogger.experiment @property def log_dir(self): - return Path(self.testtubelogger.experiment.get_logdir()).parent + return Path(self.csvlogger.experiment.log_dir) @property def name(self): @@ -64,55 +65,56 @@ class Logger(LightningLoggerBase, ABC): self.config.set('project', 'owner', 'testuser') self.config.set('project', 'name', 'test') self.config.set('project', 'neptune_key', 'XXX') - self._testtube_kwargs = dict(save_dir=self.outpath, version=self.version, name=self.name) + self._csvlogger_kwargs = dict(save_dir=self.outpath, version=self.version, name=self.name) self._neptune_kwargs = dict(offline_mode=self.debug, api_key=self.config.project.neptune_key, experiment_name=self.name, project_name=self.project_name, params=self.config.model_paramters) self.neptunelogger = NeptuneLogger(**self._neptune_kwargs) - self.testtubelogger = TestTubeLogger(**self._testtube_kwargs) + self.csvlogger = CSVLogger(**self._csvlogger_kwargs) self.log_config_as_ini() def log_hyperparams(self, params): self.neptunelogger.log_hyperparams(params) - self.testtubelogger.log_hyperparams(params) + self.csvlogger.log_hyperparams(params) pass def log_metrics(self, metrics, step=None): self.neptunelogger.log_metrics(metrics, step=step) - self.testtubelogger.log_metrics(metrics, step=step) + self.csvlogger.log_metrics(metrics, step=step) pass def close(self): - self.testtubelogger.close() + self.csvlogger.close() self.neptunelogger.close() def log_config_as_ini(self): self.config.write(self.log_dir / 'config.ini') - def log_text(self, name, text, step_nb=0, **kwargs): + def log_text(self, name, text, step_nb=0, **_): # TODO Implement Offline variant. self.neptunelogger.log_text(name, text, step_nb) def log_metric(self, metric_name, metric_value, **kwargs): - self.testtubelogger.log_metrics(dict(metric_name=metric_value)) + self.csvlogger.log_metrics(dict(metric_name=metric_value)) self.neptunelogger.log_metric(metric_name, metric_value, **kwargs) def log_image(self, name, image, ext='png', **kwargs): - self.neptunelogger.log_image(name, image, **kwargs) + step = kwargs.get('step', None) - name = f'{step}_{name}' if step is not None else name - name = f'{name}.{ext[1:] if ext.startswith(".") else ext}' + image_name = f'{step}_{name}' if step is not None else name + image_path = self.log_dir / self.media_dir / f'{image_name}.{ext[1:] if ext.startswith(".") else ext}' (self.log_dir / self.media_dir).mkdir(parents=True, exist_ok=True) - image.savefig(self.log_dir / self.media_dir / name) + image.savefig(image_path, bbox_inches='tight', pad_inches=0) + self.neptunelogger.log_image(name, str(image_path), **kwargs) def save(self): - self.testtubelogger.save() + self.csvlogger.save() self.neptunelogger.save() def finalize(self, status): - self.testtubelogger.finalize(status) + self.csvlogger.finalize(status) self.neptunelogger.finalize(status) def __enter__(self): diff --git a/utils/model_io.py b/utils/model_io.py index 9e025fe..fe9374f 100644 --- a/utils/model_io.py +++ b/utils/model_io.py @@ -20,7 +20,7 @@ class ModelParameters(Namespace, Mapping): paramter_mapping.update( dict( - activation=self._activations[self['activation']] + activation=self.__getattribute__('activation') ) ) @@ -44,7 +44,7 @@ class ModelParameters(Namespace, Mapping): def __getattribute__(self, name): if name == 'activation': - return self._activations[self['activation']] + return self._activations[self['activation'].lower()] else: try: return super(ModelParameters, self).__getattribute__(name) @@ -56,6 +56,7 @@ class ModelParameters(Namespace, Mapping): _activations = dict( leaky_relu=nn.LeakyReLU, + elu=nn.ELU, relu=nn.ReLU, sigmoid=nn.Sigmoid, tanh=nn.Tanh diff --git a/visualization/tools.py b/visualization/tools.py index 224215a..87079ba 100644 --- a/visualization/tools.py +++ b/visualization/tools.py @@ -1,5 +1,5 @@ try: - import matplotlib.pyplot as plt + from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas except ImportError: # pragma: no-cover raise ImportError('You want to use `matplotlib` plugins which are not installed yet,' # pragma: no-cover ' install it with `pip install matplotlib`.') @@ -8,30 +8,23 @@ from pathlib import Path class Plotter(object): + def __init__(self, root_path=''): if not root_path: self.root_path = Path(root_path) - def save_current_figure(self, filename: str, extention='.png', naked=False): - fig, _ = plt.gcf(), plt.gca() + def save_figure(self, figure, title, extention='.png', naked=False): + canvas = FigureCanvas(figure) # Prepare save location and check img file extention - path = self.root_path / Path(filename if filename.endswith(extention) else f'{filename}{extention}') + path = self.root_path / f'{title}{extention}' path.parent.mkdir(exist_ok=True, parents=True) if naked: - plt.axis('off') - fig.savefig(path, bbox_inches='tight', transparent=True, pad_inches=0) - fig.clf() + figure.axis('off)') + figure.savefig(path, bbox_inches='tight', transparent=True, pad_inches=0) + canvas.print_figure(path) else: - fig.savefig(path) - fig.clf() - - def show_current_figure(self): - fig, _ = plt.gcf(), plt.gca() - fig.show() - fig.clf() + canvas.print_figure(path) if __name__ == '__main__': - output_root = Path('..') / 'output' - p = Plotter(output_root) - p.save_current_figure('test.png') + raise PermissionError('Get out of here.')