inference build

This commit is contained in:
Si11ium 2020-06-23 21:05:49 +02:00
parent 1033b26195
commit 85cf3128f1
2 changed files with 4 additions and 7 deletions

View File

@ -198,10 +198,7 @@ class CustomShapeNet(InMemoryDataset):
if self.collate_per_segment: if self.collate_per_segment:
datasets[self.mode].append(data) datasets[self.mode].append(data)
if not self.collate_per_segment: if not self.collate_per_segment:
# This is just to be sure, but should not be needed, since src[all] == all datasets[self.mode].append(Data(**{key: torch.cat(data[key]) for key in data.keys()}))
raise TypeError('FIX THIS')
# old Code
# datasets[self.mode].append(Data(**{key: torch.cat(data[key]) for key in data.keys()}))
if datasets[self.mode]: if datasets[self.mode]:
os.makedirs(self.processed_dir, exist_ok=True) os.makedirs(self.processed_dir, exist_ok=True)

View File

@ -25,7 +25,7 @@ from utils.project_settings import GlobalVar
def restore_logger_and_model(log_dir): def restore_logger_and_model(log_dir):
model = SavedLightningModels.load_checkpoint(models_root_path=log_dir, model=PointNet2, n=-5) model = SavedLightningModels.load_checkpoint(models_root_path=log_dir, model=PointNet2, n=-1)
model = model.restore() model = model.restore()
if torch.cuda.is_available(): if torch.cuda.is_available():
model.cuda() model.cuda()
@ -69,10 +69,10 @@ if __name__ == '__main__':
# TEST DATASET # TEST DATASET
transforms = Compose([NormalizeScale(), ]) transforms = Compose([NormalizeScale(), ])
test_dataset = ShapeNetPartSegDataset('data', mode=GlobalVar.data_split.predict, collate_per_segment=True, test_dataset = ShapeNetPartSegDataset('data', mode=GlobalVar.data_split.predict, collate_per_segment=False,
refresh=True, transform=transforms) refresh=True, transform=transforms)
grid_clusters = cluster_cubes(test_dataset[0], [3, 3, 3], max_points_per_cluster=1024) grid_clusters = cluster_cubes(test_dataset[1], [3, 3, 3], max_points_per_cluster=1024)
ps.init() ps.init()