import pytest
from tensordict import TensorDict
from tensordict.nn import TensorDictModule
import torch
from torchrl.data import UnboundedContinuousTensorSpec
from torchrl.envs import EnvBase

from lerobot.common.policies.factory import make_policy
from lerobot.common.envs.factory import make_env
from lerobot.common.datasets.factory import make_offline_buffer
from lerobot.common.policies.abstract import AbstractPolicy
from lerobot.common.utils import init_hydra_config
from .utils import DEVICE, DEFAULT_CONFIG_PATH

@pytest.mark.parametrize(
    "env_name,policy_name,extra_overrides",
    [
        ("simxarm", "tdmpc", ["policy.mpc=true"]),
        ("pusht", "tdmpc", ["policy.mpc=false"]),
        ("pusht", "diffusion", []),
        ("aloha", "act", ["env.task=sim_insertion", "dataset_id=aloha_sim_insertion_human"]),
        ("aloha", "act", ["env.task=sim_insertion", "dataset_id=aloha_sim_insertion_scripted"]),
        ("aloha", "act", ["env.task=sim_transfer_cube", "dataset_id=aloha_sim_transfer_cube_human"]),
        ("aloha", "act", ["env.task=sim_transfer_cube", "dataset_id=aloha_sim_transfer_cube_scripted"]),
        # TODO(aliberts): simxarm not working with diffusion
        # ("simxarm", "diffusion", []),
    ],
)
def test_concrete_policy(env_name, policy_name, extra_overrides):
    """
    Tests:
        - Making the policy object.
        - Updating the policy.
        - Using the policy to select actions at inference time.
    """
    cfg = init_hydra_config(
        DEFAULT_CONFIG_PATH,
        overrides=[
            f"env={env_name}",
            f"policy={policy_name}",
            f"device={DEVICE}",
        ]
        + extra_overrides
    )
    # Check that we can make the policy object.
    policy = make_policy(cfg)
    # Check that we run select_actions and get the appropriate output.
    offline_buffer = make_offline_buffer(cfg)
    env = make_env(cfg, transform=offline_buffer.transform)

    if env_name != "aloha":
        # TODO(alexander-soare): Fix this part of the test. PrioritizedSliceSampler raises NotImplementedError:
        # seq_length as a list is not supported for now.
        policy.update(offline_buffer, torch.tensor(0, device=DEVICE))

    action = policy(
        env.observation_spec.rand()["observation"].to(DEVICE),
        torch.tensor(0, device=DEVICE),
    )
    assert action.shape == env.action_spec.shape


def test_abstract_policy_forward():
    """
    Given an underlying policy that produces an action trajectory with n_action_steps actions, checks that:
        - The policy is invoked the expected number of times during a rollout.
        - The environment's termination condition is respected even when part way through an action trajectory.
        - The observations are returned correctly.
    """

    n_action_steps = 8  # our test policy will output 8 action step horizons
    terminate_at = 10  # some number that is more than n_action_steps but not a multiple
    rollout_max_steps = terminate_at + 1  # some number greater than terminate_at

    # A minimal environment for testing.
    class StubEnv(EnvBase):

        def __init__(self):
            super().__init__()
            self.action_spec = UnboundedContinuousTensorSpec(shape=(1,))
            self.reward_spec = UnboundedContinuousTensorSpec(shape=(1,))

        def _step(self, tensordict: TensorDict) -> TensorDict:
            self.invocation_count += 1
            return TensorDict(
                {
                    "observation": torch.tensor([self.invocation_count]),
                    "reward": torch.tensor([self.invocation_count]),
                    "terminated": torch.tensor(
                        tensordict["action"].item() == terminate_at
                    ),
                }
            )

        def _reset(self, tensordict: TensorDict) -> TensorDict:
            self.invocation_count = 0
            return TensorDict(
                {
                    "observation": torch.tensor([self.invocation_count]),
                    "reward": torch.tensor([self.invocation_count]),
                }
            )

        def _set_seed(self, seed: int | None):
            return

    class StubPolicy(AbstractPolicy):
        name = "stub"

        def __init__(self):
            super().__init__(n_action_steps)
            self.n_policy_invocations = 0

        def update(self):
            pass

        def select_actions(self):
            self.n_policy_invocations += 1
            return torch.stack(
                [torch.tensor([i]) for i in range(self.n_action_steps)]
            ).unsqueeze(0)

    env = StubEnv()
    policy = StubPolicy()
    policy = TensorDictModule(
        policy,
        in_keys=[],
        out_keys=["action"],
    )

    # Keep track to make sure the policy is called the expected number of times
    rollout = env.rollout(rollout_max_steps, policy)

    assert len(rollout) == terminate_at + 1  # +1 for the reset observation
    assert policy.n_policy_invocations == (terminate_at // n_action_steps) + 1
    assert torch.equal(rollout["observation"].flatten(), torch.arange(terminate_at + 1))