refactor: add fps parameter to reset_environment and implement frame rate control during reset
This commit is contained in:
parent
bb86dcb5b3
commit
e55e904c3e
|
@ -284,7 +284,7 @@ def control_loop(
|
|||
break
|
||||
|
||||
|
||||
def reset_environment(robot, events, reset_time_s):
|
||||
def reset_environment(robot, events, reset_time_s, fps):
|
||||
# TODO(rcadene): refactor warmup_record and reset_environment
|
||||
if has_method(robot, "teleop_safety_stop"):
|
||||
robot.teleop_safety_stop()
|
||||
|
@ -297,6 +297,7 @@ def reset_environment(robot, events, reset_time_s):
|
|||
with tqdm.tqdm(total=reset_time_s, desc="Waiting") as pbar:
|
||||
last_update = 0 # Track the last update time
|
||||
while timestamp < reset_time_s:
|
||||
start_loop_t = time.perf_counter()
|
||||
robot.teleop_step(record_data=False)
|
||||
timestamp = time.perf_counter() - start_vencod_t
|
||||
|
||||
|
@ -310,6 +311,8 @@ def reset_environment(robot, events, reset_time_s):
|
|||
events["exit_early"] = False
|
||||
break
|
||||
|
||||
dt_s = time.perf_counter() - start_loop_t
|
||||
busy_wait(1 / fps - dt_s)
|
||||
|
||||
def stop_recording(robot, listener, display_cameras):
|
||||
robot.disconnect()
|
||||
|
|
|
@ -0,0 +1,215 @@
|
|||
import time
|
||||
import multiprocessing
|
||||
import numpy as np
|
||||
import inputs
|
||||
from typing import Tuple
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class ControllerType(Enum):
|
||||
PS5 = "ps5"
|
||||
XBOX = "xbox"
|
||||
|
||||
@dataclass
|
||||
class ControllerConfig:
|
||||
resolution: dict
|
||||
scale: dict
|
||||
|
||||
class JoystickInterface:
|
||||
"""
|
||||
This class provides an interface to the Joystick/Gamepad.
|
||||
It continuously reads the joystick state and provides
|
||||
a "get_action" method to get the latest action and button state.
|
||||
"""
|
||||
|
||||
CONTROLLER_CONFIGS = {
|
||||
ControllerType.PS5: ControllerConfig(
|
||||
# PS5 controller joystick values have 8 bit resolution [0, 255]
|
||||
resolution={
|
||||
'ABS_X': 2**8,
|
||||
'ABS_Y': 2**8,
|
||||
'ABS_RX': 2**8,
|
||||
'ABS_RY': 2**8,
|
||||
'ABS_Z': 2**8,
|
||||
'ABS_RZ': 2**8,
|
||||
'ABS_HAT0X': 1.0,
|
||||
},
|
||||
scale={
|
||||
'ABS_X': 0.4,
|
||||
'ABS_Y': 0.4,
|
||||
'ABS_RX': 0.5,
|
||||
'ABS_RY': 0.5,
|
||||
'ABS_Z': 0.8,
|
||||
'ABS_RZ': 1.2,
|
||||
'ABS_HAT0X': 0.5,
|
||||
}
|
||||
),
|
||||
ControllerType.XBOX: ControllerConfig(
|
||||
# XBOX controller joystick values have 16 bit resolution [0, 65535]
|
||||
resolution={
|
||||
'ABS_X': 2**16,
|
||||
'ABS_Y': 2**16,
|
||||
'ABS_RX': 2**16,
|
||||
'ABS_RY': 2**16,
|
||||
'ABS_Z': 2**8,
|
||||
'ABS_RZ': 2**8,
|
||||
'ABS_HAT0X': 1.0,
|
||||
},
|
||||
scale={
|
||||
'ABS_X': -0.01,
|
||||
'ABS_Y': -0.005,
|
||||
'ABS_RX': 0.015,
|
||||
'ABS_RY': 0.015,
|
||||
'ABS_Z': 0.0025,
|
||||
'ABS_RZ': 0.0025,
|
||||
'ABS_HAT0X': 0.03,
|
||||
}
|
||||
),
|
||||
}
|
||||
|
||||
def __init__(self, controller_type=ControllerType.XBOX):
|
||||
self.controller_type = controller_type
|
||||
self.controller_config = self.CONTROLLER_CONFIGS[controller_type]
|
||||
|
||||
# Manager to handle shared state between processes
|
||||
self.manager = multiprocessing.Manager()
|
||||
self.latest_data = self.manager.dict()
|
||||
self.latest_data["action"] = [0.0] * 6
|
||||
self.latest_data["buttons"] = [False, False, False, False, False, False]
|
||||
|
||||
# Start a process to continuously read Joystick state
|
||||
self._process = multiprocessing.Process(target=self._read_joystick)
|
||||
self._process.daemon = True
|
||||
self._process.start()
|
||||
|
||||
|
||||
def _read_joystick(self):
|
||||
"""Add a try-except to prevent thread crashes"""
|
||||
action = [0.0] * 6
|
||||
buttons = [False, False, False, False, False, False]
|
||||
|
||||
while True:
|
||||
try:
|
||||
# Get fresh events
|
||||
events = inputs.get_gamepad()
|
||||
|
||||
# Process events
|
||||
for event in events:
|
||||
if event.code in self.controller_config.resolution:
|
||||
# Calculate relative changes based on the axis
|
||||
# Normalize the joystick input values to range [-1, 1] expected by the environment
|
||||
resolution = self.controller_config.resolution[event.code]
|
||||
if self.controller_type == ControllerType.PS5:
|
||||
normalized_value = (event.state - (resolution / 2)) / (resolution / 2)
|
||||
else:
|
||||
normalized_value = event.state / (resolution / 2)
|
||||
scaled_value = normalized_value * self.controller_config.scale[event.code]
|
||||
|
||||
if event.code == 'ABS_Y':
|
||||
action[0] = scaled_value
|
||||
elif event.code == 'ABS_X':
|
||||
action[1] = scaled_value
|
||||
elif event.code == 'ABS_RZ':
|
||||
action[2] = scaled_value
|
||||
elif event.code == 'ABS_Z':
|
||||
# Flip sign so this will go in the down direction
|
||||
action[2] = -scaled_value
|
||||
elif event.code == 'ABS_RX':
|
||||
action[3] = scaled_value
|
||||
elif event.code == 'ABS_RY':
|
||||
action[4] = scaled_value
|
||||
elif event.code == 'ABS_HAT0X':
|
||||
action[5] = scaled_value
|
||||
|
||||
# Handle button events
|
||||
elif event.code == 'BTN_TL':
|
||||
buttons[0] = bool(event.state)
|
||||
elif event.code == 'BTN_TR':
|
||||
buttons[1] = bool(event.state)
|
||||
# Go back to home, B button on xbox controller
|
||||
elif event.code == 'BTN_EAST':
|
||||
buttons[2] = bool(event.state)
|
||||
# Indicate recording is starting, X button on xbox controller
|
||||
elif event.code == 'BTN_NORTH':
|
||||
buttons[3] = bool(event.state)
|
||||
# Start intervention, A button on xbox controller
|
||||
elif event.code == 'BTN_SOUTH':
|
||||
buttons[4] = bool(event.state)
|
||||
# E-Stop, Y button on xbox controller
|
||||
elif event.code == 'BTN_WEST':
|
||||
buttons[5] = bool(event.state)
|
||||
|
||||
# Update the shared state
|
||||
self.latest_data["action"] = action
|
||||
self.latest_data["buttons"] = buttons
|
||||
|
||||
except inputs.UnpluggedError:
|
||||
print("No controller found. Retrying...")
|
||||
time.sleep(1)
|
||||
|
||||
def get_action(self):
|
||||
"""Returns the latest action and button state from the Joystick."""
|
||||
action = self.latest_data["action"]
|
||||
buttons = self.latest_data["buttons"]
|
||||
return np.array(action), buttons
|
||||
|
||||
def close(self):
|
||||
"""Close the joystick interface and cleanup resources."""
|
||||
try:
|
||||
if hasattr(self, '_process') and self._process is not None:
|
||||
import signal
|
||||
self._process.terminate()
|
||||
self._process.join()
|
||||
self._process = None
|
||||
except (ImportError, AttributeError):
|
||||
# Fallback if signal module is not available
|
||||
if hasattr(self, '_process') and self._process is not None:
|
||||
self._process.kill()
|
||||
self._process.join()
|
||||
self._process = None
|
||||
|
||||
|
||||
class JoystickIntervention():
|
||||
def __init__(self, controller_type=ControllerType.XBOX, gripper_enabled=True):
|
||||
self.gripper_enabled = gripper_enabled
|
||||
self.expert = JoystickInterface(controller_type=controller_type)
|
||||
self.left, self.right, self.home, self.intervention_start, self.success, self.estop = False, False, False, False, False, False
|
||||
|
||||
def action(self) -> np.ndarray:
|
||||
"""
|
||||
Output:
|
||||
- action: joystick action if nonezero; else, policy action
|
||||
"""
|
||||
deadzone = 0.003
|
||||
|
||||
expert_a, buttons = self.expert.get_action()
|
||||
self.left, self.right, self.home, self.intervention_start, self.success, self.estop = tuple(buttons)
|
||||
# import logging
|
||||
# logging.info(f"success on joystick: {self.success}")
|
||||
|
||||
for i, a in enumerate(expert_a):
|
||||
if abs(a) <= deadzone:
|
||||
expert_a[i] = 0.0
|
||||
if abs(expert_a[0]) >= 0.003 and expert_a[1] >= 0.003 and expert_a[1] <= 0.005:
|
||||
expert_a[1] = 0.0
|
||||
expert_a[3:6] /= 2
|
||||
|
||||
if self.gripper_enabled:
|
||||
if self.left: # close gripper
|
||||
gripper_action = [0.0]
|
||||
elif self.right: # open gripper
|
||||
gripper_action = [0.08]
|
||||
else:
|
||||
gripper_action = [0.0]
|
||||
expert_a = np.concatenate((expert_a, gripper_action), axis=0)
|
||||
|
||||
return expert_a
|
||||
|
||||
def get_intervention_start(self) -> bool:
|
||||
_, buttons = self.expert.get_action()
|
||||
_, _, _, self.intervention_start, self.success, self.estop = tuple(buttons)
|
||||
return self.intervention_start, self.success, self.estop
|
||||
|
||||
def close(self):
|
||||
self.expert.close()
|
|
@ -0,0 +1,431 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# To run teleoperate:
|
||||
# python lerobot/scripts/control_robot.py teleoperate --robot-path lerobot/configs/robot/piper.yaml --fps 30
|
||||
|
||||
import time
|
||||
from dataclasses import dataclass, field, replace
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from lerobot.common.robot_devices.robots.joystick_interface import JoystickIntervention, ControllerType
|
||||
|
||||
from lerobot.common.robot_devices.cameras.utils import Camera
|
||||
from lerobot.common.robot_devices.robots.manipulator import ManipulatorRobot
|
||||
from piper_sdk import *
|
||||
from lerobot.common.robot_devices.utils import busy_wait
|
||||
|
||||
from datetime import datetime
|
||||
from copy import deepcopy
|
||||
|
||||
@dataclass
|
||||
class PiperRobotConfig:
|
||||
robot_type: str | None = "piper"
|
||||
fps: int = 20
|
||||
cameras: dict[str, Camera] = field(default_factory=lambda: {})
|
||||
leader_arms: dict = field(default_factory=lambda: {})
|
||||
# TODO(aliberts): add feature with max_relative target
|
||||
# TODO(aliberts): add comment on max_relative target
|
||||
max_relative_target: list[float] | float | None = None
|
||||
joint_position_relative_bounds: dict[np.ndarray] | None = None
|
||||
|
||||
|
||||
class Rate:
|
||||
def __init__(self, hz: float):
|
||||
self.period = 1.0 / hz
|
||||
self.last_time = time.perf_counter()
|
||||
|
||||
def sleep(self, elapsed: float):
|
||||
if elapsed < self.period:
|
||||
time.sleep(self.period - elapsed)
|
||||
self.last_time = time.perf_counter()
|
||||
|
||||
|
||||
def rectify_signal(current, previous):
|
||||
"""Rectify a single signal value using its previous value"""
|
||||
if abs(current - previous) >= .18:
|
||||
if current > previous:
|
||||
current -= .36
|
||||
else:
|
||||
current += .36
|
||||
return current
|
||||
|
||||
|
||||
class EulerAngles:
|
||||
def __init__(self):
|
||||
self.prev_rx = 0
|
||||
self.prev_ry = 0
|
||||
self.prev_rz = 0
|
||||
|
||||
def rectify(self, orientation):
|
||||
rx_rectified = rectify_signal(orientation[0], self.prev_rx)
|
||||
ry_rectified = rectify_signal(orientation[1], self.prev_ry)
|
||||
rz_rectified = rectify_signal(orientation[2], self.prev_rz)
|
||||
|
||||
self.prev_rx = rx_rectified
|
||||
self.prev_ry = ry_rectified
|
||||
self.prev_rz = rz_rectified
|
||||
|
||||
return [rx_rectified, ry_rectified, rz_rectified]
|
||||
|
||||
|
||||
class PiperRobot(ManipulatorRobot):
|
||||
"""Wrapper of piper_sdk.robot.Robot"""
|
||||
|
||||
def __init__(self, config: PiperRobotConfig | None = None, **kwargs):
|
||||
|
||||
super().__init__()
|
||||
if config is None:
|
||||
config = PiperRobotConfig()
|
||||
# Overwrite config arguments using kwargs
|
||||
self.config = replace(config, **kwargs)
|
||||
# get fps fron **kwargs
|
||||
self.fps = kwargs.get("--fps", self.config.fps)
|
||||
|
||||
self.robot_type = self.config.robot_type
|
||||
self.leader_arms = self.config.leader_arms
|
||||
self.cameras = self.config.cameras
|
||||
self.is_connected = False
|
||||
self.teleop = JoystickIntervention(controller_type=ControllerType.XBOX, gripper_enabled=True)
|
||||
self.logs = {}
|
||||
self.action_repeat = 2
|
||||
self.state_keys = None
|
||||
self.action_keys = None
|
||||
self.euler_filter = EulerAngles()
|
||||
# init piper robot
|
||||
self.piper = C_PiperInterface("can0")
|
||||
self.piper.ConnectPort()
|
||||
self.piper.EnableArm(7)
|
||||
self.piper.GripperCtrl(0,1000,0x01, 0)
|
||||
self.state_scaling_factor = 1e6
|
||||
# self.default_pos = [0.200337, 0.020786, 0.289284, 0.179831, 0.010918, 0.173467, 0.0]
|
||||
self.default_pos = [0.171642, -0.028, 0.165, 0.179831, 0.010918, 0.173467, 0.0]
|
||||
self.joint_position_relative_bounds = self.config.joint_position_relative_bounds
|
||||
self.previous_ee_position = np.array([0.0, 0.0, 0.0])
|
||||
self.current_state = np.array([0.0, 0.0, 0.0, 0.0])
|
||||
|
||||
# self.data_rows = []
|
||||
# self.csv_filename = f"arm_poses_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv"
|
||||
# Register the signal handler
|
||||
import signal
|
||||
signal.signal(signal.SIGINT, self.signal_handler)
|
||||
|
||||
@property
|
||||
def motor_features(self) -> dict:
|
||||
# Get the base features from parent class
|
||||
base_features = super().motor_features
|
||||
|
||||
# Modify or add new features
|
||||
base_features["action"]["dtype"] = "float64" # Change dtype
|
||||
# Or completely redefine the features
|
||||
return {
|
||||
"action": {
|
||||
"dtype": "float64",
|
||||
"shape": (4,), # Change shape
|
||||
"names": ["joint1", "joint2"], # New names
|
||||
},
|
||||
"observation.state": {
|
||||
"dtype": "float64",
|
||||
"shape": (7,),
|
||||
"names": ["joint1", "joint2", "velocity1", "velocity2"],
|
||||
},
|
||||
}
|
||||
# Signal handler for graceful shutdown
|
||||
def signal_handler(self, sig, frame):
|
||||
print('\nSaving data and exiting...')
|
||||
import sys
|
||||
# self.save_to_csv()
|
||||
self.disable_robot()
|
||||
sys.exit(0)
|
||||
|
||||
def disable_robot(self):
|
||||
self.piper.DisableArm(7)
|
||||
self.piper.GripperCtrl(0,1000,0x02, 0)
|
||||
|
||||
# Function to save data to CSV
|
||||
def save_to_csv(self):
|
||||
import csv
|
||||
with open(self.csv_filename, 'w', newline='') as csvfile:
|
||||
writer = csv.writer(csvfile)
|
||||
# Write header
|
||||
writer.writerow(['Timestamp', 'X', 'Y', 'Z', 'RX', 'RY', 'RZ', 'grippers_angle', 'grippers_effort'])
|
||||
# Write all stored data
|
||||
writer.writerows(self.data_rows)
|
||||
print(f"Data saved to {self.csv_filename}")
|
||||
|
||||
def startup_robot(self, piper:C_PiperInterface):
|
||||
'''
|
||||
enable robot and check enable status, try 5s, if enable timeout, exit program
|
||||
'''
|
||||
enable_flag = False
|
||||
# 设置超时时间(秒)
|
||||
timeout = 5
|
||||
# 记录进入循环前的时间
|
||||
start_time = time.time()
|
||||
elapsed_time_flag = False
|
||||
while not (enable_flag):
|
||||
elapsed_time = time.time() - start_time
|
||||
print("--------------------")
|
||||
enable_flag = piper.GetArmLowSpdInfoMsgs().motor_1.foc_status.driver_enable_status and \
|
||||
piper.GetArmLowSpdInfoMsgs().motor_2.foc_status.driver_enable_status and \
|
||||
piper.GetArmLowSpdInfoMsgs().motor_3.foc_status.driver_enable_status and \
|
||||
piper.GetArmLowSpdInfoMsgs().motor_4.foc_status.driver_enable_status and \
|
||||
piper.GetArmLowSpdInfoMsgs().motor_5.foc_status.driver_enable_status and \
|
||||
piper.GetArmLowSpdInfoMsgs().motor_6.foc_status.driver_enable_status
|
||||
print("enable status:",enable_flag)
|
||||
piper.EnableArm(7)
|
||||
piper.GripperCtrl(0,1000,0x01, 0)
|
||||
print("--------------------")
|
||||
# check if timeout
|
||||
if elapsed_time > timeout:
|
||||
print("enable timeout....")
|
||||
elapsed_time_flag = True
|
||||
enable_flag = False
|
||||
break
|
||||
time.sleep(1)
|
||||
pass
|
||||
if not elapsed_time_flag:
|
||||
return enable_flag
|
||||
else:
|
||||
print("enable timeout, exit program")
|
||||
raise RuntimeError("Failed to enable robot motors within timeout period")
|
||||
|
||||
def connect(self) -> None:
|
||||
self.is_connected = self.startup_robot(self.piper)
|
||||
|
||||
for name in self.cameras:
|
||||
self.cameras[name].connect()
|
||||
self.is_connected = self.is_connected and self.cameras[name].is_connected
|
||||
|
||||
if not self.is_connected:
|
||||
print("Could not connect to the cameras, check that all cameras are plugged-in.")
|
||||
raise ConnectionError()
|
||||
|
||||
self.move_to_home()
|
||||
|
||||
# Used when connecting to robot
|
||||
def move_to_home(self) -> None:
|
||||
count = 0
|
||||
while True:
|
||||
if(count == 0):
|
||||
print("1-----------")
|
||||
action = [0.07,0,0.22,0,0.08,0,0.08]
|
||||
# elif(count == 400):
|
||||
# print("2-----------")
|
||||
# action = [0.15,0.0,0.35,0.08,0.08,0.025,0.0] # 0.08 is maximum gripper position
|
||||
elif(count == 300):
|
||||
print("2-----------")
|
||||
action = self.default_pos
|
||||
count += 1
|
||||
before_write_t = time.perf_counter()
|
||||
state = self.get_state()
|
||||
state = state["state"]
|
||||
# state[3:6] = self.euler_filter.rectify(state[3:6])
|
||||
self.send_action(action)
|
||||
# self.rate.sleep(time.perf_counter() - before_write_t)
|
||||
if count > 600:
|
||||
break
|
||||
|
||||
# Used when returning to home after finishing a demo
|
||||
def move_to_home_2(self):
|
||||
count = 0
|
||||
while True:
|
||||
if count <= 100:
|
||||
action = self.default_pos
|
||||
action[6] = 0.08
|
||||
elif count > 100:
|
||||
action = self.default_pos
|
||||
action[6] = 0.0
|
||||
count += 1
|
||||
before_write_t = time.perf_counter()
|
||||
state = self.get_state()
|
||||
state = state["state"]
|
||||
# state[3:6] = self.euler_filter.rectify(state[3:6])
|
||||
self.send_action(action)
|
||||
# self.rate.sleep(time.perf_counter() - before_write_t)
|
||||
if count > 300:
|
||||
break
|
||||
|
||||
|
||||
def teleop_step(
|
||||
self, record_data=False
|
||||
) -> None | tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
|
||||
# TODO(aliberts): return ndarrays instead of torch.Tensors
|
||||
if not self.is_connected:
|
||||
raise ConnectionError()
|
||||
|
||||
before_read_t = time.perf_counter()
|
||||
state = self.get_state()
|
||||
self.get_intervention_start()
|
||||
state = state["state"]
|
||||
# state[3:6] = self.euler_filter.rectify(state[3:6])
|
||||
# get relative action from joystick
|
||||
action = self.teleop.action()
|
||||
# print(action)
|
||||
# Convert action to numpy array first
|
||||
action = np.array(action, dtype=np.float32)
|
||||
action_record = deepcopy(action[:2])
|
||||
# action_record = np.concatenate([action[:3], [action[-1]]])
|
||||
|
||||
for i in range(self.action_repeat):
|
||||
action[:2] = state[:2] + action_record*(i+1)/self.action_repeat
|
||||
self.send_action(action)
|
||||
busy_wait(0.03)
|
||||
|
||||
if self.teleop.home:
|
||||
self.move_to_home_2()
|
||||
if self.state_keys is None:
|
||||
self.state_keys = list(state)
|
||||
|
||||
if not record_data:
|
||||
return
|
||||
# action_record[3:6] -= state[3:6] # just get delta orientation
|
||||
# it has to be done after send_action
|
||||
# state[:2] -= self.default_pos[:2]
|
||||
state = torch.as_tensor(state).to(torch.float32)
|
||||
action_record = torch.as_tensor(action_record).to(torch.float32)
|
||||
# print(action_record)
|
||||
|
||||
# Capture images from cameras
|
||||
images = {}
|
||||
for name in self.cameras:
|
||||
before_camread_t = time.perf_counter()
|
||||
images[name] = self.cameras[name].async_read()
|
||||
images[name] = torch.from_numpy(images[name])
|
||||
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
|
||||
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
|
||||
# Populate output dictionnaries
|
||||
obs_dict, action_dict = {}, {}
|
||||
obs_dict["observation.state"] = state
|
||||
# obs_dict.update({"gripper_effort": self.get_gripper_effort()})
|
||||
action_dict["action"] = action_record
|
||||
for name in self.cameras:
|
||||
obs_dict[f"observation.images.{name}"] = images[name]
|
||||
|
||||
return obs_dict, action_dict
|
||||
|
||||
def get_state(self) -> dict:
|
||||
end_effector_pose = self.piper.GetArmEndPoseMsgs()
|
||||
gripper_pose = self.piper.GetArmGripperMsgs()
|
||||
self.previous_ee_position = self.current_state[:3]
|
||||
|
||||
# Convert to float32 numpy array
|
||||
self.current_state = np.array([
|
||||
end_effector_pose.end_pose.X_axis,
|
||||
end_effector_pose.end_pose.Y_axis,
|
||||
end_effector_pose.end_pose.Z_axis,
|
||||
gripper_pose.gripper_state.grippers_angle
|
||||
], dtype=np.float32) / self.state_scaling_factor
|
||||
# add velocity to state
|
||||
velocity = (self.current_state[:3] - self.previous_ee_position) * self.fps
|
||||
state = np.concatenate([self.current_state, velocity])
|
||||
|
||||
# print(f"state: {state}")
|
||||
|
||||
return {
|
||||
"state": state,
|
||||
}
|
||||
|
||||
def get_gripper_state(self) -> float:
|
||||
gripper_pose = self.piper.GetArmGripperMsgs()
|
||||
return np.array([gripper_pose.gripper_state.grippers_angle, gripper_pose.gripper_state.grippers_effort])
|
||||
|
||||
def get_intervention_start(self) -> bool:
|
||||
intervention_start, success, estop = self.teleop.get_intervention_start()
|
||||
if estop:
|
||||
self.disable_robot()
|
||||
return intervention_start, success
|
||||
|
||||
# def get_ee_pos(self) -> list[float]:
|
||||
# end_effector_pose = self.piper.GetArmEndPoseMsgs()
|
||||
# gripper_pose = self.piper.GetArmGripperMsgs()
|
||||
# return [end_effector_pose.end_pose.X_axis,end_effector_pose.end_pose.Y_axis,end_effector_pose.end_pose.Z_axis,gripper_pose.gripper_state.grippers_angle]
|
||||
|
||||
def capture_observation(self) -> dict:
|
||||
# TODO(aliberts): return ndarrays instead of torch.Tensors
|
||||
before_read_t = time.perf_counter()
|
||||
state = self.get_state()
|
||||
self.get_intervention_start()
|
||||
# state = state["state"]
|
||||
# state["state"][3:6] = self.euler_filter.rectify(state["state"][3:6])
|
||||
# state["state"][:2] -= self.default_pos[:2]
|
||||
self.logs["read_pos_dt_s"] = time.perf_counter() - before_read_t
|
||||
|
||||
if self.state_keys is None:
|
||||
self.state_keys = list(state)
|
||||
|
||||
state = torch.as_tensor(np.array(list(state.values())).astype(np.float32))
|
||||
state = state.squeeze(0)
|
||||
|
||||
# Capture images from cameras
|
||||
images = {}
|
||||
for name in self.cameras:
|
||||
before_camread_t = time.perf_counter()
|
||||
images[name] = self.cameras[name].async_read()
|
||||
images[name] = torch.from_numpy(images[name])
|
||||
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
|
||||
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
|
||||
# Populate output dictionnaries
|
||||
obs_dict = {}
|
||||
obs_dict["observation.state"] = state
|
||||
for name in self.cameras:
|
||||
obs_dict[f"observation.images.{name}"] = images[name]
|
||||
# obs_dict.update({"gripper_effort": self.get_gripper_effort()})
|
||||
return obs_dict
|
||||
|
||||
def send_action(self, action: list[float]) -> None:
|
||||
if not self.is_connected:
|
||||
raise ConnectionError()
|
||||
# check if action is tensor and if it is, convert it to list
|
||||
if isinstance(action, torch.Tensor):
|
||||
action = action.tolist()
|
||||
# clip rz value to be between -pi/2 and pi/2 for safety
|
||||
# action[5] = max(-np.pi/2, min(np.pi/2, action[5]))
|
||||
X = round(action[0]*self.state_scaling_factor)
|
||||
Y = round(action[1]*self.state_scaling_factor)
|
||||
Z = round(self.default_pos[2]*self.state_scaling_factor)
|
||||
RX = round(self.default_pos[3]*self.state_scaling_factor)
|
||||
RY = round(self.default_pos[4]*self.state_scaling_factor)
|
||||
RZ = round(self.default_pos[5]*self.state_scaling_factor)
|
||||
Gripper = round(action[-1]*self.state_scaling_factor)
|
||||
|
||||
self.piper.MotionCtrl_2(0x01, 0x00, 30, 0x00)
|
||||
self.piper.EndPoseCtrl(X, Y, Z, RX, RY, RZ)
|
||||
self.piper.GripperCtrl(abs(Gripper), 1000, 0x01, 0)
|
||||
self.piper.MotionCtrl_2(0x01, 0x00, 30, 0x00)
|
||||
# It is needed to give time for the CAN bus communication to complete
|
||||
busy_wait(0.01)
|
||||
|
||||
# TODO(aliberts): return action_sent when motion is limited
|
||||
return torch.tensor(action)
|
||||
|
||||
def print_logs(self) -> None:
|
||||
pass
|
||||
# TODO(aliberts): move robot-specific logs logic here
|
||||
|
||||
def disconnect(self) -> None:
|
||||
if self.teleop is not None:
|
||||
self.teleop.close()
|
||||
|
||||
# if len(self.cameras) > 0:
|
||||
# for cam in self.cameras.values():
|
||||
# cam.disconnect()
|
||||
|
||||
self.is_connected = False
|
||||
|
||||
def __del__(self):
|
||||
self.disconnect()
|
Loading…
Reference in New Issue