Final Experiments and Plot adjustments
This commit is contained in:
@@ -65,8 +65,8 @@ class AddTaskDataset(Dataset):
|
||||
|
||||
|
||||
def set_checkpoint(model, out_path, epoch_n, final_model=False):
|
||||
epoch_n = str(epoch_n)
|
||||
if not final_model:
|
||||
epoch_n = str(epoch_n)
|
||||
ckpt_path = Path(out_path) / 'ckpt' / f'{epoch_n.zfill(4)}_model_ckpt.tp'
|
||||
else:
|
||||
if isinstance(epoch_n, str):
|
||||
@@ -145,6 +145,7 @@ def plot_training_particle_types(path_to_dataframe):
|
||||
labels=fix_types.tolist(), colors=PALETTE)
|
||||
|
||||
ax.set(ylabel='Particle Count', xlabel='Epoch')
|
||||
ax.yaxis.get_major_locator().set_params(integer=True)
|
||||
# ax.set_title('Particle Type Count')
|
||||
|
||||
fig.legend(loc="center right", title='Particle Type', bbox_to_anchor=(0.85, 0.5))
|
||||
@@ -219,6 +220,9 @@ def plot_network_connectivity_by_fixtype(path_to_trained_model):
|
||||
legend=False, estimator=None, lw=1)
|
||||
_ = sns.lineplot(y=[0, 1], x=[-1, df['Layer'].max()], legend=False, estimator=None, lw=0)
|
||||
ax.set_title(fixtype)
|
||||
ax.yaxis.get_major_locator().set_params(integer=True)
|
||||
ax.xaxis.get_major_locator().set_params(integer=True)
|
||||
ax.set_ylabel('Normalized Neuron Position (1/n)') # XAXIS Label
|
||||
lines = ax.get_lines()
|
||||
for line in lines:
|
||||
line.set_color(PALETTE[n])
|
||||
@@ -273,7 +277,7 @@ def plot_dropout_stacked_barplot(mdl_path, diff_store_path, metric_class=torchme
|
||||
_ = sns.barplot(data=diff_df, y=metric_name, x='Particle Type', ax=ax[0], palette=colors[:palette_len], ci=None)
|
||||
|
||||
ax[0].set_title(f'{metric_name} after particle dropout')
|
||||
ax[0].set_xlabel('Particle Type')
|
||||
# ax[0].set_xlabel('Particle Type') # XAXIS Label
|
||||
ax[0].set_xticklabels(ax[0].get_xticklabels(), rotation=30)
|
||||
|
||||
ax[1].pie(sorted_particle_dict.values(), labels=sorted_particle_dict.keys(),
|
||||
@@ -345,9 +349,11 @@ def highlight_fixpoints_vs_mnist_mean(mdl_path, dataloader):
|
||||
|
||||
fig, axs = plt.subplots(1, 3)
|
||||
|
||||
for idx, image in enumerate([binary_image, real_image, mnist_mean]):
|
||||
for idx, (image, title) in enumerate(zip([binary_image, real_image, mnist_mean],
|
||||
["Particle Count", "Particle Value", "MNIST mean"])):
|
||||
img = axs[idx].imshow(image.squeeze().detach().cpu())
|
||||
img.axes.axis('off')
|
||||
img.axes.set_title('Random Noise')
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(mdl_path.parent / 'heatmap.png', dpi=300)
|
||||
|
||||
@@ -45,10 +45,9 @@ def test_robustness(model_path, noise_levels=10, seeds=10, log_step_size=10):
|
||||
time_to_vergence = [[0 for _ in range(noise_levels)] for _ in range(len(networks))]
|
||||
time_as_fixpoint = [[0 for _ in range(noise_levels)] for _ in range(len(networks))]
|
||||
row_headers = []
|
||||
|
||||
df = pd.DataFrame(columns=['setting', 'Noise Level', 'Self Train Steps', 'absolute_loss',
|
||||
'Time to convergence', 'Time as fixpoint'])
|
||||
with tqdm(total=(seeds * noise_levels * len(networks))) as pbar:
|
||||
with tqdm(total=(seeds * noise_levels * len(networks)), desc='Per Particle Robustness') as pbar:
|
||||
for setting, fixpoint in enumerate(networks): # 1 / n
|
||||
row_headers.append(fixpoint.name)
|
||||
for seed in range(seeds): # n / 1
|
||||
|
||||
Reference in New Issue
Block a user