Done: First VIsualization

ToDo: Visualization for all classes, latent space setups
This commit is contained in:
Si11ium 2019-08-21 07:56:31 +02:00
parent 8aa3b3616f
commit 744c0c50b7
8 changed files with 320 additions and 23 deletions

26
.idea/workspace.xml generated
View File

@ -1,7 +1,16 @@
<?xml version="1.0" encoding="UTF-8"?> <?xml version="1.0" encoding="UTF-8"?>
<project version="4"> <project version="4">
<component name="ChangeListManager"> <component name="ChangeListManager">
<list default="true" id="5955480a-c876-43d5-afd7-8717f51f413e" name="Default Changelist" comment="" /> <list default="true" id="5955480a-c876-43d5-afd7-8717f51f413e" name="Default Changelist" comment="">
<change afterPath="$PROJECT_DIR$/networks/seperating_adversarial_auto_encoder.py" afterDir="false" />
<change afterPath="$PROJECT_DIR$/viz/viz_latent.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" />
<change beforePath="$PROJECT_DIR$/networks/adverserial_auto_encoder.py" beforeDir="false" afterPath="$PROJECT_DIR$/networks/adverserial_auto_encoder.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/networks/auto_encoder.py" beforeDir="false" afterPath="$PROJECT_DIR$/networks/auto_encoder.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/networks/modules.py" beforeDir="false" afterPath="$PROJECT_DIR$/networks/modules.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/networks/variational_auto_encoder.py" beforeDir="false" afterPath="$PROJECT_DIR$/networks/variational_auto_encoder.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/run_models.py" beforeDir="false" afterPath="$PROJECT_DIR$/run_models.py" afterDir="false" />
</list>
<option name="EXCLUDED_CONVERTED_TO_IGNORED" value="true" /> <option name="EXCLUDED_CONVERTED_TO_IGNORED" value="true" />
<option name="SHOW_DIALOG" value="false" /> <option name="SHOW_DIALOG" value="false" />
<option name="HIGHLIGHT_CONFLICTS" value="true" /> <option name="HIGHLIGHT_CONFLICTS" value="true" />
@ -144,7 +153,7 @@
<workItem from="1564587420277" duration="6891000" /> <workItem from="1564587420277" duration="6891000" />
<workItem from="1565364574595" duration="1092000" /> <workItem from="1565364574595" duration="1092000" />
<workItem from="1565592214301" duration="53660000" /> <workItem from="1565592214301" duration="53660000" />
<workItem from="1565793671730" duration="30478000" /> <workItem from="1565793671730" duration="53296000" />
</task> </task>
<task id="LOCAL-00001" summary="Lightning integration basic ae, dataloaders and dataset"> <task id="LOCAL-00001" summary="Lightning integration basic ae, dataloaders and dataset">
<created>1565793753423</created> <created>1565793753423</created>
@ -167,7 +176,14 @@
<option name="project" value="LOCAL" /> <option name="project" value="LOCAL" />
<updated>1565987964760</updated> <updated>1565987964760</updated>
</task> </task>
<option name="localTasksCounter" value="4" /> <task id="LOCAL-00004" summary="Done: AE, VAE, AAE&#10;ToDo: Double AAE, Visualization&#10;All Modularized">
<created>1566064016196</created>
<option name="number" value="00004" />
<option name="presentableId" value="LOCAL-00004" />
<option name="project" value="LOCAL" />
<updated>1566064016196</updated>
</task>
<option name="localTasksCounter" value="5" />
<servers /> <servers />
</component> </component>
<component name="TypeScriptGeneratedFilesManager"> <component name="TypeScriptGeneratedFilesManager">
@ -196,7 +212,7 @@
<breakpoints> <breakpoints>
<line-breakpoint enabled="true" suspend="THREAD" type="python-line"> <line-breakpoint enabled="true" suspend="THREAD" type="python-line">
<url>file://$PROJECT_DIR$/run_models.py</url> <url>file://$PROJECT_DIR$/run_models.py</url>
<line>20</line> <line>23</line>
<option name="timeStamp" value="27" /> <option name="timeStamp" value="27" />
</line-breakpoint> </line-breakpoint>
</breakpoints> </breakpoints>
@ -214,7 +230,7 @@
<SUITE FILE_PATH="coverage/ae_toolbox_torch$basic_ae_lightning.coverage" NAME="basic_ae_lightning Coverage Results" MODIFIED="1565956491159" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" /> <SUITE FILE_PATH="coverage/ae_toolbox_torch$basic_ae_lightning.coverage" NAME="basic_ae_lightning Coverage Results" MODIFIED="1565956491159" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" />
<SUITE FILE_PATH="coverage/ae_toolbox_torch$basic_vae_lightning.coverage" NAME="basic_vae_lightning Coverage Results" MODIFIED="1565955311009" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" /> <SUITE FILE_PATH="coverage/ae_toolbox_torch$basic_vae_lightning.coverage" NAME="basic_vae_lightning Coverage Results" MODIFIED="1565955311009" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" />
<SUITE FILE_PATH="coverage/ae_toolbox_torch$run_basic_ae.coverage" NAME="run_basic_ae Coverage Results" MODIFIED="1565966122607" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" /> <SUITE FILE_PATH="coverage/ae_toolbox_torch$run_basic_ae.coverage" NAME="run_basic_ae Coverage Results" MODIFIED="1565966122607" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" />
<SUITE FILE_PATH="coverage/ae_toolbox_torch$run_models.coverage" NAME="run_models Coverage Results" MODIFIED="1565987843914" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" /> <SUITE FILE_PATH="coverage/ae_toolbox_torch$run_models.coverage" NAME="run_models Coverage Results" MODIFIED="1566306843739" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" />
<SUITE FILE_PATH="coverage/ae_toolbox_torch$dataset.coverage" NAME="dataset Coverage Results" MODIFIED="1565772669750" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/data" /> <SUITE FILE_PATH="coverage/ae_toolbox_torch$dataset.coverage" NAME="dataset Coverage Results" MODIFIED="1565772669750" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/data" />
</component> </component>
</project> </project>

