[adding] deploy code and instructions

* latest environments will follow
This commit is contained in:
Ziwen Zhuang 2023-11-08 03:27:50 +08:00
parent c0d55bbb56
commit f771497315
14 changed files with 5873 additions and 18 deletions

83
Deploy.md Normal file
View File

@ -0,0 +1,83 @@
# Deploy the model on your real Unitree robot
This version shows an example of how to deploy the model on the Unittree Go1 robot (with Nvidia Jetson NX).
## Install dependencies on Go1
To deploy the trained model on Go1, please set up a folder on your robot, e.g. `parkour`, and copy the `rsl_rl` folder to it. Then, extract the zip files in `go1_ckpts` to the `parkour` folder. Finally, copy all the files in `onboard_script` to the `parkour` folder.
1. Install ROS and the [unitree ros package for Go1](https://github.com/Tsinghua-MARS-Lab/unitree_ros_real.git) and follow the instructions to set up the robot on branch `go1`
Assuming the ros workspace is located in `parkour/unitree_ws`
2. Install pytorch on a Python 3 environment. (take the Python3 virtual environment on Nvidia Jetson NX as an example)
```bash
sudo apt-get install python3-pip python3-dev python3-venv
python3 -m venv parkour_venv
source parkour_venv/bin/activate
```
Download the pip wheel file from [here](https://forums.developer.nvidia.com/t/pytorch-for-jetson/72048) with v1.10.0. Then install it with
```bash
pip install torch-1.10.0-cp36-cp36m-linux_aarch64.whl
```
3. Install other dependencies and `rsl_rl`
```bash
pip install numpy==1.16.4 numpy-ros
pip install -e ./rsl_rl
```
4. 3D print the camera mount for Go1 using the step file in `go1_ckpts/go1_camMount_30Down.step`. Use the two pairs of screw holes to mount the Intel Realsense D435i camera on the robot, as shown in the picture below.
<p align="center">
<img src="images/go1_camMount_30Down.png" width="50%"/>
</p>
## Run the model on Go1
***Disclaimer:*** *Always put a safety belt on the robot when the robot moves. The robot may fall down and cause damage to itself or the environment.*
1. Put the robot on the ground, power on the robot, and turn the robot into developer mode. Make sure your Intel Realsense D435i camera is connected to the robot and the camera is installed on the mount.
2. Launch 3 terminals onboard (whether 3 ssh connections from your computer or something else), named T_ros, T_visual, T_gru.
3. In T_ros, run
```bash
cd parkour/unitree_ws
source devel/setup.bash
roslaunch unitree_legged_real robot.launch
```
This will launch the ros node for Go1. Please note that without `dryrun:=False` the robot will not move but do anything else. Set `dryrun:=False` only when you are ready to let the robot move.
4. In T_visual, run
```bash
cd parkour
source unitree_ws/devel/setup.bash
source parkour_venv/bin/activate
python go1_visual_embedding.py --logdir Nov02...
```
where `Nov02...` is the logdir of the trained model.
Wait for the script to finish loading the model and get access to the Realsense sensor pipeline. Then, you can see the script prompting you: `"Realsense frame received. Sending embeddings..."`
Adding the `--enable_vis` option will enable the depth image message as a rostopic. You can visualize the depth image in rviz.
5. In T_gru, run
```bash
cd parkour
source unitree_ws/devel/setup.bash
source parkour_venv/bin/activate
python a1_ros_run.py --mode upboard --logdir Nov02...
```
where `Nov02...` is the logdir of the trained model.
If the ros node is launched with `dryrun:=False`, the robot will start standing up. Otherwise, add `--debug` option on `a1_ros_run.py` to see the model output.
If the ros node is launched with `dryrun:=False`, the script will prompt you: `"Robot stood up! press R1 on the remote control to continue ..."` Press R1 on the remote control to start the standing policy
Pushing R-Y forward on the remote control to trigger the parkour policy. The robot will start running and jumping. Release the joystick to stop the robot.
Press R2 on the remote control to shutdown the gru script and the ros node. You can also use it in case of emergency.
Using the `--walkdir` option to load the walking policy e.g. (Oct24_16...), the standing policy will be replaced by the walking policy. Then you can use L-Y, L-X to control the x/y velocity of the robot and use R-X to control the yaw velocity of the robot.

View File

@ -27,7 +27,8 @@ Conference on Robot Learning (CoRL) 2023, Oral <br>
To install and run the code for training A1 in simulation, please clone this repository and follow the instructions in [legged_gym/README.md](legged_gym/README.md). To install and run the code for training A1 in simulation, please clone this repository and follow the instructions in [legged_gym/README.md](legged_gym/README.md).
## Hardware Deployment ## ## Hardware Deployment ##
TODO To deploy the trained model on your real robot, please follow the instructions in [Deploy.md](Deploy.md).
## Trouble Shooting ## ## Trouble Shooting ##
If you cannot run the distillation part or all graphics computing goes to GPU 0 dispite you have multiple GPUs and have set the CUDA_VISIBLE_DEVICES, please use docker to isolate each GPU. If you cannot run the distillation part or all graphics computing goes to GPU 0 dispite you have multiple GPUs and have set the CUDA_VISIBLE_DEVICES, please use docker to isolate each GPU.
@ -35,7 +36,7 @@ If you cannot run the distillation part or all graphics computing goes to GPU 0
## To Do (will be done before Nov 2023) ## ## To Do (will be done before Nov 2023) ##
- [ ] Go1 training pipeline in simulation - [ ] Go1 training pipeline in simulation
- [ ] A1 deployment code - [ ] A1 deployment code
- [ ] Go1 deployment code - [x] Go1 deployment code
## Citation ## ## Citation ##
If you find this project helpful to your research, please consider cite us! This is really important to us. If you find this project helpful to your research, please consider cite us! This is really important to us.

File diff suppressed because it is too large Load Diff

Binary file not shown.

After

Width:  |  Height:  |  Size: 961 KiB

489
onboard_script/a1_real.py Normal file
View File

@ -0,0 +1,489 @@
import numpy as np
import torch
import torch.nn.functional as F
from torch.autograd import Variable
import json
import os
import os.path as osp
from collections import OrderedDict
from typing import Tuple
import rospy
from unitree_legged_msgs.msg import LowState
from unitree_legged_msgs.msg import LegsCmd
from unitree_legged_msgs.msg import Float32MultiArrayStamped
from std_msgs.msg import Float32MultiArray
from geometry_msgs.msg import Twist, Pose
from nav_msgs.msg import Odometry
from sensor_msgs.msg import Image
import ros_numpy
@torch.no_grad()
def resize2d(img, size):
return (F.adaptive_avg_pool2d(Variable(img), size)).data
@torch.jit.script
def quat_rotate_inverse(q, v):
shape = q.shape
q_w = q[:, -1]
q_vec = q[:, :3]
a = v * (2.0 * q_w ** 2 - 1.0).unsqueeze(-1)
b = torch.cross(q_vec, v, dim=-1) * q_w.unsqueeze(-1) * 2.0
c = q_vec * \
torch.bmm(q_vec.view(shape[0], 1, 3), v.view(
shape[0], 3, 1)).squeeze(-1) * 2.0
return a - b + c
class UnitreeA1Real:
""" This is the handler that works for ROS 1 on unitree. """
def __init__(self,
robot_namespace= "a112138",
low_state_topic= "/low_state",
legs_cmd_topic= "/legs_cmd",
forward_depth_topic = "/camera/depth/image_rect_raw",
forward_depth_embedding_dims = None,
odom_topic= "/odom/filtered",
lin_vel_deadband= 0.1,
ang_vel_deadband= 0.1,
move_by_wireless_remote= False, # if True, command will not listen to move_cmd_subscriber, but wireless remote.
cfg= dict(),
extra_cfg= dict(),
model_device= torch.device("cpu"),
):
"""
NOTE:
* Must call start_ros() before using this class's get_obs() and send_action()
* Joint order of simulation and of real A1 protocol are different, see dof_names
* We store all joints values in the order of simulation in this class
Args:
forward_depth_embedding_dims: If a real number, the obs will not be built as a normal env.
The segment of obs will be subsituted by the embedding of forward depth image from the
ROS topic.
cfg: same config from a1_config but a dict object.
extra_cfg: some other configs that is hard to load from file.
"""
self.model_device = model_device
self.num_envs = 1
self.robot_namespace = robot_namespace
self.low_state_topic = low_state_topic
self.legs_cmd_topic = legs_cmd_topic
self.forward_depth_topic = forward_depth_topic
self.forward_depth_embedding_dims = forward_depth_embedding_dims
self.odom_topic = odom_topic
self.lin_vel_deadband = lin_vel_deadband
self.ang_vel_deadband = ang_vel_deadband
self.move_by_wireless_remote = move_by_wireless_remote
self.cfg = cfg
self.extra_cfg = dict(
torque_limits= torch.tensor([33.5] * 12, dtype= torch.float32, device= self.model_device, requires_grad= False), # Nm
# torque_limits= torch.tensor([1, 5, 5] * 4, dtype= torch.float32, device= self.model_device, requires_grad= False), # Nm
dof_map= [ # from isaacgym simulation joint order to URDF order
3, 4, 5,
0, 1, 2,
9, 10,11,
6, 7, 8,
], # real_joint_idx = dof_map[sim_joint_idx]
dof_names= [ # NOTE: order matters. This list is the order in simulation.
"FL_hip_joint",
"FL_thigh_joint",
"FL_calf_joint",
"FR_hip_joint",
"FR_thigh_joint",
"FR_calf_joint",
"RL_hip_joint",
"RL_thigh_joint",
"RL_calf_joint",
"RR_hip_joint",
"RR_thigh_joint",
"RR_calf_joint",
],
# motor strength is multiplied directly to the action.
motor_strength= torch.ones(12, dtype= torch.float32, device= self.model_device, requires_grad= False),
); self.extra_cfg.update(extra_cfg)
if "torque_limits" in self.cfg["control"]:
if isinstance(self.cfg["control"]["torque_limits"], (tuple, list)):
for i in range(len(self.cfg["control"]["torque_limits"])):
self.extra_cfg["torque_limits"][i] = self.cfg["control"]["torque_limits"][i]
else:
self.extra_cfg["torque_limits"][:] = self.cfg["control"]["torque_limits"]
self.command_buf = torch.zeros((self.num_envs, 3,), device= self.model_device, dtype= torch.float32) # zeros for initialization
self.actions = torch.zeros((1, 12), device= model_device, dtype= torch.float32)
self.process_configs()
def start_ros(self):
# initialze several buffers so that the system works even without message update.
# self.low_state_buffer = LowState() # not initialized, let input message update it.
self.base_position_buffer = torch.zeros((self.num_envs, 3), device= self.model_device, requires_grad= False)
self.legs_cmd_publisher = rospy.Publisher(
self.robot_namespace + self.legs_cmd_topic,
LegsCmd,
queue_size= 1,
)
# self.debug_publisher = rospy.Publisher(
# "/DNNmodel_debug",
# Float32MultiArray,
# queue_size= 1,
# )
# NOTE: this launches the subscriber callback function
self.low_state_subscriber = rospy.Subscriber(
self.robot_namespace + self.low_state_topic,
LowState,
self.update_low_state,
queue_size= 1,
)
self.odom_subscriber = rospy.Subscriber(
self.robot_namespace + self.odom_topic,
Odometry,
self.update_base_pose,
queue_size= 1,
)
if not self.move_by_wireless_remote:
self.move_cmd_subscriber = rospy.Subscriber(
"/cmd_vel",
Twist,
self.update_move_cmd,
queue_size= 1,
)
if "forward_depth" in self.all_obs_components:
if not self.forward_depth_embedding_dims:
self.forward_depth_subscriber = rospy.Subscriber(
self.robot_namespace + self.forward_depth_topic,
Image,
self.update_forward_depth,
queue_size= 1,
)
else:
self.forward_depth_subscriber = rospy.Subscriber(
self.robot_namespace + self.forward_depth_topic,
Float32MultiArrayStamped,
self.update_forward_depth_embedding,
queue_size= 1,
)
self.pose_cmd_subscriber = rospy.Subscriber(
"/body_pose",
Pose,
self.dummy_handler,
queue_size= 1,
)
def wait_untill_ros_working(self):
rate = rospy.Rate(100)
while not hasattr(self, "low_state_buffer"):
rate.sleep()
rospy.loginfo("UnitreeA1Real.low_state_buffer acquired, stop waiting.")
def process_configs(self):
self.up_axis_idx = 2 # 2 for z, 1 for y -> adapt gravity accordingly
self.gravity_vec = torch.zeros((self.num_envs, 3), dtype= torch.float32)
self.gravity_vec[:, self.up_axis_idx] = -1
self.obs_scales = self.cfg["normalization"]["obs_scales"]
self.obs_scales["dof_pos"] = torch.tensor(self.obs_scales["dof_pos"], device= self.model_device, dtype= torch.float32)
if not isinstance(self.cfg["control"]["damping"]["joint"], (list, tuple)):
self.cfg["control"]["damping"]["joint"] = [self.cfg["control"]["damping"]["joint"]] * 12
if not isinstance(self.cfg["control"]["stiffness"]["joint"], (list, tuple)):
self.cfg["control"]["stiffness"]["joint"] = [self.cfg["control"]["stiffness"]["joint"]] * 12
self.d_gains = torch.tensor(self.cfg["control"]["damping"]["joint"], device= self.model_device, dtype= torch.float32)
self.p_gains = torch.tensor(self.cfg["control"]["stiffness"]["joint"], device= self.model_device, dtype= torch.float32)
self.default_dof_pos = torch.zeros(12, device= self.model_device, dtype= torch.float32)
for i in range(12):
name = self.extra_cfg["dof_names"][i]
default_joint_angle = self.cfg["init_state"]["default_joint_angles"][name]
self.default_dof_pos[i] = default_joint_angle
self.computer_clip_torque = self.cfg["control"].get("computer_clip_torque", True)
rospy.loginfo("Computer Clip Torque (onboard) is " + str(self.computer_clip_torque))
if self.computer_clip_torque:
self.torque_limits = self.extra_cfg["torque_limits"]
rospy.loginfo("[Env] torque limit: {:.1f} {:.1f} {:.1f}".format(*self.torque_limits[:3]))
self.commands_scale = torch.tensor([
self.obs_scales["lin_vel"],
self.obs_scales["lin_vel"],
self.obs_scales["lin_vel"],
], device= self.model_device, requires_grad= False)
self.obs_segments = self.get_obs_segment_from_components(self.cfg["env"]["obs_components"])
self.num_obs = self.get_num_obs_from_components(self.cfg["env"]["obs_components"])
components = self.cfg["env"].get("privileged_obs_components", None)
self.privileged_obs_segments = None if components is None else self.get_num_obs_from_components(components)
self.num_privileged_obs = None if components is None else self.get_num_obs_from_components(components)
self.all_obs_components = self.cfg["env"]["obs_components"] + (self.cfg["env"].get("privileged_obs_components", []) if components is not None else [])
# store config values to attributes to improve speed
self.clip_obs = self.cfg["normalization"]["clip_observations"]
self.control_type = self.cfg["control"]["control_type"]
self.action_scale = self.cfg["control"]["action_scale"]
rospy.loginfo("[Env] action scale: {:.1f}".format(self.action_scale))
self.clip_actions = self.cfg["normalization"]["clip_actions"]
if self.cfg["normalization"].get("clip_actions_method", None) == "hard":
rospy.loginfo("clip_actions_method with hard mode")
rospy.loginfo("clip_actions_high: " + str(self.cfg["normalization"]["clip_actions_high"]))
rospy.loginfo("clip_actions_low: " + str(self.cfg["normalization"]["clip_actions_low"]))
self.clip_actions_method = "hard"
self.clip_actions_low = torch.tensor(self.cfg["normalization"]["clip_actions_low"], device= self.model_device, dtype= torch.float32)
self.clip_actions_high = torch.tensor(self.cfg["normalization"]["clip_actions_high"], device= self.model_device, dtype= torch.float32)
else:
rospy.loginfo("clip_actions_method is " + str(self.cfg["normalization"].get("clip_actions_method", None)))
self.dof_map = self.extra_cfg["dof_map"]
# get ROS params for hardware configs
self.joint_limits_high = torch.tensor([
rospy.get_param(self.robot_namespace + "/joint_limits/{}_max".format(s)) \
for s in ["hip", "thigh", "calf"] * 4
])
self.joint_limits_low = torch.tensor([
rospy.get_param(self.robot_namespace + "/joint_limits/{}_min".format(s)) \
for s in ["hip", "thigh", "calf"] * 4
])
if "forward_depth" in self.all_obs_components:
resolution = self.cfg["sensor"]["forward_camera"].get(
"output_resolution",
self.cfg["sensor"]["forward_camera"]["resolution"],
)
if not self.forward_depth_embedding_dims:
self.forward_depth_buf = torch.zeros(
(self.num_envs, *resolution),
device= self.model_device,
dtype= torch.float32,
)
else:
self.forward_depth_embedding_buf = torch.zeros(
(1, self.forward_depth_embedding_dims),
device= self.model_device,
dtype= torch.float32,
)
def _init_height_points(self):
""" Returns points at which the height measurments are sampled (in base frame)
Returns:
[torch.Tensor]: Tensor of shape (num_envs, self.num_height_points, 3)
"""
return None
def _get_heights(self):
""" TODO: get estimated terrain heights around the robot base """
# currently return a zero tensor with valid size
return torch.zeros(self.num_envs, 187, device= self.model_device, requires_grad= False)
def clip_action_before_scale(self, actions):
actions = torch.clip(actions, -self.clip_actions, self.clip_actions)
if getattr(self, "clip_actions_method", None) == "hard":
actions = torch.clip(actions, self.clip_actions_low, self.clip_actions_high)
return actions
def clip_by_torque_limit(self, actions_scaled):
""" Different from simulation, we reverse the process and clip the actions directly,
so that the PD controller runs in robot but not our script.
"""
control_type = self.cfg["control"]["control_type"]
if control_type == "P":
p_limits_low = (-self.torque_limits) + self.d_gains*self.dof_vel
p_limits_high = (self.torque_limits) + self.d_gains*self.dof_vel
actions_low = (p_limits_low/self.p_gains) - self.default_dof_pos + self.dof_pos
actions_high = (p_limits_high/self.p_gains) - self.default_dof_pos + self.dof_pos
else:
raise NotImplementedError
return torch.clip(actions_scaled, actions_low, actions_high)
""" Get obs components and cat to a single obs input """
def _get_proprioception_obs(self):
# base_ang_vel = quat_rotate_inverse(
# torch.tensor(self.low_state_buffer.imu.quaternion).unsqueeze(0),
# torch.tensor(self.low_state_buffer.imu.gyroscope).unsqueeze(0),
# ).to(self.model_device)
# NOTE: Different from the isaacgym.
# The anglar velocity is already in base frame, no need to rotate
base_ang_vel = torch.tensor(self.low_state_buffer.imu.gyroscope, device= self.model_device).unsqueeze(0)
projected_gravity = quat_rotate_inverse(
torch.tensor(self.low_state_buffer.imu.quaternion).unsqueeze(0),
self.gravity_vec,
).to(self.model_device)
self.dof_pos = dof_pos = torch.tensor([
self.low_state_buffer.motorState[self.dof_map[i]].q for i in range(12)
], dtype= torch.float32, device= self.model_device).unsqueeze(0)
self.dof_vel = dof_vel = torch.tensor([
self.low_state_buffer.motorState[self.dof_map[i]].dq for i in range(12)
], dtype= torch.float32, device= self.model_device).unsqueeze(0)
# rospy.loginfo_throttle(5,
# "projected_gravity: " + \
# ", ".join([str(i) for i in projected_gravity[0].cpu().tolist()])
# )
# rospy.loginfo_throttle(5, "Hacking projected_gravity y -= 0.15")
# projected_gravity[:, 1] -= 0.15
return torch.cat([
torch.zeros((1, 3), device= self.model_device), # no linear velocity
base_ang_vel * self.obs_scales["ang_vel"],
projected_gravity,
self.command_buf * self.commands_scale,
(dof_pos - self.default_dof_pos) * self.obs_scales["dof_pos"],
dof_vel * self.obs_scales["dof_vel"],
self.actions
], dim= -1)
def _get_forward_depth_obs(self):
if not self.forward_depth_embedding_dims:
return self.forward_depth_buf.flatten(start_dim= 1)
else:
if self.low_state_get_time.to_sec() - self.forward_depth_embedding_stamp.to_sec() > 0.4:
rospy.logerr("Getting depth embedding later than low_state later than 0.4s")
return self.forward_depth_embedding_buf.flatten(start_dim= 1)
def compute_observation(self):
""" use the updated low_state_buffer to compute observation vector """
assert hasattr(self, "legs_cmd_publisher"), "start_ros() not called, ROS handlers are not initialized!"
obs_segments = self.obs_segments
obs = []
for k, v in obs_segments.items():
obs.append(
getattr(self, "_get_" + k + "_obs")() * \
self.obs_scales.get(k, 1.)
)
obs = torch.cat(obs, dim= 1)
self.obs_buf = obs
""" The methods combined with outer model forms the step function
NOTE: the outer user handles the loop frequency.
"""
def send_action(self, actions):
""" The function that send commands to the real robot.
"""
self.actions = self.clip_action_before_scale(actions)
if self.computer_clip_torque:
robot_coordinates_action = self.clip_by_torque_limit(actions * self.action_scale) + self.default_dof_pos.unsqueeze(0)
else:
rospy.logwarn_throttle(60, "You are using control without any torque clip. The network might output torques larger than the system can provide.")
robot_coordinates_action = self.actions * self.action_scale + self.default_dof_pos.unsqueeze(0)
# debugging and logging
# transfered_action = torch.zeros_like(self.actions[0])
# for i in range(12):
# transfered_action[self.dof_map[i]] = self.actions[0, i] + self.default_dof_pos[i]
# self.debug_publisher.publish(Float32MultiArray(data=
# transfered_action\
# .cpu().numpy().astype(np.float32).tolist()
# ))
# restrict the target action delta in order to avoid robot shutdown (maybe there is another solution)
# robot_coordinates_action = torch.clip(
# robot_coordinates_action,
# self.dof_pos - 0.3,
# self.dof_pos + 0.3,
# )
# wrap the message and publish
self.publish_legs_cmd(robot_coordinates_action)
def publish_legs_cmd(self, robot_coordinates_action, kp= None, kd= None):
""" publish the joint position directly to the robot. NOTE: The joint order from input should
be in simulation order. The value should be absolute value rather than related to dof_pos.
"""
robot_coordinates_action = torch.clip(
robot_coordinates_action.cpu(),
self.joint_limits_low,
self.joint_limits_high,
)
legs_cmd = LegsCmd()
for sim_joint_idx in range(12):
real_joint_idx = self.dof_map[sim_joint_idx]
legs_cmd.cmd[real_joint_idx].mode = 10
legs_cmd.cmd[real_joint_idx].q = robot_coordinates_action[0, sim_joint_idx] if self.control_type == "P" else rospy.get_param(self.robot_namespace + "/PosStopF", (2.146e+9))
legs_cmd.cmd[real_joint_idx].dq = 0.
legs_cmd.cmd[real_joint_idx].tau = 0.
legs_cmd.cmd[real_joint_idx].Kp = self.p_gains[sim_joint_idx] if kp is None else kp
legs_cmd.cmd[real_joint_idx].Kd = self.d_gains[sim_joint_idx] if kd is None else kd
self.legs_cmd_publisher.publish(legs_cmd)
def get_obs(self):
""" The function that refreshes the buffer and return the observation vector.
"""
self.compute_observation()
self.obs_buf = torch.clip(self.obs_buf, -self.clip_obs, self.clip_obs)
return self.obs_buf.to(self.model_device)
""" Copied from legged_robot_field. Please check whether these are consistent. """
def get_obs_segment_from_components(self, components):
segments = OrderedDict()
if "proprioception" in components:
segments["proprioception"] = (48,)
if "height_measurements" in components:
segments["height_measurements"] = (187,)
if "forward_depth" in components:
resolution = self.cfg["sensor"]["forward_camera"].get(
"output_resolution",
self.cfg["sensor"]["forward_camera"]["resolution"],
)
segments["forward_depth"] = (1, *resolution)
# The following components are only for rebuilding the non-actor module.
# DO NOT use these in actor network and check consistency with simulator implementation.
if "base_pose" in components:
segments["base_pose"] = (6,) # xyz + rpy
if "robot_config" in components:
segments["robot_config"] = (1 + 3 + 1 + 12,)
if "engaging_block" in components:
# This could be wrong, please check the implementation of BarrierTrack
segments["engaging_block"] = (1 + (4 + 1) + 2,)
if "sidewall_distance" in components:
segments["sidewall_distance"] = (2,)
return segments
def get_num_obs_from_components(self, components):
obs_segments = self.get_obs_segment_from_components(components)
num_obs = 0
for k, v in obs_segments.items():
num_obs += np.prod(v)
return num_obs
""" ROS callbacks and handlers that update the buffer """
def update_low_state(self, ros_msg):
self.low_state_buffer = ros_msg
if self.move_by_wireless_remote:
self.command_buf[0, 0] = self.low_state_buffer.wirelessRemote.ly
self.command_buf[0, 1] = -self.low_state_buffer.wirelessRemote.lx # right-moving stick is positive
self.command_buf[0, 2] = -self.low_state_buffer.wirelessRemote.rx # right-moving stick is positive
# set the command to zero if it is too small
if np.linalg.norm(self.command_buf[0, :2]) < self.lin_vel_deadband:
self.command_buf[0, :2] = 0.
if np.abs(self.command_buf[0, 2]) < self.ang_vel_deadband:
self.command_buf[0, 2] = 0.
self.low_state_get_time = rospy.Time.now()
def update_base_pose(self, ros_msg):
""" update robot odometry for position """
self.base_position_buffer[0, 0] = ros_msg.pose.pose.position.x
self.base_position_buffer[0, 1] = ros_msg.pose.pose.position.y
self.base_position_buffer[0, 2] = ros_msg.pose.pose.position.z
def update_move_cmd(self, ros_msg):
self.command_buf[0, 0] = ros_msg.linear.x
self.command_buf[0, 1] = ros_msg.linear.y
self.command_buf[0, 2] = ros_msg.angular.z
def update_forward_depth(self, ros_msg):
# TODO not checked.
self.forward_depth_header = ros_msg.header
buf = ros_numpy.numpify(ros_msg)
self.forward_depth_buf = resize2d(
torch.from_numpy(buf.astype(np.float32)).unsqueeze(0).unsqueeze(0).to(self.model_device),
self.forward_depth_buf.shape[-2:],
)
def update_forward_depth_embedding(self, ros_msg):
rospy.loginfo_once("a1_ros_run recieved forward depth embedding.")
self.forward_depth_embedding_stamp = ros_msg.header.stamp
self.forward_depth_embedding_buf[:] = torch.tensor(ros_msg.data).unsqueeze(0) # (1, d)
def dummy_handler(self, ros_msg):
""" To meet the need of teleop-legged-robots requirements """
pass

View File

@ -0,0 +1,402 @@
#!/home/unitree/agility_ziwenz_venv/bin/python
import os
import os.path as osp
import json
import numpy as np
import torch
from collections import OrderedDict
from functools import partial
from typing import Tuple
import rospy
from std_msgs.msg import Float32MultiArray
from sensor_msgs.msg import Image
import ros_numpy
from a1_real import UnitreeA1Real, resize2d
from rsl_rl import modules
from rsl_rl.utils.utils import get_obs_slice
@torch.no_grad()
def handle_forward_depth(ros_msg, model, publisher, output_resolution, device):
""" The callback function to handle the forward depth and send the embedding through ROS topic """
buf = ros_numpy.numpify(ros_msg).astype(np.float32)
forward_depth_buf = resize2d(
torch.from_numpy(buf).unsqueeze(0).unsqueeze(0).to(device),
output_resolution,
)
embedding = model(forward_depth_buf)
ros_data = embedding.reshape(-1).cpu().numpy().astype(np.float32)
publisher.publish(Float32MultiArray(data= ros_data.tolist()))
class StandOnlyModel(torch.nn.Module):
def __init__(self, action_scale, dof_pos_scale, tolerance= 0.2, delta= 0.1):
rospy.loginfo("Using stand only model, please make sure the proprioception is 48 dim.")
rospy.loginfo("Using stand only model, -36 to -24 must be joint position.")
super().__init__()
if isinstance(action_scale, (tuple, list)):
self.register_buffer("action_scale", torch.tensor(action_scale))
else:
self.action_scale = action_scale
if isinstance(dof_pos_scale, (tuple, list)):
self.register_buffer("dof_pos_scale", torch.tensor(dof_pos_scale))
else:
self.dof_pos_scale = dof_pos_scale
self.tolerance = tolerance
self.delta = delta
def forward(self, obs):
joint_positions = obs[..., -36:-24] / self.dof_pos_scale
diff_large_mask = torch.abs(joint_positions) > self.tolerance
target_positions = torch.zeros_like(joint_positions)
target_positions[diff_large_mask] = joint_positions[diff_large_mask] - self.delta * torch.sign(joint_positions[diff_large_mask])
return torch.clip(
target_positions / self.action_scale,
-1.0, 1.0,
)
def reset(self, *args, **kwargs):
pass
def load_walk_policy(env, model_dir):
""" Load the walk policy from the model directory """
if model_dir == None:
model = StandOnlyModel(
action_scale= env.action_scale,
dof_pos_scale= env.obs_scales["dof_pos"],
)
policy = torch.jit.script(model)
else:
with open(osp.join(model_dir, "config.json"), "r") as f:
config_dict = json.load(f, object_pairs_hook= OrderedDict)
obs_components = config_dict["env"]["obs_components"]
privileged_obs_components = config_dict["env"].get("privileged_obs_components", obs_components)
model = getattr(modules, config_dict["runner"]["policy_class_name"])(
num_actor_obs= env.get_num_obs_from_components(obs_components),
num_critic_obs= env.get_num_obs_from_components(privileged_obs_components),
num_actions= 12,
**config_dict["policy"],
)
model_names = [i for i in os.listdir(model_dir) if i.startswith("model_")]
model_names.sort(key= lambda x: int(x.split("_")[-1].split(".")[0]))
state_dict = torch.load(osp.join(model_dir, model_names[-1]), map_location= "cpu")
model.load_state_dict(state_dict["model_state_dict"])
model_action_scale = torch.tensor(config_dict["control"]["action_scale"]) if isinstance(config_dict["control"]["action_scale"], (tuple, list)) else torch.tensor([config_dict["control"]["action_scale"]])[0]
if not (torch.is_tensor(model_action_scale) and (model_action_scale == env.action_scale).all()):
action_rescale_ratio = model_action_scale / env.action_scale
print("walk_policy action scaling:", action_rescale_ratio.tolist())
else:
action_rescale_ratio = 1.0
memory_module = model.memory_a
actor_mlp = model.actor
@torch.jit.script
def policy_run(obs):
recurrent_embedding = memory_module(obs)
actions = actor_mlp(recurrent_embedding.squeeze(0))
return actions
if (torch.is_tensor(action_rescale_ratio) and (action_rescale_ratio == 1.).all()) \
or (not torch.is_tensor(action_rescale_ratio) and action_rescale_ratio == 1.):
policy = policy_run
else:
policy = lambda x: policy_run(x) * action_rescale_ratio
return policy, model
def standup_procedure(env, ros_rate, angle_tolerance= 0.1,
kp= None,
kd= None,
warmup_timesteps= 25,
device= "cpu",
):
"""
Args:
warmup_timesteps: the number of timesteps to linearly increase the target position
"""
rospy.loginfo("Robot standing up, please wait ...")
target_pos = torch.zeros((1, 12), device= device, dtype= torch.float32)
standup_timestep_i = 0
while not rospy.is_shutdown():
dof_pos = [env.low_state_buffer.motorState[env.dof_map[i]].q for i in range(12)]
diff = [env.default_dof_pos[i].item() - dof_pos[i] for i in range(12)]
direction = [1 if i > 0 else -1 for i in diff]
if standup_timestep_i < warmup_timesteps:
direction = [standup_timestep_i / warmup_timesteps * i for i in direction]
if all([abs(i) < angle_tolerance for i in diff]):
break
print("max joint error (rad):", max([abs(i) for i in diff]), end= "\r")
for i in range(12):
target_pos[0, i] = dof_pos[i] + direction[i] * angle_tolerance if abs(diff[i]) > angle_tolerance else env.default_dof_pos[i]
env.publish_legs_cmd(target_pos,
kp= kp,
kd= kd,
)
ros_rate.sleep()
standup_timestep_i += 1
rospy.loginfo("Robot stood up! press R1 on the remote control to continue ...")
while not rospy.is_shutdown():
if env.low_state_buffer.wirelessRemote.btn.components.R1:
break
if env.low_state_buffer.wirelessRemote.btn.components.L2 or env.low_state_buffer.wirelessRemote.btn.components.R2:
env.publish_legs_cmd(env.default_dof_pos.unsqueeze(0), kp= 0, kd= 0.5)
rospy.signal_shutdown("Controller send stop signal, exiting")
exit(0)
env.publish_legs_cmd(env.default_dof_pos.unsqueeze(0), kp= kp, kd= kd)
ros_rate.sleep()
rospy.loginfo("Robot standing up procedure finished!")
class SkilledA1Real(UnitreeA1Real):
""" Some additional methods to help the execution of skill policy """
def __init__(self, *args,
skill_mode_threhold= 0.1,
skill_vel_range= [0.0, 1.0],
**kwargs,
):
self.skill_mode_threhold = skill_mode_threhold
self.skill_vel_range = skill_vel_range
super().__init__(*args, **kwargs)
def is_skill_mode(self):
if self.move_by_wireless_remote:
return self.low_state_buffer.wirelessRemote.ry > self.skill_mode_threhold
else:
# Not implemented yet
return False
def update_low_state(self, ros_msg):
self.low_state_buffer = ros_msg
if self.move_by_wireless_remote and ros_msg.wirelessRemote.ry > self.skill_mode_threhold:
skill_vel = (self.low_state_buffer.wirelessRemote.ry - self.skill_mode_threhold) / (1.0 - self.skill_mode_threhold)
skill_vel *= self.skill_vel_range[1] - self.skill_vel_range[0]
skill_vel += self.skill_vel_range[0]
self.command_buf[0, 0] = skill_vel
self.command_buf[0, 1] = 0.
self.command_buf[0, 2] = 0.
return
return super().update_low_state(ros_msg)
def main(args):
log_level = rospy.DEBUG if args.debug else rospy.INFO
rospy.init_node("a1_legged_gym_" + args.mode, log_level= log_level)
""" Not finished this modification yet """
# if args.logdir is not None:
# rospy.loginfo("Use logdir/config.json to initialize env proxy.")
# with open(osp.join(args.logdir, "config.json"), "r") as f:
# config_dict = json.load(f, object_pairs_hook= OrderedDict)
# else:
# assert args.walkdir is not None, "You must provide at least a --logdir or --walkdir"
# rospy.logwarn("You did not provide logdir, use walkdir/config.json for initializing env proxy.")
# with open(osp.join(args.walkdir, "config.json"), "r") as f:
# config_dict = json.load(f, object_pairs_hook= OrderedDict)
assert args.logdir is not None
with open(osp.join(args.logdir, "config.json"), "r") as f:
config_dict = json.load(f, object_pairs_hook= OrderedDict)
duration = config_dict["sim"]["dt"] * config_dict["control"]["decimation"] # in sec
# config_dict["control"]["stiffness"]["joint"] -= 2.5 # kp
model_device = torch.device("cpu") if args.mode == "upboard" else torch.device("cuda")
unitree_real_env = SkilledA1Real(
robot_namespace= args.namespace,
cfg= config_dict,
forward_depth_topic= "/visual_embedding" if args.mode == "upboard" else "/camera/depth/image_rect_raw",
forward_depth_embedding_dims= config_dict["policy"]["visual_latent_size"] if args.mode == "upboard" else None,
move_by_wireless_remote= True,
skill_vel_range= config_dict["commands"]["ranges"]["lin_vel_x"],
model_device= model_device,
# extra_cfg= dict(
# motor_strength= torch.tensor([
# 1., 1./0.9, 1./0.9,
# 1., 1./0.9, 1./0.9,
# 1., 1., 1.,
# 1., 1., 1.,
# ], dtype= torch.float32, device= model_device, requires_grad= False),
# ),
)
model = getattr(modules, config_dict["runner"]["policy_class_name"])(
num_actor_obs= unitree_real_env.num_obs,
num_critic_obs= unitree_real_env.num_privileged_obs,
num_actions= 12,
obs_segments= unitree_real_env.obs_segments,
privileged_obs_segments= unitree_real_env.privileged_obs_segments,
**config_dict["policy"],
)
config_dict["terrain"]["measure_heights"] = False
# load the model with the latest checkpoint
model_names = [i for i in os.listdir(args.logdir) if i.startswith("model_")]
model_names.sort(key= lambda x: int(x.split("_")[-1].split(".")[0]))
state_dict = torch.load(osp.join(args.logdir, model_names[-1]), map_location= "cpu")
model.load_state_dict(state_dict["model_state_dict"])
model.to(model_device)
model.eval()
rospy.loginfo("duration: {}, motor Kp: {}, motor Kd: {}".format(
duration,
config_dict["control"]["stiffness"]["joint"],
config_dict["control"]["damping"]["joint"],
))
# rospy.loginfo("[Env] motor strength: {}".format(unitree_real_env.motor_strength))
if args.mode == "jetson":
embeding_publisher = rospy.Publisher(
args.namespace + "/visual_embedding",
Float32MultiArray,
queue_size= 1,
)
# extract and build the torch ScriptFunction
visual_encoder = model.visual_encoder
visual_encoder = torch.jit.script(visual_encoder)
forward_depth_subscriber = rospy.Subscriber(
args.namespace + "/camera/depth/image_rect_raw",
Image,
partial(handle_forward_depth,
model= visual_encoder,
publisher= embeding_publisher,
output_resolution= config_dict["sensor"]["forward_camera"].get(
"output_resolution",
config_dict["sensor"]["forward_camera"]["resolution"],
),
device= model_device,
),
queue_size= 1,
)
rospy.spin()
elif args.mode == "upboard":
# extract and build the torch ScriptFunction
memory_module = model.memory_a
actor_mlp = model.actor
@torch.jit.script
def policy(obs):
recurrent_embedding = memory_module(obs)
actions = actor_mlp(recurrent_embedding.squeeze(0))
return actions
walk_policy, walk_model = load_walk_policy(unitree_real_env, args.walkdir)
using_walk_policy = True # switch between skill policy and walk policy
unitree_real_env.start_ros()
unitree_real_env.wait_untill_ros_working()
rate = rospy.Rate(1 / duration)
with torch.no_grad():
if not args.debug:
standup_procedure(unitree_real_env, rate,
angle_tolerance= 0.2,
kp= 40,
kd= 0.5,
warmup_timesteps= 50,
device= model_device,
)
while not rospy.is_shutdown():
# inference_start_time = rospy.get_time()
# check remote controller and decide which policy to use
if unitree_real_env.is_skill_mode():
if using_walk_policy:
rospy.loginfo_throttle(0.1, "switch to skill policy")
using_walk_policy = False
model.reset()
else:
if not using_walk_policy:
rospy.loginfo_throttle(0.1, "switch to walk policy")
using_walk_policy = True
walk_model.reset()
if not using_walk_policy:
obs = unitree_real_env.get_obs()
actions = policy(obs)
else:
walk_obs = unitree_real_env._get_proprioception_obs()
actions = walk_policy(walk_obs)
unitree_real_env.send_action(actions)
# unitree_real_env.send_action(torch.zeros((1, 12)))
# inference_duration = rospy.get_time() - inference_start_time
# rospy.loginfo("inference duration: {:.3f}".format(inference_duration))
# rospy.loginfo("visual_latency: %f", rospy.get_time() - unitree_real_env.forward_depth_embedding_stamp.to_sec())
# motor_temperatures = [motor_state.temperature for motor_state in unitree_real_env.low_state_buffer.motorState]
# rospy.loginfo_throttle(10, " ".join(["motor_temperatures:"] + ["{:d},".format(t) for t in motor_temperatures[:12]]))
rate.sleep()
if unitree_real_env.low_state_buffer.wirelessRemote.btn.components.down:
rospy.loginfo_throttle(0.1, "model reset")
model.reset()
walk_model.reset()
if unitree_real_env.low_state_buffer.wirelessRemote.btn.components.L2 or unitree_real_env.low_state_buffer.wirelessRemote.btn.components.R2:
unitree_real_env.publish_legs_cmd(unitree_real_env.default_dof_pos.unsqueeze(0), kp= 2, kd= 0.5)
rospy.signal_shutdown("Controller send stop signal, exiting")
elif args.mode == "full":
# extract and build the torch ScriptFunction
visual_obs_slice = get_obs_slice(unitree_real_env.obs_segments, "forward_depth")
visual_encoder = model.visual_encoder
memory_module = model.memory_a
actor_mlp = model.actor
@torch.jit.script
def policy(observations: torch.Tensor, obs_start: int, obs_stop: int, obs_shape: Tuple[int, int, int]):
visual_latent = visual_encoder(
observations[..., obs_start:obs_stop].reshape(-1, *obs_shape)
).reshape(1, -1)
obs = torch.cat([
observations[..., :obs_start],
visual_latent,
observations[..., obs_stop:],
], dim= -1)
recurrent_embedding = memory_module(obs)
actions = actor_mlp(recurrent_embedding.squeeze(0))
return actions
unitree_real_env.start_ros()
unitree_real_env.wait_untill_ros_working()
rate = rospy.Rate(1 / duration)
with torch.no_grad():
while not rospy.is_shutdown():
# inference_start_time = rospy.get_time()
obs = unitree_real_env.get_obs()
actions = policy(obs,
obs_start= visual_obs_slice[0].start.item(),
obs_stop= visual_obs_slice[0].stop.item(),
obs_shape= visual_obs_slice[1],
)
unitree_real_env.send_action(actions)
# inference_duration = rospy.get_time() - inference_start_time
motor_temperatures = [motor_state.temperature for motor_state in unitree_real_env.low_state_buffer.motorState]
rospy.loginfo_throttle(10, " ".join(["motor_temperatures:"] + ["{:d},".format(t) for t in motor_temperatures[:12]]))
rate.sleep()
if unitree_real_env.low_state_buffer.wirelessRemote.btn.components.L2 or unitree_real_env.low_state_buffer.wirelessRemote.btn.components.R2:
unitree_real_env.publish_legs_cmd(unitree_real_env.default_dof_pos.unsqueeze(0), kp= 20, kd= 0.5)
rospy.signal_shutdown("Controller send stop signal, exiting")
else:
rospy.logfatal("Unknown mode, exiting")
if __name__ == "__main__":
""" The script to run the A1 script in ROS.
It's designed as a main function and not designed to be a scalable code.
"""
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--namespace",
type= str,
default= "/a112138",
)
parser.add_argument("--logdir",
type= str,
help= "The log directory of the trained model",
default= None,
)
parser.add_argument("--walkdir",
type= str,
help= "The log directory of the walking model, not for the skills.",
default= None,
)
parser.add_argument("--mode",
type= str,
help= "The mode to determine which computer to run on.",
choices= ["jetson", "upboard", "full"],
)
parser.add_argument("--debug",
action= "store_true",
)
args = parser.parse_args()
main(args)

View File

@ -0,0 +1,317 @@
import os
import os.path as osp
import numpy as np
import torch
import json
from functools import partial
from collections import OrderedDict
from a1_real import UnitreeA1Real, resize2d
from rsl_rl import modules
import rospy
from unitree_legged_msgs.msg import Float32MultiArrayStamped
from sensor_msgs.msg import Image
import ros_numpy
import pyrealsense2 as rs
def get_encoder_script(logdir):
with open(osp.join(logdir, "config.json"), "r") as f:
config_dict = json.load(f, object_pairs_hook= OrderedDict)
model_device = torch.device("cuda")
unitree_real_env = UnitreeA1Real(
robot_namespace= "a112138",
cfg= config_dict,
forward_depth_topic= "", # this env only computes parameters to build the model
forward_depth_embedding_dims= None,
model_device= model_device,
)
model = getattr(modules, config_dict["runner"]["policy_class_name"])(
num_actor_obs= unitree_real_env.num_obs,
num_critic_obs= unitree_real_env.num_privileged_obs,
num_actions= 12,
obs_segments= unitree_real_env.obs_segments,
privileged_obs_segments= unitree_real_env.privileged_obs_segments,
**config_dict["policy"],
)
model_names = [i for i in os.listdir(logdir) if i.startswith("model_")]
model_names.sort(key= lambda x: int(x.split("_")[-1].split(".")[0]))
state_dict = torch.load(osp.join(args.logdir, model_names[-1]), map_location= "cpu")
model.load_state_dict(state_dict["model_state_dict"])
model.to(model_device)
model.eval()
visual_encoder = model.visual_encoder
script = torch.jit.script(visual_encoder)
return script, model_device
def get_input_filter(args):
""" This is the filter different from the simulator, but try to close the gap. """
with open(osp.join(args.logdir, "config.json"), "r") as f:
config_dict = json.load(f, object_pairs_hook= OrderedDict)
image_resolution = config_dict["sensor"]["forward_camera"].get(
"output_resolution",
config_dict["sensor"]["forward_camera"]["resolution"],
)
depth_range = config_dict["sensor"]["forward_camera"].get(
"depth_range",
[0.0, 3.0],
)
depth_range = (depth_range[0] * 1000, depth_range[1] * 1000) # [m] -> [mm]
crop_top, crop_bottom, crop_left, crop_right = args.crop_top, args.crop_bottom, args.crop_left, args.crop_right
crop_far = args.crop_far * 1000
def input_filter(depth_image: torch.Tensor,
crop_top: int,
crop_bottom: int,
crop_left: int,
crop_right: int,
crop_far: float,
depth_min: int,
depth_max: int,
output_height: int,
output_width: int,
):
""" depth_image must have shape [1, 1, H, W] """
depth_image = depth_image[:, :,
crop_top: -crop_bottom-1,
crop_left: -crop_right-1,
]
depth_image[depth_image > crop_far] = depth_max
depth_image = torch.clip(
depth_image,
depth_min,
depth_max,
) / (depth_max - depth_min)
depth_image = resize2d(depth_image, (output_height, output_width))
return depth_image
# input_filter = torch.jit.script(input_filter)
return partial(input_filter,
crop_top= crop_top,
crop_bottom= crop_bottom,
crop_left= crop_left,
crop_right= crop_right,
crop_far= crop_far,
depth_min= depth_range[0],
depth_max= depth_range[1],
output_height= image_resolution[0],
output_width= image_resolution[1],
), depth_range
def get_started_pipeline(
height= 480,
width= 640,
fps= 30,
enable_rgb= False,
):
# By default, rgb is not used.
pipeline = rs.pipeline()
config = rs.config()
config.enable_stream(rs.stream.depth, width, height, rs.format.z16, fps)
if enable_rgb:
config.enable_stream(rs.stream.color, width, height, rs.format.rgb8, fps)
profile = pipeline.start(config)
# build the sensor filter
hole_filling_filter = rs.hole_filling_filter(2)
spatial_filter = rs.spatial_filter()
spatial_filter.set_option(rs.option.filter_magnitude, 5)
spatial_filter.set_option(rs.option.filter_smooth_alpha, 0.75)
spatial_filter.set_option(rs.option.filter_smooth_delta, 1)
spatial_filter.set_option(rs.option.holes_fill, 4)
temporal_filter = rs.temporal_filter()
temporal_filter.set_option(rs.option.filter_smooth_alpha, 0.75)
temporal_filter.set_option(rs.option.filter_smooth_delta, 1)
# decimation_filter = rs.decimation_filter()
# decimation_filter.set_option(rs.option.filter_magnitude, 2)
def filter_func(frame):
frame = hole_filling_filter.process(frame)
frame = spatial_filter.process(frame)
frame = temporal_filter.process(frame)
# frame = decimation_filter.process(frame)
return frame
return pipeline, filter_func
def main(args):
rospy.init_node("a1_legged_gym_jetson")
input_filter, depth_range = get_input_filter(args)
model_script, model_device = get_encoder_script(args.logdir)
with open(osp.join(args.logdir, "config.json"), "r") as f:
config_dict = json.load(f, object_pairs_hook= OrderedDict)
if config_dict.get("sensor", dict()).get("forward_camera", dict()).get("refresh_duration", None) is not None:
refresh_duration = config_dict["sensor"]["forward_camera"]["refresh_duration"]
ros_rate = rospy.Rate(1.0 / refresh_duration)
rospy.loginfo("Using refresh duration {}s".format(refresh_duration))
else:
ros_rate = rospy.Rate(args.fps)
rs_pipeline, rs_filters = get_started_pipeline(
height= args.height,
width= args.width,
fps= args.fps,
enable_rgb= args.enable_rgb,
)
# gyro_pipeline = rs.pipeline()
# gyro_config = rs.config()
# gyro_config.enable_stream(rs.stream.gyro, rs.format.motion_xyz32f, 200)
# gyro_profile = gyro_pipeline.start(gyro_config)
embedding_publisher = rospy.Publisher(
args.namespace + "/visual_embedding",
Float32MultiArrayStamped,
queue_size= 1,
)
if args.enable_vis:
depth_image_publisher = rospy.Publisher(
args.namespace + "/camera/depth/image_rect_raw",
Image,
queue_size= 1,
)
network_input_publisher = rospy.Publisher(
args.namespace + "/camera/depth/network_input_raw",
Image,
queue_size= 1,
)
if args.enable_rgb:
rgb_image_publisher = rospy.Publisher(
args.namespace + "/camera/color/image_raw",
Image,
queue_size= 1,
)
rospy.loginfo("Depth range is clipped to [{}, {}] and normalized".format(depth_range[0], depth_range[1]))
rospy.loginfo("ROS, model, realsense have been initialized.")
if args.enable_vis:
rospy.loginfo("Visualization enabled, sending depth{} images".format(", rgb" if args.enable_rgb else ""))
try:
embedding_msg = Float32MultiArrayStamped()
embedding_msg.header.frame_id = args.namespace + "/camera_depth_optical_frame"
frame_got = False
while not rospy.is_shutdown():
# Wait for the depth image
frames = rs_pipeline.wait_for_frames(int( \
config_dict["sensor"]["forward_camera"]["latency_range"][1] \
* 1000)) # ms
embedding_msg.header.stamp = rospy.Time.now()
depth_frame = frames.get_depth_frame()
if not depth_frame:
continue
if not frame_got:
frame_got = True
rospy.loginfo("Realsense frame recieved. Sending embeddings...")
if args.enable_rgb:
color_frame = frames.get_color_frame()
# Use this branch to log the time when image is acquired
if args.enable_vis and not color_frame is None:
color_frame = np.asanyarray(color_frame.get_data())
rgb_image_msg = ros_numpy.msgify(Image, color_frame, encoding= "rgb8")
rgb_image_msg.header.stamp = rospy.Time.now()
rgb_image_msg.header.frame_id = args.namespace + "/camera_color_optical_frame"
rgb_image_publisher.publish(rgb_image_msg)
# Process the depth image and publish
depth_frame = rs_filters(depth_frame)
depth_image_ = np.asanyarray(depth_frame.get_data())
depth_image = torch.from_numpy(depth_image_.astype(np.float32)).unsqueeze(0).unsqueeze(0).to(model_device)
depth_image = input_filter(depth_image)
with torch.no_grad():
depth_embedding = model_script(depth_image).reshape(-1).cpu().numpy()
embedding_msg.header.seq += 1
embedding_msg.data = depth_embedding.tolist()
embedding_publisher.publish(embedding_msg)
# Publish the acquired image if needed
if args.enable_vis:
depth_image_msg = ros_numpy.msgify(Image, depth_image_, encoding= "16UC1")
depth_image_msg.header.stamp = rospy.Time.now()
depth_image_msg.header.frame_id = args.namespace + "/camera_depth_optical_frame"
depth_image_publisher.publish(depth_image_msg)
network_input_np = (\
depth_image.detach().cpu().numpy()[0, 0] * (depth_range[1] - depth_range[0]) \
+ depth_range[0]
).astype(np.uint16)
network_input_msg = ros_numpy.msgify(Image, network_input_np, encoding= "16UC1")
network_input_msg.header.stamp = rospy.Time.now()
network_input_msg.header.frame_id = args.namespace + "/camera_depth_optical_frame"
network_input_publisher.publish(network_input_msg)
ros_rate.sleep()
finally:
rs_pipeline.stop()
if __name__ == "__main__":
""" This script is designed to load the model and process the realsense image directly
from realsense SDK without realsense ROS wrapper
"""
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--namespace",
type= str,
default= "/a112138",
)
parser.add_argument("--logdir",
type= str,
help= "The log directory of the trained model",
)
parser.add_argument("--height",
type= int,
default= 240,
help= "The height of the realsense image",
)
parser.add_argument("--width",
type= int,
default= 424,
help= "The width of the realsense image",
)
parser.add_argument("--fps",
type= int,
default= 30,
help= "The fps of the realsense image",
)
parser.add_argument("--crop_left",
type= int,
default= 60,
help= "num of pixel to crop in the original pyrealsense readings."
)
parser.add_argument("--crop_right",
type= int,
default= 46,
help= "num of pixel to crop in the original pyrealsense readings."
)
parser.add_argument("--crop_top",
type= int,
default= 0,
help= "num of pixel to crop in the original pyrealsense readings."
)
parser.add_argument("--crop_bottom",
type= int,
default= 0,
help= "num of pixel to crop in the original pyrealsense readings."
)
parser.add_argument("--crop_far",
type= float,
default= 3.0,
help= "asside from the config far limit, make all depth readings larger than this value to be 3.0 in un-normalized network input."
)
parser.add_argument("--enable_rgb",
action= "store_true",
help= "Whether to enable rgb image",
)
parser.add_argument("--enable_vis",
action= "store_true",
help= "Whether to publish realsense image",
)
args = parser.parse_args()
main(args)

