Compare commits

54 Commits

Author SHA1 Message Date
ab01006eae Code Comments, Getting Dirty Env, Naming 2021-05-11 10:31:34 +02:00
faa27c3cf9 paper preperations and notebooks, optuna callbacks 2021-04-02 08:45:11 +02:00
abe870d106 bugs fixed, binary datasets working 2021-03-27 18:23:51 +01:00
1d1b154460 bug in metric calculation 2021-03-27 16:39:07 +01:00
6816e423ff CCS intergration training running
notebooks
2021-03-24 08:03:11 +01:00
d3e7bf7efb adjustment fot CCS, notebook folder 2021-03-22 16:43:18 +01:00
ed260f1c2a Merge remote-tracking branch 'origin/master' 2021-03-19 21:12:26 +01:00
675312537f CCS intergration dataloader 2021-03-19 18:05:17 +01:00
43cf0ad00d CCS intergration dataloader 2021-03-19 17:17:16 +01:00
479514c9e7 Merge remote-tracking branch 'origin/master'
# Conflicts:
#	utils/tools.py
2021-03-18 21:44:18 +01:00
fff5e6e00a Final Train Runs 2021-03-18 21:43:26 +01:00
8e719af554 variable mask size, beter image shapes 2021-03-18 21:34:51 +01:00
10bf376ac3 Small bugfixes 2021-03-18 12:12:43 +01:00
fc4617c9d8 Final Train Runs 2021-03-18 07:45:06 +01:00
f89f0f8528 Transformer running 2021-03-04 12:01:08 +01:00
b5e3e5aec1 Dataset rdy 2021-02-16 10:18:03 +01:00
a966321576 bringing brances up to date 2021-02-15 11:39:54 +01:00
010176e80b transition 2021-02-01 10:23:22 +01:00
f6156c6cde Urban 8k Train running with newest Lightning and pytorch 2021-01-04 11:22:34 +01:00
93103aba01 Repair of ML Lib -> Transformations back to np from torch 2020-12-17 11:00:42 +01:00
62d9eb6e8f torchaudio testing 2020-12-17 08:02:28 +01:00
c6fdaa24aa Audio Dataset 2020-12-01 16:37:15 +01:00
cfeea05673 New Model, Many Changes 2020-11-22 16:23:59 +01:00
14ed4e0117 New Model, Many Changes 2020-11-21 09:28:25 +01:00
13812b83b5 Transformer Implementation 2020-10-29 16:40:43 +01:00
f296ba78b9 Al Lot 2020-10-07 15:21:45 +02:00
5848b528f0 SubSpectral and Lightning 0.9 Update 2020-09-25 15:35:15 +02:00
6bc9447ce1 Model Loading by string. Within Debugging 2020-08-15 12:42:57 +02:00
a4b6c698c3 InterSpeech Camera Ready Reporting 2020-08-06 08:12:07 +02:00
4b089729b2 intitial thoughts 2020-08-04 09:04:04 +02:00
c7d17a9898 6D prediction files now working 2020-06-26 08:33:58 +02:00
7770b29c14 6D prediction files now working 2020-06-25 12:03:07 +02:00
53aa11521d New Model running 2020-06-23 14:37:33 +02:00
aea34de964 dataset fixing 2020-06-19 15:37:43 +02:00
3f8122484b explicit model argument 2020-06-19 13:35:35 +02:00
12d36047ef Dataset Redone 2020-06-19 08:17:35 +02:00
76308888e0 Normalization and transforms for batch_to_data class 2020-06-15 15:14:07 +02:00
0cff42f951 ensembles 2020-06-14 20:50:53 +02:00
ece80ecbed Model IO 2020-06-09 17:06:33 +02:00
d3fa32ae7b New Dataset for per spatial cluster training 2020-06-09 14:08:34 +02:00
2acf91335f Grid Clusters. 2020-06-07 16:47:51 +02:00
5987efb169 eval running - offline logger implemented -> Test it! 2020-05-30 18:12:41 +02:00
77ea043907 pointnet2 working - TODO: Eval! 2020-05-26 21:44:56 +02:00
4b4051c045 speed aug fixed 2020-05-21 14:44:05 +02:00
8cec323286 speed aug fixed 2020-05-21 14:42:35 +02:00
235743b225 Merge remote-tracking branch 'origin/master' 2020-05-21 14:14:14 +02:00
28d0034269 requirements.txt updated @torch1.4
speed augmentation updated
paramters updated
2020-05-21 14:12:54 +02:00
b87a56e8c6 fingerprinted now should work correctly 2020-05-20 13:29:16 +02:00
196b1af7ae Dataset for whole pointclouds with farthest point sampling _incomplete_ 2020-05-19 17:15:01 +02:00
fcd5ee4d29 requirements updated to ubuntu 20.04 python3.8 torch 1.5 2020-05-19 10:06:13 +02:00
f290d5a8d8 Save imports 2020-05-19 10:03:35 +02:00
645b7905e8 initial commit - just template files 2020-05-19 09:20:53 +02:00
206aca10b3 fingerprinted now should work correctly 2020-05-19 08:33:04 +02:00
e423d6fe31 Python 3.8 branch merged
some small template fixes
2020-05-17 22:11:21 +02:00
56 changed files with 2313 additions and 737 deletions

7
.gitignore vendored
View File

@ -1,6 +1 @@
/.idea/
# my own stuff
/data
/.idea
/ml_lib
.idea

View File

@ -9,9 +9,10 @@ Clone it to find a collection of:
- Utility Function for Model I/O
- DL Modules
- A Plotter Object
- Audio Related Tools and Funtion
- Audio related Tools and Funtion
- Librosa
- Scipy Signal
- PointCloud related Tools and Functions
###Notes:
- Use directory links to link from your main project folder to the ml_lib folder. Pycharm will automatically use
@ -19,11 +20,10 @@ Clone it to find a collection of:
\
\
For Windows Users:
```
``` bash
mklink /d "ml_lib" "..\ml_lib""
```
For Unix User:
``` bash
ln -s ../ml_lib ml_lib
```
TBA
```
- Cheers

Binary file not shown.

5
_templates/new_project/.gitignore vendored Normal file
View File

@ -0,0 +1,5 @@
# my own stuff
/data
/.idea
/ml_lib

View File

@ -2,5 +2,16 @@ from torch.utils.data import Dataset
class TemplateDataset(Dataset):
@property
def sample_shape(self):
return self[0][0].shape
def __init__(self, *args, **kwargs):
super(TemplateDataset, self).__init__()
def __len__(self):
pass
def __getitem__(self, item):
return item

View File

@ -7,10 +7,9 @@ import torch
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from modules.utils import LightningBaseModule
from utils.config import Config
from utils.logging import Logger
from utils.model_io import SavedLightningModels
from ml_lib.modules.util import LightningBaseModule
from ml_lib.utils.config import Config
from ml_lib.utils.loggers import LightningLogger
warnings.filterwarnings('ignore', category=FutureWarning)
warnings.filterwarnings('ignore', category=UserWarning)
@ -21,7 +20,7 @@ def run_lightning_loop(config_obj):
# Logging
# ================================================================================
# Logger
with Logger(config_obj) as logger:
with LightningLogger(config_obj) as logger:
# Callbacks
# =============================================================================
# Checkpoint Saving
@ -44,11 +43,6 @@ def run_lightning_loop(config_obj):
# Init
model: LightningBaseModule = config_obj.model_class(config_obj.model_paramters)
model.init_weights(torch.nn.init.xavier_normal_)
if model.name == 'CNNRouteGeneratorDiscriminated':
# ToDo: Make this dependent on the used seed
path = logger.outpath / 'classifier_cnn' / 'version_0'
disc_model = SavedLightningModels.load_checkpoint(path).restore()
model.set_discriminator(disc_model)
# Trainer
# =============================================================================
@ -70,8 +64,8 @@ def run_lightning_loop(config_obj):
trainer.fit(model)
# Save the last state & all parameters
trainer.save_checkpoint(logger.log_dir / 'weights.ckpt')
model.save_to_disk(logger.log_dir)
trainer.save_checkpoint(config_obj.exp_path.log_dir / 'weights.ckpt')
model.save_to_disk(config_obj.exp_path)
# Evaluate It
if config_obj.main.eval:

View File

@ -1,6 +1,6 @@
import warnings
from utils.config import Config
from ml_lib._templates.new_project.utils.project_config import Config
warnings.filterwarnings('ignore', category=FutureWarning)
warnings.filterwarnings('ignore', category=UserWarning)
@ -8,17 +8,16 @@ warnings.filterwarnings('ignore', category=UserWarning)
# Imports
# =============================================================================
from _templates.new_project.main import run_lightning_loop, args
from ml_lib._templates.new_project.main import run_lightning_loop, args
if __name__ == '__main__':
# Model Settings
config = Config().read_namespace(args)
# bias, activation, model, norm, max_epochs, filters
cnn_classifier = dict(train_epochs=10, model_use_bias=True, model_use_norm=True, model_activation='leaky_relu',
model_type='classifier_cnn', model_filters=[16, 32, 64], data_batchsize=512)
# bias, activation, model, norm, max_epochs, sr, feature_mixed_dim, filters
# bias, activation, model, norm, max_epochs
cnn_classifier = dict(train_epochs=10, model_use_bias=True, model_use_norm=True, data_batchsize=512)
# bias, activation, model, norm, max_epochs
for arg_dict in [cnn_classifier]:
for seed in range(5):

View File

