import numpy as np import torch import torch.nn as nn import torch.functional as F class Reshape(nn.Module): def __init__(self, *args): super(Reshape, self).__init__() self.to = args def forward(self, x): return x.view(x.shape[0], *self.to) class Flatten(nn.Module): def __init__(self): super(Flatten, self).__init__() def forward(self, x): return x.view(x.shape[0], -1) class AE(nn.Module): def __init__(self, in_dim=400): super(AE, self).__init__() self.net = nn.Sequential( nn.Linear(in_dim, 64), nn.ReLU(), nn.Linear(64, 64), nn.ReLU(), nn.Linear(64, 8), nn.ReLU(), nn.Linear(8, 64), nn.ReLU(), nn.Linear(64, 64), nn.ReLU(), nn.Linear(64, in_dim), nn.ReLU() ) def forward(self, data): x = data.view(data.shape[0], -1) return self.net(x), x def train_loss(self, data): criterion = nn.MSELoss() y_hat, y = self.forward(data) loss = criterion(y_hat, y) return loss def test_loss(self, data): y_hat, y = self.forward(data) preds = torch.sum((y_hat - y) ** 2, dim=tuple(range(1, y_hat.dim()))) return preds def init_weights(self): def _weight_init(m): if hasattr(m, 'weight'): if isinstance(m.weight, torch.Tensor): torch.nn.init.xavier_uniform_(m.weight, gain=nn.init.calculate_gain('relu')) if hasattr(m, 'bias'): if isinstance(m.bias, torch.Tensor): m.bias.data.fill_(0.01) self.apply(_weight_init) class SubSpecCAE(nn.Module): def __init__(self, F=20, T=80, norm='batch', activation='relu', dropout_prob=0.25): super(SubSpecCAE, self).__init__() self.T = T self.F = F self.activation = activation self.loss_fn = nn.MSELoss() Norm = nn.BatchNorm2d if norm == 'batch' else nn.InstanceNorm2d Activation = nn.ReLU if activation == 'relu' else nn.LeakyReLU a, b = 20, 40 self.encoder = nn.Sequential( nn.Conv2d(in_channels=1, out_channels=a, kernel_size=7, stride=1, padding=3), # 32 x 20 x 80 Norm(a), Activation(), nn.MaxPool2d((F//10, 5)), nn.Dropout(dropout_prob), nn.Conv2d(in_channels=a, out_channels=b, kernel_size=7, stride=1, padding=3), # 64 x 10 x 16 Norm(b), Activation(), nn.MaxPool2d(4, T), nn.Dropout(dropout_prob), Flatten(), nn.Linear(b, 16) ) self.decoder = nn.Sequential( nn.Linear(16, b), Reshape(b, 1, 1), nn.Upsample(size=(10, 16), mode='bilinear', align_corners=False), nn.ConvTranspose2d(in_channels=b, out_channels=a, kernel_size=7, stride=1, padding=3), Norm(a), Activation(), nn.Upsample(size=(20, 80), mode='bilinear', align_corners=False), nn.Dropout(dropout_prob), nn.ConvTranspose2d(in_channels=a, out_channels=1, kernel_size=7, stride=1, padding=3), nn.Sigmoid() ) def forward(self, sub_x): #x = x[:, self.band, :,].unsqueeze(1) # select a single supspec encoded = self.encoder(sub_x) decoded = self.decoder(encoded) return decoded, sub_x def init_weights(self): def weight_init(m): if isinstance(m, nn.Conv2d) or isinstance(m, torch.nn.Linear): torch.nn.init.kaiming_uniform_(m.weight) if m.bias is not None: m.bias.data.fill_(0.2) self.apply(weight_init) class FullSubSpecCAE(nn.Module): def __init__(self, F=20, T=80, norm='batch', activation='relu', dropout_prob=0.25, weight_sharing=False, sub_bands=[0, 1, 2, 3, 4, 5, 6]): super(FullSubSpecCAE, self).__init__() self.bands = sub_bands self.weight_sharing = weight_sharing self.aes = nn.ModuleList([ SubSpecCAE(F, T, norm, activation, dropout_prob) for band in range( 1 if weight_sharing else len(sub_bands) ) ]) for ae in self.aes: ae.init_weights() self.loss_fn = nn.MSELoss() def select_sub_ae(self, band): if self.weight_sharing: return self.aes[0] else: return self.aes[band] def select_sub_band(self, x, band): return x[:, band, :, ].unsqueeze(1) def forward(self, x): y_hat, y = [], [] for band in self.bands: sub_ae = self.select_sub_ae(band) sub_x = self.select_sub_band(x, band) decoded, target = sub_ae(sub_x) y.append(target) y_hat.append(decoded) decoded = torch.cat(y_hat, dim=1) y = torch.cat(y, dim=1) return decoded, y def train_loss(self, data): y_hat, y = self.forward(data) # torch.Size([96, 7, 20, 80]) loss = self.complete_mse(y_hat, y) return loss def test_loss(self, data): y_hat, y = self.forward(data) preds = torch.sum((y_hat - y) ** 2, dim=tuple(range(1, y_hat.dim()))) return preds def sub_band_mse(self, y_hat, y): losses = [] for band in self.bands: sub_y = self.select_sub_band(y, band) sub_y_hat = self.select_sub_band(y_hat, band) sub_loss = torch.mean((sub_y_hat - sub_y) ** 2, dim=tuple(range(1, sub_y.dim()))) # torch.Size([96]) losses.append(sub_loss) losses = torch.stack(losses, dim=1) # torch.Size([96, 7]) return losses def complete_mse(self, y_hat, y): return self.sub_band_mse(y_hat, y).mean(dim=0).sum() def gather_predictions(self, dataloader): device = next(self.parameters()).device predictions = [] labels = [] self.eval() with torch.no_grad(): for batch in dataloader: data, l = batch data = data.to(device) y_hat, y = self.forward(data) mse = self.sub_band_mse(y_hat, y) # 96 x 7 predictions.append(mse) labels += l.tolist() predictions = torch.cat(predictions).cpu().numpy() self.train() return predictions, labels class GlobalCAE(nn.Module): def __init__(self, F=20, T=80, norm='batch', activation='relu', dropout_prob=0.25): super(GlobalCAE, self).__init__() self.activation = activation Norm = nn.BatchNorm2d if norm == 'batch' else nn.InstanceNorm2d Activation = nn.ReLU if activation == 'relu' else nn.LeakyReLU self.encoder = nn.Sequential( nn.Conv2d(in_channels=1, out_channels=20, kernel_size=8, stride=2), # 32 x 20 x 80 Norm(20), Activation(), nn.Dropout(dropout_prob), nn.Conv2d(in_channels=20, out_channels=40, kernel_size=5, stride=2), # 64 x 10 x 16 Norm(40), Activation(), nn.Dropout(dropout_prob), nn.Conv2d(in_channels=40, out_channels=60, kernel_size=3, stride=2), # 64 x 10 x 16 Norm(60), Flatten(), nn.Linear(60*8*8, 64) ) self.decoder = nn.Sequential( nn.Linear(64, 60*8*8), Reshape(60, 8, 8), nn.ConvTranspose2d(in_channels=60, out_channels=40, kernel_size=3, stride=2), Norm(40), Activation(), nn.Dropout(dropout_prob), nn.ConvTranspose2d(in_channels=40, out_channels=20, kernel_size=5, stride=2), Norm(20), Activation(), nn.Dropout(dropout_prob), nn.ConvTranspose2d(in_channels=20, out_channels=1, kernel_size=8, stride=2), nn.Sigmoid() ) def forward(self, x): x = x.unsqueeze(1) encoded = self.encoder(x) decoded = self.decoder(encoded) return decoded, x def train_loss(self, data): criterion = nn.MSELoss() y_hat, y = self.forward(data) loss = criterion(y_hat, y) return loss def test_loss(self, data): y_hat, y = self.forward(data) preds = torch.sum((y_hat - y) ** 2, dim=tuple(range(1, y_hat.dim()))) return preds def init_weights(self): def weight_init(m): if isinstance(m, nn.Conv2d) or isinstance(m, torch.nn.Linear): torch.nn.init.kaiming_uniform_(m.weight) if m.bias is not None: m.bias.data.fill_(0.01) self.apply(weight_init)