32 lines
851 B
Python
32 lines
851 B
Python
from torchvision.datasets import MNIST
|
|
from torchvision.transforms import transforms
|
|
import numpy as np
|
|
import torch
|
|
|
|
|
|
class MyMNIST(MNIST):
|
|
|
|
@property
|
|
def map_shapes_max(self):
|
|
return np.asarray(self.test_dataset[0][0]).shape
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super(MyMNIST, self).__init__('res', train=False, download=True, transform=transforms.ToTensor())
|
|
pass
|
|
|
|
def __getitem__(self, item):
|
|
image, label = super(MyMNIST, self).__getitem__(item)
|
|
return image, label
|
|
|
|
@property
|
|
def train_dataset(self):
|
|
return self.__class__('res', train=True, download=True)
|
|
|
|
@property
|
|
def test_dataset(self):
|
|
return self.__class__('res', train=False, download=True)
|
|
|
|
@property
|
|
def val_dataset(self):
|
|
return self.__class__('res', train=False, download=True)
|