CNN Classifier

This commit is contained in:
Si11ium
2020-02-21 09:44:09 +01:00
parent 537e5371c9
commit 7b3f781d19
12 changed files with 247 additions and 109 deletions

View File

@ -6,7 +6,6 @@ from torch import nn
from torch import functional as F
from torch.utils.data import DataLoader
from dataset.dataset import TrajDataset, TrajPairDataset
from lib.objects.map import MapStorage
import pytorch_lightning as pl
@ -17,8 +16,20 @@ import pytorch_lightning as pl
class Flatten(nn.Module):
def __init__(self, to=(-1, )):
@property
def shape(self):
try:
x = torch.randn(self.in_shape).unsqueeze(0)
output = self(x)
return output.shape[1:]
except Exception as e:
print(e)
return -1
def __init__(self, in_shape, to=(-1, )):
super(Flatten, self).__init__()
self.in_shape = in_shape
self.to = to
def forward(self, x):