eval running - offline logger implemented -> Test it!
This commit is contained in:
parent
77ea043907
commit
5987efb169
0
logging/__init__.py
Normal file
0
logging/__init__.py
Normal file
488
logging/local_logging.py
Normal file
488
logging/local_logging.py
Normal file
@ -0,0 +1,488 @@
|
|||||||
|
##########################
|
||||||
|
# constants
|
||||||
|
import argparse
|
||||||
|
import contextlib
|
||||||
|
import json
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, Optional, Union, Any
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
# ToDo: Check this
|
||||||
|
import shutil
|
||||||
|
from imageio import imwrite
|
||||||
|
from natsort import natsorted
|
||||||
|
from pytorch_lightning.loggers import LightningLoggerBase
|
||||||
|
from pytorch_lightning import _logger as log
|
||||||
|
from test_tube.log import DDPExperiment
|
||||||
|
|
||||||
|
_ROOT = Path(os.path.abspath(__file__))
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------
|
||||||
|
# Experiment object
|
||||||
|
# -----------------------------
|
||||||
|
class Experiment(object):
|
||||||
|
|
||||||
|
def __init__(self, save_dir=None, name='default', debug=False, version=None, autosave=False, description=None):
|
||||||
|
"""
|
||||||
|
A new Experiment object defaults to 'default' unless a specific name is provided
|
||||||
|
If a known name is already provided, then the file version is changed
|
||||||
|
:param name:
|
||||||
|
:param debug:
|
||||||
|
"""
|
||||||
|
|
||||||
|
# change where the save dir is if requested
|
||||||
|
|
||||||
|
if save_dir is not None:
|
||||||
|
global _ROOT
|
||||||
|
_ROOT = save_dir
|
||||||
|
|
||||||
|
self.save_dir = save_dir
|
||||||
|
self.no_save_dir = save_dir is None
|
||||||
|
self.metrics = []
|
||||||
|
self.tags = {}
|
||||||
|
self.name = name
|
||||||
|
self.debug = debug
|
||||||
|
self.version = version
|
||||||
|
self.autosave = autosave
|
||||||
|
self.description = description
|
||||||
|
self.exp_hash = '{}_v{}'.format(self.name, version)
|
||||||
|
self.created_at = str(datetime.utcnow())
|
||||||
|
self.process = os.getpid()
|
||||||
|
|
||||||
|
# when debugging don't do anything else
|
||||||
|
if debug:
|
||||||
|
return
|
||||||
|
|
||||||
|
# update version hash if we need to increase version on our own
|
||||||
|
# we will increase the previous version, so do it now so the hash
|
||||||
|
# is accurate
|
||||||
|
if version is None:
|
||||||
|
old_version = self.__get_last_experiment_version()
|
||||||
|
self.exp_hash = '{}_v{}'.format(self.name, old_version + 1)
|
||||||
|
self.version = old_version + 1
|
||||||
|
|
||||||
|
# create a new log file
|
||||||
|
self.__init_cache_file_if_needed()
|
||||||
|
|
||||||
|
# when we have a version, load it
|
||||||
|
if self.version is not None:
|
||||||
|
|
||||||
|
# when no version and no file, create it
|
||||||
|
if not os.path.exists(self.__get_log_name()):
|
||||||
|
self.__create_exp_file(self.version)
|
||||||
|
else:
|
||||||
|
# otherwise load it
|
||||||
|
self.__load()
|
||||||
|
|
||||||
|
else:
|
||||||
|
# if no version given, increase the version to a new exp
|
||||||
|
# create the file if not exists
|
||||||
|
old_version = self.__get_last_experiment_version()
|
||||||
|
self.version = old_version
|
||||||
|
self.__create_exp_file(self.version + 1)
|
||||||
|
|
||||||
|
def get_meta_copy(self):
|
||||||
|
"""
|
||||||
|
Gets a meta-version only copy of this module
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
return DDPExperiment(self)
|
||||||
|
|
||||||
|
def on_exit(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __clean_dir(self):
|
||||||
|
files = os.listdir(self.save_dir)
|
||||||
|
|
||||||
|
for f in files:
|
||||||
|
if str(self.process) in f:
|
||||||
|
os.remove(os.path.join(self.save_dir, f))
|
||||||
|
|
||||||
|
def argparse(self, argparser):
|
||||||
|
parsed = vars(argparser)
|
||||||
|
to_add = {}
|
||||||
|
|
||||||
|
# don't store methods
|
||||||
|
for k, v in parsed.items():
|
||||||
|
if not callable(v):
|
||||||
|
to_add[k] = v
|
||||||
|
|
||||||
|
self.tag(to_add)
|
||||||
|
|
||||||
|
def add_meta_from_hyperopt(self, hypo):
|
||||||
|
"""
|
||||||
|
Transfers meta data about all the params from the
|
||||||
|
hyperoptimizer to the log
|
||||||
|
:param hypo:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
meta = hypo.get_current_trial_meta()
|
||||||
|
for tag in meta:
|
||||||
|
self.tag(tag)
|
||||||
|
|
||||||
|
# --------------------------------
|
||||||
|
# FILE IO UTILS
|
||||||
|
# --------------------------------
|
||||||
|
def __init_cache_file_if_needed(self):
|
||||||
|
"""
|
||||||
|
Inits a file that we log historical experiments
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
exp_cache_file = self.get_data_path(self.name, self.version)
|
||||||
|
if not os.path.isdir(exp_cache_file):
|
||||||
|
os.makedirs(exp_cache_file, exist_ok=True)
|
||||||
|
except FileExistsError:
|
||||||
|
# file already exists (likely written by another exp. In this case disable the experiment
|
||||||
|
self.debug = True
|
||||||
|
|
||||||
|
def __create_exp_file(self, version):
|
||||||
|
"""
|
||||||
|
Recreates the old file with this exp and version
|
||||||
|
:param version:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
exp_cache_file = self.get_data_path(self.name, self.version)
|
||||||
|
# if no exp, then make it
|
||||||
|
path = exp_cache_file / 'meta.experiment'
|
||||||
|
path.touch(exist_ok=True)
|
||||||
|
|
||||||
|
self.version = version
|
||||||
|
|
||||||
|
# make the directory for the experiment media assets name
|
||||||
|
self.get_media_path(self.name, self.version).mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
except FileExistsError:
|
||||||
|
# file already exists (likely written by another exp. In this case disable the experiment
|
||||||
|
self.debug = True
|
||||||
|
|
||||||
|
def __get_last_experiment_version(self):
|
||||||
|
|
||||||
|
exp_cache_file = self.get_data_path(self.name, self.version).parent
|
||||||
|
last_version = -1
|
||||||
|
|
||||||
|
version = natsorted([x.name for x in exp_cache_file.iterdir() if 'version_' in x.name])[-1]
|
||||||
|
last_version = max(last_version, int(version.split('_')[1]))
|
||||||
|
|
||||||
|
return last_version
|
||||||
|
|
||||||
|
def __get_log_name(self):
|
||||||
|
return self.get_data_path(self.name, self.version) / 'meta.experiment'
|
||||||
|
|
||||||
|
def tag(self, tag_dict):
|
||||||
|
"""
|
||||||
|
Adds a tag to the experiment.
|
||||||
|
Tags are metadata for the exp.
|
||||||
|
|
||||||
|
>> e.tag({"model": "Convnet A"})
|
||||||
|
|
||||||
|
:param tag_dict:
|
||||||
|
:type tag_dict: dict
|
||||||
|
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
if self.debug:
|
||||||
|
return
|
||||||
|
|
||||||
|
# parse tags
|
||||||
|
for k, v in tag_dict.items():
|
||||||
|
self.tags[k] = v
|
||||||
|
|
||||||
|
# save if needed
|
||||||
|
if self.autosave:
|
||||||
|
self.save()
|
||||||
|
|
||||||
|
def log(self, metrics_dict):
|
||||||
|
"""
|
||||||
|
Adds a json dict of metrics.
|
||||||
|
|
||||||
|
>> e.log({"loss": 23, "coeff_a": 0.2})
|
||||||
|
|
||||||
|
:param metrics_dict:
|
||||||
|
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
if self.debug:
|
||||||
|
return
|
||||||
|
|
||||||
|
new_metrics_dict = metrics_dict.copy()
|
||||||
|
for k, v in metrics_dict.items():
|
||||||
|
tmp_metrics_dict = new_metrics_dict.pop(k)
|
||||||
|
new_metrics_dict.update(tmp_metrics_dict)
|
||||||
|
|
||||||
|
metrics_dict = new_metrics_dict
|
||||||
|
|
||||||
|
# timestamp
|
||||||
|
if 'created_at' not in metrics_dict:
|
||||||
|
metrics_dict['created_at'] = str(datetime.utcnow())
|
||||||
|
|
||||||
|
self.__convert_numpy_types(metrics_dict)
|
||||||
|
|
||||||
|
self.metrics.append(metrics_dict)
|
||||||
|
|
||||||
|
if self.autosave:
|
||||||
|
self.save()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def __convert_numpy_types(metrics_dict):
|
||||||
|
for k, v in metrics_dict.items():
|
||||||
|
if v.__class__.__name__ == 'float32':
|
||||||
|
metrics_dict[k] = float(v)
|
||||||
|
|
||||||
|
if v.__class__.__name__ == 'float64':
|
||||||
|
metrics_dict[k] = float(v)
|
||||||
|
|
||||||
|
def save(self):
|
||||||
|
"""
|
||||||
|
Saves current experiment progress
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
if self.debug:
|
||||||
|
return
|
||||||
|
|
||||||
|
# save images and replace the image array with the
|
||||||
|
# file name
|
||||||
|
self.__save_images(self.metrics)
|
||||||
|
metrics_file_path = self.get_data_path(self.name, self.version) / 'metrics.csv'
|
||||||
|
meta_tags_path = self.get_data_path(self.name, self.version) / 'meta_tags.csv'
|
||||||
|
|
||||||
|
obj = {
|
||||||
|
'name': self.name,
|
||||||
|
'version': self.version,
|
||||||
|
'tags_path': meta_tags_path,
|
||||||
|
'metrics_path': metrics_file_path,
|
||||||
|
'autosave': self.autosave,
|
||||||
|
'description': self.description,
|
||||||
|
'created_at': self.created_at,
|
||||||
|
'exp_hash': self.exp_hash
|
||||||
|
}
|
||||||
|
|
||||||
|
# save the experiment meta file
|
||||||
|
with atomic_write(self.__get_log_name()) as tmp_path:
|
||||||
|
with open(tmp_path, 'w') as file:
|
||||||
|
json.dump(obj, file, ensure_ascii=False)
|
||||||
|
|
||||||
|
# save the metatags file
|
||||||
|
df = pd.DataFrame({'key': list(self.tags.keys()), 'value': list(self.tags.values())})
|
||||||
|
with atomic_write(meta_tags_path) as tmp_path:
|
||||||
|
df.to_csv(tmp_path, index=False)
|
||||||
|
|
||||||
|
# save the metrics data
|
||||||
|
df = pd.DataFrame(self.metrics)
|
||||||
|
with atomic_write(metrics_file_path) as tmp_path:
|
||||||
|
df.to_csv(tmp_path, index=False)
|
||||||
|
|
||||||
|
def __save_images(self, metrics):
|
||||||
|
"""
|
||||||
|
Save tags that have a png_ prefix (as images)
|
||||||
|
and replace the meta tag with the file name
|
||||||
|
:param metrics:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
# iterate all metrics and find keys with a specific prefix
|
||||||
|
for i, metric in enumerate(metrics):
|
||||||
|
for k, v in metric.items():
|
||||||
|
# if the prefix is a png, save the image and replace the value with the path
|
||||||
|
img_extension = None
|
||||||
|
img_extension = 'png' if 'png_' in k else img_extension
|
||||||
|
img_extension = 'jpg' if 'jpg' in k else img_extension
|
||||||
|
img_extension = 'jpeg' if 'jpeg' in k else img_extension
|
||||||
|
|
||||||
|
if img_extension is not None:
|
||||||
|
# determine the file name
|
||||||
|
img_name = '_'.join(k.split('_')[1:])
|
||||||
|
save_path = self.get_media_path(self.name, self.version)
|
||||||
|
save_path = '{}/{}_{}.{}'.format(save_path, img_name, i, img_extension)
|
||||||
|
|
||||||
|
# save image to disk
|
||||||
|
if type(metric[k]) is not str:
|
||||||
|
imwrite(save_path, metric[k])
|
||||||
|
|
||||||
|
# replace the image in the metric with the file path
|
||||||
|
metric[k] = save_path
|
||||||
|
|
||||||
|
def __load(self):
|
||||||
|
# load .experiment file
|
||||||
|
with open(self.__get_log_name(), 'r') as file:
|
||||||
|
data = json.load(file)
|
||||||
|
self.name = data['name']
|
||||||
|
self.version = data['version']
|
||||||
|
self.autosave = data['autosave']
|
||||||
|
self.created_at = data['created_at']
|
||||||
|
self.description = data['description']
|
||||||
|
self.exp_hash = data['exp_hash']
|
||||||
|
|
||||||
|
# load .tags file
|
||||||
|
meta_tags_path = self.get_data_path(self.name, self.version) / 'meta_tags.csv'
|
||||||
|
df = pd.read_csv(meta_tags_path)
|
||||||
|
self.tags_list = df.to_dict(orient='records')
|
||||||
|
self.tags = {}
|
||||||
|
for d in self.tags_list:
|
||||||
|
k, v = d['key'], d['value']
|
||||||
|
self.tags[k] = v
|
||||||
|
|
||||||
|
# load metrics
|
||||||
|
metrics_file_path = self.get_data_path(self.name, self.version) / 'metrics.csv'
|
||||||
|
try:
|
||||||
|
df = pd.read_csv(metrics_file_path)
|
||||||
|
self.metrics = df.to_dict(orient='records')
|
||||||
|
|
||||||
|
# remove nans and infs
|
||||||
|
for metric in self.metrics:
|
||||||
|
to_delete = []
|
||||||
|
for k, v in metric.items():
|
||||||
|
if np.isnan(v) or np.isinf(v):
|
||||||
|
to_delete.append(k)
|
||||||
|
for k in to_delete:
|
||||||
|
del metric[k]
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
# metrics was empty...
|
||||||
|
self.metrics = []
|
||||||
|
|
||||||
|
def get_data_path(self, exp_name, exp_version):
|
||||||
|
"""
|
||||||
|
Returns the path to the local package cache
|
||||||
|
:param exp_name:
|
||||||
|
:param exp_version:
|
||||||
|
:return:
|
||||||
|
Path
|
||||||
|
"""
|
||||||
|
if self.no_save_dir:
|
||||||
|
return _ROOT / 'local_experiment_data' / exp_name, f'version_{exp_version}'
|
||||||
|
else:
|
||||||
|
return _ROOT / exp_name / f'version_{exp_version}'
|
||||||
|
|
||||||
|
def get_media_path(self, exp_name, exp_version):
|
||||||
|
"""
|
||||||
|
Returns the path to the local package cache
|
||||||
|
:param exp_version:
|
||||||
|
:param exp_name:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
|
||||||
|
return self.get_data_path(exp_name, exp_version) / 'media'
|
||||||
|
|
||||||
|
# ----------------------------
|
||||||
|
# OVERWRITES
|
||||||
|
# ----------------------------
|
||||||
|
def __str__(self):
|
||||||
|
return 'Exp: {}, v: {}'.format(self.name, self.version)
|
||||||
|
|
||||||
|
def __hash__(self):
|
||||||
|
return 'Exp: {}, v: {}'.format(self.name, self.version)
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def atomic_write(dst_path):
|
||||||
|
"""A context manager to simplify atomic writing.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
>>> with atomic_write(dst_path) as tmp_path:
|
||||||
|
>>> # write to tmp_path
|
||||||
|
>>> # Here tmp_path renamed to dst_path, if no exception happened.
|
||||||
|
"""
|
||||||
|
tmp_path = dst_path / '.tmp'
|
||||||
|
try:
|
||||||
|
yield tmp_path
|
||||||
|
except:
|
||||||
|
if tmp_path.exists():
|
||||||
|
tmp_path.unlink()
|
||||||
|
raise
|
||||||
|
else:
|
||||||
|
# If everything is fine, move tmp file to the destination.
|
||||||
|
shutil.move(tmp_path, str(dst_path))
|
||||||
|
|
||||||
|
|
||||||
|
##########################
|
||||||
|
class LocalLogger(LightningLoggerBase):
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return self._name
|
||||||
|
|
||||||
|
@property
|
||||||
|
def experiment(self) -> Experiment:
|
||||||
|
r"""
|
||||||
|
|
||||||
|
Actual TestTube object. To use TestTube features in your
|
||||||
|
:class:`~pytorch_lightning.core.lightning.LightningModule` do the following.
|
||||||
|
|
||||||
|
Example::
|
||||||
|
|
||||||
|
self.logger.experiment.some_test_tube_function()
|
||||||
|
|
||||||
|
"""
|
||||||
|
if self._experiment is not None:
|
||||||
|
return self._experiment
|
||||||
|
|
||||||
|
self._experiment = Experiment(
|
||||||
|
save_dir=self.save_dir,
|
||||||
|
name=self._name,
|
||||||
|
debug=self.debug,
|
||||||
|
version=self.version,
|
||||||
|
description=self.description
|
||||||
|
)
|
||||||
|
return self._experiment
|
||||||
|
|
||||||
|
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def log_hyperparams(self, params: argparse.Namespace):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
def version(self) -> Union[int, str]:
|
||||||
|
if self._version is None:
|
||||||
|
self._version = self._get_next_version()
|
||||||
|
return self._version
|
||||||
|
|
||||||
|
def _get_next_version(self):
|
||||||
|
root_dir = self.save_dir / self.name
|
||||||
|
|
||||||
|
if not root_dir.is_dir():
|
||||||
|
log.warning(f'Missing logger folder: {root_dir}')
|
||||||
|
return 0
|
||||||
|
|
||||||
|
existing_versions = []
|
||||||
|
for d in os.listdir(root_dir):
|
||||||
|
if os.path.isdir(os.path.join(root_dir, d)) and d.startswith("version_"):
|
||||||
|
existing_versions.append(int(d.split("_")[1]))
|
||||||
|
|
||||||
|
if len(existing_versions) == 0:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
return max(existing_versions) + 1
|
||||||
|
|
||||||
|
def __init__(self, save_dir: str, name: str = "default", description: Optional[str] = None,
|
||||||
|
debug: bool = False, version: Optional[int] = None, **kwargs):
|
||||||
|
super(LocalLogger, self).__init__(**kwargs)
|
||||||
|
self.save_dir = Path(save_dir)
|
||||||
|
self._name = name
|
||||||
|
self.description = description
|
||||||
|
self.debug = debug
|
||||||
|
self._version = version
|
||||||
|
self._experiment = None
|
||||||
|
|
||||||
|
# Test tube experiments are not pickleable, so we need to override a few
|
||||||
|
# methods to get DDP working. See
|
||||||
|
# https://docs.python.org/3/library/pickle.html#handling-stateful-objects
|
||||||
|
# for more info.
|
||||||
|
def __getstate__(self) -> Dict[Any, Any]:
|
||||||
|
state = self.__dict__.copy()
|
||||||
|
state["_experiment"] = self.experiment.get_meta_copy()
|
||||||
|
return state
|
||||||
|
|
||||||
|
def __setstate__(self, state: Dict[Any, Any]):
|
||||||
|
self._experiment = state["_experiment"].get_non_ddp_exp()
|
||||||
|
del state["_experiment"]
|
||||||
|
self.__dict__.update(state)
|
@ -2,7 +2,7 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import ReLU
|
from torch.nn import ReLU
|
||||||
|
|
||||||
from torch_geometric.nn import PointConv, fps, radius, global_max_pool
|
from torch_geometric.nn import PointConv, fps, radius, global_max_pool, knn_interpolate
|
||||||
|
|
||||||
|
|
||||||
class SAModule(torch.nn.Module):
|
class SAModule(torch.nn.Module):
|
||||||
@ -23,14 +23,15 @@ class SAModule(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class GlobalSAModule(nn.Module):
|
class GlobalSAModule(nn.Module):
|
||||||
def __init__(self, nn):
|
def __init__(self, nn, channels=3):
|
||||||
super(GlobalSAModule, self).__init__()
|
super(GlobalSAModule, self).__init__()
|
||||||
self.nn = nn
|
self.nn = nn
|
||||||
|
self.channels = channels
|
||||||
|
|
||||||
def forward(self, x, pos, batch):
|
def forward(self, x, pos, batch):
|
||||||
x = self.nn(torch.cat([x, pos], dim=1))
|
x = self.nn(torch.cat([x, pos], dim=1))
|
||||||
x = global_max_pool(x, batch)
|
x = global_max_pool(x, batch)
|
||||||
pos = pos.new_zeros((x.size(0), 3))
|
pos = pos.new_zeros((x.size(0), self.channels))
|
||||||
batch = torch.arange(x.size(0), device=batch.device)
|
batch = torch.arange(x.size(0), device=batch.device)
|
||||||
return x, pos, batch
|
return x, pos, batch
|
||||||
|
|
||||||
@ -45,3 +46,17 @@ class MLP(nn.Module):
|
|||||||
|
|
||||||
def forward(self, x, *args, **kwargs):
|
def forward(self, x, *args, **kwargs):
|
||||||
return self.net(x)
|
return self.net(x)
|
||||||
|
|
||||||
|
|
||||||
|
class FPModule(torch.nn.Module):
|
||||||
|
def __init__(self, k, nn):
|
||||||
|
super(FPModule, self).__init__()
|
||||||
|
self.k = k
|
||||||
|
self.nn = nn
|
||||||
|
|
||||||
|
def forward(self, x, pos, batch, x_skip, pos_skip, batch_skip):
|
||||||
|
x = knn_interpolate(x, pos, pos_skip, batch, batch_skip, k=self.k)
|
||||||
|
if x_skip is not None:
|
||||||
|
x = torch.cat([x, x_skip], dim=1)
|
||||||
|
x = self.nn(x)
|
||||||
|
return x, pos_skip, batch_skip
|
24
point_toolset/point_io.py
Normal file
24
point_toolset/point_io.py
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
import torch
|
||||||
|
from torch_geometric.data import Data
|
||||||
|
|
||||||
|
|
||||||
|
class BatchToData(object):
|
||||||
|
def __init__(self):
|
||||||
|
super(BatchToData, self).__init__()
|
||||||
|
|
||||||
|
def __call__(self, batch_x: torch.Tensor, batch_pos: torch.Tensor, batch_y: torch.Tensor):
|
||||||
|
# Convert to torch_geometric.data.Data type
|
||||||
|
# data = data.transpose(1, 2).contiguous()
|
||||||
|
batch_size, num_points, _ = batch_x.shape # (batch_size, num_points, 3)
|
||||||
|
|
||||||
|
x = batch_x.reshape(batch_size * num_points, -1)
|
||||||
|
pos = batch_pos.reshape(batch_size * num_points, -1)
|
||||||
|
batch_y = batch_y.reshape(batch_size * num_points)
|
||||||
|
batch = torch.zeros((batch_size, num_points), device=pos.device, dtype=torch.long)
|
||||||
|
for i in range(batch_size):
|
||||||
|
batch[i] = i
|
||||||
|
batch = batch.view(-1)
|
||||||
|
|
||||||
|
data = Data()
|
||||||
|
data.x, data.pos, data.batch, data.y = x, pos, batch, batch_y
|
||||||
|
return data
|
@ -1,10 +1,36 @@
|
|||||||
|
from abc import ABC
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
class FarthestpointSampling():
|
class _Sampler(ABC):
|
||||||
|
|
||||||
def __init__(self, K):
|
def __init__(self, K, **kwargs):
|
||||||
self.k = K
|
self.k = K
|
||||||
|
self.kwargs = kwargs
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class RandomSampling(_Sampler):
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super(RandomSampling, self).__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
def __call__(self, pts, *args, **kwargs):
|
||||||
|
if pts.shape[0] < self.k:
|
||||||
|
return pts
|
||||||
|
|
||||||
|
else:
|
||||||
|
rnd_indexs = np.random.choice(np.arange(pts.shape[0]), self.k, replace=False)
|
||||||
|
return rnd_indexs
|
||||||
|
|
||||||
|
|
||||||
|
class FarthestpointSampling(_Sampler):
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super(FarthestpointSampling, self).__init__(*args, **kwargs)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def calc_distances(p0, points):
|
def calc_distances(p0, points):
|
||||||
@ -15,14 +41,15 @@ class FarthestpointSampling():
|
|||||||
if pts.shape[0] < self.k:
|
if pts.shape[0] < self.k:
|
||||||
return pts
|
return pts
|
||||||
|
|
||||||
farthest_pts = np.zeros((self.k, pts.shape[1]))
|
else:
|
||||||
farthest_pts_idx = np.zeros(self.k, dtype=np.int)
|
farthest_pts = np.zeros((self.k, pts.shape[1]))
|
||||||
farthest_pts[0] = pts[np.random.randint(len(pts))]
|
farthest_pts_idx = np.zeros(self.k, dtype=np.int)
|
||||||
distances = self.calc_distances(farthest_pts[0], pts)
|
farthest_pts[0] = pts[np.random.randint(len(pts))]
|
||||||
for i in range(1, self.k):
|
distances = self.calc_distances(farthest_pts[0], pts)
|
||||||
farthest_pts_idx[i] = np.argmax(distances)
|
for i in range(1, self.k):
|
||||||
farthest_pts[i] = pts[farthest_pts_idx[i]]
|
farthest_pts_idx[i] = np.argmax(distances)
|
||||||
|
farthest_pts[i] = pts[farthest_pts_idx[i]]
|
||||||
|
|
||||||
distances = np.minimum(distances, self.calc_distances(farthest_pts[i], pts))
|
distances = np.minimum(distances, self.calc_distances(farthest_pts[i], pts))
|
||||||
|
|
||||||
return farthest_pts_idx
|
return farthest_pts_idx
|
||||||
|
41
utils/data_util.py
Normal file
41
utils/data_util.py
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
import torch
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
|
|
||||||
|
def chunks(l, n):
|
||||||
|
"""Yield successive n-sized chunks from l."""
|
||||||
|
for i in range(0, len(l), n):
|
||||||
|
yield l[i:i + n]
|
||||||
|
|
||||||
|
|
||||||
|
class ReMapDataset(Dataset):
|
||||||
|
@property
|
||||||
|
def sample_shape(self):
|
||||||
|
return list(self[0][0].shape)
|
||||||
|
|
||||||
|
def __init__(self, ds, mapping):
|
||||||
|
super(ReMapDataset, self).__init__()
|
||||||
|
# here is a mapping from this index to the mother ds index
|
||||||
|
self.mapping = mapping
|
||||||
|
self.ds = ds
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
return self.ds[self.mapping[index]]
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.mapping.shape[0]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def do_train_vali_split(cls, ds, split_fold=0.1):
|
||||||
|
|
||||||
|
indices = torch.randperm(len(ds))
|
||||||
|
|
||||||
|
valid_size = int(len(ds) * split_fold)
|
||||||
|
|
||||||
|
train_mapping = indices[valid_size:]
|
||||||
|
valid_mapping = indices[:valid_size]
|
||||||
|
|
||||||
|
train = cls(ds, train_mapping)
|
||||||
|
valid = cls(ds, valid_mapping)
|
||||||
|
|
||||||
|
return train, valid
|
@ -1,3 +1,6 @@
|
|||||||
|
import argparse
|
||||||
|
from typing import Union, Dict, Optional, Any
|
||||||
|
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
@ -2,6 +2,16 @@ import pickle
|
|||||||
import shelve
|
import shelve
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from utils.project_config import GlobalVar
|
||||||
|
|
||||||
|
|
||||||
|
def to_one_hot(idx_array):
|
||||||
|
one_hot = np.zeros((idx_array.size, len(GlobalVar.classes)))
|
||||||
|
one_hot[np.arange(idx_array.size), idx_array] = 1
|
||||||
|
return one_hot
|
||||||
|
|
||||||
|
|
||||||
def fix_all_random_seeds(config_obj):
|
def fix_all_random_seeds(config_obj):
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -9,12 +9,13 @@ from pathlib import Path
|
|||||||
|
|
||||||
class Plotter(object):
|
class Plotter(object):
|
||||||
def __init__(self, root_path=''):
|
def __init__(self, root_path=''):
|
||||||
self.root_path = Path(root_path)
|
if not root_path:
|
||||||
|
self.root_path = Path(root_path)
|
||||||
|
|
||||||
def save_current_figure(self, path, extention='.png', naked=True):
|
def save_current_figure(self, filename: str, extention='.png', naked=False):
|
||||||
fig, _ = plt.gcf(), plt.gca()
|
fig, _ = plt.gcf(), plt.gca()
|
||||||
# Prepare save location and check img file extention
|
# Prepare save location and check img file extention
|
||||||
path = self.root_path / Path(path if str(path).endswith(extention) else f'{str(path)}{extention}')
|
path = self.root_path / Path(filename if filename.endswith(extention) else f'{filename}{extention}')
|
||||||
path.parent.mkdir(exist_ok=True, parents=True)
|
path.parent.mkdir(exist_ok=True, parents=True)
|
||||||
if naked:
|
if naked:
|
||||||
plt.axis('off')
|
plt.axis('off')
|
||||||
|
Loading…
x
Reference in New Issue
Block a user