big update
This commit is contained in:
@@ -0,0 +1,87 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.optim import Adam
|
||||
|
||||
|
||||
class MyModel(nn.Module):
|
||||
def __init__(self, feature_extractor):
|
||||
super(MyModel, self).__init__()
|
||||
self.feature_extractor = feature_extractor
|
||||
feature_size = feature_extractor.feature_size
|
||||
self.noise = nn.Sequential(
|
||||
nn.Linear(feature_size, feature_size // 2),
|
||||
nn.ELU(),
|
||||
nn.Linear(feature_size // 2, feature_size // 4)
|
||||
)
|
||||
for p in self.noise.parameters():
|
||||
p.requires_grad = False
|
||||
self.noise.eval()
|
||||
|
||||
self.student = nn.Sequential(
|
||||
nn.Linear(feature_size, feature_size // 2),
|
||||
nn.ELU(),
|
||||
nn.Linear(feature_size // 2, feature_size // 4),
|
||||
nn.ELU(),
|
||||
nn.Linear(feature_size // 4, feature_size // 4)
|
||||
)
|
||||
self.optimizer = Adam(self.student.parameters(), lr=0.0001, weight_decay=1e-7, amsgrad=True)
|
||||
|
||||
def forward(self, imgs):
|
||||
features = self.feature_extractor.F(imgs).squeeze()
|
||||
target = self.noise(features)
|
||||
prediction = self.student(features)
|
||||
return target, prediction
|
||||
|
||||
|
||||
def scores_from_dataloader(self, dataloader):
|
||||
scores = []
|
||||
with torch.no_grad():
|
||||
for batch in dataloader:
|
||||
imgs, names, seq_ids = batch
|
||||
imgs = imgs.to('cuda')
|
||||
target, prediction = self.forward(imgs)
|
||||
preds = torch.sum((prediction - target) ** 2, dim=tuple(range(1, target.dim())))
|
||||
print(preds.shape)
|
||||
|
||||
|
||||
class HyperFraud(nn.Module):
|
||||
def __init__(self, hidden_dim=256):
|
||||
super(HyperFraud, self).__init__()
|
||||
self.hidden_dim = hidden_dim
|
||||
self.mean = torch.randn(size=(1, 512))
|
||||
self.std = torch.randn(size=(1, 512))
|
||||
|
||||
self.W_forget = nn.Sequential(
|
||||
nn.Linear(512, hidden_dim)
|
||||
)
|
||||
self.U_forget = nn.Sequential(
|
||||
nn.Linear(hidden_dim, hidden_dim)
|
||||
)
|
||||
self.W_hidden = nn.Sequential(
|
||||
nn.Linear(512, 256)
|
||||
)
|
||||
self.U_hidden = nn.Sequential(
|
||||
nn.Linear(hidden_dim, 256)
|
||||
)
|
||||
self.b_forget = nn.Parameter(torch.randn(1, hidden_dim))
|
||||
self.b_hidden = nn.Parameter(torch.randn(1, hidden_dim))
|
||||
|
||||
def forward(self, data, max_seq_len=10):
|
||||
# data. batch x seqs x dim
|
||||
# random seq sampling
|
||||
h_prev = [torch.zeros(size=(data.shape[0], self.hidden_dim))]
|
||||
for i in range(0, max_seq_len):
|
||||
x_t = data[:, i]
|
||||
h_t_prev = h_prev[-1]
|
||||
W_x_t = self.W_forget(x_t)
|
||||
U_h_prev = self.U_forget(h_t_prev)
|
||||
forget_t = torch.sigmoid(W_x_t + U_h_prev + self.b_forget)
|
||||
h_t = forget_t * h_t_prev + (1.0 - forget_t) * torch.tanh(self.W_hidden(x_t) + self.U_hidden(forget_t * h_t_prev) + self.b_hidden)
|
||||
h_prev.append(h_t)
|
||||
return torch.stack(h_prev[1:], dim=1)
|
||||
|
||||
|
||||
#hf = HyperFraud()
|
||||
#rand_input = torch.randn(size=(42, 10, 512))
|
||||
|
||||
#print(hf(rand_input).shape)
|
||||
Reference in New Issue
Block a user