import pytest
import torch
from lerobot.common.datasets.factory import make_dataset
import gymnasium as gym
from gymnasium.utils.env_checker import check_env

from lerobot.common.envs.factory import make_env
from lerobot.common.utils import init_hydra_config

from lerobot.common.envs.utils import preprocess_observation

from .utils import DEVICE, DEFAULT_CONFIG_PATH


@pytest.mark.parametrize(
    "env_task, obs_type",
    [
        # ("AlohaInsertion-v0", "state"),
        ("AlohaInsertion-v0", "pixels"),
        ("AlohaInsertion-v0", "pixels_agent_pos"),
        ("AlohaTransferCube-v0", "pixels"),
        ("AlohaTransferCube-v0", "pixels_agent_pos"),
    ],
)
def test_aloha(env_task, obs_type):
    from lerobot.common.envs import aloha as gym_aloha  # noqa: F401
    env = gym.make(f"gym_aloha/{env_task}", obs_type=obs_type)
    check_env(env.unwrapped)



@pytest.mark.parametrize(
    "env_task, obs_type",
    [
        ("XarmLift-v0", "state"),
        ("XarmLift-v0", "pixels"),
        ("XarmLift-v0", "pixels_agent_pos"),
        # TODO(aliberts): Add gym_xarm other tasks
    ],
)
def test_xarm(env_task, obs_type):
    import gym_xarm  # noqa: F401
    env = gym.make(f"gym_xarm/{env_task}", obs_type=obs_type)
    check_env(env.unwrapped)



@pytest.mark.parametrize(
    "env_task, obs_type",
    [
        ("PushTPixels-v0", "state"),
        ("PushTPixels-v0", "pixels"),
        ("PushTPixels-v0", "pixels_agent_pos"),
    ],
)
def test_pusht(env_task, obs_type):
    import gym_pusht  # noqa: F401
    env = gym.make(f"gym_pusht/{env_task}", obs_type=obs_type)
    check_env(env.unwrapped)


@pytest.mark.parametrize(
    "env_name",
    [
        "pusht",
        "simxarm",
        "aloha",
    ],
)
def test_factory(env_name):
    cfg = init_hydra_config(
        DEFAULT_CONFIG_PATH,
        overrides=[f"env={env_name}", f"device={DEVICE}"],
    )

    dataset = make_dataset(cfg)

    env = make_env(cfg, num_parallel_envs=1)
    obs, info = env.reset()
    obs = preprocess_observation(obs, transform=dataset.transform)
    for key in dataset.image_keys:
        img = obs[key]
        assert img.dtype == torch.float32
        # TODO(rcadene): we assume for now that image normalization takes place in the model
        assert img.max() <= 1.0
        assert img.min() >= 0.0