Classes Fixed abnd debugging
This commit is contained in:
@ -13,7 +13,7 @@ import torch
|
||||
from torch_geometric.data import InMemoryDataset
|
||||
from torch_geometric.data import Data
|
||||
|
||||
from utils.project_settings import Classes, DataSplit, ClusterTypes
|
||||
from utils.project_settings import classesAll, classesPolyAsPlane, dataSplit, clusterTypes
|
||||
|
||||
|
||||
def save_names(name_list, path):
|
||||
@ -34,11 +34,11 @@ class CustomShapeNet(InMemoryDataset):
|
||||
|
||||
@property
|
||||
def modes(self):
|
||||
return {key: val for val, key in DataSplit().items()}
|
||||
return {key: val for val, key in dataSplit.items()}
|
||||
|
||||
@property
|
||||
def cluster_types(self):
|
||||
return {key: val for val, key in ClusterTypes().items()}
|
||||
return {key: val for val, key in clusterTypes.items()}
|
||||
|
||||
@property
|
||||
def raw_dir(self):
|
||||
@ -62,8 +62,8 @@ class CustomShapeNet(InMemoryDataset):
|
||||
|
||||
# Set the Dataset Parameters
|
||||
self.cluster_type = cluster_type if cluster_type else 'pc'
|
||||
self.classes = Classes()
|
||||
self.poly_as_plane = poly_as_plane
|
||||
self.classes = classesAll if not self.poly_as_plane else classesPolyAsPlane
|
||||
self.collate_per_segment = collate_per_segment
|
||||
self.mode = mode
|
||||
self.refresh = refresh
|
||||
@ -92,10 +92,10 @@ class CustomShapeNet(InMemoryDataset):
|
||||
|
||||
@property
|
||||
def num_classes(self):
|
||||
return len(self.categories) if self.poly_as_plane else (len(self.categories) - 2)
|
||||
return len(self.categories)
|
||||
|
||||
@property
|
||||
def class_map_all(self):
|
||||
def _class_map_all(self):
|
||||
return {0: 0,
|
||||
1: 1,
|
||||
2: None,
|
||||
@ -107,7 +107,7 @@ class CustomShapeNet(InMemoryDataset):
|
||||
}
|
||||
|
||||
@property
|
||||
def class_map_poly_as_plane(self):
|
||||
def _class_map_poly_as_plane(self):
|
||||
return {0: 0,
|
||||
1: 1,
|
||||
2: None,
|
||||
@ -118,11 +118,15 @@ class CustomShapeNet(InMemoryDataset):
|
||||
7: None
|
||||
}
|
||||
|
||||
@property
|
||||
def class_remap(self):
|
||||
return self._class_map_all if not self.poly_as_plane else self._class_map_poly_as_plane
|
||||
|
||||
def _load_dataset(self):
|
||||
data, slices = None, None
|
||||
filepath = self.processed_paths[0]
|
||||
config_path = Path(filepath).parent / f'{self.mode}_params.ini'
|
||||
if config_path.exists() and not self.refresh and not self.mode == DataSplit().predict:
|
||||
if config_path.exists() and not self.refresh and not self.mode == dataSplit.predict:
|
||||
with config_path.open('rb') as f:
|
||||
config = pickle.load(f)
|
||||
if config == self._build_config():
|
||||
@ -155,7 +159,7 @@ class CustomShapeNet(InMemoryDataset):
|
||||
break
|
||||
self.process()
|
||||
continue
|
||||
if not self.mode == DataSplit().predict:
|
||||
if not self.mode == dataSplit.predict:
|
||||
config = self._build_config()
|
||||
with config_path.open('wb') as f:
|
||||
pickle.dump(config, f, pickle.HIGHEST_PROTOCOL)
|
||||
@ -178,7 +182,6 @@ class CustomShapeNet(InMemoryDataset):
|
||||
datasets = defaultdict(list)
|
||||
path_to_clouds = self.raw_dir / self.mode
|
||||
found_clouds = list(path_to_clouds.glob('*.xyz'))
|
||||
class_map = self.class_map_all if not self.poly_as_plane else self.class_map_poly_as_plane
|
||||
if len(found_clouds):
|
||||
for pointcloud in tqdm(found_clouds):
|
||||
if self.cluster_type not in pointcloud.name:
|
||||
@ -196,21 +199,28 @@ class CustomShapeNet(InMemoryDataset):
|
||||
raise ValueError('Check the Input!!!!!!')
|
||||
# Expand the values from the csv by fake labels if non are provided.
|
||||
vals = vals + [0] * (8 - len(vals))
|
||||
vals[-2] = float(class_map[int(vals[-2])])
|
||||
vals[-2] = float(self.class_remap[int(vals[-2])])
|
||||
src[vals[-1]].append(vals)
|
||||
|
||||
# Switch from un-pickable Defaultdict to Standard Dict
|
||||
src = dict(src)
|
||||
|
||||
# Transform the Dict[List] to Dict[torch.Tensor]
|
||||
for key, values in src.items():
|
||||
for key, values in list(src.items()):
|
||||
src[key] = torch.tensor(values, dtype=torch.double).squeeze()
|
||||
if src[key].ndim == 2:
|
||||
pass
|
||||
else:
|
||||
del src[key]
|
||||
|
||||
# Screw the Sorting and make it a FullCloud rather than a seperated
|
||||
if not self.collate_per_segment:
|
||||
src = dict(
|
||||
all=torch.cat(tuple(src.values()))
|
||||
)
|
||||
try:
|
||||
src = dict(
|
||||
all=torch.cat(tuple(src.values()))
|
||||
)
|
||||
except RuntimeError:
|
||||
print('debugg')
|
||||
|
||||
# Transform Box and Polytope to Plane if poly_as_plane is set
|
||||
for key, tensor in src.items():
|
||||
@ -274,6 +284,7 @@ class ShapeNetPartSegDataset(Dataset):
|
||||
kwargs.update(dict(root_dir=root_dir, mode=self.mode))
|
||||
# self.npoints = npoints
|
||||
self.dataset = CustomShapeNet(**kwargs)
|
||||
self.classes = self.dataset.classes
|
||||
|
||||
def __getitem__(self, index):
|
||||
data = self.dataset[index]
|
||||
|
Reference in New Issue
Block a user