Modified g1 urdf with rubber hand and correct collision mesh; Modified deploy_mujoco, deploy_real, and play by using the action of all 29 joint of the g1; Implemented the joint transform to deal with different joint order between the IsaacLab and the low-level controller

This commit is contained in:
yoontae Cho 2025-02-02 17:05:43 +09:00
parent 757b051580
commit 69105aa8d0
11 changed files with 563 additions and 103 deletions

6
MUJOCO_LOG.TXT Normal file
View File

@ -0,0 +1,6 @@
Sat Jan 25 16:14:06 2025
WARNING: Nan, Inf or huge value in QACC at DOF 0. The simulation is unstable. Time = 0.1650.
Sat Jan 25 16:23:57 2025
WARNING: Nan, Inf or huge value in QACC at DOF 18. The simulation is unstable. Time = 0.1150.

View File

@ -1,26 +1,54 @@
#
policy_path: "{LEGGED_GYM_ROOT_DIR}/deploy/pre_train/g1/motion.pt"
xml_path: "{LEGGED_GYM_ROOT_DIR}/resources/robots/g1_description/scene.xml"
# policy_path: "{LEGGED_GYM_ROOT_DIR}/deploy/pre_train/g1/motion.pt"
# policy_path: "{LEGGED_GYM_ROOT_DIR}/logs/g1/walk_with_dr_test_v21/exported/policy.pt"
policy_path: "{LEGGED_GYM_ROOT_DIR}/logs/g1/walk_with_dr_test_v22/exported/policy.pt"
# xml_path: "{LEGGED_GYM_ROOT_DIR}/resources/robots/g1_description/scene.xml"
xml_path: "{LEGGED_GYM_ROOT_DIR}/resources/robots/g1_description/g1_29dof_rev_1_0.xml"
# Total simulation time
simulation_duration: 60.0
# simulation_duration: 5.
# Simulation time step
simulation_dt: 0.002
# Controller update frequency (meets the requirement of simulation_dt * controll_decimation=0.02; 50Hz)
control_decimation: 10
kps: [100, 100, 100, 150, 40, 40, 100, 100, 100, 150, 40, 40]
kds: [2, 2, 2, 4, 2, 2, 2, 2, 2, 4, 2, 2]
# kps: [100, 100, 100, 150, 40, 40, 100, 100, 100, 150, 40, 40]
kps: [
100, 100, 100, 150, 40, 40, 100, 100, 100, 150, 40, 40,
150, 150, 150,
100, 100, 50, 50, 20, 20, 20,
100, 100, 50, 50, 20, 20, 20
]
kds: [
2, 2, 2, 4, 2, 2, 2, 2, 2, 4, 2, 2,
3, 3, 3,
2, 2, 2, 2, 1, 1, 1,
2, 2, 2, 2, 1, 1, 1
]
default_angles: [-0.1, 0.0, 0.0, 0.3, -0.2, 0.0,
-0.1, 0.0, 0.0, 0.3, -0.2, 0.0]
default_angles: [
-0.2, 0.0, 0.0, 0.42, -0.23, 0.0,
-0.2, 0.0, 0.0, 0.42, -0.23, 0.0,
0., 0., 0.,
0.35, 0.16, 0., 0.87, 0., 0., 0.,
0.35, -0.16, 0., 0.87, 0., 0., 0.,
]
ang_vel_scale: 0.25
# ang_vel_scale: 0.25
ang_vel_scale: 1.0
dof_pos_scale: 1.0
dof_vel_scale: 0.05
action_scale: 0.25
cmd_scale: [2.0, 2.0, 0.25]
num_actions: 12
num_obs: 47
# dof_vel_scale: 0.05
dof_vel_scale: 1.0
# action_scale: 0.25
# action_scale: 1.0
action_scale: 0.5
# cmd_scale: [2.0, 2.0, 0.25]
cmd_scale: [1.0, 1.0, 1.0]
# cmd_scale: [0., 0., 0.]
# num_actions: 12
num_actions: 29
# num_obs: 47
num_obs: 96
cmd_init: [0.5, 0, 0]

View File

