fix: add this to class member

This commit is contained in:
fan-ziqi 2024-05-29 12:56:38 +08:00
parent be79e7f64a
commit 8327b52d01
4 changed files with 235 additions and 236 deletions

View File

@ -51,12 +51,12 @@ private:
geometry_msgs::Twist vel; geometry_msgs::Twist vel;
geometry_msgs::Pose pose; geometry_msgs::Pose pose;
geometry_msgs::Twist cmd_vel; geometry_msgs::Twist cmd_vel;
std::vector<std::string> torque_command_topics;
ros::Subscriber model_state_subscriber; ros::Subscriber model_state_subscriber;
ros::Subscriber joint_state_subscriber; ros::Subscriber joint_state_subscriber;
ros::Subscriber cmd_vel_subscriber; ros::Subscriber cmd_vel_subscriber;
std::map<std::string, ros::Publisher> torque_publishers;
ros::ServiceClient gazebo_reset_client; ros::ServiceClient gazebo_reset_client;
std::map<std::string, ros::Publisher> joint_publishers;
std::vector<robot_msgs::MotorCommand> joint_publishers_commands;
void ModelStatesCallback(const gazebo_msgs::ModelStates::ConstPtr &msg); void ModelStatesCallback(const gazebo_msgs::ModelStates::ConstPtr &msg);
void JointStatesCallback(const sensor_msgs::JointState::ConstPtr &msg); void JointStatesCallback(const sensor_msgs::JointState::ConstPtr &msg);
void CmdvelCallback(const geometry_msgs::Twist::ConstPtr &msg); void CmdvelCallback(const geometry_msgs::Twist::ConstPtr &msg);
@ -64,7 +64,6 @@ private:
// others // others
int motiontime = 0; int motiontime = 0;
std::map<std::string, size_t> sorted_to_original_index; std::map<std::string, size_t> sorted_to_original_index;
std::vector<robot_msgs::MotorCommand> motor_commands;
std::vector<double> mapped_joint_positions; std::vector<double> mapped_joint_positions;
std::vector<double> mapped_joint_velocities; std::vector<double> mapped_joint_velocities;
std::vector<double> mapped_joint_efforts; std::vector<double> mapped_joint_efforts;

View File

