journal linspace basins

This commit is contained in:
steffen-illium
2021-06-25 10:25:25 +02:00
parent cf6eec639f
commit 14d9a533cb
8 changed files with 69 additions and 100 deletions

View File

@ -1,5 +1,6 @@
# from __future__ import annotations
import copy
import random
from typing import Union
import torch
@ -9,7 +10,12 @@ import numpy as np
from torch import optim, Tensor
def prng():
return random.random()
class Net(nn.Module):
@staticmethod
def create_target_weights(input_weight_matrix: Tensor) -> Tensor:
""" Outputting a tensor with the target weights. """
@ -171,3 +177,16 @@ class Net(nn.Module):
SA_steps = 1
return other_net.apply_weights(my_evaluation)
def apply_noise(self, noise_size: float):
""" Changing the weights of a network to values + noise """
for layer_id, layer_name in enumerate(self.state_dict()):
for line_id, line_values in enumerate(self.state_dict()[layer_name]):
for weight_id, weight_value in enumerate(self.state_dict()[layer_name][line_id]):
# network.state_dict()[layer_name][line_id][weight_id] = weight_value + noise
if prng() < 0.5:
self.state_dict()[layer_name][line_id][weight_id] = weight_value + noise_size * prng()
else:
self.state_dict()[layer_name][line_id][weight_id] = weight_value - noise_size * prng()
return self