diff --git a/_parameters.py b/_parameters.py index dadc253..e33de02 100644 --- a/_parameters.py +++ b/_parameters.py @@ -21,9 +21,9 @@ main_arg_parser.add_argument("--project_neptune_key", type=str, default=os.geten # Data Parameters main_arg_parser.add_argument("--data_worker", type=int, default=10, help="") -main_arg_parser.add_argument("--data_sampling_k", type=int, default=1024, help="") +main_arg_parser.add_argument("--data_npoints", type=int, default=1024, help="") main_arg_parser.add_argument("--data_root", type=str, default='data', help="") -main_arg_parser.add_argument("--data_dataset_type", type=str, default='GridClusters', help="") +main_arg_parser.add_argument("--data_dataset_type", type=str, default='ShapeNetPartSegDataset', help="") main_arg_parser.add_argument("--data_use_preprocessed", type=strtobool, default=True, help="") # Transformations diff --git a/datasets/_point_dataset.py b/datasets/_point_dataset.py deleted file mode 100644 index c00efff..0000000 --- a/datasets/_point_dataset.py +++ /dev/null @@ -1,89 +0,0 @@ -import pickle -from collections import defaultdict - -from abc import ABC -from pathlib import Path - -from torch.utils.data import Dataset, ConcatDataset -from ml_lib.point_toolset.sampling import FarthestpointSampling, RandomSampling - -import numpy as np - - -class _Point_Dataset(ABC, Dataset): - - @property - def name(self): - raise NotImplementedError - - @property - def sample_shape(self): - # FixMe: This does not work when more then x/y tuples are returned - return self[0][0].shape - - headers = ['x', 'y', 'z', 'xn', 'yn', 'zn', 'label', 'cl_idx'] - samplers = dict(fps=FarthestpointSampling, rnd=RandomSampling) - - def __init__(self, root=Path('data'), norm_as_feature=True, sampling_k=2048, sampling='rnd', - transforms=None, load_preprocessed=True, split='train', *args, **kwargs): - super(_Point_Dataset, self).__init__() - - self.setting: str - self.split = split - self.norm_as_feature = norm_as_feature - self.load_preprocessed = load_preprocessed - self.transforms = transforms if transforms else lambda x: x - self.sampling_k = sampling_k - self.sampling = self.samplers[sampling](K=self.sampling_k) - self.root = Path(root) - self.raw = self.root / 'raw' / self.split - self.processed_ext = '.pik' - self.raw_ext = '.xyz' - self.processed = self.root / self.setting - self.processed.mkdir(parents=True, exist_ok=True) - - self._files = list(self.raw.glob(f'*{self.setting}*')) - - def _read_or_load(self, item): - raw_file_path = self._files[item] - processed_file_path = self.processed / raw_file_path.name.replace(self.raw_ext, self.processed_ext) - - if not self.load_preprocessed: - processed_file_path.unlink(missing_ok=True) - if not processed_file_path.exists(): - pointcloud = defaultdict(list) - with raw_file_path.open('r') as raw_file: - for row in raw_file: - values = [float(x) for x in row.strip().split(' ')] - for header, value in zip(self.headers, values): - pointcloud[header].append(value) - for key in pointcloud.keys(): - pointcloud[key] = np.asarray(pointcloud[key]) - with processed_file_path.open('wb') as processed_file: - pickle.dump(pointcloud, processed_file) - return processed_file_path - - def __len__(self): - raise NotImplementedError - - def __getitem__(self, item): - processed_file_path = self._read_or_load(item) - - with processed_file_path.open('rb') as processed_file: - pointcloud = pickle.load(processed_file) - - position = np.stack((pointcloud['x'], pointcloud['y'], pointcloud['z']), axis=-1) - - normal = np.stack((pointcloud['xn'], pointcloud['yn'], pointcloud['zn']), axis=-1) - - label = pointcloud['label'] - - cl_label = pointcloud['cl_idx'] - - sample_idxs = self.sampling(position) - - return (normal[sample_idxs].astype(np.float), - position[sample_idxs].astype(np.float), - label[sample_idxs].astype(np.int), - cl_label[sample_idxs].astype(np.int) - ) diff --git a/datasets/full_pointclouds.py b/datasets/full_pointclouds.py deleted file mode 100644 index a20ed3e..0000000 --- a/datasets/full_pointclouds.py +++ /dev/null @@ -1,19 +0,0 @@ -import pickle -from collections import defaultdict - -import numpy as np - -from ._point_dataset import _Point_Dataset - - -class FullCloudsDataset(_Point_Dataset): - - split: str - name = 'FullCloudsDataset' - - def __init__(self, *args, setting='pc', **kwargs): - self.setting = setting - super(FullCloudsDataset, self).__init__(*args, **kwargs) - - def __len__(self): - return len(self._files) diff --git a/datasets/grid_clusters.py b/datasets/grid_clusters.py deleted file mode 100644 index 4053c35..0000000 --- a/datasets/grid_clusters.py +++ /dev/null @@ -1,83 +0,0 @@ -import pickle -from collections import defaultdict - -import numpy as np -from torch.utils.data import ConcatDataset -from tqdm import trange - -from ._point_dataset import _Point_Dataset - - -class GridClusters(_Point_Dataset): - - split: str - name = 'GridClusters' - - def __init__(self, *args, n_spatial_clusters=3*3*3, setting='pc', **kwargs): - self.n_spatial_clusters = n_spatial_clusters - self.setting = setting - super(GridClusters, self).__init__(*args, **kwargs) - - def __len__(self): - return len(self._files) - - def _read_or_load(self, item): - raw_file_path = self._files[item] - processed_file_path = self.processed / raw_file_path.name.replace(self.raw_ext, self.processed_ext) - - if not self.load_preprocessed: - processed_file_path.unlink(missing_ok=True) - if not processed_file_path.exists(): - # nested default dict - pointcloud = defaultdict(lambda: defaultdict(list)) - - with raw_file_path.open('r') as raw_file: - for row in raw_file: - values = [float(x) for x in row.strip().split(' ')] - for header, value in zip(self.headers, values): - pointcloud[int(values[-1])][header].append(value) - for cluster in pointcloud.keys(): - for key in pointcloud[cluster].keys(): - pointcloud[cluster][key] = np.asarray(pointcloud[cluster][key]) - pointcloud[cluster] = dict(pointcloud[cluster]) - pointcloud = dict(pointcloud) - - with processed_file_path.open('wb') as processed_file: - pickle.dump(pointcloud, processed_file) - return processed_file_path - - def __getitem__(self, item): - processed_file_path = self._read_or_load(item) - - with processed_file_path.open('rb') as processed_file: - pointcloud = pickle.load(processed_file) - - # By number Variant - # cl_idx_list = np.cumsum([[len(self) // self.n_spatial_clusters, ] * self.n_spatial_clusters]) - # cl_idx = [idx for idx, x in enumerate(cl_idx_list) if item <= x][0] - - # Random Variant - cl_idx = np.random.randint(0, len(pointcloud)) - pointcloud = pointcloud[list(pointcloud.keys())[cl_idx]] - - position = np.stack((pointcloud['x'], pointcloud['y'], pointcloud['z']), axis=-1) - - normal = np.stack((pointcloud['xn'], pointcloud['yn'], pointcloud['zn']), axis=-1) - - label = pointcloud['label'] - - cl_label = pointcloud['cl_idx'] - - sample_idxs = self.sampling(position) - while sample_idxs.shape[0] < self.sampling_k: - sample_idxs = np.concatenate((sample_idxs, sample_idxs))[:self.sampling_k] - - normal = normal[sample_idxs].astype(np.float) - position = position[sample_idxs].astype(np.float) - - normal = self.transforms(normal) - position = self.transforms(position) - return (normal, position, - label[sample_idxs].astype(np.int), - cl_label[sample_idxs].astype(np.int) - ) diff --git a/datasets/shapenet.py b/datasets/shapenet.py new file mode 100644 index 0000000..0097ce3 --- /dev/null +++ b/datasets/shapenet.py @@ -0,0 +1,212 @@ +from pathlib import Path + +import numpy as np + +from collections import defaultdict + +import os +from tqdm import tqdm +import glob + +import torch +from torch_geometric.data import InMemoryDataset +from torch_geometric.data import Data +from torch.utils.data import Dataset +import re + +from utils.project_config import Classes, DataSplit + + +def save_names(name_list, path): + with open(path, 'wb') as f: + f.writelines(name_list) + + +class CustomShapeNet(InMemoryDataset): + + categories = {key: val for val, key in Classes().items()} + modes = {key: val for val, key in DataSplit().items()} + name = 'CustomShapeNet' + + @property + def raw_dir(self): + return self.root / 'raw' + + @property + def raw_file_names(self): + return [self.mode] + + @property + def processed_dir(self): + return self.root / 'processed' + + def __init__(self, root_dir, collate_per_segment=True, mode='train', transform=None, pre_filter=None, + pre_transform=None, refresh=False, with_normals=False): + assert mode in self.modes.keys(), f'"mode" must be one of {self.modes.keys()}' + + # Set the Dataset Parameters + self.collate_per_segment, self.mode, self.refresh = collate_per_segment, mode, refresh + self.with_normals = with_normals + root_dir = Path(root_dir) + super(CustomShapeNet, self).__init__(root_dir, transform, pre_transform, pre_filter) + self.data, self.slices = self._load_dataset() + print("Initialized") + + @property + def processed_file_names(self): + return [f'{self.mode}.pt'] + + def download(self): + dir_count = len([name for name in os.listdir(self.raw_dir) if os.path.isdir(os.path.join(self.raw_dir, name))]) + + if dir_count: + print(f'{dir_count} folders have been found....') + return dir_count + raise IOError("No raw pointclouds have been found.") + + @property + def num_classes(self): + return len(self.categories) + + def _load_dataset(self): + data, slices = None, None + filepath = self.processed_paths[0] + if self.refresh: + try: + os.remove(filepath) + print('Processed Location "Refreshed" (We deleted the Files)') + except FileNotFoundError: + print('You meant to refresh the allready processed dataset, but there were none...') + print('continue processing') + pass + + while True: + try: + data, slices = torch.load(filepath) + print('Dataset Loaded') + break + except FileNotFoundError: + self.process() + continue + return data, slices + + def _transform_and_filter(self, data): + # ToDo: ANy filter to apply? Then do it here. + if self.pre_filter is not None and not self.pre_filter(data): + data = self.pre_filter(data) + raise NotImplementedError + # ToDo: ANy transformation to apply? Then do it here. + if self.pre_transform is not None: + data = self.pre_transform(data) + raise NotImplementedError + return data + + def process(self, delimiter=' '): + datasets = defaultdict(list) + path_to_clouds = self.raw_dir / self.mode + + for pointcloud in tqdm(path_to_clouds.glob('*.xyz')): + if 'grid' not in pointcloud.name: + continue + data = None + + with pointcloud.open('r') as f: + src = defaultdict(list) + # Iterate over all rows + for row in f: + if row != '': + vals = row.rstrip().split(delimiter)[None:None] + vals = [float(x) if x not in ['-nan(ind)', 'nan(ind)'] else 0 for x in vals] + src[vals[-1]].append(vals) + + src = dict(src) + for key, values in src.items(): + src[key] = torch.tensor(values, dtype=torch.double).squeeze() + + if not self.collate_per_segment: + src = dict(all=torch.stack([x for x in src.values()])) + + for key, values in src.items(): + try: + points = values[:, :-2] + except IndexError: + continue + y = torch.as_tensor(values[:, -2], dtype=torch.long) + y_c = torch.as_tensor(values[:, -1], dtype=torch.long) + #################################### + # This is where you define the keys + attr_dict = dict(y=y, y_c=y_c) + if self.with_normals: + pos = points[:, :6] + norm = None + attr_dict.update(pos=pos, norm=norm) + if not self.with_normals: + pos = points[:, :3] + norm = points[:, 3:6] + attr_dict.update(pos=pos, norm=norm) + #################################### + if self.collate_per_segment: + data = Data(**attr_dict) + else: + if not data: + data = defaultdict(list) + # points=points, norm=points[:, 3:] + for key, val in attr_dict.items(): + data[key].append(val) + + data = self._transform_and_filter(data) + if self.collate_per_segment: + datasets[self.mode].append(data) + if not self.collate_per_segment: + # Todo: What is this? + datasets[self.mode].append(Data(**{key: torch.cat(data[key]) for key in data.keys()})) + + if datasets[self.mode]: + os.makedirs(self.processed_dir, exist_ok=True) + torch.save(self.collate(datasets[self.mode]), self.processed_paths[0]) + + def __repr__(self): + return f'{self.__class__.__name__}({len(self)})' + + +class ShapeNetPartSegDataset(Dataset): + """ + Resample raw point cloud to fixed number of points. + Map raw label from range [1, N] to [0, N-1]. + """ + + name = 'ShapeNetPartSegDataset' + + def __init__(self, root_dir, npoints=1024, mode='train', **kwargs): + super(ShapeNetPartSegDataset, self).__init__() + self.mode = mode + kwargs.update(dict(root_dir=root_dir, mode=self.mode)) + self.npoints = npoints + self.dataset = CustomShapeNet(**kwargs) + + def __getitem__(self, index): + data = self.dataset[index] + + # Resample to fixed number of points + try: + npoints = self.npoints if self.mode != 'predict' else data.pos.shape[0] + choice = np.random.choice(data.pos.shape[0], npoints, replace=False if self.mode == 'predict' else True) + except ValueError: + choice = [] + + pos, norm, y = data.pos[choice, :], data.norm[choice], data.y[choice] + + # y -= 1 if self.num_classes() in y else 0 # Map label from [1, C] to [0, C-1] + + sample = Data(**dict(pos=pos, # torch.Tensor (n, 3/6) + y=y, # torch.Tensor (n,) + norm=norm # torch.Tensor (n, 3/0) + ) + ) + return sample + + def __len__(self): + return len(self.dataset) + + def num_classes(self): + return self.dataset.num_classes diff --git a/datasets/template_dataset.py b/datasets/template_dataset.py index ffeb9c3..6bd6258 100644 --- a/datasets/template_dataset.py +++ b/datasets/template_dataset.py @@ -1,8 +1,6 @@ -from torch.utils.data import Dataset -from._point_dataset import _Point_Dataset +# Template - -class TemplateDataset(_Point_Dataset): +class TemplateDataset(object): def __init__(self, *args, **kwargs): super(TemplateDataset, self).__init__() diff --git a/main_inference.py b/main_inference.py index 7b72616..7e9c86c 100644 --- a/main_inference.py +++ b/main_inference.py @@ -15,12 +15,12 @@ from ml_lib.utils.model_io import SavedLightningModels # Datasets -from datasets.full_pointclouds import FullCloudsDataset +from datasets.shapenet import ShapeNetPartSegDataset from utils.project_config import GlobalVar, ThisConfig def prepare_dataloader(config_obj): - dataset = FullCloudsDataset(config_obj.data.root, split=GlobalVar.data_split.test, + dataset = ShapeNetPartSegDataset(config_obj.data.root, split=GlobalVar.data_split.test, setting=GlobalVar.settings[config_obj.model.type]) # noinspection PyTypeChecker return DataLoader(dataset, batch_size=config_obj.train.batch_size, diff --git a/main_pipeline.py b/main_pipeline.py index 89c3cbf..8e33455 100644 --- a/main_pipeline.py +++ b/main_pipeline.py @@ -18,14 +18,14 @@ from ml_lib.utils.model_io import SavedLightningModels # Datasets -from datasets.full_pointclouds import FullCloudsDataset +from datasets.shapenet import ShapeNetPartSegDataset from utils.pointcloud import read_pointcloud, normalize_pointcloud, cluster_cubes, append_onehotencoded_type, \ label2color, polytopes_to_planes from utils.project_config import GlobalVar, ThisConfig def prepare_dataloader(config_obj): - dataset = FullCloudsDataset(config_obj.data.root, split=GlobalVar.data_split.test, + dataset = ShapeNetPartSegDataset(config_obj.data.root, split=GlobalVar.data_split.test, setting=GlobalVar.settings[config_obj.model.type]) # noinspection PyTypeChecker return DataLoader(dataset, batch_size=config_obj.train.batch_size, @@ -43,15 +43,14 @@ def restore_logger_and_model(log_dir): def predict_prim_type(input_pc, model): - input_data = ( - torch.tensor(np.array([input_pc[:, 3:6]], np.float)), - torch.tensor(input_pc[:, 0:3]), - np.zeros(input_pc.shape[0]) - ) + input_data = dict(norm=torch.tensor(np.array([input_pc[:, 3:6]], np.float)), + pos=torch.tensor(input_pc[:, 0:3]), + y=np.zeros(input_pc.shape[0]) + ) batch_to_data = BatchToData() - data = batch_to_data(input_data[0], input_data[1], input_data[2]) + data = batch_to_data(input_data) y = loaded_model(data.to(device='cuda' if torch.cuda.is_available() else 'cpu')) y_primary = torch.argmax(y.main_out, dim=-1).squeeze().cpu().numpy() diff --git a/models/__init__.py b/models/__init__.py index d8bd278..da3273f 100644 --- a/models/__init__.py +++ b/models/__init__.py @@ -1,4 +1 @@ from .point_net_2 import PointNet2 -from .point_net_2_grid_clusters import PointNet2GridClusters -from .point_net_2_prim_clusters import PointNet2PrimClusters - diff --git a/models/point_net_2.py b/models/point_net_2.py index 22736b6..4b13c39 100644 --- a/models/point_net_2.py +++ b/models/point_net_2.py @@ -3,7 +3,7 @@ from argparse import Namespace import torch from torch import nn -from datasets.grid_clusters import GridClusters +from datasets.shapenet import ShapeNetPartSegDataset from models._point_net_2 import _PointNetCore from utils.module_mixins import BaseValMixin, BaseTrainMixin, BaseOptimizerMixin, BaseDataloadersMixin, DatasetMixin @@ -23,7 +23,8 @@ class PointNet2(BaseValMixin, # Dataset # ============================================================================= - self.dataset = self.build_dataset(GridClusters, setting='pc', sampling_k=self.params.sampling_k) + self.dataset = self.build_dataset(ShapeNetPartSegDataset, collate_per_segment=True, + npoints=self.params.npoints) # Model Paramters # ============================================================================= diff --git a/models/point_net_2_grid_clusters.py b/models/point_net_2_grid_clusters.py deleted file mode 100644 index a47bc86..0000000 --- a/models/point_net_2_grid_clusters.py +++ /dev/null @@ -1,80 +0,0 @@ -from argparse import Namespace - -import torch -from torch import nn -from torch_geometric.data import Data -from torch_geometric.transforms import Compose, NormalizeScale, RandomFlip - -from datasets.grid_clusters import GridClusters -from models._point_net_2 import _PointNetCore - -from utils.module_mixins import BaseValMixin, BaseTrainMixin, BaseOptimizerMixin, BaseDataloadersMixin, DatasetMixin -from utils.project_config import GlobalVar - - -class PointNet2GridClusters(BaseValMixin, - BaseTrainMixin, - BaseOptimizerMixin, - DatasetMixin, - BaseDataloadersMixin, - _PointNetCore - ): - - def training_step(self, batch_pos_x_n_y_c, batch_nb, *_, **__): - data = self.batch_to_data(*batch_pos_x_n_y_c) if not isinstance(batch_pos_x_n_y_c, Data) else batch_pos_x_n_y_c - y = self(data) - nll_main_loss = self.nll_loss(y.main_out, data.yl) - nll_cluster_loss = self.nll_loss(y.grid_out, data.yc) - nll_loss = nll_main_loss + nll_cluster_loss - return dict(loss=nll_loss, log=dict(batch_nb=batch_nb), - nll_cluster_loss=nll_cluster_loss, nll_main_loss=nll_main_loss) - - def validation_step(self, batch_pos_x_n_y_c, batch_idx, *_, **__): - data = self.batch_to_data(*batch_pos_x_n_y_c) if not isinstance(batch_pos_x_n_y_c, Data) else batch_pos_x_n_y_c - y = self(data) - nll_main_loss = self.nll_loss(y.main_out, data.yl) - nll_cluster_loss = self.nll_loss(y.grid_out, data.yc) - nll_loss = nll_main_loss + nll_cluster_loss - return dict(val_nll_loss=nll_loss, nll_cluster_loss=nll_cluster_loss, nll_main_loss=nll_main_loss, - batch_idx=batch_idx, y=y.main_out, batch_y=data.yl) - - def __init__(self, hparams): - super(PointNet2GridClusters, self).__init__(hparams=hparams) - - # Dataset - # ============================================================================= - self.dataset = self.build_dataset(GridClusters, setting='grid', sampling_k=self.params.sampling_k) - - # Model Paramters - # ============================================================================= - # Additional parameters - self.n_classes = len(GlobalVar.classes) - - # Modules - self.point_lin = torch.nn.Linear(128, len(GlobalVar.classes)) - self.grid_lin = torch.nn.Linear(128, GlobalVar.grid_count) - - # Utility - self.log_softmax = nn.LogSoftmax(dim=-1) - - def forward(self, data, **kwargs): - """ - data: a batch of input torch_geometric.data.Data type - - torch_geometric.data.Data, as torch_geometric batch input: - data.x: (batch_size * ~num_points, C), batch nodes/points feature, - ~num_points means each sample can have different number of points/nodes - - data.pos: (batch_size * ~num_points, 3) - - data.batch: (batch_size * ~num_points,), a column vector of graph/pointcloud - idendifiers for all nodes of all graphs/pointclouds in the batch. See - pytorch_gemometric documentation for more information - """ - - sa0_out = (data.norm, data.pos, data.batch) - tensor = super(PointNet2GridClusters, self).forward(sa0_out) - point_tensor = self.point_lin(tensor) - point_tensor = self.log_softmax(point_tensor) - grid_tensor = self.grid_lin(tensor) - grid_tensor = self.log_softmax(grid_tensor) - return Namespace(main_out=point_tensor, grid_out=grid_tensor) diff --git a/models/point_net_2_prim_clusters.py b/models/point_net_2_prim_clusters.py deleted file mode 100644 index 0eed09c..0000000 --- a/models/point_net_2_prim_clusters.py +++ /dev/null @@ -1,59 +0,0 @@ -from argparse import Namespace - -import torch -from torch import nn - -from datasets.full_pointclouds import FullCloudsDataset -from models._point_net_2 import _PointNetCore - -from utils.module_mixins import BaseValMixin, BaseTrainMixin, BaseOptimizerMixin, BaseDataloadersMixin, DatasetMixin -from utils.project_config import GlobalVar - - -class PointNet2PrimClusters(BaseValMixin, - BaseTrainMixin, - BaseOptimizerMixin, - DatasetMixin, - BaseDataloadersMixin, - _PointNetCore - ): - - def __init__(self, hparams): - super(PointNet2PrimClusters, self).__init__(hparams=hparams) - - # Dataset - # ============================================================================= - self.dataset = self.build_dataset(FullCloudsDataset, setting='prim') - - # Model Paramters - # ============================================================================= - # Additional parameters - - # Modules - self.point_lin = torch.nn.Linear(128, len(GlobalVar.classes)) - self.prim_lin = torch.nn.Linear(128, len(GlobalVar.prims)) - - # Utility - self.log_softmax = nn.LogSoftmax(dim=-1) - - def forward(self, data, **kwargs): - """ - data: a batch of input torch_geometric.data.Data type - - torch_geometric.data.Data, as torch_geometric batch input: - data.x: (batch_size * ~num_points, C), batch nodes/points feature, - ~num_points means each sample can have different number of points/nodes - - data.pos: (batch_size * ~num_points, 3) - - data.batch: (batch_size * ~num_points,), a column vector of graph/pointcloud - idendifiers for all nodes of all graphs/pointclouds in the batch. See - pytorch_gemometric documentation for more information - """ - - sa0_out = (data.norm, data.pos, data.batch) - tensor = super(PointNet2PrimClusters, self).forward(sa0_out) - point_tensor = self.point_lin(tensor) - point_tensor = self.log_softmax(point_tensor) - prim_tensor = self.prim_lin(tensor) - prim_tensor = self.log_softmax(prim_tensor) - return Namespace(main_out=point_tensor, prim_out=prim_tensor) diff --git a/utils/module_mixins.py b/utils/module_mixins.py index 41ea033..3940163 100644 --- a/utils/module_mixins.py +++ b/utils/module_mixins.py @@ -14,8 +14,7 @@ import matplotlib.pyplot as plt from torch import nn from torch.optim import Adam -from torch.utils.data import DataLoader -from torch_geometric.data import Data +from torch_geometric.data import Data, DataLoader from torchcontrib.optim import SWA @@ -59,11 +58,11 @@ class BaseTrainMixin: # Binary Cross Entropy bce_loss = nn.BCELoss() - def training_step(self, batch_norm_pos_y_c, batch_nb, *_, **__): + def training_step(self, batch_norm_pos_y, batch_nb, *_, **__): assert isinstance(self, LightningBaseModule) - data = self.batch_to_data(*batch_norm_pos_y_c) if not isinstance(batch_norm_pos_y_c, Data) else batch_norm_pos_y_c + data = self.batch_to_data(batch_norm_pos_y) if not isinstance(batch_norm_pos_y, Data) else batch_norm_pos_y y = self(data).main_out - nll_loss = self.nll_loss(y, data.yl) + nll_loss = self.nll_loss(y, data.y) return dict(loss=nll_loss, log=dict(batch_nb=batch_nb)) def training_epoch_end(self, outputs): @@ -89,9 +88,9 @@ class BaseValMixin: assert isinstance(self, LightningBaseModule) data = self.batch_to_data(*batch_pos_x_n_y_c) if not isinstance(batch_pos_x_n_y_c, Data) else batch_pos_x_n_y_c y = self(data).main_out - nll_loss = self.nll_loss(y, data.yl) + nll_loss = self.nll_loss(y, data.y) return dict(val_nll_loss=nll_loss, - batch_idx=batch_idx, y=y, batch_y=data.yl) + batch_idx=batch_idx, y=y, batch_y=data.y) def validation_epoch_end(self, outputs, *_, **__): assert isinstance(self, LightningBaseModule) @@ -229,14 +228,15 @@ class DatasetMixin: dataset = Namespace( **dict( # TRAIN DATASET - train_dataset=dataset_class(self.params.root, split=GlobalVar.data_split.train, **kwargs), + train_dataset=dataset_class(self.params.root, mode=GlobalVar.data_split.train, + **kwargs), # VALIDATION DATASET - val_dataset=dataset_class(self.params.root, split=GlobalVar.data_split.devel, + val_dataset=dataset_class(self.params.root, mode=GlobalVar.data_split.devel, **kwargs), # TEST DATASET - test_dataset=dataset_class(self.params.root, split=GlobalVar.data_split.test, + test_dataset=dataset_class(self.params.root, mode=GlobalVar.data_split.test, **kwargs), ) ) diff --git a/utils/project_config.py b/utils/project_config.py index c2c96d1..d0c23d1 100644 --- a/utils/project_config.py +++ b/utils/project_config.py @@ -68,4 +68,4 @@ class ThisConfig(Config): @property def _model_map(self): - return dict(PN2=PointNet2, P2P=PointNet2PrimClusters, P2G=PointNet2GridClusters) + return dict(PN2=PointNet2)