View File

@ -25,6 +25,10 @@ class AdversarialAutoEncoder(AutoEncoder):
class AdversarialAELightningOverrides: class AdversarialAELightningOverrides:
@property
def name(self):
return self.__class__.__name__
def forward(self, x): def forward(self, x):
return self.network.forward(x) return self.network.forward(x)
@ -46,6 +50,7 @@ class AdversarialAELightningOverrides:
d_loss_fake = mse_loss(d_fake_prediction, torch.ones(d_fake_prediction.shape)) d_loss_fake = mse_loss(d_fake_prediction, torch.ones(d_fake_prediction.shape))
# Calculate the mean over both the real and the fake acc # Calculate the mean over both the real and the fake acc
# ToDo: do i need to compute this seperate?
d_loss = 0.5 * torch.add(d_loss_real, d_loss_fake) d_loss = 0.5 * torch.add(d_loss_real, d_loss_fake)
return {'loss': d_loss} return {'loss': d_loss}

View File

@ -5,16 +5,12 @@ from torch import Tensor
####################### #######################
# Basic AE-Implementation # Basic AE-Implementation
class AutoEncoder(Module, ABC): class AutoEncoder(AbstractNeuralNetwork, ABC):
@property def __init__(self, latent_dim: int, dataParams: dict, **kwargs):
def name(self):
return self.__class__.__name__
def __init__(self, dataParams, **kwargs):
super(AutoEncoder, self).__init__() super(AutoEncoder, self).__init__()
self.dataParams = dataParams self.dataParams = dataParams
self.latent_dim = kwargs.get('latent_dim', 2) self.latent_dim = latent_dim
self.encoder = Encoder(self.latent_dim) self.encoder = Encoder(self.latent_dim)
self.decoder = Decoder(self.latent_dim, self.dataParams['features']) self.decoder = Decoder(self.latent_dim, self.dataParams['features'])
@ -31,6 +27,10 @@ class AutoEncoder(Module, ABC):
class AutoEncoderLightningOverrides: class AutoEncoderLightningOverrides:
@property
def name(self):
return self.__class__.__name__
def forward(self, x): def forward(self, x):
return self.network.forward(x) return self.network.forward(x)

View File

