Normalization and transforms for batch_to_data class
This commit is contained in:
@ -21,6 +21,7 @@ main_arg_parser.add_argument("--project_neptune_key", type=str, default=os.geten
|
|||||||
|
|
||||||
# Data Parameters
|
# Data Parameters
|
||||||
main_arg_parser.add_argument("--data_worker", type=int, default=10, help="")
|
main_arg_parser.add_argument("--data_worker", type=int, default=10, help="")
|
||||||
|
main_arg_parser.add_argument("--data_sampling_k", type=int, default=1024, help="")
|
||||||
main_arg_parser.add_argument("--data_root", type=str, default='data', help="")
|
main_arg_parser.add_argument("--data_root", type=str, default='data', help="")
|
||||||
main_arg_parser.add_argument("--data_dataset_type", type=str, default='GridClusters', help="")
|
main_arg_parser.add_argument("--data_dataset_type", type=str, default='GridClusters', help="")
|
||||||
main_arg_parser.add_argument("--data_use_preprocessed", type=strtobool, default=True, help="")
|
main_arg_parser.add_argument("--data_use_preprocessed", type=strtobool, default=True, help="")
|
||||||
|
@ -25,12 +25,10 @@ class _Point_Dataset(ABC, Dataset):
|
|||||||
samplers = dict(fps=FarthestpointSampling, rnd=RandomSampling)
|
samplers = dict(fps=FarthestpointSampling, rnd=RandomSampling)
|
||||||
|
|
||||||
def __init__(self, root=Path('data'), norm_as_feature=True, sampling_k=2048, sampling='rnd',
|
def __init__(self, root=Path('data'), norm_as_feature=True, sampling_k=2048, sampling='rnd',
|
||||||
transforms=None, load_preprocessed=True, split='train', dense_output=False, *args, **kwargs):
|
transforms=None, load_preprocessed=True, split='train', *args, **kwargs):
|
||||||
super(_Point_Dataset, self).__init__()
|
super(_Point_Dataset, self).__init__()
|
||||||
|
|
||||||
self.setting: str
|
self.setting: str
|
||||||
|
|
||||||
self.dense_output = dense_output
|
|
||||||
self.split = split
|
self.split = split
|
||||||
self.norm_as_feature = norm_as_feature
|
self.norm_as_feature = norm_as_feature
|
||||||
self.load_preprocessed = load_preprocessed
|
self.load_preprocessed = load_preprocessed
|
||||||
|
@ -72,8 +72,12 @@ class GridClusters(_Point_Dataset):
|
|||||||
while sample_idxs.shape[0] < self.sampling_k:
|
while sample_idxs.shape[0] < self.sampling_k:
|
||||||
sample_idxs = np.concatenate((sample_idxs, sample_idxs))[:self.sampling_k]
|
sample_idxs = np.concatenate((sample_idxs, sample_idxs))[:self.sampling_k]
|
||||||
|
|
||||||
return (normal[sample_idxs].astype(np.float),
|
normal = normal[sample_idxs].astype(np.float)
|
||||||
position[sample_idxs].astype(np.float),
|
position = position[sample_idxs].astype(np.float)
|
||||||
|
|
||||||
|
normal = self.transforms(normal)
|
||||||
|
position = self.transforms(position)
|
||||||
|
return (normal, position,
|
||||||
label[sample_idxs].astype(np.int),
|
label[sample_idxs].astype(np.int),
|
||||||
cl_label[sample_idxs].astype(np.int)
|
cl_label[sample_idxs].astype(np.int)
|
||||||
)
|
)
|
||||||
|
@ -1,8 +1,10 @@
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
from torch_geometric.transforms import Compose, NormalizeScale, RandomFlip
|
||||||
|
|
||||||
from ml_lib.modules.geometric_blocks import SAModule, GlobalSAModule, MLP, FPModule
|
from ml_lib.modules.geometric_blocks import SAModule, GlobalSAModule, MLP, FPModule
|
||||||
from ml_lib.modules.util import LightningBaseModule, F_x
|
from ml_lib.modules.util import LightningBaseModule, F_x
|
||||||
|
from ml_lib.point_toolset.point_io import BatchToData
|
||||||
|
|
||||||
|
|
||||||
class _PointNetCore(LightningBaseModule):
|
class _PointNetCore(LightningBaseModule):
|
||||||
@ -10,6 +12,11 @@ class _PointNetCore(LightningBaseModule):
|
|||||||
def __init__(self, hparams):
|
def __init__(self, hparams):
|
||||||
super(_PointNetCore, self).__init__(hparams=hparams)
|
super(_PointNetCore, self).__init__(hparams=hparams)
|
||||||
|
|
||||||
|
# Transforms
|
||||||
|
# =============================================================================
|
||||||
|
transforms = Compose([NormalizeScale(), RandomFlip(0, p=0.8), ])
|
||||||
|
self.batch_to_data = BatchToData(transforms=transforms)
|
||||||
|
|
||||||
# Model Paramters
|
# Model Paramters
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Additional parameters
|
# Additional parameters
|
||||||
|
@ -23,7 +23,7 @@ class PointNet2(BaseValMixin,
|
|||||||
|
|
||||||
# Dataset
|
# Dataset
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
self.dataset = self.build_dataset(GridClusters, setting='pc')
|
self.dataset = self.build_dataset(GridClusters, setting='pc', sampling_k=self.params.sampling_k)
|
||||||
|
|
||||||
# Model Paramters
|
# Model Paramters
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
@ -50,8 +50,7 @@ class PointNet2(BaseValMixin,
|
|||||||
idendifiers for all nodes of all graphs/pointclouds in the batch. See
|
idendifiers for all nodes of all graphs/pointclouds in the batch. See
|
||||||
pytorch_gemometric documentation for more information
|
pytorch_gemometric documentation for more information
|
||||||
"""
|
"""
|
||||||
|
sa0_out = (data.norm, data.pos, data.batch)
|
||||||
sa0_out = (data.x, data.pos, data.batch)
|
|
||||||
tensor = super(PointNet2, self).forward(sa0_out)
|
tensor = super(PointNet2, self).forward(sa0_out)
|
||||||
tensor = self.lin3(tensor)
|
tensor = self.lin3(tensor)
|
||||||
tensor = self.log_softmax(tensor)
|
tensor = self.log_softmax(tensor)
|
||||||
|
@ -3,6 +3,7 @@ from argparse import Namespace
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch_geometric.data import Data
|
from torch_geometric.data import Data
|
||||||
|
from torch_geometric.transforms import Compose, NormalizeScale, RandomFlip
|
||||||
|
|
||||||
from datasets.grid_clusters import GridClusters
|
from datasets.grid_clusters import GridClusters
|
||||||
from models._point_net_2 import _PointNetCore
|
from models._point_net_2 import _PointNetCore
|
||||||
@ -42,7 +43,7 @@ class PointNet2GridClusters(BaseValMixin,
|
|||||||
|
|
||||||
# Dataset
|
# Dataset
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
self.dataset = self.build_dataset(GridClusters, setting='grid')
|
self.dataset = self.build_dataset(GridClusters, setting='grid', sampling_k=self.params.sampling_k)
|
||||||
|
|
||||||
# Model Paramters
|
# Model Paramters
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
@ -70,7 +71,7 @@ class PointNet2GridClusters(BaseValMixin,
|
|||||||
pytorch_gemometric documentation for more information
|
pytorch_gemometric documentation for more information
|
||||||
"""
|
"""
|
||||||
|
|
||||||
sa0_out = (data.x, data.pos, data.batch)
|
sa0_out = (data.norm, data.pos, data.batch)
|
||||||
tensor = super(PointNet2GridClusters, self).forward(sa0_out)
|
tensor = super(PointNet2GridClusters, self).forward(sa0_out)
|
||||||
point_tensor = self.point_lin(tensor)
|
point_tensor = self.point_lin(tensor)
|
||||||
point_tensor = self.log_softmax(point_tensor)
|
point_tensor = self.log_softmax(point_tensor)
|
||||||
|
@ -50,7 +50,7 @@ class PointNet2PrimClusters(BaseValMixin,
|
|||||||
pytorch_gemometric documentation for more information
|
pytorch_gemometric documentation for more information
|
||||||
"""
|
"""
|
||||||
|
|
||||||
sa0_out = (data.x, data.pos, data.batch)
|
sa0_out = (data.norm, data.pos, data.batch)
|
||||||
tensor = super(PointNet2PrimClusters, self).forward(sa0_out)
|
tensor = super(PointNet2PrimClusters, self).forward(sa0_out)
|
||||||
point_tensor = self.point_lin(tensor)
|
point_tensor = self.point_lin(tensor)
|
||||||
point_tensor = self.log_softmax(point_tensor)
|
point_tensor = self.log_softmax(point_tensor)
|
||||||
|
@ -16,13 +16,12 @@ from torch import nn
|
|||||||
from torch.optim import Adam
|
from torch.optim import Adam
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from torch_geometric.data import Data
|
from torch_geometric.data import Data
|
||||||
|
|
||||||
from torchcontrib.optim import SWA
|
from torchcontrib.optim import SWA
|
||||||
from torchvision.transforms import Compose
|
|
||||||
|
|
||||||
from ml_lib.modules.util import LightningBaseModule
|
from ml_lib.modules.util import LightningBaseModule
|
||||||
from ml_lib.utils.tools import to_one_hot
|
from ml_lib.utils.tools import to_one_hot
|
||||||
from ml_lib.utils.transforms import ToTensor
|
|
||||||
from ml_lib.point_toolset.point_io import BatchToData
|
|
||||||
|
|
||||||
from .project_config import GlobalVar
|
from .project_config import GlobalVar
|
||||||
|
|
||||||
@ -59,12 +58,10 @@ class BaseTrainMixin:
|
|||||||
nll_loss = nn.NLLLoss()
|
nll_loss = nn.NLLLoss()
|
||||||
# Binary Cross Entropy
|
# Binary Cross Entropy
|
||||||
bce_loss = nn.BCELoss()
|
bce_loss = nn.BCELoss()
|
||||||
# Batch To Data
|
|
||||||
batch_to_data = BatchToData()
|
|
||||||
|
|
||||||
def training_step(self, batch_pos_x_n_y_c, batch_nb, *_, **__):
|
def training_step(self, batch_norm_pos_y_c, batch_nb, *_, **__):
|
||||||
assert isinstance(self, LightningBaseModule)
|
assert isinstance(self, LightningBaseModule)
|
||||||
data = self.batch_to_data(*batch_pos_x_n_y_c) if not isinstance(batch_pos_x_n_y_c, Data) else batch_pos_x_n_y_c
|
data = self.batch_to_data(*batch_norm_pos_y_c) if not isinstance(batch_norm_pos_y_c, Data) else batch_norm_pos_y_c
|
||||||
y = self(data).main_out
|
y = self(data).main_out
|
||||||
nll_loss = self.nll_loss(y, data.yl)
|
nll_loss = self.nll_loss(y, data.yl)
|
||||||
return dict(loss=nll_loss, log=dict(batch_nb=batch_nb))
|
return dict(loss=nll_loss, log=dict(batch_nb=batch_nb))
|
||||||
@ -87,8 +84,6 @@ class BaseValMixin:
|
|||||||
nll_loss = nn.NLLLoss()
|
nll_loss = nn.NLLLoss()
|
||||||
# Binary Cross Entropy
|
# Binary Cross Entropy
|
||||||
bce_loss = nn.BCELoss()
|
bce_loss = nn.BCELoss()
|
||||||
# Batch To Data
|
|
||||||
batch_to_data = BatchToData()
|
|
||||||
|
|
||||||
def validation_step(self, batch_pos_x_n_y_c, batch_idx, *_, **__):
|
def validation_step(self, batch_pos_x_n_y_c, batch_idx, *_, **__):
|
||||||
assert isinstance(self, LightningBaseModule)
|
assert isinstance(self, LightningBaseModule)
|
||||||
@ -230,14 +225,11 @@ class DatasetMixin:
|
|||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Data Augmentations or Utility Transformations
|
# Data Augmentations or Utility Transformations
|
||||||
|
|
||||||
transforms = Compose([ToTensor()])
|
|
||||||
|
|
||||||
# Dataset
|
# Dataset
|
||||||
dataset = Namespace(
|
dataset = Namespace(
|
||||||
**dict(
|
**dict(
|
||||||
# TRAIN DATASET
|
# TRAIN DATASET
|
||||||
train_dataset=dataset_class(self.params.root, split=GlobalVar.data_split.train,
|
train_dataset=dataset_class(self.params.root, split=GlobalVar.data_split.train, **kwargs),
|
||||||
transforms=transforms, **kwargs),
|
|
||||||
|
|
||||||
# VALIDATION DATASET
|
# VALIDATION DATASET
|
||||||
val_dataset=dataset_class(self.params.root, split=GlobalVar.data_split.devel,
|
val_dataset=dataset_class(self.params.root, split=GlobalVar.data_split.devel,
|
||||||
|
Reference in New Issue
Block a user