unitree_rl_gym/legged_gym/scripts/play.py

52 lines
1.7 KiB
Python
Raw Normal View History

2024-07-17 17:20:50 +08:00
import sys
from legged_gym import LEGGED_GYM_ROOT_DIR
2024-07-18 15:57:29 +08:00
import os
2023-10-11 15:38:49 +08:00
import sys
from legged_gym import LEGGED_GYM_ROOT_DIR
import isaacgym
from legged_gym.envs import *
from legged_gym.utils import get_args, export_policy_as_jit, task_registry, Logger
import numpy as np
import torch
def play(args):
env_cfg, train_cfg = task_registry.get_cfgs(name=args.task)
# override some parameters for testing
2024-01-12 15:41:58 +08:00
env_cfg.env.num_envs = min(env_cfg.env.num_envs, 100)
2023-10-11 15:38:49 +08:00
env_cfg.terrain.num_rows = 5
env_cfg.terrain.num_cols = 5
env_cfg.terrain.curriculum = False
env_cfg.noise.add_noise = False
env_cfg.domain_rand.randomize_friction = False
env_cfg.domain_rand.push_robots = False
2024-01-12 15:41:58 +08:00
env_cfg.env.test = True
2023-10-11 15:38:49 +08:00
# prepare environment
env, _ = task_registry.make_env(name=args.task, args=args, env_cfg=env_cfg)
obs = env.get_observations()
# load policy
train_cfg.runner.resume = True
ppo_runner, train_cfg = task_registry.make_alg_runner(env=env, name=args.task, args=args, train_cfg=train_cfg)
policy = ppo_runner.get_inference_policy(device=env.device)
# export policy as a jit module (used to run it from C++)
if EXPORT_POLICY:
path = os.path.join(LEGGED_GYM_ROOT_DIR, 'logs', train_cfg.runner.experiment_name, 'exported', 'policies')
export_policy_as_jit(ppo_runner.alg.actor_critic, path)
print('Exported policy as jit script to: ', path)
for i in range(10*int(env.max_episode_length)):
actions = policy(obs.detach())
obs, _, rews, dones, infos = env.step(actions.detach())
if __name__ == '__main__':
EXPORT_POLICY = True
RECORD_FRAMES = False
MOVE_CAMERA = False
args = get_args()
play(args)