fix: 1. change joint_state_subscriber to joint_subscribers

2. rename tauEst to tau_est
This commit is contained in:
fan-ziqi 2025-01-03 17:20:18 +08:00
parent da85366ac7
commit 87855a052d
10 changed files with 80 additions and 78 deletions

View File

@ -7,10 +7,10 @@
#include <ros/ros.h>
#include <sensor_msgs/Joy.h>
#include <gazebo_msgs/ModelStates.h>
#include <sensor_msgs/JointState.h>
#include "std_srvs/Empty.h"
#include <geometry_msgs/Twist.h>
#include "robot_msgs/MotorCommand.h"
#include "robot_msgs/MotorState.h"
#include <csignal>
#include <gazebo_msgs/SetModelState.h>
@ -54,27 +54,25 @@ private:
geometry_msgs::Twist cmd_vel;
sensor_msgs::Joy joy_msg;
ros::Subscriber model_state_subscriber;
ros::Subscriber joint_state_subscriber;
ros::Subscriber cmd_vel_subscriber;
ros::Subscriber joy_subscriber;
ros::ServiceClient gazebo_set_model_state_client;
ros::ServiceClient gazebo_pause_physics_client;
ros::ServiceClient gazebo_unpause_physics_client;
std::map<std::string, ros::Publisher> joint_publishers;
std::map<std::string, ros::Subscriber> joint_subscribers;
std::vector<robot_msgs::MotorCommand> joint_publishers_commands;
void ModelStatesCallback(const gazebo_msgs::ModelStates::ConstPtr &msg);
void JointStatesCallback(const sensor_msgs::JointState::ConstPtr &msg);
void JointStatesCallback(const robot_msgs::MotorState::ConstPtr &msg, const std::string &joint_name);
void CmdvelCallback(const geometry_msgs::Twist::ConstPtr &msg);
void JoyCallback(const sensor_msgs::Joy::ConstPtr &msg);
// others
std::string gazebo_model_name;
int motiontime = 0;
std::map<std::string, size_t> sorted_to_original_index;
std::vector<double> mapped_joint_positions;
std::vector<double> mapped_joint_velocities;
std::vector<double> mapped_joint_efforts;
void MapData(const std::vector<double> &source_data, std::vector<double> &target_data);
std::map<std::string, double> joint_positions;
std::map<std::string, double> joint_velocities;
std::map<std::string, double> joint_efforts;
};
#endif // RL_SIM_HPP

View File

@ -45,7 +45,7 @@ struct RobotState
std::vector<T> q = std::vector<T>(32, 0.0);
std::vector<T> dq = std::vector<T>(32, 0.0);
std::vector<T> ddq = std::vector<T>(32, 0.0);
std::vector<T> tauEst = std::vector<T>(32, 0.0);
std::vector<T> tau_est = std::vector<T>(32, 0.0);
std::vector<T> cur = std::vector<T>(32, 0.0);
} motor_state;
};

View File

@ -41,7 +41,7 @@ class RobotState:
self.q = [0.0] * 32
self.dq = [0.0] * 32
self.ddq = [0.0] * 32
self.tauEst = [0.0] * 32
self.tau_est = [0.0] * 32
self.cur = [0.0] * 32
class STATE(Enum):

View File

@ -4,11 +4,9 @@ import torch
import threading
import time
import rospy
import numpy as np
from gazebo_msgs.msg import ModelStates
from sensor_msgs.msg import JointState
from geometry_msgs.msg import Twist, Pose
from robot_msgs.msg import MotorCommand
from robot_msgs.msg import MotorState, MotorCommand
from gazebo_msgs.srv import SetModelState, SetModelStateRequest
from std_srvs.srv import Empty
@ -42,22 +40,13 @@ class RL_Sim(RL):
if len(self.params.observations_history) != 0:
self.history_obs_buf = ObservationBuffer(1, self.params.num_observations, len(self.params.observations_history))
# Due to the fact that the robot_state_publisher sorts the joint names alphabetically,
# the mapping table is established according to the order defined in the YAML file
sorted_joint_controller_names = sorted(self.params.joint_controller_names)
self.sorted_to_original_index = {}
for i in range(len(self.params.joint_controller_names)):
self.sorted_to_original_index[sorted_joint_controller_names[i]] = i
self.mapped_joint_positions = [0.0] * self.params.num_of_dofs
self.mapped_joint_velocities = [0.0] * self.params.num_of_dofs
self.mapped_joint_efforts = [0.0] * self.params.num_of_dofs
# init
torch.set_grad_enabled(False)
self.joint_publishers_commands = [MotorCommand() for _ in range(self.params.num_of_dofs)]
self.InitObservations()
self.InitOutputs()
self.InitControl()
self.running_state = STATE.STATE_RL_RUNNING
# model
model_path = os.path.join(os.path.dirname(__file__), f"../models/{self.robot_name}/{self.params.model_name}")
@ -67,14 +56,29 @@ class RL_Sim(RL):
self.ros_namespace = rospy.get_param("ros_namespace", "")
self.joint_publishers = {}
for i in range(self.params.num_of_dofs):
topic_name = f"{self.ros_namespace}{self.params.joint_controller_names[i]}/command"
joint_name = self.params.joint_controller_names[i]
topic_name = f"{self.ros_namespace}{joint_name}/command"
self.joint_publishers[self.params.joint_controller_names[i]] = rospy.Publisher(topic_name, MotorCommand, queue_size=10)
# subscriber
self.cmd_vel_subscriber = rospy.Subscriber("/cmd_vel", Twist, self.CmdvelCallback, queue_size=10)
self.model_state_subscriber = rospy.Subscriber("/gazebo/model_states", ModelStates, self.ModelStatesCallback, queue_size=10)
joint_states_topic = f"{self.ros_namespace}joint_states"
self.joint_state_subscriber = rospy.Subscriber(joint_states_topic, JointState, self.JointStatesCallback, queue_size=10)
self.joint_subscribers = {}
self.joint_positions = {}
self.joint_velocities = {}
self.joint_efforts = {}
for i in range(self.params.num_of_dofs):
joint_name = self.params.joint_controller_names[i]
topic_name = f"{self.ros_namespace}{joint_name}/state"
self.joint_subscribers[joint_name] = rospy.Subscriber(
topic_name,
MotorState,
lambda msg, name=joint_name: self.JointStatesCallback(msg, name),
queue_size=10
)
self.joint_positions[joint_name] = 0.0
self.joint_velocities[joint_name] = 0.0
self.joint_efforts[joint_name] = 0.0
# service
self.gazebo_model_name = rospy.get_param("gazebo_model_name", "")
@ -120,9 +124,9 @@ class RL_Sim(RL):
# state.imu.accelerometer
for i in range(self.params.num_of_dofs):
state.motor_state.q[i] = self.mapped_joint_positions[i]
state.motor_state.dq[i] = self.mapped_joint_velocities[i]
state.motor_state.tauEst[i] = self.mapped_joint_efforts[i]
state.motor_state.q[i] = self.joint_positions[self.params.joint_controller_names[i]]
state.motor_state.dq[i] = self.joint_velocities[self.params.joint_controller_names[i]]
state.motor_state.tau_est[i] = self.joint_efforts[self.params.joint_controller_names[i]]
def SetCommand(self, command):
for i in range(self.params.num_of_dofs):
@ -165,14 +169,10 @@ class RL_Sim(RL):
def CmdvelCallback(self, msg):
self.cmd_vel = msg
def MapData(self, source_data, target_data):
for i in range(len(source_data)):
target_data[i] = source_data[self.sorted_to_original_index[self.params.joint_controller_names[i]]]
def JointStatesCallback(self, msg):
self.MapData(msg.position, self.mapped_joint_positions)
self.MapData(msg.velocity, self.mapped_joint_velocities)
self.MapData(msg.effort, self.mapped_joint_efforts)
def JointStatesCallback(self, msg, joint_name):
self.joint_positions[joint_name] = msg.q
self.joint_velocities[joint_name] = msg.dq
self.joint_efforts[joint_name] = msg.tau_est
def RunModel(self):
if self.running_state == STATE.STATE_RL_RUNNING and self.simulation_running:
@ -199,7 +199,9 @@ class RL_Sim(RL):
self.output_dof_pos = self.ComputePosition(self.obs.actions)
if CSV_LOGGER:
tau_est = torch.tensor(self.mapped_joint_efforts).unsqueeze(0)
tau_est = torch.zeros((1, self.params.num_of_dofs))
for i in range(self.params.num_of_dofs):
tau_est[0, i] = self.joint_efforts[self.params.joint_controller_names[i]]
self.CSVLogger(self.output_torques, tau_est, self.obs.dof_pos, self.output_dof_pos, self.obs.dof_vel)
def Forward(self):

View File

