Transformer running

This commit is contained in:
Steffen Illium 2021-03-04 12:01:09 +01:00
parent 7edd3834a1
commit ad254dae92
14 changed files with 679 additions and 134 deletions

189
.idea/workspace.xml generated
View File

@ -1,18 +1,38 @@
<?xml version="1.0" encoding="UTF-8"?> <?xml version="1.0" encoding="UTF-8"?>
<project version="4"> <project version="4">
<component name="AutoImportSettings">
<option name="autoReloadType" value="SELECTIVE" />
</component>
<component name="ChangeListManager"> <component name="ChangeListManager">
<list default="true" id="2be1f675-29fe-4a7d-9fe6-9e96cd7c8055" name="Default Changelist" comment=""> <list default="true" id="2be1f675-29fe-4a7d-9fe6-9e96cd7c8055" name="Default Changelist" comment="">
<change afterPath="$PROJECT_DIR$/ml_lib/metrics/attention_rollout.py" afterDir="false" /> <change afterPath="$PROJECT_DIR$/ml_lib/additions/__init__.py" afterDir="false" />
<change afterPath="$PROJECT_DIR$/ml_lib/utils/_basedatamodule.py" afterDir="false" /> <change afterPath="$PROJECT_DIR$/ml_lib/additions/losses.py" afterDir="false" />
<change afterPath="$PROJECT_DIR$/ml_lib/utils/equal_sampler.py" afterDir="false" />
<change afterPath="$PROJECT_DIR$/models/performer.py" afterDir="false" />
<change afterPath="$PROJECT_DIR$/models/transformer_model.py" afterDir="false" />
<change afterPath="$PROJECT_DIR$/models/transformer_model_horizontal.py" afterDir="false" />
<change afterPath="$PROJECT_DIR$/multi_run.py" afterDir="false" />
<change afterPath="$PROJECT_DIR$/rebuild_dataset.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" />
<change beforePath="$PROJECT_DIR$/_parameters.ini" beforeDir="false" afterPath="$PROJECT_DIR$/_parameters.ini" afterDir="false" />
<change beforePath="$PROJECT_DIR$/datasets/primates_librosa_datamodule.py" beforeDir="false" afterPath="$PROJECT_DIR$/datasets/primates_librosa_datamodule.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/main.py" beforeDir="false" afterPath="$PROJECT_DIR$/main.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/ml_lib/_templates/new_project/main.py" beforeDir="false" afterPath="$PROJECT_DIR$/ml_lib/_templates/new_project/main.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/ml_lib/audio_toolset/audio_to_mel_dataset.py" beforeDir="false" afterPath="$PROJECT_DIR$/ml_lib/audio_toolset/audio_to_mel_dataset.py" afterDir="false" /> <change beforePath="$PROJECT_DIR$/ml_lib/audio_toolset/audio_to_mel_dataset.py" beforeDir="false" afterPath="$PROJECT_DIR$/ml_lib/audio_toolset/audio_to_mel_dataset.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/ml_lib/audio_toolset/mel_dataset.py" beforeDir="false" afterPath="$PROJECT_DIR$/ml_lib/audio_toolset/mel_dataset.py" afterDir="false" /> <change beforePath="$PROJECT_DIR$/ml_lib/audio_toolset/mel_dataset.py" beforeDir="false" afterPath="$PROJECT_DIR$/ml_lib/audio_toolset/mel_dataset.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/ml_lib/experiments.py" beforeDir="false" />
<change beforePath="$PROJECT_DIR$/ml_lib/metrics/multi_class_classification.py" beforeDir="false" afterPath="$PROJECT_DIR$/ml_lib/metrics/multi_class_classification.py" afterDir="false" /> <change beforePath="$PROJECT_DIR$/ml_lib/metrics/multi_class_classification.py" beforeDir="false" afterPath="$PROJECT_DIR$/ml_lib/metrics/multi_class_classification.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/ml_lib/modules/blocks.py" beforeDir="false" afterPath="$PROJECT_DIR$/ml_lib/modules/blocks.py" afterDir="false" /> <change beforePath="$PROJECT_DIR$/ml_lib/modules/blocks.py" beforeDir="false" afterPath="$PROJECT_DIR$/ml_lib/modules/blocks.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/ml_lib/modules/model_parts.py" beforeDir="false" afterPath="$PROJECT_DIR$/ml_lib/modules/model_parts.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/ml_lib/modules/util.py" beforeDir="false" afterPath="$PROJECT_DIR$/ml_lib/modules/util.py" afterDir="false" /> <change beforePath="$PROJECT_DIR$/ml_lib/modules/util.py" beforeDir="false" afterPath="$PROJECT_DIR$/ml_lib/modules/util.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/ml_lib/utils/logging.py" beforeDir="false" afterPath="$PROJECT_DIR$/ml_lib/utils/logging.py" afterDir="false" /> <change beforePath="$PROJECT_DIR$/ml_lib/utils/_basedatamodule.py" beforeDir="false" afterPath="$PROJECT_DIR$/ml_lib/utils/_basedatamodule.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/ml_lib/utils/model_io.py" beforeDir="false" afterPath="$PROJECT_DIR$/ml_lib/utils/model_io.py" afterDir="false" /> <change beforePath="$PROJECT_DIR$/ml_lib/utils/config.py" beforeDir="false" afterPath="$PROJECT_DIR$/ml_lib/utils/config.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/ml_lib/utils/logging.py" beforeDir="false" afterPath="$PROJECT_DIR$/ml_lib/utils/loggers.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/ml_lib/utils/tools.py" beforeDir="false" afterPath="$PROJECT_DIR$/ml_lib/utils/tools.py" afterDir="false" /> <change beforePath="$PROJECT_DIR$/ml_lib/utils/tools.py" beforeDir="false" afterPath="$PROJECT_DIR$/ml_lib/utils/tools.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/models/cnn_baseline.py" beforeDir="false" afterPath="$PROJECT_DIR$/models/cnn_baseline.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/util/loss_mixin.py" beforeDir="false" afterPath="$PROJECT_DIR$/util/loss_mixin.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/util/module_mixins.py" beforeDir="false" afterPath="$PROJECT_DIR$/util/module_mixins.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/util/optimizer_mixin.py" beforeDir="false" afterPath="$PROJECT_DIR$/util/optimizer_mixin.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/variables.py" beforeDir="false" afterPath="$PROJECT_DIR$/variables.py" afterDir="false" />
</list> </list>
<option name="SHOW_DIALOG" value="false" /> <option name="SHOW_DIALOG" value="false" />
<option name="HIGHLIGHT_CONFLICTS" value="true" /> <option name="HIGHLIGHT_CONFLICTS" value="true" />
@ -27,7 +47,15 @@
</option> </option>
</component> </component>
<component name="Git.Settings"> <component name="Git.Settings">
<option name="RECENT_GIT_ROOT_PATH" value="$PROJECT_DIR$/ml_lib" /> <option name="RECENT_GIT_ROOT_PATH" value="$PROJECT_DIR$" />
</component>
<component name="GitSEFilterConfiguration">
<file-type-list>
<filtered-out-file-type name="LOCAL_BRANCH" />
<filtered-out-file-type name="REMOTE_BRANCH" />
<filtered-out-file-type name="TAG" />
<filtered-out-file-type name="COMMIT_BY_MESSAGE" />
</file-type-list>
</component> </component>
<component name="ProjectId" id="1oTEXjx0b8UPBPmOIGceYxEch8r" /> <component name="ProjectId" id="1oTEXjx0b8UPBPmOIGceYxEch8r" />
<component name="ProjectLevelVcsManager" settingsEditedManually="true" /> <component name="ProjectLevelVcsManager" settingsEditedManually="true" />
@ -38,14 +66,53 @@
<component name="PropertiesComponent"> <component name="PropertiesComponent">
<property name="RunOnceActivity.OpenProjectViewOnStart" value="true" /> <property name="RunOnceActivity.OpenProjectViewOnStart" value="true" />
<property name="RunOnceActivity.ShowReadmeOnStart" value="true" /> <property name="RunOnceActivity.ShowReadmeOnStart" value="true" />
<property name="WebServerToolWindowFactoryState" value="true" /> <property name="WebServerToolWindowFactoryState" value="false" />
<property name="WebServerToolWindowPanel.toolwindow.highlight.mappings" value="true" /> <property name="WebServerToolWindowPanel.toolwindow.highlight.mappings" value="true" />
<property name="WebServerToolWindowPanel.toolwindow.highlight.symlinks" value="true" /> <property name="WebServerToolWindowPanel.toolwindow.highlight.symlinks" value="true" />
<property name="WebServerToolWindowPanel.toolwindow.show.date" value="false" /> <property name="WebServerToolWindowPanel.toolwindow.show.date" value="false" />
<property name="WebServerToolWindowPanel.toolwindow.show.permissions" value="false" /> <property name="WebServerToolWindowPanel.toolwindow.show.permissions" value="false" />
<property name="WebServerToolWindowPanel.toolwindow.show.size" value="false" /> <property name="WebServerToolWindowPanel.toolwindow.show.size" value="false" />
<property name="credentialsType com.jetbrains.python.remote.PyCreateRemoteInterpreterDialog$PyCreateRemoteSdkForm" value="Web Deployment" />
<property name="last_opened_file_path" value="$PROJECT_DIR$/../inter_challenge_2020" />
<property name="restartRequiresConfirmation" value="false" />
<property name="settings.editor.selected.configurable" value="com.jetbrains.python.configuration.PyActiveSdkModuleConfigurable" />
</component> </component>
<component name="RunManager" selected="Python.main"> <component name="PyDebuggerOptionsProvider">
<option name="mySaveCallSignatures" value="true" />
</component>
<component name="RecentsManager">
<key name="CopyFile.RECENT_KEYS">
<recent name="C:\Users\steff\projects\compare_21\models" />
<recent name="C:\Users\steff\projects\compare_21\util" />
<recent name="C:\Users\steff\projects\compare_21" />
</key>
<key name="MoveFile.RECENT_KEYS">
<recent name="C:\Users\steff\projects\compare_21\util" />
</key>
</component>
<component name="RunManager" selected="Python.rebuild_dataset">
<configuration name="equal_sampler" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true">
<module name="compare_21" />
<option name="INTERPRETER_OPTIONS" value="" />
<option name="PARENT_ENVS" value="true" />
<envs>
<env name="PYTHONUNBUFFERED" value="1" />
</envs>
<option name="SDK_HOME" value="" />
<option name="WORKING_DIRECTORY" value="$PROJECT_DIR$/ml_lib/utils" />
<option name="IS_MODULE_SDK" value="true" />
<option name="ADD_CONTENT_ROOTS" value="true" />
<option name="ADD_SOURCE_ROOTS" value="true" />
<EXTENSION ID="PythonCoverageRunConfigurationExtension" runner="coverage.py" />
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/ml_lib/utils/equal_sampler.py" />
<option name="PARAMETERS" value="" />
<option name="SHOW_COMMAND_LINE" value="false" />
<option name="EMULATE_TERMINAL" value="false" />
<option name="MODULE_MODE" value="false" />
<option name="REDIRECT_INPUT" value="false" />
<option name="INPUT_FILE" value="" />
<method v="2" />
</configuration>
<configuration name="main" type="PythonConfigurationType" factoryName="Python" nameIsGenerated="true"> <configuration name="main" type="PythonConfigurationType" factoryName="Python" nameIsGenerated="true">
<module name="compare_21" /> <module name="compare_21" />
<option name="INTERPRETER_OPTIONS" value="" /> <option name="INTERPRETER_OPTIONS" value="" />
@ -60,7 +127,7 @@
<option name="ADD_SOURCE_ROOTS" value="true" /> <option name="ADD_SOURCE_ROOTS" value="true" />
<EXTENSION ID="PythonCoverageRunConfigurationExtension" runner="coverage.py" /> <EXTENSION ID="PythonCoverageRunConfigurationExtension" runner="coverage.py" />
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/main.py" /> <option name="SCRIPT_NAME" value="$PROJECT_DIR$/main.py" />
<option name="PARAMETERS" value="" /> <option name="PARAMETERS" value="--debug=True --num_worker=0 --sampler=WeightedRandomSampler --model_name=VisualTransformer" />
<option name="SHOW_COMMAND_LINE" value="false" /> <option name="SHOW_COMMAND_LINE" value="false" />
<option name="EMULATE_TERMINAL" value="false" /> <option name="EMULATE_TERMINAL" value="false" />
<option name="MODULE_MODE" value="false" /> <option name="MODULE_MODE" value="false" />
@ -90,8 +157,62 @@
<option name="INPUT_FILE" value="" /> <option name="INPUT_FILE" value="" />
<method v="2" /> <method v="2" />
</configuration> </configuration>
<configuration name="multi_run" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true">
<module name="compare_21" />
<option name="INTERPRETER_OPTIONS" value="" />
<option name="PARENT_ENVS" value="true" />
<envs>
<env name="PYTHONUNBUFFERED" value="1" />
</envs>
<option name="SDK_HOME" value="" />
<option name="WORKING_DIRECTORY" value="$PROJECT_DIR$" />
<option name="IS_MODULE_SDK" value="true" />
<option name="ADD_CONTENT_ROOTS" value="true" />
<option name="ADD_SOURCE_ROOTS" value="true" />
<EXTENSION ID="PythonCoverageRunConfigurationExtension" runner="coverage.py" />
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/multi_run.py" />
<option name="PARAMETERS" value="" />
<option name="SHOW_COMMAND_LINE" value="false" />
<option name="EMULATE_TERMINAL" value="false" />
<option name="MODULE_MODE" value="false" />
<option name="REDIRECT_INPUT" value="false" />
<option name="INPUT_FILE" value="" />
<method v="2" />
</configuration>
<configuration name="rebuild_dataset" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true">
<module name="compare_21" />
<option name="INTERPRETER_OPTIONS" value="" />
<option name="PARENT_ENVS" value="true" />
<envs>
<env name="PYTHONUNBUFFERED" value="1" />
</envs>
<option name="SDK_HOME" value="" />
<option name="WORKING_DIRECTORY" value="$PROJECT_DIR$" />
<option name="IS_MODULE_SDK" value="true" />
<option name="ADD_CONTENT_ROOTS" value="true" />
<option name="ADD_SOURCE_ROOTS" value="true" />
<EXTENSION ID="PythonCoverageRunConfigurationExtension" runner="coverage.py" />
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/rebuild_dataset.py" />
<option name="PARAMETERS" value="" />
<option name="SHOW_COMMAND_LINE" value="false" />
<option name="EMULATE_TERMINAL" value="false" />
<option name="MODULE_MODE" value="false" />
<option name="REDIRECT_INPUT" value="false" />
<option name="INPUT_FILE" value="" />
<method v="2" />
</configuration>
<list>
<item itemvalue="Python.main" />
<item itemvalue="Python.metadata_readout" />
<item itemvalue="Python.multi_run" />
<item itemvalue="Python.equal_sampler" />
<item itemvalue="Python.rebuild_dataset" />
</list>
<recent_temporary> <recent_temporary>
<list> <list>
<item itemvalue="Python.rebuild_dataset" />
<item itemvalue="Python.multi_run" />
<item itemvalue="Python.equal_sampler" />
<item itemvalue="Python.metadata_readout" /> <item itemvalue="Python.metadata_readout" />
</list> </list>
</recent_temporary> </recent_temporary>
@ -106,25 +227,69 @@
<updated>1613302221903</updated> <updated>1613302221903</updated>
<workItem from="1613302223434" duration="2570000" /> <workItem from="1613302223434" duration="2570000" />
<workItem from="1613305247599" duration="11387000" /> <workItem from="1613305247599" duration="11387000" />
<workItem from="1613381449219" duration="21710000" />
<workItem from="1613634319983" duration="160704000" />
<workItem from="1614498086696" duration="17996000" />
<workItem from="1614671803611" duration="595000" />
<workItem from="1614679476632" duration="11417000" />
<workItem from="1614760180356" duration="12583000" />
<workItem from="1614788597119" duration="8215000" />
</task> </task>
<task id="LOCAL-00001" summary="Dataset rdy">
<created>1613467084268</created>
<option name="number" value="00001" />
<option name="presentableId" value="LOCAL-00001" />
<option name="project" value="LOCAL" />
<updated>1613467084268</updated>
</task>
<option name="localTasksCounter" value="2" />
<servers /> <servers />
</component> </component>
<component name="TypeScriptGeneratedFilesManager"> <component name="TypeScriptGeneratedFilesManager">
<option name="version" value="3" /> <option name="version" value="3" />
</component> </component>
<component name="VcsManagerConfiguration">
<MESSAGE value="Dataset rdy" />
<option name="LAST_COMMIT_MESSAGE" value="Dataset rdy" />
</component>
<component name="XDebuggerManager"> <component name="XDebuggerManager">
<breakpoint-manager> <breakpoint-manager>
<breakpoints> <breakpoints>
<line-breakpoint enabled="true" suspend="THREAD" type="python-line">
<url>file://$PROJECT_DIR$/ml_lib/modules/util.py</url>
<line>231</line>
<option name="timeStamp" value="70" />
</line-breakpoint>
<line-breakpoint enabled="true" suspend="THREAD" type="python-line">
<url>file://$PROJECT_DIR$/models/transformer_model_horizontal.py</url>
<line>68</line>
<option name="timeStamp" value="72" />
</line-breakpoint>
<line-breakpoint enabled="true" suspend="THREAD" type="python-line">
<url>file://$PROJECT_DIR$/ml_lib/audio_toolset/mel_dataset.py</url>
<line>29</line>
<option name="timeStamp" value="83" />
</line-breakpoint>
<line-breakpoint enabled="true" suspend="THREAD" type="python-line"> <line-breakpoint enabled="true" suspend="THREAD" type="python-line">
<url>file://$PROJECT_DIR$/main.py</url> <url>file://$PROJECT_DIR$/main.py</url>
<line>11</line> <line>46</line>
<option name="timeStamp" value="2" /> <option name="timeStamp" value="84" />
</line-breakpoint> </line-breakpoint>
</breakpoints> </breakpoints>
<default-breakpoints>
<breakpoint type="python-exception">
<properties notifyOnTerminate="true" exception="BaseException">
<option name="notifyOnTerminate" value="true" />
</properties>
</breakpoint>
</default-breakpoints>
</breakpoint-manager> </breakpoint-manager>
</component> </component>
<component name="com.intellij.coverage.CoverageDataManagerImpl"> <component name="com.intellij.coverage.CoverageDataManagerImpl">
<SUITE FILE_PATH="coverage/compare_21$equal_sampler.coverage" NAME="equal_sampler Coverage Results" MODIFIED="1614070702295" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/ml_lib/utils" />
<SUITE FILE_PATH="coverage/compare_21$rebuild_dataset.coverage" NAME="rebuild_dataset Coverage Results" MODIFIED="1614849857426" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" />
<SUITE FILE_PATH="coverage/compare_21$metadata_readout.coverage" NAME="metadata_readout Coverage Results" MODIFIED="1613306122664" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/datasets" /> <SUITE FILE_PATH="coverage/compare_21$metadata_readout.coverage" NAME="metadata_readout Coverage Results" MODIFIED="1613306122664" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/datasets" />
<SUITE FILE_PATH="coverage/compare_21$main.coverage" NAME="main Coverage Results" MODIFIED="1613376576627" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" /> <SUITE FILE_PATH="coverage/compare_21$main.coverage" NAME="main Coverage Results" MODIFIED="1614695189720" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" />
<SUITE FILE_PATH="coverage/compare_21$multi_run.coverage" NAME="multi_run Coverage Results" MODIFIED="1614702462483" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" />
</component> </component>
</project> </project>

