88 lines
3.0 KiB
Python
88 lines
3.0 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
from torch.optim import Adam
|
|
|
|
|
|
class MyModel(nn.Module):
|
|
def __init__(self, feature_extractor):
|
|
super(MyModel, self).__init__()
|
|
self.feature_extractor = feature_extractor
|
|
feature_size = feature_extractor.feature_size
|
|
self.noise = nn.Sequential(
|
|
nn.Linear(feature_size, feature_size // 2),
|
|
nn.ELU(),
|
|
nn.Linear(feature_size // 2, feature_size // 4)
|
|
)
|
|
for p in self.noise.parameters():
|
|
p.requires_grad = False
|
|
self.noise.eval()
|
|
|
|
self.student = nn.Sequential(
|
|
nn.Linear(feature_size, feature_size // 2),
|
|
nn.ELU(),
|
|
nn.Linear(feature_size // 2, feature_size // 4),
|
|
nn.ELU(),
|
|
nn.Linear(feature_size // 4, feature_size // 4)
|
|
)
|
|
self.optimizer = Adam(self.student.parameters(), lr=0.0001, weight_decay=1e-7, amsgrad=True)
|
|
|
|
def forward(self, imgs):
|
|
features = self.feature_extractor.F(imgs).squeeze()
|
|
target = self.noise(features)
|
|
prediction = self.student(features)
|
|
return target, prediction
|
|
|
|
|
|
def scores_from_dataloader(self, dataloader):
|
|
scores = []
|
|
with torch.no_grad():
|
|
for batch in dataloader:
|
|
imgs, names, seq_ids = batch
|
|
imgs = imgs.to('cuda')
|
|
target, prediction = self.forward(imgs)
|
|
preds = torch.sum((prediction - target) ** 2, dim=tuple(range(1, target.dim())))
|
|
print(preds.shape)
|
|
|
|
|
|
class HyperFraud(nn.Module):
|
|
def __init__(self, hidden_dim=256):
|
|
super(HyperFraud, self).__init__()
|
|
self.hidden_dim = hidden_dim
|
|
self.mean = torch.randn(size=(1, 512))
|
|
self.std = torch.randn(size=(1, 512))
|
|
|
|
self.W_forget = nn.Sequential(
|
|
nn.Linear(512, hidden_dim)
|
|
)
|
|
self.U_forget = nn.Sequential(
|
|
nn.Linear(hidden_dim, hidden_dim)
|
|
)
|
|
self.W_hidden = nn.Sequential(
|
|
nn.Linear(512, 256)
|
|
)
|
|
self.U_hidden = nn.Sequential(
|
|
nn.Linear(hidden_dim, 256)
|
|
)
|
|
self.b_forget = nn.Parameter(torch.randn(1, hidden_dim))
|
|
self.b_hidden = nn.Parameter(torch.randn(1, hidden_dim))
|
|
|
|
def forward(self, data, max_seq_len=10):
|
|
# data. batch x seqs x dim
|
|
# random seq sampling
|
|
h_prev = [torch.zeros(size=(data.shape[0], self.hidden_dim))]
|
|
for i in range(0, max_seq_len):
|
|
x_t = data[:, i]
|
|
h_t_prev = h_prev[-1]
|
|
W_x_t = self.W_forget(x_t)
|
|
U_h_prev = self.U_forget(h_t_prev)
|
|
forget_t = torch.sigmoid(W_x_t + U_h_prev + self.b_forget)
|
|
h_t = forget_t * h_t_prev + (1.0 - forget_t) * torch.tanh(self.W_hidden(x_t) + self.U_hidden(forget_t * h_t_prev) + self.b_hidden)
|
|
h_prev.append(h_t)
|
|
return torch.stack(h_prev[1:], dim=1)
|
|
|
|
|
|
#hf = HyperFraud()
|
|
#rand_input = torch.randn(size=(42, 10, 512))
|
|
|
|
#print(hf(rand_input).shape)
|