ml_lib/point_toolset/point_transforms.py
2020-06-23 14:37:33 +02:00

27 lines
658 B
Python

import torch
from torch_geometric.transforms import NormalizeScale
class NormalizePositions(NormalizeScale):
def __init__(self):
super(NormalizePositions, self).__init__()
def __call__(self, data):
if torch.isnan(data.pos).any():
print('debug')
data = self.center(data)
if torch.isnan(data.pos).any():
print('debug')
scale = (1 / data.pos.abs().max()) * 0.999999
if torch.isnan(scale).any() or torch.isinf(scale).any():
print('debug')
data.pos = data.pos * scale
if torch.isnan(data.pos).any():
print('debug')
return data