51 lines
1.3 KiB
Python
51 lines
1.3 KiB
Python
|
|
from dataset import *
|
|
# Plotting
|
|
# import matplotlib as mlp
|
|
from matplotlib import pyplot as plt
|
|
from matplotlib.patches import Polygon
|
|
from matplotlib.collections import LineCollection, PatchCollection
|
|
import matplotlib.colors as mcolors
|
|
import matplotlib.cm as cmaps
|
|
|
|
from sklearn.manifold import TSNE
|
|
from sklearn.decomposition import PCA
|
|
|
|
import seaborn as sns
|
|
from argparse import ArgumentParser
|
|
|
|
from viz.utils import search_for_weights
|
|
|
|
from run_models import *
|
|
|
|
sns.set()
|
|
|
|
|
|
arguments = ArgumentParser()
|
|
arguments.add_argument('--data', default=os.path.join('data', 'validation'))
|
|
|
|
dataset = DataContainer(os.path.join(os.pardir, 'data', 'validation'), 9, 6).to(device)
|
|
dataloader = DataLoader(dataset, shuffle=True, batch_size=len(dataset))
|
|
|
|
|
|
|
|
def viz_map(self, base_map: MapContainer):
|
|
# Base Map Plotting
|
|
# filled Triangle
|
|
patches = [Polygon(base_map.get_triangle_by_key(i), True, color='k') for i in range(len(base_map))]
|
|
patch_collection = PatchCollection(patches, color='k')
|
|
|
|
self.ax.add_collection(patch_collection)
|
|
print('Basemap Plotted')
|
|
|
|
patches = [Polygon(base_map.get_triangle_by_key(i), True, color='k') for i in range(len(base_map))]
|
|
return PatchCollection(patches, color='k')
|
|
|
|
def load_and_predict(folder):
|
|
pass
|
|
|
|
|
|
if __name__ == '__main__':
|
|
search_for_weights(load_and_predict, arguments.data)
|
|
# ToDo: THIS
|