From 9020bb81eef006c39efe1a90b60935289fecf4ff Mon Sep 17 00:00:00 2001 From: fan-ziqi Date: Fri, 15 Mar 2024 21:14:52 +0800 Subject: [PATCH] fix: fix format --- src/rl_sar/include/rl_real.hpp | 1 + src/rl_sar/include/rl_sim.hpp | 11 +++++---- src/rl_sar/library/rl/rl.hpp | 14 ++++++++---- src/rl_sar/src/rl_real.cpp | 9 ++++---- src/rl_sar/src/rl_sim.cpp | 42 +++++++++++++++++++++------------- 5 files changed, 47 insertions(+), 30 deletions(-) diff --git a/src/rl_sar/include/rl_real.hpp b/src/rl_sar/include/rl_real.hpp index c96ba5b..173b7dc 100644 --- a/src/rl_sar/include/rl_real.hpp +++ b/src/rl_sar/include/rl_real.hpp @@ -28,6 +28,7 @@ class RL_Real : public RL public: RL_Real(); ~RL_Real(); + void runModel(); torch::Tensor forward() override; torch::Tensor compute_observation() override; diff --git a/src/rl_sar/include/rl_sim.hpp b/src/rl_sar/include/rl_sim.hpp index 01197fd..bbec0d8 100644 --- a/src/rl_sar/include/rl_sim.hpp +++ b/src/rl_sar/include/rl_sim.hpp @@ -1,21 +1,24 @@ #ifndef RL_SIM_HPP #define RL_SIM_HPP +#include "../library/rl/rl.hpp" +#include "../library/observation_buffer/observation_buffer.hpp" #include #include #include #include -#include "../library/rl/rl.hpp" -#include "../library/observation_buffer/observation_buffer.hpp" #include "unitree_legged_msgs/MotorCmd.h" +#include class RL_Sim : public RL { public: RL_Sim(); + void modelStatesCallback(const gazebo_msgs::ModelStates::ConstPtr &msg); void jointStatesCallback(const sensor_msgs::JointState::ConstPtr &msg); void cmdvelCallback(const geometry_msgs::Twist::ConstPtr &msg); + void runModel(const ros::TimerEvent &event); torch::Tensor forward() override; torch::Tensor compute_observation() override; @@ -24,8 +27,6 @@ public: torch::Tensor history_obs; private: - std::string ros_namespace; - std::vector torque_command_topics; ros::Subscriber model_state_subscriber_; @@ -33,7 +34,7 @@ private: ros::Subscriber cmd_vel_subscriber_; std::map torque_publishers; - std::vector torque_commands; + std::vector motor_commands; geometry_msgs::Twist vel; geometry_msgs::Pose pose; diff --git a/src/rl_sar/library/rl/rl.hpp b/src/rl_sar/library/rl/rl.hpp index b661f20..b82f8b5 100644 --- a/src/rl_sar/library/rl/rl.hpp +++ b/src/rl_sar/library/rl/rl.hpp @@ -5,7 +5,8 @@ #include #include -struct ModelParams { +struct ModelParams +{ int num_observations; float damping; float stiffness; @@ -25,7 +26,8 @@ struct ModelParams { torch::Tensor default_dof_pos; }; -struct Observations { +struct Observations +{ torch::Tensor lin_vel; torch::Tensor ang_vel; torch::Tensor gravity_vec; @@ -36,9 +38,11 @@ struct Observations { torch::Tensor actions; }; -class RL { +class RL +{ public: RL(){}; + ModelParams params; Observations obs; @@ -61,8 +65,8 @@ protected: torch::Tensor dof_pos; torch::Tensor dof_vel; torch::Tensor actions; - - torch::Tensor torques = torch::tensor({{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}}); + // output buffer + torch::Tensor torques; torch::Tensor target_dof_pos; }; diff --git a/src/rl_sar/src/rl_real.cpp b/src/rl_sar/src/rl_real.cpp index f02885a..d44af05 100644 --- a/src/rl_sar/src/rl_real.cpp +++ b/src/rl_sar/src/rl_real.cpp @@ -159,6 +159,7 @@ RL_Real::RL_Real() : safe(LeggedType::A1), udp(LOWLEVEL) this->history_obs_buf = ObservationBuffer(1, this->params.num_observations, 6); + torques = torch::tensor({{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}}); target_dof_pos = params.default_dof_pos; // InitEnvironment(); @@ -185,15 +186,15 @@ void RL_Real::runModel() { if(init_state == STATE_RL_START) { - auto duration = std::chrono::duration_cast(std::chrono::high_resolution_clock::now() - start_time).count(); - std::cout << "Execution time: " << duration << " microseconds" << std::endl; - start_time = std::chrono::high_resolution_clock::now(); + // auto duration = std::chrono::duration_cast(std::chrono::high_resolution_clock::now() - start_time).count(); + // std::cout << "Execution time: " << duration << " microseconds" << std::endl; + // start_time = std::chrono::high_resolution_clock::now(); // printf("%f, %f, %f\n", state.imu.gyroscope[0], state.imu.gyroscope[1], state.imu.gyroscope[2]); // printf("%f, %f, %f, %f\n", state.imu.quaternion[1], state.imu.quaternion[2], state.imu.quaternion[3], state.imu.quaternion[0]); // printf("%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f\n", state.motorState[FL_0].q, state.motorState[FL_1].q, state.motorState[FL_2].q, state.motorState[FR_0].q, state.motorState[FR_1].q, state.motorState[FR_2].q, state.motorState[RL_0].q, state.motorState[RL_1].q, state.motorState[RL_2].q, state.motorState[RR_0].q, state.motorState[RR_1].q, state.motorState[RR_2].q); // printf("%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f\n", state.motorState[FL_0].dq, state.motorState[FL_1].dq, state.motorState[FL_2].dq, state.motorState[FR_0].dq, state.motorState[FR_1].dq, state.motorState[FR_2].dq, state.motorState[RL_0].dq, state.motorState[RL_1].dq, state.motorState[RL_2].dq, state.motorState[RR_0].dq, state.motorState[RR_1].dq, state.motorState[RR_2].dq); - + this->obs.ang_vel = torch::tensor({{state.imu.gyroscope[0], state.imu.gyroscope[1], state.imu.gyroscope[2]}}); this->obs.commands = torch::tensor({{_keyData.ly, -_keyData.rx, -_keyData.lx}}); this->obs.base_quat = torch::tensor({{state.imu.quaternion[1], state.imu.quaternion[2], state.imu.quaternion[3], state.imu.quaternion[0]}}); diff --git a/src/rl_sar/src/rl_sim.cpp b/src/rl_sar/src/rl_sim.cpp index 502f0af..171a57f 100644 --- a/src/rl_sar/src/rl_sim.cpp +++ b/src/rl_sar/src/rl_sim.cpp @@ -1,5 +1,6 @@ #include "../include/rl_sim.hpp" -#include + +RL_Sim rl_sar; RL_Sim::RL_Sim() { @@ -8,7 +9,7 @@ RL_Sim::RL_Sim() cmd_vel = geometry_msgs::Twist(); - torque_commands.resize(12); + motor_commands.resize(12); std::string actor_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/actor.pt"; std::string encoder_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/encoder.pt"; @@ -49,14 +50,14 @@ RL_Sim::RL_Sim() this->history_obs_buf = ObservationBuffer(1, this->params.num_observations, 6); + torques = torch::tensor({{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}}); target_dof_pos = params.default_dof_pos; - cmd_vel_subscriber_ = nh.subscribe( - "/cmd_vel", 10, &RL_Sim::cmdvelCallback, this); + cmd_vel_subscriber_ = nh.subscribe("/cmd_vel", 10, &RL_Sim::cmdvelCallback, this); timer = nh.createTimer(ros::Duration(0.02), &RL_Sim::runModel, this); - ros_namespace = "/a1_gazebo/"; + std::string ros_namespace = "/a1_gazebo/"; joint_names = { "FL_hip_joint", "FL_thigh_joint", "FL_calf_joint", @@ -98,9 +99,9 @@ void RL_Sim::jointStatesCallback(const sensor_msgs::JointState::ConstPtr &msg) void RL_Sim::runModel(const ros::TimerEvent &event) { - auto duration = std::chrono::duration_cast(std::chrono::high_resolution_clock::now() - start_time).count(); + // auto duration = std::chrono::duration_cast(std::chrono::high_resolution_clock::now() - start_time).count(); // std::cout << "Execution time: " << duration << " microseconds" << std::endl; - start_time = std::chrono::high_resolution_clock::now(); + // start_time = std::chrono::high_resolution_clock::now(); this->obs.lin_vel = torch::tensor({{vel.linear.x, vel.linear.y, vel.linear.z}}); this->obs.ang_vel = torch::tensor({{vel.angular.x, vel.angular.y, vel.angular.z}}); @@ -121,15 +122,15 @@ void RL_Sim::runModel(const ros::TimerEvent &event) for (int i = 0; i < 12; ++i) { - torque_commands[i].mode = 0x0A; - // torque_commands[i].tau = torques[0][i].item(); - torque_commands[i].tau = 0; - torque_commands[i].q = target_dof_pos[0][i].item(); - torque_commands[i].dq = 0; - torque_commands[i].Kp = params.stiffness; - torque_commands[i].Kd = params.damping; + motor_commands[i].mode = 0x0A; + // motor_commands[i].tau = torques[0][i].item(); + motor_commands[i].tau = 0; + motor_commands[i].q = target_dof_pos[0][i].item(); + motor_commands[i].dq = 0; + motor_commands[i].Kp = params.stiffness; + motor_commands[i].Kd = params.damping; - torque_publishers[joint_names[i]].publish(torque_commands[i]); + torque_publishers[joint_names[i]].publish(motor_commands[i]); } } @@ -168,10 +169,19 @@ torch::Tensor RL_Sim::forward() return clamped; } +void signalHandler(int signum) +{ + ros::shutdown(); + exit(0); +} + int main(int argc, char **argv) { + signal(SIGINT, signalHandler); + ros::init(argc, argv, "rl_sar"); - RL_Sim rl_sar; + ros::spin(); + return 0; }