Code Comments, Getting Dirty Env, Naming
This commit is contained in:
parent
faa27c3cf9
commit
ab01006eae
@ -9,7 +9,7 @@ from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
|
|||||||
|
|
||||||
from ml_lib.modules.util import LightningBaseModule
|
from ml_lib.modules.util import LightningBaseModule
|
||||||
from ml_lib.utils.config import Config
|
from ml_lib.utils.config import Config
|
||||||
from ml_lib.utils.loggers import Logger
|
from ml_lib.utils.loggers import LightningLogger
|
||||||
|
|
||||||
warnings.filterwarnings('ignore', category=FutureWarning)
|
warnings.filterwarnings('ignore', category=FutureWarning)
|
||||||
warnings.filterwarnings('ignore', category=UserWarning)
|
warnings.filterwarnings('ignore', category=UserWarning)
|
||||||
@ -20,7 +20,7 @@ def run_lightning_loop(config_obj):
|
|||||||
# Logging
|
# Logging
|
||||||
# ================================================================================
|
# ================================================================================
|
||||||
# Logger
|
# Logger
|
||||||
with Logger(config_obj) as logger:
|
with LightningLogger(config_obj) as logger:
|
||||||
# Callbacks
|
# Callbacks
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Checkpoint Saving
|
# Checkpoint Saving
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
from typing import Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@ -41,3 +43,33 @@ class FocalLossRob(nn.Module):
|
|||||||
return x.sum()
|
return x.sum()
|
||||||
else:
|
else:
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class DQN_MSELoss(object):
|
||||||
|
|
||||||
|
def __init__(self, agent_net, target_net, gamma):
|
||||||
|
self.agent_net = agent_net
|
||||||
|
self.target_net = target_net
|
||||||
|
self.gamma = gamma
|
||||||
|
|
||||||
|
def __call__(self, batch: Tuple[torch.Tensor, ...]) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Calculates the mse loss using a mini batch from the replay buffer
|
||||||
|
Args:
|
||||||
|
batch: current mini batch of replay data
|
||||||
|
Returns:
|
||||||
|
loss
|
||||||
|
"""
|
||||||
|
states, actions, rewards, dones, next_states = batch
|
||||||
|
|
||||||
|
actions = actions.to(torch.int64)
|
||||||
|
state_action_values = self.agent_net(states).gather(1, actions.unsqueeze(-1)).squeeze(-1)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
next_state_values = self.target_net(next_states).max(1)[0]
|
||||||
|
next_state_values[dones] = 0.0
|
||||||
|
next_state_values = next_state_values.detach()
|
||||||
|
|
||||||
|
expected_state_action_values = next_state_values * self.gamma + rewards
|
||||||
|
|
||||||
|
return F.mse_loss(state_action_values, expected_state_action_values)
|
||||||
|
@ -8,8 +8,6 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
|
||||||
from einops import rearrange, repeat
|
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
sys.path.append(str(Path(__file__).parent))
|
sys.path.append(str(Path(__file__).parent))
|
||||||
|
|
||||||
@ -40,7 +38,7 @@ class LinearModule(ShapeMixin, nn.Module):
|
|||||||
tensor = self.flat(x)
|
tensor = self.flat(x)
|
||||||
tensor = self.dropout(tensor)
|
tensor = self.dropout(tensor)
|
||||||
tensor = self.norm(tensor)
|
tensor = self.norm(tensor)
|
||||||
tensor = self.linear(tensor)
|
tensor = self.linear(tensor.float())
|
||||||
tensor = self.activation(tensor)
|
tensor = self.activation(tensor)
|
||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
@ -249,6 +247,7 @@ class Attention(nn.Module):
|
|||||||
) if project_out else nn.Identity()
|
) if project_out else nn.Identity()
|
||||||
|
|
||||||
def forward(self, x, mask=None, return_attn_weights=False):
|
def forward(self, x, mask=None, return_attn_weights=False):
|
||||||
|
from einops import rearrange, repeat
|
||||||
# noinspection PyTupleAssignmentBalance
|
# noinspection PyTupleAssignmentBalance
|
||||||
b, n, _, h = *x.shape, self.heads
|
b, n, _, h = *x.shape, self.heads
|
||||||
|
|
||||||
|
@ -129,7 +129,9 @@ try:
|
|||||||
self._weight_init = weight_init
|
self._weight_init = weight_init
|
||||||
self.params = ModelParameters(model_parameters)
|
self.params = ModelParameters(model_parameters)
|
||||||
|
|
||||||
|
if hasattr(self.params, 'n_classes'):
|
||||||
self.metrics = PLMetrics(self.params.n_classes, tag='PL')
|
self.metrics = PLMetrics(self.params.n_classes, tag='PL')
|
||||||
|
else:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def size(self):
|
def size(self):
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import ast
|
import ast
|
||||||
import configparser
|
import configparser
|
||||||
|
from distutils.util import strtobool
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Mapping, Dict
|
from typing import Mapping, Dict
|
||||||
|
|
||||||
@ -14,7 +15,7 @@ from configparser import ConfigParser, DuplicateSectionError
|
|||||||
import hashlib
|
import hashlib
|
||||||
from pytorch_lightning import Trainer
|
from pytorch_lightning import Trainer
|
||||||
|
|
||||||
from ml_lib.utils.loggers import Logger
|
from ml_lib.utils.loggers import LightningLogger
|
||||||
from ml_lib.utils.tools import locate_and_import_class, auto_cast
|
from ml_lib.utils.tools import locate_and_import_class, auto_cast
|
||||||
|
|
||||||
|
|
||||||
@ -27,6 +28,7 @@ def parse_comandline_args_add_defaults(filepath, overrides=None):
|
|||||||
parser.add_argument('--model_name', type=str)
|
parser.add_argument('--model_name', type=str)
|
||||||
parser.add_argument('--data_name', type=str)
|
parser.add_argument('--data_name', type=str)
|
||||||
parser.add_argument('--seed', type=str)
|
parser.add_argument('--seed', type=str)
|
||||||
|
parser.add_argument('--debug', type=strtobool)
|
||||||
|
|
||||||
# Load Defaults from _parameters.ini file
|
# Load Defaults from _parameters.ini file
|
||||||
config = configparser.ConfigParser()
|
config = configparser.ConfigParser()
|
||||||
@ -52,12 +54,9 @@ def parse_comandline_args_add_defaults(filepath, overrides=None):
|
|||||||
found_data_class = locate_and_import_class(data_name, 'datasets')
|
found_data_class = locate_and_import_class(data_name, 'datasets')
|
||||||
found_model_class = locate_and_import_class(model_name, 'models')
|
found_model_class = locate_and_import_class(model_name, 'models')
|
||||||
|
|
||||||
for module in [Logger, Trainer, found_data_class, found_model_class]:
|
for module in [LightningLogger, Trainer, found_data_class, found_model_class]:
|
||||||
parser = module.add_argparse_args(parser)
|
parser = module.add_argparse_args(parser)
|
||||||
|
|
||||||
# This is obsolete
|
|
||||||
# new_defaults.update(data_name=data_name, model_name=model_name)
|
|
||||||
|
|
||||||
args, _ = parser.parse_known_args(namespace=Namespace(**new_defaults))
|
args, _ = parser.parse_known_args(namespace=Namespace(**new_defaults))
|
||||||
|
|
||||||
args = vars(args)
|
args = vars(args)
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
import inspect
|
||||||
|
from argparse import ArgumentParser
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
|
||||||
import hashlib
|
import hashlib
|
||||||
@ -5,16 +7,17 @@ from pathlib import Path
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
from pytorch_lightning.loggers.base import LightningLoggerBase
|
from pytorch_lightning.loggers.base import LightningLoggerBase
|
||||||
from pytorch_lightning.loggers.neptune import NeptuneLogger
|
|
||||||
from neptune.api_exceptions import ProjectNotFound
|
from neptune.api_exceptions import ProjectNotFound
|
||||||
|
|
||||||
|
from pytorch_lightning.loggers.neptune import NeptuneLogger
|
||||||
|
|
||||||
from pytorch_lightning.loggers.csv_logs import CSVLogger
|
from pytorch_lightning.loggers.csv_logs import CSVLogger
|
||||||
from pytorch_lightning.utilities import argparse_utils
|
from pytorch_lightning.utilities import argparse_utils
|
||||||
|
|
||||||
from ml_lib.utils.tools import add_argparse_args
|
from ml_lib.utils.tools import add_argparse_args
|
||||||
|
|
||||||
|
|
||||||
class Logger(LightningLoggerBase):
|
class LightningLogger(LightningLoggerBase):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_argparse_args(cls, args, **kwargs):
|
def from_argparse_args(cls, args, **kwargs):
|
||||||
@ -97,7 +100,7 @@ class Logger(LightningLoggerBase):
|
|||||||
They are editable after experiment is created (see: append_tag() and remove_tag()).
|
They are editable after experiment is created (see: append_tag() and remove_tag()).
|
||||||
Tags are displayed in the experiment’s Details and can be viewed in experiments view as a column.
|
Tags are displayed in the experiment’s Details and can be viewed in experiments view as a column.
|
||||||
"""
|
"""
|
||||||
super(Logger, self).__init__()
|
super(LightningLogger, self).__init__()
|
||||||
|
|
||||||
self.debug = debug
|
self.debug = debug
|
||||||
self.owner = owner if not self.debug else 'testuser'
|
self.owner = owner if not self.debug else 'testuser'
|
||||||
|
@ -67,7 +67,7 @@ def locate_and_import_class(class_name, folder_path: Union[str, PurePath] = ''):
|
|||||||
return model_class
|
return model_class
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
continue
|
continue
|
||||||
raise AttributeError(f'Check the Model name. Possible model files are:\n{[x.name for x in module_paths]}')
|
raise AttributeError(f'Check the {folder_path.name} name. Possible files are:\n{[x.name for x in module_paths]}')
|
||||||
|
|
||||||
|
|
||||||
def add_argparse_args(cls, parent_parser):
|
def add_argparse_args(cls, parent_parser):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user