working commit

This commit is contained in:
Robert Müller
2020-03-19 16:59:49 +01:00
parent f4606a7f6c
commit cc9e9b50a4
5 changed files with 112 additions and 20 deletions

View File

@ -2,6 +2,23 @@ import torch
import torch.nn as nn
import torch.functional as F
class Reshape(nn.Module):
def __init__(self, *args):
super(Reshape, self).__init__()
self.to = args
def forward(self, x):
return x.view(x.shape[0], *self.to)
class Flatten(nn.Module):
def __init__(self):
super(Flatten, self).__init__()
def forward(self, x):
return x.view(x.shape[0], -1)
class AE(nn.Module):
def __init__(self, in_dim=400):
super(AE, self).__init__()
@ -17,7 +34,7 @@ class AE(nn.Module):
nn.Linear(64, 64),
nn.ReLU(),
nn.Linear(64, in_dim),
nn.ReLU(),
nn.ReLU()
)
def forward(self, data):
@ -45,4 +62,65 @@ class AE(nn.Module):
if isinstance(m.bias, torch.Tensor):
m.bias.data.fill_(0.01)
self.apply(_weight_init)
self.apply(_weight_init)
class SubSpecCAE(nn.Module):
def __init__(self, F=20, T=80, norm='batch', activation='relu', dropout_prob=0.25):
super(SubSpecCAE, self).__init__()
self.T = T
self.F = F
self.activation = activation
Norm = nn.BatchNorm2d if norm == 'batch' else nn.InstanceNorm2d
Activation = nn.ReLU if activation == 'relu' else nn.LeakyReLU
self.encoder = nn.Sequential(
nn.Conv2d(in_channels=1, out_channels=32, kernel_size=7, stride=1, padding=3), # 32 x 20 x 80
Norm(32),
Activation(),
nn.MaxPool2d((F//10, 5)),
nn.Dropout(dropout_prob),
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=7, stride=1, padding=3), # 64 x 10 x 16
Norm(64),
Activation(),
nn.MaxPool2d(4, T),
nn.Dropout(dropout_prob),
Flatten(),
nn.Linear(64, 16)
)
self.decoder = nn.Sequential(
nn.Linear(16, 64),
Reshape(64, 1, 1),
nn.Upsample(size=(10, 16), mode='bilinear', align_corners=False),
nn.ConvTranspose2d(in_channels=64, out_channels=32, kernel_size=7, stride=1, padding=3),
Norm(32),
Activation(),
nn.Upsample(size=(20, 80), mode='bilinear', align_corners=False),
nn.Dropout(dropout_prob),
nn.ConvTranspose2d(in_channels=32, out_channels=1, kernel_size=7, stride=1, padding=3)
)
def forward(self, x):
x = x[:,3,:,].unsqueeze(1) # select a single supspec
encoded = self.encoder(x)
decoded = self.decoder(encoded)
return decoded, x
def train_loss(self, data):
criterion = nn.MSELoss()
y_hat, y = self.forward(data)
loss = criterion(y_hat, y)
return loss
def test_loss(self, data):
y_hat, y = self.forward(data)
preds = torch.sum((y_hat - y) ** 2, dim=tuple(range(1, y_hat.dim())))
return preds
def init_weights(self):
def weight_init(m):
if isinstance(m, nn.Conv2d) or isinstance(m, torch.nn.Linear):
torch.nn.init.kaiming_uniform_(m.weight)
if m.bias is not None:
m.bias.data.fill_(0.02)
self.apply(weight_init)

View File

@ -1,3 +1,4 @@
import numpy as np
import torch
import torch.nn as nn
@ -9,7 +10,7 @@ class Subspectrogram(object):
def __call__(self, sample):
if len(sample.shape) < 3:
sample = sample.unsqueeze(0)
sample = sample.reshape(1, *sample.shape)
# sample shape: 1 x num_mels x num_frames
sub_specs = []
for i in range(0, sample.shape[1], self.hop_size):