add the sim2sim code on mujoco
This commit is contained in:
parent
bd00c6a2a1
commit
fb7514ad38
51
README.md
51
README.md
|
@ -59,12 +59,57 @@ https://github.com/user-attachments/assets/98395d82-d3f6-4548-b6ee-8edfce70ac3e
|
|||
|
||||
2. H1
|
||||
|
||||
https://github.com/user-attachments/assets/a9475a63-ea06-4327-bfa6-6a0f8065fa1c
|
||||
https://github.com/user-attachments/assets/7762b4f9-1072-4794-8ef6-7dd253a7ad4c
|
||||
|
||||
3. H1-2
|
||||
|
||||
https://github.com/user-attachments/assets/d6cdee70-8f8a-4a50-b219-df31b269b083
|
||||
https://github.com/user-attachments/assets/695323a7-a2d9-445b-bda8-f1b697159c39
|
||||
|
||||
4. G1
|
||||
|
||||
https://github.com/user-attachments/assets/0b554137-76bc-43f9-97e1-dd704a33d6a9
|
||||
https://github.com/user-attachments/assets/6063c03e-1143-4c75-8fda-793c8615cb08
|
||||
|
||||
|
||||
### mujoco(sim2sim)
|
||||
|
||||
1. H1
|
||||
|
||||
Execute the following command in the project path:
|
||||
|
||||
```bash
|
||||
|
||||
python deploy/deploy_mujoco/deploy_mujoco.py g1.yaml
|
||||
|
||||
```
|
||||
|
||||
Then you can get the following effect:
|
||||
|
||||
https://github.com/user-attachments/assets/10a84f8d-c02f-41cb-b2fd-76a97951b2c3
|
||||
|
||||
2. H1_2
|
||||
|
||||
Execute the following command in the project path:
|
||||
|
||||
```bash
|
||||
|
||||
python deploy/deploy_mujoco/deploy_mujoco.py h1_2.yaml
|
||||
|
||||
```
|
||||
|
||||
Then you can get the following effect:
|
||||
|
||||
https://github.com/user-attachments/assets/fdd4f53d-3235-4978-a77f-1c71b32fb301
|
||||
|
||||
3. G1
|
||||
|
||||
Execute the following command in the project path:
|
||||
|
||||
```bash
|
||||
|
||||
python deploy/deploy_mujoco/deploy_mujoco.py g1.yaml
|
||||
|
||||
```
|
||||
|
||||
Then you can get the following effect:
|
||||
|
||||
https://github.com/user-attachments/assets/99b892c3-7886-49f4-a7f1-0420b51443dd
|
|
@ -0,0 +1,26 @@
|
|||
#
|
||||
policy_path: "{LEGGED_GYM_ROOT_DIR}/deploy/pre_train/g1/motion.pt"
|
||||
xml_path: "{LEGGED_GYM_ROOT_DIR}/resources/robots/g1_description/scene.xml"
|
||||
|
||||
# Total simulation time
|
||||
simulation_duration: 60.0
|
||||
# Simulation time step
|
||||
simulation_dt: 0.002
|
||||
# Controller update frequency (meets the requirement of simulation_dt * controll_decimation=0.02; 50Hz)
|
||||
control_decimation: 10
|
||||
|
||||
kps: [100, 100, 100, 150, 40, 40, 100, 100, 100, 150, 40, 40]
|
||||
kds: [2, 2, 2, 4, 2, 2, 2, 2, 2, 4, 2, 2]
|
||||
|
||||
default_angles: [-0.1, 0.0, 0.0, 0.3, -0.2, 0.0,
|
||||
-0.1, 0.0, 0.0, 0.3, -0.2, 0.0]
|
||||
|
||||
ang_vel_scale: 0.25
|
||||
dof_pos_scale: 1.0
|
||||
dof_vel_scale: 0.05
|
||||
action_scale: 0.25
|
||||
cmd_scale: [2.0, 2.0, 0.25]
|
||||
num_actions: 12
|
||||
num_obs: 47
|
||||
|
||||
cmd_init: [0.5, 0, 0]
|
|
@ -0,0 +1,26 @@
|
|||
#
|
||||
policy_path: "{LEGGED_GYM_ROOT_DIR}/deploy/pre_train/h1/motion.pt"
|
||||
xml_path: "{LEGGED_GYM_ROOT_DIR}/resources/robots/h1/scene.xml"
|
||||
|
||||
# Total simulation time
|
||||
simulation_duration: 60.0
|
||||
# Simulation time step
|
||||
simulation_dt: 0.002
|
||||
# Controller update frequency (meets the requirement of simulation_dt * controll_decimation=0.02; 50Hz)
|
||||
control_decimation: 10
|
||||
|
||||
kps: [150, 150, 150, 200, 40, 150, 150, 150, 200, 40]
|
||||
kds: [2, 2, 2, 4, 2, 2, 2, 2, 4, 2]
|
||||
|
||||
default_angles: [0, 0.0, -0.1, 0.3, -0.2,
|
||||
0, 0.0, -0.1, 0.3, -0.2]
|
||||
|
||||
ang_vel_scale: 0.25
|
||||
dof_pos_scale: 1.0
|
||||
dof_vel_scale: 0.05
|
||||
action_scale: 0.25
|
||||
cmd_scale: [2.0, 2.0, 0.25]
|
||||
num_actions: 10
|
||||
num_obs: 41
|
||||
|
||||
cmd_init: [0.5, 0, 0]
|
|
@ -0,0 +1,26 @@
|
|||
#
|
||||
policy_path: "{LEGGED_GYM_ROOT_DIR}/deploy/pre_train/h1_2/motion.pt"
|
||||
xml_path: "{LEGGED_GYM_ROOT_DIR}/resources/robots/h1_2/scene.xml"
|
||||
|
||||
# Total simulation time
|
||||
simulation_duration: 60.0
|
||||
# Simulation time step
|
||||
simulation_dt: 0.002
|
||||
# Controller update frequency (meets the requirement of simulation_dt * controll_decimation=0.02; 50Hz)
|
||||
control_decimation: 10
|
||||
|
||||
kps: [200, 200, 200, 300, 40, 40, 200, 200, 200, 300, 40, 40]
|
||||
kds: [2.5, 2.5, 2.5, 4, 2, 2, 2.5, 2.5, 2.5, 4, 2, 2]
|
||||
|
||||
default_angles: [0, -0.16, 0.0, 0.36, -0.2, 0.0,
|
||||
0, -0.16, 0.0, 0.36, -0.2, 0.0]
|
||||
|
||||
ang_vel_scale: 0.25
|
||||
dof_pos_scale: 1.0
|
||||
dof_vel_scale: 0.05
|
||||
action_scale: 0.25
|
||||
cmd_scale: [2.0, 2.0, 0.25]
|
||||
num_actions: 12
|
||||
num_obs: 47
|
||||
|
||||
cmd_init: [0.5, 0, 0]
|
|
@ -0,0 +1,130 @@
|
|||
import time
|
||||
|
||||
import mujoco.viewer
|
||||
import mujoco
|
||||
import numpy as np
|
||||
from legged_gym import LEGGED_GYM_ROOT_DIR
|
||||
import torch
|
||||
import yaml
|
||||
|
||||
|
||||
def get_gravity_orientation(quaternion):
|
||||
qw = quaternion[0]
|
||||
qx = quaternion[1]
|
||||
qy = quaternion[2]
|
||||
qz = quaternion[3]
|
||||
|
||||
gravity_orientation = np.zeros(3)
|
||||
|
||||
gravity_orientation[0] = 2 * (-qz * qx + qw * qy)
|
||||
gravity_orientation[1] = -2 * (qz * qy + qw * qx)
|
||||
gravity_orientation[2] = 1 - 2 * (qw * qw + qz * qz)
|
||||
|
||||
return gravity_orientation
|
||||
|
||||
|
||||
def pd_control(target_q, q, kp, target_dq, dq, kd):
|
||||
"""Calculates torques from position commands"""
|
||||
return (target_q - q) * kp + (target_dq - dq) * kd
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# get config file name from command line
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("config_file", type=str, help="config file name in the config folder")
|
||||
args = parser.parse_args()
|
||||
config_file = args.config_file
|
||||
with open(f"{LEGGED_GYM_ROOT_DIR}/deploy/deploy_mujoco/configs/{config_file}", "r") as f:
|
||||
config = yaml.load(f, Loader=yaml.FullLoader)
|
||||
policy_path = config["policy_path"].replace("{LEGGED_GYM_ROOT_DIR}", LEGGED_GYM_ROOT_DIR)
|
||||
xml_path = config["xml_path"].replace("{LEGGED_GYM_ROOT_DIR}", LEGGED_GYM_ROOT_DIR)
|
||||
|
||||
simulation_duration = config["simulation_duration"]
|
||||
simulation_dt = config["simulation_dt"]
|
||||
control_decimation = config["control_decimation"]
|
||||
|
||||
kps = np.array(config["kps"], dtype=np.float32)
|
||||
kds = np.array(config["kds"], dtype=np.float32)
|
||||
|
||||
default_angles = np.array(config["default_angles"], dtype=np.float32)
|
||||
|
||||
ang_vel_scale = config["ang_vel_scale"]
|
||||
dof_pos_scale = config["dof_pos_scale"]
|
||||
dof_vel_scale = config["dof_vel_scale"]
|
||||
action_scale = config["action_scale"]
|
||||
cmd_scale = np.array(config["cmd_scale"], dtype=np.float32)
|
||||
|
||||
num_actions = config["num_actions"]
|
||||
num_obs = config["num_obs"]
|
||||
|
||||
cmd = np.array(config["cmd_init"], dtype=np.float32)
|
||||
|
||||
# define context variables
|
||||
action = np.zeros(num_actions, dtype=np.float32)
|
||||
target_dof_pos = default_angles.copy()
|
||||
obs = np.zeros(num_obs, dtype=np.float32)
|
||||
|
||||
counter = 0
|
||||
|
||||
# Load robot model
|
||||
m = mujoco.MjModel.from_xml_path(xml_path)
|
||||
d = mujoco.MjData(m)
|
||||
m.opt.timestep = simulation_dt
|
||||
|
||||
# load policy
|
||||
policy = torch.jit.load(policy_path)
|
||||
|
||||
with mujoco.viewer.launch_passive(m, d) as viewer:
|
||||
# Close the viewer automatically after simulation_duration wall-seconds.
|
||||
start = time.time()
|
||||
while viewer.is_running() and time.time() - start < simulation_duration:
|
||||
step_start = time.time()
|
||||
tau = pd_control(target_dof_pos, d.qpos[7:], kps, np.zeros_like(kds), d.qvel[6:], kds)
|
||||
d.ctrl[:] = tau
|
||||
# mj_step can be replaced with code that also evaluates
|
||||
# a policy and applies a control signal before stepping the physics.
|
||||
mujoco.mj_step(m, d)
|
||||
|
||||
counter += 1
|
||||
if counter % control_decimation == 0:
|
||||
# Apply control signal here.
|
||||
|
||||
# create observation
|
||||
qj = d.qpos[7:]
|
||||
dqj = d.qvel[6:]
|
||||
quat = d.qpos[3:7]
|
||||
omega = d.qvel[3:6]
|
||||
|
||||
qj = (qj - default_angles) * dof_pos_scale
|
||||
dqj = dqj * dof_vel_scale
|
||||
gravity_orientation = get_gravity_orientation(quat)
|
||||
omega = omega * ang_vel_scale
|
||||
|
||||
period = 0.8
|
||||
count = counter * simulation_dt
|
||||
phase = count % period / period
|
||||
sin_phase = np.sin(2 * np.pi * phase)
|
||||
cos_phase = np.cos(2 * np.pi * phase)
|
||||
|
||||
obs[:3] = omega
|
||||
obs[3:6] = gravity_orientation
|
||||
obs[6:9] = cmd * cmd_scale
|
||||
obs[9 : 9 + num_actions] = qj
|
||||
obs[9 + num_actions : 9 + 2 * num_actions] = dqj
|
||||
obs[9 + 2 * num_actions : 9 + 3 * num_actions] = action
|
||||
obs[9 + 3 * num_actions : 9 + 3 * num_actions + 2] = np.array([sin_phase, cos_phase])
|
||||
obs_tensor = torch.from_numpy(obs).unsqueeze(0)
|
||||
# policy inference
|
||||
action = policy(obs_tensor).detach().numpy().squeeze()
|
||||
# transform action to target_dof_pos
|
||||
target_dof_pos = action * action_scale + default_angles
|
||||
|
||||
# Pick up changes to the physics state, apply perturbations, update options from GUI.
|
||||
viewer.sync()
|
||||
|
||||
# Rudimentary time keeping, will drift relative to wall clock.
|
||||
time_until_next_step = m.opt.timestep - (time.time() - step_start)
|
||||
if time_until_next_step > 0:
|
||||
time.sleep(time_until_next_step)
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
2
setup.py
2
setup.py
|
@ -8,4 +8,4 @@ setup(name='unitree_rl_gym',
|
|||
packages=find_packages(),
|
||||
author_email='support@unitree.com',
|
||||
description='Template RL environments for Unitree Robots',
|
||||
install_requires=['isaacgym', 'rsl-rl', 'matplotlib', 'numpy==1.20', 'tensorboard'])
|
||||
install_requires=['isaacgym', 'rsl-rl', 'matplotlib', 'numpy==1.20', 'tensorboard', 'mujoco==3.2.3', 'pyyaml'])
|
||||
|
|
Loading…
Reference in New Issue