mirror of https://github.com/fan-ziqi/rl_sar.git
fix: 1. change joint_state_subscriber to joint_subscribers
2. rename tauEst to tau_est
This commit is contained in:
parent
da85366ac7
commit
87855a052d
|
@ -7,10 +7,10 @@
|
|||
#include <ros/ros.h>
|
||||
#include <sensor_msgs/Joy.h>
|
||||
#include <gazebo_msgs/ModelStates.h>
|
||||
#include <sensor_msgs/JointState.h>
|
||||
#include "std_srvs/Empty.h"
|
||||
#include <geometry_msgs/Twist.h>
|
||||
#include "robot_msgs/MotorCommand.h"
|
||||
#include "robot_msgs/MotorState.h"
|
||||
#include <csignal>
|
||||
#include <gazebo_msgs/SetModelState.h>
|
||||
|
||||
|
@ -54,27 +54,25 @@ private:
|
|||
geometry_msgs::Twist cmd_vel;
|
||||
sensor_msgs::Joy joy_msg;
|
||||
ros::Subscriber model_state_subscriber;
|
||||
ros::Subscriber joint_state_subscriber;
|
||||
ros::Subscriber cmd_vel_subscriber;
|
||||
ros::Subscriber joy_subscriber;
|
||||
ros::ServiceClient gazebo_set_model_state_client;
|
||||
ros::ServiceClient gazebo_pause_physics_client;
|
||||
ros::ServiceClient gazebo_unpause_physics_client;
|
||||
std::map<std::string, ros::Publisher> joint_publishers;
|
||||
std::map<std::string, ros::Subscriber> joint_subscribers;
|
||||
std::vector<robot_msgs::MotorCommand> joint_publishers_commands;
|
||||
void ModelStatesCallback(const gazebo_msgs::ModelStates::ConstPtr &msg);
|
||||
void JointStatesCallback(const sensor_msgs::JointState::ConstPtr &msg);
|
||||
void JointStatesCallback(const robot_msgs::MotorState::ConstPtr &msg, const std::string &joint_name);
|
||||
void CmdvelCallback(const geometry_msgs::Twist::ConstPtr &msg);
|
||||
void JoyCallback(const sensor_msgs::Joy::ConstPtr &msg);
|
||||
|
||||
// others
|
||||
std::string gazebo_model_name;
|
||||
int motiontime = 0;
|
||||
std::map<std::string, size_t> sorted_to_original_index;
|
||||
std::vector<double> mapped_joint_positions;
|
||||
std::vector<double> mapped_joint_velocities;
|
||||
std::vector<double> mapped_joint_efforts;
|
||||
void MapData(const std::vector<double> &source_data, std::vector<double> &target_data);
|
||||
std::map<std::string, double> joint_positions;
|
||||
std::map<std::string, double> joint_velocities;
|
||||
std::map<std::string, double> joint_efforts;
|
||||
};
|
||||
|
||||
#endif // RL_SIM_HPP
|
||||
|
|
|
@ -45,7 +45,7 @@ struct RobotState
|
|||
std::vector<T> q = std::vector<T>(32, 0.0);
|
||||
std::vector<T> dq = std::vector<T>(32, 0.0);
|
||||
std::vector<T> ddq = std::vector<T>(32, 0.0);
|
||||
std::vector<T> tauEst = std::vector<T>(32, 0.0);
|
||||
std::vector<T> tau_est = std::vector<T>(32, 0.0);
|
||||
std::vector<T> cur = std::vector<T>(32, 0.0);
|
||||
} motor_state;
|
||||
};
|
||||
|
|
Binary file not shown.
|
@ -41,7 +41,7 @@ class RobotState:
|
|||
self.q = [0.0] * 32
|
||||
self.dq = [0.0] * 32
|
||||
self.ddq = [0.0] * 32
|
||||
self.tauEst = [0.0] * 32
|
||||
self.tau_est = [0.0] * 32
|
||||
self.cur = [0.0] * 32
|
||||
|
||||
class STATE(Enum):
|
||||
|
|
|
@ -4,11 +4,9 @@ import torch
|
|||
import threading
|
||||
import time
|
||||
import rospy
|
||||
import numpy as np
|
||||
from gazebo_msgs.msg import ModelStates
|
||||
from sensor_msgs.msg import JointState
|
||||
from geometry_msgs.msg import Twist, Pose
|
||||
from robot_msgs.msg import MotorCommand
|
||||
from robot_msgs.msg import MotorState, MotorCommand
|
||||
from gazebo_msgs.srv import SetModelState, SetModelStateRequest
|
||||
from std_srvs.srv import Empty
|
||||
|
||||
|
@ -42,22 +40,13 @@ class RL_Sim(RL):
|
|||
if len(self.params.observations_history) != 0:
|
||||
self.history_obs_buf = ObservationBuffer(1, self.params.num_observations, len(self.params.observations_history))
|
||||
|
||||
# Due to the fact that the robot_state_publisher sorts the joint names alphabetically,
|
||||
# the mapping table is established according to the order defined in the YAML file
|
||||
sorted_joint_controller_names = sorted(self.params.joint_controller_names)
|
||||
self.sorted_to_original_index = {}
|
||||
for i in range(len(self.params.joint_controller_names)):
|
||||
self.sorted_to_original_index[sorted_joint_controller_names[i]] = i
|
||||
self.mapped_joint_positions = [0.0] * self.params.num_of_dofs
|
||||
self.mapped_joint_velocities = [0.0] * self.params.num_of_dofs
|
||||
self.mapped_joint_efforts = [0.0] * self.params.num_of_dofs
|
||||
|
||||
# init
|
||||
torch.set_grad_enabled(False)
|
||||
self.joint_publishers_commands = [MotorCommand() for _ in range(self.params.num_of_dofs)]
|
||||
self.InitObservations()
|
||||
self.InitOutputs()
|
||||
self.InitControl()
|
||||
self.running_state = STATE.STATE_RL_RUNNING
|
||||
|
||||
# model
|
||||
model_path = os.path.join(os.path.dirname(__file__), f"../models/{self.robot_name}/{self.params.model_name}")
|
||||
|
@ -67,14 +56,29 @@ class RL_Sim(RL):
|
|||
self.ros_namespace = rospy.get_param("ros_namespace", "")
|
||||
self.joint_publishers = {}
|
||||
for i in range(self.params.num_of_dofs):
|
||||
topic_name = f"{self.ros_namespace}{self.params.joint_controller_names[i]}/command"
|
||||
joint_name = self.params.joint_controller_names[i]
|
||||
topic_name = f"{self.ros_namespace}{joint_name}/command"
|
||||
self.joint_publishers[self.params.joint_controller_names[i]] = rospy.Publisher(topic_name, MotorCommand, queue_size=10)
|
||||
|
||||
# subscriber
|
||||
self.cmd_vel_subscriber = rospy.Subscriber("/cmd_vel", Twist, self.CmdvelCallback, queue_size=10)
|
||||
self.model_state_subscriber = rospy.Subscriber("/gazebo/model_states", ModelStates, self.ModelStatesCallback, queue_size=10)
|
||||
joint_states_topic = f"{self.ros_namespace}joint_states"
|
||||
self.joint_state_subscriber = rospy.Subscriber(joint_states_topic, JointState, self.JointStatesCallback, queue_size=10)
|
||||
self.joint_subscribers = {}
|
||||
self.joint_positions = {}
|
||||
self.joint_velocities = {}
|
||||
self.joint_efforts = {}
|
||||
for i in range(self.params.num_of_dofs):
|
||||
joint_name = self.params.joint_controller_names[i]
|
||||
topic_name = f"{self.ros_namespace}{joint_name}/state"
|
||||
self.joint_subscribers[joint_name] = rospy.Subscriber(
|
||||
topic_name,
|
||||
MotorState,
|
||||
lambda msg, name=joint_name: self.JointStatesCallback(msg, name),
|
||||
queue_size=10
|
||||
)
|
||||
self.joint_positions[joint_name] = 0.0
|
||||
self.joint_velocities[joint_name] = 0.0
|
||||
self.joint_efforts[joint_name] = 0.0
|
||||
|
||||
# service
|
||||
self.gazebo_model_name = rospy.get_param("gazebo_model_name", "")
|
||||
|
@ -120,9 +124,9 @@ class RL_Sim(RL):
|
|||
# state.imu.accelerometer
|
||||
|
||||
for i in range(self.params.num_of_dofs):
|
||||
state.motor_state.q[i] = self.mapped_joint_positions[i]
|
||||
state.motor_state.dq[i] = self.mapped_joint_velocities[i]
|
||||
state.motor_state.tauEst[i] = self.mapped_joint_efforts[i]
|
||||
state.motor_state.q[i] = self.joint_positions[self.params.joint_controller_names[i]]
|
||||
state.motor_state.dq[i] = self.joint_velocities[self.params.joint_controller_names[i]]
|
||||
state.motor_state.tau_est[i] = self.joint_efforts[self.params.joint_controller_names[i]]
|
||||
|
||||
def SetCommand(self, command):
|
||||
for i in range(self.params.num_of_dofs):
|
||||
|
@ -165,14 +169,10 @@ class RL_Sim(RL):
|
|||
def CmdvelCallback(self, msg):
|
||||
self.cmd_vel = msg
|
||||
|
||||
def MapData(self, source_data, target_data):
|
||||
for i in range(len(source_data)):
|
||||
target_data[i] = source_data[self.sorted_to_original_index[self.params.joint_controller_names[i]]]
|
||||
|
||||
def JointStatesCallback(self, msg):
|
||||
self.MapData(msg.position, self.mapped_joint_positions)
|
||||
self.MapData(msg.velocity, self.mapped_joint_velocities)
|
||||
self.MapData(msg.effort, self.mapped_joint_efforts)
|
||||
def JointStatesCallback(self, msg, joint_name):
|
||||
self.joint_positions[joint_name] = msg.q
|
||||
self.joint_velocities[joint_name] = msg.dq
|
||||
self.joint_efforts[joint_name] = msg.tau_est
|
||||
|
||||
def RunModel(self):
|
||||
if self.running_state == STATE.STATE_RL_RUNNING and self.simulation_running:
|
||||
|
@ -199,7 +199,9 @@ class RL_Sim(RL):
|
|||
self.output_dof_pos = self.ComputePosition(self.obs.actions)
|
||||
|
||||
if CSV_LOGGER:
|
||||
tau_est = torch.tensor(self.mapped_joint_efforts).unsqueeze(0)
|
||||
tau_est = torch.zeros((1, self.params.num_of_dofs))
|
||||
for i in range(self.params.num_of_dofs):
|
||||
tau_est[0, i] = self.joint_efforts[self.params.joint_controller_names[i]]
|
||||
self.CSVLogger(self.output_torques, tau_est, self.obs.dof_pos, self.output_dof_pos, self.obs.dof_vel)
|
||||
|
||||
def Forward(self):
|
||||
|
|
|
@ -116,7 +116,7 @@ void RL_Real::GetState(RobotState<double> *state)
|
|||
{
|
||||
state->motor_state.q[i] = this->unitree_low_state.motorState[state_mapping[i]].q;
|
||||
state->motor_state.dq[i] = this->unitree_low_state.motorState[state_mapping[i]].dq;
|
||||
state->motor_state.tauEst[i] = this->unitree_low_state.motorState[state_mapping[i]].tauEst;
|
||||
state->motor_state.tau_est[i] = this->unitree_low_state.motorState[state_mapping[i]].tauEst;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -179,7 +179,7 @@ void RL_Real::RunModel()
|
|||
this->output_dof_pos = this->ComputePosition(this->obs.actions);
|
||||
|
||||
#ifdef CSV_LOGGER
|
||||
torch::Tensor tau_est = torch::tensor(this->robot_state.motor_state.tauEst).unsqueeze(0);
|
||||
torch::Tensor tau_est = torch::tensor(this->robot_state.motor_state.tau_est).unsqueeze(0);
|
||||
this->CSVLogger(this->output_torques, tau_est, this->obs.dof_pos, this->output_dof_pos, this->obs.dof_vel);
|
||||
#endif
|
||||
}
|
||||
|
|
|
@ -122,7 +122,7 @@ void RL_Real::GetState(RobotState<double> *state)
|
|||
{
|
||||
state->motor_state.q[i] = this->unitree_low_state.motor_state()[state_mapping[i]].q();
|
||||
state->motor_state.dq[i] = this->unitree_low_state.motor_state()[state_mapping[i]].dq();
|
||||
state->motor_state.tauEst[i] = this->unitree_low_state.motor_state()[state_mapping[i]].tau_est();
|
||||
state->motor_state.tau_est[i] = this->unitree_low_state.motor_state()[state_mapping[i]].tau_est();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -184,7 +184,7 @@ void RL_Real::RunModel()
|
|||
this->output_dof_pos = this->ComputePosition(this->obs.actions);
|
||||
|
||||
#ifdef CSV_LOGGER
|
||||
torch::Tensor tau_est = torch::tensor(this->robot_state.motor_state.tauEst).unsqueeze(0);
|
||||
torch::Tensor tau_est = torch::tensor(this->robot_state.motor_state.tau_est).unsqueeze(0);
|
||||
this->CSVLogger(this->output_torques, tau_est, this->obs.dof_pos, this->output_dof_pos, this->obs.dof_vel);
|
||||
#endif
|
||||
}
|
||||
|
|
|
@ -18,17 +18,6 @@ RL_Sim::RL_Sim()
|
|||
observation = "ang_vel_world";
|
||||
}
|
||||
}
|
||||
// Due to the fact that the robot_state_publisher sorts the joint names alphabetically,
|
||||
// the mapping table is established according to the order defined in the YAML file
|
||||
std::vector<std::string> sorted_joint_controller_names = this->params.joint_controller_names;
|
||||
std::sort(sorted_joint_controller_names.begin(), sorted_joint_controller_names.end());
|
||||
for (size_t i = 0; i < this->params.joint_controller_names.size(); ++i)
|
||||
{
|
||||
this->sorted_to_original_index[sorted_joint_controller_names[i]] = i;
|
||||
}
|
||||
this->mapped_joint_positions = std::vector<double>(this->params.num_of_dofs, 0.0);
|
||||
this->mapped_joint_velocities = std::vector<double>(this->params.num_of_dofs, 0.0);
|
||||
this->mapped_joint_efforts = std::vector<double>(this->params.num_of_dofs, 0.0);
|
||||
|
||||
// init rl
|
||||
torch::autograd::GradMode::set_enabled(false);
|
||||
|
@ -52,15 +41,32 @@ RL_Sim::RL_Sim()
|
|||
for (int i = 0; i < this->params.num_of_dofs; ++i)
|
||||
{
|
||||
// joint need to rename as xxx_joint
|
||||
this->joint_publishers[this->params.joint_controller_names[i]] =
|
||||
nh.advertise<robot_msgs::MotorCommand>(this->ros_namespace + this->params.joint_controller_names[i] + "/command", 10);
|
||||
const std::string &joint_name = this->params.joint_controller_names[i];
|
||||
const std::string topic_name = this->ros_namespace + joint_name + "/command";
|
||||
this->joint_publishers[joint_name] =
|
||||
nh.advertise<robot_msgs::MotorCommand>(topic_name, 10);
|
||||
}
|
||||
|
||||
// subscriber
|
||||
this->cmd_vel_subscriber = nh.subscribe<geometry_msgs::Twist>("/cmd_vel", 10, &RL_Sim::CmdvelCallback, this);
|
||||
this->joy_subscriber = nh.subscribe<sensor_msgs::Joy>("/joy", 10, &RL_Sim::JoyCallback, this);
|
||||
this->model_state_subscriber = nh.subscribe<gazebo_msgs::ModelStates>("/gazebo/model_states", 10, &RL_Sim::ModelStatesCallback, this);
|
||||
this->joint_state_subscriber = nh.subscribe<sensor_msgs::JointState>(this->ros_namespace + "joint_states", 10, &RL_Sim::JointStatesCallback, this);
|
||||
for (int i = 0; i < this->params.num_of_dofs; ++i)
|
||||
{
|
||||
// joint need to rename as xxx_joint
|
||||
const std::string &joint_name = this->params.joint_controller_names[i];
|
||||
const std::string topic_name = this->ros_namespace + joint_name + "/state";
|
||||
this->joint_subscribers[joint_name] =
|
||||
nh.subscribe<robot_msgs::MotorState>(topic_name, 10,
|
||||
[this, joint_name](const robot_msgs::MotorState::ConstPtr &msg)
|
||||
{
|
||||
this->JointStatesCallback(msg, joint_name);
|
||||
}
|
||||
);
|
||||
this->joint_positions[joint_name] = 0.0;
|
||||
this->joint_velocities[joint_name] = 0.0;
|
||||
this->joint_efforts[joint_name] = 0.0;
|
||||
}
|
||||
|
||||
// service
|
||||
nh.param<std::string>("gazebo_model_name", this->gazebo_model_name, "");
|
||||
|
@ -130,9 +136,9 @@ void RL_Sim::GetState(RobotState<double> *state)
|
|||
|
||||
for (int i = 0; i < this->params.num_of_dofs; ++i)
|
||||
{
|
||||
state->motor_state.q[i] = this->mapped_joint_positions[i];
|
||||
state->motor_state.dq[i] = this->mapped_joint_velocities[i];
|
||||
state->motor_state.tauEst[i] = this->mapped_joint_efforts[i];
|
||||
state->motor_state.q[i] = this->joint_positions[this->params.joint_controller_names[i]];
|
||||
state->motor_state.dq[i] = this->joint_velocities[this->params.joint_controller_names[i]];
|
||||
state->motor_state.tau_est[i] = this->joint_efforts[this->params.joint_controller_names[i]];
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -240,19 +246,11 @@ void RL_Sim::JoyCallback(const sensor_msgs::Joy::ConstPtr &msg)
|
|||
this->control.yaw = this->joy_msg.axes[3] * 1.5; // Rx
|
||||
}
|
||||
|
||||
void RL_Sim::MapData(const std::vector<double> &source_data, std::vector<double> &target_data)
|
||||
void RL_Sim::JointStatesCallback(const robot_msgs::MotorState::ConstPtr &msg, const std::string &joint_name)
|
||||
{
|
||||
for (size_t i = 0; i < source_data.size(); ++i)
|
||||
{
|
||||
target_data[i] = source_data[this->sorted_to_original_index[this->params.joint_controller_names[i]]];
|
||||
}
|
||||
}
|
||||
|
||||
void RL_Sim::JointStatesCallback(const sensor_msgs::JointState::ConstPtr &msg)
|
||||
{
|
||||
MapData(msg->position, this->mapped_joint_positions);
|
||||
MapData(msg->velocity, this->mapped_joint_velocities);
|
||||
MapData(msg->effort, this->mapped_joint_efforts);
|
||||
this->joint_positions[joint_name] = msg->q;
|
||||
this->joint_velocities[joint_name] = msg->dq;
|
||||
this->joint_efforts[joint_name] = msg->tau_est;
|
||||
}
|
||||
|
||||
void RL_Sim::RunModel()
|
||||
|
@ -286,7 +284,11 @@ void RL_Sim::RunModel()
|
|||
this->output_dof_pos = this->ComputePosition(this->obs.actions);
|
||||
|
||||
#ifdef CSV_LOGGER
|
||||
torch::Tensor tau_est = torch::tensor(this->mapped_joint_efforts).unsqueeze(0);
|
||||
torch::Tensor tau_est = torch::zeros({1, this->params.num_of_dofs});
|
||||
for (int i = 0; i < this->params.num_of_dofs; ++i)
|
||||
{
|
||||
tau_est[0][i] = this->joint_efforts[this->params.joint_controller_names[i]];
|
||||
}
|
||||
this->CSVLogger(this->output_torques, tau_est, this->obs.dof_pos, this->output_dof_pos, this->obs.dof_vel);
|
||||
#endif
|
||||
}
|
||||
|
@ -330,7 +332,7 @@ void RL_Sim::Plot()
|
|||
{
|
||||
this->plot_real_joint_pos[i].erase(this->plot_real_joint_pos[i].begin());
|
||||
this->plot_target_joint_pos[i].erase(this->plot_target_joint_pos[i].begin());
|
||||
this->plot_real_joint_pos[i].push_back(this->mapped_joint_positions[i]);
|
||||
this->plot_real_joint_pos[i].push_back(this->joint_positions[this->params.joint_controller_names[i]]);
|
||||
this->plot_target_joint_pos[i].push_back(this->joint_publishers_commands[i].q);
|
||||
plt::subplot(4, 3, i + 1);
|
||||
plt::named_plot("_real_joint_pos", this->plot_t, this->plot_real_joint_pos[i], "r");
|
||||
|
|
|
@ -111,7 +111,7 @@ namespace robot_joint_controller
|
|||
lastCommand.dq = 0;
|
||||
lastState.dq = 0;
|
||||
lastCommand.tau = 0;
|
||||
lastState.tauEst = 0;
|
||||
lastState.tau_est = 0;
|
||||
command.initRT(lastCommand);
|
||||
|
||||
pid_controller_.reset();
|
||||
|
@ -158,15 +158,15 @@ namespace robot_joint_controller
|
|||
|
||||
lastState.q = currentPos;
|
||||
lastState.dq = currentVel;
|
||||
// lastState.tauEst = calcTorque;
|
||||
lastState.tauEst = joint.getEffort();
|
||||
// lastState.tau_est = calcTorque;
|
||||
lastState.tau_est = joint.getEffort();
|
||||
|
||||
// publish state
|
||||
if (controller_state_publisher_ && controller_state_publisher_->trylock())
|
||||
{
|
||||
controller_state_publisher_->msg_.q = lastState.q;
|
||||
controller_state_publisher_->msg_.dq = lastState.dq;
|
||||
controller_state_publisher_->msg_.tauEst = lastState.tauEst;
|
||||
controller_state_publisher_->msg_.tau_est = lastState.tau_est;
|
||||
controller_state_publisher_->unlockAndPublish();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
float32 q # motor current position (rad)
|
||||
float32 dq # motor current speed (rad/s)
|
||||
float32 ddq # motor current speed (rad/s)
|
||||
float32 tauEst # current estimated output torque (N*m)
|
||||
float32 tau_est # current estimated output torque (N*m)
|
||||
float32 cur # current estimated output cur (N*m)
|
Loading…
Reference in New Issue