diff --git a/_templates/new_project/main.py b/_templates/new_project/main.py index 692072e..2901ed3 100644 --- a/_templates/new_project/main.py +++ b/_templates/new_project/main.py @@ -9,7 +9,7 @@ from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping from ml_lib.modules.util import LightningBaseModule from ml_lib.utils.config import Config -from ml_lib.utils.loggers import Logger +from ml_lib.utils.loggers import LightningLogger warnings.filterwarnings('ignore', category=FutureWarning) warnings.filterwarnings('ignore', category=UserWarning) @@ -20,7 +20,7 @@ def run_lightning_loop(config_obj): # Logging # ================================================================================ # Logger - with Logger(config_obj) as logger: + with LightningLogger(config_obj) as logger: # Callbacks # ============================================================================= # Checkpoint Saving diff --git a/additions/losses.py b/additions/losses.py index 347770b..5af9cc2 100644 --- a/additions/losses.py +++ b/additions/losses.py @@ -1,3 +1,5 @@ +from typing import Tuple + import torch import torch.nn as nn import torch.nn.functional as F @@ -41,3 +43,33 @@ class FocalLossRob(nn.Module): return x.sum() else: return x + + +class DQN_MSELoss(object): + + def __init__(self, agent_net, target_net, gamma): + self.agent_net = agent_net + self.target_net = target_net + self.gamma = gamma + + def __call__(self, batch: Tuple[torch.Tensor, ...]) -> torch.Tensor: + """ + Calculates the mse loss using a mini batch from the replay buffer + Args: + batch: current mini batch of replay data + Returns: + loss + """ + states, actions, rewards, dones, next_states = batch + + actions = actions.to(torch.int64) + state_action_values = self.agent_net(states).gather(1, actions.unsqueeze(-1)).squeeze(-1) + + with torch.no_grad(): + next_state_values = self.target_net(next_states).max(1)[0] + next_state_values[dones] = 0.0 + next_state_values = next_state_values.detach() + + expected_state_action_values = next_state_values * self.gamma + rewards + + return F.mse_loss(state_action_values, expected_state_action_values) diff --git a/modules/blocks.py b/modules/blocks.py index 602bdb9..a3b71fd 100644 --- a/modules/blocks.py +++ b/modules/blocks.py @@ -8,8 +8,6 @@ import torch from torch import nn from torch.nn import functional as F -from einops import rearrange, repeat - import sys sys.path.append(str(Path(__file__).parent)) @@ -40,7 +38,7 @@ class LinearModule(ShapeMixin, nn.Module): tensor = self.flat(x) tensor = self.dropout(tensor) tensor = self.norm(tensor) - tensor = self.linear(tensor) + tensor = self.linear(tensor.float()) tensor = self.activation(tensor) return tensor @@ -249,6 +247,7 @@ class Attention(nn.Module): ) if project_out else nn.Identity() def forward(self, x, mask=None, return_attn_weights=False): + from einops import rearrange, repeat # noinspection PyTupleAssignmentBalance b, n, _, h = *x.shape, self.heads diff --git a/modules/util.py b/modules/util.py index a615bbf..e1f338e 100644 --- a/modules/util.py +++ b/modules/util.py @@ -129,8 +129,10 @@ try: self._weight_init = weight_init self.params = ModelParameters(model_parameters) - self.metrics = PLMetrics(self.params.n_classes, tag='PL') - pass + if hasattr(self.params, 'n_classes'): + self.metrics = PLMetrics(self.params.n_classes, tag='PL') + else: + pass def size(self): return self.shape diff --git a/utils/config.py b/utils/config.py index 95a76d0..4e1860a 100644 --- a/utils/config.py +++ b/utils/config.py @@ -1,5 +1,6 @@ import ast import configparser +from distutils.util import strtobool from pathlib import Path from typing import Mapping, Dict @@ -14,7 +15,7 @@ from configparser import ConfigParser, DuplicateSectionError import hashlib from pytorch_lightning import Trainer -from ml_lib.utils.loggers import Logger +from ml_lib.utils.loggers import LightningLogger from ml_lib.utils.tools import locate_and_import_class, auto_cast @@ -27,6 +28,7 @@ def parse_comandline_args_add_defaults(filepath, overrides=None): parser.add_argument('--model_name', type=str) parser.add_argument('--data_name', type=str) parser.add_argument('--seed', type=str) + parser.add_argument('--debug', type=strtobool) # Load Defaults from _parameters.ini file config = configparser.ConfigParser() @@ -52,12 +54,9 @@ def parse_comandline_args_add_defaults(filepath, overrides=None): found_data_class = locate_and_import_class(data_name, 'datasets') found_model_class = locate_and_import_class(model_name, 'models') - for module in [Logger, Trainer, found_data_class, found_model_class]: + for module in [LightningLogger, Trainer, found_data_class, found_model_class]: parser = module.add_argparse_args(parser) - # This is obsolete - # new_defaults.update(data_name=data_name, model_name=model_name) - args, _ = parser.parse_known_args(namespace=Namespace(**new_defaults)) args = vars(args) diff --git a/utils/loggers.py b/utils/loggers.py index 5d9118b..2b3bc58 100644 --- a/utils/loggers.py +++ b/utils/loggers.py @@ -1,3 +1,5 @@ +import inspect +from argparse import ArgumentParser from copy import deepcopy import hashlib @@ -5,16 +7,17 @@ from pathlib import Path import os from pytorch_lightning.loggers.base import LightningLoggerBase -from pytorch_lightning.loggers.neptune import NeptuneLogger from neptune.api_exceptions import ProjectNotFound +from pytorch_lightning.loggers.neptune import NeptuneLogger + from pytorch_lightning.loggers.csv_logs import CSVLogger from pytorch_lightning.utilities import argparse_utils from ml_lib.utils.tools import add_argparse_args -class Logger(LightningLoggerBase): +class LightningLogger(LightningLoggerBase): @classmethod def from_argparse_args(cls, args, **kwargs): @@ -97,7 +100,7 @@ class Logger(LightningLoggerBase): They are editable after experiment is created (see: append_tag() and remove_tag()). Tags are displayed in the experiment’s Details and can be viewed in experiments view as a column. """ - super(Logger, self).__init__() + super(LightningLogger, self).__init__() self.debug = debug self.owner = owner if not self.debug else 'testuser' diff --git a/utils/tools.py b/utils/tools.py index 012a496..04a50a5 100644 --- a/utils/tools.py +++ b/utils/tools.py @@ -67,7 +67,7 @@ def locate_and_import_class(class_name, folder_path: Union[str, PurePath] = ''): return model_class except AttributeError: continue - raise AttributeError(f'Check the Model name. Possible model files are:\n{[x.name for x in module_paths]}') + raise AttributeError(f'Check the {folder_path.name} name. Possible files are:\n{[x.name for x in module_paths]}') def add_argparse_args(cls, parent_parser):