diff --git a/audio_toolset/mel_augmentation.py b/audio_toolset/mel_augmentation.py index c74d683..9a72d78 100644 --- a/audio_toolset/mel_augmentation.py +++ b/audio_toolset/mel_augmentation.py @@ -13,8 +13,8 @@ class NoiseInjection(object): def __call__(self, x: np.ndarray): if self.noise_factor: - noise = np.random.normal(loc=self.mu, scale=self.sigma, size=x.shape) - augmented_data = x + self.noise_factor * noise + noise = np.random.uniform(0, self.noise_factor, size=x.shape) + augmented_data = x + x * noise # Cast back to same data type augmented_data = augmented_data.astype(x.dtype) return augmented_data diff --git a/utils/config.py b/utils/config.py index 9797a2c..bf15954 100644 --- a/utils/config.py +++ b/utils/config.py @@ -39,6 +39,7 @@ class Config(ConfigParser, ABC): h = hashlib.md5() params = deepcopy(self.as_dict) del params['model']['type'] + del params['model']['secondary_type'] del params['data']['worker'] del params['main'] h.update(str(params).encode()) diff --git a/visualization/tools.py b/visualization/tools.py index 02d6e64..32dd67c 100644 --- a/visualization/tools.py +++ b/visualization/tools.py @@ -14,8 +14,10 @@ class Plotter(object): if naked: plt.axis('off') fig.savefig(path, bbox_inches='tight', transparent=True, pad_inches=0) - fig.savefig(path) - fig.clf() + fig.clf() + else: + fig.savefig(path) + fig.clf() def show_current_figure(self): fig, _ = plt.gcf(), plt.gca()