rl_sar/src/unitree_rl/include/unitree_rl.hpp

57 lines
1.6 KiB
C++
Raw Normal View History

2024-03-06 17:24:39 +08:00
#ifndef UNITREE_RL
#define UNITREE_RL
#include <ros/ros.h>
#include <gazebo_msgs/ModelStates.h>
#include <sensor_msgs/JointState.h>
#include <geometry_msgs/Twist.h>
2024-03-14 12:42:39 +08:00
#include "../library/model/model.hpp"
#include "../library/observation_buffer/observation_buffer.hpp"
2024-03-06 17:24:39 +08:00
#include "unitree_legged_msgs/MotorCmd.h"
class Unitree_RL : public Model
{
public:
2024-03-06 18:24:05 +08:00
Unitree_RL();
void modelStatesCallback(const gazebo_msgs::ModelStates::ConstPtr &msg);
void jointStatesCallback(const sensor_msgs::JointState::ConstPtr &msg);
void cmdvelCallback(const geometry_msgs::Twist::ConstPtr &msg);
void runModel(const ros::TimerEvent &event);
torch::Tensor forward() override;
torch::Tensor compute_observation() override;
2024-03-06 17:24:39 +08:00
2024-03-06 18:24:05 +08:00
ObservationBuffer history_obs_buf;
torch::Tensor history_obs;
2024-03-06 17:24:39 +08:00
private:
2024-03-06 18:24:05 +08:00
std::string ros_namespace;
2024-03-06 17:24:39 +08:00
2024-03-06 18:24:05 +08:00
std::vector<std::string> torque_command_topics;
2024-03-06 17:24:39 +08:00
2024-03-06 18:24:05 +08:00
ros::Subscriber model_state_subscriber_;
ros::Subscriber joint_state_subscriber_;
ros::Subscriber cmd_vel_subscriber_;
2024-03-06 17:24:39 +08:00
2024-03-06 18:24:05 +08:00
std::map<std::string, ros::Publisher> torque_publishers;
std::vector<unitree_legged_msgs::MotorCmd> torque_commands;
2024-03-06 17:24:39 +08:00
2024-03-06 18:24:05 +08:00
geometry_msgs::Twist vel;
geometry_msgs::Pose pose;
geometry_msgs::Twist cmd_vel;
2024-03-06 17:24:39 +08:00
2024-03-06 18:24:05 +08:00
std::vector<std::string> joint_names;
std::vector<double> joint_positions;
std::vector<double> joint_velocities;
2024-03-06 17:24:39 +08:00
2024-03-06 18:24:05 +08:00
torch::Tensor torques;
2024-03-06 17:24:39 +08:00
2024-03-06 18:24:05 +08:00
ros::Timer timer;
2024-03-06 17:24:39 +08:00
2024-03-06 18:24:05 +08:00
std::chrono::high_resolution_clock::time_point start_time;
2024-03-06 17:24:39 +08:00
2024-03-06 18:24:05 +08:00
// other rl module
torch::jit::script::Module encoder;
torch::jit::script::Module vq;
2024-03-06 17:24:39 +08:00
};
#endif