Added gripper control mechanism to gym_manipulator

Moved HilSerl env config to configs/env/configs.py
fixes in actor_server and modeling_sac and configuration_sac
added the possibility of ignoring missing keys in env_cfg in get_features_from_env_config function
This commit is contained in:
Michel Aractingi 2025-03-28 08:21:36 +01:00 committed by AdilZouitine
parent 79e0f6e06c
commit 02b9ea9446
7 changed files with 179 additions and 130 deletions

View File

@ -14,10 +14,12 @@
import abc
from dataclasses import dataclass, field
from typing import Any, Dict, Optional, Tuple
import draccus
from lerobot.common.constants import ACTION, OBS_ENV, OBS_IMAGE, OBS_IMAGES, OBS_ROBOT
from lerobot.common.robot_devices.robots.configs import RobotConfig
from lerobot.configs.types import FeatureType, PolicyFeature
@ -159,20 +161,84 @@ class XarmEnv(EnvConfig):
@dataclass
class VideoRecordConfig:
"""Configuration for video recording in ManiSkill environments."""
enabled: bool = False
record_dir: str = "videos"
trajectory_name: str = "trajectory"
@dataclass
class WrapperConfig:
"""Configuration for environment wrappers."""
delta_action: float | None = None
joint_masking_action_space: list[bool] | None = None
@dataclass
class EEActionSpaceConfig:
"""Configuration parameters for end-effector action space."""
x_step_size: float
y_step_size: float
z_step_size: float
bounds: Dict[str, Any] # Contains 'min' and 'max' keys with position bounds
use_gamepad: bool = False
@dataclass
class EnvWrapperConfig:
"""Configuration for environment wrappers."""
display_cameras: bool = False
delta_action: float = 0.1
use_relative_joint_positions: bool = True
add_joint_velocity_to_observation: bool = False
add_ee_pose_to_observation: bool = False
crop_params_dict: Optional[Dict[str, Tuple[int, int, int, int]]] = None
resize_size: Optional[Tuple[int, int]] = None
control_time_s: float = 20.0
fixed_reset_joint_positions: Optional[Any] = None
reset_time_s: float = 5.0
joint_masking_action_space: Optional[Any] = None
ee_action_space_params: Optional[EEActionSpaceConfig] = None
use_gripper: bool = False
@EnvConfig.register_subclass(name="gym_manipulator")
@dataclass
class HILSerlRobotEnvConfig(EnvConfig):
"""Configuration for the HILSerlRobotEnv environment."""
robot: Optional[RobotConfig] = None
wrapper: Optional[EnvWrapperConfig] = None
fps: int = 10
name: str = "real_robot"
mode: str = None # Either "record", "replay", None
repo_id: Optional[str] = None
dataset_root: Optional[str] = None
task: str = ""
num_episodes: int = 10 # only for record mode
episode: int = 0
device: str = "cuda"
push_to_hub: bool = True
pretrained_policy_name_or_path: Optional[str] = None
reward_classifier: dict[str, str | None] = field(
default_factory=lambda: {
"pretrained_path": None,
"config_path": None,
}
)
def gym_kwargs(self) -> dict:
return {}
@EnvConfig.register_subclass("maniskill_push")
@dataclass
class ManiskillEnvConfig(EnvConfig):
"""Configuration for the ManiSkill environment."""
name: str = "maniskill/pushcube"
task: str = "PushCube-v1"
image_size: int = 64

View File

@ -49,9 +49,7 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
# sanity check that images are channel last
_, h, w, c = img.shape
assert c < h and c < w, (
f"expect channel last images, but instead got {img.shape=}"
)
assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}"
# sanity check that images are uint8
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
@ -91,7 +89,7 @@ def env_to_policy_features(env_cfg: EnvConfig) -> dict[str, PolicyFeature]:
else:
feature = ft
policy_key = env_cfg.features_map[key]
policy_key = env_cfg.features_map.get(key, key)
policy_features[policy_key] = feature
return policy_features

View File

