bringing brances up to date
This commit is contained in:
@ -1,3 +1,6 @@
|
||||
import inspect
|
||||
from argparse import ArgumentParser
|
||||
|
||||
from functools import reduce
|
||||
|
||||
from abc import ABC
|
||||
@ -5,13 +8,14 @@ from pathlib import Path
|
||||
|
||||
import torch
|
||||
from operator import mul
|
||||
from pytorch_lightning.utilities import argparse_utils
|
||||
from torch import nn
|
||||
from torch.nn import functional as F, Unfold
|
||||
|
||||
# Utility - Modules
|
||||
###################
|
||||
from ..utils.model_io import ModelParameters
|
||||
from ..utils.tools import locate_and_import_class
|
||||
from ..utils.tools import locate_and_import_class, add_argparse_args
|
||||
|
||||
try:
|
||||
import pytorch_lightning as pl
|
||||
@ -32,14 +36,18 @@ try:
|
||||
print(e)
|
||||
return -1
|
||||
|
||||
def __init__(self, hparams):
|
||||
super(LightningBaseModule, self).__init__()
|
||||
@classmethod
|
||||
def from_argparse_args(cls, args, **kwargs):
|
||||
return argparse_utils.from_argparse_args(cls, args, **kwargs)
|
||||
|
||||
# Set Parameters
|
||||
################################
|
||||
self.hparams = hparams
|
||||
self.params = ModelParameters(hparams)
|
||||
self.lr = self.params.lr or 1e-4
|
||||
@classmethod
|
||||
def add_argparse_args(cls, parent_parser):
|
||||
return add_argparse_args(cls, parent_parser)
|
||||
|
||||
def __init__(self, model_parameters, weight_init='xavier_normal_'):
|
||||
super(LightningBaseModule, self).__init__()
|
||||
self._weight_init = weight_init
|
||||
self.params = ModelParameters(model_parameters)
|
||||
|
||||
def size(self):
|
||||
return self.shape
|
||||
@ -47,15 +55,6 @@ try:
|
||||
def additional_scores(self, outputs):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def dataset_class(self):
|
||||
try:
|
||||
return locate_and_import_class(self.params.class_name, folder_path='datasets')
|
||||
except AttributeError as e:
|
||||
raise AttributeError(f'The dataset alias you provided ("{self.params.class_name}") ' +
|
||||
f'was not found!\n' +
|
||||
f'{e}')
|
||||
|
||||
def save_to_disk(self, model_path):
|
||||
Path(model_path, exist_ok=True).mkdir(parents=True, exist_ok=True)
|
||||
if not (model_path / 'model_class.obj').exists():
|
||||
@ -86,8 +85,12 @@ try:
|
||||
def test_epoch_end(self, outputs):
|
||||
raise NotImplementedError
|
||||
|
||||
def init_weights(self, in_place_init_func_=nn.init.xavier_uniform_):
|
||||
weight_initializer = WeightInit(in_place_init_function=in_place_init_func_)
|
||||
def init_weights(self):
|
||||
if isinstance(self._weight_init, str):
|
||||
mod = __import__('torch.nn.init', fromlist=[self._weight_init])
|
||||
self._weight_init = getattr(mod, self._weight_init)
|
||||
assert callable(self._weight_init)
|
||||
weight_initializer = WeightInit(in_place_init_function=self._weight_init)
|
||||
self.apply(weight_initializer)
|
||||
|
||||
module_types = (LightningBaseModule, nn.Module,)
|
||||
|
Reference in New Issue
Block a user