Grid Clusters.

This commit is contained in:
Si11ium 2020-06-07 16:47:52 +02:00
parent 8d0577b756
commit 2a767bead2
14 changed files with 278 additions and 149 deletions

View File

@ -40,7 +40,8 @@ main_arg_parser.add_argument("--train_sto_weight_avg", type=strtobool, default=F
main_arg_parser.add_argument("--train_opt_reset_interval", type=strtobool, default=False, help="") main_arg_parser.add_argument("--train_opt_reset_interval", type=strtobool, default=False, help="")
# Model # Model
main_arg_parser.add_argument("--model_type", type=str, default="PN2", help="") # Possible Model arguments are: P2P, PN2, P2G
main_arg_parser.add_argument("--model_type", type=str, default="P2G", help="")
main_arg_parser.add_argument("--model_norm_as_feature", type=strtobool, default=True, help="") main_arg_parser.add_argument("--model_norm_as_feature", type=strtobool, default=True, help="")
main_arg_parser.add_argument("--model_activation", type=str, default="leaky_relu", help="") main_arg_parser.add_argument("--model_activation", type=str, default="leaky_relu", help="")

View File

@ -17,10 +17,6 @@ class _Point_Dataset(ABC, Dataset):
# FixMe: This does not work when more then x/y tuples are returned # FixMe: This does not work when more then x/y tuples are returned
return self[0][0].shape return self[0][0].shape
@property
def setting(self) -> str:
raise NotImplementedError
headers = ['x', 'y', 'z', 'xn', 'yn', 'zn', 'label', 'cl_idx'] headers = ['x', 'y', 'z', 'xn', 'yn', 'zn', 'label', 'cl_idx']
samplers = dict(fps=FarthestpointSampling, rnd=RandomSampling) samplers = dict(fps=FarthestpointSampling, rnd=RandomSampling)
@ -28,6 +24,8 @@ class _Point_Dataset(ABC, Dataset):
transforms=None, load_preprocessed=True, split='train', dense_output=False, *args, **kwargs): transforms=None, load_preprocessed=True, split='train', dense_output=False, *args, **kwargs):
super(_Point_Dataset, self).__init__() super(_Point_Dataset, self).__init__()
self.setting: str
self.dense_output = dense_output self.dense_output = dense_output
self.split = split self.split = split
self.norm_as_feature = norm_as_feature self.norm_as_feature = norm_as_feature
@ -67,4 +65,23 @@ class _Point_Dataset(ABC, Dataset):
raise NotImplementedError raise NotImplementedError
def __getitem__(self, item): def __getitem__(self, item):
raise NotImplementedError processed_file_path = self._read_or_load(item)
with processed_file_path.open('rb') as processed_file:
pointcloud = pickle.load(processed_file)
position = np.stack((pointcloud['x'], pointcloud['y'], pointcloud['z']), axis=-1)
normal = np.stack((pointcloud['xn'], pointcloud['yn'], pointcloud['zn']), axis=-1)
label = pointcloud['label']
cl_label = pointcloud['cl_idx']
sample_idxs = self.sampling(position)
return (normal[sample_idxs].astype(np.float),
position[sample_idxs].astype(np.float),
label[sample_idxs].astype(np.int),
cl_label[sample_idxs].astype(np.int)
)

View File

@ -8,29 +8,11 @@ from ._point_dataset import _Point_Dataset
class FullCloudsDataset(_Point_Dataset): class FullCloudsDataset(_Point_Dataset):
setting = 'pc'
split: str split: str
def __init__(self, *args, **kwargs): def __init__(self, *args, setting='pc', **kwargs):
self.setting = setting
super(FullCloudsDataset, self).__init__(*args, **kwargs) super(FullCloudsDataset, self).__init__(*args, **kwargs)
def __len__(self): def __len__(self):
return len(self._files) return len(self._files)
def __getitem__(self, item):
processed_file_path = self._read_or_load(item)
with processed_file_path.open('rb') as processed_file:
pointcloud = pickle.load(processed_file)
position = np.stack((pointcloud['x'], pointcloud['y'], pointcloud['z']), axis=-1)
normal = np.stack((pointcloud['xn'], pointcloud['yn'], pointcloud['zn']), axis=-1)
label = pointcloud['label']
sample_idxs = self.sampling(position)
return (normal[sample_idxs].astype(np.float),
position[sample_idxs].astype(np.float),
label[sample_idxs].astype(np.int))

