Compare commits
54 Commits
Author | SHA1 | Date | |
---|---|---|---|
ab01006eae | |||
faa27c3cf9 | |||
abe870d106 | |||
1d1b154460 | |||
6816e423ff | |||
d3e7bf7efb | |||
ed260f1c2a | |||
675312537f | |||
43cf0ad00d | |||
479514c9e7 | |||
fff5e6e00a | |||
8e719af554 | |||
10bf376ac3 | |||
fc4617c9d8 | |||
f89f0f8528 | |||
b5e3e5aec1 | |||
a966321576 | |||
010176e80b | |||
f6156c6cde | |||
93103aba01 | |||
62d9eb6e8f | |||
c6fdaa24aa | |||
cfeea05673 | |||
14ed4e0117 | |||
13812b83b5 | |||
f296ba78b9 | |||
5848b528f0 | |||
6bc9447ce1 | |||
a4b6c698c3 | |||
4b089729b2 | |||
c7d17a9898 | |||
7770b29c14 | |||
53aa11521d | |||
aea34de964 | |||
3f8122484b | |||
12d36047ef | |||
76308888e0 | |||
0cff42f951 | |||
ece80ecbed | |||
d3fa32ae7b | |||
2acf91335f | |||
5987efb169 | |||
77ea043907 | |||
4b4051c045 | |||
8cec323286 | |||
235743b225 | |||
28d0034269 | |||
b87a56e8c6 | |||
196b1af7ae | |||
fcd5ee4d29 | |||
f290d5a8d8 | |||
645b7905e8 | |||
206aca10b3 | |||
e423d6fe31 |
7
.gitignore
vendored
7
.gitignore
vendored
@ -1,6 +1 @@
|
||||
/.idea/
|
||||
# my own stuff
|
||||
|
||||
/data
|
||||
/.idea
|
||||
/ml_lib
|
||||
.idea
|
||||
|
10
README.md
10
README.md
@ -9,9 +9,10 @@ Clone it to find a collection of:
|
||||
- Utility Function for Model I/O
|
||||
- DL Modules
|
||||
- A Plotter Object
|
||||
- Audio Related Tools and Funtion
|
||||
- Audio related Tools and Funtion
|
||||
- Librosa
|
||||
- Scipy Signal
|
||||
- PointCloud related Tools and Functions
|
||||
|
||||
###Notes:
|
||||
- Use directory links to link from your main project folder to the ml_lib folder. Pycharm will automatically use
|
||||
@ -19,11 +20,10 @@ Clone it to find a collection of:
|
||||
\
|
||||
\
|
||||
For Windows Users:
|
||||
```
|
||||
``` bash
|
||||
mklink /d "ml_lib" "..\ml_lib""
|
||||
```
|
||||
For Unix User:
|
||||
``` bash
|
||||
ln -s ../ml_lib ml_lib
|
||||
```
|
||||
TBA
|
||||
```
|
||||
- Cheers
|
BIN
__pycache__/__init__.cpython-37.pyc
Normal file
BIN
__pycache__/__init__.cpython-37.pyc
Normal file
Binary file not shown.
5
_templates/new_project/.gitignore
vendored
Normal file
5
_templates/new_project/.gitignore
vendored
Normal file
@ -0,0 +1,5 @@
|
||||
# my own stuff
|
||||
|
||||
/data
|
||||
/.idea
|
||||
/ml_lib
|
@ -2,5 +2,16 @@ from torch.utils.data import Dataset
|
||||
|
||||
|
||||
class TemplateDataset(Dataset):
|
||||
|
||||
@property
|
||||
def sample_shape(self):
|
||||
return self[0][0].shape
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(TemplateDataset, self).__init__()
|
||||
super(TemplateDataset, self).__init__()
|
||||
|
||||
def __len__(self):
|
||||
pass
|
||||
|
||||
def __getitem__(self, item):
|
||||
return item
|
||||
|
@ -7,10 +7,9 @@ import torch
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
|
||||
|
||||
from modules.utils import LightningBaseModule
|
||||
from utils.config import Config
|
||||
from utils.logging import Logger
|
||||
from utils.model_io import SavedLightningModels
|
||||
from ml_lib.modules.util import LightningBaseModule
|
||||
from ml_lib.utils.config import Config
|
||||
from ml_lib.utils.loggers import LightningLogger
|
||||
|
||||
warnings.filterwarnings('ignore', category=FutureWarning)
|
||||
warnings.filterwarnings('ignore', category=UserWarning)
|
||||
@ -21,7 +20,7 @@ def run_lightning_loop(config_obj):
|
||||
# Logging
|
||||
# ================================================================================
|
||||
# Logger
|
||||
with Logger(config_obj) as logger:
|
||||
with LightningLogger(config_obj) as logger:
|
||||
# Callbacks
|
||||
# =============================================================================
|
||||
# Checkpoint Saving
|
||||
@ -44,11 +43,6 @@ def run_lightning_loop(config_obj):
|
||||
# Init
|
||||
model: LightningBaseModule = config_obj.model_class(config_obj.model_paramters)
|
||||
model.init_weights(torch.nn.init.xavier_normal_)
|
||||
if model.name == 'CNNRouteGeneratorDiscriminated':
|
||||
# ToDo: Make this dependent on the used seed
|
||||
path = logger.outpath / 'classifier_cnn' / 'version_0'
|
||||
disc_model = SavedLightningModels.load_checkpoint(path).restore()
|
||||
model.set_discriminator(disc_model)
|
||||
|
||||
# Trainer
|
||||
# =============================================================================
|
||||
@ -70,8 +64,8 @@ def run_lightning_loop(config_obj):
|
||||
trainer.fit(model)
|
||||
|
||||
# Save the last state & all parameters
|
||||
trainer.save_checkpoint(logger.log_dir / 'weights.ckpt')
|
||||
model.save_to_disk(logger.log_dir)
|
||||
trainer.save_checkpoint(config_obj.exp_path.log_dir / 'weights.ckpt')
|
||||
model.save_to_disk(config_obj.exp_path)
|
||||
|
||||
# Evaluate It
|
||||
if config_obj.main.eval:
|
||||
|
@ -1,6 +1,6 @@
|
||||
import warnings
|
||||
|
||||
from utils.config import Config
|
||||
from ml_lib._templates.new_project.utils.project_config import Config
|
||||
|
||||
warnings.filterwarnings('ignore', category=FutureWarning)
|
||||
warnings.filterwarnings('ignore', category=UserWarning)
|
||||
@ -8,17 +8,16 @@ warnings.filterwarnings('ignore', category=UserWarning)
|
||||
# Imports
|
||||
# =============================================================================
|
||||
|
||||
from _templates.new_project.main import run_lightning_loop, args
|
||||
from ml_lib._templates.new_project.main import run_lightning_loop, args
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
# Model Settings
|
||||
config = Config().read_namespace(args)
|
||||
# bias, activation, model, norm, max_epochs, filters
|
||||
cnn_classifier = dict(train_epochs=10, model_use_bias=True, model_use_norm=True, model_activation='leaky_relu',
|
||||
model_type='classifier_cnn', model_filters=[16, 32, 64], data_batchsize=512)
|
||||
# bias, activation, model, norm, max_epochs, sr, feature_mixed_dim, filters
|
||||
# bias, activation, model, norm, max_epochs
|
||||
cnn_classifier = dict(train_epochs=10, model_use_bias=True, model_use_norm=True, data_batchsize=512)
|
||||
# bias, activation, model, norm, max_epochs
|
||||
|
||||
for arg_dict in [cnn_classifier]:
|
||||
for seed in range(5):
|
||||
|
@ -11,13 +11,13 @@ from torch.utils.data import DataLoader
|
||||
from torchcontrib.optim import SWA
|
||||
from torchvision.transforms import Compose
|
||||
|
||||
from _templates.new_project.datasets.template_dataset import TemplateDataset
|
||||
from ml_lib._templates.new_project.datasets.template_dataset import TemplateDataset
|
||||
|
||||
from audio_toolset.audio_io import NormalizeLocal
|
||||
from modules.utils import LightningBaseModule
|
||||
from utils.transforms import ToTensor
|
||||
from ml_lib.audio_toolset.audio_io import NormalizeLocal
|
||||
from ml_lib.modules.util import LightningBaseModule
|
||||
from ml_lib.utils.transforms import ToTensor
|
||||
|
||||
from _templates.new_project.utils.project_config import GlobalVar as GlobalVars
|
||||
from ml_lib._templates.new_project.utils.project_config import GlobalVar as GlobalVars
|
||||
|
||||
|
||||
class BaseOptimizerMixin:
|
||||
@ -61,10 +61,11 @@ class BaseTrainMixin:
|
||||
assert isinstance(self, LightningBaseModule)
|
||||
keys = list(outputs[0].keys())
|
||||
|
||||
summary_dict = dict(log={f'mean_{key}': torch.mean(torch.stack([output[key]
|
||||
for output in outputs]))
|
||||
for key in keys if 'loss' in key})
|
||||
return summary_dict
|
||||
summary_dict = {f'mean_{key}': torch.mean(torch.stack([output[key]
|
||||
for output in outputs]))
|
||||
for key in keys if 'loss' in key}
|
||||
for key in summary_dict.keys():
|
||||
self.log(key, summary_dict[key])
|
||||
|
||||
|
||||
class BaseValMixin:
|
||||
@ -83,16 +84,16 @@ class BaseValMixin:
|
||||
|
||||
def validation_epoch_end(self, outputs, *_, **__):
|
||||
assert isinstance(self, LightningBaseModule)
|
||||
summary_dict = dict(log=dict())
|
||||
summary_dict = dict()
|
||||
# In case of Multiple given dataloader this will outputs will be: list[list[dict[]]]
|
||||
# for output_idx, output in enumerate(outputs):
|
||||
# else:list[dict[]]
|
||||
keys = list(outputs.keys())
|
||||
# Add Every Value das has a "loss" in it, by calc. mean over all occurences.
|
||||
summary_dict['log'].update({f'mean_{key}': torch.mean(torch.stack([output[key]
|
||||
for output in outputs]))
|
||||
for key in keys if 'loss' in key}
|
||||
)
|
||||
summary_dict.update({f'mean_{key}': torch.mean(torch.stack([output[key]
|
||||
for output in outputs]))
|
||||
for key in keys if 'loss' in key}
|
||||
)
|
||||
"""
|
||||
# Additional Score like the unweighted Average Recall:
|
||||
# UnweightedAverageRecall
|
||||
@ -107,7 +108,8 @@ class BaseValMixin:
|
||||
summary_dict['log'].update({f'uar_score': uar_score})
|
||||
"""
|
||||
|
||||
return summary_dict
|
||||
for key in summary_dict.keys():
|
||||
self.log(key, summary_dict[key])
|
||||
|
||||
|
||||
class BinaryMaskDatasetMixin:
|
||||
|
@ -1,8 +1,5 @@
|
||||
from argparse import Namespace
|
||||
|
||||
from utils.config import Config
|
||||
|
||||
|
||||
class GlobalVar(Namespace):
|
||||
# Labels for classes
|
||||
LEFT = 1
|
||||
@ -21,10 +18,3 @@ class GlobalVar(Namespace):
|
||||
train='train',
|
||||
vali='vali',
|
||||
test='test'
|
||||
|
||||
|
||||
class ThisConfig(Config):
|
||||
|
||||
@property
|
||||
def _model_map(self):
|
||||
return dict()
|
||||
|
0
additions/__init__.py
Normal file
0
additions/__init__.py
Normal file
75
additions/losses.py
Normal file
75
additions/losses.py
Normal file
@ -0,0 +1,75 @@
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class FocalLoss(nn.modules.loss._WeightedLoss):
|
||||
def __init__(self, weight=None, gamma=2,reduction='mean'):
|
||||
super(FocalLoss, self).__init__(weight,reduction=reduction)
|
||||
self.gamma = gamma
|
||||
self.weight = weight # weight parameter will act as the alpha parameter to balance class weights
|
||||
|
||||
def forward(self, input, target):
|
||||
|
||||
ce_loss = F.cross_entropy(input, target, reduction=self.reduction, weight=self.weight)
|
||||
pt = torch.exp(-ce_loss)
|
||||
focal_loss = ((1 - pt) ** self.gamma * ce_loss).mean()
|
||||
return focal_loss
|
||||
|
||||
|
||||
class FocalLossRob(nn.Module):
|
||||
# taken from https://github.com/mathiaszinnen/focal_loss_torch/blob/main/focal_loss/focal_loss.py
|
||||
def __init__(self, alpha=1, gamma=2, reduction: str = 'mean'):
|
||||
super().__init__()
|
||||
if reduction not in ['mean', 'none', 'sum']:
|
||||
raise NotImplementedError('Reduction {} not implemented.'.format(reduction))
|
||||
self.reduction = reduction
|
||||
self.alpha = alpha
|
||||
self.gamma = gamma
|
||||
|
||||
def forward(self, x, target):
|
||||
x = x.clamp(1e-7, 1. - 1e-7) # own addition
|
||||
p_t = torch.where(target == 1, x, 1-x)
|
||||
fl = - 1 * (1 - p_t) ** self.gamma * torch.log(p_t)
|
||||
fl = torch.where(target == 1, fl * self.alpha, fl)
|
||||
return self._reduce(fl)
|
||||
|
||||
def _reduce(self, x):
|
||||
if self.reduction == 'mean':
|
||||
return x.mean()
|
||||
elif self.reduction == 'sum':
|
||||
return x.sum()
|
||||
else:
|
||||
return x
|
||||
|
||||
|
||||
class DQN_MSELoss(object):
|
||||
|
||||
def __init__(self, agent_net, target_net, gamma):
|
||||
self.agent_net = agent_net
|
||||
self.target_net = target_net
|
||||
self.gamma = gamma
|
||||
|
||||
def __call__(self, batch: Tuple[torch.Tensor, ...]) -> torch.Tensor:
|
||||
"""
|
||||
Calculates the mse loss using a mini batch from the replay buffer
|
||||
Args:
|
||||
batch: current mini batch of replay data
|
||||
Returns:
|
||||
loss
|
||||
"""
|
||||
states, actions, rewards, dones, next_states = batch
|
||||
|
||||
actions = actions.to(torch.int64)
|
||||
state_action_values = self.agent_net(states).gather(1, actions.unsqueeze(-1)).squeeze(-1)
|
||||
|
||||
with torch.no_grad():
|
||||
next_state_values = self.target_net(next_states).max(1)[0]
|
||||
next_state_values[dones] = 0.0
|
||||
next_state_values = next_state_values.detach()
|
||||
|
||||
expected_state_action_values = next_state_values * self.gamma + rewards
|
||||
|
||||
return F.mse_loss(state_action_values, expected_state_action_values)
|
@ -1,22 +1,38 @@
|
||||
import librosa
|
||||
try:
|
||||
import librosa
|
||||
except ImportError: # pragma: no-cover
|
||||
raise ImportError('You want to use `librosa` plugins which are not installed yet,' # pragma: no-cover
|
||||
' install it with `pip install librosa`.')
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class Speed(object):
|
||||
|
||||
def __init__(self, max_ratio=0.3, speed_factor=1):
|
||||
self.speed_factor = speed_factor
|
||||
self.max_ratio = max_ratio
|
||||
def __init__(self, max_amount=0.3, speed_min=1, speed_max=1):
|
||||
self.speed_max = speed_max if speed_max else 1
|
||||
self.speed_min = speed_min if speed_min else 1
|
||||
# noinspection PyTypeChecker
|
||||
self.max_amount = min(max(0, max_amount), 1)
|
||||
|
||||
def __repr__(self):
|
||||
return f'{self.__class__.__name__}({self.__dict__})'
|
||||
|
||||
def __call__(self, x):
|
||||
if not all([self.speed_factor, self.max_ratio]):
|
||||
if self.speed_min == 1 and self.speed_max == 1:
|
||||
return x
|
||||
start = int(np.random.randint(0, x.shape[-1],1))
|
||||
end = int((np.random.uniform(0, self.max_ratio, 1) * x.shape[-1]) + start)
|
||||
start = int(np.random.randint(low=0, high=x.shape[-1], size=1))
|
||||
width = np.random.uniform(low=0, high=self.max_amount, size=1) * x.shape[-1]
|
||||
end = int(width + start)
|
||||
end = min(end, x.shape[-1])
|
||||
try:
|
||||
speed_factor = float(np.random.uniform(min(self.speed_factor, 1), max(self.speed_factor, 1), 1))
|
||||
aug_data = librosa.effects.time_stretch(x[start:end], speed_factor)
|
||||
return np.concatenate((x[:start], aug_data, x[end:]), axis=0)[:x.shape[-1]]
|
||||
speed_factor = float(np.random.uniform(low=self.speed_min, high=self.speed_max, size=1))
|
||||
aug_data = librosa.effects.time_stretch(y=x[start:end], rate=speed_factor)
|
||||
x_aug = np.concatenate((x[:start], aug_data, x[end:]), axis=0)[:x.shape[-1]]
|
||||
if speed_factor > 1:
|
||||
embedding = np.zeros_like(x)
|
||||
embedding[:x_aug.shape[0]] = x_aug
|
||||
x_aug = embedding
|
||||
return x_aug
|
||||
except ValueError:
|
||||
return x
|
||||
|
@ -1,8 +1,16 @@
|
||||
import librosa
|
||||
from scipy.signal import butter, lfilter
|
||||
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
import librosa
|
||||
except ImportError: # pragma: no-cover
|
||||
raise ImportError('You want to use `librosa` plugins which are not installed yet,' # pragma: no-cover
|
||||
' install it with `pip install librosa`.')
|
||||
try:
|
||||
from scipy.signal import butter, lfilter
|
||||
except ImportError: # pragma: no-cover
|
||||
raise ImportError('You want to use `scikit` plugins which are not installed yet,' # pragma: no-cover
|
||||
' install it with `pip install scikit-learn`.')
|
||||
|
||||
|
||||
def scale_minmax(x, min_val=0.0, max_val=1.0):
|
||||
x_std = (x - x.min()) / (x.max() - x.min())
|
||||
@ -28,6 +36,9 @@ class MFCC(object):
|
||||
def __init__(self, **kwargs):
|
||||
self.__dict__.update(kwargs)
|
||||
|
||||
def __repr__(self):
|
||||
return f'{self.__class__.__name__}({self.__dict__})'
|
||||
|
||||
def __call__(self, y):
|
||||
mfcc = librosa.feature.mfcc(y, **self.__dict__)
|
||||
return mfcc
|
||||
@ -35,27 +46,38 @@ class MFCC(object):
|
||||
|
||||
class NormalizeLocal(object):
|
||||
def __init__(self):
|
||||
self.cache: np.ndarray
|
||||
pass
|
||||
|
||||
def __repr__(self):
|
||||
return f'{self.__class__.__name__}({self.__dict__})'
|
||||
|
||||
def __call__(self, x: np.ndarray):
|
||||
|
||||
x[np.isnan(x)] = 0
|
||||
x[np.isinf(x)] = 0
|
||||
|
||||
mean = x.mean()
|
||||
std = x.std() + 0.0001
|
||||
|
||||
# Pytorch Version:
|
||||
# x = x.__sub__(mean).__div__(std)
|
||||
# tensor = tensor.__sub__(mean).__div__(std)
|
||||
# Numpy Version
|
||||
x = (x - mean) / std
|
||||
|
||||
x[np.isnan(x)] = 0
|
||||
x[np.isinf(x)] = 0
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class NormalizeMelband(object):
|
||||
def __init__(self):
|
||||
self.cache: np.ndarray
|
||||
|
||||
pass
|
||||
|
||||
def __repr__(self):
|
||||
return f'{self.__class__.__name__}({self.__dict__})'
|
||||
|
||||
def __call__(self, x: np.ndarray):
|
||||
mean = x.mean(-1).unsqueeze(-1)
|
||||
std = x.std(-1).unsqueeze(-1)
|
||||
@ -66,10 +88,13 @@ class NormalizeMelband(object):
|
||||
return x
|
||||
|
||||
|
||||
class AudioToMel(object):
|
||||
def __init__(self, amplitude_to_db=False, power_to_db=False, **kwargs):
|
||||
class LibrosaAudioToMel(object):
|
||||
def __init__(self, amplitude_to_db=False, power_to_db=False, **mel_kwargs):
|
||||
assert not all([amplitude_to_db, power_to_db]), "Choose amplitude_to_db or power_to_db, not both!"
|
||||
self.mel_kwargs = kwargs
|
||||
# Mel kwargs are:
|
||||
# sr n_mels n_fft hop_length
|
||||
|
||||
self.mel_kwargs = mel_kwargs
|
||||
self.amplitude_to_db = amplitude_to_db
|
||||
self.power_to_db = power_to_db
|
||||
|
||||
@ -89,6 +114,9 @@ class PowerToDB(object):
|
||||
def __init__(self, running_max=False):
|
||||
self.running_max = 0 if running_max else None
|
||||
|
||||
def __repr__(self):
|
||||
return f'{self.__class__.__name__}({self.__dict__})'
|
||||
|
||||
def __call__(self, x):
|
||||
if self.running_max is not None:
|
||||
self.running_max = max(np.max(x), self.running_max)
|
||||
@ -100,6 +128,9 @@ class LowPass(object):
|
||||
def __init__(self, sr=16000):
|
||||
self.sr = sr
|
||||
|
||||
def __repr__(self):
|
||||
return f'{self.__class__.__name__}({self.__dict__})'
|
||||
|
||||
def __call__(self, x):
|
||||
return butter_lowpass_filter(x, 1000, 1)
|
||||
|
||||
@ -108,12 +139,16 @@ class MelToImage(object):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def __repr__(self):
|
||||
return f'{self.__class__.__name__}({self.__dict__})'
|
||||
|
||||
def __call__(self, x):
|
||||
# Source to Solution: https://stackoverflow.com/a/57204349
|
||||
mels = np.log(x + 1e-9) # add small number to avoid log(0)
|
||||
|
||||
# min-max scale to fit inside 8-bit range
|
||||
img = scale_minmax(mels, 0, 255).astype(np.uint8)
|
||||
img = np.flip(img, axis=0) # put low frequencies at the bottom in image
|
||||
img = scale_minmax(mels, 0, 255)
|
||||
img = np.flip(img) # put low frequencies at the bottom in image
|
||||
img = 255 - img # invert. make black==more energy
|
||||
img = img.astype(np.float)
|
||||
return img
|
||||
|
71
audio_toolset/audio_to_mel_dataset.py
Normal file
71
audio_toolset/audio_to_mel_dataset.py
Normal file
@ -0,0 +1,71 @@
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import pickle
|
||||
from abc import ABC
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision.transforms import Compose
|
||||
|
||||
from ml_lib.audio_toolset.audio_io import LibrosaAudioToMel, MelToImage
|
||||
from ml_lib.audio_toolset.mel_dataset import TorchMelDataset
|
||||
|
||||
|
||||
import librosa
|
||||
|
||||
|
||||
class LibrosaAudioToMelDataset(Dataset):
|
||||
|
||||
@property
|
||||
def audio_file_duration(self):
|
||||
return librosa.get_duration(sr=self.mel_kwargs.get('sr', None), filename=self.audio_path)
|
||||
|
||||
@property
|
||||
def sampling_rate(self):
|
||||
return self.mel_kwargs.get('sr', None)
|
||||
|
||||
def __init__(self, audio_file_path, label, sample_segment_len=0, sample_hop_len=0, reset=False,
|
||||
audio_augmentations=None, mel_augmentations=None, mel_kwargs=None, **kwargs):
|
||||
super(LibrosaAudioToMelDataset, self).__init__()
|
||||
|
||||
# audio_file, sampling_rate = librosa.load(self.audio_path, sr=sampling_rate)
|
||||
mel_kwargs.update(sr=mel_kwargs.get('sr', None) or librosa.get_samplerate(audio_file_path))
|
||||
self.mel_kwargs = mel_kwargs
|
||||
self.reset = reset
|
||||
self.audio_path = Path(audio_file_path)
|
||||
|
||||
mel_folder_suffix = self.audio_path.parent.parent.name
|
||||
self.mel_folder = Path(str(self.audio_path)
|
||||
.replace(mel_folder_suffix, f'{mel_folder_suffix}_mel_folder')).parent.parent
|
||||
|
||||
self.mel_file_path = self.mel_folder / f'{self.audio_path.stem}.npy'
|
||||
|
||||
self.audio_augmentations = audio_augmentations
|
||||
|
||||
self.dataset = TorchMelDataset(self.mel_file_path, sample_segment_len, sample_hop_len, label,
|
||||
self.audio_file_duration, mel_kwargs['sr'], mel_kwargs['hop_length'],
|
||||
mel_kwargs['n_mels'], transform=mel_augmentations)
|
||||
|
||||
self._mel_transform = Compose([LibrosaAudioToMel(power_to_db=False, **mel_kwargs),
|
||||
MelToImage()
|
||||
])
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self.dataset[item]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
||||
|
||||
def build_mel(self):
|
||||
if self.reset:
|
||||
self.mel_file_path.unlink(missing_ok=True)
|
||||
if not self.mel_file_path.exists():
|
||||
self.mel_file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with self.audio_path.open(mode='rb') as audio_file:
|
||||
raw_sample, _ = librosa.core.load(audio_file, sr=self.sampling_rate)
|
||||
mel_sample = self._mel_transform(raw_sample)
|
||||
with self.mel_file_path.open('wb') as mel_file:
|
||||
pickle.dump(mel_sample, mel_file, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
else:
|
||||
pass
|
||||
|
||||
return self.mel_file_path.exists()
|
@ -1,17 +1,20 @@
|
||||
import numpy as np
|
||||
|
||||
from ml_lib.utils.transforms import _BaseTransformation
|
||||
|
||||
class NoiseInjection(object):
|
||||
|
||||
def __init__(self, noise_factor: float, sigma=0.5, mu=0.5):
|
||||
assert noise_factor >= 0, f'max_shift_ratio has to be greater then 0, but was: {noise_factor}.'
|
||||
class NoiseInjection(_BaseTransformation):
|
||||
|
||||
def __init__(self, noise_factor: float, sigma=1, mu=0):
|
||||
super(NoiseInjection, self).__init__()
|
||||
assert noise_factor >= 0, f'noise_factor has to be greater then 0, but was: {noise_factor}.'
|
||||
self.mu = mu
|
||||
self.sigma = sigma
|
||||
self.noise_factor = noise_factor
|
||||
|
||||
def __call__(self, x: np.ndarray):
|
||||
if self.noise_factor:
|
||||
noise = np.random.uniform(0, self.noise_factor, size=x.shape)
|
||||
noise = np.random.normal(self.mu, self.sigma, size=x.shape) * self.noise_factor
|
||||
augmented_data = x + x * noise
|
||||
# Cast back to same data type
|
||||
augmented_data = augmented_data.astype(x.dtype)
|
||||
@ -20,14 +23,15 @@ class NoiseInjection(object):
|
||||
return x
|
||||
|
||||
|
||||
class LoudnessManipulator(object):
|
||||
class LoudnessManipulator(_BaseTransformation):
|
||||
|
||||
def __init__(self, max_factor: float):
|
||||
super(LoudnessManipulator, self).__init__()
|
||||
assert 1 > max_factor >= 0, f'max_shift_ratio has to be between [0,1], but was: {max_factor}.'
|
||||
|
||||
self.max_factor = max_factor
|
||||
|
||||
def __call__(self, x: np.ndarray):
|
||||
def __call__(self, x):
|
||||
if self.max_factor:
|
||||
augmented_data = x + x * (np.random.random() * self.max_factor)
|
||||
# Cast back to same data type
|
||||
@ -37,11 +41,12 @@ class LoudnessManipulator(object):
|
||||
return x
|
||||
|
||||
|
||||
class ShiftTime(object):
|
||||
class ShiftTime(_BaseTransformation):
|
||||
|
||||
valid_shifts = ['right', 'left', 'any']
|
||||
|
||||
def __init__(self, max_shift_ratio: float, shift_direction: str = 'any'):
|
||||
super(ShiftTime, self).__init__()
|
||||
assert 1 > max_shift_ratio >= 0, f'max_shift_ratio has to be between [0,1], but was: {max_shift_ratio}.'
|
||||
assert shift_direction.lower() in self.valid_shifts, f'shift_direction has to be one of: {self.valid_shifts}'
|
||||
self.max_shift_ratio = max_shift_ratio
|
||||
@ -53,26 +58,27 @@ class ShiftTime(object):
|
||||
if self.shift_direction == 'right':
|
||||
shift = -1 * shift
|
||||
elif self.shift_direction == 'any':
|
||||
direction = np.random.choice([1, -1], 1)
|
||||
direction = np.asscalar(np.random.choice([1, -1], 1))
|
||||
shift = direction * shift
|
||||
augmented_data = np.roll(x, shift)
|
||||
# Set to silence for heading/ tailing
|
||||
shift = int(shift)
|
||||
if shift > 0:
|
||||
augmented_data[:shift] = 0
|
||||
augmented_data[:, :shift] = 0
|
||||
else:
|
||||
augmented_data[shift:] = 0
|
||||
augmented_data[:, shift:] = 0
|
||||
return augmented_data
|
||||
else:
|
||||
return x
|
||||
|
||||
|
||||
class MaskAug(object):
|
||||
class MaskAug(_BaseTransformation):
|
||||
|
||||
w_idx = -1
|
||||
h_idx = -2
|
||||
|
||||
def __init__(self, duration_ratio_max=0.3, mask_with_noise=True):
|
||||
super(MaskAug, self).__init__()
|
||||
assertion = f'"duration_ratio" has to be within [0..1], but was: {duration_ratio_max}'
|
||||
if isinstance(duration_ratio_max, (tuple, list)):
|
||||
assert all([0 < max_val < 1 for max_val in duration_ratio_max]), assertion
|
||||
@ -87,9 +93,9 @@ class MaskAug(object):
|
||||
def __call__(self, x):
|
||||
for dim in (self.w_idx, self.h_idx):
|
||||
if self.duration_ratio_max[dim]:
|
||||
start = int(np.random.choice(x.shape[dim], 1))
|
||||
v_max = x.shape[dim] * self.duration_ratio_max[dim]
|
||||
size = int(np.random.randint(0, v_max, 1))
|
||||
start = np.asscalar(np.random.choice(x.shape[dim], 1))
|
||||
v_max = int(x.shape[dim] * self.duration_ratio_max[dim])
|
||||
size = np.asscalar(np.random.randint(0, v_max, 1))
|
||||
end = int(min(start + size, x.shape[dim]))
|
||||
size = end - start
|
||||
if dim == self.w_idx:
|
||||
|
54
audio_toolset/mel_dataset.py
Normal file
54
audio_toolset/mel_dataset.py
Normal file
@ -0,0 +1,54 @@
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import pickle
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from ml_lib.modules.util import AutoPadToShape
|
||||
|
||||
|
||||
class TorchMelDataset(Dataset):
|
||||
def __init__(self, mel_path, sub_segment_len, sub_segment_hop_len, label, audio_file_len,
|
||||
sampling_rate, mel_hop_len, n_mels, transform=None, auto_pad_to_shape=True):
|
||||
super(TorchMelDataset, self).__init__()
|
||||
self.sampling_rate = int(sampling_rate)
|
||||
self.audio_file_len = float(audio_file_len)
|
||||
if auto_pad_to_shape and sub_segment_len:
|
||||
self.padding = AutoPadToShape((int(n_mels), int(sub_segment_len)))
|
||||
else:
|
||||
self.padding = None
|
||||
self.path = Path(mel_path)
|
||||
self.sub_segment_len = int(sub_segment_len)
|
||||
self.mel_hop_len = int(mel_hop_len)
|
||||
self.sub_segment_hop_len = int(sub_segment_hop_len)
|
||||
self.n = int((self.sampling_rate / self.mel_hop_len) * self.audio_file_len + 1)
|
||||
if self.sub_segment_len and self.sub_segment_hop_len and (self.n - self.sub_segment_len) > 0:
|
||||
self.offsets = list(range(0, self.n - self.sub_segment_len, self.sub_segment_hop_len))
|
||||
else:
|
||||
self.offsets = [0]
|
||||
if len(self) == 0:
|
||||
print('what happend here')
|
||||
self.label = label
|
||||
self.transform = transform
|
||||
|
||||
def __getitem__(self, item):
|
||||
with self.path.open('rb') as mel_file:
|
||||
mel_spec = pickle.load(mel_file, fix_imports=True)
|
||||
start = self.offsets[item]
|
||||
sub_segments_attributes_set = self.sub_segment_len and self.sub_segment_hop_len
|
||||
sub_segment_length_smaller_then_tot_length = self.sub_segment_len < mel_spec.shape[1]
|
||||
|
||||
if sub_segments_attributes_set and sub_segment_length_smaller_then_tot_length:
|
||||
duration = self.sub_segment_len
|
||||
else:
|
||||
duration = mel_spec.shape[1]
|
||||
|
||||
snippet = mel_spec[:, start: start + duration]
|
||||
if self.transform:
|
||||
snippet = self.transform(snippet)
|
||||
if self.padding:
|
||||
snippet = self.padding(snippet)
|
||||
return self.path.__str__(), snippet, self.label
|
||||
|
||||
def __len__(self):
|
||||
return len(self.offsets)
|
27
audio_toolset/mel_transforms.py
Normal file
27
audio_toolset/mel_transforms.py
Normal file
@ -0,0 +1,27 @@
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class Normalize(object):
|
||||
|
||||
def __init__(self, min_db_level: Union[int, float]):
|
||||
self.min_db_level = min_db_level
|
||||
|
||||
def __repr__(self):
|
||||
return f'{self.__class__.__name__}({self.__dict__})'
|
||||
|
||||
def __call__(self, s: np.ndarray) -> np.ndarray:
|
||||
return np.clip((s - self.min_db_level) / -self.min_db_level, 0, 1)
|
||||
|
||||
|
||||
class DeNormalize(object):
|
||||
|
||||
def __init__(self, min_db_level: Union[int, float]):
|
||||
self.min_db_level = min_db_level
|
||||
|
||||
def __repr__(self):
|
||||
return f'{self.__class__.__name__}({self.__dict__})'
|
||||
|
||||
def __call__(self, s: np.ndarray) -> np.ndarray:
|
||||
return (np.clip(s, 0, 1) * -self.min_db_level) + self.min_db_level
|
@ -1,13 +1,21 @@
|
||||
import matplotlib.pyplot as plt
|
||||
from sklearn.metrics import roc_curve, auc
|
||||
try:
|
||||
import matplotlib.pyplot as plt
|
||||
except ImportError: # pragma: no-cover
|
||||
raise ImportError('You want to use `matplotlib` plugins which are not installed yet,' # pragma: no-cover
|
||||
' install it with `pip install matplotlib`.')
|
||||
try:
|
||||
from sklearn.metrics import roc_curve, auc, recall_score
|
||||
except ImportError: # pragma: no-cover
|
||||
raise ImportError('You want to use `sklearn` plugins which are not installed yet,' # pragma: no-cover
|
||||
' install it with `pip install scikit-learn`.')
|
||||
|
||||
|
||||
class ROCEvaluation(object):
|
||||
|
||||
linewidth = 2
|
||||
|
||||
def __init__(self, plot_roc=False):
|
||||
self.plot_roc = plot_roc
|
||||
def __init__(self, plot=False):
|
||||
self.plot = plot
|
||||
self.epoch = 0
|
||||
|
||||
def __call__(self, prediction, label):
|
||||
@ -15,7 +23,7 @@ class ROCEvaluation(object):
|
||||
# Compute ROC curve and ROC area
|
||||
fpr, tpr, _ = roc_curve(prediction, label)
|
||||
roc_auc = auc(fpr, tpr)
|
||||
if self.plot_roc:
|
||||
if self.plot:
|
||||
_ = plt.gcf()
|
||||
plt.plot(fpr, tpr, color='darkorange', lw=self.linewidth, label=f'ROC curve (area = {roc_auc})')
|
||||
self._prepare_fig()
|
||||
@ -32,3 +40,32 @@ class ROCEvaluation(object):
|
||||
fig.legend(loc="lower right")
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
class UAREvaluation(object):
|
||||
|
||||
def __init__(self, labels: list, plot=False):
|
||||
self.labels = labels
|
||||
self.plot_roc = plot
|
||||
self.epoch = 0
|
||||
|
||||
def __call__(self, prediction, label):
|
||||
|
||||
# Compute uar score - UnweightedAverageRecal
|
||||
|
||||
uar_score = recall_score(label, prediction, labels=self.labels, average='macro',
|
||||
sample_weight=None, zero_division='warn')
|
||||
return uar_score
|
||||
|
||||
def _prepare_fig(self):
|
||||
raise NotImplementedError # TODO Implement a nice visualization
|
||||
fig = plt.gcf()
|
||||
ax = plt.gca()
|
||||
plt.plot([0, 1], [0, 1], color='navy', lw=self.linewidth, linestyle='--')
|
||||
plt.xlim([0.0, 1.0])
|
||||
plt.ylim([0.0, 1.05])
|
||||
plt.xlabel('False Positive Rate')
|
||||
plt.ylabel('True Positive Rate')
|
||||
fig.legend(loc="lower right")
|
||||
|
||||
return fig
|
0
metrics/__init__.py
Normal file
0
metrics/__init__.py
Normal file
13
metrics/_base_score.py
Normal file
13
metrics/_base_score.py
Normal file
@ -0,0 +1,13 @@
|
||||
from abc import ABC
|
||||
|
||||
|
||||
class _BaseScores(ABC):
|
||||
|
||||
def __init__(self, lightning_model):
|
||||
self.model = lightning_model
|
||||
pass
|
||||
|
||||
def __call__(self, outputs):
|
||||
# summary_dict = dict()
|
||||
# return summary_dict
|
||||
raise NotImplementedError
|
47
metrics/attention_rollout.py
Normal file
47
metrics/attention_rollout.py
Normal file
@ -0,0 +1,47 @@
|
||||
import numpy as np
|
||||
|
||||
from einops import reduce
|
||||
|
||||
|
||||
import torch
|
||||
from sklearn.ensemble import IsolationForest
|
||||
from sklearn.metrics import recall_score, roc_auc_score, average_precision_score
|
||||
|
||||
from ml_lib.metrics._base_score import _BaseScores
|
||||
|
||||
|
||||
class AttentionRollout(_BaseScores):
|
||||
|
||||
def __init__(self, *args):
|
||||
super(AttentionRollout, self).__init__(*args)
|
||||
pass
|
||||
|
||||
def __call__(self, outputs):
|
||||
summary_dict = dict()
|
||||
#######################################################################################
|
||||
# Additional Score - Histogram Distances - Image Plotting
|
||||
#######################################################################################
|
||||
#
|
||||
# INIT
|
||||
attn_weights = [output['attn_weights'].cpu().numpy() for output in outputs]
|
||||
attn_reduce_heads = [reduce(x, '') for x in attn_weights]
|
||||
|
||||
if self.model.params.use_residual:
|
||||
residual_att = np.eye(att_mat.shape[1])[None, ...]
|
||||
aug_att_mat = att_mat + residual_att
|
||||
aug_att_mat = aug_att_mat / aug_att_mat.sum(axis=-1)[..., None]
|
||||
else:
|
||||
aug_att_mat = att_mat
|
||||
|
||||
joint_attentions = np.zeros(aug_att_mat.shape)
|
||||
|
||||
layers = joint_attentions.shape[0]
|
||||
joint_attentions[0] = aug_att_mat[0]
|
||||
for i in np.arange(1, layers):
|
||||
joint_attentions[i] = aug_att_mat[i].dot(joint_attentions[i - 1])
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
68
metrics/binary_class_classifictaion.py
Normal file
68
metrics/binary_class_classifictaion.py
Normal file
@ -0,0 +1,68 @@
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
from sklearn.ensemble import IsolationForest
|
||||
from sklearn.metrics import recall_score, roc_auc_score, average_precision_score
|
||||
|
||||
from ml_lib.metrics._base_score import _BaseScores
|
||||
from ml_lib.utils.tools import to_one_hot
|
||||
|
||||
|
||||
class BinaryScores(_BaseScores):
|
||||
|
||||
def __init__(self, *args):
|
||||
super(BinaryScores, self).__init__(*args)
|
||||
|
||||
def __call__(self, outputs):
|
||||
summary_dict = dict()
|
||||
|
||||
# Additional Score like the unweighted Average Recall:
|
||||
#########################
|
||||
# INIT
|
||||
if isinstance(outputs['batch_y'], torch.Tensor):
|
||||
y_true = outputs['batch_y'].cpu().numpy()
|
||||
else:
|
||||
y_true = torch.cat([output['batch_y'] for output in outputs]).cpu().numpy()
|
||||
|
||||
if isinstance(outputs['y'], torch.Tensor):
|
||||
y_pred = outputs['y'].cpu().numpy()
|
||||
else:
|
||||
y_pred = torch.cat([output['y'] for output in outputs]).squeeze().cpu().float().numpy()
|
||||
|
||||
# UnweightedAverageRecall
|
||||
# y_true = torch.cat([output['batch_y'] for output in outputs]).cpu().numpy()
|
||||
# y_pred = torch.cat([output['element_wise_recon_error'] for output in outputs]).squeeze().cpu().numpy()
|
||||
|
||||
# How to apply a threshold manualy
|
||||
# y_pred = (y_pred >= 0.5).astype(np.float32)
|
||||
|
||||
# How to apply a threshold by IF (Isolation Forest)
|
||||
clf = IsolationForest()
|
||||
y_score = clf.fit_predict(y_pred.reshape(-1, 1))
|
||||
y_score = (np.asarray(y_score) == -1).astype(np.float32)
|
||||
|
||||
uar_score = recall_score(y_true, y_score, labels=[0, 1], average='macro',
|
||||
sample_weight=None, zero_division='warn')
|
||||
summary_dict.update(dict(uar_score=uar_score))
|
||||
#########################
|
||||
# Precission
|
||||
precision_score = average_precision_score(y_true, y_score)
|
||||
summary_dict.update(dict(precision_score=precision_score))
|
||||
|
||||
#########################
|
||||
# AUC
|
||||
try:
|
||||
auc_score = roc_auc_score(y_true=y_true, y_score=y_score)
|
||||
summary_dict.update(dict(auc_score=auc_score))
|
||||
except ValueError:
|
||||
summary_dict.update(dict(auc_score=-1))
|
||||
|
||||
#########################
|
||||
# pAUC
|
||||
try:
|
||||
pauc = roc_auc_score(y_true=y_true, y_score=y_score, max_fpr=0.15)
|
||||
summary_dict.update(dict(pauc_score=pauc))
|
||||
except ValueError:
|
||||
summary_dict.update(dict(pauc_score=-1))
|
||||
|
||||
return summary_dict
|
68
metrics/generative_task_evaluation.py
Normal file
68
metrics/generative_task_evaluation.py
Normal file
@ -0,0 +1,68 @@
|
||||
from itertools import cycle
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from sklearn.metrics import roc_curve, auc, roc_auc_score, ConfusionMatrixDisplay, confusion_matrix
|
||||
from scipy.spatial.distance import cdist
|
||||
|
||||
from ml_lib.metrics._base_score import _BaseScores
|
||||
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
|
||||
class GenerativeTaskEval(_BaseScores):
|
||||
|
||||
def __init__(self, *args):
|
||||
super(GenerativeTaskEval, self).__init__(*args)
|
||||
pass
|
||||
|
||||
def __call__(self, outputs):
|
||||
summary_dict = dict()
|
||||
#######################################################################################
|
||||
# Additional Score - Histogram Distances - Image Plotting
|
||||
#######################################################################################
|
||||
#
|
||||
# INIT
|
||||
y_true = torch.cat([output['batch_y'] for output in outputs]).cpu().numpy()
|
||||
|
||||
y_pred = torch.cat([output['y'] for output in outputs]).squeeze().cpu().numpy()
|
||||
|
||||
attn_weights = torch.cat([output['attn_weights'] for output in outputs]).squeeze().cpu().numpy()
|
||||
|
||||
######################################################################################
|
||||
#
|
||||
# Histogram comparission
|
||||
|
||||
y_true_hist = np.histogram(y_true, bins=128)[0] # Todo: Find a better value
|
||||
y_pred_hist = np.histogram(y_pred, bins=128)[0] # Todo: Find a better value
|
||||
|
||||
# L2 norm == euclidean distance
|
||||
hist_euc_dist = cdist(np.expand_dims(y_true_hist, axis=0), np.expand_dims(y_pred_hist, axis=0),
|
||||
metric='euclidean')
|
||||
|
||||
# Manhattan Distance
|
||||
hist_manhattan_dist = cdist(np.expand_dims(y_true_hist, axis=0), np.expand_dims(y_pred_hist, axis=0),
|
||||
metric='cityblock')
|
||||
|
||||
summary_dict.update(hist_manhattan_dist=hist_manhattan_dist, hist_euc_dist=hist_euc_dist)
|
||||
|
||||
#######################################################################################
|
||||
#
|
||||
idx = np.random.choice(np.arange(y_true.shape[0]), 1).item()
|
||||
|
||||
ax = plt.imshow(y_true[idx].squeeze())
|
||||
# Plot using a small number of colors, with unevenly spaced boundaries.
|
||||
ax2 = plt.imshow(attn_weights[idx].sq, interpolation='nearest', aspect='auto', extent=ax.get_extent())
|
||||
self.model.logger.log_image('ROC', image=plt.gcf(), step=self.model.current_epoch)
|
||||
plt.clf()
|
||||
|
||||
|
||||
#######################################################################################
|
||||
#
|
||||
|
||||
|
||||
#######################################################################################
|
||||
#
|
||||
|
||||
plt.close('all')
|
||||
return summary_dict
|
142
metrics/multi_class_classification.py
Normal file
142
metrics/multi_class_classification.py
Normal file
@ -0,0 +1,142 @@
|
||||
from itertools import cycle
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from sklearn.metrics import f1_score, roc_curve, auc, roc_auc_score, ConfusionMatrixDisplay, confusion_matrix, \
|
||||
recall_score
|
||||
|
||||
from ml_lib.metrics._base_score import _BaseScores
|
||||
from ml_lib.utils.tools import to_one_hot
|
||||
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
|
||||
class MultiClassScores(_BaseScores):
|
||||
|
||||
def __init__(self, *args):
|
||||
super(MultiClassScores, self).__init__(*args)
|
||||
pass
|
||||
|
||||
def __call__(self, outputs, class_names=None):
|
||||
summary_dict = dict()
|
||||
class_names = class_names or range(self.model.params.n_classes)
|
||||
#######################################################################################
|
||||
# Additional Score - UAR - ROC - Conf. Matrix - F1
|
||||
#######################################################################################
|
||||
#
|
||||
# INIT
|
||||
if isinstance(outputs['batch_y'], torch.Tensor):
|
||||
y_true = outputs['batch_y'].cpu().numpy()
|
||||
else:
|
||||
y_true = torch.cat([output['batch_y'] for output in outputs]).cpu().numpy()
|
||||
y_true_one_hot = to_one_hot(y_true, self.model.params.n_classes)
|
||||
|
||||
if isinstance(outputs['y'], torch.Tensor):
|
||||
y_pred = outputs['y'].cpu().numpy()
|
||||
else:
|
||||
y_pred = torch.cat([output['y'] for output in outputs]).squeeze().cpu().float().numpy()
|
||||
y_pred_max = np.argmax(y_pred, axis=1)
|
||||
|
||||
class_names = {val: key for val, key in enumerate(class_names)}
|
||||
######################################################################################
|
||||
#
|
||||
# F1 SCORE
|
||||
micro_f1_score = f1_score(y_true, y_pred_max, labels=None, pos_label=1, average='micro', sample_weight=None,
|
||||
zero_division=True)
|
||||
macro_f1_score = f1_score(y_true, y_pred_max, labels=None, pos_label=1, average='macro', sample_weight=None,
|
||||
zero_division=True)
|
||||
summary_dict.update(dict(micro_f1_score=micro_f1_score, macro_f1_score=macro_f1_score))
|
||||
######################################################################################
|
||||
#
|
||||
# Unweichted Average Recall
|
||||
uar = recall_score(y_true, y_pred_max, labels=[0, 1, 2, 3, 4], average='macro',
|
||||
sample_weight=None, zero_division='warn')
|
||||
summary_dict.update(dict(uar_score=uar))
|
||||
#######################################################################################
|
||||
#
|
||||
# ROC Curve
|
||||
|
||||
# Compute ROC curve and ROC area for each class
|
||||
fpr = dict()
|
||||
tpr = dict()
|
||||
roc_auc = dict()
|
||||
for i in range(self.model.params.n_classes):
|
||||
fpr[i], tpr[i], _ = roc_curve(y_true_one_hot[:, i], y_pred[:, i])
|
||||
roc_auc[i] = auc(fpr[i], tpr[i])
|
||||
|
||||
# Compute micro-average ROC curve and ROC area
|
||||
fpr["micro"], tpr["micro"], _ = roc_curve(y_true_one_hot.ravel(), y_pred.ravel())
|
||||
roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])
|
||||
|
||||
# First aggregate all false positive rates
|
||||
all_fpr = np.unique(np.concatenate([fpr[i] for i in range(self.model.params.n_classes)]))
|
||||
|
||||
# Then interpolate all ROC curves at this points
|
||||
mean_tpr = np.zeros_like(all_fpr)
|
||||
for i in range(self.model.params.n_classes):
|
||||
mean_tpr += np.interp(all_fpr, fpr[i], tpr[i])
|
||||
|
||||
# Finally average it and compute AUC
|
||||
mean_tpr /= self.model.params.n_classes
|
||||
|
||||
fpr["macro"] = all_fpr
|
||||
tpr["macro"] = mean_tpr
|
||||
roc_auc["macro"] = auc(fpr["macro"], tpr["macro"])
|
||||
|
||||
# Plot all ROC curves
|
||||
plt.figure()
|
||||
plt.plot(fpr["micro"], tpr["micro"],
|
||||
label=f'micro ROC ({round(roc_auc["micro"], 2)})',
|
||||
color='deeppink', linestyle=':', linewidth=4)
|
||||
|
||||
plt.plot(fpr["macro"], tpr["macro"],
|
||||
label=f'macro ROC({round(roc_auc["macro"], 2)})',
|
||||
color='navy', linestyle=':', linewidth=4)
|
||||
|
||||
colors = cycle(['firebrick', 'orangered', 'gold', 'olive', 'limegreen', 'aqua',
|
||||
'dodgerblue', 'slategrey', 'royalblue', 'indigo', 'fuchsia'], )
|
||||
|
||||
for i, color in zip(range(self.model.params.n_classes), colors):
|
||||
plt.plot(fpr[i], tpr[i], color=color, lw=2, label=f'{class_names[i]} ({round(roc_auc[i], 2)})')
|
||||
|
||||
plt.plot([0, 1], [0, 1], 'k--', lw=2)
|
||||
plt.xlim([0.0, 1.0])
|
||||
plt.ylim([0.0, 1.05])
|
||||
plt.xlabel('False Positive Rate')
|
||||
plt.ylabel('True Positive Rate')
|
||||
plt.legend(loc="lower right")
|
||||
|
||||
self.model.logger.log_image('ROC', image=plt.gcf(), step=self.model.current_epoch)
|
||||
# self.model.logger.log_image('ROC', image=plt.gcf(), step=self.model.current_epoch, ext='pdf')
|
||||
plt.clf()
|
||||
|
||||
#######################################################################################
|
||||
#
|
||||
# ROC AUC SCORE
|
||||
|
||||
try:
|
||||
macro_roc_auc_ovr = roc_auc_score(y_true_one_hot, y_pred, multi_class="ovr",
|
||||
average="macro")
|
||||
summary_dict.update(macro_roc_auc_ovr=macro_roc_auc_ovr)
|
||||
except ValueError:
|
||||
micro_roc_auc_ovr = roc_auc_score(y_true_one_hot, y_pred, multi_class="ovr",
|
||||
average="micro")
|
||||
summary_dict.update(micro_roc_auc_ovr=micro_roc_auc_ovr)
|
||||
|
||||
#######################################################################################
|
||||
#
|
||||
# Confusion matrix
|
||||
fig1, ax1 = plt.subplots(dpi=96)
|
||||
cm = confusion_matrix([class_names[x] for x in y_true], [class_names[x] for x in y_pred_max],
|
||||
labels=[class_names[key] for key in class_names.keys()],
|
||||
normalize='true')
|
||||
disp = ConfusionMatrixDisplay(confusion_matrix=cm,
|
||||
display_labels=[class_names[i] for i in range(self.model.params.n_classes)]
|
||||
)
|
||||
disp.plot(include_values=True, ax=ax1)
|
||||
|
||||
self.model.logger.log_image('Confusion_Matrix', image=fig1, step=self.model.current_epoch)
|
||||
# self.model.logger.log_image('Confusion_Matrix', image=disp.figure_, step=self.model.current_epoch, ext='pdf')
|
||||
|
||||
plt.close('all')
|
||||
return summary_dict
|
BIN
modules/__pycache__/__init__.cpython-37.pyc
Normal file
BIN
modules/__pycache__/__init__.cpython-37.pyc
Normal file
Binary file not shown.
BIN
modules/__pycache__/geometric_blocks.cpython-37.pyc
Normal file
BIN
modules/__pycache__/geometric_blocks.cpython-37.pyc
Normal file
Binary file not shown.
BIN
modules/__pycache__/util.cpython-37.pyc
Normal file
BIN
modules/__pycache__/util.cpython-37.pyc
Normal file
Binary file not shown.
@ -1,11 +1,17 @@
|
||||
import warnings
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
import warnings
|
||||
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from modules.utils import AutoPad, Interpolate, ShapeMixin, F_x, Flatten
|
||||
import sys
|
||||
sys.path.append(str(Path(__file__).parent))
|
||||
|
||||
from .util import AutoPad, Interpolate, ShapeMixin, F_x, Flatten
|
||||
|
||||
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
@ -15,23 +21,24 @@ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
###################
|
||||
class LinearModule(ShapeMixin, nn.Module):
|
||||
|
||||
def __init__(self, in_shape, out_features, bias=True, activation=None,
|
||||
norm=False, dropout: Union[int, float] = 0, **kwargs):
|
||||
warnings.warn(f'The following arguments have been ignored: \n {list(kwargs.keys())}')
|
||||
def __init__(self, in_shape, out_features, use_bias=True, activation=None,
|
||||
use_norm=False, dropout: Union[int, float] = 0, **kwargs):
|
||||
if list(kwargs.keys()):
|
||||
warnings.warn(f'The following arguments have been ignored: \n {list(kwargs.keys())}')
|
||||
super(LinearModule, self).__init__()
|
||||
|
||||
self.in_shape = in_shape
|
||||
self.flat = Flatten(self.in_shape) if isinstance(self.in_shape, (tuple, list)) else F_x(in_shape)
|
||||
self.dropout = nn.Dropout(dropout) if dropout else F_x(self.flat.shape)
|
||||
self.norm = nn.BatchNorm1d(self.flat.shape) if norm else F_x(self.flat.shape)
|
||||
self.linear = nn.Linear(self.flat.shape, out_features, bias=bias)
|
||||
self.norm = nn.LayerNorm(self.flat.shape) if use_norm else F_x(self.flat.shape)
|
||||
self.linear = nn.Linear(self.flat.shape, out_features, bias=use_bias)
|
||||
self.activation = activation() if activation else F_x(self.linear.out_features)
|
||||
|
||||
def forward(self, x):
|
||||
tensor = self.flat(x)
|
||||
tensor = self.dropout(tensor)
|
||||
tensor = self.norm(tensor)
|
||||
tensor = self.linear(tensor)
|
||||
tensor = self.linear(tensor.float())
|
||||
tensor = self.activation(tensor)
|
||||
return tensor
|
||||
|
||||
@ -39,14 +46,22 @@ class LinearModule(ShapeMixin, nn.Module):
|
||||
class ConvModule(ShapeMixin, nn.Module):
|
||||
|
||||
def __init__(self, in_shape, conv_filters, conv_kernel, activation: nn.Module = nn.ELU, pooling_size=None,
|
||||
bias=True, norm=False, dropout: Union[int, float] = 0,
|
||||
bias=True, use_norm=False, dropout: Union[int, float] = 0, trainable: bool = True,
|
||||
conv_class=nn.Conv2d, conv_stride=1, conv_padding=0, **kwargs):
|
||||
super(ConvModule, self).__init__()
|
||||
assert isinstance(in_shape, (tuple, list)), f'"in_shape" should be a [list, tuple], but was {type(in_shape)}'
|
||||
assert len(in_shape) == 3, f'Length should be 3, but was {len(in_shape)}'
|
||||
warnings.warn(f'The following arguments have been ignored: \n {list(kwargs.keys())}')
|
||||
if len(kwargs.keys()):
|
||||
warnings.warn(f'The following arguments have been ignored: \n {list(kwargs.keys())}')
|
||||
if use_norm and not trainable:
|
||||
warnings.warn('You set this module to be not trainable but the running norm is active.\n' +
|
||||
'We set it to "eval" mode.\n' +
|
||||
'Keep this in mind if you do a finetunning or retraining step.'
|
||||
)
|
||||
|
||||
# Module Parameters
|
||||
self.in_shape = in_shape
|
||||
self.trainable = trainable
|
||||
in_channels, height, width = in_shape[0], in_shape[1], in_shape[2]
|
||||
|
||||
# Convolution Parameters
|
||||
@ -56,13 +71,19 @@ class ConvModule(ShapeMixin, nn.Module):
|
||||
self.conv_kernel = conv_kernel
|
||||
|
||||
# Modules
|
||||
self.activation = activation() or F_x(None)
|
||||
self.activation = activation() or nn.Identity()
|
||||
self.norm = nn.LayerNorm(self.in_shape, eps=1e-04) if use_norm else F_x(None)
|
||||
self.dropout = nn.Dropout2d(dropout) if dropout else F_x(None)
|
||||
self.pooling = nn.MaxPool2d(pooling_size) if pooling_size else F_x(None)
|
||||
self.norm = nn.BatchNorm2d(in_channels, eps=1e-04) if norm else F_x(None)
|
||||
self.conv = conv_class(in_channels, self.conv_filters, self.conv_kernel, bias=bias,
|
||||
padding=self.padding, stride=self.stride
|
||||
)
|
||||
if not self.trainable:
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
self.norm = self.norm.eval()
|
||||
else:
|
||||
pass
|
||||
|
||||
def forward(self, x):
|
||||
tensor = self.norm(x)
|
||||
@ -73,13 +94,49 @@ class ConvModule(ShapeMixin, nn.Module):
|
||||
return tensor
|
||||
|
||||
|
||||
class PreInitializedConvModule(ShapeMixin, nn.Module):
|
||||
|
||||
def __init__(self, in_shape, weight_matrix):
|
||||
super(PreInitializedConvModule, self).__init__()
|
||||
self.in_shape = in_shape
|
||||
self.weight_matrix = weight_matrix
|
||||
raise NotImplementedError
|
||||
# ToDo Get the weight_matrix shape and init a conv_module of similar size,
|
||||
# override the weights then.
|
||||
|
||||
def forward(self, x):
|
||||
x = torch.matmul(x, self.weight_matrix) # ToDo: This is an Placeholder
|
||||
return x
|
||||
|
||||
|
||||
class SobelFilter(ShapeMixin, nn.Module):
|
||||
|
||||
def __init__(self, in_shape):
|
||||
super(SobelFilter, self).__init__()
|
||||
self.in_shape = in_shape
|
||||
|
||||
self.sobel_x = torch.tensor([[1, 0, -1], [2, 0, -2], [1, 0, -1]]).view(1, 1, 3, 3)
|
||||
self.sobel_y = torch.tensor([[1, 2, 1], [0, 0, 0], [-1, 2, -1]]).view(1, 1, 3, 3)
|
||||
|
||||
def forward(self, x):
|
||||
# Apply Filters
|
||||
g_x = F.conv2d(x, self.sobel_x)
|
||||
g_y = F.conv2d(x, self.sobel_y)
|
||||
# Calculate the Edge
|
||||
g = torch.add(*[torch.pow(tensor, 2) for tensor in [g_x, g_y]])
|
||||
# Calculate the Gradient
|
||||
g_grad = torch.atan2(g_x, g_y)
|
||||
return g_x, g_y, g, g_grad
|
||||
|
||||
|
||||
class DeConvModule(ShapeMixin, nn.Module):
|
||||
|
||||
def __init__(self, in_shape, conv_filters, conv_kernel, conv_stride=1, conv_padding=0,
|
||||
dropout: Union[int, float] = 0, autopad=0,
|
||||
activation: Union[None, nn.Module] = nn.ReLU, interpolation_scale=0,
|
||||
bias=True, norm=False):
|
||||
bias=True, use_norm=False, **kwargs):
|
||||
super(DeConvModule, self).__init__()
|
||||
warnings.warn(f'The following arguments have been ignored: \n {list(kwargs.keys())}')
|
||||
in_channels, height, width = in_shape[0], in_shape[1], in_shape[2]
|
||||
self.padding = conv_padding
|
||||
self.conv_kernel = conv_kernel
|
||||
@ -89,8 +146,8 @@ class DeConvModule(ShapeMixin, nn.Module):
|
||||
|
||||
self.autopad = AutoPad() if autopad else lambda x: x
|
||||
self.interpolation = Interpolate(scale_factor=interpolation_scale) if interpolation_scale else lambda x: x
|
||||
self.norm = nn.BatchNorm2d(in_channels, eps=1e-04) if norm else lambda x: x
|
||||
self.dropout = nn.Dropout2d(dropout) if dropout else lambda x: x
|
||||
self.norm = nn.LayerNorm(in_channels, eps=1e-04) if use_norm else F_x(self.in_shape)
|
||||
self.dropout = nn.Dropout2d(dropout) if dropout else F_x(self.in_shape)
|
||||
self.de_conv = nn.ConvTranspose2d(in_channels, self.conv_filters, self.conv_kernel, bias=bias,
|
||||
padding=self.padding, stride=self.stride)
|
||||
|
||||
@ -109,14 +166,13 @@ class DeConvModule(ShapeMixin, nn.Module):
|
||||
|
||||
class ResidualModule(ShapeMixin, nn.Module):
|
||||
|
||||
def __init__(self, in_shape, module_class, n, norm=False, **module_parameters):
|
||||
def __init__(self, in_shape, module_class, n, use_norm=False, **module_parameters):
|
||||
assert n >= 1
|
||||
super(ResidualModule, self).__init__()
|
||||
self.in_shape = in_shape
|
||||
module_parameters.update(in_shape=in_shape)
|
||||
if norm:
|
||||
self.norm = nn.BatchNorm1d if len(self.in_shape) <= 2 else nn.BatchNorm2d
|
||||
self.norm = self.norm(self.in_shape if isinstance(self.in_shape, int) else self.in_shape[0])
|
||||
if use_norm:
|
||||
self.norm = nn.LayerNorm(self.in_shape if isinstance(self.in_shape, int) else self.in_shape[0])
|
||||
else:
|
||||
self.norm = F_x(self.in_shape)
|
||||
self.activation = module_parameters.get('activation', None)
|
||||
@ -128,8 +184,9 @@ class ResidualModule(ShapeMixin, nn.Module):
|
||||
assert self.in_shape == self.shape, f'The in_shape: {self.in_shape} - must match the out_shape: {self.shape}.'
|
||||
|
||||
def forward(self, x):
|
||||
tensor = self.norm(x)
|
||||
for module in self.residual_block:
|
||||
tensor = module(x)
|
||||
tensor = module(tensor)
|
||||
|
||||
# noinspection PyUnboundLocalVariable
|
||||
tensor = tensor + x
|
||||
@ -155,3 +212,106 @@ class RecurrentModule(ShapeMixin, nn.Module):
|
||||
def forward(self, x):
|
||||
tensor = self.rnn(x)
|
||||
return tensor
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, hidden_dim, dropout=0., activation=nn.GELU):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Linear(dim, hidden_dim),
|
||||
activation() or F_x(None),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
activation() or F_x(None),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, heads=8, head_dim=64, dropout=0.):
|
||||
super().__init__()
|
||||
inner_dim = head_dim * heads
|
||||
project_out = not (heads == 1 and head_dim == dim)
|
||||
|
||||
self.heads = heads
|
||||
self.scale = head_dim ** -0.5
|
||||
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
) if project_out else nn.Identity()
|
||||
|
||||
def forward(self, x, mask=None, return_attn_weights=False):
|
||||
from einops import rearrange, repeat
|
||||
# noinspection PyTupleAssignmentBalance
|
||||
b, n, _, h = *x.shape, self.heads
|
||||
|
||||
qkv = self.to_qkv(x).chunk(3, dim=-1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), qkv)
|
||||
|
||||
dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale
|
||||
|
||||
if mask is not None:
|
||||
mask_value = -torch.finfo(dots.dtype).max
|
||||
mask = F.pad(mask.flatten(1), (1, 0), value=True)
|
||||
assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions'
|
||||
mask = mask[:, None, :] * mask[:, :, None]
|
||||
mask = repeat(mask, 'b n d -> b h n d', h=h) # My addition
|
||||
dots.masked_fill_(~mask, mask_value)
|
||||
# dots.masked_fill_(mask, mask_value) # My addition
|
||||
del mask
|
||||
|
||||
attn = dots.softmax(dim=-1)
|
||||
|
||||
out = torch.einsum('bhij,bhjd->bhid', attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
out = self.to_out(out)
|
||||
if return_attn_weights:
|
||||
return out, attn
|
||||
else:
|
||||
return out
|
||||
|
||||
|
||||
class TransformerModule(ShapeMixin, nn.Module):
|
||||
|
||||
def __init__(self, in_shape, depth, heads, mlp_dim, head_dim=32, dropout=None, use_norm=False,
|
||||
activation=nn.GELU, use_residual=True):
|
||||
super(TransformerModule, self).__init__()
|
||||
|
||||
self.in_shape = in_shape
|
||||
self.use_residual = use_residual
|
||||
|
||||
self.flat = Flatten(self.in_shape) if isinstance(self.in_shape, (tuple, list)) else F_x(in_shape)
|
||||
|
||||
self.embedding_dim = self.flat.flat_shape
|
||||
self.norm = nn.LayerNorm(self.embedding_dim) if use_norm else F_x(None)
|
||||
self.attns = nn.ModuleList([Attention(self.embedding_dim, heads=heads, dropout=dropout, head_dim=head_dim)
|
||||
for _ in range(depth)])
|
||||
self.mlps = nn.ModuleList([FeedForward(self.embedding_dim, mlp_dim, dropout=dropout, activation=activation)
|
||||
for _ in range(depth)])
|
||||
|
||||
def forward(self, x, mask=None, return_attn_weights=False, **_):
|
||||
tensor = self.flat(x)
|
||||
attn_weights = list()
|
||||
|
||||
for attn, mlp in zip(self.attns, self.mlps):
|
||||
# Attention
|
||||
attn_tensor = self.norm(tensor)
|
||||
if return_attn_weights:
|
||||
attn_tensor, attn_weight = attn(attn_tensor, mask=mask, return_attn_weights=return_attn_weights)
|
||||
attn_weights.append(attn_weight)
|
||||
else:
|
||||
attn_tensor = attn(attn_tensor, mask=mask)
|
||||
tensor = tensor + attn_tensor if self.use_residual else attn_tensor
|
||||
|
||||
# MLP
|
||||
mlp_tensor = self.norm(tensor)
|
||||
mlp_tensor = mlp(mlp_tensor)
|
||||
tensor = tensor + mlp_tensor if self.use_residual else mlp_tensor
|
||||
|
||||
return (tensor, attn_weights) if return_attn_weights else tensor
|
||||
|
65
modules/geometric_blocks.py
Normal file
65
modules/geometric_blocks.py
Normal file
@ -0,0 +1,65 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import ReLU
|
||||
try:
|
||||
from torch_geometric.nn import PointConv, fps, radius, global_max_pool, knn_interpolate
|
||||
except ImportError:
|
||||
print('Install torch-geometric to use this package.')
|
||||
|
||||
|
||||
class SAModule(torch.nn.Module):
|
||||
|
||||
def __init__(self, ratio, r, nn):
|
||||
super(SAModule, self).__init__()
|
||||
self.ratio = ratio
|
||||
self.r = r
|
||||
self.conv = PointConv(nn)
|
||||
|
||||
def forward(self, x, pos, batch):
|
||||
idx = fps(pos, batch, ratio=self.ratio)
|
||||
row, col = radius(pos, pos[idx], self.r, batch, batch[idx],
|
||||
max_num_neighbors=64)
|
||||
edge_index = torch.stack([col, row], dim=0)
|
||||
x = self.conv(x, (pos, pos[idx]), edge_index)
|
||||
pos, batch = pos[idx], batch[idx]
|
||||
return x, pos, batch
|
||||
|
||||
|
||||
class GlobalSAModule(nn.Module):
|
||||
def __init__(self, nn, channels=3):
|
||||
super(GlobalSAModule, self).__init__()
|
||||
self.nn = nn
|
||||
self.channels = channels
|
||||
|
||||
def forward(self, x, pos, batch):
|
||||
x = self.nn(torch.cat([x, pos], dim=1))
|
||||
x = global_max_pool(x, batch)
|
||||
pos = pos.new_zeros((x.size(0), self.channels))
|
||||
batch = torch.arange(x.size(0), device=batch.device)
|
||||
return x, pos, batch
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, channels, norm=True):
|
||||
super(MLP, self).__init__()
|
||||
self.net = nn.Sequential(*[
|
||||
nn.Sequential(nn.Linear(channels[i - 1], channels[i]), ReLU(), nn.BatchNorm1d(channels[i]))
|
||||
for i in range(1, len(channels))
|
||||
]).double()
|
||||
|
||||
def forward(self, x, *args, **kwargs):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
class FPModule(torch.nn.Module):
|
||||
def __init__(self, k, nn):
|
||||
super(FPModule, self).__init__()
|
||||
self.k = k
|
||||
self.nn = nn
|
||||
|
||||
def forward(self, x, pos, batch, x_skip, pos_skip, batch_skip):
|
||||
x = knn_interpolate(x, pos, pos_skip, batch, batch_skip, k=self.k)
|
||||
if x_skip is not None:
|
||||
x = torch.cat([x, x_skip], dim=1)
|
||||
x = self.nn(x)
|
||||
return x, pos_skip, batch_skip
|
@ -1,80 +1,144 @@
|
||||
#
|
||||
# Full Model Parts
|
||||
###################
|
||||
from argparse import Namespace
|
||||
from functools import reduce
|
||||
from typing import Union, List, Tuple
|
||||
|
||||
import torch
|
||||
from abc import ABC
|
||||
from operator import mul
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from modules.utils import ShapeMixin
|
||||
from .blocks import ConvModule, DeConvModule, LinearModule
|
||||
|
||||
from .util import ShapeMixin, LightningBaseModule, Flatten
|
||||
|
||||
|
||||
class Generator(nn.Module):
|
||||
@property
|
||||
def shape(self):
|
||||
x = torch.randn(self.lat_dim).unsqueeze(0)
|
||||
output = self(x)
|
||||
return output.shape[1:]
|
||||
class AEBaseModule(LightningBaseModule, ABC):
|
||||
|
||||
# noinspection PyUnresolvedReferences
|
||||
def __init__(self, out_channels, re_shape, lat_dim, use_norm=False, use_bias=True, dropout: Union[int, float] = 0,
|
||||
filters: List[int] = None, activation=nn.ReLU):
|
||||
def generate_random_image(self, dataloader: Union[None, str, DataLoader] = None,
|
||||
lat_min: Union[Tuple, List, None] = None,
|
||||
lat_max: Union[Tuple, List, None] = None):
|
||||
|
||||
assert bool(dataloader) ^ bool(lat_min and lat_max), 'Decide wether to give min, max or a dataloader, not both.'
|
||||
|
||||
min_max = self._find_min_max(dataloader) if dataloader else [None, None]
|
||||
# assert not any([tensor is None for tensor in min_max])
|
||||
lat_min = torch.as_tensor(lat_min or min_max[0])
|
||||
lat_max = lat_max or min_max[1]
|
||||
|
||||
random_z = torch.rand((1, self.lat_dim))
|
||||
random_z = random_z * (abs(lat_min) + lat_max) - abs(lat_min)
|
||||
|
||||
return self.decoder(random_z).squeeze()
|
||||
|
||||
def encode(self, x):
|
||||
if len(x.shape) == 3:
|
||||
x = x.unsqueeze(0)
|
||||
return self.encoder(x)
|
||||
|
||||
def _find_min_max(self, dataloader):
|
||||
encodings = list()
|
||||
for batch in dataloader:
|
||||
encodings.append(self.encode(batch))
|
||||
encodings = torch.cat(encodings, dim=0)
|
||||
min_lat = encodings.min(dim=1)
|
||||
max_lat = encodings.max(dim=1)
|
||||
return min_lat, max_lat
|
||||
|
||||
def decode_lat_evenly(self, n: int,
|
||||
dataloader: Union[None, str, DataLoader] = None,
|
||||
lat_min: Union[Tuple, List, None] = None,
|
||||
lat_max: Union[Tuple, List, None] = None):
|
||||
assert bool(dataloader) ^ bool(lat_min and lat_max), 'Decide wether to give min, max or a dataloader, not both.'
|
||||
|
||||
min_max = self._find_min_max(dataloader) if dataloader else [None, None]
|
||||
|
||||
lat_min = lat_min or min_max[0]
|
||||
lat_max = lat_max or min_max[1]
|
||||
|
||||
random_latent_samples = torch.stack([torch.linspace(lat_min[i].item(), lat_max[i].item(), n)
|
||||
for i in range(self.params.lat_dim)], dim=-1).cpu().detach()
|
||||
return self.decode(random_latent_samples).cpu().detach()
|
||||
|
||||
def decode(self, z):
|
||||
try:
|
||||
if len(z.shape) == 1:
|
||||
z = z.unsqueeze(0)
|
||||
except AttributeError:
|
||||
# Does not seem to be a tensor.
|
||||
pass
|
||||
return self.decoder(z).squeeze()
|
||||
|
||||
def encode_and_restore(self, x):
|
||||
x = self.transfer_batch_to_device(x, self.device)
|
||||
if len(x.shape) == 3:
|
||||
x = x.unsqueeze(0)
|
||||
z = self.encode(x)
|
||||
try:
|
||||
z = z.squeeze()
|
||||
except AttributeError:
|
||||
# Does not seem to be a tensor.
|
||||
pass
|
||||
x_hat = self.decode(z)
|
||||
|
||||
return Namespace(main_out=x_hat.squeeze(), latent_out=z)
|
||||
|
||||
|
||||
class Generator(ShapeMixin, nn.Module):
|
||||
|
||||
def __init__(self, in_shape, out_channels, re_shape, use_norm=False, use_bias=True,
|
||||
dropout: Union[int, float] = 0, interpolations: List[int] = None,
|
||||
filters: List[int] = None, kernels: List[int] = None, activation=nn.ReLU,
|
||||
**kwargs):
|
||||
super(Generator, self).__init__()
|
||||
assert filters, '"Filters" has to be a list of int len 3'
|
||||
self.filters = filters
|
||||
self.activation = activation
|
||||
self.inner_activation = activation()
|
||||
assert filters, '"Filters" has to be a list of int.'
|
||||
assert filters, '"Filters" has to be a list of int.'
|
||||
kernels = kernels if kernels else [3] * len(filters)
|
||||
assert len(filters) == len(kernels), '"Filters" and "Kernels" has to be of same length.'
|
||||
|
||||
interpolations = interpolations or [2, 2, 2]
|
||||
|
||||
self.in_shape = in_shape
|
||||
self.activation = activation()
|
||||
self.out_activation = None
|
||||
self.lat_dim = lat_dim
|
||||
self.dropout = dropout
|
||||
self.l1 = nn.Linear(self.lat_dim, reduce(mul, re_shape), bias=use_bias)
|
||||
self.l1 = LinearModule(in_shape, reduce(mul, re_shape), bias=use_bias, activation=activation)
|
||||
# re_shape = (self.feature_mixed_dim // reduce(mul, re_shape[1:]), ) + tuple(re_shape[1:])
|
||||
|
||||
self.flat = Flatten(to=re_shape)
|
||||
self.flat = Flatten(self.l1.shape, to=re_shape)
|
||||
self.de_conv_list = nn.ModuleList()
|
||||
|
||||
self.deconv1 = DeConvModule(re_shape, conv_filters=self.filters[0],
|
||||
conv_kernel=5,
|
||||
conv_padding=2,
|
||||
conv_stride=1,
|
||||
normalize=use_norm,
|
||||
activation=self.activation,
|
||||
interpolation_scale=2,
|
||||
dropout=self.dropout
|
||||
)
|
||||
last_shape = re_shape
|
||||
for conv_filter, conv_kernel, interpolation in zip(reversed(filters), kernels, interpolations):
|
||||
# noinspection PyTypeChecker
|
||||
self.de_conv_list.append(DeConvModule(last_shape, conv_filters=conv_filter,
|
||||
conv_kernel=conv_kernel,
|
||||
conv_padding=conv_kernel-2,
|
||||
conv_stride=1,
|
||||
normalize=use_norm,
|
||||
activation=activation,
|
||||
interpolation_scale=interpolation,
|
||||
dropout=self.dropout
|
||||
)
|
||||
)
|
||||
last_shape = self.de_conv_list[-1].shape
|
||||
|
||||
self.deconv2 = DeConvModule(self.deconv1.shape, conv_filters=self.filters[1],
|
||||
conv_kernel=3,
|
||||
conv_padding=1,
|
||||
conv_stride=1,
|
||||
normalize=use_norm,
|
||||
activation=self.activation,
|
||||
interpolation_scale=2,
|
||||
dropout=self.dropout
|
||||
)
|
||||
|
||||
self.deconv3 = DeConvModule(self.deconv2.shape, conv_filters=self.filters[2],
|
||||
conv_kernel=3,
|
||||
conv_padding=1,
|
||||
conv_stride=1,
|
||||
normalize=use_norm,
|
||||
activation=self.activation,
|
||||
interpolation_scale=2,
|
||||
dropout=self.dropout
|
||||
)
|
||||
|
||||
self.deconv4 = DeConvModule(self.deconv3.shape, conv_filters=out_channels,
|
||||
conv_kernel=3,
|
||||
conv_padding=1,
|
||||
# normalize=norm,
|
||||
activation=self.out_activation
|
||||
)
|
||||
self.de_conv_out = DeConvModule(self.de_conv_list[-1].shape, conv_filters=out_channels, conv_kernel=3,
|
||||
conv_padding=1, activation=self.out_activation
|
||||
)
|
||||
|
||||
def forward(self, z):
|
||||
tensor = self.l1(z)
|
||||
tensor = self.inner_activation(tensor)
|
||||
tensor = self.activation(tensor)
|
||||
tensor = self.flat(tensor)
|
||||
tensor = self.deconv1(tensor)
|
||||
tensor = self.deconv2(tensor)
|
||||
tensor = self.deconv3(tensor)
|
||||
tensor = self.deconv4(tensor)
|
||||
|
||||
for de_conv in self.de_conv_list:
|
||||
tensor = de_conv(tensor)
|
||||
|
||||
tensor = self.de_conv_out(tensor)
|
||||
return tensor
|
||||
|
||||
def size(self):
|
||||
@ -114,18 +178,17 @@ class UnitGenerator(Generator):
|
||||
return tensor
|
||||
|
||||
|
||||
class BaseEncoder(ShapeMixin, nn.Module):
|
||||
class BaseCNNEncoder(ShapeMixin, nn.Module):
|
||||
|
||||
# noinspection PyUnresolvedReferences
|
||||
def __init__(self, in_shape, lat_dim=256, use_bias=True, use_norm=False, dropout: Union[int, float] = 0,
|
||||
latent_activation: Union[nn.Module, None] = None, activation: nn.Module = nn.ELU,
|
||||
filters: List[int] = None):
|
||||
super(BaseEncoder, self).__init__()
|
||||
assert filters, '"Filters" has to be a list of int len 3'
|
||||
|
||||
# Optional Padding for odd image-sizes
|
||||
# Obsolet, already Done by autopadding module on incoming tensors
|
||||
# in_shape = [x+1 if x % 2 != 0 and idx else x for idx, x in enumerate(in_shape)]
|
||||
filters: List[int] = None, kernels: Union[List[int], int, None] = None, **kwargs):
|
||||
super(BaseCNNEncoder, self).__init__()
|
||||
assert filters, '"Filters" has to be a list of int'
|
||||
kernels = kernels or [3] * len(filters)
|
||||
kernels = kernels if not isinstance(kernels, int) else [kernels] * len(filters)
|
||||
assert len(kernels) == len(filters), 'Length of "Filters" and "Kernels" has to be same.'
|
||||
|
||||
# Parameters
|
||||
self.lat_dim = lat_dim
|
||||
@ -133,52 +196,39 @@ class BaseEncoder(ShapeMixin, nn.Module):
|
||||
self.use_bias = use_bias
|
||||
self.latent_activation = latent_activation() if latent_activation else None
|
||||
|
||||
self.conv_list = nn.ModuleList()
|
||||
|
||||
# Modules
|
||||
self.conv1 = ConvModule(self.in_shape, conv_filters=filters[0],
|
||||
conv_kernel=3,
|
||||
conv_padding=1,
|
||||
conv_stride=1,
|
||||
pooling_size=2,
|
||||
use_norm=use_norm,
|
||||
dropout=dropout,
|
||||
activation=activation
|
||||
)
|
||||
last_shape = self.in_shape
|
||||
for conv_filter, conv_kernel in zip(filters, kernels):
|
||||
self.conv_list.append(ConvModule(last_shape, conv_filters=conv_filter,
|
||||
conv_kernel=conv_kernel,
|
||||
conv_padding=conv_kernel-2,
|
||||
conv_stride=1,
|
||||
pooling_size=2,
|
||||
use_norm=use_norm,
|
||||
dropout=dropout,
|
||||
activation=activation
|
||||
)
|
||||
)
|
||||
last_shape = self.conv_list[-1].shape
|
||||
self.last_conv_shape = last_shape
|
||||
|
||||
self.conv2 = ConvModule(self.conv1.shape, conv_filters=filters[1],
|
||||
conv_kernel=3,
|
||||
conv_padding=1,
|
||||
conv_stride=1,
|
||||
pooling_size=2,
|
||||
use_norm=use_norm,
|
||||
dropout=dropout,
|
||||
activation=activation
|
||||
)
|
||||
|
||||
self.conv3 = ConvModule(self.conv2.shape, conv_filters=filters[2],
|
||||
conv_kernel=5,
|
||||
conv_padding=2,
|
||||
conv_stride=1,
|
||||
pooling_size=2,
|
||||
use_norm=use_norm,
|
||||
dropout=dropout,
|
||||
activation=activation
|
||||
)
|
||||
|
||||
self.flat = Flatten()
|
||||
self.flat = Flatten(self.last_conv_shape)
|
||||
|
||||
def forward(self, x):
|
||||
tensor = self.conv1(x)
|
||||
tensor = self.conv2(tensor)
|
||||
tensor = self.conv3(tensor)
|
||||
tensor = x
|
||||
for conv in self.conv_list:
|
||||
tensor = conv(tensor)
|
||||
tensor = self.flat(tensor)
|
||||
return tensor
|
||||
|
||||
|
||||
class UnitEncoder(BaseEncoder):
|
||||
class UnitCNNEncoder(BaseCNNEncoder):
|
||||
# noinspection PyUnresolvedReferences
|
||||
def __init__(self, *args, **kwargs):
|
||||
kwargs.update(use_norm=True)
|
||||
super(UnitEncoder, self).__init__(*args, **kwargs)
|
||||
super(UnitCNNEncoder, self).__init__(*args, **kwargs)
|
||||
self.l1 = nn.Linear(reduce(mul, self.conv3.shape), self.lat_dim, bias=self.use_bias)
|
||||
|
||||
def forward(self, x):
|
||||
@ -190,10 +240,10 @@ class UnitEncoder(BaseEncoder):
|
||||
return c1, c2, c3, l1
|
||||
|
||||
|
||||
class VariationalEncoder(BaseEncoder):
|
||||
class VariationalCNNEncoder(BaseCNNEncoder):
|
||||
# noinspection PyUnresolvedReferences
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(VariationalEncoder, self).__init__(*args, **kwargs)
|
||||
super(VariationalCNNEncoder, self).__init__(*args, **kwargs)
|
||||
|
||||
self.logvar = nn.Linear(reduce(mul, self.conv3.shape), self.lat_dim, bias=self.use_bias)
|
||||
self.mu = nn.Linear(reduce(mul, self.conv3.shape), self.lat_dim, bias=self.use_bias)
|
||||
@ -205,22 +255,22 @@ class VariationalEncoder(BaseEncoder):
|
||||
return mu + eps*std
|
||||
|
||||
def forward(self, x):
|
||||
tensor = super(VariationalEncoder, self).forward(x)
|
||||
tensor = super(VariationalCNNEncoder, self).forward(x)
|
||||
mu = self.mu(tensor)
|
||||
logvar = self.logvar(tensor)
|
||||
z = self.reparameterize(mu, logvar)
|
||||
return mu, logvar, z
|
||||
|
||||
|
||||
class Encoder(BaseEncoder):
|
||||
# noinspection PyUnresolvedReferences
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(Encoder, self).__init__(*args, **kwargs)
|
||||
class CNNEncoder(BaseCNNEncoder):
|
||||
|
||||
self.l1 = nn.Linear(reduce(mul, self.conv3.shape), self.lat_dim, bias=self.use_bias)
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(CNNEncoder, self).__init__(*args, **kwargs)
|
||||
|
||||
self.l1 = nn.Linear(self.flat.shape, self.lat_dim, bias=self.use_bias)
|
||||
|
||||
def forward(self, x):
|
||||
tensor = super(Encoder, self).forward(x)
|
||||
tensor = super(CNNEncoder, self).forward(x)
|
||||
tensor = self.l1(tensor)
|
||||
tensor = self.latent_activation(tensor) if self.latent_activation else tensor
|
||||
return tensor
|
||||
|
420
modules/util.py
Normal file
420
modules/util.py
Normal file
@ -0,0 +1,420 @@
|
||||
from functools import reduce
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
from abc import ABC
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from operator import mul
|
||||
from pytorch_lightning.utilities import argparse_utils
|
||||
from torch import nn
|
||||
from torch.nn import functional as F, Unfold
|
||||
|
||||
from sklearn.metrics import ConfusionMatrixDisplay
|
||||
|
||||
# Utility - Modules
|
||||
###################
|
||||
from ..metrics.binary_class_classifictaion import BinaryScores
|
||||
from ..metrics.multi_class_classification import MultiClassScores
|
||||
from ..utils.model_io import ModelParameters
|
||||
from ..utils.tools import add_argparse_args
|
||||
|
||||
try:
|
||||
import pytorch_lightning as pl
|
||||
|
||||
class PLMetrics(pl.metrics.Metric):
|
||||
|
||||
def __init__(self, n_classes, tag=''):
|
||||
super(PLMetrics, self).__init__()
|
||||
|
||||
self.n_classes = n_classes
|
||||
self.tag = tag
|
||||
|
||||
self.accuracy_score = pl.metrics.Accuracy(compute_on_step=False,)
|
||||
self.precision = pl.metrics.Precision(num_classes=self.n_classes, average='macro', compute_on_step=False,
|
||||
is_multiclass=True)
|
||||
self.recall = pl.metrics.Recall(num_classes=self.n_classes, average='macro', compute_on_step=False,
|
||||
is_multiclass=True)
|
||||
self.confusion_matrix = pl.metrics.ConfusionMatrix(self.n_classes, normalize='true', compute_on_step=False)
|
||||
# self.precision_recall_curve = pl.metrics.PrecisionRecallCurve(self.n_classes, compute_on_step=False)
|
||||
# self.average_prec = pl.metrics.AveragePrecision(self.n_classes, compute_on_step=True)
|
||||
# self.roc = pl.metrics.ROC(self.n_classes, compute_on_step=False)
|
||||
if self.n_classes > 2:
|
||||
self.fbeta = pl.metrics.FBeta(self.n_classes, average='macro', compute_on_step=False)
|
||||
self.f1 = pl.metrics.F1(self.n_classes, average='macro', compute_on_step=False)
|
||||
|
||||
def __iter__(self):
|
||||
return iter(((name, metric) for name, metric in self._modules.items()))
|
||||
|
||||
def update(self, preds, target) -> None:
|
||||
for _, metric in self:
|
||||
try:
|
||||
if self.n_classes <= 2:
|
||||
metric.update(preds, target)
|
||||
else:
|
||||
metric.update(preds, target)
|
||||
except ValueError:
|
||||
print(f'error was: {ValueError}')
|
||||
print(f'Metric is: {metric}')
|
||||
print(f'Shape is: preds - {preds.squeeze().shape}, target - {target.shape}')
|
||||
metric.update(preds.squeeze(), target)
|
||||
except AssertionError:
|
||||
print(f'error was: {AssertionError}')
|
||||
print(f'Metric is: {metric}')
|
||||
print(f'Shape is: preds - {preds.shape}, target - {target.unsqueeze(-1).shape}')
|
||||
metric.update(preds, target.unsqueeze(-1))
|
||||
|
||||
def reset(self) -> None:
|
||||
for _, metric in self:
|
||||
metric.reset()
|
||||
|
||||
def compute(self) -> dict:
|
||||
tag = f'{self.tag}_' if self.tag else ''
|
||||
return {f'{tag}{metric_name}_score': metric.compute() for metric_name, metric in self}
|
||||
|
||||
def compute_and_prepare(self):
|
||||
pl_metrics = self.compute()
|
||||
images_from_metrics = dict()
|
||||
for metric_name in list(pl_metrics.keys()):
|
||||
if 'curve' in metric_name:
|
||||
continue
|
||||
roc_curve = pl_metrics.pop(metric_name)
|
||||
print('debug_point')
|
||||
|
||||
elif 'matrix' in metric_name:
|
||||
matrix = pl_metrics.pop(metric_name)
|
||||
fig1, ax1 = plt.subplots(dpi=96)
|
||||
disp = ConfusionMatrixDisplay(confusion_matrix=matrix.cpu().numpy(),
|
||||
display_labels=[i for i in range(self.n_classes)]
|
||||
)
|
||||
disp.plot(include_values=True, ax=ax1)
|
||||
images_from_metrics[metric_name] = fig1
|
||||
|
||||
elif 'ROC' in metric_name:
|
||||
continue
|
||||
roc = pl_metrics.pop(metric_name)
|
||||
print('debug_point')
|
||||
else:
|
||||
pl_metrics[metric_name] = pl_metrics[metric_name].cpu().item()
|
||||
|
||||
return pl_metrics, images_from_metrics
|
||||
|
||||
|
||||
class LightningBaseModule(pl.LightningModule, ABC):
|
||||
|
||||
@classmethod
|
||||
def name(cls):
|
||||
return cls.__name__
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
try:
|
||||
x = torch.randn(self.in_shape).unsqueeze(0)
|
||||
output = self(x)
|
||||
return output.shape[1:]
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return -1
|
||||
|
||||
@classmethod
|
||||
def from_argparse_args(cls, args, **kwargs):
|
||||
return argparse_utils.from_argparse_args(cls, args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def add_argparse_args(cls, parent_parser):
|
||||
return add_argparse_args(cls, parent_parser)
|
||||
|
||||
def __init__(self, model_parameters, weight_init='xavier_normal_'):
|
||||
super(LightningBaseModule, self).__init__()
|
||||
self._weight_init = weight_init
|
||||
self.params = ModelParameters(model_parameters)
|
||||
|
||||
if hasattr(self.params, 'n_classes'):
|
||||
self.metrics = PLMetrics(self.params.n_classes, tag='PL')
|
||||
else:
|
||||
pass
|
||||
|
||||
def size(self):
|
||||
return self.shape
|
||||
|
||||
def save_to_disk(self, model_path):
|
||||
Path(model_path, exist_ok=True).mkdir(parents=True, exist_ok=True)
|
||||
if not (model_path / 'model_class.obj').exists():
|
||||
with (model_path / 'model_class.obj').open('wb') as f:
|
||||
torch.save(self.__class__, f)
|
||||
return True
|
||||
|
||||
@property
|
||||
def data_len(self):
|
||||
return len(self.dataset.train_dataset)
|
||||
|
||||
@property
|
||||
def n_train_batches(self):
|
||||
return len(self.train_dataloader())
|
||||
|
||||
def configure_optimizers(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def test_step(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def test_epoch_end(self, outputs):
|
||||
raise NotImplementedError
|
||||
|
||||
def init_weights(self):
|
||||
if isinstance(self._weight_init, str):
|
||||
mod = __import__('torch.nn.init', fromlist=[self._weight_init])
|
||||
self._weight_init = getattr(mod, self._weight_init)
|
||||
assert callable(self._weight_init)
|
||||
weight_initializer = WeightInit(in_place_init_function=self._weight_init)
|
||||
self.apply(weight_initializer)
|
||||
|
||||
def additional_scores(self, outputs):
|
||||
if self.params.n_classes > 2:
|
||||
return MultiClassScores(self)(outputs)
|
||||
else:
|
||||
return BinaryScores(self)(outputs)
|
||||
|
||||
module_types = (LightningBaseModule, nn.Module,)
|
||||
|
||||
except ImportError:
|
||||
module_types = (nn.Module,)
|
||||
pl = None
|
||||
pass # Maybe post a hint to install pytorch-lightning.
|
||||
|
||||
|
||||
class ShapeMixin:
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
|
||||
assert isinstance(self, module_types)
|
||||
|
||||
def get_out_shape(output):
|
||||
return output.shape[1:] if len(output.shape[1:]) > 1 else output.shape[-1]
|
||||
|
||||
in_shape = self.in_shape if hasattr(self, 'in_shape') else None
|
||||
if in_shape is not None:
|
||||
try:
|
||||
device = self.device
|
||||
except AttributeError:
|
||||
try:
|
||||
device = next(self.parameters()).device
|
||||
except StopIteration:
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
x = torch.randn(in_shape, device=device)
|
||||
# This is needed for BatchNorm shape checking
|
||||
x = torch.stack((x, x))
|
||||
|
||||
# noinspection PyCallingNonCallable
|
||||
y = self(x)
|
||||
if isinstance(y, tuple):
|
||||
shape = tuple([get_out_shape(y[i]) for i in range(len(y))])
|
||||
else:
|
||||
shape = get_out_shape(y)
|
||||
return shape
|
||||
else:
|
||||
return -1
|
||||
|
||||
@property
|
||||
def flat_shape(self):
|
||||
shape = self.shape
|
||||
try:
|
||||
return reduce(mul, shape)
|
||||
except TypeError:
|
||||
return shape
|
||||
|
||||
|
||||
class F_x(ShapeMixin, nn.Identity):
|
||||
def __init__(self, in_shape):
|
||||
super(F_x, self).__init__()
|
||||
self.in_shape = in_shape
|
||||
|
||||
|
||||
class SlidingWindow(ShapeMixin, nn.Module):
|
||||
def __init__(self, in_shape, kernel, stride=1, padding=0, keepdim=False):
|
||||
super(SlidingWindow, self).__init__()
|
||||
self.in_shape = in_shape
|
||||
self.kernel = kernel if not isinstance(kernel, int) else (kernel, kernel)
|
||||
self.padding = padding
|
||||
self.stride = stride
|
||||
self.keepdim = keepdim
|
||||
self._unfolder = Unfold(self.kernel, dilation=1, padding=self.padding, stride=self.stride)
|
||||
|
||||
def forward(self, x):
|
||||
tensor = self._unfolder(x)
|
||||
tensor = tensor.transpose(-1, -2)
|
||||
if self.keepdim:
|
||||
shape = *x.shape[:2], -1, *self.kernel
|
||||
tensor = tensor.reshape(shape)
|
||||
return tensor
|
||||
|
||||
|
||||
# Utility - Modules
|
||||
###################
|
||||
class Flatten(ShapeMixin, nn.Module):
|
||||
|
||||
def __init__(self, in_shape, to=-1):
|
||||
assert isinstance(to, int) or isinstance(to, tuple)
|
||||
super(Flatten, self).__init__()
|
||||
self.in_shape = in_shape
|
||||
self.to = (to,) if isinstance(to, int) else to
|
||||
|
||||
def forward(self, x):
|
||||
return x.view(x.size(0), *self.to)
|
||||
|
||||
|
||||
class Interpolate(nn.Module):
|
||||
def __init__(self, size=None, scale_factor=None, mode='nearest', align_corners=None):
|
||||
super(Interpolate, self).__init__()
|
||||
self.interp = nn.functional.interpolate
|
||||
self.size = size
|
||||
self.scale_factor = scale_factor
|
||||
self.align_corners = align_corners
|
||||
self.mode = mode
|
||||
|
||||
def forward(self, x):
|
||||
x = self.interp(x, size=self.size, scale_factor=self.scale_factor,
|
||||
mode=self.mode, align_corners=self.align_corners)
|
||||
return x
|
||||
|
||||
|
||||
class AutoPad(nn.Module):
|
||||
|
||||
def __init__(self, interpolations=3, base=2):
|
||||
super(AutoPad, self).__init__()
|
||||
self.fct = base ** interpolations
|
||||
|
||||
def forward(self, x):
|
||||
# noinspection PyUnresolvedReferences
|
||||
x = F.pad(x,
|
||||
[0,
|
||||
(x.shape[-1] // self.fct + 1) * self.fct - x.shape[-1] if x.shape[-1] % self.fct != 0 else 0,
|
||||
(x.shape[-2] // self.fct + 1) * self.fct - x.shape[-2] if x.shape[-2] % self.fct != 0 else 0,
|
||||
0])
|
||||
return x
|
||||
|
||||
|
||||
class WeightInit:
|
||||
|
||||
def __init__(self, in_place_init_function):
|
||||
self.in_place_init_function = in_place_init_function
|
||||
|
||||
def __call__(self, m):
|
||||
if hasattr(m, 'weight'):
|
||||
if isinstance(m.weight, torch.Tensor):
|
||||
if m.weight.ndim < 2:
|
||||
m.weight.data.fill_(0.01)
|
||||
else:
|
||||
self.in_place_init_function(m.weight)
|
||||
if hasattr(m, 'bias'):
|
||||
if isinstance(m.bias, torch.Tensor):
|
||||
m.bias.data.fill_(0.01)
|
||||
|
||||
|
||||
class Filter(nn.Module, ShapeMixin):
|
||||
|
||||
def __init__(self, in_shape, pos, dim=-1):
|
||||
super(Filter, self).__init__()
|
||||
|
||||
self.in_shape = in_shape
|
||||
self.pos = pos
|
||||
self.dim = dim
|
||||
raise SystemError('Do not use this Module - broken.')
|
||||
|
||||
@staticmethod
|
||||
def forward(x):
|
||||
tensor = x[:, -1]
|
||||
return tensor
|
||||
|
||||
|
||||
class FlipTensor(nn.Module):
|
||||
def __init__(self, dim=-2):
|
||||
super(FlipTensor, self).__init__()
|
||||
self.dim = dim
|
||||
|
||||
def forward(self, x):
|
||||
idx = [i for i in range(x.size(self.dim) - 1, -1, -1)]
|
||||
idx = torch.as_tensor(idx).long()
|
||||
inverted_tensor = x.index_select(self.dim, idx)
|
||||
return inverted_tensor
|
||||
|
||||
|
||||
class AutoPadToShape(nn.Module):
|
||||
def __init__(self, target_shape):
|
||||
super(AutoPadToShape, self).__init__()
|
||||
self.target_shape = target_shape
|
||||
|
||||
def forward(self, x):
|
||||
if not torch.is_tensor(x):
|
||||
x = torch.as_tensor(x)
|
||||
if x.shape[-len(self.target_shape):] == self.target_shape or x.shape == self.target_shape:
|
||||
return x
|
||||
|
||||
idx = [0] * (len(self.target_shape) * 2)
|
||||
for i, j in zip(range(-1, -(len(self.target_shape)+1), -1), range(0, len(idx), 2)):
|
||||
idx[j] = self.target_shape[i] - x.shape[i]
|
||||
x = torch.nn.functional.pad(x, idx)
|
||||
return x
|
||||
|
||||
def __repr__(self):
|
||||
return f'AutoPadTransform({self.target_shape})'
|
||||
|
||||
|
||||
class Splitter(nn.Module):
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
return tuple([self._out_shape] * self.n)
|
||||
|
||||
@property
|
||||
def out_shape(self):
|
||||
return self._out_shape
|
||||
|
||||
def __init__(self, in_shape, n, dim=-1):
|
||||
super(Splitter, self).__init__()
|
||||
|
||||
self.in_shape = (in_shape, ) if isinstance(in_shape, int) else in_shape
|
||||
self.n = n
|
||||
self.dim = dim if dim > 0 else len(self.in_shape) - abs(dim)
|
||||
|
||||
self.new_dim_size = (self.in_shape[self.dim] // self.n) + (1 if self.in_shape[self.dim] % self.n != 0 else 0)
|
||||
self._out_shape = tuple([x if self.dim != i else self.new_dim_size for i, x in enumerate(self.in_shape)])
|
||||
|
||||
self.autopad = AutoPadToShape(self._out_shape)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
dim = self.dim + 1 if len(self.in_shape) == (x.ndim - 1) else self.dim
|
||||
x = x.transpose(0, dim)
|
||||
n_blocks = list()
|
||||
for block_idx in range(self.n):
|
||||
start = block_idx * self.new_dim_size
|
||||
end = (block_idx + 1) * self.new_dim_size
|
||||
block = x[start:end].transpose(0, dim)
|
||||
block = self.autopad(block)
|
||||
n_blocks.append(block)
|
||||
return n_blocks
|
||||
|
||||
|
||||
class Merger(nn.Module, ShapeMixin):
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
y = self.forward([torch.randn(self.in_shape) for _ in range(self.n)])
|
||||
return y.shape
|
||||
|
||||
def __init__(self, in_shape, n, dim=-1):
|
||||
super(Merger, self).__init__()
|
||||
|
||||
self.n = n
|
||||
self.dim = dim
|
||||
self.in_shape = in_shape
|
||||
|
||||
def forward(self, x):
|
||||
return torch.cat(x, dim=self.dim)
|
256
modules/utils.py
256
modules/utils.py
@ -1,256 +0,0 @@
|
||||
from abc import ABC
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch import functional as F
|
||||
|
||||
import pytorch_lightning as pl
|
||||
|
||||
|
||||
# Utility - Modules
|
||||
###################
|
||||
from utils.model_io import ModelParameters
|
||||
|
||||
|
||||
class ShapeMixin:
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
assert isinstance(self, (LightningBaseModule, nn.Module))
|
||||
if self.in_shape is not None:
|
||||
x = torch.randn(self.in_shape)
|
||||
# This is needed for BatchNorm shape checking
|
||||
x = torch.stack((x, x))
|
||||
output = self(x)
|
||||
return output.shape[1:] if len(output.shape[1:]) > 1 else output.shape[-1]
|
||||
else:
|
||||
return -1
|
||||
|
||||
|
||||
class F_x(ShapeMixin, nn.Module):
|
||||
def __init__(self, in_shape):
|
||||
super(F_x, self).__init__()
|
||||
self.in_shape = in_shape
|
||||
|
||||
def forward(self, x):
|
||||
return x
|
||||
|
||||
|
||||
# Utility - Modules
|
||||
###################
|
||||
class Flatten(ShapeMixin, nn.Module):
|
||||
|
||||
def __init__(self, in_shape, to=-1):
|
||||
assert isinstance(to, int) or isinstance(to, tuple)
|
||||
super(Flatten, self).__init__()
|
||||
self.in_shape = in_shape
|
||||
self.to = (to,) if isinstance(to, int) else to
|
||||
|
||||
def forward(self, x):
|
||||
return x.view(x.size(0), *self.to)
|
||||
|
||||
|
||||
class Interpolate(nn.Module):
|
||||
def __init__(self, size=None, scale_factor=None, mode='nearest', align_corners=None):
|
||||
super(Interpolate, self).__init__()
|
||||
self.interp = nn.functional.interpolate
|
||||
self.size = size
|
||||
self.scale_factor = scale_factor
|
||||
self.align_corners = align_corners
|
||||
self.mode = mode
|
||||
|
||||
def forward(self, x):
|
||||
x = self.interp(x, size=self.size, scale_factor=self.scale_factor,
|
||||
mode=self.mode, align_corners=self.align_corners)
|
||||
return x
|
||||
|
||||
|
||||
class AutoPad(nn.Module):
|
||||
|
||||
def __init__(self, interpolations=3, base=2):
|
||||
super(AutoPad, self).__init__()
|
||||
self.fct = base ** interpolations
|
||||
|
||||
def forward(self, x):
|
||||
# noinspection PyUnresolvedReferences
|
||||
x = F.pad(x,
|
||||
[0,
|
||||
(x.shape[-1] // self.fct + 1) * self.fct - x.shape[-1] if x.shape[-1] % self.fct != 0 else 0,
|
||||
(x.shape[-2] // self.fct + 1) * self.fct - x.shape[-2] if x.shape[-2] % self.fct != 0 else 0,
|
||||
0])
|
||||
return x
|
||||
|
||||
|
||||
class WeightInit:
|
||||
|
||||
def __init__(self, in_place_init_function):
|
||||
self.in_place_init_function = in_place_init_function
|
||||
|
||||
def __call__(self, m):
|
||||
if hasattr(m, 'weight'):
|
||||
if isinstance(m.weight, torch.Tensor):
|
||||
if m.weight.ndim < 2:
|
||||
m.weight.data.fill_(0.01)
|
||||
else:
|
||||
self.in_place_init_function(m.weight)
|
||||
if hasattr(m, 'bias'):
|
||||
if isinstance(m.bias, torch.Tensor):
|
||||
m.bias.data.fill_(0.01)
|
||||
|
||||
|
||||
class LightningBaseModule(pl.LightningModule, ABC):
|
||||
|
||||
@classmethod
|
||||
def name(cls):
|
||||
return cls.__name__
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
try:
|
||||
x = torch.randn(self.in_shape).unsqueeze(0)
|
||||
output = self(x)
|
||||
return output.shape[1:]
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return -1
|
||||
|
||||
def __init__(self, hparams):
|
||||
super(LightningBaseModule, self).__init__()
|
||||
|
||||
# Set Parameters
|
||||
################################
|
||||
self.hparams = hparams
|
||||
self.params = ModelParameters(hparams)
|
||||
|
||||
# Dataset Loading
|
||||
################################
|
||||
# TODO: Find a way to push Class Name, library path and parameters (sometimes thiose are objects) in here
|
||||
|
||||
def size(self):
|
||||
return self.shape
|
||||
|
||||
def save_to_disk(self, model_path):
|
||||
Path(model_path, exist_ok=True).mkdir(parents=True, exist_ok=True)
|
||||
if not (model_path / 'model_class.obj').exists():
|
||||
with (model_path / 'model_class.obj').open('wb') as f:
|
||||
torch.save(self.__class__, f)
|
||||
return True
|
||||
|
||||
@property
|
||||
def data_len(self):
|
||||
return len(self.dataset.train_dataset)
|
||||
|
||||
@property
|
||||
def n_train_batches(self):
|
||||
return len(self.train_dataloader())
|
||||
|
||||
def configure_optimizers(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def test_step(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def test_epoch_end(self, outputs):
|
||||
raise NotImplementedError
|
||||
|
||||
def init_weights(self, in_place_init_func_=nn.init.xavier_uniform_):
|
||||
weight_initializer = WeightInit(in_place_init_function=in_place_init_func_)
|
||||
self.apply(weight_initializer)
|
||||
|
||||
|
||||
class FilterLayer(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(FilterLayer, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
tensor = x[:, -1]
|
||||
return tensor
|
||||
|
||||
|
||||
class MergingLayer(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(MergingLayer, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
# ToDo: Which ones to combine?
|
||||
return
|
||||
|
||||
|
||||
class FlipTensor(nn.Module):
|
||||
def __init__(self, dim=-2):
|
||||
super(FlipTensor, self).__init__()
|
||||
self.dim = dim
|
||||
|
||||
def forward(self, x):
|
||||
idx = [i for i in range(x.size(self.dim) - 1, -1, -1)]
|
||||
idx = torch.as_tensor(idx).long()
|
||||
inverted_tensor = x.index_select(self.dim, idx)
|
||||
return inverted_tensor
|
||||
|
||||
|
||||
class AutoPadToShape(object):
|
||||
def __init__(self, shape):
|
||||
self.shape = shape
|
||||
|
||||
def __call__(self, x):
|
||||
if not torch.is_tensor(x):
|
||||
x = torch.as_tensor(x)
|
||||
if x.shape[1:] == self.shape:
|
||||
return x
|
||||
embedding = torch.zeros((x.shape[0], *self.shape))
|
||||
embedding[:, :x.shape[1], :x.shape[2], :x.shape[3]] = x
|
||||
return embedding
|
||||
|
||||
def __repr__(self):
|
||||
return f'AutoPadTransform({self.shape})'
|
||||
|
||||
|
||||
class HorizontalSplitter(nn.Module):
|
||||
|
||||
def __init__(self, in_shape, n):
|
||||
super(HorizontalSplitter, self).__init__()
|
||||
assert len(in_shape) == 3
|
||||
self.n = n
|
||||
self.in_shape = in_shape
|
||||
|
||||
self.channel, self.height, self.width = self.in_shape
|
||||
self.new_height = (self.height // self.n) + (1 if self.height % self.n != 0 else 0)
|
||||
|
||||
self.shape = (self.channel, self.new_height, self.width)
|
||||
self.autopad = AutoPadToShape(self.shape)
|
||||
|
||||
def forward(self, x):
|
||||
n_blocks = list()
|
||||
for block_idx in range(self.n):
|
||||
start = block_idx * self.new_height
|
||||
end = (block_idx + 1) * self.new_height
|
||||
block = self.autopad(x[:, :, start:end, :])
|
||||
n_blocks.append(block)
|
||||
|
||||
return n_blocks
|
||||
|
||||
|
||||
class HorizontalMerger(nn.Module):
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
merged_shape = self.in_shape[0], self.in_shape[1] * self.n, self.in_shape[2]
|
||||
return merged_shape
|
||||
|
||||
def __init__(self, in_shape, n):
|
||||
super(HorizontalMerger, self).__init__()
|
||||
assert len(in_shape) == 3
|
||||
self.n = n
|
||||
self.in_shape = in_shape
|
||||
|
||||
def forward(self, x):
|
||||
return torch.cat(x, dim=-2)
|
0
point_toolset/__init__.py
Normal file
0
point_toolset/__init__.py
Normal file
BIN
point_toolset/__pycache__/__init__.cpython-37.pyc
Normal file
BIN
point_toolset/__pycache__/__init__.cpython-37.pyc
Normal file
Binary file not shown.
BIN
point_toolset/__pycache__/point_io.cpython-37.pyc
Normal file
BIN
point_toolset/__pycache__/point_io.cpython-37.pyc
Normal file
Binary file not shown.
37
point_toolset/point_io.py
Normal file
37
point_toolset/point_io.py
Normal file
@ -0,0 +1,37 @@
|
||||
import torch
|
||||
from torch_geometric.data import Data
|
||||
|
||||
|
||||
class BatchToData(object):
|
||||
def __init__(self, transforms=None):
|
||||
super(BatchToData, self).__init__()
|
||||
self.transforms = transforms if transforms else lambda x: x
|
||||
|
||||
def __call__(self, batch_dict):
|
||||
# Convert to torch_geometric.data.Data type
|
||||
|
||||
batch_pos = batch_dict['pos']
|
||||
batch_norm = batch_dict.get('norm', None)
|
||||
batch_y = batch_dict.get('y', None)
|
||||
batch_y_c = batch_dict.get('y_c', None)
|
||||
|
||||
batch_size, num_points, _ = batch_pos.shape # (batch_size, num_points, 3)
|
||||
|
||||
batch_size, N, _ = batch_pos.shape # (batch_size, num_points, 3)
|
||||
pos = batch_pos.view(batch_size * N, -1)
|
||||
norm = batch_norm.view(batch_size * N, -1) if batch_norm is not None else batch_norm
|
||||
|
||||
batch_y_l = batch_y.view(batch_size * N, -1) if batch_y is not None else batch_y
|
||||
batch_y_c = batch_y_c.view(batch_size * N, -1) if batch_y_c is not None else batch_y_c
|
||||
|
||||
batch = torch.zeros((batch_size, num_points), device=pos.device, dtype=torch.long)
|
||||
for i in range(batch_size):
|
||||
batch[i] = i
|
||||
batch = batch.view(-1)
|
||||
|
||||
data = Data()
|
||||
data.norm, data.pos, data.batch, data.yl, data.yc = norm, pos, batch, batch_y_l, batch_y_c
|
||||
|
||||
data = self.transforms(data)
|
||||
|
||||
return data
|
26
point_toolset/point_transforms.py
Normal file
26
point_toolset/point_transforms.py
Normal file
@ -0,0 +1,26 @@
|
||||
import torch
|
||||
from torch_geometric.transforms import NormalizeScale
|
||||
|
||||
|
||||
class NormalizePositions(NormalizeScale):
|
||||
|
||||
def __init__(self):
|
||||
super(NormalizePositions, self).__init__()
|
||||
|
||||
def __call__(self, data):
|
||||
if torch.isnan(data.pos).any():
|
||||
print('debug')
|
||||
|
||||
data = self.center(data)
|
||||
if torch.isnan(data.pos).any():
|
||||
print('debug')
|
||||
|
||||
scale = (1 / data.pos.abs().max()) * 0.999999
|
||||
if torch.isnan(scale).any() or torch.isinf(scale).any():
|
||||
print('debug')
|
||||
|
||||
data.pos = data.pos * scale
|
||||
if torch.isnan(data.pos).any():
|
||||
print('debug')
|
||||
|
||||
return data
|
51
point_toolset/sampling.py
Normal file
51
point_toolset/sampling.py
Normal file
@ -0,0 +1,51 @@
|
||||
from abc import ABC
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class _Sampler(ABC):
|
||||
|
||||
def __init__(self, K, **kwargs):
|
||||
self.k = K
|
||||
self.kwargs = kwargs
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class RandomSampling(_Sampler):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(RandomSampling, self).__init__(*args, **kwargs)
|
||||
|
||||
def __call__(self, pts, *args, **kwargs):
|
||||
rnd_indexs = np.random.choice(np.arange(pts.shape[0]), min(self.k, pts.shape[0]), replace=False)
|
||||
return rnd_indexs
|
||||
|
||||
|
||||
class FarthestpointSampling(_Sampler):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(FarthestpointSampling, self).__init__(*args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def calc_distances(p0, points):
|
||||
return ((p0[:3] - points[:, :3]) ** 2).sum(axis=1)
|
||||
|
||||
def __call__(self, pts, *args, **kwargs):
|
||||
|
||||
if pts.shape[0] < self.k:
|
||||
return pts
|
||||
|
||||
else:
|
||||
farthest_pts = np.zeros((self.k, pts.shape[1]))
|
||||
farthest_pts_idx = np.zeros(self.k, dtype=np.int)
|
||||
farthest_pts[0] = pts[np.random.randint(len(pts))]
|
||||
distances = self.calc_distances(farthest_pts[0], pts)
|
||||
for i in range(1, self.k):
|
||||
farthest_pts_idx[i] = np.argmax(distances)
|
||||
farthest_pts[i] = pts[farthest_pts_idx[i]]
|
||||
|
||||
distances = np.minimum(distances, self.calc_distances(farthest_pts[i], pts))
|
||||
|
||||
return farthest_pts_idx
|
@ -1,92 +1,8 @@
|
||||
absl-py==0.9.0
|
||||
appdirs==1.4.3
|
||||
attrs==19.3.0
|
||||
audioread==2.1.8
|
||||
bravado==10.6.0
|
||||
bravado-core==5.17.0
|
||||
CacheControl==0.12.6
|
||||
cachetools==4.1.0
|
||||
certifi==2019.11.28
|
||||
cffi==1.14.0
|
||||
chardet==3.0.4
|
||||
click==7.1.2
|
||||
colorama==0.4.3
|
||||
contextlib2==0.6.0
|
||||
cycler==0.10.0
|
||||
decorator==4.4.2
|
||||
distlib==0.3.0
|
||||
distro==1.4.0
|
||||
future==0.18.2
|
||||
gitdb==4.0.5
|
||||
GitPython==3.1.2
|
||||
google-auth==1.14.3
|
||||
google-auth-oauthlib==0.4.1
|
||||
grpcio==1.29.0
|
||||
html5lib==1.0.1
|
||||
idna==2.8
|
||||
ipaddr==2.2.0
|
||||
joblib==0.15.1
|
||||
jsonpointer==2.0
|
||||
jsonref==0.2
|
||||
jsonschema==3.2.0
|
||||
kiwisolver==1.2.0
|
||||
librosa==0.7.2
|
||||
llvmlite==0.32.1
|
||||
lockfile==0.12.2
|
||||
Markdown==3.2.2
|
||||
matplotlib==3.2.1
|
||||
monotonic==1.5
|
||||
msgpack==0.6.2
|
||||
msgpack-python==0.5.6
|
||||
natsort==7.0.1
|
||||
neptune-client==0.4.113
|
||||
numba==0.49.1
|
||||
numpy==1.18.4
|
||||
oauthlib==3.1.0
|
||||
packaging==20.3
|
||||
pandas==1.0.3
|
||||
pep517==0.8.2
|
||||
Pillow==7.1.2
|
||||
progress==1.5
|
||||
protobuf==3.12.0
|
||||
py3nvml==0.2.6
|
||||
pyasn1==0.4.8
|
||||
pyasn1-modules==0.2.8
|
||||
pycparser==2.20
|
||||
PyJWT==1.7.1
|
||||
pyparsing==2.4.6
|
||||
pyrsistent==0.16.0
|
||||
python-dateutil==2.8.1
|
||||
pytoml==0.1.21
|
||||
neptune-client==0.4.109
|
||||
pytorch-lightning==0.7.6
|
||||
pytz==2020.1
|
||||
PyYAML==5.3.1
|
||||
requests==2.22.0
|
||||
requests-oauthlib==1.3.0
|
||||
resampy==0.2.2
|
||||
retrying==1.3.3
|
||||
rfc3987==1.3.8
|
||||
rsa==4.0
|
||||
scikit-learn==0.23.0
|
||||
scipy==1.4.1
|
||||
simplejson==3.17.0
|
||||
six==1.14.0
|
||||
smmap==3.0.4
|
||||
SoundFile==0.10.3.post1
|
||||
strict-rfc3339==0.7
|
||||
swagger-spec-validator==2.5.0
|
||||
tensorboard==2.2.1
|
||||
tensorboard-plugin-wit==1.6.0.post3
|
||||
threadpoolctl==2.0.0
|
||||
torch==1.5.0+cu101
|
||||
torchvision==0.6.0+cu101
|
||||
tqdm==4.46.0
|
||||
typing-extensions==3.7.4.2
|
||||
urllib3==1.25.8
|
||||
webcolors==1.11.1
|
||||
webencodings==0.5.1
|
||||
websocket-client==0.57.0
|
||||
Werkzeug==1.0.1
|
||||
xmltodict==0.12.0
|
||||
|
||||
torchcontrib~=0.0.2
|
||||
test-tube==0.7.5
|
||||
torch==1.4.0
|
||||
torchcontrib==0.0.2
|
||||
torchvision==0.5.0
|
||||
tqdm==4.45.0
|
||||
|
BIN
utils/__pycache__/__init__.cpython-37.pyc
Normal file
BIN
utils/__pycache__/__init__.cpython-37.pyc
Normal file
Binary file not shown.
BIN
utils/__pycache__/config.cpython-37.pyc
Normal file
BIN
utils/__pycache__/config.cpython-37.pyc
Normal file
Binary file not shown.
BIN
utils/__pycache__/model_io.cpython-37.pyc
Normal file
BIN
utils/__pycache__/model_io.cpython-37.pyc
Normal file
Binary file not shown.
BIN
utils/__pycache__/tools.cpython-37.pyc
Normal file
BIN
utils/__pycache__/tools.cpython-37.pyc
Normal file
Binary file not shown.
36
utils/_basedatamodule.py
Normal file
36
utils/_basedatamodule.py
Normal file
@ -0,0 +1,36 @@
|
||||
from pytorch_lightning import LightningDataModule
|
||||
|
||||
|
||||
# Dataset Options
|
||||
from ml_lib.utils.tools import add_argparse_args
|
||||
|
||||
DATA_OPTION_test = 'test'
|
||||
DATA_OPTION_devel = 'devel'
|
||||
DATA_OPTION_train = 'train'
|
||||
DATA_OPTIONS = [DATA_OPTION_train, DATA_OPTION_devel, DATA_OPTION_test]
|
||||
|
||||
|
||||
class _BaseDataModule(LightningDataModule):
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
return self.datasets[DATA_OPTION_train].sample_shape
|
||||
|
||||
@classmethod
|
||||
def add_argparse_args(cls, parent_parser):
|
||||
return add_argparse_args(cls, parent_parser)
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.datasets = dict()
|
||||
|
||||
def transfer_batch_to_device(self, batch, device):
|
||||
if isinstance(batch, list):
|
||||
for idx, item in enumerate(batch):
|
||||
try:
|
||||
batch[idx] = item.to(device)
|
||||
except (AttributeError, RuntimeError):
|
||||
continue
|
||||
return batch
|
||||
else:
|
||||
return batch.to(device)
|
28
utils/callbacks.py
Normal file
28
utils/callbacks.py
Normal file
@ -0,0 +1,28 @@
|
||||
import torch
|
||||
from pytorch_lightning import Callback, Trainer, LightningModule
|
||||
|
||||
|
||||
class BestScoresCallback(Callback):
|
||||
|
||||
def __init__(self, *monitors) -> None:
|
||||
super().__init__()
|
||||
self.monitors = list(*monitors)
|
||||
|
||||
self.best_scores = {monitor: 0.0 for monitor in self.monitors}
|
||||
self.best_epoch = {monitor: 0 for monitor in self.monitors}
|
||||
|
||||
def on_validation_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
|
||||
epoch = pl_module.current_epoch
|
||||
|
||||
for monitor in self.best_scores.keys():
|
||||
current_score = trainer.callback_metrics.get(monitor)
|
||||
if current_score is None:
|
||||
pass
|
||||
elif torch.isinf(current_score):
|
||||
pass
|
||||
elif torch.isnan(current_score):
|
||||
pass
|
||||
else:
|
||||
self.best_scores[monitor] = max(self.best_scores[monitor], current_score)
|
||||
if self.best_scores[monitor] == current_score:
|
||||
self.best_epoch[monitor] = max(self.best_epoch[monitor], epoch)
|
148
utils/config.py
148
utils/config.py
@ -1,13 +1,84 @@
|
||||
import ast
|
||||
import configparser
|
||||
from distutils.util import strtobool
|
||||
from pathlib import Path
|
||||
from typing import Mapping, Dict
|
||||
|
||||
import torch
|
||||
from copy import deepcopy
|
||||
|
||||
from abc import ABC
|
||||
|
||||
from argparse import Namespace, ArgumentParser
|
||||
from collections import defaultdict
|
||||
from configparser import ConfigParser
|
||||
from pathlib import Path
|
||||
from configparser import ConfigParser, DuplicateSectionError
|
||||
import hashlib
|
||||
from pytorch_lightning import Trainer
|
||||
|
||||
from ml_lib.utils.loggers import LightningLogger
|
||||
from ml_lib.utils.tools import locate_and_import_class, auto_cast
|
||||
|
||||
|
||||
# Argument Parser and default Values
|
||||
# =============================================================================
|
||||
def parse_comandline_args_add_defaults(filepath, overrides=None):
|
||||
|
||||
# Parse Command Line
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument('--model_name', type=str)
|
||||
parser.add_argument('--data_name', type=str)
|
||||
parser.add_argument('--seed', type=str)
|
||||
parser.add_argument('--debug', type=strtobool)
|
||||
|
||||
# Load Defaults from _parameters.ini file
|
||||
config = configparser.ConfigParser()
|
||||
config.read(str(filepath))
|
||||
|
||||
new_defaults = dict()
|
||||
for key in ['project', 'train', 'data']:
|
||||
defaults = config[key]
|
||||
new_defaults.update({key: auto_cast(val) for key, val in defaults.items()})
|
||||
|
||||
args, _ = parser.parse_known_args()
|
||||
overrides = overrides or dict()
|
||||
default_data = overrides.get('data_name', None) or new_defaults['data_name']
|
||||
default_model = overrides.get('model_name', None) or new_defaults['model_name']
|
||||
default_seed = overrides.get('seed', None) or new_defaults['seed']
|
||||
|
||||
data_name = args.__dict__.get('data_name', None) or default_data
|
||||
model_name = args.__dict__.get('model_name', None) or default_model
|
||||
found_seed = args.__dict__.get('seed', None) or default_seed
|
||||
|
||||
new_defaults.update({key: auto_cast(val) for key, val in config[model_name].items()})
|
||||
|
||||
found_data_class = locate_and_import_class(data_name, 'datasets')
|
||||
found_model_class = locate_and_import_class(model_name, 'models')
|
||||
|
||||
for module in [LightningLogger, Trainer, found_data_class, found_model_class]:
|
||||
parser = module.add_argparse_args(parser)
|
||||
|
||||
args, _ = parser.parse_known_args(namespace=Namespace(**new_defaults))
|
||||
|
||||
args = vars(args)
|
||||
args.update({key: auto_cast(val) for key, val in args.items()})
|
||||
args.update(gpus=[0] if torch.cuda.is_available() and not args['debug'] else None,
|
||||
row_log_interval=1000, # TODO: Better Value / Setting
|
||||
log_save_interval=10000, # TODO: Better Value / Setting
|
||||
weights_summary='top',
|
||||
)
|
||||
|
||||
if overrides is not None and isinstance(overrides, (Mapping, Dict)):
|
||||
args.update(**overrides)
|
||||
if args['debug']:
|
||||
args.update(
|
||||
# The seems to be the new "fast_dev_run"
|
||||
val_check_interval=1,
|
||||
max_epochs=2,
|
||||
max_steps=2,
|
||||
auto_lr_find=False,
|
||||
check_val_every_n_epoch=1
|
||||
)
|
||||
return args, found_data_class, found_model_class, found_seed
|
||||
|
||||
|
||||
def is_jsonable(x):
|
||||
@ -38,11 +109,32 @@ class Config(ConfigParser, ABC):
|
||||
def fingerprint(self):
|
||||
h = hashlib.md5()
|
||||
params = deepcopy(self.as_dict)
|
||||
del params['model']['type']
|
||||
del params['model']['secondary_type']
|
||||
del params['data']['worker']
|
||||
del params['main']
|
||||
h.update(str(params).encode())
|
||||
try:
|
||||
del params['model']['type']
|
||||
except KeyError:
|
||||
pass
|
||||
try:
|
||||
del params['data']['worker']
|
||||
except KeyError:
|
||||
pass
|
||||
try:
|
||||
del params['data']['refresh']
|
||||
except KeyError:
|
||||
pass
|
||||
try:
|
||||
del params['main']
|
||||
except KeyError:
|
||||
pass
|
||||
try:
|
||||
del params['project']
|
||||
except KeyError:
|
||||
pass
|
||||
# Flatten the dict of dicts
|
||||
for section in list(params.keys()):
|
||||
params.update({f'{section}_{key}': val for key, val in params[section].items()})
|
||||
del params[section]
|
||||
_, vals = zip(*sorted(params.items(), key=lambda tup: tup[0]))
|
||||
h.update(str(vals).encode())
|
||||
fingerprint = h.hexdigest()
|
||||
return fingerprint
|
||||
|
||||
@ -53,6 +145,7 @@ class Config(ConfigParser, ABC):
|
||||
|
||||
@property
|
||||
def _model_map(self):
|
||||
|
||||
"""
|
||||
This is function is supposed to return a dict, which holds a mapping from string model names to model classes
|
||||
|
||||
@ -62,21 +155,31 @@ class Config(ConfigParser, ABC):
|
||||
)
|
||||
:return:
|
||||
"""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def model_class(self):
|
||||
try:
|
||||
return self._model_map[self.model.type]
|
||||
except KeyError:
|
||||
raise KeyError(f'The model alias you provided ("{self.get("model", "type")}")' +
|
||||
'does not exist! Try one of these: {list(self._model_map.keys())}')
|
||||
return locate_and_import_class(self.model.type, folder_path='models')
|
||||
except AttributeError as e:
|
||||
raise AttributeError(f'The model alias you provided ("{self.get("model", "type")}") ' +
|
||||
f'was not found!\n' +
|
||||
f'{e}')
|
||||
|
||||
@property
|
||||
def data_class(self):
|
||||
try:
|
||||
return locate_and_import_class(self.data.class_name, folder_path='datasets')
|
||||
except AttributeError as e:
|
||||
raise AttributeError(f'The dataset alias you provided ("{self.get("data", "class_name")}") ' +
|
||||
f'was not found!\n' +
|
||||
f'{e}')
|
||||
|
||||
# --------------------------------------------------
|
||||
# TODO: Do this programmatically; This did not work:
|
||||
# Initialize Default Sections as Property
|
||||
# for section in self.default_sections:
|
||||
# self.__setattr__(section, property(lambda x :x._get_namespace_for_section(section))
|
||||
# self.__setattr__(section, property(lambda tensor :tensor._get_namespace_for_section(section))
|
||||
|
||||
@property
|
||||
def main(self):
|
||||
@ -195,3 +298,22 @@ class Config(ConfigParser, ABC):
|
||||
with path.open('w') as configfile:
|
||||
super().write(configfile)
|
||||
return True
|
||||
|
||||
def _write_section(self, fp, section_name, section_items, delimiter):
|
||||
if section_name == 'project':
|
||||
return
|
||||
else:
|
||||
super(Config, self)._write_section(fp, section_name, section_items, delimiter)
|
||||
|
||||
def add_section(self, section: str) -> None:
|
||||
try:
|
||||
super(Config, self).add_section(section)
|
||||
except DuplicateSectionError:
|
||||
pass
|
||||
|
||||
|
||||
class DataClass(Namespace):
|
||||
|
||||
@property
|
||||
def __dict__(self):
|
||||
return [x for x in dir(self) if not x.startswith('_')]
|
||||
|
41
utils/data_util.py
Normal file
41
utils/data_util.py
Normal file
@ -0,0 +1,41 @@
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
def chunks(l, n):
|
||||
"""Yield successive n-sized chunks from l."""
|
||||
for i in range(0, len(l), n):
|
||||
yield l[i:i + n]
|
||||
|
||||
|
||||
class ReMapDataset(Dataset):
|
||||
@property
|
||||
def sample_shape(self):
|
||||
return list(self[0][0].shape)
|
||||
|
||||
def __init__(self, ds, mapping):
|
||||
super(ReMapDataset, self).__init__()
|
||||
# here is a mapping from this index to the mother ds index
|
||||
self.mapping = mapping
|
||||
self.ds = ds
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.ds[self.mapping[index]]
|
||||
|
||||
def __len__(self):
|
||||
return self.mapping.shape[0]
|
||||
|
||||
@classmethod
|
||||
def do_train_vali_split(cls, ds, split_fold=0.1):
|
||||
|
||||
indices = torch.randperm(len(ds))
|
||||
|
||||
valid_size = int(len(ds) * split_fold)
|
||||
|
||||
train_mapping = indices[valid_size:]
|
||||
valid_mapping = indices[:valid_size]
|
||||
|
||||
train = cls(ds, train_mapping)
|
||||
valid = cls(ds, valid_mapping)
|
||||
|
||||
return train, valid
|
30
utils/equal_sampler.py
Normal file
30
utils/equal_sampler.py
Normal file
@ -0,0 +1,30 @@
|
||||
|
||||
import random
|
||||
from typing import Iterator, Sequence
|
||||
|
||||
from torch.utils.data import Sampler
|
||||
from torch.utils.data.sampler import T_co
|
||||
|
||||
|
||||
# noinspection PyMissingConstructor
|
||||
class EqualSampler(Sampler):
|
||||
|
||||
def __init__(self, idxs_per_class: Sequence[Sequence[float]], replacement: bool = True) -> None:
|
||||
|
||||
self.replacement = replacement
|
||||
self.idxs_per_class = idxs_per_class
|
||||
self.len_largest_class = max([len(x) for x in self.idxs_per_class])
|
||||
|
||||
def __iter__(self) -> Iterator[T_co]:
|
||||
return iter(random.choice(self.idxs_per_class[random.randint(0, len(self.idxs_per_class)-1)])
|
||||
for _ in range(len(self)))
|
||||
|
||||
def __len__(self):
|
||||
return self.len_largest_class * len(self.idxs_per_class)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
es = EqualSampler([list(range(5)), list(range(5, 10)), list(range(10, 12))])
|
||||
for i in es:
|
||||
print(i)
|
||||
pass
|
185
utils/loggers.py
Normal file
185
utils/loggers.py
Normal file
@ -0,0 +1,185 @@
|
||||
import inspect
|
||||
from argparse import ArgumentParser
|
||||
from copy import deepcopy
|
||||
|
||||
import hashlib
|
||||
from pathlib import Path
|
||||
|
||||
import os
|
||||
from pytorch_lightning.loggers.base import LightningLoggerBase
|
||||
from neptune.api_exceptions import ProjectNotFound
|
||||
|
||||
from pytorch_lightning.loggers.neptune import NeptuneLogger
|
||||
|
||||
from pytorch_lightning.loggers.csv_logs import CSVLogger
|
||||
from pytorch_lightning.utilities import argparse_utils
|
||||
|
||||
from ml_lib.utils.tools import add_argparse_args
|
||||
|
||||
|
||||
class LightningLogger(LightningLoggerBase):
|
||||
|
||||
@classmethod
|
||||
def from_argparse_args(cls, args, **kwargs):
|
||||
cleaned_args = deepcopy(args.__dict__)
|
||||
|
||||
# Clean Seed and other attributes
|
||||
# TODO: Find a better way in cleaning this
|
||||
for attr in ['seed', 'num_worker', 'debug', 'eval', 'owner', 'data_root', 'check_val_every_n_epoch',
|
||||
'reset', 'outpath', 'version', 'gpus', 'neptune_key', 'num_sanity_val_steps', 'tpu_cores',
|
||||
'progress_bar_refresh_rate', 'log_save_interval', 'row_log_interval']:
|
||||
|
||||
try:
|
||||
del cleaned_args[attr]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
kwargs.update(params=cleaned_args)
|
||||
new_logger = argparse_utils.from_argparse_args(cls, args, **kwargs)
|
||||
return new_logger
|
||||
|
||||
@property
|
||||
def fingerprint(self):
|
||||
h = hashlib.md5()
|
||||
h.update(self._finger_print_string.encode())
|
||||
fingerprint = h.hexdigest()
|
||||
return fingerprint
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
short_name = "".join(c for c in self.model_name if c.isupper())
|
||||
return f'{short_name}_{self.fingerprint}'
|
||||
|
||||
media_dir = 'media'
|
||||
|
||||
@classmethod
|
||||
def add_argparse_args(cls, parent_parser):
|
||||
return add_argparse_args(cls, parent_parser)
|
||||
|
||||
@property
|
||||
def experiment(self):
|
||||
if self.debug:
|
||||
return self.csvlogger.experiment
|
||||
else:
|
||||
return self.neptunelogger.experiment
|
||||
|
||||
@property
|
||||
def log_dir(self):
|
||||
return Path(self.csvlogger.experiment.log_dir)
|
||||
|
||||
@property
|
||||
def project_name(self):
|
||||
return f"{self.owner}/{self.projeect_root.replace('_', '-')}"
|
||||
|
||||
@property
|
||||
def projeect_root(self):
|
||||
root_path = Path(os.getcwd()).name if not self.debug else 'test'
|
||||
return root_path
|
||||
|
||||
@property
|
||||
def version(self):
|
||||
return self.seed
|
||||
|
||||
@property
|
||||
def save_dir(self):
|
||||
return self.log_dir
|
||||
|
||||
@property
|
||||
def outpath(self):
|
||||
return Path(self.root_out) / self.model_name
|
||||
|
||||
def __init__(self, owner, neptune_key, model_name, outpath='output', seed=69, debug=False, params=None):
|
||||
"""
|
||||
params (dict|None): Optional. Parameters of the experiment. After experiment creation params are read-only.
|
||||
Parameters are displayed in the experiment’s Parameters section and each key-value pair can be
|
||||
viewed in experiments view as a column.
|
||||
properties (dict|None): Optional default is {}. Properties of the experiment.
|
||||
They are editable after experiment is created. Properties are displayed in the experiment’s Details and
|
||||
each key-value pair can be viewed in experiments view as a column.
|
||||
tags (list|None): Optional default []. Must be list of str. Tags of the experiment.
|
||||
They are editable after experiment is created (see: append_tag() and remove_tag()).
|
||||
Tags are displayed in the experiment’s Details and can be viewed in experiments view as a column.
|
||||
"""
|
||||
super(LightningLogger, self).__init__()
|
||||
|
||||
self.debug = debug
|
||||
self.owner = owner if not self.debug else 'testuser'
|
||||
self.neptune_key = neptune_key if not self.debug else 'XXX'
|
||||
self.root_out = outpath if not self.debug else 'debug_out'
|
||||
|
||||
self.params = params
|
||||
|
||||
self.seed = seed
|
||||
self.model_name = model_name
|
||||
|
||||
if self.params:
|
||||
_, fingerprint_tuple = zip(*sorted(self.params.items(), key=lambda tup: tup[0]))
|
||||
self._finger_print_string = str(fingerprint_tuple)
|
||||
else:
|
||||
self._finger_print_string = str((self.owner, self.root_out, self.seed, self.model_name, self.debug))
|
||||
self.params.update(fingerprint=self.fingerprint)
|
||||
|
||||
self._csvlogger_kwargs = dict(save_dir=self.outpath, version=self.version, name=self.name)
|
||||
self._neptune_kwargs = dict(offline_mode=self.debug,
|
||||
params=self.params,
|
||||
api_key=self.neptune_key,
|
||||
experiment_name=self.name,
|
||||
# tags=?,
|
||||
project_name=self.project_name)
|
||||
try:
|
||||
self.neptunelogger = NeptuneLogger(**self._neptune_kwargs)
|
||||
except ProjectNotFound as e:
|
||||
print(f'The project "{self.project_name}" does not exist! Create it or check your spelling.')
|
||||
print(e)
|
||||
|
||||
self.csvlogger = CSVLogger(**self._csvlogger_kwargs)
|
||||
if self.params:
|
||||
self.log_hyperparams(self.params)
|
||||
|
||||
def close(self):
|
||||
self.csvlogger.close()
|
||||
self.neptunelogger.close()
|
||||
|
||||
def set_fingerprint_string(self, fingerprint_str):
|
||||
self._finger_print_string = fingerprint_str
|
||||
|
||||
def log_text(self, name, text, **_):
|
||||
# TODO Implement Offline variant.
|
||||
self.neptunelogger.log_text(name, text)
|
||||
|
||||
def log_hyperparams(self, params):
|
||||
self.neptunelogger.log_hyperparams(params)
|
||||
self.csvlogger.log_hyperparams(params)
|
||||
pass
|
||||
|
||||
def log_metric(self, metric_name, metric_value, step=None, **kwargs):
|
||||
self.csvlogger.log_metrics(dict(metric_name=metric_value, **kwargs), step=step, **kwargs)
|
||||
self.neptunelogger.log_metric(metric_name, metric_value, step=step, **kwargs)
|
||||
pass
|
||||
|
||||
def log_metrics(self, metrics, step=None):
|
||||
self.neptunelogger.log_metrics(metrics, step=step)
|
||||
self.csvlogger.log_metrics(metrics, step=step)
|
||||
pass
|
||||
|
||||
def log_image(self, name, image, ext='png', step=None, **kwargs):
|
||||
image_name = f'{"0" * (4 - len(str(step)))}{step}_{name}' if step is not None else name
|
||||
image_path = self.log_dir / self.media_dir / f'{image_name}.{ext[1:] if ext.startswith(".") else ext}'
|
||||
(self.log_dir / self.media_dir).mkdir(parents=True, exist_ok=True)
|
||||
image.savefig(image_path, bbox_inches='tight', pad_inches=0)
|
||||
self.neptunelogger.log_image(name, str(image_path), **kwargs)
|
||||
|
||||
def save(self):
|
||||
self.csvlogger.save()
|
||||
self.neptunelogger.save()
|
||||
|
||||
def finalize(self, status):
|
||||
self.csvlogger.finalize(status)
|
||||
self.neptunelogger.finalize(status)
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.finalize('success')
|
||||
pass
|
116
utils/logging.py
116
utils/logging.py
@ -1,116 +0,0 @@
|
||||
from abc import ABC
|
||||
from pathlib import Path
|
||||
|
||||
from pytorch_lightning.loggers.base import LightningLoggerBase
|
||||
from pytorch_lightning.loggers.neptune import NeptuneLogger
|
||||
from pytorch_lightning.loggers.test_tube import TestTubeLogger
|
||||
|
||||
from utils.config import Config
|
||||
|
||||
|
||||
class Logger(LightningLoggerBase, ABC):
|
||||
|
||||
media_dir = 'media'
|
||||
|
||||
@property
|
||||
def experiment(self):
|
||||
if self.debug:
|
||||
return self.testtubelogger.experiment
|
||||
else:
|
||||
return self.neptunelogger.experiment
|
||||
|
||||
@property
|
||||
def log_dir(self):
|
||||
return Path(self.testtubelogger.experiment.get_logdir()).parent
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self.config.name
|
||||
|
||||
@property
|
||||
def project_name(self):
|
||||
return f"{self.config.project.owner}/{self.config.project.name.replace('_', '-')}"
|
||||
|
||||
@property
|
||||
def version(self):
|
||||
return self.config.get('main', 'seed')
|
||||
|
||||
@property
|
||||
def outpath(self):
|
||||
return Path(self.config.train.outpath) / self.config.model.type
|
||||
|
||||
@property
|
||||
def exp_path(self):
|
||||
return Path(self.outpath) / self.name
|
||||
|
||||
def __init__(self, config: Config):
|
||||
"""
|
||||
params (dict|None): Optional. Parameters of the experiment. After experiment creation params are read-only.
|
||||
Parameters are displayed in the experiment’s Parameters section and each key-value pair can be
|
||||
viewed in experiments view as a column.
|
||||
properties (dict|None): Optional default is {}. Properties of the experiment.
|
||||
They are editable after experiment is created. Properties are displayed in the experiment’s Details and
|
||||
each key-value pair can be viewed in experiments view as a column.
|
||||
tags (list|None): Optional default []. Must be list of str. Tags of the experiment.
|
||||
They are editable after experiment is created (see: append_tag() and remove_tag()).
|
||||
Tags are displayed in the experiment’s Details and can be viewed in experiments view as a column.
|
||||
"""
|
||||
super(Logger, self).__init__()
|
||||
|
||||
self.config = config
|
||||
self.debug = self.config.main.debug
|
||||
self._testtube_kwargs = dict(save_dir=self.outpath, version=self.version, name=self.name)
|
||||
self._neptune_kwargs = dict(offline_mode=self.debug,
|
||||
api_key=self.config.project.neptune_key,
|
||||
experiment_name=self.name,
|
||||
project_name=self.project_name,
|
||||
params=self.config.model_paramters)
|
||||
self.neptunelogger = NeptuneLogger(**self._neptune_kwargs)
|
||||
self.testtubelogger = TestTubeLogger(**self._testtube_kwargs)
|
||||
self.log_config_as_ini()
|
||||
|
||||
def log_hyperparams(self, params):
|
||||
self.neptunelogger.log_hyperparams(params)
|
||||
self.testtubelogger.log_hyperparams(params)
|
||||
pass
|
||||
|
||||
def log_metrics(self, metrics, step=None):
|
||||
self.neptunelogger.log_metrics(metrics, step=step)
|
||||
self.testtubelogger.log_metrics(metrics, step=step)
|
||||
pass
|
||||
|
||||
def close(self):
|
||||
self.testtubelogger.close()
|
||||
self.neptunelogger.close()
|
||||
|
||||
def log_config_as_ini(self):
|
||||
self.config.write(self.log_dir / 'config.ini')
|
||||
|
||||
def log_text(self, name, text, step_nb=0, **kwargs):
|
||||
# TODO Implement Offline variant.
|
||||
self.neptunelogger.log_text(name, text, step_nb)
|
||||
|
||||
def log_metric(self, metric_name, metric_value, **kwargs):
|
||||
self.testtubelogger.log_metrics(dict(metric_name=metric_value))
|
||||
self.neptunelogger.log_metric(metric_name, metric_value, **kwargs)
|
||||
|
||||
def log_image(self, name, image, **kwargs):
|
||||
self.neptunelogger.log_image(name, image, **kwargs)
|
||||
step = kwargs.get('step', None)
|
||||
name = f'{step}_{name}' if step is not None else name
|
||||
image.savefig(self.log_dir / self.media_dir / name)
|
||||
|
||||
def save(self):
|
||||
self.testtubelogger.save()
|
||||
self.neptunelogger.save()
|
||||
|
||||
def finalize(self, status):
|
||||
self.testtubelogger.finalize(status)
|
||||
self.neptunelogger.finalize(status)
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.finalize('success')
|
||||
pass
|
@ -1,5 +1,7 @@
|
||||
from argparse import Namespace
|
||||
from collections import Mapping
|
||||
from typing import Union
|
||||
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
|
||||
@ -11,6 +13,14 @@ from torch import nn
|
||||
# Hyperparamter Object
|
||||
class ModelParameters(Namespace, Mapping):
|
||||
|
||||
@property
|
||||
def as_dict(self):
|
||||
return {key: self.get(key) if key != 'activation' else self.activation_as_string for key in self.keys()}
|
||||
|
||||
@property
|
||||
def activation_as_string(self):
|
||||
return self['activation'].lower()
|
||||
|
||||
@property
|
||||
def module_kwargs(self):
|
||||
|
||||
@ -18,9 +28,11 @@ class ModelParameters(Namespace, Mapping):
|
||||
|
||||
paramter_mapping.update(
|
||||
dict(
|
||||
activation=self._activations[self['activation']]
|
||||
activation=self.__getattribute__('activation')
|
||||
)
|
||||
)
|
||||
# Get rid of paramters that
|
||||
paramter_mapping.__delitem__('in_shape')
|
||||
|
||||
return paramter_mapping
|
||||
|
||||
@ -42,49 +54,54 @@ class ModelParameters(Namespace, Mapping):
|
||||
|
||||
def __getattribute__(self, name):
|
||||
if name == 'activation':
|
||||
return self._activations[self['activation']]
|
||||
return self._activations[self['activation'].lower()]
|
||||
else:
|
||||
try:
|
||||
return super(ModelParameters, self).__getattribute__(name)
|
||||
except AttributeError as e:
|
||||
if name == 'stretch':
|
||||
return False
|
||||
else:
|
||||
raise AttributeError(e)
|
||||
return super(ModelParameters, self).__getattribute__(name)
|
||||
|
||||
_activations = dict(
|
||||
leaky_relu=nn.LeakyReLU,
|
||||
gelu=nn.GELU,
|
||||
elu=nn.ELU,
|
||||
relu=nn.ReLU,
|
||||
sigmoid=nn.Sigmoid,
|
||||
tanh=nn.Tanh
|
||||
)
|
||||
|
||||
def __init__(self, parameter_mapping):
|
||||
if isinstance(parameter_mapping, Namespace):
|
||||
parameter_mapping = parameter_mapping.__dict__
|
||||
super(ModelParameters, self).__init__(**parameter_mapping)
|
||||
|
||||
|
||||
class SavedLightningModels(object):
|
||||
|
||||
@classmethod
|
||||
def load_checkpoint(cls, models_root_path, model=None, n=-1, tags_file_path=''):
|
||||
def load_checkpoint(cls, models_root_path, model=None, n=-1, checkpoint: Union[None, str] = None):
|
||||
assert models_root_path.exists(), f'The path {models_root_path.absolute()} does not exist!'
|
||||
found_checkpoints = list(Path(models_root_path).rglob('*.ckpt'))
|
||||
|
||||
found_checkpoints = natsorted(found_checkpoints, key=lambda y: y.name)
|
||||
if checkpoint is not None:
|
||||
checkpoint_path = Path(checkpoint)
|
||||
assert checkpoint_path.exists(), f'The path ({checkpoint_path} does not exist).'
|
||||
else:
|
||||
found_checkpoints = list(Path(models_root_path).rglob('*.ckpt'))
|
||||
checkpoint_path = natsorted(found_checkpoints, key=lambda y: y.name)[n]
|
||||
if model is None:
|
||||
model = torch.load(models_root_path / 'model_class.obj')
|
||||
assert model is not None
|
||||
|
||||
return cls(weights=found_checkpoints[n], model=model)
|
||||
return cls(weights=checkpoint_path, model=model)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.weights: str = kwargs.get('weights', '')
|
||||
self.weights: Path = Path(kwargs.get('weights', ''))
|
||||
self.hparams: Path = self.weights.parent / 'hparams.yaml'
|
||||
|
||||
self.model = kwargs.get('model', None)
|
||||
assert self.model is not None
|
||||
|
||||
def restore(self):
|
||||
pretrained_model = self.model.load_from_checkpoint(self.weights)
|
||||
|
||||
pretrained_model = self.model.load_from_checkpoint(self.weights.__str__())
|
||||
# , hparams_file=self.hparams.__str__())
|
||||
pretrained_model.eval()
|
||||
pretrained_model.freeze()
|
||||
return pretrained_model
|
||||
return pretrained_model
|
||||
|
||||
|
@ -23,3 +23,5 @@ def run_n_in_parallel(f, n, processes=0, **kwargs):
|
||||
p.join()
|
||||
|
||||
return results
|
||||
|
||||
raise NotImplementedError()
|
||||
|
@ -1,6 +1,35 @@
|
||||
import importlib
|
||||
import inspect
|
||||
import pickle
|
||||
import shelve
|
||||
from pathlib import Path
|
||||
from argparse import ArgumentParser, ArgumentError
|
||||
from ast import literal_eval
|
||||
from pathlib import Path, PurePath
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import random
|
||||
|
||||
|
||||
def auto_cast(a):
|
||||
try:
|
||||
return literal_eval(a)
|
||||
except:
|
||||
return a
|
||||
|
||||
|
||||
def to_one_hot(idx_array, max_classes):
|
||||
one_hot = np.zeros((idx_array.size, max_classes))
|
||||
one_hot[np.arange(idx_array.size), idx_array] = 1
|
||||
return one_hot
|
||||
|
||||
|
||||
def fix_all_random_seeds(seed):
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
random.seed(seed)
|
||||
print(f'Seed is now fixed: "{seed}".')
|
||||
|
||||
|
||||
def write_to_shelve(file_path, value):
|
||||
@ -20,4 +49,43 @@ def load_from_shelve(file_path, key):
|
||||
|
||||
def check_path(file_path):
|
||||
assert isinstance(file_path, Path)
|
||||
assert str(file_path).endswith('.pik')
|
||||
assert str(file_path).endswith('.pik')
|
||||
|
||||
|
||||
def locate_and_import_class(class_name, folder_path: Union[str, PurePath] = ''):
|
||||
"""Locate an object by name or dotted path, importing as necessary."""
|
||||
import sys
|
||||
sys.path.append("..")
|
||||
folder_path = Path(folder_path)
|
||||
module_paths = [x for x in folder_path.rglob('*.py') if x.is_file() and '__init__' not in x.name]
|
||||
# possible_package_path = folder_path / '__init__.py'
|
||||
# package = str(possible_package_path) if possible_package_path.exists() else None
|
||||
for module_path in module_paths:
|
||||
mod = importlib.import_module('.'.join([x.replace('.py', '') for x in module_path.parts]))
|
||||
try:
|
||||
model_class = mod.__getattribute__(class_name)
|
||||
return model_class
|
||||
except AttributeError:
|
||||
continue
|
||||
raise AttributeError(f'Check the {folder_path.name} name. Possible files are:\n{[x.name for x in module_paths]}')
|
||||
|
||||
|
||||
def add_argparse_args(cls, parent_parser):
|
||||
parser = ArgumentParser(parents=[parent_parser], add_help=False)
|
||||
full_arg_spec = inspect.getfullargspec(cls.__init__)
|
||||
n_non_defaults = len(full_arg_spec.args) - (len(full_arg_spec.defaults) if full_arg_spec.defaults else 0)
|
||||
for idx, argument in enumerate(full_arg_spec.args):
|
||||
try:
|
||||
if argument == 'self':
|
||||
continue
|
||||
if idx < n_non_defaults:
|
||||
parser.add_argument(f'--{argument}', type=int)
|
||||
else:
|
||||
argument_type = type(argument)
|
||||
parser.add_argument(f'--{argument}',
|
||||
type=argument_type,
|
||||
default=full_arg_spec.defaults[idx - n_non_defaults]
|
||||
)
|
||||
except ArgumentError:
|
||||
continue
|
||||
return parser
|
||||
|
@ -1,8 +1,22 @@
|
||||
from abc import ABC
|
||||
from torchvision.transforms import ToTensor as TorchVisionToTensor
|
||||
|
||||
|
||||
class _BaseTransformation(ABC):
|
||||
|
||||
def __init__(self, *args):
|
||||
pass
|
||||
|
||||
def __repr__(self):
|
||||
return f'{self.__class__.__name__}({self.__dict__})'
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class ToTensor(TorchVisionToTensor):
|
||||
|
||||
def __call__(self, pic):
|
||||
# Make it float .float() == 32bit
|
||||
tensor = super(ToTensor, self).__call__(pic).float()
|
||||
return tensor
|
||||
|
@ -1,31 +1,56 @@
|
||||
try:
|
||||
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
|
||||
from matplotlib import pyplot as plt
|
||||
except ImportError: # pragma: no-cover
|
||||
raise ImportError('You want to use `matplotlib` which is not installed yet,' # pragma: no-cover
|
||||
' install it with `pip install matplotlib`.')
|
||||
|
||||
from pathlib import Path
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
def prettyfy_sns():
|
||||
plt.style.use('default')
|
||||
try:
|
||||
import seaborn as sns
|
||||
except ImportError:
|
||||
raise ImportError('You want to use `seaborn` which is not installed yet,' # pragma: no-cover
|
||||
' install it with `pip install seaborn`.')
|
||||
sns.set_palette('Dark2')
|
||||
|
||||
tex_fonts = {
|
||||
# Use LaTeX to write all text
|
||||
"text.usetex": True,
|
||||
"font.family": "serif",
|
||||
# Use 10pt font in plots, to match 10pt font in document
|
||||
"axes.labelsize": 10,
|
||||
"font.size": 10,
|
||||
# Make the legend/label fonts a little smaller
|
||||
"legend.fontsize": 8,
|
||||
"xtick.labelsize": 8,
|
||||
"ytick.labelsize": 8
|
||||
}
|
||||
|
||||
plt.rcParams.update(tex_fonts)
|
||||
|
||||
|
||||
class Plotter(object):
|
||||
def __init__(self, root_path=''):
|
||||
self.root_path = Path(root_path)
|
||||
|
||||
def save_current_figure(self, path, extention='.png', naked=True):
|
||||
fig, _ = plt.gcf(), plt.gca()
|
||||
def __init__(self, root_path=''):
|
||||
if not root_path:
|
||||
self.root_path = Path(root_path)
|
||||
|
||||
def save_figure(self, figure, title, extention='.png', naked=False):
|
||||
canvas = FigureCanvas(figure)
|
||||
# Prepare save location and check img file extention
|
||||
path = self.root_path / Path(path if str(path).endswith(extention) else f'{str(path)}{extention}')
|
||||
path = self.root_path / f'{title}{extention}'
|
||||
path.parent.mkdir(exist_ok=True, parents=True)
|
||||
if naked:
|
||||
plt.axis('off')
|
||||
fig.savefig(path, bbox_inches='tight', transparent=True, pad_inches=0)
|
||||
fig.clf()
|
||||
figure.axis('off)')
|
||||
figure.savefig(path, bbox_inches='tight', transparent=True, pad_inches=0)
|
||||
canvas.print_figure(path)
|
||||
else:
|
||||
fig.savefig(path)
|
||||
fig.clf()
|
||||
|
||||
def show_current_figure(self):
|
||||
fig, _ = plt.gcf(), plt.gca()
|
||||
fig.show()
|
||||
fig.clf()
|
||||
canvas.print_figure(path)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
output_root = Path('..') / 'output'
|
||||
p = Plotter(output_root)
|
||||
p.save_current_figure('test.png')
|
||||
raise PermissionError('Get out of here.')
|
||||
|
Reference in New Issue
Block a user