@ -38,14 +38,14 @@ void RL::InitObservations()
this->obs.commands = torch::tensor({{0.0, 0.0, 0.0}}); this->obs.commands = torch::tensor({{0.0, 0.0, 0.0}});
this->obs.base_quat = torch::tensor({{0.0, 0.0, 0.0, 1.0}}); this->obs.base_quat = torch::tensor({{0.0, 0.0, 0.0, 1.0}});
this->obs.dof_pos = this->params.default_dof_pos; this->obs.dof_pos = this->params.default_dof_pos;
this->obs.dof_vel = torch::zeros({1, params.num_of_dofs}); this->obs.dof_vel = torch::zeros({1, this->params.num_of_dofs});
this->obs.actions = torch::zeros({1, params.num_of_dofs}); this->obs.actions = torch::zeros({1, this->params.num_of_dofs});
} }
void RL::InitOutputs() void RL::InitOutputs()
{ {
this->output_torques = torch::zeros({1, params.num_of_dofs}); this->output_torques = torch::zeros({1, this->params.num_of_dofs});
this->output_dof_pos = params.default_dof_pos; this->output_dof_pos = this->params.default_dof_pos;
} }
void RL::InitControl() void RL::InitControl()
@ -88,104 +88,104 @@ void RL::StateController(const RobotState<double> *state, RobotCommand<double> *
static float getdown_percent = 0.0; static float getdown_percent = 0.0;
// waiting // waiting
if(running_state == STATE_WAITING) if(this->running_state == STATE_WAITING)
{ {
for(int i = 0; i < params.num_of_dofs; ++i) for(int i = 0; i < this->params.num_of_dofs; ++i)
{ {
command->motor_command.q[i] = state->motor_state.q[i]; command->motor_command.q[i] = state->motor_state.q[i];
} }
if(control.control_state == STATE_POS_GETUP) if(this->control.control_state == STATE_POS_GETUP)
{ {
control.control_state = STATE_WAITING; this->control.control_state = STATE_WAITING;
getup_percent = 0.0; getup_percent = 0.0;
for(int i = 0; i < params.num_of_dofs; ++i) for(int i = 0; i < this->params.num_of_dofs; ++i)
{ {
now_state.motor_state.q[i] = state->motor_state.q[i]; now_state.motor_state.q[i] = state->motor_state.q[i];
start_state.motor_state.q[i] = now_state.motor_state.q[i]; start_state.motor_state.q[i] = now_state.motor_state.q[i];
} }
running_state = STATE_POS_GETUP; this->running_state = STATE_POS_GETUP;
} }
} }
// stand up (position control) // stand up (position control)
else if(running_state == STATE_POS_GETUP) else if(this->running_state == STATE_POS_GETUP)
{ {
if(getup_percent < 1.0) if(getup_percent < 1.0)
{ {
getup_percent += 1 / 1000.0; getup_percent += 1 / 1000.0;
getup_percent = getup_percent > 1.0 ? 1.0 : getup_percent; getup_percent = getup_percent > 1.0 ? 1.0 : getup_percent;
for(int i = 0; i < params.num_of_dofs; ++i) for(int i = 0; i < this->params.num_of_dofs; ++i)
{ {
command->motor_command.q[i] = (1 - getup_percent) * now_state.motor_state.q[i] + getup_percent * params.default_dof_pos[0][i].item<double>(); command->motor_command.q[i] = (1 - getup_percent) * now_state.motor_state.q[i] + getup_percent * this->params.default_dof_pos[0][i].item<double>();
command->motor_command.dq[i] = 0; command->motor_command.dq[i] = 0;
command->motor_command.kp[i] = params.fixed_kp[0][i].item<double>(); command->motor_command.kp[i] = this->params.fixed_kp[0][i].item<double>();
command->motor_command.kd[i] = params.fixed_kd[0][i].item<double>(); command->motor_command.kd[i] = this->params.fixed_kd[0][i].item<double>();
command->motor_command.tau[i] = 0; command->motor_command.tau[i] = 0;
} }
std::cout << LOGGER::INFO << "Getting up " << std::fixed << std::setprecision(2) << getup_percent * 100.0 << "%\r"; std::cout << LOGGER::INFO << "Getting up " << std::fixed << std::setprecision(2) << getup_percent * 100.0 << "%\r";
} }
if(control.control_state == STATE_RL_INIT) if(this->control.control_state == STATE_RL_INIT)
{ {
std::cout << std::endl; std::cout << std::endl;
control.control_state = STATE_WAITING; this->control.control_state = STATE_WAITING;
running_state = STATE_RL_INIT; this->running_state = STATE_RL_INIT;
} }
else if(control.control_state == STATE_POS_GETDOWN) else if(this->control.control_state == STATE_POS_GETDOWN)
{ {
control.control_state = STATE_WAITING; this->control.control_state = STATE_WAITING;
getdown_percent = 0.0; getdown_percent = 0.0;
for(int i = 0; i < params.num_of_dofs; ++i) for(int i = 0; i < this->params.num_of_dofs; ++i)
{ {
now_state.motor_state.q[i] = state->motor_state.q[i]; now_state.motor_state.q[i] = state->motor_state.q[i];
} }
running_state = STATE_POS_GETDOWN; this->running_state = STATE_POS_GETDOWN;
} }
} }
// init obs and start rl loop // init obs and start rl loop
else if(running_state == STATE_RL_INIT) else if(this->running_state == STATE_RL_INIT)
{ {
if(getup_percent == 1) if(getup_percent == 1)
{ {
running_state = STATE_RL_RUNNING; this->running_state = STATE_RL_RUNNING;
this->InitObservations(); this->InitObservations();
this->InitOutputs(); this->InitOutputs();
this->InitControl(); this->InitControl();
} }
} }
// rl loop // rl loop
else if(running_state == STATE_RL_RUNNING) else if(this->running_state == STATE_RL_RUNNING)
{ {
for(int i = 0; i < params.num_of_dofs; ++i) for(int i = 0; i < this->params.num_of_dofs; ++i)
{ {
command->motor_command.q[i] = output_dof_pos[0][i].item<double>(); command->motor_command.q[i] = this->output_dof_pos[0][i].item<double>();
command->motor_command.dq[i] = 0; command->motor_command.dq[i] = 0;
command->motor_command.kp[i] = params.rl_kp[0][i].item<double>(); command->motor_command.kp[i] = this->params.rl_kp[0][i].item<double>();
command->motor_command.kd[i] = params.rl_kd[0][i].item<double>(); command->motor_command.kd[i] = this->params.rl_kd[0][i].item<double>();
command->motor_command.tau[i] = 0; command->motor_command.tau[i] = 0;
} }
if(control.control_state == STATE_POS_GETDOWN) if(this->control.control_state == STATE_POS_GETDOWN)
{ {
control.control_state = STATE_WAITING; this->control.control_state = STATE_WAITING;
getdown_percent = 0.0; getdown_percent = 0.0;
for(int i = 0; i < params.num_of_dofs; ++i) for(int i = 0; i < this->params.num_of_dofs; ++i)
{ {
now_state.motor_state.q[i] = state->motor_state.q[i]; now_state.motor_state.q[i] = state->motor_state.q[i];
} }
running_state = STATE_POS_GETDOWN; this->running_state = STATE_POS_GETDOWN;
} }
} }
// get down (position control) // get down (position control)
else if(running_state == STATE_POS_GETDOWN) else if(this->running_state == STATE_POS_GETDOWN)
{ {
if(getdown_percent < 1.0) if(getdown_percent < 1.0)
{ {
getdown_percent += 1 / 1000.0; getdown_percent += 1 / 1000.0;
getdown_percent = getdown_percent > 1.0 ? 1.0 : getdown_percent; getdown_percent = getdown_percent > 1.0 ? 1.0 : getdown_percent;
for(int i = 0; i < params.num_of_dofs; ++i) for(int i = 0; i < this->params.num_of_dofs; ++i)
{ {
command->motor_command.q[i] = (1 - getdown_percent) * now_state.motor_state.q[i] + getdown_percent * start_state.motor_state.q[i]; command->motor_command.q[i] = (1 - getdown_percent) * now_state.motor_state.q[i] + getdown_percent * start_state.motor_state.q[i];
command->motor_command.dq[i] = 0; command->motor_command.dq[i] = 0;
command->motor_command.kp[i] = params.fixed_kp[0][i].item<double>(); command->motor_command.kp[i] = this->params.fixed_kp[0][i].item<double>();
command->motor_command.kd[i] = params.fixed_kd[0][i].item<double>(); command->motor_command.kd[i] = this->params.fixed_kd[0][i].item<double>();
command->motor_command.tau[i] = 0; command->motor_command.tau[i] = 0;
} }
std::cout << LOGGER::INFO << "Getting down " << std::fixed << std::setprecision(2) << getdown_percent * 100.0 << "%\r"; std::cout << LOGGER::INFO << "Getting down " << std::fixed << std::setprecision(2) << getdown_percent * 100.0 << "%\r";
@ -193,7 +193,7 @@ void RL::StateController(const RobotState<double> *state, RobotCommand<double> *
if(getdown_percent == 1) if(getdown_percent == 1)
{ {
std::cout << std::endl; std::cout << std::endl;
running_state = STATE_WAITING; this->running_state = STATE_WAITING;
this->InitObservations(); this->InitObservations();
this->InitOutputs(); this->InitOutputs();
this->InitControl(); this->InitControl();
@ -229,7 +229,7 @@ void RL::TorqueProtect(torch::Tensor origin_output_torques)
std::cout << LOGGER::ERROR << "Torque(" << index+1 << ")=" << value << " out of range(" << limit_lower << ", " << limit_upper << ")" << std::endl; std::cout << LOGGER::ERROR << "Torque(" << index+1 << ")=" << value << " out of range(" << limit_lower << ", " << limit_upper << ")" << std::endl;
std::cout << LOGGER::ERROR << "Switching to STATE_POS_GETDOWN"<< std::endl; std::cout << LOGGER::ERROR << "Switching to STATE_POS_GETDOWN"<< std::endl;
} }
control.control_state = STATE_POS_GETDOWN; this->control.control_state = STATE_POS_GETDOWN;
} }
} }
@ -254,9 +254,9 @@ static bool kbhit()
void RL::KeyboardInterface() void RL::KeyboardInterface()
{ {
if(running_state == STATE_RL_RUNNING) if(this->running_state == STATE_RL_RUNNING)
{ {
std::cout << LOGGER::INFO << "RL Controller x:" << control.x << " y:" << control.y << " yaw:" << control.yaw << " \r"; std::cout << LOGGER::INFO << "RL Controller x:" << this->control.x << " y:" << this->control.y << " yaw:" << this->control.yaw << " \r";
} }
if(kbhit()) if(kbhit())
@ -264,20 +264,20 @@ void RL::KeyboardInterface()
int c = fgetc(stdin); int c = fgetc(stdin);
switch(c) switch(c)
{ {
case '0': control.control_state = STATE_POS_GETUP; break; case '0': this->control.control_state = STATE_POS_GETUP; break;
case 'p': control.control_state = STATE_RL_INIT; break; case 'p': this->control.control_state = STATE_RL_INIT; break;
case '1': control.control_state = STATE_POS_GETDOWN; break; case '1': this->control.control_state = STATE_POS_GETDOWN; break;
case 'q': break; case 'q': break;
case 'w': control.x += 0.1; break; case 'w': this->control.x += 0.1; break;
case 's': control.x -= 0.1; break; case 's': this->control.x -= 0.1; break;
case 'a': control.yaw += 0.1; break; case 'a': this->control.yaw += 0.1; break;
case 'd': control.yaw -= 0.1; break; case 'd': this->control.yaw -= 0.1; break;
case 'i': break; case 'i': break;
case 'k': break; case 'k': break;
case 'j': control.y += 0.1; break; case 'j': this->control.y += 0.1; break;
case 'l': control.y -= 0.1; break; case 'l': this->control.y -= 0.1; break;
case ' ': control.x = 0; control.y = 0; control.yaw = 0; break; case ' ': this->control.x = 0; this->control.y = 0; this->control.yaw = 0; break;
case 'r': control.control_state = STATE_RESET_SIMULATION; break; case 'r': this->control.control_state = STATE_RESET_SIMULATION; break;
default: break; default: break;
} }
} }

View File

@ -8,13 +8,13 @@ RL_Real rl_sar;
RL_Real::RL_Real() : unitree_safe(UNITREE_LEGGED_SDK::LeggedType::A1), unitree_udp(UNITREE_LEGGED_SDK::LOWLEVEL) RL_Real::RL_Real() : unitree_safe(UNITREE_LEGGED_SDK::LeggedType::A1), unitree_udp(UNITREE_LEGGED_SDK::LOWLEVEL)
{ {
// read params from yaml // read params from yaml
robot_name = "a1"; this->robot_name = "a1";
ReadYaml(robot_name); this->ReadYaml(this->robot_name);
// history // history
this->history_obs_buf = ObservationBuffer(1, this->params.num_observations, 6); this->history_obs_buf = ObservationBuffer(1, this->params.num_observations, 6);
unitree_udp.InitCmdData(unitree_low_command); this->unitree_udp.InitCmdData(this->unitree_low_command);
// init // init
torch::autograd::GradMode::set_enabled(false); torch::autograd::GradMode::set_enabled(false);
@ -23,117 +23,117 @@ RL_Real::RL_Real() : unitree_safe(UNITREE_LEGGED_SDK::LeggedType::A1), unitree_u
this->InitControl(); this->InitControl();
// model // model
std::string model_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/" + robot_name + "/" + this->params.model_name; std::string model_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/" + this->robot_name + "/" + this->params.model_name;
this->model = torch::jit::load(model_path); this->model = torch::jit::load(model_path);
// loop // loop
loop_keyboard = std::make_shared<LoopFunc>("loop_keyboard", 0.05 , boost::bind(&RL_Real::KeyboardInterface, this)); this->loop_keyboard = std::make_shared<LoopFunc>("loop_keyboard", 0.05 , boost::bind(&RL_Real::KeyboardInterface, this));
loop_control = std::make_shared<LoopFunc>("loop_control" , 0.002, boost::bind(&RL_Real::RobotControl , this)); this->loop_control = std::make_shared<LoopFunc>("loop_control" , 0.002, boost::bind(&RL_Real::RobotControl , this));
loop_udpSend = std::make_shared<LoopFunc>("loop_udpSend" , 0.002, 3, boost::bind(&RL_Real::UDPSend , this)); this->loop_udpSend = std::make_shared<LoopFunc>("loop_udpSend" , 0.002, 3, boost::bind(&RL_Real::UDPSend , this));
loop_udpRecv = std::make_shared<LoopFunc>("loop_udpRecv" , 0.002, 3, boost::bind(&RL_Real::UDPRecv , this)); this->loop_udpRecv = std::make_shared<LoopFunc>("loop_udpRecv" , 0.002, 3, boost::bind(&RL_Real::UDPRecv , this));
loop_rl = std::make_shared<LoopFunc>("loop_rl" , 0.02 , boost::bind(&RL_Real::RunModel , this)); this->loop_rl = std::make_shared<LoopFunc>("loop_rl" , 0.02 , boost::bind(&RL_Real::RunModel , this));
loop_keyboard->start(); this->loop_keyboard->start();
loop_udpSend->start(); this->loop_udpSend->start();
loop_udpRecv->start(); this->loop_udpRecv->start();
loop_control->start(); this->loop_control->start();
loop_rl->start(); this->loop_rl->start();
#ifdef PLOT #ifdef PLOT
plot_t = std::vector<int>(plot_size, 0); this->plot_t = std::vector<int>(this->plot_size, 0);
plot_real_joint_pos.resize(params.num_of_dofs); this->plot_real_joint_pos.resize(this->params.num_of_dofs);
plot_target_joint_pos.resize(params.num_of_dofs); this->plot_target_joint_pos.resize(this->params.num_of_dofs);
for(auto& vector : plot_real_joint_pos) { vector = std::vector<double>(plot_size, 0); } for(auto& vector : this->plot_real_joint_pos) { vector = std::vector<double>(this->plot_size, 0); }
for(auto& vector : plot_target_joint_pos) { vector = std::vector<double>(plot_size, 0); } for(auto& vector : this->plot_target_joint_pos) { vector = std::vector<double>(this->plot_size, 0); }
loop_plot = std::make_shared<LoopFunc>("loop_plot" , 0.002, boost::bind(&RL_Real::Plot, this)); this->loop_plot = std::make_shared<LoopFunc>("loop_plot" , 0.002, boost::bind(&RL_Real::Plot, this));
loop_plot->start(); this->loop_plot->start();
#endif #endif
#ifdef CSV_LOGGER #ifdef CSV_LOGGER
CSVInit(robot_name); this->CSVInit(this->robot_name);
#endif #endif
} }
RL_Real::~RL_Real() RL_Real::~RL_Real()
{ {
loop_keyboard->shutdown(); this->loop_keyboard->shutdown();
loop_udpSend->shutdown(); this->loop_udpSend->shutdown();
loop_udpRecv->shutdown(); this->loop_udpRecv->shutdown();
loop_control->shutdown(); this->loop_control->shutdown();
loop_rl->shutdown(); this->loop_rl->shutdown();
#ifdef PLOT #ifdef PLOT
loop_plot->shutdown(); this->loop_plot->shutdown();
#endif #endif
std::cout << LOGGER::INFO << "RL_Real exit" << std::endl; std::cout << LOGGER::INFO << "RL_Real exit" << std::endl;
} }
void RL_Real::GetState(RobotState<double> *state) void RL_Real::GetState(RobotState<double> *state)
{ {
unitree_udp.GetRecv(unitree_low_state); this->unitree_udp.GetRecv(this->unitree_low_state);
memcpy(&unitree_joy, unitree_low_state.wirelessRemote, 40); memcpy(&this->unitree_joy, this->unitree_low_state.wirelessRemote, 40);
if((int)unitree_joy.btn.components.R2 == 1) if((int)this->unitree_joy.btn.components.R2 == 1)
{ {
control.control_state = STATE_POS_GETUP; this->control.control_state = STATE_POS_GETUP;
} }
else if((int)unitree_joy.btn.components.R1 == 1) else if((int)this->unitree_joy.btn.components.R1 == 1)
{ {
control.control_state = STATE_RL_INIT; this->control.control_state = STATE_RL_INIT;
} }
else if((int)unitree_joy.btn.components.L2 == 1) else if((int)this->unitree_joy.btn.components.L2 == 1)
{ {
control.control_state = STATE_POS_GETDOWN; this->control.control_state = STATE_POS_GETDOWN;
} }
state->imu.quaternion[3] = unitree_low_state.imu.quaternion[0]; // w state->imu.quaternion[3] = this->unitree_low_state.imu.quaternion[0]; // w
state->imu.quaternion[0] = unitree_low_state.imu.quaternion[1]; // x state->imu.quaternion[0] = this->unitree_low_state.imu.quaternion[1]; // x
state->imu.quaternion[1] = unitree_low_state.imu.quaternion[2]; // y state->imu.quaternion[1] = this->unitree_low_state.imu.quaternion[2]; // y
state->imu.quaternion[2] = unitree_low_state.imu.quaternion[3]; // z state->imu.quaternion[2] = this->unitree_low_state.imu.quaternion[3]; // z
for(int i = 0; i < 3; ++i) for(int i = 0; i < 3; ++i)
{ {
state->imu.gyroscope[i] = unitree_low_state.imu.gyroscope[i]; state->imu.gyroscope[i] = this->unitree_low_state.imu.gyroscope[i];
} }
for(int i = 0; i < params.num_of_dofs; ++i) for(int i = 0; i < this->params.num_of_dofs; ++i)
{ {
state->motor_state.q[i] = unitree_low_state.motorState[state_mapping[i]].q; state->motor_state.q[i] = this->unitree_low_state.motorState[state_mapping[i]].q;
state->motor_state.dq[i] = unitree_low_state.motorState[state_mapping[i]].dq; state->motor_state.dq[i] = this->unitree_low_state.motorState[state_mapping[i]].dq;
state->motor_state.tauEst[i] = unitree_low_state.motorState[state_mapping[i]].tauEst; state->motor_state.tauEst[i] = this->unitree_low_state.motorState[state_mapping[i]].tauEst;
} }
} }
void RL_Real::SetCommand(const RobotCommand<double> *command) void RL_Real::SetCommand(const RobotCommand<double> *command)
{ {
for(int i = 0; i < params.num_of_dofs; ++i) for(int i = 0; i < this->params.num_of_dofs; ++i)
{ {
unitree_low_command.motorCmd[i].mode = 0x0A; this->unitree_low_command.motorCmd[i].mode = 0x0A;
unitree_low_command.motorCmd[i].q = command->motor_command.q[command_mapping[i]]; this->unitree_low_command.motorCmd[i].q = command->motor_command.q[command_mapping[i]];
unitree_low_command.motorCmd[i].dq = command->motor_command.dq[command_mapping[i]]; this->unitree_low_command.motorCmd[i].dq = command->motor_command.dq[command_mapping[i]];
unitree_low_command.motorCmd[i].Kp = command->motor_command.kp[command_mapping[i]]; this->unitree_low_command.motorCmd[i].Kp = command->motor_command.kp[command_mapping[i]];
unitree_low_command.motorCmd[i].Kd = command->motor_command.kd[command_mapping[i]]; this->unitree_low_command.motorCmd[i].Kd = command->motor_command.kd[command_mapping[i]];
unitree_low_command.motorCmd[i].tau = command->motor_command.tau[command_mapping[i]]; this->unitree_low_command.motorCmd[i].tau = command->motor_command.tau[command_mapping[i]];
} }
unitree_safe.PowerProtect(unitree_low_command, unitree_low_state, 6); this->unitree_safe.PowerProtect(this->unitree_low_command, this->unitree_low_state, 6);
// unitree_safe.PositionProtect(unitree_low_command, unitree_low_state); // this->unitree_safe.PositionProtect(this->unitree_low_command, this->unitree_low_state);
unitree_udp.SetSend(unitree_low_command); this->unitree_udp.SetSend(this->unitree_low_command);
} }
void RL_Real::RobotControl() void RL_Real::RobotControl()
{ {
motiontime++; this->motiontime++;
GetState(&robot_state); this->GetState(&this->robot_state);
StateController(&robot_state, &robot_command); this->StateController(&this->robot_state, &this->robot_command);
SetCommand(&robot_command); this->SetCommand(&this->robot_command);
} }
void RL_Real::RunModel() void RL_Real::RunModel()
{ {
if(running_state == STATE_RL_RUNNING) if(this->running_state == STATE_RL_RUNNING)
{ {
this->obs.ang_vel = torch::tensor(robot_state.imu.gyroscope).unsqueeze(0); this->obs.ang_vel = torch::tensor(this->robot_state.imu.gyroscope).unsqueeze(0);
this->obs.commands = torch::tensor({{unitree_joy.ly, -unitree_joy.rx, -unitree_joy.lx}}); this->obs.commands = torch::tensor({{this->unitree_joy.ly, -this->unitree_joy.rx, -this->unitree_joy.lx}});
this->obs.base_quat = torch::tensor(robot_state.imu.quaternion).unsqueeze(0); this->obs.base_quat = torch::tensor(this->robot_state.imu.quaternion).unsqueeze(0);
this->obs.dof_pos = torch::tensor(robot_state.motor_state.q).narrow(0, 0, params.num_of_dofs).unsqueeze(0); this->obs.dof_pos = torch::tensor(this->robot_state.motor_state.q).narrow(0, 0, this->params.num_of_dofs).unsqueeze(0);
this->obs.dof_vel = torch::tensor(robot_state.motor_state.dq).narrow(0, 0, params.num_of_dofs).unsqueeze(0); this->obs.dof_vel = torch::tensor(this->robot_state.motor_state.dq).narrow(0, 0, this->params.num_of_dofs).unsqueeze(0);
torch::Tensor clamped_actions = this->Forward(); torch::Tensor clamped_actions = this->Forward();
@ -146,14 +146,14 @@ void RL_Real::RunModel()
torch::Tensor origin_output_torques = this->ComputeTorques(this->obs.actions); torch::Tensor origin_output_torques = this->ComputeTorques(this->obs.actions);
TorqueProtect(origin_output_torques); this->TorqueProtect(origin_output_torques);
output_torques = torch::clamp(origin_output_torques, -(this->params.torque_limits), this->params.torque_limits); this->output_torques = torch::clamp(origin_output_torques, -(this->params.torque_limits), this->params.torque_limits);
output_dof_pos = this->ComputePosition(this->obs.actions); this->output_dof_pos = this->ComputePosition(this->obs.actions);
#ifdef CSV_LOGGER #ifdef CSV_LOGGER
torch::Tensor tau_est = torch::tensor(robot_state.motor_state.tauEst).unsqueeze(0); torch::Tensor tau_est = torch::tensor(this->robot_state.motor_state.tauEst).unsqueeze(0);
CSVLogger(output_torques, tau_est, this->obs.dof_pos, output_dof_pos, this->obs.dof_vel); this->CSVLogger(this->output_torques, tau_est, this->obs.dof_pos, this->output_dof_pos, this->obs.dof_vel);
#endif #endif
} }
} }
@ -178,10 +178,10 @@ torch::Tensor RL_Real::Forward()
torch::Tensor clamped_obs = this->ComputeObservation(); torch::Tensor clamped_obs = this->ComputeObservation();
history_obs_buf.insert(clamped_obs); this->history_obs_buf.insert(clamped_obs);
history_obs = history_obs_buf.get_obs_vec({0, 1, 2, 3, 4, 5}); this->history_obs = this->history_obs_buf.get_obs_vec({0, 1, 2, 3, 4, 5});
torch::Tensor actions = this->model.forward({history_obs}).toTensor(); torch::Tensor actions = this->model.forward({this->history_obs}).toTensor();
torch::Tensor clamped_actions = torch::clamp(actions, this->params.clip_actions_lower, this->params.clip_actions_upper); torch::Tensor clamped_actions = torch::clamp(actions, this->params.clip_actions_lower, this->params.clip_actions_upper);
@ -190,20 +190,20 @@ torch::Tensor RL_Real::Forward()
void RL_Real::Plot() void RL_Real::Plot()
{ {
plot_t.erase(plot_t.begin()); this->plot_t.erase(this->plot_t.begin());
plot_t.push_back(motiontime); this->plot_t.push_back(this->motiontime);
plt::cla(); plt::cla();
plt::clf(); plt::clf();
for(int i = 0; i < params.num_of_dofs; ++i) for(int i = 0; i < this->params.num_of_dofs; ++i)
{ {
plot_real_joint_pos[i].erase(plot_real_joint_pos[i].begin()); this->plot_real_joint_pos[i].erase(this->plot_real_joint_pos[i].begin());
plot_target_joint_pos[i].erase(plot_target_joint_pos[i].begin()); this->plot_target_joint_pos[i].erase(this->plot_target_joint_pos[i].begin());
plot_real_joint_pos[i].push_back(unitree_low_state.motorState[i].q); this->plot_real_joint_pos[i].push_back(this->unitree_low_state.motorState[i].q);
plot_target_joint_pos[i].push_back(unitree_low_command.motorCmd[i].q); this->plot_target_joint_pos[i].push_back(this->unitree_low_command.motorCmd[i].q);
plt::subplot(4, 3, i + 1); plt::subplot(4, 3, i + 1);
plt::named_plot("_real_joint_pos", plot_t, plot_real_joint_pos[i], "r"); plt::named_plot("_real_joint_pos", this->plot_t, this->plot_real_joint_pos[i], "r");
plt::named_plot("_target_joint_pos", plot_t, plot_target_joint_pos[i], "b"); plt::named_plot("_target_joint_pos", this->plot_t, this->plot_target_joint_pos[i], "b");
plt::xlim(plot_t.front(), plot_t.back()); plt::xlim(this->plot_t.front(), this->plot_t.back());
} }
// plt::legend(); // plt::legend();
plt::pause(0.0001); plt::pause(0.0001);

View File

@ -8,180 +8,180 @@ RL_Sim::RL_Sim()
ros::NodeHandle nh; ros::NodeHandle nh;
// read params from yaml // read params from yaml
nh.param<std::string>("robot_name", robot_name, ""); nh.param<std::string>("robot_name", this->robot_name, "");
ReadYaml(robot_name); this->ReadYaml(this->robot_name);
// history // history
nh.param<bool>("use_history", use_history, ""); nh.param<bool>("use_history", this->use_history, "");
if(use_history) if(this->use_history)
{ {
this->history_obs_buf = ObservationBuffer(1, this->params.num_observations, 6); this->history_obs_buf = ObservationBuffer(1, this->params.num_observations, 6);
} }
// Due to the fact that the robot_state_publisher sorts the joint names alphabetically, // Due to the fact that the robot_state_publisher sorts the joint names alphabetically,
// the mapping table is established according to the order defined in the YAML file // the mapping table is established according to the order defined in the YAML file
std::vector<std::string> sorted_joint_controller_names = params.joint_controller_names; std::vector<std::string> sorted_joint_controller_names = this->params.joint_controller_names;
std::sort(sorted_joint_controller_names.begin(), sorted_joint_controller_names.end()); std::sort(sorted_joint_controller_names.begin(), sorted_joint_controller_names.end());
for(size_t i = 0; i < params.joint_controller_names.size(); ++i) for(size_t i = 0; i < this->params.joint_controller_names.size(); ++i)
{ {
sorted_to_original_index[sorted_joint_controller_names[i]] = i; this->sorted_to_original_index[sorted_joint_controller_names[i]] = i;
} }
mapped_joint_positions = std::vector<double>(params.num_of_dofs, 0.0); this->mapped_joint_positions = std::vector<double>(this->params.num_of_dofs, 0.0);
mapped_joint_velocities = std::vector<double>(params.num_of_dofs, 0.0); this->mapped_joint_velocities = std::vector<double>(this->params.num_of_dofs, 0.0);
mapped_joint_efforts = std::vector<double>(params.num_of_dofs, 0.0); this->mapped_joint_efforts = std::vector<double>(this->params.num_of_dofs, 0.0);
// init // init
torch::autograd::GradMode::set_enabled(false); torch::autograd::GradMode::set_enabled(false);
motor_commands.resize(params.num_of_dofs); this->joint_publishers_commands.resize(this->params.num_of_dofs);
this->InitObservations(); this->InitObservations();
this->InitOutputs(); this->InitOutputs();
this->InitControl(); this->InitControl();
// model // model
std::string model_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/" + robot_name + "/" + this->params.model_name; std::string model_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/" + this->robot_name + "/" + this->params.model_name;
this->model = torch::jit::load(model_path); this->model = torch::jit::load(model_path);
// publisher // publisher
nh.param<std::string>("ros_namespace", ros_namespace, ""); nh.param<std::string>("ros_namespace", this->ros_namespace, "");
for (int i = 0; i < params.num_of_dofs; ++i) for (int i = 0; i < this->params.num_of_dofs; ++i)
{ {
// joint need to rename as xxx_joint // joint need to rename as xxx_joint
torque_publishers[params.joint_controller_names[i]] = nh.advertise<robot_msgs::MotorCommand>( this->joint_publishers[this->params.joint_controller_names[i]] = nh.advertise<robot_msgs::MotorCommand>(
ros_namespace + params.joint_controller_names[i] + "/command", 10); this->ros_namespace + this->params.joint_controller_names[i] + "/command", 10);
} }
// subscriber // subscriber
cmd_vel_subscriber = nh.subscribe<geometry_msgs::Twist>("/cmd_vel", 10, &RL_Sim::CmdvelCallback, this); this->cmd_vel_subscriber = nh.subscribe<geometry_msgs::Twist>("/cmd_vel", 10, &RL_Sim::CmdvelCallback, this);
model_state_subscriber = nh.subscribe<gazebo_msgs::ModelStates>("/gazebo/model_states", 10, &RL_Sim::ModelStatesCallback, this); this->model_state_subscriber = nh.subscribe<gazebo_msgs::ModelStates>("/gazebo/model_states", 10, &RL_Sim::ModelStatesCallback, this);
joint_state_subscriber = nh.subscribe<sensor_msgs::JointState>(ros_namespace + "joint_states", 10, &RL_Sim::JointStatesCallback, this); this->joint_state_subscriber = nh.subscribe<sensor_msgs::JointState>(this->ros_namespace + "joint_states", 10, &RL_Sim::JointStatesCallback, this);
// service // service
gazebo_reset_client = nh.serviceClient<std_srvs::Empty>("/gazebo/reset_simulation"); this->gazebo_reset_client = nh.serviceClient<std_srvs::Empty>("/gazebo/reset_simulation");
// loop // loop
loop_keyboard = std::make_shared<LoopFunc>("loop_keyboard", 0.05 , boost::bind(&RL_Sim::KeyboardInterface, this)); this->loop_keyboard = std::make_shared<LoopFunc>("loop_keyboard", 0.05 , boost::bind(&RL_Sim::KeyboardInterface, this));
loop_control = std::make_shared<LoopFunc>("loop_control" , 0.002, boost::bind(&RL_Sim::RobotControl , this)); this->loop_control = std::make_shared<LoopFunc>("loop_control" , 0.002, boost::bind(&RL_Sim::RobotControl , this));
loop_rl = std::make_shared<LoopFunc>("loop_rl" , 0.02 , boost::bind(&RL_Sim::RunModel , this)); this->loop_rl = std::make_shared<LoopFunc>("loop_rl" , 0.02 , boost::bind(&RL_Sim::RunModel , this));
loop_keyboard->start(); this->loop_keyboard->start();
loop_control->start(); this->loop_control->start();
loop_rl->start(); this->loop_rl->start();
#ifdef PLOT #ifdef PLOT
plot_t = std::vector<int>(plot_size, 0); plot_t = std::vector<int>(this->plot_size, 0);
plot_real_joint_pos.resize(params.num_of_dofs); this->plot_real_joint_pos.resize(this->params.num_of_dofs);
plot_target_joint_pos.resize(params.num_of_dofs); this->plot_target_joint_pos.resize(this->params.num_of_dofs);
for(auto& vector : plot_real_joint_pos) { vector = std::vector<double>(plot_size, 0); } for(auto& vector : this->plot_real_joint_pos) { vector = std::vector<double>(this->plot_size, 0); }
for(auto& vector : plot_target_joint_pos) { vector = std::vector<double>(plot_size, 0); } for(auto& vector : this->plot_target_joint_pos) { vector = std::vector<double>(this->plot_size, 0); }
loop_plot = std::make_shared<LoopFunc>("loop_plot" , 0.002, boost::bind(&RL_Sim::Plot, this)); this->loop_plot = std::make_shared<LoopFunc>("loop_plot" , 0.002, boost::bind(&RL_Sim::Plot, this));
loop_plot->start(); this->loop_plot->start();
#endif #endif
#ifdef CSV_LOGGER #ifdef CSV_LOGGER
CSVInit(robot_name); this->CSVInit(this->robot_name);
#endif #endif
} }
RL_Sim::~RL_Sim() RL_Sim::~RL_Sim()
{ {
loop_keyboard->shutdown(); this->loop_keyboard->shutdown();
loop_control->shutdown(); this->loop_control->shutdown();
loop_rl->shutdown(); this->loop_rl->shutdown();
#ifdef PLOT #ifdef PLOT
loop_plot->shutdown(); this->loop_plot->shutdown();
#endif #endif
std::cout << LOGGER::INFO << "RL_Sim exit" << std::endl; std::cout << LOGGER::INFO << "RL_Sim exit" << std::endl;
} }
void RL_Sim::GetState(RobotState<double> *state) void RL_Sim::GetState(RobotState<double> *state)
{ {
state->imu.quaternion[3] = pose.orientation.w; state->imu.quaternion[3] = this->pose.orientation.w;
state->imu.quaternion[0] = pose.orientation.x; state->imu.quaternion[0] = this->pose.orientation.x;
state->imu.quaternion[1] = pose.orientation.y; state->imu.quaternion[1] = this->pose.orientation.y;
state->imu.quaternion[2] = pose.orientation.z; state->imu.quaternion[2] = this->pose.orientation.z;
state->imu.gyroscope[0] = vel.angular.x; state->imu.gyroscope[0] = this->vel.angular.x;
state->imu.gyroscope[1] = vel.angular.y; state->imu.gyroscope[1] = this->vel.angular.y;
state->imu.gyroscope[2] = vel.angular.z; state->imu.gyroscope[2] = this->vel.angular.z;
// state->imu.accelerometer // state->imu.accelerometer
for(int i = 0; i < params.num_of_dofs; ++i) for(int i = 0; i < this->params.num_of_dofs; ++i)
{ {
state->motor_state.q[i] = mapped_joint_positions[i]; state->motor_state.q[i] = this->mapped_joint_positions[i];
state->motor_state.dq[i] = mapped_joint_velocities[i]; state->motor_state.dq[i] = this->mapped_joint_velocities[i];
state->motor_state.tauEst[i] = mapped_joint_efforts[i]; state->motor_state.tauEst[i] = this->mapped_joint_efforts[i];
} }
} }
void RL_Sim::SetCommand(const RobotCommand<double> *command) void RL_Sim::SetCommand(const RobotCommand<double> *command)
{ {
for(int i = 0; i < params.num_of_dofs; ++i) for(int i = 0; i < this->params.num_of_dofs; ++i)
{ {
motor_commands[i].q = command->motor_command.q[i]; this->joint_publishers_commands[i].q = command->motor_command.q[i];
motor_commands[i].dq = command->motor_command.dq[i]; this->joint_publishers_commands[i].dq = command->motor_command.dq[i];
motor_commands[i].kp = command->motor_command.kp[i]; this->joint_publishers_commands[i].kp = command->motor_command.kp[i];
motor_commands[i].kd = command->motor_command.kd[i]; this->joint_publishers_commands[i].kd = command->motor_command.kd[i];
motor_commands[i].tau = command->motor_command.tau[i]; this->joint_publishers_commands[i].tau = command->motor_command.tau[i];
} }
for(int i = 0; i < params.num_of_dofs; ++i) for(int i = 0; i < this->params.num_of_dofs; ++i)
{ {
torque_publishers[params.joint_controller_names[i]].publish(motor_commands[i]); this->joint_publishers[this->params.joint_controller_names[i]].publish(this->joint_publishers_commands[i]);
} }
} }
void RL_Sim::RobotControl() void RL_Sim::RobotControl()
{ {
motiontime++; this->motiontime++;
if(control.control_state == STATE_RESET_SIMULATION) if(this->control.control_state == STATE_RESET_SIMULATION)
{ {
control.control_state = STATE_WAITING; this->control.control_state = STATE_WAITING;
std_srvs::Empty srv; std_srvs::Empty srv;
gazebo_reset_client.call(srv); this->gazebo_reset_client.call(srv);
} }
GetState(&robot_state); this->GetState(&this->robot_state);
StateController(&robot_state, &robot_command); this->StateController(&this->robot_state, &this->robot_command);
SetCommand(&robot_command); this->SetCommand(&this->robot_command);
} }
void RL_Sim::ModelStatesCallback(const gazebo_msgs::ModelStates::ConstPtr &msg) void RL_Sim::ModelStatesCallback(const gazebo_msgs::ModelStates::ConstPtr &msg)
{ {
vel = msg->twist[2]; this->vel = msg->twist[2];
pose = msg->pose[2]; this->pose = msg->pose[2];
} }
void RL_Sim::CmdvelCallback(const geometry_msgs::Twist::ConstPtr &msg) void RL_Sim::CmdvelCallback(const geometry_msgs::Twist::ConstPtr &msg)
{ {
cmd_vel = *msg; this->cmd_vel = *msg;
} }
void RL_Sim::MapData(const std::vector<double>& source_data, std::vector<double>& target_data) void RL_Sim::MapData(const std::vector<double>& source_data, std::vector<double>& target_data)
{ {
for(size_t i = 0; i < source_data.size(); ++i) for(size_t i = 0; i < source_data.size(); ++i)
{ {
target_data[i] = source_data[sorted_to_original_index[params.joint_controller_names[i]]]; target_data[i] = source_data[this->sorted_to_original_index[this->params.joint_controller_names[i]]];
} }
} }
void RL_Sim::JointStatesCallback(const sensor_msgs::JointState::ConstPtr &msg) void RL_Sim::JointStatesCallback(const sensor_msgs::JointState::ConstPtr &msg)
{ {
MapData(msg->position, mapped_joint_positions); MapData(msg->position, this->mapped_joint_positions);
MapData(msg->velocity, mapped_joint_velocities); MapData(msg->velocity, this->mapped_joint_velocities);
MapData(msg->effort, mapped_joint_efforts); MapData(msg->effort, this->mapped_joint_efforts);
} }
void RL_Sim::RunModel() void RL_Sim::RunModel()
{ {
if(running_state == STATE_RL_RUNNING) if(this->running_state == STATE_RL_RUNNING)
{ {
// this->obs.lin_vel = torch::tensor({{vel.linear.x, vel.linear.y, vel.linear.z}}); // this->obs.lin_vel = torch::tensor({{this->vel.linear.x, this->vel.linear.y, this->vel.linear.z}});
this->obs.ang_vel = torch::tensor(robot_state.imu.gyroscope).unsqueeze(0); this->obs.ang_vel = torch::tensor(this->robot_state.imu.gyroscope).unsqueeze(0);
// this->obs.commands = torch::tensor({{cmd_vel.linear.x, cmd_vel.linear.y, cmd_vel.angular.z}}); // this->obs.commands = torch::tensor({{this->cmd_vel.linear.x, this->cmd_vel.linear.y, this->cmd_vel.angular.z}});
this->obs.commands = torch::tensor({{control.x, control.y, control.yaw}}); this->obs.commands = torch::tensor({{this->control.x, this->control.y, this->control.yaw}});
this->obs.base_quat = torch::tensor(robot_state.imu.quaternion).unsqueeze(0); this->obs.base_quat = torch::tensor(this->robot_state.imu.quaternion).unsqueeze(0);
this->obs.dof_pos = torch::tensor(robot_state.motor_state.q).narrow(0, 0, params.num_of_dofs).unsqueeze(0); this->obs.dof_pos = torch::tensor(this->robot_state.motor_state.q).narrow(0, 0, this->params.num_of_dofs).unsqueeze(0);
this->obs.dof_vel = torch::tensor(robot_state.motor_state.dq).narrow(0, 0, params.num_of_dofs).unsqueeze(0); this->obs.dof_vel = torch::tensor(this->robot_state.motor_state.dq).narrow(0, 0, this->params.num_of_dofs).unsqueeze(0);
torch::Tensor clamped_actions = this->Forward(); torch::Tensor clamped_actions = this->Forward();
@ -194,14 +194,14 @@ void RL_Sim::RunModel()
torch::Tensor origin_output_torques = this->ComputeTorques(this->obs.actions); torch::Tensor origin_output_torques = this->ComputeTorques(this->obs.actions);
TorqueProtect(origin_output_torques); this->TorqueProtect(origin_output_torques);
output_torques = torch::clamp(origin_output_torques, -(this->params.torque_limits), this->params.torque_limits); this->output_torques = torch::clamp(origin_output_torques, -(this->params.torque_limits), this->params.torque_limits);
output_dof_pos = this->ComputePosition(this->obs.actions); this->output_dof_pos = this->ComputePosition(this->obs.actions);
#ifdef CSV_LOGGER #ifdef CSV_LOGGER
torch::Tensor tau_est = torch::tensor(mapped_joint_efforts).unsqueeze(0); torch::Tensor tau_est = torch::tensor(this->mapped_joint_efforts).unsqueeze(0);
CSVLogger(output_torques, tau_est, this->obs.dof_pos, output_dof_pos, this->obs.dof_vel); this->CSVLogger(this->output_torques, tau_est, this->obs.dof_pos, this->output_dof_pos, this->obs.dof_vel);
#endif #endif
} }
} }
@ -229,11 +229,11 @@ torch::Tensor RL_Sim::Forward()
torch::Tensor actions; torch::Tensor actions;
if(use_history) if(this->use_history)
{ {
history_obs_buf.insert(clamped_obs); this->history_obs_buf.insert(clamped_obs);
history_obs = history_obs_buf.get_obs_vec({0, 1, 2, 3, 4, 5}); this->history_obs = this->history_obs_buf.get_obs_vec({0, 1, 2, 3, 4, 5});
actions = this->model.forward({history_obs}).toTensor(); actions = this->model.forward({this->history_obs}).toTensor();
} }
else else
{ {
@ -247,20 +247,20 @@ torch::Tensor RL_Sim::Forward()
void RL_Sim::Plot() void RL_Sim::Plot()
{ {
plot_t.erase(plot_t.begin()); this->plot_t.erase(this->plot_t.begin());
plot_t.push_back(motiontime); this->plot_t.push_back(this->motiontime);
plt::cla(); plt::cla();
plt::clf(); plt::clf();
for(int i = 0; i < params.num_of_dofs; ++i) for(int i = 0; i < this->params.num_of_dofs; ++i)
{ {
plot_real_joint_pos[i].erase(plot_real_joint_pos[i].begin()); this->plot_real_joint_pos[i].erase(this->plot_real_joint_pos[i].begin());
plot_target_joint_pos[i].erase(plot_target_joint_pos[i].begin()); this->plot_target_joint_pos[i].erase(this->plot_target_joint_pos[i].begin());
plot_real_joint_pos[i].push_back(mapped_joint_positions[i]); this->plot_real_joint_pos[i].push_back(this->mapped_joint_positions[i]);
plot_target_joint_pos[i].push_back(motor_commands[i].q); this->plot_target_joint_pos[i].push_back(this->joint_publishers_commands[i].q);
plt::subplot(4, 3, i+1); plt::subplot(4, 3, i+1);
plt::named_plot("_real_joint_pos", plot_t, plot_real_joint_pos[i], "r"); plt::named_plot("_real_joint_pos", this->plot_t, this->plot_real_joint_pos[i], "r");
plt::named_plot("_target_joint_pos", plot_t, plot_target_joint_pos[i], "b"); plt::named_plot("_target_joint_pos", this->plot_t, this->plot_target_joint_pos[i], "b");
plt::xlim(plot_t.front(), plot_t.back()); plt::xlim(this->plot_t.front(), this->plot_t.back());
} }
// plt::legend(); // plt::legend();
plt::pause(0.0001); plt::pause(0.0001);