Normalization and transforms for batch_to_data class
This commit is contained in:
@@ -1,8 +1,10 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch_geometric.transforms import Compose, NormalizeScale, RandomFlip
|
||||
|
||||
from ml_lib.modules.geometric_blocks import SAModule, GlobalSAModule, MLP, FPModule
|
||||
from ml_lib.modules.util import LightningBaseModule, F_x
|
||||
from ml_lib.point_toolset.point_io import BatchToData
|
||||
|
||||
|
||||
class _PointNetCore(LightningBaseModule):
|
||||
@@ -10,6 +12,11 @@ class _PointNetCore(LightningBaseModule):
|
||||
def __init__(self, hparams):
|
||||
super(_PointNetCore, self).__init__(hparams=hparams)
|
||||
|
||||
# Transforms
|
||||
# =============================================================================
|
||||
transforms = Compose([NormalizeScale(), RandomFlip(0, p=0.8), ])
|
||||
self.batch_to_data = BatchToData(transforms=transforms)
|
||||
|
||||
# Model Paramters
|
||||
# =============================================================================
|
||||
# Additional parameters
|
||||
|
||||
@@ -23,7 +23,7 @@ class PointNet2(BaseValMixin,
|
||||
|
||||
# Dataset
|
||||
# =============================================================================
|
||||
self.dataset = self.build_dataset(GridClusters, setting='pc')
|
||||
self.dataset = self.build_dataset(GridClusters, setting='pc', sampling_k=self.params.sampling_k)
|
||||
|
||||
# Model Paramters
|
||||
# =============================================================================
|
||||
@@ -50,8 +50,7 @@ class PointNet2(BaseValMixin,
|
||||
idendifiers for all nodes of all graphs/pointclouds in the batch. See
|
||||
pytorch_gemometric documentation for more information
|
||||
"""
|
||||
|
||||
sa0_out = (data.x, data.pos, data.batch)
|
||||
sa0_out = (data.norm, data.pos, data.batch)
|
||||
tensor = super(PointNet2, self).forward(sa0_out)
|
||||
tensor = self.lin3(tensor)
|
||||
tensor = self.log_softmax(tensor)
|
||||
|
||||
@@ -3,6 +3,7 @@ from argparse import Namespace
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch_geometric.data import Data
|
||||
from torch_geometric.transforms import Compose, NormalizeScale, RandomFlip
|
||||
|
||||
from datasets.grid_clusters import GridClusters
|
||||
from models._point_net_2 import _PointNetCore
|
||||
@@ -42,7 +43,7 @@ class PointNet2GridClusters(BaseValMixin,
|
||||
|
||||
# Dataset
|
||||
# =============================================================================
|
||||
self.dataset = self.build_dataset(GridClusters, setting='grid')
|
||||
self.dataset = self.build_dataset(GridClusters, setting='grid', sampling_k=self.params.sampling_k)
|
||||
|
||||
# Model Paramters
|
||||
# =============================================================================
|
||||
@@ -70,7 +71,7 @@ class PointNet2GridClusters(BaseValMixin,
|
||||
pytorch_gemometric documentation for more information
|
||||
"""
|
||||
|
||||
sa0_out = (data.x, data.pos, data.batch)
|
||||
sa0_out = (data.norm, data.pos, data.batch)
|
||||
tensor = super(PointNet2GridClusters, self).forward(sa0_out)
|
||||
point_tensor = self.point_lin(tensor)
|
||||
point_tensor = self.log_softmax(point_tensor)
|
||||
|
||||
@@ -50,7 +50,7 @@ class PointNet2PrimClusters(BaseValMixin,
|
||||
pytorch_gemometric documentation for more information
|
||||
"""
|
||||
|
||||
sa0_out = (data.x, data.pos, data.batch)
|
||||
sa0_out = (data.norm, data.pos, data.batch)
|
||||
tensor = super(PointNet2PrimClusters, self).forward(sa0_out)
|
||||
point_tensor = self.point_lin(tensor)
|
||||
point_tensor = self.log_softmax(point_tensor)
|
||||
|
||||
Reference in New Issue
Block a user