Dataset Redone

This commit is contained in:
Si11ium
2020-06-19 08:17:35 +02:00
parent 4898e98851
commit 63605ae33a
14 changed files with 239 additions and 362 deletions

View File

@@ -14,8 +14,7 @@ import matplotlib.pyplot as plt
from torch import nn
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch_geometric.data import Data
from torch_geometric.data import Data, DataLoader
from torchcontrib.optim import SWA
@@ -59,11 +58,11 @@ class BaseTrainMixin:
# Binary Cross Entropy
bce_loss = nn.BCELoss()
def training_step(self, batch_norm_pos_y_c, batch_nb, *_, **__):
def training_step(self, batch_norm_pos_y, batch_nb, *_, **__):
assert isinstance(self, LightningBaseModule)
data = self.batch_to_data(*batch_norm_pos_y_c) if not isinstance(batch_norm_pos_y_c, Data) else batch_norm_pos_y_c
data = self.batch_to_data(batch_norm_pos_y) if not isinstance(batch_norm_pos_y, Data) else batch_norm_pos_y
y = self(data).main_out
nll_loss = self.nll_loss(y, data.yl)
nll_loss = self.nll_loss(y, data.y)
return dict(loss=nll_loss, log=dict(batch_nb=batch_nb))
def training_epoch_end(self, outputs):
@@ -89,9 +88,9 @@ class BaseValMixin:
assert isinstance(self, LightningBaseModule)
data = self.batch_to_data(*batch_pos_x_n_y_c) if not isinstance(batch_pos_x_n_y_c, Data) else batch_pos_x_n_y_c
y = self(data).main_out
nll_loss = self.nll_loss(y, data.yl)
nll_loss = self.nll_loss(y, data.y)
return dict(val_nll_loss=nll_loss,
batch_idx=batch_idx, y=y, batch_y=data.yl)
batch_idx=batch_idx, y=y, batch_y=data.y)
def validation_epoch_end(self, outputs, *_, **__):
assert isinstance(self, LightningBaseModule)
@@ -229,14 +228,15 @@ class DatasetMixin:
dataset = Namespace(
**dict(
# TRAIN DATASET
train_dataset=dataset_class(self.params.root, split=GlobalVar.data_split.train, **kwargs),
train_dataset=dataset_class(self.params.root, mode=GlobalVar.data_split.train,
**kwargs),
# VALIDATION DATASET
val_dataset=dataset_class(self.params.root, split=GlobalVar.data_split.devel,
val_dataset=dataset_class(self.params.root, mode=GlobalVar.data_split.devel,
**kwargs),
# TEST DATASET
test_dataset=dataset_class(self.params.root, split=GlobalVar.data_split.test,
test_dataset=dataset_class(self.params.root, mode=GlobalVar.data_split.test,
**kwargs),
)
)

View File

@@ -68,4 +68,4 @@ class ThisConfig(Config):
@property
def _model_map(self):
return dict(PN2=PointNet2, P2P=PointNet2PrimClusters, P2G=PointNet2GridClusters)
return dict(PN2=PointNet2)