fix: move rl modules

This commit is contained in:
fan-ziqi 2024-03-06 18:24:05 +08:00
parent 37fdcf6113
commit d09bbe3096
4 changed files with 33 additions and 40 deletions

View File

@ -12,12 +12,11 @@
class Unitree_RL : public Model class Unitree_RL : public Model
{ {
public: public:
Unitree_RL(); Unitree_RL();
void modelStatesCallback(const gazebo_msgs::ModelStates::ConstPtr& msg); void modelStatesCallback(const gazebo_msgs::ModelStates::ConstPtr &msg);
void jointStatesCallback(const sensor_msgs::JointState::ConstPtr& msg); void jointStatesCallback(const sensor_msgs::JointState::ConstPtr &msg);
void cmdvelCallback(const geometry_msgs::Twist::ConstPtr& msg); void cmdvelCallback(const geometry_msgs::Twist::ConstPtr &msg);
void runModel(const ros::TimerEvent& event); void runModel(const ros::TimerEvent &event);
torch::Tensor forward() override; torch::Tensor forward() override;
torch::Tensor compute_observation() override; torch::Tensor compute_observation() override;
@ -25,7 +24,6 @@ public:
torch::Tensor history_obs; torch::Tensor history_obs;
private: private:
std::string ros_namespace; std::string ros_namespace;
std::vector<std::string> torque_command_topics; std::vector<std::string> torque_command_topics;
@ -51,6 +49,9 @@ private:
std::chrono::high_resolution_clock::time_point start_time; std::chrono::high_resolution_clock::time_point start_time;
// other rl module
torch::jit::script::Module encoder;
torch::jit::script::Module vq;
}; };
#endif #endif

View File

@ -1,13 +1,5 @@
#include "model.hpp" #include "model.hpp"
void Model::init_models(std::string actor_path, std::string encoder_path, std::string vq_path)
{
this->actor = torch::jit::load(actor_path);
this->encoder = torch::jit::load(encoder_path);
this->vq = torch::jit::load(vq_path);
this->init_observations();
}
torch::Tensor Model::quat_rotate_inverse(torch::Tensor q, torch::Tensor v) torch::Tensor Model::quat_rotate_inverse(torch::Tensor q, torch::Tensor v)
{ {
c10::IntArrayRef shape = q.sizes(); c10::IntArrayRef shape = q.sizes();

View File

@ -47,13 +47,10 @@ public:
torch::Tensor compute_torques(torch::Tensor actions); torch::Tensor compute_torques(torch::Tensor actions);
torch::Tensor quat_rotate_inverse(torch::Tensor q, torch::Tensor v); torch::Tensor quat_rotate_inverse(torch::Tensor q, torch::Tensor v);
void init_observations(); void init_observations();
void init_models(std::string actor_path, std::string encoder_path, std::string vq_path);
protected: protected:
// rl module // rl module
torch::jit::script::Module actor; torch::jit::script::Module actor;
torch::jit::script::Module encoder;
torch::jit::script::Module vq;
// observation buffer // observation buffer
torch::Tensor lin_vel; torch::Tensor lin_vel;
torch::Tensor ang_vel; torch::Tensor ang_vel;
@ -63,7 +60,6 @@ protected:
torch::Tensor dof_pos; torch::Tensor dof_pos;
torch::Tensor dof_vel; torch::Tensor dof_vel;
torch::Tensor actions; torch::Tensor actions;
}; };
#endif // MODEL_HPP #endif // MODEL_HPP

View File

@ -29,7 +29,11 @@ Unitree_RL::Unitree_RL()
std::string actor_path = ros::package::getPath(package_name) + "/models/actor.pt"; std::string actor_path = ros::package::getPath(package_name) + "/models/actor.pt";
std::string encoder_path = ros::package::getPath(package_name) + "/models/encoder.pt"; std::string encoder_path = ros::package::getPath(package_name) + "/models/encoder.pt";
std::string vq_path = ros::package::getPath(package_name) + "/models/vq_layer.pt"; std::string vq_path = ros::package::getPath(package_name) + "/models/vq_layer.pt";
this->init_models(actor_path, encoder_path, vq_path);
this->actor = torch::jit::load(actor_path);
this->encoder = torch::jit::load(encoder_path);
this->vq = torch::jit::load(vq_path);
this->init_observations();
this->params.num_observations = 45; this->params.num_observations = 45;
this->params.clip_obs = 100.0; this->params.clip_obs = 100.0;