New Dataset for per spatial cluster training
This commit is contained in:
@@ -32,7 +32,7 @@ class _PointNetCore(LightningBaseModule):
|
||||
|
||||
def forward(self, sa0_out, **kwargs):
|
||||
"""
|
||||
data: a batch of input torch_geometric.data.Data type
|
||||
sa0_out: a batch of input torch_geometric.data.Data type
|
||||
- torch_geometric.data.Data, as torch_geometric batch input:
|
||||
data.x: (batch_size * ~num_points, C), batch nodes/points feature,
|
||||
~num_points means each sample can have different number of points/nodes
|
||||
|
||||
@@ -3,7 +3,7 @@ from argparse import Namespace
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from datasets.full_pointclouds import FullCloudsDataset
|
||||
from datasets.grid_clusters import GridClusters
|
||||
from models._point_net_2 import _PointNetCore
|
||||
|
||||
from utils.module_mixins import BaseValMixin, BaseTrainMixin, BaseOptimizerMixin, BaseDataloadersMixin, DatasetMixin
|
||||
@@ -23,7 +23,7 @@ class PointNet2(BaseValMixin,
|
||||
|
||||
# Dataset
|
||||
# =============================================================================
|
||||
self.dataset = self.build_dataset(FullCloudsDataset, setting='pc')
|
||||
self.dataset = self.build_dataset(GridClusters, setting='pc')
|
||||
|
||||
# Model Paramters
|
||||
# =============================================================================
|
||||
|
||||
@@ -4,7 +4,7 @@ import torch
|
||||
from torch import nn
|
||||
from torch_geometric.data import Data
|
||||
|
||||
from datasets.full_pointclouds import FullCloudsDataset
|
||||
from datasets.grid_clusters import GridClusters
|
||||
from models._point_net_2 import _PointNetCore
|
||||
|
||||
from utils.module_mixins import BaseValMixin, BaseTrainMixin, BaseOptimizerMixin, BaseDataloadersMixin, DatasetMixin
|
||||
@@ -42,7 +42,7 @@ class PointNet2GridClusters(BaseValMixin,
|
||||
|
||||
# Dataset
|
||||
# =============================================================================
|
||||
self.dataset = self.build_dataset(FullCloudsDataset, setting='grid')
|
||||
self.dataset = self.build_dataset(GridClusters, setting='grid')
|
||||
|
||||
# Model Paramters
|
||||
# =============================================================================
|
||||
|
||||
Reference in New Issue
Block a user