feat: add LOGGER

This commit is contained in:
fan-ziqi 2024-05-27 21:33:07 +08:00
parent e84f2d85bc
commit 3a9886e6ea
5 changed files with 36 additions and 31 deletions

View File

@ -69,12 +69,12 @@ set_property(TARGET observation_buffer PROPERTY CXX_STANDARD 14)
add_executable(rl_sim src/rl_sim.cpp )
target_link_libraries(rl_sim
${catkin_LIBRARIES} ${EXTRA_LIBS} "${TORCH_LIBRARIES}"
${catkin_LIBRARIES} ${EXTRA_LIBS}
rl_sdk observation_buffer yaml-cpp
)
add_executable(rl_real_a1 src/rl_real_a1.cpp )
target_link_libraries(rl_real_a1
${catkin_LIBRARIES} ${EXTRA_LIBS} "${TORCH_LIBRARIES}"
${catkin_LIBRARIES} ${EXTRA_LIBS}
rl_sdk observation_buffer yaml-cpp
)

View File

@ -3,34 +3,28 @@
/* You may need to override this ComputeObservation() function
torch::Tensor RL::ComputeObservation()
{
torch::Tensor obs = torch::cat({(this->QuatRotateInverse(this->base_quat, this->lin_vel)) * this->params.lin_vel_scale,
(this->QuatRotateInverse(this->obs.base_quat, this->obs.ang_vel)) * this->params.ang_vel_scale,
torch::Tensor obs = torch::cat({this->QuatRotateInverse(this->obs.base_quat, this->obs.ang_vel) * this->params.ang_vel_scale,
this->QuatRotateInverse(this->obs.base_quat, this->obs.gravity_vec),
this->obs.commands * this->params.commands_scale,
(this->obs.dof_pos - this->params.default_dof_pos) * this->params.dof_pos_scale,
this->obs.dof_vel * this->params.dof_vel_scale,
this->obs.actions},
1);
obs = torch::clamp(obs, -this->params.clip_obs, this->params.clip_obs);
printf("observation size: %d, %d\n", obs.sizes()[0], obs.sizes()[1]);
return obs;
this->obs.actions
},1);
torch::Tensor clamped_obs = torch::clamp(obs, -this->params.clip_obs, this->params.clip_obs);
return clamped_obs;
}
*/
/* You may need to override this Forward() function
torch::Tensor RL::Forward()
{
torch::Tensor obs = this->ComputeObservation();
torch::autograd::GradMode::set_enabled(false);
torch::Tensor actor_input = torch::cat({obs}, 1);
torch::Tensor clamped_obs = this->ComputeObservation();
torch::Tensor actions = this->actor.forward({actor_input}).toTensor();
torch::Tensor actions = this->model.forward({clamped_obs}).toTensor();
this->obs.actions = actions;
torch::Tensor clamped_actions = torch::clamp(actions, -this->params.clip_actions, this->params.clip_actions);
torch::Tensor clamped_actions = torch::clamp(actions, this->params.clip_actions_lower, this->params.clip_actions_upper);
return clamped_actions;
}
@ -122,7 +116,7 @@ void RL::StateController(const RobotState<double> *state, RobotCommand<double> *
command->motor_command.kd[i] = params.fixed_kd[0][i].item<double>();
command->motor_command.tau[i] = 0;
}
printf("Getting up %.3f%%\r", getup_percent*100.0);
std::cout << LOGGER::INFO << "Getting up " << std::fixed << std::setprecision(2) << getup_percent * 100.0 << "%\r";
}
if(keyboard.keyboard_state == STATE_RL_INIT)
{
@ -155,8 +149,6 @@ void RL::StateController(const RobotState<double> *state, RobotCommand<double> *
// rl loop
else if(running_state == STATE_RL_RUNNING)
{
std::cout << "[RL Controller] x:" << keyboard.x << " y:" << keyboard.y << " yaw:" << keyboard.yaw << " \r";
for(int i = 0; i < params.num_of_dofs; ++i)
{
command->motor_command.q[i] = output_dof_pos[0][i].item<double>();
@ -191,7 +183,7 @@ void RL::StateController(const RobotState<double> *state, RobotCommand<double> *
command->motor_command.kd[i] = params.fixed_kd[0][i].item<double>();
command->motor_command.tau[i] = 0;
}
printf("Getting down %.3f%%\r", getdown_percent*100.0);
std::cout << LOGGER::INFO << "Getting down " << std::fixed << std::setprecision(2) << getdown_percent * 100.0 << "%\r";
}
if(getdown_percent == 1)
{
@ -222,16 +214,16 @@ void RL::TorqueProtect(torch::Tensor origin_output_torques)
}
if(!out_of_range_indices.empty())
{
std::cout << "Error: origin_output_torques is out of range at indices: ";
for(int i = 0; i < out_of_range_indices.size(); ++i)
{
std::cout << out_of_range_indices[i] << " (value: " << out_of_range_values[i] << ")";
if(i < out_of_range_indices.size() - 1)
{
std::cout << ", ";
}
int index = out_of_range_indices[i];
double value = out_of_range_values[i];
double limit_lower = -this->params.torque_limits[0][index].item<double>();
double limit_upper = this->params.torque_limits[0][index].item<double>();
std::cout << LOGGER::ERROR << "Torque(" << i+1 << ")=" << value << " out of range(" << limit_lower << ", " << limit_upper << ")" << std::endl;
std::cout << LOGGER::ERROR << "Switching to STATE_POS_GETDOWN"<< std::endl;
}
std::cout << std::endl;
keyboard.keyboard_state = STATE_POS_GETDOWN;
}
}
@ -257,6 +249,11 @@ static bool kbhit()
void RL::RunKeyboard()
{
if(running_state == STATE_RL_RUNNING)
{
std::cout << LOGGER::INFO << "RL Controller x:" << keyboard.x << " y:" << keyboard.y << " yaw:" << keyboard.yaw << " \r";
}
if(kbhit())
{
int c = fgetc(stdin);
@ -301,7 +298,7 @@ void RL::ReadYaml(std::string robot_name)
} catch(YAML::BadFile &e)
{
std::cout << "The file '" << CONFIG_PATH << "' does not exist" << std::endl;
std::cout << LOGGER::ERROR << "The file '" << CONFIG_PATH << "' does not exist" << std::endl;
return;
}

View File

@ -9,6 +9,13 @@
#include <yaml-cpp/yaml.h>
#define CONFIG_PATH CMAKE_CURRENT_SOURCE_DIR "/config.yaml"
namespace LOGGER {
const char* const INFO = "\033[0;37m[INFO]\033[0m ";
const char* const WARNING = "\033[0;33m[WARNING]\033[0m ";
const char* const ERROR = "\033[0;31m[ERROR]\033[0m ";
const char* const DEBUG = "\033[0;32m[DEBUG]\033[0m ";
}
template<typename T>
struct RobotCommand
{
@ -102,6 +109,7 @@ class RL
{
public:
RL(){};
~RL(){};
ModelParams params;
Observations obs;

View File

@ -64,7 +64,7 @@ RL_Real::~RL_Real()
#ifdef PLOT
loop_plot->shutdown();
#endif
printf("exit\n");
std::cout << LOGGER::INFO << "RL_Real exit" << std::endl;
}
void RL_Real::GetState(RobotState<double> *state)

View File

@ -90,7 +90,7 @@ RL_Sim::~RL_Sim()
#ifdef PLOT
loop_plot->shutdown();
#endif
printf("exit\n");
std::cout << LOGGER::INFO << "RL_Sim exit" << std::endl;
}
void RL_Sim::GetState(RobotState<double> *state)