Use `HF_HOME` env variable (#753)
This commit is contained in:
parent
fbf2f2222a
commit
2487228ea7
|
@ -4,7 +4,8 @@ from pathlib import Path
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME, LeRobotDataset
|
from lerobot.common.constants import HF_LEROBOT_HOME
|
||||||
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_raw
|
from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_raw
|
||||||
|
|
||||||
PUSHT_TASK = "Push the T-shaped blue block onto the T-shaped green target surface."
|
PUSHT_TASK = "Push the T-shaped blue block onto the T-shaped green target surface."
|
||||||
|
@ -134,8 +135,8 @@ def main(raw_dir: Path, repo_id: str, mode: str = "video", push_to_hub: bool = T
|
||||||
if mode not in ["video", "image", "keypoints"]:
|
if mode not in ["video", "image", "keypoints"]:
|
||||||
raise ValueError(mode)
|
raise ValueError(mode)
|
||||||
|
|
||||||
if (LEROBOT_HOME / repo_id).exists():
|
if (HF_LEROBOT_HOME / repo_id).exists():
|
||||||
shutil.rmtree(LEROBOT_HOME / repo_id)
|
shutil.rmtree(HF_LEROBOT_HOME / repo_id)
|
||||||
|
|
||||||
if not raw_dir.exists():
|
if not raw_dir.exists():
|
||||||
download_raw(raw_dir, repo_id="lerobot-raw/pusht_raw")
|
download_raw(raw_dir, repo_id="lerobot-raw/pusht_raw")
|
||||||
|
|
|
@ -1,4 +1,9 @@
|
||||||
# keys
|
# keys
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from huggingface_hub.constants import HF_HOME
|
||||||
|
|
||||||
OBS_ENV = "observation.environment_state"
|
OBS_ENV = "observation.environment_state"
|
||||||
OBS_ROBOT = "observation.state"
|
OBS_ROBOT = "observation.state"
|
||||||
OBS_IMAGE = "observation.image"
|
OBS_IMAGE = "observation.image"
|
||||||
|
@ -15,3 +20,13 @@ TRAINING_STEP = "training_step.json"
|
||||||
OPTIMIZER_STATE = "optimizer_state.safetensors"
|
OPTIMIZER_STATE = "optimizer_state.safetensors"
|
||||||
OPTIMIZER_PARAM_GROUPS = "optimizer_param_groups.json"
|
OPTIMIZER_PARAM_GROUPS = "optimizer_param_groups.json"
|
||||||
SCHEDULER_STATE = "scheduler_state.json"
|
SCHEDULER_STATE = "scheduler_state.json"
|
||||||
|
|
||||||
|
# cache dir
|
||||||
|
default_cache_path = Path(HF_HOME) / "lerobot"
|
||||||
|
HF_LEROBOT_HOME = Path(os.getenv("HF_LEROBOT_HOME", default_cache_path)).expanduser()
|
||||||
|
|
||||||
|
if "LEROBOT_HOME" in os.environ:
|
||||||
|
raise ValueError(
|
||||||
|
f"You have a 'LEROBOT_HOME' environment variable set to '{os.getenv('LEROBOT_HOME')}'.\n"
|
||||||
|
"'LEROBOT_HOME' is deprecated, please use 'HF_LEROBOT_HOME' instead."
|
||||||
|
)
|
||||||
|
|
|
@ -14,7 +14,6 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
import shutil
|
import shutil
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
@ -29,6 +28,7 @@ from huggingface_hub import HfApi, snapshot_download
|
||||||
from huggingface_hub.constants import REPOCARD_NAME
|
from huggingface_hub.constants import REPOCARD_NAME
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
|
||||||
|
from lerobot.common.constants import HF_LEROBOT_HOME
|
||||||
from lerobot.common.datasets.compute_stats import aggregate_stats, compute_episode_stats
|
from lerobot.common.datasets.compute_stats import aggregate_stats, compute_episode_stats
|
||||||
from lerobot.common.datasets.image_writer import AsyncImageWriter, write_image
|
from lerobot.common.datasets.image_writer import AsyncImageWriter, write_image
|
||||||
from lerobot.common.datasets.utils import (
|
from lerobot.common.datasets.utils import (
|
||||||
|
@ -71,7 +71,6 @@ from lerobot.common.robot_devices.robots.utils import Robot
|
||||||
|
|
||||||
# For maintainers, see lerobot/common/datasets/push_dataset_to_hub/CODEBASE_VERSION.md
|
# For maintainers, see lerobot/common/datasets/push_dataset_to_hub/CODEBASE_VERSION.md
|
||||||
CODEBASE_VERSION = "v2.1"
|
CODEBASE_VERSION = "v2.1"
|
||||||
LEROBOT_HOME = Path(os.getenv("LEROBOT_HOME", "~/.cache/huggingface/lerobot")).expanduser()
|
|
||||||
|
|
||||||
|
|
||||||
class LeRobotDatasetMetadata:
|
class LeRobotDatasetMetadata:
|
||||||
|
@ -84,7 +83,7 @@ class LeRobotDatasetMetadata:
|
||||||
):
|
):
|
||||||
self.repo_id = repo_id
|
self.repo_id = repo_id
|
||||||
self.revision = revision if revision else CODEBASE_VERSION
|
self.revision = revision if revision else CODEBASE_VERSION
|
||||||
self.root = Path(root) if root is not None else LEROBOT_HOME / repo_id
|
self.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if force_cache_sync:
|
if force_cache_sync:
|
||||||
|
@ -308,7 +307,7 @@ class LeRobotDatasetMetadata:
|
||||||
"""Creates metadata for a LeRobotDataset."""
|
"""Creates metadata for a LeRobotDataset."""
|
||||||
obj = cls.__new__(cls)
|
obj = cls.__new__(cls)
|
||||||
obj.repo_id = repo_id
|
obj.repo_id = repo_id
|
||||||
obj.root = Path(root) if root is not None else LEROBOT_HOME / repo_id
|
obj.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id
|
||||||
|
|
||||||
obj.root.mkdir(parents=True, exist_ok=False)
|
obj.root.mkdir(parents=True, exist_ok=False)
|
||||||
|
|
||||||
|
@ -463,7 +462,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.repo_id = repo_id
|
self.repo_id = repo_id
|
||||||
self.root = Path(root) if root else LEROBOT_HOME / repo_id
|
self.root = Path(root) if root else HF_LEROBOT_HOME / repo_id
|
||||||
self.image_transforms = image_transforms
|
self.image_transforms = image_transforms
|
||||||
self.delta_timestamps = delta_timestamps
|
self.delta_timestamps = delta_timestamps
|
||||||
self.episodes = episodes
|
self.episodes = episodes
|
||||||
|
@ -1056,7 +1055,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.repo_ids = repo_ids
|
self.repo_ids = repo_ids
|
||||||
self.root = Path(root) if root else LEROBOT_HOME
|
self.root = Path(root) if root else HF_LEROBOT_HOME
|
||||||
self.tolerances_s = tolerances_s if tolerances_s else {repo_id: 1e-4 for repo_id in repo_ids}
|
self.tolerances_s = tolerances_s if tolerances_s else {repo_id: 1e-4 for repo_id in repo_ids}
|
||||||
# Construct the underlying datasets passing everything but `transform` and `delta_timestamps` which
|
# Construct the underlying datasets passing everything but `transform` and `delta_timestamps` which
|
||||||
# are handled by this class.
|
# are handled by this class.
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME
|
from lerobot.common.constants import HF_LEROBOT_HOME
|
||||||
|
|
||||||
LEROBOT_TEST_DIR = LEROBOT_HOME / "_testing"
|
LEROBOT_TEST_DIR = HF_LEROBOT_HOME / "_testing"
|
||||||
DUMMY_REPO_ID = "dummy/repo"
|
DUMMY_REPO_ID = "dummy/repo"
|
||||||
DUMMY_ROBOT_TYPE = "dummy_robot"
|
DUMMY_ROBOT_TYPE = "dummy_robot"
|
||||||
DUMMY_MOTOR_FEATURES = {
|
DUMMY_MOTOR_FEATURES = {
|
||||||
|
|
|
@ -581,9 +581,9 @@ def test_create_branch():
|
||||||
|
|
||||||
def test_dataset_feature_with_forward_slash_raises_error():
|
def test_dataset_feature_with_forward_slash_raises_error():
|
||||||
# make sure dir does not exist
|
# make sure dir does not exist
|
||||||
from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME
|
from lerobot.common.constants import HF_LEROBOT_HOME
|
||||||
|
|
||||||
dataset_dir = LEROBOT_HOME / "lerobot/test/with/slash"
|
dataset_dir = HF_LEROBOT_HOME / "lerobot/test/with/slash"
|
||||||
# make sure does not exist
|
# make sure does not exist
|
||||||
if dataset_dir.exists():
|
if dataset_dir.exists():
|
||||||
dataset_dir.rmdir()
|
dataset_dir.rmdir()
|
||||||
|
|
Loading…
Reference in New Issue