262 lines
8.7 KiB
Python
262 lines
8.7 KiB
Python
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.functional as F
|
|
|
|
|
|
class Reshape(nn.Module):
|
|
def __init__(self, *args):
|
|
super(Reshape, self).__init__()
|
|
self.to = args
|
|
|
|
def forward(self, x):
|
|
return x.view(x.shape[0], *self.to)
|
|
|
|
class Flatten(nn.Module):
|
|
def __init__(self):
|
|
super(Flatten, self).__init__()
|
|
|
|
def forward(self, x):
|
|
return x.view(x.shape[0], -1)
|
|
|
|
|
|
class AE(nn.Module):
|
|
def __init__(self, in_dim=400):
|
|
super(AE, self).__init__()
|
|
self.net = nn.Sequential(
|
|
nn.Linear(in_dim, 64),
|
|
nn.ReLU(),
|
|
nn.Linear(64, 64),
|
|
nn.ReLU(),
|
|
nn.Linear(64, 8),
|
|
nn.ReLU(),
|
|
nn.Linear(8, 64),
|
|
nn.ReLU(),
|
|
nn.Linear(64, 64),
|
|
nn.ReLU(),
|
|
nn.Linear(64, in_dim),
|
|
nn.ReLU()
|
|
)
|
|
|
|
def forward(self, data):
|
|
x = data.view(data.shape[0], -1)
|
|
return self.net(x), x
|
|
|
|
def train_loss(self, data):
|
|
criterion = nn.MSELoss()
|
|
y_hat, y = self.forward(data)
|
|
loss = criterion(y_hat, y)
|
|
return loss
|
|
|
|
def test_loss(self, data):
|
|
y_hat, y = self.forward(data)
|
|
preds = torch.sum((y_hat - y) ** 2, dim=tuple(range(1, y_hat.dim())))
|
|
return preds
|
|
|
|
def init_weights(self):
|
|
def _weight_init(m):
|
|
if hasattr(m, 'weight'):
|
|
if isinstance(m.weight, torch.Tensor):
|
|
torch.nn.init.xavier_uniform_(m.weight,
|
|
gain=nn.init.calculate_gain('relu'))
|
|
if hasattr(m, 'bias'):
|
|
if isinstance(m.bias, torch.Tensor):
|
|
m.bias.data.fill_(0.01)
|
|
|
|
self.apply(_weight_init)
|
|
|
|
|
|
|
|
class SubSpecCAE(nn.Module):
|
|
def __init__(self, F=20, T=80, norm='batch',
|
|
activation='relu', dropout_prob=0.25):
|
|
super(SubSpecCAE, self).__init__()
|
|
self.T = T
|
|
self.F = F
|
|
self.activation = activation
|
|
self.loss_fn = nn.MSELoss()
|
|
Norm = nn.BatchNorm2d if norm == 'batch' else nn.InstanceNorm2d
|
|
Activation = nn.ReLU if activation == 'relu' else nn.LeakyReLU
|
|
a, b = 20, 40
|
|
self.encoder = nn.Sequential(
|
|
nn.Conv2d(in_channels=1, out_channels=a, kernel_size=7, stride=1, padding=3), # 32 x 20 x 80
|
|
Norm(a),
|
|
Activation(),
|
|
nn.MaxPool2d((F//10, 5)),
|
|
nn.Dropout(dropout_prob),
|
|
nn.Conv2d(in_channels=a, out_channels=b, kernel_size=7, stride=1, padding=3), # 64 x 10 x 16
|
|
Norm(b),
|
|
Activation(),
|
|
nn.MaxPool2d(4, T),
|
|
nn.Dropout(dropout_prob),
|
|
Flatten(),
|
|
nn.Linear(b, 16)
|
|
)
|
|
|
|
self.decoder = nn.Sequential(
|
|
nn.Linear(16, b),
|
|
Reshape(b, 1, 1),
|
|
nn.Upsample(size=(10, 16), mode='bilinear', align_corners=False),
|
|
nn.ConvTranspose2d(in_channels=b, out_channels=a, kernel_size=7, stride=1, padding=3),
|
|
Norm(a),
|
|
Activation(),
|
|
nn.Upsample(size=(20, 80), mode='bilinear', align_corners=False),
|
|
nn.Dropout(dropout_prob),
|
|
nn.ConvTranspose2d(in_channels=a, out_channels=1, kernel_size=7, stride=1, padding=3),
|
|
nn.Sigmoid()
|
|
)
|
|
|
|
def forward(self, sub_x):
|
|
#x = x[:, self.band, :,].unsqueeze(1) # select a single supspec
|
|
encoded = self.encoder(sub_x)
|
|
decoded = self.decoder(encoded)
|
|
return decoded, sub_x
|
|
|
|
def init_weights(self):
|
|
def weight_init(m):
|
|
if isinstance(m, nn.Conv2d) or isinstance(m, torch.nn.Linear):
|
|
torch.nn.init.kaiming_uniform_(m.weight)
|
|
if m.bias is not None:
|
|
m.bias.data.fill_(0.2)
|
|
self.apply(weight_init)
|
|
|
|
|
|
class FullSubSpecCAE(nn.Module):
|
|
def __init__(self, F=20, T=80,
|
|
norm='batch', activation='relu',
|
|
dropout_prob=0.25, weight_sharing=False,
|
|
sub_bands=[0, 1, 2, 3, 4, 5, 6]):
|
|
super(FullSubSpecCAE, self).__init__()
|
|
self.bands = sub_bands
|
|
self.weight_sharing = weight_sharing
|
|
self.aes = nn.ModuleList([
|
|
SubSpecCAE(F, T, norm, activation, dropout_prob) for band in range(
|
|
1 if weight_sharing else len(sub_bands)
|
|
)
|
|
])
|
|
for ae in self.aes: ae.init_weights()
|
|
self.loss_fn = nn.MSELoss()
|
|
|
|
def select_sub_ae(self, band):
|
|
if self.weight_sharing:
|
|
return self.aes[0]
|
|
else:
|
|
return self.aes[band]
|
|
|
|
def select_sub_band(self, x, band):
|
|
return x[:, band, :, ].unsqueeze(1)
|
|
|
|
def forward(self, x):
|
|
y_hat, y = [], []
|
|
for band in self.bands:
|
|
sub_ae = self.select_sub_ae(band)
|
|
sub_x = self.select_sub_band(x, band)
|
|
decoded, target = sub_ae(sub_x)
|
|
y.append(target)
|
|
y_hat.append(decoded)
|
|
decoded = torch.cat(y_hat, dim=1)
|
|
y = torch.cat(y, dim=1)
|
|
return decoded, y
|
|
|
|
def train_loss(self, data):
|
|
y_hat, y = self.forward(data) # torch.Size([96, 7, 20, 80])
|
|
loss = self.complete_mse(y_hat, y)
|
|
return loss
|
|
|
|
def test_loss(self, data):
|
|
y_hat, y = self.forward(data)
|
|
preds = torch.sum((y_hat - y) ** 2, dim=tuple(range(1, y_hat.dim())))
|
|
return preds
|
|
|
|
def sub_band_mse(self, y_hat, y):
|
|
losses = []
|
|
for band in self.bands:
|
|
sub_y = self.select_sub_band(y, band)
|
|
sub_y_hat = self.select_sub_band(y_hat, band)
|
|
sub_loss = torch.mean((sub_y_hat - sub_y) ** 2, dim=tuple(range(1, sub_y.dim()))) # torch.Size([96])
|
|
losses.append(sub_loss)
|
|
losses = torch.stack(losses, dim=1) # torch.Size([96, 7])
|
|
return losses
|
|
|
|
def complete_mse(self, y_hat, y):
|
|
return self.sub_band_mse(y_hat, y).mean(dim=0).sum()
|
|
|
|
def gather_predictions(self, dataloader):
|
|
device = next(self.parameters()).device
|
|
predictions = []
|
|
labels = []
|
|
self.eval()
|
|
with torch.no_grad():
|
|
for batch in dataloader:
|
|
data, l = batch
|
|
data = data.to(device)
|
|
y_hat, y = self.forward(data)
|
|
mse = self.sub_band_mse(y_hat, y) # 96 x 7
|
|
predictions.append(mse)
|
|
labels += l.tolist()
|
|
predictions = torch.cat(predictions).cpu().numpy()
|
|
self.train()
|
|
return predictions, labels
|
|
|
|
|
|
|
|
class GlobalCAE(nn.Module):
|
|
def __init__(self, F=20, T=80, norm='batch', activation='relu', dropout_prob=0.25):
|
|
super(GlobalCAE, self).__init__()
|
|
self.activation = activation
|
|
Norm = nn.BatchNorm2d if norm == 'batch' else nn.InstanceNorm2d
|
|
Activation = nn.ReLU if activation == 'relu' else nn.LeakyReLU
|
|
self.encoder = nn.Sequential(
|
|
nn.Conv2d(in_channels=1, out_channels=20, kernel_size=8, stride=2), # 32 x 20 x 80
|
|
Norm(20),
|
|
Activation(),
|
|
nn.Dropout(dropout_prob),
|
|
nn.Conv2d(in_channels=20, out_channels=40, kernel_size=5, stride=2), # 64 x 10 x 16
|
|
Norm(40),
|
|
Activation(),
|
|
nn.Dropout(dropout_prob),
|
|
nn.Conv2d(in_channels=40, out_channels=60, kernel_size=3, stride=2), # 64 x 10 x 16
|
|
Norm(60),
|
|
Flatten(),
|
|
nn.Linear(60*8*8, 64)
|
|
|
|
)
|
|
self.decoder = nn.Sequential(
|
|
nn.Linear(64, 60*8*8),
|
|
Reshape(60, 8, 8),
|
|
nn.ConvTranspose2d(in_channels=60, out_channels=40, kernel_size=3, stride=2),
|
|
Norm(40),
|
|
Activation(),
|
|
nn.Dropout(dropout_prob),
|
|
nn.ConvTranspose2d(in_channels=40, out_channels=20, kernel_size=5, stride=2),
|
|
Norm(20),
|
|
Activation(),
|
|
nn.Dropout(dropout_prob),
|
|
nn.ConvTranspose2d(in_channels=20, out_channels=1, kernel_size=8, stride=2),
|
|
nn.Sigmoid()
|
|
)
|
|
|
|
def forward(self, x):
|
|
x = x.unsqueeze(1)
|
|
encoded = self.encoder(x)
|
|
decoded = self.decoder(encoded)
|
|
return decoded, x
|
|
|
|
def train_loss(self, data):
|
|
criterion = nn.MSELoss()
|
|
y_hat, y = self.forward(data)
|
|
loss = criterion(y_hat, y)
|
|
return loss
|
|
|
|
def test_loss(self, data):
|
|
y_hat, y = self.forward(data)
|
|
preds = torch.sum((y_hat - y) ** 2, dim=tuple(range(1, y_hat.dim())))
|
|
return preds
|
|
|
|
def init_weights(self):
|
|
def weight_init(m):
|
|
if isinstance(m, nn.Conv2d) or isinstance(m, torch.nn.Linear):
|
|
torch.nn.init.kaiming_uniform_(m.weight)
|
|
if m.bias is not None:
|
|
m.bias.data.fill_(0.01)
|
|
self.apply(weight_init) |