Refactoring
This commit is contained in:
117
generator_eval.py
Normal file
117
generator_eval.py
Normal file
@ -0,0 +1,117 @@
|
||||
from collections import defaultdict
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.cm as cmaps
|
||||
from mpl_toolkits.axisartist.axes_grid import ImageGrid
|
||||
from sklearn.cluster import Birch, DBSCAN, KMeans
|
||||
from sklearn.decomposition import PCA
|
||||
from sklearn.manifold import TSNE
|
||||
|
||||
import lib.variables as V
|
||||
import numpy as np
|
||||
|
||||
|
||||
class GeneratorVisualizer(object):
|
||||
|
||||
def __init__(self, outputs, k=8):
|
||||
d = defaultdict(list)
|
||||
for key in outputs.keys():
|
||||
try:
|
||||
d[key] = outputs[key][:k].cpu().numpy()
|
||||
except AttributeError:
|
||||
d[key] = outputs[key][:k]
|
||||
except TypeError:
|
||||
self.batch_nb = outputs[key]
|
||||
for key in d.keys():
|
||||
self.__setattr__(key, d[key])
|
||||
|
||||
# val_results = dict(discriminated_bce_loss, batch_nb, pred_label, label, generated_alternative)
|
||||
self._map_width, self._map_height = self.input.shape[1], self.input.shape[2]
|
||||
self.column_dict_list = self._build_column_dict_list()
|
||||
self._cols = len(self.column_dict_list)
|
||||
self._rows = len(self.column_dict_list[0])
|
||||
|
||||
self.colormap = cmaps.tab20
|
||||
|
||||
def _build_column_dict_list(self):
|
||||
trajectories = []
|
||||
alternatives = []
|
||||
|
||||
for idx in range(self.output.shape[0]):
|
||||
image = (self.output[idx]).squeeze()
|
||||
label = 'Homotopic' if self.labels[idx].item() == V.HOMOTOPIC else 'Alternative'
|
||||
alternatives.append(dict(image=image, label=label))
|
||||
|
||||
for idx in range(len(alternatives)):
|
||||
image = (self.input[idx]).squeeze()
|
||||
label = 'original'
|
||||
trajectories.append(dict(image=image, label=label))
|
||||
|
||||
return trajectories, alternatives
|
||||
|
||||
@staticmethod
|
||||
def cluster_data(data):
|
||||
|
||||
cluster = Birch()
|
||||
|
||||
labels = cluster.fit_predict(data)
|
||||
return labels
|
||||
|
||||
def draw_latent(self):
|
||||
plt.close('all')
|
||||
clusterer = KMeans(10)
|
||||
try:
|
||||
labels = clusterer.fit_predict(self.logvar)
|
||||
except ValueError:
|
||||
fig = plt.figure()
|
||||
return fig
|
||||
if self.z.shape[-1] > 2:
|
||||
fig, axs = plt.subplots(ncols=2, nrows=1)
|
||||
transformers = [TSNE(2), PCA(2)]
|
||||
for idx, transformer in enumerate(transformers):
|
||||
transformed = transformer.fit_transform(self.z)
|
||||
|
||||
colored = self.colormap(labels)
|
||||
ax = axs[idx]
|
||||
ax.scatter(x=transformed[:, 0], y=transformed[:, 1], c=colored)
|
||||
ax.set_title(transformer.__class__.__name__)
|
||||
ax.set_xlim(np.min(transformed[:, 0])*1.1, np.max(transformed[:, 0]*1.1))
|
||||
ax.set_ylim(np.min(transformed[:, 1]*1.1), np.max(transformed[:, 1]*1.1))
|
||||
elif self.z.shape[-1] == 2:
|
||||
fig, axs = plt.subplots()
|
||||
|
||||
# TODO: Build transformation for lat_dim_size >= 3
|
||||
print('All Predictions sucesfully Gathered and Shaped ')
|
||||
axs.set_xlim(np.min(self.z[:, 0]), np.max(self.z[:, 0]))
|
||||
axs.set_ylim(np.min(self.z[:, 1]), np.max(self.z[:, 1]))
|
||||
# ToDo: Insert Normalization
|
||||
colored = self.colormap(labels)
|
||||
plt.scatter(self.z[:, 0], self.z[:, 1], c=colored)
|
||||
else:
|
||||
raise NotImplementedError("Latent Dimensions can not be one-dimensional (yet).")
|
||||
|
||||
return fig
|
||||
|
||||
def draw_io_bundle(self):
|
||||
width, height = self._cols * 5, self._rows * 5
|
||||
additional_size = self._cols * V.PADDING + 3 * V.PADDING
|
||||
# width = (self._map_width * self._cols) / V.DPI + additional_size
|
||||
# height = (self._map_height * self._rows) / V.DPI + additional_size
|
||||
fig = plt.figure(figsize=(width, height), dpi=V.DPI)
|
||||
grid = ImageGrid(fig, 111, # similar to subplot(111)
|
||||
nrows_ncols=(self._rows, self._cols),
|
||||
axes_pad=V.PADDING, # pad between axes in inch.
|
||||
)
|
||||
|
||||
for idx in range(len(grid.axes_all)):
|
||||
row, col = divmod(idx, len(self.column_dict_list))
|
||||
if self.column_dict_list[col][row] is not None:
|
||||
current_image = self.column_dict_list[col][row]['image']
|
||||
current_label = self.column_dict_list[col][row]['label']
|
||||
grid[idx].imshow(current_image)
|
||||
grid[idx].title.set_text(current_label)
|
||||
else:
|
||||
continue
|
||||
fig.cbar_mode = 'single'
|
||||
fig.tight_layout()
|
||||
return fig
|
Reference in New Issue
Block a user