Robert Müller 482f45df87 big update
2020-04-06 14:46:26 +02:00

88 lines
3.0 KiB
Python

import torch
import torch.nn as nn
from torch.optim import Adam
class MyModel(nn.Module):
def __init__(self, feature_extractor):
super(MyModel, self).__init__()
self.feature_extractor = feature_extractor
feature_size = feature_extractor.feature_size
self.noise = nn.Sequential(
nn.Linear(feature_size, feature_size // 2),
nn.ELU(),
nn.Linear(feature_size // 2, feature_size // 4)
)
for p in self.noise.parameters():
p.requires_grad = False
self.noise.eval()
self.student = nn.Sequential(
nn.Linear(feature_size, feature_size // 2),
nn.ELU(),
nn.Linear(feature_size // 2, feature_size // 4),
nn.ELU(),
nn.Linear(feature_size // 4, feature_size // 4)
)
self.optimizer = Adam(self.student.parameters(), lr=0.0001, weight_decay=1e-7, amsgrad=True)
def forward(self, imgs):
features = self.feature_extractor.F(imgs).squeeze()
target = self.noise(features)
prediction = self.student(features)
return target, prediction
def scores_from_dataloader(self, dataloader):
scores = []
with torch.no_grad():
for batch in dataloader:
imgs, names, seq_ids = batch
imgs = imgs.to('cuda')
target, prediction = self.forward(imgs)
preds = torch.sum((prediction - target) ** 2, dim=tuple(range(1, target.dim())))
print(preds.shape)
class HyperFraud(nn.Module):
def __init__(self, hidden_dim=256):
super(HyperFraud, self).__init__()
self.hidden_dim = hidden_dim
self.mean = torch.randn(size=(1, 512))
self.std = torch.randn(size=(1, 512))
self.W_forget = nn.Sequential(
nn.Linear(512, hidden_dim)
)
self.U_forget = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim)
)
self.W_hidden = nn.Sequential(
nn.Linear(512, 256)
)
self.U_hidden = nn.Sequential(
nn.Linear(hidden_dim, 256)
)
self.b_forget = nn.Parameter(torch.randn(1, hidden_dim))
self.b_hidden = nn.Parameter(torch.randn(1, hidden_dim))
def forward(self, data, max_seq_len=10):
# data. batch x seqs x dim
# random seq sampling
h_prev = [torch.zeros(size=(data.shape[0], self.hidden_dim))]
for i in range(0, max_seq_len):
x_t = data[:, i]
h_t_prev = h_prev[-1]
W_x_t = self.W_forget(x_t)
U_h_prev = self.U_forget(h_t_prev)
forget_t = torch.sigmoid(W_x_t + U_h_prev + self.b_forget)
h_t = forget_t * h_t_prev + (1.0 - forget_t) * torch.tanh(self.W_hidden(x_t) + self.U_hidden(forget_t * h_t_prev) + self.b_hidden)
h_prev.append(h_t)
return torch.stack(h_prev[1:], dim=1)
#hf = HyperFraud()
#rand_input = torch.randn(size=(42, 10, 512))
#print(hf(rand_input).shape)