@ -30,6 +30,81 @@ def pd_control(target_q, q, kp, target_dq, dq, kd):
if __name__ == "__main__":
# get config file name from command line
isaaclab_joint_order = [
'left_hip_pitch_joint',
'right_hip_pitch_joint',
'waist_yaw_joint',
'left_hip_roll_joint',
'right_hip_roll_joint',
'waist_roll_joint',
'left_hip_yaw_joint',
'right_hip_yaw_joint',
'waist_pitch_joint',
'left_knee_joint',
'right_knee_joint',
'left_shoulder_pitch_joint',
'right_shoulder_pitch_joint',
'left_ankle_pitch_joint',
'right_ankle_pitch_joint',
'left_shoulder_roll_joint',
'right_shoulder_roll_joint',
'left_ankle_roll_joint',
'right_ankle_roll_joint',
'left_shoulder_yaw_joint',
'right_shoulder_yaw_joint',
'left_elbow_joint',
'right_elbow_joint',
'left_wrist_roll_joint',
'right_wrist_roll_joint',
'left_wrist_pitch_joint',
'right_wrist_pitch_joint',
'left_wrist_yaw_joint',
'right_wrist_yaw_joint'
]
raw_joint_order = [
'left_hip_pitch_joint',
'left_hip_roll_joint',
'left_hip_yaw_joint',
'left_knee_joint',
'left_ankle_pitch_joint',
'left_ankle_roll_joint',
'right_hip_pitch_joint',
'right_hip_roll_joint',
'right_hip_yaw_joint',
'right_knee_joint',
'right_ankle_pitch_joint',
'right_ankle_roll_joint',
'waist_yaw_joint',
'waist_roll_joint',
'waist_pitch_joint',
'left_shoulder_pitch_joint',
'left_shoulder_roll_joint',
'left_shoulder_yaw_joint',
'left_elbow_joint',
'left_wrist_roll_joint',
'left_wrist_pitch_joint',
'left_wrist_yaw_joint',
'right_shoulder_pitch_joint',
'right_shoulder_roll_joint',
'right_shoulder_yaw_joint',
'right_elbow_joint',
'right_wrist_roll_joint',
'right_wrist_pitch_joint',
'right_wrist_yaw_joint'
]
# Create a mapping tensor
# mapping_tensor = torch.zeros((len(sim_b_joints), len(sim_a_joints)), device=env.device)
mapping_tensor = torch.zeros((len(raw_joint_order), len(isaaclab_joint_order)))
# Fill the mapping tensor
for b_idx, b_joint in enumerate(raw_joint_order):
if b_joint in isaaclab_joint_order:
a_idx = isaaclab_joint_order.index(b_joint)
# mapping_tensor[b_idx, a_idx] = 1.0
mapping_tensor[a_idx, b_idx] = 1.0
import argparse
parser = argparse.ArgumentParser()
@ -81,6 +156,10 @@ if __name__ == "__main__":
start = time.time()
while viewer.is_running() and time.time() - start < simulation_duration:
step_start = time.time()
from icecream import ic
# ic(
# target_dof_pos, d.qpos[7:], kps, np.zeros_like(kds), d.qvel[6:], kds
# )
tau = pd_control(target_dof_pos, d.qpos[7:], kps, np.zeros_like(kds), d.qvel[6:], kds)
d.ctrl[:] = tau
# mj_step can be replaced with code that also evaluates
@ -114,12 +193,32 @@ if __name__ == "__main__":
obs[9 : 9 + num_actions] = qj
obs[9 + num_actions : 9 + 2 * num_actions] = dqj
obs[9 + 2 * num_actions : 9 + 3 * num_actions] = action
obs[9 + 3 * num_actions : 9 + 3 * num_actions + 2] = np.array([sin_phase, cos_phase])
# obs[9 + 3 * num_actions : 9 + 3 * num_actions + 2] = np.array([sin_phase, cos_phase])
obs_tensor = torch.from_numpy(obs).unsqueeze(0)
obs_tensor[..., 9:38] = obs_tensor[..., 9:38] @ mapping_tensor.transpose(0, 1)
obs_tensor[..., 38:67] = obs_tensor[..., 38:67] @ mapping_tensor.transpose(0, 1)
obs_tensor[..., 67:96] = obs_tensor[..., 67:96] @ mapping_tensor.transpose(0, 1)
# from icecream import ic
# ic(
# obs[..., :9],
# obs[..., 9:38],
# obs[..., 38:67],
# obs[..., 67:96],
# )
# policy inference
action = policy(obs_tensor).detach().numpy().squeeze()
# reordered_actions = action @ mapping_tensor.detach().cpu().numpy()
action = action @ mapping_tensor.detach().cpu().numpy()
# ic(
# action
# )
# action = 0.
# transform action to target_dof_pos
# target_dof_pos = action * action_scale + default_angles
target_dof_pos = action * action_scale + default_angles
# raise NotImplementedError
# Pick up changes to the physics state, apply perturbations, update options from GUI.
viewer.sync()

View File

