#ifndef UNITREE_RL #define UNITREE_RL #include #include #include #include #include "../library/model/model.hpp" #include "../library/observation_buffer/observation_buffer.hpp" #include "unitree_legged_msgs/MotorCmd.h" class Unitree_RL : public Model { public: Unitree_RL(); void modelStatesCallback(const gazebo_msgs::ModelStates::ConstPtr &msg); void jointStatesCallback(const sensor_msgs::JointState::ConstPtr &msg); void cmdvelCallback(const geometry_msgs::Twist::ConstPtr &msg); void runModel(const ros::TimerEvent &event); torch::Tensor forward() override; torch::Tensor compute_observation() override; ObservationBuffer history_obs_buf; torch::Tensor history_obs; private: std::string ros_namespace; std::vector torque_command_topics; ros::Subscriber model_state_subscriber_; ros::Subscriber joint_state_subscriber_; ros::Subscriber cmd_vel_subscriber_; std::map torque_publishers; std::vector torque_commands; geometry_msgs::Twist vel; geometry_msgs::Pose pose; geometry_msgs::Twist cmd_vel; std::vector joint_names; std::vector joint_positions; std::vector joint_velocities; torch::Tensor torques; ros::Timer timer; std::chrono::high_resolution_clock::time_point start_time; // other rl module torch::jit::script::Module encoder; torch::jit::script::Module vq; }; #endif