bringing brances up to date
This commit is contained in:
parent
010176e80b
commit
a966321576
@ -10,24 +10,31 @@ from ml_lib.audio_toolset.audio_io import LibrosaAudioToMel, MelToImage
|
||||
from ml_lib.audio_toolset.mel_dataset import TorchMelDataset
|
||||
|
||||
|
||||
class _AudioToMelDataset(Dataset, ABC):
|
||||
import librosa
|
||||
|
||||
|
||||
class LibrosaAudioToMelDataset(Dataset):
|
||||
|
||||
@property
|
||||
def audio_file_duration(self):
|
||||
raise NotImplementedError
|
||||
return librosa.get_duration(sr=self.mel_kwargs.get('sr', None), filename=self.audio_path)
|
||||
|
||||
@property
|
||||
def sampling_rate(self):
|
||||
raise NotImplementedError
|
||||
return self.mel_kwargs.get('sr', None)
|
||||
|
||||
def __init__(self, audio_file_path, label, sample_segment_len=0, sample_hop_len=0, reset=False,
|
||||
audio_augmentations=None, mel_augmentations=None, mel_kwargs=None, **kwargs):
|
||||
self.ignored_kwargs = kwargs
|
||||
super(LibrosaAudioToMelDataset, self).__init__()
|
||||
|
||||
# audio_file, sampling_rate = librosa.load(self.audio_path, sr=sampling_rate)
|
||||
mel_kwargs.update(sr=mel_kwargs.get('sr', None) or librosa.get_samplerate(audio_file_path))
|
||||
self.mel_kwargs = mel_kwargs
|
||||
self.reset = reset
|
||||
self.audio_path = Path(audio_file_path)
|
||||
|
||||
mel_folder_suffix = self.audio_path.parent.parent.name
|
||||
|
||||
self.mel_file_path = Path(str(self.audio_path)
|
||||
.replace(mel_folder_suffix, f'{mel_folder_suffix}_mel_folder')
|
||||
.replace(self.audio_path.suffix, '.npy'))
|
||||
@ -38,59 +45,25 @@ class _AudioToMelDataset(Dataset, ABC):
|
||||
self.audio_file_duration, mel_kwargs['sr'], mel_kwargs['hop_length'],
|
||||
mel_kwargs['n_mels'], transform=mel_augmentations)
|
||||
|
||||
def _build_mel(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def __getitem__(self, item):
|
||||
try:
|
||||
return self.dataset[item]
|
||||
except FileNotFoundError:
|
||||
assert self._build_mel()
|
||||
return self.dataset[item]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
||||
|
||||
|
||||
import librosa
|
||||
|
||||
|
||||
class LibrosaAudioToMelDataset(_AudioToMelDataset):
|
||||
|
||||
@property
|
||||
def audio_file_duration(self):
|
||||
return librosa.get_duration(sr=self.mel_kwargs.get('sr', None), filename=self.audio_path)
|
||||
|
||||
@property
|
||||
def sampling_rate(self):
|
||||
return self.mel_kwargs.get('sr', None)
|
||||
|
||||
def __init__(self, audio_file_path, *args, **kwargs):
|
||||
|
||||
audio_file_path = Path(audio_file_path)
|
||||
# audio_file, sampling_rate = librosa.load(self.audio_path, sr=sampling_rate)
|
||||
mel_kwargs = kwargs.get('mel_kwargs', dict())
|
||||
mel_kwargs.update(sr=mel_kwargs.get('sr', None) or librosa.get_samplerate(audio_file_path))
|
||||
kwargs.update(mel_kwargs=mel_kwargs)
|
||||
|
||||
super(LibrosaAudioToMelDataset, self).__init__(audio_file_path, *args, **kwargs)
|
||||
|
||||
self._mel_transform = Compose([LibrosaAudioToMel(**mel_kwargs),
|
||||
MelToImage()
|
||||
])
|
||||
|
||||
def _build_mel(self):
|
||||
def __getitem__(self, item):
|
||||
return self.dataset[item]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
||||
|
||||
def build_mel(self):
|
||||
if self.reset:
|
||||
self.mel_file_path.unlink(missing_ok=True)
|
||||
if not self.mel_file_path.exists():
|
||||
lockfile = Path(str(self.mel_file_path).replace(self.mel_file_path.suffix, '.lock'))
|
||||
self.mel_file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
lockfile.touch(exist_ok=False)
|
||||
raw_sample, _ = librosa.core.load(self.audio_path, sr=self.sampling_rate)
|
||||
mel_sample = self._mel_transform(raw_sample)
|
||||
with self.mel_file_path.open('wb') as mel_file:
|
||||
pickle.dump(mel_sample, mel_file, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
lockfile.unlink(missing_ok=False)
|
||||
else:
|
||||
pass
|
||||
|
||||
|
@ -11,13 +11,16 @@ class TorchMelDataset(Dataset):
|
||||
def __init__(self, mel_path, sub_segment_len, sub_segment_hop_len, label, audio_file_len,
|
||||
sampling_rate, mel_hop_len, n_mels, transform=None, auto_pad_to_shape=True):
|
||||
super(TorchMelDataset, self).__init__()
|
||||
self.sampling_rate = sampling_rate
|
||||
self.audio_file_len = audio_file_len
|
||||
self.padding = AutoPadToShape((n_mels, sub_segment_len)) if auto_pad_to_shape and sub_segment_len else None
|
||||
self.sampling_rate = int(sampling_rate)
|
||||
self.audio_file_len = int(audio_file_len)
|
||||
if auto_pad_to_shape and sub_segment_len:
|
||||
self.padding = AutoPadToShape((int(n_mels), int(sub_segment_len)))
|
||||
else:
|
||||
self.padding = None
|
||||
self.path = Path(mel_path)
|
||||
self.sub_segment_len = sub_segment_len
|
||||
self.mel_hop_len = mel_hop_len
|
||||
self.sub_segment_hop_len = sub_segment_hop_len
|
||||
self.sub_segment_len = int(sub_segment_len)
|
||||
self.mel_hop_len = int(mel_hop_len)
|
||||
self.sub_segment_hop_len = int(sub_segment_hop_len)
|
||||
self.n = int((self.sampling_rate / self.mel_hop_len) * self.audio_file_len + 1)
|
||||
if self.sub_segment_len and self.sub_segment_hop_len:
|
||||
self.offsets = list(range(0, self.n - self.sub_segment_len, self.sub_segment_hop_len))
|
||||
@ -27,8 +30,6 @@ class TorchMelDataset(Dataset):
|
||||
self.transform = transform
|
||||
|
||||
def __getitem__(self, item):
|
||||
while Path(str(self.path).replace(self.path.suffix, '.lock')).exists():
|
||||
time.sleep(0.01)
|
||||
with self.path.open('rb') as mel_file:
|
||||
mel_spec = pickle.load(mel_file, fix_imports=True)
|
||||
start = self.offsets[item]
|
||||
@ -38,7 +39,7 @@ class TorchMelDataset(Dataset):
|
||||
snippet = self.transform(snippet)
|
||||
if self.padding:
|
||||
snippet = self.padding(snippet)
|
||||
return snippet, self.label
|
||||
return self.path.__str__(), snippet, self.label
|
||||
|
||||
def __len__(self):
|
||||
return len(self.offsets)
|
||||
|
@ -1,67 +0,0 @@
|
||||
|
||||
import torchaudio
|
||||
if sys.platform =='windows':
|
||||
torchaudio.set_audio_backend('soundfile')
|
||||
else:
|
||||
torchaudio.set_audio_backend('sox_io')
|
||||
|
||||
|
||||
class PyTorchAudioToMelDataset(_AudioToMelDataset):
|
||||
|
||||
@property
|
||||
def audio_file_duration(self):
|
||||
info_obj = torchaudio.info(self.audio_path)
|
||||
return info_obj.num_frames / info_obj.sample_rate
|
||||
|
||||
@property
|
||||
def sampling_rate(self):
|
||||
return self.mel_kwargs['sample_rate']
|
||||
|
||||
def __init__(self, audio_file_path, *args, **kwargs):
|
||||
super(PyTorchAudioToMelDataset, self).__init__(audio_file_path, *args, **kwargs)
|
||||
|
||||
audio_file_path = Path(audio_file_path)
|
||||
# audio_file, sampling_rate = librosa.load(self.audio_path, sr=sampling_rate)
|
||||
|
||||
from torchaudio.transforms import MelSpectrogram
|
||||
self._mel_transform = Compose([MelSpectrogram(**self.mel_kwargs),
|
||||
MelToImage()
|
||||
])
|
||||
|
||||
def _build_mel(self):
|
||||
if self.reset:
|
||||
self.mel_file_path.unlink(missing_ok=True)
|
||||
if not self.mel_file_path.exists():
|
||||
self.mel_file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
lock_file = Path(str(self.mel_file_path).replace(self.mel_file_path.suffix, '.lock'))
|
||||
lock_file.touch(exist_ok=False)
|
||||
|
||||
try:
|
||||
audio_sample, sample_rate = torchaudio.load(self.audio_path)
|
||||
except RuntimeError:
|
||||
import soundfile
|
||||
|
||||
data, samplerate = soundfile.read(self.audio_path)
|
||||
# sf.available_formats()
|
||||
# sf.available_subtypes()
|
||||
soundfile.write(self.audio_path, data, samplerate, subtype='PCM_32')
|
||||
|
||||
audio_sample, sample_rate = torchaudio.load(self.audio_path)
|
||||
if sample_rate != self.sampling_rate:
|
||||
resample = torchaudio.transforms.Resample(orig_freq=int(sample_rate), new_freq=int(self.sampling_rate))
|
||||
audio_sample = resample(audio_sample)
|
||||
if audio_sample.shape[0] > 1:
|
||||
# Transform Stereo to Mono
|
||||
audio_sample = audio_sample.mean(dim=0, keepdim=True)
|
||||
mel_sample = self._mel_transform(audio_sample)
|
||||
with self.mel_file_path.open('wb') as mel_file:
|
||||
pickle.dump(mel_sample, mel_file, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
lock_file.unlink()
|
||||
else:
|
||||
# print(f"Already existed.. Skipping {filename}")
|
||||
# mel_file = mel_file
|
||||
pass
|
||||
|
||||
# with mel_file.open(mode='rb') as f:
|
||||
# mel_sample = pickle.load(f, fix_imports=True)
|
||||
return self.mel_file_path.exists()
|
47
metrics/attention_rollout.py
Normal file
47
metrics/attention_rollout.py
Normal file
@ -0,0 +1,47 @@
|
||||
import numpy as np
|
||||
|
||||
from einops import reduce
|
||||
|
||||
|
||||
import torch
|
||||
from sklearn.ensemble import IsolationForest
|
||||
from sklearn.metrics import recall_score, roc_auc_score, average_precision_score
|
||||
|
||||
from ml_lib.metrics._base_score import _BaseScores
|
||||
|
||||
|
||||
class AttentionRollout(_BaseScores):
|
||||
|
||||
def __init__(self, *args):
|
||||
super(AttentionRollout, self).__init__(*args)
|
||||
pass
|
||||
|
||||
def __call__(self, outputs):
|
||||
summary_dict = dict()
|
||||
#######################################################################################
|
||||
# Additional Score - Histogram Distances - Image Plotting
|
||||
#######################################################################################
|
||||
#
|
||||
# INIT
|
||||
attn_weights = [output['attn_weights'].cpu().numpy() for output in outputs]
|
||||
attn_reduce_heads = [reduce(x, '') for x in attn_weights]
|
||||
|
||||
if self.model.params.use_residual:
|
||||
residual_att = np.eye(att_mat.shape[1])[None, ...]
|
||||
aug_att_mat = att_mat + residual_att
|
||||
aug_att_mat = aug_att_mat / aug_att_mat.sum(axis=-1)[..., None]
|
||||
else:
|
||||
aug_att_mat = att_mat
|
||||
|
||||
joint_attentions = np.zeros(aug_att_mat.shape)
|
||||
|
||||
layers = joint_attentions.shape[0]
|
||||
joint_attentions[0] = aug_att_mat[0]
|
||||
for i in np.arange(1, layers):
|
||||
joint_attentions[i] = aug_att_mat[i].dot(joint_attentions[i - 1])
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
@ -113,17 +113,17 @@ class MultiClassScores(_BaseScores):
|
||||
#######################################################################################
|
||||
#
|
||||
# Confusion matrix
|
||||
|
||||
fig1, ax1 = plt.subplots(dpi=96)
|
||||
cm = confusion_matrix([class_names[x] for x in y_true], [class_names[x] for x in y_pred_max],
|
||||
labels=[class_names[key] for key in class_names.keys()],
|
||||
normalize='all')
|
||||
disp = ConfusionMatrixDisplay(confusion_matrix=cm,
|
||||
display_labels=[class_names[i] for i in range(self.model.n_classes)]
|
||||
)
|
||||
disp.plot(include_values=True)
|
||||
disp.plot(include_values=True, ax=ax1)
|
||||
|
||||
self.model.logger.log_image('Confusion_Matrix', image=disp.figure_, step=self.model.current_epoch)
|
||||
self.model.logger.log_image('Confusion_Matrix', image=fig1, step=self.model.current_epoch)
|
||||
# self.model.logger.log_image('Confusion_Matrix', image=disp.figure_, step=self.model.current_epoch, ext='pdf')
|
||||
|
||||
plt.close('all')
|
||||
return summary_dict
|
||||
return summary_dict
|
||||
|
@ -291,19 +291,17 @@ class TransformerModule(ShapeMixin, nn.Module):
|
||||
|
||||
for attn, mlp in zip(self.attns, self.mlps):
|
||||
# Attention
|
||||
skip_connection = tensor.clone()
|
||||
tensor = self.norm(tensor)
|
||||
attn_tensor = self.norm(tensor)
|
||||
if return_attn_weights:
|
||||
tensor, attn_weight = attn(tensor, mask=mask, return_attn_weights=return_attn_weights)
|
||||
attn_tensor, attn_weight = attn(attn_tensor, mask=mask, return_attn_weights=return_attn_weights)
|
||||
attn_weights.append(attn_weight)
|
||||
else:
|
||||
tensor = attn(tensor, mask=mask)
|
||||
tensor = tensor + skip_connection
|
||||
attn_tensor = attn(attn_tensor, mask=mask)
|
||||
tensor = attn_tensor + tensor
|
||||
|
||||
# MLP
|
||||
skip_connection = tensor.clone()
|
||||
tensor = self.norm(tensor)
|
||||
tensor = mlp(tensor)
|
||||
tensor = tensor + skip_connection
|
||||
mlp_tensor = self.norm(tensor)
|
||||
mlp_tensor = mlp(mlp_tensor)
|
||||
tensor = tensor + mlp_tensor
|
||||
|
||||
return (tensor, attn_weights) if return_attn_weights else tensor
|
||||
|
@ -1,3 +1,6 @@
|
||||
import inspect
|
||||
from argparse import ArgumentParser
|
||||
|
||||
from functools import reduce
|
||||
|
||||
from abc import ABC
|
||||
@ -5,13 +8,14 @@ from pathlib import Path
|
||||
|
||||
import torch
|
||||
from operator import mul
|
||||
from pytorch_lightning.utilities import argparse_utils
|
||||
from torch import nn
|
||||
from torch.nn import functional as F, Unfold
|
||||
|
||||
# Utility - Modules
|
||||
###################
|
||||
from ..utils.model_io import ModelParameters
|
||||
from ..utils.tools import locate_and_import_class
|
||||
from ..utils.tools import locate_and_import_class, add_argparse_args
|
||||
|
||||
try:
|
||||
import pytorch_lightning as pl
|
||||
@ -32,14 +36,18 @@ try:
|
||||
print(e)
|
||||
return -1
|
||||
|
||||
def __init__(self, hparams):
|
||||
super(LightningBaseModule, self).__init__()
|
||||
@classmethod
|
||||
def from_argparse_args(cls, args, **kwargs):
|
||||
return argparse_utils.from_argparse_args(cls, args, **kwargs)
|
||||
|
||||
# Set Parameters
|
||||
################################
|
||||
self.hparams = hparams
|
||||
self.params = ModelParameters(hparams)
|
||||
self.lr = self.params.lr or 1e-4
|
||||
@classmethod
|
||||
def add_argparse_args(cls, parent_parser):
|
||||
return add_argparse_args(cls, parent_parser)
|
||||
|
||||
def __init__(self, model_parameters, weight_init='xavier_normal_'):
|
||||
super(LightningBaseModule, self).__init__()
|
||||
self._weight_init = weight_init
|
||||
self.params = ModelParameters(model_parameters)
|
||||
|
||||
def size(self):
|
||||
return self.shape
|
||||
@ -47,15 +55,6 @@ try:
|
||||
def additional_scores(self, outputs):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def dataset_class(self):
|
||||
try:
|
||||
return locate_and_import_class(self.params.class_name, folder_path='datasets')
|
||||
except AttributeError as e:
|
||||
raise AttributeError(f'The dataset alias you provided ("{self.params.class_name}") ' +
|
||||
f'was not found!\n' +
|
||||
f'{e}')
|
||||
|
||||
def save_to_disk(self, model_path):
|
||||
Path(model_path, exist_ok=True).mkdir(parents=True, exist_ok=True)
|
||||
if not (model_path / 'model_class.obj').exists():
|
||||
@ -86,8 +85,12 @@ try:
|
||||
def test_epoch_end(self, outputs):
|
||||
raise NotImplementedError
|
||||
|
||||
def init_weights(self, in_place_init_func_=nn.init.xavier_uniform_):
|
||||
weight_initializer = WeightInit(in_place_init_function=in_place_init_func_)
|
||||
def init_weights(self):
|
||||
if isinstance(self._weight_init, str):
|
||||
mod = __import__('torch.nn.init', fromlist=[self._weight_init])
|
||||
self._weight_init = getattr(mod, self._weight_init)
|
||||
assert callable(self._weight_init)
|
||||
weight_initializer = WeightInit(in_place_init_function=self._weight_init)
|
||||
self.apply(weight_initializer)
|
||||
|
||||
module_types = (LightningBaseModule, nn.Module,)
|
||||
|
29
utils/_basedatamodule.py
Normal file
29
utils/_basedatamodule.py
Normal file
@ -0,0 +1,29 @@
|
||||
from pytorch_lightning import LightningDataModule
|
||||
|
||||
|
||||
# Dataset Options
|
||||
from ml_lib.utils.tools import add_argparse_args
|
||||
|
||||
DATA_OPTION_test = 'test'
|
||||
DATA_OPTION_devel = 'devel'
|
||||
DATA_OPTION_train = 'train'
|
||||
DATA_OPTIONS = [DATA_OPTION_train, DATA_OPTION_devel, DATA_OPTION_test]
|
||||
|
||||
|
||||
class _BaseDataModule(LightningDataModule):
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
return self.datasets[DATA_OPTION_train].sample_shape
|
||||
|
||||
@classmethod
|
||||
def add_argparse_args(cls, parent_parser):
|
||||
return add_argparse_args(cls, parent_parser)
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.datasets = dict()
|
||||
|
||||
def transfer_batch_to_device(self, batch, device):
|
||||
return batch.to(device)
|
||||
|
@ -1,19 +1,34 @@
|
||||
from abc import ABC
|
||||
import inspect
|
||||
from argparse import ArgumentParser
|
||||
from pathlib import Path
|
||||
|
||||
import os
|
||||
from pytorch_lightning.loggers.base import LightningLoggerBase
|
||||
from pytorch_lightning.loggers.neptune import NeptuneLogger
|
||||
from neptune.api_exceptions import ProjectNotFound
|
||||
# noinspection PyUnresolvedReferences
|
||||
|
||||
from pytorch_lightning.loggers.csv_logs import CSVLogger
|
||||
from pytorch_lightning.utilities import argparse_utils
|
||||
|
||||
from .config import Config
|
||||
from ml_lib.utils.tools import add_argparse_args
|
||||
|
||||
|
||||
class Logger(LightningLoggerBase, ABC):
|
||||
class Logger(LightningLoggerBase):
|
||||
|
||||
@classmethod
|
||||
def from_argparse_args(cls, args, **kwargs):
|
||||
return argparse_utils.from_argparse_args(cls, args, **kwargs)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
media_dir = 'media'
|
||||
|
||||
@classmethod
|
||||
def add_argparse_args(cls, parent_parser):
|
||||
return add_argparse_args(cls, parent_parser)
|
||||
|
||||
@property
|
||||
def experiment(self):
|
||||
if self.debug:
|
||||
@ -25,27 +40,23 @@ class Logger(LightningLoggerBase, ABC):
|
||||
def log_dir(self):
|
||||
return Path(self.csvlogger.experiment.log_dir)
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self.config.name
|
||||
|
||||
@property
|
||||
def project_name(self):
|
||||
return f"{self.config.project.owner}/{self.config.project.name.replace('_', '-')}"
|
||||
return f"{self.owner}/{self.name.replace('_', '-')}"
|
||||
|
||||
@property
|
||||
def version(self):
|
||||
return self.config.get('main', 'seed')
|
||||
return self.seed
|
||||
|
||||
@property
|
||||
def save_dir(self):
|
||||
return self.log_dir
|
||||
|
||||
@property
|
||||
def outpath(self):
|
||||
return Path(self.config.train.outpath) / self.config.model.type
|
||||
return Path(self.root_out) / self.model_name
|
||||
|
||||
@property
|
||||
def exp_path(self):
|
||||
return Path(self.outpath) / self.name
|
||||
|
||||
def __init__(self, config: Config):
|
||||
def __init__(self, owner, neptune_key, model_name, project_name='', outpath='output', seed=69, debug=False):
|
||||
"""
|
||||
params (dict|None): Optional. Parameters of the experiment. After experiment creation params are read-only.
|
||||
Parameters are displayed in the experiment’s Parameters section and each key-value pair can be
|
||||
@ -59,19 +70,19 @@ class Logger(LightningLoggerBase, ABC):
|
||||
"""
|
||||
super(Logger, self).__init__()
|
||||
|
||||
self.config = config
|
||||
self.debug = self.config.main.debug
|
||||
if self.debug:
|
||||
self.config.add_section('project')
|
||||
self.config.set('project', 'owner', 'testuser')
|
||||
self.config.set('project', 'name', 'test')
|
||||
self.config.set('project', 'neptune_key', 'XXX')
|
||||
self.debug = debug
|
||||
self._name = project_name or Path(os.getcwd()).name if not self.debug else 'test'
|
||||
self.owner = owner if not self.debug else 'testuser'
|
||||
self.neptune_key = neptune_key if not self.debug else 'XXX'
|
||||
self.root_out = outpath if not self.debug else 'debug_out'
|
||||
self.seed = seed
|
||||
self.model_name = model_name
|
||||
|
||||
self._csvlogger_kwargs = dict(save_dir=self.outpath, version=self.version, name=self.name)
|
||||
self._neptune_kwargs = dict(offline_mode=self.debug,
|
||||
api_key=self.config.project.neptune_key,
|
||||
api_key=self.neptune_key,
|
||||
experiment_name=self.name,
|
||||
project_name=self.project_name,
|
||||
params=self.config.model_paramters)
|
||||
project_name=self.project_name)
|
||||
try:
|
||||
self.neptunelogger = NeptuneLogger(**self._neptune_kwargs)
|
||||
except ProjectNotFound as e:
|
||||
@ -79,7 +90,6 @@ class Logger(LightningLoggerBase, ABC):
|
||||
print(e)
|
||||
|
||||
self.csvlogger = CSVLogger(**self._csvlogger_kwargs)
|
||||
self.log_config_as_ini()
|
||||
|
||||
def log_hyperparams(self, params):
|
||||
self.neptunelogger.log_hyperparams(params)
|
||||
@ -95,19 +105,15 @@ class Logger(LightningLoggerBase, ABC):
|
||||
self.csvlogger.close()
|
||||
self.neptunelogger.close()
|
||||
|
||||
def log_config_as_ini(self):
|
||||
self.config.write(self.log_dir / 'config.ini')
|
||||
|
||||
def log_text(self, name, text, step_nb=0, **_):
|
||||
def log_text(self, name, text, **_):
|
||||
# TODO Implement Offline variant.
|
||||
self.neptunelogger.log_text(name, text, step_nb)
|
||||
self.neptunelogger.log_text(name, text)
|
||||
|
||||
def log_metric(self, metric_name, metric_value, **kwargs):
|
||||
self.csvlogger.log_metrics(dict(metric_name=metric_value))
|
||||
self.neptunelogger.log_metric(metric_name, metric_value, **kwargs)
|
||||
|
||||
def log_image(self, name, image, ext='png', **kwargs):
|
||||
|
||||
step = kwargs.get('step', None)
|
||||
image_name = f'{step}_{name}' if step is not None else name
|
||||
image_path = self.log_dir / self.media_dir / f'{image_name}.{ext[1:] if ext.startswith(".") else ext}'
|
||||
|
@ -13,6 +13,10 @@ from torch import nn
|
||||
# Hyperparamter Object
|
||||
class ModelParameters(Namespace, Mapping):
|
||||
|
||||
@property
|
||||
def as_dict(self):
|
||||
return {key: self.get(key) if key != 'activation' else self.activation_as_string for key in self.keys()}
|
||||
|
||||
@property
|
||||
def activation_as_string(self):
|
||||
return self['activation'].lower()
|
||||
@ -50,13 +54,7 @@ class ModelParameters(Namespace, Mapping):
|
||||
if name == 'activation':
|
||||
return self._activations[self['activation'].lower()]
|
||||
else:
|
||||
try:
|
||||
return super(ModelParameters, self).__getattribute__(name)
|
||||
except AttributeError as e:
|
||||
if name == 'stretch':
|
||||
return False
|
||||
else:
|
||||
return None
|
||||
return super(ModelParameters, self).__getattribute__(name)
|
||||
|
||||
_activations = dict(
|
||||
leaky_relu=nn.LeakyReLU,
|
||||
@ -88,16 +86,20 @@ class SavedLightningModels(object):
|
||||
model = torch.load(models_root_path / 'model_class.obj')
|
||||
assert model is not None
|
||||
|
||||
return cls(weights=str(checkpoint_path), model=model)
|
||||
return cls(weights=checkpoint_path, model=model)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.weights: str = kwargs.get('weights', '')
|
||||
self.weights: Path = Path(kwargs.get('weights', ''))
|
||||
self.hparams: Path = self.weights.parent / 'hparams.yaml'
|
||||
|
||||
self.model = kwargs.get('model', None)
|
||||
assert self.model is not None
|
||||
|
||||
def restore(self):
|
||||
pretrained_model = self.model.load_from_checkpoint(self.weights)
|
||||
|
||||
pretrained_model = self.model.load_from_checkpoint(self.weights.__str__())
|
||||
# , hparams_file=self.hparams.__str__())
|
||||
pretrained_model.eval()
|
||||
pretrained_model.freeze()
|
||||
return pretrained_model
|
||||
return pretrained_model
|
||||
|
||||
|
@ -1,6 +1,9 @@
|
||||
import importlib
|
||||
import inspect
|
||||
import pickle
|
||||
import shelve
|
||||
from argparse import ArgumentParser
|
||||
from ast import literal_eval
|
||||
from pathlib import Path, PurePath
|
||||
from typing import Union
|
||||
|
||||
@ -9,6 +12,13 @@ import torch
|
||||
import random
|
||||
|
||||
|
||||
def auto_cast(a):
|
||||
try:
|
||||
return literal_eval(a)
|
||||
except:
|
||||
return a
|
||||
|
||||
|
||||
def to_one_hot(idx_array, max_classes):
|
||||
one_hot = np.zeros((idx_array.size, max_classes))
|
||||
one_hot[np.arange(idx_array.size), idx_array] = 1
|
||||
@ -54,3 +64,20 @@ def locate_and_import_class(class_name, folder_path: Union[str, PurePath] = ''):
|
||||
continue
|
||||
raise AttributeError(f'Check the Model name. Possible model files are:\n{[x.name for x in module_paths]}')
|
||||
|
||||
|
||||
def add_argparse_args(cls, parent_parser):
|
||||
parser = ArgumentParser(parents=[parent_parser], add_help=False)
|
||||
full_arg_spec = inspect.getfullargspec(cls.__init__)
|
||||
n_non_defaults = len(full_arg_spec.args) - (len(full_arg_spec.defaults) if full_arg_spec.defaults else 0)
|
||||
for idx, argument in enumerate(full_arg_spec.args):
|
||||
if argument == 'self':
|
||||
continue
|
||||
if idx < n_non_defaults:
|
||||
parser.add_argument(f'--{argument}', type=int)
|
||||
else:
|
||||
argument_type = type(argument)
|
||||
parser.add_argument(f'--{argument}',
|
||||
type=argument_type,
|
||||
default=full_arg_spec.defaults[idx - n_non_defaults]
|
||||
)
|
||||
return parser
|
||||
|
Loading…
x
Reference in New Issue
Block a user