@ -29,7 +29,6 @@ class ConcurrencyConfig:
learner: str = "threads"
@dataclass
class ActorLearnerConfig:
learner_host: str = "127.0.0.1"
@ -110,6 +109,7 @@ class SACConfig(PreTrainedConfig):
use_backup_entropy: Whether to use backup entropy for the SAC algorithm.
grad_clip_norm: Gradient clipping norm for the SAC algorithm.
"""
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.MEAN_STD,
@ -183,22 +183,12 @@ class SACConfig(PreTrainedConfig):
grad_clip_norm: float = 40.0
# Network configuration
critic_network_kwargs: CriticNetworkConfig = field(
default_factory=CriticNetworkConfig
)
actor_network_kwargs: ActorNetworkConfig = field(
default_factory=ActorNetworkConfig
)
policy_kwargs: PolicyConfig = field(
default_factory=PolicyConfig
)
critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
actor_network_kwargs: ActorNetworkConfig = field(default_factory=ActorNetworkConfig)
policy_kwargs: PolicyConfig = field(default_factory=PolicyConfig)
actor_learner_config: ActorLearnerConfig = field(
default_factory=ActorLearnerConfig
)
concurrency: ConcurrencyConfig = field(
default_factory=ConcurrencyConfig
)
actor_learner_config: ActorLearnerConfig = field(default_factory=ActorLearnerConfig)
concurrency: ConcurrencyConfig = field(default_factory=ConcurrencyConfig)
def __post_init__(self):
super().__post_init__()
@ -218,18 +208,20 @@ class SACConfig(PreTrainedConfig):
return None
def validate_features(self) -> None:
if "observation.image" not in self.input_features:
raise ValueError("You must provide 'observation.image' in the input features")
has_image = any(key.startswith("observation.image") for key in self.input_features)
has_state = "observation.state" in self.input_features
if "observation.state" not in self.input_features:
raise ValueError("You must provide 'observation.state' in the input features")
if not (has_state or has_image):
raise ValueError(
"You must provide either 'observation.state' or an image observation (key starting with 'observation.image') in the input features"
)
if "action" not in self.output_features:
raise ValueError("You must provide 'action' in the output features")
@property
def image_features(self) -> list[str]:
return [key for key in self.input_features.keys() if 'image' in key]
return [key for key in self.input_features.keys() if "image" in key]
@property
def observation_delta_indices(self) -> list:
@ -243,9 +235,13 @@ class SACConfig(PreTrainedConfig):
def reward_delta_indices(self) -> None:
return None
if __name__ == "__main__":
import draccus
config = SACConfig()
draccus.set_config_type("json")
draccus.dump(config=config, stream=open(file='run_config.json', mode='w'), )
draccus.dump(
config=config,
stream=open(file="run_config.json", mode="w"),
)

View File

@ -39,7 +39,6 @@ from lerobot.common.policies.utils import get_device_from_parameters
class SACPolicy(
PreTrainedPolicy,
):
config_class = SACConfig
name = "sac"
@ -53,9 +52,7 @@ class SACPolicy(
self.config = config
if config.dataset_stats is not None:
input_normalization_params = _convert_normalization_params_to_tensor(
config.dataset_stats
)
input_normalization_params = _convert_normalization_params_to_tensor(config.dataset_stats)
self.normalize_inputs = Normalize(
config.input_features,
config.normalization_mapping,
@ -64,9 +61,7 @@ class SACPolicy(
else:
self.normalize_inputs = nn.Identity()
output_normalization_params = _convert_normalization_params_to_tensor(
config.dataset_stats
)
output_normalization_params = _convert_normalization_params_to_tensor(config.dataset_stats)
# HACK: This is hacky and should be removed
dataset_stats = dataset_stats or output_normalization_params
@ -138,7 +133,6 @@ class SACPolicy(
self.log_alpha = nn.Parameter(torch.tensor([math.log(temperature_init)]))
self.temperature = self.log_alpha.exp().item()
def get_optim_params(self) -> dict:
return {
"actor": self.actor.parameters_to_optimize,
@ -655,9 +649,10 @@ class SACObservationEncoder(nn.Module):
class DefaultImageEncoder(nn.Module):
def __init__(self, config: SACConfig):
super().__init__()
image_key = next(key for key in config.input_features.keys() if key.startswith("observation.image")) # noqa: SIM118
self.image_enc_layers = nn.Sequential(
nn.Conv2d(
in_channels=config.input_features["observation.image"].shape[0],
in_channels=config.input_features[image_key].shape[0],
out_channels=config.image_encoder_hidden_dim,
kernel_size=7,
stride=2,
@ -685,7 +680,9 @@ class DefaultImageEncoder(nn.Module):
),
nn.ReLU(),
)
dummy_batch = torch.zeros(1, *config.input_features["observation.image"].shape)
# Get first image key from input features
image_key = next(key for key in config.input_features.keys() if key.startswith("observation.image")) # noqa: SIM118
dummy_batch = torch.zeros(1, *config.input_features[image_key].shape)
with torch.inference_mode():
self.image_enc_out_shape = self.image_enc_layers(dummy_batch).shape[1:]
self.image_enc_layers.extend(
@ -844,8 +841,10 @@ if __name__ == "__main__":
import draccus
from lerobot.configs import parser
@parser.wrap()
def main(config: SACConfig):
policy = SACPolicy(config=config)
print("yolo")
main()

View File

@ -28,7 +28,6 @@ from torch.multiprocessing import Event, Queue
# TODO: Remove the import of maniskill
from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.sac.modeling_sac import SACPolicy
from lerobot.common.robot_devices.robots.utils import Robot, make_robot
from lerobot.common.robot_devices.utils import busy_wait
from lerobot.common.utils.random_utils import set_seed
from lerobot.common.utils.utils import (
@ -268,7 +267,7 @@ def update_policy_parameters(policy: SACPolicy, parameters_queue: Queue, device)
def act_with_policy(
cfg: TrainPipelineConfig,
robot: Robot,
# robot: Robot,
reward_classifier: nn.Module,
shutdown_event: any, # Event,
parameters_queue: Queue,
@ -503,7 +502,6 @@ def actor_cli(cfg: TrainPipelineConfig):
mp.set_start_method("spawn")
init_logging(log_file="actor.log")
robot = make_robot(robot_type=cfg.env.robot)
shutdown_event = setup_process_handlers(use_threads(cfg))
@ -563,18 +561,17 @@ def actor_cli(cfg: TrainPipelineConfig):
# HACK: FOR MANISKILL we do not have a reward classifier
# TODO: Remove this once we merge into main
reward_classifier = None
if (
cfg.env.reward_classifier["pretrained_path"] is not None
and cfg.env.reward_classifier["config_path"] is not None
):
reward_classifier = get_classifier(
pretrained_path=cfg.env.reward_classifier["pretrained_path"],
config_path=cfg.env.reward_classifier["config_path"],
)
# if (
# cfg.env.reward_classifier["pretrained_path"] is not None
# and cfg.env.reward_classifier["config_path"] is not None
# ):
# reward_classifier = get_classifier(
# pretrained_path=cfg.env.reward_classifier["pretrained_path"],
# config_path=cfg.env.reward_classifier["config_path"],
# )
act_with_policy(
cfg=cfg,
robot=robot,
reward_classifier=reward_classifier,
shutdown_event=shutdown_event,
parameters_queue=parameters_queue,

View File

@ -29,6 +29,9 @@ class InputController:
self.z_step_size = z_step_size
self.running = True
self.episode_end_status = None # None, "success", or "failure"
self.intervention_flag = False
self.open_gripper_command = False
self.close_gripper_command = False
def start(self):
"""Start the controller and initialize resources."""
@ -70,6 +73,19 @@ class InputController:
self.episode_end_status = None # Reset after reading
return status
def should_intervene(self):
"""Return True if intervention flag was set."""
return self.intervention_flag
def gripper_command(self):
"""Return the current gripper command."""
if self.open_gripper_command == self.close_gripper_command:
return "no-op"
elif self.open_gripper_command:
return "open"
elif self.close_gripper_command:
return "close"
class KeyboardController(InputController):
"""Generate motion deltas from keyboard input."""
@ -326,7 +342,6 @@ class GamepadControllerHID(InputController):
self.buttons = {}
self.quit_requested = False
self.save_requested = False
self.intervention_flag = False
def find_device(self):
"""Look for the gamepad device by vendor and product ID."""
@ -416,7 +431,13 @@ class GamepadControllerHID(InputController):
buttons = data[5]
# Check if RB is pressed then the intervention flag should be set
self.intervention_flag = data[6] == 2
self.intervention_flag = data[6] in [2, 6, 10, 14]
# Check if RT is pressed
self.open_gripper_command = data[6] in [8, 10, 12]
# Check if LT is pressed
self.close_gripper_command = data[6] in [4, 6, 12]
# Check if Y/Triangle button (bit 7) is pressed for saving
# Check if X/Square button (bit 5) is pressed for failure
@ -676,12 +697,8 @@ def teleoperate_gym_env(env, controller, fps: int = 30):
if __name__ == "__main__":
from lerobot.common.robot_devices.robots.configs import RobotConfig
from lerobot.common.robot_devices.robots.utils import make_robot_from_config
from lerobot.scripts.server.gym_manipulator import (
EEActionSpaceConfig,
EnvWrapperConfig,
HILSerlRobotEnvConfig,
make_robot_env,
)
from lerobot.scripts.server.gym_manipulator import make_robot_env
from lerobot.common.envs.configs import HILSerlRobotEnvConfig, EEActionSpaceConfig, EnvWrapperConfig
parser = argparse.ArgumentParser(description="Test end-effector control")
parser.add_argument(

View File

@ -1,9 +1,8 @@
import logging
import sys
import time
from dataclasses import dataclass
from threading import Lock
from typing import Annotated, Any, Dict, Optional, Tuple
from typing import Annotated, Any, Dict, Tuple
import gymnasium as gym
import numpy as np
@ -17,66 +16,13 @@ from lerobot.common.robot_devices.control_utils import (
is_headless,
reset_follower_position,
)
from lerobot.common.robot_devices.robots.configs import RobotConfig
from lerobot.common.robot_devices.robots.utils import make_robot_from_config
from lerobot.common.utils.utils import log_say
from lerobot.configs import parser
from lerobot.scripts.server.kinematics import RobotKinematics
logging.basicConfig(level=logging.INFO)
@dataclass
class EEActionSpaceConfig:
"""Configuration parameters for end-effector action space."""
x_step_size: float
y_step_size: float
z_step_size: float
bounds: Dict[str, Any] # Contains 'min' and 'max' keys with position bounds
use_gamepad: bool = False
@dataclass
class EnvWrapperConfig:
"""Configuration for environment wrappers."""
display_cameras: bool = False
delta_action: float = 0.1
use_relative_joint_positions: bool = True
add_joint_velocity_to_observation: bool = False
add_ee_pose_to_observation: bool = False
crop_params_dict: Optional[Dict[str, Tuple[int, int, int, int]]] = None
resize_size: Optional[Tuple[int, int]] = None
control_time_s: float = 20.0
fixed_reset_joint_positions: Optional[Any] = None
reset_time_s: float = 5.0
joint_masking_action_space: Optional[Any] = None
ee_action_space_params: Optional[EEActionSpaceConfig] = None
reward_classifier_pretrained_path: Optional[str] = None
reward_classifier_config_file: Optional[str] = None
@EnvConfig.register_subclass(name="gym_manipulator")
@dataclass
class HILSerlRobotEnvConfig(EnvConfig):
"""Configuration for the HILSerlRobotEnv environment."""
robot: Optional[RobotConfig] = None
wrapper: Optional[EnvWrapperConfig] = None
fps: int = 10
mode: str = None # Either "record", "replay", None
repo_id: Optional[str] = None
dataset_root: Optional[str] = None
task: str = ""
num_episodes: int = 10 # only for record mode
episode: int = 0
device: str = "cuda"
push_to_hub: bool = True
pretrained_policy_name_or_path: Optional[str] = None
def gym_kwargs(self) -> dict:
return {}
MAX_GRIPPER_COMMAND = 25
class HILSerlRobotEnv(gym.Env):
@ -813,9 +759,10 @@ class BatchCompitableWrapper(gym.ObservationWrapper):
class EEActionWrapper(gym.ActionWrapper):
def __init__(self, env, ee_action_space_params=None):
def __init__(self, env, ee_action_space_params=None, use_gripper=False):
super().__init__(env)
self.ee_action_space_params = ee_action_space_params
self.use_gripper = use_gripper
# Initialize kinematics instance for the appropriate robot type
robot_type = getattr(env.unwrapped.robot.config, "robot_type", "so100")
@ -829,10 +776,12 @@ class EEActionWrapper(gym.ActionWrapper):
ee_action_space_params.z_step_size,
]
)
if self.use_gripper:
action_space_bounds = np.concatenate([action_space_bounds, [1.0]])
ee_action_space = gym.spaces.Box(
low=-action_space_bounds,
high=action_space_bounds,
shape=(3,),
shape=(3 + int(self.use_gripper),),
dtype=np.float32,
)
if isinstance(self.action_space, gym.spaces.Tuple):
@ -848,6 +797,10 @@ class EEActionWrapper(gym.ActionWrapper):
if isinstance(action, tuple):
action, _ = action
if self.use_gripper:
gripper_command = action[-1]
action = action[:-1]
current_joint_pos = self.unwrapped.robot.follower_arms["main"].read("Present_Position")
current_ee_pos = self.fk_function(current_joint_pos)
if isinstance(action, torch.Tensor):
@ -863,6 +816,12 @@ class EEActionWrapper(gym.ActionWrapper):
position_only=True,
fk_func=self.fk_function,
)
if self.use_gripper:
gripper_command = gripper_command * MAX_GRIPPER_COMMAND
gripper_state = self.unwrapped.robot.follower_arms["main"].read("Present_Position")[-1]
gripper_action = np.clip(gripper_state + gripper_command, 0, MAX_GRIPPER_COMMAND)
target_joint_pos[-1] = gripper_action
return target_joint_pos, is_intervention
@ -912,6 +871,7 @@ class GamepadControlWrapper(gym.Wrapper):
x_step_size=1.0,
y_step_size=1.0,
z_step_size=1.0,
use_gripper=False,
auto_reset=False,
input_threshold=0.001,
):
@ -948,6 +908,7 @@ class GamepadControlWrapper(gym.Wrapper):
z_step_size=z_step_size,
)
self.auto_reset = auto_reset
self.use_gripper = use_gripper
self.input_threshold = input_threshold
self.controller.start()
@ -977,6 +938,15 @@ class GamepadControlWrapper(gym.Wrapper):
# Create action from gamepad input
gamepad_action = np.array([delta_x, delta_y, delta_z], dtype=np.float32)
if self.use_gripper:
gripper_command = self.controller.gripper_command()
if gripper_command == "open":
gamepad_action = np.concatenate([gamepad_action, [1.0]])
elif gripper_command == "close":
gamepad_action = np.concatenate([gamepad_action, [-1.0]])
else:
gamepad_action = np.concatenate([gamepad_action, [0.0]])
# Check episode ending buttons
# We'll rely on controller.get_episode_end_status() which returns "success", "failure", or None
episode_end_status = self.controller.get_episode_end_status()
@ -1023,6 +993,7 @@ class GamepadControlWrapper(gym.Wrapper):
final_action = (torch.from_numpy(gamepad_action), False)
else:
final_action = torch.from_numpy(gamepad_action)
else:
# Use the original action
final_action = action
@ -1138,7 +1109,11 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv:
# env = RewardWrapper(env=env, reward_classifier=reward_classifier, device=cfg.device)
env = TimeLimitWrapper(env=env, control_time_s=cfg.wrapper.control_time_s, fps=cfg.fps)
if cfg.wrapper.ee_action_space_params is not None:
env = EEActionWrapper(env=env, ee_action_space_params=cfg.wrapper.ee_action_space_params)
env = EEActionWrapper(
env=env,
ee_action_space_params=cfg.wrapper.ee_action_space_params,
use_gripper=cfg.wrapper.use_gripper,
)
if cfg.wrapper.ee_action_space_params is not None and cfg.wrapper.ee_action_space_params.use_gamepad:
# env = ActionScaleWrapper(env=env, ee_action_space_params=cfg.wrapper.ee_action_space_params)
env = GamepadControlWrapper(
@ -1146,6 +1121,7 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv:
x_step_size=cfg.wrapper.ee_action_space_params.x_step_size,
y_step_size=cfg.wrapper.ee_action_space_params.y_step_size,
z_step_size=cfg.wrapper.ee_action_space_params.z_step_size,
use_gripper=cfg.wrapper.use_gripper,
)
else:
env = KeyboardInterfaceWrapper(env=env)
@ -1184,7 +1160,7 @@ def get_classifier(cfg):
return model
def record_dataset(env, policy, cfg: HILSerlRobotEnvConfig):
def record_dataset(env, policy, cfg):
"""
Record a dataset of robot interactions using either a policy or teleop.