Go2Py_SIM/Go2Py/sim/gym/envs/base/base_task.py

138 lines
5.2 KiB
Python
Executable File

# License: see [LICENSE, LICENSES/legged_gym/LICENSE]
import sys
import gym
import torch
from isaacgym import gymapi, gymutil
from gym import spaces
import numpy as np
# Base class for RL tasks
class BaseTask(gym.Env):
def __init__(self, cfg, sim_params, physics_engine, sim_device, headless, eval_cfg=None):
self.gym = gymapi.acquire_gym()
if isinstance(physics_engine, str) and physics_engine == "SIM_PHYSX":
physics_engine = gymapi.SIM_PHYSX
self.sim_params = sim_params
self.physics_engine = physics_engine
self.sim_device = sim_device
sim_device_type, self.sim_device_id = gymutil.parse_device_str(self.sim_device)
self.headless = headless
# env device is GPU only if sim is on GPU and use_gpu_pipeline=True, otherwise returned tensors are copied to CPU by physX.
if sim_device_type == 'cuda' and sim_params.use_gpu_pipeline:
self.device = self.sim_device
else:
self.device = 'cpu'
# graphics device for rendering, -1 for no rendering
self.graphics_device_id = self.sim_device_id
if self.headless == True:
self.graphics_device_id = self.sim_device_id
self.num_obs = cfg.env.num_observations
self.num_privileged_obs = cfg.env.num_privileged_obs
self.num_actions = cfg.env.num_actions
if eval_cfg is not None:
self.num_eval_envs = eval_cfg.env.num_envs
self.num_train_envs = cfg.env.num_envs
self.num_envs = self.num_eval_envs + self.num_train_envs
else:
self.num_eval_envs = 0
self.num_train_envs = cfg.env.num_envs
self.num_envs = cfg.env.num_envs
# optimization flags for pytorch JIT
torch._C._jit_set_profiling_mode(False)
torch._C._jit_set_profiling_executor(False)
# allocate buffers
self.obs_buf = torch.zeros(self.num_envs, self.num_obs, device=self.device, dtype=torch.float)
self.rew_buf = torch.zeros(self.num_envs, device=self.device, dtype=torch.float)
self.rew_buf_pos = torch.zeros(self.num_envs, device=self.device, dtype=torch.float)
self.rew_buf_neg = torch.zeros(self.num_envs, device=self.device, dtype=torch.float)
self.reset_buf = torch.ones(self.num_envs, device=self.device, dtype=torch.long)
self.episode_length_buf = torch.zeros(self.num_envs, device=self.device, dtype=torch.long)
self.time_out_buf = torch.zeros(self.num_envs, device=self.device, dtype=torch.bool)
self.privileged_obs_buf = torch.zeros(self.num_envs, self.num_privileged_obs, device=self.device,
dtype=torch.float)
# self.num_privileged_obs = self.num_obs
self.extras = {}
# create envs, sim and viewer
self.create_sim()
self.gym.prepare_sim(self.sim)
# todo: read from config
self.enable_viewer_sync = True
self.viewer = None
# if running with a viewer, set up keyboard shortcuts and camera
if self.headless == False:
# subscribe to keyboard shortcuts
self.viewer = self.gym.create_viewer(
self.sim, gymapi.CameraProperties())
self.gym.subscribe_viewer_keyboard_event(
self.viewer, gymapi.KEY_ESCAPE, "QUIT")
self.gym.subscribe_viewer_keyboard_event(
self.viewer, gymapi.KEY_V, "toggle_viewer_sync")
def get_observations(self):
return self.obs_buf
def get_privileged_observations(self):
return self.privileged_obs_buf
def reset_idx(self, env_ids):
"""Reset selected robots"""
raise NotImplementedError
def reset(self):
""" Reset all robots"""
self.reset_idx(torch.arange(self.num_envs, device=self.device))
obs, privileged_obs, _, _, _ = self.step(
torch.zeros(self.num_envs, self.num_actions, device=self.device, requires_grad=False))
return obs, privileged_obs
def step(self, actions):
raise NotImplementedError
def render_gui(self, sync_frame_time=True):
if self.viewer:
# check for window closed
if self.gym.query_viewer_has_closed(self.viewer):
sys.exit()
# check for keyboard events
for evt in self.gym.query_viewer_action_events(self.viewer):
if evt.action == "QUIT" and evt.value > 0:
sys.exit()
elif evt.action == "toggle_viewer_sync" and evt.value > 0:
self.enable_viewer_sync = not self.enable_viewer_sync
# fetch results
if self.device != 'cpu':
self.gym.fetch_results(self.sim, True)
# step graphics
if self.enable_viewer_sync:
self.gym.step_graphics(self.sim)
self.gym.draw_viewer(self.viewer, self.sim, True)
if sync_frame_time:
self.gym.sync_frame_time(self.sim)
else:
self.gym.poll_viewer_events(self.viewer)
def close(self):
if self.headless == False:
self.gym.destroy_viewer(self.viewer)
self.gym.destroy_sim(self.sim)