Kurz vorm durchdrehen

This commit is contained in:
Si11ium
2020-03-11 17:10:19 +01:00
parent 1b5a7dc69e
commit 1f4edae95c
12 changed files with 157 additions and 93 deletions

View File

@@ -1,36 +1,49 @@
import torch
import matplotlib.pyplot as plt
from mpl_toolkits.axisartist.axes_grid import ImageGrid
from tqdm import tqdm
from typing import List
import lib.variables as V
class GeneratorVisualizer(object):
def __init__(self, maps, trajectories, labels, val_result_dict):
# val_results = dict(discriminated_bce_loss, batch_nb, pred_label, label, generated_alternative)
self.generated_alternatives = val_result_dict['generated_alternative']
self.pred_labels = val_result_dict['pred_label']
self.alternatives = val_result_dict['generated_alternative']
self.labels = labels
self.trajectories = trajectories
self.maps = maps
self._map_width, self._map_height = self.maps[0].squeeze().shape
self.column_dict_list = self._build_column_dict_list()
self._cols = len(self.column_dict_list)
self._rows = len(self.column_dict_list[0])
def _build_column_dict_list(self):
dict_list = []
for idx in range(self.maps.shape[0]):
image = (self.maps[idx] + self.trajectories[idx] + self.generated_alternatives[idx]).cpu().numpy().squeeze()
label = int(self.labels[idx])
dict_list.append(dict(image=image, label=label))
half_size = int(len(dict_list) // 2)
return dict_list[:half_size], dict_list[half_size:]
trajectories = []
non_hom_alternatives = []
hom_alternatives = []
for idx in range(self.alternatives.shape[0]):
image = (self.alternatives[idx]).cpu().numpy().squeeze()
label = self.labels[idx].item()
if label == V.HOMOTOPIC:
hom_alternatives.append(dict(image=image, label='Homotopic'))
else:
non_hom_alternatives.append(dict(image=image, label='NonHomotopic'))
for idx in range(max(len(hom_alternatives), len(non_hom_alternatives))):
image = (self.maps[idx] + self.trajectories[idx]).cpu().numpy().squeeze()
label = 'original'
trajectories.append(dict(image=image, label=label))
return trajectories, hom_alternatives, non_hom_alternatives
def draw(self):
fig = plt.figure()
padding = 0.25
additional_size = self._cols * padding + 3 * padding
width = (self._map_width * self._cols) / V.DPI + additional_size
height = (self._map_height * self._rows) / V.DPI + additional_size
fig = plt.figure(figsize=(width, height), dpi=V.DPI)
grid = ImageGrid(fig, 111, # similar to subplot(111)
nrows_ncols=(len(self.column_dict_list[0]), len(self.column_dict_list)),
axes_pad=0.2, # pad between axes in inch.
nrows_ncols=(self._rows, self._cols),
axes_pad=padding, # pad between axes in inch.
)
for idx in range(len(grid.axes_all)):
@@ -40,4 +53,5 @@ class GeneratorVisualizer(object):
grid[idx].imshow(current_image)
grid[idx].title.set_text(current_label)
fig.cbar_mode = 'single'
fig.tight_layout()
return fig