initial commit
This commit is contained in:
parent
c15ee64688
commit
f0262e1895
1
.gitignore
vendored
Normal file
1
.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
|||||||
|
/.idea/
|
0
__init__.py
Normal file
0
__init__.py
Normal file
0
evaluation/__init__.py
Normal file
0
evaluation/__init__.py
Normal file
34
evaluation/classification.py
Normal file
34
evaluation/classification.py
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from sklearn.metrics import roc_curve, auc
|
||||||
|
|
||||||
|
|
||||||
|
class ROCEvaluation(object):
|
||||||
|
|
||||||
|
linewidth = 2
|
||||||
|
|
||||||
|
def __init__(self, plot_roc=False):
|
||||||
|
self.plot_roc = plot_roc
|
||||||
|
self.epoch = 0
|
||||||
|
|
||||||
|
def __call__(self, prediction, label):
|
||||||
|
|
||||||
|
# Compute ROC curve and ROC area
|
||||||
|
fpr, tpr, _ = roc_curve(prediction, label)
|
||||||
|
roc_auc = auc(fpr, tpr)
|
||||||
|
if self.plot_roc:
|
||||||
|
_ = plt.gcf()
|
||||||
|
plt.plot(fpr, tpr, color='darkorange', lw=self.linewidth, label=f'ROC curve (area = {roc_auc})')
|
||||||
|
self._prepare_fig()
|
||||||
|
return roc_auc, fpr, tpr
|
||||||
|
|
||||||
|
def _prepare_fig(self):
|
||||||
|
fig = plt.gcf()
|
||||||
|
ax = plt.gca()
|
||||||
|
plt.plot([0, 1], [0, 1], color='navy', lw=self.linewidth, linestyle='--')
|
||||||
|
plt.xlim([0.0, 1.0])
|
||||||
|
plt.ylim([0.0, 1.05])
|
||||||
|
plt.xlabel('False Positive Rate')
|
||||||
|
plt.ylabel('True Positive Rate')
|
||||||
|
fig.legend(loc="lower right")
|
||||||
|
|
||||||
|
return fig
|
0
examples/__init__.py
Normal file
0
examples/__init__.py
Normal file
0
modules/__init__.py
Normal file
0
modules/__init__.py
Normal file
144
modules/blocks.py
Normal file
144
modules/blocks.py
Normal file
@ -0,0 +1,144 @@
|
|||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from ml_lib.modules.utils import AutoPad, Interpolate
|
||||||
|
|
||||||
|
|
||||||
|
#
|
||||||
|
# Sub - Modules
|
||||||
|
###################
|
||||||
|
|
||||||
|
class ConvModule(nn.Module):
|
||||||
|
|
||||||
|
@property
|
||||||
|
def shape(self):
|
||||||
|
x = torch.randn(self.in_shape).unsqueeze(0)
|
||||||
|
output = self(x)
|
||||||
|
return output.shape[1:]
|
||||||
|
|
||||||
|
def __init__(self, in_shape, conv_filters, conv_kernel, activation: nn.Module = nn.ELU, pooling_size=None,
|
||||||
|
use_bias=True, use_norm=False, dropout: Union[int, float] = 0,
|
||||||
|
conv_class=nn.Conv2d, conv_stride=1, conv_padding=0):
|
||||||
|
super(ConvModule, self).__init__()
|
||||||
|
|
||||||
|
# Module Parameters
|
||||||
|
self.in_shape = in_shape
|
||||||
|
in_channels, height, width = in_shape[0], in_shape[1], in_shape[2]
|
||||||
|
self.activation = activation()
|
||||||
|
|
||||||
|
# Convolution Parameters
|
||||||
|
self.padding = conv_padding
|
||||||
|
self.stride = conv_stride
|
||||||
|
self.conv_filters = conv_filters
|
||||||
|
self.conv_kernel = conv_kernel
|
||||||
|
|
||||||
|
# Modules
|
||||||
|
self.dropout = nn.Dropout2d(dropout) if dropout else lambda x: x
|
||||||
|
self.pooling = nn.MaxPool2d(pooling_size) if pooling_size else lambda x: x
|
||||||
|
self.norm = nn.BatchNorm2d(in_channels, eps=1e-04) if use_norm else lambda x: x
|
||||||
|
self.conv = conv_class(in_channels, self.conv_filters, self.conv_kernel, bias=use_bias,
|
||||||
|
padding=self.padding, stride=self.stride
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.norm(x)
|
||||||
|
|
||||||
|
tensor = self.conv(x)
|
||||||
|
tensor = self.dropout(tensor)
|
||||||
|
tensor = self.pooling(tensor)
|
||||||
|
tensor = self.activation(tensor)
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
class DeConvModule(nn.Module):
|
||||||
|
|
||||||
|
@property
|
||||||
|
def shape(self):
|
||||||
|
x = torch.randn(self.in_shape).unsqueeze(0)
|
||||||
|
output = self(x)
|
||||||
|
return output.shape[1:]
|
||||||
|
|
||||||
|
def __init__(self, in_shape, conv_filters, conv_kernel, conv_stride=1, conv_padding=0,
|
||||||
|
dropout: Union[int, float] = 0, autopad=0,
|
||||||
|
activation: Union[None, nn.Module] = nn.ReLU, interpolation_scale=0,
|
||||||
|
use_bias=True, use_norm=False):
|
||||||
|
super(DeConvModule, self).__init__()
|
||||||
|
in_channels, height, width = in_shape[0], in_shape[1], in_shape[2]
|
||||||
|
self.padding = conv_padding
|
||||||
|
self.conv_kernel = conv_kernel
|
||||||
|
self.stride = conv_stride
|
||||||
|
self.in_shape = in_shape
|
||||||
|
self.conv_filters = conv_filters
|
||||||
|
|
||||||
|
self.autopad = AutoPad() if autopad else lambda x: x
|
||||||
|
self.interpolation = Interpolate(scale_factor=interpolation_scale) if interpolation_scale else lambda x: x
|
||||||
|
self.norm = nn.BatchNorm2d(in_channels, eps=1e-04) if use_norm else lambda x: x
|
||||||
|
self.dropout = nn.Dropout2d(dropout) if dropout else lambda x: x
|
||||||
|
self.de_conv = nn.ConvTranspose2d(in_channels, self.conv_filters, self.conv_kernel, bias=use_bias,
|
||||||
|
padding=self.padding, stride=self.stride)
|
||||||
|
|
||||||
|
self.activation = activation() if activation else lambda x: x
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.norm(x)
|
||||||
|
x = self.dropout(x)
|
||||||
|
x = self.autopad(x)
|
||||||
|
x = self.interpolation(x)
|
||||||
|
|
||||||
|
tensor = self.de_conv(x)
|
||||||
|
tensor = self.activation(tensor)
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualModule(nn.Module):
|
||||||
|
|
||||||
|
@property
|
||||||
|
def shape(self):
|
||||||
|
x = torch.randn(self.in_shape).unsqueeze(0)
|
||||||
|
output = self(x)
|
||||||
|
return output.shape[1:]
|
||||||
|
|
||||||
|
def __init__(self, in_shape, module_class, n, activation=None, **module_parameters):
|
||||||
|
assert n >= 1
|
||||||
|
super(ResidualModule, self).__init__()
|
||||||
|
self.in_shape = in_shape
|
||||||
|
module_parameters.update(in_shape=in_shape)
|
||||||
|
self.activation = activation() if activation else lambda x: x
|
||||||
|
self.residual_block = nn.ModuleList([module_class(**module_parameters) for _ in range(n)])
|
||||||
|
assert self.in_shape == self.shape, f'The in_shape: {self.in_shape} - must match the out_shape: {self.shape}.'
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
for module in self.residual_block:
|
||||||
|
tensor = module(x)
|
||||||
|
|
||||||
|
# noinspection PyUnboundLocalVariable
|
||||||
|
tensor = tensor + x
|
||||||
|
tensor = self.activation(tensor)
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
class RecurrentModule(nn.Module):
|
||||||
|
|
||||||
|
@property
|
||||||
|
def shape(self):
|
||||||
|
x = torch.randn(self.in_shape).unsqueeze(0)
|
||||||
|
output = self(x)
|
||||||
|
return output.shape[1:]
|
||||||
|
|
||||||
|
def __init__(self, in_shape, hidden_size, num_layers=1, cell_type=nn.GRU, use_bias=True, dropout=0):
|
||||||
|
super(RecurrentModule, self).__init__()
|
||||||
|
self.use_bias = use_bias
|
||||||
|
self.num_layers = num_layers
|
||||||
|
self.in_shape = in_shape
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.dropout = dropout
|
||||||
|
self.rnn = cell_type(self.in_shape[-1] * self.in_shape[-2], hidden_size,
|
||||||
|
num_layers=num_layers,
|
||||||
|
bias=self.use_bias,
|
||||||
|
batch_first=True,
|
||||||
|
dropout=self.dropout)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
tensor = self.rnn(x)
|
||||||
|
return tensor
|
23
modules/losses.py
Normal file
23
modules/losses.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
from typing import List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from ml_lib.modules.utils import FlipTensor
|
||||||
|
from ml_lib.objects.map import MapStorage, Map
|
||||||
|
from ml_lib.objects.trajectory import Trajectory
|
||||||
|
|
||||||
|
|
||||||
|
class BinaryHomotopicLoss(nn.Module):
|
||||||
|
def __init__(self, map_storage: MapStorage):
|
||||||
|
super(BinaryHomotopicLoss, self).__init__()
|
||||||
|
self.map_storage = map_storage
|
||||||
|
self.flipper = FlipTensor()
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, y: torch.Tensor, mapnames: str):
|
||||||
|
maps: List[Map] = [self.map_storage[mapname] for mapname in mapnames]
|
||||||
|
for basemap in maps:
|
||||||
|
basemap = basemap.as_2d_array
|
||||||
|
|
||||||
|
|
||||||
|
|
229
modules/model_parts.py
Normal file
229
modules/model_parts.py
Normal file
@ -0,0 +1,229 @@
|
|||||||
|
#
|
||||||
|
# Full Model Parts
|
||||||
|
###################
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
|
class Generator(nn.Module):
|
||||||
|
@property
|
||||||
|
def shape(self):
|
||||||
|
x = torch.randn(self.lat_dim).unsqueeze(0)
|
||||||
|
output = self(x)
|
||||||
|
return output.shape[1:]
|
||||||
|
|
||||||
|
# noinspection PyUnresolvedReferences
|
||||||
|
def __init__(self, out_channels, re_shape, lat_dim, use_norm=False, use_bias=True, dropout: Union[int, float] = 0,
|
||||||
|
filters: List[int] = None, activation=nn.ReLU):
|
||||||
|
super(Generator, self).__init__()
|
||||||
|
assert filters, '"Filters" has to be a list of int len 3'
|
||||||
|
self.filters = filters
|
||||||
|
self.activation = activation
|
||||||
|
self.inner_activation = activation()
|
||||||
|
self.out_activation = None
|
||||||
|
self.lat_dim = lat_dim
|
||||||
|
self.dropout = dropout
|
||||||
|
self.l1 = nn.Linear(self.lat_dim, reduce(mul, re_shape), bias=use_bias)
|
||||||
|
# re_shape = (self.feature_mixed_dim // reduce(mul, re_shape[1:]), ) + tuple(re_shape[1:])
|
||||||
|
|
||||||
|
self.flat = Flatten(to=re_shape)
|
||||||
|
|
||||||
|
self.deconv1 = DeConvModule(re_shape, conv_filters=self.filters[0],
|
||||||
|
conv_kernel=5,
|
||||||
|
conv_padding=2,
|
||||||
|
conv_stride=1,
|
||||||
|
normalize=use_norm,
|
||||||
|
activation=self.activation,
|
||||||
|
interpolation_scale=2,
|
||||||
|
dropout=self.dropout
|
||||||
|
)
|
||||||
|
|
||||||
|
self.deconv2 = DeConvModule(self.deconv1.shape, conv_filters=self.filters[1],
|
||||||
|
conv_kernel=3,
|
||||||
|
conv_padding=1,
|
||||||
|
conv_stride=1,
|
||||||
|
normalize=use_norm,
|
||||||
|
activation=self.activation,
|
||||||
|
interpolation_scale=2,
|
||||||
|
dropout=self.dropout
|
||||||
|
)
|
||||||
|
|
||||||
|
self.deconv3 = DeConvModule(self.deconv2.shape, conv_filters=self.filters[2],
|
||||||
|
conv_kernel=3,
|
||||||
|
conv_padding=1,
|
||||||
|
conv_stride=1,
|
||||||
|
normalize=use_norm,
|
||||||
|
activation=self.activation,
|
||||||
|
interpolation_scale=2,
|
||||||
|
dropout=self.dropout
|
||||||
|
)
|
||||||
|
|
||||||
|
self.deconv4 = DeConvModule(self.deconv3.shape, conv_filters=out_channels,
|
||||||
|
conv_kernel=3,
|
||||||
|
conv_padding=1,
|
||||||
|
# normalize=use_norm,
|
||||||
|
activation=self.out_activation
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, z):
|
||||||
|
tensor = self.l1(z)
|
||||||
|
tensor = self.inner_activation(tensor)
|
||||||
|
tensor = self.flat(tensor)
|
||||||
|
tensor = self.deconv1(tensor)
|
||||||
|
tensor = self.deconv2(tensor)
|
||||||
|
tensor = self.deconv3(tensor)
|
||||||
|
tensor = self.deconv4(tensor)
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
def size(self):
|
||||||
|
return self.shape
|
||||||
|
|
||||||
|
|
||||||
|
class UnitGenerator(Generator):
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
kwargs.update(use_norm=True)
|
||||||
|
super(UnitGenerator, self).__init__(*args, **kwargs)
|
||||||
|
self.norm_f = nn.BatchNorm1d(self.l1.out_features, eps=1e-04, affine=False)
|
||||||
|
self.norm1 = nn.BatchNorm2d(self.deconv1.conv_filters, eps=1e-04, affine=False)
|
||||||
|
self.norm2 = nn.BatchNorm2d(self.deconv2.conv_filters, eps=1e-04, affine=False)
|
||||||
|
self.norm3 = nn.BatchNorm2d(self.deconv3.conv_filters, eps=1e-04, affine=False)
|
||||||
|
|
||||||
|
def forward(self, z_c1_c2_c3):
|
||||||
|
z, c1, c2, c3 = z_c1_c2_c3
|
||||||
|
tensor = self.l1(z)
|
||||||
|
tensor = self.inner_activation(tensor)
|
||||||
|
tensor = self.norm(tensor)
|
||||||
|
tensor = self.flat(tensor)
|
||||||
|
|
||||||
|
tensor = self.deconv1(tensor) + c3
|
||||||
|
tensor = self.inner_activation(tensor)
|
||||||
|
tensor = self.norm1(tensor)
|
||||||
|
|
||||||
|
tensor = self.deconv2(tensor) + c2
|
||||||
|
tensor = self.inner_activation(tensor)
|
||||||
|
tensor = self.norm2(tensor)
|
||||||
|
|
||||||
|
tensor = self.deconv3(tensor) + c1
|
||||||
|
tensor = self.inner_activation(tensor)
|
||||||
|
tensor = self.norm3(tensor)
|
||||||
|
|
||||||
|
tensor = self.deconv4(tensor)
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
class BaseEncoder(nn.Module):
|
||||||
|
@property
|
||||||
|
def shape(self):
|
||||||
|
x = torch.randn(self.in_shape).unsqueeze(0)
|
||||||
|
output = self(x)
|
||||||
|
return output.shape[1:]
|
||||||
|
|
||||||
|
# noinspection PyUnresolvedReferences
|
||||||
|
def __init__(self, in_shape, lat_dim=256, use_bias=True, use_norm=False, dropout: Union[int, float] = 0,
|
||||||
|
latent_activation: Union[nn.Module, None] = None, activation: nn.Module = nn.ELU,
|
||||||
|
filters: List[int] = None):
|
||||||
|
super(BaseEncoder, self).__init__()
|
||||||
|
assert filters, '"Filters" has to be a list of int len 3'
|
||||||
|
|
||||||
|
# Optional Padding for odd image-sizes
|
||||||
|
# Obsolet, already Done by autopadding module on incoming tensors
|
||||||
|
# in_shape = [x+1 if x % 2 != 0 and idx else x for idx, x in enumerate(in_shape)]
|
||||||
|
|
||||||
|
# Parameters
|
||||||
|
self.lat_dim = lat_dim
|
||||||
|
self.in_shape = in_shape
|
||||||
|
self.use_bias = use_bias
|
||||||
|
self.latent_activation = latent_activation() if latent_activation else None
|
||||||
|
|
||||||
|
# Modules
|
||||||
|
self.conv1 = ConvModule(self.in_shape, conv_filters=filters[0],
|
||||||
|
conv_kernel=3,
|
||||||
|
conv_padding=1,
|
||||||
|
conv_stride=1,
|
||||||
|
pooling_size=2,
|
||||||
|
use_norm=use_norm,
|
||||||
|
dropout=dropout,
|
||||||
|
activation=activation
|
||||||
|
)
|
||||||
|
|
||||||
|
self.conv2 = ConvModule(self.conv1.shape, conv_filters=filters[1],
|
||||||
|
conv_kernel=3,
|
||||||
|
conv_padding=1,
|
||||||
|
conv_stride=1,
|
||||||
|
pooling_size=2,
|
||||||
|
use_norm=use_norm,
|
||||||
|
dropout=dropout,
|
||||||
|
activation=activation
|
||||||
|
)
|
||||||
|
|
||||||
|
self.conv3 = ConvModule(self.conv2.shape, conv_filters=filters[2],
|
||||||
|
conv_kernel=5,
|
||||||
|
conv_padding=2,
|
||||||
|
conv_stride=1,
|
||||||
|
pooling_size=2,
|
||||||
|
use_norm=use_norm,
|
||||||
|
dropout=dropout,
|
||||||
|
activation=activation
|
||||||
|
)
|
||||||
|
|
||||||
|
self.flat = Flatten()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
tensor = self.conv1(x)
|
||||||
|
tensor = self.conv2(tensor)
|
||||||
|
tensor = self.conv3(tensor)
|
||||||
|
tensor = self.flat(tensor)
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
class UnitEncoder(BaseEncoder):
|
||||||
|
# noinspection PyUnresolvedReferences
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
kwargs.update(use_norm=True)
|
||||||
|
super(UnitEncoder, self).__init__(*args, **kwargs)
|
||||||
|
self.l1 = nn.Linear(reduce(mul, self.conv3.shape), self.lat_dim, bias=self.use_bias)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
c1 = self.conv1(x)
|
||||||
|
c2 = self.conv2(c1)
|
||||||
|
c3 = self.conv3(c2)
|
||||||
|
tensor = self.flat(c3)
|
||||||
|
l1 = self.l1(tensor)
|
||||||
|
return c1, c2, c3, l1
|
||||||
|
|
||||||
|
|
||||||
|
class VariationalEncoder(BaseEncoder):
|
||||||
|
# noinspection PyUnresolvedReferences
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super(VariationalEncoder, self).__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
self.logvar = nn.Linear(reduce(mul, self.conv3.shape), self.lat_dim, bias=self.use_bias)
|
||||||
|
self.mu = nn.Linear(reduce(mul, self.conv3.shape), self.lat_dim, bias=self.use_bias)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def reparameterize(mu, logvar):
|
||||||
|
std = torch.exp(0.5*logvar)
|
||||||
|
eps = torch.randn_like(std)
|
||||||
|
return mu + eps*std
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
tensor = super(VariationalEncoder, self).forward(x)
|
||||||
|
mu = self.mu(tensor)
|
||||||
|
logvar = self.logvar(tensor)
|
||||||
|
z = self.reparameterize(mu, logvar)
|
||||||
|
return mu, logvar, z
|
||||||
|
|
||||||
|
|
||||||
|
class Encoder(BaseEncoder):
|
||||||
|
# noinspection PyUnresolvedReferences
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super(Encoder, self).__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
self.l1 = nn.Linear(reduce(mul, self.conv3.shape), self.lat_dim, bias=self.use_bias)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
tensor = super(Encoder, self).forward(x)
|
||||||
|
tensor = self.l1(tensor)
|
||||||
|
tensor = self.latent_activation(tensor) if self.latent_activation else tensor
|
||||||
|
return tensor
|
201
modules/utils.py
Normal file
201
modules/utils.py
Normal file
@ -0,0 +1,201 @@
|
|||||||
|
from abc import ABC
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from torch import functional as F
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
import pytorch_lightning as pl
|
||||||
|
|
||||||
|
|
||||||
|
# Utility - Modules
|
||||||
|
###################
|
||||||
|
|
||||||
|
|
||||||
|
class Flatten(nn.Module):
|
||||||
|
|
||||||
|
@property
|
||||||
|
def shape(self):
|
||||||
|
try:
|
||||||
|
x = torch.randn(self.in_shape).unsqueeze(0)
|
||||||
|
output = self(x)
|
||||||
|
return output.shape[1:] if len(output.shape[1:]) > 1 else output.shape[-1]
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
return -1
|
||||||
|
|
||||||
|
def __init__(self, in_shape, to=-1):
|
||||||
|
assert isinstance(to, int) or isinstance(to, tuple)
|
||||||
|
super(Flatten, self).__init__()
|
||||||
|
self.in_shape = in_shape
|
||||||
|
self.to = (to,) if isinstance(to, int) else to
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return x.view(x.size(0), *self.to)
|
||||||
|
|
||||||
|
|
||||||
|
class Interpolate(nn.Module):
|
||||||
|
def __init__(self, size=None, scale_factor=None, mode='nearest', align_corners=None):
|
||||||
|
super(Interpolate, self).__init__()
|
||||||
|
self.interp = nn.functional.interpolate
|
||||||
|
self.size = size
|
||||||
|
self.scale_factor = scale_factor
|
||||||
|
self.align_corners = align_corners
|
||||||
|
self.mode = mode
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.interp(x, size=self.size, scale_factor=self.scale_factor,
|
||||||
|
mode=self.mode, align_corners=self.align_corners)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class AutoPad(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, interpolations=3, base=2):
|
||||||
|
super(AutoPad, self).__init__()
|
||||||
|
self.fct = base ** interpolations
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# noinspection PyUnresolvedReferences
|
||||||
|
x = F.pad(x,
|
||||||
|
[0,
|
||||||
|
(x.shape[-1] // self.fct + 1) * self.fct - x.shape[-1] if x.shape[-1] % self.fct != 0 else 0,
|
||||||
|
(x.shape[-2] // self.fct + 1) * self.fct - x.shape[-2] if x.shape[-2] % self.fct != 0 else 0,
|
||||||
|
0])
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class WeightInit:
|
||||||
|
|
||||||
|
def __init__(self, in_place_init_function):
|
||||||
|
self.in_place_init_function = in_place_init_function
|
||||||
|
|
||||||
|
def __call__(self, m):
|
||||||
|
if hasattr(m, 'weight'):
|
||||||
|
if isinstance(m.weight, torch.Tensor):
|
||||||
|
if m.weight.ndim < 2:
|
||||||
|
m.weight.data.fill_(0.01)
|
||||||
|
else:
|
||||||
|
self.in_place_init_function(m.weight)
|
||||||
|
if hasattr(m, 'bias'):
|
||||||
|
if isinstance(m.bias, torch.Tensor):
|
||||||
|
m.bias.data.fill_(0.01)
|
||||||
|
|
||||||
|
|
||||||
|
class LightningBaseModule(pl.LightningModule, ABC):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def name(cls):
|
||||||
|
raise NotImplementedError('Give your model a name!')
|
||||||
|
|
||||||
|
@property
|
||||||
|
def shape(self):
|
||||||
|
try:
|
||||||
|
x = torch.randn(self.in_shape).unsqueeze(0)
|
||||||
|
output = self(x)
|
||||||
|
return output.shape[1:]
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
return -1
|
||||||
|
|
||||||
|
def __init__(self, hparams):
|
||||||
|
super(LightningBaseModule, self).__init__()
|
||||||
|
self.hparams = hparams
|
||||||
|
|
||||||
|
# Data loading
|
||||||
|
# =============================================================================
|
||||||
|
# Map Object
|
||||||
|
# self.map_storage = MapStorage(self.hparams.data_param.map_root)
|
||||||
|
|
||||||
|
def size(self):
|
||||||
|
return self.shape
|
||||||
|
|
||||||
|
def _move_to_model_device(self, x):
|
||||||
|
return x.cuda() if next(self.parameters()).is_cuda else x.cpu()
|
||||||
|
|
||||||
|
def save_to_disk(self, model_path):
|
||||||
|
Path(model_path, exist_ok=True).mkdir(parents=True, exist_ok=True)
|
||||||
|
if not (model_path / 'model_class.obj').exists():
|
||||||
|
with (model_path / 'model_class.obj').open('wb') as f:
|
||||||
|
torch.save(self.__class__, f)
|
||||||
|
return True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def data_len(self):
|
||||||
|
return len(self.dataset.train_dataset)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def n_train_batches(self):
|
||||||
|
return len(self.train_dataloader())
|
||||||
|
|
||||||
|
def configure_optimizers(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def test_step(self, *args, **kwargs):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def test_epoch_end(self, outputs):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def init_weights(self, in_place_init_func_=nn.init.xavier_uniform_):
|
||||||
|
weight_initializer = WeightInit(in_place_init_function=in_place_init_func_)
|
||||||
|
self.apply(weight_initializer)
|
||||||
|
|
||||||
|
# Dataloaders
|
||||||
|
# ================================================================================
|
||||||
|
# Train Dataloader
|
||||||
|
def train_dataloader(self):
|
||||||
|
return DataLoader(dataset=self.dataset.train_dataset, shuffle=True,
|
||||||
|
batch_size=self.hparams.train_param.batch_size,
|
||||||
|
num_workers=self.hparams.data_param.worker)
|
||||||
|
|
||||||
|
# Test Dataloader
|
||||||
|
def test_dataloader(self):
|
||||||
|
return DataLoader(dataset=self.dataset.test_dataset, shuffle=True,
|
||||||
|
batch_size=self.hparams.train_param.batch_size,
|
||||||
|
num_workers=self.hparams.data_param.worker)
|
||||||
|
|
||||||
|
# Validation Dataloader
|
||||||
|
def val_dataloader(self):
|
||||||
|
return DataLoader(dataset=self.dataset.val_dataset, shuffle=True,
|
||||||
|
batch_size=self.hparams.train_param.batch_size,
|
||||||
|
num_workers=self.hparams.data_param.worker)
|
||||||
|
|
||||||
|
|
||||||
|
class FilterLayer(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(FilterLayer, self).__init__()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
tensor = x[:, -1]
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
class MergingLayer(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(MergingLayer, self).__init__()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# ToDo: Which ones to combine?
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
class FlipTensor(nn.Module):
|
||||||
|
def __init__(self, dim=-2):
|
||||||
|
super(FlipTensor, self).__init__()
|
||||||
|
self.dim = dim
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
idx = [i for i in range(x.size(self.dim) - 1, -1, -1)]
|
||||||
|
idx = torch.as_tensor(idx).long()
|
||||||
|
inverted_tensor = x.index_select(self.dim, idx)
|
||||||
|
return inverted_tensor
|
0
utils/__init__.py
Normal file
0
utils/__init__.py
Normal file
135
utils/config.py
Normal file
135
utils/config.py
Normal file
@ -0,0 +1,135 @@
|
|||||||
|
import ast
|
||||||
|
|
||||||
|
from argparse import Namespace
|
||||||
|
from collections import defaultdict
|
||||||
|
from configparser import ConfigParser
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from ml_lib.models.generators.cnn import CNNRouteGeneratorModel
|
||||||
|
from ml_lib.models.generators.cnn_discriminated import CNNRouteGeneratorDiscriminated
|
||||||
|
|
||||||
|
from ml_lib.models.homotopy_classification.cnn_based import ConvHomDetector
|
||||||
|
from ml_lib.utils.model_io import ModelParameters
|
||||||
|
from ml_lib.utils.transforms import AsArray
|
||||||
|
|
||||||
|
|
||||||
|
def is_jsonable(x):
|
||||||
|
import json
|
||||||
|
try:
|
||||||
|
json.dumps(x)
|
||||||
|
return True
|
||||||
|
except TypeError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class Config(ConfigParser):
|
||||||
|
|
||||||
|
# TODO: Do this programmatically; This did not work:
|
||||||
|
# Initialize Default Sections
|
||||||
|
# for section in self.default_sections:
|
||||||
|
# self.__setattr__(section, property(lambda x :x._get_namespace_for_section(section))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model_class(self):
|
||||||
|
model_dict = dict(ConvHomDetector=ConvHomDetector,
|
||||||
|
CNNRouteGenerator=CNNRouteGeneratorModel,
|
||||||
|
CNNRouteGeneratorDiscriminated=CNNRouteGeneratorDiscriminated
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
return model_dict[self.get('model', 'type')]
|
||||||
|
except KeyError as e:
|
||||||
|
raise KeyError(rf'The model alias you provided ("{self.get("model", "type")}") does not exist! \n'
|
||||||
|
f'Try one of these:\n{list(model_dict.keys())}')
|
||||||
|
|
||||||
|
@property
|
||||||
|
def main(self):
|
||||||
|
return self._get_namespace_for_section('main')
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model(self):
|
||||||
|
return self._get_namespace_for_section('model')
|
||||||
|
|
||||||
|
@property
|
||||||
|
def train(self):
|
||||||
|
return self._get_namespace_for_section('train')
|
||||||
|
|
||||||
|
@property
|
||||||
|
def data(self):
|
||||||
|
return self._get_namespace_for_section('data')
|
||||||
|
|
||||||
|
@property
|
||||||
|
def project(self):
|
||||||
|
return self._get_namespace_for_section('project')
|
||||||
|
###################################################
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model_paramters(self):
|
||||||
|
return ModelParameters(self.model, self.train, self.data)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def tags(self, ):
|
||||||
|
return [f'{key}: {val}' for key, val in self.serializable.items()]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def serializable(self):
|
||||||
|
return {f'{section}_{key}': val for section, params in self._sections.items()
|
||||||
|
for key, val in params.items() if is_jsonable(val)}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def as_dict(self):
|
||||||
|
return self._sections
|
||||||
|
|
||||||
|
def _get_namespace_for_section(self, item):
|
||||||
|
return Namespace(**{key: self.get(item, key) for key in self[item]})
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super(Config, self).__init__(**kwargs)
|
||||||
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _sort_combined_section_key_mapping(dict_obj):
|
||||||
|
sorted_dict = defaultdict(dict)
|
||||||
|
for key in dict_obj:
|
||||||
|
section, *attr_name = key.split('_')
|
||||||
|
attr_name = '_'.join(attr_name)
|
||||||
|
value = str(dict_obj[key])
|
||||||
|
|
||||||
|
sorted_dict[section][attr_name] = value
|
||||||
|
# noinspection PyTypeChecker
|
||||||
|
return dict(sorted_dict)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def read_namespace(cls, namespace: Namespace):
|
||||||
|
|
||||||
|
sorted_dict = cls._sort_combined_section_key_mapping(namespace.__dict__)
|
||||||
|
new_config = cls()
|
||||||
|
new_config.read_dict(sorted_dict)
|
||||||
|
return new_config
|
||||||
|
|
||||||
|
def update(self, mapping):
|
||||||
|
sorted_dict = self._sort_combined_section_key_mapping(mapping)
|
||||||
|
for section in sorted_dict:
|
||||||
|
if self.has_section(section):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
self.add_section(section)
|
||||||
|
for option, value in sorted_dict[section].items():
|
||||||
|
self.set(section, option, value)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def get(self, *args, **kwargs):
|
||||||
|
item = super(Config, self).get(*args, **kwargs)
|
||||||
|
try:
|
||||||
|
return ast.literal_eval(item)
|
||||||
|
except SyntaxError:
|
||||||
|
return item
|
||||||
|
except ValueError:
|
||||||
|
return item
|
||||||
|
|
||||||
|
def write(self, filepath, **kwargs):
|
||||||
|
path = Path(filepath, exist_ok=True)
|
||||||
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
with path.open('w') as configfile:
|
||||||
|
super().write(configfile)
|
||||||
|
return True
|
108
utils/logging.py
Normal file
108
utils/logging.py
Normal file
@ -0,0 +1,108 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from pytorch_lightning.loggers.base import LightningLoggerBase
|
||||||
|
from pytorch_lightning.loggers.neptune import NeptuneLogger
|
||||||
|
from pytorch_lightning.loggers.test_tube import TestTubeLogger
|
||||||
|
|
||||||
|
from ml_lib.utils.config import Config
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class Logger(LightningLoggerBase):
|
||||||
|
|
||||||
|
media_dir = 'media'
|
||||||
|
|
||||||
|
@property
|
||||||
|
def experiment(self):
|
||||||
|
if self.debug:
|
||||||
|
return self.testtubelogger.experiment
|
||||||
|
else:
|
||||||
|
return self.neptunelogger.experiment
|
||||||
|
|
||||||
|
@property
|
||||||
|
def log_dir(self):
|
||||||
|
return Path(self.testtubelogger.experiment.get_logdir()).parent
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self):
|
||||||
|
return self.config.model.type
|
||||||
|
|
||||||
|
@property
|
||||||
|
def project_name(self):
|
||||||
|
return f"{self.config.project.owner}/{self.config.project.name}"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def version(self):
|
||||||
|
return self.config.get('main', 'seed')
|
||||||
|
|
||||||
|
@property
|
||||||
|
def outpath(self):
|
||||||
|
# ToDo: Add further path modification such as dataset config etc.
|
||||||
|
return Path(self.config.train.outpath) / self.config.data.mode
|
||||||
|
|
||||||
|
def __init__(self, config: Config):
|
||||||
|
"""
|
||||||
|
params (dict|None): Optional. Parameters of the experiment. After experiment creation params are read-only.
|
||||||
|
Parameters are displayed in the experiment’s Parameters section and each key-value pair can be
|
||||||
|
viewed in experiments view as a column.
|
||||||
|
properties (dict|None): Optional default is {}. Properties of the experiment.
|
||||||
|
They are editable after experiment is created. Properties are displayed in the experiment’s Details and
|
||||||
|
each key-value pair can be viewed in experiments view as a column.
|
||||||
|
tags (list|None): Optional default []. Must be list of str. Tags of the experiment.
|
||||||
|
They are editable after experiment is created (see: append_tag() and remove_tag()).
|
||||||
|
Tags are displayed in the experiment’s Details and can be viewed in experiments view as a column.
|
||||||
|
"""
|
||||||
|
super(Logger, self).__init__()
|
||||||
|
|
||||||
|
self.config = config
|
||||||
|
self.debug = self.config.main.debug
|
||||||
|
self._testtube_kwargs = dict(save_dir=self.outpath, version=self.version, name=self.name)
|
||||||
|
self._neptune_kwargs = dict(offline_mode=self.debug,
|
||||||
|
api_key=self.config.project.neptune_key,
|
||||||
|
project_name=self.project_name,
|
||||||
|
upload_source_files=list())
|
||||||
|
self.neptunelogger = NeptuneLogger(**self._neptune_kwargs)
|
||||||
|
self.testtubelogger = TestTubeLogger(**self._testtube_kwargs)
|
||||||
|
|
||||||
|
def log_hyperparams(self, params):
|
||||||
|
self.neptunelogger.log_hyperparams(params)
|
||||||
|
self.testtubelogger.log_hyperparams(params)
|
||||||
|
pass
|
||||||
|
|
||||||
|
def log_metrics(self, metrics, step=None):
|
||||||
|
self.neptunelogger.log_metrics(metrics, step=step)
|
||||||
|
self.testtubelogger.log_metrics(metrics, step=step)
|
||||||
|
pass
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
self.testtubelogger.close()
|
||||||
|
self.neptunelogger.close()
|
||||||
|
|
||||||
|
def log_config_as_ini(self):
|
||||||
|
self.config.write(self.log_dir / 'config.ini')
|
||||||
|
|
||||||
|
def log_metric(self, metric_name, metric_value, **kwargs):
|
||||||
|
self.testtubelogger.log_metrics(dict(metric_name=metric_value))
|
||||||
|
self.neptunelogger.log_metric(metric_name, metric_value, **kwargs)
|
||||||
|
|
||||||
|
def log_image(self, name, image, **kwargs):
|
||||||
|
self.neptunelogger.log_image(name, image, **kwargs)
|
||||||
|
step = kwargs.get('step', None)
|
||||||
|
name = f'{step}_{name}' if step is not None else name
|
||||||
|
image.savefig(self.log_dir / self.media_dir / name)
|
||||||
|
|
||||||
|
def save(self):
|
||||||
|
self.testtubelogger.save()
|
||||||
|
self.neptunelogger.save()
|
||||||
|
|
||||||
|
def finalize(self, status):
|
||||||
|
self.testtubelogger.finalize(status)
|
||||||
|
self.neptunelogger.finalize(status)
|
||||||
|
self.log_config_as_ini()
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
self.finalize('success')
|
||||||
|
pass
|
61
utils/model_io.py
Normal file
61
utils/model_io.py
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
from argparse import Namespace
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from natsort import natsorted
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
|
# Hyperparamter Object
|
||||||
|
class ModelParameters(Namespace):
|
||||||
|
|
||||||
|
_activations = dict(
|
||||||
|
leaky_relu=nn.LeakyReLU,
|
||||||
|
relu=nn.ReLU,
|
||||||
|
sigmoid=nn.Sigmoid,
|
||||||
|
tanh=nn.Tanh
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self, model_param, train_param, data_param):
|
||||||
|
self.model_param = model_param
|
||||||
|
self.train_param = train_param
|
||||||
|
self.data_param = data_param
|
||||||
|
kwargs = vars(model_param)
|
||||||
|
kwargs.update(vars(train_param))
|
||||||
|
kwargs.update(vars(data_param))
|
||||||
|
super(ModelParameters, self).__init__(**kwargs)
|
||||||
|
|
||||||
|
def __getattribute__(self, item):
|
||||||
|
if item == 'activation':
|
||||||
|
try:
|
||||||
|
return self._activations[item]
|
||||||
|
except KeyError:
|
||||||
|
return nn.ReLU
|
||||||
|
return super(ModelParameters, self).__getattribute__(item)
|
||||||
|
|
||||||
|
|
||||||
|
class SavedLightningModels(object):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load_checkpoint(cls, models_root_path, model=None, n=-1, tags_file_path=''):
|
||||||
|
assert models_root_path.exists(), f'The path {models_root_path.absolute()} does not exist!'
|
||||||
|
found_checkpoints = list(Path(models_root_path).rglob('*.ckpt'))
|
||||||
|
|
||||||
|
found_checkpoints = natsorted(found_checkpoints, key=lambda y: y.name)
|
||||||
|
if model is None:
|
||||||
|
model = torch.load(models_root_path / 'model_class.obj')
|
||||||
|
assert model is not None
|
||||||
|
|
||||||
|
return cls(weights=found_checkpoints[n], model=model)
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
self.weights: str = kwargs.get('weights', '')
|
||||||
|
|
||||||
|
self.model = kwargs.get('model', None)
|
||||||
|
assert self.model is not None
|
||||||
|
|
||||||
|
def restore(self):
|
||||||
|
pretrained_model = self.model.load_from_checkpoint(self.weights)
|
||||||
|
pretrained_model.eval()
|
||||||
|
pretrained_model.freeze()
|
||||||
|
return pretrained_model
|
25
utils/parallel.py
Normal file
25
utils/parallel.py
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
import multiprocessing as mp
|
||||||
|
import time
|
||||||
|
|
||||||
|
|
||||||
|
def run_n_in_parallel(f, n, processes=0, **kwargs):
|
||||||
|
processes = processes if processes else mp.cpu_count()
|
||||||
|
output = mp.Queue()
|
||||||
|
kwargs.update(output=output)
|
||||||
|
# Setup a list of processes that we want to run
|
||||||
|
|
||||||
|
processes = [mp.Process(target=f, kwargs=kwargs) for _ in range(n)]
|
||||||
|
# Run processes
|
||||||
|
results = []
|
||||||
|
for p in processes:
|
||||||
|
p.start()
|
||||||
|
while len(results) != n:
|
||||||
|
time.sleep(1)
|
||||||
|
# Get process results from the output queue
|
||||||
|
results.extend([output.get() for _ in processes])
|
||||||
|
|
||||||
|
# Exit the completed processes
|
||||||
|
for p in processes:
|
||||||
|
p.join()
|
||||||
|
|
||||||
|
return results
|
23
utils/tools.py
Normal file
23
utils/tools.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
import pickle
|
||||||
|
import shelve
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
def write_to_shelve(file_path, value):
|
||||||
|
check_path(file_path)
|
||||||
|
file_path.parent.mkdir(exist_ok=True, parents=True)
|
||||||
|
with shelve.open(str(file_path), protocol=pickle.HIGHEST_PROTOCOL) as f:
|
||||||
|
new_key = str(len(f))
|
||||||
|
f[new_key] = value
|
||||||
|
f.close()
|
||||||
|
|
||||||
|
|
||||||
|
def load_from_shelve(file_path, key):
|
||||||
|
check_path(file_path)
|
||||||
|
with shelve.open(str(file_path)) as d:
|
||||||
|
return d[key]
|
||||||
|
|
||||||
|
|
||||||
|
def check_path(file_path):
|
||||||
|
assert isinstance(file_path, Path)
|
||||||
|
assert str(file_path).endswith('.pik')
|
11
utils/transforms.py
Normal file
11
utils/transforms.py
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class AsArray(object):
|
||||||
|
def __init__(self, width, height):
|
||||||
|
self.width = width
|
||||||
|
self.height = height
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
array = np.zeros((self.width, self.height))
|
||||||
|
return array
|
0
visualization/__init__.py
Normal file
0
visualization/__init__.py
Normal file
26
visualization/tools.py
Normal file
26
visualization/tools.py
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
|
||||||
|
class Plotter(object):
|
||||||
|
def __init__(self, root_path=''):
|
||||||
|
self.root_path = Path(root_path)
|
||||||
|
|
||||||
|
def save_current_figure(self, path, extention='.png'):
|
||||||
|
fig, _ = plt.gcf(), plt.gca()
|
||||||
|
# Prepare save location and check img file extention
|
||||||
|
path = self.root_path / Path(path if str(path).endswith(extention) else f'{str(path)}{extention}')
|
||||||
|
path.parent.mkdir(exist_ok=True, parents=True)
|
||||||
|
fig.savefig(path)
|
||||||
|
fig.clf()
|
||||||
|
|
||||||
|
def show_current_figure(self):
|
||||||
|
fig, _ = plt.gcf(), plt.gca()
|
||||||
|
fig.show()
|
||||||
|
fig.clf()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
output_root = Path('..') / 'output'
|
||||||
|
p = Plotter(output_root)
|
||||||
|
p.save_current_figure('test.png')
|
Loading…
x
Reference in New Issue
Block a user