feat: load all model as one file

This commit is contained in:
fan-ziqi 2024-04-20 12:29:02 +08:00
parent e808a96f62
commit 706ddce2a2
16 changed files with 19 additions and 57 deletions

View File

@ -1,4 +1,5 @@
a1:
model_name: "model.pt"
num_observations: 45
clip_obs: 100.0
clip_actions: 100.0
@ -19,16 +20,17 @@ a1:
ang_vel_scale: 0.25
dof_pos_scale: 1.0
dof_vel_scale: 0.05
torque_limits: [20.0, 55.0, 55.0, # FL
20.0, 55.0, 55.0, # FR
20.0, 55.0, 55.0, # RL
20.0, 55.0, 55.0] # RR
torque_limits: [33.5, 33.5, 33.5, # FL
33.5, 33.5, 33.5, # FR
33.5, 33.5, 33.5, # RL
33.5, 33.5, 33.5] # RR
default_dof_pos: [ 0.1000, 0.8000, -1.5000, # FL
-0.1000, 0.8000, -1.5000, # FR
0.1000, 1.0000, -1.5000, # RL
-0.1000, 1.0000, -1.5000] # RR
cyberdog1:
model_name: "model.pt"
num_observations: 45
clip_obs: 100.0
clip_actions: 100.0

View File

@ -71,10 +71,6 @@ private:
int hip_scale_reduction_indices[4] = {0, 3, 6, 9};
std::chrono::high_resolution_clock::time_point start_time;
// other rl module
torch::jit::script::Module encoder;
torch::jit::script::Module vq;
};
#endif

View File

@ -81,10 +81,6 @@ private:
int hip_scale_reduction_indices[4] = {0, 3, 6, 9};
std::chrono::high_resolution_clock::time_point start_time;
// other rl module
torch::jit::script::Module encoder;
torch::jit::script::Module vq;
};
#endif

View File

@ -64,10 +64,6 @@ private:
int hip_scale_reduction_indices[4] = {0, 3, 6, 9};
std::chrono::high_resolution_clock::time_point start_time;
// other rl module
torch::jit::script::Module encoder;
torch::jit::script::Module vq;
};
#endif

View File

@ -13,6 +13,7 @@ void RL::ReadYaml(std::string robot_name)
return;
}
this->params.model_name = config["model_name"].as<std::string>();
this->params.num_observations = config["num_observations"].as<int>();
this->params.clip_obs = config["clip_obs"].as<float>();
this->params.clip_actions = config["clip_actions"].as<float>();

View File

@ -13,6 +13,7 @@ namespace plt = matplotlibcpp;
struct ModelParams
{
std::string model_name;
int num_observations;
float damping;
float stiffness;
@ -66,7 +67,7 @@ public:
protected:
// rl module
torch::jit::script::Module actor;
torch::jit::script::Module model;
// observation buffer
torch::Tensor lin_vel;
torch::Tensor ang_vel;

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -1,4 +1,4 @@
#include "../include/rl_real.hpp"
#include "../include/rl_real_a1.hpp"
#define ROBOT_NAME "a1"
@ -15,13 +15,9 @@ RL_Real::RL_Real() : safe(LeggedType::A1), udp(LOWLEVEL)
start_time = std::chrono::high_resolution_clock::now();
std::string actor_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/" + ROBOT_NAME + "/actor.pt";
std::string encoder_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/" + ROBOT_NAME + "/encoder.pt";
std::string vq_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/" + ROBOT_NAME + "/vq_layer.pt";
std::string model_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/" + ROBOT_NAME + "/" + this->params.model_name;
this->model = torch::jit::load(model_path);
this->actor = torch::jit::load(actor_path);
this->encoder = torch::jit::load(encoder_path);
this->vq = torch::jit::load(vq_path);
this->InitObservations();
this->InitOutputs();
@ -263,13 +259,7 @@ torch::Tensor RL_Real::Forward()
history_obs_buf.insert(obs);
history_obs = history_obs_buf.get_obs_vec({0, 1, 2, 3, 4, 5});
torch::Tensor encoding = this->encoder.forward({history_obs}).toTensor();
torch::Tensor z = this->vq.forward({encoding}).toTensor();
torch::Tensor actor_input = torch::cat({obs, z}, 1);
torch::Tensor action = this->actor.forward({actor_input}).toTensor();
torch::Tensor action = this->model.forward({history_obs}).toTensor();
this->obs.actions = action;
torch::Tensor clamped = torch::clamp(action, -this->params.clip_actions, this->params.clip_actions);

View File

@ -13,13 +13,9 @@ RL_Real::RL_Real() : CustomInterface(500)
start_time = std::chrono::high_resolution_clock::now();
std::string actor_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/" + ROBOT_NAME + "/actor.pt";
std::string encoder_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/" + ROBOT_NAME + "/encoder.pt";
std::string vq_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/" + ROBOT_NAME + "/vq_layer.pt";
std::string model_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/" + ROBOT_NAME + "/" + this->params.model_name;
this->model = torch::jit::load(model_path);
this->actor = torch::jit::load(actor_path);
this->encoder = torch::jit::load(encoder_path);
this->vq = torch::jit::load(vq_path);
this->InitObservations();
this->InitOutputs();
@ -310,13 +306,7 @@ torch::Tensor RL_Real::Forward()
history_obs_buf.insert(obs);
history_obs = history_obs_buf.get_obs_vec({0, 1, 2, 3, 4, 5});
torch::Tensor encoding = this->encoder.forward({history_obs}).toTensor();
torch::Tensor z = this->vq.forward({encoding}).toTensor();
torch::Tensor actor_input = torch::cat({obs, z}, 1);
torch::Tensor action = this->actor.forward({actor_input}).toTensor();
torch::Tensor action = this->model.forward({history_obs}).toTensor();
this->obs.actions = action;
torch::Tensor clamped = torch::clamp(action, -this->params.clip_actions, this->params.clip_actions);

View File

@ -16,13 +16,9 @@ RL_Sim::RL_Sim()
motor_commands.resize(12);
std::string actor_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/" + ROBOT_NAME + "/actor.pt";
std::string encoder_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/" + ROBOT_NAME + "/encoder.pt";
std::string vq_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/" + ROBOT_NAME + "/vq_layer.pt";
std::string model_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/" + ROBOT_NAME + "/" + this->params.model_name;
this->model = torch::jit::load(model_path);
this->actor = torch::jit::load(actor_path);
this->encoder = torch::jit::load(encoder_path);
this->vq = torch::jit::load(vq_path);
this->InitObservations();
this->InitOutputs();
@ -194,13 +190,7 @@ torch::Tensor RL_Sim::Forward()
history_obs_buf.insert(obs);
history_obs = history_obs_buf.get_obs_vec({0, 1, 2, 3, 4, 5});
torch::Tensor encoding = this->encoder.forward({history_obs}).toTensor();
torch::Tensor z = this->vq.forward({encoding}).toTensor();
torch::Tensor actor_input = torch::cat({obs, z}, 1);
torch::Tensor action = this->actor.forward({actor_input}).toTensor();
torch::Tensor action = this->model.forward({history_obs}).toTensor();
this->obs.actions = action;
torch::Tensor clamped = torch::clamp(action, -this->params.clip_actions, this->params.clip_actions);