New Dataset for per spatial cluster training
This commit is contained in:
parent
821b2d1961
commit
23f3aa878d
@ -22,6 +22,7 @@ main_arg_parser.add_argument("--project_neptune_key", type=str, default=os.geten
|
|||||||
# Data Parameters
|
# Data Parameters
|
||||||
main_arg_parser.add_argument("--data_worker", type=int, default=10, help="")
|
main_arg_parser.add_argument("--data_worker", type=int, default=10, help="")
|
||||||
main_arg_parser.add_argument("--data_root", type=str, default='data', help="")
|
main_arg_parser.add_argument("--data_root", type=str, default='data', help="")
|
||||||
|
main_arg_parser.add_argument("--data_dataset_type", type=str, default='GridClusters', help="")
|
||||||
main_arg_parser.add_argument("--data_use_preprocessed", type=strtobool, default=True, help="")
|
main_arg_parser.add_argument("--data_use_preprocessed", type=strtobool, default=True, help="")
|
||||||
|
|
||||||
# Transformations
|
# Transformations
|
||||||
@ -41,7 +42,7 @@ main_arg_parser.add_argument("--train_opt_reset_interval", type=strtobool, defau
|
|||||||
|
|
||||||
# Model
|
# Model
|
||||||
# Possible Model arguments are: P2P, PN2, P2G
|
# Possible Model arguments are: P2P, PN2, P2G
|
||||||
main_arg_parser.add_argument("--model_type", type=str, default="P2G", help="")
|
main_arg_parser.add_argument("--model_type", type=str, default="PN2", help="")
|
||||||
main_arg_parser.add_argument("--model_norm_as_feature", type=strtobool, default=True, help="")
|
main_arg_parser.add_argument("--model_norm_as_feature", type=strtobool, default=True, help="")
|
||||||
|
|
||||||
main_arg_parser.add_argument("--model_activation", type=str, default="leaky_relu", help="")
|
main_arg_parser.add_argument("--model_activation", type=str, default="leaky_relu", help="")
|
||||||
|
@ -4,7 +4,7 @@ from collections import defaultdict
|
|||||||
from abc import ABC
|
from abc import ABC
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset, ConcatDataset
|
||||||
from ml_lib.point_toolset.sampling import FarthestpointSampling, RandomSampling
|
from ml_lib.point_toolset.sampling import FarthestpointSampling, RandomSampling
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -12,6 +12,10 @@ import numpy as np
|
|||||||
|
|
||||||
class _Point_Dataset(ABC, Dataset):
|
class _Point_Dataset(ABC, Dataset):
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def sample_shape(self):
|
def sample_shape(self):
|
||||||
# FixMe: This does not work when more then x/y tuples are returned
|
# FixMe: This does not work when more then x/y tuples are returned
|
||||||
|
@ -9,6 +9,7 @@ from ._point_dataset import _Point_Dataset
|
|||||||
class FullCloudsDataset(_Point_Dataset):
|
class FullCloudsDataset(_Point_Dataset):
|
||||||
|
|
||||||
split: str
|
split: str
|
||||||
|
name = 'FullCloudsDataset'
|
||||||
|
|
||||||
def __init__(self, *args, setting='pc', **kwargs):
|
def __init__(self, *args, setting='pc', **kwargs):
|
||||||
self.setting = setting
|
self.setting = setting
|
||||||
|
79
datasets/grid_clusters.py
Normal file
79
datasets/grid_clusters.py
Normal file
@ -0,0 +1,79 @@
|
|||||||
|
import pickle
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from torch.utils.data import ConcatDataset
|
||||||
|
from tqdm import trange
|
||||||
|
|
||||||
|
from ._point_dataset import _Point_Dataset
|
||||||
|
|
||||||
|
|
||||||
|
class GridClusters(_Point_Dataset):
|
||||||
|
|
||||||
|
split: str
|
||||||
|
name = 'GridClusters'
|
||||||
|
|
||||||
|
def __init__(self, *args, n_spatial_clusters=3*3*3, setting='pc', **kwargs):
|
||||||
|
self.n_spatial_clusters = n_spatial_clusters
|
||||||
|
self.setting = setting
|
||||||
|
super(GridClusters, self).__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self._files)
|
||||||
|
|
||||||
|
def _read_or_load(self, item):
|
||||||
|
raw_file_path = self._files[item]
|
||||||
|
processed_file_path = self.processed / raw_file_path.name.replace(self.raw_ext, self.processed_ext)
|
||||||
|
|
||||||
|
if not self.load_preprocessed:
|
||||||
|
processed_file_path.unlink(missing_ok=True)
|
||||||
|
if not processed_file_path.exists():
|
||||||
|
# nested default dict
|
||||||
|
pointcloud = defaultdict(lambda: defaultdict(list))
|
||||||
|
|
||||||
|
with raw_file_path.open('r') as raw_file:
|
||||||
|
for row in raw_file:
|
||||||
|
values = [float(x) for x in row.strip().split(' ')]
|
||||||
|
for header, value in zip(self.headers, values):
|
||||||
|
pointcloud[int(values[-1])][header].append(value)
|
||||||
|
for cluster in pointcloud.keys():
|
||||||
|
for key in pointcloud[cluster].keys():
|
||||||
|
pointcloud[cluster][key] = np.asarray(pointcloud[cluster][key])
|
||||||
|
pointcloud[cluster] = dict(pointcloud[cluster])
|
||||||
|
pointcloud = dict(pointcloud)
|
||||||
|
|
||||||
|
with processed_file_path.open('wb') as processed_file:
|
||||||
|
pickle.dump(pointcloud, processed_file)
|
||||||
|
return processed_file_path
|
||||||
|
|
||||||
|
def __getitem__(self, item):
|
||||||
|
processed_file_path = self._read_or_load(item)
|
||||||
|
|
||||||
|
with processed_file_path.open('rb') as processed_file:
|
||||||
|
pointcloud = pickle.load(processed_file)
|
||||||
|
|
||||||
|
# By number Variant
|
||||||
|
# cl_idx_list = np.cumsum([[len(self) // self.n_spatial_clusters, ] * self.n_spatial_clusters])
|
||||||
|
# cl_idx = [idx for idx, x in enumerate(cl_idx_list) if item <= x][0]
|
||||||
|
|
||||||
|
# Random Variant
|
||||||
|
cl_idx = np.random.randint(0, len(pointcloud))
|
||||||
|
pointcloud = pointcloud[list(pointcloud.keys())[cl_idx]]
|
||||||
|
|
||||||
|
position = np.stack((pointcloud['x'], pointcloud['y'], pointcloud['z']), axis=-1)
|
||||||
|
|
||||||
|
normal = np.stack((pointcloud['xn'], pointcloud['yn'], pointcloud['zn']), axis=-1)
|
||||||
|
|
||||||
|
label = pointcloud['label']
|
||||||
|
|
||||||
|
cl_label = pointcloud['cl_idx']
|
||||||
|
|
||||||
|
sample_idxs = self.sampling(position)
|
||||||
|
while sample_idxs.shape[0] < self.sampling_k:
|
||||||
|
sample_idxs = np.concatenate((sample_idxs, sample_idxs))[:self.sampling_k]
|
||||||
|
|
||||||
|
return (normal[sample_idxs].astype(np.float),
|
||||||
|
position[sample_idxs].astype(np.float),
|
||||||
|
label[sample_idxs].astype(np.int),
|
||||||
|
cl_label[sample_idxs].astype(np.int)
|
||||||
|
)
|
6
main.py
6
main.py
@ -25,8 +25,9 @@ def run_lightning_loop(config_obj):
|
|||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Checkpoint Saving
|
# Checkpoint Saving
|
||||||
checkpoint_callback = ModelCheckpoint(
|
checkpoint_callback = ModelCheckpoint(
|
||||||
|
monitor='mean_loss',
|
||||||
filepath=str(logger.log_dir / 'ckpt_weights'),
|
filepath=str(logger.log_dir / 'ckpt_weights'),
|
||||||
verbose=True, save_top_k=0,
|
verbose=True, save_top_k=10,
|
||||||
)
|
)
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
@ -80,6 +81,9 @@ if __name__ == "__main__":
|
|||||||
from _parameters import args
|
from _parameters import args
|
||||||
from ml_lib.utils.tools import fix_all_random_seeds
|
from ml_lib.utils.tools import fix_all_random_seeds
|
||||||
|
|
||||||
|
# When debugging, use the following parameters:
|
||||||
|
# --main_debug=True --data_worker=0
|
||||||
|
|
||||||
config = ThisConfig.read_namespace(args)
|
config = ThisConfig.read_namespace(args)
|
||||||
fix_all_random_seeds(config)
|
fix_all_random_seeds(config)
|
||||||
trained_model = run_lightning_loop(config)
|
trained_model = run_lightning_loop(config)
|
||||||
|
@ -32,7 +32,7 @@ class _PointNetCore(LightningBaseModule):
|
|||||||
|
|
||||||
def forward(self, sa0_out, **kwargs):
|
def forward(self, sa0_out, **kwargs):
|
||||||
"""
|
"""
|
||||||
data: a batch of input torch_geometric.data.Data type
|
sa0_out: a batch of input torch_geometric.data.Data type
|
||||||
- torch_geometric.data.Data, as torch_geometric batch input:
|
- torch_geometric.data.Data, as torch_geometric batch input:
|
||||||
data.x: (batch_size * ~num_points, C), batch nodes/points feature,
|
data.x: (batch_size * ~num_points, C), batch nodes/points feature,
|
||||||
~num_points means each sample can have different number of points/nodes
|
~num_points means each sample can have different number of points/nodes
|
||||||
|
@ -3,7 +3,7 @@ from argparse import Namespace
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from datasets.full_pointclouds import FullCloudsDataset
|
from datasets.grid_clusters import GridClusters
|
||||||
from models._point_net_2 import _PointNetCore
|
from models._point_net_2 import _PointNetCore
|
||||||
|
|
||||||
from utils.module_mixins import BaseValMixin, BaseTrainMixin, BaseOptimizerMixin, BaseDataloadersMixin, DatasetMixin
|
from utils.module_mixins import BaseValMixin, BaseTrainMixin, BaseOptimizerMixin, BaseDataloadersMixin, DatasetMixin
|
||||||
@ -23,7 +23,7 @@ class PointNet2(BaseValMixin,
|
|||||||
|
|
||||||
# Dataset
|
# Dataset
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
self.dataset = self.build_dataset(FullCloudsDataset, setting='pc')
|
self.dataset = self.build_dataset(GridClusters, setting='pc')
|
||||||
|
|
||||||
# Model Paramters
|
# Model Paramters
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
@ -4,7 +4,7 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from torch_geometric.data import Data
|
from torch_geometric.data import Data
|
||||||
|
|
||||||
from datasets.full_pointclouds import FullCloudsDataset
|
from datasets.grid_clusters import GridClusters
|
||||||
from models._point_net_2 import _PointNetCore
|
from models._point_net_2 import _PointNetCore
|
||||||
|
|
||||||
from utils.module_mixins import BaseValMixin, BaseTrainMixin, BaseOptimizerMixin, BaseDataloadersMixin, DatasetMixin
|
from utils.module_mixins import BaseValMixin, BaseTrainMixin, BaseOptimizerMixin, BaseDataloadersMixin, DatasetMixin
|
||||||
@ -42,7 +42,7 @@ class PointNet2GridClusters(BaseValMixin,
|
|||||||
|
|
||||||
# Dataset
|
# Dataset
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
self.dataset = self.build_dataset(FullCloudsDataset, setting='grid')
|
self.dataset = self.build_dataset(GridClusters, setting='grid')
|
||||||
|
|
||||||
# Model Paramters
|
# Model Paramters
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
@ -16,11 +16,11 @@ if __name__ == '__main__':
|
|||||||
# Model Settings
|
# Model Settings
|
||||||
config = ThisConfig().read_namespace(args)
|
config = ThisConfig().read_namespace(args)
|
||||||
# bias, activation, model, norm, max_epochs
|
# bias, activation, model, norm, max_epochs
|
||||||
pn2 = dict(model_type='PN2',model_use_bias=True, model_use_norm=True, data_batchsize=250)
|
pn2 = dict(model_type='PN2', model_use_bias=True, model_use_norm=True, data_batchsize=250)
|
||||||
p2g = dict(model_type='P2G', model_use_bias=True, model_use_norm=True, data_batchsize=250)
|
# p2g = dict(model_type='P2G', model_use_bias=True, model_use_norm=True, data_batchsize=250)
|
||||||
# bias, activation, model, norm, max_epochs
|
# bias, activation, model, norm, max_epochs
|
||||||
|
|
||||||
for arg_dict in [p2g]:
|
for arg_dict in [pn2]:
|
||||||
for seed in range(10):
|
for seed in range(10):
|
||||||
arg_dict.update(main_seed=seed)
|
arg_dict.update(main_seed=seed)
|
||||||
|
|
||||||
|
@ -222,6 +222,9 @@ class DatasetMixin:
|
|||||||
|
|
||||||
def build_dataset(self, dataset_class, **kwargs):
|
def build_dataset(self, dataset_class, **kwargs):
|
||||||
assert isinstance(self, LightningBaseModule)
|
assert isinstance(self, LightningBaseModule)
|
||||||
|
assert dataset_class.name == self.params.dataset_type, f'Check the dataset! ' + \
|
||||||
|
f'Expected was {self.params.dataset_type}, ' + \
|
||||||
|
f'given:{dataset_class.name}'
|
||||||
|
|
||||||
# Dataset
|
# Dataset
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
@ -258,7 +261,7 @@ class BaseDataloadersMixin(ABC):
|
|||||||
# In case you want to implement bootstraping
|
# In case you want to implement bootstraping
|
||||||
# sampler = RandomSampler(self.dataset.train_dataset, True, len(self.dataset.train_dataset))
|
# sampler = RandomSampler(self.dataset.train_dataset, True, len(self.dataset.train_dataset))
|
||||||
sampler = None
|
sampler = None
|
||||||
return DataLoader(dataset=self.dataset.train_dataset, shuffle=True if not sampler else None, sampler=sampler,
|
return DataLoader(dataset=self.dataset.train_dataset, shuffle=False if not sampler else None, sampler=sampler,
|
||||||
batch_size=self.params.batch_size,
|
batch_size=self.params.batch_size,
|
||||||
num_workers=self.params.worker)
|
num_workers=self.params.worker)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user