Transformer running

This commit is contained in:
Steffen Illium 2021-03-04 12:01:09 +01:00
parent 7edd3834a1
commit ad254dae92
14 changed files with 679 additions and 134 deletions

189
.idea/workspace.xml generated
View File

@ -1,18 +1,38 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="AutoImportSettings">
<option name="autoReloadType" value="SELECTIVE" />
</component>
<component name="ChangeListManager">
<list default="true" id="2be1f675-29fe-4a7d-9fe6-9e96cd7c8055" name="Default Changelist" comment="">
<change afterPath="$PROJECT_DIR$/ml_lib/metrics/attention_rollout.py" afterDir="false" />
<change afterPath="$PROJECT_DIR$/ml_lib/utils/_basedatamodule.py" afterDir="false" />
<change afterPath="$PROJECT_DIR$/ml_lib/additions/__init__.py" afterDir="false" />
<change afterPath="$PROJECT_DIR$/ml_lib/additions/losses.py" afterDir="false" />
<change afterPath="$PROJECT_DIR$/ml_lib/utils/equal_sampler.py" afterDir="false" />
<change afterPath="$PROJECT_DIR$/models/performer.py" afterDir="false" />
<change afterPath="$PROJECT_DIR$/models/transformer_model.py" afterDir="false" />
<change afterPath="$PROJECT_DIR$/models/transformer_model_horizontal.py" afterDir="false" />
<change afterPath="$PROJECT_DIR$/multi_run.py" afterDir="false" />
<change afterPath="$PROJECT_DIR$/rebuild_dataset.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" />
<change beforePath="$PROJECT_DIR$/_parameters.ini" beforeDir="false" afterPath="$PROJECT_DIR$/_parameters.ini" afterDir="false" />
<change beforePath="$PROJECT_DIR$/datasets/primates_librosa_datamodule.py" beforeDir="false" afterPath="$PROJECT_DIR$/datasets/primates_librosa_datamodule.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/main.py" beforeDir="false" afterPath="$PROJECT_DIR$/main.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/ml_lib/_templates/new_project/main.py" beforeDir="false" afterPath="$PROJECT_DIR$/ml_lib/_templates/new_project/main.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/ml_lib/audio_toolset/audio_to_mel_dataset.py" beforeDir="false" afterPath="$PROJECT_DIR$/ml_lib/audio_toolset/audio_to_mel_dataset.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/ml_lib/audio_toolset/mel_dataset.py" beforeDir="false" afterPath="$PROJECT_DIR$/ml_lib/audio_toolset/mel_dataset.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/ml_lib/experiments.py" beforeDir="false" />
<change beforePath="$PROJECT_DIR$/ml_lib/metrics/multi_class_classification.py" beforeDir="false" afterPath="$PROJECT_DIR$/ml_lib/metrics/multi_class_classification.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/ml_lib/modules/blocks.py" beforeDir="false" afterPath="$PROJECT_DIR$/ml_lib/modules/blocks.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/ml_lib/modules/model_parts.py" beforeDir="false" afterPath="$PROJECT_DIR$/ml_lib/modules/model_parts.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/ml_lib/modules/util.py" beforeDir="false" afterPath="$PROJECT_DIR$/ml_lib/modules/util.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/ml_lib/utils/logging.py" beforeDir="false" afterPath="$PROJECT_DIR$/ml_lib/utils/logging.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/ml_lib/utils/model_io.py" beforeDir="false" afterPath="$PROJECT_DIR$/ml_lib/utils/model_io.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/ml_lib/utils/_basedatamodule.py" beforeDir="false" afterPath="$PROJECT_DIR$/ml_lib/utils/_basedatamodule.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/ml_lib/utils/config.py" beforeDir="false" afterPath="$PROJECT_DIR$/ml_lib/utils/config.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/ml_lib/utils/logging.py" beforeDir="false" afterPath="$PROJECT_DIR$/ml_lib/utils/loggers.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/ml_lib/utils/tools.py" beforeDir="false" afterPath="$PROJECT_DIR$/ml_lib/utils/tools.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/models/cnn_baseline.py" beforeDir="false" afterPath="$PROJECT_DIR$/models/cnn_baseline.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/util/loss_mixin.py" beforeDir="false" afterPath="$PROJECT_DIR$/util/loss_mixin.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/util/module_mixins.py" beforeDir="false" afterPath="$PROJECT_DIR$/util/module_mixins.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/util/optimizer_mixin.py" beforeDir="false" afterPath="$PROJECT_DIR$/util/optimizer_mixin.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/variables.py" beforeDir="false" afterPath="$PROJECT_DIR$/variables.py" afterDir="false" />
</list>
<option name="SHOW_DIALOG" value="false" />
<option name="HIGHLIGHT_CONFLICTS" value="true" />
@ -27,7 +47,15 @@
</option>
</component>
<component name="Git.Settings">
<option name="RECENT_GIT_ROOT_PATH" value="$PROJECT_DIR$/ml_lib" />
<option name="RECENT_GIT_ROOT_PATH" value="$PROJECT_DIR$" />
</component>
<component name="GitSEFilterConfiguration">
<file-type-list>
<filtered-out-file-type name="LOCAL_BRANCH" />
<filtered-out-file-type name="REMOTE_BRANCH" />
<filtered-out-file-type name="TAG" />
<filtered-out-file-type name="COMMIT_BY_MESSAGE" />
</file-type-list>
</component>
<component name="ProjectId" id="1oTEXjx0b8UPBPmOIGceYxEch8r" />
<component name="ProjectLevelVcsManager" settingsEditedManually="true" />
@ -38,14 +66,53 @@
<component name="PropertiesComponent">
<property name="RunOnceActivity.OpenProjectViewOnStart" value="true" />
<property name="RunOnceActivity.ShowReadmeOnStart" value="true" />
<property name="WebServerToolWindowFactoryState" value="true" />
<property name="WebServerToolWindowFactoryState" value="false" />
<property name="WebServerToolWindowPanel.toolwindow.highlight.mappings" value="true" />
<property name="WebServerToolWindowPanel.toolwindow.highlight.symlinks" value="true" />
<property name="WebServerToolWindowPanel.toolwindow.show.date" value="false" />
<property name="WebServerToolWindowPanel.toolwindow.show.permissions" value="false" />
<property name="WebServerToolWindowPanel.toolwindow.show.size" value="false" />
<property name="credentialsType com.jetbrains.python.remote.PyCreateRemoteInterpreterDialog$PyCreateRemoteSdkForm" value="Web Deployment" />
<property name="last_opened_file_path" value="$PROJECT_DIR$/../inter_challenge_2020" />
<property name="restartRequiresConfirmation" value="false" />
<property name="settings.editor.selected.configurable" value="com.jetbrains.python.configuration.PyActiveSdkModuleConfigurable" />
</component>
<component name="RunManager" selected="Python.main">
<component name="PyDebuggerOptionsProvider">
<option name="mySaveCallSignatures" value="true" />
</component>
<component name="RecentsManager">
<key name="CopyFile.RECENT_KEYS">
<recent name="C:\Users\steff\projects\compare_21\models" />
<recent name="C:\Users\steff\projects\compare_21\util" />
<recent name="C:\Users\steff\projects\compare_21" />
</key>
<key name="MoveFile.RECENT_KEYS">
<recent name="C:\Users\steff\projects\compare_21\util" />
</key>
</component>
<component name="RunManager" selected="Python.rebuild_dataset">
<configuration name="equal_sampler" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true">
<module name="compare_21" />
<option name="INTERPRETER_OPTIONS" value="" />
<option name="PARENT_ENVS" value="true" />
<envs>
<env name="PYTHONUNBUFFERED" value="1" />
</envs>
<option name="SDK_HOME" value="" />
<option name="WORKING_DIRECTORY" value="$PROJECT_DIR$/ml_lib/utils" />
<option name="IS_MODULE_SDK" value="true" />
<option name="ADD_CONTENT_ROOTS" value="true" />
<option name="ADD_SOURCE_ROOTS" value="true" />
<EXTENSION ID="PythonCoverageRunConfigurationExtension" runner="coverage.py" />
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/ml_lib/utils/equal_sampler.py" />
<option name="PARAMETERS" value="" />
<option name="SHOW_COMMAND_LINE" value="false" />
<option name="EMULATE_TERMINAL" value="false" />
<option name="MODULE_MODE" value="false" />
<option name="REDIRECT_INPUT" value="false" />
<option name="INPUT_FILE" value="" />
<method v="2" />
</configuration>
<configuration name="main" type="PythonConfigurationType" factoryName="Python" nameIsGenerated="true">
<module name="compare_21" />
<option name="INTERPRETER_OPTIONS" value="" />
@ -60,7 +127,7 @@
<option name="ADD_SOURCE_ROOTS" value="true" />
<EXTENSION ID="PythonCoverageRunConfigurationExtension" runner="coverage.py" />
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/main.py" />
<option name="PARAMETERS" value="" />
<option name="PARAMETERS" value="--debug=True --num_worker=0 --sampler=WeightedRandomSampler --model_name=VisualTransformer" />
<option name="SHOW_COMMAND_LINE" value="false" />
<option name="EMULATE_TERMINAL" value="false" />
<option name="MODULE_MODE" value="false" />
@ -90,8 +157,62 @@
<option name="INPUT_FILE" value="" />
<method v="2" />
</configuration>
<configuration name="multi_run" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true">
<module name="compare_21" />
<option name="INTERPRETER_OPTIONS" value="" />
<option name="PARENT_ENVS" value="true" />
<envs>
<env name="PYTHONUNBUFFERED" value="1" />
</envs>
<option name="SDK_HOME" value="" />
<option name="WORKING_DIRECTORY" value="$PROJECT_DIR$" />
<option name="IS_MODULE_SDK" value="true" />
<option name="ADD_CONTENT_ROOTS" value="true" />
<option name="ADD_SOURCE_ROOTS" value="true" />
<EXTENSION ID="PythonCoverageRunConfigurationExtension" runner="coverage.py" />
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/multi_run.py" />
<option name="PARAMETERS" value="" />
<option name="SHOW_COMMAND_LINE" value="false" />
<option name="EMULATE_TERMINAL" value="false" />
<option name="MODULE_MODE" value="false" />
<option name="REDIRECT_INPUT" value="false" />
<option name="INPUT_FILE" value="" />
<method v="2" />
</configuration>
<configuration name="rebuild_dataset" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true">
<module name="compare_21" />
<option name="INTERPRETER_OPTIONS" value="" />
<option name="PARENT_ENVS" value="true" />
<envs>
<env name="PYTHONUNBUFFERED" value="1" />
</envs>
<option name="SDK_HOME" value="" />
<option name="WORKING_DIRECTORY" value="$PROJECT_DIR$" />
<option name="IS_MODULE_SDK" value="true" />
<option name="ADD_CONTENT_ROOTS" value="true" />
<option name="ADD_SOURCE_ROOTS" value="true" />
<EXTENSION ID="PythonCoverageRunConfigurationExtension" runner="coverage.py" />
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/rebuild_dataset.py" />
<option name="PARAMETERS" value="" />
<option name="SHOW_COMMAND_LINE" value="false" />
<option name="EMULATE_TERMINAL" value="false" />
<option name="MODULE_MODE" value="false" />
<option name="REDIRECT_INPUT" value="false" />
<option name="INPUT_FILE" value="" />
<method v="2" />
</configuration>
<list>
<item itemvalue="Python.main" />
<item itemvalue="Python.metadata_readout" />
<item itemvalue="Python.multi_run" />
<item itemvalue="Python.equal_sampler" />
<item itemvalue="Python.rebuild_dataset" />
</list>
<recent_temporary>
<list>
<item itemvalue="Python.rebuild_dataset" />
<item itemvalue="Python.multi_run" />
<item itemvalue="Python.equal_sampler" />
<item itemvalue="Python.metadata_readout" />
</list>
</recent_temporary>
@ -106,25 +227,69 @@
<updated>1613302221903</updated>
<workItem from="1613302223434" duration="2570000" />
<workItem from="1613305247599" duration="11387000" />
<workItem from="1613381449219" duration="21710000" />
<workItem from="1613634319983" duration="160704000" />
<workItem from="1614498086696" duration="17996000" />
<workItem from="1614671803611" duration="595000" />
<workItem from="1614679476632" duration="11417000" />
<workItem from="1614760180356" duration="12583000" />
<workItem from="1614788597119" duration="8215000" />
</task>
<task id="LOCAL-00001" summary="Dataset rdy">
<created>1613467084268</created>
<option name="number" value="00001" />
<option name="presentableId" value="LOCAL-00001" />
<option name="project" value="LOCAL" />
<updated>1613467084268</updated>
</task>
<option name="localTasksCounter" value="2" />
<servers />
</component>
<component name="TypeScriptGeneratedFilesManager">
<option name="version" value="3" />
</component>
<component name="VcsManagerConfiguration">
<MESSAGE value="Dataset rdy" />
<option name="LAST_COMMIT_MESSAGE" value="Dataset rdy" />
</component>
<component name="XDebuggerManager">
<breakpoint-manager>
<breakpoints>
<line-breakpoint enabled="true" suspend="THREAD" type="python-line">
<url>file://$PROJECT_DIR$/ml_lib/modules/util.py</url>
<line>231</line>
<option name="timeStamp" value="70" />
</line-breakpoint>
<line-breakpoint enabled="true" suspend="THREAD" type="python-line">
<url>file://$PROJECT_DIR$/models/transformer_model_horizontal.py</url>
<line>68</line>
<option name="timeStamp" value="72" />
</line-breakpoint>
<line-breakpoint enabled="true" suspend="THREAD" type="python-line">
<url>file://$PROJECT_DIR$/ml_lib/audio_toolset/mel_dataset.py</url>
<line>29</line>
<option name="timeStamp" value="83" />
</line-breakpoint>
<line-breakpoint enabled="true" suspend="THREAD" type="python-line">
<url>file://$PROJECT_DIR$/main.py</url>
<line>11</line>
<option name="timeStamp" value="2" />
<line>46</line>
<option name="timeStamp" value="84" />
</line-breakpoint>
</breakpoints>
<default-breakpoints>
<breakpoint type="python-exception">
<properties notifyOnTerminate="true" exception="BaseException">
<option name="notifyOnTerminate" value="true" />
</properties>
</breakpoint>
</default-breakpoints>
</breakpoint-manager>
</component>
<component name="com.intellij.coverage.CoverageDataManagerImpl">
<SUITE FILE_PATH="coverage/compare_21$equal_sampler.coverage" NAME="equal_sampler Coverage Results" MODIFIED="1614070702295" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/ml_lib/utils" />
<SUITE FILE_PATH="coverage/compare_21$rebuild_dataset.coverage" NAME="rebuild_dataset Coverage Results" MODIFIED="1614849857426" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" />
<SUITE FILE_PATH="coverage/compare_21$metadata_readout.coverage" NAME="metadata_readout Coverage Results" MODIFIED="1613306122664" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/datasets" />
<SUITE FILE_PATH="coverage/compare_21$main.coverage" NAME="main Coverage Results" MODIFIED="1613376576627" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" />
<SUITE FILE_PATH="coverage/compare_21$main.coverage" NAME="main Coverage Results" MODIFIED="1614695189720" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" />
<SUITE FILE_PATH="coverage/compare_21$multi_run.coverage" NAME="multi_run Coverage Results" MODIFIED="1614702462483" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" />
</component>
</project>

