eval running - offline logger implemented -> Test it!

This commit is contained in:
Si11ium
2020-05-30 18:12:42 +02:00
parent ba7c0280ae
commit 8d0577b756
9 changed files with 212 additions and 102 deletions

View File

@@ -1,15 +1,16 @@
from argparse import Namespace
import torch.nn.functional as F
import torch
from torch.optim import Adam
from torch import nn
from torch_geometric.data import Data
from datasets.full_pointclouds import FullCloudsDataset
from ml_lib.modules.geometric_blocks import SAModule, GlobalSAModule, MLP
from ml_lib.modules.geometric_blocks import SAModule, GlobalSAModule, MLP, FPModule
from ml_lib.modules.util import LightningBaseModule, F_x
from utils.module_mixins import BaseValMixin, BaseTrainMixin, BaseOptimizerMixin, BaseDataloadersMixin, DatasetMixin
from utils.project_config import GlobalVar
class PointNet2(BaseValMixin,
@@ -31,31 +32,27 @@ class PointNet2(BaseValMixin,
# =============================================================================
# Additional parameters
self.in_shape = self.dataset.train_dataset.sample_shape
self.channels = self.in_shape[-1]
# Modules
self.sa1_module = SAModule(0.5, 0.2, MLP([self.channels, 64, 64, 128]))
self.sa2_module = SAModule(0.25, 0.4, MLP([128 + self.channels, 128, 128, 256]))
self.sa3_module = GlobalSAModule(MLP([256 + self.channels, 256, 512, 1024]))
self.sa1_module = SAModule(0.2, 0.2, MLP([3 + 3, 64, 64, 128]))
self.sa2_module = SAModule(0.25, 0.4, MLP([128 + 3, 128, 128, 256]))
self.sa3_module = GlobalSAModule(MLP([256 + 3, 256, 512, 1024]))
self.lin1 = nn.Linear(1024, 512)
self.lin2 = nn.Linear(512, 256)
self.lin3 = nn.Linear(256, 10)
self.fp3_module = FPModule(1, MLP([1024 + 256, 256, 256]))
self.fp2_module = FPModule(3, MLP([256 + 128, 256, 128]))
self.fp1_module = FPModule(3, MLP([128 + 3, 128, 128, 128]))
self.lin1 = torch.nn.Linear(128, 128)
self.lin2 = torch.nn.Linear(128, 128)
self.lin3 = torch.nn.Linear(128, len(GlobalVar.classes))
# Utility
self.dropout = nn.Dropout(self.params.dropout) if self.params.dropout else F_x(None)
self.activation = self.params.activation()
self.log_softmax = nn.LogSoftmax(dim=-1)
def configure_optimizers(self):
return Adam(self.parameters(), lr=self.params.lr, weight_decay=self.params.weight_decay)
def forward(self, data, **kwargs):
"""
data: a batch of input, torch.Tensor or torch_geometric.data.Data type
- torch.Tensor: (batch_size, 3, num_points), as common batch input
data: a batch of input torch_geometric.data.Data type
- torch_geometric.data.Data, as torch_geometric batch input:
data.x: (batch_size * ~num_points, C), batch nodes/points feature,
~num_points means each sample can have different number of points/nodes
@@ -66,37 +63,22 @@ class PointNet2(BaseValMixin,
idendifiers for all nodes of all graphs/pointclouds in the batch. See
pytorch_gemometric documentation for more information
"""
dense_input = True if isinstance(data, torch.Tensor) else False
if dense_input:
# Convert to torch_geometric.data.Data type
# data = data.transpose(1, 2).contiguous()
batch_size, N, _ = data.shape # (batch_size, num_points, 6)
pos = data.view(batch_size*N, -1)
batch = torch.zeros((batch_size, N), device=pos.device, dtype=torch.long)
for i in range(batch_size):
batch[i] = i
batch = batch.view(-1)
data = Data()
data.pos, data.batch = pos, batch
if not hasattr(data, 'x'):
data.x = None
sa0_out = (data.x, data.pos, data.batch)
sa1_out = self.sa1_module(*sa0_out)
sa2_out = self.sa2_module(*sa1_out)
sa3_out = self.sa3_module(*sa2_out)
tensor, pos, batch = sa3_out
fp3_out = self.fp3_module(*sa3_out, *sa2_out)
fp2_out = self.fp2_module(*fp3_out, *sa1_out)
tensor, _, _ = self.fp1_module(*fp2_out, *sa0_out)
tensor = tensor.float()
tensor = self.lin1(tensor)
tensor = self.activation(tensor)
tensor = self.lin1(tensor)
tensor = self.dropout(tensor)
tensor = self.lin2(tensor)
tensor = self.activation(tensor)
tensor = self.dropout(tensor)
tensor = self.lin3(tensor)
tensor = self.log_softmax(tensor)