fix tests

This commit is contained in:
Remi Cadene 2024-03-06 13:55:12 +00:00 committed by Simon Alibert
parent c2c0ef9927
commit 524d29aa80
4 changed files with 23 additions and 14 deletions

View File

@ -9,6 +9,7 @@ def make_env(cfg, transform=None):
"image_size": cfg.env.image_size, "image_size": cfg.env.image_size,
# TODO(rcadene): do we want a specific eval_env_seed? # TODO(rcadene): do we want a specific eval_env_seed?
"seed": cfg.seed, "seed": cfg.seed,
"num_prev_obs": cfg.n_obs_steps - 1,
} }
if cfg.env.name == "simxarm": if cfg.env.name == "simxarm":

View File

@ -2,6 +2,7 @@ import importlib
from collections import deque from collections import deque
from typing import Optional from typing import Optional
import einops
import torch import torch
from tensordict import TensorDict from tensordict import TensorDict
from torchrl.data.tensor_specs import ( from torchrl.data.tensor_specs import (
@ -28,7 +29,7 @@ class PushtEnv(EnvBase):
image_size=None, image_size=None,
seed=1337, seed=1337,
device="cpu", device="cpu",
num_prev_obs=1, num_prev_obs=0,
num_prev_action=0, num_prev_action=0,
): ):
super().__init__(device=device, batch_size=[]) super().__init__(device=device, batch_size=[])
@ -65,7 +66,8 @@ class PushtEnv(EnvBase):
self._prev_obs_image_queue = deque(maxlen=self.num_prev_obs) self._prev_obs_image_queue = deque(maxlen=self.num_prev_obs)
self._prev_obs_state_queue = deque(maxlen=self.num_prev_obs) self._prev_obs_state_queue = deque(maxlen=self.num_prev_obs)
if self.num_prev_action > 0: if self.num_prev_action > 0:
self._prev_action_queue = deque(maxlen=self.num_prev_action) raise NotImplementedError()
# self._prev_action_queue = deque(maxlen=self.num_prev_action)
def render(self, mode="rgb_array", width=384, height=384): def render(self, mode="rgb_array", width=384, height=384):
if width != height: if width != height:
@ -133,7 +135,7 @@ class PushtEnv(EnvBase):
sum_reward = 0 sum_reward = 0
if action.ndim == 1: if action.ndim == 1:
action = action.repeat(self.frame_skip, 1) action = einops.repeat(action, "c -> t c", t=self.frame_skip)
else: else:
if self.frame_skip > 1: if self.frame_skip > 1:
raise NotImplementedError() raise NotImplementedError()
@ -172,7 +174,7 @@ class PushtEnv(EnvBase):
if self.from_pixels: if self.from_pixels:
image_shape = (3, self.image_size, self.image_size) image_shape = (3, self.image_size, self.image_size)
if self.num_prev_obs > 0: if self.num_prev_obs > 0:
image_shape = (self.num_prev_obs, *image_shape) image_shape = (self.num_prev_obs + 1, *image_shape)
obs["image"] = BoundedTensorSpec( obs["image"] = BoundedTensorSpec(
low=0, low=0,
@ -184,12 +186,12 @@ class PushtEnv(EnvBase):
if not self.pixels_only: if not self.pixels_only:
state_shape = self._env.observation_space["agent_pos"].shape state_shape = self._env.observation_space["agent_pos"].shape
if self.num_prev_obs > 0: if self.num_prev_obs > 0:
state_shape = (self.num_prev_obs, *state_shape) state_shape = (self.num_prev_obs + 1, *state_shape)
obs["state"] = BoundedTensorSpec( obs["state"] = BoundedTensorSpec(
low=0, low=0,
high=512, high=512,
shape=self._env.observation_space["agent_pos"].shape, shape=state_shape,
dtype=torch.float32, dtype=torch.float32,
device=self.device, device=self.device,
) )
@ -197,11 +199,11 @@ class PushtEnv(EnvBase):
# TODO(rcadene): add observation_space achieved_goal and desired_goal? # TODO(rcadene): add observation_space achieved_goal and desired_goal?
state_shape = self._env.observation_space["observation"].shape state_shape = self._env.observation_space["observation"].shape
if self.num_prev_obs > 0: if self.num_prev_obs > 0:
state_shape = (self.num_prev_obs, *state_shape) state_shape = (self.num_prev_obs + 1, *state_shape)
obs["state"] = UnboundedContinuousTensorSpec( obs["state"] = UnboundedContinuousTensorSpec(
# TODO: # TODO:
shape=self._env.observation_space["observation"].shape, shape=state_shape,
dtype=torch.float32, dtype=torch.float32,
device=self.device, device=self.device,
) )

View File

@ -6,12 +6,17 @@ from .utils import init_config
@pytest.mark.parametrize( @pytest.mark.parametrize(
"env_name", "env_name,dataset_id",
[ [
"simxarm", # TODO(rcadene): simxarm is depreciated for now
"pusht", # ("simxarm", "lift"),
("pusht", "pusht"),
("aloha", "sim_insertion_human"),
("aloha", "sim_insertion_scripted"),
("aloha", "sim_transfer_cube_human"),
("aloha", "sim_transfer_cube_scripted"),
], ],
) )
def test_factory(env_name): def test_factory(env_name, dataset_id):
cfg = init_config(overrides=[f"env={env_name}"]) cfg = init_config(overrides=[f"env={env_name}", f"env.task={dataset_id}"])
offline_buffer = make_offline_buffer(cfg) offline_buffer = make_offline_buffer(cfg)

View File

@ -36,6 +36,7 @@ def print_spec_rollout(env):
print("data from rollout:", simple_rollout(100)) print("data from rollout:", simple_rollout(100))
@pytest.mark.skip(reason="Simxarm is deprecated")
@pytest.mark.parametrize( @pytest.mark.parametrize(
"task,from_pixels,pixels_only", "task,from_pixels,pixels_only",
[ [
@ -80,7 +81,7 @@ def test_pusht(from_pixels, pixels_only):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"env_name", "env_name",
[ [
"simxarm", # "simxarm",
"pusht", "pusht",
], ],
) )