From 4cb1657df5d967afcc4ca5f0d59c6a8f334b8990 Mon Sep 17 00:00:00 2001 From: Huang Zhenbiao Date: Mon, 30 Sep 2024 11:59:35 +0800 Subject: [PATCH] achieved basic control on go1 --- .../ocs2_quadruped_controller/CMakeLists.txt | 1 + .../control/CtrlComponent.h | 4 ++ .../control/TargetManager.h | 57 ++++++++++++++++ .../src/Ocs2QuadrupedController.cpp | 67 +++++++++---------- .../src/Ocs2QuadrupedController.h | 3 - .../src/SafetyChecker.h | 11 ++- .../src/control/TargetManager.cpp | 56 ++++++++++++++++ 7 files changed, 156 insertions(+), 43 deletions(-) create mode 100644 controllers/ocs2_quadruped_controller/include/ocs2_quadruped_controller/control/TargetManager.h create mode 100644 controllers/ocs2_quadruped_controller/src/control/TargetManager.cpp diff --git a/controllers/ocs2_quadruped_controller/CMakeLists.txt b/controllers/ocs2_quadruped_controller/CMakeLists.txt index d678dfa..49e0f55 100644 --- a/controllers/ocs2_quadruped_controller/CMakeLists.txt +++ b/controllers/ocs2_quadruped_controller/CMakeLists.txt @@ -52,6 +52,7 @@ add_library(${PROJECT_NAME} SHARED src/interface/LeggedInterface.cpp src/control/GaitManager.cpp + src/control/TargetManager.cpp ) target_include_directories(${PROJECT_NAME} diff --git a/controllers/ocs2_quadruped_controller/include/ocs2_quadruped_controller/control/CtrlComponent.h b/controllers/ocs2_quadruped_controller/include/ocs2_quadruped_controller/control/CtrlComponent.h index e1c5fc7..308876c 100644 --- a/controllers/ocs2_quadruped_controller/include/ocs2_quadruped_controller/control/CtrlComponent.h +++ b/controllers/ocs2_quadruped_controller/include/ocs2_quadruped_controller/control/CtrlComponent.h @@ -9,7 +9,9 @@ #include #include #include +#include +#include "TargetManager.h" #include "ocs2_quadruped_controller/estimator/StateEstimateBase.h" struct CtrlComponent { @@ -39,9 +41,11 @@ struct CtrlComponent { foot_force_state_interface_; control_input_msgs::msg::Inputs control_inputs_; + ocs2::SystemObservation observation_; int frequency_{}; std::shared_ptr estimator_; + std::shared_ptr target_manager_; CtrlComponent() { } diff --git a/controllers/ocs2_quadruped_controller/include/ocs2_quadruped_controller/control/TargetManager.h b/controllers/ocs2_quadruped_controller/include/ocs2_quadruped_controller/control/TargetManager.h new file mode 100644 index 0000000..9de9ebb --- /dev/null +++ b/controllers/ocs2_quadruped_controller/include/ocs2_quadruped_controller/control/TargetManager.h @@ -0,0 +1,57 @@ +// +// Created by tlab-uav on 24-9-30. +// + +#ifndef TARGETMANAGER_H +#define TARGETMANAGER_H + + +#include +#include +#include + +struct CtrlComponent; + +namespace ocs2::legged_robot { + class TargetManager { + public: + TargetManager(CtrlComponent &ctrl_component, + const std::shared_ptr &referenceManagerPtr, + const std::string& task_file, + const std::string& reference_file); + + ~TargetManager() = default; + + void update(); + + private: + TargetTrajectories targetPoseToTargetTrajectories(const vector_t &targetPose, + const SystemObservation &observation, + const scalar_t &targetReachingTime) { + // desired time trajectory + const scalar_array_t timeTrajectory{observation.time, targetReachingTime}; + + // desired state trajectory + vector_t currentPose = observation.state.segment<6>(6); + currentPose(2) = command_height_; + currentPose(4) = 0; + currentPose(5) = 0; + vector_array_t stateTrajectory(2, vector_t::Zero(observation.state.size())); + stateTrajectory[0] << vector_t::Zero(6), currentPose, default_joint_state_; + stateTrajectory[1] << vector_t::Zero(6), targetPose, default_joint_state_; + + // desired input trajectory (just right dimensions, they are not used) + const vector_array_t inputTrajectory(2, vector_t::Zero(observation.input.size())); + + return {timeTrajectory, stateTrajectory, inputTrajectory}; + } + CtrlComponent &ctrl_component_; + std::shared_ptr referenceManagerPtr_; + + vector_t default_joint_state_{}; + scalar_t command_height_{}; + scalar_t time_to_target_{}; + }; +} + +#endif //TARGETMANAGER_H diff --git a/controllers/ocs2_quadruped_controller/src/Ocs2QuadrupedController.cpp b/controllers/ocs2_quadruped_controller/src/Ocs2QuadrupedController.cpp index f7a9fff..6698413 100644 --- a/controllers/ocs2_quadruped_controller/src/Ocs2QuadrupedController.cpp +++ b/controllers/ocs2_quadruped_controller/src/Ocs2QuadrupedController.cpp @@ -58,8 +58,11 @@ namespace ocs2::legged_robot { // State Estimate updateStateEstimation(time, period); + // Compute target trajectory + ctrl_comp_.target_manager_->update(); + // Update the current state of the system - mpc_mrt_interface_->setCurrentObservation(current_observation_); + mpc_mrt_interface_->setCurrentObservation(ctrl_comp_.observation_); // Load the latest MPC policy mpc_mrt_interface_->updatePolicy(); @@ -67,24 +70,26 @@ namespace ocs2::legged_robot { // Evaluate the current policy vector_t optimized_state, optimized_input; size_t planned_mode = 0; // The mode that is active at the time the policy is evaluated at. - mpc_mrt_interface_->evaluatePolicy(current_observation_.time, current_observation_.state, optimized_state, + mpc_mrt_interface_->evaluatePolicy(ctrl_comp_.observation_.time, ctrl_comp_.observation_.state, optimized_state, optimized_input, planned_mode); // Whole body control - current_observation_.input = optimized_input; + ctrl_comp_.observation_.input = optimized_input; wbc_timer_.startTimer(); - vector_t x = wbc_->update(optimized_state, optimized_input, measured_rbd_state_, planned_mode, period.seconds()); + vector_t x = wbc_->update(optimized_state, optimized_input, measured_rbd_state_, planned_mode, + period.seconds()); wbc_timer_.endTimer(); vector_t torque = x.tail(12); - vector_t pos_des = centroidal_model::getJointAngles(optimized_state, legged_interface_->getCentroidalModelInfo()); + vector_t pos_des = centroidal_model::getJointAngles(optimized_state, + legged_interface_->getCentroidalModelInfo()); vector_t vel_des = centroidal_model::getJointVelocities(optimized_input, - legged_interface_->getCentroidalModelInfo()); + legged_interface_->getCentroidalModelInfo()); // Safety check, if failed, stop the controller - if (!safety_checker_->check(current_observation_, optimized_state, optimized_input)) { + if (!safety_checker_->check(ctrl_comp_.observation_, optimized_state, optimized_input)) { RCLCPP_ERROR(get_node()->get_logger(), "[Legged Controller] Safety check failed, stopping the controller."); for (int i = 0; i < joint_names_.size(); i++) { ctrl_comp_.joint_torque_command_interface_[i].get().set_value(0); @@ -101,10 +106,10 @@ namespace ocs2::legged_robot { ctrl_comp_.joint_position_command_interface_[i].get().set_value(pos_des(i)); ctrl_comp_.joint_velocity_command_interface_[i].get().set_value(vel_des(i)); ctrl_comp_.joint_kp_command_interface_[i].get().set_value(0.0); - ctrl_comp_.joint_kd_command_interface_[i].get().set_value(1.0); + ctrl_comp_.joint_kd_command_interface_[i].get().set_value(6.0); } - observation_publisher_->publish(ros_msg_conversions::createObservationMsg(current_observation_)); + observation_publisher_->publish(ros_msg_conversions::createObservationMsg(ctrl_comp_.observation_)); return controller_interface::return_type::OK; } @@ -190,7 +195,6 @@ namespace ocs2::legged_robot { // assign command interfaces for (auto &interface: command_interfaces_) { std::string interface_name = interface.get_interface_name(); - std::cout << "interface_name: " << interface.get_prefix_name() << std::endl; if (const size_t pos = interface_name.find('/'); pos != std::string::npos) { command_interface_map_[interface_name.substr(pos + 1)]->push_back(interface); } else { @@ -211,16 +215,16 @@ namespace ocs2::legged_robot { if (mpc_running_ == false) { // Initial state - current_observation_.state.setZero(static_cast(legged_interface_->getCentroidalModelInfo().stateDim)); + ctrl_comp_.observation_.state.setZero(static_cast(legged_interface_->getCentroidalModelInfo().stateDim)); updateStateEstimation(get_node()->now(), rclcpp::Duration(0, 200000)); - current_observation_.input.setZero(static_cast(legged_interface_->getCentroidalModelInfo().inputDim)); - current_observation_.mode = STANCE; + ctrl_comp_.observation_.input.setZero(static_cast(legged_interface_->getCentroidalModelInfo().inputDim)); + ctrl_comp_.observation_.mode = STANCE; - const TargetTrajectories target_trajectories({current_observation_.time}, {current_observation_.state}, - {current_observation_.input}); + const TargetTrajectories target_trajectories({ctrl_comp_.observation_.time}, {ctrl_comp_.observation_.state}, + {ctrl_comp_.observation_.input}); // Set the first observation and command and wait for optimization to finish - mpc_mrt_interface_->setCurrentObservation(current_observation_); + mpc_mrt_interface_->setCurrentObservation(ctrl_comp_.observation_); mpc_mrt_interface_->getReferenceManager().setTargetTrajectories(target_trajectories); RCLCPP_INFO(get_node()->get_logger(), "Waiting for the initial policy ..."); while (!mpc_mrt_interface_->initialPolicyReceived()) { @@ -268,22 +272,17 @@ namespace ocs2::legged_robot { rbd_conversions_ = std::make_shared(legged_interface_->getPinocchioInterface(), legged_interface_->getCentroidalModelInfo()); - const std::string robotName = "legged_robot"; - - // Todo Handle Gait Receive. - // Gait receiver + // Initialize the reference manager const auto gait_manager_ptr = std::make_shared( ctrl_comp_, legged_interface_->getSwitchedModelReferenceManagerPtr()-> getGaitSchedule()); gait_manager_ptr->init(gait_file_); - - // Todo Here maybe the reason of the nullPointer. - // ROS ReferenceManager - const auto rosReferenceManagerPtr = std::make_shared( - robotName, legged_interface_->getReferenceManagerPtr()); - rosReferenceManagerPtr->subscribe(get_node()); mpc_->getSolverPtr()->addSynchronizedModule(gait_manager_ptr); - mpc_->getSolverPtr()->setReferenceManager(rosReferenceManagerPtr); + mpc_->getSolverPtr()->setReferenceManager(legged_interface_->getReferenceManagerPtr()); + + ctrl_comp_.target_manager_ = std::make_shared(ctrl_comp_, + legged_interface_->getReferenceManagerPtr(), + task_file_, reference_file_); } void Ocs2QuadrupedController::setupMrt() { @@ -319,17 +318,17 @@ namespace ocs2::legged_robot { legged_interface_->getCentroidalModelInfo(), *eeKinematicsPtr_, ctrl_comp_, this->get_node()); dynamic_cast(*ctrl_comp_.estimator_).loadSettings(task_file_, verbose_); - current_observation_.time = 0; + ctrl_comp_.observation_.time = 0; } void Ocs2QuadrupedController::updateStateEstimation(const rclcpp::Time &time, const rclcpp::Duration &period) { measured_rbd_state_ = ctrl_comp_.estimator_->update(time, period); - current_observation_.time += period.seconds(); - const scalar_t yaw_last = current_observation_.state(9); - current_observation_.state = rbd_conversions_->computeCentroidalStateFromRbdModel(measured_rbd_state_); - current_observation_.state(9) = yaw_last + angles::shortest_angular_distance( - yaw_last, current_observation_.state(9)); - current_observation_.mode = ctrl_comp_.estimator_->getMode(); + ctrl_comp_.observation_.time += period.seconds(); + const scalar_t yaw_last = ctrl_comp_.observation_.state(9); + ctrl_comp_.observation_.state = rbd_conversions_->computeCentroidalStateFromRbdModel(measured_rbd_state_); + ctrl_comp_.observation_.state(9) = yaw_last + angles::shortest_angular_distance( + yaw_last, ctrl_comp_.observation_.state(9)); + ctrl_comp_.observation_.mode = ctrl_comp_.estimator_->getMode(); } } diff --git a/controllers/ocs2_quadruped_controller/src/Ocs2QuadrupedController.h b/controllers/ocs2_quadruped_controller/src/Ocs2QuadrupedController.h index 87d9a10..95c8d9f 100644 --- a/controllers/ocs2_quadruped_controller/src/Ocs2QuadrupedController.h +++ b/controllers/ocs2_quadruped_controller/src/Ocs2QuadrupedController.h @@ -103,9 +103,6 @@ namespace ocs2::legged_robot { rclcpp::Subscription::SharedPtr control_input_subscription_; rclcpp::Publisher::SharedPtr observation_publisher_; - - SystemObservation current_observation_; - std::string task_file_; std::string urdf_file_; std::string reference_file_; diff --git a/controllers/ocs2_quadruped_controller/src/SafetyChecker.h b/controllers/ocs2_quadruped_controller/src/SafetyChecker.h index 6b4e1cd..4233fb7 100644 --- a/controllers/ocs2_quadruped_controller/src/SafetyChecker.h +++ b/controllers/ocs2_quadruped_controller/src/SafetyChecker.h @@ -5,7 +5,6 @@ #pragma once #include -#include #include namespace ocs2::legged_robot { @@ -14,15 +13,15 @@ namespace ocs2::legged_robot { explicit SafetyChecker(const CentroidalModelInfo &info) : info_(info) { } - bool check(const SystemObservation &observation, const vector_t & /*optimized_state*/, - const vector_t & /*optimized_input*/) { + [[nodiscard]] bool check(const SystemObservation &observation, const vector_t & /*optimized_state*/, + const vector_t & /*optimized_input*/) const { return checkOrientation(observation); } protected: - bool checkOrientation(const SystemObservation &observation) { - vector_t pose = centroidal_model::getBasePose(observation.state, info_); - if (pose(5) > M_PI_2 || pose(5) < -M_PI_2) { + [[nodiscard]] bool checkOrientation(const SystemObservation &observation) const { + if (vector_t pose = centroidal_model::getBasePose(observation.state, info_); + pose(5) > M_PI_2 || pose(5) < -M_PI_2) { std::cerr << "[SafetyChecker] Orientation safety check failed!" << std::endl; return false; } diff --git a/controllers/ocs2_quadruped_controller/src/control/TargetManager.cpp b/controllers/ocs2_quadruped_controller/src/control/TargetManager.cpp new file mode 100644 index 0000000..bdddf12 --- /dev/null +++ b/controllers/ocs2_quadruped_controller/src/control/TargetManager.cpp @@ -0,0 +1,56 @@ +// +// Created by tlab-uav on 24-9-30. +// + +#include "ocs2_quadruped_controller/control/TargetManager.h" + +#include +#include + +#include "ocs2_quadruped_controller/control/CtrlComponent.h" + +namespace ocs2::legged_robot { + TargetManager::TargetManager(CtrlComponent &ctrl_component, + const std::shared_ptr &referenceManagerPtr, + const std::string &task_file, + const std::string &reference_file) + : ctrl_component_(ctrl_component), + referenceManagerPtr_(referenceManagerPtr) { + default_joint_state_ = vector_t::Zero(12); + loadData::loadCppDataType(reference_file, "comHeight", command_height_); + loadData::loadEigenMatrix(reference_file, "defaultJointState", default_joint_state_); + loadData::loadCppDataType(task_file, "mpc.timeHorizon", time_to_target_); + } + + void TargetManager::update() { + vector_t cmdGoal = vector_t::Zero(6); + cmdGoal[0] = ctrl_component_.control_inputs_.ly; + cmdGoal[1] = -ctrl_component_.control_inputs_.lx; + cmdGoal[2] = ctrl_component_.control_inputs_.ry; + cmdGoal[3] = -ctrl_component_.control_inputs_.rx; + + const vector_t currentPose = ctrl_component_.observation_.state.segment<6>(6); + const Eigen::Matrix zyx = currentPose.tail(3); + vector_t cmdVelRot = getRotationMatrixFromZyxEulerAngles(zyx) * cmdGoal.head(3); + + const scalar_t timeToTarget = time_to_target_; + const vector_t targetPose = [&]() { + vector_t target(6); + target(0) = currentPose(0) + cmdVelRot(0) * timeToTarget; + target(1) = currentPose(1) + cmdVelRot(1) * timeToTarget; + target(2) = command_height_; + target(3) = currentPose(3) + cmdGoal(3) * timeToTarget; + target(4) = 0; + target(5) = 0; + return target; + }(); + + const scalar_t targetReachingTime = ctrl_component_.observation_.time + timeToTarget; + auto trajectories = + targetPoseToTargetTrajectories(targetPose, ctrl_component_.observation_, targetReachingTime); + trajectories.stateTrajectory[0].head(3) = cmdVelRot; + trajectories.stateTrajectory[1].head(3) = cmdVelRot; + + referenceManagerPtr_->setTargetTrajectories(std::move(trajectories)); + } +}