train running dataset fixed
This commit is contained in:
parent
e3a9149f00
commit
fa3312e9d8
43
lib/visualization/generator_eval.py
Normal file
43
lib/visualization/generator_eval.py
Normal file
@ -0,0 +1,43 @@
|
||||
import torch
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
from mpl_toolkits.axisartist.axes_grid import ImageGrid
|
||||
from tqdm import tqdm
|
||||
from typing import List
|
||||
|
||||
|
||||
class GeneratorVisualizer(object):
|
||||
|
||||
def __init__(self, maps, trajectories, labels, val_result_dict):
|
||||
# val_results = dict(discriminated_bce_loss, batch_nb, pred_label, label, generated_alternative)
|
||||
self.generated_alternatives = val_result_dict['generated_alternative']
|
||||
self.pred_labels = val_result_dict['pred_label']
|
||||
self.labels = labels
|
||||
self.trajectories = trajectories
|
||||
self.maps = maps
|
||||
self.column_dict_list = self._build_column_dict_list()
|
||||
|
||||
def _build_column_dict_list(self):
|
||||
dict_list = []
|
||||
for idx in range(self.maps):
|
||||
image = self.maps[idx] + self.trajectories[idx] + self.generated_alternatives
|
||||
label = self.labels[idx]
|
||||
dict_list.append(dict(image=image, label=label))
|
||||
half_size = int(len(dict_list) // 2)
|
||||
return dict_list[:half_size], dict_list[half_size:]
|
||||
|
||||
def draw(self):
|
||||
fig = plt.figure()
|
||||
grid = ImageGrid(fig, 111, # similar to subplot(111)
|
||||
nrows_ncols=(len(self.column_dict_list[0]), len(self.column_dict_list)),
|
||||
axes_pad=0.2, # pad between axes in inch.
|
||||
)
|
||||
|
||||
for idx in grid.axes_all:
|
||||
row, col = divmod(idx, len(self.column_dict_list))
|
||||
current_image = self.column_dict_list[col]['image'][row]
|
||||
current_label = self.column_dict_list[col]['label'][row]
|
||||
grid[idx].imshow(current_image)
|
||||
grid[idx].title.set_text(current_label)
|
||||
fig.cbar_mode = 'single'
|
||||
return fig
|
Loading…
x
Reference in New Issue
Block a user