import torch import matplotlib.pyplot as plt from mpl_toolkits.axisartist.axes_grid import ImageGrid from tqdm import tqdm from typing import List class GeneratorVisualizer(object): def __init__(self, maps, trajectories, labels, val_result_dict): # val_results = dict(discriminated_bce_loss, batch_nb, pred_label, label, generated_alternative) self.generated_alternatives = val_result_dict['generated_alternative'] self.pred_labels = val_result_dict['pred_label'] self.labels = labels self.trajectories = trajectories self.maps = maps self.column_dict_list = self._build_column_dict_list() def _build_column_dict_list(self): dict_list = [] for idx in range(self.maps): image = self.maps[idx] + self.trajectories[idx] + self.generated_alternatives label = self.labels[idx] dict_list.append(dict(image=image, label=label)) half_size = int(len(dict_list) // 2) return dict_list[:half_size], dict_list[half_size:] def draw(self): fig = plt.figure() grid = ImageGrid(fig, 111, # similar to subplot(111) nrows_ncols=(len(self.column_dict_list[0]), len(self.column_dict_list)), axes_pad=0.2, # pad between axes in inch. ) for idx in grid.axes_all: row, col = divmod(idx, len(self.column_dict_list)) current_image = self.column_dict_list[col]['image'][row] current_label = self.column_dict_list[col]['label'][row] grid[idx].imshow(current_image) grid[idx].title.set_text(current_label) fig.cbar_mode = 'single' return fig