can run if remove pytorch support.
This commit is contained in:
parent
b064925c4b
commit
f8b3efdbdb
|
@ -1,4 +1,5 @@
|
||||||
cmake_minimum_required(VERSION 3.8)
|
cmake_minimum_required(VERSION 3.18 FATAL_ERROR)
|
||||||
|
set(CMAKE_CXX_STANDARD 17)
|
||||||
project(legged_gym_controller)
|
project(legged_gym_controller)
|
||||||
|
|
||||||
if (CMAKE_COMPILER_IS_GNUCXX OR CMAKE_CXX_COMPILER_ID MATCHES "Clang")
|
if (CMAKE_COMPILER_IS_GNUCXX OR CMAKE_CXX_COMPILER_ID MATCHES "Clang")
|
||||||
|
@ -10,25 +11,11 @@ set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
||||||
|
|
||||||
find_package(ament_cmake REQUIRED)
|
find_package(ament_cmake REQUIRED)
|
||||||
|
|
||||||
# rl_sdk library
|
|
||||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")
|
|
||||||
add_definitions(-DCMAKE_CURRENT_SOURCE_DIR="${CMAKE_CURRENT_SOURCE_DIR}")
|
|
||||||
|
|
||||||
find_package(Torch REQUIRED)
|
find_package(Torch REQUIRED)
|
||||||
find_package(Python3 COMPONENTS Interpreter Development REQUIRED)
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")
|
||||||
|
|
||||||
add_library(rl_sdk library/rl_sdk/rl_sdk.cpp)
|
find_package(Python3 COMPONENTS Interpreter Development REQUIRED)
|
||||||
target_include_directories(rl_sdk
|
|
||||||
PUBLIC
|
|
||||||
library/rl_sdk)
|
|
||||||
target_link_libraries(rl_sdk "${TORCH_LIBRARIES}" Python3::Python Python3::Module)
|
|
||||||
set_property(TARGET rl_sdk PROPERTY CXX_STANDARD 14)
|
|
||||||
find_package(Python3 COMPONENTS NumPy)
|
find_package(Python3 COMPONENTS NumPy)
|
||||||
if (Python3_NumPy_FOUND)
|
|
||||||
target_link_libraries(rl_sdk Python3::NumPy)
|
|
||||||
else ()
|
|
||||||
target_compile_definitions(rl_sdk WITHOUT_NUMPY)
|
|
||||||
endif ()
|
|
||||||
|
|
||||||
set(dependencies
|
set(dependencies
|
||||||
pluginlib
|
pluginlib
|
||||||
|
@ -43,6 +30,44 @@ foreach (Dependency IN ITEMS ${dependencies})
|
||||||
find_package(${Dependency} REQUIRED)
|
find_package(${Dependency} REQUIRED)
|
||||||
endforeach ()
|
endforeach ()
|
||||||
|
|
||||||
|
add_library(${PROJECT_NAME} SHARED
|
||||||
|
src/LeggedGymController.cpp
|
||||||
|
|
||||||
|
src/FSM/StatePassive.cpp
|
||||||
|
src/FSM/StateFixedStand.cpp
|
||||||
|
src/FSM/StateFixedDown.cpp
|
||||||
|
)
|
||||||
|
target_include_directories(${PROJECT_NAME}
|
||||||
|
PUBLIC
|
||||||
|
"$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>"
|
||||||
|
"$<INSTALL_INTERFACE:include/${PROJECT_NAME}>"
|
||||||
|
PRIVATE
|
||||||
|
src)
|
||||||
|
target_link_libraries(${PROJECT_NAME} PUBLIC
|
||||||
|
"${TORCH_LIBRARIES}"
|
||||||
|
)
|
||||||
|
ament_target_dependencies(
|
||||||
|
${PROJECT_NAME} PUBLIC
|
||||||
|
${dependencies}
|
||||||
|
)
|
||||||
|
|
||||||
|
pluginlib_export_plugin_description_file(controller_interface legged_gym_controller.xml)
|
||||||
|
|
||||||
|
install(
|
||||||
|
DIRECTORY include/
|
||||||
|
DESTINATION include/${PROJECT_NAME}
|
||||||
|
)
|
||||||
|
|
||||||
|
install(
|
||||||
|
TARGETS ${PROJECT_NAME}
|
||||||
|
EXPORT export_${PROJECT_NAME}
|
||||||
|
ARCHIVE DESTINATION lib/${PROJECT_NAME}
|
||||||
|
LIBRARY DESTINATION lib/${PROJECT_NAME}
|
||||||
|
RUNTIME DESTINATION bin
|
||||||
|
)
|
||||||
|
|
||||||
|
ament_export_dependencies(${dependencies})
|
||||||
|
ament_export_targets(export_${PROJECT_NAME} HAS_LIBRARY_TARGET)
|
||||||
|
|
||||||
if (BUILD_TESTING)
|
if (BUILD_TESTING)
|
||||||
find_package(ament_lint_auto REQUIRED)
|
find_package(ament_lint_auto REQUIRED)
|
||||||
|
|
|
@ -8,13 +8,12 @@ Tested environment:
|
||||||
|
|
||||||
|
|
||||||
## 2. Build
|
## 2. Build
|
||||||
|
|
||||||
### 2.1 Installing libtorch
|
### 2.1 Installing libtorch
|
||||||
```bash
|
```bash
|
||||||
cd ~/CLionProjects/
|
cd ~/CLionProjects/
|
||||||
wget https://download.pytorch.org/libtorch/cpu/libtorch-cxx11-abi-shared-with-deps-2.0.1%2Bcpu.zip
|
wget https://download.pytorch.org/libtorch/nightly/cpu/libtorch-shared-with-deps-latest.zip
|
||||||
unzip libtorch-cxx11-abi-shared-with-deps-2.0.1+cpu.zip -d ./
|
unzip libtorch-shared-with-deps-latest.zip
|
||||||
rm libtorch-cxx11-abi-shared-with-deps-2.0.1+cpu.zip
|
rm -rf libtorch-shared-with-deps-latest.zip
|
||||||
echo 'export Torch_DIR=~/CLionProjects/libtorch' >> ~/.bashrc
|
echo 'export Torch_DIR=~/CLionProjects/libtorch' >> ~/.bashrc
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -23,3 +22,10 @@ echo 'export Torch_DIR=~/CLionProjects/libtorch' >> ~/.bashrc
|
||||||
cd ~/ros2_ws
|
cd ~/ros2_ws
|
||||||
colcon build --packages-up-to legged_gym_controller
|
colcon build --packages-up-to legged_gym_controller
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## 3. Launch
|
||||||
|
* Unitree A1 Robot
|
||||||
|
```bash
|
||||||
|
source ~/ros2_ws/install/setup.bash
|
||||||
|
ros2 launch a1_description rl_control.launch.py
|
||||||
|
```
|
|
@ -0,0 +1,40 @@
|
||||||
|
//
|
||||||
|
// Created by tlab-uav on 24-9-6.
|
||||||
|
//
|
||||||
|
|
||||||
|
#ifndef FSMSTATE_H
|
||||||
|
#define FSMSTATE_H
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <rclcpp/time.hpp>
|
||||||
|
|
||||||
|
#include "legged_gym_controller/common/enumClass.h"
|
||||||
|
#include "legged_gym_controller/control/CtrlComponent.h"
|
||||||
|
|
||||||
|
class FSMState {
|
||||||
|
public:
|
||||||
|
virtual ~FSMState() = default;
|
||||||
|
|
||||||
|
FSMState(const FSMStateName &state_name, std::string state_name_string, CtrlComponent &ctrl_comp)
|
||||||
|
: state_name(state_name),
|
||||||
|
state_name_string(std::move(state_name_string)),
|
||||||
|
ctrl_comp_(ctrl_comp) {
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual void enter() = 0;
|
||||||
|
|
||||||
|
virtual void run() = 0;
|
||||||
|
|
||||||
|
virtual void exit() = 0;
|
||||||
|
|
||||||
|
virtual FSMStateName checkChange() { return FSMStateName::INVALID; }
|
||||||
|
|
||||||
|
FSMStateName state_name;
|
||||||
|
std::string state_name_string;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
CtrlComponent &ctrl_comp_;
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif //FSMSTATE_H
|
|
@ -0,0 +1,37 @@
|
||||||
|
//
|
||||||
|
// Created by tlab-uav on 24-9-11.
|
||||||
|
//
|
||||||
|
|
||||||
|
#ifndef STATEFIXEDDOWN_H
|
||||||
|
#define STATEFIXEDDOWN_H
|
||||||
|
|
||||||
|
#include "FSMState.h"
|
||||||
|
|
||||||
|
class StateFixedDown final : public FSMState {
|
||||||
|
public:
|
||||||
|
explicit StateFixedDown(CtrlComponent &ctrlComp);
|
||||||
|
|
||||||
|
void enter() override;
|
||||||
|
|
||||||
|
void run() override;
|
||||||
|
|
||||||
|
void exit() override;
|
||||||
|
|
||||||
|
FSMStateName checkChange() override;
|
||||||
|
private:
|
||||||
|
double target_pos_[12] = {
|
||||||
|
0.0473455, 1.22187, -2.44375, -0.0473455,
|
||||||
|
1.22187, -2.44375, 0.0473455, 1.22187,
|
||||||
|
-2.44375, -0.0473455, 1.22187, -2.44375
|
||||||
|
};
|
||||||
|
|
||||||
|
double start_pos_[12] = {};
|
||||||
|
rclcpp::Time start_time_;
|
||||||
|
|
||||||
|
double duration_ = 600; // steps
|
||||||
|
double percent_ = 0; //%
|
||||||
|
double phase = 0.0;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
#endif //STATEFIXEDDOWN_H
|
|
@ -0,0 +1,39 @@
|
||||||
|
//
|
||||||
|
// Created by biao on 24-9-10.
|
||||||
|
//
|
||||||
|
|
||||||
|
#ifndef STATEFIXEDSTAND_H
|
||||||
|
#define STATEFIXEDSTAND_H
|
||||||
|
|
||||||
|
#include "FSMState.h"
|
||||||
|
|
||||||
|
class StateFixedStand final : public FSMState {
|
||||||
|
public:
|
||||||
|
explicit StateFixedStand(CtrlComponent &ctrlComp);
|
||||||
|
|
||||||
|
void enter() override;
|
||||||
|
|
||||||
|
void run() override;
|
||||||
|
|
||||||
|
void exit() override;
|
||||||
|
|
||||||
|
FSMStateName checkChange() override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
double target_pos_[12] = {
|
||||||
|
0.00571868, 0.608813, -1.21763,
|
||||||
|
-0.00571868, 0.608813, -1.21763,
|
||||||
|
0.00571868, 0.608813, -1.21763,
|
||||||
|
-0.00571868, 0.608813, -1.21763
|
||||||
|
};
|
||||||
|
|
||||||
|
double start_pos_[12] = {};
|
||||||
|
rclcpp::Time start_time_;
|
||||||
|
|
||||||
|
double duration_ = 600; // steps
|
||||||
|
double percent_ = 0; //%
|
||||||
|
double phase = 0.0;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
#endif //STATEFIXEDSTAND_H
|
|
@ -0,0 +1,24 @@
|
||||||
|
//
|
||||||
|
// Created by tlab-uav on 24-9-6.
|
||||||
|
//
|
||||||
|
|
||||||
|
#ifndef STATEPASSIVE_H
|
||||||
|
#define STATEPASSIVE_H
|
||||||
|
#include "FSMState.h"
|
||||||
|
|
||||||
|
|
||||||
|
class StatePassive final : public FSMState {
|
||||||
|
public:
|
||||||
|
explicit StatePassive(CtrlComponent &ctrlComp);
|
||||||
|
|
||||||
|
void enter() override;
|
||||||
|
|
||||||
|
void run() override;
|
||||||
|
|
||||||
|
void exit() override;
|
||||||
|
|
||||||
|
FSMStateName checkChange() override;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
#endif //STATEPASSIVE_H
|
|
@ -0,0 +1,21 @@
|
||||||
|
//
|
||||||
|
// Created by tlab-uav on 24-9-6.
|
||||||
|
//
|
||||||
|
|
||||||
|
#ifndef ENUMCLASS_H
|
||||||
|
#define ENUMCLASS_H
|
||||||
|
|
||||||
|
enum class FSMStateName {
|
||||||
|
// EXIT,
|
||||||
|
INVALID,
|
||||||
|
PASSIVE,
|
||||||
|
FIXEDDOWN,
|
||||||
|
FIXEDSTAND,
|
||||||
|
};
|
||||||
|
|
||||||
|
enum class FSMMode {
|
||||||
|
NORMAL,
|
||||||
|
CHANGE
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif //ENUMCLASS_H
|
|
@ -0,0 +1,61 @@
|
||||||
|
//
|
||||||
|
// Created by biao on 24-9-10.
|
||||||
|
//
|
||||||
|
|
||||||
|
#ifndef INTERFACE_H
|
||||||
|
#define INTERFACE_H
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
#include <hardware_interface/loaned_command_interface.hpp>
|
||||||
|
#include <hardware_interface/loaned_state_interface.hpp>
|
||||||
|
#include <control_input_msgs/msg/inputs.hpp>
|
||||||
|
|
||||||
|
struct CtrlComponent {
|
||||||
|
std::vector<std::reference_wrapper<hardware_interface::LoanedCommandInterface> >
|
||||||
|
joint_torque_command_interface_;
|
||||||
|
std::vector<std::reference_wrapper<hardware_interface::LoanedCommandInterface> >
|
||||||
|
joint_position_command_interface_;
|
||||||
|
std::vector<std::reference_wrapper<hardware_interface::LoanedCommandInterface> >
|
||||||
|
joint_velocity_command_interface_;
|
||||||
|
std::vector<std::reference_wrapper<hardware_interface::LoanedCommandInterface> >
|
||||||
|
joint_kp_command_interface_;
|
||||||
|
std::vector<std::reference_wrapper<hardware_interface::LoanedCommandInterface> >
|
||||||
|
joint_kd_command_interface_;
|
||||||
|
|
||||||
|
|
||||||
|
std::vector<std::reference_wrapper<hardware_interface::LoanedStateInterface> >
|
||||||
|
joint_effort_state_interface_;
|
||||||
|
std::vector<std::reference_wrapper<hardware_interface::LoanedStateInterface> >
|
||||||
|
joint_position_state_interface_;
|
||||||
|
std::vector<std::reference_wrapper<hardware_interface::LoanedStateInterface> >
|
||||||
|
joint_velocity_state_interface_;
|
||||||
|
|
||||||
|
std::vector<std::reference_wrapper<hardware_interface::LoanedStateInterface> >
|
||||||
|
imu_state_interface_;
|
||||||
|
|
||||||
|
std::vector<std::reference_wrapper<hardware_interface::LoanedStateInterface> >
|
||||||
|
foot_force_state_interface_;
|
||||||
|
|
||||||
|
control_input_msgs::msg::Inputs control_inputs_;
|
||||||
|
int frequency_{};
|
||||||
|
|
||||||
|
CtrlComponent() {
|
||||||
|
}
|
||||||
|
|
||||||
|
void clear() {
|
||||||
|
joint_torque_command_interface_.clear();
|
||||||
|
joint_position_command_interface_.clear();
|
||||||
|
joint_velocity_command_interface_.clear();
|
||||||
|
joint_kd_command_interface_.clear();
|
||||||
|
joint_kp_command_interface_.clear();
|
||||||
|
|
||||||
|
joint_effort_state_interface_.clear();
|
||||||
|
joint_position_state_interface_.clear();
|
||||||
|
joint_velocity_state_interface_.clear();
|
||||||
|
|
||||||
|
imu_state_interface_.clear();
|
||||||
|
foot_force_state_interface_.clear();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif //INTERFACE_H
|
|
@ -0,0 +1,9 @@
|
||||||
|
<library path="legged_gym_controller">
|
||||||
|
<class name="legged_gym_controller/LeggedGymController"
|
||||||
|
type="legged_gym_controller::LeggedGymController"
|
||||||
|
base_class_type="controller_interface::ControllerInterface">
|
||||||
|
<description>
|
||||||
|
Quadruped Controller used RL-model trained in Legged Gym.
|
||||||
|
</description>
|
||||||
|
</class>
|
||||||
|
</library>
|
|
@ -0,0 +1,45 @@
|
||||||
|
#include "observation_buffer.hpp"
|
||||||
|
|
||||||
|
ObservationBuffer::ObservationBuffer() {}
|
||||||
|
|
||||||
|
ObservationBuffer::ObservationBuffer(int num_envs,
|
||||||
|
int num_obs,
|
||||||
|
int include_history_steps)
|
||||||
|
: num_envs(num_envs),
|
||||||
|
num_obs(num_obs),
|
||||||
|
include_history_steps(include_history_steps)
|
||||||
|
{
|
||||||
|
num_obs_total = num_obs * include_history_steps;
|
||||||
|
obs_buf = torch::zeros({num_envs, num_obs_total}, torch::dtype(torch::kFloat32));
|
||||||
|
}
|
||||||
|
|
||||||
|
void ObservationBuffer::reset(std::vector<int> reset_idxs, torch::Tensor new_obs)
|
||||||
|
{
|
||||||
|
std::vector<torch::indexing::TensorIndex> indices;
|
||||||
|
for (int idx : reset_idxs) {
|
||||||
|
indices.push_back(torch::indexing::Slice(idx));
|
||||||
|
}
|
||||||
|
obs_buf.index_put_(indices, new_obs.repeat({1, include_history_steps}));
|
||||||
|
}
|
||||||
|
|
||||||
|
void ObservationBuffer::insert(torch::Tensor new_obs)
|
||||||
|
{
|
||||||
|
// Shift observations back.
|
||||||
|
torch::Tensor shifted_obs = obs_buf.index({torch::indexing::Slice(torch::indexing::None), torch::indexing::Slice(num_obs, num_obs * include_history_steps)}).clone();
|
||||||
|
obs_buf.index({torch::indexing::Slice(torch::indexing::None), torch::indexing::Slice(0, num_obs * (include_history_steps - 1))}) = shifted_obs;
|
||||||
|
|
||||||
|
// Add new observation.
|
||||||
|
obs_buf.index({torch::indexing::Slice(torch::indexing::None), torch::indexing::Slice(-num_obs, torch::indexing::None)}) = new_obs;
|
||||||
|
}
|
||||||
|
|
||||||
|
torch::Tensor ObservationBuffer::get_obs_vec(std::vector<int> obs_ids)
|
||||||
|
{
|
||||||
|
std::vector<torch::Tensor> obs;
|
||||||
|
for (int i = obs_ids.size() - 1; i >= 0; --i)
|
||||||
|
{
|
||||||
|
int obs_id = obs_ids[i];
|
||||||
|
int slice_idx = include_history_steps - obs_id - 1;
|
||||||
|
obs.push_back(obs_buf.index({torch::indexing::Slice(torch::indexing::None), torch::indexing::Slice(slice_idx * num_obs, (slice_idx + 1) * num_obs)}));
|
||||||
|
}
|
||||||
|
return cat(obs, -1);
|
||||||
|
}
|
|
@ -0,0 +1,24 @@
|
||||||
|
#ifndef OBSERVATION_BUFFER_HPP
|
||||||
|
#define OBSERVATION_BUFFER_HPP
|
||||||
|
|
||||||
|
#include <torch/torch.h>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
class ObservationBuffer {
|
||||||
|
public:
|
||||||
|
ObservationBuffer(int num_envs, int num_obs, int include_history_steps);
|
||||||
|
ObservationBuffer();
|
||||||
|
|
||||||
|
void reset(std::vector<int> reset_idxs, torch::Tensor new_obs);
|
||||||
|
void insert(torch::Tensor new_obs);
|
||||||
|
torch::Tensor get_obs_vec(std::vector<int> obs_ids);
|
||||||
|
|
||||||
|
private:
|
||||||
|
int num_envs;
|
||||||
|
int num_obs;
|
||||||
|
int include_history_steps;
|
||||||
|
int num_obs_total;
|
||||||
|
torch::Tensor obs_buf;
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif // OBSERVATION_BUFFER_HPP
|
|
@ -374,55 +374,55 @@ std::vector<T> ReadVectorFromYaml(const YAML::Node& node, const std::string& fra
|
||||||
|
|
||||||
void RL::ReadYaml(std::string robot_name)
|
void RL::ReadYaml(std::string robot_name)
|
||||||
{
|
{
|
||||||
// The config file is located at "rl_sar/src/rl_sar/models/<robot_name>/config.yaml"
|
// // The config file is located at "rl_sar/src/rl_sar/models/<robot_name>/config.yaml"
|
||||||
std::string config_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/" + robot_name + "/config.yaml";
|
// std::string config_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/" + robot_name + "/config.yaml";
|
||||||
YAML::Node config;
|
// YAML::Node config;
|
||||||
try
|
// try
|
||||||
{
|
// {
|
||||||
config = YAML::LoadFile(config_path)[robot_name];
|
// config = YAML::LoadFile(config_path)[robot_name];
|
||||||
} catch(YAML::BadFile &e)
|
// } catch(YAML::BadFile &e)
|
||||||
{
|
// {
|
||||||
std::cout << LOGGER::ERROR << "The file '" << config_path << "' does not exist" << std::endl;
|
// std::cout << LOGGER::ERROR << "The file '" << config_path << "' does not exist" << std::endl;
|
||||||
return;
|
// return;
|
||||||
}
|
// }
|
||||||
|
//
|
||||||
this->params.model_name = config["model_name"].as<std::string>();
|
// this->params.model_name = config["model_name"].as<std::string>();
|
||||||
this->params.framework = config["framework"].as<std::string>();
|
// this->params.framework = config["framework"].as<std::string>();
|
||||||
int rows = config["rows"].as<int>();
|
// int rows = config["rows"].as<int>();
|
||||||
int cols = config["cols"].as<int>();
|
// int cols = config["cols"].as<int>();
|
||||||
this->params.use_history = config["use_history"].as<bool>();
|
// this->params.use_history = config["use_history"].as<bool>();
|
||||||
this->params.dt = config["dt"].as<double>();
|
// this->params.dt = config["dt"].as<double>();
|
||||||
this->params.decimation = config["decimation"].as<int>();
|
// this->params.decimation = config["decimation"].as<int>();
|
||||||
this->params.num_observations = config["num_observations"].as<int>();
|
// this->params.num_observations = config["num_observations"].as<int>();
|
||||||
this->params.observations = ReadVectorFromYaml<std::string>(config["observations"]);
|
// this->params.observations = ReadVectorFromYaml<std::string>(config["observations"]);
|
||||||
this->params.clip_obs = config["clip_obs"].as<double>();
|
// this->params.clip_obs = config["clip_obs"].as<double>();
|
||||||
if(config["clip_actions_lower"].IsNull() && config["clip_actions_upper"].IsNull())
|
// if(config["clip_actions_lower"].IsNull() && config["clip_actions_upper"].IsNull())
|
||||||
{
|
// {
|
||||||
this->params.clip_actions_upper = torch::tensor({}).view({1, -1});
|
// this->params.clip_actions_upper = torch::tensor({}).view({1, -1});
|
||||||
this->params.clip_actions_lower = torch::tensor({}).view({1, -1});
|
// this->params.clip_actions_lower = torch::tensor({}).view({1, -1});
|
||||||
}
|
// }
|
||||||
else
|
// else
|
||||||
{
|
// {
|
||||||
this->params.clip_actions_upper = torch::tensor(ReadVectorFromYaml<double>(config["clip_actions_upper"], this->params.framework, rows, cols)).view({1, -1});
|
// this->params.clip_actions_upper = torch::tensor(ReadVectorFromYaml<double>(config["clip_actions_upper"], this->params.framework, rows, cols)).view({1, -1});
|
||||||
this->params.clip_actions_lower = torch::tensor(ReadVectorFromYaml<double>(config["clip_actions_lower"], this->params.framework, rows, cols)).view({1, -1});
|
// this->params.clip_actions_lower = torch::tensor(ReadVectorFromYaml<double>(config["clip_actions_lower"], this->params.framework, rows, cols)).view({1, -1});
|
||||||
}
|
// }
|
||||||
this->params.action_scale = config["action_scale"].as<double>();
|
// this->params.action_scale = config["action_scale"].as<double>();
|
||||||
this->params.hip_scale_reduction = config["hip_scale_reduction"].as<double>();
|
// this->params.hip_scale_reduction = config["hip_scale_reduction"].as<double>();
|
||||||
this->params.hip_scale_reduction_indices = ReadVectorFromYaml<int>(config["hip_scale_reduction_indices"]);
|
// this->params.hip_scale_reduction_indices = ReadVectorFromYaml<int>(config["hip_scale_reduction_indices"]);
|
||||||
this->params.num_of_dofs = config["num_of_dofs"].as<int>();
|
// this->params.num_of_dofs = config["num_of_dofs"].as<int>();
|
||||||
this->params.lin_vel_scale = config["lin_vel_scale"].as<double>();
|
// this->params.lin_vel_scale = config["lin_vel_scale"].as<double>();
|
||||||
this->params.ang_vel_scale = config["ang_vel_scale"].as<double>();
|
// this->params.ang_vel_scale = config["ang_vel_scale"].as<double>();
|
||||||
this->params.dof_pos_scale = config["dof_pos_scale"].as<double>();
|
// this->params.dof_pos_scale = config["dof_pos_scale"].as<double>();
|
||||||
this->params.dof_vel_scale = config["dof_vel_scale"].as<double>();
|
// this->params.dof_vel_scale = config["dof_vel_scale"].as<double>();
|
||||||
// this->params.commands_scale = torch::tensor(ReadVectorFromYaml<double>(config["commands_scale"])).view({1, -1});
|
// // this->params.commands_scale = torch::tensor(ReadVectorFromYaml<double>(config["commands_scale"])).view({1, -1});
|
||||||
this->params.commands_scale = torch::tensor({this->params.lin_vel_scale, this->params.lin_vel_scale, this->params.ang_vel_scale});
|
// this->params.commands_scale = torch::tensor({this->params.lin_vel_scale, this->params.lin_vel_scale, this->params.ang_vel_scale});
|
||||||
this->params.rl_kp = torch::tensor(ReadVectorFromYaml<double>(config["rl_kp"], this->params.framework, rows, cols)).view({1, -1});
|
// this->params.rl_kp = torch::tensor(ReadVectorFromYaml<double>(config["rl_kp"], this->params.framework, rows, cols)).view({1, -1});
|
||||||
this->params.rl_kd = torch::tensor(ReadVectorFromYaml<double>(config["rl_kd"], this->params.framework, rows, cols)).view({1, -1});
|
// this->params.rl_kd = torch::tensor(ReadVectorFromYaml<double>(config["rl_kd"], this->params.framework, rows, cols)).view({1, -1});
|
||||||
this->params.fixed_kp = torch::tensor(ReadVectorFromYaml<double>(config["fixed_kp"], this->params.framework, rows, cols)).view({1, -1});
|
// this->params.fixed_kp = torch::tensor(ReadVectorFromYaml<double>(config["fixed_kp"], this->params.framework, rows, cols)).view({1, -1});
|
||||||
this->params.fixed_kd = torch::tensor(ReadVectorFromYaml<double>(config["fixed_kd"], this->params.framework, rows, cols)).view({1, -1});
|
// this->params.fixed_kd = torch::tensor(ReadVectorFromYaml<double>(config["fixed_kd"], this->params.framework, rows, cols)).view({1, -1});
|
||||||
this->params.torque_limits = torch::tensor(ReadVectorFromYaml<double>(config["torque_limits"], this->params.framework, rows, cols)).view({1, -1});
|
// this->params.torque_limits = torch::tensor(ReadVectorFromYaml<double>(config["torque_limits"], this->params.framework, rows, cols)).view({1, -1});
|
||||||
this->params.default_dof_pos = torch::tensor(ReadVectorFromYaml<double>(config["default_dof_pos"], this->params.framework, rows, cols)).view({1, -1});
|
// this->params.default_dof_pos = torch::tensor(ReadVectorFromYaml<double>(config["default_dof_pos"], this->params.framework, rows, cols)).view({1, -1});
|
||||||
this->params.joint_controller_names = ReadVectorFromYaml<std::string>(config["joint_controller_names"], this->params.framework, rows, cols);
|
// this->params.joint_controller_names = ReadVectorFromYaml<std::string>(config["joint_controller_names"], this->params.framework, rows, cols);
|
||||||
}
|
}
|
||||||
|
|
||||||
void RL::CSVInit(std::string robot_name)
|
void RL::CSVInit(std::string robot_name)
|
||||||
|
|
|
@ -2,11 +2,7 @@
|
||||||
#define RL_SDK_HPP
|
#define RL_SDK_HPP
|
||||||
|
|
||||||
#include <torch/script.h>
|
#include <torch/script.h>
|
||||||
#include <iostream>
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <unistd.h>
|
|
||||||
|
|
||||||
#include <yaml-cpp/yaml.h>
|
|
||||||
|
|
||||||
namespace LOGGER {
|
namespace LOGGER {
|
||||||
const char* const INFO = "\033[0;37m[INFO]\033[0m ";
|
const char* const INFO = "\033[0;37m[INFO]\033[0m ";
|
||||||
|
|
|
@ -0,0 +1,50 @@
|
||||||
|
//
|
||||||
|
// Created by tlab-uav on 24-9-11.
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "legged_gym_controller/FSM/StateFixedDown.h"
|
||||||
|
|
||||||
|
#include <cmath>
|
||||||
|
|
||||||
|
StateFixedDown::StateFixedDown(CtrlComponent &ctrlComp): FSMState(
|
||||||
|
FSMStateName::FIXEDDOWN, "fixed down", ctrlComp) {
|
||||||
|
duration_ = ctrl_comp_.frequency_ * 1.2;
|
||||||
|
}
|
||||||
|
|
||||||
|
void StateFixedDown::enter() {
|
||||||
|
for (int i = 0; i < 12; i++) {
|
||||||
|
start_pos_[i] = ctrl_comp_.joint_position_state_interface_[i].get().get_value();
|
||||||
|
}
|
||||||
|
ctrl_comp_.control_inputs_.command = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
void StateFixedDown::run() {
|
||||||
|
percent_ += 1 / duration_;
|
||||||
|
phase = std::tanh(percent_);
|
||||||
|
for (int i = 0; i < 12; i++) {
|
||||||
|
ctrl_comp_.joint_position_command_interface_[i].get().set_value(
|
||||||
|
phase * target_pos_[i] + (1 - phase) * start_pos_[i]);
|
||||||
|
ctrl_comp_.joint_velocity_command_interface_[i].get().set_value(0);
|
||||||
|
ctrl_comp_.joint_torque_command_interface_[i].get().set_value(0);
|
||||||
|
ctrl_comp_.joint_kp_command_interface_[i].get().set_value(30.0);
|
||||||
|
ctrl_comp_.joint_kd_command_interface_[i].get().set_value(1.5);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void StateFixedDown::exit() {
|
||||||
|
percent_ = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
FSMStateName StateFixedDown::checkChange() {
|
||||||
|
if (percent_ < 1.5) {
|
||||||
|
return FSMStateName::FIXEDDOWN;
|
||||||
|
}
|
||||||
|
switch (ctrl_comp_.control_inputs_.command) {
|
||||||
|
case 1:
|
||||||
|
return FSMStateName::PASSIVE;
|
||||||
|
case 2:
|
||||||
|
return FSMStateName::FIXEDSTAND;
|
||||||
|
default:
|
||||||
|
return FSMStateName::FIXEDDOWN;
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,51 @@
|
||||||
|
//
|
||||||
|
// Created by biao on 24-9-10.
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "legged_gym_controller/FSM/StateFixedStand.h"
|
||||||
|
|
||||||
|
#include <cmath>
|
||||||
|
|
||||||
|
StateFixedStand::StateFixedStand(CtrlComponent &ctrlComp): FSMState(
|
||||||
|
FSMStateName::FIXEDSTAND, "fixed stand", ctrlComp) {
|
||||||
|
duration_ = ctrl_comp_.frequency_ * 1.2;
|
||||||
|
}
|
||||||
|
|
||||||
|
void StateFixedStand::enter() {
|
||||||
|
for (int i = 0; i < 12; i++) {
|
||||||
|
start_pos_[i] = ctrl_comp_.joint_position_state_interface_[i].get().get_value();
|
||||||
|
}
|
||||||
|
ctrl_comp_.control_inputs_.command = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
void StateFixedStand::run() {
|
||||||
|
percent_ += 1 / duration_;
|
||||||
|
phase = std::tanh(percent_);
|
||||||
|
for (int i = 0; i < 12; i++) {
|
||||||
|
ctrl_comp_.joint_position_command_interface_[i].get().set_value(
|
||||||
|
phase * target_pos_[i] + (1 - phase) * start_pos_[i]);
|
||||||
|
ctrl_comp_.joint_velocity_command_interface_[i].get().set_value(0);
|
||||||
|
ctrl_comp_.joint_torque_command_interface_[i].get().set_value(0);
|
||||||
|
ctrl_comp_.joint_kp_command_interface_[i].get().set_value(
|
||||||
|
phase * 60.0 + (1 - phase) * 20.0);
|
||||||
|
ctrl_comp_.joint_kd_command_interface_[i].get().set_value(3.5);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void StateFixedStand::exit() {
|
||||||
|
percent_ = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
FSMStateName StateFixedStand::checkChange() {
|
||||||
|
if (percent_ < 1.5) {
|
||||||
|
return FSMStateName::FIXEDSTAND;
|
||||||
|
}
|
||||||
|
switch (ctrl_comp_.control_inputs_.command) {
|
||||||
|
case 1:
|
||||||
|
return FSMStateName::PASSIVE;
|
||||||
|
case 2:
|
||||||
|
return FSMStateName::FIXEDDOWN;
|
||||||
|
default:
|
||||||
|
return FSMStateName::FIXEDSTAND;
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,43 @@
|
||||||
|
//
|
||||||
|
// Created by tlab-uav on 24-9-6.
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "legged_gym_controller/FSM/StatePassive.h"
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
|
StatePassive::StatePassive(CtrlComponent &ctrlComp) : FSMState(
|
||||||
|
FSMStateName::PASSIVE, "passive", ctrlComp) {
|
||||||
|
}
|
||||||
|
|
||||||
|
void StatePassive::enter() {
|
||||||
|
for (auto i: ctrl_comp_.joint_torque_command_interface_) {
|
||||||
|
i.get().set_value(0);
|
||||||
|
}
|
||||||
|
for (auto i: ctrl_comp_.joint_position_command_interface_) {
|
||||||
|
i.get().set_value(0);
|
||||||
|
}
|
||||||
|
for (auto i: ctrl_comp_.joint_velocity_command_interface_) {
|
||||||
|
i.get().set_value(0);
|
||||||
|
}
|
||||||
|
for (auto i: ctrl_comp_.joint_kp_command_interface_) {
|
||||||
|
i.get().set_value(0);
|
||||||
|
}
|
||||||
|
for (auto i: ctrl_comp_.joint_kd_command_interface_) {
|
||||||
|
i.get().set_value(1);
|
||||||
|
}
|
||||||
|
ctrl_comp_.control_inputs_.command = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
void StatePassive::run() {
|
||||||
|
}
|
||||||
|
|
||||||
|
void StatePassive::exit() {
|
||||||
|
}
|
||||||
|
|
||||||
|
FSMStateName StatePassive::checkChange() {
|
||||||
|
if (ctrl_comp_.control_inputs_.command == 2) {
|
||||||
|
return FSMStateName::FIXEDDOWN;
|
||||||
|
}
|
||||||
|
return FSMStateName::PASSIVE;
|
||||||
|
}
|
|
@ -0,0 +1,182 @@
|
||||||
|
//
|
||||||
|
// Created by tlab-uav on 24-10-4.
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "LeggedGymController.h"
|
||||||
|
|
||||||
|
namespace legged_gym_controller {
|
||||||
|
using config_type = controller_interface::interface_configuration_type;
|
||||||
|
|
||||||
|
controller_interface::InterfaceConfiguration LeggedGymController::command_interface_configuration() const {
|
||||||
|
controller_interface::InterfaceConfiguration conf = {config_type::INDIVIDUAL, {}};
|
||||||
|
|
||||||
|
conf.names.reserve(joint_names_.size() * command_interface_types_.size());
|
||||||
|
for (const auto &joint_name: joint_names_) {
|
||||||
|
for (const auto &interface_type: command_interface_types_) {
|
||||||
|
conf.names.push_back(joint_name + "/" += interface_type);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return conf;
|
||||||
|
}
|
||||||
|
|
||||||
|
controller_interface::InterfaceConfiguration LeggedGymController::state_interface_configuration() const {
|
||||||
|
controller_interface::InterfaceConfiguration conf = {config_type::INDIVIDUAL, {}};
|
||||||
|
|
||||||
|
conf.names.reserve(joint_names_.size() * state_interface_types_.size());
|
||||||
|
for (const auto &joint_name: joint_names_) {
|
||||||
|
for (const auto &interface_type: state_interface_types_) {
|
||||||
|
conf.names.push_back(joint_name + "/" += interface_type);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const auto &interface_type: imu_interface_types_) {
|
||||||
|
conf.names.push_back(imu_name_ + "/" += interface_type);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const auto &interface_type: foot_force_interface_types_) {
|
||||||
|
conf.names.push_back(foot_force_name_ + "/" += interface_type);
|
||||||
|
}
|
||||||
|
|
||||||
|
return conf;
|
||||||
|
}
|
||||||
|
|
||||||
|
controller_interface::return_type LeggedGymController::
|
||||||
|
update(const rclcpp::Time & /*time*/, const rclcpp::Duration & /*period*/) {
|
||||||
|
if (mode_ == FSMMode::NORMAL) {
|
||||||
|
current_state_->run();
|
||||||
|
next_state_name_ = current_state_->checkChange();
|
||||||
|
if (next_state_name_ != current_state_->state_name) {
|
||||||
|
mode_ = FSMMode::CHANGE;
|
||||||
|
next_state_ = getNextState(next_state_name_);
|
||||||
|
RCLCPP_INFO(get_node()->get_logger(), "Switched from %s to %s",
|
||||||
|
current_state_->state_name_string.c_str(), next_state_->state_name_string.c_str());
|
||||||
|
}
|
||||||
|
} else if (mode_ == FSMMode::CHANGE) {
|
||||||
|
current_state_->exit();
|
||||||
|
current_state_ = next_state_;
|
||||||
|
|
||||||
|
current_state_->enter();
|
||||||
|
mode_ = FSMMode::NORMAL;
|
||||||
|
}
|
||||||
|
|
||||||
|
return controller_interface::return_type::OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
controller_interface::CallbackReturn LeggedGymController::on_init() {
|
||||||
|
try {
|
||||||
|
joint_names_ = auto_declare<std::vector<std::string> >("joints", joint_names_);
|
||||||
|
feet_names_ = auto_declare<std::vector<std::string> >("feet_names", feet_names_);
|
||||||
|
command_interface_types_ =
|
||||||
|
auto_declare<std::vector<std::string> >("command_interfaces", command_interface_types_);
|
||||||
|
state_interface_types_ =
|
||||||
|
auto_declare<std::vector<std::string> >("state_interfaces", state_interface_types_);
|
||||||
|
|
||||||
|
// imu sensor
|
||||||
|
imu_name_ = auto_declare<std::string>("imu_name", imu_name_);
|
||||||
|
imu_interface_types_ = auto_declare<std::vector<std::string> >("imu_interfaces", state_interface_types_);
|
||||||
|
} catch (const std::exception &e) {
|
||||||
|
fprintf(stderr, "Exception thrown during init stage with message: %s \n", e.what());
|
||||||
|
return controller_interface::CallbackReturn::ERROR;
|
||||||
|
}
|
||||||
|
|
||||||
|
return CallbackReturn::SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
controller_interface::CallbackReturn LeggedGymController::on_configure(
|
||||||
|
const rclcpp_lifecycle::State & /*previous_state*/) {
|
||||||
|
control_input_subscription_ = get_node()->create_subscription<control_input_msgs::msg::Inputs>(
|
||||||
|
"/control_input", 10, [this](const control_input_msgs::msg::Inputs::SharedPtr msg) {
|
||||||
|
// Handle message
|
||||||
|
ctrl_comp_.control_inputs_.command = msg->command;
|
||||||
|
ctrl_comp_.control_inputs_.lx = msg->lx;
|
||||||
|
ctrl_comp_.control_inputs_.ly = msg->ly;
|
||||||
|
ctrl_comp_.control_inputs_.rx = msg->rx;
|
||||||
|
ctrl_comp_.control_inputs_.ry = msg->ry;
|
||||||
|
});
|
||||||
|
|
||||||
|
get_node()->get_parameter("update_rate", ctrl_comp_.frequency_);
|
||||||
|
RCLCPP_INFO(get_node()->get_logger(), "Controller Manager Update Rate: %d Hz", ctrl_comp_.frequency_);
|
||||||
|
|
||||||
|
return CallbackReturn::SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
controller_interface::CallbackReturn LeggedGymController::on_activate(
|
||||||
|
const rclcpp_lifecycle::State & /*previous_state*/) {
|
||||||
|
// clear out vectors in case of restart
|
||||||
|
ctrl_comp_.clear();
|
||||||
|
|
||||||
|
// assign command interfaces
|
||||||
|
for (auto &interface: command_interfaces_) {
|
||||||
|
std::string interface_name = interface.get_interface_name();
|
||||||
|
if (const size_t pos = interface_name.find('/'); pos != std::string::npos) {
|
||||||
|
command_interface_map_[interface_name.substr(pos + 1)]->push_back(interface);
|
||||||
|
} else {
|
||||||
|
command_interface_map_[interface_name]->push_back(interface);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// assign state interfaces
|
||||||
|
for (auto &interface: state_interfaces_) {
|
||||||
|
if (interface.get_prefix_name() == imu_name_) {
|
||||||
|
ctrl_comp_.imu_state_interface_.emplace_back(interface);
|
||||||
|
} else if (interface.get_prefix_name() == foot_force_name_) {
|
||||||
|
ctrl_comp_.foot_force_state_interface_.emplace_back(interface);
|
||||||
|
} else {
|
||||||
|
state_interface_map_[interface.get_interface_name()]->push_back(interface);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create FSM List
|
||||||
|
state_list_.passive = std::make_shared<StatePassive>(ctrl_comp_);
|
||||||
|
state_list_.fixedDown = std::make_shared<StateFixedDown>(ctrl_comp_);
|
||||||
|
state_list_.fixedStand = std::make_shared<StateFixedStand>(ctrl_comp_);
|
||||||
|
|
||||||
|
// Initialize FSM
|
||||||
|
current_state_ = state_list_.passive;
|
||||||
|
current_state_->enter();
|
||||||
|
next_state_ = current_state_;
|
||||||
|
next_state_name_ = current_state_->state_name;
|
||||||
|
mode_ = FSMMode::NORMAL;
|
||||||
|
|
||||||
|
return CallbackReturn::SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
controller_interface::CallbackReturn LeggedGymController::on_deactivate(
|
||||||
|
const rclcpp_lifecycle::State & /*previous_state*/) {
|
||||||
|
release_interfaces();
|
||||||
|
return CallbackReturn::SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
controller_interface::CallbackReturn
|
||||||
|
LeggedGymController::on_cleanup(const rclcpp_lifecycle::State &previous_state) {
|
||||||
|
return ControllerInterface::on_cleanup(previous_state);
|
||||||
|
}
|
||||||
|
|
||||||
|
controller_interface::CallbackReturn
|
||||||
|
LeggedGymController::on_shutdown(const rclcpp_lifecycle::State &previous_state) {
|
||||||
|
return ControllerInterface::on_shutdown(previous_state);
|
||||||
|
}
|
||||||
|
|
||||||
|
controller_interface::CallbackReturn LeggedGymController::on_error(const rclcpp_lifecycle::State &previous_state) {
|
||||||
|
return ControllerInterface::on_error(previous_state);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<FSMState> LeggedGymController::getNextState(FSMStateName stateName) const {
|
||||||
|
switch (stateName) {
|
||||||
|
case FSMStateName::INVALID:
|
||||||
|
return state_list_.invalid;
|
||||||
|
case FSMStateName::PASSIVE:
|
||||||
|
return state_list_.passive;
|
||||||
|
case FSMStateName::FIXEDDOWN:
|
||||||
|
return state_list_.fixedDown;
|
||||||
|
case FSMStateName::FIXEDSTAND:
|
||||||
|
return state_list_.fixedStand;
|
||||||
|
default:
|
||||||
|
return state_list_.invalid;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#include "pluginlib/class_list_macros.hpp"
|
||||||
|
PLUGINLIB_EXPORT_CLASS(legged_gym_controller::LeggedGymController, controller_interface::ControllerInterface);
|
|
@ -0,0 +1,107 @@
|
||||||
|
//
|
||||||
|
// Created by tlab-uav on 24-10-4.
|
||||||
|
//
|
||||||
|
|
||||||
|
#ifndef LEGGEDGYMCONTROLLER_H
|
||||||
|
#define LEGGEDGYMCONTROLLER_H
|
||||||
|
#include <controller_interface/controller_interface.hpp>
|
||||||
|
#include "legged_gym_controller/FSM/StateFixedStand.h"
|
||||||
|
#include "legged_gym_controller/FSM/StateFixedDown.h"
|
||||||
|
#include "legged_gym_controller/FSM/StatePassive.h"
|
||||||
|
#include "legged_gym_controller/control/CtrlComponent.h"
|
||||||
|
|
||||||
|
namespace legged_gym_controller {
|
||||||
|
struct FSMStateList {
|
||||||
|
std::shared_ptr<FSMState> invalid;
|
||||||
|
std::shared_ptr<StatePassive> passive;
|
||||||
|
std::shared_ptr<StateFixedDown> fixedDown;
|
||||||
|
std::shared_ptr<StateFixedStand> fixedStand;
|
||||||
|
};
|
||||||
|
|
||||||
|
class LeggedGymController final : public controller_interface::ControllerInterface {
|
||||||
|
public:
|
||||||
|
CONTROLLER_INTERFACE_PUBLIC
|
||||||
|
LeggedGymController() = default;
|
||||||
|
|
||||||
|
CONTROLLER_INTERFACE_PUBLIC
|
||||||
|
controller_interface::InterfaceConfiguration command_interface_configuration() const override;
|
||||||
|
|
||||||
|
CONTROLLER_INTERFACE_PUBLIC
|
||||||
|
controller_interface::InterfaceConfiguration state_interface_configuration() const override;
|
||||||
|
|
||||||
|
CONTROLLER_INTERFACE_PUBLIC
|
||||||
|
controller_interface::return_type update(
|
||||||
|
const rclcpp::Time &time, const rclcpp::Duration &period) override;
|
||||||
|
|
||||||
|
CONTROLLER_INTERFACE_PUBLIC
|
||||||
|
controller_interface::CallbackReturn on_init() override;
|
||||||
|
|
||||||
|
CONTROLLER_INTERFACE_PUBLIC
|
||||||
|
controller_interface::CallbackReturn on_configure(
|
||||||
|
const rclcpp_lifecycle::State &previous_state) override;
|
||||||
|
|
||||||
|
CONTROLLER_INTERFACE_PUBLIC
|
||||||
|
controller_interface::CallbackReturn on_activate(
|
||||||
|
const rclcpp_lifecycle::State &previous_state) override;
|
||||||
|
|
||||||
|
CONTROLLER_INTERFACE_PUBLIC
|
||||||
|
controller_interface::CallbackReturn on_deactivate(
|
||||||
|
const rclcpp_lifecycle::State &previous_state) override;
|
||||||
|
|
||||||
|
CONTROLLER_INTERFACE_PUBLIC
|
||||||
|
controller_interface::CallbackReturn on_cleanup(
|
||||||
|
const rclcpp_lifecycle::State &previous_state) override;
|
||||||
|
|
||||||
|
CONTROLLER_INTERFACE_PUBLIC
|
||||||
|
controller_interface::CallbackReturn on_shutdown(
|
||||||
|
const rclcpp_lifecycle::State &previous_state) override;
|
||||||
|
|
||||||
|
CONTROLLER_INTERFACE_PUBLIC
|
||||||
|
controller_interface::CallbackReturn on_error(
|
||||||
|
const rclcpp_lifecycle::State &previous_state) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::shared_ptr<FSMState> getNextState(FSMStateName stateName) const;
|
||||||
|
|
||||||
|
CtrlComponent ctrl_comp_;
|
||||||
|
std::vector<std::string> joint_names_;
|
||||||
|
std::vector<std::string> feet_names_;
|
||||||
|
std::vector<std::string> command_interface_types_;
|
||||||
|
std::vector<std::string> state_interface_types_;
|
||||||
|
|
||||||
|
// IMU Sensor
|
||||||
|
std::string imu_name_;
|
||||||
|
std::vector<std::string> imu_interface_types_;
|
||||||
|
// Foot Force Sensor
|
||||||
|
std::string foot_force_name_;
|
||||||
|
std::vector<std::string> foot_force_interface_types_;
|
||||||
|
|
||||||
|
std::unordered_map<
|
||||||
|
std::string, std::vector<std::reference_wrapper<hardware_interface::LoanedCommandInterface> > *>
|
||||||
|
command_interface_map_ = {
|
||||||
|
{"effort", &ctrl_comp_.joint_torque_command_interface_},
|
||||||
|
{"position", &ctrl_comp_.joint_position_command_interface_},
|
||||||
|
{"velocity", &ctrl_comp_.joint_velocity_command_interface_},
|
||||||
|
{"kp", &ctrl_comp_.joint_kp_command_interface_},
|
||||||
|
{"kd", &ctrl_comp_.joint_kd_command_interface_}
|
||||||
|
};
|
||||||
|
|
||||||
|
std::unordered_map<
|
||||||
|
std::string, std::vector<std::reference_wrapper<hardware_interface::LoanedStateInterface> > *>
|
||||||
|
state_interface_map_ = {
|
||||||
|
{"position", &ctrl_comp_.joint_position_state_interface_},
|
||||||
|
{"effort", &ctrl_comp_.joint_effort_state_interface_},
|
||||||
|
{"velocity", &ctrl_comp_.joint_velocity_state_interface_}
|
||||||
|
};
|
||||||
|
|
||||||
|
rclcpp::Subscription<control_input_msgs::msg::Inputs>::SharedPtr control_input_subscription_;
|
||||||
|
|
||||||
|
FSMMode mode_ = FSMMode::NORMAL;
|
||||||
|
std::string state_name_;
|
||||||
|
FSMStateName next_state_name_ = FSMStateName::INVALID;
|
||||||
|
FSMStateList state_list_;
|
||||||
|
std::shared_ptr<FSMState> current_state_;
|
||||||
|
std::shared_ptr<FSMState> next_state_;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
#endif //LEGGEDGYMCONTROLLER_H
|
|
@ -1,467 +0,0 @@
|
||||||
#include "legged_gym_controller/rl_sdk.hpp"
|
|
||||||
|
|
||||||
/* You may need to override this Forward() function
|
|
||||||
torch::Tensor RL_XXX::Forward()
|
|
||||||
{
|
|
||||||
torch::autograd::GradMode::set_enabled(false);
|
|
||||||
torch::Tensor clamped_obs = this->ComputeObservation();
|
|
||||||
torch::Tensor actions = this->model.forward({clamped_obs}).toTensor();
|
|
||||||
torch::Tensor clamped_actions = torch::clamp(actions, this->params.clip_actions_lower, this->params.clip_actions_upper);
|
|
||||||
return clamped_actions;
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
|
|
||||||
torch::Tensor RL::ComputeObservation()
|
|
||||||
{
|
|
||||||
std::vector<torch::Tensor> obs_list;
|
|
||||||
|
|
||||||
for(const std::string& observation : this->params.observations)
|
|
||||||
{
|
|
||||||
if(observation == "lin_vel")
|
|
||||||
{
|
|
||||||
obs_list.push_back(this->obs.lin_vel * this->params.lin_vel_scale);
|
|
||||||
}
|
|
||||||
else if(observation == "ang_vel")
|
|
||||||
{
|
|
||||||
// obs_list.push_back(this->obs.ang_vel * this->params.ang_vel_scale); // TODO is QuatRotateInverse necessery?
|
|
||||||
obs_list.push_back(this->QuatRotateInverse(this->obs.base_quat, this->obs.ang_vel, this->params.framework) * this->params.ang_vel_scale);
|
|
||||||
}
|
|
||||||
else if(observation == "gravity_vec")
|
|
||||||
{
|
|
||||||
obs_list.push_back(this->QuatRotateInverse(this->obs.base_quat, this->obs.gravity_vec, this->params.framework));
|
|
||||||
}
|
|
||||||
else if(observation == "commands")
|
|
||||||
{
|
|
||||||
obs_list.push_back(this->obs.commands * this->params.commands_scale);
|
|
||||||
}
|
|
||||||
else if(observation == "dof_pos")
|
|
||||||
{
|
|
||||||
obs_list.push_back((this->obs.dof_pos - this->params.default_dof_pos) * this->params.dof_pos_scale);
|
|
||||||
}
|
|
||||||
else if(observation == "dof_vel")
|
|
||||||
{
|
|
||||||
obs_list.push_back(this->obs.dof_vel * this->params.dof_vel_scale);
|
|
||||||
}
|
|
||||||
else if(observation == "actions")
|
|
||||||
{
|
|
||||||
obs_list.push_back(this->obs.actions);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
torch::Tensor obs = torch::cat(obs_list, 1);
|
|
||||||
torch::Tensor clamped_obs = torch::clamp(obs, -this->params.clip_obs, this->params.clip_obs);
|
|
||||||
return clamped_obs;
|
|
||||||
}
|
|
||||||
|
|
||||||
void RL::InitObservations()
|
|
||||||
{
|
|
||||||
this->obs.lin_vel = torch::tensor({{0.0, 0.0, 0.0}});
|
|
||||||
this->obs.ang_vel = torch::tensor({{0.0, 0.0, 0.0}});
|
|
||||||
this->obs.gravity_vec = torch::tensor({{0.0, 0.0, -1.0}});
|
|
||||||
this->obs.commands = torch::tensor({{0.0, 0.0, 0.0}});
|
|
||||||
this->obs.base_quat = torch::tensor({{0.0, 0.0, 0.0, 1.0}});
|
|
||||||
this->obs.dof_pos = this->params.default_dof_pos;
|
|
||||||
this->obs.dof_vel = torch::zeros({1, this->params.num_of_dofs});
|
|
||||||
this->obs.actions = torch::zeros({1, this->params.num_of_dofs});
|
|
||||||
}
|
|
||||||
|
|
||||||
void RL::InitOutputs()
|
|
||||||
{
|
|
||||||
this->output_torques = torch::zeros({1, this->params.num_of_dofs});
|
|
||||||
this->output_dof_pos = this->params.default_dof_pos;
|
|
||||||
}
|
|
||||||
|
|
||||||
void RL::InitControl()
|
|
||||||
{
|
|
||||||
this->control.control_state = STATE_WAITING;
|
|
||||||
this->control.x = 0.0;
|
|
||||||
this->control.y = 0.0;
|
|
||||||
this->control.yaw = 0.0;
|
|
||||||
}
|
|
||||||
|
|
||||||
torch::Tensor RL::ComputeTorques(torch::Tensor actions)
|
|
||||||
{
|
|
||||||
torch::Tensor actions_scaled = actions * this->params.action_scale;
|
|
||||||
torch::Tensor output_torques = this->params.rl_kp * (actions_scaled + this->params.default_dof_pos - this->obs.dof_pos) - this->params.rl_kd * this->obs.dof_vel;
|
|
||||||
return output_torques;
|
|
||||||
}
|
|
||||||
|
|
||||||
torch::Tensor RL::ComputePosition(torch::Tensor actions)
|
|
||||||
{
|
|
||||||
torch::Tensor actions_scaled = actions * this->params.action_scale;
|
|
||||||
return actions_scaled + this->params.default_dof_pos;
|
|
||||||
}
|
|
||||||
|
|
||||||
torch::Tensor RL::QuatRotateInverse(torch::Tensor q, torch::Tensor v, const std::string& framework)
|
|
||||||
{
|
|
||||||
torch::Tensor q_w;
|
|
||||||
torch::Tensor q_vec;
|
|
||||||
if(framework == "isaacsim")
|
|
||||||
{
|
|
||||||
q_w = q.index({torch::indexing::Slice(), 0});
|
|
||||||
q_vec = q.index({torch::indexing::Slice(), torch::indexing::Slice(1, 4)});
|
|
||||||
}
|
|
||||||
else if(framework == "isaacgym")
|
|
||||||
{
|
|
||||||
q_w = q.index({torch::indexing::Slice(), 3});
|
|
||||||
q_vec = q.index({torch::indexing::Slice(), torch::indexing::Slice(0, 3)});
|
|
||||||
}
|
|
||||||
c10::IntArrayRef shape = q.sizes();
|
|
||||||
|
|
||||||
torch::Tensor a = v * (2.0 * torch::pow(q_w, 2) - 1.0).unsqueeze(-1);
|
|
||||||
torch::Tensor b = torch::cross(q_vec, v, -1) * q_w.unsqueeze(-1) * 2.0;
|
|
||||||
torch::Tensor c = q_vec * torch::bmm(q_vec.view({shape[0], 1, 3}), v.view({shape[0], 3, 1})).squeeze(-1) * 2.0;
|
|
||||||
return a - b + c;
|
|
||||||
}
|
|
||||||
|
|
||||||
void RL::StateController(const RobotState<double> *state, RobotCommand<double> *command)
|
|
||||||
{
|
|
||||||
static RobotState<double> start_state;
|
|
||||||
static RobotState<double> now_state;
|
|
||||||
static float getup_percent = 0.0;
|
|
||||||
static float getdown_percent = 0.0;
|
|
||||||
|
|
||||||
// waiting
|
|
||||||
if(this->running_state == STATE_WAITING)
|
|
||||||
{
|
|
||||||
for(int i = 0; i < this->params.num_of_dofs; ++i)
|
|
||||||
{
|
|
||||||
command->motor_command.q[i] = state->motor_state.q[i];
|
|
||||||
}
|
|
||||||
if(this->control.control_state == STATE_POS_GETUP)
|
|
||||||
{
|
|
||||||
this->control.control_state = STATE_WAITING;
|
|
||||||
getup_percent = 0.0;
|
|
||||||
for(int i = 0; i < this->params.num_of_dofs; ++i)
|
|
||||||
{
|
|
||||||
now_state.motor_state.q[i] = state->motor_state.q[i];
|
|
||||||
start_state.motor_state.q[i] = now_state.motor_state.q[i];
|
|
||||||
}
|
|
||||||
this->running_state = STATE_POS_GETUP;
|
|
||||||
std::cout << std::endl << LOGGER::INFO << "Switching to STATE_POS_GETUP" << std::endl;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// stand up (position control)
|
|
||||||
else if(this->running_state == STATE_POS_GETUP)
|
|
||||||
{
|
|
||||||
if(getup_percent < 1.0)
|
|
||||||
{
|
|
||||||
getup_percent += 1 / 500.0;
|
|
||||||
getup_percent = getup_percent > 1.0 ? 1.0 : getup_percent;
|
|
||||||
for(int i = 0; i < this->params.num_of_dofs; ++i)
|
|
||||||
{
|
|
||||||
command->motor_command.q[i] = (1 - getup_percent) * now_state.motor_state.q[i] + getup_percent * this->params.default_dof_pos[0][i].item<double>();
|
|
||||||
command->motor_command.dq[i] = 0;
|
|
||||||
command->motor_command.kp[i] = this->params.fixed_kp[0][i].item<double>();
|
|
||||||
command->motor_command.kd[i] = this->params.fixed_kd[0][i].item<double>();
|
|
||||||
command->motor_command.tau[i] = 0;
|
|
||||||
}
|
|
||||||
std::cout << "\r" << std::flush << LOGGER::INFO << "Getting up " << std::fixed << std::setprecision(2) << getup_percent * 100.0 << std::flush;
|
|
||||||
}
|
|
||||||
if(this->control.control_state == STATE_RL_INIT)
|
|
||||||
{
|
|
||||||
this->control.control_state = STATE_WAITING;
|
|
||||||
this->running_state = STATE_RL_INIT;
|
|
||||||
std::cout << std::endl << LOGGER::INFO << "Switching to STATE_RL_INIT" << std::endl;
|
|
||||||
}
|
|
||||||
else if(this->control.control_state == STATE_POS_GETDOWN)
|
|
||||||
{
|
|
||||||
this->control.control_state = STATE_WAITING;
|
|
||||||
getdown_percent = 0.0;
|
|
||||||
for(int i = 0; i < this->params.num_of_dofs; ++i)
|
|
||||||
{
|
|
||||||
now_state.motor_state.q[i] = state->motor_state.q[i];
|
|
||||||
}
|
|
||||||
this->running_state = STATE_POS_GETDOWN;
|
|
||||||
std::cout << std::endl << LOGGER::INFO << "Switching to STATE_POS_GETDOWN" << std::endl;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// init obs and start rl loop
|
|
||||||
else if(this->running_state == STATE_RL_INIT)
|
|
||||||
{
|
|
||||||
if(getup_percent == 1)
|
|
||||||
{
|
|
||||||
this->InitObservations();
|
|
||||||
this->InitOutputs();
|
|
||||||
this->InitControl();
|
|
||||||
this->running_state = STATE_RL_RUNNING;
|
|
||||||
std::cout << std::endl << LOGGER::INFO << "Switching to STATE_RL_RUNNING" << std::endl;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// rl loop
|
|
||||||
else if(this->running_state == STATE_RL_RUNNING)
|
|
||||||
{
|
|
||||||
std::cout << "\r" << std::flush << LOGGER::INFO << "RL Controller x:" << this->control.x << " y:" << this->control.y << " yaw:" << this->control.yaw << std::flush;
|
|
||||||
for(int i = 0; i < this->params.num_of_dofs; ++i)
|
|
||||||
{
|
|
||||||
command->motor_command.q[i] = this->output_dof_pos[0][i].item<double>();
|
|
||||||
command->motor_command.dq[i] = 0;
|
|
||||||
command->motor_command.kp[i] = this->params.rl_kp[0][i].item<double>();
|
|
||||||
command->motor_command.kd[i] = this->params.rl_kd[0][i].item<double>();
|
|
||||||
command->motor_command.tau[i] = 0;
|
|
||||||
}
|
|
||||||
if(this->control.control_state == STATE_POS_GETDOWN)
|
|
||||||
{
|
|
||||||
this->control.control_state = STATE_WAITING;
|
|
||||||
getdown_percent = 0.0;
|
|
||||||
for(int i = 0; i < this->params.num_of_dofs; ++i)
|
|
||||||
{
|
|
||||||
now_state.motor_state.q[i] = state->motor_state.q[i];
|
|
||||||
}
|
|
||||||
this->running_state = STATE_POS_GETDOWN;
|
|
||||||
std::cout << std::endl << LOGGER::INFO << "Switching to STATE_POS_GETDOWN" << std::endl;
|
|
||||||
}
|
|
||||||
else if(this->control.control_state == STATE_POS_GETUP)
|
|
||||||
{
|
|
||||||
this->control.control_state = STATE_WAITING;
|
|
||||||
getup_percent = 0.0;
|
|
||||||
for(int i = 0; i < this->params.num_of_dofs; ++i)
|
|
||||||
{
|
|
||||||
now_state.motor_state.q[i] = state->motor_state.q[i];
|
|
||||||
}
|
|
||||||
this->running_state = STATE_POS_GETUP;
|
|
||||||
std::cout << std::endl << LOGGER::INFO << "Switching to STATE_POS_GETUP" << std::endl;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// get down (position control)
|
|
||||||
else if(this->running_state == STATE_POS_GETDOWN)
|
|
||||||
{
|
|
||||||
if(getdown_percent < 1.0)
|
|
||||||
{
|
|
||||||
getdown_percent += 1 / 500.0;
|
|
||||||
getdown_percent = getdown_percent > 1.0 ? 1.0 : getdown_percent;
|
|
||||||
for(int i = 0; i < this->params.num_of_dofs; ++i)
|
|
||||||
{
|
|
||||||
command->motor_command.q[i] = (1 - getdown_percent) * now_state.motor_state.q[i] + getdown_percent * start_state.motor_state.q[i];
|
|
||||||
command->motor_command.dq[i] = 0;
|
|
||||||
command->motor_command.kp[i] = this->params.fixed_kp[0][i].item<double>();
|
|
||||||
command->motor_command.kd[i] = this->params.fixed_kd[0][i].item<double>();
|
|
||||||
command->motor_command.tau[i] = 0;
|
|
||||||
}
|
|
||||||
std::cout << "\r" << std::flush << LOGGER::INFO << "Getting down " << std::fixed << std::setprecision(2) << getdown_percent * 100.0 << std::flush;
|
|
||||||
}
|
|
||||||
if(getdown_percent == 1)
|
|
||||||
{
|
|
||||||
this->InitObservations();
|
|
||||||
this->InitOutputs();
|
|
||||||
this->InitControl();
|
|
||||||
this->running_state = STATE_WAITING;
|
|
||||||
std::cout << std::endl << LOGGER::INFO << "Switching to STATE_WAITING" << std::endl;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void RL::TorqueProtect(torch::Tensor origin_output_torques)
|
|
||||||
{
|
|
||||||
std::vector<int> out_of_range_indices;
|
|
||||||
std::vector<double> out_of_range_values;
|
|
||||||
for(int i = 0; i < origin_output_torques.size(1); ++i)
|
|
||||||
{
|
|
||||||
double torque_value = origin_output_torques[0][i].item<double>();
|
|
||||||
double limit_lower = -this->params.torque_limits[0][i].item<double>();
|
|
||||||
double limit_upper = this->params.torque_limits[0][i].item<double>();
|
|
||||||
|
|
||||||
if(torque_value < limit_lower || torque_value > limit_upper)
|
|
||||||
{
|
|
||||||
out_of_range_indices.push_back(i);
|
|
||||||
out_of_range_values.push_back(torque_value);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if(!out_of_range_indices.empty())
|
|
||||||
{
|
|
||||||
for(int i = 0; i < out_of_range_indices.size(); ++i)
|
|
||||||
{
|
|
||||||
int index = out_of_range_indices[i];
|
|
||||||
double value = out_of_range_values[i];
|
|
||||||
double limit_lower = -this->params.torque_limits[0][index].item<double>();
|
|
||||||
double limit_upper = this->params.torque_limits[0][index].item<double>();
|
|
||||||
|
|
||||||
std::cout << LOGGER::WARNING << "Torque(" << index+1 << ")=" << value << " out of range(" << limit_lower << ", " << limit_upper << ")" << std::endl;
|
|
||||||
}
|
|
||||||
// Just a reminder, no protection
|
|
||||||
// this->control.control_state = STATE_POS_GETDOWN;
|
|
||||||
// std::cout << LOGGER::INFO << "Switching to STATE_POS_GETDOWN"<< std::endl;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#include <termios.h>
|
|
||||||
#include <sys/ioctl.h>
|
|
||||||
static bool kbhit()
|
|
||||||
{
|
|
||||||
termios term;
|
|
||||||
tcgetattr(0, &term);
|
|
||||||
|
|
||||||
termios term2 = term;
|
|
||||||
term2.c_lflag &= ~ICANON;
|
|
||||||
tcsetattr(0, TCSANOW, &term2);
|
|
||||||
|
|
||||||
int byteswaiting;
|
|
||||||
ioctl(0, FIONREAD, &byteswaiting);
|
|
||||||
|
|
||||||
tcsetattr(0, TCSANOW, &term);
|
|
||||||
|
|
||||||
return byteswaiting > 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
void RL::KeyboardInterface()
|
|
||||||
{
|
|
||||||
if(kbhit())
|
|
||||||
{
|
|
||||||
int c = fgetc(stdin);
|
|
||||||
switch(c)
|
|
||||||
{
|
|
||||||
case '0': this->control.control_state = STATE_POS_GETUP; break;
|
|
||||||
case 'p': this->control.control_state = STATE_RL_INIT; break;
|
|
||||||
case '1': this->control.control_state = STATE_POS_GETDOWN; break;
|
|
||||||
case 'q': break;
|
|
||||||
case 'w': this->control.x += 0.1; break;
|
|
||||||
case 's': this->control.x -= 0.1; break;
|
|
||||||
case 'a': this->control.yaw += 0.1; break;
|
|
||||||
case 'd': this->control.yaw -= 0.1; break;
|
|
||||||
case 'i': break;
|
|
||||||
case 'k': break;
|
|
||||||
case 'j': this->control.y += 0.1; break;
|
|
||||||
case 'l': this->control.y -= 0.1; break;
|
|
||||||
case ' ': this->control.x = 0; this->control.y = 0; this->control.yaw = 0; break;
|
|
||||||
case 'r': this->control.control_state = STATE_RESET_SIMULATION; break;
|
|
||||||
case '\n': this->control.control_state = STATE_TOGGLE_SIMULATION; break;
|
|
||||||
default: break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template<typename T>
|
|
||||||
std::vector<T> ReadVectorFromYaml(const YAML::Node& node)
|
|
||||||
{
|
|
||||||
std::vector<T> values;
|
|
||||||
for(const auto& val : node)
|
|
||||||
{
|
|
||||||
values.push_back(val.as<T>());
|
|
||||||
}
|
|
||||||
return values;
|
|
||||||
}
|
|
||||||
|
|
||||||
template<typename T>
|
|
||||||
std::vector<T> ReadVectorFromYaml(const YAML::Node& node, const std::string& framework, const int& rows, const int& cols)
|
|
||||||
{
|
|
||||||
std::vector<T> values;
|
|
||||||
for(const auto& val : node)
|
|
||||||
{
|
|
||||||
values.push_back(val.as<T>());
|
|
||||||
}
|
|
||||||
|
|
||||||
if(framework == "isaacsim")
|
|
||||||
{
|
|
||||||
std::vector<T> transposed_values(cols * rows);
|
|
||||||
for(int r = 0; r < rows; ++r)
|
|
||||||
{
|
|
||||||
for(int c = 0; c < cols; ++c)
|
|
||||||
{
|
|
||||||
transposed_values[c * rows + r] = values[r * cols + c];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return transposed_values;
|
|
||||||
}
|
|
||||||
else if(framework == "isaacgym")
|
|
||||||
{
|
|
||||||
return values;
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
throw std::invalid_argument("Unsupported framework: " + framework);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void RL::ReadYaml(std::string robot_name)
|
|
||||||
{
|
|
||||||
// The config file is located at "rl_sar/src/rl_sar/models/<robot_name>/config.yaml"
|
|
||||||
std::string config_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/" + robot_name + "/config.yaml";
|
|
||||||
YAML::Node config;
|
|
||||||
try
|
|
||||||
{
|
|
||||||
config = YAML::LoadFile(config_path)[robot_name];
|
|
||||||
} catch(YAML::BadFile &e)
|
|
||||||
{
|
|
||||||
std::cout << LOGGER::ERROR << "The file '" << config_path << "' does not exist" << std::endl;
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
this->params.model_name = config["model_name"].as<std::string>();
|
|
||||||
this->params.framework = config["framework"].as<std::string>();
|
|
||||||
int rows = config["rows"].as<int>();
|
|
||||||
int cols = config["cols"].as<int>();
|
|
||||||
this->params.use_history = config["use_history"].as<bool>();
|
|
||||||
this->params.dt = config["dt"].as<double>();
|
|
||||||
this->params.decimation = config["decimation"].as<int>();
|
|
||||||
this->params.num_observations = config["num_observations"].as<int>();
|
|
||||||
this->params.observations = ReadVectorFromYaml<std::string>(config["observations"]);
|
|
||||||
this->params.clip_obs = config["clip_obs"].as<double>();
|
|
||||||
if(config["clip_actions_lower"].IsNull() && config["clip_actions_upper"].IsNull())
|
|
||||||
{
|
|
||||||
this->params.clip_actions_upper = torch::tensor({}).view({1, -1});
|
|
||||||
this->params.clip_actions_lower = torch::tensor({}).view({1, -1});
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
this->params.clip_actions_upper = torch::tensor(ReadVectorFromYaml<double>(config["clip_actions_upper"], this->params.framework, rows, cols)).view({1, -1});
|
|
||||||
this->params.clip_actions_lower = torch::tensor(ReadVectorFromYaml<double>(config["clip_actions_lower"], this->params.framework, rows, cols)).view({1, -1});
|
|
||||||
}
|
|
||||||
this->params.action_scale = config["action_scale"].as<double>();
|
|
||||||
this->params.hip_scale_reduction = config["hip_scale_reduction"].as<double>();
|
|
||||||
this->params.hip_scale_reduction_indices = ReadVectorFromYaml<int>(config["hip_scale_reduction_indices"]);
|
|
||||||
this->params.num_of_dofs = config["num_of_dofs"].as<int>();
|
|
||||||
this->params.lin_vel_scale = config["lin_vel_scale"].as<double>();
|
|
||||||
this->params.ang_vel_scale = config["ang_vel_scale"].as<double>();
|
|
||||||
this->params.dof_pos_scale = config["dof_pos_scale"].as<double>();
|
|
||||||
this->params.dof_vel_scale = config["dof_vel_scale"].as<double>();
|
|
||||||
// this->params.commands_scale = torch::tensor(ReadVectorFromYaml<double>(config["commands_scale"])).view({1, -1});
|
|
||||||
this->params.commands_scale = torch::tensor({this->params.lin_vel_scale, this->params.lin_vel_scale, this->params.ang_vel_scale});
|
|
||||||
this->params.rl_kp = torch::tensor(ReadVectorFromYaml<double>(config["rl_kp"], this->params.framework, rows, cols)).view({1, -1});
|
|
||||||
this->params.rl_kd = torch::tensor(ReadVectorFromYaml<double>(config["rl_kd"], this->params.framework, rows, cols)).view({1, -1});
|
|
||||||
this->params.fixed_kp = torch::tensor(ReadVectorFromYaml<double>(config["fixed_kp"], this->params.framework, rows, cols)).view({1, -1});
|
|
||||||
this->params.fixed_kd = torch::tensor(ReadVectorFromYaml<double>(config["fixed_kd"], this->params.framework, rows, cols)).view({1, -1});
|
|
||||||
this->params.torque_limits = torch::tensor(ReadVectorFromYaml<double>(config["torque_limits"], this->params.framework, rows, cols)).view({1, -1});
|
|
||||||
this->params.default_dof_pos = torch::tensor(ReadVectorFromYaml<double>(config["default_dof_pos"], this->params.framework, rows, cols)).view({1, -1});
|
|
||||||
this->params.joint_controller_names = ReadVectorFromYaml<std::string>(config["joint_controller_names"], this->params.framework, rows, cols);
|
|
||||||
}
|
|
||||||
|
|
||||||
void RL::CSVInit(std::string robot_name)
|
|
||||||
{
|
|
||||||
csv_filename = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/" + robot_name + "/motor";
|
|
||||||
|
|
||||||
// Uncomment these lines if need timestamp for file name
|
|
||||||
// auto now = std::chrono::system_clock::now();
|
|
||||||
// std::time_t now_c = std::chrono::system_clock::to_time_t(now);
|
|
||||||
// std::stringstream ss;
|
|
||||||
// ss << std::put_time(std::localtime(&now_c), "%Y%m%d%H%M%S");
|
|
||||||
// std::string timestamp = ss.str();
|
|
||||||
// csv_filename += "_" + timestamp;
|
|
||||||
|
|
||||||
csv_filename += ".csv";
|
|
||||||
std::ofstream file(csv_filename.c_str());
|
|
||||||
|
|
||||||
for(int i = 0; i < 12; ++i) {file << "tau_cal_" << i << ",";}
|
|
||||||
for(int i = 0; i < 12; ++i) {file << "tau_est_" << i << ",";}
|
|
||||||
for(int i = 0; i < 12; ++i) {file << "joint_pos_" << i << ",";}
|
|
||||||
for(int i = 0; i < 12; ++i) {file << "joint_pos_target_" << i << ",";}
|
|
||||||
for(int i = 0; i < 12; ++i) {file << "joint_vel_" << i << ",";}
|
|
||||||
|
|
||||||
file << std::endl;
|
|
||||||
|
|
||||||
file.close();
|
|
||||||
}
|
|
||||||
|
|
||||||
void RL::CSVLogger(torch::Tensor torque, torch::Tensor tau_est, torch::Tensor joint_pos, torch::Tensor joint_pos_target, torch::Tensor joint_vel)
|
|
||||||
{
|
|
||||||
std::ofstream file(csv_filename.c_str(), std::ios_base::app);
|
|
||||||
|
|
||||||
for(int i = 0; i < 12; ++i) {file << torque[0][i].item<double>() << ",";}
|
|
||||||
for(int i = 0; i < 12; ++i) {file << tau_est[0][i].item<double>() << ",";}
|
|
||||||
for(int i = 0; i < 12; ++i) {file << joint_pos[0][i].item<double>() << ",";}
|
|
||||||
for(int i = 0; i < 12; ++i) {file << joint_pos_target[0][i].item<double>() << ",";}
|
|
||||||
for(int i = 0; i < 12; ++i) {file << joint_vel[0][i].item<double>() << ",";}
|
|
||||||
|
|
||||||
file << std::endl;
|
|
||||||
|
|
||||||
file.close();
|
|
||||||
}
|
|
|
@ -31,4 +31,8 @@ ros2 launch a1_description visualize.launch.py
|
||||||
source ~/ros2_ws/install/setup.bash
|
source ~/ros2_ws/install/setup.bash
|
||||||
ros2 launch a1_description ocs2_control.launch.py
|
ros2 launch a1_description ocs2_control.launch.py
|
||||||
```
|
```
|
||||||
|
* Legged Gym Controller
|
||||||
|
```bash
|
||||||
|
source ~/ros2_ws/install/setup.bash
|
||||||
|
ros2 launch a1_description rl_control.launch.py
|
||||||
|
```
|
||||||
|
|
|
@ -16,6 +16,9 @@ controller_manager:
|
||||||
ocs2_quadruped_controller:
|
ocs2_quadruped_controller:
|
||||||
type: ocs2_quadruped_controller/Ocs2QuadrupedController
|
type: ocs2_quadruped_controller/Ocs2QuadrupedController
|
||||||
|
|
||||||
|
legged_gym_controller:
|
||||||
|
type: legged_gym_controller/LeggedGymController
|
||||||
|
|
||||||
imu_sensor_broadcaster:
|
imu_sensor_broadcaster:
|
||||||
ros__parameters:
|
ros__parameters:
|
||||||
sensor_name: "imu_sensor"
|
sensor_name: "imu_sensor"
|
||||||
|
@ -128,4 +131,54 @@ ocs2_quadruped_controller:
|
||||||
- FL
|
- FL
|
||||||
- RL
|
- RL
|
||||||
- FR
|
- FR
|
||||||
- RR
|
- RR
|
||||||
|
|
||||||
|
legged_gym_controller:
|
||||||
|
ros__parameters:
|
||||||
|
update_rate: 500 # Hz
|
||||||
|
joints:
|
||||||
|
- FR_hip_joint
|
||||||
|
- FR_thigh_joint
|
||||||
|
- FR_calf_joint
|
||||||
|
- FL_hip_joint
|
||||||
|
- FL_thigh_joint
|
||||||
|
- FL_calf_joint
|
||||||
|
- RR_hip_joint
|
||||||
|
- RR_thigh_joint
|
||||||
|
- RR_calf_joint
|
||||||
|
- RL_hip_joint
|
||||||
|
- RL_thigh_joint
|
||||||
|
- RL_calf_joint
|
||||||
|
|
||||||
|
command_interfaces:
|
||||||
|
- effort
|
||||||
|
- position
|
||||||
|
- velocity
|
||||||
|
- kp
|
||||||
|
- kd
|
||||||
|
|
||||||
|
state_interfaces:
|
||||||
|
- effort
|
||||||
|
- position
|
||||||
|
- velocity
|
||||||
|
|
||||||
|
feet_names:
|
||||||
|
- FR_foot
|
||||||
|
- FL_foot
|
||||||
|
- RR_foot
|
||||||
|
- RL_foot
|
||||||
|
|
||||||
|
imu_name: "imu_sensor"
|
||||||
|
base_name: "base"
|
||||||
|
|
||||||
|
imu_interfaces:
|
||||||
|
- orientation.w
|
||||||
|
- orientation.x
|
||||||
|
- orientation.y
|
||||||
|
- orientation.z
|
||||||
|
- angular_velocity.x
|
||||||
|
- angular_velocity.y
|
||||||
|
- angular_velocity.z
|
||||||
|
- linear_acceleration.x
|
||||||
|
- linear_acceleration.y
|
||||||
|
- linear_acceleration.z
|
|
@ -0,0 +1,112 @@
|
||||||
|
import os
|
||||||
|
|
||||||
|
import xacro
|
||||||
|
from ament_index_python.packages import get_package_share_directory
|
||||||
|
from launch import LaunchDescription
|
||||||
|
from launch.actions import DeclareLaunchArgument, OpaqueFunction, IncludeLaunchDescription, RegisterEventHandler
|
||||||
|
from launch.event_handlers import OnProcessExit
|
||||||
|
from launch.substitutions import PathJoinSubstitution
|
||||||
|
from launch_ros.actions import Node
|
||||||
|
from launch_ros.substitutions import FindPackageShare
|
||||||
|
|
||||||
|
package_description = "a1_description"
|
||||||
|
|
||||||
|
|
||||||
|
def process_xacro(context):
|
||||||
|
robot_type_value = context.launch_configurations['robot_type']
|
||||||
|
pkg_path = os.path.join(get_package_share_directory(package_description))
|
||||||
|
xacro_file = os.path.join(pkg_path, 'xacro', 'robot.xacro')
|
||||||
|
robot_description_config = xacro.process_file(xacro_file, mappings={'robot_type': robot_type_value})
|
||||||
|
return (robot_description_config.toxml(), robot_type_value)
|
||||||
|
|
||||||
|
|
||||||
|
def launch_setup(context, *args, **kwargs):
|
||||||
|
(robot_description, robot_type) = process_xacro(context)
|
||||||
|
robot_controllers = PathJoinSubstitution(
|
||||||
|
[
|
||||||
|
FindPackageShare(package_description),
|
||||||
|
"config",
|
||||||
|
"robot_control.yaml",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
robot_state_publisher = Node(
|
||||||
|
package='robot_state_publisher',
|
||||||
|
executable='robot_state_publisher',
|
||||||
|
name='robot_state_publisher',
|
||||||
|
parameters=[
|
||||||
|
{
|
||||||
|
'publish_frequency': 20.0,
|
||||||
|
'use_tf_static': True,
|
||||||
|
'robot_description': robot_description,
|
||||||
|
'ignore_timestamp': True
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
controller_manager = Node(
|
||||||
|
package="controller_manager",
|
||||||
|
executable="ros2_control_node",
|
||||||
|
parameters=[robot_controllers],
|
||||||
|
output="both",
|
||||||
|
)
|
||||||
|
|
||||||
|
joint_state_publisher = Node(
|
||||||
|
package="controller_manager",
|
||||||
|
executable="spawner",
|
||||||
|
arguments=["joint_state_broadcaster",
|
||||||
|
"--controller-manager", "/controller_manager"],
|
||||||
|
)
|
||||||
|
|
||||||
|
imu_sensor_broadcaster = Node(
|
||||||
|
package="controller_manager",
|
||||||
|
executable="spawner",
|
||||||
|
arguments=["imu_sensor_broadcaster",
|
||||||
|
"--controller-manager", "/controller_manager"],
|
||||||
|
)
|
||||||
|
|
||||||
|
controller = Node(
|
||||||
|
package="controller_manager",
|
||||||
|
executable="spawner",
|
||||||
|
arguments=["legged_gym_controller", "--controller-manager", "/controller_manager"],
|
||||||
|
)
|
||||||
|
|
||||||
|
return [
|
||||||
|
robot_state_publisher,
|
||||||
|
controller_manager,
|
||||||
|
joint_state_publisher,
|
||||||
|
RegisterEventHandler(
|
||||||
|
event_handler=OnProcessExit(
|
||||||
|
target_action=joint_state_publisher,
|
||||||
|
on_exit=[imu_sensor_broadcaster],
|
||||||
|
)
|
||||||
|
),
|
||||||
|
RegisterEventHandler(
|
||||||
|
event_handler=OnProcessExit(
|
||||||
|
target_action=imu_sensor_broadcaster,
|
||||||
|
on_exit=[controller],
|
||||||
|
)
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def generate_launch_description():
|
||||||
|
robot_type_arg = DeclareLaunchArgument(
|
||||||
|
'robot_type',
|
||||||
|
default_value='a1',
|
||||||
|
description='Type of the robot'
|
||||||
|
)
|
||||||
|
|
||||||
|
rviz_config_file = os.path.join(get_package_share_directory(package_description), "config", "visualize_urdf.rviz")
|
||||||
|
|
||||||
|
return LaunchDescription([
|
||||||
|
robot_type_arg,
|
||||||
|
OpaqueFunction(function=launch_setup),
|
||||||
|
Node(
|
||||||
|
package='rviz2',
|
||||||
|
executable='rviz2',
|
||||||
|
name='rviz_ocs2',
|
||||||
|
output='screen',
|
||||||
|
arguments=["-d", rviz_config_file]
|
||||||
|
)
|
||||||
|
])
|
Loading…
Reference in New Issue