pointnet2 working - TODO: Eval!

This commit is contained in:
Si11ium 2020-05-26 21:44:56 +02:00
parent 4b4051c045
commit 77ea043907
5 changed files with 138 additions and 73 deletions

View File

@ -2,6 +2,11 @@ from torch.utils.data import Dataset
class TemplateDataset(Dataset):
@property
def sample_shape(self):
return self[0][0].shape
def __init__(self, *args, **kwargs):
super(TemplateDataset, self).__init__()

View File

@ -0,0 +1,47 @@
import torch
from torch import nn
from torch.nn import ReLU
from torch_geometric.nn import PointConv, fps, radius, global_max_pool
class SAModule(torch.nn.Module):
def __init__(self, ratio, r, nn):
super(SAModule, self).__init__()
self.ratio = ratio
self.r = r
self.conv = PointConv(nn)
def forward(self, x, pos, batch):
idx = fps(pos, batch, ratio=self.ratio)
row, col = radius(pos, pos[idx], self.r, batch, batch[idx],
max_num_neighbors=64)
edge_index = torch.stack([col, row], dim=0)
x = self.conv(x, (pos, pos[idx]), edge_index)
pos, batch = pos[idx], batch[idx]
return x, pos, batch
class GlobalSAModule(nn.Module):
def __init__(self, nn):
super(GlobalSAModule, self).__init__()
self.nn = nn
def forward(self, x, pos, batch):
x = self.nn(torch.cat([x, pos], dim=1))
x = global_max_pool(x, batch)
pos = pos.new_zeros((x.size(0), 3))
batch = torch.arange(x.size(0), device=batch.device)
return x, pos, batch
class MLP(nn.Module):
def __init__(self, channels, norm=True):
super(MLP, self).__init__()
self.net = nn.Sequential(*[
nn.Sequential(nn.Linear(channels[i - 1], channels[i]), ReLU(), nn.BatchNorm1d(channels[i]))
for i in range(1, len(channels))
]).double()
def forward(self, x, *args, **kwargs):
return self.net(x)

View File

@ -13,6 +13,72 @@ import pytorch_lightning as pl
from ..utils.model_io import ModelParameters
class LightningBaseModule(pl.LightningModule, ABC):
@classmethod
def name(cls):
return cls.__name__
@property
def shape(self):
try:
x = torch.randn(self.in_shape).unsqueeze(0)
output = self(x)
return output.shape[1:]
except Exception as e:
print(e)
return -1
def __init__(self, hparams):
super(LightningBaseModule, self).__init__()
# Set Parameters
################################
self.hparams = hparams
self.params = ModelParameters(hparams)
# Dataset Loading
################################
# TODO: Find a way to push Class Name, library path and parameters (sometimes thiose are objects) in here
def size(self):
return self.shape
def save_to_disk(self, model_path):
Path(model_path, exist_ok=True).mkdir(parents=True, exist_ok=True)
if not (model_path / 'model_class.obj').exists():
with (model_path / 'model_class.obj').open('wb') as f:
torch.save(self.__class__, f)
return True
@property
def data_len(self):
return len(self.dataset.train_dataset)
@property
def n_train_batches(self):
return len(self.train_dataloader())
def configure_optimizers(self):
raise NotImplementedError
def forward(self, *args, **kwargs):
raise NotImplementedError
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
raise NotImplementedError
def test_step(self, *args, **kwargs):
raise NotImplementedError
def test_epoch_end(self, outputs):
raise NotImplementedError
def init_weights(self, in_place_init_func_=nn.init.xavier_uniform_):
weight_initializer = WeightInit(in_place_init_function=in_place_init_func_)
self.apply(weight_initializer)
class ShapeMixin:
@property
@ -99,72 +165,6 @@ class WeightInit:
m.bias.data.fill_(0.01)
class LightningBaseModule(pl.LightningModule, ABC):
@classmethod
def name(cls):
return cls.__name__
@property
def shape(self):
try:
x = torch.randn(self.in_shape).unsqueeze(0)
output = self(x)
return output.shape[1:]
except Exception as e:
print(e)
return -1
def __init__(self, hparams):
super(LightningBaseModule, self).__init__()
# Set Parameters
################################
self.hparams = hparams
self.params = ModelParameters(hparams)
# Dataset Loading
################################
# TODO: Find a way to push Class Name, library path and parameters (sometimes thiose are objects) in here
def size(self):
return self.shape
def save_to_disk(self, model_path):
Path(model_path, exist_ok=True).mkdir(parents=True, exist_ok=True)
if not (model_path / 'model_class.obj').exists():
with (model_path / 'model_class.obj').open('wb') as f:
torch.save(self.__class__, f)
return True
@property
def data_len(self):
return len(self.dataset.train_dataset)
@property
def n_train_batches(self):
return len(self.train_dataloader())
def configure_optimizers(self):
raise NotImplementedError
def forward(self, *args, **kwargs):
raise NotImplementedError
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
raise NotImplementedError
def test_step(self, *args, **kwargs):
raise NotImplementedError
def test_epoch_end(self, outputs):
raise NotImplementedError
def init_weights(self, in_place_init_func_=nn.init.xavier_uniform_):
weight_initializer = WeightInit(in_place_init_function=in_place_init_func_)
self.apply(weight_initializer)
class FilterLayer(nn.Module):
def __init__(self):

View File

@ -6,19 +6,23 @@ class FarthestpointSampling():
def __init__(self, K):
self.k = K
@staticmethod
def calc_distances(p0, points):
return ((p0[:3] - points[:, :3]) ** 2).sum(axis=1)
def __call__(self, pts, *args, **kwargs):
if pts.shape[0] < self.k:
return pts
def calc_distances(p0, points):
return ((p0[:3] - points[:, :3]) ** 2).sum(axis=1)
farthest_pts = np.zeros((self.k, pts.shape[1]))
farthest_pts_idx = np.zeros(self.k, dtype=np.int)
farthest_pts[0] = pts[np.random.randint(len(pts))]
distances = calc_distances(farthest_pts[0], pts)
distances = self.calc_distances(farthest_pts[0], pts)
for i in range(1, self.k):
farthest_pts[i] = pts[np.argmax(distances)]
distances = np.minimum(distances, calc_distances(farthest_pts[i], pts))
farthest_pts_idx[i] = np.argmax(distances)
farthest_pts[i] = pts[farthest_pts_idx[i]]
return farthest_pts
distances = np.minimum(distances, self.calc_distances(farthest_pts[i], pts))
return farthest_pts_idx

View File

@ -3,6 +3,15 @@ import shelve
from pathlib import Path
def fix_all_random_seeds(config_obj):
import numpy as np
import torch
import random
np.random.seed(config_obj.main.seed)
torch.manual_seed(config_obj.main.seed)
random.seed(config_obj.main.seed)
def write_to_shelve(file_path, value):
check_path(file_path)
file_path.parent.mkdir(exist_ok=True, parents=True)