refactor: add fps parameter to reset_environment and implement frame rate control during reset
This commit is contained in:
parent
bb86dcb5b3
commit
e55e904c3e
|
@ -284,7 +284,7 @@ def control_loop(
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
||||||
def reset_environment(robot, events, reset_time_s):
|
def reset_environment(robot, events, reset_time_s, fps):
|
||||||
# TODO(rcadene): refactor warmup_record and reset_environment
|
# TODO(rcadene): refactor warmup_record and reset_environment
|
||||||
if has_method(robot, "teleop_safety_stop"):
|
if has_method(robot, "teleop_safety_stop"):
|
||||||
robot.teleop_safety_stop()
|
robot.teleop_safety_stop()
|
||||||
|
@ -297,6 +297,7 @@ def reset_environment(robot, events, reset_time_s):
|
||||||
with tqdm.tqdm(total=reset_time_s, desc="Waiting") as pbar:
|
with tqdm.tqdm(total=reset_time_s, desc="Waiting") as pbar:
|
||||||
last_update = 0 # Track the last update time
|
last_update = 0 # Track the last update time
|
||||||
while timestamp < reset_time_s:
|
while timestamp < reset_time_s:
|
||||||
|
start_loop_t = time.perf_counter()
|
||||||
robot.teleop_step(record_data=False)
|
robot.teleop_step(record_data=False)
|
||||||
timestamp = time.perf_counter() - start_vencod_t
|
timestamp = time.perf_counter() - start_vencod_t
|
||||||
|
|
||||||
|
@ -310,6 +311,8 @@ def reset_environment(robot, events, reset_time_s):
|
||||||
events["exit_early"] = False
|
events["exit_early"] = False
|
||||||
break
|
break
|
||||||
|
|
||||||
|
dt_s = time.perf_counter() - start_loop_t
|
||||||
|
busy_wait(1 / fps - dt_s)
|
||||||
|
|
||||||
def stop_recording(robot, listener, display_cameras):
|
def stop_recording(robot, listener, display_cameras):
|
||||||
robot.disconnect()
|
robot.disconnect()
|
||||||
|
|
|
@ -0,0 +1,215 @@
|
||||||
|
import time
|
||||||
|
import multiprocessing
|
||||||
|
import numpy as np
|
||||||
|
import inputs
|
||||||
|
from typing import Tuple
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
class ControllerType(Enum):
|
||||||
|
PS5 = "ps5"
|
||||||
|
XBOX = "xbox"
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ControllerConfig:
|
||||||
|
resolution: dict
|
||||||
|
scale: dict
|
||||||
|
|
||||||
|
class JoystickInterface:
|
||||||
|
"""
|
||||||
|
This class provides an interface to the Joystick/Gamepad.
|
||||||
|
It continuously reads the joystick state and provides
|
||||||
|
a "get_action" method to get the latest action and button state.
|
||||||
|
"""
|
||||||
|
|
||||||
|
CONTROLLER_CONFIGS = {
|
||||||
|
ControllerType.PS5: ControllerConfig(
|
||||||
|
# PS5 controller joystick values have 8 bit resolution [0, 255]
|
||||||
|
resolution={
|
||||||
|
'ABS_X': 2**8,
|
||||||
|
'ABS_Y': 2**8,
|
||||||
|
'ABS_RX': 2**8,
|
||||||
|
'ABS_RY': 2**8,
|
||||||
|
'ABS_Z': 2**8,
|
||||||
|
'ABS_RZ': 2**8,
|
||||||
|
'ABS_HAT0X': 1.0,
|
||||||
|
},
|
||||||
|
scale={
|
||||||
|
'ABS_X': 0.4,
|
||||||
|
'ABS_Y': 0.4,
|
||||||
|
'ABS_RX': 0.5,
|
||||||
|
'ABS_RY': 0.5,
|
||||||
|
'ABS_Z': 0.8,
|
||||||
|
'ABS_RZ': 1.2,
|
||||||
|
'ABS_HAT0X': 0.5,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
ControllerType.XBOX: ControllerConfig(
|
||||||
|
# XBOX controller joystick values have 16 bit resolution [0, 65535]
|
||||||
|
resolution={
|
||||||
|
'ABS_X': 2**16,
|
||||||
|
'ABS_Y': 2**16,
|
||||||
|
'ABS_RX': 2**16,
|
||||||
|
'ABS_RY': 2**16,
|
||||||
|
'ABS_Z': 2**8,
|
||||||
|
'ABS_RZ': 2**8,
|
||||||
|
'ABS_HAT0X': 1.0,
|
||||||
|
},
|
||||||
|
scale={
|
||||||
|
'ABS_X': -0.01,
|
||||||
|
'ABS_Y': -0.005,
|
||||||
|
'ABS_RX': 0.015,
|
||||||
|
'ABS_RY': 0.015,
|
||||||
|
'ABS_Z': 0.0025,
|
||||||
|
'ABS_RZ': 0.0025,
|
||||||
|
'ABS_HAT0X': 0.03,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, controller_type=ControllerType.XBOX):
|
||||||
|
self.controller_type = controller_type
|
||||||
|
self.controller_config = self.CONTROLLER_CONFIGS[controller_type]
|
||||||
|
|
||||||
|
# Manager to handle shared state between processes
|
||||||
|
self.manager = multiprocessing.Manager()
|
||||||
|
self.latest_data = self.manager.dict()
|
||||||
|
self.latest_data["action"] = [0.0] * 6
|
||||||
|
self.latest_data["buttons"] = [False, False, False, False, False, False]
|
||||||
|
|
||||||
|
# Start a process to continuously read Joystick state
|
||||||
|
self._process = multiprocessing.Process(target=self._read_joystick)
|
||||||
|
self._process.daemon = True
|
||||||
|
self._process.start()
|
||||||
|
|
||||||
|
|
||||||
|
def _read_joystick(self):
|
||||||
|
"""Add a try-except to prevent thread crashes"""
|
||||||
|
action = [0.0] * 6
|
||||||
|
buttons = [False, False, False, False, False, False]
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
# Get fresh events
|
||||||
|
events = inputs.get_gamepad()
|
||||||
|
|
||||||
|
# Process events
|
||||||
|
for event in events:
|
||||||
|
if event.code in self.controller_config.resolution:
|
||||||
|
# Calculate relative changes based on the axis
|
||||||
|
# Normalize the joystick input values to range [-1, 1] expected by the environment
|
||||||
|
resolution = self.controller_config.resolution[event.code]
|
||||||
|
if self.controller_type == ControllerType.PS5:
|
||||||
|
normalized_value = (event.state - (resolution / 2)) / (resolution / 2)
|
||||||
|
else:
|
||||||
|
normalized_value = event.state / (resolution / 2)
|
||||||
|
scaled_value = normalized_value * self.controller_config.scale[event.code]
|
||||||
|
|
||||||
|
if event.code == 'ABS_Y':
|
||||||
|
action[0] = scaled_value
|
||||||
|
elif event.code == 'ABS_X':
|
||||||
|
action[1] = scaled_value
|
||||||
|
elif event.code == 'ABS_RZ':
|
||||||
|
action[2] = scaled_value
|
||||||
|
elif event.code == 'ABS_Z':
|
||||||
|
# Flip sign so this will go in the down direction
|
||||||
|
action[2] = -scaled_value
|
||||||
|
elif event.code == 'ABS_RX':
|
||||||
|
action[3] = scaled_value
|
||||||
|
elif event.code == 'ABS_RY':
|
||||||
|
action[4] = scaled_value
|
||||||
|
elif event.code == 'ABS_HAT0X':
|
||||||
|
action[5] = scaled_value
|
||||||
|
|
||||||
|
# Handle button events
|
||||||
|
elif event.code == 'BTN_TL':
|
||||||
|
buttons[0] = bool(event.state)
|
||||||
|
elif event.code == 'BTN_TR':
|
||||||
|
buttons[1] = bool(event.state)
|
||||||
|
# Go back to home, B button on xbox controller
|
||||||
|
elif event.code == 'BTN_EAST':
|
||||||
|
buttons[2] = bool(event.state)
|
||||||
|
# Indicate recording is starting, X button on xbox controller
|
||||||
|
elif event.code == 'BTN_NORTH':
|
||||||
|
buttons[3] = bool(event.state)
|
||||||
|
# Start intervention, A button on xbox controller
|
||||||
|
elif event.code == 'BTN_SOUTH':
|
||||||
|
buttons[4] = bool(event.state)
|
||||||
|
# E-Stop, Y button on xbox controller
|
||||||
|
elif event.code == 'BTN_WEST':
|
||||||
|
buttons[5] = bool(event.state)
|
||||||
|
|
||||||
|
# Update the shared state
|
||||||
|
self.latest_data["action"] = action
|
||||||
|
self.latest_data["buttons"] = buttons
|
||||||
|
|
||||||
|
except inputs.UnpluggedError:
|
||||||
|
print("No controller found. Retrying...")
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
def get_action(self):
|
||||||
|
"""Returns the latest action and button state from the Joystick."""
|
||||||
|
action = self.latest_data["action"]
|
||||||
|
buttons = self.latest_data["buttons"]
|
||||||
|
return np.array(action), buttons
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
"""Close the joystick interface and cleanup resources."""
|
||||||
|
try:
|
||||||
|
if hasattr(self, '_process') and self._process is not None:
|
||||||
|
import signal
|
||||||
|
self._process.terminate()
|
||||||
|
self._process.join()
|
||||||
|
self._process = None
|
||||||
|
except (ImportError, AttributeError):
|
||||||
|
# Fallback if signal module is not available
|
||||||
|
if hasattr(self, '_process') and self._process is not None:
|
||||||
|
self._process.kill()
|
||||||
|
self._process.join()
|
||||||
|
self._process = None
|
||||||
|
|
||||||
|
|
||||||
|
class JoystickIntervention():
|
||||||
|
def __init__(self, controller_type=ControllerType.XBOX, gripper_enabled=True):
|
||||||
|
self.gripper_enabled = gripper_enabled
|
||||||
|
self.expert = JoystickInterface(controller_type=controller_type)
|
||||||
|
self.left, self.right, self.home, self.intervention_start, self.success, self.estop = False, False, False, False, False, False
|
||||||
|
|
||||||
|
def action(self) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Output:
|
||||||
|
- action: joystick action if nonezero; else, policy action
|
||||||
|
"""
|
||||||
|
deadzone = 0.003
|
||||||
|
|
||||||
|
expert_a, buttons = self.expert.get_action()
|
||||||
|
self.left, self.right, self.home, self.intervention_start, self.success, self.estop = tuple(buttons)
|
||||||
|
# import logging
|
||||||
|
# logging.info(f"success on joystick: {self.success}")
|
||||||
|
|
||||||
|
for i, a in enumerate(expert_a):
|
||||||
|
if abs(a) <= deadzone:
|
||||||
|
expert_a[i] = 0.0
|
||||||
|
if abs(expert_a[0]) >= 0.003 and expert_a[1] >= 0.003 and expert_a[1] <= 0.005:
|
||||||
|
expert_a[1] = 0.0
|
||||||
|
expert_a[3:6] /= 2
|
||||||
|
|
||||||
|
if self.gripper_enabled:
|
||||||
|
if self.left: # close gripper
|
||||||
|
gripper_action = [0.0]
|
||||||
|
elif self.right: # open gripper
|
||||||
|
gripper_action = [0.08]
|
||||||
|
else:
|
||||||
|
gripper_action = [0.0]
|
||||||
|
expert_a = np.concatenate((expert_a, gripper_action), axis=0)
|
||||||
|
|
||||||
|
return expert_a
|
||||||
|
|
||||||
|
def get_intervention_start(self) -> bool:
|
||||||
|
_, buttons = self.expert.get_action()
|
||||||
|
_, _, _, self.intervention_start, self.success, self.estop = tuple(buttons)
|
||||||
|
return self.intervention_start, self.success, self.estop
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
self.expert.close()
|
|
@ -0,0 +1,431 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
# To run teleoperate:
|
||||||
|
# python lerobot/scripts/control_robot.py teleoperate --robot-path lerobot/configs/robot/piper.yaml --fps 30
|
||||||
|
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass, field, replace
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from lerobot.common.robot_devices.robots.joystick_interface import JoystickIntervention, ControllerType
|
||||||
|
|
||||||
|
from lerobot.common.robot_devices.cameras.utils import Camera
|
||||||
|
from lerobot.common.robot_devices.robots.manipulator import ManipulatorRobot
|
||||||
|
from piper_sdk import *
|
||||||
|
from lerobot.common.robot_devices.utils import busy_wait
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from copy import deepcopy
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PiperRobotConfig:
|
||||||
|
robot_type: str | None = "piper"
|
||||||
|
fps: int = 20
|
||||||
|
cameras: dict[str, Camera] = field(default_factory=lambda: {})
|
||||||
|
leader_arms: dict = field(default_factory=lambda: {})
|
||||||
|
# TODO(aliberts): add feature with max_relative target
|
||||||
|
# TODO(aliberts): add comment on max_relative target
|
||||||
|
max_relative_target: list[float] | float | None = None
|
||||||
|
joint_position_relative_bounds: dict[np.ndarray] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class Rate:
|
||||||
|
def __init__(self, hz: float):
|
||||||
|
self.period = 1.0 / hz
|
||||||
|
self.last_time = time.perf_counter()
|
||||||
|
|
||||||
|
def sleep(self, elapsed: float):
|
||||||
|
if elapsed < self.period:
|
||||||
|
time.sleep(self.period - elapsed)
|
||||||
|
self.last_time = time.perf_counter()
|
||||||
|
|
||||||
|
|
||||||
|
def rectify_signal(current, previous):
|
||||||
|
"""Rectify a single signal value using its previous value"""
|
||||||
|
if abs(current - previous) >= .18:
|
||||||
|
if current > previous:
|
||||||
|
current -= .36
|
||||||
|
else:
|
||||||
|
current += .36
|
||||||
|
return current
|
||||||
|
|
||||||
|
|
||||||
|
class EulerAngles:
|
||||||
|
def __init__(self):
|
||||||
|
self.prev_rx = 0
|
||||||
|
self.prev_ry = 0
|
||||||
|
self.prev_rz = 0
|
||||||
|
|
||||||
|
def rectify(self, orientation):
|
||||||
|
rx_rectified = rectify_signal(orientation[0], self.prev_rx)
|
||||||
|
ry_rectified = rectify_signal(orientation[1], self.prev_ry)
|
||||||
|
rz_rectified = rectify_signal(orientation[2], self.prev_rz)
|
||||||
|
|
||||||
|
self.prev_rx = rx_rectified
|
||||||
|
self.prev_ry = ry_rectified
|
||||||
|
self.prev_rz = rz_rectified
|
||||||
|
|
||||||
|
return [rx_rectified, ry_rectified, rz_rectified]
|
||||||
|
|
||||||
|
|
||||||
|
class PiperRobot(ManipulatorRobot):
|
||||||
|
"""Wrapper of piper_sdk.robot.Robot"""
|
||||||
|
|
||||||
|
def __init__(self, config: PiperRobotConfig | None = None, **kwargs):
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
if config is None:
|
||||||
|
config = PiperRobotConfig()
|
||||||
|
# Overwrite config arguments using kwargs
|
||||||
|
self.config = replace(config, **kwargs)
|
||||||
|
# get fps fron **kwargs
|
||||||
|
self.fps = kwargs.get("--fps", self.config.fps)
|
||||||
|
|
||||||
|
self.robot_type = self.config.robot_type
|
||||||
|
self.leader_arms = self.config.leader_arms
|
||||||
|
self.cameras = self.config.cameras
|
||||||
|
self.is_connected = False
|
||||||
|
self.teleop = JoystickIntervention(controller_type=ControllerType.XBOX, gripper_enabled=True)
|
||||||
|
self.logs = {}
|
||||||
|
self.action_repeat = 2
|
||||||
|
self.state_keys = None
|
||||||
|
self.action_keys = None
|
||||||
|
self.euler_filter = EulerAngles()
|
||||||
|
# init piper robot
|
||||||
|
self.piper = C_PiperInterface("can0")
|
||||||
|
self.piper.ConnectPort()
|
||||||
|
self.piper.EnableArm(7)
|
||||||
|
self.piper.GripperCtrl(0,1000,0x01, 0)
|
||||||
|
self.state_scaling_factor = 1e6
|
||||||
|
# self.default_pos = [0.200337, 0.020786, 0.289284, 0.179831, 0.010918, 0.173467, 0.0]
|
||||||
|
self.default_pos = [0.171642, -0.028, 0.165, 0.179831, 0.010918, 0.173467, 0.0]
|
||||||
|
self.joint_position_relative_bounds = self.config.joint_position_relative_bounds
|
||||||
|
self.previous_ee_position = np.array([0.0, 0.0, 0.0])
|
||||||
|
self.current_state = np.array([0.0, 0.0, 0.0, 0.0])
|
||||||
|
|
||||||
|
# self.data_rows = []
|
||||||
|
# self.csv_filename = f"arm_poses_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv"
|
||||||
|
# Register the signal handler
|
||||||
|
import signal
|
||||||
|
signal.signal(signal.SIGINT, self.signal_handler)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def motor_features(self) -> dict:
|
||||||
|
# Get the base features from parent class
|
||||||
|
base_features = super().motor_features
|
||||||
|
|
||||||
|
# Modify or add new features
|
||||||
|
base_features["action"]["dtype"] = "float64" # Change dtype
|
||||||
|
# Or completely redefine the features
|
||||||
|
return {
|
||||||
|
"action": {
|
||||||
|
"dtype": "float64",
|
||||||
|
"shape": (4,), # Change shape
|
||||||
|
"names": ["joint1", "joint2"], # New names
|
||||||
|
},
|
||||||
|
"observation.state": {
|
||||||
|
"dtype": "float64",
|
||||||
|
"shape": (7,),
|
||||||
|
"names": ["joint1", "joint2", "velocity1", "velocity2"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
# Signal handler for graceful shutdown
|
||||||
|
def signal_handler(self, sig, frame):
|
||||||
|
print('\nSaving data and exiting...')
|
||||||
|
import sys
|
||||||
|
# self.save_to_csv()
|
||||||
|
self.disable_robot()
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
def disable_robot(self):
|
||||||
|
self.piper.DisableArm(7)
|
||||||
|
self.piper.GripperCtrl(0,1000,0x02, 0)
|
||||||
|
|
||||||
|
# Function to save data to CSV
|
||||||
|
def save_to_csv(self):
|
||||||
|
import csv
|
||||||
|
with open(self.csv_filename, 'w', newline='') as csvfile:
|
||||||
|
writer = csv.writer(csvfile)
|
||||||
|
# Write header
|
||||||
|
writer.writerow(['Timestamp', 'X', 'Y', 'Z', 'RX', 'RY', 'RZ', 'grippers_angle', 'grippers_effort'])
|
||||||
|
# Write all stored data
|
||||||
|
writer.writerows(self.data_rows)
|
||||||
|
print(f"Data saved to {self.csv_filename}")
|
||||||
|
|
||||||
|
def startup_robot(self, piper:C_PiperInterface):
|
||||||
|
'''
|
||||||
|
enable robot and check enable status, try 5s, if enable timeout, exit program
|
||||||
|
'''
|
||||||
|
enable_flag = False
|
||||||
|
# 设置超时时间(秒)
|
||||||
|
timeout = 5
|
||||||
|
# 记录进入循环前的时间
|
||||||
|
start_time = time.time()
|
||||||
|
elapsed_time_flag = False
|
||||||
|
while not (enable_flag):
|
||||||
|
elapsed_time = time.time() - start_time
|
||||||
|
print("--------------------")
|
||||||
|
enable_flag = piper.GetArmLowSpdInfoMsgs().motor_1.foc_status.driver_enable_status and \
|
||||||
|
piper.GetArmLowSpdInfoMsgs().motor_2.foc_status.driver_enable_status and \
|
||||||
|
piper.GetArmLowSpdInfoMsgs().motor_3.foc_status.driver_enable_status and \
|
||||||
|
piper.GetArmLowSpdInfoMsgs().motor_4.foc_status.driver_enable_status and \
|
||||||
|
piper.GetArmLowSpdInfoMsgs().motor_5.foc_status.driver_enable_status and \
|
||||||
|
piper.GetArmLowSpdInfoMsgs().motor_6.foc_status.driver_enable_status
|
||||||
|
print("enable status:",enable_flag)
|
||||||
|
piper.EnableArm(7)
|
||||||
|
piper.GripperCtrl(0,1000,0x01, 0)
|
||||||
|
print("--------------------")
|
||||||
|
# check if timeout
|
||||||
|
if elapsed_time > timeout:
|
||||||
|
print("enable timeout....")
|
||||||
|
elapsed_time_flag = True
|
||||||
|
enable_flag = False
|
||||||
|
break
|
||||||
|
time.sleep(1)
|
||||||
|
pass
|
||||||
|
if not elapsed_time_flag:
|
||||||
|
return enable_flag
|
||||||
|
else:
|
||||||
|
print("enable timeout, exit program")
|
||||||
|
raise RuntimeError("Failed to enable robot motors within timeout period")
|
||||||
|
|
||||||
|
def connect(self) -> None:
|
||||||
|
self.is_connected = self.startup_robot(self.piper)
|
||||||
|
|
||||||
|
for name in self.cameras:
|
||||||
|
self.cameras[name].connect()
|
||||||
|
self.is_connected = self.is_connected and self.cameras[name].is_connected
|
||||||
|
|
||||||
|
if not self.is_connected:
|
||||||
|
print("Could not connect to the cameras, check that all cameras are plugged-in.")
|
||||||
|
raise ConnectionError()
|
||||||
|
|
||||||
|
self.move_to_home()
|
||||||
|
|
||||||
|
# Used when connecting to robot
|
||||||
|
def move_to_home(self) -> None:
|
||||||
|
count = 0
|
||||||
|
while True:
|
||||||
|
if(count == 0):
|
||||||
|
print("1-----------")
|
||||||
|
action = [0.07,0,0.22,0,0.08,0,0.08]
|
||||||
|
# elif(count == 400):
|
||||||
|
# print("2-----------")
|
||||||
|
# action = [0.15,0.0,0.35,0.08,0.08,0.025,0.0] # 0.08 is maximum gripper position
|
||||||
|
elif(count == 300):
|
||||||
|
print("2-----------")
|
||||||
|
action = self.default_pos
|
||||||
|
count += 1
|
||||||
|
before_write_t = time.perf_counter()
|
||||||
|
state = self.get_state()
|
||||||
|
state = state["state"]
|
||||||
|
# state[3:6] = self.euler_filter.rectify(state[3:6])
|
||||||
|
self.send_action(action)
|
||||||
|
# self.rate.sleep(time.perf_counter() - before_write_t)
|
||||||
|
if count > 600:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Used when returning to home after finishing a demo
|
||||||
|
def move_to_home_2(self):
|
||||||
|
count = 0
|
||||||
|
while True:
|
||||||
|
if count <= 100:
|
||||||
|
action = self.default_pos
|
||||||
|
action[6] = 0.08
|
||||||
|
elif count > 100:
|
||||||
|
action = self.default_pos
|
||||||
|
action[6] = 0.0
|
||||||
|
count += 1
|
||||||
|
before_write_t = time.perf_counter()
|
||||||
|
state = self.get_state()
|
||||||
|
state = state["state"]
|
||||||
|
# state[3:6] = self.euler_filter.rectify(state[3:6])
|
||||||
|
self.send_action(action)
|
||||||
|
# self.rate.sleep(time.perf_counter() - before_write_t)
|
||||||
|
if count > 300:
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
|
def teleop_step(
|
||||||
|
self, record_data=False
|
||||||
|
) -> None | tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
|
||||||
|
# TODO(aliberts): return ndarrays instead of torch.Tensors
|
||||||
|
if not self.is_connected:
|
||||||
|
raise ConnectionError()
|
||||||
|
|
||||||
|
before_read_t = time.perf_counter()
|
||||||
|
state = self.get_state()
|
||||||
|
self.get_intervention_start()
|
||||||
|
state = state["state"]
|
||||||
|
# state[3:6] = self.euler_filter.rectify(state[3:6])
|
||||||
|
# get relative action from joystick
|
||||||
|
action = self.teleop.action()
|
||||||
|
# print(action)
|
||||||
|
# Convert action to numpy array first
|
||||||
|
action = np.array(action, dtype=np.float32)
|
||||||
|
action_record = deepcopy(action[:2])
|
||||||
|
# action_record = np.concatenate([action[:3], [action[-1]]])
|
||||||
|
|
||||||
|
for i in range(self.action_repeat):
|
||||||
|
action[:2] = state[:2] + action_record*(i+1)/self.action_repeat
|
||||||
|
self.send_action(action)
|
||||||
|
busy_wait(0.03)
|
||||||
|
|
||||||
|
if self.teleop.home:
|
||||||
|
self.move_to_home_2()
|
||||||
|
if self.state_keys is None:
|
||||||
|
self.state_keys = list(state)
|
||||||
|
|
||||||
|
if not record_data:
|
||||||
|
return
|
||||||
|
# action_record[3:6] -= state[3:6] # just get delta orientation
|
||||||
|
# it has to be done after send_action
|
||||||
|
# state[:2] -= self.default_pos[:2]
|
||||||
|
state = torch.as_tensor(state).to(torch.float32)
|
||||||
|
action_record = torch.as_tensor(action_record).to(torch.float32)
|
||||||
|
# print(action_record)
|
||||||
|
|
||||||
|
# Capture images from cameras
|
||||||
|
images = {}
|
||||||
|
for name in self.cameras:
|
||||||
|
before_camread_t = time.perf_counter()
|
||||||
|
images[name] = self.cameras[name].async_read()
|
||||||
|
images[name] = torch.from_numpy(images[name])
|
||||||
|
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
|
||||||
|
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
|
||||||
|
# Populate output dictionnaries
|
||||||
|
obs_dict, action_dict = {}, {}
|
||||||
|
obs_dict["observation.state"] = state
|
||||||
|
# obs_dict.update({"gripper_effort": self.get_gripper_effort()})
|
||||||
|
action_dict["action"] = action_record
|
||||||
|
for name in self.cameras:
|
||||||
|
obs_dict[f"observation.images.{name}"] = images[name]
|
||||||
|
|
||||||
|
return obs_dict, action_dict
|
||||||
|
|
||||||
|
def get_state(self) -> dict:
|
||||||
|
end_effector_pose = self.piper.GetArmEndPoseMsgs()
|
||||||
|
gripper_pose = self.piper.GetArmGripperMsgs()
|
||||||
|
self.previous_ee_position = self.current_state[:3]
|
||||||
|
|
||||||
|
# Convert to float32 numpy array
|
||||||
|
self.current_state = np.array([
|
||||||
|
end_effector_pose.end_pose.X_axis,
|
||||||
|
end_effector_pose.end_pose.Y_axis,
|
||||||
|
end_effector_pose.end_pose.Z_axis,
|
||||||
|
gripper_pose.gripper_state.grippers_angle
|
||||||
|
], dtype=np.float32) / self.state_scaling_factor
|
||||||
|
# add velocity to state
|
||||||
|
velocity = (self.current_state[:3] - self.previous_ee_position) * self.fps
|
||||||
|
state = np.concatenate([self.current_state, velocity])
|
||||||
|
|
||||||
|
# print(f"state: {state}")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"state": state,
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_gripper_state(self) -> float:
|
||||||
|
gripper_pose = self.piper.GetArmGripperMsgs()
|
||||||
|
return np.array([gripper_pose.gripper_state.grippers_angle, gripper_pose.gripper_state.grippers_effort])
|
||||||
|
|
||||||
|
def get_intervention_start(self) -> bool:
|
||||||
|
intervention_start, success, estop = self.teleop.get_intervention_start()
|
||||||
|
if estop:
|
||||||
|
self.disable_robot()
|
||||||
|
return intervention_start, success
|
||||||
|
|
||||||
|
# def get_ee_pos(self) -> list[float]:
|
||||||
|
# end_effector_pose = self.piper.GetArmEndPoseMsgs()
|
||||||
|
# gripper_pose = self.piper.GetArmGripperMsgs()
|
||||||
|
# return [end_effector_pose.end_pose.X_axis,end_effector_pose.end_pose.Y_axis,end_effector_pose.end_pose.Z_axis,gripper_pose.gripper_state.grippers_angle]
|
||||||
|
|
||||||
|
def capture_observation(self) -> dict:
|
||||||
|
# TODO(aliberts): return ndarrays instead of torch.Tensors
|
||||||
|
before_read_t = time.perf_counter()
|
||||||
|
state = self.get_state()
|
||||||
|
self.get_intervention_start()
|
||||||
|
# state = state["state"]
|
||||||
|
# state["state"][3:6] = self.euler_filter.rectify(state["state"][3:6])
|
||||||
|
# state["state"][:2] -= self.default_pos[:2]
|
||||||
|
self.logs["read_pos_dt_s"] = time.perf_counter() - before_read_t
|
||||||
|
|
||||||
|
if self.state_keys is None:
|
||||||
|
self.state_keys = list(state)
|
||||||
|
|
||||||
|
state = torch.as_tensor(np.array(list(state.values())).astype(np.float32))
|
||||||
|
state = state.squeeze(0)
|
||||||
|
|
||||||
|
# Capture images from cameras
|
||||||
|
images = {}
|
||||||
|
for name in self.cameras:
|
||||||
|
before_camread_t = time.perf_counter()
|
||||||
|
images[name] = self.cameras[name].async_read()
|
||||||
|
images[name] = torch.from_numpy(images[name])
|
||||||
|
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
|
||||||
|
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
|
||||||
|
# Populate output dictionnaries
|
||||||
|
obs_dict = {}
|
||||||
|
obs_dict["observation.state"] = state
|
||||||
|
for name in self.cameras:
|
||||||
|
obs_dict[f"observation.images.{name}"] = images[name]
|
||||||
|
# obs_dict.update({"gripper_effort": self.get_gripper_effort()})
|
||||||
|
return obs_dict
|
||||||
|
|
||||||
|
def send_action(self, action: list[float]) -> None:
|
||||||
|
if not self.is_connected:
|
||||||
|
raise ConnectionError()
|
||||||
|
# check if action is tensor and if it is, convert it to list
|
||||||
|
if isinstance(action, torch.Tensor):
|
||||||
|
action = action.tolist()
|
||||||
|
# clip rz value to be between -pi/2 and pi/2 for safety
|
||||||
|
# action[5] = max(-np.pi/2, min(np.pi/2, action[5]))
|
||||||
|
X = round(action[0]*self.state_scaling_factor)
|
||||||
|
Y = round(action[1]*self.state_scaling_factor)
|
||||||
|
Z = round(self.default_pos[2]*self.state_scaling_factor)
|
||||||
|
RX = round(self.default_pos[3]*self.state_scaling_factor)
|
||||||
|
RY = round(self.default_pos[4]*self.state_scaling_factor)
|
||||||
|
RZ = round(self.default_pos[5]*self.state_scaling_factor)
|
||||||
|
Gripper = round(action[-1]*self.state_scaling_factor)
|
||||||
|
|
||||||
|
self.piper.MotionCtrl_2(0x01, 0x00, 30, 0x00)
|
||||||
|
self.piper.EndPoseCtrl(X, Y, Z, RX, RY, RZ)
|
||||||
|
self.piper.GripperCtrl(abs(Gripper), 1000, 0x01, 0)
|
||||||
|
self.piper.MotionCtrl_2(0x01, 0x00, 30, 0x00)
|
||||||
|
# It is needed to give time for the CAN bus communication to complete
|
||||||
|
busy_wait(0.01)
|
||||||
|
|
||||||
|
# TODO(aliberts): return action_sent when motion is limited
|
||||||
|
return torch.tensor(action)
|
||||||
|
|
||||||
|
def print_logs(self) -> None:
|
||||||
|
pass
|
||||||
|
# TODO(aliberts): move robot-specific logs logic here
|
||||||
|
|
||||||
|
def disconnect(self) -> None:
|
||||||
|
if self.teleop is not None:
|
||||||
|
self.teleop.close()
|
||||||
|
|
||||||
|
# if len(self.cameras) > 0:
|
||||||
|
# for cam in self.cameras.values():
|
||||||
|
# cam.disconnect()
|
||||||
|
|
||||||
|
self.is_connected = False
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
self.disconnect()
|
Loading…
Reference in New Issue