6D prediction files now working
This commit is contained in:
@@ -24,13 +24,13 @@ class _PointNetCore(LightningBaseModule, ABC):
|
||||
self.cord_dims = 6 if self.params.normals_as_cords else 3
|
||||
|
||||
# Modules
|
||||
self.sa1_module = SAModule(0.2, 0.2, MLP([self.cord_dims, 64, 64, 128]))
|
||||
self.sa1_module = SAModule(0.2, 0.2, MLP([3 + 3, 64, 64, 128]))
|
||||
self.sa2_module = SAModule(0.25, 0.4, MLP([128 + self.cord_dims, 128, 128, 256]))
|
||||
self.sa3_module = GlobalSAModule(MLP([256 + self.cord_dims, 256, 512, 1024]), channels=self.cord_dims)
|
||||
|
||||
self.fp3_module = FPModule(1, MLP([1024 + 256, 256, 256]))
|
||||
self.fp2_module = FPModule(3, MLP([256 + 128, 256, 128]))
|
||||
self.fp1_module = FPModule(3, MLP([128, 128, 128, 128]))
|
||||
self.fp1_module = FPModule(3, MLP([128 + (3 if not self.params.normals_as_cords else 0), 128, 128, 128]))
|
||||
|
||||
self.lin1 = torch.nn.Linear(128, 128)
|
||||
self.lin2 = torch.nn.Linear(128, 128)
|
||||
|
||||
Reference in New Issue
Block a user