View File

@ -1,6 +1,6 @@
[project] [project]
neptune_key = eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vdWkubmVwdHVuZS5haSIsImFwaV91cmwiOiJodHRwczovL3VpLm5lcHR1bmUuYWkiLCJhcGlfa2V5IjoiZmI0OGMzNzUtOTg1NS00Yzg2LThjMzYtMWFiYjUwMDUyMjVlIn0= neptune_key = eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vdWkubmVwdHVuZS5haSIsImFwaV91cmwiOiJodHRwczovL3VpLm5lcHR1bmUuYWkiLCJhcGlfa2V5IjoiZmI0OGMzNzUtOTg1NS00Yzg2LThjMzYtMWFiYjUwMDUyMjVlIn0=
debug = 1 debug = False
eval = True eval = True
seed = 69 seed = 69
owner = si11ium owner = si11ium
@ -8,56 +8,72 @@ model_name = CNNBaseline
data_name = PrimatesLibrosaDatamodule data_name = PrimatesLibrosaDatamodule
[data] [data]
num_worker = 0 num_worker = 10
data_root = data data_root = data
reset = False reset = False
n_mels = 64 n_mels = 64
sr = 16000 sr = 16000
hop_length = 256 hop_length = 128
n_fft = 512 n_fft = 256
random_apply_chance = 0.7
loudness_ratio = 0.0 loudness_ratio = 0.0
shift_ratio = 0.0 shift_ratio = 0.3
noise_ratio = 0 noise_ratio = 0.0
mask_ratio = 0.3 mask_ratio = 0.0
speed_amount = 0
speed_min = 0
speed_max = 0
[model_cnn] [CNNBaseline]
weight_init = xavier_normal_ weight_init = xavier_normal_
activation = gelu activation = gelu
use_bias = True use_bias = True
use_norm = True use_norm = True
dropout = 0.2 dropout = 0.2
lat_dim = 128 lat_dim = 32
features = 64 features = 64
filters = [32, 64, 128, 64] filters = [16, 32, 64, 128]
[model_attn] [VisualTransformer]
name = VerticalVisualTransformer
weight_init = xavier_normal_ weight_init = xavier_normal_
activation = gelu activation = gelu
use_bias = True use_bias = True
use_norm = True use_norm = True
use_residual = True
dropout = 0.2 dropout = 0.2
lat_dim = 128
features = 64 lat_dim = 32
patch_size = 3 patch_size = 8
attn_depth = 3 attn_depth = 12
heads = 8 heads = 4
embedding_size = 64 embedding_size = 128
[HorizontalVisualTransformer]
weight_init = xavier_normal_
activation = gelu
use_bias = True
use_norm = True
use_residual = True
dropout = 0.3
lat_dim = 256
patch_size = 8
attn_depth = 12
heads = 6
embedding_size = 32
[train] [train]
outpath = output outpath = output
version = None version = None
gpus=0 sampler = EqualSampler
loss = focal_loss_rob
sto_weight_avg = False sto_weight_avg = False
weight_decay = 0 weight_decay = 0
opt_reset_interval = 0 opt_reset_interval = 0
epochs = 100 max_epochs = 200
batch_size = 30 batch_size = 30
lr = 0.01 lr = 0.001
lr_warmup_steps = 0 use_residual = True
num_sanity_val_steps = 0 lr_warm_restart_epochs = 0
num_sanity_val_steps = 2
check_val_every_n_epoch = 5
checkpoint_callback = True
gradient_clip_val = 0

