From 1b12b00fd1dd55ccf1fc7b9d63989dbdbca4b260 Mon Sep 17 00:00:00 2001 From: fan-ziqi Date: Fri, 22 Mar 2024 21:18:55 +0800 Subject: [PATCH] fix: organize code --- src/rl_sar/include/rl_real.hpp | 9 +- src/rl_sar/include/rl_sim.hpp | 2 + src/rl_sar/library/rl/rl.cpp | 14 +-- src/rl_sar/src/rl_real.cpp | 197 +++++++++++++++++---------------- src/rl_sar/src/rl_sim.cpp | 78 ++++++------- 5 files changed, 145 insertions(+), 155 deletions(-) diff --git a/src/rl_sar/include/rl_real.hpp b/src/rl_sar/include/rl_real.hpp index 9d2e6d8..1fc5690 100644 --- a/src/rl_sar/include/rl_real.hpp +++ b/src/rl_sar/include/rl_real.hpp @@ -68,14 +68,7 @@ private: std::vector joint_velocities; int dof_mapping[13] = {3, 4, 5, 0, 1, 2, 9, 10, 11, 6, 7, 8}; - float Kp[13] = {20, 20, 20, //fr - 20, 20, 20, //fl - 20, 20, 20, //rr - 20, 20, 20};//rl - float Kd[13] = {0.5, 0.5, 0.5, - 0.5, 0.5, 0.5, - 0.5, 0.5, 0.5, - 0.5, 0.5, 0.5}; + int hip_scale_reduction_indices[] = {0, 3, 6, 9}; std::chrono::high_resolution_clock::time_point start_time; diff --git a/src/rl_sar/include/rl_sim.hpp b/src/rl_sar/include/rl_sim.hpp index 7a9fa8c..8d2d77c 100644 --- a/src/rl_sar/include/rl_sim.hpp +++ b/src/rl_sar/include/rl_sim.hpp @@ -58,6 +58,8 @@ private: std::vector joint_positions; std::vector joint_velocities; + int hip_scale_reduction_indices[] = {0, 3, 6, 9}; + std::chrono::high_resolution_clock::time_point start_time; // other rl module diff --git a/src/rl_sar/library/rl/rl.cpp b/src/rl_sar/library/rl/rl.cpp index 26fa4d2..38f0825 100644 --- a/src/rl_sar/library/rl/rl.cpp +++ b/src/rl_sar/library/rl/rl.cpp @@ -6,7 +6,7 @@ torch::Tensor RL::QuatRotateInverse(torch::Tensor q, torch::Tensor v) torch::Tensor q_w = q.index({torch::indexing::Slice(), -1}); torch::Tensor q_vec = q.index({torch::indexing::Slice(), torch::indexing::Slice(0, 3)}); torch::Tensor a = v * (2.0 * torch::pow(q_w, 2) - 1.0).unsqueeze(-1); - torch::Tensor b = torch::cross(q_vec, v, /*dim=*/-1) * q_w.unsqueeze(-1) * 2.0; + torch::Tensor b = torch::cross(q_vec, v, -1) * q_w.unsqueeze(-1) * 2.0; torch::Tensor c = q_vec * torch::bmm(q_vec.view({shape[0], 1, 3}), v.view({shape[0], 3, 1})).squeeze(-1) * 2.0; return a - b + c; } @@ -26,12 +26,6 @@ void RL::InitObservations() torch::Tensor RL::ComputeTorques(torch::Tensor actions) { torch::Tensor actions_scaled = actions * this->params.action_scale; - int indices[] = {0, 3, 6, 9}; - for (int i : indices) - { - actions_scaled[0][i] *= this->params.hip_scale_reduction; - } - torch::Tensor output_torques = this->params.p_gains * (actions_scaled + this->params.default_dof_pos - this->obs.dof_pos) - this->params.d_gains * this->obs.dof_vel; torch::Tensor clamped = torch::clamp(output_torques, -(this->params.torque_limits), this->params.torque_limits); return clamped; @@ -40,12 +34,6 @@ torch::Tensor RL::ComputeTorques(torch::Tensor actions) torch::Tensor RL::ComputePosition(torch::Tensor actions) { torch::Tensor actions_scaled = actions * this->params.action_scale; - int indices[] = {0, 3, 6, 9}; - for (int i : indices) - { - actions_scaled[0][i] *= this->params.hip_scale_reduction; - } - return actions_scaled + this->params.default_dof_pos; } diff --git a/src/rl_sar/src/rl_real.cpp b/src/rl_sar/src/rl_real.cpp index 64d2c12..0e51bc7 100644 --- a/src/rl_sar/src/rl_real.cpp +++ b/src/rl_sar/src/rl_real.cpp @@ -4,6 +4,82 @@ RL_Real rl_sar; +RL_Real::RL_Real() : safe(LeggedType::A1), udp(LOWLEVEL) +{ + udp.InitCmdData(cmd); + + start_time = std::chrono::high_resolution_clock::now(); + + std::string actor_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/actor.pt"; + std::string encoder_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/encoder.pt"; + std::string vq_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/vq_layer.pt"; + + this->actor = torch::jit::load(actor_path); + this->encoder = torch::jit::load(encoder_path); + this->vq = torch::jit::load(vq_path); + this->InitObservations(); + + this->params.num_observations = 45; + this->params.clip_obs = 100.0; + this->params.clip_actions = 100.0; + this->params.damping = 0.5; + this->params.stiffness = 20; + this->params.d_gains = torch::ones(12) * this->params.damping; + this->params.p_gains = torch::ones(12) * this->params.stiffness; + this->params.action_scale = 0.25; + this->params.hip_scale_reduction = 0.5; + this->params.num_of_dofs = 12; + this->params.lin_vel_scale = 2.0; + this->params.ang_vel_scale = 0.25; + this->params.dof_pos_scale = 1.0; + this->params.dof_vel_scale = 0.05; + this->params.commands_scale = torch::tensor({this->params.lin_vel_scale, this->params.lin_vel_scale, this->params.ang_vel_scale}); + + this->params.torque_limits = torch::tensor({{20.0, 55.0, 55.0, + 20.0, 55.0, 55.0, + 20.0, 55.0, 55.0, + 20.0, 55.0, 55.0}}); + + // hip, thigh, calf + this->params.default_dof_pos = torch::tensor({{ 0.1000, 0.8000, -1.5000, // FL + -0.1000, 0.8000, -1.5000, // FR + 0.1000, 1.0000, -1.5000, // RR + -0.1000, 1.0000, -1.5000}});// RL + + this->history_obs_buf = ObservationBuffer(1, this->params.num_observations, 6); + + output_torques = torch::tensor({{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}}); + output_dof_pos = params.default_dof_pos; + plot_real_joint_pos.resize(12); + plot_target_joint_pos.resize(12); + + loop_control = std::make_shared("loop_control", 0.002, boost::bind(&RL_Real::RobotControl, this)); + loop_udpSend = std::make_shared("loop_udpSend", 0.002, 3, boost::bind(&RL_Real::UDPSend, this)); + loop_udpRecv = std::make_shared("loop_udpRecv", 0.002, 3, boost::bind(&RL_Real::UDPRecv, this)); + loop_rl = std::make_shared("loop_rl" , 0.02 , boost::bind(&RL_Real::RunModel, this)); + + loop_udpSend->start(); + loop_udpRecv->start(); + loop_control->start(); + +#ifdef PLOT + loop_plot = std::make_shared("loop_plot" , 0.002, boost::bind(&RL_Real::Plot, this)); + loop_plot->start(); +#endif +} + +RL_Real::~RL_Real() +{ + loop_udpSend->shutdown(); + loop_udpRecv->shutdown(); + loop_control->shutdown(); + loop_rl->shutdown(); +#ifdef PLOT + loop_plot->shutdown(); +#endif + printf("exit\n"); +} + void RL_Real::RobotControl() { motiontime++; @@ -81,8 +157,6 @@ void RL_Real::RobotControl() // cmd.motorCmd[i].q = 0; cmd.motorCmd[i].q = output_dof_pos[0][dof_mapping[i]].item(); cmd.motorCmd[i].dq = 0; - // cmd.motorCmd[i].Kp = Kp[dof_mapping[i]]; - // cmd.motorCmd[i].Kd = Kd[dof_mapping[i]]; cmd.motorCmd[i].Kp = params.stiffness; cmd.motorCmd[i].Kd = params.damping; // cmd.motorCmd[i].tau = output_torques[0][dof_mapping[i]].item(); @@ -129,100 +203,6 @@ void RL_Real::RobotControl() udp.SetSend(cmd); } -RL_Real::RL_Real() : safe(LeggedType::A1), udp(LOWLEVEL) -{ - udp.InitCmdData(cmd); - - start_time = std::chrono::high_resolution_clock::now(); - - std::string actor_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/actor.pt"; - std::string encoder_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/encoder.pt"; - std::string vq_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/vq_layer.pt"; - - this->actor = torch::jit::load(actor_path); - this->encoder = torch::jit::load(encoder_path); - this->vq = torch::jit::load(vq_path); - this->InitObservations(); - - this->params.num_observations = 45; - this->params.clip_obs = 100.0; - this->params.clip_actions = 100.0; - this->params.damping = 0.5; - this->params.stiffness = 20; - this->params.d_gains = torch::ones(12) * this->params.damping; - this->params.p_gains = torch::ones(12) * this->params.stiffness; - this->params.action_scale = 0.25; - this->params.hip_scale_reduction = 0.5; - this->params.num_of_dofs = 12; - this->params.lin_vel_scale = 2.0; - this->params.ang_vel_scale = 0.25; - this->params.dof_pos_scale = 1.0; - this->params.dof_vel_scale = 0.05; - this->params.commands_scale = torch::tensor({this->params.lin_vel_scale, this->params.lin_vel_scale, this->params.ang_vel_scale}); - - this->params.torque_limits = torch::tensor({{20.0, 55.0, 55.0, - 20.0, 55.0, 55.0, - 20.0, 55.0, 55.0, - 20.0, 55.0, 55.0}}); - - // hip, thigh, calf - this->params.default_dof_pos = torch::tensor({{ 0.1000, 0.8000, -1.5000, // FL - -0.1000, 0.8000, -1.5000, // FR - 0.1000, 1.0000, -1.5000, // RR - -0.1000, 1.0000, -1.5000}});// RL - - this->history_obs_buf = ObservationBuffer(1, this->params.num_observations, 6); - - output_torques = torch::tensor({{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}}); - output_dof_pos = params.default_dof_pos; - plot_real_joint_pos.resize(12); - plot_target_joint_pos.resize(12); - - loop_control = std::make_shared("loop_control", 0.002, boost::bind(&RL_Real::RobotControl, this)); - loop_udpSend = std::make_shared("loop_udpSend", 0.002, 3, boost::bind(&RL_Real::UDPSend, this)); - loop_udpRecv = std::make_shared("loop_udpRecv", 0.002, 3, boost::bind(&RL_Real::UDPRecv, this)); - loop_rl = std::make_shared("loop_rl" , 0.02 , boost::bind(&RL_Real::RunModel, this)); - - loop_udpSend->start(); - loop_udpRecv->start(); - loop_control->start(); - -#ifdef PLOT - loop_plot = std::make_shared("loop_plot" , 0.002, boost::bind(&RL_Real::Plot, this)); - loop_plot->start(); -#endif -} - -RL_Real::~RL_Real() -{ - loop_udpSend->shutdown(); - loop_udpRecv->shutdown(); - loop_control->shutdown(); - loop_rl->shutdown(); -#ifdef PLOT - loop_plot->shutdown(); -#endif - printf("exit\n"); -} - -void RL_Real::Plot() -{ - plot_t.push_back(motiontime); - plt::cla(); - plt::clf(); - for(int i = 0; i < 12; ++i) - { - plot_real_joint_pos[i].push_back(state.motorState[i].q); - plot_target_joint_pos[i].push_back(cmd.motorCmd[i].q); - plt::subplot(4, 3, i+1); - plt::named_plot("_real_joint_pos", plot_t, plot_real_joint_pos[i], "r"); - plt::named_plot("_target_joint_pos", plot_t, plot_target_joint_pos[i], "b"); - plt::xlim(motiontime-10000, motiontime); - } - // plt::legend(); - plt::pause(0.0001); -} - void RL_Real::RunModel() { if(robot_state == STATE_RL_RUNNING) @@ -260,6 +240,11 @@ void RL_Real::RunModel() torch::Tensor actions = this->Forward(); + for (int i : hip_scale_reduction_indices) + { + actions[0][i] *= this->params.hip_scale_reduction; + } + output_torques = this->ComputeTorques(actions); output_dof_pos = this->ComputePosition(actions); } @@ -268,7 +253,7 @@ void RL_Real::RunModel() torch::Tensor RL_Real::ComputeObservation() { - torch::Tensor obs = torch::cat({// (this->QuatRotateInverse(this->base_quat, this->lin_vel)) * this->params.lin_vel_scale, + torch::Tensor obs = torch::cat({// this->QuatRotateInverse(this->obs.base_quat, this->obs.lin_vel) * this->params.lin_vel_scale, this->QuatRotateInverse(this->obs.base_quat, this->obs.ang_vel) * this->params.ang_vel_scale, this->QuatRotateInverse(this->obs.base_quat, this->obs.gravity_vec), this->obs.commands * this->params.commands_scale, @@ -301,6 +286,24 @@ torch::Tensor RL_Real::Forward() return clamped; } +void RL_Real::Plot() +{ + plot_t.push_back(motiontime); + plt::cla(); + plt::clf(); + for(int i = 0; i < 12; ++i) + { + plot_real_joint_pos[i].push_back(state.motorState[i].q); + plot_target_joint_pos[i].push_back(cmd.motorCmd[i].q); + plt::subplot(4, 3, i+1); + plt::named_plot("_real_joint_pos", plot_t, plot_real_joint_pos[i], "r"); + plt::named_plot("_target_joint_pos", plot_t, plot_target_joint_pos[i], "b"); + plt::xlim(motiontime-10000, motiontime); + } + // plt::legend(); + plt::pause(0.0001); +} + void signalHandler(int signum) { exit(0); diff --git a/src/rl_sar/src/rl_sim.cpp b/src/rl_sar/src/rl_sim.cpp index 1258372..323737c 100644 --- a/src/rl_sar/src/rl_sim.cpp +++ b/src/rl_sar/src/rl_sim.cpp @@ -2,42 +2,6 @@ // #define PLOT -void RL_Sim::RobotControl() -{ - motiontime++; - for (int i = 0; i < 12; ++i) - { - motor_commands[i].mode = 0x0A; - motor_commands[i].q = output_dof_pos[0][i].item(); - motor_commands[i].dq = 0; - motor_commands[i].Kp = params.stiffness; - motor_commands[i].Kd = params.damping; - // motor_commands[i].tau = output_torques[0][i].item(); - motor_commands[i].tau = 0; - - torque_publishers[joint_names[i]].publish(motor_commands[i]); - } -} - -void RL_Sim::Plot() -{ - int dof_mapping[13] = {1, 2, 0, 4, 5, 3, 7, 8, 6, 10, 11, 9}; - plot_t.push_back(motiontime); - plt::cla(); - plt::clf(); - for(int i = 0; i < 12; ++i) - { - plot_real_joint_pos[i].push_back(joint_positions[dof_mapping[i]]); - plot_target_joint_pos[i].push_back(motor_commands[i].q); - plt::subplot(4, 3, i+1); - plt::named_plot("_real_joint_pos", plot_t, plot_real_joint_pos[i], "r"); - plt::named_plot("_target_joint_pos", plot_t, plot_target_joint_pos[i], "b"); - plt::xlim(motiontime-10000, motiontime); - } - // plt::legend(); - plt::pause(0.0001); -} - RL_Sim::RL_Sim() { ros::NodeHandle nh; @@ -71,7 +35,6 @@ RL_Sim::RL_Sim() this->params.dof_pos_scale = 1.0; this->params.dof_vel_scale = 0.05; this->params.commands_scale = torch::tensor({this->params.lin_vel_scale, this->params.lin_vel_scale, this->params.ang_vel_scale}); - this->params.torque_limits = torch::tensor({{20.0, 55.0, 55.0, 20.0, 55.0, 55.0, @@ -137,6 +100,23 @@ RL_Sim::~RL_Sim() printf("exit\n"); } +void RL_Sim::RobotControl() +{ + motiontime++; + for (int i = 0; i < 12; ++i) + { + motor_commands[i].mode = 0x0A; + motor_commands[i].q = output_dof_pos[0][i].item(); + motor_commands[i].dq = 0; + motor_commands[i].Kp = params.stiffness; + motor_commands[i].Kd = params.damping; + // motor_commands[i].tau = output_torques[0][i].item(); + motor_commands[i].tau = 0; + + torque_publishers[joint_names[i]].publish(motor_commands[i]); + } +} + void RL_Sim::ModelStatesCallback(const gazebo_msgs::ModelStates::ConstPtr &msg) { @@ -191,6 +171,11 @@ void RL_Sim::RunModel() torch::Tensor actions = this->Forward(); + for (int i : hip_scale_reduction_indices) + { + actions[0][i] *= this->params.hip_scale_reduction; + } + output_torques = this->ComputeTorques(actions); output_dof_pos = this->ComputePosition(actions); } @@ -230,6 +215,25 @@ torch::Tensor RL_Sim::Forward() return clamped; } +void RL_Sim::Plot() +{ + int dof_mapping[13] = {1, 2, 0, 4, 5, 3, 7, 8, 6, 10, 11, 9}; + plot_t.push_back(motiontime); + plt::cla(); + plt::clf(); + for(int i = 0; i < 12; ++i) + { + plot_real_joint_pos[i].push_back(joint_positions[dof_mapping[i]]); + plot_target_joint_pos[i].push_back(motor_commands[i].q); + plt::subplot(4, 3, i+1); + plt::named_plot("_real_joint_pos", plot_t, plot_real_joint_pos[i], "r"); + plt::named_plot("_target_joint_pos", plot_t, plot_target_joint_pos[i], "b"); + plt::xlim(motiontime-10000, motiontime); + } + // plt::legend(); + plt::pause(0.0001); +} + void signalHandler(int signum) { ros::shutdown();