import torch import torch.nn as nn from torch.optim import Adam class MyModel(nn.Module): def __init__(self, feature_extractor): super(MyModel, self).__init__() self.feature_extractor = feature_extractor feature_size = feature_extractor.feature_size self.noise = nn.Sequential( nn.Linear(feature_size, feature_size // 2), nn.ELU(), nn.Linear(feature_size // 2, feature_size // 4) ) for p in self.noise.parameters(): p.requires_grad = False self.noise.eval() self.student = nn.Sequential( nn.Linear(feature_size, feature_size // 2), nn.ELU(), nn.Linear(feature_size // 2, feature_size // 4), nn.ELU(), nn.Linear(feature_size // 4, feature_size // 4) ) self.optimizer = Adam(self.student.parameters(), lr=0.0001, weight_decay=1e-7, amsgrad=True) def forward(self, imgs): features = self.feature_extractor.F(imgs).squeeze() target = self.noise(features) prediction = self.student(features) return target, prediction def scores_from_dataloader(self, dataloader): scores = [] with torch.no_grad(): for batch in dataloader: imgs, names, seq_ids = batch imgs = imgs.to('cuda') target, prediction = self.forward(imgs) preds = torch.sum((prediction - target) ** 2, dim=tuple(range(1, target.dim()))) print(preds.shape) class HyperFraud(nn.Module): def __init__(self, hidden_dim=256): super(HyperFraud, self).__init__() self.hidden_dim = hidden_dim self.mean = torch.randn(size=(1, 512)) self.std = torch.randn(size=(1, 512)) self.W_forget = nn.Sequential( nn.Linear(512, hidden_dim) ) self.U_forget = nn.Sequential( nn.Linear(hidden_dim, hidden_dim) ) self.W_hidden = nn.Sequential( nn.Linear(512, 256) ) self.U_hidden = nn.Sequential( nn.Linear(hidden_dim, 256) ) self.b_forget = nn.Parameter(torch.randn(1, hidden_dim)) self.b_hidden = nn.Parameter(torch.randn(1, hidden_dim)) def forward(self, data, max_seq_len=10): # data. batch x seqs x dim # random seq sampling h_prev = [torch.zeros(size=(data.shape[0], self.hidden_dim))] for i in range(0, max_seq_len): x_t = data[:, i] h_t_prev = h_prev[-1] W_x_t = self.W_forget(x_t) U_h_prev = self.U_forget(h_t_prev) forget_t = torch.sigmoid(W_x_t + U_h_prev + self.b_forget) h_t = forget_t * h_t_prev + (1.0 - forget_t) * torch.tanh(self.W_hidden(x_t) + self.U_hidden(forget_t * h_t_prev) + self.b_hidden) h_prev.append(h_t) return torch.stack(h_prev[1:], dim=1) #hf = HyperFraud() #rand_input = torch.randn(size=(42, 10, 512)) #print(hf(rand_input).shape)