163 lines
5.2 KiB
Python
Executable File
163 lines
5.2 KiB
Python
Executable File
import isaacgym
|
|
|
|
assert isaacgym
|
|
import torch
|
|
import numpy as np
|
|
|
|
import glob
|
|
import pickle as pkl
|
|
|
|
from go1_gym.envs import *
|
|
from go1_gym.envs.base.legged_robot_config import Cfg
|
|
from go1_gym.envs.go1.go1_config import config_go1
|
|
from go1_gym.envs.go1.velocity_tracking import VelocityTrackingEasyEnv
|
|
|
|
from tqdm import tqdm
|
|
|
|
def load_policy(logdir):
|
|
body = torch.jit.load(logdir + '/checkpoints/body_latest.jit')
|
|
import os
|
|
adaptation_module = torch.jit.load(logdir + '/checkpoints/adaptation_module_latest.jit')
|
|
|
|
def policy(obs, info={}):
|
|
i = 0
|
|
latent = adaptation_module.forward(obs["obs_history"].to('cpu'))
|
|
action = body.forward(torch.cat((obs["obs_history"].to('cpu'), latent), dim=-1))
|
|
info['latent'] = latent
|
|
return action
|
|
|
|
return policy
|
|
|
|
|
|
def load_env(label, headless=False):
|
|
dirs = glob.glob(f"../runs/{label}/*")
|
|
logdir = sorted(dirs)[0]
|
|
|
|
with open(logdir + "/parameters.pkl", 'rb') as file:
|
|
pkl_cfg = pkl.load(file)
|
|
print(pkl_cfg.keys())
|
|
cfg = pkl_cfg["Cfg"]
|
|
print(cfg.keys())
|
|
|
|
for key, value in cfg.items():
|
|
if hasattr(Cfg, key):
|
|
for key2, value2 in cfg[key].items():
|
|
setattr(getattr(Cfg, key), key2, value2)
|
|
|
|
# turn off DR for evaluation script
|
|
Cfg.domain_rand.push_robots = False
|
|
Cfg.domain_rand.randomize_friction = False
|
|
Cfg.domain_rand.randomize_gravity = False
|
|
Cfg.domain_rand.randomize_restitution = False
|
|
Cfg.domain_rand.randomize_motor_offset = False
|
|
Cfg.domain_rand.randomize_motor_strength = False
|
|
Cfg.domain_rand.randomize_friction_indep = False
|
|
Cfg.domain_rand.randomize_ground_friction = False
|
|
Cfg.domain_rand.randomize_base_mass = False
|
|
Cfg.domain_rand.randomize_Kd_factor = False
|
|
Cfg.domain_rand.randomize_Kp_factor = False
|
|
Cfg.domain_rand.randomize_joint_friction = False
|
|
Cfg.domain_rand.randomize_com_displacement = False
|
|
|
|
Cfg.env.num_recording_envs = 1
|
|
Cfg.env.num_envs = 1
|
|
Cfg.terrain.num_rows = 5
|
|
Cfg.terrain.num_cols = 5
|
|
Cfg.terrain.border_size = 0
|
|
Cfg.terrain.center_robots = True
|
|
Cfg.terrain.center_span = 1
|
|
Cfg.terrain.teleport_robots = True
|
|
|
|
Cfg.domain_rand.lag_timesteps = 6
|
|
Cfg.domain_rand.randomize_lag_timesteps = True
|
|
Cfg.control.control_type = "actuator_net"
|
|
|
|
from go1_gym.envs.wrappers.history_wrapper import HistoryWrapper
|
|
|
|
env = VelocityTrackingEasyEnv(sim_device='cuda:0', headless=False, cfg=Cfg)
|
|
env = HistoryWrapper(env)
|
|
|
|
# load policy
|
|
from ml_logger import logger
|
|
from go1_gym_learn.ppo_cse.actor_critic import ActorCritic
|
|
|
|
policy = load_policy(logdir)
|
|
|
|
return env, policy
|
|
|
|
|
|
def play_go1(headless=True):
|
|
from ml_logger import logger
|
|
|
|
from pathlib import Path
|
|
from go1_gym import MINI_GYM_ROOT_DIR
|
|
import glob
|
|
import os
|
|
|
|
label = "gait-conditioned-agility/pretrain-v0/train"
|
|
|
|
env, policy = load_env(label, headless=headless)
|
|
|
|
num_eval_steps = 250
|
|
gaits = {"pronking": [0, 0, 0],
|
|
"trotting": [0.5, 0, 0],
|
|
"bounding": [0, 0.5, 0],
|
|
"pacing": [0, 0, 0.5]}
|
|
|
|
x_vel_cmd, y_vel_cmd, yaw_vel_cmd = 1.5, 0.0, 0.0
|
|
body_height_cmd = 0.0
|
|
step_frequency_cmd = 3.0
|
|
gait = torch.tensor(gaits["trotting"])
|
|
footswing_height_cmd = 0.08
|
|
pitch_cmd = 0.0
|
|
roll_cmd = 0.0
|
|
stance_width_cmd = 0.25
|
|
|
|
measured_x_vels = np.zeros(num_eval_steps)
|
|
target_x_vels = np.ones(num_eval_steps) * x_vel_cmd
|
|
joint_positions = np.zeros((num_eval_steps, 12))
|
|
|
|
obs = env.reset()
|
|
|
|
for i in tqdm(range(num_eval_steps)):
|
|
with torch.no_grad():
|
|
actions = policy(obs)
|
|
env.commands[:, 0] = x_vel_cmd
|
|
env.commands[:, 1] = y_vel_cmd
|
|
env.commands[:, 2] = yaw_vel_cmd
|
|
env.commands[:, 3] = body_height_cmd
|
|
env.commands[:, 4] = step_frequency_cmd
|
|
env.commands[:, 5:8] = gait
|
|
env.commands[:, 8] = 0.5
|
|
env.commands[:, 9] = footswing_height_cmd
|
|
env.commands[:, 10] = pitch_cmd
|
|
env.commands[:, 11] = roll_cmd
|
|
env.commands[:, 12] = stance_width_cmd
|
|
obs, rew, done, info = env.step(actions)
|
|
|
|
measured_x_vels[i] = env.base_lin_vel[0, 0]
|
|
joint_positions[i] = env.dof_pos[0, :].cpu()
|
|
|
|
# plot target and measured forward velocity
|
|
from matplotlib import pyplot as plt
|
|
fig, axs = plt.subplots(2, 1, figsize=(12, 5))
|
|
axs[0].plot(np.linspace(0, num_eval_steps * env.dt, num_eval_steps), measured_x_vels, color='black', linestyle="-", label="Measured")
|
|
axs[0].plot(np.linspace(0, num_eval_steps * env.dt, num_eval_steps), target_x_vels, color='black', linestyle="--", label="Desired")
|
|
axs[0].legend()
|
|
axs[0].set_title("Forward Linear Velocity")
|
|
axs[0].set_xlabel("Time (s)")
|
|
axs[0].set_ylabel("Velocity (m/s)")
|
|
|
|
axs[1].plot(np.linspace(0, num_eval_steps * env.dt, num_eval_steps), joint_positions, linestyle="-", label="Measured")
|
|
axs[1].set_title("Joint Positions")
|
|
axs[1].set_xlabel("Time (s)")
|
|
axs[1].set_ylabel("Joint Position (rad)")
|
|
|
|
plt.tight_layout()
|
|
plt.show()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
# to see the environment rendering, set headless=False
|
|
play_go1(headless=False)
|