27 lines
658 B
Python
27 lines
658 B
Python
import torch
|
|
from torch_geometric.transforms import NormalizeScale
|
|
|
|
|
|
class NormalizePositions(NormalizeScale):
|
|
|
|
def __init__(self):
|
|
super(NormalizePositions, self).__init__()
|
|
|
|
def __call__(self, data):
|
|
if torch.isnan(data.pos).any():
|
|
print('debug')
|
|
|
|
data = self.center(data)
|
|
if torch.isnan(data.pos).any():
|
|
print('debug')
|
|
|
|
scale = (1 / data.pos.abs().max()) * 0.999999
|
|
if torch.isnan(scale).any() or torch.isinf(scale).any():
|
|
print('debug')
|
|
|
|
data.pos = data.pos * scale
|
|
if torch.isnan(data.pos).any():
|
|
print('debug')
|
|
|
|
return data
|