@ -116,7 +116,7 @@ void RL_Real::GetState(RobotState<double> *state)
{
state->motor_state.q[i] = this->unitree_low_state.motorState[state_mapping[i]].q;
state->motor_state.dq[i] = this->unitree_low_state.motorState[state_mapping[i]].dq;
state->motor_state.tauEst[i] = this->unitree_low_state.motorState[state_mapping[i]].tauEst;
state->motor_state.tau_est[i] = this->unitree_low_state.motorState[state_mapping[i]].tauEst;
}
}
@ -179,7 +179,7 @@ void RL_Real::RunModel()
this->output_dof_pos = this->ComputePosition(this->obs.actions);
#ifdef CSV_LOGGER
torch::Tensor tau_est = torch::tensor(this->robot_state.motor_state.tauEst).unsqueeze(0);
torch::Tensor tau_est = torch::tensor(this->robot_state.motor_state.tau_est).unsqueeze(0);
this->CSVLogger(this->output_torques, tau_est, this->obs.dof_pos, this->output_dof_pos, this->obs.dof_vel);
#endif
}

View File

@ -122,7 +122,7 @@ void RL_Real::GetState(RobotState<double> *state)
{
state->motor_state.q[i] = this->unitree_low_state.motor_state()[state_mapping[i]].q();
state->motor_state.dq[i] = this->unitree_low_state.motor_state()[state_mapping[i]].dq();
state->motor_state.tauEst[i] = this->unitree_low_state.motor_state()[state_mapping[i]].tau_est();
state->motor_state.tau_est[i] = this->unitree_low_state.motor_state()[state_mapping[i]].tau_est();
}
}
@ -184,7 +184,7 @@ void RL_Real::RunModel()
this->output_dof_pos = this->ComputePosition(this->obs.actions);
#ifdef CSV_LOGGER
torch::Tensor tau_est = torch::tensor(this->robot_state.motor_state.tauEst).unsqueeze(0);
torch::Tensor tau_est = torch::tensor(this->robot_state.motor_state.tau_est).unsqueeze(0);
this->CSVLogger(this->output_torques, tau_est, this->obs.dof_pos, this->output_dof_pos, this->obs.dof_vel);
#endif
}

View File

