pointnet2 working - TODO: Eval!
This commit is contained in:
parent
4b4051c045
commit
77ea043907
@ -2,6 +2,11 @@ from torch.utils.data import Dataset
|
|||||||
|
|
||||||
|
|
||||||
class TemplateDataset(Dataset):
|
class TemplateDataset(Dataset):
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sample_shape(self):
|
||||||
|
return self[0][0].shape
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super(TemplateDataset, self).__init__()
|
super(TemplateDataset, self).__init__()
|
||||||
|
|
||||||
|
47
modules/geometric_blocks.py
Normal file
47
modules/geometric_blocks.py
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn import ReLU
|
||||||
|
|
||||||
|
from torch_geometric.nn import PointConv, fps, radius, global_max_pool
|
||||||
|
|
||||||
|
|
||||||
|
class SAModule(torch.nn.Module):
|
||||||
|
def __init__(self, ratio, r, nn):
|
||||||
|
super(SAModule, self).__init__()
|
||||||
|
self.ratio = ratio
|
||||||
|
self.r = r
|
||||||
|
self.conv = PointConv(nn)
|
||||||
|
|
||||||
|
def forward(self, x, pos, batch):
|
||||||
|
idx = fps(pos, batch, ratio=self.ratio)
|
||||||
|
row, col = radius(pos, pos[idx], self.r, batch, batch[idx],
|
||||||
|
max_num_neighbors=64)
|
||||||
|
edge_index = torch.stack([col, row], dim=0)
|
||||||
|
x = self.conv(x, (pos, pos[idx]), edge_index)
|
||||||
|
pos, batch = pos[idx], batch[idx]
|
||||||
|
return x, pos, batch
|
||||||
|
|
||||||
|
|
||||||
|
class GlobalSAModule(nn.Module):
|
||||||
|
def __init__(self, nn):
|
||||||
|
super(GlobalSAModule, self).__init__()
|
||||||
|
self.nn = nn
|
||||||
|
|
||||||
|
def forward(self, x, pos, batch):
|
||||||
|
x = self.nn(torch.cat([x, pos], dim=1))
|
||||||
|
x = global_max_pool(x, batch)
|
||||||
|
pos = pos.new_zeros((x.size(0), 3))
|
||||||
|
batch = torch.arange(x.size(0), device=batch.device)
|
||||||
|
return x, pos, batch
|
||||||
|
|
||||||
|
|
||||||
|
class MLP(nn.Module):
|
||||||
|
def __init__(self, channels, norm=True):
|
||||||
|
super(MLP, self).__init__()
|
||||||
|
self.net = nn.Sequential(*[
|
||||||
|
nn.Sequential(nn.Linear(channels[i - 1], channels[i]), ReLU(), nn.BatchNorm1d(channels[i]))
|
||||||
|
for i in range(1, len(channels))
|
||||||
|
]).double()
|
||||||
|
|
||||||
|
def forward(self, x, *args, **kwargs):
|
||||||
|
return self.net(x)
|
132
modules/util.py
132
modules/util.py
@ -13,6 +13,72 @@ import pytorch_lightning as pl
|
|||||||
from ..utils.model_io import ModelParameters
|
from ..utils.model_io import ModelParameters
|
||||||
|
|
||||||
|
|
||||||
|
class LightningBaseModule(pl.LightningModule, ABC):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def name(cls):
|
||||||
|
return cls.__name__
|
||||||
|
|
||||||
|
@property
|
||||||
|
def shape(self):
|
||||||
|
try:
|
||||||
|
x = torch.randn(self.in_shape).unsqueeze(0)
|
||||||
|
output = self(x)
|
||||||
|
return output.shape[1:]
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
return -1
|
||||||
|
|
||||||
|
def __init__(self, hparams):
|
||||||
|
super(LightningBaseModule, self).__init__()
|
||||||
|
|
||||||
|
# Set Parameters
|
||||||
|
################################
|
||||||
|
self.hparams = hparams
|
||||||
|
self.params = ModelParameters(hparams)
|
||||||
|
|
||||||
|
# Dataset Loading
|
||||||
|
################################
|
||||||
|
# TODO: Find a way to push Class Name, library path and parameters (sometimes thiose are objects) in here
|
||||||
|
|
||||||
|
def size(self):
|
||||||
|
return self.shape
|
||||||
|
|
||||||
|
def save_to_disk(self, model_path):
|
||||||
|
Path(model_path, exist_ok=True).mkdir(parents=True, exist_ok=True)
|
||||||
|
if not (model_path / 'model_class.obj').exists():
|
||||||
|
with (model_path / 'model_class.obj').open('wb') as f:
|
||||||
|
torch.save(self.__class__, f)
|
||||||
|
return True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def data_len(self):
|
||||||
|
return len(self.dataset.train_dataset)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def n_train_batches(self):
|
||||||
|
return len(self.train_dataloader())
|
||||||
|
|
||||||
|
def configure_optimizers(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def test_step(self, *args, **kwargs):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def test_epoch_end(self, outputs):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def init_weights(self, in_place_init_func_=nn.init.xavier_uniform_):
|
||||||
|
weight_initializer = WeightInit(in_place_init_function=in_place_init_func_)
|
||||||
|
self.apply(weight_initializer)
|
||||||
|
|
||||||
|
|
||||||
class ShapeMixin:
|
class ShapeMixin:
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -99,72 +165,6 @@ class WeightInit:
|
|||||||
m.bias.data.fill_(0.01)
|
m.bias.data.fill_(0.01)
|
||||||
|
|
||||||
|
|
||||||
class LightningBaseModule(pl.LightningModule, ABC):
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def name(cls):
|
|
||||||
return cls.__name__
|
|
||||||
|
|
||||||
@property
|
|
||||||
def shape(self):
|
|
||||||
try:
|
|
||||||
x = torch.randn(self.in_shape).unsqueeze(0)
|
|
||||||
output = self(x)
|
|
||||||
return output.shape[1:]
|
|
||||||
except Exception as e:
|
|
||||||
print(e)
|
|
||||||
return -1
|
|
||||||
|
|
||||||
def __init__(self, hparams):
|
|
||||||
super(LightningBaseModule, self).__init__()
|
|
||||||
|
|
||||||
# Set Parameters
|
|
||||||
################################
|
|
||||||
self.hparams = hparams
|
|
||||||
self.params = ModelParameters(hparams)
|
|
||||||
|
|
||||||
# Dataset Loading
|
|
||||||
################################
|
|
||||||
# TODO: Find a way to push Class Name, library path and parameters (sometimes thiose are objects) in here
|
|
||||||
|
|
||||||
def size(self):
|
|
||||||
return self.shape
|
|
||||||
|
|
||||||
def save_to_disk(self, model_path):
|
|
||||||
Path(model_path, exist_ok=True).mkdir(parents=True, exist_ok=True)
|
|
||||||
if not (model_path / 'model_class.obj').exists():
|
|
||||||
with (model_path / 'model_class.obj').open('wb') as f:
|
|
||||||
torch.save(self.__class__, f)
|
|
||||||
return True
|
|
||||||
|
|
||||||
@property
|
|
||||||
def data_len(self):
|
|
||||||
return len(self.dataset.train_dataset)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def n_train_batches(self):
|
|
||||||
return len(self.train_dataloader())
|
|
||||||
|
|
||||||
def configure_optimizers(self):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def test_step(self, *args, **kwargs):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def test_epoch_end(self, outputs):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def init_weights(self, in_place_init_func_=nn.init.xavier_uniform_):
|
|
||||||
weight_initializer = WeightInit(in_place_init_function=in_place_init_func_)
|
|
||||||
self.apply(weight_initializer)
|
|
||||||
|
|
||||||
|
|
||||||
class FilterLayer(nn.Module):
|
class FilterLayer(nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -6,19 +6,23 @@ class FarthestpointSampling():
|
|||||||
def __init__(self, K):
|
def __init__(self, K):
|
||||||
self.k = K
|
self.k = K
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def calc_distances(p0, points):
|
||||||
|
return ((p0[:3] - points[:, :3]) ** 2).sum(axis=1)
|
||||||
|
|
||||||
def __call__(self, pts, *args, **kwargs):
|
def __call__(self, pts, *args, **kwargs):
|
||||||
|
|
||||||
if pts.shape[0] < self.k:
|
if pts.shape[0] < self.k:
|
||||||
return pts
|
return pts
|
||||||
|
|
||||||
def calc_distances(p0, points):
|
|
||||||
return ((p0[:3] - points[:, :3]) ** 2).sum(axis=1)
|
|
||||||
|
|
||||||
farthest_pts = np.zeros((self.k, pts.shape[1]))
|
farthest_pts = np.zeros((self.k, pts.shape[1]))
|
||||||
|
farthest_pts_idx = np.zeros(self.k, dtype=np.int)
|
||||||
farthest_pts[0] = pts[np.random.randint(len(pts))]
|
farthest_pts[0] = pts[np.random.randint(len(pts))]
|
||||||
distances = calc_distances(farthest_pts[0], pts)
|
distances = self.calc_distances(farthest_pts[0], pts)
|
||||||
for i in range(1, self.k):
|
for i in range(1, self.k):
|
||||||
farthest_pts[i] = pts[np.argmax(distances)]
|
farthest_pts_idx[i] = np.argmax(distances)
|
||||||
distances = np.minimum(distances, calc_distances(farthest_pts[i], pts))
|
farthest_pts[i] = pts[farthest_pts_idx[i]]
|
||||||
|
|
||||||
return farthest_pts
|
distances = np.minimum(distances, self.calc_distances(farthest_pts[i], pts))
|
||||||
|
|
||||||
|
return farthest_pts_idx
|
||||||
|
@ -3,6 +3,15 @@ import shelve
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
def fix_all_random_seeds(config_obj):
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import random
|
||||||
|
np.random.seed(config_obj.main.seed)
|
||||||
|
torch.manual_seed(config_obj.main.seed)
|
||||||
|
random.seed(config_obj.main.seed)
|
||||||
|
|
||||||
|
|
||||||
def write_to_shelve(file_path, value):
|
def write_to_shelve(file_path, value):
|
||||||
check_path(file_path)
|
check_path(file_path)
|
||||||
file_path.parent.mkdir(exist_ok=True, parents=True)
|
file_path.parent.mkdir(exist_ok=True, parents=True)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user