Use `HF_HOME` env variable (#753)

This commit is contained in:
Simon Alibert 2025-02-19 14:49:46 +01:00 committed by GitHub
parent fbf2f2222a
commit 2487228ea7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 28 additions and 13 deletions

View File

@ -4,7 +4,8 @@ from pathlib import Path
import numpy as np
import torch
from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME, LeRobotDataset
from lerobot.common.constants import HF_LEROBOT_HOME
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_raw
PUSHT_TASK = "Push the T-shaped blue block onto the T-shaped green target surface."
@ -134,8 +135,8 @@ def main(raw_dir: Path, repo_id: str, mode: str = "video", push_to_hub: bool = T
if mode not in ["video", "image", "keypoints"]:
raise ValueError(mode)
if (LEROBOT_HOME / repo_id).exists():
shutil.rmtree(LEROBOT_HOME / repo_id)
if (HF_LEROBOT_HOME / repo_id).exists():
shutil.rmtree(HF_LEROBOT_HOME / repo_id)
if not raw_dir.exists():
download_raw(raw_dir, repo_id="lerobot-raw/pusht_raw")

View File

@ -1,4 +1,9 @@
# keys
import os
from pathlib import Path
from huggingface_hub.constants import HF_HOME
OBS_ENV = "observation.environment_state"
OBS_ROBOT = "observation.state"
OBS_IMAGE = "observation.image"
@ -15,3 +20,13 @@ TRAINING_STEP = "training_step.json"
OPTIMIZER_STATE = "optimizer_state.safetensors"
OPTIMIZER_PARAM_GROUPS = "optimizer_param_groups.json"
SCHEDULER_STATE = "scheduler_state.json"
# cache dir
default_cache_path = Path(HF_HOME) / "lerobot"
HF_LEROBOT_HOME = Path(os.getenv("HF_LEROBOT_HOME", default_cache_path)).expanduser()
if "LEROBOT_HOME" in os.environ:
raise ValueError(
f"You have a 'LEROBOT_HOME' environment variable set to '{os.getenv('LEROBOT_HOME')}'.\n"
"'LEROBOT_HOME' is deprecated, please use 'HF_LEROBOT_HOME' instead."
)

View File

@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import shutil
from pathlib import Path
from typing import Callable
@ -29,6 +28,7 @@ from huggingface_hub import HfApi, snapshot_download
from huggingface_hub.constants import REPOCARD_NAME
from packaging import version
from lerobot.common.constants import HF_LEROBOT_HOME
from lerobot.common.datasets.compute_stats import aggregate_stats, compute_episode_stats
from lerobot.common.datasets.image_writer import AsyncImageWriter, write_image
from lerobot.common.datasets.utils import (
@ -71,7 +71,6 @@ from lerobot.common.robot_devices.robots.utils import Robot
# For maintainers, see lerobot/common/datasets/push_dataset_to_hub/CODEBASE_VERSION.md
CODEBASE_VERSION = "v2.1"
LEROBOT_HOME = Path(os.getenv("LEROBOT_HOME", "~/.cache/huggingface/lerobot")).expanduser()
class LeRobotDatasetMetadata:
@ -84,7 +83,7 @@ class LeRobotDatasetMetadata:
):
self.repo_id = repo_id
self.revision = revision if revision else CODEBASE_VERSION
self.root = Path(root) if root is not None else LEROBOT_HOME / repo_id
self.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id
try:
if force_cache_sync:
@ -308,7 +307,7 @@ class LeRobotDatasetMetadata:
"""Creates metadata for a LeRobotDataset."""
obj = cls.__new__(cls)
obj.repo_id = repo_id
obj.root = Path(root) if root is not None else LEROBOT_HOME / repo_id
obj.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id
obj.root.mkdir(parents=True, exist_ok=False)
@ -463,7 +462,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
"""
super().__init__()
self.repo_id = repo_id
self.root = Path(root) if root else LEROBOT_HOME / repo_id
self.root = Path(root) if root else HF_LEROBOT_HOME / repo_id
self.image_transforms = image_transforms
self.delta_timestamps = delta_timestamps
self.episodes = episodes
@ -1056,7 +1055,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
):
super().__init__()
self.repo_ids = repo_ids
self.root = Path(root) if root else LEROBOT_HOME
self.root = Path(root) if root else HF_LEROBOT_HOME
self.tolerances_s = tolerances_s if tolerances_s else {repo_id: 1e-4 for repo_id in repo_ids}
# Construct the underlying datasets passing everything but `transform` and `delta_timestamps` which
# are handled by this class.

View File

@ -1,6 +1,6 @@
from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME
from lerobot.common.constants import HF_LEROBOT_HOME
LEROBOT_TEST_DIR = LEROBOT_HOME / "_testing"
LEROBOT_TEST_DIR = HF_LEROBOT_HOME / "_testing"
DUMMY_REPO_ID = "dummy/repo"
DUMMY_ROBOT_TYPE = "dummy_robot"
DUMMY_MOTOR_FEATURES = {

View File

@ -581,9 +581,9 @@ def test_create_branch():
def test_dataset_feature_with_forward_slash_raises_error():
# make sure dir does not exist
from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME
from lerobot.common.constants import HF_LEROBOT_HOME
dataset_dir = LEROBOT_HOME / "lerobot/test/with/slash"
dataset_dir = HF_LEROBOT_HOME / "lerobot/test/with/slash"
# make sure does not exist
if dataset_dir.exists():
dataset_dir.rmdir()