hom_traj_gen/generator_eval.py
2020-04-08 14:50:16 +02:00

118 lines
4.4 KiB
Python

from collections import defaultdict
import matplotlib.pyplot as plt
import matplotlib.cm as cmaps
from mpl_toolkits.axisartist.axes_grid import ImageGrid
from sklearn.cluster import Birch, DBSCAN, KMeans
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
import ml_lib.variables as V
import numpy as np
class GeneratorVisualizer(object):
def __init__(self, outputs, k=8):
d = defaultdict(list)
for key in outputs.keys():
try:
d[key] = outputs[key][:k].cpu().numpy()
except AttributeError:
d[key] = outputs[key][:k]
except TypeError:
self.batch_nb = outputs[key]
for key in d.keys():
self.__setattr__(key, d[key])
# val_results = dict(discriminated_bce_loss, batch_nb, pred_label, label, generated_alternative)
self._map_width, self._map_height = self.input.shape[1], self.input.shape[2]
self.column_dict_list = self._build_column_dict_list()
self._cols = len(self.column_dict_list)
self._rows = len(self.column_dict_list[0])
self.colormap = cmaps.tab20
def _build_column_dict_list(self):
trajectories = []
alternatives = []
for idx in range(self.output.shape[0]):
image = (self.output[idx]).squeeze()
label = 'Homotopic' if self.labels[idx].item() == V.HOMOTOPIC else 'Alternative'
alternatives.append(dict(image=image, label=label))
for idx in range(len(alternatives)):
image = (self.input[idx]).squeeze()
label = 'original'
trajectories.append(dict(image=image, label=label))
return trajectories, alternatives
@staticmethod
def cluster_data(data):
cluster = Birch()
labels = cluster.fit_predict(data)
return labels
def draw_latent(self):
plt.close('all')
clusterer = KMeans(10)
try:
labels = clusterer.fit_predict(self.logvar)
except ValueError:
fig = plt.figure()
return fig
if self.z.shape[-1] > 2:
fig, axs = plt.subplots(ncols=2, nrows=1)
transformers = [TSNE(2), PCA(2)]
for idx, transformer in enumerate(transformers):
transformed = transformer.fit_transform(self.z)
colored = self.colormap(labels)
ax = axs[idx]
ax.scatter(x=transformed[:, 0], y=transformed[:, 1], c=colored)
ax.set_title(transformer.__class__.__name__)
ax.set_xlim(np.min(transformed[:, 0])*1.1, np.max(transformed[:, 0]*1.1))
ax.set_ylim(np.min(transformed[:, 1]*1.1), np.max(transformed[:, 1]*1.1))
elif self.z.shape[-1] == 2:
fig, axs = plt.subplots()
# TODO: Build transformation for lat_dim_size >= 3
print('All Predictions sucesfully Gathered and Shaped ')
axs.set_xlim(np.min(self.z[:, 0]), np.max(self.z[:, 0]))
axs.set_ylim(np.min(self.z[:, 1]), np.max(self.z[:, 1]))
# ToDo: Insert Normalization
colored = self.colormap(labels)
plt.scatter(self.z[:, 0], self.z[:, 1], c=colored)
else:
raise NotImplementedError("Latent Dimensions can not be one-dimensional (yet).")
return fig
def draw_io_bundle(self):
width, height = self._cols * 5, self._rows * 5
additional_size = self._cols * V.PADDING + 3 * V.PADDING
# width = (self._map_width * self._cols) / V.DPI + additional_size
# height = (self._map_height * self._rows) / V.DPI + additional_size
fig = plt.figure(figsize=(width, height), dpi=V.DPI)
grid = ImageGrid(fig, 111, # similar to subplot(111)
nrows_ncols=(self._rows, self._cols),
axes_pad=V.PADDING, # pad between axes in inch.
)
for idx in range(len(grid.axes_all)):
row, col = divmod(idx, len(self.column_dict_list))
if self.column_dict_list[col][row] is not None:
current_image = self.column_dict_list[col][row]['image']
current_label = self.column_dict_list[col][row]['label']
grid[idx].imshow(current_image)
grid[idx].title.set_text(current_label)
else:
continue
fig.cbar_mode = 'single'
fig.tight_layout()
return fig