@ -11,13 +11,13 @@ from torch.utils.data import DataLoader
from torchcontrib.optim import SWA
from torchvision.transforms import Compose
from _templates.new_project.datasets.template_dataset import TemplateDataset
from ml_lib._templates.new_project.datasets.template_dataset import TemplateDataset
from audio_toolset.audio_io import NormalizeLocal
from modules.utils import LightningBaseModule
from utils.transforms import ToTensor
from ml_lib.audio_toolset.audio_io import NormalizeLocal
from ml_lib.modules.util import LightningBaseModule
from ml_lib.utils.transforms import ToTensor
from _templates.new_project.utils.project_config import GlobalVar as GlobalVars
from ml_lib._templates.new_project.utils.project_config import GlobalVar as GlobalVars
class BaseOptimizerMixin:
@ -61,10 +61,11 @@ class BaseTrainMixin:
assert isinstance(self, LightningBaseModule)
keys = list(outputs[0].keys())
summary_dict = dict(log={f'mean_{key}': torch.mean(torch.stack([output[key]
summary_dict = {f'mean_{key}': torch.mean(torch.stack([output[key]
for output in outputs]))
for key in keys if 'loss' in key})
return summary_dict
for key in keys if 'loss' in key}
for key in summary_dict.keys():
self.log(key, summary_dict[key])
class BaseValMixin:
@ -83,13 +84,13 @@ class BaseValMixin:
def validation_epoch_end(self, outputs, *_, **__):
assert isinstance(self, LightningBaseModule)
summary_dict = dict(log=dict())
summary_dict = dict()
# In case of Multiple given dataloader this will outputs will be: list[list[dict[]]]
# for output_idx, output in enumerate(outputs):
# else:list[dict[]]
keys = list(outputs.keys())
# Add Every Value das has a "loss" in it, by calc. mean over all occurences.
summary_dict['log'].update({f'mean_{key}': torch.mean(torch.stack([output[key]
summary_dict.update({f'mean_{key}': torch.mean(torch.stack([output[key]
for output in outputs]))
for key in keys if 'loss' in key}
)
@ -107,7 +108,8 @@ class BaseValMixin:
summary_dict['log'].update({f'uar_score': uar_score})
"""
return summary_dict
for key in summary_dict.keys():
self.log(key, summary_dict[key])
class BinaryMaskDatasetMixin:

View File

@ -1,8 +1,5 @@
from argparse import Namespace
from utils.config import Config
class GlobalVar(Namespace):
# Labels for classes
LEFT = 1
@ -21,10 +18,3 @@ class GlobalVar(Namespace):
train='train',
vali='vali',
test='test'
class ThisConfig(Config):
@property
def _model_map(self):
return dict()

0
additions/__init__.py Normal file
View File

75
additions/losses.py Normal file
View File

@ -0,0 +1,75 @@
from typing import Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
class FocalLoss(nn.modules.loss._WeightedLoss):
def __init__(self, weight=None, gamma=2,reduction='mean'):
super(FocalLoss, self).__init__(weight,reduction=reduction)
self.gamma = gamma
self.weight = weight # weight parameter will act as the alpha parameter to balance class weights
def forward(self, input, target):
ce_loss = F.cross_entropy(input, target, reduction=self.reduction, weight=self.weight)
pt = torch.exp(-ce_loss)
focal_loss = ((1 - pt) ** self.gamma * ce_loss).mean()
return focal_loss
class FocalLossRob(nn.Module):
# taken from https://github.com/mathiaszinnen/focal_loss_torch/blob/main/focal_loss/focal_loss.py
def __init__(self, alpha=1, gamma=2, reduction: str = 'mean'):
super().__init__()
if reduction not in ['mean', 'none', 'sum']:
raise NotImplementedError('Reduction {} not implemented.'.format(reduction))
self.reduction = reduction
self.alpha = alpha
self.gamma = gamma
def forward(self, x, target):
x = x.clamp(1e-7, 1. - 1e-7) # own addition
p_t = torch.where(target == 1, x, 1-x)
fl = - 1 * (1 - p_t) ** self.gamma * torch.log(p_t)
fl = torch.where(target == 1, fl * self.alpha, fl)
return self._reduce(fl)
def _reduce(self, x):
if self.reduction == 'mean':
return x.mean()
elif self.reduction == 'sum':
return x.sum()
else:
return x
class DQN_MSELoss(object):
def __init__(self, agent_net, target_net, gamma):
self.agent_net = agent_net
self.target_net = target_net
self.gamma = gamma
def __call__(self, batch: Tuple[torch.Tensor, ...]) -> torch.Tensor:
"""
Calculates the mse loss using a mini batch from the replay buffer
Args:
batch: current mini batch of replay data
Returns:
loss
"""
states, actions, rewards, dones, next_states = batch
actions = actions.to(torch.int64)
state_action_values = self.agent_net(states).gather(1, actions.unsqueeze(-1)).squeeze(-1)
with torch.no_grad():
next_state_values = self.target_net(next_states).max(1)[0]
next_state_values[dones] = 0.0
next_state_values = next_state_values.detach()
expected_state_action_values = next_state_values * self.gamma + rewards
return F.mse_loss(state_action_values, expected_state_action_values)

View File

@ -1,22 +1,38 @@
import librosa
try:
import librosa
except ImportError: # pragma: no-cover
raise ImportError('You want to use `librosa` plugins which are not installed yet,' # pragma: no-cover
' install it with `pip install librosa`.')
import numpy as np
class Speed(object):
def __init__(self, max_ratio=0.3, speed_factor=1):
self.speed_factor = speed_factor
self.max_ratio = max_ratio
def __init__(self, max_amount=0.3, speed_min=1, speed_max=1):
self.speed_max = speed_max if speed_max else 1
self.speed_min = speed_min if speed_min else 1
# noinspection PyTypeChecker
self.max_amount = min(max(0, max_amount), 1)
def __repr__(self):
return f'{self.__class__.__name__}({self.__dict__})'
def __call__(self, x):
if not all([self.speed_factor, self.max_ratio]):
if self.speed_min == 1 and self.speed_max == 1:
return x
start = int(np.random.randint(0, x.shape[-1],1))
end = int((np.random.uniform(0, self.max_ratio, 1) * x.shape[-1]) + start)
start = int(np.random.randint(low=0, high=x.shape[-1], size=1))
width = np.random.uniform(low=0, high=self.max_amount, size=1) * x.shape[-1]
end = int(width + start)
end = min(end, x.shape[-1])
try:
speed_factor = float(np.random.uniform(min(self.speed_factor, 1), max(self.speed_factor, 1), 1))
aug_data = librosa.effects.time_stretch(x[start:end], speed_factor)
return np.concatenate((x[:start], aug_data, x[end:]), axis=0)[:x.shape[-1]]
speed_factor = float(np.random.uniform(low=self.speed_min, high=self.speed_max, size=1))
aug_data = librosa.effects.time_stretch(y=x[start:end], rate=speed_factor)
x_aug = np.concatenate((x[:start], aug_data, x[end:]), axis=0)[:x.shape[-1]]
if speed_factor > 1:
embedding = np.zeros_like(x)
embedding[:x_aug.shape[0]] = x_aug
x_aug = embedding
return x_aug
except ValueError:
return x

View File

@ -1,8 +1,16 @@
import librosa
from scipy.signal import butter, lfilter
import numpy as np
try:
import librosa
except ImportError: # pragma: no-cover
raise ImportError('You want to use `librosa` plugins which are not installed yet,' # pragma: no-cover
' install it with `pip install librosa`.')
try:
from scipy.signal import butter, lfilter
except ImportError: # pragma: no-cover
raise ImportError('You want to use `scikit` plugins which are not installed yet,' # pragma: no-cover
' install it with `pip install scikit-learn`.')
def scale_minmax(x, min_val=0.0, max_val=1.0):
x_std = (x - x.min()) / (x.max() - x.min())
@ -28,6 +36,9 @@ class MFCC(object):
def __init__(self, **kwargs):
self.__dict__.update(kwargs)
def __repr__(self):
return f'{self.__class__.__name__}({self.__dict__})'
def __call__(self, y):
mfcc = librosa.feature.mfcc(y, **self.__dict__)
return mfcc
@ -35,27 +46,38 @@ class MFCC(object):
class NormalizeLocal(object):
def __init__(self):
self.cache: np.ndarray
pass
def __repr__(self):
return f'{self.__class__.__name__}({self.__dict__})'
def __call__(self, x: np.ndarray):
x[np.isnan(x)] = 0
x[np.isinf(x)] = 0
mean = x.mean()
std = x.std() + 0.0001
# Pytorch Version:
# x = x.__sub__(mean).__div__(std)
# tensor = tensor.__sub__(mean).__div__(std)
# Numpy Version
x = (x - mean) / std
x[np.isnan(x)] = 0
x[np.isinf(x)] = 0
return x
class NormalizeMelband(object):
def __init__(self):
self.cache: np.ndarray
pass
def __repr__(self):
return f'{self.__class__.__name__}({self.__dict__})'
def __call__(self, x: np.ndarray):
mean = x.mean(-1).unsqueeze(-1)
std = x.std(-1).unsqueeze(-1)
@ -66,10 +88,13 @@ class NormalizeMelband(object):
return x
class AudioToMel(object):
def __init__(self, amplitude_to_db=False, power_to_db=False, **kwargs):
class LibrosaAudioToMel(object):
def __init__(self, amplitude_to_db=False, power_to_db=False, **mel_kwargs):
assert not all([amplitude_to_db, power_to_db]), "Choose amplitude_to_db or power_to_db, not both!"
self.mel_kwargs = kwargs
# Mel kwargs are:
# sr n_mels n_fft hop_length
self.mel_kwargs = mel_kwargs
self.amplitude_to_db = amplitude_to_db
self.power_to_db = power_to_db
@ -89,6 +114,9 @@ class PowerToDB(object):
def __init__(self, running_max=False):
self.running_max = 0 if running_max else None
def __repr__(self):
return f'{self.__class__.__name__}({self.__dict__})'
def __call__(self, x):
if self.running_max is not None:
self.running_max = max(np.max(x), self.running_max)
@ -100,6 +128,9 @@ class LowPass(object):
def __init__(self, sr=16000):
self.sr = sr
def __repr__(self):
return f'{self.__class__.__name__}({self.__dict__})'
def __call__(self, x):
return butter_lowpass_filter(x, 1000, 1)
@ -108,12 +139,16 @@ class MelToImage(object):
def __init__(self):
pass
def __repr__(self):
return f'{self.__class__.__name__}({self.__dict__})'
def __call__(self, x):
# Source to Solution: https://stackoverflow.com/a/57204349
mels = np.log(x + 1e-9) # add small number to avoid log(0)
# min-max scale to fit inside 8-bit range
img = scale_minmax(mels, 0, 255).astype(np.uint8)
img = np.flip(img, axis=0) # put low frequencies at the bottom in image
img = scale_minmax(mels, 0, 255)
img = np.flip(img) # put low frequencies at the bottom in image
img = 255 - img # invert. make black==more energy
img = img.astype(np.float)
return img

View File

@ -0,0 +1,71 @@
import sys
from pathlib import Path
import pickle
from abc import ABC
from torch.utils.data import Dataset
from torchvision.transforms import Compose
from ml_lib.audio_toolset.audio_io import LibrosaAudioToMel, MelToImage
from ml_lib.audio_toolset.mel_dataset import TorchMelDataset
import librosa
class LibrosaAudioToMelDataset(Dataset):
@property
def audio_file_duration(self):
return librosa.get_duration(sr=self.mel_kwargs.get('sr', None), filename=self.audio_path)
@property
def sampling_rate(self):
return self.mel_kwargs.get('sr', None)
def __init__(self, audio_file_path, label, sample_segment_len=0, sample_hop_len=0, reset=False,
audio_augmentations=None, mel_augmentations=None, mel_kwargs=None, **kwargs):
super(LibrosaAudioToMelDataset, self).__init__()
# audio_file, sampling_rate = librosa.load(self.audio_path, sr=sampling_rate)
mel_kwargs.update(sr=mel_kwargs.get('sr', None) or librosa.get_samplerate(audio_file_path))
self.mel_kwargs = mel_kwargs
self.reset = reset
self.audio_path = Path(audio_file_path)
mel_folder_suffix = self.audio_path.parent.parent.name
self.mel_folder = Path(str(self.audio_path)
.replace(mel_folder_suffix, f'{mel_folder_suffix}_mel_folder')).parent.parent
self.mel_file_path = self.mel_folder / f'{self.audio_path.stem}.npy'
self.audio_augmentations = audio_augmentations
self.dataset = TorchMelDataset(self.mel_file_path, sample_segment_len, sample_hop_len, label,
self.audio_file_duration, mel_kwargs['sr'], mel_kwargs['hop_length'],
mel_kwargs['n_mels'], transform=mel_augmentations)
self._mel_transform = Compose([LibrosaAudioToMel(power_to_db=False, **mel_kwargs),
MelToImage()
])
def __getitem__(self, item):
return self.dataset[item]
def __len__(self):
return len(self.dataset)
def build_mel(self):
if self.reset:
self.mel_file_path.unlink(missing_ok=True)
if not self.mel_file_path.exists():
self.mel_file_path.parent.mkdir(parents=True, exist_ok=True)
with self.audio_path.open(mode='rb') as audio_file:
raw_sample, _ = librosa.core.load(audio_file, sr=self.sampling_rate)
mel_sample = self._mel_transform(raw_sample)
with self.mel_file_path.open('wb') as mel_file:
pickle.dump(mel_sample, mel_file, protocol=pickle.HIGHEST_PROTOCOL)
else:
pass
return self.mel_file_path.exists()

View File

@ -1,17 +1,20 @@
import numpy as np
from ml_lib.utils.transforms import _BaseTransformation
class NoiseInjection(object):
def __init__(self, noise_factor: float, sigma=0.5, mu=0.5):
assert noise_factor >= 0, f'max_shift_ratio has to be greater then 0, but was: {noise_factor}.'
class NoiseInjection(_BaseTransformation):
def __init__(self, noise_factor: float, sigma=1, mu=0):
super(NoiseInjection, self).__init__()
assert noise_factor >= 0, f'noise_factor has to be greater then 0, but was: {noise_factor}.'
self.mu = mu
self.sigma = sigma
self.noise_factor = noise_factor
def __call__(self, x: np.ndarray):
if self.noise_factor:
noise = np.random.uniform(0, self.noise_factor, size=x.shape)
noise = np.random.normal(self.mu, self.sigma, size=x.shape) * self.noise_factor
augmented_data = x + x * noise
# Cast back to same data type
augmented_data = augmented_data.astype(x.dtype)
@ -20,14 +23,15 @@ class NoiseInjection(object):
return x
class LoudnessManipulator(object):
class LoudnessManipulator(_BaseTransformation):
def __init__(self, max_factor: float):
super(LoudnessManipulator, self).__init__()
assert 1 > max_factor >= 0, f'max_shift_ratio has to be between [0,1], but was: {max_factor}.'
self.max_factor = max_factor
def __call__(self, x: np.ndarray):
def __call__(self, x):
if self.max_factor:
augmented_data = x + x * (np.random.random() * self.max_factor)
# Cast back to same data type
@ -37,11 +41,12 @@ class LoudnessManipulator(object):
return x
class ShiftTime(object):
class ShiftTime(_BaseTransformation):
valid_shifts = ['right', 'left', 'any']
def __init__(self, max_shift_ratio: float, shift_direction: str = 'any'):
super(ShiftTime, self).__init__()
assert 1 > max_shift_ratio >= 0, f'max_shift_ratio has to be between [0,1], but was: {max_shift_ratio}.'
assert shift_direction.lower() in self.valid_shifts, f'shift_direction has to be one of: {self.valid_shifts}'
self.max_shift_ratio = max_shift_ratio
@ -53,26 +58,27 @@ class ShiftTime(object):
if self.shift_direction == 'right':
shift = -1 * shift
elif self.shift_direction == 'any':
direction = np.random.choice([1, -1], 1)
direction = np.asscalar(np.random.choice([1, -1], 1))
shift = direction * shift
augmented_data = np.roll(x, shift)
# Set to silence for heading/ tailing
shift = int(shift)
if shift > 0:
augmented_data[:shift] = 0
augmented_data[:, :shift] = 0
else:
augmented_data[shift:] = 0
augmented_data[:, shift:] = 0
return augmented_data
else:
return x
class MaskAug(object):
class MaskAug(_BaseTransformation):
w_idx = -1
h_idx = -2
def __init__(self, duration_ratio_max=0.3, mask_with_noise=True):
super(MaskAug, self).__init__()
assertion = f'"duration_ratio" has to be within [0..1], but was: {duration_ratio_max}'
if isinstance(duration_ratio_max, (tuple, list)):
assert all([0 < max_val < 1 for max_val in duration_ratio_max]), assertion
@ -87,9 +93,9 @@ class MaskAug(object):
def __call__(self, x):
for dim in (self.w_idx, self.h_idx):
if self.duration_ratio_max[dim]:
start = int(np.random.choice(x.shape[dim], 1))
v_max = x.shape[dim] * self.duration_ratio_max[dim]
size = int(np.random.randint(0, v_max, 1))
start = np.asscalar(np.random.choice(x.shape[dim], 1))
v_max = int(x.shape[dim] * self.duration_ratio_max[dim])
size = np.asscalar(np.random.randint(0, v_max, 1))
end = int(min(start + size, x.shape[dim]))
size = end - start
if dim == self.w_idx:

View File

@ -0,0 +1,54 @@
import time
from pathlib import Path
import pickle
from torch.utils.data import Dataset
from ml_lib.modules.util import AutoPadToShape
class TorchMelDataset(Dataset):
def __init__(self, mel_path, sub_segment_len, sub_segment_hop_len, label, audio_file_len,
sampling_rate, mel_hop_len, n_mels, transform=None, auto_pad_to_shape=True):
super(TorchMelDataset, self).__init__()
self.sampling_rate = int(sampling_rate)
self.audio_file_len = float(audio_file_len)
if auto_pad_to_shape and sub_segment_len:
self.padding = AutoPadToShape((int(n_mels), int(sub_segment_len)))
else:
self.padding = None
self.path = Path(mel_path)
self.sub_segment_len = int(sub_segment_len)
self.mel_hop_len = int(mel_hop_len)
self.sub_segment_hop_len = int(sub_segment_hop_len)
self.n = int((self.sampling_rate / self.mel_hop_len) * self.audio_file_len + 1)
if self.sub_segment_len and self.sub_segment_hop_len and (self.n - self.sub_segment_len) > 0:
self.offsets = list(range(0, self.n - self.sub_segment_len, self.sub_segment_hop_len))
else:
self.offsets = [0]
if len(self) == 0:
print('what happend here')
self.label = label
self.transform = transform
def __getitem__(self, item):
with self.path.open('rb') as mel_file:
mel_spec = pickle.load(mel_file, fix_imports=True)
start = self.offsets[item]
sub_segments_attributes_set = self.sub_segment_len and self.sub_segment_hop_len
sub_segment_length_smaller_then_tot_length = self.sub_segment_len < mel_spec.shape[1]
if sub_segments_attributes_set and sub_segment_length_smaller_then_tot_length:
duration = self.sub_segment_len
else:
duration = mel_spec.shape[1]
snippet = mel_spec[:, start: start + duration]
if self.transform:
snippet = self.transform(snippet)
if self.padding:
snippet = self.padding(snippet)
return self.path.__str__(), snippet, self.label
def __len__(self):
return len(self.offsets)

View File

@ -0,0 +1,27 @@
from typing import Union
import numpy as np
class Normalize(object):
def __init__(self, min_db_level: Union[int, float]):
self.min_db_level = min_db_level
def __repr__(self):
return f'{self.__class__.__name__}({self.__dict__})'
def __call__(self, s: np.ndarray) -> np.ndarray:
return np.clip((s - self.min_db_level) / -self.min_db_level, 0, 1)
class DeNormalize(object):
def __init__(self, min_db_level: Union[int, float]):
self.min_db_level = min_db_level
def __repr__(self):
return f'{self.__class__.__name__}({self.__dict__})'
def __call__(self, s: np.ndarray) -> np.ndarray:
return (np.clip(s, 0, 1) * -self.min_db_level) + self.min_db_level

View File

@ -1,13 +1,21 @@
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc
try:
import matplotlib.pyplot as plt
except ImportError: # pragma: no-cover
raise ImportError('You want to use `matplotlib` plugins which are not installed yet,' # pragma: no-cover
' install it with `pip install matplotlib`.')
try:
from sklearn.metrics import roc_curve, auc, recall_score
except ImportError: # pragma: no-cover
raise ImportError('You want to use `sklearn` plugins which are not installed yet,' # pragma: no-cover
' install it with `pip install scikit-learn`.')
class ROCEvaluation(object):
linewidth = 2
def __init__(self, plot_roc=False):
self.plot_roc = plot_roc
def __init__(self, plot=False):
self.plot = plot
self.epoch = 0
def __call__(self, prediction, label):
@ -15,7 +23,7 @@ class ROCEvaluation(object):
# Compute ROC curve and ROC area
fpr, tpr, _ = roc_curve(prediction, label)
roc_auc = auc(fpr, tpr)
if self.plot_roc:
if self.plot:
_ = plt.gcf()
plt.plot(fpr, tpr, color='darkorange', lw=self.linewidth, label=f'ROC curve (area = {roc_auc})')
self._prepare_fig()
@ -32,3 +40,32 @@ class ROCEvaluation(object):
fig.legend(loc="lower right")
return fig
class UAREvaluation(object):
def __init__(self, labels: list, plot=False):
self.labels = labels
self.plot_roc = plot
self.epoch = 0
def __call__(self, prediction, label):
# Compute uar score - UnweightedAverageRecal
uar_score = recall_score(label, prediction, labels=self.labels, average='macro',
sample_weight=None, zero_division='warn')
return uar_score
def _prepare_fig(self):
raise NotImplementedError # TODO Implement a nice visualization
fig = plt.gcf()
ax = plt.gca()
plt.plot([0, 1], [0, 1], color='navy', lw=self.linewidth, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
fig.legend(loc="lower right")
return fig

0
metrics/__init__.py Normal file
View File

13
metrics/_base_score.py Normal file
View File

@ -0,0 +1,13 @@
from abc import ABC
class _BaseScores(ABC):
def __init__(self, lightning_model):
self.model = lightning_model
pass
def __call__(self, outputs):
# summary_dict = dict()
# return summary_dict
raise NotImplementedError

View File

@ -0,0 +1,47 @@
import numpy as np
from einops import reduce
import torch
from sklearn.ensemble import IsolationForest
from sklearn.metrics import recall_score, roc_auc_score, average_precision_score
from ml_lib.metrics._base_score import _BaseScores
class AttentionRollout(_BaseScores):
def __init__(self, *args):
super(AttentionRollout, self).__init__(*args)
pass
def __call__(self, outputs):
summary_dict = dict()
#######################################################################################
# Additional Score - Histogram Distances - Image Plotting
#######################################################################################
#
# INIT
attn_weights = [output['attn_weights'].cpu().numpy() for output in outputs]
attn_reduce_heads = [reduce(x, '') for x in attn_weights]
if self.model.params.use_residual:
residual_att = np.eye(att_mat.shape[1])[None, ...]
aug_att_mat = att_mat + residual_att
aug_att_mat = aug_att_mat / aug_att_mat.sum(axis=-1)[..., None]
else:
aug_att_mat = att_mat
joint_attentions = np.zeros(aug_att_mat.shape)
layers = joint_attentions.shape[0]
joint_attentions[0] = aug_att_mat[0]
for i in np.arange(1, layers):
joint_attentions[i] = aug_att_mat[i].dot(joint_attentions[i - 1])

View File

@ -0,0 +1,68 @@
import numpy as np
import torch
from sklearn.ensemble import IsolationForest
from sklearn.metrics import recall_score, roc_auc_score, average_precision_score
from ml_lib.metrics._base_score import _BaseScores
from ml_lib.utils.tools import to_one_hot
class BinaryScores(_BaseScores):
def __init__(self, *args):
super(BinaryScores, self).__init__(*args)
def __call__(self, outputs):
summary_dict = dict()
# Additional Score like the unweighted Average Recall:
#########################
# INIT
if isinstance(outputs['batch_y'], torch.Tensor):
y_true = outputs['batch_y'].cpu().numpy()
else:
y_true = torch.cat([output['batch_y'] for output in outputs]).cpu().numpy()
if isinstance(outputs['y'], torch.Tensor):
y_pred = outputs['y'].cpu().numpy()
else:
y_pred = torch.cat([output['y'] for output in outputs]).squeeze().cpu().float().numpy()
# UnweightedAverageRecall
# y_true = torch.cat([output['batch_y'] for output in outputs]).cpu().numpy()
# y_pred = torch.cat([output['element_wise_recon_error'] for output in outputs]).squeeze().cpu().numpy()
# How to apply a threshold manualy
# y_pred = (y_pred >= 0.5).astype(np.float32)
# How to apply a threshold by IF (Isolation Forest)
clf = IsolationForest()
y_score = clf.fit_predict(y_pred.reshape(-1, 1))
y_score = (np.asarray(y_score) == -1).astype(np.float32)
uar_score = recall_score(y_true, y_score, labels=[0, 1], average='macro',
sample_weight=None, zero_division='warn')
summary_dict.update(dict(uar_score=uar_score))
#########################
# Precission
precision_score = average_precision_score(y_true, y_score)
summary_dict.update(dict(precision_score=precision_score))
#########################
# AUC
try:
auc_score = roc_auc_score(y_true=y_true, y_score=y_score)
summary_dict.update(dict(auc_score=auc_score))
except ValueError:
summary_dict.update(dict(auc_score=-1))
#########################
# pAUC
try:
pauc = roc_auc_score(y_true=y_true, y_score=y_score, max_fpr=0.15)
summary_dict.update(dict(pauc_score=pauc))
except ValueError:
summary_dict.update(dict(pauc_score=-1))
return summary_dict

View File

@ -0,0 +1,68 @@
from itertools import cycle
import numpy as np
import torch
from sklearn.metrics import roc_curve, auc, roc_auc_score, ConfusionMatrixDisplay, confusion_matrix
from scipy.spatial.distance import cdist
from ml_lib.metrics._base_score import _BaseScores
from matplotlib import pyplot as plt
class GenerativeTaskEval(_BaseScores):
def __init__(self, *args):
super(GenerativeTaskEval, self).__init__(*args)
pass
def __call__(self, outputs):
summary_dict = dict()
#######################################################################################
# Additional Score - Histogram Distances - Image Plotting
#######################################################################################
#
# INIT
y_true = torch.cat([output['batch_y'] for output in outputs]).cpu().numpy()
y_pred = torch.cat([output['y'] for output in outputs]).squeeze().cpu().numpy()
attn_weights = torch.cat([output['attn_weights'] for output in outputs]).squeeze().cpu().numpy()
######################################################################################
#
# Histogram comparission
y_true_hist = np.histogram(y_true, bins=128)[0] # Todo: Find a better value
y_pred_hist = np.histogram(y_pred, bins=128)[0] # Todo: Find a better value
# L2 norm == euclidean distance
hist_euc_dist = cdist(np.expand_dims(y_true_hist, axis=0), np.expand_dims(y_pred_hist, axis=0),
metric='euclidean')
# Manhattan Distance
hist_manhattan_dist = cdist(np.expand_dims(y_true_hist, axis=0), np.expand_dims(y_pred_hist, axis=0),
metric='cityblock')
summary_dict.update(hist_manhattan_dist=hist_manhattan_dist, hist_euc_dist=hist_euc_dist)
#######################################################################################
#
idx = np.random.choice(np.arange(y_true.shape[0]), 1).item()
ax = plt.imshow(y_true[idx].squeeze())
# Plot using a small number of colors, with unevenly spaced boundaries.
ax2 = plt.imshow(attn_weights[idx].sq, interpolation='nearest', aspect='auto', extent=ax.get_extent())
self.model.logger.log_image('ROC', image=plt.gcf(), step=self.model.current_epoch)
plt.clf()
#######################################################################################
#
#######################################################################################
#
plt.close('all')
return summary_dict

View File

@ -0,0 +1,142 @@
from itertools import cycle
import numpy as np
import torch
from sklearn.metrics import f1_score, roc_curve, auc, roc_auc_score, ConfusionMatrixDisplay, confusion_matrix, \
recall_score
from ml_lib.metrics._base_score import _BaseScores
from ml_lib.utils.tools import to_one_hot
from matplotlib import pyplot as plt
class MultiClassScores(_BaseScores):
def __init__(self, *args):
super(MultiClassScores, self).__init__(*args)
pass
def __call__(self, outputs, class_names=None):
summary_dict = dict()
class_names = class_names or range(self.model.params.n_classes)
#######################################################################################
# Additional Score - UAR - ROC - Conf. Matrix - F1
#######################################################################################
#
# INIT
if isinstance(outputs['batch_y'], torch.Tensor):
y_true = outputs['batch_y'].cpu().numpy()
else:
y_true = torch.cat([output['batch_y'] for output in outputs]).cpu().numpy()
y_true_one_hot = to_one_hot(y_true, self.model.params.n_classes)
if isinstance(outputs['y'], torch.Tensor):
y_pred = outputs['y'].cpu().numpy()
else:
y_pred = torch.cat([output['y'] for output in outputs]).squeeze().cpu().float().numpy()
y_pred_max = np.argmax(y_pred, axis=1)
class_names = {val: key for val, key in enumerate(class_names)}
######################################################################################
#
# F1 SCORE
micro_f1_score = f1_score(y_true, y_pred_max, labels=None, pos_label=1, average='micro', sample_weight=None,
zero_division=True)
macro_f1_score = f1_score(y_true, y_pred_max, labels=None, pos_label=1, average='macro', sample_weight=None,
zero_division=True)
summary_dict.update(dict(micro_f1_score=micro_f1_score, macro_f1_score=macro_f1_score))
######################################################################################
#
# Unweichted Average Recall
uar = recall_score(y_true, y_pred_max, labels=[0, 1, 2, 3, 4], average='macro',
sample_weight=None, zero_division='warn')
summary_dict.update(dict(uar_score=uar))
#######################################################################################
#
# ROC Curve
# Compute ROC curve and ROC area for each class
fpr = dict()
tpr = dict()
roc_auc = dict()
for i in range(self.model.params.n_classes):
fpr[i], tpr[i], _ = roc_curve(y_true_one_hot[:, i], y_pred[:, i])
roc_auc[i] = auc(fpr[i], tpr[i])
# Compute micro-average ROC curve and ROC area
fpr["micro"], tpr["micro"], _ = roc_curve(y_true_one_hot.ravel(), y_pred.ravel())
roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])
# First aggregate all false positive rates
all_fpr = np.unique(np.concatenate([fpr[i] for i in range(self.model.params.n_classes)]))
# Then interpolate all ROC curves at this points
mean_tpr = np.zeros_like(all_fpr)
for i in range(self.model.params.n_classes):
mean_tpr += np.interp(all_fpr, fpr[i], tpr[i])
# Finally average it and compute AUC
mean_tpr /= self.model.params.n_classes
fpr["macro"] = all_fpr
tpr["macro"] = mean_tpr
roc_auc["macro"] = auc(fpr["macro"], tpr["macro"])
# Plot all ROC curves
plt.figure()
plt.plot(fpr["micro"], tpr["micro"],
label=f'micro ROC ({round(roc_auc["micro"], 2)})',
color='deeppink', linestyle=':', linewidth=4)
plt.plot(fpr["macro"], tpr["macro"],
label=f'macro ROC({round(roc_auc["macro"], 2)})',
color='navy', linestyle=':', linewidth=4)
colors = cycle(['firebrick', 'orangered', 'gold', 'olive', 'limegreen', 'aqua',
'dodgerblue', 'slategrey', 'royalblue', 'indigo', 'fuchsia'], )
for i, color in zip(range(self.model.params.n_classes), colors):
plt.plot(fpr[i], tpr[i], color=color, lw=2, label=f'{class_names[i]} ({round(roc_auc[i], 2)})')
plt.plot([0, 1], [0, 1], 'k--', lw=2)
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.legend(loc="lower right")
self.model.logger.log_image('ROC', image=plt.gcf(), step=self.model.current_epoch)
# self.model.logger.log_image('ROC', image=plt.gcf(), step=self.model.current_epoch, ext='pdf')
plt.clf()
#######################################################################################
#
# ROC AUC SCORE
try:
macro_roc_auc_ovr = roc_auc_score(y_true_one_hot, y_pred, multi_class="ovr",
average="macro")
summary_dict.update(macro_roc_auc_ovr=macro_roc_auc_ovr)
except ValueError:
micro_roc_auc_ovr = roc_auc_score(y_true_one_hot, y_pred, multi_class="ovr",
average="micro")
summary_dict.update(micro_roc_auc_ovr=micro_roc_auc_ovr)
#######################################################################################
#
# Confusion matrix
fig1, ax1 = plt.subplots(dpi=96)
cm = confusion_matrix([class_names[x] for x in y_true], [class_names[x] for x in y_pred_max],
labels=[class_names[key] for key in class_names.keys()],
normalize='true')
disp = ConfusionMatrixDisplay(confusion_matrix=cm,
display_labels=[class_names[i] for i in range(self.model.params.n_classes)]
)
disp.plot(include_values=True, ax=ax1)
self.model.logger.log_image('Confusion_Matrix', image=fig1, step=self.model.current_epoch)
# self.model.logger.log_image('Confusion_Matrix', image=disp.figure_, step=self.model.current_epoch, ext='pdf')
plt.close('all')
return summary_dict

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -1,11 +1,17 @@
import warnings
from pathlib import Path
from typing import Union
import torch
import warnings
from torch import nn
from torch.nn import functional as F
from modules.utils import AutoPad, Interpolate, ShapeMixin, F_x, Flatten
import sys
sys.path.append(str(Path(__file__).parent))
from .util import AutoPad, Interpolate, ShapeMixin, F_x, Flatten
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
@ -15,23 +21,24 @@ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
###################
class LinearModule(ShapeMixin, nn.Module):
def __init__(self, in_shape, out_features, bias=True, activation=None,
norm=False, dropout: Union[int, float] = 0, **kwargs):
def __init__(self, in_shape, out_features, use_bias=True, activation=None,
use_norm=False, dropout: Union[int, float] = 0, **kwargs):
if list(kwargs.keys()):
warnings.warn(f'The following arguments have been ignored: \n {list(kwargs.keys())}')
super(LinearModule, self).__init__()
self.in_shape = in_shape
self.flat = Flatten(self.in_shape) if isinstance(self.in_shape, (tuple, list)) else F_x(in_shape)
self.dropout = nn.Dropout(dropout) if dropout else F_x(self.flat.shape)
self.norm = nn.BatchNorm1d(self.flat.shape) if norm else F_x(self.flat.shape)
self.linear = nn.Linear(self.flat.shape, out_features, bias=bias)
self.norm = nn.LayerNorm(self.flat.shape) if use_norm else F_x(self.flat.shape)
self.linear = nn.Linear(self.flat.shape, out_features, bias=use_bias)
self.activation = activation() if activation else F_x(self.linear.out_features)
def forward(self, x):
tensor = self.flat(x)
tensor = self.dropout(tensor)
tensor = self.norm(tensor)
tensor = self.linear(tensor)
tensor = self.linear(tensor.float())
tensor = self.activation(tensor)
return tensor
@ -39,14 +46,22 @@ class LinearModule(ShapeMixin, nn.Module):
class ConvModule(ShapeMixin, nn.Module):
def __init__(self, in_shape, conv_filters, conv_kernel, activation: nn.Module = nn.ELU, pooling_size=None,
bias=True, norm=False, dropout: Union[int, float] = 0,
bias=True, use_norm=False, dropout: Union[int, float] = 0, trainable: bool = True,
conv_class=nn.Conv2d, conv_stride=1, conv_padding=0, **kwargs):
super(ConvModule, self).__init__()
assert isinstance(in_shape, (tuple, list)), f'"in_shape" should be a [list, tuple], but was {type(in_shape)}'
assert len(in_shape) == 3, f'Length should be 3, but was {len(in_shape)}'
if len(kwargs.keys()):
warnings.warn(f'The following arguments have been ignored: \n {list(kwargs.keys())}')
if use_norm and not trainable:
warnings.warn('You set this module to be not trainable but the running norm is active.\n' +
'We set it to "eval" mode.\n' +
'Keep this in mind if you do a finetunning or retraining step.'
)
# Module Parameters
self.in_shape = in_shape
self.trainable = trainable
in_channels, height, width = in_shape[0], in_shape[1], in_shape[2]
# Convolution Parameters
@ -56,13 +71,19 @@ class ConvModule(ShapeMixin, nn.Module):
self.conv_kernel = conv_kernel
# Modules
self.activation = activation() or F_x(None)
self.activation = activation() or nn.Identity()
self.norm = nn.LayerNorm(self.in_shape, eps=1e-04) if use_norm else F_x(None)
self.dropout = nn.Dropout2d(dropout) if dropout else F_x(None)
self.pooling = nn.MaxPool2d(pooling_size) if pooling_size else F_x(None)
self.norm = nn.BatchNorm2d(in_channels, eps=1e-04) if norm else F_x(None)
self.conv = conv_class(in_channels, self.conv_filters, self.conv_kernel, bias=bias,
padding=self.padding, stride=self.stride
)
if not self.trainable:
for param in self.parameters():
param.requires_grad = False
self.norm = self.norm.eval()
else:
pass
def forward(self, x):
tensor = self.norm(x)
@ -73,13 +94,49 @@ class ConvModule(ShapeMixin, nn.Module):
return tensor
class PreInitializedConvModule(ShapeMixin, nn.Module):
def __init__(self, in_shape, weight_matrix):
super(PreInitializedConvModule, self).__init__()
self.in_shape = in_shape
self.weight_matrix = weight_matrix
raise NotImplementedError
# ToDo Get the weight_matrix shape and init a conv_module of similar size,
# override the weights then.
def forward(self, x):
x = torch.matmul(x, self.weight_matrix) # ToDo: This is an Placeholder
return x
class SobelFilter(ShapeMixin, nn.Module):
def __init__(self, in_shape):
super(SobelFilter, self).__init__()
self.in_shape = in_shape
self.sobel_x = torch.tensor([[1, 0, -1], [2, 0, -2], [1, 0, -1]]).view(1, 1, 3, 3)
self.sobel_y = torch.tensor([[1, 2, 1], [0, 0, 0], [-1, 2, -1]]).view(1, 1, 3, 3)
def forward(self, x):
# Apply Filters
g_x = F.conv2d(x, self.sobel_x)
g_y = F.conv2d(x, self.sobel_y)
# Calculate the Edge
g = torch.add(*[torch.pow(tensor, 2) for tensor in [g_x, g_y]])
# Calculate the Gradient
g_grad = torch.atan2(g_x, g_y)
return g_x, g_y, g, g_grad
class DeConvModule(ShapeMixin, nn.Module):
def __init__(self, in_shape, conv_filters, conv_kernel, conv_stride=1, conv_padding=0,
dropout: Union[int, float] = 0, autopad=0,
activation: Union[None, nn.Module] = nn.ReLU, interpolation_scale=0,
bias=True, norm=False):
bias=True, use_norm=False, **kwargs):
super(DeConvModule, self).__init__()
warnings.warn(f'The following arguments have been ignored: \n {list(kwargs.keys())}')
in_channels, height, width = in_shape[0], in_shape[1], in_shape[2]
self.padding = conv_padding
self.conv_kernel = conv_kernel
@ -89,8 +146,8 @@ class DeConvModule(ShapeMixin, nn.Module):
self.autopad = AutoPad() if autopad else lambda x: x
self.interpolation = Interpolate(scale_factor=interpolation_scale) if interpolation_scale else lambda x: x
self.norm = nn.BatchNorm2d(in_channels, eps=1e-04) if norm else lambda x: x
self.dropout = nn.Dropout2d(dropout) if dropout else lambda x: x
self.norm = nn.LayerNorm(in_channels, eps=1e-04) if use_norm else F_x(self.in_shape)
self.dropout = nn.Dropout2d(dropout) if dropout else F_x(self.in_shape)
self.de_conv = nn.ConvTranspose2d(in_channels, self.conv_filters, self.conv_kernel, bias=bias,
padding=self.padding, stride=self.stride)
@ -109,14 +166,13 @@ class DeConvModule(ShapeMixin, nn.Module):
class ResidualModule(ShapeMixin, nn.Module):
def __init__(self, in_shape, module_class, n, norm=False, **module_parameters):
def __init__(self, in_shape, module_class, n, use_norm=False, **module_parameters):
assert n >= 1
super(ResidualModule, self).__init__()
self.in_shape = in_shape
module_parameters.update(in_shape=in_shape)
if norm:
self.norm = nn.BatchNorm1d if len(self.in_shape) <= 2 else nn.BatchNorm2d
self.norm = self.norm(self.in_shape if isinstance(self.in_shape, int) else self.in_shape[0])
if use_norm:
self.norm = nn.LayerNorm(self.in_shape if isinstance(self.in_shape, int) else self.in_shape[0])
else:
self.norm = F_x(self.in_shape)
self.activation = module_parameters.get('activation', None)
@ -128,8 +184,9 @@ class ResidualModule(ShapeMixin, nn.Module):
assert self.in_shape == self.shape, f'The in_shape: {self.in_shape} - must match the out_shape: {self.shape}.'
def forward(self, x):
tensor = self.norm(x)
for module in self.residual_block:
tensor = module(x)
tensor = module(tensor)
# noinspection PyUnboundLocalVariable
tensor = tensor + x
@ -155,3 +212,106 @@ class RecurrentModule(ShapeMixin, nn.Module):
def forward(self, x):
tensor = self.rnn(x)
return tensor
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout=0., activation=nn.GELU):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
activation() or F_x(None),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
activation() or F_x(None),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class Attention(nn.Module):
def __init__(self, dim, heads=8, head_dim=64, dropout=0.):
super().__init__()
inner_dim = head_dim * heads
project_out = not (heads == 1 and head_dim == dim)
self.heads = heads
self.scale = head_dim ** -0.5
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x, mask=None, return_attn_weights=False):
from einops import rearrange, repeat
# noinspection PyTupleAssignmentBalance
b, n, _, h = *x.shape, self.heads
qkv = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), qkv)
dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale
if mask is not None:
mask_value = -torch.finfo(dots.dtype).max
mask = F.pad(mask.flatten(1), (1, 0), value=True)
assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions'
mask = mask[:, None, :] * mask[:, :, None]
mask = repeat(mask, 'b n d -> b h n d', h=h) # My addition
dots.masked_fill_(~mask, mask_value)
# dots.masked_fill_(mask, mask_value) # My addition
del mask
attn = dots.softmax(dim=-1)
out = torch.einsum('bhij,bhjd->bhid', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
out = self.to_out(out)
if return_attn_weights:
return out, attn
else:
return out
class TransformerModule(ShapeMixin, nn.Module):
def __init__(self, in_shape, depth, heads, mlp_dim, head_dim=32, dropout=None, use_norm=False,
activation=nn.GELU, use_residual=True):
super(TransformerModule, self).__init__()
self.in_shape = in_shape
self.use_residual = use_residual
self.flat = Flatten(self.in_shape) if isinstance(self.in_shape, (tuple, list)) else F_x(in_shape)
self.embedding_dim = self.flat.flat_shape
self.norm = nn.LayerNorm(self.embedding_dim) if use_norm else F_x(None)
self.attns = nn.ModuleList([Attention(self.embedding_dim, heads=heads, dropout=dropout, head_dim=head_dim)
for _ in range(depth)])
self.mlps = nn.ModuleList([FeedForward(self.embedding_dim, mlp_dim, dropout=dropout, activation=activation)
for _ in range(depth)])
def forward(self, x, mask=None, return_attn_weights=False, **_):
tensor = self.flat(x)
attn_weights = list()
for attn, mlp in zip(self.attns, self.mlps):
# Attention
attn_tensor = self.norm(tensor)
if return_attn_weights:
attn_tensor, attn_weight = attn(attn_tensor, mask=mask, return_attn_weights=return_attn_weights)
attn_weights.append(attn_weight)
else:
attn_tensor = attn(attn_tensor, mask=mask)
tensor = tensor + attn_tensor if self.use_residual else attn_tensor
# MLP
mlp_tensor = self.norm(tensor)
mlp_tensor = mlp(mlp_tensor)
tensor = tensor + mlp_tensor if self.use_residual else mlp_tensor
return (tensor, attn_weights) if return_attn_weights else tensor

View File

@ -0,0 +1,65 @@
import torch
from torch import nn
from torch.nn import ReLU
try:
from torch_geometric.nn import PointConv, fps, radius, global_max_pool, knn_interpolate
except ImportError:
print('Install torch-geometric to use this package.')
class SAModule(torch.nn.Module):
def __init__(self, ratio, r, nn):
super(SAModule, self).__init__()
self.ratio = ratio
self.r = r
self.conv = PointConv(nn)
def forward(self, x, pos, batch):
idx = fps(pos, batch, ratio=self.ratio)
row, col = radius(pos, pos[idx], self.r, batch, batch[idx],
max_num_neighbors=64)
edge_index = torch.stack([col, row], dim=0)
x = self.conv(x, (pos, pos[idx]), edge_index)
pos, batch = pos[idx], batch[idx]
return x, pos, batch
class GlobalSAModule(nn.Module):
def __init__(self, nn, channels=3):
super(GlobalSAModule, self).__init__()
self.nn = nn
self.channels = channels
def forward(self, x, pos, batch):
x = self.nn(torch.cat([x, pos], dim=1))
x = global_max_pool(x, batch)
pos = pos.new_zeros((x.size(0), self.channels))
batch = torch.arange(x.size(0), device=batch.device)
return x, pos, batch
class MLP(nn.Module):
def __init__(self, channels, norm=True):
super(MLP, self).__init__()
self.net = nn.Sequential(*[
nn.Sequential(nn.Linear(channels[i - 1], channels[i]), ReLU(), nn.BatchNorm1d(channels[i]))
for i in range(1, len(channels))
]).double()
def forward(self, x, *args, **kwargs):
return self.net(x)
class FPModule(torch.nn.Module):
def __init__(self, k, nn):
super(FPModule, self).__init__()
self.k = k
self.nn = nn
def forward(self, x, pos, batch, x_skip, pos_skip, batch_skip):
x = knn_interpolate(x, pos, pos_skip, batch, batch_skip, k=self.k)
if x_skip is not None:
x = torch.cat([x, x_skip], dim=1)
x = self.nn(x)
return x, pos_skip, batch_skip

View File

@ -1,80 +1,144 @@
#
# Full Model Parts
###################
from argparse import Namespace
from functools import reduce
from typing import Union, List, Tuple
import torch
from abc import ABC
from operator import mul
from torch import nn
from torch.utils.data import DataLoader
from modules.utils import ShapeMixin
from .blocks import ConvModule, DeConvModule, LinearModule
from .util import ShapeMixin, LightningBaseModule, Flatten
class Generator(nn.Module):
@property
def shape(self):
x = torch.randn(self.lat_dim).unsqueeze(0)
output = self(x)
return output.shape[1:]
class AEBaseModule(LightningBaseModule, ABC):
# noinspection PyUnresolvedReferences
def __init__(self, out_channels, re_shape, lat_dim, use_norm=False, use_bias=True, dropout: Union[int, float] = 0,
filters: List[int] = None, activation=nn.ReLU):
def generate_random_image(self, dataloader: Union[None, str, DataLoader] = None,
lat_min: Union[Tuple, List, None] = None,
lat_max: Union[Tuple, List, None] = None):
assert bool(dataloader) ^ bool(lat_min and lat_max), 'Decide wether to give min, max or a dataloader, not both.'
min_max = self._find_min_max(dataloader) if dataloader else [None, None]
# assert not any([tensor is None for tensor in min_max])
lat_min = torch.as_tensor(lat_min or min_max[0])
lat_max = lat_max or min_max[1]
random_z = torch.rand((1, self.lat_dim))
random_z = random_z * (abs(lat_min) + lat_max) - abs(lat_min)
return self.decoder(random_z).squeeze()
def encode(self, x):
if len(x.shape) == 3:
x = x.unsqueeze(0)
return self.encoder(x)
def _find_min_max(self, dataloader):
encodings = list()
for batch in dataloader:
encodings.append(self.encode(batch))
encodings = torch.cat(encodings, dim=0)
min_lat = encodings.min(dim=1)
max_lat = encodings.max(dim=1)
return min_lat, max_lat
def decode_lat_evenly(self, n: int,
dataloader: Union[None, str, DataLoader] = None,
lat_min: Union[Tuple, List, None] = None,
lat_max: Union[Tuple, List, None] = None):
assert bool(dataloader) ^ bool(lat_min and lat_max), 'Decide wether to give min, max or a dataloader, not both.'
min_max = self._find_min_max(dataloader) if dataloader else [None, None]
lat_min = lat_min or min_max[0]
lat_max = lat_max or min_max[1]
random_latent_samples = torch.stack([torch.linspace(lat_min[i].item(), lat_max[i].item(), n)
for i in range(self.params.lat_dim)], dim=-1).cpu().detach()
return self.decode(random_latent_samples).cpu().detach()
def decode(self, z):
try:
if len(z.shape) == 1:
z = z.unsqueeze(0)
except AttributeError:
# Does not seem to be a tensor.
pass
return self.decoder(z).squeeze()
def encode_and_restore(self, x):
x = self.transfer_batch_to_device(x, self.device)
if len(x.shape) == 3:
x = x.unsqueeze(0)
z = self.encode(x)
try:
z = z.squeeze()
except AttributeError:
# Does not seem to be a tensor.
pass
x_hat = self.decode(z)
return Namespace(main_out=x_hat.squeeze(), latent_out=z)
class Generator(ShapeMixin, nn.Module):
def __init__(self, in_shape, out_channels, re_shape, use_norm=False, use_bias=True,
dropout: Union[int, float] = 0, interpolations: List[int] = None,
filters: List[int] = None, kernels: List[int] = None, activation=nn.ReLU,
**kwargs):
super(Generator, self).__init__()
assert filters, '"Filters" has to be a list of int len 3'
self.filters = filters
self.activation = activation
self.inner_activation = activation()
assert filters, '"Filters" has to be a list of int.'
assert filters, '"Filters" has to be a list of int.'
kernels = kernels if kernels else [3] * len(filters)
assert len(filters) == len(kernels), '"Filters" and "Kernels" has to be of same length.'
interpolations = interpolations or [2, 2, 2]
self.in_shape = in_shape
self.activation = activation()
self.out_activation = None
self.lat_dim = lat_dim
self.dropout = dropout
self.l1 = nn.Linear(self.lat_dim, reduce(mul, re_shape), bias=use_bias)
self.l1 = LinearModule(in_shape, reduce(mul, re_shape), bias=use_bias, activation=activation)
# re_shape = (self.feature_mixed_dim // reduce(mul, re_shape[1:]), ) + tuple(re_shape[1:])
self.flat = Flatten(to=re_shape)
self.flat = Flatten(self.l1.shape, to=re_shape)
self.de_conv_list = nn.ModuleList()
self.deconv1 = DeConvModule(re_shape, conv_filters=self.filters[0],
conv_kernel=5,
conv_padding=2,
last_shape = re_shape
for conv_filter, conv_kernel, interpolation in zip(reversed(filters), kernels, interpolations):
# noinspection PyTypeChecker
self.de_conv_list.append(DeConvModule(last_shape, conv_filters=conv_filter,
conv_kernel=conv_kernel,
conv_padding=conv_kernel-2,
conv_stride=1,
normalize=use_norm,
activation=self.activation,
interpolation_scale=2,
activation=activation,
interpolation_scale=interpolation,
dropout=self.dropout
)
self.deconv2 = DeConvModule(self.deconv1.shape, conv_filters=self.filters[1],
conv_kernel=3,
conv_padding=1,
conv_stride=1,
normalize=use_norm,
activation=self.activation,
interpolation_scale=2,
dropout=self.dropout
)
last_shape = self.de_conv_list[-1].shape
self.deconv3 = DeConvModule(self.deconv2.shape, conv_filters=self.filters[2],
conv_kernel=3,
conv_padding=1,
conv_stride=1,
normalize=use_norm,
activation=self.activation,
interpolation_scale=2,
dropout=self.dropout
)
self.deconv4 = DeConvModule(self.deconv3.shape, conv_filters=out_channels,
conv_kernel=3,
conv_padding=1,
# normalize=norm,
activation=self.out_activation
self.de_conv_out = DeConvModule(self.de_conv_list[-1].shape, conv_filters=out_channels, conv_kernel=3,
conv_padding=1, activation=self.out_activation
)
def forward(self, z):
tensor = self.l1(z)
tensor = self.inner_activation(tensor)
tensor = self.activation(tensor)
tensor = self.flat(tensor)
tensor = self.deconv1(tensor)
tensor = self.deconv2(tensor)
tensor = self.deconv3(tensor)
tensor = self.deconv4(tensor)
for de_conv in self.de_conv_list:
tensor = de_conv(tensor)
tensor = self.de_conv_out(tensor)
return tensor
def size(self):
@ -114,18 +178,17 @@ class UnitGenerator(Generator):
return tensor
class BaseEncoder(ShapeMixin, nn.Module):
class BaseCNNEncoder(ShapeMixin, nn.Module):
# noinspection PyUnresolvedReferences
def __init__(self, in_shape, lat_dim=256, use_bias=True, use_norm=False, dropout: Union[int, float] = 0,
latent_activation: Union[nn.Module, None] = None, activation: nn.Module = nn.ELU,
filters: List[int] = None):
super(BaseEncoder, self).__init__()
assert filters, '"Filters" has to be a list of int len 3'
# Optional Padding for odd image-sizes
# Obsolet, already Done by autopadding module on incoming tensors
# in_shape = [x+1 if x % 2 != 0 and idx else x for idx, x in enumerate(in_shape)]
filters: List[int] = None, kernels: Union[List[int], int, None] = None, **kwargs):
super(BaseCNNEncoder, self).__init__()
assert filters, '"Filters" has to be a list of int'
kernels = kernels or [3] * len(filters)
kernels = kernels if not isinstance(kernels, int) else [kernels] * len(filters)
assert len(kernels) == len(filters), 'Length of "Filters" and "Kernels" has to be same.'
# Parameters
self.lat_dim = lat_dim
@ -133,52 +196,39 @@ class BaseEncoder(ShapeMixin, nn.Module):
self.use_bias = use_bias
self.latent_activation = latent_activation() if latent_activation else None
self.conv_list = nn.ModuleList()
# Modules
self.conv1 = ConvModule(self.in_shape, conv_filters=filters[0],
conv_kernel=3,
conv_padding=1,
last_shape = self.in_shape
for conv_filter, conv_kernel in zip(filters, kernels):
self.conv_list.append(ConvModule(last_shape, conv_filters=conv_filter,
conv_kernel=conv_kernel,
conv_padding=conv_kernel-2,
conv_stride=1,
pooling_size=2,
use_norm=use_norm,
dropout=dropout,
activation=activation
)
self.conv2 = ConvModule(self.conv1.shape, conv_filters=filters[1],
conv_kernel=3,
conv_padding=1,
conv_stride=1,
pooling_size=2,
use_norm=use_norm,
dropout=dropout,
activation=activation
)
last_shape = self.conv_list[-1].shape
self.last_conv_shape = last_shape
self.conv3 = ConvModule(self.conv2.shape, conv_filters=filters[2],
conv_kernel=5,
conv_padding=2,
conv_stride=1,
pooling_size=2,
use_norm=use_norm,
dropout=dropout,
activation=activation
)
self.flat = Flatten()
self.flat = Flatten(self.last_conv_shape)
def forward(self, x):
tensor = self.conv1(x)
tensor = self.conv2(tensor)
tensor = self.conv3(tensor)
tensor = x
for conv in self.conv_list:
tensor = conv(tensor)
tensor = self.flat(tensor)
return tensor
class UnitEncoder(BaseEncoder):
class UnitCNNEncoder(BaseCNNEncoder):
# noinspection PyUnresolvedReferences
def __init__(self, *args, **kwargs):
kwargs.update(use_norm=True)
super(UnitEncoder, self).__init__(*args, **kwargs)
super(UnitCNNEncoder, self).__init__(*args, **kwargs)
self.l1 = nn.Linear(reduce(mul, self.conv3.shape), self.lat_dim, bias=self.use_bias)
def forward(self, x):
@ -190,10 +240,10 @@ class UnitEncoder(BaseEncoder):
return c1, c2, c3, l1
class VariationalEncoder(BaseEncoder):
class VariationalCNNEncoder(BaseCNNEncoder):
# noinspection PyUnresolvedReferences
def __init__(self, *args, **kwargs):
super(VariationalEncoder, self).__init__(*args, **kwargs)
super(VariationalCNNEncoder, self).__init__(*args, **kwargs)
self.logvar = nn.Linear(reduce(mul, self.conv3.shape), self.lat_dim, bias=self.use_bias)
self.mu = nn.Linear(reduce(mul, self.conv3.shape), self.lat_dim, bias=self.use_bias)
@ -205,22 +255,22 @@ class VariationalEncoder(BaseEncoder):
return mu + eps*std
def forward(self, x):
tensor = super(VariationalEncoder, self).forward(x)
tensor = super(VariationalCNNEncoder, self).forward(x)
mu = self.mu(tensor)
logvar = self.logvar(tensor)
z = self.reparameterize(mu, logvar)
return mu, logvar, z
class Encoder(BaseEncoder):
# noinspection PyUnresolvedReferences
def __init__(self, *args, **kwargs):
super(Encoder, self).__init__(*args, **kwargs)
class CNNEncoder(BaseCNNEncoder):
self.l1 = nn.Linear(reduce(mul, self.conv3.shape), self.lat_dim, bias=self.use_bias)
def __init__(self, *args, **kwargs):
super(CNNEncoder, self).__init__(*args, **kwargs)
self.l1 = nn.Linear(self.flat.shape, self.lat_dim, bias=self.use_bias)
def forward(self, x):
tensor = super(Encoder, self).forward(x)
tensor = super(CNNEncoder, self).forward(x)
tensor = self.l1(tensor)
tensor = self.latent_activation(tensor) if self.latent_activation else tensor
return tensor

420
modules/util.py Normal file
View File

@ -0,0 +1,420 @@
from functools import reduce
from matplotlib import pyplot as plt
from abc import ABC
from pathlib import Path
import torch
from operator import mul
from pytorch_lightning.utilities import argparse_utils
from torch import nn
from torch.nn import functional as F, Unfold
from sklearn.metrics import ConfusionMatrixDisplay
# Utility - Modules
###################
from ..metrics.binary_class_classifictaion import BinaryScores
from ..metrics.multi_class_classification import MultiClassScores
from ..utils.model_io import ModelParameters
from ..utils.tools import add_argparse_args
try:
import pytorch_lightning as pl
class PLMetrics(pl.metrics.Metric):
def __init__(self, n_classes, tag=''):
super(PLMetrics, self).__init__()
self.n_classes = n_classes
self.tag = tag
self.accuracy_score = pl.metrics.Accuracy(compute_on_step=False,)
self.precision = pl.metrics.Precision(num_classes=self.n_classes, average='macro', compute_on_step=False,
is_multiclass=True)
self.recall = pl.metrics.Recall(num_classes=self.n_classes, average='macro', compute_on_step=False,
is_multiclass=True)
self.confusion_matrix = pl.metrics.ConfusionMatrix(self.n_classes, normalize='true', compute_on_step=False)
# self.precision_recall_curve = pl.metrics.PrecisionRecallCurve(self.n_classes, compute_on_step=False)
# self.average_prec = pl.metrics.AveragePrecision(self.n_classes, compute_on_step=True)
# self.roc = pl.metrics.ROC(self.n_classes, compute_on_step=False)
if self.n_classes > 2:
self.fbeta = pl.metrics.FBeta(self.n_classes, average='macro', compute_on_step=False)
self.f1 = pl.metrics.F1(self.n_classes, average='macro', compute_on_step=False)
def __iter__(self):
return iter(((name, metric) for name, metric in self._modules.items()))
def update(self, preds, target) -> None:
for _, metric in self:
try:
if self.n_classes <= 2:
metric.update(preds, target)
else:
metric.update(preds, target)
except ValueError:
print(f'error was: {ValueError}')
print(f'Metric is: {metric}')
print(f'Shape is: preds - {preds.squeeze().shape}, target - {target.shape}')
metric.update(preds.squeeze(), target)
except AssertionError:
print(f'error was: {AssertionError}')
print(f'Metric is: {metric}')
print(f'Shape is: preds - {preds.shape}, target - {target.unsqueeze(-1).shape}')
metric.update(preds, target.unsqueeze(-1))
def reset(self) -> None:
for _, metric in self:
metric.reset()
def compute(self) -> dict:
tag = f'{self.tag}_' if self.tag else ''
return {f'{tag}{metric_name}_score': metric.compute() for metric_name, metric in self}
def compute_and_prepare(self):
pl_metrics = self.compute()
images_from_metrics = dict()
for metric_name in list(pl_metrics.keys()):
if 'curve' in metric_name:
continue
roc_curve = pl_metrics.pop(metric_name)
print('debug_point')
elif 'matrix' in metric_name:
matrix = pl_metrics.pop(metric_name)
fig1, ax1 = plt.subplots(dpi=96)
disp = ConfusionMatrixDisplay(confusion_matrix=matrix.cpu().numpy(),
display_labels=[i for i in range(self.n_classes)]
)
disp.plot(include_values=True, ax=ax1)
images_from_metrics[metric_name] = fig1
elif 'ROC' in metric_name:
continue
roc = pl_metrics.pop(metric_name)
print('debug_point')
else:
pl_metrics[metric_name] = pl_metrics[metric_name].cpu().item()
return pl_metrics, images_from_metrics
class LightningBaseModule(pl.LightningModule, ABC):
@classmethod
def name(cls):
return cls.__name__
@property
def shape(self):
try:
x = torch.randn(self.in_shape).unsqueeze(0)
output = self(x)
return output.shape[1:]
except Exception as e:
print(e)
return -1
@classmethod
def from_argparse_args(cls, args, **kwargs):
return argparse_utils.from_argparse_args(cls, args, **kwargs)
@classmethod
def add_argparse_args(cls, parent_parser):
return add_argparse_args(cls, parent_parser)
def __init__(self, model_parameters, weight_init='xavier_normal_'):
super(LightningBaseModule, self).__init__()
self._weight_init = weight_init
self.params = ModelParameters(model_parameters)
if hasattr(self.params, 'n_classes'):
self.metrics = PLMetrics(self.params.n_classes, tag='PL')
else:
pass
def size(self):
return self.shape
def save_to_disk(self, model_path):
Path(model_path, exist_ok=True).mkdir(parents=True, exist_ok=True)
if not (model_path / 'model_class.obj').exists():
with (model_path / 'model_class.obj').open('wb') as f:
torch.save(self.__class__, f)
return True
@property
def data_len(self):
return len(self.dataset.train_dataset)
@property
def n_train_batches(self):
return len(self.train_dataloader())
def configure_optimizers(self):
raise NotImplementedError
def forward(self, *args, **kwargs):
raise NotImplementedError
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
raise NotImplementedError
def test_step(self, *args, **kwargs):
raise NotImplementedError
def test_epoch_end(self, outputs):
raise NotImplementedError
def init_weights(self):
if isinstance(self._weight_init, str):
mod = __import__('torch.nn.init', fromlist=[self._weight_init])
self._weight_init = getattr(mod, self._weight_init)
assert callable(self._weight_init)
weight_initializer = WeightInit(in_place_init_function=self._weight_init)
self.apply(weight_initializer)
def additional_scores(self, outputs):
if self.params.n_classes > 2:
return MultiClassScores(self)(outputs)
else:
return BinaryScores(self)(outputs)
module_types = (LightningBaseModule, nn.Module,)
except ImportError:
module_types = (nn.Module,)
pl = None
pass # Maybe post a hint to install pytorch-lightning.
class ShapeMixin:
@property
def shape(self):
assert isinstance(self, module_types)
def get_out_shape(output):
return output.shape[1:] if len(output.shape[1:]) > 1 else output.shape[-1]
in_shape = self.in_shape if hasattr(self, 'in_shape') else None
if in_shape is not None:
try:
device = self.device
except AttributeError:
try:
device = next(self.parameters()).device
except StopIteration:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
x = torch.randn(in_shape, device=device)
# This is needed for BatchNorm shape checking
x = torch.stack((x, x))
# noinspection PyCallingNonCallable
y = self(x)
if isinstance(y, tuple):
shape = tuple([get_out_shape(y[i]) for i in range(len(y))])
else:
shape = get_out_shape(y)
return shape
else:
return -1
@property
def flat_shape(self):
shape = self.shape
try:
return reduce(mul, shape)
except TypeError:
return shape
class F_x(ShapeMixin, nn.Identity):
def __init__(self, in_shape):
super(F_x, self).__init__()
self.in_shape = in_shape
class SlidingWindow(ShapeMixin, nn.Module):
def __init__(self, in_shape, kernel, stride=1, padding=0, keepdim=False):
super(SlidingWindow, self).__init__()
self.in_shape = in_shape
self.kernel = kernel if not isinstance(kernel, int) else (kernel, kernel)
self.padding = padding
self.stride = stride
self.keepdim = keepdim
self._unfolder = Unfold(self.kernel, dilation=1, padding=self.padding, stride=self.stride)
def forward(self, x):
tensor = self._unfolder(x)
tensor = tensor.transpose(-1, -2)
if self.keepdim:
shape = *x.shape[:2], -1, *self.kernel
tensor = tensor.reshape(shape)
return tensor
# Utility - Modules
###################
class Flatten(ShapeMixin, nn.Module):
def __init__(self, in_shape, to=-1):
assert isinstance(to, int) or isinstance(to, tuple)
super(Flatten, self).__init__()
self.in_shape = in_shape
self.to = (to,) if isinstance(to, int) else to
def forward(self, x):
return x.view(x.size(0), *self.to)
class Interpolate(nn.Module):
def __init__(self, size=None, scale_factor=None, mode='nearest', align_corners=None):
super(Interpolate, self).__init__()
self.interp = nn.functional.interpolate
self.size = size
self.scale_factor = scale_factor
self.align_corners = align_corners
self.mode = mode
def forward(self, x):
x = self.interp(x, size=self.size, scale_factor=self.scale_factor,
mode=self.mode, align_corners=self.align_corners)
return x
class AutoPad(nn.Module):
def __init__(self, interpolations=3, base=2):
super(AutoPad, self).__init__()
self.fct = base ** interpolations
def forward(self, x):
# noinspection PyUnresolvedReferences
x = F.pad(x,
[0,
(x.shape[-1] // self.fct + 1) * self.fct - x.shape[-1] if x.shape[-1] % self.fct != 0 else 0,
(x.shape[-2] // self.fct + 1) * self.fct - x.shape[-2] if x.shape[-2] % self.fct != 0 else 0,
0])
return x
class WeightInit:
def __init__(self, in_place_init_function):
self.in_place_init_function = in_place_init_function
def __call__(self, m):
if hasattr(m, 'weight'):
if isinstance(m.weight, torch.Tensor):
if m.weight.ndim < 2:
m.weight.data.fill_(0.01)
else:
self.in_place_init_function(m.weight)
if hasattr(m, 'bias'):
if isinstance(m.bias, torch.Tensor):
m.bias.data.fill_(0.01)
class Filter(nn.Module, ShapeMixin):
def __init__(self, in_shape, pos, dim=-1):
super(Filter, self).__init__()
self.in_shape = in_shape
self.pos = pos
self.dim = dim
raise SystemError('Do not use this Module - broken.')
@staticmethod
def forward(x):
tensor = x[:, -1]
return tensor
class FlipTensor(nn.Module):
def __init__(self, dim=-2):
super(FlipTensor, self).__init__()
self.dim = dim
def forward(self, x):
idx = [i for i in range(x.size(self.dim) - 1, -1, -1)]
idx = torch.as_tensor(idx).long()
inverted_tensor = x.index_select(self.dim, idx)
return inverted_tensor
class AutoPadToShape(nn.Module):
def __init__(self, target_shape):
super(AutoPadToShape, self).__init__()
self.target_shape = target_shape
def forward(self, x):
if not torch.is_tensor(x):
x = torch.as_tensor(x)
if x.shape[-len(self.target_shape):] == self.target_shape or x.shape == self.target_shape:
return x
idx = [0] * (len(self.target_shape) * 2)
for i, j in zip(range(-1, -(len(self.target_shape)+1), -1), range(0, len(idx), 2)):
idx[j] = self.target_shape[i] - x.shape[i]
x = torch.nn.functional.pad(x, idx)
return x
def __repr__(self):
return f'AutoPadTransform({self.target_shape})'
class Splitter(nn.Module):
@property
def shape(self):
return tuple([self._out_shape] * self.n)
@property
def out_shape(self):
return self._out_shape
def __init__(self, in_shape, n, dim=-1):
super(Splitter, self).__init__()
self.in_shape = (in_shape, ) if isinstance(in_shape, int) else in_shape
self.n = n
self.dim = dim if dim > 0 else len(self.in_shape) - abs(dim)
self.new_dim_size = (self.in_shape[self.dim] // self.n) + (1 if self.in_shape[self.dim] % self.n != 0 else 0)
self._out_shape = tuple([x if self.dim != i else self.new_dim_size for i, x in enumerate(self.in_shape)])
self.autopad = AutoPadToShape(self._out_shape)
def forward(self, x: torch.Tensor):
dim = self.dim + 1 if len(self.in_shape) == (x.ndim - 1) else self.dim
x = x.transpose(0, dim)
n_blocks = list()
for block_idx in range(self.n):
start = block_idx * self.new_dim_size
end = (block_idx + 1) * self.new_dim_size
block = x[start:end].transpose(0, dim)
block = self.autopad(block)
n_blocks.append(block)
return n_blocks
class Merger(nn.Module, ShapeMixin):
@property
def shape(self):
y = self.forward([torch.randn(self.in_shape) for _ in range(self.n)])
return y.shape
def __init__(self, in_shape, n, dim=-1):
super(Merger, self).__init__()
self.n = n
self.dim = dim
self.in_shape = in_shape
def forward(self, x):
return torch.cat(x, dim=self.dim)

View File

@ -1,256 +0,0 @@
from abc import ABC
from pathlib import Path
import torch
from torch import nn
from torch import functional as F
import pytorch_lightning as pl
# Utility - Modules
###################
from utils.model_io import ModelParameters
class ShapeMixin:
@property
def shape(self):
assert isinstance(self, (LightningBaseModule, nn.Module))
if self.in_shape is not None:
x = torch.randn(self.in_shape)
# This is needed for BatchNorm shape checking
x = torch.stack((x, x))
output = self(x)
return output.shape[1:] if len(output.shape[1:]) > 1 else output.shape[-1]
else:
return -1
class F_x(ShapeMixin, nn.Module):
def __init__(self, in_shape):
super(F_x, self).__init__()
self.in_shape = in_shape
def forward(self, x):
return x
# Utility - Modules
###################
class Flatten(ShapeMixin, nn.Module):
def __init__(self, in_shape, to=-1):
assert isinstance(to, int) or isinstance(to, tuple)
super(Flatten, self).__init__()
self.in_shape = in_shape
self.to = (to,) if isinstance(to, int) else to
def forward(self, x):
return x.view(x.size(0), *self.to)
class Interpolate(nn.Module):
def __init__(self, size=None, scale_factor=None, mode='nearest', align_corners=None):
super(Interpolate, self).__init__()
self.interp = nn.functional.interpolate
self.size = size
self.scale_factor = scale_factor
self.align_corners = align_corners
self.mode = mode
def forward(self, x):
x = self.interp(x, size=self.size, scale_factor=self.scale_factor,
mode=self.mode, align_corners=self.align_corners)
return x
class AutoPad(nn.Module):
def __init__(self, interpolations=3, base=2):
super(AutoPad, self).__init__()
self.fct = base ** interpolations
def forward(self, x):
# noinspection PyUnresolvedReferences
x = F.pad(x,
[0,
(x.shape[-1] // self.fct + 1) * self.fct - x.shape[-1] if x.shape[-1] % self.fct != 0 else 0,
(x.shape[-2] // self.fct + 1) * self.fct - x.shape[-2] if x.shape[-2] % self.fct != 0 else 0,
0])
return x
class WeightInit:
def __init__(self, in_place_init_function):
self.in_place_init_function = in_place_init_function
def __call__(self, m):
if hasattr(m, 'weight'):
if isinstance(m.weight, torch.Tensor):
if m.weight.ndim < 2:
m.weight.data.fill_(0.01)
else:
self.in_place_init_function(m.weight)
if hasattr(m, 'bias'):
if isinstance(m.bias, torch.Tensor):
m.bias.data.fill_(0.01)
class LightningBaseModule(pl.LightningModule, ABC):
@classmethod
def name(cls):
return cls.__name__
@property
def shape(self):
try:
x = torch.randn(self.in_shape).unsqueeze(0)
output = self(x)
return output.shape[1:]
except Exception as e:
print(e)
return -1
def __init__(self, hparams):
super(LightningBaseModule, self).__init__()
# Set Parameters
################################
self.hparams = hparams
self.params = ModelParameters(hparams)
# Dataset Loading
################################
# TODO: Find a way to push Class Name, library path and parameters (sometimes thiose are objects) in here
def size(self):
return self.shape
def save_to_disk(self, model_path):
Path(model_path, exist_ok=True).mkdir(parents=True, exist_ok=True)
if not (model_path / 'model_class.obj').exists():
with (model_path / 'model_class.obj').open('wb') as f:
torch.save(self.__class__, f)
return True
@property
def data_len(self):
return len(self.dataset.train_dataset)
@property
def n_train_batches(self):
return len(self.train_dataloader())
def configure_optimizers(self):
raise NotImplementedError
def forward(self, *args, **kwargs):
raise NotImplementedError
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
raise NotImplementedError
def test_step(self, *args, **kwargs):
raise NotImplementedError
def test_epoch_end(self, outputs):
raise NotImplementedError
def init_weights(self, in_place_init_func_=nn.init.xavier_uniform_):
weight_initializer = WeightInit(in_place_init_function=in_place_init_func_)
self.apply(weight_initializer)
class FilterLayer(nn.Module):
def __init__(self):
super(FilterLayer, self).__init__()
def forward(self, x):
tensor = x[:, -1]
return tensor
class MergingLayer(nn.Module):
def __init__(self):
super(MergingLayer, self).__init__()
def forward(self, x):
# ToDo: Which ones to combine?
return
class FlipTensor(nn.Module):
def __init__(self, dim=-2):
super(FlipTensor, self).__init__()
self.dim = dim
def forward(self, x):
idx = [i for i in range(x.size(self.dim) - 1, -1, -1)]
idx = torch.as_tensor(idx).long()
inverted_tensor = x.index_select(self.dim, idx)
return inverted_tensor
class AutoPadToShape(object):
def __init__(self, shape):
self.shape = shape
def __call__(self, x):
if not torch.is_tensor(x):
x = torch.as_tensor(x)
if x.shape[1:] == self.shape:
return x
embedding = torch.zeros((x.shape[0], *self.shape))
embedding[:, :x.shape[1], :x.shape[2], :x.shape[3]] = x
return embedding
def __repr__(self):
return f'AutoPadTransform({self.shape})'
class HorizontalSplitter(nn.Module):
def __init__(self, in_shape, n):
super(HorizontalSplitter, self).__init__()
assert len(in_shape) == 3
self.n = n
self.in_shape = in_shape
self.channel, self.height, self.width = self.in_shape
self.new_height = (self.height // self.n) + (1 if self.height % self.n != 0 else 0)
self.shape = (self.channel, self.new_height, self.width)
self.autopad = AutoPadToShape(self.shape)
def forward(self, x):
n_blocks = list()
for block_idx in range(self.n):
start = block_idx * self.new_height
end = (block_idx + 1) * self.new_height
block = self.autopad(x[:, :, start:end, :])
n_blocks.append(block)
return n_blocks
class HorizontalMerger(nn.Module):
@property
def shape(self):
merged_shape = self.in_shape[0], self.in_shape[1] * self.n, self.in_shape[2]
return merged_shape
def __init__(self, in_shape, n):
super(HorizontalMerger, self).__init__()
assert len(in_shape) == 3
self.n = n
self.in_shape = in_shape
def forward(self, x):
return torch.cat(x, dim=-2)

View File

Binary file not shown.

Binary file not shown.

37
point_toolset/point_io.py Normal file
View File

@ -0,0 +1,37 @@
import torch
from torch_geometric.data import Data
class BatchToData(object):
def __init__(self, transforms=None):
super(BatchToData, self).__init__()
self.transforms = transforms if transforms else lambda x: x
def __call__(self, batch_dict):
# Convert to torch_geometric.data.Data type
batch_pos = batch_dict['pos']
batch_norm = batch_dict.get('norm', None)
batch_y = batch_dict.get('y', None)
batch_y_c = batch_dict.get('y_c', None)
batch_size, num_points, _ = batch_pos.shape # (batch_size, num_points, 3)
batch_size, N, _ = batch_pos.shape # (batch_size, num_points, 3)
pos = batch_pos.view(batch_size * N, -1)
norm = batch_norm.view(batch_size * N, -1) if batch_norm is not None else batch_norm
batch_y_l = batch_y.view(batch_size * N, -1) if batch_y is not None else batch_y
batch_y_c = batch_y_c.view(batch_size * N, -1) if batch_y_c is not None else batch_y_c
batch = torch.zeros((batch_size, num_points), device=pos.device, dtype=torch.long)
for i in range(batch_size):
batch[i] = i
batch = batch.view(-1)
data = Data()
data.norm, data.pos, data.batch, data.yl, data.yc = norm, pos, batch, batch_y_l, batch_y_c
data = self.transforms(data)
return data

View File

@ -0,0 +1,26 @@
import torch
from torch_geometric.transforms import NormalizeScale
class NormalizePositions(NormalizeScale):
def __init__(self):
super(NormalizePositions, self).__init__()
def __call__(self, data):
if torch.isnan(data.pos).any():
print('debug')
data = self.center(data)
if torch.isnan(data.pos).any():
print('debug')
scale = (1 / data.pos.abs().max()) * 0.999999
if torch.isnan(scale).any() or torch.isinf(scale).any():
print('debug')
data.pos = data.pos * scale
if torch.isnan(data.pos).any():
print('debug')
return data

51
point_toolset/sampling.py Normal file
View File

@ -0,0 +1,51 @@
from abc import ABC
import numpy as np
class _Sampler(ABC):
def __init__(self, K, **kwargs):
self.k = K
self.kwargs = kwargs
def __call__(self, *args, **kwargs):
raise NotImplementedError
class RandomSampling(_Sampler):
def __init__(self, *args, **kwargs):
super(RandomSampling, self).__init__(*args, **kwargs)
def __call__(self, pts, *args, **kwargs):
rnd_indexs = np.random.choice(np.arange(pts.shape[0]), min(self.k, pts.shape[0]), replace=False)
return rnd_indexs
class FarthestpointSampling(_Sampler):
def __init__(self, *args, **kwargs):
super(FarthestpointSampling, self).__init__(*args, **kwargs)
@staticmethod
def calc_distances(p0, points):
return ((p0[:3] - points[:, :3]) ** 2).sum(axis=1)
def __call__(self, pts, *args, **kwargs):
if pts.shape[0] < self.k:
return pts
else:
farthest_pts = np.zeros((self.k, pts.shape[1]))
farthest_pts_idx = np.zeros(self.k, dtype=np.int)
farthest_pts[0] = pts[np.random.randint(len(pts))]
distances = self.calc_distances(farthest_pts[0], pts)
for i in range(1, self.k):
farthest_pts_idx[i] = np.argmax(distances)
farthest_pts[i] = pts[farthest_pts_idx[i]]
distances = np.minimum(distances, self.calc_distances(farthest_pts[i], pts))
return farthest_pts_idx

View File

@ -1,92 +1,8 @@
absl-py==0.9.0
appdirs==1.4.3
attrs==19.3.0
audioread==2.1.8
bravado==10.6.0
bravado-core==5.17.0
CacheControl==0.12.6
cachetools==4.1.0
certifi==2019.11.28
cffi==1.14.0
chardet==3.0.4
click==7.1.2
colorama==0.4.3
contextlib2==0.6.0
cycler==0.10.0
decorator==4.4.2
distlib==0.3.0
distro==1.4.0
future==0.18.2
gitdb==4.0.5
GitPython==3.1.2
google-auth==1.14.3
google-auth-oauthlib==0.4.1
grpcio==1.29.0
html5lib==1.0.1
idna==2.8
ipaddr==2.2.0
joblib==0.15.1
jsonpointer==2.0
jsonref==0.2
jsonschema==3.2.0
kiwisolver==1.2.0
librosa==0.7.2
llvmlite==0.32.1
lockfile==0.12.2
Markdown==3.2.2
matplotlib==3.2.1
monotonic==1.5
msgpack==0.6.2
msgpack-python==0.5.6
natsort==7.0.1
neptune-client==0.4.113
numba==0.49.1
numpy==1.18.4
oauthlib==3.1.0
packaging==20.3
pandas==1.0.3
pep517==0.8.2
Pillow==7.1.2
progress==1.5
protobuf==3.12.0
py3nvml==0.2.6
pyasn1==0.4.8
pyasn1-modules==0.2.8
pycparser==2.20
PyJWT==1.7.1
pyparsing==2.4.6
pyrsistent==0.16.0
python-dateutil==2.8.1
pytoml==0.1.21
neptune-client==0.4.109
pytorch-lightning==0.7.6
pytz==2020.1
PyYAML==5.3.1
requests==2.22.0
requests-oauthlib==1.3.0
resampy==0.2.2
retrying==1.3.3
rfc3987==1.3.8
rsa==4.0
scikit-learn==0.23.0
scipy==1.4.1
simplejson==3.17.0
six==1.14.0
smmap==3.0.4
SoundFile==0.10.3.post1
strict-rfc3339==0.7
swagger-spec-validator==2.5.0
tensorboard==2.2.1
tensorboard-plugin-wit==1.6.0.post3
threadpoolctl==2.0.0
torch==1.5.0+cu101
torchvision==0.6.0+cu101
tqdm==4.46.0
typing-extensions==3.7.4.2
urllib3==1.25.8
webcolors==1.11.1
webencodings==0.5.1
websocket-client==0.57.0
Werkzeug==1.0.1
xmltodict==0.12.0
torchcontrib~=0.0.2
test-tube==0.7.5
torch==1.4.0
torchcontrib==0.0.2
torchvision==0.5.0
tqdm==4.45.0

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

36
utils/_basedatamodule.py Normal file
View File

@ -0,0 +1,36 @@
from pytorch_lightning import LightningDataModule
# Dataset Options
from ml_lib.utils.tools import add_argparse_args
DATA_OPTION_test = 'test'
DATA_OPTION_devel = 'devel'
DATA_OPTION_train = 'train'
DATA_OPTIONS = [DATA_OPTION_train, DATA_OPTION_devel, DATA_OPTION_test]
class _BaseDataModule(LightningDataModule):
@property
def shape(self):
return self.datasets[DATA_OPTION_train].sample_shape
@classmethod
def add_argparse_args(cls, parent_parser):
return add_argparse_args(cls, parent_parser)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.datasets = dict()
def transfer_batch_to_device(self, batch, device):
if isinstance(batch, list):
for idx, item in enumerate(batch):
try:
batch[idx] = item.to(device)
except (AttributeError, RuntimeError):
continue
return batch
else:
return batch.to(device)

28
utils/callbacks.py Normal file
View File

@ -0,0 +1,28 @@
import torch
from pytorch_lightning import Callback, Trainer, LightningModule
class BestScoresCallback(Callback):
def __init__(self, *monitors) -> None:
super().__init__()
self.monitors = list(*monitors)
self.best_scores = {monitor: 0.0 for monitor in self.monitors}
self.best_epoch = {monitor: 0 for monitor in self.monitors}
def on_validation_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
epoch = pl_module.current_epoch
for monitor in self.best_scores.keys():
current_score = trainer.callback_metrics.get(monitor)
if current_score is None:
pass
elif torch.isinf(current_score):
pass
elif torch.isnan(current_score):
pass
else:
self.best_scores[monitor] = max(self.best_scores[monitor], current_score)
if self.best_scores[monitor] == current_score:
self.best_epoch[monitor] = max(self.best_epoch[monitor], epoch)

View File

@ -1,13 +1,84 @@
import ast
import configparser
from distutils.util import strtobool
from pathlib import Path
from typing import Mapping, Dict
import torch
from copy import deepcopy
from abc import ABC
from argparse import Namespace, ArgumentParser
from collections import defaultdict
from configparser import ConfigParser
from pathlib import Path
from configparser import ConfigParser, DuplicateSectionError
import hashlib
from pytorch_lightning import Trainer
from ml_lib.utils.loggers import LightningLogger
from ml_lib.utils.tools import locate_and_import_class, auto_cast
# Argument Parser and default Values
# =============================================================================
def parse_comandline_args_add_defaults(filepath, overrides=None):
# Parse Command Line
parser = ArgumentParser()
parser.add_argument('--model_name', type=str)
parser.add_argument('--data_name', type=str)
parser.add_argument('--seed', type=str)
parser.add_argument('--debug', type=strtobool)
# Load Defaults from _parameters.ini file
config = configparser.ConfigParser()
config.read(str(filepath))
new_defaults = dict()
for key in ['project', 'train', 'data']:
defaults = config[key]
new_defaults.update({key: auto_cast(val) for key, val in defaults.items()})
args, _ = parser.parse_known_args()
overrides = overrides or dict()
default_data = overrides.get('data_name', None) or new_defaults['data_name']
default_model = overrides.get('model_name', None) or new_defaults['model_name']
default_seed = overrides.get('seed', None) or new_defaults['seed']
data_name = args.__dict__.get('data_name', None) or default_data
model_name = args.__dict__.get('model_name', None) or default_model
found_seed = args.__dict__.get('seed', None) or default_seed
new_defaults.update({key: auto_cast(val) for key, val in config[model_name].items()})
found_data_class = locate_and_import_class(data_name, 'datasets')
found_model_class = locate_and_import_class(model_name, 'models')
for module in [LightningLogger, Trainer, found_data_class, found_model_class]:
parser = module.add_argparse_args(parser)
args, _ = parser.parse_known_args(namespace=Namespace(**new_defaults))
args = vars(args)
args.update({key: auto_cast(val) for key, val in args.items()})
args.update(gpus=[0] if torch.cuda.is_available() and not args['debug'] else None,
row_log_interval=1000, # TODO: Better Value / Setting
log_save_interval=10000, # TODO: Better Value / Setting
weights_summary='top',
)
if overrides is not None and isinstance(overrides, (Mapping, Dict)):
args.update(**overrides)
if args['debug']:
args.update(
# The seems to be the new "fast_dev_run"
val_check_interval=1,
max_epochs=2,
max_steps=2,
auto_lr_find=False,
check_val_every_n_epoch=1
)
return args, found_data_class, found_model_class, found_seed
def is_jsonable(x):
@ -38,11 +109,32 @@ class Config(ConfigParser, ABC):
def fingerprint(self):
h = hashlib.md5()
params = deepcopy(self.as_dict)
try:
del params['model']['type']
del params['model']['secondary_type']
except KeyError:
pass
try:
del params['data']['worker']
except KeyError:
pass
try:
del params['data']['refresh']
except KeyError:
pass
try:
del params['main']
h.update(str(params).encode())
except KeyError:
pass
try:
del params['project']
except KeyError:
pass
# Flatten the dict of dicts
for section in list(params.keys()):
params.update({f'{section}_{key}': val for key, val in params[section].items()})
del params[section]
_, vals = zip(*sorted(params.items(), key=lambda tup: tup[0]))
h.update(str(vals).encode())
fingerprint = h.hexdigest()
return fingerprint
@ -53,6 +145,7 @@ class Config(ConfigParser, ABC):
@property
def _model_map(self):
"""
This is function is supposed to return a dict, which holds a mapping from string model names to model classes
@ -62,21 +155,31 @@ class Config(ConfigParser, ABC):
)
:return:
"""
raise NotImplementedError
@property
def model_class(self):
try:
return self._model_map[self.model.type]
except KeyError:
raise KeyError(f'The model alias you provided ("{self.get("model", "type")}")' +
'does not exist! Try one of these: {list(self._model_map.keys())}')
return locate_and_import_class(self.model.type, folder_path='models')
except AttributeError as e:
raise AttributeError(f'The model alias you provided ("{self.get("model", "type")}") ' +
f'was not found!\n' +
f'{e}')
@property
def data_class(self):
try:
return locate_and_import_class(self.data.class_name, folder_path='datasets')
except AttributeError as e:
raise AttributeError(f'The dataset alias you provided ("{self.get("data", "class_name")}") ' +
f'was not found!\n' +
f'{e}')
# --------------------------------------------------
# TODO: Do this programmatically; This did not work:
# Initialize Default Sections as Property
# for section in self.default_sections:
# self.__setattr__(section, property(lambda x :x._get_namespace_for_section(section))
# self.__setattr__(section, property(lambda tensor :tensor._get_namespace_for_section(section))
@property
def main(self):
@ -195,3 +298,22 @@ class Config(ConfigParser, ABC):
with path.open('w') as configfile:
super().write(configfile)
return True
def _write_section(self, fp, section_name, section_items, delimiter):
if section_name == 'project':
return
else:
super(Config, self)._write_section(fp, section_name, section_items, delimiter)
def add_section(self, section: str) -> None:
try:
super(Config, self).add_section(section)
except DuplicateSectionError:
pass
class DataClass(Namespace):
@property
def __dict__(self):
return [x for x in dir(self) if not x.startswith('_')]

41
utils/data_util.py Normal file
View File

@ -0,0 +1,41 @@
import torch
from torch.utils.data import Dataset
def chunks(l, n):
"""Yield successive n-sized chunks from l."""
for i in range(0, len(l), n):
yield l[i:i + n]
class ReMapDataset(Dataset):
@property
def sample_shape(self):
return list(self[0][0].shape)
def __init__(self, ds, mapping):
super(ReMapDataset, self).__init__()
# here is a mapping from this index to the mother ds index
self.mapping = mapping
self.ds = ds
def __getitem__(self, index):
return self.ds[self.mapping[index]]
def __len__(self):
return self.mapping.shape[0]
@classmethod
def do_train_vali_split(cls, ds, split_fold=0.1):
indices = torch.randperm(len(ds))
valid_size = int(len(ds) * split_fold)
train_mapping = indices[valid_size:]
valid_mapping = indices[:valid_size]
train = cls(ds, train_mapping)
valid = cls(ds, valid_mapping)
return train, valid

30
utils/equal_sampler.py Normal file
View File

@ -0,0 +1,30 @@
import random
from typing import Iterator, Sequence
from torch.utils.data import Sampler
from torch.utils.data.sampler import T_co
# noinspection PyMissingConstructor
class EqualSampler(Sampler):
def __init__(self, idxs_per_class: Sequence[Sequence[float]], replacement: bool = True) -> None:
self.replacement = replacement
self.idxs_per_class = idxs_per_class
self.len_largest_class = max([len(x) for x in self.idxs_per_class])
def __iter__(self) -> Iterator[T_co]:
return iter(random.choice(self.idxs_per_class[random.randint(0, len(self.idxs_per_class)-1)])
for _ in range(len(self)))
def __len__(self):
return self.len_largest_class * len(self.idxs_per_class)
if __name__ == '__main__':
es = EqualSampler([list(range(5)), list(range(5, 10)), list(range(10, 12))])
for i in es:
print(i)
pass

185
utils/loggers.py Normal file
View File

@ -0,0 +1,185 @@
import inspect
from argparse import ArgumentParser
from copy import deepcopy
import hashlib
from pathlib import Path
import os
from pytorch_lightning.loggers.base import LightningLoggerBase
from neptune.api_exceptions import ProjectNotFound
from pytorch_lightning.loggers.neptune import NeptuneLogger
from pytorch_lightning.loggers.csv_logs import CSVLogger
from pytorch_lightning.utilities import argparse_utils
from ml_lib.utils.tools import add_argparse_args
class LightningLogger(LightningLoggerBase):
@classmethod
def from_argparse_args(cls, args, **kwargs):
cleaned_args = deepcopy(args.__dict__)
# Clean Seed and other attributes
# TODO: Find a better way in cleaning this
for attr in ['seed', 'num_worker', 'debug', 'eval', 'owner', 'data_root', 'check_val_every_n_epoch',
'reset', 'outpath', 'version', 'gpus', 'neptune_key', 'num_sanity_val_steps', 'tpu_cores',
'progress_bar_refresh_rate', 'log_save_interval', 'row_log_interval']:
try:
del cleaned_args[attr]
except KeyError:
pass
kwargs.update(params=cleaned_args)
new_logger = argparse_utils.from_argparse_args(cls, args, **kwargs)
return new_logger
@property
def fingerprint(self):
h = hashlib.md5()
h.update(self._finger_print_string.encode())
fingerprint = h.hexdigest()
return fingerprint
@property
def name(self):
short_name = "".join(c for c in self.model_name if c.isupper())
return f'{short_name}_{self.fingerprint}'
media_dir = 'media'
@classmethod
def add_argparse_args(cls, parent_parser):
return add_argparse_args(cls, parent_parser)
@property
def experiment(self):
if self.debug:
return self.csvlogger.experiment
else:
return self.neptunelogger.experiment
@property
def log_dir(self):
return Path(self.csvlogger.experiment.log_dir)
@property
def project_name(self):
return f"{self.owner}/{self.projeect_root.replace('_', '-')}"
@property
def projeect_root(self):
root_path = Path(os.getcwd()).name if not self.debug else 'test'
return root_path
@property
def version(self):
return self.seed
@property
def save_dir(self):
return self.log_dir
@property
def outpath(self):
return Path(self.root_out) / self.model_name
def __init__(self, owner, neptune_key, model_name, outpath='output', seed=69, debug=False, params=None):
"""
params (dict|None): Optional. Parameters of the experiment. After experiment creation params are read-only.
Parameters are displayed in the experiments Parameters section and each key-value pair can be
viewed in experiments view as a column.
properties (dict|None): Optional default is {}. Properties of the experiment.
They are editable after experiment is created. Properties are displayed in the experiments Details and
each key-value pair can be viewed in experiments view as a column.
tags (list|None): Optional default []. Must be list of str. Tags of the experiment.
They are editable after experiment is created (see: append_tag() and remove_tag()).
Tags are displayed in the experiments Details and can be viewed in experiments view as a column.
"""
super(LightningLogger, self).__init__()
self.debug = debug
self.owner = owner if not self.debug else 'testuser'
self.neptune_key = neptune_key if not self.debug else 'XXX'
self.root_out = outpath if not self.debug else 'debug_out'
self.params = params
self.seed = seed
self.model_name = model_name
if self.params:
_, fingerprint_tuple = zip(*sorted(self.params.items(), key=lambda tup: tup[0]))
self._finger_print_string = str(fingerprint_tuple)
else:
self._finger_print_string = str((self.owner, self.root_out, self.seed, self.model_name, self.debug))
self.params.update(fingerprint=self.fingerprint)
self._csvlogger_kwargs = dict(save_dir=self.outpath, version=self.version, name=self.name)
self._neptune_kwargs = dict(offline_mode=self.debug,
params=self.params,
api_key=self.neptune_key,
experiment_name=self.name,
# tags=?,
project_name=self.project_name)
try:
self.neptunelogger = NeptuneLogger(**self._neptune_kwargs)
except ProjectNotFound as e:
print(f'The project "{self.project_name}" does not exist! Create it or check your spelling.')
print(e)
self.csvlogger = CSVLogger(**self._csvlogger_kwargs)
if self.params:
self.log_hyperparams(self.params)
def close(self):
self.csvlogger.close()
self.neptunelogger.close()
def set_fingerprint_string(self, fingerprint_str):
self._finger_print_string = fingerprint_str
def log_text(self, name, text, **_):
# TODO Implement Offline variant.
self.neptunelogger.log_text(name, text)
def log_hyperparams(self, params):
self.neptunelogger.log_hyperparams(params)
self.csvlogger.log_hyperparams(params)
pass
def log_metric(self, metric_name, metric_value, step=None, **kwargs):
self.csvlogger.log_metrics(dict(metric_name=metric_value, **kwargs), step=step, **kwargs)
self.neptunelogger.log_metric(metric_name, metric_value, step=step, **kwargs)
pass
def log_metrics(self, metrics, step=None):
self.neptunelogger.log_metrics(metrics, step=step)
self.csvlogger.log_metrics(metrics, step=step)
pass
def log_image(self, name, image, ext='png', step=None, **kwargs):
image_name = f'{"0" * (4 - len(str(step)))}{step}_{name}' if step is not None else name
image_path = self.log_dir / self.media_dir / f'{image_name}.{ext[1:] if ext.startswith(".") else ext}'
(self.log_dir / self.media_dir).mkdir(parents=True, exist_ok=True)
image.savefig(image_path, bbox_inches='tight', pad_inches=0)
self.neptunelogger.log_image(name, str(image_path), **kwargs)
def save(self):
self.csvlogger.save()
self.neptunelogger.save()
def finalize(self, status):
self.csvlogger.finalize(status)
self.neptunelogger.finalize(status)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.finalize('success')
pass

View File

@ -1,116 +0,0 @@
from abc import ABC
from pathlib import Path
from pytorch_lightning.loggers.base import LightningLoggerBase
from pytorch_lightning.loggers.neptune import NeptuneLogger
from pytorch_lightning.loggers.test_tube import TestTubeLogger
from utils.config import Config
class Logger(LightningLoggerBase, ABC):
media_dir = 'media'
@property
def experiment(self):
if self.debug:
return self.testtubelogger.experiment
else:
return self.neptunelogger.experiment
@property
def log_dir(self):
return Path(self.testtubelogger.experiment.get_logdir()).parent
@property
def name(self):
return self.config.name
@property
def project_name(self):
return f"{self.config.project.owner}/{self.config.project.name.replace('_', '-')}"
@property
def version(self):
return self.config.get('main', 'seed')
@property
def outpath(self):
return Path(self.config.train.outpath) / self.config.model.type
@property
def exp_path(self):
return Path(self.outpath) / self.name
def __init__(self, config: Config):
"""
params (dict|None): Optional. Parameters of the experiment. After experiment creation params are read-only.
Parameters are displayed in the experiments Parameters section and each key-value pair can be
viewed in experiments view as a column.
properties (dict|None): Optional default is {}. Properties of the experiment.
They are editable after experiment is created. Properties are displayed in the experiments Details and
each key-value pair can be viewed in experiments view as a column.
tags (list|None): Optional default []. Must be list of str. Tags of the experiment.
They are editable after experiment is created (see: append_tag() and remove_tag()).
Tags are displayed in the experiments Details and can be viewed in experiments view as a column.
"""
super(Logger, self).__init__()
self.config = config
self.debug = self.config.main.debug
self._testtube_kwargs = dict(save_dir=self.outpath, version=self.version, name=self.name)
self._neptune_kwargs = dict(offline_mode=self.debug,
api_key=self.config.project.neptune_key,
experiment_name=self.name,
project_name=self.project_name,
params=self.config.model_paramters)
self.neptunelogger = NeptuneLogger(**self._neptune_kwargs)
self.testtubelogger = TestTubeLogger(**self._testtube_kwargs)
self.log_config_as_ini()
def log_hyperparams(self, params):
self.neptunelogger.log_hyperparams(params)
self.testtubelogger.log_hyperparams(params)
pass
def log_metrics(self, metrics, step=None):
self.neptunelogger.log_metrics(metrics, step=step)
self.testtubelogger.log_metrics(metrics, step=step)
pass
def close(self):
self.testtubelogger.close()
self.neptunelogger.close()
def log_config_as_ini(self):
self.config.write(self.log_dir / 'config.ini')
def log_text(self, name, text, step_nb=0, **kwargs):
# TODO Implement Offline variant.
self.neptunelogger.log_text(name, text, step_nb)
def log_metric(self, metric_name, metric_value, **kwargs):
self.testtubelogger.log_metrics(dict(metric_name=metric_value))
self.neptunelogger.log_metric(metric_name, metric_value, **kwargs)
def log_image(self, name, image, **kwargs):
self.neptunelogger.log_image(name, image, **kwargs)
step = kwargs.get('step', None)
name = f'{step}_{name}' if step is not None else name
image.savefig(self.log_dir / self.media_dir / name)
def save(self):
self.testtubelogger.save()
self.neptunelogger.save()
def finalize(self, status):
self.testtubelogger.finalize(status)
self.neptunelogger.finalize(status)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.finalize('success')
pass

View File

@ -1,5 +1,7 @@
from argparse import Namespace
from collections import Mapping
from typing import Union
from copy import deepcopy
from pathlib import Path
@ -11,6 +13,14 @@ from torch import nn
# Hyperparamter Object
class ModelParameters(Namespace, Mapping):
@property
def as_dict(self):
return {key: self.get(key) if key != 'activation' else self.activation_as_string for key in self.keys()}
@property
def activation_as_string(self):
return self['activation'].lower()
@property
def module_kwargs(self):
@ -18,9 +28,11 @@ class ModelParameters(Namespace, Mapping):
paramter_mapping.update(
dict(
activation=self._activations[self['activation']]
activation=self.__getattribute__('activation')
)
)
# Get rid of paramters that
paramter_mapping.__delitem__('in_shape')
return paramter_mapping
@ -42,49 +54,54 @@ class ModelParameters(Namespace, Mapping):
def __getattribute__(self, name):
if name == 'activation':
return self._activations[self['activation']]
return self._activations[self['activation'].lower()]
else:
try:
return super(ModelParameters, self).__getattribute__(name)
except AttributeError as e:
if name == 'stretch':
return False
else:
raise AttributeError(e)
_activations = dict(
leaky_relu=nn.LeakyReLU,
gelu=nn.GELU,
elu=nn.ELU,
relu=nn.ReLU,
sigmoid=nn.Sigmoid,
tanh=nn.Tanh
)
def __init__(self, parameter_mapping):
if isinstance(parameter_mapping, Namespace):
parameter_mapping = parameter_mapping.__dict__
super(ModelParameters, self).__init__(**parameter_mapping)
class SavedLightningModels(object):
@classmethod
def load_checkpoint(cls, models_root_path, model=None, n=-1, tags_file_path=''):
def load_checkpoint(cls, models_root_path, model=None, n=-1, checkpoint: Union[None, str] = None):
assert models_root_path.exists(), f'The path {models_root_path.absolute()} does not exist!'
if checkpoint is not None:
checkpoint_path = Path(checkpoint)
assert checkpoint_path.exists(), f'The path ({checkpoint_path} does not exist).'
else:
found_checkpoints = list(Path(models_root_path).rglob('*.ckpt'))
found_checkpoints = natsorted(found_checkpoints, key=lambda y: y.name)
checkpoint_path = natsorted(found_checkpoints, key=lambda y: y.name)[n]
if model is None:
model = torch.load(models_root_path / 'model_class.obj')
assert model is not None
return cls(weights=found_checkpoints[n], model=model)
return cls(weights=checkpoint_path, model=model)
def __init__(self, **kwargs):
self.weights: str = kwargs.get('weights', '')
self.weights: Path = Path(kwargs.get('weights', ''))
self.hparams: Path = self.weights.parent / 'hparams.yaml'
self.model = kwargs.get('model', None)
assert self.model is not None
def restore(self):
pretrained_model = self.model.load_from_checkpoint(self.weights)
pretrained_model = self.model.load_from_checkpoint(self.weights.__str__())
# , hparams_file=self.hparams.__str__())
pretrained_model.eval()
pretrained_model.freeze()
return pretrained_model

View File

@ -23,3 +23,5 @@ def run_n_in_parallel(f, n, processes=0, **kwargs):
p.join()
return results
raise NotImplementedError()

View File

@ -1,6 +1,35 @@
import importlib
import inspect
import pickle
import shelve
from pathlib import Path
from argparse import ArgumentParser, ArgumentError
from ast import literal_eval
from pathlib import Path, PurePath
from typing import Union
import numpy as np
import torch
import random
def auto_cast(a):
try:
return literal_eval(a)
except:
return a
def to_one_hot(idx_array, max_classes):
one_hot = np.zeros((idx_array.size, max_classes))
one_hot[np.arange(idx_array.size), idx_array] = 1
return one_hot
def fix_all_random_seeds(seed):
np.random.seed(seed)
torch.manual_seed(seed)
random.seed(seed)
print(f'Seed is now fixed: "{seed}".')
def write_to_shelve(file_path, value):
@ -21,3 +50,42 @@ def load_from_shelve(file_path, key):
def check_path(file_path):
assert isinstance(file_path, Path)
assert str(file_path).endswith('.pik')
def locate_and_import_class(class_name, folder_path: Union[str, PurePath] = ''):
"""Locate an object by name or dotted path, importing as necessary."""
import sys
sys.path.append("..")
folder_path = Path(folder_path)
module_paths = [x for x in folder_path.rglob('*.py') if x.is_file() and '__init__' not in x.name]
# possible_package_path = folder_path / '__init__.py'
# package = str(possible_package_path) if possible_package_path.exists() else None
for module_path in module_paths:
mod = importlib.import_module('.'.join([x.replace('.py', '') for x in module_path.parts]))
try:
model_class = mod.__getattribute__(class_name)
return model_class
except AttributeError:
continue
raise AttributeError(f'Check the {folder_path.name} name. Possible files are:\n{[x.name for x in module_paths]}')
def add_argparse_args(cls, parent_parser):
parser = ArgumentParser(parents=[parent_parser], add_help=False)
full_arg_spec = inspect.getfullargspec(cls.__init__)
n_non_defaults = len(full_arg_spec.args) - (len(full_arg_spec.defaults) if full_arg_spec.defaults else 0)
for idx, argument in enumerate(full_arg_spec.args):
try:
if argument == 'self':
continue
if idx < n_non_defaults:
parser.add_argument(f'--{argument}', type=int)
else:
argument_type = type(argument)
parser.add_argument(f'--{argument}',
type=argument_type,
default=full_arg_spec.defaults[idx - n_non_defaults]
)
except ArgumentError:
continue
return parser

View File

@ -1,8 +1,22 @@
from abc import ABC
from torchvision.transforms import ToTensor as TorchVisionToTensor
class _BaseTransformation(ABC):
def __init__(self, *args):
pass
def __repr__(self):
return f'{self.__class__.__name__}({self.__dict__})'
def __call__(self, *args, **kwargs):
raise NotImplementedError
class ToTensor(TorchVisionToTensor):
def __call__(self, pic):
# Make it float .float() == 32bit
tensor = super(ToTensor, self).__call__(pic).float()
return tensor

View File

@ -1,31 +1,56 @@
try:
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
from matplotlib import pyplot as plt
except ImportError: # pragma: no-cover
raise ImportError('You want to use `matplotlib` which is not installed yet,' # pragma: no-cover
' install it with `pip install matplotlib`.')
from pathlib import Path
import matplotlib.pyplot as plt
def prettyfy_sns():
plt.style.use('default')
try:
import seaborn as sns
except ImportError:
raise ImportError('You want to use `seaborn` which is not installed yet,' # pragma: no-cover
' install it with `pip install seaborn`.')
sns.set_palette('Dark2')
tex_fonts = {
# Use LaTeX to write all text
"text.usetex": True,
"font.family": "serif",
# Use 10pt font in plots, to match 10pt font in document
"axes.labelsize": 10,
"font.size": 10,
# Make the legend/label fonts a little smaller
"legend.fontsize": 8,
"xtick.labelsize": 8,
"ytick.labelsize": 8
}
plt.rcParams.update(tex_fonts)
class Plotter(object):
def __init__(self, root_path=''):
if not root_path:
self.root_path = Path(root_path)
def save_current_figure(self, path, extention='.png', naked=True):
fig, _ = plt.gcf(), plt.gca()
def save_figure(self, figure, title, extention='.png', naked=False):
canvas = FigureCanvas(figure)
# Prepare save location and check img file extention
path = self.root_path / Path(path if str(path).endswith(extention) else f'{str(path)}{extention}')
path = self.root_path / f'{title}{extention}'
path.parent.mkdir(exist_ok=True, parents=True)
if naked:
plt.axis('off')
fig.savefig(path, bbox_inches='tight', transparent=True, pad_inches=0)
fig.clf()
figure.axis('off)')
figure.savefig(path, bbox_inches='tight', transparent=True, pad_inches=0)
canvas.print_figure(path)
else:
fig.savefig(path)
fig.clf()
def show_current_figure(self):
fig, _ = plt.gcf(), plt.gca()
fig.show()
fig.clf()
canvas.print_figure(path)
if __name__ == '__main__':
output_root = Path('..') / 'output'
p = Plotter(output_root)
p.save_current_figure('test.png')
raise PermissionError('Get out of here.')