diff --git a/src/rl_sar/library/rl_sdk/rl_sdk.cpp b/src/rl_sar/library/rl_sdk/rl_sdk.cpp index 039c87d..5cd311a 100644 --- a/src/rl_sar/library/rl_sdk/rl_sdk.cpp +++ b/src/rl_sar/library/rl_sdk/rl_sdk.cpp @@ -294,6 +294,55 @@ void RL::TorqueProtect(torch::Tensor origin_output_torques) } } +void RL::AttitudeProtect(const std::vector &quaternion, float pitch_threshold, float roll_threshold) +{ + float rad2deg = 57.2958; + float w, x, y, z; + + if (this->params.framework == "isaacgym") + { + w = quaternion[3]; + x = quaternion[0]; + y = quaternion[1]; + z = quaternion[2]; + } + else if (this->params.framework == "isaacsim") + { + w = quaternion[0]; + x = quaternion[1]; + y = quaternion[2]; + z = quaternion[3]; + } + + // Calculate roll (rotation around the X-axis) + float sinr_cosp = 2 * (w * x + y * z); + float cosr_cosp = 1 - 2 * (x * x + y * y); + float roll = std::atan2(sinr_cosp, cosr_cosp) * rad2deg; + + // Calculate pitch (rotation around the Y-axis) + float sinp = 2 * (w * y - z * x); + float pitch; + if (std::fabs(sinp) >= 1) + { + pitch = std::copysign(90.0, sinp); // Clamp to avoid out-of-range values + } + else + { + pitch = std::asin(sinp) * rad2deg; + } + + if (std::fabs(roll) > roll_threshold) + { + // this->control.control_state = STATE_POS_GETDOWN; + std::cout << LOGGER::WARNING << "Roll exceeds " << roll_threshold << " degrees. Current: " << roll << " degrees." << std::endl; + } + if (std::fabs(pitch) > pitch_threshold) + { + // this->control.control_state = STATE_POS_GETDOWN; + std::cout << LOGGER::WARNING << "Pitch exceeds " << pitch_threshold << " degrees. Current: " << pitch << " degrees." << std::endl; + } +} + #include #include static bool kbhit() diff --git a/src/rl_sar/library/rl_sdk/rl_sdk.hpp b/src/rl_sar/library/rl_sdk/rl_sdk.hpp index 5f86f76..b31bcfe 100644 --- a/src/rl_sar/library/rl_sdk/rl_sdk.hpp +++ b/src/rl_sar/library/rl_sdk/rl_sdk.hpp @@ -160,6 +160,7 @@ public: // protect func void TorqueProtect(torch::Tensor origin_output_torques); + void AttitudeProtect(const std::vector &quaternion, float pitch_threshold, float roll_threshold); protected: // rl module diff --git a/src/rl_sar/src/rl_real_a1.cpp b/src/rl_sar/src/rl_real_a1.cpp index e192670..8411f2d 100644 --- a/src/rl_sar/src/rl_real_a1.cpp +++ b/src/rl_sar/src/rl_real_a1.cpp @@ -173,6 +173,7 @@ void RL_Real::RunModel() torch::Tensor origin_output_torques = this->ComputeTorques(this->obs.actions); this->TorqueProtect(origin_output_torques); + this->AttitudeProtect(this->robot_state.imu.quaternion, 60.0f, 60.0f); this->output_torques = torch::clamp(origin_output_torques, -(this->params.torque_limits), this->params.torque_limits); this->output_dof_pos = this->ComputePosition(this->obs.actions); diff --git a/src/rl_sar/src/rl_real_go2.cpp b/src/rl_sar/src/rl_real_go2.cpp index b245ad5..66b929b 100644 --- a/src/rl_sar/src/rl_real_go2.cpp +++ b/src/rl_sar/src/rl_real_go2.cpp @@ -178,6 +178,7 @@ void RL_Real::RunModel() torch::Tensor origin_output_torques = this->ComputeTorques(this->obs.actions); this->TorqueProtect(origin_output_torques); + this->AttitudeProtect(this->robot_state.imu.quaternion, 60.0f, 60.0f); this->output_torques = torch::clamp(origin_output_torques, -(this->params.torque_limits), this->params.torque_limits); this->output_dof_pos = this->ComputePosition(this->obs.actions);