23 lines
522 B
Python
23 lines
522 B
Python
from abc import ABC
|
|
from torchvision.transforms import ToTensor as TorchVisionToTensor
|
|
|
|
|
|
class _BaseTransformation(ABC):
|
|
|
|
def __init__(self, *args):
|
|
pass
|
|
|
|
def __repr__(self):
|
|
return f'{self.__class__.__name__}({self.__dict__})'
|
|
|
|
def __call__(self, *args, **kwargs):
|
|
raise NotImplementedError
|
|
|
|
|
|
class ToTensor(TorchVisionToTensor):
|
|
|
|
def __call__(self, pic):
|
|
# Make it float .float() == 32bit
|
|
tensor = super(ToTensor, self).__call__(pic).float()
|
|
return tensor
|