6D prediction files now working

This commit is contained in:
Si11ium
2020-06-25 12:03:08 +02:00
parent 965b805ee9
commit 2a7a236b89
5 changed files with 113 additions and 80 deletions

View File

@@ -24,13 +24,13 @@ class _PointNetCore(LightningBaseModule, ABC):
self.cord_dims = 6 if self.params.normals_as_cords else 3
# Modules
self.sa1_module = SAModule(0.2, 0.2, MLP([self.cord_dims, 64, 64, 128]))
self.sa1_module = SAModule(0.2, 0.2, MLP([3 + 3, 64, 64, 128]))
self.sa2_module = SAModule(0.25, 0.4, MLP([128 + self.cord_dims, 128, 128, 256]))
self.sa3_module = GlobalSAModule(MLP([256 + self.cord_dims, 256, 512, 1024]), channels=self.cord_dims)
self.fp3_module = FPModule(1, MLP([1024 + 256, 256, 256]))
self.fp2_module = FPModule(3, MLP([256 + 128, 256, 128]))
self.fp1_module = FPModule(3, MLP([128, 128, 128, 128]))
self.fp1_module = FPModule(3, MLP([128 + (3 if not self.params.normals_as_cords else 0), 128, 128, 128]))
self.lin1 = torch.nn.Linear(128, 128)
self.lin2 = torch.nn.Linear(128, 128)