Merge branch 'localization' of github.com:YoontaeCho/unitree_rl_gym into navigation

This commit is contained in:
HYUNHONOH98 2025-02-20 15:45:16 +09:00
commit 6bfb0a9cd7
11 changed files with 1815 additions and 28 deletions

View File

View File

View File

@ -0,0 +1,258 @@
import time
import math
import numpy as np
import matplotlib.pyplot as plt
def linear_map(val, in_min, in_max, out_min, out_max):
"""Linearly map val from [in_min, in_max] to [out_min, out_max]."""
return out_min + (val - in_min) * (out_max - out_min) / (in_max - in_min)
def quaternion_multiply(q1, q2):
# q = [w, x, y, z]
w1, x1, y1, z1 = q1
w2, x2, y2, z2 = q2
w = w1*w2 - x1*x2 - y1*y2 - z1*z2
x = w1*x2 + x1*w2 + y1*z2 - z1*y2
y = w1*y2 - x1*z2 + y1*w2 + z1*x2
z = w1*z2 + x1*y2 - y1*x2 + z1*w2
return np.array([w, x, y, z], dtype=np.float32)
def quaternion_rotate(q, v):
"""Rotate vector v by quaternion q."""
q_conj = np.array([q[0], -q[1], -q[2], -q[3]], dtype=np.float32)
v_q = np.concatenate(([0.0], v))
rotated = quaternion_multiply(quaternion_multiply(q, v_q), q_conj)
return rotated[1:]
def yaw_to_quaternion(yaw):
"""Convert yaw angle (radians) to a quaternion (w, x, y, z)."""
half_yaw = yaw / 2.0
return np.array([np.cos(half_yaw), 0.0, 0.0, np.sin(half_yaw)], dtype=np.float32)
def combine_frame_transforms(pos, quat, rel_pos, rel_quat):
"""
Combine two transforms:
T_new = T * T_rel
where T is given by (pos, quat) and T_rel by (rel_pos, rel_quat).
"""
new_pos = pos + quaternion_rotate(quat, rel_pos)
new_quat = quaternion_multiply(quat, rel_quat)
return new_pos, new_quat
# ----------------------
# StepCommand Class
# ----------------------
class StepCommand:
def __init__(self, current_left_pose, current_right_pose):
"""
Initialize with the current foot poses.
Each pose is a 7-dimensional vector: [x, y, z, qw, qx, qy, qz].
Both next_ctarget_left and next_ctarget_right are initialized to these values.
Also, store the maximum ranges for x, y, and theta.
- x_range: (-0.2, 0.2)
- y_range: (0.2, 0.4)
- theta_range: (-0.3, 0.3)
"""
self.next_ctarget_left = current_left_pose.copy()
self.next_ctarget_right = current_right_pose.copy()
self.next_ctime_left = 0.4
self.next_ctime_right = 0.4
self.delta_ctime = 0.4 # Fixed time delta for a new step
self.max_range = {
'x_range': (-0.2, 0.2),
'y_range': (0.2, 0.4),
'theta_range': (-0.3, 0.3)
}
def compute_relstep_left(self, lx, ly, rx):
"""
Compute the left foot relative step based on remote controller inputs.
Mapping:
- x: map ly in [-1,1] to self.max_range['x_range'].
- y: baseline for left is self.max_range['y_range'][0]. If lx > 0,
add an offset mapping lx in [0,1] to [0, self.max_range['y_range'][1]-self.max_range['y_range'][0]].
- z: fixed at 0.
- rotation: map rx in [-1,1] to self.max_range['theta_range'] and convert to quaternion.
"""
delta_x = linear_map(ly, -1, 1, self.max_range['x_range'][0], self.max_range['x_range'][1])
baseline_left = self.max_range['y_range'][0]
extra_y = linear_map(lx, 0, 1, 0, self.max_range['y_range'][1] - self.max_range['y_range'][0]) if lx > 0 else 0.0
delta_y = baseline_left + extra_y
delta_z = 0.0
theta = linear_map(rx, -1, 1, self.max_range['theta_range'][0], self.max_range['theta_range'][1])
q = yaw_to_quaternion(theta)
return np.array([delta_x, delta_y, delta_z, q[0], q[1], q[2], q[3]], dtype=np.float32)
def compute_relstep_right(self, lx, ly, rx):
"""
Compute the right foot relative step based on remote controller inputs.
Mapping:
- x: map ly in [-1,1] to self.max_range['x_range'].
- y: baseline for right is the negative of self.max_range['y_range'][0]. If lx < 0,
add an offset mapping lx in [-1,0] to [- (self.max_range['y_range'][1]-self.max_range['y_range'][0]), 0].
- z: fixed at 0.
- rotation: map rx in [-1,1] to self.max_range['theta_range'] and convert to quaternion.
"""
delta_x = linear_map(ly, -1, 1, self.max_range['x_range'][0], self.max_range['x_range'][1])
baseline_right = -self.max_range['y_range'][0]
extra_y = linear_map(lx, -1, 0, -(self.max_range['y_range'][1] - self.max_range['y_range'][0]), 0) if lx < 0 else 0.0
delta_y = baseline_right + extra_y
delta_z = 0.0
theta = linear_map(rx, -1, 1, self.max_range['theta_range'][0], self.max_range['theta_range'][1])
q = yaw_to_quaternion(theta)
return np.array([delta_x, delta_y, delta_z, q[0], q[1], q[2], q[3]], dtype=np.float32)
def get_next_ctarget(self, remote_controller, count):
"""
Given the remote controller inputs and elapsed time (count),
compute relative step commands for left and right feet and update
the outdated targets accordingly.
Update procedure:
- When the left foot is due (count > next_ctime_left), update it by combining
the right foot target with the left relative step.
- Similarly, when the right foot is due (count > next_ctime_right), update it using
the left foot target and the right relative step.
Returns:
A concatenated 14-dimensional vector:
[left_foot_target (7D), right_foot_target (7D)]
"""
lx = remote_controller.lx
ly = remote_controller.ly
rx = remote_controller.rx
# Compute relative steps using the internal methods.
relstep_left = self.compute_relstep_left(lx, ly, rx)
relstep_right = self.compute_relstep_right(lx, ly, rx)
# from icecream import ic
# Update left foot target if its scheduled time has elapsed.
if count > self.next_ctime_left:
self.next_ctime_left = self.next_ctime_right + self.delta_ctime
new_pos, new_quat = combine_frame_transforms(
self.next_ctarget_right[:3],
self.next_ctarget_right[3:7],
relstep_left[:3],
relstep_left[3:7],
)
self.next_ctarget_left[:3] = new_pos
self.next_ctarget_left[3:7] = new_quat
# Update right foot target if its scheduled time has elapsed.
if count > self.next_ctime_right:
self.next_ctime_right = self.next_ctime_left + self.delta_ctime
new_pos, new_quat = combine_frame_transforms(
self.next_ctarget_left[:3],
self.next_ctarget_left[3:7],
relstep_right[:3],
relstep_right[3:7],
)
self.next_ctarget_right[:3] = new_pos
self.next_ctarget_right[3:7] = new_quat
# Return the concatenated target: left (7D) followed by right (7D).
return (self.next_ctarget_left, self.next_ctarget_right,
(self.next_ctime_left - count),
(self.next_ctime_right - count))
# For testing purposes, we define a dummy remote controller that mimics the attributes lx, ly, and rx.
class DummyRemoteController:
def __init__(self, lx=0.0, ly=0.0, rx=0.0):
self.lx = lx # lateral command input in range [-1,1]
self.ly = ly # forward/backward command input in range [-1,1]
self.rx = rx # yaw command input in range [-1,1]
if __name__ == "__main__":
# Initial foot poses (7D each): [x, y, z, qw, qx, qy, qz]
current_left_pose = np.array([0.0, 0.2, 0.0, 1.0, 0.0, 0.0, 0.0], dtype=np.float32)
current_right_pose = np.array([0.0, -0.2, 0.0, 1.0, 0.0, 0.0, 0.0], dtype=np.float32)
# Create an instance of StepCommand with the initial poses.
step_command = StepCommand(current_left_pose, current_right_pose)
# Create a dummy remote controller.
dummy_remote = DummyRemoteController()
# Set up matplotlib for interactive plotting.
plt.ion()
fig, ax = plt.subplots()
ax.set_xlim(-1, 1)
ax.set_ylim(-1, 1)
ax.set_xlabel("X")
ax.set_ylabel("Y")
ax.set_title("Footstep Target Visualization")
print("Starting test. Press Ctrl+C to exit.")
start_time = time.time()
try:
while True:
elapsed = time.time() - start_time
# For demonstration, vary the controller inputs over time:
# - ly oscillates between -1 and 1 (forward/backward)
# - lx oscillates between -1 and 1 (lateral left/right)
# - rx is held at 0 (no yaw command)
# dummy_remote.ly = math.sin(elapsed) # forward/backward command
# dummy_remote.lx = math.cos(elapsed) # lateral command
dummy_remote.ly = 0.0
dummy_remote.lx = 0.0
dummy_remote.rx = 1. # no yaw
# Get the current footstep target (14-dimensional)
ctarget = step_command.get_next_ctarget(dummy_remote, elapsed)
print("Time: {:.2f} s, ctarget: {}".format(elapsed, ctarget))
# Extract left foot and right foot positions:
# Left foot: indices 0:7 (position: [0:3], quaternion: [3:7])
left_pos = ctarget[0:3] # [x, y, z]
left_quat = ctarget[3:7] # [qw, qx, qy, qz]
# Right foot: indices 7:14 (position: [7:10], quaternion: [10:14])
right_pos = ctarget[7:10]
right_quat = ctarget[10:14]
# For visualization, we use only the x and y components.
left_x, left_y = left_pos[0], left_pos[1]
right_x, right_y = right_pos[0], right_pos[1]
# Assuming rotation only about z, compute yaw angle from quaternion:
# yaw = 2 * atan2(qz, qw)
left_yaw = 2 * math.atan2(left_quat[3], left_quat[0])
right_yaw = 2 * math.atan2(right_quat[3], right_quat[0])
# Clear and redraw the plot.
ax.cla()
ax.set_xlim(-1, 1)
ax.set_ylim(-1, 1)
ax.set_xlabel("X")
ax.set_ylabel("Y")
ax.set_title("Footstep Target Visualization")
# Plot the left and right foot positions.
ax.plot(left_x, left_y, 'bo', label='Left Foot')
ax.plot(right_x, right_y, 'ro', label='Right Foot')
# Draw an arrow for each foot to indicate orientation.
arrow_length = 0.1
ax.arrow(left_x, left_y,
arrow_length * math.cos(left_yaw),
arrow_length * math.sin(left_yaw),
head_width=0.03, head_length=0.03, fc='b', ec='b')
ax.arrow(right_x, right_y,
arrow_length * math.cos(right_yaw),
arrow_length * math.sin(right_yaw),
head_width=0.03, head_length=0.03, fc='r', ec='r')
ax.legend()
plt.pause(0.001)
time.sleep(0.1)
except KeyboardInterrupt:
print("Test terminated by user.")
finally:
plt.ioff()
plt.show()

View File

@ -0,0 +1,226 @@
import numpy as np
from geometry_msgs.msg import Vector3, Quaternion
from typing import Optional, Tuple
def to_array(v):
if isinstance(v, Vector3):
return np.array([v.x, v.y, v.z], dtype=np.float32)
elif isinstance(v, Quaternion):
return np.array([v.x, v.y, v.z, v.w], dtype=np.float32)
def normalize(x: np.ndarray, eps: float = 1e-9) -> np.ndarray:
"""Normalizes a given input tensor to unit length.
Args:
x: Input tensor of shape (N, dims).
eps: A small value to avoid division by zero. Defaults to 1e-9.
Returns:
Normalized tensor of shape (N, dims).
"""
return x / np.linalg.norm(x, ord=2, axis=-1, keepdims=True).clip(min=eps, max=None)
def yaw_quat(quat: np.ndarray) -> np.ndarray:
"""Extract the yaw component of a quaternion.
Args:
quat: The orientation in (w, x, y, z). Shape is (..., 4)
Returns:
A quaternion with only yaw component.
"""
shape = quat.shape
quat_yaw = quat.copy().reshape(-1, 4)
qw = quat_yaw[:, 0]
qx = quat_yaw[:, 1]
qy = quat_yaw[:, 2]
qz = quat_yaw[:, 3]
yaw = np.arctan2(2 * (qw * qz + qx * qy), 1 - 2 * (qy * qy + qz * qz))
quat_yaw[:] = 0.0
quat_yaw[:, 3] = np.sin(yaw / 2)
quat_yaw[:, 0] = np.cos(yaw / 2)
quat_yaw = normalize(quat_yaw)
return quat_yaw.reshape(shape)
def quat_conjugate(q: np.ndarray) -> np.ndarray:
"""Computes the conjugate of a quaternion.
Args:
q: The quaternion orientation in (w, x, y, z). Shape is (..., 4).
Returns:
The conjugate quaternion in (w, x, y, z). Shape is (..., 4).
"""
shape = q.shape
q = q.reshape(-1, 4)
return np.concatenate((q[:, 0:1], -q[:, 1:]), axis=-1).reshape(shape)
def quat_inv(q: np.ndarray) -> np.ndarray:
"""Compute the inverse of a quaternion.
Args:
q: The quaternion orientation in (w, x, y, z). Shape is (N, 4).
Returns:
The inverse quaternion in (w, x, y, z). Shape is (N, 4).
"""
return normalize(quat_conjugate(q))
def quat_mul(q1: np.ndarray, q2: np.ndarray) -> np.ndarray:
"""Multiply two quaternions together.
Args:
q1: The first quaternion in (w, x, y, z). Shape is (..., 4).
q2: The second quaternion in (w, x, y, z). Shape is (..., 4).
Returns:
The product of the two quaternions in (w, x, y, z). Shape is (..., 4).
Raises:
ValueError: Input shapes of ``q1`` and ``q2`` are not matching.
"""
# check input is correct
if q1.shape != q2.shape:
msg = f"Expected input quaternion shape mismatch: {q1.shape} != {q2.shape}."
raise ValueError(msg)
# reshape to (N, 4) for multiplication
shape = q1.shape
q1 = q1.reshape(-1, 4)
q2 = q2.reshape(-1, 4)
# extract components from quaternions
w1, x1, y1, z1 = q1[:, 0], q1[:, 1], q1[:, 2], q1[:, 3]
w2, x2, y2, z2 = q2[:, 0], q2[:, 1], q2[:, 2], q2[:, 3]
# perform multiplication
ww = (z1 + x1) * (x2 + y2)
yy = (w1 - y1) * (w2 + z2)
zz = (w1 + y1) * (w2 - z2)
xx = ww + yy + zz
qq = 0.5 * (xx + (z1 - x1) * (x2 - y2))
w = qq - ww + (z1 - y1) * (y2 - z2)
x = qq - xx + (x1 + w1) * (x2 + w2)
y = qq - yy + (w1 - x1) * (y2 + z2)
z = qq - zz + (z1 + y1) * (w2 - x2)
return np.stack([w, x, y, z], axis=-1).reshape(shape)
def quat_apply(quat: np.ndarray, vec: np.ndarray) -> np.ndarray:
"""Apply a quaternion rotation to a vector.
Args:
quat: The quaternion in (w, x, y, z). Shape is (..., 4).
vec: The vector in (x, y, z). Shape is (..., 3).
Returns:
The rotated vector in (x, y, z). Shape is (..., 3).
"""
# store shape
shape = vec.shape
# reshape to (N, 3) for multiplication
quat = quat.reshape(-1, 4)
vec = vec.reshape(-1, 3)
# extract components from quaternions
xyz = quat[:, 1:]
t = np.cross(xyz, vec, axis=-1) * 2
return (vec + quat[:, 0:1] * t + np.cross(xyz, t, axis=-1)).reshape(shape)
def subtract_frame_transforms(
t01: np.ndarray, q01: np.ndarray,
t02: Optional[np.ndarray] = None,
q02: Optional[np.ndarray] = None
) -> Tuple[np.ndarray, np.ndarray]:
r"""Subtract transformations between two reference frames into a stationary frame.
It performs the following transformation operation: :math:`T_{12} = T_{01}^{-1} \times T_{02}`,
where :math:`T_{AB}` is the homogeneous transformation matrix from frame A to B.
Args:
t01: Position of frame 1 w.r.t. frame 0. Shape is (N, 3).
q01: Quaternion orientation of frame 1 w.r.t. frame 0 in (w, x, y, z). Shape is (N, 4).
t02: Position of frame 2 w.r.t. frame 0. Shape is (N, 3).
Defaults to None, in which case the position is assumed to be zero.
q02: Quaternion orientation of frame 2 w.r.t. frame 0 in (w, x, y, z). Shape is (N, 4).
Defaults to None, in which case the orientation is assumed to be identity.
Returns:
A tuple containing the position and orientation of frame 2 w.r.t. frame 1.
Shape of the tensors are (N, 3) and (N, 4) respectively.
"""
# compute orientation
q10 = quat_inv(q01)
if q02 is not None:
q12 = quat_mul(q10, q02)
else:
q12 = q10
# compute translation
if t02 is not None:
t12 = quat_apply(q10, t02 - t01)
else:
t12 = quat_apply(q10, -t01)
return t12, q12
def compute_pose_error(t01: np.ndarray,
q01: np.ndarray,
t02: np.ndarray,
q02: np.ndarray,
return_type='axa') -> Tuple[np.ndarray, np.ndarray]:
q10 = quat_inv(q01)
quat_error = quat_mul(q02, q10)
pos_error = t02-t01
if return_type == 'axa':
quat_error = axis_angle_from_quat(quat_error)
return pos_error, quat_error
def axis_angle_from_quat(quat: np.ndarray, eps: float = 1.0e-6) -> np.ndarray:
"""Convert rotations given as quaternions to axis/angle.
Args:
quat: The quaternion orientation in (w, x, y, z). Shape is (..., 4).
eps: The tolerance for Taylor approximation. Defaults to 1.0e-6.
Returns:
Rotations given as a vector in axis angle form. Shape is (..., 3).
The vector's magnitude is the angle turned anti-clockwise in radians around the vector's direction.
Reference:
https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/transforms/rotation_conversions.py#L526-L554
"""
# Modified to take in quat as [q_w, q_x, q_y, q_z]
# Quaternion is [q_w, q_x, q_y, q_z] = [cos(theta/2), n_x * sin(theta/2), n_y * sin(theta/2), n_z * sin(theta/2)]
# Axis-angle is [a_x, a_y, a_z] = [theta * n_x, theta * n_y, theta * n_z]
# Thus, axis-angle is [q_x, q_y, q_z] / (sin(theta/2) / theta)
# When theta = 0, (sin(theta/2) / theta) is undefined
# However, as theta --> 0, we can use the Taylor approximation 1/2 - theta^2 / 48
quat = quat * (1.0 - 2.0 * (quat[..., 0:1] < 0.0))
mag = np.linalg.norm(quat[..., 1:], axis=-1)
half_angle = np.arctan2(mag, quat[..., 0])
angle = 2.0 * half_angle
# check whether to apply Taylor approximation
sin_half_angles_over_angles = np.where(
np.abs(angle) > eps, np.sin(half_angle) / angle, 0.5 - angle * angle / 48
)
return quat[..., 1:4] / sin_half_angles_over_angles[..., None]
def wrap_to_pi(angles: np.ndarray) -> np.ndarray:
r"""Wraps input angles (in radians) to the range :math:`[-\pi, \pi]`.
This function wraps angles in radians to the range :math:`[-\pi, \pi]`, such that
:math:`\pi` maps to :math:`\pi`, and :math:`-\pi` maps to :math:`-\pi`. In general,
odd positive multiples of :math:`\pi` are mapped to :math:`\pi`, and odd negative
multiples of :math:`\pi` are mapped to :math:`-\pi`.
The function behaves similar to MATLAB's `wrapToPi <https://www.mathworks.com/help/map/ref/wraptopi.html>`_
function.
Args:
angles: Input angles of any shape.
Returns:
Angles in the range :math:`[-\pi, \pi]`.
"""
# wrap to [0, 2*pi)
wrapped_angle = (angles + np.pi) % (2 * np.pi)
# map to [-pi, pi]
# we check for zero in wrapped angle to make it go to pi when input angle is odd multiple of pi
return np.where((wrapped_angle == 0) & (angles > 0), np.pi, wrapped_angle - np.pi)

View File

@ -7,12 +7,35 @@ imu_type: "pelvis" # "torso" or "pelvis"
lowcmd_topic: "rt/lowcmd"
lowstate_topic: "rt/lowstate"
policy_path: "{LEGGED_GYM_ROOT_DIR}/deploy/pre_train/g1/policy.pt"
policy_path: "{LEGGED_GYM_ROOT_DIR}/deploy/pre_train/g1/step_policy_v2.pt"
leg_joint2motor_idx: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
joint2motor_idx: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
12, 13, 14, 15, 16, 17, 18, 19, 20, 21,
22, 23, 24, 25, 26, 27, 28]
# kps: [100, 100, 100, 150, 40, 40, 100, 100, 100, 150, 40, 40]
# kds: [2, 2, 2, 4, 2, 2, 2, 2, 2, 4, 2, 2]
# default_angles: [-0.1, 0.0, 0.0, 0.3, -0.2, 0.0,
# -0.1, 0.0, 0.0, 0.3, -0.2, 0.0]
kps: [
100, 100, 100, 150, 40, 40, 100, 100, 100, 150, 40, 40,
150, 150, 150,
100, 100, 50, 50, 20, 20, 20,
100, 100, 50, 50, 20, 20, 20
]
kds: [
2, 2, 2, 4, 2, 2, 2, 2, 2, 4, 2, 2,
3, 3, 3,
2, 2, 2, 2, 1, 1, 1,
2, 2, 2, 2, 1, 1, 1
]
default_angles: [
-0.2, 0.0, 0.0, 0.42, -0.23, 0.0,
-0.2, 0.0, 0.0, 0.42, -0.23, 0.0,
0., 0., 0.,
0.35, 0.16, 0., 0.87, 0., 0., 0.,
0.35, -0.16, 0., 0.87, 0., 0., 0.,
]
raw_joint_order: [
'left_hip_pitch_joint',
@ -45,29 +68,6 @@ raw_joint_order: [
'right_wrist_pitch_joint',
'right_wrist_yaw_joint'
]
# kps: [100, 100, 100, 150, 40, 40, 100, 100, 100, 150, 40, 40]
# kds: [2, 2, 2, 4, 2, 2, 2, 2, 2, 4, 2, 2]
# default_angles: [-0.1, 0.0, 0.0, 0.3, -0.2, 0.0,
# -0.1, 0.0, 0.0, 0.3, -0.2, 0.0]
kps: [
100, 100, 100, 150, 40, 40, 100, 100, 100, 150, 40, 40,
150, 150, 150,
100, 100, 50, 50, 20, 20, 20,
100, 100, 50, 50, 20, 20, 20
]
kds: [
2, 2, 2, 4, 2, 2, 2, 2, 2, 4, 2, 2,
3, 3, 3,
2, 2, 2, 2, 1, 1, 1,
2, 2, 2, 2, 1, 1, 1
]
default_angles: [
-0.2, 0.0, 0.0, 0.42, -0.23, 0.0,
-0.2, 0.0, 0.0, 0.42, -0.23, 0.0,
0., 0., 0.,
0.35, 0.16, 0., 0.87, 0., 0., 0.,
0.35, -0.16, 0., 0.87, 0., 0., 0.,
]
# arm_waist_joint2motor_idx: [12, 13, 14,
# 15, 16, 17, 18, 19, 20, 21,
@ -97,7 +97,7 @@ cmd_scale: [0.0, 0.0, 0.0]
# num_actions: 12
num_actions: 29
# num_obs: 47
num_obs: 96
num_obs: 131
# max_cmd: [0.8, 0.5, 1.57]
max_cmd: [1.0, 1.0, 1.0]

View File

@ -0,0 +1,556 @@
from legged_gym import LEGGED_GYM_ROOT_DIR
from typing import Union
import numpy as np
import time
import torch
import rclpy as rp
from unitree_sdk2py.core.channel import ChannelPublisher, ChannelFactoryInitialize
from unitree_sdk2py.core.channel import ChannelSubscriber, ChannelFactoryInitialize
from unitree_sdk2py.idl.default import unitree_hg_msg_dds__LowCmd_, unitree_hg_msg_dds__LowState_
from unitree_sdk2py.idl.default import unitree_go_msg_dds__LowCmd_, unitree_go_msg_dds__LowState_
from unitree_sdk2py.idl.unitree_hg.msg.dds_ import LowCmd_ as LowCmdHG
from unitree_sdk2py.idl.unitree_go.msg.dds_ import LowCmd_ as LowCmdGo
from unitree_sdk2py.idl.unitree_hg.msg.dds_ import LowState_ as LowStateHG
from unitree_sdk2py.idl.unitree_go.msg.dds_ import LowState_ as LowStateGo
from unitree_sdk2py.utils.crc import CRC
from common.command_helper import create_damping_cmd, create_zero_cmd, init_cmd_hg, init_cmd_go, MotorMode
from common.rotation_helper import get_gravity_orientation, transform_imu_data
from common.remote_controller import RemoteController, KeyMap
from common.step_command import StepCommand
from common.utils import (to_array, normalize, yaw_quat,
axis_angle_from_quat,
subtract_frame_transforms,
wrap_to_pi,
compute_pose_error
)
from config import Config
from tf2_ros.buffer import Buffer
from tf2_ros.transform_listener import TransformListener
from tf2_ros import TransformBroadcaster, TransformStamped, StaticTransformBroadcaster
isaaclab_joint_order = [
'left_hip_pitch_joint',
'right_hip_pitch_joint',
'waist_yaw_joint',
'left_hip_roll_joint',
'right_hip_roll_joint',
'waist_roll_joint',
'left_hip_yaw_joint',
'right_hip_yaw_joint',
'waist_pitch_joint',
'left_knee_joint',
'right_knee_joint',
'left_shoulder_pitch_joint',
'right_shoulder_pitch_joint',
'left_ankle_pitch_joint',
'right_ankle_pitch_joint',
'left_shoulder_roll_joint',
'right_shoulder_roll_joint',
'left_ankle_roll_joint',
'right_ankle_roll_joint',
'left_shoulder_yaw_joint',
'right_shoulder_yaw_joint',
'left_elbow_joint',
'right_elbow_joint',
'left_wrist_roll_joint',
'right_wrist_roll_joint',
'left_wrist_pitch_joint',
'right_wrist_pitch_joint',
'left_wrist_yaw_joint',
'right_wrist_yaw_joint'
]
raw_joint_order = [
'left_hip_pitch_joint',
'left_hip_roll_joint',
'left_hip_yaw_joint',
'left_knee_joint',
'left_ankle_pitch_joint',
'left_ankle_roll_joint',
'right_hip_pitch_joint',
'right_hip_roll_joint',
'right_hip_yaw_joint',
'right_knee_joint',
'right_ankle_pitch_joint',
'right_ankle_roll_joint',
'waist_yaw_joint',
'waist_roll_joint',
'waist_pitch_joint',
'left_shoulder_pitch_joint',
'left_shoulder_roll_joint',
'left_shoulder_yaw_joint',
'left_elbow_joint',
'left_wrist_roll_joint',
'left_wrist_pitch_joint',
'left_wrist_yaw_joint',
'right_shoulder_pitch_joint',
'right_shoulder_roll_joint',
'right_shoulder_yaw_joint',
'right_elbow_joint',
'right_wrist_roll_joint',
'right_wrist_pitch_joint',
'right_wrist_yaw_joint'
]
# Create a mapping tensor
# mapping_tensor = torch.zeros((len(sim_b_joints), len(sim_a_joints)), device=env.device)
mapping_tensor = torch.zeros((len(raw_joint_order), len(isaaclab_joint_order)))
# Fill the mapping tensor
for b_idx, b_joint in enumerate(raw_joint_order):
if b_joint in isaaclab_joint_order:
a_idx = isaaclab_joint_order.index(b_joint)
mapping_tensor[a_idx, b_idx] = 1.0
class Controller:
def __init__(self, config: Config) -> None:
self.config = config
self.remote_controller = RemoteController()
# Initialize the policy network
self.policy = torch.jit.load(config.policy_path)
# Initializing process variables
self.qj = np.zeros(config.num_actions, dtype=np.float32)
self.dqj = np.zeros(config.num_actions, dtype=np.float32)
self.action = np.zeros(config.num_actions, dtype=np.float32)
self.target_dof_pos = config.default_angles.copy()
self.obs = np.zeros(config.num_obs, dtype=np.float32)
self.cmd = np.array([0.0, 0, 0])
self.counter = 0
rp.init()
self._node = rp.create_node("low_level_cmd_sender")
self.tf_buffer = Buffer()
self.tf_listener = TransformListener(self.tf_buffer, self._node)
self.tf_broadcaster = TransformBroadcaster(self._node)
self._step_command = None
self._saved = False
self._cur_time = None
if config.msg_type == "hg":
# g1 and h1_2 use the hg msg type
self.low_cmd = unitree_hg_msg_dds__LowCmd_()
self.low_state = unitree_hg_msg_dds__LowState_()
self.mode_pr_ = MotorMode.PR
self.mode_machine_ = 0
self.lowcmd_publisher_ = ChannelPublisher(config.lowcmd_topic, LowCmdHG)
self.lowcmd_publisher_.Init()
self.lowstate_subscriber = ChannelSubscriber(config.lowstate_topic, LowStateHG)
self.lowstate_subscriber.Init(self.LowStateHgHandler, 10)
elif config.msg_type == "go":
# h1 uses the go msg type
self.low_cmd = unitree_go_msg_dds__LowCmd_()
self.low_state = unitree_go_msg_dds__LowState_()
self.lowcmd_publisher_ = ChannelPublisher(config.lowcmd_topic, LowCmdGo)
self.lowcmd_publisher_.Init()
self.lowstate_subscriber = ChannelSubscriber(config.lowstate_topic, LowStateGo)
self.lowstate_subscriber.Init(self.LowStateGoHandler, 10)
else:
raise ValueError("Invalid msg_type")
# wait for the subscriber to receive data
self.wait_for_low_state()
# Initialize the command msg
if config.msg_type == "hg":
init_cmd_hg(self.low_cmd, self.mode_machine_, self.mode_pr_)
elif config.msg_type == "go":
init_cmd_go(self.low_cmd, weak_motor=self.config.weak_motor)
def LowStateHgHandler(self, msg: LowStateHG):
self.low_state = msg
self.mode_machine_ = self.low_state.mode_machine
self.remote_controller.set(self.low_state.wireless_remote)
def LowStateGoHandler(self, msg: LowStateGo):
self.low_state = msg
self.remote_controller.set(self.low_state.wireless_remote)
def send_cmd(self, cmd: Union[LowCmdGo, LowCmdHG]):
cmd.crc = CRC().Crc(cmd)
self.lowcmd_publisher_.Write(cmd)
def wait_for_low_state(self):
while self.low_state.tick == 0:
time.sleep(self.config.control_dt)
print("Successfully connected to the robot.")
def zero_torque_state(self):
print("Enter zero torque state.")
print("Waiting for the start signal...")
while self.remote_controller.button[KeyMap.start] != 1:
create_zero_cmd(self.low_cmd)
self.send_cmd(self.low_cmd)
time.sleep(self.config.control_dt)
def move_to_default_pos(self):
print("Moving to default pos.")
# move time 2s
total_time = 2
num_step = int(total_time / self.config.control_dt)
# dof_idx = self.config.leg_joint2motor_idx + self.config.arm_waist_joint2motor_idx
# kps = self.config.kps + self.config.arm_waist_kps
# kds = self.config.kds + self.config.arm_waist_kds
# default_pos = np.concatenate((self.config.default_angles, self.config.arm_waist_target), axis=0)
dof_idx = self.config.joint2motor_idx
kps = self.config.kps
kds = self.config.kds
default_pos = self.config.default_angles
dof_size = len(dof_idx)
# record the current pos
init_dof_pos = np.zeros(dof_size, dtype=np.float32)
for i in range(dof_size):
init_dof_pos[i] = self.low_state.motor_state[dof_idx[i]].q
# move to default pos
for i in range(num_step):
alpha = i / num_step
for j in range(dof_size):
motor_idx = dof_idx[j]
target_pos = default_pos[j]
self.low_cmd.motor_cmd[motor_idx].q = init_dof_pos[j] * (1 - alpha) + target_pos * alpha
self.low_cmd.motor_cmd[motor_idx].qd = 0
self.low_cmd.motor_cmd[motor_idx].kp = kps[j]
self.low_cmd.motor_cmd[motor_idx].kd = kds[j]
self.low_cmd.motor_cmd[motor_idx].tau = 0
self.send_cmd(self.low_cmd)
time.sleep(self.config.control_dt)
def default_pos_state(self):
print("Enter default pos state.")
print("Waiting for the Button A signal...")
while self.remote_controller.button[KeyMap.A] != 1:
# for i in range(len(self.config.leg_joint2motor_idx)):
# motor_idx = self.config.leg_joint2motor_idx[i]
# self.low_cmd.motor_cmd[motor_idx].q = self.config.default_angles[i]
# self.low_cmd.motor_cmd[motor_idx].qd = 0
# self.low_cmd.motor_cmd[motor_idx].kp = self.config.kps[i]
# self.low_cmd.motor_cmd[motor_idx].kd = self.config.kds[i]
# self.low_cmd.motor_cmd[motor_idx].tau = 0
# for i in range(len(self.config.arm_waist_joint2motor_idx)):
# motor_idx = self.config.arm_waist_joint2motor_idx[i]
# self.low_cmd.motor_cmd[motor_idx].q = self.config.arm_waist_target[i]
# self.low_cmd.motor_cmd[motor_idx].qd = 0
# self.low_cmd.motor_cmd[motor_idx].kp = self.config.arm_waist_kps[i]
# self.low_cmd.motor_cmd[motor_idx].kd = self.config.arm_waist_kds[i]
# self.low_cmd.motor_cmd[motor_idx].tau = 0
for i in range(len(self.config.joint2motor_idx)):
motor_idx = self.config.joint2motor_idx[i]
self.low_cmd.motor_cmd[motor_idx].q = self.config.default_angles[i]
self.low_cmd.motor_cmd[motor_idx].qd = 0
self.low_cmd.motor_cmd[motor_idx].kp = self.config.kps[i]
self.low_cmd.motor_cmd[motor_idx].kd = self.config.kds[i]
self.low_cmd.motor_cmd[motor_idx].tau = 0
self.send_cmd(self.low_cmd)
time.sleep(self.config.control_dt)
def tf_to_pose(self, tf, order='xyzw'):
pos = to_array(tf.transform.translation)
quat = to_array(tf.transform.rotation)
if order == 'wxyz':
quat = np.roll(quat, 1, axis=-1)
return np.concatenate((pos, quat), axis=0)
def publish_step_command(self, next_ctarget_left, next_ctarget_right):
left_tf = TransformStamped()
left_tf.header.stamp = self._node.get_clock().now().to_msg()
left_tf.header.frame_id = 'world'
left_tf.child_frame_id = 'left_ctarget'
left_tf.transform.translation.x = float(next_ctarget_left[0])
left_tf.transform.translation.y = float(next_ctarget_left[1])
left_tf.transform.translation.z = float(next_ctarget_left[2])
left_tf.transform.rotation.x = float(next_ctarget_left[4])
left_tf.transform.rotation.y = float(next_ctarget_left[5])
left_tf.transform.rotation.z = float(next_ctarget_left[6])
left_tf.transform.rotation.w = float(next_ctarget_left[3])
right_tf = TransformStamped()
right_tf.header.stamp = left_tf.header.stamp
right_tf.header.frame_id = 'world'
right_tf.child_frame_id = 'right_ctarget'
right_tf.transform.translation.x = float(next_ctarget_right[0])
right_tf.transform.translation.y = float(next_ctarget_right[1])
right_tf.transform.translation.z = float(next_ctarget_right[2])
right_tf.transform.rotation.x = float(next_ctarget_right[4])
right_tf.transform.rotation.y = float(next_ctarget_right[5])
right_tf.transform.rotation.z = float(next_ctarget_right[6])
right_tf.transform.rotation.w = float(next_ctarget_right[3])
self.tf_broadcaster.sendTransform(left_tf)
self.tf_broadcaster.sendTransform(right_tf)
def get_command(self, pelvis_w,
foot_left_b,
foot_right_b,
ctarget_left_w,
ctarget_right_w):
ctarget_left_b_pos, ctarget_left_b_quat = subtract_frame_transforms(pelvis_w[:3],
pelvis_w[3:7],
ctarget_left_w[:3],
ctarget_left_w[3:7])
ctarget_right_b_pos, ctarget_right_b_quat = subtract_frame_transforms(pelvis_w[:3],
pelvis_w[3:7],
ctarget_right_w[:3],
ctarget_right_w[3:7])
pos_delta_left, axa_delta_left = compute_pose_error(foot_left_b[:3],
foot_left_b[3:7],
ctarget_left_b_pos,
ctarget_left_b_quat)
pos_delta_right, axa_delta_right = compute_pose_error(foot_right_b[:3],
foot_right_b[3:7],
ctarget_right_b_pos,
ctarget_right_b_quat)
return np.concatenate((pos_delta_left, axa_delta_left, pos_delta_right, axa_delta_right), axis=0)
def run_wrapper(self):
t = time.time()
if self._cur_time is None:
self._cur_time = 0.0
if not self._saved:
while True:
rp.spin_once(self._node)
try:
current_left_tf = self.tf_buffer.lookup_transform(
"world",
"left_ankle_roll_link",
rp.time.Time(),
rp.duration.Duration(seconds=0.02))
break
except Exception as ex:
print(ex)
time.sleep(0.05)
if t- self._cur_time > self.config.control_dt:
self.run()
self._cur_time = t
time.sleep(self.config.control_dt/10)
def run(self):
if self._step_command is None:
current_left_tf = self.tf_buffer.lookup_transform(
"world",
"left_ankle_roll_link",
rp.time.Time(),
rp.duration.Duration(seconds=0.02))
current_left_pose = self.tf_to_pose(current_left_tf, 'wxyz')
current_left_pose[2] = 0.0
current_left_pose[3:7] = yaw_quat(current_left_pose[3:7])
current_right_tf = self.tf_buffer.lookup_transform(
"world",
"right_ankle_roll_link",
rp.time.Time(),
rp.duration.Duration(seconds=0.02))
current_right_pose = self.tf_to_pose(current_right_tf, 'wxyz')
current_right_pose[2] = 0.0
current_right_pose[3:7] = yaw_quat(current_right_pose[3:7])
self._step_command = StepCommand(current_left_pose, current_right_pose)
self.counter += 1
next_ctarget = self._step_command.get_next_ctarget(
self.remote_controller,
self.counter * self.config.control_dt)
next_ctarget_left, next_ctarget_right, dt_left, dt_right = next_ctarget
self.publish_step_command(next_ctarget_left, next_ctarget_right)
# Get the current joint position and velocity
# for i in range(len(self.config.leg_joint2motor_idx)):
# self.qj[i] = self.low_state.motor_state[self.config.leg_joint2motor_idx[i]].q
# self.dqj[i] = self.low_state.motor_state[self.config.leg_joint2motor_idx[i]].dq
for i, motor_idx in enumerate(self.config.joint2motor_idx):
self.qj[i] = self.low_state.motor_state[motor_idx].q
self.dqj[i] = self.low_state.motor_state[motor_idx].dq
# imu_state quaternion: w, x, y, z
quat = self.low_state.imu_state.quaternion
ang_vel = np.array([self.low_state.imu_state.gyroscope], dtype=np.float32)
if self.config.imu_type == "torso":
# h1 and h1_2 imu is on the torso
# imu data needs to be transformed to the pelvis frame
# waist_yaw = self.low_state.motor_state[self.config.arm_waist_joint2motor_idx[0]].q
# waist_yaw_omega = self.low_state.motor_state[self.config.arm_waist_joint2motor_idx[0]].dq
waist_yaw = self.low_state.motor_state[self.config.joint2motor_idx[12]].q
waist_yaw_omega = self.low_state.motor_state[self.config.joint2motor_idx[12]].dq
quat, ang_vel = transform_imu_data(waist_yaw=waist_yaw, waist_yaw_omega=waist_yaw_omega, imu_quat=quat, imu_omega=ang_vel)
# create observation
gravity_orientation = get_gravity_orientation(quat)
qj_obs = self.qj.copy()
dqj_obs = self.dqj.copy()
qj_obs = (qj_obs - self.config.default_angles) * self.config.dof_pos_scale
dqj_obs = dqj_obs * self.config.dof_vel_scale
ang_vel = ang_vel * self.config.ang_vel_scale
# foot pose
left_foot_from_base_tf = self.tf_buffer.lookup_transform(
"pelvis",
"left_ankle_roll_link",
rp.time.Time())
right_foot_from_base_tf = self.tf_buffer.lookup_transform(
"pelvis",
"right_ankle_roll_link",
rp.time.Time())
print(left_foot_from_base_tf, right_foot_from_base_tf)
lf_b = self.tf_to_pose(left_foot_from_base_tf, 'wxyz')
rf_b = self.tf_to_pose(right_foot_from_base_tf, 'wxyz')
left_foot_axa = wrap_to_pi(axis_angle_from_quat(lf_b[3:7]))
right_foot_axa = wrap_to_pi(axis_angle_from_quat(rf_b[3:7]))
rel_foot = np.concatenate((lf_b[:3],
rf_b[:3],
left_foot_axa,
right_foot_axa), axis=0)
# hand pose
left_hand_from_base_tf = self.tf_buffer.lookup_transform(
"pelvis",
"left_rubber_hand",
rp.time.Time())
right_hand_from_base_tf = self.tf_buffer.lookup_transform(
"pelvis",
"right_rubber_hand",
rp.time.Time())
left_hand_from_base = self.tf_to_pose(left_hand_from_base_tf, 'wxyz')
right_hand_from_base = self.tf_to_pose(right_hand_from_base_tf, 'wxyz')
left_hand_axa = wrap_to_pi(axis_angle_from_quat(left_hand_from_base[3:7]))
right_hand_axa = wrap_to_pi(axis_angle_from_quat(right_hand_from_base[3:7]))
rel_hand = np.concatenate((left_hand_from_base[:3],
right_hand_from_base[:3],
left_hand_axa,
right_hand_axa), axis=0)
# foot command
base_pose_w = self.tf_to_pose(self.tf_buffer.lookup_transform(
"world", "pelvis",
rp.time.Time()), 'wxyz')
step_command = self.get_command(base_pose_w,
lf_b,
rf_b,
next_ctarget_left,
next_ctarget_right)
step_command = np.concatenate((step_command,
np.asarray([dt_left, dt_right])), axis=0)
num_actions = self.config.num_actions
self.obs[:3] = ang_vel
self.obs[3:6] = gravity_orientation
# self.obs[6:9] = self.cmd * self.config.cmd_scale * self.config.max_cmd
self.obs[6:18] = rel_foot
self.obs[18:30] = rel_hand
self.obs[30 : 30 + num_actions] = qj_obs
self.obs[30 + num_actions : 30 + num_actions * 2] = dqj_obs
self.obs[30 + num_actions * 2 : 30 + num_actions * 3] = self.action
self.obs[30 + num_actions * 3 : 30 + num_actions * 3 + 14] = step_command
# self.obs[9 + num_actions * 3] = sin_phase
# self.obs[9 + num_actions * 3 + 1] = cos_phase
# Get the action from the policy network
obs_tensor = torch.from_numpy(self.obs).unsqueeze(0)
# Reorder the observations
obs_tensor[..., 30:30+num_actions] = obs_tensor[..., 30:30+num_actions] @ mapping_tensor.transpose(0, 1)
obs_tensor[..., 30 + num_actions : 30 + num_actions * 2] = obs_tensor[..., 30 + num_actions : 30 + num_actions * 2] @ mapping_tensor.transpose(0, 1)
obs_tensor[..., 30 + num_actions * 2 : 30 + num_actions * 3] = obs_tensor[..., 30 + num_actions * 2 : 30 + num_actions * 3] @ mapping_tensor.transpose(0, 1)
if not self._saved:
torch.save(obs_tensor, "obs.pt")
self._saved = True
self.action = self.policy(obs_tensor).detach().numpy().squeeze()
# Reorder the actions
self.action = self.action @ mapping_tensor.detach().cpu().numpy()
# transform action to target_dof_pos
target_dof_pos = self.config.default_angles + self.action * self.config.action_scale
# Build low cmd
# for i in range(len(self.config.leg_joint2motor_idx)):
# motor_idx = self.config.leg_joint2motor_idx[i]
# self.low_cmd.motor_cmd[motor_idx].q = target_dof_pos[i]
# self.low_cmd.motor_cmd[motor_idx].qd = 0
# self.low_cmd.motor_cmd[motor_idx].kp = self.config.kps[i]
# self.low_cmd.motor_cmd[motor_idx].kd = self.config.kds[i]
# self.low_cmd.motor_cmd[motor_idx].tau = 0
# for i in range(len(self.config.arm_waist_joint2motor_idx)):
# motor_idx = self.config.arm_waist_joint2motor_idx[i]
# self.low_cmd.motor_cmd[motor_idx].q = self.config.arm_waist_target[i]
# self.low_cmd.motor_cmd[motor_idx].qd = 0
# self.low_cmd.motor_cmd[motor_idx].kp = self.config.arm_waist_kps[i]
# self.low_cmd.motor_cmd[motor_idx].kd = self.config.arm_waist_kds[i]
# self.low_cmd.motor_cmd[motor_idx].tau = 0
if False:
for i, motor_idx in enumerate(self.config.joint2motor_idx):
self.low_cmd.motor_cmd[motor_idx].q = target_dof_pos[i]
self.low_cmd.motor_cmd[motor_idx].qd = 0
self.low_cmd.motor_cmd[motor_idx].kp = self.config.kps[i]
self.low_cmd.motor_cmd[motor_idx].kd = self.config.kds[i]
self.low_cmd.motor_cmd[motor_idx].tau = 0
# send the command
self.send_cmd(self.low_cmd)
def clear(self):
self._node.destroy_node()
rp.shutdown()
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("net", type=str, help="network interface")
parser.add_argument("config", type=str, help="config file name in the configs folder", default="g1.yaml")
args = parser.parse_args()
# Load config
config_path = f"{LEGGED_GYM_ROOT_DIR}/deploy/deploy_real/configs/{args.config}"
config = Config(config_path)
# Initialize DDS communication
ChannelFactoryInitialize(0, args.net)
controller = Controller(config)
# Enter the zero torque state, press the start key to continue executing
controller.zero_torque_state()
# Move to the default position
controller.move_to_default_pos()
# Enter the default position state, press the A key to continue executing
controller.default_pos_state()
while True:
try:
controller.run_wrapper()
# Press the select key to exit
if controller.remote_controller.button[KeyMap.select] == 1:
break
except KeyboardInterrupt:
break
# Enter the damping state
create_damping_cmd(controller.low_cmd)
controller.send_cmd(controller.low_cmd)
controller.clear()
print("Exit")

View File

@ -0,0 +1,552 @@
from legged_gym import LEGGED_GYM_ROOT_DIR
from typing import Union
import numpy as np
import time
import torch
import rclpy as rp
from unitree_hg.msg import LowCmd as LowCmdHG, LowState as LowStateHG
from unitree_go.msg import LowCmd as LowCmdGo, LowState as LowStateGo
from common.command_helper_ros import create_damping_cmd, create_zero_cmd, init_cmd_hg, init_cmd_go, MotorMode
from common.rotation_helper import get_gravity_orientation, transform_imu_data
from common.remote_controller import RemoteController, KeyMap
from config import Config
from common.crc import CRC
from enum import Enum
from common.step_command import StepCommand
from common.utils import (to_array, normalize, yaw_quat,
axis_angle_from_quat,
subtract_frame_transforms,
wrap_to_pi,
compute_pose_error
)
from config import Config
from tf2_ros.buffer import Buffer
from tf2_ros.transform_listener import TransformListener
from tf2_ros import TransformBroadcaster, TransformStamped, StaticTransformBroadcaster
class Mode(Enum):
wait = 0
zero_torque = 1
default_pos = 2
damping = 3
policy = 4
null = 5
isaaclab_joint_order = [
'left_hip_pitch_joint',
'right_hip_pitch_joint',
'waist_yaw_joint',
'left_hip_roll_joint',
'right_hip_roll_joint',
'waist_roll_joint',
'left_hip_yaw_joint',
'right_hip_yaw_joint',
'waist_pitch_joint',
'left_knee_joint',
'right_knee_joint',
'left_shoulder_pitch_joint',
'right_shoulder_pitch_joint',
'left_ankle_pitch_joint',
'right_ankle_pitch_joint',
'left_shoulder_roll_joint',
'right_shoulder_roll_joint',
'left_ankle_roll_joint',
'right_ankle_roll_joint',
'left_shoulder_yaw_joint',
'right_shoulder_yaw_joint',
'left_elbow_joint',
'right_elbow_joint',
'left_wrist_roll_joint',
'right_wrist_roll_joint',
'left_wrist_pitch_joint',
'right_wrist_pitch_joint',
'left_wrist_yaw_joint',
'right_wrist_yaw_joint'
]
raw_joint_order = [
'left_hip_pitch_joint',
'left_hip_roll_joint',
'left_hip_yaw_joint',
'left_knee_joint',
'left_ankle_pitch_joint',
'left_ankle_roll_joint',
'right_hip_pitch_joint',
'right_hip_roll_joint',
'right_hip_yaw_joint',
'right_knee_joint',
'right_ankle_pitch_joint',
'right_ankle_roll_joint',
'waist_yaw_joint',
'waist_roll_joint',
'waist_pitch_joint',
'left_shoulder_pitch_joint',
'left_shoulder_roll_joint',
'left_shoulder_yaw_joint',
'left_elbow_joint',
'left_wrist_roll_joint',
'left_wrist_pitch_joint',
'left_wrist_yaw_joint',
'right_shoulder_pitch_joint',
'right_shoulder_roll_joint',
'right_shoulder_yaw_joint',
'right_elbow_joint',
'right_wrist_roll_joint',
'right_wrist_pitch_joint',
'right_wrist_yaw_joint'
]
# Create a mapping tensor
mapping_tensor = torch.zeros((len(raw_joint_order), len(isaaclab_joint_order)))
# Fill the mapping tensor
for b_idx, b_joint in enumerate(raw_joint_order):
if b_joint in isaaclab_joint_order:
a_idx = isaaclab_joint_order.index(b_joint)
# if 'shoulder' in b_joint or 'elbow' in b_joint or 'wrist' in b_joint:
# mapping_tensor[a_idx, b_idx] = 0.1
# else:
mapping_tensor[a_idx, b_idx] = 1.0
mask = torch.ones(len(isaaclab_joint_order))
for b_idx, b_joint in enumerate(isaaclab_joint_order):
if 'shoulder' in b_joint or 'elbow' in b_joint or 'wrist' in b_joint:
mask[b_idx] = 0
class Controller:
def __init__(self, config: Config) -> None:
self.config = config
self.remote_controller = RemoteController()
# Initialize the policy network
self.policy = torch.jit.load(config.policy_path)
# Initializing process variables
self.qj = np.zeros(config.num_actions, dtype=np.float32)
self.dqj = np.zeros(config.num_actions, dtype=np.float32)
self.action = np.zeros(config.num_actions, dtype=np.float32)
self.target_dof_pos = config.default_angles.copy()
self.obs = np.zeros(config.num_obs, dtype=np.float32)
self.cmd = np.array([0.0, 0, 0])
self.counter = 0
rp.init()
self._node = rp.create_node("low_level_cmd_sender")
self.tf_buffer = Buffer()
self.tf_listener = TransformListener(self.tf_buffer, self._node)
self.tf_broadcaster = TransformBroadcaster(self._node)
self._step_command = None
self._saved = False
self._cur_time = None
if config.msg_type == "hg":
# g1 and h1_2 use the hg msg type
self.low_cmd = LowCmdHG()
self.low_state = LowStateHG()
self.lowcmd_publisher_ = self._node.create_publisher(LowCmdHG,
'lowcmd', 10)
self.lowstate_subscriber = self._node.create_subscription(LowStateHG,
'lowstate', self.LowStateHgHandler, 10)
self.mode_pr_ = MotorMode.PR
self.mode_machine_ = 0
# self.lowcmd_publisher_ = ChannelPublisher(config.lowcmd_topic, LowCmdHG)
# self.lowcmd_publisher_.Init()
# self.lowstate_subscriber = ChannelSubscriber(config.lowstate_topic, LowStateHG)
# self.lowstate_subscriber.Init(self.LowStateHgHandler, 10)
elif config.msg_type == "go":
raise ValueError(f"{config.msg_type} is not implemented yet.")
else:
raise ValueError("Invalid msg_type")
# wait for the subscriber to receive data
# self.wait_for_low_state()
# Initialize the command msg
if config.msg_type == "hg":
init_cmd_hg(self.low_cmd, self.mode_machine_, self.mode_pr_)
elif config.msg_type == "go":
init_cmd_go(self.low_cmd, weak_motor=self.config.weak_motor)
self.mode = Mode.wait
self._mode_change = True
self._timer = self._node.create_timer(self.config.control_dt, self.run_wrapper)
self._terminate = False
self._obs_buf = []
try:
rp.spin(self._node)
except KeyboardInterrupt:
print("KeyboardInterrupt")
finally:
self._node.destroy_timer(self._timer)
create_damping_cmd(self.low_cmd)
self.send_cmd(self.low_cmd)
self._node.destroy_node()
rp.shutdown()
torch.save(torch.cat(self._obs_buf, dim=0), "obs6.pt")
print("Exit")
def LowStateHgHandler(self, msg: LowStateHG):
self.low_state = msg
self.mode_machine_ = self.low_state.mode_machine
self.remote_controller.set(self.low_state.wireless_remote)
def LowStateGoHandler(self, msg: LowStateGo):
self.low_state = msg
self.remote_controller.set(self.low_state.wireless_remote)
def send_cmd(self, cmd: Union[LowCmdGo, LowCmdHG]):
cmd.mode_machine = self.mode_machine_
cmd.crc = CRC().Crc(cmd)
size = len(cmd.motor_cmd)
# print(cmd.mode_machine)
# for i in range(size):
# print(i, cmd.motor_cmd[i].q,
# cmd.motor_cmd[i].dq,
# cmd.motor_cmd[i].kp,
# cmd.motor_cmd[i].kd,
# cmd.motor_cmd[i].tau)
self.lowcmd_publisher_.publish(cmd)
def wait_for_low_state(self):
while self.low_state.crc == 0:
print(self.low_state)
time.sleep(self.config.control_dt)
print("Successfully connected to the robot.")
def zero_torque_state(self):
if self.remote_controller.button[KeyMap.start] == 1:
self._mode_change = True
self.mode = Mode.default_pos
else:
create_zero_cmd(self.low_cmd)
self.send_cmd(self.low_cmd)
def prepare_default_pos(self):
# move time 2s
total_time = 2
self.counter = 0
self._num_step = int(total_time / self.config.control_dt)
dof_idx = self.config.joint2motor_idx
kps = self.config.kps
kds = self.config.kds
self._kps = [float(kp) for kp in kps]
self._kds = [float(kd) for kd in kds]
self._default_pos = np.asarray(self.config.default_angles)
# np.concatenate((self.config.default_angles), axis=0)
self._dof_size = len(dof_idx)
self._dof_idx = dof_idx
# record the current pos
self._init_dof_pos = np.zeros(self._dof_size,
dtype=np.float32)
for i in range(self._dof_size):
self._init_dof_pos[i] = self.low_state.motor_state[dof_idx[i]].q
def move_to_default_pos(self):
# move to default pos
if self.counter < self._num_step:
alpha = self.counter / self._num_step
for j in range(self._dof_size):
motor_idx = self._dof_idx[j]
target_pos = self._default_pos[j]
self.low_cmd.motor_cmd[motor_idx].q = (self._init_dof_pos[j] *
(1 - alpha) + target_pos * alpha)
self.low_cmd.motor_cmd[motor_idx].dq = 0.0
self.low_cmd.motor_cmd[motor_idx].kp = self._kps[j]
self.low_cmd.motor_cmd[motor_idx].kd = self._kds[j]
self.low_cmd.motor_cmd[motor_idx].tau = 0.0
self.send_cmd(self.low_cmd)
self.counter += 1
else:
self._mode_change = True
self.mode = Mode.damping
def default_pos_state(self):
if self.remote_controller.button[KeyMap.A] != 1:
for i in range(len(self.config.joint2motor_idx)):
motor_idx = self.config.joint2motor_idx[i]
self.low_cmd.motor_cmd[motor_idx].q = float(self.config.default_angles[i])
self.low_cmd.motor_cmd[motor_idx].dq = 0.0
self.low_cmd.motor_cmd[motor_idx].kp = self._kps[i]
self.low_cmd.motor_cmd[motor_idx].kd = self._kds[i]
self.low_cmd.motor_cmd[motor_idx].tau = 0.0
self.send_cmd(self.low_cmd)
else:
self._mode_change = True
self.mode = Mode.policy
def tf_to_pose(self, tf, order='xyzw'):
pos = to_array(tf.transform.translation)
quat = to_array(tf.transform.rotation)
if order == 'wxyz':
quat = np.roll(quat, 1, axis=-1)
return np.concatenate((pos, quat), axis=0)
def publish_step_command(self, next_ctarget_left, next_ctarget_right):
left_tf = TransformStamped()
left_tf.header.stamp = self._node.get_clock().now().to_msg()
left_tf.header.frame_id = 'world'
left_tf.child_frame_id = 'left_ctarget'
left_tf.transform.translation.x = float(next_ctarget_left[0])
left_tf.transform.translation.y = float(next_ctarget_left[1])
left_tf.transform.translation.z = float(next_ctarget_left[2])
left_tf.transform.rotation.x = float(next_ctarget_left[4])
left_tf.transform.rotation.y = float(next_ctarget_left[5])
left_tf.transform.rotation.z = float(next_ctarget_left[6])
left_tf.transform.rotation.w = float(next_ctarget_left[3])
right_tf = TransformStamped()
right_tf.header.stamp = left_tf.header.stamp
right_tf.header.frame_id = 'world'
right_tf.child_frame_id = 'right_ctarget'
right_tf.transform.translation.x = float(next_ctarget_right[0])
right_tf.transform.translation.y = float(next_ctarget_right[1])
right_tf.transform.translation.z = float(next_ctarget_right[2])
right_tf.transform.rotation.x = float(next_ctarget_right[4])
right_tf.transform.rotation.y = float(next_ctarget_right[5])
right_tf.transform.rotation.z = float(next_ctarget_right[6])
right_tf.transform.rotation.w = float(next_ctarget_right[3])
self.tf_broadcaster.sendTransform(left_tf)
self.tf_broadcaster.sendTransform(right_tf)
def get_command(self, pelvis_w,
foot_left_b,
foot_right_b,
ctarget_left_w,
ctarget_right_w):
ctarget_left_b_pos, ctarget_left_b_quat = subtract_frame_transforms(pelvis_w[:3],
pelvis_w[3:7],
ctarget_left_w[:3],
ctarget_left_w[3:7])
ctarget_right_b_pos, ctarget_right_b_quat = subtract_frame_transforms(pelvis_w[:3],
pelvis_w[3:7],
ctarget_right_w[:3],
ctarget_right_w[3:7])
pos_delta_left, axa_delta_left = compute_pose_error(foot_left_b[:3],
foot_left_b[3:7],
ctarget_left_b_pos,
ctarget_left_b_quat)
pos_delta_right, axa_delta_right = compute_pose_error(foot_right_b[:3],
foot_right_b[3:7],
ctarget_right_b_pos,
ctarget_right_b_quat)
return np.concatenate((pos_delta_left, axa_delta_left, pos_delta_right, axa_delta_right), axis=0)
def run_policy(self):
if self._step_command is None:
current_left_tf = self.tf_buffer.lookup_transform(
"world",
"left_ankle_roll_link",
rp.time.Time())
# rp.duration.Duration(seconds=0.02))
current_left_pose = self.tf_to_pose(current_left_tf, 'wxyz')
current_left_pose[2] = 0.0
current_left_pose[3:7] = yaw_quat(current_left_pose[3:7])
current_right_tf = self.tf_buffer.lookup_transform(
"world",
"right_ankle_roll_link",
rp.time.Time())
# rp.duration.Duration(seconds=0.02))
current_right_pose = self.tf_to_pose(current_right_tf, 'wxyz')
current_right_pose[2] = 0.0
current_right_pose[3:7] = yaw_quat(current_right_pose[3:7])
self._step_command = StepCommand(current_left_pose, current_right_pose)
if self.remote_controller.button[KeyMap.select] == 1:
self._mode_change = True
self.mode = Mode.null
return
self.counter += 1
# Get the current joint position and velocity
next_ctarget = self._step_command.get_next_ctarget(
self.remote_controller,
self.counter * self.config.control_dt)
print(next_ctarget)
next_ctarget_left, next_ctarget_right, dt_left, dt_right = next_ctarget
self.publish_step_command(next_ctarget_left, next_ctarget_right)
for i in range(len(self.config.joint2motor_idx)):
self.qj[i] = self.low_state.motor_state[self.config.joint2motor_idx[i]].q
self.dqj[i] = self.low_state.motor_state[self.config.joint2motor_idx[i]].dq
# imu_state quaternion: w, x, y, z
quat = self.low_state.imu_state.quaternion
ang_vel = np.array([self.low_state.imu_state.gyroscope], dtype=np.float32)
if self.config.imu_type == "torso":
# h1 and h1_2 imu is on the torso
# imu data needs to be transformed to the pelvis frame
waist_yaw = self.low_state.motor_state[self.config.arm_waist_joint2motor_idx[0]].q
waist_yaw_omega = self.low_state.motor_state[self.config.arm_waist_joint2motor_idx[0]].dq
quat, ang_vel = transform_imu_data(waist_yaw=waist_yaw, waist_yaw_omega=waist_yaw_omega, imu_quat=quat, imu_omega=ang_vel)
# create observation
gravity_orientation = get_gravity_orientation(quat)
qj_obs = self.qj.copy()
dqj_obs = self.dqj.copy()
qj_obs = (qj_obs - self.config.default_angles) * self.config.dof_pos_scale
dqj_obs = dqj_obs * self.config.dof_vel_scale
ang_vel = ang_vel * self.config.ang_vel_scale
# foot pose
left_foot_from_base_tf = self.tf_buffer.lookup_transform(
"pelvis",
"left_ankle_roll_link",
rp.time.Time())
right_foot_from_base_tf = self.tf_buffer.lookup_transform(
"pelvis",
"right_ankle_roll_link",
rp.time.Time())
lf_b = self.tf_to_pose(left_foot_from_base_tf, 'wxyz')
rf_b = self.tf_to_pose(right_foot_from_base_tf, 'wxyz')
left_foot_axa = wrap_to_pi(axis_angle_from_quat(lf_b[3:7]))
right_foot_axa = wrap_to_pi(axis_angle_from_quat(rf_b[3:7]))
rel_foot = np.concatenate((lf_b[:3],
rf_b[:3],
left_foot_axa,
right_foot_axa), axis=0)
# hand pose
left_hand_from_base_tf = self.tf_buffer.lookup_transform(
"pelvis",
"left_rubber_hand",
rp.time.Time())
right_hand_from_base_tf = self.tf_buffer.lookup_transform(
"pelvis",
"right_rubber_hand",
rp.time.Time())
lh_b = self.tf_to_pose(left_hand_from_base_tf, 'wxyz')
rh_b = self.tf_to_pose(right_hand_from_base_tf, 'wxyz')
left_hand_axa = wrap_to_pi(axis_angle_from_quat(lh_b[3:7]))
right_hand_axa = wrap_to_pi(axis_angle_from_quat(rh_b[3:7]))
rel_hand = np.concatenate((lh_b[:3],
rh_b[:3],
left_hand_axa,
right_hand_axa), axis=0)
# foot command
base_pose_w = self.tf_to_pose(self.tf_buffer.lookup_transform(
"world", "pelvis",
rp.time.Time()), 'wxyz')
dt_left = dt_right = 0.0
step_command = self.get_command(base_pose_w,
lf_b,
rf_b,
next_ctarget_left,
next_ctarget_right)
step_command = np.concatenate((step_command,
np.asarray([dt_left, dt_right])), axis=0)
num_actions = self.config.num_actions
self.obs[:3] = ang_vel
self.obs[3:6] = gravity_orientation
self.obs[6:18] = rel_foot
self.obs[18:30] = rel_hand
self.obs[30 : 30 + num_actions] = qj_obs
self.obs[30 + num_actions : 30 + num_actions * 2] = dqj_obs
self.obs[30 + num_actions * 2 : 30 + num_actions * 3] = self.action
self.obs[30 + num_actions * 3 : 30 + num_actions * 3 + 14] = step_command
# Get the action from the policy network
obs_tensor = torch.from_numpy(self.obs).unsqueeze(0)
obs_tensor[..., 30:30+num_actions] = obs_tensor[..., 30:30+num_actions] @ mapping_tensor.transpose(0, 1)
obs_tensor[..., 30 + num_actions : 30 + num_actions * 2] = obs_tensor[..., 30 + num_actions : 30 + num_actions * 2] @ mapping_tensor.transpose(0, 1)
obs_tensor[..., 30 + num_actions * 2 : 30 + num_actions * 3] = obs_tensor[..., 30 + num_actions * 2 : 30 + num_actions * 3] @ mapping_tensor.transpose(0, 1)
# if not self._saved:
# torch.save(obs_tensor, "obs.pt")
# self._saved = True
self._obs_buf.append(obs_tensor.clone())
self.action = self.policy(obs_tensor).detach().numpy().squeeze()
# self.action = self.action * mask.numpy()
# Reorder the actions
self.action = self.action @ mapping_tensor.detach().cpu().numpy()
# transform action to target_dof_pos
target_dof_pos = self.config.default_angles + self.action * self.config.action_scale *0.8
# Build low cmd
if True:
for i, motor_idx in enumerate(self.config.joint2motor_idx):
self.low_cmd.motor_cmd[motor_idx].q = float(target_dof_pos[i])
self.low_cmd.motor_cmd[motor_idx].dq = 0.0
self.low_cmd.motor_cmd[motor_idx].kp = float(self.config.kps[i])
self.low_cmd.motor_cmd[motor_idx].kd = float(self.config.kds[i])
self.low_cmd.motor_cmd[motor_idx].tau = 0.0
# send the command
self.send_cmd(self.low_cmd)
def run_wrapper(self):
# print("hello", self.mode,
# self.mode == Mode.zero_torque)
if self.mode == Mode.wait:
if self.low_state.crc != 0:
self.mode = Mode.zero_torque
self.low_cmd.mode_machine = self.mode_machine_
print("Successfully connected to the robot.")
elif self.mode == Mode.zero_torque:
if self._mode_change:
print("Enter zero torque state.")
print("Waiting for the start signal...")
self._mode_change = False
self.zero_torque_state()
elif self.mode == Mode.default_pos:
if self._mode_change:
print("Moving to default pos.")
self._mode_change = False
self.prepare_default_pos()
self.move_to_default_pos()
elif self.mode == Mode.damping:
if self._mode_change:
print("Enter default pos state.")
print("Waiting for the Button A signal...")
try:
current_left_tf = self.tf_buffer.lookup_transform(
"world",
"left_ankle_roll_link",
rp.time.Time())
self._mode_change = False
except Exception as ex:
print(ex)
self.default_pos_state()
elif self.mode == Mode.policy:
if self._mode_change:
print("Run policy.")
self._mode_change = False
self.counter = 0
self.run_policy()
elif self.mode == Mode.null:
self._terminate = True
# time.sleep(self.config.control_dt)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("config", type=str, help="config file name in the configs folder", default="g1.yaml")
args = parser.parse_args()
# Load config
config_path = f"{LEGGED_GYM_ROOT_DIR}/deploy/deploy_real/configs/{args.config}"
config = Config(config_path)
controller = Controller(config)

View File

@ -8,11 +8,49 @@ from rclpy.qos import QoSProfile
from unitree_hg.msg import LowState as LowStateHG
from tf2_ros.buffer import Buffer
from tf2_ros.transform_listener import TransformListener
from tf2_ros import TransformBroadcaster, TransformStamped
from tf2_ros import TransformBroadcaster, TransformStamped, StaticTransformBroadcaster
import numpy as np
import yaml
from geometry_msgs.msg import Vector3, Quaternion
from scipy.spatial.transform import Rotation as R
def index_map(k_to, k_from):
"""
Returns an index mapping from k_from to k_to.
Given k_to=a, k_from=b,
returns an index map "a_from_b" such that
array_a[a_from_b] = array_b
Missing values are set to -1.
"""
index_dict = {k: i for i, k in enumerate(k_to)} # O(len(k_from))
return [index_dict.get(k, -1) for k in k_from] # O(len(k_to))
def quat_rotate(q: np.ndarray, v: np.ndarray) -> np.ndarray:
"""Rotate a vector by a quaternion along the last dimension of q and v.
Args:
q: The quaternion in (w, x, y, z). Shape is (..., 4).
v: The vector in (x, y, z). Shape is (..., 3).
Returns:
The rotated vector in (x, y, z). Shape is (..., 3).
"""
q_w = q[..., 0]
q_vec = q[..., 1:]
a = v * (2.0 * q_w**2 - 1.0)[..., None]
b = np.cross(q_vec, v, axis=-1) * q_w[..., None] * 2.0
c = q_vec * np.einsum("...i,...i->...", q_vec, v)[..., None] * 2.0
return a + b + c
def to_array(v):
if isinstance(v, Vector3):
return np.array([v.x, v.y, v.z], dtype=np.float32)
elif isinstance(v, Quaternion):
return np.array([v.x, v.y, v.z, v.w], dtype=np.float32)
class PelvistoTrack(Node):
def __init__(self):
@ -22,9 +60,31 @@ class PelvistoTrack(Node):
self.tf_listener = TransformListener(self.tf_buffer, self)
self.tf_broadcaster = TransformBroadcaster(self)
self.timer = self.create_timer(0.05, self.on_timer)
self.low_state = LowStateHG()
self.low_state_subscriber = self.create_subscription(
LowStateHG,
'lowstate',
self.on_low_state,
10)
self.static_tf_broadcaster = StaticTransformBroadcaster(self)
# Timer for dynamic transform broadcasting (e.g., pelvis tracking)
self.timer = self.create_timer(0.01, self.on_timer)
# One-shot timer to check & publish the static transform after a short delay
self.static_tf_timer = self.create_timer(1.0, self.publish_static_tf)
def on_low_state(self,
msg: LowStateHG):
self.low_state = msg
def on_timer(self):
try:
self.tf_buffer.lookup_transform(
"world", "camera_init", rclpy.time.Time()
)
except Exception as ex:
print(f'Could not transform mid360_link_IMU to pelvis as world to camera_init is yet published: {ex}')
try:
t = TransformStamped()
@ -58,6 +118,79 @@ class PelvistoTrack(Node):
except Exception as ex:
print(f'Could not transform mid360_link_IMU to pelvis: {ex}')
def publish_static_tf(self):
"""Check if a static transform from 'world' to 'camera_init' exists.
If not, publish it using the parameter 'camera_init_z' for the z-value.
This method is designed to run only once.
"""
# Cancel the timer so this callback runs only one time.
if self.low_state.crc == 0:
return
self.static_tf_timer.cancel()
try:
# Try to look up an existing transform from "world" to "camera_init".
# Here, rclpy.time.Time() (i.e. time=0) means "the latest available".
self.tf_buffer.lookup_transform(
"world", "camera_init", rclpy.time.Time()
)
self.get_logger().info(
"Static transform from 'world' to 'camera_init' already exists. Not publishing a new one."
)
except Exception as ex:
# If the transform isn't found, declare (or get) the parameter for z and publish the static transform.
z_value, rot = self.lidar_height_rot(self.low_state)
static_tf = TransformStamped()
static_tf.header.stamp = self.get_clock().now().to_msg()
static_tf.header.frame_id = "world"
static_tf.child_frame_id = "camera_init"
# static_tf.child_frame_id = "pelvis"
static_tf.transform.translation.x = 0.0
static_tf.transform.translation.y = 0.0
static_tf.transform.translation.z = z_value
static_tf.transform.rotation.x = float(rot[0])
static_tf.transform.rotation.y = float(rot[1])
static_tf.transform.rotation.z = float(rot[2])
static_tf.transform.rotation.w = float(rot[3])
self.static_tf_broadcaster.sendTransform(static_tf)
self.get_logger().info(
f"Published static transform from 'world' to 'camera_init' with z = {z_value} quat = {rot}"
)
def lidar_height_rot(self, low_state: LowStateHG):
print(self.tf_buffer.lookup_transform('pelvis',
'left_ankle_roll_link', rclpy.time.Time()))
world_from_pelvis_quat = np.asarray(low_state.imu_state.quaternion,
dtype=np.float32)
pelvis_from_rf = self.tf_buffer.lookup_transform('pelvis',
'right_ankle_roll_link', rclpy.time.Time())
pelvis_from_lf = self.tf_buffer.lookup_transform('pelvis',
'left_ankle_roll_link', rclpy.time.Time())
xyz_rf = to_array(pelvis_from_rf.transform.translation)
xyz_lf = to_array(pelvis_from_rf.transform.translation)
pelvis_z_rf = -quat_rotate(
world_from_pelvis_quat, xyz_rf)[2] + 0.028531
pelvis_z_lf = -quat_rotate(
world_from_pelvis_quat, xyz_lf)[2] + 0.028531
# print(xyz_lf)
lidar_from_pelvis = self.tf_buffer.lookup_transform('pelvis',
'mid360_link_frame', rclpy.time.Time())
# print(to_array(lidar_from_pelvis.transform.rotation),
# world_from_pelvis_quat)
lidar_z_pevlis = quat_rotate(world_from_pelvis_quat,
to_array(lidar_from_pelvis.transform.translation))[2]
lidar_rot = (R.from_quat(np.roll(world_from_pelvis_quat, -1)) *
R.from_quat(to_array(lidar_from_pelvis.transform.rotation)))
return (0.5 * pelvis_z_lf + 0.5 * pelvis_z_rf + lidar_z_pevlis,
# lidar_rot.as_quat())
# np.roll(world_from_pelvis_quat, -1))
# to_array(lidar_from_pelvis.transform.rotation))
lidar_rot.as_quat())
def main():
rclpy.init()

View File

@ -0,0 +1,50 @@
import rclpy
from rclpy.node import Node
from tf2_ros.buffer import Buffer
from tf2_ros.transform_listener import TransformListener
from tf2_ros import TransformBroadcaster, TransformStamped, StaticTransformBroadcaster
class Tester(Node):
def __init__(self):
super().__init__('tester')
self.tf_buffer = Buffer()
self.tf_listener = TransformListener(self.tf_buffer, self)
self.tf_broadcaster = TransformBroadcaster(self)
self.timer = self.create_timer(0.01, self.on_timer)
def on_timer(self):
try:
self.tf_buffer.lookup_transform(
"world", "camera_init", rclpy.time.Time()
)
except Exception as ex:
print(f'Could not transform mid360_link_IMU to pelvis as world to camera_init is yet published: {ex}')
return
try:
current_left_tf = self.tf_buffer.lookup_transform(
"world",
"left_ankle_roll_link",
rclpy.time.Time(),
rclpy.duration.Duration(seconds=0.02))
current_right_tf = self.tf_buffer.lookup_transform(
"world",
"right_ankle_roll_link",
rclpy.time.Time(),
rclpy.duration.Duration(seconds=0.02))
print(current_left_tf, current_right_tf)
except Exception as ex:
print(f'Could not transform mid360_link_IMU to pelvis: {ex}')
def main():
rclpy.init()
node = Tester()
try:
rclpy.spin(node)
except KeyboardInterrupt:
pass
rclpy.shutdown()
if __name__ == '__main__':
main()

View File

@ -0,0 +1,12 @@
#!/usr/bin/env python3
import torch
import torch as th
import numpy as np
policy = torch.jit.load('../pre_train/g1/policy_eetrack.pt')
obs=np.load('/tmp/eet5/obs002.npy')
print('obs', obs.shape)
act=policy(torch.from_numpy(obs))
act_sim=np.load('/tmp/eet5/act002.npy')
act_rec=act.detach().cpu().numpy()
delta= (act_sim - act_rec)
print(np.abs(delta).max())