diff --git a/_parameters.py b/_parameters.py index 3ea14d9..0806eed 100644 --- a/_parameters.py +++ b/_parameters.py @@ -40,7 +40,8 @@ main_arg_parser.add_argument("--train_sto_weight_avg", type=strtobool, default=F main_arg_parser.add_argument("--train_opt_reset_interval", type=strtobool, default=False, help="") # Model -main_arg_parser.add_argument("--model_type", type=str, default="PN2", help="") +# Possible Model arguments are: P2P, PN2, P2G +main_arg_parser.add_argument("--model_type", type=str, default="P2G", help="") main_arg_parser.add_argument("--model_norm_as_feature", type=strtobool, default=True, help="") main_arg_parser.add_argument("--model_activation", type=str, default="leaky_relu", help="") diff --git a/datasets/_point_dataset.py b/datasets/_point_dataset.py index 1175bde..13be173 100644 --- a/datasets/_point_dataset.py +++ b/datasets/_point_dataset.py @@ -17,10 +17,6 @@ class _Point_Dataset(ABC, Dataset): # FixMe: This does not work when more then x/y tuples are returned return self[0][0].shape - @property - def setting(self) -> str: - raise NotImplementedError - headers = ['x', 'y', 'z', 'xn', 'yn', 'zn', 'label', 'cl_idx'] samplers = dict(fps=FarthestpointSampling, rnd=RandomSampling) @@ -28,6 +24,8 @@ class _Point_Dataset(ABC, Dataset): transforms=None, load_preprocessed=True, split='train', dense_output=False, *args, **kwargs): super(_Point_Dataset, self).__init__() + self.setting: str + self.dense_output = dense_output self.split = split self.norm_as_feature = norm_as_feature @@ -67,4 +65,23 @@ class _Point_Dataset(ABC, Dataset): raise NotImplementedError def __getitem__(self, item): - raise NotImplementedError + processed_file_path = self._read_or_load(item) + + with processed_file_path.open('rb') as processed_file: + pointcloud = pickle.load(processed_file) + + position = np.stack((pointcloud['x'], pointcloud['y'], pointcloud['z']), axis=-1) + + normal = np.stack((pointcloud['xn'], pointcloud['yn'], pointcloud['zn']), axis=-1) + + label = pointcloud['label'] + + cl_label = pointcloud['cl_idx'] + + sample_idxs = self.sampling(position) + + return (normal[sample_idxs].astype(np.float), + position[sample_idxs].astype(np.float), + label[sample_idxs].astype(np.int), + cl_label[sample_idxs].astype(np.int) + ) diff --git a/datasets/full_pointclouds.py b/datasets/full_pointclouds.py index e453a88..4248eb4 100644 --- a/datasets/full_pointclouds.py +++ b/datasets/full_pointclouds.py @@ -8,29 +8,11 @@ from ._point_dataset import _Point_Dataset class FullCloudsDataset(_Point_Dataset): - setting = 'pc' split: str - def __init__(self, *args, **kwargs): + def __init__(self, *args, setting='pc', **kwargs): + self.setting = setting super(FullCloudsDataset, self).__init__(*args, **kwargs) def __len__(self): return len(self._files) - - def __getitem__(self, item): - processed_file_path = self._read_or_load(item) - - with processed_file_path.open('rb') as processed_file: - pointcloud = pickle.load(processed_file) - - position = np.stack((pointcloud['x'], pointcloud['y'], pointcloud['z']), axis=-1) - - normal = np.stack((pointcloud['xn'], pointcloud['yn'], pointcloud['zn']), axis=-1) - - label = pointcloud['label'] - - sample_idxs = self.sampling(position) - - return (normal[sample_idxs].astype(np.float), - position[sample_idxs].astype(np.float), - label[sample_idxs].astype(np.int)) diff --git a/datasets/grid_clustered.py b/datasets/grid_clustered.py deleted file mode 100644 index a077b84..0000000 --- a/datasets/grid_clustered.py +++ /dev/null @@ -1,32 +0,0 @@ -import pickle -import numpy as np - -from ._point_dataset import _Point_Dataset - - -class FullCloudsDataset(_Point_Dataset): - - setting = 'grid' - - def __init__(self, *args, **kwargs): - super(FullCloudsDataset, self).__init__(*args, **kwargs) - - def __len__(self): - return len(self._files) - - def __getitem__(self, item): - processed_file_path = self._read_or_load(item) - - with processed_file_path.open('rb') as processed_file: - pointcloud = pickle.load(processed_file) - points = np.stack((pointcloud['x'], pointcloud['y'], pointcloud['z'], - pointcloud['xn'], pointcloud['yn'], pointcloud['zn'] - ), - axis=-1) - - # When yopu want to return points and normal seperately - # normal = np.stack((pointcloud['xn'], pointcloud['yn'], pointcloud['zn']), axis=-1) - label = pointcloud['cl_idx'] - sample_idxs = self.sampling(points) - - return points[sample_idxs], label[sample_idxs] \ No newline at end of file diff --git a/datasets/prim_clustered.py b/datasets/prim_clustered.py deleted file mode 100644 index 7923cfe..0000000 --- a/datasets/prim_clustered.py +++ /dev/null @@ -1,32 +0,0 @@ -import pickle -import numpy as np - -from ._point_dataset import _Point_Dataset - - -class FullCloudsDataset(_Point_Dataset): - - setting = 'prim' - - def __init__(self, *args, **kwargs): - super(FullCloudsDataset, self).__init__(*args, **kwargs) - - def __len__(self): - return len(self._files) - - def __getitem__(self, item): - processed_file_path = self._read_or_load(item) - - with processed_file_path.open('rb') as processed_file: - pointcloud = pickle.load(processed_file) - points = np.stack((pointcloud['x'], pointcloud['y'], pointcloud['z'], - pointcloud['xn'], pointcloud['yn'], pointcloud['zn'] - ), - axis=-1) - - # When yopu want to return points and normal seperately - # normal = np.stack((pointcloud['xn'], pointcloud['yn'], pointcloud['zn']), axis=-1) - label = pointcloud['cl_idx'] - sample_idxs = self.sampling(points) - - return points[sample_idxs], label[sample_idxs] \ No newline at end of file diff --git a/datasets/template_dataset.py b/datasets/template_dataset.py index 7886edd..ffeb9c3 100644 --- a/datasets/template_dataset.py +++ b/datasets/template_dataset.py @@ -1,6 +1,7 @@ from torch.utils.data import Dataset from._point_dataset import _Point_Dataset + class TemplateDataset(_Point_Dataset): def __init__(self, *args, **kwargs): super(TemplateDataset, self).__init__() diff --git a/models/__init__.py b/models/__init__.py index da3273f..d8bd278 100644 --- a/models/__init__.py +++ b/models/__init__.py @@ -1 +1,4 @@ from .point_net_2 import PointNet2 +from .point_net_2_grid_clusters import PointNet2GridClusters +from .point_net_2_prim_clusters import PointNet2PrimClusters + diff --git a/models/_point_net_2.py b/models/_point_net_2.py new file mode 100644 index 0000000..681b547 --- /dev/null +++ b/models/_point_net_2.py @@ -0,0 +1,62 @@ +import torch +from torch import nn + +from ml_lib.modules.geometric_blocks import SAModule, GlobalSAModule, MLP, FPModule +from ml_lib.modules.util import LightningBaseModule, F_x + + +class _PointNetCore(LightningBaseModule): + + def __init__(self, hparams): + super(_PointNetCore, self).__init__(hparams=hparams) + + # Model Paramters + # ============================================================================= + # Additional parameters + + # Modules + self.sa1_module = SAModule(0.2, 0.2, MLP([3 + 3, 64, 64, 128])) + self.sa2_module = SAModule(0.25, 0.4, MLP([128 + 3, 128, 128, 256])) + self.sa3_module = GlobalSAModule(MLP([256 + 3, 256, 512, 1024])) + + self.fp3_module = FPModule(1, MLP([1024 + 256, 256, 256])) + self.fp2_module = FPModule(3, MLP([256 + 128, 256, 128])) + self.fp1_module = FPModule(3, MLP([128 + 3, 128, 128, 128])) + + self.lin1 = torch.nn.Linear(128, 128) + self.lin2 = torch.nn.Linear(128, 128) + + # Utility + self.dropout = nn.Dropout(self.params.dropout) if self.params.dropout else F_x(None) + self.activation = self.params.activation() + + def forward(self, sa0_out, **kwargs): + """ + data: a batch of input torch_geometric.data.Data type + - torch_geometric.data.Data, as torch_geometric batch input: + data.x: (batch_size * ~num_points, C), batch nodes/points feature, + ~num_points means each sample can have different number of points/nodes + + data.pos: (batch_size * ~num_points, 3) + + data.batch: (batch_size * ~num_points,), a column vector of graph/pointcloud + idendifiers for all nodes of all graphs/pointclouds in the batch. See + pytorch_gemometric documentation for more information + """ + + sa1_out = self.sa1_module(*sa0_out) + sa2_out = self.sa2_module(*sa1_out) + sa3_out = self.sa3_module(*sa2_out) + + fp3_out = self.fp3_module(*sa3_out, *sa2_out) + fp2_out = self.fp2_module(*fp3_out, *sa1_out) + tensor, _, _ = self.fp1_module(*fp2_out, *sa0_out) + + tensor = tensor.float() + + tensor = self.activation(tensor) + tensor = self.lin1(tensor) + tensor = self.dropout(tensor) + tensor = self.lin2(tensor) + tensor = self.dropout(tensor) + return tensor diff --git a/models/point_net_2.py b/models/point_net_2.py index 847d6f8..e41e752 100644 --- a/models/point_net_2.py +++ b/models/point_net_2.py @@ -1,13 +1,10 @@ from argparse import Namespace -import torch.nn.functional as F - import torch from torch import nn from datasets.full_pointclouds import FullCloudsDataset -from ml_lib.modules.geometric_blocks import SAModule, GlobalSAModule, MLP, FPModule -from ml_lib.modules.util import LightningBaseModule, F_x +from models._point_net_2 import _PointNetCore from utils.module_mixins import BaseValMixin, BaseTrainMixin, BaseOptimizerMixin, BaseDataloadersMixin, DatasetMixin from utils.project_config import GlobalVar @@ -18,7 +15,7 @@ class PointNet2(BaseValMixin, BaseOptimizerMixin, DatasetMixin, BaseDataloadersMixin, - LightningBaseModule + _PointNetCore ): def __init__(self, hparams): @@ -26,28 +23,18 @@ class PointNet2(BaseValMixin, # Dataset # ============================================================================= - self.dataset = self.build_dataset(FullCloudsDataset) + self.dataset = self.build_dataset(FullCloudsDataset, setting='pc') # Model Paramters # ============================================================================= # Additional parameters + self.n_classes = len(GlobalVar.classes) # Modules - self.sa1_module = SAModule(0.2, 0.2, MLP([3 + 3, 64, 64, 128])) - self.sa2_module = SAModule(0.25, 0.4, MLP([128 + 3, 128, 128, 256])) - self.sa3_module = GlobalSAModule(MLP([256 + 3, 256, 512, 1024])) - - self.fp3_module = FPModule(1, MLP([1024 + 256, 256, 256])) - self.fp2_module = FPModule(3, MLP([256 + 128, 256, 128])) - self.fp1_module = FPModule(3, MLP([128 + 3, 128, 128, 128])) - - self.lin1 = torch.nn.Linear(128, 128) - self.lin2 = torch.nn.Linear(128, 128) + self.point_net_core = () self.lin3 = torch.nn.Linear(128, len(GlobalVar.classes)) # Utility - self.dropout = nn.Dropout(self.params.dropout) if self.params.dropout else F_x(None) - self.activation = self.params.activation() self.log_softmax = nn.LogSoftmax(dim=-1) def forward(self, data, **kwargs): @@ -65,21 +52,7 @@ class PointNet2(BaseValMixin, """ sa0_out = (data.x, data.pos, data.batch) - sa1_out = self.sa1_module(*sa0_out) - sa2_out = self.sa2_module(*sa1_out) - sa3_out = self.sa3_module(*sa2_out) - - fp3_out = self.fp3_module(*sa3_out, *sa2_out) - fp2_out = self.fp2_module(*fp3_out, *sa1_out) - tensor, _, _ = self.fp1_module(*fp2_out, *sa0_out) - - tensor = tensor.float() - - tensor = self.activation(tensor) - tensor = self.lin1(tensor) - tensor = self.dropout(tensor) - tensor = self.lin2(tensor) - tensor = self.dropout(tensor) + tensor = super(PointNet2, self).forward(sa0_out) tensor = self.lin3(tensor) tensor = self.log_softmax(tensor) return Namespace(main_out=tensor) diff --git a/models/point_net_2_grid_clusters.py b/models/point_net_2_grid_clusters.py new file mode 100644 index 0000000..bc59f2a --- /dev/null +++ b/models/point_net_2_grid_clusters.py @@ -0,0 +1,79 @@ +from argparse import Namespace + +import torch +from torch import nn +from torch_geometric.data import Data + +from datasets.full_pointclouds import FullCloudsDataset +from models._point_net_2 import _PointNetCore + +from utils.module_mixins import BaseValMixin, BaseTrainMixin, BaseOptimizerMixin, BaseDataloadersMixin, DatasetMixin +from utils.project_config import GlobalVar + + +class PointNet2GridClusters(BaseValMixin, + BaseTrainMixin, + BaseOptimizerMixin, + DatasetMixin, + BaseDataloadersMixin, + _PointNetCore + ): + + def training_step(self, batch_pos_x_n_y_c, batch_nb, *_, **__): + data = self.batch_to_data(*batch_pos_x_n_y_c) if not isinstance(batch_pos_x_n_y_c, Data) else batch_pos_x_n_y_c + y = self(data) + nll_main_loss = self.nll_loss(y.main_out, data.yl) + nll_cluster_loss = self.nll_loss(y.grid_out, data.yc) + nll_loss = nll_main_loss + nll_cluster_loss + return dict(loss=nll_loss, log=dict(batch_nb=batch_nb), + nll_cluster_loss=nll_cluster_loss, nll_main_loss=nll_main_loss) + + def validation_step(self, batch_pos_x_n_y_c, batch_idx, *_, **__): + data = self.batch_to_data(*batch_pos_x_n_y_c) if not isinstance(batch_pos_x_n_y_c, Data) else batch_pos_x_n_y_c + y = self(data) + nll_main_loss = self.nll_loss(y.main_out, data.yl) + nll_cluster_loss = self.nll_loss(y.grid_out, data.yc) + nll_loss = nll_main_loss + nll_cluster_loss + return dict(val_nll_loss=nll_loss, nll_cluster_loss=nll_cluster_loss, nll_main_loss=nll_main_loss, + batch_idx=batch_idx, y=y.main_out, batch_y=data.yl) + + def __init__(self, hparams): + super(PointNet2GridClusters, self).__init__(hparams=hparams) + + # Dataset + # ============================================================================= + self.dataset = self.build_dataset(FullCloudsDataset, setting='grid') + + # Model Paramters + # ============================================================================= + # Additional parameters + self.n_classes = len(GlobalVar.classes) + + # Modules + self.point_lin = torch.nn.Linear(128, len(GlobalVar.classes)) + self.grid_lin = torch.nn.Linear(128, GlobalVar.grid_count) + + # Utility + self.log_softmax = nn.LogSoftmax(dim=-1) + + def forward(self, data, **kwargs): + """ + data: a batch of input torch_geometric.data.Data type + - torch_geometric.data.Data, as torch_geometric batch input: + data.x: (batch_size * ~num_points, C), batch nodes/points feature, + ~num_points means each sample can have different number of points/nodes + + data.pos: (batch_size * ~num_points, 3) + + data.batch: (batch_size * ~num_points,), a column vector of graph/pointcloud + idendifiers for all nodes of all graphs/pointclouds in the batch. See + pytorch_gemometric documentation for more information + """ + + sa0_out = (data.x, data.pos, data.batch) + tensor = super(PointNet2GridClusters, self).forward(sa0_out) + point_tensor = self.point_lin(tensor) + point_tensor = self.log_softmax(point_tensor) + grid_tensor = self.grid_lin(tensor) + grid_tensor = self.log_softmax(grid_tensor) + return Namespace(main_out=point_tensor, grid_out=grid_tensor) diff --git a/models/point_net_2_prim_clusters.py b/models/point_net_2_prim_clusters.py new file mode 100644 index 0000000..c72501e --- /dev/null +++ b/models/point_net_2_prim_clusters.py @@ -0,0 +1,59 @@ +from argparse import Namespace + +import torch +from torch import nn + +from datasets.full_pointclouds import FullCloudsDataset +from models._point_net_2 import _PointNetCore + +from utils.module_mixins import BaseValMixin, BaseTrainMixin, BaseOptimizerMixin, BaseDataloadersMixin, DatasetMixin +from utils.project_config import GlobalVar + + +class PointNet2PrimClusters(BaseValMixin, + BaseTrainMixin, + BaseOptimizerMixin, + DatasetMixin, + BaseDataloadersMixin, + _PointNetCore + ): + + def __init__(self, hparams): + super(PointNet2PrimClusters, self).__init__(hparams=hparams) + + # Dataset + # ============================================================================= + self.dataset = self.build_dataset(FullCloudsDataset, setting='prim') + + # Model Paramters + # ============================================================================= + # Additional parameters + + # Modules + self.point_lin = torch.nn.Linear(128, len(GlobalVar.classes)) + self.prim_lin = torch.nn.Linear(128, len(GlobalVar.prims)) + + # Utility + self.log_softmax = nn.LogSoftmax(dim=-1) + + def forward(self, data, **kwargs): + """ + data: a batch of input torch_geometric.data.Data type + - torch_geometric.data.Data, as torch_geometric batch input: + data.x: (batch_size * ~num_points, C), batch nodes/points feature, + ~num_points means each sample can have different number of points/nodes + + data.pos: (batch_size * ~num_points, 3) + + data.batch: (batch_size * ~num_points,), a column vector of graph/pointcloud + idendifiers for all nodes of all graphs/pointclouds in the batch. See + pytorch_gemometric documentation for more information + """ + + sa0_out = (data.x, data.pos, data.batch) + tensor = super(PointNet2PrimClusters, self).forward(sa0_out) + point_tensor = self.point_lin(tensor) + point_tensor = self.log_softmax(point_tensor) + prim_tensor = self.prim_lin(tensor) + prim_tensor = self.log_softmax(prim_tensor) + return Namespace(main_out=point_tensor, prim_out=prim_tensor) diff --git a/utils/module_mixins.py b/utils/module_mixins.py index 70ebab3..1337f5f 100644 --- a/utils/module_mixins.py +++ b/utils/module_mixins.py @@ -15,6 +15,7 @@ import matplotlib.pyplot as plt from torch import nn from torch.optim import Adam from torch.utils.data import DataLoader +from torch_geometric.data import Data from torchcontrib.optim import SWA from torchvision.transforms import Compose @@ -61,11 +62,11 @@ class BaseTrainMixin: # Batch To Data batch_to_data = BatchToData() - def training_step(self, batch_pos_x_y, batch_nb, *_, **__): + def training_step(self, batch_pos_x_n_y_c, batch_nb, *_, **__): assert isinstance(self, LightningBaseModule) - data = self.batch_to_data(*batch_pos_x_y) + data = self.batch_to_data(*batch_pos_x_n_y_c) if not isinstance(batch_pos_x_n_y_c, Data) else batch_pos_x_n_y_c y = self(data).main_out - nll_loss = self.nll_loss(y, data.y) + nll_loss = self.nll_loss(y, data.yl) return dict(loss=nll_loss, log=dict(batch_nb=batch_nb)) def training_epoch_end(self, outputs): @@ -86,14 +87,16 @@ class BaseValMixin: nll_loss = nn.NLLLoss() # Binary Cross Entropy bce_loss = nn.BCELoss() + # Batch To Data + batch_to_data = BatchToData() - def validation_step(self, batch_pos_x_y, batch_idx, *_, **__): + def validation_step(self, batch_pos_x_n_y_c, batch_idx, *_, **__): assert isinstance(self, LightningBaseModule) - data = self.batch_to_data(*batch_pos_x_y) + data = self.batch_to_data(*batch_pos_x_n_y_c) if not isinstance(batch_pos_x_n_y_c, Data) else batch_pos_x_n_y_c y = self(data).main_out - nll_loss = self.nll_loss(y, data.y) + nll_loss = self.nll_loss(y, data.yl) return dict(val_nll_loss=nll_loss, - batch_idx=batch_idx, y=y, batch_y=data.y) + batch_idx=batch_idx, y=y, batch_y=data.yl) def validation_epoch_end(self, outputs, *_, **__): assert isinstance(self, LightningBaseModule) @@ -114,12 +117,12 @@ class BaseValMixin: # # INIT y_true = torch.cat([output['batch_y'] for output in outputs]) .cpu().numpy() - y_true_one_hot = to_one_hot(y_true) + y_true_one_hot = to_one_hot(y_true, self.n_classes) y_pred = torch.cat([output['y'] for output in outputs]).squeeze().cpu().numpy() y_pred_max = np.argmax(y_pred, axis=1) - class_names = {val: key for key, val in GlobalVar.classes.__dict__().items()} + class_names = {val: key for key, val in GlobalVar.classes.items()} ###################################################################################### # # F1 SCORE @@ -167,7 +170,7 @@ class BaseValMixin: color='deeppink', linestyle=':', linewidth=4) plt.plot(fpr["macro"], tpr["macro"], - label=f'macro ROC({round(roc_auc["macro"], 2)})]', + label=f'macro ROC({round(roc_auc["macro"], 2)})', color='navy', linestyle=':', linewidth=4) colors = cycle(['firebrick', 'orangered', 'gold', 'olive', 'limegreen', 'aqua', @@ -190,25 +193,32 @@ class BaseValMixin: # # ROC SCORE - macro_roc_auc_ovr = roc_auc_score(y_true_one_hot, y_pred, multi_class="ovr", - average="macro") - summary_dict['log'].update(macro_roc_auc_ovr=macro_roc_auc_ovr) + try: + macro_roc_auc_ovr = roc_auc_score(y_true_one_hot, y_pred, multi_class="ovr", + average="macro") + summary_dict['log'].update(macro_roc_auc_ovr=macro_roc_auc_ovr) + except ValueError: + micro_roc_auc_ovr = roc_auc_score(y_true_one_hot, y_pred, multi_class="ovr", + average="micro") + summary_dict['log'].update(micro_roc_auc_ovr=micro_roc_auc_ovr) ####################################################################################### # # Confusion matrix - cm = confusion_matrix(y_true, y_pred_max, labels=[class_name for class_name in class_names], normalize='all') + cm = confusion_matrix([class_names[x] for x in y_true], [class_names[x] for x in y_pred_max], + labels=[class_names[key] for key in class_names.keys()], + normalize='all') disp = ConfusionMatrixDisplay(confusion_matrix=cm) disp.plot(include_values=True) - self.logger.log_image('Confusion Matrix', image=plt.gcf(), step=self.current_epoch) + self.logger.log_image('Confusion Matrix', image=disp.figure_, step=self.current_epoch) return summary_dict class DatasetMixin: - def build_dataset(self, dataset_class): + def build_dataset(self, dataset_class, **kwargs): assert isinstance(self, LightningBaseModule) # Dataset @@ -221,17 +231,16 @@ class DatasetMixin: dataset = Namespace( **dict( # TRAIN DATASET - train_dataset=dataset_class(self.params.root, setting=GlobalVar.data_split.train, - transforms=transforms - ), + train_dataset=dataset_class(self.params.root, split=GlobalVar.data_split.train, + transforms=transforms, **kwargs), # VALIDATION DATASET - val_dataset=dataset_class(self.params.root, setting=GlobalVar.data_split.devel, - ), + val_dataset=dataset_class(self.params.root, split=GlobalVar.data_split.devel, + **kwargs), # TEST DATASET - test_dataset=dataset_class(self.params.root, setting=GlobalVar.data_split.test, - ), + test_dataset=dataset_class(self.params.root, split=GlobalVar.data_split.test, + **kwargs), ) ) return dataset diff --git a/utils/project_config.py b/utils/project_config.py index bb5d3e3..5cd0b11 100644 --- a/utils/project_config.py +++ b/utils/project_config.py @@ -11,6 +11,9 @@ class DataClass(Namespace): def __dict__(self): return {key: val for key, val in self.__class__.__dict__.items() if '__' not in key} + def items(self): + return self.__dict__().items() + def __repr__(self): return f'{self.__class__.__name__}({self.__dict__().__repr__()})' @@ -28,8 +31,8 @@ class Classes(DataClass): class DataSplit(DataClass): # DATA SPLIT OPTIONS - train = 'train', - devel = 'devel', + train = 'train' + devel = 'devel' test = 'test' @@ -42,6 +45,10 @@ class GlobalVar(DataClass): classes = Classes() + grid_count = 12 + + prim_count = -1 + from models import * @@ -50,4 +57,4 @@ class ThisConfig(Config): @property def _model_map(self): - return dict(PN2=PointNet2) + return dict(PN2=PointNet2, P2P=PointNet2PrimClusters, P2G=PointNet2GridClusters) diff --git a/utils/validation_mixins.py b/utils/validation_mixins.py new file mode 100644 index 0000000..e69de29