@ -1,9 +1,24 @@
import torch import torch
import pytorch_lightning as pl import pytorch_lightning as pl
from torch.nn import Module, Linear, ReLU, Tanh, Sigmoid, Dropout, GRU from torch.nn import Module, Linear, ReLU, Tanh, Sigmoid, Dropout, GRU, AvgPool2d
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
#######################
# Abstract NN Class
class AbstractNeuralNetwork(Module):
@property
def name(self):
return self.__class__.__name__
def __init__(self):
super(AbstractNeuralNetwork, self).__init__()
def forward(self, batch):
pass
###################### ######################
# Abstract Network class following the Lightning Syntax # Abstract Network class following the Lightning Syntax
@ -102,6 +117,15 @@ class RNNOutputFilter(Module):
return out if not self.only_last else out[:, -1, :] return out if not self.only_last else out[:, -1, :]
class AvgDimPool(Module):
def __init__(self):
super(AvgDimPool, self).__init__()
def forward(self, x):
return x.mean(-2)
####################### #######################
# Network Modules # Network Modules
# Generators, Decoders, Encoders, Discriminators # Generators, Decoders, Encoders, Discriminators
@ -112,8 +136,8 @@ class Discriminator(Module):
self.dataParams = dataParams self.dataParams = dataParams
self.latent_dim = latent_dim self.latent_dim = latent_dim
self.l1 = Linear(self.latent_dim, self.dataParams['features'] * 10) self.l1 = Linear(self.latent_dim, self.dataParams['features'] * 10)
self.l2 = Linear(self.dataParams['features']*10, self.dataParams['features'] * 20) self.l2 = Linear(self.dataParams['features'] * 10, self.dataParams['features'] * 20)
self.lout = Linear(self.dataParams['features']*20, 1) self.lout = Linear(self.dataParams['features'] * 20, 1)
self.dropout = Dropout(dropout) self.dropout = Dropout(dropout)
self.activation = activation() self.activation = activation()
self.sigmoid = Sigmoid() self.sigmoid = Sigmoid()
@ -149,6 +173,7 @@ class EncoderLinearStack(Module):
def __init__(self): def __init__(self):
super(EncoderLinearStack, self).__init__() super(EncoderLinearStack, self).__init__()
# FixMe: Get Hardcoded shit out of here
self.l1 = Linear(6, 100, bias=True) self.l1 = Linear(6, 100, bias=True)
self.l2 = Linear(100, 10, bias=True) self.l2 = Linear(100, 10, bias=True)
self.activation = ReLU() self.activation = ReLU()
@ -188,6 +213,31 @@ class Encoder(Module):
return tensor return tensor
class PoolingEncoder(Module):
def __init__(self, lat_dim, variational=False):
self.lat_dim = lat_dim
self.variational = variational
super(PoolingEncoder, self).__init__()
self.p = AvgDimPool()
self.l = EncoderLinearStack()
if variational:
self.mu = Linear(10, self.lat_dim)
self.logvar = Linear(10, self.lat_dim)
else:
self.lat_dim_layer = Linear(10, self.lat_dim)
def forward(self, x):
tensor = self.p(x)
tensor = self.l(tensor)
if self.variational:
tensor = self.mu(tensor), self.logvar(tensor)
else:
tensor = self.lat_dim_layer(tensor)
return tensor
class Decoder(Module): class Decoder(Module):
def __init__(self, latent_dim, *args, variational=False): def __init__(self, latent_dim, *args, variational=False):

View File

