Final Experiments and Plot adjustments

This commit is contained in:
Steffen Illium
2022-03-12 11:39:28 +01:00
parent dd2458da4a
commit 0ba3994325
6 changed files with 214 additions and 79 deletions

View File

@@ -65,8 +65,8 @@ class AddTaskDataset(Dataset):
def set_checkpoint(model, out_path, epoch_n, final_model=False):
epoch_n = str(epoch_n)
if not final_model:
epoch_n = str(epoch_n)
ckpt_path = Path(out_path) / 'ckpt' / f'{epoch_n.zfill(4)}_model_ckpt.tp'
else:
if isinstance(epoch_n, str):
@@ -145,6 +145,7 @@ def plot_training_particle_types(path_to_dataframe):
labels=fix_types.tolist(), colors=PALETTE)
ax.set(ylabel='Particle Count', xlabel='Epoch')
ax.yaxis.get_major_locator().set_params(integer=True)
# ax.set_title('Particle Type Count')
fig.legend(loc="center right", title='Particle Type', bbox_to_anchor=(0.85, 0.5))
@@ -219,6 +220,9 @@ def plot_network_connectivity_by_fixtype(path_to_trained_model):
legend=False, estimator=None, lw=1)
_ = sns.lineplot(y=[0, 1], x=[-1, df['Layer'].max()], legend=False, estimator=None, lw=0)
ax.set_title(fixtype)
ax.yaxis.get_major_locator().set_params(integer=True)
ax.xaxis.get_major_locator().set_params(integer=True)
ax.set_ylabel('Normalized Neuron Position (1/n)') # XAXIS Label
lines = ax.get_lines()
for line in lines:
line.set_color(PALETTE[n])
@@ -273,7 +277,7 @@ def plot_dropout_stacked_barplot(mdl_path, diff_store_path, metric_class=torchme
_ = sns.barplot(data=diff_df, y=metric_name, x='Particle Type', ax=ax[0], palette=colors[:palette_len], ci=None)
ax[0].set_title(f'{metric_name} after particle dropout')
ax[0].set_xlabel('Particle Type')
# ax[0].set_xlabel('Particle Type') # XAXIS Label
ax[0].set_xticklabels(ax[0].get_xticklabels(), rotation=30)
ax[1].pie(sorted_particle_dict.values(), labels=sorted_particle_dict.keys(),
@@ -345,9 +349,11 @@ def highlight_fixpoints_vs_mnist_mean(mdl_path, dataloader):
fig, axs = plt.subplots(1, 3)
for idx, image in enumerate([binary_image, real_image, mnist_mean]):
for idx, (image, title) in enumerate(zip([binary_image, real_image, mnist_mean],
["Particle Count", "Particle Value", "MNIST mean"])):
img = axs[idx].imshow(image.squeeze().detach().cpu())
img.axes.axis('off')
img.axes.set_title('Random Noise')
plt.tight_layout()
plt.savefig(mdl_path.parent / 'heatmap.png', dpi=300)

View File

@@ -45,10 +45,9 @@ def test_robustness(model_path, noise_levels=10, seeds=10, log_step_size=10):
time_to_vergence = [[0 for _ in range(noise_levels)] for _ in range(len(networks))]
time_as_fixpoint = [[0 for _ in range(noise_levels)] for _ in range(len(networks))]
row_headers = []
df = pd.DataFrame(columns=['setting', 'Noise Level', 'Self Train Steps', 'absolute_loss',
'Time to convergence', 'Time as fixpoint'])
with tqdm(total=(seeds * noise_levels * len(networks))) as pbar:
with tqdm(total=(seeds * noise_levels * len(networks)), desc='Per Particle Robustness') as pbar:
for setting, fixpoint in enumerate(networks): # 1 / n
row_headers.append(fixpoint.name)
for seed in range(seeds): # n / 1