Fix policy defaults (#113)

This commit is contained in:
Alexander Soare 2024-04-29 08:26:59 +01:00 committed by GitHub
parent 791506dfb8
commit ccffa9e406
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 13 additions and 2 deletions

View File

@ -77,7 +77,7 @@ class ActionChunkingTransformerConfig:
# Normalization / Unnormalization # Normalization / Unnormalization
input_normalization_modes: dict[str, str] = field( input_normalization_modes: dict[str, str] = field(
default_factory=lambda: { default_factory=lambda: {
"observation.image": "mean_std", "observation.images.top": "mean_std",
"observation.state": "mean_std", "observation.state": "mean_std",
} }
) )

View File

@ -43,7 +43,7 @@ class DiffusionPolicy(nn.Module):
name = "diffusion" name = "diffusion"
def __init__( def __init__(
self, cfg: DiffusionConfig | None, lr_scheduler_num_training_steps: int = 0, dataset_stats=None self, cfg: DiffusionConfig | None = None, lr_scheduler_num_training_steps: int = 0, dataset_stats=None
): ):
""" """
Args: Args:

View File

@ -5,6 +5,8 @@ from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.utils import cycle from lerobot.common.datasets.utils import cycle
from lerobot.common.envs.factory import make_env from lerobot.common.envs.factory import make_env
from lerobot.common.envs.utils import postprocess_action, preprocess_observation from lerobot.common.envs.utils import postprocess_action, preprocess_observation
from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
from lerobot.common.policies.factory import make_policy from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.policy_protocol import Policy from lerobot.common.policies.policy_protocol import Policy
@ -113,6 +115,15 @@ def test_policy(env_name, policy_name, extra_overrides):
new_policy.load_state_dict(policy.state_dict()) new_policy.load_state_dict(policy.state_dict())
@pytest.mark.parametrize("policy_cls", [DiffusionPolicy, ActionChunkingTransformerPolicy])
def test_policy_defaults(policy_cls):
kwargs = {}
# TODO(alexander-soare): Remove this kwargs hack when we move the scheduler out of DP.
if policy_cls is DiffusionPolicy:
kwargs = {"lr_scheduler_num_training_steps": 1}
policy_cls(**kwargs)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"insert_temporal_dim", "insert_temporal_dim",
[ [