from legged_gym import LEGGED_GYM_ROOT_DIR from typing import Union import numpy as np import time import torch import rclpy as rp from unitree_hg.msg import LowCmd as LowCmdHG, LowState as LowStateHG from unitree_go.msg import LowCmd as LowCmdGo, LowState as LowStateGo from tf2_ros import TransformException from tf2_ros.buffer import Buffer from tf2_ros.transform_listener import TransformListener from common.command_helper_ros import create_damping_cmd, create_zero_cmd, init_cmd_hg, init_cmd_go, MotorMode from common.rotation_helper import get_gravity_orientation, transform_imu_data from common.remote_controller import RemoteController, KeyMap from config import Config from common.crc import CRC from enum import Enum class Mode(Enum): wait = 0 zero_torque = 1 default_pos = 2 damping = 3 policy = 4 null = 5 def axis_angle_from_quat(quat: np.ndarray, eps: float = 1.0e-6) -> np.ndarray: """Convert rotations given as quaternions to axis/angle. Args: quat: The quaternion orientation in (w, x, y, z). Shape is (..., 4). eps: The tolerance for Taylor approximation. Defaults to 1.0e-6. Returns: Rotations given as a vector in axis angle form. Shape is (..., 3). The vector's magnitude is the angle turned anti-clockwise in radians around the vector's direction. Reference: https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/transforms/rotation_conversions.py#L526-L554 """ # Modified to take in quat as [q_w, q_x, q_y, q_z] # Quaternion is [q_w, q_x, q_y, q_z] = [cos(theta/2), n_x * sin(theta/2), n_y * sin(theta/2), n_z * sin(theta/2)] # Axis-angle is [a_x, a_y, a_z] = [theta * n_x, theta * n_y, theta * n_z] # Thus, axis-angle is [q_x, q_y, q_z] / (sin(theta/2) / theta) # When theta = 0, (sin(theta/2) / theta) is undefined # However, as theta --> 0, we can use the Taylor approximation 1/2 - theta^2 / 48 quat = quat * (1.0 - 2.0 * (quat[..., 0:1] < 0.0)) mag = np.linalg.norm(quat[..., 1:], dim=-1) half_angle = np.arctan2(mag, quat[..., 0]) angle = 2.0 * half_angle # check whether to apply Taylor approximation sin_half_angles_over_angles = np.where( angle.abs() > eps, np.sin(half_angle) / angle, 0.5 - angle * angle / 48 ) return quat[..., 1:4] / sin_half_angles_over_angles.unsqueeze(-1) def body_pose_axa( tf_buffer, frame:str, ref_frame:str='pelvis', stamp=None): """ --> tf does not exist """ if stamp is None: stamp = rp.time.Time() try: # t = "ref{=pelvis}_from_frame" transform t = tf_buffer.lookup_transform( ref_frame, #to frame, #from stamp) except TransformException as ex: print(f'Could not transform {frame} to {ref_frame}: {ex}') return (np.zeros(3), np.zeros(3)) txn = t.transform.translation rxn = t.transform.rotation xyz = [txn.x, txn.y, txn.z] quat_wxyz = [rxn.w, rxn.x, rxn.y, rxn.z] xyz = np.array(xyz) axa = axis_angle_from_quat(quat_wxyz) axa = (axa + np.pi) % (2*np.pi) return (xyz, axa) def index_map(k_to, k_from): """ returns an index mapping from k_from to k_to; i.e. k_to[index_map] = k_from """ out = [] for k in k_to: out.append(k_from.index(k)) return out class Observation: def __init__(self, config, tf_buffer:Buffer): self.config = config self.num_lab_joint = len(config.lab_joint) self.tf_buffer = tf_buffer self.lab_from_mot = index_map(config.lab_joint, config.motor_joint) def __call__(self, low_state: LowStateHG, last_action: np.ndarray, hands_command: np.ndarray ): lab_from_mot = self.lab_from_mot num_lab_joint = self.num_lab_joint # observation terms (order preserved) # NOTE(ycho): dummy value # base_lin_vel = np.zeros(3) base_ang_vel = np.array([low_state.imu_state.gyroscope], dtype=np.float32) # FIXME(ycho): check if the convention "q_base^{-1} @ g" holds. quat = low_state.imu_state.quaternion projected_gravity = get_gravity_orientation(quat) fp_l = body_pose_axa(self.tf_buffer,'left_ankle_roll_link') fp_r = body_pose_axa(self.tf_buffer,'right_ankle_roll_link') foot_pose = np.concatenate([fp_l[0], fp_r[0], fp_l[1], fp_r[1]]) hp_l = body_pose_axa(self.tf_buffer,'left_hand_palm_link') hp_r = body_pose_axa(self.tf_buffer,'right_hand_palm_link') hand_pose = np.concatenate([hp_l[0], hp_r[0], hp_l[1], hp_r[1]]) projected_com = quat_rotate( quat_conjugate(quat), (com_pos - root_pos) ) # projected_zmp = _ # IMPOSSIBLE # Map `low_state` to index-mapped joint_{pos,vel} joint_pos = np.zeros(num_lab_joint, dtype=np.float32) joint_vel = np.zeros(num_lab_joint, dtype=np.float32) joint_pos[lab_from_mot] = [low_state.motor_state[i_mot].q for i_mot in range(len(lab_from_mot))] joint_vel[lab_from_mot] = [low_state.motor_state[i_mot].dq for i_mot in range(len(lab_from_mot))] actions = last_action hands_command = hands_command right_arm_com = _ left_arm_com = _ if True: # hack lf_from_pelvis = self.tf_buffer.lookup_transform( 'left_ankle_roll_link', #to 'pelvis' stamp=rp.time.Time() ) rf_from_pelvis = self.tf_buffer.lookup_transform( 'right_ankle_roll_link', #to 'pelvis' stamp=rp.time.Time() ) # NOTE(ycho): we assume at least one of the feet is on the ground # and use the higher of the two as the pelvis height. pelvis_height = max(lf_from_pelvis.transform.translation.z, rf_from_pelvis.transform.translation.z) pelvis_height = [pelvis_height] else: pelvis_height = np.abs(np.dot( projected_gravity, # world frame fp_l[0] ) ) return np.concatenate([ base_ang_vel, projected_gravity, foot_pose, hand_pose, projected_com, joint_pos, joint_vel, actions, hands_command, right_arm_com, left_arm_com, pelvis_height ], axis=-1) class Controller: def __init__(self, config: Config) -> None: self.config = config self.remote_controller = RemoteController() # Initialize the policy network self.policy = torch.jit.load(config.policy_path) # Initializing process variables self.qj = np.zeros(config.num_actions, dtype=np.float32) self.dqj = np.zeros(config.num_actions, dtype=np.float32) self.action = np.zeros(config.num_actions, dtype=np.float32) self.target_dof_pos = config.default_angles.copy() self.obs = np.zeros(config.num_obs, dtype=np.float32) self.cmd = np.array([0.0, 0, 0]) self.counter = 0 rp.init() self._node = rp.create_node("low_level_cmd_sender") if config.msg_type == "hg": # g1 and h1_2 use the hg msg type self.low_cmd = LowCmdHG() self.low_state = LowStateHG() self.lowcmd_publisher_ = self._node.create_publisher(LowCmdHG, 'lowcmd', 10) self.lowstate_subscriber = self._node.create_subscription(LowStateHG, 'lowstate', self.LowStateHgHandler, 10) self.mode_pr_ = MotorMode.PR self.mode_machine_ = 0 # self.lowcmd_publisher_ = ChannelPublisher(config.lowcmd_topic, LowCmdHG) # self.lowcmd_publisher_.Init() # self.lowstate_subscriber = ChannelSubscriber(config.lowstate_topic, LowStateHG) # self.lowstate_subscriber.Init(self.LowStateHgHandler, 10) elif config.msg_type == "go": raise ValueError(f"{config.msg_type} is not implemented yet.") else: raise ValueError("Invalid msg_type") # wait for the subscriber to receive data # self.wait_for_low_state() # Initialize the command msg if config.msg_type == "hg": init_cmd_hg(self.low_cmd, self.mode_machine_, self.mode_pr_) elif config.msg_type == "go": init_cmd_go(self.low_cmd, weak_motor=self.config.weak_motor) self.mode = Mode.wait self._mode_change = True self._timer = self._node.create_timer(self.config.control_dt, self.run_wrapper) self._terminate = False try: rp.spin(self._node) except KeyboardInterrupt: print("KeyboardInterrupt") finally: self._node.destroy_timer(self._timer) create_damping_cmd(self.low_cmd) self.send_cmd(self.low_cmd) self._node.destroy_node() rp.shutdown() print("Exit") def LowStateHgHandler(self, msg: LowStateHG): self.low_state = msg self.mode_machine_ = self.low_state.mode_machine self.remote_controller.set(self.low_state.wireless_remote) def LowStateGoHandler(self, msg: LowStateGo): self.low_state = msg self.remote_controller.set(self.low_state.wireless_remote) def send_cmd(self, cmd: Union[LowCmdGo, LowCmdHG]): cmd.mode_machine = self.mode_machine_ cmd.crc = CRC().Crc(cmd) size = len(cmd.motor_cmd) # print(cmd.mode_machine) # for i in range(size): # print(i, cmd.motor_cmd[i].q, # cmd.motor_cmd[i].dq, # cmd.motor_cmd[i].kp, # cmd.motor_cmd[i].kd, # cmd.motor_cmd[i].tau) self.lowcmd_publisher_.publish(cmd) def wait_for_low_state(self): while self.low_state.crc == 0: print(self.low_state) time.sleep(self.config.control_dt) print("Successfully connected to the robot.") def zero_torque_state(self): if self.remote_controller.button[KeyMap.start] == 1: self._mode_change = True self.mode = Mode.default_pos else: create_zero_cmd(self.low_cmd) self.send_cmd(self.low_cmd) def prepare_default_pos(self): # move time 2s total_time = 2 self.counter = 0 self._num_step = int(total_time / self.config.control_dt) dof_idx = self.config.leg_joint2motor_idx + self.config.arm_waist_joint2motor_idx kps = self.config.kps + self.config.arm_waist_kps kds = self.config.kds + self.config.arm_waist_kds self._kps = [float(kp) for kp in kps] self._kds = [float(kd) for kd in kds] self._default_pos = np.concatenate((self.config.default_angles, self.config.arm_waist_target), axis=0) self._dof_size = len(dof_idx) self._dof_idx = dof_idx # record the current pos self._init_dof_pos = np.zeros(self._dof_size, dtype=np.float32) for i in range(self._dof_size): self._init_dof_pos[i] = self.low_state.motor_state[dof_idx[i]].q def move_to_default_pos(self): # move to default pos if self.counter < self._num_step: alpha = self.counter / self._num_step for j in range(self._dof_size): motor_idx = self._dof_idx[j] target_pos = self._default_pos[j] self.low_cmd.motor_cmd[motor_idx].q = (self._init_dof_pos[j] * (1 - alpha) + target_pos * alpha) self.low_cmd.motor_cmd[motor_idx].dq = 0.0 self.low_cmd.motor_cmd[motor_idx].kp = self._kps[j] self.low_cmd.motor_cmd[motor_idx].kd = self._kds[j] self.low_cmd.motor_cmd[motor_idx].tau = 0.0 self.send_cmd(self.low_cmd) self.counter += 1 else: self._mode_change = True self.mode = Mode.damping def default_pos_state(self): if self.remote_controller.button[KeyMap.A] != 1: for i in range(len(self.config.leg_joint2motor_idx)): motor_idx = self.config.leg_joint2motor_idx[i] self.low_cmd.motor_cmd[motor_idx].q = float(self.config.default_angles[i]) self.low_cmd.motor_cmd[motor_idx].dq = 0.0 self.low_cmd.motor_cmd[motor_idx].kp = self._kps[i] self.low_cmd.motor_cmd[motor_idx].kd = self._kds[i] self.low_cmd.motor_cmd[motor_idx].tau = 0.0 for i in range(len(self.config.arm_waist_joint2motor_idx)): motor_idx = self.config.arm_waist_joint2motor_idx[i] self.low_cmd.motor_cmd[motor_idx].q = float(self.config.arm_waist_target[i]) self.low_cmd.motor_cmd[motor_idx].dq = 0.0 self.low_cmd.motor_cmd[motor_idx].kp = self._kps[i] self.low_cmd.motor_cmd[motor_idx].kd = self._kds[i] self.low_cmd.motor_cmd[motor_idx].tau = 0.0 self.send_cmd(self.low_cmd) else: self._mode_change = True self.mode = Mode.policy def run_policy(self): if self.remote_controller.button[KeyMap.select] == 1: self._mode_change = True self.mode = Mode.null return self.counter += 1 # Get the current joint position and velocity for i in range(len(self.config.leg_joint2motor_idx)): self.qj[i] = self.low_state.motor_state[self.config.leg_joint2motor_idx[i]].q self.dqj[i] = self.low_state.motor_state[self.config.leg_joint2motor_idx[i]].dq # imu_state quaternion: w, x, y, z quat = self.low_state.imu_state.quaternion ang_vel = np.array([self.low_state.imu_state.gyroscope], dtype=np.float32) if self.config.imu_type == "torso": # h1 and h1_2 imu is on the torso # imu data needs to be transformed to the pelvis frame waist_yaw = self.low_state.motor_state[self.config.arm_waist_joint2motor_idx[0]].q waist_yaw_omega = self.low_state.motor_state[self.config.arm_waist_joint2motor_idx[0]].dq quat, ang_vel = transform_imu_data(waist_yaw=waist_yaw, waist_yaw_omega=waist_yaw_omega, imu_quat=quat, imu_omega=ang_vel) # create observation gravity_orientation = get_gravity_orientation(quat) qj_obs = self.qj.copy() dqj_obs = self.dqj.copy() qj_obs = (qj_obs - self.config.default_angles) * self.config.dof_pos_scale dqj_obs = dqj_obs * self.config.dof_vel_scale ang_vel = ang_vel * self.config.ang_vel_scale period = 0.8 count = self.counter * self.config.control_dt phase = count % period / period sin_phase = np.sin(2 * np.pi * phase) cos_phase = np.cos(2 * np.pi * phase) self.cmd[0] = self.remote_controller.ly self.cmd[1] = self.remote_controller.lx * -1 self.cmd[2] = self.remote_controller.rx * -1 # print(self.remote_controller.ly, # self.remote_controller.lx, # self.remote_controller.rx) # self.cmd[0] = 0.0 # self.cmd[1] = 0.0 # self.cmd[2] = 0.0 num_actions = self.config.num_actions self.obs[:3] = ang_vel self.obs[3:6] = gravity_orientation self.obs[6:9] = self.cmd * self.config.cmd_scale * self.config.max_cmd self.obs[9 : 9 + num_actions] = qj_obs self.obs[9 + num_actions : 9 + num_actions * 2] = dqj_obs self.obs[9 + num_actions * 2 : 9 + num_actions * 3] = self.action self.obs[9 + num_actions * 3] = sin_phase self.obs[9 + num_actions * 3 + 1] = cos_phase # Get the action from the policy network obs_tensor = torch.from_numpy(self.obs).unsqueeze(0) self.action = self.policy(obs_tensor).detach().numpy().squeeze() # transform action to target_dof_pos target_dof_pos = self.config.default_angles + self.action * self.config.action_scale # Build low cmd for i in range(len(self.config.leg_joint2motor_idx)): motor_idx = self.config.leg_joint2motor_idx[i] self.low_cmd.motor_cmd[motor_idx].q = float(target_dof_pos[i]) self.low_cmd.motor_cmd[motor_idx].dq = 0.0 self.low_cmd.motor_cmd[motor_idx].kp = float(self.config.kps[i]) self.low_cmd.motor_cmd[motor_idx].kd = float(self.config.kds[i]) self.low_cmd.motor_cmd[motor_idx].tau = 0.0 for i in range(len(self.config.arm_waist_joint2motor_idx)): motor_idx = self.config.arm_waist_joint2motor_idx[i] self.low_cmd.motor_cmd[motor_idx].q = float(self.config.arm_waist_target[i]) self.low_cmd.motor_cmd[motor_idx].dq = 0.0 self.low_cmd.motor_cmd[motor_idx].kp = float(self.config.arm_waist_kps[i]) self.low_cmd.motor_cmd[motor_idx].kd = float(self.config.arm_waist_kds[i]) self.low_cmd.motor_cmd[motor_idx].tau = 0.0 # send the command self.send_cmd(self.low_cmd) def run_wrapper(self): # print("hello", self.mode, # self.mode == Mode.zero_torque) if self.mode == Mode.wait: if self.low_state.crc != 0: self.mode = Mode.zero_torque self.low_cmd.mode_machine = self.mode_machine_ print("Successfully connected to the robot.") elif self.mode == Mode.zero_torque: if self._mode_change: print("Enter zero torque state.") print("Waiting for the start signal...") self._mode_change = False self.zero_torque_state() elif self.mode == Mode.default_pos: if self._mode_change: print("Moving to default pos.") self._mode_change = False self.prepare_default_pos() self.move_to_default_pos() elif self.mode == Mode.damping: if self._mode_change: print("Enter default pos state.") print("Waiting for the Button A signal...") self._mode_change = False self.default_pos_state() elif self.mode == Mode.policy: if self._mode_change: print("Run policy.") self._mode_change = False self.counter = 0 self.run_policy() elif self.mode == Mode.null: self._terminate = True # time.sleep(self.config.control_dt) if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument("config", type=str, help="config file name in the configs folder", default="g1.yaml") args = parser.parse_args() # Load config config_path = f"{LEGGED_GYM_ROOT_DIR}/deploy/deploy_real/configs/{args.config}" config = Config(config_path) controller = Controller(config)