pointnet2 working - TODO: Eval!
This commit is contained in:
132
modules/util.py
132
modules/util.py
@ -13,6 +13,72 @@ import pytorch_lightning as pl
|
||||
from ..utils.model_io import ModelParameters
|
||||
|
||||
|
||||
class LightningBaseModule(pl.LightningModule, ABC):
|
||||
|
||||
@classmethod
|
||||
def name(cls):
|
||||
return cls.__name__
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
try:
|
||||
x = torch.randn(self.in_shape).unsqueeze(0)
|
||||
output = self(x)
|
||||
return output.shape[1:]
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return -1
|
||||
|
||||
def __init__(self, hparams):
|
||||
super(LightningBaseModule, self).__init__()
|
||||
|
||||
# Set Parameters
|
||||
################################
|
||||
self.hparams = hparams
|
||||
self.params = ModelParameters(hparams)
|
||||
|
||||
# Dataset Loading
|
||||
################################
|
||||
# TODO: Find a way to push Class Name, library path and parameters (sometimes thiose are objects) in here
|
||||
|
||||
def size(self):
|
||||
return self.shape
|
||||
|
||||
def save_to_disk(self, model_path):
|
||||
Path(model_path, exist_ok=True).mkdir(parents=True, exist_ok=True)
|
||||
if not (model_path / 'model_class.obj').exists():
|
||||
with (model_path / 'model_class.obj').open('wb') as f:
|
||||
torch.save(self.__class__, f)
|
||||
return True
|
||||
|
||||
@property
|
||||
def data_len(self):
|
||||
return len(self.dataset.train_dataset)
|
||||
|
||||
@property
|
||||
def n_train_batches(self):
|
||||
return len(self.train_dataloader())
|
||||
|
||||
def configure_optimizers(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def test_step(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def test_epoch_end(self, outputs):
|
||||
raise NotImplementedError
|
||||
|
||||
def init_weights(self, in_place_init_func_=nn.init.xavier_uniform_):
|
||||
weight_initializer = WeightInit(in_place_init_function=in_place_init_func_)
|
||||
self.apply(weight_initializer)
|
||||
|
||||
|
||||
class ShapeMixin:
|
||||
|
||||
@property
|
||||
@ -99,72 +165,6 @@ class WeightInit:
|
||||
m.bias.data.fill_(0.01)
|
||||
|
||||
|
||||
class LightningBaseModule(pl.LightningModule, ABC):
|
||||
|
||||
@classmethod
|
||||
def name(cls):
|
||||
return cls.__name__
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
try:
|
||||
x = torch.randn(self.in_shape).unsqueeze(0)
|
||||
output = self(x)
|
||||
return output.shape[1:]
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return -1
|
||||
|
||||
def __init__(self, hparams):
|
||||
super(LightningBaseModule, self).__init__()
|
||||
|
||||
# Set Parameters
|
||||
################################
|
||||
self.hparams = hparams
|
||||
self.params = ModelParameters(hparams)
|
||||
|
||||
# Dataset Loading
|
||||
################################
|
||||
# TODO: Find a way to push Class Name, library path and parameters (sometimes thiose are objects) in here
|
||||
|
||||
def size(self):
|
||||
return self.shape
|
||||
|
||||
def save_to_disk(self, model_path):
|
||||
Path(model_path, exist_ok=True).mkdir(parents=True, exist_ok=True)
|
||||
if not (model_path / 'model_class.obj').exists():
|
||||
with (model_path / 'model_class.obj').open('wb') as f:
|
||||
torch.save(self.__class__, f)
|
||||
return True
|
||||
|
||||
@property
|
||||
def data_len(self):
|
||||
return len(self.dataset.train_dataset)
|
||||
|
||||
@property
|
||||
def n_train_batches(self):
|
||||
return len(self.train_dataloader())
|
||||
|
||||
def configure_optimizers(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def test_step(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def test_epoch_end(self, outputs):
|
||||
raise NotImplementedError
|
||||
|
||||
def init_weights(self, in_place_init_func_=nn.init.xavier_uniform_):
|
||||
weight_initializer = WeightInit(in_place_init_function=in_place_init_func_)
|
||||
self.apply(weight_initializer)
|
||||
|
||||
|
||||
class FilterLayer(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
|
Reference in New Issue
Block a user