mirror of https://github.com/fan-ziqi/rl_sar.git
fix: move rl modules
This commit is contained in:
parent
37fdcf6113
commit
d09bbe3096
|
@ -12,7 +12,6 @@
|
||||||
class Unitree_RL : public Model
|
class Unitree_RL : public Model
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
|
|
||||||
Unitree_RL();
|
Unitree_RL();
|
||||||
void modelStatesCallback(const gazebo_msgs::ModelStates::ConstPtr &msg);
|
void modelStatesCallback(const gazebo_msgs::ModelStates::ConstPtr &msg);
|
||||||
void jointStatesCallback(const sensor_msgs::JointState::ConstPtr &msg);
|
void jointStatesCallback(const sensor_msgs::JointState::ConstPtr &msg);
|
||||||
|
@ -25,7 +24,6 @@ public:
|
||||||
torch::Tensor history_obs;
|
torch::Tensor history_obs;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
||||||
std::string ros_namespace;
|
std::string ros_namespace;
|
||||||
|
|
||||||
std::vector<std::string> torque_command_topics;
|
std::vector<std::string> torque_command_topics;
|
||||||
|
@ -51,6 +49,9 @@ private:
|
||||||
|
|
||||||
std::chrono::high_resolution_clock::time_point start_time;
|
std::chrono::high_resolution_clock::time_point start_time;
|
||||||
|
|
||||||
|
// other rl module
|
||||||
|
torch::jit::script::Module encoder;
|
||||||
|
torch::jit::script::Module vq;
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif
|
#endif
|
|
@ -1,13 +1,5 @@
|
||||||
#include "model.hpp"
|
#include "model.hpp"
|
||||||
|
|
||||||
void Model::init_models(std::string actor_path, std::string encoder_path, std::string vq_path)
|
|
||||||
{
|
|
||||||
this->actor = torch::jit::load(actor_path);
|
|
||||||
this->encoder = torch::jit::load(encoder_path);
|
|
||||||
this->vq = torch::jit::load(vq_path);
|
|
||||||
this->init_observations();
|
|
||||||
}
|
|
||||||
|
|
||||||
torch::Tensor Model::quat_rotate_inverse(torch::Tensor q, torch::Tensor v)
|
torch::Tensor Model::quat_rotate_inverse(torch::Tensor q, torch::Tensor v)
|
||||||
{
|
{
|
||||||
c10::IntArrayRef shape = q.sizes();
|
c10::IntArrayRef shape = q.sizes();
|
||||||
|
|
|
@ -47,13 +47,10 @@ public:
|
||||||
torch::Tensor compute_torques(torch::Tensor actions);
|
torch::Tensor compute_torques(torch::Tensor actions);
|
||||||
torch::Tensor quat_rotate_inverse(torch::Tensor q, torch::Tensor v);
|
torch::Tensor quat_rotate_inverse(torch::Tensor q, torch::Tensor v);
|
||||||
void init_observations();
|
void init_observations();
|
||||||
void init_models(std::string actor_path, std::string encoder_path, std::string vq_path);
|
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
// rl module
|
// rl module
|
||||||
torch::jit::script::Module actor;
|
torch::jit::script::Module actor;
|
||||||
torch::jit::script::Module encoder;
|
|
||||||
torch::jit::script::Module vq;
|
|
||||||
// observation buffer
|
// observation buffer
|
||||||
torch::Tensor lin_vel;
|
torch::Tensor lin_vel;
|
||||||
torch::Tensor ang_vel;
|
torch::Tensor ang_vel;
|
||||||
|
@ -63,7 +60,6 @@ protected:
|
||||||
torch::Tensor dof_pos;
|
torch::Tensor dof_pos;
|
||||||
torch::Tensor dof_vel;
|
torch::Tensor dof_vel;
|
||||||
torch::Tensor actions;
|
torch::Tensor actions;
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif // MODEL_HPP
|
#endif // MODEL_HPP
|
|
@ -29,7 +29,11 @@ Unitree_RL::Unitree_RL()
|
||||||
std::string actor_path = ros::package::getPath(package_name) + "/models/actor.pt";
|
std::string actor_path = ros::package::getPath(package_name) + "/models/actor.pt";
|
||||||
std::string encoder_path = ros::package::getPath(package_name) + "/models/encoder.pt";
|
std::string encoder_path = ros::package::getPath(package_name) + "/models/encoder.pt";
|
||||||
std::string vq_path = ros::package::getPath(package_name) + "/models/vq_layer.pt";
|
std::string vq_path = ros::package::getPath(package_name) + "/models/vq_layer.pt";
|
||||||
this->init_models(actor_path, encoder_path, vq_path);
|
|
||||||
|
this->actor = torch::jit::load(actor_path);
|
||||||
|
this->encoder = torch::jit::load(encoder_path);
|
||||||
|
this->vq = torch::jit::load(vq_path);
|
||||||
|
this->init_observations();
|
||||||
|
|
||||||
this->params.num_observations = 45;
|
this->params.num_observations = 45;
|
||||||
this->params.clip_obs = 100.0;
|
this->params.clip_obs = 100.0;
|
||||||
|
|
Loading…
Reference in New Issue