@ -10,33 +10,59 @@ lowstate_topic: "rt/lowstate"
policy_path: "{LEGGED_GYM_ROOT_DIR}/deploy/pre_train/g1/motion.pt"
leg_joint2motor_idx: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
kps: [100, 100, 100, 150, 40, 40, 100, 100, 100, 150, 40, 40]
kds: [2, 2, 2, 4, 2, 2, 2, 2, 2, 4, 2, 2]
default_angles: [-0.1, 0.0, 0.0, 0.3, -0.2, 0.0,
-0.1, 0.0, 0.0, 0.3, -0.2, 0.0]
# kps: [100, 100, 100, 150, 40, 40, 100, 100, 100, 150, 40, 40]
# kds: [2, 2, 2, 4, 2, 2, 2, 2, 2, 4, 2, 2]
# default_angles: [-0.1, 0.0, 0.0, 0.3, -0.2, 0.0,
# -0.1, 0.0, 0.0, 0.3, -0.2, 0.0]
kps: [
100, 100, 100, 150, 40, 40, 100, 100, 100, 150, 40, 40,
150, 150, 150,
100, 100, 50, 50, 20, 20, 20,
100, 100, 50, 50, 20, 20, 20
]
kds: [
2, 2, 2, 4, 2, 2, 2, 2, 2, 4, 2, 2,
3, 3, 3,
2, 2, 2, 2, 1, 1, 1,
2, 2, 2, 2, 1, 1, 1
]
default_angles: [
-0.2, 0.0, 0.0, 0.42, -0.23, 0.0,
-0.2, 0.0, 0.0, 0.42, -0.23, 0.0,
0., 0., 0.,
0.35, 0.16, 0., 0.87, 0., 0., 0.,
0.35, -0.16, 0., 0.87, 0., 0., 0.,
]
arm_waist_joint2motor_idx: [12, 13, 14,
15, 16, 17, 18, 19, 20, 21,
22, 23, 24, 25, 26, 27, 28]
# arm_waist_joint2motor_idx: [12, 13, 14,
# 15, 16, 17, 18, 19, 20, 21,
# 22, 23, 24, 25, 26, 27, 28]
arm_waist_kps: [300, 300, 300,
100, 100, 50, 50, 20, 20, 20,
100, 100, 50, 50, 20, 20, 20]
# arm_waist_kps: [300, 300, 300,
# 100, 100, 50, 50, 20, 20, 20,
# 100, 100, 50, 50, 20, 20, 20]
arm_waist_kds: [3, 3, 3,
2, 2, 2, 2, 1, 1, 1,
2, 2, 2, 2, 1, 1, 1]
# arm_waist_kds: [3, 3, 3,
# 2, 2, 2, 2, 1, 1, 1,
# 2, 2, 2, 2, 1, 1, 1]
arm_waist_target: [ 0, 0, 0,
0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0]
# arm_waist_target: [ 0, 0, 0,
# 0, 0, 0, 0, 0, 0, 0,
# 0, 0, 0, 0, 0, 0, 0]
ang_vel_scale: 0.25
# ang_vel_scale: 0.25
ang_vel_scale: 1.0
dof_pos_scale: 1.0
dof_vel_scale: 0.05
action_scale: 0.25
cmd_scale: [2.0, 2.0, 0.25]
num_actions: 12
num_obs: 47
# dof_vel_scale: 0.05
dof_vel_scale: 1.0
# action_scale: 0.25
action_scale: 1.0
# cmd_scale: [2.0, 2.0, 0.25]
cmd_scale: [0.0, 0.0, 0.0]
# num_actions: 12
num_actions: 29
# num_obs: 47
num_obs: 96
max_cmd: [0.8, 0.5, 1.57]
# max_cmd: [0.8, 0.5, 1.57]
max_cmd: [1.0, 1.0, 1.0]

View File