View File

@ -0,0 +1,317 @@
import os
import os.path as osp
import numpy as np
import torch
import json
from functools import partial
from collections import OrderedDict
from a1_real import UnitreeA1Real, resize2d
from rsl_rl import modules
import rospy
from unitree_legged_msgs.msg import Float32MultiArrayStamped
from sensor_msgs.msg import Image
import ros_numpy
import pyrealsense2 as rs
def get_encoder_script(logdir):
with open(osp.join(logdir, "config.json"), "r") as f:
config_dict = json.load(f, object_pairs_hook= OrderedDict)
model_device = torch.device("cuda")
unitree_real_env = UnitreeA1Real(
robot_namespace= "a112138",
cfg= config_dict,
forward_depth_topic= "", # this env only computes parameters to build the model
forward_depth_embedding_dims= None,
model_device= model_device,
)
model = getattr(modules, config_dict["runner"]["policy_class_name"])(
num_actor_obs= unitree_real_env.num_obs,
num_critic_obs= unitree_real_env.num_privileged_obs,
num_actions= 12,
obs_segments= unitree_real_env.obs_segments,
privileged_obs_segments= unitree_real_env.privileged_obs_segments,
**config_dict["policy"],
)
model_names = [i for i in os.listdir(logdir) if i.startswith("model_")]
model_names.sort(key= lambda x: int(x.split("_")[-1].split(".")[0]))
state_dict = torch.load(osp.join(args.logdir, model_names[-1]), map_location= "cpu")
model.load_state_dict(state_dict["model_state_dict"])
model.to(model_device)
model.eval()
visual_encoder = model.visual_encoder
script = torch.jit.script(visual_encoder)
return script, model_device
def get_input_filter(args):
""" This is the filter different from the simulator, but try to close the gap. """
with open(osp.join(args.logdir, "config.json"), "r") as f:
config_dict = json.load(f, object_pairs_hook= OrderedDict)
image_resolution = config_dict["sensor"]["forward_camera"].get(
"output_resolution",
config_dict["sensor"]["forward_camera"]["resolution"],
)
depth_range = config_dict["sensor"]["forward_camera"].get(
"depth_range",
[0.0, 3.0],
)
depth_range = (depth_range[0] * 1000, depth_range[1] * 1000) # [m] -> [mm]
crop_top, crop_bottom, crop_left, crop_right = args.crop_top, args.crop_bottom, args.crop_left, args.crop_right
crop_far = args.crop_far * 1000
def input_filter(depth_image: torch.Tensor,
crop_top: int,
crop_bottom: int,
crop_left: int,
crop_right: int,
crop_far: float,
depth_min: int,
depth_max: int,
output_height: int,
output_width: int,
):
""" depth_image must have shape [1, 1, H, W] """
depth_image = depth_image[:, :,
crop_top: -crop_bottom-1,
crop_left: -crop_right-1,
]
depth_image[depth_image > crop_far] = depth_max
depth_image = torch.clip(
depth_image,
depth_min,
depth_max,
) / (depth_max - depth_min)
depth_image = resize2d(depth_image, (output_height, output_width))
return depth_image
# input_filter = torch.jit.script(input_filter)
return partial(input_filter,
crop_top= crop_top,
crop_bottom= crop_bottom,
crop_left= crop_left,
crop_right= crop_right,
crop_far= crop_far,
depth_min= depth_range[0],
depth_max= depth_range[1],
output_height= image_resolution[0],
output_width= image_resolution[1],
), depth_range
def get_started_pipeline(
height= 480,
width= 640,
fps= 30,
enable_rgb= False,
):
# By default, rgb is not used.
pipeline = rs.pipeline()
config = rs.config()
config.enable_stream(rs.stream.depth, width, height, rs.format.z16, fps)
if enable_rgb:
config.enable_stream(rs.stream.color, width, height, rs.format.rgb8, fps)
profile = pipeline.start(config)
# build the sensor filter
hole_filling_filter = rs.hole_filling_filter(2)
spatial_filter = rs.spatial_filter()
spatial_filter.set_option(rs.option.filter_magnitude, 5)
spatial_filter.set_option(rs.option.filter_smooth_alpha, 0.75)
spatial_filter.set_option(rs.option.filter_smooth_delta, 1)
spatial_filter.set_option(rs.option.holes_fill, 4)
temporal_filter = rs.temporal_filter()
temporal_filter.set_option(rs.option.filter_smooth_alpha, 0.75)
temporal_filter.set_option(rs.option.filter_smooth_delta, 1)
# decimation_filter = rs.decimation_filter()
# decimation_filter.set_option(rs.option.filter_magnitude, 2)
def filter_func(frame):
frame = hole_filling_filter.process(frame)
frame = spatial_filter.process(frame)
frame = temporal_filter.process(frame)
# frame = decimation_filter.process(frame)
return frame
return pipeline, filter_func
def main(args):
rospy.init_node("a1_legged_gym_jetson")
input_filter, depth_range = get_input_filter(args)
model_script, model_device = get_encoder_script(args.logdir)
with open(osp.join(args.logdir, "config.json"), "r") as f:
config_dict = json.load(f, object_pairs_hook= OrderedDict)
if config_dict.get("sensor", dict()).get("forward_camera", dict()).get("refresh_duration", None) is not None:
refresh_duration = config_dict["sensor"]["forward_camera"]["refresh_duration"]
ros_rate = rospy.Rate(1.0 / refresh_duration)
rospy.loginfo("Using refresh duration {}s".format(refresh_duration))
else:
ros_rate = rospy.Rate(args.fps)
rs_pipeline, rs_filters = get_started_pipeline(
height= args.height,
width= args.width,
fps= args.fps,
enable_rgb= args.enable_rgb,
)
# gyro_pipeline = rs.pipeline()
# gyro_config = rs.config()
# gyro_config.enable_stream(rs.stream.gyro, rs.format.motion_xyz32f, 200)
# gyro_profile = gyro_pipeline.start(gyro_config)
embedding_publisher = rospy.Publisher(
args.namespace + "/visual_embedding",
Float32MultiArrayStamped,
queue_size= 1,
)
if args.enable_vis:
depth_image_publisher = rospy.Publisher(
args.namespace + "/camera/depth/image_rect_raw",
Image,
queue_size= 1,
)
network_input_publisher = rospy.Publisher(
args.namespace + "/camera/depth/network_input_raw",
Image,
queue_size= 1,
)
if args.enable_rgb:
rgb_image_publisher = rospy.Publisher(
args.namespace + "/camera/color/image_raw",
Image,
queue_size= 1,
)
rospy.loginfo("Depth range is clipped to [{}, {}] and normalized".format(depth_range[0], depth_range[1]))
rospy.loginfo("ROS, model, realsense have been initialized.")
if args.enable_vis:
rospy.loginfo("Visualization enabled, sending depth{} images".format(", rgb" if args.enable_rgb else ""))
try:
embedding_msg = Float32MultiArrayStamped()
embedding_msg.header.frame_id = args.namespace + "/camera_depth_optical_frame"
frame_got = False
while not rospy.is_shutdown():
# Wait for the depth image
frames = rs_pipeline.wait_for_frames(int( \
config_dict["sensor"]["forward_camera"]["latency_range"][1] \
* 1000)) # ms
embedding_msg.header.stamp = rospy.Time.now()
depth_frame = frames.get_depth_frame()
if not depth_frame:
continue
if not frame_got:
frame_got = True
rospy.loginfo("Realsense frame recieved. Sending embeddings...")
if args.enable_rgb:
color_frame = frames.get_color_frame()
# Use this branch to log the time when image is acquired
if args.enable_vis and not color_frame is None:
color_frame = np.asanyarray(color_frame.get_data())
rgb_image_msg = ros_numpy.msgify(Image, color_frame, encoding= "rgb8")
rgb_image_msg.header.stamp = rospy.Time.now()
rgb_image_msg.header.frame_id = args.namespace + "/camera_color_optical_frame"
rgb_image_publisher.publish(rgb_image_msg)
# Process the depth image and publish
depth_frame = rs_filters(depth_frame)
depth_image_ = np.asanyarray(depth_frame.get_data())
depth_image = torch.from_numpy(depth_image_.astype(np.float32)).unsqueeze(0).unsqueeze(0).to(model_device)
depth_image = input_filter(depth_image)
with torch.no_grad():
depth_embedding = model_script(depth_image).reshape(-1).cpu().numpy()
embedding_msg.header.seq += 1
embedding_msg.data = depth_embedding.tolist()
embedding_publisher.publish(embedding_msg)
# Publish the acquired image if needed
if args.enable_vis:
depth_image_msg = ros_numpy.msgify(Image, depth_image_, encoding= "16UC1")
depth_image_msg.header.stamp = rospy.Time.now()
depth_image_msg.header.frame_id = args.namespace + "/camera_depth_optical_frame"
depth_image_publisher.publish(depth_image_msg)
network_input_np = (\
depth_image.detach().cpu().numpy()[0, 0] * (depth_range[1] - depth_range[0]) \
+ depth_range[0]
).astype(np.uint16)
network_input_msg = ros_numpy.msgify(Image, network_input_np, encoding= "16UC1")
network_input_msg.header.stamp = rospy.Time.now()
network_input_msg.header.frame_id = args.namespace + "/camera_depth_optical_frame"
network_input_publisher.publish(network_input_msg)
ros_rate.sleep()
finally:
rs_pipeline.stop()
if __name__ == "__main__":
""" This script is designed to load the model and process the realsense image directly
from realsense SDK without realsense ROS wrapper
"""
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--namespace",
type= str,
default= "/a112138",
)
parser.add_argument("--logdir",
type= str,
help= "The log directory of the trained model",
)
parser.add_argument("--height",
type= int,
default= 270,
help= "The height of the realsense image",
)
parser.add_argument("--width",
type= int,
default= 480,
help= "The width of the realsense image",
)
parser.add_argument("--fps",
type= int,
default= 30,
help= "The fps of the realsense image",
)
parser.add_argument("--crop_left",
type= int,
default= 60,
help= "num of pixel to crop in the original pyrealsense readings."
)
parser.add_argument("--crop_right",
type= int,
default= 46,
help= "num of pixel to crop in the original pyrealsense readings."
)
parser.add_argument("--crop_top",
type= int,
default= 0,
help= "num of pixel to crop in the original pyrealsense readings."
)
parser.add_argument("--crop_bottom",
type= int,
default= 0,
help= "num of pixel to crop in the original pyrealsense readings."
)
parser.add_argument("--crop_far",
type= float,
default= 3.0,
help= "asside from the config far limit, make all depth readings larger than this value to be 3.0 in un-normalized network input."
)
parser.add_argument("--enable_rgb",
action= "store_true",
help= "Whether to enable rgb image",
)
parser.add_argument("--enable_vis",
action= "store_true",
help= "Whether to publish realsense image",
)
args = parser.parse_args()
main(args)

