journal linspace basins
This commit is contained in:
19
network.py
19
network.py
@ -1,5 +1,6 @@
|
||||
# from __future__ import annotations
|
||||
import copy
|
||||
import random
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
@ -9,7 +10,12 @@ import numpy as np
|
||||
from torch import optim, Tensor
|
||||
|
||||
|
||||
def prng():
|
||||
return random.random()
|
||||
|
||||
|
||||
class Net(nn.Module):
|
||||
|
||||
@staticmethod
|
||||
def create_target_weights(input_weight_matrix: Tensor) -> Tensor:
|
||||
""" Outputting a tensor with the target weights. """
|
||||
@ -171,3 +177,16 @@ class Net(nn.Module):
|
||||
SA_steps = 1
|
||||
|
||||
return other_net.apply_weights(my_evaluation)
|
||||
|
||||
def apply_noise(self, noise_size: float):
|
||||
""" Changing the weights of a network to values + noise """
|
||||
for layer_id, layer_name in enumerate(self.state_dict()):
|
||||
for line_id, line_values in enumerate(self.state_dict()[layer_name]):
|
||||
for weight_id, weight_value in enumerate(self.state_dict()[layer_name][line_id]):
|
||||
# network.state_dict()[layer_name][line_id][weight_id] = weight_value + noise
|
||||
if prng() < 0.5:
|
||||
self.state_dict()[layer_name][line_id][weight_id] = weight_value + noise_size * prng()
|
||||
else:
|
||||
self.state_dict()[layer_name][line_id][weight_id] = weight_value - noise_size * prng()
|
||||
|
||||
return self
|
||||
|
Reference in New Issue
Block a user