View File

@ -1,6 +1,6 @@
[project]
neptune_key = eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vdWkubmVwdHVuZS5haSIsImFwaV91cmwiOiJodHRwczovL3VpLm5lcHR1bmUuYWkiLCJhcGlfa2V5IjoiZmI0OGMzNzUtOTg1NS00Yzg2LThjMzYtMWFiYjUwMDUyMjVlIn0=
debug = 1
debug = False
eval = True
seed = 69
owner = si11ium
@ -8,56 +8,72 @@ model_name = CNNBaseline
data_name = PrimatesLibrosaDatamodule
[data]
num_worker = 0
num_worker = 10
data_root = data
reset = False
n_mels = 64
sr = 16000
hop_length = 256
n_fft = 512
hop_length = 128
n_fft = 256
random_apply_chance = 0.7
loudness_ratio = 0.0
shift_ratio = 0.0
noise_ratio = 0
mask_ratio = 0.3
speed_amount = 0
speed_min = 0
speed_max = 0
shift_ratio = 0.3
noise_ratio = 0.0
mask_ratio = 0.0
[model_cnn]
[CNNBaseline]
weight_init = xavier_normal_
activation = gelu
use_bias = True
use_norm = True
dropout = 0.2
lat_dim = 128
lat_dim = 32
features = 64
filters = [32, 64, 128, 64]
filters = [16, 32, 64, 128]
[model_attn]
name = VerticalVisualTransformer
[VisualTransformer]
weight_init = xavier_normal_
activation = gelu
use_bias = True
use_norm = True
use_residual = True
dropout = 0.2
lat_dim = 128
features = 64
patch_size = 3
attn_depth = 3
heads = 8
embedding_size = 64
lat_dim = 32
patch_size = 8
attn_depth = 12
heads = 4
embedding_size = 128
[HorizontalVisualTransformer]
weight_init = xavier_normal_
activation = gelu
use_bias = True
use_norm = True
use_residual = True
dropout = 0.3
lat_dim = 256
patch_size = 8
attn_depth = 12
heads = 6
embedding_size = 32
[train]
outpath = output
version = None
gpus=0
sampler = EqualSampler
loss = focal_loss_rob
sto_weight_avg = False
weight_decay = 0
opt_reset_interval = 0
epochs = 100
max_epochs = 200
batch_size = 30
lr = 0.01
lr_warmup_steps = 0
num_sanity_val_steps = 0
lr = 0.001
use_residual = True
lr_warm_restart_epochs = 0
num_sanity_val_steps = 2
check_val_every_n_epoch = 5
checkpoint_callback = True
gradient_clip_val = 0

