From 23f3aa878dfcb97691ed7e67ce3ac820075f4102 Mon Sep 17 00:00:00 2001 From: Si11ium Date: Tue, 9 Jun 2020 14:08:35 +0200 Subject: [PATCH] New Dataset for per spatial cluster training --- _parameters.py | 3 +- datasets/_point_dataset.py | 6 ++- datasets/full_pointclouds.py | 1 + datasets/grid_clusters.py | 79 +++++++++++++++++++++++++++++ main.py | 6 ++- models/_point_net_2.py | 2 +- models/point_net_2.py | 4 +- models/point_net_2_grid_clusters.py | 4 +- multi_run.py | 6 +-- utils/module_mixins.py | 5 +- 10 files changed, 104 insertions(+), 12 deletions(-) create mode 100644 datasets/grid_clusters.py diff --git a/_parameters.py b/_parameters.py index 3d4d30f..3f7caa5 100644 --- a/_parameters.py +++ b/_parameters.py @@ -22,6 +22,7 @@ main_arg_parser.add_argument("--project_neptune_key", type=str, default=os.geten # Data Parameters main_arg_parser.add_argument("--data_worker", type=int, default=10, help="") main_arg_parser.add_argument("--data_root", type=str, default='data', help="") +main_arg_parser.add_argument("--data_dataset_type", type=str, default='GridClusters', help="") main_arg_parser.add_argument("--data_use_preprocessed", type=strtobool, default=True, help="") # Transformations @@ -41,7 +42,7 @@ main_arg_parser.add_argument("--train_opt_reset_interval", type=strtobool, defau # Model # Possible Model arguments are: P2P, PN2, P2G -main_arg_parser.add_argument("--model_type", type=str, default="P2G", help="") +main_arg_parser.add_argument("--model_type", type=str, default="PN2", help="") main_arg_parser.add_argument("--model_norm_as_feature", type=strtobool, default=True, help="") main_arg_parser.add_argument("--model_activation", type=str, default="leaky_relu", help="") diff --git a/datasets/_point_dataset.py b/datasets/_point_dataset.py index 13be173..5c71982 100644 --- a/datasets/_point_dataset.py +++ b/datasets/_point_dataset.py @@ -4,7 +4,7 @@ from collections import defaultdict from abc import ABC from pathlib import Path -from torch.utils.data import Dataset +from torch.utils.data import Dataset, ConcatDataset from ml_lib.point_toolset.sampling import FarthestpointSampling, RandomSampling import numpy as np @@ -12,6 +12,10 @@ import numpy as np class _Point_Dataset(ABC, Dataset): + @property + def name(self): + raise NotImplementedError + @property def sample_shape(self): # FixMe: This does not work when more then x/y tuples are returned diff --git a/datasets/full_pointclouds.py b/datasets/full_pointclouds.py index 4248eb4..a20ed3e 100644 --- a/datasets/full_pointclouds.py +++ b/datasets/full_pointclouds.py @@ -9,6 +9,7 @@ from ._point_dataset import _Point_Dataset class FullCloudsDataset(_Point_Dataset): split: str + name = 'FullCloudsDataset' def __init__(self, *args, setting='pc', **kwargs): self.setting = setting diff --git a/datasets/grid_clusters.py b/datasets/grid_clusters.py new file mode 100644 index 0000000..5a72500 --- /dev/null +++ b/datasets/grid_clusters.py @@ -0,0 +1,79 @@ +import pickle +from collections import defaultdict + +import numpy as np +from torch.utils.data import ConcatDataset +from tqdm import trange + +from ._point_dataset import _Point_Dataset + + +class GridClusters(_Point_Dataset): + + split: str + name = 'GridClusters' + + def __init__(self, *args, n_spatial_clusters=3*3*3, setting='pc', **kwargs): + self.n_spatial_clusters = n_spatial_clusters + self.setting = setting + super(GridClusters, self).__init__(*args, **kwargs) + + def __len__(self): + return len(self._files) + + def _read_or_load(self, item): + raw_file_path = self._files[item] + processed_file_path = self.processed / raw_file_path.name.replace(self.raw_ext, self.processed_ext) + + if not self.load_preprocessed: + processed_file_path.unlink(missing_ok=True) + if not processed_file_path.exists(): + # nested default dict + pointcloud = defaultdict(lambda: defaultdict(list)) + + with raw_file_path.open('r') as raw_file: + for row in raw_file: + values = [float(x) for x in row.strip().split(' ')] + for header, value in zip(self.headers, values): + pointcloud[int(values[-1])][header].append(value) + for cluster in pointcloud.keys(): + for key in pointcloud[cluster].keys(): + pointcloud[cluster][key] = np.asarray(pointcloud[cluster][key]) + pointcloud[cluster] = dict(pointcloud[cluster]) + pointcloud = dict(pointcloud) + + with processed_file_path.open('wb') as processed_file: + pickle.dump(pointcloud, processed_file) + return processed_file_path + + def __getitem__(self, item): + processed_file_path = self._read_or_load(item) + + with processed_file_path.open('rb') as processed_file: + pointcloud = pickle.load(processed_file) + + # By number Variant + # cl_idx_list = np.cumsum([[len(self) // self.n_spatial_clusters, ] * self.n_spatial_clusters]) + # cl_idx = [idx for idx, x in enumerate(cl_idx_list) if item <= x][0] + + # Random Variant + cl_idx = np.random.randint(0, len(pointcloud)) + pointcloud = pointcloud[list(pointcloud.keys())[cl_idx]] + + position = np.stack((pointcloud['x'], pointcloud['y'], pointcloud['z']), axis=-1) + + normal = np.stack((pointcloud['xn'], pointcloud['yn'], pointcloud['zn']), axis=-1) + + label = pointcloud['label'] + + cl_label = pointcloud['cl_idx'] + + sample_idxs = self.sampling(position) + while sample_idxs.shape[0] < self.sampling_k: + sample_idxs = np.concatenate((sample_idxs, sample_idxs))[:self.sampling_k] + + return (normal[sample_idxs].astype(np.float), + position[sample_idxs].astype(np.float), + label[sample_idxs].astype(np.int), + cl_label[sample_idxs].astype(np.int) + ) diff --git a/main.py b/main.py index 6c4f905..953b58d 100644 --- a/main.py +++ b/main.py @@ -25,8 +25,9 @@ def run_lightning_loop(config_obj): # ============================================================================= # Checkpoint Saving checkpoint_callback = ModelCheckpoint( + monitor='mean_loss', filepath=str(logger.log_dir / 'ckpt_weights'), - verbose=True, save_top_k=0, + verbose=True, save_top_k=10, ) # ============================================================================= @@ -80,6 +81,9 @@ if __name__ == "__main__": from _parameters import args from ml_lib.utils.tools import fix_all_random_seeds + # When debugging, use the following parameters: + # --main_debug=True --data_worker=0 + config = ThisConfig.read_namespace(args) fix_all_random_seeds(config) trained_model = run_lightning_loop(config) diff --git a/models/_point_net_2.py b/models/_point_net_2.py index 681b547..7da24e5 100644 --- a/models/_point_net_2.py +++ b/models/_point_net_2.py @@ -32,7 +32,7 @@ class _PointNetCore(LightningBaseModule): def forward(self, sa0_out, **kwargs): """ - data: a batch of input torch_geometric.data.Data type + sa0_out: a batch of input torch_geometric.data.Data type - torch_geometric.data.Data, as torch_geometric batch input: data.x: (batch_size * ~num_points, C), batch nodes/points feature, ~num_points means each sample can have different number of points/nodes diff --git a/models/point_net_2.py b/models/point_net_2.py index e41e752..42c45f3 100644 --- a/models/point_net_2.py +++ b/models/point_net_2.py @@ -3,7 +3,7 @@ from argparse import Namespace import torch from torch import nn -from datasets.full_pointclouds import FullCloudsDataset +from datasets.grid_clusters import GridClusters from models._point_net_2 import _PointNetCore from utils.module_mixins import BaseValMixin, BaseTrainMixin, BaseOptimizerMixin, BaseDataloadersMixin, DatasetMixin @@ -23,7 +23,7 @@ class PointNet2(BaseValMixin, # Dataset # ============================================================================= - self.dataset = self.build_dataset(FullCloudsDataset, setting='pc') + self.dataset = self.build_dataset(GridClusters, setting='pc') # Model Paramters # ============================================================================= diff --git a/models/point_net_2_grid_clusters.py b/models/point_net_2_grid_clusters.py index bc59f2a..ae1651f 100644 --- a/models/point_net_2_grid_clusters.py +++ b/models/point_net_2_grid_clusters.py @@ -4,7 +4,7 @@ import torch from torch import nn from torch_geometric.data import Data -from datasets.full_pointclouds import FullCloudsDataset +from datasets.grid_clusters import GridClusters from models._point_net_2 import _PointNetCore from utils.module_mixins import BaseValMixin, BaseTrainMixin, BaseOptimizerMixin, BaseDataloadersMixin, DatasetMixin @@ -42,7 +42,7 @@ class PointNet2GridClusters(BaseValMixin, # Dataset # ============================================================================= - self.dataset = self.build_dataset(FullCloudsDataset, setting='grid') + self.dataset = self.build_dataset(GridClusters, setting='grid') # Model Paramters # ============================================================================= diff --git a/multi_run.py b/multi_run.py index 05fe96c..66327c5 100644 --- a/multi_run.py +++ b/multi_run.py @@ -16,11 +16,11 @@ if __name__ == '__main__': # Model Settings config = ThisConfig().read_namespace(args) # bias, activation, model, norm, max_epochs - pn2 = dict(model_type='PN2',model_use_bias=True, model_use_norm=True, data_batchsize=250) - p2g = dict(model_type='P2G', model_use_bias=True, model_use_norm=True, data_batchsize=250) + pn2 = dict(model_type='PN2', model_use_bias=True, model_use_norm=True, data_batchsize=250) + # p2g = dict(model_type='P2G', model_use_bias=True, model_use_norm=True, data_batchsize=250) # bias, activation, model, norm, max_epochs - for arg_dict in [p2g]: + for arg_dict in [pn2]: for seed in range(10): arg_dict.update(main_seed=seed) diff --git a/utils/module_mixins.py b/utils/module_mixins.py index 3806649..0355e30 100644 --- a/utils/module_mixins.py +++ b/utils/module_mixins.py @@ -222,6 +222,9 @@ class DatasetMixin: def build_dataset(self, dataset_class, **kwargs): assert isinstance(self, LightningBaseModule) + assert dataset_class.name == self.params.dataset_type, f'Check the dataset! ' + \ + f'Expected was {self.params.dataset_type}, ' + \ + f'given:{dataset_class.name}' # Dataset # ============================================================================= @@ -258,7 +261,7 @@ class BaseDataloadersMixin(ABC): # In case you want to implement bootstraping # sampler = RandomSampler(self.dataset.train_dataset, True, len(self.dataset.train_dataset)) sampler = None - return DataLoader(dataset=self.dataset.train_dataset, shuffle=True if not sampler else None, sampler=sampler, + return DataLoader(dataset=self.dataset.train_dataset, shuffle=False if not sampler else None, sampler=sampler, batch_size=self.params.batch_size, num_workers=self.params.worker)