Merge pull request #10 from Cadene/user/rcadene/2024_03_06_fix_tests

Fix env tests
This commit is contained in:
Simon Alibert 2024-03-08 12:30:08 +01:00 committed by GitHub
commit 4cc7e1539e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 31 additions and 17 deletions

View File

@ -87,6 +87,10 @@ jobs:
#---------------------------------------------- #----------------------------------------------
# run tests # run tests
#---------------------------------------------- #----------------------------------------------
- name: Run tests
run: |
source .venv/bin/activate
pytest tests
- name: Test train pusht end-to-end - name: Test train pusht end-to-end
run: | run: |
source .venv/bin/activate source .venv/bin/activate

View File

@ -8,13 +8,13 @@ import pymunk
import torch import torch
import torchrl import torchrl
import tqdm import tqdm
from diffusion_policy.common.replay_buffer import ReplayBuffer as DiffusionPolicyReplayBuffer
from diffusion_policy.env.pusht.pusht_env import pymunk_to_shapely
from tensordict import TensorDict from tensordict import TensorDict
from torchrl.data.replay_buffers.samplers import SliceSampler from torchrl.data.replay_buffers.samplers import SliceSampler
from torchrl.data.replay_buffers.storages import TensorStorage from torchrl.data.replay_buffers.storages import TensorStorage
from torchrl.data.replay_buffers.writers import Writer from torchrl.data.replay_buffers.writers import Writer
from diffusion_policy.common.replay_buffer import ReplayBuffer as DiffusionPolicyReplayBuffer
from diffusion_policy.env.pusht.pusht_env import pymunk_to_shapely
from lerobot.common.datasets.abstract import AbstractExperienceReplay from lerobot.common.datasets.abstract import AbstractExperienceReplay
from lerobot.common.datasets.utils import download_and_extract_zip from lerobot.common.datasets.utils import download_and_extract_zip
@ -111,7 +111,7 @@ class PushtExperienceReplay(AbstractExperienceReplay):
) )
def _download_and_preproc(self): def _download_and_preproc(self):
raw_dir = self.data_dir / "raw" raw_dir = self.data_dir.parent / f"{self.data_dir.name}_raw"
zarr_path = (raw_dir / PUSHT_ZARR).resolve() zarr_path = (raw_dir / PUSHT_ZARR).resolve()
if not zarr_path.is_dir(): if not zarr_path.is_dir():
raw_dir.mkdir(parents=True, exist_ok=True) raw_dir.mkdir(parents=True, exist_ok=True)

View File

@ -9,6 +9,7 @@ def make_env(cfg, transform=None):
"image_size": cfg.env.image_size, "image_size": cfg.env.image_size,
# TODO(rcadene): do we want a specific eval_env_seed? # TODO(rcadene): do we want a specific eval_env_seed?
"seed": cfg.seed, "seed": cfg.seed,
"num_prev_obs": cfg.n_obs_steps - 1,
} }
if cfg.env.name == "simxarm": if cfg.env.name == "simxarm":

View File

