6D prediction files now working
This commit is contained in:
parent
965b805ee9
commit
2a7a236b89
@ -23,11 +23,10 @@ main_arg_parser.add_argument("--project_neptune_key", type=str, default=os.geten
|
|||||||
main_arg_parser.add_argument("--data_worker", type=int, default=10, help="")
|
main_arg_parser.add_argument("--data_worker", type=int, default=10, help="")
|
||||||
main_arg_parser.add_argument("--data_npoints", type=int, default=1024, help="")
|
main_arg_parser.add_argument("--data_npoints", type=int, default=1024, help="")
|
||||||
main_arg_parser.add_argument("--data_root", type=str, default='data', help="")
|
main_arg_parser.add_argument("--data_root", type=str, default='data', help="")
|
||||||
|
main_arg_parser.add_argument("--data_refresh", type=strtobool, default=False, help="")
|
||||||
main_arg_parser.add_argument("--data_dataset_type", type=str, default='ShapeNetPartSegDataset', help="")
|
main_arg_parser.add_argument("--data_dataset_type", type=str, default='ShapeNetPartSegDataset', help="")
|
||||||
main_arg_parser.add_argument("--data_cluster_type", type=str, default='grid', help="")
|
main_arg_parser.add_argument("--data_cluster_type", type=str, default='grid', help="")
|
||||||
main_arg_parser.add_argument("--data_use_preprocessed", type=strtobool, default=True, help="")
|
main_arg_parser.add_argument("--data_normals_as_cords", type=strtobool, default=False, help="")
|
||||||
main_arg_parser.add_argument("--data_normals_as_cords", type=strtobool, default=True, help="")
|
|
||||||
main_arg_parser.add_argument("--data_refresh", type=strtobool, default=False, help="")
|
|
||||||
main_arg_parser.add_argument("--data_poly_as_plane", type=strtobool, default=True, help="")
|
main_arg_parser.add_argument("--data_poly_as_plane", type=strtobool, default=True, help="")
|
||||||
|
|
||||||
# Transformations
|
# Transformations
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import pickle
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union
|
from typing import Union
|
||||||
from warnings import warn
|
from warnings import warn
|
||||||
@ -96,19 +97,32 @@ class CustomShapeNet(InMemoryDataset):
|
|||||||
def _load_dataset(self):
|
def _load_dataset(self):
|
||||||
data, slices = None, None
|
data, slices = None, None
|
||||||
filepath = self.processed_paths[0]
|
filepath = self.processed_paths[0]
|
||||||
|
config_path = Path(filepath).parent / f'{self.mode}_params.ini'
|
||||||
|
if config_path.exists() and not self.refresh and not self.mode == DataSplit().predict:
|
||||||
|
with config_path.open('rb') as f:
|
||||||
|
config = pickle.load(f)
|
||||||
|
if config == self._build_config():
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
print('The given data parameters seem to differ from the one used to process the dataset:')
|
||||||
|
self.refresh = True
|
||||||
if self.refresh:
|
if self.refresh:
|
||||||
try:
|
try:
|
||||||
os.remove(filepath)
|
os.remove(filepath)
|
||||||
|
try:
|
||||||
|
config_path.unlink()
|
||||||
|
except:
|
||||||
|
pass
|
||||||
print('Processed Location "Refreshed" (We deleted the Files)')
|
print('Processed Location "Refreshed" (We deleted the Files)')
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
print('You meant to refresh the allready processed dataset, but there were none...')
|
print('The allready processed dataset was meant to be refreshed, but there was none...')
|
||||||
print('continue processing')
|
print('continue processing')
|
||||||
pass
|
pass
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
data, slices = torch.load(filepath)
|
data, slices = torch.load(filepath)
|
||||||
print('Dataset Loaded')
|
print(f'{self.mode.title()}-Dataset Loaded')
|
||||||
break
|
break
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
status = self.check_and_resolve_cloud_count()
|
status = self.check_and_resolve_cloud_count()
|
||||||
@ -117,8 +131,18 @@ class CustomShapeNet(InMemoryDataset):
|
|||||||
break
|
break
|
||||||
self.process()
|
self.process()
|
||||||
continue
|
continue
|
||||||
|
if not self.mode == DataSplit().predict:
|
||||||
|
config = self._build_config()
|
||||||
|
with config_path.open('wb') as f:
|
||||||
|
pickle.dump(config, f, pickle.HIGHEST_PROTOCOL)
|
||||||
return data, slices
|
return data, slices
|
||||||
|
|
||||||
|
def _build_config(self):
|
||||||
|
conf_dict = {key:str(val) for key, val in self.__dict__.items() if '__' not in key and key not in [
|
||||||
|
'classes', 'refresh', 'transform', 'data', 'slices'
|
||||||
|
]}
|
||||||
|
return conf_dict
|
||||||
|
|
||||||
def _pre_transform_and_filter(self, data):
|
def _pre_transform_and_filter(self, data):
|
||||||
if self.pre_filter is not None and not self.pre_filter(data):
|
if self.pre_filter is not None and not self.pre_filter(data):
|
||||||
data = self.pre_filter(data)
|
data = self.pre_filter(data)
|
||||||
@ -129,76 +153,83 @@ class CustomShapeNet(InMemoryDataset):
|
|||||||
def process(self, delimiter=' '):
|
def process(self, delimiter=' '):
|
||||||
datasets = defaultdict(list)
|
datasets = defaultdict(list)
|
||||||
path_to_clouds = self.raw_dir / self.mode
|
path_to_clouds = self.raw_dir / self.mode
|
||||||
for pointcloud in tqdm(path_to_clouds.glob('*.xyz')):
|
found_clouds = list(path_to_clouds.glob('*.xyz'))
|
||||||
if self.cluster_type not in pointcloud.name:
|
if len(found_clouds):
|
||||||
continue
|
for pointcloud in tqdm(found_clouds):
|
||||||
data = None
|
if self.cluster_type not in pointcloud.name:
|
||||||
|
|
||||||
with pointcloud.open('r') as f:
|
|
||||||
src = defaultdict(list)
|
|
||||||
# Iterate over all rows
|
|
||||||
for row in f:
|
|
||||||
if row != '':
|
|
||||||
vals = row.rstrip().split(delimiter)[None:None]
|
|
||||||
vals = [float(x) if x not in ['-nan(ind)', 'nan(ind)'] else 0 for x in vals]
|
|
||||||
src[vals[-1]].append(vals)
|
|
||||||
|
|
||||||
# Switch from un-pickable Defaultdict to Standard Dict
|
|
||||||
src = dict(src)
|
|
||||||
|
|
||||||
# Transform the Dict[List] to Dict[torch.Tensor]
|
|
||||||
for key, values in src.items():
|
|
||||||
src[key] = torch.tensor(values, dtype=torch.double).squeeze()
|
|
||||||
|
|
||||||
# Screw the Sorting and make it a FullCloud rather than a seperated
|
|
||||||
if not self.collate_per_segment:
|
|
||||||
src = dict(
|
|
||||||
all=torch.cat(tuple(src.values()))
|
|
||||||
)
|
|
||||||
|
|
||||||
# Transform Box and Polytope to Plane if poly_as_plane is set
|
|
||||||
for key, tensor in src.items():
|
|
||||||
if tensor.ndim == 1:
|
|
||||||
if all([x == 0 for x in tensor]):
|
|
||||||
continue
|
|
||||||
tensor = tensor.unsqueeze(0)
|
|
||||||
if self.poly_as_plane:
|
|
||||||
tensor[:, -2][tensor[:, -2] == float(self.classes.Plane)] = 4.0
|
|
||||||
tensor[:, -2][tensor[:, -2] == float(self.classes.Box)] = 4.0
|
|
||||||
tensor[:, -2][tensor[:, -2] == float(self.classes.Polytope)] = 4.0
|
|
||||||
tensor[:, -2][tensor[:, -2] == self.classes.Torus] = 3.0
|
|
||||||
src[key] = tensor
|
|
||||||
|
|
||||||
for key, values in src.items():
|
|
||||||
try:
|
|
||||||
points = values[:, :-2]
|
|
||||||
except IndexError:
|
|
||||||
continue
|
continue
|
||||||
y = torch.as_tensor(values[:, -2], dtype=torch.long)
|
data = None
|
||||||
y_c = torch.as_tensor(values[:, -1], dtype=torch.long)
|
|
||||||
####################################
|
|
||||||
# This is where you define the keys
|
|
||||||
attr_dict = dict(
|
|
||||||
y=y,
|
|
||||||
y_c=y_c,
|
|
||||||
pos=points[:, :3],
|
|
||||||
norm=points[:, 3:6]
|
|
||||||
)
|
|
||||||
|
|
||||||
####################################
|
with pointcloud.open('r') as f:
|
||||||
if self.collate_per_segment:
|
src = defaultdict(list)
|
||||||
data = Data(**attr_dict)
|
# Iterate over all rows
|
||||||
else:
|
for row in f:
|
||||||
if data is None:
|
if row != '':
|
||||||
data = defaultdict(list)
|
vals = row.rstrip().split(delimiter)[None:None]
|
||||||
for attr_key, val in attr_dict.items():
|
vals = [float(x) if x not in ['-nan(ind)', 'nan(ind)'] else 0 for x in vals]
|
||||||
data[attr_key].append(val)
|
if len(vals) < 6:
|
||||||
|
raise ValueError('Check the Input!!!!!!')
|
||||||
|
# Expand the values from the csv by fake labels if non are provided.
|
||||||
|
vals = vals + [0] * (8 - len(vals))
|
||||||
|
|
||||||
# data = self._pre_transform_and_filter(data)
|
src[vals[-1]].append(vals)
|
||||||
if self.collate_per_segment:
|
|
||||||
datasets[self.mode].append(data)
|
# Switch from un-pickable Defaultdict to Standard Dict
|
||||||
if not self.collate_per_segment:
|
src = dict(src)
|
||||||
datasets[self.mode].append(Data(**{key: torch.cat(data[key]) for key in data.keys()}))
|
|
||||||
|
# Transform the Dict[List] to Dict[torch.Tensor]
|
||||||
|
for key, values in src.items():
|
||||||
|
src[key] = torch.tensor(values, dtype=torch.double).squeeze()
|
||||||
|
|
||||||
|
# Screw the Sorting and make it a FullCloud rather than a seperated
|
||||||
|
if not self.collate_per_segment:
|
||||||
|
src = dict(
|
||||||
|
all=torch.cat(tuple(src.values()))
|
||||||
|
)
|
||||||
|
|
||||||
|
# Transform Box and Polytope to Plane if poly_as_plane is set
|
||||||
|
for key, tensor in src.items():
|
||||||
|
if tensor.ndim == 1:
|
||||||
|
if all([x == 0 for x in tensor]):
|
||||||
|
continue
|
||||||
|
tensor = tensor.unsqueeze(0)
|
||||||
|
if self.poly_as_plane:
|
||||||
|
tensor[:, -2][tensor[:, -2] == float(self.classes.Plane)] = 4.0
|
||||||
|
tensor[:, -2][tensor[:, -2] == float(self.classes.Box)] = 4.0
|
||||||
|
tensor[:, -2][tensor[:, -2] == float(self.classes.Polytope)] = 4.0
|
||||||
|
tensor[:, -2][tensor[:, -2] == self.classes.Torus] = 3.0
|
||||||
|
src[key] = tensor
|
||||||
|
|
||||||
|
for key, values in src.items():
|
||||||
|
try:
|
||||||
|
points = values[:, :-2]
|
||||||
|
except IndexError:
|
||||||
|
continue
|
||||||
|
y = torch.as_tensor(values[:, -2], dtype=torch.long)
|
||||||
|
y_c = torch.as_tensor(values[:, -1], dtype=torch.long)
|
||||||
|
####################################
|
||||||
|
# This is where you define the keys
|
||||||
|
attr_dict = dict(
|
||||||
|
y=y,
|
||||||
|
y_c=y_c,
|
||||||
|
pos=points[:, :3],
|
||||||
|
norm=points[:, 3:6]
|
||||||
|
)
|
||||||
|
|
||||||
|
####################################
|
||||||
|
if self.collate_per_segment:
|
||||||
|
data = Data(**attr_dict)
|
||||||
|
else:
|
||||||
|
if data is None:
|
||||||
|
data = defaultdict(list)
|
||||||
|
for attr_key, val in attr_dict.items():
|
||||||
|
data[attr_key].append(val)
|
||||||
|
|
||||||
|
# data = self._pre_transform_and_filter(data)
|
||||||
|
if self.collate_per_segment:
|
||||||
|
datasets[self.mode].append(data)
|
||||||
|
if not self.collate_per_segment:
|
||||||
|
datasets[self.mode].append(Data(**{key: torch.cat(data[key]) for key in data.keys()}))
|
||||||
|
|
||||||
if datasets[self.mode]:
|
if datasets[self.mode]:
|
||||||
os.makedirs(self.processed_dir, exist_ok=True)
|
os.makedirs(self.processed_dir, exist_ok=True)
|
||||||
|
@ -54,9 +54,9 @@ def predict_prim_type(input_pc, model):
|
|||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
||||||
input_pc_path = Path('data') / 'pc' / 'test.xyz'
|
# input_pc_path = Path('data') / 'pc' / 'test.xyz'
|
||||||
|
|
||||||
model_path = Path('output') / 'PN2' / 'PN_9843bf499399786cfd58fe79fa1b3db8' / 'version_0'
|
model_path = Path('output') / 'PN2' / 'PN_14628b734c5b651b013ad9e36c406934' / 'version_0'
|
||||||
# config_filename = 'config.ini'
|
# config_filename = 'config.ini'
|
||||||
# config = ThisConfig()
|
# config = ThisConfig()
|
||||||
# config.read_file((Path(model_path) / config_filename).open('r'))
|
# config.read_file((Path(model_path) / config_filename).open('r'))
|
||||||
@ -72,7 +72,7 @@ if __name__ == '__main__':
|
|||||||
test_dataset = ShapeNetPartSegDataset('data', mode=GlobalVar.data_split.predict, collate_per_segment=False,
|
test_dataset = ShapeNetPartSegDataset('data', mode=GlobalVar.data_split.predict, collate_per_segment=False,
|
||||||
refresh=True, transform=transforms)
|
refresh=True, transform=transforms)
|
||||||
|
|
||||||
grid_clusters = cluster_cubes(test_dataset[1], [1, 1, 1], max_points_per_cluster=32768)
|
grid_clusters = cluster_cubes(test_dataset[0], [1, 1, 1], max_points_per_cluster=8192)
|
||||||
|
|
||||||
ps.init()
|
ps.init()
|
||||||
|
|
||||||
|
@ -24,13 +24,13 @@ class _PointNetCore(LightningBaseModule, ABC):
|
|||||||
self.cord_dims = 6 if self.params.normals_as_cords else 3
|
self.cord_dims = 6 if self.params.normals_as_cords else 3
|
||||||
|
|
||||||
# Modules
|
# Modules
|
||||||
self.sa1_module = SAModule(0.2, 0.2, MLP([self.cord_dims, 64, 64, 128]))
|
self.sa1_module = SAModule(0.2, 0.2, MLP([3 + 3, 64, 64, 128]))
|
||||||
self.sa2_module = SAModule(0.25, 0.4, MLP([128 + self.cord_dims, 128, 128, 256]))
|
self.sa2_module = SAModule(0.25, 0.4, MLP([128 + self.cord_dims, 128, 128, 256]))
|
||||||
self.sa3_module = GlobalSAModule(MLP([256 + self.cord_dims, 256, 512, 1024]), channels=self.cord_dims)
|
self.sa3_module = GlobalSAModule(MLP([256 + self.cord_dims, 256, 512, 1024]), channels=self.cord_dims)
|
||||||
|
|
||||||
self.fp3_module = FPModule(1, MLP([1024 + 256, 256, 256]))
|
self.fp3_module = FPModule(1, MLP([1024 + 256, 256, 256]))
|
||||||
self.fp2_module = FPModule(3, MLP([256 + 128, 256, 128]))
|
self.fp2_module = FPModule(3, MLP([256 + 128, 256, 128]))
|
||||||
self.fp1_module = FPModule(3, MLP([128, 128, 128, 128]))
|
self.fp1_module = FPModule(3, MLP([128 + (3 if not self.params.normals_as_cords else 0), 128, 128, 128]))
|
||||||
|
|
||||||
self.lin1 = torch.nn.Linear(128, 128)
|
self.lin1 = torch.nn.Linear(128, 128)
|
||||||
self.lin2 = torch.nn.Linear(128, 128)
|
self.lin2 = torch.nn.Linear(128, 128)
|
||||||
|
@ -75,9 +75,12 @@ def write_pointcloud(file, pc, numCols=6):
|
|||||||
|
|
||||||
|
|
||||||
def farthest_point_sampling(pts, K):
|
def farthest_point_sampling(pts, K):
|
||||||
if isinstance(pts, Data):
|
if K > 0:
|
||||||
pts = pts.pos.numpy()
|
if isinstance(pts, Data):
|
||||||
if pts.shape[0] < K:
|
pts = pts.pos.numpy()
|
||||||
|
if pts.shape[0] < K:
|
||||||
|
return pts
|
||||||
|
else:
|
||||||
return pts
|
return pts
|
||||||
|
|
||||||
def calc_distances(p0, points):
|
def calc_distances(p0, points):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user