diff --git a/_templates/new_project/datasets/template_dataset.py b/_templates/new_project/datasets/template_dataset.py index b391a56..e1e98a1 100644 --- a/_templates/new_project/datasets/template_dataset.py +++ b/_templates/new_project/datasets/template_dataset.py @@ -2,6 +2,11 @@ from torch.utils.data import Dataset class TemplateDataset(Dataset): + + @property + def sample_shape(self): + return self[0][0].shape + def __init__(self, *args, **kwargs): super(TemplateDataset, self).__init__() diff --git a/modules/geometric_blocks.py b/modules/geometric_blocks.py new file mode 100644 index 0000000..9fd4479 --- /dev/null +++ b/modules/geometric_blocks.py @@ -0,0 +1,47 @@ +import torch +from torch import nn +from torch.nn import ReLU + +from torch_geometric.nn import PointConv, fps, radius, global_max_pool + + +class SAModule(torch.nn.Module): + def __init__(self, ratio, r, nn): + super(SAModule, self).__init__() + self.ratio = ratio + self.r = r + self.conv = PointConv(nn) + + def forward(self, x, pos, batch): + idx = fps(pos, batch, ratio=self.ratio) + row, col = radius(pos, pos[idx], self.r, batch, batch[idx], + max_num_neighbors=64) + edge_index = torch.stack([col, row], dim=0) + x = self.conv(x, (pos, pos[idx]), edge_index) + pos, batch = pos[idx], batch[idx] + return x, pos, batch + + +class GlobalSAModule(nn.Module): + def __init__(self, nn): + super(GlobalSAModule, self).__init__() + self.nn = nn + + def forward(self, x, pos, batch): + x = self.nn(torch.cat([x, pos], dim=1)) + x = global_max_pool(x, batch) + pos = pos.new_zeros((x.size(0), 3)) + batch = torch.arange(x.size(0), device=batch.device) + return x, pos, batch + + +class MLP(nn.Module): + def __init__(self, channels, norm=True): + super(MLP, self).__init__() + self.net = nn.Sequential(*[ + nn.Sequential(nn.Linear(channels[i - 1], channels[i]), ReLU(), nn.BatchNorm1d(channels[i])) + for i in range(1, len(channels)) + ]).double() + + def forward(self, x, *args, **kwargs): + return self.net(x) diff --git a/modules/util.py b/modules/util.py index 9b25913..0b34277 100644 --- a/modules/util.py +++ b/modules/util.py @@ -13,6 +13,72 @@ import pytorch_lightning as pl from ..utils.model_io import ModelParameters +class LightningBaseModule(pl.LightningModule, ABC): + + @classmethod + def name(cls): + return cls.__name__ + + @property + def shape(self): + try: + x = torch.randn(self.in_shape).unsqueeze(0) + output = self(x) + return output.shape[1:] + except Exception as e: + print(e) + return -1 + + def __init__(self, hparams): + super(LightningBaseModule, self).__init__() + + # Set Parameters + ################################ + self.hparams = hparams + self.params = ModelParameters(hparams) + + # Dataset Loading + ################################ + # TODO: Find a way to push Class Name, library path and parameters (sometimes thiose are objects) in here + + def size(self): + return self.shape + + def save_to_disk(self, model_path): + Path(model_path, exist_ok=True).mkdir(parents=True, exist_ok=True) + if not (model_path / 'model_class.obj').exists(): + with (model_path / 'model_class.obj').open('wb') as f: + torch.save(self.__class__, f) + return True + + @property + def data_len(self): + return len(self.dataset.train_dataset) + + @property + def n_train_batches(self): + return len(self.train_dataloader()) + + def configure_optimizers(self): + raise NotImplementedError + + def forward(self, *args, **kwargs): + raise NotImplementedError + + def training_step(self, batch_xy, batch_nb, *args, **kwargs): + raise NotImplementedError + + def test_step(self, *args, **kwargs): + raise NotImplementedError + + def test_epoch_end(self, outputs): + raise NotImplementedError + + def init_weights(self, in_place_init_func_=nn.init.xavier_uniform_): + weight_initializer = WeightInit(in_place_init_function=in_place_init_func_) + self.apply(weight_initializer) + + class ShapeMixin: @property @@ -99,72 +165,6 @@ class WeightInit: m.bias.data.fill_(0.01) -class LightningBaseModule(pl.LightningModule, ABC): - - @classmethod - def name(cls): - return cls.__name__ - - @property - def shape(self): - try: - x = torch.randn(self.in_shape).unsqueeze(0) - output = self(x) - return output.shape[1:] - except Exception as e: - print(e) - return -1 - - def __init__(self, hparams): - super(LightningBaseModule, self).__init__() - - # Set Parameters - ################################ - self.hparams = hparams - self.params = ModelParameters(hparams) - - # Dataset Loading - ################################ - # TODO: Find a way to push Class Name, library path and parameters (sometimes thiose are objects) in here - - def size(self): - return self.shape - - def save_to_disk(self, model_path): - Path(model_path, exist_ok=True).mkdir(parents=True, exist_ok=True) - if not (model_path / 'model_class.obj').exists(): - with (model_path / 'model_class.obj').open('wb') as f: - torch.save(self.__class__, f) - return True - - @property - def data_len(self): - return len(self.dataset.train_dataset) - - @property - def n_train_batches(self): - return len(self.train_dataloader()) - - def configure_optimizers(self): - raise NotImplementedError - - def forward(self, *args, **kwargs): - raise NotImplementedError - - def training_step(self, batch_xy, batch_nb, *args, **kwargs): - raise NotImplementedError - - def test_step(self, *args, **kwargs): - raise NotImplementedError - - def test_epoch_end(self, outputs): - raise NotImplementedError - - def init_weights(self, in_place_init_func_=nn.init.xavier_uniform_): - weight_initializer = WeightInit(in_place_init_function=in_place_init_func_) - self.apply(weight_initializer) - - class FilterLayer(nn.Module): def __init__(self): diff --git a/point_toolset/sampling.py b/point_toolset/sampling.py index 37f309c..0a9c2c5 100644 --- a/point_toolset/sampling.py +++ b/point_toolset/sampling.py @@ -6,19 +6,23 @@ class FarthestpointSampling(): def __init__(self, K): self.k = K + @staticmethod + def calc_distances(p0, points): + return ((p0[:3] - points[:, :3]) ** 2).sum(axis=1) + def __call__(self, pts, *args, **kwargs): if pts.shape[0] < self.k: return pts - def calc_distances(p0, points): - return ((p0[:3] - points[:, :3]) ** 2).sum(axis=1) - farthest_pts = np.zeros((self.k, pts.shape[1])) + farthest_pts_idx = np.zeros(self.k, dtype=np.int) farthest_pts[0] = pts[np.random.randint(len(pts))] - distances = calc_distances(farthest_pts[0], pts) + distances = self.calc_distances(farthest_pts[0], pts) for i in range(1, self.k): - farthest_pts[i] = pts[np.argmax(distances)] - distances = np.minimum(distances, calc_distances(farthest_pts[i], pts)) + farthest_pts_idx[i] = np.argmax(distances) + farthest_pts[i] = pts[farthest_pts_idx[i]] - return farthest_pts + distances = np.minimum(distances, self.calc_distances(farthest_pts[i], pts)) + + return farthest_pts_idx diff --git a/utils/tools.py b/utils/tools.py index 594fff4..e5b73c1 100644 --- a/utils/tools.py +++ b/utils/tools.py @@ -3,6 +3,15 @@ import shelve from pathlib import Path +def fix_all_random_seeds(config_obj): + import numpy as np + import torch + import random + np.random.seed(config_obj.main.seed) + torch.manual_seed(config_obj.main.seed) + random.seed(config_obj.main.seed) + + def write_to_shelve(file_path, value): check_path(file_path) file_path.parent.mkdir(exist_ok=True, parents=True)