Added gripper control mechanism to gym_manipulator

Moved HilSerl env config to configs/env/configs.py
fixes in actor_server and modeling_sac and configuration_sac
added the possibility of ignoring missing keys in env_cfg in get_features_from_env_config function
This commit is contained in:
Michel Aractingi 2025-03-28 08:21:36 +01:00
parent 88cc2b8fc8
commit 05a237ce10
7 changed files with 179 additions and 130 deletions

View File

@ -14,10 +14,12 @@
import abc import abc
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Dict, Optional, Tuple
import draccus import draccus
from lerobot.common.constants import ACTION, OBS_ENV, OBS_IMAGE, OBS_IMAGES, OBS_ROBOT from lerobot.common.constants import ACTION, OBS_ENV, OBS_IMAGE, OBS_IMAGES, OBS_ROBOT
from lerobot.common.robot_devices.robots.configs import RobotConfig
from lerobot.configs.types import FeatureType, PolicyFeature from lerobot.configs.types import FeatureType, PolicyFeature
@ -159,20 +161,84 @@ class XarmEnv(EnvConfig):
@dataclass @dataclass
class VideoRecordConfig: class VideoRecordConfig:
"""Configuration for video recording in ManiSkill environments.""" """Configuration for video recording in ManiSkill environments."""
enabled: bool = False enabled: bool = False
record_dir: str = "videos" record_dir: str = "videos"
trajectory_name: str = "trajectory" trajectory_name: str = "trajectory"
@dataclass @dataclass
class WrapperConfig: class WrapperConfig:
"""Configuration for environment wrappers.""" """Configuration for environment wrappers."""
delta_action: float | None = None delta_action: float | None = None
joint_masking_action_space: list[bool] | None = None joint_masking_action_space: list[bool] | None = None
@dataclass
class EEActionSpaceConfig:
"""Configuration parameters for end-effector action space."""
x_step_size: float
y_step_size: float
z_step_size: float
bounds: Dict[str, Any] # Contains 'min' and 'max' keys with position bounds
use_gamepad: bool = False
@dataclass
class EnvWrapperConfig:
"""Configuration for environment wrappers."""
display_cameras: bool = False
delta_action: float = 0.1
use_relative_joint_positions: bool = True
add_joint_velocity_to_observation: bool = False
add_ee_pose_to_observation: bool = False
crop_params_dict: Optional[Dict[str, Tuple[int, int, int, int]]] = None
resize_size: Optional[Tuple[int, int]] = None
control_time_s: float = 20.0
fixed_reset_joint_positions: Optional[Any] = None
reset_time_s: float = 5.0
joint_masking_action_space: Optional[Any] = None
ee_action_space_params: Optional[EEActionSpaceConfig] = None
use_gripper: bool = False
@EnvConfig.register_subclass(name="gym_manipulator")
@dataclass
class HILSerlRobotEnvConfig(EnvConfig):
"""Configuration for the HILSerlRobotEnv environment."""
robot: Optional[RobotConfig] = None
wrapper: Optional[EnvWrapperConfig] = None
fps: int = 10
name: str = "real_robot"
mode: str = None # Either "record", "replay", None
repo_id: Optional[str] = None
dataset_root: Optional[str] = None
task: str = ""
num_episodes: int = 10 # only for record mode
episode: int = 0
device: str = "cuda"
push_to_hub: bool = True
pretrained_policy_name_or_path: Optional[str] = None
reward_classifier: dict[str, str | None] = field(
default_factory=lambda: {
"pretrained_path": None,
"config_path": None,
}
)
def gym_kwargs(self) -> dict:
return {}
@EnvConfig.register_subclass("maniskill_push") @EnvConfig.register_subclass("maniskill_push")
@dataclass @dataclass
class ManiskillEnvConfig(EnvConfig): class ManiskillEnvConfig(EnvConfig):
"""Configuration for the ManiSkill environment.""" """Configuration for the ManiSkill environment."""
name: str = "maniskill/pushcube" name: str = "maniskill/pushcube"
task: str = "PushCube-v1" task: str = "PushCube-v1"
image_size: int = 64 image_size: int = 64
@ -185,7 +251,7 @@ class ManiskillEnvConfig(EnvConfig):
render_mode: str = "rgb_array" render_mode: str = "rgb_array"
render_size: int = 64 render_size: int = 64
device: str = "cuda" device: str = "cuda"
robot: str = "so100" # This is a hack to make the robot config work robot: str = "so100" # This is a hack to make the robot config work
video_record: VideoRecordConfig = field(default_factory=VideoRecordConfig) video_record: VideoRecordConfig = field(default_factory=VideoRecordConfig)
wrapper: WrapperConfig = field(default_factory=WrapperConfig) wrapper: WrapperConfig = field(default_factory=WrapperConfig)
features: dict[str, PolicyFeature] = field( features: dict[str, PolicyFeature] = field(
@ -218,4 +284,4 @@ class ManiskillEnvConfig(EnvConfig):
"control_mode": self.control_mode, "control_mode": self.control_mode,
"sensor_configs": {"width": self.image_size, "height": self.image_size}, "sensor_configs": {"width": self.image_size, "height": self.image_size},
"num_envs": 1, "num_envs": 1,
} }

