From 155b3fdcf2e18e8362e82b039522af140990d292 Mon Sep 17 00:00:00 2001 From: Remi Cadene Date: Sun, 19 May 2024 19:30:20 +0000 Subject: [PATCH] update aloha_act --- .../aloha_act/actions.safetensors | Bin 5104 -> 130 bytes .../aloha_act/grad_stats.safetensors | Bin 31688 -> 130 bytes .../aloha_act/output_dict.safetensors | Bin 68 -> 127 bytes .../aloha_act/param_stats.safetensors | Bin 33408 -> 130 bytes tests/scripts/save_policy_to_safetensor.py | 18 ++++++++---------- 5 files changed, 8 insertions(+), 10 deletions(-) diff --git a/tests/data/save_policy_to_safetensors/aloha_act/actions.safetensors b/tests/data/save_policy_to_safetensors/aloha_act/actions.safetensors index 3c9447d7fa0b68143216f21c4d9cf5c075253fe4..c816148f19e02efd4b7d8195b2f63452665dba44 100644 GIT binary patch literal 130 zcmWN?%MrpL5CG6SRnUNe<-18Y%L2lTN=7gTtJimVS9#AqUb3xq&O@ntU$;k{+yC}S z8;_@&XJK`L7(L6`$oOys3&lbX5JDxLED|IWHF{4GwIQ9O!=xkaY`Ge64%ssla@AMb NvzYx`qzB_r`~bw=CnEp= literal 5104 zcmb8y`CpCc8wc=ECJFH+l}a6H;gptB=P1v0x19DxyO6P*bCPsUY15ExgprtL#?n}# zO<7tjW9_<7N>MWzW*AG85mBVbl-BPt|AFVX=Xt%}ulw~}_kDd{_Z2--<-cb=&zk4B zo+nNECSJyKO=;fc^7K`{cf+c*Oba~S9rEyVFiL#_bC{;ci4RDXtiTL;m9 zC3E;VeSuRc&AN7 z{j+4uYd#-ORt><@(V>`-UkWFTJZRT-1sQLXv-8$wC>z@Vm3PKN-14d5{K*_&ymF!H zQYqVXY!-s*Ak@4LMxoVND3E(nvTQXeCdrw8=UANftOle%jD(nNqv1%)Y&49qBfBwC zjCVi>otO&C_6foHukzrxSRcB$HI?SQNo3MLXQB^R3(cYrT>4=(*j&fQ#@BNx^0thP z+%**)T}H65zCdhsI0^?+{U~F7JS8+GG0m0snAuzdnvOl(?_ZAxD=i_$@AjeGUK#UA z(nXaED$Ft`9Oqstf|_Vw>bkmy94{v@@lrA3jVcg7?cxM~O@v*KgsAQ5OZ6|J*!?se zymnrN4gL{|=Pnii&(D`ei{t6xp9w5Y>WeN{Zo=N5KXMN<^q@FlHjX;$PCh~jyI-e^ z5y=DKb0`cWZyg2ASMJmrxQep>Q81J9({Z8wWyo$G;=Fh9pf%GTvo?Cs%Lh{S+{h4h zXQ;5}b3;+II3KK%y(m|rpgl2h%+JyhPv+FX;UiyxMUxTO|89wo{$k{nEM>nn&cc9` zeK4&l1jFM`!wGdynrIh8o1^2Hc(w`l7vF)LRyEl0SQ|#Xw89Q!Px`e-!j=`7qyAU@ zaQok29KE*;*135R$D&AkY#f_+Wi0-9`zBIRo=v6}VzE|AaUv#0})EEPwIns}LQg(a2A*LSeg{|TR_`azSB7@!N z>u2%gTo%hRt9f|EunHyztAk-D;6B56>BR{O|zWi3`W^1w~M&??Klr zlW1>mEaShMi^~lg!APbBfrXPHQ^Lc4Ivwcbex;M!;M#&7Sn3jvy_e2_eZL!hyC|N* zz2$6PJ`cC5-Gkao6X52(F(6hoLGxdL4$YP^hjG@Ja;^td48t(n_ZKMUxs!iIEUi$J zv!tCGc*?W^DtbqP^hF2PwAB=|c-9oVB8ug78(^PhKl~9Ch9Sv?U@_B+R@$wl(%2+6 zu-gL1uKpFQp_^N`q=T!QXofz8LYiw4#f;V&;D>R8kT4C9U8+{@L&|it zJA4h2QZzsuH5CFnXJfj*6Aj;xu@ZkXoTjP5B9;f^qo6|YKjcB?W@{*Ohnx)*i!gKe z4h)^01i$%Bh0e4Y_|Kp71QrSu@sXlXE*mMYc{nJ zE`68;ho)(O@YYN;;W<*>BPk1fXoD-0`{B;tfmlEBEC}zo)7ERzq|1{tzZ5k*`9~f2 zG^oNx|1NG{y*Ww)EGc$R6w_8W#_x@X!9*(veeUGLDCR{KO%(Aa&nL)vx@JmF!fR$TtD2; zrFDPeI`ssYG{v2+#YeGa;l|j$Wf)R7g`(807`9Y;Q~w+VwbduEdPhLvpEV%J>*4ek zj)arP1;`KapigNl*|H->xGZuQM0vre`mO*jHh7b3WGwlOOJpq<-SOL$Tkr(>IC0)s zSbTRD^5Z!=dvOJGE1iaQ)dP@mBN*qcJqDXg80{%tP3yQgmYVUZqPZO5;2d=jY|??H zn)Y~3&xNu)q|Cw42($H7Sg>OdcHPg1jLoj}r&b(U*2J-j{s|cIst&};j8I&h3X==@ z_}BMB3Og@hu2(FvKD-Y!W(J{1R07HKT_|%Pmd@UYW8WR^udvbo6(ar~56!I_a7ku` zfch-oF|FqrO2Y8zIOp(Kt)xYku9x7>n=J0l^Z=mR&P?A7+W zHuN|~%DNBeVr2FpcrFXVShYM@YUWN=_fkl2dkWK8{l4O+pbk`ae&k{rJ2|MWw=jJ}HNIVku);Xal(Tk2} zNLk!NU5r~j1h))Bv8?bo>`?QhNM+wTijtYwKo>i{y9v$%-P|z;b+CNzgj!dR|0x(ZZ_9^gogNg_C8t1#WcJ3x7;jvwhHhamR~s}DzWaG5dPLjPR+)r3 zm1v{$oqpKg7>cdaioy1vn9AE$QGutNWuD-nX>SE+r=F-@HDSO|j zhu*V>!E9>?)|@^G< zz^y14({^5jnL{3w)+M8M<$9-EG;nP|4J3Jtgpt#pbEi%CI6lCFcC3)GmP=Ex)_D-# z90|snWyP>W%ZGk_yoSb2PiFU{XJU+1Ef^_2a9_`R$+{QUK-1rs@(xB*WKRO~d*p(*=2SuLwr(!iO}br3wdseVgdIjWB=$OsQfzw z-`AgjmT+Ivez%ewOcR-(t2dhMx(P`edbr~oG~j}!Io_&tra!eM%%pHChWYk^@sm(o z@bnl!x|o6wt)_BW95Z=lhG`L3LE1LRbsm}oFI{ahOC~1uT~gK-sgKSxhu}Mt5S&q- z4-pDa%0mV1WAV%+nc-HKS}?c#3MLgyg4`cw;df7+$jesB^y+7!x6LQ;S{{NGMW^6} zwFgbEm6Kq5JUhJH9L>hpL&k+spmtFmR;}S(rg6%9&$gBscvW`r4*0^=Nh$exQU3g`HdmB~QmBIjAIwu#T*FCA!HJ0r4*Rsr} z|KXVhH=yM}AJ@D<9rP|aV#-4?DSnl*H~&t@5n3utF-bWG+%E=*^Q5+>STcQ?#1115 z{rV6x$9?3uNOefpb3&_I9u!+8WAROfSY@cfCVK|sl~KoG9(Ynh@M^mKDv9wMEil~Z z8g%n|xs@v>L2HB&o*#CgampMFbuz|niv~fwFbt=iD}Y8nF$r!a(&6uwbJc4b>={`B zCYw}Y->0##T{H{3KRMA}K@=O%HAS5z%KJ-K2o{zXfxU(&9rjHj*8w@J(DBAoWwp?) zqY0DM#zO8hV8v_}86K;!34Ey1oECgSSI0s`D%Kv_#px?9POe)dF4@;_Hp6LLm zS^I#SanB47Oq@-Q%c7XgkP*H;JqUNxL-0q;gO5+V$;xv*J$RbL{FHrKsk|3Rx4-5p zrat9zw9T<;qb>cU8^wO8fM zTtlNR+Prn6#9b>H&(Rq7st-fK^e~(!EQHc5AL?*hLznj@va~-U@UKlbfzLX)2RS35 zyI>}|EpVjAtrFJ8>7(?_0Gu=o#t5r?z_H3aW^&R>S1|qVaky*Z74UmC3LLJg!zR zGLEiCN}0E>C0dpCLv!)FDQd4+^Slz?DEhd**Y%w01wM*2 zt;sY-#w}0y6ZvhLCG}LHHD>J8jHMuYVh-GV-xYy7by<|xd(>_xpFVrh|~GIz@iFu%GAybiqM22MZY9;pe@t;Ulw z_R5&sCO!OP=`cL*3qgyyMNl}#i`JOPsrUONc3#;J?$I|uaIu}cul|~w5Glmzi@ZrA vBZ{rh(8bkZ!(cEz7|*`WhvcW8q<3{SHIyk`wgdjHS_PY}-f@d24RZen2?GM@ diff --git a/tests/data/save_policy_to_safetensors/aloha_act/grad_stats.safetensors b/tests/data/save_policy_to_safetensors/aloha_act/grad_stats.safetensors index 7dfbc3b35cc394e87cd77018d4f427928b9aa637..bdecb18b3485e24e9973faf6b2ebdf57c90a6b21 100644 GIT binary patch literal 130 zcmWN?%MrpL5CG6SRnUN8`7A8yMqm+UR5C(xuzG!$ckyTT@s@3^W9~{l`no;q-2S&S zZF#)ZJUFXMi_wdm?QYQT5laCVle6T(LgwNS5XmPBu4*5_S1J^%i3tN3(18qd%86Zd NV8rZSmR5Tr@dLHCCXWCB literal 31688 zcmb`Qdz2H^9mm5dTHjI-tu0zzv;`UWB$JuRWG0&cwOGaaKs+d-G3*9ix(~`m9tG(Z zXw_4TXz|#Bz*2enDCMupL^T`H4GHcy(=lFqc?Z-R<;+dT>|ttC@O3DXqQJ8XKo z{?5BHkcX3-u(SnTjr$X-PV2*Sc;f26v{ebYkg7 zP%*e8&4)`Lbwg7O?o3_j)Y45Q!X2p-Ue`RmsU_7or6G+UBu|YIRfz~k>jfb;5-}80 zFyNAHMP2!3P%#A)E=ebrZsK|=7;wq9BA0F|QNe^u(ut)h9Y6?ORl0gQ0QUKnaKu&0oxJ{`O0 zhO0|<7`Y3WREV!qu6%0wP^l2#rSkFeOG(A7`8wswr9@)m$#h4_ZW!%&-nXodKurB%F<;aVZS zY`OBOFef=qPoQ z0MSZ_wJdQwwRlKYh_x(kx}=68902Bm><1t>fZ?*P#mQGCk3M88#M+iPo?1L)OJ!w? zn=X0u5e@*OTYdmy1IR|cXu6GeH)QIaTV%;pCeu`1-;|s(wRv(wT@G`>>O~nOl7rB98WDC zHm9;S=b0{fr9rTwvJxi>rxp%@Q(1xIWXm3P2%O3aoH(9bJZ8%3n`e4h^1K9jQ-$*8 zto7b-YT+<9RcLNK*|J9+>ZS_Sttg&aJiJX6dRt4nVbY1XsGAtCP~bQqYC|wNRcLZ< zj_hQhbzz~_aZ$vkKmkc~sxFXz@x`6XY?LDtD78@Mcaa{xm7LPt zl1w*FOxMZ&+61Xe*u$*c$<&fzRbkeu(oHAwHUXj$&+JceJh^zxko78Vy6hr>T7_A) z@&iyCfL(=IyW-?a${V5)&um3;JhgcERhacFZo2Hz#~rdTt5|*jY6GyWFl$+^eCJ#s z`EdooN|@CwQ8=}5I98Z-EKauUaff7uS;-Q|lZ(exS+nv?mtFDDsW7WkqHt>A@F$J+ zCr-BHQAaf5MVr-GjmK{{6N^VO(nTlYf>Tscq7lzO8I;6Ri+7$46IfXxU3L(-;iU=0%mtt}058)7UY5_7yow-N(F9sn5>G82 zR;CH8tdK5y^dV-NK+IeKas!yIz{^6_6?hYZ5zc)0b!{vroLV@XOcOX+zHr&&4lUCJ zT2>NIEgn{;39PIqU3SIec9$lQvZ8Qm;czld;AHu1*`p3E(*;^q5>G51$w(KSh&eh# zy^{q9K>Q1GsG0*JH-sm51%Bqmk(>-LK#8`pyQN+fwJ8{yE-*B?XP4Pej$FU#jkUPTa%c>abjiKiA1E7Jv5R!EmU`Vcc+AZ9KAwE=jUF7Pt75^qAV zq6?&~D4be2oJjY3Nta#mkTP8$WkuoS!ZA^SljXBz zk2mK=1rg|hQQGHBx*|#G(#Y0o*>yl!O;wXqlsALwy-7yiWaWSn-Gn7{*}8Vo?1Mt z%n(>vAzgM55Hmv{W-b7=0eG1q@Uncq_wMmYbxUKCC)9CH;oSw36#xI@bfftHoTQ;Ua{83HToNta#mkTOFc zWkuoC!r^3wz{&F2vPT`a$_#;)mBdqvhm{!uD+`}6Z-Ssj1cJr|5gP&n9KN4@L z3{ax2e4_3*fuckNhQ>BknKyv~lIX%*=&6`Qof8BtA`mpsmShJ7M~es?O~j(M1x1T^ zC|a_<@y_I=bSjgqZ<1jP}_i`MLZlW z*^tVln@+$S;AjyKN8{y_%g2~K9L*zMb_VdYh=-@~3Dg$gX%P=kOHOT`*qlkGn(FeX zil0{m1}OJgzYG+$DfnB&!{2&IEg7t?%J(>Ku0=dVP6(kk0-K9?*c^``yUL(-5f82N z!BCrl-$huzo0du^OMa7E;r~EIJP^#)-XB430#j#APe_oR2ZS%e3ST)8Y9nwxlXX1- zLsIAw4+OK}{Sm|_kPbq5^Hyb_vj#vhS+Ntg(K~W#BTzb%RXR@w^OOVimMFp%lT|rC z0JQ<=o5|{%AYXDC!z2Wr=mww0a8p>;V3X_V~{zMmANumvXg?^nXKB;aMb2tcP49h zKD6QwrsUHZ49{c@k3=K4hBd;PUb!&Y!NK}W*7}G%YJ0FhleIq&T>kMu$^ZC(cp$_k zP=cT~0rN9`%&$O@9Sm;snLhI8f>0ZQ`>AN4Ci5Su_c2<6RN+RFd{ zMC=2FHhM=+ZUkSC=i`1o4--o641@!~+{^Rae2n-}8-VvkeY`K}vuEfib z4uapL9fIVJK)4e15x@L1qsP7vhB^g!U)0C@^7*n;fciy!)UO0VZ35;O^)Wy8slbZ7 z3?RV#M$VBya$}gjj{|x&COau;VAMwg1vqMRu)wH~1+uitybL7J{LTf1u@%#(tw92# zJ`(8Jn(W|k$1m#Rfg&EYJ!oLmM+5VZAIe^Rpng#w^(#S8n}GR6eax>wkR1%WKw~q++94P6JEl$1zDQD!zA>>A2gZP4Jzl~nF5LYB?4#~m!4WEM}HwQNa&2JsJ zY0~=#Hwi;QA46mx?93m{ULd(a7$SInE2Pk1u|`pEra36^d?bBcg zGjK_dL%2gIXWsI=1NngX$YNzr{yOmv!*8P(OdxL_Qr}wbkbv{iM)pa}A{=?CIo4S8 zW(Lot=59{*Dg%}S^jj8OAiY6nh4^?QdpM74hvfJiZ!CI=gUl<(7&RDI)9$7m$I{%Sx+RyFipItxbWP9N9!I7)qIoTeuC}^&D={Hv0 zy$j5T=Jc~WlYMfB#=c^Woc4g(x8YaToOJ)loqyY6T{`vQjy~-_w%+P&%^qA~*j+cM z9qlLm(wbVeZgX~5o3&C|8;(`~-a75=$JbOH-D$P{JhCM``4M~hX#?9Qu6fn^$$&u} zM~D2<{`cVf)PKBuuf4wZNBZZ1(__Q?FW&UTi!1CYH#Zn(_B}n;^VZ)$pzFEIavN`1 zWnZ*lXZxBTwb}UlUwv`FT05|PL2$xFOYHZIw%ERzFD3H-e)yG(vg4i^Ha!3H{H5n) zSD(Mu-mJZ44ys*ZUorB$@KYDOqBn1UKGHGfEoqC#&?_I0s>TlR)?>uR26d;VaZ z^6Xi;*M7Xw-u~-l#)KPxZa+5R7PIBjZT73?y^(Q~e__9OP&J#au-$jw>h=M{PPMo1 z8__=dPrtD5IBAZWJaUoUy`sD0x-EY~ZtD{*OSf86hK@Eq>t1IMoijuq-)FC7jeIk^ zc6Zk97aO@{!pukQ-Gk3HS6}zMy>QOykrxj=Zr5*Jncdv2*}v~I(D>aO{p`Js&HC8I zkJ)X@9}IQe7qdIw{cYs^U3;x>zq@$o;*Ymmt@Gx#x8~N{`_ngV{`>uVt=RvDc64W+ zu>eCM_wGU0dFgxq)=k36d3-qsCzQTUs%co`UpAoWy9Xs^C-x^?VKk`BQkj1O) z!!JF!@ss@{?d^B3>iEX9dThv+?|0lV`Jg@KhNIzm0oz=V-W1yPY{2~N3nTR4yZ1#7 zy>o$f`hPDrKbp5>)4309Fwf|>p)+^xe`~&1bK2!IYCB+MisL2<%sL@qfk=hnIe8J@EMG#P#ditmTb%BKXcY>!%YQjoz^b4`_*&zAE=o!|Bu$|nX3{*pV%2)e($64TyRM2 znw8&41n&5EcH`z);))OZ$DaA}^@(>jth76q*olRKRk`td=EltfkL9{|+#VmawY6*K zpb?4dPo8J)dHut9_j`-1)V6=sOqw*RdiKD7)C3khW%lngDzW3H4O^aC)E3{>_X}Iw zE*O;v9IeX*=9-E2Sp&BIVAWNL*2Pa-{U%$9^sFPDOOLe1@1H)Q^OfsT@gJVE%v$R62(C;OliSd(lVwcn5uJ literal 68 zcmcCufPiYHoc!WqC97(sl# diff --git a/tests/data/save_policy_to_safetensors/aloha_act/param_stats.safetensors b/tests/data/save_policy_to_safetensors/aloha_act/param_stats.safetensors index 7a2e0e70e8cf106ba93c3f7c14b1e333fe982190..26d9192401bf3e5b0c9586feff331389420b625c 100644 GIT binary patch literal 130 zcmWN?NfN>!5CFh?Ucm2{l8OGjD`*FNgDW!Tc0deW(-I|h@El!tfqUzZt7^{whn;|I-CBQp)A z9IYRO*a))?M={`J<};Wp_|J@Lfa5si51RLk+&`xe>Amk!pGPc9!?BzjTxC6>qgZ7b^X0RkB1jQudw zW?)ZYUwxXm=tf#=b{M$}m{eGTa*=(k@CzD(?(Endat*ro( zVT$(U=|)APwgv|?MF*=FraLNrFH5vBB9Gc0zs)7eSrD#&b#`KOdiw|{;0Oz)8b@vp z3q@47x|`D-9%RlGWsVA@HVC;hMY;3vGW`R6npYeIBV2XT5W=a2L+eaY>m-Hi&H;DJ zOi}C9@zmmBbEar>zUi7*8U!1rC~>NAYT*z#QxrH!w(e1fz?q`JspHATW2U0M`KCuS z&r6UuQz>tJt(U^7g~QxTrMVTeb&opK%~YyeRXnwLc$=y8w!w67_sO`ZI~cE0;3Obw zLohj0X>wkU?qr~K5vA5iQN*S|0ZDbLu8^v)sMJrgpn#;hfXhhamhej~rR4d5GzSF^ zSi*faVuqP7LSi}>js-=uoyF2o_yXZV#!YLT$Ch@YiRBAn@|F5O?7Ak~QYn3Xq~S~9FEB3e~Cb24ueAR6(){#3`4i^mK_uac(g zE)u9!L{zIF0JQmGgFA&ZEL6$GF*0LzMq zmgULk=K{@-D+o3sqGqYWsfEL_BBEnSvUQI;Br76HmO7qXJffq>329jinl2kv6+7#T$65UBbYD}C#0ja+AP)#DYg!LeLRy{504hsHdiT*`kQQP8A zhKWK}g9lH*JUVbvDiKiCNpAA!3StVU}5V+xGDa6bJpf&(6vlL!d&eyz(Alk4L zT2>QJEgn{8DXgrLu6y($W|l(CJOFY7n6ARh!mSN>6M_*geE9WjEFqj)IGoHDa<=YKhn7VZT2>QJEFQ^7 zSDlC@Izzpa1qeX$3v#$s0wOnrCwCQo=Eu>T3@|{cwz9XS5{lXs3@xfKG_k2_n<0-0 zh89&Anw&&!34#_?2%0ZQcTl(&7F9T!ibZV;iWXHUn%S5)A=-#4w5%qcT0E>Qs<5(3 zy6zw#W>JNhc>v@FFldFBmGd>PB8W!3e8bnoQ;Ua{MHN<7N!LC45VNR4%sc>U1Msq_ z!pqD?ya~ZZR3T+m;nc$6WKo5am9uq^JG3mS(6X9%YVokLsKUyM>AEW(QWjN6Syec> za7LDzvO7o?1Mt%vM;LDbpG11sM>4s@j=3tFtemZT+@WQ*Ld$C6sl~&}Y=xB- z({)!oq|8=GSyec-a5$N*aI$i??or3BGFzc#HSyHqVP&?$$|5Jtn;>XRA!t$%u^}+P zk^7k+M{_d30HxZ>C+mI_Cp2mDUEjiG?qCcBV zWxC3!hM!ji1}N`YzYY|&Dfk=n@wXCbFpJgI_#Vg2HRdC7N(i+P*c|h*IT=HDl|k#6 zkJbfXsLjCdnCN$FQt4#PZ*m*_7pRB_Lb%!oBgjo)>Z0i>3A*!u@G(*N>VZ%jf$KS< z>nRwTLXUVLgbg2zAU1(?5bB$^sr#HY0EQ!qowAM6$f=D$=^Rn%d>O)14%Azs2sa#2 z<$?gz2B2?_sBemV&0T?bAcBTnJVu&RfVMfJwn>}N-4W0>N7S|;0CfsbHb+!8MZWG7 zplpt)Y{3X>6OcAXlr~>Nx{LJXY_rXx70Pl+hcwaeRcM4FySb+M~AgE2i{9*y-Cq5O} zke2}jSkTBN5=d?g(+_Y!zs7VY1r3Y^XrKZ|Z4MR~3$Q?u)|i)p1X|F!pfa{<8nrb@ zU@SlaeOuEV9PaqV0z6Q~qqYYPj0I?5`SC;Ds}IyK7NCAL2x=2BzgU3zRS3F+f&9e+ zi@6&2ex7$EZ&O9t^P;4&W5^ zHKO>$qUQGC=|BYdUloGf1pE&KwE&h0+Vg+}hT8%xum(bI1QrNJ(99ccK^whtA#O<4IEmwehA+X9n}ZvI7PJn$ zH0}L^n}p%E07DcX>?|M65|G>=3=zDb6;f%iTBE2p(>Mye07(@4tLBm0gCqi1|7tmb z(;gs95ybjeOf8Y#B4059u2_BsQv2!zK?Fp-OllzHMqr5emU{Jxq{7f15Ht}O^^)*l z$j!hLK~XP*GDTD3cs?KjrdR_ZHv&@xqrS1j0SoLNG z-=*elPWLJU76%3`3n`G^AYUN?-Y6c<6WXCUKJJZGFL4Mspw&Ebdt4i<-sVs*PJ4jd z9IM{%KqQh|kIEWK8-(Wiwu02GNW$|&SkxGh$_{XyM|_KdhOR=ozIE)KaJ zZj4p$gAg&Dy$yc3CXeR%{^yFMC!bzEAGyR5%a7b*dp7nQIwx?9|Mf$PAveQ4vFhy* z;w(zfa*NIAt}XtO2!Q;}!LOYXh}sZ;Sp+ze;TIo`J~>Yja|%!rag;rPtMfZZ4&fxp zT97oQJt?k{RWFGUmxLNdo)-7X(is)HSrqU zp1#z|^k6dEzp9mgk|4itma{l;VtMEQ=?(HVQZ!8>4=9cPxOJ9Z5~0B(x5urs^p1#n za1%REcXZr2ORtO|vdC?5+bq37qJEU#!1xWy(%U3xC~{L=F-vcjsF1R$ft7esba}HC zy-0UL+$u}2m#B+CZh=c==^Yb60KI8&X)L{Eq8SU1H;G$T1>qIa@ZxUP*OK+a24?%8%YvRP)CpOHQHEY)!pD9e=`aAbq-RBJd_nJNK zBW$wswHMxzc>2mqM()a89siFjdkRm@j>Yf3|3f2&b6Gq)@BTvjEGs@XGFpgVF+ZNU z?!LnJE}Is=_KtZYWB1$^H|~0_aM9su@!dPN7VcPZ*q!{T?S7y?fKGm9YzkXy#;gKsJa{qMR?!u1U_qzZ4{f7%AVH?-?l`TH(e=6=Whm*GoJ+w1;c?9C^4&K-5T{&Yn?apU>!hJ%;nE}J>( z#t-!kEqH4x+qvaX?#b7WyYt324K4pgh|QcaWq8diR(yWp@}Z8`LaeFj(;fVU^MC!u z(d}0p3bEtIH+Jj|wXh2wJ~}kxyCHTAKVNue3wv~DXz1Hh>T&n(2iFh1 z@@j}(JZ<{$6D`sBy2lO--ScdS9l_t z9Ce4LjIhIp=DF9dzk_Y*7X?E$zKZ&eZ9Z!4eJaEz;pcF33){Ex7uLaj zQ`!G7{(Sx`yN|oup4wy`{dtIu9(l`I+Yycb&ABx{b9~%AXU032YmK^VKl^NM`^@?7 zlEGPaTl1)U)158bTmBGYlabecMhjbo1{VkDiO*)l5)P0*hX?j!q z+s|(rIv0OFym8FAeZLj&eBK`ZpPetb^Zso~{_gfsch}>o!k4d`@BZYXp~}@p(`U7?Yxn;=^24{MvUd%ik>7i6Q~aN|j1N7zJH)PE zcrZV2+=@T`;g0;+4}R?SPWmHTGHukgZ*XGw-7?$#V&|Xjw)c&?>!02+bjwdF`QN$k zXrv2s;`P5ixrKG+9*lhW*i^P4^P~LgQ<~!UKYCy&_Ja`n(Bcma|L5#z{IxfK8(Z8t z=Kgfwxoq=kTc)bg;E`Gi3fOQ!A<0b5m z)Bb%nTXgw?p(*{-S*&vkyM1a?{6Xxm-h(0bTK-7xk-s7D_Gb#`^m#5=fY`)WpmF?T>0v661i=& zvM6zB;>GxxH%Ah~WBJ4>f4(~L(%Vgm+MoXcP=QSt diff --git a/tests/scripts/save_policy_to_safetensor.py b/tests/scripts/save_policy_to_safetensor.py index e79a94ff..ccdd204c 100644 --- a/tests/scripts/save_policy_to_safetensor.py +++ b/tests/scripts/save_policy_to_safetensor.py @@ -19,6 +19,7 @@ from pathlib import Path import torch from safetensors.torch import save_file +from lerobot import available_policies_per_env from lerobot.common.datasets.factory import make_dataset from lerobot.common.policies.factory import make_policy from lerobot.common.utils.utils import init_hydra_config, set_global_seed @@ -26,15 +27,14 @@ from lerobot.scripts.train import make_optimizer_and_scheduler from tests.utils import DEFAULT_CONFIG_PATH -def get_policy_stats(env_name, policy_name, extra_overrides=None): +def get_policy_stats(env_name, policy_name): cfg = init_hydra_config( DEFAULT_CONFIG_PATH, overrides=[ f"env={env_name}", f"policy={policy_name}", "device=cpu", - ] - + extra_overrides, + ], ) set_global_seed(1337) dataset = make_dataset(cfg) @@ -88,14 +88,14 @@ def get_policy_stats(env_name, policy_name, extra_overrides=None): return output_dict, grad_stats, param_stats, actions -def save_policy_to_safetensors(output_dir, env_name, policy_name, extra_overrides): +def save_policy_to_safetensors(output_dir, env_name, policy_name): env_policy_dir = Path(output_dir) / f"{env_name}_{policy_name}" if env_policy_dir.exists(): shutil.rmtree(env_policy_dir) env_policy_dir.mkdir(parents=True, exist_ok=True) - output_dict, grad_stats, param_stats, actions = get_policy_stats(env_name, policy_name, extra_overrides) + output_dict, grad_stats, param_stats, actions = get_policy_stats(env_name, policy_name) save_file(output_dict, env_policy_dir / "output_dict.safetensors") save_file(grad_stats, env_policy_dir / "grad_stats.safetensors") save_file(param_stats, env_policy_dir / "param_stats.safetensors") @@ -103,8 +103,6 @@ def save_policy_to_safetensors(output_dir, env_name, policy_name, extra_override if __name__ == "__main__": - # Instructions: include the policies that you want to save artifacts for here. Please make sure to revert - # your changes when you are done. - env_policies = [] - for env, policy, extra_overrides in env_policies: - save_policy_to_safetensors("tests/data/save_policy_to_safetensors", env, policy, extra_overrides) + for env, policies in available_policies_per_env.items(): + for policy in policies: + save_policy_to_safetensors("tests/data/save_policy_to_safetensors", env, policy)