diff --git a/src/rl_sar/library/rl_sdk/rl_sdk.cpp b/src/rl_sar/library/rl_sdk/rl_sdk.cpp index fb5aee9..039c87d 100644 --- a/src/rl_sar/library/rl_sdk/rl_sdk.cpp +++ b/src/rl_sar/library/rl_sdk/rl_sdk.cpp @@ -21,9 +21,19 @@ torch::Tensor RL::ComputeObservation() { obs_list.push_back(this->obs.lin_vel * this->params.lin_vel_scale); } - else if (observation == "ang_vel") + /* + The first argument of the QuatRotateInverse function is the quaternion representing the robot's orientation, and the second argument is in the world coordinate system. The function outputs the value of the second argument in the body coordinate system. + In IsaacGym, the coordinate system for angular velocity is in the world coordinate system. During training, the angular velocity in the observation uses QuatRotateInverse to transform the coordinate system to the body coordinate system. + In Gazebo, the coordinate system for angular velocity is also in the world coordinate system, so QuatRotateInverse is needed to transform the coordinate system to the body coordinate system. + In some real robots like Unitree, if the coordinate system for the angular velocity is already in the body coordinate system, no transformation is necessary. + Forgetting to perform the transformation or performing it multiple times may cause controller crashes when the rotation reaches 180 degrees. + */ + else if (observation == "ang_vel_body") + { + obs_list.push_back(this->obs.ang_vel * this->params.ang_vel_scale); + } + else if (observation == "ang_vel_world") { - // obs_list.push_back(this->obs.ang_vel * this->params.ang_vel_scale); // TODO is QuatRotateInverse necessery? obs_list.push_back(this->QuatRotateInverse(this->obs.base_quat, this->obs.ang_vel, this->params.framework) * this->params.ang_vel_scale); } else if (observation == "gravity_vec") diff --git a/src/rl_sar/scripts/rl_sdk.py b/src/rl_sar/scripts/rl_sdk.py index 5f25a0b..6c2a642 100644 --- a/src/rl_sar/scripts/rl_sdk.py +++ b/src/rl_sar/scripts/rl_sdk.py @@ -138,10 +138,18 @@ class RL: def ComputeObservation(self): obs_list = [] for observation in self.params.observations: + """ + The first argument of the QuatRotateInverse function is the quaternion representing the robot's orientation, and the second argument is in the world coordinate system. The function outputs the value of the second argument in the body coordinate system. + In IsaacGym, the coordinate system for angular velocity is in the world coordinate system. During training, the angular velocity in the observation uses QuatRotateInverse to transform the coordinate system to the body coordinate system. + In Gazebo, the coordinate system for angular velocity is also in the world coordinate system, so QuatRotateInverse is needed to transform the coordinate system to the body coordinate system. + In some real robots like Unitree, if the coordinate system for the angular velocity is already in the body coordinate system, no transformation is necessary. + Forgetting to perform the transformation or performing it multiple times may cause controller crashes when the rotation reaches 180 degrees. + """ if observation == "lin_vel": obs_list.append(self.obs.lin_vel * self.params.lin_vel_scale) - elif observation == "ang_vel": - # obs_list.append(self.obs.ang_vel * self.params.ang_vel_scale) # TODO is QuatRotateInverse necessery? + elif observation == "ang_vel_body": + obs_list.append(self.obs.ang_vel * self.params.ang_vel_scale) + elif observation == "ang_vel_world": obs_list.append(self.QuatRotateInverse(self.obs.base_quat, self.obs.ang_vel, self.params.framework) * self.params.ang_vel_scale) elif observation == "gravity_vec": obs_list.append(self.QuatRotateInverse(self.obs.base_quat, self.obs.gravity_vec, self.params.framework)) diff --git a/src/rl_sar/scripts/rl_sim.py b/src/rl_sar/scripts/rl_sim.py index 3c61883..bcc4192 100644 --- a/src/rl_sar/scripts/rl_sim.py +++ b/src/rl_sar/scripts/rl_sim.py @@ -34,6 +34,9 @@ class RL_Sim(RL): # read params from yaml self.robot_name = rospy.get_param("robot_name", "") self.ReadYaml(self.robot_name) + for i in range(len(self.params.observations)): + if self.params.observations[i] == "ang_vel": + self.params.observations[i] = "ang_vel_world" # history if self.params.observations_history is None: diff --git a/src/rl_sar/src/rl_real_a1.cpp b/src/rl_sar/src/rl_real_a1.cpp index 7fa86c5..66420bb 100644 --- a/src/rl_sar/src/rl_real_a1.cpp +++ b/src/rl_sar/src/rl_real_a1.cpp @@ -10,6 +10,14 @@ RL_Real::RL_Real() : unitree_safe(UNITREE_LEGGED_SDK::LeggedType::A1), unitree_u // read params from yaml this->robot_name = "a1_isaacgym"; this->ReadYaml(this->robot_name); + for (std::string &observation : this->params.observations) + { + // In Unitree A1, the coordinate system for angular velocity is in the body coordinate system. + if (observation == "ang_vel") + { + observation = "ang_vel_body"; + } + } // history if (!this->params.observations_history.empty()) diff --git a/src/rl_sar/src/rl_real_go2.cpp b/src/rl_sar/src/rl_real_go2.cpp index d6b7719..fa8db36 100644 --- a/src/rl_sar/src/rl_real_go2.cpp +++ b/src/rl_sar/src/rl_real_go2.cpp @@ -10,6 +10,14 @@ void RL_Real::RL_Real() // read params from yaml this->robot_name = "go2_isaacgym"; this->ReadYaml(this->robot_name); + for (std::string &observation : this->params.observations) + { + // In Unitree Go2, the coordinate system for angular velocity is in the body coordinate system. + if (observation == "ang_vel") + { + observation = "ang_vel_body"; + } + } // history if (!this->params.observations_history.empty()) diff --git a/src/rl_sar/src/rl_sim.cpp b/src/rl_sar/src/rl_sim.cpp index 4f2d06d..6242fdb 100644 --- a/src/rl_sar/src/rl_sim.cpp +++ b/src/rl_sar/src/rl_sim.cpp @@ -10,6 +10,14 @@ RL_Sim::RL_Sim() // read params from yaml nh.param("robot_name", this->robot_name, ""); this->ReadYaml(this->robot_name); + for (std::string &observation : this->params.observations) + { + // In Gazebo, the coordinate system for angular velocity is in the world coordinate system. + if (observation == "ang_vel") + { + observation = "ang_vel_world"; + } + } // history if (!this->params.observations_history.empty())