52 lines
2.1 KiB
Python
Executable File
52 lines
2.1 KiB
Python
Executable File
# import isaacgym
|
|
|
|
# assert isaacgym, "import isaacgym before pytorch"
|
|
import torch
|
|
|
|
|
|
class HistoryWrapper:
|
|
def __init__(self, env):
|
|
self.env = env
|
|
|
|
if isinstance(self.env.cfg, dict):
|
|
self.obs_history_length = self.env.cfg["env"]["num_observation_history"]
|
|
else:
|
|
self.obs_history_length = self.env.cfg.env.num_observation_history
|
|
self.num_obs_history = self.obs_history_length * self.env.num_obs
|
|
self.obs_history = torch.zeros(self.env.num_envs, self.num_obs_history, dtype=torch.float,
|
|
device=self.env.device, requires_grad=False)
|
|
self.num_privileged_obs = self.env.num_privileged_obs
|
|
|
|
def step(self, action):
|
|
obs, rew, done, info = self.env.step(action)
|
|
privileged_obs = info["privileged_obs"]
|
|
|
|
self.obs_history = torch.cat((self.obs_history[:, self.env.num_obs:], obs), dim=-1)
|
|
return {'obs': obs, 'privileged_obs': privileged_obs, 'obs_history': self.obs_history}, rew, done, info
|
|
|
|
def get_observations(self):
|
|
obs = self.env.get_observations()
|
|
privileged_obs = self.env.get_privileged_observations()
|
|
self.obs_history = torch.cat((self.obs_history[:, self.env.num_obs:], obs), dim=-1)
|
|
return {'obs': obs, 'privileged_obs': privileged_obs, 'obs_history': self.obs_history}
|
|
|
|
def get_obs(self):
|
|
obs = self.env.get_obs()
|
|
privileged_obs = self.env.get_privileged_observations()
|
|
self.obs_history = torch.cat((self.obs_history[:, self.env.num_obs:], obs), dim=-1)
|
|
return {'obs': obs, 'privileged_obs': privileged_obs, 'obs_history': self.obs_history}
|
|
|
|
def reset_idx(self, env_ids): # it might be a problem that this isn't getting called!!
|
|
ret = self.env.reset_idx(env_ids)
|
|
self.obs_history[env_ids, :] = 0
|
|
return ret
|
|
|
|
def reset(self):
|
|
ret = self.env.reset()
|
|
privileged_obs = self.env.get_privileged_observations()
|
|
self.obs_history[:, :] = 0
|
|
return {"obs": ret, "privileged_obs": privileged_obs, "obs_history": self.obs_history}
|
|
|
|
def __getattr__(self, name):
|
|
return getattr(self.env, name)
|