pointnet2 working - TODO: Eval!

This commit is contained in:
Si11ium 2020-05-26 21:44:57 +02:00
parent e04ef2f8b9
commit ba7c0280ae
11 changed files with 232 additions and 58 deletions

View File

@ -15,15 +15,14 @@ main_arg_parser.add_argument("--main_eval", type=strtobool, default=True, help="
main_arg_parser.add_argument("--main_seed", type=int, default=69, help="") main_arg_parser.add_argument("--main_seed", type=int, default=69, help="")
# Project # Project
main_arg_parser.add_argument("--project_name", type=str, default='traj-gen', help="") main_arg_parser.add_argument("--project_name", type=str, default='point-to-primitive', help="")
main_arg_parser.add_argument("--project_owner", type=str, default='si11ium', help="") main_arg_parser.add_argument("--project_owner", type=str, default='si11ium', help="")
main_arg_parser.add_argument("--project_neptune_key", type=str, default=os.getenv('NEPTUNE_KEY'), help="") main_arg_parser.add_argument("--project_neptune_key", type=str, default=os.getenv('NEPTUNE_API_TOKEN'), help="")
# Data Parameters # Data Parameters
main_arg_parser.add_argument("--data_worker", type=int, default=10, help="") main_arg_parser.add_argument("--data_worker", type=int, default=10, help="")
main_arg_parser.add_argument("--data_dataset_length", type=int, default=10000, help="") main_arg_parser.add_argument("--data_dataset_length", type=int, default=10000, help="")
main_arg_parser.add_argument("--data_root", type=str, default='data', help="") main_arg_parser.add_argument("--data_root", type=str, default='data', help="")
main_arg_parser.add_argument("--data_additional_resource_root", type=str, default='res/resource/root', help="")
main_arg_parser.add_argument("--data_use_preprocessed", type=strtobool, default=True, help="") main_arg_parser.add_argument("--data_use_preprocessed", type=strtobool, default=True, help="")
# Transformations # Transformations
@ -36,10 +35,11 @@ main_arg_parser.add_argument("--train_version", type=strtobool, required=False,
main_arg_parser.add_argument("--train_epochs", type=int, default=500, help="") main_arg_parser.add_argument("--train_epochs", type=int, default=500, help="")
main_arg_parser.add_argument("--train_batch_size", type=int, default=200, help="") main_arg_parser.add_argument("--train_batch_size", type=int, default=200, help="")
main_arg_parser.add_argument("--train_lr", type=float, default=1e-3, help="") main_arg_parser.add_argument("--train_lr", type=float, default=1e-3, help="")
main_arg_parser.add_argument("--train_weight_decay", type=float, default=1e-8, help="")
main_arg_parser.add_argument("--train_num_sanity_val_steps", type=int, default=0, help="") main_arg_parser.add_argument("--train_num_sanity_val_steps", type=int, default=0, help="")
# Model # Model
main_arg_parser.add_argument("--model_type", type=str, default="CNNRouteGenerator", help="") main_arg_parser.add_argument("--model_type", type=str, default="PN2", help="")
main_arg_parser.add_argument("--model_activation", type=str, default="leaky_relu", help="") main_arg_parser.add_argument("--model_activation", type=str, default="leaky_relu", help="")
main_arg_parser.add_argument("--model_use_bias", type=strtobool, default=True, help="") main_arg_parser.add_argument("--model_use_bias", type=strtobool, default=True, help="")

View File

@ -1,17 +1,27 @@
import pickle
from collections import defaultdict
from abc import ABC from abc import ABC
from pathlib import Path from pathlib import Path
from torch.utils.data import Dataset from torch.utils.data import Dataset
from ml_lib.point_toolset.sampling import FarthestpointSampling from ml_lib.point_toolset.sampling import FarthestpointSampling
import numpy as np
class _Point_Dataset(ABC, Dataset): class _Point_Dataset(ABC, Dataset):
@property
def sample_shape(self):
# FixMe: This does not work when more then x/y tuples are returned
return self[0][0].shape
@property @property
def setting(self) -> str: def setting(self) -> str:
raise NotImplementedError raise NotImplementedError
headers = ['x', 'y', 'z', 'nx', 'ny', 'nz', 'label', 'cl_idx'] headers = ['x', 'y', 'z', 'xn', 'yn', 'zn', 'label', 'cl_idx']
def __init__(self, root=Path('data'), sampling_k=2048, transforms=None, load_preprocessed=True, *args, **kwargs): def __init__(self, root=Path('data'), sampling_k=2048, transforms=None, load_preprocessed=True, *args, **kwargs):
super(_Point_Dataset, self).__init__() super(_Point_Dataset, self).__init__()
@ -21,13 +31,32 @@ class _Point_Dataset(ABC, Dataset):
self.sampling_k = sampling_k self.sampling_k = sampling_k
self.sampling = FarthestpointSampling(K=self.sampling_k) self.sampling = FarthestpointSampling(K=self.sampling_k)
self.root = Path(root) self.root = Path(root)
self.raw = root / 'raw' self.raw = self.root / 'raw'
self.processed_ext = '.pik' self.processed_ext = '.pik'
self.raw_ext = '.xyz' self.raw_ext = '.xyz'
self.processed = root / self.setting self.processed = self.root / self.setting
self.processed.mkdir(parents=True, exist_ok=True)
self._files = list(self.raw.glob(f'*{self.setting}*')) self._files = list(self.raw.glob(f'*{self.setting}*'))
def _read_or_load(self, item):
raw_file_path = self._files[item]
processed_file_path = self.processed / raw_file_path.name.replace(self.raw_ext, self.processed_ext)
if not self.load_preprocessed:
processed_file_path.unlink(missing_ok=True)
if not processed_file_path.exists():
pointcloud = defaultdict(list)
with raw_file_path.open('r') as raw_file:
for row in raw_file:
values = [float(x) for x in row.strip().split(' ')]
for header, value in zip(self.headers, values):
pointcloud[header].append(value)
for key in pointcloud.keys():
pointcloud[key] = np.asarray(pointcloud[key])
with processed_file_path.open('wb') as processed_file:
pickle.dump(pointcloud, processed_file)
return processed_file_path
def __len__(self): def __len__(self):
raise NotImplementedError raise NotImplementedError

View File

@ -1,9 +1,7 @@
import pickle import pickle
from collections import defaultdict from collections import defaultdict
from pathlib import Path
import numpy as np import numpy as np
from torch.utils.data import Dataset
from ._point_dataset import _Point_Dataset from ._point_dataset import _Point_Dataset
@ -19,27 +17,17 @@ class FullCloudsDataset(_Point_Dataset):
return len(self._files) return len(self._files)
def __getitem__(self, item): def __getitem__(self, item):
raw_file_path = self._files[item] processed_file_path = self._read_or_load(item)
processed_file_path = self.processed / raw_file_path.name.replace(self.raw_ext, self.processed_ext)
if not self.load_preprocessed:
processed_file_path.unlink(missing_ok=True)
if not processed_file_path.exists():
pointcloud = defaultdict(list)
with raw_file_path.open('r') as raw_file:
for row in raw_file:
values = [float(x) for x in row.split(' ')]
for header, value in zip(self.headers, values):
pointcloud[header].append(value)
for key in pointcloud.keys():
pointcloud[key] = np.asarray(pointcloud[key])
with processed_file_path.open('wb') as processed_file:
pickle.dump(pointcloud, processed_file)
with processed_file_path.open('rb') as processed_file: with processed_file_path.open('rb') as processed_file:
pointcloud = pickle.load(processed_file) pointcloud = pickle.load(processed_file)
points = np.stack(pointcloud['x'], pointcloud['y'], pointcloud['z']) points = np.stack((pointcloud['x'], pointcloud['y'], pointcloud['z'],
normal = np.stack(pointcloud['xn'], pointcloud['yn'], pointcloud['zn']) pointcloud['xn'], pointcloud['yn'], pointcloud['zn']
label = points['label'] ),
samples = self.sampling(points) axis=-1)
# When yopu want to return points and normal seperately
# normal = np.stack((pointcloud['xn'], pointcloud['yn'], pointcloud['zn']), axis=-1)
label = pointcloud['label']
sample_idxs = self.sampling(points)
return points[samples], normal[samples], label[samples] return points[sample_idxs].astype(np.float), label[sample_idxs].astype(np.int)

View File

@ -1,6 +1,32 @@
from torch.utils.data import Dataset import pickle
import numpy as np
from ._point_dataset import _Point_Dataset
class TemplateDataset(_Point_Dataset): class FullCloudsDataset(_Point_Dataset):
setting = 'grid'
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(TemplateDataset, self).__init__() super(FullCloudsDataset, self).__init__(*args, **kwargs)
def __len__(self):
return len(self._files)
def __getitem__(self, item):
processed_file_path = self._read_or_load(item)
with processed_file_path.open('rb') as processed_file:
pointcloud = pickle.load(processed_file)
points = np.stack((pointcloud['x'], pointcloud['y'], pointcloud['z'],
pointcloud['xn'], pointcloud['yn'], pointcloud['zn']
),
axis=-1)
# When yopu want to return points and normal seperately
# normal = np.stack((pointcloud['xn'], pointcloud['yn'], pointcloud['zn']), axis=-1)
label = np.stack((pointcloud['label'], pointcloud['cl_idx']))
sample_idxs = self.sampling(points)
return points[sample_idxs], label[sample_idxs]

View File

@ -1,8 +1,32 @@
from torch.utils.data import Dataset import pickle
import numpy as np
from ._point_dataset import _Point_Dataset from ._point_dataset import _Point_Dataset
class TemplateDataset(_Point_Dataset): class FullCloudsDataset(_Point_Dataset):
setting = 'prim'
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(TemplateDataset, self).__init__() super(FullCloudsDataset, self).__init__(*args, **kwargs)
def __len__(self):
return len(self._files)
def __getitem__(self, item):
processed_file_path = self._read_or_load(item)
with processed_file_path.open('rb') as processed_file:
pointcloud = pickle.load(processed_file)
points = np.stack((pointcloud['x'], pointcloud['y'], pointcloud['z'],
pointcloud['xn'], pointcloud['yn'], pointcloud['zn']
),
axis=-1)
# When yopu want to return points and normal seperately
# normal = np.stack((pointcloud['xn'], pointcloud['yn'], pointcloud['zn']), axis=-1)
label = np.stack((pointcloud['label'], pointcloud['cl_idx']))
sample_idxs = self.sampling(points)
return points[sample_idxs], label[sample_idxs]

View File

@ -10,4 +10,3 @@ class TemplateDataset(_Point_Dataset):
def __getitem__(self, item): def __getitem__(self, item):
return item return item

14
main.py
View File

@ -5,12 +5,11 @@ import warnings
import torch import torch
from pytorch_lightning import Trainer from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping from pytorch_lightning.callbacks import ModelCheckpoint # , EarlyStopping
from ml_lib.modules.util import LightningBaseModule from ml_lib.modules.util import LightningBaseModule
from ml_lib.utils.config import Config
from ml_lib.utils.logging import Logger from ml_lib.utils.logging import Logger
from ml_lib.utils.model_io import SavedLightningModels from utils.project_config import ThisConfig
warnings.filterwarnings('ignore', category=FutureWarning) warnings.filterwarnings('ignore', category=FutureWarning)
warnings.filterwarnings('ignore', category=UserWarning) warnings.filterwarnings('ignore', category=UserWarning)
@ -33,11 +32,13 @@ def run_lightning_loop(config_obj):
# ============================================================================= # =============================================================================
# Early Stopping # Early Stopping
# TODO: For This to work, one must set a validation step and End Eval and Score # TODO: For This to work, one must set a validation step and End Eval and Score
"""
early_stopping_callback = EarlyStopping( early_stopping_callback = EarlyStopping(
monitor='val_loss', monitor='val_loss',
min_delta=0.0, min_delta=0.0,
patience=0, patience=0,
) )
"""
# Model # Model
# ============================================================================= # =============================================================================
@ -76,6 +77,9 @@ def run_lightning_loop(config_obj):
if __name__ == "__main__": if __name__ == "__main__":
from ._parameters import args from _parameters import args
config = Config.read_namespace(args) from ml_lib.utils.tools import fix_all_random_seeds
config = ThisConfig.read_namespace(args)
fix_all_random_seeds(config)
trained_model = run_lightning_loop(config) trained_model = run_lightning_loop(config)

1
models/__init__.py Normal file
View File

@ -0,0 +1 @@
from .point_net_2 import PointNet2

103
models/point_net_2.py Normal file
View File

@ -0,0 +1,103 @@
from argparse import Namespace
import torch
from torch.optim import Adam
from torch import nn
from torch_geometric.data import Data
from datasets.full_pointclouds import FullCloudsDataset
from ml_lib.modules.geometric_blocks import SAModule, GlobalSAModule, MLP
from ml_lib.modules.util import LightningBaseModule, F_x
from utils.module_mixins import BaseValMixin, BaseTrainMixin, BaseOptimizerMixin, BaseDataloadersMixin, DatasetMixin
class PointNet2(BaseValMixin,
BaseTrainMixin,
BaseOptimizerMixin,
DatasetMixin,
BaseDataloadersMixin,
LightningBaseModule
):
def __init__(self, hparams):
super(PointNet2, self).__init__(hparams=hparams)
# Dataset
# =============================================================================
self.dataset = self.build_dataset(FullCloudsDataset)
# Model Paramters
# =============================================================================
# Additional parameters
self.in_shape = self.dataset.train_dataset.sample_shape
self.channels = self.in_shape[-1]
# Modules
self.sa1_module = SAModule(0.5, 0.2, MLP([self.channels, 64, 64, 128]))
self.sa2_module = SAModule(0.25, 0.4, MLP([128 + self.channels, 128, 128, 256]))
self.sa3_module = GlobalSAModule(MLP([256 + self.channels, 256, 512, 1024]))
self.lin1 = nn.Linear(1024, 512)
self.lin2 = nn.Linear(512, 256)
self.lin3 = nn.Linear(256, 10)
# Utility
self.dropout = nn.Dropout(self.params.dropout) if self.params.dropout else F_x(None)
self.activation = self.params.activation()
self.log_softmax = nn.LogSoftmax(dim=-1)
def configure_optimizers(self):
return Adam(self.parameters(), lr=self.params.lr, weight_decay=self.params.weight_decay)
def forward(self, data, **kwargs):
"""
data: a batch of input, torch.Tensor or torch_geometric.data.Data type
- torch.Tensor: (batch_size, 3, num_points), as common batch input
- torch_geometric.data.Data, as torch_geometric batch input:
data.x: (batch_size * ~num_points, C), batch nodes/points feature,
~num_points means each sample can have different number of points/nodes
data.pos: (batch_size * ~num_points, 3)
data.batch: (batch_size * ~num_points,), a column vector of graph/pointcloud
idendifiers for all nodes of all graphs/pointclouds in the batch. See
pytorch_gemometric documentation for more information
"""
dense_input = True if isinstance(data, torch.Tensor) else False
if dense_input:
# Convert to torch_geometric.data.Data type
# data = data.transpose(1, 2).contiguous()
batch_size, N, _ = data.shape # (batch_size, num_points, 6)
pos = data.view(batch_size*N, -1)
batch = torch.zeros((batch_size, N), device=pos.device, dtype=torch.long)
for i in range(batch_size):
batch[i] = i
batch = batch.view(-1)
data = Data()
data.pos, data.batch = pos, batch
if not hasattr(data, 'x'):
data.x = None
sa0_out = (data.x, data.pos, data.batch)
sa1_out = self.sa1_module(*sa0_out)
sa2_out = self.sa2_module(*sa1_out)
sa3_out = self.sa3_module(*sa2_out)
tensor, pos, batch = sa3_out
tensor = tensor.float()
tensor = self.lin1(tensor)
tensor = self.activation(tensor)
tensor = self.dropout(tensor)
tensor = self.lin2(tensor)
tensor = self.activation(tensor)
tensor = self.dropout(tensor)
tensor = self.lin3(tensor)
tensor = self.log_softmax(tensor)
return Namespace(main_out=tensor)

View File

@ -11,13 +11,10 @@ from torch.utils.data import DataLoader
from torchcontrib.optim import SWA from torchcontrib.optim import SWA
from torchvision.transforms import Compose from torchvision.transforms import Compose
from _templates.new_project.datasets.template_dataset import TemplateDataset from ml_lib.modules.util import LightningBaseModule
from ml_lib.utils.transforms import ToTensor
from audio_toolset.audio_io import NormalizeLocal from .project_config import GlobalVar
from modules.utils import LightningBaseModule
from utils.transforms import ToTensor
from _templates.new_project.utils.project_config import GlobalVar as GlobalVars
class BaseOptimizerMixin: class BaseOptimizerMixin:
@ -110,31 +107,31 @@ class BaseValMixin:
return summary_dict return summary_dict
class BinaryMaskDatasetMixin: class DatasetMixin:
def build_dataset(self): def build_dataset(self, dataset_class):
assert isinstance(self, LightningBaseModule) assert isinstance(self, LightningBaseModule)
# Dataset # Dataset
# ============================================================================= # =============================================================================
# Data Augmentations or Utility Transformations # Data Augmentations or Utility Transformations
transforms = Compose([NormalizeLocal(), ToTensor()]) transforms = Compose([ToTensor()])
# Dataset # Dataset
dataset = Namespace( dataset = Namespace(
**dict( **dict(
# TRAIN DATASET # TRAIN DATASET
train_dataset=TemplateDataset(self.params.root, setting=GlobalVars.DATA_OPTIONS.train, train_dataset=dataset_class(self.params.root, setting=GlobalVar.train,
transforms=transforms transforms=transforms
), ),
# VALIDATION DATASET # VALIDATION DATASET
val_dataset=TemplateDataset(self.params.root, setting=GlobalVars.vali, val_dataset=dataset_class(self.params.root, setting=GlobalVar.vali,
), ),
# TEST DATASET # TEST DATASET
test_dataset=TemplateDataset(self.params.root, setting=GlobalVars.test, test_dataset=dataset_class(self.params.root, setting=GlobalVar.test,
), ),
) )

View File

@ -1,6 +1,6 @@
from argparse import Namespace from argparse import Namespace
from utils.config import Config from ml_lib.utils.config import Config
class GlobalVar(Namespace): class GlobalVar(Namespace):
@ -18,13 +18,16 @@ class GlobalVar(Namespace):
DPI = 50 DPI = 50
# DATAOPTIONS # DATAOPTIONS
train='train', train ='train',
vali='vali', vali ='vali',
test='test' test ='test'
from models import *
class ThisConfig(Config): class ThisConfig(Config):
@property @property
def _model_map(self): def _model_map(self):
return dict() return dict(PN2=PointNet2)