Code Comments, Getting Dirty Env, Naming

This commit is contained in:
Steffen Illium 2021-05-11 10:31:34 +02:00
parent faa27c3cf9
commit ab01006eae
7 changed files with 51 additions and 16 deletions

View File

@ -9,7 +9,7 @@ from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from ml_lib.modules.util import LightningBaseModule from ml_lib.modules.util import LightningBaseModule
from ml_lib.utils.config import Config from ml_lib.utils.config import Config
from ml_lib.utils.loggers import Logger from ml_lib.utils.loggers import LightningLogger
warnings.filterwarnings('ignore', category=FutureWarning) warnings.filterwarnings('ignore', category=FutureWarning)
warnings.filterwarnings('ignore', category=UserWarning) warnings.filterwarnings('ignore', category=UserWarning)
@ -20,7 +20,7 @@ def run_lightning_loop(config_obj):
# Logging # Logging
# ================================================================================ # ================================================================================
# Logger # Logger
with Logger(config_obj) as logger: with LightningLogger(config_obj) as logger:
# Callbacks # Callbacks
# ============================================================================= # =============================================================================
# Checkpoint Saving # Checkpoint Saving

View File

@ -1,3 +1,5 @@
from typing import Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
@ -41,3 +43,33 @@ class FocalLossRob(nn.Module):
return x.sum() return x.sum()
else: else:
return x return x
class DQN_MSELoss(object):
def __init__(self, agent_net, target_net, gamma):
self.agent_net = agent_net
self.target_net = target_net
self.gamma = gamma
def __call__(self, batch: Tuple[torch.Tensor, ...]) -> torch.Tensor:
"""
Calculates the mse loss using a mini batch from the replay buffer
Args:
batch: current mini batch of replay data
Returns:
loss
"""
states, actions, rewards, dones, next_states = batch
actions = actions.to(torch.int64)
state_action_values = self.agent_net(states).gather(1, actions.unsqueeze(-1)).squeeze(-1)
with torch.no_grad():
next_state_values = self.target_net(next_states).max(1)[0]
next_state_values[dones] = 0.0
next_state_values = next_state_values.detach()
expected_state_action_values = next_state_values * self.gamma + rewards
return F.mse_loss(state_action_values, expected_state_action_values)

View File

@ -8,8 +8,6 @@ import torch
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
from einops import rearrange, repeat
import sys import sys
sys.path.append(str(Path(__file__).parent)) sys.path.append(str(Path(__file__).parent))
@ -40,7 +38,7 @@ class LinearModule(ShapeMixin, nn.Module):
tensor = self.flat(x) tensor = self.flat(x)
tensor = self.dropout(tensor) tensor = self.dropout(tensor)
tensor = self.norm(tensor) tensor = self.norm(tensor)
tensor = self.linear(tensor) tensor = self.linear(tensor.float())
tensor = self.activation(tensor) tensor = self.activation(tensor)
return tensor return tensor
@ -249,6 +247,7 @@ class Attention(nn.Module):
) if project_out else nn.Identity() ) if project_out else nn.Identity()
def forward(self, x, mask=None, return_attn_weights=False): def forward(self, x, mask=None, return_attn_weights=False):
from einops import rearrange, repeat
# noinspection PyTupleAssignmentBalance # noinspection PyTupleAssignmentBalance
b, n, _, h = *x.shape, self.heads b, n, _, h = *x.shape, self.heads

View File

@ -129,8 +129,10 @@ try:
self._weight_init = weight_init self._weight_init = weight_init
self.params = ModelParameters(model_parameters) self.params = ModelParameters(model_parameters)
self.metrics = PLMetrics(self.params.n_classes, tag='PL') if hasattr(self.params, 'n_classes'):
pass self.metrics = PLMetrics(self.params.n_classes, tag='PL')
else:
pass
def size(self): def size(self):
return self.shape return self.shape

View File

