diff --git a/tests/data/save_policy_to_safetensors/aloha_act/actions.safetensors b/tests/data/save_policy_to_safetensors/aloha_act/actions.safetensors index 3c9447d7..c816148f 100644 Binary files a/tests/data/save_policy_to_safetensors/aloha_act/actions.safetensors and b/tests/data/save_policy_to_safetensors/aloha_act/actions.safetensors differ diff --git a/tests/data/save_policy_to_safetensors/aloha_act/grad_stats.safetensors b/tests/data/save_policy_to_safetensors/aloha_act/grad_stats.safetensors index 7dfbc3b3..bdecb18b 100644 Binary files a/tests/data/save_policy_to_safetensors/aloha_act/grad_stats.safetensors and b/tests/data/save_policy_to_safetensors/aloha_act/grad_stats.safetensors differ diff --git a/tests/data/save_policy_to_safetensors/aloha_act/output_dict.safetensors b/tests/data/save_policy_to_safetensors/aloha_act/output_dict.safetensors index 4c738f39..641771c6 100644 Binary files a/tests/data/save_policy_to_safetensors/aloha_act/output_dict.safetensors and b/tests/data/save_policy_to_safetensors/aloha_act/output_dict.safetensors differ diff --git a/tests/data/save_policy_to_safetensors/aloha_act/param_stats.safetensors b/tests/data/save_policy_to_safetensors/aloha_act/param_stats.safetensors index 7a2e0e70..26d91924 100644 Binary files a/tests/data/save_policy_to_safetensors/aloha_act/param_stats.safetensors and b/tests/data/save_policy_to_safetensors/aloha_act/param_stats.safetensors differ diff --git a/tests/scripts/save_policy_to_safetensor.py b/tests/scripts/save_policy_to_safetensor.py index e79a94ff..ccdd204c 100644 --- a/tests/scripts/save_policy_to_safetensor.py +++ b/tests/scripts/save_policy_to_safetensor.py @@ -19,6 +19,7 @@ from pathlib import Path import torch from safetensors.torch import save_file +from lerobot import available_policies_per_env from lerobot.common.datasets.factory import make_dataset from lerobot.common.policies.factory import make_policy from lerobot.common.utils.utils import init_hydra_config, set_global_seed @@ -26,15 +27,14 @@ from lerobot.scripts.train import make_optimizer_and_scheduler from tests.utils import DEFAULT_CONFIG_PATH -def get_policy_stats(env_name, policy_name, extra_overrides=None): +def get_policy_stats(env_name, policy_name): cfg = init_hydra_config( DEFAULT_CONFIG_PATH, overrides=[ f"env={env_name}", f"policy={policy_name}", "device=cpu", - ] - + extra_overrides, + ], ) set_global_seed(1337) dataset = make_dataset(cfg) @@ -88,14 +88,14 @@ def get_policy_stats(env_name, policy_name, extra_overrides=None): return output_dict, grad_stats, param_stats, actions -def save_policy_to_safetensors(output_dir, env_name, policy_name, extra_overrides): +def save_policy_to_safetensors(output_dir, env_name, policy_name): env_policy_dir = Path(output_dir) / f"{env_name}_{policy_name}" if env_policy_dir.exists(): shutil.rmtree(env_policy_dir) env_policy_dir.mkdir(parents=True, exist_ok=True) - output_dict, grad_stats, param_stats, actions = get_policy_stats(env_name, policy_name, extra_overrides) + output_dict, grad_stats, param_stats, actions = get_policy_stats(env_name, policy_name) save_file(output_dict, env_policy_dir / "output_dict.safetensors") save_file(grad_stats, env_policy_dir / "grad_stats.safetensors") save_file(param_stats, env_policy_dir / "param_stats.safetensors") @@ -103,8 +103,6 @@ def save_policy_to_safetensors(output_dir, env_name, policy_name, extra_override if __name__ == "__main__": - # Instructions: include the policies that you want to save artifacts for here. Please make sure to revert - # your changes when you are done. - env_policies = [] - for env, policy, extra_overrides in env_policies: - save_policy_to_safetensors("tests/data/save_policy_to_safetensors", env, policy, extra_overrides) + for env, policies in available_policies_per_env.items(): + for policy in policies: + save_policy_to_safetensors("tests/data/save_policy_to_safetensors", env, policy)