fix freq bug

This commit is contained in:
fan-ziqi 2024-03-15 20:27:06 +08:00
parent 220347de1d
commit 8d2ab423b5
6 changed files with 33 additions and 38 deletions

View File

@ -52,8 +52,6 @@ public:
std::shared_ptr<LoopFunc> loop_rl; std::shared_ptr<LoopFunc> loop_rl;
float _percent; float _percent;
// float _targetPos[12] = {0.0, 0.8, -1.6, 0.0, 0.8, -1.6,
// 0.0, 0.8, -1.6, 0.0, 0.8, -1.6};
float _startPos[12]; float _startPos[12];
int init_state = STATE_WAITING; int init_state = STATE_WAITING;
@ -63,21 +61,7 @@ private:
std::vector<double> joint_positions; std::vector<double> joint_positions;
std::vector<double> joint_velocities; std::vector<double> joint_velocities;
torch::Tensor torques = torch::tensor({{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}}); int dof_mapping[13] = {3, 4, 5, 0, 1, 2, 9, 10, 11, 6, 7, 8};
int dof_mapping[13] = {3, 4, 5,
0, 1, 2,
9, 10, 11,
6, 7, 8};
float Kp[13] = {20, 10, 10,
20, 10, 10,
20, 10, 10,
20, 10, 10};
float Kd[13] = {1.0, 0.5, 0.5,
1.0, 0.5, 0.5,
1.0, 0.5, 0.5,
1.0, 0.5, 0.5};
torch::Tensor target_dof_pos;
torch::Tensor compute_pos(torch::Tensor actions);
std::chrono::high_resolution_clock::time_point start_time; std::chrono::high_resolution_clock::time_point start_time;

View File

@ -34,7 +34,7 @@
<!-- Load joint controller configurations from YAML file to parameter server --> <!-- Load joint controller configurations from YAML file to parameter server -->
<rosparam file="$(arg dollar)$(arg robot_path)/config/robot_control.yaml" command="load"/> <rosparam file="$(arg dollar)$(arg robot_path)/config/robot_control.yaml" command="load"/>
<rosparam param="/a1_gazebo/joint_state_controller/publish_rate">5000</rosparam> <!-- <rosparam param="/a1_gazebo/joint_state_controller/publish_rate">5000</rosparam> -->
<!-- load the controllers --> <!-- load the controllers -->
<node pkg="controller_manager" type="spawner" name="controller_spawner" respawn="false" <node pkg="controller_manager" type="spawner" name="controller_spawner" respawn="false"

View File

