From 010176e80b89bda29af6d217a8c75f6da79552fa Mon Sep 17 00:00:00 2001 From: Si11ium Date: Mon, 1 Feb 2021 10:23:22 +0100 Subject: [PATCH] transition --- __pycache__/__init__.cpython-37.pyc | Bin 0 -> 149 bytes audio_toolset/audio_io.py | 2 - audio_toolset/audio_to_mel_dataset.py | 5 +- audio_toolset/mel_dataset.py | 10 ++- metrics/generative_task_evaluation.py | 68 ++++++++++++++ modules/__pycache__/__init__.cpython-37.pyc | Bin 0 -> 157 bytes .../geometric_blocks.cpython-37.pyc | Bin 0 -> 3110 bytes modules/__pycache__/util.cpython-37.pyc | Bin 0 -> 10602 bytes modules/blocks.py | 83 ++++++++++-------- modules/model_parts.py | 1 + modules/util.py | 17 ---- .../__pycache__/__init__.cpython-37.pyc | Bin 0 -> 163 bytes .../__pycache__/point_io.cpython-37.pyc | Bin 0 -> 1401 bytes utils/__pycache__/__init__.cpython-37.pyc | Bin 0 -> 155 bytes utils/__pycache__/config.cpython-37.pyc | Bin 0 -> 7781 bytes utils/__pycache__/model_io.cpython-37.pyc | Bin 0 -> 3458 bytes utils/__pycache__/tools.cpython-37.pyc | Bin 0 -> 1377 bytes utils/logging.py | 8 +- 18 files changed, 133 insertions(+), 61 deletions(-) create mode 100644 __pycache__/__init__.cpython-37.pyc create mode 100644 metrics/generative_task_evaluation.py create mode 100644 modules/__pycache__/__init__.cpython-37.pyc create mode 100644 modules/__pycache__/geometric_blocks.cpython-37.pyc create mode 100644 modules/__pycache__/util.cpython-37.pyc create mode 100644 point_toolset/__pycache__/__init__.cpython-37.pyc create mode 100644 point_toolset/__pycache__/point_io.cpython-37.pyc create mode 100644 utils/__pycache__/__init__.cpython-37.pyc create mode 100644 utils/__pycache__/config.cpython-37.pyc create mode 100644 utils/__pycache__/model_io.cpython-37.pyc create mode 100644 utils/__pycache__/tools.cpython-37.pyc diff --git a/__pycache__/__init__.cpython-37.pyc b/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c7cf4388c8670319d8901dd7a9d64a88111b255a GIT binary patch literal 149 zcmZ?b<>g`kf=6#M<3RLd5CH>>K!yVl7qb9~6oz01O-8?!3`HPe1o6wx*(xTqIJKxa zCNn1|votrRpeR2pHMyiXrXW8vuOz-CKfa(SGdHs&vn(|xHzz(PGbtuMJ~J<~BtBlR Wpz;=nO>TZlX-=vg$lT9B%m4s4z$V=Q literal 0 HcmV?d00001 diff --git a/audio_toolset/audio_io.py b/audio_toolset/audio_io.py index cbbe053..530daf0 100644 --- a/audio_toolset/audio_io.py +++ b/audio_toolset/audio_io.py @@ -1,5 +1,3 @@ -from typing import Union - import numpy as np try: diff --git a/audio_toolset/audio_to_mel_dataset.py b/audio_toolset/audio_to_mel_dataset.py index cf91c2b..058a326 100644 --- a/audio_toolset/audio_to_mel_dataset.py +++ b/audio_toolset/audio_to_mel_dataset.py @@ -20,7 +20,7 @@ class _AudioToMelDataset(Dataset, ABC): def sampling_rate(self): raise NotImplementedError - def __init__(self, audio_file_path, label, sample_segment_len=1, sample_hop_len=1, reset=False, + def __init__(self, audio_file_path, label, sample_segment_len=0, sample_hop_len=0, reset=False, audio_augmentations=None, mel_augmentations=None, mel_kwargs=None, **kwargs): self.ignored_kwargs = kwargs self.mel_kwargs = mel_kwargs @@ -46,7 +46,7 @@ class _AudioToMelDataset(Dataset, ABC): return self.dataset[item] except FileNotFoundError: assert self._build_mel() - return self.dataset[item] + return self.dataset[item] def __len__(self): return len(self.dataset) @@ -79,7 +79,6 @@ class LibrosaAudioToMelDataset(_AudioToMelDataset): MelToImage() ]) - def _build_mel(self): if self.reset: self.mel_file_path.unlink(missing_ok=True) diff --git a/audio_toolset/mel_dataset.py b/audio_toolset/mel_dataset.py index 1e7b97a..6b6f245 100644 --- a/audio_toolset/mel_dataset.py +++ b/audio_toolset/mel_dataset.py @@ -13,13 +13,16 @@ class TorchMelDataset(Dataset): super(TorchMelDataset, self).__init__() self.sampling_rate = sampling_rate self.audio_file_len = audio_file_len - self.padding = AutoPadToShape((n_mels , sub_segment_len)) if auto_pad_to_shape else None + self.padding = AutoPadToShape((n_mels, sub_segment_len)) if auto_pad_to_shape and sub_segment_len else None self.path = Path(mel_path) self.sub_segment_len = sub_segment_len self.mel_hop_len = mel_hop_len self.sub_segment_hop_len = sub_segment_hop_len self.n = int((self.sampling_rate / self.mel_hop_len) * self.audio_file_len + 1) - self.offsets = list(range(0, self.n - self.sub_segment_len, self.sub_segment_hop_len)) + if self.sub_segment_len and self.sub_segment_hop_len: + self.offsets = list(range(0, self.n - self.sub_segment_len, self.sub_segment_hop_len)) + else: + self.offsets = [0] self.label = label self.transform = transform @@ -29,7 +32,8 @@ class TorchMelDataset(Dataset): with self.path.open('rb') as mel_file: mel_spec = pickle.load(mel_file, fix_imports=True) start = self.offsets[item] - snippet = mel_spec[: , start: start + self.sub_segment_len] + duration = self.sub_segment_len if self.sub_segment_len and self.sub_segment_hop_len else mel_spec.shape[1] + snippet = mel_spec[:, start: start + duration] if self.transform: snippet = self.transform(snippet) if self.padding: diff --git a/metrics/generative_task_evaluation.py b/metrics/generative_task_evaluation.py new file mode 100644 index 0000000..6c823e8 --- /dev/null +++ b/metrics/generative_task_evaluation.py @@ -0,0 +1,68 @@ +from itertools import cycle + +import numpy as np +import torch +from sklearn.metrics import roc_curve, auc, roc_auc_score, ConfusionMatrixDisplay, confusion_matrix +from scipy.spatial.distance import cdist + +from ml_lib.metrics._base_score import _BaseScores + +from matplotlib import pyplot as plt + + +class GenerativeTaskEval(_BaseScores): + + def __init__(self, *args): + super(GenerativeTaskEval, self).__init__(*args) + pass + + def __call__(self, outputs): + summary_dict = dict() + ####################################################################################### + # Additional Score - Histogram Distances - Image Plotting + ####################################################################################### + # + # INIT + y_true = torch.cat([output['batch_y'] for output in outputs]).cpu().numpy() + + y_pred = torch.cat([output['y'] for output in outputs]).squeeze().cpu().numpy() + + attn_weights = torch.cat([output['attn_weights'] for output in outputs]).squeeze().cpu().numpy() + + ###################################################################################### + # + # Histogram comparission + + y_true_hist = np.histogram(y_true, bins=128)[0] # Todo: Find a better value + y_pred_hist = np.histogram(y_pred, bins=128)[0] # Todo: Find a better value + + # L2 norm == euclidean distance + hist_euc_dist = cdist(np.expand_dims(y_true_hist, axis=0), np.expand_dims(y_pred_hist, axis=0), + metric='euclidean') + + # Manhattan Distance + hist_manhattan_dist = cdist(np.expand_dims(y_true_hist, axis=0), np.expand_dims(y_pred_hist, axis=0), + metric='cityblock') + + summary_dict.update(hist_manhattan_dist=hist_manhattan_dist, hist_euc_dist=hist_euc_dist) + + ####################################################################################### + # + idx = np.random.choice(np.arange(y_true.shape[0]), 1).item() + + ax = plt.imshow(y_true[idx].squeeze()) + # Plot using a small number of colors, with unevenly spaced boundaries. + ax2 = plt.imshow(attn_weights[idx].sq, interpolation='nearest', aspect='auto', extent=ax.get_extent()) + self.model.logger.log_image('ROC', image=plt.gcf(), step=self.model.current_epoch) + plt.clf() + + + ####################################################################################### + # + + + ####################################################################################### + # + + plt.close('all') + return summary_dict \ No newline at end of file diff --git a/modules/__pycache__/__init__.cpython-37.pyc b/modules/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..18b0c5c0e5b731bcc446bf1cfec0397b0b9db2d9 GIT binary patch literal 157 zcmZ?b<>g`kf=6#M<3RLd5CH>>K!yVl7qb9~6oz01O-8?!3`HPe1o6ww*(xTqIJKxa zCNn1|votrRpeR2pHMyiXrXW8vuOz-CKfa(SGdHs&vn(|xHzz(PGbtuFKczG$wKyg| aJ~J<~BtBlRpz;=n4MfxqWd3I$W&i-2$SOJj literal 0 HcmV?d00001 diff --git a/modules/__pycache__/geometric_blocks.cpython-37.pyc b/modules/__pycache__/geometric_blocks.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7f53e2911f53b00dfe56cd06f1d43c5c12373f63 GIT binary patch literal 3110 zcmbVOOK%)S5bmCynVns)pMZ%0L4W{R39NuXNWqFi9Fa?6B9cgGv@jY^w>Pt%$EJJM zG0X1BmXZTT;vaaA{1W~{Upe{4Ib5jf*%{jq4oR=Nwz~RNUG-Jf+?bzl5omw>{MV0v zHwgI?2UW9yxei_Z0)!Jz6XN5)m5`MBlp@a(D`h@2yq(yo<2$MAyJ^F3q@M4kO}|OW zEy5k{9un>f(x?5_0rTfT*5Dq{?aGkHL!}pcNxmH>K^pD{gFH|4;%=4&uvExF zo`i*n@Vu(A8KW)%BZyBq@vSFhhjPlTL(gZNar=e#bd8H3 z(2fj?IM-B4JYL%oJhfwMTZv>xwm~H4kyt>$EX^Q@l2EB2_?B;M|(=1HpC6L~5M8Ari(l1IC0eejc$eJ|9+uT1h|uneYuRM8M1;K8#}pt97?_f1BE9Nx=2>mmNFo^2t-`qa5JXjkaM#2Lxa+V9F#--c7b6Y5aLYtt^@(~?67eT5Awe`6*lRC(7-pGBREHUV7^y@>t5?<*H@n`18HIi=F+6(x zyq5pPF+LvGZ$ek!0|5*|X!k9|7$6gp#E~_n6IvkZ*nx$3JhsR5*x_snVcvIlaNdGF zjath39=8F}XOZe`os>Y7ufa7HC|xf0;z)Fw<;zxK3ZNAh67o0E%}hKFMnVDrZ50E9 zFV}7;!>lLdWf*kq(k>T(mupDgKyn_81}_&uu1tNP7ibGWxJpN>|1wVnMJhjl)g4#j zCiyOs_mCLJAo%3_Na|ZEXBfz#3sH#@a&*BlPK`iPx0iz)yL0>Q{c{*Y&mqp#RUqdv zR)HOS&wYY{qiF?rQ?)XZru})h9WECFwySqI1{0i zsJx!Lfw^TXm+7axV`*#sil#AaMrWU9lxMAtBvwU~r-N%BqjIPzBT3jm8EaU4FJhP53xei0DLjWhu5|L?_=fmv;2AL?Tm@}ZWGth75=Snx@Zal05 zy)GR#plj4Vt5t-mVne^;-moRY_(%KGX4@&5e$hl2S`~g95+&#Gi*La8LLU|1+?{Mxchzk;2{ zi<<5la8<6_zzMX*%u}cEG0VSznX~C}xzecyP=t908~)b?{w55dz+)+fuNx@ZWtn#& z8QPreHszJVf+7ziNSRpTF-<~5plgAnu*)^&+D0cEB~4^dlM5k0@TO9Z8Bijg%4am5 zm)?ALKTx~z!1tgumt4!#XTv`aC72&($#0cLuY)B=Oem3+IA#L zI*Zag1)ZTi^L??NWo5vxm|;+Vh~WfoF27AMshRX|tfJi#HKuNxOlsh0?l|+(x(sS> aqx^EXh8fL#d4Nu|ZN0@_U0GRJY5xOT2yD;* literal 0 HcmV?d00001 diff --git a/modules/__pycache__/util.cpython-37.pyc b/modules/__pycache__/util.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..889848bec7a7c0fb4d8291bad337412d181d346d GIT binary patch literal 10602 zcmb_iTZ|l6TCP)7eVv}3Tic$olZ|7dkYSf0?7{-UWF0$kHraRxaaPGT6N;X$>Y488 zuI`+wiS5o9b~6siB88g}LPDT15)Zr}AtYW95-;!oPayG95=dNL{I<_QT)yu=RehQ1 zkrRh%oztgIom1yO|NZ|@U7epVTeuG2|G_W)%j1^yAH2z~92!?pqOU8<@-1I=t)8kY z#kSqGdrsYvw$pX%F8bWA+soB+(w9R!UoWK3i}jM9_Y2)}Z>~P4td}jn=$GzUekmw* zDuFt5>+|R<`*Y}KNqbh%Q}fTD=M3Il!`)f`9PZAg&(EM|)n7x;TH13KJ?H%k=(&*goI}rJ{^RI* zJXk$+>T9_FlK%wmpGZfXN6$t7N%TCK_FV9to7VcJ2gK&ON}Lz2Uc;aJdL!Q5M8gh4 z)aBNG*o@o#u+e>Bn}G{A`hL)Ty`dYuAP#iYU4z&Aa{Fg9p&acS0GTQTgRoZj@g1?C86@STJ z#{ay(;@9wB_0RZc@xS1o^H=e|7%X*`J1cnqnt%SDs@FcUT&sQt8@`aNyw={?jl*`h zbF~o#*v)=7Xo7>};y=qZlqiqt!0K3-rihcG=Y`m{=b`O1yV0LodJaQAy?X7rt#3!b z#a6r9ZSVKC2D;w~nsKx>=(oezi~HU{w|nil{a&!u>w4Yx_Et~UjkfmVc6VcNCvo}a zWv)s*TlR2u2ACThG=kN|I7)}W`ZgMOo{6m^1&HnY>Ig``ukPaRKpm>P_JQ4TecQJ_ zwAZj7*NSs{W&MV-?m|X}l{kOk98z+?6F&NW#d+KMjuqN(;|{xbH#L#+K&t=6wO&Ya zabGuglbmjZewY;7p%?8o20>EV52JVYgJ2jWxwPk{+s$A=E=qF1Lm02Si5msoR-$ev z`Tl-9*pCwx2rf-k!D70G>MTlBMP=ExT2@6>Rg0*HYqO9et2G6EoPa`hWQi`Jime#f zIJA%K`ylZF+MW};#12|H*(a!+2h{t6K3;b-U5q*hmmRY#W zyGJ>&UjD$okCAVIael#i!vX^-Fi?K;P0CMjyB)<|e{bDR3cz3x#*uy!s|^<*4nfyo zZ==86N$gwO>jh1)l3Z`kZ|fv4lcU7#4}wrH@`)QY-V5{-Xs$bgiKJ}aGk_9HRO=W= zB^8dRH2Upm@5eAL662!IstVX@RShrBqRZr@DcTUbLL0Ix5N35$$3EnC!ADNwbb~M{ z_>CA^Ca%xNx&awUZ%?6Ik&ivh`#OOW>e;iIB^x*e5YaUS(2=?it!n_X^_TIuzBq=X z+i&=RJ_VG8&@&^q8*y_th@P8(=!a&Z>Dc%wz(!n5p^B{!U=|EeDp}d=$FKAT-Jl0< z2>h3{?wMqQJCLP!0!o29iJPd%G@OHf@l<;of$y(vI{&7s# zUuAVn2N-sPjq{>77`%uP8CXxw0!v7pjeZ2|oDN7S4j}Rb;^he-9-j$BHgcqvh{@A= z$B@UU;G|&e$O+0+W&L2#Z|-_Q=)XDv#HCY#7!4mmLVB@W|bjk9^ zNnH^OWyVD}2p=5-OO;R2XOqL%&uXgur0jXdA%Uf~k3%48ZwFB+#hiynk z&(lB4XY;8YZ3f+LW=qN6P*qg+THFIlE3H6u9#zqjS{ z;}sMk<18Ol9Ah^r-=sftqkX#_PDunsmhXr@Mn$zBtJv;1E&I@pYOxFDCf@G%X%OI6 z?3MI)q^#kd^LeGC`nP`wCE zPDd^*evGlkrGvAUqr(Qa%WsZ=q{I-%Hwh$3VQdq?67TO20PeZ^DcY;7@1|HUUwp%L zXH(PV(hI2I;m~Yk?#3G1$5qWv0GqxKAt%1*hi0nghsPM)kcS(WyZvUP8$G`(RDb;pEtbyor{Fy(qKajtu}M+%gf=Sx^p1$xpQ&)fo6U%)U;t8DZOv6?yC z(oQwOcn4HNdnl>$tEyoI%WsUFYOc?V|09z^DLF9yk)S18LWm;o9xp&m)Pcx@wZUK zC$r$1R~5ph#dKSU1G}2n?D-6AP!IuRK80++Ov)r16qC7E z;8N>BAIuPJCWSCS$OUak4a+MtWWo=$(bExf>~!1?_`u6ctXwb@KXB1gI&e^zGwRP7 zN>fc4N^>KcuTSQ{5g7--s84dR z*3=bn@La<-$6(6ow=jR^qlfU;uVY@89{}jXl}V5@9&;C>HU@9{djzZ*HQlC(fYZ z8)N$5 znO{ohxA45w8O;$X$ZaNh6U_-NcHXc3*~|Yny!Q>Td%F>>mywEWL@@7KY_8r0bS52^ ze=`Up=u$WI+r7kPiYUpokdpQUIhuq>F_fs@G5j#q35AS!>rXH&T13SVvdYw$U6Iys z8f9km3|Etw)1*UsTlpGi)KGziz_1@eN3@_> zn3ETCVnm!6$~>c0w39p>IFNG!iQ=k%4bT1(CDm?=uojb0WYYX-Hi9Vl^X%masvXUq zCd%mIu6TVO;#6(~dI#yHFG8Ij)_gd1{+f`b)HyS`>Phl&)KnypvV$Nz(UY{={C|~ck2S!0?53{EfD{0E|Sp%~WV^i}C zRg>Y0(a~&~`8&Q_;1=&@49ZEWa1=eqYn*ZwmC2@1|7J&9zO#zMT;Dx(m=nRtgLC9E z0e6(UPoYvrdGXwE4sq8x;2~#joEst0xzn3`J(Zkf7>cyHCo>EBB@EDvl#a8&H_`VG zD8gz;C_Fu}rIZ22xsHjH!Z@!|(3^bkMOH+dQQ@UVqs7pXU+{@*Y(1LE{ibpj3 zl)l8G&hLX;$*Bi8kQIXf?&jPz8VA#MCfaZiKeSu!WrXX(%_coZ-6bA^_-zQ5$j(Wr zopF#JZ~(WTdKll6 zu$VEY#Yf_i+`-g|4b$iv_b^85A5kO>D=MUqOPYOCOiVW5HO0ga$$~9^v#;C3ei%2p zHwWD|(lWZ4y@&r48U9E2P)++MvBT48#2$9^B_@_70?&!#OBI*jVibGXu#Z+ZU~O zUZbtRol~>!x(ATOEV#Ri6bR)vlT^wpk{~%sD%&kcC!`tY-Z=lx)7O7@`7I-G-;3%c`S+meC6lb<#KGVsC*2c# zJn{^GfOn>Xn+K-oz(1|3_VD}+Jbf`wh0O5pTupSoU8!jHG*L!Z#^%wp=*oh<$*+M% z`Kw|*m*%gK+QgSS)c=KKaSSwFtO>`SDuC-=LlHj+l!K~~pc_mt{|bRQd!b2NP8w9F zAF>z)%1i~U#F+O?mZo0maWUVVXBt_g<>?kUCCAQg3R^D5~s1<)C7Q^rNq1VK1)75ksWC6S(2Ihm(i*(4erR7Pxy)} z{fsQJ2i_drCWg5x3_F0js3oZ8juYHhAoSzcIP PtS&8=@vAH@EwB9_8xlAo literal 0 HcmV?d00001 diff --git a/modules/blocks.py b/modules/blocks.py index f3ebbe8..2d3a359 100644 --- a/modules/blocks.py +++ b/modules/blocks.py @@ -12,7 +12,7 @@ from einops import rearrange import sys sys.path.append(str(Path(__file__).parent)) -from .util import AutoPad, Interpolate, ShapeMixin, F_x, Flatten, ResidualBlock, PreNorm +from .util import AutoPad, Interpolate, ShapeMixin, F_x, Flatten DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') @@ -85,7 +85,6 @@ class ConvModule(ShapeMixin, nn.Module): else: pass - def forward(self, x): tensor = self.norm(x) tensor = self.conv(tensor) @@ -100,12 +99,13 @@ class PreInitializedConvModule(ShapeMixin, nn.Module): def __init__(self, in_shape, weight_matrix): super(PreInitializedConvModule, self).__init__() self.in_shape = in_shape + self.weight_matrix = weight_matrix raise NotImplementedError # ToDo Get the weight_matrix shape and init a conv_module of similar size, # override the weights then. def forward(self, x): - + x = torch.matmul(x, self.weight_matrix) # ToDo: This is an Placeholder return x @@ -214,8 +214,9 @@ class RecurrentModule(ShapeMixin, nn.Module): tensor = self.rnn(x) return tensor + class FeedForward(nn.Module): - def __init__(self, dim, hidden_dim, dropout = 0.): + def __init__(self, dim, hidden_dim, dropout=0.): super().__init__() self.net = nn.Sequential( nn.Linear(dim, hidden_dim), @@ -224,31 +225,35 @@ class FeedForward(nn.Module): nn.Linear(hidden_dim, dim), nn.Dropout(dropout) ) + def forward(self, x): return self.net(x) + class Attention(nn.Module): - def __init__(self, dim, heads = 8, dropout = 0.): + def __init__(self, dim, heads=8, dropout=0.): super().__init__() self.heads = heads - self.scale = dim ** -0.5 + self.scale = dim / heads ** -0.5 - self.to_qkv = nn.Linear(dim, dim * 3, bias = False) + self.to_qkv = nn.Linear(dim, dim * 3, bias=False) self.to_out = nn.Sequential( nn.Linear(dim, dim), nn.Dropout(dropout) ) - def forward(self, x, mask = None): + def forward(self, x, mask=None, return_attn_weights=False): + # noinspection PyTupleAssignmentBalance b, n, _, h = *x.shape, self.heads - qkv = self.to_qkv(x).chunk(3, dim = -1) - q, k, v = [rearrange(t, 'b n (h d) -> b h n d', h = h) for t in qkv] + + qkv = self.to_qkv(x).chunk(3, dim=-1) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), qkv) dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale mask_value = -torch.finfo(dots.dtype).max if mask is not None: - mask = F.pad(mask.flatten(1), [1, 0], value = True) + mask = F.pad(mask.flatten(1), (1, 0), value=True) assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions' mask = mask[:, None, :] * mask[:, :, None] dots.masked_fill_(~mask, mask_value) @@ -258,39 +263,47 @@ class Attention(nn.Module): out = torch.einsum('bhij,bhjd->bhid', attn, v) out = rearrange(out, 'b h n d -> b n (h d)') - out = self.to_out(out) - return out - -class Transformer(nn.Module): - def __init__(self, dim, depth, heads, mlp_dim, dropout): - super().__init__() - self.layers = nn.ModuleList([]) - for _ in range(depth): - self.layers.append(nn.ModuleList([ - ResidualBlock(PreNorm(dim, Attention(dim, heads = heads, dropout = dropout))), - ResidualBlock(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))) - ])) - - def forward(self, x, mask = None, *_, **__): - for attn, ff in self.layers: - x = attn(x, mask = mask) - x = ff(x) - return x + out = self.to_out(out) + if return_attn_weights: + return out, attn + else: + return out class TransformerModule(ShapeMixin, nn.Module): - def __init__(self, in_shape, hidden_size, n_heads, num_layers=1, dropout=None, use_norm=False, activation='gelu'): + def __init__(self, in_shape, depth, heads, mlp_dim, dropout=None, use_norm=False, activation='gelu'): super(TransformerModule, self).__init__() self.in_shape = in_shape self.flat = Flatten(self.in_shape) if isinstance(self.in_shape, (tuple, list)) else F_x(in_shape) - self.transformer = Transformer(dim=self.flat.flat_shape, depth=num_layers, heads=n_heads, - mlp_dim=hidden_size, dropout=dropout) + self.layers = nn.ModuleList([]) + self.embedding_dim = self.flat.flat_shape + self.norm = nn.LayerNorm(self.embedding_dim) + self.attns = nn.ModuleList([Attention(self.embedding_dim, heads=heads, dropout=dropout) for _ in range(depth)]) + self.mlps = nn.ModuleList([FeedForward(self.embedding_dim, mlp_dim, dropout=dropout) for _ in range(depth)]) - def forward(self, x, mask=None, key_padding_mask=None): + def forward(self, x, mask=None, return_attn_weights=False, **_): tensor = self.flat(x) - tensor = self.transformer(tensor, mask, key_padding_mask) - return tensor + attn_weights = list() + + for attn, mlp in zip(self.attns, self.mlps): + # Attention + skip_connection = tensor.clone() + tensor = self.norm(tensor) + if return_attn_weights: + tensor, attn_weight = attn(tensor, mask=mask, return_attn_weights=return_attn_weights) + attn_weights.append(attn_weight) + else: + tensor = attn(tensor, mask=mask) + tensor = tensor + skip_connection + + # MLP + skip_connection = tensor.clone() + tensor = self.norm(tensor) + tensor = mlp(tensor) + tensor = tensor + skip_connection + + return (tensor, attn_weights) if return_attn_weights else tensor diff --git a/modules/model_parts.py b/modules/model_parts.py index 2150e64..833faed 100644 --- a/modules/model_parts.py +++ b/modules/model_parts.py @@ -96,6 +96,7 @@ class Generator(ShapeMixin, nn.Module): super(Generator, self).__init__() assert filters, '"Filters" has to be a list of int.' assert filters, '"Filters" has to be a list of int.' + kernels = kernels if kernels else [3] * len(filters) assert len(filters) == len(kernels), '"Filters" and "Kernels" has to be of same length.' interpolations = interpolations or [2, 2, 2] diff --git a/modules/util.py b/modules/util.py index 680a088..be56c35 100644 --- a/modules/util.py +++ b/modules/util.py @@ -150,23 +150,6 @@ class F_x(ShapeMixin, nn.Module): return x -class ResidualBlock(nn.Module): - def __init__(self, fn): - super().__init__() - self.fn = fn - def forward(self, x, **kwargs): - return self.fn(x, **kwargs) + x - - -class PreNorm(nn.Module): - def __init__(self, dim, fn): - super().__init__() - self.norm = nn.LayerNorm(dim) - self.fn = fn - def forward(self, x, **kwargs): - return self.fn(self.norm(x), **kwargs) - - class SlidingWindow(ShapeMixin, nn.Module): def __init__(self, in_shape, kernel, stride=1, padding=0, keepdim=False): super(SlidingWindow, self).__init__() diff --git a/point_toolset/__pycache__/__init__.cpython-37.pyc b/point_toolset/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..587046070b7a06b993cd7a8b95dbcb6a37b112ee GIT binary patch literal 163 zcmZ?b<>g`kg09Dh<3RLd5CH>>K!yVl7qb9~6oz01O-8?!3`HPe1o6wq*(xTqIJKxa zCNn1|votrRpeR2pHMyiXrXW8vuOz-CKfa(SGdHs&vn(|xHzz(PGYKw}pHrM#5)&Vv dnU`4-AFo$Xd5gm)H$SB`C)EyQ%V!{F006h1EYko0 literal 0 HcmV?d00001 diff --git a/point_toolset/__pycache__/point_io.cpython-37.pyc b/point_toolset/__pycache__/point_io.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4756e198b635230c6284dbd91f467437b0adc5b5 GIT binary patch literal 1401 zcmZux&2A($5VqZR&rHu>gd$jAPYA7;OCA745!zje3!2Mn50+mn2$1DEyNh znNI;&WU)gxWJ}&4j2zbQJG@KbB_W9x)BehvnN>%H6vC_?@EUj6rnrdmeZTL!8b!Dq7xu_La)#;t2QAXO`s%nl_ z8{a=xMSUuYlRg~qHxRmh@=kboe7T62xlb4pOB0^tjNv;wC5}h7MDdWIrkaobtBd) z;1TC9*M;gNvBY&a%#U{+P|D!My>~OL#~_bR#)g9Tuj`0Dp3umdG5rBB^TU`EelD9@ kT3wY#G0u<)pEjrADTwO#wmZ~?f82R^65g8Fa1ea*AJsix!Tg`kf=6#M<3RLd5CH>>K!yVl7qb9~6oz01O-8?!3`HPe1o6wo*(xTqIJKxa zCNn1|votrRpeR2pHMyiXrXW8vuOz-CKfa(SGdHs&vn(|xHzz(PGbyIDBr~TtCO$qh cFS8^*Uaz3?7Kcr4eoARhsvXGm&p^xo074NdZ~y=R literal 0 HcmV?d00001 diff --git a/utils/__pycache__/config.cpython-37.pyc b/utils/__pycache__/config.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..496676267a234d379f808e8f73a05238e3f00e54 GIT binary patch literal 7781 zcmcgx%X1vZdGDT?ePDM1k|0G9q(m-hJ-D(A%d+JMLopH`P8`LEY%-F)hO{-<9smQ* zKG-uLf%UQ~B``(0T#}RW!3SFv;O0~=x#wtK^AF6`r<`-sDZk%0vx5Z$naagjbniDk z{rcXIy*@izQ~28*|KY>`KBLs1_%Qh^0eA~79-%YJ=t#Bs?L@lgv>i>@johBs_5}8# zQm@=D_bTm*JeQ(sZ>BvXa5<{=>g~F~m8j91ZO;m9lHsVAvfDw*k9vU%%0E(wQZ{7sNxq8>Dk@+qKPJ&`;j= zZ5&w4sheQc-;5Fyb`v}{ZVdXX;ac(9d)H6ayPCV>&q1rA-9U?9Lua&6k$R}N=f-+h zkJX;Wm7UwlQ;Bo8W?#{YALHIuKLIIlQE$m|vpe5!xv94i5Bh1zZ1#q6I&)`x7~HaU zU~yx;WoIz!-<6%k*JsynTv`4{TsK}0qbS_$Ef4KrBj_ga@^BFLlTI?|3~ks8lkk49 z+>1IJuV+i5+FJ3P-{i2`mKH_$0v)sC)dPtR$449{H-PX4?ch;QP` zfocQSwiCF(*?FXPblU^$8c*O7;F2i|Tn1b*Re>vjXG~4tD&V?l2s{IL*31c91ANRJ z7q|}igqat(0r;djCGafZXUu7V=Kw!z&Io)A@Pauj@NvNB%yR;t0Q|f;FYr9z3#KXX zNpsO$f*_rGtX8%8ih1Etwf#(R+Pr99lK0OBXMzRu^6#{{Y+8>zV0jiJ=gcb@`KozU zMxHZYGmF^cdGmGi8otk)Z* zH9DVNf3XjZTD%$!x_%U2TP&ugy3_I?*O1Dz)Ek%}O1%Vfoz_>w{#sxmrTqj)dT|h~ zwrlbFz$P8;b?Efe;~T4(*x7e_8jIPKDd(4eje`xY2}RO@{{vdM;}fqB{|P_cnNpIJICkJIqE%# zUbZwHX(_R|M7uKBOop&6mRQ;acF&>vf@@X5i4Lb<#GJ{!j^+27yelqN#w0e8h~EQ9 z)UMuxnMxd9Y)3zI#*Wbo5VSWHAQzA)P)Yjon`-My;*Fi%(jLU`ruyKOUtr#DS%jLZ zT^&15)PmZr?5RvZHQql_Z>oE5y|0pLGK1N8m#My~f9!msmYs+0Lk|*JPHJN))^2@I z%q2z|xLavl0+||1`}9aX<0pPvelW0i18d2Vw31uC0#%|&R8OMCO>`ke{|pWjB%df~ zjiT;+>i*if=aCckG^PpPTdw80)V&*Qr|x|}qGH4_tPqzeRj6I1bLb9w!)ss1bRmCP zjb6&hKPB2zDp4#ct~!y`+uVbqFov`)P&k~-DP(to+e$Wvhvl+*HS&8a#=rIrEc{VPZAi-3||)}?l7a~J7V`bQ|rD>G&4IkuSW|xLZ6>k;6+Jc6RSo^Nz_ML;vC0+ z882$h{O`_s7&r01y4e?@B}}_Mhy&A122C3zo3`Kdo8PpbjjHMR`zv7n2_0qsZ;W_ zVNo#svd1)^dJJcLh&DlHa*Tu!U-pNBW7{EoY-l#PYdrKDScLeQ#=mC(M)%brJv=;Jg4?*vmB zbnWQ7m4m^zV;MzQrjKwW3_vm$t?l;+Gdas#uvdW$Gl3HbHDV6?9BO z_P^Nm?(JU>3UdKZKcaYv_-6MSAQIaUu-H$I&ULq0hrCvCZkTdUX0|Y8IrU7#0TV z6dtm?7eXO?W!NhLMX`XBDGImIL=@IU81hh*=^J8X2BQC(2cjeAN;3|7AEzutdD@6!i^aFdKH-;E#ebc%LU2aJ~<@FnRMkj6{C0I8OjMy^e(L0Snr{rkY!Vu z_-acm4j41!+RPuOa%_qE$%TJLn@I2*If=g5%^%4TZRsFzp!nt`z}ZC{E3k~pdAIQD z(d0a(M;D)ZUcn$zfyAqGIa9_&K|a-vv*E0Q2A{x9Ov)*4T0+!~9F>NK#*NuR2L*>L zO(@@e;M=tr;mk}Yi;Iz6?3?*_vFI@mAx;We&uEg^Pc|y>-vxvB^oN@1HmXVP*hBruLpl%hvjD5bR2xYVlK!!GRo$CI!VZ5W)(n9cF@;+3 zBqZQW4omSc3KRP*M#KuY$|$npMQ5?#U+g1mE>he!-!bY;YyIE> zuwo54V{yUfbWWBP_xTN`OBq194Wzv_J$fb=&V%!a1jGmK9)_wwVxroD1hpzF^$`Ko zp&_F{#ua3|u?HD&1sSP-EHTfO?K%|8eZ=g~un%~@j4#_kmwGaRh}jR=2{MNrt~y?xt!rg%`G_1`ShT2R|;Att`yd-zZ*Ny`s?I0?hPK>X1@m5$@snX z(ijlRbgp>W)9|t$v>z$x7*=6i+H+A9{_gwg-tqUbYldw}Q0?YeIY4HXB7 zLJ*R|yy{cQ^;Hx=Dm-aZw`q8iqK(F?2}bpjZ$w0?WLpZHr@N?d!7CQaa$m-!%{JN$?p%`!yw$Ak-2k$wLQ zLzzK?k>}4YbMnK>sFvf5RJY(d^_!&y@jll9APGVVNQzp}W)`&Yq>am@1m+6A!|dr_ z&C|~BsNE86zE;#nH zIO1~feUjN`+kb^ll$H`10!ngFp3R7Gqtm&!=|}mSDl$iS-DF!T=UJL?5qoD9M`~zR z)esB9fpqu;o)-v-*b&w zf%#l+)Y0Qz)x~$i(Z?%j)v8-9RqKrz*`;;DhS($Z_;YAh@pb)`u5~$BEsU|Cz}o|} zK@?@>?>LJ98W}G$=VNJxC^HG`$ntB3eHB9zif8JxZv_j{Ppelm{-Fv6tme_d_&IRp L;66XN^UVJN>anTd literal 0 HcmV?d00001 diff --git a/utils/__pycache__/model_io.cpython-37.pyc b/utils/__pycache__/model_io.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d9580cee3ec5a637108395d418e0a37ee1a3ba44 GIT binary patch literal 3458 zcma)9O>f-B8J^+CE|=PsR<4l7j*Tis5@g+0r#1*u*Kp(1El@`=DjP`%H3Y>ONtC$c zc4p*Mi**lP0XaE;pdXNq{Xad;HK*JP^wvxIyu)36taT_CGkiIGIrDxz^SrYk_j)nI z^{;RL{iR%D>|1)6Um=7CX!<2O!2~bakSCmz?3LcoPkbl)WiSkr(8)pB8n%hTV?J+Ue#=Bh#3xL|l1qOYB+D3C5?zdR<;v7g`l7YP z2EFfK2ODtP+02GgkFs3iq5Ww#8Wq(Z#PC^F)D?tCNIA;u(J|!UNoMvpA#|$D=vo;m z@ZeQdPP5Ig)=bbP1j!OESmH@vaN(V>DNh37i{OMMp$J6__m*gj2={hlc@llicDb`T zN;)!@O9x%F9@;Y6is-!xlD^Yf7Aq%Fvida(SaJhi_3g^1wUFhLOu;K7mCpIh7!P+c z>>-+dADv+ncEG3n1)uO!&-i!>CP>Upb`vRi#m|tO3CE=Vrj4`Q6wfo`BBChEgHTaQ zZIq^>$W5Bs*3n30M%sX$Z8yF6L=QY0Xj$$i{Sg@kyV7A3=zlWR$29-_`O%rQ}NI&)h?(wx?d}kp%8&|%_DM)nDotg_4O@ZJb7&jTh!I6Ijc(Z}0R&cZY z(D~4Gt2DBdrhC#9Mh?^T4hFQF*BigJxcVi%^F&d?uB<&zRI|2c{Zdwq0o)m|xOzKH zp^>Il2orXS@Dm~h8RBg|#kNC~k^0OS%sJ5j9*@oF-$9=<%H(^smj7;6a zo;1N>tUl-Lg~!==4c~eA?i1l}f%Uf7XT*JY_ZRPT_64iZ)94`BbQ{dJ5nyTZ{YThp z6ayRTqmfjqLvr6ChkP$~j*P_iTmJR@;WtXvYT!5E3@Q+TomQ9s#?cv zKSOi7h4F|ldmiuO7xBOwufLuu0Q=Ii07P8M?C>~MvOKbZ>)X268`g!efyt`<^FR`) zDTpqR??C{tx#3e!^(Xwmd%=xwz(7;)Al${9Z(MO`VAry7p^HjmzFaDr$YD?Yh$I&f zLRwXThh|7fRvpz42hO<$7=4HIi5C>@Q*{^J12pXz8;EG}@pb5ICjJ4D9fcFoj?;A5 za7_xTOaF9~m9sYymTXd-Y~W(9dL%{YZ09ir@DK9wGaID zQ^e0D-$m{l^p!s#X9uM3v_Fd3&D4bvC1^Pn_ zT`F3!3{$RH&p5Kzj4L z?{NWxalS9}L&xbktw>iJ?VNi4fd|whcE5Tmq73uH(l)vPYK(ND(dW3ov zpmv3m(CL=Y^(P_pf<~`;&jio`ah;;fo>84^nbu@xwTNDM6aOZ|_a5ep)JxMojj%yG zZ?QA}!gHe?{f(g&ndwa6F~;Io9`LxdIqp2&m+K>(1=h#?bx}*bUe#t@{u!kB_BcL& z`Tlry;q72!+`7Ag2$(C(OPs%j=731N(!idpA7c<<`-X^QN^-*=%A-zcVD+?a~4)bW*hRnMSR*U#8?*?BZrsG|2 z{LcRqwDVjM^AJ>jkES=!5mf11cLwH939Q&8g1IM*91yrV`mXU`;tbeu`$zc7GUp}8 z(+8a@8gc{bSCHnVtks-i7Tj}0@EO!*oVDyN^YNW)otQJ?73+z7AJXT9EJE6;tMULn zS&v#zhrZLGxi=^boo&gQi|ekchnL{wY`1}Yo|U$>tE3#uIRW=aC|@coDk+F7Wx3#q zZ7Zn(3;9PfhnSEdC6u7??Q8v+spCAVegWmpK}Y=rclA^1sD{}k2VNX}_iLdjPwHdp z9#KbJuO3t9`2V7Fr1^8aAT?)4gc5jbIcg~y9jpS}Y7f-}fT#RpmAWi(HU;>#qIlYt zIi`ONZb=)^e}(37+$KQdsM>6b2T)Dgvp)`QTR8r2>Q@Up4s*ICud~ob_Yg!JZtoLU SXsRt;j-N*@A}P4>;r{@KR6p?m literal 0 HcmV?d00001 diff --git a/utils/__pycache__/tools.cpython-37.pyc b/utils/__pycache__/tools.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..63eacc3ee8ec70e2e4f82be6a430b00a3db62629 GIT binary patch literal 1377 zcmZuwO^+Nk5VhSO(=$7}Sw*Wz2}tAsg3cl0j8=%SiU<;#4ap*+(TJwg_GCQM-EP^Q z{nG4(;Tq*H%pUni$l}H+e*rE$ckixXg)w)Zx0BD{`-%geh&!w(;L@= zVeu8F`5A%}PD?WGG#yjetYlR%4rsR*j-%Em9>=Zjje9)c;jd(z@QBCIecs~<^nmyI z0D6mW@gek()0ZUOevc?*lu{Qy&+V-6(ft4_5@8--n$I9C+0Yf4vlX4POM1aLT`}wh ze(w?|Q?^f53^%gTdkX(FbgXV%bS|_qE;RC7xH#8&{YIq0xG(u>mTR4#jkl`&G%L#7 z7-3vb)gqfI`y0`HT=LhGgD)mO7@^HXmZe-)lSZq#D6E+@O4c^BDr>Z?q?N~FQk7XL zUr&}+mS$p=D$V2O%njgkb*oRhi6IExrF&FwVO4mret7L!w~Rl>!jd`NFz^dJLrQG8 ziT-8%=Zq6Dr=v91A^vnRcyd*Nwh-KfRW556S*44a1KoO=mzh67)(3^Er}9mvUe8;` zo$x^2-c%w-WtnxCwSSmhoZL{@q53ZDo1*#HK5Z?9S6EoGB5ObgY&u8D@UgcHRDyYE zL1s;DJV(R4W$zwbFdNxq#nyC9FM<`yy9)Nnk7%>+$qVud#ohETk)@VuI66w1OT?)( zHd71NYjQ1WYrrT$iL5GAnGV2iCW=MYptD@uxwxoWaH(BvW}-Y6E>w-EU0^J>8d)q# z;dY;W`^~ebFORb4FAk3m4-UU`K?$FzD3uZDtNv7$qPsLY&da56bn1GwILQ{`>~re9(eHqL>lOaP%Z#ONBM1(yQRu`Hr1-Sg831guj=e!>Q7<+f2(@v z9-yUbW_B$5g?11(AqPzJ(C@U+*yucb+{i^5xq&pYHa4#dq5X?@J5TWZw3VuUs=ypOL# s`wH7WZmZBT?K0?Va4V)u_AAAgrTEf6SF;C^ppHZ4WfCMYzJ3z_14cnUTmS$7 literal 0 HcmV?d00001 diff --git a/utils/logging.py b/utils/logging.py index ec4119d..d0d983f 100644 --- a/utils/logging.py +++ b/utils/logging.py @@ -3,6 +3,7 @@ from pathlib import Path from pytorch_lightning.loggers.base import LightningLoggerBase from pytorch_lightning.loggers.neptune import NeptuneLogger +from neptune.api_exceptions import ProjectNotFound # noinspection PyUnresolvedReferences from pytorch_lightning.loggers.csv_logs import CSVLogger @@ -71,7 +72,12 @@ class Logger(LightningLoggerBase, ABC): experiment_name=self.name, project_name=self.project_name, params=self.config.model_paramters) - self.neptunelogger = NeptuneLogger(**self._neptune_kwargs) + try: + self.neptunelogger = NeptuneLogger(**self._neptune_kwargs) + except ProjectNotFound as e: + print(f'The project "{self.project_name}"') + print(e) + self.csvlogger = CSVLogger(**self._csvlogger_kwargs) self.log_config_as_ini()