44 lines
1.7 KiB
Python
44 lines
1.7 KiB
Python
import torch
|
|
|
|
import matplotlib.pyplot as plt
|
|
from mpl_toolkits.axisartist.axes_grid import ImageGrid
|
|
from tqdm import tqdm
|
|
from typing import List
|
|
|
|
|
|
class GeneratorVisualizer(object):
|
|
|
|
def __init__(self, maps, trajectories, labels, val_result_dict):
|
|
# val_results = dict(discriminated_bce_loss, batch_nb, pred_label, label, generated_alternative)
|
|
self.generated_alternatives = val_result_dict['generated_alternative']
|
|
self.pred_labels = val_result_dict['pred_label']
|
|
self.labels = labels
|
|
self.trajectories = trajectories
|
|
self.maps = maps
|
|
self.column_dict_list = self._build_column_dict_list()
|
|
|
|
def _build_column_dict_list(self):
|
|
dict_list = []
|
|
for idx in range(self.maps):
|
|
image = self.maps[idx] + self.trajectories[idx] + self.generated_alternatives
|
|
label = self.labels[idx]
|
|
dict_list.append(dict(image=image, label=label))
|
|
half_size = int(len(dict_list) // 2)
|
|
return dict_list[:half_size], dict_list[half_size:]
|
|
|
|
def draw(self):
|
|
fig = plt.figure()
|
|
grid = ImageGrid(fig, 111, # similar to subplot(111)
|
|
nrows_ncols=(len(self.column_dict_list[0]), len(self.column_dict_list)),
|
|
axes_pad=0.2, # pad between axes in inch.
|
|
)
|
|
|
|
for idx in grid.axes_all:
|
|
row, col = divmod(idx, len(self.column_dict_list))
|
|
current_image = self.column_dict_list[col]['image'][row]
|
|
current_label = self.column_dict_list[col]['label'][row]
|
|
grid[idx].imshow(current_image)
|
|
grid[idx].title.set_text(current_label)
|
|
fig.cbar_mode = 'single'
|
|
return fig
|