current problem is from wbc

This commit is contained in:
Huang Zhenbiao 2024-09-27 17:40:22 +08:00
parent d4f17e631e
commit ee4118b4e5
9 changed files with 57 additions and 55 deletions

View File

@ -16,8 +16,8 @@
namespace ocs2::legged_robot { namespace ocs2::legged_robot {
class KalmanFilterEstimate final : public StateEstimateBase { class KalmanFilterEstimate final : public StateEstimateBase {
public: public:
KalmanFilterEstimate(PinocchioInterface pinocchioInterface, CentroidalModelInfo info, KalmanFilterEstimate(PinocchioInterface pinocchio_interface, CentroidalModelInfo info,
const PinocchioEndEffectorKinematics &eeKinematics, const PinocchioEndEffectorKinematics &ee_kinematics,
CtrlComponent &ctrl_component, CtrlComponent &ctrl_component,
const rclcpp_lifecycle::LifecycleNode::SharedPtr &node); const rclcpp_lifecycle::LifecycleNode::SharedPtr &node);

View File

@ -23,8 +23,8 @@ namespace ocs2::legged_robot {
public: public:
virtual ~StateEstimateBase() = default; virtual ~StateEstimateBase() = default;
StateEstimateBase(PinocchioInterface pinocchioInterface, CentroidalModelInfo info, StateEstimateBase(PinocchioInterface pinocchio_interface, CentroidalModelInfo info,
const PinocchioEndEffectorKinematics &eeKinematics, const PinocchioEndEffectorKinematics &ee_kinematics,
CtrlComponent &ctrl_component, CtrlComponent &ctrl_component,
rclcpp_lifecycle::LifecycleNode::SharedPtr node); rclcpp_lifecycle::LifecycleNode::SharedPtr node);
@ -47,12 +47,12 @@ namespace ocs2::legged_robot {
CtrlComponent &ctrl_component_; CtrlComponent &ctrl_component_;
PinocchioInterface pinocchioInterface_; PinocchioInterface pinocchio_interface_;
CentroidalModelInfo info_; CentroidalModelInfo info_;
std::unique_ptr<PinocchioEndEffectorKinematics> eeKinematics_; std::unique_ptr<PinocchioEndEffectorKinematics> ee_kinematics_;
vector3_t zyxOffset_ = vector3_t::Zero(); vector3_t zyx_offset_ = vector3_t::Zero();
vector_t rbdState_; vector_t rbd_state_;
contact_flag_t contact_flag_{}; contact_flag_t contact_flag_{};
Eigen::Quaternion<scalar_t> quat_; Eigen::Quaternion<scalar_t> quat_;
vector3_t angular_vel_local_, linear_accel_local_; vector3_t angular_vel_local_, linear_accel_local_;

View File

@ -6,7 +6,7 @@
#include "WbcBase.h" #include "WbcBase.h"
namespace ocs2::legged_robot { namespace ocs2::legged_robot {
class HierarchicalWbc : public WbcBase { class HierarchicalWbc final : public WbcBase {
public: public:
using WbcBase::WbcBase; using WbcBase::WbcBase;

View File

@ -12,7 +12,6 @@
#include <utility> #include <utility>
namespace ocs2::legged_robot { namespace ocs2::legged_robot {
using namespace ocs2;
class Task { class Task {
public: public:

View File

@ -11,6 +11,7 @@
#include <ocs2_quadruped_controller/estimator/LinearKalmanFilter.h> #include <ocs2_quadruped_controller/estimator/LinearKalmanFilter.h>
#include <ocs2_quadruped_controller/wbc/WeightedWbc.h> #include <ocs2_quadruped_controller/wbc/WeightedWbc.h>
#include <ocs2_ros_interfaces/synchronized_module/RosReferenceManager.h> #include <ocs2_ros_interfaces/synchronized_module/RosReferenceManager.h>
#include <ocs2_ros_interfaces/common/RosMsgConversions.h>
#include <ocs2_sqp/SqpMpc.h> #include <ocs2_sqp/SqpMpc.h>
#include <angles/angles.h> #include <angles/angles.h>
#include <ocs2_quadruped_controller/control/GaitManager.h> #include <ocs2_quadruped_controller/control/GaitManager.h>
@ -64,42 +65,43 @@ namespace ocs2::legged_robot {
mpc_mrt_interface_->updatePolicy(); mpc_mrt_interface_->updatePolicy();
// Evaluate the current policy // Evaluate the current policy
vector_t optimizedState, optimizedInput; vector_t optimized_state, optimized_input;
size_t plannedMode = 0; // The mode that is active at the time the policy is evaluated at. size_t planned_mode = 0; // The mode that is active at the time the policy is evaluated at.
mpc_mrt_interface_->evaluatePolicy(current_observation_.time, current_observation_.state, optimizedState, mpc_mrt_interface_->evaluatePolicy(current_observation_.time, current_observation_.state, optimized_state,
optimizedInput, plannedMode); optimized_input, planned_mode);
// Whole body control // Whole body control
current_observation_.input = optimizedInput; current_observation_.input = optimized_input;
wbc_timer_.startTimer(); wbc_timer_.startTimer();
vector_t x = wbc_->update(optimizedState, optimizedInput, measuredRbdState_, plannedMode, period.seconds()); vector_t x = wbc_->update(optimized_state, optimized_input, measuredRbdState_, planned_mode, period.seconds());
wbc_timer_.endTimer(); wbc_timer_.endTimer();
vector_t torque = x.tail(12); vector_t torque = x.tail(12);
vector_t posDes = centroidal_model::getJointAngles(optimizedState, legged_interface_->getCentroidalModelInfo()); vector_t pos_des = centroidal_model::getJointAngles(optimized_state, legged_interface_->getCentroidalModelInfo());
vector_t velDes = centroidal_model::getJointVelocities(optimizedInput, vector_t vel_des = centroidal_model::getJointVelocities(optimized_input,
legged_interface_->getCentroidalModelInfo()); legged_interface_->getCentroidalModelInfo());
// Safety check, if failed, stop the controller // Safety check, if failed, stop the controller
if (!safety_checker_->check(current_observation_, optimizedState, optimizedInput)) { if (!safety_checker_->check(current_observation_, optimized_state, optimized_input)) {
RCLCPP_ERROR(get_node()->get_logger(), "[Legged Controller] Safety check failed, stopping the controller."); RCLCPP_ERROR(get_node()->get_logger(), "[Legged Controller] Safety check failed, stopping the controller.");
} }
for (int i = 0; i < joint_names_.size(); i++) { for (int i = 0; i < joint_names_.size(); i++) {
ctrl_comp_.joint_torque_command_interface_[i].get().set_value(torque(i)); ctrl_comp_.joint_torque_command_interface_[i].get().set_value(torque(i));
ctrl_comp_.joint_position_command_interface_[i].get().set_value(posDes(i)); ctrl_comp_.joint_position_command_interface_[i].get().set_value(pos_des(i));
ctrl_comp_.joint_velocity_command_interface_[i].get().set_value(velDes(i)); ctrl_comp_.joint_velocity_command_interface_[i].get().set_value(vel_des(i));
ctrl_comp_.joint_kp_command_interface_[i].get().set_value(0.0); ctrl_comp_.joint_kp_command_interface_[i].get().set_value(0.0);
ctrl_comp_.joint_kd_command_interface_[i].get().set_value(3.0); ctrl_comp_.joint_kd_command_interface_[i].get().set_value(1.0);
} }
observation_publisher_->publish(ros_msg_conversions::createObservationMsg(current_observation_));
return controller_interface::return_type::OK; return controller_interface::return_type::OK;
} }
controller_interface::CallbackReturn Ocs2QuadrupedController::on_init() { controller_interface::CallbackReturn Ocs2QuadrupedController::on_init() {
// Initialize OCS2 // Initialize OCS2
urdf_file_ = auto_declare<std::string>("urdf_file", urdf_file_); urdf_file_ = auto_declare<std::string>("urdf_file", urdf_file_);
task_file_ = auto_declare<std::string>("task_file", task_file_); task_file_ = auto_declare<std::string>("task_file", task_file_);
@ -162,6 +164,9 @@ namespace ocs2::legged_robot {
ctrl_comp_.control_inputs_.ry = msg->ry; ctrl_comp_.control_inputs_.ry = msg->ry;
}); });
observation_publisher_ = get_node()->create_publisher<ocs2_msgs::msg::MpcObservation>(
"legged_robot_mpc_observation", 10);
get_node()->get_parameter("update_rate", ctrl_comp_.frequency_); get_node()->get_parameter("update_rate", ctrl_comp_.frequency_);
RCLCPP_INFO(get_node()->get_logger(), "Controller Manager Update Rate: %d Hz", ctrl_comp_.frequency_); RCLCPP_INFO(get_node()->get_logger(), "Controller Manager Update Rate: %d Hz", ctrl_comp_.frequency_);
@ -209,9 +214,7 @@ namespace ocs2::legged_robot {
mpc_mrt_interface_->getReferenceManager().setTargetTrajectories(target_trajectories); mpc_mrt_interface_->getReferenceManager().setTargetTrajectories(target_trajectories);
RCLCPP_INFO(get_node()->get_logger(), "Waiting for the initial policy ..."); RCLCPP_INFO(get_node()->get_logger(), "Waiting for the initial policy ...");
while (!mpc_mrt_interface_->initialPolicyReceived()) { while (!mpc_mrt_interface_->initialPolicyReceived()) {
std::cout<<"Waiting for the initial policy ..."<<std::endl;
mpc_mrt_interface_->advanceMpc(); mpc_mrt_interface_->advanceMpc();
std::cout<<"Advance MPC"<<std::endl;
rclcpp::WallRate(legged_interface_->mpcSettings().mrtDesiredFrequency_).sleep(); rclcpp::WallRate(legged_interface_->mpcSettings().mrtDesiredFrequency_).sleep();
} }
RCLCPP_INFO(get_node()->get_logger(), "Initial policy has been received."); RCLCPP_INFO(get_node()->get_logger(), "Initial policy has been received.");
@ -253,7 +256,7 @@ namespace ocs2::legged_robot {
legged_interface_->getOptimalControlProblem(), legged_interface_->getOptimalControlProblem(),
legged_interface_->getInitializer()); legged_interface_->getInitializer());
rbd_conversions_ = std::make_shared<CentroidalModelRbdConversions>(legged_interface_->getPinocchioInterface(), rbd_conversions_ = std::make_shared<CentroidalModelRbdConversions>(legged_interface_->getPinocchioInterface(),
legged_interface_->getCentroidalModelInfo()); legged_interface_->getCentroidalModelInfo());
const std::string robotName = "legged_robot"; const std::string robotName = "legged_robot";
@ -268,10 +271,9 @@ namespace ocs2::legged_robot {
// ROS ReferenceManager // ROS ReferenceManager
const auto rosReferenceManagerPtr = std::make_shared<RosReferenceManager>( const auto rosReferenceManagerPtr = std::make_shared<RosReferenceManager>(
robotName, legged_interface_->getReferenceManagerPtr()); robotName, legged_interface_->getReferenceManagerPtr());
// rosReferenceManagerPtr->subscribe(this->get_node()); rosReferenceManagerPtr->subscribe(get_node());
mpc_->getSolverPtr()->addSynchronizedModule(gait_manager_ptr); mpc_->getSolverPtr()->addSynchronizedModule(gait_manager_ptr);
mpc_->getSolverPtr()->setReferenceManager(rosReferenceManagerPtr); mpc_->getSolverPtr()->setReferenceManager(rosReferenceManagerPtr);
// observationPublisher_ = nh.advertise<ocs2_msgs::msg::mpc_observation>(robotName + "_mpc_observation", 1);
} }
void Ocs2QuadrupedController::setupMrt() { void Ocs2QuadrupedController::setupMrt() {
@ -316,7 +318,7 @@ namespace ocs2::legged_robot {
const scalar_t yaw_last = current_observation_.state(9); const scalar_t yaw_last = current_observation_.state(9);
current_observation_.state = rbd_conversions_->computeCentroidalStateFromRbdModel(measuredRbdState_); current_observation_.state = rbd_conversions_->computeCentroidalStateFromRbdModel(measuredRbdState_);
current_observation_.state(9) = yaw_last + angles::shortest_angular_distance( current_observation_.state(9) = yaw_last + angles::shortest_angular_distance(
yaw_last, current_observation_.state(9)); yaw_last, current_observation_.state(9));
current_observation_.mode = ctrl_comp_.estimator_->getMode(); current_observation_.mode = ctrl_comp_.estimator_->getMode();
} }
} }

View File

@ -12,6 +12,7 @@
#include <ocs2_quadruped_controller/estimator/StateEstimateBase.h> #include <ocs2_quadruped_controller/estimator/StateEstimateBase.h>
#include <ocs2_quadruped_controller/interface/LeggedInterface.h> #include <ocs2_quadruped_controller/interface/LeggedInterface.h>
#include <ocs2_quadruped_controller/wbc/WbcBase.h> #include <ocs2_quadruped_controller/wbc/WbcBase.h>
#include <ocs2_msgs/msg/mpc_observation.hpp>
#include "SafetyChecker.h" #include "SafetyChecker.h"
#include "ocs2_quadruped_controller/control/CtrlComponent.h" #include "ocs2_quadruped_controller/control/CtrlComponent.h"
@ -100,6 +101,7 @@ namespace ocs2::legged_robot {
std::vector<std::string> foot_force_interface_types_; std::vector<std::string> foot_force_interface_types_;
rclcpp::Subscription<control_input_msgs::msg::Inputs>::SharedPtr control_input_subscription_; rclcpp::Subscription<control_input_msgs::msg::Inputs>::SharedPtr control_input_subscription_;
rclcpp::Publisher<ocs2_msgs::msg::MpcObservation>::SharedPtr observation_publisher_;
SystemObservation current_observation_; SystemObservation current_observation_;

View File

@ -41,7 +41,6 @@ namespace ocs2::legged_robot {
} }
void GaitManager::getTargetGait() { void GaitManager::getTargetGait() {
std::cout << "ctrl_component_.control_inputs_.command: " << ctrl_component_.control_inputs_.command << std::endl;
if (ctrl_component_.control_inputs_.command == 0) return; if (ctrl_component_.control_inputs_.command == 0) return;
target_gait_ = gait_list_[ctrl_component_.control_inputs_.command - 1]; target_gait_ = gait_list_[ctrl_component_.control_inputs_.command - 1];
RCLCPP_INFO(rclcpp::get_logger("GaitManager"), "Switch to gait: %s", RCLCPP_INFO(rclcpp::get_logger("GaitManager"), "Switch to gait: %s",

View File

@ -13,11 +13,11 @@
#include <ocs2_robotic_tools/common/RotationTransforms.h> #include <ocs2_robotic_tools/common/RotationTransforms.h>
namespace ocs2::legged_robot { namespace ocs2::legged_robot {
KalmanFilterEstimate::KalmanFilterEstimate(PinocchioInterface pinocchioInterface, CentroidalModelInfo info, KalmanFilterEstimate::KalmanFilterEstimate(PinocchioInterface pinocchio_interface, CentroidalModelInfo info,
const PinocchioEndEffectorKinematics &eeKinematics, const PinocchioEndEffectorKinematics &ee_kinematics,
CtrlComponent &ctrl_component, CtrlComponent &ctrl_component,
const rclcpp_lifecycle::LifecycleNode::SharedPtr &node) const rclcpp_lifecycle::LifecycleNode::SharedPtr &node)
: StateEstimateBase(std::move(pinocchioInterface), std::move(info), eeKinematics, ctrl_component, : StateEstimateBase(std::move(pinocchio_interface), std::move(info), ee_kinematics, ctrl_component,
node), node),
numContacts_(info_.numThreeDofContacts + info_.numSixDofContacts), numContacts_(info_.numThreeDofContacts + info_.numSixDofContacts),
dimContacts_(3 * numContacts_), dimContacts_(3 * numContacts_),
@ -45,7 +45,7 @@ namespace ocs2::legged_robot {
r_.setIdentity(numObserve_, numObserve_); r_.setIdentity(numObserve_, numObserve_);
feetHeights_.setZero(numContacts_); feetHeights_.setZero(numContacts_);
eeKinematics_->setPinocchioInterface(pinocchioInterface_); ee_kinematics_->setPinocchioInterface(pinocchio_interface_);
} }
vector_t KalmanFilterEstimate::update(const rclcpp::Time &time, const rclcpp::Duration &period) { vector_t KalmanFilterEstimate::update(const rclcpp::Time &time, const rclcpp::Duration &period) {
@ -61,28 +61,28 @@ namespace ocs2::legged_robot {
q_.block(3, 3, 3, 3) = (dt * 9.81f / 20.f) * matrix3_t::Identity(); q_.block(3, 3, 3, 3) = (dt * 9.81f / 20.f) * matrix3_t::Identity();
q_.block(6, 6, dimContacts_, dimContacts_) = dt * matrix_t::Identity(dimContacts_, dimContacts_); q_.block(6, 6, dimContacts_, dimContacts_) = dt * matrix_t::Identity(dimContacts_, dimContacts_);
const auto &model = pinocchioInterface_.getModel(); const auto &model = pinocchio_interface_.getModel();
auto &data = pinocchioInterface_.getData(); auto &data = pinocchio_interface_.getData();
size_t actuatedDofNum = info_.actuatedDofNum; size_t actuatedDofNum = info_.actuatedDofNum;
vector_t qPino(info_.generalizedCoordinatesNum); vector_t qPino(info_.generalizedCoordinatesNum);
vector_t vPino(info_.generalizedCoordinatesNum); vector_t vPino(info_.generalizedCoordinatesNum);
qPino.setZero(); qPino.setZero();
qPino.segment<3>(3) = rbdState_.head<3>(); // Only set orientation, let position in origin. qPino.segment<3>(3) = rbd_state_.head<3>(); // Only set orientation, let position in origin.
qPino.tail(actuatedDofNum) = rbdState_.segment(6, actuatedDofNum); qPino.tail(actuatedDofNum) = rbd_state_.segment(6, actuatedDofNum);
vPino.setZero(); vPino.setZero();
vPino.segment<3>(3) = getEulerAnglesZyxDerivativesFromGlobalAngularVelocity<scalar_t>( vPino.segment<3>(3) = getEulerAnglesZyxDerivativesFromGlobalAngularVelocity<scalar_t>(
qPino.segment<3>(3), qPino.segment<3>(3),
rbdState_.segment<3>(info_.generalizedCoordinatesNum)); rbd_state_.segment<3>(info_.generalizedCoordinatesNum));
// Only set angular velocity, let linear velocity be zero // Only set angular velocity, let linear velocity be zero
vPino.tail(actuatedDofNum) = rbdState_.segment(6 + info_.generalizedCoordinatesNum, actuatedDofNum); vPino.tail(actuatedDofNum) = rbd_state_.segment(6 + info_.generalizedCoordinatesNum, actuatedDofNum);
forwardKinematics(model, data, qPino, vPino); forwardKinematics(model, data, qPino, vPino);
updateFramePlacements(model, data); updateFramePlacements(model, data);
const auto eePos = eeKinematics_->getPosition(vector_t()); const auto eePos = ee_kinematics_->getPosition(vector_t());
const auto eeVel = eeKinematics_->getVelocity(vector_t(), vector_t()); const auto eeVel = ee_kinematics_->getVelocity(vector_t(), vector_t());
matrix_t q = matrix_t::Identity(numState_, numState_); matrix_t q = matrix_t::Identity(numState_, numState_);
q.block(0, 0, 3, 3) = q_.block(0, 0, 3, 3) * imuProcessNoisePosition_; q.block(0, 0, 3, 3) = q_.block(0, 0, 3, 3) * imuProcessNoisePosition_;
@ -154,7 +154,7 @@ namespace ocs2::legged_robot {
odom.child_frame_id = "base"; odom.child_frame_id = "base";
publishMsgs(odom); publishMsgs(odom);
return rbdState_; return rbd_state_;
} }
nav_msgs::msg::Odometry KalmanFilterEstimate::getOdomMsg() { nav_msgs::msg::Odometry KalmanFilterEstimate::getOdomMsg() {

View File

@ -13,15 +13,15 @@
namespace ocs2::legged_robot { namespace ocs2::legged_robot {
StateEstimateBase::StateEstimateBase(PinocchioInterface pinocchioInterface, CentroidalModelInfo info, StateEstimateBase::StateEstimateBase(PinocchioInterface pinocchio_interface, CentroidalModelInfo info,
const PinocchioEndEffectorKinematics &eeKinematics, const PinocchioEndEffectorKinematics &ee_kinematics,
CtrlComponent &ctrl_component, CtrlComponent &ctrl_component,
rclcpp_lifecycle::LifecycleNode::SharedPtr node) rclcpp_lifecycle::LifecycleNode::SharedPtr node)
: ctrl_component_(ctrl_component), : ctrl_component_(ctrl_component),
pinocchioInterface_(std::move(pinocchioInterface)), pinocchio_interface_(std::move(pinocchio_interface)),
info_(std::move(info)), info_(std::move(info)),
eeKinematics_(eeKinematics.clone()), ee_kinematics_(ee_kinematics.clone()),
rbdState_(vector_t::Zero(2 * info_.generalizedCoordinatesNum)), node_(std::move(node)) { rbd_state_(vector_t::Zero(2 * info_.generalizedCoordinatesNum)), node_(std::move(node)) {
odom_pub_ = node_->create_publisher<nav_msgs::msg::Odometry>("odom", 10); odom_pub_ = node_->create_publisher<nav_msgs::msg::Odometry>("odom", 10);
pose_pub_ = node_->create_publisher<geometry_msgs::msg::PoseWithCovarianceStamped>("pose", 10); pose_pub_ = node_->create_publisher<geometry_msgs::msg::PoseWithCovarianceStamped>("pose", 10);
} }
@ -35,8 +35,8 @@ namespace ocs2::legged_robot {
joint_vel(i) = ctrl_component_.joint_velocity_state_interface_[i].get().get_value(); joint_vel(i) = ctrl_component_.joint_velocity_state_interface_[i].get().get_value();
} }
rbdState_.segment(6, info_.actuatedDofNum) = joint_pos; rbd_state_.segment(6, info_.actuatedDofNum) = joint_pos;
rbdState_.segment(6 + info_.generalizedCoordinatesNum, info_.actuatedDofNum) = joint_vel; rbd_state_.segment(6 + info_.generalizedCoordinatesNum, info_.actuatedDofNum) = joint_vel;
} }
void StateEstimateBase::updateContact() { void StateEstimateBase::updateContact() {
@ -70,20 +70,20 @@ namespace ocs2::legged_robot {
// angularVelCovariance_ = angularVelCovariance; // angularVelCovariance_ = angularVelCovariance;
// linearAccelCovariance_ = linearAccelCovariance; // linearAccelCovariance_ = linearAccelCovariance;
const vector3_t zyx = quatToZyx(quat_) - zyxOffset_; const vector3_t zyx = quatToZyx(quat_) - zyx_offset_;
const vector3_t angularVelGlobal = getGlobalAngularVelocityFromEulerAnglesZyxDerivatives<scalar_t>( const vector3_t angularVelGlobal = getGlobalAngularVelocityFromEulerAnglesZyxDerivatives<scalar_t>(
zyx, getEulerAnglesZyxDerivativesFromLocalAngularVelocity<scalar_t>(quatToZyx(quat_), angular_vel_local_)); zyx, getEulerAnglesZyxDerivativesFromLocalAngularVelocity<scalar_t>(quatToZyx(quat_), angular_vel_local_));
updateAngular(zyx, angularVelGlobal); updateAngular(zyx, angularVelGlobal);
} }
void StateEstimateBase::updateAngular(const vector3_t &zyx, const vector_t &angularVel) { void StateEstimateBase::updateAngular(const vector3_t &zyx, const vector_t &angularVel) {
rbdState_.segment<3>(0) = zyx; rbd_state_.segment<3>(0) = zyx;
rbdState_.segment<3>(info_.generalizedCoordinatesNum) = angularVel; rbd_state_.segment<3>(info_.generalizedCoordinatesNum) = angularVel;
} }
void StateEstimateBase::updateLinear(const vector_t &pos, const vector_t &linearVel) { void StateEstimateBase::updateLinear(const vector_t &pos, const vector_t &linearVel) {
rbdState_.segment<3>(3) = pos; rbd_state_.segment<3>(3) = pos;
rbdState_.segment<3>(info_.generalizedCoordinatesNum + 3) = linearVel; rbd_state_.segment<3>(info_.generalizedCoordinatesNum + 3) = linearVel;
} }
void StateEstimateBase::publishMsgs(const nav_msgs::msg::Odometry &odom) const { void StateEstimateBase::publishMsgs(const nav_msgs::msg::Odometry &odom) const {