support policy learning with rospy
This commit is contained in:
parent
69105aa8d0
commit
84a01f065b
|
@ -0,0 +1,62 @@
|
|||
from unitree_go.msg import LowCmd as LowCmdGo
|
||||
from unitree_hg.msg import LowCmd as LowCmdHG
|
||||
from typing import Union
|
||||
|
||||
|
||||
class MotorMode:
|
||||
PR = 0 # Series Control for Pitch/Roll Joints
|
||||
AB = 1 # Parallel Control for A/B Joints
|
||||
|
||||
|
||||
def create_damping_cmd(cmd: Union[LowCmdGo, LowCmdHG]):
|
||||
size = len(cmd.motor_cmd)
|
||||
for i in range(size):
|
||||
cmd.motor_cmd[i].q = 0.0
|
||||
cmd.motor_cmd[i].dq = 0.0
|
||||
cmd.motor_cmd[i].kp = 0.0
|
||||
cmd.motor_cmd[i].kd = 8.0
|
||||
cmd.motor_cmd[i].tau = 0.0
|
||||
|
||||
|
||||
def create_zero_cmd(cmd: Union[LowCmdGo, LowCmdHG]):
|
||||
size = len(cmd.motor_cmd)
|
||||
for i in range(size):
|
||||
cmd.motor_cmd[i].q = 0.0
|
||||
cmd.motor_cmd[i].dq = 0.0
|
||||
cmd.motor_cmd[i].kp = 0.0
|
||||
cmd.motor_cmd[i].kd = 0.0
|
||||
cmd.motor_cmd[i].tau = 0.0
|
||||
|
||||
|
||||
def init_cmd_hg(cmd: LowCmdHG, mode_machine: int, mode_pr: int):
|
||||
cmd.mode_machine = mode_machine
|
||||
cmd.mode_pr = mode_pr
|
||||
size = len(cmd.motor_cmd)
|
||||
print(size)
|
||||
for i in range(size):
|
||||
cmd.motor_cmd[i].mode = 1
|
||||
cmd.motor_cmd[i].q = 0.0
|
||||
cmd.motor_cmd[i].dq = 0.0
|
||||
cmd.motor_cmd[i].kp = 0.0
|
||||
cmd.motor_cmd[i].kd = 0.0
|
||||
cmd.motor_cmd[i].tau = 0.0
|
||||
|
||||
|
||||
def init_cmd_go(cmd: LowCmdGo, weak_motor: list):
|
||||
cmd.head[0] = 0xFE
|
||||
cmd.head[1] = 0xEF
|
||||
cmd.level_flag = 0xFF
|
||||
cmd.gpio = 0
|
||||
PosStopF = 2.146e9
|
||||
VelStopF = 16000.0
|
||||
size = len(cmd.motor_cmd)
|
||||
for i in range(size):
|
||||
if i in weak_motor:
|
||||
cmd.motor_cmd[i].mode = 1
|
||||
else:
|
||||
cmd.motor_cmd[i].mode = 0x0A
|
||||
cmd.motor_cmd[i].q = PosStopF
|
||||
cmd.motor_cmd[i].dq = VelStopF
|
||||
cmd.motor_cmd[i].kp = 0.0
|
||||
cmd.motor_cmd[i].kd = 0.0
|
||||
cmd.motor_cmd[i].tau = 0.0
|
|
@ -0,0 +1,232 @@
|
|||
import struct
|
||||
import ctypes
|
||||
import os
|
||||
import platform
|
||||
from unitree_hg.msg import LowCmd as LowCmdHG, LowState as LowStateHG
|
||||
from unitree_go.msg import LowCmd as LowCmdGo, LowState as LowStateGo
|
||||
|
||||
class Singleton:
|
||||
__instance = None
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
if cls.__instance is None:
|
||||
cls.__instance = super(Singleton, cls).__new__(cls)
|
||||
return cls.__instance
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
class CRC(Singleton):
|
||||
def __init__(self):
|
||||
#4 bytes aligned, little-endian format.
|
||||
#size 812
|
||||
self.__packFmtLowCmd = '<4B4IH2x' + 'B3x5f3I' * 20 + '4B' + '55Bx2I'
|
||||
#size 1180
|
||||
self.__packFmtLowState = '<4B4IH2x' + '13fb3x' + 'B3x7fb3x3I' * 20 + '4BiH4b15H' + '8hI41B3xf2b2x2f4h2I'
|
||||
#size 1004
|
||||
self.__packFmtHGLowCmd = '<2B2x' + 'B3x5fI' * 35 + '5I'
|
||||
#size 2092
|
||||
self.__packFmtHGLowState = '<2I2B2xI' + '13fh2x' + 'B3x4f2hf7I' * 35 + '40B5I'
|
||||
|
||||
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
self.platform = platform.system()
|
||||
if self.platform == "Linux":
|
||||
if platform.machine()=="x86_64":
|
||||
self.crc_lib = ctypes.CDLL(script_dir + '/lib/crc_amd64.so')
|
||||
elif platform.machine()=="aarch64":
|
||||
self.crc_lib = ctypes.CDLL(script_dir + '/lib/crc_aarch64.so')
|
||||
|
||||
self.crc_lib.crc32_core.argtypes = (ctypes.POINTER(ctypes.c_uint32), ctypes.c_uint32)
|
||||
self.crc_lib.crc32_core.restype = ctypes.c_uint32
|
||||
|
||||
def Crc(self, msg):
|
||||
if type(msg) == LowCmdGo:
|
||||
return self.__Crc32(self.__PackLowCmd(msg))
|
||||
elif type(msg) == LowStateGo:
|
||||
return self.__Crc32(self.__PackLowState(msg))
|
||||
if type(msg) == LowCmdHG:
|
||||
return self.__Crc32(self.__PackHGLowCmd(msg))
|
||||
elif type(msg) == LowStateHG:
|
||||
return self.__Crc32(self.__PackHGLowState(msg))
|
||||
else:
|
||||
raise TypeError('unknown message type to crc')
|
||||
|
||||
def __PackLowCmd(self, cmd: LowCmdGo):
|
||||
origData = []
|
||||
origData.extend(cmd.head)
|
||||
origData.append(cmd.level_flag)
|
||||
origData.append(cmd.frame_reserve)
|
||||
origData.extend(cmd.sn)
|
||||
origData.extend(cmd.version)
|
||||
origData.append(cmd.bandwidth)
|
||||
|
||||
for i in range(20):
|
||||
origData.append(cmd.motor_cmd[i].mode)
|
||||
origData.append(cmd.motor_cmd[i].q)
|
||||
origData.append(cmd.motor_cmd[i].dq)
|
||||
origData.append(cmd.motor_cmd[i].tau)
|
||||
origData.append(cmd.motor_cmd[i].kp)
|
||||
origData.append(cmd.motor_cmd[i].kd)
|
||||
origData.extend(cmd.motor_cmd[i].reserve)
|
||||
|
||||
origData.append(cmd.bms_cmd.off)
|
||||
origData.extend(cmd.bms_cmd.reserve)
|
||||
|
||||
origData.extend(cmd.wireless_remote)
|
||||
origData.extend(cmd.led)
|
||||
origData.extend(cmd.fan)
|
||||
origData.append(cmd.gpio)
|
||||
origData.append(cmd.reserve)
|
||||
origData.append(cmd.crc)
|
||||
|
||||
return self.__Trans(struct.pack(self.__packFmtLowCmd, *origData))
|
||||
|
||||
def __PackLowState(self, state: LowStateGo):
|
||||
origData = []
|
||||
origData.extend(state.head)
|
||||
origData.append(state.level_flag)
|
||||
origData.append(state.frame_reserve)
|
||||
origData.extend(state.sn)
|
||||
origData.extend(state.version)
|
||||
origData.append(state.bandwidth)
|
||||
|
||||
origData.extend(state.imu_state.quaternion)
|
||||
origData.extend(state.imu_state.gyroscope)
|
||||
origData.extend(state.imu_state.accelerometer)
|
||||
origData.extend(state.imu_state.rpy)
|
||||
origData.append(state.imu_state.temperature)
|
||||
|
||||
for i in range(20):
|
||||
origData.append(state.motor_state[i].mode)
|
||||
origData.append(state.motor_state[i].q)
|
||||
origData.append(state.motor_state[i].dq)
|
||||
origData.append(state.motor_state[i].ddq)
|
||||
origData.append(state.motor_state[i].tau_est)
|
||||
origData.append(state.motor_state[i].q_raw)
|
||||
origData.append(state.motor_state[i].dq_raw)
|
||||
origData.append(state.motor_state[i].ddq_raw)
|
||||
origData.append(state.motor_state[i].temperature)
|
||||
origData.append(state.motor_state[i].lost)
|
||||
origData.extend(state.motor_state[i].reserve)
|
||||
|
||||
origData.append(state.bms_state.version_high)
|
||||
origData.append(state.bms_state.version_low)
|
||||
origData.append(state.bms_state.status)
|
||||
origData.append(state.bms_state.soc)
|
||||
origData.append(state.bms_state.current)
|
||||
origData.append(state.bms_state.cycle)
|
||||
origData.extend(state.bms_state.bq_ntc)
|
||||
origData.extend(state.bms_state.mcu_ntc)
|
||||
origData.extend(state.bms_state.cell_vol)
|
||||
|
||||
origData.extend(state.foot_force)
|
||||
origData.extend(state.foot_force_est)
|
||||
origData.append(state.tick)
|
||||
origData.extend(state.wireless_remote)
|
||||
origData.append(state.bit_flag)
|
||||
origData.append(state.adc_reel)
|
||||
origData.append(state.temperature_ntc1)
|
||||
origData.append(state.temperature_ntc2)
|
||||
origData.append(state.power_v)
|
||||
origData.append(state.power_a)
|
||||
origData.extend(state.fan_frequency)
|
||||
origData.append(state.reserve)
|
||||
origData.append(state.crc)
|
||||
|
||||
return self.__Trans(struct.pack(self.__packFmtLowState, *origData))
|
||||
|
||||
def __PackHGLowCmd(self, cmd: LowCmdHG):
|
||||
origData = []
|
||||
origData.append(cmd.mode_pr)
|
||||
origData.append(cmd.mode_machine)
|
||||
|
||||
for i in range(35):
|
||||
origData.append(cmd.motor_cmd[i].mode)
|
||||
origData.append(cmd.motor_cmd[i].q)
|
||||
origData.append(cmd.motor_cmd[i].dq)
|
||||
origData.append(cmd.motor_cmd[i].tau)
|
||||
origData.append(cmd.motor_cmd[i].kp)
|
||||
origData.append(cmd.motor_cmd[i].kd)
|
||||
origData.append(cmd.motor_cmd[i].reserve)
|
||||
|
||||
origData.extend(cmd.reserve)
|
||||
origData.append(cmd.crc)
|
||||
|
||||
return self.__Trans(struct.pack(self.__packFmtHGLowCmd, *origData))
|
||||
|
||||
def __PackHGLowState(self, state: LowStateHG):
|
||||
origData = []
|
||||
origData.extend(state.version)
|
||||
origData.append(state.mode_pr)
|
||||
origData.append(state.mode_machine)
|
||||
origData.append(state.tick)
|
||||
|
||||
origData.extend(state.imu_state.quaternion)
|
||||
origData.extend(state.imu_state.gyroscope)
|
||||
origData.extend(state.imu_state.accelerometer)
|
||||
origData.extend(state.imu_state.rpy)
|
||||
origData.append(state.imu_state.temperature)
|
||||
|
||||
for i in range(35):
|
||||
origData.append(state.motor_state[i].mode)
|
||||
origData.append(state.motor_state[i].q)
|
||||
origData.append(state.motor_state[i].dq)
|
||||
origData.append(state.motor_state[i].ddq)
|
||||
origData.append(state.motor_state[i].tau_est)
|
||||
origData.extend(state.motor_state[i].temperature)
|
||||
origData.append(state.motor_state[i].vol)
|
||||
origData.extend(state.motor_state[i].sensor)
|
||||
origData.append(state.motor_state[i].motorstate)
|
||||
origData.extend(state.motor_state[i].reserve)
|
||||
|
||||
origData.extend(state.wireless_remote)
|
||||
origData.extend(state.reserve)
|
||||
origData.append(state.crc)
|
||||
|
||||
return self.__Trans(struct.pack(self.__packFmtHGLowState, *origData))
|
||||
|
||||
def __Trans(self, packData):
|
||||
calcData = []
|
||||
calcLen = ((len(packData)>>2)-1)
|
||||
|
||||
for i in range(calcLen):
|
||||
d = ((packData[i*4+3] << 24) | (packData[i*4+2] << 16) | (packData[i*4+1] << 8) | (packData[i*4]))
|
||||
calcData.append(d)
|
||||
|
||||
return calcData
|
||||
|
||||
def _crc_py(self, data):
|
||||
bit = 0
|
||||
crc = 0xFFFFFFFF
|
||||
polynomial = 0x04c11db7
|
||||
|
||||
for i in range(len(data)):
|
||||
bit = 1 << 31
|
||||
current = data[i]
|
||||
|
||||
for b in range(32):
|
||||
if crc & 0x80000000:
|
||||
crc = (crc << 1) & 0xFFFFFFFF
|
||||
crc ^= polynomial
|
||||
else:
|
||||
crc = (crc << 1) & 0xFFFFFFFF
|
||||
|
||||
if current & bit:
|
||||
crc ^= polynomial
|
||||
|
||||
bit >>= 1
|
||||
|
||||
return crc
|
||||
|
||||
def _crc_ctypes(self, data):
|
||||
uint32_array = (ctypes.c_uint32 * len(data))(*data)
|
||||
length = len(data)
|
||||
crc=self.crc_lib.crc32_core(uint32_array, length)
|
||||
return crc
|
||||
|
||||
def __Crc32(self, data):
|
||||
if self.platform == "Linux":
|
||||
return self._crc_ctypes(data)
|
||||
else:
|
||||
return self._crc_py(data)
|
|
@ -0,0 +1,42 @@
|
|||
#
|
||||
control_dt: 0.02
|
||||
|
||||
msg_type: "hg" # "hg" or "go"
|
||||
imu_type: "pelvis" # "torso" or "pelvis"
|
||||
|
||||
lowcmd_topic: "rt/lowcmd"
|
||||
lowstate_topic: "rt/lowstate"
|
||||
|
||||
policy_path: "{LEGGED_GYM_ROOT_DIR}/deploy/pre_train/g1/motion.pt"
|
||||
|
||||
leg_joint2motor_idx: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
|
||||
kps: [100, 100, 100, 150, 40, 40, 100, 100, 100, 150, 40, 40]
|
||||
kds: [2, 2, 2, 4, 2, 2, 2, 2, 2, 4, 2, 2]
|
||||
default_angles: [-0.1, 0.0, 0.0, 0.3, -0.2, 0.0,
|
||||
-0.1, 0.0, 0.0, 0.3, -0.2, 0.0]
|
||||
|
||||
arm_waist_joint2motor_idx: [12, 13, 14,
|
||||
15, 16, 17, 18, 19, 20, 21,
|
||||
22, 23, 24, 25, 26, 27, 28]
|
||||
|
||||
arm_waist_kps: [300, 300, 300,
|
||||
100, 100, 50, 50, 20, 20, 20,
|
||||
100, 100, 50, 50, 20, 20, 20]
|
||||
|
||||
arm_waist_kds: [3, 3, 3,
|
||||
2, 2, 2, 2, 1, 1, 1,
|
||||
2, 2, 2, 2, 1, 1, 1]
|
||||
|
||||
arm_waist_target: [ 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0]
|
||||
|
||||
ang_vel_scale: 0.25
|
||||
dof_pos_scale: 1.0
|
||||
dof_vel_scale: 0.05
|
||||
action_scale: 0.25
|
||||
cmd_scale: [2.0, 2.0, 0.25]
|
||||
num_actions: 12
|
||||
num_obs: 47
|
||||
|
||||
max_cmd: [0.8, 0.5, 1.57]
|
|
@ -0,0 +1,265 @@
|
|||
from legged_gym import LEGGED_GYM_ROOT_DIR
|
||||
from typing import Union
|
||||
import numpy as np
|
||||
import time
|
||||
import torch
|
||||
|
||||
from unitree_sdk2py.core.channel import ChannelPublisher, ChannelFactoryInitialize
|
||||
from unitree_sdk2py.core.channel import ChannelSubscriber, ChannelFactoryInitialize
|
||||
from unitree_sdk2py.idl.default import unitree_hg_msg_dds__LowCmd_, unitree_hg_msg_dds__LowState_
|
||||
from unitree_sdk2py.idl.default import unitree_go_msg_dds__LowCmd_, unitree_go_msg_dds__LowState_
|
||||
from unitree_sdk2py.idl.unitree_hg.msg.dds_ import LowCmd_ as LowCmdHG
|
||||
from unitree_sdk2py.idl.unitree_go.msg.dds_ import LowCmd_ as LowCmdGo
|
||||
from unitree_sdk2py.idl.unitree_hg.msg.dds_ import LowState_ as LowStateHG
|
||||
from unitree_sdk2py.idl.unitree_go.msg.dds_ import LowState_ as LowStateGo
|
||||
from unitree_sdk2py.utils.crc import CRC
|
||||
|
||||
from common.command_helper import create_damping_cmd, create_zero_cmd, init_cmd_hg, init_cmd_go, MotorMode
|
||||
from common.rotation_helper import get_gravity_orientation, transform_imu_data
|
||||
from common.remote_controller import RemoteController, KeyMap
|
||||
from config import Config
|
||||
|
||||
|
||||
class Controller:
|
||||
def __init__(self, config: Config) -> None:
|
||||
self.config = config
|
||||
self.remote_controller = RemoteController()
|
||||
|
||||
# Initialize the policy network
|
||||
self.policy = torch.jit.load(config.policy_path)
|
||||
# Initializing process variables
|
||||
self.qj = np.zeros(config.num_actions, dtype=np.float32)
|
||||
self.dqj = np.zeros(config.num_actions, dtype=np.float32)
|
||||
self.action = np.zeros(config.num_actions, dtype=np.float32)
|
||||
self.target_dof_pos = config.default_angles.copy()
|
||||
self.obs = np.zeros(config.num_obs, dtype=np.float32)
|
||||
self.cmd = np.array([0.0, 0, 0])
|
||||
self.counter = 0
|
||||
|
||||
if config.msg_type == "hg":
|
||||
# g1 and h1_2 use the hg msg type
|
||||
self.low_cmd = unitree_hg_msg_dds__LowCmd_()
|
||||
self.low_state = unitree_hg_msg_dds__LowState_()
|
||||
self.mode_pr_ = MotorMode.PR
|
||||
self.mode_machine_ = 0
|
||||
|
||||
self.lowcmd_publisher_ = ChannelPublisher(config.lowcmd_topic, LowCmdHG)
|
||||
self.lowcmd_publisher_.Init()
|
||||
|
||||
self.lowstate_subscriber = ChannelSubscriber(config.lowstate_topic, LowStateHG)
|
||||
self.lowstate_subscriber.Init(self.LowStateHgHandler, 10)
|
||||
|
||||
elif config.msg_type == "go":
|
||||
# h1 uses the go msg type
|
||||
self.low_cmd = unitree_go_msg_dds__LowCmd_()
|
||||
self.low_state = unitree_go_msg_dds__LowState_()
|
||||
|
||||
self.lowcmd_publisher_ = ChannelPublisher(config.lowcmd_topic, LowCmdGo)
|
||||
self.lowcmd_publisher_.Init()
|
||||
|
||||
self.lowstate_subscriber = ChannelSubscriber(config.lowstate_topic, LowStateGo)
|
||||
self.lowstate_subscriber.Init(self.LowStateGoHandler, 10)
|
||||
|
||||
else:
|
||||
raise ValueError("Invalid msg_type")
|
||||
|
||||
# wait for the subscriber to receive data
|
||||
self.wait_for_low_state()
|
||||
|
||||
# Initialize the command msg
|
||||
if config.msg_type == "hg":
|
||||
init_cmd_hg(self.low_cmd, self.mode_machine_, self.mode_pr_)
|
||||
elif config.msg_type == "go":
|
||||
init_cmd_go(self.low_cmd, weak_motor=self.config.weak_motor)
|
||||
|
||||
def LowStateHgHandler(self, msg: LowStateHG):
|
||||
self.low_state = msg
|
||||
self.mode_machine_ = self.low_state.mode_machine
|
||||
self.remote_controller.set(self.low_state.wireless_remote)
|
||||
|
||||
def LowStateGoHandler(self, msg: LowStateGo):
|
||||
self.low_state = msg
|
||||
self.remote_controller.set(self.low_state.wireless_remote)
|
||||
|
||||
def send_cmd(self, cmd: Union[LowCmdGo, LowCmdHG]):
|
||||
cmd.crc = CRC().Crc(cmd)
|
||||
self.lowcmd_publisher_.Write(cmd)
|
||||
|
||||
def wait_for_low_state(self):
|
||||
while self.low_state.tick == 0:
|
||||
time.sleep(self.config.control_dt)
|
||||
print("Successfully connected to the robot.")
|
||||
|
||||
def zero_torque_state(self):
|
||||
print("Enter zero torque state.")
|
||||
print("Waiting for the start signal...")
|
||||
while self.remote_controller.button[KeyMap.start] != 1:
|
||||
create_zero_cmd(self.low_cmd)
|
||||
self.send_cmd(self.low_cmd)
|
||||
time.sleep(self.config.control_dt)
|
||||
|
||||
def move_to_default_pos(self):
|
||||
print("Moving to default pos.")
|
||||
# move time 2s
|
||||
total_time = 2
|
||||
num_step = int(total_time / self.config.control_dt)
|
||||
|
||||
dof_idx = self.config.leg_joint2motor_idx + self.config.arm_waist_joint2motor_idx
|
||||
kps = self.config.kps + self.config.arm_waist_kps
|
||||
kds = self.config.kds + self.config.arm_waist_kds
|
||||
default_pos = np.concatenate((self.config.default_angles, self.config.arm_waist_target), axis=0)
|
||||
dof_size = len(dof_idx)
|
||||
|
||||
# record the current pos
|
||||
init_dof_pos = np.zeros(dof_size, dtype=np.float32)
|
||||
for i in range(dof_size):
|
||||
init_dof_pos[i] = self.low_state.motor_state[dof_idx[i]].q
|
||||
|
||||
# move to default pos
|
||||
for i in range(num_step):
|
||||
alpha = i / num_step
|
||||
for j in range(dof_size):
|
||||
motor_idx = dof_idx[j]
|
||||
target_pos = default_pos[j]
|
||||
self.low_cmd.motor_cmd[motor_idx].q = init_dof_pos[j] * (1 - alpha) + target_pos * alpha
|
||||
self.low_cmd.motor_cmd[motor_idx].qd = 0
|
||||
self.low_cmd.motor_cmd[motor_idx].kp = kps[j]
|
||||
self.low_cmd.motor_cmd[motor_idx].kd = kds[j]
|
||||
self.low_cmd.motor_cmd[motor_idx].tau = 0
|
||||
self.send_cmd(self.low_cmd)
|
||||
time.sleep(self.config.control_dt)
|
||||
|
||||
def default_pos_state(self):
|
||||
print("Enter default pos state.")
|
||||
print("Waiting for the Button A signal...")
|
||||
while self.remote_controller.button[KeyMap.A] != 1:
|
||||
for i in range(len(self.config.leg_joint2motor_idx)):
|
||||
motor_idx = self.config.leg_joint2motor_idx[i]
|
||||
self.low_cmd.motor_cmd[motor_idx].q = self.config.default_angles[i]
|
||||
self.low_cmd.motor_cmd[motor_idx].qd = 0
|
||||
self.low_cmd.motor_cmd[motor_idx].kp = self.config.kps[i]
|
||||
self.low_cmd.motor_cmd[motor_idx].kd = self.config.kds[i]
|
||||
self.low_cmd.motor_cmd[motor_idx].tau = 0
|
||||
for i in range(len(self.config.arm_waist_joint2motor_idx)):
|
||||
motor_idx = self.config.arm_waist_joint2motor_idx[i]
|
||||
self.low_cmd.motor_cmd[motor_idx].q = self.config.arm_waist_target[i]
|
||||
self.low_cmd.motor_cmd[motor_idx].qd = 0
|
||||
self.low_cmd.motor_cmd[motor_idx].kp = self.config.arm_waist_kps[i]
|
||||
self.low_cmd.motor_cmd[motor_idx].kd = self.config.arm_waist_kds[i]
|
||||
self.low_cmd.motor_cmd[motor_idx].tau = 0
|
||||
self.send_cmd(self.low_cmd)
|
||||
time.sleep(self.config.control_dt)
|
||||
|
||||
def run(self):
|
||||
self.counter += 1
|
||||
# Get the current joint position and velocity
|
||||
for i in range(len(self.config.leg_joint2motor_idx)):
|
||||
self.qj[i] = self.low_state.motor_state[self.config.leg_joint2motor_idx[i]].q
|
||||
self.dqj[i] = self.low_state.motor_state[self.config.leg_joint2motor_idx[i]].dq
|
||||
|
||||
# imu_state quaternion: w, x, y, z
|
||||
quat = self.low_state.imu_state.quaternion
|
||||
ang_vel = np.array([self.low_state.imu_state.gyroscope], dtype=np.float32)
|
||||
|
||||
if self.config.imu_type == "torso":
|
||||
# h1 and h1_2 imu is on the torso
|
||||
# imu data needs to be transformed to the pelvis frame
|
||||
waist_yaw = self.low_state.motor_state[self.config.arm_waist_joint2motor_idx[0]].q
|
||||
waist_yaw_omega = self.low_state.motor_state[self.config.arm_waist_joint2motor_idx[0]].dq
|
||||
quat, ang_vel = transform_imu_data(waist_yaw=waist_yaw, waist_yaw_omega=waist_yaw_omega, imu_quat=quat, imu_omega=ang_vel)
|
||||
|
||||
# create observation
|
||||
gravity_orientation = get_gravity_orientation(quat)
|
||||
qj_obs = self.qj.copy()
|
||||
dqj_obs = self.dqj.copy()
|
||||
qj_obs = (qj_obs - self.config.default_angles) * self.config.dof_pos_scale
|
||||
dqj_obs = dqj_obs * self.config.dof_vel_scale
|
||||
ang_vel = ang_vel * self.config.ang_vel_scale
|
||||
period = 0.8
|
||||
count = self.counter * self.config.control_dt
|
||||
phase = count % period / period
|
||||
sin_phase = np.sin(2 * np.pi * phase)
|
||||
cos_phase = np.cos(2 * np.pi * phase)
|
||||
|
||||
self.cmd[0] = self.remote_controller.ly
|
||||
self.cmd[1] = self.remote_controller.lx * -1
|
||||
self.cmd[2] = self.remote_controller.rx * -1
|
||||
|
||||
num_actions = self.config.num_actions
|
||||
self.obs[:3] = ang_vel
|
||||
self.obs[3:6] = gravity_orientation
|
||||
self.obs[6:9] = self.cmd * self.config.cmd_scale * self.config.max_cmd
|
||||
self.obs[9 : 9 + num_actions] = qj_obs
|
||||
self.obs[9 + num_actions : 9 + num_actions * 2] = dqj_obs
|
||||
self.obs[9 + num_actions * 2 : 9 + num_actions * 3] = self.action
|
||||
self.obs[9 + num_actions * 3] = sin_phase
|
||||
self.obs[9 + num_actions * 3 + 1] = cos_phase
|
||||
|
||||
# Get the action from the policy network
|
||||
obs_tensor = torch.from_numpy(self.obs).unsqueeze(0)
|
||||
self.action = self.policy(obs_tensor).detach().numpy().squeeze()
|
||||
|
||||
# transform action to target_dof_pos
|
||||
target_dof_pos = self.config.default_angles + self.action * self.config.action_scale
|
||||
|
||||
# Build low cmd
|
||||
for i in range(len(self.config.leg_joint2motor_idx)):
|
||||
motor_idx = self.config.leg_joint2motor_idx[i]
|
||||
self.low_cmd.motor_cmd[motor_idx].q = target_dof_pos[i]
|
||||
self.low_cmd.motor_cmd[motor_idx].qd = 0
|
||||
self.low_cmd.motor_cmd[motor_idx].kp = self.config.kps[i]
|
||||
self.low_cmd.motor_cmd[motor_idx].kd = self.config.kds[i]
|
||||
self.low_cmd.motor_cmd[motor_idx].tau = 0
|
||||
|
||||
for i in range(len(self.config.arm_waist_joint2motor_idx)):
|
||||
motor_idx = self.config.arm_waist_joint2motor_idx[i]
|
||||
self.low_cmd.motor_cmd[motor_idx].q = self.config.arm_waist_target[i]
|
||||
self.low_cmd.motor_cmd[motor_idx].qd = 0
|
||||
self.low_cmd.motor_cmd[motor_idx].kp = self.config.arm_waist_kps[i]
|
||||
self.low_cmd.motor_cmd[motor_idx].kd = self.config.arm_waist_kds[i]
|
||||
self.low_cmd.motor_cmd[motor_idx].tau = 0
|
||||
|
||||
# send the command
|
||||
self.send_cmd(self.low_cmd)
|
||||
|
||||
time.sleep(self.config.control_dt)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("net", type=str, help="network interface")
|
||||
parser.add_argument("config", type=str, help="config file name in the configs folder", default="g1.yaml")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load config
|
||||
config_path = f"{LEGGED_GYM_ROOT_DIR}/deploy/deploy_real/configs/{args.config}"
|
||||
config = Config(config_path)
|
||||
|
||||
# Initialize DDS communication
|
||||
ChannelFactoryInitialize(0, args.net)
|
||||
|
||||
controller = Controller(config)
|
||||
|
||||
# Enter the zero torque state, press the start key to continue executing
|
||||
controller.zero_torque_state()
|
||||
|
||||
# Move to the default position
|
||||
controller.move_to_default_pos()
|
||||
|
||||
# Enter the default position state, press the A key to continue executing
|
||||
controller.default_pos_state()
|
||||
|
||||
while True:
|
||||
try:
|
||||
controller.run()
|
||||
# Press the select key to exit
|
||||
if controller.remote_controller.button[KeyMap.select] == 1:
|
||||
break
|
||||
except KeyboardInterrupt:
|
||||
break
|
||||
# Enter the damping state
|
||||
create_damping_cmd(controller.low_cmd)
|
||||
controller.send_cmd(controller.low_cmd)
|
||||
print("Exit")
|
|
@ -0,0 +1,322 @@
|
|||
from legged_gym import LEGGED_GYM_ROOT_DIR
|
||||
from typing import Union
|
||||
import numpy as np
|
||||
import time
|
||||
import torch
|
||||
|
||||
import rclpy as rp
|
||||
from unitree_hg.msg import LowCmd as LowCmdHG, LowState as LowStateHG
|
||||
from unitree_go.msg import LowCmd as LowCmdGo, LowState as LowStateGo
|
||||
from common.command_helper_ros import create_damping_cmd, create_zero_cmd, init_cmd_hg, init_cmd_go, MotorMode
|
||||
from common.rotation_helper import get_gravity_orientation, transform_imu_data
|
||||
from common.remote_controller import RemoteController, KeyMap
|
||||
from config import Config
|
||||
from common.crc import CRC
|
||||
|
||||
from enum import Enum
|
||||
|
||||
class Mode(Enum):
|
||||
wait = 0
|
||||
zero_torque = 1
|
||||
default_pos = 2
|
||||
damping = 3
|
||||
policy = 4
|
||||
null = 5
|
||||
|
||||
class Controller:
|
||||
def __init__(self, config: Config) -> None:
|
||||
self.config = config
|
||||
self.remote_controller = RemoteController()
|
||||
|
||||
# Initialize the policy network
|
||||
self.policy = torch.jit.load(config.policy_path)
|
||||
# Initializing process variables
|
||||
self.qj = np.zeros(config.num_actions, dtype=np.float32)
|
||||
self.dqj = np.zeros(config.num_actions, dtype=np.float32)
|
||||
self.action = np.zeros(config.num_actions, dtype=np.float32)
|
||||
self.target_dof_pos = config.default_angles.copy()
|
||||
self.obs = np.zeros(config.num_obs, dtype=np.float32)
|
||||
self.cmd = np.array([0.0, 0, 0])
|
||||
self.counter = 0
|
||||
|
||||
rp.init()
|
||||
self._node = rp.create_node("low_level_cmd_sender")
|
||||
|
||||
if config.msg_type == "hg":
|
||||
# g1 and h1_2 use the hg msg type
|
||||
|
||||
self.low_cmd = LowCmdHG()
|
||||
self.low_state = LowStateHG()
|
||||
|
||||
self.lowcmd_publisher_ = self._node.create_publisher(LowCmdHG,
|
||||
'lowcmd', 10)
|
||||
self.lowstate_subscriber = self._node.create_subscription(LowStateHG,
|
||||
'lowstate', self.LowStateHgHandler, 10)
|
||||
self.mode_pr_ = MotorMode.PR
|
||||
self.mode_machine_ = 0
|
||||
|
||||
# self.lowcmd_publisher_ = ChannelPublisher(config.lowcmd_topic, LowCmdHG)
|
||||
# self.lowcmd_publisher_.Init()
|
||||
|
||||
# self.lowstate_subscriber = ChannelSubscriber(config.lowstate_topic, LowStateHG)
|
||||
# self.lowstate_subscriber.Init(self.LowStateHgHandler, 10)
|
||||
|
||||
elif config.msg_type == "go":
|
||||
raise ValueError(f"{config.msg_type} is not implemented yet.")
|
||||
|
||||
else:
|
||||
raise ValueError("Invalid msg_type")
|
||||
|
||||
# wait for the subscriber to receive data
|
||||
# self.wait_for_low_state()
|
||||
|
||||
# Initialize the command msg
|
||||
if config.msg_type == "hg":
|
||||
init_cmd_hg(self.low_cmd, self.mode_machine_, self.mode_pr_)
|
||||
elif config.msg_type == "go":
|
||||
init_cmd_go(self.low_cmd, weak_motor=self.config.weak_motor)
|
||||
|
||||
self.mode = Mode.wait
|
||||
self._mode_change = True
|
||||
self._timer = self._node.create_timer(self.config.control_dt, self.run_wrapper)
|
||||
self._terminate = False
|
||||
try:
|
||||
rp.spin(self._node)
|
||||
except KeyboardInterrupt:
|
||||
print("KeyboardInterrupt")
|
||||
finally:
|
||||
self._node.destroy_timer(self._timer)
|
||||
create_damping_cmd(self.low_cmd)
|
||||
self.send_cmd(self.low_cmd)
|
||||
self._node.destroy_node()
|
||||
rp.shutdown()
|
||||
print("Exit")
|
||||
|
||||
def LowStateHgHandler(self, msg: LowStateHG):
|
||||
self.low_state = msg
|
||||
self.mode_machine_ = self.low_state.mode_machine
|
||||
self.remote_controller.set(self.low_state.wireless_remote)
|
||||
|
||||
def LowStateGoHandler(self, msg: LowStateGo):
|
||||
self.low_state = msg
|
||||
self.remote_controller.set(self.low_state.wireless_remote)
|
||||
|
||||
def send_cmd(self, cmd: Union[LowCmdGo, LowCmdHG]):
|
||||
cmd.mode_machine = self.mode_machine_
|
||||
cmd.crc = CRC().Crc(cmd)
|
||||
size = len(cmd.motor_cmd)
|
||||
# print(cmd.mode_machine)
|
||||
# for i in range(size):
|
||||
# print(i, cmd.motor_cmd[i].q,
|
||||
# cmd.motor_cmd[i].dq,
|
||||
# cmd.motor_cmd[i].kp,
|
||||
# cmd.motor_cmd[i].kd,
|
||||
# cmd.motor_cmd[i].tau)
|
||||
self.lowcmd_publisher_.publish(cmd)
|
||||
|
||||
def wait_for_low_state(self):
|
||||
while self.low_state.crc == 0:
|
||||
print(self.low_state)
|
||||
time.sleep(self.config.control_dt)
|
||||
print("Successfully connected to the robot.")
|
||||
|
||||
def zero_torque_state(self):
|
||||
if self.remote_controller.button[KeyMap.start] == 1:
|
||||
self._mode_change = True
|
||||
self.mode = Mode.default_pos
|
||||
else:
|
||||
create_zero_cmd(self.low_cmd)
|
||||
self.send_cmd(self.low_cmd)
|
||||
|
||||
def prepare_default_pos(self):
|
||||
# move time 2s
|
||||
total_time = 2
|
||||
self.counter = 0
|
||||
self._num_step = int(total_time / self.config.control_dt)
|
||||
|
||||
dof_idx = self.config.leg_joint2motor_idx + self.config.arm_waist_joint2motor_idx
|
||||
kps = self.config.kps + self.config.arm_waist_kps
|
||||
kds = self.config.kds + self.config.arm_waist_kds
|
||||
self._kps = [float(kp) for kp in kps]
|
||||
self._kds = [float(kd) for kd in kds]
|
||||
self._default_pos = np.concatenate((self.config.default_angles, self.config.arm_waist_target), axis=0)
|
||||
self._dof_size = len(dof_idx)
|
||||
self._dof_idx = dof_idx
|
||||
|
||||
# record the current pos
|
||||
self._init_dof_pos = np.zeros(self._dof_size,
|
||||
dtype=np.float32)
|
||||
for i in range(self._dof_size):
|
||||
self._init_dof_pos[i] = self.low_state.motor_state[dof_idx[i]].q
|
||||
|
||||
def move_to_default_pos(self):
|
||||
# move to default pos
|
||||
if self.counter < self._num_step:
|
||||
alpha = self.counter / self._num_step
|
||||
for j in range(self._dof_size):
|
||||
motor_idx = self._dof_idx[j]
|
||||
target_pos = self._default_pos[j]
|
||||
self.low_cmd.motor_cmd[motor_idx].q = (self._init_dof_pos[j] *
|
||||
(1 - alpha) + target_pos * alpha)
|
||||
self.low_cmd.motor_cmd[motor_idx].dq = 0.0
|
||||
self.low_cmd.motor_cmd[motor_idx].kp = self._kps[j]
|
||||
self.low_cmd.motor_cmd[motor_idx].kd = self._kds[j]
|
||||
self.low_cmd.motor_cmd[motor_idx].tau = 0.0
|
||||
self.send_cmd(self.low_cmd)
|
||||
self.counter += 1
|
||||
else:
|
||||
self._mode_change = True
|
||||
self.mode = Mode.damping
|
||||
|
||||
def default_pos_state(self):
|
||||
if self.remote_controller.button[KeyMap.A] != 1:
|
||||
for i in range(len(self.config.leg_joint2motor_idx)):
|
||||
motor_idx = self.config.leg_joint2motor_idx[i]
|
||||
self.low_cmd.motor_cmd[motor_idx].q = float(self.config.default_angles[i])
|
||||
self.low_cmd.motor_cmd[motor_idx].dq = 0.0
|
||||
self.low_cmd.motor_cmd[motor_idx].kp = self._kps[i]
|
||||
self.low_cmd.motor_cmd[motor_idx].kd = self._kds[i]
|
||||
self.low_cmd.motor_cmd[motor_idx].tau = 0.0
|
||||
for i in range(len(self.config.arm_waist_joint2motor_idx)):
|
||||
motor_idx = self.config.arm_waist_joint2motor_idx[i]
|
||||
self.low_cmd.motor_cmd[motor_idx].q = float(self.config.arm_waist_target[i])
|
||||
self.low_cmd.motor_cmd[motor_idx].dq = 0.0
|
||||
self.low_cmd.motor_cmd[motor_idx].kp = self._kps[i]
|
||||
self.low_cmd.motor_cmd[motor_idx].kd = self._kds[i]
|
||||
self.low_cmd.motor_cmd[motor_idx].tau = 0.0
|
||||
self.send_cmd(self.low_cmd)
|
||||
else:
|
||||
self._mode_change = True
|
||||
self.mode = Mode.policy
|
||||
|
||||
def run_policy(self):
|
||||
if self.remote_controller.button[KeyMap.select] == 1:
|
||||
self._mode_change = True
|
||||
self.mode = Mode.null
|
||||
return
|
||||
self.counter += 1
|
||||
# Get the current joint position and velocity
|
||||
for i in range(len(self.config.leg_joint2motor_idx)):
|
||||
self.qj[i] = self.low_state.motor_state[self.config.leg_joint2motor_idx[i]].q
|
||||
self.dqj[i] = self.low_state.motor_state[self.config.leg_joint2motor_idx[i]].dq
|
||||
|
||||
# imu_state quaternion: w, x, y, z
|
||||
quat = self.low_state.imu_state.quaternion
|
||||
ang_vel = np.array([self.low_state.imu_state.gyroscope], dtype=np.float32)
|
||||
|
||||
if self.config.imu_type == "torso":
|
||||
# h1 and h1_2 imu is on the torso
|
||||
# imu data needs to be transformed to the pelvis frame
|
||||
waist_yaw = self.low_state.motor_state[self.config.arm_waist_joint2motor_idx[0]].q
|
||||
waist_yaw_omega = self.low_state.motor_state[self.config.arm_waist_joint2motor_idx[0]].dq
|
||||
quat, ang_vel = transform_imu_data(waist_yaw=waist_yaw, waist_yaw_omega=waist_yaw_omega, imu_quat=quat, imu_omega=ang_vel)
|
||||
|
||||
# create observation
|
||||
gravity_orientation = get_gravity_orientation(quat)
|
||||
qj_obs = self.qj.copy()
|
||||
dqj_obs = self.dqj.copy()
|
||||
qj_obs = (qj_obs - self.config.default_angles) * self.config.dof_pos_scale
|
||||
dqj_obs = dqj_obs * self.config.dof_vel_scale
|
||||
ang_vel = ang_vel * self.config.ang_vel_scale
|
||||
period = 0.8
|
||||
count = self.counter * self.config.control_dt
|
||||
phase = count % period / period
|
||||
sin_phase = np.sin(2 * np.pi * phase)
|
||||
cos_phase = np.cos(2 * np.pi * phase)
|
||||
|
||||
self.cmd[0] = self.remote_controller.ly
|
||||
self.cmd[1] = self.remote_controller.lx * -1
|
||||
self.cmd[2] = self.remote_controller.rx * -1
|
||||
# print(self.remote_controller.ly,
|
||||
# self.remote_controller.lx,
|
||||
# self.remote_controller.rx)
|
||||
# self.cmd[0] = 0.0
|
||||
# self.cmd[1] = 0.0
|
||||
# self.cmd[2] = 0.0
|
||||
|
||||
num_actions = self.config.num_actions
|
||||
self.obs[:3] = ang_vel
|
||||
self.obs[3:6] = gravity_orientation
|
||||
self.obs[6:9] = self.cmd * self.config.cmd_scale * self.config.max_cmd
|
||||
self.obs[9 : 9 + num_actions] = qj_obs
|
||||
self.obs[9 + num_actions : 9 + num_actions * 2] = dqj_obs
|
||||
self.obs[9 + num_actions * 2 : 9 + num_actions * 3] = self.action
|
||||
self.obs[9 + num_actions * 3] = sin_phase
|
||||
self.obs[9 + num_actions * 3 + 1] = cos_phase
|
||||
|
||||
# Get the action from the policy network
|
||||
obs_tensor = torch.from_numpy(self.obs).unsqueeze(0)
|
||||
self.action = self.policy(obs_tensor).detach().numpy().squeeze()
|
||||
|
||||
# transform action to target_dof_pos
|
||||
target_dof_pos = self.config.default_angles + self.action * self.config.action_scale
|
||||
|
||||
# Build low cmd
|
||||
for i in range(len(self.config.leg_joint2motor_idx)):
|
||||
motor_idx = self.config.leg_joint2motor_idx[i]
|
||||
self.low_cmd.motor_cmd[motor_idx].q = float(target_dof_pos[i])
|
||||
self.low_cmd.motor_cmd[motor_idx].dq = 0.0
|
||||
self.low_cmd.motor_cmd[motor_idx].kp = float(self.config.kps[i])
|
||||
self.low_cmd.motor_cmd[motor_idx].kd = float(self.config.kds[i])
|
||||
self.low_cmd.motor_cmd[motor_idx].tau = 0.0
|
||||
|
||||
for i in range(len(self.config.arm_waist_joint2motor_idx)):
|
||||
motor_idx = self.config.arm_waist_joint2motor_idx[i]
|
||||
self.low_cmd.motor_cmd[motor_idx].q = float(self.config.arm_waist_target[i])
|
||||
self.low_cmd.motor_cmd[motor_idx].dq = 0.0
|
||||
self.low_cmd.motor_cmd[motor_idx].kp = float(self.config.arm_waist_kps[i])
|
||||
self.low_cmd.motor_cmd[motor_idx].kd = float(self.config.arm_waist_kds[i])
|
||||
self.low_cmd.motor_cmd[motor_idx].tau = 0.0
|
||||
|
||||
# send the command
|
||||
self.send_cmd(self.low_cmd)
|
||||
|
||||
def run_wrapper(self):
|
||||
# print("hello", self.mode,
|
||||
# self.mode == Mode.zero_torque)
|
||||
if self.mode == Mode.wait:
|
||||
if self.low_state.crc != 0:
|
||||
self.mode = Mode.zero_torque
|
||||
self.low_cmd.mode_machine = self.mode_machine_
|
||||
print("Successfully connected to the robot.")
|
||||
elif self.mode == Mode.zero_torque:
|
||||
if self._mode_change:
|
||||
print("Enter zero torque state.")
|
||||
print("Waiting for the start signal...")
|
||||
self._mode_change = False
|
||||
self.zero_torque_state()
|
||||
elif self.mode == Mode.default_pos:
|
||||
if self._mode_change:
|
||||
print("Moving to default pos.")
|
||||
self._mode_change = False
|
||||
self.prepare_default_pos()
|
||||
self.move_to_default_pos()
|
||||
elif self.mode == Mode.damping:
|
||||
if self._mode_change:
|
||||
print("Enter default pos state.")
|
||||
print("Waiting for the Button A signal...")
|
||||
self._mode_change = False
|
||||
self.default_pos_state()
|
||||
elif self.mode == Mode.policy:
|
||||
if self._mode_change:
|
||||
print("Run policy.")
|
||||
self._mode_change = False
|
||||
self.counter = 0
|
||||
self.run_policy()
|
||||
elif self.mode == Mode.null:
|
||||
self._terminate = True
|
||||
|
||||
# time.sleep(self.config.control_dt)
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("config", type=str, help="config file name in the configs folder", default="g1.yaml")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load config
|
||||
config_path = f"{LEGGED_GYM_ROOT_DIR}/deploy/deploy_real/configs/{args.config}"
|
||||
config = Config(config_path)
|
||||
|
||||
controller = Controller(config)
|
2
setup.py
2
setup.py
|
@ -8,4 +8,4 @@ setup(name='unitree_rl_gym',
|
|||
packages=find_packages(),
|
||||
author_email='support@unitree.com',
|
||||
description='Template RL environments for Unitree Robots',
|
||||
install_requires=['isaacgym', 'rsl-rl', 'matplotlib', 'numpy==1.20', 'tensorboard', 'mujoco==3.2.3', 'pyyaml'])
|
||||
install_requires=['rsl-rl', 'matplotlib','tensorboard', 'pyyaml'])
|
||||
|
|
Loading…
Reference in New Issue