refactor: add fps parameter to reset_environment and implement frame rate control during reset

This commit is contained in:
Ke-Wang1017 2025-04-03 17:14:38 +01:00
parent bb86dcb5b3
commit e55e904c3e
3 changed files with 651 additions and 2 deletions

View File

@ -284,7 +284,7 @@ def control_loop(
break break
def reset_environment(robot, events, reset_time_s): def reset_environment(robot, events, reset_time_s, fps):
# TODO(rcadene): refactor warmup_record and reset_environment # TODO(rcadene): refactor warmup_record and reset_environment
if has_method(robot, "teleop_safety_stop"): if has_method(robot, "teleop_safety_stop"):
robot.teleop_safety_stop() robot.teleop_safety_stop()
@ -297,6 +297,7 @@ def reset_environment(robot, events, reset_time_s):
with tqdm.tqdm(total=reset_time_s, desc="Waiting") as pbar: with tqdm.tqdm(total=reset_time_s, desc="Waiting") as pbar:
last_update = 0 # Track the last update time last_update = 0 # Track the last update time
while timestamp < reset_time_s: while timestamp < reset_time_s:
start_loop_t = time.perf_counter()
robot.teleop_step(record_data=False) robot.teleop_step(record_data=False)
timestamp = time.perf_counter() - start_vencod_t timestamp = time.perf_counter() - start_vencod_t
@ -310,6 +311,8 @@ def reset_environment(robot, events, reset_time_s):
events["exit_early"] = False events["exit_early"] = False
break break
dt_s = time.perf_counter() - start_loop_t
busy_wait(1 / fps - dt_s)
def stop_recording(robot, listener, display_cameras): def stop_recording(robot, listener, display_cameras):
robot.disconnect() robot.disconnect()

View File

@ -0,0 +1,215 @@
import time
import multiprocessing
import numpy as np
import inputs
from typing import Tuple
from dataclasses import dataclass
from enum import Enum
class ControllerType(Enum):
PS5 = "ps5"
XBOX = "xbox"
@dataclass
class ControllerConfig:
resolution: dict
scale: dict
class JoystickInterface:
"""
This class provides an interface to the Joystick/Gamepad.
It continuously reads the joystick state and provides
a "get_action" method to get the latest action and button state.
"""
CONTROLLER_CONFIGS = {
ControllerType.PS5: ControllerConfig(
# PS5 controller joystick values have 8 bit resolution [0, 255]
resolution={
'ABS_X': 2**8,
'ABS_Y': 2**8,
'ABS_RX': 2**8,
'ABS_RY': 2**8,
'ABS_Z': 2**8,
'ABS_RZ': 2**8,
'ABS_HAT0X': 1.0,
},
scale={
'ABS_X': 0.4,
'ABS_Y': 0.4,
'ABS_RX': 0.5,
'ABS_RY': 0.5,
'ABS_Z': 0.8,
'ABS_RZ': 1.2,
'ABS_HAT0X': 0.5,
}
),
ControllerType.XBOX: ControllerConfig(
# XBOX controller joystick values have 16 bit resolution [0, 65535]
resolution={
'ABS_X': 2**16,
'ABS_Y': 2**16,
'ABS_RX': 2**16,
'ABS_RY': 2**16,
'ABS_Z': 2**8,
'ABS_RZ': 2**8,
'ABS_HAT0X': 1.0,
},
scale={
'ABS_X': -0.01,
'ABS_Y': -0.005,
'ABS_RX': 0.015,
'ABS_RY': 0.015,
'ABS_Z': 0.0025,
'ABS_RZ': 0.0025,
'ABS_HAT0X': 0.03,
}
),
}
def __init__(self, controller_type=ControllerType.XBOX):
self.controller_type = controller_type
self.controller_config = self.CONTROLLER_CONFIGS[controller_type]
# Manager to handle shared state between processes
self.manager = multiprocessing.Manager()
self.latest_data = self.manager.dict()
self.latest_data["action"] = [0.0] * 6
self.latest_data["buttons"] = [False, False, False, False, False, False]
# Start a process to continuously read Joystick state
self._process = multiprocessing.Process(target=self._read_joystick)
self._process.daemon = True
self._process.start()
def _read_joystick(self):
"""Add a try-except to prevent thread crashes"""
action = [0.0] * 6
buttons = [False, False, False, False, False, False]
while True:
try:
# Get fresh events
events = inputs.get_gamepad()
# Process events
for event in events:
if event.code in self.controller_config.resolution:
# Calculate relative changes based on the axis
# Normalize the joystick input values to range [-1, 1] expected by the environment
resolution = self.controller_config.resolution[event.code]
if self.controller_type == ControllerType.PS5:
normalized_value = (event.state - (resolution / 2)) / (resolution / 2)
else:
normalized_value = event.state / (resolution / 2)
scaled_value = normalized_value * self.controller_config.scale[event.code]
if event.code == 'ABS_Y':
action[0] = scaled_value
elif event.code == 'ABS_X':
action[1] = scaled_value
elif event.code == 'ABS_RZ':
action[2] = scaled_value
elif event.code == 'ABS_Z':
# Flip sign so this will go in the down direction
action[2] = -scaled_value
elif event.code == 'ABS_RX':
action[3] = scaled_value
elif event.code == 'ABS_RY':
action[4] = scaled_value
elif event.code == 'ABS_HAT0X':
action[5] = scaled_value
# Handle button events
elif event.code == 'BTN_TL':
buttons[0] = bool(event.state)
elif event.code == 'BTN_TR':
buttons[1] = bool(event.state)
# Go back to home, B button on xbox controller
elif event.code == 'BTN_EAST':
buttons[2] = bool(event.state)
# Indicate recording is starting, X button on xbox controller
elif event.code == 'BTN_NORTH':
buttons[3] = bool(event.state)
# Start intervention, A button on xbox controller
elif event.code == 'BTN_SOUTH':
buttons[4] = bool(event.state)
# E-Stop, Y button on xbox controller
elif event.code == 'BTN_WEST':
buttons[5] = bool(event.state)
# Update the shared state
self.latest_data["action"] = action
self.latest_data["buttons"] = buttons
except inputs.UnpluggedError:
print("No controller found. Retrying...")
time.sleep(1)
def get_action(self):
"""Returns the latest action and button state from the Joystick."""
action = self.latest_data["action"]
buttons = self.latest_data["buttons"]
return np.array(action), buttons
def close(self):
"""Close the joystick interface and cleanup resources."""
try:
if hasattr(self, '_process') and self._process is not None:
import signal
self._process.terminate()
self._process.join()
self._process = None
except (ImportError, AttributeError):
# Fallback if signal module is not available
if hasattr(self, '_process') and self._process is not None:
self._process.kill()
self._process.join()
self._process = None
class JoystickIntervention():
def __init__(self, controller_type=ControllerType.XBOX, gripper_enabled=True):
self.gripper_enabled = gripper_enabled
self.expert = JoystickInterface(controller_type=controller_type)
self.left, self.right, self.home, self.intervention_start, self.success, self.estop = False, False, False, False, False, False
def action(self) -> np.ndarray:
"""
Output:
- action: joystick action if nonezero; else, policy action
"""
deadzone = 0.003
expert_a, buttons = self.expert.get_action()
self.left, self.right, self.home, self.intervention_start, self.success, self.estop = tuple(buttons)
# import logging
# logging.info(f"success on joystick: {self.success}")
for i, a in enumerate(expert_a):
if abs(a) <= deadzone:
expert_a[i] = 0.0
if abs(expert_a[0]) >= 0.003 and expert_a[1] >= 0.003 and expert_a[1] <= 0.005:
expert_a[1] = 0.0
expert_a[3:6] /= 2
if self.gripper_enabled:
if self.left: # close gripper
gripper_action = [0.0]
elif self.right: # open gripper
gripper_action = [0.08]
else:
gripper_action = [0.0]
expert_a = np.concatenate((expert_a, gripper_action), axis=0)
return expert_a
def get_intervention_start(self) -> bool:
_, buttons = self.expert.get_action()
_, _, _, self.intervention_start, self.success, self.estop = tuple(buttons)
return self.intervention_start, self.success, self.estop
def close(self):
self.expert.close()

View File

@ -0,0 +1,431 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# To run teleoperate:
# python lerobot/scripts/control_robot.py teleoperate --robot-path lerobot/configs/robot/piper.yaml --fps 30
import time
from dataclasses import dataclass, field, replace
import torch
import numpy as np
from lerobot.common.robot_devices.robots.joystick_interface import JoystickIntervention, ControllerType
from lerobot.common.robot_devices.cameras.utils import Camera
from lerobot.common.robot_devices.robots.manipulator import ManipulatorRobot
from piper_sdk import *
from lerobot.common.robot_devices.utils import busy_wait
from datetime import datetime
from copy import deepcopy
@dataclass
class PiperRobotConfig:
robot_type: str | None = "piper"
fps: int = 20
cameras: dict[str, Camera] = field(default_factory=lambda: {})
leader_arms: dict = field(default_factory=lambda: {})
# TODO(aliberts): add feature with max_relative target
# TODO(aliberts): add comment on max_relative target
max_relative_target: list[float] | float | None = None
joint_position_relative_bounds: dict[np.ndarray] | None = None
class Rate:
def __init__(self, hz: float):
self.period = 1.0 / hz
self.last_time = time.perf_counter()
def sleep(self, elapsed: float):
if elapsed < self.period:
time.sleep(self.period - elapsed)
self.last_time = time.perf_counter()
def rectify_signal(current, previous):
"""Rectify a single signal value using its previous value"""
if abs(current - previous) >= .18:
if current > previous:
current -= .36
else:
current += .36
return current
class EulerAngles:
def __init__(self):
self.prev_rx = 0
self.prev_ry = 0
self.prev_rz = 0
def rectify(self, orientation):
rx_rectified = rectify_signal(orientation[0], self.prev_rx)
ry_rectified = rectify_signal(orientation[1], self.prev_ry)
rz_rectified = rectify_signal(orientation[2], self.prev_rz)
self.prev_rx = rx_rectified
self.prev_ry = ry_rectified
self.prev_rz = rz_rectified
return [rx_rectified, ry_rectified, rz_rectified]
class PiperRobot(ManipulatorRobot):
"""Wrapper of piper_sdk.robot.Robot"""
def __init__(self, config: PiperRobotConfig | None = None, **kwargs):
super().__init__()
if config is None:
config = PiperRobotConfig()
# Overwrite config arguments using kwargs
self.config = replace(config, **kwargs)
# get fps fron **kwargs
self.fps = kwargs.get("--fps", self.config.fps)
self.robot_type = self.config.robot_type
self.leader_arms = self.config.leader_arms
self.cameras = self.config.cameras
self.is_connected = False
self.teleop = JoystickIntervention(controller_type=ControllerType.XBOX, gripper_enabled=True)
self.logs = {}
self.action_repeat = 2
self.state_keys = None
self.action_keys = None
self.euler_filter = EulerAngles()
# init piper robot
self.piper = C_PiperInterface("can0")
self.piper.ConnectPort()
self.piper.EnableArm(7)
self.piper.GripperCtrl(0,1000,0x01, 0)
self.state_scaling_factor = 1e6
# self.default_pos = [0.200337, 0.020786, 0.289284, 0.179831, 0.010918, 0.173467, 0.0]
self.default_pos = [0.171642, -0.028, 0.165, 0.179831, 0.010918, 0.173467, 0.0]
self.joint_position_relative_bounds = self.config.joint_position_relative_bounds
self.previous_ee_position = np.array([0.0, 0.0, 0.0])
self.current_state = np.array([0.0, 0.0, 0.0, 0.0])
# self.data_rows = []
# self.csv_filename = f"arm_poses_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv"
# Register the signal handler
import signal
signal.signal(signal.SIGINT, self.signal_handler)
@property
def motor_features(self) -> dict:
# Get the base features from parent class
base_features = super().motor_features
# Modify or add new features
base_features["action"]["dtype"] = "float64" # Change dtype
# Or completely redefine the features
return {
"action": {
"dtype": "float64",
"shape": (4,), # Change shape
"names": ["joint1", "joint2"], # New names
},
"observation.state": {
"dtype": "float64",
"shape": (7,),
"names": ["joint1", "joint2", "velocity1", "velocity2"],
},
}
# Signal handler for graceful shutdown
def signal_handler(self, sig, frame):
print('\nSaving data and exiting...')
import sys
# self.save_to_csv()
self.disable_robot()
sys.exit(0)
def disable_robot(self):
self.piper.DisableArm(7)
self.piper.GripperCtrl(0,1000,0x02, 0)
# Function to save data to CSV
def save_to_csv(self):
import csv
with open(self.csv_filename, 'w', newline='') as csvfile:
writer = csv.writer(csvfile)
# Write header
writer.writerow(['Timestamp', 'X', 'Y', 'Z', 'RX', 'RY', 'RZ', 'grippers_angle', 'grippers_effort'])
# Write all stored data
writer.writerows(self.data_rows)
print(f"Data saved to {self.csv_filename}")
def startup_robot(self, piper:C_PiperInterface):
'''
enable robot and check enable status, try 5s, if enable timeout, exit program
'''
enable_flag = False
# 设置超时时间(秒)
timeout = 5
# 记录进入循环前的时间
start_time = time.time()
elapsed_time_flag = False
while not (enable_flag):
elapsed_time = time.time() - start_time
print("--------------------")
enable_flag = piper.GetArmLowSpdInfoMsgs().motor_1.foc_status.driver_enable_status and \
piper.GetArmLowSpdInfoMsgs().motor_2.foc_status.driver_enable_status and \
piper.GetArmLowSpdInfoMsgs().motor_3.foc_status.driver_enable_status and \
piper.GetArmLowSpdInfoMsgs().motor_4.foc_status.driver_enable_status and \
piper.GetArmLowSpdInfoMsgs().motor_5.foc_status.driver_enable_status and \
piper.GetArmLowSpdInfoMsgs().motor_6.foc_status.driver_enable_status
print("enable status:",enable_flag)
piper.EnableArm(7)
piper.GripperCtrl(0,1000,0x01, 0)
print("--------------------")
# check if timeout
if elapsed_time > timeout:
print("enable timeout....")
elapsed_time_flag = True
enable_flag = False
break
time.sleep(1)
pass
if not elapsed_time_flag:
return enable_flag
else:
print("enable timeout, exit program")
raise RuntimeError("Failed to enable robot motors within timeout period")
def connect(self) -> None:
self.is_connected = self.startup_robot(self.piper)
for name in self.cameras:
self.cameras[name].connect()
self.is_connected = self.is_connected and self.cameras[name].is_connected
if not self.is_connected:
print("Could not connect to the cameras, check that all cameras are plugged-in.")
raise ConnectionError()
self.move_to_home()
# Used when connecting to robot
def move_to_home(self) -> None:
count = 0
while True:
if(count == 0):
print("1-----------")
action = [0.07,0,0.22,0,0.08,0,0.08]
# elif(count == 400):
# print("2-----------")
# action = [0.15,0.0,0.35,0.08,0.08,0.025,0.0] # 0.08 is maximum gripper position
elif(count == 300):
print("2-----------")
action = self.default_pos
count += 1
before_write_t = time.perf_counter()
state = self.get_state()
state = state["state"]
# state[3:6] = self.euler_filter.rectify(state[3:6])
self.send_action(action)
# self.rate.sleep(time.perf_counter() - before_write_t)
if count > 600:
break
# Used when returning to home after finishing a demo
def move_to_home_2(self):
count = 0
while True:
if count <= 100:
action = self.default_pos
action[6] = 0.08
elif count > 100:
action = self.default_pos
action[6] = 0.0
count += 1
before_write_t = time.perf_counter()
state = self.get_state()
state = state["state"]
# state[3:6] = self.euler_filter.rectify(state[3:6])
self.send_action(action)
# self.rate.sleep(time.perf_counter() - before_write_t)
if count > 300:
break
def teleop_step(
self, record_data=False
) -> None | tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
# TODO(aliberts): return ndarrays instead of torch.Tensors
if not self.is_connected:
raise ConnectionError()
before_read_t = time.perf_counter()
state = self.get_state()
self.get_intervention_start()
state = state["state"]
# state[3:6] = self.euler_filter.rectify(state[3:6])
# get relative action from joystick
action = self.teleop.action()
# print(action)
# Convert action to numpy array first
action = np.array(action, dtype=np.float32)
action_record = deepcopy(action[:2])
# action_record = np.concatenate([action[:3], [action[-1]]])
for i in range(self.action_repeat):
action[:2] = state[:2] + action_record*(i+1)/self.action_repeat
self.send_action(action)
busy_wait(0.03)
if self.teleop.home:
self.move_to_home_2()
if self.state_keys is None:
self.state_keys = list(state)
if not record_data:
return
# action_record[3:6] -= state[3:6] # just get delta orientation
# it has to be done after send_action
# state[:2] -= self.default_pos[:2]
state = torch.as_tensor(state).to(torch.float32)
action_record = torch.as_tensor(action_record).to(torch.float32)
# print(action_record)
# Capture images from cameras
images = {}
for name in self.cameras:
before_camread_t = time.perf_counter()
images[name] = self.cameras[name].async_read()
images[name] = torch.from_numpy(images[name])
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
# Populate output dictionnaries
obs_dict, action_dict = {}, {}
obs_dict["observation.state"] = state
# obs_dict.update({"gripper_effort": self.get_gripper_effort()})
action_dict["action"] = action_record
for name in self.cameras:
obs_dict[f"observation.images.{name}"] = images[name]
return obs_dict, action_dict
def get_state(self) -> dict:
end_effector_pose = self.piper.GetArmEndPoseMsgs()
gripper_pose = self.piper.GetArmGripperMsgs()
self.previous_ee_position = self.current_state[:3]
# Convert to float32 numpy array
self.current_state = np.array([
end_effector_pose.end_pose.X_axis,
end_effector_pose.end_pose.Y_axis,
end_effector_pose.end_pose.Z_axis,
gripper_pose.gripper_state.grippers_angle
], dtype=np.float32) / self.state_scaling_factor
# add velocity to state
velocity = (self.current_state[:3] - self.previous_ee_position) * self.fps
state = np.concatenate([self.current_state, velocity])
# print(f"state: {state}")
return {
"state": state,
}
def get_gripper_state(self) -> float:
gripper_pose = self.piper.GetArmGripperMsgs()
return np.array([gripper_pose.gripper_state.grippers_angle, gripper_pose.gripper_state.grippers_effort])
def get_intervention_start(self) -> bool:
intervention_start, success, estop = self.teleop.get_intervention_start()
if estop:
self.disable_robot()
return intervention_start, success
# def get_ee_pos(self) -> list[float]:
# end_effector_pose = self.piper.GetArmEndPoseMsgs()
# gripper_pose = self.piper.GetArmGripperMsgs()
# return [end_effector_pose.end_pose.X_axis,end_effector_pose.end_pose.Y_axis,end_effector_pose.end_pose.Z_axis,gripper_pose.gripper_state.grippers_angle]
def capture_observation(self) -> dict:
# TODO(aliberts): return ndarrays instead of torch.Tensors
before_read_t = time.perf_counter()
state = self.get_state()
self.get_intervention_start()
# state = state["state"]
# state["state"][3:6] = self.euler_filter.rectify(state["state"][3:6])
# state["state"][:2] -= self.default_pos[:2]
self.logs["read_pos_dt_s"] = time.perf_counter() - before_read_t
if self.state_keys is None:
self.state_keys = list(state)
state = torch.as_tensor(np.array(list(state.values())).astype(np.float32))
state = state.squeeze(0)
# Capture images from cameras
images = {}
for name in self.cameras:
before_camread_t = time.perf_counter()
images[name] = self.cameras[name].async_read()
images[name] = torch.from_numpy(images[name])
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
# Populate output dictionnaries
obs_dict = {}
obs_dict["observation.state"] = state
for name in self.cameras:
obs_dict[f"observation.images.{name}"] = images[name]
# obs_dict.update({"gripper_effort": self.get_gripper_effort()})
return obs_dict
def send_action(self, action: list[float]) -> None:
if not self.is_connected:
raise ConnectionError()
# check if action is tensor and if it is, convert it to list
if isinstance(action, torch.Tensor):
action = action.tolist()
# clip rz value to be between -pi/2 and pi/2 for safety
# action[5] = max(-np.pi/2, min(np.pi/2, action[5]))
X = round(action[0]*self.state_scaling_factor)
Y = round(action[1]*self.state_scaling_factor)
Z = round(self.default_pos[2]*self.state_scaling_factor)
RX = round(self.default_pos[3]*self.state_scaling_factor)
RY = round(self.default_pos[4]*self.state_scaling_factor)
RZ = round(self.default_pos[5]*self.state_scaling_factor)
Gripper = round(action[-1]*self.state_scaling_factor)
self.piper.MotionCtrl_2(0x01, 0x00, 30, 0x00)
self.piper.EndPoseCtrl(X, Y, Z, RX, RY, RZ)
self.piper.GripperCtrl(abs(Gripper), 1000, 0x01, 0)
self.piper.MotionCtrl_2(0x01, 0x00, 30, 0x00)
# It is needed to give time for the CAN bus communication to complete
busy_wait(0.01)
# TODO(aliberts): return action_sent when motion is limited
return torch.tensor(action)
def print_logs(self) -> None:
pass
# TODO(aliberts): move robot-specific logs logic here
def disconnect(self) -> None:
if self.teleop is not None:
self.teleop.close()
# if len(self.cameras) > 0:
# for cam in self.cameras.values():
# cam.disconnect()
self.is_connected = False
def __del__(self):
self.disconnect()