diff --git a/src/rl_sar/include/rl_sim.hpp b/src/rl_sar/include/rl_sim.hpp index 3951978..086e268 100644 --- a/src/rl_sar/include/rl_sim.hpp +++ b/src/rl_sar/include/rl_sim.hpp @@ -7,10 +7,10 @@ #include #include #include -#include #include "std_srvs/Empty.h" #include #include "robot_msgs/MotorCommand.h" +#include "robot_msgs/MotorState.h" #include #include @@ -54,27 +54,25 @@ private: geometry_msgs::Twist cmd_vel; sensor_msgs::Joy joy_msg; ros::Subscriber model_state_subscriber; - ros::Subscriber joint_state_subscriber; ros::Subscriber cmd_vel_subscriber; ros::Subscriber joy_subscriber; ros::ServiceClient gazebo_set_model_state_client; ros::ServiceClient gazebo_pause_physics_client; ros::ServiceClient gazebo_unpause_physics_client; std::map joint_publishers; + std::map joint_subscribers; std::vector joint_publishers_commands; void ModelStatesCallback(const gazebo_msgs::ModelStates::ConstPtr &msg); - void JointStatesCallback(const sensor_msgs::JointState::ConstPtr &msg); + void JointStatesCallback(const robot_msgs::MotorState::ConstPtr &msg, const std::string &joint_name); void CmdvelCallback(const geometry_msgs::Twist::ConstPtr &msg); void JoyCallback(const sensor_msgs::Joy::ConstPtr &msg); // others std::string gazebo_model_name; int motiontime = 0; - std::map sorted_to_original_index; - std::vector mapped_joint_positions; - std::vector mapped_joint_velocities; - std::vector mapped_joint_efforts; - void MapData(const std::vector &source_data, std::vector &target_data); + std::map joint_positions; + std::map joint_velocities; + std::map joint_efforts; }; #endif // RL_SIM_HPP diff --git a/src/rl_sar/library/rl_sdk/rl_sdk.hpp b/src/rl_sar/library/rl_sdk/rl_sdk.hpp index b31bcfe..1afd0f6 100644 --- a/src/rl_sar/library/rl_sdk/rl_sdk.hpp +++ b/src/rl_sar/library/rl_sdk/rl_sdk.hpp @@ -45,7 +45,7 @@ struct RobotState std::vector q = std::vector(32, 0.0); std::vector dq = std::vector(32, 0.0); std::vector ddq = std::vector(32, 0.0); - std::vector tauEst = std::vector(32, 0.0); + std::vector tau_est = std::vector(32, 0.0); std::vector cur = std::vector(32, 0.0); } motor_state; }; diff --git a/src/rl_sar/models/a1_isaacsim/policy.pt b/src/rl_sar/models/a1_isaacsim/policy.pt index 8dfe1eb..f047fa0 100644 Binary files a/src/rl_sar/models/a1_isaacsim/policy.pt and b/src/rl_sar/models/a1_isaacsim/policy.pt differ diff --git a/src/rl_sar/scripts/rl_sdk.py b/src/rl_sar/scripts/rl_sdk.py index 2761b30..ab426ff 100644 --- a/src/rl_sar/scripts/rl_sdk.py +++ b/src/rl_sar/scripts/rl_sdk.py @@ -41,7 +41,7 @@ class RobotState: self.q = [0.0] * 32 self.dq = [0.0] * 32 self.ddq = [0.0] * 32 - self.tauEst = [0.0] * 32 + self.tau_est = [0.0] * 32 self.cur = [0.0] * 32 class STATE(Enum): diff --git a/src/rl_sar/scripts/rl_sim.py b/src/rl_sar/scripts/rl_sim.py index fc88643..091cf08 100644 --- a/src/rl_sar/scripts/rl_sim.py +++ b/src/rl_sar/scripts/rl_sim.py @@ -4,11 +4,9 @@ import torch import threading import time import rospy -import numpy as np from gazebo_msgs.msg import ModelStates -from sensor_msgs.msg import JointState from geometry_msgs.msg import Twist, Pose -from robot_msgs.msg import MotorCommand +from robot_msgs.msg import MotorState, MotorCommand from gazebo_msgs.srv import SetModelState, SetModelStateRequest from std_srvs.srv import Empty @@ -42,22 +40,13 @@ class RL_Sim(RL): if len(self.params.observations_history) != 0: self.history_obs_buf = ObservationBuffer(1, self.params.num_observations, len(self.params.observations_history)) - # Due to the fact that the robot_state_publisher sorts the joint names alphabetically, - # the mapping table is established according to the order defined in the YAML file - sorted_joint_controller_names = sorted(self.params.joint_controller_names) - self.sorted_to_original_index = {} - for i in range(len(self.params.joint_controller_names)): - self.sorted_to_original_index[sorted_joint_controller_names[i]] = i - self.mapped_joint_positions = [0.0] * self.params.num_of_dofs - self.mapped_joint_velocities = [0.0] * self.params.num_of_dofs - self.mapped_joint_efforts = [0.0] * self.params.num_of_dofs - # init torch.set_grad_enabled(False) self.joint_publishers_commands = [MotorCommand() for _ in range(self.params.num_of_dofs)] self.InitObservations() self.InitOutputs() self.InitControl() + self.running_state = STATE.STATE_RL_RUNNING # model model_path = os.path.join(os.path.dirname(__file__), f"../models/{self.robot_name}/{self.params.model_name}") @@ -67,14 +56,29 @@ class RL_Sim(RL): self.ros_namespace = rospy.get_param("ros_namespace", "") self.joint_publishers = {} for i in range(self.params.num_of_dofs): - topic_name = f"{self.ros_namespace}{self.params.joint_controller_names[i]}/command" + joint_name = self.params.joint_controller_names[i] + topic_name = f"{self.ros_namespace}{joint_name}/command" self.joint_publishers[self.params.joint_controller_names[i]] = rospy.Publisher(topic_name, MotorCommand, queue_size=10) # subscriber self.cmd_vel_subscriber = rospy.Subscriber("/cmd_vel", Twist, self.CmdvelCallback, queue_size=10) self.model_state_subscriber = rospy.Subscriber("/gazebo/model_states", ModelStates, self.ModelStatesCallback, queue_size=10) - joint_states_topic = f"{self.ros_namespace}joint_states" - self.joint_state_subscriber = rospy.Subscriber(joint_states_topic, JointState, self.JointStatesCallback, queue_size=10) + self.joint_subscribers = {} + self.joint_positions = {} + self.joint_velocities = {} + self.joint_efforts = {} + for i in range(self.params.num_of_dofs): + joint_name = self.params.joint_controller_names[i] + topic_name = f"{self.ros_namespace}{joint_name}/state" + self.joint_subscribers[joint_name] = rospy.Subscriber( + topic_name, + MotorState, + lambda msg, name=joint_name: self.JointStatesCallback(msg, name), + queue_size=10 + ) + self.joint_positions[joint_name] = 0.0 + self.joint_velocities[joint_name] = 0.0 + self.joint_efforts[joint_name] = 0.0 # service self.gazebo_model_name = rospy.get_param("gazebo_model_name", "") @@ -120,9 +124,9 @@ class RL_Sim(RL): # state.imu.accelerometer for i in range(self.params.num_of_dofs): - state.motor_state.q[i] = self.mapped_joint_positions[i] - state.motor_state.dq[i] = self.mapped_joint_velocities[i] - state.motor_state.tauEst[i] = self.mapped_joint_efforts[i] + state.motor_state.q[i] = self.joint_positions[self.params.joint_controller_names[i]] + state.motor_state.dq[i] = self.joint_velocities[self.params.joint_controller_names[i]] + state.motor_state.tau_est[i] = self.joint_efforts[self.params.joint_controller_names[i]] def SetCommand(self, command): for i in range(self.params.num_of_dofs): @@ -165,14 +169,10 @@ class RL_Sim(RL): def CmdvelCallback(self, msg): self.cmd_vel = msg - def MapData(self, source_data, target_data): - for i in range(len(source_data)): - target_data[i] = source_data[self.sorted_to_original_index[self.params.joint_controller_names[i]]] - - def JointStatesCallback(self, msg): - self.MapData(msg.position, self.mapped_joint_positions) - self.MapData(msg.velocity, self.mapped_joint_velocities) - self.MapData(msg.effort, self.mapped_joint_efforts) + def JointStatesCallback(self, msg, joint_name): + self.joint_positions[joint_name] = msg.q + self.joint_velocities[joint_name] = msg.dq + self.joint_efforts[joint_name] = msg.tau_est def RunModel(self): if self.running_state == STATE.STATE_RL_RUNNING and self.simulation_running: @@ -199,7 +199,9 @@ class RL_Sim(RL): self.output_dof_pos = self.ComputePosition(self.obs.actions) if CSV_LOGGER: - tau_est = torch.tensor(self.mapped_joint_efforts).unsqueeze(0) + tau_est = torch.zeros((1, self.params.num_of_dofs)) + for i in range(self.params.num_of_dofs): + tau_est[0, i] = self.joint_efforts[self.params.joint_controller_names[i]] self.CSVLogger(self.output_torques, tau_est, self.obs.dof_pos, self.output_dof_pos, self.obs.dof_vel) def Forward(self): diff --git a/src/rl_sar/src/rl_real_a1.cpp b/src/rl_sar/src/rl_real_a1.cpp index 8411f2d..b7e005a 100644 --- a/src/rl_sar/src/rl_real_a1.cpp +++ b/src/rl_sar/src/rl_real_a1.cpp @@ -116,7 +116,7 @@ void RL_Real::GetState(RobotState *state) { state->motor_state.q[i] = this->unitree_low_state.motorState[state_mapping[i]].q; state->motor_state.dq[i] = this->unitree_low_state.motorState[state_mapping[i]].dq; - state->motor_state.tauEst[i] = this->unitree_low_state.motorState[state_mapping[i]].tauEst; + state->motor_state.tau_est[i] = this->unitree_low_state.motorState[state_mapping[i]].tauEst; } } @@ -179,7 +179,7 @@ void RL_Real::RunModel() this->output_dof_pos = this->ComputePosition(this->obs.actions); #ifdef CSV_LOGGER - torch::Tensor tau_est = torch::tensor(this->robot_state.motor_state.tauEst).unsqueeze(0); + torch::Tensor tau_est = torch::tensor(this->robot_state.motor_state.tau_est).unsqueeze(0); this->CSVLogger(this->output_torques, tau_est, this->obs.dof_pos, this->output_dof_pos, this->obs.dof_vel); #endif } diff --git a/src/rl_sar/src/rl_real_go2.cpp b/src/rl_sar/src/rl_real_go2.cpp index 66b929b..9a8fe5b 100644 --- a/src/rl_sar/src/rl_real_go2.cpp +++ b/src/rl_sar/src/rl_real_go2.cpp @@ -122,7 +122,7 @@ void RL_Real::GetState(RobotState *state) { state->motor_state.q[i] = this->unitree_low_state.motor_state()[state_mapping[i]].q(); state->motor_state.dq[i] = this->unitree_low_state.motor_state()[state_mapping[i]].dq(); - state->motor_state.tauEst[i] = this->unitree_low_state.motor_state()[state_mapping[i]].tau_est(); + state->motor_state.tau_est[i] = this->unitree_low_state.motor_state()[state_mapping[i]].tau_est(); } } @@ -184,7 +184,7 @@ void RL_Real::RunModel() this->output_dof_pos = this->ComputePosition(this->obs.actions); #ifdef CSV_LOGGER - torch::Tensor tau_est = torch::tensor(this->robot_state.motor_state.tauEst).unsqueeze(0); + torch::Tensor tau_est = torch::tensor(this->robot_state.motor_state.tau_est).unsqueeze(0); this->CSVLogger(this->output_torques, tau_est, this->obs.dof_pos, this->output_dof_pos, this->obs.dof_vel); #endif } diff --git a/src/rl_sar/src/rl_sim.cpp b/src/rl_sar/src/rl_sim.cpp index 0001475..db09a4b 100644 --- a/src/rl_sar/src/rl_sim.cpp +++ b/src/rl_sar/src/rl_sim.cpp @@ -18,17 +18,6 @@ RL_Sim::RL_Sim() observation = "ang_vel_world"; } } - // Due to the fact that the robot_state_publisher sorts the joint names alphabetically, - // the mapping table is established according to the order defined in the YAML file - std::vector sorted_joint_controller_names = this->params.joint_controller_names; - std::sort(sorted_joint_controller_names.begin(), sorted_joint_controller_names.end()); - for (size_t i = 0; i < this->params.joint_controller_names.size(); ++i) - { - this->sorted_to_original_index[sorted_joint_controller_names[i]] = i; - } - this->mapped_joint_positions = std::vector(this->params.num_of_dofs, 0.0); - this->mapped_joint_velocities = std::vector(this->params.num_of_dofs, 0.0); - this->mapped_joint_efforts = std::vector(this->params.num_of_dofs, 0.0); // init rl torch::autograd::GradMode::set_enabled(false); @@ -52,15 +41,32 @@ RL_Sim::RL_Sim() for (int i = 0; i < this->params.num_of_dofs; ++i) { // joint need to rename as xxx_joint - this->joint_publishers[this->params.joint_controller_names[i]] = - nh.advertise(this->ros_namespace + this->params.joint_controller_names[i] + "/command", 10); + const std::string &joint_name = this->params.joint_controller_names[i]; + const std::string topic_name = this->ros_namespace + joint_name + "/command"; + this->joint_publishers[joint_name] = + nh.advertise(topic_name, 10); } // subscriber this->cmd_vel_subscriber = nh.subscribe("/cmd_vel", 10, &RL_Sim::CmdvelCallback, this); this->joy_subscriber = nh.subscribe("/joy", 10, &RL_Sim::JoyCallback, this); this->model_state_subscriber = nh.subscribe("/gazebo/model_states", 10, &RL_Sim::ModelStatesCallback, this); - this->joint_state_subscriber = nh.subscribe(this->ros_namespace + "joint_states", 10, &RL_Sim::JointStatesCallback, this); + for (int i = 0; i < this->params.num_of_dofs; ++i) + { + // joint need to rename as xxx_joint + const std::string &joint_name = this->params.joint_controller_names[i]; + const std::string topic_name = this->ros_namespace + joint_name + "/state"; + this->joint_subscribers[joint_name] = + nh.subscribe(topic_name, 10, + [this, joint_name](const robot_msgs::MotorState::ConstPtr &msg) + { + this->JointStatesCallback(msg, joint_name); + } + ); + this->joint_positions[joint_name] = 0.0; + this->joint_velocities[joint_name] = 0.0; + this->joint_efforts[joint_name] = 0.0; + } // service nh.param("gazebo_model_name", this->gazebo_model_name, ""); @@ -130,9 +136,9 @@ void RL_Sim::GetState(RobotState *state) for (int i = 0; i < this->params.num_of_dofs; ++i) { - state->motor_state.q[i] = this->mapped_joint_positions[i]; - state->motor_state.dq[i] = this->mapped_joint_velocities[i]; - state->motor_state.tauEst[i] = this->mapped_joint_efforts[i]; + state->motor_state.q[i] = this->joint_positions[this->params.joint_controller_names[i]]; + state->motor_state.dq[i] = this->joint_velocities[this->params.joint_controller_names[i]]; + state->motor_state.tau_est[i] = this->joint_efforts[this->params.joint_controller_names[i]]; } } @@ -240,19 +246,11 @@ void RL_Sim::JoyCallback(const sensor_msgs::Joy::ConstPtr &msg) this->control.yaw = this->joy_msg.axes[3] * 1.5; // Rx } -void RL_Sim::MapData(const std::vector &source_data, std::vector &target_data) +void RL_Sim::JointStatesCallback(const robot_msgs::MotorState::ConstPtr &msg, const std::string &joint_name) { - for (size_t i = 0; i < source_data.size(); ++i) - { - target_data[i] = source_data[this->sorted_to_original_index[this->params.joint_controller_names[i]]]; - } -} - -void RL_Sim::JointStatesCallback(const sensor_msgs::JointState::ConstPtr &msg) -{ - MapData(msg->position, this->mapped_joint_positions); - MapData(msg->velocity, this->mapped_joint_velocities); - MapData(msg->effort, this->mapped_joint_efforts); + this->joint_positions[joint_name] = msg->q; + this->joint_velocities[joint_name] = msg->dq; + this->joint_efforts[joint_name] = msg->tau_est; } void RL_Sim::RunModel() @@ -286,7 +284,11 @@ void RL_Sim::RunModel() this->output_dof_pos = this->ComputePosition(this->obs.actions); #ifdef CSV_LOGGER - torch::Tensor tau_est = torch::tensor(this->mapped_joint_efforts).unsqueeze(0); + torch::Tensor tau_est = torch::zeros({1, this->params.num_of_dofs}); + for (int i = 0; i < this->params.num_of_dofs; ++i) + { + tau_est[0][i] = this->joint_efforts[this->params.joint_controller_names[i]]; + } this->CSVLogger(this->output_torques, tau_est, this->obs.dof_pos, this->output_dof_pos, this->obs.dof_vel); #endif } @@ -330,7 +332,7 @@ void RL_Sim::Plot() { this->plot_real_joint_pos[i].erase(this->plot_real_joint_pos[i].begin()); this->plot_target_joint_pos[i].erase(this->plot_target_joint_pos[i].begin()); - this->plot_real_joint_pos[i].push_back(this->mapped_joint_positions[i]); + this->plot_real_joint_pos[i].push_back(this->joint_positions[this->params.joint_controller_names[i]]); this->plot_target_joint_pos[i].push_back(this->joint_publishers_commands[i].q); plt::subplot(4, 3, i + 1); plt::named_plot("_real_joint_pos", this->plot_t, this->plot_real_joint_pos[i], "r"); diff --git a/src/robot_joint_controller/src/robot_joint_controller.cpp b/src/robot_joint_controller/src/robot_joint_controller.cpp index 41d652c..86e077e 100644 --- a/src/robot_joint_controller/src/robot_joint_controller.cpp +++ b/src/robot_joint_controller/src/robot_joint_controller.cpp @@ -111,7 +111,7 @@ namespace robot_joint_controller lastCommand.dq = 0; lastState.dq = 0; lastCommand.tau = 0; - lastState.tauEst = 0; + lastState.tau_est = 0; command.initRT(lastCommand); pid_controller_.reset(); @@ -158,15 +158,15 @@ namespace robot_joint_controller lastState.q = currentPos; lastState.dq = currentVel; - // lastState.tauEst = calcTorque; - lastState.tauEst = joint.getEffort(); + // lastState.tau_est = calcTorque; + lastState.tau_est = joint.getEffort(); // publish state if (controller_state_publisher_ && controller_state_publisher_->trylock()) { controller_state_publisher_->msg_.q = lastState.q; controller_state_publisher_->msg_.dq = lastState.dq; - controller_state_publisher_->msg_.tauEst = lastState.tauEst; + controller_state_publisher_->msg_.tau_est = lastState.tau_est; controller_state_publisher_->unlockAndPublish(); } } diff --git a/src/robot_msgs/msg/MotorState.msg b/src/robot_msgs/msg/MotorState.msg index fbc5477..ceeb066 100644 --- a/src/robot_msgs/msg/MotorState.msg +++ b/src/robot_msgs/msg/MotorState.msg @@ -1,5 +1,5 @@ float32 q # motor current position (rad) float32 dq # motor current speed (rad/s) float32 ddq # motor current speed (rad/s) -float32 tauEst # current estimated output torque (N*m) +float32 tau_est # current estimated output torque (N*m) float32 cur # current estimated output cur (N*m) \ No newline at end of file