fix: move rl modules

This commit is contained in:
fan-ziqi 2024-03-06 18:24:05 +08:00
parent 37fdcf6113
commit d09bbe3096
4 changed files with 33 additions and 40 deletions

View File

@ -12,45 +12,46 @@
class Unitree_RL : public Model
{
public:
Unitree_RL();
void modelStatesCallback(const gazebo_msgs::ModelStates::ConstPtr &msg);
void jointStatesCallback(const sensor_msgs::JointState::ConstPtr &msg);
void cmdvelCallback(const geometry_msgs::Twist::ConstPtr &msg);
void runModel(const ros::TimerEvent &event);
torch::Tensor forward() override;
torch::Tensor compute_observation() override;
Unitree_RL();
void modelStatesCallback(const gazebo_msgs::ModelStates::ConstPtr& msg);
void jointStatesCallback(const sensor_msgs::JointState::ConstPtr& msg);
void cmdvelCallback(const geometry_msgs::Twist::ConstPtr& msg);
void runModel(const ros::TimerEvent& event);
torch::Tensor forward() override;
torch::Tensor compute_observation() override;
ObservationBuffer history_obs_buf;
torch::Tensor history_obs;
ObservationBuffer history_obs_buf;
torch::Tensor history_obs;
private:
std::string ros_namespace;
std::string ros_namespace;
std::vector<std::string> torque_command_topics;
std::vector<std::string> torque_command_topics;
ros::Subscriber model_state_subscriber_;
ros::Subscriber joint_state_subscriber_;
ros::Subscriber cmd_vel_subscriber_;
ros::Subscriber model_state_subscriber_;
ros::Subscriber joint_state_subscriber_;
ros::Subscriber cmd_vel_subscriber_;
std::map<std::string, ros::Publisher> torque_publishers;
std::vector<unitree_legged_msgs::MotorCmd> torque_commands;
std::map<std::string, ros::Publisher> torque_publishers;
std::vector<unitree_legged_msgs::MotorCmd> torque_commands;
geometry_msgs::Twist vel;
geometry_msgs::Pose pose;
geometry_msgs::Twist cmd_vel;
geometry_msgs::Twist vel;
geometry_msgs::Pose pose;
geometry_msgs::Twist cmd_vel;
std::vector<std::string> joint_names;
std::vector<double> joint_positions;
std::vector<double> joint_velocities;
std::vector<std::string> joint_names;
std::vector<double> joint_positions;
std::vector<double> joint_velocities;
torch::Tensor torques;
torch::Tensor torques;
ros::Timer timer;
ros::Timer timer;
std::chrono::high_resolution_clock::time_point start_time;
std::chrono::high_resolution_clock::time_point start_time;
// other rl module
torch::jit::script::Module encoder;
torch::jit::script::Module vq;
};
#endif

View File

@ -1,13 +1,5 @@
#include "model.hpp"
void Model::init_models(std::string actor_path, std::string encoder_path, std::string vq_path)
{
this->actor = torch::jit::load(actor_path);
this->encoder = torch::jit::load(encoder_path);
this->vq = torch::jit::load(vq_path);
this->init_observations();
}
torch::Tensor Model::quat_rotate_inverse(torch::Tensor q, torch::Tensor v)
{
c10::IntArrayRef shape = q.sizes();

View File

@ -47,13 +47,10 @@ public:
torch::Tensor compute_torques(torch::Tensor actions);
torch::Tensor quat_rotate_inverse(torch::Tensor q, torch::Tensor v);
void init_observations();
void init_models(std::string actor_path, std::string encoder_path, std::string vq_path);
protected:
// rl module
torch::jit::script::Module actor;
torch::jit::script::Module encoder;
torch::jit::script::Module vq;
// observation buffer
torch::Tensor lin_vel;
torch::Tensor ang_vel;
@ -63,7 +60,6 @@ protected:
torch::Tensor dof_pos;
torch::Tensor dof_vel;
torch::Tensor actions;
};
#endif // MODEL_HPP

View File

@ -29,7 +29,11 @@ Unitree_RL::Unitree_RL()
std::string actor_path = ros::package::getPath(package_name) + "/models/actor.pt";
std::string encoder_path = ros::package::getPath(package_name) + "/models/encoder.pt";
std::string vq_path = ros::package::getPath(package_name) + "/models/vq_layer.pt";
this->init_models(actor_path, encoder_path, vq_path);
this->actor = torch::jit::load(actor_path);
this->encoder = torch::jit::load(encoder_path);
this->vq = torch::jit::load(vq_path);
this->init_observations();
this->params.num_observations = 45;
this->params.clip_obs = 100.0;