inference build
This commit is contained in:
parent
1033b26195
commit
85cf3128f1
@ -198,10 +198,7 @@ class CustomShapeNet(InMemoryDataset):
|
|||||||
if self.collate_per_segment:
|
if self.collate_per_segment:
|
||||||
datasets[self.mode].append(data)
|
datasets[self.mode].append(data)
|
||||||
if not self.collate_per_segment:
|
if not self.collate_per_segment:
|
||||||
# This is just to be sure, but should not be needed, since src[all] == all
|
datasets[self.mode].append(Data(**{key: torch.cat(data[key]) for key in data.keys()}))
|
||||||
raise TypeError('FIX THIS')
|
|
||||||
# old Code
|
|
||||||
# datasets[self.mode].append(Data(**{key: torch.cat(data[key]) for key in data.keys()}))
|
|
||||||
|
|
||||||
if datasets[self.mode]:
|
if datasets[self.mode]:
|
||||||
os.makedirs(self.processed_dir, exist_ok=True)
|
os.makedirs(self.processed_dir, exist_ok=True)
|
||||||
|
@ -25,7 +25,7 @@ from utils.project_settings import GlobalVar
|
|||||||
|
|
||||||
|
|
||||||
def restore_logger_and_model(log_dir):
|
def restore_logger_and_model(log_dir):
|
||||||
model = SavedLightningModels.load_checkpoint(models_root_path=log_dir, model=PointNet2, n=-5)
|
model = SavedLightningModels.load_checkpoint(models_root_path=log_dir, model=PointNet2, n=-1)
|
||||||
model = model.restore()
|
model = model.restore()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
model.cuda()
|
model.cuda()
|
||||||
@ -69,10 +69,10 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
# TEST DATASET
|
# TEST DATASET
|
||||||
transforms = Compose([NormalizeScale(), ])
|
transforms = Compose([NormalizeScale(), ])
|
||||||
test_dataset = ShapeNetPartSegDataset('data', mode=GlobalVar.data_split.predict, collate_per_segment=True,
|
test_dataset = ShapeNetPartSegDataset('data', mode=GlobalVar.data_split.predict, collate_per_segment=False,
|
||||||
refresh=True, transform=transforms)
|
refresh=True, transform=transforms)
|
||||||
|
|
||||||
grid_clusters = cluster_cubes(test_dataset[0], [3, 3, 3], max_points_per_cluster=1024)
|
grid_clusters = cluster_cubes(test_dataset[1], [3, 3, 3], max_points_per_cluster=1024)
|
||||||
|
|
||||||
ps.init()
|
ps.init()
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user