New Model running
This commit is contained in:
parent
a19bd9cafd
commit
1033b26195
@ -24,7 +24,11 @@ main_arg_parser.add_argument("--data_worker", type=int, default=10, help="")
|
|||||||
main_arg_parser.add_argument("--data_npoints", type=int, default=1024, help="")
|
main_arg_parser.add_argument("--data_npoints", type=int, default=1024, help="")
|
||||||
main_arg_parser.add_argument("--data_root", type=str, default='data', help="")
|
main_arg_parser.add_argument("--data_root", type=str, default='data', help="")
|
||||||
main_arg_parser.add_argument("--data_dataset_type", type=str, default='ShapeNetPartSegDataset', help="")
|
main_arg_parser.add_argument("--data_dataset_type", type=str, default='ShapeNetPartSegDataset', help="")
|
||||||
|
main_arg_parser.add_argument("--data_cluster_type", type=str, default='grid', help="")
|
||||||
main_arg_parser.add_argument("--data_use_preprocessed", type=strtobool, default=True, help="")
|
main_arg_parser.add_argument("--data_use_preprocessed", type=strtobool, default=True, help="")
|
||||||
|
main_arg_parser.add_argument("--data_normals_as_cords", type=strtobool, default=True, help="")
|
||||||
|
main_arg_parser.add_argument("--data_refresh", type=strtobool, default=False, help="")
|
||||||
|
main_arg_parser.add_argument("--data_poly_as_plane", type=strtobool, default=True, help="")
|
||||||
|
|
||||||
# Transformations
|
# Transformations
|
||||||
# main_arg_parser.add_argument("--transformations_to_tensor", type=strtobool, default=False, help="")
|
# main_arg_parser.add_argument("--transformations_to_tensor", type=strtobool, default=False, help="")
|
||||||
@ -33,7 +37,7 @@ main_arg_parser.add_argument("--data_use_preprocessed", type=strtobool, default=
|
|||||||
# Training
|
# Training
|
||||||
main_arg_parser.add_argument("--train_outpath", type=str, default="output", help="")
|
main_arg_parser.add_argument("--train_outpath", type=str, default="output", help="")
|
||||||
main_arg_parser.add_argument("--train_version", type=strtobool, required=False, help="")
|
main_arg_parser.add_argument("--train_version", type=strtobool, required=False, help="")
|
||||||
main_arg_parser.add_argument("--train_epochs", type=int, default=250, help="")
|
main_arg_parser.add_argument("--train_epochs", type=int, default=200, help="")
|
||||||
main_arg_parser.add_argument("--train_batch_size", type=int, default=10, help="")
|
main_arg_parser.add_argument("--train_batch_size", type=int, default=10, help="")
|
||||||
main_arg_parser.add_argument("--train_lr", type=float, default=1e-3, help="")
|
main_arg_parser.add_argument("--train_lr", type=float, default=1e-3, help="")
|
||||||
main_arg_parser.add_argument("--train_weight_decay", type=float, default=1e-8, help="")
|
main_arg_parser.add_argument("--train_weight_decay", type=float, default=1e-8, help="")
|
||||||
@ -44,7 +48,6 @@ main_arg_parser.add_argument("--train_opt_reset_interval", type=strtobool, defau
|
|||||||
# Model
|
# Model
|
||||||
# Possible Model arguments are: P2P, PN2, P2G
|
# Possible Model arguments are: P2P, PN2, P2G
|
||||||
main_arg_parser.add_argument("--model_type", type=str, default="PN2", help="")
|
main_arg_parser.add_argument("--model_type", type=str, default="PN2", help="")
|
||||||
main_arg_parser.add_argument("--model_norm_as_feature", type=strtobool, default=True, help="")
|
|
||||||
|
|
||||||
main_arg_parser.add_argument("--model_activation", type=str, default="leaky_relu", help="")
|
main_arg_parser.add_argument("--model_activation", type=str, default="leaky_relu", help="")
|
||||||
main_arg_parser.add_argument("--model_use_bias", type=strtobool, default=True, help="")
|
main_arg_parser.add_argument("--model_use_bias", type=strtobool, default=True, help="")
|
||||||
@ -52,11 +55,10 @@ main_arg_parser.add_argument("--model_use_norm", type=strtobool, default=True, h
|
|||||||
main_arg_parser.add_argument("--model_dropout", type=float, default=0.00, help="")
|
main_arg_parser.add_argument("--model_dropout", type=float, default=0.00, help="")
|
||||||
|
|
||||||
# Model 2: Layer Specific Stuff
|
# Model 2: Layer Specific Stuff
|
||||||
|
|
||||||
main_arg_parser.add_argument("--model_features", type=int, default=16, help="")
|
main_arg_parser.add_argument("--model_features", type=int, default=16, help="")
|
||||||
|
|
||||||
# Parse it
|
# Parse it
|
||||||
args: Namespace = main_arg_parser.parse_args()
|
args: Namespace = main_arg_parser.parse_args()
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
pass
|
pass
|
||||||
|
@ -1,8 +1,7 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Union
|
||||||
from warnings import warn
|
from warnings import warn
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
||||||
import os
|
import os
|
||||||
@ -13,7 +12,7 @@ import torch
|
|||||||
from torch_geometric.data import InMemoryDataset
|
from torch_geometric.data import InMemoryDataset
|
||||||
from torch_geometric.data import Data
|
from torch_geometric.data import Data
|
||||||
|
|
||||||
from utils.project_settings import Classes, DataSplit
|
from utils.project_settings import Classes, DataSplit, ClusterTypes
|
||||||
|
|
||||||
|
|
||||||
def save_names(name_list, path):
|
def save_names(name_list, path):
|
||||||
@ -23,10 +22,23 @@ def save_names(name_list, path):
|
|||||||
|
|
||||||
class CustomShapeNet(InMemoryDataset):
|
class CustomShapeNet(InMemoryDataset):
|
||||||
|
|
||||||
categories = {key: val for val, key in Classes().items()}
|
|
||||||
modes = {key: val for val, key in DataSplit().items()}
|
|
||||||
name = 'CustomShapeNet'
|
name = 'CustomShapeNet'
|
||||||
|
|
||||||
|
def download(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
def categories(self):
|
||||||
|
return {key: val for val, key in self.classes.items()}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def modes(self):
|
||||||
|
return {key: val for val, key in DataSplit().items()}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def cluster_types(self):
|
||||||
|
return {key: val for val, key in ClusterTypes().items()}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def raw_dir(self):
|
def raw_dir(self):
|
||||||
return self.root / 'raw'
|
return self.root / 'raw'
|
||||||
@ -40,14 +52,21 @@ class CustomShapeNet(InMemoryDataset):
|
|||||||
return self.root / 'processed'
|
return self.root / 'processed'
|
||||||
|
|
||||||
def __init__(self, root_dir, collate_per_segment=True, mode='train', transform=None, pre_filter=None,
|
def __init__(self, root_dir, collate_per_segment=True, mode='train', transform=None, pre_filter=None,
|
||||||
pre_transform=None, refresh=False, with_normals=False):
|
pre_transform=None, refresh=False, cluster_type: Union[str, None] = '',
|
||||||
assert mode in self.modes.keys(), f'"mode" must be one of {self.modes.keys()}'
|
poly_as_plane=False):
|
||||||
|
assert mode in self.modes.keys(), \
|
||||||
|
f'"mode" must be one of {self.modes.keys()}'
|
||||||
|
assert cluster_type in self.cluster_types.keys() or cluster_type is None, \
|
||||||
|
f'"cluster_type" must be one of {self.cluster_types.keys()} or None, but was: {cluster_type}'
|
||||||
|
|
||||||
# Set the Dataset Parameters
|
# Set the Dataset Parameters
|
||||||
|
self.cluster_type = cluster_type if cluster_type else 'pc'
|
||||||
|
self.classes = Classes()
|
||||||
|
self.poly_as_plane = poly_as_plane
|
||||||
self.collate_per_segment = collate_per_segment
|
self.collate_per_segment = collate_per_segment
|
||||||
self.mode = mode
|
self.mode = mode
|
||||||
self.refresh = refresh
|
self.refresh = refresh
|
||||||
self.with_normals = with_normals
|
|
||||||
root_dir = Path(root_dir)
|
root_dir = Path(root_dir)
|
||||||
super(CustomShapeNet, self).__init__(root_dir, transform, pre_transform, pre_filter)
|
super(CustomShapeNet, self).__init__(root_dir, transform, pre_transform, pre_filter)
|
||||||
self.data, self.slices = self._load_dataset()
|
self.data, self.slices = self._load_dataset()
|
||||||
@ -72,7 +91,7 @@ class CustomShapeNet(InMemoryDataset):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def num_classes(self):
|
def num_classes(self):
|
||||||
return len(self.categories)
|
return len(self.categories) if self.poly_as_plane else (len(self.categories) - 2)
|
||||||
|
|
||||||
def _load_dataset(self):
|
def _load_dataset(self):
|
||||||
data, slices = None, None
|
data, slices = None, None
|
||||||
@ -101,22 +120,17 @@ class CustomShapeNet(InMemoryDataset):
|
|||||||
return data, slices
|
return data, slices
|
||||||
|
|
||||||
def _pre_transform_and_filter(self, data):
|
def _pre_transform_and_filter(self, data):
|
||||||
# ToDo: ANy filter to apply? Then do it here.
|
|
||||||
if self.pre_filter is not None and not self.pre_filter(data):
|
if self.pre_filter is not None and not self.pre_filter(data):
|
||||||
data = self.pre_filter(data)
|
data = self.pre_filter(data)
|
||||||
raise NotImplementedError
|
|
||||||
# ToDo: ANy transformation to apply? Then do it here.
|
|
||||||
if self.pre_transform is not None:
|
if self.pre_transform is not None:
|
||||||
data = self.pre_transform(data)
|
data = self.pre_transform(data)
|
||||||
raise NotImplementedError
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
def process(self, delimiter=' '):
|
def process(self, delimiter=' '):
|
||||||
datasets = defaultdict(list)
|
datasets = defaultdict(list)
|
||||||
path_to_clouds = self.raw_dir / self.mode
|
path_to_clouds = self.raw_dir / self.mode
|
||||||
|
|
||||||
for pointcloud in tqdm(path_to_clouds.glob('*.xyz')):
|
for pointcloud in tqdm(path_to_clouds.glob('*.xyz')):
|
||||||
if 'grid' not in pointcloud.name:
|
if self.cluster_type not in pointcloud.name:
|
||||||
continue
|
continue
|
||||||
data = None
|
data = None
|
||||||
|
|
||||||
@ -129,15 +143,32 @@ class CustomShapeNet(InMemoryDataset):
|
|||||||
vals = [float(x) if x not in ['-nan(ind)', 'nan(ind)'] else 0 for x in vals]
|
vals = [float(x) if x not in ['-nan(ind)', 'nan(ind)'] else 0 for x in vals]
|
||||||
src[vals[-1]].append(vals)
|
src[vals[-1]].append(vals)
|
||||||
|
|
||||||
|
# Switch from un-pickable Defaultdict to Standard Dict
|
||||||
src = dict(src)
|
src = dict(src)
|
||||||
|
|
||||||
|
# Transform the Dict[List] to Dict[torch.Tensor]
|
||||||
for key, values in src.items():
|
for key, values in src.items():
|
||||||
src[key] = torch.tensor(values, dtype=torch.double).squeeze()
|
src[key] = torch.tensor(values, dtype=torch.double).squeeze()
|
||||||
|
|
||||||
|
# Screw the Sorting and make it a FullCloud rather than a seperated
|
||||||
if not self.collate_per_segment:
|
if not self.collate_per_segment:
|
||||||
src = dict(
|
src = dict(
|
||||||
all=torch.cat(tuple(src.values()))
|
all=torch.cat(tuple(src.values()))
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Transform Box and Polytope to Plane if poly_as_plane is set
|
||||||
|
for key, tensor in src.items():
|
||||||
|
if tensor.ndim == 1:
|
||||||
|
if all([x == 0 for x in tensor]):
|
||||||
|
continue
|
||||||
|
tensor = tensor.unsqueeze(0)
|
||||||
|
if self.poly_as_plane:
|
||||||
|
tensor[:, -2][tensor[:, -2] == float(self.classes.Plane)] = 4.0
|
||||||
|
tensor[:, -2][tensor[:, -2] == float(self.classes.Box)] = 4.0
|
||||||
|
tensor[:, -2][tensor[:, -2] == float(self.classes.Polytope)] = 4.0
|
||||||
|
tensor[:, -2][tensor[:, -2] == self.classes.Torus] = 3.0
|
||||||
|
src[key] = tensor
|
||||||
|
|
||||||
for key, values in src.items():
|
for key, values in src.items():
|
||||||
try:
|
try:
|
||||||
points = values[:, :-2]
|
points = values[:, :-2]
|
||||||
@ -147,36 +178,35 @@ class CustomShapeNet(InMemoryDataset):
|
|||||||
y_c = torch.as_tensor(values[:, -1], dtype=torch.long)
|
y_c = torch.as_tensor(values[:, -1], dtype=torch.long)
|
||||||
####################################
|
####################################
|
||||||
# This is where you define the keys
|
# This is where you define the keys
|
||||||
attr_dict = dict(y=y, y_c=y_c)
|
attr_dict = dict(
|
||||||
if self.with_normals:
|
y=y,
|
||||||
pos = points[:, :6]
|
y_c=y_c,
|
||||||
norm = None
|
pos=points[:, :3],
|
||||||
attr_dict.update(pos=pos, norm=norm)
|
norm=points[:, 3:6]
|
||||||
if not self.with_normals:
|
)
|
||||||
pos = points[:, :3]
|
|
||||||
norm = points[:, 3:6]
|
|
||||||
attr_dict.update(pos=pos, norm=norm)
|
|
||||||
####################################
|
####################################
|
||||||
if self.collate_per_segment:
|
if self.collate_per_segment:
|
||||||
data = Data(**attr_dict)
|
data = Data(**attr_dict)
|
||||||
else:
|
else:
|
||||||
if data is None:
|
if data is None:
|
||||||
data = defaultdict(list)
|
data = defaultdict(list)
|
||||||
# points=points, norm=points[:, 3:]
|
for attr_key, val in attr_dict.items():
|
||||||
for key, val in attr_dict.items():
|
data[attr_key].append(val)
|
||||||
data[key].append(val)
|
|
||||||
# data = Data(**data)
|
|
||||||
|
|
||||||
# data = self._pre_transform_and_filter(data)
|
# data = self._pre_transform_and_filter(data)
|
||||||
if self.collate_per_segment:
|
if self.collate_per_segment:
|
||||||
datasets[self.mode].append(data)
|
datasets[self.mode].append(data)
|
||||||
if not self.collate_per_segment:
|
if not self.collate_per_segment:
|
||||||
# This is just to be sure, but should not be needed, since src[all] == all there is in this cloud
|
# This is just to be sure, but should not be needed, since src[all] == all
|
||||||
datasets[self.mode].append(Data(**{key: torch.cat(data[key]) for key in data.keys()}))
|
raise TypeError('FIX THIS')
|
||||||
|
# old Code
|
||||||
|
# datasets[self.mode].append(Data(**{key: torch.cat(data[key]) for key in data.keys()}))
|
||||||
|
|
||||||
if datasets[self.mode]:
|
if datasets[self.mode]:
|
||||||
os.makedirs(self.processed_dir, exist_ok=True)
|
os.makedirs(self.processed_dir, exist_ok=True)
|
||||||
torch.save(self.collate(datasets[self.mode]), self.processed_paths[0])
|
collated_dataset = self.collate(datasets[self.mode])
|
||||||
|
torch.save(collated_dataset, self.processed_paths[0])
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f'{self.__class__.__name__}({len(self)})'
|
return f'{self.__class__.__name__}({len(self)})'
|
||||||
@ -190,17 +220,18 @@ class ShapeNetPartSegDataset(Dataset):
|
|||||||
|
|
||||||
name = 'ShapeNetPartSegDataset'
|
name = 'ShapeNetPartSegDataset'
|
||||||
|
|
||||||
def __init__(self, root_dir, npoints=1024, mode='train', **kwargs):
|
def __init__(self, root_dir, mode='train', **kwargs):
|
||||||
super(ShapeNetPartSegDataset, self).__init__()
|
super(ShapeNetPartSegDataset, self).__init__()
|
||||||
self.mode = mode
|
self.mode = mode
|
||||||
kwargs.update(dict(root_dir=root_dir, mode=self.mode))
|
kwargs.update(dict(root_dir=root_dir, mode=self.mode))
|
||||||
self.npoints = npoints
|
# self.npoints = npoints
|
||||||
self.dataset = CustomShapeNet(**kwargs)
|
self.dataset = CustomShapeNet(**kwargs)
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
data = self.dataset[index]
|
data = self.dataset[index]
|
||||||
|
|
||||||
# Resample to fixed number of points
|
# Resample to fixed number of points
|
||||||
|
'''
|
||||||
try:
|
try:
|
||||||
npoints = self.npoints if self.mode != DataSplit.predict else data.pos.shape[0]
|
npoints = self.npoints if self.mode != DataSplit.predict else data.pos.shape[0]
|
||||||
choice = np.random.choice(data.pos.shape[0], npoints,
|
choice = np.random.choice(data.pos.shape[0], npoints,
|
||||||
@ -209,16 +240,16 @@ class ShapeNetPartSegDataset(Dataset):
|
|||||||
except ValueError:
|
except ValueError:
|
||||||
choice = []
|
choice = []
|
||||||
|
|
||||||
pos, norm, y = data.pos[choice, :], data.norm[choice], data.y[choice]
|
pos, norm, y = data.pos[choice, :], data.norm[choice], data.y[choice]
|
||||||
|
|
||||||
# y -= 1 if self.num_classes() in y else 0 # Map label from [1, C] to [0, C-1]
|
# y -= 1 if self.num_classes() in y else 0 # Map label from [1, C] to [0, C-1]
|
||||||
|
|
||||||
data = Data(**dict(pos=pos, # torch.Tensor (n, 3/6)
|
data = Data(**dict(pos=pos, # torch.Tensor (n, 3/6)
|
||||||
y=y, # torch.Tensor (n,)
|
y=y, # torch.Tensor (n,)
|
||||||
norm=norm # torch.Tensor (n, 3/0)
|
norm=norm # torch.Tensor (n, 3/0)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
'''
|
||||||
return data
|
return data
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
|
1
main.py
1
main.py
@ -46,6 +46,7 @@ def run_lightning_loop(config_obj):
|
|||||||
# Init
|
# Init
|
||||||
model: LightningBaseModule = config_obj.model_class(config_obj.model_paramters)
|
model: LightningBaseModule = config_obj.model_class(config_obj.model_paramters)
|
||||||
model.init_weights(torch.nn.init.xavier_normal_)
|
model.init_weights(torch.nn.init.xavier_normal_)
|
||||||
|
model.save_to_disk(logger.log_dir)
|
||||||
|
|
||||||
# Trainer
|
# Trainer
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
@ -21,7 +21,7 @@ from utils.project_settings import GlobalVar
|
|||||||
|
|
||||||
|
|
||||||
def prepare_dataloader(config_obj):
|
def prepare_dataloader(config_obj):
|
||||||
dataset = ShapeNetPartSegDataset(config_obj.data.root, split=GlobalVar.data_split.test,
|
dataset = ShapeNetPartSegDataset(config_obj.data.root, mode=GlobalVar.data_split.test,
|
||||||
setting=GlobalVar.settings[config_obj.model.type])
|
setting=GlobalVar.settings[config_obj.model.type])
|
||||||
# noinspection PyTypeChecker
|
# noinspection PyTypeChecker
|
||||||
return DataLoader(dataset, batch_size=config_obj.train.batch_size,
|
return DataLoader(dataset, batch_size=config_obj.train.batch_size,
|
||||||
|
@ -11,6 +11,8 @@ from torch.utils.data import DataLoader
|
|||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
# Transforms
|
# Transforms
|
||||||
|
from torch_geometric.transforms import Compose, NormalizeScale, RandomFlip
|
||||||
|
|
||||||
from ml_lib.point_toolset.point_io import BatchToData
|
from ml_lib.point_toolset.point_io import BatchToData
|
||||||
from ml_lib.utils.model_io import SavedLightningModels
|
from ml_lib.utils.model_io import SavedLightningModels
|
||||||
|
|
||||||
@ -18,21 +20,12 @@ from ml_lib.utils.model_io import SavedLightningModels
|
|||||||
# Datasets
|
# Datasets
|
||||||
from datasets.shapenet import ShapeNetPartSegDataset
|
from datasets.shapenet import ShapeNetPartSegDataset
|
||||||
from models import PointNet2
|
from models import PointNet2
|
||||||
from utils.pointcloud import read_pointcloud, normalize_pointcloud, cluster_cubes, append_onehotencoded_type, \
|
from utils.pointcloud import cluster_cubes, append_onehotencoded_type, label2color
|
||||||
label2color
|
|
||||||
from utils.project_settings import GlobalVar
|
from utils.project_settings import GlobalVar
|
||||||
|
|
||||||
|
|
||||||
def prepare_dataloader(config_obj):
|
|
||||||
dataset = ShapeNetPartSegDataset(config_obj.data.root, split=GlobalVar.data_split.test,
|
|
||||||
setting=GlobalVar.settings[config_obj.model.type])
|
|
||||||
# noinspection PyTypeChecker
|
|
||||||
return DataLoader(dataset, batch_size=config_obj.train.batch_size,
|
|
||||||
num_workers=config_obj.data.worker, shuffle=False)
|
|
||||||
|
|
||||||
|
|
||||||
def restore_logger_and_model(log_dir):
|
def restore_logger_and_model(log_dir):
|
||||||
model = SavedLightningModels.load_checkpoint(models_root_path=log_dir, model=PointNet2, n=-1)
|
model = SavedLightningModels.load_checkpoint(models_root_path=log_dir, model=PointNet2, n=-5)
|
||||||
model = model.restore()
|
model = model.restore()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
model.cuda()
|
model.cuda()
|
||||||
@ -40,26 +33,30 @@ def restore_logger_and_model(log_dir):
|
|||||||
model.cpu()
|
model.cpu()
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def predict_prim_type(input_pc, model):
|
def predict_prim_type(input_pc, model):
|
||||||
|
|
||||||
input_data = dict(norm=torch.tensor(np.array([input_pc[:, 3:6]], np.float)).unsqueeze(0),
|
input_data = dict(
|
||||||
pos=torch.tensor(input_pc[:, 0:3]).unsqueeze(0),
|
norm=torch.tensor(np.array([input_pc[:, 3:6]], np.float)).unsqueeze(0),
|
||||||
)
|
pos=torch.tensor(input_pc[:, 0:3]).unsqueeze(0),
|
||||||
|
)
|
||||||
|
|
||||||
batch_to_data = BatchToData()
|
batch_to_data = BatchToData()
|
||||||
|
|
||||||
data = batch_to_data(input_data)
|
data = batch_to_data(input_data)
|
||||||
y = loaded_model(data.to(device='cuda' if torch.cuda.is_available() else 'cpu'))
|
y = loaded_model(data.to(device='cuda' if torch.cuda.is_available() else 'cpu'))
|
||||||
y_primary = torch.argmax(y.main_out, dim=-1).squeeze().cpu().numpy()
|
y_primary = torch.argmax(y.main_out, dim=-1).cpu().numpy()
|
||||||
|
|
||||||
return np.concatenate((input_pc, y_primary.reshape(-1,1)), axis=1)
|
if input_pc.shape[1] > 6:
|
||||||
|
input_pc = input_pc[:, :6]
|
||||||
|
return np.concatenate((input_pc, y_primary.reshape(-1, 1)), axis=-1)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
||||||
input_pc_path = Path('data') / 'pc' / 'test.xyz'
|
input_pc_path = Path('data') / 'pc' / 'test.xyz'
|
||||||
|
|
||||||
model_path = Path('output') / 'PN2' / 'PN_26512907a2de0664bfad2349a6bffee3' / 'version_0'
|
model_path = Path('output') / 'PN2' / 'PN_9843bf499399786cfd58fe79fa1b3db8' / 'version_0'
|
||||||
# config_filename = 'config.ini'
|
# config_filename = 'config.ini'
|
||||||
# config = ThisConfig()
|
# config = ThisConfig()
|
||||||
# config.read_file((Path(model_path) / config_filename).open('r'))
|
# config.read_file((Path(model_path) / config_filename).open('r'))
|
||||||
@ -71,8 +68,9 @@ if __name__ == '__main__':
|
|||||||
# input_pc = normalize_pointcloud(input_pc)
|
# input_pc = normalize_pointcloud(input_pc)
|
||||||
|
|
||||||
# TEST DATASET
|
# TEST DATASET
|
||||||
test_dataset = ShapeNetPartSegDataset('data', mode=GlobalVar.data_split.predict, collate_per_segment=False,
|
transforms = Compose([NormalizeScale(), ])
|
||||||
npoints=1024, refresh=True)
|
test_dataset = ShapeNetPartSegDataset('data', mode=GlobalVar.data_split.predict, collate_per_segment=True,
|
||||||
|
refresh=True, transform=transforms)
|
||||||
|
|
||||||
grid_clusters = cluster_cubes(test_dataset[0], [3, 3, 3], max_points_per_cluster=1024)
|
grid_clusters = cluster_cubes(test_dataset[0], [3, 3, 3], max_points_per_cluster=1024)
|
||||||
|
|
||||||
@ -84,8 +82,7 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
pc_with_prim_type = predict_prim_type(grid_cluster_pc, loaded_model)
|
pc_with_prim_type = predict_prim_type(grid_cluster_pc, loaded_model)
|
||||||
|
|
||||||
#pc_with_prim_type = polytopes_to_planes(pc_with_prim_type)
|
# pc_with_prim_type = polytopes_to_planes(pc_with_prim_type)
|
||||||
|
|
||||||
pc_with_prim_type = append_onehotencoded_type(pc_with_prim_type)
|
pc_with_prim_type = append_onehotencoded_type(pc_with_prim_type)
|
||||||
|
|
||||||
pc = ps.register_point_cloud("points_" + str(i), pc_with_prim_type[:, :3], radius=0.01)
|
pc = ps.register_point_cloud("points_" + str(i), pc_with_prim_type[:, :3], radius=0.01)
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
from abc import ABC
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch_geometric.transforms import Compose, NormalizeScale, RandomFlip
|
from torch_geometric.transforms import Compose, NormalizeScale, RandomFlip
|
||||||
@ -7,28 +9,28 @@ from ml_lib.modules.util import LightningBaseModule, F_x
|
|||||||
from ml_lib.point_toolset.point_io import BatchToData
|
from ml_lib.point_toolset.point_io import BatchToData
|
||||||
|
|
||||||
|
|
||||||
class _PointNetCore(LightningBaseModule):
|
class _PointNetCore(LightningBaseModule, ABC):
|
||||||
|
|
||||||
def __init__(self, hparams):
|
def __init__(self, hparams):
|
||||||
super(_PointNetCore, self).__init__(hparams=hparams)
|
super(_PointNetCore, self).__init__(hparams=hparams)
|
||||||
|
|
||||||
# Transforms
|
# Transforms
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
transforms = Compose([NormalizeScale(), RandomFlip(0, p=0.8), ])
|
self.batch_to_data = BatchToData(transforms=None)
|
||||||
self.batch_to_data = BatchToData(transforms=transforms)
|
|
||||||
|
|
||||||
# Model Paramters
|
# Model Paramters
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Additional parameters
|
# Additional parameters
|
||||||
|
self.cord_dims = 6 if self.params.normals_as_cords else 3
|
||||||
|
|
||||||
# Modules
|
# Modules
|
||||||
self.sa1_module = SAModule(0.2, 0.2, MLP([3 + 3, 64, 64, 128]))
|
self.sa1_module = SAModule(0.2, 0.2, MLP([self.cord_dims, 64, 64, 128]))
|
||||||
self.sa2_module = SAModule(0.25, 0.4, MLP([128 + 3, 128, 128, 256]))
|
self.sa2_module = SAModule(0.25, 0.4, MLP([128 + self.cord_dims, 128, 128, 256]))
|
||||||
self.sa3_module = GlobalSAModule(MLP([256 + 3, 256, 512, 1024]))
|
self.sa3_module = GlobalSAModule(MLP([256 + self.cord_dims, 256, 512, 1024]), channels=self.cord_dims)
|
||||||
|
|
||||||
self.fp3_module = FPModule(1, MLP([1024 + 256, 256, 256]))
|
self.fp3_module = FPModule(1, MLP([1024 + 256, 256, 256]))
|
||||||
self.fp2_module = FPModule(3, MLP([256 + 128, 256, 128]))
|
self.fp2_module = FPModule(3, MLP([256 + 128, 256, 128]))
|
||||||
self.fp1_module = FPModule(3, MLP([128 + 3, 128, 128, 128]))
|
self.fp1_module = FPModule(3, MLP([128, 128, 128, 128]))
|
||||||
|
|
||||||
self.lin1 = torch.nn.Linear(128, 128)
|
self.lin1 = torch.nn.Linear(128, 128)
|
||||||
self.lin2 = torch.nn.Linear(128, 128)
|
self.lin2 = torch.nn.Linear(128, 128)
|
||||||
|
@ -2,6 +2,7 @@ from argparse import Namespace
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
from torch_geometric.transforms import Compose, RandomFlip, FixedPoints, RandomTranslate, NormalizeScale
|
||||||
|
|
||||||
from datasets.shapenet import ShapeNetPartSegDataset
|
from datasets.shapenet import ShapeNetPartSegDataset
|
||||||
from models._point_net_2 import _PointNetCore
|
from models._point_net_2 import _PointNetCore
|
||||||
@ -21,21 +22,40 @@ class PointNet2(BaseValMixin,
|
|||||||
def __init__(self, hparams):
|
def __init__(self, hparams):
|
||||||
super(PointNet2, self).__init__(hparams=hparams)
|
super(PointNet2, self).__init__(hparams=hparams)
|
||||||
|
|
||||||
|
# Dataset
|
||||||
|
# =============================================================================
|
||||||
|
# rot_max_angle = 15
|
||||||
|
trans_max_distance = 0.01
|
||||||
|
transforms = Compose(
|
||||||
|
[
|
||||||
|
RandomFlip(0, p=0.8),
|
||||||
|
FixedPoints(self.params.npoints),
|
||||||
|
# This is not available with 6-dim cords
|
||||||
|
# RandomRotate(rot_max_angle, 0), RandomRotate(rot_max_angle, 1), RandomRotate(rot_max_angle, 2),
|
||||||
|
RandomTranslate(trans_max_distance),
|
||||||
|
NormalizeScale()
|
||||||
|
# NormalizePositions()
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
# Dataset
|
# Dataset
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
self.dataset = self.build_dataset(ShapeNetPartSegDataset,
|
self.dataset = self.build_dataset(ShapeNetPartSegDataset,
|
||||||
collate_per_segment=True,
|
collate_per_segment=True,
|
||||||
npoints=self.params.npoints
|
transform=transforms,
|
||||||
|
cluster_type=self.params.cluster_type,
|
||||||
|
refresh=self.params.refresh,
|
||||||
|
poly_as_plane=self.params.poly_as_plane
|
||||||
)
|
)
|
||||||
|
|
||||||
# Model Paramters
|
# Model Paramters
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Additional parameters
|
# Additional parameters
|
||||||
self.n_classes = len(GlobalVar.classes)
|
self.n_classes = len(GlobalVar.classes) if not self.params.poly_as_plane else (len(GlobalVar.classes) - 2)
|
||||||
|
|
||||||
# Modules
|
# Modules
|
||||||
self.point_net_core = ()
|
self.point_net_core = ()
|
||||||
self.lin3 = torch.nn.Linear(128, len(GlobalVar.classes))
|
self.lin3 = torch.nn.Linear(128, self.n_classes)
|
||||||
|
|
||||||
# Utility
|
# Utility
|
||||||
self.log_softmax = nn.LogSoftmax(dim=-1)
|
self.log_softmax = nn.LogSoftmax(dim=-1)
|
||||||
@ -53,7 +73,11 @@ class PointNet2(BaseValMixin,
|
|||||||
idendifiers for all nodes of all graphs/pointclouds in the batch. See
|
idendifiers for all nodes of all graphs/pointclouds in the batch. See
|
||||||
pytorch_gemometric documentation for more information
|
pytorch_gemometric documentation for more information
|
||||||
"""
|
"""
|
||||||
sa0_out = (data.norm, data.pos, data.batch)
|
if not self.params.normals_as_cords:
|
||||||
|
sa0_out = (data.norm, data.pos, data.batch)
|
||||||
|
else:
|
||||||
|
pos_cat_norm = torch.cat((data.pos, data.norm), dim=-1)
|
||||||
|
sa0_out = (None, pos_cat_norm, data.batch)
|
||||||
tensor = super(PointNet2, self).forward(sa0_out)
|
tensor = super(PointNet2, self).forward(sa0_out)
|
||||||
tensor = self.lin3(tensor)
|
tensor = self.lin3(tensor)
|
||||||
tensor = self.log_softmax(tensor)
|
tensor = self.log_softmax(tensor)
|
||||||
|
@ -21,7 +21,7 @@ if __name__ == '__main__':
|
|||||||
# bias, activation, model, norm, max_epochs
|
# bias, activation, model, norm, max_epochs
|
||||||
|
|
||||||
for arg_dict in [pn2]:
|
for arg_dict in [pn2]:
|
||||||
for seed in range(10):
|
for seed in range(2):
|
||||||
arg_dict.update(main_seed=seed)
|
arg_dict.update(main_seed=seed)
|
||||||
|
|
||||||
config = config.update(arg_dict)
|
config = config.update(arg_dict)
|
||||||
|
@ -28,7 +28,7 @@ jsonschema==3.2.0
|
|||||||
kiwisolver==1.2.0
|
kiwisolver==1.2.0
|
||||||
llvmlite==0.32.1
|
llvmlite==0.32.1
|
||||||
Markdown==3.2.2
|
Markdown==3.2.2
|
||||||
matplotlib==3.2.1
|
matplotlib~=3.2.2
|
||||||
monotonic==1.5
|
monotonic==1.5
|
||||||
msgpack==1.0.0
|
msgpack==1.0.0
|
||||||
msgpack-python==0.5.6
|
msgpack-python==0.5.6
|
||||||
@ -38,7 +38,7 @@ networkx==2.4
|
|||||||
numba==0.49.1
|
numba==0.49.1
|
||||||
numpy==1.18.5
|
numpy==1.18.5
|
||||||
oauthlib==3.1.0
|
oauthlib==3.1.0
|
||||||
pandas==1.0.4
|
pandas~=1.0.5
|
||||||
Pillow==7.1.2
|
Pillow==7.1.2
|
||||||
plyfile==0.7.2
|
plyfile==0.7.2
|
||||||
polyscope==0.1.2
|
polyscope==0.1.2
|
||||||
@ -74,14 +74,14 @@ tensorboard-plugin-wit==1.6.0.post3
|
|||||||
test-tube==0.7.5
|
test-tube==0.7.5
|
||||||
threadpoolctl==2.1.0
|
threadpoolctl==2.1.0
|
||||||
tifffile==2020.6.3
|
tifffile==2020.6.3
|
||||||
torch==1.4.0+cpu
|
torch~=1.5.1+cpu
|
||||||
torch-cluster==1.5.4
|
torch-cluster==1.5.4
|
||||||
torch-geometric==1.4.3
|
torch-geometric==1.4.3
|
||||||
torch-scatter==2.0.4
|
torch-scatter==2.0.4
|
||||||
torch-sparse==0.6.1
|
torch-sparse==0.6.1
|
||||||
torchcontrib==0.0.2
|
torchcontrib==0.0.2
|
||||||
torchvision==0.5.0
|
torchvision~=0.4.1
|
||||||
tqdm==4.45.0
|
tqdm~=4.46.1
|
||||||
typing-extensions==3.7.4.2
|
typing-extensions==3.7.4.2
|
||||||
urllib3==1.25.9
|
urllib3==1.25.9
|
||||||
webcolors==1.11.1
|
webcolors==1.11.1
|
||||||
@ -89,3 +89,5 @@ websocket-client==0.57.0
|
|||||||
Werkzeug==1.0.1
|
Werkzeug==1.0.1
|
||||||
xmltodict==0.12.0
|
xmltodict==0.12.0
|
||||||
zipp==3.1.0
|
zipp==3.1.0
|
||||||
|
|
||||||
|
open3d~=0.10.0.1
|
@ -110,10 +110,10 @@ class BaseValMixin:
|
|||||||
#######################################################################################
|
#######################################################################################
|
||||||
#
|
#
|
||||||
# INIT
|
# INIT
|
||||||
y_true = torch.cat([output['batch_y'] for output in outputs]) .cpu().numpy()
|
y_true = torch.cat([output['batch_y'] for output in outputs]).cpu().numpy()
|
||||||
y_true_one_hot = to_one_hot(y_true, self.n_classes)
|
y_true_one_hot = to_one_hot(y_true, self.n_classes)
|
||||||
|
|
||||||
y_pred = torch.cat([output['y'] for output in outputs]).squeeze().cpu().numpy()
|
y_pred = torch.cat([output['y'] for output in outputs]).squeeze().cpu().float().numpy()
|
||||||
y_pred_max = np.argmax(y_pred, axis=1)
|
y_pred_max = np.argmax(y_pred, axis=1)
|
||||||
|
|
||||||
class_names = {val: key for key, val in GlobalVar.classes.items()}
|
class_names = {val: key for key, val in GlobalVar.classes.items()}
|
||||||
@ -134,7 +134,7 @@ class BaseValMixin:
|
|||||||
fpr = dict()
|
fpr = dict()
|
||||||
tpr = dict()
|
tpr = dict()
|
||||||
roc_auc = dict()
|
roc_auc = dict()
|
||||||
for i in range(len(GlobalVar.classes)):
|
for i in range(self.n_classes):
|
||||||
fpr[i], tpr[i], _ = roc_curve(y_true_one_hot[:, i], y_pred[:, i])
|
fpr[i], tpr[i], _ = roc_curve(y_true_one_hot[:, i], y_pred[:, i])
|
||||||
roc_auc[i] = auc(fpr[i], tpr[i])
|
roc_auc[i] = auc(fpr[i], tpr[i])
|
||||||
|
|
||||||
@ -143,15 +143,15 @@ class BaseValMixin:
|
|||||||
roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])
|
roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])
|
||||||
|
|
||||||
# First aggregate all false positive rates
|
# First aggregate all false positive rates
|
||||||
all_fpr = np.unique(np.concatenate([fpr[i] for i in range(len(GlobalVar.classes))]))
|
all_fpr = np.unique(np.concatenate([fpr[i] for i in range(self.n_classes)]))
|
||||||
|
|
||||||
# Then interpolate all ROC curves at this points
|
# Then interpolate all ROC curves at this points
|
||||||
mean_tpr = np.zeros_like(all_fpr)
|
mean_tpr = np.zeros_like(all_fpr)
|
||||||
for i in range(len(GlobalVar.classes)):
|
for i in range(self.n_classes):
|
||||||
mean_tpr += interp(all_fpr, fpr[i], tpr[i])
|
mean_tpr += interp(all_fpr, fpr[i], tpr[i])
|
||||||
|
|
||||||
# Finally average it and compute AUC
|
# Finally average it and compute AUC
|
||||||
mean_tpr /= len(GlobalVar.classes)
|
mean_tpr /= self.n_classes
|
||||||
|
|
||||||
fpr["macro"] = all_fpr
|
fpr["macro"] = all_fpr
|
||||||
tpr["macro"] = mean_tpr
|
tpr["macro"] = mean_tpr
|
||||||
@ -170,7 +170,7 @@ class BaseValMixin:
|
|||||||
colors = cycle(['firebrick', 'orangered', 'gold', 'olive', 'limegreen', 'aqua',
|
colors = cycle(['firebrick', 'orangered', 'gold', 'olive', 'limegreen', 'aqua',
|
||||||
'dodgerblue', 'slategrey', 'royalblue', 'indigo', 'fuchsia'], )
|
'dodgerblue', 'slategrey', 'royalblue', 'indigo', 'fuchsia'], )
|
||||||
|
|
||||||
for i, color in zip(range(len(GlobalVar.classes)), colors):
|
for i, color in zip(range(self.n_classes), colors):
|
||||||
plt.plot(fpr[i], tpr[i], color=color, lw=2, label=f'{class_names[i]} ({round(roc_auc[i],2 )})')
|
plt.plot(fpr[i], tpr[i], color=color, lw=2, label=f'{class_names[i]} ({round(roc_auc[i],2 )})')
|
||||||
|
|
||||||
plt.plot([0, 1], [0, 1], 'k--', lw=2)
|
plt.plot([0, 1], [0, 1], 'k--', lw=2)
|
||||||
@ -236,7 +236,7 @@ class DatasetMixin:
|
|||||||
**kwargs),
|
**kwargs),
|
||||||
|
|
||||||
# TEST DATASET
|
# TEST DATASET
|
||||||
test_dataset=dataset_class(self.params.root, mode=GlobalVar.data_split.test,
|
test_dataset=dataset_class(self.params.root, mode=GlobalVar.data_split.predict,
|
||||||
**kwargs),
|
**kwargs),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -3,27 +3,15 @@ from sklearn.cluster import DBSCAN
|
|||||||
|
|
||||||
import open3d as o3d
|
import open3d as o3d
|
||||||
|
|
||||||
from pyod.models.knn import KNN
|
|
||||||
from pyod.models.sod import SOD
|
|
||||||
from pyod.models.abod import ABOD
|
|
||||||
from pyod.models.sos import SOS
|
|
||||||
from pyod.models.pca import PCA
|
|
||||||
from pyod.models.ocsvm import OCSVM
|
|
||||||
from pyod.models.mcd import MCD
|
|
||||||
from pyod.models.lof import LOF
|
from pyod.models.lof import LOF
|
||||||
from pyod.models.cof import COF
|
|
||||||
from pyod.models.cblof import CBLOF
|
|
||||||
from pyod.models.loci import LOCI
|
|
||||||
from pyod.models.hbos import HBOS
|
|
||||||
from pyod.models.lscp import LSCP
|
|
||||||
from pyod.models.feature_bagging import FeatureBagging
|
|
||||||
from torch_geometric.data import Data
|
from torch_geometric.data import Data
|
||||||
|
|
||||||
from utils.project_settings import Classes
|
from utils.project_settings import Classes
|
||||||
|
|
||||||
|
|
||||||
def polytopes_to_planes(pc):
|
def polytopes_to_planes(pc):
|
||||||
pc[(pc[:, 6] == float(Classes.Box)) | (pc[:, 6] == float(Classes.Polytope)), 6] = float(Classes.Plane);
|
pc[(pc[:, 6] == float(Classes.Box)) or (pc[:, 6] == float(Classes.Polytope)), 6] = float(Classes.Plane)
|
||||||
return pc
|
return pc
|
||||||
|
|
||||||
|
|
||||||
@ -49,7 +37,7 @@ def mini_color_table(index, norm=True):
|
|||||||
def cluster2Color(cluster, cluster_idx):
|
def cluster2Color(cluster, cluster_idx):
|
||||||
colors = np.zeros(shape=(len(cluster), 3))
|
colors = np.zeros(shape=(len(cluster), 3))
|
||||||
point_idx = 0
|
point_idx = 0
|
||||||
for point in cluster:
|
for _ in cluster:
|
||||||
colors[point_idx, :] = mini_color_table(cluster_idx)
|
colors[point_idx, :] = mini_color_table(cluster_idx)
|
||||||
point_idx += 1
|
point_idx += 1
|
||||||
|
|
||||||
@ -87,6 +75,8 @@ def write_pointcloud(file, pc, numCols=6):
|
|||||||
|
|
||||||
|
|
||||||
def farthest_point_sampling(pts, K):
|
def farthest_point_sampling(pts, K):
|
||||||
|
if isinstance(pts, Data):
|
||||||
|
pts = pts.pos.numpy()
|
||||||
if pts.shape[0] < K:
|
if pts.shape[0] < K:
|
||||||
return pts
|
return pts
|
||||||
|
|
||||||
@ -119,7 +109,15 @@ def cluster_cubes(data, cluster_dims, max_points_per_cluster=-1, min_points_per_
|
|||||||
|
|
||||||
if isinstance(data, Data):
|
if isinstance(data, Data):
|
||||||
import torch
|
import torch
|
||||||
data = torch.cat((data.pos, data.norm, data.y.double().unsqueeze(-1)), dim=-1).numpy()
|
candidate_list = list()
|
||||||
|
if data.pos is not None:
|
||||||
|
candidate_list.append(data.pos)
|
||||||
|
if data.norm is not None:
|
||||||
|
candidate_list.append(data.norm)
|
||||||
|
if data.y is not None:
|
||||||
|
candidate_list.append(data.y.double().unsqueeze(-1))
|
||||||
|
|
||||||
|
data = torch.cat(candidate_list, dim=-1).numpy()
|
||||||
|
|
||||||
max = data[:, :3].max(axis=0)
|
max = data[:, :3].max(axis=0)
|
||||||
max += max * 0.01
|
max += max * 0.01
|
||||||
|
@ -1,7 +1,5 @@
|
|||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
|
|
||||||
from ml_lib.utils.config import Config
|
|
||||||
|
|
||||||
|
|
||||||
class DataClass(Namespace):
|
class DataClass(Namespace):
|
||||||
|
|
||||||
@ -18,18 +16,19 @@ class DataClass(Namespace):
|
|||||||
return f'{self.__class__.__name__}({self.__dict__().__repr__()})'
|
return f'{self.__class__.__name__}({self.__dict__().__repr__()})'
|
||||||
|
|
||||||
def __getitem__(self, item):
|
def __getitem__(self, item):
|
||||||
return self.__getattribute__(item)
|
return self.__dict__()[item]
|
||||||
|
|
||||||
|
|
||||||
class Classes(DataClass):
|
class Classes(DataClass):
|
||||||
|
|
||||||
# Object Classes for Point Segmentation
|
# Object Classes for Point Segmentation
|
||||||
Sphere = 0
|
Sphere = 0
|
||||||
Cylinder = 1
|
Cylinder = 1
|
||||||
Cone = 2
|
Cone = 2
|
||||||
Box = 3
|
Box = 3 # All SubTypes of Planes
|
||||||
Polytope = 4
|
Polytope = 4 #
|
||||||
Torus = 5
|
Torus = 5
|
||||||
Plane = 6
|
Plane = 6 #
|
||||||
|
|
||||||
|
|
||||||
class Settings(DataClass):
|
class Settings(DataClass):
|
||||||
@ -38,6 +37,11 @@ class Settings(DataClass):
|
|||||||
PN2 = 'pc'
|
PN2 = 'pc'
|
||||||
|
|
||||||
|
|
||||||
|
class ClusterTypes(DataClass):
|
||||||
|
prim = 'prim'
|
||||||
|
grid = 'grid'
|
||||||
|
none = ''
|
||||||
|
|
||||||
class DataSplit(DataClass):
|
class DataSplit(DataClass):
|
||||||
# DATA SPLIT OPTIONS
|
# DATA SPLIT OPTIONS
|
||||||
train = 'train'
|
train = 'train'
|
||||||
@ -59,4 +63,4 @@ class GlobalVar(DataClass):
|
|||||||
|
|
||||||
prim_count = -1
|
prim_count = -1
|
||||||
|
|
||||||
settings = Settings()
|
settings = Settings()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user