hom_traj_gen/lib/visualization/generator_eval.py
2020-03-09 22:01:10 +01:00

44 lines
1.7 KiB
Python

import torch
import matplotlib.pyplot as plt
from mpl_toolkits.axisartist.axes_grid import ImageGrid
from tqdm import tqdm
from typing import List
class GeneratorVisualizer(object):
def __init__(self, maps, trajectories, labels, val_result_dict):
# val_results = dict(discriminated_bce_loss, batch_nb, pred_label, label, generated_alternative)
self.generated_alternatives = val_result_dict['generated_alternative']
self.pred_labels = val_result_dict['pred_label']
self.labels = labels
self.trajectories = trajectories
self.maps = maps
self.column_dict_list = self._build_column_dict_list()
def _build_column_dict_list(self):
dict_list = []
for idx in range(self.maps):
image = self.maps[idx] + self.trajectories[idx] + self.generated_alternatives
label = self.labels[idx]
dict_list.append(dict(image=image, label=label))
half_size = int(len(dict_list) // 2)
return dict_list[:half_size], dict_list[half_size:]
def draw(self):
fig = plt.figure()
grid = ImageGrid(fig, 111, # similar to subplot(111)
nrows_ncols=(len(self.column_dict_list[0]), len(self.column_dict_list)),
axes_pad=0.2, # pad between axes in inch.
)
for idx in grid.axes_all:
row, col = divmod(idx, len(self.column_dict_list))
current_image = self.column_dict_list[col]['image'][row]
current_label = self.column_dict_list[col]['label'][row]
grid[idx].imshow(current_image)
grid[idx].title.set_text(current_label)
fig.cbar_mode = 'single'
return fig