From f8b3efdbdb287c2df2c6471c594767d1c4342d24 Mon Sep 17 00:00:00 2001 From: Huang Zhenbiao Date: Fri, 4 Oct 2024 20:14:34 +0800 Subject: [PATCH] can run if remove pytorch support. --- .../legged_gym_controller/CMakeLists.txt | 59 ++- controllers/legged_gym_controller/README.md | 14 +- .../legged_gym_controller/FSM/FSMState.h | 40 ++ .../FSM/StateFixedDown.h | 37 ++ .../FSM/StateFixedStand.h | 39 ++ .../legged_gym_controller/FSM/StatePassive.h | 24 + .../legged_gym_controller/common/enumClass.h | 21 + .../control/CtrlComponent.h | 61 +++ .../legged_gym_controller.xml | 9 + .../observation_buffer/observation_buffer.cpp | 45 ++ .../observation_buffer/observation_buffer.hpp | 24 + .../library/rl_sdk/rl_sdk.cpp | 98 ++-- .../library/rl_sdk/rl_sdk.hpp | 4 - .../src/FSM/StateFixedDown.cpp | 50 ++ .../src/FSM/StateFixedStand.cpp | 51 ++ .../src/FSM/StatePassive.cpp | 43 ++ .../src/LeggedGymController.cpp | 182 +++++++ .../src/LeggedGymController.h | 107 ++++ .../legged_gym_controller/src/rl_sdk.cpp | 467 ------------------ descriptions/unitree/a1_description/README.md | 6 +- .../a1_description/config/robot_control.yaml | 55 ++- .../launch/rl_control.launch.py | 112 +++++ 22 files changed, 1005 insertions(+), 543 deletions(-) create mode 100644 controllers/legged_gym_controller/include/legged_gym_controller/FSM/FSMState.h create mode 100644 controllers/legged_gym_controller/include/legged_gym_controller/FSM/StateFixedDown.h create mode 100644 controllers/legged_gym_controller/include/legged_gym_controller/FSM/StateFixedStand.h create mode 100644 controllers/legged_gym_controller/include/legged_gym_controller/FSM/StatePassive.h create mode 100644 controllers/legged_gym_controller/include/legged_gym_controller/common/enumClass.h create mode 100644 controllers/legged_gym_controller/include/legged_gym_controller/control/CtrlComponent.h create mode 100644 controllers/legged_gym_controller/legged_gym_controller.xml create mode 100644 controllers/legged_gym_controller/library/observation_buffer/observation_buffer.cpp create mode 100644 controllers/legged_gym_controller/library/observation_buffer/observation_buffer.hpp create mode 100644 controllers/legged_gym_controller/src/FSM/StateFixedDown.cpp create mode 100644 controllers/legged_gym_controller/src/FSM/StateFixedStand.cpp create mode 100644 controllers/legged_gym_controller/src/FSM/StatePassive.cpp create mode 100644 controllers/legged_gym_controller/src/LeggedGymController.cpp create mode 100644 controllers/legged_gym_controller/src/LeggedGymController.h delete mode 100644 controllers/legged_gym_controller/src/rl_sdk.cpp create mode 100644 descriptions/unitree/a1_description/launch/rl_control.launch.py diff --git a/controllers/legged_gym_controller/CMakeLists.txt b/controllers/legged_gym_controller/CMakeLists.txt index 63c1cba..1794551 100644 --- a/controllers/legged_gym_controller/CMakeLists.txt +++ b/controllers/legged_gym_controller/CMakeLists.txt @@ -1,4 +1,5 @@ -cmake_minimum_required(VERSION 3.8) +cmake_minimum_required(VERSION 3.18 FATAL_ERROR) +set(CMAKE_CXX_STANDARD 17) project(legged_gym_controller) if (CMAKE_COMPILER_IS_GNUCXX OR CMAKE_CXX_COMPILER_ID MATCHES "Clang") @@ -10,25 +11,11 @@ set(CMAKE_EXPORT_COMPILE_COMMANDS ON) find_package(ament_cmake REQUIRED) -# rl_sdk library -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}") -add_definitions(-DCMAKE_CURRENT_SOURCE_DIR="${CMAKE_CURRENT_SOURCE_DIR}") - find_package(Torch REQUIRED) -find_package(Python3 COMPONENTS Interpreter Development REQUIRED) +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}") -add_library(rl_sdk library/rl_sdk/rl_sdk.cpp) -target_include_directories(rl_sdk - PUBLIC - library/rl_sdk) -target_link_libraries(rl_sdk "${TORCH_LIBRARIES}" Python3::Python Python3::Module) -set_property(TARGET rl_sdk PROPERTY CXX_STANDARD 14) +find_package(Python3 COMPONENTS Interpreter Development REQUIRED) find_package(Python3 COMPONENTS NumPy) -if (Python3_NumPy_FOUND) - target_link_libraries(rl_sdk Python3::NumPy) -else () - target_compile_definitions(rl_sdk WITHOUT_NUMPY) -endif () set(dependencies pluginlib @@ -43,6 +30,44 @@ foreach (Dependency IN ITEMS ${dependencies}) find_package(${Dependency} REQUIRED) endforeach () +add_library(${PROJECT_NAME} SHARED + src/LeggedGymController.cpp + + src/FSM/StatePassive.cpp + src/FSM/StateFixedStand.cpp + src/FSM/StateFixedDown.cpp +) +target_include_directories(${PROJECT_NAME} + PUBLIC + "$" + "$" + PRIVATE + src) +target_link_libraries(${PROJECT_NAME} PUBLIC + "${TORCH_LIBRARIES}" +) +ament_target_dependencies( + ${PROJECT_NAME} PUBLIC + ${dependencies} +) + +pluginlib_export_plugin_description_file(controller_interface legged_gym_controller.xml) + +install( + DIRECTORY include/ + DESTINATION include/${PROJECT_NAME} +) + +install( + TARGETS ${PROJECT_NAME} + EXPORT export_${PROJECT_NAME} + ARCHIVE DESTINATION lib/${PROJECT_NAME} + LIBRARY DESTINATION lib/${PROJECT_NAME} + RUNTIME DESTINATION bin +) + +ament_export_dependencies(${dependencies}) +ament_export_targets(export_${PROJECT_NAME} HAS_LIBRARY_TARGET) if (BUILD_TESTING) find_package(ament_lint_auto REQUIRED) diff --git a/controllers/legged_gym_controller/README.md b/controllers/legged_gym_controller/README.md index d3ce62a..6913ede 100644 --- a/controllers/legged_gym_controller/README.md +++ b/controllers/legged_gym_controller/README.md @@ -8,13 +8,12 @@ Tested environment: ## 2. Build - ### 2.1 Installing libtorch ```bash cd ~/CLionProjects/ -wget https://download.pytorch.org/libtorch/cpu/libtorch-cxx11-abi-shared-with-deps-2.0.1%2Bcpu.zip -unzip libtorch-cxx11-abi-shared-with-deps-2.0.1+cpu.zip -d ./ -rm libtorch-cxx11-abi-shared-with-deps-2.0.1+cpu.zip +wget https://download.pytorch.org/libtorch/nightly/cpu/libtorch-shared-with-deps-latest.zip +unzip libtorch-shared-with-deps-latest.zip +rm -rf libtorch-shared-with-deps-latest.zip echo 'export Torch_DIR=~/CLionProjects/libtorch' >> ~/.bashrc ``` @@ -23,3 +22,10 @@ echo 'export Torch_DIR=~/CLionProjects/libtorch' >> ~/.bashrc cd ~/ros2_ws colcon build --packages-up-to legged_gym_controller ``` + +## 3. Launch +* Unitree A1 Robot + ```bash + source ~/ros2_ws/install/setup.bash + ros2 launch a1_description rl_control.launch.py + ``` \ No newline at end of file diff --git a/controllers/legged_gym_controller/include/legged_gym_controller/FSM/FSMState.h b/controllers/legged_gym_controller/include/legged_gym_controller/FSM/FSMState.h new file mode 100644 index 0000000..cfed755 --- /dev/null +++ b/controllers/legged_gym_controller/include/legged_gym_controller/FSM/FSMState.h @@ -0,0 +1,40 @@ +// +// Created by tlab-uav on 24-9-6. +// + +#ifndef FSMSTATE_H +#define FSMSTATE_H + +#include +#include +#include + +#include "legged_gym_controller/common/enumClass.h" +#include "legged_gym_controller/control/CtrlComponent.h" + +class FSMState { +public: + virtual ~FSMState() = default; + + FSMState(const FSMStateName &state_name, std::string state_name_string, CtrlComponent &ctrl_comp) + : state_name(state_name), + state_name_string(std::move(state_name_string)), + ctrl_comp_(ctrl_comp) { + } + + virtual void enter() = 0; + + virtual void run() = 0; + + virtual void exit() = 0; + + virtual FSMStateName checkChange() { return FSMStateName::INVALID; } + + FSMStateName state_name; + std::string state_name_string; + +protected: + CtrlComponent &ctrl_comp_; +}; + +#endif //FSMSTATE_H diff --git a/controllers/legged_gym_controller/include/legged_gym_controller/FSM/StateFixedDown.h b/controllers/legged_gym_controller/include/legged_gym_controller/FSM/StateFixedDown.h new file mode 100644 index 0000000..f46b42f --- /dev/null +++ b/controllers/legged_gym_controller/include/legged_gym_controller/FSM/StateFixedDown.h @@ -0,0 +1,37 @@ +// +// Created by tlab-uav on 24-9-11. +// + +#ifndef STATEFIXEDDOWN_H +#define STATEFIXEDDOWN_H + +#include "FSMState.h" + +class StateFixedDown final : public FSMState { +public: + explicit StateFixedDown(CtrlComponent &ctrlComp); + + void enter() override; + + void run() override; + + void exit() override; + + FSMStateName checkChange() override; +private: + double target_pos_[12] = { + 0.0473455, 1.22187, -2.44375, -0.0473455, + 1.22187, -2.44375, 0.0473455, 1.22187, + -2.44375, -0.0473455, 1.22187, -2.44375 + }; + + double start_pos_[12] = {}; + rclcpp::Time start_time_; + + double duration_ = 600; // steps + double percent_ = 0; //% + double phase = 0.0; +}; + + +#endif //STATEFIXEDDOWN_H diff --git a/controllers/legged_gym_controller/include/legged_gym_controller/FSM/StateFixedStand.h b/controllers/legged_gym_controller/include/legged_gym_controller/FSM/StateFixedStand.h new file mode 100644 index 0000000..473cbd5 --- /dev/null +++ b/controllers/legged_gym_controller/include/legged_gym_controller/FSM/StateFixedStand.h @@ -0,0 +1,39 @@ +// +// Created by biao on 24-9-10. +// + +#ifndef STATEFIXEDSTAND_H +#define STATEFIXEDSTAND_H + +#include "FSMState.h" + +class StateFixedStand final : public FSMState { +public: + explicit StateFixedStand(CtrlComponent &ctrlComp); + + void enter() override; + + void run() override; + + void exit() override; + + FSMStateName checkChange() override; + +private: + double target_pos_[12] = { + 0.00571868, 0.608813, -1.21763, + -0.00571868, 0.608813, -1.21763, + 0.00571868, 0.608813, -1.21763, + -0.00571868, 0.608813, -1.21763 + }; + + double start_pos_[12] = {}; + rclcpp::Time start_time_; + + double duration_ = 600; // steps + double percent_ = 0; //% + double phase = 0.0; +}; + + +#endif //STATEFIXEDSTAND_H diff --git a/controllers/legged_gym_controller/include/legged_gym_controller/FSM/StatePassive.h b/controllers/legged_gym_controller/include/legged_gym_controller/FSM/StatePassive.h new file mode 100644 index 0000000..9b186c4 --- /dev/null +++ b/controllers/legged_gym_controller/include/legged_gym_controller/FSM/StatePassive.h @@ -0,0 +1,24 @@ +// +// Created by tlab-uav on 24-9-6. +// + +#ifndef STATEPASSIVE_H +#define STATEPASSIVE_H +#include "FSMState.h" + + +class StatePassive final : public FSMState { +public: + explicit StatePassive(CtrlComponent &ctrlComp); + + void enter() override; + + void run() override; + + void exit() override; + + FSMStateName checkChange() override; +}; + + +#endif //STATEPASSIVE_H diff --git a/controllers/legged_gym_controller/include/legged_gym_controller/common/enumClass.h b/controllers/legged_gym_controller/include/legged_gym_controller/common/enumClass.h new file mode 100644 index 0000000..f942ecc --- /dev/null +++ b/controllers/legged_gym_controller/include/legged_gym_controller/common/enumClass.h @@ -0,0 +1,21 @@ +// +// Created by tlab-uav on 24-9-6. +// + +#ifndef ENUMCLASS_H +#define ENUMCLASS_H + +enum class FSMStateName { + // EXIT, + INVALID, + PASSIVE, + FIXEDDOWN, + FIXEDSTAND, +}; + +enum class FSMMode { + NORMAL, + CHANGE +}; + +#endif //ENUMCLASS_H diff --git a/controllers/legged_gym_controller/include/legged_gym_controller/control/CtrlComponent.h b/controllers/legged_gym_controller/include/legged_gym_controller/control/CtrlComponent.h new file mode 100644 index 0000000..7ca3092 --- /dev/null +++ b/controllers/legged_gym_controller/include/legged_gym_controller/control/CtrlComponent.h @@ -0,0 +1,61 @@ +// +// Created by biao on 24-9-10. +// + +#ifndef INTERFACE_H +#define INTERFACE_H + +#include +#include +#include +#include + +struct CtrlComponent { + std::vector > + joint_torque_command_interface_; + std::vector > + joint_position_command_interface_; + std::vector > + joint_velocity_command_interface_; + std::vector > + joint_kp_command_interface_; + std::vector > + joint_kd_command_interface_; + + + std::vector > + joint_effort_state_interface_; + std::vector > + joint_position_state_interface_; + std::vector > + joint_velocity_state_interface_; + + std::vector > + imu_state_interface_; + + std::vector > + foot_force_state_interface_; + + control_input_msgs::msg::Inputs control_inputs_; + int frequency_{}; + + CtrlComponent() { + } + + void clear() { + joint_torque_command_interface_.clear(); + joint_position_command_interface_.clear(); + joint_velocity_command_interface_.clear(); + joint_kd_command_interface_.clear(); + joint_kp_command_interface_.clear(); + + joint_effort_state_interface_.clear(); + joint_position_state_interface_.clear(); + joint_velocity_state_interface_.clear(); + + imu_state_interface_.clear(); + foot_force_state_interface_.clear(); + } +}; + +#endif //INTERFACE_H diff --git a/controllers/legged_gym_controller/legged_gym_controller.xml b/controllers/legged_gym_controller/legged_gym_controller.xml new file mode 100644 index 0000000..9c6aafb --- /dev/null +++ b/controllers/legged_gym_controller/legged_gym_controller.xml @@ -0,0 +1,9 @@ + + + + Quadruped Controller used RL-model trained in Legged Gym. + + + diff --git a/controllers/legged_gym_controller/library/observation_buffer/observation_buffer.cpp b/controllers/legged_gym_controller/library/observation_buffer/observation_buffer.cpp new file mode 100644 index 0000000..01d2cae --- /dev/null +++ b/controllers/legged_gym_controller/library/observation_buffer/observation_buffer.cpp @@ -0,0 +1,45 @@ +#include "observation_buffer.hpp" + +ObservationBuffer::ObservationBuffer() {} + +ObservationBuffer::ObservationBuffer(int num_envs, + int num_obs, + int include_history_steps) + : num_envs(num_envs), + num_obs(num_obs), + include_history_steps(include_history_steps) +{ + num_obs_total = num_obs * include_history_steps; + obs_buf = torch::zeros({num_envs, num_obs_total}, torch::dtype(torch::kFloat32)); +} + +void ObservationBuffer::reset(std::vector reset_idxs, torch::Tensor new_obs) +{ + std::vector indices; + for (int idx : reset_idxs) { + indices.push_back(torch::indexing::Slice(idx)); + } + obs_buf.index_put_(indices, new_obs.repeat({1, include_history_steps})); +} + +void ObservationBuffer::insert(torch::Tensor new_obs) +{ + // Shift observations back. + torch::Tensor shifted_obs = obs_buf.index({torch::indexing::Slice(torch::indexing::None), torch::indexing::Slice(num_obs, num_obs * include_history_steps)}).clone(); + obs_buf.index({torch::indexing::Slice(torch::indexing::None), torch::indexing::Slice(0, num_obs * (include_history_steps - 1))}) = shifted_obs; + + // Add new observation. + obs_buf.index({torch::indexing::Slice(torch::indexing::None), torch::indexing::Slice(-num_obs, torch::indexing::None)}) = new_obs; +} + +torch::Tensor ObservationBuffer::get_obs_vec(std::vector obs_ids) +{ + std::vector obs; + for (int i = obs_ids.size() - 1; i >= 0; --i) + { + int obs_id = obs_ids[i]; + int slice_idx = include_history_steps - obs_id - 1; + obs.push_back(obs_buf.index({torch::indexing::Slice(torch::indexing::None), torch::indexing::Slice(slice_idx * num_obs, (slice_idx + 1) * num_obs)})); + } + return cat(obs, -1); +} diff --git a/controllers/legged_gym_controller/library/observation_buffer/observation_buffer.hpp b/controllers/legged_gym_controller/library/observation_buffer/observation_buffer.hpp new file mode 100644 index 0000000..72be75c --- /dev/null +++ b/controllers/legged_gym_controller/library/observation_buffer/observation_buffer.hpp @@ -0,0 +1,24 @@ +#ifndef OBSERVATION_BUFFER_HPP +#define OBSERVATION_BUFFER_HPP + +#include +#include + +class ObservationBuffer { +public: + ObservationBuffer(int num_envs, int num_obs, int include_history_steps); + ObservationBuffer(); + + void reset(std::vector reset_idxs, torch::Tensor new_obs); + void insert(torch::Tensor new_obs); + torch::Tensor get_obs_vec(std::vector obs_ids); + +private: + int num_envs; + int num_obs; + int include_history_steps; + int num_obs_total; + torch::Tensor obs_buf; +}; + +#endif // OBSERVATION_BUFFER_HPP \ No newline at end of file diff --git a/controllers/legged_gym_controller/library/rl_sdk/rl_sdk.cpp b/controllers/legged_gym_controller/library/rl_sdk/rl_sdk.cpp index 9529b2b..1a916df 100644 --- a/controllers/legged_gym_controller/library/rl_sdk/rl_sdk.cpp +++ b/controllers/legged_gym_controller/library/rl_sdk/rl_sdk.cpp @@ -374,55 +374,55 @@ std::vector ReadVectorFromYaml(const YAML::Node& node, const std::string& fra void RL::ReadYaml(std::string robot_name) { - // The config file is located at "rl_sar/src/rl_sar/models//config.yaml" - std::string config_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/" + robot_name + "/config.yaml"; - YAML::Node config; - try - { - config = YAML::LoadFile(config_path)[robot_name]; - } catch(YAML::BadFile &e) - { - std::cout << LOGGER::ERROR << "The file '" << config_path << "' does not exist" << std::endl; - return; - } - - this->params.model_name = config["model_name"].as(); - this->params.framework = config["framework"].as(); - int rows = config["rows"].as(); - int cols = config["cols"].as(); - this->params.use_history = config["use_history"].as(); - this->params.dt = config["dt"].as(); - this->params.decimation = config["decimation"].as(); - this->params.num_observations = config["num_observations"].as(); - this->params.observations = ReadVectorFromYaml(config["observations"]); - this->params.clip_obs = config["clip_obs"].as(); - if(config["clip_actions_lower"].IsNull() && config["clip_actions_upper"].IsNull()) - { - this->params.clip_actions_upper = torch::tensor({}).view({1, -1}); - this->params.clip_actions_lower = torch::tensor({}).view({1, -1}); - } - else - { - this->params.clip_actions_upper = torch::tensor(ReadVectorFromYaml(config["clip_actions_upper"], this->params.framework, rows, cols)).view({1, -1}); - this->params.clip_actions_lower = torch::tensor(ReadVectorFromYaml(config["clip_actions_lower"], this->params.framework, rows, cols)).view({1, -1}); - } - this->params.action_scale = config["action_scale"].as(); - this->params.hip_scale_reduction = config["hip_scale_reduction"].as(); - this->params.hip_scale_reduction_indices = ReadVectorFromYaml(config["hip_scale_reduction_indices"]); - this->params.num_of_dofs = config["num_of_dofs"].as(); - this->params.lin_vel_scale = config["lin_vel_scale"].as(); - this->params.ang_vel_scale = config["ang_vel_scale"].as(); - this->params.dof_pos_scale = config["dof_pos_scale"].as(); - this->params.dof_vel_scale = config["dof_vel_scale"].as(); - // this->params.commands_scale = torch::tensor(ReadVectorFromYaml(config["commands_scale"])).view({1, -1}); - this->params.commands_scale = torch::tensor({this->params.lin_vel_scale, this->params.lin_vel_scale, this->params.ang_vel_scale}); - this->params.rl_kp = torch::tensor(ReadVectorFromYaml(config["rl_kp"], this->params.framework, rows, cols)).view({1, -1}); - this->params.rl_kd = torch::tensor(ReadVectorFromYaml(config["rl_kd"], this->params.framework, rows, cols)).view({1, -1}); - this->params.fixed_kp = torch::tensor(ReadVectorFromYaml(config["fixed_kp"], this->params.framework, rows, cols)).view({1, -1}); - this->params.fixed_kd = torch::tensor(ReadVectorFromYaml(config["fixed_kd"], this->params.framework, rows, cols)).view({1, -1}); - this->params.torque_limits = torch::tensor(ReadVectorFromYaml(config["torque_limits"], this->params.framework, rows, cols)).view({1, -1}); - this->params.default_dof_pos = torch::tensor(ReadVectorFromYaml(config["default_dof_pos"], this->params.framework, rows, cols)).view({1, -1}); - this->params.joint_controller_names = ReadVectorFromYaml(config["joint_controller_names"], this->params.framework, rows, cols); + // // The config file is located at "rl_sar/src/rl_sar/models//config.yaml" + // std::string config_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/" + robot_name + "/config.yaml"; + // YAML::Node config; + // try + // { + // config = YAML::LoadFile(config_path)[robot_name]; + // } catch(YAML::BadFile &e) + // { + // std::cout << LOGGER::ERROR << "The file '" << config_path << "' does not exist" << std::endl; + // return; + // } + // + // this->params.model_name = config["model_name"].as(); + // this->params.framework = config["framework"].as(); + // int rows = config["rows"].as(); + // int cols = config["cols"].as(); + // this->params.use_history = config["use_history"].as(); + // this->params.dt = config["dt"].as(); + // this->params.decimation = config["decimation"].as(); + // this->params.num_observations = config["num_observations"].as(); + // this->params.observations = ReadVectorFromYaml(config["observations"]); + // this->params.clip_obs = config["clip_obs"].as(); + // if(config["clip_actions_lower"].IsNull() && config["clip_actions_upper"].IsNull()) + // { + // this->params.clip_actions_upper = torch::tensor({}).view({1, -1}); + // this->params.clip_actions_lower = torch::tensor({}).view({1, -1}); + // } + // else + // { + // this->params.clip_actions_upper = torch::tensor(ReadVectorFromYaml(config["clip_actions_upper"], this->params.framework, rows, cols)).view({1, -1}); + // this->params.clip_actions_lower = torch::tensor(ReadVectorFromYaml(config["clip_actions_lower"], this->params.framework, rows, cols)).view({1, -1}); + // } + // this->params.action_scale = config["action_scale"].as(); + // this->params.hip_scale_reduction = config["hip_scale_reduction"].as(); + // this->params.hip_scale_reduction_indices = ReadVectorFromYaml(config["hip_scale_reduction_indices"]); + // this->params.num_of_dofs = config["num_of_dofs"].as(); + // this->params.lin_vel_scale = config["lin_vel_scale"].as(); + // this->params.ang_vel_scale = config["ang_vel_scale"].as(); + // this->params.dof_pos_scale = config["dof_pos_scale"].as(); + // this->params.dof_vel_scale = config["dof_vel_scale"].as(); + // // this->params.commands_scale = torch::tensor(ReadVectorFromYaml(config["commands_scale"])).view({1, -1}); + // this->params.commands_scale = torch::tensor({this->params.lin_vel_scale, this->params.lin_vel_scale, this->params.ang_vel_scale}); + // this->params.rl_kp = torch::tensor(ReadVectorFromYaml(config["rl_kp"], this->params.framework, rows, cols)).view({1, -1}); + // this->params.rl_kd = torch::tensor(ReadVectorFromYaml(config["rl_kd"], this->params.framework, rows, cols)).view({1, -1}); + // this->params.fixed_kp = torch::tensor(ReadVectorFromYaml(config["fixed_kp"], this->params.framework, rows, cols)).view({1, -1}); + // this->params.fixed_kd = torch::tensor(ReadVectorFromYaml(config["fixed_kd"], this->params.framework, rows, cols)).view({1, -1}); + // this->params.torque_limits = torch::tensor(ReadVectorFromYaml(config["torque_limits"], this->params.framework, rows, cols)).view({1, -1}); + // this->params.default_dof_pos = torch::tensor(ReadVectorFromYaml(config["default_dof_pos"], this->params.framework, rows, cols)).view({1, -1}); + // this->params.joint_controller_names = ReadVectorFromYaml(config["joint_controller_names"], this->params.framework, rows, cols); } void RL::CSVInit(std::string robot_name) diff --git a/controllers/legged_gym_controller/library/rl_sdk/rl_sdk.hpp b/controllers/legged_gym_controller/library/rl_sdk/rl_sdk.hpp index 703770b..0194704 100644 --- a/controllers/legged_gym_controller/library/rl_sdk/rl_sdk.hpp +++ b/controllers/legged_gym_controller/library/rl_sdk/rl_sdk.hpp @@ -2,11 +2,7 @@ #define RL_SDK_HPP #include -#include #include -#include - -#include namespace LOGGER { const char* const INFO = "\033[0;37m[INFO]\033[0m "; diff --git a/controllers/legged_gym_controller/src/FSM/StateFixedDown.cpp b/controllers/legged_gym_controller/src/FSM/StateFixedDown.cpp new file mode 100644 index 0000000..8f10606 --- /dev/null +++ b/controllers/legged_gym_controller/src/FSM/StateFixedDown.cpp @@ -0,0 +1,50 @@ +// +// Created by tlab-uav on 24-9-11. +// + +#include "legged_gym_controller/FSM/StateFixedDown.h" + +#include + +StateFixedDown::StateFixedDown(CtrlComponent &ctrlComp): FSMState( + FSMStateName::FIXEDDOWN, "fixed down", ctrlComp) { + duration_ = ctrl_comp_.frequency_ * 1.2; +} + +void StateFixedDown::enter() { + for (int i = 0; i < 12; i++) { + start_pos_[i] = ctrl_comp_.joint_position_state_interface_[i].get().get_value(); + } + ctrl_comp_.control_inputs_.command = 0; +} + +void StateFixedDown::run() { + percent_ += 1 / duration_; + phase = std::tanh(percent_); + for (int i = 0; i < 12; i++) { + ctrl_comp_.joint_position_command_interface_[i].get().set_value( + phase * target_pos_[i] + (1 - phase) * start_pos_[i]); + ctrl_comp_.joint_velocity_command_interface_[i].get().set_value(0); + ctrl_comp_.joint_torque_command_interface_[i].get().set_value(0); + ctrl_comp_.joint_kp_command_interface_[i].get().set_value(30.0); + ctrl_comp_.joint_kd_command_interface_[i].get().set_value(1.5); + } +} + +void StateFixedDown::exit() { + percent_ = 0; +} + +FSMStateName StateFixedDown::checkChange() { + if (percent_ < 1.5) { + return FSMStateName::FIXEDDOWN; + } + switch (ctrl_comp_.control_inputs_.command) { + case 1: + return FSMStateName::PASSIVE; + case 2: + return FSMStateName::FIXEDSTAND; + default: + return FSMStateName::FIXEDDOWN; + } +} diff --git a/controllers/legged_gym_controller/src/FSM/StateFixedStand.cpp b/controllers/legged_gym_controller/src/FSM/StateFixedStand.cpp new file mode 100644 index 0000000..ef8b73e --- /dev/null +++ b/controllers/legged_gym_controller/src/FSM/StateFixedStand.cpp @@ -0,0 +1,51 @@ +// +// Created by biao on 24-9-10. +// + +#include "legged_gym_controller/FSM/StateFixedStand.h" + +#include + +StateFixedStand::StateFixedStand(CtrlComponent &ctrlComp): FSMState( + FSMStateName::FIXEDSTAND, "fixed stand", ctrlComp) { + duration_ = ctrl_comp_.frequency_ * 1.2; +} + +void StateFixedStand::enter() { + for (int i = 0; i < 12; i++) { + start_pos_[i] = ctrl_comp_.joint_position_state_interface_[i].get().get_value(); + } + ctrl_comp_.control_inputs_.command = 0; +} + +void StateFixedStand::run() { + percent_ += 1 / duration_; + phase = std::tanh(percent_); + for (int i = 0; i < 12; i++) { + ctrl_comp_.joint_position_command_interface_[i].get().set_value( + phase * target_pos_[i] + (1 - phase) * start_pos_[i]); + ctrl_comp_.joint_velocity_command_interface_[i].get().set_value(0); + ctrl_comp_.joint_torque_command_interface_[i].get().set_value(0); + ctrl_comp_.joint_kp_command_interface_[i].get().set_value( + phase * 60.0 + (1 - phase) * 20.0); + ctrl_comp_.joint_kd_command_interface_[i].get().set_value(3.5); + } +} + +void StateFixedStand::exit() { + percent_ = 0; +} + +FSMStateName StateFixedStand::checkChange() { + if (percent_ < 1.5) { + return FSMStateName::FIXEDSTAND; + } + switch (ctrl_comp_.control_inputs_.command) { + case 1: + return FSMStateName::PASSIVE; + case 2: + return FSMStateName::FIXEDDOWN; + default: + return FSMStateName::FIXEDSTAND; + } +} diff --git a/controllers/legged_gym_controller/src/FSM/StatePassive.cpp b/controllers/legged_gym_controller/src/FSM/StatePassive.cpp new file mode 100644 index 0000000..fe25be5 --- /dev/null +++ b/controllers/legged_gym_controller/src/FSM/StatePassive.cpp @@ -0,0 +1,43 @@ +// +// Created by tlab-uav on 24-9-6. +// + +#include "legged_gym_controller/FSM/StatePassive.h" + +#include + +StatePassive::StatePassive(CtrlComponent &ctrlComp) : FSMState( + FSMStateName::PASSIVE, "passive", ctrlComp) { +} + +void StatePassive::enter() { + for (auto i: ctrl_comp_.joint_torque_command_interface_) { + i.get().set_value(0); + } + for (auto i: ctrl_comp_.joint_position_command_interface_) { + i.get().set_value(0); + } + for (auto i: ctrl_comp_.joint_velocity_command_interface_) { + i.get().set_value(0); + } + for (auto i: ctrl_comp_.joint_kp_command_interface_) { + i.get().set_value(0); + } + for (auto i: ctrl_comp_.joint_kd_command_interface_) { + i.get().set_value(1); + } + ctrl_comp_.control_inputs_.command = 0; +} + +void StatePassive::run() { +} + +void StatePassive::exit() { +} + +FSMStateName StatePassive::checkChange() { + if (ctrl_comp_.control_inputs_.command == 2) { + return FSMStateName::FIXEDDOWN; + } + return FSMStateName::PASSIVE; +} diff --git a/controllers/legged_gym_controller/src/LeggedGymController.cpp b/controllers/legged_gym_controller/src/LeggedGymController.cpp new file mode 100644 index 0000000..a9cd93f --- /dev/null +++ b/controllers/legged_gym_controller/src/LeggedGymController.cpp @@ -0,0 +1,182 @@ +// +// Created by tlab-uav on 24-10-4. +// + +#include "LeggedGymController.h" + +namespace legged_gym_controller { + using config_type = controller_interface::interface_configuration_type; + + controller_interface::InterfaceConfiguration LeggedGymController::command_interface_configuration() const { + controller_interface::InterfaceConfiguration conf = {config_type::INDIVIDUAL, {}}; + + conf.names.reserve(joint_names_.size() * command_interface_types_.size()); + for (const auto &joint_name: joint_names_) { + for (const auto &interface_type: command_interface_types_) { + conf.names.push_back(joint_name + "/" += interface_type); + } + } + + return conf; + } + + controller_interface::InterfaceConfiguration LeggedGymController::state_interface_configuration() const { + controller_interface::InterfaceConfiguration conf = {config_type::INDIVIDUAL, {}}; + + conf.names.reserve(joint_names_.size() * state_interface_types_.size()); + for (const auto &joint_name: joint_names_) { + for (const auto &interface_type: state_interface_types_) { + conf.names.push_back(joint_name + "/" += interface_type); + } + } + + for (const auto &interface_type: imu_interface_types_) { + conf.names.push_back(imu_name_ + "/" += interface_type); + } + + for (const auto &interface_type: foot_force_interface_types_) { + conf.names.push_back(foot_force_name_ + "/" += interface_type); + } + + return conf; + } + + controller_interface::return_type LeggedGymController:: + update(const rclcpp::Time & /*time*/, const rclcpp::Duration & /*period*/) { + if (mode_ == FSMMode::NORMAL) { + current_state_->run(); + next_state_name_ = current_state_->checkChange(); + if (next_state_name_ != current_state_->state_name) { + mode_ = FSMMode::CHANGE; + next_state_ = getNextState(next_state_name_); + RCLCPP_INFO(get_node()->get_logger(), "Switched from %s to %s", + current_state_->state_name_string.c_str(), next_state_->state_name_string.c_str()); + } + } else if (mode_ == FSMMode::CHANGE) { + current_state_->exit(); + current_state_ = next_state_; + + current_state_->enter(); + mode_ = FSMMode::NORMAL; + } + + return controller_interface::return_type::OK; + } + + controller_interface::CallbackReturn LeggedGymController::on_init() { + try { + joint_names_ = auto_declare >("joints", joint_names_); + feet_names_ = auto_declare >("feet_names", feet_names_); + command_interface_types_ = + auto_declare >("command_interfaces", command_interface_types_); + state_interface_types_ = + auto_declare >("state_interfaces", state_interface_types_); + + // imu sensor + imu_name_ = auto_declare("imu_name", imu_name_); + imu_interface_types_ = auto_declare >("imu_interfaces", state_interface_types_); + } catch (const std::exception &e) { + fprintf(stderr, "Exception thrown during init stage with message: %s \n", e.what()); + return controller_interface::CallbackReturn::ERROR; + } + + return CallbackReturn::SUCCESS; + } + + controller_interface::CallbackReturn LeggedGymController::on_configure( + const rclcpp_lifecycle::State & /*previous_state*/) { + control_input_subscription_ = get_node()->create_subscription( + "/control_input", 10, [this](const control_input_msgs::msg::Inputs::SharedPtr msg) { + // Handle message + ctrl_comp_.control_inputs_.command = msg->command; + ctrl_comp_.control_inputs_.lx = msg->lx; + ctrl_comp_.control_inputs_.ly = msg->ly; + ctrl_comp_.control_inputs_.rx = msg->rx; + ctrl_comp_.control_inputs_.ry = msg->ry; + }); + + get_node()->get_parameter("update_rate", ctrl_comp_.frequency_); + RCLCPP_INFO(get_node()->get_logger(), "Controller Manager Update Rate: %d Hz", ctrl_comp_.frequency_); + + return CallbackReturn::SUCCESS; + } + + controller_interface::CallbackReturn LeggedGymController::on_activate( + const rclcpp_lifecycle::State & /*previous_state*/) { + // clear out vectors in case of restart + ctrl_comp_.clear(); + + // assign command interfaces + for (auto &interface: command_interfaces_) { + std::string interface_name = interface.get_interface_name(); + if (const size_t pos = interface_name.find('/'); pos != std::string::npos) { + command_interface_map_[interface_name.substr(pos + 1)]->push_back(interface); + } else { + command_interface_map_[interface_name]->push_back(interface); + } + } + + // assign state interfaces + for (auto &interface: state_interfaces_) { + if (interface.get_prefix_name() == imu_name_) { + ctrl_comp_.imu_state_interface_.emplace_back(interface); + } else if (interface.get_prefix_name() == foot_force_name_) { + ctrl_comp_.foot_force_state_interface_.emplace_back(interface); + } else { + state_interface_map_[interface.get_interface_name()]->push_back(interface); + } + } + + // Create FSM List + state_list_.passive = std::make_shared(ctrl_comp_); + state_list_.fixedDown = std::make_shared(ctrl_comp_); + state_list_.fixedStand = std::make_shared(ctrl_comp_); + + // Initialize FSM + current_state_ = state_list_.passive; + current_state_->enter(); + next_state_ = current_state_; + next_state_name_ = current_state_->state_name; + mode_ = FSMMode::NORMAL; + + return CallbackReturn::SUCCESS; + } + + controller_interface::CallbackReturn LeggedGymController::on_deactivate( + const rclcpp_lifecycle::State & /*previous_state*/) { + release_interfaces(); + return CallbackReturn::SUCCESS; + } + + controller_interface::CallbackReturn + LeggedGymController::on_cleanup(const rclcpp_lifecycle::State &previous_state) { + return ControllerInterface::on_cleanup(previous_state); + } + + controller_interface::CallbackReturn + LeggedGymController::on_shutdown(const rclcpp_lifecycle::State &previous_state) { + return ControllerInterface::on_shutdown(previous_state); + } + + controller_interface::CallbackReturn LeggedGymController::on_error(const rclcpp_lifecycle::State &previous_state) { + return ControllerInterface::on_error(previous_state); + } + + std::shared_ptr LeggedGymController::getNextState(FSMStateName stateName) const { + switch (stateName) { + case FSMStateName::INVALID: + return state_list_.invalid; + case FSMStateName::PASSIVE: + return state_list_.passive; + case FSMStateName::FIXEDDOWN: + return state_list_.fixedDown; + case FSMStateName::FIXEDSTAND: + return state_list_.fixedStand; + default: + return state_list_.invalid; + } + } +} + +#include "pluginlib/class_list_macros.hpp" +PLUGINLIB_EXPORT_CLASS(legged_gym_controller::LeggedGymController, controller_interface::ControllerInterface); diff --git a/controllers/legged_gym_controller/src/LeggedGymController.h b/controllers/legged_gym_controller/src/LeggedGymController.h new file mode 100644 index 0000000..ddee6ac --- /dev/null +++ b/controllers/legged_gym_controller/src/LeggedGymController.h @@ -0,0 +1,107 @@ +// +// Created by tlab-uav on 24-10-4. +// + +#ifndef LEGGEDGYMCONTROLLER_H +#define LEGGEDGYMCONTROLLER_H +#include +#include "legged_gym_controller/FSM/StateFixedStand.h" +#include "legged_gym_controller/FSM/StateFixedDown.h" +#include "legged_gym_controller/FSM/StatePassive.h" +#include "legged_gym_controller/control/CtrlComponent.h" + +namespace legged_gym_controller { + struct FSMStateList { + std::shared_ptr invalid; + std::shared_ptr passive; + std::shared_ptr fixedDown; + std::shared_ptr fixedStand; + }; + + class LeggedGymController final : public controller_interface::ControllerInterface { + public: + CONTROLLER_INTERFACE_PUBLIC + LeggedGymController() = default; + + CONTROLLER_INTERFACE_PUBLIC + controller_interface::InterfaceConfiguration command_interface_configuration() const override; + + CONTROLLER_INTERFACE_PUBLIC + controller_interface::InterfaceConfiguration state_interface_configuration() const override; + + CONTROLLER_INTERFACE_PUBLIC + controller_interface::return_type update( + const rclcpp::Time &time, const rclcpp::Duration &period) override; + + CONTROLLER_INTERFACE_PUBLIC + controller_interface::CallbackReturn on_init() override; + + CONTROLLER_INTERFACE_PUBLIC + controller_interface::CallbackReturn on_configure( + const rclcpp_lifecycle::State &previous_state) override; + + CONTROLLER_INTERFACE_PUBLIC + controller_interface::CallbackReturn on_activate( + const rclcpp_lifecycle::State &previous_state) override; + + CONTROLLER_INTERFACE_PUBLIC + controller_interface::CallbackReturn on_deactivate( + const rclcpp_lifecycle::State &previous_state) override; + + CONTROLLER_INTERFACE_PUBLIC + controller_interface::CallbackReturn on_cleanup( + const rclcpp_lifecycle::State &previous_state) override; + + CONTROLLER_INTERFACE_PUBLIC + controller_interface::CallbackReturn on_shutdown( + const rclcpp_lifecycle::State &previous_state) override; + + CONTROLLER_INTERFACE_PUBLIC + controller_interface::CallbackReturn on_error( + const rclcpp_lifecycle::State &previous_state) override; + + private: + std::shared_ptr getNextState(FSMStateName stateName) const; + + CtrlComponent ctrl_comp_; + std::vector joint_names_; + std::vector feet_names_; + std::vector command_interface_types_; + std::vector state_interface_types_; + + // IMU Sensor + std::string imu_name_; + std::vector imu_interface_types_; + // Foot Force Sensor + std::string foot_force_name_; + std::vector foot_force_interface_types_; + + std::unordered_map< + std::string, std::vector > *> + command_interface_map_ = { + {"effort", &ctrl_comp_.joint_torque_command_interface_}, + {"position", &ctrl_comp_.joint_position_command_interface_}, + {"velocity", &ctrl_comp_.joint_velocity_command_interface_}, + {"kp", &ctrl_comp_.joint_kp_command_interface_}, + {"kd", &ctrl_comp_.joint_kd_command_interface_} + }; + + std::unordered_map< + std::string, std::vector > *> + state_interface_map_ = { + {"position", &ctrl_comp_.joint_position_state_interface_}, + {"effort", &ctrl_comp_.joint_effort_state_interface_}, + {"velocity", &ctrl_comp_.joint_velocity_state_interface_} + }; + + rclcpp::Subscription::SharedPtr control_input_subscription_; + + FSMMode mode_ = FSMMode::NORMAL; + std::string state_name_; + FSMStateName next_state_name_ = FSMStateName::INVALID; + FSMStateList state_list_; + std::shared_ptr current_state_; + std::shared_ptr next_state_; + }; +} +#endif //LEGGEDGYMCONTROLLER_H diff --git a/controllers/legged_gym_controller/src/rl_sdk.cpp b/controllers/legged_gym_controller/src/rl_sdk.cpp deleted file mode 100644 index 0a24bb5..0000000 --- a/controllers/legged_gym_controller/src/rl_sdk.cpp +++ /dev/null @@ -1,467 +0,0 @@ -#include "legged_gym_controller/rl_sdk.hpp" - -/* You may need to override this Forward() function -torch::Tensor RL_XXX::Forward() -{ - torch::autograd::GradMode::set_enabled(false); - torch::Tensor clamped_obs = this->ComputeObservation(); - torch::Tensor actions = this->model.forward({clamped_obs}).toTensor(); - torch::Tensor clamped_actions = torch::clamp(actions, this->params.clip_actions_lower, this->params.clip_actions_upper); - return clamped_actions; -} -*/ - -torch::Tensor RL::ComputeObservation() -{ - std::vector obs_list; - - for(const std::string& observation : this->params.observations) - { - if(observation == "lin_vel") - { - obs_list.push_back(this->obs.lin_vel * this->params.lin_vel_scale); - } - else if(observation == "ang_vel") - { - // obs_list.push_back(this->obs.ang_vel * this->params.ang_vel_scale); // TODO is QuatRotateInverse necessery? - obs_list.push_back(this->QuatRotateInverse(this->obs.base_quat, this->obs.ang_vel, this->params.framework) * this->params.ang_vel_scale); - } - else if(observation == "gravity_vec") - { - obs_list.push_back(this->QuatRotateInverse(this->obs.base_quat, this->obs.gravity_vec, this->params.framework)); - } - else if(observation == "commands") - { - obs_list.push_back(this->obs.commands * this->params.commands_scale); - } - else if(observation == "dof_pos") - { - obs_list.push_back((this->obs.dof_pos - this->params.default_dof_pos) * this->params.dof_pos_scale); - } - else if(observation == "dof_vel") - { - obs_list.push_back(this->obs.dof_vel * this->params.dof_vel_scale); - } - else if(observation == "actions") - { - obs_list.push_back(this->obs.actions); - } - } - - torch::Tensor obs = torch::cat(obs_list, 1); - torch::Tensor clamped_obs = torch::clamp(obs, -this->params.clip_obs, this->params.clip_obs); - return clamped_obs; -} - -void RL::InitObservations() -{ - this->obs.lin_vel = torch::tensor({{0.0, 0.0, 0.0}}); - this->obs.ang_vel = torch::tensor({{0.0, 0.0, 0.0}}); - this->obs.gravity_vec = torch::tensor({{0.0, 0.0, -1.0}}); - this->obs.commands = torch::tensor({{0.0, 0.0, 0.0}}); - this->obs.base_quat = torch::tensor({{0.0, 0.0, 0.0, 1.0}}); - this->obs.dof_pos = this->params.default_dof_pos; - this->obs.dof_vel = torch::zeros({1, this->params.num_of_dofs}); - this->obs.actions = torch::zeros({1, this->params.num_of_dofs}); -} - -void RL::InitOutputs() -{ - this->output_torques = torch::zeros({1, this->params.num_of_dofs}); - this->output_dof_pos = this->params.default_dof_pos; -} - -void RL::InitControl() -{ - this->control.control_state = STATE_WAITING; - this->control.x = 0.0; - this->control.y = 0.0; - this->control.yaw = 0.0; -} - -torch::Tensor RL::ComputeTorques(torch::Tensor actions) -{ - torch::Tensor actions_scaled = actions * this->params.action_scale; - torch::Tensor output_torques = this->params.rl_kp * (actions_scaled + this->params.default_dof_pos - this->obs.dof_pos) - this->params.rl_kd * this->obs.dof_vel; - return output_torques; -} - -torch::Tensor RL::ComputePosition(torch::Tensor actions) -{ - torch::Tensor actions_scaled = actions * this->params.action_scale; - return actions_scaled + this->params.default_dof_pos; -} - -torch::Tensor RL::QuatRotateInverse(torch::Tensor q, torch::Tensor v, const std::string& framework) -{ - torch::Tensor q_w; - torch::Tensor q_vec; - if(framework == "isaacsim") - { - q_w = q.index({torch::indexing::Slice(), 0}); - q_vec = q.index({torch::indexing::Slice(), torch::indexing::Slice(1, 4)}); - } - else if(framework == "isaacgym") - { - q_w = q.index({torch::indexing::Slice(), 3}); - q_vec = q.index({torch::indexing::Slice(), torch::indexing::Slice(0, 3)}); - } - c10::IntArrayRef shape = q.sizes(); - - torch::Tensor a = v * (2.0 * torch::pow(q_w, 2) - 1.0).unsqueeze(-1); - torch::Tensor b = torch::cross(q_vec, v, -1) * q_w.unsqueeze(-1) * 2.0; - torch::Tensor c = q_vec * torch::bmm(q_vec.view({shape[0], 1, 3}), v.view({shape[0], 3, 1})).squeeze(-1) * 2.0; - return a - b + c; -} - -void RL::StateController(const RobotState *state, RobotCommand *command) -{ - static RobotState start_state; - static RobotState now_state; - static float getup_percent = 0.0; - static float getdown_percent = 0.0; - - // waiting - if(this->running_state == STATE_WAITING) - { - for(int i = 0; i < this->params.num_of_dofs; ++i) - { - command->motor_command.q[i] = state->motor_state.q[i]; - } - if(this->control.control_state == STATE_POS_GETUP) - { - this->control.control_state = STATE_WAITING; - getup_percent = 0.0; - for(int i = 0; i < this->params.num_of_dofs; ++i) - { - now_state.motor_state.q[i] = state->motor_state.q[i]; - start_state.motor_state.q[i] = now_state.motor_state.q[i]; - } - this->running_state = STATE_POS_GETUP; - std::cout << std::endl << LOGGER::INFO << "Switching to STATE_POS_GETUP" << std::endl; - } - } - // stand up (position control) - else if(this->running_state == STATE_POS_GETUP) - { - if(getup_percent < 1.0) - { - getup_percent += 1 / 500.0; - getup_percent = getup_percent > 1.0 ? 1.0 : getup_percent; - for(int i = 0; i < this->params.num_of_dofs; ++i) - { - command->motor_command.q[i] = (1 - getup_percent) * now_state.motor_state.q[i] + getup_percent * this->params.default_dof_pos[0][i].item(); - command->motor_command.dq[i] = 0; - command->motor_command.kp[i] = this->params.fixed_kp[0][i].item(); - command->motor_command.kd[i] = this->params.fixed_kd[0][i].item(); - command->motor_command.tau[i] = 0; - } - std::cout << "\r" << std::flush << LOGGER::INFO << "Getting up " << std::fixed << std::setprecision(2) << getup_percent * 100.0 << std::flush; - } - if(this->control.control_state == STATE_RL_INIT) - { - this->control.control_state = STATE_WAITING; - this->running_state = STATE_RL_INIT; - std::cout << std::endl << LOGGER::INFO << "Switching to STATE_RL_INIT" << std::endl; - } - else if(this->control.control_state == STATE_POS_GETDOWN) - { - this->control.control_state = STATE_WAITING; - getdown_percent = 0.0; - for(int i = 0; i < this->params.num_of_dofs; ++i) - { - now_state.motor_state.q[i] = state->motor_state.q[i]; - } - this->running_state = STATE_POS_GETDOWN; - std::cout << std::endl << LOGGER::INFO << "Switching to STATE_POS_GETDOWN" << std::endl; - } - } - // init obs and start rl loop - else if(this->running_state == STATE_RL_INIT) - { - if(getup_percent == 1) - { - this->InitObservations(); - this->InitOutputs(); - this->InitControl(); - this->running_state = STATE_RL_RUNNING; - std::cout << std::endl << LOGGER::INFO << "Switching to STATE_RL_RUNNING" << std::endl; - } - } - // rl loop - else if(this->running_state == STATE_RL_RUNNING) - { - std::cout << "\r" << std::flush << LOGGER::INFO << "RL Controller x:" << this->control.x << " y:" << this->control.y << " yaw:" << this->control.yaw << std::flush; - for(int i = 0; i < this->params.num_of_dofs; ++i) - { - command->motor_command.q[i] = this->output_dof_pos[0][i].item(); - command->motor_command.dq[i] = 0; - command->motor_command.kp[i] = this->params.rl_kp[0][i].item(); - command->motor_command.kd[i] = this->params.rl_kd[0][i].item(); - command->motor_command.tau[i] = 0; - } - if(this->control.control_state == STATE_POS_GETDOWN) - { - this->control.control_state = STATE_WAITING; - getdown_percent = 0.0; - for(int i = 0; i < this->params.num_of_dofs; ++i) - { - now_state.motor_state.q[i] = state->motor_state.q[i]; - } - this->running_state = STATE_POS_GETDOWN; - std::cout << std::endl << LOGGER::INFO << "Switching to STATE_POS_GETDOWN" << std::endl; - } - else if(this->control.control_state == STATE_POS_GETUP) - { - this->control.control_state = STATE_WAITING; - getup_percent = 0.0; - for(int i = 0; i < this->params.num_of_dofs; ++i) - { - now_state.motor_state.q[i] = state->motor_state.q[i]; - } - this->running_state = STATE_POS_GETUP; - std::cout << std::endl << LOGGER::INFO << "Switching to STATE_POS_GETUP" << std::endl; - } - } - // get down (position control) - else if(this->running_state == STATE_POS_GETDOWN) - { - if(getdown_percent < 1.0) - { - getdown_percent += 1 / 500.0; - getdown_percent = getdown_percent > 1.0 ? 1.0 : getdown_percent; - for(int i = 0; i < this->params.num_of_dofs; ++i) - { - command->motor_command.q[i] = (1 - getdown_percent) * now_state.motor_state.q[i] + getdown_percent * start_state.motor_state.q[i]; - command->motor_command.dq[i] = 0; - command->motor_command.kp[i] = this->params.fixed_kp[0][i].item(); - command->motor_command.kd[i] = this->params.fixed_kd[0][i].item(); - command->motor_command.tau[i] = 0; - } - std::cout << "\r" << std::flush << LOGGER::INFO << "Getting down " << std::fixed << std::setprecision(2) << getdown_percent * 100.0 << std::flush; - } - if(getdown_percent == 1) - { - this->InitObservations(); - this->InitOutputs(); - this->InitControl(); - this->running_state = STATE_WAITING; - std::cout << std::endl << LOGGER::INFO << "Switching to STATE_WAITING" << std::endl; - } - } -} - -void RL::TorqueProtect(torch::Tensor origin_output_torques) -{ - std::vector out_of_range_indices; - std::vector out_of_range_values; - for(int i = 0; i < origin_output_torques.size(1); ++i) - { - double torque_value = origin_output_torques[0][i].item(); - double limit_lower = -this->params.torque_limits[0][i].item(); - double limit_upper = this->params.torque_limits[0][i].item(); - - if(torque_value < limit_lower || torque_value > limit_upper) - { - out_of_range_indices.push_back(i); - out_of_range_values.push_back(torque_value); - } - } - if(!out_of_range_indices.empty()) - { - for(int i = 0; i < out_of_range_indices.size(); ++i) - { - int index = out_of_range_indices[i]; - double value = out_of_range_values[i]; - double limit_lower = -this->params.torque_limits[0][index].item(); - double limit_upper = this->params.torque_limits[0][index].item(); - - std::cout << LOGGER::WARNING << "Torque(" << index+1 << ")=" << value << " out of range(" << limit_lower << ", " << limit_upper << ")" << std::endl; - } - // Just a reminder, no protection - // this->control.control_state = STATE_POS_GETDOWN; - // std::cout << LOGGER::INFO << "Switching to STATE_POS_GETDOWN"<< std::endl; - } -} - -#include -#include -static bool kbhit() -{ - termios term; - tcgetattr(0, &term); - - termios term2 = term; - term2.c_lflag &= ~ICANON; - tcsetattr(0, TCSANOW, &term2); - - int byteswaiting; - ioctl(0, FIONREAD, &byteswaiting); - - tcsetattr(0, TCSANOW, &term); - - return byteswaiting > 0; -} - -void RL::KeyboardInterface() -{ - if(kbhit()) - { - int c = fgetc(stdin); - switch(c) - { - case '0': this->control.control_state = STATE_POS_GETUP; break; - case 'p': this->control.control_state = STATE_RL_INIT; break; - case '1': this->control.control_state = STATE_POS_GETDOWN; break; - case 'q': break; - case 'w': this->control.x += 0.1; break; - case 's': this->control.x -= 0.1; break; - case 'a': this->control.yaw += 0.1; break; - case 'd': this->control.yaw -= 0.1; break; - case 'i': break; - case 'k': break; - case 'j': this->control.y += 0.1; break; - case 'l': this->control.y -= 0.1; break; - case ' ': this->control.x = 0; this->control.y = 0; this->control.yaw = 0; break; - case 'r': this->control.control_state = STATE_RESET_SIMULATION; break; - case '\n': this->control.control_state = STATE_TOGGLE_SIMULATION; break; - default: break; - } - } -} - -template -std::vector ReadVectorFromYaml(const YAML::Node& node) -{ - std::vector values; - for(const auto& val : node) - { - values.push_back(val.as()); - } - return values; -} - -template -std::vector ReadVectorFromYaml(const YAML::Node& node, const std::string& framework, const int& rows, const int& cols) -{ - std::vector values; - for(const auto& val : node) - { - values.push_back(val.as()); - } - - if(framework == "isaacsim") - { - std::vector transposed_values(cols * rows); - for(int r = 0; r < rows; ++r) - { - for(int c = 0; c < cols; ++c) - { - transposed_values[c * rows + r] = values[r * cols + c]; - } - } - return transposed_values; - } - else if(framework == "isaacgym") - { - return values; - } - else - { - throw std::invalid_argument("Unsupported framework: " + framework); - } -} - -void RL::ReadYaml(std::string robot_name) -{ - // The config file is located at "rl_sar/src/rl_sar/models//config.yaml" - std::string config_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/" + robot_name + "/config.yaml"; - YAML::Node config; - try - { - config = YAML::LoadFile(config_path)[robot_name]; - } catch(YAML::BadFile &e) - { - std::cout << LOGGER::ERROR << "The file '" << config_path << "' does not exist" << std::endl; - return; - } - - this->params.model_name = config["model_name"].as(); - this->params.framework = config["framework"].as(); - int rows = config["rows"].as(); - int cols = config["cols"].as(); - this->params.use_history = config["use_history"].as(); - this->params.dt = config["dt"].as(); - this->params.decimation = config["decimation"].as(); - this->params.num_observations = config["num_observations"].as(); - this->params.observations = ReadVectorFromYaml(config["observations"]); - this->params.clip_obs = config["clip_obs"].as(); - if(config["clip_actions_lower"].IsNull() && config["clip_actions_upper"].IsNull()) - { - this->params.clip_actions_upper = torch::tensor({}).view({1, -1}); - this->params.clip_actions_lower = torch::tensor({}).view({1, -1}); - } - else - { - this->params.clip_actions_upper = torch::tensor(ReadVectorFromYaml(config["clip_actions_upper"], this->params.framework, rows, cols)).view({1, -1}); - this->params.clip_actions_lower = torch::tensor(ReadVectorFromYaml(config["clip_actions_lower"], this->params.framework, rows, cols)).view({1, -1}); - } - this->params.action_scale = config["action_scale"].as(); - this->params.hip_scale_reduction = config["hip_scale_reduction"].as(); - this->params.hip_scale_reduction_indices = ReadVectorFromYaml(config["hip_scale_reduction_indices"]); - this->params.num_of_dofs = config["num_of_dofs"].as(); - this->params.lin_vel_scale = config["lin_vel_scale"].as(); - this->params.ang_vel_scale = config["ang_vel_scale"].as(); - this->params.dof_pos_scale = config["dof_pos_scale"].as(); - this->params.dof_vel_scale = config["dof_vel_scale"].as(); - // this->params.commands_scale = torch::tensor(ReadVectorFromYaml(config["commands_scale"])).view({1, -1}); - this->params.commands_scale = torch::tensor({this->params.lin_vel_scale, this->params.lin_vel_scale, this->params.ang_vel_scale}); - this->params.rl_kp = torch::tensor(ReadVectorFromYaml(config["rl_kp"], this->params.framework, rows, cols)).view({1, -1}); - this->params.rl_kd = torch::tensor(ReadVectorFromYaml(config["rl_kd"], this->params.framework, rows, cols)).view({1, -1}); - this->params.fixed_kp = torch::tensor(ReadVectorFromYaml(config["fixed_kp"], this->params.framework, rows, cols)).view({1, -1}); - this->params.fixed_kd = torch::tensor(ReadVectorFromYaml(config["fixed_kd"], this->params.framework, rows, cols)).view({1, -1}); - this->params.torque_limits = torch::tensor(ReadVectorFromYaml(config["torque_limits"], this->params.framework, rows, cols)).view({1, -1}); - this->params.default_dof_pos = torch::tensor(ReadVectorFromYaml(config["default_dof_pos"], this->params.framework, rows, cols)).view({1, -1}); - this->params.joint_controller_names = ReadVectorFromYaml(config["joint_controller_names"], this->params.framework, rows, cols); -} - -void RL::CSVInit(std::string robot_name) -{ - csv_filename = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/" + robot_name + "/motor"; - - // Uncomment these lines if need timestamp for file name - // auto now = std::chrono::system_clock::now(); - // std::time_t now_c = std::chrono::system_clock::to_time_t(now); - // std::stringstream ss; - // ss << std::put_time(std::localtime(&now_c), "%Y%m%d%H%M%S"); - // std::string timestamp = ss.str(); - // csv_filename += "_" + timestamp; - - csv_filename += ".csv"; - std::ofstream file(csv_filename.c_str()); - - for(int i = 0; i < 12; ++i) {file << "tau_cal_" << i << ",";} - for(int i = 0; i < 12; ++i) {file << "tau_est_" << i << ",";} - for(int i = 0; i < 12; ++i) {file << "joint_pos_" << i << ",";} - for(int i = 0; i < 12; ++i) {file << "joint_pos_target_" << i << ",";} - for(int i = 0; i < 12; ++i) {file << "joint_vel_" << i << ",";} - - file << std::endl; - - file.close(); -} - -void RL::CSVLogger(torch::Tensor torque, torch::Tensor tau_est, torch::Tensor joint_pos, torch::Tensor joint_pos_target, torch::Tensor joint_vel) -{ - std::ofstream file(csv_filename.c_str(), std::ios_base::app); - - for(int i = 0; i < 12; ++i) {file << torque[0][i].item() << ",";} - for(int i = 0; i < 12; ++i) {file << tau_est[0][i].item() << ",";} - for(int i = 0; i < 12; ++i) {file << joint_pos[0][i].item() << ",";} - for(int i = 0; i < 12; ++i) {file << joint_pos_target[0][i].item() << ",";} - for(int i = 0; i < 12; ++i) {file << joint_vel[0][i].item() << ",";} - - file << std::endl; - - file.close(); -} \ No newline at end of file diff --git a/descriptions/unitree/a1_description/README.md b/descriptions/unitree/a1_description/README.md index db82764..b72c13e 100644 --- a/descriptions/unitree/a1_description/README.md +++ b/descriptions/unitree/a1_description/README.md @@ -31,4 +31,8 @@ ros2 launch a1_description visualize.launch.py source ~/ros2_ws/install/setup.bash ros2 launch a1_description ocs2_control.launch.py ``` - +* Legged Gym Controller + ```bash + source ~/ros2_ws/install/setup.bash + ros2 launch a1_description rl_control.launch.py + ``` diff --git a/descriptions/unitree/a1_description/config/robot_control.yaml b/descriptions/unitree/a1_description/config/robot_control.yaml index 0ebfd6b..ab4d515 100644 --- a/descriptions/unitree/a1_description/config/robot_control.yaml +++ b/descriptions/unitree/a1_description/config/robot_control.yaml @@ -16,6 +16,9 @@ controller_manager: ocs2_quadruped_controller: type: ocs2_quadruped_controller/Ocs2QuadrupedController + legged_gym_controller: + type: legged_gym_controller/LeggedGymController + imu_sensor_broadcaster: ros__parameters: sensor_name: "imu_sensor" @@ -128,4 +131,54 @@ ocs2_quadruped_controller: - FL - RL - FR - - RR \ No newline at end of file + - RR + +legged_gym_controller: + ros__parameters: + update_rate: 500 # Hz + joints: + - FR_hip_joint + - FR_thigh_joint + - FR_calf_joint + - FL_hip_joint + - FL_thigh_joint + - FL_calf_joint + - RR_hip_joint + - RR_thigh_joint + - RR_calf_joint + - RL_hip_joint + - RL_thigh_joint + - RL_calf_joint + + command_interfaces: + - effort + - position + - velocity + - kp + - kd + + state_interfaces: + - effort + - position + - velocity + + feet_names: + - FR_foot + - FL_foot + - RR_foot + - RL_foot + + imu_name: "imu_sensor" + base_name: "base" + + imu_interfaces: + - orientation.w + - orientation.x + - orientation.y + - orientation.z + - angular_velocity.x + - angular_velocity.y + - angular_velocity.z + - linear_acceleration.x + - linear_acceleration.y + - linear_acceleration.z \ No newline at end of file diff --git a/descriptions/unitree/a1_description/launch/rl_control.launch.py b/descriptions/unitree/a1_description/launch/rl_control.launch.py new file mode 100644 index 0000000..2c6e0a9 --- /dev/null +++ b/descriptions/unitree/a1_description/launch/rl_control.launch.py @@ -0,0 +1,112 @@ +import os + +import xacro +from ament_index_python.packages import get_package_share_directory +from launch import LaunchDescription +from launch.actions import DeclareLaunchArgument, OpaqueFunction, IncludeLaunchDescription, RegisterEventHandler +from launch.event_handlers import OnProcessExit +from launch.substitutions import PathJoinSubstitution +from launch_ros.actions import Node +from launch_ros.substitutions import FindPackageShare + +package_description = "a1_description" + + +def process_xacro(context): + robot_type_value = context.launch_configurations['robot_type'] + pkg_path = os.path.join(get_package_share_directory(package_description)) + xacro_file = os.path.join(pkg_path, 'xacro', 'robot.xacro') + robot_description_config = xacro.process_file(xacro_file, mappings={'robot_type': robot_type_value}) + return (robot_description_config.toxml(), robot_type_value) + + +def launch_setup(context, *args, **kwargs): + (robot_description, robot_type) = process_xacro(context) + robot_controllers = PathJoinSubstitution( + [ + FindPackageShare(package_description), + "config", + "robot_control.yaml", + ] + ) + + robot_state_publisher = Node( + package='robot_state_publisher', + executable='robot_state_publisher', + name='robot_state_publisher', + parameters=[ + { + 'publish_frequency': 20.0, + 'use_tf_static': True, + 'robot_description': robot_description, + 'ignore_timestamp': True + } + ], + ) + + controller_manager = Node( + package="controller_manager", + executable="ros2_control_node", + parameters=[robot_controllers], + output="both", + ) + + joint_state_publisher = Node( + package="controller_manager", + executable="spawner", + arguments=["joint_state_broadcaster", + "--controller-manager", "/controller_manager"], + ) + + imu_sensor_broadcaster = Node( + package="controller_manager", + executable="spawner", + arguments=["imu_sensor_broadcaster", + "--controller-manager", "/controller_manager"], + ) + + controller = Node( + package="controller_manager", + executable="spawner", + arguments=["legged_gym_controller", "--controller-manager", "/controller_manager"], + ) + + return [ + robot_state_publisher, + controller_manager, + joint_state_publisher, + RegisterEventHandler( + event_handler=OnProcessExit( + target_action=joint_state_publisher, + on_exit=[imu_sensor_broadcaster], + ) + ), + RegisterEventHandler( + event_handler=OnProcessExit( + target_action=imu_sensor_broadcaster, + on_exit=[controller], + ) + ), + ] + + +def generate_launch_description(): + robot_type_arg = DeclareLaunchArgument( + 'robot_type', + default_value='a1', + description='Type of the robot' + ) + + rviz_config_file = os.path.join(get_package_share_directory(package_description), "config", "visualize_urdf.rviz") + + return LaunchDescription([ + robot_type_arg, + OpaqueFunction(function=launch_setup), + Node( + package='rviz2', + executable='rviz2', + name='rviz_ocs2', + output='screen', + arguments=["-d", rviz_config_file] + ) + ])