modified observation buffer

This commit is contained in:
Zhenbiao Huang 2024-10-06 20:19:36 +08:00
parent f8b3efdbdb
commit bc8b3ded39
8 changed files with 197 additions and 69 deletions

View File

@ -33,9 +33,12 @@ endforeach ()
add_library(${PROJECT_NAME} SHARED
src/LeggedGymController.cpp
src/common/ObservationBuffer.cpp
src/FSM/StatePassive.cpp
src/FSM/StateFixedStand.cpp
src/FSM/StateFixedDown.cpp
src/FSM/StateRL.cpp
)
target_include_directories(${PROJECT_NAME}
PUBLIC

View File

@ -0,0 +1,54 @@
//
// Created by biao on 24-10-6.
//
#ifndef STATERL_H
#define STATERL_H
#include <common/ObservationBuffer.h>
#include <torch/script.h>
#include "FSMState.h"
class StateRL final : public FSMState {
public:
explicit StateRL(CtrlComponent &ctrl_component);
void enter() override;
void run() override;
void exit() override;
FSMStateName checkChange() override;
private:
torch::Tensor computeObservation();
/**
* @brief Forward the RL model to get the action
*/
torch::Tensor forward();
// Parameters
double linear_vel_scale_;
double angular_vel_scale_;
double clip_obs_;
torch::Tensor clip_actions_upper_;
torch::Tensor clip_actions_lower_;
bool use_history_;
// history buffer
std::shared_ptr<ObservationBuffer> history_obs_buf_;
torch::Tensor history_obs_;
// rl module
torch::jit::script::Module model;
// output buffer
torch::Tensor output_torques;
torch::Tensor output_dof_pos;
};
#endif //STATERL_H

View File

@ -11,6 +11,7 @@ enum class FSMStateName {
PASSIVE,
FIXEDDOWN,
FIXEDSTAND,
RL,
};
enum class FSMMode {

View File

@ -1,45 +0,0 @@
#include "observation_buffer.hpp"
ObservationBuffer::ObservationBuffer() {}
ObservationBuffer::ObservationBuffer(int num_envs,
int num_obs,
int include_history_steps)
: num_envs(num_envs),
num_obs(num_obs),
include_history_steps(include_history_steps)
{
num_obs_total = num_obs * include_history_steps;
obs_buf = torch::zeros({num_envs, num_obs_total}, torch::dtype(torch::kFloat32));
}
void ObservationBuffer::reset(std::vector<int> reset_idxs, torch::Tensor new_obs)
{
std::vector<torch::indexing::TensorIndex> indices;
for (int idx : reset_idxs) {
indices.push_back(torch::indexing::Slice(idx));
}
obs_buf.index_put_(indices, new_obs.repeat({1, include_history_steps}));
}
void ObservationBuffer::insert(torch::Tensor new_obs)
{
// Shift observations back.
torch::Tensor shifted_obs = obs_buf.index({torch::indexing::Slice(torch::indexing::None), torch::indexing::Slice(num_obs, num_obs * include_history_steps)}).clone();
obs_buf.index({torch::indexing::Slice(torch::indexing::None), torch::indexing::Slice(0, num_obs * (include_history_steps - 1))}) = shifted_obs;
// Add new observation.
obs_buf.index({torch::indexing::Slice(torch::indexing::None), torch::indexing::Slice(-num_obs, torch::indexing::None)}) = new_obs;
}
torch::Tensor ObservationBuffer::get_obs_vec(std::vector<int> obs_ids)
{
std::vector<torch::Tensor> obs;
for (int i = obs_ids.size() - 1; i >= 0; --i)
{
int obs_id = obs_ids[i];
int slice_idx = include_history_steps - obs_id - 1;
obs.push_back(obs_buf.index({torch::indexing::Slice(torch::indexing::None), torch::indexing::Slice(slice_idx * num_obs, (slice_idx + 1) * num_obs)}));
}
return cat(obs, -1);
}

View File

@ -1,24 +0,0 @@
#ifndef OBSERVATION_BUFFER_HPP
#define OBSERVATION_BUFFER_HPP
#include <torch/torch.h>
#include <vector>
class ObservationBuffer {
public:
ObservationBuffer(int num_envs, int num_obs, int include_history_steps);
ObservationBuffer();
void reset(std::vector<int> reset_idxs, torch::Tensor new_obs);
void insert(torch::Tensor new_obs);
torch::Tensor get_obs_vec(std::vector<int> obs_ids);
private:
int num_envs;
int num_obs;
int include_history_steps;
int num_obs_total;
torch::Tensor obs_buf;
};
#endif // OBSERVATION_BUFFER_HPP

View File

@ -0,0 +1,56 @@
//
// Created by biao on 24-10-6.
//
#include "legged_gym_controller/FSM/StateRL.h"
StateRL::StateRL(CtrlComponent &ctrl_component) : FSMState(
FSMStateName::RL, "rl", ctrl_component) {
}
void StateRL::enter() {
}
void StateRL::run() {
}
void StateRL::exit() {
}
FSMStateName StateRL::checkChange() {
switch (ctrl_comp_.control_inputs_.command) {
case 1:
return FSMStateName::PASSIVE;
case 2:
return FSMStateName::FIXEDDOWN;
default:
return FSMStateName::RL;
}
}
torch::Tensor StateRL::computeObservation() {
std::vector<torch::Tensor> obs_list;
const torch::Tensor obs = cat(obs_list, 1);
torch::Tensor clamped_obs = clamp(obs, -clip_obs_, clip_obs_);
return clamped_obs;
}
torch::Tensor StateRL::forward() {
torch::autograd::GradMode::set_enabled(false);
torch::Tensor clamped_obs = computeObservation();
torch::Tensor actions;
if (use_history_) {
history_obs_buf_->insert(clamped_obs);
history_obs_ = history_obs_buf_->getObsVec({0, 1, 2, 3, 4, 5});
actions = model.forward({history_obs_}).toTensor();
} else {
actions = model.forward({clamped_obs}).toTensor();
}
if (clip_actions_upper_.numel() != 0 && clip_actions_lower_.numel() != 0) {
return clamp(actions, clip_actions_lower_, clip_actions_upper_);
}
return actions;
}

View File

@ -0,0 +1,51 @@
//
// Created by biao on 24-10-6.
//
#include "ObservationBuffer.h"
ObservationBuffer::ObservationBuffer(int num_envs,
const int num_obs,
const int include_history_steps)
: num_envs_(num_envs),
num_obs_(num_obs),
include_history_steps_(include_history_steps) {
num_obs_total_ = num_obs * include_history_steps;
obs_buffer_ = torch::zeros({num_envs, num_obs_total_}, dtype(torch::kFloat32));
}
void ObservationBuffer::reset(const std::vector<int> &reset_index, const torch::Tensor &new_obs) {
std::vector<torch::indexing::TensorIndex> indices;
for (int index: reset_index) {
indices.emplace_back(torch::indexing::Slice(index));
}
obs_buffer_.index_put_(indices, new_obs.repeat({1, include_history_steps_}));
}
void ObservationBuffer::insert(const torch::Tensor &new_obs) {
// Shift observations back.
const torch::Tensor shifted_obs = obs_buffer_.index({
torch::indexing::Slice(torch::indexing::None), torch::indexing::Slice(num_obs_, num_obs_ * include_history_steps_)
}).clone();
obs_buffer_.index({
torch::indexing::Slice(torch::indexing::None), torch::indexing::Slice(0, num_obs_ * (include_history_steps_ - 1))
}) = shifted_obs;
// Add new observation.
obs_buffer_.index({
torch::indexing::Slice(torch::indexing::None), torch::indexing::Slice(-num_obs_, torch::indexing::None)
}) = new_obs;
}
torch::Tensor ObservationBuffer::getObsVec(const std::vector<int> &obs_ids) const {
std::vector<torch::Tensor> obs;
for (int i = obs_ids.size() - 1; i >= 0; --i) {
const int obs_id = obs_ids[i];
const int slice_idx = include_history_steps_ - obs_id - 1;
obs.push_back(obs_buffer_.index({
torch::indexing::Slice(torch::indexing::None),
torch::indexing::Slice(slice_idx * num_obs_, (slice_idx + 1) * num_obs_)
}));
}
return cat(obs, -1);
}

View File

@ -0,0 +1,32 @@
//
// Created by biao on 24-10-6.
//
#ifndef OBSERVATIONBUFFER_H
#define OBSERVATIONBUFFER_H
#include <torch/torch.h>
#include <vector>
class ObservationBuffer {
public:
ObservationBuffer(int num_envs, int num_obs, int include_history_steps);
~ObservationBuffer() = default;
void reset(const std::vector<int>& reset_index, const torch::Tensor &new_obs);
void insert(const torch::Tensor &new_obs);
[[nodiscard]] torch::Tensor getObsVec(const std::vector<int> &obs_ids) const;
private:
int num_envs_;
int num_obs_;
int include_history_steps_;
int num_obs_total_;
torch::Tensor obs_buffer_;
};
#endif //OBSERVATIONBUFFER_H