Add push to hub, Camera 0 and 1, pynput optional

This commit is contained in:
Remi Cadene 2024-07-14 16:35:22 +02:00
parent 3fc351074c
commit d2a17d2c74
4 changed files with 30 additions and 18 deletions

View File

@ -36,8 +36,8 @@ def make_robot(name):
), ),
}, },
cameras={ cameras={
"laptop": OpenCVCamera(1, fps=30, width=640, height=480), "laptop": OpenCVCamera(0, fps=30, width=640, height=480),
"phone": OpenCVCamera(2, fps=30, width=640, height=480), "phone": OpenCVCamera(1, fps=30, width=640, height=480),
}, },
) )
else: else:

View File

@ -91,13 +91,14 @@ from pathlib import Path
import torch import torch
import tqdm import tqdm
from huggingface_hub import create_branch
from omegaconf import DictConfig from omegaconf import DictConfig
from PIL import Image from PIL import Image
from pynput import keyboard from pynput import keyboard
# from safetensors.torch import load_file, save_file # from safetensors.torch import load_file, save_file
from lerobot.common.datasets.compute_stats import compute_stats from lerobot.common.datasets.compute_stats import compute_stats
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
from lerobot.common.datasets.push_dataset_to_hub.aloha_hdf5_format import to_hf_dataset from lerobot.common.datasets.push_dataset_to_hub.aloha_hdf5_format import to_hf_dataset
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes
from lerobot.common.datasets.utils import calculate_episode_data_index from lerobot.common.datasets.utils import calculate_episode_data_index
@ -107,7 +108,7 @@ from lerobot.common.robot_devices.robots.factory import make_robot
from lerobot.common.robot_devices.robots.utils import Robot from lerobot.common.robot_devices.robots.utils import Robot
from lerobot.common.utils.utils import get_safe_torch_device, init_hydra_config, init_logging, set_global_seed from lerobot.common.utils.utils import get_safe_torch_device, init_hydra_config, init_logging, set_global_seed
from lerobot.scripts.eval import get_pretrained_policy_path from lerobot.scripts.eval import get_pretrained_policy_path
from lerobot.scripts.push_dataset_to_hub import save_meta_data from lerobot.scripts.push_dataset_to_hub import push_meta_data_to_hub, push_videos_to_hub, save_meta_data
######################################################################################## ########################################################################################
# Utilities # Utilities
@ -209,6 +210,7 @@ def record_dataset(
num_episodes=50, num_episodes=50,
video=True, video=True,
run_compute_stats=True, run_compute_stats=True,
push_to_hub=True,
num_image_writers=8, num_image_writers=8,
force_override=False, force_override=False,
): ):
@ -469,7 +471,12 @@ def record_dataset(
meta_data_dir = local_dir / "meta_data" meta_data_dir = local_dir / "meta_data"
save_meta_data(info, stats, episode_data_index, meta_data_dir) save_meta_data(info, stats, episode_data_index, meta_data_dir)
# TODO(rcadene): push to hub if push_to_hub:
hf_dataset.push_to_hub(repo_id, revision="main")
push_meta_data_to_hub(repo_id, meta_data_dir, revision="main")
if video:
push_videos_to_hub(repo_id, videos_dir, revision="main")
create_branch(repo_id, repo_type="dataset", branch=CODEBASE_VERSION)
logging.info("Exiting") logging.info("Exiting")
os.system('say "Exiting" &') os.system('say "Exiting" &')
@ -613,7 +620,12 @@ if __name__ == "__main__":
default=1, default=1,
help="By default, run the computation of the data statistics at the end of data collection. Compute intensive and not required to just replay an episode.", help="By default, run the computation of the data statistics at the end of data collection. Compute intensive and not required to just replay an episode.",
) )
parser.add_argument(
"--push-to-hub",
type=int,
default=1,
help="Upload dataset to Hugging Face hub.",
)
parser_record.add_argument( parser_record.add_argument(
"--num-image-writers", "--num-image-writers",
type=int, type=int,

20
poetry.lock generated
View File

@ -835,7 +835,7 @@ files = [
name = "evdev" name = "evdev"
version = "1.7.1" version = "1.7.1"
description = "Bindings to the Linux input handling subsystem" description = "Bindings to the Linux input handling subsystem"
optional = false optional = true
python-versions = ">=3.6" python-versions = ">=3.6"
files = [ files = [
{file = "evdev-1.7.1.tar.gz", hash = "sha256:0c72c370bda29d857e188d931019c32651a9c1ea977c08c8d939b1ced1637fde"}, {file = "evdev-1.7.1.tar.gz", hash = "sha256:0c72c370bda29d857e188d931019c32651a9c1ea977c08c8d939b1ced1637fde"},
@ -3014,7 +3014,7 @@ dev = ["aafigure", "matplotlib", "numpy", "pygame", "pyglet (<2.0.0)", "sphinx",
name = "pynput" name = "pynput"
version = "1.7.7" version = "1.7.7"
description = "Monitor and control user input devices" description = "Monitor and control user input devices"
optional = false optional = true
python-versions = "*" python-versions = "*"
files = [ files = [
{file = "pynput-1.7.7-py2.py3-none-any.whl", hash = "sha256:afc43f651684c98818de048abc76adf9f2d3d797083cb07c1f82be764a2d44cb"}, {file = "pynput-1.7.7-py2.py3-none-any.whl", hash = "sha256:afc43f651684c98818de048abc76adf9f2d3d797083cb07c1f82be764a2d44cb"},
@ -3031,7 +3031,7 @@ six = "*"
name = "pyobjc-core" name = "pyobjc-core"
version = "10.3.1" version = "10.3.1"
description = "Python<->ObjC Interoperability Module" description = "Python<->ObjC Interoperability Module"
optional = false optional = true
python-versions = ">=3.8" python-versions = ">=3.8"
files = [ files = [
{file = "pyobjc_core-10.3.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:ea46d2cda17921e417085ac6286d43ae448113158afcf39e0abe484c58fb3d78"}, {file = "pyobjc_core-10.3.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:ea46d2cda17921e417085ac6286d43ae448113158afcf39e0abe484c58fb3d78"},
@ -3048,7 +3048,7 @@ files = [
name = "pyobjc-framework-applicationservices" name = "pyobjc-framework-applicationservices"
version = "10.3.1" version = "10.3.1"
description = "Wrappers for the framework ApplicationServices on macOS" description = "Wrappers for the framework ApplicationServices on macOS"
optional = false optional = true
python-versions = ">=3.8" python-versions = ">=3.8"
files = [ files = [
{file = "pyobjc_framework_ApplicationServices-10.3.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:b694260d423c470cb90c3a7009cfde93e332ea6fb4b9b9526ad3acbd33460e3d"}, {file = "pyobjc_framework_ApplicationServices-10.3.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:b694260d423c470cb90c3a7009cfde93e332ea6fb4b9b9526ad3acbd33460e3d"},
@ -3071,7 +3071,7 @@ pyobjc-framework-Quartz = ">=10.3.1"
name = "pyobjc-framework-cocoa" name = "pyobjc-framework-cocoa"
version = "10.3.1" version = "10.3.1"
description = "Wrappers for the Cocoa frameworks on macOS" description = "Wrappers for the Cocoa frameworks on macOS"
optional = false optional = true
python-versions = ">=3.8" python-versions = ">=3.8"
files = [ files = [
{file = "pyobjc_framework_Cocoa-10.3.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:4cb4f8491ab4d9b59f5187e42383f819f7a46306a4fa25b84f126776305291d1"}, {file = "pyobjc_framework_Cocoa-10.3.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:4cb4f8491ab4d9b59f5187e42383f819f7a46306a4fa25b84f126776305291d1"},
@ -3091,7 +3091,7 @@ pyobjc-core = ">=10.3.1"
name = "pyobjc-framework-coretext" name = "pyobjc-framework-coretext"
version = "10.3.1" version = "10.3.1"
description = "Wrappers for the framework CoreText on macOS" description = "Wrappers for the framework CoreText on macOS"
optional = false optional = true
python-versions = ">=3.8" python-versions = ">=3.8"
files = [ files = [
{file = "pyobjc_framework_CoreText-10.3.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:dd6123cfccc38e32be884d1a13fb62bd636ecb192b9e8ae2b8011c977dec229e"}, {file = "pyobjc_framework_CoreText-10.3.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:dd6123cfccc38e32be884d1a13fb62bd636ecb192b9e8ae2b8011c977dec229e"},
@ -3113,7 +3113,7 @@ pyobjc-framework-Quartz = ">=10.3.1"
name = "pyobjc-framework-quartz" name = "pyobjc-framework-quartz"
version = "10.3.1" version = "10.3.1"
description = "Wrappers for the Quartz frameworks on macOS" description = "Wrappers for the Quartz frameworks on macOS"
optional = false optional = true
python-versions = ">=3.8" python-versions = ">=3.8"
files = [ files = [
{file = "pyobjc_framework_Quartz-10.3.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:5ef4fd315ed2bc42ef77fdeb2bae28a88ec986bd7b8079a87ba3b3475348f96e"}, {file = "pyobjc_framework_Quartz-10.3.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:5ef4fd315ed2bc42ef77fdeb2bae28a88ec986bd7b8079a87ba3b3475348f96e"},
@ -3256,7 +3256,7 @@ six = ">=1.5"
name = "python-xlib" name = "python-xlib"
version = "0.33" version = "0.33"
description = "Python X Library" description = "Python X Library"
optional = false optional = true
python-versions = "*" python-versions = "*"
files = [ files = [
{file = "python-xlib-0.33.tar.gz", hash = "sha256:55af7906a2c75ce6cb280a584776080602444f75815a7aff4d287bb2d7018b32"}, {file = "python-xlib-0.33.tar.gz", hash = "sha256:55af7906a2c75ce6cb280a584776080602444f75815a7aff4d287bb2d7018b32"},
@ -4497,7 +4497,7 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools",
aloha = ["gym-aloha"] aloha = ["gym-aloha"]
dev = ["debugpy", "pre-commit"] dev = ["debugpy", "pre-commit"]
dora = ["gym-dora"] dora = ["gym-dora"]
koch = ["dynamixel-sdk"] koch = ["dynamixel-sdk", "pynput"]
pusht = ["gym-pusht"] pusht = ["gym-pusht"]
test = ["pytest", "pytest-cov", "pytest-mock"] test = ["pytest", "pytest-cov", "pytest-mock"]
umi = ["imagecodecs"] umi = ["imagecodecs"]
@ -4507,4 +4507,4 @@ xarm = ["gym-xarm"]
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = ">=3.10,<3.13" python-versions = ">=3.10,<3.13"
content-hash = "6d82706d5216ce065ba5912ea9f802846dfdf7e7b1665f8560cc782fb0dfa354" content-hash = "2c59d869c6b1f2132070387f3d371b5b004765ae853501bbd522eb400738f2d0"

View File

@ -64,7 +64,7 @@ scikit-image = {version = "^0.23.2", optional = true}
pandas = {version = "^2.2.2", optional = true} pandas = {version = "^2.2.2", optional = true}
pytest-mock = {version = "^3.14.0", optional = true} pytest-mock = {version = "^3.14.0", optional = true}
dynamixel-sdk = {version = "^3.7.31", optional = true} dynamixel-sdk = {version = "^3.7.31", optional = true}
pynput = "^1.7.7" pynput = {version = "^1.7.7", optional = true}
@ -77,7 +77,7 @@ dev = ["pre-commit", "debugpy"]
test = ["pytest", "pytest-cov", "pytest-mock"] test = ["pytest", "pytest-cov", "pytest-mock"]
umi = ["imagecodecs"] umi = ["imagecodecs"]
video_benchmark = ["scikit-image", "pandas"] video_benchmark = ["scikit-image", "pandas"]
koch = ["dynamixel-sdk"] koch = ["dynamixel-sdk", "pynput"]
[tool.ruff] [tool.ruff]
line-length = 110 line-length = 110