6D prediction files now working

This commit is contained in:
Si11ium 2020-06-25 12:03:08 +02:00
parent 965b805ee9
commit 2a7a236b89
5 changed files with 113 additions and 80 deletions

View File

@ -23,11 +23,10 @@ main_arg_parser.add_argument("--project_neptune_key", type=str, default=os.geten
main_arg_parser.add_argument("--data_worker", type=int, default=10, help="") main_arg_parser.add_argument("--data_worker", type=int, default=10, help="")
main_arg_parser.add_argument("--data_npoints", type=int, default=1024, help="") main_arg_parser.add_argument("--data_npoints", type=int, default=1024, help="")
main_arg_parser.add_argument("--data_root", type=str, default='data', help="") main_arg_parser.add_argument("--data_root", type=str, default='data', help="")
main_arg_parser.add_argument("--data_refresh", type=strtobool, default=False, help="")
main_arg_parser.add_argument("--data_dataset_type", type=str, default='ShapeNetPartSegDataset', help="") main_arg_parser.add_argument("--data_dataset_type", type=str, default='ShapeNetPartSegDataset', help="")
main_arg_parser.add_argument("--data_cluster_type", type=str, default='grid', help="") main_arg_parser.add_argument("--data_cluster_type", type=str, default='grid', help="")
main_arg_parser.add_argument("--data_use_preprocessed", type=strtobool, default=True, help="") main_arg_parser.add_argument("--data_normals_as_cords", type=strtobool, default=False, help="")
main_arg_parser.add_argument("--data_normals_as_cords", type=strtobool, default=True, help="")
main_arg_parser.add_argument("--data_refresh", type=strtobool, default=False, help="")
main_arg_parser.add_argument("--data_poly_as_plane", type=strtobool, default=True, help="") main_arg_parser.add_argument("--data_poly_as_plane", type=strtobool, default=True, help="")
# Transformations # Transformations

View File

@ -1,3 +1,4 @@
import pickle
from pathlib import Path from pathlib import Path
from typing import Union from typing import Union
from warnings import warn from warnings import warn
@ -96,19 +97,32 @@ class CustomShapeNet(InMemoryDataset):
def _load_dataset(self): def _load_dataset(self):
data, slices = None, None data, slices = None, None
filepath = self.processed_paths[0] filepath = self.processed_paths[0]
config_path = Path(filepath).parent / f'{self.mode}_params.ini'
if config_path.exists() and not self.refresh and not self.mode == DataSplit().predict:
with config_path.open('rb') as f:
config = pickle.load(f)
if config == self._build_config():
pass
else:
print('The given data parameters seem to differ from the one used to process the dataset:')
self.refresh = True
if self.refresh: if self.refresh:
try: try:
os.remove(filepath) os.remove(filepath)
try:
config_path.unlink()
except:
pass
print('Processed Location "Refreshed" (We deleted the Files)') print('Processed Location "Refreshed" (We deleted the Files)')
except FileNotFoundError: except FileNotFoundError:
print('You meant to refresh the allready processed dataset, but there were none...') print('The allready processed dataset was meant to be refreshed, but there was none...')
print('continue processing') print('continue processing')
pass pass
while True: while True:
try: try:
data, slices = torch.load(filepath) data, slices = torch.load(filepath)
print('Dataset Loaded') print(f'{self.mode.title()}-Dataset Loaded')
break break
except FileNotFoundError: except FileNotFoundError:
status = self.check_and_resolve_cloud_count() status = self.check_and_resolve_cloud_count()
@ -117,8 +131,18 @@ class CustomShapeNet(InMemoryDataset):
break break
self.process() self.process()
continue continue
if not self.mode == DataSplit().predict:
config = self._build_config()
with config_path.open('wb') as f:
pickle.dump(config, f, pickle.HIGHEST_PROTOCOL)
return data, slices return data, slices
def _build_config(self):
conf_dict = {key:str(val) for key, val in self.__dict__.items() if '__' not in key and key not in [
'classes', 'refresh', 'transform', 'data', 'slices'
]}
return conf_dict
def _pre_transform_and_filter(self, data): def _pre_transform_and_filter(self, data):
if self.pre_filter is not None and not self.pre_filter(data): if self.pre_filter is not None and not self.pre_filter(data):
data = self.pre_filter(data) data = self.pre_filter(data)
@ -129,76 +153,83 @@ class CustomShapeNet(InMemoryDataset):
def process(self, delimiter=' '): def process(self, delimiter=' '):
datasets = defaultdict(list) datasets = defaultdict(list)
path_to_clouds = self.raw_dir / self.mode path_to_clouds = self.raw_dir / self.mode
for pointcloud in tqdm(path_to_clouds.glob('*.xyz')): found_clouds = list(path_to_clouds.glob('*.xyz'))
if self.cluster_type not in pointcloud.name: if len(found_clouds):
continue for pointcloud in tqdm(found_clouds):
data = None if self.cluster_type not in pointcloud.name:
with pointcloud.open('r') as f:
src = defaultdict(list)
# Iterate over all rows
for row in f:
if row != '':
vals = row.rstrip().split(delimiter)[None:None]
vals = [float(x) if x not in ['-nan(ind)', 'nan(ind)'] else 0 for x in vals]
src[vals[-1]].append(vals)
# Switch from un-pickable Defaultdict to Standard Dict
src = dict(src)
# Transform the Dict[List] to Dict[torch.Tensor]
for key, values in src.items():
src[key] = torch.tensor(values, dtype=torch.double).squeeze()
# Screw the Sorting and make it a FullCloud rather than a seperated
if not self.collate_per_segment:
src = dict(
all=torch.cat(tuple(src.values()))
)
# Transform Box and Polytope to Plane if poly_as_plane is set
for key, tensor in src.items():
if tensor.ndim == 1:
if all([x == 0 for x in tensor]):
continue
tensor = tensor.unsqueeze(0)
if self.poly_as_plane:
tensor[:, -2][tensor[:, -2] == float(self.classes.Plane)] = 4.0
tensor[:, -2][tensor[:, -2] == float(self.classes.Box)] = 4.0
tensor[:, -2][tensor[:, -2] == float(self.classes.Polytope)] = 4.0
tensor[:, -2][tensor[:, -2] == self.classes.Torus] = 3.0
src[key] = tensor
for key, values in src.items():
try:
points = values[:, :-2]
except IndexError:
continue continue
y = torch.as_tensor(values[:, -2], dtype=torch.long) data = None
y_c = torch.as_tensor(values[:, -1], dtype=torch.long)
####################################
# This is where you define the keys
attr_dict = dict(
y=y,
y_c=y_c,
pos=points[:, :3],
norm=points[:, 3:6]
)
#################################### with pointcloud.open('r') as f:
if self.collate_per_segment: src = defaultdict(list)
data = Data(**attr_dict) # Iterate over all rows
else: for row in f:
if data is None: if row != '':
data = defaultdict(list) vals = row.rstrip().split(delimiter)[None:None]
for attr_key, val in attr_dict.items(): vals = [float(x) if x not in ['-nan(ind)', 'nan(ind)'] else 0 for x in vals]
data[attr_key].append(val) if len(vals) < 6:
raise ValueError('Check the Input!!!!!!')
# Expand the values from the csv by fake labels if non are provided.
vals = vals + [0] * (8 - len(vals))
# data = self._pre_transform_and_filter(data) src[vals[-1]].append(vals)
if self.collate_per_segment:
datasets[self.mode].append(data) # Switch from un-pickable Defaultdict to Standard Dict
if not self.collate_per_segment: src = dict(src)
datasets[self.mode].append(Data(**{key: torch.cat(data[key]) for key in data.keys()}))
# Transform the Dict[List] to Dict[torch.Tensor]
for key, values in src.items():
src[key] = torch.tensor(values, dtype=torch.double).squeeze()
# Screw the Sorting and make it a FullCloud rather than a seperated
if not self.collate_per_segment:
src = dict(
all=torch.cat(tuple(src.values()))
)
# Transform Box and Polytope to Plane if poly_as_plane is set
for key, tensor in src.items():
if tensor.ndim == 1:
if all([x == 0 for x in tensor]):
continue
tensor = tensor.unsqueeze(0)
if self.poly_as_plane:
tensor[:, -2][tensor[:, -2] == float(self.classes.Plane)] = 4.0
tensor[:, -2][tensor[:, -2] == float(self.classes.Box)] = 4.0
tensor[:, -2][tensor[:, -2] == float(self.classes.Polytope)] = 4.0
tensor[:, -2][tensor[:, -2] == self.classes.Torus] = 3.0
src[key] = tensor
for key, values in src.items():
try:
points = values[:, :-2]
except IndexError:
continue
y = torch.as_tensor(values[:, -2], dtype=torch.long)
y_c = torch.as_tensor(values[:, -1], dtype=torch.long)
####################################
# This is where you define the keys
attr_dict = dict(
y=y,
y_c=y_c,
pos=points[:, :3],
norm=points[:, 3:6]
)
####################################
if self.collate_per_segment:
data = Data(**attr_dict)
else:
if data is None:
data = defaultdict(list)
for attr_key, val in attr_dict.items():
data[attr_key].append(val)
# data = self._pre_transform_and_filter(data)
if self.collate_per_segment:
datasets[self.mode].append(data)
if not self.collate_per_segment:
datasets[self.mode].append(Data(**{key: torch.cat(data[key]) for key in data.keys()}))
if datasets[self.mode]: if datasets[self.mode]:
os.makedirs(self.processed_dir, exist_ok=True) os.makedirs(self.processed_dir, exist_ok=True)

View File

@ -54,9 +54,9 @@ def predict_prim_type(input_pc, model):
if __name__ == '__main__': if __name__ == '__main__':
input_pc_path = Path('data') / 'pc' / 'test.xyz' # input_pc_path = Path('data') / 'pc' / 'test.xyz'
model_path = Path('output') / 'PN2' / 'PN_9843bf499399786cfd58fe79fa1b3db8' / 'version_0' model_path = Path('output') / 'PN2' / 'PN_14628b734c5b651b013ad9e36c406934' / 'version_0'
# config_filename = 'config.ini' # config_filename = 'config.ini'
# config = ThisConfig() # config = ThisConfig()
# config.read_file((Path(model_path) / config_filename).open('r')) # config.read_file((Path(model_path) / config_filename).open('r'))
@ -72,7 +72,7 @@ if __name__ == '__main__':
test_dataset = ShapeNetPartSegDataset('data', mode=GlobalVar.data_split.predict, collate_per_segment=False, test_dataset = ShapeNetPartSegDataset('data', mode=GlobalVar.data_split.predict, collate_per_segment=False,
refresh=True, transform=transforms) refresh=True, transform=transforms)
grid_clusters = cluster_cubes(test_dataset[1], [1, 1, 1], max_points_per_cluster=32768) grid_clusters = cluster_cubes(test_dataset[0], [1, 1, 1], max_points_per_cluster=8192)
ps.init() ps.init()

View File

@ -24,13 +24,13 @@ class _PointNetCore(LightningBaseModule, ABC):
self.cord_dims = 6 if self.params.normals_as_cords else 3 self.cord_dims = 6 if self.params.normals_as_cords else 3
# Modules # Modules
self.sa1_module = SAModule(0.2, 0.2, MLP([self.cord_dims, 64, 64, 128])) self.sa1_module = SAModule(0.2, 0.2, MLP([3 + 3, 64, 64, 128]))
self.sa2_module = SAModule(0.25, 0.4, MLP([128 + self.cord_dims, 128, 128, 256])) self.sa2_module = SAModule(0.25, 0.4, MLP([128 + self.cord_dims, 128, 128, 256]))
self.sa3_module = GlobalSAModule(MLP([256 + self.cord_dims, 256, 512, 1024]), channels=self.cord_dims) self.sa3_module = GlobalSAModule(MLP([256 + self.cord_dims, 256, 512, 1024]), channels=self.cord_dims)
self.fp3_module = FPModule(1, MLP([1024 + 256, 256, 256])) self.fp3_module = FPModule(1, MLP([1024 + 256, 256, 256]))
self.fp2_module = FPModule(3, MLP([256 + 128, 256, 128])) self.fp2_module = FPModule(3, MLP([256 + 128, 256, 128]))
self.fp1_module = FPModule(3, MLP([128, 128, 128, 128])) self.fp1_module = FPModule(3, MLP([128 + (3 if not self.params.normals_as_cords else 0), 128, 128, 128]))
self.lin1 = torch.nn.Linear(128, 128) self.lin1 = torch.nn.Linear(128, 128)
self.lin2 = torch.nn.Linear(128, 128) self.lin2 = torch.nn.Linear(128, 128)

View File

@ -75,9 +75,12 @@ def write_pointcloud(file, pc, numCols=6):
def farthest_point_sampling(pts, K): def farthest_point_sampling(pts, K):
if isinstance(pts, Data): if K > 0:
pts = pts.pos.numpy() if isinstance(pts, Data):
if pts.shape[0] < K: pts = pts.pos.numpy()
if pts.shape[0] < K:
return pts
else:
return pts return pts
def calc_distances(p0, points): def calc_distances(p0, points):