pointnet2 working - TODO: Eval!
This commit is contained in:
1
models/__init__.py
Normal file
1
models/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from .point_net_2 import PointNet2
|
103
models/point_net_2.py
Normal file
103
models/point_net_2.py
Normal file
@ -0,0 +1,103 @@
|
||||
from argparse import Namespace
|
||||
|
||||
import torch
|
||||
from torch.optim import Adam
|
||||
from torch import nn
|
||||
from torch_geometric.data import Data
|
||||
|
||||
from datasets.full_pointclouds import FullCloudsDataset
|
||||
from ml_lib.modules.geometric_blocks import SAModule, GlobalSAModule, MLP
|
||||
from ml_lib.modules.util import LightningBaseModule, F_x
|
||||
|
||||
from utils.module_mixins import BaseValMixin, BaseTrainMixin, BaseOptimizerMixin, BaseDataloadersMixin, DatasetMixin
|
||||
|
||||
|
||||
class PointNet2(BaseValMixin,
|
||||
BaseTrainMixin,
|
||||
BaseOptimizerMixin,
|
||||
DatasetMixin,
|
||||
BaseDataloadersMixin,
|
||||
LightningBaseModule
|
||||
):
|
||||
|
||||
def __init__(self, hparams):
|
||||
super(PointNet2, self).__init__(hparams=hparams)
|
||||
|
||||
# Dataset
|
||||
# =============================================================================
|
||||
self.dataset = self.build_dataset(FullCloudsDataset)
|
||||
|
||||
# Model Paramters
|
||||
# =============================================================================
|
||||
# Additional parameters
|
||||
|
||||
self.in_shape = self.dataset.train_dataset.sample_shape
|
||||
self.channels = self.in_shape[-1]
|
||||
|
||||
# Modules
|
||||
self.sa1_module = SAModule(0.5, 0.2, MLP([self.channels, 64, 64, 128]))
|
||||
self.sa2_module = SAModule(0.25, 0.4, MLP([128 + self.channels, 128, 128, 256]))
|
||||
self.sa3_module = GlobalSAModule(MLP([256 + self.channels, 256, 512, 1024]))
|
||||
|
||||
self.lin1 = nn.Linear(1024, 512)
|
||||
self.lin2 = nn.Linear(512, 256)
|
||||
self.lin3 = nn.Linear(256, 10)
|
||||
|
||||
# Utility
|
||||
self.dropout = nn.Dropout(self.params.dropout) if self.params.dropout else F_x(None)
|
||||
self.activation = self.params.activation()
|
||||
self.log_softmax = nn.LogSoftmax(dim=-1)
|
||||
|
||||
def configure_optimizers(self):
|
||||
return Adam(self.parameters(), lr=self.params.lr, weight_decay=self.params.weight_decay)
|
||||
|
||||
def forward(self, data, **kwargs):
|
||||
"""
|
||||
data: a batch of input, torch.Tensor or torch_geometric.data.Data type
|
||||
- torch.Tensor: (batch_size, 3, num_points), as common batch input
|
||||
|
||||
- torch_geometric.data.Data, as torch_geometric batch input:
|
||||
data.x: (batch_size * ~num_points, C), batch nodes/points feature,
|
||||
~num_points means each sample can have different number of points/nodes
|
||||
|
||||
data.pos: (batch_size * ~num_points, 3)
|
||||
|
||||
data.batch: (batch_size * ~num_points,), a column vector of graph/pointcloud
|
||||
idendifiers for all nodes of all graphs/pointclouds in the batch. See
|
||||
pytorch_gemometric documentation for more information
|
||||
"""
|
||||
dense_input = True if isinstance(data, torch.Tensor) else False
|
||||
|
||||
if dense_input:
|
||||
# Convert to torch_geometric.data.Data type
|
||||
# data = data.transpose(1, 2).contiguous()
|
||||
batch_size, N, _ = data.shape # (batch_size, num_points, 6)
|
||||
pos = data.view(batch_size*N, -1)
|
||||
batch = torch.zeros((batch_size, N), device=pos.device, dtype=torch.long)
|
||||
for i in range(batch_size):
|
||||
batch[i] = i
|
||||
batch = batch.view(-1)
|
||||
|
||||
data = Data()
|
||||
data.pos, data.batch = pos, batch
|
||||
|
||||
if not hasattr(data, 'x'):
|
||||
data.x = None
|
||||
|
||||
sa0_out = (data.x, data.pos, data.batch)
|
||||
sa1_out = self.sa1_module(*sa0_out)
|
||||
sa2_out = self.sa2_module(*sa1_out)
|
||||
sa3_out = self.sa3_module(*sa2_out)
|
||||
|
||||
tensor, pos, batch = sa3_out
|
||||
tensor = tensor.float()
|
||||
|
||||
tensor = self.lin1(tensor)
|
||||
tensor = self.activation(tensor)
|
||||
tensor = self.dropout(tensor)
|
||||
tensor = self.lin2(tensor)
|
||||
tensor = self.activation(tensor)
|
||||
tensor = self.dropout(tensor)
|
||||
tensor = self.lin3(tensor)
|
||||
tensor = self.log_softmax(tensor)
|
||||
return Namespace(main_out=tensor)
|
Reference in New Issue
Block a user