2024-03-06 17:24:39 +08:00
|
|
|
#ifndef UNITREE_RL
|
|
|
|
#define UNITREE_RL
|
|
|
|
|
|
|
|
#include <ros/ros.h>
|
|
|
|
#include <gazebo_msgs/ModelStates.h>
|
|
|
|
#include <sensor_msgs/JointState.h>
|
|
|
|
#include <geometry_msgs/Twist.h>
|
2024-03-14 12:42:39 +08:00
|
|
|
#include "../library/model/model.hpp"
|
|
|
|
#include "../library/observation_buffer/observation_buffer.hpp"
|
2024-03-06 17:24:39 +08:00
|
|
|
#include "unitree_legged_msgs/MotorCmd.h"
|
|
|
|
|
|
|
|
class Unitree_RL : public Model
|
|
|
|
{
|
|
|
|
public:
|
2024-03-06 18:24:05 +08:00
|
|
|
Unitree_RL();
|
|
|
|
void modelStatesCallback(const gazebo_msgs::ModelStates::ConstPtr &msg);
|
|
|
|
void jointStatesCallback(const sensor_msgs::JointState::ConstPtr &msg);
|
|
|
|
void cmdvelCallback(const geometry_msgs::Twist::ConstPtr &msg);
|
|
|
|
void runModel(const ros::TimerEvent &event);
|
|
|
|
torch::Tensor forward() override;
|
|
|
|
torch::Tensor compute_observation() override;
|
2024-03-06 17:24:39 +08:00
|
|
|
|
2024-03-06 18:24:05 +08:00
|
|
|
ObservationBuffer history_obs_buf;
|
|
|
|
torch::Tensor history_obs;
|
2024-03-06 17:24:39 +08:00
|
|
|
|
|
|
|
private:
|
2024-03-06 18:24:05 +08:00
|
|
|
std::string ros_namespace;
|
2024-03-06 17:24:39 +08:00
|
|
|
|
2024-03-06 18:24:05 +08:00
|
|
|
std::vector<std::string> torque_command_topics;
|
2024-03-06 17:24:39 +08:00
|
|
|
|
2024-03-06 18:24:05 +08:00
|
|
|
ros::Subscriber model_state_subscriber_;
|
|
|
|
ros::Subscriber joint_state_subscriber_;
|
|
|
|
ros::Subscriber cmd_vel_subscriber_;
|
2024-03-06 17:24:39 +08:00
|
|
|
|
2024-03-06 18:24:05 +08:00
|
|
|
std::map<std::string, ros::Publisher> torque_publishers;
|
|
|
|
std::vector<unitree_legged_msgs::MotorCmd> torque_commands;
|
2024-03-06 17:24:39 +08:00
|
|
|
|
2024-03-06 18:24:05 +08:00
|
|
|
geometry_msgs::Twist vel;
|
|
|
|
geometry_msgs::Pose pose;
|
|
|
|
geometry_msgs::Twist cmd_vel;
|
2024-03-06 17:24:39 +08:00
|
|
|
|
2024-03-06 18:24:05 +08:00
|
|
|
std::vector<std::string> joint_names;
|
|
|
|
std::vector<double> joint_positions;
|
|
|
|
std::vector<double> joint_velocities;
|
2024-03-06 17:24:39 +08:00
|
|
|
|
2024-03-06 18:24:05 +08:00
|
|
|
torch::Tensor torques;
|
2024-03-06 17:24:39 +08:00
|
|
|
|
2024-03-06 18:24:05 +08:00
|
|
|
ros::Timer timer;
|
2024-03-06 17:24:39 +08:00
|
|
|
|
2024-03-06 18:24:05 +08:00
|
|
|
std::chrono::high_resolution_clock::time_point start_time;
|
2024-03-06 17:24:39 +08:00
|
|
|
|
2024-03-06 18:24:05 +08:00
|
|
|
// other rl module
|
|
|
|
torch::jit::script::Module encoder;
|
|
|
|
torch::jit::script::Module vq;
|
2024-03-06 17:24:39 +08:00
|
|
|
};
|
|
|
|
|
|
|
|
#endif
|