From 5987efb169ce812021d7ff3dfc758dda8a34a399 Mon Sep 17 00:00:00 2001 From: Si11ium Date: Sat, 30 May 2020 18:12:41 +0200 Subject: [PATCH] eval running - offline logger implemented -> Test it! --- logging/__init__.py | 0 logging/local_logging.py | 488 ++++++++++++++++++++++++++++++++++++ modules/geometric_blocks.py | 21 +- point_toolset/point_io.py | 24 ++ point_toolset/sampling.py | 49 +++- utils/data_util.py | 41 +++ utils/logging.py | 3 + utils/tools.py | 10 + visualization/tools.py | 7 +- 9 files changed, 626 insertions(+), 17 deletions(-) create mode 100644 logging/__init__.py create mode 100644 logging/local_logging.py create mode 100644 point_toolset/point_io.py create mode 100644 utils/data_util.py diff --git a/logging/__init__.py b/logging/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/logging/local_logging.py b/logging/local_logging.py new file mode 100644 index 0000000..42933d1 --- /dev/null +++ b/logging/local_logging.py @@ -0,0 +1,488 @@ +########################## +# constants +import argparse +import contextlib +import json +from datetime import datetime +from pathlib import Path +from typing import Dict, Optional, Union, Any +import numpy as np + +import pandas as pd + +import os + + +# ToDo: Check this +import shutil +from imageio import imwrite +from natsort import natsorted +from pytorch_lightning.loggers import LightningLoggerBase +from pytorch_lightning import _logger as log +from test_tube.log import DDPExperiment + +_ROOT = Path(os.path.abspath(__file__)) + + +# ----------------------------- +# Experiment object +# ----------------------------- +class Experiment(object): + + def __init__(self, save_dir=None, name='default', debug=False, version=None, autosave=False, description=None): + """ + A new Experiment object defaults to 'default' unless a specific name is provided + If a known name is already provided, then the file version is changed + :param name: + :param debug: + """ + + # change where the save dir is if requested + + if save_dir is not None: + global _ROOT + _ROOT = save_dir + + self.save_dir = save_dir + self.no_save_dir = save_dir is None + self.metrics = [] + self.tags = {} + self.name = name + self.debug = debug + self.version = version + self.autosave = autosave + self.description = description + self.exp_hash = '{}_v{}'.format(self.name, version) + self.created_at = str(datetime.utcnow()) + self.process = os.getpid() + + # when debugging don't do anything else + if debug: + return + + # update version hash if we need to increase version on our own + # we will increase the previous version, so do it now so the hash + # is accurate + if version is None: + old_version = self.__get_last_experiment_version() + self.exp_hash = '{}_v{}'.format(self.name, old_version + 1) + self.version = old_version + 1 + + # create a new log file + self.__init_cache_file_if_needed() + + # when we have a version, load it + if self.version is not None: + + # when no version and no file, create it + if not os.path.exists(self.__get_log_name()): + self.__create_exp_file(self.version) + else: + # otherwise load it + self.__load() + + else: + # if no version given, increase the version to a new exp + # create the file if not exists + old_version = self.__get_last_experiment_version() + self.version = old_version + self.__create_exp_file(self.version + 1) + + def get_meta_copy(self): + """ + Gets a meta-version only copy of this module + :return: + """ + return DDPExperiment(self) + + def on_exit(self): + pass + + def __clean_dir(self): + files = os.listdir(self.save_dir) + + for f in files: + if str(self.process) in f: + os.remove(os.path.join(self.save_dir, f)) + + def argparse(self, argparser): + parsed = vars(argparser) + to_add = {} + + # don't store methods + for k, v in parsed.items(): + if not callable(v): + to_add[k] = v + + self.tag(to_add) + + def add_meta_from_hyperopt(self, hypo): + """ + Transfers meta data about all the params from the + hyperoptimizer to the log + :param hypo: + :return: + """ + meta = hypo.get_current_trial_meta() + for tag in meta: + self.tag(tag) + + # -------------------------------- + # FILE IO UTILS + # -------------------------------- + def __init_cache_file_if_needed(self): + """ + Inits a file that we log historical experiments + :return: + """ + try: + exp_cache_file = self.get_data_path(self.name, self.version) + if not os.path.isdir(exp_cache_file): + os.makedirs(exp_cache_file, exist_ok=True) + except FileExistsError: + # file already exists (likely written by another exp. In this case disable the experiment + self.debug = True + + def __create_exp_file(self, version): + """ + Recreates the old file with this exp and version + :param version: + :return: + """ + + try: + exp_cache_file = self.get_data_path(self.name, self.version) + # if no exp, then make it + path = exp_cache_file / 'meta.experiment' + path.touch(exist_ok=True) + + self.version = version + + # make the directory for the experiment media assets name + self.get_media_path(self.name, self.version).mkdir(parents=True, exist_ok=True) + + except FileExistsError: + # file already exists (likely written by another exp. In this case disable the experiment + self.debug = True + + def __get_last_experiment_version(self): + + exp_cache_file = self.get_data_path(self.name, self.version).parent + last_version = -1 + + version = natsorted([x.name for x in exp_cache_file.iterdir() if 'version_' in x.name])[-1] + last_version = max(last_version, int(version.split('_')[1])) + + return last_version + + def __get_log_name(self): + return self.get_data_path(self.name, self.version) / 'meta.experiment' + + def tag(self, tag_dict): + """ + Adds a tag to the experiment. + Tags are metadata for the exp. + + >> e.tag({"model": "Convnet A"}) + + :param tag_dict: + :type tag_dict: dict + + :return: + """ + if self.debug: + return + + # parse tags + for k, v in tag_dict.items(): + self.tags[k] = v + + # save if needed + if self.autosave: + self.save() + + def log(self, metrics_dict): + """ + Adds a json dict of metrics. + + >> e.log({"loss": 23, "coeff_a": 0.2}) + + :param metrics_dict: + + :return: + """ + if self.debug: + return + + new_metrics_dict = metrics_dict.copy() + for k, v in metrics_dict.items(): + tmp_metrics_dict = new_metrics_dict.pop(k) + new_metrics_dict.update(tmp_metrics_dict) + + metrics_dict = new_metrics_dict + + # timestamp + if 'created_at' not in metrics_dict: + metrics_dict['created_at'] = str(datetime.utcnow()) + + self.__convert_numpy_types(metrics_dict) + + self.metrics.append(metrics_dict) + + if self.autosave: + self.save() + + @staticmethod + def __convert_numpy_types(metrics_dict): + for k, v in metrics_dict.items(): + if v.__class__.__name__ == 'float32': + metrics_dict[k] = float(v) + + if v.__class__.__name__ == 'float64': + metrics_dict[k] = float(v) + + def save(self): + """ + Saves current experiment progress + :return: + """ + if self.debug: + return + + # save images and replace the image array with the + # file name + self.__save_images(self.metrics) + metrics_file_path = self.get_data_path(self.name, self.version) / 'metrics.csv' + meta_tags_path = self.get_data_path(self.name, self.version) / 'meta_tags.csv' + + obj = { + 'name': self.name, + 'version': self.version, + 'tags_path': meta_tags_path, + 'metrics_path': metrics_file_path, + 'autosave': self.autosave, + 'description': self.description, + 'created_at': self.created_at, + 'exp_hash': self.exp_hash + } + + # save the experiment meta file + with atomic_write(self.__get_log_name()) as tmp_path: + with open(tmp_path, 'w') as file: + json.dump(obj, file, ensure_ascii=False) + + # save the metatags file + df = pd.DataFrame({'key': list(self.tags.keys()), 'value': list(self.tags.values())}) + with atomic_write(meta_tags_path) as tmp_path: + df.to_csv(tmp_path, index=False) + + # save the metrics data + df = pd.DataFrame(self.metrics) + with atomic_write(metrics_file_path) as tmp_path: + df.to_csv(tmp_path, index=False) + + def __save_images(self, metrics): + """ + Save tags that have a png_ prefix (as images) + and replace the meta tag with the file name + :param metrics: + :return: + """ + # iterate all metrics and find keys with a specific prefix + for i, metric in enumerate(metrics): + for k, v in metric.items(): + # if the prefix is a png, save the image and replace the value with the path + img_extension = None + img_extension = 'png' if 'png_' in k else img_extension + img_extension = 'jpg' if 'jpg' in k else img_extension + img_extension = 'jpeg' if 'jpeg' in k else img_extension + + if img_extension is not None: + # determine the file name + img_name = '_'.join(k.split('_')[1:]) + save_path = self.get_media_path(self.name, self.version) + save_path = '{}/{}_{}.{}'.format(save_path, img_name, i, img_extension) + + # save image to disk + if type(metric[k]) is not str: + imwrite(save_path, metric[k]) + + # replace the image in the metric with the file path + metric[k] = save_path + + def __load(self): + # load .experiment file + with open(self.__get_log_name(), 'r') as file: + data = json.load(file) + self.name = data['name'] + self.version = data['version'] + self.autosave = data['autosave'] + self.created_at = data['created_at'] + self.description = data['description'] + self.exp_hash = data['exp_hash'] + + # load .tags file + meta_tags_path = self.get_data_path(self.name, self.version) / 'meta_tags.csv' + df = pd.read_csv(meta_tags_path) + self.tags_list = df.to_dict(orient='records') + self.tags = {} + for d in self.tags_list: + k, v = d['key'], d['value'] + self.tags[k] = v + + # load metrics + metrics_file_path = self.get_data_path(self.name, self.version) / 'metrics.csv' + try: + df = pd.read_csv(metrics_file_path) + self.metrics = df.to_dict(orient='records') + + # remove nans and infs + for metric in self.metrics: + to_delete = [] + for k, v in metric.items(): + if np.isnan(v) or np.isinf(v): + to_delete.append(k) + for k in to_delete: + del metric[k] + + except Exception: + # metrics was empty... + self.metrics = [] + + def get_data_path(self, exp_name, exp_version): + """ + Returns the path to the local package cache + :param exp_name: + :param exp_version: + :return: + Path + """ + if self.no_save_dir: + return _ROOT / 'local_experiment_data' / exp_name, f'version_{exp_version}' + else: + return _ROOT / exp_name / f'version_{exp_version}' + + def get_media_path(self, exp_name, exp_version): + """ + Returns the path to the local package cache + :param exp_version: + :param exp_name: + :return: + """ + + return self.get_data_path(exp_name, exp_version) / 'media' + + # ---------------------------- + # OVERWRITES + # ---------------------------- + def __str__(self): + return 'Exp: {}, v: {}'.format(self.name, self.version) + + def __hash__(self): + return 'Exp: {}, v: {}'.format(self.name, self.version) + + +@contextlib.contextmanager +def atomic_write(dst_path): + """A context manager to simplify atomic writing. + + Usage: + >>> with atomic_write(dst_path) as tmp_path: + >>> # write to tmp_path + >>> # Here tmp_path renamed to dst_path, if no exception happened. + """ + tmp_path = dst_path / '.tmp' + try: + yield tmp_path + except: + if tmp_path.exists(): + tmp_path.unlink() + raise + else: + # If everything is fine, move tmp file to the destination. + shutil.move(tmp_path, str(dst_path)) + + +########################## +class LocalLogger(LightningLoggerBase): + + @property + def name(self) -> str: + return self._name + + @property + def experiment(self) -> Experiment: + r""" + + Actual TestTube object. To use TestTube features in your + :class:`~pytorch_lightning.core.lightning.LightningModule` do the following. + + Example:: + + self.logger.experiment.some_test_tube_function() + + """ + if self._experiment is not None: + return self._experiment + + self._experiment = Experiment( + save_dir=self.save_dir, + name=self._name, + debug=self.debug, + version=self.version, + description=self.description + ) + return self._experiment + + def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None): + pass + + def log_hyperparams(self, params: argparse.Namespace): + pass + + @property + def version(self) -> Union[int, str]: + if self._version is None: + self._version = self._get_next_version() + return self._version + + def _get_next_version(self): + root_dir = self.save_dir / self.name + + if not root_dir.is_dir(): + log.warning(f'Missing logger folder: {root_dir}') + return 0 + + existing_versions = [] + for d in os.listdir(root_dir): + if os.path.isdir(os.path.join(root_dir, d)) and d.startswith("version_"): + existing_versions.append(int(d.split("_")[1])) + + if len(existing_versions) == 0: + return 0 + + return max(existing_versions) + 1 + + def __init__(self, save_dir: str, name: str = "default", description: Optional[str] = None, + debug: bool = False, version: Optional[int] = None, **kwargs): + super(LocalLogger, self).__init__(**kwargs) + self.save_dir = Path(save_dir) + self._name = name + self.description = description + self.debug = debug + self._version = version + self._experiment = None + + # Test tube experiments are not pickleable, so we need to override a few + # methods to get DDP working. See + # https://docs.python.org/3/library/pickle.html#handling-stateful-objects + # for more info. + def __getstate__(self) -> Dict[Any, Any]: + state = self.__dict__.copy() + state["_experiment"] = self.experiment.get_meta_copy() + return state + + def __setstate__(self, state: Dict[Any, Any]): + self._experiment = state["_experiment"].get_non_ddp_exp() + del state["_experiment"] + self.__dict__.update(state) diff --git a/modules/geometric_blocks.py b/modules/geometric_blocks.py index 9fd4479..dfd77d0 100644 --- a/modules/geometric_blocks.py +++ b/modules/geometric_blocks.py @@ -2,7 +2,7 @@ import torch from torch import nn from torch.nn import ReLU -from torch_geometric.nn import PointConv, fps, radius, global_max_pool +from torch_geometric.nn import PointConv, fps, radius, global_max_pool, knn_interpolate class SAModule(torch.nn.Module): @@ -23,14 +23,15 @@ class SAModule(torch.nn.Module): class GlobalSAModule(nn.Module): - def __init__(self, nn): + def __init__(self, nn, channels=3): super(GlobalSAModule, self).__init__() self.nn = nn + self.channels = channels def forward(self, x, pos, batch): x = self.nn(torch.cat([x, pos], dim=1)) x = global_max_pool(x, batch) - pos = pos.new_zeros((x.size(0), 3)) + pos = pos.new_zeros((x.size(0), self.channels)) batch = torch.arange(x.size(0), device=batch.device) return x, pos, batch @@ -45,3 +46,17 @@ class MLP(nn.Module): def forward(self, x, *args, **kwargs): return self.net(x) + + +class FPModule(torch.nn.Module): + def __init__(self, k, nn): + super(FPModule, self).__init__() + self.k = k + self.nn = nn + + def forward(self, x, pos, batch, x_skip, pos_skip, batch_skip): + x = knn_interpolate(x, pos, pos_skip, batch, batch_skip, k=self.k) + if x_skip is not None: + x = torch.cat([x, x_skip], dim=1) + x = self.nn(x) + return x, pos_skip, batch_skip \ No newline at end of file diff --git a/point_toolset/point_io.py b/point_toolset/point_io.py new file mode 100644 index 0000000..49e1ebc --- /dev/null +++ b/point_toolset/point_io.py @@ -0,0 +1,24 @@ +import torch +from torch_geometric.data import Data + + +class BatchToData(object): + def __init__(self): + super(BatchToData, self).__init__() + + def __call__(self, batch_x: torch.Tensor, batch_pos: torch.Tensor, batch_y: torch.Tensor): + # Convert to torch_geometric.data.Data type + # data = data.transpose(1, 2).contiguous() + batch_size, num_points, _ = batch_x.shape # (batch_size, num_points, 3) + + x = batch_x.reshape(batch_size * num_points, -1) + pos = batch_pos.reshape(batch_size * num_points, -1) + batch_y = batch_y.reshape(batch_size * num_points) + batch = torch.zeros((batch_size, num_points), device=pos.device, dtype=torch.long) + for i in range(batch_size): + batch[i] = i + batch = batch.view(-1) + + data = Data() + data.x, data.pos, data.batch, data.y = x, pos, batch, batch_y + return data diff --git a/point_toolset/sampling.py b/point_toolset/sampling.py index 0a9c2c5..0c2ee8f 100644 --- a/point_toolset/sampling.py +++ b/point_toolset/sampling.py @@ -1,10 +1,36 @@ +from abc import ABC + import numpy as np -class FarthestpointSampling(): +class _Sampler(ABC): - def __init__(self, K): + def __init__(self, K, **kwargs): self.k = K + self.kwargs = kwargs + + def __call__(self, *args, **kwargs): + raise NotImplementedError + + +class RandomSampling(_Sampler): + + def __init__(self, *args, **kwargs): + super(RandomSampling, self).__init__(*args, **kwargs) + + def __call__(self, pts, *args, **kwargs): + if pts.shape[0] < self.k: + return pts + + else: + rnd_indexs = np.random.choice(np.arange(pts.shape[0]), self.k, replace=False) + return rnd_indexs + + +class FarthestpointSampling(_Sampler): + + def __init__(self, *args, **kwargs): + super(FarthestpointSampling, self).__init__(*args, **kwargs) @staticmethod def calc_distances(p0, points): @@ -15,14 +41,15 @@ class FarthestpointSampling(): if pts.shape[0] < self.k: return pts - farthest_pts = np.zeros((self.k, pts.shape[1])) - farthest_pts_idx = np.zeros(self.k, dtype=np.int) - farthest_pts[0] = pts[np.random.randint(len(pts))] - distances = self.calc_distances(farthest_pts[0], pts) - for i in range(1, self.k): - farthest_pts_idx[i] = np.argmax(distances) - farthest_pts[i] = pts[farthest_pts_idx[i]] + else: + farthest_pts = np.zeros((self.k, pts.shape[1])) + farthest_pts_idx = np.zeros(self.k, dtype=np.int) + farthest_pts[0] = pts[np.random.randint(len(pts))] + distances = self.calc_distances(farthest_pts[0], pts) + for i in range(1, self.k): + farthest_pts_idx[i] = np.argmax(distances) + farthest_pts[i] = pts[farthest_pts_idx[i]] - distances = np.minimum(distances, self.calc_distances(farthest_pts[i], pts)) + distances = np.minimum(distances, self.calc_distances(farthest_pts[i], pts)) - return farthest_pts_idx + return farthest_pts_idx diff --git a/utils/data_util.py b/utils/data_util.py new file mode 100644 index 0000000..a22f979 --- /dev/null +++ b/utils/data_util.py @@ -0,0 +1,41 @@ +import torch +from torch.utils.data import Dataset + + +def chunks(l, n): + """Yield successive n-sized chunks from l.""" + for i in range(0, len(l), n): + yield l[i:i + n] + + +class ReMapDataset(Dataset): + @property + def sample_shape(self): + return list(self[0][0].shape) + + def __init__(self, ds, mapping): + super(ReMapDataset, self).__init__() + # here is a mapping from this index to the mother ds index + self.mapping = mapping + self.ds = ds + + def __getitem__(self, index): + return self.ds[self.mapping[index]] + + def __len__(self): + return self.mapping.shape[0] + + @classmethod + def do_train_vali_split(cls, ds, split_fold=0.1): + + indices = torch.randperm(len(ds)) + + valid_size = int(len(ds) * split_fold) + + train_mapping = indices[valid_size:] + valid_mapping = indices[:valid_size] + + train = cls(ds, train_mapping) + valid = cls(ds, valid_mapping) + + return train, valid \ No newline at end of file diff --git a/utils/logging.py b/utils/logging.py index fde0f47..8406ebb 100644 --- a/utils/logging.py +++ b/utils/logging.py @@ -1,3 +1,6 @@ +import argparse +from typing import Union, Dict, Optional, Any + from abc import ABC from pathlib import Path diff --git a/utils/tools.py b/utils/tools.py index e5b73c1..68f5a21 100644 --- a/utils/tools.py +++ b/utils/tools.py @@ -2,6 +2,16 @@ import pickle import shelve from pathlib import Path +import numpy as np + +from utils.project_config import GlobalVar + + +def to_one_hot(idx_array): + one_hot = np.zeros((idx_array.size, len(GlobalVar.classes))) + one_hot[np.arange(idx_array.size), idx_array] = 1 + return one_hot + def fix_all_random_seeds(config_obj): import numpy as np diff --git a/visualization/tools.py b/visualization/tools.py index 1110f50..224215a 100644 --- a/visualization/tools.py +++ b/visualization/tools.py @@ -9,12 +9,13 @@ from pathlib import Path class Plotter(object): def __init__(self, root_path=''): - self.root_path = Path(root_path) + if not root_path: + self.root_path = Path(root_path) - def save_current_figure(self, path, extention='.png', naked=True): + def save_current_figure(self, filename: str, extention='.png', naked=False): fig, _ = plt.gcf(), plt.gca() # Prepare save location and check img file extention - path = self.root_path / Path(path if str(path).endswith(extention) else f'{str(path)}{extention}') + path = self.root_path / Path(filename if filename.endswith(extention) else f'{filename}{extention}') path.parent.mkdir(exist_ok=True, parents=True) if naked: plt.axis('off')