bringing brances up to date

This commit is contained in:
Steffen Illium
2021-02-15 11:39:54 +01:00
parent 010176e80b
commit a966321576
11 changed files with 216 additions and 197 deletions

View File

@ -1,3 +1,6 @@
import inspect
from argparse import ArgumentParser
from functools import reduce
from abc import ABC
@ -5,13 +8,14 @@ from pathlib import Path
import torch
from operator import mul
from pytorch_lightning.utilities import argparse_utils
from torch import nn
from torch.nn import functional as F, Unfold
# Utility - Modules
###################
from ..utils.model_io import ModelParameters
from ..utils.tools import locate_and_import_class
from ..utils.tools import locate_and_import_class, add_argparse_args
try:
import pytorch_lightning as pl
@ -32,14 +36,18 @@ try:
print(e)
return -1
def __init__(self, hparams):
super(LightningBaseModule, self).__init__()
@classmethod
def from_argparse_args(cls, args, **kwargs):
return argparse_utils.from_argparse_args(cls, args, **kwargs)
# Set Parameters
################################
self.hparams = hparams
self.params = ModelParameters(hparams)
self.lr = self.params.lr or 1e-4
@classmethod
def add_argparse_args(cls, parent_parser):
return add_argparse_args(cls, parent_parser)
def __init__(self, model_parameters, weight_init='xavier_normal_'):
super(LightningBaseModule, self).__init__()
self._weight_init = weight_init
self.params = ModelParameters(model_parameters)
def size(self):
return self.shape
@ -47,15 +55,6 @@ try:
def additional_scores(self, outputs):
raise NotImplementedError
@property
def dataset_class(self):
try:
return locate_and_import_class(self.params.class_name, folder_path='datasets')
except AttributeError as e:
raise AttributeError(f'The dataset alias you provided ("{self.params.class_name}") ' +
f'was not found!\n' +
f'{e}')
def save_to_disk(self, model_path):
Path(model_path, exist_ok=True).mkdir(parents=True, exist_ok=True)
if not (model_path / 'model_class.obj').exists():
@ -86,8 +85,12 @@ try:
def test_epoch_end(self, outputs):
raise NotImplementedError
def init_weights(self, in_place_init_func_=nn.init.xavier_uniform_):
weight_initializer = WeightInit(in_place_init_function=in_place_init_func_)
def init_weights(self):
if isinstance(self._weight_init, str):
mod = __import__('torch.nn.init', fromlist=[self._weight_init])
self._weight_init = getattr(mod, self._weight_init)
assert callable(self._weight_init)
weight_initializer = WeightInit(in_place_init_function=self._weight_init)
self.apply(weight_initializer)
module_types = (LightningBaseModule, nn.Module,)