mirror of https://github.com/fan-ziqi/rl_sar.git
feat: add RL python sdk
This commit is contained in:
parent
9d2ed695cc
commit
c153746bce
|
@ -8,4 +8,4 @@ logs
|
||||||
*fldlar*
|
*fldlar*
|
||||||
.cache
|
.cache
|
||||||
*.json
|
*.json
|
||||||
# *gr1t1*
|
__pycache__
|
12
README.md
12
README.md
|
@ -75,14 +75,22 @@ Before running, copy the trained pt model file to `rl_sar/src/rl_sar/models/YOUR
|
||||||
|
|
||||||
### Simulation
|
### Simulation
|
||||||
|
|
||||||
Open a new terminal, launch the gazebo simulation environment
|
Open a terminal, launch the gazebo simulation environment
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
source devel/setup.bash
|
source devel/setup.bash
|
||||||
roslaunch rl_sar gazebo_<ROBOT>.launch
|
roslaunch rl_sar gazebo_<ROBOT>.launch
|
||||||
```
|
```
|
||||||
|
|
||||||
Where \<ROBOT\> can be `a1` or `gr1t1`.
|
Open a new terminal, launch the control program
|
||||||
|
|
||||||
|
```bash
|
||||||
|
source devel/setup.bash
|
||||||
|
(for cpp version) rosrun rl_sar rl_sim
|
||||||
|
(for python version) rosrun rl_sar rl_sim.py
|
||||||
|
```
|
||||||
|
|
||||||
|
Where \<ROBOT\> can be `a1` or `gr1t1` or `gr1t2`.
|
||||||
|
|
||||||
Control:
|
Control:
|
||||||
* Press **\<Enter\>** to toggle simulation start/stop.
|
* Press **\<Enter\>** to toggle simulation start/stop.
|
||||||
|
|
12
README_CN.md
12
README_CN.md
|
@ -75,14 +75,22 @@ catkin build
|
||||||
|
|
||||||
### 仿真
|
### 仿真
|
||||||
|
|
||||||
新建终端,启动gazebo仿真环境
|
打开一个终端,启动gazebo仿真环境
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
source devel/setup.bash
|
source devel/setup.bash
|
||||||
roslaunch rl_sar gazebo_<ROBOT>.launch
|
roslaunch rl_sar gazebo_<ROBOT>.launch
|
||||||
```
|
```
|
||||||
|
|
||||||
其中 \<ROBOT\> 可以是 `a1` 或 `gr1t1`.
|
打开一个新终端,启动控制程序
|
||||||
|
|
||||||
|
```bash
|
||||||
|
source devel/setup.bash
|
||||||
|
(for cpp version) rosrun rl_sar rl_sim
|
||||||
|
(for python version) rosrun rl_sar rl_sim.py
|
||||||
|
```
|
||||||
|
|
||||||
|
其中 \<ROBOT\> 可以是 `a1` 或 `gr1t1` 或 `gr1t2`.
|
||||||
|
|
||||||
控制:
|
控制:
|
||||||
|
|
||||||
|
|
|
@ -26,6 +26,7 @@ find_package(catkin REQUIRED COMPONENTS
|
||||||
geometry_msgs
|
geometry_msgs
|
||||||
robot_msgs
|
robot_msgs
|
||||||
robot_joint_controller
|
robot_joint_controller
|
||||||
|
rospy
|
||||||
)
|
)
|
||||||
|
|
||||||
find_package(Python3 COMPONENTS Interpreter Development REQUIRED)
|
find_package(Python3 COMPONENTS Interpreter Development REQUIRED)
|
||||||
|
@ -37,6 +38,7 @@ include_directories(${YAML_CPP_INCLUDE_DIR})
|
||||||
catkin_package(
|
catkin_package(
|
||||||
CATKIN_DEPENDS
|
CATKIN_DEPENDS
|
||||||
robot_joint_controller
|
robot_joint_controller
|
||||||
|
rospy
|
||||||
)
|
)
|
||||||
|
|
||||||
include_directories(library/unitree_legged_sdk_3.2/include)
|
include_directories(library/unitree_legged_sdk_3.2/include)
|
||||||
|
@ -78,3 +80,9 @@ target_link_libraries(rl_real_a1
|
||||||
${catkin_LIBRARIES} ${EXTRA_LIBS}
|
${catkin_LIBRARIES} ${EXTRA_LIBS}
|
||||||
rl_sdk observation_buffer yaml-cpp
|
rl_sdk observation_buffer yaml-cpp
|
||||||
)
|
)
|
||||||
|
|
||||||
|
catkin_install_python(PROGRAMS
|
||||||
|
scripts/rl_sim.py
|
||||||
|
scripts/rl_sdk.py
|
||||||
|
DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION}
|
||||||
|
)
|
|
@ -1,7 +1,7 @@
|
||||||
<launch>
|
<launch>
|
||||||
<arg name="wname" default="stairs"/>
|
<arg name="wname" default="stairs"/>
|
||||||
<arg name="rname" default="a1"/>
|
<arg name="rname" default="a1"/>
|
||||||
<param name="robot_name" type="str" value="a1"/>
|
<param name="robot_name" type="str" value="$(arg rname)"/>
|
||||||
<param name="use_history" type="bool" value="true"/>
|
<param name="use_history" type="bool" value="true"/>
|
||||||
<param name="ros_namespace" type="str" value="/$(arg rname)_gazebo/"/>
|
<param name="ros_namespace" type="str" value="/$(arg rname)_gazebo/"/>
|
||||||
<arg name="robot_path" value="(find $(arg rname)_description)"/>
|
<arg name="robot_path" value="(find $(arg rname)_description)"/>
|
||||||
|
@ -37,8 +37,6 @@
|
||||||
<!-- Load joint controller configurations from YAML file to parameter server -->
|
<!-- Load joint controller configurations from YAML file to parameter server -->
|
||||||
<rosparam file="$(arg dollar)$(arg robot_path)/config/robot_control.yaml" command="load"/>
|
<rosparam file="$(arg dollar)$(arg robot_path)/config/robot_control.yaml" command="load"/>
|
||||||
|
|
||||||
<!-- <rosparam param="/a1_gazebo/joint_state_controller/publish_rate">5000</rosparam> -->
|
|
||||||
|
|
||||||
<!-- load the controllers -->
|
<!-- load the controllers -->
|
||||||
<node pkg="controller_manager" type="spawner" name="controller_spawner" respawn="false"
|
<node pkg="controller_manager" type="spawner" name="controller_spawner" respawn="false"
|
||||||
output="screen" ns="/$(arg rname)_gazebo" args="joint_state_controller
|
output="screen" ns="/$(arg rname)_gazebo" args="joint_state_controller
|
||||||
|
@ -53,6 +51,4 @@
|
||||||
<remap from="/joint_states" to="/$(arg rname)_gazebo/joint_states"/>
|
<remap from="/joint_states" to="/$(arg rname)_gazebo/joint_states"/>
|
||||||
</node>
|
</node>
|
||||||
|
|
||||||
<node pkg="rl_sar" type="rl_sim" name="rl_sim" output="screen"/>
|
|
||||||
|
|
||||||
</launch>
|
</launch>
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
<launch>
|
<launch>
|
||||||
<arg name="wname" default="stairs"/>
|
<arg name="wname" default="stairs"/>
|
||||||
<arg name="rname" default="gr1t1"/>
|
<arg name="rname" default="gr1t1"/>
|
||||||
<param name="robot_name" type="str" value="gr1t1"/>
|
<param name="robot_name" type="str" value="$(arg rname)"/>
|
||||||
<param name="use_history" type="bool" value="false"/>
|
<param name="use_history" type="bool" value="false"/>
|
||||||
<param name="ros_namespace" type="str" value="/$(arg rname)_gazebo/"/>
|
<param name="ros_namespace" type="str" value="/$(arg rname)_gazebo/"/>
|
||||||
<arg name="robot_path" value="(find $(arg rname)_description)"/>
|
<arg name="robot_path" value="(find $(arg rname)_description)"/>
|
||||||
|
@ -33,8 +33,6 @@
|
||||||
<!-- Load joint controller configurations from YAML file to parameter server -->
|
<!-- Load joint controller configurations from YAML file to parameter server -->
|
||||||
<rosparam file="$(arg dollar)$(arg robot_path)/config/robot_control.yaml" command="load"/>
|
<rosparam file="$(arg dollar)$(arg robot_path)/config/robot_control.yaml" command="load"/>
|
||||||
|
|
||||||
<!-- <rosparam param="/gr1t1_gazebo/joint_state_controller/publish_rate">5000</rosparam> -->
|
|
||||||
|
|
||||||
<!-- load the controllers -->
|
<!-- load the controllers -->
|
||||||
<node pkg="controller_manager" type="spawner" name="controller_spawner" respawn="false"
|
<node pkg="controller_manager" type="spawner" name="controller_spawner" respawn="false"
|
||||||
output="screen" ns="/$(arg rname)_gazebo" args="joint_state_controller
|
output="screen" ns="/$(arg rname)_gazebo" args="joint_state_controller
|
||||||
|
@ -47,6 +45,4 @@
|
||||||
<remap from="/joint_states" to="/$(arg rname)_gazebo/joint_states"/>
|
<remap from="/joint_states" to="/$(arg rname)_gazebo/joint_states"/>
|
||||||
</node>
|
</node>
|
||||||
|
|
||||||
<node pkg="rl_sar" type="rl_sim" name="rl_sim" output="screen"/>
|
|
||||||
|
|
||||||
</launch>
|
</launch>
|
||||||
|
|
|
@ -1,31 +1,28 @@
|
||||||
#include "rl_sdk.hpp"
|
#include "rl_sdk.hpp"
|
||||||
|
|
||||||
/* You may need to override this ComputeObservation() function
|
/* You may need to override this ComputeObservation() function
|
||||||
torch::Tensor RL::ComputeObservation()
|
torch::Tensor RL_XXX::ComputeObservation()
|
||||||
{
|
{
|
||||||
torch::Tensor obs = torch::cat({this->QuatRotateInverse(this->obs.base_quat, this->obs.ang_vel) * this->params.ang_vel_scale,
|
torch::Tensor obs = torch::cat({
|
||||||
this->QuatRotateInverse(this->obs.base_quat, this->obs.gravity_vec),
|
this->QuatRotateInverse(this->obs.base_quat, this->obs.ang_vel) * this->params.ang_vel_scale,
|
||||||
this->obs.commands * this->params.commands_scale,
|
this->QuatRotateInverse(this->obs.base_quat, this->obs.gravity_vec),
|
||||||
(this->obs.dof_pos - this->params.default_dof_pos) * this->params.dof_pos_scale,
|
this->obs.commands * this->params.commands_scale,
|
||||||
this->obs.dof_vel * this->params.dof_vel_scale,
|
(this->obs.dof_pos - this->params.default_dof_pos) * this->params.dof_pos_scale,
|
||||||
this->obs.actions
|
this->obs.dof_vel * this->params.dof_vel_scale,
|
||||||
},1);
|
this->obs.actions
|
||||||
|
},1);
|
||||||
torch::Tensor clamped_obs = torch::clamp(obs, -this->params.clip_obs, this->params.clip_obs);
|
torch::Tensor clamped_obs = torch::clamp(obs, -this->params.clip_obs, this->params.clip_obs);
|
||||||
return clamped_obs;
|
return clamped_obs;
|
||||||
}
|
}
|
||||||
*/
|
*/
|
||||||
|
|
||||||
/* You may need to override this Forward() function
|
/* You may need to override this Forward() function
|
||||||
torch::Tensor RL::Forward()
|
torch::Tensor RL_XXX::Forward()
|
||||||
{
|
{
|
||||||
torch::autograd::GradMode::set_enabled(false);
|
torch::autograd::GradMode::set_enabled(false);
|
||||||
|
|
||||||
torch::Tensor clamped_obs = this->ComputeObservation();
|
torch::Tensor clamped_obs = this->ComputeObservation();
|
||||||
|
|
||||||
torch::Tensor actions = this->model.forward({clamped_obs}).toTensor();
|
torch::Tensor actions = this->model.forward({clamped_obs}).toTensor();
|
||||||
|
|
||||||
torch::Tensor clamped_actions = torch::clamp(actions, this->params.clip_actions_lower, this->params.clip_actions_upper);
|
torch::Tensor clamped_actions = torch::clamp(actions, this->params.clip_actions_lower, this->params.clip_actions_upper);
|
||||||
|
|
||||||
return clamped_actions;
|
return clamped_actions;
|
||||||
}
|
}
|
||||||
*/
|
*/
|
||||||
|
|
|
@ -25,6 +25,8 @@
|
||||||
<exec_depend>robot_state_publisher</exec_depend>
|
<exec_depend>robot_state_publisher</exec_depend>
|
||||||
<exec_depend>roscpp</exec_depend>
|
<exec_depend>roscpp</exec_depend>
|
||||||
<exec_depend>std_msgs</exec_depend>
|
<exec_depend>std_msgs</exec_depend>
|
||||||
|
<build_depend>rospy</build_depend>
|
||||||
|
<exec_depend>rospy</exec_depend>
|
||||||
<depend>robot_msgs</depend>
|
<depend>robot_msgs</depend>
|
||||||
<depend>robot_joint_controller</depend>
|
<depend>robot_joint_controller</depend>
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,37 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
|
class ObservationBuffer:
|
||||||
|
def __init__(self, num_envs, num_obs, include_history_steps):
|
||||||
|
|
||||||
|
self.num_envs = num_envs
|
||||||
|
self.num_obs = num_obs
|
||||||
|
self.include_history_steps = include_history_steps
|
||||||
|
|
||||||
|
self.num_obs_total = num_obs * include_history_steps
|
||||||
|
|
||||||
|
self.obs_buf = torch.zeros(self.num_envs, self.num_obs_total, dtype=torch.float)
|
||||||
|
|
||||||
|
def reset(self, reset_idxs, new_obs):
|
||||||
|
self.obs_buf[reset_idxs] = new_obs.repeat(1, self.include_history_steps)
|
||||||
|
|
||||||
|
def insert(self, new_obs):
|
||||||
|
# Shift observations back.
|
||||||
|
self.obs_buf[:, : self.num_obs * (self.include_history_steps - 1)] = self.obs_buf[:,self.num_obs : self.num_obs * self.include_history_steps].clone()
|
||||||
|
|
||||||
|
# Add new observation.
|
||||||
|
self.obs_buf[:, -self.num_obs:] = new_obs
|
||||||
|
|
||||||
|
def get_obs_vec(self, obs_ids):
|
||||||
|
"""Gets history of observations indexed by obs_ids.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
obs_ids: An array of integers with which to index the desired
|
||||||
|
observations, where 0 is the latest observation and
|
||||||
|
include_history_steps - 1 is the oldest observation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
obs = []
|
||||||
|
for obs_id in reversed(sorted(obs_ids)):
|
||||||
|
slice_idx = self.include_history_steps - obs_id - 1
|
||||||
|
obs.append(self.obs_buf[:, slice_idx * self.num_obs : (slice_idx + 1) * self.num_obs])
|
||||||
|
return torch.cat(obs, dim=-1)
|
|
@ -0,0 +1,356 @@
|
||||||
|
import torch
|
||||||
|
import yaml
|
||||||
|
import os
|
||||||
|
from pynput import keyboard
|
||||||
|
from enum import Enum, auto
|
||||||
|
|
||||||
|
CONFIG_PATH = os.path.join(os.path.dirname(__file__), "../config.yaml")
|
||||||
|
|
||||||
|
class LOGGER:
|
||||||
|
INFO = "\033[0;37m[INFO]\033[0m "
|
||||||
|
WARNING = "\033[0;33m[WARNING]\033[0m "
|
||||||
|
ERROR = "\033[0;31m[ERROR]\033[0m "
|
||||||
|
DEBUG = "\033[0;32m[DEBUG]\033[0m "
|
||||||
|
|
||||||
|
class RobotCommand:
|
||||||
|
def __init__(self):
|
||||||
|
self.motor_command = self.MotorCommand()
|
||||||
|
|
||||||
|
class MotorCommand:
|
||||||
|
def __init__(self):
|
||||||
|
self.q = [0.0] * 32
|
||||||
|
self.dq = [0.0] * 32
|
||||||
|
self.tau = [0.0] * 32
|
||||||
|
self.kp = [0.0] * 32
|
||||||
|
self.kd = [0.0] * 32
|
||||||
|
|
||||||
|
class RobotState:
|
||||||
|
def __init__(self):
|
||||||
|
self.imu = self.IMU()
|
||||||
|
self.motor_state = self.MotorState()
|
||||||
|
|
||||||
|
class IMU:
|
||||||
|
def __init__(self):
|
||||||
|
self.quaternion = [1.0, 0.0, 0.0, 0.0] # w, x, y, z
|
||||||
|
self.gyroscope = [0.0, 0.0, 0.0]
|
||||||
|
self.accelerometer = [0.0, 0.0, 0.0]
|
||||||
|
|
||||||
|
class MotorState:
|
||||||
|
def __init__(self):
|
||||||
|
self.q = [0.0] * 32
|
||||||
|
self.dq = [0.0] * 32
|
||||||
|
self.ddq = [0.0] * 32
|
||||||
|
self.tauEst = [0.0] * 32
|
||||||
|
self.cur = [0.0] * 32
|
||||||
|
|
||||||
|
class STATE(Enum):
|
||||||
|
STATE_WAITING = 0
|
||||||
|
STATE_POS_GETUP = auto()
|
||||||
|
STATE_RL_INIT = auto()
|
||||||
|
STATE_RL_RUNNING = auto()
|
||||||
|
STATE_POS_GETDOWN = auto()
|
||||||
|
STATE_RESET_SIMULATION = auto()
|
||||||
|
STATE_TOGGLE_SIMULATION = auto()
|
||||||
|
|
||||||
|
class Control:
|
||||||
|
def __init__(self):
|
||||||
|
self.control_state = STATE.STATE_WAITING
|
||||||
|
self.x = 0.0
|
||||||
|
self.y = 0.0
|
||||||
|
self.yaw = 0.0
|
||||||
|
|
||||||
|
class ModelParams:
|
||||||
|
def __init__(self):
|
||||||
|
self.model_name = None
|
||||||
|
self.dt = None
|
||||||
|
self.decimation = None
|
||||||
|
self.num_observations = None
|
||||||
|
self.damping = None
|
||||||
|
self.stiffness = None
|
||||||
|
self.action_scale = None
|
||||||
|
self.hip_scale_reduction = None
|
||||||
|
self.hip_scale_reduction_indices = None
|
||||||
|
self.clip_actions_upper = None
|
||||||
|
self.clip_actions_lower = None
|
||||||
|
self.num_of_dofs = None
|
||||||
|
self.lin_vel_scale = None
|
||||||
|
self.ang_vel_scale = None
|
||||||
|
self.dof_pos_scale = None
|
||||||
|
self.dof_vel_scale = None
|
||||||
|
self.clip_obs = None
|
||||||
|
self.torque_limits = None
|
||||||
|
self.rl_kd = None
|
||||||
|
self.rl_kp = None
|
||||||
|
self.fixed_kp = None
|
||||||
|
self.fixed_kd = None
|
||||||
|
self.commands_scale = None
|
||||||
|
self.default_dof_pos = None
|
||||||
|
self.joint_controller_names = None
|
||||||
|
|
||||||
|
class Observations:
|
||||||
|
def __init__(self):
|
||||||
|
self.lin_vel = None
|
||||||
|
self.ang_vel = None
|
||||||
|
self.gravity_vec = None
|
||||||
|
self.commands = None
|
||||||
|
self.base_quat = None
|
||||||
|
self.dof_pos = None
|
||||||
|
self.dof_vel = None
|
||||||
|
self.actions = None
|
||||||
|
|
||||||
|
class RL:
|
||||||
|
# Static variables
|
||||||
|
start_state = RobotState()
|
||||||
|
now_state = RobotState()
|
||||||
|
getup_percent = 0.0
|
||||||
|
getdown_percent = 0.0
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
### public in cpp ###
|
||||||
|
self.params = ModelParams()
|
||||||
|
self.obs = Observations()
|
||||||
|
|
||||||
|
self.robot_state = RobotState()
|
||||||
|
self.robot_command = RobotCommand()
|
||||||
|
|
||||||
|
# control
|
||||||
|
self.control = Control()
|
||||||
|
|
||||||
|
# others
|
||||||
|
self.robot_name = ""
|
||||||
|
self.running_state = STATE.STATE_RL_RUNNING # default running_state set to STATE_RL_RUNNING
|
||||||
|
self.simulation_running = False
|
||||||
|
|
||||||
|
### protected in cpp ###
|
||||||
|
# rl module
|
||||||
|
self.model = None
|
||||||
|
self.walk_model = None
|
||||||
|
self.stand_model = None
|
||||||
|
|
||||||
|
# output buffer
|
||||||
|
self.output_torques = torch.zeros(1, 32)
|
||||||
|
self.output_dof_pos = torch.zeros(1, 32)
|
||||||
|
|
||||||
|
def InitObservations(self):
|
||||||
|
self.obs.lin_vel = torch.zeros(1, 3, dtype=torch.float)
|
||||||
|
self.obs.ang_vel = torch.zeros(1, 3, dtype=torch.float)
|
||||||
|
self.obs.gravity_vec = torch.tensor([[0.0, 0.0, -1.0]])
|
||||||
|
self.obs.commands = torch.zeros(1, 3, dtype=torch.float)
|
||||||
|
self.obs.base_quat = torch.zeros(1, 4, dtype=torch.float)
|
||||||
|
self.obs.dof_pos = self.params.default_dof_pos
|
||||||
|
self.obs.dof_vel = torch.zeros(1, self.params.num_of_dofs, dtype=torch.float)
|
||||||
|
self.obs.actions = torch.zeros(1, self.params.num_of_dofs, dtype=torch.float)
|
||||||
|
|
||||||
|
def InitOutputs(self):
|
||||||
|
self.output_torques = torch.zeros(1, self.params.num_of_dofs, dtype=torch.float)
|
||||||
|
self.output_dof_pos = self.params.default_dof_pos
|
||||||
|
|
||||||
|
def InitControl(self):
|
||||||
|
self.control.control_state = STATE.STATE_WAITING
|
||||||
|
self.control.x = 0.0
|
||||||
|
self.control.y = 0.0
|
||||||
|
self.control.yaw = 0.0
|
||||||
|
|
||||||
|
def ComputeTorques(self, actions):
|
||||||
|
actions_scaled = actions * self.params.action_scale
|
||||||
|
output_torques = self.params.rl_kp * (actions_scaled + self.params.default_dof_pos - self.obs.dof_pos) - self.params.rl_kd * self.obs.dof_vel
|
||||||
|
return output_torques
|
||||||
|
|
||||||
|
def ComputePosition(self, actions):
|
||||||
|
actions_scaled = actions * self.params.action_scale
|
||||||
|
return actions_scaled + self.params.default_dof_pos
|
||||||
|
|
||||||
|
def QuatRotateInverse(self, q, v):
|
||||||
|
shape = q.shape
|
||||||
|
q_w = q[:, -1]
|
||||||
|
q_vec = q[:, :3]
|
||||||
|
a = v * (2.0 * q_w ** 2 - 1.0).unsqueeze(-1)
|
||||||
|
b = torch.cross(q_vec, v, dim=-1) * q_w.unsqueeze(-1) * 2.0
|
||||||
|
c = q_vec * torch.bmm(q_vec.view(shape[0], 1, 3), v.view(shape[0], 3, 1)).squeeze(-1) * 2.0
|
||||||
|
return a - b + c
|
||||||
|
|
||||||
|
def StateController(self, state, command):
|
||||||
|
# waiting
|
||||||
|
if self.running_state == STATE.STATE_WAITING:
|
||||||
|
for i in range(self.params.num_of_dofs):
|
||||||
|
command.motor_command.q[i] = state.motor_state.q[i]
|
||||||
|
if self.control.control_state == STATE.STATE_POS_GETUP:
|
||||||
|
self.control.control_state = STATE.STATE_WAITING
|
||||||
|
self.getup_percent = 0.0
|
||||||
|
for i in range(self.params.num_of_dofs):
|
||||||
|
self.now_state.motor_state.q[i] = state.motor_state.q[i]
|
||||||
|
self.start_state.motor_state.q[i] = self.now_state.motor_state.q[i]
|
||||||
|
self.running_state = STATE.STATE_POS_GETUP
|
||||||
|
print("\r\n" + LOGGER.INFO + "Switching to STATE_POS_GETUP")
|
||||||
|
|
||||||
|
# stand up (position control)
|
||||||
|
elif self.running_state == STATE.STATE_POS_GETUP:
|
||||||
|
if self.getup_percent < 1.0:
|
||||||
|
self.getup_percent += 1 / 500.0
|
||||||
|
self.getup_percent = min(self.getup_percent, 1.0)
|
||||||
|
for i in range(self.params.num_of_dofs):
|
||||||
|
command.motor_command.q[i] = (1 - self.getup_percent) * self.now_state.motor_state.q[i] + self.getup_percent * self.params.default_dof_pos[0][i].item()
|
||||||
|
command.motor_command.dq[i] = 0
|
||||||
|
command.motor_command.kp[i] = self.params.fixed_kp[0][i].item()
|
||||||
|
command.motor_command.kd[i] = self.params.fixed_kd[0][i].item()
|
||||||
|
command.motor_command.tau[i] = 0
|
||||||
|
print("\r" + LOGGER.INFO + f"Getting up {self.getup_percent * 100.0:.1f}", end='', flush=True)
|
||||||
|
|
||||||
|
if self.control.control_state == STATE.STATE_RL_INIT:
|
||||||
|
self.control.control_state = STATE.STATE_WAITING
|
||||||
|
self.running_state = STATE.STATE_RL_INIT
|
||||||
|
print("\r\n" + LOGGER.INFO + "Switching to STATE_RL_INIT")
|
||||||
|
|
||||||
|
elif self.control.control_state == STATE.STATE_POS_GETDOWN:
|
||||||
|
self.control.control_state = STATE.STATE_WAITING
|
||||||
|
self.getdown_percent = 0.0
|
||||||
|
for i in range(self.params.num_of_dofs):
|
||||||
|
self.now_state.motor_state.q[i] = state.motor_state.q[i]
|
||||||
|
self.running_state = STATE.STATE_POS_GETDOWN
|
||||||
|
print("\r\n" + LOGGER.INFO + "Switching to STATE_POS_GETDOWN")
|
||||||
|
|
||||||
|
# init obs and start rl loop
|
||||||
|
elif self.running_state == STATE.STATE_RL_INIT:
|
||||||
|
if self.getup_percent == 1:
|
||||||
|
self.InitObservations()
|
||||||
|
self.InitOutputs()
|
||||||
|
self.InitControl()
|
||||||
|
self.running_state = STATE.STATE_RL_RUNNING
|
||||||
|
print("\r\n" + LOGGER.INFO + "Switching to STATE_RL_RUNNING")
|
||||||
|
|
||||||
|
# rl loop
|
||||||
|
if self.running_state == STATE.STATE_RL_RUNNING:
|
||||||
|
print("\r" + LOGGER.INFO + f"RL Controller x: {self.control.x:.1f} y: {self.control.y:.1f} yaw: {self.control.yaw:.1f}", end='', flush=True)
|
||||||
|
for i in range(self.params.num_of_dofs):
|
||||||
|
command.motor_command.q[i] = self.output_dof_pos[0][i].item()
|
||||||
|
command.motor_command.dq[i] = 0
|
||||||
|
command.motor_command.kp[i] = self.params.rl_kp[0][i].item()
|
||||||
|
command.motor_command.kd[i] = self.params.rl_kd[0][i].item()
|
||||||
|
command.motor_command.tau[i] = 0
|
||||||
|
|
||||||
|
if self.control.control_state == STATE.STATE_POS_GETDOWN:
|
||||||
|
self.control.control_state = STATE.STATE_WAITING
|
||||||
|
self.getdown_percent = 0.0
|
||||||
|
for i in range(self.params.num_of_dofs):
|
||||||
|
self.now_state.motor_state.q[i] = state.motor_state.q[i]
|
||||||
|
self.running_state = STATE.STATE_POS_GETDOWN
|
||||||
|
print("\r\n" + LOGGER.INFO + "Switching to STATE_POS_GETDOWN")
|
||||||
|
|
||||||
|
elif self.control.control_state == STATE.STATE_POS_GETUP:
|
||||||
|
self.control.control_state = STATE.STATE_WAITING
|
||||||
|
self.getup_percent = 0.0
|
||||||
|
for i in range(self.params.num_of_dofs):
|
||||||
|
self.now_state.motor_state.q[i] = state.motor_state.q[i]
|
||||||
|
self.running_state = STATE.STATE_POS_GETUP
|
||||||
|
print("\r\n" + LOGGER.INFO + "Switching to STATE_POS_GETUP")
|
||||||
|
|
||||||
|
# get down (position control)
|
||||||
|
elif self.running_state == STATE.STATE_POS_GETDOWN:
|
||||||
|
if self.getdown_percent < 1.0:
|
||||||
|
self.getdown_percent += 1 / 500.0
|
||||||
|
self.getdown_percent = min(1.0, self.getdown_percent)
|
||||||
|
for i in range(self.params.num_of_dofs):
|
||||||
|
command.motor_command.q[i] = (1 - self.getdown_percent) * self.now_state.motor_state.q[i] + self.getdown_percent * self.start_state.motor_state.q[i]
|
||||||
|
command.motor_command.dq[i] = 0
|
||||||
|
command.motor_command.kp[i] = self.params.fixed_kp[0][i].item()
|
||||||
|
command.motor_command.kd[i] = self.params.fixed_kd[0][i].item()
|
||||||
|
command.motor_command.tau[i] = 0
|
||||||
|
print("\r" + LOGGER.INFO + f"Getting down {self.getdown_percent * 100.0:.1f}", end='', flush=True)
|
||||||
|
|
||||||
|
if self.getdown_percent == 1:
|
||||||
|
self.InitObservations()
|
||||||
|
self.InitOutputs()
|
||||||
|
self.InitControl()
|
||||||
|
self.running_state = STATE.STATE_WAITING
|
||||||
|
print("\r\n" + LOGGER.INFO + "Switching to STATE_WAITING")
|
||||||
|
|
||||||
|
def TorqueProtect(self, origin_output_torques):
|
||||||
|
out_of_range_indices = []
|
||||||
|
out_of_range_values = []
|
||||||
|
|
||||||
|
for i in range(origin_output_torques.size(1)):
|
||||||
|
torque_value = origin_output_torques[0][i].item()
|
||||||
|
limit_lower = -self.params.torque_limits[0][i].item()
|
||||||
|
limit_upper = self.params.torque_limits[0][i].item()
|
||||||
|
|
||||||
|
if torque_value < limit_lower or torque_value > limit_upper:
|
||||||
|
out_of_range_indices.append(i)
|
||||||
|
out_of_range_values.append(torque_value)
|
||||||
|
|
||||||
|
if out_of_range_indices:
|
||||||
|
for i, index in enumerate(out_of_range_indices):
|
||||||
|
value = out_of_range_values[i]
|
||||||
|
limit_lower = -self.params.torque_limits[0][index].item()
|
||||||
|
limit_upper = self.params.torque_limits[0][index].item()
|
||||||
|
|
||||||
|
print(LOGGER.WARNING + f"Torque({index + 1})={value} out of range({limit_lower}, {limit_upper})")
|
||||||
|
|
||||||
|
# Just a reminder, no protection
|
||||||
|
self.control.control_state = STATE.STATE_POS_GETDOWN
|
||||||
|
print(LOGGER.INFO + "Switching to STATE_POS_GETDOWN")
|
||||||
|
|
||||||
|
def KeyboardInterface(self, key):
|
||||||
|
try:
|
||||||
|
if hasattr(key, 'char'):
|
||||||
|
if key.char == '0':
|
||||||
|
self.control.control_state = STATE.STATE_POS_GETUP
|
||||||
|
elif key.char == 'p':
|
||||||
|
self.control.control_state = STATE.STATE_RL_INIT
|
||||||
|
elif key.char == '1':
|
||||||
|
self.control.control_state = STATE.STATE_POS_GETDOWN
|
||||||
|
elif key.char == 'w':
|
||||||
|
self.control.x += 0.1
|
||||||
|
elif key.char == 's':
|
||||||
|
self.control.x -= 0.1
|
||||||
|
elif key.char == 'a':
|
||||||
|
self.control.yaw += 0.1
|
||||||
|
elif key.char == 'd':
|
||||||
|
self.control.yaw -= 0.1
|
||||||
|
elif key.char == 'j':
|
||||||
|
self.control.y += 0.1
|
||||||
|
elif key.char == 'l':
|
||||||
|
self.control.y -= 0.1
|
||||||
|
elif key.char == 'r':
|
||||||
|
self.control.control_state = STATE.STATE_RESET_SIMULATION
|
||||||
|
else:
|
||||||
|
if key == keyboard.Key.enter:
|
||||||
|
self.control.control_state = STATE.STATE_TOGGLE_SIMULATION
|
||||||
|
elif key == keyboard.Key.space:
|
||||||
|
self.control.x = 0
|
||||||
|
self.control.y = 0
|
||||||
|
self.control.yaw = 0
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def ReadYaml(self, robot_name):
|
||||||
|
try:
|
||||||
|
with open(CONFIG_PATH, 'r') as f:
|
||||||
|
config = yaml.safe_load(f)[robot_name]
|
||||||
|
except FileNotFoundError as e:
|
||||||
|
print(LOGGER.ERROR + "The file '{CONFIG_PATH}' does not exist")
|
||||||
|
return
|
||||||
|
|
||||||
|
self.params.model_name = config["model_name"]
|
||||||
|
self.params.dt = config["dt"]
|
||||||
|
self.params.decimation = config["decimation"]
|
||||||
|
self.params.num_observations = config["num_observations"]
|
||||||
|
self.params.clip_obs = config["clip_obs"]
|
||||||
|
self.params.action_scale = config["action_scale"]
|
||||||
|
self.params.hip_scale_reduction = config["hip_scale_reduction"]
|
||||||
|
self.params.hip_scale_reduction_indices = config["hip_scale_reduction_indices"]
|
||||||
|
self.params.clip_actions_upper = torch.tensor(config["clip_actions_upper"]).view(1, -1)
|
||||||
|
self.params.clip_actions_lower = torch.tensor(config["clip_actions_lower"]).view(1, -1)
|
||||||
|
self.params.num_of_dofs = config["num_of_dofs"]
|
||||||
|
self.params.lin_vel_scale = config["lin_vel_scale"]
|
||||||
|
self.params.ang_vel_scale = config["ang_vel_scale"]
|
||||||
|
self.params.dof_pos_scale = config["dof_pos_scale"]
|
||||||
|
self.params.dof_vel_scale = config["dof_vel_scale"]
|
||||||
|
self.params.commands_scale = torch.tensor([self.params.lin_vel_scale, self.params.lin_vel_scale, self.params.ang_vel_scale])
|
||||||
|
self.params.rl_kp = torch.tensor(config["rl_kp"]).view(1, -1)
|
||||||
|
self.params.rl_kd = torch.tensor(config["rl_kd"]).view(1, -1)
|
||||||
|
self.params.fixed_kp = torch.tensor(config["fixed_kp"]).view(1, -1)
|
||||||
|
self.params.fixed_kd = torch.tensor(config["fixed_kd"]).view(1, -1)
|
||||||
|
self.params.torque_limits = torch.tensor(config["torque_limits"]).view(1, -1)
|
||||||
|
self.params.default_dof_pos = torch.tensor(config["default_dof_pos"]).view(1, -1)
|
||||||
|
self.params.joint_controller_names = config["joint_controller_names"]
|
||||||
|
|
|
@ -0,0 +1,234 @@
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
import rospy
|
||||||
|
import numpy as np
|
||||||
|
from gazebo_msgs.msg import ModelStates
|
||||||
|
from sensor_msgs.msg import JointState
|
||||||
|
from geometry_msgs.msg import Twist, Pose
|
||||||
|
from robot_msgs.msg import MotorCommand
|
||||||
|
from gazebo_msgs.srv import SetModelState, SetModelStateRequest
|
||||||
|
from std_srvs.srv import Empty
|
||||||
|
|
||||||
|
path = os.path.abspath(".")
|
||||||
|
sys.path.insert(0, path + "/src/rl_sar/scripts")
|
||||||
|
from rl_sdk import *
|
||||||
|
from observation_buffer import *
|
||||||
|
|
||||||
|
class RL_Sim(RL):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# member variables for RL_Sim
|
||||||
|
self.vel = Twist()
|
||||||
|
self.pose = Pose()
|
||||||
|
self.cmd_vel = Twist()
|
||||||
|
|
||||||
|
# start ros node
|
||||||
|
rospy.init_node("rl_sim")
|
||||||
|
|
||||||
|
# read params from yaml
|
||||||
|
self.robot_name = rospy.get_param("robot_name", "")
|
||||||
|
self.ReadYaml(self.robot_name)
|
||||||
|
|
||||||
|
# history
|
||||||
|
self.use_history = rospy.get_param("use_history", "")
|
||||||
|
if self.use_history:
|
||||||
|
self.history_obs_buf = ObservationBuffer(1, self.params.num_observations, 6)
|
||||||
|
|
||||||
|
# Due to the fact that the robot_state_publisher sorts the joint names alphabetically,
|
||||||
|
# the mapping table is established according to the order defined in the YAML file
|
||||||
|
sorted_joint_controller_names = sorted(self.params.joint_controller_names)
|
||||||
|
self.sorted_to_original_index = {}
|
||||||
|
for i in range(len(self.params.joint_controller_names)):
|
||||||
|
self.sorted_to_original_index[sorted_joint_controller_names[i]] = i
|
||||||
|
self.mapped_joint_positions = [0.0] * self.params.num_of_dofs
|
||||||
|
self.mapped_joint_velocities = [0.0] * self.params.num_of_dofs
|
||||||
|
self.mapped_joint_efforts = [0.0] * self.params.num_of_dofs
|
||||||
|
|
||||||
|
# init
|
||||||
|
torch.set_grad_enabled(False)
|
||||||
|
self.joint_publishers_commands = [MotorCommand() for _ in range(self.params.num_of_dofs)]
|
||||||
|
self.InitObservations()
|
||||||
|
self.InitOutputs()
|
||||||
|
self.InitControl()
|
||||||
|
|
||||||
|
# model
|
||||||
|
model_path = os.path.join(os.path.dirname(__file__), f"../models/{self.robot_name}/{self.params.model_name}")
|
||||||
|
self.model = torch.jit.load(model_path)
|
||||||
|
|
||||||
|
# publisher
|
||||||
|
self.ros_namespace = rospy.get_param("ros_namespace", "")
|
||||||
|
self.joint_publishers = {}
|
||||||
|
for i in range(self.params.num_of_dofs):
|
||||||
|
topic_name = f"{self.ros_namespace}{self.params.joint_controller_names[i]}/command"
|
||||||
|
self.joint_publishers[self.params.joint_controller_names[i]] = rospy.Publisher(topic_name, MotorCommand, queue_size=10)
|
||||||
|
|
||||||
|
# subscriber
|
||||||
|
self.cmd_vel_subscriber = rospy.Subscriber("/cmd_vel", Twist, self.CmdvelCallback, queue_size=10)
|
||||||
|
self.model_state_subscriber = rospy.Subscriber("/gazebo/model_states", ModelStates, self.ModelStatesCallback, queue_size=10)
|
||||||
|
joint_states_topic = f"{self.ros_namespace}joint_states"
|
||||||
|
self.joint_state_subscriber = rospy.Subscriber(joint_states_topic, JointState, self.JointStatesCallback, queue_size=10)
|
||||||
|
|
||||||
|
# service
|
||||||
|
self.gazebo_set_model_state_client = rospy.ServiceProxy("/gazebo/set_model_state", SetModelState)
|
||||||
|
self.gazebo_pause_physics_client = rospy.ServiceProxy("/gazebo/pause_physics", Empty)
|
||||||
|
self.gazebo_unpause_physics_client = rospy.ServiceProxy("/gazebo/unpause_physics", Empty)
|
||||||
|
|
||||||
|
# loops
|
||||||
|
self.thread_control = threading.Thread(target=self.ThreadControl)
|
||||||
|
self.thread_rl = threading.Thread(target=self.ThreadRL)
|
||||||
|
self.thread_control.start()
|
||||||
|
self.thread_rl.start()
|
||||||
|
|
||||||
|
# keyboard
|
||||||
|
self.listener_keyboard = keyboard.Listener(on_press=self.KeyboardInterface)
|
||||||
|
self.listener_keyboard.start()
|
||||||
|
|
||||||
|
print(LOGGER.INFO + "RL_Sim start")
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
print(LOGGER.INFO + "RL_Sim exit")
|
||||||
|
|
||||||
|
def GetState(self, state):
|
||||||
|
state.imu.quaternion[3] = self.pose.orientation.w
|
||||||
|
state.imu.quaternion[0] = self.pose.orientation.x
|
||||||
|
state.imu.quaternion[1] = self.pose.orientation.y
|
||||||
|
state.imu.quaternion[2] = self.pose.orientation.z
|
||||||
|
|
||||||
|
state.imu.gyroscope[0] = self.vel.angular.x
|
||||||
|
state.imu.gyroscope[1] = self.vel.angular.y
|
||||||
|
state.imu.gyroscope[2] = self.vel.angular.z
|
||||||
|
|
||||||
|
# state.imu.accelerometer
|
||||||
|
|
||||||
|
for i in range(self.params.num_of_dofs):
|
||||||
|
state.motor_state.q[i] = self.mapped_joint_positions[i]
|
||||||
|
state.motor_state.dq[i] = self.mapped_joint_velocities[i]
|
||||||
|
state.motor_state.tauEst[i] = self.mapped_joint_efforts[i]
|
||||||
|
|
||||||
|
def SetCommand(self, command):
|
||||||
|
for i in range(self.params.num_of_dofs):
|
||||||
|
self.joint_publishers_commands[i].q = command.motor_command.q[i]
|
||||||
|
self.joint_publishers_commands[i].dq = command.motor_command.dq[i]
|
||||||
|
self.joint_publishers_commands[i].kp = command.motor_command.kp[i]
|
||||||
|
self.joint_publishers_commands[i].kd = command.motor_command.kd[i]
|
||||||
|
self.joint_publishers_commands[i].tau = command.motor_command.tau[i]
|
||||||
|
|
||||||
|
for i in range(self.params.num_of_dofs):
|
||||||
|
self.joint_publishers[self.params.joint_controller_names[i]].publish(self.joint_publishers_commands[i])
|
||||||
|
|
||||||
|
def RobotControl(self):
|
||||||
|
if self.control.control_state == STATE.STATE_RESET_SIMULATION:
|
||||||
|
set_model_state = SetModelStateRequest().model_state
|
||||||
|
gazebo_model_name = f"{self.robot_name}_gazebo"
|
||||||
|
set_model_state.model_name = gazebo_model_name
|
||||||
|
set_model_state.pose.position.z = 1.0
|
||||||
|
set_model_state.reference_frame = "world"
|
||||||
|
self.gazebo_set_model_state_client(set_model_state)
|
||||||
|
self.control.control_state = STATE.STATE_WAITING
|
||||||
|
if self.control.control_state == STATE.STATE_TOGGLE_SIMULATION:
|
||||||
|
if self.simulation_running:
|
||||||
|
self.gazebo_pause_physics_client()
|
||||||
|
print("\r\n" + LOGGER.INFO + "Simulation Stop")
|
||||||
|
else:
|
||||||
|
self.gazebo_unpause_physics_client()
|
||||||
|
print("\r\n" + LOGGER.INFO + "Simulation Start")
|
||||||
|
self.simulation_running = not self.simulation_running
|
||||||
|
self.control.control_state = STATE.STATE_WAITING
|
||||||
|
|
||||||
|
if self.simulation_running:
|
||||||
|
self.GetState(self.robot_state)
|
||||||
|
self.StateController(self.robot_state, self.robot_command)
|
||||||
|
self.SetCommand(self.robot_command)
|
||||||
|
|
||||||
|
def ModelStatesCallback(self, msg):
|
||||||
|
self.vel = msg.twist[2]
|
||||||
|
self.pose = msg.pose[2]
|
||||||
|
|
||||||
|
def CmdvelCallback(self, msg):
|
||||||
|
self.cmd_vel = msg
|
||||||
|
|
||||||
|
def MapData(self, source_data, target_data):
|
||||||
|
for i in range(len(source_data)):
|
||||||
|
target_data[i] = source_data[self.sorted_to_original_index[self.params.joint_controller_names[i]]]
|
||||||
|
|
||||||
|
def JointStatesCallback(self, msg):
|
||||||
|
self.MapData(msg.position, self.mapped_joint_positions)
|
||||||
|
self.MapData(msg.velocity, self.mapped_joint_velocities)
|
||||||
|
self.MapData(msg.effort, self.mapped_joint_efforts)
|
||||||
|
|
||||||
|
def RunModel(self):
|
||||||
|
if self.running_state == STATE.STATE_RL_RUNNING and self.simulation_running:
|
||||||
|
# self.obs.lin_vel = torch.tensor([[self.vel.linear.x, self.vel.linear.y, self.vel.linear.z]])
|
||||||
|
self.obs.ang_vel = torch.tensor(self.robot_state.imu.gyroscope).unsqueeze(0)
|
||||||
|
# self.obs.commands = torch.tensor([[self.cmd_vel.linear.x, self.cmd_vel.linear.y, self.cmd_vel.angular.z]])
|
||||||
|
self.obs.commands = torch.tensor([[self.control.x, self.control.y, self.control.yaw]])
|
||||||
|
self.obs.base_quat = torch.tensor(self.robot_state.imu.quaternion).unsqueeze(0)
|
||||||
|
self.obs.dof_pos = torch.tensor(self.robot_state.motor_state.q).narrow(0, 0, self.params.num_of_dofs).unsqueeze(0)
|
||||||
|
self.obs.dof_vel = torch.tensor(self.robot_state.motor_state.dq).narrow(0, 0, self.params.num_of_dofs).unsqueeze(0)
|
||||||
|
|
||||||
|
clamped_actions = self.Forward()
|
||||||
|
|
||||||
|
for i in self.params.hip_scale_reduction_indices:
|
||||||
|
clamped_actions[0][i] *= self.params.hip_scale_reduction
|
||||||
|
|
||||||
|
self.obs.actions = clamped_actions
|
||||||
|
|
||||||
|
origin_output_torques = self.ComputeTorques(self.obs.actions)
|
||||||
|
|
||||||
|
# self.TorqueProtect(origin_output_torques)
|
||||||
|
|
||||||
|
self.output_torques = torch.clamp(origin_output_torques, -(self.params.torque_limits), self.params.torque_limits)
|
||||||
|
self.output_dof_pos = self.ComputePosition(self.obs.actions)
|
||||||
|
|
||||||
|
def ComputeObservation(self):
|
||||||
|
obs = torch.cat([
|
||||||
|
# self.obs.lin_vel * self.params.lin_vel_scale,
|
||||||
|
# self.obs.ang_vel * self.params.ang_vel_scale, # TODO is QuatRotateInverse necessery?
|
||||||
|
self.QuatRotateInverse(self.obs.base_quat, self.obs.ang_vel) * self.params.ang_vel_scale,
|
||||||
|
self.QuatRotateInverse(self.obs.base_quat, self.obs.gravity_vec),
|
||||||
|
self.obs.commands * self.params.commands_scale,
|
||||||
|
(self.obs.dof_pos - self.params.default_dof_pos) * self.params.dof_pos_scale,
|
||||||
|
self.obs.dof_vel * self.params.dof_vel_scale,
|
||||||
|
self.obs.actions
|
||||||
|
], dim = -1)
|
||||||
|
clamped_obs = torch.clamp(obs, -self.params.clip_obs, self.params.clip_obs)
|
||||||
|
return clamped_obs
|
||||||
|
|
||||||
|
def Forward(self):
|
||||||
|
torch.set_grad_enabled(False)
|
||||||
|
clamped_obs = self.ComputeObservation()
|
||||||
|
if self.use_history:
|
||||||
|
self.history_obs_buf.insert(clamped_obs)
|
||||||
|
history_obs = self.history_obs_buf.get_obs_vec(np.arange(6))
|
||||||
|
actions = self.model.forward(history_obs)
|
||||||
|
else:
|
||||||
|
actions = self.model.forward(clamped_obs)
|
||||||
|
clamped_actions = torch.clamp(actions, self.params.clip_actions_lower, self.params.clip_actions_upper)
|
||||||
|
return clamped_actions
|
||||||
|
|
||||||
|
def ThreadControl(self):
|
||||||
|
thread_period = self.params.dt
|
||||||
|
thread_name = "thread_control"
|
||||||
|
print(f"[Thread Start] named: {thread_name}, period: {thread_period * 1000:.0f}(ms), cpu unspecified")
|
||||||
|
while not rospy.is_shutdown():
|
||||||
|
self.RobotControl()
|
||||||
|
time.sleep(thread_period)
|
||||||
|
print("[Thread End] named: " + thread_name)
|
||||||
|
|
||||||
|
def ThreadRL(self):
|
||||||
|
thread_period = self.params.dt * self.params.decimation
|
||||||
|
thread_name = "thread_rl"
|
||||||
|
print(f"[Thread Start] named: {thread_name}, period: {thread_period * 1000:.0f}(ms), cpu unspecified")
|
||||||
|
while not rospy.is_shutdown():
|
||||||
|
self.RunModel()
|
||||||
|
time.sleep(thread_period)
|
||||||
|
print("[Thread End] named: " + thread_name)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
rl_sim = RL_Sim()
|
||||||
|
rospy.spin()
|
||||||
|
|
|
@ -61,13 +61,15 @@ RL_Sim::RL_Sim()
|
||||||
this->gazebo_unpause_physics_client = nh.serviceClient<std_srvs::Empty>("/gazebo/unpause_physics");
|
this->gazebo_unpause_physics_client = nh.serviceClient<std_srvs::Empty>("/gazebo/unpause_physics");
|
||||||
|
|
||||||
// loop
|
// loop
|
||||||
this->loop_keyboard = std::make_shared<LoopFunc>("loop_keyboard", 0.05, std::bind(&RL_Sim::KeyboardInterface, this));
|
|
||||||
this->loop_control = std::make_shared<LoopFunc>("loop_control", this->params.dt, std::bind(&RL_Sim::RobotControl, this));
|
this->loop_control = std::make_shared<LoopFunc>("loop_control", this->params.dt, std::bind(&RL_Sim::RobotControl, this));
|
||||||
this->loop_rl = std::make_shared<LoopFunc>("loop_rl", this->params.dt * this->params.decimation, std::bind(&RL_Sim::RunModel, this));
|
this->loop_rl = std::make_shared<LoopFunc>("loop_rl", this->params.dt * this->params.decimation, std::bind(&RL_Sim::RunModel, this));
|
||||||
this->loop_keyboard->start();
|
|
||||||
this->loop_control->start();
|
this->loop_control->start();
|
||||||
this->loop_rl->start();
|
this->loop_rl->start();
|
||||||
|
|
||||||
|
// keyboard
|
||||||
|
this->loop_keyboard = std::make_shared<LoopFunc>("loop_keyboard", 0.05, std::bind(&RL_Sim::KeyboardInterface, this));
|
||||||
|
this->loop_keyboard->start();
|
||||||
|
|
||||||
#ifdef PLOT
|
#ifdef PLOT
|
||||||
this->plot_t = std::vector<int>(this->plot_size, 0);
|
this->plot_t = std::vector<int>(this->plot_size, 0);
|
||||||
this->plot_real_joint_pos.resize(this->params.num_of_dofs);
|
this->plot_real_joint_pos.resize(this->params.num_of_dofs);
|
||||||
|
@ -80,6 +82,8 @@ RL_Sim::RL_Sim()
|
||||||
#ifdef CSV_LOGGER
|
#ifdef CSV_LOGGER
|
||||||
this->CSVInit(this->robot_name);
|
this->CSVInit(this->robot_name);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
std::cout << LOGGER::INFO << "RL_Sim start" << std::endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
RL_Sim::~RL_Sim()
|
RL_Sim::~RL_Sim()
|
||||||
|
@ -150,10 +154,12 @@ void RL_Sim::RobotControl()
|
||||||
if(simulation_running)
|
if(simulation_running)
|
||||||
{
|
{
|
||||||
this->gazebo_pause_physics_client.call(empty);
|
this->gazebo_pause_physics_client.call(empty);
|
||||||
|
std::cout << std::endl << LOGGER::INFO << "Simulation Stop" << std::endl;
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
this->gazebo_unpause_physics_client.call(empty);
|
this->gazebo_unpause_physics_client.call(empty);
|
||||||
|
std::cout << std::endl << LOGGER::INFO << "Simulation Start" << std::endl;
|
||||||
}
|
}
|
||||||
simulation_running = !simulation_running;
|
simulation_running = !simulation_running;
|
||||||
this->control.control_state = STATE_WAITING;
|
this->control.control_state = STATE_WAITING;
|
||||||
|
@ -230,15 +236,16 @@ void RL_Sim::RunModel()
|
||||||
|
|
||||||
torch::Tensor RL_Sim::ComputeObservation()
|
torch::Tensor RL_Sim::ComputeObservation()
|
||||||
{
|
{
|
||||||
torch::Tensor obs = torch::cat({// this->obs.lin_vel * this->params.lin_vel_scale,
|
torch::Tensor obs = torch::cat({
|
||||||
this->QuatRotateInverse(this->obs.base_quat, this->obs.ang_vel) * this->params.ang_vel_scale,
|
// this->obs.lin_vel * this->params.lin_vel_scale,
|
||||||
// this->obs.ang_vel * this->params.ang_vel_scale, // TODO
|
// this->obs.ang_vel * this->params.ang_vel_scale, // TODO is QuatRotateInverse necessery?
|
||||||
this->QuatRotateInverse(this->obs.base_quat, this->obs.gravity_vec),
|
this->QuatRotateInverse(this->obs.base_quat, this->obs.ang_vel) * this->params.ang_vel_scale,
|
||||||
this->obs.commands * this->params.commands_scale,
|
this->QuatRotateInverse(this->obs.base_quat, this->obs.gravity_vec),
|
||||||
(this->obs.dof_pos - this->params.default_dof_pos) * this->params.dof_pos_scale,
|
this->obs.commands * this->params.commands_scale,
|
||||||
this->obs.dof_vel * this->params.dof_vel_scale,
|
(this->obs.dof_pos - this->params.default_dof_pos) * this->params.dof_pos_scale,
|
||||||
this->obs.actions
|
this->obs.dof_vel * this->params.dof_vel_scale,
|
||||||
},1);
|
this->obs.actions
|
||||||
|
},1);
|
||||||
torch::Tensor clamped_obs = torch::clamp(obs, -this->params.clip_obs, this->params.clip_obs);
|
torch::Tensor clamped_obs = torch::clamp(obs, -this->params.clip_obs, this->params.clip_obs);
|
||||||
return clamped_obs;
|
return clamped_obs;
|
||||||
}
|
}
|
||||||
|
@ -246,11 +253,8 @@ torch::Tensor RL_Sim::ComputeObservation()
|
||||||
torch::Tensor RL_Sim::Forward()
|
torch::Tensor RL_Sim::Forward()
|
||||||
{
|
{
|
||||||
torch::autograd::GradMode::set_enabled(false);
|
torch::autograd::GradMode::set_enabled(false);
|
||||||
|
|
||||||
torch::Tensor clamped_obs = this->ComputeObservation();
|
torch::Tensor clamped_obs = this->ComputeObservation();
|
||||||
|
|
||||||
torch::Tensor actions;
|
torch::Tensor actions;
|
||||||
|
|
||||||
if(this->use_history)
|
if(this->use_history)
|
||||||
{
|
{
|
||||||
this->history_obs_buf.insert(clamped_obs);
|
this->history_obs_buf.insert(clamped_obs);
|
||||||
|
@ -263,7 +267,6 @@ torch::Tensor RL_Sim::Forward()
|
||||||
}
|
}
|
||||||
|
|
||||||
torch::Tensor clamped_actions = torch::clamp(actions, this->params.clip_actions_lower, this->params.clip_actions_upper);
|
torch::Tensor clamped_actions = torch::clamp(actions, this->params.clip_actions_lower, this->params.clip_actions_upper);
|
||||||
|
|
||||||
return clamped_actions;
|
return clamped_actions;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -297,12 +300,8 @@ void signalHandler(int signum)
|
||||||
int main(int argc, char **argv)
|
int main(int argc, char **argv)
|
||||||
{
|
{
|
||||||
signal(SIGINT, signalHandler);
|
signal(SIGINT, signalHandler);
|
||||||
|
|
||||||
ros::init(argc, argv, "rl_sar");
|
ros::init(argc, argv, "rl_sar");
|
||||||
|
|
||||||
RL_Sim rl_sar;
|
RL_Sim rl_sar;
|
||||||
|
|
||||||
ros::spin();
|
ros::spin();
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue