diff --git a/src/rl_sar/CMakeLists.txt b/src/rl_sar/CMakeLists.txt index 141f094..1bb6ea1 100644 --- a/src/rl_sar/CMakeLists.txt +++ b/src/rl_sar/CMakeLists.txt @@ -1,15 +1,17 @@ cmake_minimum_required(VERSION 3.0.2) project(rl_sar) +add_definitions(-DCMAKE_CURRENT_SOURCE_DIR="${CMAKE_CURRENT_SOURCE_DIR}") + set(CMAKE_BUILD_TYPE Debug) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}") - -add_definitions(-DCMAKE_CURRENT_SOURCE_DIR="${CMAKE_CURRENT_SOURCE_DIR}") - find_package(Torch REQUIRED) +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${GAZEBO_CXX_FLAGS}") +find_package(gazebo REQUIRED) + find_package(catkin REQUIRED COMPONENTS controller_manager genmsg @@ -23,7 +25,6 @@ find_package(catkin REQUIRED COMPONENTS unitree_legged_msgs ) -find_package(gazebo REQUIRED) find_package(Python3 COMPONENTS Interpreter Development REQUIRED) catkin_package( @@ -43,8 +44,6 @@ include_directories( library/matplotlibcpp ) -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${GAZEBO_CXX_FLAGS}") - add_library(rl library/rl/rl.cpp library/rl/rl.hpp) target_link_libraries(rl "${TORCH_LIBRARIES}" Python3::Python Python3::Module) set_property(TARGET rl PROPERTY CXX_STANDARD 14) diff --git a/src/rl_sar/include/rl_real.hpp b/src/rl_sar/include/rl_real.hpp index 93351c4..f73259a 100644 --- a/src/rl_sar/include/rl_real.hpp +++ b/src/rl_sar/include/rl_real.hpp @@ -14,9 +14,9 @@ using namespace UNITREE_LEGGED_SDK; -enum InitState { +enum RobotState { STATE_WAITING = 0, - STATE_POS_INIT, + STATE_POS_START, STATE_RL_INIT, STATE_RL_START, STATE_POS_STOP, @@ -28,23 +28,23 @@ public: RL_Real(); ~RL_Real(); - void runModel(); - torch::Tensor forward() override; - torch::Tensor compute_observation() override; + void RunModel(); + torch::Tensor Forward() override; + torch::Tensor ComputeObservation() override; ObservationBuffer history_obs_buf; torch::Tensor history_obs; + int motiontime = 0; //udp - void UDPSend(); - void UDPRecv(); + void UDPSend(){udp.Send();} + void UDPRecv(){udp.Recv();} void RobotControl(); Safety safe; UDP udp; LowCmd cmd = {0}; LowState state = {0}; xRockerBtnDataStruct _keyData; - int motiontime = 0; std::shared_ptr loop_control; std::shared_ptr loop_udpSend; @@ -52,14 +52,16 @@ public: std::shared_ptr loop_rl; std::shared_ptr loop_plot; - float _percent; - float _startPos[12]; + float start_percent = 0.0; + float stop_percent = 0.0; + float start_pos[12]; + float stop_pos[12]; int robot_state = STATE_WAITING; - std::vector _t; - std::vector> _real_joint_pos, _target_joint_pos; - void plot(); + std::vector plot_t; + std::vector> plot_real_joint_pos, plot_target_joint_pos; + void Plot(); private: std::vector joint_names; std::vector joint_positions; diff --git a/src/rl_sar/include/rl_sim.hpp b/src/rl_sar/include/rl_sim.hpp index d765650..7a9fa8c 100644 --- a/src/rl_sar/include/rl_sim.hpp +++ b/src/rl_sar/include/rl_sim.hpp @@ -19,14 +19,14 @@ public: RL_Sim(); ~RL_Sim(); - void modelStatesCallback(const gazebo_msgs::ModelStates::ConstPtr &msg); - void jointStatesCallback(const sensor_msgs::JointState::ConstPtr &msg); - void cmdvelCallback(const geometry_msgs::Twist::ConstPtr &msg); + void ModelStatesCallback(const gazebo_msgs::ModelStates::ConstPtr &msg); + void JointStatesCallback(const sensor_msgs::JointState::ConstPtr &msg); + void CmdvelCallback(const geometry_msgs::Twist::ConstPtr &msg); - void runModel(); + void RunModel(); void RobotControl(); - torch::Tensor forward() override; - torch::Tensor compute_observation() override; + torch::Tensor Forward() override; + torch::Tensor ComputeObservation() override; ObservationBuffer history_obs_buf; torch::Tensor history_obs; @@ -37,9 +37,9 @@ public: std::shared_ptr loop_rl; std::shared_ptr loop_plot; - std::vector _t; - std::vector> _real_joint_pos, _target_joint_pos; - void plot(); + std::vector plot_t; + std::vector> plot_real_joint_pos, plot_target_joint_pos; + void Plot(); private: std::vector torque_command_topics; @@ -58,8 +58,6 @@ private: std::vector joint_positions; std::vector joint_velocities; - torch::Tensor torques; - std::chrono::high_resolution_clock::time_point start_time; // other rl module diff --git a/src/rl_sar/library/rl/rl.cpp b/src/rl_sar/library/rl/rl.cpp index d1f6d03..26fa4d2 100644 --- a/src/rl_sar/library/rl/rl.cpp +++ b/src/rl_sar/library/rl/rl.cpp @@ -1,6 +1,6 @@ #include "rl.hpp" -torch::Tensor RL::quat_rotate_inverse(torch::Tensor q, torch::Tensor v) +torch::Tensor RL::QuatRotateInverse(torch::Tensor q, torch::Tensor v) { c10::IntArrayRef shape = q.sizes(); torch::Tensor q_w = q.index({torch::indexing::Slice(), -1}); @@ -11,7 +11,7 @@ torch::Tensor RL::quat_rotate_inverse(torch::Tensor q, torch::Tensor v) return a - b + c; } -void RL::init_observations() +void RL::InitObservations() { this->obs.lin_vel = torch::tensor({{0.0, 0.0, 0.0}}); this->obs.ang_vel = torch::tensor({{0.0, 0.0, 0.0}}); @@ -23,7 +23,7 @@ void RL::init_observations() this->obs.actions = torch::tensor({{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}}); } -torch::Tensor RL::compute_torques(torch::Tensor actions) +torch::Tensor RL::ComputeTorques(torch::Tensor actions) { torch::Tensor actions_scaled = actions * this->params.action_scale; int indices[] = {0, 3, 6, 9}; @@ -32,12 +32,12 @@ torch::Tensor RL::compute_torques(torch::Tensor actions) actions_scaled[0][i] *= this->params.hip_scale_reduction; } - torch::Tensor torques = this->params.p_gains * (actions_scaled + this->params.default_dof_pos - this->obs.dof_pos) - this->params.d_gains * this->obs.dof_vel; - torch::Tensor clamped = torch::clamp(torques, -(this->params.torque_limits), this->params.torque_limits); + torch::Tensor output_torques = this->params.p_gains * (actions_scaled + this->params.default_dof_pos - this->obs.dof_pos) - this->params.d_gains * this->obs.dof_vel; + torch::Tensor clamped = torch::clamp(output_torques, -(this->params.torque_limits), this->params.torque_limits); return clamped; } -torch::Tensor RL::compute_pos(torch::Tensor actions) +torch::Tensor RL::ComputePosition(torch::Tensor actions) { torch::Tensor actions_scaled = actions * this->params.action_scale; int indices[] = {0, 3, 6, 9}; @@ -49,12 +49,12 @@ torch::Tensor RL::compute_pos(torch::Tensor actions) return actions_scaled + this->params.default_dof_pos; } -/* You may need to override this compute_observation() function -torch::Tensor RL::compute_observation() +/* You may need to override this ComputeObservation() function +torch::Tensor RL::ComputeObservation() { - torch::Tensor obs = torch::cat({(this->quat_rotate_inverse(this->base_quat, this->lin_vel)) * this->params.lin_vel_scale, - (this->quat_rotate_inverse(this->obs.base_quat, this->obs.ang_vel)) * this->params.ang_vel_scale, - this->quat_rotate_inverse(this->obs.base_quat, this->obs.gravity_vec), + torch::Tensor obs = torch::cat({(this->QuatRotateInverse(this->base_quat, this->lin_vel)) * this->params.lin_vel_scale, + (this->QuatRotateInverse(this->obs.base_quat, this->obs.ang_vel)) * this->params.ang_vel_scale, + this->QuatRotateInverse(this->obs.base_quat, this->obs.gravity_vec), this->obs.commands * this->params.commands_scale, (this->obs.dof_pos - this->params.default_dof_pos) * this->params.dof_pos_scale, this->obs.dof_vel * this->params.dof_vel_scale, @@ -69,10 +69,10 @@ torch::Tensor RL::compute_observation() } */ -/* You may need to override this forward() function -torch::Tensor RL::forward() +/* You may need to override this Forward() function +torch::Tensor RL::Forward() { - torch::Tensor obs = this->compute_observation(); + torch::Tensor obs = this->ComputeObservation(); torch::Tensor actor_input = torch::cat({obs}, 1); diff --git a/src/rl_sar/library/rl/rl.hpp b/src/rl_sar/library/rl/rl.hpp index ffe2d58..eb6790c 100644 --- a/src/rl_sar/library/rl/rl.hpp +++ b/src/rl_sar/library/rl/rl.hpp @@ -49,12 +49,12 @@ public: ModelParams params; Observations obs; - virtual torch::Tensor forward() = 0; - virtual torch::Tensor compute_observation() = 0; - torch::Tensor compute_torques(torch::Tensor actions); - torch::Tensor compute_pos(torch::Tensor actions); - torch::Tensor quat_rotate_inverse(torch::Tensor q, torch::Tensor v); - void init_observations(); + virtual torch::Tensor Forward() = 0; + virtual torch::Tensor ComputeObservation() = 0; + torch::Tensor ComputeTorques(torch::Tensor actions); + torch::Tensor ComputePosition(torch::Tensor actions); + torch::Tensor QuatRotateInverse(torch::Tensor q, torch::Tensor v); + void InitObservations(); protected: // rl module @@ -69,8 +69,8 @@ protected: torch::Tensor dof_vel; torch::Tensor actions; // output buffer - torch::Tensor torques; - torch::Tensor target_dof_pos; + torch::Tensor output_torques; + torch::Tensor output_dof_pos; }; #endif // RL_HPP \ No newline at end of file diff --git a/src/rl_sar/library/unitree_legged_sdk_3.2/include/unitree_legged_sdk/comm.h b/src/rl_sar/library/unitree_legged_sdk_3.2/include/unitree_legged_sdk/comm.h index 975c4ee..9bc2347 100644 --- a/src/rl_sar/library/unitree_legged_sdk_3.2/include/unitree_legged_sdk/comm.h +++ b/src/rl_sar/library/unitree_legged_sdk_3.2/include/unitree_legged_sdk/comm.h @@ -131,7 +131,7 @@ namespace UNITREE_LEGGED_SDK uint32_t SN; uint8_t bandWidth; uint8_t mode; // 0:idle, default stand 1:forced stand 2:walk continuously - float forwardSpeed; // speed of move forward or backward, scale: -1~1 + float forwardSpeed; // speed of move Forward or backward, scale: -1~1 float sideSpeed; // speed of move left or right, scale: -1~1 float rotateSpeed; // speed of spin left or right, scale: -1~1 float bodyHeight; // body height, scale: -1~1 diff --git a/src/rl_sar/src/rl_real.cpp b/src/rl_sar/src/rl_real.cpp index 1c9a467..c6e2841 100644 --- a/src/rl_sar/src/rl_real.cpp +++ b/src/rl_sar/src/rl_real.cpp @@ -1,20 +1,9 @@ #include "../include/rl_real.hpp" -// #define CONTROL_BY_TORQUE // #define PLOT RL_Real rl_sar; -void RL_Real::UDPRecv() -{ - udp.Recv(); -} - -void RL_Real::UDPSend() -{ - udp.Send(); -} - void RL_Real::RobotControl() { motiontime++; @@ -23,14 +12,28 @@ void RL_Real::RobotControl() memcpy(&_keyData, state.wirelessRemote, 40); // get joy button - if(robot_state < STATE_POS_INIT && (int)_keyData.btn.components.R2 == 1) + if(robot_state < STATE_POS_START && (int)_keyData.btn.components.R2 == 1) { - robot_state = STATE_POS_INIT; + start_percent = 0.0; + for(int i = 0; i < 12; ++i) + { + start_pos[i] = state.motorState[i].q; + } + robot_state = STATE_POS_START; } else if(robot_state < STATE_RL_INIT && (int)_keyData.btn.components.R1 == 1) { robot_state = STATE_RL_INIT; } + else if(robot_state == STATE_RL_START && (int)_keyData.btn.components.L2 == 1) + { + stop_percent = 0.0; + for(int i = 0; i < 12; ++i) + { + stop_pos[i] = state.motorState[i].q; + } + robot_state = STATE_POS_STOP; + } // wait for standup if(robot_state == STATE_WAITING) @@ -38,57 +41,68 @@ void RL_Real::RobotControl() for(int i = 0; i < 12; ++i) { cmd.motorCmd[i].q = state.motorState[i].q; - _startPos[i] = state.motorState[i].q; } } // standup (position control) - else if(robot_state == STATE_POS_INIT && _percent != 1) + else if(robot_state == STATE_POS_START && start_percent != 1) { - _percent += 1 / 1000.0; - _percent = _percent > 1 ? 1 : _percent; + start_percent += 1 / 1000.0; + start_percent = start_percent > 1 ? 1 : start_percent; for(int i = 0; i < 12; ++i) { cmd.motorCmd[i].mode = 0x0A; - cmd.motorCmd[i].q = (1 - _percent) * _startPos[i] + _percent * params.default_dof_pos[0][dof_mapping[i]].item(); + cmd.motorCmd[i].q = (1 - start_percent) * start_pos[i] + start_percent * params.default_dof_pos[0][dof_mapping[i]].item(); cmd.motorCmd[i].dq = 0; cmd.motorCmd[i].Kp = 50; cmd.motorCmd[i].Kd = 3; cmd.motorCmd[i].tau = 0; } - printf("initing %.3f%%\r", _percent*100.0); + printf("starting %.3f%%\r", start_percent*100.0); } // init obs and start rl loop - else if(robot_state == STATE_RL_INIT && _percent == 1) + else if(robot_state == STATE_RL_INIT && start_percent == 1) { robot_state = STATE_RL_START; - this->init_observations(); + this->InitObservations(); printf("\nstart rl loop\n"); loop_rl->start(); } // rl loop else if(robot_state == STATE_RL_START) { -#ifdef CONTROL_BY_TORQUE - for (int i = 0; i < 12; ++i) - { - cmd.motorCmd[i].mode = 0x0A; - cmd.motorCmd[i].q = 0; - cmd.motorCmd[i].dq = 0; - cmd.motorCmd[i].Kp = 0; - cmd.motorCmd[i].Kd = 0; - cmd.motorCmd[i].tau = torques[0][dof_mapping[i]].item(); - } -#else - for (int i = 0; i < 12; ++i) - { - cmd.motorCmd[i].mode = 0x0A; - cmd.motorCmd[i].q = target_dof_pos[0][dof_mapping[i]].item(); - cmd.motorCmd[i].dq = 0; - cmd.motorCmd[i].Kp = params.stiffness; - cmd.motorCmd[i].Kd = params.damping; - cmd.motorCmd[i].tau = 0; - } -#endif + for (int i = 0; i < 12; ++i) + { + cmd.motorCmd[i].mode = 0x0A; + cmd.motorCmd[i].q = output_dof_pos[0][dof_mapping[i]].item(); + cmd.motorCmd[i].dq = 0; + cmd.motorCmd[i].Kp = params.stiffness; + cmd.motorCmd[i].Kd = params.damping; + // cmd.motorCmd[i].tau = output_torques[0][dof_mapping[i]].item(); + cmd.motorCmd[i].tau = 0; + } + } + // move to start pos + else if(robot_state == STATE_POS_STOP && stop_percent != 1) + { + stop_percent += 1 / 1000.0; + stop_percent = stop_percent > 1 ? 1 : stop_percent; + for(int i = 0; i < 12; ++i) + { + cmd.motorCmd[i].mode = 0x0A; + cmd.motorCmd[i].q = (1 - stop_percent) * stop_pos[i] + stop_percent * start_pos[i]; + cmd.motorCmd[i].dq = 0; + cmd.motorCmd[i].Kp = 50; + cmd.motorCmd[i].Kd = 3; + cmd.motorCmd[i].tau = 0; + } + printf("stopping %.3f%%\r", stop_percent*100.0); + } + else if(robot_state == STATE_POS_STOP && stop_percent == 1) + { + robot_state = STATE_WAITING; + this->InitObservations(); + printf("\nstop rl loop\n"); + loop_rl->shutdown(); } safe.PowerProtect(cmd, state, 7); @@ -108,7 +122,7 @@ RL_Real::RL_Real() : safe(LeggedType::A1), udp(LOWLEVEL) this->actor = torch::jit::load(actor_path); this->encoder = torch::jit::load(encoder_path); this->vq = torch::jit::load(vq_path); - this->init_observations(); + this->InitObservations(); this->params.num_observations = 45; this->params.clip_obs = 100.0; @@ -139,22 +153,22 @@ RL_Real::RL_Real() : safe(LeggedType::A1), udp(LOWLEVEL) this->history_obs_buf = ObservationBuffer(1, this->params.num_observations, 6); - torques = torch::tensor({{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}}); - target_dof_pos = params.default_dof_pos; - _real_joint_pos.resize(12); - _target_joint_pos.resize(12); + output_torques = torch::tensor({{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}}); + output_dof_pos = params.default_dof_pos; + plot_real_joint_pos.resize(12); + plot_target_joint_pos.resize(12); loop_control = std::make_shared("loop_control", 0.002, boost::bind(&RL_Real::RobotControl, this)); loop_udpSend = std::make_shared("loop_udpSend", 0.002, 3, boost::bind(&RL_Real::UDPSend, this)); loop_udpRecv = std::make_shared("loop_udpRecv", 0.002, 3, boost::bind(&RL_Real::UDPRecv, this)); - loop_rl = std::make_shared("loop_rl" , 0.02 , boost::bind(&RL_Real::runModel, this)); + loop_rl = std::make_shared("loop_rl" , 0.02 , boost::bind(&RL_Real::RunModel, this)); loop_udpSend->start(); loop_udpRecv->start(); loop_control->start(); #ifdef PLOT - loop_plot = std::make_shared("loop_plot" , 0.002, boost::bind(&RL_Real::plot, this)); + loop_plot = std::make_shared("loop_plot" , 0.002, boost::bind(&RL_Real::Plot, this)); loop_plot->start(); #endif } @@ -171,25 +185,25 @@ RL_Real::~RL_Real() printf("exit\n"); } -void RL_Real::plot() +void RL_Real::Plot() { - _t.push_back(motiontime); + plot_t.push_back(motiontime); plt::cla(); plt::clf(); for(int i = 0; i < 12; ++i) { - _real_joint_pos[i].push_back(state.motorState[i].q); - _target_joint_pos[i].push_back(cmd.motorCmd[i].q); + plot_real_joint_pos[i].push_back(state.motorState[i].q); + plot_target_joint_pos[i].push_back(cmd.motorCmd[i].q); plt::subplot(4, 3, i+1); - plt::named_plot("_real_joint_pos", _t, _real_joint_pos[i], "r"); - plt::named_plot("_target_joint_pos", _t, _target_joint_pos[i], "b"); + plt::named_plot("_real_joint_pos", plot_t, plot_real_joint_pos[i], "r"); + plt::named_plot("_target_joint_pos", plot_t, plot_target_joint_pos[i], "b"); plt::xlim(motiontime-10000, motiontime); } // plt::legend(); plt::pause(0.0001); } -void RL_Real::runModel() +void RL_Real::RunModel() { if(robot_state == STATE_RL_START) { @@ -224,21 +238,19 @@ void RL_Real::runModel() state.motorState[RL_0].dq, state.motorState[RL_1].dq, state.motorState[RL_2].dq, state.motorState[RR_0].dq, state.motorState[RR_1].dq, state.motorState[RR_2].dq}}); - torch::Tensor actions = this->forward(); -#ifdef CONTROL_BY_TORQUE - torques = this->compute_torques(actions); -#else - target_dof_pos = this->compute_pos(actions); -#endif + torch::Tensor actions = this->Forward(); + + output_torques = this->ComputeTorques(actions); + output_dof_pos = this->ComputePosition(actions); } } -torch::Tensor RL_Real::compute_observation() +torch::Tensor RL_Real::ComputeObservation() { - torch::Tensor obs = torch::cat({// (this->quat_rotate_inverse(this->base_quat, this->lin_vel)) * this->params.lin_vel_scale, - this->quat_rotate_inverse(this->obs.base_quat, this->obs.ang_vel) * this->params.ang_vel_scale, - this->quat_rotate_inverse(this->obs.base_quat, this->obs.gravity_vec), + torch::Tensor obs = torch::cat({// (this->QuatRotateInverse(this->base_quat, this->lin_vel)) * this->params.lin_vel_scale, + this->QuatRotateInverse(this->obs.base_quat, this->obs.ang_vel) * this->params.ang_vel_scale, + this->QuatRotateInverse(this->obs.base_quat, this->obs.gravity_vec), this->obs.commands * this->params.commands_scale, (this->obs.dof_pos - this->params.default_dof_pos) * this->params.dof_pos_scale, this->obs.dof_vel * this->params.dof_vel_scale, @@ -248,9 +260,9 @@ torch::Tensor RL_Real::compute_observation() return obs; } -torch::Tensor RL_Real::forward() +torch::Tensor RL_Real::Forward() { - torch::Tensor obs = this->compute_observation(); + torch::Tensor obs = this->ComputeObservation(); history_obs_buf.insert(obs); history_obs = history_obs_buf.get_obs_vec({0, 1, 2, 3, 4, 5}); diff --git a/src/rl_sar/src/rl_sim.cpp b/src/rl_sar/src/rl_sim.cpp index 0d7a258..1258372 100644 --- a/src/rl_sar/src/rl_sim.cpp +++ b/src/rl_sar/src/rl_sim.cpp @@ -8,31 +8,30 @@ void RL_Sim::RobotControl() for (int i = 0; i < 12; ++i) { motor_commands[i].mode = 0x0A; - // motor_commands[i].tau = torques[0][i].item(); - motor_commands[i].tau = 0; - motor_commands[i].q = target_dof_pos[0][i].item(); + motor_commands[i].q = output_dof_pos[0][i].item(); motor_commands[i].dq = 0; motor_commands[i].Kp = params.stiffness; motor_commands[i].Kd = params.damping; + // motor_commands[i].tau = output_torques[0][i].item(); + motor_commands[i].tau = 0; torque_publishers[joint_names[i]].publish(motor_commands[i]); } } - -void RL_Sim::plot() +void RL_Sim::Plot() { int dof_mapping[13] = {1, 2, 0, 4, 5, 3, 7, 8, 6, 10, 11, 9}; - _t.push_back(motiontime); + plot_t.push_back(motiontime); plt::cla(); plt::clf(); for(int i = 0; i < 12; ++i) { - _real_joint_pos[i].push_back(joint_positions[dof_mapping[i]]); - _target_joint_pos[i].push_back(motor_commands[i].q); + plot_real_joint_pos[i].push_back(joint_positions[dof_mapping[i]]); + plot_target_joint_pos[i].push_back(motor_commands[i].q); plt::subplot(4, 3, i+1); - plt::named_plot("_real_joint_pos", _t, _real_joint_pos[i], "r"); - plt::named_plot("_target_joint_pos", _t, _target_joint_pos[i], "b"); + plt::named_plot("_real_joint_pos", plot_t, plot_real_joint_pos[i], "r"); + plt::named_plot("_target_joint_pos", plot_t, plot_target_joint_pos[i], "b"); plt::xlim(motiontime-10000, motiontime); } // plt::legend(); @@ -55,7 +54,7 @@ RL_Sim::RL_Sim() this->actor = torch::jit::load(actor_path); this->encoder = torch::jit::load(encoder_path); this->vq = torch::jit::load(vq_path); - this->init_observations(); + this->InitObservations(); this->params.num_observations = 45; this->params.clip_obs = 100.0; @@ -87,14 +86,14 @@ RL_Sim::RL_Sim() this->history_obs_buf = ObservationBuffer(1, this->params.num_observations, 6); - torques = torch::tensor({{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}}); - target_dof_pos = params.default_dof_pos; + output_torques = torch::tensor({{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}}); + output_dof_pos = params.default_dof_pos; joint_positions = {0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}; joint_velocities = {0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}; - _real_joint_pos.resize(12); - _target_joint_pos.resize(12); + plot_real_joint_pos.resize(12); + plot_target_joint_pos.resize(12); - cmd_vel_subscriber_ = nh.subscribe("/cmd_vel", 10, &RL_Sim::cmdvelCallback, this); + cmd_vel_subscriber_ = nh.subscribe("/cmd_vel", 10, &RL_Sim::CmdvelCallback, this); std::string ros_namespace = "/a1_gazebo/"; @@ -112,18 +111,18 @@ RL_Sim::RL_Sim() } model_state_subscriber_ = nh.subscribe( - "/gazebo/model_states", 10, &RL_Sim::modelStatesCallback, this); + "/gazebo/model_states", 10, &RL_Sim::ModelStatesCallback, this); joint_state_subscriber_ = nh.subscribe( - "/a1_gazebo/joint_states", 10, &RL_Sim::jointStatesCallback, this); + "/a1_gazebo/joint_states", 10, &RL_Sim::JointStatesCallback, this); loop_control = std::make_shared("loop_control", 0.002, boost::bind(&RL_Sim::RobotControl, this)); - loop_rl = std::make_shared("loop_rl" , 0.02 , boost::bind(&RL_Sim::runModel, this)); + loop_rl = std::make_shared("loop_rl" , 0.02 , boost::bind(&RL_Sim::RunModel, this)); loop_control->start(); loop_rl->start(); #ifdef PLOT - loop_plot = std::make_shared("loop_plot" , 0.002, boost::bind(&RL_Sim::plot, this)); + loop_plot = std::make_shared("loop_plot" , 0.002, boost::bind(&RL_Sim::Plot, this)); loop_plot->start(); #endif } @@ -138,25 +137,25 @@ RL_Sim::~RL_Sim() printf("exit\n"); } -void RL_Sim::modelStatesCallback(const gazebo_msgs::ModelStates::ConstPtr &msg) +void RL_Sim::ModelStatesCallback(const gazebo_msgs::ModelStates::ConstPtr &msg) { vel = msg->twist[2]; pose = msg->pose[2]; } -void RL_Sim::cmdvelCallback(const geometry_msgs::Twist::ConstPtr &msg) +void RL_Sim::CmdvelCallback(const geometry_msgs::Twist::ConstPtr &msg) { cmd_vel = *msg; } -void RL_Sim::jointStatesCallback(const sensor_msgs::JointState::ConstPtr &msg) +void RL_Sim::JointStatesCallback(const sensor_msgs::JointState::ConstPtr &msg) { joint_positions = msg->position; joint_velocities = msg->velocity; } -void RL_Sim::runModel() +void RL_Sim::RunModel() { // auto duration = std::chrono::duration_cast(std::chrono::high_resolution_clock::now() - start_time).count(); // std::cout << "Execution time: " << duration << " microseconds" << std::endl; @@ -190,16 +189,17 @@ void RL_Sim::runModel() joint_velocities[7], joint_velocities[8], joint_velocities[6], joint_velocities[10], joint_velocities[11], joint_velocities[9]}}); - torch::Tensor actions = this->forward(); - torques = this->compute_torques(actions); - target_dof_pos = this->compute_pos(actions); + torch::Tensor actions = this->Forward(); + + output_torques = this->ComputeTorques(actions); + output_dof_pos = this->ComputePosition(actions); } -torch::Tensor RL_Sim::compute_observation() +torch::Tensor RL_Sim::ComputeObservation() { - torch::Tensor obs = torch::cat({// (this->quat_rotate_inverse(this->base_quat, this->lin_vel)) * this->params.lin_vel_scale, - (this->quat_rotate_inverse(this->obs.base_quat, this->obs.ang_vel)) * this->params.ang_vel_scale, - this->quat_rotate_inverse(this->obs.base_quat, this->obs.gravity_vec), + torch::Tensor obs = torch::cat({// (this->QuatRotateInverse(this->base_quat, this->lin_vel)) * this->params.lin_vel_scale, + (this->QuatRotateInverse(this->obs.base_quat, this->obs.ang_vel)) * this->params.ang_vel_scale, + this->QuatRotateInverse(this->obs.base_quat, this->obs.gravity_vec), this->obs.commands * this->params.commands_scale, (this->obs.dof_pos - this->params.default_dof_pos) * this->params.dof_pos_scale, this->obs.dof_vel * this->params.dof_vel_scale, @@ -209,9 +209,9 @@ torch::Tensor RL_Sim::compute_observation() return obs; } -torch::Tensor RL_Sim::forward() +torch::Tensor RL_Sim::Forward() { - torch::Tensor obs = this->compute_observation(); + torch::Tensor obs = this->ComputeObservation(); history_obs_buf.insert(obs); history_obs = history_obs_buf.get_obs_vec({0, 1, 2, 3, 4, 5});