From b60f4efa7cf53b4c70e8450fe471d20978ac5638 Mon Sep 17 00:00:00 2001 From: Remi Date: Fri, 31 May 2024 12:02:26 +0200 Subject: [PATCH] Apply suggestions from code review Co-authored-by: Alexander Soare --- lerobot/__init__.py | 2 +- tests/scripts/save_policy_to_safetensors.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/lerobot/__init__.py b/lerobot/__init__.py index f5134a59..a5a90fb4 100644 --- a/lerobot/__init__.py +++ b/lerobot/__init__.py @@ -129,7 +129,7 @@ available_datasets = list( itertools.chain(*available_datasets_per_env.values(), available_real_world_datasets) ) -# refers to attribute "name" of policy instance +# lists all available policies from `lerobot/common/policies` by their class attribute: `name`. available_policies = [ "act", "diffusion", diff --git a/tests/scripts/save_policy_to_safetensors.py b/tests/scripts/save_policy_to_safetensors.py index f867a5e8..c8c0b6cd 100644 --- a/tests/scripts/save_policy_to_safetensors.py +++ b/tests/scripts/save_policy_to_safetensors.py @@ -77,7 +77,7 @@ def get_policy_stats(env_name, policy_name, extra_overrides): batch = next(iter(dataloader)) obs = {} for k in batch: - if "observation" in k: + if k.startswith("observation"): obs[k] = batch[k] if "n_action_steps" in cfg.policy: