# TODO: THIS import seaborn as sb import torch from torch.utils.data import DataLoader from pytorch_lightning import data_loader from dataset import DataContainer from collections import defaultdict from tqdm import tqdm import os from sklearn.manifold import TSNE from sklearn.decomposition import PCA import seaborn as sns; sns.set() import matplotlib.pyplot as plt from run_models import * def search_for_weights(folder): while not os.path.exists(folder): if len(os.path.split(folder)) >= 50: raise FileNotFoundError(f'The folder "{folder}" could not be found') folder = os.path.join(os.pardir, folder) for element in os.scandir(folder): if os.path.exists(element): if element.is_dir(): search_for_weights(element.path) elif element.is_file() and element.name.endswith('.ckpt'): load_and_predict(element) else: continue def load_and_predict(path_like_element): # Define Loop to search for models and folder with visualizations model = globals()[path_like_element.path.split(os.sep)[-3]] pretrained_model = model.load_from_metrics( weights_path=path_like_element.path, tags_csv=os.path.join(os.path.dirname(path_like_element), 'default', 'version_0', 'meta_tags.csv'), on_gpu=True if torch.cuda.is_available() else False, map_location=None ) # Init model and freeze its weights ( for faster inference) pretrained_model.eval() pretrained_model.freeze() # Load the data for prediction dataset = DataContainer(os.path.join(os.pardir, 'data'), 5, 5) # Do the inference prediction_dict = defaultdict(list) for i in tqdm(range(len(dataset)), total=len(dataset)): p_X = pretrained_model(dataset[i].unsqueeze(0)) for idx in range(len(p_X) - 1): prediction_dict[idx].append(p_X[idx]) predictions = [torch.cat(prediction).detach().numpy() for prediction in prediction_dict.values()] for prediction in predictions: viz_latent(prediction) def viz_latent(prediction): if prediction.shape[-1] <= 1: raise ValueError('How did this happen?') elif prediction.shape[-1] == 2: ax = sns.scatterplot(x=prediction[:, 0], y=prediction[:, 1]) plt.show() return ax else: fig, axs = plt.subplots(ncols=2) predictions_pca = PCA(n_components=2) predictions_tsne = TSNE(n_components=2) pca_plot = sns.scatterplot(x=predictions_pca[:, 0], y=predictions_pca[:, 1], ax=axs[0]) tsne_plot = sns.scatterplot(x=predictions_tsne[:, 0], y=predictions_tsne[:, 1], ax=axs[1]) plt.show() return fig, axs, pca_plot, tsne_plot if __name__ == '__main__': path = 'output' search_for_weights(path)