application losses II
This commit is contained in:
@ -105,10 +105,9 @@ class RobustnessComparisonExperiment:
|
|||||||
|
|
||||||
for i, fixpoint in enumerate(self.id_functions):
|
for i, fixpoint in enumerate(self.id_functions):
|
||||||
row_headers.append(fixpoint.name)
|
row_headers.append(fixpoint.name)
|
||||||
|
loss_per_application = [[0 for _ in range(10)] for _ in range(len(self.id_functions))]
|
||||||
for seed in range(10):
|
for seed in range(10):
|
||||||
for noise_level in noise_range:
|
for noise_level in noise_range:
|
||||||
application_losses = []
|
|
||||||
|
|
||||||
clone = Net(fixpoint.input_size, fixpoint.hidden_size, fixpoint.out_size,
|
clone = Net(fixpoint.input_size, fixpoint.hidden_size, fixpoint.out_size,
|
||||||
f"{fixpoint.name}_clone_noise10e-{noise_level}")
|
f"{fixpoint.name}_clone_noise10e-{noise_level}")
|
||||||
clone.load_state_dict(copy.deepcopy(fixpoint.state_dict()))
|
clone.load_state_dict(copy.deepcopy(fixpoint.state_dict()))
|
||||||
@ -129,7 +128,8 @@ class RobustnessComparisonExperiment:
|
|||||||
clone_weight_post_application = clone.input_weight_matrix()
|
clone_weight_post_application = clone.input_weight_matrix()
|
||||||
target_data_post_application = clone.create_target_weights(clone_weight_post_application)
|
target_data_post_application = clone.create_target_weights(clone_weight_post_application)
|
||||||
|
|
||||||
application_losses.append(F.l1_loss(target_data_pre_application, target_data_post_application))
|
loss_per_application[seed][noise_level] = (F.l1_loss(target_data_pre_application,
|
||||||
|
target_data_post_application))
|
||||||
|
|
||||||
|
|
||||||
if print_it:
|
if print_it:
|
||||||
|
Reference in New Issue
Block a user