Model blocks, Model files, rearrange project structure
This commit is contained in:
0
lib/evaluation/__init__.py
Normal file
0
lib/evaluation/__init__.py
Normal file
37
lib/evaluation/classification.py
Normal file
37
lib/evaluation/classification.py
Normal file
@@ -0,0 +1,37 @@
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from sklearn.metrics import roc_curve, auc
|
||||
|
||||
|
||||
class ROCEvaluation(object):
|
||||
|
||||
BINARY_PROBLEM = 2
|
||||
linewidth = 2
|
||||
|
||||
def __init__(self, save_fig=True):
|
||||
self.epoch = 0
|
||||
pass
|
||||
|
||||
def __call__(self, prediction, label, prepare_fig=True):
|
||||
|
||||
# Compute ROC curve and ROC area
|
||||
fpr, tpr, _ = roc_curve(prediction, label)
|
||||
roc_auc = auc(fpr, tpr)
|
||||
|
||||
if prepare_fig:
|
||||
fig = self._prepare_fig()
|
||||
fig.plot(fpr, tpr, color='darkorange',
|
||||
lw=2, label=f'ROC curve (area = {roc_auc})')
|
||||
self._prepare_fig()
|
||||
return roc_auc
|
||||
|
||||
def _prepare_fig(self):
|
||||
fig = plt.gcf()
|
||||
fig.plot([0, 1], [0, 1], color='navy', lw=self.linewidth, linestyle='--')
|
||||
fig.xlim([0.0, 1.0])
|
||||
fig.ylim([0.0, 1.05])
|
||||
fig.xlabel('False Positive Rate')
|
||||
fig.ylabel('True Positive Rate')
|
||||
|
||||
fig.legend(loc="lower right")
|
||||
return fig
|
||||
0
lib/evaluation/homotopic.py
Normal file
0
lib/evaluation/homotopic.py
Normal file
@@ -156,6 +156,26 @@ class LightningBaseModule(pl.LightningModule, ABC):
|
||||
self.apply(_weight_init)
|
||||
|
||||
|
||||
class FilterLayer(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(FilterLayer, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
tensor = x[:, -1]
|
||||
return tensor
|
||||
|
||||
|
||||
class MergingLayer(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(MergingLayer, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
# ToDo: Which ones to combine?
|
||||
return
|
||||
|
||||
|
||||
#
|
||||
# Sub - Modules
|
||||
###################
|
||||
@@ -241,6 +261,32 @@ class DeConvModule(nn.Module):
|
||||
return self.shape
|
||||
|
||||
|
||||
class RecurrentModule(nn.Module):
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
x = torch.randn(self.in_shape).unsqueeze(0)
|
||||
output = self(x)
|
||||
return output.shape[1:]
|
||||
|
||||
def __init__(self, in_shape, hidden_size, num_layers=1, cell_type=nn.GRU, use_bias=True, dropout=0):
|
||||
super(RecurrentModule, self).__init__()
|
||||
self.use_bias = use_bias
|
||||
self.num_layers = num_layers
|
||||
self.in_shape = in_shape
|
||||
self.hidden_size = hidden_size
|
||||
self.dropout = dropout
|
||||
self.rnn = cell_type(self.in_shape[-1] * self.in_shape[-2], hidden_size,
|
||||
num_layers=num_layers,
|
||||
bias=self.use_bias,
|
||||
batch_first=True,
|
||||
dropout=self.dropout)
|
||||
|
||||
def forward(self, x):
|
||||
tensor = self.rnn(x)
|
||||
return tensor
|
||||
|
||||
|
||||
#
|
||||
# Full Model Parts
|
||||
###################
|
||||
|
||||
29
lib/models/cnn.py
Normal file
29
lib/models/cnn.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from lib.models.blocks import RecurrentModule, ConvModule, DeConvModule, Generator, LightningBaseModule
|
||||
|
||||
|
||||
class CNNRouteGeneratorModel(LightningBaseModule):
|
||||
|
||||
@classmethod
|
||||
def name(cls):
|
||||
pass
|
||||
|
||||
def configure_optimizers(self):
|
||||
pass
|
||||
|
||||
def validation_step(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def validation_end(self, outputs):
|
||||
pass
|
||||
|
||||
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def test_step(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def __init__(self, *params):
|
||||
super(CNNRouteGeneratorModel, self).__init__(*params)
|
||||
|
||||
def forward(self, x):
|
||||
pass
|
||||
0
lib/models/full.py
Normal file
0
lib/models/full.py
Normal file
0
lib/models/recurrent.py
Normal file
0
lib/models/recurrent.py
Normal file
0
lib/preprocessing/__init__.py
Normal file
0
lib/preprocessing/__init__.py
Normal file
121
lib/preprocessing/generator.py
Normal file
121
lib/preprocessing/generator.py
Normal file
@@ -0,0 +1,121 @@
|
||||
import multiprocessing as mp
|
||||
import pickle
|
||||
import shelve
|
||||
from collections import defaultdict
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
from tqdm import trange
|
||||
|
||||
from lib.objects.map import Map
|
||||
|
||||
|
||||
class Generator:
|
||||
|
||||
possible_modes = ['one_patching']
|
||||
|
||||
def __init__(self, data_root, map_obj, binary=True):
|
||||
self.binary: bool = binary
|
||||
self.map: Map = map_obj
|
||||
|
||||
self.data_root = Path(data_root)
|
||||
|
||||
def generate_n_trajectories_m_alternatives(self, n, m, dataset_name='', **kwargs):
|
||||
trajectories_with_alternatives = list()
|
||||
for _ in trange(n, desc='Processing Trajectories'):
|
||||
trajectory = self.map.get_random_trajectory()
|
||||
alternatives, labels = self.generate_n_alternatives(trajectory, m, dataset_name=dataset_name, **kwargs)
|
||||
trajectories_with_alternatives.append(dict(trajectory=trajectory, alternatives=alternatives, labels=labels))
|
||||
return trajectories_with_alternatives
|
||||
|
||||
def generate_alternatives(self, trajectory, output: Union[mp.
|
||||
Queue, None] = None, mode='one_patching'):
|
||||
start, dest = trajectory.endpoints
|
||||
if mode == 'one_patching':
|
||||
patch = self.map.get_valid_position()
|
||||
alternative = self.map.get_trajectory_from_vertices(start, patch, dest)
|
||||
else:
|
||||
raise RuntimeError(f'mode checking went wrong...')
|
||||
|
||||
if output:
|
||||
output.put(alternative)
|
||||
return alternative
|
||||
|
||||
def generate_n_alternatives(self, trajectory, n, dataset_name: Union[str, Path] = '',
|
||||
mode='one_patching', equal_samples=True):
|
||||
assert mode in self.possible_modes, f'Parameter "mode" must be either {self.possible_modes}, but was {mode}.'
|
||||
# Define an output queue
|
||||
output = mp.Queue()
|
||||
# Setup a list of processes that we want to run
|
||||
processes = [mp.Process(target=self.generate_alternatives,
|
||||
kwargs=dict(trajectory=trajectory, output=output, mode=mode))
|
||||
for _ in range(n)]
|
||||
# Run processes
|
||||
for p in processes:
|
||||
p.start()
|
||||
# Exit the completed processes
|
||||
for p in processes:
|
||||
p.join()
|
||||
# Get process results from the output queue
|
||||
results = [output.get() for _ in processes]
|
||||
|
||||
# label per homotopic class
|
||||
homotopy_classes = defaultdict(list)
|
||||
homotopy_classes[0].append(trajectory)
|
||||
for i in range(len(results)):
|
||||
alternative = results[i]
|
||||
class_not_found, label = True, None
|
||||
# check for homotopy class
|
||||
for label in homotopy_classes.keys():
|
||||
if self.map.are_homotopic(homotopy_classes[label][0], alternative):
|
||||
homotopy_classes[label].append(alternative)
|
||||
class_not_found = False
|
||||
break
|
||||
if class_not_found:
|
||||
label = len(homotopy_classes)
|
||||
homotopy_classes[label].append(alternative)
|
||||
|
||||
# There should be as much homotopic samples as non-homotopic samples
|
||||
if equal_samples:
|
||||
homotopy_classes = self._remove_unequal(homotopy_classes)
|
||||
|
||||
# Compose lists of alternatives with labels
|
||||
alternatives, labels = list(), list()
|
||||
for key in homotopy_classes.keys():
|
||||
alternatives.extend([homotopy_classes[key]])
|
||||
labels.extend([key] * len(homotopy_classes[key]))
|
||||
|
||||
# Write to disk
|
||||
if dataset_name:
|
||||
self.write_to_disk(dataset_name, trajectory, alternatives, labels)
|
||||
|
||||
# Return
|
||||
return alternatives, labels
|
||||
|
||||
def write_to_disk(self, filepath, trajectory, alternatives, labels):
|
||||
dataset_name = filepath if filepath.endswith('.pik') else f'{filepath}.pik'
|
||||
self.data_root.mkdir(exist_ok=True, parents=True)
|
||||
with shelve.open(str(self.data_root / dataset_name), protocol=pickle.HIGHEST_PROTOCOL) as f:
|
||||
new_key = len(f)
|
||||
f[f'trajectory_{new_key}'] = dict(alternatives=alternatives,
|
||||
trajectory=trajectory,
|
||||
labels=labels)
|
||||
if 'map' not in f:
|
||||
f['map'] = dict(map=self.map, name=f'map_{self.map.name}')
|
||||
|
||||
@staticmethod
|
||||
def _remove_unequal(hom_dict):
|
||||
hom_dict = hom_dict.copy()
|
||||
|
||||
counter = len(hom_dict)
|
||||
while sum([len(hom_dict[class_id]) for class_id in range(len(hom_dict))]) > len(hom_dict[0]):
|
||||
if counter > len(hom_dict):
|
||||
counter = len(hom_dict)
|
||||
if counter in hom_dict:
|
||||
if len(hom_dict[counter]) == 0:
|
||||
del hom_dict[counter]
|
||||
else:
|
||||
del hom_dict[counter][-1]
|
||||
counter -= 1
|
||||
return hom_dict
|
||||
0
lib/visualization/__init__.py
Normal file
0
lib/visualization/__init__.py
Normal file
0
lib/visualization/bars.py
Normal file
0
lib/visualization/bars.py
Normal file
26
lib/visualization/tools.py
Normal file
26
lib/visualization/tools.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from pathlib import Path
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
class Plotter(object):
|
||||
def __init__(self, root_path=''):
|
||||
self.root_path = Path(root_path)
|
||||
|
||||
def save_current_figure(self, path, extention='.png'):
|
||||
fig, _ = plt.gcf(), plt.gca()
|
||||
# Prepare save location and check img file extention
|
||||
path = self.root_path / Path(path if str(path).endswith(extention) else f'{str(path)}{extention}')
|
||||
path.parent.mkdir(exist_ok=True, parents=True)
|
||||
fig.savefig(path)
|
||||
fig.clf()
|
||||
|
||||
def show_current_figure(self):
|
||||
fig, _ = plt.gcf(), plt.gca()
|
||||
fig.show()
|
||||
fig.clf()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
output_root = Path('..') / 'output'
|
||||
p = Plotter(output_root)
|
||||
p.save_current_figure('test.png')
|
||||
Reference in New Issue
Block a user