Offline Datasets res net optionality
This commit is contained in:
@@ -5,12 +5,13 @@ import lib.variables as V
|
||||
|
||||
class GeneratorVisualizer(object):
|
||||
|
||||
def __init__(self, maps, trajectories, labels, val_result_dict):
|
||||
def __init__(self, **kwargs):
|
||||
# val_results = dict(discriminated_bce_loss, batch_nb, pred_label, label, generated_alternative)
|
||||
self.alternatives = val_result_dict['generated_alternative']
|
||||
self.labels = labels
|
||||
self.trajectories = trajectories
|
||||
self.maps = maps
|
||||
self.alternatives = kwargs.get('generated_alternative')
|
||||
self.labels = kwargs.get('labels')
|
||||
self.trajectories = kwargs.get('trajectories')
|
||||
self.maps = kwargs.get('maps')
|
||||
|
||||
self._map_width, self._map_height = self.maps[0].squeeze().shape
|
||||
self.column_dict_list = self._build_column_dict_list()
|
||||
self._cols = len(self.column_dict_list)
|
||||
@@ -24,10 +25,13 @@ class GeneratorVisualizer(object):
|
||||
for idx in range(self.alternatives.shape[0]):
|
||||
image = (self.alternatives[idx]).cpu().numpy().squeeze()
|
||||
label = self.labels[idx].item()
|
||||
# Dirty and Quick hack incomming.
|
||||
if label == V.HOMOTOPIC:
|
||||
hom_alternatives.append(dict(image=image, label='Homotopic'))
|
||||
non_hom_alternatives.append(None)
|
||||
else:
|
||||
non_hom_alternatives.append(dict(image=image, label='NonHomotopic'))
|
||||
hom_alternatives.append(None)
|
||||
for idx in range(max(len(hom_alternatives), len(non_hom_alternatives))):
|
||||
image = (self.maps[idx] + self.trajectories[idx]).cpu().numpy().squeeze()
|
||||
label = 'original'
|
||||
@@ -48,10 +52,13 @@ class GeneratorVisualizer(object):
|
||||
|
||||
for idx in range(len(grid.axes_all)):
|
||||
row, col = divmod(idx, len(self.column_dict_list))
|
||||
current_image = self.column_dict_list[col][row]['image']
|
||||
current_label = self.column_dict_list[col][row]['label']
|
||||
grid[idx].imshow(current_image)
|
||||
grid[idx].title.set_text(current_label)
|
||||
if self.column_dict_list[col][row] is not None:
|
||||
current_image = self.column_dict_list[col][row]['image']
|
||||
current_label = self.column_dict_list[col][row]['label']
|
||||
grid[idx].imshow(current_image)
|
||||
grid[idx].title.set_text(current_label)
|
||||
else:
|
||||
continue
|
||||
fig.cbar_mode = 'single'
|
||||
fig.tight_layout()
|
||||
return fig
|
||||
|
||||
Reference in New Issue
Block a user