From ccffa9e406ceeafc891523f3fcbdeb4cc88c861f Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Mon, 29 Apr 2024 08:26:59 +0100 Subject: [PATCH] Fix policy defaults (#113) --- lerobot/common/policies/act/configuration_act.py | 2 +- .../common/policies/diffusion/modeling_diffusion.py | 2 +- tests/test_policies.py | 11 +++++++++++ 3 files changed, 13 insertions(+), 2 deletions(-) diff --git a/lerobot/common/policies/act/configuration_act.py b/lerobot/common/policies/act/configuration_act.py index 7564e6f7..16be36df 100644 --- a/lerobot/common/policies/act/configuration_act.py +++ b/lerobot/common/policies/act/configuration_act.py @@ -77,7 +77,7 @@ class ActionChunkingTransformerConfig: # Normalization / Unnormalization input_normalization_modes: dict[str, str] = field( default_factory=lambda: { - "observation.image": "mean_std", + "observation.images.top": "mean_std", "observation.state": "mean_std", } ) diff --git a/lerobot/common/policies/diffusion/modeling_diffusion.py b/lerobot/common/policies/diffusion/modeling_diffusion.py index 9e52ae92..4427296b 100644 --- a/lerobot/common/policies/diffusion/modeling_diffusion.py +++ b/lerobot/common/policies/diffusion/modeling_diffusion.py @@ -43,7 +43,7 @@ class DiffusionPolicy(nn.Module): name = "diffusion" def __init__( - self, cfg: DiffusionConfig | None, lr_scheduler_num_training_steps: int = 0, dataset_stats=None + self, cfg: DiffusionConfig | None = None, lr_scheduler_num_training_steps: int = 0, dataset_stats=None ): """ Args: diff --git a/tests/test_policies.py b/tests/test_policies.py index 3b1959d5..1a9e6674 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -5,6 +5,8 @@ from lerobot.common.datasets.factory import make_dataset from lerobot.common.datasets.utils import cycle from lerobot.common.envs.factory import make_env from lerobot.common.envs.utils import postprocess_action, preprocess_observation +from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy +from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy from lerobot.common.policies.factory import make_policy from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.policy_protocol import Policy @@ -113,6 +115,15 @@ def test_policy(env_name, policy_name, extra_overrides): new_policy.load_state_dict(policy.state_dict()) +@pytest.mark.parametrize("policy_cls", [DiffusionPolicy, ActionChunkingTransformerPolicy]) +def test_policy_defaults(policy_cls): + kwargs = {} + # TODO(alexander-soare): Remove this kwargs hack when we move the scheduler out of DP. + if policy_cls is DiffusionPolicy: + kwargs = {"lr_scheduler_num_training_steps": 1} + policy_cls(**kwargs) + + @pytest.mark.parametrize( "insert_temporal_dim", [