Go2Py_SIM/Go2Py/sim/gym/envs/wrappers/history_wrapper.py

75 lines
2.7 KiB
Python
Executable File

import gym
import torch
import isaacgym
assert isaacgym
class HistoryWrapper(gym.Wrapper):
def __init__(self, env):
super().__init__(env)
self.env = env
self.obs_history_length = self.env.cfg.env.num_observation_history
self.num_obs_history = self.obs_history_length * self.num_obs
self.obs_history = torch.zeros(self.env.num_envs, self.num_obs_history, dtype=torch.float,
device=self.env.device, requires_grad=False)
self.num_privileged_obs = self.num_privileged_obs
def step(self, action):
# privileged information and observation history are stored in info
obs, rew, done, info = self.env.step(action)
privileged_obs = info["privileged_obs"]
self.obs_history = torch.cat((self.obs_history[:, self.env.num_obs:], obs), dim=-1)
return {'obs': obs, 'privileged_obs': privileged_obs,
'obs_history': self.obs_history}, rew, done, info
def get_observations(self):
obs = self.env.get_observations()
privileged_obs = self.env.get_privileged_observations()
self.obs_history = torch.cat((self.obs_history[:, self.env.num_obs:], obs), dim=-1)
return {'obs': obs, 'privileged_obs': privileged_obs, 'obs_history': self.obs_history}
def reset_idx(self, env_ids): # it might be a problem that this isn't getting called!!
ret = super().reset_idx(env_ids)
self.obs_history[env_ids, :] = 0
return ret
def reset(self):
ret = super().reset()
privileged_obs = self.env.get_privileged_observations()
self.obs_history[:, :] = 0
return {"obs": ret, "privileged_obs": privileged_obs, "obs_history": self.obs_history}
if __name__ == "__main__":
from tqdm import trange
import matplotlib.pyplot as plt
import ml_logger as logger
from gym_learn.ppo import Runner
from Go2Py.sim.gym.envs.wrappers.history_wrapper import HistoryWrapper
from gym_learn.ppo.actor_critic import AC_Args
from Go2Py.sim.gym.envs.base.legged_robot_config import Cfg
from Go2Py.sim.gym.envs.mini_cheetah.mini_cheetah_config import config_mini_cheetah
config_mini_cheetah(Cfg)
test_env = gym.make("VelocityTrackingEasyEnv-v0", cfg=Cfg)
env = HistoryWrapper(test_env)
env.reset()
action = torch.zeros(test_env.num_envs, 12)
for i in trange(3):
obs, rew, done, info = env.step(action)
print(obs.keys())
print(f"obs: {obs['obs']}")
print(f"privileged obs: {obs['privileged_obs']}")
print(f"obs_history: {obs['obs_history']}")
img = env.render('rgb_array')
plt.imshow(img)
plt.show()