View File

@ -1,32 +0,0 @@
import pickle
import numpy as np
from ._point_dataset import _Point_Dataset
class FullCloudsDataset(_Point_Dataset):
setting = 'grid'
def __init__(self, *args, **kwargs):
super(FullCloudsDataset, self).__init__(*args, **kwargs)
def __len__(self):
return len(self._files)
def __getitem__(self, item):
processed_file_path = self._read_or_load(item)
with processed_file_path.open('rb') as processed_file:
pointcloud = pickle.load(processed_file)
points = np.stack((pointcloud['x'], pointcloud['y'], pointcloud['z'],
pointcloud['xn'], pointcloud['yn'], pointcloud['zn']
),
axis=-1)
# When yopu want to return points and normal seperately
# normal = np.stack((pointcloud['xn'], pointcloud['yn'], pointcloud['zn']), axis=-1)
label = pointcloud['cl_idx']
sample_idxs = self.sampling(points)
return points[sample_idxs], label[sample_idxs]

View File

@ -1,32 +0,0 @@
import pickle
import numpy as np
from ._point_dataset import _Point_Dataset
class FullCloudsDataset(_Point_Dataset):
setting = 'prim'
def __init__(self, *args, **kwargs):
super(FullCloudsDataset, self).__init__(*args, **kwargs)
def __len__(self):
return len(self._files)
def __getitem__(self, item):
processed_file_path = self._read_or_load(item)
with processed_file_path.open('rb') as processed_file:
pointcloud = pickle.load(processed_file)
points = np.stack((pointcloud['x'], pointcloud['y'], pointcloud['z'],
pointcloud['xn'], pointcloud['yn'], pointcloud['zn']
),
axis=-1)
# When yopu want to return points and normal seperately
# normal = np.stack((pointcloud['xn'], pointcloud['yn'], pointcloud['zn']), axis=-1)
label = pointcloud['cl_idx']
sample_idxs = self.sampling(points)
return points[sample_idxs], label[sample_idxs]

View File

@ -1,6 +1,7 @@
from torch.utils.data import Dataset from torch.utils.data import Dataset
from._point_dataset import _Point_Dataset from._point_dataset import _Point_Dataset
class TemplateDataset(_Point_Dataset): class TemplateDataset(_Point_Dataset):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(TemplateDataset, self).__init__() super(TemplateDataset, self).__init__()

View File

@ -1 +1,4 @@
from .point_net_2 import PointNet2 from .point_net_2 import PointNet2
from .point_net_2_grid_clusters import PointNet2GridClusters
from .point_net_2_prim_clusters import PointNet2PrimClusters

62
models/_point_net_2.py Normal file
View File

