From a1d16fb4009fc49a02c5f4fcbe1369189fbb61b1 Mon Sep 17 00:00:00 2001 From: Eugene Mironov Date: Tue, 11 Feb 2025 17:37:00 +0700 Subject: [PATCH] [Port HIL-SERL] Add resnet-10 as default encoder for HIL-SERL (#696) Co-authored-by: Khalil Meftah Co-authored-by: Adil Zouitine Co-authored-by: Michel Aractingi Co-authored-by: Ke Wang --- .../policies/hilserl/classifier/configuration_classifier.py | 2 +- lerobot/common/policies/sac/configuration_sac.py | 2 +- lerobot/configs/policy/hilserl_classifier.yaml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/lerobot/common/policies/hilserl/classifier/configuration_classifier.py b/lerobot/common/policies/hilserl/classifier/configuration_classifier.py index de3742ec..fe7eb142 100644 --- a/lerobot/common/policies/hilserl/classifier/configuration_classifier.py +++ b/lerobot/common/policies/hilserl/classifier/configuration_classifier.py @@ -10,7 +10,7 @@ class ClassifierConfig: num_classes: int = 2 hidden_dim: int = 256 dropout_rate: float = 0.1 - model_name: str = "microsoft/resnet-50" + model_name: str = "helper2424/resnet10" device: str = "cpu" model_type: str = "cnn" # "transformer" or "cnn" num_cameras: int = 2 diff --git a/lerobot/common/policies/sac/configuration_sac.py b/lerobot/common/policies/sac/configuration_sac.py index bcca8976..e9d78fdd 100644 --- a/lerobot/common/policies/sac/configuration_sac.py +++ b/lerobot/common/policies/sac/configuration_sac.py @@ -61,7 +61,7 @@ class SACConfig: ) camera_number: int = 1 # Add type annotations for these fields: - vision_encoder_name: str | None = field(default="microsoft/resnet-18") + vision_encoder_name: str | None = field(default="helper2424/resnet10") freeze_vision_encoder: bool = True image_encoder_hidden_dim: int = 32 shared_encoder: bool = True diff --git a/lerobot/configs/policy/hilserl_classifier.yaml b/lerobot/configs/policy/hilserl_classifier.yaml index a315902b..1a95f000 100644 --- a/lerobot/configs/policy/hilserl_classifier.yaml +++ b/lerobot/configs/policy/hilserl_classifier.yaml @@ -36,7 +36,7 @@ eval: policy: name: "hilserl/classifier/push_green_cube_hf_cropped_resized" #"hilserl/classifier/pick_place_lego_cube_1" - model_name: "facebook/convnext-base-224" + model_name: "helper2424/resnet10" model_type: "cnn" num_cameras: 2 # Has to be len(training.image_keys)