Added gripper control mechanism to gym_manipulator
Moved HilSerl env config to configs/env/configs.py fixes in actor_server and modeling_sac and configuration_sac added the possibility of ignoring missing keys in env_cfg in get_features_from_env_config function
This commit is contained in:
parent
88cc2b8fc8
commit
05a237ce10
|
@ -14,10 +14,12 @@
|
||||||
|
|
||||||
import abc
|
import abc
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any, Dict, Optional, Tuple
|
||||||
|
|
||||||
import draccus
|
import draccus
|
||||||
|
|
||||||
from lerobot.common.constants import ACTION, OBS_ENV, OBS_IMAGE, OBS_IMAGES, OBS_ROBOT
|
from lerobot.common.constants import ACTION, OBS_ENV, OBS_IMAGE, OBS_IMAGES, OBS_ROBOT
|
||||||
|
from lerobot.common.robot_devices.robots.configs import RobotConfig
|
||||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||||
|
|
||||||
|
|
||||||
|
@ -159,20 +161,84 @@ class XarmEnv(EnvConfig):
|
||||||
@dataclass
|
@dataclass
|
||||||
class VideoRecordConfig:
|
class VideoRecordConfig:
|
||||||
"""Configuration for video recording in ManiSkill environments."""
|
"""Configuration for video recording in ManiSkill environments."""
|
||||||
|
|
||||||
enabled: bool = False
|
enabled: bool = False
|
||||||
record_dir: str = "videos"
|
record_dir: str = "videos"
|
||||||
trajectory_name: str = "trajectory"
|
trajectory_name: str = "trajectory"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class WrapperConfig:
|
class WrapperConfig:
|
||||||
"""Configuration for environment wrappers."""
|
"""Configuration for environment wrappers."""
|
||||||
|
|
||||||
delta_action: float | None = None
|
delta_action: float | None = None
|
||||||
joint_masking_action_space: list[bool] | None = None
|
joint_masking_action_space: list[bool] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EEActionSpaceConfig:
|
||||||
|
"""Configuration parameters for end-effector action space."""
|
||||||
|
|
||||||
|
x_step_size: float
|
||||||
|
y_step_size: float
|
||||||
|
z_step_size: float
|
||||||
|
bounds: Dict[str, Any] # Contains 'min' and 'max' keys with position bounds
|
||||||
|
use_gamepad: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EnvWrapperConfig:
|
||||||
|
"""Configuration for environment wrappers."""
|
||||||
|
|
||||||
|
display_cameras: bool = False
|
||||||
|
delta_action: float = 0.1
|
||||||
|
use_relative_joint_positions: bool = True
|
||||||
|
add_joint_velocity_to_observation: bool = False
|
||||||
|
add_ee_pose_to_observation: bool = False
|
||||||
|
crop_params_dict: Optional[Dict[str, Tuple[int, int, int, int]]] = None
|
||||||
|
resize_size: Optional[Tuple[int, int]] = None
|
||||||
|
control_time_s: float = 20.0
|
||||||
|
fixed_reset_joint_positions: Optional[Any] = None
|
||||||
|
reset_time_s: float = 5.0
|
||||||
|
joint_masking_action_space: Optional[Any] = None
|
||||||
|
ee_action_space_params: Optional[EEActionSpaceConfig] = None
|
||||||
|
use_gripper: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
@EnvConfig.register_subclass(name="gym_manipulator")
|
||||||
|
@dataclass
|
||||||
|
class HILSerlRobotEnvConfig(EnvConfig):
|
||||||
|
"""Configuration for the HILSerlRobotEnv environment."""
|
||||||
|
|
||||||
|
robot: Optional[RobotConfig] = None
|
||||||
|
wrapper: Optional[EnvWrapperConfig] = None
|
||||||
|
fps: int = 10
|
||||||
|
name: str = "real_robot"
|
||||||
|
mode: str = None # Either "record", "replay", None
|
||||||
|
repo_id: Optional[str] = None
|
||||||
|
dataset_root: Optional[str] = None
|
||||||
|
task: str = ""
|
||||||
|
num_episodes: int = 10 # only for record mode
|
||||||
|
episode: int = 0
|
||||||
|
device: str = "cuda"
|
||||||
|
push_to_hub: bool = True
|
||||||
|
pretrained_policy_name_or_path: Optional[str] = None
|
||||||
|
reward_classifier: dict[str, str | None] = field(
|
||||||
|
default_factory=lambda: {
|
||||||
|
"pretrained_path": None,
|
||||||
|
"config_path": None,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def gym_kwargs(self) -> dict:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
@EnvConfig.register_subclass("maniskill_push")
|
@EnvConfig.register_subclass("maniskill_push")
|
||||||
@dataclass
|
@dataclass
|
||||||
class ManiskillEnvConfig(EnvConfig):
|
class ManiskillEnvConfig(EnvConfig):
|
||||||
"""Configuration for the ManiSkill environment."""
|
"""Configuration for the ManiSkill environment."""
|
||||||
|
|
||||||
name: str = "maniskill/pushcube"
|
name: str = "maniskill/pushcube"
|
||||||
task: str = "PushCube-v1"
|
task: str = "PushCube-v1"
|
||||||
image_size: int = 64
|
image_size: int = 64
|
||||||
|
@ -185,7 +251,7 @@ class ManiskillEnvConfig(EnvConfig):
|
||||||
render_mode: str = "rgb_array"
|
render_mode: str = "rgb_array"
|
||||||
render_size: int = 64
|
render_size: int = 64
|
||||||
device: str = "cuda"
|
device: str = "cuda"
|
||||||
robot: str = "so100" # This is a hack to make the robot config work
|
robot: str = "so100" # This is a hack to make the robot config work
|
||||||
video_record: VideoRecordConfig = field(default_factory=VideoRecordConfig)
|
video_record: VideoRecordConfig = field(default_factory=VideoRecordConfig)
|
||||||
wrapper: WrapperConfig = field(default_factory=WrapperConfig)
|
wrapper: WrapperConfig = field(default_factory=WrapperConfig)
|
||||||
features: dict[str, PolicyFeature] = field(
|
features: dict[str, PolicyFeature] = field(
|
||||||
|
|
|
@ -53,9 +53,7 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
|
||||||
|
|
||||||
# sanity check that images are channel last
|
# sanity check that images are channel last
|
||||||
_, h, w, c = img.shape
|
_, h, w, c = img.shape
|
||||||
assert c < h and c < w, (
|
assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}"
|
||||||
f"expect channel last images, but instead got {img.shape=}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# sanity check that images are uint8
|
# sanity check that images are uint8
|
||||||
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
|
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
|
||||||
|
@ -95,7 +93,7 @@ def env_to_policy_features(env_cfg: EnvConfig) -> dict[str, PolicyFeature]:
|
||||||
else:
|
else:
|
||||||
feature = ft
|
feature = ft
|
||||||
|
|
||||||
policy_key = env_cfg.features_map[key]
|
policy_key = env_cfg.features_map.get(key, key)
|
||||||
policy_features[policy_key] = feature
|
policy_features[policy_key] = feature
|
||||||
|
|
||||||
return policy_features
|
return policy_features
|
||||||
|
|
|
@ -20,7 +20,7 @@ from typing import Any, Optional
|
||||||
|
|
||||||
from lerobot.common.optim.optimizers import MultiAdamConfig
|
from lerobot.common.optim.optimizers import MultiAdamConfig
|
||||||
from lerobot.configs.policies import PreTrainedConfig
|
from lerobot.configs.policies import PreTrainedConfig
|
||||||
from lerobot.configs.types import NormalizationMode, PolicyFeature, FeatureType
|
from lerobot.configs.types import NormalizationMode, PolicyFeature, FeatureType
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -29,7 +29,6 @@ class ConcurrencyConfig:
|
||||||
learner: str = "threads"
|
learner: str = "threads"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ActorLearnerConfig:
|
class ActorLearnerConfig:
|
||||||
learner_host: str = "127.0.0.1"
|
learner_host: str = "127.0.0.1"
|
||||||
|
@ -110,6 +109,7 @@ class SACConfig(PreTrainedConfig):
|
||||||
use_backup_entropy: Whether to use backup entropy for the SAC algorithm.
|
use_backup_entropy: Whether to use backup entropy for the SAC algorithm.
|
||||||
grad_clip_norm: Gradient clipping norm for the SAC algorithm.
|
grad_clip_norm: Gradient clipping norm for the SAC algorithm.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||||
default_factory=lambda: {
|
default_factory=lambda: {
|
||||||
"VISUAL": NormalizationMode.MEAN_STD,
|
"VISUAL": NormalizationMode.MEAN_STD,
|
||||||
|
@ -183,22 +183,12 @@ class SACConfig(PreTrainedConfig):
|
||||||
grad_clip_norm: float = 40.0
|
grad_clip_norm: float = 40.0
|
||||||
|
|
||||||
# Network configuration
|
# Network configuration
|
||||||
critic_network_kwargs: CriticNetworkConfig = field(
|
critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
|
||||||
default_factory=CriticNetworkConfig
|
actor_network_kwargs: ActorNetworkConfig = field(default_factory=ActorNetworkConfig)
|
||||||
)
|
policy_kwargs: PolicyConfig = field(default_factory=PolicyConfig)
|
||||||
actor_network_kwargs: ActorNetworkConfig = field(
|
|
||||||
default_factory=ActorNetworkConfig
|
|
||||||
)
|
|
||||||
policy_kwargs: PolicyConfig = field(
|
|
||||||
default_factory=PolicyConfig
|
|
||||||
)
|
|
||||||
|
|
||||||
actor_learner_config: ActorLearnerConfig = field(
|
actor_learner_config: ActorLearnerConfig = field(default_factory=ActorLearnerConfig)
|
||||||
default_factory=ActorLearnerConfig
|
concurrency: ConcurrencyConfig = field(default_factory=ConcurrencyConfig)
|
||||||
)
|
|
||||||
concurrency: ConcurrencyConfig = field(
|
|
||||||
default_factory=ConcurrencyConfig
|
|
||||||
)
|
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
super().__post_init__()
|
super().__post_init__()
|
||||||
|
@ -218,18 +208,20 @@ class SACConfig(PreTrainedConfig):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def validate_features(self) -> None:
|
def validate_features(self) -> None:
|
||||||
if "observation.image" not in self.input_features:
|
has_image = any(key.startswith("observation.image") for key in self.input_features)
|
||||||
raise ValueError("You must provide 'observation.image' in the input features")
|
has_state = "observation.state" in self.input_features
|
||||||
|
|
||||||
if "observation.state" not in self.input_features:
|
if not (has_state or has_image):
|
||||||
raise ValueError("You must provide 'observation.state' in the input features")
|
raise ValueError(
|
||||||
|
"You must provide either 'observation.state' or an image observation (key starting with 'observation.image') in the input features"
|
||||||
|
)
|
||||||
|
|
||||||
if "action" not in self.output_features:
|
if "action" not in self.output_features:
|
||||||
raise ValueError("You must provide 'action' in the output features")
|
raise ValueError("You must provide 'action' in the output features")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def image_features(self) -> list[str]:
|
def image_features(self) -> list[str]:
|
||||||
return [key for key in self.input_features.keys() if 'image' in key]
|
return [key for key in self.input_features.keys() if "image" in key]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def observation_delta_indices(self) -> list:
|
def observation_delta_indices(self) -> list:
|
||||||
|
@ -243,9 +235,13 @@ class SACConfig(PreTrainedConfig):
|
||||||
def reward_delta_indices(self) -> None:
|
def reward_delta_indices(self) -> None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import draccus
|
import draccus
|
||||||
|
|
||||||
config = SACConfig()
|
config = SACConfig()
|
||||||
draccus.set_config_type("json")
|
draccus.set_config_type("json")
|
||||||
draccus.dump(config=config, stream=open(file='run_config.json', mode='w'), )
|
draccus.dump(
|
||||||
|
config=config,
|
||||||
|
stream=open(file="run_config.json", mode="w"),
|
||||||
|
)
|
||||||
|
|
|
@ -39,7 +39,6 @@ from lerobot.common.policies.utils import get_device_from_parameters
|
||||||
class SACPolicy(
|
class SACPolicy(
|
||||||
PreTrainedPolicy,
|
PreTrainedPolicy,
|
||||||
):
|
):
|
||||||
|
|
||||||
config_class = SACConfig
|
config_class = SACConfig
|
||||||
name = "sac"
|
name = "sac"
|
||||||
|
|
||||||
|
@ -53,9 +52,7 @@ class SACPolicy(
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
if config.dataset_stats is not None:
|
if config.dataset_stats is not None:
|
||||||
input_normalization_params = _convert_normalization_params_to_tensor(
|
input_normalization_params = _convert_normalization_params_to_tensor(config.dataset_stats)
|
||||||
config.dataset_stats
|
|
||||||
)
|
|
||||||
self.normalize_inputs = Normalize(
|
self.normalize_inputs = Normalize(
|
||||||
config.input_features,
|
config.input_features,
|
||||||
config.normalization_mapping,
|
config.normalization_mapping,
|
||||||
|
@ -64,9 +61,7 @@ class SACPolicy(
|
||||||
else:
|
else:
|
||||||
self.normalize_inputs = nn.Identity()
|
self.normalize_inputs = nn.Identity()
|
||||||
|
|
||||||
output_normalization_params = _convert_normalization_params_to_tensor(
|
output_normalization_params = _convert_normalization_params_to_tensor(config.dataset_stats)
|
||||||
config.dataset_stats
|
|
||||||
)
|
|
||||||
|
|
||||||
# HACK: This is hacky and should be removed
|
# HACK: This is hacky and should be removed
|
||||||
dataset_stats = dataset_stats or output_normalization_params
|
dataset_stats = dataset_stats or output_normalization_params
|
||||||
|
@ -138,7 +133,6 @@ class SACPolicy(
|
||||||
self.log_alpha = nn.Parameter(torch.tensor([math.log(temperature_init)]))
|
self.log_alpha = nn.Parameter(torch.tensor([math.log(temperature_init)]))
|
||||||
self.temperature = self.log_alpha.exp().item()
|
self.temperature = self.log_alpha.exp().item()
|
||||||
|
|
||||||
|
|
||||||
def get_optim_params(self) -> dict:
|
def get_optim_params(self) -> dict:
|
||||||
return {
|
return {
|
||||||
"actor": self.actor.parameters_to_optimize,
|
"actor": self.actor.parameters_to_optimize,
|
||||||
|
@ -655,9 +649,10 @@ class SACObservationEncoder(nn.Module):
|
||||||
class DefaultImageEncoder(nn.Module):
|
class DefaultImageEncoder(nn.Module):
|
||||||
def __init__(self, config: SACConfig):
|
def __init__(self, config: SACConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
image_key = next(key for key in config.input_features.keys() if key.startswith("observation.image")) # noqa: SIM118
|
||||||
self.image_enc_layers = nn.Sequential(
|
self.image_enc_layers = nn.Sequential(
|
||||||
nn.Conv2d(
|
nn.Conv2d(
|
||||||
in_channels=config.input_features["observation.image"].shape[0],
|
in_channels=config.input_features[image_key].shape[0],
|
||||||
out_channels=config.image_encoder_hidden_dim,
|
out_channels=config.image_encoder_hidden_dim,
|
||||||
kernel_size=7,
|
kernel_size=7,
|
||||||
stride=2,
|
stride=2,
|
||||||
|
@ -685,7 +680,9 @@ class DefaultImageEncoder(nn.Module):
|
||||||
),
|
),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
)
|
)
|
||||||
dummy_batch = torch.zeros(1, *config.input_features["observation.image"].shape)
|
# Get first image key from input features
|
||||||
|
image_key = next(key for key in config.input_features.keys() if key.startswith("observation.image")) # noqa: SIM118
|
||||||
|
dummy_batch = torch.zeros(1, *config.input_features[image_key].shape)
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
self.image_enc_out_shape = self.image_enc_layers(dummy_batch).shape[1:]
|
self.image_enc_out_shape = self.image_enc_layers(dummy_batch).shape[1:]
|
||||||
self.image_enc_layers.extend(
|
self.image_enc_layers.extend(
|
||||||
|
@ -844,8 +841,10 @@ if __name__ == "__main__":
|
||||||
import draccus
|
import draccus
|
||||||
|
|
||||||
from lerobot.configs import parser
|
from lerobot.configs import parser
|
||||||
|
|
||||||
@parser.wrap()
|
@parser.wrap()
|
||||||
def main(config: SACConfig):
|
def main(config: SACConfig):
|
||||||
policy = SACPolicy(config=config)
|
policy = SACPolicy(config=config)
|
||||||
print("yolo")
|
print("yolo")
|
||||||
|
|
||||||
main()
|
main()
|
|
@ -28,7 +28,6 @@ from torch.multiprocessing import Event, Queue
|
||||||
# TODO: Remove the import of maniskill
|
# TODO: Remove the import of maniskill
|
||||||
from lerobot.common.policies.factory import make_policy
|
from lerobot.common.policies.factory import make_policy
|
||||||
from lerobot.common.policies.sac.modeling_sac import SACPolicy
|
from lerobot.common.policies.sac.modeling_sac import SACPolicy
|
||||||
from lerobot.common.robot_devices.robots.utils import Robot, make_robot
|
|
||||||
from lerobot.common.robot_devices.utils import busy_wait
|
from lerobot.common.robot_devices.utils import busy_wait
|
||||||
from lerobot.common.utils.random_utils import set_seed
|
from lerobot.common.utils.random_utils import set_seed
|
||||||
from lerobot.common.utils.utils import (
|
from lerobot.common.utils.utils import (
|
||||||
|
@ -268,7 +267,7 @@ def update_policy_parameters(policy: SACPolicy, parameters_queue: Queue, device)
|
||||||
|
|
||||||
def act_with_policy(
|
def act_with_policy(
|
||||||
cfg: TrainPipelineConfig,
|
cfg: TrainPipelineConfig,
|
||||||
robot: Robot,
|
# robot: Robot,
|
||||||
reward_classifier: nn.Module,
|
reward_classifier: nn.Module,
|
||||||
shutdown_event: any, # Event,
|
shutdown_event: any, # Event,
|
||||||
parameters_queue: Queue,
|
parameters_queue: Queue,
|
||||||
|
@ -287,7 +286,7 @@ def act_with_policy(
|
||||||
|
|
||||||
logging.info("make_env online")
|
logging.info("make_env online")
|
||||||
|
|
||||||
online_env = make_robot_env( cfg=cfg.env)
|
online_env = make_robot_env(cfg=cfg.env)
|
||||||
|
|
||||||
set_seed(cfg.seed)
|
set_seed(cfg.seed)
|
||||||
device = get_safe_torch_device(cfg.policy.device, log=True)
|
device = get_safe_torch_device(cfg.policy.device, log=True)
|
||||||
|
@ -503,7 +502,6 @@ def actor_cli(cfg: TrainPipelineConfig):
|
||||||
mp.set_start_method("spawn")
|
mp.set_start_method("spawn")
|
||||||
|
|
||||||
init_logging(log_file="actor.log")
|
init_logging(log_file="actor.log")
|
||||||
robot = make_robot(robot_type=cfg.env.robot)
|
|
||||||
|
|
||||||
shutdown_event = setup_process_handlers(use_threads(cfg))
|
shutdown_event = setup_process_handlers(use_threads(cfg))
|
||||||
|
|
||||||
|
@ -563,18 +561,17 @@ def actor_cli(cfg: TrainPipelineConfig):
|
||||||
# HACK: FOR MANISKILL we do not have a reward classifier
|
# HACK: FOR MANISKILL we do not have a reward classifier
|
||||||
# TODO: Remove this once we merge into main
|
# TODO: Remove this once we merge into main
|
||||||
reward_classifier = None
|
reward_classifier = None
|
||||||
if (
|
# if (
|
||||||
cfg.env.reward_classifier["pretrained_path"] is not None
|
# cfg.env.reward_classifier["pretrained_path"] is not None
|
||||||
and cfg.env.reward_classifier["config_path"] is not None
|
# and cfg.env.reward_classifier["config_path"] is not None
|
||||||
):
|
# ):
|
||||||
reward_classifier = get_classifier(
|
# reward_classifier = get_classifier(
|
||||||
pretrained_path=cfg.env.reward_classifier["pretrained_path"],
|
# pretrained_path=cfg.env.reward_classifier["pretrained_path"],
|
||||||
config_path=cfg.env.reward_classifier["config_path"],
|
# config_path=cfg.env.reward_classifier["config_path"],
|
||||||
)
|
# )
|
||||||
|
|
||||||
act_with_policy(
|
act_with_policy(
|
||||||
cfg=cfg,
|
cfg=cfg,
|
||||||
robot=robot,
|
|
||||||
reward_classifier=reward_classifier,
|
reward_classifier=reward_classifier,
|
||||||
shutdown_event=shutdown_event,
|
shutdown_event=shutdown_event,
|
||||||
parameters_queue=parameters_queue,
|
parameters_queue=parameters_queue,
|
||||||
|
|
|
@ -29,6 +29,9 @@ class InputController:
|
||||||
self.z_step_size = z_step_size
|
self.z_step_size = z_step_size
|
||||||
self.running = True
|
self.running = True
|
||||||
self.episode_end_status = None # None, "success", or "failure"
|
self.episode_end_status = None # None, "success", or "failure"
|
||||||
|
self.intervention_flag = False
|
||||||
|
self.open_gripper_command = False
|
||||||
|
self.close_gripper_command = False
|
||||||
|
|
||||||
def start(self):
|
def start(self):
|
||||||
"""Start the controller and initialize resources."""
|
"""Start the controller and initialize resources."""
|
||||||
|
@ -70,6 +73,19 @@ class InputController:
|
||||||
self.episode_end_status = None # Reset after reading
|
self.episode_end_status = None # Reset after reading
|
||||||
return status
|
return status
|
||||||
|
|
||||||
|
def should_intervene(self):
|
||||||
|
"""Return True if intervention flag was set."""
|
||||||
|
return self.intervention_flag
|
||||||
|
|
||||||
|
def gripper_command(self):
|
||||||
|
"""Return the current gripper command."""
|
||||||
|
if self.open_gripper_command == self.close_gripper_command:
|
||||||
|
return "no-op"
|
||||||
|
elif self.open_gripper_command:
|
||||||
|
return "open"
|
||||||
|
elif self.close_gripper_command:
|
||||||
|
return "close"
|
||||||
|
|
||||||
|
|
||||||
class KeyboardController(InputController):
|
class KeyboardController(InputController):
|
||||||
"""Generate motion deltas from keyboard input."""
|
"""Generate motion deltas from keyboard input."""
|
||||||
|
@ -326,7 +342,6 @@ class GamepadControllerHID(InputController):
|
||||||
self.buttons = {}
|
self.buttons = {}
|
||||||
self.quit_requested = False
|
self.quit_requested = False
|
||||||
self.save_requested = False
|
self.save_requested = False
|
||||||
self.intervention_flag = False
|
|
||||||
|
|
||||||
def find_device(self):
|
def find_device(self):
|
||||||
"""Look for the gamepad device by vendor and product ID."""
|
"""Look for the gamepad device by vendor and product ID."""
|
||||||
|
@ -416,7 +431,13 @@ class GamepadControllerHID(InputController):
|
||||||
buttons = data[5]
|
buttons = data[5]
|
||||||
|
|
||||||
# Check if RB is pressed then the intervention flag should be set
|
# Check if RB is pressed then the intervention flag should be set
|
||||||
self.intervention_flag = data[6] == 2
|
self.intervention_flag = data[6] in [2, 6, 10, 14]
|
||||||
|
|
||||||
|
# Check if RT is pressed
|
||||||
|
self.open_gripper_command = data[6] in [8, 10, 12]
|
||||||
|
|
||||||
|
# Check if LT is pressed
|
||||||
|
self.close_gripper_command = data[6] in [4, 6, 12]
|
||||||
|
|
||||||
# Check if Y/Triangle button (bit 7) is pressed for saving
|
# Check if Y/Triangle button (bit 7) is pressed for saving
|
||||||
# Check if X/Square button (bit 5) is pressed for failure
|
# Check if X/Square button (bit 5) is pressed for failure
|
||||||
|
@ -676,12 +697,8 @@ def teleoperate_gym_env(env, controller, fps: int = 30):
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
from lerobot.common.robot_devices.robots.configs import RobotConfig
|
from lerobot.common.robot_devices.robots.configs import RobotConfig
|
||||||
from lerobot.common.robot_devices.robots.utils import make_robot_from_config
|
from lerobot.common.robot_devices.robots.utils import make_robot_from_config
|
||||||
from lerobot.scripts.server.gym_manipulator import (
|
from lerobot.scripts.server.gym_manipulator import make_robot_env
|
||||||
EEActionSpaceConfig,
|
from lerobot.common.envs.configs import HILSerlRobotEnvConfig, EEActionSpaceConfig, EnvWrapperConfig
|
||||||
EnvWrapperConfig,
|
|
||||||
HILSerlRobotEnvConfig,
|
|
||||||
make_robot_env,
|
|
||||||
)
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="Test end-effector control")
|
parser = argparse.ArgumentParser(description="Test end-effector control")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|
|
@ -1,9 +1,8 @@
|
||||||
import logging
|
import logging
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
from dataclasses import dataclass
|
|
||||||
from threading import Lock
|
from threading import Lock
|
||||||
from typing import Annotated, Any, Dict, Optional, Tuple
|
from typing import Annotated, Any, Dict, Tuple
|
||||||
|
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -17,66 +16,13 @@ from lerobot.common.robot_devices.control_utils import (
|
||||||
is_headless,
|
is_headless,
|
||||||
reset_follower_position,
|
reset_follower_position,
|
||||||
)
|
)
|
||||||
from lerobot.common.robot_devices.robots.configs import RobotConfig
|
|
||||||
from lerobot.common.robot_devices.robots.utils import make_robot_from_config
|
from lerobot.common.robot_devices.robots.utils import make_robot_from_config
|
||||||
from lerobot.common.utils.utils import log_say
|
from lerobot.common.utils.utils import log_say
|
||||||
from lerobot.configs import parser
|
from lerobot.configs import parser
|
||||||
from lerobot.scripts.server.kinematics import RobotKinematics
|
from lerobot.scripts.server.kinematics import RobotKinematics
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
MAX_GRIPPER_COMMAND = 25
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class EEActionSpaceConfig:
|
|
||||||
"""Configuration parameters for end-effector action space."""
|
|
||||||
|
|
||||||
x_step_size: float
|
|
||||||
y_step_size: float
|
|
||||||
z_step_size: float
|
|
||||||
bounds: Dict[str, Any] # Contains 'min' and 'max' keys with position bounds
|
|
||||||
use_gamepad: bool = False
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class EnvWrapperConfig:
|
|
||||||
"""Configuration for environment wrappers."""
|
|
||||||
|
|
||||||
display_cameras: bool = False
|
|
||||||
delta_action: float = 0.1
|
|
||||||
use_relative_joint_positions: bool = True
|
|
||||||
add_joint_velocity_to_observation: bool = False
|
|
||||||
add_ee_pose_to_observation: bool = False
|
|
||||||
crop_params_dict: Optional[Dict[str, Tuple[int, int, int, int]]] = None
|
|
||||||
resize_size: Optional[Tuple[int, int]] = None
|
|
||||||
control_time_s: float = 20.0
|
|
||||||
fixed_reset_joint_positions: Optional[Any] = None
|
|
||||||
reset_time_s: float = 5.0
|
|
||||||
joint_masking_action_space: Optional[Any] = None
|
|
||||||
ee_action_space_params: Optional[EEActionSpaceConfig] = None
|
|
||||||
reward_classifier_pretrained_path: Optional[str] = None
|
|
||||||
reward_classifier_config_file: Optional[str] = None
|
|
||||||
|
|
||||||
|
|
||||||
@EnvConfig.register_subclass(name="gym_manipulator")
|
|
||||||
@dataclass
|
|
||||||
class HILSerlRobotEnvConfig(EnvConfig):
|
|
||||||
"""Configuration for the HILSerlRobotEnv environment."""
|
|
||||||
|
|
||||||
robot: Optional[RobotConfig] = None
|
|
||||||
wrapper: Optional[EnvWrapperConfig] = None
|
|
||||||
fps: int = 10
|
|
||||||
mode: str = None # Either "record", "replay", None
|
|
||||||
repo_id: Optional[str] = None
|
|
||||||
dataset_root: Optional[str] = None
|
|
||||||
task: str = ""
|
|
||||||
num_episodes: int = 10 # only for record mode
|
|
||||||
episode: int = 0
|
|
||||||
device: str = "cuda"
|
|
||||||
push_to_hub: bool = True
|
|
||||||
pretrained_policy_name_or_path: Optional[str] = None
|
|
||||||
|
|
||||||
def gym_kwargs(self) -> dict:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
|
|
||||||
class HILSerlRobotEnv(gym.Env):
|
class HILSerlRobotEnv(gym.Env):
|
||||||
|
@ -813,9 +759,10 @@ class BatchCompitableWrapper(gym.ObservationWrapper):
|
||||||
|
|
||||||
|
|
||||||
class EEActionWrapper(gym.ActionWrapper):
|
class EEActionWrapper(gym.ActionWrapper):
|
||||||
def __init__(self, env, ee_action_space_params=None):
|
def __init__(self, env, ee_action_space_params=None, use_gripper=False):
|
||||||
super().__init__(env)
|
super().__init__(env)
|
||||||
self.ee_action_space_params = ee_action_space_params
|
self.ee_action_space_params = ee_action_space_params
|
||||||
|
self.use_gripper = use_gripper
|
||||||
|
|
||||||
# Initialize kinematics instance for the appropriate robot type
|
# Initialize kinematics instance for the appropriate robot type
|
||||||
robot_type = getattr(env.unwrapped.robot.config, "robot_type", "so100")
|
robot_type = getattr(env.unwrapped.robot.config, "robot_type", "so100")
|
||||||
|
@ -829,10 +776,12 @@ class EEActionWrapper(gym.ActionWrapper):
|
||||||
ee_action_space_params.z_step_size,
|
ee_action_space_params.z_step_size,
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
if self.use_gripper:
|
||||||
|
action_space_bounds = np.concatenate([action_space_bounds, [1.0]])
|
||||||
ee_action_space = gym.spaces.Box(
|
ee_action_space = gym.spaces.Box(
|
||||||
low=-action_space_bounds,
|
low=-action_space_bounds,
|
||||||
high=action_space_bounds,
|
high=action_space_bounds,
|
||||||
shape=(3,),
|
shape=(3 + int(self.use_gripper),),
|
||||||
dtype=np.float32,
|
dtype=np.float32,
|
||||||
)
|
)
|
||||||
if isinstance(self.action_space, gym.spaces.Tuple):
|
if isinstance(self.action_space, gym.spaces.Tuple):
|
||||||
|
@ -848,6 +797,10 @@ class EEActionWrapper(gym.ActionWrapper):
|
||||||
if isinstance(action, tuple):
|
if isinstance(action, tuple):
|
||||||
action, _ = action
|
action, _ = action
|
||||||
|
|
||||||
|
if self.use_gripper:
|
||||||
|
gripper_command = action[-1]
|
||||||
|
action = action[:-1]
|
||||||
|
|
||||||
current_joint_pos = self.unwrapped.robot.follower_arms["main"].read("Present_Position")
|
current_joint_pos = self.unwrapped.robot.follower_arms["main"].read("Present_Position")
|
||||||
current_ee_pos = self.fk_function(current_joint_pos)
|
current_ee_pos = self.fk_function(current_joint_pos)
|
||||||
if isinstance(action, torch.Tensor):
|
if isinstance(action, torch.Tensor):
|
||||||
|
@ -863,6 +816,12 @@ class EEActionWrapper(gym.ActionWrapper):
|
||||||
position_only=True,
|
position_only=True,
|
||||||
fk_func=self.fk_function,
|
fk_func=self.fk_function,
|
||||||
)
|
)
|
||||||
|
if self.use_gripper:
|
||||||
|
gripper_command = gripper_command * MAX_GRIPPER_COMMAND
|
||||||
|
gripper_state = self.unwrapped.robot.follower_arms["main"].read("Present_Position")[-1]
|
||||||
|
gripper_action = np.clip(gripper_state + gripper_command, 0, MAX_GRIPPER_COMMAND)
|
||||||
|
target_joint_pos[-1] = gripper_action
|
||||||
|
|
||||||
return target_joint_pos, is_intervention
|
return target_joint_pos, is_intervention
|
||||||
|
|
||||||
|
|
||||||
|
@ -912,6 +871,7 @@ class GamepadControlWrapper(gym.Wrapper):
|
||||||
x_step_size=1.0,
|
x_step_size=1.0,
|
||||||
y_step_size=1.0,
|
y_step_size=1.0,
|
||||||
z_step_size=1.0,
|
z_step_size=1.0,
|
||||||
|
use_gripper=False,
|
||||||
auto_reset=False,
|
auto_reset=False,
|
||||||
input_threshold=0.001,
|
input_threshold=0.001,
|
||||||
):
|
):
|
||||||
|
@ -948,6 +908,7 @@ class GamepadControlWrapper(gym.Wrapper):
|
||||||
z_step_size=z_step_size,
|
z_step_size=z_step_size,
|
||||||
)
|
)
|
||||||
self.auto_reset = auto_reset
|
self.auto_reset = auto_reset
|
||||||
|
self.use_gripper = use_gripper
|
||||||
self.input_threshold = input_threshold
|
self.input_threshold = input_threshold
|
||||||
self.controller.start()
|
self.controller.start()
|
||||||
|
|
||||||
|
@ -977,6 +938,15 @@ class GamepadControlWrapper(gym.Wrapper):
|
||||||
# Create action from gamepad input
|
# Create action from gamepad input
|
||||||
gamepad_action = np.array([delta_x, delta_y, delta_z], dtype=np.float32)
|
gamepad_action = np.array([delta_x, delta_y, delta_z], dtype=np.float32)
|
||||||
|
|
||||||
|
if self.use_gripper:
|
||||||
|
gripper_command = self.controller.gripper_command()
|
||||||
|
if gripper_command == "open":
|
||||||
|
gamepad_action = np.concatenate([gamepad_action, [1.0]])
|
||||||
|
elif gripper_command == "close":
|
||||||
|
gamepad_action = np.concatenate([gamepad_action, [-1.0]])
|
||||||
|
else:
|
||||||
|
gamepad_action = np.concatenate([gamepad_action, [0.0]])
|
||||||
|
|
||||||
# Check episode ending buttons
|
# Check episode ending buttons
|
||||||
# We'll rely on controller.get_episode_end_status() which returns "success", "failure", or None
|
# We'll rely on controller.get_episode_end_status() which returns "success", "failure", or None
|
||||||
episode_end_status = self.controller.get_episode_end_status()
|
episode_end_status = self.controller.get_episode_end_status()
|
||||||
|
@ -1023,6 +993,7 @@ class GamepadControlWrapper(gym.Wrapper):
|
||||||
final_action = (torch.from_numpy(gamepad_action), False)
|
final_action = (torch.from_numpy(gamepad_action), False)
|
||||||
else:
|
else:
|
||||||
final_action = torch.from_numpy(gamepad_action)
|
final_action = torch.from_numpy(gamepad_action)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# Use the original action
|
# Use the original action
|
||||||
final_action = action
|
final_action = action
|
||||||
|
@ -1138,7 +1109,11 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv:
|
||||||
# env = RewardWrapper(env=env, reward_classifier=reward_classifier, device=cfg.device)
|
# env = RewardWrapper(env=env, reward_classifier=reward_classifier, device=cfg.device)
|
||||||
env = TimeLimitWrapper(env=env, control_time_s=cfg.wrapper.control_time_s, fps=cfg.fps)
|
env = TimeLimitWrapper(env=env, control_time_s=cfg.wrapper.control_time_s, fps=cfg.fps)
|
||||||
if cfg.wrapper.ee_action_space_params is not None:
|
if cfg.wrapper.ee_action_space_params is not None:
|
||||||
env = EEActionWrapper(env=env, ee_action_space_params=cfg.wrapper.ee_action_space_params)
|
env = EEActionWrapper(
|
||||||
|
env=env,
|
||||||
|
ee_action_space_params=cfg.wrapper.ee_action_space_params,
|
||||||
|
use_gripper=cfg.wrapper.use_gripper,
|
||||||
|
)
|
||||||
if cfg.wrapper.ee_action_space_params is not None and cfg.wrapper.ee_action_space_params.use_gamepad:
|
if cfg.wrapper.ee_action_space_params is not None and cfg.wrapper.ee_action_space_params.use_gamepad:
|
||||||
# env = ActionScaleWrapper(env=env, ee_action_space_params=cfg.wrapper.ee_action_space_params)
|
# env = ActionScaleWrapper(env=env, ee_action_space_params=cfg.wrapper.ee_action_space_params)
|
||||||
env = GamepadControlWrapper(
|
env = GamepadControlWrapper(
|
||||||
|
@ -1146,6 +1121,7 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv:
|
||||||
x_step_size=cfg.wrapper.ee_action_space_params.x_step_size,
|
x_step_size=cfg.wrapper.ee_action_space_params.x_step_size,
|
||||||
y_step_size=cfg.wrapper.ee_action_space_params.y_step_size,
|
y_step_size=cfg.wrapper.ee_action_space_params.y_step_size,
|
||||||
z_step_size=cfg.wrapper.ee_action_space_params.z_step_size,
|
z_step_size=cfg.wrapper.ee_action_space_params.z_step_size,
|
||||||
|
use_gripper=cfg.wrapper.use_gripper,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
env = KeyboardInterfaceWrapper(env=env)
|
env = KeyboardInterfaceWrapper(env=env)
|
||||||
|
@ -1184,7 +1160,7 @@ def get_classifier(cfg):
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def record_dataset(env, policy, cfg: HILSerlRobotEnvConfig):
|
def record_dataset(env, policy, cfg):
|
||||||
"""
|
"""
|
||||||
Record a dataset of robot interactions using either a policy or teleop.
|
Record a dataset of robot interactions using either a policy or teleop.
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue