feat: add RL python sdk

This commit is contained in:
fan-ziqi 2024-06-29 11:38:07 +08:00
parent 9d2ed695cc
commit c153746bce
12 changed files with 689 additions and 48 deletions

2
.gitignore vendored
View File

@ -8,4 +8,4 @@ logs
*fldlar* *fldlar*
.cache .cache
*.json *.json
# *gr1t1* __pycache__

View File

@ -75,14 +75,22 @@ Before running, copy the trained pt model file to `rl_sar/src/rl_sar/models/YOUR
### Simulation ### Simulation
Open a new terminal, launch the gazebo simulation environment Open a terminal, launch the gazebo simulation environment
```bash ```bash
source devel/setup.bash source devel/setup.bash
roslaunch rl_sar gazebo_<ROBOT>.launch roslaunch rl_sar gazebo_<ROBOT>.launch
``` ```
Where \<ROBOT\> can be `a1` or `gr1t1`. Open a new terminal, launch the control program
```bash
source devel/setup.bash
(for cpp version) rosrun rl_sar rl_sim
(for python version) rosrun rl_sar rl_sim.py
```
Where \<ROBOT\> can be `a1` or `gr1t1` or `gr1t2`.
Control: Control:
* Press **\<Enter\>** to toggle simulation start/stop. * Press **\<Enter\>** to toggle simulation start/stop.

View File

@ -75,14 +75,22 @@ catkin build
### 仿真 ### 仿真
新建终端启动gazebo仿真环境 打开一个终端启动gazebo仿真环境
```bash ```bash
source devel/setup.bash source devel/setup.bash
roslaunch rl_sar gazebo_<ROBOT>.launch roslaunch rl_sar gazebo_<ROBOT>.launch
``` ```
其中 \<ROBOT\> 可以是 `a1``gr1t1`. 打开一个新终端,启动控制程序
```bash
source devel/setup.bash
(for cpp version) rosrun rl_sar rl_sim
(for python version) rosrun rl_sar rl_sim.py
```
其中 \<ROBOT\> 可以是 `a1``gr1t1``gr1t2`.
控制: 控制:

View File

@ -26,6 +26,7 @@ find_package(catkin REQUIRED COMPONENTS
geometry_msgs geometry_msgs
robot_msgs robot_msgs
robot_joint_controller robot_joint_controller
rospy
) )
find_package(Python3 COMPONENTS Interpreter Development REQUIRED) find_package(Python3 COMPONENTS Interpreter Development REQUIRED)
@ -37,6 +38,7 @@ include_directories(${YAML_CPP_INCLUDE_DIR})
catkin_package( catkin_package(
CATKIN_DEPENDS CATKIN_DEPENDS
robot_joint_controller robot_joint_controller
rospy
) )
include_directories(library/unitree_legged_sdk_3.2/include) include_directories(library/unitree_legged_sdk_3.2/include)
@ -78,3 +80,9 @@ target_link_libraries(rl_real_a1
${catkin_LIBRARIES} ${EXTRA_LIBS} ${catkin_LIBRARIES} ${EXTRA_LIBS}
rl_sdk observation_buffer yaml-cpp rl_sdk observation_buffer yaml-cpp
) )
catkin_install_python(PROGRAMS
scripts/rl_sim.py
scripts/rl_sdk.py
DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION}
)

View File

@ -1,7 +1,7 @@
<launch> <launch>
<arg name="wname" default="stairs"/> <arg name="wname" default="stairs"/>
<arg name="rname" default="a1"/> <arg name="rname" default="a1"/>
<param name="robot_name" type="str" value="a1"/> <param name="robot_name" type="str" value="$(arg rname)"/>
<param name="use_history" type="bool" value="true"/> <param name="use_history" type="bool" value="true"/>
<param name="ros_namespace" type="str" value="/$(arg rname)_gazebo/"/> <param name="ros_namespace" type="str" value="/$(arg rname)_gazebo/"/>
<arg name="robot_path" value="(find $(arg rname)_description)"/> <arg name="robot_path" value="(find $(arg rname)_description)"/>
@ -37,8 +37,6 @@
<!-- Load joint controller configurations from YAML file to parameter server --> <!-- Load joint controller configurations from YAML file to parameter server -->
<rosparam file="$(arg dollar)$(arg robot_path)/config/robot_control.yaml" command="load"/> <rosparam file="$(arg dollar)$(arg robot_path)/config/robot_control.yaml" command="load"/>
<!-- <rosparam param="/a1_gazebo/joint_state_controller/publish_rate">5000</rosparam> -->
<!-- load the controllers --> <!-- load the controllers -->
<node pkg="controller_manager" type="spawner" name="controller_spawner" respawn="false" <node pkg="controller_manager" type="spawner" name="controller_spawner" respawn="false"
output="screen" ns="/$(arg rname)_gazebo" args="joint_state_controller output="screen" ns="/$(arg rname)_gazebo" args="joint_state_controller
@ -53,6 +51,4 @@
<remap from="/joint_states" to="/$(arg rname)_gazebo/joint_states"/> <remap from="/joint_states" to="/$(arg rname)_gazebo/joint_states"/>
</node> </node>
<node pkg="rl_sar" type="rl_sim" name="rl_sim" output="screen"/>
</launch> </launch>

View File

@ -1,7 +1,7 @@
<launch> <launch>
<arg name="wname" default="stairs"/> <arg name="wname" default="stairs"/>
<arg name="rname" default="gr1t1"/> <arg name="rname" default="gr1t1"/>
<param name="robot_name" type="str" value="gr1t1"/> <param name="robot_name" type="str" value="$(arg rname)"/>
<param name="use_history" type="bool" value="false"/> <param name="use_history" type="bool" value="false"/>
<param name="ros_namespace" type="str" value="/$(arg rname)_gazebo/"/> <param name="ros_namespace" type="str" value="/$(arg rname)_gazebo/"/>
<arg name="robot_path" value="(find $(arg rname)_description)"/> <arg name="robot_path" value="(find $(arg rname)_description)"/>
@ -33,8 +33,6 @@
<!-- Load joint controller configurations from YAML file to parameter server --> <!-- Load joint controller configurations from YAML file to parameter server -->
<rosparam file="$(arg dollar)$(arg robot_path)/config/robot_control.yaml" command="load"/> <rosparam file="$(arg dollar)$(arg robot_path)/config/robot_control.yaml" command="load"/>
<!-- <rosparam param="/gr1t1_gazebo/joint_state_controller/publish_rate">5000</rosparam> -->
<!-- load the controllers --> <!-- load the controllers -->
<node pkg="controller_manager" type="spawner" name="controller_spawner" respawn="false" <node pkg="controller_manager" type="spawner" name="controller_spawner" respawn="false"
output="screen" ns="/$(arg rname)_gazebo" args="joint_state_controller output="screen" ns="/$(arg rname)_gazebo" args="joint_state_controller
@ -47,6 +45,4 @@
<remap from="/joint_states" to="/$(arg rname)_gazebo/joint_states"/> <remap from="/joint_states" to="/$(arg rname)_gazebo/joint_states"/>
</node> </node>
<node pkg="rl_sar" type="rl_sim" name="rl_sim" output="screen"/>
</launch> </launch>

View File

@ -1,31 +1,28 @@
#include "rl_sdk.hpp" #include "rl_sdk.hpp"
/* You may need to override this ComputeObservation() function /* You may need to override this ComputeObservation() function
torch::Tensor RL::ComputeObservation() torch::Tensor RL_XXX::ComputeObservation()
{ {
torch::Tensor obs = torch::cat({this->QuatRotateInverse(this->obs.base_quat, this->obs.ang_vel) * this->params.ang_vel_scale, torch::Tensor obs = torch::cat({
this->QuatRotateInverse(this->obs.base_quat, this->obs.gravity_vec), this->QuatRotateInverse(this->obs.base_quat, this->obs.ang_vel) * this->params.ang_vel_scale,
this->obs.commands * this->params.commands_scale, this->QuatRotateInverse(this->obs.base_quat, this->obs.gravity_vec),
(this->obs.dof_pos - this->params.default_dof_pos) * this->params.dof_pos_scale, this->obs.commands * this->params.commands_scale,
this->obs.dof_vel * this->params.dof_vel_scale, (this->obs.dof_pos - this->params.default_dof_pos) * this->params.dof_pos_scale,
this->obs.actions this->obs.dof_vel * this->params.dof_vel_scale,
},1); this->obs.actions
},1);
torch::Tensor clamped_obs = torch::clamp(obs, -this->params.clip_obs, this->params.clip_obs); torch::Tensor clamped_obs = torch::clamp(obs, -this->params.clip_obs, this->params.clip_obs);
return clamped_obs; return clamped_obs;
} }
*/ */
/* You may need to override this Forward() function /* You may need to override this Forward() function
torch::Tensor RL::Forward() torch::Tensor RL_XXX::Forward()
{ {
torch::autograd::GradMode::set_enabled(false); torch::autograd::GradMode::set_enabled(false);
torch::Tensor clamped_obs = this->ComputeObservation(); torch::Tensor clamped_obs = this->ComputeObservation();
torch::Tensor actions = this->model.forward({clamped_obs}).toTensor(); torch::Tensor actions = this->model.forward({clamped_obs}).toTensor();
torch::Tensor clamped_actions = torch::clamp(actions, this->params.clip_actions_lower, this->params.clip_actions_upper); torch::Tensor clamped_actions = torch::clamp(actions, this->params.clip_actions_lower, this->params.clip_actions_upper);
return clamped_actions; return clamped_actions;
} }
*/ */

View File

@ -25,6 +25,8 @@
<exec_depend>robot_state_publisher</exec_depend> <exec_depend>robot_state_publisher</exec_depend>
<exec_depend>roscpp</exec_depend> <exec_depend>roscpp</exec_depend>
<exec_depend>std_msgs</exec_depend> <exec_depend>std_msgs</exec_depend>
<build_depend>rospy</build_depend>
<exec_depend>rospy</exec_depend>
<depend>robot_msgs</depend> <depend>robot_msgs</depend>
<depend>robot_joint_controller</depend> <depend>robot_joint_controller</depend>

View File

@ -0,0 +1,37 @@
import torch
class ObservationBuffer:
def __init__(self, num_envs, num_obs, include_history_steps):
self.num_envs = num_envs
self.num_obs = num_obs
self.include_history_steps = include_history_steps
self.num_obs_total = num_obs * include_history_steps
self.obs_buf = torch.zeros(self.num_envs, self.num_obs_total, dtype=torch.float)
def reset(self, reset_idxs, new_obs):
self.obs_buf[reset_idxs] = new_obs.repeat(1, self.include_history_steps)
def insert(self, new_obs):
# Shift observations back.
self.obs_buf[:, : self.num_obs * (self.include_history_steps - 1)] = self.obs_buf[:,self.num_obs : self.num_obs * self.include_history_steps].clone()
# Add new observation.
self.obs_buf[:, -self.num_obs:] = new_obs
def get_obs_vec(self, obs_ids):
"""Gets history of observations indexed by obs_ids.
Arguments:
obs_ids: An array of integers with which to index the desired
observations, where 0 is the latest observation and
include_history_steps - 1 is the oldest observation.
"""
obs = []
for obs_id in reversed(sorted(obs_ids)):
slice_idx = self.include_history_steps - obs_id - 1
obs.append(self.obs_buf[:, slice_idx * self.num_obs : (slice_idx + 1) * self.num_obs])
return torch.cat(obs, dim=-1)

View File

@ -0,0 +1,356 @@
import torch
import yaml
import os
from pynput import keyboard
from enum import Enum, auto
CONFIG_PATH = os.path.join(os.path.dirname(__file__), "../config.yaml")
class LOGGER:
INFO = "\033[0;37m[INFO]\033[0m "
WARNING = "\033[0;33m[WARNING]\033[0m "
ERROR = "\033[0;31m[ERROR]\033[0m "
DEBUG = "\033[0;32m[DEBUG]\033[0m "
class RobotCommand:
def __init__(self):
self.motor_command = self.MotorCommand()
class MotorCommand:
def __init__(self):
self.q = [0.0] * 32
self.dq = [0.0] * 32
self.tau = [0.0] * 32
self.kp = [0.0] * 32
self.kd = [0.0] * 32
class RobotState:
def __init__(self):
self.imu = self.IMU()
self.motor_state = self.MotorState()
class IMU:
def __init__(self):
self.quaternion = [1.0, 0.0, 0.0, 0.0] # w, x, y, z
self.gyroscope = [0.0, 0.0, 0.0]
self.accelerometer = [0.0, 0.0, 0.0]
class MotorState:
def __init__(self):
self.q = [0.0] * 32
self.dq = [0.0] * 32
self.ddq = [0.0] * 32
self.tauEst = [0.0] * 32
self.cur = [0.0] * 32
class STATE(Enum):
STATE_WAITING = 0
STATE_POS_GETUP = auto()
STATE_RL_INIT = auto()
STATE_RL_RUNNING = auto()
STATE_POS_GETDOWN = auto()
STATE_RESET_SIMULATION = auto()
STATE_TOGGLE_SIMULATION = auto()
class Control:
def __init__(self):
self.control_state = STATE.STATE_WAITING
self.x = 0.0
self.y = 0.0
self.yaw = 0.0
class ModelParams:
def __init__(self):
self.model_name = None
self.dt = None
self.decimation = None
self.num_observations = None
self.damping = None
self.stiffness = None
self.action_scale = None
self.hip_scale_reduction = None
self.hip_scale_reduction_indices = None
self.clip_actions_upper = None
self.clip_actions_lower = None
self.num_of_dofs = None
self.lin_vel_scale = None
self.ang_vel_scale = None
self.dof_pos_scale = None
self.dof_vel_scale = None
self.clip_obs = None
self.torque_limits = None
self.rl_kd = None
self.rl_kp = None
self.fixed_kp = None
self.fixed_kd = None
self.commands_scale = None
self.default_dof_pos = None
self.joint_controller_names = None
class Observations:
def __init__(self):
self.lin_vel = None
self.ang_vel = None
self.gravity_vec = None
self.commands = None
self.base_quat = None
self.dof_pos = None
self.dof_vel = None
self.actions = None
class RL:
# Static variables
start_state = RobotState()
now_state = RobotState()
getup_percent = 0.0
getdown_percent = 0.0
def __init__(self):
### public in cpp ###
self.params = ModelParams()
self.obs = Observations()
self.robot_state = RobotState()
self.robot_command = RobotCommand()
# control
self.control = Control()
# others
self.robot_name = ""
self.running_state = STATE.STATE_RL_RUNNING # default running_state set to STATE_RL_RUNNING
self.simulation_running = False
### protected in cpp ###
# rl module
self.model = None
self.walk_model = None
self.stand_model = None
# output buffer
self.output_torques = torch.zeros(1, 32)
self.output_dof_pos = torch.zeros(1, 32)
def InitObservations(self):
self.obs.lin_vel = torch.zeros(1, 3, dtype=torch.float)
self.obs.ang_vel = torch.zeros(1, 3, dtype=torch.float)
self.obs.gravity_vec = torch.tensor([[0.0, 0.0, -1.0]])
self.obs.commands = torch.zeros(1, 3, dtype=torch.float)
self.obs.base_quat = torch.zeros(1, 4, dtype=torch.float)
self.obs.dof_pos = self.params.default_dof_pos
self.obs.dof_vel = torch.zeros(1, self.params.num_of_dofs, dtype=torch.float)
self.obs.actions = torch.zeros(1, self.params.num_of_dofs, dtype=torch.float)
def InitOutputs(self):
self.output_torques = torch.zeros(1, self.params.num_of_dofs, dtype=torch.float)
self.output_dof_pos = self.params.default_dof_pos
def InitControl(self):
self.control.control_state = STATE.STATE_WAITING
self.control.x = 0.0
self.control.y = 0.0
self.control.yaw = 0.0
def ComputeTorques(self, actions):
actions_scaled = actions * self.params.action_scale
output_torques = self.params.rl_kp * (actions_scaled + self.params.default_dof_pos - self.obs.dof_pos) - self.params.rl_kd * self.obs.dof_vel
return output_torques
def ComputePosition(self, actions):
actions_scaled = actions * self.params.action_scale
return actions_scaled + self.params.default_dof_pos
def QuatRotateInverse(self, q, v):
shape = q.shape
q_w = q[:, -1]
q_vec = q[:, :3]
a = v * (2.0 * q_w ** 2 - 1.0).unsqueeze(-1)
b = torch.cross(q_vec, v, dim=-1) * q_w.unsqueeze(-1) * 2.0
c = q_vec * torch.bmm(q_vec.view(shape[0], 1, 3), v.view(shape[0], 3, 1)).squeeze(-1) * 2.0
return a - b + c
def StateController(self, state, command):
# waiting
if self.running_state == STATE.STATE_WAITING:
for i in range(self.params.num_of_dofs):
command.motor_command.q[i] = state.motor_state.q[i]
if self.control.control_state == STATE.STATE_POS_GETUP:
self.control.control_state = STATE.STATE_WAITING
self.getup_percent = 0.0
for i in range(self.params.num_of_dofs):
self.now_state.motor_state.q[i] = state.motor_state.q[i]
self.start_state.motor_state.q[i] = self.now_state.motor_state.q[i]
self.running_state = STATE.STATE_POS_GETUP
print("\r\n" + LOGGER.INFO + "Switching to STATE_POS_GETUP")
# stand up (position control)
elif self.running_state == STATE.STATE_POS_GETUP:
if self.getup_percent < 1.0:
self.getup_percent += 1 / 500.0
self.getup_percent = min(self.getup_percent, 1.0)
for i in range(self.params.num_of_dofs):
command.motor_command.q[i] = (1 - self.getup_percent) * self.now_state.motor_state.q[i] + self.getup_percent * self.params.default_dof_pos[0][i].item()
command.motor_command.dq[i] = 0
command.motor_command.kp[i] = self.params.fixed_kp[0][i].item()
command.motor_command.kd[i] = self.params.fixed_kd[0][i].item()
command.motor_command.tau[i] = 0
print("\r" + LOGGER.INFO + f"Getting up {self.getup_percent * 100.0:.1f}", end='', flush=True)
if self.control.control_state == STATE.STATE_RL_INIT:
self.control.control_state = STATE.STATE_WAITING
self.running_state = STATE.STATE_RL_INIT
print("\r\n" + LOGGER.INFO + "Switching to STATE_RL_INIT")
elif self.control.control_state == STATE.STATE_POS_GETDOWN:
self.control.control_state = STATE.STATE_WAITING
self.getdown_percent = 0.0
for i in range(self.params.num_of_dofs):
self.now_state.motor_state.q[i] = state.motor_state.q[i]
self.running_state = STATE.STATE_POS_GETDOWN
print("\r\n" + LOGGER.INFO + "Switching to STATE_POS_GETDOWN")
# init obs and start rl loop
elif self.running_state == STATE.STATE_RL_INIT:
if self.getup_percent == 1:
self.InitObservations()
self.InitOutputs()
self.InitControl()
self.running_state = STATE.STATE_RL_RUNNING
print("\r\n" + LOGGER.INFO + "Switching to STATE_RL_RUNNING")
# rl loop
if self.running_state == STATE.STATE_RL_RUNNING:
print("\r" + LOGGER.INFO + f"RL Controller x: {self.control.x:.1f} y: {self.control.y:.1f} yaw: {self.control.yaw:.1f}", end='', flush=True)
for i in range(self.params.num_of_dofs):
command.motor_command.q[i] = self.output_dof_pos[0][i].item()
command.motor_command.dq[i] = 0
command.motor_command.kp[i] = self.params.rl_kp[0][i].item()
command.motor_command.kd[i] = self.params.rl_kd[0][i].item()
command.motor_command.tau[i] = 0
if self.control.control_state == STATE.STATE_POS_GETDOWN:
self.control.control_state = STATE.STATE_WAITING
self.getdown_percent = 0.0
for i in range(self.params.num_of_dofs):
self.now_state.motor_state.q[i] = state.motor_state.q[i]
self.running_state = STATE.STATE_POS_GETDOWN
print("\r\n" + LOGGER.INFO + "Switching to STATE_POS_GETDOWN")
elif self.control.control_state == STATE.STATE_POS_GETUP:
self.control.control_state = STATE.STATE_WAITING
self.getup_percent = 0.0
for i in range(self.params.num_of_dofs):
self.now_state.motor_state.q[i] = state.motor_state.q[i]
self.running_state = STATE.STATE_POS_GETUP
print("\r\n" + LOGGER.INFO + "Switching to STATE_POS_GETUP")
# get down (position control)
elif self.running_state == STATE.STATE_POS_GETDOWN:
if self.getdown_percent < 1.0:
self.getdown_percent += 1 / 500.0
self.getdown_percent = min(1.0, self.getdown_percent)
for i in range(self.params.num_of_dofs):
command.motor_command.q[i] = (1 - self.getdown_percent) * self.now_state.motor_state.q[i] + self.getdown_percent * self.start_state.motor_state.q[i]
command.motor_command.dq[i] = 0
command.motor_command.kp[i] = self.params.fixed_kp[0][i].item()
command.motor_command.kd[i] = self.params.fixed_kd[0][i].item()
command.motor_command.tau[i] = 0
print("\r" + LOGGER.INFO + f"Getting down {self.getdown_percent * 100.0:.1f}", end='', flush=True)
if self.getdown_percent == 1:
self.InitObservations()
self.InitOutputs()
self.InitControl()
self.running_state = STATE.STATE_WAITING
print("\r\n" + LOGGER.INFO + "Switching to STATE_WAITING")
def TorqueProtect(self, origin_output_torques):
out_of_range_indices = []
out_of_range_values = []
for i in range(origin_output_torques.size(1)):
torque_value = origin_output_torques[0][i].item()
limit_lower = -self.params.torque_limits[0][i].item()
limit_upper = self.params.torque_limits[0][i].item()
if torque_value < limit_lower or torque_value > limit_upper:
out_of_range_indices.append(i)
out_of_range_values.append(torque_value)
if out_of_range_indices:
for i, index in enumerate(out_of_range_indices):
value = out_of_range_values[i]
limit_lower = -self.params.torque_limits[0][index].item()
limit_upper = self.params.torque_limits[0][index].item()
print(LOGGER.WARNING + f"Torque({index + 1})={value} out of range({limit_lower}, {limit_upper})")
# Just a reminder, no protection
self.control.control_state = STATE.STATE_POS_GETDOWN
print(LOGGER.INFO + "Switching to STATE_POS_GETDOWN")
def KeyboardInterface(self, key):
try:
if hasattr(key, 'char'):
if key.char == '0':
self.control.control_state = STATE.STATE_POS_GETUP
elif key.char == 'p':
self.control.control_state = STATE.STATE_RL_INIT
elif key.char == '1':
self.control.control_state = STATE.STATE_POS_GETDOWN
elif key.char == 'w':
self.control.x += 0.1
elif key.char == 's':
self.control.x -= 0.1
elif key.char == 'a':
self.control.yaw += 0.1
elif key.char == 'd':
self.control.yaw -= 0.1
elif key.char == 'j':
self.control.y += 0.1
elif key.char == 'l':
self.control.y -= 0.1
elif key.char == 'r':
self.control.control_state = STATE.STATE_RESET_SIMULATION
else:
if key == keyboard.Key.enter:
self.control.control_state = STATE.STATE_TOGGLE_SIMULATION
elif key == keyboard.Key.space:
self.control.x = 0
self.control.y = 0
self.control.yaw = 0
except AttributeError:
pass
def ReadYaml(self, robot_name):
try:
with open(CONFIG_PATH, 'r') as f:
config = yaml.safe_load(f)[robot_name]
except FileNotFoundError as e:
print(LOGGER.ERROR + "The file '{CONFIG_PATH}' does not exist")
return
self.params.model_name = config["model_name"]
self.params.dt = config["dt"]
self.params.decimation = config["decimation"]
self.params.num_observations = config["num_observations"]
self.params.clip_obs = config["clip_obs"]
self.params.action_scale = config["action_scale"]
self.params.hip_scale_reduction = config["hip_scale_reduction"]
self.params.hip_scale_reduction_indices = config["hip_scale_reduction_indices"]
self.params.clip_actions_upper = torch.tensor(config["clip_actions_upper"]).view(1, -1)
self.params.clip_actions_lower = torch.tensor(config["clip_actions_lower"]).view(1, -1)
self.params.num_of_dofs = config["num_of_dofs"]
self.params.lin_vel_scale = config["lin_vel_scale"]
self.params.ang_vel_scale = config["ang_vel_scale"]
self.params.dof_pos_scale = config["dof_pos_scale"]
self.params.dof_vel_scale = config["dof_vel_scale"]
self.params.commands_scale = torch.tensor([self.params.lin_vel_scale, self.params.lin_vel_scale, self.params.ang_vel_scale])
self.params.rl_kp = torch.tensor(config["rl_kp"]).view(1, -1)
self.params.rl_kd = torch.tensor(config["rl_kd"]).view(1, -1)
self.params.fixed_kp = torch.tensor(config["fixed_kp"]).view(1, -1)
self.params.fixed_kd = torch.tensor(config["fixed_kd"]).view(1, -1)
self.params.torque_limits = torch.tensor(config["torque_limits"]).view(1, -1)
self.params.default_dof_pos = torch.tensor(config["default_dof_pos"]).view(1, -1)
self.params.joint_controller_names = config["joint_controller_names"]

View File

@ -0,0 +1,234 @@
import sys
import os
import torch
import threading
import time
import rospy
import numpy as np
from gazebo_msgs.msg import ModelStates
from sensor_msgs.msg import JointState
from geometry_msgs.msg import Twist, Pose
from robot_msgs.msg import MotorCommand
from gazebo_msgs.srv import SetModelState, SetModelStateRequest
from std_srvs.srv import Empty
path = os.path.abspath(".")
sys.path.insert(0, path + "/src/rl_sar/scripts")
from rl_sdk import *
from observation_buffer import *
class RL_Sim(RL):
def __init__(self):
super().__init__()
# member variables for RL_Sim
self.vel = Twist()
self.pose = Pose()
self.cmd_vel = Twist()
# start ros node
rospy.init_node("rl_sim")
# read params from yaml
self.robot_name = rospy.get_param("robot_name", "")
self.ReadYaml(self.robot_name)
# history
self.use_history = rospy.get_param("use_history", "")
if self.use_history:
self.history_obs_buf = ObservationBuffer(1, self.params.num_observations, 6)
# Due to the fact that the robot_state_publisher sorts the joint names alphabetically,
# the mapping table is established according to the order defined in the YAML file
sorted_joint_controller_names = sorted(self.params.joint_controller_names)
self.sorted_to_original_index = {}
for i in range(len(self.params.joint_controller_names)):
self.sorted_to_original_index[sorted_joint_controller_names[i]] = i
self.mapped_joint_positions = [0.0] * self.params.num_of_dofs
self.mapped_joint_velocities = [0.0] * self.params.num_of_dofs
self.mapped_joint_efforts = [0.0] * self.params.num_of_dofs
# init
torch.set_grad_enabled(False)
self.joint_publishers_commands = [MotorCommand() for _ in range(self.params.num_of_dofs)]
self.InitObservations()
self.InitOutputs()
self.InitControl()
# model
model_path = os.path.join(os.path.dirname(__file__), f"../models/{self.robot_name}/{self.params.model_name}")
self.model = torch.jit.load(model_path)
# publisher
self.ros_namespace = rospy.get_param("ros_namespace", "")
self.joint_publishers = {}
for i in range(self.params.num_of_dofs):
topic_name = f"{self.ros_namespace}{self.params.joint_controller_names[i]}/command"
self.joint_publishers[self.params.joint_controller_names[i]] = rospy.Publisher(topic_name, MotorCommand, queue_size=10)
# subscriber
self.cmd_vel_subscriber = rospy.Subscriber("/cmd_vel", Twist, self.CmdvelCallback, queue_size=10)
self.model_state_subscriber = rospy.Subscriber("/gazebo/model_states", ModelStates, self.ModelStatesCallback, queue_size=10)
joint_states_topic = f"{self.ros_namespace}joint_states"
self.joint_state_subscriber = rospy.Subscriber(joint_states_topic, JointState, self.JointStatesCallback, queue_size=10)
# service
self.gazebo_set_model_state_client = rospy.ServiceProxy("/gazebo/set_model_state", SetModelState)
self.gazebo_pause_physics_client = rospy.ServiceProxy("/gazebo/pause_physics", Empty)
self.gazebo_unpause_physics_client = rospy.ServiceProxy("/gazebo/unpause_physics", Empty)
# loops
self.thread_control = threading.Thread(target=self.ThreadControl)
self.thread_rl = threading.Thread(target=self.ThreadRL)
self.thread_control.start()
self.thread_rl.start()
# keyboard
self.listener_keyboard = keyboard.Listener(on_press=self.KeyboardInterface)
self.listener_keyboard.start()
print(LOGGER.INFO + "RL_Sim start")
def __del__(self):
print(LOGGER.INFO + "RL_Sim exit")
def GetState(self, state):
state.imu.quaternion[3] = self.pose.orientation.w
state.imu.quaternion[0] = self.pose.orientation.x
state.imu.quaternion[1] = self.pose.orientation.y
state.imu.quaternion[2] = self.pose.orientation.z
state.imu.gyroscope[0] = self.vel.angular.x
state.imu.gyroscope[1] = self.vel.angular.y
state.imu.gyroscope[2] = self.vel.angular.z
# state.imu.accelerometer
for i in range(self.params.num_of_dofs):
state.motor_state.q[i] = self.mapped_joint_positions[i]
state.motor_state.dq[i] = self.mapped_joint_velocities[i]
state.motor_state.tauEst[i] = self.mapped_joint_efforts[i]
def SetCommand(self, command):
for i in range(self.params.num_of_dofs):
self.joint_publishers_commands[i].q = command.motor_command.q[i]
self.joint_publishers_commands[i].dq = command.motor_command.dq[i]
self.joint_publishers_commands[i].kp = command.motor_command.kp[i]
self.joint_publishers_commands[i].kd = command.motor_command.kd[i]
self.joint_publishers_commands[i].tau = command.motor_command.tau[i]
for i in range(self.params.num_of_dofs):
self.joint_publishers[self.params.joint_controller_names[i]].publish(self.joint_publishers_commands[i])
def RobotControl(self):
if self.control.control_state == STATE.STATE_RESET_SIMULATION:
set_model_state = SetModelStateRequest().model_state
gazebo_model_name = f"{self.robot_name}_gazebo"
set_model_state.model_name = gazebo_model_name
set_model_state.pose.position.z = 1.0
set_model_state.reference_frame = "world"
self.gazebo_set_model_state_client(set_model_state)
self.control.control_state = STATE.STATE_WAITING
if self.control.control_state == STATE.STATE_TOGGLE_SIMULATION:
if self.simulation_running:
self.gazebo_pause_physics_client()
print("\r\n" + LOGGER.INFO + "Simulation Stop")
else:
self.gazebo_unpause_physics_client()
print("\r\n" + LOGGER.INFO + "Simulation Start")
self.simulation_running = not self.simulation_running
self.control.control_state = STATE.STATE_WAITING
if self.simulation_running:
self.GetState(self.robot_state)
self.StateController(self.robot_state, self.robot_command)
self.SetCommand(self.robot_command)
def ModelStatesCallback(self, msg):
self.vel = msg.twist[2]
self.pose = msg.pose[2]
def CmdvelCallback(self, msg):
self.cmd_vel = msg
def MapData(self, source_data, target_data):
for i in range(len(source_data)):
target_data[i] = source_data[self.sorted_to_original_index[self.params.joint_controller_names[i]]]
def JointStatesCallback(self, msg):
self.MapData(msg.position, self.mapped_joint_positions)
self.MapData(msg.velocity, self.mapped_joint_velocities)
self.MapData(msg.effort, self.mapped_joint_efforts)
def RunModel(self):
if self.running_state == STATE.STATE_RL_RUNNING and self.simulation_running:
# self.obs.lin_vel = torch.tensor([[self.vel.linear.x, self.vel.linear.y, self.vel.linear.z]])
self.obs.ang_vel = torch.tensor(self.robot_state.imu.gyroscope).unsqueeze(0)
# self.obs.commands = torch.tensor([[self.cmd_vel.linear.x, self.cmd_vel.linear.y, self.cmd_vel.angular.z]])
self.obs.commands = torch.tensor([[self.control.x, self.control.y, self.control.yaw]])
self.obs.base_quat = torch.tensor(self.robot_state.imu.quaternion).unsqueeze(0)
self.obs.dof_pos = torch.tensor(self.robot_state.motor_state.q).narrow(0, 0, self.params.num_of_dofs).unsqueeze(0)
self.obs.dof_vel = torch.tensor(self.robot_state.motor_state.dq).narrow(0, 0, self.params.num_of_dofs).unsqueeze(0)
clamped_actions = self.Forward()
for i in self.params.hip_scale_reduction_indices:
clamped_actions[0][i] *= self.params.hip_scale_reduction
self.obs.actions = clamped_actions
origin_output_torques = self.ComputeTorques(self.obs.actions)
# self.TorqueProtect(origin_output_torques)
self.output_torques = torch.clamp(origin_output_torques, -(self.params.torque_limits), self.params.torque_limits)
self.output_dof_pos = self.ComputePosition(self.obs.actions)
def ComputeObservation(self):
obs = torch.cat([
# self.obs.lin_vel * self.params.lin_vel_scale,
# self.obs.ang_vel * self.params.ang_vel_scale, # TODO is QuatRotateInverse necessery?
self.QuatRotateInverse(self.obs.base_quat, self.obs.ang_vel) * self.params.ang_vel_scale,
self.QuatRotateInverse(self.obs.base_quat, self.obs.gravity_vec),
self.obs.commands * self.params.commands_scale,
(self.obs.dof_pos - self.params.default_dof_pos) * self.params.dof_pos_scale,
self.obs.dof_vel * self.params.dof_vel_scale,
self.obs.actions
], dim = -1)
clamped_obs = torch.clamp(obs, -self.params.clip_obs, self.params.clip_obs)
return clamped_obs
def Forward(self):
torch.set_grad_enabled(False)
clamped_obs = self.ComputeObservation()
if self.use_history:
self.history_obs_buf.insert(clamped_obs)
history_obs = self.history_obs_buf.get_obs_vec(np.arange(6))
actions = self.model.forward(history_obs)
else:
actions = self.model.forward(clamped_obs)
clamped_actions = torch.clamp(actions, self.params.clip_actions_lower, self.params.clip_actions_upper)
return clamped_actions
def ThreadControl(self):
thread_period = self.params.dt
thread_name = "thread_control"
print(f"[Thread Start] named: {thread_name}, period: {thread_period * 1000:.0f}(ms), cpu unspecified")
while not rospy.is_shutdown():
self.RobotControl()
time.sleep(thread_period)
print("[Thread End] named: " + thread_name)
def ThreadRL(self):
thread_period = self.params.dt * self.params.decimation
thread_name = "thread_rl"
print(f"[Thread Start] named: {thread_name}, period: {thread_period * 1000:.0f}(ms), cpu unspecified")
while not rospy.is_shutdown():
self.RunModel()
time.sleep(thread_period)
print("[Thread End] named: " + thread_name)
if __name__ == "__main__":
rl_sim = RL_Sim()
rospy.spin()

View File

@ -61,13 +61,15 @@ RL_Sim::RL_Sim()
this->gazebo_unpause_physics_client = nh.serviceClient<std_srvs::Empty>("/gazebo/unpause_physics"); this->gazebo_unpause_physics_client = nh.serviceClient<std_srvs::Empty>("/gazebo/unpause_physics");
// loop // loop
this->loop_keyboard = std::make_shared<LoopFunc>("loop_keyboard", 0.05, std::bind(&RL_Sim::KeyboardInterface, this));
this->loop_control = std::make_shared<LoopFunc>("loop_control", this->params.dt, std::bind(&RL_Sim::RobotControl, this)); this->loop_control = std::make_shared<LoopFunc>("loop_control", this->params.dt, std::bind(&RL_Sim::RobotControl, this));
this->loop_rl = std::make_shared<LoopFunc>("loop_rl", this->params.dt * this->params.decimation, std::bind(&RL_Sim::RunModel, this)); this->loop_rl = std::make_shared<LoopFunc>("loop_rl", this->params.dt * this->params.decimation, std::bind(&RL_Sim::RunModel, this));
this->loop_keyboard->start();
this->loop_control->start(); this->loop_control->start();
this->loop_rl->start(); this->loop_rl->start();
// keyboard
this->loop_keyboard = std::make_shared<LoopFunc>("loop_keyboard", 0.05, std::bind(&RL_Sim::KeyboardInterface, this));
this->loop_keyboard->start();
#ifdef PLOT #ifdef PLOT
this->plot_t = std::vector<int>(this->plot_size, 0); this->plot_t = std::vector<int>(this->plot_size, 0);
this->plot_real_joint_pos.resize(this->params.num_of_dofs); this->plot_real_joint_pos.resize(this->params.num_of_dofs);
@ -80,6 +82,8 @@ RL_Sim::RL_Sim()
#ifdef CSV_LOGGER #ifdef CSV_LOGGER
this->CSVInit(this->robot_name); this->CSVInit(this->robot_name);
#endif #endif
std::cout << LOGGER::INFO << "RL_Sim start" << std::endl;
} }
RL_Sim::~RL_Sim() RL_Sim::~RL_Sim()
@ -150,10 +154,12 @@ void RL_Sim::RobotControl()
if(simulation_running) if(simulation_running)
{ {
this->gazebo_pause_physics_client.call(empty); this->gazebo_pause_physics_client.call(empty);
std::cout << std::endl << LOGGER::INFO << "Simulation Stop" << std::endl;
} }
else else
{ {
this->gazebo_unpause_physics_client.call(empty); this->gazebo_unpause_physics_client.call(empty);
std::cout << std::endl << LOGGER::INFO << "Simulation Start" << std::endl;
} }
simulation_running = !simulation_running; simulation_running = !simulation_running;
this->control.control_state = STATE_WAITING; this->control.control_state = STATE_WAITING;
@ -230,15 +236,16 @@ void RL_Sim::RunModel()
torch::Tensor RL_Sim::ComputeObservation() torch::Tensor RL_Sim::ComputeObservation()
{ {
torch::Tensor obs = torch::cat({// this->obs.lin_vel * this->params.lin_vel_scale, torch::Tensor obs = torch::cat({
this->QuatRotateInverse(this->obs.base_quat, this->obs.ang_vel) * this->params.ang_vel_scale, // this->obs.lin_vel * this->params.lin_vel_scale,
// this->obs.ang_vel * this->params.ang_vel_scale, // TODO // this->obs.ang_vel * this->params.ang_vel_scale, // TODO is QuatRotateInverse necessery?
this->QuatRotateInverse(this->obs.base_quat, this->obs.gravity_vec), this->QuatRotateInverse(this->obs.base_quat, this->obs.ang_vel) * this->params.ang_vel_scale,
this->obs.commands * this->params.commands_scale, this->QuatRotateInverse(this->obs.base_quat, this->obs.gravity_vec),
(this->obs.dof_pos - this->params.default_dof_pos) * this->params.dof_pos_scale, this->obs.commands * this->params.commands_scale,
this->obs.dof_vel * this->params.dof_vel_scale, (this->obs.dof_pos - this->params.default_dof_pos) * this->params.dof_pos_scale,
this->obs.actions this->obs.dof_vel * this->params.dof_vel_scale,
},1); this->obs.actions
},1);
torch::Tensor clamped_obs = torch::clamp(obs, -this->params.clip_obs, this->params.clip_obs); torch::Tensor clamped_obs = torch::clamp(obs, -this->params.clip_obs, this->params.clip_obs);
return clamped_obs; return clamped_obs;
} }
@ -246,11 +253,8 @@ torch::Tensor RL_Sim::ComputeObservation()
torch::Tensor RL_Sim::Forward() torch::Tensor RL_Sim::Forward()
{ {
torch::autograd::GradMode::set_enabled(false); torch::autograd::GradMode::set_enabled(false);
torch::Tensor clamped_obs = this->ComputeObservation(); torch::Tensor clamped_obs = this->ComputeObservation();
torch::Tensor actions; torch::Tensor actions;
if(this->use_history) if(this->use_history)
{ {
this->history_obs_buf.insert(clamped_obs); this->history_obs_buf.insert(clamped_obs);
@ -263,7 +267,6 @@ torch::Tensor RL_Sim::Forward()
} }
torch::Tensor clamped_actions = torch::clamp(actions, this->params.clip_actions_lower, this->params.clip_actions_upper); torch::Tensor clamped_actions = torch::clamp(actions, this->params.clip_actions_lower, this->params.clip_actions_upper);
return clamped_actions; return clamped_actions;
} }
@ -297,12 +300,8 @@ void signalHandler(int signum)
int main(int argc, char **argv) int main(int argc, char **argv)
{ {
signal(SIGINT, signalHandler); signal(SIGINT, signalHandler);
ros::init(argc, argv, "rl_sar"); ros::init(argc, argv, "rl_sar");
RL_Sim rl_sar; RL_Sim rl_sar;
ros::spin(); ros::spin();
return 0; return 0;
} }