modified observation buffer
This commit is contained in:
parent
f8b3efdbdb
commit
bc8b3ded39
|
@ -33,9 +33,12 @@ endforeach ()
|
|||
add_library(${PROJECT_NAME} SHARED
|
||||
src/LeggedGymController.cpp
|
||||
|
||||
src/common/ObservationBuffer.cpp
|
||||
|
||||
src/FSM/StatePassive.cpp
|
||||
src/FSM/StateFixedStand.cpp
|
||||
src/FSM/StateFixedDown.cpp
|
||||
src/FSM/StateRL.cpp
|
||||
)
|
||||
target_include_directories(${PROJECT_NAME}
|
||||
PUBLIC
|
||||
|
|
|
@ -0,0 +1,54 @@
|
|||
//
|
||||
// Created by biao on 24-10-6.
|
||||
//
|
||||
|
||||
#ifndef STATERL_H
|
||||
#define STATERL_H
|
||||
|
||||
#include <common/ObservationBuffer.h>
|
||||
#include <torch/script.h>
|
||||
|
||||
#include "FSMState.h"
|
||||
|
||||
|
||||
class StateRL final : public FSMState {
|
||||
public:
|
||||
explicit StateRL(CtrlComponent &ctrl_component);
|
||||
|
||||
void enter() override;
|
||||
|
||||
void run() override;
|
||||
|
||||
void exit() override;
|
||||
|
||||
FSMStateName checkChange() override;
|
||||
|
||||
private:
|
||||
torch::Tensor computeObservation();
|
||||
|
||||
/**
|
||||
* @brief Forward the RL model to get the action
|
||||
*/
|
||||
torch::Tensor forward();
|
||||
|
||||
// Parameters
|
||||
double linear_vel_scale_;
|
||||
double angular_vel_scale_;
|
||||
double clip_obs_;
|
||||
torch::Tensor clip_actions_upper_;
|
||||
torch::Tensor clip_actions_lower_;
|
||||
bool use_history_;
|
||||
|
||||
// history buffer
|
||||
std::shared_ptr<ObservationBuffer> history_obs_buf_;
|
||||
torch::Tensor history_obs_;
|
||||
|
||||
// rl module
|
||||
torch::jit::script::Module model;
|
||||
// output buffer
|
||||
torch::Tensor output_torques;
|
||||
torch::Tensor output_dof_pos;
|
||||
};
|
||||
|
||||
|
||||
#endif //STATERL_H
|
|
@ -11,6 +11,7 @@ enum class FSMStateName {
|
|||
PASSIVE,
|
||||
FIXEDDOWN,
|
||||
FIXEDSTAND,
|
||||
RL,
|
||||
};
|
||||
|
||||
enum class FSMMode {
|
||||
|
|
|
@ -1,45 +0,0 @@
|
|||
#include "observation_buffer.hpp"
|
||||
|
||||
ObservationBuffer::ObservationBuffer() {}
|
||||
|
||||
ObservationBuffer::ObservationBuffer(int num_envs,
|
||||
int num_obs,
|
||||
int include_history_steps)
|
||||
: num_envs(num_envs),
|
||||
num_obs(num_obs),
|
||||
include_history_steps(include_history_steps)
|
||||
{
|
||||
num_obs_total = num_obs * include_history_steps;
|
||||
obs_buf = torch::zeros({num_envs, num_obs_total}, torch::dtype(torch::kFloat32));
|
||||
}
|
||||
|
||||
void ObservationBuffer::reset(std::vector<int> reset_idxs, torch::Tensor new_obs)
|
||||
{
|
||||
std::vector<torch::indexing::TensorIndex> indices;
|
||||
for (int idx : reset_idxs) {
|
||||
indices.push_back(torch::indexing::Slice(idx));
|
||||
}
|
||||
obs_buf.index_put_(indices, new_obs.repeat({1, include_history_steps}));
|
||||
}
|
||||
|
||||
void ObservationBuffer::insert(torch::Tensor new_obs)
|
||||
{
|
||||
// Shift observations back.
|
||||
torch::Tensor shifted_obs = obs_buf.index({torch::indexing::Slice(torch::indexing::None), torch::indexing::Slice(num_obs, num_obs * include_history_steps)}).clone();
|
||||
obs_buf.index({torch::indexing::Slice(torch::indexing::None), torch::indexing::Slice(0, num_obs * (include_history_steps - 1))}) = shifted_obs;
|
||||
|
||||
// Add new observation.
|
||||
obs_buf.index({torch::indexing::Slice(torch::indexing::None), torch::indexing::Slice(-num_obs, torch::indexing::None)}) = new_obs;
|
||||
}
|
||||
|
||||
torch::Tensor ObservationBuffer::get_obs_vec(std::vector<int> obs_ids)
|
||||
{
|
||||
std::vector<torch::Tensor> obs;
|
||||
for (int i = obs_ids.size() - 1; i >= 0; --i)
|
||||
{
|
||||
int obs_id = obs_ids[i];
|
||||
int slice_idx = include_history_steps - obs_id - 1;
|
||||
obs.push_back(obs_buf.index({torch::indexing::Slice(torch::indexing::None), torch::indexing::Slice(slice_idx * num_obs, (slice_idx + 1) * num_obs)}));
|
||||
}
|
||||
return cat(obs, -1);
|
||||
}
|
|
@ -1,24 +0,0 @@
|
|||
#ifndef OBSERVATION_BUFFER_HPP
|
||||
#define OBSERVATION_BUFFER_HPP
|
||||
|
||||
#include <torch/torch.h>
|
||||
#include <vector>
|
||||
|
||||
class ObservationBuffer {
|
||||
public:
|
||||
ObservationBuffer(int num_envs, int num_obs, int include_history_steps);
|
||||
ObservationBuffer();
|
||||
|
||||
void reset(std::vector<int> reset_idxs, torch::Tensor new_obs);
|
||||
void insert(torch::Tensor new_obs);
|
||||
torch::Tensor get_obs_vec(std::vector<int> obs_ids);
|
||||
|
||||
private:
|
||||
int num_envs;
|
||||
int num_obs;
|
||||
int include_history_steps;
|
||||
int num_obs_total;
|
||||
torch::Tensor obs_buf;
|
||||
};
|
||||
|
||||
#endif // OBSERVATION_BUFFER_HPP
|
|
@ -0,0 +1,56 @@
|
|||
//
|
||||
// Created by biao on 24-10-6.
|
||||
//
|
||||
|
||||
#include "legged_gym_controller/FSM/StateRL.h"
|
||||
|
||||
StateRL::StateRL(CtrlComponent &ctrl_component) : FSMState(
|
||||
FSMStateName::RL, "rl", ctrl_component) {
|
||||
}
|
||||
|
||||
void StateRL::enter() {
|
||||
}
|
||||
|
||||
void StateRL::run() {
|
||||
}
|
||||
|
||||
void StateRL::exit() {
|
||||
}
|
||||
|
||||
FSMStateName StateRL::checkChange() {
|
||||
switch (ctrl_comp_.control_inputs_.command) {
|
||||
case 1:
|
||||
return FSMStateName::PASSIVE;
|
||||
case 2:
|
||||
return FSMStateName::FIXEDDOWN;
|
||||
default:
|
||||
return FSMStateName::RL;
|
||||
}
|
||||
}
|
||||
|
||||
torch::Tensor StateRL::computeObservation() {
|
||||
std::vector<torch::Tensor> obs_list;
|
||||
|
||||
const torch::Tensor obs = cat(obs_list, 1);
|
||||
torch::Tensor clamped_obs = clamp(obs, -clip_obs_, clip_obs_);
|
||||
return clamped_obs;
|
||||
}
|
||||
|
||||
torch::Tensor StateRL::forward() {
|
||||
torch::autograd::GradMode::set_enabled(false);
|
||||
torch::Tensor clamped_obs = computeObservation();
|
||||
torch::Tensor actions;
|
||||
|
||||
if (use_history_) {
|
||||
history_obs_buf_->insert(clamped_obs);
|
||||
history_obs_ = history_obs_buf_->getObsVec({0, 1, 2, 3, 4, 5});
|
||||
actions = model.forward({history_obs_}).toTensor();
|
||||
} else {
|
||||
actions = model.forward({clamped_obs}).toTensor();
|
||||
}
|
||||
|
||||
if (clip_actions_upper_.numel() != 0 && clip_actions_lower_.numel() != 0) {
|
||||
return clamp(actions, clip_actions_lower_, clip_actions_upper_);
|
||||
}
|
||||
return actions;
|
||||
}
|
|
@ -0,0 +1,51 @@
|
|||
//
|
||||
// Created by biao on 24-10-6.
|
||||
//
|
||||
|
||||
#include "ObservationBuffer.h"
|
||||
|
||||
ObservationBuffer::ObservationBuffer(int num_envs,
|
||||
const int num_obs,
|
||||
const int include_history_steps)
|
||||
: num_envs_(num_envs),
|
||||
num_obs_(num_obs),
|
||||
include_history_steps_(include_history_steps) {
|
||||
num_obs_total_ = num_obs * include_history_steps;
|
||||
obs_buffer_ = torch::zeros({num_envs, num_obs_total_}, dtype(torch::kFloat32));
|
||||
}
|
||||
|
||||
void ObservationBuffer::reset(const std::vector<int> &reset_index, const torch::Tensor &new_obs) {
|
||||
std::vector<torch::indexing::TensorIndex> indices;
|
||||
for (int index: reset_index) {
|
||||
indices.emplace_back(torch::indexing::Slice(index));
|
||||
}
|
||||
obs_buffer_.index_put_(indices, new_obs.repeat({1, include_history_steps_}));
|
||||
}
|
||||
|
||||
void ObservationBuffer::insert(const torch::Tensor &new_obs) {
|
||||
// Shift observations back.
|
||||
const torch::Tensor shifted_obs = obs_buffer_.index({
|
||||
torch::indexing::Slice(torch::indexing::None), torch::indexing::Slice(num_obs_, num_obs_ * include_history_steps_)
|
||||
}).clone();
|
||||
obs_buffer_.index({
|
||||
torch::indexing::Slice(torch::indexing::None), torch::indexing::Slice(0, num_obs_ * (include_history_steps_ - 1))
|
||||
}) = shifted_obs;
|
||||
|
||||
// Add new observation.
|
||||
obs_buffer_.index({
|
||||
torch::indexing::Slice(torch::indexing::None), torch::indexing::Slice(-num_obs_, torch::indexing::None)
|
||||
}) = new_obs;
|
||||
}
|
||||
|
||||
torch::Tensor ObservationBuffer::getObsVec(const std::vector<int> &obs_ids) const {
|
||||
std::vector<torch::Tensor> obs;
|
||||
for (int i = obs_ids.size() - 1; i >= 0; --i) {
|
||||
const int obs_id = obs_ids[i];
|
||||
const int slice_idx = include_history_steps_ - obs_id - 1;
|
||||
obs.push_back(obs_buffer_.index({
|
||||
torch::indexing::Slice(torch::indexing::None),
|
||||
torch::indexing::Slice(slice_idx * num_obs_, (slice_idx + 1) * num_obs_)
|
||||
}));
|
||||
}
|
||||
return cat(obs, -1);
|
||||
}
|
|
@ -0,0 +1,32 @@
|
|||
//
|
||||
// Created by biao on 24-10-6.
|
||||
//
|
||||
|
||||
#ifndef OBSERVATIONBUFFER_H
|
||||
#define OBSERVATIONBUFFER_H
|
||||
|
||||
#include <torch/torch.h>
|
||||
#include <vector>
|
||||
|
||||
class ObservationBuffer {
|
||||
public:
|
||||
ObservationBuffer(int num_envs, int num_obs, int include_history_steps);
|
||||
|
||||
~ObservationBuffer() = default;
|
||||
|
||||
void reset(const std::vector<int>& reset_index, const torch::Tensor &new_obs);
|
||||
|
||||
void insert(const torch::Tensor &new_obs);
|
||||
|
||||
[[nodiscard]] torch::Tensor getObsVec(const std::vector<int> &obs_ids) const;
|
||||
|
||||
private:
|
||||
int num_envs_;
|
||||
int num_obs_;
|
||||
int include_history_steps_;
|
||||
int num_obs_total_;
|
||||
torch::Tensor obs_buffer_;
|
||||
};
|
||||
|
||||
|
||||
#endif //OBSERVATIONBUFFER_H
|
Loading…
Reference in New Issue