2020-03-13 21:52:33 +01:00

30 lines
813 B
Python

from torchvision.datasets import MNIST
import numpy as np
class MyMNIST(MNIST):
@property
def map_shapes_max(self):
return np.asarray(self.test_dataset[0][0]).shape
def __init__(self, *args, **kwargs):
super(MyMNIST, self).__init__('res', train=False, download=True)
pass
def __getitem__(self, item):
image = super(MyMNIST, self).__getitem__(item)
return np.expand_dims(np.asarray(image[0]), axis=0).astype(np.float32), image[1]
@property
def train_dataset(self):
return self.__class__('res', train=True, download=True)
@property
def test_dataset(self):
return self.__class__('res', train=False, download=True)
@property
def val_dataset(self):
return self.__class__('res', train=False, download=True)