ml_lib/utils/transforms.py
2020-12-17 08:02:28 +01:00

23 lines
522 B
Python

from abc import ABC
from torchvision.transforms import ToTensor as TorchVisionToTensor
class _BaseTransformation(ABC):
def __init__(self, *args):
pass
def __repr__(self):
return f'{self.__class__.__name__}({self.__dict__})'
def __call__(self, *args, **kwargs):
raise NotImplementedError
class ToTensor(TorchVisionToTensor):
def __call__(self, pic):
# Make it float .float() == 32bit
tensor = super(ToTensor, self).__call__(pic).float()
return tensor