modified observation buffer
This commit is contained in:
parent
f8b3efdbdb
commit
bc8b3ded39
|
@ -33,9 +33,12 @@ endforeach ()
|
||||||
add_library(${PROJECT_NAME} SHARED
|
add_library(${PROJECT_NAME} SHARED
|
||||||
src/LeggedGymController.cpp
|
src/LeggedGymController.cpp
|
||||||
|
|
||||||
|
src/common/ObservationBuffer.cpp
|
||||||
|
|
||||||
src/FSM/StatePassive.cpp
|
src/FSM/StatePassive.cpp
|
||||||
src/FSM/StateFixedStand.cpp
|
src/FSM/StateFixedStand.cpp
|
||||||
src/FSM/StateFixedDown.cpp
|
src/FSM/StateFixedDown.cpp
|
||||||
|
src/FSM/StateRL.cpp
|
||||||
)
|
)
|
||||||
target_include_directories(${PROJECT_NAME}
|
target_include_directories(${PROJECT_NAME}
|
||||||
PUBLIC
|
PUBLIC
|
||||||
|
|
|
@ -0,0 +1,54 @@
|
||||||
|
//
|
||||||
|
// Created by biao on 24-10-6.
|
||||||
|
//
|
||||||
|
|
||||||
|
#ifndef STATERL_H
|
||||||
|
#define STATERL_H
|
||||||
|
|
||||||
|
#include <common/ObservationBuffer.h>
|
||||||
|
#include <torch/script.h>
|
||||||
|
|
||||||
|
#include "FSMState.h"
|
||||||
|
|
||||||
|
|
||||||
|
class StateRL final : public FSMState {
|
||||||
|
public:
|
||||||
|
explicit StateRL(CtrlComponent &ctrl_component);
|
||||||
|
|
||||||
|
void enter() override;
|
||||||
|
|
||||||
|
void run() override;
|
||||||
|
|
||||||
|
void exit() override;
|
||||||
|
|
||||||
|
FSMStateName checkChange() override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
torch::Tensor computeObservation();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Forward the RL model to get the action
|
||||||
|
*/
|
||||||
|
torch::Tensor forward();
|
||||||
|
|
||||||
|
// Parameters
|
||||||
|
double linear_vel_scale_;
|
||||||
|
double angular_vel_scale_;
|
||||||
|
double clip_obs_;
|
||||||
|
torch::Tensor clip_actions_upper_;
|
||||||
|
torch::Tensor clip_actions_lower_;
|
||||||
|
bool use_history_;
|
||||||
|
|
||||||
|
// history buffer
|
||||||
|
std::shared_ptr<ObservationBuffer> history_obs_buf_;
|
||||||
|
torch::Tensor history_obs_;
|
||||||
|
|
||||||
|
// rl module
|
||||||
|
torch::jit::script::Module model;
|
||||||
|
// output buffer
|
||||||
|
torch::Tensor output_torques;
|
||||||
|
torch::Tensor output_dof_pos;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
#endif //STATERL_H
|
|
@ -11,6 +11,7 @@ enum class FSMStateName {
|
||||||
PASSIVE,
|
PASSIVE,
|
||||||
FIXEDDOWN,
|
FIXEDDOWN,
|
||||||
FIXEDSTAND,
|
FIXEDSTAND,
|
||||||
|
RL,
|
||||||
};
|
};
|
||||||
|
|
||||||
enum class FSMMode {
|
enum class FSMMode {
|
||||||
|
|
|
@ -1,45 +0,0 @@
|
||||||
#include "observation_buffer.hpp"
|
|
||||||
|
|
||||||
ObservationBuffer::ObservationBuffer() {}
|
|
||||||
|
|
||||||
ObservationBuffer::ObservationBuffer(int num_envs,
|
|
||||||
int num_obs,
|
|
||||||
int include_history_steps)
|
|
||||||
: num_envs(num_envs),
|
|
||||||
num_obs(num_obs),
|
|
||||||
include_history_steps(include_history_steps)
|
|
||||||
{
|
|
||||||
num_obs_total = num_obs * include_history_steps;
|
|
||||||
obs_buf = torch::zeros({num_envs, num_obs_total}, torch::dtype(torch::kFloat32));
|
|
||||||
}
|
|
||||||
|
|
||||||
void ObservationBuffer::reset(std::vector<int> reset_idxs, torch::Tensor new_obs)
|
|
||||||
{
|
|
||||||
std::vector<torch::indexing::TensorIndex> indices;
|
|
||||||
for (int idx : reset_idxs) {
|
|
||||||
indices.push_back(torch::indexing::Slice(idx));
|
|
||||||
}
|
|
||||||
obs_buf.index_put_(indices, new_obs.repeat({1, include_history_steps}));
|
|
||||||
}
|
|
||||||
|
|
||||||
void ObservationBuffer::insert(torch::Tensor new_obs)
|
|
||||||
{
|
|
||||||
// Shift observations back.
|
|
||||||
torch::Tensor shifted_obs = obs_buf.index({torch::indexing::Slice(torch::indexing::None), torch::indexing::Slice(num_obs, num_obs * include_history_steps)}).clone();
|
|
||||||
obs_buf.index({torch::indexing::Slice(torch::indexing::None), torch::indexing::Slice(0, num_obs * (include_history_steps - 1))}) = shifted_obs;
|
|
||||||
|
|
||||||
// Add new observation.
|
|
||||||
obs_buf.index({torch::indexing::Slice(torch::indexing::None), torch::indexing::Slice(-num_obs, torch::indexing::None)}) = new_obs;
|
|
||||||
}
|
|
||||||
|
|
||||||
torch::Tensor ObservationBuffer::get_obs_vec(std::vector<int> obs_ids)
|
|
||||||
{
|
|
||||||
std::vector<torch::Tensor> obs;
|
|
||||||
for (int i = obs_ids.size() - 1; i >= 0; --i)
|
|
||||||
{
|
|
||||||
int obs_id = obs_ids[i];
|
|
||||||
int slice_idx = include_history_steps - obs_id - 1;
|
|
||||||
obs.push_back(obs_buf.index({torch::indexing::Slice(torch::indexing::None), torch::indexing::Slice(slice_idx * num_obs, (slice_idx + 1) * num_obs)}));
|
|
||||||
}
|
|
||||||
return cat(obs, -1);
|
|
||||||
}
|
|
|
@ -1,24 +0,0 @@
|
||||||
#ifndef OBSERVATION_BUFFER_HPP
|
|
||||||
#define OBSERVATION_BUFFER_HPP
|
|
||||||
|
|
||||||
#include <torch/torch.h>
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
class ObservationBuffer {
|
|
||||||
public:
|
|
||||||
ObservationBuffer(int num_envs, int num_obs, int include_history_steps);
|
|
||||||
ObservationBuffer();
|
|
||||||
|
|
||||||
void reset(std::vector<int> reset_idxs, torch::Tensor new_obs);
|
|
||||||
void insert(torch::Tensor new_obs);
|
|
||||||
torch::Tensor get_obs_vec(std::vector<int> obs_ids);
|
|
||||||
|
|
||||||
private:
|
|
||||||
int num_envs;
|
|
||||||
int num_obs;
|
|
||||||
int include_history_steps;
|
|
||||||
int num_obs_total;
|
|
||||||
torch::Tensor obs_buf;
|
|
||||||
};
|
|
||||||
|
|
||||||
#endif // OBSERVATION_BUFFER_HPP
|
|
|
@ -0,0 +1,56 @@
|
||||||
|
//
|
||||||
|
// Created by biao on 24-10-6.
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "legged_gym_controller/FSM/StateRL.h"
|
||||||
|
|
||||||
|
StateRL::StateRL(CtrlComponent &ctrl_component) : FSMState(
|
||||||
|
FSMStateName::RL, "rl", ctrl_component) {
|
||||||
|
}
|
||||||
|
|
||||||
|
void StateRL::enter() {
|
||||||
|
}
|
||||||
|
|
||||||
|
void StateRL::run() {
|
||||||
|
}
|
||||||
|
|
||||||
|
void StateRL::exit() {
|
||||||
|
}
|
||||||
|
|
||||||
|
FSMStateName StateRL::checkChange() {
|
||||||
|
switch (ctrl_comp_.control_inputs_.command) {
|
||||||
|
case 1:
|
||||||
|
return FSMStateName::PASSIVE;
|
||||||
|
case 2:
|
||||||
|
return FSMStateName::FIXEDDOWN;
|
||||||
|
default:
|
||||||
|
return FSMStateName::RL;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
torch::Tensor StateRL::computeObservation() {
|
||||||
|
std::vector<torch::Tensor> obs_list;
|
||||||
|
|
||||||
|
const torch::Tensor obs = cat(obs_list, 1);
|
||||||
|
torch::Tensor clamped_obs = clamp(obs, -clip_obs_, clip_obs_);
|
||||||
|
return clamped_obs;
|
||||||
|
}
|
||||||
|
|
||||||
|
torch::Tensor StateRL::forward() {
|
||||||
|
torch::autograd::GradMode::set_enabled(false);
|
||||||
|
torch::Tensor clamped_obs = computeObservation();
|
||||||
|
torch::Tensor actions;
|
||||||
|
|
||||||
|
if (use_history_) {
|
||||||
|
history_obs_buf_->insert(clamped_obs);
|
||||||
|
history_obs_ = history_obs_buf_->getObsVec({0, 1, 2, 3, 4, 5});
|
||||||
|
actions = model.forward({history_obs_}).toTensor();
|
||||||
|
} else {
|
||||||
|
actions = model.forward({clamped_obs}).toTensor();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (clip_actions_upper_.numel() != 0 && clip_actions_lower_.numel() != 0) {
|
||||||
|
return clamp(actions, clip_actions_lower_, clip_actions_upper_);
|
||||||
|
}
|
||||||
|
return actions;
|
||||||
|
}
|
|
@ -0,0 +1,51 @@
|
||||||
|
//
|
||||||
|
// Created by biao on 24-10-6.
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "ObservationBuffer.h"
|
||||||
|
|
||||||
|
ObservationBuffer::ObservationBuffer(int num_envs,
|
||||||
|
const int num_obs,
|
||||||
|
const int include_history_steps)
|
||||||
|
: num_envs_(num_envs),
|
||||||
|
num_obs_(num_obs),
|
||||||
|
include_history_steps_(include_history_steps) {
|
||||||
|
num_obs_total_ = num_obs * include_history_steps;
|
||||||
|
obs_buffer_ = torch::zeros({num_envs, num_obs_total_}, dtype(torch::kFloat32));
|
||||||
|
}
|
||||||
|
|
||||||
|
void ObservationBuffer::reset(const std::vector<int> &reset_index, const torch::Tensor &new_obs) {
|
||||||
|
std::vector<torch::indexing::TensorIndex> indices;
|
||||||
|
for (int index: reset_index) {
|
||||||
|
indices.emplace_back(torch::indexing::Slice(index));
|
||||||
|
}
|
||||||
|
obs_buffer_.index_put_(indices, new_obs.repeat({1, include_history_steps_}));
|
||||||
|
}
|
||||||
|
|
||||||
|
void ObservationBuffer::insert(const torch::Tensor &new_obs) {
|
||||||
|
// Shift observations back.
|
||||||
|
const torch::Tensor shifted_obs = obs_buffer_.index({
|
||||||
|
torch::indexing::Slice(torch::indexing::None), torch::indexing::Slice(num_obs_, num_obs_ * include_history_steps_)
|
||||||
|
}).clone();
|
||||||
|
obs_buffer_.index({
|
||||||
|
torch::indexing::Slice(torch::indexing::None), torch::indexing::Slice(0, num_obs_ * (include_history_steps_ - 1))
|
||||||
|
}) = shifted_obs;
|
||||||
|
|
||||||
|
// Add new observation.
|
||||||
|
obs_buffer_.index({
|
||||||
|
torch::indexing::Slice(torch::indexing::None), torch::indexing::Slice(-num_obs_, torch::indexing::None)
|
||||||
|
}) = new_obs;
|
||||||
|
}
|
||||||
|
|
||||||
|
torch::Tensor ObservationBuffer::getObsVec(const std::vector<int> &obs_ids) const {
|
||||||
|
std::vector<torch::Tensor> obs;
|
||||||
|
for (int i = obs_ids.size() - 1; i >= 0; --i) {
|
||||||
|
const int obs_id = obs_ids[i];
|
||||||
|
const int slice_idx = include_history_steps_ - obs_id - 1;
|
||||||
|
obs.push_back(obs_buffer_.index({
|
||||||
|
torch::indexing::Slice(torch::indexing::None),
|
||||||
|
torch::indexing::Slice(slice_idx * num_obs_, (slice_idx + 1) * num_obs_)
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
return cat(obs, -1);
|
||||||
|
}
|
|
@ -0,0 +1,32 @@
|
||||||
|
//
|
||||||
|
// Created by biao on 24-10-6.
|
||||||
|
//
|
||||||
|
|
||||||
|
#ifndef OBSERVATIONBUFFER_H
|
||||||
|
#define OBSERVATIONBUFFER_H
|
||||||
|
|
||||||
|
#include <torch/torch.h>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
class ObservationBuffer {
|
||||||
|
public:
|
||||||
|
ObservationBuffer(int num_envs, int num_obs, int include_history_steps);
|
||||||
|
|
||||||
|
~ObservationBuffer() = default;
|
||||||
|
|
||||||
|
void reset(const std::vector<int>& reset_index, const torch::Tensor &new_obs);
|
||||||
|
|
||||||
|
void insert(const torch::Tensor &new_obs);
|
||||||
|
|
||||||
|
[[nodiscard]] torch::Tensor getObsVec(const std::vector<int> &obs_ids) const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
int num_envs_;
|
||||||
|
int num_obs_;
|
||||||
|
int include_history_steps_;
|
||||||
|
int num_obs_total_;
|
||||||
|
torch::Tensor obs_buffer_;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
#endif //OBSERVATIONBUFFER_H
|
Loading…
Reference in New Issue