fix: organize code

This commit is contained in:
fan-ziqi 2024-03-22 21:18:55 +08:00
parent beb21ae9fa
commit 1b12b00fd1
5 changed files with 145 additions and 155 deletions

View File

@ -68,14 +68,7 @@ private:
std::vector<double> joint_velocities; std::vector<double> joint_velocities;
int dof_mapping[13] = {3, 4, 5, 0, 1, 2, 9, 10, 11, 6, 7, 8}; int dof_mapping[13] = {3, 4, 5, 0, 1, 2, 9, 10, 11, 6, 7, 8};
float Kp[13] = {20, 20, 20, //fr int hip_scale_reduction_indices[] = {0, 3, 6, 9};
20, 20, 20, //fl
20, 20, 20, //rr
20, 20, 20};//rl
float Kd[13] = {0.5, 0.5, 0.5,
0.5, 0.5, 0.5,
0.5, 0.5, 0.5,
0.5, 0.5, 0.5};
std::chrono::high_resolution_clock::time_point start_time; std::chrono::high_resolution_clock::time_point start_time;

View File

@ -58,6 +58,8 @@ private:
std::vector<double> joint_positions; std::vector<double> joint_positions;
std::vector<double> joint_velocities; std::vector<double> joint_velocities;
int hip_scale_reduction_indices[] = {0, 3, 6, 9};
std::chrono::high_resolution_clock::time_point start_time; std::chrono::high_resolution_clock::time_point start_time;
// other rl module // other rl module

View File

