Normalization and transforms for batch_to_data class

This commit is contained in:
Si11ium
2020-06-15 15:14:08 +02:00
parent bc70f42c74
commit 4898e98851
8 changed files with 26 additions and 24 deletions

View File

@@ -1,8 +1,10 @@
import torch
from torch import nn
from torch_geometric.transforms import Compose, NormalizeScale, RandomFlip
from ml_lib.modules.geometric_blocks import SAModule, GlobalSAModule, MLP, FPModule
from ml_lib.modules.util import LightningBaseModule, F_x
from ml_lib.point_toolset.point_io import BatchToData
class _PointNetCore(LightningBaseModule):
@@ -10,6 +12,11 @@ class _PointNetCore(LightningBaseModule):
def __init__(self, hparams):
super(_PointNetCore, self).__init__(hparams=hparams)
# Transforms
# =============================================================================
transforms = Compose([NormalizeScale(), RandomFlip(0, p=0.8), ])
self.batch_to_data = BatchToData(transforms=transforms)
# Model Paramters
# =============================================================================
# Additional parameters

View File

@@ -23,7 +23,7 @@ class PointNet2(BaseValMixin,
# Dataset
# =============================================================================
self.dataset = self.build_dataset(GridClusters, setting='pc')
self.dataset = self.build_dataset(GridClusters, setting='pc', sampling_k=self.params.sampling_k)
# Model Paramters
# =============================================================================
@@ -50,8 +50,7 @@ class PointNet2(BaseValMixin,
idendifiers for all nodes of all graphs/pointclouds in the batch. See
pytorch_gemometric documentation for more information
"""
sa0_out = (data.x, data.pos, data.batch)
sa0_out = (data.norm, data.pos, data.batch)
tensor = super(PointNet2, self).forward(sa0_out)
tensor = self.lin3(tensor)
tensor = self.log_softmax(tensor)

View File

@@ -3,6 +3,7 @@ from argparse import Namespace
import torch
from torch import nn
from torch_geometric.data import Data
from torch_geometric.transforms import Compose, NormalizeScale, RandomFlip
from datasets.grid_clusters import GridClusters
from models._point_net_2 import _PointNetCore
@@ -42,7 +43,7 @@ class PointNet2GridClusters(BaseValMixin,
# Dataset
# =============================================================================
self.dataset = self.build_dataset(GridClusters, setting='grid')
self.dataset = self.build_dataset(GridClusters, setting='grid', sampling_k=self.params.sampling_k)
# Model Paramters
# =============================================================================
@@ -70,7 +71,7 @@ class PointNet2GridClusters(BaseValMixin,
pytorch_gemometric documentation for more information
"""
sa0_out = (data.x, data.pos, data.batch)
sa0_out = (data.norm, data.pos, data.batch)
tensor = super(PointNet2GridClusters, self).forward(sa0_out)
point_tensor = self.point_lin(tensor)
point_tensor = self.log_softmax(point_tensor)

View File

@@ -50,7 +50,7 @@ class PointNet2PrimClusters(BaseValMixin,
pytorch_gemometric documentation for more information
"""
sa0_out = (data.x, data.pos, data.batch)
sa0_out = (data.norm, data.pos, data.batch)
tensor = super(PointNet2PrimClusters, self).forward(sa0_out)
point_tensor = self.point_lin(tensor)
point_tensor = self.log_softmax(point_tensor)