fix imap, etc.

This commit is contained in:
Yoonyoung Cho 2025-02-14 22:29:55 +09:00
parent 5fc670e9e1
commit 12febcabac
2 changed files with 48 additions and 40 deletions

View File

@ -41,6 +41,7 @@ lab_joint: [
'right_wrist_yaw_joint' 'right_wrist_yaw_joint'
] ]
arm_joint: [ arm_joint: [
"left_shoulder_pitch_joint", "left_shoulder_pitch_joint",
"left_shoulder_roll_joint", "left_shoulder_roll_joint",

View File

@ -25,6 +25,7 @@ from yourdfpy import URDF
from math_utils import * from math_utils import *
import random as rd import random as rd
class Mode(Enum): class Mode(Enum):
wait = 0 wait = 0
zero_torque = 1 zero_torque = 1
@ -67,9 +68,10 @@ def axis_angle_from_quat(quat: np.ndarray, eps: float = 1.0e-6) -> np.ndarray:
) )
return quat[..., 1:4] / sin_half_angles_over_angles[..., None] return quat[..., 1:4] / sin_half_angles_over_angles[..., None]
def quat_from_angle_axis( def quat_from_angle_axis(
angle: torch.Tensor, angle: torch.Tensor,
axis: torch.Tensor=None) -> torch.Tensor: axis: torch.Tensor = None) -> torch.Tensor:
"""Convert rotations given as angle-axis to quaternions. """Convert rotations given as angle-axis to quaternions.
Args: Args:
@ -88,6 +90,7 @@ def quat_from_angle_axis(
w = theta.cos() w = theta.cos()
return normalize(torch.cat([w, xyz], dim=-1)) return normalize(torch.cat([w, xyz], dim=-1))
def quat_mul(q1: torch.Tensor, q2: torch.Tensor) -> torch.Tensor: def quat_mul(q1: torch.Tensor, q2: torch.Tensor) -> torch.Tensor:
"""Multiply two quaternions together. """Multiply two quaternions together.
@ -126,7 +129,6 @@ def quat_mul(q1: torch.Tensor, q2: torch.Tensor) -> torch.Tensor:
return torch.stack([w, x, y, z], dim=-1).view(shape) return torch.stack([w, x, y, z], dim=-1).view(shape)
def quat_rotate(q: np.ndarray, v: np.ndarray) -> np.ndarray: def quat_rotate(q: np.ndarray, v: np.ndarray) -> np.ndarray:
"""Rotate a vector by a quaternion along the last dimension of q and v. """Rotate a vector by a quaternion along the last dimension of q and v.
@ -191,7 +193,7 @@ def body_pose(
xyz = np.array(xyz) xyz = np.array(xyz)
if rot_type == 'axa': if rot_type == 'axa':
axa = axis_angle_from_quat(quat_wxyz) axa = axis_angle_from_quat(quat_wxyz)
axa = (axa + np.pi) % (2 * np.pi) axa = (axa + np.pi) % (2 * np.pi) - np.pi
return (xyz, axa) return (xyz, axa)
elif rot_type == 'quat': elif rot_type == 'quat':
return (xyz, quat_wxyz) return (xyz, quat_wxyz)
@ -207,7 +209,8 @@ def compute_com(tf_buffer, body_frames: List[str]):
com_list = [] com_list = []
# bring default values # bring default values
com_data = extract_link_data('../../resources/robots/g1_description/g1_29dof_rev_1_0.xml') com_data = extract_link_data(
'../../resources/robots/g1_description/g1_29dof_rev_1_0.xml')
# iterate for frames # iterate for frames
for frame in body_frames: for frame in body_frames:
@ -239,19 +242,16 @@ def compute_com(tf_buffer, body_frames: List[str]):
def index_map(k_to, k_from): def index_map(k_to, k_from):
""" """
returns an index mapping from k_from to k_to; Returns an index mapping from k_from to k_to.
i.e. k_to[index_map] = k_from
Given k_to=a, k_from=b,
returns an index map "a_from_b" such that
array_a[a_from_b] = array_b
Missing values are set to -1. Missing values are set to -1.
""" """
out = [] index_dict = {k: i for i, k in enumerate(k_to)} # O(len(k_from))
for k in k_to: return [index_dict.get(k, -1) for k in k_from] # O(len(k_to))
try:
i = k_from.index(k)
except ValueError:
i = -1
out.append(i)
return out
def interpolate_position(pos1, pos2, n_segments): def interpolate_position(pos1, pos2, n_segments):
@ -260,6 +260,7 @@ def interpolate_position(pos1, pos2, n_segments):
interp_pos.append(pos2) interp_pos.append(pos2)
return interp_pos return interp_pos
class eetrack: class eetrack:
def __init__(self, root_state_w): def __init__(self, root_state_w):
self.eetrack_midpt = root_state_w.clone() self.eetrack_midpt = root_state_w.clone()
@ -296,19 +297,23 @@ class eetrack:
def create_direction(self): def create_direction(self):
angle_from_eetrack_line = torch.rand(1, device=self.device) * np.pi angle_from_eetrack_line = torch.rand(1, device=self.device) * np.pi
angle_from_xyplane_in_global_frame = torch.rand(1, device=self.device) * np.pi - np.pi/2 angle_from_xyplane_in_global_frame = torch.rand(
1, device=self.device) * np.pi - np.pi / 2
# For testing # For testing
angle_from_eetrack_line = torch.rand(1, device=self.device) * np.pi/2 angle_from_eetrack_line = torch.rand(1, device=self.device) * np.pi / 2
angle_from_xyplane_in_global_frame = torch.rand(1, device=self.device) * 0 angle_from_xyplane_in_global_frame = torch.rand(
1, device=self.device) * 0
roll = torch.zeros(1, device=self.device) roll = torch.zeros(1, device=self.device)
pitch = angle_from_xyplane_in_global_frame pitch = angle_from_xyplane_in_global_frame
yaw = angle_from_eetrack_line yaw = angle_from_eetrack_line
euler = torch.stack([roll, pitch, yaw], dim=1) euler = torch.stack([roll, pitch, yaw], dim=1)
quat = math_utils.quat_from_euler_xyz(euler[:,0], euler[:,1], euler[:,2]) quat = math_utils.quat_from_euler_xyz(
euler[:, 0], euler[:, 1], euler[:, 2])
return quat return quat
def create_subgoal(self): def create_subgoal(self):
eetrack_subgoals = interpolate_position(self.eetrack_start, self.eetrack_end, self.number_of_subgoals) eetrack_subgoals = interpolate_position(
self.eetrack_start, self.eetrack_end, self.number_of_subgoals)
eetrack_subgoals = [ eetrack_subgoals = [
( (
l.clone().to(self.device, dtype=torch.float32) l.clone().to(self.device, dtype=torch.float32)
@ -317,14 +322,15 @@ class eetrack:
) )
for l in eetrack_subgoals for l in eetrack_subgoals
] ]
eetrack_subgoals = torch.stack(eetrack_subgoals,axis=1) eetrack_subgoals = torch.stack(eetrack_subgoals, axis=1)
eetrack_ori = self.create_direction().unsqueeze(1).repeat(1, self.number_of_subgoals + 1, 1) eetrack_ori = self.create_direction().unsqueeze(
1).repeat(1, self.number_of_subgoals + 1, 1)
# welidng_subgoals -> Nenv x Npoints x (3 + 4) # welidng_subgoals -> Nenv x Npoints x (3 + 4)
return torch.cat([eetrack_subgoals, eetrack_ori], dim=2) return torch.cat([eetrack_subgoals, eetrack_ori], dim=2)
def update_command(self): def update_command(self):
time = self.init_time - rp.time.Time() time = self.init_time - rp.time.Time()
if (time>=0): if (time >= 0):
self.sg_idx = time / 0.1 + 1 self.sg_idx = time / 0.1 + 1
self.sg_idx.clamp_(0, self.number_of_subgoals + 1) self.sg_idx.clamp_(0, self.number_of_subgoals + 1)
self.next_command_s_left = self.eetrack_subgoal[self.sg_idx] self.next_command_s_left = self.eetrack_subgoal[self.sg_idx]
@ -332,7 +338,8 @@ class eetrack:
def get_command(self, root_state_w): def get_command(self, root_state_w):
self.update_command() self.update_command()
pos_hand_b_left, quat_hand_b_left = body_pose_axa("left_hand_palm_link") pos_hand_b_left, quat_hand_b_left = body_pose_axa(
"left_hand_palm_link")
lerp_command_w_left = self.next_command_s_left lerp_command_w_left = self.next_command_s_left
@ -705,7 +712,6 @@ class Controller:
i_pin = self.pin_from_mot[i_mot] i_pin = self.pin_from_mot[i_mot]
q_pin[i_pin] = self.low_state.motor_state[i_mot].q q_pin[i_pin] = self.low_state.motor_state[i_mot].q
d_quat = quat_from_angle_axis( d_quat = quat_from_angle_axis(
torch.from_numpy(_hands_command_[..., 3:]) torch.from_numpy(_hands_command_[..., 3:])
).detach().cpu().numpy() ).detach().cpu().numpy()
@ -715,7 +721,8 @@ class Controller:
source_quat = xyzw2wxyz(pin.Quaternion(source_pose.rotation).coeffs()) source_quat = xyzw2wxyz(pin.Quaternion(source_pose.rotation).coeffs())
target_xyz = source_xyz + _hands_command_[..., :3] target_xyz = source_xyz + _hands_command_[..., :3]
target_quat = quat_mul(torch.from_numpy(d_quat), target_quat = quat_mul(
torch.from_numpy(d_quat),
torch.from_numpy(source_quat)).detach().cpu().numpy() torch.from_numpy(source_quat)).detach().cpu().numpy()
target = np.concatenate([target_xyz, target_quat]) target = np.concatenate([target_xyz, target_quat])