View File

@ -1,8 +1,8 @@
from multiprocessing.pool import ApplyResult
import multiprocessing as mp
from collections import defaultdict
from pathlib import Path
from typing import List
from torch.utils.data import DataLoader, ConcatDataset
from torch.utils.data import DataLoader, ConcatDataset, WeightedRandomSampler
from torchvision.transforms import Compose, RandomApply
from tqdm import tqdm
@ -10,9 +10,8 @@ from ml_lib.audio_toolset.audio_io import NormalizeLocal
from ml_lib.audio_toolset.audio_to_mel_dataset import LibrosaAudioToMelDataset
from ml_lib.audio_toolset.mel_augmentation import NoiseInjection, LoudnessManipulator, ShiftTime, MaskAug
from ml_lib.utils._basedatamodule import _BaseDataModule, DATA_OPTION_test, DATA_OPTION_train, DATA_OPTION_devel
from ml_lib.utils.equal_sampler import EqualSampler
from ml_lib.utils.transforms import ToTensor
import multiprocessing as mp
data_options = [DATA_OPTION_test, DATA_OPTION_train, DATA_OPTION_devel]
@ -34,11 +33,16 @@ class PrimatesLibrosaDatamodule(_BaseDataModule):
def wav_folder(self):
return self.root / 'wav'
def __init__(self, data_root, batch_size, num_worker, sr, n_mels, n_fft, hop_length,
sample_segment_len=40, sample_hop_len=15):
def __init__(self, data_root, batch_size, num_worker, sr, n_mels, n_fft, hop_length, sampler=None,
sample_segment_len=40, sample_hop_len=15, random_apply_chance=0.5,
loudness_ratio=0.3, shift_ratio=0.3, noise_ratio=0.3, mask_ratio=0.3):
super(PrimatesLibrosaDatamodule, self).__init__()
self.sampler = sampler
self.samplers = None
self.sample_hop_len = sample_hop_len
self.sample_segment_len = sample_segment_len
self.num_worker = num_worker or 1
self.batch_size = batch_size
self.root = Path(data_root) / 'primates'
@ -51,23 +55,22 @@ class PrimatesLibrosaDatamodule(_BaseDataModule):
self.utility_transforms = Compose([NormalizeLocal(), ToTensor()])
# Data Augmentations
self.random_apply_chance = random_apply_chance
self.mel_augmentations = Compose([
# ToDo: HP Search this parameters, make it adjustable from outside
RandomApply([NoiseInjection(0.2)], p=0.3),
RandomApply([LoudnessManipulator(0.5)], p=0.3),
RandomApply([ShiftTime(0.4)], p=0.3),
RandomApply([MaskAug(0.2)], p=0.3),
RandomApply([NoiseInjection(noise_ratio)], p=random_apply_chance),
RandomApply([LoudnessManipulator(loudness_ratio)], p=random_apply_chance),
RandomApply([ShiftTime(shift_ratio)], p=random_apply_chance),
RandomApply([MaskAug(mask_ratio)], p=random_apply_chance),
self.utility_transforms])
def train_dataloader(self):
return DataLoader(dataset=self.datasets[DATA_OPTION_train], shuffle=True,
batch_size=self.batch_size, pin_memory=True,
num_workers=self.num_worker)
return DataLoader(dataset=self.datasets[DATA_OPTION_train], num_workers=self.num_worker, pin_memory=True,
sampler=self.samplers[DATA_OPTION_train], batch_size=self.batch_size)
# Validation Dataloader
def val_dataloader(self):
return DataLoader(dataset=self.datasets[DATA_OPTION_devel], shuffle=False, pin_memory=True,
batch_size=self.batch_size, num_workers=self.num_worker)
return DataLoader(dataset=self.datasets[DATA_OPTION_devel], num_workers=self.num_worker, pin_memory=True,
sampler=self.samplers[DATA_OPTION_devel], batch_size=self.batch_size)
# Test Dataloader
def test_dataloader(self):
@ -79,6 +82,7 @@ class PrimatesLibrosaDatamodule(_BaseDataModule):
slice_file_name, class_name = row.strip().split(',')
class_id = self.class_names.get(class_name, -1)
audio_file_path = self.wav_folder / slice_file_name
# DATA OPTION DIFFERENTIATION !!!!!!!!!!! - Begin
kwargs = self.__dict__
if any([x in slice_file_name for x in [DATA_OPTION_devel, DATA_OPTION_test]]):
@ -91,7 +95,7 @@ class PrimatesLibrosaDatamodule(_BaseDataModule):
mel_dataset = LibrosaAudioToMelDataset(audio_file_path, class_id, **kwargs)
if build:
assert mel_dataset.build_mel()
return mel_dataset
return mel_dataset, class_id, slice_file_name
def prepare_data(self, *args, **kwargs):
datasets = dict()
@ -103,22 +107,21 @@ class PrimatesLibrosaDatamodule(_BaseDataModule):
chunksize = len(all_rows) // max(self.num_worker, 1)
dataset = list()
with mp.Pool(processes=self.num_worker) as pool:
pbar = tqdm(total=len(all_rows))
def update():
pbar.update(chunksize)
from itertools import repeat
results = pool.starmap_async(self._build_subdataset, zip(all_rows, repeat(True, len(all_rows))),
chunksize=chunksize)
for sub_dataset in results.get():
dataset.append(sub_dataset)
update() # FIXME: will i ever get this to work?
dataset.append(sub_dataset[0])
datasets[data_option] = ConcatDataset(dataset)
self.datasets = datasets
return datasets
def setup(self, stag=None):
datasets = dict()
samplers = dict()
weights = dict()
for data_option in data_options:
with open(Path(self.root) / 'lab' / f'{data_option}.csv', mode='r') as f:
# Exclude the header
@ -126,7 +129,38 @@ class PrimatesLibrosaDatamodule(_BaseDataModule):
all_rows = list(f)
dataset = list()
for row in all_rows:
dataset.append(self._build_subdataset(row))
mel_dataset, class_id, _ = self._build_subdataset(row)
dataset.append(mel_dataset)
datasets[data_option] = ConcatDataset(dataset)
# Build Weighted Sampler for train and val
if data_option in [DATA_OPTION_train, DATA_OPTION_devel]:
if self.sampler == EqualSampler.__name__:
class_idxs = [[idx for idx, (_, __, label) in enumerate(datasets[data_option]) if label == class_idx]
for class_idx in range(len(self.class_names))
]
samplers[data_option] = EqualSampler(class_idxs)
elif self.sampler == WeightedRandomSampler.__name__:
class_counts = defaultdict(lambda: 0)
for _, __, label in datasets[data_option]:
class_counts[label] += 1
len_largest_class = max(class_counts.values())
weights[data_option] = [1 / class_counts[x] for x in range(len(class_counts))]
weights[data_option] = [weights[data_option][datasets[data_option][i][-1]]
for i in range(len(datasets[data_option]))]
samplers[data_option] = WeightedRandomSampler(weights[data_option],
len_largest_class * len(self.class_names))
else:
samplers[data_option] = None
self.datasets = datasets
self.samplers = samplers
return datasets
def purge(self):
import shutil
shutil.rmtree(self.mel_folder, ignore_errors=True)
print('Mel Folder has been recursively deleted')
print(f'Folder still exists: {self.mel_folder.exists()}')
return not self.mel_folder.exists()

