import numpy as np
from tqdm import tqdm
import librosa
import librosa.display
from matplotlib import pyplot as plt
from pathlib import Path


class Preprocessor:
    def __init__(self, sr=16000, n_mels=64, n_fft=1024, hop_length=256, chunk_size=64, chunk_hop=32, cmap='viridis'):
        self.sr = sr
        self.n_fft = n_fft
        self.n_mels = n_mels
        self.hop_length = hop_length
        self.chunk_size = chunk_size
        self.chunk_hop = chunk_hop
        self.cmap = cmap

    def process_audio(self, path, out_folder=None):
        mel_spec = self.to_mel_spec(path)
        for count, i in enumerate(range(0, mel_spec.shape[1], self.chunk_hop)):
            try:
                chunk = mel_spec[:, i:i+self.chunk_size]
                out_path = out_folder / f'{path.stem}_{count}.jpg'
                self.mel_spec_to_img(chunk, out_path)  # todo must adjust outpath name
            except IndexError:
                pass


    def to_mel_spec(self, path):
        audio, sr = librosa.load(str(path), sr=self.sr, mono=True)
        spectrogram = librosa.stft(audio,
                                   n_fft=self.n_fft,
                                   hop_length=self.n_fft // 2,
                                   center=False)
        spectrogram = librosa.feature.melspectrogram(S=np.abs(spectrogram) ** 2,
                                                     sr=sr,
                                                     n_mels=self.n_mels,
                                                     hop_length=self.hop_length)
        # prepare plot
        spectrogram = librosa.power_to_db(spectrogram, ref=np.max, top_db=None)
        return spectrogram

    def mel_spec_to_img(self, spectrogram, out_path, size=227):
        # prepare plotting
        fig = plt.figure(frameon=False, tight_layout=False)
        fig.set_size_inches(1, 1)
        ax = plt.Axes(fig, [0., 0., 1., 1.])
        ax.set_axis_off()
        fig.add_axes(ax)

        spectrogram_axes = librosa.display.specshow(spectrogram,
                                                    hop_length=self.n_fft // 2,
                                                    fmax=self.sr/2,
                                                    sr=self.sr,
                                                    cmap=self.cmap,
                                                    y_axis='mel',
                                                    x_axis='time')

        fig.add_axes(spectrogram_axes, id='spectrogram')
        fig.savefig(out_path, format='jpg', dpi=size)
        plt.clf()
        plt.close()

    def process_folder(self, folder_in, folder_out):
        wavs = folder_in.glob('*.wav')
        folder_out.mkdir(parents=True, exist_ok=True)
        for wav in tqdm(list(wavs)):
            self.process_audio(wav, folder_out)

if __name__ == '__main__':
    models = ['slider', 'pump', 'fan']
    model_ids = [0, 2, 4, 6]
    preprocessor = Preprocessor()
    for model in models:
        for model_id in model_ids:
            preprocessor.process_folder(Path(f'/home/robert/coding/audio_anomaly_detection/data/mimii/-6_dB_{model}/id_0{model_id}/normal'),
                                        Path(f'/home/robert/coding/audio_anomaly_detection/data/mimii/-6_dB_{model}/id_0{model_id}/normal/melspec_images/')
                                        )
            preprocessor.process_folder(Path(f'/home/robert/coding/audio_anomaly_detection/data/mimii/-6_dB_{model}/id_0{model_id}/abnormal'),
                                        Path(f'/home/robert/coding/audio_anomaly_detection/data/mimii/-6_dB_{model}/id_0{model_id}/abnormal/melspec_images/')
                                        )