@ -19,6 +19,79 @@ from common.rotation_helper import get_gravity_orientation, transform_imu_data
from common.remote_controller import RemoteController, KeyMap
from config import Config
isaaclab_joint_order = [
'left_hip_pitch_joint',
'right_hip_pitch_joint',
'waist_yaw_joint',
'left_hip_roll_joint',
'right_hip_roll_joint',
'waist_roll_joint',
'left_hip_yaw_joint',
'right_hip_yaw_joint',
'waist_pitch_joint',
'left_knee_joint',
'right_knee_joint',
'left_shoulder_pitch_joint',
'right_shoulder_pitch_joint',
'left_ankle_pitch_joint',
'right_ankle_pitch_joint',
'left_shoulder_roll_joint',
'right_shoulder_roll_joint',
'left_ankle_roll_joint',
'right_ankle_roll_joint',
'left_shoulder_yaw_joint',
'right_shoulder_yaw_joint',
'left_elbow_joint',
'right_elbow_joint',
'left_wrist_roll_joint',
'right_wrist_roll_joint',
'left_wrist_pitch_joint',
'right_wrist_pitch_joint',
'left_wrist_yaw_joint',
'right_wrist_yaw_joint'
]
raw_joint_order = [
'left_hip_pitch_joint',
'left_hip_roll_joint',
'left_hip_yaw_joint',
'left_knee_joint',
'left_ankle_pitch_joint',
'left_ankle_roll_joint',
'right_hip_pitch_joint',
'right_hip_roll_joint',
'right_hip_yaw_joint',
'right_knee_joint',
'right_ankle_pitch_joint',
'right_ankle_roll_joint',
'waist_yaw_joint',
'waist_roll_joint',
'waist_pitch_joint',
'left_shoulder_pitch_joint',
'left_shoulder_roll_joint',
'left_shoulder_yaw_joint',
'left_elbow_joint',
'left_wrist_roll_joint',
'left_wrist_pitch_joint',
'left_wrist_yaw_joint',
'right_shoulder_pitch_joint',
'right_shoulder_roll_joint',
'right_shoulder_yaw_joint',
'right_elbow_joint',
'right_wrist_roll_joint',
'right_wrist_pitch_joint',
'right_wrist_yaw_joint'
]
# Create a mapping tensor
# mapping_tensor = torch.zeros((len(sim_b_joints), len(sim_a_joints)), device=env.device)
mapping_tensor = torch.zeros((len(raw_joint_order), len(isaaclab_joint_order)))
# Fill the mapping tensor
for b_idx, b_joint in enumerate(raw_joint_order):
if b_joint in isaaclab_joint_order:
a_idx = isaaclab_joint_order.index(b_joint)
mapping_tensor[a_idx, b_idx] = 1.0
class Controller:
def __init__(self, config: Config) -> None:
@ -104,10 +177,14 @@ class Controller:
total_time = 2
num_step = int(total_time / self.config.control_dt)
dof_idx = self.config.leg_joint2motor_idx + self.config.arm_waist_joint2motor_idx
kps = self.config.kps + self.config.arm_waist_kps
kds = self.config.kds + self.config.arm_waist_kds
default_pos = np.concatenate((self.config.default_angles, self.config.arm_waist_target), axis=0)
# dof_idx = self.config.leg_joint2motor_idx + self.config.arm_waist_joint2motor_idx
# kps = self.config.kps + self.config.arm_waist_kps
# kds = self.config.kds + self.config.arm_waist_kds
# default_pos = np.concatenate((self.config.default_angles, self.config.arm_waist_target), axis=0)
dof_idx = self.config.joint2motor_idx
kps = self.config.kps
kds = self.config.kds
default_pos = self.config.default_angles
dof_size = len(dof_idx)
# record the current pos
@ -133,29 +210,40 @@ class Controller:
print("Enter default pos state.")
print("Waiting for the Button A signal...")
while self.remote_controller.button[KeyMap.A] != 1:
for i in range(len(self.config.leg_joint2motor_idx)):
motor_idx = self.config.leg_joint2motor_idx[i]
# for i in range(len(self.config.leg_joint2motor_idx)):
# motor_idx = self.config.leg_joint2motor_idx[i]
# self.low_cmd.motor_cmd[motor_idx].q = self.config.default_angles[i]
# self.low_cmd.motor_cmd[motor_idx].qd = 0
# self.low_cmd.motor_cmd[motor_idx].kp = self.config.kps[i]
# self.low_cmd.motor_cmd[motor_idx].kd = self.config.kds[i]
# self.low_cmd.motor_cmd[motor_idx].tau = 0
# for i in range(len(self.config.arm_waist_joint2motor_idx)):
# motor_idx = self.config.arm_waist_joint2motor_idx[i]
# self.low_cmd.motor_cmd[motor_idx].q = self.config.arm_waist_target[i]
# self.low_cmd.motor_cmd[motor_idx].qd = 0
# self.low_cmd.motor_cmd[motor_idx].kp = self.config.arm_waist_kps[i]
# self.low_cmd.motor_cmd[motor_idx].kd = self.config.arm_waist_kds[i]
# self.low_cmd.motor_cmd[motor_idx].tau = 0
for i in range(len(self.config.joint2motor_idx)):
motor_idx = self.config.joint2motor_idx[i]
self.low_cmd.motor_cmd[motor_idx].q = self.config.default_angles[i]
self.low_cmd.motor_cmd[motor_idx].qd = 0
self.low_cmd.motor_cmd[motor_idx].kp = self.config.kps[i]
self.low_cmd.motor_cmd[motor_idx].kd = self.config.kds[i]
self.low_cmd.motor_cmd[motor_idx].tau = 0
for i in range(len(self.config.arm_waist_joint2motor_idx)):
motor_idx = self.config.arm_waist_joint2motor_idx[i]
self.low_cmd.motor_cmd[motor_idx].q = self.config.arm_waist_target[i]
self.low_cmd.motor_cmd[motor_idx].qd = 0
self.low_cmd.motor_cmd[motor_idx].kp = self.config.arm_waist_kps[i]
self.low_cmd.motor_cmd[motor_idx].kd = self.config.arm_waist_kds[i]
self.low_cmd.motor_cmd[motor_idx].tau = 0
self.send_cmd(self.low_cmd)
time.sleep(self.config.control_dt)
def run(self):
self.counter += 1
# Get the current joint position and velocity
for i in range(len(self.config.leg_joint2motor_idx)):
self.qj[i] = self.low_state.motor_state[self.config.leg_joint2motor_idx[i]].q
self.dqj[i] = self.low_state.motor_state[self.config.leg_joint2motor_idx[i]].dq
# for i in range(len(self.config.leg_joint2motor_idx)):
# self.qj[i] = self.low_state.motor_state[self.config.leg_joint2motor_idx[i]].q
# self.dqj[i] = self.low_state.motor_state[self.config.leg_joint2motor_idx[i]].dq
for i, motor_idx in enumerate(self.config.joint2motor_idx):
self.qj[i] = self.low_state.motor_state[motor_idx].q
self.dqj[i] = self.low_state.motor_state[motor_idx].dq
# imu_state quaternion: w, x, y, z
quat = self.low_state.imu_state.quaternion
@ -164,8 +252,11 @@ class Controller:
if self.config.imu_type == "torso":
# h1 and h1_2 imu is on the torso
# imu data needs to be transformed to the pelvis frame
waist_yaw = self.low_state.motor_state[self.config.arm_waist_joint2motor_idx[0]].q
waist_yaw_omega = self.low_state.motor_state[self.config.arm_waist_joint2motor_idx[0]].dq
# waist_yaw = self.low_state.motor_state[self.config.arm_waist_joint2motor_idx[0]].q
# waist_yaw_omega = self.low_state.motor_state[self.config.arm_waist_joint2motor_idx[0]].dq
waist_yaw = self.low_state.motor_state[self.config.joint2motor_idx[12]].q
waist_yaw_omega = self.low_state.motor_state[self.config.joint2motor_idx[12]].dq
quat, ang_vel = transform_imu_data(waist_yaw=waist_yaw, waist_yaw_omega=waist_yaw_omega, imu_quat=quat, imu_omega=ang_vel)
# create observation
@ -175,11 +266,11 @@ class Controller:
qj_obs = (qj_obs - self.config.default_angles) * self.config.dof_pos_scale
dqj_obs = dqj_obs * self.config.dof_vel_scale
ang_vel = ang_vel * self.config.ang_vel_scale
period = 0.8
count = self.counter * self.config.control_dt
phase = count % period / period
sin_phase = np.sin(2 * np.pi * phase)
cos_phase = np.cos(2 * np.pi * phase)
# period = 0.8
# count = self.counter * self.config.control_dt
# phase = count % period / period
# sin_phase = np.sin(2 * np.pi * phase)
# cos_phase = np.cos(2 * np.pi * phase)
self.cmd[0] = self.remote_controller.ly
self.cmd[1] = self.remote_controller.lx * -1
@ -192,32 +283,48 @@ class Controller:
self.obs[9 : 9 + num_actions] = qj_obs
self.obs[9 + num_actions : 9 + num_actions * 2] = dqj_obs
self.obs[9 + num_actions * 2 : 9 + num_actions * 3] = self.action
self.obs[9 + num_actions * 3] = sin_phase
self.obs[9 + num_actions * 3 + 1] = cos_phase
# self.obs[9 + num_actions * 3] = sin_phase
# self.obs[9 + num_actions * 3 + 1] = cos_phase
# Get the action from the policy network
obs_tensor = torch.from_numpy(self.obs).unsqueeze(0)
# Reorder the observations
obs_tensor[..., 9:38] = obs_tensor[..., 9:38] @ mapping_tensor.transpose(0, 1)
obs_tensor[..., 38:67] = obs_tensor[..., 38:67] @ mapping_tensor.transpose(0, 1)
obs_tensor[..., 67:96] = obs_tensor[..., 67:96] @ mapping_tensor.transpose(0, 1)
self.action = self.policy(obs_tensor).detach().numpy().squeeze()
# Reorder the actions
self.action = self.action @ mapping_tensor.detach().cpu().numpy()
# transform action to target_dof_pos
target_dof_pos = self.config.default_angles + self.action * self.config.action_scale
# Build low cmd
for i in range(len(self.config.leg_joint2motor_idx)):
motor_idx = self.config.leg_joint2motor_idx[i]
# for i in range(len(self.config.leg_joint2motor_idx)):
# motor_idx = self.config.leg_joint2motor_idx[i]
# self.low_cmd.motor_cmd[motor_idx].q = target_dof_pos[i]
# self.low_cmd.motor_cmd[motor_idx].qd = 0
# self.low_cmd.motor_cmd[motor_idx].kp = self.config.kps[i]
# self.low_cmd.motor_cmd[motor_idx].kd = self.config.kds[i]
# self.low_cmd.motor_cmd[motor_idx].tau = 0
# for i in range(len(self.config.arm_waist_joint2motor_idx)):
# motor_idx = self.config.arm_waist_joint2motor_idx[i]
# self.low_cmd.motor_cmd[motor_idx].q = self.config.arm_waist_target[i]
# self.low_cmd.motor_cmd[motor_idx].qd = 0
# self.low_cmd.motor_cmd[motor_idx].kp = self.config.arm_waist_kps[i]
# self.low_cmd.motor_cmd[motor_idx].kd = self.config.arm_waist_kds[i]
# self.low_cmd.motor_cmd[motor_idx].tau = 0
for i, motor_idx in enumerate(self.config.joint2motor_idx):
self.low_cmd.motor_cmd[motor_idx].q = target_dof_pos[i]
self.low_cmd.motor_cmd[motor_idx].qd = 0
self.low_cmd.motor_cmd[motor_idx].kp = self.config.kps[i]
self.low_cmd.motor_cmd[motor_idx].kd = self.config.kds[i]
self.low_cmd.motor_cmd[motor_idx].tau = 0
for i in range(len(self.config.arm_waist_joint2motor_idx)):
motor_idx = self.config.arm_waist_joint2motor_idx[i]
self.low_cmd.motor_cmd[motor_idx].q = self.config.arm_waist_target[i]
self.low_cmd.motor_cmd[motor_idx].qd = 0
self.low_cmd.motor_cmd[motor_idx].kp = self.config.arm_waist_kps[i]
self.low_cmd.motor_cmd[motor_idx].kd = self.config.arm_waist_kds[i]
self.low_cmd.motor_cmd[motor_idx].tau = 0
# send the command
self.send_cmd(self.low_cmd)

