This commit is contained in:
Si11ium
2019-03-07 12:05:58 +01:00
parent bae997feab
commit 95c2ff4200
4 changed files with 128 additions and 48 deletions

View File

@ -24,20 +24,24 @@ def build_args():
return arg_parser.parse_args()
def build_from_soup(soup):
def build_from_soup_or_exp(soup):
particles = soup.historical_particles
particle_dict = [dict(trajectory=[timestamp['weights'] for timestamp in particle],
fitted=[timestamp['fitted'] for timestamp in particle],
loss=[timestamp['loss'] for timestamp in particle],
time=[timestamp['time'] for timestamp in particle]) for particle in particles.values()]
return particle_dict
particle_list = []
for particle in particles.values():
particle_dict = dict(
trajectory=[event['weights'] for event in particle],
time=[event['time'] for event in particle],
action=[event['action'] for event in particle],
counterpart=[event['counterpart'] for event in particle]
)
particle_list.append(particle_dict)
return particle_list
def plot_latent_trajectories(soup_or_experiment, filename='latent_trajectory_plot'):
assert isinstance(soup_or_experiment, (Experiment, Soup))
bupu = cl.scales['11']['div']['RdYlGn']
data_dict = soup_or_experiment.data_storage if isinstance(soup_or_experiment, Experiment) \
else build_from_soup(soup_or_experiment)
data_dict = build_from_soup_or_exp(soup_or_experiment)
scale = cl.interp(bupu, len(data_dict)+1) # Map color scale to N bins
# Fit the mebedding space
@ -91,25 +95,22 @@ def plot_latent_trajectories_3D(soup_or_experiment, filename='plot'):
def norm(val, a=0, b=0.25):
return (val - a) / (b - a)
data_dict = soup_or_experiment.data_storage if isinstance(soup_or_experiment, Experiment) \
else build_from_soup(soup_or_experiment)
data_list = build_from_soup_or_exp(soup_or_experiment)
bupu = cl.scales['11']['div']['RdYlGn']
scale = cl.interp(bupu, len(data_dict)+1) # Map color scale to N bins
scale = cl.interp(bupu, len(data_list)+1) # Map color scale to N bins
# Fit the embedding space
transformer = TSNE()
for particle_dict in data_dict:
array = np.asarray([np.hstack([x.flatten() for x in timestamp]).flatten()
for timestamp in particle_dict['trajectory']])
particle_dict['trajectory'] = array
for particle_dict in data_list:
array = np.asarray(particle_dict['trajectory'])
transformer.fit(array)
# Transform data accordingly and plot it
data = []
for p_id, particle_dict in enumerate(data_dict):
for p_id, particle_dict in enumerate(data_list):
transformed = transformer._fit(particle_dict['trajectory'])
trace = go.Scatter3d(
line_trace = go.Scatter3d(
x=transformed[:, 0],
y=transformed[:, 1],
z=np.asarray(particle_dict['time']),
@ -120,9 +121,28 @@ def plot_latent_trajectories_3D(soup_or_experiment, filename='plot'):
# showlegend=True,
hoverinfo='text',
mode='lines')
data.append(trace)
layout = go.Layout(scene=dict(aspectratio=dict(x=2, y=2, z=1),
line_start = go.Scatter3d(mode='markers', x=[transformed[0, 0]], y=[transformed[0, 1]],
z=np.asarray(particle_dict['time'][0]),
marker=dict(
color='rgb(255, 0, 0)',
size=4
),
showlegend=False
)
line_end = go.Scatter3d(mode='markers', x=[transformed[-1, 0]], y=[transformed[-1, 1]],
z=np.asarray(particle_dict['time'][-1]),
marker=dict(
color='rgb(0, 0, 0)',
size=4
),
showlegend=False
)
data.extend([line_trace, line_start, line_end])
layout = go.Layout(scene=dict(aspectratio=dict(x=2, y=2, z=2),
xaxis=dict(tickwidth=1, title='Transformed X'),
yaxis=dict(tickwidth=1, title='transformed Y'),
zaxis=dict(tickwidth=1, title='Epoch')),