67
main.py
View File

@ -1,45 +1,29 @@
import configparser
from argparse import ArgumentParser, Namespace
from argparse import Namespace
import warnings
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from ml_lib.utils.logging import Logger
from ml_lib.utils.config import parse_comandline_args_add_defaults
from ml_lib.utils.loggers import Logger
from ml_lib.utils.tools import locate_and_import_class, auto_cast
import variables as v
if __name__ == '__main__':
# Argument Parser and default Values
# =============================================================================
# Load Defaults from _parameters.ini file
config = configparser.ConfigParser()
config.read('_parameters.ini')
project = config['project']
warnings.filterwarnings('ignore', category=FutureWarning)
warnings.filterwarnings('ignore', category=UserWarning)
data_class = locate_and_import_class(project['data_name'], 'datasets')
model_class = locate_and_import_class(project['model_name'], 'models')
tmp_params = dict()
for key in ['project', 'train', 'data', 'model_cnn']:
defaults = config[key]
tmp_params.update({key: auto_cast(val) for key, val in defaults.items()})
# Parse Command Line
parser = ArgumentParser()
for module in [Logger, Trainer, data_class, model_class]:
parser = module.add_argparse_args(parser)
cmd_args, _ = parser.parse_known_args()
tmp_params.update({key: val for key, val in vars(cmd_args).items() if val is not None})
hparams = Namespace(**tmp_params)
with Logger.from_argparse_args(hparams) as logger:
def run_lightning_loop(h_params, data_class, model_class):
with Logger.from_argparse_args(h_params) as logger:
# Callbacks
# =============================================================================
# Checkpoint Saving
ckpt_callback = ModelCheckpoint(
monitor='mean_loss',
filepath=str(logger.log_dir / 'ckpt_weights'),
dirpath=str(logger.log_dir),
filename='ckpt_weights',
verbose=False,
save_top_k=3,
)
@ -47,22 +31,37 @@ if __name__ == '__main__':
# Learning Rate Logger
lr_logger = LearningRateMonitor(logging_interval='epoch')
#
# START
# =============================================================================
# Let Datamodule pull what it wants
datamodule = data_class.from_argparse_args(hparams)
datamodule = data_class.from_argparse_args(h_params)
datamodule.setup()
model_in_shape = datamodule.shape
# Let Trainer pull what it wants and add callbacks
trainer = Trainer.from_argparse_args(hparams, callbacks=[ckpt_callback, lr_logger])
trainer = Trainer.from_argparse_args(h_params, logger=logger, callbacks=[ckpt_callback, lr_logger])
# Let Model pull what it wants
model = model_class.from_argparse_args(hparams, in_shape=datamodule.shape, n_classes=v.N_CLASS_multi)
model = model_class.from_argparse_args(h_params, in_shape=datamodule.shape, n_classes=v.N_CLASS_multi)
model.init_weights()
logger.log_hyperparams(dict(model.params))
trainer.fit(model, datamodule)
trainer.save_checkpoint(trainer.logger.save_dir)
# Log paramters
pytorch_total_params = sum(p.numel() for p in model.parameters())
# logger.log_text('n_parameters', pytorch_total_params)
trainer.save_checkpoint(logger.save_dir / 'weights.ckpt')
if __name__ == '__main__':
# Parse comandline args, read config and get model
cmd_args, found_data_class, found_model_class = parse_comandline_args_add_defaults('_parameters.ini')
# To NameSpace
hparams = Namespace(**cmd_args)
# Start
# -----------------
run_lightning_loop(hparams, found_data_class, found_model_class)
print('done')
pass

