eval running - offline logger implemented -> Test it!

This commit is contained in:
Si11ium 2020-05-30 18:12:42 +02:00
parent ba7c0280ae
commit 8d0577b756
9 changed files with 212 additions and 102 deletions

View File

@ -11,7 +11,7 @@ main_arg_parser = ArgumentParser(description="parser for fast-neural-style")
# Main Parameters # Main Parameters
main_arg_parser.add_argument("--main_debug", type=strtobool, default=False, help="") main_arg_parser.add_argument("--main_debug", type=strtobool, default=False, help="")
main_arg_parser.add_argument("--main_eval", type=strtobool, default=True, help="") main_arg_parser.add_argument("--main_eval", type=strtobool, default=False, help="")
main_arg_parser.add_argument("--main_seed", type=int, default=69, help="") main_arg_parser.add_argument("--main_seed", type=int, default=69, help="")
# Project # Project
@ -21,7 +21,6 @@ main_arg_parser.add_argument("--project_neptune_key", type=str, default=os.geten
# Data Parameters # Data Parameters
main_arg_parser.add_argument("--data_worker", type=int, default=10, help="") main_arg_parser.add_argument("--data_worker", type=int, default=10, help="")
main_arg_parser.add_argument("--data_dataset_length", type=int, default=10000, help="")
main_arg_parser.add_argument("--data_root", type=str, default='data', help="") main_arg_parser.add_argument("--data_root", type=str, default='data', help="")
main_arg_parser.add_argument("--data_use_preprocessed", type=strtobool, default=True, help="") main_arg_parser.add_argument("--data_use_preprocessed", type=strtobool, default=True, help="")
@ -29,25 +28,28 @@ main_arg_parser.add_argument("--data_use_preprocessed", type=strtobool, default=
# main_arg_parser.add_argument("--transformations_to_tensor", type=strtobool, default=False, help="") # main_arg_parser.add_argument("--transformations_to_tensor", type=strtobool, default=False, help="")
# main_arg_parser.add_argument("--transformations_normalize", type=strtobool, default=False, help="") # main_arg_parser.add_argument("--transformations_normalize", type=strtobool, default=False, help="")
# Transformations # Training
main_arg_parser.add_argument("--train_outpath", type=str, default="output", help="") main_arg_parser.add_argument("--train_outpath", type=str, default="output", help="")
main_arg_parser.add_argument("--train_version", type=strtobool, required=False, help="") main_arg_parser.add_argument("--train_version", type=strtobool, required=False, help="")
main_arg_parser.add_argument("--train_epochs", type=int, default=500, help="") main_arg_parser.add_argument("--train_epochs", type=int, default=500, help="")
main_arg_parser.add_argument("--train_batch_size", type=int, default=200, help="") main_arg_parser.add_argument("--train_batch_size", type=int, default=10, help="")
main_arg_parser.add_argument("--train_lr", type=float, default=1e-3, help="") main_arg_parser.add_argument("--train_lr", type=float, default=1e-3, help="")
main_arg_parser.add_argument("--train_weight_decay", type=float, default=1e-8, help="") main_arg_parser.add_argument("--train_weight_decay", type=float, default=1e-8, help="")
main_arg_parser.add_argument("--train_num_sanity_val_steps", type=int, default=0, help="") main_arg_parser.add_argument("--train_num_sanity_val_steps", type=int, default=0, help="")
main_arg_parser.add_argument("--train_sto_weight_avg", type=strtobool, default=False, help="")
main_arg_parser.add_argument("--train_opt_reset_interval", type=strtobool, default=False, help="")
# Model # Model
main_arg_parser.add_argument("--model_type", type=str, default="PN2", help="") main_arg_parser.add_argument("--model_type", type=str, default="PN2", help="")
main_arg_parser.add_argument("--model_norm_as_feature", type=strtobool, default=True, help="")
main_arg_parser.add_argument("--model_activation", type=str, default="leaky_relu", help="") main_arg_parser.add_argument("--model_activation", type=str, default="leaky_relu", help="")
main_arg_parser.add_argument("--model_use_bias", type=strtobool, default=True, help="") main_arg_parser.add_argument("--model_use_bias", type=strtobool, default=True, help="")
main_arg_parser.add_argument("--model_use_norm", type=strtobool, default=False, help="") main_arg_parser.add_argument("--model_use_norm", type=strtobool, default=True, help="")
main_arg_parser.add_argument("--model_dropout", type=float, default=0.00, help="") main_arg_parser.add_argument("--model_dropout", type=float, default=0.00, help="")
# Model 2: Layer Specific Stuff # Model 2: Layer Specific Stuff
main_arg_parser.add_argument("--model_filters", type=str, default="[16, 32, 64]", help="")
main_arg_parser.add_argument("--model_features", type=int, default=16, help="") main_arg_parser.add_argument("--model_features", type=int, default=16, help="")
# Parse it # Parse it

View File

@ -5,7 +5,7 @@ from abc import ABC
from pathlib import Path from pathlib import Path
from torch.utils.data import Dataset from torch.utils.data import Dataset
from ml_lib.point_toolset.sampling import FarthestpointSampling from ml_lib.point_toolset.sampling import FarthestpointSampling, RandomSampling
import numpy as np import numpy as np
@ -22,16 +22,21 @@ class _Point_Dataset(ABC, Dataset):
raise NotImplementedError raise NotImplementedError
headers = ['x', 'y', 'z', 'xn', 'yn', 'zn', 'label', 'cl_idx'] headers = ['x', 'y', 'z', 'xn', 'yn', 'zn', 'label', 'cl_idx']
samplers = dict(fps=FarthestpointSampling, rnd=RandomSampling)
def __init__(self, root=Path('data'), sampling_k=2048, transforms=None, load_preprocessed=True, *args, **kwargs): def __init__(self, root=Path('data'), norm_as_feature=True, sampling_k=2048, sampling='rnd',
transforms=None, load_preprocessed=True, split='train', dense_output=False, *args, **kwargs):
super(_Point_Dataset, self).__init__() super(_Point_Dataset, self).__init__()
self.dense_output = dense_output
self.split = split
self.norm_as_feature = norm_as_feature
self.load_preprocessed = load_preprocessed self.load_preprocessed = load_preprocessed
self.transforms = transforms if transforms else lambda x: x self.transforms = transforms if transforms else lambda x: x
self.sampling_k = sampling_k self.sampling_k = sampling_k
self.sampling = FarthestpointSampling(K=self.sampling_k) self.sampling = self.samplers[sampling](K=self.sampling_k)
self.root = Path(root) self.root = Path(root)
self.raw = self.root / 'raw' self.raw = self.root / 'raw' / self.split
self.processed_ext = '.pik' self.processed_ext = '.pik'
self.raw_ext = '.xyz' self.raw_ext = '.xyz'
self.processed = self.root / self.setting self.processed = self.root / self.setting

View File

@ -9,6 +9,7 @@ from ._point_dataset import _Point_Dataset
class FullCloudsDataset(_Point_Dataset): class FullCloudsDataset(_Point_Dataset):
setting = 'pc' setting = 'pc'
split: str
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(FullCloudsDataset, self).__init__(*args, **kwargs) super(FullCloudsDataset, self).__init__(*args, **kwargs)
@ -21,13 +22,15 @@ class FullCloudsDataset(_Point_Dataset):
with processed_file_path.open('rb') as processed_file: with processed_file_path.open('rb') as processed_file:
pointcloud = pickle.load(processed_file) pointcloud = pickle.load(processed_file)
points = np.stack((pointcloud['x'], pointcloud['y'], pointcloud['z'],
pointcloud['xn'], pointcloud['yn'], pointcloud['zn']
),
axis=-1)
# When yopu want to return points and normal seperately
# normal = np.stack((pointcloud['xn'], pointcloud['yn'], pointcloud['zn']), axis=-1)
label = pointcloud['label']
sample_idxs = self.sampling(points)
return points[sample_idxs].astype(np.float), label[sample_idxs].astype(np.int) position = np.stack((pointcloud['x'], pointcloud['y'], pointcloud['z']), axis=-1)
normal = np.stack((pointcloud['xn'], pointcloud['yn'], pointcloud['zn']), axis=-1)
label = pointcloud['label']
sample_idxs = self.sampling(position)
return (normal[sample_idxs].astype(np.float),
position[sample_idxs].astype(np.float),
label[sample_idxs].astype(np.int))

View File

@ -26,7 +26,7 @@ class FullCloudsDataset(_Point_Dataset):
# When yopu want to return points and normal seperately # When yopu want to return points and normal seperately
# normal = np.stack((pointcloud['xn'], pointcloud['yn'], pointcloud['zn']), axis=-1) # normal = np.stack((pointcloud['xn'], pointcloud['yn'], pointcloud['zn']), axis=-1)
label = np.stack((pointcloud['label'], pointcloud['cl_idx'])) label = pointcloud['cl_idx']
sample_idxs = self.sampling(points) sample_idxs = self.sampling(points)
return points[sample_idxs], label[sample_idxs] return points[sample_idxs], label[sample_idxs]

View File

@ -26,7 +26,7 @@ class FullCloudsDataset(_Point_Dataset):
# When yopu want to return points and normal seperately # When yopu want to return points and normal seperately
# normal = np.stack((pointcloud['xn'], pointcloud['yn'], pointcloud['zn']), axis=-1) # normal = np.stack((pointcloud['xn'], pointcloud['yn'], pointcloud['zn']), axis=-1)
label = np.stack((pointcloud['label'], pointcloud['cl_idx'])) label = pointcloud['cl_idx']
sample_idxs = self.sampling(points) sample_idxs = self.sampling(points)
return points[sample_idxs], label[sample_idxs] return points[sample_idxs], label[sample_idxs]

View File

@ -66,8 +66,8 @@ def run_lightning_loop(config_obj):
trainer.fit(model) trainer.fit(model)
# Save the last state & all parameters # Save the last state & all parameters
trainer.save_checkpoint(config_obj.exp_path.log_dir / 'weights.ckpt') trainer.save_checkpoint(logger.log_dir / 'weights.ckpt')
model.save_to_disk(config_obj.exp_path) model.save_to_disk(logger.log_dir)
# Evaluate It # Evaluate It
if config_obj.main.eval: if config_obj.main.eval:

View File

@ -1,15 +1,16 @@
from argparse import Namespace from argparse import Namespace
import torch.nn.functional as F
import torch import torch
from torch.optim import Adam
from torch import nn from torch import nn
from torch_geometric.data import Data
from datasets.full_pointclouds import FullCloudsDataset from datasets.full_pointclouds import FullCloudsDataset
from ml_lib.modules.geometric_blocks import SAModule, GlobalSAModule, MLP from ml_lib.modules.geometric_blocks import SAModule, GlobalSAModule, MLP, FPModule
from ml_lib.modules.util import LightningBaseModule, F_x from ml_lib.modules.util import LightningBaseModule, F_x
from utils.module_mixins import BaseValMixin, BaseTrainMixin, BaseOptimizerMixin, BaseDataloadersMixin, DatasetMixin from utils.module_mixins import BaseValMixin, BaseTrainMixin, BaseOptimizerMixin, BaseDataloadersMixin, DatasetMixin
from utils.project_config import GlobalVar
class PointNet2(BaseValMixin, class PointNet2(BaseValMixin,
@ -31,31 +32,27 @@ class PointNet2(BaseValMixin,
# ============================================================================= # =============================================================================
# Additional parameters # Additional parameters
self.in_shape = self.dataset.train_dataset.sample_shape
self.channels = self.in_shape[-1]
# Modules # Modules
self.sa1_module = SAModule(0.5, 0.2, MLP([self.channels, 64, 64, 128])) self.sa1_module = SAModule(0.2, 0.2, MLP([3 + 3, 64, 64, 128]))
self.sa2_module = SAModule(0.25, 0.4, MLP([128 + self.channels, 128, 128, 256])) self.sa2_module = SAModule(0.25, 0.4, MLP([128 + 3, 128, 128, 256]))
self.sa3_module = GlobalSAModule(MLP([256 + self.channels, 256, 512, 1024])) self.sa3_module = GlobalSAModule(MLP([256 + 3, 256, 512, 1024]))
self.lin1 = nn.Linear(1024, 512) self.fp3_module = FPModule(1, MLP([1024 + 256, 256, 256]))
self.lin2 = nn.Linear(512, 256) self.fp2_module = FPModule(3, MLP([256 + 128, 256, 128]))
self.lin3 = nn.Linear(256, 10) self.fp1_module = FPModule(3, MLP([128 + 3, 128, 128, 128]))
self.lin1 = torch.nn.Linear(128, 128)
self.lin2 = torch.nn.Linear(128, 128)
self.lin3 = torch.nn.Linear(128, len(GlobalVar.classes))
# Utility # Utility
self.dropout = nn.Dropout(self.params.dropout) if self.params.dropout else F_x(None) self.dropout = nn.Dropout(self.params.dropout) if self.params.dropout else F_x(None)
self.activation = self.params.activation() self.activation = self.params.activation()
self.log_softmax = nn.LogSoftmax(dim=-1) self.log_softmax = nn.LogSoftmax(dim=-1)
def configure_optimizers(self):
return Adam(self.parameters(), lr=self.params.lr, weight_decay=self.params.weight_decay)
def forward(self, data, **kwargs): def forward(self, data, **kwargs):
""" """
data: a batch of input, torch.Tensor or torch_geometric.data.Data type data: a batch of input torch_geometric.data.Data type
- torch.Tensor: (batch_size, 3, num_points), as common batch input
- torch_geometric.data.Data, as torch_geometric batch input: - torch_geometric.data.Data, as torch_geometric batch input:
data.x: (batch_size * ~num_points, C), batch nodes/points feature, data.x: (batch_size * ~num_points, C), batch nodes/points feature,
~num_points means each sample can have different number of points/nodes ~num_points means each sample can have different number of points/nodes
@ -66,37 +63,22 @@ class PointNet2(BaseValMixin,
idendifiers for all nodes of all graphs/pointclouds in the batch. See idendifiers for all nodes of all graphs/pointclouds in the batch. See
pytorch_gemometric documentation for more information pytorch_gemometric documentation for more information
""" """
dense_input = True if isinstance(data, torch.Tensor) else False
if dense_input:
# Convert to torch_geometric.data.Data type
# data = data.transpose(1, 2).contiguous()
batch_size, N, _ = data.shape # (batch_size, num_points, 6)
pos = data.view(batch_size*N, -1)
batch = torch.zeros((batch_size, N), device=pos.device, dtype=torch.long)
for i in range(batch_size):
batch[i] = i
batch = batch.view(-1)
data = Data()
data.pos, data.batch = pos, batch
if not hasattr(data, 'x'):
data.x = None
sa0_out = (data.x, data.pos, data.batch) sa0_out = (data.x, data.pos, data.batch)
sa1_out = self.sa1_module(*sa0_out) sa1_out = self.sa1_module(*sa0_out)
sa2_out = self.sa2_module(*sa1_out) sa2_out = self.sa2_module(*sa1_out)
sa3_out = self.sa3_module(*sa2_out) sa3_out = self.sa3_module(*sa2_out)
tensor, pos, batch = sa3_out fp3_out = self.fp3_module(*sa3_out, *sa2_out)
fp2_out = self.fp2_module(*fp3_out, *sa1_out)
tensor, _, _ = self.fp1_module(*fp2_out, *sa0_out)
tensor = tensor.float() tensor = tensor.float()
tensor = self.lin1(tensor)
tensor = self.activation(tensor) tensor = self.activation(tensor)
tensor = self.lin1(tensor)
tensor = self.dropout(tensor) tensor = self.dropout(tensor)
tensor = self.lin2(tensor) tensor = self.lin2(tensor)
tensor = self.activation(tensor)
tensor = self.dropout(tensor) tensor = self.dropout(tensor)
tensor = self.lin3(tensor) tensor = self.lin3(tensor)
tensor = self.log_softmax(tensor) tensor = self.log_softmax(tensor)

View File

@ -1,10 +1,17 @@
from collections import defaultdict from collections import defaultdict
from itertools import cycle
from abc import ABC from abc import ABC
from argparse import Namespace from argparse import Namespace
import torch import torch
import numpy as np
from numpy import interp
from sklearn.metrics import roc_curve, auc, confusion_matrix, ConfusionMatrixDisplay, f1_score, roc_auc_score
import matplotlib.pyplot as plt
from torch import nn from torch import nn
from torch.optim import Adam from torch.optim import Adam
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
@ -12,7 +19,9 @@ from torchcontrib.optim import SWA
from torchvision.transforms import Compose from torchvision.transforms import Compose
from ml_lib.modules.util import LightningBaseModule from ml_lib.modules.util import LightningBaseModule
from ml_lib.utils.tools import to_one_hot
from ml_lib.utils.transforms import ToTensor from ml_lib.utils.transforms import ToTensor
from ml_lib.point_toolset.point_io import BatchToData
from .project_config import GlobalVar from .project_config import GlobalVar
@ -43,16 +52,21 @@ class BaseOptimizerMixin:
class BaseTrainMixin: class BaseTrainMixin:
# Absolute Error
absolute_loss = nn.L1Loss() absolute_loss = nn.L1Loss()
# negative Log Likelyhood
nll_loss = nn.NLLLoss() nll_loss = nn.NLLLoss()
# Binary Cross Entropy
bce_loss = nn.BCELoss() bce_loss = nn.BCELoss()
# Batch To Data
batch_to_data = BatchToData()
def training_step(self, batch_xy, batch_nb, *_, **__): def training_step(self, batch_pos_x_y, batch_nb, *_, **__):
assert isinstance(self, LightningBaseModule) assert isinstance(self, LightningBaseModule)
batch_x, batch_y = batch_xy data = self.batch_to_data(*batch_pos_x_y)
y = self(batch_x).main_out y = self(data).main_out
bce_loss = self.bce_loss(y, batch_y) nll_loss = self.nll_loss(y, data.y)
return dict(loss=bce_loss, log=dict(batch_nb=batch_nb)) return dict(loss=nll_loss, log=dict(batch_nb=batch_nb))
def training_epoch_end(self, outputs): def training_epoch_end(self, outputs):
assert isinstance(self, LightningBaseModule) assert isinstance(self, LightningBaseModule)
@ -66,17 +80,20 @@ class BaseTrainMixin:
class BaseValMixin: class BaseValMixin:
# Absolute Error
absolute_loss = nn.L1Loss() absolute_loss = nn.L1Loss()
# negative Log Likelyhood
nll_loss = nn.NLLLoss() nll_loss = nn.NLLLoss()
# Binary Cross Entropy
bce_loss = nn.BCELoss() bce_loss = nn.BCELoss()
def validation_step(self, batch_xy, batch_idx, _, *__, **___): def validation_step(self, batch_pos_x_y, batch_idx, *_, **__):
assert isinstance(self, LightningBaseModule) assert isinstance(self, LightningBaseModule)
batch_x, batch_y = batch_xy data = self.batch_to_data(*batch_pos_x_y)
y = self(batch_x).main_out y = self(data).main_out
val_bce_loss = self.bce_loss(y, batch_y) nll_loss = self.nll_loss(y, data.y)
return dict(val_bce_loss=val_bce_loss, return dict(val_nll_loss=nll_loss,
batch_idx=batch_idx, y=y, batch_y=batch_y) batch_idx=batch_idx, y=y, batch_y=data.y)
def validation_epoch_end(self, outputs, *_, **__): def validation_epoch_end(self, outputs, *_, **__):
assert isinstance(self, LightningBaseModule) assert isinstance(self, LightningBaseModule)
@ -84,25 +101,107 @@ class BaseValMixin:
# In case of Multiple given dataloader this will outputs will be: list[list[dict[]]] # In case of Multiple given dataloader this will outputs will be: list[list[dict[]]]
# for output_idx, output in enumerate(outputs): # for output_idx, output in enumerate(outputs):
# else:list[dict[]] # else:list[dict[]]
keys = list(outputs.keys()) keys = list(outputs[0].keys())
# Add Every Value das has a "loss" in it, by calc. mean over all occurences. # Add Every Value das has a "loss" in it, by calc. mean over all occurences.
summary_dict['log'].update({f'mean_{key}': torch.mean(torch.stack([output[key] summary_dict['log'].update({f'mean_{key}': torch.mean(torch.stack([output[key]
for output in outputs])) for output in outputs]))
for key in keys if 'loss' in key} for key in keys if 'loss' in key}
) )
"""
# Additional Score like the unweighted Average Recall: #######################################################################################
# UnweightedAverageRecall # Additional Score - UAR - ROC - Conf. Matrix - F1
#######################################################################################
#
# INIT
y_true = torch.cat([output['batch_y'] for output in outputs]) .cpu().numpy() y_true = torch.cat([output['batch_y'] for output in outputs]) .cpu().numpy()
y_true_one_hot = to_one_hot(y_true)
y_pred = torch.cat([output['y'] for output in outputs]).squeeze().cpu().numpy() y_pred = torch.cat([output['y'] for output in outputs]).squeeze().cpu().numpy()
y_pred_max = np.argmax(y_pred, axis=1)
y_pred = (y_pred >= 0.5).astype(np.float32) class_names = {val: key for key, val in GlobalVar.classes.__dict__().items()}
######################################################################################
#
# F1 SCORE
micro_f1_score = f1_score(y_true, y_pred_max, labels=None, pos_label=1, average='micro', sample_weight=None,
zero_division=True)
macro_f1_score = f1_score(y_true, y_pred_max, labels=None, pos_label=1, average='macro', sample_weight=None,
zero_division=True)
summary_dict['log'].update(dict(micro_f1_score=micro_f1_score, macro_f1_score=macro_f1_score))
uar_score = sklearn.metrics.recall_score(y_true, y_pred, labels=[0, 1], average='macro', #######################################################################################
sample_weight=None, zero_division='warn') #
# ROC Curve
summary_dict['log'].update({f'uar_score': uar_score}) # Compute ROC curve and ROC area for each class
""" fpr = dict()
tpr = dict()
roc_auc = dict()
for i in range(len(GlobalVar.classes)):
fpr[i], tpr[i], _ = roc_curve(y_true_one_hot[:, i], y_pred[:, i])
roc_auc[i] = auc(fpr[i], tpr[i])
# Compute micro-average ROC curve and ROC area
fpr["micro"], tpr["micro"], _ = roc_curve(y_true_one_hot.ravel(), y_pred.ravel())
roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])
# First aggregate all false positive rates
all_fpr = np.unique(np.concatenate([fpr[i] for i in range(len(GlobalVar.classes))]))
# Then interpolate all ROC curves at this points
mean_tpr = np.zeros_like(all_fpr)
for i in range(len(GlobalVar.classes)):
mean_tpr += interp(all_fpr, fpr[i], tpr[i])
# Finally average it and compute AUC
mean_tpr /= len(GlobalVar.classes)
fpr["macro"] = all_fpr
tpr["macro"] = mean_tpr
roc_auc["macro"] = auc(fpr["macro"], tpr["macro"])
# Plot all ROC curves
plt.figure()
plt.plot(fpr["micro"], tpr["micro"],
label=f'micro ROC ({round(roc_auc["micro"], 2)})',
color='deeppink', linestyle=':', linewidth=4)
plt.plot(fpr["macro"], tpr["macro"],
label=f'macro ROC({round(roc_auc["macro"], 2)})]',
color='navy', linestyle=':', linewidth=4)
colors = cycle(['firebrick', 'orangered', 'gold', 'olive', 'limegreen', 'aqua',
'dodgerblue', 'slategrey', 'royalblue', 'indigo', 'fuchsia'], )
for i, color in zip(range(len(GlobalVar.classes)), colors):
plt.plot(fpr[i], tpr[i], color=color, lw=2, label=f'{class_names[i]} ({round(roc_auc[i],2 )})')
plt.plot([0, 1], [0, 1], 'k--', lw=2)
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.legend(loc="lower right")
self.logger.log_image('ROC', image=plt.gcf(), step=self.current_epoch)
plt.clf()
#######################################################################################
#
# ROC SCORE
macro_roc_auc_ovr = roc_auc_score(y_true_one_hot, y_pred, multi_class="ovr",
average="macro")
summary_dict['log'].update(macro_roc_auc_ovr=macro_roc_auc_ovr)
#######################################################################################
#
# Confusion matrix
cm = confusion_matrix(y_true, y_pred_max, labels=[class_name for class_name in class_names], normalize='all')
disp = ConfusionMatrixDisplay(confusion_matrix=cm)
disp.plot(include_values=True)
self.logger.log_image('Confusion Matrix', image=plt.gcf(), step=self.current_epoch)
return summary_dict return summary_dict
@ -122,18 +221,17 @@ class DatasetMixin:
dataset = Namespace( dataset = Namespace(
**dict( **dict(
# TRAIN DATASET # TRAIN DATASET
train_dataset=dataset_class(self.params.root, setting=GlobalVar.train, train_dataset=dataset_class(self.params.root, setting=GlobalVar.data_split.train,
transforms=transforms transforms=transforms
),
# VALIDATION DATASET
val_dataset=dataset_class(self.params.root, setting=GlobalVar.vali,
), ),
# TEST DATASET # VALIDATION DATASET
test_dataset=dataset_class(self.params.root, setting=GlobalVar.test, val_dataset=dataset_class(self.params.root, setting=GlobalVar.data_split.devel,
), ),
# TEST DATASET
test_dataset=dataset_class(self.params.root, setting=GlobalVar.data_split.test,
),
) )
) )
return dataset return dataset

View File

@ -3,24 +3,44 @@ from argparse import Namespace
from ml_lib.utils.config import Config from ml_lib.utils.config import Config
class GlobalVar(Namespace): class DataClass(Namespace):
# Labels for classes
LEFT = 1
RIGHT = 0
WRONG = -1
# Colors for img files def __len__(self):
WHITE = 255 return len(self.__dict__())
BLACK = 0
def __dict__(self):
return {key: val for key, val in self.__class__.__dict__.items() if '__' not in key}
def __repr__(self):
return f'{self.__class__.__name__}({self.__dict__().__repr__()})'
class Classes(DataClass):
# Object Classes for Point Segmentation
Sphere = 0
Cylinder = 1
Cone = 2
Box = 3
Polytope = 4
Torus = 5
Plane = 6
class DataSplit(DataClass):
# DATA SPLIT OPTIONS
train = 'train',
devel = 'devel',
test = 'test'
class GlobalVar(DataClass):
# Variables for plotting # Variables for plotting
PADDING = 0.25 PADDING = 0.25
DPI = 50 DPI = 50
# DATAOPTIONS data_split = DataSplit()
train ='train',
vali ='vali', classes = Classes()
test ='test'
from models import * from models import *