diff --git a/.idea/workspace.xml b/.idea/workspace.xml index ff2b3b5..539e40a 100644 --- a/.idea/workspace.xml +++ b/.idea/workspace.xml @@ -2,14 +2,15 @@ <project version="4"> <component name="ChangeListManager"> <list default="true" id="5955480a-c876-43d5-afd7-8717f51f413e" name="Default Changelist" comment=""> - <change afterPath="$PROJECT_DIR$/networks/seperating_adversarial_auto_encoder.py" afterDir="false" /> - <change afterPath="$PROJECT_DIR$/viz/viz_latent.py" afterDir="false" /> <change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" /> + <change beforePath="$PROJECT_DIR$/dataset.py" beforeDir="false" afterPath="$PROJECT_DIR$/dataset.py" afterDir="false" /> <change beforePath="$PROJECT_DIR$/networks/adverserial_auto_encoder.py" beforeDir="false" afterPath="$PROJECT_DIR$/networks/adverserial_auto_encoder.py" afterDir="false" /> <change beforePath="$PROJECT_DIR$/networks/auto_encoder.py" beforeDir="false" afterPath="$PROJECT_DIR$/networks/auto_encoder.py" afterDir="false" /> <change beforePath="$PROJECT_DIR$/networks/modules.py" beforeDir="false" afterPath="$PROJECT_DIR$/networks/modules.py" afterDir="false" /> + <change beforePath="$PROJECT_DIR$/networks/seperating_adversarial_auto_encoder.py" beforeDir="false" afterPath="$PROJECT_DIR$/networks/seperating_adversarial_auto_encoder.py" afterDir="false" /> <change beforePath="$PROJECT_DIR$/networks/variational_auto_encoder.py" beforeDir="false" afterPath="$PROJECT_DIR$/networks/variational_auto_encoder.py" afterDir="false" /> <change beforePath="$PROJECT_DIR$/run_models.py" beforeDir="false" afterPath="$PROJECT_DIR$/run_models.py" afterDir="false" /> + <change beforePath="$PROJECT_DIR$/viz/viz_latent.py" beforeDir="false" afterPath="$PROJECT_DIR$/viz/viz_latent.py" afterDir="false" /> </list> <option name="EXCLUDED_CONVERTED_TO_IGNORED" value="true" /> <option name="SHOW_DIALOG" value="false" /> @@ -68,7 +69,7 @@ </list> </option> </component> - <component name="RunManager" selected="Python.run_models"> + <component name="RunManager" selected="Python.viz_latent"> <configuration default="true" type="PythonConfigurationType" factoryName="Python"> <module name="ae_toolbox_torch" /> <option name="INTERPRETER_OPTIONS" value="" /> @@ -95,6 +96,9 @@ <module name="ae_toolbox_torch" /> <option name="INTERPRETER_OPTIONS" value="" /> <option name="PARENT_ENVS" value="true" /> + <envs> + <env name="PYTHONUNBUFFERED" value="1" /> + </envs> <option name="SDK_HOME" value="" /> <option name="WORKING_DIRECTORY" value="$PROJECT_DIR$" /> <option name="IS_MODULE_SDK" value="true" /> @@ -129,12 +133,33 @@ <option name="INPUT_FILE" value="" /> <method v="2" /> </configuration> + <configuration name="viz_latent" type="PythonConfigurationType" factoryName="Python" temporary="true"> + <module name="ae_toolbox_torch" /> + <option name="INTERPRETER_OPTIONS" value="" /> + <option name="PARENT_ENVS" value="true" /> + <option name="SDK_HOME" value="" /> + <option name="WORKING_DIRECTORY" value="$PROJECT_DIR$/viz" /> + <option name="IS_MODULE_SDK" value="true" /> + <option name="ADD_CONTENT_ROOTS" value="true" /> + <option name="ADD_SOURCE_ROOTS" value="true" /> + <EXTENSION ID="PythonCoverageRunConfigurationExtension" runner="coverage.py" /> + <option name="SCRIPT_NAME" value="$PROJECT_DIR$/viz/viz_latent.py" /> + <option name="PARAMETERS" value="" /> + <option name="SHOW_COMMAND_LINE" value="true" /> + <option name="EMULATE_TERMINAL" value="false" /> + <option name="MODULE_MODE" value="false" /> + <option name="REDIRECT_INPUT" value="false" /> + <option name="INPUT_FILE" value="" /> + <method v="2" /> + </configuration> <list> <item itemvalue="Python.run_basic_ae" /> <item itemvalue="Python.run_models" /> + <item itemvalue="Python.viz_latent" /> </list> <recent_temporary> <list> + <item itemvalue="Python.viz_latent" /> <item itemvalue="Python.run_models" /> <item itemvalue="Python.run_basic_ae" /> </list> @@ -153,7 +178,8 @@ <workItem from="1564587420277" duration="6891000" /> <workItem from="1565364574595" duration="1092000" /> <workItem from="1565592214301" duration="53660000" /> - <workItem from="1565793671730" duration="53296000" /> + <workItem from="1565793671730" duration="53351000" /> + <workItem from="1566372837067" duration="27899000" /> </task> <task id="LOCAL-00001" summary="Lightning integration basic ae, dataloaders and dataset"> <created>1565793753423</created> @@ -183,7 +209,14 @@ <option name="project" value="LOCAL" /> <updated>1566064016196</updated> </task> - <option name="localTasksCounter" value="5" /> + <task id="LOCAL-00005" summary="Done: First VIsualization ToDo: Visualization for all classes, latent space setups"> + <created>1566366992088</created> + <option name="number" value="00005" /> + <option name="presentableId" value="LOCAL-00005" /> + <option name="project" value="LOCAL" /> + <updated>1566366992088</updated> + </task> + <option name="localTasksCounter" value="6" /> <servers /> </component> <component name="TypeScriptGeneratedFilesManager"> @@ -205,32 +238,46 @@ <component name="VcsManagerConfiguration"> <MESSAGE value="Lightning integration basic ae, dataloaders and dataset" /> <MESSAGE value="Done: AE, VAE, AAE ToDo: Double AAE, Visualization All Modularized" /> - <option name="LAST_COMMIT_MESSAGE" value="Done: AE, VAE, AAE ToDo: Double AAE, Visualization All Modularized" /> + <MESSAGE value="Done: First VIsualization ToDo: Visualization for all classes, latent space setups" /> + <option name="LAST_COMMIT_MESSAGE" value="Done: First VIsualization ToDo: Visualization for all classes, latent space setups" /> </component> <component name="XDebuggerManager"> <breakpoint-manager> <breakpoints> <line-breakpoint enabled="true" suspend="THREAD" type="python-line"> - <url>file://$PROJECT_DIR$/run_models.py</url> + <url>file://$PROJECT_DIR$/networks/modules.py</url> + <line>206</line> + <option name="timeStamp" value="51" /> + </line-breakpoint> + <line-breakpoint enabled="true" suspend="THREAD" type="python-line"> + <url>file://$PROJECT_DIR$/networks/seperating_adversarial_auto_encoder.py</url> <line>23</line> - <option name="timeStamp" value="27" /> + <option name="timeStamp" value="52" /> + </line-breakpoint> + <line-breakpoint enabled="true" suspend="THREAD" type="python-line"> + <url>file://$PROJECT_DIR$/viz/viz_latent.py</url> + <line>67</line> + <option name="timeStamp" value="56" /> </line-breakpoint> </breakpoints> <default-breakpoints> <breakpoint type="python-exception"> - <properties notifyOnTerminate="true" exception="BaseException"> + <properties notifyOnlyOnFirst="true" notifyOnTerminate="true" ignoreLibraries="true" exception="BaseException"> + <option name="ignoreLibraries" value="true" /> <option name="notifyOnTerminate" value="true" /> + <option name="notifyOnlyOnFirst" value="true" /> </properties> </breakpoint> </default-breakpoints> </breakpoint-manager> </component> <component name="com.intellij.coverage.CoverageDataManagerImpl"> + <SUITE FILE_PATH="coverage/ae_toolbox_torch$viz_latent.coverage" NAME="viz_latent Coverage Results" MODIFIED="1566541302103" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/viz" /> <SUITE FILE_PATH="coverage/ae_toolbox_torch$basic_ae_lightning_torch.coverage" NAME="basic_ae_lightning_torch Coverage Results" MODIFIED="1565937164457" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" /> <SUITE FILE_PATH="coverage/ae_toolbox_torch$basic_ae_lightning.coverage" NAME="basic_ae_lightning Coverage Results" MODIFIED="1565956491159" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" /> <SUITE FILE_PATH="coverage/ae_toolbox_torch$basic_vae_lightning.coverage" NAME="basic_vae_lightning Coverage Results" MODIFIED="1565955311009" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" /> <SUITE FILE_PATH="coverage/ae_toolbox_torch$run_basic_ae.coverage" NAME="run_basic_ae Coverage Results" MODIFIED="1565966122607" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" /> - <SUITE FILE_PATH="coverage/ae_toolbox_torch$run_models.coverage" NAME="run_models Coverage Results" MODIFIED="1566306843739" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" /> + <SUITE FILE_PATH="coverage/ae_toolbox_torch$run_models.coverage" NAME="run_models Coverage Results" MODIFIED="1566537126647" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" /> <SUITE FILE_PATH="coverage/ae_toolbox_torch$dataset.coverage" NAME="dataset Coverage Results" MODIFIED="1565772669750" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/data" /> </component> </project> \ No newline at end of file diff --git a/dataset.py b/dataset.py index 09763b8..a5c85a1 100644 --- a/dataset.py +++ b/dataset.py @@ -163,7 +163,7 @@ class Trajectories(Dataset): self.data = self.__init_data_(**kwargs) pass - def __init_data_(self, **kwargs): + def __init_data_(self, **kwargs: dict): dataDict = dict() for key, val in kwargs.items(): if key in self.isovistMeasures: @@ -177,6 +177,7 @@ class Trajectories(Dataset): return data def __iter__(self): + # FixMe: is that correct? for i in range(len(self)): yield self[i] diff --git a/networks/adverserial_auto_encoder.py b/networks/adverserial_auto_encoder.py index 996d5b0..53352fe 100644 --- a/networks/adverserial_auto_encoder.py +++ b/networks/adverserial_auto_encoder.py @@ -10,7 +10,7 @@ class AdversarialAutoEncoder(AutoEncoder): def __init__(self, *args, **kwargs): super(AdversarialAutoEncoder, self).__init__(*args, **kwargs) - self.discriminator = Discriminator(self.latent_dim, self.dataParams) + self.discriminator = Discriminator(self.latent_dim, self.features) def forward(self, batch): # Encoder @@ -18,7 +18,7 @@ class AdversarialAutoEncoder(AutoEncoder): z = self.encoder(batch) # Decoder # First repeat the data accordingly to the batch size - z_repeatet = Repeater((batch.shape[0], self.dataParams['size'], -1))(z) + z_repeatet = Repeater((batch.shape[0], batch.shape[1], -1))(z) x_hat = self.decoder(z_repeatet) return z, x_hat diff --git a/networks/auto_encoder.py b/networks/auto_encoder.py index cae1214..b72bc59 100644 --- a/networks/auto_encoder.py +++ b/networks/auto_encoder.py @@ -7,12 +7,13 @@ from torch import Tensor # Basic AE-Implementation class AutoEncoder(AbstractNeuralNetwork, ABC): - def __init__(self, latent_dim: int, dataParams: dict, **kwargs): + def __init__(self, latent_dim: int=0, features: int = 0, **kwargs): + assert latent_dim and features super(AutoEncoder, self).__init__() - self.dataParams = dataParams self.latent_dim = latent_dim + self.features = features self.encoder = Encoder(self.latent_dim) - self.decoder = Decoder(self.latent_dim, self.dataParams['features']) + self.decoder = Decoder(self.latent_dim, self.features) def forward(self, batch: Tensor): # Encoder @@ -20,7 +21,7 @@ class AutoEncoder(AbstractNeuralNetwork, ABC): z = self.encoder(batch) # Decoder # First repeat the data accordingly to the batch size - z_repeatet = Repeater((batch.shape[0], self.dataParams['size'], -1))(z) + z_repeatet = Repeater((batch.shape[0], batch.shape[1], -1))(z) x_hat = self.decoder(z_repeatet) return z, x_hat diff --git a/networks/modules.py b/networks/modules.py index bf73652..0cc5ccf 100644 --- a/networks/modules.py +++ b/networks/modules.py @@ -131,13 +131,13 @@ class AvgDimPool(Module): # Generators, Decoders, Encoders, Discriminators class Discriminator(Module): - def __init__(self, latent_dim, dataParams, dropout=.0, activation=ReLU): + def __init__(self, latent_dim, features, dropout=.0, activation=ReLU): super(Discriminator, self).__init__() - self.dataParams = dataParams + self.features = features self.latent_dim = latent_dim - self.l1 = Linear(self.latent_dim, self.dataParams['features'] * 10) - self.l2 = Linear(self.dataParams['features'] * 10, self.dataParams['features'] * 20) - self.lout = Linear(self.dataParams['features'] * 20, 1) + self.l1 = Linear(self.latent_dim, self.features * 10) + self.l2 = Linear(self.features * 10, self.features * 20) + self.lout = Linear(self.features * 20, 1) self.dropout = Dropout(dropout) self.activation = activation() self.sigmoid = Sigmoid() diff --git a/networks/seperating_adversarial_auto_encoder.py b/networks/seperating_adversarial_auto_encoder.py index 858a92f..5bf5fc5 100644 --- a/networks/seperating_adversarial_auto_encoder.py +++ b/networks/seperating_adversarial_auto_encoder.py @@ -6,17 +6,17 @@ import torch class SeperatingAdversarialAutoEncoder(Module): - def __init__(self, latent_dim, dataParams, **kwargs): + def __init__(self, latent_dim, features, **kwargs): assert latent_dim % 2 == 0, f'Your latent space needs to be even, not odd, but was: "{latent_dim}"' super(SeperatingAdversarialAutoEncoder, self).__init__() self.latent_dim = latent_dim - self.dataParams = dataParams + self.features = features self.spatial_encoder = PoolingEncoder(self.latent_dim // 2) self.temporal_encoder = Encoder(self.latent_dim // 2) - self.decoder = Decoder(self.latent_dim, self.dataParams['features']) - self.spatial_discriminator = Discriminator(self.latent_dim // 2, self.dataParams) - self.temporal_discriminator = Discriminator(self.latent_dim // 2, self.dataParams) + self.decoder = Decoder(self.latent_dim, self.features) + self.spatial_discriminator = Discriminator(self.latent_dim // 2, self.features) + self.temporal_discriminator = Discriminator(self.latent_dim // 2, self.features) def forward(self, batch): # Encoder @@ -25,7 +25,7 @@ class SeperatingAdversarialAutoEncoder(Module): # Decoder # First repeat the data accordingly to the batch size z_concat = torch.cat((z_spatial, z_temporal), dim=-1) - z_repeatet = Repeater((batch.shape[0], self.dataParams['size'], -1))(z_concat) + z_repeatet = Repeater((batch.shape[0], batch.shape[1], -1))(z_concat) x_hat = self.decoder(z_repeatet) return z_spatial, z_temporal, x_hat diff --git a/networks/variational_auto_encoder.py b/networks/variational_auto_encoder.py index fad33ef..64cb7a9 100644 --- a/networks/variational_auto_encoder.py +++ b/networks/variational_auto_encoder.py @@ -10,12 +10,13 @@ class VariationalAutoEncoder(AbstractNeuralNetwork, ABC): def name(self): return self.__class__.__name__ - def __init__(self, dataParams, **kwargs): + def __init__(self, latent_dim=0, features=0, **kwargs): + assert latent_dim and features super(VariationalAutoEncoder, self).__init__() - self.dataParams = dataParams - self.latent_dim = kwargs.get('latent_dim', 2) + self.features = features + self.latent_dim = latent_dim self.encoder = Encoder(self.latent_dim, variational=True) - self.decoder = Decoder(self.latent_dim, self.dataParams['features'], variational=True) + self.decoder = Decoder(self.latent_dim, self.features, variational=True) @staticmethod def reparameterize(mu, logvar): @@ -27,7 +28,7 @@ class VariationalAutoEncoder(AbstractNeuralNetwork, ABC): def forward(self, batch): mu, logvar = self.encoder(batch) z = self.reparameterize(mu, logvar) - repeat = Repeater((batch.shape[0], self.dataParams['size'], -1)) + repeat = Repeater((batch.shape[0], batch.shape[1], -1)) x_hat = self.decoder(repeat(z)) return x_hat, mu, logvar diff --git a/run_models.py b/run_models.py index 5b52e50..6362036 100644 --- a/run_models.py +++ b/run_models.py @@ -14,21 +14,35 @@ from torch.nn import BatchNorm1d from pytorch_lightning import Trainer from test_tube import Experiment +from argparse import Namespace +from argparse import ArgumentParser + +args = ArgumentParser() +args.add_argument('step') +args.add_argument('features') +args.add_argument('size') +args.add_argument('latent_dim') + + # ToDo: How to implement this better? # other_classes = [AutoEncoder, AutoEncoderLightningOverrides] -class Model(VariationalAutoEncoderLightningOverrides, LightningModule): +class Model(AutoEncoderLightningOverrides, LightningModule): - def __init__(self, dataParams: dict): + def __init__(self, latent_dim=0, size=0, step=0, features=0, **kwargs): + assert all([x in args for x in ['step', 'size', 'latent_dim', 'features']]) + self.size = args.size + self.latent_dim = args.latent_dim + self.features = args.features + self.step = args.step super(Model, self).__init__() - self.dataParams = dataParams - self.network = VariationalAutoEncoder(self.dataParams) + self.network = AutoEncoder(self.latent_dim, self.features) def configure_optimizers(self): return [Adam(self.parameters(), lr=0.02)] @data_loader def tng_dataloader(self): - return DataLoader(DataContainer('data', **self.dataParams), shuffle=True, batch_size=100) + return DataLoader(DataContainer('data', self.size, self.step), shuffle=True, batch_size=100) class AdversarialModel(AdversarialAELightningOverrides, LightningModule): @@ -37,11 +51,15 @@ class AdversarialModel(AdversarialAELightningOverrides, LightningModule): def name(self): return self.network.name - def __init__(self, dataParams: dict): + def __init__(self, args: Namespace, **kwargs): + assert all([x in args for x in ['step', 'size', 'latent_dim', 'features']]) + self.size = args.size + self.latent_dim = args.latent_dim + self.features = args.features + self.step = args.step super(AdversarialModel, self).__init__() - self.dataParams = dataParams self.normal = Normal(0, 1) - self.network = AdversarialAutoEncoder(self.dataParams) + self.network = AdversarialAutoEncoder(self.latent_dim, self.features) pass # This is Fucked up, why do i need to put an additional empty list here? @@ -52,17 +70,20 @@ class AdversarialModel(AdversarialAELightningOverrides, LightningModule): @data_loader def tng_dataloader(self): - return DataLoader(DataContainer('data', **self.dataParams), shuffle=True, batch_size=100) + return DataLoader(DataContainer('data', self.size, self.step), shuffle=True, batch_size=100) class SeparatingAdversarialModel(SeparatingAdversarialAELightningOverrides, LightningModule): - def __init__(self, latent_dim, dataParams: dict): + def __init__(self, args: Namespace, **kwargs): + assert all([x in args for x in ['step', 'size', 'latent_dim', 'features']]) + self.size = args.size + self.latent_dim = args.latent_dim + self.features = args.features + self.step = args.step super(SeparatingAdversarialModel, self).__init__() - self.latent_dim = latent_dim - self.dataParams = dataParams self.normal = Normal(0, 1) - self.network = SeperatingAdversarialAutoEncoder(self.latent_dim, self.dataParams) + self.network = SeperatingAdversarialAutoEncoder(self.latent_dim, self.features, **kwargs) pass # This is Fucked up, why do i need to put an additional empty list here? @@ -78,22 +99,24 @@ class SeparatingAdversarialModel(SeparatingAdversarialAELightningOverrides, Ligh @data_loader def tng_dataloader(self): - return DataLoader(DataContainer('data', **self.dataParams), shuffle=True, batch_size=100) + num_workers = os.cpu_count() // 2 + return DataLoader(DataContainer('data', self.size, self.step), shuffle=True, batch_size=100, num_workers=num_workers) if __name__ == '__main__': - features = 6 - latent_dim = 4 - model = SeparatingAdversarialModel(latent_dim=latent_dim, dataParams=dict(refresh=False, size=5, step=5, - features=features, transforms=[BatchNorm1d(features)] - ) - ) + tag_dict = dict(features=features, latent_dim=4, size=5, step=6, refresh=False, + transforms=[BatchNorm1d(features)]) + arguments = args.parse_args() + arguments.__dict__.update(tag_dict) + + model = SeparatingAdversarialModel(arguments) # PyTorch summarywriter with a few bells and whistles outpath = os.path.join(os.getcwd(), 'output', model.name, time.asctime().replace(' ', '_').replace(':', '-')) os.makedirs(outpath, exist_ok=True) exp = Experiment(save_dir=outpath) + exp.tag(tag_dict=tag_dict) from pytorch_lightning.callbacks import ModelCheckpoint @@ -101,9 +124,8 @@ if __name__ == '__main__': filepath=os.path.join(outpath, 'weights.ckpt'), save_best_only=True, verbose=True, - monitor='val_loss', + monitor='tng_loss', # val_loss mode='min', - ) trainer = Trainer(experiment=exp, checkpoint_callback=checkpoint_callback, max_nb_epochs=15) # gpus=[0...LoL] diff --git a/viz/viz_latent.py b/viz/viz_latent.py index e4a1b5f..140304c 100644 --- a/viz/viz_latent.py +++ b/viz/viz_latent.py @@ -4,6 +4,8 @@ import torch from torch.utils.data import DataLoader from pytorch_lightning import data_loader from dataset import DataContainer +from collections import defaultdict +from tqdm import tqdm import os from sklearn.manifold import TSNE @@ -12,30 +14,28 @@ from sklearn.decomposition import PCA import seaborn as sns; sns.set() import matplotlib.pyplot as plt -from run_models import SeparatingAdversarialModel - -path = 'output' -mylightningmodule = 'weired name, loaded from disk' - - -# FIXME: How to store hyperparamters in testtube element? - +from run_models import * def search_for_weights(folder): + while not os.path.exists(folder): + if len(os.path.split(folder)) >= 50: + raise FileNotFoundError(f'The folder "{folder}" could not be found') + folder = os.path.join(os.pardir, folder) for element in os.scandir(folder): if os.path.exists(element): if element.is_dir(): search_for_weights(element.path) elif element.is_file() and element.name.endswith('.ckpt'): - load_and_viz(element) + load_and_predict(element) else: continue -def load_and_viz(path_like_element): +def load_and_predict(path_like_element): # Define Loop to search for models and folder with visualizations - pretrained_model = SeparatingAdversarialModel.load_from_metrics( + model = globals()[path_like_element.path.split(os.sep)[-3]] + pretrained_model = model.load_from_metrics( weights_path=path_like_element.path, tags_csv=os.path.join(os.path.dirname(path_like_element), 'default', 'version_0', 'meta_tags.csv'), on_gpu=True if torch.cuda.is_available() else False, @@ -46,19 +46,26 @@ def load_and_viz(path_like_element): pretrained_model.eval() pretrained_model.freeze() - # Load the data fpr prediction - dataset = DataContainer('data', 5, 5) + # Load the data for prediction + dataset = DataContainer(os.path.join(os.pardir, 'data'), 5, 5) # Do the inference - predictions = [] - for i in range(len(dataset)): - z, _ = pretrained_model(dataset[i]) - predictions.append(z) - predictions = torch.cat(predictions) - if predictions.shape[-1] <= 1: + prediction_dict = defaultdict(list) + for i in tqdm(range(len(dataset)), total=len(dataset)): + p_X = pretrained_model(dataset[i].unsqueeze(0)) + for idx in range(len(p_X) - 1): + prediction_dict[idx].append(p_X[idx]) + + predictions = [torch.cat(prediction).detach().numpy() for prediction in prediction_dict.values()] + for prediction in predictions: + viz_latent(prediction) + + +def viz_latent(prediction): + if prediction.shape[-1] <= 1: raise ValueError('How did this happen?') - elif predictions.shape[-1] == 2: - ax = sns.scatterplot(x=predictions[:, 0], y=predictions[:, 1]) + elif prediction.shape[-1] == 2: + ax = sns.scatterplot(x=prediction[:, 0], y=prediction[:, 1]) plt.show() return ax else: @@ -69,3 +76,7 @@ def load_and_viz(path_like_element): tsne_plot = sns.scatterplot(x=predictions_tsne[:, 0], y=predictions_tsne[:, 1], ax=axs[1]) plt.show() return fig, axs, pca_plot, tsne_plot + +if __name__ == '__main__': + path = 'output' + search_for_weights(path) \ No newline at end of file