diff --git a/src/unitree_rl/include/unitree_rl.hpp b/src/unitree_rl/include/unitree_rl.hpp index 60be355..e791078 100644 --- a/src/unitree_rl/include/unitree_rl.hpp +++ b/src/unitree_rl/include/unitree_rl.hpp @@ -12,45 +12,46 @@ class Unitree_RL : public Model { public: + Unitree_RL(); + void modelStatesCallback(const gazebo_msgs::ModelStates::ConstPtr &msg); + void jointStatesCallback(const sensor_msgs::JointState::ConstPtr &msg); + void cmdvelCallback(const geometry_msgs::Twist::ConstPtr &msg); + void runModel(const ros::TimerEvent &event); + torch::Tensor forward() override; + torch::Tensor compute_observation() override; - Unitree_RL(); - void modelStatesCallback(const gazebo_msgs::ModelStates::ConstPtr& msg); - void jointStatesCallback(const sensor_msgs::JointState::ConstPtr& msg); - void cmdvelCallback(const geometry_msgs::Twist::ConstPtr& msg); - void runModel(const ros::TimerEvent& event); - torch::Tensor forward() override; - torch::Tensor compute_observation() override; - - ObservationBuffer history_obs_buf; - torch::Tensor history_obs; + ObservationBuffer history_obs_buf; + torch::Tensor history_obs; private: + std::string ros_namespace; - std::string ros_namespace; + std::vector torque_command_topics; - std::vector torque_command_topics; + ros::Subscriber model_state_subscriber_; + ros::Subscriber joint_state_subscriber_; + ros::Subscriber cmd_vel_subscriber_; - ros::Subscriber model_state_subscriber_; - ros::Subscriber joint_state_subscriber_; - ros::Subscriber cmd_vel_subscriber_; + std::map torque_publishers; + std::vector torque_commands; - std::map torque_publishers; - std::vector torque_commands; + geometry_msgs::Twist vel; + geometry_msgs::Pose pose; + geometry_msgs::Twist cmd_vel; - geometry_msgs::Twist vel; - geometry_msgs::Pose pose; - geometry_msgs::Twist cmd_vel; + std::vector joint_names; + std::vector joint_positions; + std::vector joint_velocities; - std::vector joint_names; - std::vector joint_positions; - std::vector joint_velocities; + torch::Tensor torques; - torch::Tensor torques; + ros::Timer timer; - ros::Timer timer; - - std::chrono::high_resolution_clock::time_point start_time; + std::chrono::high_resolution_clock::time_point start_time; + // other rl module + torch::jit::script::Module encoder; + torch::jit::script::Module vq; }; #endif \ No newline at end of file diff --git a/src/unitree_rl/lib/model.cpp b/src/unitree_rl/lib/model.cpp index b476ba7..1578f08 100644 --- a/src/unitree_rl/lib/model.cpp +++ b/src/unitree_rl/lib/model.cpp @@ -1,13 +1,5 @@ #include "model.hpp" -void Model::init_models(std::string actor_path, std::string encoder_path, std::string vq_path) -{ - this->actor = torch::jit::load(actor_path); - this->encoder = torch::jit::load(encoder_path); - this->vq = torch::jit::load(vq_path); - this->init_observations(); -} - torch::Tensor Model::quat_rotate_inverse(torch::Tensor q, torch::Tensor v) { c10::IntArrayRef shape = q.sizes(); diff --git a/src/unitree_rl/lib/model.hpp b/src/unitree_rl/lib/model.hpp index 1c20cb0..63606f5 100644 --- a/src/unitree_rl/lib/model.hpp +++ b/src/unitree_rl/lib/model.hpp @@ -47,13 +47,10 @@ public: torch::Tensor compute_torques(torch::Tensor actions); torch::Tensor quat_rotate_inverse(torch::Tensor q, torch::Tensor v); void init_observations(); - void init_models(std::string actor_path, std::string encoder_path, std::string vq_path); protected: // rl module torch::jit::script::Module actor; - torch::jit::script::Module encoder; - torch::jit::script::Module vq; // observation buffer torch::Tensor lin_vel; torch::Tensor ang_vel; @@ -63,7 +60,6 @@ protected: torch::Tensor dof_pos; torch::Tensor dof_vel; torch::Tensor actions; - }; #endif // MODEL_HPP \ No newline at end of file diff --git a/src/unitree_rl/src/unitree_rl.cpp b/src/unitree_rl/src/unitree_rl.cpp index 38f2652..6e67d09 100644 --- a/src/unitree_rl/src/unitree_rl.cpp +++ b/src/unitree_rl/src/unitree_rl.cpp @@ -29,7 +29,11 @@ Unitree_RL::Unitree_RL() std::string actor_path = ros::package::getPath(package_name) + "/models/actor.pt"; std::string encoder_path = ros::package::getPath(package_name) + "/models/encoder.pt"; std::string vq_path = ros::package::getPath(package_name) + "/models/vq_layer.pt"; - this->init_models(actor_path, encoder_path, vq_path); + + this->actor = torch::jit::load(actor_path); + this->encoder = torch::jit::load(encoder_path); + this->vq = torch::jit::load(vq_path); + this->init_observations(); this->params.num_observations = 45; this->params.clip_obs = 100.0;