6D prediction files now working
This commit is contained in:
@ -1,3 +1,4 @@
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
from warnings import warn
|
||||
@ -96,19 +97,32 @@ class CustomShapeNet(InMemoryDataset):
|
||||
def _load_dataset(self):
|
||||
data, slices = None, None
|
||||
filepath = self.processed_paths[0]
|
||||
config_path = Path(filepath).parent / f'{self.mode}_params.ini'
|
||||
if config_path.exists() and not self.refresh and not self.mode == DataSplit().predict:
|
||||
with config_path.open('rb') as f:
|
||||
config = pickle.load(f)
|
||||
if config == self._build_config():
|
||||
pass
|
||||
else:
|
||||
print('The given data parameters seem to differ from the one used to process the dataset:')
|
||||
self.refresh = True
|
||||
if self.refresh:
|
||||
try:
|
||||
os.remove(filepath)
|
||||
try:
|
||||
config_path.unlink()
|
||||
except:
|
||||
pass
|
||||
print('Processed Location "Refreshed" (We deleted the Files)')
|
||||
except FileNotFoundError:
|
||||
print('You meant to refresh the allready processed dataset, but there were none...')
|
||||
print('The allready processed dataset was meant to be refreshed, but there was none...')
|
||||
print('continue processing')
|
||||
pass
|
||||
|
||||
while True:
|
||||
try:
|
||||
data, slices = torch.load(filepath)
|
||||
print('Dataset Loaded')
|
||||
print(f'{self.mode.title()}-Dataset Loaded')
|
||||
break
|
||||
except FileNotFoundError:
|
||||
status = self.check_and_resolve_cloud_count()
|
||||
@ -117,8 +131,18 @@ class CustomShapeNet(InMemoryDataset):
|
||||
break
|
||||
self.process()
|
||||
continue
|
||||
if not self.mode == DataSplit().predict:
|
||||
config = self._build_config()
|
||||
with config_path.open('wb') as f:
|
||||
pickle.dump(config, f, pickle.HIGHEST_PROTOCOL)
|
||||
return data, slices
|
||||
|
||||
def _build_config(self):
|
||||
conf_dict = {key:str(val) for key, val in self.__dict__.items() if '__' not in key and key not in [
|
||||
'classes', 'refresh', 'transform', 'data', 'slices'
|
||||
]}
|
||||
return conf_dict
|
||||
|
||||
def _pre_transform_and_filter(self, data):
|
||||
if self.pre_filter is not None and not self.pre_filter(data):
|
||||
data = self.pre_filter(data)
|
||||
@ -129,76 +153,83 @@ class CustomShapeNet(InMemoryDataset):
|
||||
def process(self, delimiter=' '):
|
||||
datasets = defaultdict(list)
|
||||
path_to_clouds = self.raw_dir / self.mode
|
||||
for pointcloud in tqdm(path_to_clouds.glob('*.xyz')):
|
||||
if self.cluster_type not in pointcloud.name:
|
||||
continue
|
||||
data = None
|
||||
|
||||
with pointcloud.open('r') as f:
|
||||
src = defaultdict(list)
|
||||
# Iterate over all rows
|
||||
for row in f:
|
||||
if row != '':
|
||||
vals = row.rstrip().split(delimiter)[None:None]
|
||||
vals = [float(x) if x not in ['-nan(ind)', 'nan(ind)'] else 0 for x in vals]
|
||||
src[vals[-1]].append(vals)
|
||||
|
||||
# Switch from un-pickable Defaultdict to Standard Dict
|
||||
src = dict(src)
|
||||
|
||||
# Transform the Dict[List] to Dict[torch.Tensor]
|
||||
for key, values in src.items():
|
||||
src[key] = torch.tensor(values, dtype=torch.double).squeeze()
|
||||
|
||||
# Screw the Sorting and make it a FullCloud rather than a seperated
|
||||
if not self.collate_per_segment:
|
||||
src = dict(
|
||||
all=torch.cat(tuple(src.values()))
|
||||
)
|
||||
|
||||
# Transform Box and Polytope to Plane if poly_as_plane is set
|
||||
for key, tensor in src.items():
|
||||
if tensor.ndim == 1:
|
||||
if all([x == 0 for x in tensor]):
|
||||
continue
|
||||
tensor = tensor.unsqueeze(0)
|
||||
if self.poly_as_plane:
|
||||
tensor[:, -2][tensor[:, -2] == float(self.classes.Plane)] = 4.0
|
||||
tensor[:, -2][tensor[:, -2] == float(self.classes.Box)] = 4.0
|
||||
tensor[:, -2][tensor[:, -2] == float(self.classes.Polytope)] = 4.0
|
||||
tensor[:, -2][tensor[:, -2] == self.classes.Torus] = 3.0
|
||||
src[key] = tensor
|
||||
|
||||
for key, values in src.items():
|
||||
try:
|
||||
points = values[:, :-2]
|
||||
except IndexError:
|
||||
found_clouds = list(path_to_clouds.glob('*.xyz'))
|
||||
if len(found_clouds):
|
||||
for pointcloud in tqdm(found_clouds):
|
||||
if self.cluster_type not in pointcloud.name:
|
||||
continue
|
||||
y = torch.as_tensor(values[:, -2], dtype=torch.long)
|
||||
y_c = torch.as_tensor(values[:, -1], dtype=torch.long)
|
||||
####################################
|
||||
# This is where you define the keys
|
||||
attr_dict = dict(
|
||||
y=y,
|
||||
y_c=y_c,
|
||||
pos=points[:, :3],
|
||||
norm=points[:, 3:6]
|
||||
)
|
||||
data = None
|
||||
|
||||
####################################
|
||||
if self.collate_per_segment:
|
||||
data = Data(**attr_dict)
|
||||
else:
|
||||
if data is None:
|
||||
data = defaultdict(list)
|
||||
for attr_key, val in attr_dict.items():
|
||||
data[attr_key].append(val)
|
||||
with pointcloud.open('r') as f:
|
||||
src = defaultdict(list)
|
||||
# Iterate over all rows
|
||||
for row in f:
|
||||
if row != '':
|
||||
vals = row.rstrip().split(delimiter)[None:None]
|
||||
vals = [float(x) if x not in ['-nan(ind)', 'nan(ind)'] else 0 for x in vals]
|
||||
if len(vals) < 6:
|
||||
raise ValueError('Check the Input!!!!!!')
|
||||
# Expand the values from the csv by fake labels if non are provided.
|
||||
vals = vals + [0] * (8 - len(vals))
|
||||
|
||||
# data = self._pre_transform_and_filter(data)
|
||||
if self.collate_per_segment:
|
||||
datasets[self.mode].append(data)
|
||||
if not self.collate_per_segment:
|
||||
datasets[self.mode].append(Data(**{key: torch.cat(data[key]) for key in data.keys()}))
|
||||
src[vals[-1]].append(vals)
|
||||
|
||||
# Switch from un-pickable Defaultdict to Standard Dict
|
||||
src = dict(src)
|
||||
|
||||
# Transform the Dict[List] to Dict[torch.Tensor]
|
||||
for key, values in src.items():
|
||||
src[key] = torch.tensor(values, dtype=torch.double).squeeze()
|
||||
|
||||
# Screw the Sorting and make it a FullCloud rather than a seperated
|
||||
if not self.collate_per_segment:
|
||||
src = dict(
|
||||
all=torch.cat(tuple(src.values()))
|
||||
)
|
||||
|
||||
# Transform Box and Polytope to Plane if poly_as_plane is set
|
||||
for key, tensor in src.items():
|
||||
if tensor.ndim == 1:
|
||||
if all([x == 0 for x in tensor]):
|
||||
continue
|
||||
tensor = tensor.unsqueeze(0)
|
||||
if self.poly_as_plane:
|
||||
tensor[:, -2][tensor[:, -2] == float(self.classes.Plane)] = 4.0
|
||||
tensor[:, -2][tensor[:, -2] == float(self.classes.Box)] = 4.0
|
||||
tensor[:, -2][tensor[:, -2] == float(self.classes.Polytope)] = 4.0
|
||||
tensor[:, -2][tensor[:, -2] == self.classes.Torus] = 3.0
|
||||
src[key] = tensor
|
||||
|
||||
for key, values in src.items():
|
||||
try:
|
||||
points = values[:, :-2]
|
||||
except IndexError:
|
||||
continue
|
||||
y = torch.as_tensor(values[:, -2], dtype=torch.long)
|
||||
y_c = torch.as_tensor(values[:, -1], dtype=torch.long)
|
||||
####################################
|
||||
# This is where you define the keys
|
||||
attr_dict = dict(
|
||||
y=y,
|
||||
y_c=y_c,
|
||||
pos=points[:, :3],
|
||||
norm=points[:, 3:6]
|
||||
)
|
||||
|
||||
####################################
|
||||
if self.collate_per_segment:
|
||||
data = Data(**attr_dict)
|
||||
else:
|
||||
if data is None:
|
||||
data = defaultdict(list)
|
||||
for attr_key, val in attr_dict.items():
|
||||
data[attr_key].append(val)
|
||||
|
||||
# data = self._pre_transform_and_filter(data)
|
||||
if self.collate_per_segment:
|
||||
datasets[self.mode].append(data)
|
||||
if not self.collate_per_segment:
|
||||
datasets[self.mode].append(Data(**{key: torch.cat(data[key]) for key in data.keys()}))
|
||||
|
||||
if datasets[self.mode]:
|
||||
os.makedirs(self.processed_dir, exist_ok=True)
|
||||
|
Reference in New Issue
Block a user