30 lines
813 B
Python
30 lines
813 B
Python
from torchvision.datasets import MNIST
|
|
import numpy as np
|
|
|
|
|
|
class MyMNIST(MNIST):
|
|
|
|
@property
|
|
def map_shapes_max(self):
|
|
return np.asarray(self.test_dataset[0][0]).shape
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super(MyMNIST, self).__init__('res', train=False, download=True)
|
|
pass
|
|
|
|
def __getitem__(self, item):
|
|
image = super(MyMNIST, self).__getitem__(item)
|
|
return np.expand_dims(np.asarray(image[0]), axis=0).astype(np.float32), image[1]
|
|
|
|
@property
|
|
def train_dataset(self):
|
|
return self.__class__('res', train=True, download=True)
|
|
|
|
@property
|
|
def test_dataset(self):
|
|
return self.__class__('res', train=False, download=True)
|
|
|
|
@property
|
|
def val_dataset(self):
|
|
return self.__class__('res', train=False, download=True)
|