View File

@ -41,7 +41,8 @@ class LeggedRobotCfg(BaseConfig):
max_curriculum = 1.
num_commands = 4 # default: lin_vel_x, lin_vel_y, ang_vel_yaw, heading (in heading mode ang_vel_yaw is recomputed from heading error)
resampling_time = 10. # time before command are changed[s]
heading_command = True # if true: compute ang vel command from heading error
# heading_command = True # if true: compute ang vel command from heading error
heading_command = False # if true: compute ang vel command from heading error
class ranges:
lin_vel_x = [-1.0, 1.0] # min max [m/s]
lin_vel_y = [-1.0, 1.0] # min max [m/s]
@ -86,7 +87,8 @@ class LeggedRobotCfg(BaseConfig):
linear_damping = 0.
max_angular_velocity = 1000.
max_linear_velocity = 1000.
armature = 0.
# armature = 0.
armature = 0.01
thickness = 0.01
class domain_rand:
@ -126,10 +128,13 @@ class LeggedRobotCfg(BaseConfig):
class normalization:
class obs_scales:
lin_vel = 2.0
ang_vel = 0.25
# lin_vel = 2.0
lin_vel = 1.0
# ang_vel = 0.25
ang_vel = 1.0
dof_pos = 1.0
dof_vel = 0.05
# dof_vel = 0.05
dof_vel = 1.0
height_measurements = 5.0
clip_observations = 100.
clip_actions = 100.

