Dataset Redone

This commit is contained in:
Si11ium
2020-06-19 08:17:35 +02:00
parent 4898e98851
commit 63605ae33a
14 changed files with 239 additions and 362 deletions

View File

@@ -1,4 +1 @@
from .point_net_2 import PointNet2
from .point_net_2_grid_clusters import PointNet2GridClusters
from .point_net_2_prim_clusters import PointNet2PrimClusters

View File

@@ -3,7 +3,7 @@ from argparse import Namespace
import torch
from torch import nn
from datasets.grid_clusters import GridClusters
from datasets.shapenet import ShapeNetPartSegDataset
from models._point_net_2 import _PointNetCore
from utils.module_mixins import BaseValMixin, BaseTrainMixin, BaseOptimizerMixin, BaseDataloadersMixin, DatasetMixin
@@ -23,7 +23,8 @@ class PointNet2(BaseValMixin,
# Dataset
# =============================================================================
self.dataset = self.build_dataset(GridClusters, setting='pc', sampling_k=self.params.sampling_k)
self.dataset = self.build_dataset(ShapeNetPartSegDataset, collate_per_segment=True,
npoints=self.params.npoints)
# Model Paramters
# =============================================================================

View File

@@ -1,80 +0,0 @@
from argparse import Namespace
import torch
from torch import nn
from torch_geometric.data import Data
from torch_geometric.transforms import Compose, NormalizeScale, RandomFlip
from datasets.grid_clusters import GridClusters
from models._point_net_2 import _PointNetCore
from utils.module_mixins import BaseValMixin, BaseTrainMixin, BaseOptimizerMixin, BaseDataloadersMixin, DatasetMixin
from utils.project_config import GlobalVar
class PointNet2GridClusters(BaseValMixin,
BaseTrainMixin,
BaseOptimizerMixin,
DatasetMixin,
BaseDataloadersMixin,
_PointNetCore
):
def training_step(self, batch_pos_x_n_y_c, batch_nb, *_, **__):
data = self.batch_to_data(*batch_pos_x_n_y_c) if not isinstance(batch_pos_x_n_y_c, Data) else batch_pos_x_n_y_c
y = self(data)
nll_main_loss = self.nll_loss(y.main_out, data.yl)
nll_cluster_loss = self.nll_loss(y.grid_out, data.yc)
nll_loss = nll_main_loss + nll_cluster_loss
return dict(loss=nll_loss, log=dict(batch_nb=batch_nb),
nll_cluster_loss=nll_cluster_loss, nll_main_loss=nll_main_loss)
def validation_step(self, batch_pos_x_n_y_c, batch_idx, *_, **__):
data = self.batch_to_data(*batch_pos_x_n_y_c) if not isinstance(batch_pos_x_n_y_c, Data) else batch_pos_x_n_y_c
y = self(data)
nll_main_loss = self.nll_loss(y.main_out, data.yl)
nll_cluster_loss = self.nll_loss(y.grid_out, data.yc)
nll_loss = nll_main_loss + nll_cluster_loss
return dict(val_nll_loss=nll_loss, nll_cluster_loss=nll_cluster_loss, nll_main_loss=nll_main_loss,
batch_idx=batch_idx, y=y.main_out, batch_y=data.yl)
def __init__(self, hparams):
super(PointNet2GridClusters, self).__init__(hparams=hparams)
# Dataset
# =============================================================================
self.dataset = self.build_dataset(GridClusters, setting='grid', sampling_k=self.params.sampling_k)
# Model Paramters
# =============================================================================
# Additional parameters
self.n_classes = len(GlobalVar.classes)
# Modules
self.point_lin = torch.nn.Linear(128, len(GlobalVar.classes))
self.grid_lin = torch.nn.Linear(128, GlobalVar.grid_count)
# Utility
self.log_softmax = nn.LogSoftmax(dim=-1)
def forward(self, data, **kwargs):
"""
data: a batch of input torch_geometric.data.Data type
- torch_geometric.data.Data, as torch_geometric batch input:
data.x: (batch_size * ~num_points, C), batch nodes/points feature,
~num_points means each sample can have different number of points/nodes
data.pos: (batch_size * ~num_points, 3)
data.batch: (batch_size * ~num_points,), a column vector of graph/pointcloud
idendifiers for all nodes of all graphs/pointclouds in the batch. See
pytorch_gemometric documentation for more information
"""
sa0_out = (data.norm, data.pos, data.batch)
tensor = super(PointNet2GridClusters, self).forward(sa0_out)
point_tensor = self.point_lin(tensor)
point_tensor = self.log_softmax(point_tensor)
grid_tensor = self.grid_lin(tensor)
grid_tensor = self.log_softmax(grid_tensor)
return Namespace(main_out=point_tensor, grid_out=grid_tensor)

View File

@@ -1,59 +0,0 @@
from argparse import Namespace
import torch
from torch import nn
from datasets.full_pointclouds import FullCloudsDataset
from models._point_net_2 import _PointNetCore
from utils.module_mixins import BaseValMixin, BaseTrainMixin, BaseOptimizerMixin, BaseDataloadersMixin, DatasetMixin
from utils.project_config import GlobalVar
class PointNet2PrimClusters(BaseValMixin,
BaseTrainMixin,
BaseOptimizerMixin,
DatasetMixin,
BaseDataloadersMixin,
_PointNetCore
):
def __init__(self, hparams):
super(PointNet2PrimClusters, self).__init__(hparams=hparams)
# Dataset
# =============================================================================
self.dataset = self.build_dataset(FullCloudsDataset, setting='prim')
# Model Paramters
# =============================================================================
# Additional parameters
# Modules
self.point_lin = torch.nn.Linear(128, len(GlobalVar.classes))
self.prim_lin = torch.nn.Linear(128, len(GlobalVar.prims))
# Utility
self.log_softmax = nn.LogSoftmax(dim=-1)
def forward(self, data, **kwargs):
"""
data: a batch of input torch_geometric.data.Data type
- torch_geometric.data.Data, as torch_geometric batch input:
data.x: (batch_size * ~num_points, C), batch nodes/points feature,
~num_points means each sample can have different number of points/nodes
data.pos: (batch_size * ~num_points, 3)
data.batch: (batch_size * ~num_points,), a column vector of graph/pointcloud
idendifiers for all nodes of all graphs/pointclouds in the batch. See
pytorch_gemometric documentation for more information
"""
sa0_out = (data.norm, data.pos, data.batch)
tensor = super(PointNet2PrimClusters, self).forward(sa0_out)
point_tensor = self.point_lin(tensor)
point_tensor = self.log_softmax(point_tensor)
prim_tensor = self.prim_lin(tensor)
prim_tensor = self.log_softmax(prim_tensor)
return Namespace(main_out=point_tensor, prim_out=prim_tensor)