Viz updates

This commit is contained in:
Si11ium
2019-03-07 20:07:37 +01:00
parent 95c2ff4200
commit 7a76b1ba88
4 changed files with 35 additions and 41 deletions

View File

@ -14,7 +14,7 @@ import colorlover as cl
import dill
from sklearn.manifold.t_sne import TSNE
from sklearn.manifold.t_sne import TSNE, PCA
def build_args():
@ -31,8 +31,8 @@ def build_from_soup_or_exp(soup):
particle_dict = dict(
trajectory=[event['weights'] for event in particle],
time=[event['time'] for event in particle],
action=[event['action'] for event in particle],
counterpart=[event['counterpart'] for event in particle]
action=[event.get('action', None) for event in particle],
counterpart=[event.get('counterpart', None) for event in particle]
)
particle_list.append(particle_dict)
return particle_list
@ -101,15 +101,18 @@ def plot_latent_trajectories_3D(soup_or_experiment, filename='plot'):
scale = cl.interp(bupu, len(data_list)+1) # Map color scale to N bins
# Fit the embedding space
transformer = TSNE()
transformer = PCA(n_components=2)
array = []
for particle_dict in data_list:
array = np.asarray(particle_dict['trajectory'])
transformer.fit(array)
array.append(particle_dict['trajectory'])
transformer.fit(np.vstack(array))
# Transform data accordingly and plot it
data = []
for p_id, particle_dict in enumerate(data_list):
transformed = transformer._fit(particle_dict['trajectory'])
transformed = transformer.transform(particle_dict['trajectory'])
line_trace = go.Scatter3d(
x=transformed[:, 0],
y=transformed[:, 1],