CNN Classifier
This commit is contained in:
@ -6,7 +6,6 @@ from torch import nn
|
||||
from torch import functional as F
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from dataset.dataset import TrajDataset, TrajPairDataset
|
||||
from lib.objects.map import MapStorage
|
||||
|
||||
import pytorch_lightning as pl
|
||||
@ -17,8 +16,20 @@ import pytorch_lightning as pl
|
||||
|
||||
|
||||
class Flatten(nn.Module):
|
||||
def __init__(self, to=(-1, )):
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
try:
|
||||
x = torch.randn(self.in_shape).unsqueeze(0)
|
||||
output = self(x)
|
||||
return output.shape[1:]
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return -1
|
||||
|
||||
def __init__(self, in_shape, to=(-1, )):
|
||||
super(Flatten, self).__init__()
|
||||
self.in_shape = in_shape
|
||||
self.to = to
|
||||
|
||||
def forward(self, x):
|
||||
|
Reference in New Issue
Block a user