diff --git a/lib/visualization/generator_eval.py b/lib/visualization/generator_eval.py new file mode 100644 index 0000000..e370f79 --- /dev/null +++ b/lib/visualization/generator_eval.py @@ -0,0 +1,43 @@ +import torch + +import matplotlib.pyplot as plt +from mpl_toolkits.axisartist.axes_grid import ImageGrid +from tqdm import tqdm +from typing import List + + +class GeneratorVisualizer(object): + + def __init__(self, maps, trajectories, labels, val_result_dict): + # val_results = dict(discriminated_bce_loss, batch_nb, pred_label, label, generated_alternative) + self.generated_alternatives = val_result_dict['generated_alternative'] + self.pred_labels = val_result_dict['pred_label'] + self.labels = labels + self.trajectories = trajectories + self.maps = maps + self.column_dict_list = self._build_column_dict_list() + + def _build_column_dict_list(self): + dict_list = [] + for idx in range(self.maps): + image = self.maps[idx] + self.trajectories[idx] + self.generated_alternatives + label = self.labels[idx] + dict_list.append(dict(image=image, label=label)) + half_size = int(len(dict_list) // 2) + return dict_list[:half_size], dict_list[half_size:] + + def draw(self): + fig = plt.figure() + grid = ImageGrid(fig, 111, # similar to subplot(111) + nrows_ncols=(len(self.column_dict_list[0]), len(self.column_dict_list)), + axes_pad=0.2, # pad between axes in inch. + ) + + for idx in grid.axes_all: + row, col = divmod(idx, len(self.column_dict_list)) + current_image = self.column_dict_list[col]['image'][row] + current_label = self.column_dict_list[col]['label'][row] + grid[idx].imshow(current_image) + grid[idx].title.set_text(current_label) + fig.cbar_mode = 'single' + return fig