@ -0,0 +1,96 @@
from networks.auto_encoder import AutoEncoder
from torch.nn.functional import mse_loss
from networks.modules import *
import torch
class SeperatingAdversarialAutoEncoder(Module):
def __init__(self, latent_dim, dataParams, **kwargs):
assert latent_dim % 2 == 0, f'Your latent space needs to be even, not odd, but was: "{latent_dim}"'
super(SeperatingAdversarialAutoEncoder, self).__init__()
self.latent_dim = latent_dim
self.dataParams = dataParams
self.spatial_encoder = PoolingEncoder(self.latent_dim // 2)
self.temporal_encoder = Encoder(self.latent_dim // 2)
self.decoder = Decoder(self.latent_dim, self.dataParams['features'])
self.spatial_discriminator = Discriminator(self.latent_dim // 2, self.dataParams)
self.temporal_discriminator = Discriminator(self.latent_dim // 2, self.dataParams)
def forward(self, batch):
# Encoder
# outputs, hidden (Batch, Timesteps aka. Size, Features / Latent Dim Size)
z_spatial, z_temporal = self.spatial_encoder(batch), self.temporal_encoder(batch)
# Decoder
# First repeat the data accordingly to the batch size
z_concat = torch.cat((z_spatial, z_temporal), dim=-1)
z_repeatet = Repeater((batch.shape[0], self.dataParams['size'], -1))(z_concat)
x_hat = self.decoder(z_repeatet)
return z_spatial, z_temporal, x_hat
class SeparatingAdversarialAELightningOverrides:
@property
def name(self):
return self.__class__.__name__
def forward(self, x):
return self.network.forward(x)
def training_step(self, batch, _, optimizer_i):
spatial_latent_fake, temporal_latent_fake, batch_hat = self.network.forward(batch)
if optimizer_i == 0:
# ---------------------
# Train temporal Discriminator
# ---------------------
# latent_fake, reconstruction
temporal_latent_real = self.normal.sample(temporal_latent_fake.shape)
# Evaluate the input
temporal_real_prediction = self.network.temporal_discriminator.forward(temporal_latent_real)
temporal_fake_prediction = self.network.temporal_discriminator.forward(temporal_latent_fake)
# Train the discriminator
temporal_loss_real = mse_loss(temporal_real_prediction, torch.zeros(temporal_real_prediction.shape))
temporal_loss_fake = mse_loss(temporal_fake_prediction, torch.ones(temporal_fake_prediction.shape))
# Calculate the mean over bot the real and the fake acc
# ToDo: do i need to compute this seperate?
d_loss = 0.5 * torch.add(temporal_loss_real, temporal_loss_fake)
return {'loss': d_loss}
if optimizer_i == 1:
# ---------------------
# Train spatial Discriminator
# ---------------------
# latent_fake, reconstruction
spatial_latent_real = self.normal.sample(spatial_latent_fake.shape)
# Evaluate the input
spatial_real_prediction = self.network.spatial_discriminator.forward(spatial_latent_real)
spatial_fake_prediction = self.network.spatial_discriminator.forward(spatial_latent_fake)
# Train the discriminator
spatial_loss_real = mse_loss(spatial_real_prediction, torch.zeros(spatial_real_prediction.shape))
spatial_loss_fake = mse_loss(spatial_fake_prediction, torch.ones(spatial_fake_prediction.shape))
# Calculate the mean over bot the real and the fake acc
# ToDo: do i need to compute this seperate?
d_loss = 0.5 * torch.add(spatial_loss_real, spatial_loss_fake)
return {'loss': d_loss}
elif optimizer_i == 2:
# ---------------------
# Train AutoEncoder
# ---------------------
loss = mse_loss(batch, batch_hat)
return {'loss': loss}
else:
raise RuntimeError('This should not have happened, catch me if u can.')
if __name__ == '__main__':
raise PermissionError('Get out of here - never run this module')

View File

@ -4,7 +4,7 @@ from torch.nn.functional import mse_loss
####################### #######################
# Basic AE-Implementation # Basic AE-Implementation
class VariationalAutoEncoder(Module, ABC): class VariationalAutoEncoder(AbstractNeuralNetwork, ABC):
@property @property
def name(self): def name(self):
@ -34,6 +34,10 @@ class VariationalAutoEncoder(Module, ABC):
class VariationalAutoEncoderLightningOverrides: class VariationalAutoEncoderLightningOverrides:
@property
def name(self):
return self.network.name
def forward(self, x): def forward(self, x):
return self.network.forward(x) return self.network.forward(x)

View File

@ -1,6 +1,9 @@
from networks.auto_encoder import * from networks.auto_encoder import *
import os
import time
from networks.variational_auto_encoder import * from networks.variational_auto_encoder import *
from networks.adverserial_auto_encoder import * from networks.adverserial_auto_encoder import *
from networks.seperating_adversarial_auto_encoder import *
from networks.modules import LightningModule from networks.modules import LightningModule
from torch.optim import Adam from torch.optim import Adam
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
@ -9,7 +12,7 @@ from dataset import DataContainer
from torch.nn import BatchNorm1d from torch.nn import BatchNorm1d
from pytorch_lightning import Trainer from pytorch_lightning import Trainer
from test_tube import Experiment
# ToDo: How to implement this better? # ToDo: How to implement this better?
# other_classes = [AutoEncoder, AutoEncoderLightningOverrides] # other_classes = [AutoEncoder, AutoEncoderLightningOverrides]
@ -30,6 +33,10 @@ class Model(VariationalAutoEncoderLightningOverrides, LightningModule):
class AdversarialModel(AdversarialAELightningOverrides, LightningModule): class AdversarialModel(AdversarialAELightningOverrides, LightningModule):
@property
def name(self):
return self.network.name
def __init__(self, dataParams: dict): def __init__(self, dataParams: dict):
super(AdversarialModel, self).__init__() super(AdversarialModel, self).__init__()
self.dataParams = dataParams self.dataParams = dataParams
@ -48,13 +55,61 @@ class AdversarialModel(AdversarialAELightningOverrides, LightningModule):
return DataLoader(DataContainer('data', **self.dataParams), shuffle=True, batch_size=100) return DataLoader(DataContainer('data', **self.dataParams), shuffle=True, batch_size=100)
class SeparatingAdversarialModel(SeparatingAdversarialAELightningOverrides, LightningModule):
def __init__(self, latent_dim, dataParams: dict):
super(SeparatingAdversarialModel, self).__init__()
self.latent_dim = latent_dim
self.dataParams = dataParams
self.normal = Normal(0, 1)
self.network = SeperatingAdversarialAutoEncoder(self.latent_dim, self.dataParams)
pass
# This is Fucked up, why do i need to put an additional empty list here?
def configure_optimizers(self):
return [Adam([*self.network.spatial_discriminator.parameters(), *self.network.spatial_encoder.parameters()]
, lr=0.02),
Adam([*self.network.temporal_discriminator.parameters(), *self.network.temporal_encoder.parameters()]
, lr=0.02),
Adam([*self.network.temporal_encoder.parameters(),
*self.network.spatial_encoder.parameters(),
*self.network.decoder.parameters()]
, lr=0.02)], []
@data_loader
def tng_dataloader(self):
return DataLoader(DataContainer('data', **self.dataParams), shuffle=True, batch_size=100)
if __name__ == '__main__': if __name__ == '__main__':
features = 6 features = 6
ae = AdversarialModel( latent_dim = 4
dataParams=dict(refresh=False, size=5, step=5, model = SeparatingAdversarialModel(latent_dim=latent_dim, dataParams=dict(refresh=False, size=5, step=5,
features=features, transforms=[BatchNorm1d(features)] features=features, transforms=[BatchNorm1d(features)]
) )
)
# PyTorch summarywriter with a few bells and whistles
outpath = os.path.join(os.getcwd(), 'output', model.name, time.asctime().replace(' ', '_').replace(':', '-'))
os.makedirs(outpath, exist_ok=True)
exp = Experiment(save_dir=outpath)
from pytorch_lightning.callbacks import ModelCheckpoint
checkpoint_callback = ModelCheckpoint(
filepath=os.path.join(outpath, 'weights.ckpt'),
save_best_only=True,
verbose=True,
monitor='val_loss',
mode='min',
) )
trainer = Trainer() trainer = Trainer(experiment=exp, checkpoint_callback=checkpoint_callback, max_nb_epochs=15) # gpus=[0...LoL]
trainer.fit(ae) trainer.fit(model)
trainer.save_checkpoint(os.path.join(outpath, 'weights.ckpt'))
# view tensorflow logs
print(f'View tensorboard logs by running\ntensorboard --logdir {outpath}')
print('and going to http://localhost:6006 on your browser')

71
viz/viz_latent.py Normal file
View File

@ -0,0 +1,71 @@
# TODO: THIS
import seaborn as sb
import torch
from torch.utils.data import DataLoader
from pytorch_lightning import data_loader
from dataset import DataContainer
import os
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
import seaborn as sns; sns.set()
import matplotlib.pyplot as plt
from run_models import SeparatingAdversarialModel
path = 'output'
mylightningmodule = 'weired name, loaded from disk'
# FIXME: How to store hyperparamters in testtube element?
def search_for_weights(folder):
for element in os.scandir(folder):
if os.path.exists(element):
if element.is_dir():
search_for_weights(element.path)
elif element.is_file() and element.name.endswith('.ckpt'):
load_and_viz(element)
else:
continue
def load_and_viz(path_like_element):
# Define Loop to search for models and folder with visualizations
pretrained_model = SeparatingAdversarialModel.load_from_metrics(
weights_path=path_like_element.path,
tags_csv=os.path.join(os.path.dirname(path_like_element), 'default', 'version_0', 'meta_tags.csv'),
on_gpu=True if torch.cuda.is_available() else False,
map_location=None
)
# Init model and freeze its weights ( for faster inference)
pretrained_model.eval()
pretrained_model.freeze()
# Load the data fpr prediction
dataset = DataContainer('data', 5, 5)
# Do the inference
predictions = []
for i in range(len(dataset)):
z, _ = pretrained_model(dataset[i])
predictions.append(z)
predictions = torch.cat(predictions)
if predictions.shape[-1] <= 1:
raise ValueError('How did this happen?')
elif predictions.shape[-1] == 2:
ax = sns.scatterplot(x=predictions[:, 0], y=predictions[:, 1])
plt.show()
return ax
else:
fig, axs = plt.subplots(ncols=2)
predictions_pca = PCA(n_components=2)
predictions_tsne = TSNE(n_components=2)
pca_plot = sns.scatterplot(x=predictions_pca[:, 0], y=predictions_pca[:, 1], ax=axs[0])
tsne_plot = sns.scatterplot(x=predictions_tsne[:, 0], y=predictions_tsne[:, 1], ax=axs[1])
plt.show()
return fig, axs, pca_plot, tsne_plot