diff --git a/src/rl_sar/include/rl_real.hpp b/src/rl_sar/include/rl_real.hpp index 06469f6..c96ba5b 100644 --- a/src/rl_sar/include/rl_real.hpp +++ b/src/rl_sar/include/rl_real.hpp @@ -52,8 +52,6 @@ public: std::shared_ptr loop_rl; float _percent; - // float _targetPos[12] = {0.0, 0.8, -1.6, 0.0, 0.8, -1.6, - // 0.0, 0.8, -1.6, 0.0, 0.8, -1.6}; float _startPos[12]; int init_state = STATE_WAITING; @@ -63,21 +61,7 @@ private: std::vector joint_positions; std::vector joint_velocities; - torch::Tensor torques = torch::tensor({{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}}); - int dof_mapping[13] = {3, 4, 5, - 0, 1, 2, - 9, 10, 11, - 6, 7, 8}; - float Kp[13] = {20, 10, 10, - 20, 10, 10, - 20, 10, 10, - 20, 10, 10}; - float Kd[13] = {1.0, 0.5, 0.5, - 1.0, 0.5, 0.5, - 1.0, 0.5, 0.5, - 1.0, 0.5, 0.5}; - torch::Tensor target_dof_pos; - torch::Tensor compute_pos(torch::Tensor actions); + int dof_mapping[13] = {3, 4, 5, 0, 1, 2, 9, 10, 11, 6, 7, 8}; std::chrono::high_resolution_clock::time_point start_time; diff --git a/src/rl_sar/launch/start_env.launch b/src/rl_sar/launch/start_env.launch index 158ca39..bc881e2 100644 --- a/src/rl_sar/launch/start_env.launch +++ b/src/rl_sar/launch/start_env.launch @@ -34,7 +34,7 @@ - 5000 + params.action_scale; + int indices[] = {0, 3, 6, 9}; + for (int i : indices) + { + actions_scaled[0][i] *= this->params.hip_scale_reduction; + } + + return actions_scaled + this->params.default_dof_pos; +} + /* You may need to override this compute_observation() function torch::Tensor RL::compute_observation() { diff --git a/src/rl_sar/library/rl/rl.hpp b/src/rl_sar/library/rl/rl.hpp index 9bb57c1..b661f20 100644 --- a/src/rl_sar/library/rl/rl.hpp +++ b/src/rl_sar/library/rl/rl.hpp @@ -44,8 +44,8 @@ public: virtual torch::Tensor forward() = 0; virtual torch::Tensor compute_observation() = 0; - torch::Tensor compute_torques(torch::Tensor actions); + torch::Tensor compute_pos(torch::Tensor actions); torch::Tensor quat_rotate_inverse(torch::Tensor q, torch::Tensor v); void init_observations(); @@ -61,6 +61,9 @@ protected: torch::Tensor dof_pos; torch::Tensor dof_vel; torch::Tensor actions; + + torch::Tensor torques = torch::tensor({{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}}); + torch::Tensor target_dof_pos; }; #endif // RL_HPP \ No newline at end of file diff --git a/src/rl_sar/src/rl_real.cpp b/src/rl_sar/src/rl_real.cpp index 2a5795d..f02885a 100644 --- a/src/rl_sar/src/rl_real.cpp +++ b/src/rl_sar/src/rl_real.cpp @@ -103,8 +103,8 @@ void RL_Real::RobotControl() cmd.motorCmd[i].mode = 0x0A; cmd.motorCmd[i].q = target_dof_pos[0][dof_mapping[i]].item(); cmd.motorCmd[i].dq = 0; - cmd.motorCmd[i].Kp = 15; - cmd.motorCmd[i].Kd = 1.5; + cmd.motorCmd[i].Kp = params.stiffness; + cmd.motorCmd[i].Kd = params.damping; cmd.motorCmd[i].tau = 0; } #endif @@ -181,24 +181,12 @@ RL_Real::~RL_Real() printf("shutdown\n"); } -torch::Tensor RL_Real::compute_pos(torch::Tensor actions) -{ - torch::Tensor actions_scaled = actions * this->params.action_scale; - int indices[] = {0, 3, 6, 9}; - for (int i : indices) - { - actions_scaled[0][i] *= this->params.hip_scale_reduction; - } - - return actions_scaled + this->params.default_dof_pos; -} - void RL_Real::runModel() { if(init_state == STATE_RL_START) { auto duration = std::chrono::duration_cast(std::chrono::high_resolution_clock::now() - start_time).count(); - // std::cout << "Execution time: " << duration << " microseconds" << std::endl; + std::cout << "Execution time: " << duration << " microseconds" << std::endl; start_time = std::chrono::high_resolution_clock::now(); // printf("%f, %f, %f\n", state.imu.gyroscope[0], state.imu.gyroscope[1], state.imu.gyroscope[2]); diff --git a/src/rl_sar/src/rl_sim.cpp b/src/rl_sar/src/rl_sim.cpp index b14d335..502f0af 100644 --- a/src/rl_sar/src/rl_sim.cpp +++ b/src/rl_sar/src/rl_sim.cpp @@ -49,10 +49,12 @@ RL_Sim::RL_Sim() this->history_obs_buf = ObservationBuffer(1, this->params.num_observations, 6); + target_dof_pos = params.default_dof_pos; + cmd_vel_subscriber_ = nh.subscribe( "/cmd_vel", 10, &RL_Sim::cmdvelCallback, this); - timer = nh.createTimer(ros::Duration(0.005), &RL_Sim::runModel, this); + timer = nh.createTimer(ros::Duration(0.02), &RL_Sim::runModel, this); ros_namespace = "/a1_gazebo/"; @@ -115,11 +117,17 @@ void RL_Sim::runModel(const ros::TimerEvent &event) torch::Tensor actions = this->forward(); torques = this->compute_torques(actions); + target_dof_pos = this->compute_pos(actions); for (int i = 0; i < 12; ++i) { - torque_commands[i].tau = torques[0][i].item(); torque_commands[i].mode = 0x0A; + // torque_commands[i].tau = torques[0][i].item(); + torque_commands[i].tau = 0; + torque_commands[i].q = target_dof_pos[0][i].item(); + torque_commands[i].dq = 0; + torque_commands[i].Kp = params.stiffness; + torque_commands[i].Kd = params.damping; torque_publishers[joint_names[i]].publish(torque_commands[i]); } @@ -133,8 +141,8 @@ torch::Tensor RL_Sim::compute_observation() this->obs.commands * this->params.commands_scale, (this->obs.dof_pos - this->params.default_dof_pos) * this->params.dof_pos_scale, this->obs.dof_vel * this->params.dof_vel_scale, - this->obs.actions}, - 1); + this->obs.actions + },1); obs = torch::clamp(obs, -this->params.clip_obs, this->params.clip_obs); return obs; }