View File

@ -6,23 +6,43 @@ class G1RoughCfg( LeggedRobotCfg ):
default_joint_angles = { # = target angles [rad] when action = 0.0
'left_hip_yaw_joint' : 0. ,
'left_hip_roll_joint' : 0,
'left_hip_pitch_joint' : -0.1,
'left_knee_joint' : 0.3,
'left_ankle_pitch_joint' : -0.2,
'left_hip_pitch_joint' : -0.2,
'left_knee_joint' : 0.42,
'left_ankle_pitch_joint' : -0.23,
'left_ankle_roll_joint' : 0,
'right_hip_yaw_joint' : 0.,
'right_hip_roll_joint' : 0,
'right_hip_pitch_joint' : -0.1,
'right_knee_joint' : 0.3,
'right_ankle_pitch_joint': -0.2,
'right_hip_pitch_joint' : -0.2,
'right_knee_joint' : 0.42,
'right_ankle_pitch_joint': -0.23,
'right_ankle_roll_joint' : 0,
'torso_joint' : 0.
'left_elbow_joint': 0.87,
'right_elbow_joint': 0.87,
'left_shoulder_roll_joint': 0.16,
'left_shoulder_pitch_joint': 0.35,
'left_shoulder_yaw_joint': 0.,
'right_shoulder_roll_joint': -0.16,
'right_shoulder_pitch_joint': 0.35,
'right_shoulder_yaw_joint': 0.,
'waist_roll_joint' : 0,
'waist_pitch_joint' : 0,
'waist_yaw_joint' : 0,
'left_wrist_roll_joint' : 0,
'left_wrist_pitch_joint' : 0,
'left_wrist_yaw_joint' : 0,
'right_wrist_roll_joint' : 0,
'right_wrist_pitch_joint' : 0,
'right_wrist_yaw_joint' : 0,
}
class env(LeggedRobotCfg.env):
num_observations = 47
num_privileged_obs = 50
num_actions = 12
# num_observations = 47
num_observations = 96
# num_privileged_obs = 50
num_privileged_obs = 96
# num_actions = 12
num_actions = 29
class domain_rand(LeggedRobotCfg.domain_rand):
@ -44,24 +64,40 @@ class G1RoughCfg( LeggedRobotCfg ):
'hip_pitch': 100,
'knee': 150,
'ankle': 40,
'shoulder_pitch': 100,
'shoulder_roll': 100,
'shoulder_yaw': 50,
'elbow': 50,
'wrist': 20,
'waist': 150,
} # [N*m/rad]
damping = { 'hip_yaw': 2,
'hip_roll': 2,
'hip_pitch': 2,
'knee': 4,
'ankle': 2,
'shoulder_pitch': 2,
'shoulder_roll': 2,
'shoulder_yaw': 2,
'elbow': 2,
'wrist': 1,
'waist': 3,
} # [N*m/rad] # [N*m*s/rad]
# action scale: target angle = actionScale * action + defaultAngle
action_scale = 0.25
# action_scale = 0.25
action_scale = 0.5
# decimation: Number of control action updates @ sim DT per policy DT
decimation = 4
class asset( LeggedRobotCfg.asset ):
file = '{LEGGED_GYM_ROOT_DIR}/resources/robots/g1_description/g1_12dof.urdf'
# file = '{LEGGED_GYM_ROOT_DIR}/resources/robots/g1_description/g1_12dof.urdf'
file = '{LEGGED_GYM_ROOT_DIR}/resources/robots/g1_description/g1_29dof_rev_1_0.urdf'
# file = '{LEGGED_GYM_ROOT_DIR}/resources/robots/g1_description/g1_29dof_rev_1_0.urdf'
name = "g1"
foot_name = "ankle_roll"
penalize_contacts_on = ["hip", "knee"]
terminate_after_contacts_on = ["pelvis"]
# terminate_after_contacts_on = ["pelvis"]
terminate_after_contacts_on = []
self_collisions = 0 # 1 to disable, 0 to enable...bitwise filter
flip_visual_attachments = False
@ -91,8 +127,10 @@ class G1RoughCfg( LeggedRobotCfg ):
class G1RoughCfgPPO( LeggedRobotCfgPPO ):
class policy:
init_noise_std = 0.8
actor_hidden_dims = [32]
critic_hidden_dims = [32]
# actor_hidden_dims = [32]
actor_hidden_dims = [256, 128, 128]
# critic_hidden_dims = [32]
critic_hidden_dims = [256, 128, 128]
activation = 'elu' # can be elu, relu, selu, crelu, lrelu, tanh, sigmoid
# only for 'ActorCriticRecurrent':
rnn_type = 'lstm'
@ -102,7 +140,8 @@ class G1RoughCfgPPO( LeggedRobotCfgPPO ):
class algorithm( LeggedRobotCfgPPO.algorithm ):
entropy_coef = 0.01
class runner( LeggedRobotCfgPPO.runner ):
policy_class_name = "ActorCriticRecurrent"
# policy_class_name = "ActorCriticRecurrent"
policy_class_name = "ActorCritic"
max_iterations = 10000
run_name = ''
experiment_name = 'g1'