View File

@ -53,9 +53,7 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
# sanity check that images are channel last # sanity check that images are channel last
_, h, w, c = img.shape _, h, w, c = img.shape
assert c < h and c < w, ( assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}"
f"expect channel last images, but instead got {img.shape=}"
)
# sanity check that images are uint8 # sanity check that images are uint8
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}" assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
@ -95,7 +93,7 @@ def env_to_policy_features(env_cfg: EnvConfig) -> dict[str, PolicyFeature]:
else: else:
feature = ft feature = ft
policy_key = env_cfg.features_map[key] policy_key = env_cfg.features_map.get(key, key)
policy_features[policy_key] = feature policy_features[policy_key] = feature
return policy_features return policy_features

View File

@ -20,7 +20,7 @@ from typing import Any, Optional
from lerobot.common.optim.optimizers import MultiAdamConfig from lerobot.common.optim.optimizers import MultiAdamConfig
from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import NormalizationMode, PolicyFeature, FeatureType from lerobot.configs.types import NormalizationMode, PolicyFeature, FeatureType
@dataclass @dataclass
@ -29,7 +29,6 @@ class ConcurrencyConfig:
learner: str = "threads" learner: str = "threads"
@dataclass @dataclass
class ActorLearnerConfig: class ActorLearnerConfig:
learner_host: str = "127.0.0.1" learner_host: str = "127.0.0.1"
@ -110,6 +109,7 @@ class SACConfig(PreTrainedConfig):
use_backup_entropy: Whether to use backup entropy for the SAC algorithm. use_backup_entropy: Whether to use backup entropy for the SAC algorithm.
grad_clip_norm: Gradient clipping norm for the SAC algorithm. grad_clip_norm: Gradient clipping norm for the SAC algorithm.
""" """
normalization_mapping: dict[str, NormalizationMode] = field( normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: { default_factory=lambda: {
"VISUAL": NormalizationMode.MEAN_STD, "VISUAL": NormalizationMode.MEAN_STD,
@ -152,8 +152,8 @@ class SACConfig(PreTrainedConfig):
camera_number: int = 1 camera_number: int = 1
device: str = "cuda" device: str = "cuda"
storage_device: str = "cpu" storage_device: str = "cpu"
# Set to "helper2424/resnet10" for hil serl # Set to "helper2424/resnet10" for hil serl
vision_encoder_name: str | None = None vision_encoder_name: str | None = None
freeze_vision_encoder: bool = True freeze_vision_encoder: bool = True
image_encoder_hidden_dim: int = 32 image_encoder_hidden_dim: int = 32
shared_encoder: bool = True shared_encoder: bool = True
@ -163,7 +163,7 @@ class SACConfig(PreTrainedConfig):
online_env_seed: int = 10000 online_env_seed: int = 10000
online_buffer_capacity: int = 100000 online_buffer_capacity: int = 100000
offline_buffer_capacity: int = 100000 offline_buffer_capacity: int = 100000
online_step_before_learning: int = 100 online_step_before_learning: int = 100
policy_update_freq: int = 1 policy_update_freq: int = 1
# SAC algorithm parameters # SAC algorithm parameters
@ -181,24 +181,14 @@ class SACConfig(PreTrainedConfig):
target_entropy: float | None = None target_entropy: float | None = None
use_backup_entropy: bool = True use_backup_entropy: bool = True
grad_clip_norm: float = 40.0 grad_clip_norm: float = 40.0
# Network configuration # Network configuration
critic_network_kwargs: CriticNetworkConfig = field( critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
default_factory=CriticNetworkConfig actor_network_kwargs: ActorNetworkConfig = field(default_factory=ActorNetworkConfig)
) policy_kwargs: PolicyConfig = field(default_factory=PolicyConfig)
actor_network_kwargs: ActorNetworkConfig = field(
default_factory=ActorNetworkConfig actor_learner_config: ActorLearnerConfig = field(default_factory=ActorLearnerConfig)
) concurrency: ConcurrencyConfig = field(default_factory=ConcurrencyConfig)
policy_kwargs: PolicyConfig = field(
default_factory=PolicyConfig
)
actor_learner_config: ActorLearnerConfig = field(
default_factory=ActorLearnerConfig
)
concurrency: ConcurrencyConfig = field(
default_factory=ConcurrencyConfig
)
def __post_init__(self): def __post_init__(self):
super().__post_init__() super().__post_init__()
@ -218,18 +208,20 @@ class SACConfig(PreTrainedConfig):
return None return None
def validate_features(self) -> None: def validate_features(self) -> None:
if "observation.image" not in self.input_features: has_image = any(key.startswith("observation.image") for key in self.input_features)
raise ValueError("You must provide 'observation.image' in the input features") has_state = "observation.state" in self.input_features
if "observation.state" not in self.input_features: if not (has_state or has_image):
raise ValueError("You must provide 'observation.state' in the input features") raise ValueError(
"You must provide either 'observation.state' or an image observation (key starting with 'observation.image') in the input features"
)
if "action" not in self.output_features: if "action" not in self.output_features:
raise ValueError("You must provide 'action' in the output features") raise ValueError("You must provide 'action' in the output features")
@property @property
def image_features(self) -> list[str]: def image_features(self) -> list[str]:
return [key for key in self.input_features.keys() if 'image' in key] return [key for key in self.input_features.keys() if "image" in key]
@property @property
def observation_delta_indices(self) -> list: def observation_delta_indices(self) -> list:
@ -243,9 +235,13 @@ class SACConfig(PreTrainedConfig):
def reward_delta_indices(self) -> None: def reward_delta_indices(self) -> None:
return None return None
if __name__ == "__main__": if __name__ == "__main__":
import draccus import draccus
config = SACConfig() config = SACConfig()
draccus.set_config_type("json") draccus.set_config_type("json")
draccus.dump(config=config, stream=open(file='run_config.json', mode='w'), ) draccus.dump(
config=config,
stream=open(file="run_config.json", mode="w"),
)

View File

@ -39,7 +39,6 @@ from lerobot.common.policies.utils import get_device_from_parameters
class SACPolicy( class SACPolicy(
PreTrainedPolicy, PreTrainedPolicy,
): ):
config_class = SACConfig config_class = SACConfig
name = "sac" name = "sac"
@ -53,9 +52,7 @@ class SACPolicy(
self.config = config self.config = config
if config.dataset_stats is not None: if config.dataset_stats is not None:
input_normalization_params = _convert_normalization_params_to_tensor( input_normalization_params = _convert_normalization_params_to_tensor(config.dataset_stats)
config.dataset_stats
)
self.normalize_inputs = Normalize( self.normalize_inputs = Normalize(
config.input_features, config.input_features,
config.normalization_mapping, config.normalization_mapping,
@ -64,12 +61,10 @@ class SACPolicy(
else: else:
self.normalize_inputs = nn.Identity() self.normalize_inputs = nn.Identity()
output_normalization_params = _convert_normalization_params_to_tensor( output_normalization_params = _convert_normalization_params_to_tensor(config.dataset_stats)
config.dataset_stats
)
# HACK: This is hacky and should be removed # HACK: This is hacky and should be removed
dataset_stats = dataset_stats or output_normalization_params dataset_stats = dataset_stats or output_normalization_params
self.normalize_targets = Normalize( self.normalize_targets = Normalize(
config.output_features, config.normalization_mapping, dataset_stats config.output_features, config.normalization_mapping, dataset_stats
) )
@ -138,7 +133,6 @@ class SACPolicy(
self.log_alpha = nn.Parameter(torch.tensor([math.log(temperature_init)])) self.log_alpha = nn.Parameter(torch.tensor([math.log(temperature_init)]))
self.temperature = self.log_alpha.exp().item() self.temperature = self.log_alpha.exp().item()
def get_optim_params(self) -> dict: def get_optim_params(self) -> dict:
return { return {
"actor": self.actor.parameters_to_optimize, "actor": self.actor.parameters_to_optimize,
@ -655,9 +649,10 @@ class SACObservationEncoder(nn.Module):
class DefaultImageEncoder(nn.Module): class DefaultImageEncoder(nn.Module):
def __init__(self, config: SACConfig): def __init__(self, config: SACConfig):
super().__init__() super().__init__()
image_key = next(key for key in config.input_features.keys() if key.startswith("observation.image")) # noqa: SIM118
self.image_enc_layers = nn.Sequential( self.image_enc_layers = nn.Sequential(
nn.Conv2d( nn.Conv2d(
in_channels=config.input_features["observation.image"].shape[0], in_channels=config.input_features[image_key].shape[0],
out_channels=config.image_encoder_hidden_dim, out_channels=config.image_encoder_hidden_dim,
kernel_size=7, kernel_size=7,
stride=2, stride=2,
@ -685,7 +680,9 @@ class DefaultImageEncoder(nn.Module):
), ),
nn.ReLU(), nn.ReLU(),
) )
dummy_batch = torch.zeros(1, *config.input_features["observation.image"].shape) # Get first image key from input features
image_key = next(key for key in config.input_features.keys() if key.startswith("observation.image")) # noqa: SIM118
dummy_batch = torch.zeros(1, *config.input_features[image_key].shape)
with torch.inference_mode(): with torch.inference_mode():
self.image_enc_out_shape = self.image_enc_layers(dummy_batch).shape[1:] self.image_enc_out_shape = self.image_enc_layers(dummy_batch).shape[1:]
self.image_enc_layers.extend( self.image_enc_layers.extend(
@ -844,8 +841,10 @@ if __name__ == "__main__":
import draccus import draccus
from lerobot.configs import parser from lerobot.configs import parser
@parser.wrap() @parser.wrap()
def main(config: SACConfig): def main(config: SACConfig):
policy = SACPolicy(config=config) policy = SACPolicy(config=config)
print("yolo") print("yolo")
main()
main()

View File

@ -28,7 +28,6 @@ from torch.multiprocessing import Event, Queue
# TODO: Remove the import of maniskill # TODO: Remove the import of maniskill
from lerobot.common.policies.factory import make_policy from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.sac.modeling_sac import SACPolicy from lerobot.common.policies.sac.modeling_sac import SACPolicy
from lerobot.common.robot_devices.robots.utils import Robot, make_robot
from lerobot.common.robot_devices.utils import busy_wait from lerobot.common.robot_devices.utils import busy_wait
from lerobot.common.utils.random_utils import set_seed from lerobot.common.utils.random_utils import set_seed
from lerobot.common.utils.utils import ( from lerobot.common.utils.utils import (
@ -268,7 +267,7 @@ def update_policy_parameters(policy: SACPolicy, parameters_queue: Queue, device)
def act_with_policy( def act_with_policy(
cfg: TrainPipelineConfig, cfg: TrainPipelineConfig,
robot: Robot, # robot: Robot,
reward_classifier: nn.Module, reward_classifier: nn.Module,
shutdown_event: any, # Event, shutdown_event: any, # Event,
parameters_queue: Queue, parameters_queue: Queue,
@ -287,7 +286,7 @@ def act_with_policy(
logging.info("make_env online") logging.info("make_env online")
online_env = make_robot_env( cfg=cfg.env) online_env = make_robot_env(cfg=cfg.env)
set_seed(cfg.seed) set_seed(cfg.seed)
device = get_safe_torch_device(cfg.policy.device, log=True) device = get_safe_torch_device(cfg.policy.device, log=True)
@ -503,7 +502,6 @@ def actor_cli(cfg: TrainPipelineConfig):
mp.set_start_method("spawn") mp.set_start_method("spawn")
init_logging(log_file="actor.log") init_logging(log_file="actor.log")
robot = make_robot(robot_type=cfg.env.robot)
shutdown_event = setup_process_handlers(use_threads(cfg)) shutdown_event = setup_process_handlers(use_threads(cfg))
@ -563,18 +561,17 @@ def actor_cli(cfg: TrainPipelineConfig):
# HACK: FOR MANISKILL we do not have a reward classifier # HACK: FOR MANISKILL we do not have a reward classifier
# TODO: Remove this once we merge into main # TODO: Remove this once we merge into main
reward_classifier = None reward_classifier = None
if ( # if (
cfg.env.reward_classifier["pretrained_path"] is not None # cfg.env.reward_classifier["pretrained_path"] is not None
and cfg.env.reward_classifier["config_path"] is not None # and cfg.env.reward_classifier["config_path"] is not None
): # ):
reward_classifier = get_classifier( # reward_classifier = get_classifier(
pretrained_path=cfg.env.reward_classifier["pretrained_path"], # pretrained_path=cfg.env.reward_classifier["pretrained_path"],
config_path=cfg.env.reward_classifier["config_path"], # config_path=cfg.env.reward_classifier["config_path"],
) # )
act_with_policy( act_with_policy(
cfg=cfg, cfg=cfg,
robot=robot,
reward_classifier=reward_classifier, reward_classifier=reward_classifier,
shutdown_event=shutdown_event, shutdown_event=shutdown_event,
parameters_queue=parameters_queue, parameters_queue=parameters_queue,

View File

@ -29,6 +29,9 @@ class InputController:
self.z_step_size = z_step_size self.z_step_size = z_step_size
self.running = True self.running = True
self.episode_end_status = None # None, "success", or "failure" self.episode_end_status = None # None, "success", or "failure"
self.intervention_flag = False
self.open_gripper_command = False
self.close_gripper_command = False
def start(self): def start(self):
"""Start the controller and initialize resources.""" """Start the controller and initialize resources."""
@ -70,6 +73,19 @@ class InputController:
self.episode_end_status = None # Reset after reading self.episode_end_status = None # Reset after reading
return status return status
def should_intervene(self):
"""Return True if intervention flag was set."""
return self.intervention_flag
def gripper_command(self):
"""Return the current gripper command."""
if self.open_gripper_command == self.close_gripper_command:
return "no-op"
elif self.open_gripper_command:
return "open"
elif self.close_gripper_command:
return "close"
class KeyboardController(InputController): class KeyboardController(InputController):
"""Generate motion deltas from keyboard input.""" """Generate motion deltas from keyboard input."""
@ -326,7 +342,6 @@ class GamepadControllerHID(InputController):
self.buttons = {} self.buttons = {}
self.quit_requested = False self.quit_requested = False
self.save_requested = False self.save_requested = False
self.intervention_flag = False
def find_device(self): def find_device(self):
"""Look for the gamepad device by vendor and product ID.""" """Look for the gamepad device by vendor and product ID."""
@ -416,7 +431,13 @@ class GamepadControllerHID(InputController):
buttons = data[5] buttons = data[5]
# Check if RB is pressed then the intervention flag should be set # Check if RB is pressed then the intervention flag should be set
self.intervention_flag = data[6] == 2 self.intervention_flag = data[6] in [2, 6, 10, 14]
# Check if RT is pressed
self.open_gripper_command = data[6] in [8, 10, 12]
# Check if LT is pressed
self.close_gripper_command = data[6] in [4, 6, 12]
# Check if Y/Triangle button (bit 7) is pressed for saving # Check if Y/Triangle button (bit 7) is pressed for saving
# Check if X/Square button (bit 5) is pressed for failure # Check if X/Square button (bit 5) is pressed for failure
@ -676,12 +697,8 @@ def teleoperate_gym_env(env, controller, fps: int = 30):
if __name__ == "__main__": if __name__ == "__main__":
from lerobot.common.robot_devices.robots.configs import RobotConfig from lerobot.common.robot_devices.robots.configs import RobotConfig
from lerobot.common.robot_devices.robots.utils import make_robot_from_config from lerobot.common.robot_devices.robots.utils import make_robot_from_config
from lerobot.scripts.server.gym_manipulator import ( from lerobot.scripts.server.gym_manipulator import make_robot_env
EEActionSpaceConfig, from lerobot.common.envs.configs import HILSerlRobotEnvConfig, EEActionSpaceConfig, EnvWrapperConfig
EnvWrapperConfig,
HILSerlRobotEnvConfig,
make_robot_env,
)
parser = argparse.ArgumentParser(description="Test end-effector control") parser = argparse.ArgumentParser(description="Test end-effector control")
parser.add_argument( parser.add_argument(

View File

@ -1,9 +1,8 @@
import logging import logging
import sys import sys
import time import time
from dataclasses import dataclass
from threading import Lock from threading import Lock
from typing import Annotated, Any, Dict, Optional, Tuple from typing import Annotated, Any, Dict, Tuple
import gymnasium as gym import gymnasium as gym
import numpy as np import numpy as np
@ -17,66 +16,13 @@ from lerobot.common.robot_devices.control_utils import (
is_headless, is_headless,
reset_follower_position, reset_follower_position,
) )
from lerobot.common.robot_devices.robots.configs import RobotConfig
from lerobot.common.robot_devices.robots.utils import make_robot_from_config from lerobot.common.robot_devices.robots.utils import make_robot_from_config
from lerobot.common.utils.utils import log_say from lerobot.common.utils.utils import log_say
from lerobot.configs import parser from lerobot.configs import parser
from lerobot.scripts.server.kinematics import RobotKinematics from lerobot.scripts.server.kinematics import RobotKinematics
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
MAX_GRIPPER_COMMAND = 25
@dataclass
class EEActionSpaceConfig:
"""Configuration parameters for end-effector action space."""
x_step_size: float
y_step_size: float
z_step_size: float
bounds: Dict[str, Any] # Contains 'min' and 'max' keys with position bounds
use_gamepad: bool = False
@dataclass
class EnvWrapperConfig:
"""Configuration for environment wrappers."""
display_cameras: bool = False
delta_action: float = 0.1
use_relative_joint_positions: bool = True
add_joint_velocity_to_observation: bool = False
add_ee_pose_to_observation: bool = False
crop_params_dict: Optional[Dict[str, Tuple[int, int, int, int]]] = None
resize_size: Optional[Tuple[int, int]] = None
control_time_s: float = 20.0
fixed_reset_joint_positions: Optional[Any] = None
reset_time_s: float = 5.0
joint_masking_action_space: Optional[Any] = None
ee_action_space_params: Optional[EEActionSpaceConfig] = None
reward_classifier_pretrained_path: Optional[str] = None
reward_classifier_config_file: Optional[str] = None
@EnvConfig.register_subclass(name="gym_manipulator")
@dataclass
class HILSerlRobotEnvConfig(EnvConfig):
"""Configuration for the HILSerlRobotEnv environment."""
robot: Optional[RobotConfig] = None
wrapper: Optional[EnvWrapperConfig] = None
fps: int = 10
mode: str = None # Either "record", "replay", None
repo_id: Optional[str] = None
dataset_root: Optional[str] = None
task: str = ""
num_episodes: int = 10 # only for record mode
episode: int = 0
device: str = "cuda"
push_to_hub: bool = True
pretrained_policy_name_or_path: Optional[str] = None
def gym_kwargs(self) -> dict:
return {}
class HILSerlRobotEnv(gym.Env): class HILSerlRobotEnv(gym.Env):
@ -813,9 +759,10 @@ class BatchCompitableWrapper(gym.ObservationWrapper):
class EEActionWrapper(gym.ActionWrapper): class EEActionWrapper(gym.ActionWrapper):
def __init__(self, env, ee_action_space_params=None): def __init__(self, env, ee_action_space_params=None, use_gripper=False):
super().__init__(env) super().__init__(env)
self.ee_action_space_params = ee_action_space_params self.ee_action_space_params = ee_action_space_params
self.use_gripper = use_gripper
# Initialize kinematics instance for the appropriate robot type # Initialize kinematics instance for the appropriate robot type
robot_type = getattr(env.unwrapped.robot.config, "robot_type", "so100") robot_type = getattr(env.unwrapped.robot.config, "robot_type", "so100")
@ -829,10 +776,12 @@ class EEActionWrapper(gym.ActionWrapper):
ee_action_space_params.z_step_size, ee_action_space_params.z_step_size,
] ]
) )
if self.use_gripper:
action_space_bounds = np.concatenate([action_space_bounds, [1.0]])
ee_action_space = gym.spaces.Box( ee_action_space = gym.spaces.Box(
low=-action_space_bounds, low=-action_space_bounds,
high=action_space_bounds, high=action_space_bounds,
shape=(3,), shape=(3 + int(self.use_gripper),),
dtype=np.float32, dtype=np.float32,
) )
if isinstance(self.action_space, gym.spaces.Tuple): if isinstance(self.action_space, gym.spaces.Tuple):
@ -848,6 +797,10 @@ class EEActionWrapper(gym.ActionWrapper):
if isinstance(action, tuple): if isinstance(action, tuple):
action, _ = action action, _ = action
if self.use_gripper:
gripper_command = action[-1]
action = action[:-1]
current_joint_pos = self.unwrapped.robot.follower_arms["main"].read("Present_Position") current_joint_pos = self.unwrapped.robot.follower_arms["main"].read("Present_Position")
current_ee_pos = self.fk_function(current_joint_pos) current_ee_pos = self.fk_function(current_joint_pos)
if isinstance(action, torch.Tensor): if isinstance(action, torch.Tensor):
@ -863,6 +816,12 @@ class EEActionWrapper(gym.ActionWrapper):
position_only=True, position_only=True,
fk_func=self.fk_function, fk_func=self.fk_function,
) )
if self.use_gripper:
gripper_command = gripper_command * MAX_GRIPPER_COMMAND
gripper_state = self.unwrapped.robot.follower_arms["main"].read("Present_Position")[-1]
gripper_action = np.clip(gripper_state + gripper_command, 0, MAX_GRIPPER_COMMAND)
target_joint_pos[-1] = gripper_action
return target_joint_pos, is_intervention return target_joint_pos, is_intervention
@ -912,6 +871,7 @@ class GamepadControlWrapper(gym.Wrapper):
x_step_size=1.0, x_step_size=1.0,
y_step_size=1.0, y_step_size=1.0,
z_step_size=1.0, z_step_size=1.0,
use_gripper=False,
auto_reset=False, auto_reset=False,
input_threshold=0.001, input_threshold=0.001,
): ):
@ -948,6 +908,7 @@ class GamepadControlWrapper(gym.Wrapper):
z_step_size=z_step_size, z_step_size=z_step_size,
) )
self.auto_reset = auto_reset self.auto_reset = auto_reset
self.use_gripper = use_gripper
self.input_threshold = input_threshold self.input_threshold = input_threshold
self.controller.start() self.controller.start()
@ -977,6 +938,15 @@ class GamepadControlWrapper(gym.Wrapper):
# Create action from gamepad input # Create action from gamepad input
gamepad_action = np.array([delta_x, delta_y, delta_z], dtype=np.float32) gamepad_action = np.array([delta_x, delta_y, delta_z], dtype=np.float32)
if self.use_gripper:
gripper_command = self.controller.gripper_command()
if gripper_command == "open":
gamepad_action = np.concatenate([gamepad_action, [1.0]])
elif gripper_command == "close":
gamepad_action = np.concatenate([gamepad_action, [-1.0]])
else:
gamepad_action = np.concatenate([gamepad_action, [0.0]])
# Check episode ending buttons # Check episode ending buttons
# We'll rely on controller.get_episode_end_status() which returns "success", "failure", or None # We'll rely on controller.get_episode_end_status() which returns "success", "failure", or None
episode_end_status = self.controller.get_episode_end_status() episode_end_status = self.controller.get_episode_end_status()
@ -1023,6 +993,7 @@ class GamepadControlWrapper(gym.Wrapper):
final_action = (torch.from_numpy(gamepad_action), False) final_action = (torch.from_numpy(gamepad_action), False)
else: else:
final_action = torch.from_numpy(gamepad_action) final_action = torch.from_numpy(gamepad_action)
else: else:
# Use the original action # Use the original action
final_action = action final_action = action
@ -1138,7 +1109,11 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv:
# env = RewardWrapper(env=env, reward_classifier=reward_classifier, device=cfg.device) # env = RewardWrapper(env=env, reward_classifier=reward_classifier, device=cfg.device)
env = TimeLimitWrapper(env=env, control_time_s=cfg.wrapper.control_time_s, fps=cfg.fps) env = TimeLimitWrapper(env=env, control_time_s=cfg.wrapper.control_time_s, fps=cfg.fps)
if cfg.wrapper.ee_action_space_params is not None: if cfg.wrapper.ee_action_space_params is not None:
env = EEActionWrapper(env=env, ee_action_space_params=cfg.wrapper.ee_action_space_params) env = EEActionWrapper(
env=env,
ee_action_space_params=cfg.wrapper.ee_action_space_params,
use_gripper=cfg.wrapper.use_gripper,
)
if cfg.wrapper.ee_action_space_params is not None and cfg.wrapper.ee_action_space_params.use_gamepad: if cfg.wrapper.ee_action_space_params is not None and cfg.wrapper.ee_action_space_params.use_gamepad:
# env = ActionScaleWrapper(env=env, ee_action_space_params=cfg.wrapper.ee_action_space_params) # env = ActionScaleWrapper(env=env, ee_action_space_params=cfg.wrapper.ee_action_space_params)
env = GamepadControlWrapper( env = GamepadControlWrapper(
@ -1146,6 +1121,7 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv:
x_step_size=cfg.wrapper.ee_action_space_params.x_step_size, x_step_size=cfg.wrapper.ee_action_space_params.x_step_size,
y_step_size=cfg.wrapper.ee_action_space_params.y_step_size, y_step_size=cfg.wrapper.ee_action_space_params.y_step_size,
z_step_size=cfg.wrapper.ee_action_space_params.z_step_size, z_step_size=cfg.wrapper.ee_action_space_params.z_step_size,
use_gripper=cfg.wrapper.use_gripper,
) )
else: else:
env = KeyboardInterfaceWrapper(env=env) env = KeyboardInterfaceWrapper(env=env)
@ -1184,7 +1160,7 @@ def get_classifier(cfg):
return model return model
def record_dataset(env, policy, cfg: HILSerlRobotEnvConfig): def record_dataset(env, policy, cfg):
""" """
Record a dataset of robot interactions using either a policy or teleop. Record a dataset of robot interactions using either a policy or teleop.