Visualization approach 1

This commit is contained in:
Si11ium
2019-09-13 13:36:13 +02:00
parent 18305a9e7e
commit 1386cdfd33
9 changed files with 185 additions and 50 deletions

View File

@ -1,6 +1,3 @@
from collections import defaultdict
from tqdm import tqdm
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
@ -12,57 +9,52 @@ from run_models import *
sns.set()
def search_for_weights(folder):
while not os.path.exists(folder):
if len(os.path.split(folder)) >= 50:
raise FileNotFoundError(f'The folder "{folder}" could not be found')
folder = os.path.join(os.pardir, folder)
for element in os.scandir(folder):
if os.path.exists(element):
if element.is_dir():
search_for_weights(element.path)
elif element.is_file() and element.name.endswith('.ckpt'):
load_and_predict(element)
else:
continue
def load_and_predict(path_like_element):
if any([x.name.endswith('.png') for x in os.scandir(os.path.dirname(path_like_element))]):
return
# Define Loop to search for models and folder with visualizations
model = globals()[path_like_element.path.split(os.sep)[-3]]
splitpath = path_like_element.split(os.sep)
base_dir = os.path.join(*splitpath[:4])
model = globals()[splitpath[2]]
print(f'... loading model named: "{Model.name}" from timestamp: {splitpath[3]}')
pretrained_model = model.load_from_metrics(
weights_path=path_like_element.path,
tags_csv=os.path.join(os.path.dirname(path_like_element), 'default', 'version_0', 'meta_tags.csv'),
weights_path=path_like_element,
tags_csv=os.path.join(base_dir, 'default', 'version_0', 'meta_tags.csv'),
on_gpu=True if torch.cuda.is_available() else False,
map_location=None
)
# Init model and freeze its weights ( for faster inference)
pretrained_model = pretrained_model.to(device)
pretrained_model.eval()
pretrained_model.freeze()
with torch.no_grad():
# Load the data for prediction
dataset = DataContainer(os.path.join(os.pardir, 'data'), 5, 5)
# TODO!!!!!!!!!:
# Hier müssen natürlich auch die date parameter geladen werden!
# Muss ich die val-sets automatisch neu setzen, also immer auf refresh haben, wenn ich validieren möchte?
# Was ist denn eigentlich mein Val Dataset?
# Hab ich irgendwo eine ganze karte?
# Wie sorge ich dafür, dass gewisse karten, also größenverhältnisse usw nicht überrepräsentiert sind?
dataset = DataContainer(os.path.join(os.pardir, 'data', 'validation'), 9, 6).to(device)
dataloader = DataLoader(dataset, shuffle=True, batch_size=len(dataset))
# Do the inference
prediction_dict = defaultdict(list)
for i in tqdm(range(len(dataset)), total=len(dataset)):
p_X = pretrained_model(dataset[i].unsqueeze(0))
for idx in range(len(p_X) - 1):
prediction_dict[idx].append(p_X[idx])
test_pred = [pretrained_model(test_sample)[:-1] for test_sample in dataloader][0]
predictions = [torch.cat(prediction).detach().numpy() for prediction in prediction_dict.values()]
for idx, prediction in enumerate(predictions):
for idx, prediction in enumerate(test_pred):
plot, _ = viz_latent(prediction)
plot.savefig(os.path.join(os.path.dirname(path_like_element), f'latent_space_{idx}.png'))
plot.savefig(os.path.join(base_dir, f'latent_space_{idx}.png'))
def viz_latent(prediction):
try:
prediction = prediction.cpu()
prediction = prediction.numpy()
except AttributeError:
pass
if prediction.shape[-1] <= 1:
raise ValueError('How did this happen?')
elif prediction.shape[-1] == 2:
@ -91,4 +83,4 @@ def viz_latent(prediction):
if __name__ == '__main__':
path = 'output'
search_for_weights(path)
search_for_weights(search_for_weights, path)