pointnet2 working - TODO: Eval!

This commit is contained in:
Si11ium
2020-05-26 21:44:57 +02:00
parent e04ef2f8b9
commit ba7c0280ae
11 changed files with 232 additions and 58 deletions

View File

@@ -11,13 +11,10 @@ from torch.utils.data import DataLoader
from torchcontrib.optim import SWA
from torchvision.transforms import Compose
from _templates.new_project.datasets.template_dataset import TemplateDataset
from ml_lib.modules.util import LightningBaseModule
from ml_lib.utils.transforms import ToTensor
from audio_toolset.audio_io import NormalizeLocal
from modules.utils import LightningBaseModule
from utils.transforms import ToTensor
from _templates.new_project.utils.project_config import GlobalVar as GlobalVars
from .project_config import GlobalVar
class BaseOptimizerMixin:
@@ -110,31 +107,31 @@ class BaseValMixin:
return summary_dict
class BinaryMaskDatasetMixin:
class DatasetMixin:
def build_dataset(self):
def build_dataset(self, dataset_class):
assert isinstance(self, LightningBaseModule)
# Dataset
# =============================================================================
# Data Augmentations or Utility Transformations
transforms = Compose([NormalizeLocal(), ToTensor()])
transforms = Compose([ToTensor()])
# Dataset
dataset = Namespace(
**dict(
# TRAIN DATASET
train_dataset=TemplateDataset(self.params.root, setting=GlobalVars.DATA_OPTIONS.train,
train_dataset=dataset_class(self.params.root, setting=GlobalVar.train,
transforms=transforms
),
# VALIDATION DATASET
val_dataset=TemplateDataset(self.params.root, setting=GlobalVars.vali,
val_dataset=dataset_class(self.params.root, setting=GlobalVar.vali,
),
# TEST DATASET
test_dataset=TemplateDataset(self.params.root, setting=GlobalVars.test,
test_dataset=dataset_class(self.params.root, setting=GlobalVar.test,
),
)

View File

@@ -1,6 +1,6 @@
from argparse import Namespace
from utils.config import Config
from ml_lib.utils.config import Config
class GlobalVar(Namespace):
@@ -18,13 +18,16 @@ class GlobalVar(Namespace):
DPI = 50
# DATAOPTIONS
train='train',
vali='vali',
test='test'
train ='train',
vali ='vali',
test ='test'
from models import *
class ThisConfig(Config):
@property
def _model_map(self):
return dict()
return dict(PN2=PointNet2)