New Model running
This commit is contained in:
26
point_toolset/point_transforms.py
Normal file
26
point_toolset/point_transforms.py
Normal file
@ -0,0 +1,26 @@
|
||||
import torch
|
||||
from torch_geometric.transforms import NormalizeScale
|
||||
|
||||
|
||||
class NormalizePositions(NormalizeScale):
|
||||
|
||||
def __init__(self):
|
||||
super(NormalizePositions, self).__init__()
|
||||
|
||||
def __call__(self, data):
|
||||
if torch.isnan(data.pos).any():
|
||||
print('debug')
|
||||
|
||||
data = self.center(data)
|
||||
if torch.isnan(data.pos).any():
|
||||
print('debug')
|
||||
|
||||
scale = (1 / data.pos.abs().max()) * 0.999999
|
||||
if torch.isnan(scale).any() or torch.isinf(scale).any():
|
||||
print('debug')
|
||||
|
||||
data.pos = data.pos * scale
|
||||
if torch.isnan(data.pos).any():
|
||||
print('debug')
|
||||
|
||||
return data
|
Reference in New Issue
Block a user