transition

This commit is contained in:
Si11ium
2021-02-01 09:59:56 +01:00
parent 4c489237d7
commit 578727d043
35 changed files with 177 additions and 305 deletions

@ -1,9 +1,8 @@
from typing import Union
from functools import reduce
from statistics import stdev
from sklearn.cluster import Birch, KMeans, DBSCAN
from sklearn.cluster import Birch, KMeans
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
@ -16,7 +15,7 @@ from matplotlib.collections import LineCollection, PatchCollection
import matplotlib.colors as mcolors
import matplotlib.cm as cmaps
from math import pi
from math import pi, cos, sin
def search_for_weights(func, folder, file_type='latent_space'):
@ -24,10 +23,13 @@ def search_for_weights(func, folder, file_type='latent_space'):
if len(os.path.split(folder)) >= 50:
raise FileNotFoundError(f'The folder "{folder}" could not be found')
folder = os.path.join(os.pardir, folder)
if any([file_type in x.name for x in os.scandir(folder)]):
return
elif folder == 'weights' and os.path.isdir(folder):
return
if any(['.ckpt' in element.name and element.is_dir() for element in os.scandir(folder)]):
if any(['weights.ckpt' in element.name and element.is_dir() for element in os.scandir(folder)]) and False:
_, _, filenames = next(os.walk(os.path.join(folder, 'weights.ckpt')))
filenames.sort(key=lambda f: int(''.join(filter(str.isdigit, f))))
func(os.path.join(folder, 'weights.ckpt', filenames[-1]))
@ -37,7 +39,7 @@ def search_for_weights(func, folder, file_type='latent_space'):
if os.path.exists(element):
if element.is_dir():
search_for_weights(func, element.path, file_type=file_type)
elif element.is_file() and element.name.endswith('.ckpt'):
elif element.is_file() and element.name.endswith('weights.ckpt'):
func(element.path)
else:
continue
@ -47,16 +49,15 @@ class Printer(object):
def __init__(self, model: AbstractNeuralNetwork, ax=None):
self.norm = mcolors.Normalize(vmin=0, vmax=1)
self.colormap = cmaps.gist_rainbow
self.colormap = cmaps.tab20
self.network = model
self.fig = plt.figure(dpi=300)
self.ax = ax if ax else plt.subplot(1, 1, 1)
pass
def colorize(self, x, min_val: Union[float, None] = None, max_val: Union[float, None] = None,
colormap=cmaps.rainbow, **kwargs):
def colorize(self, x, min_val: Union[float, None] = None, max_val: Union[float, None] = None, **kwargs):
norm = mcolors.Normalize(vmin=min_val, vmax=max_val)
colored = colormap(norm(x))
colored = self.colormap(norm(x))
return colored
@staticmethod
@ -79,20 +80,26 @@ class Printer(object):
clusterer.init = np.asarray(centers)
else:
# clusterer = Birch(n_clusters=None)
clusterer = Birch()
clusterer = KMeans(3)
labels = clusterer.fit_predict(data)
print('Birch Clustering Sucessfull')
return labels
def print_possible_latent_spaces(self, data: Trajectories, n: Union[int, str] = 1000, **kwargs):
predictions, _ = self._gather_predictions(data, n)
def print_possible_latent_spaces(self, data: Trajectories, n: Union[int, str] = 1000,
cluster_by_motion=True, **kwargs):
predictions, motion_sequence = self._gather_predictions(data, n)
if len(predictions) >= 2:
predictions += (torch.cat(predictions, dim=-1), )
labels = self.cluster_data(predictions[-1])
if cluster_by_motion:
motion_analyzer = MotionAnalyser()
labels = motion_analyzer.cluster_motion(motion_sequence)
else:
labels = self.cluster_data(predictions[-1])
for idx, prediction in enumerate(predictions):
self.print_latent_space(prediction, labels, running_index=idx, **kwargs)
self.print_latent_space(prediction, labels.squeeze(), running_index=idx, **kwargs)
def print_latent_space(self, prediction, labels, running_index=0, save=None):
@ -179,12 +186,13 @@ class Printer(object):
print("Gathering Predictions")
n = n if isinstance(n, int) and n else len(data) - (data.size * data.step)
idxs = np.random.choice(np.arange(len(data) - data.step * data.size), n, replace=False)
idxs = np.random.choice(np.arange(len(data)), n, replace=True)
complete_data = torch.stack([data.get_both_by_key(idx) for idx in idxs], dim=0)
segment_coords, trajectories = complete_data[:, :, :2], complete_data[:, :, 2:]
if color_by_movement:
motion_analyser = MotionAnalyser()
predictions = (motion_analyser.cluster_motion(segment_coords), )
predictions = (motion_analyser.cluster_motion(segment_coords,
clustering=kwargs.get('clustering', 'kmeans')), )
else:
with torch.no_grad():
@ -193,7 +201,7 @@ class Printer(object):
return predictions, segment_coords
@staticmethod
def colorize_as_hsv(self, x, min_val: Union[float, None] = None, max_val: Union[float, None] = None,
def colorize_as_hsv(x, min_val: Union[float, None] = None, max_val: Union[float, None] = None,
colormap=cmaps.rainbow, **kwargs):
norm = mcolors.Normalize(vmin=min_val, vmax=max_val)
colored = colormap(norm(x))
@ -248,11 +256,12 @@ class Printer(object):
patches = [Polygon(base_map[i], True, color='black') for i in range(len(base_map))]
return PatchCollection(patches, color='black')
def print_trajec_on_basemap(self, data, base_map: Map, save=False, color_by_movement=False, **kwargs):
def print_trajec_on_basemap(self, data, base_map: Map, save=False, show=False, color_by_movement=False, **kwargs):
"""
:rtype: object
"""
prediction_segments = self._gather_predictions(data, color_by_movement=color_by_movement, **kwargs)
trajectory_shapes = self._build_trajectory_shapes(*prediction_segments, **kwargs)
map_shapes = self._build_map_shapes(base_map)
@ -266,7 +275,8 @@ class Printer(object):
self.save(save)
else:
self.save(base_map.name)
pass
if show:
self.show()
@staticmethod
def show():
@ -284,15 +294,25 @@ class MotionAnalyser(object):
pass
def _sequential_pairwise_map(self, func, xy_sequence, on_deltas=False):
zipped_list = [x for x in zip(xy_sequence[:-1], xy_sequence[1:])]
if on_deltas:
zipped_list = [x for x in zip(xy_sequence[:-1], xy_sequence[1:])]
zipped_list = [self.delta(*movement) for movement in zipped_list]
else:
pass
zipped_list = xy_sequence
return [func(*xy) for xy in zipped_list]
@staticmethod
def _rotatePoint(point, center, angle, is_rad=True):
angle = (angle) * (pi / 180) if not is_rad else angle # Convert to radians if
rotatedX = cos(angle) * (point[0] - center[0]) - sin(angle) * (point[1] - center[1]) + center[0]
rotatedY = sin(angle) * (point[0] - center[0]) + cos(angle) * (point[1] - center[1]) + center[1]
return rotatedX, rotatedY
@staticmethod
def delta(x1y1, x2y2):
x1, y1 = x1y1
@ -306,10 +326,16 @@ class MotionAnalyser(object):
return r
@staticmethod
def get_theta(deltax, deltay, rad=False):
def get_theta(deltax, deltay, as_radians=True):
# https://mathinsight.org/polar_coordinates
try:
deltax = torch.as_tensor(deltax)
deltay = torch.as_tensor(deltay)
except:
pass
theta = torch.atan2(deltay, deltax)
return theta if rad else theta * 180 / pi
return theta if as_radians else theta * 180 / pi
def get_theta_for_sequence(self, xy_sequence):
ts = self._sequential_pairwise_map(self.get_theta, xy_sequence, on_deltas=True)
@ -319,38 +345,90 @@ class MotionAnalyser(object):
rs = self._sequential_pairwise_map(self.get_r, xy_sequence, on_deltas=True)
return rs
def move_to_zero(self, xy_sequence):
old_origin = xy_sequence[0]
return xy_sequence - old_origin
def get_unique_seq_identifier(self, xy_sequence):
xy_sequence = xy_sequence.cpu()
# Move all points so that the first point is always (0, 0)
# moved_sequence = self.move_to_zero(xy_sequence)
moved_sequence = xy_sequence
# Rotate, so that x is zero for last point
angle = self.get_theta(*self.delta(moved_sequence[0], moved_sequence[1]))
rotated_sequence = torch.as_tensor([self._rotatePoint(point, moved_sequence[0], -angle)
for point in moved_sequence[1:]])
rotated_sequence = torch.cat((moved_sequence[0].unsqueeze(0), rotated_sequence))
# rotated_sequence = moved_sequence
std, mean = torch.std_mean(rotated_sequence)
rotated_sequence = (rotated_sequence - mean) / std
def centroid_for(arr):
try:
arr = torch.as_tensor(arr)
except:
pass
size = arr.shape[0]
sum_x = torch.sum(arr[:, 0])
sum_y = torch.sum(arr[:, 1])
return sum_x/size, sum_y/size
# Globals
global_delta = self.delta(xy_sequence[0], xy_sequence[-1])
global_theta = self.get_theta(*global_delta)
global_delta = self.delta(rotated_sequence[0], rotated_sequence[-1])
global_r = self.get_r(*global_delta)
def f(*args):
return args
centroid = centroid_for(self._sequential_pairwise_map(f, rotated_sequence, on_deltas=True))
hull_length = sum(self.get_r_for_sequence(torch.cat((rotated_sequence, rotated_sequence[0].unsqueeze(0)))))
# For Each
theta_seq = self.get_theta_for_sequence(xy_sequence)
theta_seq = self.get_theta_for_sequence(rotated_sequence)
mean_theta = sum(theta_seq) / len(theta_seq)
theta_sum = sum([abs(theta) for theta in theta_seq])
std_theta = stdev(map(float, theta_seq))
return torch.stack((global_r, torch.as_tensor(std_theta), mean_theta, global_theta))
return torch.stack((centroid[0], centroid[1], torch.as_tensor(std_theta), mean_theta, theta_sum, hull_length))
def cluster_motion(self, trajectory_samples, cluster_class=KMeans):
cluster_class = cluster_class(3)
def cluster_motion(self, trajectory_samples, clustering='kmeans'):
if clustering.lower() == 'kmeans':
cluster_class = KMeans(3)
std, mean = torch.std_mean(trajectory_samples, dim=0)
trajectory_samples = (trajectory_samples - mean) / std
std, mean = torch.std_mean(trajectory_samples, dim=0)
trajectory_samples = (trajectory_samples - mean) / std
unique_seq_identifiers = torch.stack([self.get_unique_seq_identifier(trajectory)
for trajectory in trajectory_samples])
unique_seq_identifiers = torch.stack([self.get_unique_seq_identifier(trajectory)
for trajectory in trajectory_samples])
clustered_movement = cluster_class.fit_predict(unique_seq_identifiers)
elif clustering.lower() == 'fastdtw':
# Move all points so that the first point is always (0, 0)
moved_sequence = self.move_to_zero(trajectory_samples)
rotated_sequences = []
for sequence in moved_sequence:
# Rotate, so that x is zero for last point
angle = self.get_theta(*self.delta(sequence[0], sequence[1]))
rotated_sequence = torch.as_tensor([self._rotatePoint(point, sequence[0], -angle)
for point in sequence[1:]])
rotated_sequence = torch.cat((sequence[0].unsqueeze(0), rotated_sequence)).unsqueeze(0)
rotated_sequences.append(rotated_sequence)
# deltas = [self._sequential_pairwise_map(self.delta, x, on_deltas=False) for x in rotated_sequence]
t = torch.cat(rotated_sequences)
# t = torch.as_tensor(deltas)
z = torch.zeros((t.shape[0], t.shape[0]))
clustered_movement = cluster_class.fit_predict(unique_seq_identifiers)
if False:
from sklearn.decomposition import PCA
p = PCA(2)
t = p.fit_transform(unique_seq_identifiers)
f = plt.figure()
plt.scatter(t[:, 0], t[:,1])
plt.show()
import fastdtw
for idx, x in tqdm(enumerate(t), total=z.shape[0]):
for idy, y in enumerate(t):
z[idx, idy] = fastdtw.fastdtw(x, y)[0]
from sklearn.cluster.hierarchical import AgglomerativeClustering
clusterer = KMeans(3)
clustered_movement = clusterer.fit_predict(z)
else:
raise NotImplementedError
return clustered_movement.reshape(-1, 1)