feat: add CSV_LOGGER && format

This commit is contained in:
fan-ziqi 2024-04-15 13:11:33 +08:00
parent e79596c1f1
commit 956072401f
1 changed files with 21 additions and 10 deletions

View File

@ -3,6 +3,7 @@
#define ROBOT_NAME "a1" #define ROBOT_NAME "a1"
// #define PLOT // #define PLOT
#define CSV_LOGGER
RL_Real rl_sar; RL_Real rl_sar;
@ -42,6 +43,10 @@ RL_Real::RL_Real() : safe(LeggedType::A1), udp(LOWLEVEL)
loop_plot = std::make_shared<LoopFunc>("loop_plot" , 0.002, boost::bind(&RL_Real::Plot, this)); loop_plot = std::make_shared<LoopFunc>("loop_plot" , 0.002, boost::bind(&RL_Real::Plot, this));
loop_plot->start(); loop_plot->start();
#endif #endif
#ifdef CSV_LOGGER
CSVInit(ROBOT_NAME);
#endif
} }
RL_Real::~RL_Real() RL_Real::~RL_Real()
@ -97,7 +102,7 @@ void RL_Real::RobotControl()
cmd.motorCmd[i].Kd = 3; cmd.motorCmd[i].Kd = 3;
cmd.motorCmd[i].tau = 0; cmd.motorCmd[i].tau = 0;
} }
printf("getting up %.3f%%\r", getup_percent*100.0); printf("getting up %.3f%%\r", getup_percent * 100.0);
} }
if((int)_keyData.btn.components.R1 == 1) if((int)_keyData.btn.components.R1 == 1)
{ {
@ -167,7 +172,7 @@ void RL_Real::RobotControl()
cmd.motorCmd[i].Kd = 3; cmd.motorCmd[i].Kd = 3;
cmd.motorCmd[i].tau = 0; cmd.motorCmd[i].tau = 0;
} }
printf("getting down %.3f%%\r", getdown_percent*100.0); printf("getting down %.3f%%\r", getdown_percent * 100.0);
} }
if(getdown_percent == 1) if(getdown_percent == 1)
{ {
@ -227,8 +232,14 @@ void RL_Real::RunModel()
output_torques = this->ComputeTorques(actions); output_torques = this->ComputeTorques(actions);
output_dof_pos = this->ComputePosition(actions); output_dof_pos = this->ComputePosition(actions);
#ifdef CSV_LOGGER
torch::Tensor tau_est = torch::tensor({{state.motorState[FL_0].tauEst, state.motorState[FL_1].tauEst, state.motorState[FL_2].tauEst,
state.motorState[FR_0].tauEst, state.motorState[FR_1].tauEst, state.motorState[FR_2].tauEst,
state.motorState[RL_0].tauEst, state.motorState[RL_1].tauEst, state.motorState[RL_2].tauEst,
state.motorState[RR_0].tauEst, state.motorState[RR_1].tauEst, state.motorState[RR_2].tauEst}});
CSVLogger(output_torques, tau_est, this->obs.dof_pos, output_dof_pos, this->obs.dof_vel);
#endif
} }
} }
torch::Tensor RL_Real::ComputeObservation() torch::Tensor RL_Real::ComputeObservation()
@ -275,10 +286,10 @@ void RL_Real::Plot()
{ {
plot_real_joint_pos[i].push_back(state.motorState[i].q); plot_real_joint_pos[i].push_back(state.motorState[i].q);
plot_target_joint_pos[i].push_back(cmd.motorCmd[i].q); plot_target_joint_pos[i].push_back(cmd.motorCmd[i].q);
plt::subplot(4, 3, i+1); plt::subplot(4, 3, i + 1);
plt::named_plot("_real_joint_pos", plot_t, plot_real_joint_pos[i], "r"); plt::named_plot("_real_joint_pos", plot_t, plot_real_joint_pos[i], "r");
plt::named_plot("_target_joint_pos", plot_t, plot_target_joint_pos[i], "b"); plt::named_plot("_target_joint_pos", plot_t, plot_target_joint_pos[i], "b");
plt::xlim(motiontime-10000, motiontime); plt::xlim(motiontime - 10000, motiontime);
} }
// plt::legend(); // plt::legend();
plt::pause(0.0001); plt::pause(0.0001);