diff --git a/src/rl_sar/include/rl_sim.hpp b/src/rl_sar/include/rl_sim.hpp index 5606d55..be5ce13 100644 --- a/src/rl_sar/include/rl_sim.hpp +++ b/src/rl_sar/include/rl_sim.hpp @@ -51,12 +51,12 @@ private: geometry_msgs::Twist vel; geometry_msgs::Pose pose; geometry_msgs::Twist cmd_vel; - std::vector torque_command_topics; ros::Subscriber model_state_subscriber; ros::Subscriber joint_state_subscriber; ros::Subscriber cmd_vel_subscriber; - std::map torque_publishers; ros::ServiceClient gazebo_reset_client; + std::map joint_publishers; + std::vector joint_publishers_commands; void ModelStatesCallback(const gazebo_msgs::ModelStates::ConstPtr &msg); void JointStatesCallback(const sensor_msgs::JointState::ConstPtr &msg); void CmdvelCallback(const geometry_msgs::Twist::ConstPtr &msg); @@ -64,7 +64,6 @@ private: // others int motiontime = 0; std::map sorted_to_original_index; - std::vector motor_commands; std::vector mapped_joint_positions; std::vector mapped_joint_velocities; std::vector mapped_joint_efforts; diff --git a/src/rl_sar/library/rl_sdk/rl_sdk.cpp b/src/rl_sar/library/rl_sdk/rl_sdk.cpp index 215610a..89dfc25 100644 --- a/src/rl_sar/library/rl_sdk/rl_sdk.cpp +++ b/src/rl_sar/library/rl_sdk/rl_sdk.cpp @@ -38,14 +38,14 @@ void RL::InitObservations() this->obs.commands = torch::tensor({{0.0, 0.0, 0.0}}); this->obs.base_quat = torch::tensor({{0.0, 0.0, 0.0, 1.0}}); this->obs.dof_pos = this->params.default_dof_pos; - this->obs.dof_vel = torch::zeros({1, params.num_of_dofs}); - this->obs.actions = torch::zeros({1, params.num_of_dofs}); + this->obs.dof_vel = torch::zeros({1, this->params.num_of_dofs}); + this->obs.actions = torch::zeros({1, this->params.num_of_dofs}); } void RL::InitOutputs() { - this->output_torques = torch::zeros({1, params.num_of_dofs}); - this->output_dof_pos = params.default_dof_pos; + this->output_torques = torch::zeros({1, this->params.num_of_dofs}); + this->output_dof_pos = this->params.default_dof_pos; } void RL::InitControl() @@ -88,104 +88,104 @@ void RL::StateController(const RobotState *state, RobotCommand * static float getdown_percent = 0.0; // waiting - if(running_state == STATE_WAITING) + if(this->running_state == STATE_WAITING) { - for(int i = 0; i < params.num_of_dofs; ++i) + for(int i = 0; i < this->params.num_of_dofs; ++i) { command->motor_command.q[i] = state->motor_state.q[i]; } - if(control.control_state == STATE_POS_GETUP) + if(this->control.control_state == STATE_POS_GETUP) { - control.control_state = STATE_WAITING; + this->control.control_state = STATE_WAITING; getup_percent = 0.0; - for(int i = 0; i < params.num_of_dofs; ++i) + for(int i = 0; i < this->params.num_of_dofs; ++i) { now_state.motor_state.q[i] = state->motor_state.q[i]; start_state.motor_state.q[i] = now_state.motor_state.q[i]; } - running_state = STATE_POS_GETUP; + this->running_state = STATE_POS_GETUP; } } // stand up (position control) - else if(running_state == STATE_POS_GETUP) + else if(this->running_state == STATE_POS_GETUP) { if(getup_percent < 1.0) { getup_percent += 1 / 1000.0; getup_percent = getup_percent > 1.0 ? 1.0 : getup_percent; - for(int i = 0; i < params.num_of_dofs; ++i) + for(int i = 0; i < this->params.num_of_dofs; ++i) { - command->motor_command.q[i] = (1 - getup_percent) * now_state.motor_state.q[i] + getup_percent * params.default_dof_pos[0][i].item(); + command->motor_command.q[i] = (1 - getup_percent) * now_state.motor_state.q[i] + getup_percent * this->params.default_dof_pos[0][i].item(); command->motor_command.dq[i] = 0; - command->motor_command.kp[i] = params.fixed_kp[0][i].item(); - command->motor_command.kd[i] = params.fixed_kd[0][i].item(); + command->motor_command.kp[i] = this->params.fixed_kp[0][i].item(); + command->motor_command.kd[i] = this->params.fixed_kd[0][i].item(); command->motor_command.tau[i] = 0; } std::cout << LOGGER::INFO << "Getting up " << std::fixed << std::setprecision(2) << getup_percent * 100.0 << "%\r"; } - if(control.control_state == STATE_RL_INIT) + if(this->control.control_state == STATE_RL_INIT) { std::cout << std::endl; - control.control_state = STATE_WAITING; - running_state = STATE_RL_INIT; + this->control.control_state = STATE_WAITING; + this->running_state = STATE_RL_INIT; } - else if(control.control_state == STATE_POS_GETDOWN) + else if(this->control.control_state == STATE_POS_GETDOWN) { - control.control_state = STATE_WAITING; + this->control.control_state = STATE_WAITING; getdown_percent = 0.0; - for(int i = 0; i < params.num_of_dofs; ++i) + for(int i = 0; i < this->params.num_of_dofs; ++i) { now_state.motor_state.q[i] = state->motor_state.q[i]; } - running_state = STATE_POS_GETDOWN; + this->running_state = STATE_POS_GETDOWN; } } // init obs and start rl loop - else if(running_state == STATE_RL_INIT) + else if(this->running_state == STATE_RL_INIT) { if(getup_percent == 1) { - running_state = STATE_RL_RUNNING; + this->running_state = STATE_RL_RUNNING; this->InitObservations(); this->InitOutputs(); this->InitControl(); } } // rl loop - else if(running_state == STATE_RL_RUNNING) + else if(this->running_state == STATE_RL_RUNNING) { - for(int i = 0; i < params.num_of_dofs; ++i) + for(int i = 0; i < this->params.num_of_dofs; ++i) { - command->motor_command.q[i] = output_dof_pos[0][i].item(); + command->motor_command.q[i] = this->output_dof_pos[0][i].item(); command->motor_command.dq[i] = 0; - command->motor_command.kp[i] = params.rl_kp[0][i].item(); - command->motor_command.kd[i] = params.rl_kd[0][i].item(); + command->motor_command.kp[i] = this->params.rl_kp[0][i].item(); + command->motor_command.kd[i] = this->params.rl_kd[0][i].item(); command->motor_command.tau[i] = 0; } - if(control.control_state == STATE_POS_GETDOWN) + if(this->control.control_state == STATE_POS_GETDOWN) { - control.control_state = STATE_WAITING; + this->control.control_state = STATE_WAITING; getdown_percent = 0.0; - for(int i = 0; i < params.num_of_dofs; ++i) + for(int i = 0; i < this->params.num_of_dofs; ++i) { now_state.motor_state.q[i] = state->motor_state.q[i]; } - running_state = STATE_POS_GETDOWN; + this->running_state = STATE_POS_GETDOWN; } } // get down (position control) - else if(running_state == STATE_POS_GETDOWN) + else if(this->running_state == STATE_POS_GETDOWN) { if(getdown_percent < 1.0) { getdown_percent += 1 / 1000.0; getdown_percent = getdown_percent > 1.0 ? 1.0 : getdown_percent; - for(int i = 0; i < params.num_of_dofs; ++i) + for(int i = 0; i < this->params.num_of_dofs; ++i) { command->motor_command.q[i] = (1 - getdown_percent) * now_state.motor_state.q[i] + getdown_percent * start_state.motor_state.q[i]; command->motor_command.dq[i] = 0; - command->motor_command.kp[i] = params.fixed_kp[0][i].item(); - command->motor_command.kd[i] = params.fixed_kd[0][i].item(); + command->motor_command.kp[i] = this->params.fixed_kp[0][i].item(); + command->motor_command.kd[i] = this->params.fixed_kd[0][i].item(); command->motor_command.tau[i] = 0; } std::cout << LOGGER::INFO << "Getting down " << std::fixed << std::setprecision(2) << getdown_percent * 100.0 << "%\r"; @@ -193,7 +193,7 @@ void RL::StateController(const RobotState *state, RobotCommand * if(getdown_percent == 1) { std::cout << std::endl; - running_state = STATE_WAITING; + this->running_state = STATE_WAITING; this->InitObservations(); this->InitOutputs(); this->InitControl(); @@ -229,7 +229,7 @@ void RL::TorqueProtect(torch::Tensor origin_output_torques) std::cout << LOGGER::ERROR << "Torque(" << index+1 << ")=" << value << " out of range(" << limit_lower << ", " << limit_upper << ")" << std::endl; std::cout << LOGGER::ERROR << "Switching to STATE_POS_GETDOWN"<< std::endl; } - control.control_state = STATE_POS_GETDOWN; + this->control.control_state = STATE_POS_GETDOWN; } } @@ -254,9 +254,9 @@ static bool kbhit() void RL::KeyboardInterface() { - if(running_state == STATE_RL_RUNNING) + if(this->running_state == STATE_RL_RUNNING) { - std::cout << LOGGER::INFO << "RL Controller x:" << control.x << " y:" << control.y << " yaw:" << control.yaw << " \r"; + std::cout << LOGGER::INFO << "RL Controller x:" << this->control.x << " y:" << this->control.y << " yaw:" << this->control.yaw << " \r"; } if(kbhit()) @@ -264,20 +264,20 @@ void RL::KeyboardInterface() int c = fgetc(stdin); switch(c) { - case '0': control.control_state = STATE_POS_GETUP; break; - case 'p': control.control_state = STATE_RL_INIT; break; - case '1': control.control_state = STATE_POS_GETDOWN; break; + case '0': this->control.control_state = STATE_POS_GETUP; break; + case 'p': this->control.control_state = STATE_RL_INIT; break; + case '1': this->control.control_state = STATE_POS_GETDOWN; break; case 'q': break; - case 'w': control.x += 0.1; break; - case 's': control.x -= 0.1; break; - case 'a': control.yaw += 0.1; break; - case 'd': control.yaw -= 0.1; break; + case 'w': this->control.x += 0.1; break; + case 's': this->control.x -= 0.1; break; + case 'a': this->control.yaw += 0.1; break; + case 'd': this->control.yaw -= 0.1; break; case 'i': break; case 'k': break; - case 'j': control.y += 0.1; break; - case 'l': control.y -= 0.1; break; - case ' ': control.x = 0; control.y = 0; control.yaw = 0; break; - case 'r': control.control_state = STATE_RESET_SIMULATION; break; + case 'j': this->control.y += 0.1; break; + case 'l': this->control.y -= 0.1; break; + case ' ': this->control.x = 0; this->control.y = 0; this->control.yaw = 0; break; + case 'r': this->control.control_state = STATE_RESET_SIMULATION; break; default: break; } } diff --git a/src/rl_sar/src/rl_real_a1.cpp b/src/rl_sar/src/rl_real_a1.cpp index 5b3d8fe..6b589ab 100644 --- a/src/rl_sar/src/rl_real_a1.cpp +++ b/src/rl_sar/src/rl_real_a1.cpp @@ -8,13 +8,13 @@ RL_Real rl_sar; RL_Real::RL_Real() : unitree_safe(UNITREE_LEGGED_SDK::LeggedType::A1), unitree_udp(UNITREE_LEGGED_SDK::LOWLEVEL) { // read params from yaml - robot_name = "a1"; - ReadYaml(robot_name); + this->robot_name = "a1"; + this->ReadYaml(this->robot_name); // history this->history_obs_buf = ObservationBuffer(1, this->params.num_observations, 6); - unitree_udp.InitCmdData(unitree_low_command); + this->unitree_udp.InitCmdData(this->unitree_low_command); // init torch::autograd::GradMode::set_enabled(false); @@ -23,117 +23,117 @@ RL_Real::RL_Real() : unitree_safe(UNITREE_LEGGED_SDK::LeggedType::A1), unitree_u this->InitControl(); // model - std::string model_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/" + robot_name + "/" + this->params.model_name; + std::string model_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/" + this->robot_name + "/" + this->params.model_name; this->model = torch::jit::load(model_path); // loop - loop_keyboard = std::make_shared("loop_keyboard", 0.05 , boost::bind(&RL_Real::KeyboardInterface, this)); - loop_control = std::make_shared("loop_control" , 0.002, boost::bind(&RL_Real::RobotControl , this)); - loop_udpSend = std::make_shared("loop_udpSend" , 0.002, 3, boost::bind(&RL_Real::UDPSend , this)); - loop_udpRecv = std::make_shared("loop_udpRecv" , 0.002, 3, boost::bind(&RL_Real::UDPRecv , this)); - loop_rl = std::make_shared("loop_rl" , 0.02 , boost::bind(&RL_Real::RunModel , this)); - loop_keyboard->start(); - loop_udpSend->start(); - loop_udpRecv->start(); - loop_control->start(); - loop_rl->start(); + this->loop_keyboard = std::make_shared("loop_keyboard", 0.05 , boost::bind(&RL_Real::KeyboardInterface, this)); + this->loop_control = std::make_shared("loop_control" , 0.002, boost::bind(&RL_Real::RobotControl , this)); + this->loop_udpSend = std::make_shared("loop_udpSend" , 0.002, 3, boost::bind(&RL_Real::UDPSend , this)); + this->loop_udpRecv = std::make_shared("loop_udpRecv" , 0.002, 3, boost::bind(&RL_Real::UDPRecv , this)); + this->loop_rl = std::make_shared("loop_rl" , 0.02 , boost::bind(&RL_Real::RunModel , this)); + this->loop_keyboard->start(); + this->loop_udpSend->start(); + this->loop_udpRecv->start(); + this->loop_control->start(); + this->loop_rl->start(); #ifdef PLOT - plot_t = std::vector(plot_size, 0); - plot_real_joint_pos.resize(params.num_of_dofs); - plot_target_joint_pos.resize(params.num_of_dofs); - for(auto& vector : plot_real_joint_pos) { vector = std::vector(plot_size, 0); } - for(auto& vector : plot_target_joint_pos) { vector = std::vector(plot_size, 0); } - loop_plot = std::make_shared("loop_plot" , 0.002, boost::bind(&RL_Real::Plot, this)); - loop_plot->start(); + this->plot_t = std::vector(this->plot_size, 0); + this->plot_real_joint_pos.resize(this->params.num_of_dofs); + this->plot_target_joint_pos.resize(this->params.num_of_dofs); + for(auto& vector : this->plot_real_joint_pos) { vector = std::vector(this->plot_size, 0); } + for(auto& vector : this->plot_target_joint_pos) { vector = std::vector(this->plot_size, 0); } + this->loop_plot = std::make_shared("loop_plot" , 0.002, boost::bind(&RL_Real::Plot, this)); + this->loop_plot->start(); #endif #ifdef CSV_LOGGER - CSVInit(robot_name); + this->CSVInit(this->robot_name); #endif } RL_Real::~RL_Real() { - loop_keyboard->shutdown(); - loop_udpSend->shutdown(); - loop_udpRecv->shutdown(); - loop_control->shutdown(); - loop_rl->shutdown(); + this->loop_keyboard->shutdown(); + this->loop_udpSend->shutdown(); + this->loop_udpRecv->shutdown(); + this->loop_control->shutdown(); + this->loop_rl->shutdown(); #ifdef PLOT - loop_plot->shutdown(); + this->loop_plot->shutdown(); #endif std::cout << LOGGER::INFO << "RL_Real exit" << std::endl; } void RL_Real::GetState(RobotState *state) { - unitree_udp.GetRecv(unitree_low_state); - memcpy(&unitree_joy, unitree_low_state.wirelessRemote, 40); + this->unitree_udp.GetRecv(this->unitree_low_state); + memcpy(&this->unitree_joy, this->unitree_low_state.wirelessRemote, 40); - if((int)unitree_joy.btn.components.R2 == 1) + if((int)this->unitree_joy.btn.components.R2 == 1) { - control.control_state = STATE_POS_GETUP; + this->control.control_state = STATE_POS_GETUP; } - else if((int)unitree_joy.btn.components.R1 == 1) + else if((int)this->unitree_joy.btn.components.R1 == 1) { - control.control_state = STATE_RL_INIT; + this->control.control_state = STATE_RL_INIT; } - else if((int)unitree_joy.btn.components.L2 == 1) + else if((int)this->unitree_joy.btn.components.L2 == 1) { - control.control_state = STATE_POS_GETDOWN; + this->control.control_state = STATE_POS_GETDOWN; } - state->imu.quaternion[3] = unitree_low_state.imu.quaternion[0]; // w - state->imu.quaternion[0] = unitree_low_state.imu.quaternion[1]; // x - state->imu.quaternion[1] = unitree_low_state.imu.quaternion[2]; // y - state->imu.quaternion[2] = unitree_low_state.imu.quaternion[3]; // z + state->imu.quaternion[3] = this->unitree_low_state.imu.quaternion[0]; // w + state->imu.quaternion[0] = this->unitree_low_state.imu.quaternion[1]; // x + state->imu.quaternion[1] = this->unitree_low_state.imu.quaternion[2]; // y + state->imu.quaternion[2] = this->unitree_low_state.imu.quaternion[3]; // z for(int i = 0; i < 3; ++i) { - state->imu.gyroscope[i] = unitree_low_state.imu.gyroscope[i]; + state->imu.gyroscope[i] = this->unitree_low_state.imu.gyroscope[i]; } - for(int i = 0; i < params.num_of_dofs; ++i) + for(int i = 0; i < this->params.num_of_dofs; ++i) { - state->motor_state.q[i] = unitree_low_state.motorState[state_mapping[i]].q; - state->motor_state.dq[i] = unitree_low_state.motorState[state_mapping[i]].dq; - state->motor_state.tauEst[i] = unitree_low_state.motorState[state_mapping[i]].tauEst; + state->motor_state.q[i] = this->unitree_low_state.motorState[state_mapping[i]].q; + state->motor_state.dq[i] = this->unitree_low_state.motorState[state_mapping[i]].dq; + state->motor_state.tauEst[i] = this->unitree_low_state.motorState[state_mapping[i]].tauEst; } } void RL_Real::SetCommand(const RobotCommand *command) { - for(int i = 0; i < params.num_of_dofs; ++i) + for(int i = 0; i < this->params.num_of_dofs; ++i) { - unitree_low_command.motorCmd[i].mode = 0x0A; - unitree_low_command.motorCmd[i].q = command->motor_command.q[command_mapping[i]]; - unitree_low_command.motorCmd[i].dq = command->motor_command.dq[command_mapping[i]]; - unitree_low_command.motorCmd[i].Kp = command->motor_command.kp[command_mapping[i]]; - unitree_low_command.motorCmd[i].Kd = command->motor_command.kd[command_mapping[i]]; - unitree_low_command.motorCmd[i].tau = command->motor_command.tau[command_mapping[i]]; + this->unitree_low_command.motorCmd[i].mode = 0x0A; + this->unitree_low_command.motorCmd[i].q = command->motor_command.q[command_mapping[i]]; + this->unitree_low_command.motorCmd[i].dq = command->motor_command.dq[command_mapping[i]]; + this->unitree_low_command.motorCmd[i].Kp = command->motor_command.kp[command_mapping[i]]; + this->unitree_low_command.motorCmd[i].Kd = command->motor_command.kd[command_mapping[i]]; + this->unitree_low_command.motorCmd[i].tau = command->motor_command.tau[command_mapping[i]]; } - unitree_safe.PowerProtect(unitree_low_command, unitree_low_state, 6); - // unitree_safe.PositionProtect(unitree_low_command, unitree_low_state); - unitree_udp.SetSend(unitree_low_command); + this->unitree_safe.PowerProtect(this->unitree_low_command, this->unitree_low_state, 6); + // this->unitree_safe.PositionProtect(this->unitree_low_command, this->unitree_low_state); + this->unitree_udp.SetSend(this->unitree_low_command); } void RL_Real::RobotControl() { - motiontime++; + this->motiontime++; - GetState(&robot_state); - StateController(&robot_state, &robot_command); - SetCommand(&robot_command); + this->GetState(&this->robot_state); + this->StateController(&this->robot_state, &this->robot_command); + this->SetCommand(&this->robot_command); } void RL_Real::RunModel() { - if(running_state == STATE_RL_RUNNING) + if(this->running_state == STATE_RL_RUNNING) { - this->obs.ang_vel = torch::tensor(robot_state.imu.gyroscope).unsqueeze(0); - this->obs.commands = torch::tensor({{unitree_joy.ly, -unitree_joy.rx, -unitree_joy.lx}}); - this->obs.base_quat = torch::tensor(robot_state.imu.quaternion).unsqueeze(0); - this->obs.dof_pos = torch::tensor(robot_state.motor_state.q).narrow(0, 0, params.num_of_dofs).unsqueeze(0); - this->obs.dof_vel = torch::tensor(robot_state.motor_state.dq).narrow(0, 0, params.num_of_dofs).unsqueeze(0); + this->obs.ang_vel = torch::tensor(this->robot_state.imu.gyroscope).unsqueeze(0); + this->obs.commands = torch::tensor({{this->unitree_joy.ly, -this->unitree_joy.rx, -this->unitree_joy.lx}}); + this->obs.base_quat = torch::tensor(this->robot_state.imu.quaternion).unsqueeze(0); + this->obs.dof_pos = torch::tensor(this->robot_state.motor_state.q).narrow(0, 0, this->params.num_of_dofs).unsqueeze(0); + this->obs.dof_vel = torch::tensor(this->robot_state.motor_state.dq).narrow(0, 0, this->params.num_of_dofs).unsqueeze(0); torch::Tensor clamped_actions = this->Forward(); @@ -146,14 +146,14 @@ void RL_Real::RunModel() torch::Tensor origin_output_torques = this->ComputeTorques(this->obs.actions); - TorqueProtect(origin_output_torques); + this->TorqueProtect(origin_output_torques); - output_torques = torch::clamp(origin_output_torques, -(this->params.torque_limits), this->params.torque_limits); - output_dof_pos = this->ComputePosition(this->obs.actions); + this->output_torques = torch::clamp(origin_output_torques, -(this->params.torque_limits), this->params.torque_limits); + this->output_dof_pos = this->ComputePosition(this->obs.actions); #ifdef CSV_LOGGER - torch::Tensor tau_est = torch::tensor(robot_state.motor_state.tauEst).unsqueeze(0); - CSVLogger(output_torques, tau_est, this->obs.dof_pos, output_dof_pos, this->obs.dof_vel); + torch::Tensor tau_est = torch::tensor(this->robot_state.motor_state.tauEst).unsqueeze(0); + this->CSVLogger(this->output_torques, tau_est, this->obs.dof_pos, this->output_dof_pos, this->obs.dof_vel); #endif } } @@ -178,10 +178,10 @@ torch::Tensor RL_Real::Forward() torch::Tensor clamped_obs = this->ComputeObservation(); - history_obs_buf.insert(clamped_obs); - history_obs = history_obs_buf.get_obs_vec({0, 1, 2, 3, 4, 5}); + this->history_obs_buf.insert(clamped_obs); + this->history_obs = this->history_obs_buf.get_obs_vec({0, 1, 2, 3, 4, 5}); - torch::Tensor actions = this->model.forward({history_obs}).toTensor(); + torch::Tensor actions = this->model.forward({this->history_obs}).toTensor(); torch::Tensor clamped_actions = torch::clamp(actions, this->params.clip_actions_lower, this->params.clip_actions_upper); @@ -190,20 +190,20 @@ torch::Tensor RL_Real::Forward() void RL_Real::Plot() { - plot_t.erase(plot_t.begin()); - plot_t.push_back(motiontime); + this->plot_t.erase(this->plot_t.begin()); + this->plot_t.push_back(this->motiontime); plt::cla(); plt::clf(); - for(int i = 0; i < params.num_of_dofs; ++i) + for(int i = 0; i < this->params.num_of_dofs; ++i) { - plot_real_joint_pos[i].erase(plot_real_joint_pos[i].begin()); - plot_target_joint_pos[i].erase(plot_target_joint_pos[i].begin()); - plot_real_joint_pos[i].push_back(unitree_low_state.motorState[i].q); - plot_target_joint_pos[i].push_back(unitree_low_command.motorCmd[i].q); + this->plot_real_joint_pos[i].erase(this->plot_real_joint_pos[i].begin()); + this->plot_target_joint_pos[i].erase(this->plot_target_joint_pos[i].begin()); + this->plot_real_joint_pos[i].push_back(this->unitree_low_state.motorState[i].q); + this->plot_target_joint_pos[i].push_back(this->unitree_low_command.motorCmd[i].q); plt::subplot(4, 3, i + 1); - plt::named_plot("_real_joint_pos", plot_t, plot_real_joint_pos[i], "r"); - plt::named_plot("_target_joint_pos", plot_t, plot_target_joint_pos[i], "b"); - plt::xlim(plot_t.front(), plot_t.back()); + plt::named_plot("_real_joint_pos", this->plot_t, this->plot_real_joint_pos[i], "r"); + plt::named_plot("_target_joint_pos", this->plot_t, this->plot_target_joint_pos[i], "b"); + plt::xlim(this->plot_t.front(), this->plot_t.back()); } // plt::legend(); plt::pause(0.0001); diff --git a/src/rl_sar/src/rl_sim.cpp b/src/rl_sar/src/rl_sim.cpp index 821199c..c01604b 100644 --- a/src/rl_sar/src/rl_sim.cpp +++ b/src/rl_sar/src/rl_sim.cpp @@ -8,180 +8,180 @@ RL_Sim::RL_Sim() ros::NodeHandle nh; // read params from yaml - nh.param("robot_name", robot_name, ""); - ReadYaml(robot_name); + nh.param("robot_name", this->robot_name, ""); + this->ReadYaml(this->robot_name); // history - nh.param("use_history", use_history, ""); - if(use_history) + nh.param("use_history", this->use_history, ""); + if(this->use_history) { this->history_obs_buf = ObservationBuffer(1, this->params.num_observations, 6); } // Due to the fact that the robot_state_publisher sorts the joint names alphabetically, // the mapping table is established according to the order defined in the YAML file - std::vector sorted_joint_controller_names = params.joint_controller_names; + std::vector sorted_joint_controller_names = this->params.joint_controller_names; std::sort(sorted_joint_controller_names.begin(), sorted_joint_controller_names.end()); - for(size_t i = 0; i < params.joint_controller_names.size(); ++i) + for(size_t i = 0; i < this->params.joint_controller_names.size(); ++i) { - sorted_to_original_index[sorted_joint_controller_names[i]] = i; + this->sorted_to_original_index[sorted_joint_controller_names[i]] = i; } - mapped_joint_positions = std::vector(params.num_of_dofs, 0.0); - mapped_joint_velocities = std::vector(params.num_of_dofs, 0.0); - mapped_joint_efforts = std::vector(params.num_of_dofs, 0.0); + this->mapped_joint_positions = std::vector(this->params.num_of_dofs, 0.0); + this->mapped_joint_velocities = std::vector(this->params.num_of_dofs, 0.0); + this->mapped_joint_efforts = std::vector(this->params.num_of_dofs, 0.0); // init torch::autograd::GradMode::set_enabled(false); - motor_commands.resize(params.num_of_dofs); + this->joint_publishers_commands.resize(this->params.num_of_dofs); this->InitObservations(); this->InitOutputs(); this->InitControl(); // model - std::string model_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/" + robot_name + "/" + this->params.model_name; + std::string model_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/" + this->robot_name + "/" + this->params.model_name; this->model = torch::jit::load(model_path); // publisher - nh.param("ros_namespace", ros_namespace, ""); - for (int i = 0; i < params.num_of_dofs; ++i) + nh.param("ros_namespace", this->ros_namespace, ""); + for (int i = 0; i < this->params.num_of_dofs; ++i) { // joint need to rename as xxx_joint - torque_publishers[params.joint_controller_names[i]] = nh.advertise( - ros_namespace + params.joint_controller_names[i] + "/command", 10); + this->joint_publishers[this->params.joint_controller_names[i]] = nh.advertise( + this->ros_namespace + this->params.joint_controller_names[i] + "/command", 10); } // subscriber - cmd_vel_subscriber = nh.subscribe("/cmd_vel", 10, &RL_Sim::CmdvelCallback, this); - model_state_subscriber = nh.subscribe("/gazebo/model_states", 10, &RL_Sim::ModelStatesCallback, this); - joint_state_subscriber = nh.subscribe(ros_namespace + "joint_states", 10, &RL_Sim::JointStatesCallback, this); + this->cmd_vel_subscriber = nh.subscribe("/cmd_vel", 10, &RL_Sim::CmdvelCallback, this); + this->model_state_subscriber = nh.subscribe("/gazebo/model_states", 10, &RL_Sim::ModelStatesCallback, this); + this->joint_state_subscriber = nh.subscribe(this->ros_namespace + "joint_states", 10, &RL_Sim::JointStatesCallback, this); // service - gazebo_reset_client = nh.serviceClient("/gazebo/reset_simulation"); + this->gazebo_reset_client = nh.serviceClient("/gazebo/reset_simulation"); // loop - loop_keyboard = std::make_shared("loop_keyboard", 0.05 , boost::bind(&RL_Sim::KeyboardInterface, this)); - loop_control = std::make_shared("loop_control" , 0.002, boost::bind(&RL_Sim::RobotControl , this)); - loop_rl = std::make_shared("loop_rl" , 0.02 , boost::bind(&RL_Sim::RunModel , this)); - loop_keyboard->start(); - loop_control->start(); - loop_rl->start(); + this->loop_keyboard = std::make_shared("loop_keyboard", 0.05 , boost::bind(&RL_Sim::KeyboardInterface, this)); + this->loop_control = std::make_shared("loop_control" , 0.002, boost::bind(&RL_Sim::RobotControl , this)); + this->loop_rl = std::make_shared("loop_rl" , 0.02 , boost::bind(&RL_Sim::RunModel , this)); + this->loop_keyboard->start(); + this->loop_control->start(); + this->loop_rl->start(); #ifdef PLOT - plot_t = std::vector(plot_size, 0); - plot_real_joint_pos.resize(params.num_of_dofs); - plot_target_joint_pos.resize(params.num_of_dofs); - for(auto& vector : plot_real_joint_pos) { vector = std::vector(plot_size, 0); } - for(auto& vector : plot_target_joint_pos) { vector = std::vector(plot_size, 0); } - loop_plot = std::make_shared("loop_plot" , 0.002, boost::bind(&RL_Sim::Plot, this)); - loop_plot->start(); + plot_t = std::vector(this->plot_size, 0); + this->plot_real_joint_pos.resize(this->params.num_of_dofs); + this->plot_target_joint_pos.resize(this->params.num_of_dofs); + for(auto& vector : this->plot_real_joint_pos) { vector = std::vector(this->plot_size, 0); } + for(auto& vector : this->plot_target_joint_pos) { vector = std::vector(this->plot_size, 0); } + this->loop_plot = std::make_shared("loop_plot" , 0.002, boost::bind(&RL_Sim::Plot, this)); + this->loop_plot->start(); #endif #ifdef CSV_LOGGER - CSVInit(robot_name); + this->CSVInit(this->robot_name); #endif } RL_Sim::~RL_Sim() { - loop_keyboard->shutdown(); - loop_control->shutdown(); - loop_rl->shutdown(); + this->loop_keyboard->shutdown(); + this->loop_control->shutdown(); + this->loop_rl->shutdown(); #ifdef PLOT - loop_plot->shutdown(); + this->loop_plot->shutdown(); #endif std::cout << LOGGER::INFO << "RL_Sim exit" << std::endl; } void RL_Sim::GetState(RobotState *state) { - state->imu.quaternion[3] = pose.orientation.w; - state->imu.quaternion[0] = pose.orientation.x; - state->imu.quaternion[1] = pose.orientation.y; - state->imu.quaternion[2] = pose.orientation.z; + state->imu.quaternion[3] = this->pose.orientation.w; + state->imu.quaternion[0] = this->pose.orientation.x; + state->imu.quaternion[1] = this->pose.orientation.y; + state->imu.quaternion[2] = this->pose.orientation.z; - state->imu.gyroscope[0] = vel.angular.x; - state->imu.gyroscope[1] = vel.angular.y; - state->imu.gyroscope[2] = vel.angular.z; + state->imu.gyroscope[0] = this->vel.angular.x; + state->imu.gyroscope[1] = this->vel.angular.y; + state->imu.gyroscope[2] = this->vel.angular.z; // state->imu.accelerometer - for(int i = 0; i < params.num_of_dofs; ++i) + for(int i = 0; i < this->params.num_of_dofs; ++i) { - state->motor_state.q[i] = mapped_joint_positions[i]; - state->motor_state.dq[i] = mapped_joint_velocities[i]; - state->motor_state.tauEst[i] = mapped_joint_efforts[i]; + state->motor_state.q[i] = this->mapped_joint_positions[i]; + state->motor_state.dq[i] = this->mapped_joint_velocities[i]; + state->motor_state.tauEst[i] = this->mapped_joint_efforts[i]; } } void RL_Sim::SetCommand(const RobotCommand *command) { - for(int i = 0; i < params.num_of_dofs; ++i) + for(int i = 0; i < this->params.num_of_dofs; ++i) { - motor_commands[i].q = command->motor_command.q[i]; - motor_commands[i].dq = command->motor_command.dq[i]; - motor_commands[i].kp = command->motor_command.kp[i]; - motor_commands[i].kd = command->motor_command.kd[i]; - motor_commands[i].tau = command->motor_command.tau[i]; + this->joint_publishers_commands[i].q = command->motor_command.q[i]; + this->joint_publishers_commands[i].dq = command->motor_command.dq[i]; + this->joint_publishers_commands[i].kp = command->motor_command.kp[i]; + this->joint_publishers_commands[i].kd = command->motor_command.kd[i]; + this->joint_publishers_commands[i].tau = command->motor_command.tau[i]; } - for(int i = 0; i < params.num_of_dofs; ++i) + for(int i = 0; i < this->params.num_of_dofs; ++i) { - torque_publishers[params.joint_controller_names[i]].publish(motor_commands[i]); + this->joint_publishers[this->params.joint_controller_names[i]].publish(this->joint_publishers_commands[i]); } } void RL_Sim::RobotControl() { - motiontime++; + this->motiontime++; - if(control.control_state == STATE_RESET_SIMULATION) + if(this->control.control_state == STATE_RESET_SIMULATION) { - control.control_state = STATE_WAITING; + this->control.control_state = STATE_WAITING; std_srvs::Empty srv; - gazebo_reset_client.call(srv); + this->gazebo_reset_client.call(srv); } - GetState(&robot_state); - StateController(&robot_state, &robot_command); - SetCommand(&robot_command); + this->GetState(&this->robot_state); + this->StateController(&this->robot_state, &this->robot_command); + this->SetCommand(&this->robot_command); } void RL_Sim::ModelStatesCallback(const gazebo_msgs::ModelStates::ConstPtr &msg) { - vel = msg->twist[2]; - pose = msg->pose[2]; + this->vel = msg->twist[2]; + this->pose = msg->pose[2]; } void RL_Sim::CmdvelCallback(const geometry_msgs::Twist::ConstPtr &msg) { - cmd_vel = *msg; + this->cmd_vel = *msg; } void RL_Sim::MapData(const std::vector& source_data, std::vector& target_data) { for(size_t i = 0; i < source_data.size(); ++i) { - target_data[i] = source_data[sorted_to_original_index[params.joint_controller_names[i]]]; + target_data[i] = source_data[this->sorted_to_original_index[this->params.joint_controller_names[i]]]; } } void RL_Sim::JointStatesCallback(const sensor_msgs::JointState::ConstPtr &msg) { - MapData(msg->position, mapped_joint_positions); - MapData(msg->velocity, mapped_joint_velocities); - MapData(msg->effort, mapped_joint_efforts); + MapData(msg->position, this->mapped_joint_positions); + MapData(msg->velocity, this->mapped_joint_velocities); + MapData(msg->effort, this->mapped_joint_efforts); } void RL_Sim::RunModel() { - if(running_state == STATE_RL_RUNNING) + if(this->running_state == STATE_RL_RUNNING) { - // this->obs.lin_vel = torch::tensor({{vel.linear.x, vel.linear.y, vel.linear.z}}); - this->obs.ang_vel = torch::tensor(robot_state.imu.gyroscope).unsqueeze(0); - // this->obs.commands = torch::tensor({{cmd_vel.linear.x, cmd_vel.linear.y, cmd_vel.angular.z}}); - this->obs.commands = torch::tensor({{control.x, control.y, control.yaw}}); - this->obs.base_quat = torch::tensor(robot_state.imu.quaternion).unsqueeze(0); - this->obs.dof_pos = torch::tensor(robot_state.motor_state.q).narrow(0, 0, params.num_of_dofs).unsqueeze(0); - this->obs.dof_vel = torch::tensor(robot_state.motor_state.dq).narrow(0, 0, params.num_of_dofs).unsqueeze(0); + // this->obs.lin_vel = torch::tensor({{this->vel.linear.x, this->vel.linear.y, this->vel.linear.z}}); + this->obs.ang_vel = torch::tensor(this->robot_state.imu.gyroscope).unsqueeze(0); + // this->obs.commands = torch::tensor({{this->cmd_vel.linear.x, this->cmd_vel.linear.y, this->cmd_vel.angular.z}}); + this->obs.commands = torch::tensor({{this->control.x, this->control.y, this->control.yaw}}); + this->obs.base_quat = torch::tensor(this->robot_state.imu.quaternion).unsqueeze(0); + this->obs.dof_pos = torch::tensor(this->robot_state.motor_state.q).narrow(0, 0, this->params.num_of_dofs).unsqueeze(0); + this->obs.dof_vel = torch::tensor(this->robot_state.motor_state.dq).narrow(0, 0, this->params.num_of_dofs).unsqueeze(0); torch::Tensor clamped_actions = this->Forward(); @@ -194,14 +194,14 @@ void RL_Sim::RunModel() torch::Tensor origin_output_torques = this->ComputeTorques(this->obs.actions); - TorqueProtect(origin_output_torques); + this->TorqueProtect(origin_output_torques); - output_torques = torch::clamp(origin_output_torques, -(this->params.torque_limits), this->params.torque_limits); - output_dof_pos = this->ComputePosition(this->obs.actions); + this->output_torques = torch::clamp(origin_output_torques, -(this->params.torque_limits), this->params.torque_limits); + this->output_dof_pos = this->ComputePosition(this->obs.actions); #ifdef CSV_LOGGER - torch::Tensor tau_est = torch::tensor(mapped_joint_efforts).unsqueeze(0); - CSVLogger(output_torques, tau_est, this->obs.dof_pos, output_dof_pos, this->obs.dof_vel); + torch::Tensor tau_est = torch::tensor(this->mapped_joint_efforts).unsqueeze(0); + this->CSVLogger(this->output_torques, tau_est, this->obs.dof_pos, this->output_dof_pos, this->obs.dof_vel); #endif } } @@ -229,11 +229,11 @@ torch::Tensor RL_Sim::Forward() torch::Tensor actions; - if(use_history) + if(this->use_history) { - history_obs_buf.insert(clamped_obs); - history_obs = history_obs_buf.get_obs_vec({0, 1, 2, 3, 4, 5}); - actions = this->model.forward({history_obs}).toTensor(); + this->history_obs_buf.insert(clamped_obs); + this->history_obs = this->history_obs_buf.get_obs_vec({0, 1, 2, 3, 4, 5}); + actions = this->model.forward({this->history_obs}).toTensor(); } else { @@ -247,20 +247,20 @@ torch::Tensor RL_Sim::Forward() void RL_Sim::Plot() { - plot_t.erase(plot_t.begin()); - plot_t.push_back(motiontime); + this->plot_t.erase(this->plot_t.begin()); + this->plot_t.push_back(this->motiontime); plt::cla(); plt::clf(); - for(int i = 0; i < params.num_of_dofs; ++i) + for(int i = 0; i < this->params.num_of_dofs; ++i) { - plot_real_joint_pos[i].erase(plot_real_joint_pos[i].begin()); - plot_target_joint_pos[i].erase(plot_target_joint_pos[i].begin()); - plot_real_joint_pos[i].push_back(mapped_joint_positions[i]); - plot_target_joint_pos[i].push_back(motor_commands[i].q); + this->plot_real_joint_pos[i].erase(this->plot_real_joint_pos[i].begin()); + this->plot_target_joint_pos[i].erase(this->plot_target_joint_pos[i].begin()); + this->plot_real_joint_pos[i].push_back(this->mapped_joint_positions[i]); + this->plot_target_joint_pos[i].push_back(this->joint_publishers_commands[i].q); plt::subplot(4, 3, i+1); - plt::named_plot("_real_joint_pos", plot_t, plot_real_joint_pos[i], "r"); - plt::named_plot("_target_joint_pos", plot_t, plot_target_joint_pos[i], "b"); - plt::xlim(plot_t.front(), plot_t.back()); + plt::named_plot("_real_joint_pos", this->plot_t, this->plot_real_joint_pos[i], "r"); + plt::named_plot("_target_joint_pos", this->plot_t, this->plot_target_joint_pos[i], "b"); + plt::xlim(this->plot_t.front(), this->plot_t.back()); } // plt::legend(); plt::pause(0.0001);