mirror of https://github.com/fan-ziqi/rl_sar.git
feat: add AttitudeProtect
This commit is contained in:
parent
44f8f5324f
commit
2d4c0d23d0
|
@ -294,6 +294,55 @@ void RL::TorqueProtect(torch::Tensor origin_output_torques)
|
|||
}
|
||||
}
|
||||
|
||||
void RL::AttitudeProtect(const std::vector<double> &quaternion, float pitch_threshold, float roll_threshold)
|
||||
{
|
||||
float rad2deg = 57.2958;
|
||||
float w, x, y, z;
|
||||
|
||||
if (this->params.framework == "isaacgym")
|
||||
{
|
||||
w = quaternion[3];
|
||||
x = quaternion[0];
|
||||
y = quaternion[1];
|
||||
z = quaternion[2];
|
||||
}
|
||||
else if (this->params.framework == "isaacsim")
|
||||
{
|
||||
w = quaternion[0];
|
||||
x = quaternion[1];
|
||||
y = quaternion[2];
|
||||
z = quaternion[3];
|
||||
}
|
||||
|
||||
// Calculate roll (rotation around the X-axis)
|
||||
float sinr_cosp = 2 * (w * x + y * z);
|
||||
float cosr_cosp = 1 - 2 * (x * x + y * y);
|
||||
float roll = std::atan2(sinr_cosp, cosr_cosp) * rad2deg;
|
||||
|
||||
// Calculate pitch (rotation around the Y-axis)
|
||||
float sinp = 2 * (w * y - z * x);
|
||||
float pitch;
|
||||
if (std::fabs(sinp) >= 1)
|
||||
{
|
||||
pitch = std::copysign(90.0, sinp); // Clamp to avoid out-of-range values
|
||||
}
|
||||
else
|
||||
{
|
||||
pitch = std::asin(sinp) * rad2deg;
|
||||
}
|
||||
|
||||
if (std::fabs(roll) > roll_threshold)
|
||||
{
|
||||
// this->control.control_state = STATE_POS_GETDOWN;
|
||||
std::cout << LOGGER::WARNING << "Roll exceeds " << roll_threshold << " degrees. Current: " << roll << " degrees." << std::endl;
|
||||
}
|
||||
if (std::fabs(pitch) > pitch_threshold)
|
||||
{
|
||||
// this->control.control_state = STATE_POS_GETDOWN;
|
||||
std::cout << LOGGER::WARNING << "Pitch exceeds " << pitch_threshold << " degrees. Current: " << pitch << " degrees." << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
#include <termios.h>
|
||||
#include <sys/ioctl.h>
|
||||
static bool kbhit()
|
||||
|
|
|
@ -160,6 +160,7 @@ public:
|
|||
|
||||
// protect func
|
||||
void TorqueProtect(torch::Tensor origin_output_torques);
|
||||
void AttitudeProtect(const std::vector<double> &quaternion, float pitch_threshold, float roll_threshold);
|
||||
|
||||
protected:
|
||||
// rl module
|
||||
|
|
|
@ -173,6 +173,7 @@ void RL_Real::RunModel()
|
|||
torch::Tensor origin_output_torques = this->ComputeTorques(this->obs.actions);
|
||||
|
||||
this->TorqueProtect(origin_output_torques);
|
||||
this->AttitudeProtect(this->robot_state.imu.quaternion, 60.0f, 60.0f);
|
||||
|
||||
this->output_torques = torch::clamp(origin_output_torques, -(this->params.torque_limits), this->params.torque_limits);
|
||||
this->output_dof_pos = this->ComputePosition(this->obs.actions);
|
||||
|
|
|
@ -178,6 +178,7 @@ void RL_Real::RunModel()
|
|||
torch::Tensor origin_output_torques = this->ComputeTorques(this->obs.actions);
|
||||
|
||||
this->TorqueProtect(origin_output_torques);
|
||||
this->AttitudeProtect(this->robot_state.imu.quaternion, 60.0f, 60.0f);
|
||||
|
||||
this->output_torques = torch::clamp(origin_output_torques, -(this->params.torque_limits), this->params.torque_limits);
|
||||
this->output_dof_pos = this->ComputePosition(this->obs.actions);
|
||||
|
|
Loading…
Reference in New Issue