big update
This commit is contained in:
163
models/ae.py
163
models/ae.py
@@ -1,3 +1,4 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.functional as F
|
||||
@@ -67,42 +68,176 @@ class AE(nn.Module):
|
||||
|
||||
|
||||
class SubSpecCAE(nn.Module):
|
||||
def __init__(self, F=20, T=80, norm='batch', activation='relu', dropout_prob=0.25, band=0):
|
||||
def __init__(self, F=20, T=80, norm='batch',
|
||||
activation='relu', dropout_prob=0.25):
|
||||
super(SubSpecCAE, self).__init__()
|
||||
self.T = T
|
||||
self.F = F
|
||||
self.activation = activation
|
||||
self.band = band
|
||||
self.loss_fn = nn.MSELoss()
|
||||
Norm = nn.BatchNorm2d if norm == 'batch' else nn.InstanceNorm2d
|
||||
Activation = nn.ReLU if activation == 'relu' else nn.LeakyReLU
|
||||
a, b = 20, 40
|
||||
self.encoder = nn.Sequential(
|
||||
nn.Conv2d(in_channels=1, out_channels=32, kernel_size=7, stride=1, padding=3), # 32 x 20 x 80
|
||||
Norm(32),
|
||||
nn.Conv2d(in_channels=1, out_channels=a, kernel_size=7, stride=1, padding=3), # 32 x 20 x 80
|
||||
Norm(a),
|
||||
Activation(),
|
||||
nn.MaxPool2d((F//10, 5)),
|
||||
nn.Dropout(dropout_prob),
|
||||
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=7, stride=1, padding=3), # 64 x 10 x 16
|
||||
Norm(64),
|
||||
nn.Conv2d(in_channels=a, out_channels=b, kernel_size=7, stride=1, padding=3), # 64 x 10 x 16
|
||||
Norm(b),
|
||||
Activation(),
|
||||
nn.MaxPool2d(4, T),
|
||||
nn.Dropout(dropout_prob),
|
||||
Flatten(),
|
||||
nn.Linear(64, 16)
|
||||
nn.Linear(b, 16)
|
||||
)
|
||||
|
||||
self.decoder = nn.Sequential(
|
||||
nn.Linear(16, 64),
|
||||
Reshape(64, 1, 1),
|
||||
nn.Linear(16, b),
|
||||
Reshape(b, 1, 1),
|
||||
nn.Upsample(size=(10, 16), mode='bilinear', align_corners=False),
|
||||
nn.ConvTranspose2d(in_channels=64, out_channels=32, kernel_size=7, stride=1, padding=3),
|
||||
Norm(32),
|
||||
nn.ConvTranspose2d(in_channels=b, out_channels=a, kernel_size=7, stride=1, padding=3),
|
||||
Norm(a),
|
||||
Activation(),
|
||||
nn.Upsample(size=(20, 80), mode='bilinear', align_corners=False),
|
||||
nn.Dropout(dropout_prob),
|
||||
nn.ConvTranspose2d(in_channels=32, out_channels=1, kernel_size=7, stride=1, padding=3)
|
||||
nn.ConvTranspose2d(in_channels=a, out_channels=1, kernel_size=7, stride=1, padding=3),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
|
||||
def forward(self, sub_x):
|
||||
#x = x[:, self.band, :,].unsqueeze(1) # select a single supspec
|
||||
encoded = self.encoder(sub_x)
|
||||
decoded = self.decoder(encoded)
|
||||
return decoded, sub_x
|
||||
|
||||
def init_weights(self):
|
||||
def weight_init(m):
|
||||
if isinstance(m, nn.Conv2d) or isinstance(m, torch.nn.Linear):
|
||||
torch.nn.init.kaiming_uniform_(m.weight)
|
||||
if m.bias is not None:
|
||||
m.bias.data.fill_(0.2)
|
||||
self.apply(weight_init)
|
||||
|
||||
|
||||
class FullSubSpecCAE(nn.Module):
|
||||
def __init__(self, F=20, T=80,
|
||||
norm='batch', activation='relu',
|
||||
dropout_prob=0.25, weight_sharing=False,
|
||||
sub_bands=[0, 1, 2, 3, 4, 5, 6]):
|
||||
super(FullSubSpecCAE, self).__init__()
|
||||
self.bands = sub_bands
|
||||
self.weight_sharing = weight_sharing
|
||||
self.aes = nn.ModuleList([
|
||||
SubSpecCAE(F, T, norm, activation, dropout_prob) for band in range(
|
||||
1 if weight_sharing else len(sub_bands)
|
||||
)
|
||||
])
|
||||
for ae in self.aes: ae.init_weights()
|
||||
self.loss_fn = nn.MSELoss()
|
||||
|
||||
def select_sub_ae(self, band):
|
||||
if self.weight_sharing:
|
||||
return self.aes[0]
|
||||
else:
|
||||
return self.aes[band]
|
||||
|
||||
def select_sub_band(self, x, band):
|
||||
return x[:, band, :, ].unsqueeze(1)
|
||||
|
||||
def forward(self, x):
|
||||
y_hat, y = [], []
|
||||
for band in self.bands:
|
||||
sub_ae = self.select_sub_ae(band)
|
||||
sub_x = self.select_sub_band(x, band)
|
||||
decoded, target = sub_ae(sub_x)
|
||||
y.append(target)
|
||||
y_hat.append(decoded)
|
||||
decoded = torch.cat(y_hat, dim=1)
|
||||
y = torch.cat(y, dim=1)
|
||||
return decoded, y
|
||||
|
||||
def train_loss(self, data):
|
||||
y_hat, y = self.forward(data) # torch.Size([96, 7, 20, 80])
|
||||
loss = self.complete_mse(y_hat, y)
|
||||
return loss
|
||||
|
||||
def test_loss(self, data):
|
||||
y_hat, y = self.forward(data)
|
||||
preds = torch.sum((y_hat - y) ** 2, dim=tuple(range(1, y_hat.dim())))
|
||||
return preds
|
||||
|
||||
def sub_band_mse(self, y_hat, y):
|
||||
losses = []
|
||||
for band in self.bands:
|
||||
sub_y = self.select_sub_band(y, band)
|
||||
sub_y_hat = self.select_sub_band(y_hat, band)
|
||||
sub_loss = torch.mean((sub_y_hat - sub_y) ** 2, dim=tuple(range(1, sub_y.dim()))) # torch.Size([96])
|
||||
losses.append(sub_loss)
|
||||
losses = torch.stack(losses, dim=1) # torch.Size([96, 7])
|
||||
return losses
|
||||
|
||||
def complete_mse(self, y_hat, y):
|
||||
return self.sub_band_mse(y_hat, y).mean(dim=0).sum()
|
||||
|
||||
def gather_predictions(self, dataloader):
|
||||
device = next(self.parameters()).device
|
||||
predictions = []
|
||||
labels = []
|
||||
self.eval()
|
||||
with torch.no_grad():
|
||||
for batch in dataloader:
|
||||
data, l = batch
|
||||
data = data.to(device)
|
||||
y_hat, y = self.forward(data)
|
||||
mse = self.sub_band_mse(y_hat, y) # 96 x 7
|
||||
predictions.append(mse)
|
||||
labels += l.tolist()
|
||||
predictions = torch.cat(predictions).cpu().numpy()
|
||||
self.train()
|
||||
return predictions, labels
|
||||
|
||||
|
||||
|
||||
class GlobalCAE(nn.Module):
|
||||
def __init__(self, F=20, T=80, norm='batch', activation='relu', dropout_prob=0.25):
|
||||
super(GlobalCAE, self).__init__()
|
||||
self.activation = activation
|
||||
Norm = nn.BatchNorm2d if norm == 'batch' else nn.InstanceNorm2d
|
||||
Activation = nn.ReLU if activation == 'relu' else nn.LeakyReLU
|
||||
self.encoder = nn.Sequential(
|
||||
nn.Conv2d(in_channels=1, out_channels=20, kernel_size=8, stride=2), # 32 x 20 x 80
|
||||
Norm(20),
|
||||
Activation(),
|
||||
nn.Dropout(dropout_prob),
|
||||
nn.Conv2d(in_channels=20, out_channels=40, kernel_size=5, stride=2), # 64 x 10 x 16
|
||||
Norm(40),
|
||||
Activation(),
|
||||
nn.Dropout(dropout_prob),
|
||||
nn.Conv2d(in_channels=40, out_channels=60, kernel_size=3, stride=2), # 64 x 10 x 16
|
||||
Norm(60),
|
||||
Flatten(),
|
||||
nn.Linear(60*8*8, 64)
|
||||
|
||||
)
|
||||
self.decoder = nn.Sequential(
|
||||
nn.Linear(64, 60*8*8),
|
||||
Reshape(60, 8, 8),
|
||||
nn.ConvTranspose2d(in_channels=60, out_channels=40, kernel_size=3, stride=2),
|
||||
Norm(40),
|
||||
Activation(),
|
||||
nn.Dropout(dropout_prob),
|
||||
nn.ConvTranspose2d(in_channels=40, out_channels=20, kernel_size=5, stride=2),
|
||||
Norm(20),
|
||||
Activation(),
|
||||
nn.Dropout(dropout_prob),
|
||||
nn.ConvTranspose2d(in_channels=20, out_channels=1, kernel_size=8, stride=2),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = x[:,self.band,:,].unsqueeze(1) # select a single supspec
|
||||
x = x.unsqueeze(1)
|
||||
encoded = self.encoder(x)
|
||||
decoded = self.decoder(encoded)
|
||||
return decoded, x
|
||||
@@ -123,5 +258,5 @@ class SubSpecCAE(nn.Module):
|
||||
if isinstance(m, nn.Conv2d) or isinstance(m, torch.nn.Linear):
|
||||
torch.nn.init.kaiming_uniform_(m.weight)
|
||||
if m.bias is not None:
|
||||
m.bias.data.fill_(0.02)
|
||||
m.bias.data.fill_(0.01)
|
||||
self.apply(weight_init)
|
||||
2
models/utils.py
Normal file
2
models/utils.py
Normal file
@@ -0,0 +1,2 @@
|
||||
def count_parameters(model):
|
||||
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
Reference in New Issue
Block a user