View File

@ -1,8 +1,8 @@
from multiprocessing.pool import ApplyResult import multiprocessing as mp
from collections import defaultdict
from pathlib import Path from pathlib import Path
from typing import List
from torch.utils.data import DataLoader, ConcatDataset from torch.utils.data import DataLoader, ConcatDataset, WeightedRandomSampler
from torchvision.transforms import Compose, RandomApply from torchvision.transforms import Compose, RandomApply
from tqdm import tqdm from tqdm import tqdm
@ -10,9 +10,8 @@ from ml_lib.audio_toolset.audio_io import NormalizeLocal
from ml_lib.audio_toolset.audio_to_mel_dataset import LibrosaAudioToMelDataset from ml_lib.audio_toolset.audio_to_mel_dataset import LibrosaAudioToMelDataset
from ml_lib.audio_toolset.mel_augmentation import NoiseInjection, LoudnessManipulator, ShiftTime, MaskAug from ml_lib.audio_toolset.mel_augmentation import NoiseInjection, LoudnessManipulator, ShiftTime, MaskAug
from ml_lib.utils._basedatamodule import _BaseDataModule, DATA_OPTION_test, DATA_OPTION_train, DATA_OPTION_devel from ml_lib.utils._basedatamodule import _BaseDataModule, DATA_OPTION_test, DATA_OPTION_train, DATA_OPTION_devel
from ml_lib.utils.equal_sampler import EqualSampler
from ml_lib.utils.transforms import ToTensor from ml_lib.utils.transforms import ToTensor
import multiprocessing as mp
data_options = [DATA_OPTION_test, DATA_OPTION_train, DATA_OPTION_devel] data_options = [DATA_OPTION_test, DATA_OPTION_train, DATA_OPTION_devel]
@ -34,11 +33,16 @@ class PrimatesLibrosaDatamodule(_BaseDataModule):
def wav_folder(self): def wav_folder(self):
return self.root / 'wav' return self.root / 'wav'
def __init__(self, data_root, batch_size, num_worker, sr, n_mels, n_fft, hop_length, def __init__(self, data_root, batch_size, num_worker, sr, n_mels, n_fft, hop_length, sampler=None,
sample_segment_len=40, sample_hop_len=15): sample_segment_len=40, sample_hop_len=15, random_apply_chance=0.5,
loudness_ratio=0.3, shift_ratio=0.3, noise_ratio=0.3, mask_ratio=0.3):
super(PrimatesLibrosaDatamodule, self).__init__() super(PrimatesLibrosaDatamodule, self).__init__()
self.sampler = sampler
self.samplers = None
self.sample_hop_len = sample_hop_len self.sample_hop_len = sample_hop_len
self.sample_segment_len = sample_segment_len self.sample_segment_len = sample_segment_len
self.num_worker = num_worker or 1 self.num_worker = num_worker or 1
self.batch_size = batch_size self.batch_size = batch_size
self.root = Path(data_root) / 'primates' self.root = Path(data_root) / 'primates'
@ -51,23 +55,22 @@ class PrimatesLibrosaDatamodule(_BaseDataModule):
self.utility_transforms = Compose([NormalizeLocal(), ToTensor()]) self.utility_transforms = Compose([NormalizeLocal(), ToTensor()])
# Data Augmentations # Data Augmentations
self.random_apply_chance = random_apply_chance
self.mel_augmentations = Compose([ self.mel_augmentations = Compose([
# ToDo: HP Search this parameters, make it adjustable from outside RandomApply([NoiseInjection(noise_ratio)], p=random_apply_chance),
RandomApply([NoiseInjection(0.2)], p=0.3), RandomApply([LoudnessManipulator(loudness_ratio)], p=random_apply_chance),
RandomApply([LoudnessManipulator(0.5)], p=0.3), RandomApply([ShiftTime(shift_ratio)], p=random_apply_chance),
RandomApply([ShiftTime(0.4)], p=0.3), RandomApply([MaskAug(mask_ratio)], p=random_apply_chance),
RandomApply([MaskAug(0.2)], p=0.3),
self.utility_transforms]) self.utility_transforms])
def train_dataloader(self): def train_dataloader(self):
return DataLoader(dataset=self.datasets[DATA_OPTION_train], shuffle=True, return DataLoader(dataset=self.datasets[DATA_OPTION_train], num_workers=self.num_worker, pin_memory=True,
batch_size=self.batch_size, pin_memory=True, sampler=self.samplers[DATA_OPTION_train], batch_size=self.batch_size)
num_workers=self.num_worker)
# Validation Dataloader # Validation Dataloader
def val_dataloader(self): def val_dataloader(self):
return DataLoader(dataset=self.datasets[DATA_OPTION_devel], shuffle=False, pin_memory=True, return DataLoader(dataset=self.datasets[DATA_OPTION_devel], num_workers=self.num_worker, pin_memory=True,
batch_size=self.batch_size, num_workers=self.num_worker) sampler=self.samplers[DATA_OPTION_devel], batch_size=self.batch_size)
# Test Dataloader # Test Dataloader
def test_dataloader(self): def test_dataloader(self):
@ -79,6 +82,7 @@ class PrimatesLibrosaDatamodule(_BaseDataModule):
slice_file_name, class_name = row.strip().split(',') slice_file_name, class_name = row.strip().split(',')
class_id = self.class_names.get(class_name, -1) class_id = self.class_names.get(class_name, -1)
audio_file_path = self.wav_folder / slice_file_name audio_file_path = self.wav_folder / slice_file_name
# DATA OPTION DIFFERENTIATION !!!!!!!!!!! - Begin # DATA OPTION DIFFERENTIATION !!!!!!!!!!! - Begin
kwargs = self.__dict__ kwargs = self.__dict__
if any([x in slice_file_name for x in [DATA_OPTION_devel, DATA_OPTION_test]]): if any([x in slice_file_name for x in [DATA_OPTION_devel, DATA_OPTION_test]]):
@ -91,7 +95,7 @@ class PrimatesLibrosaDatamodule(_BaseDataModule):
mel_dataset = LibrosaAudioToMelDataset(audio_file_path, class_id, **kwargs) mel_dataset = LibrosaAudioToMelDataset(audio_file_path, class_id, **kwargs)
if build: if build:
assert mel_dataset.build_mel() assert mel_dataset.build_mel()
return mel_dataset return mel_dataset, class_id, slice_file_name
def prepare_data(self, *args, **kwargs): def prepare_data(self, *args, **kwargs):
datasets = dict() datasets = dict()
@ -103,22 +107,21 @@ class PrimatesLibrosaDatamodule(_BaseDataModule):
chunksize = len(all_rows) // max(self.num_worker, 1) chunksize = len(all_rows) // max(self.num_worker, 1)
dataset = list() dataset = list()
with mp.Pool(processes=self.num_worker) as pool: with mp.Pool(processes=self.num_worker) as pool:
pbar = tqdm(total=len(all_rows))
def update():
pbar.update(chunksize)
from itertools import repeat from itertools import repeat
results = pool.starmap_async(self._build_subdataset, zip(all_rows, repeat(True, len(all_rows))), results = pool.starmap_async(self._build_subdataset, zip(all_rows, repeat(True, len(all_rows))),
chunksize=chunksize) chunksize=chunksize)
for sub_dataset in results.get(): for sub_dataset in results.get():
dataset.append(sub_dataset) dataset.append(sub_dataset[0])
update() # FIXME: will i ever get this to work?
datasets[data_option] = ConcatDataset(dataset) datasets[data_option] = ConcatDataset(dataset)
self.datasets = datasets self.datasets = datasets
return datasets return datasets
def setup(self, stag=None): def setup(self, stag=None):
datasets = dict() datasets = dict()
samplers = dict()
weights = dict()
for data_option in data_options: for data_option in data_options:
with open(Path(self.root) / 'lab' / f'{data_option}.csv', mode='r') as f: with open(Path(self.root) / 'lab' / f'{data_option}.csv', mode='r') as f:
# Exclude the header # Exclude the header
@ -126,7 +129,38 @@ class PrimatesLibrosaDatamodule(_BaseDataModule):
all_rows = list(f) all_rows = list(f)
dataset = list() dataset = list()
for row in all_rows: for row in all_rows:
dataset.append(self._build_subdataset(row)) mel_dataset, class_id, _ = self._build_subdataset(row)
dataset.append(mel_dataset)
datasets[data_option] = ConcatDataset(dataset) datasets[data_option] = ConcatDataset(dataset)
# Build Weighted Sampler for train and val
if data_option in [DATA_OPTION_train, DATA_OPTION_devel]:
if self.sampler == EqualSampler.__name__:
class_idxs = [[idx for idx, (_, __, label) in enumerate(datasets[data_option]) if label == class_idx]
for class_idx in range(len(self.class_names))
]
samplers[data_option] = EqualSampler(class_idxs)
elif self.sampler == WeightedRandomSampler.__name__:
class_counts = defaultdict(lambda: 0)
for _, __, label in datasets[data_option]:
class_counts[label] += 1
len_largest_class = max(class_counts.values())
weights[data_option] = [1 / class_counts[x] for x in range(len(class_counts))]
weights[data_option] = [weights[data_option][datasets[data_option][i][-1]]
for i in range(len(datasets[data_option]))]
samplers[data_option] = WeightedRandomSampler(weights[data_option],
len_largest_class * len(self.class_names))
else:
samplers[data_option] = None
self.datasets = datasets self.datasets = datasets
self.samplers = samplers
return datasets return datasets
def purge(self):
import shutil
shutil.rmtree(self.mel_folder, ignore_errors=True)
print('Mel Folder has been recursively deleted')
print(f'Folder still exists: {self.mel_folder.exists()}')
return not self.mel_folder.exists()

