Grid Clusters.
This commit is contained in:
@@ -1 +1,4 @@
|
||||
from .point_net_2 import PointNet2
|
||||
from .point_net_2_grid_clusters import PointNet2GridClusters
|
||||
from .point_net_2_prim_clusters import PointNet2PrimClusters
|
||||
|
||||
|
||||
62
models/_point_net_2.py
Normal file
62
models/_point_net_2.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ml_lib.modules.geometric_blocks import SAModule, GlobalSAModule, MLP, FPModule
|
||||
from ml_lib.modules.util import LightningBaseModule, F_x
|
||||
|
||||
|
||||
class _PointNetCore(LightningBaseModule):
|
||||
|
||||
def __init__(self, hparams):
|
||||
super(_PointNetCore, self).__init__(hparams=hparams)
|
||||
|
||||
# Model Paramters
|
||||
# =============================================================================
|
||||
# Additional parameters
|
||||
|
||||
# Modules
|
||||
self.sa1_module = SAModule(0.2, 0.2, MLP([3 + 3, 64, 64, 128]))
|
||||
self.sa2_module = SAModule(0.25, 0.4, MLP([128 + 3, 128, 128, 256]))
|
||||
self.sa3_module = GlobalSAModule(MLP([256 + 3, 256, 512, 1024]))
|
||||
|
||||
self.fp3_module = FPModule(1, MLP([1024 + 256, 256, 256]))
|
||||
self.fp2_module = FPModule(3, MLP([256 + 128, 256, 128]))
|
||||
self.fp1_module = FPModule(3, MLP([128 + 3, 128, 128, 128]))
|
||||
|
||||
self.lin1 = torch.nn.Linear(128, 128)
|
||||
self.lin2 = torch.nn.Linear(128, 128)
|
||||
|
||||
# Utility
|
||||
self.dropout = nn.Dropout(self.params.dropout) if self.params.dropout else F_x(None)
|
||||
self.activation = self.params.activation()
|
||||
|
||||
def forward(self, sa0_out, **kwargs):
|
||||
"""
|
||||
data: a batch of input torch_geometric.data.Data type
|
||||
- torch_geometric.data.Data, as torch_geometric batch input:
|
||||
data.x: (batch_size * ~num_points, C), batch nodes/points feature,
|
||||
~num_points means each sample can have different number of points/nodes
|
||||
|
||||
data.pos: (batch_size * ~num_points, 3)
|
||||
|
||||
data.batch: (batch_size * ~num_points,), a column vector of graph/pointcloud
|
||||
idendifiers for all nodes of all graphs/pointclouds in the batch. See
|
||||
pytorch_gemometric documentation for more information
|
||||
"""
|
||||
|
||||
sa1_out = self.sa1_module(*sa0_out)
|
||||
sa2_out = self.sa2_module(*sa1_out)
|
||||
sa3_out = self.sa3_module(*sa2_out)
|
||||
|
||||
fp3_out = self.fp3_module(*sa3_out, *sa2_out)
|
||||
fp2_out = self.fp2_module(*fp3_out, *sa1_out)
|
||||
tensor, _, _ = self.fp1_module(*fp2_out, *sa0_out)
|
||||
|
||||
tensor = tensor.float()
|
||||
|
||||
tensor = self.activation(tensor)
|
||||
tensor = self.lin1(tensor)
|
||||
tensor = self.dropout(tensor)
|
||||
tensor = self.lin2(tensor)
|
||||
tensor = self.dropout(tensor)
|
||||
return tensor
|
||||
@@ -1,13 +1,10 @@
|
||||
from argparse import Namespace
|
||||
|
||||
import torch.nn.functional as F
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from datasets.full_pointclouds import FullCloudsDataset
|
||||
from ml_lib.modules.geometric_blocks import SAModule, GlobalSAModule, MLP, FPModule
|
||||
from ml_lib.modules.util import LightningBaseModule, F_x
|
||||
from models._point_net_2 import _PointNetCore
|
||||
|
||||
from utils.module_mixins import BaseValMixin, BaseTrainMixin, BaseOptimizerMixin, BaseDataloadersMixin, DatasetMixin
|
||||
from utils.project_config import GlobalVar
|
||||
@@ -18,7 +15,7 @@ class PointNet2(BaseValMixin,
|
||||
BaseOptimizerMixin,
|
||||
DatasetMixin,
|
||||
BaseDataloadersMixin,
|
||||
LightningBaseModule
|
||||
_PointNetCore
|
||||
):
|
||||
|
||||
def __init__(self, hparams):
|
||||
@@ -26,28 +23,18 @@ class PointNet2(BaseValMixin,
|
||||
|
||||
# Dataset
|
||||
# =============================================================================
|
||||
self.dataset = self.build_dataset(FullCloudsDataset)
|
||||
self.dataset = self.build_dataset(FullCloudsDataset, setting='pc')
|
||||
|
||||
# Model Paramters
|
||||
# =============================================================================
|
||||
# Additional parameters
|
||||
self.n_classes = len(GlobalVar.classes)
|
||||
|
||||
# Modules
|
||||
self.sa1_module = SAModule(0.2, 0.2, MLP([3 + 3, 64, 64, 128]))
|
||||
self.sa2_module = SAModule(0.25, 0.4, MLP([128 + 3, 128, 128, 256]))
|
||||
self.sa3_module = GlobalSAModule(MLP([256 + 3, 256, 512, 1024]))
|
||||
|
||||
self.fp3_module = FPModule(1, MLP([1024 + 256, 256, 256]))
|
||||
self.fp2_module = FPModule(3, MLP([256 + 128, 256, 128]))
|
||||
self.fp1_module = FPModule(3, MLP([128 + 3, 128, 128, 128]))
|
||||
|
||||
self.lin1 = torch.nn.Linear(128, 128)
|
||||
self.lin2 = torch.nn.Linear(128, 128)
|
||||
self.point_net_core = ()
|
||||
self.lin3 = torch.nn.Linear(128, len(GlobalVar.classes))
|
||||
|
||||
# Utility
|
||||
self.dropout = nn.Dropout(self.params.dropout) if self.params.dropout else F_x(None)
|
||||
self.activation = self.params.activation()
|
||||
self.log_softmax = nn.LogSoftmax(dim=-1)
|
||||
|
||||
def forward(self, data, **kwargs):
|
||||
@@ -65,21 +52,7 @@ class PointNet2(BaseValMixin,
|
||||
"""
|
||||
|
||||
sa0_out = (data.x, data.pos, data.batch)
|
||||
sa1_out = self.sa1_module(*sa0_out)
|
||||
sa2_out = self.sa2_module(*sa1_out)
|
||||
sa3_out = self.sa3_module(*sa2_out)
|
||||
|
||||
fp3_out = self.fp3_module(*sa3_out, *sa2_out)
|
||||
fp2_out = self.fp2_module(*fp3_out, *sa1_out)
|
||||
tensor, _, _ = self.fp1_module(*fp2_out, *sa0_out)
|
||||
|
||||
tensor = tensor.float()
|
||||
|
||||
tensor = self.activation(tensor)
|
||||
tensor = self.lin1(tensor)
|
||||
tensor = self.dropout(tensor)
|
||||
tensor = self.lin2(tensor)
|
||||
tensor = self.dropout(tensor)
|
||||
tensor = super(PointNet2, self).forward(sa0_out)
|
||||
tensor = self.lin3(tensor)
|
||||
tensor = self.log_softmax(tensor)
|
||||
return Namespace(main_out=tensor)
|
||||
|
||||
79
models/point_net_2_grid_clusters.py
Normal file
79
models/point_net_2_grid_clusters.py
Normal file
@@ -0,0 +1,79 @@
|
||||
from argparse import Namespace
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch_geometric.data import Data
|
||||
|
||||
from datasets.full_pointclouds import FullCloudsDataset
|
||||
from models._point_net_2 import _PointNetCore
|
||||
|
||||
from utils.module_mixins import BaseValMixin, BaseTrainMixin, BaseOptimizerMixin, BaseDataloadersMixin, DatasetMixin
|
||||
from utils.project_config import GlobalVar
|
||||
|
||||
|
||||
class PointNet2GridClusters(BaseValMixin,
|
||||
BaseTrainMixin,
|
||||
BaseOptimizerMixin,
|
||||
DatasetMixin,
|
||||
BaseDataloadersMixin,
|
||||
_PointNetCore
|
||||
):
|
||||
|
||||
def training_step(self, batch_pos_x_n_y_c, batch_nb, *_, **__):
|
||||
data = self.batch_to_data(*batch_pos_x_n_y_c) if not isinstance(batch_pos_x_n_y_c, Data) else batch_pos_x_n_y_c
|
||||
y = self(data)
|
||||
nll_main_loss = self.nll_loss(y.main_out, data.yl)
|
||||
nll_cluster_loss = self.nll_loss(y.grid_out, data.yc)
|
||||
nll_loss = nll_main_loss + nll_cluster_loss
|
||||
return dict(loss=nll_loss, log=dict(batch_nb=batch_nb),
|
||||
nll_cluster_loss=nll_cluster_loss, nll_main_loss=nll_main_loss)
|
||||
|
||||
def validation_step(self, batch_pos_x_n_y_c, batch_idx, *_, **__):
|
||||
data = self.batch_to_data(*batch_pos_x_n_y_c) if not isinstance(batch_pos_x_n_y_c, Data) else batch_pos_x_n_y_c
|
||||
y = self(data)
|
||||
nll_main_loss = self.nll_loss(y.main_out, data.yl)
|
||||
nll_cluster_loss = self.nll_loss(y.grid_out, data.yc)
|
||||
nll_loss = nll_main_loss + nll_cluster_loss
|
||||
return dict(val_nll_loss=nll_loss, nll_cluster_loss=nll_cluster_loss, nll_main_loss=nll_main_loss,
|
||||
batch_idx=batch_idx, y=y.main_out, batch_y=data.yl)
|
||||
|
||||
def __init__(self, hparams):
|
||||
super(PointNet2GridClusters, self).__init__(hparams=hparams)
|
||||
|
||||
# Dataset
|
||||
# =============================================================================
|
||||
self.dataset = self.build_dataset(FullCloudsDataset, setting='grid')
|
||||
|
||||
# Model Paramters
|
||||
# =============================================================================
|
||||
# Additional parameters
|
||||
self.n_classes = len(GlobalVar.classes)
|
||||
|
||||
# Modules
|
||||
self.point_lin = torch.nn.Linear(128, len(GlobalVar.classes))
|
||||
self.grid_lin = torch.nn.Linear(128, GlobalVar.grid_count)
|
||||
|
||||
# Utility
|
||||
self.log_softmax = nn.LogSoftmax(dim=-1)
|
||||
|
||||
def forward(self, data, **kwargs):
|
||||
"""
|
||||
data: a batch of input torch_geometric.data.Data type
|
||||
- torch_geometric.data.Data, as torch_geometric batch input:
|
||||
data.x: (batch_size * ~num_points, C), batch nodes/points feature,
|
||||
~num_points means each sample can have different number of points/nodes
|
||||
|
||||
data.pos: (batch_size * ~num_points, 3)
|
||||
|
||||
data.batch: (batch_size * ~num_points,), a column vector of graph/pointcloud
|
||||
idendifiers for all nodes of all graphs/pointclouds in the batch. See
|
||||
pytorch_gemometric documentation for more information
|
||||
"""
|
||||
|
||||
sa0_out = (data.x, data.pos, data.batch)
|
||||
tensor = super(PointNet2GridClusters, self).forward(sa0_out)
|
||||
point_tensor = self.point_lin(tensor)
|
||||
point_tensor = self.log_softmax(point_tensor)
|
||||
grid_tensor = self.grid_lin(tensor)
|
||||
grid_tensor = self.log_softmax(grid_tensor)
|
||||
return Namespace(main_out=point_tensor, grid_out=grid_tensor)
|
||||
59
models/point_net_2_prim_clusters.py
Normal file
59
models/point_net_2_prim_clusters.py
Normal file
@@ -0,0 +1,59 @@
|
||||
from argparse import Namespace
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from datasets.full_pointclouds import FullCloudsDataset
|
||||
from models._point_net_2 import _PointNetCore
|
||||
|
||||
from utils.module_mixins import BaseValMixin, BaseTrainMixin, BaseOptimizerMixin, BaseDataloadersMixin, DatasetMixin
|
||||
from utils.project_config import GlobalVar
|
||||
|
||||
|
||||
class PointNet2PrimClusters(BaseValMixin,
|
||||
BaseTrainMixin,
|
||||
BaseOptimizerMixin,
|
||||
DatasetMixin,
|
||||
BaseDataloadersMixin,
|
||||
_PointNetCore
|
||||
):
|
||||
|
||||
def __init__(self, hparams):
|
||||
super(PointNet2PrimClusters, self).__init__(hparams=hparams)
|
||||
|
||||
# Dataset
|
||||
# =============================================================================
|
||||
self.dataset = self.build_dataset(FullCloudsDataset, setting='prim')
|
||||
|
||||
# Model Paramters
|
||||
# =============================================================================
|
||||
# Additional parameters
|
||||
|
||||
# Modules
|
||||
self.point_lin = torch.nn.Linear(128, len(GlobalVar.classes))
|
||||
self.prim_lin = torch.nn.Linear(128, len(GlobalVar.prims))
|
||||
|
||||
# Utility
|
||||
self.log_softmax = nn.LogSoftmax(dim=-1)
|
||||
|
||||
def forward(self, data, **kwargs):
|
||||
"""
|
||||
data: a batch of input torch_geometric.data.Data type
|
||||
- torch_geometric.data.Data, as torch_geometric batch input:
|
||||
data.x: (batch_size * ~num_points, C), batch nodes/points feature,
|
||||
~num_points means each sample can have different number of points/nodes
|
||||
|
||||
data.pos: (batch_size * ~num_points, 3)
|
||||
|
||||
data.batch: (batch_size * ~num_points,), a column vector of graph/pointcloud
|
||||
idendifiers for all nodes of all graphs/pointclouds in the batch. See
|
||||
pytorch_gemometric documentation for more information
|
||||
"""
|
||||
|
||||
sa0_out = (data.x, data.pos, data.batch)
|
||||
tensor = super(PointNet2PrimClusters, self).forward(sa0_out)
|
||||
point_tensor = self.point_lin(tensor)
|
||||
point_tensor = self.log_softmax(point_tensor)
|
||||
prim_tensor = self.prim_lin(tensor)
|
||||
prim_tensor = self.log_softmax(prim_tensor)
|
||||
return Namespace(main_out=point_tensor, prim_out=prim_tensor)
|
||||
Reference in New Issue
Block a user