@ -2,6 +2,7 @@ import importlib
from collections import deque from collections import deque
from typing import Optional from typing import Optional
import einops
import torch import torch
from tensordict import TensorDict from tensordict import TensorDict
from torchrl.data.tensor_specs import ( from torchrl.data.tensor_specs import (
@ -28,7 +29,7 @@ class PushtEnv(EnvBase):
image_size=None, image_size=None,
seed=1337, seed=1337,
device="cpu", device="cpu",
num_prev_obs=1, num_prev_obs=0,
num_prev_action=0, num_prev_action=0,
): ):
super().__init__(device=device, batch_size=[]) super().__init__(device=device, batch_size=[])
@ -65,7 +66,8 @@ class PushtEnv(EnvBase):
self._prev_obs_image_queue = deque(maxlen=self.num_prev_obs) self._prev_obs_image_queue = deque(maxlen=self.num_prev_obs)
self._prev_obs_state_queue = deque(maxlen=self.num_prev_obs) self._prev_obs_state_queue = deque(maxlen=self.num_prev_obs)
if self.num_prev_action > 0: if self.num_prev_action > 0:
self._prev_action_queue = deque(maxlen=self.num_prev_action) raise NotImplementedError()
# self._prev_action_queue = deque(maxlen=self.num_prev_action)
def render(self, mode="rgb_array", width=384, height=384): def render(self, mode="rgb_array", width=384, height=384):
if width != height: if width != height:
@ -133,7 +135,7 @@ class PushtEnv(EnvBase):
sum_reward = 0 sum_reward = 0
if action.ndim == 1: if action.ndim == 1:
action = action.repeat(self.frame_skip, 1) action = einops.repeat(action, "c -> t c", t=self.frame_skip)
else: else:
if self.frame_skip > 1: if self.frame_skip > 1:
raise NotImplementedError() raise NotImplementedError()
@ -172,7 +174,7 @@ class PushtEnv(EnvBase):
if self.from_pixels: if self.from_pixels:
image_shape = (3, self.image_size, self.image_size) image_shape = (3, self.image_size, self.image_size)
if self.num_prev_obs > 0: if self.num_prev_obs > 0:
image_shape = (self.num_prev_obs, *image_shape) image_shape = (self.num_prev_obs + 1, *image_shape)
obs["image"] = BoundedTensorSpec( obs["image"] = BoundedTensorSpec(
low=0, low=0,
@ -184,12 +186,12 @@ class PushtEnv(EnvBase):
if not self.pixels_only: if not self.pixels_only:
state_shape = self._env.observation_space["agent_pos"].shape state_shape = self._env.observation_space["agent_pos"].shape
if self.num_prev_obs > 0: if self.num_prev_obs > 0:
state_shape = (self.num_prev_obs, *state_shape) state_shape = (self.num_prev_obs + 1, *state_shape)
obs["state"] = BoundedTensorSpec( obs["state"] = BoundedTensorSpec(
low=0, low=0,
high=512, high=512,
shape=self._env.observation_space["agent_pos"].shape, shape=state_shape,
dtype=torch.float32, dtype=torch.float32,
device=self.device, device=self.device,
) )
@ -197,11 +199,11 @@ class PushtEnv(EnvBase):
# TODO(rcadene): add observation_space achieved_goal and desired_goal? # TODO(rcadene): add observation_space achieved_goal and desired_goal?
state_shape = self._env.observation_space["observation"].shape state_shape = self._env.observation_space["observation"].shape
if self.num_prev_obs > 0: if self.num_prev_obs > 0:
state_shape = (self.num_prev_obs, *state_shape) state_shape = (self.num_prev_obs + 1, *state_shape)
obs["state"] = UnboundedContinuousTensorSpec( obs["state"] = UnboundedContinuousTensorSpec(
# TODO: # TODO:
shape=self._env.observation_space["observation"].shape, shape=state_shape,
dtype=torch.float32, dtype=torch.float32,
device=self.device, device=self.device,
) )

View File

@ -6,12 +6,18 @@ from .utils import init_config
@pytest.mark.parametrize( @pytest.mark.parametrize(
"env_name", "env_name,dataset_id",
[ [
"simxarm", # TODO(rcadene): simxarm is depreciated for now
"pusht", # ("simxarm", "lift"),
("pusht", "pusht"),
# TODO(aliberts): add aloha when dataset is available on hub
# ("aloha", "sim_insertion_human"),
# ("aloha", "sim_insertion_scripted"),
# ("aloha", "sim_transfer_cube_human"),
# ("aloha", "sim_transfer_cube_scripted"),
], ],
) )
def test_factory(env_name): def test_factory(env_name, dataset_id):
cfg = init_config(overrides=[f"env={env_name}"]) cfg = init_config(overrides=[f"env={env_name}", f"env.task={dataset_id}"])
offline_buffer = make_offline_buffer(cfg) offline_buffer = make_offline_buffer(cfg)

View File

@ -36,6 +36,7 @@ def print_spec_rollout(env):
print("data from rollout:", simple_rollout(100)) print("data from rollout:", simple_rollout(100))
@pytest.mark.skip(reason="Simxarm is deprecated")
@pytest.mark.parametrize( @pytest.mark.parametrize(
"task,from_pixels,pixels_only", "task,from_pixels,pixels_only",
[ [
@ -80,7 +81,7 @@ def test_pusht(from_pixels, pixels_only):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"env_name", "env_name",
[ [
"simxarm", # "simxarm",
"pusht", "pusht",
], ],
) )