DataSet Modifications & Checks
This commit is contained in:
parent
3c1202d5b6
commit
e9d0591b11
@ -94,6 +94,30 @@ class CustomShapeNet(InMemoryDataset):
|
||||
def num_classes(self):
|
||||
return len(self.categories) if self.poly_as_plane else (len(self.categories) - 2)
|
||||
|
||||
@property
|
||||
def class_map_all(self):
|
||||
return {0: 0,
|
||||
1: 1,
|
||||
2: None,
|
||||
3: 2,
|
||||
4: 3,
|
||||
5: None,
|
||||
6: 4,
|
||||
7: None
|
||||
}
|
||||
|
||||
@property
|
||||
def class_map_poly_as_plane(self):
|
||||
return {0: 0,
|
||||
1: 1,
|
||||
2: None,
|
||||
3: 2,
|
||||
4: 2,
|
||||
5: None,
|
||||
6: 2,
|
||||
7: None
|
||||
}
|
||||
|
||||
def _load_dataset(self):
|
||||
data, slices = None, None
|
||||
filepath = self.processed_paths[0]
|
||||
@ -154,6 +178,7 @@ class CustomShapeNet(InMemoryDataset):
|
||||
datasets = defaultdict(list)
|
||||
path_to_clouds = self.raw_dir / self.mode
|
||||
found_clouds = list(path_to_clouds.glob('*.xyz'))
|
||||
class_map = self.class_map_all if not self.poly_as_plane else self.class_map_poly_as_plane
|
||||
if len(found_clouds):
|
||||
for pointcloud in tqdm(found_clouds):
|
||||
if self.cluster_type not in pointcloud.name:
|
||||
@ -171,7 +196,7 @@ class CustomShapeNet(InMemoryDataset):
|
||||
raise ValueError('Check the Input!!!!!!')
|
||||
# Expand the values from the csv by fake labels if non are provided.
|
||||
vals = vals + [0] * (8 - len(vals))
|
||||
|
||||
vals[-2] = float(class_map[int(vals[-2])])
|
||||
src[vals[-1]].append(vals)
|
||||
|
||||
# Switch from un-pickable Defaultdict to Standard Dict
|
||||
@ -193,10 +218,6 @@ class CustomShapeNet(InMemoryDataset):
|
||||
if all([x == 0 for x in tensor]):
|
||||
continue
|
||||
tensor = tensor.unsqueeze(0)
|
||||
if self.poly_as_plane:
|
||||
tensor[:, -2][tensor[:, -2] == float(self.classes.Plane)] = 2.0
|
||||
tensor[:, -2][tensor[:, -2] == float(self.classes.Box)] = 2.0
|
||||
tensor[:, -2][tensor[:, -2] == float(self.classes.Polytope)] = 2.0
|
||||
src[key] = tensor
|
||||
|
||||
for key, values in src.items():
|
||||
|
2
main.py
2
main.py
@ -54,7 +54,7 @@ def run_lightning_loop(config_obj):
|
||||
show_progress_bar=True,
|
||||
weights_save_path=logger.log_dir,
|
||||
gpus=[0] if torch.cuda.is_available() else None,
|
||||
check_val_every_n_epoch=10,
|
||||
check_val_every_n_epoch=2,
|
||||
# num_sanity_val_steps=config_obj.train.num_sanity_val_steps,
|
||||
# row_log_interval=(model.n_train_batches * 0.1), # TODO: Better Value / Setting
|
||||
# log_save_interval=(model.n_train_batches * 0.2), # TODO: Better Value / Setting
|
||||
|
@ -25,7 +25,7 @@ class PointNet2(BaseValMixin,
|
||||
# Dataset
|
||||
# =============================================================================
|
||||
# rot_max_angle = 15
|
||||
trans_max_distance = 0.02
|
||||
trans_max_distance = 0.01
|
||||
transforms = Compose(
|
||||
[
|
||||
RandomFlip(0, p=0.8),
|
||||
@ -54,7 +54,6 @@ class PointNet2(BaseValMixin,
|
||||
self.n_classes = len(GlobalVar.classes) if not self.params.poly_as_plane else (len(GlobalVar.classes) - 2)
|
||||
|
||||
# Modules
|
||||
self.point_net_core = ()
|
||||
self.lin3 = torch.nn.Linear(128, self.n_classes)
|
||||
|
||||
# Utility
|
||||
|
Loading…
x
Reference in New Issue
Block a user