@ -18,17 +18,6 @@ RL_Sim::RL_Sim()
observation = "ang_vel_world";
}
}
// Due to the fact that the robot_state_publisher sorts the joint names alphabetically,
// the mapping table is established according to the order defined in the YAML file
std::vector<std::string> sorted_joint_controller_names = this->params.joint_controller_names;
std::sort(sorted_joint_controller_names.begin(), sorted_joint_controller_names.end());
for (size_t i = 0; i < this->params.joint_controller_names.size(); ++i)
{
this->sorted_to_original_index[sorted_joint_controller_names[i]] = i;
}
this->mapped_joint_positions = std::vector<double>(this->params.num_of_dofs, 0.0);
this->mapped_joint_velocities = std::vector<double>(this->params.num_of_dofs, 0.0);
this->mapped_joint_efforts = std::vector<double>(this->params.num_of_dofs, 0.0);
// init rl
torch::autograd::GradMode::set_enabled(false);
@ -52,15 +41,32 @@ RL_Sim::RL_Sim()
for (int i = 0; i < this->params.num_of_dofs; ++i)
{
// joint need to rename as xxx_joint
this->joint_publishers[this->params.joint_controller_names[i]] =
nh.advertise<robot_msgs::MotorCommand>(this->ros_namespace + this->params.joint_controller_names[i] + "/command", 10);
const std::string &joint_name = this->params.joint_controller_names[i];
const std::string topic_name = this->ros_namespace + joint_name + "/command";
this->joint_publishers[joint_name] =
nh.advertise<robot_msgs::MotorCommand>(topic_name, 10);
}
// subscriber
this->cmd_vel_subscriber = nh.subscribe<geometry_msgs::Twist>("/cmd_vel", 10, &RL_Sim::CmdvelCallback, this);
this->joy_subscriber = nh.subscribe<sensor_msgs::Joy>("/joy", 10, &RL_Sim::JoyCallback, this);
this->model_state_subscriber = nh.subscribe<gazebo_msgs::ModelStates>("/gazebo/model_states", 10, &RL_Sim::ModelStatesCallback, this);
this->joint_state_subscriber = nh.subscribe<sensor_msgs::JointState>(this->ros_namespace + "joint_states", 10, &RL_Sim::JointStatesCallback, this);
for (int i = 0; i < this->params.num_of_dofs; ++i)
{
// joint need to rename as xxx_joint
const std::string &joint_name = this->params.joint_controller_names[i];
const std::string topic_name = this->ros_namespace + joint_name + "/state";
this->joint_subscribers[joint_name] =
nh.subscribe<robot_msgs::MotorState>(topic_name, 10,
[this, joint_name](const robot_msgs::MotorState::ConstPtr &msg)
{
this->JointStatesCallback(msg, joint_name);
}
);
this->joint_positions[joint_name] = 0.0;
this->joint_velocities[joint_name] = 0.0;
this->joint_efforts[joint_name] = 0.0;
}
// service
nh.param<std::string>("gazebo_model_name", this->gazebo_model_name, "");
@ -130,9 +136,9 @@ void RL_Sim::GetState(RobotState<double> *state)
for (int i = 0; i < this->params.num_of_dofs; ++i)
{
state->motor_state.q[i] = this->mapped_joint_positions[i];
state->motor_state.dq[i] = this->mapped_joint_velocities[i];
state->motor_state.tauEst[i] = this->mapped_joint_efforts[i];
state->motor_state.q[i] = this->joint_positions[this->params.joint_controller_names[i]];
state->motor_state.dq[i] = this->joint_velocities[this->params.joint_controller_names[i]];
state->motor_state.tau_est[i] = this->joint_efforts[this->params.joint_controller_names[i]];
}
}
@ -240,19 +246,11 @@ void RL_Sim::JoyCallback(const sensor_msgs::Joy::ConstPtr &msg)
this->control.yaw = this->joy_msg.axes[3] * 1.5; // Rx
}
void RL_Sim::MapData(const std::vector<double> &source_data, std::vector<double> &target_data)
void RL_Sim::JointStatesCallback(const robot_msgs::MotorState::ConstPtr &msg, const std::string &joint_name)
{
for (size_t i = 0; i < source_data.size(); ++i)
{
target_data[i] = source_data[this->sorted_to_original_index[this->params.joint_controller_names[i]]];
}
}
void RL_Sim::JointStatesCallback(const sensor_msgs::JointState::ConstPtr &msg)
{
MapData(msg->position, this->mapped_joint_positions);
MapData(msg->velocity, this->mapped_joint_velocities);
MapData(msg->effort, this->mapped_joint_efforts);
this->joint_positions[joint_name] = msg->q;
this->joint_velocities[joint_name] = msg->dq;
this->joint_efforts[joint_name] = msg->tau_est;
}
void RL_Sim::RunModel()
@ -286,7 +284,11 @@ void RL_Sim::RunModel()
this->output_dof_pos = this->ComputePosition(this->obs.actions);
#ifdef CSV_LOGGER
torch::Tensor tau_est = torch::tensor(this->mapped_joint_efforts).unsqueeze(0);
torch::Tensor tau_est = torch::zeros({1, this->params.num_of_dofs});
for (int i = 0; i < this->params.num_of_dofs; ++i)
{
tau_est[0][i] = this->joint_efforts[this->params.joint_controller_names[i]];
}
this->CSVLogger(this->output_torques, tau_est, this->obs.dof_pos, this->output_dof_pos, this->obs.dof_vel);
#endif
}
@ -330,7 +332,7 @@ void RL_Sim::Plot()
{
this->plot_real_joint_pos[i].erase(this->plot_real_joint_pos[i].begin());
this->plot_target_joint_pos[i].erase(this->plot_target_joint_pos[i].begin());
this->plot_real_joint_pos[i].push_back(this->mapped_joint_positions[i]);
this->plot_real_joint_pos[i].push_back(this->joint_positions[this->params.joint_controller_names[i]]);
this->plot_target_joint_pos[i].push_back(this->joint_publishers_commands[i].q);
plt::subplot(4, 3, i + 1);
plt::named_plot("_real_joint_pos", this->plot_t, this->plot_real_joint_pos[i], "r");

View File

@ -111,7 +111,7 @@ namespace robot_joint_controller
lastCommand.dq = 0;
lastState.dq = 0;
lastCommand.tau = 0;
lastState.tauEst = 0;
lastState.tau_est = 0;
command.initRT(lastCommand);
pid_controller_.reset();
@ -158,15 +158,15 @@ namespace robot_joint_controller
lastState.q = currentPos;
lastState.dq = currentVel;
// lastState.tauEst = calcTorque;
lastState.tauEst = joint.getEffort();
// lastState.tau_est = calcTorque;
lastState.tau_est = joint.getEffort();
// publish state
if (controller_state_publisher_ && controller_state_publisher_->trylock())
{
controller_state_publisher_->msg_.q = lastState.q;
controller_state_publisher_->msg_.dq = lastState.dq;
controller_state_publisher_->msg_.tauEst = lastState.tauEst;
controller_state_publisher_->msg_.tau_est = lastState.tau_est;
controller_state_publisher_->unlockAndPublish();
}
}

View File

@ -1,5 +1,5 @@
float32 q # motor current position (rad)
float32 dq # motor current speed (rad/s)
float32 ddq # motor current speed (rad/s)
float32 tauEst # current estimated output torque (N*m)
float32 tau_est # current estimated output torque (N*m)
float32 cur # current estimated output cur (N*m)