2020-03-25 09:39:59 +01:00

32 lines
851 B
Python

from torchvision.datasets import MNIST
from torchvision.transforms import transforms
import numpy as np
import torch
class MyMNIST(MNIST):
@property
def map_shapes_max(self):
return np.asarray(self.test_dataset[0][0]).shape
def __init__(self, *args, **kwargs):
super(MyMNIST, self).__init__('res', train=False, download=True, transform=transforms.ToTensor())
pass
def __getitem__(self, item):
image, label = super(MyMNIST, self).__getitem__(item)
return image, label
@property
def train_dataset(self):
return self.__class__('res', train=True, download=True)
@property
def test_dataset(self):
return self.__class__('res', train=False, download=True)
@property
def val_dataset(self):
return self.__class__('res', train=False, download=True)