View File

@ -27,7 +27,7 @@ class G1Robot(LeggedRobot):
noise_vec[9:9+self.num_actions] = noise_scales.dof_pos * noise_level * self.obs_scales.dof_pos
noise_vec[9+self.num_actions:9+2*self.num_actions] = noise_scales.dof_vel * noise_level * self.obs_scales.dof_vel
noise_vec[9+2*self.num_actions:9+3*self.num_actions] = 0. # previous actions
noise_vec[9+3*self.num_actions:9+3*self.num_actions+2] = 0. # sin/cos phase
# noise_vec[9+3*self.num_actions:9+3*self.num_actions+2] = 0. # sin/cos phase
return noise_vec
@ -68,26 +68,55 @@ class G1Robot(LeggedRobot):
def compute_observations(self):
""" Computes observations
"""
sin_phase = torch.sin(2 * np.pi * self.phase ).unsqueeze(1)
cos_phase = torch.cos(2 * np.pi * self.phase ).unsqueeze(1)
self.obs_buf = torch.cat(( self.base_ang_vel * self.obs_scales.ang_vel,
# sin_phase = torch.sin(2 * np.pi * self.phase ).unsqueeze(1)
# cos_phase = torch.cos(2 * np.pi * self.phase ).unsqueeze(1)
self.gym.refresh_rigid_body_state_tensor(self.sim)
self.pelvis_states = self.rigid_body_states_view[:, 0, :]
from icecream import ic
# ic(self.pelvis_states)
# ic(self.commands[:, :3] * self.commands_scale)
self.pelvis_ang_vel = quat_rotate_inverse(self.pelvis_states[..., 3:7], self.pelvis_states[:, 10:13])
self.projected_gravity = quat_rotate_inverse(self.pelvis_states[..., 3:7], self.gravity_vec)
# self.commands[..., :3] = 0.
# if self.episode_length_buf == 0:
# self.pelvis_ang_vel[..., :] = 0.
# self.projected_gravity[..., 0:2] = 0.
# self.projected_gravity[..., 2] = -1.
self.obs_buf = torch.cat((
# self.base_ang_vel * self.obs_scales.ang_vel,
self.pelvis_ang_vel* self.obs_scales.ang_vel,
self.projected_gravity,
self.commands[:, :3] * self.commands_scale,
(self.dof_pos - self.default_dof_pos) * self.obs_scales.dof_pos,
self.dof_vel * self.obs_scales.dof_vel,
self.actions,
sin_phase,
cos_phase
# sin_phase,
# cos_phase
),dim=-1)
self.privileged_obs_buf = torch.cat(( self.base_lin_vel * self.obs_scales.lin_vel,
from icecream import ic
# ic(
# self.episode_length_buf,
# self.base_ang_vel,
# self.projected_gravity,
# self.commands,
# (self.dof_pos - self.default_dof_pos) * self.obs_scales.dof_pos,
# self.dof_vel,
# self.actions,
# )
self.privileged_obs_buf = torch.cat((
# self.base_lin_vel * self.obs_scales.lin_vel,
self.base_ang_vel * self.obs_scales.ang_vel,
self.projected_gravity,
self.commands[:, :3] * self.commands_scale,
(self.dof_pos - self.default_dof_pos) * self.obs_scales.dof_pos,
self.dof_vel * self.obs_scales.dof_vel,
self.actions,
sin_phase,
cos_phase
# sin_phase,
# cos_phase
),dim=-1)
# add perceptive inputs if not blind
# add noise if needed

View File

@ -33,6 +33,80 @@ def play(args):
ppo_runner, train_cfg = task_registry.make_alg_runner(env=env, name=args.task, args=args, train_cfg=train_cfg)
policy = ppo_runner.get_inference_policy(device=env.device)
# Define the joint orders for sim A and sim B
sim_a_joints = [
'left_hip_pitch_joint',
'right_hip_pitch_joint',
'waist_yaw_joint',
'left_hip_roll_joint',
'right_hip_roll_joint',
'waist_roll_joint',
'left_hip_yaw_joint',
'right_hip_yaw_joint',
'waist_pitch_joint',
'left_knee_joint',
'right_knee_joint',
'left_shoulder_pitch_joint',
'right_shoulder_pitch_joint',
'left_ankle_pitch_joint',
'right_ankle_pitch_joint',
'left_shoulder_roll_joint',
'right_shoulder_roll_joint',
'left_ankle_roll_joint',
'right_ankle_roll_joint',
'left_shoulder_yaw_joint',
'right_shoulder_yaw_joint',
'left_elbow_joint',
'right_elbow_joint',
'left_wrist_roll_joint',
'right_wrist_roll_joint',
'left_wrist_pitch_joint',
'right_wrist_pitch_joint',
'left_wrist_yaw_joint',
'right_wrist_yaw_joint'
]
sim_b_joints = [
'left_hip_pitch_joint',
'left_hip_roll_joint',
'left_hip_yaw_joint',
'left_knee_joint',
'left_ankle_pitch_joint',
'left_ankle_roll_joint',
'right_hip_pitch_joint',
'right_hip_roll_joint',
'right_hip_yaw_joint',
'right_knee_joint',
'right_ankle_pitch_joint',
'right_ankle_roll_joint',
'waist_yaw_joint',
'waist_roll_joint',
'waist_pitch_joint',
'left_shoulder_pitch_joint',
'left_shoulder_roll_joint',
'left_shoulder_yaw_joint',
'left_elbow_joint',
'left_wrist_roll_joint',
'left_wrist_pitch_joint',
'left_wrist_yaw_joint',
'right_shoulder_pitch_joint',
'right_shoulder_roll_joint',
'right_shoulder_yaw_joint',
'right_elbow_joint',
'right_wrist_roll_joint',
'right_wrist_pitch_joint',
'right_wrist_yaw_joint'
]
# Create a mapping tensor
mapping_tensor = torch.zeros((len(sim_b_joints), len(sim_a_joints)), device=env.device)
# Fill the mapping tensor
for b_idx, b_joint in enumerate(sim_b_joints):
if b_joint in sim_a_joints:
a_idx = sim_a_joints.index(b_joint)
# mapping_tensor[b_idx, a_idx] = 1.0
mapping_tensor[a_idx, b_idx] = 1.0
# export policy as a jit module (used to run it from C++)
if EXPORT_POLICY:
path = os.path.join(LEGGED_GYM_ROOT_DIR, 'logs', train_cfg.runner.experiment_name, 'exported', 'policies')
@ -40,11 +114,28 @@ def play(args):
print('Exported policy as jit script to: ', path)
for i in range(10*int(env.max_episode_length)):
obs[..., 9:38] = obs[..., 9:38] @ mapping_tensor.transpose(0, 1)
obs[..., 38:67] = obs[..., 38:67] @ mapping_tensor.transpose(0, 1)
obs[..., 67:96] = obs[..., 67:96] @ mapping_tensor.transpose(0, 1)
# from icecream import ic
# ic(
# obs[..., :9],
# obs[..., 9:38],
# obs[..., 38:67],
# obs[..., 67:96],
# )
actions = policy(obs.detach())
obs, _, rews, dones, infos = env.step(actions.detach())
# ic(
# actions
# )
reordered_actions = actions @ mapping_tensor
# obs, _, rews, dones, infos = env.step(actions.detach())
obs, _, rews, dones, infos = env.step(reordered_actions.detach())
if __name__ == '__main__':
EXPORT_POLICY = True
# EXPORT_POLICY = False
RECORD_FRAMES = False
MOVE_CAMERA = False
args = get_args()

View File

@ -1,3 +1,4 @@
<?xml version="1.0" encoding="utf-8"?>
<robot name="g1_29dof_rev_1_0">
<mujoco>
<compiler meshdir="meshes" discardvisual="false"/>
@ -215,6 +216,12 @@
</material>
</visual>
<collision>
<origin xyz="0.03826199 0.0 -0.02540915" rpy="0 0 0"/>
<geometry>
<box size="0.20820672 0.07558269 0.02"/>
</geometry>
</collision>
<!-- <collision>
<origin xyz="-0.05 0.025 -0.03" rpy="0 0 0"/>
<geometry>
<sphere radius="0.005"/>
@ -237,7 +244,7 @@
<geometry>
<sphere radius="0.005"/>
</geometry>
</collision>
</collision> -->
</link>
<joint name="left_ankle_roll_joint" type="revolute">
<origin xyz="0 0 -0.017558" rpy="0 0 0"/>
@ -407,6 +414,12 @@
</material>
</visual>
<collision>
<origin xyz="0.03826199 0.0 -0.02540915" rpy="0 0 0"/>
<geometry>
<box size="0.20820672 0.07558269 0.02"/>
</geometry>
</collision>
<!-- <collision>
<origin xyz="-0.05 0.025 -0.03" rpy="0 0 0"/>
<geometry>
<sphere radius="0.005"/>
@ -429,7 +442,7 @@
<geometry>
<sphere radius="0.005"/>
</geometry>
</collision>
</collision> -->
</link>
<joint name="right_ankle_roll_joint" type="revolute">
<origin xyz="0 0 -0.017558" rpy="0 0 0"/>
@ -830,6 +843,12 @@
<color rgba="0.7 0.7 0.7 1"/>
</material>
</visual>
<collision>
<origin xyz="0.07 -0.01 0" rpy="0 0 -0.3"/>
<geometry>
<box size="0.12 0.03 0.08"/>
</geometry>
</collision>
</link>
<link name="right_shoulder_pitch_link">
<inertial>
@ -1054,5 +1073,11 @@
<color rgba="0.7 0.7 0.7 1"/>
</material>
</visual>
<collision>
<origin xyz="0.07 0.01 0" rpy="0 0 0.3"/>
<geometry>
<box size="0.12 0.03 0.08"/>
</geometry>
</collision>
</link>
</robot>

View File

@ -1,6 +1,11 @@
<mujoco model="g1_29dof_rev_1_0">
<compiler angle="radian" meshdir="meshes"/>
<statistic meansize="0.144785" extent="1.23314" center="0.025392 2.0634e-05 -0.245975"/>
<default>
<joint damping="0.001" armature="0.01" frictionloss="0.1"/>
</default>
<asset>
<mesh name="pelvis" file="pelvis.STL"/>
<mesh name="pelvis_contour_link" file="pelvis_contour_link.STL"/>