From e89521dfa02f42701c3732e67154300257c5c443 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Thu, 9 May 2024 13:42:12 +0100 Subject: [PATCH] Enable tests for TD-MPC (#160) --- Makefile | 7 +++---- lerobot/configs/policy/tdmpc.yaml | 2 +- .../xarm_tdmpc/actions.safetensors | Bin 0 -> 928 bytes .../xarm_tdmpc/grad_stats.safetensors | Bin 0 -> 16904 bytes .../xarm_tdmpc/output_dict.safetensors | Bin 0 -> 240 bytes .../xarm_tdmpc/param_stats.safetensors | Bin 0 -> 36312 bytes tests/test_policies.py | 2 +- 7 files changed, 5 insertions(+), 6 deletions(-) create mode 100644 tests/data/save_policy_to_safetensors/xarm_tdmpc/actions.safetensors create mode 100644 tests/data/save_policy_to_safetensors/xarm_tdmpc/grad_stats.safetensors create mode 100644 tests/data/save_policy_to_safetensors/xarm_tdmpc/output_dict.safetensors create mode 100644 tests/data/save_policy_to_safetensors/xarm_tdmpc/param_stats.safetensors diff --git a/Makefile b/Makefile index 07aa4e97..a0163f94 100644 --- a/Makefile +++ b/Makefile @@ -22,9 +22,8 @@ test-end-to-end: ${MAKE} test-act-ete-eval ${MAKE} test-diffusion-ete-train ${MAKE} test-diffusion-ete-eval - # TODO(rcadene, alexander-soare): enable end-to-end tests for tdmpc - # ${MAKE} test-tdmpc-ete-train - # ${MAKE} test-tdmpc-ete-eval + ${MAKE} test-tdmpc-ete-train + ${MAKE} test-tdmpc-ete-eval ${MAKE} test-default-ete-eval test-act-ete-train: @@ -80,7 +79,7 @@ test-tdmpc-ete-train: policy=tdmpc \ env=xarm \ env.task=XarmLift-v0 \ - dataset_repo_id=lerobot/xarm_lift_medium_replay \ + dataset_repo_id=lerobot/xarm_lift_medium \ wandb.enable=False \ training.offline_steps=2 \ training.online_steps=2 \ diff --git a/lerobot/configs/policy/tdmpc.yaml b/lerobot/configs/policy/tdmpc.yaml index eb89033b..7e736850 100644 --- a/lerobot/configs/policy/tdmpc.yaml +++ b/lerobot/configs/policy/tdmpc.yaml @@ -1,7 +1,7 @@ # @package _global_ seed: 1 -dataset_repo_id: lerobot/xarm_lift_medium_replay +dataset_repo_id: lerobot/xarm_lift_medium training: offline_steps: 25000 diff --git a/tests/data/save_policy_to_safetensors/xarm_tdmpc/actions.safetensors b/tests/data/save_policy_to_safetensors/xarm_tdmpc/actions.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..0339ca0e30448daa5b30bea25c2dba2687927db5 GIT binary patch literal 928 zcmbb3dB{yRuB^{;Wj6@JG+Q?MLBvwZ$C9xziK0ht3IJKl0 zC~BZ%Vqg$ktD|H{s!pIj9Se}&S_S95F?QNQfA>9a5#Cq#NZVe}uG8-ExhMNxTw&Q4 zd_~iKUT}!rlip4HmS$+~YkqEQ|4FCa?xORaeN6GT`~K9J+WYu$+NGKv-xqthVV~*| z9edl@t+qR3PVd{7Ww%eaP|;qz>8$PZ(o6evZy(uPSs-sO{AjM-P46rFKA2nV%kk2; zKWDCJ_d@RHK8~LZ`zj7<*=IbHv{U@`b)Wq%jeU=k6zym0|FShbdw!pA3ERE{lN9Yg zbnw`nwLQ1*F{j|Z8%3h_YdG%PN(NorCv3*JPil>Xy|4dDTb(H{_tk|x+UpsvX8-#c zgPqLcqx&}g5#P5%P|AMMG6_5LqEGwk+W+jWd?;nVyQ$Icf8OVP3LeMzUfd&R-`0J` zHpk)dzHE;xd!J8Hv7efL)Hb;C>AuQt*S&A^l%O*7Gy9ZF^X>i~ md$P~s%(1<0qDJg(fAl#@c{r>@UM3O literal 0 HcmV?d00001 diff --git a/tests/data/save_policy_to_safetensors/xarm_tdmpc/grad_stats.safetensors b/tests/data/save_policy_to_safetensors/xarm_tdmpc/grad_stats.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..5520c643c53ca0f1b417465c3d0336aa96a990e0 GIT binary patch literal 16904 zcmb`OZHQgf6^17UYY8MGRIF018UGmFj`!@d&p!7vNDHA=q6zAl(!}xVT*s;TFmp#j zB0(Z;u_DcnVnu5)hG;~YNo+C6#6hF@fmKs~SRzVQgs8N^?;;3-@0r}}*=wJ3=VmQB zfoL+h&sux!Is19fxigD=U_0pY;3Q93t%N8`3uR&j3pv20?o zR&j3n9Tu~#*ea77ek`4s!54mR{*K0M%@}2J<4Gxg6HBbJ zx%oR9vo%{%HW+^-J;JO}HkiI+aoZ}ZG#Gw{5?+<5R%tMQkH%~*wW=J9KN=98W95=6 z=ceCbG21HR+|Lc)ojhC{mCLGFyYY#mWo0hdpvG?&~X&{SWAyLkU$GA2uEYLhHVIqoWZ1% zgO#ID%N~#09<&^VUe*J`V-6(Hf}_dNxUE6VQE282CY_wcPK0{)c-;1&=P2~E7N2wQ za})~N(YUQa&{1gU3?`i%EFFcK_IT{}uo0oBWAQl$Pe-As9gW)>M6E+nXE5R9zydG0 zAUqzkJ)nRKPNrCT&Vi(KXlh5}wgyq_P}CVrI=Sp-IdrwhK zmgfLcIuv!*nsjn-wGLhF@wk(Ns&%O9SbENZ1X}KbkmWgm0LyKI_?!bv>Cn~LQad}r zog7rHLsc8;IR}!`p{X5>+ZsfzLs4fi>DL@wtwUFPJa&87h)~r=e9pntIyAMTaa)6^ zbtvi#CY>BytwUFPJZ^hXwGLHn#OEAtsC8&+N8`2zQR`6D8B97kxH=A9?eUoH0R>!e zG8ySP2NGz((d20C*6>I*6m3^CP~4`V=$I%v zBS{AaMaM+Z9*o;q_Dq}Tc~E+=0i$A~W{2W71ue%!%Na>JEoeC=TJ~V<#xQxJ5v<oLHaEA%#4ZON75odSR|W18 z5Ecvx3v(2ZP98V#hUACdHMq+_oG>I#bS)7HzsU@Em5_MRvjlSyprgV`Z8jDm+zb>z z5kG>=wTw;?_CoOeKO~ZLt)XQn4(O=h-$wPSz+D0YOF{z690jCXfee$74AZ*?cNvH@ z35hgaO9=A_=qe!zr&k5;5)gP25_sk)Al(UMpoCs=!?W0#`x;*Bk|; zJAr2}3CUo+YjBr=NS2UDcA`XkygB~=$WNfFq%fg%tH50X0$NfS&fy zYp|Dr%`S{+xe`)iBcUaQ39VZN?h=sCY+*htC?MPk@KK{Ep>?mpTn6wcQFMyuN(i6D zfUay|I_p+}y99(YTNus?3P^VX`OFsPv+gyx%RoG{B%a+iDt#sk4p{!K&a&8TVf#qn znn6kDh;(I1y7Fn<))1pCiBW(FpS^+umVc?uvbb#_7+DgGW>C_J;jxh=DafaBTSMHj zB<=tvy{?gKEXg&qEN)u}FqQZPMBqJ+d?3)BoNG?q!WYvQ(}Lg#%v8B z!2GjkthK2*HY`2;!2-`e{V9*#9)7}y2;AX>(*q4e|3V0f+!mqdl<2tw6xskR@CtTB z_5(ja0avgg;%_$qO{K)jZo!>B;Z6`@PKlU1K%pJLQYmpW<#F4Co>QV{hm(E{!qO?R zbVy{kh%F(a?f`|h01Lc=i=@Njwg+9OMAxo8=>%czlvq0?a$AJBQzGsTP-qXZR7%`U zdEEA(@094<;iMCU#ZzMOkjQNjB2P1sZ$1C`$MK(fKl$Ic{&elL*}vYheE#&0SA5hs z`BxtQmnU9-@2xwgKiK`T`1M;}dw6>1qo0Uh|IHUWr?9x*>gY`U(UDtg z4?gg~lPmB1x;lCOgZr;rc6RlLe_m>mz6r?W|{ z`rXlkzgYaUI{vqG{fZ|WuWWp1a^9O4HBNc%)yj8Qx9btPo~q;Xdw=<>(Z*p({zWm#{^-sqRA2{P53D)_+z`uX6DqHs@hx!)Q zfAaXU>KiYfQJ;7JHPstGKCAwtRY#^j+?_1F+rIhCd-wg5z5n2nkH#DNx$G}qe|pKY JFFtm7>GxZf1gQW3 literal 0 HcmV?d00001 diff --git a/tests/data/save_policy_to_safetensors/xarm_tdmpc/output_dict.safetensors b/tests/data/save_policy_to_safetensors/xarm_tdmpc/output_dict.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..2321f31c7f291eccead6ce8b742f141faae46cac GIT binary patch literal 240 zcmZo*fPiYH#FVncypqK9R3)owrIeD&0w6=l&DcmuN2xd?5yXo&GL6+yN=Yn9jL%O? zD^4vb28tQz7#bMF*6Jwb;yinwbf6OvJ)?7wd35N sV<+Hp&DLV_R$Di@AGSWSb8QbS_ps~RywG-gx}e>IGjnWPj@3H=05Eh_kpKVy literal 0 HcmV?d00001 diff --git a/tests/data/save_policy_to_safetensors/xarm_tdmpc/param_stats.safetensors b/tests/data/save_policy_to_safetensors/xarm_tdmpc/param_stats.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..5e8a694703af7e2abee62891a64aefa55cd3b815 GIT binary patch literal 36312 zcmb`QYm8l06~_lCj}XKM4-JaZ@$rS@IOp7R&pl`AWKckDQWB6@NdzaiblOhZPTQH= zVj)eNC}>C;2^gZb7owsPe8d;bbRY#1N`j=Q5mHbLNK8Zs4@Hy@rTAadbN5|)pS|zB zYeD%iy`8zgwf0))?En9qJ7;m@H-?t}TsypJ^O~uRBa@fTjf~>I)zi(n$xTzunc=Z( zhu5@rY?&G!8~$u98IFeM);HV5*ECzr$<1rm&P}!E@GqlLdgac@N1AJ`k$}=D zuY#_gnqIfQB~B`j;v@(Y<`pMVoVl=ASFAXTlFozig0ISsBY(CPWfdyWr35F*u8&FZ~>Tf)v+|ED(tckalf+;Hp7I@AH!Fa;-fCA3hnQGd74x}h?mDU|EZC~>u#mRfc~S~*B{;!1U@-RD4x5?84WO_~~nI&p=XVA^92TAjFB&3NMU z&HZ*B!5bDGgYJzDi2dz$At!6xFdXVbGmFiTz&*6ePah2N8q^Uuu z6IZASrmY;by5?#%;|bFP3OHwHO11kONT4}elc9-I!!1=;s0pU599ZBv7uAd>O%GCC zv!r^Jc9R1JDBc8vk|qU-u2~XIByGVU(KSn=8BCfO-P5+@d6jl!1173j(ri%Dq#(;R zOO}bGtrldtX2~*xi4#NRElFOb-OPXiiXZg`B~1!aT(hK@NZM*aiffh>Gnh0n$Z*Y) z;Z^Dl3|GrFOM(qbniS->X2~y+wAF(A)-3sDFll0t-kK%7Yj(^uH%(tPr{2DxzBNmI zU7|Qebioo}2Go`iI$X1KxXP5K2`Nr3DK<#uM+~5W$c4Zq3R46OKrYA(sH`M-QEKUO zl_^aV(wth-Y>>)(a-c=2rOqxUbLfcbI+Zj+> zNvM2kseF|wO%qz5T3TWB!T4U)x18OS?2asA0pvsh{2_KMJK45C* zD*QUr?8x+{=DMj#yqetD+%Yw))^gAhfuBhIB1o42vyfV5p+^91<#7R@TK=Jb4Cyjp zCsNB!1VhBiUorzes$2G=ZwTQcKu3gunvF#$n}Glb!YxRzWt2h03&HRISe7IhBhl?R zpd$jmjOrIbx&#=Ox@A~;1kknu9;R-2nEo-O%Ydb+Tb3pmLg`09kLs4g=@&t|1Q?#W zWq5i7&~^eIsBU?n{xPJ>fF-J1H#dSIlzs+uL>TDIE+dGSfL2?EsmBtuoq&g`TOO)^ z4Cyjpsp^)c3Wg|@yAvE%-Evs{B1o42!&SEoSC0VNPT<~4-SS}lV@Q_)OIEimStrC? ztJ!)x{1fO=-Rsao5u{6iL92TWnkRs^6YywtuSW~V5HADG_FA+`2(4nnq1C+(Efhhz z1o*Se>(4v^l$`({F??81!I(1#%Kwptr%__WtI!#G-+zEJDFv70Mnk=@HLs`Yq~6HS}-t~Wnh+2 z+KR!gWR_duG;wMuzhz4RQ(n7(0~WuE?6Rb3!DwWb(O5!hD+X7QS+0W9q^ZF^WR`sZ zOnHp~PLx?bq05q{1w)WohF}S$tr*n5VW~f-2~z_IFuwQ9PF>xcU8CLo!2*x({x(mX z9-i>A1l-`1)dLMgej>O;X^N2Nh9%D>P^kuBf#>On?tvemfb%p&{p$vxMGZ@p&4SCG zNGk|oZdk%x0+s3jR@AVx+2%>pgFH7Zc{VuhF$k4zSSsxj#VMjCmQa^KrCNXmp64Vf zdD8SC*9}XqO?ui2LbV%~YP&>fiV*IGCEO)YsUBcO4NJRio-{qkcf*ozgVRZ|!k+00HgpUS6m;QM85$yUWxT6e8bXA|Jp?lZDvlOzdr{ zQeprEv@0=Qs4!9RzzlUtEjB<#IWqums%5jJ$-)HW%mmm}rR)G0<;)0pq0&TQ26AQw zT&D6^1rpd$v&^s7iIYXIm@%-ERX&gkAfR0-@yraGBbw!aU^6JOr@PWMLq3W*}^;c3*{w$eD@o zLdA)qQ_M)XOl?(RC~{^f0$6FXFcmp76*g77ufka50b>z&?jT7Mg|Em1zJi$As^XqO z9`Fe(DZ z!a${o!dDalUqMWz2LKXSzgjV@I9c>6;4I7t0r3F7aJ&CD~9rwiK=j|O~)!HScG?*KK_GxZHv zViGviHUpL?OkE4u4>R*XUR~G^U_(tSw{vl_@E;ho7dp|@Hw=hL;85EPSe!6i2)LoH zsrOyn;D|>9K14fifq+B}^~`LK@`HiQ2&CA|0On>yKY)Rd#PxK|j872f3LG WzYn zK>!+%J`cjiDaaEg55i|enxZnQI92i=d{!ibD(}V;j+ovq>3*4lXoU1O$?Sykijwc( zvm>p=m)9ZUijwExGb9;QYigIqNb)(7Ri-#i@)~@Wq=i&ogAtIZ-ae`KRPX@x`3y6v zLEd`u7<|4YTfWl8f(E3|U$D^)@q5|!+7cN?S z#FE3{lPG;js}Ec<8hkKi54bpS@)~?jWtFYoYRPW!k(EKTJZb8q&vBSOzJ-mKj0YcP z=}TLCY?Jrkvn{J^?KVtPE`jz4o2iM?=hDDSkUje*BWsUsntvJd0hqpuD60)Du{{>j zB(D5jIHCgWyEWpXrN=g)MDEqcRC^nexCHOjN&}ZCPM!4KC^02b{@xr>3Eiuax<@(* z3bOBJ2}@diY?F=QgEp(jcGhX7xY}f9_}tAZTRF{ws6hKHjj2hA7AH-{hEL=4C9SSD zSsXr`vj<$9I2j#2tJAPMPi($j6E7$Vh4P%D{X)qXMSb?Up&pUc-64a-2Y?PzLG8Xz zeu>Wq?H@}!6@$W1EYEfwEMg~?xL@R__*~I}mZfeO-PPf9MU zqTv6c;Jf`3Y*-mex+ofEib#?}Bk9P2_)?5!og$cI+c3f83&WwoXQB4*meyy{lvITB zD1We9y2lYH2qO6$U5`lGlZ!^IBADdRQ2R23p;(e5L+jHJ1)+QvFjzySZ$l7@BsngO z)smlEWD@y|*Z#-52_`E5xr*u#h3MDx(%24t}(Q)B(l)>Yo(C>s~GWm?z zK`N+p8VG|}zT0Y@g;gHM*WSBA6bjX2%Ls)McZ(b-pI}=lpeK@gC!||)e8#OHn0R62 zLHQKi{;{+!15psl_t{|Zz|gu3GN62d?jVa&cZ)12AFeA1Chis)P(F9Je=P02Gen_K zJ=COf#4GQmAqf3Hw>%#{29FE#(S%sHXnrq(Nwy3VOxi6P?Tc6*F~NiA+C7hE`^J2@ zuQ>#|P;sKo2flBXu@yINBnxR;$+Du@;Seps=Y4{9zp%BQho*s9GE`mhfeA5 z#e+y#&xy&g;;NFL<1>GGRkeEvcm(yFY)DnPTL4E)f2Yi;D&H6A1|J{&Yp-AFg#aX~ zzgODpKkxwc87x#)d8G#&G5wt~ms#?vl5OH+fO%C7R(!0IKjO21IaPa|1|C8Etx{f* z0SBhf@t{-Ya!6iPvNwDLu)Timu}VIM&;2#j@_~1?Vv%X#!+rx`@nJv?g-`gca@ER3 zmV{6C4Wgw<(?DO4JttaOWvznI5MO{j`=Sk&Cyeg#$-QnV^h!n(dx7?xawcWvu?#54 zz8hC1S*vJZ5;)vGrK$l*+;I2w%0|<8V?K`8ZJR?jQlg z3n2eePR{MipPo@HfXqxeICqc$;subaDd*;Ho9RZ_ysfomTkGxV0a=`KZ0=wo$zt@gm6mEX~c``62unD*wE5*_p+j6YJxh*F!_QC%S(h8hY2W zH)f~e`b`&Cxp$0R(&PQBo>-e7$KJnT?}OPJ>$~qiaA5xK z-}T6U{ni5u|6=)%t!UrTZ^54`P%&A zDfRK0RVx}Z58(C7yC$w%7Zv~gTzfv-d0Ua+clzQJ&%HE$(-&4O{QmKM#m&d zKJ@b6pZxrx+lq5mzFJ)V@XGPA7dPbJ|KGmiu^+tmp0oDeIlld+7aC{Ww6BO-w=X{X z(>uraT()B2O}zh-=MU^Z=_^~uk6du>-G8{HT|D!iLt{T)wtM{Mqh~LC_}+Fg{n#T# zbadDF9}({`e*R-8?4CI5o3|9niyN|EZ)z7G{q9o>C!bay-}t>@U;DZLjUMkJ@)6=@qMW|o*!)Ner<`r^zA{uj}@6=-#LEpY(YD zDeKSJ|EQ;*9p1jnZXWr3$ebIrR?LYU*t>*iOFFrIjf8uWQdQW<&_~B6ZwHt41{