add estimator
This commit is contained in:
parent
07e3cf086e
commit
0ff1b344e7
|
@ -32,3 +32,9 @@ colcon build --packages-up-to rl_quadruped_controller
|
||||||
source ~/ros2_ws/install/setup.bash
|
source ~/ros2_ws/install/setup.bash
|
||||||
ros2 launch a1_description rl_control.launch.py
|
ros2 launch a1_description rl_control.launch.py
|
||||||
```
|
```
|
||||||
|
|
||||||
|
* Unitree Go2 Robot
|
||||||
|
```bash
|
||||||
|
source ~/ros2_ws/install/setup.bash
|
||||||
|
ros2 launch go2_description rl_control.launch.py
|
||||||
|
```
|
|
@ -45,14 +45,20 @@ StateRL::StateRL(CtrlComponent &ctrl_component, const std::string &config_path)
|
||||||
loadYaml(config_path);
|
loadYaml(config_path);
|
||||||
|
|
||||||
// history
|
// history
|
||||||
if (!params_.observations_history.empty())
|
if (!params_.observations_history.empty()) {
|
||||||
{
|
history_obs_buf_ = std::make_shared<ObservationBuffer>(1, params_.num_observations,
|
||||||
history_obs_buf_ = std::make_shared<ObservationBuffer>(1, params_.num_observations, params_.observations_history.size());
|
params_.observations_history.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
model_ = torch::jit::load(config_path + "/" + params_.model_name);
|
model_ = torch::jit::load(config_path + "/" + params_.model_name);
|
||||||
std::cout << "Model loaded: " << config_path + "/" + params_.model_name << std::endl;
|
std::cout << "Model loaded: " << config_path + "/" + params_.model_name << std::endl;
|
||||||
|
|
||||||
|
|
||||||
|
// for (const auto ¶m: model_.parameters()) {
|
||||||
|
// std::cout << "Parameter dtype: " << param.dtype() << std::endl;
|
||||||
|
// }
|
||||||
|
|
||||||
|
|
||||||
rl_thread_ = std::thread([&] {
|
rl_thread_ = std::thread([&] {
|
||||||
while (true) {
|
while (true) {
|
||||||
try {
|
try {
|
||||||
|
@ -137,6 +143,8 @@ torch::Tensor StateRL::computeObservation() {
|
||||||
}
|
}
|
||||||
|
|
||||||
const torch::Tensor obs = cat(obs_list, 1);
|
const torch::Tensor obs = cat(obs_list, 1);
|
||||||
|
|
||||||
|
std::cout << "Observation: " << obs << std::endl;
|
||||||
torch::Tensor clamped_obs = clamp(obs, -params_.clip_obs, params_.clip_obs);
|
torch::Tensor clamped_obs = clamp(obs, -params_.clip_obs, params_.clip_obs);
|
||||||
return clamped_obs;
|
return clamped_obs;
|
||||||
}
|
}
|
||||||
|
@ -156,12 +164,9 @@ void StateRL::loadYaml(const std::string &config_path) {
|
||||||
params_.framework = config["framework"].as<std::string>();
|
params_.framework = config["framework"].as<std::string>();
|
||||||
const int rows = config["rows"].as<int>();
|
const int rows = config["rows"].as<int>();
|
||||||
const int cols = config["cols"].as<int>();
|
const int cols = config["cols"].as<int>();
|
||||||
if (config["observations_history"].IsNull())
|
if (config["observations_history"].IsNull()) {
|
||||||
{
|
|
||||||
params_.observations_history = {};
|
params_.observations_history = {};
|
||||||
}
|
} else {
|
||||||
else
|
|
||||||
{
|
|
||||||
params_.observations_history = ReadVectorFromYaml<int>(config["observations_history"]);
|
params_.observations_history = ReadVectorFromYaml<int>(config["observations_history"]);
|
||||||
}
|
}
|
||||||
params_.decimation = config["decimation"].as<int>();
|
params_.decimation = config["decimation"].as<int>();
|
||||||
|
@ -226,14 +231,11 @@ torch::Tensor StateRL::forward() {
|
||||||
torch::Tensor clamped_obs = computeObservation();
|
torch::Tensor clamped_obs = computeObservation();
|
||||||
torch::Tensor actions;
|
torch::Tensor actions;
|
||||||
|
|
||||||
if (!params_.observations_history.empty())
|
if (!params_.observations_history.empty()) {
|
||||||
{
|
|
||||||
history_obs_buf_->insert(clamped_obs);
|
history_obs_buf_->insert(clamped_obs);
|
||||||
history_obs_ = history_obs_buf_->getObsVec(params_.observations_history);
|
history_obs_ = history_obs_buf_->getObsVec(params_.observations_history);
|
||||||
actions = model_.forward({history_obs_}).toTensor();
|
actions = model_.forward({history_obs_}).toTensor();
|
||||||
}
|
} else {
|
||||||
else
|
|
||||||
{
|
|
||||||
actions = model_.forward({clamped_obs}).toTensor();
|
actions = model_.forward({clamped_obs}).toTensor();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -278,7 +280,8 @@ void StateRL::getState() {
|
||||||
}
|
}
|
||||||
|
|
||||||
void StateRL::runModel() {
|
void StateRL::runModel() {
|
||||||
obs_.lin_vel = torch::tensor(ctrl_comp_.estimator_->getVelocity().data()).unsqueeze(0);
|
obs_.lin_vel = torch::from_blob(ctrl_comp_.estimator_->getVelocity().data(), {3}, torch::kDouble).clone().
|
||||||
|
to(torch::kFloat).unsqueeze(0);
|
||||||
obs_.ang_vel = torch::tensor(robot_state_.imu.gyroscope).unsqueeze(0);
|
obs_.ang_vel = torch::tensor(robot_state_.imu.gyroscope).unsqueeze(0);
|
||||||
obs_.commands = torch::tensor({{control_.x, control_.y, control_.yaw}});
|
obs_.commands = torch::tensor({{control_.x, control_.y, control_.yaw}});
|
||||||
obs_.base_quat = torch::tensor(robot_state_.imu.quaternion).unsqueeze(0);
|
obs_.base_quat = torch::tensor(robot_state_.imu.quaternion).unsqueeze(0);
|
||||||
|
|
|
@ -93,7 +93,7 @@ namespace rl_quadruped_controller {
|
||||||
// foot_force_sensor
|
// foot_force_sensor
|
||||||
foot_force_name_ = auto_declare<std::string>("foot_force_name", foot_force_name_);
|
foot_force_name_ = auto_declare<std::string>("foot_force_name", foot_force_name_);
|
||||||
foot_force_interface_types_ =
|
foot_force_interface_types_ =
|
||||||
auto_declare<std::vector<std::string> >("foot_force_interfaces", state_interface_types_);
|
auto_declare<std::vector<std::string> >("foot_force_interfaces", foot_force_interface_types_);
|
||||||
|
|
||||||
// rl config folder
|
// rl config folder
|
||||||
rl_config_folder_ = auto_declare<std::string>("config_folder", rl_config_folder_);
|
rl_config_folder_ = auto_declare<std::string>("config_folder", rl_config_folder_);
|
||||||
|
|
|
@ -81,7 +81,7 @@ namespace rl_quadruped_controller {
|
||||||
std::vector<std::string> imu_interface_types_;
|
std::vector<std::string> imu_interface_types_;
|
||||||
// Foot Force Sensor
|
// Foot Force Sensor
|
||||||
std::string foot_force_name_ = "foot_force";
|
std::string foot_force_name_ = "foot_force";
|
||||||
std::vector<std::string> foot_force_interface_types_ = {"force", "torque"};
|
std::vector<std::string> foot_force_interface_types_ = {"FL", "FR", "RL", "RR"};
|
||||||
|
|
||||||
std::string rl_config_folder_;
|
std::string rl_config_folder_;
|
||||||
|
|
||||||
|
|
|
@ -217,10 +217,10 @@ void Estimator::update() {
|
||||||
Ppriori * C.transpose() * SR * STC * Ppriori.transpose();
|
Ppriori * C.transpose() * SR * STC * Ppriori.transpose();
|
||||||
|
|
||||||
// // Using low pass filter to smooth the velocity
|
// // Using low pass filter to smooth the velocity
|
||||||
low_pass_filters_[0]->addValue(x_hat_(3));
|
// low_pass_filters_[0]->addValue(x_hat_(3));
|
||||||
low_pass_filters_[1]->addValue(x_hat_(4));
|
// low_pass_filters_[1]->addValue(x_hat_(4));
|
||||||
low_pass_filters_[2]->addValue(x_hat_(5));
|
// low_pass_filters_[2]->addValue(x_hat_(5));
|
||||||
x_hat_(3) = low_pass_filters_[0]->getValue();
|
// x_hat_(3) = low_pass_filters_[0]->getValue();
|
||||||
x_hat_(4) = low_pass_filters_[1]->getValue();
|
// x_hat_(4) = low_pass_filters_[1]->getValue();
|
||||||
x_hat_(5) = low_pass_filters_[2]->getValue();
|
// x_hat_(5) = low_pass_filters_[2]->getValue();
|
||||||
}
|
}
|
||||||
|
|
|
@ -137,10 +137,10 @@ Estimator::Estimator(CtrlComponent &ctrl_component) : ctrl_component_(ctrl_compo
|
||||||
QInit_ = Qdig.asDiagonal();
|
QInit_ = Qdig.asDiagonal();
|
||||||
QInit_ += B * Cu * B.transpose();
|
QInit_ += B * Cu * B.transpose();
|
||||||
|
|
||||||
// low_pass_filters_.resize(3);
|
low_pass_filters_.resize(3);
|
||||||
// low_pass_filters_[0] = std::make_shared<LowPassFilter>(dt_, 3.0);
|
low_pass_filters_[0] = std::make_shared<LowPassFilter>(dt_, 3.0);
|
||||||
// low_pass_filters_[1] = std::make_shared<LowPassFilter>(dt_, 3.0);
|
low_pass_filters_[1] = std::make_shared<LowPassFilter>(dt_, 3.0);
|
||||||
// low_pass_filters_[2] = std::make_shared<LowPassFilter>(dt_, 3.0);
|
low_pass_filters_[2] = std::make_shared<LowPassFilter>(dt_, 3.0);
|
||||||
}
|
}
|
||||||
|
|
||||||
double Estimator::getYaw() const {
|
double Estimator::getYaw() const {
|
||||||
|
|
|
@ -100,7 +100,7 @@ unitree_guide_controller:
|
||||||
rl_quadruped_controller:
|
rl_quadruped_controller:
|
||||||
ros__parameters:
|
ros__parameters:
|
||||||
update_rate: 200 # Hz
|
update_rate: 200 # Hz
|
||||||
config_folder: "/home/biao/ros2_ws/install/a1_description/share/a1_description/config/issacgym"
|
config_folder: "/home/tlab-uav/ros2_ws/install/a1_description/share/a1_description/config/issacgym"
|
||||||
command_prefix: "leg_pd_controller"
|
command_prefix: "leg_pd_controller"
|
||||||
joints:
|
joints:
|
||||||
- FL_hip_joint
|
- FL_hip_joint
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
# Controller Manager configuration
|
# Controller Manager configuration
|
||||||
controller_manager:
|
controller_manager:
|
||||||
ros__parameters:
|
ros__parameters:
|
||||||
update_rate: 500 # Hz
|
update_rate: 1000 # Hz
|
||||||
|
|
||||||
# Define the available controllers
|
# Define the available controllers
|
||||||
joint_state_broadcaster:
|
joint_state_broadcaster:
|
||||||
|
@ -135,7 +135,7 @@ ocs2_quadruped_controller:
|
||||||
|
|
||||||
rl_quadruped_controller:
|
rl_quadruped_controller:
|
||||||
ros__parameters:
|
ros__parameters:
|
||||||
update_rate: 100 # Hz
|
update_rate: 200 # Hz
|
||||||
joints:
|
joints:
|
||||||
- FL_hip_joint
|
- FL_hip_joint
|
||||||
- FL_thigh_joint
|
- FL_thigh_joint
|
||||||
|
@ -163,10 +163,10 @@ rl_quadruped_controller:
|
||||||
- velocity
|
- velocity
|
||||||
|
|
||||||
feet_names:
|
feet_names:
|
||||||
- FR_foot
|
|
||||||
- FL_foot
|
- FL_foot
|
||||||
- RR_foot
|
- FR_foot
|
||||||
- RL_foot
|
- RL_foot
|
||||||
|
- RR_foot
|
||||||
|
|
||||||
imu_name: "imu_sensor"
|
imu_name: "imu_sensor"
|
||||||
base_name: "base"
|
base_name: "base"
|
||||||
|
|
|
@ -1,11 +1,11 @@
|
||||||
model_name: "himloco.pt"
|
model_name: "policy.pt"
|
||||||
framework: "isaacgym"
|
framework: "isaacgym"
|
||||||
rows: 4
|
rows: 4
|
||||||
cols: 3
|
cols: 3
|
||||||
decimation: 4
|
decimation: 4
|
||||||
num_observations: 48
|
num_observations: 48
|
||||||
observations: ["lin_vel", "ang_vel", "gravity_vec", "dof_pos", "dof_vel", "actions"]
|
observations: ["lin_vel", "ang_vel", "gravity_vec", "commands", "dof_pos", "dof_vel", "actions"]
|
||||||
observations_history: [5, 4, 3, 2, 1, 0]
|
#observations_history: [6, 5, 4, 3, 2, 1, 0]
|
||||||
clip_obs: 100.0
|
clip_obs: 100.0
|
||||||
clip_actions_lower: [-100, -100, -100,
|
clip_actions_lower: [-100, -100, -100,
|
||||||
-100, -100, -100,
|
-100, -100, -100,
|
||||||
|
@ -15,14 +15,14 @@ clip_actions_upper: [100, 100, 100,
|
||||||
100, 100, 100,
|
100, 100, 100,
|
||||||
100, 100, 100,
|
100, 100, 100,
|
||||||
100, 100, 100]
|
100, 100, 100]
|
||||||
rl_kp: [40, 40, 40,
|
rl_kp: [20, 20, 20,
|
||||||
40, 40, 40,
|
20, 20, 20,
|
||||||
40, 40, 40,
|
20, 20, 20,
|
||||||
40, 40, 40]
|
20, 20, 20]
|
||||||
rl_kd: [1, 1, 1,
|
rl_kd: [0.5, 0.5, 0.5,
|
||||||
1, 1, 1,
|
0.5, 0.5, 0.5,
|
||||||
1, 1, 1,
|
0.5, 0.5, 0.5,
|
||||||
1, 1, 1]
|
0.5, 0.5, 0.5]
|
||||||
fixed_kp: [60, 60, 60,
|
fixed_kp: [60, 60, 60,
|
||||||
60, 60, 60,
|
60, 60, 60,
|
||||||
60, 60, 60,
|
60, 60, 60,
|
||||||
|
@ -35,11 +35,14 @@ hip_scale_reduction: 1.0
|
||||||
hip_scale_reduction_indices: [0, 3, 6, 9]
|
hip_scale_reduction_indices: [0, 3, 6, 9]
|
||||||
num_of_dofs: 12
|
num_of_dofs: 12
|
||||||
action_scale: 0.25
|
action_scale: 0.25
|
||||||
|
|
||||||
lin_vel_scale: 2.0
|
lin_vel_scale: 2.0
|
||||||
ang_vel_scale: 0.25
|
ang_vel_scale: 0.25
|
||||||
dof_pos_scale: 1.0
|
dof_pos_scale: 1.0
|
||||||
dof_vel_scale: 0.05
|
dof_vel_scale: 0.05
|
||||||
|
|
||||||
commands_scale: [2.0, 2.0, 0.25]
|
commands_scale: [2.0, 2.0, 0.25]
|
||||||
|
|
||||||
torque_limits: [33.5, 33.5, 33.5,
|
torque_limits: [33.5, 33.5, 33.5,
|
||||||
33.5, 33.5, 33.5,
|
33.5, 33.5, 33.5,
|
||||||
33.5, 33.5, 33.5,
|
33.5, 33.5, 33.5,
|
||||||
|
|
|
@ -161,7 +161,14 @@ rl_quadruped_controller:
|
||||||
- position
|
- position
|
||||||
- velocity
|
- velocity
|
||||||
|
|
||||||
|
feet_names:
|
||||||
|
- FL_foot
|
||||||
|
- FR_foot
|
||||||
|
- RL_foot
|
||||||
|
- RR_foot
|
||||||
|
|
||||||
imu_name: "imu_sensor"
|
imu_name: "imu_sensor"
|
||||||
|
base_name: "base"
|
||||||
|
|
||||||
imu_interfaces:
|
imu_interfaces:
|
||||||
- orientation.w
|
- orientation.w
|
||||||
|
|
Loading…
Reference in New Issue