Fix policy defaults (#113)
This commit is contained in:
parent
791506dfb8
commit
ccffa9e406
|
@ -77,7 +77,7 @@ class ActionChunkingTransformerConfig:
|
||||||
# Normalization / Unnormalization
|
# Normalization / Unnormalization
|
||||||
input_normalization_modes: dict[str, str] = field(
|
input_normalization_modes: dict[str, str] = field(
|
||||||
default_factory=lambda: {
|
default_factory=lambda: {
|
||||||
"observation.image": "mean_std",
|
"observation.images.top": "mean_std",
|
||||||
"observation.state": "mean_std",
|
"observation.state": "mean_std",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
|
@ -43,7 +43,7 @@ class DiffusionPolicy(nn.Module):
|
||||||
name = "diffusion"
|
name = "diffusion"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, cfg: DiffusionConfig | None, lr_scheduler_num_training_steps: int = 0, dataset_stats=None
|
self, cfg: DiffusionConfig | None = None, lr_scheduler_num_training_steps: int = 0, dataset_stats=None
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
|
|
|
@ -5,6 +5,8 @@ from lerobot.common.datasets.factory import make_dataset
|
||||||
from lerobot.common.datasets.utils import cycle
|
from lerobot.common.datasets.utils import cycle
|
||||||
from lerobot.common.envs.factory import make_env
|
from lerobot.common.envs.factory import make_env
|
||||||
from lerobot.common.envs.utils import postprocess_action, preprocess_observation
|
from lerobot.common.envs.utils import postprocess_action, preprocess_observation
|
||||||
|
from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy
|
||||||
|
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
||||||
from lerobot.common.policies.factory import make_policy
|
from lerobot.common.policies.factory import make_policy
|
||||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||||
from lerobot.common.policies.policy_protocol import Policy
|
from lerobot.common.policies.policy_protocol import Policy
|
||||||
|
@ -113,6 +115,15 @@ def test_policy(env_name, policy_name, extra_overrides):
|
||||||
new_policy.load_state_dict(policy.state_dict())
|
new_policy.load_state_dict(policy.state_dict())
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("policy_cls", [DiffusionPolicy, ActionChunkingTransformerPolicy])
|
||||||
|
def test_policy_defaults(policy_cls):
|
||||||
|
kwargs = {}
|
||||||
|
# TODO(alexander-soare): Remove this kwargs hack when we move the scheduler out of DP.
|
||||||
|
if policy_cls is DiffusionPolicy:
|
||||||
|
kwargs = {"lr_scheduler_num_training_steps": 1}
|
||||||
|
policy_cls(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"insert_temporal_dim",
|
"insert_temporal_dim",
|
||||||
[
|
[
|
||||||
|
|
Loading…
Reference in New Issue