from collections import defaultdict import matplotlib.pyplot as plt import matplotlib.cm as cmaps from mpl_toolkits.axisartist.axes_grid import ImageGrid from sklearn.cluster import Birch, DBSCAN, KMeans from sklearn.decomposition import PCA from sklearn.manifold import TSNE import ml_lib.variables as V import numpy as np class GeneratorVisualizer(object): def __init__(self, outputs, k=8): d = defaultdict(list) for key in outputs.keys(): try: d[key] = outputs[key][:k].cpu().numpy() except AttributeError: d[key] = outputs[key][:k] except TypeError: self.batch_nb = outputs[key] for key in d.keys(): self.__setattr__(key, d[key]) # val_results = dict(discriminated_bce_loss, batch_nb, pred_label, label, generated_alternative) self._map_width, self._map_height = self.input.shape[1], self.input.shape[2] self.column_dict_list = self._build_column_dict_list() self._cols = len(self.column_dict_list) self._rows = len(self.column_dict_list[0]) self.colormap = cmaps.tab20 def _build_column_dict_list(self): trajectories = [] alternatives = [] for idx in range(self.output.shape[0]): image = (self.output[idx]).squeeze() label = 'Homotopic' if self.labels[idx].item() == V.HOMOTOPIC else 'Alternative' alternatives.append(dict(image=image, label=label)) for idx in range(len(alternatives)): image = (self.input[idx]).squeeze() label = 'original' trajectories.append(dict(image=image, label=label)) return trajectories, alternatives @staticmethod def cluster_data(data): cluster = Birch() labels = cluster.fit_predict(data) return labels def draw_latent(self): plt.close('all') clusterer = KMeans(10) try: labels = clusterer.fit_predict(self.logvar) except ValueError: fig = plt.figure() return fig if self.z.shape[-1] > 2: fig, axs = plt.subplots(ncols=2, nrows=1) transformers = [TSNE(2), PCA(2)] for idx, transformer in enumerate(transformers): transformed = transformer.fit_transform(self.z) colored = self.colormap(labels) ax = axs[idx] ax.scatter(x=transformed[:, 0], y=transformed[:, 1], c=colored) ax.set_title(transformer.__class__.__name__) ax.set_xlim(np.min(transformed[:, 0])*1.1, np.max(transformed[:, 0]*1.1)) ax.set_ylim(np.min(transformed[:, 1]*1.1), np.max(transformed[:, 1]*1.1)) elif self.z.shape[-1] == 2: fig, axs = plt.subplots() # TODO: Build transformation for lat_dim_size >= 3 print('All Predictions sucesfully Gathered and Shaped ') axs.set_xlim(np.min(self.z[:, 0]), np.max(self.z[:, 0])) axs.set_ylim(np.min(self.z[:, 1]), np.max(self.z[:, 1])) # ToDo: Insert Normalization colored = self.colormap(labels) plt.scatter(self.z[:, 0], self.z[:, 1], c=colored) else: raise NotImplementedError("Latent Dimensions can not be one-dimensional (yet).") return fig def draw_io_bundle(self): width, height = self._cols * 5, self._rows * 5 additional_size = self._cols * V.PADDING + 3 * V.PADDING # width = (self._map_width * self._cols) / V.DPI + additional_size # height = (self._map_height * self._rows) / V.DPI + additional_size fig = plt.figure(figsize=(width, height), dpi=V.DPI) grid = ImageGrid(fig, 111, # similar to subplot(111) nrows_ncols=(self._rows, self._cols), axes_pad=V.PADDING, # pad between axes in inch. ) for idx in range(len(grid.axes_all)): row, col = divmod(idx, len(self.column_dict_list)) if self.column_dict_list[col][row] is not None: current_image = self.column_dict_list[col][row]['image'] current_label = self.column_dict_list[col][row]['label'] grid[idx].imshow(current_image) grid[idx].title.set_text(current_label) else: continue fig.cbar_mode = 'single' fig.tight_layout() return fig