unitree_rl_gym/deploy/deploy_real/deploy_real_ros_eetrack.py

385 lines
15 KiB
Python
Raw Normal View History

2025-02-14 14:11:26 +08:00
from legged_gym import LEGGED_GYM_ROOT_DIR
from typing import Union
import numpy as np
import time
import torch
import rclpy as rp
from unitree_hg.msg import LowCmd as LowCmdHG, LowState as LowStateHG
from unitree_go.msg import LowCmd as LowCmdGo, LowState as LowStateGo
from tf2_ros import TransformException
from tf2_ros.buffer import Buffer
from tf2_ros.transform_listener import TransformListener
from common.command_helper_ros import create_damping_cmd, create_zero_cmd, init_cmd_hg, init_cmd_go, MotorMode
from common.rotation_helper import get_gravity_orientation, transform_imu_data
from common.remote_controller import RemoteController, KeyMap
from config import Config
from common.crc import CRC
from enum import Enum
class Mode(Enum):
wait = 0
zero_torque = 1
default_pos = 2
damping = 3
policy = 4
null = 5
def body_pose_axa(frame:str):
""" --> tf does not exist """
try:
t = tf_buffer.lookup_transform(
to_frame_rel,
from_frame_rel,
rclpy.time.Time())
except TransformException as ex:
print(f'Could not transform {to_frame_rel} to {from_frame_rel}: {ex}')
return (np.zeros(3), np.zeros(3))
txn = t.transform.translation
rxn = t.transform.rotation
xyz = [txn.x, txn.y, txn.z]
quat_wxyz = [rxn.w, rxn.x, rxn.y, rxn.z]
xyz = np.array(xyz)
axa = axa_from_quat(quat_wxyz)
return (xyz, axa)
def make_obs(config,
low_state,
quat,
last_action):
# observation terms (order preserved)
# NOTE(ycho): dummy value
base_lin_vel = np.zeros(3)
base_ang_vel = np.array([self.low_state.imu_state.gyroscope], dtype=np.float32)
# FIXME(ycho): check if the convention "q_base^{-1} @ g" holds.
projected_gravity = get_gravity_orientation(quat)
fp_l = body_pose_axa('left_ankle_roll_link')
fp_r = body_pose_axa('right_ankle_roll_link')
foot_pose = np.concatenate([fp_l[0], fp_r[0], fp_l[1], fp_r[1]])
hp_l = body_pose_axa('left_hand_palm_link')
hp_r = body_pose_axa('right_hand_palm_link')
hand_pose = np.concatenate([hp_l[0], hp_r[0], hp_l[1], hp_r[1]])
projected_com = _
projected_zmp = _ # IMPOSSIBLE
joint_pos = []
joint_vel = []
for i in range(len(config.leg_joint2motor_idx)):
joint_pos[i] = low_state.motor_state[config.leg_joint2motor_idx[i]].q
joint_vel[i] = low_state.motor_state[config.leg_joint2motor_idx[i]].dq
actions = last_action
hands_command = _ # goal
right_arm_com = _
left_arm_com = _
pelvis_height = _
class Controller:
def __init__(self, config: Config) -> None:
self.config = config
self.remote_controller = RemoteController()
# Initialize the policy network
self.policy = torch.jit.load(config.policy_path)
# Initializing process variables
self.qj = np.zeros(config.num_actions, dtype=np.float32)
self.dqj = np.zeros(config.num_actions, dtype=np.float32)
self.action = np.zeros(config.num_actions, dtype=np.float32)
self.target_dof_pos = config.default_angles.copy()
self.obs = np.zeros(config.num_obs, dtype=np.float32)
self.cmd = np.array([0.0, 0, 0])
self.counter = 0
rp.init()
self._node = rp.create_node("low_level_cmd_sender")
if config.msg_type == "hg":
# g1 and h1_2 use the hg msg type
self.low_cmd = LowCmdHG()
self.low_state = LowStateHG()
self.lowcmd_publisher_ = self._node.create_publisher(LowCmdHG,
'lowcmd', 10)
self.lowstate_subscriber = self._node.create_subscription(LowStateHG,
'lowstate', self.LowStateHgHandler, 10)
self.mode_pr_ = MotorMode.PR
self.mode_machine_ = 0
# self.lowcmd_publisher_ = ChannelPublisher(config.lowcmd_topic, LowCmdHG)
# self.lowcmd_publisher_.Init()
# self.lowstate_subscriber = ChannelSubscriber(config.lowstate_topic, LowStateHG)
# self.lowstate_subscriber.Init(self.LowStateHgHandler, 10)
elif config.msg_type == "go":
raise ValueError(f"{config.msg_type} is not implemented yet.")
else:
raise ValueError("Invalid msg_type")
# wait for the subscriber to receive data
# self.wait_for_low_state()
# Initialize the command msg
if config.msg_type == "hg":
init_cmd_hg(self.low_cmd, self.mode_machine_, self.mode_pr_)
elif config.msg_type == "go":
init_cmd_go(self.low_cmd, weak_motor=self.config.weak_motor)
self.mode = Mode.wait
self._mode_change = True
self._timer = self._node.create_timer(self.config.control_dt, self.run_wrapper)
self._terminate = False
try:
rp.spin(self._node)
except KeyboardInterrupt:
print("KeyboardInterrupt")
finally:
self._node.destroy_timer(self._timer)
create_damping_cmd(self.low_cmd)
self.send_cmd(self.low_cmd)
self._node.destroy_node()
rp.shutdown()
print("Exit")
def LowStateHgHandler(self, msg: LowStateHG):
self.low_state = msg
self.mode_machine_ = self.low_state.mode_machine
self.remote_controller.set(self.low_state.wireless_remote)
def LowStateGoHandler(self, msg: LowStateGo):
self.low_state = msg
self.remote_controller.set(self.low_state.wireless_remote)
def send_cmd(self, cmd: Union[LowCmdGo, LowCmdHG]):
cmd.mode_machine = self.mode_machine_
cmd.crc = CRC().Crc(cmd)
size = len(cmd.motor_cmd)
# print(cmd.mode_machine)
# for i in range(size):
# print(i, cmd.motor_cmd[i].q,
# cmd.motor_cmd[i].dq,
# cmd.motor_cmd[i].kp,
# cmd.motor_cmd[i].kd,
# cmd.motor_cmd[i].tau)
self.lowcmd_publisher_.publish(cmd)
def wait_for_low_state(self):
while self.low_state.crc == 0:
print(self.low_state)
time.sleep(self.config.control_dt)
print("Successfully connected to the robot.")
def zero_torque_state(self):
if self.remote_controller.button[KeyMap.start] == 1:
self._mode_change = True
self.mode = Mode.default_pos
else:
create_zero_cmd(self.low_cmd)
self.send_cmd(self.low_cmd)
def prepare_default_pos(self):
# move time 2s
total_time = 2
self.counter = 0
self._num_step = int(total_time / self.config.control_dt)
dof_idx = self.config.leg_joint2motor_idx + self.config.arm_waist_joint2motor_idx
kps = self.config.kps + self.config.arm_waist_kps
kds = self.config.kds + self.config.arm_waist_kds
self._kps = [float(kp) for kp in kps]
self._kds = [float(kd) for kd in kds]
self._default_pos = np.concatenate((self.config.default_angles, self.config.arm_waist_target), axis=0)
self._dof_size = len(dof_idx)
self._dof_idx = dof_idx
# record the current pos
self._init_dof_pos = np.zeros(self._dof_size,
dtype=np.float32)
for i in range(self._dof_size):
self._init_dof_pos[i] = self.low_state.motor_state[dof_idx[i]].q
def move_to_default_pos(self):
# move to default pos
if self.counter < self._num_step:
alpha = self.counter / self._num_step
for j in range(self._dof_size):
motor_idx = self._dof_idx[j]
target_pos = self._default_pos[j]
self.low_cmd.motor_cmd[motor_idx].q = (self._init_dof_pos[j] *
(1 - alpha) + target_pos * alpha)
self.low_cmd.motor_cmd[motor_idx].dq = 0.0
self.low_cmd.motor_cmd[motor_idx].kp = self._kps[j]
self.low_cmd.motor_cmd[motor_idx].kd = self._kds[j]
self.low_cmd.motor_cmd[motor_idx].tau = 0.0
self.send_cmd(self.low_cmd)
self.counter += 1
else:
self._mode_change = True
self.mode = Mode.damping
def default_pos_state(self):
if self.remote_controller.button[KeyMap.A] != 1:
for i in range(len(self.config.leg_joint2motor_idx)):
motor_idx = self.config.leg_joint2motor_idx[i]
self.low_cmd.motor_cmd[motor_idx].q = float(self.config.default_angles[i])
self.low_cmd.motor_cmd[motor_idx].dq = 0.0
self.low_cmd.motor_cmd[motor_idx].kp = self._kps[i]
self.low_cmd.motor_cmd[motor_idx].kd = self._kds[i]
self.low_cmd.motor_cmd[motor_idx].tau = 0.0
for i in range(len(self.config.arm_waist_joint2motor_idx)):
motor_idx = self.config.arm_waist_joint2motor_idx[i]
self.low_cmd.motor_cmd[motor_idx].q = float(self.config.arm_waist_target[i])
self.low_cmd.motor_cmd[motor_idx].dq = 0.0
self.low_cmd.motor_cmd[motor_idx].kp = self._kps[i]
self.low_cmd.motor_cmd[motor_idx].kd = self._kds[i]
self.low_cmd.motor_cmd[motor_idx].tau = 0.0
self.send_cmd(self.low_cmd)
else:
self._mode_change = True
self.mode = Mode.policy
def run_policy(self):
if self.remote_controller.button[KeyMap.select] == 1:
self._mode_change = True
self.mode = Mode.null
return
self.counter += 1
# Get the current joint position and velocity
for i in range(len(self.config.leg_joint2motor_idx)):
self.qj[i] = self.low_state.motor_state[self.config.leg_joint2motor_idx[i]].q
self.dqj[i] = self.low_state.motor_state[self.config.leg_joint2motor_idx[i]].dq
# imu_state quaternion: w, x, y, z
quat = self.low_state.imu_state.quaternion
ang_vel = np.array([self.low_state.imu_state.gyroscope], dtype=np.float32)
if self.config.imu_type == "torso":
# h1 and h1_2 imu is on the torso
# imu data needs to be transformed to the pelvis frame
waist_yaw = self.low_state.motor_state[self.config.arm_waist_joint2motor_idx[0]].q
waist_yaw_omega = self.low_state.motor_state[self.config.arm_waist_joint2motor_idx[0]].dq
quat, ang_vel = transform_imu_data(waist_yaw=waist_yaw, waist_yaw_omega=waist_yaw_omega, imu_quat=quat, imu_omega=ang_vel)
# create observation
gravity_orientation = get_gravity_orientation(quat)
qj_obs = self.qj.copy()
dqj_obs = self.dqj.copy()
qj_obs = (qj_obs - self.config.default_angles) * self.config.dof_pos_scale
dqj_obs = dqj_obs * self.config.dof_vel_scale
ang_vel = ang_vel * self.config.ang_vel_scale
period = 0.8
count = self.counter * self.config.control_dt
phase = count % period / period
sin_phase = np.sin(2 * np.pi * phase)
cos_phase = np.cos(2 * np.pi * phase)
self.cmd[0] = self.remote_controller.ly
self.cmd[1] = self.remote_controller.lx * -1
self.cmd[2] = self.remote_controller.rx * -1
# print(self.remote_controller.ly,
# self.remote_controller.lx,
# self.remote_controller.rx)
# self.cmd[0] = 0.0
# self.cmd[1] = 0.0
# self.cmd[2] = 0.0
num_actions = self.config.num_actions
self.obs[:3] = ang_vel
self.obs[3:6] = gravity_orientation
self.obs[6:9] = self.cmd * self.config.cmd_scale * self.config.max_cmd
self.obs[9 : 9 + num_actions] = qj_obs
self.obs[9 + num_actions : 9 + num_actions * 2] = dqj_obs
self.obs[9 + num_actions * 2 : 9 + num_actions * 3] = self.action
self.obs[9 + num_actions * 3] = sin_phase
self.obs[9 + num_actions * 3 + 1] = cos_phase
# Get the action from the policy network
obs_tensor = torch.from_numpy(self.obs).unsqueeze(0)
self.action = self.policy(obs_tensor).detach().numpy().squeeze()
# transform action to target_dof_pos
target_dof_pos = self.config.default_angles + self.action * self.config.action_scale
# Build low cmd
for i in range(len(self.config.leg_joint2motor_idx)):
motor_idx = self.config.leg_joint2motor_idx[i]
self.low_cmd.motor_cmd[motor_idx].q = float(target_dof_pos[i])
self.low_cmd.motor_cmd[motor_idx].dq = 0.0
self.low_cmd.motor_cmd[motor_idx].kp = float(self.config.kps[i])
self.low_cmd.motor_cmd[motor_idx].kd = float(self.config.kds[i])
self.low_cmd.motor_cmd[motor_idx].tau = 0.0
for i in range(len(self.config.arm_waist_joint2motor_idx)):
motor_idx = self.config.arm_waist_joint2motor_idx[i]
self.low_cmd.motor_cmd[motor_idx].q = float(self.config.arm_waist_target[i])
self.low_cmd.motor_cmd[motor_idx].dq = 0.0
self.low_cmd.motor_cmd[motor_idx].kp = float(self.config.arm_waist_kps[i])
self.low_cmd.motor_cmd[motor_idx].kd = float(self.config.arm_waist_kds[i])
self.low_cmd.motor_cmd[motor_idx].tau = 0.0
# send the command
self.send_cmd(self.low_cmd)
def run_wrapper(self):
# print("hello", self.mode,
# self.mode == Mode.zero_torque)
if self.mode == Mode.wait:
if self.low_state.crc != 0:
self.mode = Mode.zero_torque
self.low_cmd.mode_machine = self.mode_machine_
print("Successfully connected to the robot.")
elif self.mode == Mode.zero_torque:
if self._mode_change:
print("Enter zero torque state.")
print("Waiting for the start signal...")
self._mode_change = False
self.zero_torque_state()
elif self.mode == Mode.default_pos:
if self._mode_change:
print("Moving to default pos.")
self._mode_change = False
self.prepare_default_pos()
self.move_to_default_pos()
elif self.mode == Mode.damping:
if self._mode_change:
print("Enter default pos state.")
print("Waiting for the Button A signal...")
self._mode_change = False
self.default_pos_state()
elif self.mode == Mode.policy:
if self._mode_change:
print("Run policy.")
self._mode_change = False
self.counter = 0
self.run_policy()
elif self.mode == Mode.null:
self._terminate = True
# time.sleep(self.config.control_dt)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("config", type=str, help="config file name in the configs folder", default="g1.yaml")
args = parser.parse_args()
# Load config
config_path = f"{LEGGED_GYM_ROOT_DIR}/deploy/deploy_real/configs/{args.config}"
config = Config(config_path)
controller = Controller(config)