Transformer running
This commit is contained in:
parent
7edd3834a1
commit
ad254dae92
189
.idea/workspace.xml
generated
189
.idea/workspace.xml
generated
@ -1,18 +1,38 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="AutoImportSettings">
|
||||
<option name="autoReloadType" value="SELECTIVE" />
|
||||
</component>
|
||||
<component name="ChangeListManager">
|
||||
<list default="true" id="2be1f675-29fe-4a7d-9fe6-9e96cd7c8055" name="Default Changelist" comment="">
|
||||
<change afterPath="$PROJECT_DIR$/ml_lib/metrics/attention_rollout.py" afterDir="false" />
|
||||
<change afterPath="$PROJECT_DIR$/ml_lib/utils/_basedatamodule.py" afterDir="false" />
|
||||
<change afterPath="$PROJECT_DIR$/ml_lib/additions/__init__.py" afterDir="false" />
|
||||
<change afterPath="$PROJECT_DIR$/ml_lib/additions/losses.py" afterDir="false" />
|
||||
<change afterPath="$PROJECT_DIR$/ml_lib/utils/equal_sampler.py" afterDir="false" />
|
||||
<change afterPath="$PROJECT_DIR$/models/performer.py" afterDir="false" />
|
||||
<change afterPath="$PROJECT_DIR$/models/transformer_model.py" afterDir="false" />
|
||||
<change afterPath="$PROJECT_DIR$/models/transformer_model_horizontal.py" afterDir="false" />
|
||||
<change afterPath="$PROJECT_DIR$/multi_run.py" afterDir="false" />
|
||||
<change afterPath="$PROJECT_DIR$/rebuild_dataset.py" afterDir="false" />
|
||||
<change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" />
|
||||
<change beforePath="$PROJECT_DIR$/_parameters.ini" beforeDir="false" afterPath="$PROJECT_DIR$/_parameters.ini" afterDir="false" />
|
||||
<change beforePath="$PROJECT_DIR$/datasets/primates_librosa_datamodule.py" beforeDir="false" afterPath="$PROJECT_DIR$/datasets/primates_librosa_datamodule.py" afterDir="false" />
|
||||
<change beforePath="$PROJECT_DIR$/main.py" beforeDir="false" afterPath="$PROJECT_DIR$/main.py" afterDir="false" />
|
||||
<change beforePath="$PROJECT_DIR$/ml_lib/_templates/new_project/main.py" beforeDir="false" afterPath="$PROJECT_DIR$/ml_lib/_templates/new_project/main.py" afterDir="false" />
|
||||
<change beforePath="$PROJECT_DIR$/ml_lib/audio_toolset/audio_to_mel_dataset.py" beforeDir="false" afterPath="$PROJECT_DIR$/ml_lib/audio_toolset/audio_to_mel_dataset.py" afterDir="false" />
|
||||
<change beforePath="$PROJECT_DIR$/ml_lib/audio_toolset/mel_dataset.py" beforeDir="false" afterPath="$PROJECT_DIR$/ml_lib/audio_toolset/mel_dataset.py" afterDir="false" />
|
||||
<change beforePath="$PROJECT_DIR$/ml_lib/experiments.py" beforeDir="false" />
|
||||
<change beforePath="$PROJECT_DIR$/ml_lib/metrics/multi_class_classification.py" beforeDir="false" afterPath="$PROJECT_DIR$/ml_lib/metrics/multi_class_classification.py" afterDir="false" />
|
||||
<change beforePath="$PROJECT_DIR$/ml_lib/modules/blocks.py" beforeDir="false" afterPath="$PROJECT_DIR$/ml_lib/modules/blocks.py" afterDir="false" />
|
||||
<change beforePath="$PROJECT_DIR$/ml_lib/modules/model_parts.py" beforeDir="false" afterPath="$PROJECT_DIR$/ml_lib/modules/model_parts.py" afterDir="false" />
|
||||
<change beforePath="$PROJECT_DIR$/ml_lib/modules/util.py" beforeDir="false" afterPath="$PROJECT_DIR$/ml_lib/modules/util.py" afterDir="false" />
|
||||
<change beforePath="$PROJECT_DIR$/ml_lib/utils/logging.py" beforeDir="false" afterPath="$PROJECT_DIR$/ml_lib/utils/logging.py" afterDir="false" />
|
||||
<change beforePath="$PROJECT_DIR$/ml_lib/utils/model_io.py" beforeDir="false" afterPath="$PROJECT_DIR$/ml_lib/utils/model_io.py" afterDir="false" />
|
||||
<change beforePath="$PROJECT_DIR$/ml_lib/utils/_basedatamodule.py" beforeDir="false" afterPath="$PROJECT_DIR$/ml_lib/utils/_basedatamodule.py" afterDir="false" />
|
||||
<change beforePath="$PROJECT_DIR$/ml_lib/utils/config.py" beforeDir="false" afterPath="$PROJECT_DIR$/ml_lib/utils/config.py" afterDir="false" />
|
||||
<change beforePath="$PROJECT_DIR$/ml_lib/utils/logging.py" beforeDir="false" afterPath="$PROJECT_DIR$/ml_lib/utils/loggers.py" afterDir="false" />
|
||||
<change beforePath="$PROJECT_DIR$/ml_lib/utils/tools.py" beforeDir="false" afterPath="$PROJECT_DIR$/ml_lib/utils/tools.py" afterDir="false" />
|
||||
<change beforePath="$PROJECT_DIR$/models/cnn_baseline.py" beforeDir="false" afterPath="$PROJECT_DIR$/models/cnn_baseline.py" afterDir="false" />
|
||||
<change beforePath="$PROJECT_DIR$/util/loss_mixin.py" beforeDir="false" afterPath="$PROJECT_DIR$/util/loss_mixin.py" afterDir="false" />
|
||||
<change beforePath="$PROJECT_DIR$/util/module_mixins.py" beforeDir="false" afterPath="$PROJECT_DIR$/util/module_mixins.py" afterDir="false" />
|
||||
<change beforePath="$PROJECT_DIR$/util/optimizer_mixin.py" beforeDir="false" afterPath="$PROJECT_DIR$/util/optimizer_mixin.py" afterDir="false" />
|
||||
<change beforePath="$PROJECT_DIR$/variables.py" beforeDir="false" afterPath="$PROJECT_DIR$/variables.py" afterDir="false" />
|
||||
</list>
|
||||
<option name="SHOW_DIALOG" value="false" />
|
||||
<option name="HIGHLIGHT_CONFLICTS" value="true" />
|
||||
@ -27,7 +47,15 @@
|
||||
</option>
|
||||
</component>
|
||||
<component name="Git.Settings">
|
||||
<option name="RECENT_GIT_ROOT_PATH" value="$PROJECT_DIR$/ml_lib" />
|
||||
<option name="RECENT_GIT_ROOT_PATH" value="$PROJECT_DIR$" />
|
||||
</component>
|
||||
<component name="GitSEFilterConfiguration">
|
||||
<file-type-list>
|
||||
<filtered-out-file-type name="LOCAL_BRANCH" />
|
||||
<filtered-out-file-type name="REMOTE_BRANCH" />
|
||||
<filtered-out-file-type name="TAG" />
|
||||
<filtered-out-file-type name="COMMIT_BY_MESSAGE" />
|
||||
</file-type-list>
|
||||
</component>
|
||||
<component name="ProjectId" id="1oTEXjx0b8UPBPmOIGceYxEch8r" />
|
||||
<component name="ProjectLevelVcsManager" settingsEditedManually="true" />
|
||||
@ -38,14 +66,53 @@
|
||||
<component name="PropertiesComponent">
|
||||
<property name="RunOnceActivity.OpenProjectViewOnStart" value="true" />
|
||||
<property name="RunOnceActivity.ShowReadmeOnStart" value="true" />
|
||||
<property name="WebServerToolWindowFactoryState" value="true" />
|
||||
<property name="WebServerToolWindowFactoryState" value="false" />
|
||||
<property name="WebServerToolWindowPanel.toolwindow.highlight.mappings" value="true" />
|
||||
<property name="WebServerToolWindowPanel.toolwindow.highlight.symlinks" value="true" />
|
||||
<property name="WebServerToolWindowPanel.toolwindow.show.date" value="false" />
|
||||
<property name="WebServerToolWindowPanel.toolwindow.show.permissions" value="false" />
|
||||
<property name="WebServerToolWindowPanel.toolwindow.show.size" value="false" />
|
||||
<property name="credentialsType com.jetbrains.python.remote.PyCreateRemoteInterpreterDialog$PyCreateRemoteSdkForm" value="Web Deployment" />
|
||||
<property name="last_opened_file_path" value="$PROJECT_DIR$/../inter_challenge_2020" />
|
||||
<property name="restartRequiresConfirmation" value="false" />
|
||||
<property name="settings.editor.selected.configurable" value="com.jetbrains.python.configuration.PyActiveSdkModuleConfigurable" />
|
||||
</component>
|
||||
<component name="RunManager" selected="Python.main">
|
||||
<component name="PyDebuggerOptionsProvider">
|
||||
<option name="mySaveCallSignatures" value="true" />
|
||||
</component>
|
||||
<component name="RecentsManager">
|
||||
<key name="CopyFile.RECENT_KEYS">
|
||||
<recent name="C:\Users\steff\projects\compare_21\models" />
|
||||
<recent name="C:\Users\steff\projects\compare_21\util" />
|
||||
<recent name="C:\Users\steff\projects\compare_21" />
|
||||
</key>
|
||||
<key name="MoveFile.RECENT_KEYS">
|
||||
<recent name="C:\Users\steff\projects\compare_21\util" />
|
||||
</key>
|
||||
</component>
|
||||
<component name="RunManager" selected="Python.rebuild_dataset">
|
||||
<configuration name="equal_sampler" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true">
|
||||
<module name="compare_21" />
|
||||
<option name="INTERPRETER_OPTIONS" value="" />
|
||||
<option name="PARENT_ENVS" value="true" />
|
||||
<envs>
|
||||
<env name="PYTHONUNBUFFERED" value="1" />
|
||||
</envs>
|
||||
<option name="SDK_HOME" value="" />
|
||||
<option name="WORKING_DIRECTORY" value="$PROJECT_DIR$/ml_lib/utils" />
|
||||
<option name="IS_MODULE_SDK" value="true" />
|
||||
<option name="ADD_CONTENT_ROOTS" value="true" />
|
||||
<option name="ADD_SOURCE_ROOTS" value="true" />
|
||||
<EXTENSION ID="PythonCoverageRunConfigurationExtension" runner="coverage.py" />
|
||||
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/ml_lib/utils/equal_sampler.py" />
|
||||
<option name="PARAMETERS" value="" />
|
||||
<option name="SHOW_COMMAND_LINE" value="false" />
|
||||
<option name="EMULATE_TERMINAL" value="false" />
|
||||
<option name="MODULE_MODE" value="false" />
|
||||
<option name="REDIRECT_INPUT" value="false" />
|
||||
<option name="INPUT_FILE" value="" />
|
||||
<method v="2" />
|
||||
</configuration>
|
||||
<configuration name="main" type="PythonConfigurationType" factoryName="Python" nameIsGenerated="true">
|
||||
<module name="compare_21" />
|
||||
<option name="INTERPRETER_OPTIONS" value="" />
|
||||
@ -60,7 +127,7 @@
|
||||
<option name="ADD_SOURCE_ROOTS" value="true" />
|
||||
<EXTENSION ID="PythonCoverageRunConfigurationExtension" runner="coverage.py" />
|
||||
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/main.py" />
|
||||
<option name="PARAMETERS" value="" />
|
||||
<option name="PARAMETERS" value="--debug=True --num_worker=0 --sampler=WeightedRandomSampler --model_name=VisualTransformer" />
|
||||
<option name="SHOW_COMMAND_LINE" value="false" />
|
||||
<option name="EMULATE_TERMINAL" value="false" />
|
||||
<option name="MODULE_MODE" value="false" />
|
||||
@ -90,8 +157,62 @@
|
||||
<option name="INPUT_FILE" value="" />
|
||||
<method v="2" />
|
||||
</configuration>
|
||||
<configuration name="multi_run" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true">
|
||||
<module name="compare_21" />
|
||||
<option name="INTERPRETER_OPTIONS" value="" />
|
||||
<option name="PARENT_ENVS" value="true" />
|
||||
<envs>
|
||||
<env name="PYTHONUNBUFFERED" value="1" />
|
||||
</envs>
|
||||
<option name="SDK_HOME" value="" />
|
||||
<option name="WORKING_DIRECTORY" value="$PROJECT_DIR$" />
|
||||
<option name="IS_MODULE_SDK" value="true" />
|
||||
<option name="ADD_CONTENT_ROOTS" value="true" />
|
||||
<option name="ADD_SOURCE_ROOTS" value="true" />
|
||||
<EXTENSION ID="PythonCoverageRunConfigurationExtension" runner="coverage.py" />
|
||||
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/multi_run.py" />
|
||||
<option name="PARAMETERS" value="" />
|
||||
<option name="SHOW_COMMAND_LINE" value="false" />
|
||||
<option name="EMULATE_TERMINAL" value="false" />
|
||||
<option name="MODULE_MODE" value="false" />
|
||||
<option name="REDIRECT_INPUT" value="false" />
|
||||
<option name="INPUT_FILE" value="" />
|
||||
<method v="2" />
|
||||
</configuration>
|
||||
<configuration name="rebuild_dataset" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true">
|
||||
<module name="compare_21" />
|
||||
<option name="INTERPRETER_OPTIONS" value="" />
|
||||
<option name="PARENT_ENVS" value="true" />
|
||||
<envs>
|
||||
<env name="PYTHONUNBUFFERED" value="1" />
|
||||
</envs>
|
||||
<option name="SDK_HOME" value="" />
|
||||
<option name="WORKING_DIRECTORY" value="$PROJECT_DIR$" />
|
||||
<option name="IS_MODULE_SDK" value="true" />
|
||||
<option name="ADD_CONTENT_ROOTS" value="true" />
|
||||
<option name="ADD_SOURCE_ROOTS" value="true" />
|
||||
<EXTENSION ID="PythonCoverageRunConfigurationExtension" runner="coverage.py" />
|
||||
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/rebuild_dataset.py" />
|
||||
<option name="PARAMETERS" value="" />
|
||||
<option name="SHOW_COMMAND_LINE" value="false" />
|
||||
<option name="EMULATE_TERMINAL" value="false" />
|
||||
<option name="MODULE_MODE" value="false" />
|
||||
<option name="REDIRECT_INPUT" value="false" />
|
||||
<option name="INPUT_FILE" value="" />
|
||||
<method v="2" />
|
||||
</configuration>
|
||||
<list>
|
||||
<item itemvalue="Python.main" />
|
||||
<item itemvalue="Python.metadata_readout" />
|
||||
<item itemvalue="Python.multi_run" />
|
||||
<item itemvalue="Python.equal_sampler" />
|
||||
<item itemvalue="Python.rebuild_dataset" />
|
||||
</list>
|
||||
<recent_temporary>
|
||||
<list>
|
||||
<item itemvalue="Python.rebuild_dataset" />
|
||||
<item itemvalue="Python.multi_run" />
|
||||
<item itemvalue="Python.equal_sampler" />
|
||||
<item itemvalue="Python.metadata_readout" />
|
||||
</list>
|
||||
</recent_temporary>
|
||||
@ -106,25 +227,69 @@
|
||||
<updated>1613302221903</updated>
|
||||
<workItem from="1613302223434" duration="2570000" />
|
||||
<workItem from="1613305247599" duration="11387000" />
|
||||
<workItem from="1613381449219" duration="21710000" />
|
||||
<workItem from="1613634319983" duration="160704000" />
|
||||
<workItem from="1614498086696" duration="17996000" />
|
||||
<workItem from="1614671803611" duration="595000" />
|
||||
<workItem from="1614679476632" duration="11417000" />
|
||||
<workItem from="1614760180356" duration="12583000" />
|
||||
<workItem from="1614788597119" duration="8215000" />
|
||||
</task>
|
||||
<task id="LOCAL-00001" summary="Dataset rdy">
|
||||
<created>1613467084268</created>
|
||||
<option name="number" value="00001" />
|
||||
<option name="presentableId" value="LOCAL-00001" />
|
||||
<option name="project" value="LOCAL" />
|
||||
<updated>1613467084268</updated>
|
||||
</task>
|
||||
<option name="localTasksCounter" value="2" />
|
||||
<servers />
|
||||
</component>
|
||||
<component name="TypeScriptGeneratedFilesManager">
|
||||
<option name="version" value="3" />
|
||||
</component>
|
||||
<component name="VcsManagerConfiguration">
|
||||
<MESSAGE value="Dataset rdy" />
|
||||
<option name="LAST_COMMIT_MESSAGE" value="Dataset rdy" />
|
||||
</component>
|
||||
<component name="XDebuggerManager">
|
||||
<breakpoint-manager>
|
||||
<breakpoints>
|
||||
<line-breakpoint enabled="true" suspend="THREAD" type="python-line">
|
||||
<url>file://$PROJECT_DIR$/ml_lib/modules/util.py</url>
|
||||
<line>231</line>
|
||||
<option name="timeStamp" value="70" />
|
||||
</line-breakpoint>
|
||||
<line-breakpoint enabled="true" suspend="THREAD" type="python-line">
|
||||
<url>file://$PROJECT_DIR$/models/transformer_model_horizontal.py</url>
|
||||
<line>68</line>
|
||||
<option name="timeStamp" value="72" />
|
||||
</line-breakpoint>
|
||||
<line-breakpoint enabled="true" suspend="THREAD" type="python-line">
|
||||
<url>file://$PROJECT_DIR$/ml_lib/audio_toolset/mel_dataset.py</url>
|
||||
<line>29</line>
|
||||
<option name="timeStamp" value="83" />
|
||||
</line-breakpoint>
|
||||
<line-breakpoint enabled="true" suspend="THREAD" type="python-line">
|
||||
<url>file://$PROJECT_DIR$/main.py</url>
|
||||
<line>11</line>
|
||||
<option name="timeStamp" value="2" />
|
||||
<line>46</line>
|
||||
<option name="timeStamp" value="84" />
|
||||
</line-breakpoint>
|
||||
</breakpoints>
|
||||
<default-breakpoints>
|
||||
<breakpoint type="python-exception">
|
||||
<properties notifyOnTerminate="true" exception="BaseException">
|
||||
<option name="notifyOnTerminate" value="true" />
|
||||
</properties>
|
||||
</breakpoint>
|
||||
</default-breakpoints>
|
||||
</breakpoint-manager>
|
||||
</component>
|
||||
<component name="com.intellij.coverage.CoverageDataManagerImpl">
|
||||
<SUITE FILE_PATH="coverage/compare_21$equal_sampler.coverage" NAME="equal_sampler Coverage Results" MODIFIED="1614070702295" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/ml_lib/utils" />
|
||||
<SUITE FILE_PATH="coverage/compare_21$rebuild_dataset.coverage" NAME="rebuild_dataset Coverage Results" MODIFIED="1614849857426" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" />
|
||||
<SUITE FILE_PATH="coverage/compare_21$metadata_readout.coverage" NAME="metadata_readout Coverage Results" MODIFIED="1613306122664" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/datasets" />
|
||||
<SUITE FILE_PATH="coverage/compare_21$main.coverage" NAME="main Coverage Results" MODIFIED="1613376576627" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" />
|
||||
<SUITE FILE_PATH="coverage/compare_21$main.coverage" NAME="main Coverage Results" MODIFIED="1614695189720" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" />
|
||||
<SUITE FILE_PATH="coverage/compare_21$multi_run.coverage" NAME="multi_run Coverage Results" MODIFIED="1614702462483" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" />
|
||||
</component>
|
||||
</project>
|
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
neptune_key = eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vdWkubmVwdHVuZS5haSIsImFwaV91cmwiOiJodHRwczovL3VpLm5lcHR1bmUuYWkiLCJhcGlfa2V5IjoiZmI0OGMzNzUtOTg1NS00Yzg2LThjMzYtMWFiYjUwMDUyMjVlIn0=
|
||||
debug = 1
|
||||
debug = False
|
||||
eval = True
|
||||
seed = 69
|
||||
owner = si11ium
|
||||
@ -8,56 +8,72 @@ model_name = CNNBaseline
|
||||
data_name = PrimatesLibrosaDatamodule
|
||||
|
||||
[data]
|
||||
num_worker = 0
|
||||
num_worker = 10
|
||||
data_root = data
|
||||
reset = False
|
||||
n_mels = 64
|
||||
sr = 16000
|
||||
hop_length = 256
|
||||
n_fft = 512
|
||||
hop_length = 128
|
||||
n_fft = 256
|
||||
|
||||
random_apply_chance = 0.7
|
||||
loudness_ratio = 0.0
|
||||
shift_ratio = 0.0
|
||||
noise_ratio = 0
|
||||
mask_ratio = 0.3
|
||||
speed_amount = 0
|
||||
speed_min = 0
|
||||
speed_max = 0
|
||||
shift_ratio = 0.3
|
||||
noise_ratio = 0.0
|
||||
mask_ratio = 0.0
|
||||
|
||||
[model_cnn]
|
||||
[CNNBaseline]
|
||||
weight_init = xavier_normal_
|
||||
activation = gelu
|
||||
use_bias = True
|
||||
use_norm = True
|
||||
dropout = 0.2
|
||||
lat_dim = 128
|
||||
lat_dim = 32
|
||||
features = 64
|
||||
filters = [32, 64, 128, 64]
|
||||
filters = [16, 32, 64, 128]
|
||||
|
||||
[model_attn]
|
||||
name = VerticalVisualTransformer
|
||||
[VisualTransformer]
|
||||
weight_init = xavier_normal_
|
||||
activation = gelu
|
||||
use_bias = True
|
||||
use_norm = True
|
||||
use_residual = True
|
||||
dropout = 0.2
|
||||
lat_dim = 128
|
||||
features = 64
|
||||
patch_size = 3
|
||||
attn_depth = 3
|
||||
heads = 8
|
||||
embedding_size = 64
|
||||
|
||||
lat_dim = 32
|
||||
patch_size = 8
|
||||
attn_depth = 12
|
||||
heads = 4
|
||||
embedding_size = 128
|
||||
|
||||
[HorizontalVisualTransformer]
|
||||
weight_init = xavier_normal_
|
||||
activation = gelu
|
||||
use_bias = True
|
||||
use_norm = True
|
||||
use_residual = True
|
||||
dropout = 0.3
|
||||
|
||||
lat_dim = 256
|
||||
patch_size = 8
|
||||
attn_depth = 12
|
||||
heads = 6
|
||||
embedding_size = 32
|
||||
|
||||
[train]
|
||||
outpath = output
|
||||
version = None
|
||||
gpus=0
|
||||
sampler = EqualSampler
|
||||
loss = focal_loss_rob
|
||||
sto_weight_avg = False
|
||||
weight_decay = 0
|
||||
opt_reset_interval = 0
|
||||
epochs = 100
|
||||
max_epochs = 200
|
||||
batch_size = 30
|
||||
lr = 0.01
|
||||
lr_warmup_steps = 0
|
||||
num_sanity_val_steps = 0
|
||||
|
||||
lr = 0.001
|
||||
use_residual = True
|
||||
lr_warm_restart_epochs = 0
|
||||
num_sanity_val_steps = 2
|
||||
check_val_every_n_epoch = 5
|
||||
checkpoint_callback = True
|
||||
gradient_clip_val = 0
|
||||
|
@ -1,8 +1,8 @@
|
||||
from multiprocessing.pool import ApplyResult
|
||||
import multiprocessing as mp
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from torch.utils.data import DataLoader, ConcatDataset
|
||||
from torch.utils.data import DataLoader, ConcatDataset, WeightedRandomSampler
|
||||
from torchvision.transforms import Compose, RandomApply
|
||||
from tqdm import tqdm
|
||||
|
||||
@ -10,9 +10,8 @@ from ml_lib.audio_toolset.audio_io import NormalizeLocal
|
||||
from ml_lib.audio_toolset.audio_to_mel_dataset import LibrosaAudioToMelDataset
|
||||
from ml_lib.audio_toolset.mel_augmentation import NoiseInjection, LoudnessManipulator, ShiftTime, MaskAug
|
||||
from ml_lib.utils._basedatamodule import _BaseDataModule, DATA_OPTION_test, DATA_OPTION_train, DATA_OPTION_devel
|
||||
from ml_lib.utils.equal_sampler import EqualSampler
|
||||
from ml_lib.utils.transforms import ToTensor
|
||||
import multiprocessing as mp
|
||||
|
||||
|
||||
data_options = [DATA_OPTION_test, DATA_OPTION_train, DATA_OPTION_devel]
|
||||
|
||||
@ -34,11 +33,16 @@ class PrimatesLibrosaDatamodule(_BaseDataModule):
|
||||
def wav_folder(self):
|
||||
return self.root / 'wav'
|
||||
|
||||
def __init__(self, data_root, batch_size, num_worker, sr, n_mels, n_fft, hop_length,
|
||||
sample_segment_len=40, sample_hop_len=15):
|
||||
def __init__(self, data_root, batch_size, num_worker, sr, n_mels, n_fft, hop_length, sampler=None,
|
||||
sample_segment_len=40, sample_hop_len=15, random_apply_chance=0.5,
|
||||
loudness_ratio=0.3, shift_ratio=0.3, noise_ratio=0.3, mask_ratio=0.3):
|
||||
super(PrimatesLibrosaDatamodule, self).__init__()
|
||||
self.sampler = sampler
|
||||
self.samplers = None
|
||||
|
||||
self.sample_hop_len = sample_hop_len
|
||||
self.sample_segment_len = sample_segment_len
|
||||
|
||||
self.num_worker = num_worker or 1
|
||||
self.batch_size = batch_size
|
||||
self.root = Path(data_root) / 'primates'
|
||||
@ -51,23 +55,22 @@ class PrimatesLibrosaDatamodule(_BaseDataModule):
|
||||
self.utility_transforms = Compose([NormalizeLocal(), ToTensor()])
|
||||
|
||||
# Data Augmentations
|
||||
self.random_apply_chance = random_apply_chance
|
||||
self.mel_augmentations = Compose([
|
||||
# ToDo: HP Search this parameters, make it adjustable from outside
|
||||
RandomApply([NoiseInjection(0.2)], p=0.3),
|
||||
RandomApply([LoudnessManipulator(0.5)], p=0.3),
|
||||
RandomApply([ShiftTime(0.4)], p=0.3),
|
||||
RandomApply([MaskAug(0.2)], p=0.3),
|
||||
RandomApply([NoiseInjection(noise_ratio)], p=random_apply_chance),
|
||||
RandomApply([LoudnessManipulator(loudness_ratio)], p=random_apply_chance),
|
||||
RandomApply([ShiftTime(shift_ratio)], p=random_apply_chance),
|
||||
RandomApply([MaskAug(mask_ratio)], p=random_apply_chance),
|
||||
self.utility_transforms])
|
||||
|
||||
def train_dataloader(self):
|
||||
return DataLoader(dataset=self.datasets[DATA_OPTION_train], shuffle=True,
|
||||
batch_size=self.batch_size, pin_memory=True,
|
||||
num_workers=self.num_worker)
|
||||
return DataLoader(dataset=self.datasets[DATA_OPTION_train], num_workers=self.num_worker, pin_memory=True,
|
||||
sampler=self.samplers[DATA_OPTION_train], batch_size=self.batch_size)
|
||||
|
||||
# Validation Dataloader
|
||||
def val_dataloader(self):
|
||||
return DataLoader(dataset=self.datasets[DATA_OPTION_devel], shuffle=False, pin_memory=True,
|
||||
batch_size=self.batch_size, num_workers=self.num_worker)
|
||||
return DataLoader(dataset=self.datasets[DATA_OPTION_devel], num_workers=self.num_worker, pin_memory=True,
|
||||
sampler=self.samplers[DATA_OPTION_devel], batch_size=self.batch_size)
|
||||
|
||||
# Test Dataloader
|
||||
def test_dataloader(self):
|
||||
@ -79,6 +82,7 @@ class PrimatesLibrosaDatamodule(_BaseDataModule):
|
||||
slice_file_name, class_name = row.strip().split(',')
|
||||
class_id = self.class_names.get(class_name, -1)
|
||||
audio_file_path = self.wav_folder / slice_file_name
|
||||
|
||||
# DATA OPTION DIFFERENTIATION !!!!!!!!!!! - Begin
|
||||
kwargs = self.__dict__
|
||||
if any([x in slice_file_name for x in [DATA_OPTION_devel, DATA_OPTION_test]]):
|
||||
@ -91,7 +95,7 @@ class PrimatesLibrosaDatamodule(_BaseDataModule):
|
||||
mel_dataset = LibrosaAudioToMelDataset(audio_file_path, class_id, **kwargs)
|
||||
if build:
|
||||
assert mel_dataset.build_mel()
|
||||
return mel_dataset
|
||||
return mel_dataset, class_id, slice_file_name
|
||||
|
||||
def prepare_data(self, *args, **kwargs):
|
||||
datasets = dict()
|
||||
@ -103,22 +107,21 @@ class PrimatesLibrosaDatamodule(_BaseDataModule):
|
||||
chunksize = len(all_rows) // max(self.num_worker, 1)
|
||||
dataset = list()
|
||||
with mp.Pool(processes=self.num_worker) as pool:
|
||||
pbar = tqdm(total=len(all_rows))
|
||||
|
||||
def update():
|
||||
pbar.update(chunksize)
|
||||
from itertools import repeat
|
||||
results = pool.starmap_async(self._build_subdataset, zip(all_rows, repeat(True, len(all_rows))),
|
||||
chunksize=chunksize)
|
||||
for sub_dataset in results.get():
|
||||
dataset.append(sub_dataset)
|
||||
update() # FIXME: will i ever get this to work?
|
||||
dataset.append(sub_dataset[0])
|
||||
datasets[data_option] = ConcatDataset(dataset)
|
||||
self.datasets = datasets
|
||||
return datasets
|
||||
|
||||
def setup(self, stag=None):
|
||||
datasets = dict()
|
||||
samplers = dict()
|
||||
weights = dict()
|
||||
|
||||
for data_option in data_options:
|
||||
with open(Path(self.root) / 'lab' / f'{data_option}.csv', mode='r') as f:
|
||||
# Exclude the header
|
||||
@ -126,7 +129,38 @@ class PrimatesLibrosaDatamodule(_BaseDataModule):
|
||||
all_rows = list(f)
|
||||
dataset = list()
|
||||
for row in all_rows:
|
||||
dataset.append(self._build_subdataset(row))
|
||||
mel_dataset, class_id, _ = self._build_subdataset(row)
|
||||
dataset.append(mel_dataset)
|
||||
datasets[data_option] = ConcatDataset(dataset)
|
||||
|
||||
# Build Weighted Sampler for train and val
|
||||
if data_option in [DATA_OPTION_train, DATA_OPTION_devel]:
|
||||
if self.sampler == EqualSampler.__name__:
|
||||
class_idxs = [[idx for idx, (_, __, label) in enumerate(datasets[data_option]) if label == class_idx]
|
||||
for class_idx in range(len(self.class_names))
|
||||
]
|
||||
samplers[data_option] = EqualSampler(class_idxs)
|
||||
elif self.sampler == WeightedRandomSampler.__name__:
|
||||
class_counts = defaultdict(lambda: 0)
|
||||
for _, __, label in datasets[data_option]:
|
||||
class_counts[label] += 1
|
||||
len_largest_class = max(class_counts.values())
|
||||
|
||||
weights[data_option] = [1 / class_counts[x] for x in range(len(class_counts))]
|
||||
weights[data_option] = [weights[data_option][datasets[data_option][i][-1]]
|
||||
for i in range(len(datasets[data_option]))]
|
||||
samplers[data_option] = WeightedRandomSampler(weights[data_option],
|
||||
len_largest_class * len(self.class_names))
|
||||
else:
|
||||
samplers[data_option] = None
|
||||
self.datasets = datasets
|
||||
self.samplers = samplers
|
||||
return datasets
|
||||
|
||||
def purge(self):
|
||||
import shutil
|
||||
|
||||
shutil.rmtree(self.mel_folder, ignore_errors=True)
|
||||
print('Mel Folder has been recursively deleted')
|
||||
print(f'Folder still exists: {self.mel_folder.exists()}')
|
||||
return not self.mel_folder.exists()
|
||||
|
67
main.py
67
main.py
@ -1,45 +1,29 @@
|
||||
import configparser
|
||||
from argparse import ArgumentParser, Namespace
|
||||
from argparse import Namespace
|
||||
|
||||
import warnings
|
||||
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
|
||||
|
||||
from ml_lib.utils.logging import Logger
|
||||
from ml_lib.utils.config import parse_comandline_args_add_defaults
|
||||
from ml_lib.utils.loggers import Logger
|
||||
from ml_lib.utils.tools import locate_and_import_class, auto_cast
|
||||
|
||||
import variables as v
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Argument Parser and default Values
|
||||
# =============================================================================
|
||||
# Load Defaults from _parameters.ini file
|
||||
config = configparser.ConfigParser()
|
||||
config.read('_parameters.ini')
|
||||
project = config['project']
|
||||
warnings.filterwarnings('ignore', category=FutureWarning)
|
||||
warnings.filterwarnings('ignore', category=UserWarning)
|
||||
|
||||
data_class = locate_and_import_class(project['data_name'], 'datasets')
|
||||
model_class = locate_and_import_class(project['model_name'], 'models')
|
||||
|
||||
tmp_params = dict()
|
||||
for key in ['project', 'train', 'data', 'model_cnn']:
|
||||
defaults = config[key]
|
||||
tmp_params.update({key: auto_cast(val) for key, val in defaults.items()})
|
||||
|
||||
# Parse Command Line
|
||||
parser = ArgumentParser()
|
||||
for module in [Logger, Trainer, data_class, model_class]:
|
||||
parser = module.add_argparse_args(parser)
|
||||
cmd_args, _ = parser.parse_known_args()
|
||||
tmp_params.update({key: val for key, val in vars(cmd_args).items() if val is not None})
|
||||
hparams = Namespace(**tmp_params)
|
||||
|
||||
with Logger.from_argparse_args(hparams) as logger:
|
||||
def run_lightning_loop(h_params, data_class, model_class):
|
||||
with Logger.from_argparse_args(h_params) as logger:
|
||||
# Callbacks
|
||||
# =============================================================================
|
||||
# Checkpoint Saving
|
||||
ckpt_callback = ModelCheckpoint(
|
||||
monitor='mean_loss',
|
||||
filepath=str(logger.log_dir / 'ckpt_weights'),
|
||||
dirpath=str(logger.log_dir),
|
||||
filename='ckpt_weights',
|
||||
verbose=False,
|
||||
save_top_k=3,
|
||||
)
|
||||
@ -47,22 +31,37 @@ if __name__ == '__main__':
|
||||
# Learning Rate Logger
|
||||
lr_logger = LearningRateMonitor(logging_interval='epoch')
|
||||
|
||||
#
|
||||
# START
|
||||
# =============================================================================
|
||||
# Let Datamodule pull what it wants
|
||||
datamodule = data_class.from_argparse_args(hparams)
|
||||
datamodule = data_class.from_argparse_args(h_params)
|
||||
datamodule.setup()
|
||||
model_in_shape = datamodule.shape
|
||||
|
||||
# Let Trainer pull what it wants and add callbacks
|
||||
trainer = Trainer.from_argparse_args(hparams, callbacks=[ckpt_callback, lr_logger])
|
||||
trainer = Trainer.from_argparse_args(h_params, logger=logger, callbacks=[ckpt_callback, lr_logger])
|
||||
|
||||
# Let Model pull what it wants
|
||||
model = model_class.from_argparse_args(hparams, in_shape=datamodule.shape, n_classes=v.N_CLASS_multi)
|
||||
model = model_class.from_argparse_args(h_params, in_shape=datamodule.shape, n_classes=v.N_CLASS_multi)
|
||||
model.init_weights()
|
||||
|
||||
logger.log_hyperparams(dict(model.params))
|
||||
trainer.fit(model, datamodule)
|
||||
|
||||
trainer.save_checkpoint(trainer.logger.save_dir)
|
||||
# Log paramters
|
||||
pytorch_total_params = sum(p.numel() for p in model.parameters())
|
||||
# logger.log_text('n_parameters', pytorch_total_params)
|
||||
|
||||
trainer.save_checkpoint(logger.save_dir / 'weights.ckpt')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Parse comandline args, read config and get model
|
||||
cmd_args, found_data_class, found_model_class = parse_comandline_args_add_defaults('_parameters.ini')
|
||||
|
||||
# To NameSpace
|
||||
hparams = Namespace(**cmd_args)
|
||||
|
||||
# Start
|
||||
# -----------------
|
||||
run_lightning_loop(hparams, found_data_class, found_model_class)
|
||||
print('done')
|
||||
pass
|
||||
|
@ -1,8 +1,6 @@
|
||||
import inspect
|
||||
from argparse import Namespace
|
||||
|
||||
import variables as v
|
||||
|
||||
from torch import nn
|
||||
|
||||
from ml_lib.metrics.multi_class_classification import MultiClassScores
|
||||
@ -17,7 +15,7 @@ class CNNBaseline(CombinedModelMixins,
|
||||
):
|
||||
|
||||
def __init__(self, in_shape, n_classes, weight_init, activation, use_bias, use_norm, dropout, lat_dim, features,
|
||||
filters):
|
||||
filters, lr, weight_decay, sto_weight_avg, lr_warm_restart_epochs, opt_reset_interval, loss):
|
||||
|
||||
# TODO: Move this to parent class, or make it much easieer to access....
|
||||
a = dict(locals())
|
||||
@ -50,4 +48,4 @@ class CNNBaseline(CombinedModelMixins,
|
||||
return Namespace(main_out=tensor)
|
||||
|
||||
def additional_scores(self, outputs):
|
||||
return MultiClassScores(self)
|
||||
return MultiClassScores(self)(outputs)
|
||||
|
0
models/performer.py
Normal file
0
models/performer.py
Normal file
120
models/transformer_model.py
Normal file
120
models/transformer_model.py
Normal file
@ -0,0 +1,120 @@
|
||||
import inspect
|
||||
from argparse import Namespace
|
||||
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from einops import rearrange, repeat
|
||||
|
||||
from ml_lib.metrics.multi_class_classification import MultiClassScores
|
||||
from ml_lib.modules.blocks import TransformerModule
|
||||
from ml_lib.modules.util import (LightningBaseModule, AutoPadToShape, F_x)
|
||||
from util.module_mixins import CombinedModelMixins
|
||||
|
||||
MIN_NUM_PATCHES = 16
|
||||
|
||||
|
||||
class VisualTransformer(CombinedModelMixins,
|
||||
LightningBaseModule
|
||||
):
|
||||
|
||||
def __init__(self, in_shape, n_classes, weight_init, activation,
|
||||
embedding_size, heads, attn_depth, patch_size,use_residual,
|
||||
use_bias, use_norm, dropout, lat_dim, loss,
|
||||
lr, weight_decay, sto_weight_avg, lr_warm_restart_epochs, opt_reset_interval):
|
||||
|
||||
# TODO: Move this to parent class, or make it much easieer to access... But How...
|
||||
a = dict(locals())
|
||||
params = {arg: a[arg] for arg in inspect.signature(self.__init__).parameters.keys() if arg != 'self'}
|
||||
super(VisualTransformer, self).__init__(params)
|
||||
|
||||
self.in_shape = in_shape
|
||||
assert len(self.in_shape) == 3, 'There need to be three Dimensions'
|
||||
channels, height, width = self.in_shape
|
||||
|
||||
# Model Paramters
|
||||
# =============================================================================
|
||||
# Additional parameters
|
||||
self.embed_dim = self.params.embedding_size
|
||||
|
||||
# Automatic Image Shaping
|
||||
self.patch_size = self.params.patch_size
|
||||
image_size = (max(height, width) // self.patch_size) * self.patch_size
|
||||
self.image_size = image_size + self.patch_size if image_size < max(height, width) else image_size
|
||||
|
||||
# This should be obsolete
|
||||
assert self.image_size % self.patch_size == 0, 'image dimensions must be divisible by the patch size'
|
||||
|
||||
num_patches = (self.image_size // self.patch_size) ** 2
|
||||
patch_dim = channels * self.patch_size ** 2
|
||||
assert num_patches >= MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for ' + \
|
||||
f'attention. Try decreasing your patch size'
|
||||
|
||||
# Correct the Embedding Dim
|
||||
if not self.embed_dim % self.params.heads == 0:
|
||||
self.embed_dim = (self.embed_dim // self.params.heads) * self.params.heads
|
||||
message = ('Embedding Dimension was fixed to be devideable by the number' +
|
||||
f' of attention heads, is now: {self.embed_dim}')
|
||||
for func in print, warnings.warn:
|
||||
func(message)
|
||||
|
||||
# Utility Modules
|
||||
self.autopad = AutoPadToShape((self.image_size, self.image_size))
|
||||
|
||||
# Modules with Parameters
|
||||
self.transformer = TransformerModule(in_shape=self.embed_dim, mlp_dim=self.params.lat_dim,
|
||||
heads=self.params.heads, depth=self.params.attn_depth,
|
||||
dropout=self.params.dropout, use_norm=self.params.use_norm,
|
||||
activation=self.params.activation, use_residual=self.params.use_residual
|
||||
)
|
||||
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, self.embed_dim))
|
||||
self.patch_to_embedding = nn.Linear(patch_dim, self.embed_dim) if self.params.embedding_size \
|
||||
else F_x(self.embed_dim)
|
||||
self.cls_token = nn.Parameter(torch.randn(1, 1, self.embed_dim))
|
||||
self.dropout = nn.Dropout(self.params.dropout)
|
||||
|
||||
self.to_cls_token = nn.Identity()
|
||||
|
||||
self.mlp_head = nn.Sequential(
|
||||
nn.LayerNorm(self.embed_dim),
|
||||
nn.Linear(self.embed_dim, self.params.lat_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(self.params.dropout),
|
||||
nn.Linear(self.params.lat_dim, n_classes),
|
||||
nn.Softmax()
|
||||
)
|
||||
|
||||
def forward(self, x, mask=None, return_attn_weights=False):
|
||||
"""
|
||||
:param x: the sequence to the encoder (required).
|
||||
:param mask: the mask for the src sequence (optional).
|
||||
:return:
|
||||
"""
|
||||
tensor = self.autopad(x)
|
||||
p = self.params.patch_size
|
||||
tensor = rearrange(tensor, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p)
|
||||
|
||||
tensor = self.patch_to_embedding(tensor)
|
||||
b, n, _ = tensor.shape
|
||||
|
||||
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b)
|
||||
|
||||
tensor = torch.cat((cls_tokens, tensor), dim=1)
|
||||
tensor += self.pos_embedding[:, :(n + 1)]
|
||||
tensor = self.dropout(tensor)
|
||||
|
||||
if return_attn_weights:
|
||||
tensor, attn_weights = self.transformer(tensor, mask, return_attn_weights)
|
||||
else:
|
||||
attn_weights = None
|
||||
tensor = self.transformer(tensor, mask)
|
||||
|
||||
tensor = self.to_cls_token(tensor[:, 0])
|
||||
tensor = self.mlp_head(tensor)
|
||||
return Namespace(main_out=tensor, attn_weights=attn_weights)
|
||||
|
||||
def additional_scores(self, outputs):
|
||||
return MultiClassScores(self)(outputs)
|
120
models/transformer_model_horizontal.py
Normal file
120
models/transformer_model_horizontal.py
Normal file
@ -0,0 +1,120 @@
|
||||
import inspect
|
||||
from argparse import Namespace
|
||||
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
from einops import repeat
|
||||
from torch import nn
|
||||
|
||||
from ml_lib.metrics.multi_class_classification import MultiClassScores
|
||||
from ml_lib.modules.blocks import TransformerModule
|
||||
from ml_lib.modules.util import (LightningBaseModule, AutoPadToShape, F_x, SlidingWindow)
|
||||
from util.module_mixins import CombinedModelMixins
|
||||
from variables import N_CLASS_multi
|
||||
|
||||
MIN_NUM_PATCHES = 16
|
||||
|
||||
|
||||
class HorizontalVisualTransformer(CombinedModelMixins,
|
||||
LightningBaseModule
|
||||
):
|
||||
|
||||
def __init__(self, in_shape, n_classes, weight_init, activation,
|
||||
embedding_size, heads, attn_depth, patch_size,use_residual,
|
||||
use_bias, use_norm, dropout, lat_dim, features, loss,
|
||||
lr, weight_decay, sto_weight_avg, lr_warm_restart_epochs, opt_reset_interval):
|
||||
|
||||
# TODO: Move this to parent class, or make it much easieer to access... But How...
|
||||
a = dict(locals())
|
||||
params = {arg: a[arg] for arg in inspect.signature(self.__init__).parameters.keys() if arg != 'self'}
|
||||
super(HorizontalVisualTransformer, self).__init__(params)
|
||||
|
||||
self.in_shape = in_shape
|
||||
assert len(self.in_shape) == 3, 'There need to be three Dimensions'
|
||||
channels, height, width = self.in_shape
|
||||
|
||||
# Model Paramters
|
||||
# =============================================================================
|
||||
# Additional parameters
|
||||
self.n_classes = N_CLASS_multi
|
||||
self.embed_dim = self.params.embedding_size
|
||||
self.height = height
|
||||
self.width = width
|
||||
self.channels = channels
|
||||
|
||||
self.new_height = ((self.height - self.params.patch_size)//1) + 1
|
||||
|
||||
num_patches = self.new_height - (self.params.patch_size // 2)
|
||||
patch_dim = channels * self.params.patch_size * self.width
|
||||
assert patch_dim
|
||||
assert num_patches >= MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for ' + \
|
||||
f'attention. Try decreasing your patch size'
|
||||
|
||||
# Correct the Embedding Dim
|
||||
if not self.embed_dim % self.params.heads == 0:
|
||||
self.embed_dim = (self.embed_dim // self.params.heads) * self.params.heads
|
||||
message = ('Embedding Dimension was fixed to be devideable by the number' +
|
||||
f' of attention heads, is now: {self.embed_dim}')
|
||||
for func in print, warnings.warn:
|
||||
func(message)
|
||||
|
||||
# Utility Modules
|
||||
self.autopad = AutoPadToShape((self.new_height, self.width))
|
||||
self.dropout = nn.Dropout(self.params.dropout)
|
||||
self.slider = SlidingWindow((channels, *self.autopad.target_shape), (self.params.patch_size, self.width),
|
||||
keepdim=False)
|
||||
|
||||
# Modules with Parameters
|
||||
self.transformer = TransformerModule(in_shape=self.embed_dim, mlp_dim=self.params.lat_dim,
|
||||
heads=self.params.heads, depth=self.params.attn_depth,
|
||||
dropout=self.params.dropout, use_norm=self.params.use_norm,
|
||||
activation=self.params.activation
|
||||
)
|
||||
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, self.embed_dim))
|
||||
self.patch_to_embedding = nn.Linear(patch_dim, self.embed_dim) if self.params.embedding_size \
|
||||
else F_x(self.embed_dim)
|
||||
self.cls_token = nn.Parameter(torch.randn(1, 1, self.embed_dim))
|
||||
self.to_cls_token = nn.Identity()
|
||||
|
||||
self.mlp_head = nn.Sequential(
|
||||
nn.LayerNorm(self.embed_dim),
|
||||
nn.Linear(self.embed_dim, self.params.lat_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(self.params.dropout),
|
||||
nn.Linear(self.params.lat_dim, n_classes),
|
||||
nn.Softmax()
|
||||
)
|
||||
|
||||
def forward(self, x, mask=None, return_attn_weights=False):
|
||||
"""
|
||||
:param x: the sequence to the encoder (required).
|
||||
:param mask: the mask for the src sequence (optional).
|
||||
:param return_attn_weights: wether to return the attn weights (optional)
|
||||
:return:
|
||||
"""
|
||||
tensor = self.autopad(x)
|
||||
tensor = self.slider(tensor)
|
||||
|
||||
tensor = self.patch_to_embedding(tensor)
|
||||
b, n, _ = tensor.shape
|
||||
|
||||
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b)
|
||||
|
||||
tensor = torch.cat((cls_tokens, tensor), dim=1)
|
||||
tensor += self.pos_embedding[:, :(n + 1)]
|
||||
tensor = self.dropout(tensor)
|
||||
|
||||
if return_attn_weights:
|
||||
tensor, attn_weights = self.transformer(tensor, mask, return_attn_weights)
|
||||
else:
|
||||
attn_weights = None
|
||||
tensor = self.transformer(tensor, mask)
|
||||
|
||||
tensor = self.to_cls_token(tensor[:, 0])
|
||||
tensor = self.mlp_head(tensor)
|
||||
return Namespace(main_out=tensor, attn_weights=attn_weights)
|
||||
|
||||
def additional_scores(self, outputs):
|
||||
return MultiClassScores(self)(outputs)
|
48
multi_run.py
Normal file
48
multi_run.py
Normal file
@ -0,0 +1,48 @@
|
||||
from argparse import Namespace
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from main import run_lightning_loop
|
||||
from ml_lib.utils.config import parse_comandline_args_add_defaults
|
||||
|
||||
import itertools
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
# Set new values
|
||||
hparams_dict = dict(model_name=['VisualTransformer'],
|
||||
max_epochs=[150],
|
||||
batch_size=[50],
|
||||
random_apply_chance=[0.5],
|
||||
loudness_ratio=[0],
|
||||
shift_ratio=[0.3],
|
||||
noise_ratio=[0.3],
|
||||
mask_ratio=[0.3],
|
||||
lr=[0.001],
|
||||
dropout=[0.2],
|
||||
lat_dim=[32, 64],
|
||||
patch_size=[8, 12],
|
||||
attn_depth=[12],
|
||||
heads=[6],
|
||||
embedding_size=[16, 32],
|
||||
loss=['ce_loss'],
|
||||
sampler=['WeightedRandomSampler']
|
||||
)
|
||||
|
||||
keys, values = zip(*hparams_dict.items())
|
||||
permutations_dicts = [dict(zip(keys, v)) for v in itertools.product(*values)]
|
||||
for permutations_dict in tqdm(permutations_dicts, total=len(permutations_dicts)):
|
||||
# Parse comandline args, read config and get model
|
||||
cmd_args, found_data_class, found_model_class = parse_comandline_args_add_defaults(
|
||||
'_parameters.ini', overrides=permutations_dict)
|
||||
|
||||
hparams = dict(**cmd_args)
|
||||
hparams.update(permutations_dict)
|
||||
hparams = Namespace(**hparams)
|
||||
|
||||
# RUN
|
||||
# ---------------------------------------
|
||||
print(f'Running Loop, parameters are: {permutations_dict}')
|
||||
run_lightning_loop(hparams, found_data_class, found_model_class)
|
||||
print(f'Done, parameters were: {permutations_dict}')
|
||||
pass
|
34
rebuild_dataset.py
Normal file
34
rebuild_dataset.py
Normal file
@ -0,0 +1,34 @@
|
||||
from argparse import Namespace
|
||||
|
||||
import warnings
|
||||
|
||||
from ml_lib.utils.config import parse_comandline_args_add_defaults
|
||||
|
||||
warnings.filterwarnings('ignore', category=FutureWarning)
|
||||
warnings.filterwarnings('ignore', category=UserWarning)
|
||||
|
||||
|
||||
def rebuild_dataset(h_params, data_class):
|
||||
|
||||
# START
|
||||
# =============================================================================
|
||||
# Let Datamodule pull what it wants
|
||||
datamodule = data_class.from_argparse_args(h_params)
|
||||
assert datamodule.purge()
|
||||
datasets = datamodule.prepare_data()
|
||||
datasets = datamodule.setup()
|
||||
print(f'Dataset length is: {len(datasets)}')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Parse comandline args, read config and get model
|
||||
cmd_args, found_data_class, _ = parse_comandline_args_add_defaults('_parameters.ini')
|
||||
|
||||
# To NameSpace
|
||||
hparams = Namespace(**cmd_args)
|
||||
|
||||
# Start
|
||||
# -----------------
|
||||
rebuild_dataset(hparams, found_data_class)
|
||||
print('done')
|
||||
pass
|
@ -1,5 +1,7 @@
|
||||
from torch import nn
|
||||
|
||||
from ml_lib.additions.losses import FocalLoss, FocalLossRob
|
||||
|
||||
|
||||
class LossMixin:
|
||||
|
||||
@ -7,3 +9,5 @@ class LossMixin:
|
||||
nll_loss = nn.NLLLoss()
|
||||
bce_loss = nn.BCELoss()
|
||||
ce_loss = nn.CrossEntropyLoss()
|
||||
focal_loss = FocalLoss(None)
|
||||
focal_loss_rob = FocalLossRob()
|
||||
|
@ -1,6 +1,7 @@
|
||||
from abc import ABC
|
||||
|
||||
import torch
|
||||
import pandas as pd
|
||||
|
||||
from ml_lib.modules.util import LightningBaseModule
|
||||
from util.loss_mixin import LossMixin
|
||||
@ -11,9 +12,15 @@ class TrainMixin:
|
||||
|
||||
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
|
||||
assert isinstance(self, LightningBaseModule)
|
||||
batch_x, batch_y = batch_xy
|
||||
batch_files, batch_x, batch_y = batch_xy
|
||||
y = self(batch_x).main_out
|
||||
loss = self.ce_loss(y.squeeze(), batch_y.long())
|
||||
|
||||
if self.params.loss == 'focal_loss_rob':
|
||||
labels_one_hot = torch.nn.functional.one_hot(batch_y, num_classes=5)
|
||||
loss = self.__getattribute__(self.params.loss)(y, labels_one_hot)
|
||||
else:
|
||||
loss = self.__getattribute__(self.params.loss)(y, batch_y.long())
|
||||
|
||||
return dict(loss=loss)
|
||||
|
||||
def training_epoch_end(self, outputs):
|
||||
@ -23,21 +30,23 @@ class TrainMixin:
|
||||
summary_dict = {f'mean_{key}': torch.mean(torch.stack([output[key]
|
||||
for output in outputs]))
|
||||
for key in keys if 'loss' in key}
|
||||
for key in summary_dict.keys():
|
||||
self.log(key, summary_dict[key])
|
||||
summary_dict.update(epoch=self.current_epoch)
|
||||
self.log_dict(summary_dict)
|
||||
|
||||
|
||||
class ValMixin:
|
||||
|
||||
def validation_step(self, batch_xy, batch_idx, *args, **kwargs):
|
||||
assert isinstance(self, LightningBaseModule)
|
||||
batch_x, batch_y = batch_xy
|
||||
batch_files, batch_x, batch_y = batch_xy
|
||||
model_out = self(batch_x)
|
||||
y = model_out.main_out
|
||||
|
||||
val_loss = self.ce_loss(y.squeeze(), batch_y.long())
|
||||
val_loss = self.ce_loss(y, batch_y.long())
|
||||
|
||||
return dict(val_loss=val_loss,
|
||||
self.metrics.update(y, batch_y) # torch.argmax(y, -1), batch_y)
|
||||
|
||||
return dict(val_loss=val_loss, batch_files=batch_files,
|
||||
batch_idx=batch_idx, y=y, batch_y=batch_y)
|
||||
|
||||
def validation_epoch_end(self, outputs, *_, **__):
|
||||
@ -49,40 +58,40 @@ class ValMixin:
|
||||
for output in outputs]))
|
||||
for key in keys if 'loss' in key}
|
||||
)
|
||||
|
||||
# Sklearn Scores
|
||||
additional_scores = self.additional_scores(outputs)
|
||||
summary_dict.update(**additional_scores)
|
||||
|
||||
for key in summary_dict.keys():
|
||||
self.log(key, summary_dict[key])
|
||||
pl_metrics, pl_images = self.metrics.compute_and_prepare()
|
||||
self.metrics.reset()
|
||||
summary_dict.update(**pl_metrics)
|
||||
summary_dict.update(epoch=self.current_epoch)
|
||||
|
||||
self.log_dict(summary_dict, on_epoch=True)
|
||||
|
||||
for name, image in pl_images.items():
|
||||
self.logger.log_image(name, image, step=self.global_step)
|
||||
pass
|
||||
|
||||
|
||||
class TestMixin:
|
||||
|
||||
def test_step(self, batch_xy, batch_idx, *_, **__):
|
||||
assert isinstance(self, LightningBaseModule)
|
||||
batch_x, batch_y = batch_xy
|
||||
batch_files, batch_x, batch_y = batch_xy
|
||||
model_out = self(batch_x)
|
||||
y = model_out.main_out
|
||||
test_loss = self.ce_loss(y.squeeze(), batch_y.long())
|
||||
return dict(test_loss=test_loss,
|
||||
batch_idx=batch_idx, y=y, batch_y=batch_y)
|
||||
return dict(batch_files=batch_files, batch_idx=batch_idx, y=y)
|
||||
|
||||
def test_epoch_end(self, outputs, *_, **__):
|
||||
assert isinstance(self, LightningBaseModule)
|
||||
summary_dict = dict()
|
||||
|
||||
keys = list(outputs[0].keys())
|
||||
summary_dict.update({f'mean_{key}': torch.mean(torch.stack([output[key]
|
||||
for output in outputs]))
|
||||
for key in keys if 'loss' in key}
|
||||
)
|
||||
y_arg_max = torch.argmax(outputs[0]['y'])
|
||||
|
||||
additional_scores = self.additional_scores(outputs)
|
||||
summary_dict.update(**additional_scores)
|
||||
pd.DataFrame(data=dict(filenames=outputs[0]['batch_files'], predtiction=y_arg_max))
|
||||
|
||||
for key in summary_dict.keys():
|
||||
self.log(key, summary_dict[key])
|
||||
# No logging, just inference.
|
||||
# self.log_dict(summary_dict, on_epoch=True)
|
||||
|
||||
|
||||
class CombinedModelMixins(LossMixin,
|
||||
|
@ -1,6 +1,6 @@
|
||||
from collections import defaultdict
|
||||
|
||||
from torch.optim import Adam
|
||||
from torch.optim import Adam, AdamW
|
||||
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, LambdaLR
|
||||
from torchcontrib.optim import SWA
|
||||
|
||||
@ -20,12 +20,12 @@ class OptimizerMixin:
|
||||
# 'monitor': 'mean_val_loss' # Metric to monitor
|
||||
)
|
||||
|
||||
optimizer = Adam(params=self.parameters(), lr=self.params.lr, weight_decay=self.params.weight_decay)
|
||||
optimizer = AdamW(params=self.parameters(), lr=self.params.lr, weight_decay=self.params.weight_decay)
|
||||
if self.params.sto_weight_avg:
|
||||
optimizer = SWA(optimizer, swa_start=10, swa_freq=5, swa_lr=0.05)
|
||||
optimizer_dict.update(optimizer=optimizer)
|
||||
if self.params.lr_warmup_steps:
|
||||
scheduler = CosineAnnealingWarmRestarts(optimizer, self.params.lr_warmup_steps)
|
||||
if self.params.lr_warm_restart_epochs:
|
||||
scheduler = CosineAnnealingWarmRestarts(optimizer, self.params.lr_warm_restart_epochs)
|
||||
else:
|
||||
scheduler = LambdaLR(optimizer, lr_lambda=lambda epoch: 0.95 ** epoch)
|
||||
optimizer_dict.update(lr_scheduler=scheduler)
|
||||
|
@ -3,8 +3,6 @@ from pathlib import Path
|
||||
|
||||
sr = 16000
|
||||
|
||||
|
||||
|
||||
PRIMATES_Root = Path(__file__).parent / 'data' / 'primates'
|
||||
|
||||
N_CLASS_multi = 4
|
||||
N_CLASS_multi = 5
|
||||
|
Loading…
x
Reference in New Issue
Block a user