127 lines
4.0 KiB
Python
127 lines
4.0 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.functional as F
|
|
|
|
|
|
class Reshape(nn.Module):
|
|
def __init__(self, *args):
|
|
super(Reshape, self).__init__()
|
|
self.to = args
|
|
|
|
def forward(self, x):
|
|
return x.view(x.shape[0], *self.to)
|
|
|
|
class Flatten(nn.Module):
|
|
def __init__(self):
|
|
super(Flatten, self).__init__()
|
|
|
|
def forward(self, x):
|
|
return x.view(x.shape[0], -1)
|
|
|
|
|
|
class AE(nn.Module):
|
|
def __init__(self, in_dim=400):
|
|
super(AE, self).__init__()
|
|
self.net = nn.Sequential(
|
|
nn.Linear(in_dim, 64),
|
|
nn.ReLU(),
|
|
nn.Linear(64, 64),
|
|
nn.ReLU(),
|
|
nn.Linear(64, 8),
|
|
nn.ReLU(),
|
|
nn.Linear(8, 64),
|
|
nn.ReLU(),
|
|
nn.Linear(64, 64),
|
|
nn.ReLU(),
|
|
nn.Linear(64, in_dim),
|
|
nn.ReLU()
|
|
)
|
|
|
|
def forward(self, data):
|
|
x = data.view(data.shape[0], -1)
|
|
return self.net(x), x
|
|
|
|
def train_loss(self, data):
|
|
criterion = nn.MSELoss()
|
|
y_hat, y = self.forward(data)
|
|
loss = criterion(y_hat, y)
|
|
return loss
|
|
|
|
def test_loss(self, data):
|
|
y_hat, y = self.forward(data)
|
|
preds = torch.sum((y_hat - y) ** 2, dim=tuple(range(1, y_hat.dim())))
|
|
return preds
|
|
|
|
def init_weights(self):
|
|
def _weight_init(m):
|
|
if hasattr(m, 'weight'):
|
|
if isinstance(m.weight, torch.Tensor):
|
|
torch.nn.init.xavier_uniform_(m.weight,
|
|
gain=nn.init.calculate_gain('relu'))
|
|
if hasattr(m, 'bias'):
|
|
if isinstance(m.bias, torch.Tensor):
|
|
m.bias.data.fill_(0.01)
|
|
|
|
self.apply(_weight_init)
|
|
|
|
|
|
|
|
class SubSpecCAE(nn.Module):
|
|
def __init__(self, F=20, T=80, norm='batch', activation='relu', dropout_prob=0.25, band=0):
|
|
super(SubSpecCAE, self).__init__()
|
|
self.T = T
|
|
self.F = F
|
|
self.activation = activation
|
|
self.band = band
|
|
Norm = nn.BatchNorm2d if norm == 'batch' else nn.InstanceNorm2d
|
|
Activation = nn.ReLU if activation == 'relu' else nn.LeakyReLU
|
|
self.encoder = nn.Sequential(
|
|
nn.Conv2d(in_channels=1, out_channels=32, kernel_size=7, stride=1, padding=3), # 32 x 20 x 80
|
|
Norm(32),
|
|
Activation(),
|
|
nn.MaxPool2d((F//10, 5)),
|
|
nn.Dropout(dropout_prob),
|
|
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=7, stride=1, padding=3), # 64 x 10 x 16
|
|
Norm(64),
|
|
Activation(),
|
|
nn.MaxPool2d(4, T),
|
|
nn.Dropout(dropout_prob),
|
|
Flatten(),
|
|
nn.Linear(64, 16)
|
|
)
|
|
self.decoder = nn.Sequential(
|
|
nn.Linear(16, 64),
|
|
Reshape(64, 1, 1),
|
|
nn.Upsample(size=(10, 16), mode='bilinear', align_corners=False),
|
|
nn.ConvTranspose2d(in_channels=64, out_channels=32, kernel_size=7, stride=1, padding=3),
|
|
Norm(32),
|
|
Activation(),
|
|
nn.Upsample(size=(20, 80), mode='bilinear', align_corners=False),
|
|
nn.Dropout(dropout_prob),
|
|
nn.ConvTranspose2d(in_channels=32, out_channels=1, kernel_size=7, stride=1, padding=3)
|
|
)
|
|
|
|
def forward(self, x):
|
|
x = x[:,self.band,:,].unsqueeze(1) # select a single supspec
|
|
encoded = self.encoder(x)
|
|
decoded = self.decoder(encoded)
|
|
return decoded, x
|
|
|
|
def train_loss(self, data):
|
|
criterion = nn.MSELoss()
|
|
y_hat, y = self.forward(data)
|
|
loss = criterion(y_hat, y)
|
|
return loss
|
|
|
|
def test_loss(self, data):
|
|
y_hat, y = self.forward(data)
|
|
preds = torch.sum((y_hat - y) ** 2, dim=tuple(range(1, y_hat.dim())))
|
|
return preds
|
|
|
|
def init_weights(self):
|
|
def weight_init(m):
|
|
if isinstance(m, nn.Conv2d) or isinstance(m, torch.nn.Linear):
|
|
torch.nn.init.kaiming_uniform_(m.weight)
|
|
if m.bias is not None:
|
|
m.bias.data.fill_(0.02)
|
|
self.apply(weight_init) |