mirror of https://github.com/fan-ziqi/rl_sar.git
feat: add unitree_rl
This commit is contained in:
parent
010e7942d8
commit
0ffb6b2af7
|
@ -0,0 +1,5 @@
|
|||
*.pt
|
||||
build
|
||||
devel
|
||||
logs
|
||||
.catkin_tools
|
|
@ -0,0 +1,62 @@
|
|||
cmake_minimum_required(VERSION 3.0.2)
|
||||
project(unitree_rl)
|
||||
|
||||
set(EXTRA_LIBS -pthread libunitree_legged_sdk_amd64.so lcm)
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")
|
||||
|
||||
find_package(Torch REQUIRED)
|
||||
find_package(unitree_legged_sdk REQUIRED)
|
||||
|
||||
find_package(catkin REQUIRED COMPONENTS
|
||||
controller_manager
|
||||
genmsg
|
||||
joint_state_controller
|
||||
robot_state_publisher
|
||||
roscpp
|
||||
gazebo_ros
|
||||
std_msgs
|
||||
tf
|
||||
geometry_msgs
|
||||
unitree_legged_msgs
|
||||
)
|
||||
|
||||
find_package(gazebo REQUIRED)
|
||||
|
||||
catkin_package(
|
||||
CATKIN_DEPENDS
|
||||
unitree_legged_msgs
|
||||
)
|
||||
|
||||
message("-- CMAKE_SYSTEM_PROCESSOR: ${CMAKE_SYSTEM_PROCESSOR}")
|
||||
if("${CMAKE_SYSTEM_PROCESSOR}" MATCHES "x86_64.*")
|
||||
set(ARCH amd64)
|
||||
else()
|
||||
set(ARCH arm64)
|
||||
endif()
|
||||
|
||||
set(EXTRA_LIBS -pthread ${unitree_legged_sdk_LIBRARIES})
|
||||
|
||||
include_directories(
|
||||
include
|
||||
${catkin_INCLUDE_DIRS}
|
||||
${unitree_legged_sdk_INCLUDE_DIRS}
|
||||
../unitree_controller/include
|
||||
|
||||
)
|
||||
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${GAZEBO_CXX_FLAGS}")
|
||||
|
||||
add_library(model lib/model.cpp lib/model.hpp)
|
||||
target_link_libraries(model "${TORCH_LIBRARIES}")
|
||||
set_property(TARGET model PROPERTY CXX_STANDARD 14)
|
||||
|
||||
add_library(observation_buffer lib/observation_buffer.cpp lib/observation_buffer.hpp)
|
||||
target_link_libraries(observation_buffer "${TORCH_LIBRARIES}")
|
||||
set_property(TARGET observation_buffer PROPERTY CXX_STANDARD 14)
|
||||
|
||||
add_executable(unitree_rl src/unitree_rl.cpp )
|
||||
target_link_libraries(${PROJECT_NAME}
|
||||
${catkin_LIBRARIES} ${EXTRA_LIBS} "${TORCH_LIBRARIES}"
|
||||
model observation_buffer
|
||||
)
|
||||
# add_dependencies(${PROJECT_NAME} unitree_legged_msgs_gencpp)
|
|
@ -0,0 +1,56 @@
|
|||
#ifndef UNITREE_RL
|
||||
#define UNITREE_RL
|
||||
|
||||
#include <ros/ros.h>
|
||||
#include <gazebo_msgs/ModelStates.h>
|
||||
#include <sensor_msgs/JointState.h>
|
||||
#include <geometry_msgs/Twist.h>
|
||||
#include "../lib/model.cpp"
|
||||
#include "../lib/observation_buffer.hpp"
|
||||
#include "unitree_legged_msgs/MotorCmd.h"
|
||||
|
||||
class Unitree_RL : public Model
|
||||
{
|
||||
public:
|
||||
|
||||
Unitree_RL();
|
||||
void modelStatesCallback(const gazebo_msgs::ModelStates::ConstPtr& msg);
|
||||
void jointStatesCallback(const sensor_msgs::JointState::ConstPtr& msg);
|
||||
void cmdvelCallback(const geometry_msgs::Twist::ConstPtr& msg);
|
||||
void runModel(const ros::TimerEvent& event);
|
||||
torch::Tensor forward() override;
|
||||
torch::Tensor compute_observation() override;
|
||||
|
||||
ObservationBuffer history_obs_buf;
|
||||
torch::Tensor history_obs;
|
||||
|
||||
private:
|
||||
|
||||
std::string ros_namespace;
|
||||
|
||||
std::vector<std::string> torque_command_topics;
|
||||
|
||||
ros::Subscriber model_state_subscriber_;
|
||||
ros::Subscriber joint_state_subscriber_;
|
||||
ros::Subscriber cmd_vel_subscriber_;
|
||||
|
||||
std::map<std::string, ros::Publisher> torque_publishers;
|
||||
std::vector<unitree_legged_msgs::MotorCmd> torque_commands;
|
||||
|
||||
geometry_msgs::Twist vel;
|
||||
geometry_msgs::Pose pose;
|
||||
geometry_msgs::Twist cmd_vel;
|
||||
|
||||
std::vector<std::string> joint_names;
|
||||
std::vector<double> joint_positions;
|
||||
std::vector<double> joint_velocities;
|
||||
|
||||
torch::Tensor torques;
|
||||
|
||||
ros::Timer timer;
|
||||
|
||||
std::chrono::high_resolution_clock::time_point start_time;
|
||||
|
||||
};
|
||||
|
||||
#endif
|
|
@ -0,0 +1,60 @@
|
|||
<launch>
|
||||
<arg name="wname" default="stairs"/>
|
||||
<arg name="rname" default="a1"/>
|
||||
<arg name="robot_path" value="(find $(arg rname)_description)"/>
|
||||
<arg name="dollar" value="$"/>
|
||||
|
||||
<arg name="paused" default="true"/>
|
||||
<arg name="use_sim_time" default="true"/>
|
||||
<arg name="gui" default="true"/>
|
||||
<arg name="headless" default="false"/>
|
||||
<arg name="debug" default="false"/>
|
||||
<!-- Debug mode will hung up the robot, use "true" or "false" to switch it. -->
|
||||
<arg name="user_debug" default="false"/>
|
||||
|
||||
<include file="$(find gazebo_ros)/launch/empty_world.launch">
|
||||
<arg name="world_name" value="$(find unitree_rl)/worlds/$(arg wname).world"/>
|
||||
<arg name="debug" value="$(arg debug)"/>
|
||||
<arg name="gui" value="$(arg gui)"/>
|
||||
<arg name="paused" value="$(arg paused)"/>
|
||||
<arg name="use_sim_time" value="$(arg use_sim_time)"/>
|
||||
<arg name="headless" value="$(arg headless)"/>
|
||||
</include>
|
||||
|
||||
<!-- Load the URDF into the ROS Parameter Server -->
|
||||
<param name="robot_description"
|
||||
command="$(find xacro)/xacro --inorder '$(arg dollar)$(arg robot_path)/xacro/robot.xacro'
|
||||
DEBUG:=$(arg user_debug)"/>
|
||||
|
||||
<!-- Run a python script to the send a service call to gazebo_ros to spawn a URDF robot -->
|
||||
<!-- Set trunk and joint positions at startup -->
|
||||
<node pkg="gazebo_ros" type="spawn_model" name="urdf_spawner" respawn="false" output="screen"
|
||||
args="-urdf -z 0.6 -model $(arg rname)_gazebo -param robot_description -unpause"/>
|
||||
|
||||
<!-- Load joint controller configurations from YAML file to parameter server -->
|
||||
<rosparam file="$(arg dollar)$(arg robot_path)/config/robot_control.yaml" command="load"/>
|
||||
|
||||
<rosparam param="/a1_gazebo/joint_state_controller/publish_rate">5000</rosparam>
|
||||
|
||||
<!-- load the controllers -->
|
||||
<node pkg="controller_manager" type="spawner" name="controller_spawner" respawn="false"
|
||||
output="screen" ns="/$(arg rname)_gazebo" args="joint_state_controller
|
||||
FL_hip_controller FL_thigh_controller FL_calf_controller
|
||||
FR_hip_controller FR_thigh_controller FR_calf_controller
|
||||
RL_hip_controller RL_thigh_controller RL_calf_controller
|
||||
RR_hip_controller RR_thigh_controller RR_calf_controller "/>
|
||||
|
||||
<!-- convert joint states to TF transforms for rviz, etc -->
|
||||
<node pkg="robot_state_publisher" type="robot_state_publisher" name="robot_state_publisher"
|
||||
respawn="false" output="screen">
|
||||
<remap from="/joint_states" to="/$(arg rname)_gazebo/joint_states"/>
|
||||
</node>
|
||||
|
||||
<!-- <node pkg="unitree_gazebo" type="servo" name="servo" required="true" output="screen"/> -->
|
||||
|
||||
<!-- load the parameter unitree_controller -->
|
||||
<include file="$(find unitree_controller)/launch/set_ctrl.launch">
|
||||
<arg name="rname" value="$(arg rname)"/>
|
||||
</include>
|
||||
|
||||
</launch>
|
|
@ -0,0 +1,76 @@
|
|||
#include "model.hpp"
|
||||
|
||||
void Model::init_models(std::string actor_path, std::string encoder_path, std::string vq_path)
|
||||
{
|
||||
this->actor = torch::jit::load(actor_path);
|
||||
this->encoder = torch::jit::load(encoder_path);
|
||||
this->vq = torch::jit::load(vq_path);
|
||||
this->init_observations();
|
||||
}
|
||||
|
||||
torch::Tensor Model::quat_rotate_inverse(torch::Tensor q, torch::Tensor v)
|
||||
{
|
||||
c10::IntArrayRef shape = q.sizes();
|
||||
torch::Tensor q_w = q.index({torch::indexing::Slice(), -1});
|
||||
torch::Tensor q_vec = q.index({torch::indexing::Slice(), torch::indexing::Slice(0, 3)});
|
||||
torch::Tensor a = v * (2.0 * torch::pow(q_w, 2) - 1.0).unsqueeze(-1);
|
||||
torch::Tensor b = torch::cross(q_vec, v, /*dim=*/-1) * q_w.unsqueeze(-1) * 2.0;
|
||||
torch::Tensor c = q_vec * torch::bmm(q_vec.view({shape[0], 1, 3}), v.view({shape[0], 3, 1})).squeeze(-1) * 2.0;
|
||||
return a - b + c;
|
||||
}
|
||||
|
||||
torch::Tensor Model::compute_torques(torch::Tensor actions)
|
||||
{
|
||||
actions *= this->params.action_scale;
|
||||
torch::Tensor torques = this->params.p_gains * (actions + this->params.default_dof_pos - this->obs.dof_pos) - this->params.d_gains * this->obs.dof_vel;
|
||||
torch::Tensor clamped = torch::clamp(torques, -(this->params.torque_limits), this->params.torque_limits);
|
||||
return clamped;
|
||||
}
|
||||
|
||||
torch::Tensor Model::compute_observation()
|
||||
{
|
||||
std::cout << "You may need to override this compute_observation() function" << std::endl;
|
||||
|
||||
torch::Tensor obs = torch::cat({(this->quat_rotate_inverse(this->base_quat, this->lin_vel)) * this->params.lin_vel_scale,
|
||||
(this->quat_rotate_inverse(this->obs.base_quat, this->obs.ang_vel)) * this->params.ang_vel_scale,
|
||||
this->quat_rotate_inverse(this->obs.base_quat, this->obs.gravity_vec),
|
||||
this->obs.commands * this->params.commands_scale,
|
||||
(this->obs.dof_pos - this->params.default_dof_pos) * this->params.dof_pos_scale,
|
||||
this->obs.dof_vel * this->params.dof_vel_scale,
|
||||
this->obs.actions},
|
||||
1);
|
||||
|
||||
obs = torch::clamp(obs, -this->params.clip_obs, this->params.clip_obs);
|
||||
|
||||
// printf("observation size: %d, %d\n", obs.sizes()[0], obs.sizes()[1]);
|
||||
|
||||
return obs;
|
||||
}
|
||||
|
||||
void Model::init_observations()
|
||||
{
|
||||
this->obs.lin_vel = torch::tensor({{0.0, 0.0, 0.0}});
|
||||
this->obs.ang_vel = torch::tensor({{0.0, 0.0, 0.0}});
|
||||
this->obs.gravity_vec = torch::tensor({{0.0, 0.0, -1.0}});
|
||||
this->obs.commands = torch::tensor({{0.0, 0.0, 0.0}});
|
||||
this->obs.base_quat = torch::tensor({{0.0, 0.0, 0.0, 1.0}});
|
||||
this->obs.dof_pos = this->params.default_dof_pos;
|
||||
this->obs.dof_vel = torch::tensor({{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}});
|
||||
this->obs.actions = torch::tensor({{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}});
|
||||
}
|
||||
|
||||
torch::Tensor Model::forward()
|
||||
{
|
||||
std::cout << "You may need to override this forward() function" << std::endl;
|
||||
|
||||
torch::Tensor obs = this->compute_observation();
|
||||
|
||||
torch::Tensor actor_input = torch::cat({obs}, 1);
|
||||
|
||||
torch::Tensor action = this->actor.forward({actor_input}).toTensor();
|
||||
|
||||
this->obs.actions = action;
|
||||
torch::Tensor clamped = torch::clamp(action, -this->params.clip_actions, this->params.clip_actions);
|
||||
|
||||
return clamped;
|
||||
}
|
|
@ -0,0 +1,69 @@
|
|||
#ifndef MODEL_HPP
|
||||
#define MODEL_HPP
|
||||
|
||||
#include <torch/script.h>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
|
||||
struct ModelParams {
|
||||
int num_observations;
|
||||
float damping;
|
||||
float stiffness;
|
||||
float action_scale;
|
||||
float num_of_dofs;
|
||||
float lin_vel_scale;
|
||||
float ang_vel_scale;
|
||||
float dof_pos_scale;
|
||||
float dof_vel_scale;
|
||||
float clip_obs;
|
||||
float clip_actions;
|
||||
torch::Tensor torque_limits;
|
||||
torch::Tensor d_gains;
|
||||
torch::Tensor p_gains;
|
||||
torch::Tensor commands_scale;
|
||||
torch::Tensor default_dof_pos;
|
||||
};
|
||||
|
||||
struct Observations {
|
||||
torch::Tensor lin_vel;
|
||||
torch::Tensor ang_vel;
|
||||
torch::Tensor gravity_vec;
|
||||
torch::Tensor commands;
|
||||
torch::Tensor base_quat;
|
||||
torch::Tensor dof_pos;
|
||||
torch::Tensor dof_vel;
|
||||
torch::Tensor actions;
|
||||
};
|
||||
|
||||
class Model {
|
||||
public:
|
||||
Model(){};
|
||||
ModelParams params;
|
||||
Observations obs;
|
||||
|
||||
virtual torch::Tensor forward();
|
||||
virtual torch::Tensor compute_observation();
|
||||
|
||||
torch::Tensor compute_torques(torch::Tensor actions);
|
||||
torch::Tensor quat_rotate_inverse(torch::Tensor q, torch::Tensor v);
|
||||
void init_observations();
|
||||
void init_models(std::string actor_path, std::string encoder_path, std::string vq_path);
|
||||
|
||||
protected:
|
||||
// rl module
|
||||
torch::jit::script::Module actor;
|
||||
torch::jit::script::Module encoder;
|
||||
torch::jit::script::Module vq;
|
||||
// observation buffer
|
||||
torch::Tensor lin_vel;
|
||||
torch::Tensor ang_vel;
|
||||
torch::Tensor gravity_vec;
|
||||
torch::Tensor commands;
|
||||
torch::Tensor base_quat;
|
||||
torch::Tensor dof_pos;
|
||||
torch::Tensor dof_vel;
|
||||
torch::Tensor actions;
|
||||
|
||||
};
|
||||
|
||||
#endif // MODEL_HPP
|
|
@ -0,0 +1,45 @@
|
|||
#include "observation_buffer.hpp"
|
||||
|
||||
ObservationBuffer::ObservationBuffer() {}
|
||||
|
||||
ObservationBuffer::ObservationBuffer(int num_envs,
|
||||
int num_obs,
|
||||
int include_history_steps)
|
||||
: num_envs(num_envs),
|
||||
num_obs(num_obs),
|
||||
include_history_steps(include_history_steps)
|
||||
{
|
||||
num_obs_total = num_obs * include_history_steps;
|
||||
obs_buf = torch::zeros({num_envs, num_obs_total}, torch::dtype(torch::kFloat32));
|
||||
}
|
||||
|
||||
void ObservationBuffer::reset(std::vector<int> reset_idxs, torch::Tensor new_obs)
|
||||
{
|
||||
std::vector<torch::indexing::TensorIndex> indices;
|
||||
for (int idx : reset_idxs) {
|
||||
indices.push_back(torch::indexing::Slice(idx));
|
||||
}
|
||||
obs_buf.index_put_(indices, new_obs.repeat({1, include_history_steps}));
|
||||
}
|
||||
|
||||
void ObservationBuffer::insert(torch::Tensor new_obs)
|
||||
{
|
||||
// Shift observations back.
|
||||
torch::Tensor shifted_obs = obs_buf.index({torch::indexing::Slice(torch::indexing::None), torch::indexing::Slice(num_obs, num_obs * include_history_steps)}).clone();
|
||||
obs_buf.index({torch::indexing::Slice(torch::indexing::None), torch::indexing::Slice(0, num_obs * (include_history_steps - 1))}) = shifted_obs;
|
||||
|
||||
// Add new observation.
|
||||
obs_buf.index({torch::indexing::Slice(torch::indexing::None), torch::indexing::Slice(-num_obs, torch::indexing::None)}) = new_obs;
|
||||
}
|
||||
|
||||
torch::Tensor ObservationBuffer::get_obs_vec(std::vector<int> obs_ids)
|
||||
{
|
||||
std::vector<torch::Tensor> obs;
|
||||
for (int i = obs_ids.size() - 1; i >= 0; --i)
|
||||
{
|
||||
int obs_id = obs_ids[i];
|
||||
int slice_idx = include_history_steps - obs_id - 1;
|
||||
obs.push_back(obs_buf.index({torch::indexing::Slice(torch::indexing::None), torch::indexing::Slice(slice_idx * num_obs, (slice_idx + 1) * num_obs)}));
|
||||
}
|
||||
return torch::cat(obs, -1);
|
||||
}
|
|
@ -0,0 +1,24 @@
|
|||
#ifndef OBSERVATION_BUFFER_HPP
|
||||
#define OBSERVATION_BUFFER_HPP
|
||||
|
||||
#include <torch/torch.h>
|
||||
#include <vector>
|
||||
|
||||
class ObservationBuffer {
|
||||
public:
|
||||
ObservationBuffer(int num_envs, int num_obs, int include_history_steps);
|
||||
ObservationBuffer();
|
||||
|
||||
void reset(std::vector<int> reset_idxs, torch::Tensor new_obs);
|
||||
void insert(torch::Tensor new_obs);
|
||||
torch::Tensor get_obs_vec(std::vector<int> obs_ids);
|
||||
|
||||
private:
|
||||
int num_envs;
|
||||
int num_obs;
|
||||
int include_history_steps;
|
||||
int num_obs_total;
|
||||
torch::Tensor obs_buf;
|
||||
};
|
||||
|
||||
#endif // OBSERVATION_BUFFER_HPP
|
|
@ -0,0 +1,33 @@
|
|||
<?xml version="1.0"?>
|
||||
<package format="2">
|
||||
<name>unitree_rl</name>
|
||||
<version>0.0.0</version>
|
||||
<description>The unitree_rl package</description>
|
||||
|
||||
<maintainer email="fanziqi614@gmail.com">Ziqi Fan</maintainer>
|
||||
|
||||
<license>TODO</license>
|
||||
|
||||
<buildtool_depend>catkin</buildtool_depend>
|
||||
<buildtool_depend>genmsg</buildtool_depend>
|
||||
<build_depend>controller_manager</build_depend>
|
||||
<build_depend>joint_state_controller</build_depend>
|
||||
<build_depend>robot_state_publisher</build_depend>
|
||||
<build_depend>roscpp</build_depend>
|
||||
<build_depend>std_msgs</build_depend>
|
||||
<build_export_depend>controller_manager</build_export_depend>
|
||||
<build_export_depend>joint_state_controller</build_export_depend>
|
||||
<build_export_depend>robot_state_publisher</build_export_depend>
|
||||
<build_export_depend>roscpp</build_export_depend>
|
||||
<build_export_depend>std_msgs</build_export_depend>
|
||||
<exec_depend>controller_manager</exec_depend>
|
||||
<exec_depend>joint_state_controller</exec_depend>
|
||||
<exec_depend>robot_state_publisher</exec_depend>
|
||||
<exec_depend>roscpp</exec_depend>
|
||||
<exec_depend>std_msgs</exec_depend>
|
||||
<depend>unitree_legged_msgs</depend>
|
||||
|
||||
<export>
|
||||
|
||||
</export>
|
||||
</package>
|
|
@ -0,0 +1,164 @@
|
|||
#include "../include/unitree_rl.hpp"
|
||||
#include <ros/package.h>
|
||||
|
||||
Unitree_RL::Unitree_RL()
|
||||
{
|
||||
ros::NodeHandle nh;
|
||||
start_time = std::chrono::high_resolution_clock::now();
|
||||
|
||||
cmd_vel = geometry_msgs::Twist();
|
||||
|
||||
torque_commands.resize(12);
|
||||
|
||||
ros_namespace = "/a1_gazebo/";
|
||||
|
||||
joint_names = {
|
||||
"FL_hip_joint", "FL_thigh_joint", "FL_calf_joint",
|
||||
"FR_hip_joint", "FR_thigh_joint", "FR_calf_joint",
|
||||
"RL_hip_joint", "RL_thigh_joint", "RL_calf_joint",
|
||||
"RR_hip_joint", "RR_thigh_joint", "RR_calf_joint",
|
||||
};
|
||||
|
||||
for (int i = 0; i < 12; ++i)
|
||||
{
|
||||
torque_publishers[joint_names[i]] = nh.advertise<unitree_legged_msgs::MotorCmd>(
|
||||
ros_namespace + joint_names[i].substr(0, joint_names[i].size() - 6) + "_controller/command", 10);
|
||||
}
|
||||
|
||||
std::string package_name = "unitree_rl";
|
||||
std::string actor_path = ros::package::getPath(package_name) + "/models/actor.pt";
|
||||
std::string encoder_path = ros::package::getPath(package_name) + "/models/encoder.pt";
|
||||
std::string vq_path = ros::package::getPath(package_name) + "/models/vq_layer.pt";
|
||||
this->init_models(actor_path, encoder_path, vq_path);
|
||||
|
||||
this->params.num_observations = 45;
|
||||
this->params.clip_obs = 100.0;
|
||||
this->params.clip_actions = 100.0;
|
||||
this->params.damping = 0.5;
|
||||
this->params.stiffness = 20;
|
||||
this->params.d_gains = torch::ones(12) * this->params.damping;
|
||||
this->params.p_gains = torch::ones(12) * this->params.stiffness;
|
||||
this->params.action_scale = 0.25;
|
||||
this->params.num_of_dofs = 12;
|
||||
this->params.lin_vel_scale = 2.0;
|
||||
this->params.ang_vel_scale = 0.25;
|
||||
this->params.dof_pos_scale = 1.0;
|
||||
this->params.dof_vel_scale = 0.05;
|
||||
this->params.commands_scale = torch::tensor({this->params.lin_vel_scale, this->params.lin_vel_scale, this->params.ang_vel_scale});
|
||||
|
||||
// hip, thigh, calf
|
||||
this->params.torque_limits = torch::tensor({{20.0, 55.0, 55.0, // front left
|
||||
20.0, 55.0, 55.0, // front right
|
||||
20.0, 55.0, 55.0, // rear left
|
||||
20.0, 55.0, 55.0}}); // rear right
|
||||
|
||||
this->params.default_dof_pos = torch::tensor({{0.1000, 0.8000, -1.5000,
|
||||
-0.1000, 0.8000, -1.5000,
|
||||
0.1000, 1.0000, -1.5000,
|
||||
-0.1000, 1.0000, -1.5000}});
|
||||
|
||||
this->history_obs_buf = ObservationBuffer(1, this->params.num_observations, 6);
|
||||
|
||||
// Create a subscriber object
|
||||
model_state_subscriber_ = nh.subscribe<gazebo_msgs::ModelStates>(
|
||||
"/gazebo/model_states", 10, &Unitree_RL::modelStatesCallback, this);
|
||||
|
||||
joint_state_subscriber_ = nh.subscribe<sensor_msgs::JointState>(
|
||||
"/a1_gazebo/joint_states", 10, &Unitree_RL::jointStatesCallback, this);
|
||||
|
||||
cmd_vel_subscriber_ = nh.subscribe<geometry_msgs::Twist>(
|
||||
"/cmd_vel", 10, &Unitree_RL::cmdvelCallback, this);
|
||||
|
||||
timer = nh.createTimer(ros::Duration(0.005), &Unitree_RL::runModel, this);
|
||||
}
|
||||
|
||||
void Unitree_RL::modelStatesCallback(const gazebo_msgs::ModelStates::ConstPtr &msg)
|
||||
{
|
||||
|
||||
vel = msg->twist[2];
|
||||
pose = msg->pose[2];
|
||||
}
|
||||
|
||||
void Unitree_RL::cmdvelCallback(const geometry_msgs::Twist::ConstPtr &msg)
|
||||
{
|
||||
cmd_vel = *msg;
|
||||
}
|
||||
|
||||
void Unitree_RL::jointStatesCallback(const sensor_msgs::JointState::ConstPtr &msg)
|
||||
{
|
||||
joint_positions = msg->position;
|
||||
joint_velocities = msg->velocity;
|
||||
}
|
||||
|
||||
void Unitree_RL::runModel(const ros::TimerEvent &event)
|
||||
{
|
||||
auto duration = std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::high_resolution_clock::now() - start_time).count();
|
||||
start_time = std::chrono::high_resolution_clock::now();
|
||||
|
||||
this->obs.lin_vel = torch::tensor({{vel.linear.x, vel.linear.y, vel.linear.z}});
|
||||
this->obs.ang_vel = torch::tensor({{vel.angular.x, vel.angular.y, vel.angular.z}});
|
||||
this->obs.commands = torch::tensor({{cmd_vel.linear.x, cmd_vel.linear.y, cmd_vel.angular.z}});
|
||||
this->obs.base_quat = torch::tensor({{pose.orientation.x, pose.orientation.y, pose.orientation.z, pose.orientation.w}});
|
||||
this->obs.dof_pos = torch::tensor({{joint_positions[1], joint_positions[2], joint_positions[0],
|
||||
joint_positions[4], joint_positions[5], joint_positions[3],
|
||||
joint_positions[7], joint_positions[8], joint_positions[6],
|
||||
joint_positions[10], joint_positions[11], joint_positions[9]}});
|
||||
this->obs.dof_vel = torch::tensor({{joint_velocities[1], joint_velocities[2], joint_velocities[0],
|
||||
joint_velocities[4], joint_velocities[5], joint_velocities[3],
|
||||
joint_velocities[7], joint_velocities[8], joint_velocities[6],
|
||||
joint_velocities[10], joint_velocities[11], joint_velocities[9]}});
|
||||
|
||||
torques = this->compute_torques(this->forward());
|
||||
|
||||
|
||||
for (int i = 0; i < 12; ++i)
|
||||
{
|
||||
torque_commands[i].tau = torques[0][i].item<double>();
|
||||
torque_commands[i].mode = 0x0A;
|
||||
|
||||
torque_publishers[joint_names[i]].publish(torque_commands[i]);
|
||||
}
|
||||
}
|
||||
|
||||
torch::Tensor Unitree_RL::compute_observation()
|
||||
{
|
||||
torch::Tensor obs = torch::cat({// (this->quat_rotate_inverse(this->base_quat, this->lin_vel)) * this->params.lin_vel_scale,
|
||||
(this->quat_rotate_inverse(this->obs.base_quat, this->obs.ang_vel)) * this->params.ang_vel_scale,
|
||||
this->quat_rotate_inverse(this->obs.base_quat, this->obs.gravity_vec),
|
||||
this->obs.commands * this->params.commands_scale,
|
||||
(this->obs.dof_pos - this->params.default_dof_pos) * this->params.dof_pos_scale,
|
||||
this->obs.dof_vel * this->params.dof_vel_scale,
|
||||
this->obs.actions},
|
||||
1);
|
||||
obs = torch::clamp(obs, -this->params.clip_obs, this->params.clip_obs);
|
||||
return obs;
|
||||
}
|
||||
|
||||
torch::Tensor Unitree_RL::forward()
|
||||
{
|
||||
torch::Tensor obs = this->compute_observation();
|
||||
|
||||
history_obs_buf.insert(obs);
|
||||
history_obs = history_obs_buf.get_obs_vec({0, 1, 2, 3, 4, 5});
|
||||
|
||||
torch::Tensor encoding = this->encoder.forward({history_obs}).toTensor();
|
||||
|
||||
torch::Tensor z = this->vq.forward({encoding}).toTensor();
|
||||
|
||||
torch::Tensor actor_input = torch::cat({obs, z}, 1);
|
||||
|
||||
torch::Tensor action = this->actor.forward({actor_input}).toTensor();
|
||||
|
||||
this->obs.actions = action;
|
||||
torch::Tensor clamped = torch::clamp(action, -this->params.clip_actions, this->params.clip_actions);
|
||||
|
||||
return clamped;
|
||||
}
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
ros::init(argc, argv, "unitree_rl");
|
||||
Unitree_RL unitree_rl;
|
||||
ros::spin();
|
||||
return 0;
|
||||
}
|
|
@ -0,0 +1,332 @@
|
|||
<?xml version="1.0" ?>
|
||||
<sdf version="1.5">
|
||||
<world name="default">
|
||||
|
||||
<physics type="ode">
|
||||
<max_step_size>0.0002</max_step_size>
|
||||
<real_time_factor>1</real_time_factor>
|
||||
<real_time_update_rate>5000</real_time_update_rate>
|
||||
<gravity>0 0 -9.81</gravity>
|
||||
<ode>
|
||||
<solver>
|
||||
<type>quick</type>
|
||||
<iters>50</iters>
|
||||
<sor>1.3</sor>
|
||||
</solver>
|
||||
<constraints>
|
||||
<cfm>0.0</cfm>
|
||||
<erp>0.2</erp>
|
||||
<contact_max_correcting_vel>10.0</contact_max_correcting_vel>
|
||||
<contact_surface_layer>0.001</contact_surface_layer>
|
||||
</constraints>
|
||||
</ode>
|
||||
</physics>
|
||||
|
||||
<scene>
|
||||
<sky>
|
||||
<clouds>
|
||||
<speed>12</speed>
|
||||
</clouds>
|
||||
</sky>
|
||||
</scene>
|
||||
<!-- A global light source -->
|
||||
<include>
|
||||
<uri>model://sun</uri>
|
||||
</include>
|
||||
<!-- A ground plane -->
|
||||
<include>
|
||||
<uri>model://ground_plane</uri>
|
||||
</include>
|
||||
<!-- environment blocks, obstacles or stairs -->
|
||||
<model name="static_environment">
|
||||
<static>true</static>
|
||||
<link name="static_box">
|
||||
<pose>-2 2 0.5 0 0 0</pose>
|
||||
<collision name="static_box_collision">
|
||||
<geometry>
|
||||
<box>
|
||||
<size>1 1 1</size>
|
||||
</box>
|
||||
</geometry>
|
||||
</collision>
|
||||
<visual name="static_box_visual">
|
||||
<geometry>
|
||||
<box>
|
||||
<size>1 1 1</size>
|
||||
</box>
|
||||
</geometry>
|
||||
<material>
|
||||
<ambient>0.2 0.2 0.2 1.0</ambient>
|
||||
<diffuse>.421 0.225 0.0 1.0</diffuse>
|
||||
</material>
|
||||
</visual>
|
||||
</link>
|
||||
|
||||
<link name='Stairs'>
|
||||
<pose>3 0 0 0 0 0</pose>
|
||||
<scale>1 1 1</scale>
|
||||
<visual name='Stairs_Visual_0'>
|
||||
<pose>-1.26 -0 0.075 0 -0 1.5708</pose>
|
||||
<geometry>
|
||||
<box>
|
||||
<size>10 0.28 0.15</size>
|
||||
</box>
|
||||
</geometry>
|
||||
<material>
|
||||
<script>
|
||||
<uri>file://media/materials/scripts/gazebo.material</uri>
|
||||
<name>Gazebo/Grey</name>
|
||||
</script>
|
||||
<ambient>1 1 1 1</ambient>
|
||||
</material>
|
||||
<meta>
|
||||
<layer>0</layer>
|
||||
</meta>
|
||||
</visual>
|
||||
<collision name='Stairs_Collision_0'>
|
||||
<geometry>
|
||||
<box>
|
||||
<size>10 0.28 0.15</size>
|
||||
</box>
|
||||
</geometry>
|
||||
<pose>-1.26 -0 0.075 0 -0 1.5708</pose>
|
||||
</collision>
|
||||
<visual name='Stairs_Visual_1'>
|
||||
<pose>-0.98 -0 0.225 0 -0 1.5708</pose>
|
||||
<geometry>
|
||||
<box>
|
||||
<size>10 0.28 0.15</size>
|
||||
</box>
|
||||
</geometry>
|
||||
<material>
|
||||
<script>
|
||||
<uri>file://media/materials/scripts/gazebo.material</uri>
|
||||
<name>Gazebo/Grey</name>
|
||||
</script>
|
||||
<ambient>1 1 1 1</ambient>
|
||||
</material>
|
||||
<meta>
|
||||
<layer>0</layer>
|
||||
</meta>
|
||||
</visual>
|
||||
<collision name='Stairs_Collision_1'>
|
||||
<geometry>
|
||||
<box>
|
||||
<size>10 0.28 0.15</size>
|
||||
</box>
|
||||
</geometry>
|
||||
<pose>-0.98 -0 0.225 0 -0 1.5708</pose>
|
||||
</collision>
|
||||
<visual name='Stairs_Visual_2'>
|
||||
<pose>-0.7 -0 0.375 0 -0 1.5708</pose>
|
||||
<geometry>
|
||||
<box>
|
||||
<size>10 0.28 0.15</size>
|
||||
</box>
|
||||
</geometry>
|
||||
<material>
|
||||
<script>
|
||||
<uri>file://media/materials/scripts/gazebo.material</uri>
|
||||
<name>Gazebo/Grey</name>
|
||||
</script>
|
||||
<ambient>1 1 1 1</ambient>
|
||||
</material>
|
||||
<meta>
|
||||
<layer>0</layer>
|
||||
</meta>
|
||||
</visual>
|
||||
<collision name='Stairs_Collision_2'>
|
||||
<geometry>
|
||||
<box>
|
||||
<size>10 0.28 0.15</size>
|
||||
</box>
|
||||
</geometry>
|
||||
<pose>-0.7 -0 0.375 0 -0 1.5708</pose>
|
||||
</collision>
|
||||
<visual name='Stairs_Visual_3'>
|
||||
<pose>-0.42 -0 0.525 0 -0 1.5708</pose>
|
||||
<geometry>
|
||||
<box>
|
||||
<size>10 0.28 0.15</size>
|
||||
</box>
|
||||
</geometry>
|
||||
<material>
|
||||
<script>
|
||||
<uri>file://media/materials/scripts/gazebo.material</uri>
|
||||
<name>Gazebo/Grey</name>
|
||||
</script>
|
||||
<ambient>1 1 1 1</ambient>
|
||||
</material>
|
||||
<meta>
|
||||
<layer>0</layer>
|
||||
</meta>
|
||||
</visual>
|
||||
<collision name='Stairs_Collision_3'>
|
||||
<geometry>
|
||||
<box>
|
||||
<size>10 0.28 0.15</size>
|
||||
</box>
|
||||
</geometry>
|
||||
<pose>-0.42 -0 0.525 0 -0 1.5708</pose>
|
||||
</collision>
|
||||
<visual name='Stairs_Visual_4'>
|
||||
<pose>-0.14 -0 0.675 0 -0 1.5708</pose>
|
||||
<geometry>
|
||||
<box>
|
||||
<size>10 0.28 0.15</size>
|
||||
</box>
|
||||
</geometry>
|
||||
<material>
|
||||
<script>
|
||||
<uri>file://media/materials/scripts/gazebo.material</uri>
|
||||
<name>Gazebo/Grey</name>
|
||||
</script>
|
||||
<ambient>1 1 1 1</ambient>
|
||||
</material>
|
||||
<meta>
|
||||
<layer>0</layer>
|
||||
</meta>
|
||||
</visual>
|
||||
<collision name='Stairs_Collision_4'>
|
||||
<geometry>
|
||||
<box>
|
||||
<size>10 0.28 0.15</size>
|
||||
</box>
|
||||
</geometry>
|
||||
<pose>-0.14 -0 0.675 0 -0 1.5708</pose>
|
||||
</collision>
|
||||
<visual name='Stairs_Visual_5'>
|
||||
<pose>0.14 0 0.825 0 -0 1.5708</pose>
|
||||
<geometry>
|
||||
<box>
|
||||
<size>10 0.28 0.15</size>
|
||||
</box>
|
||||
</geometry>
|
||||
<material>
|
||||
<script>
|
||||
<uri>file://media/materials/scripts/gazebo.material</uri>
|
||||
<name>Gazebo/Grey</name>
|
||||
</script>
|
||||
<ambient>1 1 1 1</ambient>
|
||||
</material>
|
||||
<meta>
|
||||
<layer>0</layer>
|
||||
</meta>
|
||||
</visual>
|
||||
<collision name='Stairs_Collision_5'>
|
||||
<geometry>
|
||||
<box>
|
||||
<size>10 0.28 0.15</size>
|
||||
</box>
|
||||
</geometry>
|
||||
<pose>0.14 0 0.825 0 -0 1.5708</pose>
|
||||
</collision>
|
||||
<visual name='Stairs_Visual_6'>
|
||||
<pose>0.42 0 0.975 0 -0 1.5708</pose>
|
||||
<geometry>
|
||||
<box>
|
||||
<size>10 0.28 0.15</size>
|
||||
</box>
|
||||
</geometry>
|
||||
<material>
|
||||
<script>
|
||||
<uri>file://media/materials/scripts/gazebo.material</uri>
|
||||
<name>Gazebo/Grey</name>
|
||||
</script>
|
||||
<ambient>1 1 1 1</ambient>
|
||||
</material>
|
||||
<meta>
|
||||
<layer>0</layer>
|
||||
</meta>
|
||||
</visual>
|
||||
<collision name='Stairs_Collision_6'>
|
||||
<geometry>
|
||||
<box>
|
||||
<size>10 0.28 0.15</size>
|
||||
</box>
|
||||
</geometry>
|
||||
<pose>0.42 0 0.975 0 -0 1.5708</pose>
|
||||
</collision>
|
||||
<visual name='Stairs_Visual_7'>
|
||||
<pose>0.7 0 1.125 0 -0 1.5708</pose>
|
||||
<geometry>
|
||||
<box>
|
||||
<size>10 0.28 0.15</size>
|
||||
</box>
|
||||
</geometry>
|
||||
<material>
|
||||
<script>
|
||||
<uri>file://media/materials/scripts/gazebo.material</uri>
|
||||
<name>Gazebo/Grey</name>
|
||||
</script>
|
||||
<ambient>1 1 1 1</ambient>
|
||||
</material>
|
||||
<meta>
|
||||
<layer>0</layer>
|
||||
</meta>
|
||||
</visual>
|
||||
<collision name='Stairs_Collision_7'>
|
||||
<geometry>
|
||||
<box>
|
||||
<size>10 0.28 0.15</size>
|
||||
</box>
|
||||
</geometry>
|
||||
<pose>0.7 0 1.125 0 -0 1.5708</pose>
|
||||
</collision>
|
||||
<visual name='Stairs_Visual_8'>
|
||||
<pose>0.98 0 1.275 0 -0 1.5708</pose>
|
||||
<geometry>
|
||||
<box>
|
||||
<size>10 0.28 0.15</size>
|
||||
</box>
|
||||
</geometry>
|
||||
<material>
|
||||
<script>
|
||||
<uri>file://media/materials/scripts/gazebo.material</uri>
|
||||
<name>Gazebo/Grey</name>
|
||||
</script>
|
||||
<ambient>1 1 1 1</ambient>
|
||||
</material>
|
||||
<meta>
|
||||
<layer>0</layer>
|
||||
</meta>
|
||||
</visual>
|
||||
<collision name='Stairs_Collision_8'>
|
||||
<geometry>
|
||||
<box>
|
||||
<size>10 0.28 0.15</size>
|
||||
</box>
|
||||
</geometry>
|
||||
<pose>0.98 0 1.275 0 -0 1.5708</pose>
|
||||
</collision>
|
||||
<visual name='Stairs_Visual_9'>
|
||||
<pose>1.26 0 1.425 0 -0 1.5708</pose>
|
||||
<geometry>
|
||||
<box>
|
||||
<size>10 0.28 0.15</size>
|
||||
</box>
|
||||
</geometry>
|
||||
<material>
|
||||
<script>
|
||||
<uri>file://media/materials/scripts/gazebo.material</uri>
|
||||
<name>Gazebo/Grey</name>
|
||||
</script>
|
||||
<ambient>1 1 1 1</ambient>
|
||||
</material>
|
||||
<meta>
|
||||
<layer>0</layer>
|
||||
</meta>
|
||||
</visual>
|
||||
<collision name='Stairs_Collision_9'>
|
||||
<geometry>
|
||||
<box>
|
||||
<size>10 0.28 0.15</size>
|
||||
</box>
|
||||
</geometry>
|
||||
<pose>1.26 0 1.425 0 -0 1.5708</pose>
|
||||
</collision>
|
||||
</link>
|
||||
</model>
|
||||
|
||||
</world>
|
||||
</sdf>
|
Loading…
Reference in New Issue