From 4d7d41cdee1e2406746ff38739fda2c58586e811 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Thu, 16 May 2024 15:43:25 +0100 Subject: [PATCH] Fix act action queue (#185) --- lerobot/common/policies/act/modeling_act.py | 6 +++--- .../aloha_act/actions.safetensors | Bin 5104 -> 5104 bytes .../aloha_act/grad_stats.safetensors | Bin 31688 -> 31688 bytes 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/lerobot/common/policies/act/modeling_act.py b/lerobot/common/policies/act/modeling_act.py index 4a8df1ce..3aab03cf 100644 --- a/lerobot/common/policies/act/modeling_act.py +++ b/lerobot/common/policies/act/modeling_act.py @@ -98,13 +98,13 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin): batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4) if len(self._action_queue) == 0: - # `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue - # effectively has shape (n_action_steps, batch_size, *), hence the transpose. - actions = self.model(batch)[0][: self.config.n_action_steps] + actions = self.model(batch)[0][:, : self.config.n_action_steps] # TODO(rcadene): make _forward return output dictionary? actions = self.unnormalize_outputs({"action": actions})["action"] + # `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue + # effectively has shape (n_action_steps, batch_size, *), hence the transpose. self._action_queue.extend(actions.transpose(0, 1)) return self._action_queue.popleft() diff --git a/tests/data/save_policy_to_safetensors/aloha_act/actions.safetensors b/tests/data/save_policy_to_safetensors/aloha_act/actions.safetensors index 7e7ad8e1df015d0ff52d689b317b8d77b3f380fa..3c9447d7fa0b68143216f21c4d9cf5c075253fe4 100644 GIT binary patch literal 5104 zcmb8y`CpCc8wc=ECJFH+l}a6H;gptB=P1v0x19DxyO6P*bCPsUY15ExgprtL#?n}# zO<7tjW9_<7N>MWzW*AG85mBVbl-BPt|AFVX=Xt%}ulw~}_kDd{_Z2--<-cb=&zk4B zo+nNECSJyKO=;fc^7K`{cf+c*Oba~S9rEyVFiL#_bC{;ci4RDXtiTL;m9 zC3E;VeSuRc&AN7 z{j+4uYd#-ORt><@(V>`-UkWFTJZRT-1sQLXv-8$wC>z@Vm3PKN-14d5{K*_&ymF!H zQYqVXY!-s*Ak@4LMxoVND3E(nvTQXeCdrw8=UANftOle%jD(nNqv1%)Y&49qBfBwC zjCVi>otO&C_6foHukzrxSRcB$HI?SQNo3MLXQB^R3(cYrT>4=(*j&fQ#@BNx^0thP z+%**)T}H65zCdhsI0^?+{U~F7JS8+GG0m0snAuzdnvOl(?_ZAxD=i_$@AjeGUK#UA z(nXaED$Ft`9Oqstf|_Vw>bkmy94{v@@lrA3jVcg7?cxM~O@v*KgsAQ5OZ6|J*!?se zymnrN4gL{|=Pnii&(D`ei{t6xp9w5Y>WeN{Zo=N5KXMN<^q@FlHjX;$PCh~jyI-e^ z5y=DKb0`cWZyg2ASMJmrxQep>Q81J9({Z8wWyo$G;=Fh9pf%GTvo?Cs%Lh{S+{h4h zXQ;5}b3;+II3KK%y(m|rpgl2h%+JyhPv+FX;UiyxMUxTO|89wo{$k{nEM>nn&cc9` zeK4&l1jFM`!wGdynrIh8o1^2Hc(w`l7vF)LRyEl0SQ|#Xw89Q!Px`e-!j=`7qyAU@ zaQok29KE*;*135R$D&AkY#f_+Wi0-9`zBIRo=v6}VzE|AaUv#0})EEPwIns}LQg(a2A*LSeg{|TR_`azSB7@!N z>u2%gTo%hRt9f|EunHyztAk-D;6B56>BR{O|zWi3`W^1w~M&??Klr zlW1>mEaShMi^~lg!APbBfrXPHQ^Lc4Ivwcbex;M!;M#&7Sn3jvy_e2_eZL!hyC|N* zz2$6PJ`cC5-Gkao6X52(F(6hoLGxdL4$YP^hjG@Ja;^td48t(n_ZKMUxs!iIEUi$J zv!tCGc*?W^DtbqP^hF2PwAB=|c-9oVB8ug78(^PhKl~9Ch9Sv?U@_B+R@$wl(%2+6 zu-gL1uKpFQp_^N`q=T!QXofz8LYiw4#f;V&;D>R8kT4C9U8+{@L&|it zJA4h2QZzsuH5CFnXJfj*6Aj;xu@ZkXoTjP5B9;f^qo6|YKjcB?W@{*Ohnx)*i!gKe z4h)^01i$%Bh0e4Y_|Kp71QrSu@sXlXE*mMYc{nJ zE`68;ho)(O@YYN;;W<*>BPk1fXoD-0`{B;tfmlEBEC}zo)7ERzq|1{tzZ5k*`9~f2 zG^oNx|1NG{y*Ww)EGc$R6w_8W#_x@X!9*(veeUGLDCR{KO%(Aa&nL)vx@JmF!fR$TtD2; zrFDPeI`ssYG{v2+#YeGa;l|j$Wf)R7g`(807`9Y;Q~w+VwbduEdPhLvpEV%J>*4ek zj)arP1;`KapigNl*|H->xGZuQM0vre`mO*jHh7b3WGwlOOJpq<-SOL$Tkr(>IC0)s zSbTRD^5Z!=dvOJGE1iaQ)dP@mBN*qcJqDXg80{%tP3yQgmYVUZqPZO5;2d=jY|??H zn)Y~3&xNu)q|Cw42($H7Sg>OdcHPg1jLoj}r&b(U*2J-j{s|cIst&};j8I&h3X==@ z_}BMB3Og@hu2(FvKD-Y!W(J{1R07HKT_|%Pmd@UYW8WR^udvbo6(ar~56!I_a7ku` zfch-oF|FqrO2Y8zIOp(Kt)xYku9x7>n=J0l^Z=mR&P?A7+W zHuN|~%DNBeVr2FpcrFXVShYM@YUWN=_fkl2dkWK8{l4O+pbk`ae&k{rJ2|MWw=jJ}HNIVku);Xal(Tk2} zNLk!NU5r~j1h))Bv8?bo>`?QhNM+wTijtYwKo>i{y9v$%-P|z;b+CNzgj!dR|0x(ZZ_9^gogNg_C8t1#WcJ3x7;jvwhHhamR~s}DzWaG5dPLjPR+)r3 zm1v{$oqpKg7>cdaioy1vn9AE$QGutNWuD-nX>SE+r=F-@HDSO|j zhu*V>!E9>?)|@^G< zz^y14({^5jnL{3w)+M8M<$9-EG;nP|4J3Jtgpt#pbEi%CI6lCFcC3)GmP=Ex)_D-# z90|snWyP>W%ZGk_yoSb2PiFU{XJU+1Ef^_2a9_`R$+{QUK-1rs@(xB*WKRO~d*p(*=2SuLwr(!iO}br3wdseVgdIjWB=$OsQfzw z-`AgjmT+Ivez%ewOcR-(t2dhMx(P`edbr~oG~j}!Io_&tra!eM%%pHChWYk^@sm(o z@bnl!x|o6wt)_BW95Z=lhG`L3LE1LRbsm}oFI{ahOC~1uT~gK-sgKSxhu}Mt5S&q- z4-pDa%0mV1WAV%+nc-HKS}?c#3MLgyg4`cw;df7+$jesB^y+7!x6LQ;S{{NGMW^6} zwFgbEm6Kq5JUhJH9L>hpL&k+spmtFmR;}S(rg6%9&$gBscvW`r4*0^=Nh$exQU3g`HdmB~QmBIjAIwu#T*FCA!HJ0r4*Rsr} z|KXVhH=yM}AJ@D<9rP|aV#-4?DSnl*H~&t@5n3utF-bWG+%E=*^Q5+>STcQ?#1115 z{rV6x$9?3uNOefpb3&_I9u!+8WAROfSY@cfCVK|sl~KoG9(Ynh@M^mKDv9wMEil~Z z8g%n|xs@v>L2HB&o*#CgampMFbuz|niv~fwFbt=iD}Y8nF$r!a(&6uwbJc4b>={`B zCYw}Y->0##T{H{3KRMA}K@=O%HAS5z%KJ-K2o{zXfxU(&9rjHj*8w@J(DBAoWwp?) zqY0DM#zO8hV8v_}86K;!34Ey1oECgSSI0s`D%Kv_#px?9POe)dF4@;_Hp6LLm zS^I#SanB47Oq@-Q%c7XgkP*H;JqUNxL-0q;gO5+V$;xv*J$RbL{FHrKsk|3Rx4-5p zrat9zw9T<;qb>cU8^wO8fM zTtlNR+Prn6#9b>H&(Rq7st-fK^e~(!EQHc5AL?*hLznj@va~-U@UKlbfzLX)2RS35 zyI>}|EpVjAtrFJ8>7(?_0Gu=o#t5r?z_H3aW^&R>S1|qVaky*Z74UmC3LLJg!zR zGLEiCN}0E>C0dpCLv!)FDQd4+^Slz?DEhd**Y%w01wM*2 zt;sY-#w}0y6ZvhLCG}LHHD>J8jHMuYVh-GV-xYy7by<|xd(>_xpFVrh|~GIz@iFu%GAybiqM22MZY9;pe@t;Ulw z_R5&sCO!OP=`cL*3qgyyMNl}#i`JOPsrUONc3#;J?$I|uaIu}cul|~w5Glmzi@ZrA vBZ{rh(8bkZ!(cEz7|*`WhvcW8q<3{SHIyk`wgdjHS_PY}-f@d24RZen2?GM@ literal 5104 zcmb7{`CpCc8^=#k6d^5WG3{DT+D>&^p6k}4PDIXGT~etenfW$oNElI+$kM)ikNFQgzdg_E_5S5^-`D&3T(PRW|DFwIj%J=4%p{7h zlcZ*zW&tkFW_+`h*p;8&mI(Pyf@OR&$x6k_<%uydDN;p>@~b0XAP_Fw#5Z<vEyL zg&uC7Sm)1Wl^oy6MIiVDyL>ME3FJGudpLih1)ohTLB6xI!0{99`nmKcknij!{4X-M z&!&|i-^I!0ztG%2m;MCuU0j9#r6&ArS_$$6jxPVD=JC1oC-BovQKq)Z#63J7uXsy5GJxp@I&J{)!?{s0-R}@1}BCr@!DxulB-DB zxF`n{W%h#S{0J-@J_F4wW)phI>CTEo)>37QN5r-8V5S!AIjaStM>BA3D@R#}B&^8P z9@`UoL4x6^@k-p~o?pyz68XT9d|de7!JFb^|8ayyw=uRfn%E-7woAkOtmMSy7-d`ZW!L z!=#0HptAr}bLP@zT^UubRIrjM5g1oq4xzR} zuF%beqO7^pwk4ig(iM#Fu^87LssK^O2d)UUA%3eF+Gcu@`rT;e*lvP~kU==3DaP_I zb76m-9|=w;k++td83`@XTS4IVZVbHppbZvJ_~`nEkuONuRW&QTXvt$Y8p3gVd;$1Y z1dzq}R4O~cc>gaH`78iv2QmcaT~vq^eg zPC7|)<~^u^fi^YpL{k;cZEffF2istCo)g)BCuROiO>w!}2;6=WhQ1lOkku%n_>zq@ zusM|-ZZ*b9EIzuzsv7p1c@D=Obd6wa^5Y>Wsj>tr6JXnFsr;gDB&NHRP$I zV29raU_eV1H0|x?ROe}e^EN9qZFZw!el)Y5Ivu0Z`=LQijFwutu%y<9I^APvr%cYq zj(byT=yVb4#bbbHG#Lcr9I+|Ki*~J*FxPozc=7HK47Y~k&%fk@pxB>GQsOD_p7I=g zC!>o`1rTB0Eq*3eC&a(o5OM0*>iAsyDvqWt|Hx)a&|el zv-Ee5D!BJR zma5v6fv86nV^kjfO3RB$My z3@UT_xDSUua&wJ@_na)p=TBLp>uXQCQxU~1zA(lQYx>~Iui?m5<^iwUmyY|yQKM=iGb%DfpGgFn zd(}bkzbVl8hXWq8VI=jIFr)LPxEO{&YkU|k)-Hssf!X9|E~oVUiA)wRz=L(Q@N@G7 zxKgbLp>8v9&esC!d>PFe7dv7_aSzNBh2fvk#ZWu!M`vfqsJ14N)p>q_Uh+GTwr~Q> zPS$`u|5)SWPH#$E9?kNO&BP5IJ&>|76h|dqgge2rY1-yh^e#1#`RA))AJ#y|X%)zm zzTs#bA2re)X>_=RL5KmGWDWp-ZYbWKQ2+vG<@~uOQ{XRYOnh+~zI}KT2A_9vi_<%} zukY~j&nIrQW0{1tXBePDHUz7i=Ha}tIbgt_Lm8!U)VyOYYf`#=cVG?hFZFQKTZg!k zUGCU;JAig%Nm%3&L%fhT2n$z5pkMtd=u`T7+w*uTcTQ#BuLUe~y8*K1E-vjKRp|Z8 z9nDQdRChwkTDBVCiQ++cksgLm#U~*AUJwPHiKAb?RNj|#e-ufpV5!eXZu6z_aHGu> zLwr05XG$4=w*~TAhrs@4rFZ|$hKLa#DqX6eX}1$t(>7zY_^}iUs?;DbbPSm8wZrCc zpo@oMSpF9_*!_aXoF0Z_ta%>X)0j=FS!-#kwVXB8x}r_=P1xwI1=Z`+z(!z>+ACZs z*eiy02Dszm#e<-l5su$%J_}7UKdQN#Of6k+V8tG%16*k_3kl7LLk zNLlnWSG=YCKZq`eVOV$(EPCos@1G>n^Ib}=YYdhit%7rPL)@%A_qlazt+7ewME5Vn zu*)CJu~U=Bw)uqP=2OREnQb8b(72v1pGsvVcc!ED)f(`B)yWxo-sjGpvcbeYcRKic z44Zt!9KV~a!fsWCV8yAUV3!(1*Z*BZe;!O@l}BdcP;nKA_rBvYeA>9PRRRPZ5j}qt z!~S?^fg9KGSaawC^wli@FHaFw-&sSAjtZt2Np-o2x>}>4ftQ=H;3e8Z|?Lm!qqgeSj z(~u?i!oa%)xIZHw@-=;F{-SuAaW;XS`c@zP`>#RS5e@jJQ4f+f@p0pHAdjUI#u=L6 z_Td3AeAL`@cKMUsG?ttmCoswJQE0fL76cxnA#(dWZpUnE6h3gE{xk{eoN0(NNBSYw zU5raQ@*s0TAcY=Ep|W-b^UpTKr>Ad0$v`is`>~6AaLyXdF1XPmg@n~=8{+xuL3naK z42_4fA!Bn8Z3$mZNA9FD-7Z^v5LgR}mqT3T3LYHsa={Nr11P^i!u(Z?@Ib);e6?&5 zjvhG);xZ9A?~kWaixg&I%Hj5hRUk6$;o2sThL$!Ld@@c%_e!N~BN?FjkpW0QyZ{d! zDS+myB0Bvvj=W_lOuWk%mk(FNy3BrVS)vw%eYC_vPj7F$qO zwlCKYA8A*CHID~YDTXk@S)=yP-Zc9AD0Xwlbkx!ufr*I`SQ303_Pv}%yKc#-WxkwM z4v)rB76ifi7|t`g8kq ztXy3SXH;|{wL>5NxM+iCg54NNzxbXZS{B$Q2SHCQRRC8rdJ{L#12jpz(qQ|BE z0o7nsrUw7B>f(NlQ2H&?nI`)y=g89-_m_=8t!5amRvd@j`AW~br_n{1G-h#L4Ud%8 zz~EnfT$!?)TY5U+vLl^fxsP#()T5RM0mbKsm^An8=d=-@WxeTkllV_#LmVC4vx z@{Kn9)NO{7CwkL{vM5$xHy!VK4?>}`?{2u84O`lMD1KKW1wWFpmIf7Ed>CQU2~Egr z(+9qxJw}h2MYBW_W_`mPFI0@czYb!wzMc!MfBRFax`L8i5?E=wA?{pP3-{_LgV_mP zhzPL2^;g~KSVc6eb9cs`!hYcCiE;9lvoOEGm*#FuBHb9J2S<&;XFuNtg~?>lzdaV_ zY_P(gUjeQ5ie_fn4tVBXKQu3kzU3 zfPE0XA}Q+)H^Slz!=RC(?5&S-A#r()FJ7?3*j@6`*kia@d+2b`$gm0By=~S#f#C6zVRlN_% zGNp`H`~{{i?S~&`EX3`%3n3y#L{FTPNN19qDW=ax)iXE2`IrW5I;jhSi{@zG>`o@i zy>aSGd)zzeBRD23K<|AeQ1D{_<@w5JaE@}{47bJhtG8hK6%B|IjDbd5D;y_eq}-pR_6sSyUCLYp zQ*gI>KU669#v)Z^|J0jLSNF!!gh{#=lj<<)iIfRfQ(}A0S}~5&!@I diff --git a/tests/data/save_policy_to_safetensors/aloha_act/grad_stats.safetensors b/tests/data/save_policy_to_safetensors/aloha_act/grad_stats.safetensors index 5188d8f428cdddf509da7f91f29950d9c80ad309..7dfbc3b35cc394e87cd77018d4f427928b9aa637 100644 GIT binary patch delta 110 zcmV-!0FnR5_W{WF0k9BvQS|F*yWiJSJZkQmIefA!Jk{F(xgs{;JMQ7>IWSb)JK|an zIg=~7I~5V|xpws>JPm4_xc^&QI}Rt-xeTT(Je8;;IfRu%JHo>fxUd`sI~9{*cNYn| QUje!FRLwdIWJV(JKkCj zIg=~7I}{P{xpws>JPvA`xc^&QI}az;xeTT(Je8;;IfRu%JHo>fxUd`sI}(#&cNYn` QUje!FRLwd=vw?R<1in}@TmS$7