walk-these-ways-go2/go1_gym_deploy/envs/history_wrapper.py

52 lines
2.1 KiB
Python
Raw Normal View History

2024-01-28 17:11:38 +08:00
# import isaacgym
# assert isaacgym, "import isaacgym before pytorch"
import torch
class HistoryWrapper:
def __init__(self, env):
self.env = env
if isinstance(self.env.cfg, dict):
self.obs_history_length = self.env.cfg["env"]["num_observation_history"]
else:
self.obs_history_length = self.env.cfg.env.num_observation_history
self.num_obs_history = self.obs_history_length * self.env.num_obs
self.obs_history = torch.zeros(self.env.num_envs, self.num_obs_history, dtype=torch.float,
device=self.env.device, requires_grad=False)
self.num_privileged_obs = self.env.num_privileged_obs
def step(self, action):
obs, rew, done, info = self.env.step(action)
privileged_obs = info["privileged_obs"]
self.obs_history = torch.cat((self.obs_history[:, self.env.num_obs:], obs), dim=-1)
return {'obs': obs, 'privileged_obs': privileged_obs, 'obs_history': self.obs_history}, rew, done, info
def get_observations(self):
obs = self.env.get_observations()
privileged_obs = self.env.get_privileged_observations()
self.obs_history = torch.cat((self.obs_history[:, self.env.num_obs:], obs), dim=-1)
return {'obs': obs, 'privileged_obs': privileged_obs, 'obs_history': self.obs_history}
def get_obs(self):
obs = self.env.get_obs()
privileged_obs = self.env.get_privileged_observations()
self.obs_history = torch.cat((self.obs_history[:, self.env.num_obs:], obs), dim=-1)
return {'obs': obs, 'privileged_obs': privileged_obs, 'obs_history': self.obs_history}
def reset_idx(self, env_ids): # it might be a problem that this isn't getting called!!
ret = self.env.reset_idx(env_ids)
self.obs_history[env_ids, :] = 0
return ret
def reset(self):
ret = self.env.reset()
privileged_obs = self.env.get_privileged_observations()
self.obs_history[:, :] = 0
return {"obs": ret, "privileged_obs": privileged_obs, "obs_history": self.obs_history}
def __getattr__(self, name):
return getattr(self.env, name)