pointnet2 working - TODO: Eval!

This commit is contained in:
Si11ium
2020-05-26 21:44:57 +02:00
parent e04ef2f8b9
commit ba7c0280ae
11 changed files with 232 additions and 58 deletions

View File

@ -11,13 +11,10 @@ from torch.utils.data import DataLoader
from torchcontrib.optim import SWA
from torchvision.transforms import Compose
from _templates.new_project.datasets.template_dataset import TemplateDataset
from ml_lib.modules.util import LightningBaseModule
from ml_lib.utils.transforms import ToTensor
from audio_toolset.audio_io import NormalizeLocal
from modules.utils import LightningBaseModule
from utils.transforms import ToTensor
from _templates.new_project.utils.project_config import GlobalVar as GlobalVars
from .project_config import GlobalVar
class BaseOptimizerMixin:
@ -110,31 +107,31 @@ class BaseValMixin:
return summary_dict
class BinaryMaskDatasetMixin:
class DatasetMixin:
def build_dataset(self):
def build_dataset(self, dataset_class):
assert isinstance(self, LightningBaseModule)
# Dataset
# =============================================================================
# Data Augmentations or Utility Transformations
transforms = Compose([NormalizeLocal(), ToTensor()])
transforms = Compose([ToTensor()])
# Dataset
dataset = Namespace(
**dict(
# TRAIN DATASET
train_dataset=TemplateDataset(self.params.root, setting=GlobalVars.DATA_OPTIONS.train,
train_dataset=dataset_class(self.params.root, setting=GlobalVar.train,
transforms=transforms
),
# VALIDATION DATASET
val_dataset=TemplateDataset(self.params.root, setting=GlobalVars.vali,
val_dataset=dataset_class(self.params.root, setting=GlobalVar.vali,
),
# TEST DATASET
test_dataset=TemplateDataset(self.params.root, setting=GlobalVars.test,
test_dataset=dataset_class(self.params.root, setting=GlobalVar.test,
),
)