Offline Datasets res net optionality

This commit is contained in:
Si11ium
2020-03-12 18:32:23 +01:00
parent 2f99341cc3
commit bb47e07566
11 changed files with 638 additions and 140 deletions

View File

@@ -5,12 +5,13 @@ import lib.variables as V
class GeneratorVisualizer(object):
def __init__(self, maps, trajectories, labels, val_result_dict):
def __init__(self, **kwargs):
# val_results = dict(discriminated_bce_loss, batch_nb, pred_label, label, generated_alternative)
self.alternatives = val_result_dict['generated_alternative']
self.labels = labels
self.trajectories = trajectories
self.maps = maps
self.alternatives = kwargs.get('generated_alternative')
self.labels = kwargs.get('labels')
self.trajectories = kwargs.get('trajectories')
self.maps = kwargs.get('maps')
self._map_width, self._map_height = self.maps[0].squeeze().shape
self.column_dict_list = self._build_column_dict_list()
self._cols = len(self.column_dict_list)
@@ -24,10 +25,13 @@ class GeneratorVisualizer(object):
for idx in range(self.alternatives.shape[0]):
image = (self.alternatives[idx]).cpu().numpy().squeeze()
label = self.labels[idx].item()
# Dirty and Quick hack incomming.
if label == V.HOMOTOPIC:
hom_alternatives.append(dict(image=image, label='Homotopic'))
non_hom_alternatives.append(None)
else:
non_hom_alternatives.append(dict(image=image, label='NonHomotopic'))
hom_alternatives.append(None)
for idx in range(max(len(hom_alternatives), len(non_hom_alternatives))):
image = (self.maps[idx] + self.trajectories[idx]).cpu().numpy().squeeze()
label = 'original'
@@ -48,10 +52,13 @@ class GeneratorVisualizer(object):
for idx in range(len(grid.axes_all)):
row, col = divmod(idx, len(self.column_dict_list))
current_image = self.column_dict_list[col][row]['image']
current_label = self.column_dict_list[col][row]['label']
grid[idx].imshow(current_image)
grid[idx].title.set_text(current_label)
if self.column_dict_list[col][row] is not None:
current_image = self.column_dict_list[col][row]['image']
current_label = self.column_dict_list[col][row]['label']
grid[idx].imshow(current_image)
grid[idx].title.set_text(current_label)
else:
continue
fig.cbar_mode = 'single'
fig.tight_layout()
return fig