Added gripper control mechanism to gym_manipulator
Moved HilSerl env config to configs/env/configs.py fixes in actor_server and modeling_sac and configuration_sac added the possibility of ignoring missing keys in env_cfg in get_features_from_env_config function
This commit is contained in:
parent
79e0f6e06c
commit
02b9ea9446
|
@ -14,10 +14,12 @@
|
|||
|
||||
import abc
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import draccus
|
||||
|
||||
from lerobot.common.constants import ACTION, OBS_ENV, OBS_IMAGE, OBS_IMAGES, OBS_ROBOT
|
||||
from lerobot.common.robot_devices.robots.configs import RobotConfig
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
|
||||
|
||||
|
@ -159,20 +161,84 @@ class XarmEnv(EnvConfig):
|
|||
@dataclass
|
||||
class VideoRecordConfig:
|
||||
"""Configuration for video recording in ManiSkill environments."""
|
||||
|
||||
enabled: bool = False
|
||||
record_dir: str = "videos"
|
||||
trajectory_name: str = "trajectory"
|
||||
|
||||
|
||||
@dataclass
|
||||
class WrapperConfig:
|
||||
"""Configuration for environment wrappers."""
|
||||
|
||||
delta_action: float | None = None
|
||||
joint_masking_action_space: list[bool] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class EEActionSpaceConfig:
|
||||
"""Configuration parameters for end-effector action space."""
|
||||
|
||||
x_step_size: float
|
||||
y_step_size: float
|
||||
z_step_size: float
|
||||
bounds: Dict[str, Any] # Contains 'min' and 'max' keys with position bounds
|
||||
use_gamepad: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class EnvWrapperConfig:
|
||||
"""Configuration for environment wrappers."""
|
||||
|
||||
display_cameras: bool = False
|
||||
delta_action: float = 0.1
|
||||
use_relative_joint_positions: bool = True
|
||||
add_joint_velocity_to_observation: bool = False
|
||||
add_ee_pose_to_observation: bool = False
|
||||
crop_params_dict: Optional[Dict[str, Tuple[int, int, int, int]]] = None
|
||||
resize_size: Optional[Tuple[int, int]] = None
|
||||
control_time_s: float = 20.0
|
||||
fixed_reset_joint_positions: Optional[Any] = None
|
||||
reset_time_s: float = 5.0
|
||||
joint_masking_action_space: Optional[Any] = None
|
||||
ee_action_space_params: Optional[EEActionSpaceConfig] = None
|
||||
use_gripper: bool = False
|
||||
|
||||
|
||||
@EnvConfig.register_subclass(name="gym_manipulator")
|
||||
@dataclass
|
||||
class HILSerlRobotEnvConfig(EnvConfig):
|
||||
"""Configuration for the HILSerlRobotEnv environment."""
|
||||
|
||||
robot: Optional[RobotConfig] = None
|
||||
wrapper: Optional[EnvWrapperConfig] = None
|
||||
fps: int = 10
|
||||
name: str = "real_robot"
|
||||
mode: str = None # Either "record", "replay", None
|
||||
repo_id: Optional[str] = None
|
||||
dataset_root: Optional[str] = None
|
||||
task: str = ""
|
||||
num_episodes: int = 10 # only for record mode
|
||||
episode: int = 0
|
||||
device: str = "cuda"
|
||||
push_to_hub: bool = True
|
||||
pretrained_policy_name_or_path: Optional[str] = None
|
||||
reward_classifier: dict[str, str | None] = field(
|
||||
default_factory=lambda: {
|
||||
"pretrained_path": None,
|
||||
"config_path": None,
|
||||
}
|
||||
)
|
||||
|
||||
def gym_kwargs(self) -> dict:
|
||||
return {}
|
||||
|
||||
|
||||
@EnvConfig.register_subclass("maniskill_push")
|
||||
@dataclass
|
||||
class ManiskillEnvConfig(EnvConfig):
|
||||
"""Configuration for the ManiSkill environment."""
|
||||
|
||||
name: str = "maniskill/pushcube"
|
||||
task: str = "PushCube-v1"
|
||||
image_size: int = 64
|
||||
|
@ -185,7 +251,7 @@ class ManiskillEnvConfig(EnvConfig):
|
|||
render_mode: str = "rgb_array"
|
||||
render_size: int = 64
|
||||
device: str = "cuda"
|
||||
robot: str = "so100" # This is a hack to make the robot config work
|
||||
robot: str = "so100" # This is a hack to make the robot config work
|
||||
video_record: VideoRecordConfig = field(default_factory=VideoRecordConfig)
|
||||
wrapper: WrapperConfig = field(default_factory=WrapperConfig)
|
||||
features: dict[str, PolicyFeature] = field(
|
||||
|
@ -218,4 +284,4 @@ class ManiskillEnvConfig(EnvConfig):
|
|||
"control_mode": self.control_mode,
|
||||
"sensor_configs": {"width": self.image_size, "height": self.image_size},
|
||||
"num_envs": 1,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -49,9 +49,7 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
|
|||
|
||||
# sanity check that images are channel last
|
||||
_, h, w, c = img.shape
|
||||
assert c < h and c < w, (
|
||||
f"expect channel last images, but instead got {img.shape=}"
|
||||
)
|
||||
assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}"
|
||||
|
||||
# sanity check that images are uint8
|
||||
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
|
||||
|
@ -91,7 +89,7 @@ def env_to_policy_features(env_cfg: EnvConfig) -> dict[str, PolicyFeature]:
|
|||
else:
|
||||
feature = ft
|
||||
|
||||
policy_key = env_cfg.features_map[key]
|
||||
policy_key = env_cfg.features_map.get(key, key)
|
||||
policy_features[policy_key] = feature
|
||||
|
||||
return policy_features
|
||||
|
|
|
@ -20,7 +20,7 @@ from typing import Any, Optional
|
|||
|
||||
from lerobot.common.optim.optimizers import MultiAdamConfig
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import NormalizationMode, PolicyFeature, FeatureType
|
||||
from lerobot.configs.types import NormalizationMode, PolicyFeature, FeatureType
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -29,7 +29,6 @@ class ConcurrencyConfig:
|
|||
learner: str = "threads"
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class ActorLearnerConfig:
|
||||
learner_host: str = "127.0.0.1"
|
||||
|
@ -110,6 +109,7 @@ class SACConfig(PreTrainedConfig):
|
|||
use_backup_entropy: Whether to use backup entropy for the SAC algorithm.
|
||||
grad_clip_norm: Gradient clipping norm for the SAC algorithm.
|
||||
"""
|
||||
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.MEAN_STD,
|
||||
|
@ -152,8 +152,8 @@ class SACConfig(PreTrainedConfig):
|
|||
camera_number: int = 1
|
||||
device: str = "cuda"
|
||||
storage_device: str = "cpu"
|
||||
# Set to "helper2424/resnet10" for hil serl
|
||||
vision_encoder_name: str | None = None
|
||||
# Set to "helper2424/resnet10" for hil serl
|
||||
vision_encoder_name: str | None = None
|
||||
freeze_vision_encoder: bool = True
|
||||
image_encoder_hidden_dim: int = 32
|
||||
shared_encoder: bool = True
|
||||
|
@ -163,7 +163,7 @@ class SACConfig(PreTrainedConfig):
|
|||
online_env_seed: int = 10000
|
||||
online_buffer_capacity: int = 100000
|
||||
offline_buffer_capacity: int = 100000
|
||||
online_step_before_learning: int = 100
|
||||
online_step_before_learning: int = 100
|
||||
policy_update_freq: int = 1
|
||||
|
||||
# SAC algorithm parameters
|
||||
|
@ -181,24 +181,14 @@ class SACConfig(PreTrainedConfig):
|
|||
target_entropy: float | None = None
|
||||
use_backup_entropy: bool = True
|
||||
grad_clip_norm: float = 40.0
|
||||
|
||||
|
||||
# Network configuration
|
||||
critic_network_kwargs: CriticNetworkConfig = field(
|
||||
default_factory=CriticNetworkConfig
|
||||
)
|
||||
actor_network_kwargs: ActorNetworkConfig = field(
|
||||
default_factory=ActorNetworkConfig
|
||||
)
|
||||
policy_kwargs: PolicyConfig = field(
|
||||
default_factory=PolicyConfig
|
||||
)
|
||||
|
||||
actor_learner_config: ActorLearnerConfig = field(
|
||||
default_factory=ActorLearnerConfig
|
||||
)
|
||||
concurrency: ConcurrencyConfig = field(
|
||||
default_factory=ConcurrencyConfig
|
||||
)
|
||||
critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
|
||||
actor_network_kwargs: ActorNetworkConfig = field(default_factory=ActorNetworkConfig)
|
||||
policy_kwargs: PolicyConfig = field(default_factory=PolicyConfig)
|
||||
|
||||
actor_learner_config: ActorLearnerConfig = field(default_factory=ActorLearnerConfig)
|
||||
concurrency: ConcurrencyConfig = field(default_factory=ConcurrencyConfig)
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
@ -218,18 +208,20 @@ class SACConfig(PreTrainedConfig):
|
|||
return None
|
||||
|
||||
def validate_features(self) -> None:
|
||||
if "observation.image" not in self.input_features:
|
||||
raise ValueError("You must provide 'observation.image' in the input features")
|
||||
|
||||
if "observation.state" not in self.input_features:
|
||||
raise ValueError("You must provide 'observation.state' in the input features")
|
||||
|
||||
has_image = any(key.startswith("observation.image") for key in self.input_features)
|
||||
has_state = "observation.state" in self.input_features
|
||||
|
||||
if not (has_state or has_image):
|
||||
raise ValueError(
|
||||
"You must provide either 'observation.state' or an image observation (key starting with 'observation.image') in the input features"
|
||||
)
|
||||
|
||||
if "action" not in self.output_features:
|
||||
raise ValueError("You must provide 'action' in the output features")
|
||||
|
||||
@property
|
||||
def image_features(self) -> list[str]:
|
||||
return [key for key in self.input_features.keys() if 'image' in key]
|
||||
return [key for key in self.input_features.keys() if "image" in key]
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> list:
|
||||
|
@ -243,9 +235,13 @@ class SACConfig(PreTrainedConfig):
|
|||
def reward_delta_indices(self) -> None:
|
||||
return None
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import draccus
|
||||
|
||||
config = SACConfig()
|
||||
draccus.set_config_type("json")
|
||||
draccus.dump(config=config, stream=open(file='run_config.json', mode='w'), )
|
||||
|
||||
draccus.dump(
|
||||
config=config,
|
||||
stream=open(file="run_config.json", mode="w"),
|
||||
)
|
||||
|
|
|
@ -39,7 +39,6 @@ from lerobot.common.policies.utils import get_device_from_parameters
|
|||
class SACPolicy(
|
||||
PreTrainedPolicy,
|
||||
):
|
||||
|
||||
config_class = SACConfig
|
||||
name = "sac"
|
||||
|
||||
|
@ -53,9 +52,7 @@ class SACPolicy(
|
|||
self.config = config
|
||||
|
||||
if config.dataset_stats is not None:
|
||||
input_normalization_params = _convert_normalization_params_to_tensor(
|
||||
config.dataset_stats
|
||||
)
|
||||
input_normalization_params = _convert_normalization_params_to_tensor(config.dataset_stats)
|
||||
self.normalize_inputs = Normalize(
|
||||
config.input_features,
|
||||
config.normalization_mapping,
|
||||
|
@ -64,12 +61,10 @@ class SACPolicy(
|
|||
else:
|
||||
self.normalize_inputs = nn.Identity()
|
||||
|
||||
output_normalization_params = _convert_normalization_params_to_tensor(
|
||||
config.dataset_stats
|
||||
)
|
||||
output_normalization_params = _convert_normalization_params_to_tensor(config.dataset_stats)
|
||||
|
||||
# HACK: This is hacky and should be removed
|
||||
dataset_stats = dataset_stats or output_normalization_params
|
||||
dataset_stats = dataset_stats or output_normalization_params
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
|
@ -138,7 +133,6 @@ class SACPolicy(
|
|||
self.log_alpha = nn.Parameter(torch.tensor([math.log(temperature_init)]))
|
||||
self.temperature = self.log_alpha.exp().item()
|
||||
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
return {
|
||||
"actor": self.actor.parameters_to_optimize,
|
||||
|
@ -655,9 +649,10 @@ class SACObservationEncoder(nn.Module):
|
|||
class DefaultImageEncoder(nn.Module):
|
||||
def __init__(self, config: SACConfig):
|
||||
super().__init__()
|
||||
image_key = next(key for key in config.input_features.keys() if key.startswith("observation.image")) # noqa: SIM118
|
||||
self.image_enc_layers = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_channels=config.input_features["observation.image"].shape[0],
|
||||
in_channels=config.input_features[image_key].shape[0],
|
||||
out_channels=config.image_encoder_hidden_dim,
|
||||
kernel_size=7,
|
||||
stride=2,
|
||||
|
@ -685,7 +680,9 @@ class DefaultImageEncoder(nn.Module):
|
|||
),
|
||||
nn.ReLU(),
|
||||
)
|
||||
dummy_batch = torch.zeros(1, *config.input_features["observation.image"].shape)
|
||||
# Get first image key from input features
|
||||
image_key = next(key for key in config.input_features.keys() if key.startswith("observation.image")) # noqa: SIM118
|
||||
dummy_batch = torch.zeros(1, *config.input_features[image_key].shape)
|
||||
with torch.inference_mode():
|
||||
self.image_enc_out_shape = self.image_enc_layers(dummy_batch).shape[1:]
|
||||
self.image_enc_layers.extend(
|
||||
|
@ -844,8 +841,10 @@ if __name__ == "__main__":
|
|||
import draccus
|
||||
|
||||
from lerobot.configs import parser
|
||||
|
||||
@parser.wrap()
|
||||
def main(config: SACConfig):
|
||||
policy = SACPolicy(config=config)
|
||||
print("yolo")
|
||||
main()
|
||||
|
||||
main()
|
||||
|
|
|
@ -28,7 +28,6 @@ from torch.multiprocessing import Event, Queue
|
|||
# TODO: Remove the import of maniskill
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.policies.sac.modeling_sac import SACPolicy
|
||||
from lerobot.common.robot_devices.robots.utils import Robot, make_robot
|
||||
from lerobot.common.robot_devices.utils import busy_wait
|
||||
from lerobot.common.utils.random_utils import set_seed
|
||||
from lerobot.common.utils.utils import (
|
||||
|
@ -268,7 +267,7 @@ def update_policy_parameters(policy: SACPolicy, parameters_queue: Queue, device)
|
|||
|
||||
def act_with_policy(
|
||||
cfg: TrainPipelineConfig,
|
||||
robot: Robot,
|
||||
# robot: Robot,
|
||||
reward_classifier: nn.Module,
|
||||
shutdown_event: any, # Event,
|
||||
parameters_queue: Queue,
|
||||
|
@ -287,7 +286,7 @@ def act_with_policy(
|
|||
|
||||
logging.info("make_env online")
|
||||
|
||||
online_env = make_robot_env( cfg=cfg.env)
|
||||
online_env = make_robot_env(cfg=cfg.env)
|
||||
|
||||
set_seed(cfg.seed)
|
||||
device = get_safe_torch_device(cfg.policy.device, log=True)
|
||||
|
@ -503,7 +502,6 @@ def actor_cli(cfg: TrainPipelineConfig):
|
|||
mp.set_start_method("spawn")
|
||||
|
||||
init_logging(log_file="actor.log")
|
||||
robot = make_robot(robot_type=cfg.env.robot)
|
||||
|
||||
shutdown_event = setup_process_handlers(use_threads(cfg))
|
||||
|
||||
|
@ -563,18 +561,17 @@ def actor_cli(cfg: TrainPipelineConfig):
|
|||
# HACK: FOR MANISKILL we do not have a reward classifier
|
||||
# TODO: Remove this once we merge into main
|
||||
reward_classifier = None
|
||||
if (
|
||||
cfg.env.reward_classifier["pretrained_path"] is not None
|
||||
and cfg.env.reward_classifier["config_path"] is not None
|
||||
):
|
||||
reward_classifier = get_classifier(
|
||||
pretrained_path=cfg.env.reward_classifier["pretrained_path"],
|
||||
config_path=cfg.env.reward_classifier["config_path"],
|
||||
)
|
||||
# if (
|
||||
# cfg.env.reward_classifier["pretrained_path"] is not None
|
||||
# and cfg.env.reward_classifier["config_path"] is not None
|
||||
# ):
|
||||
# reward_classifier = get_classifier(
|
||||
# pretrained_path=cfg.env.reward_classifier["pretrained_path"],
|
||||
# config_path=cfg.env.reward_classifier["config_path"],
|
||||
# )
|
||||
|
||||
act_with_policy(
|
||||
cfg=cfg,
|
||||
robot=robot,
|
||||
reward_classifier=reward_classifier,
|
||||
shutdown_event=shutdown_event,
|
||||
parameters_queue=parameters_queue,
|
||||
|
|
|
@ -29,6 +29,9 @@ class InputController:
|
|||
self.z_step_size = z_step_size
|
||||
self.running = True
|
||||
self.episode_end_status = None # None, "success", or "failure"
|
||||
self.intervention_flag = False
|
||||
self.open_gripper_command = False
|
||||
self.close_gripper_command = False
|
||||
|
||||
def start(self):
|
||||
"""Start the controller and initialize resources."""
|
||||
|
@ -70,6 +73,19 @@ class InputController:
|
|||
self.episode_end_status = None # Reset after reading
|
||||
return status
|
||||
|
||||
def should_intervene(self):
|
||||
"""Return True if intervention flag was set."""
|
||||
return self.intervention_flag
|
||||
|
||||
def gripper_command(self):
|
||||
"""Return the current gripper command."""
|
||||
if self.open_gripper_command == self.close_gripper_command:
|
||||
return "no-op"
|
||||
elif self.open_gripper_command:
|
||||
return "open"
|
||||
elif self.close_gripper_command:
|
||||
return "close"
|
||||
|
||||
|
||||
class KeyboardController(InputController):
|
||||
"""Generate motion deltas from keyboard input."""
|
||||
|
@ -326,7 +342,6 @@ class GamepadControllerHID(InputController):
|
|||
self.buttons = {}
|
||||
self.quit_requested = False
|
||||
self.save_requested = False
|
||||
self.intervention_flag = False
|
||||
|
||||
def find_device(self):
|
||||
"""Look for the gamepad device by vendor and product ID."""
|
||||
|
@ -416,7 +431,13 @@ class GamepadControllerHID(InputController):
|
|||
buttons = data[5]
|
||||
|
||||
# Check if RB is pressed then the intervention flag should be set
|
||||
self.intervention_flag = data[6] == 2
|
||||
self.intervention_flag = data[6] in [2, 6, 10, 14]
|
||||
|
||||
# Check if RT is pressed
|
||||
self.open_gripper_command = data[6] in [8, 10, 12]
|
||||
|
||||
# Check if LT is pressed
|
||||
self.close_gripper_command = data[6] in [4, 6, 12]
|
||||
|
||||
# Check if Y/Triangle button (bit 7) is pressed for saving
|
||||
# Check if X/Square button (bit 5) is pressed for failure
|
||||
|
@ -676,12 +697,8 @@ def teleoperate_gym_env(env, controller, fps: int = 30):
|
|||
if __name__ == "__main__":
|
||||
from lerobot.common.robot_devices.robots.configs import RobotConfig
|
||||
from lerobot.common.robot_devices.robots.utils import make_robot_from_config
|
||||
from lerobot.scripts.server.gym_manipulator import (
|
||||
EEActionSpaceConfig,
|
||||
EnvWrapperConfig,
|
||||
HILSerlRobotEnvConfig,
|
||||
make_robot_env,
|
||||
)
|
||||
from lerobot.scripts.server.gym_manipulator import make_robot_env
|
||||
from lerobot.common.envs.configs import HILSerlRobotEnvConfig, EEActionSpaceConfig, EnvWrapperConfig
|
||||
|
||||
parser = argparse.ArgumentParser(description="Test end-effector control")
|
||||
parser.add_argument(
|
||||
|
|
|
@ -1,9 +1,8 @@
|
|||
import logging
|
||||
import sys
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from threading import Lock
|
||||
from typing import Annotated, Any, Dict, Optional, Tuple
|
||||
from typing import Annotated, Any, Dict, Tuple
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
|
@ -17,66 +16,13 @@ from lerobot.common.robot_devices.control_utils import (
|
|||
is_headless,
|
||||
reset_follower_position,
|
||||
)
|
||||
from lerobot.common.robot_devices.robots.configs import RobotConfig
|
||||
from lerobot.common.robot_devices.robots.utils import make_robot_from_config
|
||||
from lerobot.common.utils.utils import log_say
|
||||
from lerobot.configs import parser
|
||||
from lerobot.scripts.server.kinematics import RobotKinematics
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EEActionSpaceConfig:
|
||||
"""Configuration parameters for end-effector action space."""
|
||||
|
||||
x_step_size: float
|
||||
y_step_size: float
|
||||
z_step_size: float
|
||||
bounds: Dict[str, Any] # Contains 'min' and 'max' keys with position bounds
|
||||
use_gamepad: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class EnvWrapperConfig:
|
||||
"""Configuration for environment wrappers."""
|
||||
|
||||
display_cameras: bool = False
|
||||
delta_action: float = 0.1
|
||||
use_relative_joint_positions: bool = True
|
||||
add_joint_velocity_to_observation: bool = False
|
||||
add_ee_pose_to_observation: bool = False
|
||||
crop_params_dict: Optional[Dict[str, Tuple[int, int, int, int]]] = None
|
||||
resize_size: Optional[Tuple[int, int]] = None
|
||||
control_time_s: float = 20.0
|
||||
fixed_reset_joint_positions: Optional[Any] = None
|
||||
reset_time_s: float = 5.0
|
||||
joint_masking_action_space: Optional[Any] = None
|
||||
ee_action_space_params: Optional[EEActionSpaceConfig] = None
|
||||
reward_classifier_pretrained_path: Optional[str] = None
|
||||
reward_classifier_config_file: Optional[str] = None
|
||||
|
||||
|
||||
@EnvConfig.register_subclass(name="gym_manipulator")
|
||||
@dataclass
|
||||
class HILSerlRobotEnvConfig(EnvConfig):
|
||||
"""Configuration for the HILSerlRobotEnv environment."""
|
||||
|
||||
robot: Optional[RobotConfig] = None
|
||||
wrapper: Optional[EnvWrapperConfig] = None
|
||||
fps: int = 10
|
||||
mode: str = None # Either "record", "replay", None
|
||||
repo_id: Optional[str] = None
|
||||
dataset_root: Optional[str] = None
|
||||
task: str = ""
|
||||
num_episodes: int = 10 # only for record mode
|
||||
episode: int = 0
|
||||
device: str = "cuda"
|
||||
push_to_hub: bool = True
|
||||
pretrained_policy_name_or_path: Optional[str] = None
|
||||
|
||||
def gym_kwargs(self) -> dict:
|
||||
return {}
|
||||
MAX_GRIPPER_COMMAND = 25
|
||||
|
||||
|
||||
class HILSerlRobotEnv(gym.Env):
|
||||
|
@ -813,9 +759,10 @@ class BatchCompitableWrapper(gym.ObservationWrapper):
|
|||
|
||||
|
||||
class EEActionWrapper(gym.ActionWrapper):
|
||||
def __init__(self, env, ee_action_space_params=None):
|
||||
def __init__(self, env, ee_action_space_params=None, use_gripper=False):
|
||||
super().__init__(env)
|
||||
self.ee_action_space_params = ee_action_space_params
|
||||
self.use_gripper = use_gripper
|
||||
|
||||
# Initialize kinematics instance for the appropriate robot type
|
||||
robot_type = getattr(env.unwrapped.robot.config, "robot_type", "so100")
|
||||
|
@ -829,10 +776,12 @@ class EEActionWrapper(gym.ActionWrapper):
|
|||
ee_action_space_params.z_step_size,
|
||||
]
|
||||
)
|
||||
if self.use_gripper:
|
||||
action_space_bounds = np.concatenate([action_space_bounds, [1.0]])
|
||||
ee_action_space = gym.spaces.Box(
|
||||
low=-action_space_bounds,
|
||||
high=action_space_bounds,
|
||||
shape=(3,),
|
||||
shape=(3 + int(self.use_gripper),),
|
||||
dtype=np.float32,
|
||||
)
|
||||
if isinstance(self.action_space, gym.spaces.Tuple):
|
||||
|
@ -848,6 +797,10 @@ class EEActionWrapper(gym.ActionWrapper):
|
|||
if isinstance(action, tuple):
|
||||
action, _ = action
|
||||
|
||||
if self.use_gripper:
|
||||
gripper_command = action[-1]
|
||||
action = action[:-1]
|
||||
|
||||
current_joint_pos = self.unwrapped.robot.follower_arms["main"].read("Present_Position")
|
||||
current_ee_pos = self.fk_function(current_joint_pos)
|
||||
if isinstance(action, torch.Tensor):
|
||||
|
@ -863,6 +816,12 @@ class EEActionWrapper(gym.ActionWrapper):
|
|||
position_only=True,
|
||||
fk_func=self.fk_function,
|
||||
)
|
||||
if self.use_gripper:
|
||||
gripper_command = gripper_command * MAX_GRIPPER_COMMAND
|
||||
gripper_state = self.unwrapped.robot.follower_arms["main"].read("Present_Position")[-1]
|
||||
gripper_action = np.clip(gripper_state + gripper_command, 0, MAX_GRIPPER_COMMAND)
|
||||
target_joint_pos[-1] = gripper_action
|
||||
|
||||
return target_joint_pos, is_intervention
|
||||
|
||||
|
||||
|
@ -912,6 +871,7 @@ class GamepadControlWrapper(gym.Wrapper):
|
|||
x_step_size=1.0,
|
||||
y_step_size=1.0,
|
||||
z_step_size=1.0,
|
||||
use_gripper=False,
|
||||
auto_reset=False,
|
||||
input_threshold=0.001,
|
||||
):
|
||||
|
@ -948,6 +908,7 @@ class GamepadControlWrapper(gym.Wrapper):
|
|||
z_step_size=z_step_size,
|
||||
)
|
||||
self.auto_reset = auto_reset
|
||||
self.use_gripper = use_gripper
|
||||
self.input_threshold = input_threshold
|
||||
self.controller.start()
|
||||
|
||||
|
@ -977,6 +938,15 @@ class GamepadControlWrapper(gym.Wrapper):
|
|||
# Create action from gamepad input
|
||||
gamepad_action = np.array([delta_x, delta_y, delta_z], dtype=np.float32)
|
||||
|
||||
if self.use_gripper:
|
||||
gripper_command = self.controller.gripper_command()
|
||||
if gripper_command == "open":
|
||||
gamepad_action = np.concatenate([gamepad_action, [1.0]])
|
||||
elif gripper_command == "close":
|
||||
gamepad_action = np.concatenate([gamepad_action, [-1.0]])
|
||||
else:
|
||||
gamepad_action = np.concatenate([gamepad_action, [0.0]])
|
||||
|
||||
# Check episode ending buttons
|
||||
# We'll rely on controller.get_episode_end_status() which returns "success", "failure", or None
|
||||
episode_end_status = self.controller.get_episode_end_status()
|
||||
|
@ -1023,6 +993,7 @@ class GamepadControlWrapper(gym.Wrapper):
|
|||
final_action = (torch.from_numpy(gamepad_action), False)
|
||||
else:
|
||||
final_action = torch.from_numpy(gamepad_action)
|
||||
|
||||
else:
|
||||
# Use the original action
|
||||
final_action = action
|
||||
|
@ -1138,7 +1109,11 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv:
|
|||
# env = RewardWrapper(env=env, reward_classifier=reward_classifier, device=cfg.device)
|
||||
env = TimeLimitWrapper(env=env, control_time_s=cfg.wrapper.control_time_s, fps=cfg.fps)
|
||||
if cfg.wrapper.ee_action_space_params is not None:
|
||||
env = EEActionWrapper(env=env, ee_action_space_params=cfg.wrapper.ee_action_space_params)
|
||||
env = EEActionWrapper(
|
||||
env=env,
|
||||
ee_action_space_params=cfg.wrapper.ee_action_space_params,
|
||||
use_gripper=cfg.wrapper.use_gripper,
|
||||
)
|
||||
if cfg.wrapper.ee_action_space_params is not None and cfg.wrapper.ee_action_space_params.use_gamepad:
|
||||
# env = ActionScaleWrapper(env=env, ee_action_space_params=cfg.wrapper.ee_action_space_params)
|
||||
env = GamepadControlWrapper(
|
||||
|
@ -1146,6 +1121,7 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv:
|
|||
x_step_size=cfg.wrapper.ee_action_space_params.x_step_size,
|
||||
y_step_size=cfg.wrapper.ee_action_space_params.y_step_size,
|
||||
z_step_size=cfg.wrapper.ee_action_space_params.z_step_size,
|
||||
use_gripper=cfg.wrapper.use_gripper,
|
||||
)
|
||||
else:
|
||||
env = KeyboardInterfaceWrapper(env=env)
|
||||
|
@ -1184,7 +1160,7 @@ def get_classifier(cfg):
|
|||
return model
|
||||
|
||||
|
||||
def record_dataset(env, policy, cfg: HILSerlRobotEnvConfig):
|
||||
def record_dataset(env, policy, cfg):
|
||||
"""
|
||||
Record a dataset of robot interactions using either a policy or teleop.
|
||||
|
||||
|
|
Loading…
Reference in New Issue