diff --git a/examples/port_datasets/pusht_zarr.py b/examples/port_datasets/pusht_zarr.py index 622fbd14..e05c742e 100644 --- a/examples/port_datasets/pusht_zarr.py +++ b/examples/port_datasets/pusht_zarr.py @@ -4,7 +4,8 @@ from pathlib import Path import numpy as np import torch -from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME, LeRobotDataset +from lerobot.common.constants import HF_LEROBOT_HOME +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_raw PUSHT_TASK = "Push the T-shaped blue block onto the T-shaped green target surface." @@ -134,8 +135,8 @@ def main(raw_dir: Path, repo_id: str, mode: str = "video", push_to_hub: bool = T if mode not in ["video", "image", "keypoints"]: raise ValueError(mode) - if (LEROBOT_HOME / repo_id).exists(): - shutil.rmtree(LEROBOT_HOME / repo_id) + if (HF_LEROBOT_HOME / repo_id).exists(): + shutil.rmtree(HF_LEROBOT_HOME / repo_id) if not raw_dir.exists(): download_raw(raw_dir, repo_id="lerobot-raw/pusht_raw") diff --git a/lerobot/common/constants.py b/lerobot/common/constants.py index 34da4ac0..d0c9845a 100644 --- a/lerobot/common/constants.py +++ b/lerobot/common/constants.py @@ -1,4 +1,9 @@ # keys +import os +from pathlib import Path + +from huggingface_hub.constants import HF_HOME + OBS_ENV = "observation.environment_state" OBS_ROBOT = "observation.state" OBS_IMAGE = "observation.image" @@ -15,3 +20,13 @@ TRAINING_STEP = "training_step.json" OPTIMIZER_STATE = "optimizer_state.safetensors" OPTIMIZER_PARAM_GROUPS = "optimizer_param_groups.json" SCHEDULER_STATE = "scheduler_state.json" + +# cache dir +default_cache_path = Path(HF_HOME) / "lerobot" +HF_LEROBOT_HOME = Path(os.getenv("HF_LEROBOT_HOME", default_cache_path)).expanduser() + +if "LEROBOT_HOME" in os.environ: + raise ValueError( + f"You have a 'LEROBOT_HOME' environment variable set to '{os.getenv('LEROBOT_HOME')}'.\n" + "'LEROBOT_HOME' is deprecated, please use 'HF_LEROBOT_HOME' instead." + ) diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index dfdb3618..d4224b7e 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -14,7 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -import os import shutil from pathlib import Path from typing import Callable @@ -29,6 +28,7 @@ from huggingface_hub import HfApi, snapshot_download from huggingface_hub.constants import REPOCARD_NAME from packaging import version +from lerobot.common.constants import HF_LEROBOT_HOME from lerobot.common.datasets.compute_stats import aggregate_stats, compute_episode_stats from lerobot.common.datasets.image_writer import AsyncImageWriter, write_image from lerobot.common.datasets.utils import ( @@ -71,7 +71,6 @@ from lerobot.common.robot_devices.robots.utils import Robot # For maintainers, see lerobot/common/datasets/push_dataset_to_hub/CODEBASE_VERSION.md CODEBASE_VERSION = "v2.1" -LEROBOT_HOME = Path(os.getenv("LEROBOT_HOME", "~/.cache/huggingface/lerobot")).expanduser() class LeRobotDatasetMetadata: @@ -84,7 +83,7 @@ class LeRobotDatasetMetadata: ): self.repo_id = repo_id self.revision = revision if revision else CODEBASE_VERSION - self.root = Path(root) if root is not None else LEROBOT_HOME / repo_id + self.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id try: if force_cache_sync: @@ -308,7 +307,7 @@ class LeRobotDatasetMetadata: """Creates metadata for a LeRobotDataset.""" obj = cls.__new__(cls) obj.repo_id = repo_id - obj.root = Path(root) if root is not None else LEROBOT_HOME / repo_id + obj.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id obj.root.mkdir(parents=True, exist_ok=False) @@ -463,7 +462,7 @@ class LeRobotDataset(torch.utils.data.Dataset): """ super().__init__() self.repo_id = repo_id - self.root = Path(root) if root else LEROBOT_HOME / repo_id + self.root = Path(root) if root else HF_LEROBOT_HOME / repo_id self.image_transforms = image_transforms self.delta_timestamps = delta_timestamps self.episodes = episodes @@ -1056,7 +1055,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset): ): super().__init__() self.repo_ids = repo_ids - self.root = Path(root) if root else LEROBOT_HOME + self.root = Path(root) if root else HF_LEROBOT_HOME self.tolerances_s = tolerances_s if tolerances_s else {repo_id: 1e-4 for repo_id in repo_ids} # Construct the underlying datasets passing everything but `transform` and `delta_timestamps` which # are handled by this class. diff --git a/tests/fixtures/constants.py b/tests/fixtures/constants.py index 7d80d2b7..3201dcf2 100644 --- a/tests/fixtures/constants.py +++ b/tests/fixtures/constants.py @@ -1,6 +1,6 @@ -from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME +from lerobot.common.constants import HF_LEROBOT_HOME -LEROBOT_TEST_DIR = LEROBOT_HOME / "_testing" +LEROBOT_TEST_DIR = HF_LEROBOT_HOME / "_testing" DUMMY_REPO_ID = "dummy/repo" DUMMY_ROBOT_TYPE = "dummy_robot" DUMMY_MOTOR_FEATURES = { diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 3e8b531d..6d358eea 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -581,9 +581,9 @@ def test_create_branch(): def test_dataset_feature_with_forward_slash_raises_error(): # make sure dir does not exist - from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME + from lerobot.common.constants import HF_LEROBOT_HOME - dataset_dir = LEROBOT_HOME / "lerobot/test/with/slash" + dataset_dir = HF_LEROBOT_HOME / "lerobot/test/with/slash" # make sure does not exist if dataset_dir.exists(): dataset_dir.rmdir()