138 lines
5.2 KiB
Python
Executable File
138 lines
5.2 KiB
Python
Executable File
# License: see [LICENSE, LICENSES/legged_gym/LICENSE]
|
|
|
|
import sys
|
|
|
|
import gym
|
|
import torch
|
|
from isaacgym import gymapi, gymutil
|
|
|
|
from gym import spaces
|
|
import numpy as np
|
|
|
|
|
|
# Base class for RL tasks
|
|
class BaseTask(gym.Env):
|
|
|
|
def __init__(self, cfg, sim_params, physics_engine, sim_device, headless, eval_cfg=None):
|
|
self.gym = gymapi.acquire_gym()
|
|
|
|
if isinstance(physics_engine, str) and physics_engine == "SIM_PHYSX":
|
|
physics_engine = gymapi.SIM_PHYSX
|
|
|
|
self.sim_params = sim_params
|
|
self.physics_engine = physics_engine
|
|
self.sim_device = sim_device
|
|
sim_device_type, self.sim_device_id = gymutil.parse_device_str(self.sim_device)
|
|
self.headless = headless
|
|
|
|
# env device is GPU only if sim is on GPU and use_gpu_pipeline=True, otherwise returned tensors are copied to CPU by physX.
|
|
if sim_device_type == 'cuda' and sim_params.use_gpu_pipeline:
|
|
self.device = self.sim_device
|
|
else:
|
|
self.device = 'cpu'
|
|
|
|
# graphics device for rendering, -1 for no rendering
|
|
self.graphics_device_id = self.sim_device_id
|
|
if self.headless == True:
|
|
self.graphics_device_id = self.sim_device_id
|
|
|
|
self.num_obs = cfg.env.num_observations
|
|
self.num_privileged_obs = cfg.env.num_privileged_obs
|
|
self.num_actions = cfg.env.num_actions
|
|
|
|
if eval_cfg is not None:
|
|
self.num_eval_envs = eval_cfg.env.num_envs
|
|
self.num_train_envs = cfg.env.num_envs
|
|
self.num_envs = self.num_eval_envs + self.num_train_envs
|
|
else:
|
|
self.num_eval_envs = 0
|
|
self.num_train_envs = cfg.env.num_envs
|
|
self.num_envs = cfg.env.num_envs
|
|
|
|
# optimization flags for pytorch JIT
|
|
torch._C._jit_set_profiling_mode(False)
|
|
torch._C._jit_set_profiling_executor(False)
|
|
|
|
# allocate buffers
|
|
self.obs_buf = torch.zeros(self.num_envs, self.num_obs, device=self.device, dtype=torch.float)
|
|
self.rew_buf = torch.zeros(self.num_envs, device=self.device, dtype=torch.float)
|
|
self.rew_buf_pos = torch.zeros(self.num_envs, device=self.device, dtype=torch.float)
|
|
self.rew_buf_neg = torch.zeros(self.num_envs, device=self.device, dtype=torch.float)
|
|
self.reset_buf = torch.ones(self.num_envs, device=self.device, dtype=torch.long)
|
|
self.episode_length_buf = torch.zeros(self.num_envs, device=self.device, dtype=torch.long)
|
|
self.time_out_buf = torch.zeros(self.num_envs, device=self.device, dtype=torch.bool)
|
|
self.privileged_obs_buf = torch.zeros(self.num_envs, self.num_privileged_obs, device=self.device,
|
|
dtype=torch.float)
|
|
# self.num_privileged_obs = self.num_obs
|
|
|
|
self.extras = {}
|
|
|
|
# create envs, sim and viewer
|
|
self.create_sim()
|
|
self.gym.prepare_sim(self.sim)
|
|
|
|
# todo: read from config
|
|
self.enable_viewer_sync = True
|
|
self.viewer = None
|
|
|
|
# if running with a viewer, set up keyboard shortcuts and camera
|
|
if self.headless == False:
|
|
# subscribe to keyboard shortcuts
|
|
self.viewer = self.gym.create_viewer(
|
|
self.sim, gymapi.CameraProperties())
|
|
self.gym.subscribe_viewer_keyboard_event(
|
|
self.viewer, gymapi.KEY_ESCAPE, "QUIT")
|
|
self.gym.subscribe_viewer_keyboard_event(
|
|
self.viewer, gymapi.KEY_V, "toggle_viewer_sync")
|
|
|
|
def get_observations(self):
|
|
return self.obs_buf
|
|
|
|
def get_privileged_observations(self):
|
|
return self.privileged_obs_buf
|
|
|
|
def reset_idx(self, env_ids):
|
|
"""Reset selected robots"""
|
|
raise NotImplementedError
|
|
|
|
def reset(self):
|
|
""" Reset all robots"""
|
|
self.reset_idx(torch.arange(self.num_envs, device=self.device))
|
|
obs, privileged_obs, _, _, _ = self.step(
|
|
torch.zeros(self.num_envs, self.num_actions, device=self.device, requires_grad=False))
|
|
return obs, privileged_obs
|
|
|
|
def step(self, actions):
|
|
raise NotImplementedError
|
|
|
|
def render_gui(self, sync_frame_time=True):
|
|
if self.viewer:
|
|
# check for window closed
|
|
if self.gym.query_viewer_has_closed(self.viewer):
|
|
sys.exit()
|
|
|
|
# check for keyboard events
|
|
for evt in self.gym.query_viewer_action_events(self.viewer):
|
|
if evt.action == "QUIT" and evt.value > 0:
|
|
sys.exit()
|
|
elif evt.action == "toggle_viewer_sync" and evt.value > 0:
|
|
self.enable_viewer_sync = not self.enable_viewer_sync
|
|
|
|
# fetch results
|
|
if self.device != 'cpu':
|
|
self.gym.fetch_results(self.sim, True)
|
|
|
|
# step graphics
|
|
if self.enable_viewer_sync:
|
|
self.gym.step_graphics(self.sim)
|
|
self.gym.draw_viewer(self.viewer, self.sim, True)
|
|
if sync_frame_time:
|
|
self.gym.sync_frame_time(self.sim)
|
|
else:
|
|
self.gym.poll_viewer_events(self.viewer)
|
|
|
|
def close(self):
|
|
if self.headless == False:
|
|
self.gym.destroy_viewer(self.viewer)
|
|
self.gym.destroy_sim(self.sim)
|