67
main.py
View File

@ -1,45 +1,29 @@
import configparser from argparse import Namespace
from argparse import ArgumentParser, Namespace
import warnings
from pytorch_lightning import Trainer from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from ml_lib.utils.logging import Logger from ml_lib.utils.config import parse_comandline_args_add_defaults
from ml_lib.utils.loggers import Logger
from ml_lib.utils.tools import locate_and_import_class, auto_cast from ml_lib.utils.tools import locate_and_import_class, auto_cast
import variables as v import variables as v
if __name__ == '__main__': warnings.filterwarnings('ignore', category=FutureWarning)
# Argument Parser and default Values warnings.filterwarnings('ignore', category=UserWarning)
# =============================================================================
# Load Defaults from _parameters.ini file
config = configparser.ConfigParser()
config.read('_parameters.ini')
project = config['project']
data_class = locate_and_import_class(project['data_name'], 'datasets')
model_class = locate_and_import_class(project['model_name'], 'models')
tmp_params = dict() def run_lightning_loop(h_params, data_class, model_class):
for key in ['project', 'train', 'data', 'model_cnn']: with Logger.from_argparse_args(h_params) as logger:
defaults = config[key]
tmp_params.update({key: auto_cast(val) for key, val in defaults.items()})
# Parse Command Line
parser = ArgumentParser()
for module in [Logger, Trainer, data_class, model_class]:
parser = module.add_argparse_args(parser)
cmd_args, _ = parser.parse_known_args()
tmp_params.update({key: val for key, val in vars(cmd_args).items() if val is not None})
hparams = Namespace(**tmp_params)
with Logger.from_argparse_args(hparams) as logger:
# Callbacks # Callbacks
# ============================================================================= # =============================================================================
# Checkpoint Saving # Checkpoint Saving
ckpt_callback = ModelCheckpoint( ckpt_callback = ModelCheckpoint(
monitor='mean_loss', monitor='mean_loss',
filepath=str(logger.log_dir / 'ckpt_weights'), dirpath=str(logger.log_dir),
filename='ckpt_weights',
verbose=False, verbose=False,
save_top_k=3, save_top_k=3,
) )
@ -47,22 +31,37 @@ if __name__ == '__main__':
# Learning Rate Logger # Learning Rate Logger
lr_logger = LearningRateMonitor(logging_interval='epoch') lr_logger = LearningRateMonitor(logging_interval='epoch')
#
# START # START
# ============================================================================= # =============================================================================
# Let Datamodule pull what it wants # Let Datamodule pull what it wants
datamodule = data_class.from_argparse_args(hparams) datamodule = data_class.from_argparse_args(h_params)
datamodule.setup() datamodule.setup()
model_in_shape = datamodule.shape
# Let Trainer pull what it wants and add callbacks # Let Trainer pull what it wants and add callbacks
trainer = Trainer.from_argparse_args(hparams, callbacks=[ckpt_callback, lr_logger]) trainer = Trainer.from_argparse_args(h_params, logger=logger, callbacks=[ckpt_callback, lr_logger])
# Let Model pull what it wants # Let Model pull what it wants
model = model_class.from_argparse_args(hparams, in_shape=datamodule.shape, n_classes=v.N_CLASS_multi) model = model_class.from_argparse_args(h_params, in_shape=datamodule.shape, n_classes=v.N_CLASS_multi)
model.init_weights() model.init_weights()
logger.log_hyperparams(dict(model.params))
trainer.fit(model, datamodule) trainer.fit(model, datamodule)
trainer.save_checkpoint(trainer.logger.save_dir) # Log paramters
pytorch_total_params = sum(p.numel() for p in model.parameters())
# logger.log_text('n_parameters', pytorch_total_params)
trainer.save_checkpoint(logger.save_dir / 'weights.ckpt')
if __name__ == '__main__':
# Parse comandline args, read config and get model
cmd_args, found_data_class, found_model_class = parse_comandline_args_add_defaults('_parameters.ini')
# To NameSpace
hparams = Namespace(**cmd_args)
# Start
# -----------------
run_lightning_loop(hparams, found_data_class, found_model_class)
print('done')
pass

