Compare commits
54 Commits
Author | SHA1 | Date | |
---|---|---|---|
ab01006eae | |||
faa27c3cf9 | |||
abe870d106 | |||
1d1b154460 | |||
6816e423ff | |||
d3e7bf7efb | |||
ed260f1c2a | |||
675312537f | |||
43cf0ad00d | |||
479514c9e7 | |||
fff5e6e00a | |||
8e719af554 | |||
10bf376ac3 | |||
fc4617c9d8 | |||
f89f0f8528 | |||
b5e3e5aec1 | |||
a966321576 | |||
010176e80b | |||
f6156c6cde | |||
93103aba01 | |||
62d9eb6e8f | |||
c6fdaa24aa | |||
cfeea05673 | |||
14ed4e0117 | |||
13812b83b5 | |||
f296ba78b9 | |||
5848b528f0 | |||
6bc9447ce1 | |||
a4b6c698c3 | |||
4b089729b2 | |||
c7d17a9898 | |||
7770b29c14 | |||
53aa11521d | |||
aea34de964 | |||
3f8122484b | |||
12d36047ef | |||
76308888e0 | |||
0cff42f951 | |||
ece80ecbed | |||
d3fa32ae7b | |||
2acf91335f | |||
5987efb169 | |||
77ea043907 | |||
4b4051c045 | |||
8cec323286 | |||
235743b225 | |||
28d0034269 | |||
b87a56e8c6 | |||
196b1af7ae | |||
fcd5ee4d29 | |||
f290d5a8d8 | |||
645b7905e8 | |||
206aca10b3 | |||
e423d6fe31 |
7
.gitignore
vendored
7
.gitignore
vendored
@ -1,6 +1 @@
|
|||||||
/.idea/
|
.idea
|
||||||
# my own stuff
|
|
||||||
|
|
||||||
/data
|
|
||||||
/.idea
|
|
||||||
/ml_lib
|
|
||||||
|
10
README.md
10
README.md
@ -9,9 +9,10 @@ Clone it to find a collection of:
|
|||||||
- Utility Function for Model I/O
|
- Utility Function for Model I/O
|
||||||
- DL Modules
|
- DL Modules
|
||||||
- A Plotter Object
|
- A Plotter Object
|
||||||
- Audio Related Tools and Funtion
|
- Audio related Tools and Funtion
|
||||||
- Librosa
|
- Librosa
|
||||||
- Scipy Signal
|
- Scipy Signal
|
||||||
|
- PointCloud related Tools and Functions
|
||||||
|
|
||||||
###Notes:
|
###Notes:
|
||||||
- Use directory links to link from your main project folder to the ml_lib folder. Pycharm will automatically use
|
- Use directory links to link from your main project folder to the ml_lib folder. Pycharm will automatically use
|
||||||
@ -19,11 +20,10 @@ Clone it to find a collection of:
|
|||||||
\
|
\
|
||||||
\
|
\
|
||||||
For Windows Users:
|
For Windows Users:
|
||||||
```
|
``` bash
|
||||||
mklink /d "ml_lib" "..\ml_lib""
|
mklink /d "ml_lib" "..\ml_lib""
|
||||||
```
|
```
|
||||||
For Unix User:
|
For Unix User:
|
||||||
|
``` bash
|
||||||
|
ln -s ../ml_lib ml_lib
|
||||||
```
|
```
|
||||||
TBA
|
|
||||||
```
|
|
||||||
- Cheers
|
|
BIN
__pycache__/__init__.cpython-37.pyc
Normal file
BIN
__pycache__/__init__.cpython-37.pyc
Normal file
Binary file not shown.
5
_templates/new_project/.gitignore
vendored
Normal file
5
_templates/new_project/.gitignore
vendored
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
# my own stuff
|
||||||
|
|
||||||
|
/data
|
||||||
|
/.idea
|
||||||
|
/ml_lib
|
@ -2,5 +2,16 @@ from torch.utils.data import Dataset
|
|||||||
|
|
||||||
|
|
||||||
class TemplateDataset(Dataset):
|
class TemplateDataset(Dataset):
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sample_shape(self):
|
||||||
|
return self[0][0].shape
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super(TemplateDataset, self).__init__()
|
super(TemplateDataset, self).__init__()
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __getitem__(self, item):
|
||||||
|
return item
|
||||||
|
@ -7,10 +7,9 @@ import torch
|
|||||||
from pytorch_lightning import Trainer
|
from pytorch_lightning import Trainer
|
||||||
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
|
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
|
||||||
|
|
||||||
from modules.utils import LightningBaseModule
|
from ml_lib.modules.util import LightningBaseModule
|
||||||
from utils.config import Config
|
from ml_lib.utils.config import Config
|
||||||
from utils.logging import Logger
|
from ml_lib.utils.loggers import LightningLogger
|
||||||
from utils.model_io import SavedLightningModels
|
|
||||||
|
|
||||||
warnings.filterwarnings('ignore', category=FutureWarning)
|
warnings.filterwarnings('ignore', category=FutureWarning)
|
||||||
warnings.filterwarnings('ignore', category=UserWarning)
|
warnings.filterwarnings('ignore', category=UserWarning)
|
||||||
@ -21,7 +20,7 @@ def run_lightning_loop(config_obj):
|
|||||||
# Logging
|
# Logging
|
||||||
# ================================================================================
|
# ================================================================================
|
||||||
# Logger
|
# Logger
|
||||||
with Logger(config_obj) as logger:
|
with LightningLogger(config_obj) as logger:
|
||||||
# Callbacks
|
# Callbacks
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Checkpoint Saving
|
# Checkpoint Saving
|
||||||
@ -44,11 +43,6 @@ def run_lightning_loop(config_obj):
|
|||||||
# Init
|
# Init
|
||||||
model: LightningBaseModule = config_obj.model_class(config_obj.model_paramters)
|
model: LightningBaseModule = config_obj.model_class(config_obj.model_paramters)
|
||||||
model.init_weights(torch.nn.init.xavier_normal_)
|
model.init_weights(torch.nn.init.xavier_normal_)
|
||||||
if model.name == 'CNNRouteGeneratorDiscriminated':
|
|
||||||
# ToDo: Make this dependent on the used seed
|
|
||||||
path = logger.outpath / 'classifier_cnn' / 'version_0'
|
|
||||||
disc_model = SavedLightningModels.load_checkpoint(path).restore()
|
|
||||||
model.set_discriminator(disc_model)
|
|
||||||
|
|
||||||
# Trainer
|
# Trainer
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
@ -70,8 +64,8 @@ def run_lightning_loop(config_obj):
|
|||||||
trainer.fit(model)
|
trainer.fit(model)
|
||||||
|
|
||||||
# Save the last state & all parameters
|
# Save the last state & all parameters
|
||||||
trainer.save_checkpoint(logger.log_dir / 'weights.ckpt')
|
trainer.save_checkpoint(config_obj.exp_path.log_dir / 'weights.ckpt')
|
||||||
model.save_to_disk(logger.log_dir)
|
model.save_to_disk(config_obj.exp_path)
|
||||||
|
|
||||||
# Evaluate It
|
# Evaluate It
|
||||||
if config_obj.main.eval:
|
if config_obj.main.eval:
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
from utils.config import Config
|
from ml_lib._templates.new_project.utils.project_config import Config
|
||||||
|
|
||||||
warnings.filterwarnings('ignore', category=FutureWarning)
|
warnings.filterwarnings('ignore', category=FutureWarning)
|
||||||
warnings.filterwarnings('ignore', category=UserWarning)
|
warnings.filterwarnings('ignore', category=UserWarning)
|
||||||
@ -8,17 +8,16 @@ warnings.filterwarnings('ignore', category=UserWarning)
|
|||||||
# Imports
|
# Imports
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
from _templates.new_project.main import run_lightning_loop, args
|
from ml_lib._templates.new_project.main import run_lightning_loop, args
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
||||||
# Model Settings
|
# Model Settings
|
||||||
config = Config().read_namespace(args)
|
config = Config().read_namespace(args)
|
||||||
# bias, activation, model, norm, max_epochs, filters
|
# bias, activation, model, norm, max_epochs
|
||||||
cnn_classifier = dict(train_epochs=10, model_use_bias=True, model_use_norm=True, model_activation='leaky_relu',
|
cnn_classifier = dict(train_epochs=10, model_use_bias=True, model_use_norm=True, data_batchsize=512)
|
||||||
model_type='classifier_cnn', model_filters=[16, 32, 64], data_batchsize=512)
|
# bias, activation, model, norm, max_epochs
|
||||||
# bias, activation, model, norm, max_epochs, sr, feature_mixed_dim, filters
|
|
||||||
|
|
||||||
for arg_dict in [cnn_classifier]:
|
for arg_dict in [cnn_classifier]:
|
||||||
for seed in range(5):
|
for seed in range(5):
|
||||||
|
@ -11,13 +11,13 @@ from torch.utils.data import DataLoader
|
|||||||
from torchcontrib.optim import SWA
|
from torchcontrib.optim import SWA
|
||||||
from torchvision.transforms import Compose
|
from torchvision.transforms import Compose
|
||||||
|
|
||||||
from _templates.new_project.datasets.template_dataset import TemplateDataset
|
from ml_lib._templates.new_project.datasets.template_dataset import TemplateDataset
|
||||||
|
|
||||||
from audio_toolset.audio_io import NormalizeLocal
|
from ml_lib.audio_toolset.audio_io import NormalizeLocal
|
||||||
from modules.utils import LightningBaseModule
|
from ml_lib.modules.util import LightningBaseModule
|
||||||
from utils.transforms import ToTensor
|
from ml_lib.utils.transforms import ToTensor
|
||||||
|
|
||||||
from _templates.new_project.utils.project_config import GlobalVar as GlobalVars
|
from ml_lib._templates.new_project.utils.project_config import GlobalVar as GlobalVars
|
||||||
|
|
||||||
|
|
||||||
class BaseOptimizerMixin:
|
class BaseOptimizerMixin:
|
||||||
@ -61,10 +61,11 @@ class BaseTrainMixin:
|
|||||||
assert isinstance(self, LightningBaseModule)
|
assert isinstance(self, LightningBaseModule)
|
||||||
keys = list(outputs[0].keys())
|
keys = list(outputs[0].keys())
|
||||||
|
|
||||||
summary_dict = dict(log={f'mean_{key}': torch.mean(torch.stack([output[key]
|
summary_dict = {f'mean_{key}': torch.mean(torch.stack([output[key]
|
||||||
for output in outputs]))
|
for output in outputs]))
|
||||||
for key in keys if 'loss' in key})
|
for key in keys if 'loss' in key}
|
||||||
return summary_dict
|
for key in summary_dict.keys():
|
||||||
|
self.log(key, summary_dict[key])
|
||||||
|
|
||||||
|
|
||||||
class BaseValMixin:
|
class BaseValMixin:
|
||||||
@ -83,16 +84,16 @@ class BaseValMixin:
|
|||||||
|
|
||||||
def validation_epoch_end(self, outputs, *_, **__):
|
def validation_epoch_end(self, outputs, *_, **__):
|
||||||
assert isinstance(self, LightningBaseModule)
|
assert isinstance(self, LightningBaseModule)
|
||||||
summary_dict = dict(log=dict())
|
summary_dict = dict()
|
||||||
# In case of Multiple given dataloader this will outputs will be: list[list[dict[]]]
|
# In case of Multiple given dataloader this will outputs will be: list[list[dict[]]]
|
||||||
# for output_idx, output in enumerate(outputs):
|
# for output_idx, output in enumerate(outputs):
|
||||||
# else:list[dict[]]
|
# else:list[dict[]]
|
||||||
keys = list(outputs.keys())
|
keys = list(outputs.keys())
|
||||||
# Add Every Value das has a "loss" in it, by calc. mean over all occurences.
|
# Add Every Value das has a "loss" in it, by calc. mean over all occurences.
|
||||||
summary_dict['log'].update({f'mean_{key}': torch.mean(torch.stack([output[key]
|
summary_dict.update({f'mean_{key}': torch.mean(torch.stack([output[key]
|
||||||
for output in outputs]))
|
for output in outputs]))
|
||||||
for key in keys if 'loss' in key}
|
for key in keys if 'loss' in key}
|
||||||
)
|
)
|
||||||
"""
|
"""
|
||||||
# Additional Score like the unweighted Average Recall:
|
# Additional Score like the unweighted Average Recall:
|
||||||
# UnweightedAverageRecall
|
# UnweightedAverageRecall
|
||||||
@ -107,7 +108,8 @@ class BaseValMixin:
|
|||||||
summary_dict['log'].update({f'uar_score': uar_score})
|
summary_dict['log'].update({f'uar_score': uar_score})
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return summary_dict
|
for key in summary_dict.keys():
|
||||||
|
self.log(key, summary_dict[key])
|
||||||
|
|
||||||
|
|
||||||
class BinaryMaskDatasetMixin:
|
class BinaryMaskDatasetMixin:
|
||||||
|
@ -1,8 +1,5 @@
|
|||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
|
|
||||||
from utils.config import Config
|
|
||||||
|
|
||||||
|
|
||||||
class GlobalVar(Namespace):
|
class GlobalVar(Namespace):
|
||||||
# Labels for classes
|
# Labels for classes
|
||||||
LEFT = 1
|
LEFT = 1
|
||||||
@ -21,10 +18,3 @@ class GlobalVar(Namespace):
|
|||||||
train='train',
|
train='train',
|
||||||
vali='vali',
|
vali='vali',
|
||||||
test='test'
|
test='test'
|
||||||
|
|
||||||
|
|
||||||
class ThisConfig(Config):
|
|
||||||
|
|
||||||
@property
|
|
||||||
def _model_map(self):
|
|
||||||
return dict()
|
|
||||||
|
0
additions/__init__.py
Normal file
0
additions/__init__.py
Normal file
75
additions/losses.py
Normal file
75
additions/losses.py
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
class FocalLoss(nn.modules.loss._WeightedLoss):
|
||||||
|
def __init__(self, weight=None, gamma=2,reduction='mean'):
|
||||||
|
super(FocalLoss, self).__init__(weight,reduction=reduction)
|
||||||
|
self.gamma = gamma
|
||||||
|
self.weight = weight # weight parameter will act as the alpha parameter to balance class weights
|
||||||
|
|
||||||
|
def forward(self, input, target):
|
||||||
|
|
||||||
|
ce_loss = F.cross_entropy(input, target, reduction=self.reduction, weight=self.weight)
|
||||||
|
pt = torch.exp(-ce_loss)
|
||||||
|
focal_loss = ((1 - pt) ** self.gamma * ce_loss).mean()
|
||||||
|
return focal_loss
|
||||||
|
|
||||||
|
|
||||||
|
class FocalLossRob(nn.Module):
|
||||||
|
# taken from https://github.com/mathiaszinnen/focal_loss_torch/blob/main/focal_loss/focal_loss.py
|
||||||
|
def __init__(self, alpha=1, gamma=2, reduction: str = 'mean'):
|
||||||
|
super().__init__()
|
||||||
|
if reduction not in ['mean', 'none', 'sum']:
|
||||||
|
raise NotImplementedError('Reduction {} not implemented.'.format(reduction))
|
||||||
|
self.reduction = reduction
|
||||||
|
self.alpha = alpha
|
||||||
|
self.gamma = gamma
|
||||||
|
|
||||||
|
def forward(self, x, target):
|
||||||
|
x = x.clamp(1e-7, 1. - 1e-7) # own addition
|
||||||
|
p_t = torch.where(target == 1, x, 1-x)
|
||||||
|
fl = - 1 * (1 - p_t) ** self.gamma * torch.log(p_t)
|
||||||
|
fl = torch.where(target == 1, fl * self.alpha, fl)
|
||||||
|
return self._reduce(fl)
|
||||||
|
|
||||||
|
def _reduce(self, x):
|
||||||
|
if self.reduction == 'mean':
|
||||||
|
return x.mean()
|
||||||
|
elif self.reduction == 'sum':
|
||||||
|
return x.sum()
|
||||||
|
else:
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class DQN_MSELoss(object):
|
||||||
|
|
||||||
|
def __init__(self, agent_net, target_net, gamma):
|
||||||
|
self.agent_net = agent_net
|
||||||
|
self.target_net = target_net
|
||||||
|
self.gamma = gamma
|
||||||
|
|
||||||
|
def __call__(self, batch: Tuple[torch.Tensor, ...]) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Calculates the mse loss using a mini batch from the replay buffer
|
||||||
|
Args:
|
||||||
|
batch: current mini batch of replay data
|
||||||
|
Returns:
|
||||||
|
loss
|
||||||
|
"""
|
||||||
|
states, actions, rewards, dones, next_states = batch
|
||||||
|
|
||||||
|
actions = actions.to(torch.int64)
|
||||||
|
state_action_values = self.agent_net(states).gather(1, actions.unsqueeze(-1)).squeeze(-1)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
next_state_values = self.target_net(next_states).max(1)[0]
|
||||||
|
next_state_values[dones] = 0.0
|
||||||
|
next_state_values = next_state_values.detach()
|
||||||
|
|
||||||
|
expected_state_action_values = next_state_values * self.gamma + rewards
|
||||||
|
|
||||||
|
return F.mse_loss(state_action_values, expected_state_action_values)
|
@ -1,22 +1,38 @@
|
|||||||
import librosa
|
try:
|
||||||
|
import librosa
|
||||||
|
except ImportError: # pragma: no-cover
|
||||||
|
raise ImportError('You want to use `librosa` plugins which are not installed yet,' # pragma: no-cover
|
||||||
|
' install it with `pip install librosa`.')
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
class Speed(object):
|
class Speed(object):
|
||||||
|
|
||||||
def __init__(self, max_ratio=0.3, speed_factor=1):
|
def __init__(self, max_amount=0.3, speed_min=1, speed_max=1):
|
||||||
self.speed_factor = speed_factor
|
self.speed_max = speed_max if speed_max else 1
|
||||||
self.max_ratio = max_ratio
|
self.speed_min = speed_min if speed_min else 1
|
||||||
|
# noinspection PyTypeChecker
|
||||||
|
self.max_amount = min(max(0, max_amount), 1)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f'{self.__class__.__name__}({self.__dict__})'
|
||||||
|
|
||||||
def __call__(self, x):
|
def __call__(self, x):
|
||||||
if not all([self.speed_factor, self.max_ratio]):
|
if self.speed_min == 1 and self.speed_max == 1:
|
||||||
return x
|
return x
|
||||||
start = int(np.random.randint(0, x.shape[-1],1))
|
start = int(np.random.randint(low=0, high=x.shape[-1], size=1))
|
||||||
end = int((np.random.uniform(0, self.max_ratio, 1) * x.shape[-1]) + start)
|
width = np.random.uniform(low=0, high=self.max_amount, size=1) * x.shape[-1]
|
||||||
|
end = int(width + start)
|
||||||
end = min(end, x.shape[-1])
|
end = min(end, x.shape[-1])
|
||||||
try:
|
try:
|
||||||
speed_factor = float(np.random.uniform(min(self.speed_factor, 1), max(self.speed_factor, 1), 1))
|
speed_factor = float(np.random.uniform(low=self.speed_min, high=self.speed_max, size=1))
|
||||||
aug_data = librosa.effects.time_stretch(x[start:end], speed_factor)
|
aug_data = librosa.effects.time_stretch(y=x[start:end], rate=speed_factor)
|
||||||
return np.concatenate((x[:start], aug_data, x[end:]), axis=0)[:x.shape[-1]]
|
x_aug = np.concatenate((x[:start], aug_data, x[end:]), axis=0)[:x.shape[-1]]
|
||||||
|
if speed_factor > 1:
|
||||||
|
embedding = np.zeros_like(x)
|
||||||
|
embedding[:x_aug.shape[0]] = x_aug
|
||||||
|
x_aug = embedding
|
||||||
|
return x_aug
|
||||||
except ValueError:
|
except ValueError:
|
||||||
return x
|
return x
|
||||||
|
@ -1,8 +1,16 @@
|
|||||||
import librosa
|
|
||||||
from scipy.signal import butter, lfilter
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
try:
|
||||||
|
import librosa
|
||||||
|
except ImportError: # pragma: no-cover
|
||||||
|
raise ImportError('You want to use `librosa` plugins which are not installed yet,' # pragma: no-cover
|
||||||
|
' install it with `pip install librosa`.')
|
||||||
|
try:
|
||||||
|
from scipy.signal import butter, lfilter
|
||||||
|
except ImportError: # pragma: no-cover
|
||||||
|
raise ImportError('You want to use `scikit` plugins which are not installed yet,' # pragma: no-cover
|
||||||
|
' install it with `pip install scikit-learn`.')
|
||||||
|
|
||||||
|
|
||||||
def scale_minmax(x, min_val=0.0, max_val=1.0):
|
def scale_minmax(x, min_val=0.0, max_val=1.0):
|
||||||
x_std = (x - x.min()) / (x.max() - x.min())
|
x_std = (x - x.min()) / (x.max() - x.min())
|
||||||
@ -28,6 +36,9 @@ class MFCC(object):
|
|||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
self.__dict__.update(kwargs)
|
self.__dict__.update(kwargs)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f'{self.__class__.__name__}({self.__dict__})'
|
||||||
|
|
||||||
def __call__(self, y):
|
def __call__(self, y):
|
||||||
mfcc = librosa.feature.mfcc(y, **self.__dict__)
|
mfcc = librosa.feature.mfcc(y, **self.__dict__)
|
||||||
return mfcc
|
return mfcc
|
||||||
@ -35,27 +46,38 @@ class MFCC(object):
|
|||||||
|
|
||||||
class NormalizeLocal(object):
|
class NormalizeLocal(object):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.cache: np.ndarray
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f'{self.__class__.__name__}({self.__dict__})'
|
||||||
|
|
||||||
def __call__(self, x: np.ndarray):
|
def __call__(self, x: np.ndarray):
|
||||||
|
|
||||||
|
x[np.isnan(x)] = 0
|
||||||
|
x[np.isinf(x)] = 0
|
||||||
|
|
||||||
mean = x.mean()
|
mean = x.mean()
|
||||||
std = x.std() + 0.0001
|
std = x.std() + 0.0001
|
||||||
|
|
||||||
# Pytorch Version:
|
# Pytorch Version:
|
||||||
# x = x.__sub__(mean).__div__(std)
|
# tensor = tensor.__sub__(mean).__div__(std)
|
||||||
# Numpy Version
|
# Numpy Version
|
||||||
x = (x - mean) / std
|
x = (x - mean) / std
|
||||||
|
|
||||||
x[np.isnan(x)] = 0
|
x[np.isnan(x)] = 0
|
||||||
x[np.isinf(x)] = 0
|
x[np.isinf(x)] = 0
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class NormalizeMelband(object):
|
class NormalizeMelband(object):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.cache: np.ndarray
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f'{self.__class__.__name__}({self.__dict__})'
|
||||||
|
|
||||||
def __call__(self, x: np.ndarray):
|
def __call__(self, x: np.ndarray):
|
||||||
mean = x.mean(-1).unsqueeze(-1)
|
mean = x.mean(-1).unsqueeze(-1)
|
||||||
std = x.std(-1).unsqueeze(-1)
|
std = x.std(-1).unsqueeze(-1)
|
||||||
@ -66,10 +88,13 @@ class NormalizeMelband(object):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class AudioToMel(object):
|
class LibrosaAudioToMel(object):
|
||||||
def __init__(self, amplitude_to_db=False, power_to_db=False, **kwargs):
|
def __init__(self, amplitude_to_db=False, power_to_db=False, **mel_kwargs):
|
||||||
assert not all([amplitude_to_db, power_to_db]), "Choose amplitude_to_db or power_to_db, not both!"
|
assert not all([amplitude_to_db, power_to_db]), "Choose amplitude_to_db or power_to_db, not both!"
|
||||||
self.mel_kwargs = kwargs
|
# Mel kwargs are:
|
||||||
|
# sr n_mels n_fft hop_length
|
||||||
|
|
||||||
|
self.mel_kwargs = mel_kwargs
|
||||||
self.amplitude_to_db = amplitude_to_db
|
self.amplitude_to_db = amplitude_to_db
|
||||||
self.power_to_db = power_to_db
|
self.power_to_db = power_to_db
|
||||||
|
|
||||||
@ -89,6 +114,9 @@ class PowerToDB(object):
|
|||||||
def __init__(self, running_max=False):
|
def __init__(self, running_max=False):
|
||||||
self.running_max = 0 if running_max else None
|
self.running_max = 0 if running_max else None
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f'{self.__class__.__name__}({self.__dict__})'
|
||||||
|
|
||||||
def __call__(self, x):
|
def __call__(self, x):
|
||||||
if self.running_max is not None:
|
if self.running_max is not None:
|
||||||
self.running_max = max(np.max(x), self.running_max)
|
self.running_max = max(np.max(x), self.running_max)
|
||||||
@ -100,6 +128,9 @@ class LowPass(object):
|
|||||||
def __init__(self, sr=16000):
|
def __init__(self, sr=16000):
|
||||||
self.sr = sr
|
self.sr = sr
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f'{self.__class__.__name__}({self.__dict__})'
|
||||||
|
|
||||||
def __call__(self, x):
|
def __call__(self, x):
|
||||||
return butter_lowpass_filter(x, 1000, 1)
|
return butter_lowpass_filter(x, 1000, 1)
|
||||||
|
|
||||||
@ -108,12 +139,16 @@ class MelToImage(object):
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f'{self.__class__.__name__}({self.__dict__})'
|
||||||
|
|
||||||
def __call__(self, x):
|
def __call__(self, x):
|
||||||
# Source to Solution: https://stackoverflow.com/a/57204349
|
# Source to Solution: https://stackoverflow.com/a/57204349
|
||||||
mels = np.log(x + 1e-9) # add small number to avoid log(0)
|
mels = np.log(x + 1e-9) # add small number to avoid log(0)
|
||||||
|
|
||||||
# min-max scale to fit inside 8-bit range
|
# min-max scale to fit inside 8-bit range
|
||||||
img = scale_minmax(mels, 0, 255).astype(np.uint8)
|
img = scale_minmax(mels, 0, 255)
|
||||||
img = np.flip(img, axis=0) # put low frequencies at the bottom in image
|
img = np.flip(img) # put low frequencies at the bottom in image
|
||||||
img = 255 - img # invert. make black==more energy
|
img = 255 - img # invert. make black==more energy
|
||||||
|
img = img.astype(np.float)
|
||||||
return img
|
return img
|
||||||
|
71
audio_toolset/audio_to_mel_dataset.py
Normal file
71
audio_toolset/audio_to_mel_dataset.py
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pickle
|
||||||
|
from abc import ABC
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from torchvision.transforms import Compose
|
||||||
|
|
||||||
|
from ml_lib.audio_toolset.audio_io import LibrosaAudioToMel, MelToImage
|
||||||
|
from ml_lib.audio_toolset.mel_dataset import TorchMelDataset
|
||||||
|
|
||||||
|
|
||||||
|
import librosa
|
||||||
|
|
||||||
|
|
||||||
|
class LibrosaAudioToMelDataset(Dataset):
|
||||||
|
|
||||||
|
@property
|
||||||
|
def audio_file_duration(self):
|
||||||
|
return librosa.get_duration(sr=self.mel_kwargs.get('sr', None), filename=self.audio_path)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sampling_rate(self):
|
||||||
|
return self.mel_kwargs.get('sr', None)
|
||||||
|
|
||||||
|
def __init__(self, audio_file_path, label, sample_segment_len=0, sample_hop_len=0, reset=False,
|
||||||
|
audio_augmentations=None, mel_augmentations=None, mel_kwargs=None, **kwargs):
|
||||||
|
super(LibrosaAudioToMelDataset, self).__init__()
|
||||||
|
|
||||||
|
# audio_file, sampling_rate = librosa.load(self.audio_path, sr=sampling_rate)
|
||||||
|
mel_kwargs.update(sr=mel_kwargs.get('sr', None) or librosa.get_samplerate(audio_file_path))
|
||||||
|
self.mel_kwargs = mel_kwargs
|
||||||
|
self.reset = reset
|
||||||
|
self.audio_path = Path(audio_file_path)
|
||||||
|
|
||||||
|
mel_folder_suffix = self.audio_path.parent.parent.name
|
||||||
|
self.mel_folder = Path(str(self.audio_path)
|
||||||
|
.replace(mel_folder_suffix, f'{mel_folder_suffix}_mel_folder')).parent.parent
|
||||||
|
|
||||||
|
self.mel_file_path = self.mel_folder / f'{self.audio_path.stem}.npy'
|
||||||
|
|
||||||
|
self.audio_augmentations = audio_augmentations
|
||||||
|
|
||||||
|
self.dataset = TorchMelDataset(self.mel_file_path, sample_segment_len, sample_hop_len, label,
|
||||||
|
self.audio_file_duration, mel_kwargs['sr'], mel_kwargs['hop_length'],
|
||||||
|
mel_kwargs['n_mels'], transform=mel_augmentations)
|
||||||
|
|
||||||
|
self._mel_transform = Compose([LibrosaAudioToMel(power_to_db=False, **mel_kwargs),
|
||||||
|
MelToImage()
|
||||||
|
])
|
||||||
|
|
||||||
|
def __getitem__(self, item):
|
||||||
|
return self.dataset[item]
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.dataset)
|
||||||
|
|
||||||
|
def build_mel(self):
|
||||||
|
if self.reset:
|
||||||
|
self.mel_file_path.unlink(missing_ok=True)
|
||||||
|
if not self.mel_file_path.exists():
|
||||||
|
self.mel_file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
with self.audio_path.open(mode='rb') as audio_file:
|
||||||
|
raw_sample, _ = librosa.core.load(audio_file, sr=self.sampling_rate)
|
||||||
|
mel_sample = self._mel_transform(raw_sample)
|
||||||
|
with self.mel_file_path.open('wb') as mel_file:
|
||||||
|
pickle.dump(mel_sample, mel_file, protocol=pickle.HIGHEST_PROTOCOL)
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return self.mel_file_path.exists()
|
@ -1,17 +1,20 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from ml_lib.utils.transforms import _BaseTransformation
|
||||||
|
|
||||||
class NoiseInjection(object):
|
|
||||||
|
|
||||||
def __init__(self, noise_factor: float, sigma=0.5, mu=0.5):
|
class NoiseInjection(_BaseTransformation):
|
||||||
assert noise_factor >= 0, f'max_shift_ratio has to be greater then 0, but was: {noise_factor}.'
|
|
||||||
|
def __init__(self, noise_factor: float, sigma=1, mu=0):
|
||||||
|
super(NoiseInjection, self).__init__()
|
||||||
|
assert noise_factor >= 0, f'noise_factor has to be greater then 0, but was: {noise_factor}.'
|
||||||
self.mu = mu
|
self.mu = mu
|
||||||
self.sigma = sigma
|
self.sigma = sigma
|
||||||
self.noise_factor = noise_factor
|
self.noise_factor = noise_factor
|
||||||
|
|
||||||
def __call__(self, x: np.ndarray):
|
def __call__(self, x: np.ndarray):
|
||||||
if self.noise_factor:
|
if self.noise_factor:
|
||||||
noise = np.random.uniform(0, self.noise_factor, size=x.shape)
|
noise = np.random.normal(self.mu, self.sigma, size=x.shape) * self.noise_factor
|
||||||
augmented_data = x + x * noise
|
augmented_data = x + x * noise
|
||||||
# Cast back to same data type
|
# Cast back to same data type
|
||||||
augmented_data = augmented_data.astype(x.dtype)
|
augmented_data = augmented_data.astype(x.dtype)
|
||||||
@ -20,14 +23,15 @@ class NoiseInjection(object):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class LoudnessManipulator(object):
|
class LoudnessManipulator(_BaseTransformation):
|
||||||
|
|
||||||
def __init__(self, max_factor: float):
|
def __init__(self, max_factor: float):
|
||||||
|
super(LoudnessManipulator, self).__init__()
|
||||||
assert 1 > max_factor >= 0, f'max_shift_ratio has to be between [0,1], but was: {max_factor}.'
|
assert 1 > max_factor >= 0, f'max_shift_ratio has to be between [0,1], but was: {max_factor}.'
|
||||||
|
|
||||||
self.max_factor = max_factor
|
self.max_factor = max_factor
|
||||||
|
|
||||||
def __call__(self, x: np.ndarray):
|
def __call__(self, x):
|
||||||
if self.max_factor:
|
if self.max_factor:
|
||||||
augmented_data = x + x * (np.random.random() * self.max_factor)
|
augmented_data = x + x * (np.random.random() * self.max_factor)
|
||||||
# Cast back to same data type
|
# Cast back to same data type
|
||||||
@ -37,11 +41,12 @@ class LoudnessManipulator(object):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class ShiftTime(object):
|
class ShiftTime(_BaseTransformation):
|
||||||
|
|
||||||
valid_shifts = ['right', 'left', 'any']
|
valid_shifts = ['right', 'left', 'any']
|
||||||
|
|
||||||
def __init__(self, max_shift_ratio: float, shift_direction: str = 'any'):
|
def __init__(self, max_shift_ratio: float, shift_direction: str = 'any'):
|
||||||
|
super(ShiftTime, self).__init__()
|
||||||
assert 1 > max_shift_ratio >= 0, f'max_shift_ratio has to be between [0,1], but was: {max_shift_ratio}.'
|
assert 1 > max_shift_ratio >= 0, f'max_shift_ratio has to be between [0,1], but was: {max_shift_ratio}.'
|
||||||
assert shift_direction.lower() in self.valid_shifts, f'shift_direction has to be one of: {self.valid_shifts}'
|
assert shift_direction.lower() in self.valid_shifts, f'shift_direction has to be one of: {self.valid_shifts}'
|
||||||
self.max_shift_ratio = max_shift_ratio
|
self.max_shift_ratio = max_shift_ratio
|
||||||
@ -53,26 +58,27 @@ class ShiftTime(object):
|
|||||||
if self.shift_direction == 'right':
|
if self.shift_direction == 'right':
|
||||||
shift = -1 * shift
|
shift = -1 * shift
|
||||||
elif self.shift_direction == 'any':
|
elif self.shift_direction == 'any':
|
||||||
direction = np.random.choice([1, -1], 1)
|
direction = np.asscalar(np.random.choice([1, -1], 1))
|
||||||
shift = direction * shift
|
shift = direction * shift
|
||||||
augmented_data = np.roll(x, shift)
|
augmented_data = np.roll(x, shift)
|
||||||
# Set to silence for heading/ tailing
|
# Set to silence for heading/ tailing
|
||||||
shift = int(shift)
|
shift = int(shift)
|
||||||
if shift > 0:
|
if shift > 0:
|
||||||
augmented_data[:shift] = 0
|
augmented_data[:, :shift] = 0
|
||||||
else:
|
else:
|
||||||
augmented_data[shift:] = 0
|
augmented_data[:, shift:] = 0
|
||||||
return augmented_data
|
return augmented_data
|
||||||
else:
|
else:
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class MaskAug(object):
|
class MaskAug(_BaseTransformation):
|
||||||
|
|
||||||
w_idx = -1
|
w_idx = -1
|
||||||
h_idx = -2
|
h_idx = -2
|
||||||
|
|
||||||
def __init__(self, duration_ratio_max=0.3, mask_with_noise=True):
|
def __init__(self, duration_ratio_max=0.3, mask_with_noise=True):
|
||||||
|
super(MaskAug, self).__init__()
|
||||||
assertion = f'"duration_ratio" has to be within [0..1], but was: {duration_ratio_max}'
|
assertion = f'"duration_ratio" has to be within [0..1], but was: {duration_ratio_max}'
|
||||||
if isinstance(duration_ratio_max, (tuple, list)):
|
if isinstance(duration_ratio_max, (tuple, list)):
|
||||||
assert all([0 < max_val < 1 for max_val in duration_ratio_max]), assertion
|
assert all([0 < max_val < 1 for max_val in duration_ratio_max]), assertion
|
||||||
@ -87,9 +93,9 @@ class MaskAug(object):
|
|||||||
def __call__(self, x):
|
def __call__(self, x):
|
||||||
for dim in (self.w_idx, self.h_idx):
|
for dim in (self.w_idx, self.h_idx):
|
||||||
if self.duration_ratio_max[dim]:
|
if self.duration_ratio_max[dim]:
|
||||||
start = int(np.random.choice(x.shape[dim], 1))
|
start = np.asscalar(np.random.choice(x.shape[dim], 1))
|
||||||
v_max = x.shape[dim] * self.duration_ratio_max[dim]
|
v_max = int(x.shape[dim] * self.duration_ratio_max[dim])
|
||||||
size = int(np.random.randint(0, v_max, 1))
|
size = np.asscalar(np.random.randint(0, v_max, 1))
|
||||||
end = int(min(start + size, x.shape[dim]))
|
end = int(min(start + size, x.shape[dim]))
|
||||||
size = end - start
|
size = end - start
|
||||||
if dim == self.w_idx:
|
if dim == self.w_idx:
|
||||||
|
54
audio_toolset/mel_dataset.py
Normal file
54
audio_toolset/mel_dataset.py
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pickle
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
|
from ml_lib.modules.util import AutoPadToShape
|
||||||
|
|
||||||
|
|
||||||
|
class TorchMelDataset(Dataset):
|
||||||
|
def __init__(self, mel_path, sub_segment_len, sub_segment_hop_len, label, audio_file_len,
|
||||||
|
sampling_rate, mel_hop_len, n_mels, transform=None, auto_pad_to_shape=True):
|
||||||
|
super(TorchMelDataset, self).__init__()
|
||||||
|
self.sampling_rate = int(sampling_rate)
|
||||||
|
self.audio_file_len = float(audio_file_len)
|
||||||
|
if auto_pad_to_shape and sub_segment_len:
|
||||||
|
self.padding = AutoPadToShape((int(n_mels), int(sub_segment_len)))
|
||||||
|
else:
|
||||||
|
self.padding = None
|
||||||
|
self.path = Path(mel_path)
|
||||||
|
self.sub_segment_len = int(sub_segment_len)
|
||||||
|
self.mel_hop_len = int(mel_hop_len)
|
||||||
|
self.sub_segment_hop_len = int(sub_segment_hop_len)
|
||||||
|
self.n = int((self.sampling_rate / self.mel_hop_len) * self.audio_file_len + 1)
|
||||||
|
if self.sub_segment_len and self.sub_segment_hop_len and (self.n - self.sub_segment_len) > 0:
|
||||||
|
self.offsets = list(range(0, self.n - self.sub_segment_len, self.sub_segment_hop_len))
|
||||||
|
else:
|
||||||
|
self.offsets = [0]
|
||||||
|
if len(self) == 0:
|
||||||
|
print('what happend here')
|
||||||
|
self.label = label
|
||||||
|
self.transform = transform
|
||||||
|
|
||||||
|
def __getitem__(self, item):
|
||||||
|
with self.path.open('rb') as mel_file:
|
||||||
|
mel_spec = pickle.load(mel_file, fix_imports=True)
|
||||||
|
start = self.offsets[item]
|
||||||
|
sub_segments_attributes_set = self.sub_segment_len and self.sub_segment_hop_len
|
||||||
|
sub_segment_length_smaller_then_tot_length = self.sub_segment_len < mel_spec.shape[1]
|
||||||
|
|
||||||
|
if sub_segments_attributes_set and sub_segment_length_smaller_then_tot_length:
|
||||||
|
duration = self.sub_segment_len
|
||||||
|
else:
|
||||||
|
duration = mel_spec.shape[1]
|
||||||
|
|
||||||
|
snippet = mel_spec[:, start: start + duration]
|
||||||
|
if self.transform:
|
||||||
|
snippet = self.transform(snippet)
|
||||||
|
if self.padding:
|
||||||
|
snippet = self.padding(snippet)
|
||||||
|
return self.path.__str__(), snippet, self.label
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.offsets)
|
27
audio_toolset/mel_transforms.py
Normal file
27
audio_toolset/mel_transforms.py
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class Normalize(object):
|
||||||
|
|
||||||
|
def __init__(self, min_db_level: Union[int, float]):
|
||||||
|
self.min_db_level = min_db_level
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f'{self.__class__.__name__}({self.__dict__})'
|
||||||
|
|
||||||
|
def __call__(self, s: np.ndarray) -> np.ndarray:
|
||||||
|
return np.clip((s - self.min_db_level) / -self.min_db_level, 0, 1)
|
||||||
|
|
||||||
|
|
||||||
|
class DeNormalize(object):
|
||||||
|
|
||||||
|
def __init__(self, min_db_level: Union[int, float]):
|
||||||
|
self.min_db_level = min_db_level
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f'{self.__class__.__name__}({self.__dict__})'
|
||||||
|
|
||||||
|
def __call__(self, s: np.ndarray) -> np.ndarray:
|
||||||
|
return (np.clip(s, 0, 1) * -self.min_db_level) + self.min_db_level
|
@ -1,13 +1,21 @@
|
|||||||
import matplotlib.pyplot as plt
|
try:
|
||||||
from sklearn.metrics import roc_curve, auc
|
import matplotlib.pyplot as plt
|
||||||
|
except ImportError: # pragma: no-cover
|
||||||
|
raise ImportError('You want to use `matplotlib` plugins which are not installed yet,' # pragma: no-cover
|
||||||
|
' install it with `pip install matplotlib`.')
|
||||||
|
try:
|
||||||
|
from sklearn.metrics import roc_curve, auc, recall_score
|
||||||
|
except ImportError: # pragma: no-cover
|
||||||
|
raise ImportError('You want to use `sklearn` plugins which are not installed yet,' # pragma: no-cover
|
||||||
|
' install it with `pip install scikit-learn`.')
|
||||||
|
|
||||||
|
|
||||||
class ROCEvaluation(object):
|
class ROCEvaluation(object):
|
||||||
|
|
||||||
linewidth = 2
|
linewidth = 2
|
||||||
|
|
||||||
def __init__(self, plot_roc=False):
|
def __init__(self, plot=False):
|
||||||
self.plot_roc = plot_roc
|
self.plot = plot
|
||||||
self.epoch = 0
|
self.epoch = 0
|
||||||
|
|
||||||
def __call__(self, prediction, label):
|
def __call__(self, prediction, label):
|
||||||
@ -15,7 +23,7 @@ class ROCEvaluation(object):
|
|||||||
# Compute ROC curve and ROC area
|
# Compute ROC curve and ROC area
|
||||||
fpr, tpr, _ = roc_curve(prediction, label)
|
fpr, tpr, _ = roc_curve(prediction, label)
|
||||||
roc_auc = auc(fpr, tpr)
|
roc_auc = auc(fpr, tpr)
|
||||||
if self.plot_roc:
|
if self.plot:
|
||||||
_ = plt.gcf()
|
_ = plt.gcf()
|
||||||
plt.plot(fpr, tpr, color='darkorange', lw=self.linewidth, label=f'ROC curve (area = {roc_auc})')
|
plt.plot(fpr, tpr, color='darkorange', lw=self.linewidth, label=f'ROC curve (area = {roc_auc})')
|
||||||
self._prepare_fig()
|
self._prepare_fig()
|
||||||
@ -32,3 +40,32 @@ class ROCEvaluation(object):
|
|||||||
fig.legend(loc="lower right")
|
fig.legend(loc="lower right")
|
||||||
|
|
||||||
return fig
|
return fig
|
||||||
|
|
||||||
|
|
||||||
|
class UAREvaluation(object):
|
||||||
|
|
||||||
|
def __init__(self, labels: list, plot=False):
|
||||||
|
self.labels = labels
|
||||||
|
self.plot_roc = plot
|
||||||
|
self.epoch = 0
|
||||||
|
|
||||||
|
def __call__(self, prediction, label):
|
||||||
|
|
||||||
|
# Compute uar score - UnweightedAverageRecal
|
||||||
|
|
||||||
|
uar_score = recall_score(label, prediction, labels=self.labels, average='macro',
|
||||||
|
sample_weight=None, zero_division='warn')
|
||||||
|
return uar_score
|
||||||
|
|
||||||
|
def _prepare_fig(self):
|
||||||
|
raise NotImplementedError # TODO Implement a nice visualization
|
||||||
|
fig = plt.gcf()
|
||||||
|
ax = plt.gca()
|
||||||
|
plt.plot([0, 1], [0, 1], color='navy', lw=self.linewidth, linestyle='--')
|
||||||
|
plt.xlim([0.0, 1.0])
|
||||||
|
plt.ylim([0.0, 1.05])
|
||||||
|
plt.xlabel('False Positive Rate')
|
||||||
|
plt.ylabel('True Positive Rate')
|
||||||
|
fig.legend(loc="lower right")
|
||||||
|
|
||||||
|
return fig
|
0
metrics/__init__.py
Normal file
0
metrics/__init__.py
Normal file
13
metrics/_base_score.py
Normal file
13
metrics/_base_score.py
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
from abc import ABC
|
||||||
|
|
||||||
|
|
||||||
|
class _BaseScores(ABC):
|
||||||
|
|
||||||
|
def __init__(self, lightning_model):
|
||||||
|
self.model = lightning_model
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __call__(self, outputs):
|
||||||
|
# summary_dict = dict()
|
||||||
|
# return summary_dict
|
||||||
|
raise NotImplementedError
|
47
metrics/attention_rollout.py
Normal file
47
metrics/attention_rollout.py
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from einops import reduce
|
||||||
|
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from sklearn.ensemble import IsolationForest
|
||||||
|
from sklearn.metrics import recall_score, roc_auc_score, average_precision_score
|
||||||
|
|
||||||
|
from ml_lib.metrics._base_score import _BaseScores
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionRollout(_BaseScores):
|
||||||
|
|
||||||
|
def __init__(self, *args):
|
||||||
|
super(AttentionRollout, self).__init__(*args)
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __call__(self, outputs):
|
||||||
|
summary_dict = dict()
|
||||||
|
#######################################################################################
|
||||||
|
# Additional Score - Histogram Distances - Image Plotting
|
||||||
|
#######################################################################################
|
||||||
|
#
|
||||||
|
# INIT
|
||||||
|
attn_weights = [output['attn_weights'].cpu().numpy() for output in outputs]
|
||||||
|
attn_reduce_heads = [reduce(x, '') for x in attn_weights]
|
||||||
|
|
||||||
|
if self.model.params.use_residual:
|
||||||
|
residual_att = np.eye(att_mat.shape[1])[None, ...]
|
||||||
|
aug_att_mat = att_mat + residual_att
|
||||||
|
aug_att_mat = aug_att_mat / aug_att_mat.sum(axis=-1)[..., None]
|
||||||
|
else:
|
||||||
|
aug_att_mat = att_mat
|
||||||
|
|
||||||
|
joint_attentions = np.zeros(aug_att_mat.shape)
|
||||||
|
|
||||||
|
layers = joint_attentions.shape[0]
|
||||||
|
joint_attentions[0] = aug_att_mat[0]
|
||||||
|
for i in np.arange(1, layers):
|
||||||
|
joint_attentions[i] = aug_att_mat[i].dot(joint_attentions[i - 1])
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
68
metrics/binary_class_classifictaion.py
Normal file
68
metrics/binary_class_classifictaion.py
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from sklearn.ensemble import IsolationForest
|
||||||
|
from sklearn.metrics import recall_score, roc_auc_score, average_precision_score
|
||||||
|
|
||||||
|
from ml_lib.metrics._base_score import _BaseScores
|
||||||
|
from ml_lib.utils.tools import to_one_hot
|
||||||
|
|
||||||
|
|
||||||
|
class BinaryScores(_BaseScores):
|
||||||
|
|
||||||
|
def __init__(self, *args):
|
||||||
|
super(BinaryScores, self).__init__(*args)
|
||||||
|
|
||||||
|
def __call__(self, outputs):
|
||||||
|
summary_dict = dict()
|
||||||
|
|
||||||
|
# Additional Score like the unweighted Average Recall:
|
||||||
|
#########################
|
||||||
|
# INIT
|
||||||
|
if isinstance(outputs['batch_y'], torch.Tensor):
|
||||||
|
y_true = outputs['batch_y'].cpu().numpy()
|
||||||
|
else:
|
||||||
|
y_true = torch.cat([output['batch_y'] for output in outputs]).cpu().numpy()
|
||||||
|
|
||||||
|
if isinstance(outputs['y'], torch.Tensor):
|
||||||
|
y_pred = outputs['y'].cpu().numpy()
|
||||||
|
else:
|
||||||
|
y_pred = torch.cat([output['y'] for output in outputs]).squeeze().cpu().float().numpy()
|
||||||
|
|
||||||
|
# UnweightedAverageRecall
|
||||||
|
# y_true = torch.cat([output['batch_y'] for output in outputs]).cpu().numpy()
|
||||||
|
# y_pred = torch.cat([output['element_wise_recon_error'] for output in outputs]).squeeze().cpu().numpy()
|
||||||
|
|
||||||
|
# How to apply a threshold manualy
|
||||||
|
# y_pred = (y_pred >= 0.5).astype(np.float32)
|
||||||
|
|
||||||
|
# How to apply a threshold by IF (Isolation Forest)
|
||||||
|
clf = IsolationForest()
|
||||||
|
y_score = clf.fit_predict(y_pred.reshape(-1, 1))
|
||||||
|
y_score = (np.asarray(y_score) == -1).astype(np.float32)
|
||||||
|
|
||||||
|
uar_score = recall_score(y_true, y_score, labels=[0, 1], average='macro',
|
||||||
|
sample_weight=None, zero_division='warn')
|
||||||
|
summary_dict.update(dict(uar_score=uar_score))
|
||||||
|
#########################
|
||||||
|
# Precission
|
||||||
|
precision_score = average_precision_score(y_true, y_score)
|
||||||
|
summary_dict.update(dict(precision_score=precision_score))
|
||||||
|
|
||||||
|
#########################
|
||||||
|
# AUC
|
||||||
|
try:
|
||||||
|
auc_score = roc_auc_score(y_true=y_true, y_score=y_score)
|
||||||
|
summary_dict.update(dict(auc_score=auc_score))
|
||||||
|
except ValueError:
|
||||||
|
summary_dict.update(dict(auc_score=-1))
|
||||||
|
|
||||||
|
#########################
|
||||||
|
# pAUC
|
||||||
|
try:
|
||||||
|
pauc = roc_auc_score(y_true=y_true, y_score=y_score, max_fpr=0.15)
|
||||||
|
summary_dict.update(dict(pauc_score=pauc))
|
||||||
|
except ValueError:
|
||||||
|
summary_dict.update(dict(pauc_score=-1))
|
||||||
|
|
||||||
|
return summary_dict
|
68
metrics/generative_task_evaluation.py
Normal file
68
metrics/generative_task_evaluation.py
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
from itertools import cycle
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from sklearn.metrics import roc_curve, auc, roc_auc_score, ConfusionMatrixDisplay, confusion_matrix
|
||||||
|
from scipy.spatial.distance import cdist
|
||||||
|
|
||||||
|
from ml_lib.metrics._base_score import _BaseScores
|
||||||
|
|
||||||
|
from matplotlib import pyplot as plt
|
||||||
|
|
||||||
|
|
||||||
|
class GenerativeTaskEval(_BaseScores):
|
||||||
|
|
||||||
|
def __init__(self, *args):
|
||||||
|
super(GenerativeTaskEval, self).__init__(*args)
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __call__(self, outputs):
|
||||||
|
summary_dict = dict()
|
||||||
|
#######################################################################################
|
||||||
|
# Additional Score - Histogram Distances - Image Plotting
|
||||||
|
#######################################################################################
|
||||||
|
#
|
||||||
|
# INIT
|
||||||
|
y_true = torch.cat([output['batch_y'] for output in outputs]).cpu().numpy()
|
||||||
|
|
||||||
|
y_pred = torch.cat([output['y'] for output in outputs]).squeeze().cpu().numpy()
|
||||||
|
|
||||||
|
attn_weights = torch.cat([output['attn_weights'] for output in outputs]).squeeze().cpu().numpy()
|
||||||
|
|
||||||
|
######################################################################################
|
||||||
|
#
|
||||||
|
# Histogram comparission
|
||||||
|
|
||||||
|
y_true_hist = np.histogram(y_true, bins=128)[0] # Todo: Find a better value
|
||||||
|
y_pred_hist = np.histogram(y_pred, bins=128)[0] # Todo: Find a better value
|
||||||
|
|
||||||
|
# L2 norm == euclidean distance
|
||||||
|
hist_euc_dist = cdist(np.expand_dims(y_true_hist, axis=0), np.expand_dims(y_pred_hist, axis=0),
|
||||||
|
metric='euclidean')
|
||||||
|
|
||||||
|
# Manhattan Distance
|
||||||
|
hist_manhattan_dist = cdist(np.expand_dims(y_true_hist, axis=0), np.expand_dims(y_pred_hist, axis=0),
|
||||||
|
metric='cityblock')
|
||||||
|
|
||||||
|
summary_dict.update(hist_manhattan_dist=hist_manhattan_dist, hist_euc_dist=hist_euc_dist)
|
||||||
|
|
||||||
|
#######################################################################################
|
||||||
|
#
|
||||||
|
idx = np.random.choice(np.arange(y_true.shape[0]), 1).item()
|
||||||
|
|
||||||
|
ax = plt.imshow(y_true[idx].squeeze())
|
||||||
|
# Plot using a small number of colors, with unevenly spaced boundaries.
|
||||||
|
ax2 = plt.imshow(attn_weights[idx].sq, interpolation='nearest', aspect='auto', extent=ax.get_extent())
|
||||||
|
self.model.logger.log_image('ROC', image=plt.gcf(), step=self.model.current_epoch)
|
||||||
|
plt.clf()
|
||||||
|
|
||||||
|
|
||||||
|
#######################################################################################
|
||||||
|
#
|
||||||
|
|
||||||
|
|
||||||
|
#######################################################################################
|
||||||
|
#
|
||||||
|
|
||||||
|
plt.close('all')
|
||||||
|
return summary_dict
|
142
metrics/multi_class_classification.py
Normal file
142
metrics/multi_class_classification.py
Normal file
@ -0,0 +1,142 @@
|
|||||||
|
from itertools import cycle
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from sklearn.metrics import f1_score, roc_curve, auc, roc_auc_score, ConfusionMatrixDisplay, confusion_matrix, \
|
||||||
|
recall_score
|
||||||
|
|
||||||
|
from ml_lib.metrics._base_score import _BaseScores
|
||||||
|
from ml_lib.utils.tools import to_one_hot
|
||||||
|
|
||||||
|
from matplotlib import pyplot as plt
|
||||||
|
|
||||||
|
|
||||||
|
class MultiClassScores(_BaseScores):
|
||||||
|
|
||||||
|
def __init__(self, *args):
|
||||||
|
super(MultiClassScores, self).__init__(*args)
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __call__(self, outputs, class_names=None):
|
||||||
|
summary_dict = dict()
|
||||||
|
class_names = class_names or range(self.model.params.n_classes)
|
||||||
|
#######################################################################################
|
||||||
|
# Additional Score - UAR - ROC - Conf. Matrix - F1
|
||||||
|
#######################################################################################
|
||||||
|
#
|
||||||
|
# INIT
|
||||||
|
if isinstance(outputs['batch_y'], torch.Tensor):
|
||||||
|
y_true = outputs['batch_y'].cpu().numpy()
|
||||||
|
else:
|
||||||
|
y_true = torch.cat([output['batch_y'] for output in outputs]).cpu().numpy()
|
||||||
|
y_true_one_hot = to_one_hot(y_true, self.model.params.n_classes)
|
||||||
|
|
||||||
|
if isinstance(outputs['y'], torch.Tensor):
|
||||||
|
y_pred = outputs['y'].cpu().numpy()
|
||||||
|
else:
|
||||||
|
y_pred = torch.cat([output['y'] for output in outputs]).squeeze().cpu().float().numpy()
|
||||||
|
y_pred_max = np.argmax(y_pred, axis=1)
|
||||||
|
|
||||||
|
class_names = {val: key for val, key in enumerate(class_names)}
|
||||||
|
######################################################################################
|
||||||
|
#
|
||||||
|
# F1 SCORE
|
||||||
|
micro_f1_score = f1_score(y_true, y_pred_max, labels=None, pos_label=1, average='micro', sample_weight=None,
|
||||||
|
zero_division=True)
|
||||||
|
macro_f1_score = f1_score(y_true, y_pred_max, labels=None, pos_label=1, average='macro', sample_weight=None,
|
||||||
|
zero_division=True)
|
||||||
|
summary_dict.update(dict(micro_f1_score=micro_f1_score, macro_f1_score=macro_f1_score))
|
||||||
|
######################################################################################
|
||||||
|
#
|
||||||
|
# Unweichted Average Recall
|
||||||
|
uar = recall_score(y_true, y_pred_max, labels=[0, 1, 2, 3, 4], average='macro',
|
||||||
|
sample_weight=None, zero_division='warn')
|
||||||
|
summary_dict.update(dict(uar_score=uar))
|
||||||
|
#######################################################################################
|
||||||
|
#
|
||||||
|
# ROC Curve
|
||||||
|
|
||||||
|
# Compute ROC curve and ROC area for each class
|
||||||
|
fpr = dict()
|
||||||
|
tpr = dict()
|
||||||
|
roc_auc = dict()
|
||||||
|
for i in range(self.model.params.n_classes):
|
||||||
|
fpr[i], tpr[i], _ = roc_curve(y_true_one_hot[:, i], y_pred[:, i])
|
||||||
|
roc_auc[i] = auc(fpr[i], tpr[i])
|
||||||
|
|
||||||
|
# Compute micro-average ROC curve and ROC area
|
||||||
|
fpr["micro"], tpr["micro"], _ = roc_curve(y_true_one_hot.ravel(), y_pred.ravel())
|
||||||
|
roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])
|
||||||
|
|
||||||
|
# First aggregate all false positive rates
|
||||||
|
all_fpr = np.unique(np.concatenate([fpr[i] for i in range(self.model.params.n_classes)]))
|
||||||
|
|
||||||
|
# Then interpolate all ROC curves at this points
|
||||||
|
mean_tpr = np.zeros_like(all_fpr)
|
||||||
|
for i in range(self.model.params.n_classes):
|
||||||
|
mean_tpr += np.interp(all_fpr, fpr[i], tpr[i])
|
||||||
|
|
||||||
|
# Finally average it and compute AUC
|
||||||
|
mean_tpr /= self.model.params.n_classes
|
||||||
|
|
||||||
|
fpr["macro"] = all_fpr
|
||||||
|
tpr["macro"] = mean_tpr
|
||||||
|
roc_auc["macro"] = auc(fpr["macro"], tpr["macro"])
|
||||||
|
|
||||||
|
# Plot all ROC curves
|
||||||
|
plt.figure()
|
||||||
|
plt.plot(fpr["micro"], tpr["micro"],
|
||||||
|
label=f'micro ROC ({round(roc_auc["micro"], 2)})',
|
||||||
|
color='deeppink', linestyle=':', linewidth=4)
|
||||||
|
|
||||||
|
plt.plot(fpr["macro"], tpr["macro"],
|
||||||
|
label=f'macro ROC({round(roc_auc["macro"], 2)})',
|
||||||
|
color='navy', linestyle=':', linewidth=4)
|
||||||
|
|
||||||
|
colors = cycle(['firebrick', 'orangered', 'gold', 'olive', 'limegreen', 'aqua',
|
||||||
|
'dodgerblue', 'slategrey', 'royalblue', 'indigo', 'fuchsia'], )
|
||||||
|
|
||||||
|
for i, color in zip(range(self.model.params.n_classes), colors):
|
||||||
|
plt.plot(fpr[i], tpr[i], color=color, lw=2, label=f'{class_names[i]} ({round(roc_auc[i], 2)})')
|
||||||
|
|
||||||
|
plt.plot([0, 1], [0, 1], 'k--', lw=2)
|
||||||
|
plt.xlim([0.0, 1.0])
|
||||||
|
plt.ylim([0.0, 1.05])
|
||||||
|
plt.xlabel('False Positive Rate')
|
||||||
|
plt.ylabel('True Positive Rate')
|
||||||
|
plt.legend(loc="lower right")
|
||||||
|
|
||||||
|
self.model.logger.log_image('ROC', image=plt.gcf(), step=self.model.current_epoch)
|
||||||
|
# self.model.logger.log_image('ROC', image=plt.gcf(), step=self.model.current_epoch, ext='pdf')
|
||||||
|
plt.clf()
|
||||||
|
|
||||||
|
#######################################################################################
|
||||||
|
#
|
||||||
|
# ROC AUC SCORE
|
||||||
|
|
||||||
|
try:
|
||||||
|
macro_roc_auc_ovr = roc_auc_score(y_true_one_hot, y_pred, multi_class="ovr",
|
||||||
|
average="macro")
|
||||||
|
summary_dict.update(macro_roc_auc_ovr=macro_roc_auc_ovr)
|
||||||
|
except ValueError:
|
||||||
|
micro_roc_auc_ovr = roc_auc_score(y_true_one_hot, y_pred, multi_class="ovr",
|
||||||
|
average="micro")
|
||||||
|
summary_dict.update(micro_roc_auc_ovr=micro_roc_auc_ovr)
|
||||||
|
|
||||||
|
#######################################################################################
|
||||||
|
#
|
||||||
|
# Confusion matrix
|
||||||
|
fig1, ax1 = plt.subplots(dpi=96)
|
||||||
|
cm = confusion_matrix([class_names[x] for x in y_true], [class_names[x] for x in y_pred_max],
|
||||||
|
labels=[class_names[key] for key in class_names.keys()],
|
||||||
|
normalize='true')
|
||||||
|
disp = ConfusionMatrixDisplay(confusion_matrix=cm,
|
||||||
|
display_labels=[class_names[i] for i in range(self.model.params.n_classes)]
|
||||||
|
)
|
||||||
|
disp.plot(include_values=True, ax=ax1)
|
||||||
|
|
||||||
|
self.model.logger.log_image('Confusion_Matrix', image=fig1, step=self.model.current_epoch)
|
||||||
|
# self.model.logger.log_image('Confusion_Matrix', image=disp.figure_, step=self.model.current_epoch, ext='pdf')
|
||||||
|
|
||||||
|
plt.close('all')
|
||||||
|
return summary_dict
|
BIN
modules/__pycache__/__init__.cpython-37.pyc
Normal file
BIN
modules/__pycache__/__init__.cpython-37.pyc
Normal file
Binary file not shown.
BIN
modules/__pycache__/geometric_blocks.cpython-37.pyc
Normal file
BIN
modules/__pycache__/geometric_blocks.cpython-37.pyc
Normal file
Binary file not shown.
BIN
modules/__pycache__/util.cpython-37.pyc
Normal file
BIN
modules/__pycache__/util.cpython-37.pyc
Normal file
Binary file not shown.
@ -1,11 +1,17 @@
|
|||||||
|
import warnings
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import warnings
|
|
||||||
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
from modules.utils import AutoPad, Interpolate, ShapeMixin, F_x, Flatten
|
import sys
|
||||||
|
sys.path.append(str(Path(__file__).parent))
|
||||||
|
|
||||||
|
from .util import AutoPad, Interpolate, ShapeMixin, F_x, Flatten
|
||||||
|
|
||||||
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
|
||||||
@ -15,23 +21,24 @@ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|||||||
###################
|
###################
|
||||||
class LinearModule(ShapeMixin, nn.Module):
|
class LinearModule(ShapeMixin, nn.Module):
|
||||||
|
|
||||||
def __init__(self, in_shape, out_features, bias=True, activation=None,
|
def __init__(self, in_shape, out_features, use_bias=True, activation=None,
|
||||||
norm=False, dropout: Union[int, float] = 0, **kwargs):
|
use_norm=False, dropout: Union[int, float] = 0, **kwargs):
|
||||||
warnings.warn(f'The following arguments have been ignored: \n {list(kwargs.keys())}')
|
if list(kwargs.keys()):
|
||||||
|
warnings.warn(f'The following arguments have been ignored: \n {list(kwargs.keys())}')
|
||||||
super(LinearModule, self).__init__()
|
super(LinearModule, self).__init__()
|
||||||
|
|
||||||
self.in_shape = in_shape
|
self.in_shape = in_shape
|
||||||
self.flat = Flatten(self.in_shape) if isinstance(self.in_shape, (tuple, list)) else F_x(in_shape)
|
self.flat = Flatten(self.in_shape) if isinstance(self.in_shape, (tuple, list)) else F_x(in_shape)
|
||||||
self.dropout = nn.Dropout(dropout) if dropout else F_x(self.flat.shape)
|
self.dropout = nn.Dropout(dropout) if dropout else F_x(self.flat.shape)
|
||||||
self.norm = nn.BatchNorm1d(self.flat.shape) if norm else F_x(self.flat.shape)
|
self.norm = nn.LayerNorm(self.flat.shape) if use_norm else F_x(self.flat.shape)
|
||||||
self.linear = nn.Linear(self.flat.shape, out_features, bias=bias)
|
self.linear = nn.Linear(self.flat.shape, out_features, bias=use_bias)
|
||||||
self.activation = activation() if activation else F_x(self.linear.out_features)
|
self.activation = activation() if activation else F_x(self.linear.out_features)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
tensor = self.flat(x)
|
tensor = self.flat(x)
|
||||||
tensor = self.dropout(tensor)
|
tensor = self.dropout(tensor)
|
||||||
tensor = self.norm(tensor)
|
tensor = self.norm(tensor)
|
||||||
tensor = self.linear(tensor)
|
tensor = self.linear(tensor.float())
|
||||||
tensor = self.activation(tensor)
|
tensor = self.activation(tensor)
|
||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
@ -39,14 +46,22 @@ class LinearModule(ShapeMixin, nn.Module):
|
|||||||
class ConvModule(ShapeMixin, nn.Module):
|
class ConvModule(ShapeMixin, nn.Module):
|
||||||
|
|
||||||
def __init__(self, in_shape, conv_filters, conv_kernel, activation: nn.Module = nn.ELU, pooling_size=None,
|
def __init__(self, in_shape, conv_filters, conv_kernel, activation: nn.Module = nn.ELU, pooling_size=None,
|
||||||
bias=True, norm=False, dropout: Union[int, float] = 0,
|
bias=True, use_norm=False, dropout: Union[int, float] = 0, trainable: bool = True,
|
||||||
conv_class=nn.Conv2d, conv_stride=1, conv_padding=0, **kwargs):
|
conv_class=nn.Conv2d, conv_stride=1, conv_padding=0, **kwargs):
|
||||||
super(ConvModule, self).__init__()
|
super(ConvModule, self).__init__()
|
||||||
assert isinstance(in_shape, (tuple, list)), f'"in_shape" should be a [list, tuple], but was {type(in_shape)}'
|
assert isinstance(in_shape, (tuple, list)), f'"in_shape" should be a [list, tuple], but was {type(in_shape)}'
|
||||||
assert len(in_shape) == 3, f'Length should be 3, but was {len(in_shape)}'
|
assert len(in_shape) == 3, f'Length should be 3, but was {len(in_shape)}'
|
||||||
warnings.warn(f'The following arguments have been ignored: \n {list(kwargs.keys())}')
|
if len(kwargs.keys()):
|
||||||
|
warnings.warn(f'The following arguments have been ignored: \n {list(kwargs.keys())}')
|
||||||
|
if use_norm and not trainable:
|
||||||
|
warnings.warn('You set this module to be not trainable but the running norm is active.\n' +
|
||||||
|
'We set it to "eval" mode.\n' +
|
||||||
|
'Keep this in mind if you do a finetunning or retraining step.'
|
||||||
|
)
|
||||||
|
|
||||||
# Module Parameters
|
# Module Parameters
|
||||||
self.in_shape = in_shape
|
self.in_shape = in_shape
|
||||||
|
self.trainable = trainable
|
||||||
in_channels, height, width = in_shape[0], in_shape[1], in_shape[2]
|
in_channels, height, width = in_shape[0], in_shape[1], in_shape[2]
|
||||||
|
|
||||||
# Convolution Parameters
|
# Convolution Parameters
|
||||||
@ -56,13 +71,19 @@ class ConvModule(ShapeMixin, nn.Module):
|
|||||||
self.conv_kernel = conv_kernel
|
self.conv_kernel = conv_kernel
|
||||||
|
|
||||||
# Modules
|
# Modules
|
||||||
self.activation = activation() or F_x(None)
|
self.activation = activation() or nn.Identity()
|
||||||
|
self.norm = nn.LayerNorm(self.in_shape, eps=1e-04) if use_norm else F_x(None)
|
||||||
self.dropout = nn.Dropout2d(dropout) if dropout else F_x(None)
|
self.dropout = nn.Dropout2d(dropout) if dropout else F_x(None)
|
||||||
self.pooling = nn.MaxPool2d(pooling_size) if pooling_size else F_x(None)
|
self.pooling = nn.MaxPool2d(pooling_size) if pooling_size else F_x(None)
|
||||||
self.norm = nn.BatchNorm2d(in_channels, eps=1e-04) if norm else F_x(None)
|
|
||||||
self.conv = conv_class(in_channels, self.conv_filters, self.conv_kernel, bias=bias,
|
self.conv = conv_class(in_channels, self.conv_filters, self.conv_kernel, bias=bias,
|
||||||
padding=self.padding, stride=self.stride
|
padding=self.padding, stride=self.stride
|
||||||
)
|
)
|
||||||
|
if not self.trainable:
|
||||||
|
for param in self.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
self.norm = self.norm.eval()
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
tensor = self.norm(x)
|
tensor = self.norm(x)
|
||||||
@ -73,13 +94,49 @@ class ConvModule(ShapeMixin, nn.Module):
|
|||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
class PreInitializedConvModule(ShapeMixin, nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, in_shape, weight_matrix):
|
||||||
|
super(PreInitializedConvModule, self).__init__()
|
||||||
|
self.in_shape = in_shape
|
||||||
|
self.weight_matrix = weight_matrix
|
||||||
|
raise NotImplementedError
|
||||||
|
# ToDo Get the weight_matrix shape and init a conv_module of similar size,
|
||||||
|
# override the weights then.
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = torch.matmul(x, self.weight_matrix) # ToDo: This is an Placeholder
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class SobelFilter(ShapeMixin, nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, in_shape):
|
||||||
|
super(SobelFilter, self).__init__()
|
||||||
|
self.in_shape = in_shape
|
||||||
|
|
||||||
|
self.sobel_x = torch.tensor([[1, 0, -1], [2, 0, -2], [1, 0, -1]]).view(1, 1, 3, 3)
|
||||||
|
self.sobel_y = torch.tensor([[1, 2, 1], [0, 0, 0], [-1, 2, -1]]).view(1, 1, 3, 3)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# Apply Filters
|
||||||
|
g_x = F.conv2d(x, self.sobel_x)
|
||||||
|
g_y = F.conv2d(x, self.sobel_y)
|
||||||
|
# Calculate the Edge
|
||||||
|
g = torch.add(*[torch.pow(tensor, 2) for tensor in [g_x, g_y]])
|
||||||
|
# Calculate the Gradient
|
||||||
|
g_grad = torch.atan2(g_x, g_y)
|
||||||
|
return g_x, g_y, g, g_grad
|
||||||
|
|
||||||
|
|
||||||
class DeConvModule(ShapeMixin, nn.Module):
|
class DeConvModule(ShapeMixin, nn.Module):
|
||||||
|
|
||||||
def __init__(self, in_shape, conv_filters, conv_kernel, conv_stride=1, conv_padding=0,
|
def __init__(self, in_shape, conv_filters, conv_kernel, conv_stride=1, conv_padding=0,
|
||||||
dropout: Union[int, float] = 0, autopad=0,
|
dropout: Union[int, float] = 0, autopad=0,
|
||||||
activation: Union[None, nn.Module] = nn.ReLU, interpolation_scale=0,
|
activation: Union[None, nn.Module] = nn.ReLU, interpolation_scale=0,
|
||||||
bias=True, norm=False):
|
bias=True, use_norm=False, **kwargs):
|
||||||
super(DeConvModule, self).__init__()
|
super(DeConvModule, self).__init__()
|
||||||
|
warnings.warn(f'The following arguments have been ignored: \n {list(kwargs.keys())}')
|
||||||
in_channels, height, width = in_shape[0], in_shape[1], in_shape[2]
|
in_channels, height, width = in_shape[0], in_shape[1], in_shape[2]
|
||||||
self.padding = conv_padding
|
self.padding = conv_padding
|
||||||
self.conv_kernel = conv_kernel
|
self.conv_kernel = conv_kernel
|
||||||
@ -89,8 +146,8 @@ class DeConvModule(ShapeMixin, nn.Module):
|
|||||||
|
|
||||||
self.autopad = AutoPad() if autopad else lambda x: x
|
self.autopad = AutoPad() if autopad else lambda x: x
|
||||||
self.interpolation = Interpolate(scale_factor=interpolation_scale) if interpolation_scale else lambda x: x
|
self.interpolation = Interpolate(scale_factor=interpolation_scale) if interpolation_scale else lambda x: x
|
||||||
self.norm = nn.BatchNorm2d(in_channels, eps=1e-04) if norm else lambda x: x
|
self.norm = nn.LayerNorm(in_channels, eps=1e-04) if use_norm else F_x(self.in_shape)
|
||||||
self.dropout = nn.Dropout2d(dropout) if dropout else lambda x: x
|
self.dropout = nn.Dropout2d(dropout) if dropout else F_x(self.in_shape)
|
||||||
self.de_conv = nn.ConvTranspose2d(in_channels, self.conv_filters, self.conv_kernel, bias=bias,
|
self.de_conv = nn.ConvTranspose2d(in_channels, self.conv_filters, self.conv_kernel, bias=bias,
|
||||||
padding=self.padding, stride=self.stride)
|
padding=self.padding, stride=self.stride)
|
||||||
|
|
||||||
@ -109,14 +166,13 @@ class DeConvModule(ShapeMixin, nn.Module):
|
|||||||
|
|
||||||
class ResidualModule(ShapeMixin, nn.Module):
|
class ResidualModule(ShapeMixin, nn.Module):
|
||||||
|
|
||||||
def __init__(self, in_shape, module_class, n, norm=False, **module_parameters):
|
def __init__(self, in_shape, module_class, n, use_norm=False, **module_parameters):
|
||||||
assert n >= 1
|
assert n >= 1
|
||||||
super(ResidualModule, self).__init__()
|
super(ResidualModule, self).__init__()
|
||||||
self.in_shape = in_shape
|
self.in_shape = in_shape
|
||||||
module_parameters.update(in_shape=in_shape)
|
module_parameters.update(in_shape=in_shape)
|
||||||
if norm:
|
if use_norm:
|
||||||
self.norm = nn.BatchNorm1d if len(self.in_shape) <= 2 else nn.BatchNorm2d
|
self.norm = nn.LayerNorm(self.in_shape if isinstance(self.in_shape, int) else self.in_shape[0])
|
||||||
self.norm = self.norm(self.in_shape if isinstance(self.in_shape, int) else self.in_shape[0])
|
|
||||||
else:
|
else:
|
||||||
self.norm = F_x(self.in_shape)
|
self.norm = F_x(self.in_shape)
|
||||||
self.activation = module_parameters.get('activation', None)
|
self.activation = module_parameters.get('activation', None)
|
||||||
@ -128,8 +184,9 @@ class ResidualModule(ShapeMixin, nn.Module):
|
|||||||
assert self.in_shape == self.shape, f'The in_shape: {self.in_shape} - must match the out_shape: {self.shape}.'
|
assert self.in_shape == self.shape, f'The in_shape: {self.in_shape} - must match the out_shape: {self.shape}.'
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
tensor = self.norm(x)
|
||||||
for module in self.residual_block:
|
for module in self.residual_block:
|
||||||
tensor = module(x)
|
tensor = module(tensor)
|
||||||
|
|
||||||
# noinspection PyUnboundLocalVariable
|
# noinspection PyUnboundLocalVariable
|
||||||
tensor = tensor + x
|
tensor = tensor + x
|
||||||
@ -155,3 +212,106 @@ class RecurrentModule(ShapeMixin, nn.Module):
|
|||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
tensor = self.rnn(x)
|
tensor = self.rnn(x)
|
||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
class FeedForward(nn.Module):
|
||||||
|
def __init__(self, dim, hidden_dim, dropout=0., activation=nn.GELU):
|
||||||
|
super().__init__()
|
||||||
|
self.net = nn.Sequential(
|
||||||
|
nn.Linear(dim, hidden_dim),
|
||||||
|
activation() or F_x(None),
|
||||||
|
nn.Dropout(dropout),
|
||||||
|
nn.Linear(hidden_dim, dim),
|
||||||
|
activation() or F_x(None),
|
||||||
|
nn.Dropout(dropout)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.net(x)
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
def __init__(self, dim, heads=8, head_dim=64, dropout=0.):
|
||||||
|
super().__init__()
|
||||||
|
inner_dim = head_dim * heads
|
||||||
|
project_out = not (heads == 1 and head_dim == dim)
|
||||||
|
|
||||||
|
self.heads = heads
|
||||||
|
self.scale = head_dim ** -0.5
|
||||||
|
|
||||||
|
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
|
||||||
|
|
||||||
|
self.to_out = nn.Sequential(
|
||||||
|
nn.Linear(inner_dim, dim),
|
||||||
|
nn.Dropout(dropout)
|
||||||
|
) if project_out else nn.Identity()
|
||||||
|
|
||||||
|
def forward(self, x, mask=None, return_attn_weights=False):
|
||||||
|
from einops import rearrange, repeat
|
||||||
|
# noinspection PyTupleAssignmentBalance
|
||||||
|
b, n, _, h = *x.shape, self.heads
|
||||||
|
|
||||||
|
qkv = self.to_qkv(x).chunk(3, dim=-1)
|
||||||
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), qkv)
|
||||||
|
|
||||||
|
dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale
|
||||||
|
|
||||||
|
if mask is not None:
|
||||||
|
mask_value = -torch.finfo(dots.dtype).max
|
||||||
|
mask = F.pad(mask.flatten(1), (1, 0), value=True)
|
||||||
|
assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions'
|
||||||
|
mask = mask[:, None, :] * mask[:, :, None]
|
||||||
|
mask = repeat(mask, 'b n d -> b h n d', h=h) # My addition
|
||||||
|
dots.masked_fill_(~mask, mask_value)
|
||||||
|
# dots.masked_fill_(mask, mask_value) # My addition
|
||||||
|
del mask
|
||||||
|
|
||||||
|
attn = dots.softmax(dim=-1)
|
||||||
|
|
||||||
|
out = torch.einsum('bhij,bhjd->bhid', attn, v)
|
||||||
|
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||||
|
out = self.to_out(out)
|
||||||
|
if return_attn_weights:
|
||||||
|
return out, attn
|
||||||
|
else:
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerModule(ShapeMixin, nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, in_shape, depth, heads, mlp_dim, head_dim=32, dropout=None, use_norm=False,
|
||||||
|
activation=nn.GELU, use_residual=True):
|
||||||
|
super(TransformerModule, self).__init__()
|
||||||
|
|
||||||
|
self.in_shape = in_shape
|
||||||
|
self.use_residual = use_residual
|
||||||
|
|
||||||
|
self.flat = Flatten(self.in_shape) if isinstance(self.in_shape, (tuple, list)) else F_x(in_shape)
|
||||||
|
|
||||||
|
self.embedding_dim = self.flat.flat_shape
|
||||||
|
self.norm = nn.LayerNorm(self.embedding_dim) if use_norm else F_x(None)
|
||||||
|
self.attns = nn.ModuleList([Attention(self.embedding_dim, heads=heads, dropout=dropout, head_dim=head_dim)
|
||||||
|
for _ in range(depth)])
|
||||||
|
self.mlps = nn.ModuleList([FeedForward(self.embedding_dim, mlp_dim, dropout=dropout, activation=activation)
|
||||||
|
for _ in range(depth)])
|
||||||
|
|
||||||
|
def forward(self, x, mask=None, return_attn_weights=False, **_):
|
||||||
|
tensor = self.flat(x)
|
||||||
|
attn_weights = list()
|
||||||
|
|
||||||
|
for attn, mlp in zip(self.attns, self.mlps):
|
||||||
|
# Attention
|
||||||
|
attn_tensor = self.norm(tensor)
|
||||||
|
if return_attn_weights:
|
||||||
|
attn_tensor, attn_weight = attn(attn_tensor, mask=mask, return_attn_weights=return_attn_weights)
|
||||||
|
attn_weights.append(attn_weight)
|
||||||
|
else:
|
||||||
|
attn_tensor = attn(attn_tensor, mask=mask)
|
||||||
|
tensor = tensor + attn_tensor if self.use_residual else attn_tensor
|
||||||
|
|
||||||
|
# MLP
|
||||||
|
mlp_tensor = self.norm(tensor)
|
||||||
|
mlp_tensor = mlp(mlp_tensor)
|
||||||
|
tensor = tensor + mlp_tensor if self.use_residual else mlp_tensor
|
||||||
|
|
||||||
|
return (tensor, attn_weights) if return_attn_weights else tensor
|
||||||
|
65
modules/geometric_blocks.py
Normal file
65
modules/geometric_blocks.py
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn import ReLU
|
||||||
|
try:
|
||||||
|
from torch_geometric.nn import PointConv, fps, radius, global_max_pool, knn_interpolate
|
||||||
|
except ImportError:
|
||||||
|
print('Install torch-geometric to use this package.')
|
||||||
|
|
||||||
|
|
||||||
|
class SAModule(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, ratio, r, nn):
|
||||||
|
super(SAModule, self).__init__()
|
||||||
|
self.ratio = ratio
|
||||||
|
self.r = r
|
||||||
|
self.conv = PointConv(nn)
|
||||||
|
|
||||||
|
def forward(self, x, pos, batch):
|
||||||
|
idx = fps(pos, batch, ratio=self.ratio)
|
||||||
|
row, col = radius(pos, pos[idx], self.r, batch, batch[idx],
|
||||||
|
max_num_neighbors=64)
|
||||||
|
edge_index = torch.stack([col, row], dim=0)
|
||||||
|
x = self.conv(x, (pos, pos[idx]), edge_index)
|
||||||
|
pos, batch = pos[idx], batch[idx]
|
||||||
|
return x, pos, batch
|
||||||
|
|
||||||
|
|
||||||
|
class GlobalSAModule(nn.Module):
|
||||||
|
def __init__(self, nn, channels=3):
|
||||||
|
super(GlobalSAModule, self).__init__()
|
||||||
|
self.nn = nn
|
||||||
|
self.channels = channels
|
||||||
|
|
||||||
|
def forward(self, x, pos, batch):
|
||||||
|
x = self.nn(torch.cat([x, pos], dim=1))
|
||||||
|
x = global_max_pool(x, batch)
|
||||||
|
pos = pos.new_zeros((x.size(0), self.channels))
|
||||||
|
batch = torch.arange(x.size(0), device=batch.device)
|
||||||
|
return x, pos, batch
|
||||||
|
|
||||||
|
|
||||||
|
class MLP(nn.Module):
|
||||||
|
def __init__(self, channels, norm=True):
|
||||||
|
super(MLP, self).__init__()
|
||||||
|
self.net = nn.Sequential(*[
|
||||||
|
nn.Sequential(nn.Linear(channels[i - 1], channels[i]), ReLU(), nn.BatchNorm1d(channels[i]))
|
||||||
|
for i in range(1, len(channels))
|
||||||
|
]).double()
|
||||||
|
|
||||||
|
def forward(self, x, *args, **kwargs):
|
||||||
|
return self.net(x)
|
||||||
|
|
||||||
|
|
||||||
|
class FPModule(torch.nn.Module):
|
||||||
|
def __init__(self, k, nn):
|
||||||
|
super(FPModule, self).__init__()
|
||||||
|
self.k = k
|
||||||
|
self.nn = nn
|
||||||
|
|
||||||
|
def forward(self, x, pos, batch, x_skip, pos_skip, batch_skip):
|
||||||
|
x = knn_interpolate(x, pos, pos_skip, batch, batch_skip, k=self.k)
|
||||||
|
if x_skip is not None:
|
||||||
|
x = torch.cat([x, x_skip], dim=1)
|
||||||
|
x = self.nn(x)
|
||||||
|
return x, pos_skip, batch_skip
|
@ -1,80 +1,144 @@
|
|||||||
#
|
#
|
||||||
# Full Model Parts
|
# Full Model Parts
|
||||||
###################
|
###################
|
||||||
|
from argparse import Namespace
|
||||||
|
from functools import reduce
|
||||||
|
from typing import Union, List, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from abc import ABC
|
||||||
|
from operator import mul
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from modules.utils import ShapeMixin
|
from .blocks import ConvModule, DeConvModule, LinearModule
|
||||||
|
|
||||||
|
from .util import ShapeMixin, LightningBaseModule, Flatten
|
||||||
|
|
||||||
|
|
||||||
class Generator(nn.Module):
|
class AEBaseModule(LightningBaseModule, ABC):
|
||||||
@property
|
|
||||||
def shape(self):
|
|
||||||
x = torch.randn(self.lat_dim).unsqueeze(0)
|
|
||||||
output = self(x)
|
|
||||||
return output.shape[1:]
|
|
||||||
|
|
||||||
# noinspection PyUnresolvedReferences
|
def generate_random_image(self, dataloader: Union[None, str, DataLoader] = None,
|
||||||
def __init__(self, out_channels, re_shape, lat_dim, use_norm=False, use_bias=True, dropout: Union[int, float] = 0,
|
lat_min: Union[Tuple, List, None] = None,
|
||||||
filters: List[int] = None, activation=nn.ReLU):
|
lat_max: Union[Tuple, List, None] = None):
|
||||||
|
|
||||||
|
assert bool(dataloader) ^ bool(lat_min and lat_max), 'Decide wether to give min, max or a dataloader, not both.'
|
||||||
|
|
||||||
|
min_max = self._find_min_max(dataloader) if dataloader else [None, None]
|
||||||
|
# assert not any([tensor is None for tensor in min_max])
|
||||||
|
lat_min = torch.as_tensor(lat_min or min_max[0])
|
||||||
|
lat_max = lat_max or min_max[1]
|
||||||
|
|
||||||
|
random_z = torch.rand((1, self.lat_dim))
|
||||||
|
random_z = random_z * (abs(lat_min) + lat_max) - abs(lat_min)
|
||||||
|
|
||||||
|
return self.decoder(random_z).squeeze()
|
||||||
|
|
||||||
|
def encode(self, x):
|
||||||
|
if len(x.shape) == 3:
|
||||||
|
x = x.unsqueeze(0)
|
||||||
|
return self.encoder(x)
|
||||||
|
|
||||||
|
def _find_min_max(self, dataloader):
|
||||||
|
encodings = list()
|
||||||
|
for batch in dataloader:
|
||||||
|
encodings.append(self.encode(batch))
|
||||||
|
encodings = torch.cat(encodings, dim=0)
|
||||||
|
min_lat = encodings.min(dim=1)
|
||||||
|
max_lat = encodings.max(dim=1)
|
||||||
|
return min_lat, max_lat
|
||||||
|
|
||||||
|
def decode_lat_evenly(self, n: int,
|
||||||
|
dataloader: Union[None, str, DataLoader] = None,
|
||||||
|
lat_min: Union[Tuple, List, None] = None,
|
||||||
|
lat_max: Union[Tuple, List, None] = None):
|
||||||
|
assert bool(dataloader) ^ bool(lat_min and lat_max), 'Decide wether to give min, max or a dataloader, not both.'
|
||||||
|
|
||||||
|
min_max = self._find_min_max(dataloader) if dataloader else [None, None]
|
||||||
|
|
||||||
|
lat_min = lat_min or min_max[0]
|
||||||
|
lat_max = lat_max or min_max[1]
|
||||||
|
|
||||||
|
random_latent_samples = torch.stack([torch.linspace(lat_min[i].item(), lat_max[i].item(), n)
|
||||||
|
for i in range(self.params.lat_dim)], dim=-1).cpu().detach()
|
||||||
|
return self.decode(random_latent_samples).cpu().detach()
|
||||||
|
|
||||||
|
def decode(self, z):
|
||||||
|
try:
|
||||||
|
if len(z.shape) == 1:
|
||||||
|
z = z.unsqueeze(0)
|
||||||
|
except AttributeError:
|
||||||
|
# Does not seem to be a tensor.
|
||||||
|
pass
|
||||||
|
return self.decoder(z).squeeze()
|
||||||
|
|
||||||
|
def encode_and_restore(self, x):
|
||||||
|
x = self.transfer_batch_to_device(x, self.device)
|
||||||
|
if len(x.shape) == 3:
|
||||||
|
x = x.unsqueeze(0)
|
||||||
|
z = self.encode(x)
|
||||||
|
try:
|
||||||
|
z = z.squeeze()
|
||||||
|
except AttributeError:
|
||||||
|
# Does not seem to be a tensor.
|
||||||
|
pass
|
||||||
|
x_hat = self.decode(z)
|
||||||
|
|
||||||
|
return Namespace(main_out=x_hat.squeeze(), latent_out=z)
|
||||||
|
|
||||||
|
|
||||||
|
class Generator(ShapeMixin, nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, in_shape, out_channels, re_shape, use_norm=False, use_bias=True,
|
||||||
|
dropout: Union[int, float] = 0, interpolations: List[int] = None,
|
||||||
|
filters: List[int] = None, kernels: List[int] = None, activation=nn.ReLU,
|
||||||
|
**kwargs):
|
||||||
super(Generator, self).__init__()
|
super(Generator, self).__init__()
|
||||||
assert filters, '"Filters" has to be a list of int len 3'
|
assert filters, '"Filters" has to be a list of int.'
|
||||||
self.filters = filters
|
assert filters, '"Filters" has to be a list of int.'
|
||||||
self.activation = activation
|
kernels = kernels if kernels else [3] * len(filters)
|
||||||
self.inner_activation = activation()
|
assert len(filters) == len(kernels), '"Filters" and "Kernels" has to be of same length.'
|
||||||
|
|
||||||
|
interpolations = interpolations or [2, 2, 2]
|
||||||
|
|
||||||
|
self.in_shape = in_shape
|
||||||
|
self.activation = activation()
|
||||||
self.out_activation = None
|
self.out_activation = None
|
||||||
self.lat_dim = lat_dim
|
|
||||||
self.dropout = dropout
|
self.dropout = dropout
|
||||||
self.l1 = nn.Linear(self.lat_dim, reduce(mul, re_shape), bias=use_bias)
|
self.l1 = LinearModule(in_shape, reduce(mul, re_shape), bias=use_bias, activation=activation)
|
||||||
# re_shape = (self.feature_mixed_dim // reduce(mul, re_shape[1:]), ) + tuple(re_shape[1:])
|
# re_shape = (self.feature_mixed_dim // reduce(mul, re_shape[1:]), ) + tuple(re_shape[1:])
|
||||||
|
|
||||||
self.flat = Flatten(to=re_shape)
|
self.flat = Flatten(self.l1.shape, to=re_shape)
|
||||||
|
self.de_conv_list = nn.ModuleList()
|
||||||
|
|
||||||
self.deconv1 = DeConvModule(re_shape, conv_filters=self.filters[0],
|
last_shape = re_shape
|
||||||
conv_kernel=5,
|
for conv_filter, conv_kernel, interpolation in zip(reversed(filters), kernels, interpolations):
|
||||||
conv_padding=2,
|
# noinspection PyTypeChecker
|
||||||
conv_stride=1,
|
self.de_conv_list.append(DeConvModule(last_shape, conv_filters=conv_filter,
|
||||||
normalize=use_norm,
|
conv_kernel=conv_kernel,
|
||||||
activation=self.activation,
|
conv_padding=conv_kernel-2,
|
||||||
interpolation_scale=2,
|
conv_stride=1,
|
||||||
dropout=self.dropout
|
normalize=use_norm,
|
||||||
)
|
activation=activation,
|
||||||
|
interpolation_scale=interpolation,
|
||||||
|
dropout=self.dropout
|
||||||
|
)
|
||||||
|
)
|
||||||
|
last_shape = self.de_conv_list[-1].shape
|
||||||
|
|
||||||
self.deconv2 = DeConvModule(self.deconv1.shape, conv_filters=self.filters[1],
|
self.de_conv_out = DeConvModule(self.de_conv_list[-1].shape, conv_filters=out_channels, conv_kernel=3,
|
||||||
conv_kernel=3,
|
conv_padding=1, activation=self.out_activation
|
||||||
conv_padding=1,
|
)
|
||||||
conv_stride=1,
|
|
||||||
normalize=use_norm,
|
|
||||||
activation=self.activation,
|
|
||||||
interpolation_scale=2,
|
|
||||||
dropout=self.dropout
|
|
||||||
)
|
|
||||||
|
|
||||||
self.deconv3 = DeConvModule(self.deconv2.shape, conv_filters=self.filters[2],
|
|
||||||
conv_kernel=3,
|
|
||||||
conv_padding=1,
|
|
||||||
conv_stride=1,
|
|
||||||
normalize=use_norm,
|
|
||||||
activation=self.activation,
|
|
||||||
interpolation_scale=2,
|
|
||||||
dropout=self.dropout
|
|
||||||
)
|
|
||||||
|
|
||||||
self.deconv4 = DeConvModule(self.deconv3.shape, conv_filters=out_channels,
|
|
||||||
conv_kernel=3,
|
|
||||||
conv_padding=1,
|
|
||||||
# normalize=norm,
|
|
||||||
activation=self.out_activation
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, z):
|
def forward(self, z):
|
||||||
tensor = self.l1(z)
|
tensor = self.l1(z)
|
||||||
tensor = self.inner_activation(tensor)
|
tensor = self.activation(tensor)
|
||||||
tensor = self.flat(tensor)
|
tensor = self.flat(tensor)
|
||||||
tensor = self.deconv1(tensor)
|
|
||||||
tensor = self.deconv2(tensor)
|
for de_conv in self.de_conv_list:
|
||||||
tensor = self.deconv3(tensor)
|
tensor = de_conv(tensor)
|
||||||
tensor = self.deconv4(tensor)
|
|
||||||
|
tensor = self.de_conv_out(tensor)
|
||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
def size(self):
|
def size(self):
|
||||||
@ -114,18 +178,17 @@ class UnitGenerator(Generator):
|
|||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
class BaseEncoder(ShapeMixin, nn.Module):
|
class BaseCNNEncoder(ShapeMixin, nn.Module):
|
||||||
|
|
||||||
# noinspection PyUnresolvedReferences
|
# noinspection PyUnresolvedReferences
|
||||||
def __init__(self, in_shape, lat_dim=256, use_bias=True, use_norm=False, dropout: Union[int, float] = 0,
|
def __init__(self, in_shape, lat_dim=256, use_bias=True, use_norm=False, dropout: Union[int, float] = 0,
|
||||||
latent_activation: Union[nn.Module, None] = None, activation: nn.Module = nn.ELU,
|
latent_activation: Union[nn.Module, None] = None, activation: nn.Module = nn.ELU,
|
||||||
filters: List[int] = None):
|
filters: List[int] = None, kernels: Union[List[int], int, None] = None, **kwargs):
|
||||||
super(BaseEncoder, self).__init__()
|
super(BaseCNNEncoder, self).__init__()
|
||||||
assert filters, '"Filters" has to be a list of int len 3'
|
assert filters, '"Filters" has to be a list of int'
|
||||||
|
kernels = kernels or [3] * len(filters)
|
||||||
# Optional Padding for odd image-sizes
|
kernels = kernels if not isinstance(kernels, int) else [kernels] * len(filters)
|
||||||
# Obsolet, already Done by autopadding module on incoming tensors
|
assert len(kernels) == len(filters), 'Length of "Filters" and "Kernels" has to be same.'
|
||||||
# in_shape = [x+1 if x % 2 != 0 and idx else x for idx, x in enumerate(in_shape)]
|
|
||||||
|
|
||||||
# Parameters
|
# Parameters
|
||||||
self.lat_dim = lat_dim
|
self.lat_dim = lat_dim
|
||||||
@ -133,52 +196,39 @@ class BaseEncoder(ShapeMixin, nn.Module):
|
|||||||
self.use_bias = use_bias
|
self.use_bias = use_bias
|
||||||
self.latent_activation = latent_activation() if latent_activation else None
|
self.latent_activation = latent_activation() if latent_activation else None
|
||||||
|
|
||||||
|
self.conv_list = nn.ModuleList()
|
||||||
|
|
||||||
# Modules
|
# Modules
|
||||||
self.conv1 = ConvModule(self.in_shape, conv_filters=filters[0],
|
last_shape = self.in_shape
|
||||||
conv_kernel=3,
|
for conv_filter, conv_kernel in zip(filters, kernels):
|
||||||
conv_padding=1,
|
self.conv_list.append(ConvModule(last_shape, conv_filters=conv_filter,
|
||||||
conv_stride=1,
|
conv_kernel=conv_kernel,
|
||||||
pooling_size=2,
|
conv_padding=conv_kernel-2,
|
||||||
use_norm=use_norm,
|
conv_stride=1,
|
||||||
dropout=dropout,
|
pooling_size=2,
|
||||||
activation=activation
|
use_norm=use_norm,
|
||||||
)
|
dropout=dropout,
|
||||||
|
activation=activation
|
||||||
|
)
|
||||||
|
)
|
||||||
|
last_shape = self.conv_list[-1].shape
|
||||||
|
self.last_conv_shape = last_shape
|
||||||
|
|
||||||
self.conv2 = ConvModule(self.conv1.shape, conv_filters=filters[1],
|
self.flat = Flatten(self.last_conv_shape)
|
||||||
conv_kernel=3,
|
|
||||||
conv_padding=1,
|
|
||||||
conv_stride=1,
|
|
||||||
pooling_size=2,
|
|
||||||
use_norm=use_norm,
|
|
||||||
dropout=dropout,
|
|
||||||
activation=activation
|
|
||||||
)
|
|
||||||
|
|
||||||
self.conv3 = ConvModule(self.conv2.shape, conv_filters=filters[2],
|
|
||||||
conv_kernel=5,
|
|
||||||
conv_padding=2,
|
|
||||||
conv_stride=1,
|
|
||||||
pooling_size=2,
|
|
||||||
use_norm=use_norm,
|
|
||||||
dropout=dropout,
|
|
||||||
activation=activation
|
|
||||||
)
|
|
||||||
|
|
||||||
self.flat = Flatten()
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
tensor = self.conv1(x)
|
tensor = x
|
||||||
tensor = self.conv2(tensor)
|
for conv in self.conv_list:
|
||||||
tensor = self.conv3(tensor)
|
tensor = conv(tensor)
|
||||||
tensor = self.flat(tensor)
|
tensor = self.flat(tensor)
|
||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
class UnitEncoder(BaseEncoder):
|
class UnitCNNEncoder(BaseCNNEncoder):
|
||||||
# noinspection PyUnresolvedReferences
|
# noinspection PyUnresolvedReferences
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
kwargs.update(use_norm=True)
|
kwargs.update(use_norm=True)
|
||||||
super(UnitEncoder, self).__init__(*args, **kwargs)
|
super(UnitCNNEncoder, self).__init__(*args, **kwargs)
|
||||||
self.l1 = nn.Linear(reduce(mul, self.conv3.shape), self.lat_dim, bias=self.use_bias)
|
self.l1 = nn.Linear(reduce(mul, self.conv3.shape), self.lat_dim, bias=self.use_bias)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
@ -190,10 +240,10 @@ class UnitEncoder(BaseEncoder):
|
|||||||
return c1, c2, c3, l1
|
return c1, c2, c3, l1
|
||||||
|
|
||||||
|
|
||||||
class VariationalEncoder(BaseEncoder):
|
class VariationalCNNEncoder(BaseCNNEncoder):
|
||||||
# noinspection PyUnresolvedReferences
|
# noinspection PyUnresolvedReferences
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super(VariationalEncoder, self).__init__(*args, **kwargs)
|
super(VariationalCNNEncoder, self).__init__(*args, **kwargs)
|
||||||
|
|
||||||
self.logvar = nn.Linear(reduce(mul, self.conv3.shape), self.lat_dim, bias=self.use_bias)
|
self.logvar = nn.Linear(reduce(mul, self.conv3.shape), self.lat_dim, bias=self.use_bias)
|
||||||
self.mu = nn.Linear(reduce(mul, self.conv3.shape), self.lat_dim, bias=self.use_bias)
|
self.mu = nn.Linear(reduce(mul, self.conv3.shape), self.lat_dim, bias=self.use_bias)
|
||||||
@ -205,22 +255,22 @@ class VariationalEncoder(BaseEncoder):
|
|||||||
return mu + eps*std
|
return mu + eps*std
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
tensor = super(VariationalEncoder, self).forward(x)
|
tensor = super(VariationalCNNEncoder, self).forward(x)
|
||||||
mu = self.mu(tensor)
|
mu = self.mu(tensor)
|
||||||
logvar = self.logvar(tensor)
|
logvar = self.logvar(tensor)
|
||||||
z = self.reparameterize(mu, logvar)
|
z = self.reparameterize(mu, logvar)
|
||||||
return mu, logvar, z
|
return mu, logvar, z
|
||||||
|
|
||||||
|
|
||||||
class Encoder(BaseEncoder):
|
class CNNEncoder(BaseCNNEncoder):
|
||||||
# noinspection PyUnresolvedReferences
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super(Encoder, self).__init__(*args, **kwargs)
|
|
||||||
|
|
||||||
self.l1 = nn.Linear(reduce(mul, self.conv3.shape), self.lat_dim, bias=self.use_bias)
|
def __init__(self, *args, **kwargs):
|
||||||
|
super(CNNEncoder, self).__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
self.l1 = nn.Linear(self.flat.shape, self.lat_dim, bias=self.use_bias)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
tensor = super(Encoder, self).forward(x)
|
tensor = super(CNNEncoder, self).forward(x)
|
||||||
tensor = self.l1(tensor)
|
tensor = self.l1(tensor)
|
||||||
tensor = self.latent_activation(tensor) if self.latent_activation else tensor
|
tensor = self.latent_activation(tensor) if self.latent_activation else tensor
|
||||||
return tensor
|
return tensor
|
||||||
|
420
modules/util.py
Normal file
420
modules/util.py
Normal file
@ -0,0 +1,420 @@
|
|||||||
|
from functools import reduce
|
||||||
|
from matplotlib import pyplot as plt
|
||||||
|
|
||||||
|
from abc import ABC
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from operator import mul
|
||||||
|
from pytorch_lightning.utilities import argparse_utils
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn import functional as F, Unfold
|
||||||
|
|
||||||
|
from sklearn.metrics import ConfusionMatrixDisplay
|
||||||
|
|
||||||
|
# Utility - Modules
|
||||||
|
###################
|
||||||
|
from ..metrics.binary_class_classifictaion import BinaryScores
|
||||||
|
from ..metrics.multi_class_classification import MultiClassScores
|
||||||
|
from ..utils.model_io import ModelParameters
|
||||||
|
from ..utils.tools import add_argparse_args
|
||||||
|
|
||||||
|
try:
|
||||||
|
import pytorch_lightning as pl
|
||||||
|
|
||||||
|
class PLMetrics(pl.metrics.Metric):
|
||||||
|
|
||||||
|
def __init__(self, n_classes, tag=''):
|
||||||
|
super(PLMetrics, self).__init__()
|
||||||
|
|
||||||
|
self.n_classes = n_classes
|
||||||
|
self.tag = tag
|
||||||
|
|
||||||
|
self.accuracy_score = pl.metrics.Accuracy(compute_on_step=False,)
|
||||||
|
self.precision = pl.metrics.Precision(num_classes=self.n_classes, average='macro', compute_on_step=False,
|
||||||
|
is_multiclass=True)
|
||||||
|
self.recall = pl.metrics.Recall(num_classes=self.n_classes, average='macro', compute_on_step=False,
|
||||||
|
is_multiclass=True)
|
||||||
|
self.confusion_matrix = pl.metrics.ConfusionMatrix(self.n_classes, normalize='true', compute_on_step=False)
|
||||||
|
# self.precision_recall_curve = pl.metrics.PrecisionRecallCurve(self.n_classes, compute_on_step=False)
|
||||||
|
# self.average_prec = pl.metrics.AveragePrecision(self.n_classes, compute_on_step=True)
|
||||||
|
# self.roc = pl.metrics.ROC(self.n_classes, compute_on_step=False)
|
||||||
|
if self.n_classes > 2:
|
||||||
|
self.fbeta = pl.metrics.FBeta(self.n_classes, average='macro', compute_on_step=False)
|
||||||
|
self.f1 = pl.metrics.F1(self.n_classes, average='macro', compute_on_step=False)
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return iter(((name, metric) for name, metric in self._modules.items()))
|
||||||
|
|
||||||
|
def update(self, preds, target) -> None:
|
||||||
|
for _, metric in self:
|
||||||
|
try:
|
||||||
|
if self.n_classes <= 2:
|
||||||
|
metric.update(preds, target)
|
||||||
|
else:
|
||||||
|
metric.update(preds, target)
|
||||||
|
except ValueError:
|
||||||
|
print(f'error was: {ValueError}')
|
||||||
|
print(f'Metric is: {metric}')
|
||||||
|
print(f'Shape is: preds - {preds.squeeze().shape}, target - {target.shape}')
|
||||||
|
metric.update(preds.squeeze(), target)
|
||||||
|
except AssertionError:
|
||||||
|
print(f'error was: {AssertionError}')
|
||||||
|
print(f'Metric is: {metric}')
|
||||||
|
print(f'Shape is: preds - {preds.shape}, target - {target.unsqueeze(-1).shape}')
|
||||||
|
metric.update(preds, target.unsqueeze(-1))
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
for _, metric in self:
|
||||||
|
metric.reset()
|
||||||
|
|
||||||
|
def compute(self) -> dict:
|
||||||
|
tag = f'{self.tag}_' if self.tag else ''
|
||||||
|
return {f'{tag}{metric_name}_score': metric.compute() for metric_name, metric in self}
|
||||||
|
|
||||||
|
def compute_and_prepare(self):
|
||||||
|
pl_metrics = self.compute()
|
||||||
|
images_from_metrics = dict()
|
||||||
|
for metric_name in list(pl_metrics.keys()):
|
||||||
|
if 'curve' in metric_name:
|
||||||
|
continue
|
||||||
|
roc_curve = pl_metrics.pop(metric_name)
|
||||||
|
print('debug_point')
|
||||||
|
|
||||||
|
elif 'matrix' in metric_name:
|
||||||
|
matrix = pl_metrics.pop(metric_name)
|
||||||
|
fig1, ax1 = plt.subplots(dpi=96)
|
||||||
|
disp = ConfusionMatrixDisplay(confusion_matrix=matrix.cpu().numpy(),
|
||||||
|
display_labels=[i for i in range(self.n_classes)]
|
||||||
|
)
|
||||||
|
disp.plot(include_values=True, ax=ax1)
|
||||||
|
images_from_metrics[metric_name] = fig1
|
||||||
|
|
||||||
|
elif 'ROC' in metric_name:
|
||||||
|
continue
|
||||||
|
roc = pl_metrics.pop(metric_name)
|
||||||
|
print('debug_point')
|
||||||
|
else:
|
||||||
|
pl_metrics[metric_name] = pl_metrics[metric_name].cpu().item()
|
||||||
|
|
||||||
|
return pl_metrics, images_from_metrics
|
||||||
|
|
||||||
|
|
||||||
|
class LightningBaseModule(pl.LightningModule, ABC):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def name(cls):
|
||||||
|
return cls.__name__
|
||||||
|
|
||||||
|
@property
|
||||||
|
def shape(self):
|
||||||
|
try:
|
||||||
|
x = torch.randn(self.in_shape).unsqueeze(0)
|
||||||
|
output = self(x)
|
||||||
|
return output.shape[1:]
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
return -1
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_argparse_args(cls, args, **kwargs):
|
||||||
|
return argparse_utils.from_argparse_args(cls, args, **kwargs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def add_argparse_args(cls, parent_parser):
|
||||||
|
return add_argparse_args(cls, parent_parser)
|
||||||
|
|
||||||
|
def __init__(self, model_parameters, weight_init='xavier_normal_'):
|
||||||
|
super(LightningBaseModule, self).__init__()
|
||||||
|
self._weight_init = weight_init
|
||||||
|
self.params = ModelParameters(model_parameters)
|
||||||
|
|
||||||
|
if hasattr(self.params, 'n_classes'):
|
||||||
|
self.metrics = PLMetrics(self.params.n_classes, tag='PL')
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def size(self):
|
||||||
|
return self.shape
|
||||||
|
|
||||||
|
def save_to_disk(self, model_path):
|
||||||
|
Path(model_path, exist_ok=True).mkdir(parents=True, exist_ok=True)
|
||||||
|
if not (model_path / 'model_class.obj').exists():
|
||||||
|
with (model_path / 'model_class.obj').open('wb') as f:
|
||||||
|
torch.save(self.__class__, f)
|
||||||
|
return True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def data_len(self):
|
||||||
|
return len(self.dataset.train_dataset)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def n_train_batches(self):
|
||||||
|
return len(self.train_dataloader())
|
||||||
|
|
||||||
|
def configure_optimizers(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def test_step(self, *args, **kwargs):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def test_epoch_end(self, outputs):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def init_weights(self):
|
||||||
|
if isinstance(self._weight_init, str):
|
||||||
|
mod = __import__('torch.nn.init', fromlist=[self._weight_init])
|
||||||
|
self._weight_init = getattr(mod, self._weight_init)
|
||||||
|
assert callable(self._weight_init)
|
||||||
|
weight_initializer = WeightInit(in_place_init_function=self._weight_init)
|
||||||
|
self.apply(weight_initializer)
|
||||||
|
|
||||||
|
def additional_scores(self, outputs):
|
||||||
|
if self.params.n_classes > 2:
|
||||||
|
return MultiClassScores(self)(outputs)
|
||||||
|
else:
|
||||||
|
return BinaryScores(self)(outputs)
|
||||||
|
|
||||||
|
module_types = (LightningBaseModule, nn.Module,)
|
||||||
|
|
||||||
|
except ImportError:
|
||||||
|
module_types = (nn.Module,)
|
||||||
|
pl = None
|
||||||
|
pass # Maybe post a hint to install pytorch-lightning.
|
||||||
|
|
||||||
|
|
||||||
|
class ShapeMixin:
|
||||||
|
|
||||||
|
@property
|
||||||
|
def shape(self):
|
||||||
|
|
||||||
|
assert isinstance(self, module_types)
|
||||||
|
|
||||||
|
def get_out_shape(output):
|
||||||
|
return output.shape[1:] if len(output.shape[1:]) > 1 else output.shape[-1]
|
||||||
|
|
||||||
|
in_shape = self.in_shape if hasattr(self, 'in_shape') else None
|
||||||
|
if in_shape is not None:
|
||||||
|
try:
|
||||||
|
device = self.device
|
||||||
|
except AttributeError:
|
||||||
|
try:
|
||||||
|
device = next(self.parameters()).device
|
||||||
|
except StopIteration:
|
||||||
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
|
x = torch.randn(in_shape, device=device)
|
||||||
|
# This is needed for BatchNorm shape checking
|
||||||
|
x = torch.stack((x, x))
|
||||||
|
|
||||||
|
# noinspection PyCallingNonCallable
|
||||||
|
y = self(x)
|
||||||
|
if isinstance(y, tuple):
|
||||||
|
shape = tuple([get_out_shape(y[i]) for i in range(len(y))])
|
||||||
|
else:
|
||||||
|
shape = get_out_shape(y)
|
||||||
|
return shape
|
||||||
|
else:
|
||||||
|
return -1
|
||||||
|
|
||||||
|
@property
|
||||||
|
def flat_shape(self):
|
||||||
|
shape = self.shape
|
||||||
|
try:
|
||||||
|
return reduce(mul, shape)
|
||||||
|
except TypeError:
|
||||||
|
return shape
|
||||||
|
|
||||||
|
|
||||||
|
class F_x(ShapeMixin, nn.Identity):
|
||||||
|
def __init__(self, in_shape):
|
||||||
|
super(F_x, self).__init__()
|
||||||
|
self.in_shape = in_shape
|
||||||
|
|
||||||
|
|
||||||
|
class SlidingWindow(ShapeMixin, nn.Module):
|
||||||
|
def __init__(self, in_shape, kernel, stride=1, padding=0, keepdim=False):
|
||||||
|
super(SlidingWindow, self).__init__()
|
||||||
|
self.in_shape = in_shape
|
||||||
|
self.kernel = kernel if not isinstance(kernel, int) else (kernel, kernel)
|
||||||
|
self.padding = padding
|
||||||
|
self.stride = stride
|
||||||
|
self.keepdim = keepdim
|
||||||
|
self._unfolder = Unfold(self.kernel, dilation=1, padding=self.padding, stride=self.stride)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
tensor = self._unfolder(x)
|
||||||
|
tensor = tensor.transpose(-1, -2)
|
||||||
|
if self.keepdim:
|
||||||
|
shape = *x.shape[:2], -1, *self.kernel
|
||||||
|
tensor = tensor.reshape(shape)
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
# Utility - Modules
|
||||||
|
###################
|
||||||
|
class Flatten(ShapeMixin, nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, in_shape, to=-1):
|
||||||
|
assert isinstance(to, int) or isinstance(to, tuple)
|
||||||
|
super(Flatten, self).__init__()
|
||||||
|
self.in_shape = in_shape
|
||||||
|
self.to = (to,) if isinstance(to, int) else to
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return x.view(x.size(0), *self.to)
|
||||||
|
|
||||||
|
|
||||||
|
class Interpolate(nn.Module):
|
||||||
|
def __init__(self, size=None, scale_factor=None, mode='nearest', align_corners=None):
|
||||||
|
super(Interpolate, self).__init__()
|
||||||
|
self.interp = nn.functional.interpolate
|
||||||
|
self.size = size
|
||||||
|
self.scale_factor = scale_factor
|
||||||
|
self.align_corners = align_corners
|
||||||
|
self.mode = mode
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.interp(x, size=self.size, scale_factor=self.scale_factor,
|
||||||
|
mode=self.mode, align_corners=self.align_corners)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class AutoPad(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, interpolations=3, base=2):
|
||||||
|
super(AutoPad, self).__init__()
|
||||||
|
self.fct = base ** interpolations
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# noinspection PyUnresolvedReferences
|
||||||
|
x = F.pad(x,
|
||||||
|
[0,
|
||||||
|
(x.shape[-1] // self.fct + 1) * self.fct - x.shape[-1] if x.shape[-1] % self.fct != 0 else 0,
|
||||||
|
(x.shape[-2] // self.fct + 1) * self.fct - x.shape[-2] if x.shape[-2] % self.fct != 0 else 0,
|
||||||
|
0])
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class WeightInit:
|
||||||
|
|
||||||
|
def __init__(self, in_place_init_function):
|
||||||
|
self.in_place_init_function = in_place_init_function
|
||||||
|
|
||||||
|
def __call__(self, m):
|
||||||
|
if hasattr(m, 'weight'):
|
||||||
|
if isinstance(m.weight, torch.Tensor):
|
||||||
|
if m.weight.ndim < 2:
|
||||||
|
m.weight.data.fill_(0.01)
|
||||||
|
else:
|
||||||
|
self.in_place_init_function(m.weight)
|
||||||
|
if hasattr(m, 'bias'):
|
||||||
|
if isinstance(m.bias, torch.Tensor):
|
||||||
|
m.bias.data.fill_(0.01)
|
||||||
|
|
||||||
|
|
||||||
|
class Filter(nn.Module, ShapeMixin):
|
||||||
|
|
||||||
|
def __init__(self, in_shape, pos, dim=-1):
|
||||||
|
super(Filter, self).__init__()
|
||||||
|
|
||||||
|
self.in_shape = in_shape
|
||||||
|
self.pos = pos
|
||||||
|
self.dim = dim
|
||||||
|
raise SystemError('Do not use this Module - broken.')
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward(x):
|
||||||
|
tensor = x[:, -1]
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
class FlipTensor(nn.Module):
|
||||||
|
def __init__(self, dim=-2):
|
||||||
|
super(FlipTensor, self).__init__()
|
||||||
|
self.dim = dim
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
idx = [i for i in range(x.size(self.dim) - 1, -1, -1)]
|
||||||
|
idx = torch.as_tensor(idx).long()
|
||||||
|
inverted_tensor = x.index_select(self.dim, idx)
|
||||||
|
return inverted_tensor
|
||||||
|
|
||||||
|
|
||||||
|
class AutoPadToShape(nn.Module):
|
||||||
|
def __init__(self, target_shape):
|
||||||
|
super(AutoPadToShape, self).__init__()
|
||||||
|
self.target_shape = target_shape
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if not torch.is_tensor(x):
|
||||||
|
x = torch.as_tensor(x)
|
||||||
|
if x.shape[-len(self.target_shape):] == self.target_shape or x.shape == self.target_shape:
|
||||||
|
return x
|
||||||
|
|
||||||
|
idx = [0] * (len(self.target_shape) * 2)
|
||||||
|
for i, j in zip(range(-1, -(len(self.target_shape)+1), -1), range(0, len(idx), 2)):
|
||||||
|
idx[j] = self.target_shape[i] - x.shape[i]
|
||||||
|
x = torch.nn.functional.pad(x, idx)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f'AutoPadTransform({self.target_shape})'
|
||||||
|
|
||||||
|
|
||||||
|
class Splitter(nn.Module):
|
||||||
|
|
||||||
|
@property
|
||||||
|
def shape(self):
|
||||||
|
return tuple([self._out_shape] * self.n)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def out_shape(self):
|
||||||
|
return self._out_shape
|
||||||
|
|
||||||
|
def __init__(self, in_shape, n, dim=-1):
|
||||||
|
super(Splitter, self).__init__()
|
||||||
|
|
||||||
|
self.in_shape = (in_shape, ) if isinstance(in_shape, int) else in_shape
|
||||||
|
self.n = n
|
||||||
|
self.dim = dim if dim > 0 else len(self.in_shape) - abs(dim)
|
||||||
|
|
||||||
|
self.new_dim_size = (self.in_shape[self.dim] // self.n) + (1 if self.in_shape[self.dim] % self.n != 0 else 0)
|
||||||
|
self._out_shape = tuple([x if self.dim != i else self.new_dim_size for i, x in enumerate(self.in_shape)])
|
||||||
|
|
||||||
|
self.autopad = AutoPadToShape(self._out_shape)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
dim = self.dim + 1 if len(self.in_shape) == (x.ndim - 1) else self.dim
|
||||||
|
x = x.transpose(0, dim)
|
||||||
|
n_blocks = list()
|
||||||
|
for block_idx in range(self.n):
|
||||||
|
start = block_idx * self.new_dim_size
|
||||||
|
end = (block_idx + 1) * self.new_dim_size
|
||||||
|
block = x[start:end].transpose(0, dim)
|
||||||
|
block = self.autopad(block)
|
||||||
|
n_blocks.append(block)
|
||||||
|
return n_blocks
|
||||||
|
|
||||||
|
|
||||||
|
class Merger(nn.Module, ShapeMixin):
|
||||||
|
|
||||||
|
@property
|
||||||
|
def shape(self):
|
||||||
|
y = self.forward([torch.randn(self.in_shape) for _ in range(self.n)])
|
||||||
|
return y.shape
|
||||||
|
|
||||||
|
def __init__(self, in_shape, n, dim=-1):
|
||||||
|
super(Merger, self).__init__()
|
||||||
|
|
||||||
|
self.n = n
|
||||||
|
self.dim = dim
|
||||||
|
self.in_shape = in_shape
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return torch.cat(x, dim=self.dim)
|
256
modules/utils.py
256
modules/utils.py
@ -1,256 +0,0 @@
|
|||||||
from abc import ABC
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch import nn
|
|
||||||
from torch import functional as F
|
|
||||||
|
|
||||||
import pytorch_lightning as pl
|
|
||||||
|
|
||||||
|
|
||||||
# Utility - Modules
|
|
||||||
###################
|
|
||||||
from utils.model_io import ModelParameters
|
|
||||||
|
|
||||||
|
|
||||||
class ShapeMixin:
|
|
||||||
|
|
||||||
@property
|
|
||||||
def shape(self):
|
|
||||||
assert isinstance(self, (LightningBaseModule, nn.Module))
|
|
||||||
if self.in_shape is not None:
|
|
||||||
x = torch.randn(self.in_shape)
|
|
||||||
# This is needed for BatchNorm shape checking
|
|
||||||
x = torch.stack((x, x))
|
|
||||||
output = self(x)
|
|
||||||
return output.shape[1:] if len(output.shape[1:]) > 1 else output.shape[-1]
|
|
||||||
else:
|
|
||||||
return -1
|
|
||||||
|
|
||||||
|
|
||||||
class F_x(ShapeMixin, nn.Module):
|
|
||||||
def __init__(self, in_shape):
|
|
||||||
super(F_x, self).__init__()
|
|
||||||
self.in_shape = in_shape
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
# Utility - Modules
|
|
||||||
###################
|
|
||||||
class Flatten(ShapeMixin, nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, in_shape, to=-1):
|
|
||||||
assert isinstance(to, int) or isinstance(to, tuple)
|
|
||||||
super(Flatten, self).__init__()
|
|
||||||
self.in_shape = in_shape
|
|
||||||
self.to = (to,) if isinstance(to, int) else to
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return x.view(x.size(0), *self.to)
|
|
||||||
|
|
||||||
|
|
||||||
class Interpolate(nn.Module):
|
|
||||||
def __init__(self, size=None, scale_factor=None, mode='nearest', align_corners=None):
|
|
||||||
super(Interpolate, self).__init__()
|
|
||||||
self.interp = nn.functional.interpolate
|
|
||||||
self.size = size
|
|
||||||
self.scale_factor = scale_factor
|
|
||||||
self.align_corners = align_corners
|
|
||||||
self.mode = mode
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = self.interp(x, size=self.size, scale_factor=self.scale_factor,
|
|
||||||
mode=self.mode, align_corners=self.align_corners)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class AutoPad(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, interpolations=3, base=2):
|
|
||||||
super(AutoPad, self).__init__()
|
|
||||||
self.fct = base ** interpolations
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
# noinspection PyUnresolvedReferences
|
|
||||||
x = F.pad(x,
|
|
||||||
[0,
|
|
||||||
(x.shape[-1] // self.fct + 1) * self.fct - x.shape[-1] if x.shape[-1] % self.fct != 0 else 0,
|
|
||||||
(x.shape[-2] // self.fct + 1) * self.fct - x.shape[-2] if x.shape[-2] % self.fct != 0 else 0,
|
|
||||||
0])
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class WeightInit:
|
|
||||||
|
|
||||||
def __init__(self, in_place_init_function):
|
|
||||||
self.in_place_init_function = in_place_init_function
|
|
||||||
|
|
||||||
def __call__(self, m):
|
|
||||||
if hasattr(m, 'weight'):
|
|
||||||
if isinstance(m.weight, torch.Tensor):
|
|
||||||
if m.weight.ndim < 2:
|
|
||||||
m.weight.data.fill_(0.01)
|
|
||||||
else:
|
|
||||||
self.in_place_init_function(m.weight)
|
|
||||||
if hasattr(m, 'bias'):
|
|
||||||
if isinstance(m.bias, torch.Tensor):
|
|
||||||
m.bias.data.fill_(0.01)
|
|
||||||
|
|
||||||
|
|
||||||
class LightningBaseModule(pl.LightningModule, ABC):
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def name(cls):
|
|
||||||
return cls.__name__
|
|
||||||
|
|
||||||
@property
|
|
||||||
def shape(self):
|
|
||||||
try:
|
|
||||||
x = torch.randn(self.in_shape).unsqueeze(0)
|
|
||||||
output = self(x)
|
|
||||||
return output.shape[1:]
|
|
||||||
except Exception as e:
|
|
||||||
print(e)
|
|
||||||
return -1
|
|
||||||
|
|
||||||
def __init__(self, hparams):
|
|
||||||
super(LightningBaseModule, self).__init__()
|
|
||||||
|
|
||||||
# Set Parameters
|
|
||||||
################################
|
|
||||||
self.hparams = hparams
|
|
||||||
self.params = ModelParameters(hparams)
|
|
||||||
|
|
||||||
# Dataset Loading
|
|
||||||
################################
|
|
||||||
# TODO: Find a way to push Class Name, library path and parameters (sometimes thiose are objects) in here
|
|
||||||
|
|
||||||
def size(self):
|
|
||||||
return self.shape
|
|
||||||
|
|
||||||
def save_to_disk(self, model_path):
|
|
||||||
Path(model_path, exist_ok=True).mkdir(parents=True, exist_ok=True)
|
|
||||||
if not (model_path / 'model_class.obj').exists():
|
|
||||||
with (model_path / 'model_class.obj').open('wb') as f:
|
|
||||||
torch.save(self.__class__, f)
|
|
||||||
return True
|
|
||||||
|
|
||||||
@property
|
|
||||||
def data_len(self):
|
|
||||||
return len(self.dataset.train_dataset)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def n_train_batches(self):
|
|
||||||
return len(self.train_dataloader())
|
|
||||||
|
|
||||||
def configure_optimizers(self):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def test_step(self, *args, **kwargs):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def test_epoch_end(self, outputs):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def init_weights(self, in_place_init_func_=nn.init.xavier_uniform_):
|
|
||||||
weight_initializer = WeightInit(in_place_init_function=in_place_init_func_)
|
|
||||||
self.apply(weight_initializer)
|
|
||||||
|
|
||||||
|
|
||||||
class FilterLayer(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super(FilterLayer, self).__init__()
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
tensor = x[:, -1]
|
|
||||||
return tensor
|
|
||||||
|
|
||||||
|
|
||||||
class MergingLayer(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super(MergingLayer, self).__init__()
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
# ToDo: Which ones to combine?
|
|
||||||
return
|
|
||||||
|
|
||||||
|
|
||||||
class FlipTensor(nn.Module):
|
|
||||||
def __init__(self, dim=-2):
|
|
||||||
super(FlipTensor, self).__init__()
|
|
||||||
self.dim = dim
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
idx = [i for i in range(x.size(self.dim) - 1, -1, -1)]
|
|
||||||
idx = torch.as_tensor(idx).long()
|
|
||||||
inverted_tensor = x.index_select(self.dim, idx)
|
|
||||||
return inverted_tensor
|
|
||||||
|
|
||||||
|
|
||||||
class AutoPadToShape(object):
|
|
||||||
def __init__(self, shape):
|
|
||||||
self.shape = shape
|
|
||||||
|
|
||||||
def __call__(self, x):
|
|
||||||
if not torch.is_tensor(x):
|
|
||||||
x = torch.as_tensor(x)
|
|
||||||
if x.shape[1:] == self.shape:
|
|
||||||
return x
|
|
||||||
embedding = torch.zeros((x.shape[0], *self.shape))
|
|
||||||
embedding[:, :x.shape[1], :x.shape[2], :x.shape[3]] = x
|
|
||||||
return embedding
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return f'AutoPadTransform({self.shape})'
|
|
||||||
|
|
||||||
|
|
||||||
class HorizontalSplitter(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, in_shape, n):
|
|
||||||
super(HorizontalSplitter, self).__init__()
|
|
||||||
assert len(in_shape) == 3
|
|
||||||
self.n = n
|
|
||||||
self.in_shape = in_shape
|
|
||||||
|
|
||||||
self.channel, self.height, self.width = self.in_shape
|
|
||||||
self.new_height = (self.height // self.n) + (1 if self.height % self.n != 0 else 0)
|
|
||||||
|
|
||||||
self.shape = (self.channel, self.new_height, self.width)
|
|
||||||
self.autopad = AutoPadToShape(self.shape)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
n_blocks = list()
|
|
||||||
for block_idx in range(self.n):
|
|
||||||
start = block_idx * self.new_height
|
|
||||||
end = (block_idx + 1) * self.new_height
|
|
||||||
block = self.autopad(x[:, :, start:end, :])
|
|
||||||
n_blocks.append(block)
|
|
||||||
|
|
||||||
return n_blocks
|
|
||||||
|
|
||||||
|
|
||||||
class HorizontalMerger(nn.Module):
|
|
||||||
|
|
||||||
@property
|
|
||||||
def shape(self):
|
|
||||||
merged_shape = self.in_shape[0], self.in_shape[1] * self.n, self.in_shape[2]
|
|
||||||
return merged_shape
|
|
||||||
|
|
||||||
def __init__(self, in_shape, n):
|
|
||||||
super(HorizontalMerger, self).__init__()
|
|
||||||
assert len(in_shape) == 3
|
|
||||||
self.n = n
|
|
||||||
self.in_shape = in_shape
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return torch.cat(x, dim=-2)
|
|
0
point_toolset/__init__.py
Normal file
0
point_toolset/__init__.py
Normal file
BIN
point_toolset/__pycache__/__init__.cpython-37.pyc
Normal file
BIN
point_toolset/__pycache__/__init__.cpython-37.pyc
Normal file
Binary file not shown.
BIN
point_toolset/__pycache__/point_io.cpython-37.pyc
Normal file
BIN
point_toolset/__pycache__/point_io.cpython-37.pyc
Normal file
Binary file not shown.
37
point_toolset/point_io.py
Normal file
37
point_toolset/point_io.py
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
import torch
|
||||||
|
from torch_geometric.data import Data
|
||||||
|
|
||||||
|
|
||||||
|
class BatchToData(object):
|
||||||
|
def __init__(self, transforms=None):
|
||||||
|
super(BatchToData, self).__init__()
|
||||||
|
self.transforms = transforms if transforms else lambda x: x
|
||||||
|
|
||||||
|
def __call__(self, batch_dict):
|
||||||
|
# Convert to torch_geometric.data.Data type
|
||||||
|
|
||||||
|
batch_pos = batch_dict['pos']
|
||||||
|
batch_norm = batch_dict.get('norm', None)
|
||||||
|
batch_y = batch_dict.get('y', None)
|
||||||
|
batch_y_c = batch_dict.get('y_c', None)
|
||||||
|
|
||||||
|
batch_size, num_points, _ = batch_pos.shape # (batch_size, num_points, 3)
|
||||||
|
|
||||||
|
batch_size, N, _ = batch_pos.shape # (batch_size, num_points, 3)
|
||||||
|
pos = batch_pos.view(batch_size * N, -1)
|
||||||
|
norm = batch_norm.view(batch_size * N, -1) if batch_norm is not None else batch_norm
|
||||||
|
|
||||||
|
batch_y_l = batch_y.view(batch_size * N, -1) if batch_y is not None else batch_y
|
||||||
|
batch_y_c = batch_y_c.view(batch_size * N, -1) if batch_y_c is not None else batch_y_c
|
||||||
|
|
||||||
|
batch = torch.zeros((batch_size, num_points), device=pos.device, dtype=torch.long)
|
||||||
|
for i in range(batch_size):
|
||||||
|
batch[i] = i
|
||||||
|
batch = batch.view(-1)
|
||||||
|
|
||||||
|
data = Data()
|
||||||
|
data.norm, data.pos, data.batch, data.yl, data.yc = norm, pos, batch, batch_y_l, batch_y_c
|
||||||
|
|
||||||
|
data = self.transforms(data)
|
||||||
|
|
||||||
|
return data
|
26
point_toolset/point_transforms.py
Normal file
26
point_toolset/point_transforms.py
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
import torch
|
||||||
|
from torch_geometric.transforms import NormalizeScale
|
||||||
|
|
||||||
|
|
||||||
|
class NormalizePositions(NormalizeScale):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(NormalizePositions, self).__init__()
|
||||||
|
|
||||||
|
def __call__(self, data):
|
||||||
|
if torch.isnan(data.pos).any():
|
||||||
|
print('debug')
|
||||||
|
|
||||||
|
data = self.center(data)
|
||||||
|
if torch.isnan(data.pos).any():
|
||||||
|
print('debug')
|
||||||
|
|
||||||
|
scale = (1 / data.pos.abs().max()) * 0.999999
|
||||||
|
if torch.isnan(scale).any() or torch.isinf(scale).any():
|
||||||
|
print('debug')
|
||||||
|
|
||||||
|
data.pos = data.pos * scale
|
||||||
|
if torch.isnan(data.pos).any():
|
||||||
|
print('debug')
|
||||||
|
|
||||||
|
return data
|
51
point_toolset/sampling.py
Normal file
51
point_toolset/sampling.py
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
from abc import ABC
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class _Sampler(ABC):
|
||||||
|
|
||||||
|
def __init__(self, K, **kwargs):
|
||||||
|
self.k = K
|
||||||
|
self.kwargs = kwargs
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class RandomSampling(_Sampler):
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super(RandomSampling, self).__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
def __call__(self, pts, *args, **kwargs):
|
||||||
|
rnd_indexs = np.random.choice(np.arange(pts.shape[0]), min(self.k, pts.shape[0]), replace=False)
|
||||||
|
return rnd_indexs
|
||||||
|
|
||||||
|
|
||||||
|
class FarthestpointSampling(_Sampler):
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super(FarthestpointSampling, self).__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def calc_distances(p0, points):
|
||||||
|
return ((p0[:3] - points[:, :3]) ** 2).sum(axis=1)
|
||||||
|
|
||||||
|
def __call__(self, pts, *args, **kwargs):
|
||||||
|
|
||||||
|
if pts.shape[0] < self.k:
|
||||||
|
return pts
|
||||||
|
|
||||||
|
else:
|
||||||
|
farthest_pts = np.zeros((self.k, pts.shape[1]))
|
||||||
|
farthest_pts_idx = np.zeros(self.k, dtype=np.int)
|
||||||
|
farthest_pts[0] = pts[np.random.randint(len(pts))]
|
||||||
|
distances = self.calc_distances(farthest_pts[0], pts)
|
||||||
|
for i in range(1, self.k):
|
||||||
|
farthest_pts_idx[i] = np.argmax(distances)
|
||||||
|
farthest_pts[i] = pts[farthest_pts_idx[i]]
|
||||||
|
|
||||||
|
distances = np.minimum(distances, self.calc_distances(farthest_pts[i], pts))
|
||||||
|
|
||||||
|
return farthest_pts_idx
|
@ -1,92 +1,8 @@
|
|||||||
absl-py==0.9.0
|
|
||||||
appdirs==1.4.3
|
|
||||||
attrs==19.3.0
|
|
||||||
audioread==2.1.8
|
|
||||||
bravado==10.6.0
|
|
||||||
bravado-core==5.17.0
|
|
||||||
CacheControl==0.12.6
|
|
||||||
cachetools==4.1.0
|
|
||||||
certifi==2019.11.28
|
|
||||||
cffi==1.14.0
|
|
||||||
chardet==3.0.4
|
|
||||||
click==7.1.2
|
|
||||||
colorama==0.4.3
|
|
||||||
contextlib2==0.6.0
|
|
||||||
cycler==0.10.0
|
|
||||||
decorator==4.4.2
|
|
||||||
distlib==0.3.0
|
|
||||||
distro==1.4.0
|
|
||||||
future==0.18.2
|
|
||||||
gitdb==4.0.5
|
|
||||||
GitPython==3.1.2
|
|
||||||
google-auth==1.14.3
|
|
||||||
google-auth-oauthlib==0.4.1
|
|
||||||
grpcio==1.29.0
|
|
||||||
html5lib==1.0.1
|
|
||||||
idna==2.8
|
|
||||||
ipaddr==2.2.0
|
|
||||||
joblib==0.15.1
|
|
||||||
jsonpointer==2.0
|
|
||||||
jsonref==0.2
|
|
||||||
jsonschema==3.2.0
|
|
||||||
kiwisolver==1.2.0
|
|
||||||
librosa==0.7.2
|
|
||||||
llvmlite==0.32.1
|
|
||||||
lockfile==0.12.2
|
|
||||||
Markdown==3.2.2
|
|
||||||
matplotlib==3.2.1
|
|
||||||
monotonic==1.5
|
|
||||||
msgpack==0.6.2
|
|
||||||
msgpack-python==0.5.6
|
|
||||||
natsort==7.0.1
|
natsort==7.0.1
|
||||||
neptune-client==0.4.113
|
neptune-client==0.4.109
|
||||||
numba==0.49.1
|
|
||||||
numpy==1.18.4
|
|
||||||
oauthlib==3.1.0
|
|
||||||
packaging==20.3
|
|
||||||
pandas==1.0.3
|
|
||||||
pep517==0.8.2
|
|
||||||
Pillow==7.1.2
|
|
||||||
progress==1.5
|
|
||||||
protobuf==3.12.0
|
|
||||||
py3nvml==0.2.6
|
|
||||||
pyasn1==0.4.8
|
|
||||||
pyasn1-modules==0.2.8
|
|
||||||
pycparser==2.20
|
|
||||||
PyJWT==1.7.1
|
|
||||||
pyparsing==2.4.6
|
|
||||||
pyrsistent==0.16.0
|
|
||||||
python-dateutil==2.8.1
|
|
||||||
pytoml==0.1.21
|
|
||||||
pytorch-lightning==0.7.6
|
pytorch-lightning==0.7.6
|
||||||
pytz==2020.1
|
test-tube==0.7.5
|
||||||
PyYAML==5.3.1
|
torch==1.4.0
|
||||||
requests==2.22.0
|
torchcontrib==0.0.2
|
||||||
requests-oauthlib==1.3.0
|
torchvision==0.5.0
|
||||||
resampy==0.2.2
|
tqdm==4.45.0
|
||||||
retrying==1.3.3
|
|
||||||
rfc3987==1.3.8
|
|
||||||
rsa==4.0
|
|
||||||
scikit-learn==0.23.0
|
|
||||||
scipy==1.4.1
|
|
||||||
simplejson==3.17.0
|
|
||||||
six==1.14.0
|
|
||||||
smmap==3.0.4
|
|
||||||
SoundFile==0.10.3.post1
|
|
||||||
strict-rfc3339==0.7
|
|
||||||
swagger-spec-validator==2.5.0
|
|
||||||
tensorboard==2.2.1
|
|
||||||
tensorboard-plugin-wit==1.6.0.post3
|
|
||||||
threadpoolctl==2.0.0
|
|
||||||
torch==1.5.0+cu101
|
|
||||||
torchvision==0.6.0+cu101
|
|
||||||
tqdm==4.46.0
|
|
||||||
typing-extensions==3.7.4.2
|
|
||||||
urllib3==1.25.8
|
|
||||||
webcolors==1.11.1
|
|
||||||
webencodings==0.5.1
|
|
||||||
websocket-client==0.57.0
|
|
||||||
Werkzeug==1.0.1
|
|
||||||
xmltodict==0.12.0
|
|
||||||
|
|
||||||
torchcontrib~=0.0.2
|
|
||||||
|
BIN
utils/__pycache__/__init__.cpython-37.pyc
Normal file
BIN
utils/__pycache__/__init__.cpython-37.pyc
Normal file
Binary file not shown.
BIN
utils/__pycache__/config.cpython-37.pyc
Normal file
BIN
utils/__pycache__/config.cpython-37.pyc
Normal file
Binary file not shown.
BIN
utils/__pycache__/model_io.cpython-37.pyc
Normal file
BIN
utils/__pycache__/model_io.cpython-37.pyc
Normal file
Binary file not shown.
BIN
utils/__pycache__/tools.cpython-37.pyc
Normal file
BIN
utils/__pycache__/tools.cpython-37.pyc
Normal file
Binary file not shown.
36
utils/_basedatamodule.py
Normal file
36
utils/_basedatamodule.py
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
from pytorch_lightning import LightningDataModule
|
||||||
|
|
||||||
|
|
||||||
|
# Dataset Options
|
||||||
|
from ml_lib.utils.tools import add_argparse_args
|
||||||
|
|
||||||
|
DATA_OPTION_test = 'test'
|
||||||
|
DATA_OPTION_devel = 'devel'
|
||||||
|
DATA_OPTION_train = 'train'
|
||||||
|
DATA_OPTIONS = [DATA_OPTION_train, DATA_OPTION_devel, DATA_OPTION_test]
|
||||||
|
|
||||||
|
|
||||||
|
class _BaseDataModule(LightningDataModule):
|
||||||
|
|
||||||
|
@property
|
||||||
|
def shape(self):
|
||||||
|
return self.datasets[DATA_OPTION_train].sample_shape
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def add_argparse_args(cls, parent_parser):
|
||||||
|
return add_argparse_args(cls, parent_parser)
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.datasets = dict()
|
||||||
|
|
||||||
|
def transfer_batch_to_device(self, batch, device):
|
||||||
|
if isinstance(batch, list):
|
||||||
|
for idx, item in enumerate(batch):
|
||||||
|
try:
|
||||||
|
batch[idx] = item.to(device)
|
||||||
|
except (AttributeError, RuntimeError):
|
||||||
|
continue
|
||||||
|
return batch
|
||||||
|
else:
|
||||||
|
return batch.to(device)
|
28
utils/callbacks.py
Normal file
28
utils/callbacks.py
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
import torch
|
||||||
|
from pytorch_lightning import Callback, Trainer, LightningModule
|
||||||
|
|
||||||
|
|
||||||
|
class BestScoresCallback(Callback):
|
||||||
|
|
||||||
|
def __init__(self, *monitors) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.monitors = list(*monitors)
|
||||||
|
|
||||||
|
self.best_scores = {monitor: 0.0 for monitor in self.monitors}
|
||||||
|
self.best_epoch = {monitor: 0 for monitor in self.monitors}
|
||||||
|
|
||||||
|
def on_validation_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
|
||||||
|
epoch = pl_module.current_epoch
|
||||||
|
|
||||||
|
for monitor in self.best_scores.keys():
|
||||||
|
current_score = trainer.callback_metrics.get(monitor)
|
||||||
|
if current_score is None:
|
||||||
|
pass
|
||||||
|
elif torch.isinf(current_score):
|
||||||
|
pass
|
||||||
|
elif torch.isnan(current_score):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
self.best_scores[monitor] = max(self.best_scores[monitor], current_score)
|
||||||
|
if self.best_scores[monitor] == current_score:
|
||||||
|
self.best_epoch[monitor] = max(self.best_epoch[monitor], epoch)
|
148
utils/config.py
148
utils/config.py
@ -1,13 +1,84 @@
|
|||||||
import ast
|
import ast
|
||||||
|
import configparser
|
||||||
|
from distutils.util import strtobool
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Mapping, Dict
|
||||||
|
|
||||||
|
import torch
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
|
|
||||||
from argparse import Namespace, ArgumentParser
|
from argparse import Namespace, ArgumentParser
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from configparser import ConfigParser
|
from configparser import ConfigParser, DuplicateSectionError
|
||||||
from pathlib import Path
|
|
||||||
import hashlib
|
import hashlib
|
||||||
|
from pytorch_lightning import Trainer
|
||||||
|
|
||||||
|
from ml_lib.utils.loggers import LightningLogger
|
||||||
|
from ml_lib.utils.tools import locate_and_import_class, auto_cast
|
||||||
|
|
||||||
|
|
||||||
|
# Argument Parser and default Values
|
||||||
|
# =============================================================================
|
||||||
|
def parse_comandline_args_add_defaults(filepath, overrides=None):
|
||||||
|
|
||||||
|
# Parse Command Line
|
||||||
|
parser = ArgumentParser()
|
||||||
|
parser.add_argument('--model_name', type=str)
|
||||||
|
parser.add_argument('--data_name', type=str)
|
||||||
|
parser.add_argument('--seed', type=str)
|
||||||
|
parser.add_argument('--debug', type=strtobool)
|
||||||
|
|
||||||
|
# Load Defaults from _parameters.ini file
|
||||||
|
config = configparser.ConfigParser()
|
||||||
|
config.read(str(filepath))
|
||||||
|
|
||||||
|
new_defaults = dict()
|
||||||
|
for key in ['project', 'train', 'data']:
|
||||||
|
defaults = config[key]
|
||||||
|
new_defaults.update({key: auto_cast(val) for key, val in defaults.items()})
|
||||||
|
|
||||||
|
args, _ = parser.parse_known_args()
|
||||||
|
overrides = overrides or dict()
|
||||||
|
default_data = overrides.get('data_name', None) or new_defaults['data_name']
|
||||||
|
default_model = overrides.get('model_name', None) or new_defaults['model_name']
|
||||||
|
default_seed = overrides.get('seed', None) or new_defaults['seed']
|
||||||
|
|
||||||
|
data_name = args.__dict__.get('data_name', None) or default_data
|
||||||
|
model_name = args.__dict__.get('model_name', None) or default_model
|
||||||
|
found_seed = args.__dict__.get('seed', None) or default_seed
|
||||||
|
|
||||||
|
new_defaults.update({key: auto_cast(val) for key, val in config[model_name].items()})
|
||||||
|
|
||||||
|
found_data_class = locate_and_import_class(data_name, 'datasets')
|
||||||
|
found_model_class = locate_and_import_class(model_name, 'models')
|
||||||
|
|
||||||
|
for module in [LightningLogger, Trainer, found_data_class, found_model_class]:
|
||||||
|
parser = module.add_argparse_args(parser)
|
||||||
|
|
||||||
|
args, _ = parser.parse_known_args(namespace=Namespace(**new_defaults))
|
||||||
|
|
||||||
|
args = vars(args)
|
||||||
|
args.update({key: auto_cast(val) for key, val in args.items()})
|
||||||
|
args.update(gpus=[0] if torch.cuda.is_available() and not args['debug'] else None,
|
||||||
|
row_log_interval=1000, # TODO: Better Value / Setting
|
||||||
|
log_save_interval=10000, # TODO: Better Value / Setting
|
||||||
|
weights_summary='top',
|
||||||
|
)
|
||||||
|
|
||||||
|
if overrides is not None and isinstance(overrides, (Mapping, Dict)):
|
||||||
|
args.update(**overrides)
|
||||||
|
if args['debug']:
|
||||||
|
args.update(
|
||||||
|
# The seems to be the new "fast_dev_run"
|
||||||
|
val_check_interval=1,
|
||||||
|
max_epochs=2,
|
||||||
|
max_steps=2,
|
||||||
|
auto_lr_find=False,
|
||||||
|
check_val_every_n_epoch=1
|
||||||
|
)
|
||||||
|
return args, found_data_class, found_model_class, found_seed
|
||||||
|
|
||||||
|
|
||||||
def is_jsonable(x):
|
def is_jsonable(x):
|
||||||
@ -38,11 +109,32 @@ class Config(ConfigParser, ABC):
|
|||||||
def fingerprint(self):
|
def fingerprint(self):
|
||||||
h = hashlib.md5()
|
h = hashlib.md5()
|
||||||
params = deepcopy(self.as_dict)
|
params = deepcopy(self.as_dict)
|
||||||
del params['model']['type']
|
try:
|
||||||
del params['model']['secondary_type']
|
del params['model']['type']
|
||||||
del params['data']['worker']
|
except KeyError:
|
||||||
del params['main']
|
pass
|
||||||
h.update(str(params).encode())
|
try:
|
||||||
|
del params['data']['worker']
|
||||||
|
except KeyError:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
del params['data']['refresh']
|
||||||
|
except KeyError:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
del params['main']
|
||||||
|
except KeyError:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
del params['project']
|
||||||
|
except KeyError:
|
||||||
|
pass
|
||||||
|
# Flatten the dict of dicts
|
||||||
|
for section in list(params.keys()):
|
||||||
|
params.update({f'{section}_{key}': val for key, val in params[section].items()})
|
||||||
|
del params[section]
|
||||||
|
_, vals = zip(*sorted(params.items(), key=lambda tup: tup[0]))
|
||||||
|
h.update(str(vals).encode())
|
||||||
fingerprint = h.hexdigest()
|
fingerprint = h.hexdigest()
|
||||||
return fingerprint
|
return fingerprint
|
||||||
|
|
||||||
@ -53,6 +145,7 @@ class Config(ConfigParser, ABC):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def _model_map(self):
|
def _model_map(self):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
This is function is supposed to return a dict, which holds a mapping from string model names to model classes
|
This is function is supposed to return a dict, which holds a mapping from string model names to model classes
|
||||||
|
|
||||||
@ -62,21 +155,31 @@ class Config(ConfigParser, ABC):
|
|||||||
)
|
)
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def model_class(self):
|
def model_class(self):
|
||||||
try:
|
try:
|
||||||
return self._model_map[self.model.type]
|
return locate_and_import_class(self.model.type, folder_path='models')
|
||||||
except KeyError:
|
except AttributeError as e:
|
||||||
raise KeyError(f'The model alias you provided ("{self.get("model", "type")}")' +
|
raise AttributeError(f'The model alias you provided ("{self.get("model", "type")}") ' +
|
||||||
'does not exist! Try one of these: {list(self._model_map.keys())}')
|
f'was not found!\n' +
|
||||||
|
f'{e}')
|
||||||
|
|
||||||
|
@property
|
||||||
|
def data_class(self):
|
||||||
|
try:
|
||||||
|
return locate_and_import_class(self.data.class_name, folder_path='datasets')
|
||||||
|
except AttributeError as e:
|
||||||
|
raise AttributeError(f'The dataset alias you provided ("{self.get("data", "class_name")}") ' +
|
||||||
|
f'was not found!\n' +
|
||||||
|
f'{e}')
|
||||||
|
|
||||||
|
# --------------------------------------------------
|
||||||
# TODO: Do this programmatically; This did not work:
|
# TODO: Do this programmatically; This did not work:
|
||||||
# Initialize Default Sections as Property
|
# Initialize Default Sections as Property
|
||||||
# for section in self.default_sections:
|
# for section in self.default_sections:
|
||||||
# self.__setattr__(section, property(lambda x :x._get_namespace_for_section(section))
|
# self.__setattr__(section, property(lambda tensor :tensor._get_namespace_for_section(section))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def main(self):
|
def main(self):
|
||||||
@ -195,3 +298,22 @@ class Config(ConfigParser, ABC):
|
|||||||
with path.open('w') as configfile:
|
with path.open('w') as configfile:
|
||||||
super().write(configfile)
|
super().write(configfile)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def _write_section(self, fp, section_name, section_items, delimiter):
|
||||||
|
if section_name == 'project':
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
super(Config, self)._write_section(fp, section_name, section_items, delimiter)
|
||||||
|
|
||||||
|
def add_section(self, section: str) -> None:
|
||||||
|
try:
|
||||||
|
super(Config, self).add_section(section)
|
||||||
|
except DuplicateSectionError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DataClass(Namespace):
|
||||||
|
|
||||||
|
@property
|
||||||
|
def __dict__(self):
|
||||||
|
return [x for x in dir(self) if not x.startswith('_')]
|
||||||
|
41
utils/data_util.py
Normal file
41
utils/data_util.py
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
import torch
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
|
|
||||||
|
def chunks(l, n):
|
||||||
|
"""Yield successive n-sized chunks from l."""
|
||||||
|
for i in range(0, len(l), n):
|
||||||
|
yield l[i:i + n]
|
||||||
|
|
||||||
|
|
||||||
|
class ReMapDataset(Dataset):
|
||||||
|
@property
|
||||||
|
def sample_shape(self):
|
||||||
|
return list(self[0][0].shape)
|
||||||
|
|
||||||
|
def __init__(self, ds, mapping):
|
||||||
|
super(ReMapDataset, self).__init__()
|
||||||
|
# here is a mapping from this index to the mother ds index
|
||||||
|
self.mapping = mapping
|
||||||
|
self.ds = ds
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
return self.ds[self.mapping[index]]
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.mapping.shape[0]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def do_train_vali_split(cls, ds, split_fold=0.1):
|
||||||
|
|
||||||
|
indices = torch.randperm(len(ds))
|
||||||
|
|
||||||
|
valid_size = int(len(ds) * split_fold)
|
||||||
|
|
||||||
|
train_mapping = indices[valid_size:]
|
||||||
|
valid_mapping = indices[:valid_size]
|
||||||
|
|
||||||
|
train = cls(ds, train_mapping)
|
||||||
|
valid = cls(ds, valid_mapping)
|
||||||
|
|
||||||
|
return train, valid
|
30
utils/equal_sampler.py
Normal file
30
utils/equal_sampler.py
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
|
||||||
|
import random
|
||||||
|
from typing import Iterator, Sequence
|
||||||
|
|
||||||
|
from torch.utils.data import Sampler
|
||||||
|
from torch.utils.data.sampler import T_co
|
||||||
|
|
||||||
|
|
||||||
|
# noinspection PyMissingConstructor
|
||||||
|
class EqualSampler(Sampler):
|
||||||
|
|
||||||
|
def __init__(self, idxs_per_class: Sequence[Sequence[float]], replacement: bool = True) -> None:
|
||||||
|
|
||||||
|
self.replacement = replacement
|
||||||
|
self.idxs_per_class = idxs_per_class
|
||||||
|
self.len_largest_class = max([len(x) for x in self.idxs_per_class])
|
||||||
|
|
||||||
|
def __iter__(self) -> Iterator[T_co]:
|
||||||
|
return iter(random.choice(self.idxs_per_class[random.randint(0, len(self.idxs_per_class)-1)])
|
||||||
|
for _ in range(len(self)))
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.len_largest_class * len(self.idxs_per_class)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
es = EqualSampler([list(range(5)), list(range(5, 10)), list(range(10, 12))])
|
||||||
|
for i in es:
|
||||||
|
print(i)
|
||||||
|
pass
|
185
utils/loggers.py
Normal file
185
utils/loggers.py
Normal file
@ -0,0 +1,185 @@
|
|||||||
|
import inspect
|
||||||
|
from argparse import ArgumentParser
|
||||||
|
from copy import deepcopy
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import os
|
||||||
|
from pytorch_lightning.loggers.base import LightningLoggerBase
|
||||||
|
from neptune.api_exceptions import ProjectNotFound
|
||||||
|
|
||||||
|
from pytorch_lightning.loggers.neptune import NeptuneLogger
|
||||||
|
|
||||||
|
from pytorch_lightning.loggers.csv_logs import CSVLogger
|
||||||
|
from pytorch_lightning.utilities import argparse_utils
|
||||||
|
|
||||||
|
from ml_lib.utils.tools import add_argparse_args
|
||||||
|
|
||||||
|
|
||||||
|
class LightningLogger(LightningLoggerBase):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_argparse_args(cls, args, **kwargs):
|
||||||
|
cleaned_args = deepcopy(args.__dict__)
|
||||||
|
|
||||||
|
# Clean Seed and other attributes
|
||||||
|
# TODO: Find a better way in cleaning this
|
||||||
|
for attr in ['seed', 'num_worker', 'debug', 'eval', 'owner', 'data_root', 'check_val_every_n_epoch',
|
||||||
|
'reset', 'outpath', 'version', 'gpus', 'neptune_key', 'num_sanity_val_steps', 'tpu_cores',
|
||||||
|
'progress_bar_refresh_rate', 'log_save_interval', 'row_log_interval']:
|
||||||
|
|
||||||
|
try:
|
||||||
|
del cleaned_args[attr]
|
||||||
|
except KeyError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
kwargs.update(params=cleaned_args)
|
||||||
|
new_logger = argparse_utils.from_argparse_args(cls, args, **kwargs)
|
||||||
|
return new_logger
|
||||||
|
|
||||||
|
@property
|
||||||
|
def fingerprint(self):
|
||||||
|
h = hashlib.md5()
|
||||||
|
h.update(self._finger_print_string.encode())
|
||||||
|
fingerprint = h.hexdigest()
|
||||||
|
return fingerprint
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self):
|
||||||
|
short_name = "".join(c for c in self.model_name if c.isupper())
|
||||||
|
return f'{short_name}_{self.fingerprint}'
|
||||||
|
|
||||||
|
media_dir = 'media'
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def add_argparse_args(cls, parent_parser):
|
||||||
|
return add_argparse_args(cls, parent_parser)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def experiment(self):
|
||||||
|
if self.debug:
|
||||||
|
return self.csvlogger.experiment
|
||||||
|
else:
|
||||||
|
return self.neptunelogger.experiment
|
||||||
|
|
||||||
|
@property
|
||||||
|
def log_dir(self):
|
||||||
|
return Path(self.csvlogger.experiment.log_dir)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def project_name(self):
|
||||||
|
return f"{self.owner}/{self.projeect_root.replace('_', '-')}"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def projeect_root(self):
|
||||||
|
root_path = Path(os.getcwd()).name if not self.debug else 'test'
|
||||||
|
return root_path
|
||||||
|
|
||||||
|
@property
|
||||||
|
def version(self):
|
||||||
|
return self.seed
|
||||||
|
|
||||||
|
@property
|
||||||
|
def save_dir(self):
|
||||||
|
return self.log_dir
|
||||||
|
|
||||||
|
@property
|
||||||
|
def outpath(self):
|
||||||
|
return Path(self.root_out) / self.model_name
|
||||||
|
|
||||||
|
def __init__(self, owner, neptune_key, model_name, outpath='output', seed=69, debug=False, params=None):
|
||||||
|
"""
|
||||||
|
params (dict|None): Optional. Parameters of the experiment. After experiment creation params are read-only.
|
||||||
|
Parameters are displayed in the experiment’s Parameters section and each key-value pair can be
|
||||||
|
viewed in experiments view as a column.
|
||||||
|
properties (dict|None): Optional default is {}. Properties of the experiment.
|
||||||
|
They are editable after experiment is created. Properties are displayed in the experiment’s Details and
|
||||||
|
each key-value pair can be viewed in experiments view as a column.
|
||||||
|
tags (list|None): Optional default []. Must be list of str. Tags of the experiment.
|
||||||
|
They are editable after experiment is created (see: append_tag() and remove_tag()).
|
||||||
|
Tags are displayed in the experiment’s Details and can be viewed in experiments view as a column.
|
||||||
|
"""
|
||||||
|
super(LightningLogger, self).__init__()
|
||||||
|
|
||||||
|
self.debug = debug
|
||||||
|
self.owner = owner if not self.debug else 'testuser'
|
||||||
|
self.neptune_key = neptune_key if not self.debug else 'XXX'
|
||||||
|
self.root_out = outpath if not self.debug else 'debug_out'
|
||||||
|
|
||||||
|
self.params = params
|
||||||
|
|
||||||
|
self.seed = seed
|
||||||
|
self.model_name = model_name
|
||||||
|
|
||||||
|
if self.params:
|
||||||
|
_, fingerprint_tuple = zip(*sorted(self.params.items(), key=lambda tup: tup[0]))
|
||||||
|
self._finger_print_string = str(fingerprint_tuple)
|
||||||
|
else:
|
||||||
|
self._finger_print_string = str((self.owner, self.root_out, self.seed, self.model_name, self.debug))
|
||||||
|
self.params.update(fingerprint=self.fingerprint)
|
||||||
|
|
||||||
|
self._csvlogger_kwargs = dict(save_dir=self.outpath, version=self.version, name=self.name)
|
||||||
|
self._neptune_kwargs = dict(offline_mode=self.debug,
|
||||||
|
params=self.params,
|
||||||
|
api_key=self.neptune_key,
|
||||||
|
experiment_name=self.name,
|
||||||
|
# tags=?,
|
||||||
|
project_name=self.project_name)
|
||||||
|
try:
|
||||||
|
self.neptunelogger = NeptuneLogger(**self._neptune_kwargs)
|
||||||
|
except ProjectNotFound as e:
|
||||||
|
print(f'The project "{self.project_name}" does not exist! Create it or check your spelling.')
|
||||||
|
print(e)
|
||||||
|
|
||||||
|
self.csvlogger = CSVLogger(**self._csvlogger_kwargs)
|
||||||
|
if self.params:
|
||||||
|
self.log_hyperparams(self.params)
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
self.csvlogger.close()
|
||||||
|
self.neptunelogger.close()
|
||||||
|
|
||||||
|
def set_fingerprint_string(self, fingerprint_str):
|
||||||
|
self._finger_print_string = fingerprint_str
|
||||||
|
|
||||||
|
def log_text(self, name, text, **_):
|
||||||
|
# TODO Implement Offline variant.
|
||||||
|
self.neptunelogger.log_text(name, text)
|
||||||
|
|
||||||
|
def log_hyperparams(self, params):
|
||||||
|
self.neptunelogger.log_hyperparams(params)
|
||||||
|
self.csvlogger.log_hyperparams(params)
|
||||||
|
pass
|
||||||
|
|
||||||
|
def log_metric(self, metric_name, metric_value, step=None, **kwargs):
|
||||||
|
self.csvlogger.log_metrics(dict(metric_name=metric_value, **kwargs), step=step, **kwargs)
|
||||||
|
self.neptunelogger.log_metric(metric_name, metric_value, step=step, **kwargs)
|
||||||
|
pass
|
||||||
|
|
||||||
|
def log_metrics(self, metrics, step=None):
|
||||||
|
self.neptunelogger.log_metrics(metrics, step=step)
|
||||||
|
self.csvlogger.log_metrics(metrics, step=step)
|
||||||
|
pass
|
||||||
|
|
||||||
|
def log_image(self, name, image, ext='png', step=None, **kwargs):
|
||||||
|
image_name = f'{"0" * (4 - len(str(step)))}{step}_{name}' if step is not None else name
|
||||||
|
image_path = self.log_dir / self.media_dir / f'{image_name}.{ext[1:] if ext.startswith(".") else ext}'
|
||||||
|
(self.log_dir / self.media_dir).mkdir(parents=True, exist_ok=True)
|
||||||
|
image.savefig(image_path, bbox_inches='tight', pad_inches=0)
|
||||||
|
self.neptunelogger.log_image(name, str(image_path), **kwargs)
|
||||||
|
|
||||||
|
def save(self):
|
||||||
|
self.csvlogger.save()
|
||||||
|
self.neptunelogger.save()
|
||||||
|
|
||||||
|
def finalize(self, status):
|
||||||
|
self.csvlogger.finalize(status)
|
||||||
|
self.neptunelogger.finalize(status)
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
self.finalize('success')
|
||||||
|
pass
|
116
utils/logging.py
116
utils/logging.py
@ -1,116 +0,0 @@
|
|||||||
from abc import ABC
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from pytorch_lightning.loggers.base import LightningLoggerBase
|
|
||||||
from pytorch_lightning.loggers.neptune import NeptuneLogger
|
|
||||||
from pytorch_lightning.loggers.test_tube import TestTubeLogger
|
|
||||||
|
|
||||||
from utils.config import Config
|
|
||||||
|
|
||||||
|
|
||||||
class Logger(LightningLoggerBase, ABC):
|
|
||||||
|
|
||||||
media_dir = 'media'
|
|
||||||
|
|
||||||
@property
|
|
||||||
def experiment(self):
|
|
||||||
if self.debug:
|
|
||||||
return self.testtubelogger.experiment
|
|
||||||
else:
|
|
||||||
return self.neptunelogger.experiment
|
|
||||||
|
|
||||||
@property
|
|
||||||
def log_dir(self):
|
|
||||||
return Path(self.testtubelogger.experiment.get_logdir()).parent
|
|
||||||
|
|
||||||
@property
|
|
||||||
def name(self):
|
|
||||||
return self.config.name
|
|
||||||
|
|
||||||
@property
|
|
||||||
def project_name(self):
|
|
||||||
return f"{self.config.project.owner}/{self.config.project.name.replace('_', '-')}"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def version(self):
|
|
||||||
return self.config.get('main', 'seed')
|
|
||||||
|
|
||||||
@property
|
|
||||||
def outpath(self):
|
|
||||||
return Path(self.config.train.outpath) / self.config.model.type
|
|
||||||
|
|
||||||
@property
|
|
||||||
def exp_path(self):
|
|
||||||
return Path(self.outpath) / self.name
|
|
||||||
|
|
||||||
def __init__(self, config: Config):
|
|
||||||
"""
|
|
||||||
params (dict|None): Optional. Parameters of the experiment. After experiment creation params are read-only.
|
|
||||||
Parameters are displayed in the experiment’s Parameters section and each key-value pair can be
|
|
||||||
viewed in experiments view as a column.
|
|
||||||
properties (dict|None): Optional default is {}. Properties of the experiment.
|
|
||||||
They are editable after experiment is created. Properties are displayed in the experiment’s Details and
|
|
||||||
each key-value pair can be viewed in experiments view as a column.
|
|
||||||
tags (list|None): Optional default []. Must be list of str. Tags of the experiment.
|
|
||||||
They are editable after experiment is created (see: append_tag() and remove_tag()).
|
|
||||||
Tags are displayed in the experiment’s Details and can be viewed in experiments view as a column.
|
|
||||||
"""
|
|
||||||
super(Logger, self).__init__()
|
|
||||||
|
|
||||||
self.config = config
|
|
||||||
self.debug = self.config.main.debug
|
|
||||||
self._testtube_kwargs = dict(save_dir=self.outpath, version=self.version, name=self.name)
|
|
||||||
self._neptune_kwargs = dict(offline_mode=self.debug,
|
|
||||||
api_key=self.config.project.neptune_key,
|
|
||||||
experiment_name=self.name,
|
|
||||||
project_name=self.project_name,
|
|
||||||
params=self.config.model_paramters)
|
|
||||||
self.neptunelogger = NeptuneLogger(**self._neptune_kwargs)
|
|
||||||
self.testtubelogger = TestTubeLogger(**self._testtube_kwargs)
|
|
||||||
self.log_config_as_ini()
|
|
||||||
|
|
||||||
def log_hyperparams(self, params):
|
|
||||||
self.neptunelogger.log_hyperparams(params)
|
|
||||||
self.testtubelogger.log_hyperparams(params)
|
|
||||||
pass
|
|
||||||
|
|
||||||
def log_metrics(self, metrics, step=None):
|
|
||||||
self.neptunelogger.log_metrics(metrics, step=step)
|
|
||||||
self.testtubelogger.log_metrics(metrics, step=step)
|
|
||||||
pass
|
|
||||||
|
|
||||||
def close(self):
|
|
||||||
self.testtubelogger.close()
|
|
||||||
self.neptunelogger.close()
|
|
||||||
|
|
||||||
def log_config_as_ini(self):
|
|
||||||
self.config.write(self.log_dir / 'config.ini')
|
|
||||||
|
|
||||||
def log_text(self, name, text, step_nb=0, **kwargs):
|
|
||||||
# TODO Implement Offline variant.
|
|
||||||
self.neptunelogger.log_text(name, text, step_nb)
|
|
||||||
|
|
||||||
def log_metric(self, metric_name, metric_value, **kwargs):
|
|
||||||
self.testtubelogger.log_metrics(dict(metric_name=metric_value))
|
|
||||||
self.neptunelogger.log_metric(metric_name, metric_value, **kwargs)
|
|
||||||
|
|
||||||
def log_image(self, name, image, **kwargs):
|
|
||||||
self.neptunelogger.log_image(name, image, **kwargs)
|
|
||||||
step = kwargs.get('step', None)
|
|
||||||
name = f'{step}_{name}' if step is not None else name
|
|
||||||
image.savefig(self.log_dir / self.media_dir / name)
|
|
||||||
|
|
||||||
def save(self):
|
|
||||||
self.testtubelogger.save()
|
|
||||||
self.neptunelogger.save()
|
|
||||||
|
|
||||||
def finalize(self, status):
|
|
||||||
self.testtubelogger.finalize(status)
|
|
||||||
self.neptunelogger.finalize(status)
|
|
||||||
|
|
||||||
def __enter__(self):
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
||||||
self.finalize('success')
|
|
||||||
pass
|
|
@ -1,5 +1,7 @@
|
|||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
from collections import Mapping
|
from collections import Mapping
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
@ -11,6 +13,14 @@ from torch import nn
|
|||||||
# Hyperparamter Object
|
# Hyperparamter Object
|
||||||
class ModelParameters(Namespace, Mapping):
|
class ModelParameters(Namespace, Mapping):
|
||||||
|
|
||||||
|
@property
|
||||||
|
def as_dict(self):
|
||||||
|
return {key: self.get(key) if key != 'activation' else self.activation_as_string for key in self.keys()}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def activation_as_string(self):
|
||||||
|
return self['activation'].lower()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def module_kwargs(self):
|
def module_kwargs(self):
|
||||||
|
|
||||||
@ -18,9 +28,11 @@ class ModelParameters(Namespace, Mapping):
|
|||||||
|
|
||||||
paramter_mapping.update(
|
paramter_mapping.update(
|
||||||
dict(
|
dict(
|
||||||
activation=self._activations[self['activation']]
|
activation=self.__getattribute__('activation')
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
# Get rid of paramters that
|
||||||
|
paramter_mapping.__delitem__('in_shape')
|
||||||
|
|
||||||
return paramter_mapping
|
return paramter_mapping
|
||||||
|
|
||||||
@ -42,49 +54,54 @@ class ModelParameters(Namespace, Mapping):
|
|||||||
|
|
||||||
def __getattribute__(self, name):
|
def __getattribute__(self, name):
|
||||||
if name == 'activation':
|
if name == 'activation':
|
||||||
return self._activations[self['activation']]
|
return self._activations[self['activation'].lower()]
|
||||||
else:
|
else:
|
||||||
try:
|
return super(ModelParameters, self).__getattribute__(name)
|
||||||
return super(ModelParameters, self).__getattribute__(name)
|
|
||||||
except AttributeError as e:
|
|
||||||
if name == 'stretch':
|
|
||||||
return False
|
|
||||||
else:
|
|
||||||
raise AttributeError(e)
|
|
||||||
|
|
||||||
_activations = dict(
|
_activations = dict(
|
||||||
leaky_relu=nn.LeakyReLU,
|
leaky_relu=nn.LeakyReLU,
|
||||||
|
gelu=nn.GELU,
|
||||||
|
elu=nn.ELU,
|
||||||
relu=nn.ReLU,
|
relu=nn.ReLU,
|
||||||
sigmoid=nn.Sigmoid,
|
sigmoid=nn.Sigmoid,
|
||||||
tanh=nn.Tanh
|
tanh=nn.Tanh
|
||||||
)
|
)
|
||||||
|
|
||||||
def __init__(self, parameter_mapping):
|
def __init__(self, parameter_mapping):
|
||||||
|
if isinstance(parameter_mapping, Namespace):
|
||||||
|
parameter_mapping = parameter_mapping.__dict__
|
||||||
super(ModelParameters, self).__init__(**parameter_mapping)
|
super(ModelParameters, self).__init__(**parameter_mapping)
|
||||||
|
|
||||||
|
|
||||||
class SavedLightningModels(object):
|
class SavedLightningModels(object):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load_checkpoint(cls, models_root_path, model=None, n=-1, tags_file_path=''):
|
def load_checkpoint(cls, models_root_path, model=None, n=-1, checkpoint: Union[None, str] = None):
|
||||||
assert models_root_path.exists(), f'The path {models_root_path.absolute()} does not exist!'
|
assert models_root_path.exists(), f'The path {models_root_path.absolute()} does not exist!'
|
||||||
found_checkpoints = list(Path(models_root_path).rglob('*.ckpt'))
|
if checkpoint is not None:
|
||||||
|
checkpoint_path = Path(checkpoint)
|
||||||
found_checkpoints = natsorted(found_checkpoints, key=lambda y: y.name)
|
assert checkpoint_path.exists(), f'The path ({checkpoint_path} does not exist).'
|
||||||
|
else:
|
||||||
|
found_checkpoints = list(Path(models_root_path).rglob('*.ckpt'))
|
||||||
|
checkpoint_path = natsorted(found_checkpoints, key=lambda y: y.name)[n]
|
||||||
if model is None:
|
if model is None:
|
||||||
model = torch.load(models_root_path / 'model_class.obj')
|
model = torch.load(models_root_path / 'model_class.obj')
|
||||||
assert model is not None
|
assert model is not None
|
||||||
|
|
||||||
return cls(weights=found_checkpoints[n], model=model)
|
return cls(weights=checkpoint_path, model=model)
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
self.weights: str = kwargs.get('weights', '')
|
self.weights: Path = Path(kwargs.get('weights', ''))
|
||||||
|
self.hparams: Path = self.weights.parent / 'hparams.yaml'
|
||||||
|
|
||||||
self.model = kwargs.get('model', None)
|
self.model = kwargs.get('model', None)
|
||||||
assert self.model is not None
|
assert self.model is not None
|
||||||
|
|
||||||
def restore(self):
|
def restore(self):
|
||||||
pretrained_model = self.model.load_from_checkpoint(self.weights)
|
|
||||||
|
pretrained_model = self.model.load_from_checkpoint(self.weights.__str__())
|
||||||
|
# , hparams_file=self.hparams.__str__())
|
||||||
pretrained_model.eval()
|
pretrained_model.eval()
|
||||||
pretrained_model.freeze()
|
pretrained_model.freeze()
|
||||||
return pretrained_model
|
return pretrained_model
|
||||||
|
|
||||||
|
@ -23,3 +23,5 @@ def run_n_in_parallel(f, n, processes=0, **kwargs):
|
|||||||
p.join()
|
p.join()
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
raise NotImplementedError()
|
||||||
|
@ -1,6 +1,35 @@
|
|||||||
|
import importlib
|
||||||
|
import inspect
|
||||||
import pickle
|
import pickle
|
||||||
import shelve
|
import shelve
|
||||||
from pathlib import Path
|
from argparse import ArgumentParser, ArgumentError
|
||||||
|
from ast import literal_eval
|
||||||
|
from pathlib import Path, PurePath
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import random
|
||||||
|
|
||||||
|
|
||||||
|
def auto_cast(a):
|
||||||
|
try:
|
||||||
|
return literal_eval(a)
|
||||||
|
except:
|
||||||
|
return a
|
||||||
|
|
||||||
|
|
||||||
|
def to_one_hot(idx_array, max_classes):
|
||||||
|
one_hot = np.zeros((idx_array.size, max_classes))
|
||||||
|
one_hot[np.arange(idx_array.size), idx_array] = 1
|
||||||
|
return one_hot
|
||||||
|
|
||||||
|
|
||||||
|
def fix_all_random_seeds(seed):
|
||||||
|
np.random.seed(seed)
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
random.seed(seed)
|
||||||
|
print(f'Seed is now fixed: "{seed}".')
|
||||||
|
|
||||||
|
|
||||||
def write_to_shelve(file_path, value):
|
def write_to_shelve(file_path, value):
|
||||||
@ -20,4 +49,43 @@ def load_from_shelve(file_path, key):
|
|||||||
|
|
||||||
def check_path(file_path):
|
def check_path(file_path):
|
||||||
assert isinstance(file_path, Path)
|
assert isinstance(file_path, Path)
|
||||||
assert str(file_path).endswith('.pik')
|
assert str(file_path).endswith('.pik')
|
||||||
|
|
||||||
|
|
||||||
|
def locate_and_import_class(class_name, folder_path: Union[str, PurePath] = ''):
|
||||||
|
"""Locate an object by name or dotted path, importing as necessary."""
|
||||||
|
import sys
|
||||||
|
sys.path.append("..")
|
||||||
|
folder_path = Path(folder_path)
|
||||||
|
module_paths = [x for x in folder_path.rglob('*.py') if x.is_file() and '__init__' not in x.name]
|
||||||
|
# possible_package_path = folder_path / '__init__.py'
|
||||||
|
# package = str(possible_package_path) if possible_package_path.exists() else None
|
||||||
|
for module_path in module_paths:
|
||||||
|
mod = importlib.import_module('.'.join([x.replace('.py', '') for x in module_path.parts]))
|
||||||
|
try:
|
||||||
|
model_class = mod.__getattribute__(class_name)
|
||||||
|
return model_class
|
||||||
|
except AttributeError:
|
||||||
|
continue
|
||||||
|
raise AttributeError(f'Check the {folder_path.name} name. Possible files are:\n{[x.name for x in module_paths]}')
|
||||||
|
|
||||||
|
|
||||||
|
def add_argparse_args(cls, parent_parser):
|
||||||
|
parser = ArgumentParser(parents=[parent_parser], add_help=False)
|
||||||
|
full_arg_spec = inspect.getfullargspec(cls.__init__)
|
||||||
|
n_non_defaults = len(full_arg_spec.args) - (len(full_arg_spec.defaults) if full_arg_spec.defaults else 0)
|
||||||
|
for idx, argument in enumerate(full_arg_spec.args):
|
||||||
|
try:
|
||||||
|
if argument == 'self':
|
||||||
|
continue
|
||||||
|
if idx < n_non_defaults:
|
||||||
|
parser.add_argument(f'--{argument}', type=int)
|
||||||
|
else:
|
||||||
|
argument_type = type(argument)
|
||||||
|
parser.add_argument(f'--{argument}',
|
||||||
|
type=argument_type,
|
||||||
|
default=full_arg_spec.defaults[idx - n_non_defaults]
|
||||||
|
)
|
||||||
|
except ArgumentError:
|
||||||
|
continue
|
||||||
|
return parser
|
||||||
|
@ -1,8 +1,22 @@
|
|||||||
|
from abc import ABC
|
||||||
from torchvision.transforms import ToTensor as TorchVisionToTensor
|
from torchvision.transforms import ToTensor as TorchVisionToTensor
|
||||||
|
|
||||||
|
|
||||||
|
class _BaseTransformation(ABC):
|
||||||
|
|
||||||
|
def __init__(self, *args):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f'{self.__class__.__name__}({self.__dict__})'
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
class ToTensor(TorchVisionToTensor):
|
class ToTensor(TorchVisionToTensor):
|
||||||
|
|
||||||
def __call__(self, pic):
|
def __call__(self, pic):
|
||||||
|
# Make it float .float() == 32bit
|
||||||
tensor = super(ToTensor, self).__call__(pic).float()
|
tensor = super(ToTensor, self).__call__(pic).float()
|
||||||
return tensor
|
return tensor
|
||||||
|
@ -1,31 +1,56 @@
|
|||||||
|
try:
|
||||||
|
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
|
||||||
|
from matplotlib import pyplot as plt
|
||||||
|
except ImportError: # pragma: no-cover
|
||||||
|
raise ImportError('You want to use `matplotlib` which is not installed yet,' # pragma: no-cover
|
||||||
|
' install it with `pip install matplotlib`.')
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
|
|
||||||
|
def prettyfy_sns():
|
||||||
|
plt.style.use('default')
|
||||||
|
try:
|
||||||
|
import seaborn as sns
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError('You want to use `seaborn` which is not installed yet,' # pragma: no-cover
|
||||||
|
' install it with `pip install seaborn`.')
|
||||||
|
sns.set_palette('Dark2')
|
||||||
|
|
||||||
|
tex_fonts = {
|
||||||
|
# Use LaTeX to write all text
|
||||||
|
"text.usetex": True,
|
||||||
|
"font.family": "serif",
|
||||||
|
# Use 10pt font in plots, to match 10pt font in document
|
||||||
|
"axes.labelsize": 10,
|
||||||
|
"font.size": 10,
|
||||||
|
# Make the legend/label fonts a little smaller
|
||||||
|
"legend.fontsize": 8,
|
||||||
|
"xtick.labelsize": 8,
|
||||||
|
"ytick.labelsize": 8
|
||||||
|
}
|
||||||
|
|
||||||
|
plt.rcParams.update(tex_fonts)
|
||||||
|
|
||||||
|
|
||||||
class Plotter(object):
|
class Plotter(object):
|
||||||
def __init__(self, root_path=''):
|
|
||||||
self.root_path = Path(root_path)
|
|
||||||
|
|
||||||
def save_current_figure(self, path, extention='.png', naked=True):
|
def __init__(self, root_path=''):
|
||||||
fig, _ = plt.gcf(), plt.gca()
|
if not root_path:
|
||||||
|
self.root_path = Path(root_path)
|
||||||
|
|
||||||
|
def save_figure(self, figure, title, extention='.png', naked=False):
|
||||||
|
canvas = FigureCanvas(figure)
|
||||||
# Prepare save location and check img file extention
|
# Prepare save location and check img file extention
|
||||||
path = self.root_path / Path(path if str(path).endswith(extention) else f'{str(path)}{extention}')
|
path = self.root_path / f'{title}{extention}'
|
||||||
path.parent.mkdir(exist_ok=True, parents=True)
|
path.parent.mkdir(exist_ok=True, parents=True)
|
||||||
if naked:
|
if naked:
|
||||||
plt.axis('off')
|
figure.axis('off)')
|
||||||
fig.savefig(path, bbox_inches='tight', transparent=True, pad_inches=0)
|
figure.savefig(path, bbox_inches='tight', transparent=True, pad_inches=0)
|
||||||
fig.clf()
|
canvas.print_figure(path)
|
||||||
else:
|
else:
|
||||||
fig.savefig(path)
|
canvas.print_figure(path)
|
||||||
fig.clf()
|
|
||||||
|
|
||||||
def show_current_figure(self):
|
|
||||||
fig, _ = plt.gcf(), plt.gca()
|
|
||||||
fig.show()
|
|
||||||
fig.clf()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
output_root = Path('..') / 'output'
|
raise PermissionError('Get out of here.')
|
||||||
p = Plotter(output_root)
|
|
||||||
p.save_current_figure('test.png')
|
|
||||||
|
Reference in New Issue
Block a user