View File

@ -1,8 +1,6 @@
import inspect
from argparse import Namespace
import variables as v
from torch import nn
from ml_lib.metrics.multi_class_classification import MultiClassScores
@ -17,7 +15,7 @@ class CNNBaseline(CombinedModelMixins,
):
def __init__(self, in_shape, n_classes, weight_init, activation, use_bias, use_norm, dropout, lat_dim, features,
filters):
filters, lr, weight_decay, sto_weight_avg, lr_warm_restart_epochs, opt_reset_interval, loss):
# TODO: Move this to parent class, or make it much easieer to access....
a = dict(locals())
@ -50,4 +48,4 @@ class CNNBaseline(CombinedModelMixins,
return Namespace(main_out=tensor)
def additional_scores(self, outputs):
return MultiClassScores(self)
return MultiClassScores(self)(outputs)

0
models/performer.py Normal file
View File

120
models/transformer_model.py Normal file
View File

@ -0,0 +1,120 @@
import inspect
from argparse import Namespace
import warnings
import torch
from torch import nn
from einops import rearrange, repeat
from ml_lib.metrics.multi_class_classification import MultiClassScores
from ml_lib.modules.blocks import TransformerModule
from ml_lib.modules.util import (LightningBaseModule, AutoPadToShape, F_x)
from util.module_mixins import CombinedModelMixins
MIN_NUM_PATCHES = 16
class VisualTransformer(CombinedModelMixins,
LightningBaseModule
):
def __init__(self, in_shape, n_classes, weight_init, activation,
embedding_size, heads, attn_depth, patch_size,use_residual,
use_bias, use_norm, dropout, lat_dim, loss,
lr, weight_decay, sto_weight_avg, lr_warm_restart_epochs, opt_reset_interval):
# TODO: Move this to parent class, or make it much easieer to access... But How...
a = dict(locals())
params = {arg: a[arg] for arg in inspect.signature(self.__init__).parameters.keys() if arg != 'self'}
super(VisualTransformer, self).__init__(params)
self.in_shape = in_shape
assert len(self.in_shape) == 3, 'There need to be three Dimensions'
channels, height, width = self.in_shape
# Model Paramters
# =============================================================================
# Additional parameters
self.embed_dim = self.params.embedding_size
# Automatic Image Shaping
self.patch_size = self.params.patch_size
image_size = (max(height, width) // self.patch_size) * self.patch_size
self.image_size = image_size + self.patch_size if image_size < max(height, width) else image_size
# This should be obsolete
assert self.image_size % self.patch_size == 0, 'image dimensions must be divisible by the patch size'
num_patches = (self.image_size // self.patch_size) ** 2
patch_dim = channels * self.patch_size ** 2
assert num_patches >= MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for ' + \
f'attention. Try decreasing your patch size'
# Correct the Embedding Dim
if not self.embed_dim % self.params.heads == 0:
self.embed_dim = (self.embed_dim // self.params.heads) * self.params.heads
message = ('Embedding Dimension was fixed to be devideable by the number' +
f' of attention heads, is now: {self.embed_dim}')
for func in print, warnings.warn:
func(message)
# Utility Modules
self.autopad = AutoPadToShape((self.image_size, self.image_size))
# Modules with Parameters
self.transformer = TransformerModule(in_shape=self.embed_dim, mlp_dim=self.params.lat_dim,
heads=self.params.heads, depth=self.params.attn_depth,
dropout=self.params.dropout, use_norm=self.params.use_norm,
activation=self.params.activation, use_residual=self.params.use_residual
)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, self.embed_dim))
self.patch_to_embedding = nn.Linear(patch_dim, self.embed_dim) if self.params.embedding_size \
else F_x(self.embed_dim)
self.cls_token = nn.Parameter(torch.randn(1, 1, self.embed_dim))
self.dropout = nn.Dropout(self.params.dropout)
self.to_cls_token = nn.Identity()
self.mlp_head = nn.Sequential(
nn.LayerNorm(self.embed_dim),
nn.Linear(self.embed_dim, self.params.lat_dim),
nn.GELU(),
nn.Dropout(self.params.dropout),
nn.Linear(self.params.lat_dim, n_classes),
nn.Softmax()
)
def forward(self, x, mask=None, return_attn_weights=False):
"""
:param x: the sequence to the encoder (required).
:param mask: the mask for the src sequence (optional).
:return:
"""
tensor = self.autopad(x)
p = self.params.patch_size
tensor = rearrange(tensor, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p)
tensor = self.patch_to_embedding(tensor)
b, n, _ = tensor.shape
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b)
tensor = torch.cat((cls_tokens, tensor), dim=1)
tensor += self.pos_embedding[:, :(n + 1)]
tensor = self.dropout(tensor)
if return_attn_weights:
tensor, attn_weights = self.transformer(tensor, mask, return_attn_weights)
else:
attn_weights = None
tensor = self.transformer(tensor, mask)
tensor = self.to_cls_token(tensor[:, 0])
tensor = self.mlp_head(tensor)
return Namespace(main_out=tensor, attn_weights=attn_weights)
def additional_scores(self, outputs):
return MultiClassScores(self)(outputs)

View File

@ -0,0 +1,120 @@
import inspect
from argparse import Namespace
import warnings
import torch
from einops import repeat
from torch import nn
from ml_lib.metrics.multi_class_classification import MultiClassScores
from ml_lib.modules.blocks import TransformerModule
from ml_lib.modules.util import (LightningBaseModule, AutoPadToShape, F_x, SlidingWindow)
from util.module_mixins import CombinedModelMixins
from variables import N_CLASS_multi
MIN_NUM_PATCHES = 16
class HorizontalVisualTransformer(CombinedModelMixins,
LightningBaseModule
):
def __init__(self, in_shape, n_classes, weight_init, activation,
embedding_size, heads, attn_depth, patch_size,use_residual,
use_bias, use_norm, dropout, lat_dim, features, loss,
lr, weight_decay, sto_weight_avg, lr_warm_restart_epochs, opt_reset_interval):
# TODO: Move this to parent class, or make it much easieer to access... But How...
a = dict(locals())
params = {arg: a[arg] for arg in inspect.signature(self.__init__).parameters.keys() if arg != 'self'}
super(HorizontalVisualTransformer, self).__init__(params)
self.in_shape = in_shape
assert len(self.in_shape) == 3, 'There need to be three Dimensions'
channels, height, width = self.in_shape
# Model Paramters
# =============================================================================
# Additional parameters
self.n_classes = N_CLASS_multi
self.embed_dim = self.params.embedding_size
self.height = height
self.width = width
self.channels = channels
self.new_height = ((self.height - self.params.patch_size)//1) + 1
num_patches = self.new_height - (self.params.patch_size // 2)
patch_dim = channels * self.params.patch_size * self.width
assert patch_dim
assert num_patches >= MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for ' + \
f'attention. Try decreasing your patch size'
# Correct the Embedding Dim
if not self.embed_dim % self.params.heads == 0:
self.embed_dim = (self.embed_dim // self.params.heads) * self.params.heads
message = ('Embedding Dimension was fixed to be devideable by the number' +
f' of attention heads, is now: {self.embed_dim}')
for func in print, warnings.warn:
func(message)
# Utility Modules
self.autopad = AutoPadToShape((self.new_height, self.width))
self.dropout = nn.Dropout(self.params.dropout)
self.slider = SlidingWindow((channels, *self.autopad.target_shape), (self.params.patch_size, self.width),
keepdim=False)
# Modules with Parameters
self.transformer = TransformerModule(in_shape=self.embed_dim, mlp_dim=self.params.lat_dim,
heads=self.params.heads, depth=self.params.attn_depth,
dropout=self.params.dropout, use_norm=self.params.use_norm,
activation=self.params.activation
)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, self.embed_dim))
self.patch_to_embedding = nn.Linear(patch_dim, self.embed_dim) if self.params.embedding_size \
else F_x(self.embed_dim)
self.cls_token = nn.Parameter(torch.randn(1, 1, self.embed_dim))
self.to_cls_token = nn.Identity()
self.mlp_head = nn.Sequential(
nn.LayerNorm(self.embed_dim),
nn.Linear(self.embed_dim, self.params.lat_dim),
nn.GELU(),
nn.Dropout(self.params.dropout),
nn.Linear(self.params.lat_dim, n_classes),
nn.Softmax()
)
def forward(self, x, mask=None, return_attn_weights=False):
"""
:param x: the sequence to the encoder (required).
:param mask: the mask for the src sequence (optional).
:param return_attn_weights: wether to return the attn weights (optional)
:return:
"""
tensor = self.autopad(x)
tensor = self.slider(tensor)
tensor = self.patch_to_embedding(tensor)
b, n, _ = tensor.shape
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b)
tensor = torch.cat((cls_tokens, tensor), dim=1)
tensor += self.pos_embedding[:, :(n + 1)]
tensor = self.dropout(tensor)
if return_attn_weights:
tensor, attn_weights = self.transformer(tensor, mask, return_attn_weights)
else:
attn_weights = None
tensor = self.transformer(tensor, mask)
tensor = self.to_cls_token(tensor[:, 0])
tensor = self.mlp_head(tensor)
return Namespace(main_out=tensor, attn_weights=attn_weights)
def additional_scores(self, outputs):
return MultiClassScores(self)(outputs)

48
multi_run.py Normal file
View File

@ -0,0 +1,48 @@
from argparse import Namespace
from tqdm import tqdm
from main import run_lightning_loop
from ml_lib.utils.config import parse_comandline_args_add_defaults
import itertools
if __name__ == '__main__':
# Set new values
hparams_dict = dict(model_name=['VisualTransformer'],
max_epochs=[150],
batch_size=[50],
random_apply_chance=[0.5],
loudness_ratio=[0],
shift_ratio=[0.3],
noise_ratio=[0.3],
mask_ratio=[0.3],
lr=[0.001],
dropout=[0.2],
lat_dim=[32, 64],
patch_size=[8, 12],
attn_depth=[12],
heads=[6],
embedding_size=[16, 32],
loss=['ce_loss'],
sampler=['WeightedRandomSampler']
)
keys, values = zip(*hparams_dict.items())
permutations_dicts = [dict(zip(keys, v)) for v in itertools.product(*values)]
for permutations_dict in tqdm(permutations_dicts, total=len(permutations_dicts)):
# Parse comandline args, read config and get model
cmd_args, found_data_class, found_model_class = parse_comandline_args_add_defaults(
'_parameters.ini', overrides=permutations_dict)
hparams = dict(**cmd_args)
hparams.update(permutations_dict)
hparams = Namespace(**hparams)
# RUN
# ---------------------------------------
print(f'Running Loop, parameters are: {permutations_dict}')
run_lightning_loop(hparams, found_data_class, found_model_class)
print(f'Done, parameters were: {permutations_dict}')
pass

34
rebuild_dataset.py Normal file
View File

@ -0,0 +1,34 @@
from argparse import Namespace
import warnings
from ml_lib.utils.config import parse_comandline_args_add_defaults
warnings.filterwarnings('ignore', category=FutureWarning)
warnings.filterwarnings('ignore', category=UserWarning)
def rebuild_dataset(h_params, data_class):
# START
# =============================================================================
# Let Datamodule pull what it wants
datamodule = data_class.from_argparse_args(h_params)
assert datamodule.purge()
datasets = datamodule.prepare_data()
datasets = datamodule.setup()
print(f'Dataset length is: {len(datasets)}')
if __name__ == '__main__':
# Parse comandline args, read config and get model
cmd_args, found_data_class, _ = parse_comandline_args_add_defaults('_parameters.ini')
# To NameSpace
hparams = Namespace(**cmd_args)
# Start
# -----------------
rebuild_dataset(hparams, found_data_class)
print('done')
pass

View File

@ -1,5 +1,7 @@
from torch import nn
from ml_lib.additions.losses import FocalLoss, FocalLossRob
class LossMixin:
@ -7,3 +9,5 @@ class LossMixin:
nll_loss = nn.NLLLoss()
bce_loss = nn.BCELoss()
ce_loss = nn.CrossEntropyLoss()
focal_loss = FocalLoss(None)
focal_loss_rob = FocalLossRob()

View File

@ -1,6 +1,7 @@
from abc import ABC
import torch
import pandas as pd
from ml_lib.modules.util import LightningBaseModule
from util.loss_mixin import LossMixin
@ -11,9 +12,15 @@ class TrainMixin:
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
assert isinstance(self, LightningBaseModule)
batch_x, batch_y = batch_xy
batch_files, batch_x, batch_y = batch_xy
y = self(batch_x).main_out
loss = self.ce_loss(y.squeeze(), batch_y.long())
if self.params.loss == 'focal_loss_rob':
labels_one_hot = torch.nn.functional.one_hot(batch_y, num_classes=5)
loss = self.__getattribute__(self.params.loss)(y, labels_one_hot)
else:
loss = self.__getattribute__(self.params.loss)(y, batch_y.long())
return dict(loss=loss)
def training_epoch_end(self, outputs):
@ -23,21 +30,23 @@ class TrainMixin:
summary_dict = {f'mean_{key}': torch.mean(torch.stack([output[key]
for output in outputs]))
for key in keys if 'loss' in key}
for key in summary_dict.keys():
self.log(key, summary_dict[key])
summary_dict.update(epoch=self.current_epoch)
self.log_dict(summary_dict)
class ValMixin:
def validation_step(self, batch_xy, batch_idx, *args, **kwargs):
assert isinstance(self, LightningBaseModule)
batch_x, batch_y = batch_xy
batch_files, batch_x, batch_y = batch_xy
model_out = self(batch_x)
y = model_out.main_out
val_loss = self.ce_loss(y.squeeze(), batch_y.long())
val_loss = self.ce_loss(y, batch_y.long())
return dict(val_loss=val_loss,
self.metrics.update(y, batch_y) # torch.argmax(y, -1), batch_y)
return dict(val_loss=val_loss, batch_files=batch_files,
batch_idx=batch_idx, y=y, batch_y=batch_y)
def validation_epoch_end(self, outputs, *_, **__):
@ -49,40 +58,40 @@ class ValMixin:
for output in outputs]))
for key in keys if 'loss' in key}
)
# Sklearn Scores
additional_scores = self.additional_scores(outputs)
summary_dict.update(**additional_scores)
for key in summary_dict.keys():
self.log(key, summary_dict[key])
pl_metrics, pl_images = self.metrics.compute_and_prepare()
self.metrics.reset()
summary_dict.update(**pl_metrics)
summary_dict.update(epoch=self.current_epoch)
self.log_dict(summary_dict, on_epoch=True)
for name, image in pl_images.items():
self.logger.log_image(name, image, step=self.global_step)
pass
class TestMixin:
def test_step(self, batch_xy, batch_idx, *_, **__):
assert isinstance(self, LightningBaseModule)
batch_x, batch_y = batch_xy
batch_files, batch_x, batch_y = batch_xy
model_out = self(batch_x)
y = model_out.main_out
test_loss = self.ce_loss(y.squeeze(), batch_y.long())
return dict(test_loss=test_loss,
batch_idx=batch_idx, y=y, batch_y=batch_y)
return dict(batch_files=batch_files, batch_idx=batch_idx, y=y)
def test_epoch_end(self, outputs, *_, **__):
assert isinstance(self, LightningBaseModule)
summary_dict = dict()
keys = list(outputs[0].keys())
summary_dict.update({f'mean_{key}': torch.mean(torch.stack([output[key]
for output in outputs]))
for key in keys if 'loss' in key}
)
y_arg_max = torch.argmax(outputs[0]['y'])
additional_scores = self.additional_scores(outputs)
summary_dict.update(**additional_scores)
pd.DataFrame(data=dict(filenames=outputs[0]['batch_files'], predtiction=y_arg_max))
for key in summary_dict.keys():
self.log(key, summary_dict[key])
# No logging, just inference.
# self.log_dict(summary_dict, on_epoch=True)
class CombinedModelMixins(LossMixin,

View File

@ -1,6 +1,6 @@
from collections import defaultdict
from torch.optim import Adam
from torch.optim import Adam, AdamW
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, LambdaLR
from torchcontrib.optim import SWA
@ -20,12 +20,12 @@ class OptimizerMixin:
# 'monitor': 'mean_val_loss' # Metric to monitor
)
optimizer = Adam(params=self.parameters(), lr=self.params.lr, weight_decay=self.params.weight_decay)
optimizer = AdamW(params=self.parameters(), lr=self.params.lr, weight_decay=self.params.weight_decay)
if self.params.sto_weight_avg:
optimizer = SWA(optimizer, swa_start=10, swa_freq=5, swa_lr=0.05)
optimizer_dict.update(optimizer=optimizer)
if self.params.lr_warmup_steps:
scheduler = CosineAnnealingWarmRestarts(optimizer, self.params.lr_warmup_steps)
if self.params.lr_warm_restart_epochs:
scheduler = CosineAnnealingWarmRestarts(optimizer, self.params.lr_warm_restart_epochs)
else:
scheduler = LambdaLR(optimizer, lr_lambda=lambda epoch: 0.95 ** epoch)
optimizer_dict.update(lr_scheduler=scheduler)

View File

@ -3,8 +3,6 @@ from pathlib import Path
sr = 16000
PRIMATES_Root = Path(__file__).parent / 'data' / 'primates'
N_CLASS_multi = 4
N_CLASS_multi = 5