@ -0,0 +1,62 @@
import torch
from torch import nn
from ml_lib.modules.geometric_blocks import SAModule, GlobalSAModule, MLP, FPModule
from ml_lib.modules.util import LightningBaseModule, F_x
class _PointNetCore(LightningBaseModule):
def __init__(self, hparams):
super(_PointNetCore, self).__init__(hparams=hparams)
# Model Paramters
# =============================================================================
# Additional parameters
# Modules
self.sa1_module = SAModule(0.2, 0.2, MLP([3 + 3, 64, 64, 128]))
self.sa2_module = SAModule(0.25, 0.4, MLP([128 + 3, 128, 128, 256]))
self.sa3_module = GlobalSAModule(MLP([256 + 3, 256, 512, 1024]))
self.fp3_module = FPModule(1, MLP([1024 + 256, 256, 256]))
self.fp2_module = FPModule(3, MLP([256 + 128, 256, 128]))
self.fp1_module = FPModule(3, MLP([128 + 3, 128, 128, 128]))
self.lin1 = torch.nn.Linear(128, 128)
self.lin2 = torch.nn.Linear(128, 128)
# Utility
self.dropout = nn.Dropout(self.params.dropout) if self.params.dropout else F_x(None)
self.activation = self.params.activation()
def forward(self, sa0_out, **kwargs):
"""
data: a batch of input torch_geometric.data.Data type
- torch_geometric.data.Data, as torch_geometric batch input:
data.x: (batch_size * ~num_points, C), batch nodes/points feature,
~num_points means each sample can have different number of points/nodes
data.pos: (batch_size * ~num_points, 3)
data.batch: (batch_size * ~num_points,), a column vector of graph/pointcloud
idendifiers for all nodes of all graphs/pointclouds in the batch. See
pytorch_gemometric documentation for more information
"""
sa1_out = self.sa1_module(*sa0_out)
sa2_out = self.sa2_module(*sa1_out)
sa3_out = self.sa3_module(*sa2_out)
fp3_out = self.fp3_module(*sa3_out, *sa2_out)
fp2_out = self.fp2_module(*fp3_out, *sa1_out)
tensor, _, _ = self.fp1_module(*fp2_out, *sa0_out)
tensor = tensor.float()
tensor = self.activation(tensor)
tensor = self.lin1(tensor)
tensor = self.dropout(tensor)
tensor = self.lin2(tensor)
tensor = self.dropout(tensor)
return tensor

View File

