28 lines
803 B
Python
28 lines
803 B
Python
import torch
|
|
import torch.nn as nn
|
|
|
|
|
|
class Subspectrogram(object):
|
|
def __init__(self, height, hop_size):
|
|
self.height = height
|
|
self.hop_size = hop_size
|
|
|
|
def __call__(self, sample):
|
|
if len(sample.shape) < 3:
|
|
sample = sample.unsqueeze(0)
|
|
# sample shape: 1 x num_mels x num_frames
|
|
sub_specs = []
|
|
for i in range(0, sample.shape[1], self.hop_size):
|
|
sub_spec = sample[:, i:i+self.hop_size:,]
|
|
sub_specs.append(sub_spec)
|
|
return np.concatenate(sub_specs)
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
import numpy as np
|
|
sub_spec_tnfm = Subspectrogram(20, 10)
|
|
X = np.random.rand(1, 60, 40)
|
|
Y = sub_spec_tnfm(X)
|
|
print(f'\t Sub-Spectrogram transformation from shape {X.shape} to {Y.shape}')
|
|
print('Done ...') |