add estimator

This commit is contained in:
Huang Zhenbiao 2024-10-15 22:40:15 +08:00
parent 07e3cf086e
commit 0ff1b344e7
10 changed files with 61 additions and 42 deletions

View File

@ -32,3 +32,9 @@ colcon build --packages-up-to rl_quadruped_controller
source ~/ros2_ws/install/setup.bash source ~/ros2_ws/install/setup.bash
ros2 launch a1_description rl_control.launch.py ros2 launch a1_description rl_control.launch.py
``` ```
* Unitree Go2 Robot
```bash
source ~/ros2_ws/install/setup.bash
ros2 launch go2_description rl_control.launch.py
```

View File

@ -45,14 +45,20 @@ StateRL::StateRL(CtrlComponent &ctrl_component, const std::string &config_path)
loadYaml(config_path); loadYaml(config_path);
// history // history
if (!params_.observations_history.empty()) if (!params_.observations_history.empty()) {
{ history_obs_buf_ = std::make_shared<ObservationBuffer>(1, params_.num_observations,
history_obs_buf_ = std::make_shared<ObservationBuffer>(1, params_.num_observations, params_.observations_history.size()); params_.observations_history.size());
} }
model_ = torch::jit::load(config_path + "/" + params_.model_name); model_ = torch::jit::load(config_path + "/" + params_.model_name);
std::cout << "Model loaded: " << config_path + "/" + params_.model_name << std::endl; std::cout << "Model loaded: " << config_path + "/" + params_.model_name << std::endl;
// for (const auto &param: model_.parameters()) {
// std::cout << "Parameter dtype: " << param.dtype() << std::endl;
// }
rl_thread_ = std::thread([&] { rl_thread_ = std::thread([&] {
while (true) { while (true) {
try { try {
@ -137,6 +143,8 @@ torch::Tensor StateRL::computeObservation() {
} }
const torch::Tensor obs = cat(obs_list, 1); const torch::Tensor obs = cat(obs_list, 1);
std::cout << "Observation: " << obs << std::endl;
torch::Tensor clamped_obs = clamp(obs, -params_.clip_obs, params_.clip_obs); torch::Tensor clamped_obs = clamp(obs, -params_.clip_obs, params_.clip_obs);
return clamped_obs; return clamped_obs;
} }
@ -156,12 +164,9 @@ void StateRL::loadYaml(const std::string &config_path) {
params_.framework = config["framework"].as<std::string>(); params_.framework = config["framework"].as<std::string>();
const int rows = config["rows"].as<int>(); const int rows = config["rows"].as<int>();
const int cols = config["cols"].as<int>(); const int cols = config["cols"].as<int>();
if (config["observations_history"].IsNull()) if (config["observations_history"].IsNull()) {
{
params_.observations_history = {}; params_.observations_history = {};
} } else {
else
{
params_.observations_history = ReadVectorFromYaml<int>(config["observations_history"]); params_.observations_history = ReadVectorFromYaml<int>(config["observations_history"]);
} }
params_.decimation = config["decimation"].as<int>(); params_.decimation = config["decimation"].as<int>();
@ -226,14 +231,11 @@ torch::Tensor StateRL::forward() {
torch::Tensor clamped_obs = computeObservation(); torch::Tensor clamped_obs = computeObservation();
torch::Tensor actions; torch::Tensor actions;
if (!params_.observations_history.empty()) if (!params_.observations_history.empty()) {
{
history_obs_buf_->insert(clamped_obs); history_obs_buf_->insert(clamped_obs);
history_obs_ = history_obs_buf_->getObsVec(params_.observations_history); history_obs_ = history_obs_buf_->getObsVec(params_.observations_history);
actions = model_.forward({history_obs_}).toTensor(); actions = model_.forward({history_obs_}).toTensor();
} } else {
else
{
actions = model_.forward({clamped_obs}).toTensor(); actions = model_.forward({clamped_obs}).toTensor();
} }
@ -278,7 +280,8 @@ void StateRL::getState() {
} }
void StateRL::runModel() { void StateRL::runModel() {
obs_.lin_vel = torch::tensor(ctrl_comp_.estimator_->getVelocity().data()).unsqueeze(0); obs_.lin_vel = torch::from_blob(ctrl_comp_.estimator_->getVelocity().data(), {3}, torch::kDouble).clone().
to(torch::kFloat).unsqueeze(0);
obs_.ang_vel = torch::tensor(robot_state_.imu.gyroscope).unsqueeze(0); obs_.ang_vel = torch::tensor(robot_state_.imu.gyroscope).unsqueeze(0);
obs_.commands = torch::tensor({{control_.x, control_.y, control_.yaw}}); obs_.commands = torch::tensor({{control_.x, control_.y, control_.yaw}});
obs_.base_quat = torch::tensor(robot_state_.imu.quaternion).unsqueeze(0); obs_.base_quat = torch::tensor(robot_state_.imu.quaternion).unsqueeze(0);

View File

@ -93,7 +93,7 @@ namespace rl_quadruped_controller {
// foot_force_sensor // foot_force_sensor
foot_force_name_ = auto_declare<std::string>("foot_force_name", foot_force_name_); foot_force_name_ = auto_declare<std::string>("foot_force_name", foot_force_name_);
foot_force_interface_types_ = foot_force_interface_types_ =
auto_declare<std::vector<std::string> >("foot_force_interfaces", state_interface_types_); auto_declare<std::vector<std::string> >("foot_force_interfaces", foot_force_interface_types_);
// rl config folder // rl config folder
rl_config_folder_ = auto_declare<std::string>("config_folder", rl_config_folder_); rl_config_folder_ = auto_declare<std::string>("config_folder", rl_config_folder_);

View File

@ -81,7 +81,7 @@ namespace rl_quadruped_controller {
std::vector<std::string> imu_interface_types_; std::vector<std::string> imu_interface_types_;
// Foot Force Sensor // Foot Force Sensor
std::string foot_force_name_ = "foot_force"; std::string foot_force_name_ = "foot_force";
std::vector<std::string> foot_force_interface_types_ = {"force", "torque"}; std::vector<std::string> foot_force_interface_types_ = {"FL", "FR", "RL", "RR"};
std::string rl_config_folder_; std::string rl_config_folder_;

View File

@ -217,10 +217,10 @@ void Estimator::update() {
Ppriori * C.transpose() * SR * STC * Ppriori.transpose(); Ppriori * C.transpose() * SR * STC * Ppriori.transpose();
// // Using low pass filter to smooth the velocity // // Using low pass filter to smooth the velocity
low_pass_filters_[0]->addValue(x_hat_(3)); // low_pass_filters_[0]->addValue(x_hat_(3));
low_pass_filters_[1]->addValue(x_hat_(4)); // low_pass_filters_[1]->addValue(x_hat_(4));
low_pass_filters_[2]->addValue(x_hat_(5)); // low_pass_filters_[2]->addValue(x_hat_(5));
x_hat_(3) = low_pass_filters_[0]->getValue(); // x_hat_(3) = low_pass_filters_[0]->getValue();
x_hat_(4) = low_pass_filters_[1]->getValue(); // x_hat_(4) = low_pass_filters_[1]->getValue();
x_hat_(5) = low_pass_filters_[2]->getValue(); // x_hat_(5) = low_pass_filters_[2]->getValue();
} }

View File

@ -137,10 +137,10 @@ Estimator::Estimator(CtrlComponent &ctrl_component) : ctrl_component_(ctrl_compo
QInit_ = Qdig.asDiagonal(); QInit_ = Qdig.asDiagonal();
QInit_ += B * Cu * B.transpose(); QInit_ += B * Cu * B.transpose();
// low_pass_filters_.resize(3); low_pass_filters_.resize(3);
// low_pass_filters_[0] = std::make_shared<LowPassFilter>(dt_, 3.0); low_pass_filters_[0] = std::make_shared<LowPassFilter>(dt_, 3.0);
// low_pass_filters_[1] = std::make_shared<LowPassFilter>(dt_, 3.0); low_pass_filters_[1] = std::make_shared<LowPassFilter>(dt_, 3.0);
// low_pass_filters_[2] = std::make_shared<LowPassFilter>(dt_, 3.0); low_pass_filters_[2] = std::make_shared<LowPassFilter>(dt_, 3.0);
} }
double Estimator::getYaw() const { double Estimator::getYaw() const {

View File

@ -100,7 +100,7 @@ unitree_guide_controller:
rl_quadruped_controller: rl_quadruped_controller:
ros__parameters: ros__parameters:
update_rate: 200 # Hz update_rate: 200 # Hz
config_folder: "/home/biao/ros2_ws/install/a1_description/share/a1_description/config/issacgym" config_folder: "/home/tlab-uav/ros2_ws/install/a1_description/share/a1_description/config/issacgym"
command_prefix: "leg_pd_controller" command_prefix: "leg_pd_controller"
joints: joints:
- FL_hip_joint - FL_hip_joint

View File

@ -1,7 +1,7 @@
# Controller Manager configuration # Controller Manager configuration
controller_manager: controller_manager:
ros__parameters: ros__parameters:
update_rate: 500 # Hz update_rate: 1000 # Hz
# Define the available controllers # Define the available controllers
joint_state_broadcaster: joint_state_broadcaster:
@ -135,7 +135,7 @@ ocs2_quadruped_controller:
rl_quadruped_controller: rl_quadruped_controller:
ros__parameters: ros__parameters:
update_rate: 100 # Hz update_rate: 200 # Hz
joints: joints:
- FL_hip_joint - FL_hip_joint
- FL_thigh_joint - FL_thigh_joint
@ -163,10 +163,10 @@ rl_quadruped_controller:
- velocity - velocity
feet_names: feet_names:
- FR_foot
- FL_foot - FL_foot
- RR_foot - FR_foot
- RL_foot - RL_foot
- RR_foot
imu_name: "imu_sensor" imu_name: "imu_sensor"
base_name: "base" base_name: "base"

View File

@ -1,11 +1,11 @@
model_name: "himloco.pt" model_name: "policy.pt"
framework: "isaacgym" framework: "isaacgym"
rows: 4 rows: 4
cols: 3 cols: 3
decimation: 4 decimation: 4
num_observations: 48 num_observations: 48
observations: ["lin_vel", "ang_vel", "gravity_vec", "dof_pos", "dof_vel", "actions"] observations: ["lin_vel", "ang_vel", "gravity_vec", "commands", "dof_pos", "dof_vel", "actions"]
observations_history: [5, 4, 3, 2, 1, 0] #observations_history: [6, 5, 4, 3, 2, 1, 0]
clip_obs: 100.0 clip_obs: 100.0
clip_actions_lower: [-100, -100, -100, clip_actions_lower: [-100, -100, -100,
-100, -100, -100, -100, -100, -100,
@ -15,14 +15,14 @@ clip_actions_upper: [100, 100, 100,
100, 100, 100, 100, 100, 100,
100, 100, 100, 100, 100, 100,
100, 100, 100] 100, 100, 100]
rl_kp: [40, 40, 40, rl_kp: [20, 20, 20,
40, 40, 40, 20, 20, 20,
40, 40, 40, 20, 20, 20,
40, 40, 40] 20, 20, 20]
rl_kd: [1, 1, 1, rl_kd: [0.5, 0.5, 0.5,
1, 1, 1, 0.5, 0.5, 0.5,
1, 1, 1, 0.5, 0.5, 0.5,
1, 1, 1] 0.5, 0.5, 0.5]
fixed_kp: [60, 60, 60, fixed_kp: [60, 60, 60,
60, 60, 60, 60, 60, 60,
60, 60, 60, 60, 60, 60,
@ -35,11 +35,14 @@ hip_scale_reduction: 1.0
hip_scale_reduction_indices: [0, 3, 6, 9] hip_scale_reduction_indices: [0, 3, 6, 9]
num_of_dofs: 12 num_of_dofs: 12
action_scale: 0.25 action_scale: 0.25
lin_vel_scale: 2.0 lin_vel_scale: 2.0
ang_vel_scale: 0.25 ang_vel_scale: 0.25
dof_pos_scale: 1.0 dof_pos_scale: 1.0
dof_vel_scale: 0.05 dof_vel_scale: 0.05
commands_scale: [2.0, 2.0, 0.25] commands_scale: [2.0, 2.0, 0.25]
torque_limits: [33.5, 33.5, 33.5, torque_limits: [33.5, 33.5, 33.5,
33.5, 33.5, 33.5, 33.5, 33.5, 33.5,
33.5, 33.5, 33.5, 33.5, 33.5, 33.5,

View File

@ -161,7 +161,14 @@ rl_quadruped_controller:
- position - position
- velocity - velocity
feet_names:
- FL_foot
- FR_foot
- RL_foot
- RR_foot
imu_name: "imu_sensor" imu_name: "imu_sensor"
base_name: "base"
imu_interfaces: imu_interfaces:
- orientation.w - orientation.w