mirror of https://github.com/fan-ziqi/rl_sar.git
feat: load all model as one file
This commit is contained in:
parent
e808a96f62
commit
706ddce2a2
|
@ -1,4 +1,5 @@
|
|||
a1:
|
||||
model_name: "model.pt"
|
||||
num_observations: 45
|
||||
clip_obs: 100.0
|
||||
clip_actions: 100.0
|
||||
|
@ -19,16 +20,17 @@ a1:
|
|||
ang_vel_scale: 0.25
|
||||
dof_pos_scale: 1.0
|
||||
dof_vel_scale: 0.05
|
||||
torque_limits: [20.0, 55.0, 55.0, # FL
|
||||
20.0, 55.0, 55.0, # FR
|
||||
20.0, 55.0, 55.0, # RL
|
||||
20.0, 55.0, 55.0] # RR
|
||||
torque_limits: [33.5, 33.5, 33.5, # FL
|
||||
33.5, 33.5, 33.5, # FR
|
||||
33.5, 33.5, 33.5, # RL
|
||||
33.5, 33.5, 33.5] # RR
|
||||
default_dof_pos: [ 0.1000, 0.8000, -1.5000, # FL
|
||||
-0.1000, 0.8000, -1.5000, # FR
|
||||
0.1000, 1.0000, -1.5000, # RL
|
||||
-0.1000, 1.0000, -1.5000] # RR
|
||||
|
||||
cyberdog1:
|
||||
model_name: "model.pt"
|
||||
num_observations: 45
|
||||
clip_obs: 100.0
|
||||
clip_actions: 100.0
|
||||
|
|
|
@ -71,10 +71,6 @@ private:
|
|||
int hip_scale_reduction_indices[4] = {0, 3, 6, 9};
|
||||
|
||||
std::chrono::high_resolution_clock::time_point start_time;
|
||||
|
||||
// other rl module
|
||||
torch::jit::script::Module encoder;
|
||||
torch::jit::script::Module vq;
|
||||
};
|
||||
|
||||
#endif
|
|
@ -81,10 +81,6 @@ private:
|
|||
int hip_scale_reduction_indices[4] = {0, 3, 6, 9};
|
||||
|
||||
std::chrono::high_resolution_clock::time_point start_time;
|
||||
|
||||
// other rl module
|
||||
torch::jit::script::Module encoder;
|
||||
torch::jit::script::Module vq;
|
||||
};
|
||||
|
||||
#endif
|
|
@ -64,10 +64,6 @@ private:
|
|||
int hip_scale_reduction_indices[4] = {0, 3, 6, 9};
|
||||
|
||||
std::chrono::high_resolution_clock::time_point start_time;
|
||||
|
||||
// other rl module
|
||||
torch::jit::script::Module encoder;
|
||||
torch::jit::script::Module vq;
|
||||
};
|
||||
|
||||
#endif
|
|
@ -13,6 +13,7 @@ void RL::ReadYaml(std::string robot_name)
|
|||
return;
|
||||
}
|
||||
|
||||
this->params.model_name = config["model_name"].as<std::string>();
|
||||
this->params.num_observations = config["num_observations"].as<int>();
|
||||
this->params.clip_obs = config["clip_obs"].as<float>();
|
||||
this->params.clip_actions = config["clip_actions"].as<float>();
|
||||
|
|
|
@ -13,6 +13,7 @@ namespace plt = matplotlibcpp;
|
|||
|
||||
struct ModelParams
|
||||
{
|
||||
std::string model_name;
|
||||
int num_observations;
|
||||
float damping;
|
||||
float stiffness;
|
||||
|
@ -66,7 +67,7 @@ public:
|
|||
|
||||
protected:
|
||||
// rl module
|
||||
torch::jit::script::Module actor;
|
||||
torch::jit::script::Module model;
|
||||
// observation buffer
|
||||
torch::Tensor lin_vel;
|
||||
torch::Tensor ang_vel;
|
||||
|
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -1,4 +1,4 @@
|
|||
#include "../include/rl_real.hpp"
|
||||
#include "../include/rl_real_a1.hpp"
|
||||
|
||||
#define ROBOT_NAME "a1"
|
||||
|
||||
|
@ -15,13 +15,9 @@ RL_Real::RL_Real() : safe(LeggedType::A1), udp(LOWLEVEL)
|
|||
|
||||
start_time = std::chrono::high_resolution_clock::now();
|
||||
|
||||
std::string actor_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/" + ROBOT_NAME + "/actor.pt";
|
||||
std::string encoder_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/" + ROBOT_NAME + "/encoder.pt";
|
||||
std::string vq_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/" + ROBOT_NAME + "/vq_layer.pt";
|
||||
std::string model_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/" + ROBOT_NAME + "/" + this->params.model_name;
|
||||
this->model = torch::jit::load(model_path);
|
||||
|
||||
this->actor = torch::jit::load(actor_path);
|
||||
this->encoder = torch::jit::load(encoder_path);
|
||||
this->vq = torch::jit::load(vq_path);
|
||||
this->InitObservations();
|
||||
this->InitOutputs();
|
||||
|
||||
|
@ -263,13 +259,7 @@ torch::Tensor RL_Real::Forward()
|
|||
history_obs_buf.insert(obs);
|
||||
history_obs = history_obs_buf.get_obs_vec({0, 1, 2, 3, 4, 5});
|
||||
|
||||
torch::Tensor encoding = this->encoder.forward({history_obs}).toTensor();
|
||||
|
||||
torch::Tensor z = this->vq.forward({encoding}).toTensor();
|
||||
|
||||
torch::Tensor actor_input = torch::cat({obs, z}, 1);
|
||||
|
||||
torch::Tensor action = this->actor.forward({actor_input}).toTensor();
|
||||
torch::Tensor action = this->model.forward({history_obs}).toTensor();
|
||||
|
||||
this->obs.actions = action;
|
||||
torch::Tensor clamped = torch::clamp(action, -this->params.clip_actions, this->params.clip_actions);
|
||||
|
|
|
@ -13,13 +13,9 @@ RL_Real::RL_Real() : CustomInterface(500)
|
|||
|
||||
start_time = std::chrono::high_resolution_clock::now();
|
||||
|
||||
std::string actor_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/" + ROBOT_NAME + "/actor.pt";
|
||||
std::string encoder_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/" + ROBOT_NAME + "/encoder.pt";
|
||||
std::string vq_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/" + ROBOT_NAME + "/vq_layer.pt";
|
||||
std::string model_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/" + ROBOT_NAME + "/" + this->params.model_name;
|
||||
this->model = torch::jit::load(model_path);
|
||||
|
||||
this->actor = torch::jit::load(actor_path);
|
||||
this->encoder = torch::jit::load(encoder_path);
|
||||
this->vq = torch::jit::load(vq_path);
|
||||
this->InitObservations();
|
||||
this->InitOutputs();
|
||||
|
||||
|
@ -310,13 +306,7 @@ torch::Tensor RL_Real::Forward()
|
|||
history_obs_buf.insert(obs);
|
||||
history_obs = history_obs_buf.get_obs_vec({0, 1, 2, 3, 4, 5});
|
||||
|
||||
torch::Tensor encoding = this->encoder.forward({history_obs}).toTensor();
|
||||
|
||||
torch::Tensor z = this->vq.forward({encoding}).toTensor();
|
||||
|
||||
torch::Tensor actor_input = torch::cat({obs, z}, 1);
|
||||
|
||||
torch::Tensor action = this->actor.forward({actor_input}).toTensor();
|
||||
torch::Tensor action = this->model.forward({history_obs}).toTensor();
|
||||
|
||||
this->obs.actions = action;
|
||||
torch::Tensor clamped = torch::clamp(action, -this->params.clip_actions, this->params.clip_actions);
|
||||
|
|
|
@ -16,13 +16,9 @@ RL_Sim::RL_Sim()
|
|||
|
||||
motor_commands.resize(12);
|
||||
|
||||
std::string actor_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/" + ROBOT_NAME + "/actor.pt";
|
||||
std::string encoder_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/" + ROBOT_NAME + "/encoder.pt";
|
||||
std::string vq_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/" + ROBOT_NAME + "/vq_layer.pt";
|
||||
std::string model_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/" + ROBOT_NAME + "/" + this->params.model_name;
|
||||
this->model = torch::jit::load(model_path);
|
||||
|
||||
this->actor = torch::jit::load(actor_path);
|
||||
this->encoder = torch::jit::load(encoder_path);
|
||||
this->vq = torch::jit::load(vq_path);
|
||||
this->InitObservations();
|
||||
this->InitOutputs();
|
||||
|
||||
|
@ -194,13 +190,7 @@ torch::Tensor RL_Sim::Forward()
|
|||
history_obs_buf.insert(obs);
|
||||
history_obs = history_obs_buf.get_obs_vec({0, 1, 2, 3, 4, 5});
|
||||
|
||||
torch::Tensor encoding = this->encoder.forward({history_obs}).toTensor();
|
||||
|
||||
torch::Tensor z = this->vq.forward({encoding}).toTensor();
|
||||
|
||||
torch::Tensor actor_input = torch::cat({obs, z}, 1);
|
||||
|
||||
torch::Tensor action = this->actor.forward({actor_input}).toTensor();
|
||||
torch::Tensor action = this->model.forward({history_obs}).toTensor();
|
||||
|
||||
this->obs.actions = action;
|
||||
torch::Tensor clamped = torch::clamp(action, -this->params.clip_actions, this->params.clip_actions);
|
||||
|
|
Loading…
Reference in New Issue