118 lines
4.4 KiB
Python
118 lines
4.4 KiB
Python
from collections import defaultdict
|
|
|
|
import matplotlib.pyplot as plt
|
|
import matplotlib.cm as cmaps
|
|
from mpl_toolkits.axisartist.axes_grid import ImageGrid
|
|
from sklearn.cluster import Birch, DBSCAN, KMeans
|
|
from sklearn.decomposition import PCA
|
|
from sklearn.manifold import TSNE
|
|
|
|
import ml_lib.variables as V
|
|
import numpy as np
|
|
|
|
|
|
class GeneratorVisualizer(object):
|
|
|
|
def __init__(self, outputs, k=8):
|
|
d = defaultdict(list)
|
|
for key in outputs.keys():
|
|
try:
|
|
d[key] = outputs[key][:k].cpu().numpy()
|
|
except AttributeError:
|
|
d[key] = outputs[key][:k]
|
|
except TypeError:
|
|
self.batch_nb = outputs[key]
|
|
for key in d.keys():
|
|
self.__setattr__(key, d[key])
|
|
|
|
# val_results = dict(discriminated_bce_loss, batch_nb, pred_label, label, generated_alternative)
|
|
self._map_width, self._map_height = self.input.shape[1], self.input.shape[2]
|
|
self.column_dict_list = self._build_column_dict_list()
|
|
self._cols = len(self.column_dict_list)
|
|
self._rows = len(self.column_dict_list[0])
|
|
|
|
self.colormap = cmaps.tab20
|
|
|
|
def _build_column_dict_list(self):
|
|
trajectories = []
|
|
alternatives = []
|
|
|
|
for idx in range(self.output.shape[0]):
|
|
image = (self.output[idx]).squeeze()
|
|
label = 'Homotopic' if self.labels[idx].item() == V.HOMOTOPIC else 'Alternative'
|
|
alternatives.append(dict(image=image, label=label))
|
|
|
|
for idx in range(len(alternatives)):
|
|
image = (self.input[idx]).squeeze()
|
|
label = 'original'
|
|
trajectories.append(dict(image=image, label=label))
|
|
|
|
return trajectories, alternatives
|
|
|
|
@staticmethod
|
|
def cluster_data(data):
|
|
|
|
cluster = Birch()
|
|
|
|
labels = cluster.fit_predict(data)
|
|
return labels
|
|
|
|
def draw_latent(self):
|
|
plt.close('all')
|
|
clusterer = KMeans(10)
|
|
try:
|
|
labels = clusterer.fit_predict(self.logvar)
|
|
except ValueError:
|
|
fig = plt.figure()
|
|
return fig
|
|
if self.z.shape[-1] > 2:
|
|
fig, axs = plt.subplots(ncols=2, nrows=1)
|
|
transformers = [TSNE(2), PCA(2)]
|
|
for idx, transformer in enumerate(transformers):
|
|
transformed = transformer.fit_transform(self.z)
|
|
|
|
colored = self.colormap(labels)
|
|
ax = axs[idx]
|
|
ax.scatter(x=transformed[:, 0], y=transformed[:, 1], c=colored)
|
|
ax.set_title(transformer.__class__.__name__)
|
|
ax.set_xlim(np.min(transformed[:, 0])*1.1, np.max(transformed[:, 0]*1.1))
|
|
ax.set_ylim(np.min(transformed[:, 1]*1.1), np.max(transformed[:, 1]*1.1))
|
|
elif self.z.shape[-1] == 2:
|
|
fig, axs = plt.subplots()
|
|
|
|
# TODO: Build transformation for lat_dim_size >= 3
|
|
print('All Predictions sucesfully Gathered and Shaped ')
|
|
axs.set_xlim(np.min(self.z[:, 0]), np.max(self.z[:, 0]))
|
|
axs.set_ylim(np.min(self.z[:, 1]), np.max(self.z[:, 1]))
|
|
# ToDo: Insert Normalization
|
|
colored = self.colormap(labels)
|
|
plt.scatter(self.z[:, 0], self.z[:, 1], c=colored)
|
|
else:
|
|
raise NotImplementedError("Latent Dimensions can not be one-dimensional (yet).")
|
|
|
|
return fig
|
|
|
|
def draw_io_bundle(self):
|
|
width, height = self._cols * 5, self._rows * 5
|
|
additional_size = self._cols * V.PADDING + 3 * V.PADDING
|
|
# width = (self._map_width * self._cols) / V.DPI + additional_size
|
|
# height = (self._map_height * self._rows) / V.DPI + additional_size
|
|
fig = plt.figure(figsize=(width, height), dpi=V.DPI)
|
|
grid = ImageGrid(fig, 111, # similar to subplot(111)
|
|
nrows_ncols=(self._rows, self._cols),
|
|
axes_pad=V.PADDING, # pad between axes in inch.
|
|
)
|
|
|
|
for idx in range(len(grid.axes_all)):
|
|
row, col = divmod(idx, len(self.column_dict_list))
|
|
if self.column_dict_list[col][row] is not None:
|
|
current_image = self.column_dict_list[col][row]['image']
|
|
current_label = self.column_dict_list[col][row]['label']
|
|
grid[idx].imshow(current_image)
|
|
grid[idx].title.set_text(current_label)
|
|
else:
|
|
continue
|
|
fig.cbar_mode = 'single'
|
|
fig.tight_layout()
|
|
return fig
|