DataSet Modifications & Checks

This commit is contained in:
Si11ium 2020-07-02 08:58:02 +02:00
parent 3c1202d5b6
commit e9d0591b11
3 changed files with 28 additions and 8 deletions

View File

@ -94,6 +94,30 @@ class CustomShapeNet(InMemoryDataset):
def num_classes(self):
return len(self.categories) if self.poly_as_plane else (len(self.categories) - 2)
@property
def class_map_all(self):
return {0: 0,
1: 1,
2: None,
3: 2,
4: 3,
5: None,
6: 4,
7: None
}
@property
def class_map_poly_as_plane(self):
return {0: 0,
1: 1,
2: None,
3: 2,
4: 2,
5: None,
6: 2,
7: None
}
def _load_dataset(self):
data, slices = None, None
filepath = self.processed_paths[0]
@ -154,6 +178,7 @@ class CustomShapeNet(InMemoryDataset):
datasets = defaultdict(list)
path_to_clouds = self.raw_dir / self.mode
found_clouds = list(path_to_clouds.glob('*.xyz'))
class_map = self.class_map_all if not self.poly_as_plane else self.class_map_poly_as_plane
if len(found_clouds):
for pointcloud in tqdm(found_clouds):
if self.cluster_type not in pointcloud.name:
@ -171,7 +196,7 @@ class CustomShapeNet(InMemoryDataset):
raise ValueError('Check the Input!!!!!!')
# Expand the values from the csv by fake labels if non are provided.
vals = vals + [0] * (8 - len(vals))
vals[-2] = float(class_map[int(vals[-2])])
src[vals[-1]].append(vals)
# Switch from un-pickable Defaultdict to Standard Dict
@ -193,10 +218,6 @@ class CustomShapeNet(InMemoryDataset):
if all([x == 0 for x in tensor]):
continue
tensor = tensor.unsqueeze(0)
if self.poly_as_plane:
tensor[:, -2][tensor[:, -2] == float(self.classes.Plane)] = 2.0
tensor[:, -2][tensor[:, -2] == float(self.classes.Box)] = 2.0
tensor[:, -2][tensor[:, -2] == float(self.classes.Polytope)] = 2.0
src[key] = tensor
for key, values in src.items():

View File

@ -54,7 +54,7 @@ def run_lightning_loop(config_obj):
show_progress_bar=True,
weights_save_path=logger.log_dir,
gpus=[0] if torch.cuda.is_available() else None,
check_val_every_n_epoch=10,
check_val_every_n_epoch=2,
# num_sanity_val_steps=config_obj.train.num_sanity_val_steps,
# row_log_interval=(model.n_train_batches * 0.1), # TODO: Better Value / Setting
# log_save_interval=(model.n_train_batches * 0.2), # TODO: Better Value / Setting

View File

@ -25,7 +25,7 @@ class PointNet2(BaseValMixin,
# Dataset
# =============================================================================
# rot_max_angle = 15
trans_max_distance = 0.02
trans_max_distance = 0.01
transforms = Compose(
[
RandomFlip(0, p=0.8),
@ -54,7 +54,6 @@ class PointNet2(BaseValMixin,
self.n_classes = len(GlobalVar.classes) if not self.params.poly_as_plane else (len(GlobalVar.classes) - 2)
# Modules
self.point_net_core = ()
self.lin3 = torch.nn.Linear(128, self.n_classes)
# Utility