unitree_rl_gym/legged_gym/utils/helpers.py

192 lines
7.2 KiB
Python
Raw Normal View History

2023-10-11 15:38:49 +08:00
import os
import copy
import torch
import numpy as np
import random
from isaacgym import gymapi
from isaacgym import gymutil
from legged_gym import LEGGED_GYM_ROOT_DIR, LEGGED_GYM_ENVS_DIR
def class_to_dict(obj) -> dict:
if not hasattr(obj,"__dict__"):
return obj
result = {}
for key in dir(obj):
if key.startswith("_"):
continue
element = []
val = getattr(obj, key)
if isinstance(val, list):
for item in val:
element.append(class_to_dict(item))
else:
element = class_to_dict(val)
result[key] = element
return result
def update_class_from_dict(obj, dict):
for key, val in dict.items():
attr = getattr(obj, key, None)
if isinstance(attr, type):
update_class_from_dict(attr, val)
else:
setattr(obj, key, val)
return
def set_seed(seed):
if seed == -1:
seed = np.random.randint(0, 10000)
print("Setting seed: {}".format(seed))
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def parse_sim_params(args, cfg):
# code from Isaac Gym Preview 2
# initialize sim params
sim_params = gymapi.SimParams()
# set some values from args
if args.physics_engine == gymapi.SIM_FLEX:
if args.device != "cpu":
print("WARNING: Using Flex with GPU instead of PHYSX!")
elif args.physics_engine == gymapi.SIM_PHYSX:
sim_params.physx.use_gpu = args.use_gpu
sim_params.physx.num_subscenes = args.subscenes
sim_params.use_gpu_pipeline = args.use_gpu_pipeline
# if sim options are provided in cfg, parse them and update/override above:
if "sim" in cfg:
gymutil.parse_sim_config(cfg["sim"], sim_params)
# Override num_threads if passed on the command line
if args.physics_engine == gymapi.SIM_PHYSX and args.num_threads > 0:
sim_params.physx.num_threads = args.num_threads
return sim_params
def get_load_path(root, load_run=-1, checkpoint=-1):
try:
runs = os.listdir(root)
#TODO sort by date to handle change of month
runs.sort()
if 'exported' in runs: runs.remove('exported')
last_run = os.path.join(root, runs[-1])
except:
raise ValueError("No runs in this directory: " + root)
if load_run==-1:
load_run = last_run
else:
load_run = os.path.join(root, load_run)
if checkpoint==-1:
models = [file for file in os.listdir(load_run) if 'model' in file]
models.sort(key=lambda m: '{0:0>15}'.format(m))
model = models[-1]
else:
model = "model_{}.pt".format(checkpoint)
load_path = os.path.join(load_run, model)
return load_path
def update_cfg_from_args(env_cfg, cfg_train, args):
# seed
if env_cfg is not None:
# num envs
if args.num_envs is not None:
env_cfg.env.num_envs = args.num_envs
if cfg_train is not None:
if args.seed is not None:
cfg_train.seed = args.seed
# alg runner parameters
if args.max_iterations is not None:
cfg_train.runner.max_iterations = args.max_iterations
if args.resume:
cfg_train.runner.resume = args.resume
if args.experiment_name is not None:
cfg_train.runner.experiment_name = args.experiment_name
if args.run_name is not None:
cfg_train.runner.run_name = args.run_name
if args.load_run is not None:
cfg_train.runner.load_run = args.load_run
if args.checkpoint is not None:
cfg_train.runner.checkpoint = args.checkpoint
return env_cfg, cfg_train
def get_args():
custom_parameters = [
{"name": "--task", "type": str, "default": "go2", "help": "Resume training or start testing from a checkpoint. Overrides config file if provided."},
{"name": "--resume", "action": "store_true", "default": False, "help": "Resume training from a checkpoint"},
{"name": "--experiment_name", "type": str, "help": "Name of the experiment to run or load. Overrides config file if provided."},
{"name": "--run_name", "type": str, "help": "Name of the run. Overrides config file if provided."},
{"name": "--load_run", "type": str, "help": "Name of the run to load when resume=True. If -1: will load the last run. Overrides config file if provided."},
{"name": "--checkpoint", "type": int, "help": "Saved model checkpoint number. If -1: will load the last checkpoint. Overrides config file if provided."},
{"name": "--headless", "action": "store_true", "default": False, "help": "Force display off at all times"},
{"name": "--horovod", "action": "store_true", "default": False, "help": "Use horovod for multi-gpu training"},
{"name": "--rl_device", "type": str, "default": "cuda:0", "help": 'Device used by the RL algorithm, (cpu, gpu, cuda:0, cuda:1 etc..)'},
{"name": "--num_envs", "type": int, "help": "Number of environments to create. Overrides config file if provided."},
{"name": "--seed", "type": int, "help": "Random seed. Overrides config file if provided."},
{"name": "--max_iterations", "type": int, "help": "Maximum number of training iterations. Overrides config file if provided."},
]
# parse arguments
args = gymutil.parse_arguments(
description="RL Policy",
custom_parameters=custom_parameters)
# name allignment
args.sim_device_id = args.compute_device_id
args.sim_device = args.sim_device_type
if args.sim_device=='cuda':
args.sim_device += f":{args.sim_device_id}"
return args
def export_policy_as_jit(actor_critic, path):
if hasattr(actor_critic, 'memory_a'):
# assumes LSTM: TODO add GRU
exporter = PolicyExporterLSTM(actor_critic)
exporter.export(path)
else:
os.makedirs(path, exist_ok=True)
path = os.path.join(path, 'policy_1.pt')
model = copy.deepcopy(actor_critic.actor).to('cpu')
traced_script_module = torch.jit.script(model)
traced_script_module.save(path)
class PolicyExporterLSTM(torch.nn.Module):
def __init__(self, actor_critic):
super().__init__()
self.actor = copy.deepcopy(actor_critic.actor)
self.is_recurrent = actor_critic.is_recurrent
self.memory = copy.deepcopy(actor_critic.memory_a.rnn)
self.memory.cpu()
self.register_buffer(f'hidden_state', torch.zeros(self.memory.num_layers, 1, self.memory.hidden_size))
self.register_buffer(f'cell_state', torch.zeros(self.memory.num_layers, 1, self.memory.hidden_size))
def forward(self, x):
out, (h, c) = self.memory(x.unsqueeze(0), (self.hidden_state, self.cell_state))
self.hidden_state[:] = h
self.cell_state[:] = c
return self.actor(out.squeeze(0))
@torch.jit.export
def reset_memory(self):
self.hidden_state[:] = 0.
self.cell_state[:] = 0.
def export(self, path):
os.makedirs(path, exist_ok=True)
path = os.path.join(path, 'policy_lstm_1.pt')
self.to('cpu')
traced_script_module = torch.jit.script(self)
traced_script_module.save(path)