From 1d6ecb2771384e9a22fe483b94ed8f87b9d7bbd1 Mon Sep 17 00:00:00 2001 From: fan-ziqi Date: Fri, 24 May 2024 11:40:03 +0800 Subject: [PATCH 1/2] only sim --- src/rl_sar/CMakeLists.txt | 21 ++- src/rl_sar/config.yaml | 27 +++ src/rl_sar/include/rl_real_a1.hpp | 8 - src/rl_sar/include/rl_real_cyberdog.hpp | 28 --- src/rl_sar/include/rl_sim.hpp | 5 + src/rl_sar/launch/start_a1.launch | 2 + src/rl_sar/library/rl_sdk/rl_sdk.cpp | 218 +++++++++++++++++++++--- src/rl_sar/library/rl_sdk/rl_sdk.hpp | 87 ++++++++-- src/rl_sar/src/rl_real_cyberdog.cpp | 49 +----- src/rl_sar/src/rl_sim.cpp | 110 ++++++++---- 10 files changed, 397 insertions(+), 158 deletions(-) diff --git a/src/rl_sar/CMakeLists.txt b/src/rl_sar/CMakeLists.txt index 1a088e8..f2c93c4 100644 --- a/src/rl_sar/CMakeLists.txt +++ b/src/rl_sar/CMakeLists.txt @@ -75,15 +75,14 @@ target_link_libraries(rl_sim rl_sdk observation_buffer yaml-cpp ) -add_executable(rl_real_a1 src/rl_real_a1.cpp ) -target_link_libraries(rl_real_a1 - ${catkin_LIBRARIES} ${EXTRA_LIBS} "${TORCH_LIBRARIES}" - rl_sdk observation_buffer yaml-cpp -) +# add_executable(rl_real_a1 src/rl_real_a1.cpp ) +# target_link_libraries(rl_real_a1 +# ${catkin_LIBRARIES} ${EXTRA_LIBS} "${TORCH_LIBRARIES}" +# rl_sdk observation_buffer yaml-cpp +# ) - -add_executable(rl_real_cyberdog src/rl_real_cyberdog.cpp ) -target_link_libraries(rl_real_cyberdog - ${catkin_LIBRARIES} ${EXTRA_LIBS} "${TORCH_LIBRARIES}" - rl_sdk observation_buffer cyberdog_motor_sdk yaml-cpp -) \ No newline at end of file +# add_executable(rl_real_cyberdog src/rl_real_cyberdog.cpp ) +# target_link_libraries(rl_real_cyberdog +# ${catkin_LIBRARIES} ${EXTRA_LIBS} "${TORCH_LIBRARIES}" +# rl_sdk observation_buffer cyberdog_motor_sdk yaml-cpp +# ) \ No newline at end of file diff --git a/src/rl_sar/config.yaml b/src/rl_sar/config.yaml index f6fe237..19c85c4 100644 --- a/src/rl_sar/config.yaml +++ b/src/rl_sar/config.yaml @@ -107,3 +107,30 @@ lite3_wheel: "FR_hip_joint", "FR_thigh_joint", "FR_calf_joint", "RL_hip_joint", "RL_thigh_joint", "RL_calf_joint", "RR_hip_joint", "RR_thigh_joint", "RR_calf_joint"] + +gr1t1: + model_name: "model_4000_jit.pt" + num_observations: 39 + clip_obs: 100.0 + clip_actions: 100.0 + # damping: 0.5 + # stiffness: 20.0 + p_gains: [57.0, 43.0, 114.0, 114.0, 15.3, + 57.0, 43.0, 114.0, 114.0, 15.3] + d_gains: [5.7, 4.3, 11.4, 11.4, 1.5, + 5.7, 4.3, 11.4, 11.4, 1.5] + action_scale: 1.0 + hip_scale_reduction: 1.0 + hip_scale_reduction_indices: [] + num_of_dofs: 10 + lin_vel_scale: 2.0 + ang_vel_scale: 0.25 + dof_pos_scale: 1.0 + dof_vel_scale: 0.05 + commands_scale: [2.0, 2.0, 0.25] + torque_limits: [100.0, 100.0, 100.0, 100.0, 100.0, + 100.0, 100.0, 100.0, 100.0, 100.0] + default_dof_pos: [0.0, 0.0, -0.2618, 0.5236, -0.2618, + 0.0, 0.0, -0.2618, 0.5236, -0.2618] + joint_names: ["l_hip_roll_joint", "l_hip_yaw_joint", "l_hip_pitch_joint", "l_knee_pitch_joint", "l_ankle_pitch_joint", + "r_hip_roll_joint", "r_hip_yaw_joint", "r_hip_pitch_joint", "r_knee_pitch_joint", "r_ankle_pitch_joint"] \ No newline at end of file diff --git a/src/rl_sar/include/rl_real_a1.hpp b/src/rl_sar/include/rl_real_a1.hpp index 3df2625..d004a49 100644 --- a/src/rl_sar/include/rl_real_a1.hpp +++ b/src/rl_sar/include/rl_real_a1.hpp @@ -8,14 +8,6 @@ #include // #include -enum RobotState { - STATE_WAITING = 0, - STATE_POS_GETUP, - STATE_RL_INIT, - STATE_RL_RUNNING, - STATE_POS_GETDOWN, -}; - class RL_Real : public RL { public: diff --git a/src/rl_sar/include/rl_real_cyberdog.hpp b/src/rl_sar/include/rl_real_cyberdog.hpp index 3dd7458..31744fe 100644 --- a/src/rl_sar/include/rl_real_cyberdog.hpp +++ b/src/rl_sar/include/rl_real_cyberdog.hpp @@ -8,29 +8,10 @@ #include #include // #include -#include -#include -#include using CyberdogData = Robot_Data; using CyberdogCmd = Motor_Cmd; -enum RobotState { - STATE_WAITING = 0, - STATE_POS_GETUP, - STATE_RL_INIT, - STATE_RL_RUNNING, - STATE_POS_GETDOWN, -}; - -struct KeyBoard -{ - RobotState robot_state; - float x = 0; - float y = 0; - float yaw = 0; -}; - class RL_Real : public RL, public CustomInterface { public: @@ -56,20 +37,11 @@ public: std::shared_ptr loop_rl; std::shared_ptr loop_plot; - float getup_percent = 0.0; - float getdown_percent = 0.0; - float start_pos[12]; - float now_pos[12]; - - int robot_state = STATE_WAITING; - const int plot_size = 100; std::vector plot_t; std::vector> plot_real_joint_pos, plot_target_joint_pos; void Plot(); - void run_keyboard(); - KeyBoard keyboard; std::thread _keyboardThread; private: std::vector joint_names; diff --git a/src/rl_sar/include/rl_sim.hpp b/src/rl_sar/include/rl_sim.hpp index 0b691a0..8c4f579 100644 --- a/src/rl_sar/include/rl_sim.hpp +++ b/src/rl_sar/include/rl_sim.hpp @@ -29,6 +29,9 @@ public: torch::Tensor Forward() override; torch::Tensor ComputeObservation() override; + void GetState(RobotState *state) override; + void SetCommand(const RobotCommand *command) override; + ObservationBuffer history_obs_buf; torch::Tensor history_obs; @@ -42,6 +45,8 @@ public: std::vector plot_t; std::vector> plot_real_joint_pos, plot_target_joint_pos; void Plot(); + + std::thread _keyboardThread; private: std::string ros_namespace; diff --git a/src/rl_sar/launch/start_a1.launch b/src/rl_sar/launch/start_a1.launch index b5db455..6adb284 100644 --- a/src/rl_sar/launch/start_a1.launch +++ b/src/rl_sar/launch/start_a1.launch @@ -51,4 +51,6 @@ + + diff --git a/src/rl_sar/library/rl_sdk/rl_sdk.cpp b/src/rl_sar/library/rl_sdk/rl_sdk.cpp index 1e2bcba..419b54b 100644 --- a/src/rl_sar/library/rl_sdk/rl_sdk.cpp +++ b/src/rl_sar/library/rl_sdk/rl_sdk.cpp @@ -1,5 +1,8 @@ #include "rl_sdk.hpp" +#include +#include + template std::vector ReadVectorFromYaml(const YAML::Node& node) { @@ -26,26 +29,26 @@ void RL::ReadYaml(std::string robot_name) this->params.model_name = config["model_name"].as(); this->params.num_observations = config["num_observations"].as(); - this->params.clip_obs = config["clip_obs"].as(); - this->params.clip_actions = config["clip_actions"].as(); - this->params.action_scale = config["action_scale"].as(); - this->params.hip_scale_reduction = config["hip_scale_reduction"].as(); + this->params.clip_obs = config["clip_obs"].as(); + this->params.clip_actions = config["clip_actions"].as(); + this->params.action_scale = config["action_scale"].as(); + this->params.hip_scale_reduction = config["hip_scale_reduction"].as(); this->params.hip_scale_reduction_indices = ReadVectorFromYaml(config["hip_scale_reduction_indices"]); this->params.num_of_dofs = config["num_of_dofs"].as(); - this->params.lin_vel_scale = config["lin_vel_scale"].as(); - this->params.ang_vel_scale = config["ang_vel_scale"].as(); - this->params.dof_pos_scale = config["dof_pos_scale"].as(); - this->params.dof_vel_scale = config["dof_vel_scale"].as(); - // this->params.commands_scale = torch::tensor(ReadVectorFromYaml(config["commands_scale"])).view({1, -1}); + this->params.lin_vel_scale = config["lin_vel_scale"].as(); + this->params.ang_vel_scale = config["ang_vel_scale"].as(); + this->params.dof_pos_scale = config["dof_pos_scale"].as(); + this->params.dof_vel_scale = config["dof_vel_scale"].as(); + // this->params.commands_scale = torch::tensor(ReadVectorFromYaml(config["commands_scale"])).view({1, -1}); this->params.commands_scale = torch::tensor({this->params.lin_vel_scale, this->params.lin_vel_scale, this->params.ang_vel_scale}); - // this->params.damping = config["damping"].as(); - // this->params.stiffness = config["stiffness"].as(); + // this->params.damping = config["damping"].as(); + // this->params.stiffness = config["stiffness"].as(); // this->params.d_gains = torch::ones(12) * this->params.damping; // this->params.p_gains = torch::ones(12) * this->params.stiffness; - this->params.p_gains = torch::tensor(ReadVectorFromYaml(config["p_gains"])).view({1, -1}); - this->params.d_gains = torch::tensor(ReadVectorFromYaml(config["d_gains"])).view({1, -1}); - this->params.torque_limits = torch::tensor(ReadVectorFromYaml(config["torque_limits"])).view({1, -1}); - this->params.default_dof_pos = torch::tensor(ReadVectorFromYaml(config["default_dof_pos"])).view({1, -1}); + this->params.p_gains = torch::tensor(ReadVectorFromYaml(config["p_gains"])).view({1, -1}); + this->params.d_gains = torch::tensor(ReadVectorFromYaml(config["d_gains"])).view({1, -1}); + this->params.torque_limits = torch::tensor(ReadVectorFromYaml(config["torque_limits"])).view({1, -1}); + this->params.default_dof_pos = torch::tensor(ReadVectorFromYaml(config["default_dof_pos"])).view({1, -1}); this->params.joint_names = ReadVectorFromYaml(config["joint_names"]); } @@ -110,14 +113,22 @@ void RL::InitObservations() this->obs.commands = torch::tensor({{0.0, 0.0, 0.0}}); this->obs.base_quat = torch::tensor({{0.0, 0.0, 0.0, 1.0}}); this->obs.dof_pos = this->params.default_dof_pos; - this->obs.dof_vel = torch::tensor({{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}}); - this->obs.actions = torch::tensor({{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}}); + this->obs.dof_vel = torch::zeros({1, params.num_of_dofs}); + this->obs.actions = torch::zeros({1, params.num_of_dofs}); } void RL::InitOutputs() { - output_torques = torch::tensor({{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}}); - output_dof_pos = params.default_dof_pos; + this->output_torques = torch::zeros({1, params.num_of_dofs}); + this->output_dof_pos = params.default_dof_pos; +} + +void RL::InitKeyboard() +{ + this->keyboard.keyboard_state = STATE_WAITING; + this->keyboard.x = 0.0; + this->keyboard.y = 0.0; + this->keyboard.yaw = 0.0; } torch::Tensor RL::ComputeTorques(torch::Tensor actions) @@ -169,3 +180,172 @@ torch::Tensor RL::Forward() return clamped; } */ + +static bool kbhit() +{ + termios term; + tcgetattr(0, &term); + + termios term2 = term; + term2.c_lflag &= ~ICANON; + tcsetattr(0, TCSANOW, &term2); + + int byteswaiting; + ioctl(0, FIONREAD, &byteswaiting); + + tcsetattr(0, TCSANOW, &term); + + return byteswaiting > 0; +} + +void RL::run_keyboard() +{ + int c; + // Check for keyboard input + while(true) + { + if(kbhit()) + { + c = fgetc(stdin); + switch(c) + { + case '0': keyboard.keyboard_state = STATE_POS_GETUP; break; + case 'p': keyboard.keyboard_state = STATE_RL_INIT; break; + case '1': keyboard.keyboard_state = STATE_POS_GETDOWN; break; + case 'q': break; + case 'w': keyboard.x += 0.1; break; + case 's': keyboard.x -= 0.1; break; + case 'a': keyboard.yaw += 0.1; break; + case 'd': keyboard.yaw -= 0.1; break; + case 'i': break; + case 'k': break; + case 'j': keyboard.y += 0.1; break; + case 'l': keyboard.y -= 0.1; break; + case ' ': keyboard.x = 0; keyboard.y = 0; keyboard.yaw = 0; break; + default: break; + } + } + usleep(10000); + } +} + +void RL::StateController(const RobotState *state, RobotCommand *command) +{ + // waiting + if(running_state == STATE_WAITING) + { + for(int i = 0; i < params.num_of_dofs; ++i) + { + command->motor_command.q[i] = state->motor_state.q[i]; + } + if(keyboard.keyboard_state == STATE_POS_GETUP) + { + keyboard.keyboard_state = STATE_WAITING; + getup_percent = 0.0; + for(int i = 0; i < params.num_of_dofs; ++i) + { + now_pos[i] = state->motor_state.q[i]; + start_pos[i] = now_pos[i]; + } + running_state = STATE_POS_GETUP; + } + } + // stand up (position control) + else if(running_state == STATE_POS_GETUP) + { + if(getup_percent != 1) + { + getup_percent += 1 / 1000.0; + getup_percent = getup_percent > 1 ? 1 : getup_percent; + for(int i = 0; i < params.num_of_dofs; ++i) + { + command->motor_command.q[i] = (1 - getup_percent) * now_pos[i] + getup_percent * params.default_dof_pos[0][i].item(); + command->motor_command.dq[i] = 0; + command->motor_command.kp[i] = 200; + command->motor_command.kd[i] = 10; + command->motor_command.tau[i] = 0; + } + printf("getting up %.3f%%\r", getup_percent*100.0); + } + if(keyboard.keyboard_state == STATE_RL_INIT) + { + keyboard.keyboard_state = STATE_WAITING; + running_state = STATE_RL_INIT; + } + else if(keyboard.keyboard_state == STATE_POS_GETDOWN) + { + keyboard.keyboard_state = STATE_WAITING; + getdown_percent = 0.0; + for(int i = 0; i < params.num_of_dofs; ++i) + { + now_pos[i] = state->motor_state.q[i]; + } + running_state = STATE_POS_GETDOWN; + } + } + // init obs and start rl loop + else if(running_state == STATE_RL_INIT) + { + if(getup_percent == 1) + { + running_state = STATE_RL_RUNNING; + this->InitObservations(); + this->InitOutputs(); + this->InitKeyboard(); + // printf("\nstart rl loop\n"); + // loop_rl->start(); + } + } + // rl loop + else if(running_state == STATE_RL_RUNNING) + { + for(int i = 0; i < params.num_of_dofs; ++i) + { + command->motor_command.q[i] = output_dof_pos[0][i].item(); + command->motor_command.dq[i] = 0; + // command->motor_command.kp[i] = params.stiffness; + // command->motor_command.kd[i] = params.damping; + command->motor_command.kp[i] = params.p_gains[0][i].item(); + command->motor_command.kd[i] = params.d_gains[0][i].item(); + // command->motor_command.tau[i] = output_torques[0][i].item(); + command->motor_command.tau[i] = 0; + } + if(keyboard.keyboard_state == STATE_POS_GETDOWN) + { + keyboard.keyboard_state = STATE_WAITING; + getdown_percent = 0.0; + for(int i = 0; i < params.num_of_dofs; ++i) + { + now_pos[i] = state->motor_state.q[i]; + } + running_state = STATE_POS_GETDOWN; + } + } + // get down (position control) + else if(running_state == STATE_POS_GETDOWN) + { + if(getdown_percent != 1) + { + getdown_percent += 1 / 1000.0; + getdown_percent = getdown_percent > 1 ? 1 : getdown_percent; + for(int i = 0; i < params.num_of_dofs; ++i) + { + command->motor_command.q[i] = (1 - getdown_percent) * now_pos[i] + getdown_percent * start_pos[i]; + command->motor_command.dq[i] = 0; + command->motor_command.kp[i] = 200; + command->motor_command.kd[i] = 10; + command->motor_command.tau[i] = 0; + } + printf("getting down %.3f%%\r", getdown_percent*100.0); + } + if(getdown_percent == 1) + { + running_state = STATE_WAITING; + this->InitObservations(); + this->InitOutputs(); + this->InitKeyboard(); + // printf("\nstop rl loop\n"); + // loop_rl->shutdown(); + } + } +} \ No newline at end of file diff --git a/src/rl_sar/library/rl_sdk/rl_sdk.hpp b/src/rl_sar/library/rl_sdk/rl_sdk.hpp index a0326d8..6ab4c60 100644 --- a/src/rl_sar/library/rl_sdk/rl_sdk.hpp +++ b/src/rl_sar/library/rl_sdk/rl_sdk.hpp @@ -11,22 +11,71 @@ namespace plt = matplotlibcpp; #include #define CONFIG_PATH CMAKE_CURRENT_SOURCE_DIR "/config.yaml" +template +struct RobotCommand +{ + struct MotorCommand + { + std::vector q = std::vector(32, 0.0); + std::vector dq = std::vector(32, 0.0); + std::vector tau = std::vector(32, 0.0); + std::vector kp = std::vector(32, 0.0); + std::vector kd = std::vector(32, 0.0); + } motor_command; +}; + +template +struct RobotState +{ + struct IMU + { + T quaternion[4] = {1.0, 0.0, 0.0, 0.0}; // w, x, y, z + T gyroscope[3] = {0.0, 0.0, 0.0}; + T accelerometer[3] = {0.0, 0.0, 0.0}; + } imu; + + struct MotorState + { + std::vector q = std::vector(32, 0.0); + std::vector dq = std::vector(32, 0.0); + std::vector ddq = std::vector(32, 0.0); + std::vector tauEst = std::vector(32, 0.0); + std::vector cur = std::vector(32, 0.0); + } motor_state; +}; + +enum STATE { + STATE_WAITING = 0, + STATE_POS_GETUP, + STATE_RL_INIT, + STATE_RL_RUNNING, + STATE_POS_GETDOWN, +}; + +struct KeyBoard +{ + STATE keyboard_state; + double x = 0.0; + double y = 0.0; + double yaw = 0.0; +}; + struct ModelParams { std::string model_name; int num_observations; - float damping; - float stiffness; - float action_scale; - float hip_scale_reduction; + double damping; + double stiffness; + double action_scale; + double hip_scale_reduction; std::vector hip_scale_reduction_indices; int num_of_dofs; - float lin_vel_scale; - float ang_vel_scale; - float dof_pos_scale; - float dof_vel_scale; - float clip_obs; - float clip_actions; + double lin_vel_scale; + double ang_vel_scale; + double dof_pos_scale; + double dof_vel_scale; + double clip_obs; + double clip_actions; torch::Tensor torque_limits; torch::Tensor d_gains; torch::Tensor p_gains; @@ -62,10 +111,26 @@ public: torch::Tensor QuatRotateInverse(torch::Tensor q, torch::Tensor v); void InitObservations(); void InitOutputs(); + void InitKeyboard(); void ReadYaml(std::string robot_name); std::string csv_filename; void CSVInit(std::string robot_name); void CSVLogger(torch::Tensor torque, torch::Tensor tau_est, torch::Tensor joint_pos, torch::Tensor joint_pos_target, torch::Tensor joint_vel); + void run_keyboard(); + + float getup_percent = 0.0; + float getdown_percent = 0.0; + std::vector start_pos; + std::vector now_pos; + + int running_state = STATE_WAITING; + + RobotState robot_state; + RobotCommand robot_command; + + virtual void GetState(RobotState *state) = 0; + virtual void SetCommand(const RobotCommand *command) = 0; + void StateController(const RobotState *state, RobotCommand *command); protected: // rl module @@ -82,6 +147,8 @@ protected: // output buffer torch::Tensor output_torques; torch::Tensor output_dof_pos; + // keyboard + KeyBoard keyboard; }; #endif // RL_SDK_HPP \ No newline at end of file diff --git a/src/rl_sar/src/rl_real_cyberdog.cpp b/src/rl_sar/src/rl_real_cyberdog.cpp index 407e8e9..f218e23 100644 --- a/src/rl_sar/src/rl_real_cyberdog.cpp +++ b/src/rl_sar/src/rl_real_cyberdog.cpp @@ -3,7 +3,7 @@ #define ROBOT_NAME "cyberdog1" // #define PLOT -#define CSV_LOGGER +// #define CSV_LOGGER RL_Real rl_sar; @@ -202,53 +202,6 @@ void RL_Real::UserCode() motor_cmd = cyberdogCmd; } -static bool kbhit() -{ - termios term; - tcgetattr(0, &term); - - termios term2 = term; - term2.c_lflag &= ~ICANON; - tcsetattr(0, TCSANOW, &term2); - - int byteswaiting; - ioctl(0, FIONREAD, &byteswaiting); - - tcsetattr(0, TCSANOW, &term); - - return byteswaiting > 0; -} -void RL_Real::run_keyboard() -{ - int c; - // Check for keyboard input - while(true) - { - if(kbhit()) - { - c = fgetc(stdin); - switch(c) - { - case '0': keyboard.robot_state = STATE_POS_GETUP; break; - case 'p': keyboard.robot_state = STATE_RL_INIT; break; - case '1': keyboard.robot_state = STATE_POS_GETDOWN; break; - case 'q': break; - case 'w': keyboard.x += 0.5; break; - case 's': keyboard.x -= 0.5; break; - case 'a': keyboard.yaw += 0.5; break; - case 'd': keyboard.yaw -= 0.5; break; - case 'i': break; - case 'k': break; - case 'j': keyboard.y += 0.5; break; - case 'l': keyboard.y -= 0.5; break; - case ' ': keyboard.x = 0; keyboard.y = 0; keyboard.yaw = 0; break; - default: break; - } - } - usleep(10000); - } -} - void RL_Real::RunModel() { if(robot_state == STATE_RL_RUNNING) diff --git a/src/rl_sar/src/rl_sim.cpp b/src/rl_sar/src/rl_sim.cpp index 95aba5c..003bd82 100644 --- a/src/rl_sar/src/rl_sim.cpp +++ b/src/rl_sar/src/rl_sim.cpp @@ -1,10 +1,11 @@ #include "../include/rl_sim.hpp" -#define ROBOT_NAME "a1" +// #define ROBOT_NAME "a1" +#define ROBOT_NAME "gr1t1" // #define PLOT // #define CSV_LOGGER -#define USE_HISTORY +// #define USE_HISTORY RL_Sim::RL_Sim() { @@ -27,12 +28,15 @@ RL_Sim::RL_Sim() cmd_vel = geometry_msgs::Twist(); motor_commands.resize(params.num_of_dofs); + start_pos.resize(params.num_of_dofs); + now_pos.resize(params.num_of_dofs); std::string model_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/" + ROBOT_NAME + "/" + this->params.model_name; this->model = torch::jit::load(model_path); this->InitObservations(); this->InitOutputs(); + this->InitKeyboard(); #ifdef USE_HISTORY this->history_obs_buf = ObservationBuffer(1, this->params.num_observations, 6); @@ -73,6 +77,7 @@ RL_Sim::RL_Sim() loop_plot = std::make_shared("loop_plot" , 0.002, boost::bind(&RL_Sim::Plot, this)); loop_plot->start(); #endif + _keyboardThread = std::thread(&RL_Sim::run_keyboard, this); #ifdef CSV_LOGGER CSVInit(ROBOT_NAME); @@ -89,24 +94,57 @@ RL_Sim::~RL_Sim() printf("exit\n"); } -void RL_Sim::RobotControl() +void RL_Sim::GetState(RobotState *state) { - motiontime++; - for (int i = 0; i < params.num_of_dofs; ++i) - { - motor_commands[i].q = output_dof_pos[0][i].item(); - motor_commands[i].dq = 0; - // motor_commands[i].Kp = params.stiffness; - // motor_commands[i].Kd = params.damping; - motor_commands[i].kp = params.p_gains[0][i].item(); - motor_commands[i].kd = params.d_gains[0][i].item(); - // motor_commands[i].tau = output_torques[0][i].item(); - motor_commands[i].tau = 0; + state->imu.quaternion[0] = pose.orientation.w; + state->imu.quaternion[1] = pose.orientation.x; + state->imu.quaternion[2] = pose.orientation.y; + state->imu.quaternion[3] = pose.orientation.z; + state->imu.gyroscope[0] = vel.angular.x; + state->imu.gyroscope[1] = vel.angular.y; + state->imu.gyroscope[2] = vel.angular.z; + + // state->imu.accelerometer + + for(int i = 0; i < params.num_of_dofs; ++i) + { + state->motor_state.q[i] = joint_positions[i]; + state->motor_state.dq[i] = joint_velocities[i]; + state->motor_state.tauEst[i] = joint_efforts[i]; + } +} + +void RL_Sim::SetCommand(const RobotCommand *command) +{ + for(int i = 0; i < params.num_of_dofs; ++i) + { + motor_commands[i].q = command->motor_command.q[i]; + motor_commands[i].dq = command->motor_command.dq[i]; + motor_commands[i].kp = command->motor_command.kp[i]; + motor_commands[i].kd = command->motor_command.kd[i]; + motor_commands[i].tau = command->motor_command.tau[i]; + } + + for(int i = 0; i < params.num_of_dofs; ++i) + { torque_publishers[params.joint_names[i]].publish(motor_commands[i]); } } +void RL_Sim::RobotControl() +{ + std::cout << "running_state " << keyboard.keyboard_state + << " x" << keyboard.x << " y" << keyboard.y << " yaw" << keyboard.yaw + << " \r"; + + motiontime++; + + GetState(&robot_state); + StateController(&robot_state, &robot_command); + SetCommand(&robot_command); +} + void RL_Sim::ModelStatesCallback(const gazebo_msgs::ModelStates::ConstPtr &msg) { vel = msg->twist[2]; @@ -135,27 +173,31 @@ void RL_Sim::JointStatesCallback(const sensor_msgs::JointState::ConstPtr &msg) void RL_Sim::RunModel() { - // this->obs.lin_vel = torch::tensor({{vel.linear.x, vel.linear.y, vel.linear.z}}); - this->obs.ang_vel = torch::tensor({{vel.angular.x, vel.angular.y, vel.angular.z}}); - this->obs.commands = torch::tensor({{cmd_vel.linear.x, cmd_vel.linear.y, cmd_vel.angular.z}}); - this->obs.base_quat = torch::tensor({{pose.orientation.x, pose.orientation.y, pose.orientation.z, pose.orientation.w}}); - this->obs.dof_pos = torch::tensor(joint_positions).unsqueeze(0); - this->obs.dof_vel = torch::tensor(joint_velocities).unsqueeze(0); - - torch::Tensor actions = this->Forward(); - - for (int i : this->params.hip_scale_reduction_indices) + if(running_state == STATE_RL_RUNNING) { - actions[0][i] *= this->params.hip_scale_reduction; - } + // this->obs.lin_vel = torch::tensor({{vel.linear.x, vel.linear.y, vel.linear.z}}); + this->obs.ang_vel = torch::tensor({{vel.angular.x, vel.angular.y, vel.angular.z}}); + // this->obs.commands = torch::tensor({{cmd_vel.linear.x, cmd_vel.linear.y, cmd_vel.angular.z}}); + this->obs.commands = torch::tensor({{keyboard.x, keyboard.y, keyboard.yaw}}); + this->obs.base_quat = torch::tensor({{pose.orientation.x, pose.orientation.y, pose.orientation.z, pose.orientation.w}}); + this->obs.dof_pos = torch::tensor(joint_positions).unsqueeze(0); + this->obs.dof_vel = torch::tensor(joint_velocities).unsqueeze(0); - output_torques = this->ComputeTorques(actions); - output_dof_pos = this->ComputePosition(actions); + torch::Tensor clamped_actions = this->Forward(); + + for (int i : this->params.hip_scale_reduction_indices) + { + clamped_actions[0][i] *= this->params.hip_scale_reduction; + } + + output_torques = this->ComputeTorques(clamped_actions); + output_dof_pos = this->ComputePosition(clamped_actions); #ifdef CSV_LOGGER - torch::Tensor tau_est = torch::tensor(joint_efforts).unsqueeze(0); - CSVLogger(output_torques, tau_est, this->obs.dof_pos, output_dof_pos, this->obs.dof_vel); + torch::Tensor tau_est = torch::tensor(joint_efforts).unsqueeze(0); + CSVLogger(output_torques, tau_est, this->obs.dof_pos, output_dof_pos, this->obs.dof_vel); #endif + } } torch::Tensor RL_Sim::ComputeObservation() @@ -181,13 +223,13 @@ torch::Tensor RL_Sim::Forward() history_obs = history_obs_buf.get_obs_vec({0, 1, 2, 3, 4, 5}); torch::Tensor action = this->model.forward({history_obs}).toTensor(); #else - torch::Tensor action = this->model.forward({obs}).toTensor(); + torch::Tensor actions = this->model.forward({obs}).toTensor(); #endif - this->obs.actions = action; - torch::Tensor clamped = torch::clamp(action, -this->params.clip_actions, this->params.clip_actions); + this->obs.actions = actions; + torch::Tensor clamped_actions = torch::clamp(actions, -this->params.clip_actions, this->params.clip_actions); - return clamped; + return clamped_actions; } void RL_Sim::Plot() From 75a943e9d8b4f409c556b070f99c7fef198d11a6 Mon Sep 17 00:00:00 2001 From: fan-ziqi Date: Fri, 24 May 2024 12:03:27 +0800 Subject: [PATCH 2/2] move matplotlib --- .gitignore | 7 ++++++- src/rl_sar/include/rl_real_a1.hpp | 3 +++ src/rl_sar/include/rl_real_cyberdog.hpp | 3 +++ src/rl_sar/include/rl_sim.hpp | 3 +++ src/rl_sar/library/rl_sdk/rl_sdk.hpp | 4 +--- 5 files changed, 16 insertions(+), 4 deletions(-) diff --git a/.gitignore b/.gitignore index 16071d3..efc853a 100644 --- a/.gitignore +++ b/.gitignore @@ -3,4 +3,9 @@ devel logs .catkin_tools .vscode -*.csv \ No newline at end of file +*.csv +*lite3* +*fldlar* +.cache +.json +# *gr1t1* \ No newline at end of file diff --git a/src/rl_sar/include/rl_real_a1.hpp b/src/rl_sar/include/rl_real_a1.hpp index d004a49..5390148 100644 --- a/src/rl_sar/include/rl_real_a1.hpp +++ b/src/rl_sar/include/rl_real_a1.hpp @@ -8,6 +8,9 @@ #include // #include +#include "matplotlibcpp.h" +namespace plt = matplotlibcpp; + class RL_Real : public RL { public: diff --git a/src/rl_sar/include/rl_real_cyberdog.hpp b/src/rl_sar/include/rl_real_cyberdog.hpp index 31744fe..0d3b0a6 100644 --- a/src/rl_sar/include/rl_real_cyberdog.hpp +++ b/src/rl_sar/include/rl_real_cyberdog.hpp @@ -9,6 +9,9 @@ #include // #include +#include "matplotlibcpp.h" +namespace plt = matplotlibcpp; + using CyberdogData = Robot_Data; using CyberdogCmd = Motor_Cmd; diff --git a/src/rl_sar/include/rl_sim.hpp b/src/rl_sar/include/rl_sim.hpp index 8c4f579..ed936d7 100644 --- a/src/rl_sar/include/rl_sim.hpp +++ b/src/rl_sar/include/rl_sim.hpp @@ -14,6 +14,9 @@ // #include "robot_msgs/RobotState.h" // #include "robot_msgs/RobotCommand.h" +#include "matplotlibcpp.h" +namespace plt = matplotlibcpp; + class RL_Sim : public RL { public: diff --git a/src/rl_sar/library/rl_sdk/rl_sdk.hpp b/src/rl_sar/library/rl_sdk/rl_sdk.hpp index 6ab4c60..6b91ed7 100644 --- a/src/rl_sar/library/rl_sdk/rl_sdk.hpp +++ b/src/rl_sar/library/rl_sdk/rl_sdk.hpp @@ -4,9 +4,7 @@ #include #include #include - -#include "matplotlibcpp.h" -namespace plt = matplotlibcpp; +#include #include #define CONFIG_PATH CMAKE_CURRENT_SOURCE_DIR "/config.yaml"