Merge pull request #10 from Cadene/user/rcadene/2024_03_06_fix_tests
Fix env tests
This commit is contained in:
commit
4cc7e1539e
|
@ -87,6 +87,10 @@ jobs:
|
||||||
#----------------------------------------------
|
#----------------------------------------------
|
||||||
# run tests
|
# run tests
|
||||||
#----------------------------------------------
|
#----------------------------------------------
|
||||||
|
- name: Run tests
|
||||||
|
run: |
|
||||||
|
source .venv/bin/activate
|
||||||
|
pytest tests
|
||||||
- name: Test train pusht end-to-end
|
- name: Test train pusht end-to-end
|
||||||
run: |
|
run: |
|
||||||
source .venv/bin/activate
|
source .venv/bin/activate
|
||||||
|
|
|
@ -8,13 +8,13 @@ import pymunk
|
||||||
import torch
|
import torch
|
||||||
import torchrl
|
import torchrl
|
||||||
import tqdm
|
import tqdm
|
||||||
|
from diffusion_policy.common.replay_buffer import ReplayBuffer as DiffusionPolicyReplayBuffer
|
||||||
|
from diffusion_policy.env.pusht.pusht_env import pymunk_to_shapely
|
||||||
from tensordict import TensorDict
|
from tensordict import TensorDict
|
||||||
from torchrl.data.replay_buffers.samplers import SliceSampler
|
from torchrl.data.replay_buffers.samplers import SliceSampler
|
||||||
from torchrl.data.replay_buffers.storages import TensorStorage
|
from torchrl.data.replay_buffers.storages import TensorStorage
|
||||||
from torchrl.data.replay_buffers.writers import Writer
|
from torchrl.data.replay_buffers.writers import Writer
|
||||||
|
|
||||||
from diffusion_policy.common.replay_buffer import ReplayBuffer as DiffusionPolicyReplayBuffer
|
|
||||||
from diffusion_policy.env.pusht.pusht_env import pymunk_to_shapely
|
|
||||||
from lerobot.common.datasets.abstract import AbstractExperienceReplay
|
from lerobot.common.datasets.abstract import AbstractExperienceReplay
|
||||||
from lerobot.common.datasets.utils import download_and_extract_zip
|
from lerobot.common.datasets.utils import download_and_extract_zip
|
||||||
|
|
||||||
|
@ -111,7 +111,7 @@ class PushtExperienceReplay(AbstractExperienceReplay):
|
||||||
)
|
)
|
||||||
|
|
||||||
def _download_and_preproc(self):
|
def _download_and_preproc(self):
|
||||||
raw_dir = self.data_dir / "raw"
|
raw_dir = self.data_dir.parent / f"{self.data_dir.name}_raw"
|
||||||
zarr_path = (raw_dir / PUSHT_ZARR).resolve()
|
zarr_path = (raw_dir / PUSHT_ZARR).resolve()
|
||||||
if not zarr_path.is_dir():
|
if not zarr_path.is_dir():
|
||||||
raw_dir.mkdir(parents=True, exist_ok=True)
|
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
|
@ -9,6 +9,7 @@ def make_env(cfg, transform=None):
|
||||||
"image_size": cfg.env.image_size,
|
"image_size": cfg.env.image_size,
|
||||||
# TODO(rcadene): do we want a specific eval_env_seed?
|
# TODO(rcadene): do we want a specific eval_env_seed?
|
||||||
"seed": cfg.seed,
|
"seed": cfg.seed,
|
||||||
|
"num_prev_obs": cfg.n_obs_steps - 1,
|
||||||
}
|
}
|
||||||
|
|
||||||
if cfg.env.name == "simxarm":
|
if cfg.env.name == "simxarm":
|
||||||
|
|
|
@ -2,6 +2,7 @@ import importlib
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
import einops
|
||||||
import torch
|
import torch
|
||||||
from tensordict import TensorDict
|
from tensordict import TensorDict
|
||||||
from torchrl.data.tensor_specs import (
|
from torchrl.data.tensor_specs import (
|
||||||
|
@ -28,7 +29,7 @@ class PushtEnv(EnvBase):
|
||||||
image_size=None,
|
image_size=None,
|
||||||
seed=1337,
|
seed=1337,
|
||||||
device="cpu",
|
device="cpu",
|
||||||
num_prev_obs=1,
|
num_prev_obs=0,
|
||||||
num_prev_action=0,
|
num_prev_action=0,
|
||||||
):
|
):
|
||||||
super().__init__(device=device, batch_size=[])
|
super().__init__(device=device, batch_size=[])
|
||||||
|
@ -65,7 +66,8 @@ class PushtEnv(EnvBase):
|
||||||
self._prev_obs_image_queue = deque(maxlen=self.num_prev_obs)
|
self._prev_obs_image_queue = deque(maxlen=self.num_prev_obs)
|
||||||
self._prev_obs_state_queue = deque(maxlen=self.num_prev_obs)
|
self._prev_obs_state_queue = deque(maxlen=self.num_prev_obs)
|
||||||
if self.num_prev_action > 0:
|
if self.num_prev_action > 0:
|
||||||
self._prev_action_queue = deque(maxlen=self.num_prev_action)
|
raise NotImplementedError()
|
||||||
|
# self._prev_action_queue = deque(maxlen=self.num_prev_action)
|
||||||
|
|
||||||
def render(self, mode="rgb_array", width=384, height=384):
|
def render(self, mode="rgb_array", width=384, height=384):
|
||||||
if width != height:
|
if width != height:
|
||||||
|
@ -133,7 +135,7 @@ class PushtEnv(EnvBase):
|
||||||
sum_reward = 0
|
sum_reward = 0
|
||||||
|
|
||||||
if action.ndim == 1:
|
if action.ndim == 1:
|
||||||
action = action.repeat(self.frame_skip, 1)
|
action = einops.repeat(action, "c -> t c", t=self.frame_skip)
|
||||||
else:
|
else:
|
||||||
if self.frame_skip > 1:
|
if self.frame_skip > 1:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
@ -172,7 +174,7 @@ class PushtEnv(EnvBase):
|
||||||
if self.from_pixels:
|
if self.from_pixels:
|
||||||
image_shape = (3, self.image_size, self.image_size)
|
image_shape = (3, self.image_size, self.image_size)
|
||||||
if self.num_prev_obs > 0:
|
if self.num_prev_obs > 0:
|
||||||
image_shape = (self.num_prev_obs, *image_shape)
|
image_shape = (self.num_prev_obs + 1, *image_shape)
|
||||||
|
|
||||||
obs["image"] = BoundedTensorSpec(
|
obs["image"] = BoundedTensorSpec(
|
||||||
low=0,
|
low=0,
|
||||||
|
@ -184,12 +186,12 @@ class PushtEnv(EnvBase):
|
||||||
if not self.pixels_only:
|
if not self.pixels_only:
|
||||||
state_shape = self._env.observation_space["agent_pos"].shape
|
state_shape = self._env.observation_space["agent_pos"].shape
|
||||||
if self.num_prev_obs > 0:
|
if self.num_prev_obs > 0:
|
||||||
state_shape = (self.num_prev_obs, *state_shape)
|
state_shape = (self.num_prev_obs + 1, *state_shape)
|
||||||
|
|
||||||
obs["state"] = BoundedTensorSpec(
|
obs["state"] = BoundedTensorSpec(
|
||||||
low=0,
|
low=0,
|
||||||
high=512,
|
high=512,
|
||||||
shape=self._env.observation_space["agent_pos"].shape,
|
shape=state_shape,
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
)
|
)
|
||||||
|
@ -197,11 +199,11 @@ class PushtEnv(EnvBase):
|
||||||
# TODO(rcadene): add observation_space achieved_goal and desired_goal?
|
# TODO(rcadene): add observation_space achieved_goal and desired_goal?
|
||||||
state_shape = self._env.observation_space["observation"].shape
|
state_shape = self._env.observation_space["observation"].shape
|
||||||
if self.num_prev_obs > 0:
|
if self.num_prev_obs > 0:
|
||||||
state_shape = (self.num_prev_obs, *state_shape)
|
state_shape = (self.num_prev_obs + 1, *state_shape)
|
||||||
|
|
||||||
obs["state"] = UnboundedContinuousTensorSpec(
|
obs["state"] = UnboundedContinuousTensorSpec(
|
||||||
# TODO:
|
# TODO:
|
||||||
shape=self._env.observation_space["observation"].shape,
|
shape=state_shape,
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
)
|
)
|
||||||
|
|
|
@ -6,12 +6,18 @@ from .utils import init_config
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"env_name",
|
"env_name,dataset_id",
|
||||||
[
|
[
|
||||||
"simxarm",
|
# TODO(rcadene): simxarm is depreciated for now
|
||||||
"pusht",
|
# ("simxarm", "lift"),
|
||||||
|
("pusht", "pusht"),
|
||||||
|
# TODO(aliberts): add aloha when dataset is available on hub
|
||||||
|
# ("aloha", "sim_insertion_human"),
|
||||||
|
# ("aloha", "sim_insertion_scripted"),
|
||||||
|
# ("aloha", "sim_transfer_cube_human"),
|
||||||
|
# ("aloha", "sim_transfer_cube_scripted"),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_factory(env_name):
|
def test_factory(env_name, dataset_id):
|
||||||
cfg = init_config(overrides=[f"env={env_name}"])
|
cfg = init_config(overrides=[f"env={env_name}", f"env.task={dataset_id}"])
|
||||||
offline_buffer = make_offline_buffer(cfg)
|
offline_buffer = make_offline_buffer(cfg)
|
||||||
|
|
|
@ -36,6 +36,7 @@ def print_spec_rollout(env):
|
||||||
print("data from rollout:", simple_rollout(100))
|
print("data from rollout:", simple_rollout(100))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="Simxarm is deprecated")
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"task,from_pixels,pixels_only",
|
"task,from_pixels,pixels_only",
|
||||||
[
|
[
|
||||||
|
@ -80,7 +81,7 @@ def test_pusht(from_pixels, pixels_only):
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"env_name",
|
"env_name",
|
||||||
[
|
[
|
||||||
"simxarm",
|
# "simxarm",
|
||||||
"pusht",
|
"pusht",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue