Robert Müller 482f45df87 big update
2020-04-06 14:46:26 +02:00

262 lines
8.7 KiB
Python

import numpy as np
import torch
import torch.nn as nn
import torch.functional as F
class Reshape(nn.Module):
def __init__(self, *args):
super(Reshape, self).__init__()
self.to = args
def forward(self, x):
return x.view(x.shape[0], *self.to)
class Flatten(nn.Module):
def __init__(self):
super(Flatten, self).__init__()
def forward(self, x):
return x.view(x.shape[0], -1)
class AE(nn.Module):
def __init__(self, in_dim=400):
super(AE, self).__init__()
self.net = nn.Sequential(
nn.Linear(in_dim, 64),
nn.ReLU(),
nn.Linear(64, 64),
nn.ReLU(),
nn.Linear(64, 8),
nn.ReLU(),
nn.Linear(8, 64),
nn.ReLU(),
nn.Linear(64, 64),
nn.ReLU(),
nn.Linear(64, in_dim),
nn.ReLU()
)
def forward(self, data):
x = data.view(data.shape[0], -1)
return self.net(x), x
def train_loss(self, data):
criterion = nn.MSELoss()
y_hat, y = self.forward(data)
loss = criterion(y_hat, y)
return loss
def test_loss(self, data):
y_hat, y = self.forward(data)
preds = torch.sum((y_hat - y) ** 2, dim=tuple(range(1, y_hat.dim())))
return preds
def init_weights(self):
def _weight_init(m):
if hasattr(m, 'weight'):
if isinstance(m.weight, torch.Tensor):
torch.nn.init.xavier_uniform_(m.weight,
gain=nn.init.calculate_gain('relu'))
if hasattr(m, 'bias'):
if isinstance(m.bias, torch.Tensor):
m.bias.data.fill_(0.01)
self.apply(_weight_init)
class SubSpecCAE(nn.Module):
def __init__(self, F=20, T=80, norm='batch',
activation='relu', dropout_prob=0.25):
super(SubSpecCAE, self).__init__()
self.T = T
self.F = F
self.activation = activation
self.loss_fn = nn.MSELoss()
Norm = nn.BatchNorm2d if norm == 'batch' else nn.InstanceNorm2d
Activation = nn.ReLU if activation == 'relu' else nn.LeakyReLU
a, b = 20, 40
self.encoder = nn.Sequential(
nn.Conv2d(in_channels=1, out_channels=a, kernel_size=7, stride=1, padding=3), # 32 x 20 x 80
Norm(a),
Activation(),
nn.MaxPool2d((F//10, 5)),
nn.Dropout(dropout_prob),
nn.Conv2d(in_channels=a, out_channels=b, kernel_size=7, stride=1, padding=3), # 64 x 10 x 16
Norm(b),
Activation(),
nn.MaxPool2d(4, T),
nn.Dropout(dropout_prob),
Flatten(),
nn.Linear(b, 16)
)
self.decoder = nn.Sequential(
nn.Linear(16, b),
Reshape(b, 1, 1),
nn.Upsample(size=(10, 16), mode='bilinear', align_corners=False),
nn.ConvTranspose2d(in_channels=b, out_channels=a, kernel_size=7, stride=1, padding=3),
Norm(a),
Activation(),
nn.Upsample(size=(20, 80), mode='bilinear', align_corners=False),
nn.Dropout(dropout_prob),
nn.ConvTranspose2d(in_channels=a, out_channels=1, kernel_size=7, stride=1, padding=3),
nn.Sigmoid()
)
def forward(self, sub_x):
#x = x[:, self.band, :,].unsqueeze(1) # select a single supspec
encoded = self.encoder(sub_x)
decoded = self.decoder(encoded)
return decoded, sub_x
def init_weights(self):
def weight_init(m):
if isinstance(m, nn.Conv2d) or isinstance(m, torch.nn.Linear):
torch.nn.init.kaiming_uniform_(m.weight)
if m.bias is not None:
m.bias.data.fill_(0.2)
self.apply(weight_init)
class FullSubSpecCAE(nn.Module):
def __init__(self, F=20, T=80,
norm='batch', activation='relu',
dropout_prob=0.25, weight_sharing=False,
sub_bands=[0, 1, 2, 3, 4, 5, 6]):
super(FullSubSpecCAE, self).__init__()
self.bands = sub_bands
self.weight_sharing = weight_sharing
self.aes = nn.ModuleList([
SubSpecCAE(F, T, norm, activation, dropout_prob) for band in range(
1 if weight_sharing else len(sub_bands)
)
])
for ae in self.aes: ae.init_weights()
self.loss_fn = nn.MSELoss()
def select_sub_ae(self, band):
if self.weight_sharing:
return self.aes[0]
else:
return self.aes[band]
def select_sub_band(self, x, band):
return x[:, band, :, ].unsqueeze(1)
def forward(self, x):
y_hat, y = [], []
for band in self.bands:
sub_ae = self.select_sub_ae(band)
sub_x = self.select_sub_band(x, band)
decoded, target = sub_ae(sub_x)
y.append(target)
y_hat.append(decoded)
decoded = torch.cat(y_hat, dim=1)
y = torch.cat(y, dim=1)
return decoded, y
def train_loss(self, data):
y_hat, y = self.forward(data) # torch.Size([96, 7, 20, 80])
loss = self.complete_mse(y_hat, y)
return loss
def test_loss(self, data):
y_hat, y = self.forward(data)
preds = torch.sum((y_hat - y) ** 2, dim=tuple(range(1, y_hat.dim())))
return preds
def sub_band_mse(self, y_hat, y):
losses = []
for band in self.bands:
sub_y = self.select_sub_band(y, band)
sub_y_hat = self.select_sub_band(y_hat, band)
sub_loss = torch.mean((sub_y_hat - sub_y) ** 2, dim=tuple(range(1, sub_y.dim()))) # torch.Size([96])
losses.append(sub_loss)
losses = torch.stack(losses, dim=1) # torch.Size([96, 7])
return losses
def complete_mse(self, y_hat, y):
return self.sub_band_mse(y_hat, y).mean(dim=0).sum()
def gather_predictions(self, dataloader):
device = next(self.parameters()).device
predictions = []
labels = []
self.eval()
with torch.no_grad():
for batch in dataloader:
data, l = batch
data = data.to(device)
y_hat, y = self.forward(data)
mse = self.sub_band_mse(y_hat, y) # 96 x 7
predictions.append(mse)
labels += l.tolist()
predictions = torch.cat(predictions).cpu().numpy()
self.train()
return predictions, labels
class GlobalCAE(nn.Module):
def __init__(self, F=20, T=80, norm='batch', activation='relu', dropout_prob=0.25):
super(GlobalCAE, self).__init__()
self.activation = activation
Norm = nn.BatchNorm2d if norm == 'batch' else nn.InstanceNorm2d
Activation = nn.ReLU if activation == 'relu' else nn.LeakyReLU
self.encoder = nn.Sequential(
nn.Conv2d(in_channels=1, out_channels=20, kernel_size=8, stride=2), # 32 x 20 x 80
Norm(20),
Activation(),
nn.Dropout(dropout_prob),
nn.Conv2d(in_channels=20, out_channels=40, kernel_size=5, stride=2), # 64 x 10 x 16
Norm(40),
Activation(),
nn.Dropout(dropout_prob),
nn.Conv2d(in_channels=40, out_channels=60, kernel_size=3, stride=2), # 64 x 10 x 16
Norm(60),
Flatten(),
nn.Linear(60*8*8, 64)
)
self.decoder = nn.Sequential(
nn.Linear(64, 60*8*8),
Reshape(60, 8, 8),
nn.ConvTranspose2d(in_channels=60, out_channels=40, kernel_size=3, stride=2),
Norm(40),
Activation(),
nn.Dropout(dropout_prob),
nn.ConvTranspose2d(in_channels=40, out_channels=20, kernel_size=5, stride=2),
Norm(20),
Activation(),
nn.Dropout(dropout_prob),
nn.ConvTranspose2d(in_channels=20, out_channels=1, kernel_size=8, stride=2),
nn.Sigmoid()
)
def forward(self, x):
x = x.unsqueeze(1)
encoded = self.encoder(x)
decoded = self.decoder(encoded)
return decoded, x
def train_loss(self, data):
criterion = nn.MSELoss()
y_hat, y = self.forward(data)
loss = criterion(y_hat, y)
return loss
def test_loss(self, data):
y_hat, y = self.forward(data)
preds = torch.sum((y_hat - y) ** 2, dim=tuple(range(1, y_hat.dim())))
return preds
def init_weights(self):
def weight_init(m):
if isinstance(m, nn.Conv2d) or isinstance(m, torch.nn.Linear):
torch.nn.init.kaiming_uniform_(m.weight)
if m.bias is not None:
m.bias.data.fill_(0.01)
self.apply(weight_init)