mirror of https://github.com/fan-ziqi/rl_sar.git
fix freq bug
This commit is contained in:
parent
220347de1d
commit
8d2ab423b5
|
@ -52,8 +52,6 @@ public:
|
||||||
std::shared_ptr<LoopFunc> loop_rl;
|
std::shared_ptr<LoopFunc> loop_rl;
|
||||||
|
|
||||||
float _percent;
|
float _percent;
|
||||||
// float _targetPos[12] = {0.0, 0.8, -1.6, 0.0, 0.8, -1.6,
|
|
||||||
// 0.0, 0.8, -1.6, 0.0, 0.8, -1.6};
|
|
||||||
float _startPos[12];
|
float _startPos[12];
|
||||||
|
|
||||||
int init_state = STATE_WAITING;
|
int init_state = STATE_WAITING;
|
||||||
|
@ -63,21 +61,7 @@ private:
|
||||||
std::vector<double> joint_positions;
|
std::vector<double> joint_positions;
|
||||||
std::vector<double> joint_velocities;
|
std::vector<double> joint_velocities;
|
||||||
|
|
||||||
torch::Tensor torques = torch::tensor({{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}});
|
int dof_mapping[13] = {3, 4, 5, 0, 1, 2, 9, 10, 11, 6, 7, 8};
|
||||||
int dof_mapping[13] = {3, 4, 5,
|
|
||||||
0, 1, 2,
|
|
||||||
9, 10, 11,
|
|
||||||
6, 7, 8};
|
|
||||||
float Kp[13] = {20, 10, 10,
|
|
||||||
20, 10, 10,
|
|
||||||
20, 10, 10,
|
|
||||||
20, 10, 10};
|
|
||||||
float Kd[13] = {1.0, 0.5, 0.5,
|
|
||||||
1.0, 0.5, 0.5,
|
|
||||||
1.0, 0.5, 0.5,
|
|
||||||
1.0, 0.5, 0.5};
|
|
||||||
torch::Tensor target_dof_pos;
|
|
||||||
torch::Tensor compute_pos(torch::Tensor actions);
|
|
||||||
|
|
||||||
std::chrono::high_resolution_clock::time_point start_time;
|
std::chrono::high_resolution_clock::time_point start_time;
|
||||||
|
|
||||||
|
|
|
@ -34,7 +34,7 @@
|
||||||
<!-- Load joint controller configurations from YAML file to parameter server -->
|
<!-- Load joint controller configurations from YAML file to parameter server -->
|
||||||
<rosparam file="$(arg dollar)$(arg robot_path)/config/robot_control.yaml" command="load"/>
|
<rosparam file="$(arg dollar)$(arg robot_path)/config/robot_control.yaml" command="load"/>
|
||||||
|
|
||||||
<rosparam param="/a1_gazebo/joint_state_controller/publish_rate">5000</rosparam>
|
<!-- <rosparam param="/a1_gazebo/joint_state_controller/publish_rate">5000</rosparam> -->
|
||||||
|
|
||||||
<!-- load the controllers -->
|
<!-- load the controllers -->
|
||||||
<node pkg="controller_manager" type="spawner" name="controller_spawner" respawn="false"
|
<node pkg="controller_manager" type="spawner" name="controller_spawner" respawn="false"
|
||||||
|
|
|
@ -37,6 +37,18 @@ torch::Tensor RL::compute_torques(torch::Tensor actions)
|
||||||
return clamped;
|
return clamped;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
torch::Tensor RL::compute_pos(torch::Tensor actions)
|
||||||
|
{
|
||||||
|
torch::Tensor actions_scaled = actions * this->params.action_scale;
|
||||||
|
int indices[] = {0, 3, 6, 9};
|
||||||
|
for (int i : indices)
|
||||||
|
{
|
||||||
|
actions_scaled[0][i] *= this->params.hip_scale_reduction;
|
||||||
|
}
|
||||||
|
|
||||||
|
return actions_scaled + this->params.default_dof_pos;
|
||||||
|
}
|
||||||
|
|
||||||
/* You may need to override this compute_observation() function
|
/* You may need to override this compute_observation() function
|
||||||
torch::Tensor RL::compute_observation()
|
torch::Tensor RL::compute_observation()
|
||||||
{
|
{
|
||||||
|
|
|
@ -44,8 +44,8 @@ public:
|
||||||
|
|
||||||
virtual torch::Tensor forward() = 0;
|
virtual torch::Tensor forward() = 0;
|
||||||
virtual torch::Tensor compute_observation() = 0;
|
virtual torch::Tensor compute_observation() = 0;
|
||||||
|
|
||||||
torch::Tensor compute_torques(torch::Tensor actions);
|
torch::Tensor compute_torques(torch::Tensor actions);
|
||||||
|
torch::Tensor compute_pos(torch::Tensor actions);
|
||||||
torch::Tensor quat_rotate_inverse(torch::Tensor q, torch::Tensor v);
|
torch::Tensor quat_rotate_inverse(torch::Tensor q, torch::Tensor v);
|
||||||
void init_observations();
|
void init_observations();
|
||||||
|
|
||||||
|
@ -61,6 +61,9 @@ protected:
|
||||||
torch::Tensor dof_pos;
|
torch::Tensor dof_pos;
|
||||||
torch::Tensor dof_vel;
|
torch::Tensor dof_vel;
|
||||||
torch::Tensor actions;
|
torch::Tensor actions;
|
||||||
|
|
||||||
|
torch::Tensor torques = torch::tensor({{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}});
|
||||||
|
torch::Tensor target_dof_pos;
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif // RL_HPP
|
#endif // RL_HPP
|
|
@ -103,8 +103,8 @@ void RL_Real::RobotControl()
|
||||||
cmd.motorCmd[i].mode = 0x0A;
|
cmd.motorCmd[i].mode = 0x0A;
|
||||||
cmd.motorCmd[i].q = target_dof_pos[0][dof_mapping[i]].item<double>();
|
cmd.motorCmd[i].q = target_dof_pos[0][dof_mapping[i]].item<double>();
|
||||||
cmd.motorCmd[i].dq = 0;
|
cmd.motorCmd[i].dq = 0;
|
||||||
cmd.motorCmd[i].Kp = 15;
|
cmd.motorCmd[i].Kp = params.stiffness;
|
||||||
cmd.motorCmd[i].Kd = 1.5;
|
cmd.motorCmd[i].Kd = params.damping;
|
||||||
cmd.motorCmd[i].tau = 0;
|
cmd.motorCmd[i].tau = 0;
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
@ -181,24 +181,12 @@ RL_Real::~RL_Real()
|
||||||
printf("shutdown\n");
|
printf("shutdown\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
torch::Tensor RL_Real::compute_pos(torch::Tensor actions)
|
|
||||||
{
|
|
||||||
torch::Tensor actions_scaled = actions * this->params.action_scale;
|
|
||||||
int indices[] = {0, 3, 6, 9};
|
|
||||||
for (int i : indices)
|
|
||||||
{
|
|
||||||
actions_scaled[0][i] *= this->params.hip_scale_reduction;
|
|
||||||
}
|
|
||||||
|
|
||||||
return actions_scaled + this->params.default_dof_pos;
|
|
||||||
}
|
|
||||||
|
|
||||||
void RL_Real::runModel()
|
void RL_Real::runModel()
|
||||||
{
|
{
|
||||||
if(init_state == STATE_RL_START)
|
if(init_state == STATE_RL_START)
|
||||||
{
|
{
|
||||||
auto duration = std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::high_resolution_clock::now() - start_time).count();
|
auto duration = std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::high_resolution_clock::now() - start_time).count();
|
||||||
// std::cout << "Execution time: " << duration << " microseconds" << std::endl;
|
std::cout << "Execution time: " << duration << " microseconds" << std::endl;
|
||||||
start_time = std::chrono::high_resolution_clock::now();
|
start_time = std::chrono::high_resolution_clock::now();
|
||||||
|
|
||||||
// printf("%f, %f, %f\n", state.imu.gyroscope[0], state.imu.gyroscope[1], state.imu.gyroscope[2]);
|
// printf("%f, %f, %f\n", state.imu.gyroscope[0], state.imu.gyroscope[1], state.imu.gyroscope[2]);
|
||||||
|
|
|
@ -49,10 +49,12 @@ RL_Sim::RL_Sim()
|
||||||
|
|
||||||
this->history_obs_buf = ObservationBuffer(1, this->params.num_observations, 6);
|
this->history_obs_buf = ObservationBuffer(1, this->params.num_observations, 6);
|
||||||
|
|
||||||
|
target_dof_pos = params.default_dof_pos;
|
||||||
|
|
||||||
cmd_vel_subscriber_ = nh.subscribe<geometry_msgs::Twist>(
|
cmd_vel_subscriber_ = nh.subscribe<geometry_msgs::Twist>(
|
||||||
"/cmd_vel", 10, &RL_Sim::cmdvelCallback, this);
|
"/cmd_vel", 10, &RL_Sim::cmdvelCallback, this);
|
||||||
|
|
||||||
timer = nh.createTimer(ros::Duration(0.005), &RL_Sim::runModel, this);
|
timer = nh.createTimer(ros::Duration(0.02), &RL_Sim::runModel, this);
|
||||||
|
|
||||||
ros_namespace = "/a1_gazebo/";
|
ros_namespace = "/a1_gazebo/";
|
||||||
|
|
||||||
|
@ -115,11 +117,17 @@ void RL_Sim::runModel(const ros::TimerEvent &event)
|
||||||
|
|
||||||
torch::Tensor actions = this->forward();
|
torch::Tensor actions = this->forward();
|
||||||
torques = this->compute_torques(actions);
|
torques = this->compute_torques(actions);
|
||||||
|
target_dof_pos = this->compute_pos(actions);
|
||||||
|
|
||||||
for (int i = 0; i < 12; ++i)
|
for (int i = 0; i < 12; ++i)
|
||||||
{
|
{
|
||||||
torque_commands[i].tau = torques[0][i].item<double>();
|
|
||||||
torque_commands[i].mode = 0x0A;
|
torque_commands[i].mode = 0x0A;
|
||||||
|
// torque_commands[i].tau = torques[0][i].item<double>();
|
||||||
|
torque_commands[i].tau = 0;
|
||||||
|
torque_commands[i].q = target_dof_pos[0][i].item<double>();
|
||||||
|
torque_commands[i].dq = 0;
|
||||||
|
torque_commands[i].Kp = params.stiffness;
|
||||||
|
torque_commands[i].Kd = params.damping;
|
||||||
|
|
||||||
torque_publishers[joint_names[i]].publish(torque_commands[i]);
|
torque_publishers[joint_names[i]].publish(torque_commands[i]);
|
||||||
}
|
}
|
||||||
|
@ -133,8 +141,8 @@ torch::Tensor RL_Sim::compute_observation()
|
||||||
this->obs.commands * this->params.commands_scale,
|
this->obs.commands * this->params.commands_scale,
|
||||||
(this->obs.dof_pos - this->params.default_dof_pos) * this->params.dof_pos_scale,
|
(this->obs.dof_pos - this->params.default_dof_pos) * this->params.dof_pos_scale,
|
||||||
this->obs.dof_vel * this->params.dof_vel_scale,
|
this->obs.dof_vel * this->params.dof_vel_scale,
|
||||||
this->obs.actions},
|
this->obs.actions
|
||||||
1);
|
},1);
|
||||||
obs = torch::clamp(obs, -this->params.clip_obs, this->params.clip_obs);
|
obs = torch::clamp(obs, -this->params.clip_obs, this->params.clip_obs);
|
||||||
return obs;
|
return obs;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue