mirror of https://github.com/fan-ziqi/rl_sar.git
fix freq bug
This commit is contained in:
parent
220347de1d
commit
8d2ab423b5
|
@ -52,8 +52,6 @@ public:
|
|||
std::shared_ptr<LoopFunc> loop_rl;
|
||||
|
||||
float _percent;
|
||||
// float _targetPos[12] = {0.0, 0.8, -1.6, 0.0, 0.8, -1.6,
|
||||
// 0.0, 0.8, -1.6, 0.0, 0.8, -1.6};
|
||||
float _startPos[12];
|
||||
|
||||
int init_state = STATE_WAITING;
|
||||
|
@ -63,21 +61,7 @@ private:
|
|||
std::vector<double> joint_positions;
|
||||
std::vector<double> joint_velocities;
|
||||
|
||||
torch::Tensor torques = torch::tensor({{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}});
|
||||
int dof_mapping[13] = {3, 4, 5,
|
||||
0, 1, 2,
|
||||
9, 10, 11,
|
||||
6, 7, 8};
|
||||
float Kp[13] = {20, 10, 10,
|
||||
20, 10, 10,
|
||||
20, 10, 10,
|
||||
20, 10, 10};
|
||||
float Kd[13] = {1.0, 0.5, 0.5,
|
||||
1.0, 0.5, 0.5,
|
||||
1.0, 0.5, 0.5,
|
||||
1.0, 0.5, 0.5};
|
||||
torch::Tensor target_dof_pos;
|
||||
torch::Tensor compute_pos(torch::Tensor actions);
|
||||
int dof_mapping[13] = {3, 4, 5, 0, 1, 2, 9, 10, 11, 6, 7, 8};
|
||||
|
||||
std::chrono::high_resolution_clock::time_point start_time;
|
||||
|
||||
|
|
|
@ -34,7 +34,7 @@
|
|||
<!-- Load joint controller configurations from YAML file to parameter server -->
|
||||
<rosparam file="$(arg dollar)$(arg robot_path)/config/robot_control.yaml" command="load"/>
|
||||
|
||||
<rosparam param="/a1_gazebo/joint_state_controller/publish_rate">5000</rosparam>
|
||||
<!-- <rosparam param="/a1_gazebo/joint_state_controller/publish_rate">5000</rosparam> -->
|
||||
|
||||
<!-- load the controllers -->
|
||||
<node pkg="controller_manager" type="spawner" name="controller_spawner" respawn="false"
|
||||
|
|
|
@ -37,6 +37,18 @@ torch::Tensor RL::compute_torques(torch::Tensor actions)
|
|||
return clamped;
|
||||
}
|
||||
|
||||
torch::Tensor RL::compute_pos(torch::Tensor actions)
|
||||
{
|
||||
torch::Tensor actions_scaled = actions * this->params.action_scale;
|
||||
int indices[] = {0, 3, 6, 9};
|
||||
for (int i : indices)
|
||||
{
|
||||
actions_scaled[0][i] *= this->params.hip_scale_reduction;
|
||||
}
|
||||
|
||||
return actions_scaled + this->params.default_dof_pos;
|
||||
}
|
||||
|
||||
/* You may need to override this compute_observation() function
|
||||
torch::Tensor RL::compute_observation()
|
||||
{
|
||||
|
|
|
@ -44,8 +44,8 @@ public:
|
|||
|
||||
virtual torch::Tensor forward() = 0;
|
||||
virtual torch::Tensor compute_observation() = 0;
|
||||
|
||||
torch::Tensor compute_torques(torch::Tensor actions);
|
||||
torch::Tensor compute_pos(torch::Tensor actions);
|
||||
torch::Tensor quat_rotate_inverse(torch::Tensor q, torch::Tensor v);
|
||||
void init_observations();
|
||||
|
||||
|
@ -61,6 +61,9 @@ protected:
|
|||
torch::Tensor dof_pos;
|
||||
torch::Tensor dof_vel;
|
||||
torch::Tensor actions;
|
||||
|
||||
torch::Tensor torques = torch::tensor({{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}});
|
||||
torch::Tensor target_dof_pos;
|
||||
};
|
||||
|
||||
#endif // RL_HPP
|
|
@ -103,8 +103,8 @@ void RL_Real::RobotControl()
|
|||
cmd.motorCmd[i].mode = 0x0A;
|
||||
cmd.motorCmd[i].q = target_dof_pos[0][dof_mapping[i]].item<double>();
|
||||
cmd.motorCmd[i].dq = 0;
|
||||
cmd.motorCmd[i].Kp = 15;
|
||||
cmd.motorCmd[i].Kd = 1.5;
|
||||
cmd.motorCmd[i].Kp = params.stiffness;
|
||||
cmd.motorCmd[i].Kd = params.damping;
|
||||
cmd.motorCmd[i].tau = 0;
|
||||
}
|
||||
#endif
|
||||
|
@ -181,24 +181,12 @@ RL_Real::~RL_Real()
|
|||
printf("shutdown\n");
|
||||
}
|
||||
|
||||
torch::Tensor RL_Real::compute_pos(torch::Tensor actions)
|
||||
{
|
||||
torch::Tensor actions_scaled = actions * this->params.action_scale;
|
||||
int indices[] = {0, 3, 6, 9};
|
||||
for (int i : indices)
|
||||
{
|
||||
actions_scaled[0][i] *= this->params.hip_scale_reduction;
|
||||
}
|
||||
|
||||
return actions_scaled + this->params.default_dof_pos;
|
||||
}
|
||||
|
||||
void RL_Real::runModel()
|
||||
{
|
||||
if(init_state == STATE_RL_START)
|
||||
{
|
||||
auto duration = std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::high_resolution_clock::now() - start_time).count();
|
||||
// std::cout << "Execution time: " << duration << " microseconds" << std::endl;
|
||||
std::cout << "Execution time: " << duration << " microseconds" << std::endl;
|
||||
start_time = std::chrono::high_resolution_clock::now();
|
||||
|
||||
// printf("%f, %f, %f\n", state.imu.gyroscope[0], state.imu.gyroscope[1], state.imu.gyroscope[2]);
|
||||
|
|
|
@ -49,10 +49,12 @@ RL_Sim::RL_Sim()
|
|||
|
||||
this->history_obs_buf = ObservationBuffer(1, this->params.num_observations, 6);
|
||||
|
||||
target_dof_pos = params.default_dof_pos;
|
||||
|
||||
cmd_vel_subscriber_ = nh.subscribe<geometry_msgs::Twist>(
|
||||
"/cmd_vel", 10, &RL_Sim::cmdvelCallback, this);
|
||||
|
||||
timer = nh.createTimer(ros::Duration(0.005), &RL_Sim::runModel, this);
|
||||
timer = nh.createTimer(ros::Duration(0.02), &RL_Sim::runModel, this);
|
||||
|
||||
ros_namespace = "/a1_gazebo/";
|
||||
|
||||
|
@ -115,11 +117,17 @@ void RL_Sim::runModel(const ros::TimerEvent &event)
|
|||
|
||||
torch::Tensor actions = this->forward();
|
||||
torques = this->compute_torques(actions);
|
||||
target_dof_pos = this->compute_pos(actions);
|
||||
|
||||
for (int i = 0; i < 12; ++i)
|
||||
{
|
||||
torque_commands[i].tau = torques[0][i].item<double>();
|
||||
torque_commands[i].mode = 0x0A;
|
||||
// torque_commands[i].tau = torques[0][i].item<double>();
|
||||
torque_commands[i].tau = 0;
|
||||
torque_commands[i].q = target_dof_pos[0][i].item<double>();
|
||||
torque_commands[i].dq = 0;
|
||||
torque_commands[i].Kp = params.stiffness;
|
||||
torque_commands[i].Kd = params.damping;
|
||||
|
||||
torque_publishers[joint_names[i]].publish(torque_commands[i]);
|
||||
}
|
||||
|
@ -133,8 +141,8 @@ torch::Tensor RL_Sim::compute_observation()
|
|||
this->obs.commands * this->params.commands_scale,
|
||||
(this->obs.dof_pos - this->params.default_dof_pos) * this->params.dof_pos_scale,
|
||||
this->obs.dof_vel * this->params.dof_vel_scale,
|
||||
this->obs.actions},
|
||||
1);
|
||||
this->obs.actions
|
||||
},1);
|
||||
obs = torch::clamp(obs, -this->params.clip_obs, this->params.clip_obs);
|
||||
return obs;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue