New Model running
This commit is contained in:
@@ -1,3 +1,5 @@
|
||||
from abc import ABC
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch_geometric.transforms import Compose, NormalizeScale, RandomFlip
|
||||
@@ -7,28 +9,28 @@ from ml_lib.modules.util import LightningBaseModule, F_x
|
||||
from ml_lib.point_toolset.point_io import BatchToData
|
||||
|
||||
|
||||
class _PointNetCore(LightningBaseModule):
|
||||
class _PointNetCore(LightningBaseModule, ABC):
|
||||
|
||||
def __init__(self, hparams):
|
||||
super(_PointNetCore, self).__init__(hparams=hparams)
|
||||
|
||||
# Transforms
|
||||
# =============================================================================
|
||||
transforms = Compose([NormalizeScale(), RandomFlip(0, p=0.8), ])
|
||||
self.batch_to_data = BatchToData(transforms=transforms)
|
||||
self.batch_to_data = BatchToData(transforms=None)
|
||||
|
||||
# Model Paramters
|
||||
# =============================================================================
|
||||
# Additional parameters
|
||||
self.cord_dims = 6 if self.params.normals_as_cords else 3
|
||||
|
||||
# Modules
|
||||
self.sa1_module = SAModule(0.2, 0.2, MLP([3 + 3, 64, 64, 128]))
|
||||
self.sa2_module = SAModule(0.25, 0.4, MLP([128 + 3, 128, 128, 256]))
|
||||
self.sa3_module = GlobalSAModule(MLP([256 + 3, 256, 512, 1024]))
|
||||
self.sa1_module = SAModule(0.2, 0.2, MLP([self.cord_dims, 64, 64, 128]))
|
||||
self.sa2_module = SAModule(0.25, 0.4, MLP([128 + self.cord_dims, 128, 128, 256]))
|
||||
self.sa3_module = GlobalSAModule(MLP([256 + self.cord_dims, 256, 512, 1024]), channels=self.cord_dims)
|
||||
|
||||
self.fp3_module = FPModule(1, MLP([1024 + 256, 256, 256]))
|
||||
self.fp2_module = FPModule(3, MLP([256 + 128, 256, 128]))
|
||||
self.fp1_module = FPModule(3, MLP([128 + 3, 128, 128, 128]))
|
||||
self.fp1_module = FPModule(3, MLP([128, 128, 128, 128]))
|
||||
|
||||
self.lin1 = torch.nn.Linear(128, 128)
|
||||
self.lin2 = torch.nn.Linear(128, 128)
|
||||
|
||||
@@ -2,6 +2,7 @@ from argparse import Namespace
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch_geometric.transforms import Compose, RandomFlip, FixedPoints, RandomTranslate, NormalizeScale
|
||||
|
||||
from datasets.shapenet import ShapeNetPartSegDataset
|
||||
from models._point_net_2 import _PointNetCore
|
||||
@@ -21,21 +22,40 @@ class PointNet2(BaseValMixin,
|
||||
def __init__(self, hparams):
|
||||
super(PointNet2, self).__init__(hparams=hparams)
|
||||
|
||||
# Dataset
|
||||
# =============================================================================
|
||||
# rot_max_angle = 15
|
||||
trans_max_distance = 0.01
|
||||
transforms = Compose(
|
||||
[
|
||||
RandomFlip(0, p=0.8),
|
||||
FixedPoints(self.params.npoints),
|
||||
# This is not available with 6-dim cords
|
||||
# RandomRotate(rot_max_angle, 0), RandomRotate(rot_max_angle, 1), RandomRotate(rot_max_angle, 2),
|
||||
RandomTranslate(trans_max_distance),
|
||||
NormalizeScale()
|
||||
# NormalizePositions()
|
||||
]
|
||||
)
|
||||
|
||||
# Dataset
|
||||
# =============================================================================
|
||||
self.dataset = self.build_dataset(ShapeNetPartSegDataset,
|
||||
collate_per_segment=True,
|
||||
npoints=self.params.npoints
|
||||
transform=transforms,
|
||||
cluster_type=self.params.cluster_type,
|
||||
refresh=self.params.refresh,
|
||||
poly_as_plane=self.params.poly_as_plane
|
||||
)
|
||||
|
||||
# Model Paramters
|
||||
# =============================================================================
|
||||
# Additional parameters
|
||||
self.n_classes = len(GlobalVar.classes)
|
||||
self.n_classes = len(GlobalVar.classes) if not self.params.poly_as_plane else (len(GlobalVar.classes) - 2)
|
||||
|
||||
# Modules
|
||||
self.point_net_core = ()
|
||||
self.lin3 = torch.nn.Linear(128, len(GlobalVar.classes))
|
||||
self.lin3 = torch.nn.Linear(128, self.n_classes)
|
||||
|
||||
# Utility
|
||||
self.log_softmax = nn.LogSoftmax(dim=-1)
|
||||
@@ -53,7 +73,11 @@ class PointNet2(BaseValMixin,
|
||||
idendifiers for all nodes of all graphs/pointclouds in the batch. See
|
||||
pytorch_gemometric documentation for more information
|
||||
"""
|
||||
sa0_out = (data.norm, data.pos, data.batch)
|
||||
if not self.params.normals_as_cords:
|
||||
sa0_out = (data.norm, data.pos, data.batch)
|
||||
else:
|
||||
pos_cat_norm = torch.cat((data.pos, data.norm), dim=-1)
|
||||
sa0_out = (None, pos_cat_norm, data.batch)
|
||||
tensor = super(PointNet2, self).forward(sa0_out)
|
||||
tensor = self.lin3(tensor)
|
||||
tensor = self.log_softmax(tensor)
|
||||
|
||||
Reference in New Issue
Block a user