Kurz vorm durchdrehen
This commit is contained in:
@@ -1,36 +1,49 @@
|
||||
import torch
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
from mpl_toolkits.axisartist.axes_grid import ImageGrid
|
||||
from tqdm import tqdm
|
||||
from typing import List
|
||||
import lib.variables as V
|
||||
|
||||
|
||||
class GeneratorVisualizer(object):
|
||||
|
||||
def __init__(self, maps, trajectories, labels, val_result_dict):
|
||||
# val_results = dict(discriminated_bce_loss, batch_nb, pred_label, label, generated_alternative)
|
||||
self.generated_alternatives = val_result_dict['generated_alternative']
|
||||
self.pred_labels = val_result_dict['pred_label']
|
||||
self.alternatives = val_result_dict['generated_alternative']
|
||||
self.labels = labels
|
||||
self.trajectories = trajectories
|
||||
self.maps = maps
|
||||
self._map_width, self._map_height = self.maps[0].squeeze().shape
|
||||
self.column_dict_list = self._build_column_dict_list()
|
||||
self._cols = len(self.column_dict_list)
|
||||
self._rows = len(self.column_dict_list[0])
|
||||
|
||||
def _build_column_dict_list(self):
|
||||
dict_list = []
|
||||
for idx in range(self.maps.shape[0]):
|
||||
image = (self.maps[idx] + self.trajectories[idx] + self.generated_alternatives[idx]).cpu().numpy().squeeze()
|
||||
label = int(self.labels[idx])
|
||||
dict_list.append(dict(image=image, label=label))
|
||||
half_size = int(len(dict_list) // 2)
|
||||
return dict_list[:half_size], dict_list[half_size:]
|
||||
trajectories = []
|
||||
non_hom_alternatives = []
|
||||
hom_alternatives = []
|
||||
|
||||
for idx in range(self.alternatives.shape[0]):
|
||||
image = (self.alternatives[idx]).cpu().numpy().squeeze()
|
||||
label = self.labels[idx].item()
|
||||
if label == V.HOMOTOPIC:
|
||||
hom_alternatives.append(dict(image=image, label='Homotopic'))
|
||||
else:
|
||||
non_hom_alternatives.append(dict(image=image, label='NonHomotopic'))
|
||||
for idx in range(max(len(hom_alternatives), len(non_hom_alternatives))):
|
||||
image = (self.maps[idx] + self.trajectories[idx]).cpu().numpy().squeeze()
|
||||
label = 'original'
|
||||
trajectories.append(dict(image=image, label=label))
|
||||
|
||||
return trajectories, hom_alternatives, non_hom_alternatives
|
||||
|
||||
def draw(self):
|
||||
fig = plt.figure()
|
||||
padding = 0.25
|
||||
additional_size = self._cols * padding + 3 * padding
|
||||
width = (self._map_width * self._cols) / V.DPI + additional_size
|
||||
height = (self._map_height * self._rows) / V.DPI + additional_size
|
||||
fig = plt.figure(figsize=(width, height), dpi=V.DPI)
|
||||
grid = ImageGrid(fig, 111, # similar to subplot(111)
|
||||
nrows_ncols=(len(self.column_dict_list[0]), len(self.column_dict_list)),
|
||||
axes_pad=0.2, # pad between axes in inch.
|
||||
nrows_ncols=(self._rows, self._cols),
|
||||
axes_pad=padding, # pad between axes in inch.
|
||||
)
|
||||
|
||||
for idx in range(len(grid.axes_all)):
|
||||
@@ -40,4 +53,5 @@ class GeneratorVisualizer(object):
|
||||
grid[idx].imshow(current_image)
|
||||
grid[idx].title.set_text(current_label)
|
||||
fig.cbar_mode = 'single'
|
||||
fig.tight_layout()
|
||||
return fig
|
||||
|
||||
Reference in New Issue
Block a user