From bc8b3ded3955ee77bbd3198caad89740be50d0ce Mon Sep 17 00:00:00 2001 From: Zhenbiao Huang Date: Sun, 6 Oct 2024 20:19:36 +0800 Subject: [PATCH] modified observation buffer --- .../legged_gym_controller/CMakeLists.txt | 3 + .../legged_gym_controller/FSM/StateRL.h | 54 ++++++++++++++++++ .../legged_gym_controller/common/enumClass.h | 1 + .../observation_buffer/observation_buffer.cpp | 45 --------------- .../observation_buffer/observation_buffer.hpp | 24 -------- .../legged_gym_controller/src/FSM/StateRL.cpp | 56 +++++++++++++++++++ .../src/common/ObservationBuffer.cpp | 51 +++++++++++++++++ .../src/common/ObservationBuffer.h | 32 +++++++++++ 8 files changed, 197 insertions(+), 69 deletions(-) create mode 100644 controllers/legged_gym_controller/include/legged_gym_controller/FSM/StateRL.h delete mode 100644 controllers/legged_gym_controller/library/observation_buffer/observation_buffer.cpp delete mode 100644 controllers/legged_gym_controller/library/observation_buffer/observation_buffer.hpp create mode 100644 controllers/legged_gym_controller/src/FSM/StateRL.cpp create mode 100644 controllers/legged_gym_controller/src/common/ObservationBuffer.cpp create mode 100644 controllers/legged_gym_controller/src/common/ObservationBuffer.h diff --git a/controllers/legged_gym_controller/CMakeLists.txt b/controllers/legged_gym_controller/CMakeLists.txt index 1794551..7003b28 100644 --- a/controllers/legged_gym_controller/CMakeLists.txt +++ b/controllers/legged_gym_controller/CMakeLists.txt @@ -33,9 +33,12 @@ endforeach () add_library(${PROJECT_NAME} SHARED src/LeggedGymController.cpp + src/common/ObservationBuffer.cpp + src/FSM/StatePassive.cpp src/FSM/StateFixedStand.cpp src/FSM/StateFixedDown.cpp + src/FSM/StateRL.cpp ) target_include_directories(${PROJECT_NAME} PUBLIC diff --git a/controllers/legged_gym_controller/include/legged_gym_controller/FSM/StateRL.h b/controllers/legged_gym_controller/include/legged_gym_controller/FSM/StateRL.h new file mode 100644 index 0000000..184a6d3 --- /dev/null +++ b/controllers/legged_gym_controller/include/legged_gym_controller/FSM/StateRL.h @@ -0,0 +1,54 @@ +// +// Created by biao on 24-10-6. +// + +#ifndef STATERL_H +#define STATERL_H + +#include +#include + +#include "FSMState.h" + + +class StateRL final : public FSMState { +public: + explicit StateRL(CtrlComponent &ctrl_component); + + void enter() override; + + void run() override; + + void exit() override; + + FSMStateName checkChange() override; + +private: + torch::Tensor computeObservation(); + + /** + * @brief Forward the RL model to get the action + */ + torch::Tensor forward(); + + // Parameters + double linear_vel_scale_; + double angular_vel_scale_; + double clip_obs_; + torch::Tensor clip_actions_upper_; + torch::Tensor clip_actions_lower_; + bool use_history_; + + // history buffer + std::shared_ptr history_obs_buf_; + torch::Tensor history_obs_; + + // rl module + torch::jit::script::Module model; + // output buffer + torch::Tensor output_torques; + torch::Tensor output_dof_pos; +}; + + +#endif //STATERL_H diff --git a/controllers/legged_gym_controller/include/legged_gym_controller/common/enumClass.h b/controllers/legged_gym_controller/include/legged_gym_controller/common/enumClass.h index f942ecc..4b2fc0e 100644 --- a/controllers/legged_gym_controller/include/legged_gym_controller/common/enumClass.h +++ b/controllers/legged_gym_controller/include/legged_gym_controller/common/enumClass.h @@ -11,6 +11,7 @@ enum class FSMStateName { PASSIVE, FIXEDDOWN, FIXEDSTAND, + RL, }; enum class FSMMode { diff --git a/controllers/legged_gym_controller/library/observation_buffer/observation_buffer.cpp b/controllers/legged_gym_controller/library/observation_buffer/observation_buffer.cpp deleted file mode 100644 index 01d2cae..0000000 --- a/controllers/legged_gym_controller/library/observation_buffer/observation_buffer.cpp +++ /dev/null @@ -1,45 +0,0 @@ -#include "observation_buffer.hpp" - -ObservationBuffer::ObservationBuffer() {} - -ObservationBuffer::ObservationBuffer(int num_envs, - int num_obs, - int include_history_steps) - : num_envs(num_envs), - num_obs(num_obs), - include_history_steps(include_history_steps) -{ - num_obs_total = num_obs * include_history_steps; - obs_buf = torch::zeros({num_envs, num_obs_total}, torch::dtype(torch::kFloat32)); -} - -void ObservationBuffer::reset(std::vector reset_idxs, torch::Tensor new_obs) -{ - std::vector indices; - for (int idx : reset_idxs) { - indices.push_back(torch::indexing::Slice(idx)); - } - obs_buf.index_put_(indices, new_obs.repeat({1, include_history_steps})); -} - -void ObservationBuffer::insert(torch::Tensor new_obs) -{ - // Shift observations back. - torch::Tensor shifted_obs = obs_buf.index({torch::indexing::Slice(torch::indexing::None), torch::indexing::Slice(num_obs, num_obs * include_history_steps)}).clone(); - obs_buf.index({torch::indexing::Slice(torch::indexing::None), torch::indexing::Slice(0, num_obs * (include_history_steps - 1))}) = shifted_obs; - - // Add new observation. - obs_buf.index({torch::indexing::Slice(torch::indexing::None), torch::indexing::Slice(-num_obs, torch::indexing::None)}) = new_obs; -} - -torch::Tensor ObservationBuffer::get_obs_vec(std::vector obs_ids) -{ - std::vector obs; - for (int i = obs_ids.size() - 1; i >= 0; --i) - { - int obs_id = obs_ids[i]; - int slice_idx = include_history_steps - obs_id - 1; - obs.push_back(obs_buf.index({torch::indexing::Slice(torch::indexing::None), torch::indexing::Slice(slice_idx * num_obs, (slice_idx + 1) * num_obs)})); - } - return cat(obs, -1); -} diff --git a/controllers/legged_gym_controller/library/observation_buffer/observation_buffer.hpp b/controllers/legged_gym_controller/library/observation_buffer/observation_buffer.hpp deleted file mode 100644 index 72be75c..0000000 --- a/controllers/legged_gym_controller/library/observation_buffer/observation_buffer.hpp +++ /dev/null @@ -1,24 +0,0 @@ -#ifndef OBSERVATION_BUFFER_HPP -#define OBSERVATION_BUFFER_HPP - -#include -#include - -class ObservationBuffer { -public: - ObservationBuffer(int num_envs, int num_obs, int include_history_steps); - ObservationBuffer(); - - void reset(std::vector reset_idxs, torch::Tensor new_obs); - void insert(torch::Tensor new_obs); - torch::Tensor get_obs_vec(std::vector obs_ids); - -private: - int num_envs; - int num_obs; - int include_history_steps; - int num_obs_total; - torch::Tensor obs_buf; -}; - -#endif // OBSERVATION_BUFFER_HPP \ No newline at end of file diff --git a/controllers/legged_gym_controller/src/FSM/StateRL.cpp b/controllers/legged_gym_controller/src/FSM/StateRL.cpp new file mode 100644 index 0000000..baa38b4 --- /dev/null +++ b/controllers/legged_gym_controller/src/FSM/StateRL.cpp @@ -0,0 +1,56 @@ +// +// Created by biao on 24-10-6. +// + +#include "legged_gym_controller/FSM/StateRL.h" + +StateRL::StateRL(CtrlComponent &ctrl_component) : FSMState( + FSMStateName::RL, "rl", ctrl_component) { +} + +void StateRL::enter() { +} + +void StateRL::run() { +} + +void StateRL::exit() { +} + +FSMStateName StateRL::checkChange() { + switch (ctrl_comp_.control_inputs_.command) { + case 1: + return FSMStateName::PASSIVE; + case 2: + return FSMStateName::FIXEDDOWN; + default: + return FSMStateName::RL; + } +} + +torch::Tensor StateRL::computeObservation() { + std::vector obs_list; + + const torch::Tensor obs = cat(obs_list, 1); + torch::Tensor clamped_obs = clamp(obs, -clip_obs_, clip_obs_); + return clamped_obs; +} + +torch::Tensor StateRL::forward() { + torch::autograd::GradMode::set_enabled(false); + torch::Tensor clamped_obs = computeObservation(); + torch::Tensor actions; + + if (use_history_) { + history_obs_buf_->insert(clamped_obs); + history_obs_ = history_obs_buf_->getObsVec({0, 1, 2, 3, 4, 5}); + actions = model.forward({history_obs_}).toTensor(); + } else { + actions = model.forward({clamped_obs}).toTensor(); + } + + if (clip_actions_upper_.numel() != 0 && clip_actions_lower_.numel() != 0) { + return clamp(actions, clip_actions_lower_, clip_actions_upper_); + } + return actions; +} diff --git a/controllers/legged_gym_controller/src/common/ObservationBuffer.cpp b/controllers/legged_gym_controller/src/common/ObservationBuffer.cpp new file mode 100644 index 0000000..4509942 --- /dev/null +++ b/controllers/legged_gym_controller/src/common/ObservationBuffer.cpp @@ -0,0 +1,51 @@ +// +// Created by biao on 24-10-6. +// + +#include "ObservationBuffer.h" + +ObservationBuffer::ObservationBuffer(int num_envs, + const int num_obs, + const int include_history_steps) + : num_envs_(num_envs), + num_obs_(num_obs), + include_history_steps_(include_history_steps) { + num_obs_total_ = num_obs * include_history_steps; + obs_buffer_ = torch::zeros({num_envs, num_obs_total_}, dtype(torch::kFloat32)); +} + +void ObservationBuffer::reset(const std::vector &reset_index, const torch::Tensor &new_obs) { + std::vector indices; + for (int index: reset_index) { + indices.emplace_back(torch::indexing::Slice(index)); + } + obs_buffer_.index_put_(indices, new_obs.repeat({1, include_history_steps_})); +} + +void ObservationBuffer::insert(const torch::Tensor &new_obs) { + // Shift observations back. + const torch::Tensor shifted_obs = obs_buffer_.index({ + torch::indexing::Slice(torch::indexing::None), torch::indexing::Slice(num_obs_, num_obs_ * include_history_steps_) + }).clone(); + obs_buffer_.index({ + torch::indexing::Slice(torch::indexing::None), torch::indexing::Slice(0, num_obs_ * (include_history_steps_ - 1)) + }) = shifted_obs; + + // Add new observation. + obs_buffer_.index({ + torch::indexing::Slice(torch::indexing::None), torch::indexing::Slice(-num_obs_, torch::indexing::None) + }) = new_obs; +} + +torch::Tensor ObservationBuffer::getObsVec(const std::vector &obs_ids) const { + std::vector obs; + for (int i = obs_ids.size() - 1; i >= 0; --i) { + const int obs_id = obs_ids[i]; + const int slice_idx = include_history_steps_ - obs_id - 1; + obs.push_back(obs_buffer_.index({ + torch::indexing::Slice(torch::indexing::None), + torch::indexing::Slice(slice_idx * num_obs_, (slice_idx + 1) * num_obs_) + })); + } + return cat(obs, -1); +} diff --git a/controllers/legged_gym_controller/src/common/ObservationBuffer.h b/controllers/legged_gym_controller/src/common/ObservationBuffer.h new file mode 100644 index 0000000..ebb5581 --- /dev/null +++ b/controllers/legged_gym_controller/src/common/ObservationBuffer.h @@ -0,0 +1,32 @@ +// +// Created by biao on 24-10-6. +// + +#ifndef OBSERVATIONBUFFER_H +#define OBSERVATIONBUFFER_H + +#include +#include + +class ObservationBuffer { +public: + ObservationBuffer(int num_envs, int num_obs, int include_history_steps); + + ~ObservationBuffer() = default; + + void reset(const std::vector& reset_index, const torch::Tensor &new_obs); + + void insert(const torch::Tensor &new_obs); + + [[nodiscard]] torch::Tensor getObsVec(const std::vector &obs_ids) const; + +private: + int num_envs_; + int num_obs_; + int include_history_steps_; + int num_obs_total_; + torch::Tensor obs_buffer_; +}; + + +#endif //OBSERVATIONBUFFER_H