From c77633c38cadb8e99628d5de84d31c1eb06dd757 Mon Sep 17 00:00:00 2001 From: Simon Alibert <75076266+aliberts@users.noreply.github.com> Date: Sat, 4 May 2024 16:20:30 +0200 Subject: [PATCH] Add regression tests (#119) - Add `tests/scripts/save_policy_to_safetensor.py` to generate test artifacts - Add `test_backward_compatibility to test generated outputs from the policies against artifacts --- .../common/policies/tdmpc/modeling_tdmpc.py | 3 +- lerobot/configs/policy/tdmpc.yaml | 1 + lerobot/scripts/train.py | 86 ++++++++------- .../aloha_act/actions.safetensors | Bin 0 -> 5104 bytes .../aloha_act/grad_stats.safetensors | Bin 0 -> 31688 bytes .../aloha_act/output_dict.safetensors | Bin 0 -> 196 bytes .../aloha_act/param_stats.safetensors | Bin 0 -> 33408 bytes .../pusht_diffusion/actions.safetensors | Bin 0 -> 4600 bytes .../pusht_diffusion/grad_stats.safetensors | Bin 0 -> 47424 bytes .../pusht_diffusion/output_dict.safetensors | Bin 0 -> 68 bytes .../pusht_diffusion/param_stats.safetensors | Bin 0 -> 98776 bytes tests/scripts/save_policy_to_safetensor.py | 101 ++++++++++++++++++ tests/test_datasets.py | 2 +- tests/test_policies.py | 39 ++++++- tests/utils.py | 47 ++++++++ 15 files changed, 236 insertions(+), 43 deletions(-) create mode 100644 tests/data/save_policy_to_safetensors/aloha_act/actions.safetensors create mode 100644 tests/data/save_policy_to_safetensors/aloha_act/grad_stats.safetensors create mode 100644 tests/data/save_policy_to_safetensors/aloha_act/output_dict.safetensors create mode 100644 tests/data/save_policy_to_safetensors/aloha_act/param_stats.safetensors create mode 100644 tests/data/save_policy_to_safetensors/pusht_diffusion/actions.safetensors create mode 100644 tests/data/save_policy_to_safetensors/pusht_diffusion/grad_stats.safetensors create mode 100644 tests/data/save_policy_to_safetensors/pusht_diffusion/output_dict.safetensors create mode 100644 tests/data/save_policy_to_safetensors/pusht_diffusion/param_stats.safetensors create mode 100644 tests/scripts/save_policy_to_safetensor.py diff --git a/lerobot/common/policies/tdmpc/modeling_tdmpc.py b/lerobot/common/policies/tdmpc/modeling_tdmpc.py index eab0f94e..1fba43d0 100644 --- a/lerobot/common/policies/tdmpc/modeling_tdmpc.py +++ b/lerobot/common/policies/tdmpc/modeling_tdmpc.py @@ -80,7 +80,8 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin): self.config = config self.model = TDMPCTOLD(config) self.model_target = deepcopy(self.model) - self.model_target.eval() + for param in self.model_target.parameters(): + param.requires_grad = False if config.input_normalization_modes is not None: self.normalize_inputs = Normalize( diff --git a/lerobot/configs/policy/tdmpc.yaml b/lerobot/configs/policy/tdmpc.yaml index 71dfa9c9..eb89033b 100644 --- a/lerobot/configs/policy/tdmpc.yaml +++ b/lerobot/configs/policy/tdmpc.yaml @@ -1,6 +1,7 @@ # @package _global_ seed: 1 +dataset_repo_id: lerobot/xarm_lift_medium_replay training: offline_steps: 25000 diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 268185a3..f58dbd06 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -25,6 +25,51 @@ from lerobot.common.utils.utils import ( from lerobot.scripts.eval import eval_policy +def make_optimizer_and_scheduler(cfg, policy): + if cfg.policy.name == "act": + optimizer_params_dicts = [ + { + "params": [ + p + for n, p in policy.named_parameters() + if not n.startswith("backbone") and p.requires_grad + ] + }, + { + "params": [ + p for n, p in policy.named_parameters() if n.startswith("backbone") and p.requires_grad + ], + "lr": cfg.training.lr_backbone, + }, + ] + optimizer = torch.optim.AdamW( + optimizer_params_dicts, lr=cfg.training.lr, weight_decay=cfg.training.weight_decay + ) + lr_scheduler = None + elif cfg.policy.name == "diffusion": + optimizer = torch.optim.Adam( + policy.diffusion.parameters(), + cfg.training.lr, + cfg.training.adam_betas, + cfg.training.adam_eps, + cfg.training.adam_weight_decay, + ) + assert cfg.training.online_steps == 0, "Diffusion Policy does not handle online training." + lr_scheduler = get_scheduler( + cfg.training.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=cfg.training.lr_warmup_steps, + num_training_steps=cfg.training.offline_steps, + ) + elif policy.name == "tdmpc": + optimizer = torch.optim.Adam(policy.parameters(), cfg.training.lr) + lr_scheduler = None + else: + raise NotImplementedError() + + return optimizer, lr_scheduler + + def update_policy(policy, batch, optimizer, grad_clip_norm, lr_scheduler=None): start_time = time.time() policy.train() @@ -276,46 +321,7 @@ def train(cfg: dict, out_dir=None, job_name=None): # Create optimizer and scheduler # Temporary hack to move optimizer out of policy - if cfg.policy.name == "act": - optimizer_params_dicts = [ - { - "params": [ - p - for n, p in policy.named_parameters() - if not n.startswith("backbone") and p.requires_grad - ] - }, - { - "params": [ - p for n, p in policy.named_parameters() if n.startswith("backbone") and p.requires_grad - ], - "lr": cfg.training.lr_backbone, - }, - ] - optimizer = torch.optim.AdamW( - optimizer_params_dicts, lr=cfg.training.lr, weight_decay=cfg.training.weight_decay - ) - lr_scheduler = None - elif cfg.policy.name == "diffusion": - optimizer = torch.optim.Adam( - policy.diffusion.parameters(), - cfg.training.lr, - cfg.training.adam_betas, - cfg.training.adam_eps, - cfg.training.adam_weight_decay, - ) - assert cfg.training.online_steps == 0, "Diffusion Policy does not handle online training." - lr_scheduler = get_scheduler( - cfg.training.lr_scheduler, - optimizer=optimizer, - num_warmup_steps=cfg.training.lr_warmup_steps, - num_training_steps=cfg.training.offline_steps, - ) - elif policy.name == "tdmpc": - optimizer = torch.optim.Adam(policy.parameters(), cfg.training.lr) - lr_scheduler = None - else: - raise NotImplementedError() + optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy) num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad) num_total_params = sum(p.numel() for p in policy.parameters()) diff --git a/tests/data/save_policy_to_safetensors/aloha_act/actions.safetensors b/tests/data/save_policy_to_safetensors/aloha_act/actions.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..70c9b6d811a396f7e554b424beb0b0ced73048d0 GIT binary patch literal 5104 zcmb7{`CHBD8^=#a!(=RJFv+MSVpLB1!gD{?L{ZXuI;|6(>XfCJP^pY%Letj_WyX+Y zlD!Nj&gXtgS}Y+Wv`O1gX+71c(@DO^{0E-jKF@W%uIstp_ve1U?_130{daa52n_6Z z8HB{_*u25O-oU}a+`!l%S{(H0%h%S}%+k-;AS5UzC}7jZjnNxoqE)v7V@pe0zj$K< zv(F_}oUyH)^(WT+b6FKt zY5kuWDjvykX{(rSQs`m#r7jpbCBi>wF(jC-VplJQvAQG`H#LbHvw~0Cmv)gu&x*+} z=gjfkuaB|8HHv1xn2A|BJz&%8gBIZhaPhJ;^IN@>osRdTg2GLlqlX&hjqW7XYwL(z zfhE%MwP;uxNuLGk;koiI$ki8OMNI+Z|K`NDcg8awUql77^*QM>1q|lCCMU9M$^OGO zI9K`{^B{t@e$dB#|MY|V8!rqH6o7M=3)3h}WH(HMY4@o=x%MCh)b6h!gUaV*jEN0i z(r!V!SK;(Y_6$^R>xZfVZwwA#kPYitMo%2GpA$s2Yc6pu*9W0}dKoEitRl^`7h#6q zdmJ2=(9@0D7?z@dIq!s6c(54m6s}|aEF5W^hX z@vvGXo&4`qsi@T6z?8MZJl9{MmNEfH8arVcQ@TKS-Nw9cF4tZGM!d-80YjHTu zduf0tGk9oXj!0(PA`7r7*1<1 z>0yb#0vc5HvpmMZc70ct9J_~&@d%`I=cRFr?1msi@RGE=X(TUvY%$@k0yWo5sHw9K zuCn?7BS9iOuU`lntuD;&mjmp7Cj;n#o+n(9=OC!>8zI%gAu{Jr8+^|c_;#~|o^;m6 zvEx3#;v5mKu3->Zxt48>+|TYm@~2S>o}q;4*f=JJKOV164;UnZfMBidM+ zs|2$dKDdQf3XK<>*%fIbi@O#`544=(PQ>(r;nzws_TCfn^qoDX-DtzDE)rV!W-5O3 zy%LN{d@$zZJ&659_0Au6GmCoxH0|9TPPU~N>Z99<@woRyZJhu$$2H@))1kB}QWu3W zJ@9^90ID_J1wP+_)jyFi=WK5p*z3lbTpNbI zp5ucKXY%1{x-)B-xQ*#3L{vXk$gPj$(*mCklB!-q>Km+4qoM`5-zC)FT^IjqmcfY3 z8xN>_u&&UF-8%dO+xe?6%@ONx>RL)TU{XhxAFd=zP3_QkY#sh+5l)v5=;Jfd5OmG< z#F>c*dG;=>WKSFm`Zb7d{`W99RX7BJ3{78!QF|?2ToOA3+PPl% z*hC7;|8!+`F0m}ID3EU6aEv=vqX1EO8Cm-LA2PfLaQy3jOne+cwc@|Qm19O=<`N<5 zeSHtMrn)kxWywrtA4n^|e9o2M8-zpO%1Neu1KC_auslPK``<=THhl{AbbNq$=|W8T zz+m8l3u}Cm$a?$(sP~pGZie9?)CRPW=A;plYB3+v%$o6%OiVX-&qD3WUNHYV02i=3 z@Xzk0tR^a!Et2@q0_k_$&n_RpdwDncE$l7{J#C7;i+SveX<@YU$sBy<+Yi4J5eD}# zIFq@8`MPap!$saSaTjoTi$+n)m^Sh<>nhnd#te^F)Zzk-C@L#g_21|LaJLm=^>7|^ zPh7=r)g-Xn&i*vV-jMTj=2OA)R$^IGMtr7Q;OUZfTzNc_#?GIGC++&-y_yh%cz577 zab(eUNv!;&hz43+m+Eg&Lc!}Qa#{9}81Jz{XSZft`*j3$CvyrfZ1x(q z*F1?`{5F`zY9w;O(}zGm@d2^5DJHv1?D6TvVN`TVs2%q$PW0o^FHL=M$)FTs$E;!3 zChcQ!;vlO3Fqae08HA88`bhJ(=R~vG4mC~;0Azj?%(g_=1_~S|A`>=PdBU>>g zfJG`jsD9Z#?pPU*w%i{gL+VAuyGwx1k0xG4H`G4x#jB$W;re4Imh2tL z#J7cXa5#yp9M7Yd)s>{Iy^(mCSm3XZ8!%p#f2}vZ#Wn6-pg7`-9WjLvJ>H2~MQ&!& zO)8)NXu}Dj2B0sXo$L&$Ba#JHI6=1tk8TL32ac=qT}uuvAHC6hK>=jMuVr6-wT(?Z z7eMtV9_IpP%Rw{xEy*^iAZ6*6I26-~3k2a*uy;Bd&5(mUP>44+Fc^96!u}fA%sMLl zY3$Uq+=mPOP^4E!E-EU?d8HjTRP^Ge{VM;gnu;aW3dp}H!p4{)@cYSy)h~}@V!0pv zGQWYN4`g6CP908ec}KJ-*x~k(UYxa0LZ9SK#!dMOQ0(x-Wo-rEwq-5bGkX{F|JIMr z-L2#yP-^|tH~7S~13rxV9+Se0AS`4h>+cY=_E=B4 z_?a_za)uIi_VMAz^)lo_kp-wT6S$W0IG z+7`%FJXNEUV@61ddI@O@7vSDm71;AwLe{h&Fb zBnG;p;jpJEM*Z82=A%RD;zP4A>O?PGuk^=0-9m5+U&bu2ZDr|3UbOyphqUL$5V$Lp zq`nD>ph|#CT0dfhSt!koFv70aeQ?%Ym49I+u<36Hwtra!d&N9y&K3=BSiq-c^9RYX z(2L~VHgkL!UxnB2NT~jdd1#s`hlyG~NKWL#(V|tXwa8%fjfsWJ8S0QG#BC7fyJ`fhn$GJ8s6aJAVYw#6A98`FuGHn!X}(YesZB zRP)R55H^`cP}dCAo^$RaB;@#D>R*L$;{6(Co4beIdf-pRcmL#m3+{u}uXu1@T1G}) zx5x4+Lzpx#g1%ia0}l&7!X1$iAMVJ9xPE7Lv>}=4Jo2My6AHMAnSD?e*Gq)mZN&A$ z984~&#eMam6#ku!7az+YEh7N8W!!@|xP-YR#IaNpRsJ#%poBJd%){<0a@c!Yh`9xKLGpAZbNi6QvV;9-=EPE|)@(j)oW_H!jt8Xcu_@+f zG~tg2B~;*NgpaQF!7f$)jW@gpX4Wg2`S?9-+8bZWJ2O)nzex##H9drz@{kNi3NYbX z6>fuY8o{53BMnMux#opMZUrz>yPB0x+Rt)qgK4I}1$Rz51QnxNh&f>-`RxKc`a=ia zR`uWYirM&kz7oo{eQ-xOgHa-9Ha>hWs|gLF^&Lw%H=7}t{HBgrJKrPKEmr7sU4~6W zLPy!^q1KC!&~(ukf2a3A7Us-CU!^c{YM|SpqgWlGrTbX1)&O}Z_`g^TFu2BwQsOnE~ays z4!XCtL0WJiE=a!*j;B|!TSj55R8_yc+`U{yjshyT4U*({4~eI@F;3JU!Amp3=!s>z z7;DoGO$YsOqR%5p-Ra0Cei6(v6TN9=_aQE8p&Hdp>LHS-8zkk0DXw#pq) z861&%V(Ez@c(Glzzt8%CS>Ft%_b+5~?5Z5{vR)C*gaX1mEzs{p7fv#ZpvGTLLt9l= z)%keib@M{F_P~{mv~FeZYlEm@{8g?XTMqiCo{;P3OUbDOTdY1lfDL0K=}#$NqhGun zv~qm$snLDt6T7nfh6H9XIf!Ot)pEN$d*O2bAQ7&tA)*5|7#%)<0%O%&UN{LobLCKN kBf|FMMPRMv!n8NUv#@!ARJ=sN9Zcgx-^7?-Af7^_iasU7T literal 0 HcmV?d00001 diff --git a/tests/data/save_policy_to_safetensors/aloha_act/grad_stats.safetensors b/tests/data/save_policy_to_safetensors/aloha_act/grad_stats.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..2e8451891e477a10a94c36901734822c69470a04 GIT binary patch literal 31688 zcmb`Qdvp}l9miLJBWi7xhXtQB#TN`|c4l{Wc4o5+9K^#>G!-~1D%*uD63BxlD-Xdo zAc~c$D7Ju#MnElCut-8S3BF3SrL|ho8ho^+$EZ+WRq#}cw*3ytvUhjp-r3(>_{Rgs z^ZopO^SgKEvy;irm-Bo_{wyk~sjEy@mnLSUs_JUvvyzF*(&<%+hImafQCngzDyd8@ zoSQ5$OU7$ziBi%qD`Efb+Uu2)%0wy=ubVluA(?8xzmzF~>u*+gX$`4LN?1TK+{5N4 zt8SQ;A{S>UepMJqtk171emxsFJuzd>^t#$)>5RJCdHzz#BK^9emU)E|3suV$m6^-Y zLaP%CllA`6vf?UM1B%KF$|Idxx}j*y(A>jS*@LaAism-KoC9+-omjf2D*?Aj_J;e5 zJL&;!HNi~t(ut)90!lzI(w^|6$Ax;52pCFGFw?wrV(CFu335lO7Q5Dix}tMu>PRP+ zu9qn~cchtc>7%X(6rDR$M>@51LkV$5YKK?W&97}p)Xc3;;xCe?#*nIngrjwX5E}{U ziXj+q-nOEye4|V;1QX6nCzfvDdMOxi-nJr_ZYW{Fg!9sgrH54|%$=-O?4}Ltioc9I zR7X6qc0?ndJJn3Q^u9$l^0|X`&JYBR7WzpFkS zx#()8c{_~U1x(7%S1Ct6wS1_QpYKwcc=@HIV%B_}a^zFXhfewVJ{2f-A0u1w^Oeex zPc0ux<>xz9CSLyN!>aszt#ag(%g1o}US-F-?J)8dfMEIghQ-5Bn}KNg`KA>pdLzTN z{CwGRLvlA6+dfP;&^KDkSsrIS=@9<4MjKr%mvvCKyCoTWnGJtuSy<$$d;eAEpa@xc*vH@ z$`&_W^5`QR07kdG0K^86jeOB`Yi_JgRoS=5;)ztMwzR4?KDWMZc3f&JfTXCbeL1=j z(WtG#!Bp143Wdpz%D$Jy+8B{XZO^{V#mbo%uI|RvkZIhB<;Dv;VB#w!#!4v5+iOimS= zoRcFv8EBm*)H*JT*c2!riB8p7QtlO%`bicPkVF@79*Nu%euyQMJQtAUpuho(xzFY} zweVTgwxD;Ko8H-*RPNk`I7GS(KO0Fc5|*a9S=y0A*@KMxXquahv7yvL?Vm-u_*Q&w zT|+!sGd)=;`(qQNDw>N~Ig_a+!>Tmas*<(G@-_jY5zp*TaXh(r%#igeZo2Fufm&&- zT6qDe4ZyB6)~-1DlJbUV#4}q_98WDCex1RNC~-j4h9F4+R+0+f$W8{5 z6ksK(a1^yExKn_2CjqHAaRvpX!q!7BiQE#_1M69Zv?Mzy_*a1SF9M6&mi=THD`a^n ze{nxvK{Vp|Cxg6rYVr27VFD}5rppcjH@pG@F>?T@4ZzC+0x!$tOI}3~tpo&GmKRSg z9#$3*SXnk*_UJ>*0s=8}0LTqsx&klrmloko2u3*b;n%UTm~d+0aI%2F$ufn@9(QP2 zK%ixL@zmmBWdVVe^`y(Lc--y^2&610oLV@XEFf^QOt$P%hn58eT9y}2EFQ^77oCVX zIzzpa1qeX=3$nkI10pwsCwB#Y=Ejkn3@|{6wz9LOUKF({7+O$ZXlzrtHbWj03@s=y zG(L&i5(F(M5Hwei?4WQjEGTd^5sTUu6fG!FG_^QyLbMVTXjxu7wRl)rP+(=*blE{b z%z^?ja{$NW!ZGuqYp9D z1!Cp^P#b`k=>jj)it#1{BbvYc>g;czlt;AEL>*`tnIWx7Di^5Ut*!^(7lm1)P!n;>W*fuM0g#D>5Chwo=@ z9LdQ51C(ehAFKOKpeP}Mp|MRB=S`r1B)TwXd&(tI=LA6u2?WixCD}p2(Lw@86S1gm zLD51kiWaY`xgkC?nMlQ}YUdX9I|N8qLM~>;Cs12}n}uB5%oQLz5g1y?#n40yY8!C0 zkc*?ms}reY?J<}G94+MHXuN!K`53c{qq)S(&H$bka`7}if!YE*E#%^9@%p;yb*Xrw zwlae%`guiQfO4Mo3qVnug1?1a{H>SNkizOJevjkkTF6D@gb->Yu(^qc#V- zGg!OxpyhrrC7;eZ zF9QG&u@7Y1=p8w^5qv$Khx_$BOencC5DoxyFVA)JG2%yU0Nxk&@V-m~lDh)&KzI$i z7%xLQ2ws!636eVk;Y!#;{4&ps9{D~P>J;F8VGr-iKFDd5_$kK}QGLS&?Iu~TemP@0y z1_=y%NT6$LvV+4Nzp#e~ig?uapn+iz4a_`#D0}sR`h`8zFAqU&0_GR?FuyE8b}*2? zu!sCPAmm0cd=K}_Fl4VnP`|K;`sE?0O~Cx{y|&DE=DO%c79lwfn_y$arvV3s*ozIY zDdKBH&l8K1+rv%=!o&Y^5acG{e;|kjFhh`@2PDv6=3#+(5OO21Krp;!UWBCC={BN} zf%on}dK31%Jw&kQKuLdWvGZ+6AtOHuAvXdW#1~9^ZS?GgxF}h(NjA=F_#7O$Ik+Kc zUhBX~liokLN$4;0Fhusj&dky51(F+tA%f?%Lb46!Y83TmnvDX_LlW8ka(U$TAc?>g zzFLmrqz4F71hMcHlTW0#Xs;L#SIj&EDSdT~ zP$*7%fVMg2yx)OHB)4dbW6m2N7+}2jwTi_)Eo=io;n*UNLT<&j#hkZ4h+C1Kk!_4Q zuYo9tLvF`5#+>&-u$URuMSi)a9c|~kpDW^ye0urz$T^mndE^$`v$*HbZGw$)zkYBr z`T$vlci_Nl(hw$efo%uuDP=BTvir$s=!wxOrI5pRcYPX`TE?P5F5f)F+=g zescMFt1s-Rs7ppdukAK>J@mcEfT{i3|9)}b$du2T!e5=+H}cMhYucwyY_}Fqysd3b z_xX|U?(b|}RJp{OHR4dXWA=>5p?lv6ue$nr>-2S-((l%tW8Jo7c=)`r=UZbNPDs~p zI%Gck<;m%o?*~@*vH`(4k&nz<&bT$ywC7cG&Z_0D_s9QY4jp|?`Wt1nxuN^GwqKli zjnzHmSIv7;>&=Vad#&}9->x?&tv#{zu0F4u=ia*ZiD{4h-dwY4>5kbS-*5TOnbLCb z$-U;Rp>Kq`)~~ZhRlKeSmMyjRo;^i>@6G;^#>yqlBWG^3p8NO3ddXY;BRy~Z_4%GT zWlY=sQ#M=sb`RW9f5#)1&(}DrW6T!(_vRui(YPQ{HjoiD&xMSF_tzB)) zTb^Dt*zy^lxBTtZA(6(DZql|K9BX}Z+riMY_x#bk?(XA9^gDfn+1PRC_UaKEtU+6c zKehMFJ*N4@fVO@+(Z`%Mdh_hnR{gm9Lm&MwZCy5DUTEl!d#o{!k1{rQmRWBfIB~~M zXZ5xAU;Fcx&MVhgqnjU8V_VO%PMguu`ebL9IcV$eHtjiKyV*2ygnom6leK8$w}Znj z=ra4f|7LphwNF`%o9_%(Y`ohVK5d2m-)YZUjaS^RzdmoBwPgB>=?lIbWxd^6p}TR z1 z^1_rR>*~)hd1}5Aw#MDOIo^RZbI_a*! zi05lVKN>f*W$c}28rmOM2iE^&voZ11ZS6Zx|8M!$OA=caMh}%YRZi|$(>Wow;8)|@ zfBE<$QQy>rQMKZ-nD2nvHs)X?_OA_JnD=hDG*+>-$!w~zVjp&0W!`h$>ge;YY)n6T z`PiuMi-@J}j>iVPbZ+{uzL%myuK$je4F0G5$K_{~uK4Ja@|8>fYQCMC5WC~?&hW^sgFi3#&3?ida@_dXu9j^(esf<_ z^sT;qc0PXg_?Yi-Wm})6Myz7-z@5t)-n@TK4dNti=lgvENQx*51`=#eOt=S=%-5ofSLr$xCB{b}hiZ eJG5i9O>`{4I!Uz9LdVc3wzgJ5!MyZ$K8+-aX(%sFeC7G_i-sBbOR7Z1rS85>HlTP(EwqD!V zk-27By0NwK+=$s|G!9&m;y-=va-*>$l}RQ0Iy(o_nF0Kp7Q?#yIzy&5km;a^S%%{; zc2&A-=@l7ragGr-m4d|j!ln_n%YyBxC0}gs>rFQ=>FZq)Zq_U^Y#U}vkSMWGv&Arl zz8opEJ9SOEKiu3>U&W?nm_nlh)TyOAMnq_uzqqM8*b&o+_;s*iV2-F0OOM!w<=4rg zc(}e%x3JU{J*`qFmTp;wrD$nE{K&Xa9}<>hL=`=)QYV%kHI1mWrDnZzJ!%`aw5J|* zV(E5^VM|+D64xGe+cIovPd)0?(j9|ITgt^d`mX97NcAl1PU8p7M z4M)-7s%1r;`A&=BC^}rFPAuKQeyM11)v_X&?ievehpW_yrN>MoChcsb-bEX;jc|*! zsUCS^?U07Nw5ug~?R5)n=t~>x(I=M=EE2h>a*O5ie#?qFeP96_Wsm)As9``)VPAY2 z-*h9*RV$2~0wfidqm)OVT0T%JEN7{by#7=&(Q7$QdGx8}1E<1rp0es)$Iv#ya-{O; zQ_BZZh2>0DlGh)7P*qrtRUUnE`Dm`3tGK-13L`H809IH|STY-G7=Tt-PFhyo3mL2x zmcy1upISa(D=eq2lDz&@ve9cfaC!8p<%6`sB5HM{m*9M$zj^#AI)Itf-PM~;;iIG0 zApxX~ut-^Id1~>1tguK~QguxXg*Xg^4LQhw901J~SxcgCYK}f&D=gBMTAo@wV9OL? zORBCp`VfbKfLlQZ!~oESzN))D%eymO{D>@>%4B+*yLyw$`una-YAppg8Ky{Io-`^q zYBVsIDKc2CX1cB7XIUbR5&fv~_%W9VXF+y-%QNGx(_2Tdf*oO^RN9flVWNoWR&zMr z<^kqR5$33d)PR6HQ-nK@UvFQ3kLHd8VTh|98iF{raA2J&Vx6RL-9F&7%oMRsEl({T zG-rx5=c}%{(?HlTMTk?0Qws;cnIgbRv~@=v0B4E-r5j)I$01B7Jx(xvXy> zneJ&%cj*4w1gb{FhpfEH)RI9}5s|9Wy+`wy0Md{b)~8yYTs(Rxa+OqFcai|DA|hG^ z8Bha&t|B5`N%S@04bqSomZDmoT0HnFBJ!0~U3c_xLKYDbE69Kv0F)IGDa)hJp9Pv9 zS0HReM9fl&Qws-UMMTDuXzPwUAS)t5mRg=%Ji029m9M()j0a9dM4VEIQws+_St374 zv^7T^(vVjzRz4e#95)k-hceVv58^UY)O8vL6T=cg$umk4BWfT(k|jb?4L-V^0VG)> zB-L(34Fz_xM0Qd*)d$XSfm3bsq0)&Q3G+eZtXd-JHVXV@iTp+IqQ>HHhKWE{WfiXP z*DH{Qy!>HMB~LA$zZ<5Yva-5vBXHnlDZtEQKn(ycvlLub(%0NYAZ=I*EUS{I77r@3 z6jWAL*ByO;nWX?Tj{!LVx~t%_aC05p1Yw8^AAUUxOAx0P4koh{Ojasfcie$xmIBMF z*;vMPCE@lb}k>Om~Q8R|(E zOkgD6ki*RqBXS_TxvStaKOfD`zy?;TrRZQo%c|t5#e>SC3MwnB>ox*l7FB?m z$ABCFjaG13Nndjpfi&dh3%*L8T0E#Ms-Uv6y6)%$%%Tb~^B7PAfXku^E;H-lCI}l* z1(a2YQws-^MHNg|($*b!U|CdwWmWRj;z4Cm1(g-lb!R-FEUJLA3UPAb=%|9pO4_=k z4lIi*u&hd+T0E%CR#2HK!x`#^3=D2*DXgfWKxnpt z&_t*DxCt(XtspeH6EzY5%~k-KuaRz}z-YFD(NtdKSeO$EM2pnNO^}AXeCMu`rxp(? zvlUcUR@ZF=z|2;Fna6+{09CgqJ=_Fgh|8br72@Q=(N_hNm9%xo9av^7u&hd+T0E%CR!~_%U3bO<%4`Lc zRftmy2b0+fCM#*{jyjIYYz3B8$y195mDvg^iySp>0-!Mkph=C0f#3ljIiLCYXm$oR zuu?7MqjlZ{7lSDXO$=2ZH^Bu?s*SmvsnUtsCjc5#0Ge+k-9~}Yn1az%Ues7XH0Fb7 z$*!KI$)TbTCzKp zN%tOwK7i4f52MNS$>pQXK8)rguiFD~8uQ^axdSx&o3>V97l zHn8$O^=q)Ah62AaAAT!34P-F8>R;nHTw^{Ur!=7k0nITVnv;3x&N5&f^TE0R8)_Kv z9TWL(MJk=F`b}<~e*+crun>0pU<+~xbX_DprGsui0DMdYzFJJELBM*B$a)G7O`wN7 zEQAFgY(WeGbvV@4Z&UX@YZw@g2zJUC#g3ltew;ZTPq7XM65#@pmr~!a) zj)-pxea%$?c~}HByD&zYT>!Q@BDP6G=&lH0nQ94M=_%Y1I#%h%+=_n+bN)(BcdJ6jv5Yh=ZJI{;8yv; zl>9gY;W;AVk=)49Fh@kvtJO@maiBg&q&}h_H6G~C5$VszuJn4K=3jh(JS@~CP-Q_4 z0rGPK$gk|6+ZY`4IRWtJF`)(l`#AyZSMt!^g@Ass0Q9S}AclZC9BS*gHkZKwCTboi z$0)X(90XsF7r=gnmkBjj2E<_?oaOm0J^}ow0l)|q}!y%~C z+(C0iK-`E0fM4mI(edwtp>_e>7YpFNlD=*ifPS$6^sBO2?Yj7z@Bag&j2a4fudV|Tm~n&1w9MOZL4&nMgs)K0wB;g znr`E8;uj0xK$Rag9xyN#fPtmg4|R7RpkFKi{i-aeAwYhy0P-t4=r#uM7Yl$tj|n*l znjgS^B@f+Q2+# zxjeW!ECTqi!h#$E{09d$0G1rI`vC|Hw**jNl?gcrC=eb&J+A}OxH`9}rQyXDNDsmL z9RPxbfztli!u7dHEiFG{LJk5N#1~8lE%fq6ToP^_1|~ki&oz z;i8@fB^OPI{D|unWLO@lDJMJMvWH)L3ydgYQ)H zhSS|;;KdDtriIjy9+1zF0B#hA^9XimwvTgT#U%~_16t`vj>oaF;x>m`?X(-n$+6;o z2cjc6A_vEc8y^J2Wbs2p2pIX3gI}DY5j7CMECN1~<`-U#K00p^a~D`iz)|)BuFmfuxe0fY%mqm( z?M`uwthgjXoDym?a<@277H^2)T$<|cT3LL*d!V^}d1q(3KiStlknUfR8ovg&t0%QI zJ&?@wEohp1DUc@74^|raICT~;iO}#P$K%vlyd$C(yYZE$ z+d9si#VaF-UgTJuHj6h%)NZBMFn&N;yiI~;MGnOgvv{*a*(sCiUy6&OON*_*MYJv3ScwEe7-eyxRp_T;5xDDD1FDsSBJIi`Pxm=s}Lb zL9ux2M6iWkCpaV)Z=R^lf*gWl;`r?o*Ug$WD>v(g{E}b&!~MnbnZuncM&0{AaDd&h zePUu|*ZCXn&0ZRB+;v6%(K)gBu@Aj>L(KU^{JZ;pm59%L{mo{hJ*ny4tpImaUd-kbY^V_!G?RGqJUw+5l@4B5YUY1{Q-tqB;EhleC zyknJng0nk!*ZGUxDHBg0cBVb$9zXllhqsl(Z%chlCx?uAdT8v4v%LTt~3H=N~d(Rk$Cn{#brWA1m3oyc5k#9i>W&DnFN&vWnn zv|%?lj<_FvB4kZ@Da0l~Z~3$)w&R)5;9KzHx!;LgFT8m3VD^|3*wRy1*~d+7h_87e zG`JLhe%gvH&MmvF_=K12;jYcw-R27}%H7>M;?CZb%D-^cJh!2LeYWF$BW~B|#-W2x zhuD9S`z!Bgg6>|U<6i~8-N!Vsf6kpU^vky=+JvTbLA%5ewlhzNd3$YdFzLX1% zS@9RAwdM9d{yVZ<^y?^6G20dE$tB%l6RVA=Jb5d!T>(&qI;D z_?hRqetHvoY~@aC*&ipd$4=RnyLv)HeCq18gImzY1s^_Tc=4QQJUjg#u@&uG+-={U z#qOLo<__(i+q%5-nfM21PHvsEYApWlFQ2gS`Bh(Un>cjphN(BqN?dy9_gl}IIy2GR z|DpVabI06c);z%$9~yPo%?D_>EcXsyTp{w^4$8la0`@+L(2VK-P zKCbvV_xQ}g^k?4AezdhccjK1B?z`~cUxWYtKI2a;zHih$>5h!EFgNC=nnKR??;3M= zJv*iK^DT|;eeN3jnuQ^Eeo_Azs@eW6WjgZi`IePh1gE? zaoVCe?Bet98fv|5GD|Kwj@>$`A-;NZXs8?WbWiTh?EU|R-mRPRGqZCN&g&=TyVstP zXu4}^{MG$mZrgOx-5aj{`*wHH)AMti_Kdo#vTYk)e&&qC=!Yif7oRyJ@#KwXCQg3- zYpwH>ixOv_@mlM1EuT)Dv*GuNt~VDZE==r*zvr8g#PF6}V!~^eCU*X>AyNJF{{j9S BMs5HA literal 0 HcmV?d00001 diff --git a/tests/data/save_policy_to_safetensors/pusht_diffusion/actions.safetensors b/tests/data/save_policy_to_safetensors/pusht_diffusion/actions.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..d9b2031708df0f0770c29ef3a44f5c6d0c1279cf GIT binary patch literal 4600 zcmb7HdsvNG`wmHrlH^>C4h}WDbdV;QrnPP&heU^Aexfi5A)+WTB;^>LL+{Otc`&JdWm2(;I^?QmPY)!p6uw;kTxTO)y)v58B7K->6R znH*1GY-BQnC!2gLnI{U2%}i(VY}0RMb3}oO(TpE?_Ka_3^F)D(*-RszZT8Jhe zm;1dB3B_5cO_k85h7>wWrINi5IO+s`&GBXv&2Y7c?!5lH zN3_i>B+-Um%+82_-SAv=gsyxGI4&KJ$n-Qk6G@^-)J*>01*kpToIAUhsQp}Jj2 z4R+O`v#SR|*7H$QFQLD!Q$=hXPL=%Fn zA~NLuVzfH!N1;VMDMr_byR#kHMU6&)GU3M^L5}b=$d>#y?lUFEw+C^X+by|MRgU z>Ewf@$O-!oIrE)>;Zx?eVZX-$?%A3ct4W{~;M5=gIA0z1etp9`Qll2qZ{OS@8Gh6Y0fXjh?n2X@b3+a znf-tB{qHzIvl2m{`xjaqDLhrOdR-GAaZvs)h118sESo!oeo zy)wn&&8Nun>I+DNkSbf2vAE069L1}uTpZaIM6<40la=u=iBi4~F&(Z-)g)KRhTCr# zZkzBi32gt;ALqc0f0IWU5k%dhn@lnM%-)-+tYUgj`su*rSrZ9~-a*V%x8skGm*A8v zC9|K|kj%pkSXFs|jQ3u`;@dWI36xTW@C-@9;Q&X*KbxUXOI%(N;aMMKbuDB#!+p=l z8Lf6Wr1wIj>M#!Y$@-G-&&JzLxbEYNL25TSjS@6G+RJGcl5ArO*zVC{`Q&k$2h9+? z?H^oih?Dgd_s*wn6b;Ym{<;@s=Llaf9=~x@JHx3o-D0{%D~@Dy1^Zcu7@5iWl_E~v z6cf08=b>FphR--iEPD7}l<~*zUW6?+iMTe?gS1r@;d}84l4~G^-iNbL{cscK=Y^nk zS{4qL%Jg`wy+(y2>lokhuX?h3i7p12m!Z03In{L>Ob2Gj;kwyyv`=YedU!w2%RiyO z@Dy>LUWjoO9}rcbNHu2vgtbB`y3MYzICxxTNd-A|(V5NjTYKnquSN_#OL2QiF&r-` z_uDH!(}nVR33RDWJuTn$J&RW@HUftfp5s|=4$D)e(E$<^S&4>wMYta?!NB`n$g4A_ z3sl4m-(DI?fB19}!zLbJbx2zA7C93)V|4r)a=QH*N_6HBsg@23J|;3f4{OI_lV3N2 zxZJd^6482JOJ?U@x`gZ*90bc~uJ5>duz#+u>I~&@Ss~S%*ocN*a+$hUGM+$E!)7as%z8Sdu zGFxYT@%8+4I(!hi$Ie4(@oXCBupadWigaAy7^+_|g%%&U&hUr(zSfzKo#QD!%O9 zu51g2vV9n{V>|5@^u=+kUzY9U^2TWkH|oR&J!|&N$6X$Lh?=eKVfcUZAnS?$*|!|d z)h?c#zHpq=Y)wuS@inq0u__Wgdw2$W=Jw)m<9KExbks%MZx!Q-r37)7BhVQq`D)vm z3_ExizfSm#`Sz;ahxpfzP^y_kC+rr%>U9ONG48{vI9KX;q#7?1<@@V0ead_0FFm3Q zA*%Jv-i`W9*xf0?`hB}eoEhJ9|F73+XC1s!G-%BQBQ(}1W7ZHUCan*}s}0joon685 ze){r{)T}^+&%vvRUL-+<*A-@0-*F906fQB1*;m&vySCfWpRYvED2BVYEsdJ0zsB+S zWJ;szAmD2D=*vB9zqW@A98-#pr*W8eBp(xeTcDF#go`_lG8^-@-$AoQ3J&^o)2*4K znf(Mc5wc3wLex=&bs_tSe$VRucz%fgooL3tVe`9B9Z;56jZ+tzSsxmGRG~pZ7eF?O ziAVJbs+a1Cy>^t?c|XV7RC73KdNbbG>m^7m?Zp!B78F^f;CkRE_C5_eSS(WuTJXrU z5|-v(P@E)1SEC#975Q&2Avt^e3z95)VL7W0w^F;&7{i_G-20*2T@m3KjqSt$7s|Qw zmFumY8@XEZb>P-dm0V}NuivAFjl~XZy&$Zdt^ZZpM0QkcV0Bpds~@u`>p3BjRP~~~ z^cofnYan5cN_26p2>GTB`5gMukeFYJTP5+1SE}oR#%cFJL5j)VA^(XIt z-^=y%^r8ajZxqA-cssU5E+H+mma+A%nx1G;wnGfgTXd*&&P0j+FZs~wvK8B0BUI<%wofJLqBa?I2vwn~t7Kn7eo)fBp^8nVJlQw4&?;Ln65>xT{x6oNHer@iEZE$cxwN_bh?G`H4u_DX{|`{lc&cD z!sxj9rSK3cA*-^Q1}|+z*Md1z>*@)JX&(`7@W=04xcil>pWVDn`t-nj@-QR-w!u>T zY;gyk=?BTuFE%XBKTmCAqw^FgVx`@fdr|@)|HJ5Ud4roZOT$qbS4IAjMbyk;55A=A54F$l|VRA1sT@d zeard3YW1#P$AGd<1cWT1o2FlfReBg6XIwzc12HtFX47T6xbIXB@BgY3j9 z^T~s{Z0KqABEY>B%^o!{Iw?<`|Ga{y?QdyI@c{VEh`{CGXp|V1P@|DA5cQoQ#IF2r z7a{Qp;_fnTREz#Vi_|}p^MSH{=%FLLgyE|CifKX2YIXxUo_Iu=vXysxcIlRa{?ctD9HC#G`IUzwM%W}>UGv6=O(PQG0GI(HAtOR5Z4A9L zqC&F-k(d#oR5%C14PqNnN}V$d&3Z&5MwC);4~7{gY8o+0)l-y+-ol7sQPWGGhA@M~ zVn&Q!`H1fJ6tj&uz4%!SGfX^U#HqC(Sy^#?ST|1GG)!vwb3n{MX2>w99UwT_8^?Wo zm@&howm}Yv8^|^+S|4C1)7{Oih+)y%pc#Z2#4-(=)&VJo=5A((44c*fIT&Ub8_#!I z^#w(9H?wWSp_g87H*+F}L$AE{YH>^>9HLi0%VCxU3~;FB&)H|V>mL}Ps6CK_Vuyn7 zk%TW+94^k;YuKTL{E_TWLUBXE0>dN?j2EJ5KmbE&h#U+z3~Vq=(!mTxb8`d42$P)9 z1L6jP6NX7rnBr(|a6lMgk`?A)xM5(0VUiYRD7w{w8iq-3=mD_LtS1{mN_JTnzY^HB_h5g|!t4u)AB5I~{# zfc6#-h!G*`W$Fzr-3Hu1;;@nEmnk>bK0yHnC~70*JoVY3&%i$O@LQ+i+h#Lr=86kOQile*BK{6vG z$;`oU!@w^iB)`m1bgKisjF9xw17Zh47m%b;VCY))m@!(Nixk$7~L{8eVUTCiKODBLEP$yPgIgTGZjelQ4E9;BS~fs zh8qTc86){+hN4>?=w*zkm!dk^X9+hD{4z%L%K~3GfB}lOFL+SwQ0N7sWfu6t0Sr*o zO`)JiFo@a5CYy+)|U={~_=g>kPHl4;y%5XHDx z6e%#>%E2GwUVmip*zwRPUZo7i;YGEeH3!R#do7bp<3@vWnqK9kz%(ld7I@?h<)>Sj z<}u>|1)N7n)$ugf97u-gZ>K&_R{RWNs!!51ZZs&W>945yVY-!rtD63lQ1KS6_nZR>G*X}V(Ad$?CjRQ0Kbmgk;IO8@!;*O1co13BUu1J} zO(!e``z27o5q1=R95)MEQB|wHnky$THvvOd8=X>PpqYce&WtzvW9CX&A=qy&7&Uf>1ZCszw8R_eD zzkCR*MG@BQ_y5?1}U}FQYyoAO9wBt)Lt6MFcYAX$(aic*|ZMCAxFwHdw7I>si$-R}Dm4m?AYJm;oqn|i9F8CZ0 zDBuW+;*aCj4mN93Y*wP_77s#eQ-n4M$PEaewJAQ^#4FS zDO`bm2@+@oOYx;~qd{hEip+Xpx|PErt4;CQ03J6Ubk?TmY%WfB&B10Jip~1cn9%?N zjNK;MehC(M1WWOTOLq!89)8;^MQ1h7A#rV7pP(gr=a9HGep#*Or(NX_ScHosx}WF9 z)j=TCKTF_N09xv(wKVcGokId;IBG2o!g0gFN*%S8`p`5tKEMn|EvCUhZa{dcqxRA? zuTW0}2{h&A5I|!`L#L>vREFu64qoc0y)=-=jR!4t)LQDrX+E<+0!_J%1kkwAAf=94 zN@bXC<#5RAsJ%3h$BhRqb<|p#!4>FfKQRFwbV|3Rb+ke+>!)hP36*Sd@Qtts5DVYzPNS; zgm+4Vlmf5M|MEqFeI;}J=#pl?_u%LFG>8BRn%_cK)7a63UP@Dxn(ka)n0BRvUrM8t z$UJsDA(zrTC4XF2)pbR$X_AsON{I$z2NO~$O;J*^w96)>QW~N}#j(Q)p_FDQDQMao zpO8sud=d@D4<~KONrAbNww2YGLt~D`C2{1o_ zl<%Ym#tkM0DuLyvk@;AS)@_ zO}MEvfQigw#}le5O=9xLWgh|ByQDZBm8LP#VC-N*TcwFiN|tungt$tRnW#8+IH9f5 zgeC<|*GRw-Fx6I)I%@)?>5N2om&Et~!_uTCWiaik32l|8HBoWwstIeACN>4oJd(gZ zLjy;^RL`5-6Tq z&3g*MWj1KtR{=5g2{hV{l6}$F^b>FdjMyeWMboaFuvlpbl=@ccR!&%~Gz&_F(_VAJ zVx^H#L>hZL5f&>=hEl_{D<@1=8W2V1@#Bd`@jZa`Y$-mUO5_`uFj;9-lwUX>8atXW zS!r678m3!0F)<1}!iJI?PIJwPu~EP&w-o#*QYW)PL?(UYK^}gqQjczM}B>@kFDjj#<-d+nVvNQBxnRl@uBgfi8fIws#tTmFzw0-XZ4?Y zmG@Tat~nvF{&TN#aMDqoUg^ZtE8vuh;*aCjP7J*QP1z_CO}BVr?p4@-@Kyf1saZXs z1W)O+sZ$Eynj)2iRr5Q}Sbi!x-Ctu|ar&~Fpv;MQM^5Arz zbpQz(vrp1A_No&y>p%A@Z#3=737_>Jd`020;|ZPhpM2$m^B5o7F98BgZL_Hr?xC@x z37Peud6hSscIAZ6`cJ*0@YwN$&ic>2%7fEhbHZl*r(UUP{Ai+0R7b4o-|(D$>Cr#? z2TJgW8%m}vvpjY@p|k#TuX64sKw9UJxU{&w@?n#{KKsiT!sOafM zpTLeFw6tk{)y?gj{#n`ki^eJ6exF!=c(pS*)+6yzdAG(7I*duQKk*~$wfX;ai>99+ zzIpc{ch&>XJ8#|fq1$fRN~h1vgYN7_HT4HNKj|KPe@NY@bC$RRte35=t)6j<`t(^> zv2T^Te1ALVjZ5acFMbqiJg&IG?RN44Yu}bFZsgsY%%UN`b9eW)o3!c(A^I%LVRfi~E~zjaly+AIzx#^`}p{&JL&Ui$P=EEBhv`yU%zi`}h8@^}dFi z>K}0b{C%YH^i7YsMMc{~9}R!d9rxnkhM6x`yKP>6-~MLy5_jmy9h*7`D{=IqsINpBXxc07H|7q=+6=%6cTQ6Ai&WC5a zb)mDYK4W$`pLLy4|Hb+3++Ck^sI5HwYv-$Hm)1UC{=M_s?smynOEx42e|CJsg&&tC ziYg|rIy-!*c=IXeHQe~x>cqP@tZMAHY)+!C&EAH&%R40o4V=2^)NW5FZ;ZBS+}3G+ zV%nqA+!?JqCI?z`t)u#PL*8LMe|PgFcfxgN)vu^)>sDR&u({;kzdA*IzKQHy_@Pr& zw6)iuqn}7VJN8I(L#s#J>Yj5g*tLw(Ni??6faN@R8+{ZiJ(_mDeDDs7doN##W z+p9M|wQuA{?yYT4tp9e~GWWKrHTDnf7r8Aij@6%Aaa7{oE9RST-t~~Xyii%$zJEGIh+D{JkbC=Z~+c5N%1J1f<$2WXF?Et?16GH2zeVbh0 zA{rYs=PkG2H-EBPy}itBH?5n!z3vV7v}?QDy*51Nw7qJMRXwN1*?#aHd+8l56TkS^ zQRba@x5ocp)T3o+l{2GwTlDEr8|Rt7Z8B${yvbg$a8CTuH5=`sqEV}-4}3Tg+Ouc< znED%>-4_fnI}SP588`FC+CA6yb&h(Y+?p`#9;a+To7!zlk9I=c``2G}(L3&*-`*Cz z>(Wo~|8HG!T>k;izJXn%kG%bX-Q|~S>h@2YXRrTpVO^Ko@5Y~>Q2h4+^Aa@|*F=VP zu5hZ4{f*V7o#7m2oo=n@aSpz|N7fAMH7!|s{KMw91s#%8XWr^=J92#Dyz$H2UF#l8 zoUroMWcj%L$s1=4PLAKTsQ1cQSJk(;<9>JitSwK?x?vHn-+o2;{vGzFjUzX_-C?nP zQmgNtTJh3A`;jFJYfH|)8h`%os?(QWl9>0>J!=o0)zz8O<8pI%>wR{=OM1qqeD$8) zweKK1cJh38arG7PMeUa0?;lw1EO^|m8*_N<=%XIByPjNR9&zTc?Hdm)iyyq|a{PIy z_=&X_B|2U6%KFVa|I3-W>@)M+p?mEK8@@Ha9JJS7cKgKo%P$z6*s}1yYdccaQx%Lx;S^d`8TuW^KJ3x(^n@Bo}K)|*RQWXvhk2}&c!>d zZN0WT7xl2Lw`yK=p4e4mU3mVfZjbgOtS2`|+_%p@yz!punTd51+tt+7$48U1@iq;pAq?}5<)wf+9RIMH+Cla1FE-;(HA z@?*o+=Waf4djeS$Eu@XtD3r#*xYKiP^KZ dG``nmSYq|*S2p}`OzULVDgEkS{@~=~e*h>$TSx!^ literal 0 HcmV?d00001 diff --git a/tests/data/save_policy_to_safetensors/pusht_diffusion/output_dict.safetensors b/tests/data/save_policy_to_safetensors/pusht_diffusion/output_dict.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..77472bb50b2896c48a0a739eee9848e821769021 GIT binary patch literal 68 zcmcCufPiYHoc!WqC97(sl#3W)1$W_B6bo!RWn zh7S#-*fKop6==As*T6&;D3H**V4*?1IN}^)>h|^9^AigdG*@0rRAe@ z`xmc$;r_MN<+2)!BkpyV0_AxByWL={4ppeV#C!{g`hx z<}D2`VZ>tQJB_xb>tjmGw=mjR)b`c~h*(Iw-Dul8-|}ux?OtQS-ux*>EM}qASg`i~ z+O^f|o8EF3I*pFC{RNO%P-l0eW1RqUrWa1VeRSH5j&%$QAhn=gqidT3>}1~EtlMgI zZDX(vA{Nr^G`@60Cq&Ro8Ua!%&x8Ap#^;?a;z4PwX z((g2yyY1bda>TZP0gko(g?&a{|G)rcoq+;ME(*S962A6abFQ%0$VE~9O!gO{)S_U4 zO_K&L)X_8`fUyij0izZJ8*G|%a6)lzZh$nJCMPU`)PmrIO_LOkInE6Zh(^<7g$0aS z46Lwe(!vSF+a0K3)8vLFkX#U^z$Ay{6t*6DLx3PQO@gD(eEsTL$5Dy!TemQn? z_X!FxKv@T&aMzcM!Ymjqv*rT_FhE%cqVQ~^76sRAnOw6|=YazVU@Y500izZJ(`=bc zb3*a%ayW);nN+g`QVW7-woINm=6H8GNM_3 zil~+dk8GJdvIKNKaDdgAHyLG`Czb~kaMl~hwhp?*0yGdUVi6VSlP16#^Cs8q)Op|l z0vOv?vhBG>?H5dQ-ej7SVw{^jAOP70qs{{d7~t5(VXV%s=Xf9*^Crn0UpL>|8z6vU zp8@w44@hI)sF&lb=3G1Q%Xy<;)_mXq1}N(w6z=-!8iHlc8!fZu0|zibSqGxd0|zMh zT!rO5uWp?O4j_QBYzqa9+98B=_A2 z7@(}1LY-X>6p$=82^O_YSKXi61g+6F*<%$^>>FSJS|?;%iMY>u-r)eN(KZ=nsldQf z%Y#?8O=!>tX)?`8G2S+fK22%c zM8@R=klG!7qSEA<(=Gw6h=N$=t*`DQ>(^ppe0H18n z@=3n1cY)wwvOUWox4pLdW5XHSvz+nB+S;MHRs5cb^LoP+7iO7a>Gj1QA1{rUpRuuT zy}a|<@mQSSY)lfuoiKw_7N?-E54vL-4zWvfr6V zYDw_Bh4NA27=#a^?K`}-bnMXbI-c4_PuNSfml8B@p{#j!oLn5nqHKF6nzwBbzlE~+ zDL}PA_}@a=|0;R-#hv|!)~ z6I@MC1PQdPZB)?Y(l91@DUD&?&cRE&c`vQxspUaSdwDG__B+#<%Q9%<+0|+pA zOs1ddQ@sSNMlbKBadBZ%)bgOEy}XtZmpB=ei-VQ+@>W_wQ%i%Q_VS7v!<=glEbvT| zQoNMT&Ou;%d4X->qn|i9u6+&(6mSMbN#oSs!Df3Fn~iAR=0RwC7NM;Iss+MldlsJ^ z@#^{|NR6JwW|K6vG{|hvBC}aG{|hvBC}Qt zA@S1qvf9#byT%`|XwJ2~pXbJ_gJ9_Yl%RG1TH4QRY3nqdLju+4=e4v7rxph*?dPpD zL33_=fHnGgF|7uw1;R`Fc`psTx}FFUXxYu7f+m-SNy$rT4D+@QUfR!lX(dlB4_eyK zYiSwhyk~&~T6P<$psA%nO8a>!jbYx-;gGeT_tHw9S{}5tpV!g}uBIn~)ad7}w1TFV z1}W|5r8I_lI|ncA=e@L&Czppo$!lpD=UsEK(th4bD`;wIkWzdPd{#wbnzMC;k!api zChK6TrecL;5oo&o6i?7xZ($P!US3XlNfSJ|Jg6#gSyzp5?he9Uo4fP+T1k{kgvVQi|9a*=Rcl+M5i?`nVzzMZH3;(s(L6H| zQ7#ec3!r&qCP4QZ#F1=sw?%%76XRWjkYC{BftjudeGSrpA#KHs#rQC;HGapl_D3km zA}}>F2Ub=mzmS6!SLfH^2p}ajC)E^n81edmBvm&F@vbr;$VpAbERI|b#Wu=W8jV5% z<6Tjsl%*(UDMOJzzE~7l9G;XTuUeOyRDMe9AHke!(r7}uU%Cg*Tqf|yAN>-d)9ECEZPRh|YuzY7DN0bvmKKIYE{+lzB`H}<|PUqqar~O&OZkNWdX1 zx0Q^~ngAJ{k$87Wy#L>flA5xGxw}SfjnbN|IJs-I)+n*5f>x3Q`3wyl!gBYU;t63` z*m7ZX)+n_pTbR3RJpYU$o2)pwI39mSc}*Fbd#zDgqp&6$Of8H-O<-vgGF1e5Aeqtx zoNXJKO#c+87l#>%GMYvg^4O7WM4U!0 zPrxC}I3`=AxjRRTjY6REm(trgT5OaBmBYE$94$7Agfi0P?SvK^B}3(5?#|I%K*U$TFczyZDr?5sjRjD7}-49x7lvFi?pq9X>DgZKfknkN(_1_w0ZWVBudq@>d zFAnWbr&w*<7!ovZyO?GLU`9-TwvMU)V-g_26pe28gDo>M3qm-s|uV%vBouikggRd++y*v&|?wB>K z@$$)_zUFA9>Da48aWy>=9Kdo?8kgqo9KAH1dNs4Sz>_;iEluZM&A_?W91mI3saJWL zTpFb`oq07A=I$K5G@W{7;mPGuOVhbmC0z0F%x8@Nk!FyT1WhlEW0F%;uDK%2-8njI zI`wMirSz^j3T!&}s(_1*>U^zZ>J@NVMM>k--ZAtFwCtcnG;i~md(}(_U(LRn&h7z4 zUe;&Frxe&JYyJKiK%gbRBp|M?u3v&fm~B$F!Q8E*&!&^FW)|n|9CbFGeKiB;z1IOm znm8wcCa*fmY&!R9W@+xu(Pz`aR~DXJ9(6XIe3ihJj8FDUfI!RJ?0AKjXmV+k*>vXB z%+lPQqtB*OuPi*dJnC#Z_i6^tz2<1M>C~${O)rgOk~?Az-|#$r=<%QZ14W*3Lz#@_ zlqZ)*olWOn6|N;fu5(DdG+#e^vk9*+|M(QvbgotD$9JzJ9;BwTtkMLz1Zru>uUbC5 zxNn*=KX>%t{%JDy{>5uwxPNVR88+;$ghCpEs}zu05Cf~g8MWNN@$Ldlssd)DpmvO2 z4B8(us!SAbcbH8D&}iK&Ah{r_W=N$naM~ES{j2UCQ8hzEl?sy!qh^M@DuLzh7d114 zSLrypIBI4{vNF@Wdp2XNnjyzZg~^3cF++-#z;btuiWxGjbevusXJo9HMjG>lr#Exi zX_;XKpmh`B;w}K?0;!oH$4V%0`ZdaWffUUUWaRvozlNs^4v58S*b9& zFluJVu@YGBbw|w%F;+TGE{@U}Qmo80_ijlA4Ix%4OfQTRGCndDSnjS-O+$p0j+2X{ zpN0f0BdxG`$(so?tN^r%YCC+V)pUDS3uKHHcvfSb@cg?abF7*n$V#}PQ=nWR9X2Fc z8F&k|+!7qX#%HD~OfHNb8*;1!mUrDT#R@p!Fch{)0A-_t;$z7w#hV&{Ut>$jY^U@GtBUi$5T zG8;myJV3cXx@<_XGV+!XQ@y@;_{gC;=RiP`tL!c*u3j$ZXMB*Vm+L)&0FYE=zu8u7 zpvtAvc08(h8fG{0spWfM&_%P!@o4fz;cOlO+^bW4EN*hNvtLP%e;4 z8`83jylwvOf3qA&a+TdB#nsE@{Fv>Q^Ro>Lvv7S#13*%(3X`VFrNVac<*ZhdiJ;CM zh)8m+LQ};R%O#Om`#^1bZc(=(asV?caHqyY5;0`ihi6=vyL}MgR>;#bZ3qRN+%?!Q zX!c|Sp1`}5S#vz-M;mWRxTwzyNXVeXF6L2)ge zJh?ooXNcn}PW!C#=ACPcX=CJb2$oix05;=Lc07|xma3hi0d*^w|!}d@6il6FjFV8 zl}#-(BhaNdz;&K&>?B*@cEVl1rsVhiEWM zsl_k7b9R(29n!&^Ou0-7bqEO4k=`}Rps-d53UgECQmNM=Ff0!poi5e>9%=FbO?|c@ zwe6;%a-nqW5FO?y)ZJ8ScE}L(WA$P=N~uHN$_rd}Pklq8V~0R7W2I%LTqYGego^1% zcS~v6AzaK&l}n{$hmbK9>h)dVz_dE{$Qc6f8(s)-ZBi$XPWrxr){o)3y+)+w) z2p_XTOYW(!Q95=AAagSHGC4$6#U6Kbt4w!CY1tu+OfS^iQA&0QB~zi^HOgQzFpZ8~ zyu8k5QihWOX;kbo(%VsHl(j-enSODd9c5G*pvLV8W9#}<1^}Q`A5IrlK2X$kh}S2? zI4INi^W-gtab(~ar(?>|%i%awA46}yW{JcOXQL>~A&o3IgH=$uP`Yx6By%v`%Z<7m zvdR2dxmenA2q`mC%MPgzEKj6EP??h{mq}F)VP!hf-BQ|e2rY9{ z2shJ_?vC<6Iz*nisdA~5w-t9ag0U} zUG>jt!2{GP*=eZUQ#y8tR?GJ^V!E43%?=rBeym(9Ejxs*8L72>7dSGNNh&k-GC4=K zFp>p+(4B!qsiz_EnDp{Bem?1@;NPVWSTaXaDmHAxlF2Z2y@fRbhnfT(jnB% zO_fWfEQfG673#ew0tcq;6Vm05I`4^~5$PJD>D@^mUD|R8d($tjca2h(L+G0oI=K;h zy_LalV45DeCLJv?<-w&YhwwMOO!pe4Er&=rH&rf`vK)fqRH%E6@<2L7$2pmDnN;Ou zm|XnuQy!ukYm~N}OqLU`(x>1$g=de$vlFO#shlI*k#uVsnlyRC-nmTa*vVwO ze4$R?1r126VwaDG-5sT6Cll^0g$AnJQA&0)=B^rAa!-AY(y^0icV?zsCKWpwdS^kp zJ4(w=X5R@^xl~GaG6F9T_4+PwWE$;gykI6%|`LGrA&86Y1zqKJb@~Al#-o{ z#~VQptgJ2`+Q)C`;df!y`pr(};(=6U*3~4nBnILER5dV3khf>d!fPe7@TxDLb2kNr zs496*l;vG<%)o0UGw`a4a;`Hl2ogrZ|1`<;JFtYQn6l&nq!A~h@5;SCVFQP`4fqH=)oXXED&*Hh;NO4W3*k!M& z_Z~)RO-9&di0+M&!kWykvjOEUQdX0pby=RfibC9u_3)oGUu zwZk{bRDHwTES|fI6xw90UG|zfcT~Vgw9L(EI@$&d#LACVY`?5&qCBXS-DI>~bWtOq zyOH$YWWt@9DVIqdP6pkjk!9hKkIEQz2cUU$%0Rt94vtBU$EP@0_u<6B-A2lBGV(5a zO}*=pYMczd%MjhWD#bV%eP;v8U8EN$!|$>@cNcl4n~cA+BIP1!#>oJ@3gXt?*!M;R zSO8Xjyg1rar{9njlu9m;UYrcT%NFSFBGovVeP>1LMRIN`KUG|PQQd2iVw_C8%arIf z9C$Il(#>Kx;~jM}rqiP>^`=1@krp+sK#E-?jKr$j)UhMmj&Me~m$eUZNCUF+iZ zxk$J%h*fSpo9JATeSXSBy+lq=r6A8N(cMZo zGMM)H*;Zs=%4I^4@!gJ9o6n3go!cv6Bsyltx4)7_pks7uY_K>$o=;}tjqV{cFRQnc zuw)>G;hEe=oCi9Dfmrz=h(wVkqCBjyWB`TH*#f%PBup7hVR|B_UM7d9@(_8(b#<>v zm@XiR@inXSGC5Pge`;Vn4T$9?j&p(NaOGnKPc&X1R)H>N?R_?PJ$?x z2wMiwI66h3dri`qlNoxc-8IQn%Vdlmkfz~@&aZr}WP%=$rr}8=^&&Am$?&}Rx<*8A zGnt(SsA+(zpmL#f=45nUMWNe7%h^mCb22o~NR&&YFel^k@<4Yhai~0>OwBVh*X)2x06)nWVl`q>0XmKTb@s*>lumiuu_!GN z5)Io;1yL@M;)^dR*RSgI{6yYUfg{r~JyWJ!Ce^n(z%MbdTaxCx0|EwOJ-D{HWTOHdsv&g`d%cS~N$M(%Yy4NJlw>q*fPt;4~=va=5C!4^elhgH~ zuSvRZbzI-fBE6lY`c}vFm5{{)Kj)gH`Bq2u<%x19sloWrv29HjOLVVE9u`-}^vx{N z+eu1tbxdCY8J+eyTS;ZcrvWTiHjPw!2^fi{yH-T>HWNHTO#@T~)eFS{;gr#QY0EY# zZ$5yLXu5Tcsw@-b5-H5p(R}g1ZUeeoNoTH(=Cd&6GO5hifAmgz@0EaoXfx(wCj(I~ zk-}UZ$v5Mw`tOxYa%ub1@m&nnv3G;Y#wt((6 zNoTH(;hT9`o%c%6fV3PJpQdk9C5%M-jca-$?p>1<=IZ#p8CSK&n&e?|b^M-%DGw}_ zxjKHYfW)Y}&QbC5U|2p<(VPnmxeQ8ib+lfo4EIW-30H^dr73zT9D?zY@ao#p!;6Pj zUb?(*W%bChjrFlZgx)>ike99}%QJV#yGr#lOsH-zSiEV(SY zZivdWqd?KCVnQBJqrE~ysYNju52Vrk#X?%&SlmFb7GLZs3S$w{@HCWM6h$^f;0dG^ z$NVt|Z$9MUX(+iUifahN+ahf&9zD3cwX7De<%)9lhngBv@GJnm0QP?@rP<+`lp(IpKydYeXRaZg&W4bSpfHZY z=x(_9;?LfEPXDd9{^?-H^glO#|MY*3{zv+spLgPgC(ro1>xNIf^xIDT#$`u`TYo40 z@9%uW`KMYBUN^k(8UNjXa{Ea3-~YuQ`qfjPY&~W8zR!J6|B=6W%J47W^vcbHAKGzg zj#K{gbIyEe@4fduW%#B`|GIbfA3Y_@{pZJ?-+k8eA09mMGjBifeJ^=3K zyiw-1e(jkjUUDa1=f@w@e$3nM9Nc-?f8YD^SNzKGvb!$qUGi7A4R3t#-@Bh~{@38l z{_FZ{PkUhSiBtc(dCfDf8eZ1>?&HhNtA~%A^`m>g@jw22@Ua`3{pVfsz~GjJ=WqVb z6^|Q!;R9cH&wsve=Wyp&cJKYl7kiif>&g=bYk##T?nnE&dw1Ob;lVRHJNj?y-y7Ba zHTPWczVG9Y5py6M`hhVTFNjXmccKX>^4x6JjQ@XnpXFP!|DPm{`0F2p0nqZAA0-XPv(FB#NXU==iv5d ze(UB{w_h~8@-uJj{qQgB9NuyMQ}?`L<>tYw-t)}kzx(9)?;kzqFZ)mao+}2MFFyD9 z`Z;@}pMU3lgHs>8|NP;tcYf^Tm;dd&;m4kJ*{>E9C4ZeEACCAVI zZ2bDYrO%#v;y-i-H~qwabdG-}{`3EO)<;i0bob+jC;NAGZ(83u{Kz|g?9`9_=3sF4 zH(%ZP$RqK;f9oR`?fJkJgTbwT{L$_WpO621>5dEczV)4tAO65~_n-X4%XbdXyZP?^ zxl3O&_@$TL(fO9|KQp>M|8~J|ZGP!BHx2gx(vNlryW;*l_k$N~?)vlB5ANUHJ9*Z_ zw+%kN{OrAR-*w*boU0$(|HbF;9DdV%PrCO@KlQS~#$6YkxH!?@H(Yhi=1+e6ErZYf z-9MgafBB|}fAr9a&2wLL>)@wecXe;=?{6Dic-`%LZ~5HW!&iUg*}Z)`&l{e#c;%j_ zzx&d`cf8}C?pX`*d|dacAKLqmuRSq1|6i|p zVDREwcHev3+bI3%={^F{;dpCT`1A|BJPm?>#_`dq+{pr#B)1&vNNAFKNc8vFf zNAFL?{qg_%{sj8ud1>_DZ~x%`T=DeVo_u-pmrq@O%gXaF-*MwtuNeP?zkK;w=bsn- G^Zx-TuGq@} literal 0 HcmV?d00001 diff --git a/tests/scripts/save_policy_to_safetensor.py b/tests/scripts/save_policy_to_safetensor.py new file mode 100644 index 00000000..70337c17 --- /dev/null +++ b/tests/scripts/save_policy_to_safetensor.py @@ -0,0 +1,101 @@ +import shutil +from pathlib import Path + +import torch +from safetensors.torch import save_file + +from lerobot.common.datasets.factory import make_dataset +from lerobot.common.policies.factory import make_policy +from lerobot.common.utils.utils import init_hydra_config, set_global_seed +from lerobot.scripts.train import make_optimizer_and_scheduler +from tests.utils import DEFAULT_CONFIG_PATH + + +def get_policy_stats(env_name, policy_name, extra_overrides=None): + cfg = init_hydra_config( + DEFAULT_CONFIG_PATH, + overrides=[ + f"env={env_name}", + f"policy={policy_name}", + "device=cpu", + ] + + extra_overrides, + ) + set_global_seed(1337) + dataset = make_dataset(cfg) + policy = make_policy(cfg, dataset_stats=dataset.stats) + policy.train() + optimizer, _ = make_optimizer_and_scheduler(cfg, policy) + + dataloader = torch.utils.data.DataLoader( + dataset, + num_workers=0, + batch_size=cfg.training.batch_size, + shuffle=False, + ) + + batch = next(iter(dataloader)) + output_dict = policy.forward(batch) + output_dict = {k: v for k, v in output_dict.items() if isinstance(v, torch.Tensor)} + loss = output_dict["loss"] + + loss.backward() + grad_stats = {} + for key, param in policy.named_parameters(): + if param.requires_grad: + grad_stats[f"{key}_mean"] = param.grad.mean() + grad_stats[f"{key}_std"] = ( + param.grad.std() if param.grad.numel() > 1 else torch.tensor(float(0.0)) + ) + + optimizer.step() + param_stats = {} + for key, param in policy.named_parameters(): + param_stats[f"{key}_mean"] = param.mean() + param_stats[f"{key}_std"] = param.std() if param.numel() > 1 else torch.tensor(float(0.0)) + + optimizer.zero_grad() + policy.reset() + + # HACK: We reload a batch with no delta_timestamps as `select_action` won't expect a timestamps dimension + dataset.delta_timestamps = None + batch = next(iter(dataloader)) + obs = { + k: batch[k] + for k in batch + if k in ["observation.image", "observation.images.top", "observation.state"] + } + + actions_queue = ( + cfg.policy.n_action_steps if "n_action_steps" in cfg.policy else cfg.policy.n_action_repeats + ) + actions = {str(i): policy.select_action(obs).contiguous() for i in range(actions_queue)} + return output_dict, grad_stats, param_stats, actions + + +def save_policy_to_safetensors(output_dir, env_name, policy_name, extra_overrides): + env_policy_dir = Path(output_dir) / f"{env_name}_{policy_name}" + + if env_policy_dir.exists(): + shutil.rmtree(env_policy_dir) + + env_policy_dir.mkdir(parents=True, exist_ok=True) + output_dict, grad_stats, param_stats, actions = get_policy_stats(env_name, policy_name, extra_overrides) + save_file(output_dict, env_policy_dir / "output_dict.safetensors") + save_file(grad_stats, env_policy_dir / "grad_stats.safetensors") + save_file(param_stats, env_policy_dir / "param_stats.safetensors") + save_file(actions, env_policy_dir / "actions.safetensors") + + +if __name__ == "__main__": + env_policies = [ + # ("xarm", "tdmpc", ["policy.n_action_repeats=2"]), + ( + "pusht", + "diffusion", + ["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"], + ), + ("aloha", "act", ["policy.n_action_steps=10"]), + ] + for env, policy, extra_overrides in env_policies: + save_policy_to_safetensors("tests/data/save_policy_to_safetensors", env, policy, extra_overrides) diff --git a/tests/test_datasets.py b/tests/test_datasets.py index e50d4108..22b271be 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -265,7 +265,7 @@ def test_backward_compatibility(repo_id): for key in new_frame: assert torch.isclose( - new_frame[key], old_frame[key], rtol=1e-05, atol=1e-08 + new_frame[key], old_frame[key] ).all(), f"{key=} for index={i} does not contain the same value" # test2 first frames of first episode diff --git a/tests/test_policies.py b/tests/test_policies.py index 51cdb93e..7d2f19ba 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -1,8 +1,10 @@ import inspect +from pathlib import Path import pytest import torch from huggingface_hub import PyTorchModelHubMixin +from safetensors.torch import load_file from lerobot import available_policies from lerobot.common.datasets.factory import make_dataset @@ -13,7 +15,8 @@ from lerobot.common.policies.factory import get_policy_and_config_classes, make_ from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.policy_protocol import Policy from lerobot.common.utils.utils import init_hydra_config -from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_env +from tests.scripts.save_policy_to_safetensor import get_policy_stats +from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_env, require_x86_64_kernel @pytest.mark.parametrize("policy_name", available_policies) @@ -228,3 +231,37 @@ def test_normalize(insert_temporal_dim): new_unnormalize = Unnormalize(output_shapes, unnormalize_output_modes, stats=None) new_unnormalize.load_state_dict(unnormalize.state_dict()) unnormalize(output_batch) + + +@pytest.mark.parametrize( + "env_name, policy_name, extra_overrides", + [ + # ("xarm", "tdmpc", ["policy.n_action_repeats=2"]), + ( + "pusht", + "diffusion", + ["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"], + ), + ("aloha", "act", ["policy.n_action_steps=10"]), + ], +) +# As artifacts have been generated on an x86_64 kernel, this test won't +# pass if it's run on another platform due to floating point errors +@require_x86_64_kernel +def test_backward_compatibility(env_name, policy_name, extra_overrides): + env_policy_dir = Path("tests/data/save_policy_to_safetensors") / f"{env_name}_{policy_name}" + saved_output_dict = load_file(env_policy_dir / "output_dict.safetensors") + saved_grad_stats = load_file(env_policy_dir / "grad_stats.safetensors") + saved_param_stats = load_file(env_policy_dir / "param_stats.safetensors") + saved_actions = load_file(env_policy_dir / "actions.safetensors") + + output_dict, grad_stats, param_stats, actions = get_policy_stats(env_name, policy_name, extra_overrides) + + for key in saved_output_dict: + assert torch.isclose(output_dict[key], saved_output_dict[key], rtol=0.1, atol=1e-7).all() + for key in saved_grad_stats: + assert torch.isclose(grad_stats[key], saved_grad_stats[key], rtol=0.1, atol=1e-7).all() + for key in saved_param_stats: + assert torch.isclose(param_stats[key], saved_param_stats[key], rtol=50, atol=1e-7).all() + for key in saved_actions: + assert torch.isclose(actions[key], saved_actions[key], rtol=0.1, atol=1e-7).all() diff --git a/tests/utils.py b/tests/utils.py index f3fe5790..6a706694 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,3 +1,5 @@ +import platform + import pytest import torch @@ -9,6 +11,51 @@ DEFAULT_CONFIG_PATH = "lerobot/configs/default.yaml" DEVICE = "cuda" if torch.cuda.is_available() else "cpu" +def require_x86_64_kernel(func): + """ + Decorator that skips the test if plateform device is not an x86_64 cpu. + """ + from functools import wraps + + @wraps(func) + def wrapper(*args, **kwargs): + if platform.machine() != "x86_64": + pytest.skip("requires x86_64 plateform") + return func(*args, **kwargs) + + return wrapper + + +def require_cpu(func): + """ + Decorator that skips the test if device is not cpu. + """ + from functools import wraps + + @wraps(func) + def wrapper(*args, **kwargs): + if DEVICE != "cpu": + pytest.skip("requires cpu") + return func(*args, **kwargs) + + return wrapper + + +def require_cuda(func): + """ + Decorator that skips the test if cuda is not available. + """ + from functools import wraps + + @wraps(func) + def wrapper(*args, **kwargs): + if not torch.cuda.is_available(): + pytest.skip("requires cuda") + return func(*args, **kwargs) + + return wrapper + + def require_env(func): """ Decorator that skips the test if the required environment package is not installed.