pointnet2 working - TODO: Eval!
This commit is contained in:
@@ -11,13 +11,10 @@ from torch.utils.data import DataLoader
|
||||
from torchcontrib.optim import SWA
|
||||
from torchvision.transforms import Compose
|
||||
|
||||
from _templates.new_project.datasets.template_dataset import TemplateDataset
|
||||
from ml_lib.modules.util import LightningBaseModule
|
||||
from ml_lib.utils.transforms import ToTensor
|
||||
|
||||
from audio_toolset.audio_io import NormalizeLocal
|
||||
from modules.utils import LightningBaseModule
|
||||
from utils.transforms import ToTensor
|
||||
|
||||
from _templates.new_project.utils.project_config import GlobalVar as GlobalVars
|
||||
from .project_config import GlobalVar
|
||||
|
||||
|
||||
class BaseOptimizerMixin:
|
||||
@@ -110,31 +107,31 @@ class BaseValMixin:
|
||||
return summary_dict
|
||||
|
||||
|
||||
class BinaryMaskDatasetMixin:
|
||||
class DatasetMixin:
|
||||
|
||||
def build_dataset(self):
|
||||
def build_dataset(self, dataset_class):
|
||||
assert isinstance(self, LightningBaseModule)
|
||||
|
||||
# Dataset
|
||||
# =============================================================================
|
||||
# Data Augmentations or Utility Transformations
|
||||
|
||||
transforms = Compose([NormalizeLocal(), ToTensor()])
|
||||
transforms = Compose([ToTensor()])
|
||||
|
||||
# Dataset
|
||||
dataset = Namespace(
|
||||
**dict(
|
||||
# TRAIN DATASET
|
||||
train_dataset=TemplateDataset(self.params.root, setting=GlobalVars.DATA_OPTIONS.train,
|
||||
train_dataset=dataset_class(self.params.root, setting=GlobalVar.train,
|
||||
transforms=transforms
|
||||
),
|
||||
|
||||
# VALIDATION DATASET
|
||||
val_dataset=TemplateDataset(self.params.root, setting=GlobalVars.vali,
|
||||
val_dataset=dataset_class(self.params.root, setting=GlobalVar.vali,
|
||||
),
|
||||
|
||||
# TEST DATASET
|
||||
test_dataset=TemplateDataset(self.params.root, setting=GlobalVars.test,
|
||||
test_dataset=dataset_class(self.params.root, setting=GlobalVar.test,
|
||||
),
|
||||
|
||||
)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from argparse import Namespace
|
||||
|
||||
from utils.config import Config
|
||||
from ml_lib.utils.config import Config
|
||||
|
||||
|
||||
class GlobalVar(Namespace):
|
||||
@@ -18,13 +18,16 @@ class GlobalVar(Namespace):
|
||||
DPI = 50
|
||||
|
||||
# DATAOPTIONS
|
||||
train='train',
|
||||
vali='vali',
|
||||
test='test'
|
||||
train ='train',
|
||||
vali ='vali',
|
||||
test ='test'
|
||||
|
||||
|
||||
from models import *
|
||||
|
||||
|
||||
class ThisConfig(Config):
|
||||
|
||||
@property
|
||||
def _model_map(self):
|
||||
return dict()
|
||||
return dict(PN2=PointNet2)
|
||||
|
||||
Reference in New Issue
Block a user