From 40f3783fca0b3334d3abb42e8b10e476b178f0c1 Mon Sep 17 00:00:00 2001 From: Cadene Date: Sat, 23 Mar 2024 11:41:56 +0000 Subject: [PATCH 1/2] v1.2 --- lerobot/common/datasets/aloha.py | 2 +- lerobot/common/datasets/pusht.py | 2 +- .../data/aloha_sim_insertion_human/stats.pth | Bin 4370 -> 4370 bytes .../aloha_sim_insertion_scripted/stats.pth | Bin 4434 -> 4370 bytes .../aloha_sim_transfer_cube_human/stats.pth | Bin 4434 -> 4370 bytes .../stats.pth | Bin 4434 -> 4370 bytes tests/data/pusht/stats.pth | Bin 4306 -> 4306 bytes tests/test_policies.py | 3 +++ 8 files changed, 5 insertions(+), 2 deletions(-) diff --git a/lerobot/common/datasets/aloha.py b/lerobot/common/datasets/aloha.py index e891ccdd..7c0c9d44 100644 --- a/lerobot/common/datasets/aloha.py +++ b/lerobot/common/datasets/aloha.py @@ -84,7 +84,7 @@ class AlohaExperienceReplay(AbstractExperienceReplay): def __init__( self, dataset_id: str, - version: str | None = "v1.1", + version: str | None = "v1.2", batch_size: int = None, *, shuffle: bool = True, diff --git a/lerobot/common/datasets/pusht.py b/lerobot/common/datasets/pusht.py index a8a47da8..bcbb10b8 100644 --- a/lerobot/common/datasets/pusht.py +++ b/lerobot/common/datasets/pusht.py @@ -87,7 +87,7 @@ class PushtExperienceReplay(AbstractExperienceReplay): def __init__( self, dataset_id: str, - version: str | None = "v1.1", + version: str | None = "v1.2", batch_size: int = None, *, shuffle: bool = True, diff --git a/tests/data/aloha_sim_insertion_human/stats.pth b/tests/data/aloha_sim_insertion_human/stats.pth index d41ac18cdfeb94b610369f116db8b267cf642af8..87b18a24b3baf8546451f43a0525c1c7574fb7ef 100644 GIT binary patch delta 819 zcmYLHSx8h-7}mVFy>WDXMAi zQBX%q%}mi!A`!Q9b0JW{HWeQlmonNO(n`@%XzqRSe>?y3{onb%pH;HTF2S(#@5Gc1Z(|&PM&{{u5+e)v4zePH*b;KIHRu+O+ zvJ5bk5e#^cq|zKEYsWC}a1(48>i|wZHRw#A0k%EXqnTqLVckR}`sLLO3%_e%zncm5 zpdggk!J=fxRk$Y4K|F?8=H84;EFe{LESYG@<)TVWuxGx4G+eUcdL>h&D!mQJzO0bl z1;yNh)Lu|iltk7>J`wxP6t3#3evk+8>r^bW_6(EY+XTh_wZQXY4caks2Ns+4$hLd| z>N{_tm90mxej^#m9+;3WFdoGajM7xrD182gK%{7leTz#hAhS6vF%0K%X?FE+Df%WE zlW&rxlo?`WiA&l)q-&~(yH_;~#wRs|%ltr225Pz1qcHDw^BXlJyVMiyO^$~N zUtJNE=uGZrl?j`!(RVQ}ab%do9wyr8>jM@Q&LHmeL;BkfCGf$sv@SLhTy~qLIdd*J z%U%H$5*v{DJP%}Lri0+NGQs+$^(HMVn%*B68X6p?3JX)K)FEN&5Va~GfM=XJ?LUT> zMMg@vGfun*o5Tu=zZcAj)s$fRhB$=cjl)5b5K2%#TcW0T1uZ>t{8Hjh@v>o~bPwK; qC@5Yb%R`QXr0x_iiBQo=Qz_v`{E|hfg5sG80Is`IDS=1fuIm3H_)A>? delta 819 zcmbQFG)ZZL1Dit)hv1%BI|BCKc-v@iEzQ3-@AL-SiXx7^5q1K07vn!l~9emWOR7i?Z1s%9?L|SZAT_ zE6F+53LheDgi|D}qquh1Walv0PQClp>gL}Ww(1XgZMGkO0X8IEtg~|xyMx1%(_%J` z!W#RZoStTXo8NP{nFb%2wpp8I4;0&-5oEA`Yrz!zWD!~0<10AqHO?{EY2G!m136@$ z-;W2A8-Rw)i`uz2Mm>MOl!=f1al@o+FvtF zvJp|Ax$n6Zm#zA{|9e1&+)@1HH;Kc+!NHT$CPP|lf33%Kdj=plQOpmfZ;9-&1M%I? z8tk8UZi;>ES20^h3oiQ{O-8!|?;LDFX2kmH&zjr-G$d$4-`@U1rTZsz`P!c??cN)? zeeS--yCr);1{APKf#qbM`+*IfHFM@ZkofP;?YlvS+>`zKc@n3Cr0@?#r$=8DoEo#0 zoj~I2*1T8X0jlGf?9U@Rxs20Ga<|J~+iQXQY!kzGgRQSGyKjn6@r^S;IM<)KswheM)1qaqgg}#@Q8@L@DCiK0vei3ujMshc&t$^?%>-C40+TL!T zX`S%sn@uEppml`ML7U8PjJ8$xo?7kPKHFC1Kby_(qvAFoC#in^C_0Jfs>+UZ1PV&1r|^U0XgE6P59+mKtTeQ$>on? b0s9DK$wPhx7Ld!p2FM6Rv4GuCB`^a3N9jMg diff --git a/tests/data/aloha_sim_insertion_scripted/stats.pth b/tests/data/aloha_sim_insertion_scripted/stats.pth index 4e1c1884488d7dcf6aaac1174d2415ac0f97f5c9..7d149ca4a4f93981271a58152723f97945fca005 100644 GIT binary patch delta 1196 zcmZuwe@qis9KS2Q(z0s{k|3ral{sN6q%h9@(0li8MMbGZx`luOL6~mNjR~bQbZm(u zw3-NG;W1g83>-n1rBhJUIjuc}Y$4g|M1vud{xQQ6X69mWEN)cTy{j<)*h}8M&-;Ad zect!J-|yYZ?8+QENXBbCjVX50gWe-cFwVd@Ofnc}VNA;;Y{b|u3nbd(?4+GM+E|De z5g9tHa0eDLH`6UeKhfbu&Ju(fUk z*yKKa+JpFbo^QD^&wZKN1(JsRFmSF6PSV?4*MnsdQ$V@1*DJ-nsdzbs*;U2;7^MxjA`~o44Z}S&If4vbF-e0)I zy(8N#X@!IRl7>yL78dm&Udi))L-X7YrW-KDK6taC49+d>fM;Jj3DGxRIPRX~vNjKZ zWKp{^y9b7zl)&H3eei0#8)B1F>(1n$5z!>|;!J33T3Gn1BEt0qVv$_W@eA8O{1K$i z{u6515*DJ0Np9EC3MrC~iL^&Q+;b0>FsZKqTdiu3tbP7Iq8ib-BD8i2oNA>{nE99AA^zE)Y80Pie@ zK>J%?b5mQl1XsOC3Q>0jZ2PQ@wFYST*m)y}HGJXu>Yt5h#M zz{C`6G?~qeg)w72Fn@bD%2TJRR(cmX)FosM>QiURr6yZar^xP5D572~m-v uszj@3d?Hb=!wJY1bv)Bz=bk=aMH7j;k{;uVilfx7h{x*?%1_cMV(h;r?3+FS delta 1252 zcmZvbeN01{y``m}IKyP^Xp3ux`$%P!uHEjr=gJ~qwXFMNbS&Epgh4Wg$f%2G zh!vrUxDnb@HB?{?lZnn{=r$4;+Ojy4Wv*&KiqS1h-KIfjMl&WU>b&RPI$7dL-kbBB z-^p{{^S@+b3050G-wq4KrlI8+YOhgwLdp!I_OoU{mfEw2Njv@{ENtLS7-frE5p za6x)nwiO67ZjSDW)-Oku%TQs)Lukh`bns!cYZ-b!<3}NtCgJQ}@diPA1pS$Kqo7|A z^dJA~4neP{3bO34SER})f+(YqOQBn-CTOD7Q|oCp*R^@}Hy^6sf2g#?Mj$_BMqx^Z zCMYj^t7tVC{%J8fK3EO+iW(x}89ChGJOw)IP4H!3&v@5|_oHu9Zg?W^939!Ag2CIl zk&ULm8J8nd`rWFC8+j!xOB=4yy{&QZWug)s8?*p@LI&&uU7+rVmEh*9NxJR&WpMSk z8gOuID`+&S!O@u`V6kN#z$Q~!Gl4)R%RchO0u!9~1LbH4$}jGO6WQgk>G?1u2rv95 zmSh@Rhrp8@`?z8cywLj$9LPV$GweZp@YGAl%bEBlZ$t}=YT?L)T2jPovwn6NXdzP9)Ud|+c^hWe)T|8#m zb&c8ctB{CJ9Ez22i?pY3hp8bv|X_p9hOz*E^Il=ywEeoJUV@jZwQsTaN@y^GxubbBs=_<5^VbF z&*(7M!zbF^5od!O&NdwfitDEslI|S;S}TeL-*v<8)CfIiQNiEWNg|lx181LmEf0C+ zEZZHvN`Lb66_9AJ0`t|ypm`GkM&9ZKjTbw?gfdBgJ9`A8e4pqSQL3oV7VVw=@su@)6uZN=8jMHnw02)oc7#bdGuZx>}No#Z>{S!KSMe_%0X zjwnZh29$YX{8kXB$xA4q)N66uZCpWcD7pA=lp~GGpv4{ie|`x$R0c8MEo@hgd@8*d z*FEh%!AV?hhM&XaZuZgV;PwV75Mp?aj1Wz?<2d(VMZo z&>2Jl8IyUKwc&zZFhO4=!9bW`ACh1MOz=GOBamIWAoKUI=z!QIAogb#T@bq(#MWj7 zvztKdWL6EXM}h9T{HzQN{H&8F@W_bsGngnBx~q~;f;q!zh388A$~$f`g2 zKWn4IkDV`fyGbhUzf;>{@AoNqZ^??c`R@-y8oosFSqiuC>AGU1_4zdLq@-B$))no_u0E6QzI`*}r#6O)&{#^ZyEbw?>Y7f%V&aZ-hA&w@#F*S5gGvoa&`xz*X?6a zXx;xMuTX zlRNg+_psVU`Zn1oTwP%&DcohJ@p!3ysau_0QGl$ySC zr|<1QRJuR7T;HC#H(_6~Y|#E5;kEl1fS}EB-9G2rZTo{cU)%L9-?9JY?1uekudm$? zk^{L!qUe3yWCu`v*n)m1dKr{ zuJ|Q?g{#P|ku?cg9fQQVp_EXa1x^S3tg%Xn?}UAb;^ zl{edllfBqh&TP6(@`V6f(=aw$h274!u2Ujy9p@gj4OqPh9Ce|tDw`)e@LZKPH8Quf zv@kF=G&3_bG&Zp?H!}yh+18Xrb+Q4k1SdRAP7dVtWsILZpI32PK0hUli~c`~iwm TAPDegV+YY2CL0P$vq3}w-s6Cw delta 1212 zcmbQFbV+GL79%&KH(y(!m^Ys{W836HMqdbD9>PD$XaeD@gZYz~pj>^3v^CReF%54< zZ)PB7EGcx>fUtn<$={f?k)^zlqzsu+qymwoa!{orkffF~KY}%PS@kC`VC9^AfVIIP8r0%=*f5E=a=Vkk1m+0CDc(Zd{SyQlPas!)#gK8U#^*`wV+tm^uZKCzq zY^-h1+sxS8V6(j|!Zxz3-=>$>&USyxR~wC~LAJAhdD<5B?68$J^|J*Ta{6)ehe_-K z21}Pu-X~b`&(@o3qJ84hm3C_Jy>{nXm)c9PHrU->FJm9<9s7Q1=8P4-?_R@g}jciHWGxzxUAXO&&yK3V%euV&l-`#c3~ z;s4jC-c4=*8X|FJ`riISrTc@+_3d}gPuTZoThRU<;kEl1fS}5C-9GzG?fW|)zp|5g zxMTl~i4FT>f3MvSk^?#9SIDAwlQO!c}bk#_7fApuh0>ZsB8!`*rME>_2X?+PjY7<35#+IeUY| z_uHMA&%HOSdilQA&@}t?`QiH-B z<~N(&!E81S^_Oi_ zfdw2}AmcnHxADodfWjSW`T-yp6x@?1@`Gf50lA>C1#{<2HsQBqlHr=1&u_svXYvYu zd1hchP2R^Vt|G(D0L@4#i6x2p76IOjOdv;ba6p^}3Ja(WG6I%BgTn+Aq(LE|0h9+} h8HhS4Icf4l0e?XYsMi9#+1No6A0~ekkYhM&XaZuZgV;PwV75Mp?aj1Wz?<2d(VMZo z&>2Jl8IyUKwc&zZFhO4=!9bW`ACh1MOz=GOBamIWAoKUI=z!QIAogb#T@bq(#MWj7 zvztKdWL6EXM}h9T{HzQN{H&8F@W_bsGngnBx~q~;f;q!zh388A$~$f`g2 zKWn2y#?^JZ-;4j>r!CWDZ}w>G?y%2H`@eBU_Lm_6){rCp)ma*o5Ux+GQ=c+CFvu zCR@kLN9+=(oU&(NXs|b(^u?|{qRZ~&npyU*jK%HMrnK9OFYW@%feeT@m~>+D0d@}y zmM<3e4)Qg7bk~IJ-vnp7*LU%~{R~F?XU&|s&zCuIKS&OwdF#oL zhm##RTx^_dJgfxTSKGI~U27{Wa>A}-$tioG9eMVma7;eH;bF1ytfl?rBNO)gKN7gV_V1Qm^Ba%b261fLC#!UNzaTf~ z-go@QVB0mc5t{PV`a7V(gj=B z3`yHX3mvUROeWZNzPx7J;nicCAidniQ*4{@dF5EZ2?NMy?3gUZCoc_(W(#1DfG{r*BLK(b06qn2P)uT%@0i@p zr@#V^D3Ea+laKMqvw(vhZ2BaAc@|K3Prk?xlC%PHL4gbA#!N2Zw`97*Ie9(51!K(Q z7eG01NC7pf+~Hz?=B$*&l0PD$XaeD@gZYz~pj>^3v^CReF%54< zZ)PB7EGcx>fUtn<$={f?k)^zlqzsu+qymwoa!{orkffF~KY}%PS@kC`VC9^AfVIKle&f2`GhG<=znau!|Nh+0-42q>`$bRo*&90;*t<9^vAObk z<-QJ)zjn_`F75eY?QGArRo3>quBm;1H#^7mDPh+qH?TQ4@V}O~syXQ8$?Y}g4 z>(E(+wp{Bk+HPH^Z#$E9jm`Y~n{ADC7uhaNyKl2mAl24;o{{a=Gp4p6LpC3DS}}>; z#b#OJq+Qm6tL;d@2sm;V&sdu6-f{#i3; z?qkYN+z*liX|oZQ}!bYmvPbwmO*?Y_Hy&W_$0!dYeY3^|l3xvuz*PJhm}A9d3Jf zpOLNk*RSYOGcti(#lZn_8z?lOR>%lg0!==%fd)_>gk>P=pyZ^<69xPQEuel2@MdENNqm_6Q9zmvA_@Sq+L_4! diff --git a/tests/data/pusht/stats.pth b/tests/data/pusht/stats.pth index 039d5db3d53a93a92163692e0702cc1c5de65298..d7107185f4dd1f481913735dc24f68230a9d5a9d 100644 GIT binary patch delta 433 zcmcblcu8>qo4~WFsm^yUOmGhHX6JC-KV#`+0k#5xw`H>&9kr%7f&`O#mTFF(z*fZf z!f{OR1pV)s0LvH*91K*P_;j+=|7gZ*a1@I7wx z1a4MFX#;a(b5nB@3kxG7GfNXgBXe^DkeT7kEq;?X@JwU@2TeV%*yLzl1r~6S)bpCL zfCFMDuQ?0Y`(Jqo4{Vr>CTVcTAc&D**U_!`~FWBU@H)q|Cra&RP>f3NRVmL<@=K-uody` zc4laJ09N;j=gXzZ0_;V6sm=@xeNG@%3EyVPPoBVDAkckxm-7uVh`JfGTm2^sa1`;y zIeYc50INHfX1Qka1dam91MAD3%`B>%t8FU5X5L?VH=Kuofq`c-KacEWGfpqbyqSsC z-{$LCKWI5?39`a}QBD9t#Wv0};!REr43~gd0BpjNhRGX|tc=ozmd3^gCg#Q_X2!;bMrP)w2Ie3;BR{xX zPTs&Xkp&zy^}J$}qj?orz(G>aYsLZ&h@HIVEMV_{0W!c|*5@;00eiHbSA22|TT$BYH+K9B_q`4w0|?wSnJe1|`Z1?&bZff)e5 C9F{Bq diff --git a/tests/test_policies.py b/tests/test_policies.py index e6cfdfbc..92508dac 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -22,6 +22,9 @@ from .utils import DEVICE, init_config ("simxarm", "diffusion", []), ("pusht", "diffusion", []), ("aloha", "act", ["env.task=sim_insertion_scripted"]), + ("aloha", "act", ["env.task=sim_insertion_human"]), + ("aloha", "act", ["env.task=sim_transfer_cube_scripted"]), + ("aloha", "act", ["env.task=sim_transfer_cube_human"]), ], ) def test_concrete_policy(env_name, policy_name, extra_overrides): From d2ef43436c69a4df1b10f4923467894410774bad Mon Sep 17 00:00:00 2001 From: Cadene Date: Sat, 23 Mar 2024 13:34:35 +0000 Subject: [PATCH 2/2] move from cadene to lerobot --- lerobot/common/datasets/abstract.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/lerobot/common/datasets/abstract.py b/lerobot/common/datasets/abstract.py index 529bf6db..8295ed48 100644 --- a/lerobot/common/datasets/abstract.py +++ b/lerobot/common/datasets/abstract.py @@ -14,6 +14,8 @@ from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer from torchrl.envs.transforms.transforms import Compose +HF_USER = "lerobot" + class AbstractExperienceReplay(TensorDictReplayBuffer): def __init__( @@ -106,7 +108,7 @@ class AbstractExperienceReplay(TensorDictReplayBuffer): if self.root is None: self.data_dir = Path( snapshot_download( - repo_id=f"cadene/{self.dataset_id}", repo_type="dataset", revision=self.version + repo_id=f"{HF_USER}/{self.dataset_id}", repo_type="dataset", revision=self.version ) ) else: