From f3bba0270d6885f0363b638acf28b6a6d7c4f0b7 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Sun, 5 May 2024 11:26:12 +0100 Subject: [PATCH] Remove EMA model from Diffusion Policy (#134) --- .../diffusion/configuration_diffusion.py | 9 -- .../policies/diffusion/modeling_diffusion.py | 84 +----------------- lerobot/configs/policy/diffusion.yaml | 13 +-- lerobot/scripts/eval.py | 2 +- lerobot/scripts/train.py | 3 - poetry.lock | 4 +- .../pusht_diffusion/actions.safetensors | Bin 4600 -> 4600 bytes .../pusht_diffusion/grad_stats.safetensors | Bin 47424 -> 47424 bytes .../pusht_diffusion/param_stats.safetensors | Bin 98776 -> 49120 bytes tests/scripts/save_policy_to_safetensor.py | 12 +-- tests/test_policies.py | 11 +++ 11 files changed, 21 insertions(+), 117 deletions(-) diff --git a/lerobot/common/policies/diffusion/configuration_diffusion.py b/lerobot/common/policies/diffusion/configuration_diffusion.py index b5188488..73fabefa 100644 --- a/lerobot/common/policies/diffusion/configuration_diffusion.py +++ b/lerobot/common/policies/diffusion/configuration_diffusion.py @@ -118,15 +118,6 @@ class DiffusionConfig: # Inference num_inference_steps: int | None = None - # --- - # TODO(alexander-soare): Remove these from the policy config. - use_ema: bool = True - ema_update_after_step: int = 0 - ema_min_alpha: float = 0.0 - ema_max_alpha: float = 0.9999 - ema_inv_gamma: float = 1.0 - ema_power: float = 0.75 - def __post_init__(self): """Input validation (not exhaustive).""" if not self.vision_backbone.startswith("resnet"): diff --git a/lerobot/common/policies/diffusion/modeling_diffusion.py b/lerobot/common/policies/diffusion/modeling_diffusion.py index c639e2f9..f5f64d80 100644 --- a/lerobot/common/policies/diffusion/modeling_diffusion.py +++ b/lerobot/common/policies/diffusion/modeling_diffusion.py @@ -3,12 +3,8 @@ TODO(alexander-soare): - Remove reliance on Robomimic for SpatialSoftmax. - Remove reliance on diffusers for DDPMScheduler and LR scheduler. - - Move EMA out of policy. - - Consolidate _DiffusionUnetImagePolicy into DiffusionPolicy. - - One more pass on comments and documentation. """ -import copy import math from collections import deque from typing import Callable @@ -21,7 +17,6 @@ from diffusers.schedulers.scheduling_ddpm import DDPMScheduler from huggingface_hub import PyTorchModelHubMixin from robomimic.models.base_nets import SpatialSoftmax from torch import Tensor, nn -from torch.nn.modules.batchnorm import _BatchNorm from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig from lerobot.common.policies.normalize import Normalize, Unnormalize @@ -71,13 +66,6 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin): self.diffusion = DiffusionModel(config) - # TODO(alexander-soare): This should probably be managed outside of the policy class. - self.ema_diffusion = None - self.ema = None - if self.config.use_ema: - self.ema_diffusion = copy.deepcopy(self.diffusion) - self.ema = DiffusionEMA(config, model=self.ema_diffusion) - def reset(self): """ Clear observation and action queues. Should be called on `env.reset()` @@ -109,9 +97,6 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin): Note that this means we require: `n_action_steps < horizon - n_obs_steps + 1`. Also, note that "horizon" may not the best name to describe what the variable actually means, because this period is actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past. - - Note: this method uses the ema model weights if self.training == False, otherwise the non-ema model - weights. """ assert "observation.image" in batch assert "observation.state" in batch @@ -123,10 +108,7 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin): if len(self._queues["action"]) == 0: # stack n latest observations from the queue batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch} - if not self.training and self.ema_diffusion is not None: - actions = self.ema_diffusion.generate_actions(batch) - else: - actions = self.diffusion.generate_actions(batch) + actions = self.diffusion.generate_actions(batch) # TODO(rcadene): make above methods return output dictionary? actions = self.unnormalize_outputs({"action": actions})["action"] @@ -612,67 +594,3 @@ class DiffusionConditionalResidualBlock1d(nn.Module): out = self.conv2(out) out = out + self.residual_conv(x) return out - - -class DiffusionEMA: - """ - Exponential Moving Average of models weights - """ - - def __init__(self, config: DiffusionConfig, model: nn.Module): - """ - @crowsonkb's notes on EMA Warmup: - If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models - you plan to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 - at 1M steps), gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 - at 10K steps, 0.9999 at 215.4k steps). - Args: - inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1. - power (float): Exponential factor of EMA warmup. Default: 2/3. - min_alpha (float): The minimum EMA decay rate. Default: 0. - """ - - self.averaged_model = model - self.averaged_model.eval() - self.averaged_model.requires_grad_(False) - - self.update_after_step = config.ema_update_after_step - self.inv_gamma = config.ema_inv_gamma - self.power = config.ema_power - self.min_alpha = config.ema_min_alpha - self.max_alpha = config.ema_max_alpha - - self.alpha = 0.0 - self.optimization_step = 0 - - def get_decay(self, optimization_step): - """ - Compute the decay factor for the exponential moving average. - """ - step = max(0, optimization_step - self.update_after_step - 1) - value = 1 - (1 + step / self.inv_gamma) ** -self.power - - if step <= 0: - return 0.0 - - return max(self.min_alpha, min(value, self.max_alpha)) - - @torch.no_grad() - def step(self, new_model): - self.alpha = self.get_decay(self.optimization_step) - - for module, ema_module in zip(new_model.modules(), self.averaged_model.modules(), strict=True): - # Iterate over immediate parameters only. - for param, ema_param in zip( - module.parameters(recurse=False), ema_module.parameters(recurse=False), strict=True - ): - if isinstance(param, dict): - raise RuntimeError("Dict parameter not supported") - if isinstance(module, _BatchNorm) or not param.requires_grad: - # Copy BatchNorm parameters, and non-trainable parameters directly. - ema_param.copy_(param.to(dtype=ema_param.dtype).data) - else: - ema_param.mul_(self.alpha) - ema_param.add_(param.data.to(dtype=ema_param.dtype), alpha=1 - self.alpha) - - self.optimization_step += 1 diff --git a/lerobot/configs/policy/diffusion.yaml b/lerobot/configs/policy/diffusion.yaml index aa90afdf..60061c38 100644 --- a/lerobot/configs/policy/diffusion.yaml +++ b/lerobot/configs/policy/diffusion.yaml @@ -1,5 +1,9 @@ # @package _global_ +# Defaults for training for the PushT dataset as per https://github.com/real-stanford/diffusion_policy. +# Note: We do not track EMA model weights as we discovered it does not improve the results. See +# https://github.com/huggingface/lerobot/pull/134 for more details. + seed: 100000 dataset_repo_id: lerobot/pusht @@ -91,12 +95,3 @@ policy: # Inference num_inference_steps: 100 - - # --- - # TODO(alexander-soare): Remove these from the policy config. - use_ema: true - ema_update_after_step: 0 - ema_min_alpha: 0.0 - ema_max_alpha: 0.9999 - ema_inv_gamma: 1.0 - ema_power: 0.75 diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index e3afac41..e9aa3041 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -121,7 +121,7 @@ def rollout( max_steps = env.call("_max_episode_steps")[0] progbar = trange( max_steps, - desc=f"Running rollout with {max_steps} steps (maximum) per rollout", + desc=f"Running rollout with at most {max_steps} steps", disable=not enable_progbar, leave=False, ) diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index f58dbd06..6cbc8265 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -89,9 +89,6 @@ def update_policy(policy, batch, optimizer, grad_clip_norm, lr_scheduler=None): if lr_scheduler is not None: lr_scheduler.step() - if hasattr(policy, "ema") and policy.ema is not None: - policy.ema.step(policy.diffusion) - if isinstance(policy, PolicyWithUpdate): # To possibly update an internal buffer (for instance an Exponential Moving Average like in TDMPC). policy.update() diff --git a/poetry.lock b/poetry.lock index 1121e68c..616f4a6a 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. [[package]] name = "absl-py" @@ -2407,7 +2407,6 @@ optional = false python-versions = ">=3.9" files = [ {file = "pandas-2.2.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:90c6fca2acf139569e74e8781709dccb6fe25940488755716d1d354d6bc58bce"}, - {file = "pandas-2.2.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c7adfc142dac335d8c1e0dcbd37eb8617eac386596eb9e1a1b77791cf2498238"}, {file = "pandas-2.2.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4abfe0be0d7221be4f12552995e58723c7422c80a659da13ca382697de830c08"}, {file = "pandas-2.2.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8635c16bf3d99040fdf3ca3db669a7250ddf49c55dc4aa8fe0ae0fa8d6dcc1f0"}, {file = "pandas-2.2.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:40ae1dffb3967a52203105a077415a86044a2bea011b5f321c6aa64b379a3f51"}, @@ -2428,7 +2427,6 @@ files = [ {file = "pandas-2.2.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:43498c0bdb43d55cb162cdc8c06fac328ccb5d2eabe3cadeb3529ae6f0517c32"}, {file = "pandas-2.2.2-cp312-cp312-win_amd64.whl", hash = "sha256:d187d355ecec3629624fccb01d104da7d7f391db0311145817525281e2804d23"}, {file = "pandas-2.2.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:0ca6377b8fca51815f382bd0b697a0814c8bda55115678cbc94c30aacbb6eff2"}, - {file = "pandas-2.2.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9057e6aa78a584bc93a13f0a9bf7e753a5e9770a30b4d758b8d5f2a62a9433cd"}, {file = "pandas-2.2.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:001910ad31abc7bf06f49dcc903755d2f7f3a9186c0c040b827e522e9cef0863"}, {file = "pandas-2.2.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:66b479b0bd07204e37583c191535505410daa8df638fd8e75ae1b383851fe921"}, {file = "pandas-2.2.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:a77e9d1c386196879aa5eb712e77461aaee433e54c68cf253053a73b7e49c33a"}, diff --git a/tests/data/save_policy_to_safetensors/pusht_diffusion/actions.safetensors b/tests/data/save_policy_to_safetensors/pusht_diffusion/actions.safetensors index d9b2031708df0f0770c29ef3a44f5c6d0c1279cf..730f5b2bc2a801d15b4ade3593c90f95650f5472 100644 GIT binary patch literal 4600 zcmb7Hd0dUzAC5^=A;d7YwArt;sD><^^AI&ng%Beu*|JPZB4o59LTO4dEeaK3k|^BL zaLZOn5-M8IK5bg=t>62eYntD``h5DF_q^wMzTfA&oO9mW)K7-}?N`@UpR-@xarZvw zjp}pMP4slt1?sL_?D^TR2D1dZwgPoW``z~If7`Ulb>nVV?ya^!Uq{#Wpg>*cTbUeB zprfrjizn-TE14$>boBKMc(&d*vpJ$bS9{h^JbTu+vU#FFSKmOJXX}46no4kopUeE- zi_3Yp(U>Hm9kq#c&d_r9e($jr;MdHrH_$tFR?r+xe;$8|5`bOtobrueaI1jOmh*5fOf_W+TRae{+OQ z(V%HnS;(``r02_ekT6pdYqsgZY+(l!BHO6^UrkV5n?yGi%*UriA;j%*4`f_KRIv9m zIy$@IYrF&>Y9w@vqcT3_T;sHe$lRw(zUD`dW;IWwJ_qed)8KaKnEoF&`?CvEw65eLu1|c-_1{m8ks)-XrF6FV*|`Mhx9)5=Q<>UyXrJRG@h~m-M^; z3tE;R#{Cs3RK`mT+3mK}aitw~ukON)4l&%G3TbIv0<~XW2VIFRRl5HN(J~ z+U01C9HR#;FTxq6#3EaxFJF&?-(a(=3o%tor5~nz!i)X_C{5c=Hz<^l{<&eyPL($x z(^5-{UM+>z>F@EI@D3TcDFWE_oUQY5c%{v!W^TXH3(w`Ss(KNzyxGO_9}zi>s`?b+ z!RdD-e4v2g6Bo5%&lWT8*_h~#kih)lioW692N|6|Gu)&@&q%mhJC>)2 zG4A~lcza0ulJC#@n+^DDk2@-dU*|MR@aB;xr&&ZYbY^3~0X3FS9;fW155HX>kX(CS z+E?5=m$ue;g43Kb3e%=qSi*AO>LO~@J+<6WT*Rka>K`=`nvMSl=lCp9xYyr1VK^~lybMSh)` zgOQ~bIFT(+HRsO522l^1^)IqGcwAm#DLEXylFgHwyXj2lIw+i_NVCj?Rjg8MSCZ(c!x)Zz)$Wu@`M>IPsQ3^#P>`&YVPJj7$ZUd`<+PrY)DOpwlI8a z(HW{&9}UH^UaSt6SHD5#*d36IaUjRquOM4vA?a1qK+?x}rYHB)C~WiSf)AIQ#&wmn z#@&e7xsoCwo&$U_H;n5$t{&{4tLvDbtGFDA$l!-!Hd>x3!wKcjq#~;fmlk+q;PlYG zcq|O8af>#xdG$O6Tr(WmXD22xjqy8L?aAv9?%d{lDBhK$8&3=A%jf5C-z1Z*Z`4hJ zLtrPcY@)>52fcuUe-C1b4`gx4U+av@2U%Io%O}n z^V8;t52ghEiNU|jrBTZ_qk4usZSxvQf6Sgl^SrJy{NdiOb>?H|cuLQ>yX-6fWxfn; z`@Rzk-EC>)TuujvDb49Y#*APZkwM6`xNf#@U(8TM&$3tu_di0e;!Ma_TR_D%TWW{n z2t(aqGhLN_s4rr*==W?7%e~DD2kLzG1W6w_0Jp}B)R*{Bjzl zlZv3XZZbq2!=ZCjs%74~Vsd?f7!5BTpuf>=gon(h2A+fIrInA+u|pnBmmbmeZT?WU zI?e2uPPqY%mh-G%tlhZ0aoU2lJK!+Ym_74xxBL51{SDm=|8E|oJ@J40mcvI4}8X+n5aVxr1e6 zHZtE%pAI4F)gzDzm+1I?m6-6Vlq}Zi#m@6~)cR-@3gTt^>Y_8LjQP76(g{E18fGuK zCJk$D6~gAwJ`$i8s+C~$eGgPN2jS(`DTqriWqA+2@Dnx2 zu0(?GF0lIRrOk>hz2WHpQ1%3ItxIL8NQg$U%-EpsQDkgz03jGW% zu4WHf@51HE0rK6*B8Z=!hwjlV{OI0<2}!v~*nN!ISnP5Ojk6OGxTlLcrTxI{$Bn2& zLg5Bfe#yn!fWyS7dwpL#nlZnVF)^>%JpJi%lx9|8dsHLqL&3+P)Ymr(OPpHBhN=^E zO44@hw4lV&^*P=p8NyPg*wu=fAl`YR)pY?y9mK<>M^KC}(@LZ0kn5Z2{9YXu(>l@I(mHAuF)$O9Qh% zZAdDMr?U1k?j$dye#V8=DmE5_CZ?g?Ae{03-7h)=OsL1lU>aoDB8}aR|CTT&a>ioZ z_YXw8rZZ0ee3fK9?a_EF5ujxFQeJn3-e{0~l!B(oGrV8_f zO9))8%yQUY5Xf|y+jZi0dJauh6Ou*VPq174FQ(HWfUkjw*d;gPqK7Q?%sx%WE-Ate z(O}#vucAkmHbY{zkWRmR0=2Z4v~Asw{X4n)m8+lG;xu~KYYDj*a1f^cJusMk3yvv1 zWKF9Hi}U8GU2I%AsRt1~T`)N*fy@3Q5Zk^+a`kB3*rZFR+n-_ntKT8~I9Y!>?~?>sV}~Nyn7eN||Ci^# z?bFe(xC6%nEa~Q%S1~W;G@hhJA>u&`Mo&(sR(rYcR1SaOMF&mF{GR^)rUD_8^Kq`T8PDd7M_FqUGKMT64?d@3 znwl5~HZ`MZOEuI^%F@;MF5*%98`>|wA3SD-Al5$&d0K^Z<}i*wPzx<~{BIW#@$}{H zGHx6nRzb7I{X^otrTxhJ&)YH;oxz0+m*1O5MG@<%-^gp^ZRK$~J|PrxHqS`;mxEMc z^)9+Q^8~r6*#(uA8LV&l`bNht!j#TWp&puqVYB@4Bi4`l>h`LR3d&wKMo^3*Z4Rci#k}&W5R_rVX}0; zIUasN*Qa%|Ug=zwPcsjcL1XrxsH^+}i>4H6o@7sSN}OOmfvY3uFK%fy;%E3mrr;GB zdT}deC4h}WDbdV;QrnPP&heU^Aexfi5A)+WTB;^>LL+{Otc`&JdWm2(;I^?QmPY)!p6uw;kTxTO)y)v58B7K->6R znH*1GY-BQnC!2gLnI{U2%}i(VY}0RMb3}oO(TpE?_Ka_3^F)D(*-RszZT8Jhe zm;1dB3B_5cO_k85h7>wWrINi5IO+s`&GBXv&2Y7c?!5lH zN3_i>B+-Um%+82_-SAv=gsyxGI4&KJ$n-Qk6G@^-)J*>01*kpToIAUhsQp}Jj2 z4R+O`v#SR|*7H$QFQLD!Q$=hXPL=%Fn zA~NLuVzfH!N1;VMDMr_byR#kHMU6&)GU3M^L5}b=$d>#y?lUFEw+C^X+by|MRgU z>Ewf@$O-!oIrE)>;Zx?eVZX-$?%A3ct4W{~;M5=gIA0z1etp9`Qll2qZ{OS@8Gh6Y0fXjh?n2X@b3+a znf-tB{qHzIvl2m{`xjaqDLhrOdR-GAaZvs)h118sESo!oeo zy)wn&&8Nun>I+DNkSbf2vAE069L1}uTpZaIM6<40la=u=iBi4~F&(Z-)g)KRhTCr# zZkzBi32gt;ALqc0f0IWU5k%dhn@lnM%-)-+tYUgj`su*rSrZ9~-a*V%x8skGm*A8v zC9|K|kj%pkSXFs|jQ3u`;@dWI36xTW@C-@9;Q&X*KbxUXOI%(N;aMMKbuDB#!+p=l z8Lf6Wr1wIj>M#!Y$@-G-&&JzLxbEYNL25TSjS@6G+RJGcl5ArO*zVC{`Q&k$2h9+? z?H^oih?Dgd_s*wn6b;Ym{<;@s=Llaf9=~x@JHx3o-D0{%D~@Dy1^Zcu7@5iWl_E~v z6cf08=b>FphR--iEPD7}l<~*zUW6?+iMTe?gS1r@;d}84l4~G^-iNbL{cscK=Y^nk zS{4qL%Jg`wy+(y2>lokhuX?h3i7p12m!Z03In{L>Ob2Gj;kwyyv`=YedU!w2%RiyO z@Dy>LUWjoO9}rcbNHu2vgtbB`y3MYzICxxTNd-A|(V5NjTYKnquSN_#OL2QiF&r-` z_uDH!(}nVR33RDWJuTn$J&RW@HUftfp5s|=4$D)e(E$<^S&4>wMYta?!NB`n$g4A_ z3sl4m-(DI?fB19}!zLbJbx2zA7C93)V|4r)a=QH*N_6HBsg@23J|;3f4{OI_lV3N2 zxZJd^6482JOJ?U@x`gZ*90bc~uJ5>duz#+u>I~&@Ss~S%*ocN*a+$hUGM+$E!)7as%z8Sdu zGFxYT@%8+4I(!hi$Ie4(@oXCBupadWigaAy7^+_|g%%&U&hUr(zSfzKo#QD!%O9 zu51g2vV9n{V>|5@^u=+kUzY9U^2TWkH|oR&J!|&N$6X$Lh?=eKVfcUZAnS?$*|!|d z)h?c#zHpq=Y)wuS@inq0u__Wgdw2$W=Jw)m<9KExbks%MZx!Q-r37)7BhVQq`D)vm z3_ExizfSm#`Sz;ahxpfzP^y_kC+rr%>U9ONG48{vI9KX;q#7?1<@@V0ead_0FFm3Q zA*%Jv-i`W9*xf0?`hB}eoEhJ9|F73+XC1s!G-%BQBQ(}1W7ZHUCan*}s}0joon685 ze){r{)T}^+&%vvRUL-+<*A-@0-*F906fQB1*;m&vySCfWpRYvED2BVYEsdJ0zsB+S zWJ;szAmD2D=*vB9zqW@A98-#pr*W8eBp(xeTcDF#go`_lG8^-@-$AoQ3J&^o)2*4K znf(Mc5wc3wLex=&bs_tSe$VRucz%fgooL3tVe`9B9Z;56jZ+tzSsxmGRG~pZ7eF?O ziAVJbs+a1Cy>^t?c|XV7RC73KdNbbG>m^7m?Zp!B78F^f;CkRE_C5_eSS(WuTJXrU z5|-v(P@E)1SEC#975Q&2Avt^e3z95)VL7W0w^F;&7{i_G-20*2T@m3KjqSt$7s|Qw zmFumY8@XEZb>P-dm0V}NuivAFjl~XZy&$Zdt^ZZpM0QkcV0Bpds~@u`>p3BjRP~~~ z^cofnYan5cN_26p2>GTB`5gMukeFYJTP5+1SE}oR#%cFJL5j)VA^(XIt z-^=y%^r8ajZxqA-cssU5E+H+mma+A%nx1G;wnGfgTXd*&&P0j+FZs~wvK8B0BUI<%wofJLqBa?I2vwn~t7Kn7eo)fBp^8nVJlQw4&?;Ln65>xT{x6oNHer@iEZE$cxwN_bh?G`H4u_DX{|`{lc&cD z!sxj9rSK3cA*-^Q1}|+z*Md1z>*@)JX&(`7@W=04xcil>pWVDn`t-nj@-QR-w!u>T zY;gyk=?BTuFE%XBKTmCAqw^FgVx`@fdr|@)|HJ5Ud4roZOT$qbS4IAjMbyk;55A=A54F$l|VRA1sT@d zeard3YW1#P$AGd<1cWT1o2FlfReBg6XIwzc12HtFX47T6xbIXB@BgY3j9 z^T~s{Z0KqABEY>B%^o!{Iw?<`|Ga{y?QdyI@c{VEh`{CGXp|V1P@|DA5cQoQ#IF2r z7a{Qp;_fnTREz#Vi_|}p^MSH{=%FLLgyE|CifKX2YI*uV(3I7NY>H(m&tp8yIp_C1XRMW)t0cXyo^0#d$Y`X+-d&f}kfCg( zh7GfxAG(sKFmuj$ShgX%O08Y1x+IS%tFeySJBX^C$PGz z26@V6Ah}3EPug6BBO|+26ItcpgZJSY{U2CfcsY8k7V4fGqg&hZq5COcy`0B@r%3@m z-rf$P2S>4`T~c_q z#y=8LyeT}xEcY)?&K+YC>2KlIf}46>YXxvP8vgj_=RwjNz@+x(fob0eF`rQmer*G) zu#hhxFliL4aSMUX5(#s$UBdEXF2%>CGAJ7+iK@9s_+HzNh#WWx%dTdt`p(5eF1H)A znkV7H4t;#(n&iK@oNQp`j3VG0!vvyeXa`#qQ7^k41Wz*3@!=N-z#}yWvki;|ubVy1 zw{AMavrc^_-<}71+w7SNM@Qg9pXIxY5lEW+h|AKz)u7iOOFRm2gN$%1+*nA$32z(K zfv$NZGV2tCMr`~)$a3HP$B;Pxo+Uh`CzR`tn6u)|^}|mrfaaUy-jW_PZJSI_Z!(7N zKdh9|I&=JX|% z5$ud9#!Aopf;iM0w`+w@5Qmdg7qI-_RC}Z@)}&toCC+l(VVM{{2u{JSI2?lb@3jcK zfChLVI4D&?(iz|;y==kvZ=m9LE@wXC@!kn)!^EF}#bl}SvrCrHBl?(Dbao=Sbc5$wLkjW5O?*JDT4al~0a&)E3nDjY225J8d zrC{U@9x)op_FFpdJ@FiL^`)!bcg9GdqAQqLrC9^z)|rfB-uoShOqxr>!2e=Mg$(udzP-k1L zfGq3l%4f4%phmfp3BEan>aIjF6NyuQF4U0tmik4if?u5@AgDFvL)qkus71ZEyklmGLcm5XQ0$n zE=Py$MRbtweGu+XaZlgz2V2nTXejbSZQE_#OYc1pyo%q&_t)#xw<1RD$dTXmq1Uuc#^* zxLiT5sOhI74?mYuap?h6)IgDNAblULot6W#^!Da&;)+-S$3gtOhZt`OZMtVwMl9;h zH`iXyamNUUZO($kxrc<4w?82;$%mS} zJ%S?Ra+YtXhJp)3wnWE+t5-@`ns6}Pb<)2RnNwAQxqZ7>AcR-d;7fXacC)el`g(M@ zY7+DO$0!lg^cI`;A0;xw50%#2zGh!|w&re85}KTCRK9LIhSonU7K|_3 zfKCkkhUs24VslPCHDf_3=LQ>f)w77dwFmEV=|Vw{IDTKzi<$=%xUY1RDR@~Pz^mKb z;EsikIx|m3*2fyFRJYRKP3W|*M~;5yWV((Ii_zfe3rf$vwWuxbOTh=LJ=pv;@7_Tk zrDYoKKKB(D-ng{k$xAm->d$Kk^{Z=0wP%V5t(k_{nFd_(@C+23Jm0#k$yK;QtCXqy z+Lh#Rs_UgMGQJ=qa%C#hK}#~gtBL9DegF#pX}V=^-@9MnY7<5ax48Q$MT=#mOUp&| zm%QU7H*AUdq5C+gNHJ}hT31IN{K-hYqc({=YF^XorqCC5v~)8QUMpF_&pi1fSCXIhKBN!^*>hBQS`Mi*)~^nuqDXmqyLx0p2x+r+ SFSD?MOWDM4p>I4kq5cO59`5P@ delta 1719 zcmX|*YOjY=o|h%nd~f%@u2BcKkXcJ)=c_f!0RdI|Wq^e94u4}nws z2*1K97~F(HTGlFL_+E$nI8hRm4G{R1x8d-Ex-}l&a}E~X&Q^3?h=v^PTYS?G1W31P zVM`1X{^7*gK+ov!hxBDJxb*Q>#zpfKxBzYa)$-rShzK^2`JjTw}~#*nuXj5~$Z3*w8a>CO@R{ z4!yluCciD1Ll1h{Xev7Y?H^Ue48?*{RY@dbMEf4@1l*n&icL3skP{XbKYoXR< zpGX1m=V3|wt0w$wOY+p0i)jlEhY3*G3k6?*%S4epsKOMqHm;}gRVow{Z>jhtu#5?i z)|!>vA;%x)$T6J?AfIOOXFJ0nE-(+{5eX1bvkuDs)FG~4$U-;T*~G^0#vu7)DCzd} zJ+^-&obl&Y?>I9Fri(-?cwZh7LUDT#y}Zp3u9#=gcH-rq8gc{lG+fV$n;OSLN8ybA za%#3s28l5e{G(|;sBLzp+#)rJ^E-0+k5eu~xj995>5`;i;@=9zyazMyUS(4tfe`a9 z`S^?%$mC1u-S?(Y-OUL4W$YAFK#yA&J4@D!)OgN*nSJ2({zpu!Uk2=0YrI|Y2v!Bz z;0_I?NPkx*CeDGGTgMp#RfVp%4;u5e;uZD}FPtdEdEPhBaY zS9zZl9Gn!9QK|l9L{GkJW$G?UEjb4wQd=8SPv$cLrWJz84!pT3w4vXu7+<0Fp|Rqw zFJ4%ZC*Nxl$moMrM&(al1_BBM}ln>ZCDAmrXkq17ANR8 z(NJcHlZd*bY=!){)K39+fgM@scl8X%=K}(CGwq6eeYX)RjQZAP)e2`u-{#Q2k3-5< zhVkAW*M>=l7A#@@CzSHDA>Q=*E^65^g@;tWfn(FPSnkkyke<8L%xZ9S)mAH_OHViw z%n55mlj(G8!%|?$IfUBSePt*aL8Pa)Du*nCiE5*r R^v5OIq*b&R)je!L{soDW=coVx diff --git a/tests/data/save_policy_to_safetensors/pusht_diffusion/param_stats.safetensors b/tests/data/save_policy_to_safetensors/pusht_diffusion/param_stats.safetensors index a26868dbdd2782b49aff295071cb6b42d7d6c617..ade6a9e03118ac410f6f0ff9357b6779074437b6 100644 GIT binary patch delta 300 zcmcc7%=X|v6X%Z|3}8?_S&&^}^1e-$8*6e|5 z)So^fk5PYe{AS+i=J||z(|_eL8gLq#8<^+-35)3;bAe)s`HaSsb+&-jt4ud2VC0&9 zKA+Km-O$|72&im&K>?%Tbn^m6qsf1f6tPcVSOC3W)1$W_B6bo!RWn zh7S#-*fKop6==As*T6&;D3H**V4*?1IN}^)>h|^9^AigdG*@0rRAe@ z`xmc$;r_MN<+2)!BkpyV0_AxByWL={4ppeV#C!{g`hx z<}D2`VZ>tQJB_xb>tjmGw=mjR)b`c~h*(Iw-Dul8-|}ux?OtQS-ux*>EM}qASg`i~ z+O^f|o8EF3I*pFC{RNO%P-l0eW1RqUrWa1VeRSH5j&%$QAhn=gqidT3>}1~EtlMgI zZDX(vA{Nr^G`@60Cq&Ro8Ua!%&x8Ap#^;?a;z4PwX z((g2yyY1bda>TZP0gko(g?&a{|G)rcoq+;ME(*S962A6abFQ%0$VE~9O!gO{)S_U4 zO_K&L)X_8`fUyij0izZJ8*G|%a6)lzZh$nJCMPU`)PmrIO_LOkInE6Zh(^<7g$0aS z46Lwe(!vSF+a0K3)8vLFkX#U^z$Ay{6t*6DLx3PQO@gD(eEsTL$5Dy!TemQn? z_X!FxKv@T&aMzcM!Ymjqv*rT_FhE%cqVQ~^76sRAnOw6|=YazVU@Y500izZJ(`=bc zb3*a%ayW);nN+g`QVW7-woINm=6H8GNM_3 zil~+dk8GJdvIKNKaDdgAHyLG`Czb~kaMl~hwhp?*0yGdUVi6VSlP16#^Cs8q)Op|l z0vOv?vhBG>?H5dQ-ej7SVw{^jAOP70qs{{d7~t5(VXV%s=Xf9*^Crn0UpL>|8z6vU zp8@w44@hI)sF&lb=3G1Q%Xy<;)_mXq1}N(w6z=-!8iHlc8!fZu0|zibSqGxd0|zMh zT!rO5uWp?O4j_QBYzqa9+98B=_A2 z7@(}1LY-X>6p$=82^O_YSKXi61g+6F*<%$^>>FSJS|?;%iMY>u-r)eN(KZ=nsldQf z%Y#?8O=!>tX)?`8G2S+fK22%c zM8@R=klG!7qSEA<(=Gw6h=N$=t*`DQ>(^ppe0H18n z@=3n1cY)wwvOUWox4pLdW5XHSvz+nB+S;MHRs5cb^LoP+7iO7a>Gj1QA1{rUpRuuT zy}a|<@mQSSY)lfuoiKw_7N?-E54vL-4zWvfr6V zYDw_Bh4NA27=#a^?K`}-bnMXbI-c4_PuNSfml8B@p{#j!oLn5nqHKF6nzwBbzlE~+ zDL}PA_}@a=|0;R-#hv|!)~ z6I@MC1PQdPZB)?Y(l91@DUD&?&cRE&c`vQxspUaSdwDG__B+#<%Q9%<+0|+pA zOs1ddQ@sSNMlbKBadBZ%)bgOEy}XtZmpB=ei-VQ+@>W_wQ%i%Q_VS7v!<=glEbvT| zQoNMT&Ou;%d4X->qn|i9u6+&(6mSMbN#oSs!Df3Fn~iAR=0RwC7NM;Iss+MldlsJ^ z@#^{|NR6JwW|K6vG{|hvBC}aG{|hvBC}Qt zA@S1qvf9#byT%`|XwJ2~pXbJ_gJ9_Yl%RG1TH4QRY3nqdLju+4=e4v7rxph*?dPpD zL33_=fHnGgF|7uw1;R`Fc`psTx}FFUXxYu7f+m-SNy$rT4D+@QUfR!lX(dlB4_eyK zYiSwhyk~&~T6P<$psA%nO8a>!jbYx-;gGeT_tHw9S{}5tpV!g}uBIn~)ad7}w1TFV z1}W|5r8I_lI|ncA=e@L&Czppo$!lpD=UsEK(th4bD`;wIkWzdPd{#wbnzMC;k!api zChK6TrecL;5oo&o6i?7xZ($P!US3XlNfSJ|Jg6#gSyzp5?he9Uo4fP+T1k{kgvVQi|9a*=Rcl+M5i?`nVzzMZH3;(s(L6H| zQ7#ec3!r&qCP4QZ#F1=sw?%%76XRWjkYC{BftjudeGSrpA#KHs#rQC;HGapl_D3km zA}}>F2Ub=mzmS6!SLfH^2p}ajC)E^n81edmBvm&F@vbr;$VpAbERI|b#Wu=W8jV5% z<6Tjsl%*(UDMOJzzE~7l9G;XTuUeOyRDMe9AHke!(r7}uU%Cg*Tqf|yAN>-d)9ECEZPRh|YuzY7DN0bvmKKIYE{+lzB`H}<|PUqqar~O&OZkNWdX1 zx0Q^~ngAJ{k$87Wy#L>flA5xGxw}SfjnbN|IJs-I)+n*5f>x3Q`3wyl!gBYU;t63` z*m7ZX)+n_pTbR3RJpYU$o2)pwI39mSc}*Fbd#zDgqp&6$Of8H-O<-vgGF1e5Aeqtx zoNXJKO#c+87l#>%GMYvg^4O7WM4U!0 zPrxC}I3`=AxjRRTjY6REm(trgT5OaBmBYE$94$7Agfi0P?SvK^B}3(5?#|I%K*U$TFczyZDr?5sjRjD7}-49x7lvFi?pq9X>DgZKfknkN(_1_w0ZWVBudq@>d zFAnWbr&w*<7!ovZyO?GLU`9-TwvMU)V-g_26pe28gDo>M3qm-s|uV%vBouikggRd++y*v&|?wB>K z@$$)_zUFA9>Da48aWy>=9Kdo?8kgqo9KAH1dNs4Sz>_;iEluZM&A_?W91mI3saJWL zTpFb`oq07A=I$K5G@W{7;mPGuOVhbmC0z0F%x8@Nk!FyT1WhlEW0F%;uDK%2-8njI zI`wMirSz^j3T!&}s(_1*>U^zZ>J@NVMM>k--ZAtFwCtcnG;i~md(}(_U(LRn&h7z4 zUe;&Frxe&JYyJKiK%gbRBp|M?u3v&fm~B$F!Q8E*&!&^FW)|n|9CbFGeKiB;z1IOm znm8wcCa*fmY&!R9W@+xu(Pz`aR~DXJ9(6XIe3ihJj8FDUfI!RJ?0AKjXmV+k*>vXB z%+lPQqtB*OuPi*dJnC#Z_i6^tz2<1M>C~${O)rgOk~?Az-|#$r=<%QZ14W*3Lz#@_ zlqZ)*olWOn6|N;fu5(DdG+#e^vk9*+|M(QvbgotD$9JzJ9;BwTtkMLz1Zru>uUbC5 zxNn*=KX>%t{%JDy{>5uwxPNVR88+;$ghCpEs}zu05Cf~g8MWNN@$Ldlssd)DpmvO2 z4B8(us!SAbcbH8D&}iK&Ah{r_W=N$naM~ES{j2UCQ8hzEl?sy!qh^M@DuLzh7d114 zSLrypIBI4{vNF@Wdp2XNnjyzZg~^3cF++-#z;btuiWxGjbevusXJo9HMjG>lr#Exi zX_;XKpmh`B;w}K?0;!oH$4V%0`ZdaWffUUUWaRvozlNs^4v58S*b9& zFluJVu@YGBbw|w%F;+TGE{@U}Qmo80_ijlA4Ix%4OfQTRGCndDSnjS-O+$p0j+2X{ zpN0f0BdxG`$(so?tN^r%YCC+V)pUDS3uKHHcvfSb@cg?abF7*n$V#}PQ=nWR9X2Fc z8F&k|+!7qX#%HD~OfHNb8*;1!mUrDT#R@p!Fch{)0A-_t;$z7w#hV&{Ut>$jY^U@GtBUi$5T zG8;myJV3cXx@<_XGV+!XQ@y@;_{gC;=RiP`tL!c*u3j$ZXMB*Vm+L)&0FYE=zu8u7 zpvtAvc08(h8fG{0spWfM&_%P!@o4fz;cOlO+^bW4EN*hNvtLP%e;4 z8`83jylwvOf3qA&a+TdB#nsE@{Fv>Q^Ro>Lvv7S#13*%(3X`VFrNVac<*ZhdiJ;CM zh)8m+LQ};R%O#Om`#^1bZc(=(asV?caHqyY5;0`ihi6=vyL}MgR>;#bZ3qRN+%?!Q zX!c|Sp1`}5S#vz-M;mWRxTwzyNXVeXF6L2)ge zJh?ooXNcn}PW!C#=ACPcX=CJb2$oix05;=Lc07|xma3hi0d*^w|!}d@6il6FjFV8 zl}#-(BhaNdz;&K&>?B*@cEVl1rsVhiEWM zsl_k7b9R(29n!&^Ou0-7bqEO4k=`}Rps-d53UgECQmNM=Ff0!poi5e>9%=FbO?|c@ zwe6;%a-nqW5FO?y)ZJ8ScE}L(WA$P=N~uHN$_rd}Pklq8V~0R7W2I%LTqYGego^1% zcS~v6AzaK&l}n{$hmbK9>h)dVz_dE{$Qc6f8(s)-ZBi$XPWrxr){o)3y+)+w) z2p_XTOYW(!Q95=AAagSHGC4$6#U6Kbt4w!CY1tu+OfS^iQA&0QB~zi^HOgQzFpZ8~ zyu8k5QihWOX;kbo(%VsHl(j-enSODd9c5G*pvLV8W9#}<1^}Q`A5IrlK2X$kh}S2? zI4INi^W-gtab(~ar(?>|%i%awA46}yW{JcOXQL>~A&o3IgH=$uP`Yx6By%v`%Z<7m zvdR2dxmenA2q`mC%MPgzEKj6EP??h{mq}F)VP!hf-BQ|e2rY9{ z2shJ_?vC<6Iz*nisdA~5w-t9ag0U} zUG>jt!2{GP*=eZUQ#y8tR?GJ^V!E43%?=rBeym(9Ejxs*8L72>7dSGNNh&k-GC4=K zFp>p+(4B!qsiz_EnDp{Bem?1@;NPVWSTaXaDmHAxlF2Z2y@fRbhnfT(jnB% zO_fWfEQfG673#ew0tcq;6Vm05I`4^~5$PJD>D@^mUD|R8d($tjca2h(L+G0oI=K;h zy_LalV45DeCLJv?<-w&YhwwMOO!pe4Er&=rH&rf`vK)fqRH%E6@<2L7$2pmDnN;Ou zm|XnuQy!ukYm~N}OqLU`(x>1$g=de$vlFO#shlI*k#uVsnlyRC-nmTa*vVwO ze4$R?1r126VwaDG-5sT6Cll^0g$AnJQA&0)=B^rAa!-AY(y^0icV?zsCKWpwdS^kp zJ4(w=X5R@^xl~GaG6F9T_4+PwWE$;gykI6%|`LGrA&86Y1zqKJb@~Al#-o{ z#~VQptgJ2`+Q)C`;df!y`pr(};(=6U*3~4nBnILER5dV3khf>d!fPe7@TxDLb2kNr zs496*l;vG<%)o0UGw`a4a;`Hl2ogrZ|1`<;JFtYQn6l&nq!A~h@5;SCVFQP`4fqH=)oXXED&*Hh;NO4W3*k!M& z_Z~)RO-9&di0+M&!kWykvjOEUQdX0pby=RfibC9u_3)oGUu zwZk{bRDHwTES|fI6xw90UG|zfcT~Vgw9L(EI@$&d#LACVY`?5&qCBXS-DI>~bWtOq zyOH$YWWt@9DVIqdP6pkjk!9hKkIEQz2cUU$%0Rt94vtBU$EP@0_u<6B-A2lBGV(5a zO}*=pYMczd%MjhWD#bV%eP;v8U8EN$!|$>@cNcl4n~cA+BIP1!#>oJ@3gXt?*!M;R zSO8Xjyg1rar{9njlu9m;UYrcT%NFSFBGovVeP>1LMRIN`KUG|PQQd2iVw_C8%arIf z9C$Il(#>Kx;~jM}rqiP>^`=1@krp+sK#E-?jKr$j)UhMmj&Me~m$eUZNCUF+iZ zxk$J%h*fSpo9JATeSXSBy+lq=r6A8N(cMZo zGMM)H*;Zs=%4I^4@!gJ9o6n3go!cv6Bsyltx4)7_pks7uY_K>$o=;}tjqV{cFRQnc zuw)>G;hEe=oCi9Dfmrz=h(wVkqCBjyWB`TH*#f%PBup7hVR|B_UM7d9@(_8(b#<>v zm@XiR@inXSGC5Pge`;Vn4T$9?j&p(NaOGnKPc&X1R)H>N?R_?PJ$?x z2wMiwI66h3dri`qlNoxc-8IQn%Vdlmkfz~@&aZr}WP%=$rr}8=^&&Am$?&}Rx<*8A zGnt(SsA+(zpmL#f=45nUMWNe7%h^mCb22o~NR&&YFel^k@<4Yhai~0>OwBVh*X)2x06)nWVl`q>0XmKTb@s*>lumiuu_!GN z5)Io;1yL@M;)^dR*RSgI{6yYUfg{r~JyWJ!Ce^n(z%MbdTaxCx0|EwOJ-D{HWTOHdsv&g`d%cS~N$M(%Yy4NJlw>q*fPt;4~=va=5C!4^elhgH~ zuSvRZbzI-fBE6lY`c}vFm5{{)Kj)gH`Bq2u<%x19sloWrv29HjOLVVE9u`-}^vx{N z+eu1tbxdCY8J+eyTS;ZcrvWTiHjPw!2^fi{yH-T>HWNHTO#@T~)eFS{;gr#QY0EY# zZ$5yLXu5Tcsw@-b5-H5p(R}g1ZUeeoNoTH(=Cd&6GO5hifAmgz@0EaoXfx(wCj(I~ zk-}UZ$v5Mw`tOxYa%ub1@m&nnv3G;Y#wt((6 zNoTH(;hT9`o%c%6fV3PJpQdk9C5%M-jca-$?p>1<=IZ#p8CSK&n&e?|b^M-%DGw}_ zxjKHYfW)Y}&QbC5U|2p<(VPnmxeQ8ib+lfo4EIW-30H^dr73zT9D?zY@ao#p!;6Pj zUb?(*W%bChjrFlZgx)>ike99}%QJV#yGr#lOsH-zSiEV(SY zZivdWqd?KCVnQBJqrE~ysYNju52Vrk#X?%&SlmFb7GLZs3S$w{@HCWM6h$^f;0dG^ z$NVt|Z$9MUX(+iUifahN+ahf&9zD3cwX7De<%)9lhngBv@GJnm0QP?@rP<+`lp(IpKydYeXRaZg&W4bSpfHZY z=x(_9;?LfEPXDd9{^?-H^glO#|MY*3{zv+spLgPgC(ro1>xNIf^xIDT#$`u`TYo40 z@9%uW`KMYBUN^k(8UNjXa{Ea3-~YuQ`qfjPY&~W8zR!J6|B=6W%J47W^vcbHAKGzg zj#K{gbIyEe@4fduW%#B`|GIbfA3Y_@{pZJ?-+k8eA09mMGjBifeJ^=3K zyiw-1e(jkjUUDa1=f@w@e$3nM9Nc-?f8YD^SNzKGvb!$qUGi7A4R3t#-@Bh~{@38l z{_FZ{PkUhSiBtc(dCfDf8eZ1>?&HhNtA~%A^`m>g@jw22@Ua`3{pVfsz~GjJ=WqVb z6^|Q!;R9cH&wsve=Wyp&cJKYl7kiif>&g=bYk##T?nnE&dw1Ob;lVRHJNj?y-y7Ba zHTPWczVG9Y5py6M`hhVTFNjXmccKX>^4x6JjQ@XnpXFP!|DPm{`0F2p0nqZAA0-XPv(FB#NXU==iv5d ze(UB{w_h~8@-uJj{qQgB9NuyMQ}?`L<>tYw-t)}kzx(9)?;kzqFZ)mao+}2MFFyD9 z`Z;@}pMU3lgHs>8|NP;tcYf^Tm;dd&;m4kJ*{>E9C4ZeEACCAVI zZ2bDYrO%#v;y-i-H~qwabdG-}{`3EO)<;i0bob+jC;NAGZ(83u{Kz|g?9`9_=3sF4 zH(%ZP$RqK;f9oR`?fJkJgTbwT{L$_WpO621>5dEczV)4tAO65~_n-X4%XbdXyZP?^ zxl3O&_@$TL(fO9|KQp>M|8~J|ZGP!BHx2gx(vNlryW;*l_k$N~?)vlB5ANUHJ9*Z_ zw+%kN{OrAR-*w*boU0$(|HbF;9DdV%PrCO@KlQS~#$6YkxH!?@H(Yhi=1+e6ErZYf z-9MgafBB|}fAr9a&2wLL>)@wecXe;=?{6Dic-`%LZ~5HW!&iUg*}Z)`&l{e#c;%j_ zzx&d`cf8}C?pX`*d|dacAKLqmuRSq1|6i|p zVDREwcHev3+bI3%={^F{;dpCT`1A|BJPm?>#_`dq+{pr#B)1&vNNAFKNc8vFf zNAFL?{qg_%{sj8ud1>_DZ~x%`T=DeVo_u-pmrq@O%gXaF-*MwtuNeP?zkK;w=bsn- G^Zx-TuGq@} diff --git a/tests/scripts/save_policy_to_safetensor.py b/tests/scripts/save_policy_to_safetensor.py index 70337c17..29e9a34f 100644 --- a/tests/scripts/save_policy_to_safetensor.py +++ b/tests/scripts/save_policy_to_safetensor.py @@ -88,14 +88,8 @@ def save_policy_to_safetensors(output_dir, env_name, policy_name, extra_override if __name__ == "__main__": - env_policies = [ - # ("xarm", "tdmpc", ["policy.n_action_repeats=2"]), - ( - "pusht", - "diffusion", - ["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"], - ), - ("aloha", "act", ["policy.n_action_steps=10"]), - ] + # Instructions: include the policies that you want to save artifacts for here. Please make sure to revert + # your changes when you are done. + env_policies = [] for env, policy, extra_overrides in env_policies: save_policy_to_safetensors("tests/data/save_policy_to_safetensors", env, policy, extra_overrides) diff --git a/tests/test_policies.py b/tests/test_policies.py index 7d2f19ba..12beec92 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -249,6 +249,17 @@ def test_normalize(insert_temporal_dim): # pass if it's run on another platform due to floating point errors @require_x86_64_kernel def test_backward_compatibility(env_name, policy_name, extra_overrides): + """ + NOTE: If this test does not pass, and you have intentionally changed something in the policy: + 1. Inspect the differences in policy outputs and make sure you can account for them. Your PR should + include a report on what changed and how that affected the outputs. + 2. Go to the `if __name__ == "__main__"` block of `test/scripts/save_policy_to_safetensors.py` and + add the policies you want to update the test artifacts for. + 3. Run `python test/scripts/save_policy_to_safetensors.py`. The test artifact should be updated. + 4. Check that this test now passes. + 5. Remember to restore `test/scripts/save_policy_to_safetensors.py` to its original state. + 6. Remember to stage and commit the resulting changes to `tests/data`. + """ env_policy_dir = Path("tests/data/save_policy_to_safetensors") / f"{env_name}_{policy_name}" saved_output_dict = load_file(env_policy_dir / "output_dict.safetensors") saved_grad_stats = load_file(env_policy_dir / "grad_stats.safetensors")