@ -37,6 +37,18 @@ torch::Tensor RL::compute_torques(torch::Tensor actions)
return clamped; return clamped;
} }
torch::Tensor RL::compute_pos(torch::Tensor actions)
{
torch::Tensor actions_scaled = actions * this->params.action_scale;
int indices[] = {0, 3, 6, 9};
for (int i : indices)
{
actions_scaled[0][i] *= this->params.hip_scale_reduction;
}
return actions_scaled + this->params.default_dof_pos;
}
/* You may need to override this compute_observation() function /* You may need to override this compute_observation() function
torch::Tensor RL::compute_observation() torch::Tensor RL::compute_observation()
{ {

View File

@ -44,8 +44,8 @@ public:
virtual torch::Tensor forward() = 0; virtual torch::Tensor forward() = 0;
virtual torch::Tensor compute_observation() = 0; virtual torch::Tensor compute_observation() = 0;
torch::Tensor compute_torques(torch::Tensor actions); torch::Tensor compute_torques(torch::Tensor actions);
torch::Tensor compute_pos(torch::Tensor actions);
torch::Tensor quat_rotate_inverse(torch::Tensor q, torch::Tensor v); torch::Tensor quat_rotate_inverse(torch::Tensor q, torch::Tensor v);
void init_observations(); void init_observations();
@ -61,6 +61,9 @@ protected:
torch::Tensor dof_pos; torch::Tensor dof_pos;
torch::Tensor dof_vel; torch::Tensor dof_vel;
torch::Tensor actions; torch::Tensor actions;
torch::Tensor torques = torch::tensor({{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}});
torch::Tensor target_dof_pos;
}; };
#endif // RL_HPP #endif // RL_HPP

View File

@ -103,8 +103,8 @@ void RL_Real::RobotControl()
cmd.motorCmd[i].mode = 0x0A; cmd.motorCmd[i].mode = 0x0A;
cmd.motorCmd[i].q = target_dof_pos[0][dof_mapping[i]].item<double>(); cmd.motorCmd[i].q = target_dof_pos[0][dof_mapping[i]].item<double>();
cmd.motorCmd[i].dq = 0; cmd.motorCmd[i].dq = 0;
cmd.motorCmd[i].Kp = 15; cmd.motorCmd[i].Kp = params.stiffness;
cmd.motorCmd[i].Kd = 1.5; cmd.motorCmd[i].Kd = params.damping;
cmd.motorCmd[i].tau = 0; cmd.motorCmd[i].tau = 0;
} }
#endif #endif
@ -181,24 +181,12 @@ RL_Real::~RL_Real()
printf("shutdown\n"); printf("shutdown\n");
} }
torch::Tensor RL_Real::compute_pos(torch::Tensor actions)
{
torch::Tensor actions_scaled = actions * this->params.action_scale;
int indices[] = {0, 3, 6, 9};
for (int i : indices)
{
actions_scaled[0][i] *= this->params.hip_scale_reduction;
}
return actions_scaled + this->params.default_dof_pos;
}
void RL_Real::runModel() void RL_Real::runModel()
{ {
if(init_state == STATE_RL_START) if(init_state == STATE_RL_START)
{ {
auto duration = std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::high_resolution_clock::now() - start_time).count(); auto duration = std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::high_resolution_clock::now() - start_time).count();
// std::cout << "Execution time: " << duration << " microseconds" << std::endl; std::cout << "Execution time: " << duration << " microseconds" << std::endl;
start_time = std::chrono::high_resolution_clock::now(); start_time = std::chrono::high_resolution_clock::now();
// printf("%f, %f, %f\n", state.imu.gyroscope[0], state.imu.gyroscope[1], state.imu.gyroscope[2]); // printf("%f, %f, %f\n", state.imu.gyroscope[0], state.imu.gyroscope[1], state.imu.gyroscope[2]);

View File

@ -49,10 +49,12 @@ RL_Sim::RL_Sim()
this->history_obs_buf = ObservationBuffer(1, this->params.num_observations, 6); this->history_obs_buf = ObservationBuffer(1, this->params.num_observations, 6);
target_dof_pos = params.default_dof_pos;
cmd_vel_subscriber_ = nh.subscribe<geometry_msgs::Twist>( cmd_vel_subscriber_ = nh.subscribe<geometry_msgs::Twist>(
"/cmd_vel", 10, &RL_Sim::cmdvelCallback, this); "/cmd_vel", 10, &RL_Sim::cmdvelCallback, this);
timer = nh.createTimer(ros::Duration(0.005), &RL_Sim::runModel, this); timer = nh.createTimer(ros::Duration(0.02), &RL_Sim::runModel, this);
ros_namespace = "/a1_gazebo/"; ros_namespace = "/a1_gazebo/";
@ -115,11 +117,17 @@ void RL_Sim::runModel(const ros::TimerEvent &event)
torch::Tensor actions = this->forward(); torch::Tensor actions = this->forward();
torques = this->compute_torques(actions); torques = this->compute_torques(actions);
target_dof_pos = this->compute_pos(actions);
for (int i = 0; i < 12; ++i) for (int i = 0; i < 12; ++i)
{ {
torque_commands[i].tau = torques[0][i].item<double>();
torque_commands[i].mode = 0x0A; torque_commands[i].mode = 0x0A;
// torque_commands[i].tau = torques[0][i].item<double>();
torque_commands[i].tau = 0;
torque_commands[i].q = target_dof_pos[0][i].item<double>();
torque_commands[i].dq = 0;
torque_commands[i].Kp = params.stiffness;
torque_commands[i].Kd = params.damping;
torque_publishers[joint_names[i]].publish(torque_commands[i]); torque_publishers[joint_names[i]].publish(torque_commands[i]);
} }
@ -133,8 +141,8 @@ torch::Tensor RL_Sim::compute_observation()
this->obs.commands * this->params.commands_scale, this->obs.commands * this->params.commands_scale,
(this->obs.dof_pos - this->params.default_dof_pos) * this->params.dof_pos_scale, (this->obs.dof_pos - this->params.default_dof_pos) * this->params.dof_pos_scale,
this->obs.dof_vel * this->params.dof_vel_scale, this->obs.dof_vel * this->params.dof_vel_scale,
this->obs.actions}, this->obs.actions
1); },1);
obs = torch::clamp(obs, -this->params.clip_obs, this->params.clip_obs); obs = torch::clamp(obs, -this->params.clip_obs, this->params.clip_obs);
return obs; return obs;
} }