New Dataset for per spatial cluster training

This commit is contained in:
Si11ium 2020-06-09 14:08:35 +02:00
parent 821b2d1961
commit 23f3aa878d
10 changed files with 104 additions and 12 deletions

View File

@ -22,6 +22,7 @@ main_arg_parser.add_argument("--project_neptune_key", type=str, default=os.geten
# Data Parameters # Data Parameters
main_arg_parser.add_argument("--data_worker", type=int, default=10, help="") main_arg_parser.add_argument("--data_worker", type=int, default=10, help="")
main_arg_parser.add_argument("--data_root", type=str, default='data', help="") main_arg_parser.add_argument("--data_root", type=str, default='data', help="")
main_arg_parser.add_argument("--data_dataset_type", type=str, default='GridClusters', help="")
main_arg_parser.add_argument("--data_use_preprocessed", type=strtobool, default=True, help="") main_arg_parser.add_argument("--data_use_preprocessed", type=strtobool, default=True, help="")
# Transformations # Transformations
@ -41,7 +42,7 @@ main_arg_parser.add_argument("--train_opt_reset_interval", type=strtobool, defau
# Model # Model
# Possible Model arguments are: P2P, PN2, P2G # Possible Model arguments are: P2P, PN2, P2G
main_arg_parser.add_argument("--model_type", type=str, default="P2G", help="") main_arg_parser.add_argument("--model_type", type=str, default="PN2", help="")
main_arg_parser.add_argument("--model_norm_as_feature", type=strtobool, default=True, help="") main_arg_parser.add_argument("--model_norm_as_feature", type=strtobool, default=True, help="")
main_arg_parser.add_argument("--model_activation", type=str, default="leaky_relu", help="") main_arg_parser.add_argument("--model_activation", type=str, default="leaky_relu", help="")

View File

@ -4,7 +4,7 @@ from collections import defaultdict
from abc import ABC from abc import ABC
from pathlib import Path from pathlib import Path
from torch.utils.data import Dataset from torch.utils.data import Dataset, ConcatDataset
from ml_lib.point_toolset.sampling import FarthestpointSampling, RandomSampling from ml_lib.point_toolset.sampling import FarthestpointSampling, RandomSampling
import numpy as np import numpy as np
@ -12,6 +12,10 @@ import numpy as np
class _Point_Dataset(ABC, Dataset): class _Point_Dataset(ABC, Dataset):
@property
def name(self):
raise NotImplementedError
@property @property
def sample_shape(self): def sample_shape(self):
# FixMe: This does not work when more then x/y tuples are returned # FixMe: This does not work when more then x/y tuples are returned

View File

@ -9,6 +9,7 @@ from ._point_dataset import _Point_Dataset
class FullCloudsDataset(_Point_Dataset): class FullCloudsDataset(_Point_Dataset):
split: str split: str
name = 'FullCloudsDataset'
def __init__(self, *args, setting='pc', **kwargs): def __init__(self, *args, setting='pc', **kwargs):
self.setting = setting self.setting = setting

79
datasets/grid_clusters.py Normal file
View File

@ -0,0 +1,79 @@
import pickle
from collections import defaultdict
import numpy as np
from torch.utils.data import ConcatDataset
from tqdm import trange
from ._point_dataset import _Point_Dataset
class GridClusters(_Point_Dataset):
split: str
name = 'GridClusters'
def __init__(self, *args, n_spatial_clusters=3*3*3, setting='pc', **kwargs):
self.n_spatial_clusters = n_spatial_clusters
self.setting = setting
super(GridClusters, self).__init__(*args, **kwargs)
def __len__(self):
return len(self._files)
def _read_or_load(self, item):
raw_file_path = self._files[item]
processed_file_path = self.processed / raw_file_path.name.replace(self.raw_ext, self.processed_ext)
if not self.load_preprocessed:
processed_file_path.unlink(missing_ok=True)
if not processed_file_path.exists():
# nested default dict
pointcloud = defaultdict(lambda: defaultdict(list))
with raw_file_path.open('r') as raw_file:
for row in raw_file:
values = [float(x) for x in row.strip().split(' ')]
for header, value in zip(self.headers, values):
pointcloud[int(values[-1])][header].append(value)
for cluster in pointcloud.keys():
for key in pointcloud[cluster].keys():
pointcloud[cluster][key] = np.asarray(pointcloud[cluster][key])
pointcloud[cluster] = dict(pointcloud[cluster])
pointcloud = dict(pointcloud)
with processed_file_path.open('wb') as processed_file:
pickle.dump(pointcloud, processed_file)
return processed_file_path
def __getitem__(self, item):
processed_file_path = self._read_or_load(item)
with processed_file_path.open('rb') as processed_file:
pointcloud = pickle.load(processed_file)
# By number Variant
# cl_idx_list = np.cumsum([[len(self) // self.n_spatial_clusters, ] * self.n_spatial_clusters])
# cl_idx = [idx for idx, x in enumerate(cl_idx_list) if item <= x][0]
# Random Variant
cl_idx = np.random.randint(0, len(pointcloud))
pointcloud = pointcloud[list(pointcloud.keys())[cl_idx]]
position = np.stack((pointcloud['x'], pointcloud['y'], pointcloud['z']), axis=-1)
normal = np.stack((pointcloud['xn'], pointcloud['yn'], pointcloud['zn']), axis=-1)
label = pointcloud['label']
cl_label = pointcloud['cl_idx']
sample_idxs = self.sampling(position)
while sample_idxs.shape[0] < self.sampling_k:
sample_idxs = np.concatenate((sample_idxs, sample_idxs))[:self.sampling_k]
return (normal[sample_idxs].astype(np.float),
position[sample_idxs].astype(np.float),
label[sample_idxs].astype(np.int),
cl_label[sample_idxs].astype(np.int)
)

View File

@ -25,8 +25,9 @@ def run_lightning_loop(config_obj):
# ============================================================================= # =============================================================================
# Checkpoint Saving # Checkpoint Saving
checkpoint_callback = ModelCheckpoint( checkpoint_callback = ModelCheckpoint(
monitor='mean_loss',
filepath=str(logger.log_dir / 'ckpt_weights'), filepath=str(logger.log_dir / 'ckpt_weights'),
verbose=True, save_top_k=0, verbose=True, save_top_k=10,
) )
# ============================================================================= # =============================================================================
@ -80,6 +81,9 @@ if __name__ == "__main__":
from _parameters import args from _parameters import args
from ml_lib.utils.tools import fix_all_random_seeds from ml_lib.utils.tools import fix_all_random_seeds
# When debugging, use the following parameters:
# --main_debug=True --data_worker=0
config = ThisConfig.read_namespace(args) config = ThisConfig.read_namespace(args)
fix_all_random_seeds(config) fix_all_random_seeds(config)
trained_model = run_lightning_loop(config) trained_model = run_lightning_loop(config)

View File

@ -32,7 +32,7 @@ class _PointNetCore(LightningBaseModule):
def forward(self, sa0_out, **kwargs): def forward(self, sa0_out, **kwargs):
""" """
data: a batch of input torch_geometric.data.Data type sa0_out: a batch of input torch_geometric.data.Data type
- torch_geometric.data.Data, as torch_geometric batch input: - torch_geometric.data.Data, as torch_geometric batch input:
data.x: (batch_size * ~num_points, C), batch nodes/points feature, data.x: (batch_size * ~num_points, C), batch nodes/points feature,
~num_points means each sample can have different number of points/nodes ~num_points means each sample can have different number of points/nodes

View File

@ -3,7 +3,7 @@ from argparse import Namespace
import torch import torch
from torch import nn from torch import nn
from datasets.full_pointclouds import FullCloudsDataset from datasets.grid_clusters import GridClusters
from models._point_net_2 import _PointNetCore from models._point_net_2 import _PointNetCore
from utils.module_mixins import BaseValMixin, BaseTrainMixin, BaseOptimizerMixin, BaseDataloadersMixin, DatasetMixin from utils.module_mixins import BaseValMixin, BaseTrainMixin, BaseOptimizerMixin, BaseDataloadersMixin, DatasetMixin
@ -23,7 +23,7 @@ class PointNet2(BaseValMixin,
# Dataset # Dataset
# ============================================================================= # =============================================================================
self.dataset = self.build_dataset(FullCloudsDataset, setting='pc') self.dataset = self.build_dataset(GridClusters, setting='pc')
# Model Paramters # Model Paramters
# ============================================================================= # =============================================================================

View File

@ -4,7 +4,7 @@ import torch
from torch import nn from torch import nn
from torch_geometric.data import Data from torch_geometric.data import Data
from datasets.full_pointclouds import FullCloudsDataset from datasets.grid_clusters import GridClusters
from models._point_net_2 import _PointNetCore from models._point_net_2 import _PointNetCore
from utils.module_mixins import BaseValMixin, BaseTrainMixin, BaseOptimizerMixin, BaseDataloadersMixin, DatasetMixin from utils.module_mixins import BaseValMixin, BaseTrainMixin, BaseOptimizerMixin, BaseDataloadersMixin, DatasetMixin
@ -42,7 +42,7 @@ class PointNet2GridClusters(BaseValMixin,
# Dataset # Dataset
# ============================================================================= # =============================================================================
self.dataset = self.build_dataset(FullCloudsDataset, setting='grid') self.dataset = self.build_dataset(GridClusters, setting='grid')
# Model Paramters # Model Paramters
# ============================================================================= # =============================================================================

View File

@ -16,11 +16,11 @@ if __name__ == '__main__':
# Model Settings # Model Settings
config = ThisConfig().read_namespace(args) config = ThisConfig().read_namespace(args)
# bias, activation, model, norm, max_epochs # bias, activation, model, norm, max_epochs
pn2 = dict(model_type='PN2',model_use_bias=True, model_use_norm=True, data_batchsize=250) pn2 = dict(model_type='PN2', model_use_bias=True, model_use_norm=True, data_batchsize=250)
p2g = dict(model_type='P2G', model_use_bias=True, model_use_norm=True, data_batchsize=250) # p2g = dict(model_type='P2G', model_use_bias=True, model_use_norm=True, data_batchsize=250)
# bias, activation, model, norm, max_epochs # bias, activation, model, norm, max_epochs
for arg_dict in [p2g]: for arg_dict in [pn2]:
for seed in range(10): for seed in range(10):
arg_dict.update(main_seed=seed) arg_dict.update(main_seed=seed)

View File

@ -222,6 +222,9 @@ class DatasetMixin:
def build_dataset(self, dataset_class, **kwargs): def build_dataset(self, dataset_class, **kwargs):
assert isinstance(self, LightningBaseModule) assert isinstance(self, LightningBaseModule)
assert dataset_class.name == self.params.dataset_type, f'Check the dataset! ' + \
f'Expected was {self.params.dataset_type}, ' + \
f'given:{dataset_class.name}'
# Dataset # Dataset
# ============================================================================= # =============================================================================
@ -258,7 +261,7 @@ class BaseDataloadersMixin(ABC):
# In case you want to implement bootstraping # In case you want to implement bootstraping
# sampler = RandomSampler(self.dataset.train_dataset, True, len(self.dataset.train_dataset)) # sampler = RandomSampler(self.dataset.train_dataset, True, len(self.dataset.train_dataset))
sampler = None sampler = None
return DataLoader(dataset=self.dataset.train_dataset, shuffle=True if not sampler else None, sampler=sampler, return DataLoader(dataset=self.dataset.train_dataset, shuffle=False if not sampler else None, sampler=sampler,
batch_size=self.params.batch_size, batch_size=self.params.batch_size,
num_workers=self.params.worker) num_workers=self.params.worker)