Classes Fixed abnd debugging

This commit is contained in:
Si11ium
2020-07-03 14:40:28 +02:00
parent e9d0591b11
commit 5353220890
10 changed files with 66 additions and 59 deletions

@ -13,7 +13,7 @@ import torch
from torch_geometric.data import InMemoryDataset
from torch_geometric.data import Data
from utils.project_settings import Classes, DataSplit, ClusterTypes
from utils.project_settings import classesAll, classesPolyAsPlane, dataSplit, clusterTypes
def save_names(name_list, path):
@ -34,11 +34,11 @@ class CustomShapeNet(InMemoryDataset):
@property
def modes(self):
return {key: val for val, key in DataSplit().items()}
return {key: val for val, key in dataSplit.items()}
@property
def cluster_types(self):
return {key: val for val, key in ClusterTypes().items()}
return {key: val for val, key in clusterTypes.items()}
@property
def raw_dir(self):
@ -62,8 +62,8 @@ class CustomShapeNet(InMemoryDataset):
# Set the Dataset Parameters
self.cluster_type = cluster_type if cluster_type else 'pc'
self.classes = Classes()
self.poly_as_plane = poly_as_plane
self.classes = classesAll if not self.poly_as_plane else classesPolyAsPlane
self.collate_per_segment = collate_per_segment
self.mode = mode
self.refresh = refresh
@ -92,10 +92,10 @@ class CustomShapeNet(InMemoryDataset):
@property
def num_classes(self):
return len(self.categories) if self.poly_as_plane else (len(self.categories) - 2)
return len(self.categories)
@property
def class_map_all(self):
def _class_map_all(self):
return {0: 0,
1: 1,
2: None,
@ -107,7 +107,7 @@ class CustomShapeNet(InMemoryDataset):
}
@property
def class_map_poly_as_plane(self):
def _class_map_poly_as_plane(self):
return {0: 0,
1: 1,
2: None,
@ -118,11 +118,15 @@ class CustomShapeNet(InMemoryDataset):
7: None
}
@property
def class_remap(self):
return self._class_map_all if not self.poly_as_plane else self._class_map_poly_as_plane
def _load_dataset(self):
data, slices = None, None
filepath = self.processed_paths[0]
config_path = Path(filepath).parent / f'{self.mode}_params.ini'
if config_path.exists() and not self.refresh and not self.mode == DataSplit().predict:
if config_path.exists() and not self.refresh and not self.mode == dataSplit.predict:
with config_path.open('rb') as f:
config = pickle.load(f)
if config == self._build_config():
@ -155,7 +159,7 @@ class CustomShapeNet(InMemoryDataset):
break
self.process()
continue
if not self.mode == DataSplit().predict:
if not self.mode == dataSplit.predict:
config = self._build_config()
with config_path.open('wb') as f:
pickle.dump(config, f, pickle.HIGHEST_PROTOCOL)
@ -178,7 +182,6 @@ class CustomShapeNet(InMemoryDataset):
datasets = defaultdict(list)
path_to_clouds = self.raw_dir / self.mode
found_clouds = list(path_to_clouds.glob('*.xyz'))
class_map = self.class_map_all if not self.poly_as_plane else self.class_map_poly_as_plane
if len(found_clouds):
for pointcloud in tqdm(found_clouds):
if self.cluster_type not in pointcloud.name:
@ -196,21 +199,28 @@ class CustomShapeNet(InMemoryDataset):
raise ValueError('Check the Input!!!!!!')
# Expand the values from the csv by fake labels if non are provided.
vals = vals + [0] * (8 - len(vals))
vals[-2] = float(class_map[int(vals[-2])])
vals[-2] = float(self.class_remap[int(vals[-2])])
src[vals[-1]].append(vals)
# Switch from un-pickable Defaultdict to Standard Dict
src = dict(src)
# Transform the Dict[List] to Dict[torch.Tensor]
for key, values in src.items():
for key, values in list(src.items()):
src[key] = torch.tensor(values, dtype=torch.double).squeeze()
if src[key].ndim == 2:
pass
else:
del src[key]
# Screw the Sorting and make it a FullCloud rather than a seperated
if not self.collate_per_segment:
src = dict(
all=torch.cat(tuple(src.values()))
)
try:
src = dict(
all=torch.cat(tuple(src.values()))
)
except RuntimeError:
print('debugg')
# Transform Box and Polytope to Plane if poly_as_plane is set
for key, tensor in src.items():
@ -274,6 +284,7 @@ class ShapeNetPartSegDataset(Dataset):
kwargs.update(dict(root_dir=root_dir, mode=self.mode))
# self.npoints = npoints
self.dataset = CustomShapeNet(**kwargs)
self.classes = self.dataset.classes
def __getitem__(self, index):
data = self.dataset[index]