Add test to make sure policy dataclass configs match yaml configs (#292)
This commit is contained in:
parent
7d1542cae1
commit
342f429f1c
|
@ -28,9 +28,15 @@ def _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg):
|
||||||
logging.warning(
|
logging.warning(
|
||||||
f"Hydra config is missing arguments: {set(expected_kwargs).difference(hydra_cfg.policy)}"
|
f"Hydra config is missing arguments: {set(expected_kwargs).difference(hydra_cfg.policy)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# OmegaConf.to_container returns lists where sequences are found, but our dataclasses use tuples to avoid
|
||||||
|
# issues with mutable defaults. This filter changes all lists to tuples.
|
||||||
|
def list_to_tuple(item):
|
||||||
|
return tuple(item) if isinstance(item, list) else item
|
||||||
|
|
||||||
policy_cfg = policy_cfg_class(
|
policy_cfg = policy_cfg_class(
|
||||||
**{
|
**{
|
||||||
k: v
|
k: list_to_tuple(v)
|
||||||
for k, v in OmegaConf.to_container(hydra_cfg.policy, resolve=True).items()
|
for k, v in OmegaConf.to_container(hydra_cfg.policy, resolve=True).items()
|
||||||
if k in expected_kwargs
|
if k in expected_kwargs
|
||||||
}
|
}
|
||||||
|
@ -80,7 +86,9 @@ def make_policy(
|
||||||
policy. Therefore, this argument is mutually exclusive with `pretrained_policy_name_or_path`.
|
policy. Therefore, this argument is mutually exclusive with `pretrained_policy_name_or_path`.
|
||||||
"""
|
"""
|
||||||
if not (pretrained_policy_name_or_path is None) ^ (dataset_stats is None):
|
if not (pretrained_policy_name_or_path is None) ^ (dataset_stats is None):
|
||||||
raise ValueError("Only one of `pretrained_policy_name_or_path` and `dataset_stats` may be provided.")
|
raise ValueError(
|
||||||
|
"Exactly one of `pretrained_policy_name_or_path` and `dataset_stats` must be provided."
|
||||||
|
)
|
||||||
|
|
||||||
policy_cls, policy_cfg_class = get_policy_and_config_classes(hydra_cfg.policy.name)
|
policy_cls, policy_cfg_class = get_policy_and_config_classes(hydra_cfg.policy.name)
|
||||||
|
|
||||||
|
@ -91,9 +99,10 @@ def make_policy(
|
||||||
else:
|
else:
|
||||||
# Load a pretrained policy and override the config if needed (for example, if there are inference-time
|
# Load a pretrained policy and override the config if needed (for example, if there are inference-time
|
||||||
# hyperparameters that we want to vary).
|
# hyperparameters that we want to vary).
|
||||||
# TODO(alexander-soare): This hack makes use of huggingface_hub's tooling to load the policy with, pretrained
|
# TODO(alexander-soare): This hack makes use of huggingface_hub's tooling to load the policy with,
|
||||||
# weights which are then loaded into a fresh policy with the desired config. This PR in huggingface_hub should
|
# pretrained weights which are then loaded into a fresh policy with the desired config. This PR in
|
||||||
# make it possible to avoid the hack: https://github.com/huggingface/huggingface_hub/pull/2274.
|
# huggingface_hub should make it possible to avoid the hack:
|
||||||
|
# https://github.com/huggingface/huggingface_hub/pull/2274.
|
||||||
policy = policy_cls(policy_cfg)
|
policy = policy_cls(policy_cfg)
|
||||||
policy.load_state_dict(policy_cls.from_pretrained(pretrained_policy_name_or_path).state_dict())
|
policy.load_state_dict(policy_cls.from_pretrained(pretrained_policy_name_or_path).state_dict())
|
||||||
|
|
||||||
|
|
|
@ -99,7 +99,7 @@ policy:
|
||||||
clip_sample_range: 1.0
|
clip_sample_range: 1.0
|
||||||
|
|
||||||
# Inference
|
# Inference
|
||||||
num_inference_steps: 100
|
num_inference_steps: null # if not provided, defaults to `num_train_timesteps`
|
||||||
|
|
||||||
# Loss computation
|
# Loss computation
|
||||||
do_mask_loss_for_padding: false
|
do_mask_loss_for_padding: false
|
||||||
|
|
|
@ -54,7 +54,7 @@ policy:
|
||||||
discount: 0.9
|
discount: 0.9
|
||||||
|
|
||||||
# Inference.
|
# Inference.
|
||||||
use_mpc: false
|
use_mpc: true
|
||||||
cem_iterations: 6
|
cem_iterations: 6
|
||||||
max_std: 2.0
|
max_std: 2.0
|
||||||
min_std: 0.05
|
min_std: 0.05
|
||||||
|
|
|
@ -108,16 +108,23 @@ def save_policy_to_safetensors(output_dir, env_name, policy_name, extra_override
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
env_policies = [
|
env_policies = [
|
||||||
# ("xarm", "tdmpc", []),
|
# ("xarm", "tdmpc", ["policy.use_mpc=false"], ""),
|
||||||
# (
|
# (
|
||||||
# "pusht",
|
# "pusht",
|
||||||
# "diffusion",
|
# "diffusion",
|
||||||
# ["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"],
|
# [
|
||||||
|
# "policy.n_action_steps=8",
|
||||||
|
# "policy.num_inference_steps=10",
|
||||||
|
# "policy.down_dims=[128, 256, 512]",
|
||||||
|
# ],
|
||||||
|
# "",
|
||||||
# ),
|
# ),
|
||||||
("aloha", "act", ["policy.n_action_steps=1000", "policy.chunk_size=1000"], "_1000_steps"),
|
# ("aloha", "act", ["policy.n_action_steps=1000", "policy.chunk_size=1000"], "_1000_steps"),
|
||||||
# ("dora_aloha_real", "act_real", ["policy.n_action_steps=10"]),
|
# ("dora_aloha_real", "act_real", ["policy.n_action_steps=10"]),
|
||||||
# ("dora_aloha_real", "act_real_no_state", ["policy.n_action_steps=10"]),
|
# ("dora_aloha_real", "act_real_no_state", ["policy.n_action_steps=10"]),
|
||||||
]
|
]
|
||||||
|
if len(env_policies) == 0:
|
||||||
|
raise RuntimeError("No policies were provided!")
|
||||||
for env, policy, extra_overrides, file_name_extra in env_policies:
|
for env, policy, extra_overrides, file_name_extra in env_policies:
|
||||||
save_policy_to_safetensors(
|
save_policy_to_safetensors(
|
||||||
"tests/data/save_policy_to_safetensors", env, policy, extra_overrides, file_name_extra
|
"tests/data/save_policy_to_safetensors", env, policy, extra_overrides, file_name_extra
|
||||||
|
|
|
@ -26,7 +26,11 @@ from lerobot.common.datasets.factory import make_dataset
|
||||||
from lerobot.common.datasets.utils import cycle
|
from lerobot.common.datasets.utils import cycle
|
||||||
from lerobot.common.envs.factory import make_env
|
from lerobot.common.envs.factory import make_env
|
||||||
from lerobot.common.envs.utils import preprocess_observation
|
from lerobot.common.envs.utils import preprocess_observation
|
||||||
from lerobot.common.policies.factory import get_policy_and_config_classes, make_policy
|
from lerobot.common.policies.factory import (
|
||||||
|
_policy_cfg_from_hydra_cfg,
|
||||||
|
get_policy_and_config_classes,
|
||||||
|
make_policy,
|
||||||
|
)
|
||||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||||
from lerobot.common.policies.policy_protocol import Policy
|
from lerobot.common.policies.policy_protocol import Policy
|
||||||
from lerobot.common.utils.utils import init_hydra_config
|
from lerobot.common.utils.utils import init_hydra_config
|
||||||
|
@ -210,6 +214,23 @@ def test_policy_defaults(policy_name: str):
|
||||||
policy_cls()
|
policy_cls()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"env_name,policy_name",
|
||||||
|
[
|
||||||
|
("xarm", "tdmpc"),
|
||||||
|
("pusht", "diffusion"),
|
||||||
|
("aloha", "act"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_yaml_matches_dataclass(env_name: str, policy_name: str):
|
||||||
|
"""Check that dataclass configs match their respective yaml configs."""
|
||||||
|
hydra_cfg = init_hydra_config(DEFAULT_CONFIG_PATH, overrides=[f"env={env_name}", f"policy={policy_name}"])
|
||||||
|
_, policy_cfg_cls = get_policy_and_config_classes(policy_name)
|
||||||
|
policy_cfg_from_hydra = _policy_cfg_from_hydra_cfg(policy_cfg_cls, hydra_cfg)
|
||||||
|
policy_cfg_from_dataclass = policy_cfg_cls()
|
||||||
|
assert policy_cfg_from_hydra == policy_cfg_from_dataclass
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("policy_name", available_policies)
|
@pytest.mark.parametrize("policy_name", available_policies)
|
||||||
def test_save_and_load_pretrained(policy_name: str):
|
def test_save_and_load_pretrained(policy_name: str):
|
||||||
policy_cls, _ = get_policy_and_config_classes(policy_name)
|
policy_cls, _ = get_policy_and_config_classes(policy_name)
|
||||||
|
@ -318,7 +339,10 @@ def test_normalize(insert_temporal_dim):
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"env_name, policy_name, extra_overrides, file_name_extra",
|
"env_name, policy_name, extra_overrides, file_name_extra",
|
||||||
[
|
[
|
||||||
("xarm", "tdmpc", [], ""),
|
# TODO(alexander-soare): `policy.use_mpc=false` was previously the default in the config yaml but it
|
||||||
|
# was changed to true. For some reason, tests would pass locally, but not in CI. So here we override
|
||||||
|
# to test with `policy.use_mpc=false`.
|
||||||
|
("xarm", "tdmpc", ["policy.use_mpc=false"], ""),
|
||||||
(
|
(
|
||||||
"pusht",
|
"pusht",
|
||||||
"diffusion",
|
"diffusion",
|
||||||
|
@ -342,7 +366,8 @@ def test_backward_compatibility(env_name, policy_name, extra_overrides, file_nam
|
||||||
include a report on what changed and how that affected the outputs.
|
include a report on what changed and how that affected the outputs.
|
||||||
2. Go to the `if __name__ == "__main__"` block of `tests/scripts/save_policy_to_safetensors.py` and
|
2. Go to the `if __name__ == "__main__"` block of `tests/scripts/save_policy_to_safetensors.py` and
|
||||||
add the policies you want to update the test artifacts for.
|
add the policies you want to update the test artifacts for.
|
||||||
3. Run `python tests/scripts/save_policy_to_safetensors.py`. The test artifact should be updated.
|
3. Run `DATA_DIR=tests/data python tests/scripts/save_policy_to_safetensors.py`. The test artifact
|
||||||
|
should be updated.
|
||||||
4. Check that this test now passes.
|
4. Check that this test now passes.
|
||||||
5. Remember to restore `tests/scripts/save_policy_to_safetensors.py` to its original state.
|
5. Remember to restore `tests/scripts/save_policy_to_safetensors.py` to its original state.
|
||||||
6. Remember to stage and commit the resulting changes to `tests/data`.
|
6. Remember to stage and commit the resulting changes to `tests/data`.
|
||||||
|
|
Loading…
Reference in New Issue