[adding] deploy code and instructions
* latest environments will follow
This commit is contained in:
parent
c0d55bbb56
commit
f771497315
|
@ -0,0 +1,83 @@
|
|||
# Deploy the model on your real Unitree robot
|
||||
|
||||
This version shows an example of how to deploy the model on the Unittree Go1 robot (with Nvidia Jetson NX).
|
||||
|
||||
## Install dependencies on Go1
|
||||
To deploy the trained model on Go1, please set up a folder on your robot, e.g. `parkour`, and copy the `rsl_rl` folder to it. Then, extract the zip files in `go1_ckpts` to the `parkour` folder. Finally, copy all the files in `onboard_script` to the `parkour` folder.
|
||||
|
||||
1. Install ROS and the [unitree ros package for Go1](https://github.com/Tsinghua-MARS-Lab/unitree_ros_real.git) and follow the instructions to set up the robot on branch `go1`
|
||||
|
||||
Assuming the ros workspace is located in `parkour/unitree_ws`
|
||||
|
||||
2. Install pytorch on a Python 3 environment. (take the Python3 virtual environment on Nvidia Jetson NX as an example)
|
||||
```bash
|
||||
sudo apt-get install python3-pip python3-dev python3-venv
|
||||
python3 -m venv parkour_venv
|
||||
source parkour_venv/bin/activate
|
||||
```
|
||||
Download the pip wheel file from [here](https://forums.developer.nvidia.com/t/pytorch-for-jetson/72048) with v1.10.0. Then install it with
|
||||
```bash
|
||||
pip install torch-1.10.0-cp36-cp36m-linux_aarch64.whl
|
||||
```
|
||||
|
||||
3. Install other dependencies and `rsl_rl`
|
||||
```bash
|
||||
pip install numpy==1.16.4 numpy-ros
|
||||
pip install -e ./rsl_rl
|
||||
```
|
||||
|
||||
4. 3D print the camera mount for Go1 using the step file in `go1_ckpts/go1_camMount_30Down.step`. Use the two pairs of screw holes to mount the Intel Realsense D435i camera on the robot, as shown in the picture below.
|
||||
|
||||
<p align="center">
|
||||
<img src="images/go1_camMount_30Down.png" width="50%"/>
|
||||
</p>
|
||||
|
||||
## Run the model on Go1
|
||||
|
||||
***Disclaimer:*** *Always put a safety belt on the robot when the robot moves. The robot may fall down and cause damage to itself or the environment.*
|
||||
|
||||
1. Put the robot on the ground, power on the robot, and turn the robot into developer mode. Make sure your Intel Realsense D435i camera is connected to the robot and the camera is installed on the mount.
|
||||
|
||||
2. Launch 3 terminals onboard (whether 3 ssh connections from your computer or something else), named T_ros, T_visual, T_gru.
|
||||
|
||||
|
||||
3. In T_ros, run
|
||||
```bash
|
||||
cd parkour/unitree_ws
|
||||
source devel/setup.bash
|
||||
roslaunch unitree_legged_real robot.launch
|
||||
```
|
||||
This will launch the ros node for Go1. Please note that without `dryrun:=False` the robot will not move but do anything else. Set `dryrun:=False` only when you are ready to let the robot move.
|
||||
|
||||
4. In T_visual, run
|
||||
```bash
|
||||
cd parkour
|
||||
source unitree_ws/devel/setup.bash
|
||||
source parkour_venv/bin/activate
|
||||
python go1_visual_embedding.py --logdir Nov02...
|
||||
```
|
||||
where `Nov02...` is the logdir of the trained model.
|
||||
|
||||
Wait for the script to finish loading the model and get access to the Realsense sensor pipeline. Then, you can see the script prompting you: `"Realsense frame received. Sending embeddings..."`
|
||||
|
||||
Adding the `--enable_vis` option will enable the depth image message as a rostopic. You can visualize the depth image in rviz.
|
||||
|
||||
5. In T_gru, run
|
||||
```bash
|
||||
cd parkour
|
||||
source unitree_ws/devel/setup.bash
|
||||
source parkour_venv/bin/activate
|
||||
python a1_ros_run.py --mode upboard --logdir Nov02...
|
||||
```
|
||||
where `Nov02...` is the logdir of the trained model.
|
||||
|
||||
If the ros node is launched with `dryrun:=False`, the robot will start standing up. Otherwise, add `--debug` option on `a1_ros_run.py` to see the model output.
|
||||
|
||||
If the ros node is launched with `dryrun:=False`, the script will prompt you: `"Robot stood up! press R1 on the remote control to continue ..."` Press R1 on the remote control to start the standing policy
|
||||
|
||||
Pushing R-Y forward on the remote control to trigger the parkour policy. The robot will start running and jumping. Release the joystick to stop the robot.
|
||||
|
||||
Press R2 on the remote control to shutdown the gru script and the ros node. You can also use it in case of emergency.
|
||||
|
||||
Using the `--walkdir` option to load the walking policy e.g. (Oct24_16...), the standing policy will be replaced by the walking policy. Then you can use L-Y, L-X to control the x/y velocity of the robot and use R-X to control the yaw velocity of the robot.
|
||||
|
|
@ -27,7 +27,8 @@ Conference on Robot Learning (CoRL) 2023, Oral <br>
|
|||
To install and run the code for training A1 in simulation, please clone this repository and follow the instructions in [legged_gym/README.md](legged_gym/README.md).
|
||||
|
||||
## Hardware Deployment ##
|
||||
TODO
|
||||
To deploy the trained model on your real robot, please follow the instructions in [Deploy.md](Deploy.md).
|
||||
|
||||
|
||||
## Trouble Shooting ##
|
||||
If you cannot run the distillation part or all graphics computing goes to GPU 0 dispite you have multiple GPUs and have set the CUDA_VISIBLE_DEVICES, please use docker to isolate each GPU.
|
||||
|
@ -35,7 +36,7 @@ If you cannot run the distillation part or all graphics computing goes to GPU 0
|
|||
## To Do (will be done before Nov 2023) ##
|
||||
- [ ] Go1 training pipeline in simulation
|
||||
- [ ] A1 deployment code
|
||||
- [ ] Go1 deployment code
|
||||
- [x] Go1 deployment code
|
||||
|
||||
## Citation ##
|
||||
If you find this project helpful to your research, please consider cite us! This is really important to us.
|
||||
|
|
Binary file not shown.
Binary file not shown.
File diff suppressed because it is too large
Load Diff
Binary file not shown.
After Width: | Height: | Size: 961 KiB |
|
@ -0,0 +1,489 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.autograd import Variable
|
||||
import json
|
||||
import os
|
||||
import os.path as osp
|
||||
from collections import OrderedDict
|
||||
from typing import Tuple
|
||||
|
||||
import rospy
|
||||
from unitree_legged_msgs.msg import LowState
|
||||
from unitree_legged_msgs.msg import LegsCmd
|
||||
from unitree_legged_msgs.msg import Float32MultiArrayStamped
|
||||
from std_msgs.msg import Float32MultiArray
|
||||
from geometry_msgs.msg import Twist, Pose
|
||||
from nav_msgs.msg import Odometry
|
||||
from sensor_msgs.msg import Image
|
||||
|
||||
import ros_numpy
|
||||
|
||||
@torch.no_grad()
|
||||
def resize2d(img, size):
|
||||
return (F.adaptive_avg_pool2d(Variable(img), size)).data
|
||||
|
||||
@torch.jit.script
|
||||
def quat_rotate_inverse(q, v):
|
||||
shape = q.shape
|
||||
q_w = q[:, -1]
|
||||
q_vec = q[:, :3]
|
||||
a = v * (2.0 * q_w ** 2 - 1.0).unsqueeze(-1)
|
||||
b = torch.cross(q_vec, v, dim=-1) * q_w.unsqueeze(-1) * 2.0
|
||||
c = q_vec * \
|
||||
torch.bmm(q_vec.view(shape[0], 1, 3), v.view(
|
||||
shape[0], 3, 1)).squeeze(-1) * 2.0
|
||||
return a - b + c
|
||||
|
||||
class UnitreeA1Real:
|
||||
""" This is the handler that works for ROS 1 on unitree. """
|
||||
def __init__(self,
|
||||
robot_namespace= "a112138",
|
||||
low_state_topic= "/low_state",
|
||||
legs_cmd_topic= "/legs_cmd",
|
||||
forward_depth_topic = "/camera/depth/image_rect_raw",
|
||||
forward_depth_embedding_dims = None,
|
||||
odom_topic= "/odom/filtered",
|
||||
lin_vel_deadband= 0.1,
|
||||
ang_vel_deadband= 0.1,
|
||||
move_by_wireless_remote= False, # if True, command will not listen to move_cmd_subscriber, but wireless remote.
|
||||
cfg= dict(),
|
||||
extra_cfg= dict(),
|
||||
model_device= torch.device("cpu"),
|
||||
):
|
||||
"""
|
||||
NOTE:
|
||||
* Must call start_ros() before using this class's get_obs() and send_action()
|
||||
* Joint order of simulation and of real A1 protocol are different, see dof_names
|
||||
* We store all joints values in the order of simulation in this class
|
||||
Args:
|
||||
forward_depth_embedding_dims: If a real number, the obs will not be built as a normal env.
|
||||
The segment of obs will be subsituted by the embedding of forward depth image from the
|
||||
ROS topic.
|
||||
cfg: same config from a1_config but a dict object.
|
||||
extra_cfg: some other configs that is hard to load from file.
|
||||
"""
|
||||
self.model_device = model_device
|
||||
self.num_envs = 1
|
||||
self.robot_namespace = robot_namespace
|
||||
self.low_state_topic = low_state_topic
|
||||
self.legs_cmd_topic = legs_cmd_topic
|
||||
self.forward_depth_topic = forward_depth_topic
|
||||
self.forward_depth_embedding_dims = forward_depth_embedding_dims
|
||||
self.odom_topic = odom_topic
|
||||
self.lin_vel_deadband = lin_vel_deadband
|
||||
self.ang_vel_deadband = ang_vel_deadband
|
||||
self.move_by_wireless_remote = move_by_wireless_remote
|
||||
self.cfg = cfg
|
||||
self.extra_cfg = dict(
|
||||
torque_limits= torch.tensor([33.5] * 12, dtype= torch.float32, device= self.model_device, requires_grad= False), # Nm
|
||||
# torque_limits= torch.tensor([1, 5, 5] * 4, dtype= torch.float32, device= self.model_device, requires_grad= False), # Nm
|
||||
dof_map= [ # from isaacgym simulation joint order to URDF order
|
||||
3, 4, 5,
|
||||
0, 1, 2,
|
||||
9, 10,11,
|
||||
6, 7, 8,
|
||||
], # real_joint_idx = dof_map[sim_joint_idx]
|
||||
dof_names= [ # NOTE: order matters. This list is the order in simulation.
|
||||
"FL_hip_joint",
|
||||
"FL_thigh_joint",
|
||||
"FL_calf_joint",
|
||||
|
||||
"FR_hip_joint",
|
||||
"FR_thigh_joint",
|
||||
"FR_calf_joint",
|
||||
|
||||
"RL_hip_joint",
|
||||
"RL_thigh_joint",
|
||||
"RL_calf_joint",
|
||||
|
||||
"RR_hip_joint",
|
||||
"RR_thigh_joint",
|
||||
"RR_calf_joint",
|
||||
],
|
||||
# motor strength is multiplied directly to the action.
|
||||
motor_strength= torch.ones(12, dtype= torch.float32, device= self.model_device, requires_grad= False),
|
||||
); self.extra_cfg.update(extra_cfg)
|
||||
if "torque_limits" in self.cfg["control"]:
|
||||
if isinstance(self.cfg["control"]["torque_limits"], (tuple, list)):
|
||||
for i in range(len(self.cfg["control"]["torque_limits"])):
|
||||
self.extra_cfg["torque_limits"][i] = self.cfg["control"]["torque_limits"][i]
|
||||
else:
|
||||
self.extra_cfg["torque_limits"][:] = self.cfg["control"]["torque_limits"]
|
||||
self.command_buf = torch.zeros((self.num_envs, 3,), device= self.model_device, dtype= torch.float32) # zeros for initialization
|
||||
self.actions = torch.zeros((1, 12), device= model_device, dtype= torch.float32)
|
||||
|
||||
self.process_configs()
|
||||
|
||||
def start_ros(self):
|
||||
# initialze several buffers so that the system works even without message update.
|
||||
# self.low_state_buffer = LowState() # not initialized, let input message update it.
|
||||
self.base_position_buffer = torch.zeros((self.num_envs, 3), device= self.model_device, requires_grad= False)
|
||||
self.legs_cmd_publisher = rospy.Publisher(
|
||||
self.robot_namespace + self.legs_cmd_topic,
|
||||
LegsCmd,
|
||||
queue_size= 1,
|
||||
)
|
||||
# self.debug_publisher = rospy.Publisher(
|
||||
# "/DNNmodel_debug",
|
||||
# Float32MultiArray,
|
||||
# queue_size= 1,
|
||||
# )
|
||||
# NOTE: this launches the subscriber callback function
|
||||
self.low_state_subscriber = rospy.Subscriber(
|
||||
self.robot_namespace + self.low_state_topic,
|
||||
LowState,
|
||||
self.update_low_state,
|
||||
queue_size= 1,
|
||||
)
|
||||
self.odom_subscriber = rospy.Subscriber(
|
||||
self.robot_namespace + self.odom_topic,
|
||||
Odometry,
|
||||
self.update_base_pose,
|
||||
queue_size= 1,
|
||||
)
|
||||
if not self.move_by_wireless_remote:
|
||||
self.move_cmd_subscriber = rospy.Subscriber(
|
||||
"/cmd_vel",
|
||||
Twist,
|
||||
self.update_move_cmd,
|
||||
queue_size= 1,
|
||||
)
|
||||
if "forward_depth" in self.all_obs_components:
|
||||
if not self.forward_depth_embedding_dims:
|
||||
self.forward_depth_subscriber = rospy.Subscriber(
|
||||
self.robot_namespace + self.forward_depth_topic,
|
||||
Image,
|
||||
self.update_forward_depth,
|
||||
queue_size= 1,
|
||||
)
|
||||
else:
|
||||
self.forward_depth_subscriber = rospy.Subscriber(
|
||||
self.robot_namespace + self.forward_depth_topic,
|
||||
Float32MultiArrayStamped,
|
||||
self.update_forward_depth_embedding,
|
||||
queue_size= 1,
|
||||
)
|
||||
self.pose_cmd_subscriber = rospy.Subscriber(
|
||||
"/body_pose",
|
||||
Pose,
|
||||
self.dummy_handler,
|
||||
queue_size= 1,
|
||||
)
|
||||
|
||||
def wait_untill_ros_working(self):
|
||||
rate = rospy.Rate(100)
|
||||
while not hasattr(self, "low_state_buffer"):
|
||||
rate.sleep()
|
||||
rospy.loginfo("UnitreeA1Real.low_state_buffer acquired, stop waiting.")
|
||||
|
||||
def process_configs(self):
|
||||
self.up_axis_idx = 2 # 2 for z, 1 for y -> adapt gravity accordingly
|
||||
self.gravity_vec = torch.zeros((self.num_envs, 3), dtype= torch.float32)
|
||||
self.gravity_vec[:, self.up_axis_idx] = -1
|
||||
|
||||
self.obs_scales = self.cfg["normalization"]["obs_scales"]
|
||||
self.obs_scales["dof_pos"] = torch.tensor(self.obs_scales["dof_pos"], device= self.model_device, dtype= torch.float32)
|
||||
if not isinstance(self.cfg["control"]["damping"]["joint"], (list, tuple)):
|
||||
self.cfg["control"]["damping"]["joint"] = [self.cfg["control"]["damping"]["joint"]] * 12
|
||||
if not isinstance(self.cfg["control"]["stiffness"]["joint"], (list, tuple)):
|
||||
self.cfg["control"]["stiffness"]["joint"] = [self.cfg["control"]["stiffness"]["joint"]] * 12
|
||||
self.d_gains = torch.tensor(self.cfg["control"]["damping"]["joint"], device= self.model_device, dtype= torch.float32)
|
||||
self.p_gains = torch.tensor(self.cfg["control"]["stiffness"]["joint"], device= self.model_device, dtype= torch.float32)
|
||||
self.default_dof_pos = torch.zeros(12, device= self.model_device, dtype= torch.float32)
|
||||
for i in range(12):
|
||||
name = self.extra_cfg["dof_names"][i]
|
||||
default_joint_angle = self.cfg["init_state"]["default_joint_angles"][name]
|
||||
self.default_dof_pos[i] = default_joint_angle
|
||||
|
||||
self.computer_clip_torque = self.cfg["control"].get("computer_clip_torque", True)
|
||||
rospy.loginfo("Computer Clip Torque (onboard) is " + str(self.computer_clip_torque))
|
||||
if self.computer_clip_torque:
|
||||
self.torque_limits = self.extra_cfg["torque_limits"]
|
||||
rospy.loginfo("[Env] torque limit: {:.1f} {:.1f} {:.1f}".format(*self.torque_limits[:3]))
|
||||
|
||||
self.commands_scale = torch.tensor([
|
||||
self.obs_scales["lin_vel"],
|
||||
self.obs_scales["lin_vel"],
|
||||
self.obs_scales["lin_vel"],
|
||||
], device= self.model_device, requires_grad= False)
|
||||
|
||||
self.obs_segments = self.get_obs_segment_from_components(self.cfg["env"]["obs_components"])
|
||||
self.num_obs = self.get_num_obs_from_components(self.cfg["env"]["obs_components"])
|
||||
components = self.cfg["env"].get("privileged_obs_components", None)
|
||||
self.privileged_obs_segments = None if components is None else self.get_num_obs_from_components(components)
|
||||
self.num_privileged_obs = None if components is None else self.get_num_obs_from_components(components)
|
||||
self.all_obs_components = self.cfg["env"]["obs_components"] + (self.cfg["env"].get("privileged_obs_components", []) if components is not None else [])
|
||||
|
||||
# store config values to attributes to improve speed
|
||||
self.clip_obs = self.cfg["normalization"]["clip_observations"]
|
||||
self.control_type = self.cfg["control"]["control_type"]
|
||||
self.action_scale = self.cfg["control"]["action_scale"]
|
||||
rospy.loginfo("[Env] action scale: {:.1f}".format(self.action_scale))
|
||||
self.clip_actions = self.cfg["normalization"]["clip_actions"]
|
||||
if self.cfg["normalization"].get("clip_actions_method", None) == "hard":
|
||||
rospy.loginfo("clip_actions_method with hard mode")
|
||||
rospy.loginfo("clip_actions_high: " + str(self.cfg["normalization"]["clip_actions_high"]))
|
||||
rospy.loginfo("clip_actions_low: " + str(self.cfg["normalization"]["clip_actions_low"]))
|
||||
self.clip_actions_method = "hard"
|
||||
self.clip_actions_low = torch.tensor(self.cfg["normalization"]["clip_actions_low"], device= self.model_device, dtype= torch.float32)
|
||||
self.clip_actions_high = torch.tensor(self.cfg["normalization"]["clip_actions_high"], device= self.model_device, dtype= torch.float32)
|
||||
else:
|
||||
rospy.loginfo("clip_actions_method is " + str(self.cfg["normalization"].get("clip_actions_method", None)))
|
||||
self.dof_map = self.extra_cfg["dof_map"]
|
||||
|
||||
# get ROS params for hardware configs
|
||||
self.joint_limits_high = torch.tensor([
|
||||
rospy.get_param(self.robot_namespace + "/joint_limits/{}_max".format(s)) \
|
||||
for s in ["hip", "thigh", "calf"] * 4
|
||||
])
|
||||
self.joint_limits_low = torch.tensor([
|
||||
rospy.get_param(self.robot_namespace + "/joint_limits/{}_min".format(s)) \
|
||||
for s in ["hip", "thigh", "calf"] * 4
|
||||
])
|
||||
|
||||
if "forward_depth" in self.all_obs_components:
|
||||
resolution = self.cfg["sensor"]["forward_camera"].get(
|
||||
"output_resolution",
|
||||
self.cfg["sensor"]["forward_camera"]["resolution"],
|
||||
)
|
||||
if not self.forward_depth_embedding_dims:
|
||||
self.forward_depth_buf = torch.zeros(
|
||||
(self.num_envs, *resolution),
|
||||
device= self.model_device,
|
||||
dtype= torch.float32,
|
||||
)
|
||||
else:
|
||||
self.forward_depth_embedding_buf = torch.zeros(
|
||||
(1, self.forward_depth_embedding_dims),
|
||||
device= self.model_device,
|
||||
dtype= torch.float32,
|
||||
)
|
||||
|
||||
def _init_height_points(self):
|
||||
""" Returns points at which the height measurments are sampled (in base frame)
|
||||
|
||||
Returns:
|
||||
[torch.Tensor]: Tensor of shape (num_envs, self.num_height_points, 3)
|
||||
"""
|
||||
return None
|
||||
|
||||
def _get_heights(self):
|
||||
""" TODO: get estimated terrain heights around the robot base """
|
||||
# currently return a zero tensor with valid size
|
||||
return torch.zeros(self.num_envs, 187, device= self.model_device, requires_grad= False)
|
||||
|
||||
def clip_action_before_scale(self, actions):
|
||||
actions = torch.clip(actions, -self.clip_actions, self.clip_actions)
|
||||
if getattr(self, "clip_actions_method", None) == "hard":
|
||||
actions = torch.clip(actions, self.clip_actions_low, self.clip_actions_high)
|
||||
|
||||
return actions
|
||||
|
||||
def clip_by_torque_limit(self, actions_scaled):
|
||||
""" Different from simulation, we reverse the process and clip the actions directly,
|
||||
so that the PD controller runs in robot but not our script.
|
||||
"""
|
||||
control_type = self.cfg["control"]["control_type"]
|
||||
if control_type == "P":
|
||||
p_limits_low = (-self.torque_limits) + self.d_gains*self.dof_vel
|
||||
p_limits_high = (self.torque_limits) + self.d_gains*self.dof_vel
|
||||
actions_low = (p_limits_low/self.p_gains) - self.default_dof_pos + self.dof_pos
|
||||
actions_high = (p_limits_high/self.p_gains) - self.default_dof_pos + self.dof_pos
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return torch.clip(actions_scaled, actions_low, actions_high)
|
||||
|
||||
""" Get obs components and cat to a single obs input """
|
||||
def _get_proprioception_obs(self):
|
||||
# base_ang_vel = quat_rotate_inverse(
|
||||
# torch.tensor(self.low_state_buffer.imu.quaternion).unsqueeze(0),
|
||||
# torch.tensor(self.low_state_buffer.imu.gyroscope).unsqueeze(0),
|
||||
# ).to(self.model_device)
|
||||
# NOTE: Different from the isaacgym.
|
||||
# The anglar velocity is already in base frame, no need to rotate
|
||||
base_ang_vel = torch.tensor(self.low_state_buffer.imu.gyroscope, device= self.model_device).unsqueeze(0)
|
||||
projected_gravity = quat_rotate_inverse(
|
||||
torch.tensor(self.low_state_buffer.imu.quaternion).unsqueeze(0),
|
||||
self.gravity_vec,
|
||||
).to(self.model_device)
|
||||
self.dof_pos = dof_pos = torch.tensor([
|
||||
self.low_state_buffer.motorState[self.dof_map[i]].q for i in range(12)
|
||||
], dtype= torch.float32, device= self.model_device).unsqueeze(0)
|
||||
self.dof_vel = dof_vel = torch.tensor([
|
||||
self.low_state_buffer.motorState[self.dof_map[i]].dq for i in range(12)
|
||||
], dtype= torch.float32, device= self.model_device).unsqueeze(0)
|
||||
|
||||
# rospy.loginfo_throttle(5,
|
||||
# "projected_gravity: " + \
|
||||
# ", ".join([str(i) for i in projected_gravity[0].cpu().tolist()])
|
||||
# )
|
||||
# rospy.loginfo_throttle(5, "Hacking projected_gravity y -= 0.15")
|
||||
# projected_gravity[:, 1] -= 0.15
|
||||
|
||||
return torch.cat([
|
||||
torch.zeros((1, 3), device= self.model_device), # no linear velocity
|
||||
base_ang_vel * self.obs_scales["ang_vel"],
|
||||
projected_gravity,
|
||||
self.command_buf * self.commands_scale,
|
||||
(dof_pos - self.default_dof_pos) * self.obs_scales["dof_pos"],
|
||||
dof_vel * self.obs_scales["dof_vel"],
|
||||
self.actions
|
||||
], dim= -1)
|
||||
|
||||
def _get_forward_depth_obs(self):
|
||||
if not self.forward_depth_embedding_dims:
|
||||
return self.forward_depth_buf.flatten(start_dim= 1)
|
||||
else:
|
||||
if self.low_state_get_time.to_sec() - self.forward_depth_embedding_stamp.to_sec() > 0.4:
|
||||
rospy.logerr("Getting depth embedding later than low_state later than 0.4s")
|
||||
return self.forward_depth_embedding_buf.flatten(start_dim= 1)
|
||||
|
||||
def compute_observation(self):
|
||||
""" use the updated low_state_buffer to compute observation vector """
|
||||
assert hasattr(self, "legs_cmd_publisher"), "start_ros() not called, ROS handlers are not initialized!"
|
||||
obs_segments = self.obs_segments
|
||||
obs = []
|
||||
for k, v in obs_segments.items():
|
||||
obs.append(
|
||||
getattr(self, "_get_" + k + "_obs")() * \
|
||||
self.obs_scales.get(k, 1.)
|
||||
)
|
||||
obs = torch.cat(obs, dim= 1)
|
||||
self.obs_buf = obs
|
||||
|
||||
""" The methods combined with outer model forms the step function
|
||||
NOTE: the outer user handles the loop frequency.
|
||||
"""
|
||||
def send_action(self, actions):
|
||||
""" The function that send commands to the real robot.
|
||||
"""
|
||||
self.actions = self.clip_action_before_scale(actions)
|
||||
if self.computer_clip_torque:
|
||||
robot_coordinates_action = self.clip_by_torque_limit(actions * self.action_scale) + self.default_dof_pos.unsqueeze(0)
|
||||
else:
|
||||
rospy.logwarn_throttle(60, "You are using control without any torque clip. The network might output torques larger than the system can provide.")
|
||||
robot_coordinates_action = self.actions * self.action_scale + self.default_dof_pos.unsqueeze(0)
|
||||
|
||||
# debugging and logging
|
||||
# transfered_action = torch.zeros_like(self.actions[0])
|
||||
# for i in range(12):
|
||||
# transfered_action[self.dof_map[i]] = self.actions[0, i] + self.default_dof_pos[i]
|
||||
# self.debug_publisher.publish(Float32MultiArray(data=
|
||||
# transfered_action\
|
||||
# .cpu().numpy().astype(np.float32).tolist()
|
||||
# ))
|
||||
|
||||
# restrict the target action delta in order to avoid robot shutdown (maybe there is another solution)
|
||||
# robot_coordinates_action = torch.clip(
|
||||
# robot_coordinates_action,
|
||||
# self.dof_pos - 0.3,
|
||||
# self.dof_pos + 0.3,
|
||||
# )
|
||||
|
||||
# wrap the message and publish
|
||||
self.publish_legs_cmd(robot_coordinates_action)
|
||||
|
||||
def publish_legs_cmd(self, robot_coordinates_action, kp= None, kd= None):
|
||||
""" publish the joint position directly to the robot. NOTE: The joint order from input should
|
||||
be in simulation order. The value should be absolute value rather than related to dof_pos.
|
||||
"""
|
||||
robot_coordinates_action = torch.clip(
|
||||
robot_coordinates_action.cpu(),
|
||||
self.joint_limits_low,
|
||||
self.joint_limits_high,
|
||||
)
|
||||
legs_cmd = LegsCmd()
|
||||
for sim_joint_idx in range(12):
|
||||
real_joint_idx = self.dof_map[sim_joint_idx]
|
||||
legs_cmd.cmd[real_joint_idx].mode = 10
|
||||
legs_cmd.cmd[real_joint_idx].q = robot_coordinates_action[0, sim_joint_idx] if self.control_type == "P" else rospy.get_param(self.robot_namespace + "/PosStopF", (2.146e+9))
|
||||
legs_cmd.cmd[real_joint_idx].dq = 0.
|
||||
legs_cmd.cmd[real_joint_idx].tau = 0.
|
||||
legs_cmd.cmd[real_joint_idx].Kp = self.p_gains[sim_joint_idx] if kp is None else kp
|
||||
legs_cmd.cmd[real_joint_idx].Kd = self.d_gains[sim_joint_idx] if kd is None else kd
|
||||
self.legs_cmd_publisher.publish(legs_cmd)
|
||||
|
||||
def get_obs(self):
|
||||
""" The function that refreshes the buffer and return the observation vector.
|
||||
"""
|
||||
self.compute_observation()
|
||||
self.obs_buf = torch.clip(self.obs_buf, -self.clip_obs, self.clip_obs)
|
||||
return self.obs_buf.to(self.model_device)
|
||||
|
||||
""" Copied from legged_robot_field. Please check whether these are consistent. """
|
||||
def get_obs_segment_from_components(self, components):
|
||||
segments = OrderedDict()
|
||||
if "proprioception" in components:
|
||||
segments["proprioception"] = (48,)
|
||||
if "height_measurements" in components:
|
||||
segments["height_measurements"] = (187,)
|
||||
if "forward_depth" in components:
|
||||
resolution = self.cfg["sensor"]["forward_camera"].get(
|
||||
"output_resolution",
|
||||
self.cfg["sensor"]["forward_camera"]["resolution"],
|
||||
)
|
||||
segments["forward_depth"] = (1, *resolution)
|
||||
# The following components are only for rebuilding the non-actor module.
|
||||
# DO NOT use these in actor network and check consistency with simulator implementation.
|
||||
if "base_pose" in components:
|
||||
segments["base_pose"] = (6,) # xyz + rpy
|
||||
if "robot_config" in components:
|
||||
segments["robot_config"] = (1 + 3 + 1 + 12,)
|
||||
if "engaging_block" in components:
|
||||
# This could be wrong, please check the implementation of BarrierTrack
|
||||
segments["engaging_block"] = (1 + (4 + 1) + 2,)
|
||||
if "sidewall_distance" in components:
|
||||
segments["sidewall_distance"] = (2,)
|
||||
return segments
|
||||
|
||||
def get_num_obs_from_components(self, components):
|
||||
obs_segments = self.get_obs_segment_from_components(components)
|
||||
num_obs = 0
|
||||
for k, v in obs_segments.items():
|
||||
num_obs += np.prod(v)
|
||||
return num_obs
|
||||
|
||||
""" ROS callbacks and handlers that update the buffer """
|
||||
def update_low_state(self, ros_msg):
|
||||
self.low_state_buffer = ros_msg
|
||||
if self.move_by_wireless_remote:
|
||||
self.command_buf[0, 0] = self.low_state_buffer.wirelessRemote.ly
|
||||
self.command_buf[0, 1] = -self.low_state_buffer.wirelessRemote.lx # right-moving stick is positive
|
||||
self.command_buf[0, 2] = -self.low_state_buffer.wirelessRemote.rx # right-moving stick is positive
|
||||
# set the command to zero if it is too small
|
||||
if np.linalg.norm(self.command_buf[0, :2]) < self.lin_vel_deadband:
|
||||
self.command_buf[0, :2] = 0.
|
||||
if np.abs(self.command_buf[0, 2]) < self.ang_vel_deadband:
|
||||
self.command_buf[0, 2] = 0.
|
||||
self.low_state_get_time = rospy.Time.now()
|
||||
|
||||
def update_base_pose(self, ros_msg):
|
||||
""" update robot odometry for position """
|
||||
self.base_position_buffer[0, 0] = ros_msg.pose.pose.position.x
|
||||
self.base_position_buffer[0, 1] = ros_msg.pose.pose.position.y
|
||||
self.base_position_buffer[0, 2] = ros_msg.pose.pose.position.z
|
||||
|
||||
def update_move_cmd(self, ros_msg):
|
||||
self.command_buf[0, 0] = ros_msg.linear.x
|
||||
self.command_buf[0, 1] = ros_msg.linear.y
|
||||
self.command_buf[0, 2] = ros_msg.angular.z
|
||||
|
||||
def update_forward_depth(self, ros_msg):
|
||||
# TODO not checked.
|
||||
self.forward_depth_header = ros_msg.header
|
||||
buf = ros_numpy.numpify(ros_msg)
|
||||
self.forward_depth_buf = resize2d(
|
||||
torch.from_numpy(buf.astype(np.float32)).unsqueeze(0).unsqueeze(0).to(self.model_device),
|
||||
self.forward_depth_buf.shape[-2:],
|
||||
)
|
||||
|
||||
def update_forward_depth_embedding(self, ros_msg):
|
||||
rospy.loginfo_once("a1_ros_run recieved forward depth embedding.")
|
||||
self.forward_depth_embedding_stamp = ros_msg.header.stamp
|
||||
self.forward_depth_embedding_buf[:] = torch.tensor(ros_msg.data).unsqueeze(0) # (1, d)
|
||||
|
||||
def dummy_handler(self, ros_msg):
|
||||
""" To meet the need of teleop-legged-robots requirements """
|
||||
pass
|
|
@ -0,0 +1,402 @@
|
|||
#!/home/unitree/agility_ziwenz_venv/bin/python
|
||||
import os
|
||||
import os.path as osp
|
||||
import json
|
||||
import numpy as np
|
||||
import torch
|
||||
from collections import OrderedDict
|
||||
from functools import partial
|
||||
from typing import Tuple
|
||||
|
||||
import rospy
|
||||
from std_msgs.msg import Float32MultiArray
|
||||
from sensor_msgs.msg import Image
|
||||
import ros_numpy
|
||||
|
||||
from a1_real import UnitreeA1Real, resize2d
|
||||
from rsl_rl import modules
|
||||
from rsl_rl.utils.utils import get_obs_slice
|
||||
|
||||
@torch.no_grad()
|
||||
def handle_forward_depth(ros_msg, model, publisher, output_resolution, device):
|
||||
""" The callback function to handle the forward depth and send the embedding through ROS topic """
|
||||
buf = ros_numpy.numpify(ros_msg).astype(np.float32)
|
||||
forward_depth_buf = resize2d(
|
||||
torch.from_numpy(buf).unsqueeze(0).unsqueeze(0).to(device),
|
||||
output_resolution,
|
||||
)
|
||||
embedding = model(forward_depth_buf)
|
||||
ros_data = embedding.reshape(-1).cpu().numpy().astype(np.float32)
|
||||
publisher.publish(Float32MultiArray(data= ros_data.tolist()))
|
||||
|
||||
class StandOnlyModel(torch.nn.Module):
|
||||
def __init__(self, action_scale, dof_pos_scale, tolerance= 0.2, delta= 0.1):
|
||||
rospy.loginfo("Using stand only model, please make sure the proprioception is 48 dim.")
|
||||
rospy.loginfo("Using stand only model, -36 to -24 must be joint position.")
|
||||
super().__init__()
|
||||
if isinstance(action_scale, (tuple, list)):
|
||||
self.register_buffer("action_scale", torch.tensor(action_scale))
|
||||
else:
|
||||
self.action_scale = action_scale
|
||||
if isinstance(dof_pos_scale, (tuple, list)):
|
||||
self.register_buffer("dof_pos_scale", torch.tensor(dof_pos_scale))
|
||||
else:
|
||||
self.dof_pos_scale = dof_pos_scale
|
||||
self.tolerance = tolerance
|
||||
self.delta = delta
|
||||
|
||||
def forward(self, obs):
|
||||
joint_positions = obs[..., -36:-24] / self.dof_pos_scale
|
||||
diff_large_mask = torch.abs(joint_positions) > self.tolerance
|
||||
target_positions = torch.zeros_like(joint_positions)
|
||||
target_positions[diff_large_mask] = joint_positions[diff_large_mask] - self.delta * torch.sign(joint_positions[diff_large_mask])
|
||||
return torch.clip(
|
||||
target_positions / self.action_scale,
|
||||
-1.0, 1.0,
|
||||
)
|
||||
|
||||
def reset(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def load_walk_policy(env, model_dir):
|
||||
""" Load the walk policy from the model directory """
|
||||
if model_dir == None:
|
||||
model = StandOnlyModel(
|
||||
action_scale= env.action_scale,
|
||||
dof_pos_scale= env.obs_scales["dof_pos"],
|
||||
)
|
||||
policy = torch.jit.script(model)
|
||||
|
||||
else:
|
||||
with open(osp.join(model_dir, "config.json"), "r") as f:
|
||||
config_dict = json.load(f, object_pairs_hook= OrderedDict)
|
||||
obs_components = config_dict["env"]["obs_components"]
|
||||
privileged_obs_components = config_dict["env"].get("privileged_obs_components", obs_components)
|
||||
model = getattr(modules, config_dict["runner"]["policy_class_name"])(
|
||||
num_actor_obs= env.get_num_obs_from_components(obs_components),
|
||||
num_critic_obs= env.get_num_obs_from_components(privileged_obs_components),
|
||||
num_actions= 12,
|
||||
**config_dict["policy"],
|
||||
)
|
||||
model_names = [i for i in os.listdir(model_dir) if i.startswith("model_")]
|
||||
model_names.sort(key= lambda x: int(x.split("_")[-1].split(".")[0]))
|
||||
state_dict = torch.load(osp.join(model_dir, model_names[-1]), map_location= "cpu")
|
||||
model.load_state_dict(state_dict["model_state_dict"])
|
||||
model_action_scale = torch.tensor(config_dict["control"]["action_scale"]) if isinstance(config_dict["control"]["action_scale"], (tuple, list)) else torch.tensor([config_dict["control"]["action_scale"]])[0]
|
||||
if not (torch.is_tensor(model_action_scale) and (model_action_scale == env.action_scale).all()):
|
||||
action_rescale_ratio = model_action_scale / env.action_scale
|
||||
print("walk_policy action scaling:", action_rescale_ratio.tolist())
|
||||
else:
|
||||
action_rescale_ratio = 1.0
|
||||
memory_module = model.memory_a
|
||||
actor_mlp = model.actor
|
||||
@torch.jit.script
|
||||
def policy_run(obs):
|
||||
recurrent_embedding = memory_module(obs)
|
||||
actions = actor_mlp(recurrent_embedding.squeeze(0))
|
||||
return actions
|
||||
if (torch.is_tensor(action_rescale_ratio) and (action_rescale_ratio == 1.).all()) \
|
||||
or (not torch.is_tensor(action_rescale_ratio) and action_rescale_ratio == 1.):
|
||||
policy = policy_run
|
||||
else:
|
||||
policy = lambda x: policy_run(x) * action_rescale_ratio
|
||||
|
||||
return policy, model
|
||||
|
||||
def standup_procedure(env, ros_rate, angle_tolerance= 0.1,
|
||||
kp= None,
|
||||
kd= None,
|
||||
warmup_timesteps= 25,
|
||||
device= "cpu",
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
warmup_timesteps: the number of timesteps to linearly increase the target position
|
||||
"""
|
||||
rospy.loginfo("Robot standing up, please wait ...")
|
||||
|
||||
target_pos = torch.zeros((1, 12), device= device, dtype= torch.float32)
|
||||
standup_timestep_i = 0
|
||||
while not rospy.is_shutdown():
|
||||
dof_pos = [env.low_state_buffer.motorState[env.dof_map[i]].q for i in range(12)]
|
||||
diff = [env.default_dof_pos[i].item() - dof_pos[i] for i in range(12)]
|
||||
direction = [1 if i > 0 else -1 for i in diff]
|
||||
if standup_timestep_i < warmup_timesteps:
|
||||
direction = [standup_timestep_i / warmup_timesteps * i for i in direction]
|
||||
if all([abs(i) < angle_tolerance for i in diff]):
|
||||
break
|
||||
print("max joint error (rad):", max([abs(i) for i in diff]), end= "\r")
|
||||
for i in range(12):
|
||||
target_pos[0, i] = dof_pos[i] + direction[i] * angle_tolerance if abs(diff[i]) > angle_tolerance else env.default_dof_pos[i]
|
||||
env.publish_legs_cmd(target_pos,
|
||||
kp= kp,
|
||||
kd= kd,
|
||||
)
|
||||
ros_rate.sleep()
|
||||
standup_timestep_i += 1
|
||||
|
||||
rospy.loginfo("Robot stood up! press R1 on the remote control to continue ...")
|
||||
while not rospy.is_shutdown():
|
||||
if env.low_state_buffer.wirelessRemote.btn.components.R1:
|
||||
break
|
||||
if env.low_state_buffer.wirelessRemote.btn.components.L2 or env.low_state_buffer.wirelessRemote.btn.components.R2:
|
||||
env.publish_legs_cmd(env.default_dof_pos.unsqueeze(0), kp= 0, kd= 0.5)
|
||||
rospy.signal_shutdown("Controller send stop signal, exiting")
|
||||
exit(0)
|
||||
env.publish_legs_cmd(env.default_dof_pos.unsqueeze(0), kp= kp, kd= kd)
|
||||
ros_rate.sleep()
|
||||
rospy.loginfo("Robot standing up procedure finished!")
|
||||
|
||||
class SkilledA1Real(UnitreeA1Real):
|
||||
""" Some additional methods to help the execution of skill policy """
|
||||
def __init__(self, *args,
|
||||
skill_mode_threhold= 0.1,
|
||||
skill_vel_range= [0.0, 1.0],
|
||||
**kwargs,
|
||||
):
|
||||
self.skill_mode_threhold = skill_mode_threhold
|
||||
self.skill_vel_range = skill_vel_range
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def is_skill_mode(self):
|
||||
if self.move_by_wireless_remote:
|
||||
return self.low_state_buffer.wirelessRemote.ry > self.skill_mode_threhold
|
||||
else:
|
||||
# Not implemented yet
|
||||
return False
|
||||
|
||||
def update_low_state(self, ros_msg):
|
||||
self.low_state_buffer = ros_msg
|
||||
if self.move_by_wireless_remote and ros_msg.wirelessRemote.ry > self.skill_mode_threhold:
|
||||
skill_vel = (self.low_state_buffer.wirelessRemote.ry - self.skill_mode_threhold) / (1.0 - self.skill_mode_threhold)
|
||||
skill_vel *= self.skill_vel_range[1] - self.skill_vel_range[0]
|
||||
skill_vel += self.skill_vel_range[0]
|
||||
self.command_buf[0, 0] = skill_vel
|
||||
self.command_buf[0, 1] = 0.
|
||||
self.command_buf[0, 2] = 0.
|
||||
return
|
||||
return super().update_low_state(ros_msg)
|
||||
|
||||
def main(args):
|
||||
log_level = rospy.DEBUG if args.debug else rospy.INFO
|
||||
rospy.init_node("a1_legged_gym_" + args.mode, log_level= log_level)
|
||||
|
||||
""" Not finished this modification yet """
|
||||
# if args.logdir is not None:
|
||||
# rospy.loginfo("Use logdir/config.json to initialize env proxy.")
|
||||
# with open(osp.join(args.logdir, "config.json"), "r") as f:
|
||||
# config_dict = json.load(f, object_pairs_hook= OrderedDict)
|
||||
# else:
|
||||
# assert args.walkdir is not None, "You must provide at least a --logdir or --walkdir"
|
||||
# rospy.logwarn("You did not provide logdir, use walkdir/config.json for initializing env proxy.")
|
||||
# with open(osp.join(args.walkdir, "config.json"), "r") as f:
|
||||
# config_dict = json.load(f, object_pairs_hook= OrderedDict)
|
||||
assert args.logdir is not None
|
||||
with open(osp.join(args.logdir, "config.json"), "r") as f:
|
||||
config_dict = json.load(f, object_pairs_hook= OrderedDict)
|
||||
|
||||
duration = config_dict["sim"]["dt"] * config_dict["control"]["decimation"] # in sec
|
||||
# config_dict["control"]["stiffness"]["joint"] -= 2.5 # kp
|
||||
|
||||
model_device = torch.device("cpu") if args.mode == "upboard" else torch.device("cuda")
|
||||
|
||||
unitree_real_env = SkilledA1Real(
|
||||
robot_namespace= args.namespace,
|
||||
cfg= config_dict,
|
||||
forward_depth_topic= "/visual_embedding" if args.mode == "upboard" else "/camera/depth/image_rect_raw",
|
||||
forward_depth_embedding_dims= config_dict["policy"]["visual_latent_size"] if args.mode == "upboard" else None,
|
||||
move_by_wireless_remote= True,
|
||||
skill_vel_range= config_dict["commands"]["ranges"]["lin_vel_x"],
|
||||
model_device= model_device,
|
||||
# extra_cfg= dict(
|
||||
# motor_strength= torch.tensor([
|
||||
# 1., 1./0.9, 1./0.9,
|
||||
# 1., 1./0.9, 1./0.9,
|
||||
# 1., 1., 1.,
|
||||
# 1., 1., 1.,
|
||||
# ], dtype= torch.float32, device= model_device, requires_grad= False),
|
||||
# ),
|
||||
)
|
||||
|
||||
model = getattr(modules, config_dict["runner"]["policy_class_name"])(
|
||||
num_actor_obs= unitree_real_env.num_obs,
|
||||
num_critic_obs= unitree_real_env.num_privileged_obs,
|
||||
num_actions= 12,
|
||||
obs_segments= unitree_real_env.obs_segments,
|
||||
privileged_obs_segments= unitree_real_env.privileged_obs_segments,
|
||||
**config_dict["policy"],
|
||||
)
|
||||
config_dict["terrain"]["measure_heights"] = False
|
||||
# load the model with the latest checkpoint
|
||||
model_names = [i for i in os.listdir(args.logdir) if i.startswith("model_")]
|
||||
model_names.sort(key= lambda x: int(x.split("_")[-1].split(".")[0]))
|
||||
state_dict = torch.load(osp.join(args.logdir, model_names[-1]), map_location= "cpu")
|
||||
model.load_state_dict(state_dict["model_state_dict"])
|
||||
model.to(model_device)
|
||||
model.eval()
|
||||
|
||||
rospy.loginfo("duration: {}, motor Kp: {}, motor Kd: {}".format(
|
||||
duration,
|
||||
config_dict["control"]["stiffness"]["joint"],
|
||||
config_dict["control"]["damping"]["joint"],
|
||||
))
|
||||
# rospy.loginfo("[Env] motor strength: {}".format(unitree_real_env.motor_strength))
|
||||
|
||||
if args.mode == "jetson":
|
||||
embeding_publisher = rospy.Publisher(
|
||||
args.namespace + "/visual_embedding",
|
||||
Float32MultiArray,
|
||||
queue_size= 1,
|
||||
)
|
||||
# extract and build the torch ScriptFunction
|
||||
visual_encoder = model.visual_encoder
|
||||
visual_encoder = torch.jit.script(visual_encoder)
|
||||
|
||||
forward_depth_subscriber = rospy.Subscriber(
|
||||
args.namespace + "/camera/depth/image_rect_raw",
|
||||
Image,
|
||||
partial(handle_forward_depth,
|
||||
model= visual_encoder,
|
||||
publisher= embeding_publisher,
|
||||
output_resolution= config_dict["sensor"]["forward_camera"].get(
|
||||
"output_resolution",
|
||||
config_dict["sensor"]["forward_camera"]["resolution"],
|
||||
),
|
||||
device= model_device,
|
||||
),
|
||||
queue_size= 1,
|
||||
)
|
||||
rospy.spin()
|
||||
elif args.mode == "upboard":
|
||||
# extract and build the torch ScriptFunction
|
||||
memory_module = model.memory_a
|
||||
actor_mlp = model.actor
|
||||
@torch.jit.script
|
||||
def policy(obs):
|
||||
recurrent_embedding = memory_module(obs)
|
||||
actions = actor_mlp(recurrent_embedding.squeeze(0))
|
||||
return actions
|
||||
|
||||
walk_policy, walk_model = load_walk_policy(unitree_real_env, args.walkdir)
|
||||
|
||||
using_walk_policy = True # switch between skill policy and walk policy
|
||||
unitree_real_env.start_ros()
|
||||
unitree_real_env.wait_untill_ros_working()
|
||||
rate = rospy.Rate(1 / duration)
|
||||
with torch.no_grad():
|
||||
if not args.debug:
|
||||
standup_procedure(unitree_real_env, rate,
|
||||
angle_tolerance= 0.2,
|
||||
kp= 40,
|
||||
kd= 0.5,
|
||||
warmup_timesteps= 50,
|
||||
device= model_device,
|
||||
)
|
||||
while not rospy.is_shutdown():
|
||||
# inference_start_time = rospy.get_time()
|
||||
# check remote controller and decide which policy to use
|
||||
if unitree_real_env.is_skill_mode():
|
||||
if using_walk_policy:
|
||||
rospy.loginfo_throttle(0.1, "switch to skill policy")
|
||||
using_walk_policy = False
|
||||
model.reset()
|
||||
else:
|
||||
if not using_walk_policy:
|
||||
rospy.loginfo_throttle(0.1, "switch to walk policy")
|
||||
using_walk_policy = True
|
||||
walk_model.reset()
|
||||
if not using_walk_policy:
|
||||
obs = unitree_real_env.get_obs()
|
||||
actions = policy(obs)
|
||||
else:
|
||||
walk_obs = unitree_real_env._get_proprioception_obs()
|
||||
actions = walk_policy(walk_obs)
|
||||
unitree_real_env.send_action(actions)
|
||||
# unitree_real_env.send_action(torch.zeros((1, 12)))
|
||||
# inference_duration = rospy.get_time() - inference_start_time
|
||||
# rospy.loginfo("inference duration: {:.3f}".format(inference_duration))
|
||||
# rospy.loginfo("visual_latency: %f", rospy.get_time() - unitree_real_env.forward_depth_embedding_stamp.to_sec())
|
||||
# motor_temperatures = [motor_state.temperature for motor_state in unitree_real_env.low_state_buffer.motorState]
|
||||
# rospy.loginfo_throttle(10, " ".join(["motor_temperatures:"] + ["{:d},".format(t) for t in motor_temperatures[:12]]))
|
||||
rate.sleep()
|
||||
if unitree_real_env.low_state_buffer.wirelessRemote.btn.components.down:
|
||||
rospy.loginfo_throttle(0.1, "model reset")
|
||||
model.reset()
|
||||
walk_model.reset()
|
||||
if unitree_real_env.low_state_buffer.wirelessRemote.btn.components.L2 or unitree_real_env.low_state_buffer.wirelessRemote.btn.components.R2:
|
||||
unitree_real_env.publish_legs_cmd(unitree_real_env.default_dof_pos.unsqueeze(0), kp= 2, kd= 0.5)
|
||||
rospy.signal_shutdown("Controller send stop signal, exiting")
|
||||
elif args.mode == "full":
|
||||
# extract and build the torch ScriptFunction
|
||||
visual_obs_slice = get_obs_slice(unitree_real_env.obs_segments, "forward_depth")
|
||||
visual_encoder = model.visual_encoder
|
||||
memory_module = model.memory_a
|
||||
actor_mlp = model.actor
|
||||
@torch.jit.script
|
||||
def policy(observations: torch.Tensor, obs_start: int, obs_stop: int, obs_shape: Tuple[int, int, int]):
|
||||
visual_latent = visual_encoder(
|
||||
observations[..., obs_start:obs_stop].reshape(-1, *obs_shape)
|
||||
).reshape(1, -1)
|
||||
obs = torch.cat([
|
||||
observations[..., :obs_start],
|
||||
visual_latent,
|
||||
observations[..., obs_stop:],
|
||||
], dim= -1)
|
||||
recurrent_embedding = memory_module(obs)
|
||||
actions = actor_mlp(recurrent_embedding.squeeze(0))
|
||||
return actions
|
||||
|
||||
unitree_real_env.start_ros()
|
||||
unitree_real_env.wait_untill_ros_working()
|
||||
rate = rospy.Rate(1 / duration)
|
||||
with torch.no_grad():
|
||||
while not rospy.is_shutdown():
|
||||
# inference_start_time = rospy.get_time()
|
||||
obs = unitree_real_env.get_obs()
|
||||
actions = policy(obs,
|
||||
obs_start= visual_obs_slice[0].start.item(),
|
||||
obs_stop= visual_obs_slice[0].stop.item(),
|
||||
obs_shape= visual_obs_slice[1],
|
||||
)
|
||||
unitree_real_env.send_action(actions)
|
||||
# inference_duration = rospy.get_time() - inference_start_time
|
||||
motor_temperatures = [motor_state.temperature for motor_state in unitree_real_env.low_state_buffer.motorState]
|
||||
rospy.loginfo_throttle(10, " ".join(["motor_temperatures:"] + ["{:d},".format(t) for t in motor_temperatures[:12]]))
|
||||
rate.sleep()
|
||||
if unitree_real_env.low_state_buffer.wirelessRemote.btn.components.L2 or unitree_real_env.low_state_buffer.wirelessRemote.btn.components.R2:
|
||||
unitree_real_env.publish_legs_cmd(unitree_real_env.default_dof_pos.unsqueeze(0), kp= 20, kd= 0.5)
|
||||
rospy.signal_shutdown("Controller send stop signal, exiting")
|
||||
else:
|
||||
rospy.logfatal("Unknown mode, exiting")
|
||||
|
||||
if __name__ == "__main__":
|
||||
""" The script to run the A1 script in ROS.
|
||||
It's designed as a main function and not designed to be a scalable code.
|
||||
"""
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--namespace",
|
||||
type= str,
|
||||
default= "/a112138",
|
||||
)
|
||||
parser.add_argument("--logdir",
|
||||
type= str,
|
||||
help= "The log directory of the trained model",
|
||||
default= None,
|
||||
)
|
||||
parser.add_argument("--walkdir",
|
||||
type= str,
|
||||
help= "The log directory of the walking model, not for the skills.",
|
||||
default= None,
|
||||
)
|
||||
parser.add_argument("--mode",
|
||||
type= str,
|
||||
help= "The mode to determine which computer to run on.",
|
||||
choices= ["jetson", "upboard", "full"],
|
||||
)
|
||||
parser.add_argument("--debug",
|
||||
action= "store_true",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
|
@ -0,0 +1,317 @@
|
|||
import os
|
||||
import os.path as osp
|
||||
import numpy as np
|
||||
import torch
|
||||
import json
|
||||
from functools import partial
|
||||
from collections import OrderedDict
|
||||
|
||||
from a1_real import UnitreeA1Real, resize2d
|
||||
from rsl_rl import modules
|
||||
|
||||
import rospy
|
||||
from unitree_legged_msgs.msg import Float32MultiArrayStamped
|
||||
from sensor_msgs.msg import Image
|
||||
import ros_numpy
|
||||
|
||||
import pyrealsense2 as rs
|
||||
|
||||
def get_encoder_script(logdir):
|
||||
with open(osp.join(logdir, "config.json"), "r") as f:
|
||||
config_dict = json.load(f, object_pairs_hook= OrderedDict)
|
||||
|
||||
model_device = torch.device("cuda")
|
||||
|
||||
unitree_real_env = UnitreeA1Real(
|
||||
robot_namespace= "a112138",
|
||||
cfg= config_dict,
|
||||
forward_depth_topic= "", # this env only computes parameters to build the model
|
||||
forward_depth_embedding_dims= None,
|
||||
model_device= model_device,
|
||||
)
|
||||
|
||||
model = getattr(modules, config_dict["runner"]["policy_class_name"])(
|
||||
num_actor_obs= unitree_real_env.num_obs,
|
||||
num_critic_obs= unitree_real_env.num_privileged_obs,
|
||||
num_actions= 12,
|
||||
obs_segments= unitree_real_env.obs_segments,
|
||||
privileged_obs_segments= unitree_real_env.privileged_obs_segments,
|
||||
**config_dict["policy"],
|
||||
)
|
||||
model_names = [i for i in os.listdir(logdir) if i.startswith("model_")]
|
||||
model_names.sort(key= lambda x: int(x.split("_")[-1].split(".")[0]))
|
||||
state_dict = torch.load(osp.join(args.logdir, model_names[-1]), map_location= "cpu")
|
||||
model.load_state_dict(state_dict["model_state_dict"])
|
||||
model.to(model_device)
|
||||
model.eval()
|
||||
|
||||
visual_encoder = model.visual_encoder
|
||||
script = torch.jit.script(visual_encoder)
|
||||
|
||||
return script, model_device
|
||||
|
||||
def get_input_filter(args):
|
||||
""" This is the filter different from the simulator, but try to close the gap. """
|
||||
with open(osp.join(args.logdir, "config.json"), "r") as f:
|
||||
config_dict = json.load(f, object_pairs_hook= OrderedDict)
|
||||
image_resolution = config_dict["sensor"]["forward_camera"].get(
|
||||
"output_resolution",
|
||||
config_dict["sensor"]["forward_camera"]["resolution"],
|
||||
)
|
||||
depth_range = config_dict["sensor"]["forward_camera"].get(
|
||||
"depth_range",
|
||||
[0.0, 3.0],
|
||||
)
|
||||
depth_range = (depth_range[0] * 1000, depth_range[1] * 1000) # [m] -> [mm]
|
||||
crop_top, crop_bottom, crop_left, crop_right = args.crop_top, args.crop_bottom, args.crop_left, args.crop_right
|
||||
crop_far = args.crop_far * 1000
|
||||
|
||||
def input_filter(depth_image: torch.Tensor,
|
||||
crop_top: int,
|
||||
crop_bottom: int,
|
||||
crop_left: int,
|
||||
crop_right: int,
|
||||
crop_far: float,
|
||||
depth_min: int,
|
||||
depth_max: int,
|
||||
output_height: int,
|
||||
output_width: int,
|
||||
):
|
||||
""" depth_image must have shape [1, 1, H, W] """
|
||||
depth_image = depth_image[:, :,
|
||||
crop_top: -crop_bottom-1,
|
||||
crop_left: -crop_right-1,
|
||||
]
|
||||
depth_image[depth_image > crop_far] = depth_max
|
||||
depth_image = torch.clip(
|
||||
depth_image,
|
||||
depth_min,
|
||||
depth_max,
|
||||
) / (depth_max - depth_min)
|
||||
depth_image = resize2d(depth_image, (output_height, output_width))
|
||||
return depth_image
|
||||
# input_filter = torch.jit.script(input_filter)
|
||||
|
||||
return partial(input_filter,
|
||||
crop_top= crop_top,
|
||||
crop_bottom= crop_bottom,
|
||||
crop_left= crop_left,
|
||||
crop_right= crop_right,
|
||||
crop_far= crop_far,
|
||||
depth_min= depth_range[0],
|
||||
depth_max= depth_range[1],
|
||||
output_height= image_resolution[0],
|
||||
output_width= image_resolution[1],
|
||||
), depth_range
|
||||
|
||||
def get_started_pipeline(
|
||||
height= 480,
|
||||
width= 640,
|
||||
fps= 30,
|
||||
enable_rgb= False,
|
||||
):
|
||||
# By default, rgb is not used.
|
||||
pipeline = rs.pipeline()
|
||||
config = rs.config()
|
||||
config.enable_stream(rs.stream.depth, width, height, rs.format.z16, fps)
|
||||
if enable_rgb:
|
||||
config.enable_stream(rs.stream.color, width, height, rs.format.rgb8, fps)
|
||||
profile = pipeline.start(config)
|
||||
|
||||
# build the sensor filter
|
||||
hole_filling_filter = rs.hole_filling_filter(2)
|
||||
spatial_filter = rs.spatial_filter()
|
||||
spatial_filter.set_option(rs.option.filter_magnitude, 5)
|
||||
spatial_filter.set_option(rs.option.filter_smooth_alpha, 0.75)
|
||||
spatial_filter.set_option(rs.option.filter_smooth_delta, 1)
|
||||
spatial_filter.set_option(rs.option.holes_fill, 4)
|
||||
temporal_filter = rs.temporal_filter()
|
||||
temporal_filter.set_option(rs.option.filter_smooth_alpha, 0.75)
|
||||
temporal_filter.set_option(rs.option.filter_smooth_delta, 1)
|
||||
# decimation_filter = rs.decimation_filter()
|
||||
# decimation_filter.set_option(rs.option.filter_magnitude, 2)
|
||||
|
||||
def filter_func(frame):
|
||||
frame = hole_filling_filter.process(frame)
|
||||
frame = spatial_filter.process(frame)
|
||||
frame = temporal_filter.process(frame)
|
||||
# frame = decimation_filter.process(frame)
|
||||
return frame
|
||||
|
||||
return pipeline, filter_func
|
||||
|
||||
def main(args):
|
||||
rospy.init_node("a1_legged_gym_jetson")
|
||||
|
||||
input_filter, depth_range = get_input_filter(args)
|
||||
model_script, model_device = get_encoder_script(args.logdir)
|
||||
with open(osp.join(args.logdir, "config.json"), "r") as f:
|
||||
config_dict = json.load(f, object_pairs_hook= OrderedDict)
|
||||
if config_dict.get("sensor", dict()).get("forward_camera", dict()).get("refresh_duration", None) is not None:
|
||||
refresh_duration = config_dict["sensor"]["forward_camera"]["refresh_duration"]
|
||||
ros_rate = rospy.Rate(1.0 / refresh_duration)
|
||||
rospy.loginfo("Using refresh duration {}s".format(refresh_duration))
|
||||
else:
|
||||
ros_rate = rospy.Rate(args.fps)
|
||||
|
||||
rs_pipeline, rs_filters = get_started_pipeline(
|
||||
height= args.height,
|
||||
width= args.width,
|
||||
fps= args.fps,
|
||||
enable_rgb= args.enable_rgb,
|
||||
)
|
||||
|
||||
# gyro_pipeline = rs.pipeline()
|
||||
# gyro_config = rs.config()
|
||||
# gyro_config.enable_stream(rs.stream.gyro, rs.format.motion_xyz32f, 200)
|
||||
# gyro_profile = gyro_pipeline.start(gyro_config)
|
||||
|
||||
embedding_publisher = rospy.Publisher(
|
||||
args.namespace + "/visual_embedding",
|
||||
Float32MultiArrayStamped,
|
||||
queue_size= 1,
|
||||
)
|
||||
|
||||
if args.enable_vis:
|
||||
depth_image_publisher = rospy.Publisher(
|
||||
args.namespace + "/camera/depth/image_rect_raw",
|
||||
Image,
|
||||
queue_size= 1,
|
||||
)
|
||||
network_input_publisher = rospy.Publisher(
|
||||
args.namespace + "/camera/depth/network_input_raw",
|
||||
Image,
|
||||
queue_size= 1,
|
||||
)
|
||||
if args.enable_rgb:
|
||||
rgb_image_publisher = rospy.Publisher(
|
||||
args.namespace + "/camera/color/image_raw",
|
||||
Image,
|
||||
queue_size= 1,
|
||||
)
|
||||
|
||||
rospy.loginfo("Depth range is clipped to [{}, {}] and normalized".format(depth_range[0], depth_range[1]))
|
||||
rospy.loginfo("ROS, model, realsense have been initialized.")
|
||||
if args.enable_vis:
|
||||
rospy.loginfo("Visualization enabled, sending depth{} images".format(", rgb" if args.enable_rgb else ""))
|
||||
try:
|
||||
embedding_msg = Float32MultiArrayStamped()
|
||||
embedding_msg.header.frame_id = args.namespace + "/camera_depth_optical_frame"
|
||||
frame_got = False
|
||||
while not rospy.is_shutdown():
|
||||
# Wait for the depth image
|
||||
frames = rs_pipeline.wait_for_frames(int( \
|
||||
config_dict["sensor"]["forward_camera"]["latency_range"][1] \
|
||||
* 1000)) # ms
|
||||
embedding_msg.header.stamp = rospy.Time.now()
|
||||
depth_frame = frames.get_depth_frame()
|
||||
if not depth_frame:
|
||||
continue
|
||||
if not frame_got:
|
||||
frame_got = True
|
||||
rospy.loginfo("Realsense frame recieved. Sending embeddings...")
|
||||
if args.enable_rgb:
|
||||
color_frame = frames.get_color_frame()
|
||||
# Use this branch to log the time when image is acquired
|
||||
if args.enable_vis and not color_frame is None:
|
||||
color_frame = np.asanyarray(color_frame.get_data())
|
||||
rgb_image_msg = ros_numpy.msgify(Image, color_frame, encoding= "rgb8")
|
||||
rgb_image_msg.header.stamp = rospy.Time.now()
|
||||
rgb_image_msg.header.frame_id = args.namespace + "/camera_color_optical_frame"
|
||||
rgb_image_publisher.publish(rgb_image_msg)
|
||||
|
||||
# Process the depth image and publish
|
||||
depth_frame = rs_filters(depth_frame)
|
||||
depth_image_ = np.asanyarray(depth_frame.get_data())
|
||||
depth_image = torch.from_numpy(depth_image_.astype(np.float32)).unsqueeze(0).unsqueeze(0).to(model_device)
|
||||
depth_image = input_filter(depth_image)
|
||||
with torch.no_grad():
|
||||
depth_embedding = model_script(depth_image).reshape(-1).cpu().numpy()
|
||||
embedding_msg.header.seq += 1
|
||||
embedding_msg.data = depth_embedding.tolist()
|
||||
embedding_publisher.publish(embedding_msg)
|
||||
|
||||
# Publish the acquired image if needed
|
||||
if args.enable_vis:
|
||||
depth_image_msg = ros_numpy.msgify(Image, depth_image_, encoding= "16UC1")
|
||||
depth_image_msg.header.stamp = rospy.Time.now()
|
||||
depth_image_msg.header.frame_id = args.namespace + "/camera_depth_optical_frame"
|
||||
depth_image_publisher.publish(depth_image_msg)
|
||||
network_input_np = (\
|
||||
depth_image.detach().cpu().numpy()[0, 0] * (depth_range[1] - depth_range[0]) \
|
||||
+ depth_range[0]
|
||||
).astype(np.uint16)
|
||||
network_input_msg = ros_numpy.msgify(Image, network_input_np, encoding= "16UC1")
|
||||
network_input_msg.header.stamp = rospy.Time.now()
|
||||
network_input_msg.header.frame_id = args.namespace + "/camera_depth_optical_frame"
|
||||
network_input_publisher.publish(network_input_msg)
|
||||
|
||||
ros_rate.sleep()
|
||||
finally:
|
||||
rs_pipeline.stop()
|
||||
|
||||
if __name__ == "__main__":
|
||||
""" This script is designed to load the model and process the realsense image directly
|
||||
from realsense SDK without realsense ROS wrapper
|
||||
"""
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--namespace",
|
||||
type= str,
|
||||
default= "/a112138",
|
||||
)
|
||||
parser.add_argument("--logdir",
|
||||
type= str,
|
||||
help= "The log directory of the trained model",
|
||||
)
|
||||
parser.add_argument("--height",
|
||||
type= int,
|
||||
default= 240,
|
||||
help= "The height of the realsense image",
|
||||
)
|
||||
parser.add_argument("--width",
|
||||
type= int,
|
||||
default= 424,
|
||||
help= "The width of the realsense image",
|
||||
)
|
||||
parser.add_argument("--fps",
|
||||
type= int,
|
||||
default= 30,
|
||||
help= "The fps of the realsense image",
|
||||
)
|
||||
parser.add_argument("--crop_left",
|
||||
type= int,
|
||||
default= 60,
|
||||
help= "num of pixel to crop in the original pyrealsense readings."
|
||||
)
|
||||
parser.add_argument("--crop_right",
|
||||
type= int,
|
||||
default= 46,
|
||||
help= "num of pixel to crop in the original pyrealsense readings."
|
||||
)
|
||||
parser.add_argument("--crop_top",
|
||||
type= int,
|
||||
default= 0,
|
||||
help= "num of pixel to crop in the original pyrealsense readings."
|
||||
)
|
||||
parser.add_argument("--crop_bottom",
|
||||
type= int,
|
||||
default= 0,
|
||||
help= "num of pixel to crop in the original pyrealsense readings."
|
||||
)
|
||||
parser.add_argument("--crop_far",
|
||||
type= float,
|
||||
default= 3.0,
|
||||
help= "asside from the config far limit, make all depth readings larger than this value to be 3.0 in un-normalized network input."
|
||||
)
|
||||
parser.add_argument("--enable_rgb",
|
||||
action= "store_true",
|
||||
help= "Whether to enable rgb image",
|
||||
)
|
||||
parser.add_argument("--enable_vis",
|
||||
action= "store_true",
|
||||
help= "Whether to publish realsense image",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
|
@ -0,0 +1,317 @@
|
|||
import os
|
||||
import os.path as osp
|
||||
import numpy as np
|
||||
import torch
|
||||
import json
|
||||
from functools import partial
|
||||
from collections import OrderedDict
|
||||
|
||||
from a1_real import UnitreeA1Real, resize2d
|
||||
from rsl_rl import modules
|
||||
|
||||
import rospy
|
||||
from unitree_legged_msgs.msg import Float32MultiArrayStamped
|
||||
from sensor_msgs.msg import Image
|
||||
import ros_numpy
|
||||
|
||||
import pyrealsense2 as rs
|
||||
|
||||
def get_encoder_script(logdir):
|
||||
with open(osp.join(logdir, "config.json"), "r") as f:
|
||||
config_dict = json.load(f, object_pairs_hook= OrderedDict)
|
||||
|
||||
model_device = torch.device("cuda")
|
||||
|
||||
unitree_real_env = UnitreeA1Real(
|
||||
robot_namespace= "a112138",
|
||||
cfg= config_dict,
|
||||
forward_depth_topic= "", # this env only computes parameters to build the model
|
||||
forward_depth_embedding_dims= None,
|
||||
model_device= model_device,
|
||||
)
|
||||
|
||||
model = getattr(modules, config_dict["runner"]["policy_class_name"])(
|
||||
num_actor_obs= unitree_real_env.num_obs,
|
||||
num_critic_obs= unitree_real_env.num_privileged_obs,
|
||||
num_actions= 12,
|
||||
obs_segments= unitree_real_env.obs_segments,
|
||||
privileged_obs_segments= unitree_real_env.privileged_obs_segments,
|
||||
**config_dict["policy"],
|
||||
)
|
||||
model_names = [i for i in os.listdir(logdir) if i.startswith("model_")]
|
||||
model_names.sort(key= lambda x: int(x.split("_")[-1].split(".")[0]))
|
||||
state_dict = torch.load(osp.join(args.logdir, model_names[-1]), map_location= "cpu")
|
||||
model.load_state_dict(state_dict["model_state_dict"])
|
||||
model.to(model_device)
|
||||
model.eval()
|
||||
|
||||
visual_encoder = model.visual_encoder
|
||||
script = torch.jit.script(visual_encoder)
|
||||
|
||||
return script, model_device
|
||||
|
||||
def get_input_filter(args):
|
||||
""" This is the filter different from the simulator, but try to close the gap. """
|
||||
with open(osp.join(args.logdir, "config.json"), "r") as f:
|
||||
config_dict = json.load(f, object_pairs_hook= OrderedDict)
|
||||
image_resolution = config_dict["sensor"]["forward_camera"].get(
|
||||
"output_resolution",
|
||||
config_dict["sensor"]["forward_camera"]["resolution"],
|
||||
)
|
||||
depth_range = config_dict["sensor"]["forward_camera"].get(
|
||||
"depth_range",
|
||||
[0.0, 3.0],
|
||||
)
|
||||
depth_range = (depth_range[0] * 1000, depth_range[1] * 1000) # [m] -> [mm]
|
||||
crop_top, crop_bottom, crop_left, crop_right = args.crop_top, args.crop_bottom, args.crop_left, args.crop_right
|
||||
crop_far = args.crop_far * 1000
|
||||
|
||||
def input_filter(depth_image: torch.Tensor,
|
||||
crop_top: int,
|
||||
crop_bottom: int,
|
||||
crop_left: int,
|
||||
crop_right: int,
|
||||
crop_far: float,
|
||||
depth_min: int,
|
||||
depth_max: int,
|
||||
output_height: int,
|
||||
output_width: int,
|
||||
):
|
||||
""" depth_image must have shape [1, 1, H, W] """
|
||||
depth_image = depth_image[:, :,
|
||||
crop_top: -crop_bottom-1,
|
||||
crop_left: -crop_right-1,
|
||||
]
|
||||
depth_image[depth_image > crop_far] = depth_max
|
||||
depth_image = torch.clip(
|
||||
depth_image,
|
||||
depth_min,
|
||||
depth_max,
|
||||
) / (depth_max - depth_min)
|
||||
depth_image = resize2d(depth_image, (output_height, output_width))
|
||||
return depth_image
|
||||
# input_filter = torch.jit.script(input_filter)
|
||||
|
||||
return partial(input_filter,
|
||||
crop_top= crop_top,
|
||||
crop_bottom= crop_bottom,
|
||||
crop_left= crop_left,
|
||||
crop_right= crop_right,
|
||||
crop_far= crop_far,
|
||||
depth_min= depth_range[0],
|
||||
depth_max= depth_range[1],
|
||||
output_height= image_resolution[0],
|
||||
output_width= image_resolution[1],
|
||||
), depth_range
|
||||
|
||||
def get_started_pipeline(
|
||||
height= 480,
|
||||
width= 640,
|
||||
fps= 30,
|
||||
enable_rgb= False,
|
||||
):
|
||||
# By default, rgb is not used.
|
||||
pipeline = rs.pipeline()
|
||||
config = rs.config()
|
||||
config.enable_stream(rs.stream.depth, width, height, rs.format.z16, fps)
|
||||
if enable_rgb:
|
||||
config.enable_stream(rs.stream.color, width, height, rs.format.rgb8, fps)
|
||||
profile = pipeline.start(config)
|
||||
|
||||
# build the sensor filter
|
||||
hole_filling_filter = rs.hole_filling_filter(2)
|
||||
spatial_filter = rs.spatial_filter()
|
||||
spatial_filter.set_option(rs.option.filter_magnitude, 5)
|
||||
spatial_filter.set_option(rs.option.filter_smooth_alpha, 0.75)
|
||||
spatial_filter.set_option(rs.option.filter_smooth_delta, 1)
|
||||
spatial_filter.set_option(rs.option.holes_fill, 4)
|
||||
temporal_filter = rs.temporal_filter()
|
||||
temporal_filter.set_option(rs.option.filter_smooth_alpha, 0.75)
|
||||
temporal_filter.set_option(rs.option.filter_smooth_delta, 1)
|
||||
# decimation_filter = rs.decimation_filter()
|
||||
# decimation_filter.set_option(rs.option.filter_magnitude, 2)
|
||||
|
||||
def filter_func(frame):
|
||||
frame = hole_filling_filter.process(frame)
|
||||
frame = spatial_filter.process(frame)
|
||||
frame = temporal_filter.process(frame)
|
||||
# frame = decimation_filter.process(frame)
|
||||
return frame
|
||||
|
||||
return pipeline, filter_func
|
||||
|
||||
def main(args):
|
||||
rospy.init_node("a1_legged_gym_jetson")
|
||||
|
||||
input_filter, depth_range = get_input_filter(args)
|
||||
model_script, model_device = get_encoder_script(args.logdir)
|
||||
with open(osp.join(args.logdir, "config.json"), "r") as f:
|
||||
config_dict = json.load(f, object_pairs_hook= OrderedDict)
|
||||
if config_dict.get("sensor", dict()).get("forward_camera", dict()).get("refresh_duration", None) is not None:
|
||||
refresh_duration = config_dict["sensor"]["forward_camera"]["refresh_duration"]
|
||||
ros_rate = rospy.Rate(1.0 / refresh_duration)
|
||||
rospy.loginfo("Using refresh duration {}s".format(refresh_duration))
|
||||
else:
|
||||
ros_rate = rospy.Rate(args.fps)
|
||||
|
||||
rs_pipeline, rs_filters = get_started_pipeline(
|
||||
height= args.height,
|
||||
width= args.width,
|
||||
fps= args.fps,
|
||||
enable_rgb= args.enable_rgb,
|
||||
)
|
||||
|
||||
# gyro_pipeline = rs.pipeline()
|
||||
# gyro_config = rs.config()
|
||||
# gyro_config.enable_stream(rs.stream.gyro, rs.format.motion_xyz32f, 200)
|
||||
# gyro_profile = gyro_pipeline.start(gyro_config)
|
||||
|
||||
embedding_publisher = rospy.Publisher(
|
||||
args.namespace + "/visual_embedding",
|
||||
Float32MultiArrayStamped,
|
||||
queue_size= 1,
|
||||
)
|
||||
|
||||
if args.enable_vis:
|
||||
depth_image_publisher = rospy.Publisher(
|
||||
args.namespace + "/camera/depth/image_rect_raw",
|
||||
Image,
|
||||
queue_size= 1,
|
||||
)
|
||||
network_input_publisher = rospy.Publisher(
|
||||
args.namespace + "/camera/depth/network_input_raw",
|
||||
Image,
|
||||
queue_size= 1,
|
||||
)
|
||||
if args.enable_rgb:
|
||||
rgb_image_publisher = rospy.Publisher(
|
||||
args.namespace + "/camera/color/image_raw",
|
||||
Image,
|
||||
queue_size= 1,
|
||||
)
|
||||
|
||||
rospy.loginfo("Depth range is clipped to [{}, {}] and normalized".format(depth_range[0], depth_range[1]))
|
||||
rospy.loginfo("ROS, model, realsense have been initialized.")
|
||||
if args.enable_vis:
|
||||
rospy.loginfo("Visualization enabled, sending depth{} images".format(", rgb" if args.enable_rgb else ""))
|
||||
try:
|
||||
embedding_msg = Float32MultiArrayStamped()
|
||||
embedding_msg.header.frame_id = args.namespace + "/camera_depth_optical_frame"
|
||||
frame_got = False
|
||||
while not rospy.is_shutdown():
|
||||
# Wait for the depth image
|
||||
frames = rs_pipeline.wait_for_frames(int( \
|
||||
config_dict["sensor"]["forward_camera"]["latency_range"][1] \
|
||||
* 1000)) # ms
|
||||
embedding_msg.header.stamp = rospy.Time.now()
|
||||
depth_frame = frames.get_depth_frame()
|
||||
if not depth_frame:
|
||||
continue
|
||||
if not frame_got:
|
||||
frame_got = True
|
||||
rospy.loginfo("Realsense frame recieved. Sending embeddings...")
|
||||
if args.enable_rgb:
|
||||
color_frame = frames.get_color_frame()
|
||||
# Use this branch to log the time when image is acquired
|
||||
if args.enable_vis and not color_frame is None:
|
||||
color_frame = np.asanyarray(color_frame.get_data())
|
||||
rgb_image_msg = ros_numpy.msgify(Image, color_frame, encoding= "rgb8")
|
||||
rgb_image_msg.header.stamp = rospy.Time.now()
|
||||
rgb_image_msg.header.frame_id = args.namespace + "/camera_color_optical_frame"
|
||||
rgb_image_publisher.publish(rgb_image_msg)
|
||||
|
||||
# Process the depth image and publish
|
||||
depth_frame = rs_filters(depth_frame)
|
||||
depth_image_ = np.asanyarray(depth_frame.get_data())
|
||||
depth_image = torch.from_numpy(depth_image_.astype(np.float32)).unsqueeze(0).unsqueeze(0).to(model_device)
|
||||
depth_image = input_filter(depth_image)
|
||||
with torch.no_grad():
|
||||
depth_embedding = model_script(depth_image).reshape(-1).cpu().numpy()
|
||||
embedding_msg.header.seq += 1
|
||||
embedding_msg.data = depth_embedding.tolist()
|
||||
embedding_publisher.publish(embedding_msg)
|
||||
|
||||
# Publish the acquired image if needed
|
||||
if args.enable_vis:
|
||||
depth_image_msg = ros_numpy.msgify(Image, depth_image_, encoding= "16UC1")
|
||||
depth_image_msg.header.stamp = rospy.Time.now()
|
||||
depth_image_msg.header.frame_id = args.namespace + "/camera_depth_optical_frame"
|
||||
depth_image_publisher.publish(depth_image_msg)
|
||||
network_input_np = (\
|
||||
depth_image.detach().cpu().numpy()[0, 0] * (depth_range[1] - depth_range[0]) \
|
||||
+ depth_range[0]
|
||||
).astype(np.uint16)
|
||||
network_input_msg = ros_numpy.msgify(Image, network_input_np, encoding= "16UC1")
|
||||
network_input_msg.header.stamp = rospy.Time.now()
|
||||
network_input_msg.header.frame_id = args.namespace + "/camera_depth_optical_frame"
|
||||
network_input_publisher.publish(network_input_msg)
|
||||
|
||||
ros_rate.sleep()
|
||||
finally:
|
||||
rs_pipeline.stop()
|
||||
|
||||
if __name__ == "__main__":
|
||||
""" This script is designed to load the model and process the realsense image directly
|
||||
from realsense SDK without realsense ROS wrapper
|
||||
"""
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--namespace",
|
||||
type= str,
|
||||
default= "/a112138",
|
||||
)
|
||||
parser.add_argument("--logdir",
|
||||
type= str,
|
||||
help= "The log directory of the trained model",
|
||||
)
|
||||
parser.add_argument("--height",
|
||||
type= int,
|
||||
default= 270,
|
||||
help= "The height of the realsense image",
|
||||
)
|
||||
parser.add_argument("--width",
|
||||
type= int,
|
||||
default= 480,
|
||||
help= "The width of the realsense image",
|
||||
)
|
||||
parser.add_argument("--fps",
|
||||
type= int,
|
||||
default= 30,
|
||||
help= "The fps of the realsense image",
|
||||
)
|
||||
parser.add_argument("--crop_left",
|
||||
type= int,
|
||||
default= 60,
|
||||
help= "num of pixel to crop in the original pyrealsense readings."
|
||||
)
|
||||
parser.add_argument("--crop_right",
|
||||
type= int,
|
||||
default= 46,
|
||||
help= "num of pixel to crop in the original pyrealsense readings."
|
||||
)
|
||||
parser.add_argument("--crop_top",
|
||||
type= int,
|
||||
default= 0,
|
||||
help= "num of pixel to crop in the original pyrealsense readings."
|
||||
)
|
||||
parser.add_argument("--crop_bottom",
|
||||
type= int,
|
||||
default= 0,
|
||||
help= "num of pixel to crop in the original pyrealsense readings."
|
||||
)
|
||||
parser.add_argument("--crop_far",
|
||||
type= float,
|
||||
default= 3.0,
|
||||
help= "asside from the config far limit, make all depth readings larger than this value to be 3.0 in un-normalized network input."
|
||||
)
|
||||
parser.add_argument("--enable_rgb",
|
||||
action= "store_true",
|
||||
help= "Whether to enable rgb image",
|
||||
)
|
||||
parser.add_argument("--enable_vis",
|
||||
action= "store_true",
|
||||
help= "Whether to publish realsense image",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
|
@ -164,6 +164,12 @@ class TPPO(PPO):
|
|||
self.actor_critic.action_mean - minibatch.action_labels,
|
||||
dim= -1
|
||||
)
|
||||
elif self.distill_target == "l1":
|
||||
dist_loss = torch.norm(
|
||||
self.actor_critic.action_mean - minibatch.action_labels,
|
||||
dim= -1,
|
||||
p= 1,
|
||||
)
|
||||
elif self.distill_target == "tanh":
|
||||
# for tanh, similar to loss function for sigmoid, refer to https://stats.stackexchange.com/questions/12754/matching-loss-function-for-tanh-units-in-a-neural-net
|
||||
dist_loss = F.binary_cross_entropy(
|
||||
|
|
|
@ -15,6 +15,8 @@ class ActorCriticFieldMutex(ActorCriticMutex):
|
|||
def __init__(self,
|
||||
*args,
|
||||
cmd_vel_mapping = dict(),
|
||||
reset_non_selected = "all",
|
||||
action_smoothing_buffer_len = 1,
|
||||
**kwargs,
|
||||
):
|
||||
""" NOTE: This implementation only supports subpolicy output to (-1., 1.) range.
|
||||
|
@ -24,6 +26,9 @@ class ActorCriticFieldMutex(ActorCriticMutex):
|
|||
"""
|
||||
super().__init__(*args, **kwargs)
|
||||
self.cmd_vel_mapping = cmd_vel_mapping
|
||||
self.reset_non_selected = reset_non_selected
|
||||
self.action_smoothing_buffer_len = action_smoothing_buffer_len
|
||||
self.action_smoothing_buffer = None
|
||||
|
||||
# load cmd_scale to assign the cmd_vel during overriding
|
||||
self.cmd_scales = []
|
||||
|
@ -101,18 +106,33 @@ class ActorCriticFieldMutex(ActorCriticMutex):
|
|||
def act_inference(self, observations):
|
||||
# run entire batch for each sub policy in case the batch size and length problem.
|
||||
policy_selection = self.get_policy_selection(observations)
|
||||
if self.action_smoothing_buffer is None:
|
||||
self.action_smoothing_buffer = torch.zeros(
|
||||
self.action_smoothing_buffer_len,
|
||||
*policy_selection.shape,
|
||||
device= policy_selection.device,
|
||||
dtype= torch.float,
|
||||
) # (len, N, ..., selection)
|
||||
self.action_smoothing_buffer = torch.cat([
|
||||
self.action_smoothing_buffer[1:],
|
||||
policy_selection.unsqueeze(0),
|
||||
], dim= 0) # put the new one at the end
|
||||
observations = self.recover_last_action(observations, policy_selection)
|
||||
if self.cmd_vel_mapping:
|
||||
observations = self.override_cmd_vel(observations, policy_selection)
|
||||
outputs = [p.act_inference(observations) for p in self.submodules]
|
||||
output = torch.empty_like(outputs[0])
|
||||
output = torch.zeros_like(outputs[0])
|
||||
for idx, out in enumerate(outputs):
|
||||
output[policy_selection[..., idx]] = torch.clip(
|
||||
out[policy_selection[..., idx]] * getattr(self, "subpolicy_action_scale_{:d}".format(idx)) / self.env_action_scale,
|
||||
-1., 1.,
|
||||
)
|
||||
output += out * getattr(self, "subpolicy_action_scale_{:d}".format(idx)) / self.env_action_scale \
|
||||
* self.action_smoothing_buffer[..., idx].mean(dim= 0).unsqueeze(-1)
|
||||
# choose one or none reset method
|
||||
self.submodules[idx].reset(~policy_selection[..., idx])
|
||||
if self.reset_non_selected == "all" or self.reset_non_selected == True:
|
||||
self.submodules[idx].reset(self.action_smoothing_buffer[..., idx].sum(0) == 0)
|
||||
elif self.reset_non_selected == "when_skill" and idx > 0:
|
||||
self.submodules[idx].reset(torch.logical_and(
|
||||
~policy_selection[..., idx],
|
||||
~policy_selection[..., 0],
|
||||
))
|
||||
# self.submodules[idx].reset(torch.ones(observations.shape[0], dtype= bool, device= observations.device))
|
||||
return output
|
||||
|
||||
|
@ -122,16 +142,18 @@ class ActorCriticFieldMutex(ActorCriticMutex):
|
|||
return super().reset(dones)
|
||||
|
||||
class ActorCriticClimbMutex(ActorCriticFieldMutex):
|
||||
""" A variant to handle climb-up and climb-down with seperate policies
|
||||
Climb-down policy will be the last submodule in the list
|
||||
""" A variant to handle jump-up and jump-down with seperate policies
|
||||
Jump-down policy will be the last submodule in the list
|
||||
"""
|
||||
JUMP_OBSTACLE_ID = 3 # starting from 0, referring to barrker_track.py:BarrierTrack.track_options_id_dict
|
||||
def __init__(self,
|
||||
*args,
|
||||
sub_policy_paths: list = None,
|
||||
climb_down_policy_path: str = None,
|
||||
jump_down_policy_path: str = None,
|
||||
jump_down_vel: float = None, # can be tuple/list, use it to stop using jump up velocity command
|
||||
**kwargs,):
|
||||
sub_policy_paths = sub_policy_paths + [climb_down_policy_path]
|
||||
sub_policy_paths = sub_policy_paths + [jump_down_policy_path]
|
||||
self.jump_down_vel = jump_down_vel
|
||||
super().__init__(
|
||||
*args,
|
||||
sub_policy_paths= sub_policy_paths,
|
||||
|
@ -140,23 +162,28 @@ class ActorCriticClimbMutex(ActorCriticFieldMutex):
|
|||
|
||||
def resample_cmd_vel_current(self, dones=None):
|
||||
return_ = super().resample_cmd_vel_current(dones)
|
||||
self.cmd_vel_current[len(self.submodules) - 1] = self.cmd_vel_current[self.JUMP_OBSTACLE_ID]
|
||||
if self.jump_down_vel is None:
|
||||
self.cmd_vel_current[len(self.submodules) - 1] = self.cmd_vel_current[self.JUMP_OBSTACLE_ID]
|
||||
elif isinstance(self.jump_down_vel, (tuple, list)):
|
||||
self.cmd_vel_current[len(self.submodules) - 1] = np.random.uniform(*self.jump_down_vel)
|
||||
else:
|
||||
self.cmd_vel_current[len(self.submodules) - 1] = self.jump_down_vel
|
||||
return return_
|
||||
|
||||
def get_policy_selection(self, observations):
|
||||
obstacle_id_onehot = super().get_policy_selection(observations)
|
||||
obs_slice = get_obs_slice(self.obs_segments, "engaging_block")
|
||||
engaging_block_obs = observations[..., obs_slice[0]].reshape(*observations.shape[:-1], *obs_slice[1])
|
||||
climb_up_mask = engaging_block_obs[..., -1] > 0 # climb-up or climb-down
|
||||
jump_up_mask = engaging_block_obs[..., -1] > 0 # jump-up or jump-down
|
||||
obstacle_id_onehot = torch.cat([
|
||||
obstacle_id_onehot,
|
||||
torch.logical_and(
|
||||
obstacle_id_onehot[..., self.JUMP_OBSTACLE_ID],
|
||||
torch.logical_not(climb_up_mask),
|
||||
torch.logical_not(jump_up_mask),
|
||||
).unsqueeze(-1)
|
||||
], dim= -1)
|
||||
obstacle_id_onehot[..., self.JUMP_OBSTACLE_ID] = torch.logical_and(
|
||||
obstacle_id_onehot[..., self.JUMP_OBSTACLE_ID],
|
||||
climb_up_mask,
|
||||
jump_up_mask,
|
||||
)
|
||||
return obstacle_id_onehot.to(torch.bool) # (N, ..., selection)
|
||||
|
|
|
@ -76,6 +76,7 @@ class OnPolicyRunner:
|
|||
self.tot_timesteps = 0
|
||||
self.tot_time = 0
|
||||
self.current_learning_iteration = 0
|
||||
self.log_interval = self.cfg.get("log_interval", 1)
|
||||
|
||||
_, _ = self.env.reset()
|
||||
|
||||
|
@ -130,7 +131,7 @@ class OnPolicyRunner:
|
|||
losses, stats = self.alg.update(self.current_learning_iteration)
|
||||
stop = time.time()
|
||||
learn_time = stop - start
|
||||
if self.log_dir is not None:
|
||||
if self.log_dir is not None and self.current_learning_iteration % self.log_interval == 0:
|
||||
self.log(locals())
|
||||
if self.current_learning_iteration % self.save_interval == 0 and self.current_learning_iteration > start_iter:
|
||||
self.save(os.path.join(self.log_dir, 'model_{}.pt'.format(self.current_learning_iteration)))
|
||||
|
@ -238,7 +239,7 @@ class OnPolicyRunner:
|
|||
def load(self, path, load_optimizer=True):
|
||||
loaded_dict = torch.load(path)
|
||||
self.alg.actor_critic.load_state_dict(loaded_dict['model_state_dict'])
|
||||
if load_optimizer:
|
||||
if load_optimizer and "optimizer_state_dict" in loaded_dict:
|
||||
self.alg.optimizer.load_state_dict(loaded_dict['optimizer_state_dict'])
|
||||
if "lr_scheduler_state_dict" in loaded_dict:
|
||||
if not hasattr(self.alg, "lr_scheduler"):
|
||||
|
|
|
@ -12,6 +12,7 @@ class TwoStageRunner(OnPolicyRunner):
|
|||
|
||||
# load some configs and their default values
|
||||
self.pretrain_iterations = self.cfg.get("pretrain_iterations", 0)
|
||||
self.log_interval = self.cfg.get("log_interval", 50)
|
||||
assert "pretrain_dataset" in self.cfg, "pretrain_dataset is not defined in the runner cfg object"
|
||||
self.rollout_dataset = RolloutDataset(
|
||||
**self.cfg["pretrain_dataset"],
|
||||
|
|
Loading…
Reference in New Issue