import argparse
import logging
import os
from pathlib import Path

from datatrove.executor import LocalPipelineExecutor
from datatrove.executor.slurm import SlurmPipelineExecutor
from datatrove.pipeline.base import PipelineStep
from huggingface_hub import HfApi
from huggingface_hub.constants import REPOCARD_NAME

from examples.port_datasets.droid_rlds.port_droid import DROID_SHARDS
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDatasetMetadata
from lerobot.common.datasets.utils import create_lerobot_dataset_card
from lerobot.common.utils.utils import init_logging


class UploadDataset(PipelineStep):
    def __init__(
        self,
        repo_id: str,
        branch: str | None = None,
        revision: str | None = None,
        tags: list | None = None,
        license: str | None = "apache-2.0",
        private: bool = False,
        distant_repo_id: str | None = None,
        **card_kwargs,
    ):
        super().__init__()
        self.repo_id = repo_id
        self.distant_repo_id = self.repo_id if distant_repo_id is None else distant_repo_id
        self.branch = branch
        self.tags = tags
        self.license = license
        self.private = private
        self.card_kwargs = card_kwargs
        self.revision = revision if revision else CODEBASE_VERSION

        if os.environ.get("HF_HUB_ENABLE_HF_TRANSFER", "0") != "1":
            logging.warning(
                'HF_HUB_ENABLE_HF_TRANSFER is not set to "1". Install hf_transfer and set the env '
                "variable for faster uploads:\npip install hf-transfer\nexport HF_HUB_ENABLE_HF_TRANSFER=1"
            )

        self.create_repo()

    def create_repo(self):
        logging.info(f"Loading meta data from {self.repo_id}...")
        meta = LeRobotDatasetMetadata(self.repo_id)

        logging.info(f"Creating repo {self.distant_repo_id}...")
        hub_api = HfApi()
        hub_api.create_repo(
            repo_id=self.distant_repo_id,
            private=self.private,
            repo_type="dataset",
            exist_ok=True,
        )
        if self.branch:
            hub_api.create_branch(
                repo_id=self.distant_repo_id,
                branch=self.branch,
                revision=self.revision,
                repo_type="dataset",
                exist_ok=True,
            )

        if not hub_api.file_exists(
            self.distant_repo_id, REPOCARD_NAME, repo_type="dataset", revision=self.branch
        ):
            card = create_lerobot_dataset_card(
                tags=self.tags, dataset_info=meta.info, license=self.license, **self.card_kwargs
            )
            card.push_to_hub(repo_id=self.distant_repo_id, repo_type="dataset", revision=self.branch)

        def list_files_recursively(directory):
            base_path = Path(directory)
            return [str(file.relative_to(base_path)) for file in base_path.rglob("*") if file.is_file()]

        logging.info(f"Listing all local files from {self.repo_id}...")
        self.file_paths = list_files_recursively(meta.root)
        self.file_paths = sorted(self.file_paths)

    def create_chunks(self, lst, n):
        from itertools import islice

        it = iter(lst)
        return [list(islice(it, size)) for size in [len(lst) // n + (i < len(lst) % n) for i in range(n)]]

    def create_commits(self, additions):
        import logging
        import math
        import random
        import time

        from huggingface_hub import create_commit
        from huggingface_hub.utils import HfHubHTTPError

        FILES_BETWEEN_COMMITS = 10  # noqa: N806
        BASE_DELAY = 0.1  # noqa: N806
        MAX_RETRIES = 12  # noqa: N806

        # Split the files into smaller chunks for faster commit
        # and avoiding "A commit has happened since" error
        num_chunks = math.ceil(len(additions) / FILES_BETWEEN_COMMITS)
        chunks = self.create_chunks(additions, num_chunks)

        for chunk in chunks:
            retries = 0
            while True:
                try:
                    create_commit(
                        self.distant_repo_id,
                        repo_type="dataset",
                        operations=chunk,
                        commit_message=f"DataTrove upload ({len(chunk)} files)",
                        revision=self.branch,
                    )
                    logging.info("create_commit completed!")
                    break
                except HfHubHTTPError as e:
                    if "A commit has happened since" in e.server_message:
                        if retries >= MAX_RETRIES:
                            logging.error(f"Failed to create commit after {MAX_RETRIES=}. Giving up.")
                            raise e
                        logging.info("Commit creation race condition issue. Waiting...")
                        time.sleep(BASE_DELAY * 2**retries + random.uniform(0, 2))
                        retries += 1
                    else:
                        raise e

    def run(self, data=None, rank: int = 0, world_size: int = 1):
        import logging

        from datasets.utils.tqdm import disable_progress_bars
        from huggingface_hub import CommitOperationAdd, preupload_lfs_files

        from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata
        from lerobot.common.utils.utils import init_logging

        init_logging()
        disable_progress_bars()

        chunks = self.create_chunks(self.file_paths, world_size)
        file_paths = chunks[rank]

        if len(file_paths) == 0:
            raise ValueError(file_paths)

        logging.info("Pre-uploading LFS files...")
        for i, path in enumerate(file_paths):
            logging.info(f"{i}: {path}")

        meta = LeRobotDatasetMetadata(self.repo_id)
        additions = [
            CommitOperationAdd(path_in_repo=path, path_or_fileobj=meta.root / path) for path in file_paths
        ]
        preupload_lfs_files(
            repo_id=self.distant_repo_id, repo_type="dataset", additions=additions, revision=self.branch
        )

        logging.info("Creating commits...")
        self.create_commits(additions)
        logging.info("Done!")


def make_upload_executor(
    repo_id, job_name, logs_dir, workers, partition, cpus_per_task, mem_per_cpu, slurm=True
):
    kwargs = {
        "pipeline": [
            UploadDataset(repo_id),
        ],
        "logging_dir": str(logs_dir / job_name),
    }

    if slurm:
        kwargs.update(
            {
                "job_name": job_name,
                "tasks": DROID_SHARDS,
                "workers": workers,
                "time": "08:00:00",
                "partition": partition,
                "cpus_per_task": cpus_per_task,
                "sbatch_args": {"mem-per-cpu": mem_per_cpu},
            }
        )
        executor = SlurmPipelineExecutor(**kwargs)
    else:
        kwargs.update(
            {
                "tasks": DROID_SHARDS,
                "workers": 1,
            }
        )
        executor = LocalPipelineExecutor(**kwargs)

    return executor


def main():
    parser = argparse.ArgumentParser()

    parser.add_argument(
        "--repo-id",
        type=str,
        help="Repositery identifier on Hugging Face: a community or a user name `/` the name of the dataset, required when push-to-hub is True.",
    )
    parser.add_argument(
        "--logs-dir",
        type=Path,
        help="Path to logs directory for `datatrove`.",
    )
    parser.add_argument(
        "--job-name",
        type=str,
        default="upload_droid",
        help="Job name used in slurm, and name of the directory created inside the provided logs directory.",
    )
    parser.add_argument(
        "--slurm",
        type=int,
        default=1,
        help="Launch over slurm. Use `--slurm 0` to launch sequentially (useful to debug).",
    )
    parser.add_argument(
        "--workers",
        type=int,
        default=50,
        help="Number of slurm workers. It should be less than the maximum number of shards.",
    )
    parser.add_argument(
        "--partition",
        type=str,
        help="Slurm partition. Ideally a CPU partition. No need for GPU partition.",
    )
    parser.add_argument(
        "--cpus-per-task",
        type=int,
        default=8,
        help="Number of cpus that each slurm worker will use.",
    )
    parser.add_argument(
        "--mem-per-cpu",
        type=str,
        default="1950M",
        help="Memory per cpu that each worker will use.",
    )

    init_logging()

    args = parser.parse_args()
    kwargs = vars(args)
    kwargs["slurm"] = kwargs.pop("slurm") == 1
    upload_executor = make_upload_executor(**kwargs)
    upload_executor.run()


if __name__ == "__main__":
    main()