From 4fb32c98c6d1a6cd0b969e401302d7428d4c718b Mon Sep 17 00:00:00 2001 From: Steffen Illium Date: Tue, 31 Aug 2021 17:27:19 +0200 Subject: [PATCH] Multithreaded Operation Debugging Doors, when no Doors are present Smaller Bugfixes --- environments/factory/assets/agents/move.png | Bin 0 -> 5933 bytes environments/factory/assets/agents/valid.png | Bin 1530 -> 5717 bytes environments/factory/base/base_factory.py | 130 +++++++++++-------- environments/factory/base/objects.py | 30 +++-- environments/factory/base/registers.py | 13 +- environments/factory/renderer.py | 17 +-- environments/factory/simple_factory.py | 15 ++- environments/helpers.py | 32 +++-- environments/logging/monitor.py | 25 ++-- main.py | 101 ++++++++------ reload_agent.py | 14 +- 11 files changed, 228 insertions(+), 149 deletions(-) create mode 100644 environments/factory/assets/agents/move.png diff --git a/environments/factory/assets/agents/move.png b/environments/factory/assets/agents/move.png new file mode 100644 index 0000000000000000000000000000000000000000..2a56ae404a14ec7eedda1ea1ab84f3c08fe2b842 GIT binary patch literal 5933 zcmV+|7t-j7P) zaB^>EX>4U6ba`-PAZ2)IW&i+q+Rd4JcHAlwhyQ04vjm`kg5~f$XLc~lpRZ8caS|st zxi@ppxZ74+YDrL!-vhb*pMM|sFMd+a(WR`lv{F1jsihV>AF6+T+Wi}Rx_|Fa&aY$P z>;3NiKxCJ&mFHjWukUyE=PxI`-6M>z_q+N!66@T#_Qz*`NUsaE z{d_2`>c^pu_-yCvi+3omU(tK%_lwy7?&FT6L}_KdB{6OZ;`tlpH3G5UAf0XHuSno~ zK{B7SsJjRBH4*~y)%4!B-hT|xS3!RJ08uBSpo%d4Widy9^p!xsl3mwQ^_rTk_->wPs|)ptKh z&hpa_UNf{1$;xpdg%x@@Vc+cxlOd+Kqq2>0k|^6(MPrLoEKkW^;Y5`tiws7YaXZqO zB)*R&-21kBzZoibUV*R1z{LWm{P5*|cJOyz?wn~M3ckL_ig87&T!vwpv%h?cgoN|T zTfPB)d%e={KhzJUN(S={b7g|vkLMC2hcC4yH_yWNGTWysg46Z703qVunZX1h1HMrG zp#-lnY6!%!lAzAWqog>J>xvru~vldmY+O$`xspeX$)>>QbEj9rIb1O|- zZLRelUFSxfYjwWg*)!shkwzY6=%}NOJ}IA>W}ap0th3F&;*teStg>{~)mGoZZIEK8 zop;%~>u$RrYVF8LC!cck)YDFX&)O@i-@Se#YvGl(cr#_+)Ay{g*Qvf;A{d-x%NZGq zDUfls3{cR~a^`DDDO%<%XLL7Bkt2i5mK(Q%WsDT29b!J=J$GL+_e*(mx&Bh#;$LOX zS?c~o<{YW}K5w67Z5hwao3W=0Ev8=RK929&U9895tRw8~t2MB)oK~c_DJk!87X4Rq%ylnvxqE2ov^r)5A_rk2Mp{z|fM}O8 zW@%dnd0tF#J-S-r@7wnd8-mS3$_~Ak z7(olQwkLVHFW+<6+J-|`^brQAl|3cDNl?jIxw0Q^QVl^40R>F}!%ZdGY-c@X`$cz# z)HFnpj~3yE%Bzoj#NtRhd5de9q;l>hy2~hM5`EdK#GF9&)sun^w4GHZ2sx;i{rqF9 z^!K)8XP>IYWKBQ{CBu@q3tgC`H7atAw$wJH=I#%%KBsZA_D~DaAgC)cfCzG0b*W6t+**rIvfb$#jsQYb15F1Z zEjval&|Lr{5JvyUerlfp4tb-bC~iu)wF9aE1+h;x<$sdzw;(*3yPpuAAGd1(ox3V! z0LPIm1YsHZmST#8Bgi0Y+<1|9P;%o`PNucE9bp*o6Q$C@yA77YT7>f&mJ@JTf%!Sb zqR`ST%4`yXq0YULd2xbe7KE@yqeU8ZzwB}?YHjGF7E~k5lVOutQXov|`C2k$dugXY zR#aOyX_BcKt~<(x`r^2~KNAYYS#cG#1!ctzp{RV>6~tu(n|hRXfL+iwWMxlOODw-( zNv~D|i@EmU`1M(+-DmDNkSC12+Kl=I<_4FTsEug_*iI08m3Hn(DT=JfAxtjr9gWT> zEM%k0of$#6)S7i;gMYbb*l85%Yho_oiY|*%UsF3O10-PU#=O*Va)HE|%8-Qm(32T5 z3^9qV?A5zQkV8Q7ab9IVR+V|)TEKnU6Mtj-`mhOuZ7$F^L$AiClqw{8kunVMCSswTMKeMA>0}ThqPnZNQ_BDdnvsz$(VzVTf7r% zdCcfZb~dTwSKG)9NmL&xT8M76Bq}?dsZcGt%!H9;k7h)ZMq!tHy$X8N6ASOg`FiAThrGbbbq?SU&7v-~WW@#XFC29`@Lt&d*t ztH@PB*N!1i);d7_*ZFsU1ZV$0FND^hEQnkO&wxrF6A4G2AnuG2(8qg%_J~P#Bk=_l zfW}9q&)yPB(wsmpt~?A&q5NH0@@(Djosg4}KHRY4UaM$Z5jRq33tQ0i$$Z{Hk}UHF z7!&Fl7Qpf{3-MOYab`>dr~*|i=!otDR=ce!BrJ>&BlCOJI&$Lx5hdU%=!46{Wt$ut z+K@{IpJ9v$g?pmXug$|dbW=9m49rFqfqy#%6GoeXipeSKUX;0YFT=sbVX+P_uSGNc zSnv*}JdpMvnu~uZ|9^R*Z4iX#DUdGQD7lHisS`As&C12#=qzyq>;o3ovl3Ec2iF>; zZfn43CFm%46OxNd64#LMTM-qBG?eBQ4~`TfMF6mnUUrfa~%kf)eYVj3+VQnPL0Nd!#7No#xx(mU6>?W9fvMpBaXz*oz zbK+Yob$!`m*$URr1d01C1!Jx`o9(A(uq4=i3<+~ZTY*R93}56^DT&Iq^@40iF@yAC z#G%|Mz(q6oz%0d{JNWT5PqiQ*h=bd&N0Q*7;+NLj6 z&M;0Sa5(x^`fc7#j4CZ(4!Y@#hE*NDUqthm0Gx5-*Q^)aJj4iS-8WK8%2!J>-?j)e z;Ozm%T@8u_{C-x0Vz8fQH3$ojYdQK}MTo5A|RMXy2E8%g^dLE5fvzby1hMMh}H6&i7MiAv|iLH&k*6~QlJ2(OS$Ss72J8rN%W7dMyw za`o3WQ~-<&iPa7o%xbg3u^82^>YwK*C``<0qmf`~OC88bVvlG<)d48sI4~c`!7hkR zZ$kigO`Om1z&vUXNocnn6z)sh&M2SUVtdbuYqwzzN~og(#g%TR;jWVPFKX^9F_#D> z)_d6U6(-zwQj8p>+*7p0Ew^kXt%IAGuYjdX_gO+QxA#0tkYVDA=uVMmBKF_G9fX>{X9tj~Z!mp8ia{d#@jpT?bFdQA-8Fa!V)Q=ppI*S= zsd**(2;cpwoHwqVjjqK!1y!_ZBJPCW zZ2hjj899Y)4-z3KTN_d4A_yDKL%C1`q0l4nKR73ZBu3CF3O6~0cZUP?i8(+yA*oez za4Ugj=7Uj|-PR%Q^ORFUX=1BeB>Fy2k#wJ@ps*(svFC!edikYn+_!aepQXgv&_PTF z_#8?KwSY_safy8g)WMH+mzv`9JVi%!AXwnehos#HDhmZx-X5sXrc3O3*|L@1^1QJT zQ*C9HyOnm)mP&Q=Cy!IkSEX>4Tx0C=2z zkv&MmKpe$iTct%>9NIxdFhg~+AS&u8RV;#q(pG5I!Q|2}Xws0RxHt-~1qVMCs}3&C zx;nTDg5U>;o12rOir6bq^0;?_xa5{oJ1;Bo_<@cm(1(rs*c}2J!T!sd3&Xj2yKjN3Kf_zi}?v?B$seJ)521~4$GXkI4jjUYu}T< zFj!DmGF+!Qganq5L<#~V)KNhdCStUzq!>uke%!@BX!}#-lE_s7BgX=2P$1fV@IUz7 ztx=quaFT*?p!dbHK1KllF3@UN*7vbxwN3!vGjOGL{Iw=9`$>AEqeYK^{%zpmx}!;Z zz~v4w@T7^lXiGkt-eM7WKcjET1L0esZ_Vkgy^qreAWK~>-v9@Pz-WoG*InKn=y{D4^000SaNLh0L01FcU01FcV0GgZ_ z00007bV*G`2jvJK4LJz9(7HkZ00k*YL_t(|+U=dqOPp5}ho5^#9UTlALl@GzX=x!F z5iImyh)Hl~!G#7HF@z~yNP~+&R^0?saSc*BVN;Bf1 zqsI4kaj$Wl%=i1w_xUYE6lS;&=bU@a{Ww?A42qk6jv2);jA8`Cbzlh103LlnH=qFz zI7M@W#~WY|#SVro6x)Q^B%#uc9*%0*H{UfkIN>&$o507w0nh_Ps`#`+-~_k=?BekX z%}bK%m-dGmGh=wf(2S$VV2I*58w^V4pQlJ+=*43LO|qVVRh?Y>9x#s~0z?~NJSh~1 zz%pT0YaJp~4XyGjumF66;*>M`1GUI%EI33OLpQJlERt0Ftpjz|+{f^(8%l)c(;=Wy zJY-FM*W6HJW(-IG*MO@uzxILufOwEGR|ROD31AvSKdmN38dzsd#hU}_ocWwHo}(CS zNub$a4X+6?$y4=v2WTHa9{?+D1{6Fru;Rv7M@zk9 zp3ucqgL2OL<+g8m#*gaj+Cz<*r@#b?D5i~(D!`BeekQ5rYUloU%oAV&+8ik0#uo)9 z+_-B1HD(^r#T2ToT|S@1M(JXzk{GMX1Ue0@U{`@URR_RV@fDy(acQyK1WSL|N{y~3@eGQeQ&4njUlU2qA6sJujL zp}3{SO+Oj_1;WCsY#r0f2icp6I_KZ_P@D>^Oo zNXJ)B)l^!Bh}-ncf$9;14qJf=PEp4%2dZXV*!Z&8@yk$cNtzKdz9y^;cg>9ps2;a4 z;`%;I*C;34zJRK^Dd~OyWq>^vigF2m3J3|A01~WTqMBnyv0DWq|2p89QH^03dxBZy zUp*LxHHs0M^SEOxDuQA}yGskA;Ie@08W<81foe#j84wPER#pSr-BS_;m#_9|?8Prp zaO&0|i+w={fb*+~q5uPc2H8FdSf?7zk#Gof_CM0@HnJ$V+5e3O_Jl;B+S4d@ghPN* z>}U*I*t?n{|2o95rBQ5S8xV;6>jcHNCd?+CeU?8+T_Mb-c4vTgr3COj>jUhPQ0epC z)K^lh`p!=+$9Hvq4=+%pgp99QKZWL{{qujAKSJ@+%jWH~88aoHv;Gw2RK!fwYuQ%-Zm@|@bX4SIxa zL)>XG6zoFNqjH-nR)(lEdi>?8;+}uf>#rD@y7c<9pz&I={hLMsGX9aO-% z!eii|?Xm(po&(@aw|%G)V5lf3d%W666|bpd>%SX|&qxjlcf^eg(qFc1z|* P00000NkvXXu0mjf>*Gd{ literal 0 HcmV?d00001 diff --git a/environments/factory/assets/agents/valid.png b/environments/factory/assets/agents/valid.png index ae7c768f910fad2411ae7284a662b041dc12e50b..8341ab6a215b05bd47754b9ebf12581113e78fd9 100644 GIT binary patch literal 5717 zcmV-b7OLrqP) zaB^>EX>4U6ba`-PAZ2)IW&i+q+SQp`mK!+^ME|*pSpr`m$mQ^P&g@{8KNn1?+OM|D z?fDs(T`8whnIr;;2taQC=ikTui$6IP@3LqutrX9n+;WSZ7u7$1+Wj4TzQ6Auk=J|S z>+|m8jmW9QF+Km){<^+%UjF*P`x#+;ecsjAJF&iQ^m^lc!(>m6yjWi@lJoI7AM)!) zZ9gwctNP_qM|_U+b>Usg>tE7)>yMk*|J|1-vJ#~wT$N(nDJ0MDm|iPL_AjK)vGO~m zz;z)BuO+GDkkflh1mvsjy^r318KAF%{C?%WjsC~0Z-O8D>wNkqmgN;AUi|QlAU_WN z4e=KbP9H0ZU%#<}Ki_TV_3xZj&)LuJdM-q?yfVtWcXaGCTsR1Mx*se2DF1}ldVd-} z)ptJuXZh&|KQputLFKrR!wNl|uEf9=p z#_fPHr}(j!aG%@m^Jb_#c?B+wftv+B^0zuiK!|v@u$YiQ zz!$PVl;AZ+4S_gT0P3tfrW^+d1d+H{WTY}mbQ){4XY&~i_u442L^};P0feNI6l_}R zfR(e5Kbk8!)Fnw4ks>P9Nt0eu$wf*jT8!knCe12ZRJCf;UQ^9gYN=XlZMC=95*#S4 zG;OuD)_ZiF2X*e%xxRB|#F0iGGRn|VM;m>TJ~L05W$LW6&A#GF3z}GE>8h)(zJuE^ z#m-xH*}Cg)yB~7xq?3=Fa`e>GPX9pdh3a=dzkynKp%yPvcAb8p##yKOzC|!P5yK3` zVh$j#h5!j24KrUu&e4!F%zTR!g+vx9hK-}zK@14<4pC0{!0s>PK82f;`YGJvuaI*F z-TwkP2Xx=#_5;+G>D;`CJymEi^@-}^{aw3@wVSB*?~V9$ccQ$Zk{5mKhg^Gg)30pqDvggBC5MZ8$k8 zQ&Q{1sbT>gGp*j|@jZ{g#y6|-v+JA^n%5TA=Fzqo(xnv%gA+1cCXqCa6~a;v!gWs_ zg%f<7u%mSBjaTTR?TxUk*vgovw^l_O$PjfnJNCXyBhr12d`P%W4VIa@$Vf7#Uq*nb zX_Q%^>FN{qVK-3=e5E05nKUg1q)+w2{d9*$8gjo+RcB&;V3gu4&s6jUwGUI4Y#y*}+l~E+-!^XjqB&ajz#;R9~xHpn} zX%BfMo|34+PNt!@War8{LC$&bsysiJx)w>2SvkQI>=0^-!jnlL{vd$<|J~7jd15UR zg4{m6wv+mnk}2Jik%LT7Shj~`rlw8ljntbI*e9{BQZS=^L5Cu0`Ga{2o^)& zhal?!^Cn%Z&*#_@$+?+Y359!e-5NDH3amnLJLWRW33Y3y*jWe3E+C1CSgDWWfsDX6 zDg`Vh0#P_nmeD!b1$;K~ElOD4>Gzw*l4YN&4&y>Xv!Fcirn^v2*qB)%EAss{&4 zIo-BQcp`Cw8ADqsv{*GJEc0=s^${p~fb4}{;#cpA2nY9#{Y<^mrKbL;SXdCi336q7f& z)*gi1C^UpABxdBAFFIKMajT4>ww6tyWDVRDGK0-4PDeMID5jS1C6V0gVLAB83fXsE zIu9fOjYNcH3~XafHKI3Mn$s>gZiz)nN@_N*mTxZbGpI9u&!) zNIIt7dtJr;LNK(imoNsV@V-Jp!Dl)26uM&Xn4QF`sczcD4Y+Iw-mVns2&YxHLo`P9 zJ9X7dblf|$?#O3_Jv# zkda-%ALxGJL(t%z3p+;aw~p>J7vhj)Lk9%;P$?Z_a7?raD(!5VO0_uFZq+|4utD^8 zaY}=Y%qXL=QiK)}Z-?;kfn(N!U{C~bI~&Ye(jvRufCu3UWxHJp*Bva$DL>bP(>+*v zl0}dp);fghUmD^(Fc-qZM=7LHIVjCoJrIRxf}ap1RJhwY7-8w?c$^whe|Q1^K#9QQ z9p*PHv7l!R6PhLc$rhZ{EGg=$Fm zWF|}|2CnPwZ5HTk*%oE*W5@1swIE97X*}#y*(e2q6J?(iX{bYaQvckJU84DSnd~>u=QG`CrK5}a{;9*LF-XC}nxsDm99=$o{Rs6u!zAZD}f@D{Jdx3=m;WK0#8mo10LV!iL zW-o41<)%ZXgL`w9`zLJpqvQY6jRoUIK;KMkBXy3`)@K{o>(N}nk?4t+nR-%2IN>rGNyre?uRu={TbdnCm$^VNuxldg< zXgP?)!7R}2kHXFh)+{xQdrS0tzBpQBJh!r|3=c!xX4o!qW8cyx?uFsPV-=xMp|zwt zurg}_+dESXW&mMGI!ts~m(8Vd*LEbRebdk)rvn0SW+AT7hDT_>kXhpQl5JXuA$=|z z>>cX5MKy1hK&@S7Tvbtj{3BK`gOJHoPD`wLq+`o!c8ayV5|Vpjoc@IW-S2z=H6(ra zBePdz$)N79#yX;*$hIY?$ub(?Goi^9p^)ZDzKaTQOd8B9+hY*ZsNM`SyNBZ0n=LYWA>JTVV+{2m)CoYYH(*Y9K*13e4J{qNmt&56?sfMa36;QCY_fzKKSnv{C`g=XVQ`UP=#R7=R0P5r3Srxo3LislLf(Q}d@ zR!xhTw>bT~5%pwQ2oUzl=6f&@ zJV-;M#0X>;7`I}G5cB|nG+GZDbZF95XpX`lS&l!p$6^yP3V7#n^n;i#|N#UN4?uR*rXO zLJdoQMkX;tw*HJv@OCHcvo_AI_DREH6E?H-q>Zz@xg(R@$ett3I+P-XoKhvH7;PGC$KSph-Z3q5qg>*)YrfLiK+x2pGUS`4_AH4rx;#qnO;r z2rA&e%qh?wR5fX}$$u{m=G#JM+vMNIz6TQ3Rt+ln$T~+!G>ZfUX)}qkF)q6xsp_>& z-oC#s)Qo+{32lUJUz%){pjI_P5ofkW!mX!uMBfUI)IX$0yLsx^9z-|7xnLO}B{BD| zxH~-FL^YSq4Uoy{=$Kg8vackB#&ci;Xqv3rR$kX5&z1W#q=yQ;ygjnV@ma@&L4mTn z!(GxZ@rbeI_Dke7AEX7LiItb^)b^DPAnWn5LZk)TrGXr;4mPFxwQtG3GqET3JN}RO zahosWJZ;8^;X=647$F~#0X#h^XP@a{C4fy-!1 z?>8_Qjz~~U5n^|c^SEEZ)MWFr*gzQPk7vm5`9G=Lc+0-MdaTS4%h6W~+{r)oSWN?( zP5uMT2E-^4fS)f{>Z&vMmcXD!b3$Z_+E;T*whzl2^zz1Qd3O^hkvQ=hydTTxqoO|@ z{ySEPA)zVxFSH`I^}qV2hyVZqg=s@WP)S2WAaHVTW@&6?004NLeUUv#!$2IxUt6U` zS{&LzL@+~jvLGtzC{-+ih0<1N)xqS_FKE(`q_{W=t_24_7OM^}&bm6d3WDGVh?|>} zqKlOHzogJ2#)IR2yu0_fdk1)%Ri+p{CyubPC=;I(kLh$l;zzDa4!?0O z+3e++5j~rpCyo#cr7o7bn3Z*fc#1eGD=OuCa}LX#w>T@+I&0sPzc5%(S2A3uIfMk3 zkVFasB-Bws6((Y|s-zf5(|+8=KWO_?_x@b#2 zn%-g&ct4|W$^+qBpl{9Tt-X)a2OvvbE#CkKhrnowve#YS9q8`u-!tw0egMbWa>>zB zi@5**00v@9M??Vs0RI60puMM)00009a7bBm000XU000XU0RWnu7ytkO2XskIMF-^w z9}PJyhSOe?000FnNklYJOgKdgDK|w$8xa)z3xfPH?h4!p z7G+m=QnV0SwFo3}5!p?n0$T*NQ?u6Xf&_CT6uGfz5p~5ur*56o;ypGqI(NMHo^!o3 zci!)QW;ow@-XG8V<9RbHflo~7*e16{uvP9^Y{tWQ5Df(6=-?Lq#5KVcxl4k}8ZLUt z?5IB{8`)nz?RmL93cK(GZr}k-4)EJ<;%_{Jt1=f9zS8vd_)BfsDf6n>eANHS8{8cv3&I(LdhA>w=ELRIL#? z$d_hj_Ty8TFw#M&UUD4J zbY{?BFi}D1QM@CwVT^>5$?%*RnRC9qe^Nj2p3I?DG8hfQdTh7HX6IcPe)DfRjTZ#d za^p)}X=vk1O|K3wQe%Yf!}M4+`K5^8w*gP~8?h68MCehOL&6#mx;SDt;2AR}dhKH; zdK9`J?_lGa61q5IcbIX_IX_hIQ{z1JyD^(K4L12KiuEjtBYtm@QrxGC+6-4jv)G4CQ?5!1SH(73a-XS&f zve6uL7zd_He|JbN>`iVeMeN?iQe#Rwm&{Zeu;!fh{Jhl6Zh30@6=$r|a(m{b7IuyB zP~~`#l^&6m6aKa;+trEN{A<1GpkteO;5Jd68#=ZHxh>)e=0tTKklPXjTdmCFj;oPL z!PZb-T1eD43wSn&w?Gp0u{ns>Y7+JFa1ie)CFNEuT_BCR}&L8 zj{kwQeNr=>AYPtL)W@wL-bPN;$DaY$l80Kk76ez4hdRf8SAyIn@vdf~IydDm1;J&p z070TUe+w=L4Hw0V4vFeKq~T&H&j4LbdsG#P{;qn-?Cssu3+ZG^MT*kn|9O9KUeHcv zDvkPWg|Fs=7Fupe2YstBrx}+~7oMd$o4i!&Q_s!^?R=l`F{uj}kap?M;^q&h{G#w! z8npTx^)tUZ=<&NFHF82W`m%1TOzJF7EFQBK($FgFHYIu`*U-@FE;S|s$J1IsAGtoq z@2tpn_x{v6$E1C$-e;%Lm*u6t;=151hEuX_jEJ6=C~ddgh2NjW5}b*pxvJ^OrNg&7 zhs(-hiMy@M(!JVRcd7B(^_e7gNug(!s||SX1;3UbWu8h2!TFMRY-+OYqudE%C zGAHt~gp_JhV-c1EiZyuy*t2UQV&^0DP`^Jq*oCz7E|@Wa9kMa2 zViDW!4IJyI$yaOsi@SnY>36&qU@BMoO*N?(W=bofCiOy6_4FY~jom@2^Dn(fu?k+K zSV($1Wr(C#euhkXkLtb%-aWLJ(}MF}_08z=hfylH!Jif00000NkvXX Hu0mjfEr=C6 delta 1513 zcmVNC95ft1hxG0%a z#Eq2-7DOmYXW|cP5q}Wz2dRobc~jJa6e_e>HzGDObRpa{qs8v4Ig38qjRomY~;HGwNiQp)mK4kpxbQ*vavE^d~gntmzXHL z?fc$-p&Flw9tF7?bf@pZ`qw9OpnroL^67K^L9Ikbfqw%iWBwn$%Nhj}hMi!dc*w_{ zvO}%Be>*yK3Sn1Y0&P$|tIr-v2mtng>-5M6!%Fo$i?V-j#lxGLtLR*n8TkKo5 zKGqIuJj+W7)MSCS=r5kU3N}$eXQ)XW`rUwqP40$~WWQ?2(jA@R*v?SP(J#Pm_fZV& z6~B@%-n8D+tq--Im!92vrxr`vwV|5ii0J9QUVnOa6RgUy4z_wluw~y{S&(QF$)*Xb^P|MNl zAWv@2jQwWA9q%B#N51&N%8@lgHJfw``WG;|*@nM~$FsnwT(fdBu^FoIrRd9H#&>vy zaes&W!#Tv|ykv)qf9u zdDXmymDLPXmy=*yn3esb2TQeQWh)t5YVfJHxa{PNPYtub9V=U@WcCknsx2=&ITNRb z6_p(;TPd@m3TuB`S{8EBSo?=8l?5$Jk;_@CkWaNFW+PV5r-sbG4J=!kdgkAVhdnA= zXlu(g?NJ%Au$!74s?pH2n`)ZG)PLku*!R!CHtyMyl>USYl?)tPE^Vn$iCsXbrAp6t z&04BZ^_L#1IT+lo>d$Dbg!)uxuR#!y2gwKLR?mDHi;-=bo;?)x1VJ<5c*PC_WtYK z?PsMw2_4j|rZ+dJi7N=b{yHrStkAa#Z4HZc?(x^aKgg*aWI*cB^Na14G4zl6fDGsn) zMZ;HqT!}f#hwo820&^9Q0oV3+?aOIb@Rey2<2o!K@_uLkv?Cg@Z~XlapB`q@Q#_`F P00000NkvXXu0mjf6tnmc diff --git a/environments/factory/base/base_factory.py b/environments/factory/base/base_factory.py index 3e96e90..7749898 100644 --- a/environments/factory/base/base_factory.py +++ b/environments/factory/base/base_factory.py @@ -60,9 +60,12 @@ class BaseFactory(gym.Env): omit_agent_in_obs=False, done_at_collision=False, cast_shadows=True, verbose=False, doors_have_area=True, env_seed=time.time_ns(), **kwargs): assert frames_to_stack != 1 and frames_to_stack >= 0, "'frames_to_stack' cannot be negative or 1." + if kwargs: + print(f'Following kwargs were passed, but ignored: {kwargs}') # Attribute Assignment self.env_seed = env_seed + self.seed(env_seed) self._base_rng = np.random.default_rng(self.env_seed) self.movement_properties = movement_properties self.level_name = level_name @@ -85,11 +88,6 @@ class BaseFactory(gym.Env): self.parse_doors = parse_doors self.doors_have_area = doors_have_area - # Actions - self._actions = Actions(self.movement_properties, can_use_doors=self.parse_doors) - if additional_actions := self.additional_actions: - self._actions.register_additional_items(additional_actions) - # Reset self.reset() @@ -123,11 +121,17 @@ class BaseFactory(gym.Env): self.NO_POS_TILE = Tile(c.NO_POS.value) # Doors - parsed_doors = h.one_hot_level(parsed_level, c.DOOR) - if np.any(parsed_doors): - door_tiles = [floor.by_pos(pos) for pos in np.argwhere(parsed_doors == c.OCCUPIED_CELL.value)] - doors = Doors.from_tiles(door_tiles, self._level_shape, context=floor, is_blocking_light=True) - entities.update({c.DOORS: doors}) + if self.parse_doors: + parsed_doors = h.one_hot_level(parsed_level, c.DOOR) + if np.any(parsed_doors): + door_tiles = [floor.by_pos(pos) for pos in np.argwhere(parsed_doors == c.OCCUPIED_CELL.value)] + doors = Doors.from_tiles(door_tiles, self._level_shape, context=floor, is_blocking_light=True) + entities.update({c.DOORS: doors}) + + # Actions + self._actions = Actions(self.movement_properties, can_use_doors=self.parse_doors) + if additional_actions := self.additional_actions: + self._actions.register_additional_items(additional_actions) # Agents agents = Agents.from_tiles(floor.empty_tiles[:self.n_agents], self._level_shape) @@ -155,8 +159,8 @@ class BaseFactory(gym.Env): # Optionally Pad this obs cube for pomdp cases if r := self.pomdp_r: x, y = self._level_shape + # was c.SHADOW self._padded_obs_cube = np.full((obs_cube_z, x + r*2, y + r*2), c.SHADOWED_CELL.value, dtype=np.float32) - # self._padded_obs_cube[0] = c.OCCUPIED_CELL.value self._padded_obs_cube[:, r:r+x, r:r+y] = self._obs_cube def reset(self) -> (np.ndarray, int, bool, dict): @@ -170,7 +174,10 @@ class BaseFactory(gym.Env): return obs def step(self, actions): - actions = [actions] if isinstance(actions, int) or np.isscalar(actions) else actions + + if self.n_agents == 1: + actions = [int(actions)] + assert isinstance(actions, Iterable), f'"actions" has to be in [{int, list}]' self._steps += 1 done = False @@ -180,9 +187,10 @@ class BaseFactory(gym.Env): # Move this in a seperate function? for action, agent in zip(actions, self[c.AGENT]): - agent.clear_temp_sate() + agent.clear_temp_state() action_obj = self._actions[int(action)] - if self._actions.is_moving_action(action_obj): + self.print(f'Action #{action} has been resolved to: {action_obj}') + if h.MovingAction.is_member(action_obj): valid = self._move_or_colide(agent, action_obj) elif h.EnvActions.NOOP == agent.temp_action: valid = c.VALID @@ -210,7 +218,8 @@ class BaseFactory(gym.Env): # Step the door close intervall if self.parse_doors: - self[c.DOORS].tick_doors() + if doors := self[c.DOORS]: + doors.tick_doors() # Finalize reward, reward_info = self.calculate_reward() @@ -229,15 +238,18 @@ class BaseFactory(gym.Env): return obs, reward, done, info def _handle_door_interaction(self, agent) -> c: - # Check if agent really is standing on a door: - if self.doors_have_area: - door = self[c.DOORS].get_near_position(agent.pos) - else: - door = self[c.DOORS].by_pos(agent.pos) - if door is not None: - door.use() - return c.VALID - # When he doesn't... + if doors := self[c.DOORS]: + # Check if agent really is standing on a door: + if self.doors_have_area: + door = doors.get_near_position(agent.pos) + else: + door = doors.by_pos(agent.pos) + if door is not None: + door.use() + return c.VALID + # When he doesn't... + else: + return c.NOT_VALID else: return c.NOT_VALID @@ -284,8 +296,9 @@ class BaseFactory(gym.Env): state_array_dict[c.AGENT][0, agent.x, agent.y] += agent.encoding if r := self.pomdp_r: + self._padded_obs_cube[:] = c.SHADOWED_CELL.value # Was c.SHADOW + # self._padded_obs_cube[0] = c.OCCUPIED_CELL.value x, y = self._level_shape - self._padded_obs_cube[:] = c.SHADOWED_CELL.value self._padded_obs_cube[:, r:r + x, r:r + y] = self._obs_cube global_x, global_y = map(sum, zip(agent.pos, (r, r))) x0, x1 = max(0, global_x - self.pomdp_r), global_x + self.pomdp_r + 1 @@ -297,20 +310,22 @@ class BaseFactory(gym.Env): if self.cast_shadows: obs_block_light = [obs[idx] != c.OCCUPIED_CELL.value for idx in shadowing_idxs] door_shadowing = False - if door := self[c.DOORS].by_pos(agent.pos): - if door.is_closed: - for group in door.connectivity_subgroups: - if agent.last_pos not in group: - door_shadowing = True - if self.pomdp_r: - blocking = [tuple(np.subtract(x, agent.pos) + (self.pomdp_r, self.pomdp_r)) - for x in group] - xs, ys = zip(*blocking) - else: - xs, ys = zip(*group) + if self.parse_doors: + if doors := self[c.DOORS]: + if door := doors.by_pos(agent.pos): + if door.is_closed: + for group in door.connectivity_subgroups: + if agent.last_pos not in group: + door_shadowing = True + if self.pomdp_r: + blocking = [tuple(np.subtract(x, agent.pos) + (self.pomdp_r, self.pomdp_r)) + for x in group] + xs, ys = zip(*blocking) + else: + xs, ys = zip(*group) - # noinspection PyUnresolvedReferences - obs_block_light[0][xs, ys] = False + # noinspection PyUnresolvedReferences + obs_block_light[0][xs, ys] = False light_block_map = Map((np.prod(obs_block_light, axis=0) != True).astype(int)) if self.pomdp_r: @@ -361,22 +376,24 @@ class BaseFactory(gym.Env): return tile, valid if self.parse_doors and agent.last_pos != c.NO_POS: - if door := self[c.DOORS].by_pos(new_tile.pos): - if door.can_collide: - return agent.tile, c.NOT_VALID - else: # door.is_closed: - pass + if doors := self[c.DOORS]: + if self.doors_have_area: + if door := doors.by_pos(new_tile.pos): + if door.can_collide: + return agent.tile, c.NOT_VALID + else: # door.is_closed: + pass - if door := self[c.DOORS].by_pos(agent.pos): - if door.is_open: - pass - else: # door.is_closed: - if door.is_linked(agent.last_pos, new_tile.pos): + if door := doors.by_pos(agent.pos): + if door.is_open: pass - else: - return agent.tile, c.NOT_VALID - else: - pass + else: # door.is_closed: + if door.is_linked(agent.last_pos, new_tile.pos): + pass + else: + return agent.tile, c.NOT_VALID + else: + pass else: pass @@ -391,7 +408,9 @@ class BaseFactory(gym.Env): if self._actions.is_moving_action(agent.temp_action): if agent.temp_valid: # info_dict.update(movement=1) - reward -= 0.00 + # info_dict.update({f'{agent.name}_failed_action': 1}) + # reward += 0.00 + pass else: # self.print('collision') reward -= 0.01 @@ -400,16 +419,17 @@ class BaseFactory(gym.Env): elif h.EnvActions.USE_DOOR == agent.temp_action: if agent.temp_valid: + # reward += 0.00 self.print(f'{agent.name} did just use the door at {agent.pos}.') info_dict.update(door_used=1) else: - reward -= 0.00 + # reward -= 0.00 self.print(f'{agent.name} just tried to use a door at {agent.pos}, but failed.') info_dict.update({f'{agent.name}_failed_action': 1}) info_dict.update({f'{agent.name}_failed_door_open': 1}) elif h.EnvActions.NOOP == agent.temp_action: info_dict.update(no_op=1) - reward -= 0.00 + # reward -= 0.00 additional_reward, additional_info_dict = self.calculate_additional_reward(agent) reward += additional_reward diff --git a/environments/factory/base/objects.py b/environments/factory/base/objects.py index aa67b14..fda6607 100644 --- a/environments/factory/base/objects.py +++ b/environments/factory/base/objects.py @@ -24,15 +24,27 @@ class Object: @property def identifier(self): - return self._enum_ident - - def __init__(self, enum_ident: Union[Enum, None] = None, is_blocking_light=False, **kwargs): - self._enum_ident = enum_ident if self._enum_ident is not None: - self._name = f'{self.__class__.__name__}[{self._enum_ident.name}]' + return self._enum_ident + elif self._str_ident is not None: + return self._str_ident else: + return self._name + + def __init__(self, str_ident: Union[str, None] = None, enum_ident: Union[Enum, None] = None, is_blocking_light=False, **kwargs): + self._str_ident = str_ident + self._enum_ident = enum_ident + + if self._enum_ident is not None and self._str_ident is None: + self._name = f'{self.__class__.__name__}[{self._enum_ident.name}]' + elif self._str_ident is not None and self._enum_ident is None: + self._name = f'{self.__class__.__name__}[{self._str_ident}]' + elif self._str_ident is None and self._enum_ident is None: self._name = f'{self.__class__.__name__}#{self._u_idx}' - Object._u_idx += 1 + Object._u_idx += 1 + else: + raise ValueError('Please use either of the idents.') + self._is_blocking_light = is_blocking_light if kwargs: print(f'Following kwargs were passed, but ignored: {kwargs}') @@ -166,7 +178,7 @@ class Door(Entity): @property def encoding(self): - return 1 if self.is_closed else -1 + return 1 if self.is_closed else 0.5 @property def access_area(self): @@ -274,10 +286,10 @@ class Agent(MoveableEntity): def __init__(self, *args, **kwargs): super(Agent, self).__init__(*args, **kwargs) - self.clear_temp_sate() + self.clear_temp_state() # noinspection PyAttributeOutsideInit - def clear_temp_sate(self): + def clear_temp_state(self): # for attr in self.__dict__: # if attr.startswith('temp'): self.temp_collisions = [] diff --git a/environments/factory/base/registers.py b/environments/factory/base/registers.py index 6cb9b21..4f09641 100644 --- a/environments/factory/base/registers.py +++ b/environments/factory/base/registers.py @@ -53,7 +53,10 @@ class Register: return next(v for i, v in enumerate(self._register.values()) if i == item) except StopIteration: return None - return self._register[item] + try: + return self._register[item] + except KeyError: + return None def __repr__(self): return f'{self.__class__.__name__}({self._register})' @@ -84,8 +87,8 @@ class EntityObjectRegister(ObjectRegister, ABC): @classmethod def from_tiles(cls, tiles, *args, **kwargs): # objects_name = cls._accepted_objects.__name__ - entities = [cls._accepted_objects(tile, **kwargs) - for tile in tiles] + entities = [cls._accepted_objects(tile, str_ident=i, **kwargs) + for i, tile in enumerate(tiles)] register_obj = cls(*args) register_obj.register_additional_items(entities) return register_obj @@ -294,10 +297,10 @@ class Actions(Register): if self.allow_square_movement: self.register_additional_items([self._accepted_objects(enum_ident=direction) - for direction in h.ManhattanMoves]) + for direction in h.MovingAction.square()]) if self.allow_diagonal_movement: self.register_additional_items([self._accepted_objects(enum_ident=direction) - for direction in h.DiagonalMoves]) + for direction in h.MovingAction.diagonal()]) self._movement_actions = self._register.copy() if self.can_use_doors: self.register_additional_items([self._accepted_objects(enum_ident=h.EnvActions.USE_DOOR)]) diff --git a/environments/factory/renderer.py b/environments/factory/renderer.py index 42491db..a4ca734 100644 --- a/environments/factory/renderer.py +++ b/environments/factory/renderer.py @@ -79,14 +79,15 @@ class Renderer: rects = [] for i, j in product(range(-self.view_radius, self.view_radius+1), range(-self.view_radius, self.view_radius+1)): - if bool(view[self.view_radius+j, self.view_radius+i]): - visibility_rect = bp['dest'].copy() - visibility_rect.centerx += i*self.cell_size - visibility_rect.centery += j*self.cell_size - shape_surf = pygame.Surface(visibility_rect.size, pygame.SRCALPHA) - pygame.draw.rect(shape_surf, self.AGENT_VIEW_COLOR, shape_surf.get_rect()) - shape_surf.set_alpha(64) - rects.append(dict(source=shape_surf, dest=visibility_rect)) + if view is not None: + if bool(view[self.view_radius+j, self.view_radius+i]): + visibility_rect = bp['dest'].copy() + visibility_rect.centerx += i*self.cell_size + visibility_rect.centery += j*self.cell_size + shape_surf = pygame.Surface(visibility_rect.size, pygame.SRCALPHA) + pygame.draw.rect(shape_surf, self.AGENT_VIEW_COLOR, shape_surf.get_rect()) + shape_surf.set_alpha(64) + rects.append(dict(source=shape_surf, dest=visibility_rect)) return rects def render(self, entities): diff --git a/environments/factory/simple_factory.py b/environments/factory/simple_factory.py index 667e145..1df8bef 100644 --- a/environments/factory/simple_factory.py +++ b/environments/factory/simple_factory.py @@ -94,6 +94,10 @@ class DirtRegister(MovingEntityObjectRegister): return c.NOT_VALID return c.VALID + def __repr__(self): + s = super(DirtRegister, self).__repr__() + return f'{s[:-1]}, {self.amount})' + def softmax(x): """Compute softmax values for each sets of scores in x.""" @@ -149,7 +153,10 @@ class SimpleFactory(BaseFactory): return c.NOT_VALID def trigger_dirt_spawn(self): - free_for_dirt = self[c.FLOOR].empty_tiles + free_for_dirt = [x for x in self[c.FLOOR] + if len(x.guests) == 0 or (len(x.guests) == 1 and isinstance(next(y for y in x.guests), Dirt)) + ] + self._dirt_rng.shuffle(free_for_dirt) new_spawn = self._dirt_rng.uniform(0, self.dirt_properties.max_spawn_ratio) n_dirt_tiles = max(0, int(new_spawn * len(free_for_dirt))) self[c.DIRT].spawn_dirt(free_for_dirt[:n_dirt_tiles]) @@ -216,7 +223,7 @@ class SimpleFactory(BaseFactory): self.print(f'{agent.name} did just clean up some dirt at {agent.pos}.') info_dict.update(dirt_cleaned=1) else: - reward -= 0.00 + reward -= 0.01 self.print(f'{agent.name} just tried to clean up some dirt at {agent.pos}, but failed.') info_dict.update({f'{agent.name}_failed_action': 1}) info_dict.update({f'{agent.name}_failed_action': 1}) @@ -235,8 +242,8 @@ if __name__ == '__main__': factory = SimpleFactory(n_agents=1, done_at_collision=False, frames_to_stack=0, level_name='rooms', max_steps=400, combin_agent_obs=True, - omit_agent_in_obs=True, parse_doors=True, pomdp_r=2, - record_episodes=False, verbose=True + omit_agent_in_obs=True, parse_doors=False, pomdp_r=2, + record_episodes=False, verbose=True, cast_shadows=False ) # noinspection DuplicatedCode diff --git a/environments/helpers.py b/environments/helpers.py index b9de4b1..c9127b9 100644 --- a/environments/helpers.py +++ b/environments/helpers.py @@ -50,19 +50,28 @@ class Constants(Enum): return bool(self.value) -class ManhattanMoves(Enum): +class MovingAction(Enum): NORTH = 'north' EAST = 'east' SOUTH = 'south' WEST = 'west' - - -class DiagonalMoves(Enum): NORTHEAST = 'north_east' SOUTHEAST = 'south_east' SOUTHWEST = 'south_west' NORTHWEST = 'north_west' + @classmethod + def is_member(cls, other): + return any([other == direction for direction in cls]) + + @classmethod + def square(cls): + return [cls.NORTH, cls.EAST, cls.SOUTH, cls.WEST] + + @classmethod + def diagonal(cls): + return [cls.NORTHEAST, cls.SOUTHEAST, cls.SOUTHWEST, cls.NORTHWEST] + class EnvActions(Enum): NOOP = 'no_op' @@ -71,14 +80,13 @@ class EnvActions(Enum): ITEM_ACTION = 'item_action' -d = DiagonalMoves -m = ManhattanMoves +m = MovingAction c = Constants -ACTIONMAP = defaultdict(lambda: (0, 0), {m.NORTH: (-1, 0), d.NORTHEAST: (-1, +1), - m.EAST: (0, 1), d.SOUTHEAST: (1, 1), - m.SOUTH: (1, 0), d.SOUTHWEST: (+1, -1), - m.WEST: (0, -1), d.NORTHWEST: (-1, -1) +ACTIONMAP = defaultdict(lambda: (0, 0), {m.NORTH: (-1, 0), m.NORTHEAST: (-1, +1), + m.EAST: (0, 1), m.SOUTHEAST: (1, 1), + m.SOUTH: (1, 0), m.SOUTHWEST: (+1, -1), + m.WEST: (0, -1), m.NORTHWEST: (-1, -1) } ) @@ -126,8 +134,10 @@ def asset_str(agent): return 'agent_collision', 'blank' elif not agent.temp_valid or c.LEVEL.name in col_names or c.AGENT.name in col_names: return c.AGENT.value, 'invalid' - elif agent.temp_valid: + elif agent.temp_valid and not MovingAction.is_member(agent.temp_action): return c.AGENT.value, 'valid' + elif agent.temp_valid and MovingAction.is_member(agent.temp_action): + return c.AGENT.value, 'move' else: return c.AGENT.value, 'idle' diff --git a/environments/logging/monitor.py b/environments/logging/monitor.py index 93786d0..9ded10b 100644 --- a/environments/logging/monitor.py +++ b/environments/logging/monitor.py @@ -1,4 +1,5 @@ import pickle +from collections import defaultdict from pathlib import Path from typing import List, Dict @@ -17,7 +18,7 @@ class MonitorCallback(BaseCallback): super(MonitorCallback, self).__init__() self.filepath = Path(filepath) self._monitor_df = pd.DataFrame() - self._monitor_dict = dict() + self._monitor_dicts = defaultdict(dict) self.plotting = plotting self.started = False self.closed = False @@ -69,16 +70,22 @@ class MonitorCallback(BaseCallback): def _on_step(self, alt_infos: List[Dict] = None, alt_dones: List[bool] = None) -> bool: infos = alt_infos or self.locals.get('infos', []) - dones = alt_dones or self.locals.get('dones', None) or self.locals.get('done', [None]) - for _, info in enumerate(infos): - self._monitor_dict[self.num_timesteps] = {key: val for key, val in info.items() - if key not in ['terminal_observation', 'episode'] - and not key.startswith('rec_')} + if alt_dones is not None: + dones = alt_dones + elif self.locals.get('dones', None) is not None: + dones =self.locals.get('dones', None) + elif self.locals.get('dones', None) is not None: + dones = self.locals.get('done', [None]) + else: + dones = [] - for env_idx, done in enumerate(dones): + for env_idx, (info, done) in enumerate(zip(infos, dones)): + self._monitor_dicts[env_idx][self.num_timesteps - env_idx] = {key: val for key, val in info.items() + if key not in ['terminal_observation', 'episode'] + and not key.startswith('rec_')} if done: - env_monitor_df = pd.DataFrame.from_dict(self._monitor_dict, orient='index') - self._monitor_dict = dict() + env_monitor_df = pd.DataFrame.from_dict(self._monitor_dicts[env_idx], orient='index') + self._monitor_dicts[env_idx] = dict() columns = [col for col in env_monitor_df.columns if col not in IGNORED_DF_COLUMNS] env_monitor_df = env_monitor_df.aggregate( {col: 'mean' if col.endswith('ount') else 'sum' for col in columns} diff --git a/main.py b/main.py index 5967248..d64a472 100644 --- a/main.py +++ b/main.py @@ -8,6 +8,7 @@ import time import pandas as pd from stable_baselines3.common.callbacks import CallbackList +from stable_baselines3.common.vec_env import SubprocVecEnv from environments.factory.double_task_factory import DoubleTaskFactory, ItemProperties from environments.factory.simple_factory import DirtProperties, SimpleFactory @@ -84,8 +85,20 @@ def compare_runs(run_path: Path, run_identifier: int, parameter: Union[str, List print('Plotting done.') +def make_env(env_kwargs_dict): + + def _init(): + with SimpleFactory(**env_kwargs_dict) as init_env: + return init_env + + return _init + + if __name__ == '__main__': + # combine_runs(Path('debug_out') / 'A2C_1630314192') + # exit() + # compare_runs(Path('debug_out'), 1623052687, ['step_reward']) # exit() @@ -93,65 +106,67 @@ if __name__ == '__main__': from algorithms.reg_dqn import RegDQN # from sb3_contrib import QRDQN - dirt_props = DirtProperties(clean_amount=1, gain_amount=0.1, max_global_amount=20, - max_local_amount=1, spawn_frequency=5, max_spawn_ratio=0.05, - dirt_smear_amount=0.0, agent_can_interact=False) + dirt_props = DirtProperties(clean_amount=2, gain_amount=0.1, max_global_amount=20, + max_local_amount=1, spawn_frequency=3, max_spawn_ratio=0.05, + dirt_smear_amount=0.0, agent_can_interact=True) item_props = ItemProperties(n_items=5, agent_can_interact=True) - move_props = MovementProperties(allow_diagonal_movement=True, + move_props = MovementProperties(allow_diagonal_movement=False, allow_square_movement=True, allow_no_op=False) - train_steps = 6e5 + train_steps = 1e6 time_stamp = int(time.time()) out_path = None for modeL_type in [A2C, PPO, DQN]: # ,RegDQN, QRDQN]: for seed in range(3): + env_kwargs = dict(n_agents=1, + # with_dirt=True, + # item_properties=item_props, + dirt_properties=dirt_props, + movement_properties=move_props, + pomdp_r=2, max_steps=400, parse_doors=True, + level_name='simple', frames_to_stack=6, + omit_agent_in_obs=True, combin_agent_obs=True, record_episodes=False, + cast_shadows=True, doors_have_area=False, env_seed=seed, verbose=False, + ) - with SimpleFactory(n_agents=1, - # with_dirt=True, - # item_properties=item_props, - dirt_properties=dirt_props, - movement_properties=move_props, - pomdp_radius=2, max_steps=500, parse_doors=True, - level_name='rooms', frames_to_stack=3, - omit_agent_in_obs=True, combin_agent_slices_in_obs=True, record_episodes=False, - cast_shadows=True, doors_have_area=False, seed=seed, verbose=False, - ) as env: + # env = make_env(env_kwargs)() + env = SubprocVecEnv([make_env(env_kwargs) for _ in range(12)], start_method="spawn") - if modeL_type.__name__ in ["PPO", "A2C"]: - kwargs = dict(ent_coef=0.01) - elif modeL_type.__name__ in ["RegDQN", "DQN", "QRDQN"]: - kwargs = dict(buffer_size=50000, - learning_starts=64, - batch_size=64, - target_update_interval=5000, - exploration_fraction=0.25, - exploration_final_eps=0.025) - else: - raise NameError(f'The model "{model.__name__}" has the wrong name.') - model = modeL_type("MlpPolicy", env, verbose=1, seed=seed, device='cpu', **kwargs) + if modeL_type.__name__ in ["PPO", "A2C"]: + kwargs = dict(ent_coef=0.01) + elif modeL_type.__name__ in ["RegDQN", "DQN", "QRDQN"]: + kwargs = dict(buffer_size=50000, + learning_starts=64, + batch_size=64, + target_update_interval=5000, + exploration_fraction=0.25, + exploration_final_eps=0.025) + else: + raise NameError(f'The model "{modeL_type.__name__}" has the wrong name.') + model = modeL_type("MlpPolicy", env, verbose=1, seed=seed, device='cpu', **kwargs) - out_path = Path('debug_out') / f'{model.__class__.__name__}_{time_stamp}' + out_path = Path('debug_out') / f'{model.__class__.__name__}_{time_stamp}' - # identifier = f'{seed}_{model.__class__.__name__}_{time_stamp}' - identifier = f'{seed}_{model.__class__.__name__}_{time_stamp}' - out_path /= identifier + # identifier = f'{seed}_{model.__class__.__name__}_{time_stamp}' + identifier = f'{seed}_{model.__class__.__name__}_{time_stamp}' + out_path /= identifier - callbacks = CallbackList( - [MonitorCallback(filepath=out_path / f'monitor_{identifier}.pick', plotting=False), - RecorderCallback(filepath=out_path / f'recorder_{identifier}.json', occupation_map=False, - trajectory_map=False - )] - ) + callbacks = CallbackList( + [MonitorCallback(filepath=out_path / f'monitor_{identifier}.pick', plotting=False), + RecorderCallback(filepath=out_path / f'recorder_{identifier}.json', occupation_map=False, + trajectory_map=False + )] + ) - model.learn(total_timesteps=int(train_steps), callback=callbacks) + model.learn(total_timesteps=int(train_steps), callback=callbacks) - save_path = out_path / f'model_{identifier}.zip' - save_path.parent.mkdir(parents=True, exist_ok=True) - model.save(save_path) - env.save_params(out_path.parent / f'env_{model.__class__.__name__}_{time_stamp}.yaml') - print("Model Trained and saved") + save_path = out_path / f'model_{identifier}.zip' + save_path.parent.mkdir(parents=True, exist_ok=True) + model.save(save_path) + env.env_method('save_params', out_path.parent / f'env_{model.__class__.__name__}_{time_stamp}.yaml') + print("Model Trained and saved") print("Model Group Done.. Plotting...") if out_path: diff --git a/reload_agent.py b/reload_agent.py index f018df7..80b5e49 100644 --- a/reload_agent.py +++ b/reload_agent.py @@ -3,7 +3,7 @@ from pathlib import Path import yaml from natsort import natsorted -from stable_baselines3 import PPO +from stable_baselines3 import PPO, DQN, A2C from stable_baselines3.common.evaluation import evaluate_policy from environments.factory.simple_factory import DirtProperties, SimpleFactory @@ -12,16 +12,19 @@ from environments.factory.double_task_factory import ItemProperties, DoubleTaskF warnings.filterwarnings('ignore', category=FutureWarning) warnings.filterwarnings('ignore', category=UserWarning) +model_map = dict(PPO=PPO, DQN=DQN, A2C=A2C) if __name__ == '__main__': - model_name = 'A2C_1630073286' + model_name = 'A2C_1630414444' run_id = 0 + seed=69 out_path = Path(__file__).parent / 'debug_out' model_path = out_path / model_name with (model_path / f'env_{model_name}.yaml').open('r') as f: env_kwargs = yaml.load(f, Loader=yaml.FullLoader) + env_kwargs.update(verbose=True, env_seed=seed) if False: env_kwargs.update(dirt_properties=DirtProperties(clean_amount=1, gain_amount=0.1, max_global_amount=20, max_local_amount=1, spawn_frequency=5, max_spawn_ratio=0.05, @@ -30,9 +33,10 @@ if __name__ == '__main__': with SimpleFactory(**env_kwargs) as env: # Edit THIS: + env.seed(seed) model_files = list(natsorted((model_path / f'{run_id}_{model_name}').rglob('model_*.zip'))) this_model = model_files[0] - - model = PPO.load(this_model) - evaluation_result = evaluate_policy(model, env, n_eval_episodes=100, deterministic=False, render=True) + model_cls = next(val for key, val in model_map.items() if key in model_name) + model = model_cls.load(this_model) + evaluation_result = evaluate_policy(model, env, n_eval_episodes=100, deterministic=True, render=True) print(evaluation_result)