@ -6,7 +6,7 @@ torch::Tensor RL::QuatRotateInverse(torch::Tensor q, torch::Tensor v)
torch::Tensor q_w = q.index({torch::indexing::Slice(), -1}); torch::Tensor q_w = q.index({torch::indexing::Slice(), -1});
torch::Tensor q_vec = q.index({torch::indexing::Slice(), torch::indexing::Slice(0, 3)}); torch::Tensor q_vec = q.index({torch::indexing::Slice(), torch::indexing::Slice(0, 3)});
torch::Tensor a = v * (2.0 * torch::pow(q_w, 2) - 1.0).unsqueeze(-1); torch::Tensor a = v * (2.0 * torch::pow(q_w, 2) - 1.0).unsqueeze(-1);
torch::Tensor b = torch::cross(q_vec, v, /*dim=*/-1) * q_w.unsqueeze(-1) * 2.0; torch::Tensor b = torch::cross(q_vec, v, -1) * q_w.unsqueeze(-1) * 2.0;
torch::Tensor c = q_vec * torch::bmm(q_vec.view({shape[0], 1, 3}), v.view({shape[0], 3, 1})).squeeze(-1) * 2.0; torch::Tensor c = q_vec * torch::bmm(q_vec.view({shape[0], 1, 3}), v.view({shape[0], 3, 1})).squeeze(-1) * 2.0;
return a - b + c; return a - b + c;
} }
@ -26,12 +26,6 @@ void RL::InitObservations()
torch::Tensor RL::ComputeTorques(torch::Tensor actions) torch::Tensor RL::ComputeTorques(torch::Tensor actions)
{ {
torch::Tensor actions_scaled = actions * this->params.action_scale; torch::Tensor actions_scaled = actions * this->params.action_scale;
int indices[] = {0, 3, 6, 9};
for (int i : indices)
{
actions_scaled[0][i] *= this->params.hip_scale_reduction;
}
torch::Tensor output_torques = this->params.p_gains * (actions_scaled + this->params.default_dof_pos - this->obs.dof_pos) - this->params.d_gains * this->obs.dof_vel; torch::Tensor output_torques = this->params.p_gains * (actions_scaled + this->params.default_dof_pos - this->obs.dof_pos) - this->params.d_gains * this->obs.dof_vel;
torch::Tensor clamped = torch::clamp(output_torques, -(this->params.torque_limits), this->params.torque_limits); torch::Tensor clamped = torch::clamp(output_torques, -(this->params.torque_limits), this->params.torque_limits);
return clamped; return clamped;
@ -40,12 +34,6 @@ torch::Tensor RL::ComputeTorques(torch::Tensor actions)
torch::Tensor RL::ComputePosition(torch::Tensor actions) torch::Tensor RL::ComputePosition(torch::Tensor actions)
{ {
torch::Tensor actions_scaled = actions * this->params.action_scale; torch::Tensor actions_scaled = actions * this->params.action_scale;
int indices[] = {0, 3, 6, 9};
for (int i : indices)
{
actions_scaled[0][i] *= this->params.hip_scale_reduction;
}
return actions_scaled + this->params.default_dof_pos; return actions_scaled + this->params.default_dof_pos;
} }

View File

@ -4,6 +4,82 @@
RL_Real rl_sar; RL_Real rl_sar;
RL_Real::RL_Real() : safe(LeggedType::A1), udp(LOWLEVEL)
{
udp.InitCmdData(cmd);
start_time = std::chrono::high_resolution_clock::now();
std::string actor_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/actor.pt";
std::string encoder_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/encoder.pt";
std::string vq_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/vq_layer.pt";
this->actor = torch::jit::load(actor_path);
this->encoder = torch::jit::load(encoder_path);
this->vq = torch::jit::load(vq_path);
this->InitObservations();
this->params.num_observations = 45;
this->params.clip_obs = 100.0;
this->params.clip_actions = 100.0;
this->params.damping = 0.5;
this->params.stiffness = 20;
this->params.d_gains = torch::ones(12) * this->params.damping;
this->params.p_gains = torch::ones(12) * this->params.stiffness;
this->params.action_scale = 0.25;
this->params.hip_scale_reduction = 0.5;
this->params.num_of_dofs = 12;
this->params.lin_vel_scale = 2.0;
this->params.ang_vel_scale = 0.25;
this->params.dof_pos_scale = 1.0;
this->params.dof_vel_scale = 0.05;
this->params.commands_scale = torch::tensor({this->params.lin_vel_scale, this->params.lin_vel_scale, this->params.ang_vel_scale});
this->params.torque_limits = torch::tensor({{20.0, 55.0, 55.0,
20.0, 55.0, 55.0,
20.0, 55.0, 55.0,
20.0, 55.0, 55.0}});
// hip, thigh, calf
this->params.default_dof_pos = torch::tensor({{ 0.1000, 0.8000, -1.5000, // FL
-0.1000, 0.8000, -1.5000, // FR
0.1000, 1.0000, -1.5000, // RR
-0.1000, 1.0000, -1.5000}});// RL
this->history_obs_buf = ObservationBuffer(1, this->params.num_observations, 6);
output_torques = torch::tensor({{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}});
output_dof_pos = params.default_dof_pos;
plot_real_joint_pos.resize(12);
plot_target_joint_pos.resize(12);
loop_control = std::make_shared<LoopFunc>("loop_control", 0.002, boost::bind(&RL_Real::RobotControl, this));
loop_udpSend = std::make_shared<LoopFunc>("loop_udpSend", 0.002, 3, boost::bind(&RL_Real::UDPSend, this));
loop_udpRecv = std::make_shared<LoopFunc>("loop_udpRecv", 0.002, 3, boost::bind(&RL_Real::UDPRecv, this));
loop_rl = std::make_shared<LoopFunc>("loop_rl" , 0.02 , boost::bind(&RL_Real::RunModel, this));
loop_udpSend->start();
loop_udpRecv->start();
loop_control->start();
#ifdef PLOT
loop_plot = std::make_shared<LoopFunc>("loop_plot" , 0.002, boost::bind(&RL_Real::Plot, this));
loop_plot->start();
#endif
}
RL_Real::~RL_Real()
{
loop_udpSend->shutdown();
loop_udpRecv->shutdown();
loop_control->shutdown();
loop_rl->shutdown();
#ifdef PLOT
loop_plot->shutdown();
#endif
printf("exit\n");
}
void RL_Real::RobotControl() void RL_Real::RobotControl()
{ {
motiontime++; motiontime++;
@ -81,8 +157,6 @@ void RL_Real::RobotControl()
// cmd.motorCmd[i].q = 0; // cmd.motorCmd[i].q = 0;
cmd.motorCmd[i].q = output_dof_pos[0][dof_mapping[i]].item<double>(); cmd.motorCmd[i].q = output_dof_pos[0][dof_mapping[i]].item<double>();
cmd.motorCmd[i].dq = 0; cmd.motorCmd[i].dq = 0;
// cmd.motorCmd[i].Kp = Kp[dof_mapping[i]];
// cmd.motorCmd[i].Kd = Kd[dof_mapping[i]];
cmd.motorCmd[i].Kp = params.stiffness; cmd.motorCmd[i].Kp = params.stiffness;
cmd.motorCmd[i].Kd = params.damping; cmd.motorCmd[i].Kd = params.damping;
// cmd.motorCmd[i].tau = output_torques[0][dof_mapping[i]].item<double>(); // cmd.motorCmd[i].tau = output_torques[0][dof_mapping[i]].item<double>();
@ -129,100 +203,6 @@ void RL_Real::RobotControl()
udp.SetSend(cmd); udp.SetSend(cmd);
} }
RL_Real::RL_Real() : safe(LeggedType::A1), udp(LOWLEVEL)
{
udp.InitCmdData(cmd);
start_time = std::chrono::high_resolution_clock::now();
std::string actor_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/actor.pt";
std::string encoder_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/encoder.pt";
std::string vq_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/vq_layer.pt";
this->actor = torch::jit::load(actor_path);
this->encoder = torch::jit::load(encoder_path);
this->vq = torch::jit::load(vq_path);
this->InitObservations();
this->params.num_observations = 45;
this->params.clip_obs = 100.0;
this->params.clip_actions = 100.0;
this->params.damping = 0.5;
this->params.stiffness = 20;
this->params.d_gains = torch::ones(12) * this->params.damping;
this->params.p_gains = torch::ones(12) * this->params.stiffness;
this->params.action_scale = 0.25;
this->params.hip_scale_reduction = 0.5;
this->params.num_of_dofs = 12;
this->params.lin_vel_scale = 2.0;
this->params.ang_vel_scale = 0.25;
this->params.dof_pos_scale = 1.0;
this->params.dof_vel_scale = 0.05;
this->params.commands_scale = torch::tensor({this->params.lin_vel_scale, this->params.lin_vel_scale, this->params.ang_vel_scale});
this->params.torque_limits = torch::tensor({{20.0, 55.0, 55.0,
20.0, 55.0, 55.0,
20.0, 55.0, 55.0,
20.0, 55.0, 55.0}});
// hip, thigh, calf
this->params.default_dof_pos = torch::tensor({{ 0.1000, 0.8000, -1.5000, // FL
-0.1000, 0.8000, -1.5000, // FR
0.1000, 1.0000, -1.5000, // RR
-0.1000, 1.0000, -1.5000}});// RL
this->history_obs_buf = ObservationBuffer(1, this->params.num_observations, 6);
output_torques = torch::tensor({{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}});
output_dof_pos = params.default_dof_pos;
plot_real_joint_pos.resize(12);
plot_target_joint_pos.resize(12);
loop_control = std::make_shared<LoopFunc>("loop_control", 0.002, boost::bind(&RL_Real::RobotControl, this));
loop_udpSend = std::make_shared<LoopFunc>("loop_udpSend", 0.002, 3, boost::bind(&RL_Real::UDPSend, this));
loop_udpRecv = std::make_shared<LoopFunc>("loop_udpRecv", 0.002, 3, boost::bind(&RL_Real::UDPRecv, this));
loop_rl = std::make_shared<LoopFunc>("loop_rl" , 0.02 , boost::bind(&RL_Real::RunModel, this));
loop_udpSend->start();
loop_udpRecv->start();
loop_control->start();
#ifdef PLOT
loop_plot = std::make_shared<LoopFunc>("loop_plot" , 0.002, boost::bind(&RL_Real::Plot, this));
loop_plot->start();
#endif
}
RL_Real::~RL_Real()
{
loop_udpSend->shutdown();
loop_udpRecv->shutdown();
loop_control->shutdown();
loop_rl->shutdown();
#ifdef PLOT
loop_plot->shutdown();
#endif
printf("exit\n");
}
void RL_Real::Plot()
{
plot_t.push_back(motiontime);
plt::cla();
plt::clf();
for(int i = 0; i < 12; ++i)
{
plot_real_joint_pos[i].push_back(state.motorState[i].q);
plot_target_joint_pos[i].push_back(cmd.motorCmd[i].q);
plt::subplot(4, 3, i+1);
plt::named_plot("_real_joint_pos", plot_t, plot_real_joint_pos[i], "r");
plt::named_plot("_target_joint_pos", plot_t, plot_target_joint_pos[i], "b");
plt::xlim(motiontime-10000, motiontime);
}
// plt::legend();
plt::pause(0.0001);
}
void RL_Real::RunModel() void RL_Real::RunModel()
{ {
if(robot_state == STATE_RL_RUNNING) if(robot_state == STATE_RL_RUNNING)
@ -260,6 +240,11 @@ void RL_Real::RunModel()
torch::Tensor actions = this->Forward(); torch::Tensor actions = this->Forward();
for (int i : hip_scale_reduction_indices)
{
actions[0][i] *= this->params.hip_scale_reduction;
}
output_torques = this->ComputeTorques(actions); output_torques = this->ComputeTorques(actions);
output_dof_pos = this->ComputePosition(actions); output_dof_pos = this->ComputePosition(actions);
} }
@ -268,7 +253,7 @@ void RL_Real::RunModel()
torch::Tensor RL_Real::ComputeObservation() torch::Tensor RL_Real::ComputeObservation()
{ {
torch::Tensor obs = torch::cat({// (this->QuatRotateInverse(this->base_quat, this->lin_vel)) * this->params.lin_vel_scale, torch::Tensor obs = torch::cat({// this->QuatRotateInverse(this->obs.base_quat, this->obs.lin_vel) * this->params.lin_vel_scale,
this->QuatRotateInverse(this->obs.base_quat, this->obs.ang_vel) * this->params.ang_vel_scale, this->QuatRotateInverse(this->obs.base_quat, this->obs.ang_vel) * this->params.ang_vel_scale,
this->QuatRotateInverse(this->obs.base_quat, this->obs.gravity_vec), this->QuatRotateInverse(this->obs.base_quat, this->obs.gravity_vec),
this->obs.commands * this->params.commands_scale, this->obs.commands * this->params.commands_scale,
@ -301,6 +286,24 @@ torch::Tensor RL_Real::Forward()
return clamped; return clamped;
} }
void RL_Real::Plot()
{
plot_t.push_back(motiontime);
plt::cla();
plt::clf();
for(int i = 0; i < 12; ++i)
{
plot_real_joint_pos[i].push_back(state.motorState[i].q);
plot_target_joint_pos[i].push_back(cmd.motorCmd[i].q);
plt::subplot(4, 3, i+1);
plt::named_plot("_real_joint_pos", plot_t, plot_real_joint_pos[i], "r");
plt::named_plot("_target_joint_pos", plot_t, plot_target_joint_pos[i], "b");
plt::xlim(motiontime-10000, motiontime);
}
// plt::legend();
plt::pause(0.0001);
}
void signalHandler(int signum) void signalHandler(int signum)
{ {
exit(0); exit(0);

View File

@ -2,42 +2,6 @@
// #define PLOT // #define PLOT
void RL_Sim::RobotControl()
{
motiontime++;
for (int i = 0; i < 12; ++i)
{
motor_commands[i].mode = 0x0A;
motor_commands[i].q = output_dof_pos[0][i].item<double>();
motor_commands[i].dq = 0;
motor_commands[i].Kp = params.stiffness;
motor_commands[i].Kd = params.damping;
// motor_commands[i].tau = output_torques[0][i].item<double>();
motor_commands[i].tau = 0;
torque_publishers[joint_names[i]].publish(motor_commands[i]);
}
}
void RL_Sim::Plot()
{
int dof_mapping[13] = {1, 2, 0, 4, 5, 3, 7, 8, 6, 10, 11, 9};
plot_t.push_back(motiontime);
plt::cla();
plt::clf();
for(int i = 0; i < 12; ++i)
{
plot_real_joint_pos[i].push_back(joint_positions[dof_mapping[i]]);
plot_target_joint_pos[i].push_back(motor_commands[i].q);
plt::subplot(4, 3, i+1);
plt::named_plot("_real_joint_pos", plot_t, plot_real_joint_pos[i], "r");
plt::named_plot("_target_joint_pos", plot_t, plot_target_joint_pos[i], "b");
plt::xlim(motiontime-10000, motiontime);
}
// plt::legend();
plt::pause(0.0001);
}
RL_Sim::RL_Sim() RL_Sim::RL_Sim()
{ {
ros::NodeHandle nh; ros::NodeHandle nh;
@ -72,7 +36,6 @@ RL_Sim::RL_Sim()
this->params.dof_vel_scale = 0.05; this->params.dof_vel_scale = 0.05;
this->params.commands_scale = torch::tensor({this->params.lin_vel_scale, this->params.lin_vel_scale, this->params.ang_vel_scale}); this->params.commands_scale = torch::tensor({this->params.lin_vel_scale, this->params.lin_vel_scale, this->params.ang_vel_scale});
this->params.torque_limits = torch::tensor({{20.0, 55.0, 55.0, this->params.torque_limits = torch::tensor({{20.0, 55.0, 55.0,
20.0, 55.0, 55.0, 20.0, 55.0, 55.0,
20.0, 55.0, 55.0, 20.0, 55.0, 55.0,
@ -137,6 +100,23 @@ RL_Sim::~RL_Sim()
printf("exit\n"); printf("exit\n");
} }
void RL_Sim::RobotControl()
{
motiontime++;
for (int i = 0; i < 12; ++i)
{
motor_commands[i].mode = 0x0A;
motor_commands[i].q = output_dof_pos[0][i].item<double>();
motor_commands[i].dq = 0;
motor_commands[i].Kp = params.stiffness;
motor_commands[i].Kd = params.damping;
// motor_commands[i].tau = output_torques[0][i].item<double>();
motor_commands[i].tau = 0;
torque_publishers[joint_names[i]].publish(motor_commands[i]);
}
}
void RL_Sim::ModelStatesCallback(const gazebo_msgs::ModelStates::ConstPtr &msg) void RL_Sim::ModelStatesCallback(const gazebo_msgs::ModelStates::ConstPtr &msg)
{ {
@ -191,6 +171,11 @@ void RL_Sim::RunModel()
torch::Tensor actions = this->Forward(); torch::Tensor actions = this->Forward();
for (int i : hip_scale_reduction_indices)
{
actions[0][i] *= this->params.hip_scale_reduction;
}
output_torques = this->ComputeTorques(actions); output_torques = this->ComputeTorques(actions);
output_dof_pos = this->ComputePosition(actions); output_dof_pos = this->ComputePosition(actions);
} }
@ -230,6 +215,25 @@ torch::Tensor RL_Sim::Forward()
return clamped; return clamped;
} }
void RL_Sim::Plot()
{
int dof_mapping[13] = {1, 2, 0, 4, 5, 3, 7, 8, 6, 10, 11, 9};
plot_t.push_back(motiontime);
plt::cla();
plt::clf();
for(int i = 0; i < 12; ++i)
{
plot_real_joint_pos[i].push_back(joint_positions[dof_mapping[i]]);
plot_target_joint_pos[i].push_back(motor_commands[i].q);
plt::subplot(4, 3, i+1);
plt::named_plot("_real_joint_pos", plot_t, plot_real_joint_pos[i], "r");
plt::named_plot("_target_joint_pos", plot_t, plot_target_joint_pos[i], "b");
plt::xlim(motiontime-10000, motiontime);
}
// plt::legend();
plt::pause(0.0001);
}
void signalHandler(int signum) void signalHandler(int signum)
{ {
ros::shutdown(); ros::shutdown();