From ba7c0280ae3a026ed925e7c4c4eff36e05f55d00 Mon Sep 17 00:00:00 2001 From: Si11ium Date: Tue, 26 May 2020 21:44:57 +0200 Subject: [PATCH] pointnet2 working - TODO: Eval! --- _parameters.py | 8 +-- datasets/_point_dataset.py | 35 +++++++++++- datasets/full_pointclouds.py | 32 ++++------- datasets/grid_clustered.py | 32 ++++++++++- datasets/prim_clustered.py | 30 +++++++++- datasets/template_dataset.py | 1 - main.py | 14 +++-- models/__init__.py | 1 + models/point_net_2.py | 103 +++++++++++++++++++++++++++++++++++ utils/module_mixins.py | 21 +++---- utils/project_config.py | 13 +++-- 11 files changed, 232 insertions(+), 58 deletions(-) create mode 100644 models/__init__.py create mode 100644 models/point_net_2.py diff --git a/_parameters.py b/_parameters.py index 16dcfee..f94c489 100644 --- a/_parameters.py +++ b/_parameters.py @@ -15,15 +15,14 @@ main_arg_parser.add_argument("--main_eval", type=strtobool, default=True, help=" main_arg_parser.add_argument("--main_seed", type=int, default=69, help="") # Project -main_arg_parser.add_argument("--project_name", type=str, default='traj-gen', help="") +main_arg_parser.add_argument("--project_name", type=str, default='point-to-primitive', help="") main_arg_parser.add_argument("--project_owner", type=str, default='si11ium', help="") -main_arg_parser.add_argument("--project_neptune_key", type=str, default=os.getenv('NEPTUNE_KEY'), help="") +main_arg_parser.add_argument("--project_neptune_key", type=str, default=os.getenv('NEPTUNE_API_TOKEN'), help="") # Data Parameters main_arg_parser.add_argument("--data_worker", type=int, default=10, help="") main_arg_parser.add_argument("--data_dataset_length", type=int, default=10000, help="") main_arg_parser.add_argument("--data_root", type=str, default='data', help="") -main_arg_parser.add_argument("--data_additional_resource_root", type=str, default='res/resource/root', help="") main_arg_parser.add_argument("--data_use_preprocessed", type=strtobool, default=True, help="") # Transformations @@ -36,10 +35,11 @@ main_arg_parser.add_argument("--train_version", type=strtobool, required=False, main_arg_parser.add_argument("--train_epochs", type=int, default=500, help="") main_arg_parser.add_argument("--train_batch_size", type=int, default=200, help="") main_arg_parser.add_argument("--train_lr", type=float, default=1e-3, help="") +main_arg_parser.add_argument("--train_weight_decay", type=float, default=1e-8, help="") main_arg_parser.add_argument("--train_num_sanity_val_steps", type=int, default=0, help="") # Model -main_arg_parser.add_argument("--model_type", type=str, default="CNNRouteGenerator", help="") +main_arg_parser.add_argument("--model_type", type=str, default="PN2", help="") main_arg_parser.add_argument("--model_activation", type=str, default="leaky_relu", help="") main_arg_parser.add_argument("--model_use_bias", type=strtobool, default=True, help="") diff --git a/datasets/_point_dataset.py b/datasets/_point_dataset.py index 79a7fc9..7b986ac 100644 --- a/datasets/_point_dataset.py +++ b/datasets/_point_dataset.py @@ -1,17 +1,27 @@ +import pickle +from collections import defaultdict + from abc import ABC from pathlib import Path from torch.utils.data import Dataset from ml_lib.point_toolset.sampling import FarthestpointSampling +import numpy as np + class _Point_Dataset(ABC, Dataset): + @property + def sample_shape(self): + # FixMe: This does not work when more then x/y tuples are returned + return self[0][0].shape + @property def setting(self) -> str: raise NotImplementedError - headers = ['x', 'y', 'z', 'nx', 'ny', 'nz', 'label', 'cl_idx'] + headers = ['x', 'y', 'z', 'xn', 'yn', 'zn', 'label', 'cl_idx'] def __init__(self, root=Path('data'), sampling_k=2048, transforms=None, load_preprocessed=True, *args, **kwargs): super(_Point_Dataset, self).__init__() @@ -21,13 +31,32 @@ class _Point_Dataset(ABC, Dataset): self.sampling_k = sampling_k self.sampling = FarthestpointSampling(K=self.sampling_k) self.root = Path(root) - self.raw = root / 'raw' + self.raw = self.root / 'raw' self.processed_ext = '.pik' self.raw_ext = '.xyz' - self.processed = root / self.setting + self.processed = self.root / self.setting + self.processed.mkdir(parents=True, exist_ok=True) self._files = list(self.raw.glob(f'*{self.setting}*')) + def _read_or_load(self, item): + raw_file_path = self._files[item] + processed_file_path = self.processed / raw_file_path.name.replace(self.raw_ext, self.processed_ext) + + if not self.load_preprocessed: + processed_file_path.unlink(missing_ok=True) + if not processed_file_path.exists(): + pointcloud = defaultdict(list) + with raw_file_path.open('r') as raw_file: + for row in raw_file: + values = [float(x) for x in row.strip().split(' ')] + for header, value in zip(self.headers, values): + pointcloud[header].append(value) + for key in pointcloud.keys(): + pointcloud[key] = np.asarray(pointcloud[key]) + with processed_file_path.open('wb') as processed_file: + pickle.dump(pointcloud, processed_file) + return processed_file_path def __len__(self): raise NotImplementedError diff --git a/datasets/full_pointclouds.py b/datasets/full_pointclouds.py index f7f1b7d..9a433cc 100644 --- a/datasets/full_pointclouds.py +++ b/datasets/full_pointclouds.py @@ -1,9 +1,7 @@ import pickle from collections import defaultdict -from pathlib import Path import numpy as np -from torch.utils.data import Dataset from ._point_dataset import _Point_Dataset @@ -19,27 +17,17 @@ class FullCloudsDataset(_Point_Dataset): return len(self._files) def __getitem__(self, item): - raw_file_path = self._files[item] - processed_file_path = self.processed / raw_file_path.name.replace(self.raw_ext, self.processed_ext) - if not self.load_preprocessed: - processed_file_path.unlink(missing_ok=True) - if not processed_file_path.exists(): - pointcloud = defaultdict(list) - with raw_file_path.open('r') as raw_file: - for row in raw_file: - values = [float(x) for x in row.split(' ')] - for header, value in zip(self.headers, values): - pointcloud[header].append(value) - for key in pointcloud.keys(): - pointcloud[key] = np.asarray(pointcloud[key]) - with processed_file_path.open('wb') as processed_file: - pickle.dump(pointcloud, processed_file) + processed_file_path = self._read_or_load(item) with processed_file_path.open('rb') as processed_file: pointcloud = pickle.load(processed_file) - points = np.stack(pointcloud['x'], pointcloud['y'], pointcloud['z']) - normal = np.stack(pointcloud['xn'], pointcloud['yn'], pointcloud['zn']) - label = points['label'] - samples = self.sampling(points) + points = np.stack((pointcloud['x'], pointcloud['y'], pointcloud['z'], + pointcloud['xn'], pointcloud['yn'], pointcloud['zn'] + ), + axis=-1) + # When yopu want to return points and normal seperately + # normal = np.stack((pointcloud['xn'], pointcloud['yn'], pointcloud['zn']), axis=-1) + label = pointcloud['label'] + sample_idxs = self.sampling(points) - return points[samples], normal[samples], label[samples] + return points[sample_idxs].astype(np.float), label[sample_idxs].astype(np.int) diff --git a/datasets/grid_clustered.py b/datasets/grid_clustered.py index e9e96ef..26a088b 100644 --- a/datasets/grid_clustered.py +++ b/datasets/grid_clustered.py @@ -1,6 +1,32 @@ -from torch.utils.data import Dataset +import pickle +import numpy as np + +from ._point_dataset import _Point_Dataset -class TemplateDataset(_Point_Dataset): +class FullCloudsDataset(_Point_Dataset): + + setting = 'grid' + def __init__(self, *args, **kwargs): - super(TemplateDataset, self).__init__() \ No newline at end of file + super(FullCloudsDataset, self).__init__(*args, **kwargs) + + def __len__(self): + return len(self._files) + + def __getitem__(self, item): + processed_file_path = self._read_or_load(item) + + with processed_file_path.open('rb') as processed_file: + pointcloud = pickle.load(processed_file) + points = np.stack((pointcloud['x'], pointcloud['y'], pointcloud['z'], + pointcloud['xn'], pointcloud['yn'], pointcloud['zn'] + ), + axis=-1) + + # When yopu want to return points and normal seperately + # normal = np.stack((pointcloud['xn'], pointcloud['yn'], pointcloud['zn']), axis=-1) + label = np.stack((pointcloud['label'], pointcloud['cl_idx'])) + sample_idxs = self.sampling(points) + + return points[sample_idxs], label[sample_idxs] \ No newline at end of file diff --git a/datasets/prim_clustered.py b/datasets/prim_clustered.py index 612c520..a5e7397 100644 --- a/datasets/prim_clustered.py +++ b/datasets/prim_clustered.py @@ -1,8 +1,32 @@ -from torch.utils.data import Dataset +import pickle +import numpy as np from ._point_dataset import _Point_Dataset -class TemplateDataset(_Point_Dataset): +class FullCloudsDataset(_Point_Dataset): + + setting = 'prim' + def __init__(self, *args, **kwargs): - super(TemplateDataset, self).__init__() \ No newline at end of file + super(FullCloudsDataset, self).__init__(*args, **kwargs) + + def __len__(self): + return len(self._files) + + def __getitem__(self, item): + processed_file_path = self._read_or_load(item) + + with processed_file_path.open('rb') as processed_file: + pointcloud = pickle.load(processed_file) + points = np.stack((pointcloud['x'], pointcloud['y'], pointcloud['z'], + pointcloud['xn'], pointcloud['yn'], pointcloud['zn'] + ), + axis=-1) + + # When yopu want to return points and normal seperately + # normal = np.stack((pointcloud['xn'], pointcloud['yn'], pointcloud['zn']), axis=-1) + label = np.stack((pointcloud['label'], pointcloud['cl_idx'])) + sample_idxs = self.sampling(points) + + return points[sample_idxs], label[sample_idxs] \ No newline at end of file diff --git a/datasets/template_dataset.py b/datasets/template_dataset.py index 8318b5a..7886edd 100644 --- a/datasets/template_dataset.py +++ b/datasets/template_dataset.py @@ -10,4 +10,3 @@ class TemplateDataset(_Point_Dataset): def __getitem__(self, item): return item - diff --git a/main.py b/main.py index 9b1010b..5f73b32 100644 --- a/main.py +++ b/main.py @@ -5,12 +5,11 @@ import warnings import torch from pytorch_lightning import Trainer -from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping +from pytorch_lightning.callbacks import ModelCheckpoint # , EarlyStopping from ml_lib.modules.util import LightningBaseModule -from ml_lib.utils.config import Config from ml_lib.utils.logging import Logger -from ml_lib.utils.model_io import SavedLightningModels +from utils.project_config import ThisConfig warnings.filterwarnings('ignore', category=FutureWarning) warnings.filterwarnings('ignore', category=UserWarning) @@ -33,11 +32,13 @@ def run_lightning_loop(config_obj): # ============================================================================= # Early Stopping # TODO: For This to work, one must set a validation step and End Eval and Score + """ early_stopping_callback = EarlyStopping( monitor='val_loss', min_delta=0.0, patience=0, ) + """ # Model # ============================================================================= @@ -76,6 +77,9 @@ def run_lightning_loop(config_obj): if __name__ == "__main__": - from ._parameters import args - config = Config.read_namespace(args) + from _parameters import args + from ml_lib.utils.tools import fix_all_random_seeds + + config = ThisConfig.read_namespace(args) + fix_all_random_seeds(config) trained_model = run_lightning_loop(config) diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000..da3273f --- /dev/null +++ b/models/__init__.py @@ -0,0 +1 @@ +from .point_net_2 import PointNet2 diff --git a/models/point_net_2.py b/models/point_net_2.py new file mode 100644 index 0000000..652ff95 --- /dev/null +++ b/models/point_net_2.py @@ -0,0 +1,103 @@ +from argparse import Namespace + +import torch +from torch.optim import Adam +from torch import nn +from torch_geometric.data import Data + +from datasets.full_pointclouds import FullCloudsDataset +from ml_lib.modules.geometric_blocks import SAModule, GlobalSAModule, MLP +from ml_lib.modules.util import LightningBaseModule, F_x + +from utils.module_mixins import BaseValMixin, BaseTrainMixin, BaseOptimizerMixin, BaseDataloadersMixin, DatasetMixin + + +class PointNet2(BaseValMixin, + BaseTrainMixin, + BaseOptimizerMixin, + DatasetMixin, + BaseDataloadersMixin, + LightningBaseModule + ): + + def __init__(self, hparams): + super(PointNet2, self).__init__(hparams=hparams) + + # Dataset + # ============================================================================= + self.dataset = self.build_dataset(FullCloudsDataset) + + # Model Paramters + # ============================================================================= + # Additional parameters + + self.in_shape = self.dataset.train_dataset.sample_shape + self.channels = self.in_shape[-1] + + # Modules + self.sa1_module = SAModule(0.5, 0.2, MLP([self.channels, 64, 64, 128])) + self.sa2_module = SAModule(0.25, 0.4, MLP([128 + self.channels, 128, 128, 256])) + self.sa3_module = GlobalSAModule(MLP([256 + self.channels, 256, 512, 1024])) + + self.lin1 = nn.Linear(1024, 512) + self.lin2 = nn.Linear(512, 256) + self.lin3 = nn.Linear(256, 10) + + # Utility + self.dropout = nn.Dropout(self.params.dropout) if self.params.dropout else F_x(None) + self.activation = self.params.activation() + self.log_softmax = nn.LogSoftmax(dim=-1) + + def configure_optimizers(self): + return Adam(self.parameters(), lr=self.params.lr, weight_decay=self.params.weight_decay) + + def forward(self, data, **kwargs): + """ + data: a batch of input, torch.Tensor or torch_geometric.data.Data type + - torch.Tensor: (batch_size, 3, num_points), as common batch input + + - torch_geometric.data.Data, as torch_geometric batch input: + data.x: (batch_size * ~num_points, C), batch nodes/points feature, + ~num_points means each sample can have different number of points/nodes + + data.pos: (batch_size * ~num_points, 3) + + data.batch: (batch_size * ~num_points,), a column vector of graph/pointcloud + idendifiers for all nodes of all graphs/pointclouds in the batch. See + pytorch_gemometric documentation for more information + """ + dense_input = True if isinstance(data, torch.Tensor) else False + + if dense_input: + # Convert to torch_geometric.data.Data type + # data = data.transpose(1, 2).contiguous() + batch_size, N, _ = data.shape # (batch_size, num_points, 6) + pos = data.view(batch_size*N, -1) + batch = torch.zeros((batch_size, N), device=pos.device, dtype=torch.long) + for i in range(batch_size): + batch[i] = i + batch = batch.view(-1) + + data = Data() + data.pos, data.batch = pos, batch + + if not hasattr(data, 'x'): + data.x = None + + sa0_out = (data.x, data.pos, data.batch) + sa1_out = self.sa1_module(*sa0_out) + sa2_out = self.sa2_module(*sa1_out) + sa3_out = self.sa3_module(*sa2_out) + + tensor, pos, batch = sa3_out + tensor = tensor.float() + + tensor = self.lin1(tensor) + tensor = self.activation(tensor) + tensor = self.dropout(tensor) + tensor = self.lin2(tensor) + tensor = self.activation(tensor) + tensor = self.dropout(tensor) + tensor = self.lin3(tensor) + tensor = self.log_softmax(tensor) + return Namespace(main_out=tensor) diff --git a/utils/module_mixins.py b/utils/module_mixins.py index cc98998..658aa36 100644 --- a/utils/module_mixins.py +++ b/utils/module_mixins.py @@ -11,13 +11,10 @@ from torch.utils.data import DataLoader from torchcontrib.optim import SWA from torchvision.transforms import Compose -from _templates.new_project.datasets.template_dataset import TemplateDataset +from ml_lib.modules.util import LightningBaseModule +from ml_lib.utils.transforms import ToTensor -from audio_toolset.audio_io import NormalizeLocal -from modules.utils import LightningBaseModule -from utils.transforms import ToTensor - -from _templates.new_project.utils.project_config import GlobalVar as GlobalVars +from .project_config import GlobalVar class BaseOptimizerMixin: @@ -110,31 +107,31 @@ class BaseValMixin: return summary_dict -class BinaryMaskDatasetMixin: +class DatasetMixin: - def build_dataset(self): + def build_dataset(self, dataset_class): assert isinstance(self, LightningBaseModule) # Dataset # ============================================================================= # Data Augmentations or Utility Transformations - transforms = Compose([NormalizeLocal(), ToTensor()]) + transforms = Compose([ToTensor()]) # Dataset dataset = Namespace( **dict( # TRAIN DATASET - train_dataset=TemplateDataset(self.params.root, setting=GlobalVars.DATA_OPTIONS.train, + train_dataset=dataset_class(self.params.root, setting=GlobalVar.train, transforms=transforms ), # VALIDATION DATASET - val_dataset=TemplateDataset(self.params.root, setting=GlobalVars.vali, + val_dataset=dataset_class(self.params.root, setting=GlobalVar.vali, ), # TEST DATASET - test_dataset=TemplateDataset(self.params.root, setting=GlobalVars.test, + test_dataset=dataset_class(self.params.root, setting=GlobalVar.test, ), ) diff --git a/utils/project_config.py b/utils/project_config.py index 78774db..1648789 100644 --- a/utils/project_config.py +++ b/utils/project_config.py @@ -1,6 +1,6 @@ from argparse import Namespace -from utils.config import Config +from ml_lib.utils.config import Config class GlobalVar(Namespace): @@ -18,13 +18,16 @@ class GlobalVar(Namespace): DPI = 50 # DATAOPTIONS - train='train', - vali='vali', - test='test' + train ='train', + vali ='vali', + test ='test' + + +from models import * class ThisConfig(Config): @property def _model_map(self): - return dict() + return dict(PN2=PointNet2)