VAE Debugged and Running

This commit is contained in:
Si11ium
2020-03-25 09:39:59 +01:00
parent defa232bf2
commit 934dadb558
5 changed files with 171 additions and 193 deletions

View File

@@ -1,5 +1,7 @@
from torchvision.datasets import MNIST
from torchvision.transforms import transforms
import numpy as np
import torch
class MyMNIST(MNIST):
@@ -9,12 +11,12 @@ class MyMNIST(MNIST):
return np.asarray(self.test_dataset[0][0]).shape
def __init__(self, *args, **kwargs):
super(MyMNIST, self).__init__('res', train=False, download=True)
super(MyMNIST, self).__init__('res', train=False, download=True, transform=transforms.ToTensor())
pass
def __getitem__(self, item):
image = super(MyMNIST, self).__getitem__(item)
return np.expand_dims(np.asarray(image[0]), axis=0).astype(np.float32), image[1]
image, label = super(MyMNIST, self).__getitem__(item)
return image, label
@property
def train_dataset(self):