parkour/onboard_script/a1_ros_run.py

402 lines
18 KiB
Python

#!/home/unitree/agility_ziwenz_venv/bin/python
import os
import os.path as osp
import json
import numpy as np
import torch
from collections import OrderedDict
from functools import partial
from typing import Tuple
import rospy
from std_msgs.msg import Float32MultiArray
from sensor_msgs.msg import Image
import ros_numpy
from a1_real import UnitreeA1Real, resize2d
from rsl_rl import modules
from rsl_rl.utils.utils import get_obs_slice
@torch.no_grad()
def handle_forward_depth(ros_msg, model, publisher, output_resolution, device):
""" The callback function to handle the forward depth and send the embedding through ROS topic """
buf = ros_numpy.numpify(ros_msg).astype(np.float32)
forward_depth_buf = resize2d(
torch.from_numpy(buf).unsqueeze(0).unsqueeze(0).to(device),
output_resolution,
)
embedding = model(forward_depth_buf)
ros_data = embedding.reshape(-1).cpu().numpy().astype(np.float32)
publisher.publish(Float32MultiArray(data= ros_data.tolist()))
class StandOnlyModel(torch.nn.Module):
def __init__(self, action_scale, dof_pos_scale, tolerance= 0.2, delta= 0.1):
rospy.loginfo("Using stand only model, please make sure the proprioception is 48 dim.")
rospy.loginfo("Using stand only model, -36 to -24 must be joint position.")
super().__init__()
if isinstance(action_scale, (tuple, list)):
self.register_buffer("action_scale", torch.tensor(action_scale))
else:
self.action_scale = action_scale
if isinstance(dof_pos_scale, (tuple, list)):
self.register_buffer("dof_pos_scale", torch.tensor(dof_pos_scale))
else:
self.dof_pos_scale = dof_pos_scale
self.tolerance = tolerance
self.delta = delta
def forward(self, obs):
joint_positions = obs[..., -36:-24] / self.dof_pos_scale
diff_large_mask = torch.abs(joint_positions) > self.tolerance
target_positions = torch.zeros_like(joint_positions)
target_positions[diff_large_mask] = joint_positions[diff_large_mask] - self.delta * torch.sign(joint_positions[diff_large_mask])
return torch.clip(
target_positions / self.action_scale,
-1.0, 1.0,
)
def reset(self, *args, **kwargs):
pass
def load_walk_policy(env, model_dir):
""" Load the walk policy from the model directory """
if model_dir == None:
model = StandOnlyModel(
action_scale= env.action_scale,
dof_pos_scale= env.obs_scales["dof_pos"],
)
policy = torch.jit.script(model)
else:
with open(osp.join(model_dir, "config.json"), "r") as f:
config_dict = json.load(f, object_pairs_hook= OrderedDict)
obs_components = config_dict["env"]["obs_components"]
privileged_obs_components = config_dict["env"].get("privileged_obs_components", obs_components)
model = getattr(modules, config_dict["runner"]["policy_class_name"])(
num_actor_obs= env.get_num_obs_from_components(obs_components),
num_critic_obs= env.get_num_obs_from_components(privileged_obs_components),
num_actions= 12,
**config_dict["policy"],
)
model_names = [i for i in os.listdir(model_dir) if i.startswith("model_")]
model_names.sort(key= lambda x: int(x.split("_")[-1].split(".")[0]))
state_dict = torch.load(osp.join(model_dir, model_names[-1]), map_location= "cpu")
model.load_state_dict(state_dict["model_state_dict"])
model_action_scale = torch.tensor(config_dict["control"]["action_scale"]) if isinstance(config_dict["control"]["action_scale"], (tuple, list)) else torch.tensor([config_dict["control"]["action_scale"]])[0]
if not (torch.is_tensor(model_action_scale) and (model_action_scale == env.action_scale).all()):
action_rescale_ratio = model_action_scale / env.action_scale
print("walk_policy action scaling:", action_rescale_ratio.tolist())
else:
action_rescale_ratio = 1.0
memory_module = model.memory_a
actor_mlp = model.actor
@torch.jit.script
def policy_run(obs):
recurrent_embedding = memory_module(obs)
actions = actor_mlp(recurrent_embedding.squeeze(0))
return actions
if (torch.is_tensor(action_rescale_ratio) and (action_rescale_ratio == 1.).all()) \
or (not torch.is_tensor(action_rescale_ratio) and action_rescale_ratio == 1.):
policy = policy_run
else:
policy = lambda x: policy_run(x) * action_rescale_ratio
return policy, model
def standup_procedure(env, ros_rate, angle_tolerance= 0.1,
kp= None,
kd= None,
warmup_timesteps= 25,
device= "cpu",
):
"""
Args:
warmup_timesteps: the number of timesteps to linearly increase the target position
"""
rospy.loginfo("Robot standing up, please wait ...")
target_pos = torch.zeros((1, 12), device= device, dtype= torch.float32)
standup_timestep_i = 0
while not rospy.is_shutdown():
dof_pos = [env.low_state_buffer.motorState[env.dof_map[i]].q for i in range(12)]
diff = [env.default_dof_pos[i].item() - dof_pos[i] for i in range(12)]
direction = [1 if i > 0 else -1 for i in diff]
if standup_timestep_i < warmup_timesteps:
direction = [standup_timestep_i / warmup_timesteps * i for i in direction]
if all([abs(i) < angle_tolerance for i in diff]):
break
print("max joint error (rad):", max([abs(i) for i in diff]), end= "\r")
for i in range(12):
target_pos[0, i] = dof_pos[i] + direction[i] * angle_tolerance if abs(diff[i]) > angle_tolerance else env.default_dof_pos[i]
env.publish_legs_cmd(target_pos,
kp= kp,
kd= kd,
)
ros_rate.sleep()
standup_timestep_i += 1
rospy.loginfo("Robot stood up! press R1 on the remote control to continue ...")
while not rospy.is_shutdown():
if env.low_state_buffer.wirelessRemote.btn.components.R1:
break
if env.low_state_buffer.wirelessRemote.btn.components.L2 or env.low_state_buffer.wirelessRemote.btn.components.R2:
env.publish_legs_cmd(env.default_dof_pos.unsqueeze(0), kp= 0, kd= 0.5)
rospy.signal_shutdown("Controller send stop signal, exiting")
exit(0)
env.publish_legs_cmd(env.default_dof_pos.unsqueeze(0), kp= kp, kd= kd)
ros_rate.sleep()
rospy.loginfo("Robot standing up procedure finished!")
class SkilledA1Real(UnitreeA1Real):
""" Some additional methods to help the execution of skill policy """
def __init__(self, *args,
skill_mode_threhold= 0.1,
skill_vel_range= [0.0, 1.0],
**kwargs,
):
self.skill_mode_threhold = skill_mode_threhold
self.skill_vel_range = skill_vel_range
super().__init__(*args, **kwargs)
def is_skill_mode(self):
if self.move_by_wireless_remote:
return self.low_state_buffer.wirelessRemote.ry > self.skill_mode_threhold
else:
# Not implemented yet
return False
def update_low_state(self, ros_msg):
self.low_state_buffer = ros_msg
if self.move_by_wireless_remote and ros_msg.wirelessRemote.ry > self.skill_mode_threhold:
skill_vel = (self.low_state_buffer.wirelessRemote.ry - self.skill_mode_threhold) / (1.0 - self.skill_mode_threhold)
skill_vel *= self.skill_vel_range[1] - self.skill_vel_range[0]
skill_vel += self.skill_vel_range[0]
self.command_buf[0, 0] = skill_vel
self.command_buf[0, 1] = 0.
self.command_buf[0, 2] = 0.
return
return super().update_low_state(ros_msg)
def main(args):
log_level = rospy.DEBUG if args.debug else rospy.INFO
rospy.init_node("a1_legged_gym_" + args.mode, log_level= log_level)
""" Not finished this modification yet """
# if args.logdir is not None:
# rospy.loginfo("Use logdir/config.json to initialize env proxy.")
# with open(osp.join(args.logdir, "config.json"), "r") as f:
# config_dict = json.load(f, object_pairs_hook= OrderedDict)
# else:
# assert args.walkdir is not None, "You must provide at least a --logdir or --walkdir"
# rospy.logwarn("You did not provide logdir, use walkdir/config.json for initializing env proxy.")
# with open(osp.join(args.walkdir, "config.json"), "r") as f:
# config_dict = json.load(f, object_pairs_hook= OrderedDict)
assert args.logdir is not None
with open(osp.join(args.logdir, "config.json"), "r") as f:
config_dict = json.load(f, object_pairs_hook= OrderedDict)
duration = config_dict["sim"]["dt"] * config_dict["control"]["decimation"] # in sec
# config_dict["control"]["stiffness"]["joint"] -= 2.5 # kp
model_device = torch.device("cpu") if args.mode == "upboard" else torch.device("cuda")
unitree_real_env = SkilledA1Real(
robot_namespace= args.namespace,
cfg= config_dict,
forward_depth_topic= "/visual_embedding" if args.mode == "upboard" else "/camera/depth/image_rect_raw",
forward_depth_embedding_dims= config_dict["policy"]["visual_latent_size"] if args.mode == "upboard" else None,
move_by_wireless_remote= True,
skill_vel_range= config_dict["commands"]["ranges"]["lin_vel_x"],
model_device= model_device,
# extra_cfg= dict(
# motor_strength= torch.tensor([
# 1., 1./0.9, 1./0.9,
# 1., 1./0.9, 1./0.9,
# 1., 1., 1.,
# 1., 1., 1.,
# ], dtype= torch.float32, device= model_device, requires_grad= False),
# ),
)
model = getattr(modules, config_dict["runner"]["policy_class_name"])(
num_actor_obs= unitree_real_env.num_obs,
num_critic_obs= unitree_real_env.num_privileged_obs,
num_actions= 12,
obs_segments= unitree_real_env.obs_segments,
privileged_obs_segments= unitree_real_env.privileged_obs_segments,
**config_dict["policy"],
)
config_dict["terrain"]["measure_heights"] = False
# load the model with the latest checkpoint
model_names = [i for i in os.listdir(args.logdir) if i.startswith("model_")]
model_names.sort(key= lambda x: int(x.split("_")[-1].split(".")[0]))
state_dict = torch.load(osp.join(args.logdir, model_names[-1]), map_location= "cpu")
model.load_state_dict(state_dict["model_state_dict"])
model.to(model_device)
model.eval()
rospy.loginfo("duration: {}, motor Kp: {}, motor Kd: {}".format(
duration,
config_dict["control"]["stiffness"]["joint"],
config_dict["control"]["damping"]["joint"],
))
# rospy.loginfo("[Env] motor strength: {}".format(unitree_real_env.motor_strength))
if args.mode == "jetson":
embeding_publisher = rospy.Publisher(
args.namespace + "/visual_embedding",
Float32MultiArray,
queue_size= 1,
)
# extract and build the torch ScriptFunction
visual_encoder = model.visual_encoder
visual_encoder = torch.jit.script(visual_encoder)
forward_depth_subscriber = rospy.Subscriber(
args.namespace + "/camera/depth/image_rect_raw",
Image,
partial(handle_forward_depth,
model= visual_encoder,
publisher= embeding_publisher,
output_resolution= config_dict["sensor"]["forward_camera"].get(
"output_resolution",
config_dict["sensor"]["forward_camera"]["resolution"],
),
device= model_device,
),
queue_size= 1,
)
rospy.spin()
elif args.mode == "upboard":
# extract and build the torch ScriptFunction
memory_module = model.memory_a
actor_mlp = model.actor
@torch.jit.script
def policy(obs):
recurrent_embedding = memory_module(obs)
actions = actor_mlp(recurrent_embedding.squeeze(0))
return actions
walk_policy, walk_model = load_walk_policy(unitree_real_env, args.walkdir)
using_walk_policy = True # switch between skill policy and walk policy
unitree_real_env.start_ros()
unitree_real_env.wait_untill_ros_working()
rate = rospy.Rate(1 / duration)
with torch.no_grad():
if not args.debug:
standup_procedure(unitree_real_env, rate,
angle_tolerance= 0.2,
kp= 40,
kd= 0.5,
warmup_timesteps= 50,
device= model_device,
)
while not rospy.is_shutdown():
# inference_start_time = rospy.get_time()
# check remote controller and decide which policy to use
if unitree_real_env.is_skill_mode():
if using_walk_policy:
rospy.loginfo_throttle(0.1, "switch to skill policy")
using_walk_policy = False
model.reset()
else:
if not using_walk_policy:
rospy.loginfo_throttle(0.1, "switch to walk policy")
using_walk_policy = True
walk_model.reset()
if not using_walk_policy:
obs = unitree_real_env.get_obs()
actions = policy(obs)
else:
walk_obs = unitree_real_env._get_proprioception_obs()
actions = walk_policy(walk_obs)
unitree_real_env.send_action(actions)
# unitree_real_env.send_action(torch.zeros((1, 12)))
# inference_duration = rospy.get_time() - inference_start_time
# rospy.loginfo("inference duration: {:.3f}".format(inference_duration))
# rospy.loginfo("visual_latency: %f", rospy.get_time() - unitree_real_env.forward_depth_embedding_stamp.to_sec())
# motor_temperatures = [motor_state.temperature for motor_state in unitree_real_env.low_state_buffer.motorState]
# rospy.loginfo_throttle(10, " ".join(["motor_temperatures:"] + ["{:d},".format(t) for t in motor_temperatures[:12]]))
rate.sleep()
if unitree_real_env.low_state_buffer.wirelessRemote.btn.components.down:
rospy.loginfo_throttle(0.1, "model reset")
model.reset()
walk_model.reset()
if unitree_real_env.low_state_buffer.wirelessRemote.btn.components.L2 or unitree_real_env.low_state_buffer.wirelessRemote.btn.components.R2:
unitree_real_env.publish_legs_cmd(unitree_real_env.default_dof_pos.unsqueeze(0), kp= 2, kd= 0.5)
rospy.signal_shutdown("Controller send stop signal, exiting")
elif args.mode == "full":
# extract and build the torch ScriptFunction
visual_obs_slice = get_obs_slice(unitree_real_env.obs_segments, "forward_depth")
visual_encoder = model.visual_encoder
memory_module = model.memory_a
actor_mlp = model.actor
@torch.jit.script
def policy(observations: torch.Tensor, obs_start: int, obs_stop: int, obs_shape: Tuple[int, int, int]):
visual_latent = visual_encoder(
observations[..., obs_start:obs_stop].reshape(-1, *obs_shape)
).reshape(1, -1)
obs = torch.cat([
observations[..., :obs_start],
visual_latent,
observations[..., obs_stop:],
], dim= -1)
recurrent_embedding = memory_module(obs)
actions = actor_mlp(recurrent_embedding.squeeze(0))
return actions
unitree_real_env.start_ros()
unitree_real_env.wait_untill_ros_working()
rate = rospy.Rate(1 / duration)
with torch.no_grad():
while not rospy.is_shutdown():
# inference_start_time = rospy.get_time()
obs = unitree_real_env.get_obs()
actions = policy(obs,
obs_start= visual_obs_slice[0].start.item(),
obs_stop= visual_obs_slice[0].stop.item(),
obs_shape= visual_obs_slice[1],
)
unitree_real_env.send_action(actions)
# inference_duration = rospy.get_time() - inference_start_time
motor_temperatures = [motor_state.temperature for motor_state in unitree_real_env.low_state_buffer.motorState]
rospy.loginfo_throttle(10, " ".join(["motor_temperatures:"] + ["{:d},".format(t) for t in motor_temperatures[:12]]))
rate.sleep()
if unitree_real_env.low_state_buffer.wirelessRemote.btn.components.L2 or unitree_real_env.low_state_buffer.wirelessRemote.btn.components.R2:
unitree_real_env.publish_legs_cmd(unitree_real_env.default_dof_pos.unsqueeze(0), kp= 20, kd= 0.5)
rospy.signal_shutdown("Controller send stop signal, exiting")
else:
rospy.logfatal("Unknown mode, exiting")
if __name__ == "__main__":
""" The script to run the A1 script in ROS.
It's designed as a main function and not designed to be a scalable code.
"""
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--namespace",
type= str,
default= "/a112138",
)
parser.add_argument("--logdir",
type= str,
help= "The log directory of the trained model",
default= None,
)
parser.add_argument("--walkdir",
type= str,
help= "The log directory of the walking model, not for the skills.",
default= None,
)
parser.add_argument("--mode",
type= str,
help= "The mode to determine which computer to run on.",
choices= ["jetson", "upboard", "full"],
)
parser.add_argument("--debug",
action= "store_true",
)
args = parser.parse_args()
main(args)