mirror of https://github.com/fan-ziqi/rl_sar.git
fix: move rl modules
This commit is contained in:
parent
37fdcf6113
commit
d09bbe3096
|
@ -12,45 +12,46 @@
|
|||
class Unitree_RL : public Model
|
||||
{
|
||||
public:
|
||||
Unitree_RL();
|
||||
void modelStatesCallback(const gazebo_msgs::ModelStates::ConstPtr &msg);
|
||||
void jointStatesCallback(const sensor_msgs::JointState::ConstPtr &msg);
|
||||
void cmdvelCallback(const geometry_msgs::Twist::ConstPtr &msg);
|
||||
void runModel(const ros::TimerEvent &event);
|
||||
torch::Tensor forward() override;
|
||||
torch::Tensor compute_observation() override;
|
||||
|
||||
Unitree_RL();
|
||||
void modelStatesCallback(const gazebo_msgs::ModelStates::ConstPtr& msg);
|
||||
void jointStatesCallback(const sensor_msgs::JointState::ConstPtr& msg);
|
||||
void cmdvelCallback(const geometry_msgs::Twist::ConstPtr& msg);
|
||||
void runModel(const ros::TimerEvent& event);
|
||||
torch::Tensor forward() override;
|
||||
torch::Tensor compute_observation() override;
|
||||
|
||||
ObservationBuffer history_obs_buf;
|
||||
torch::Tensor history_obs;
|
||||
ObservationBuffer history_obs_buf;
|
||||
torch::Tensor history_obs;
|
||||
|
||||
private:
|
||||
std::string ros_namespace;
|
||||
|
||||
std::string ros_namespace;
|
||||
std::vector<std::string> torque_command_topics;
|
||||
|
||||
std::vector<std::string> torque_command_topics;
|
||||
ros::Subscriber model_state_subscriber_;
|
||||
ros::Subscriber joint_state_subscriber_;
|
||||
ros::Subscriber cmd_vel_subscriber_;
|
||||
|
||||
ros::Subscriber model_state_subscriber_;
|
||||
ros::Subscriber joint_state_subscriber_;
|
||||
ros::Subscriber cmd_vel_subscriber_;
|
||||
std::map<std::string, ros::Publisher> torque_publishers;
|
||||
std::vector<unitree_legged_msgs::MotorCmd> torque_commands;
|
||||
|
||||
std::map<std::string, ros::Publisher> torque_publishers;
|
||||
std::vector<unitree_legged_msgs::MotorCmd> torque_commands;
|
||||
geometry_msgs::Twist vel;
|
||||
geometry_msgs::Pose pose;
|
||||
geometry_msgs::Twist cmd_vel;
|
||||
|
||||
geometry_msgs::Twist vel;
|
||||
geometry_msgs::Pose pose;
|
||||
geometry_msgs::Twist cmd_vel;
|
||||
std::vector<std::string> joint_names;
|
||||
std::vector<double> joint_positions;
|
||||
std::vector<double> joint_velocities;
|
||||
|
||||
std::vector<std::string> joint_names;
|
||||
std::vector<double> joint_positions;
|
||||
std::vector<double> joint_velocities;
|
||||
torch::Tensor torques;
|
||||
|
||||
torch::Tensor torques;
|
||||
ros::Timer timer;
|
||||
|
||||
ros::Timer timer;
|
||||
|
||||
std::chrono::high_resolution_clock::time_point start_time;
|
||||
std::chrono::high_resolution_clock::time_point start_time;
|
||||
|
||||
// other rl module
|
||||
torch::jit::script::Module encoder;
|
||||
torch::jit::script::Module vq;
|
||||
};
|
||||
|
||||
#endif
|
|
@ -1,13 +1,5 @@
|
|||
#include "model.hpp"
|
||||
|
||||
void Model::init_models(std::string actor_path, std::string encoder_path, std::string vq_path)
|
||||
{
|
||||
this->actor = torch::jit::load(actor_path);
|
||||
this->encoder = torch::jit::load(encoder_path);
|
||||
this->vq = torch::jit::load(vq_path);
|
||||
this->init_observations();
|
||||
}
|
||||
|
||||
torch::Tensor Model::quat_rotate_inverse(torch::Tensor q, torch::Tensor v)
|
||||
{
|
||||
c10::IntArrayRef shape = q.sizes();
|
||||
|
|
|
@ -47,13 +47,10 @@ public:
|
|||
torch::Tensor compute_torques(torch::Tensor actions);
|
||||
torch::Tensor quat_rotate_inverse(torch::Tensor q, torch::Tensor v);
|
||||
void init_observations();
|
||||
void init_models(std::string actor_path, std::string encoder_path, std::string vq_path);
|
||||
|
||||
protected:
|
||||
// rl module
|
||||
torch::jit::script::Module actor;
|
||||
torch::jit::script::Module encoder;
|
||||
torch::jit::script::Module vq;
|
||||
// observation buffer
|
||||
torch::Tensor lin_vel;
|
||||
torch::Tensor ang_vel;
|
||||
|
@ -63,7 +60,6 @@ protected:
|
|||
torch::Tensor dof_pos;
|
||||
torch::Tensor dof_vel;
|
||||
torch::Tensor actions;
|
||||
|
||||
};
|
||||
|
||||
#endif // MODEL_HPP
|
|
@ -29,7 +29,11 @@ Unitree_RL::Unitree_RL()
|
|||
std::string actor_path = ros::package::getPath(package_name) + "/models/actor.pt";
|
||||
std::string encoder_path = ros::package::getPath(package_name) + "/models/encoder.pt";
|
||||
std::string vq_path = ros::package::getPath(package_name) + "/models/vq_layer.pt";
|
||||
this->init_models(actor_path, encoder_path, vq_path);
|
||||
|
||||
this->actor = torch::jit::load(actor_path);
|
||||
this->encoder = torch::jit::load(encoder_path);
|
||||
this->vq = torch::jit::load(vq_path);
|
||||
this->init_observations();
|
||||
|
||||
this->params.num_observations = 45;
|
||||
this->params.clip_obs = 100.0;
|
||||
|
|
Loading…
Reference in New Issue