Normalization and transforms for batch_to_data class

This commit is contained in:
Si11ium
2020-06-15 15:14:08 +02:00
parent bc70f42c74
commit 4898e98851
8 changed files with 26 additions and 24 deletions

View File

@@ -16,13 +16,12 @@ from torch import nn
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch_geometric.data import Data
from torchcontrib.optim import SWA
from torchvision.transforms import Compose
from ml_lib.modules.util import LightningBaseModule
from ml_lib.utils.tools import to_one_hot
from ml_lib.utils.transforms import ToTensor
from ml_lib.point_toolset.point_io import BatchToData
from .project_config import GlobalVar
@@ -59,12 +58,10 @@ class BaseTrainMixin:
nll_loss = nn.NLLLoss()
# Binary Cross Entropy
bce_loss = nn.BCELoss()
# Batch To Data
batch_to_data = BatchToData()
def training_step(self, batch_pos_x_n_y_c, batch_nb, *_, **__):
def training_step(self, batch_norm_pos_y_c, batch_nb, *_, **__):
assert isinstance(self, LightningBaseModule)
data = self.batch_to_data(*batch_pos_x_n_y_c) if not isinstance(batch_pos_x_n_y_c, Data) else batch_pos_x_n_y_c
data = self.batch_to_data(*batch_norm_pos_y_c) if not isinstance(batch_norm_pos_y_c, Data) else batch_norm_pos_y_c
y = self(data).main_out
nll_loss = self.nll_loss(y, data.yl)
return dict(loss=nll_loss, log=dict(batch_nb=batch_nb))
@@ -87,8 +84,6 @@ class BaseValMixin:
nll_loss = nn.NLLLoss()
# Binary Cross Entropy
bce_loss = nn.BCELoss()
# Batch To Data
batch_to_data = BatchToData()
def validation_step(self, batch_pos_x_n_y_c, batch_idx, *_, **__):
assert isinstance(self, LightningBaseModule)
@@ -230,14 +225,11 @@ class DatasetMixin:
# =============================================================================
# Data Augmentations or Utility Transformations
transforms = Compose([ToTensor()])
# Dataset
dataset = Namespace(
**dict(
# TRAIN DATASET
train_dataset=dataset_class(self.params.root, split=GlobalVar.data_split.train,
transforms=transforms, **kwargs),
train_dataset=dataset_class(self.params.root, split=GlobalVar.data_split.train, **kwargs),
# VALIDATION DATASET
val_dataset=dataset_class(self.params.root, split=GlobalVar.data_split.devel,