Visualization approach 1
This commit is contained in:
parent
18305a9e7e
commit
1386cdfd33
15
dataset.py
15
dataset.py
@ -44,12 +44,6 @@ def build_parse_commands():
|
|||||||
|
|
||||||
class AbstractDataset(ConcatDataset, ABC):
|
class AbstractDataset(ConcatDataset, ABC):
|
||||||
|
|
||||||
# maps = ['hotel', 'tum','gallery', 'queens', 'oet']
|
|
||||||
@property
|
|
||||||
def maps(self):
|
|
||||||
# return ['test', 'test2']
|
|
||||||
return ['hotel', 'tum','gallery', 'queens', 'oet']
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def raw_filenames(self):
|
def raw_filenames(self):
|
||||||
@ -71,8 +65,13 @@ class AbstractDataset(ConcatDataset, ABC):
|
|||||||
self.path = path
|
self.path = path
|
||||||
self.refresh = refresh
|
self.refresh = refresh
|
||||||
self.transforms = transforms or None
|
self.transforms = transforms or None
|
||||||
|
self.maps = list(set([x.name.split('_')[0] for x in os.scandir(os.path.join(self.path, 'raw'))]))
|
||||||
super(AbstractDataset, self).__init__(datasets=self._load_datasets())
|
super(AbstractDataset, self).__init__(datasets=self._load_datasets())
|
||||||
|
|
||||||
|
def to(self, device):
|
||||||
|
self.datasets = [dataset.to(device) for dataset in self.datasets]
|
||||||
|
return self
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def process(self, filepath):
|
def process(self, filepath):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
@ -195,6 +194,10 @@ class Trajectories(Dataset):
|
|||||||
total_len = self.data.size()[0]
|
total_len = self.data.size()[0]
|
||||||
return total_len - (self.size * self.step - (self.step - 1))
|
return total_len - (self.size * self.step - (self.step - 1))
|
||||||
|
|
||||||
|
def to(self, device):
|
||||||
|
self.data = self.data.to(device)
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
class MapContainer(AbstractDataset):
|
class MapContainer(AbstractDataset):
|
||||||
|
|
||||||
|
@ -50,7 +50,7 @@ class AdversarialAELightningOverrides(LightningModuleOverrides):
|
|||||||
|
|
||||||
# Calculate the mean over both the real and the fake acc
|
# Calculate the mean over both the real and the fake acc
|
||||||
# ToDo: do i need to compute this seperate?
|
# ToDo: do i need to compute this seperate?
|
||||||
d_loss = 0.5 * torch.add(d_loss_real, d_loss_fake)
|
d_loss = 0.5 * torch.add(d_loss_real, d_loss_fake) * 0.001
|
||||||
return {'loss': d_loss}
|
return {'loss': d_loss}
|
||||||
|
|
||||||
elif optimizer_i == 1:
|
elif optimizer_i == 1:
|
||||||
@ -69,7 +69,7 @@ class AdversarialAELightningOverrides(LightningModuleOverrides):
|
|||||||
# This is Fucked up, why do i need to put an additional empty list here?
|
# This is Fucked up, why do i need to put an additional empty list here?
|
||||||
def configure_optimizers(self):
|
def configure_optimizers(self):
|
||||||
return [Adam(self.network.discriminator.parameters(), lr=0.02),
|
return [Adam(self.network.discriminator.parameters(), lr=0.02),
|
||||||
Adam([*self.network.encoder.parameters(), *self.network.decoder.parameters()], lr=0.02)],\
|
Adam([*self.network.encoder.parameters(), *self.network.decoder.parameters()], lr=0.02), ],\
|
||||||
[]
|
[]
|
||||||
|
|
||||||
|
|
||||||
|
48
networks/attention_based_auto_enoder.py
Normal file
48
networks/attention_based_auto_enoder.py
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
from torch.optim import Adam
|
||||||
|
|
||||||
|
from .modules import *
|
||||||
|
from torch.nn.functional import mse_loss
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
|
||||||
|
#######################
|
||||||
|
# Basic AE-Implementation
|
||||||
|
class AutoEncoder(AbstractNeuralNetwork, ABC):
|
||||||
|
|
||||||
|
def __init__(self, latent_dim: int=0, features: int = 0, **kwargs):
|
||||||
|
assert latent_dim and features
|
||||||
|
super(AutoEncoder, self).__init__()
|
||||||
|
self.latent_dim = latent_dim
|
||||||
|
self.features = features
|
||||||
|
self.encoder = Encoder(self.latent_dim)
|
||||||
|
self.decoder = Decoder(self.latent_dim, self.features)
|
||||||
|
|
||||||
|
def forward(self, batch: Tensor):
|
||||||
|
# Encoder
|
||||||
|
# outputs, hidden (Batch, Timesteps aka. Size, Features / Latent Dim Size)
|
||||||
|
z = self.encoder(batch)
|
||||||
|
# Decoder
|
||||||
|
# First repeat the data accordingly to the batch size
|
||||||
|
z_repeatet = Repeater((batch.shape[0], batch.shape[1], -1))(z)
|
||||||
|
x_hat = self.decoder(z_repeatet)
|
||||||
|
return z, x_hat
|
||||||
|
|
||||||
|
|
||||||
|
class AutoEncoderLightningOverrides(LightningModuleOverrides):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(AutoEncoderLightningOverrides, self).__init__()
|
||||||
|
|
||||||
|
def training_step(self, x, batch_nb):
|
||||||
|
# ToDo: We need a new loss function, fullfilling all attention needs
|
||||||
|
# z, x_hat
|
||||||
|
_, x_hat = self.forward(x)
|
||||||
|
loss = mse_loss(x, x_hat)
|
||||||
|
return {'loss': loss}
|
||||||
|
|
||||||
|
def configure_optimizers(self):
|
||||||
|
return [Adam(self.parameters(), lr=0.02)]
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
raise PermissionError('Get out of here - never run this module')
|
@ -14,6 +14,9 @@ from torch.utils.data import DataLoader
|
|||||||
from dataset import DataContainer
|
from dataset import DataContainer
|
||||||
|
|
||||||
|
|
||||||
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
|
||||||
|
|
||||||
class LightningModuleOverrides:
|
class LightningModuleOverrides:
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -25,8 +28,8 @@ class LightningModuleOverrides:
|
|||||||
|
|
||||||
@data_loader
|
@data_loader
|
||||||
def tng_dataloader(self):
|
def tng_dataloader(self):
|
||||||
num_workers = 0 # os.cpu_count() // 2
|
num_workers = 0 # os.cpu_count() // 2
|
||||||
return DataLoader(DataContainer('data', self.size, self.step),
|
return DataLoader(DataContainer(os.path.join('data', 'training'), self.size, self.step),
|
||||||
shuffle=True, batch_size=10000, num_workers=num_workers)
|
shuffle=True, batch_size=10000, num_workers=num_workers)
|
||||||
|
|
||||||
|
|
||||||
@ -236,6 +239,19 @@ class Encoder(Module):
|
|||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionEncoder(Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(AttentionEncoder, self).__init__()
|
||||||
|
self.l_stack = TimeDistributed(EncoderLinearStack())
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
tensor = self.l_stack(x)
|
||||||
|
torch.bmm() # TODO Add Attention here
|
||||||
|
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
class PoolingEncoder(Module):
|
class PoolingEncoder(Module):
|
||||||
|
|
||||||
def __init__(self, lat_dim, variational=False):
|
def __init__(self, lat_dim, variational=False):
|
||||||
|
@ -4,9 +4,6 @@ from networks.modules import *
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
||||||
|
|
||||||
|
|
||||||
class SeperatingAdversarialAutoEncoder(Module):
|
class SeperatingAdversarialAutoEncoder(Module):
|
||||||
|
|
||||||
def __init__(self, latent_dim, features):
|
def __init__(self, latent_dim, features):
|
||||||
@ -58,7 +55,7 @@ class SeparatingAdversarialAELightningOverrides(LightningModuleOverrides):
|
|||||||
|
|
||||||
# Calculate the mean over bot the real and the fake acc
|
# Calculate the mean over bot the real and the fake acc
|
||||||
# ToDo: do i need to compute this seperate?
|
# ToDo: do i need to compute this seperate?
|
||||||
d_loss = 0.5 * torch.add(temporal_loss_real, temporal_loss_fake)
|
d_loss = 0.5 * torch.add(temporal_loss_real, temporal_loss_fake) * 0.001
|
||||||
return {'loss': d_loss}
|
return {'loss': d_loss}
|
||||||
|
|
||||||
if optimizer_i == 1:
|
if optimizer_i == 1:
|
||||||
@ -80,7 +77,7 @@ class SeparatingAdversarialAELightningOverrides(LightningModuleOverrides):
|
|||||||
|
|
||||||
# Calculate the mean over bot the real and the fake acc
|
# Calculate the mean over bot the real and the fake acc
|
||||||
# ToDo: do i need to compute this seperate?
|
# ToDo: do i need to compute this seperate?
|
||||||
d_loss = 0.5 * torch.add(spatial_loss_real, spatial_loss_fake)
|
d_loss = 0.5 * torch.add(spatial_loss_real, spatial_loss_fake) * 0.001
|
||||||
return {'loss': d_loss}
|
return {'loss': d_loss}
|
||||||
|
|
||||||
elif optimizer_i == 2:
|
elif optimizer_i == 2:
|
||||||
|
@ -22,7 +22,6 @@ args.add_argument('--model', default='Model')
|
|||||||
args.add_argument('--refresh', type=strtobool, default=False)
|
args.add_argument('--refresh', type=strtobool, default=False)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# ToDo: How to implement this better?
|
# ToDo: How to implement this better?
|
||||||
# other_classes = [AutoEncoder, AutoEncoderLightningOverrides]
|
# other_classes = [AutoEncoder, AutoEncoderLightningOverrides]
|
||||||
class Model(AutoEncoderLightningOverrides, LightningModule):
|
class Model(AutoEncoderLightningOverrides, LightningModule):
|
||||||
|
30
viz/utils.py
Normal file
30
viz/utils.py
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
def search_for_weights(func, folder):
|
||||||
|
while not os.path.exists(folder):
|
||||||
|
if len(os.path.split(folder)) >= 50:
|
||||||
|
raise FileNotFoundError(f'The folder "{folder}" could not be found')
|
||||||
|
folder = os.path.join(os.pardir, folder)
|
||||||
|
|
||||||
|
if any([x.name.endswith('.png') for x in os.scandir(folder)]):
|
||||||
|
return
|
||||||
|
|
||||||
|
if any(['.ckpt' in element.name and element.is_dir() for element in os.scandir(folder)]):
|
||||||
|
_, _, filenames = next(os.walk(os.path.join(folder, 'weights.ckpt')))
|
||||||
|
filenames.sort(key=lambda f: int(''.join(filter(str.isdigit, f))))
|
||||||
|
func(os.path.join(folder, 'weights.ckpt', filenames[-1]))
|
||||||
|
return
|
||||||
|
|
||||||
|
for element in os.scandir(folder):
|
||||||
|
if os.path.exists(element):
|
||||||
|
if element.is_dir():
|
||||||
|
search_for_weights(func, element.path)
|
||||||
|
elif element.is_file() and element.name.endswith('.ckpt'):
|
||||||
|
func(element)
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
raise PermissionError('This file should not be called.')
|
@ -1,6 +1,3 @@
|
|||||||
from collections import defaultdict
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
from sklearn.manifold import TSNE
|
from sklearn.manifold import TSNE
|
||||||
from sklearn.decomposition import PCA
|
from sklearn.decomposition import PCA
|
||||||
|
|
||||||
@ -12,57 +9,52 @@ from run_models import *
|
|||||||
sns.set()
|
sns.set()
|
||||||
|
|
||||||
|
|
||||||
def search_for_weights(folder):
|
|
||||||
while not os.path.exists(folder):
|
|
||||||
if len(os.path.split(folder)) >= 50:
|
|
||||||
raise FileNotFoundError(f'The folder "{folder}" could not be found')
|
|
||||||
folder = os.path.join(os.pardir, folder)
|
|
||||||
for element in os.scandir(folder):
|
|
||||||
if os.path.exists(element):
|
|
||||||
if element.is_dir():
|
|
||||||
search_for_weights(element.path)
|
|
||||||
elif element.is_file() and element.name.endswith('.ckpt'):
|
|
||||||
load_and_predict(element)
|
|
||||||
else:
|
|
||||||
continue
|
|
||||||
|
|
||||||
|
|
||||||
def load_and_predict(path_like_element):
|
def load_and_predict(path_like_element):
|
||||||
if any([x.name.endswith('.png') for x in os.scandir(os.path.dirname(path_like_element))]):
|
|
||||||
return
|
|
||||||
|
|
||||||
# Define Loop to search for models and folder with visualizations
|
# Define Loop to search for models and folder with visualizations
|
||||||
model = globals()[path_like_element.path.split(os.sep)[-3]]
|
splitpath = path_like_element.split(os.sep)
|
||||||
|
base_dir = os.path.join(*splitpath[:4])
|
||||||
|
model = globals()[splitpath[2]]
|
||||||
|
print(f'... loading model named: "{Model.name}" from timestamp: {splitpath[3]}')
|
||||||
pretrained_model = model.load_from_metrics(
|
pretrained_model = model.load_from_metrics(
|
||||||
weights_path=path_like_element.path,
|
weights_path=path_like_element,
|
||||||
tags_csv=os.path.join(os.path.dirname(path_like_element), 'default', 'version_0', 'meta_tags.csv'),
|
tags_csv=os.path.join(base_dir, 'default', 'version_0', 'meta_tags.csv'),
|
||||||
on_gpu=True if torch.cuda.is_available() else False,
|
on_gpu=True if torch.cuda.is_available() else False,
|
||||||
map_location=None
|
map_location=None
|
||||||
)
|
)
|
||||||
|
|
||||||
# Init model and freeze its weights ( for faster inference)
|
# Init model and freeze its weights ( for faster inference)
|
||||||
|
pretrained_model = pretrained_model.to(device)
|
||||||
pretrained_model.eval()
|
pretrained_model.eval()
|
||||||
pretrained_model.freeze()
|
pretrained_model.freeze()
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
|
||||||
# Load the data for prediction
|
# Load the data for prediction
|
||||||
dataset = DataContainer(os.path.join(os.pardir, 'data'), 5, 5)
|
|
||||||
|
# TODO!!!!!!!!!:
|
||||||
|
# Hier müssen natürlich auch die date parameter geladen werden!
|
||||||
|
# Muss ich die val-sets automatisch neu setzen, also immer auf refresh haben, wenn ich validieren möchte?
|
||||||
|
# Was ist denn eigentlich mein Val Dataset?
|
||||||
|
# Hab ich irgendwo eine ganze karte?
|
||||||
|
# Wie sorge ich dafür, dass gewisse karten, also größenverhältnisse usw nicht überrepräsentiert sind?
|
||||||
|
dataset = DataContainer(os.path.join(os.pardir, 'data', 'validation'), 9, 6).to(device)
|
||||||
|
dataloader = DataLoader(dataset, shuffle=True, batch_size=len(dataset))
|
||||||
|
|
||||||
# Do the inference
|
# Do the inference
|
||||||
prediction_dict = defaultdict(list)
|
test_pred = [pretrained_model(test_sample)[:-1] for test_sample in dataloader][0]
|
||||||
for i in tqdm(range(len(dataset)), total=len(dataset)):
|
|
||||||
p_X = pretrained_model(dataset[i].unsqueeze(0))
|
|
||||||
for idx in range(len(p_X) - 1):
|
|
||||||
prediction_dict[idx].append(p_X[idx])
|
|
||||||
|
|
||||||
predictions = [torch.cat(prediction).detach().numpy() for prediction in prediction_dict.values()]
|
for idx, prediction in enumerate(test_pred):
|
||||||
for idx, prediction in enumerate(predictions):
|
|
||||||
plot, _ = viz_latent(prediction)
|
plot, _ = viz_latent(prediction)
|
||||||
plot.savefig(os.path.join(os.path.dirname(path_like_element), f'latent_space_{idx}.png'))
|
plot.savefig(os.path.join(base_dir, f'latent_space_{idx}.png'))
|
||||||
|
|
||||||
|
|
||||||
def viz_latent(prediction):
|
def viz_latent(prediction):
|
||||||
|
try:
|
||||||
|
prediction = prediction.cpu()
|
||||||
|
prediction = prediction.numpy()
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
|
||||||
if prediction.shape[-1] <= 1:
|
if prediction.shape[-1] <= 1:
|
||||||
raise ValueError('How did this happen?')
|
raise ValueError('How did this happen?')
|
||||||
elif prediction.shape[-1] == 2:
|
elif prediction.shape[-1] == 2:
|
||||||
@ -91,4 +83,4 @@ def viz_latent(prediction):
|
|||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
path = 'output'
|
path = 'output'
|
||||||
search_for_weights(path)
|
search_for_weights(search_for_weights, path)
|
50
viz/viz_map.py
Normal file
50
viz/viz_map.py
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
|
||||||
|
from dataset import *
|
||||||
|
# Plotting
|
||||||
|
# import matplotlib as mlp
|
||||||
|
from matplotlib import pyplot as plt
|
||||||
|
from matplotlib.patches import Polygon
|
||||||
|
from matplotlib.collections import LineCollection, PatchCollection
|
||||||
|
import matplotlib.colors as mcolors
|
||||||
|
import matplotlib.cm as cmaps
|
||||||
|
|
||||||
|
from sklearn.manifold import TSNE
|
||||||
|
from sklearn.decomposition import PCA
|
||||||
|
|
||||||
|
import seaborn as sns
|
||||||
|
from argparse import ArgumentParser
|
||||||
|
|
||||||
|
from viz.utils import search_for_weights
|
||||||
|
|
||||||
|
from run_models import *
|
||||||
|
|
||||||
|
sns.set()
|
||||||
|
|
||||||
|
|
||||||
|
arguments = ArgumentParser()
|
||||||
|
arguments.add_argument('--data', default=os.path.join('data', 'validation'))
|
||||||
|
|
||||||
|
dataset = DataContainer(os.path.join(os.pardir, 'data', 'validation'), 9, 6).to(device)
|
||||||
|
dataloader = DataLoader(dataset, shuffle=True, batch_size=len(dataset))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def viz_map(self, base_map: MapContainer):
|
||||||
|
# Base Map Plotting
|
||||||
|
# filled Triangle
|
||||||
|
patches = [Polygon(base_map.get_triangle_by_key(i), True, color='k') for i in range(len(base_map))]
|
||||||
|
patch_collection = PatchCollection(patches, color='k')
|
||||||
|
|
||||||
|
self.ax.add_collection(patch_collection)
|
||||||
|
print('Basemap Plotted')
|
||||||
|
|
||||||
|
patches = [Polygon(base_map.get_triangle_by_key(i), True, color='k') for i in range(len(base_map))]
|
||||||
|
return PatchCollection(patches, color='k')
|
||||||
|
|
||||||
|
def load_and_predict(folder):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
search_for_weights(load_and_predict, arguments.data)
|
||||||
|
# ToDo: THIS
|
Loading…
x
Reference in New Issue
Block a user