View File

@ -1,8 +1,6 @@
import inspect import inspect
from argparse import Namespace from argparse import Namespace
import variables as v
from torch import nn from torch import nn
from ml_lib.metrics.multi_class_classification import MultiClassScores from ml_lib.metrics.multi_class_classification import MultiClassScores
@ -17,7 +15,7 @@ class CNNBaseline(CombinedModelMixins,
): ):
def __init__(self, in_shape, n_classes, weight_init, activation, use_bias, use_norm, dropout, lat_dim, features, def __init__(self, in_shape, n_classes, weight_init, activation, use_bias, use_norm, dropout, lat_dim, features,
filters): filters, lr, weight_decay, sto_weight_avg, lr_warm_restart_epochs, opt_reset_interval, loss):
# TODO: Move this to parent class, or make it much easieer to access.... # TODO: Move this to parent class, or make it much easieer to access....
a = dict(locals()) a = dict(locals())
@ -50,4 +48,4 @@ class CNNBaseline(CombinedModelMixins,
return Namespace(main_out=tensor) return Namespace(main_out=tensor)
def additional_scores(self, outputs): def additional_scores(self, outputs):
return MultiClassScores(self) return MultiClassScores(self)(outputs)

0
models/performer.py Normal file
View File

120
models/transformer_model.py Normal file
View File

@ -0,0 +1,120 @@
import inspect
from argparse import Namespace
import warnings
import torch
from torch import nn
from einops import rearrange, repeat
from ml_lib.metrics.multi_class_classification import MultiClassScores
from ml_lib.modules.blocks import TransformerModule
from ml_lib.modules.util import (LightningBaseModule, AutoPadToShape, F_x)
from util.module_mixins import CombinedModelMixins
MIN_NUM_PATCHES = 16
class VisualTransformer(CombinedModelMixins,
LightningBaseModule
):
def __init__(self, in_shape, n_classes, weight_init, activation,
embedding_size, heads, attn_depth, patch_size,use_residual,
use_bias, use_norm, dropout, lat_dim, loss,
lr, weight_decay, sto_weight_avg, lr_warm_restart_epochs, opt_reset_interval):
# TODO: Move this to parent class, or make it much easieer to access... But How...
a = dict(locals())
params = {arg: a[arg] for arg in inspect.signature(self.__init__).parameters.keys() if arg != 'self'}
super(VisualTransformer, self).__init__(params)
self.in_shape = in_shape
assert len(self.in_shape) == 3, 'There need to be three Dimensions'
channels, height, width = self.in_shape
# Model Paramters
# =============================================================================
# Additional parameters
self.embed_dim = self.params.embedding_size
# Automatic Image Shaping
self.patch_size = self.params.patch_size
image_size = (max(height, width) // self.patch_size) * self.patch_size
self.image_size = image_size + self.patch_size if image_size < max(height, width) else image_size
# This should be obsolete
assert self.image_size % self.patch_size == 0, 'image dimensions must be divisible by the patch size'
num_patches = (self.image_size // self.patch_size) ** 2
patch_dim = channels * self.patch_size ** 2
assert num_patches >= MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for ' + \
f'attention. Try decreasing your patch size'
# Correct the Embedding Dim
if not self.embed_dim % self.params.heads == 0:
self.embed_dim = (self.embed_dim // self.params.heads) * self.params.heads
message = ('Embedding Dimension was fixed to be devideable by the number' +
f' of attention heads, is now: {self.embed_dim}')
for func in print, warnings.warn:
func(message)
# Utility Modules
self.autopad = AutoPadToShape((self.image_size, self.image_size))
# Modules with Parameters
self.transformer = TransformerModule(in_shape=self.embed_dim, mlp_dim=self.params.lat_dim,
heads=self.params.heads, depth=self.params.attn_depth,
dropout=self.params.dropout, use_norm=self.params.use_norm,
activation=self.params.activation, use_residual=self.params.use_residual
)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, self.embed_dim))
self.patch_to_embedding = nn.Linear(patch_dim, self.embed_dim) if self.params.embedding_size \
else F_x(self.embed_dim)
self.cls_token = nn.Parameter(torch.randn(1, 1, self.embed_dim))
self.dropout = nn.Dropout(self.params.dropout)
self.to_cls_token = nn.Identity()
self.mlp_head = nn.Sequential(
nn.LayerNorm(self.embed_dim),
nn.Linear(self.embed_dim, self.params.lat_dim),
nn.GELU(),
nn.Dropout(self.params.dropout),
nn.Linear(self.params.lat_dim, n_classes),
nn.Softmax()
)
def forward(self, x, mask=None, return_attn_weights=False):
"""
:param x: the sequence to the encoder (required).
:param mask: the mask for the src sequence (optional).
:return:
"""
tensor = self.autopad(x)
p = self.params.patch_size
tensor = rearrange(tensor, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p)
tensor = self.patch_to_embedding(tensor)
b, n, _ = tensor.shape
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b)
tensor = torch.cat((cls_tokens, tensor), dim=1)
tensor += self.pos_embedding[:, :(n + 1)]
tensor = self.dropout(tensor)
if return_attn_weights:
tensor, attn_weights = self.transformer(tensor, mask, return_attn_weights)
else:
attn_weights = None
tensor = self.transformer(tensor, mask)
tensor = self.to_cls_token(tensor[:, 0])
tensor = self.mlp_head(tensor)
return Namespace(main_out=tensor, attn_weights=attn_weights)
def additional_scores(self, outputs):
return MultiClassScores(self)(outputs)

View File

@ -0,0 +1,120 @@
import inspect
from argparse import Namespace
import warnings
import torch
from einops import repeat
from torch import nn
from ml_lib.metrics.multi_class_classification import MultiClassScores
from ml_lib.modules.blocks import TransformerModule
from ml_lib.modules.util import (LightningBaseModule, AutoPadToShape, F_x, SlidingWindow)
from util.module_mixins import CombinedModelMixins
from variables import N_CLASS_multi
MIN_NUM_PATCHES = 16
class HorizontalVisualTransformer(CombinedModelMixins,
LightningBaseModule
):
def __init__(self, in_shape, n_classes, weight_init, activation,
embedding_size, heads, attn_depth, patch_size,use_residual,
use_bias, use_norm, dropout, lat_dim, features, loss,
lr, weight_decay, sto_weight_avg, lr_warm_restart_epochs, opt_reset_interval):
# TODO: Move this to parent class, or make it much easieer to access... But How...
a = dict(locals())
params = {arg: a[arg] for arg in inspect.signature(self.__init__).parameters.keys() if arg != 'self'}
super(HorizontalVisualTransformer, self).__init__(params)
self.in_shape = in_shape
assert len(self.in_shape) == 3, 'There need to be three Dimensions'
channels, height, width = self.in_shape
# Model Paramters
# =============================================================================
# Additional parameters
self.n_classes = N_CLASS_multi
self.embed_dim = self.params.embedding_size
self.height = height
self.width = width
self.channels = channels
self.new_height = ((self.height - self.params.patch_size)//1) + 1
num_patches = self.new_height - (self.params.patch_size // 2)
patch_dim = channels * self.params.patch_size * self.width
assert patch_dim
assert num_patches >= MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for ' + \
f'attention. Try decreasing your patch size'
# Correct the Embedding Dim
if not self.embed_dim % self.params.heads == 0:
self.embed_dim = (self.embed_dim // self.params.heads) * self.params.heads
message = ('Embedding Dimension was fixed to be devideable by the number' +
f' of attention heads, is now: {self.embed_dim}')
for func in print, warnings.warn:
func(message)
# Utility Modules
self.autopad = AutoPadToShape((self.new_height, self.width))
self.dropout = nn.Dropout(self.params.dropout)
self.slider = SlidingWindow((channels, *self.autopad.target_shape), (self.params.patch_size, self.width),
keepdim=False)
# Modules with Parameters
self.transformer = TransformerModule(in_shape=self.embed_dim, mlp_dim=self.params.lat_dim,
heads=self.params.heads, depth=self.params.attn_depth,
dropout=self.params.dropout, use_norm=self.params.use_norm,
activation=self.params.activation
)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, self.embed_dim))
self.patch_to_embedding = nn.Linear(patch_dim, self.embed_dim) if self.params.embedding_size \
else F_x(self.embed_dim)
self.cls_token = nn.Parameter(torch.randn(1, 1, self.embed_dim))
self.to_cls_token = nn.Identity()
self.mlp_head = nn.Sequential(
nn.LayerNorm(self.embed_dim),
nn.Linear(self.embed_dim, self.params.lat_dim),
nn.GELU(),
nn.Dropout(self.params.dropout),
nn.Linear(self.params.lat_dim, n_classes),
nn.Softmax()
)
def forward(self, x, mask=None, return_attn_weights=False):
"""
:param x: the sequence to the encoder (required).
:param mask: the mask for the src sequence (optional).
:param return_attn_weights: wether to return the attn weights (optional)
:return:
"""
tensor = self.autopad(x)
tensor = self.slider(tensor)
tensor = self.patch_to_embedding(tensor)
b, n, _ = tensor.shape
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b)
tensor = torch.cat((cls_tokens, tensor), dim=1)
tensor += self.pos_embedding[:, :(n + 1)]
tensor = self.dropout(tensor)
if return_attn_weights:
tensor, attn_weights = self.transformer(tensor, mask, return_attn_weights)
else:
attn_weights = None
tensor = self.transformer(tensor, mask)
tensor = self.to_cls_token(tensor[:, 0])
tensor = self.mlp_head(tensor)
return Namespace(main_out=tensor, attn_weights=attn_weights)
def additional_scores(self, outputs):
return MultiClassScores(self)(outputs)

48
multi_run.py Normal file
View File

@ -0,0 +1,48 @@
from argparse import Namespace
from tqdm import tqdm
from main import run_lightning_loop
from ml_lib.utils.config import parse_comandline_args_add_defaults
import itertools
if __name__ == '__main__':
# Set new values
hparams_dict = dict(model_name=['VisualTransformer'],
max_epochs=[150],
batch_size=[50],
random_apply_chance=[0.5],
loudness_ratio=[0],
shift_ratio=[0.3],
noise_ratio=[0.3],
mask_ratio=[0.3],
lr=[0.001],
dropout=[0.2],
lat_dim=[32, 64],
patch_size=[8, 12],
attn_depth=[12],
heads=[6],
embedding_size=[16, 32],
loss=['ce_loss'],
sampler=['WeightedRandomSampler']
)
keys, values = zip(*hparams_dict.items())
permutations_dicts = [dict(zip(keys, v)) for v in itertools.product(*values)]
for permutations_dict in tqdm(permutations_dicts, total=len(permutations_dicts)):
# Parse comandline args, read config and get model
cmd_args, found_data_class, found_model_class = parse_comandline_args_add_defaults(
'_parameters.ini', overrides=permutations_dict)
hparams = dict(**cmd_args)
hparams.update(permutations_dict)
hparams = Namespace(**hparams)
# RUN
# ---------------------------------------
print(f'Running Loop, parameters are: {permutations_dict}')
run_lightning_loop(hparams, found_data_class, found_model_class)
print(f'Done, parameters were: {permutations_dict}')
pass

34
rebuild_dataset.py Normal file
View File

@ -0,0 +1,34 @@
from argparse import Namespace
import warnings
from ml_lib.utils.config import parse_comandline_args_add_defaults
warnings.filterwarnings('ignore', category=FutureWarning)
warnings.filterwarnings('ignore', category=UserWarning)
def rebuild_dataset(h_params, data_class):
# START
# =============================================================================
# Let Datamodule pull what it wants
datamodule = data_class.from_argparse_args(h_params)
assert datamodule.purge()
datasets = datamodule.prepare_data()
datasets = datamodule.setup()
print(f'Dataset length is: {len(datasets)}')
if __name__ == '__main__':
# Parse comandline args, read config and get model
cmd_args, found_data_class, _ = parse_comandline_args_add_defaults('_parameters.ini')
# To NameSpace
hparams = Namespace(**cmd_args)
# Start
# -----------------
rebuild_dataset(hparams, found_data_class)
print('done')
pass

View File

@ -1,9 +1,13 @@
from torch import nn from torch import nn
from ml_lib.additions.losses import FocalLoss, FocalLossRob
class LossMixin: class LossMixin:
absolute_loss = nn.L1Loss() absolute_loss = nn.L1Loss()
nll_loss = nn.NLLLoss() nll_loss = nn.NLLLoss()
bce_loss = nn.BCELoss() bce_loss = nn.BCELoss()
ce_loss = nn.CrossEntropyLoss() ce_loss = nn.CrossEntropyLoss()
focal_loss = FocalLoss(None)
focal_loss_rob = FocalLossRob()

View File

@ -1,6 +1,7 @@
from abc import ABC from abc import ABC
import torch import torch
import pandas as pd
from ml_lib.modules.util import LightningBaseModule from ml_lib.modules.util import LightningBaseModule
from util.loss_mixin import LossMixin from util.loss_mixin import LossMixin
@ -11,9 +12,15 @@ class TrainMixin:
def training_step(self, batch_xy, batch_nb, *args, **kwargs): def training_step(self, batch_xy, batch_nb, *args, **kwargs):
assert isinstance(self, LightningBaseModule) assert isinstance(self, LightningBaseModule)
batch_x, batch_y = batch_xy batch_files, batch_x, batch_y = batch_xy
y = self(batch_x).main_out y = self(batch_x).main_out
loss = self.ce_loss(y.squeeze(), batch_y.long())
if self.params.loss == 'focal_loss_rob':
labels_one_hot = torch.nn.functional.one_hot(batch_y, num_classes=5)
loss = self.__getattribute__(self.params.loss)(y, labels_one_hot)
else:
loss = self.__getattribute__(self.params.loss)(y, batch_y.long())
return dict(loss=loss) return dict(loss=loss)
def training_epoch_end(self, outputs): def training_epoch_end(self, outputs):
@ -23,21 +30,23 @@ class TrainMixin:
summary_dict = {f'mean_{key}': torch.mean(torch.stack([output[key] summary_dict = {f'mean_{key}': torch.mean(torch.stack([output[key]
for output in outputs])) for output in outputs]))
for key in keys if 'loss' in key} for key in keys if 'loss' in key}
for key in summary_dict.keys(): summary_dict.update(epoch=self.current_epoch)
self.log(key, summary_dict[key]) self.log_dict(summary_dict)
class ValMixin: class ValMixin:
def validation_step(self, batch_xy, batch_idx, *args, **kwargs): def validation_step(self, batch_xy, batch_idx, *args, **kwargs):
assert isinstance(self, LightningBaseModule) assert isinstance(self, LightningBaseModule)
batch_x, batch_y = batch_xy batch_files, batch_x, batch_y = batch_xy
model_out = self(batch_x) model_out = self(batch_x)
y = model_out.main_out y = model_out.main_out
val_loss = self.ce_loss(y.squeeze(), batch_y.long()) val_loss = self.ce_loss(y, batch_y.long())
return dict(val_loss=val_loss, self.metrics.update(y, batch_y) # torch.argmax(y, -1), batch_y)
return dict(val_loss=val_loss, batch_files=batch_files,
batch_idx=batch_idx, y=y, batch_y=batch_y) batch_idx=batch_idx, y=y, batch_y=batch_y)
def validation_epoch_end(self, outputs, *_, **__): def validation_epoch_end(self, outputs, *_, **__):
@ -49,40 +58,40 @@ class ValMixin:
for output in outputs])) for output in outputs]))
for key in keys if 'loss' in key} for key in keys if 'loss' in key}
) )
# Sklearn Scores
additional_scores = self.additional_scores(outputs) additional_scores = self.additional_scores(outputs)
summary_dict.update(**additional_scores) summary_dict.update(**additional_scores)
for key in summary_dict.keys(): pl_metrics, pl_images = self.metrics.compute_and_prepare()
self.log(key, summary_dict[key]) self.metrics.reset()
summary_dict.update(**pl_metrics)
summary_dict.update(epoch=self.current_epoch)
self.log_dict(summary_dict, on_epoch=True)
for name, image in pl_images.items():
self.logger.log_image(name, image, step=self.global_step)
pass
class TestMixin: class TestMixin:
def test_step(self, batch_xy, batch_idx, *_, **__): def test_step(self, batch_xy, batch_idx, *_, **__):
assert isinstance(self, LightningBaseModule) assert isinstance(self, LightningBaseModule)
batch_x, batch_y = batch_xy batch_files, batch_x, batch_y = batch_xy
model_out = self(batch_x) model_out = self(batch_x)
y = model_out.main_out y = model_out.main_out
test_loss = self.ce_loss(y.squeeze(), batch_y.long()) return dict(batch_files=batch_files, batch_idx=batch_idx, y=y)
return dict(test_loss=test_loss,
batch_idx=batch_idx, y=y, batch_y=batch_y)
def test_epoch_end(self, outputs, *_, **__): def test_epoch_end(self, outputs, *_, **__):
assert isinstance(self, LightningBaseModule) assert isinstance(self, LightningBaseModule)
summary_dict = dict()
keys = list(outputs[0].keys()) y_arg_max = torch.argmax(outputs[0]['y'])
summary_dict.update({f'mean_{key}': torch.mean(torch.stack([output[key]
for output in outputs]))
for key in keys if 'loss' in key}
)
additional_scores = self.additional_scores(outputs) pd.DataFrame(data=dict(filenames=outputs[0]['batch_files'], predtiction=y_arg_max))
summary_dict.update(**additional_scores)
for key in summary_dict.keys(): # No logging, just inference.
self.log(key, summary_dict[key]) # self.log_dict(summary_dict, on_epoch=True)
class CombinedModelMixins(LossMixin, class CombinedModelMixins(LossMixin,

View File

@ -1,6 +1,6 @@
from collections import defaultdict from collections import defaultdict
from torch.optim import Adam from torch.optim import Adam, AdamW
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, LambdaLR from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, LambdaLR
from torchcontrib.optim import SWA from torchcontrib.optim import SWA
@ -20,12 +20,12 @@ class OptimizerMixin:
# 'monitor': 'mean_val_loss' # Metric to monitor # 'monitor': 'mean_val_loss' # Metric to monitor
) )
optimizer = Adam(params=self.parameters(), lr=self.params.lr, weight_decay=self.params.weight_decay) optimizer = AdamW(params=self.parameters(), lr=self.params.lr, weight_decay=self.params.weight_decay)
if self.params.sto_weight_avg: if self.params.sto_weight_avg:
optimizer = SWA(optimizer, swa_start=10, swa_freq=5, swa_lr=0.05) optimizer = SWA(optimizer, swa_start=10, swa_freq=5, swa_lr=0.05)
optimizer_dict.update(optimizer=optimizer) optimizer_dict.update(optimizer=optimizer)
if self.params.lr_warmup_steps: if self.params.lr_warm_restart_epochs:
scheduler = CosineAnnealingWarmRestarts(optimizer, self.params.lr_warmup_steps) scheduler = CosineAnnealingWarmRestarts(optimizer, self.params.lr_warm_restart_epochs)
else: else:
scheduler = LambdaLR(optimizer, lr_lambda=lambda epoch: 0.95 ** epoch) scheduler = LambdaLR(optimizer, lr_lambda=lambda epoch: 0.95 ** epoch)
optimizer_dict.update(lr_scheduler=scheduler) optimizer_dict.update(lr_scheduler=scheduler)
@ -42,4 +42,4 @@ class OptimizerMixin:
if self.params.opt_reset_interval: if self.params.opt_reset_interval:
if self.current_epoch % self.params.opt_reset_interval == 0: if self.current_epoch % self.params.opt_reset_interval == 0:
for opt in self.trainer.optimizers: for opt in self.trainer.optimizers:
opt.state = defaultdict(dict) opt.state = defaultdict(dict)

View File

@ -3,8 +3,6 @@ from pathlib import Path
sr = 16000 sr = 16000
PRIMATES_Root = Path(__file__).parent / 'data' / 'primates' PRIMATES_Root = Path(__file__).parent / 'data' / 'primates'
N_CLASS_multi = 4 N_CLASS_multi = 5