import pytest from tensordict import TensorDict import torch from torchrl.envs.utils import check_env_specs, step_mdp from lerobot.common.datasets.factory import make_dataset import gymnasium as gym from gymnasium.utils.env_checker import check_env from lerobot.common.envs.aloha.env import AlohaEnv from lerobot.common.envs.factory import make_env from lerobot.common.envs.pusht.env import PushtEnv from lerobot.common.envs.simxarm.env import SimxarmEnv from lerobot.common.utils import init_hydra_config from .utils import DEVICE, DEFAULT_CONFIG_PATH def print_spec_rollout(env): print("observation_spec:", env.observation_spec) print("action_spec:", env.action_spec) print("reward_spec:", env.reward_spec) print("done_spec:", env.done_spec) td = env.reset() print("reset tensordict", td) td = env.rand_step(td) print("random step tensordict", td) def simple_rollout(steps=100): # preallocate: data = TensorDict({}, [steps]) # reset _data = env.reset() for i in range(steps): _data["action"] = env.action_spec.rand() _data = env.step(_data) data[i] = _data _data = step_mdp(_data, keep_other=True) return data print("data from rollout:", simple_rollout(100)) @pytest.mark.parametrize( "task,from_pixels,pixels_only", [ ("sim_insertion", True, False), ("sim_insertion", True, True), ("sim_transfer_cube", True, False), ("sim_transfer_cube", True, True), ], ) def test_aloha(task, from_pixels, pixels_only): env = AlohaEnv( task, from_pixels=from_pixels, pixels_only=pixels_only, image_size=[3, 480, 640] if from_pixels else None, ) # print_spec_rollout(env) check_env_specs(env) @pytest.mark.parametrize( "task, obs_type", [ ("XarmLift-v0", "state"), ("XarmLift-v0", "pixels"), ("XarmLift-v0", "pixels_agent_pos"), # TODO(aliberts): Add simxarm other tasks ], ) def test_xarm(env_task, obs_type): import gym_xarm env = gym.make(f"gym_xarm/{env_task}", obs_type=obs_type) # env = SimxarmEnv( # task, # from_pixels=from_pixels, # pixels_only=pixels_only, # image_size=84 if from_pixels else None, # ) # print_spec_rollout(env) # check_env_specs(env) check_env(env) @pytest.mark.parametrize( "from_pixels,pixels_only", [ (True, False), ], ) def test_pusht(from_pixels, pixels_only): env = PushtEnv( from_pixels=from_pixels, pixels_only=pixels_only, image_size=96 if from_pixels else None, ) # print_spec_rollout(env) check_env_specs(env) @pytest.mark.parametrize( "env_name", [ "simxarm", "pusht", "aloha", ], ) def test_factory(env_name): cfg = init_hydra_config( DEFAULT_CONFIG_PATH, overrides=[f"env={env_name}", f"device={DEVICE}"], ) dataset = make_dataset(cfg) env = make_env(cfg) for key in dataset.image_keys: assert env.reset().get(key).dtype == torch.uint8 check_env_specs(env) env = make_env(cfg, transform=dataset.transform) for key in dataset.image_keys: img = env.reset().get(key) assert img.dtype == torch.float32 # TODO(rcadene): we assume for now that image normalization takes place in the model assert img.max() <= 1.0 assert img.min() >= 0.0 check_env_specs(env)