VAE Debugged and Running
This commit is contained in:
@@ -1,5 +1,7 @@
|
||||
from torchvision.datasets import MNIST
|
||||
from torchvision.transforms import transforms
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
class MyMNIST(MNIST):
|
||||
@@ -9,12 +11,12 @@ class MyMNIST(MNIST):
|
||||
return np.asarray(self.test_dataset[0][0]).shape
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(MyMNIST, self).__init__('res', train=False, download=True)
|
||||
super(MyMNIST, self).__init__('res', train=False, download=True, transform=transforms.ToTensor())
|
||||
pass
|
||||
|
||||
def __getitem__(self, item):
|
||||
image = super(MyMNIST, self).__getitem__(item)
|
||||
return np.expand_dims(np.asarray(image[0]), axis=0).astype(np.float32), image[1]
|
||||
image, label = super(MyMNIST, self).__getitem__(item)
|
||||
return image, label
|
||||
|
||||
@property
|
||||
def train_dataset(self):
|
||||
|
||||
Reference in New Issue
Block a user