diff --git a/lerobot/__init__.py b/lerobot/__init__.py index f5134a59..a5a90fb4 100644 --- a/lerobot/__init__.py +++ b/lerobot/__init__.py @@ -129,7 +129,7 @@ available_datasets = list( itertools.chain(*available_datasets_per_env.values(), available_real_world_datasets) ) -# refers to attribute "name" of policy instance +# lists all available policies from `lerobot/common/policies` by their class attribute: `name`. available_policies = [ "act", "diffusion", diff --git a/tests/scripts/save_policy_to_safetensors.py b/tests/scripts/save_policy_to_safetensors.py index f867a5e8..c8c0b6cd 100644 --- a/tests/scripts/save_policy_to_safetensors.py +++ b/tests/scripts/save_policy_to_safetensors.py @@ -77,7 +77,7 @@ def get_policy_stats(env_name, policy_name, extra_overrides): batch = next(iter(dataloader)) obs = {} for k in batch: - if "observation" in k: + if k.startswith("observation"): obs[k] = batch[k] if "n_action_steps" in cfg.policy: