add estimator

This commit is contained in:
Huang Zhenbiao 2024-10-15 22:40:15 +08:00
parent 07e3cf086e
commit 0ff1b344e7
10 changed files with 61 additions and 42 deletions

View File

@ -31,4 +31,10 @@ colcon build --packages-up-to rl_quadruped_controller
```bash
source ~/ros2_ws/install/setup.bash
ros2 launch a1_description rl_control.launch.py
```
* Unitree Go2 Robot
```bash
source ~/ros2_ws/install/setup.bash
ros2 launch go2_description rl_control.launch.py
```

View File

@ -45,14 +45,20 @@ StateRL::StateRL(CtrlComponent &ctrl_component, const std::string &config_path)
loadYaml(config_path);
// history
if (!params_.observations_history.empty())
{
history_obs_buf_ = std::make_shared<ObservationBuffer>(1, params_.num_observations, params_.observations_history.size());
if (!params_.observations_history.empty()) {
history_obs_buf_ = std::make_shared<ObservationBuffer>(1, params_.num_observations,
params_.observations_history.size());
}
model_ = torch::jit::load(config_path + "/" + params_.model_name);
std::cout << "Model loaded: " << config_path + "/" + params_.model_name << std::endl;
// for (const auto &param: model_.parameters()) {
// std::cout << "Parameter dtype: " << param.dtype() << std::endl;
// }
rl_thread_ = std::thread([&] {
while (true) {
try {
@ -137,6 +143,8 @@ torch::Tensor StateRL::computeObservation() {
}
const torch::Tensor obs = cat(obs_list, 1);
std::cout << "Observation: " << obs << std::endl;
torch::Tensor clamped_obs = clamp(obs, -params_.clip_obs, params_.clip_obs);
return clamped_obs;
}
@ -156,12 +164,9 @@ void StateRL::loadYaml(const std::string &config_path) {
params_.framework = config["framework"].as<std::string>();
const int rows = config["rows"].as<int>();
const int cols = config["cols"].as<int>();
if (config["observations_history"].IsNull())
{
if (config["observations_history"].IsNull()) {
params_.observations_history = {};
}
else
{
} else {
params_.observations_history = ReadVectorFromYaml<int>(config["observations_history"]);
}
params_.decimation = config["decimation"].as<int>();
@ -226,14 +231,11 @@ torch::Tensor StateRL::forward() {
torch::Tensor clamped_obs = computeObservation();
torch::Tensor actions;
if (!params_.observations_history.empty())
{
if (!params_.observations_history.empty()) {
history_obs_buf_->insert(clamped_obs);
history_obs_ = history_obs_buf_->getObsVec(params_.observations_history);
actions = model_.forward({history_obs_}).toTensor();
}
else
{
} else {
actions = model_.forward({clamped_obs}).toTensor();
}
@ -278,7 +280,8 @@ void StateRL::getState() {
}
void StateRL::runModel() {
obs_.lin_vel = torch::tensor(ctrl_comp_.estimator_->getVelocity().data()).unsqueeze(0);
obs_.lin_vel = torch::from_blob(ctrl_comp_.estimator_->getVelocity().data(), {3}, torch::kDouble).clone().
to(torch::kFloat).unsqueeze(0);
obs_.ang_vel = torch::tensor(robot_state_.imu.gyroscope).unsqueeze(0);
obs_.commands = torch::tensor({{control_.x, control_.y, control_.yaw}});
obs_.base_quat = torch::tensor(robot_state_.imu.quaternion).unsqueeze(0);

View File

@ -93,7 +93,7 @@ namespace rl_quadruped_controller {
// foot_force_sensor
foot_force_name_ = auto_declare<std::string>("foot_force_name", foot_force_name_);
foot_force_interface_types_ =
auto_declare<std::vector<std::string> >("foot_force_interfaces", state_interface_types_);
auto_declare<std::vector<std::string> >("foot_force_interfaces", foot_force_interface_types_);
// rl config folder
rl_config_folder_ = auto_declare<std::string>("config_folder", rl_config_folder_);

View File

@ -81,7 +81,7 @@ namespace rl_quadruped_controller {
std::vector<std::string> imu_interface_types_;
// Foot Force Sensor
std::string foot_force_name_ = "foot_force";
std::vector<std::string> foot_force_interface_types_ = {"force", "torque"};
std::vector<std::string> foot_force_interface_types_ = {"FL", "FR", "RL", "RR"};
std::string rl_config_folder_;

View File

@ -217,10 +217,10 @@ void Estimator::update() {
Ppriori * C.transpose() * SR * STC * Ppriori.transpose();
// // Using low pass filter to smooth the velocity
low_pass_filters_[0]->addValue(x_hat_(3));
low_pass_filters_[1]->addValue(x_hat_(4));
low_pass_filters_[2]->addValue(x_hat_(5));
x_hat_(3) = low_pass_filters_[0]->getValue();
x_hat_(4) = low_pass_filters_[1]->getValue();
x_hat_(5) = low_pass_filters_[2]->getValue();
// low_pass_filters_[0]->addValue(x_hat_(3));
// low_pass_filters_[1]->addValue(x_hat_(4));
// low_pass_filters_[2]->addValue(x_hat_(5));
// x_hat_(3) = low_pass_filters_[0]->getValue();
// x_hat_(4) = low_pass_filters_[1]->getValue();
// x_hat_(5) = low_pass_filters_[2]->getValue();
}

View File

@ -137,10 +137,10 @@ Estimator::Estimator(CtrlComponent &ctrl_component) : ctrl_component_(ctrl_compo
QInit_ = Qdig.asDiagonal();
QInit_ += B * Cu * B.transpose();
// low_pass_filters_.resize(3);
// low_pass_filters_[0] = std::make_shared<LowPassFilter>(dt_, 3.0);
// low_pass_filters_[1] = std::make_shared<LowPassFilter>(dt_, 3.0);
// low_pass_filters_[2] = std::make_shared<LowPassFilter>(dt_, 3.0);
low_pass_filters_.resize(3);
low_pass_filters_[0] = std::make_shared<LowPassFilter>(dt_, 3.0);
low_pass_filters_[1] = std::make_shared<LowPassFilter>(dt_, 3.0);
low_pass_filters_[2] = std::make_shared<LowPassFilter>(dt_, 3.0);
}
double Estimator::getYaw() const {

View File

@ -100,7 +100,7 @@ unitree_guide_controller:
rl_quadruped_controller:
ros__parameters:
update_rate: 200 # Hz
config_folder: "/home/biao/ros2_ws/install/a1_description/share/a1_description/config/issacgym"
config_folder: "/home/tlab-uav/ros2_ws/install/a1_description/share/a1_description/config/issacgym"
command_prefix: "leg_pd_controller"
joints:
- FL_hip_joint

View File

@ -1,7 +1,7 @@
# Controller Manager configuration
controller_manager:
ros__parameters:
update_rate: 500 # Hz
update_rate: 1000 # Hz
# Define the available controllers
joint_state_broadcaster:
@ -135,7 +135,7 @@ ocs2_quadruped_controller:
rl_quadruped_controller:
ros__parameters:
update_rate: 100 # Hz
update_rate: 200 # Hz
joints:
- FL_hip_joint
- FL_thigh_joint
@ -163,10 +163,10 @@ rl_quadruped_controller:
- velocity
feet_names:
- FR_foot
- FL_foot
- RR_foot
- FR_foot
- RL_foot
- RR_foot
imu_name: "imu_sensor"
base_name: "base"

View File

@ -1,11 +1,11 @@
model_name: "himloco.pt"
model_name: "policy.pt"
framework: "isaacgym"
rows: 4
cols: 3
decimation: 4
num_observations: 48
observations: ["lin_vel", "ang_vel", "gravity_vec", "dof_pos", "dof_vel", "actions"]
observations_history: [5, 4, 3, 2, 1, 0]
observations: ["lin_vel", "ang_vel", "gravity_vec", "commands", "dof_pos", "dof_vel", "actions"]
#observations_history: [6, 5, 4, 3, 2, 1, 0]
clip_obs: 100.0
clip_actions_lower: [-100, -100, -100,
-100, -100, -100,
@ -15,14 +15,14 @@ clip_actions_upper: [100, 100, 100,
100, 100, 100,
100, 100, 100,
100, 100, 100]
rl_kp: [40, 40, 40,
40, 40, 40,
40, 40, 40,
40, 40, 40]
rl_kd: [1, 1, 1,
1, 1, 1,
1, 1, 1,
1, 1, 1]
rl_kp: [20, 20, 20,
20, 20, 20,
20, 20, 20,
20, 20, 20]
rl_kd: [0.5, 0.5, 0.5,
0.5, 0.5, 0.5,
0.5, 0.5, 0.5,
0.5, 0.5, 0.5]
fixed_kp: [60, 60, 60,
60, 60, 60,
60, 60, 60,
@ -35,11 +35,14 @@ hip_scale_reduction: 1.0
hip_scale_reduction_indices: [0, 3, 6, 9]
num_of_dofs: 12
action_scale: 0.25
lin_vel_scale: 2.0
ang_vel_scale: 0.25
dof_pos_scale: 1.0
dof_vel_scale: 0.05
commands_scale: [2.0, 2.0, 0.25]
torque_limits: [33.5, 33.5, 33.5,
33.5, 33.5, 33.5,
33.5, 33.5, 33.5,

View File

@ -161,7 +161,14 @@ rl_quadruped_controller:
- position
- velocity
feet_names:
- FL_foot
- FR_foot
- RL_foot
- RR_foot
imu_name: "imu_sensor"
base_name: "base"
imu_interfaces:
- orientation.w