6D prediction files now working

This commit is contained in:
Si11ium
2020-06-25 12:03:08 +02:00
parent 965b805ee9
commit 2a7a236b89
5 changed files with 113 additions and 80 deletions

View File

@ -23,11 +23,10 @@ main_arg_parser.add_argument("--project_neptune_key", type=str, default=os.geten
main_arg_parser.add_argument("--data_worker", type=int, default=10, help="") main_arg_parser.add_argument("--data_worker", type=int, default=10, help="")
main_arg_parser.add_argument("--data_npoints", type=int, default=1024, help="") main_arg_parser.add_argument("--data_npoints", type=int, default=1024, help="")
main_arg_parser.add_argument("--data_root", type=str, default='data', help="") main_arg_parser.add_argument("--data_root", type=str, default='data', help="")
main_arg_parser.add_argument("--data_refresh", type=strtobool, default=False, help="")
main_arg_parser.add_argument("--data_dataset_type", type=str, default='ShapeNetPartSegDataset', help="") main_arg_parser.add_argument("--data_dataset_type", type=str, default='ShapeNetPartSegDataset', help="")
main_arg_parser.add_argument("--data_cluster_type", type=str, default='grid', help="") main_arg_parser.add_argument("--data_cluster_type", type=str, default='grid', help="")
main_arg_parser.add_argument("--data_use_preprocessed", type=strtobool, default=True, help="") main_arg_parser.add_argument("--data_normals_as_cords", type=strtobool, default=False, help="")
main_arg_parser.add_argument("--data_normals_as_cords", type=strtobool, default=True, help="")
main_arg_parser.add_argument("--data_refresh", type=strtobool, default=False, help="")
main_arg_parser.add_argument("--data_poly_as_plane", type=strtobool, default=True, help="") main_arg_parser.add_argument("--data_poly_as_plane", type=strtobool, default=True, help="")
# Transformations # Transformations

View File

@ -1,3 +1,4 @@
import pickle
from pathlib import Path from pathlib import Path
from typing import Union from typing import Union
from warnings import warn from warnings import warn
@ -96,19 +97,32 @@ class CustomShapeNet(InMemoryDataset):
def _load_dataset(self): def _load_dataset(self):
data, slices = None, None data, slices = None, None
filepath = self.processed_paths[0] filepath = self.processed_paths[0]
config_path = Path(filepath).parent / f'{self.mode}_params.ini'
if config_path.exists() and not self.refresh and not self.mode == DataSplit().predict:
with config_path.open('rb') as f:
config = pickle.load(f)
if config == self._build_config():
pass
else:
print('The given data parameters seem to differ from the one used to process the dataset:')
self.refresh = True
if self.refresh: if self.refresh:
try: try:
os.remove(filepath) os.remove(filepath)
try:
config_path.unlink()
except:
pass
print('Processed Location "Refreshed" (We deleted the Files)') print('Processed Location "Refreshed" (We deleted the Files)')
except FileNotFoundError: except FileNotFoundError:
print('You meant to refresh the allready processed dataset, but there were none...') print('The allready processed dataset was meant to be refreshed, but there was none...')
print('continue processing') print('continue processing')
pass pass
while True: while True:
try: try:
data, slices = torch.load(filepath) data, slices = torch.load(filepath)
print('Dataset Loaded') print(f'{self.mode.title()}-Dataset Loaded')
break break
except FileNotFoundError: except FileNotFoundError:
status = self.check_and_resolve_cloud_count() status = self.check_and_resolve_cloud_count()
@ -117,8 +131,18 @@ class CustomShapeNet(InMemoryDataset):
break break
self.process() self.process()
continue continue
if not self.mode == DataSplit().predict:
config = self._build_config()
with config_path.open('wb') as f:
pickle.dump(config, f, pickle.HIGHEST_PROTOCOL)
return data, slices return data, slices
def _build_config(self):
conf_dict = {key:str(val) for key, val in self.__dict__.items() if '__' not in key and key not in [
'classes', 'refresh', 'transform', 'data', 'slices'
]}
return conf_dict
def _pre_transform_and_filter(self, data): def _pre_transform_and_filter(self, data):
if self.pre_filter is not None and not self.pre_filter(data): if self.pre_filter is not None and not self.pre_filter(data):
data = self.pre_filter(data) data = self.pre_filter(data)
@ -129,7 +153,9 @@ class CustomShapeNet(InMemoryDataset):
def process(self, delimiter=' '): def process(self, delimiter=' '):
datasets = defaultdict(list) datasets = defaultdict(list)
path_to_clouds = self.raw_dir / self.mode path_to_clouds = self.raw_dir / self.mode
for pointcloud in tqdm(path_to_clouds.glob('*.xyz')): found_clouds = list(path_to_clouds.glob('*.xyz'))
if len(found_clouds):
for pointcloud in tqdm(found_clouds):
if self.cluster_type not in pointcloud.name: if self.cluster_type not in pointcloud.name:
continue continue
data = None data = None
@ -141,6 +167,11 @@ class CustomShapeNet(InMemoryDataset):
if row != '': if row != '':
vals = row.rstrip().split(delimiter)[None:None] vals = row.rstrip().split(delimiter)[None:None]
vals = [float(x) if x not in ['-nan(ind)', 'nan(ind)'] else 0 for x in vals] vals = [float(x) if x not in ['-nan(ind)', 'nan(ind)'] else 0 for x in vals]
if len(vals) < 6:
raise ValueError('Check the Input!!!!!!')
# Expand the values from the csv by fake labels if non are provided.
vals = vals + [0] * (8 - len(vals))
src[vals[-1]].append(vals) src[vals[-1]].append(vals)
# Switch from un-pickable Defaultdict to Standard Dict # Switch from un-pickable Defaultdict to Standard Dict

View File

@ -54,9 +54,9 @@ def predict_prim_type(input_pc, model):
if __name__ == '__main__': if __name__ == '__main__':
input_pc_path = Path('data') / 'pc' / 'test.xyz' # input_pc_path = Path('data') / 'pc' / 'test.xyz'
model_path = Path('output') / 'PN2' / 'PN_9843bf499399786cfd58fe79fa1b3db8' / 'version_0' model_path = Path('output') / 'PN2' / 'PN_14628b734c5b651b013ad9e36c406934' / 'version_0'
# config_filename = 'config.ini' # config_filename = 'config.ini'
# config = ThisConfig() # config = ThisConfig()
# config.read_file((Path(model_path) / config_filename).open('r')) # config.read_file((Path(model_path) / config_filename).open('r'))
@ -72,7 +72,7 @@ if __name__ == '__main__':
test_dataset = ShapeNetPartSegDataset('data', mode=GlobalVar.data_split.predict, collate_per_segment=False, test_dataset = ShapeNetPartSegDataset('data', mode=GlobalVar.data_split.predict, collate_per_segment=False,
refresh=True, transform=transforms) refresh=True, transform=transforms)
grid_clusters = cluster_cubes(test_dataset[1], [1, 1, 1], max_points_per_cluster=32768) grid_clusters = cluster_cubes(test_dataset[0], [1, 1, 1], max_points_per_cluster=8192)
ps.init() ps.init()

View File

@ -24,13 +24,13 @@ class _PointNetCore(LightningBaseModule, ABC):
self.cord_dims = 6 if self.params.normals_as_cords else 3 self.cord_dims = 6 if self.params.normals_as_cords else 3
# Modules # Modules
self.sa1_module = SAModule(0.2, 0.2, MLP([self.cord_dims, 64, 64, 128])) self.sa1_module = SAModule(0.2, 0.2, MLP([3 + 3, 64, 64, 128]))
self.sa2_module = SAModule(0.25, 0.4, MLP([128 + self.cord_dims, 128, 128, 256])) self.sa2_module = SAModule(0.25, 0.4, MLP([128 + self.cord_dims, 128, 128, 256]))
self.sa3_module = GlobalSAModule(MLP([256 + self.cord_dims, 256, 512, 1024]), channels=self.cord_dims) self.sa3_module = GlobalSAModule(MLP([256 + self.cord_dims, 256, 512, 1024]), channels=self.cord_dims)
self.fp3_module = FPModule(1, MLP([1024 + 256, 256, 256])) self.fp3_module = FPModule(1, MLP([1024 + 256, 256, 256]))
self.fp2_module = FPModule(3, MLP([256 + 128, 256, 128])) self.fp2_module = FPModule(3, MLP([256 + 128, 256, 128]))
self.fp1_module = FPModule(3, MLP([128, 128, 128, 128])) self.fp1_module = FPModule(3, MLP([128 + (3 if not self.params.normals_as_cords else 0), 128, 128, 128]))
self.lin1 = torch.nn.Linear(128, 128) self.lin1 = torch.nn.Linear(128, 128)
self.lin2 = torch.nn.Linear(128, 128) self.lin2 = torch.nn.Linear(128, 128)

View File

@ -75,10 +75,13 @@ def write_pointcloud(file, pc, numCols=6):
def farthest_point_sampling(pts, K): def farthest_point_sampling(pts, K):
if K > 0:
if isinstance(pts, Data): if isinstance(pts, Data):
pts = pts.pos.numpy() pts = pts.pos.numpy()
if pts.shape[0] < K: if pts.shape[0] < K:
return pts return pts
else:
return pts
def calc_distances(p0, points): def calc_distances(p0, points):
return ((p0[:3] - points[:, :3]) ** 2).sum(axis=1) return ((p0[:3] - points[:, :3]) ** 2).sum(axis=1)