View File

@ -164,6 +164,12 @@ class TPPO(PPO):
self.actor_critic.action_mean - minibatch.action_labels, self.actor_critic.action_mean - minibatch.action_labels,
dim= -1 dim= -1
) )
elif self.distill_target == "l1":
dist_loss = torch.norm(
self.actor_critic.action_mean - minibatch.action_labels,
dim= -1,
p= 1,
)
elif self.distill_target == "tanh": elif self.distill_target == "tanh":
# for tanh, similar to loss function for sigmoid, refer to https://stats.stackexchange.com/questions/12754/matching-loss-function-for-tanh-units-in-a-neural-net # for tanh, similar to loss function for sigmoid, refer to https://stats.stackexchange.com/questions/12754/matching-loss-function-for-tanh-units-in-a-neural-net
dist_loss = F.binary_cross_entropy( dist_loss = F.binary_cross_entropy(

View File

@ -15,6 +15,8 @@ class ActorCriticFieldMutex(ActorCriticMutex):
def __init__(self, def __init__(self,
*args, *args,
cmd_vel_mapping = dict(), cmd_vel_mapping = dict(),
reset_non_selected = "all",
action_smoothing_buffer_len = 1,
**kwargs, **kwargs,
): ):
""" NOTE: This implementation only supports subpolicy output to (-1., 1.) range. """ NOTE: This implementation only supports subpolicy output to (-1., 1.) range.
@ -24,6 +26,9 @@ class ActorCriticFieldMutex(ActorCriticMutex):
""" """
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.cmd_vel_mapping = cmd_vel_mapping self.cmd_vel_mapping = cmd_vel_mapping
self.reset_non_selected = reset_non_selected
self.action_smoothing_buffer_len = action_smoothing_buffer_len
self.action_smoothing_buffer = None
# load cmd_scale to assign the cmd_vel during overriding # load cmd_scale to assign the cmd_vel during overriding
self.cmd_scales = [] self.cmd_scales = []
@ -101,18 +106,33 @@ class ActorCriticFieldMutex(ActorCriticMutex):
def act_inference(self, observations): def act_inference(self, observations):
# run entire batch for each sub policy in case the batch size and length problem. # run entire batch for each sub policy in case the batch size and length problem.
policy_selection = self.get_policy_selection(observations) policy_selection = self.get_policy_selection(observations)
if self.action_smoothing_buffer is None:
self.action_smoothing_buffer = torch.zeros(
self.action_smoothing_buffer_len,
*policy_selection.shape,
device= policy_selection.device,
dtype= torch.float,
) # (len, N, ..., selection)
self.action_smoothing_buffer = torch.cat([
self.action_smoothing_buffer[1:],
policy_selection.unsqueeze(0),
], dim= 0) # put the new one at the end
observations = self.recover_last_action(observations, policy_selection) observations = self.recover_last_action(observations, policy_selection)
if self.cmd_vel_mapping: if self.cmd_vel_mapping:
observations = self.override_cmd_vel(observations, policy_selection) observations = self.override_cmd_vel(observations, policy_selection)
outputs = [p.act_inference(observations) for p in self.submodules] outputs = [p.act_inference(observations) for p in self.submodules]
output = torch.empty_like(outputs[0]) output = torch.zeros_like(outputs[0])
for idx, out in enumerate(outputs): for idx, out in enumerate(outputs):
output[policy_selection[..., idx]] = torch.clip( output += out * getattr(self, "subpolicy_action_scale_{:d}".format(idx)) / self.env_action_scale \
out[policy_selection[..., idx]] * getattr(self, "subpolicy_action_scale_{:d}".format(idx)) / self.env_action_scale, * self.action_smoothing_buffer[..., idx].mean(dim= 0).unsqueeze(-1)
-1., 1.,
)
# choose one or none reset method # choose one or none reset method
self.submodules[idx].reset(~policy_selection[..., idx]) if self.reset_non_selected == "all" or self.reset_non_selected == True:
self.submodules[idx].reset(self.action_smoothing_buffer[..., idx].sum(0) == 0)
elif self.reset_non_selected == "when_skill" and idx > 0:
self.submodules[idx].reset(torch.logical_and(
~policy_selection[..., idx],
~policy_selection[..., 0],
))
# self.submodules[idx].reset(torch.ones(observations.shape[0], dtype= bool, device= observations.device)) # self.submodules[idx].reset(torch.ones(observations.shape[0], dtype= bool, device= observations.device))
return output return output
@ -122,16 +142,18 @@ class ActorCriticFieldMutex(ActorCriticMutex):
return super().reset(dones) return super().reset(dones)
class ActorCriticClimbMutex(ActorCriticFieldMutex): class ActorCriticClimbMutex(ActorCriticFieldMutex):
""" A variant to handle climb-up and climb-down with seperate policies """ A variant to handle jump-up and jump-down with seperate policies
Climb-down policy will be the last submodule in the list Jump-down policy will be the last submodule in the list
""" """
JUMP_OBSTACLE_ID = 3 # starting from 0, referring to barrker_track.py:BarrierTrack.track_options_id_dict JUMP_OBSTACLE_ID = 3 # starting from 0, referring to barrker_track.py:BarrierTrack.track_options_id_dict
def __init__(self, def __init__(self,
*args, *args,
sub_policy_paths: list = None, sub_policy_paths: list = None,
climb_down_policy_path: str = None, jump_down_policy_path: str = None,
jump_down_vel: float = None, # can be tuple/list, use it to stop using jump up velocity command
**kwargs,): **kwargs,):
sub_policy_paths = sub_policy_paths + [climb_down_policy_path] sub_policy_paths = sub_policy_paths + [jump_down_policy_path]
self.jump_down_vel = jump_down_vel
super().__init__( super().__init__(
*args, *args,
sub_policy_paths= sub_policy_paths, sub_policy_paths= sub_policy_paths,
@ -140,23 +162,28 @@ class ActorCriticClimbMutex(ActorCriticFieldMutex):
def resample_cmd_vel_current(self, dones=None): def resample_cmd_vel_current(self, dones=None):
return_ = super().resample_cmd_vel_current(dones) return_ = super().resample_cmd_vel_current(dones)
self.cmd_vel_current[len(self.submodules) - 1] = self.cmd_vel_current[self.JUMP_OBSTACLE_ID] if self.jump_down_vel is None:
self.cmd_vel_current[len(self.submodules) - 1] = self.cmd_vel_current[self.JUMP_OBSTACLE_ID]
elif isinstance(self.jump_down_vel, (tuple, list)):
self.cmd_vel_current[len(self.submodules) - 1] = np.random.uniform(*self.jump_down_vel)
else:
self.cmd_vel_current[len(self.submodules) - 1] = self.jump_down_vel
return return_ return return_
def get_policy_selection(self, observations): def get_policy_selection(self, observations):
obstacle_id_onehot = super().get_policy_selection(observations) obstacle_id_onehot = super().get_policy_selection(observations)
obs_slice = get_obs_slice(self.obs_segments, "engaging_block") obs_slice = get_obs_slice(self.obs_segments, "engaging_block")
engaging_block_obs = observations[..., obs_slice[0]].reshape(*observations.shape[:-1], *obs_slice[1]) engaging_block_obs = observations[..., obs_slice[0]].reshape(*observations.shape[:-1], *obs_slice[1])
climb_up_mask = engaging_block_obs[..., -1] > 0 # climb-up or climb-down jump_up_mask = engaging_block_obs[..., -1] > 0 # jump-up or jump-down
obstacle_id_onehot = torch.cat([ obstacle_id_onehot = torch.cat([
obstacle_id_onehot, obstacle_id_onehot,
torch.logical_and( torch.logical_and(
obstacle_id_onehot[..., self.JUMP_OBSTACLE_ID], obstacle_id_onehot[..., self.JUMP_OBSTACLE_ID],
torch.logical_not(climb_up_mask), torch.logical_not(jump_up_mask),
).unsqueeze(-1) ).unsqueeze(-1)
], dim= -1) ], dim= -1)
obstacle_id_onehot[..., self.JUMP_OBSTACLE_ID] = torch.logical_and( obstacle_id_onehot[..., self.JUMP_OBSTACLE_ID] = torch.logical_and(
obstacle_id_onehot[..., self.JUMP_OBSTACLE_ID], obstacle_id_onehot[..., self.JUMP_OBSTACLE_ID],
climb_up_mask, jump_up_mask,
) )
return obstacle_id_onehot.to(torch.bool) # (N, ..., selection) return obstacle_id_onehot.to(torch.bool) # (N, ..., selection)

View File

@ -76,6 +76,7 @@ class OnPolicyRunner:
self.tot_timesteps = 0 self.tot_timesteps = 0
self.tot_time = 0 self.tot_time = 0
self.current_learning_iteration = 0 self.current_learning_iteration = 0
self.log_interval = self.cfg.get("log_interval", 1)
_, _ = self.env.reset() _, _ = self.env.reset()
@ -130,7 +131,7 @@ class OnPolicyRunner:
losses, stats = self.alg.update(self.current_learning_iteration) losses, stats = self.alg.update(self.current_learning_iteration)
stop = time.time() stop = time.time()
learn_time = stop - start learn_time = stop - start
if self.log_dir is not None: if self.log_dir is not None and self.current_learning_iteration % self.log_interval == 0:
self.log(locals()) self.log(locals())
if self.current_learning_iteration % self.save_interval == 0 and self.current_learning_iteration > start_iter: if self.current_learning_iteration % self.save_interval == 0 and self.current_learning_iteration > start_iter:
self.save(os.path.join(self.log_dir, 'model_{}.pt'.format(self.current_learning_iteration))) self.save(os.path.join(self.log_dir, 'model_{}.pt'.format(self.current_learning_iteration)))
@ -238,7 +239,7 @@ class OnPolicyRunner:
def load(self, path, load_optimizer=True): def load(self, path, load_optimizer=True):
loaded_dict = torch.load(path) loaded_dict = torch.load(path)
self.alg.actor_critic.load_state_dict(loaded_dict['model_state_dict']) self.alg.actor_critic.load_state_dict(loaded_dict['model_state_dict'])
if load_optimizer: if load_optimizer and "optimizer_state_dict" in loaded_dict:
self.alg.optimizer.load_state_dict(loaded_dict['optimizer_state_dict']) self.alg.optimizer.load_state_dict(loaded_dict['optimizer_state_dict'])
if "lr_scheduler_state_dict" in loaded_dict: if "lr_scheduler_state_dict" in loaded_dict:
if not hasattr(self.alg, "lr_scheduler"): if not hasattr(self.alg, "lr_scheduler"):

View File

@ -12,6 +12,7 @@ class TwoStageRunner(OnPolicyRunner):
# load some configs and their default values # load some configs and their default values
self.pretrain_iterations = self.cfg.get("pretrain_iterations", 0) self.pretrain_iterations = self.cfg.get("pretrain_iterations", 0)
self.log_interval = self.cfg.get("log_interval", 50)
assert "pretrain_dataset" in self.cfg, "pretrain_dataset is not defined in the runner cfg object" assert "pretrain_dataset" in self.cfg, "pretrain_dataset is not defined in the runner cfg object"
self.rollout_dataset = RolloutDataset( self.rollout_dataset = RolloutDataset(
**self.cfg["pretrain_dataset"], **self.cfg["pretrain_dataset"],