Add test to make sure policy dataclass configs match yaml configs (#292)

This commit is contained in:
Alexander Soare 2024-06-26 09:09:40 +01:00 committed by GitHub
parent 7d1542cae1
commit 342f429f1c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 54 additions and 13 deletions

View File

@ -28,9 +28,15 @@ def _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg):
logging.warning( logging.warning(
f"Hydra config is missing arguments: {set(expected_kwargs).difference(hydra_cfg.policy)}" f"Hydra config is missing arguments: {set(expected_kwargs).difference(hydra_cfg.policy)}"
) )
# OmegaConf.to_container returns lists where sequences are found, but our dataclasses use tuples to avoid
# issues with mutable defaults. This filter changes all lists to tuples.
def list_to_tuple(item):
return tuple(item) if isinstance(item, list) else item
policy_cfg = policy_cfg_class( policy_cfg = policy_cfg_class(
**{ **{
k: v k: list_to_tuple(v)
for k, v in OmegaConf.to_container(hydra_cfg.policy, resolve=True).items() for k, v in OmegaConf.to_container(hydra_cfg.policy, resolve=True).items()
if k in expected_kwargs if k in expected_kwargs
} }
@ -80,7 +86,9 @@ def make_policy(
policy. Therefore, this argument is mutually exclusive with `pretrained_policy_name_or_path`. policy. Therefore, this argument is mutually exclusive with `pretrained_policy_name_or_path`.
""" """
if not (pretrained_policy_name_or_path is None) ^ (dataset_stats is None): if not (pretrained_policy_name_or_path is None) ^ (dataset_stats is None):
raise ValueError("Only one of `pretrained_policy_name_or_path` and `dataset_stats` may be provided.") raise ValueError(
"Exactly one of `pretrained_policy_name_or_path` and `dataset_stats` must be provided."
)
policy_cls, policy_cfg_class = get_policy_and_config_classes(hydra_cfg.policy.name) policy_cls, policy_cfg_class = get_policy_and_config_classes(hydra_cfg.policy.name)
@ -91,9 +99,10 @@ def make_policy(
else: else:
# Load a pretrained policy and override the config if needed (for example, if there are inference-time # Load a pretrained policy and override the config if needed (for example, if there are inference-time
# hyperparameters that we want to vary). # hyperparameters that we want to vary).
# TODO(alexander-soare): This hack makes use of huggingface_hub's tooling to load the policy with, pretrained # TODO(alexander-soare): This hack makes use of huggingface_hub's tooling to load the policy with,
# weights which are then loaded into a fresh policy with the desired config. This PR in huggingface_hub should # pretrained weights which are then loaded into a fresh policy with the desired config. This PR in
# make it possible to avoid the hack: https://github.com/huggingface/huggingface_hub/pull/2274. # huggingface_hub should make it possible to avoid the hack:
# https://github.com/huggingface/huggingface_hub/pull/2274.
policy = policy_cls(policy_cfg) policy = policy_cls(policy_cfg)
policy.load_state_dict(policy_cls.from_pretrained(pretrained_policy_name_or_path).state_dict()) policy.load_state_dict(policy_cls.from_pretrained(pretrained_policy_name_or_path).state_dict())

View File

@ -99,7 +99,7 @@ policy:
clip_sample_range: 1.0 clip_sample_range: 1.0
# Inference # Inference
num_inference_steps: 100 num_inference_steps: null # if not provided, defaults to `num_train_timesteps`
# Loss computation # Loss computation
do_mask_loss_for_padding: false do_mask_loss_for_padding: false

View File

@ -54,7 +54,7 @@ policy:
discount: 0.9 discount: 0.9
# Inference. # Inference.
use_mpc: false use_mpc: true
cem_iterations: 6 cem_iterations: 6
max_std: 2.0 max_std: 2.0
min_std: 0.05 min_std: 0.05

View File

@ -108,16 +108,23 @@ def save_policy_to_safetensors(output_dir, env_name, policy_name, extra_override
if __name__ == "__main__": if __name__ == "__main__":
env_policies = [ env_policies = [
# ("xarm", "tdmpc", []), # ("xarm", "tdmpc", ["policy.use_mpc=false"], ""),
# ( # (
# "pusht", # "pusht",
# "diffusion", # "diffusion",
# ["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"], # [
# "policy.n_action_steps=8",
# "policy.num_inference_steps=10",
# "policy.down_dims=[128, 256, 512]",
# ],
# "",
# ), # ),
("aloha", "act", ["policy.n_action_steps=1000", "policy.chunk_size=1000"], "_1000_steps"), # ("aloha", "act", ["policy.n_action_steps=1000", "policy.chunk_size=1000"], "_1000_steps"),
# ("dora_aloha_real", "act_real", ["policy.n_action_steps=10"]), # ("dora_aloha_real", "act_real", ["policy.n_action_steps=10"]),
# ("dora_aloha_real", "act_real_no_state", ["policy.n_action_steps=10"]), # ("dora_aloha_real", "act_real_no_state", ["policy.n_action_steps=10"]),
] ]
if len(env_policies) == 0:
raise RuntimeError("No policies were provided!")
for env, policy, extra_overrides, file_name_extra in env_policies: for env, policy, extra_overrides, file_name_extra in env_policies:
save_policy_to_safetensors( save_policy_to_safetensors(
"tests/data/save_policy_to_safetensors", env, policy, extra_overrides, file_name_extra "tests/data/save_policy_to_safetensors", env, policy, extra_overrides, file_name_extra

View File

@ -26,7 +26,11 @@ from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.utils import cycle from lerobot.common.datasets.utils import cycle
from lerobot.common.envs.factory import make_env from lerobot.common.envs.factory import make_env
from lerobot.common.envs.utils import preprocess_observation from lerobot.common.envs.utils import preprocess_observation
from lerobot.common.policies.factory import get_policy_and_config_classes, make_policy from lerobot.common.policies.factory import (
_policy_cfg_from_hydra_cfg,
get_policy_and_config_classes,
make_policy,
)
from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.policy_protocol import Policy from lerobot.common.policies.policy_protocol import Policy
from lerobot.common.utils.utils import init_hydra_config from lerobot.common.utils.utils import init_hydra_config
@ -210,6 +214,23 @@ def test_policy_defaults(policy_name: str):
policy_cls() policy_cls()
@pytest.mark.parametrize(
"env_name,policy_name",
[
("xarm", "tdmpc"),
("pusht", "diffusion"),
("aloha", "act"),
],
)
def test_yaml_matches_dataclass(env_name: str, policy_name: str):
"""Check that dataclass configs match their respective yaml configs."""
hydra_cfg = init_hydra_config(DEFAULT_CONFIG_PATH, overrides=[f"env={env_name}", f"policy={policy_name}"])
_, policy_cfg_cls = get_policy_and_config_classes(policy_name)
policy_cfg_from_hydra = _policy_cfg_from_hydra_cfg(policy_cfg_cls, hydra_cfg)
policy_cfg_from_dataclass = policy_cfg_cls()
assert policy_cfg_from_hydra == policy_cfg_from_dataclass
@pytest.mark.parametrize("policy_name", available_policies) @pytest.mark.parametrize("policy_name", available_policies)
def test_save_and_load_pretrained(policy_name: str): def test_save_and_load_pretrained(policy_name: str):
policy_cls, _ = get_policy_and_config_classes(policy_name) policy_cls, _ = get_policy_and_config_classes(policy_name)
@ -318,7 +339,10 @@ def test_normalize(insert_temporal_dim):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"env_name, policy_name, extra_overrides, file_name_extra", "env_name, policy_name, extra_overrides, file_name_extra",
[ [
("xarm", "tdmpc", [], ""), # TODO(alexander-soare): `policy.use_mpc=false` was previously the default in the config yaml but it
# was changed to true. For some reason, tests would pass locally, but not in CI. So here we override
# to test with `policy.use_mpc=false`.
("xarm", "tdmpc", ["policy.use_mpc=false"], ""),
( (
"pusht", "pusht",
"diffusion", "diffusion",
@ -342,7 +366,8 @@ def test_backward_compatibility(env_name, policy_name, extra_overrides, file_nam
include a report on what changed and how that affected the outputs. include a report on what changed and how that affected the outputs.
2. Go to the `if __name__ == "__main__"` block of `tests/scripts/save_policy_to_safetensors.py` and 2. Go to the `if __name__ == "__main__"` block of `tests/scripts/save_policy_to_safetensors.py` and
add the policies you want to update the test artifacts for. add the policies you want to update the test artifacts for.
3. Run `python tests/scripts/save_policy_to_safetensors.py`. The test artifact should be updated. 3. Run `DATA_DIR=tests/data python tests/scripts/save_policy_to_safetensors.py`. The test artifact
should be updated.
4. Check that this test now passes. 4. Check that this test now passes.
5. Remember to restore `tests/scripts/save_policy_to_safetensors.py` to its original state. 5. Remember to restore `tests/scripts/save_policy_to_safetensors.py` to its original state.
6. Remember to stage and commit the resulting changes to `tests/data`. 6. Remember to stage and commit the resulting changes to `tests/data`.