@ -1,13 +1,10 @@
from argparse import Namespace from argparse import Namespace
import torch.nn.functional as F
import torch import torch
from torch import nn from torch import nn
from datasets.full_pointclouds import FullCloudsDataset from datasets.full_pointclouds import FullCloudsDataset
from ml_lib.modules.geometric_blocks import SAModule, GlobalSAModule, MLP, FPModule from models._point_net_2 import _PointNetCore
from ml_lib.modules.util import LightningBaseModule, F_x
from utils.module_mixins import BaseValMixin, BaseTrainMixin, BaseOptimizerMixin, BaseDataloadersMixin, DatasetMixin from utils.module_mixins import BaseValMixin, BaseTrainMixin, BaseOptimizerMixin, BaseDataloadersMixin, DatasetMixin
from utils.project_config import GlobalVar from utils.project_config import GlobalVar
@ -18,7 +15,7 @@ class PointNet2(BaseValMixin,
BaseOptimizerMixin, BaseOptimizerMixin,
DatasetMixin, DatasetMixin,
BaseDataloadersMixin, BaseDataloadersMixin,
LightningBaseModule _PointNetCore
): ):
def __init__(self, hparams): def __init__(self, hparams):
@ -26,28 +23,18 @@ class PointNet2(BaseValMixin,
# Dataset # Dataset
# ============================================================================= # =============================================================================
self.dataset = self.build_dataset(FullCloudsDataset) self.dataset = self.build_dataset(FullCloudsDataset, setting='pc')
# Model Paramters # Model Paramters
# ============================================================================= # =============================================================================
# Additional parameters # Additional parameters
self.n_classes = len(GlobalVar.classes)
# Modules # Modules
self.sa1_module = SAModule(0.2, 0.2, MLP([3 + 3, 64, 64, 128])) self.point_net_core = ()
self.sa2_module = SAModule(0.25, 0.4, MLP([128 + 3, 128, 128, 256]))
self.sa3_module = GlobalSAModule(MLP([256 + 3, 256, 512, 1024]))
self.fp3_module = FPModule(1, MLP([1024 + 256, 256, 256]))
self.fp2_module = FPModule(3, MLP([256 + 128, 256, 128]))
self.fp1_module = FPModule(3, MLP([128 + 3, 128, 128, 128]))
self.lin1 = torch.nn.Linear(128, 128)
self.lin2 = torch.nn.Linear(128, 128)
self.lin3 = torch.nn.Linear(128, len(GlobalVar.classes)) self.lin3 = torch.nn.Linear(128, len(GlobalVar.classes))
# Utility # Utility
self.dropout = nn.Dropout(self.params.dropout) if self.params.dropout else F_x(None)
self.activation = self.params.activation()
self.log_softmax = nn.LogSoftmax(dim=-1) self.log_softmax = nn.LogSoftmax(dim=-1)
def forward(self, data, **kwargs): def forward(self, data, **kwargs):
@ -65,21 +52,7 @@ class PointNet2(BaseValMixin,
""" """
sa0_out = (data.x, data.pos, data.batch) sa0_out = (data.x, data.pos, data.batch)
sa1_out = self.sa1_module(*sa0_out) tensor = super(PointNet2, self).forward(sa0_out)
sa2_out = self.sa2_module(*sa1_out)
sa3_out = self.sa3_module(*sa2_out)
fp3_out = self.fp3_module(*sa3_out, *sa2_out)
fp2_out = self.fp2_module(*fp3_out, *sa1_out)
tensor, _, _ = self.fp1_module(*fp2_out, *sa0_out)
tensor = tensor.float()
tensor = self.activation(tensor)
tensor = self.lin1(tensor)
tensor = self.dropout(tensor)
tensor = self.lin2(tensor)
tensor = self.dropout(tensor)
tensor = self.lin3(tensor) tensor = self.lin3(tensor)
tensor = self.log_softmax(tensor) tensor = self.log_softmax(tensor)
return Namespace(main_out=tensor) return Namespace(main_out=tensor)

View File

@ -0,0 +1,79 @@
from argparse import Namespace
import torch
from torch import nn
from torch_geometric.data import Data
from datasets.full_pointclouds import FullCloudsDataset
from models._point_net_2 import _PointNetCore
from utils.module_mixins import BaseValMixin, BaseTrainMixin, BaseOptimizerMixin, BaseDataloadersMixin, DatasetMixin
from utils.project_config import GlobalVar
class PointNet2GridClusters(BaseValMixin,
BaseTrainMixin,
BaseOptimizerMixin,
DatasetMixin,
BaseDataloadersMixin,
_PointNetCore
):
def training_step(self, batch_pos_x_n_y_c, batch_nb, *_, **__):
data = self.batch_to_data(*batch_pos_x_n_y_c) if not isinstance(batch_pos_x_n_y_c, Data) else batch_pos_x_n_y_c
y = self(data)
nll_main_loss = self.nll_loss(y.main_out, data.yl)
nll_cluster_loss = self.nll_loss(y.grid_out, data.yc)
nll_loss = nll_main_loss + nll_cluster_loss
return dict(loss=nll_loss, log=dict(batch_nb=batch_nb),
nll_cluster_loss=nll_cluster_loss, nll_main_loss=nll_main_loss)
def validation_step(self, batch_pos_x_n_y_c, batch_idx, *_, **__):
data = self.batch_to_data(*batch_pos_x_n_y_c) if not isinstance(batch_pos_x_n_y_c, Data) else batch_pos_x_n_y_c
y = self(data)
nll_main_loss = self.nll_loss(y.main_out, data.yl)
nll_cluster_loss = self.nll_loss(y.grid_out, data.yc)
nll_loss = nll_main_loss + nll_cluster_loss
return dict(val_nll_loss=nll_loss, nll_cluster_loss=nll_cluster_loss, nll_main_loss=nll_main_loss,
batch_idx=batch_idx, y=y.main_out, batch_y=data.yl)
def __init__(self, hparams):
super(PointNet2GridClusters, self).__init__(hparams=hparams)
# Dataset
# =============================================================================
self.dataset = self.build_dataset(FullCloudsDataset, setting='grid')
# Model Paramters
# =============================================================================
# Additional parameters
self.n_classes = len(GlobalVar.classes)
# Modules
self.point_lin = torch.nn.Linear(128, len(GlobalVar.classes))
self.grid_lin = torch.nn.Linear(128, GlobalVar.grid_count)
# Utility
self.log_softmax = nn.LogSoftmax(dim=-1)
def forward(self, data, **kwargs):
"""
data: a batch of input torch_geometric.data.Data type
- torch_geometric.data.Data, as torch_geometric batch input:
data.x: (batch_size * ~num_points, C), batch nodes/points feature,
~num_points means each sample can have different number of points/nodes
data.pos: (batch_size * ~num_points, 3)
data.batch: (batch_size * ~num_points,), a column vector of graph/pointcloud
idendifiers for all nodes of all graphs/pointclouds in the batch. See
pytorch_gemometric documentation for more information
"""
sa0_out = (data.x, data.pos, data.batch)
tensor = super(PointNet2GridClusters, self).forward(sa0_out)
point_tensor = self.point_lin(tensor)
point_tensor = self.log_softmax(point_tensor)
grid_tensor = self.grid_lin(tensor)
grid_tensor = self.log_softmax(grid_tensor)
return Namespace(main_out=point_tensor, grid_out=grid_tensor)

View File

@ -0,0 +1,59 @@
from argparse import Namespace
import torch
from torch import nn
from datasets.full_pointclouds import FullCloudsDataset
from models._point_net_2 import _PointNetCore
from utils.module_mixins import BaseValMixin, BaseTrainMixin, BaseOptimizerMixin, BaseDataloadersMixin, DatasetMixin
from utils.project_config import GlobalVar
class PointNet2PrimClusters(BaseValMixin,
BaseTrainMixin,
BaseOptimizerMixin,
DatasetMixin,
BaseDataloadersMixin,
_PointNetCore
):
def __init__(self, hparams):
super(PointNet2PrimClusters, self).__init__(hparams=hparams)
# Dataset
# =============================================================================
self.dataset = self.build_dataset(FullCloudsDataset, setting='prim')
# Model Paramters
# =============================================================================
# Additional parameters
# Modules
self.point_lin = torch.nn.Linear(128, len(GlobalVar.classes))
self.prim_lin = torch.nn.Linear(128, len(GlobalVar.prims))
# Utility
self.log_softmax = nn.LogSoftmax(dim=-1)
def forward(self, data, **kwargs):
"""
data: a batch of input torch_geometric.data.Data type
- torch_geometric.data.Data, as torch_geometric batch input:
data.x: (batch_size * ~num_points, C), batch nodes/points feature,
~num_points means each sample can have different number of points/nodes
data.pos: (batch_size * ~num_points, 3)
data.batch: (batch_size * ~num_points,), a column vector of graph/pointcloud
idendifiers for all nodes of all graphs/pointclouds in the batch. See
pytorch_gemometric documentation for more information
"""
sa0_out = (data.x, data.pos, data.batch)
tensor = super(PointNet2PrimClusters, self).forward(sa0_out)
point_tensor = self.point_lin(tensor)
point_tensor = self.log_softmax(point_tensor)
prim_tensor = self.prim_lin(tensor)
prim_tensor = self.log_softmax(prim_tensor)
return Namespace(main_out=point_tensor, prim_out=prim_tensor)

View File

@ -15,6 +15,7 @@ import matplotlib.pyplot as plt
from torch import nn from torch import nn
from torch.optim import Adam from torch.optim import Adam
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch_geometric.data import Data
from torchcontrib.optim import SWA from torchcontrib.optim import SWA
from torchvision.transforms import Compose from torchvision.transforms import Compose
@ -61,11 +62,11 @@ class BaseTrainMixin:
# Batch To Data # Batch To Data
batch_to_data = BatchToData() batch_to_data = BatchToData()
def training_step(self, batch_pos_x_y, batch_nb, *_, **__): def training_step(self, batch_pos_x_n_y_c, batch_nb, *_, **__):
assert isinstance(self, LightningBaseModule) assert isinstance(self, LightningBaseModule)
data = self.batch_to_data(*batch_pos_x_y) data = self.batch_to_data(*batch_pos_x_n_y_c) if not isinstance(batch_pos_x_n_y_c, Data) else batch_pos_x_n_y_c
y = self(data).main_out y = self(data).main_out
nll_loss = self.nll_loss(y, data.y) nll_loss = self.nll_loss(y, data.yl)
return dict(loss=nll_loss, log=dict(batch_nb=batch_nb)) return dict(loss=nll_loss, log=dict(batch_nb=batch_nb))
def training_epoch_end(self, outputs): def training_epoch_end(self, outputs):
@ -86,14 +87,16 @@ class BaseValMixin:
nll_loss = nn.NLLLoss() nll_loss = nn.NLLLoss()
# Binary Cross Entropy # Binary Cross Entropy
bce_loss = nn.BCELoss() bce_loss = nn.BCELoss()
# Batch To Data
batch_to_data = BatchToData()
def validation_step(self, batch_pos_x_y, batch_idx, *_, **__): def validation_step(self, batch_pos_x_n_y_c, batch_idx, *_, **__):
assert isinstance(self, LightningBaseModule) assert isinstance(self, LightningBaseModule)
data = self.batch_to_data(*batch_pos_x_y) data = self.batch_to_data(*batch_pos_x_n_y_c) if not isinstance(batch_pos_x_n_y_c, Data) else batch_pos_x_n_y_c
y = self(data).main_out y = self(data).main_out
nll_loss = self.nll_loss(y, data.y) nll_loss = self.nll_loss(y, data.yl)
return dict(val_nll_loss=nll_loss, return dict(val_nll_loss=nll_loss,
batch_idx=batch_idx, y=y, batch_y=data.y) batch_idx=batch_idx, y=y, batch_y=data.yl)
def validation_epoch_end(self, outputs, *_, **__): def validation_epoch_end(self, outputs, *_, **__):
assert isinstance(self, LightningBaseModule) assert isinstance(self, LightningBaseModule)
@ -114,12 +117,12 @@ class BaseValMixin:
# #
# INIT # INIT
y_true = torch.cat([output['batch_y'] for output in outputs]) .cpu().numpy() y_true = torch.cat([output['batch_y'] for output in outputs]) .cpu().numpy()
y_true_one_hot = to_one_hot(y_true) y_true_one_hot = to_one_hot(y_true, self.n_classes)
y_pred = torch.cat([output['y'] for output in outputs]).squeeze().cpu().numpy() y_pred = torch.cat([output['y'] for output in outputs]).squeeze().cpu().numpy()
y_pred_max = np.argmax(y_pred, axis=1) y_pred_max = np.argmax(y_pred, axis=1)
class_names = {val: key for key, val in GlobalVar.classes.__dict__().items()} class_names = {val: key for key, val in GlobalVar.classes.items()}
###################################################################################### ######################################################################################
# #
# F1 SCORE # F1 SCORE
@ -167,7 +170,7 @@ class BaseValMixin:
color='deeppink', linestyle=':', linewidth=4) color='deeppink', linestyle=':', linewidth=4)
plt.plot(fpr["macro"], tpr["macro"], plt.plot(fpr["macro"], tpr["macro"],
label=f'macro ROC({round(roc_auc["macro"], 2)})]', label=f'macro ROC({round(roc_auc["macro"], 2)})',
color='navy', linestyle=':', linewidth=4) color='navy', linestyle=':', linewidth=4)
colors = cycle(['firebrick', 'orangered', 'gold', 'olive', 'limegreen', 'aqua', colors = cycle(['firebrick', 'orangered', 'gold', 'olive', 'limegreen', 'aqua',
@ -190,25 +193,32 @@ class BaseValMixin:
# #
# ROC SCORE # ROC SCORE
macro_roc_auc_ovr = roc_auc_score(y_true_one_hot, y_pred, multi_class="ovr", try:
average="macro") macro_roc_auc_ovr = roc_auc_score(y_true_one_hot, y_pred, multi_class="ovr",
summary_dict['log'].update(macro_roc_auc_ovr=macro_roc_auc_ovr) average="macro")
summary_dict['log'].update(macro_roc_auc_ovr=macro_roc_auc_ovr)
except ValueError:
micro_roc_auc_ovr = roc_auc_score(y_true_one_hot, y_pred, multi_class="ovr",
average="micro")
summary_dict['log'].update(micro_roc_auc_ovr=micro_roc_auc_ovr)
####################################################################################### #######################################################################################
# #
# Confusion matrix # Confusion matrix
cm = confusion_matrix(y_true, y_pred_max, labels=[class_name for class_name in class_names], normalize='all') cm = confusion_matrix([class_names[x] for x in y_true], [class_names[x] for x in y_pred_max],
labels=[class_names[key] for key in class_names.keys()],
normalize='all')
disp = ConfusionMatrixDisplay(confusion_matrix=cm) disp = ConfusionMatrixDisplay(confusion_matrix=cm)
disp.plot(include_values=True) disp.plot(include_values=True)
self.logger.log_image('Confusion Matrix', image=plt.gcf(), step=self.current_epoch) self.logger.log_image('Confusion Matrix', image=disp.figure_, step=self.current_epoch)
return summary_dict return summary_dict
class DatasetMixin: class DatasetMixin:
def build_dataset(self, dataset_class): def build_dataset(self, dataset_class, **kwargs):
assert isinstance(self, LightningBaseModule) assert isinstance(self, LightningBaseModule)
# Dataset # Dataset
@ -221,17 +231,16 @@ class DatasetMixin:
dataset = Namespace( dataset = Namespace(
**dict( **dict(
# TRAIN DATASET # TRAIN DATASET
train_dataset=dataset_class(self.params.root, setting=GlobalVar.data_split.train, train_dataset=dataset_class(self.params.root, split=GlobalVar.data_split.train,
transforms=transforms transforms=transforms, **kwargs),
),
# VALIDATION DATASET # VALIDATION DATASET
val_dataset=dataset_class(self.params.root, setting=GlobalVar.data_split.devel, val_dataset=dataset_class(self.params.root, split=GlobalVar.data_split.devel,
), **kwargs),
# TEST DATASET # TEST DATASET
test_dataset=dataset_class(self.params.root, setting=GlobalVar.data_split.test, test_dataset=dataset_class(self.params.root, split=GlobalVar.data_split.test,
), **kwargs),
) )
) )
return dataset return dataset

View File

@ -11,6 +11,9 @@ class DataClass(Namespace):
def __dict__(self): def __dict__(self):
return {key: val for key, val in self.__class__.__dict__.items() if '__' not in key} return {key: val for key, val in self.__class__.__dict__.items() if '__' not in key}
def items(self):
return self.__dict__().items()
def __repr__(self): def __repr__(self):
return f'{self.__class__.__name__}({self.__dict__().__repr__()})' return f'{self.__class__.__name__}({self.__dict__().__repr__()})'
@ -28,8 +31,8 @@ class Classes(DataClass):
class DataSplit(DataClass): class DataSplit(DataClass):
# DATA SPLIT OPTIONS # DATA SPLIT OPTIONS
train = 'train', train = 'train'
devel = 'devel', devel = 'devel'
test = 'test' test = 'test'
@ -42,6 +45,10 @@ class GlobalVar(DataClass):
classes = Classes() classes = Classes()
grid_count = 12
prim_count = -1
from models import * from models import *
@ -50,4 +57,4 @@ class ThisConfig(Config):
@property @property
def _model_map(self): def _model_map(self):
return dict(PN2=PointNet2) return dict(PN2=PointNet2, P2P=PointNet2PrimClusters, P2G=PointNet2GridClusters)

View File