@ -1,5 +1,6 @@
import ast import ast
import configparser import configparser
from distutils.util import strtobool
from pathlib import Path from pathlib import Path
from typing import Mapping, Dict from typing import Mapping, Dict
@ -14,7 +15,7 @@ from configparser import ConfigParser, DuplicateSectionError
import hashlib import hashlib
from pytorch_lightning import Trainer from pytorch_lightning import Trainer
from ml_lib.utils.loggers import Logger from ml_lib.utils.loggers import LightningLogger
from ml_lib.utils.tools import locate_and_import_class, auto_cast from ml_lib.utils.tools import locate_and_import_class, auto_cast
@ -27,6 +28,7 @@ def parse_comandline_args_add_defaults(filepath, overrides=None):
parser.add_argument('--model_name', type=str) parser.add_argument('--model_name', type=str)
parser.add_argument('--data_name', type=str) parser.add_argument('--data_name', type=str)
parser.add_argument('--seed', type=str) parser.add_argument('--seed', type=str)
parser.add_argument('--debug', type=strtobool)
# Load Defaults from _parameters.ini file # Load Defaults from _parameters.ini file
config = configparser.ConfigParser() config = configparser.ConfigParser()
@ -52,12 +54,9 @@ def parse_comandline_args_add_defaults(filepath, overrides=None):
found_data_class = locate_and_import_class(data_name, 'datasets') found_data_class = locate_and_import_class(data_name, 'datasets')
found_model_class = locate_and_import_class(model_name, 'models') found_model_class = locate_and_import_class(model_name, 'models')
for module in [Logger, Trainer, found_data_class, found_model_class]: for module in [LightningLogger, Trainer, found_data_class, found_model_class]:
parser = module.add_argparse_args(parser) parser = module.add_argparse_args(parser)
# This is obsolete
# new_defaults.update(data_name=data_name, model_name=model_name)
args, _ = parser.parse_known_args(namespace=Namespace(**new_defaults)) args, _ = parser.parse_known_args(namespace=Namespace(**new_defaults))
args = vars(args) args = vars(args)

View File

@ -1,3 +1,5 @@
import inspect
from argparse import ArgumentParser
from copy import deepcopy from copy import deepcopy
import hashlib import hashlib
@ -5,16 +7,17 @@ from pathlib import Path
import os import os
from pytorch_lightning.loggers.base import LightningLoggerBase from pytorch_lightning.loggers.base import LightningLoggerBase
from pytorch_lightning.loggers.neptune import NeptuneLogger
from neptune.api_exceptions import ProjectNotFound from neptune.api_exceptions import ProjectNotFound
from pytorch_lightning.loggers.neptune import NeptuneLogger
from pytorch_lightning.loggers.csv_logs import CSVLogger from pytorch_lightning.loggers.csv_logs import CSVLogger
from pytorch_lightning.utilities import argparse_utils from pytorch_lightning.utilities import argparse_utils
from ml_lib.utils.tools import add_argparse_args from ml_lib.utils.tools import add_argparse_args
class Logger(LightningLoggerBase): class LightningLogger(LightningLoggerBase):
@classmethod @classmethod
def from_argparse_args(cls, args, **kwargs): def from_argparse_args(cls, args, **kwargs):
@ -97,7 +100,7 @@ class Logger(LightningLoggerBase):
They are editable after experiment is created (see: append_tag() and remove_tag()). They are editable after experiment is created (see: append_tag() and remove_tag()).
Tags are displayed in the experiments Details and can be viewed in experiments view as a column. Tags are displayed in the experiments Details and can be viewed in experiments view as a column.
""" """
super(Logger, self).__init__() super(LightningLogger, self).__init__()
self.debug = debug self.debug = debug
self.owner = owner if not self.debug else 'testuser' self.owner = owner if not self.debug else 'testuser'

View File

@ -67,7 +67,7 @@ def locate_and_import_class(class_name, folder_path: Union[str, PurePath] = ''):
return model_class return model_class
except AttributeError: except AttributeError:
continue continue
raise AttributeError(f'Check the Model name. Possible model files are:\n{[x.name for x in module_paths]}') raise AttributeError(f'Check the {folder_path.name} name. Possible files are:\n{[x.name for x in module_paths]}')
def add_argparse_args(cls, parent_parser): def add_argparse_args(cls, parent_parser):