Merge branch 'main' of github.com:huggingface/lerobot

This commit is contained in:
jess-moss 2024-07-22 08:48:50 -05:00
commit 3fde016246
69 changed files with 3785 additions and 522 deletions

View File

@ -54,3 +54,31 @@ jobs:
- name: Poetry check - name: Poetry check
run: poetry check run: poetry check
poetry_relax:
name: Poetry relax
runs-on: ubuntu-latest
steps:
- name: Checkout Repository
uses: actions/checkout@v3
- name: Install poetry
run: pipx install poetry
- name: Install poetry-relax
run: poetry self add poetry-relax
- name: Poetry relax
id: poetry_relax
run: |
output=$(poetry relax --check 2>&1)
if echo "$output" | grep -q "Proposing updates"; then
echo "$output"
echo ""
echo "Some dependencies have caret '^' version requirement added by poetry by default."
echo "Please replace them with '>='. You can do this by hand or use poetry-relax to do this."
exit 1
else
echo "$output"
fi

View File

@ -16,3 +16,5 @@ jobs:
fetch-depth: 0 fetch-depth: 0
- name: Secret Scanning - name: Secret Scanning
uses: trufflesecurity/trufflehog@main uses: trufflesecurity/trufflehog@main
with:
extra_args: --only-verified

1
.gitignore vendored
View File

@ -122,7 +122,6 @@ celerybeat.pid
.env .env
.venv .venv
venv/ venv/
ENV/
env.bak/ env.bak/
venv.bak/ venv.bak/

View File

@ -14,11 +14,11 @@ repos:
- id: end-of-file-fixer - id: end-of-file-fixer
- id: trailing-whitespace - id: trailing-whitespace
- repo: https://github.com/asottile/pyupgrade - repo: https://github.com/asottile/pyupgrade
rev: v3.15.2 rev: v3.16.0
hooks: hooks:
- id: pyupgrade - id: pyupgrade
- repo: https://github.com/astral-sh/ruff-pre-commit - repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.4.3 rev: v0.5.2
hooks: hooks:
- id: ruff - id: ruff
args: [--fix] args: [--fix]
@ -31,3 +31,7 @@ repos:
args: args:
- "--check" - "--check"
- "--no-update" - "--no-update"
- repo: https://github.com/gitleaks/gitleaks
rev: v8.18.4
hooks:
- id: gitleaks

View File

@ -21,7 +21,7 @@ RUN echo "source /opt/venv/bin/activate" >> /root/.bashrc
COPY . /lerobot COPY . /lerobot
WORKDIR /lerobot WORKDIR /lerobot
RUN pip install --upgrade --no-cache-dir pip RUN pip install --upgrade --no-cache-dir pip
RUN pip install --no-cache-dir ".[test, aloha, xarm, pusht]" \ RUN pip install --no-cache-dir ".[test, aloha, xarm, pusht, koch]" \
--extra-index-url https://download.pytorch.org/whl/cpu --extra-index-url https://download.pytorch.org/whl/cpu
# Set EGL as the rendering backend for MuJoCo # Set EGL as the rendering backend for MuJoCo

View File

@ -43,7 +43,6 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
libsvtav1-dev libsvtav1enc-dev libsvtav1dec-dev \ libsvtav1-dev libsvtav1enc-dev libsvtav1dec-dev \
libdav1d-dev libdav1d-dev
# Install gh cli tool # Install gh cli tool
RUN (type -p wget >/dev/null || (apt update && apt-get install wget -y)) \ RUN (type -p wget >/dev/null || (apt update && apt-get install wget -y)) \
&& mkdir -p -m 755 /etc/apt/keyrings \ && mkdir -p -m 755 /etc/apt/keyrings \

View File

@ -9,7 +9,7 @@ ARG DEBIAN_FRONTEND=noninteractive
RUN apt-get update && apt-get install -y --no-install-recommends \ RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential cmake \ build-essential cmake \
libglib2.0-0 libgl1-mesa-glx libegl1-mesa ffmpeg \ libglib2.0-0 libgl1-mesa-glx libegl1-mesa ffmpeg \
python${PYTHON_VERSION} python${PYTHON_VERSION}-venv \ python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv \
&& apt-get clean && rm -rf /var/lib/apt/lists/* && apt-get clean && rm -rf /var/lib/apt/lists/*
@ -23,7 +23,7 @@ RUN echo "source /opt/venv/bin/activate" >> /root/.bashrc
COPY . /lerobot COPY . /lerobot
WORKDIR /lerobot WORKDIR /lerobot
RUN pip install --upgrade --no-cache-dir pip RUN pip install --upgrade --no-cache-dir pip
RUN pip install --no-cache-dir ".[test, aloha, xarm, pusht]" RUN pip install --no-cache-dir ".[test, aloha, xarm, pusht, koch]"
# Set EGL as the rendering backend for MuJoCo # Set EGL as the rendering backend for MuJoCo
ENV MUJOCO_GL="egl" ENV MUJOCO_GL="egl"

View File

@ -80,7 +80,7 @@ policy:
n_vae_encoder_layers: 4 n_vae_encoder_layers: 4
# Inference. # Inference.
temporal_ensemble_momentum: null temporal_ensemble_coeff: null
# Training and loss computation. # Training and loss computation.
dropout: 0.1 dropout: 0.1

View File

@ -35,15 +35,16 @@ from lerobot.common.datasets.utils import (
) )
from lerobot.common.datasets.video_utils import VideoFrame, load_from_videos from lerobot.common.datasets.video_utils import VideoFrame, load_from_videos
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None # For maintainers, see lerobot/common/datasets/push_dataset_to_hub/codebase_version.md
CODEBASE_VERSION = "v1.5" CODEBASE_VERSION = "v1.5"
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
class LeRobotDataset(torch.utils.data.Dataset): class LeRobotDataset(torch.utils.data.Dataset):
def __init__( def __init__(
self, self,
repo_id: str, repo_id: str,
version: str | None = CODEBASE_VERSION,
root: Path | None = DATA_DIR, root: Path | None = DATA_DIR,
split: str = "train", split: str = "train",
image_transforms: Callable | None = None, image_transforms: Callable | None = None,
@ -52,7 +53,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
): ):
super().__init__() super().__init__()
self.repo_id = repo_id self.repo_id = repo_id
self.version = version
self.root = root self.root = root
self.split = split self.split = split
self.image_transforms = image_transforms self.image_transforms = image_transforms
@ -60,16 +60,16 @@ class LeRobotDataset(torch.utils.data.Dataset):
# load data from hub or locally when root is provided # load data from hub or locally when root is provided
# TODO(rcadene, aliberts): implement faster transfer # TODO(rcadene, aliberts): implement faster transfer
# https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads # https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads
self.hf_dataset = load_hf_dataset(repo_id, version, root, split) self.hf_dataset = load_hf_dataset(repo_id, CODEBASE_VERSION, root, split)
if split == "train": if split == "train":
self.episode_data_index = load_episode_data_index(repo_id, version, root) self.episode_data_index = load_episode_data_index(repo_id, CODEBASE_VERSION, root)
else: else:
self.episode_data_index = calculate_episode_data_index(self.hf_dataset) self.episode_data_index = calculate_episode_data_index(self.hf_dataset)
self.hf_dataset = reset_episode_index(self.hf_dataset) self.hf_dataset = reset_episode_index(self.hf_dataset)
self.stats = load_stats(repo_id, version, root) self.stats = load_stats(repo_id, CODEBASE_VERSION, root)
self.info = load_info(repo_id, version, root) self.info = load_info(repo_id, CODEBASE_VERSION, root)
if self.video: if self.video:
self.videos_dir = load_videos(repo_id, version, root) self.videos_dir = load_videos(repo_id, CODEBASE_VERSION, root)
self.video_backend = video_backend if video_backend is not None else "pyav" self.video_backend = video_backend if video_backend is not None else "pyav"
@property @property
@ -164,7 +164,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
return ( return (
f"{self.__class__.__name__}(\n" f"{self.__class__.__name__}(\n"
f" Repository ID: '{self.repo_id}',\n" f" Repository ID: '{self.repo_id}',\n"
f" Version: '{self.version}',\n"
f" Split: '{self.split}',\n" f" Split: '{self.split}',\n"
f" Number of Samples: {self.num_samples},\n" f" Number of Samples: {self.num_samples},\n"
f" Number of Episodes: {self.num_episodes},\n" f" Number of Episodes: {self.num_episodes},\n"
@ -173,6 +172,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
f" Camera Keys: {self.camera_keys},\n" f" Camera Keys: {self.camera_keys},\n"
f" Video Frame Keys: {self.video_frame_keys if self.video else 'N/A'},\n" f" Video Frame Keys: {self.video_frame_keys if self.video else 'N/A'},\n"
f" Transformations: {self.image_transforms},\n" f" Transformations: {self.image_transforms},\n"
f" Codebase Version: {self.info.get('codebase_version', '< v1.6')},\n"
f")" f")"
) )
@ -180,7 +180,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
def from_preloaded( def from_preloaded(
cls, cls,
repo_id: str = "from_preloaded", repo_id: str = "from_preloaded",
version: str | None = CODEBASE_VERSION,
root: Path | None = None, root: Path | None = None,
split: str = "train", split: str = "train",
transform: callable = None, transform: callable = None,
@ -204,7 +203,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
# create an empty object of type LeRobotDataset # create an empty object of type LeRobotDataset
obj = cls.__new__(cls) obj = cls.__new__(cls)
obj.repo_id = repo_id obj.repo_id = repo_id
obj.version = version
obj.root = root obj.root = root
obj.split = split obj.split = split
obj.image_transforms = transform obj.image_transforms = transform
@ -228,7 +226,6 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
def __init__( def __init__(
self, self,
repo_ids: list[str], repo_ids: list[str],
version: str | None = CODEBASE_VERSION,
root: Path | None = DATA_DIR, root: Path | None = DATA_DIR,
split: str = "train", split: str = "train",
image_transforms: Callable | None = None, image_transforms: Callable | None = None,
@ -242,7 +239,6 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
self._datasets = [ self._datasets = [
LeRobotDataset( LeRobotDataset(
repo_id, repo_id,
version=version,
root=root, root=root,
split=split, split=split,
delta_timestamps=delta_timestamps, delta_timestamps=delta_timestamps,
@ -279,7 +275,6 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
) )
self.disabled_data_keys.update(extra_keys) self.disabled_data_keys.update(extra_keys)
self.version = version
self.root = root self.root = root
self.split = split self.split = split
self.image_transforms = image_transforms self.image_transforms = image_transforms
@ -395,7 +390,6 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
return ( return (
f"{self.__class__.__name__}(\n" f"{self.__class__.__name__}(\n"
f" Repository IDs: '{self.repo_ids}',\n" f" Repository IDs: '{self.repo_ids}',\n"
f" Version: '{self.version}',\n"
f" Split: '{self.split}',\n" f" Split: '{self.split}',\n"
f" Number of Samples: {self.num_samples},\n" f" Number of Samples: {self.num_samples},\n"
f" Number of Episodes: {self.num_episodes},\n" f" Number of Episodes: {self.num_episodes},\n"

View File

@ -0,0 +1,57 @@
## Using / Updating `CODEBASE_VERSION` (for maintainers)
Since our dataset pushed to the hub are decoupled with the evolution of this repo, we ensure compatibility of
the datasets with our code, we use a `CODEBASE_VERSION` (defined in
lerobot/common/datasets/lerobot_dataset.py) variable.
For instance, [`lerobot/pusht`](https://huggingface.co/datasets/lerobot/pusht) has many versions to maintain backward compatibility between LeRobot codebase versions:
- [v1.0](https://huggingface.co/datasets/lerobot/pusht/tree/v1.0)
- [v1.1](https://huggingface.co/datasets/lerobot/pusht/tree/v1.1)
- [v1.2](https://huggingface.co/datasets/lerobot/pusht/tree/v1.2)
- [v1.3](https://huggingface.co/datasets/lerobot/pusht/tree/v1.3)
- [v1.4](https://huggingface.co/datasets/lerobot/pusht/tree/v1.4)
- [v1.5](https://huggingface.co/datasets/lerobot/pusht/tree/v1.5) <-- last version
- [main](https://huggingface.co/datasets/lerobot/pusht/tree/main) <-- points to the last version
Starting with v1.6, every dataset pushed to the hub or saved locally also have this version number in their
`info.json` metadata.
### Uploading a new dataset
If you are pushing a new dataset, you don't need to worry about any of the instructions below, nor to be
compatible with previous codebase versions. The `push_dataset_to_hub.py` script will automatically tag your
dataset with the current `CODEBASE_VERSION`.
### Updating an existing dataset
If you want to update an existing dataset, you need to change the `CODEBASE_VERSION` from `lerobot_dataset.py`
before running `push_dataset_to_hub.py`. This is especially useful if you introduce a breaking change
intentionally or not (i.e. something not backward compatible such as modifying the reward functions used,
deleting some frames at the end of an episode, etc.). That way, people running a previous version of the
codebase won't be affected by your change and backward compatibility is maintained.
However, you will need to update the version of ALL the other datasets so that they have the new
`CODEBASE_VERSION` as a branch in their hugging face dataset repository. Don't worry, there is an easy way
that doesn't require to run `push_dataset_to_hub.py`. You can just "branch-out" from the `main` branch on HF
dataset repo by running this script which corresponds to a `git checkout -b` (so no copy or upload needed):
```python
from huggingface_hub import HfApi
from lerobot import available_datasets
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
api = HfApi()
for repo_id in available_datasets:
dataset_info = api.list_repo_refs(repo_id, repo_type="dataset")
branches = [b.name for b in dataset_info.branches]
if CODEBASE_VERSION in branches:
# First check if the newer version already exists.
print(f"Found existing branch for {repo_id}. Please contact a member of the core LeRobot team.")
print("Exiting early")
break
else:
# Now create a branch named after the new version by branching out from "main"
# which is expected to be the preceding version
api.create_branch(repo_id, repo_type="dataset", branch=CODEBASE_VERSION, revision="main")
print(f"{repo_id} successfully updated")
```

View File

@ -31,6 +31,44 @@ from pathlib import Path
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
AVAILABLE_RAW_REPO_IDS = [
"lerobot-raw/aloha_mobile_cabinet_raw",
"lerobot-raw/aloha_mobile_chair_raw",
"lerobot-raw/aloha_mobile_elevator_raw",
"lerobot-raw/aloha_mobile_shrimp_raw",
"lerobot-raw/aloha_mobile_wash_pan_raw",
"lerobot-raw/aloha_mobile_wipe_wine_raw",
"lerobot-raw/aloha_sim_insertion_human_raw",
"lerobot-raw/aloha_sim_insertion_scripted_raw",
"lerobot-raw/aloha_sim_transfer_cube_human_raw",
"lerobot-raw/aloha_sim_transfer_cube_scripted_raw",
"lerobot-raw/aloha_static_battery_raw",
"lerobot-raw/aloha_static_candy_raw",
"lerobot-raw/aloha_static_coffee_new_raw",
"lerobot-raw/aloha_static_coffee_raw",
"lerobot-raw/aloha_static_cups_open_raw",
"lerobot-raw/aloha_static_fork_pick_up_raw",
"lerobot-raw/aloha_static_pingpong_test_raw",
"lerobot-raw/aloha_static_pro_pencil_raw",
"lerobot-raw/aloha_static_screw_driver_raw",
"lerobot-raw/aloha_static_tape_raw",
"lerobot-raw/aloha_static_thread_velcro_raw",
"lerobot-raw/aloha_static_towel_raw",
"lerobot-raw/aloha_static_vinh_cup_left_raw",
"lerobot-raw/aloha_static_vinh_cup_raw",
"lerobot-raw/aloha_static_ziploc_slide_raw",
"lerobot-raw/pusht_raw",
"lerobot-raw/umi_cup_in_the_wild_raw",
"lerobot-raw/unitreeh1_fold_clothes_raw",
"lerobot-raw/unitreeh1_rearrange_objects_raw",
"lerobot-raw/unitreeh1_two_robot_greeting_raw",
"lerobot-raw/unitreeh1_warehouse_raw",
"lerobot-raw/xarm_lift_medium_raw",
"lerobot-raw/xarm_lift_medium_replay_raw",
"lerobot-raw/xarm_push_medium_raw",
"lerobot-raw/xarm_push_medium_replay_raw",
]
def download_raw(raw_dir: Path, repo_id: str): def download_raw(raw_dir: Path, repo_id: str):
# Check repo_id is well formated # Check repo_id is well formated
@ -46,7 +84,6 @@ def download_raw(raw_dir: Path, repo_id: str):
stacklevel=1, stacklevel=1,
) )
raw_dir = Path(raw_dir)
# Send warning if raw_dir isn't well formated # Send warning if raw_dir isn't well formated
if raw_dir.parts[-2] != user_id or raw_dir.parts[-1] != dataset_id: if raw_dir.parts[-2] != user_id or raw_dir.parts[-1] != dataset_id:
warnings.warn( warnings.warn(
@ -56,61 +93,21 @@ def download_raw(raw_dir: Path, repo_id: str):
raw_dir.mkdir(parents=True, exist_ok=True) raw_dir.mkdir(parents=True, exist_ok=True)
logging.info(f"Start downloading from huggingface.co/{user_id} for {dataset_id}") logging.info(f"Start downloading from huggingface.co/{user_id} for {dataset_id}")
snapshot_download(f"{repo_id}", repo_type="dataset", local_dir=raw_dir) snapshot_download(repo_id, repo_type="dataset", local_dir=raw_dir)
logging.info(f"Finish downloading from huggingface.co/{user_id} for {dataset_id}") logging.info(f"Finish downloading from huggingface.co/{user_id} for {dataset_id}")
def download_all_raw_datasets(): def download_all_raw_datasets():
data_dir = Path("data") data_dir = Path("data")
repo_ids = [ for repo_id in AVAILABLE_RAW_REPO_IDS:
"cadene/pusht_image_raw",
"cadene/xarm_lift_medium_image_raw",
"cadene/xarm_lift_medium_replay_image_raw",
"cadene/xarm_push_medium_image_raw",
"cadene/xarm_push_medium_replay_image_raw",
"cadene/aloha_sim_insertion_human_image_raw",
"cadene/aloha_sim_insertion_scripted_image_raw",
"cadene/aloha_sim_transfer_cube_human_image_raw",
"cadene/aloha_sim_transfer_cube_scripted_image_raw",
"cadene/pusht_raw",
"cadene/xarm_lift_medium_raw",
"cadene/xarm_lift_medium_replay_raw",
"cadene/xarm_push_medium_raw",
"cadene/xarm_push_medium_replay_raw",
"cadene/aloha_sim_insertion_human_raw",
"cadene/aloha_sim_insertion_scripted_raw",
"cadene/aloha_sim_transfer_cube_human_raw",
"cadene/aloha_sim_transfer_cube_scripted_raw",
"cadene/aloha_mobile_cabinet_raw",
"cadene/aloha_mobile_chair_raw",
"cadene/aloha_mobile_elevator_raw",
"cadene/aloha_mobile_shrimp_raw",
"cadene/aloha_mobile_wash_pan_raw",
"cadene/aloha_mobile_wipe_wine_raw",
"cadene/aloha_static_battery_raw",
"cadene/aloha_static_candy_raw",
"cadene/aloha_static_coffee_raw",
"cadene/aloha_static_coffee_new_raw",
"cadene/aloha_static_cups_open_raw",
"cadene/aloha_static_fork_pick_up_raw",
"cadene/aloha_static_pingpong_test_raw",
"cadene/aloha_static_pro_pencil_raw",
"cadene/aloha_static_screw_driver_raw",
"cadene/aloha_static_tape_raw",
"cadene/aloha_static_thread_velcro_raw",
"cadene/aloha_static_towel_raw",
"cadene/aloha_static_vinh_cup_raw",
"cadene/aloha_static_vinh_cup_left_raw",
"cadene/aloha_static_ziploc_slide_raw",
"cadene/umi_cup_in_the_wild_raw",
]
for repo_id in repo_ids:
raw_dir = data_dir / repo_id raw_dir = data_dir / repo_id
download_raw(raw_dir, repo_id) download_raw(raw_dir, repo_id)
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser(
description=f"A script to download raw datasets from Hugging Face hub to a local directory. Here is a non exhaustive list of available repositories to use in `--repo-id`: {AVAILABLE_RAW_REPO_IDS}",
)
parser.add_argument( parser.add_argument(
"--raw-dir", "--raw-dir",

View File

@ -28,6 +28,7 @@ import tqdm
from datasets import Dataset, Features, Image, Sequence, Value from datasets import Dataset, Features, Image, Sequence, Value
from PIL import Image as PILImage from PIL import Image as PILImage
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently
from lerobot.common.datasets.utils import ( from lerobot.common.datasets.utils import (
calculate_episode_data_index, calculate_episode_data_index,
@ -210,6 +211,7 @@ def from_raw_to_lerobot_format(
hf_dataset = to_hf_dataset(data_dict, video) hf_dataset = to_hf_dataset(data_dict, video)
episode_data_index = calculate_episode_data_index(hf_dataset) episode_data_index = calculate_episode_data_index(hf_dataset)
info = { info = {
"codebase_version": CODEBASE_VERSION,
"fps": fps, "fps": fps,
"video": video, "video": video,
} }

View File

@ -23,6 +23,7 @@ import torch
from datasets import Dataset, Features, Image, Value from datasets import Dataset, Features, Image, Value
from PIL import Image as PILImage from PIL import Image as PILImage
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes
from lerobot.common.datasets.utils import calculate_episode_data_index, hf_transform_to_torch from lerobot.common.datasets.utils import calculate_episode_data_index, hf_transform_to_torch
from lerobot.common.datasets.video_utils import VideoFrame from lerobot.common.datasets.video_utils import VideoFrame
@ -95,6 +96,7 @@ def from_raw_to_lerobot_format(
hf_dataset = to_hf_dataset(data_dict, video) hf_dataset = to_hf_dataset(data_dict, video)
episode_data_index = calculate_episode_data_index(hf_dataset) episode_data_index = calculate_episode_data_index(hf_dataset)
info = { info = {
"codebase_version": CODEBASE_VERSION,
"fps": fps, "fps": fps,
"video": video, "video": video,
} }

View File

@ -24,6 +24,7 @@ import pandas as pd
import torch import torch
from datasets import Dataset, Features, Image, Sequence, Value from datasets import Dataset, Features, Image, Sequence, Value
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
from lerobot.common.datasets.utils import ( from lerobot.common.datasets.utils import (
calculate_episode_data_index, calculate_episode_data_index,
hf_transform_to_torch, hf_transform_to_torch,
@ -214,6 +215,7 @@ def from_raw_to_lerobot_format(
hf_dataset = to_hf_dataset(data_df, video) hf_dataset = to_hf_dataset(data_df, video)
episode_data_index = calculate_episode_data_index(hf_dataset) episode_data_index = calculate_episode_data_index(hf_dataset)
info = { info = {
"codebase_version": CODEBASE_VERSION,
"fps": fps, "fps": fps,
"video": video, "video": video,
} }

View File

@ -25,6 +25,7 @@ import zarr
from datasets import Dataset, Features, Image, Sequence, Value from datasets import Dataset, Features, Image, Sequence, Value
from PIL import Image as PILImage from PIL import Image as PILImage
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently
from lerobot.common.datasets.utils import ( from lerobot.common.datasets.utils import (
calculate_episode_data_index, calculate_episode_data_index,
@ -258,6 +259,7 @@ def from_raw_to_lerobot_format(
hf_dataset = to_hf_dataset(data_dict, video, keypoints_instead_of_image) hf_dataset = to_hf_dataset(data_dict, video, keypoints_instead_of_image)
episode_data_index = calculate_episode_data_index(hf_dataset) episode_data_index = calculate_episode_data_index(hf_dataset)
info = { info = {
"codebase_version": CODEBASE_VERSION,
"fps": fps, "fps": fps,
"video": video if not keypoints_instead_of_image else 0, "video": video if not keypoints_instead_of_image else 0,
} }

View File

@ -25,6 +25,7 @@ import zarr
from datasets import Dataset, Features, Image, Sequence, Value from datasets import Dataset, Features, Image, Sequence, Value
from PIL import Image as PILImage from PIL import Image as PILImage
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
from lerobot.common.datasets.push_dataset_to_hub._umi_imagecodecs_numcodecs import register_codecs from lerobot.common.datasets.push_dataset_to_hub._umi_imagecodecs_numcodecs import register_codecs
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently
from lerobot.common.datasets.utils import ( from lerobot.common.datasets.utils import (
@ -199,6 +200,7 @@ def from_raw_to_lerobot_format(
hf_dataset = to_hf_dataset(data_dict, video) hf_dataset = to_hf_dataset(data_dict, video)
episode_data_index = calculate_episode_data_index(hf_dataset) episode_data_index = calculate_episode_data_index(hf_dataset)
info = { info = {
"codebase_version": CODEBASE_VERSION,
"fps": fps, "fps": fps,
"video": video, "video": video,
} }

View File

@ -25,6 +25,7 @@ import tqdm
from datasets import Dataset, Features, Image, Sequence, Value from datasets import Dataset, Features, Image, Sequence, Value
from PIL import Image as PILImage from PIL import Image as PILImage
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently
from lerobot.common.datasets.utils import ( from lerobot.common.datasets.utils import (
calculate_episode_data_index, calculate_episode_data_index,
@ -177,6 +178,7 @@ def from_raw_to_lerobot_format(
hf_dataset = to_hf_dataset(data_dict, video) hf_dataset = to_hf_dataset(data_dict, video)
episode_data_index = calculate_episode_data_index(hf_dataset) episode_data_index = calculate_episode_data_index(hf_dataset)
info = { info = {
"codebase_version": CODEBASE_VERSION,
"fps": fps, "fps": fps,
"video": video, "video": video,
} }

View File

@ -15,13 +15,15 @@
# limitations under the License. # limitations under the License.
import json import json
import re import re
import warnings
from functools import cache
from pathlib import Path from pathlib import Path
from typing import Dict from typing import Dict
import datasets import datasets
import torch import torch
from datasets import load_dataset, load_from_disk from datasets import load_dataset, load_from_disk
from huggingface_hub import hf_hub_download, snapshot_download from huggingface_hub import HfApi, hf_hub_download, snapshot_download
from PIL import Image as PILImage from PIL import Image as PILImage
from safetensors.torch import load_file from safetensors.torch import load_file
from torchvision import transforms from torchvision import transforms
@ -80,7 +82,28 @@ def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
return items_dict return items_dict
def load_hf_dataset(repo_id, version, root, split) -> datasets.Dataset: @cache
def get_hf_dataset_safe_version(repo_id: str, version: str) -> str:
api = HfApi()
dataset_info = api.list_repo_refs(repo_id, repo_type="dataset")
branches = [b.name for b in dataset_info.branches]
if version not in branches:
warnings.warn(
f"""You are trying to load a dataset from {repo_id} created with a previous version of the
codebase. The following versions are available: {branches}.
The requested version ('{version}') is not found. You should be fine since
backward compatibility is maintained. If you encounter a problem, contact LeRobot maintainers on
Discord ('https://discord.com/invite/s3KuuzsPFb') or open an issue on github.""",
stacklevel=1,
)
if "main" not in branches:
raise ValueError(f"Version 'main' not found on {repo_id}")
return "main"
else:
return version
def load_hf_dataset(repo_id: str, version: str, root: Path, split: str) -> datasets.Dataset:
"""hf_dataset contains all the observations, states, actions, rewards, etc.""" """hf_dataset contains all the observations, states, actions, rewards, etc."""
if root is not None: if root is not None:
hf_dataset = load_from_disk(str(Path(root) / repo_id / "train")) hf_dataset = load_from_disk(str(Path(root) / repo_id / "train"))
@ -101,7 +124,9 @@ def load_hf_dataset(repo_id, version, root, split) -> datasets.Dataset:
f'`split` ({split}) should either be "train", "train[INT:]", or "train[:INT]"' f'`split` ({split}) should either be "train", "train[INT:]", or "train[:INT]"'
) )
else: else:
hf_dataset = load_dataset(repo_id, revision=version, split=split) safe_version = get_hf_dataset_safe_version(repo_id, version)
hf_dataset = load_dataset(repo_id, revision=safe_version, split=split)
hf_dataset.set_transform(hf_transform_to_torch) hf_dataset.set_transform(hf_transform_to_torch)
return hf_dataset return hf_dataset
@ -119,8 +144,9 @@ def load_episode_data_index(repo_id, version, root) -> dict[str, torch.Tensor]:
if root is not None: if root is not None:
path = Path(root) / repo_id / "meta_data" / "episode_data_index.safetensors" path = Path(root) / repo_id / "meta_data" / "episode_data_index.safetensors"
else: else:
safe_version = get_hf_dataset_safe_version(repo_id, version)
path = hf_hub_download( path = hf_hub_download(
repo_id, "meta_data/episode_data_index.safetensors", repo_type="dataset", revision=version repo_id, "meta_data/episode_data_index.safetensors", repo_type="dataset", revision=safe_version
) )
return load_file(path) return load_file(path)
@ -137,7 +163,10 @@ def load_stats(repo_id, version, root) -> dict[str, dict[str, torch.Tensor]]:
if root is not None: if root is not None:
path = Path(root) / repo_id / "meta_data" / "stats.safetensors" path = Path(root) / repo_id / "meta_data" / "stats.safetensors"
else: else:
path = hf_hub_download(repo_id, "meta_data/stats.safetensors", repo_type="dataset", revision=version) safe_version = get_hf_dataset_safe_version(repo_id, version)
path = hf_hub_download(
repo_id, "meta_data/stats.safetensors", repo_type="dataset", revision=safe_version
)
stats = load_file(path) stats = load_file(path)
return unflatten_dict(stats) return unflatten_dict(stats)
@ -154,7 +183,8 @@ def load_info(repo_id, version, root) -> dict:
if root is not None: if root is not None:
path = Path(root) / repo_id / "meta_data" / "info.json" path = Path(root) / repo_id / "meta_data" / "info.json"
else: else:
path = hf_hub_download(repo_id, "meta_data/info.json", repo_type="dataset", revision=version) safe_version = get_hf_dataset_safe_version(repo_id, version)
path = hf_hub_download(repo_id, "meta_data/info.json", repo_type="dataset", revision=safe_version)
with open(path) as f: with open(path) as f:
info = json.load(f) info = json.load(f)
@ -166,7 +196,8 @@ def load_videos(repo_id, version, root) -> Path:
path = Path(root) / repo_id / "videos" path = Path(root) / repo_id / "videos"
else: else:
# TODO(rcadene): we download the whole repo here. see if we can avoid this # TODO(rcadene): we download the whole repo here. see if we can avoid this
repo_dir = snapshot_download(repo_id, repo_type="dataset", revision=version) safe_version = get_hf_dataset_safe_version(repo_id, version)
repo_dir = snapshot_download(repo_id, repo_type="dataset", revision=safe_version)
path = Path(repo_dir) / "videos" path = Path(repo_dir) / "videos"
return path return path

View File

@ -207,7 +207,8 @@ def encode_video_frames(
ffmpeg_args.append("-y") ffmpeg_args.append("-y")
ffmpeg_cmd = ["ffmpeg"] + ffmpeg_args + [str(video_path)] ffmpeg_cmd = ["ffmpeg"] + ffmpeg_args + [str(video_path)]
subprocess.run(ffmpeg_cmd, check=True) # redirect stdin to subprocess.DEVNULL to prevent reading random keyboard inputs from terminal
subprocess.run(ffmpeg_cmd, check=True, stdin=subprocess.DEVNULL)
@dataclass @dataclass

View File

@ -19,7 +19,7 @@ import gymnasium as gym
from omegaconf import DictConfig from omegaconf import DictConfig
def make_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv: def make_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv | None:
"""Makes a gym vector environment according to the evaluation config. """Makes a gym vector environment according to the evaluation config.
n_envs can be used to override eval.batch_size in the configuration. Must be at least 1. n_envs can be used to override eval.batch_size in the configuration. Must be at least 1.
@ -27,6 +27,9 @@ def make_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv
if n_envs is not None and n_envs < 1: if n_envs is not None and n_envs < 1:
raise ValueError("`n_envs must be at least 1") raise ValueError("`n_envs must be at least 1")
if cfg.env.name == "real_world":
return
package_name = f"gym_{cfg.env.name}" package_name = f"gym_{cfg.env.name}"
try: try:

View File

@ -26,7 +26,10 @@ class ACTConfig:
Those are: `input_shapes` and 'output_shapes`. Those are: `input_shapes` and 'output_shapes`.
Notes on the inputs and outputs: Notes on the inputs and outputs:
- At least one key starting with "observation.image is required as an input. - Either:
- At least one key starting with "observation.image is required as an input.
AND/OR
- The key "observation.environment_state" is required as input.
- If there are multiple keys beginning with "observation.images." they are treated as multiple camera - If there are multiple keys beginning with "observation.images." they are treated as multiple camera
views. Right now we only support all images having the same shape. views. Right now we only support all images having the same shape.
- May optionally work without an "observation.state" key for the proprioceptive robot state. - May optionally work without an "observation.state" key for the proprioceptive robot state.
@ -73,12 +76,10 @@ class ACTConfig:
documentation in the policy class). documentation in the policy class).
latent_dim: The VAE's latent dimension. latent_dim: The VAE's latent dimension.
n_vae_encoder_layers: The number of transformer layers to use for the VAE's encoder. n_vae_encoder_layers: The number of transformer layers to use for the VAE's encoder.
temporal_ensemble_momentum: Exponential moving average (EMA) momentum parameter (α) for ensembling temporal_ensemble_coeff: Coefficient for the exponential weighting scheme to apply for temporal
actions for a given time step over multiple policy invocations. Updates are calculated as: ensembling. Defaults to None which means temporal ensembling is not used. `n_action_steps` must be
x = αx + (1-α)xₙ. Note that the ACT paper and original ACT code describes a different 1 when using this feature, as inference needs to happen at every step to form an ensemble. For
parameter here: they refer to a weighting scheme wᵢ = exp(-mi) and set m = 0.01. With our more information on how ensembling works, please see `ACTTemporalEnsembler`.
formulation, this is equivalent to α = exp(-0.01) 0.99. When this parameter is provided, we
require `n_action_steps == 1` (since we need to query the policy every step anyway).
dropout: Dropout to use in the transformer layers (see code for details). dropout: Dropout to use in the transformer layers (see code for details).
kl_weight: The weight to use for the KL-divergence component of the loss if the variational objective kl_weight: The weight to use for the KL-divergence component of the loss if the variational objective
is enabled. Loss is then calculated as: `reconstruction_loss + kl_weight * kld_loss`. is enabled. Loss is then calculated as: `reconstruction_loss + kl_weight * kld_loss`.
@ -136,7 +137,8 @@ class ACTConfig:
n_vae_encoder_layers: int = 4 n_vae_encoder_layers: int = 4
# Inference. # Inference.
temporal_ensemble_momentum: float | None = None # Note: the value used in ACT when temporal ensembling is enabled is 0.01.
temporal_ensemble_coeff: float | None = None
# Training and loss computation. # Training and loss computation.
dropout: float = 0.1 dropout: float = 0.1
@ -148,7 +150,7 @@ class ACTConfig:
raise ValueError( raise ValueError(
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}." f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
) )
if self.temporal_ensemble_momentum is not None and self.n_action_steps > 1: if self.temporal_ensemble_coeff is not None and self.n_action_steps > 1:
raise NotImplementedError( raise NotImplementedError(
"`n_action_steps` must be 1 when using temporal ensembling. This is " "`n_action_steps` must be 1 when using temporal ensembling. This is "
"because the policy needs to be queried every step to compute the ensembled action." "because the policy needs to be queried every step to compute the ensembled action."
@ -162,3 +164,8 @@ class ACTConfig:
raise ValueError( raise ValueError(
f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`" f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`"
) )
if (
not any(k.startswith("observation.image") for k in self.input_shapes)
and "observation.environment_state" not in self.input_shapes
):
raise ValueError("You must provide at least one image or the environment state among the inputs.")

View File

@ -77,12 +77,15 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")] self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
if config.temporal_ensemble_coeff is not None:
self.temporal_ensembler = ACTTemporalEnsembler(config.temporal_ensemble_coeff, config.chunk_size)
self.reset() self.reset()
def reset(self): def reset(self):
"""This should be called whenever the environment is reset.""" """This should be called whenever the environment is reset."""
if self.config.temporal_ensemble_momentum is not None: if self.config.temporal_ensemble_coeff is not None:
self._ensembled_actions = None self.temporal_ensembler.reset()
else: else:
self._action_queue = deque([], maxlen=self.config.n_action_steps) self._action_queue = deque([], maxlen=self.config.n_action_steps)
@ -97,26 +100,15 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
self.eval() self.eval()
batch = self.normalize_inputs(batch) batch = self.normalize_inputs(batch)
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4) if len(self.expected_image_keys) > 0:
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
# If we are doing temporal ensembling, keep track of the exponential moving average (EMA), and return # If we are doing temporal ensembling, do online updates where we keep track of the number of actions
# the first action. # we are ensembling over.
if self.config.temporal_ensemble_momentum is not None: if self.config.temporal_ensemble_coeff is not None:
actions = self.model(batch)[0] # (batch_size, chunk_size, action_dim) actions = self.model(batch)[0] # (batch_size, chunk_size, action_dim)
actions = self.unnormalize_outputs({"action": actions})["action"] actions = self.unnormalize_outputs({"action": actions})["action"]
if self._ensembled_actions is None: action = self.temporal_ensembler.update(actions)
# Initializes `self._ensembled_action` to the sequence of actions predicted during the first
# time step of the episode.
self._ensembled_actions = actions.clone()
else:
# self._ensembled_actions will have shape (batch_size, chunk_size - 1, action_dim). Compute
# the EMA update for those entries.
alpha = self.config.temporal_ensemble_momentum
self._ensembled_actions = alpha * self._ensembled_actions + (1 - alpha) * actions[:, :-1]
# The last action, which has no prior moving average, needs to get concatenated onto the end.
self._ensembled_actions = torch.cat([self._ensembled_actions, actions[:, -1:]], dim=1)
# "Consume" the first action.
action, self._ensembled_actions = self._ensembled_actions[:, 0], self._ensembled_actions[:, 1:]
return action return action
# Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by # Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
@ -135,7 +127,8 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""Run the batch through the model and compute the loss for training or validation.""" """Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch) batch = self.normalize_inputs(batch)
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4) if len(self.expected_image_keys) > 0:
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
batch = self.normalize_targets(batch) batch = self.normalize_targets(batch)
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch) actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
@ -160,6 +153,97 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
return loss_dict return loss_dict
class ACTTemporalEnsembler:
def __init__(self, temporal_ensemble_coeff: float, chunk_size: int) -> None:
"""Temporal ensembling as described in Algorithm 2 of https://arxiv.org/abs/2304.13705.
The weights are calculated as wᵢ = exp(-temporal_ensemble_coeff * i) where w₀ is the oldest action.
They are then normalized to sum to 1 by dividing by Σwᵢ. Here's some intuition around how the
coefficient works:
- Setting it to 0 uniformly weighs all actions.
- Setting it positive gives more weight to older actions.
- Setting it negative gives more weight to newer actions.
NOTE: The default value for `temporal_ensemble_coeff` used by the original ACT work is 0.01. This
results in older actions being weighed more highly than newer actions (the experiments documented in
https://github.com/huggingface/lerobot/pull/319 hint at why highly weighing new actions might be
detrimental: doing so aggressively may diminish the benefits of action chunking).
Here we use an online method for computing the average rather than caching a history of actions in
order to compute the average offline. For a simple 1D sequence it looks something like:
```
import torch
seq = torch.linspace(8, 8.5, 100)
print(seq)
m = 0.01
exp_weights = torch.exp(-m * torch.arange(len(seq)))
print(exp_weights)
# Calculate offline
avg = (exp_weights * seq).sum() / exp_weights.sum()
print("offline", avg)
# Calculate online
for i, item in enumerate(seq):
if i == 0:
avg = item
continue
avg *= exp_weights[:i].sum()
avg += item * exp_weights[i]
avg /= exp_weights[:i+1].sum()
print("online", avg)
```
"""
self.chunk_size = chunk_size
self.ensemble_weights = torch.exp(-temporal_ensemble_coeff * torch.arange(chunk_size))
self.ensemble_weights_cumsum = torch.cumsum(self.ensemble_weights, dim=0)
self.reset()
def reset(self):
"""Resets the online computation variables."""
self.ensembled_actions = None
# (chunk_size,) count of how many actions are in the ensemble for each time step in the sequence.
self.ensembled_actions_count = None
def update(self, actions: Tensor) -> Tensor:
"""
Takes a (batch, chunk_size, action_dim) sequence of actions, update the temporal ensemble for all
time steps, and pop/return the next batch of actions in the sequence.
"""
self.ensemble_weights = self.ensemble_weights.to(device=actions.device)
self.ensemble_weights_cumsum = self.ensemble_weights_cumsum.to(device=actions.device)
if self.ensembled_actions is None:
# Initializes `self._ensembled_action` to the sequence of actions predicted during the first
# time step of the episode.
self.ensembled_actions = actions.clone()
# Note: The last dimension is unsqueeze to make sure we can broadcast properly for tensor
# operations later.
self.ensembled_actions_count = torch.ones(
(self.chunk_size, 1), dtype=torch.long, device=self.ensembled_actions.device
)
else:
# self.ensembled_actions will have shape (batch_size, chunk_size - 1, action_dim). Compute
# the online update for those entries.
self.ensembled_actions *= self.ensemble_weights_cumsum[self.ensembled_actions_count - 1]
self.ensembled_actions += actions[:, :-1] * self.ensemble_weights[self.ensembled_actions_count]
self.ensembled_actions /= self.ensemble_weights_cumsum[self.ensembled_actions_count]
self.ensembled_actions_count = torch.clamp(self.ensembled_actions_count + 1, max=self.chunk_size)
# The last action, which has no prior online average, needs to get concatenated onto the end.
self.ensembled_actions = torch.cat([self.ensembled_actions, actions[:, -1:]], dim=1)
self.ensembled_actions_count = torch.cat(
[self.ensembled_actions_count, torch.ones_like(self.ensembled_actions_count[-1:])]
)
# "Consume" the first action.
action, self.ensembled_actions, self.ensembled_actions_count = (
self.ensembled_actions[:, 0],
self.ensembled_actions[:, 1:],
self.ensembled_actions_count[1:],
)
return action
class ACT(nn.Module): class ACT(nn.Module):
"""Action Chunking Transformer: The underlying neural network for ACTPolicy. """Action Chunking Transformer: The underlying neural network for ACTPolicy.
@ -200,12 +284,14 @@ class ACT(nn.Module):
self.config = config self.config = config
# BERT style VAE encoder with input tokens [cls, robot_state, *action_sequence]. # BERT style VAE encoder with input tokens [cls, robot_state, *action_sequence].
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]). # The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
self.use_input_state = "observation.state" in config.input_shapes self.use_robot_state = "observation.state" in config.input_shapes
self.use_images = any(k.startswith("observation.image") for k in config.input_shapes)
self.use_env_state = "observation.environment_state" in config.input_shapes
if self.config.use_vae: if self.config.use_vae:
self.vae_encoder = ACTEncoder(config) self.vae_encoder = ACTEncoder(config)
self.vae_encoder_cls_embed = nn.Embedding(1, config.dim_model) self.vae_encoder_cls_embed = nn.Embedding(1, config.dim_model)
# Projection layer for joint-space configuration to hidden dimension. # Projection layer for joint-space configuration to hidden dimension.
if self.use_input_state: if self.use_robot_state:
self.vae_encoder_robot_state_input_proj = nn.Linear( self.vae_encoder_robot_state_input_proj = nn.Linear(
config.input_shapes["observation.state"][0], config.dim_model config.input_shapes["observation.state"][0], config.dim_model
) )
@ -218,7 +304,7 @@ class ACT(nn.Module):
# Fixed sinusoidal positional embedding for the input to the VAE encoder. Unsqueeze for batch # Fixed sinusoidal positional embedding for the input to the VAE encoder. Unsqueeze for batch
# dimension. # dimension.
num_input_token_encoder = 1 + config.chunk_size num_input_token_encoder = 1 + config.chunk_size
if self.use_input_state: if self.use_robot_state:
num_input_token_encoder += 1 num_input_token_encoder += 1
self.register_buffer( self.register_buffer(
"vae_encoder_pos_enc", "vae_encoder_pos_enc",
@ -226,34 +312,45 @@ class ACT(nn.Module):
) )
# Backbone for image feature extraction. # Backbone for image feature extraction.
backbone_model = getattr(torchvision.models, config.vision_backbone)( if self.use_images:
replace_stride_with_dilation=[False, False, config.replace_final_stride_with_dilation], backbone_model = getattr(torchvision.models, config.vision_backbone)(
weights=config.pretrained_backbone_weights, replace_stride_with_dilation=[False, False, config.replace_final_stride_with_dilation],
norm_layer=FrozenBatchNorm2d, weights=config.pretrained_backbone_weights,
) norm_layer=FrozenBatchNorm2d,
# Note: The assumption here is that we are using a ResNet model (and hence layer4 is the final feature )
# map). # Note: The assumption here is that we are using a ResNet model (and hence layer4 is the final
# Note: The forward method of this returns a dict: {"feature_map": output}. # feature map).
self.backbone = IntermediateLayerGetter(backbone_model, return_layers={"layer4": "feature_map"}) # Note: The forward method of this returns a dict: {"feature_map": output}.
self.backbone = IntermediateLayerGetter(backbone_model, return_layers={"layer4": "feature_map"})
# Transformer (acts as VAE decoder when training with the variational objective). # Transformer (acts as VAE decoder when training with the variational objective).
self.encoder = ACTEncoder(config) self.encoder = ACTEncoder(config)
self.decoder = ACTDecoder(config) self.decoder = ACTDecoder(config)
# Transformer encoder input projections. The tokens will be structured like # Transformer encoder input projections. The tokens will be structured like
# [latent, robot_state, image_feature_map_pixels]. # [latent, (robot_state), (env_state), (image_feature_map_pixels)].
if self.use_input_state: if self.use_robot_state:
self.encoder_robot_state_input_proj = nn.Linear( self.encoder_robot_state_input_proj = nn.Linear(
config.input_shapes["observation.state"][0], config.dim_model config.input_shapes["observation.state"][0], config.dim_model
) )
if self.use_env_state:
self.encoder_env_state_input_proj = nn.Linear(
config.input_shapes["observation.environment_state"][0], config.dim_model
)
self.encoder_latent_input_proj = nn.Linear(config.latent_dim, config.dim_model) self.encoder_latent_input_proj = nn.Linear(config.latent_dim, config.dim_model)
self.encoder_img_feat_input_proj = nn.Conv2d( if self.use_images:
backbone_model.fc.in_features, config.dim_model, kernel_size=1 self.encoder_img_feat_input_proj = nn.Conv2d(
) backbone_model.fc.in_features, config.dim_model, kernel_size=1
)
# Transformer encoder positional embeddings. # Transformer encoder positional embeddings.
num_input_token_decoder = 2 if self.use_input_state else 1 n_1d_tokens = 1 # for the latent
self.encoder_robot_and_latent_pos_embed = nn.Embedding(num_input_token_decoder, config.dim_model) if self.use_robot_state:
self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2) n_1d_tokens += 1
if self.use_env_state:
n_1d_tokens += 1
self.encoder_1d_feature_pos_embed = nn.Embedding(n_1d_tokens, config.dim_model)
if self.use_images:
self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2)
# Transformer decoder. # Transformer decoder.
# Learnable positional embedding for the transformer's decoder (in the style of DETR object queries). # Learnable positional embedding for the transformer's decoder (in the style of DETR object queries).
@ -274,10 +371,13 @@ class ACT(nn.Module):
"""A forward pass through the Action Chunking Transformer (with optional VAE encoder). """A forward pass through the Action Chunking Transformer (with optional VAE encoder).
`batch` should have the following structure: `batch` should have the following structure:
{ {
"observation.state": (B, state_dim) batch of robot states. "observation.state" (optional): (B, state_dim) batch of robot states.
"observation.images": (B, n_cameras, C, H, W) batch of images. "observation.images": (B, n_cameras, C, H, W) batch of images.
AND/OR
"observation.environment_state": (B, env_dim) batch of environment states.
"action" (optional, only if training with VAE): (B, chunk_size, action dim) batch of actions. "action" (optional, only if training with VAE): (B, chunk_size, action dim) batch of actions.
} }
@ -291,7 +391,11 @@ class ACT(nn.Module):
"action" in batch "action" in batch
), "actions must be provided when using the variational objective in training mode." ), "actions must be provided when using the variational objective in training mode."
batch_size = batch["observation.images"].shape[0] batch_size = (
batch["observation.images"]
if "observation.images" in batch
else batch["observation.environment_state"]
).shape[0]
# Prepare the latent for input to the transformer encoder. # Prepare the latent for input to the transformer encoder.
if self.config.use_vae and "action" in batch: if self.config.use_vae and "action" in batch:
@ -299,12 +403,12 @@ class ACT(nn.Module):
cls_embed = einops.repeat( cls_embed = einops.repeat(
self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size
) # (B, 1, D) ) # (B, 1, D)
if self.use_input_state: if self.use_robot_state:
robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"]) robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"])
robot_state_embed = robot_state_embed.unsqueeze(1) # (B, 1, D) robot_state_embed = robot_state_embed.unsqueeze(1) # (B, 1, D)
action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D) action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D)
if self.use_input_state: if self.use_robot_state:
vae_encoder_input = [cls_embed, robot_state_embed, action_embed] # (B, S+2, D) vae_encoder_input = [cls_embed, robot_state_embed, action_embed] # (B, S+2, D)
else: else:
vae_encoder_input = [cls_embed, action_embed] vae_encoder_input = [cls_embed, action_embed]
@ -318,7 +422,7 @@ class ACT(nn.Module):
# sequence depending whether we use the input states or not (cls and robot state) # sequence depending whether we use the input states or not (cls and robot state)
# False means not a padding token. # False means not a padding token.
cls_joint_is_pad = torch.full( cls_joint_is_pad = torch.full(
(batch_size, 2 if self.use_input_state else 1), (batch_size, 2 if self.use_robot_state else 1),
False, False,
device=batch["observation.state"].device, device=batch["observation.state"].device,
) )
@ -347,56 +451,55 @@ class ACT(nn.Module):
batch["observation.state"].device batch["observation.state"].device
) )
# Prepare all other transformer encoder inputs. # Prepare transformer encoder inputs.
encoder_in_tokens = [self.encoder_latent_input_proj(latent_sample)]
encoder_in_pos_embed = list(self.encoder_1d_feature_pos_embed.weight.unsqueeze(1))
# Robot state token.
if self.use_robot_state:
encoder_in_tokens.append(self.encoder_robot_state_input_proj(batch["observation.state"]))
# Environment state token.
if self.use_env_state:
encoder_in_tokens.append(
self.encoder_env_state_input_proj(batch["observation.environment_state"])
)
# Camera observation features and positional embeddings. # Camera observation features and positional embeddings.
all_cam_features = [] if self.use_images:
all_cam_pos_embeds = [] all_cam_features = []
images = batch["observation.images"] all_cam_pos_embeds = []
images = batch["observation.images"]
for cam_index in range(images.shape[-4]): for cam_index in range(images.shape[-4]):
cam_features = self.backbone(images[:, cam_index])["feature_map"] cam_features = self.backbone(images[:, cam_index])["feature_map"]
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer # TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype) # buffer
cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w) cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
all_cam_features.append(cam_features) cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w)
all_cam_pos_embeds.append(cam_pos_embed) all_cam_features.append(cam_features)
# Concatenate camera observation feature maps and positional embeddings along the width dimension. all_cam_pos_embeds.append(cam_pos_embed)
encoder_in = torch.cat(all_cam_features, axis=-1) # Concatenate camera observation feature maps and positional embeddings along the width dimension,
cam_pos_embed = torch.cat(all_cam_pos_embeds, axis=-1) # and move to (sequence, batch, dim).
all_cam_features = torch.cat(all_cam_features, axis=-1)
encoder_in_tokens.extend(einops.rearrange(all_cam_features, "b c h w -> (h w) b c"))
all_cam_pos_embeds = torch.cat(all_cam_pos_embeds, axis=-1)
encoder_in_pos_embed.extend(einops.rearrange(all_cam_pos_embeds, "b c h w -> (h w) b c"))
# Get positional embeddings for robot state and latent. # Stack all tokens along the sequence dimension.
if self.use_input_state: encoder_in_tokens = torch.stack(encoder_in_tokens, axis=0)
robot_state_embed = self.encoder_robot_state_input_proj(batch["observation.state"]) # (B, C) encoder_in_pos_embed = torch.stack(encoder_in_pos_embed, axis=0)
latent_embed = self.encoder_latent_input_proj(latent_sample) # (B, C)
# Stack encoder input and positional embeddings moving to (S, B, C).
encoder_in_feats = [latent_embed, robot_state_embed] if self.use_input_state else [latent_embed]
encoder_in = torch.cat(
[
torch.stack(encoder_in_feats, axis=0),
einops.rearrange(encoder_in, "b c h w -> (h w) b c"),
]
)
pos_embed = torch.cat(
[
self.encoder_robot_and_latent_pos_embed.weight.unsqueeze(1),
cam_pos_embed.flatten(2).permute(2, 0, 1),
],
axis=0,
)
# Forward pass through the transformer modules. # Forward pass through the transformer modules.
encoder_out = self.encoder(encoder_in, pos_embed=pos_embed) encoder_out = self.encoder(encoder_in_tokens, pos_embed=encoder_in_pos_embed)
# TODO(rcadene, alexander-soare): remove call to `device` ; precompute and use buffer # TODO(rcadene, alexander-soare): remove call to `device` ; precompute and use buffer
decoder_in = torch.zeros( decoder_in = torch.zeros(
(self.config.chunk_size, batch_size, self.config.dim_model), (self.config.chunk_size, batch_size, self.config.dim_model),
dtype=pos_embed.dtype, dtype=encoder_in_pos_embed.dtype,
device=pos_embed.device, device=encoder_in_pos_embed.device,
) )
decoder_out = self.decoder( decoder_out = self.decoder(
decoder_in, decoder_in,
encoder_out, encoder_out,
encoder_pos_embed=pos_embed, encoder_pos_embed=encoder_in_pos_embed,
decoder_pos_embed=self.decoder_pos_embed.weight.unsqueeze(1), decoder_pos_embed=self.decoder_pos_embed.weight.unsqueeze(1),
) )

View File

@ -298,7 +298,8 @@ class VQBeTModel(nn.Module):
# bin prediction head / offset prediction head part of VQ-BeT # bin prediction head / offset prediction head part of VQ-BeT
self.action_head = VQBeTHead(config) self.action_head = VQBeTHead(config)
num_tokens = self.config.n_action_pred_token + self.config.action_chunk_size - 1 # Action tokens for: each observation step, the current action token, and all future action tokens.
num_tokens = self.config.n_action_pred_token + self.config.n_obs_steps - 1
self.register_buffer( self.register_buffer(
"select_target_actions_indices", "select_target_actions_indices",
torch.row_stack([torch.arange(i, i + self.config.action_chunk_size) for i in range(num_tokens)]), torch.row_stack([torch.arange(i, i + self.config.action_chunk_size) for i in range(num_tokens)]),

View File

@ -0,0 +1,404 @@
"""
This file contains utilities for recording frames from cameras. For more info look at `OpenCVCamera` docstring.
"""
import argparse
import concurrent.futures
import math
import shutil
import threading
import time
from dataclasses import dataclass, replace
from pathlib import Path
from threading import Thread
import cv2
import numpy as np
from PIL import Image
from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError
from lerobot.common.utils.utils import capture_timestamp_utc
from lerobot.scripts.control_robot import busy_wait
# Use 1 thread to avoid blocking the main thread. Especially useful during data collection
# when other threads are used to save the images.
cv2.setNumThreads(1)
# The maximum opencv device index depends on your operating system. For instance,
# if you have 3 cameras, they should be associated to index 0, 1, and 2. This is the case
# on MacOS. However, on Ubuntu, the indices are different like 6, 16, 23.
# When you change the USB port or reboot the computer, the operating system might
# treat the same cameras as new devices. Thus we select a higher bound to search indices.
MAX_OPENCV_INDEX = 60
def find_camera_indices(raise_when_empty=False, max_index_search_range=MAX_OPENCV_INDEX):
camera_ids = []
for camera_idx in range(max_index_search_range):
camera = cv2.VideoCapture(camera_idx)
is_open = camera.isOpened()
camera.release()
if is_open:
print(f"Camera found at index {camera_idx}")
camera_ids.append(camera_idx)
if raise_when_empty and len(camera_ids) == 0:
raise OSError(
"Not a single camera was detected. Try re-plugging, or re-installing `opencv2`, or your camera driver, or make sure your camera is compatible with opencv2."
)
return camera_ids
def save_image(img_array, camera_index, frame_index, images_dir):
img = Image.fromarray(img_array)
path = images_dir / f"camera_{camera_index:02d}_frame_{frame_index:06d}.png"
path.parent.mkdir(parents=True, exist_ok=True)
img.save(str(path), quality=100)
def save_images_from_cameras(
images_dir: Path, camera_ids=None, fps=None, width=None, height=None, record_time_s=2
):
if camera_ids is None:
print("Finding available camera indices")
camera_ids = find_camera_indices()
print("Connecting cameras")
cameras = []
for cam_idx in camera_ids:
camera = OpenCVCamera(cam_idx, fps=fps, width=width, height=height)
camera.connect()
print(
f"OpenCVCamera({camera.camera_index}, fps={camera.fps}, width={camera.width}, height={camera.height}, color_mode={camera.color_mode})"
)
cameras.append(camera)
images_dir = Path(
images_dir,
)
if images_dir.exists():
shutil.rmtree(
images_dir,
)
images_dir.mkdir(parents=True, exist_ok=True)
print(f"Saving images to {images_dir}")
frame_index = 0
start_time = time.perf_counter()
with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
while True:
now = time.perf_counter()
for camera in cameras:
# If we use async_read when fps is None, the loop will go full speed, and we will endup
# saving the same images from the cameras multiple times until the RAM/disk is full.
image = camera.read() if fps is None else camera.async_read()
executor.submit(
save_image,
image,
camera.camera_index,
frame_index,
images_dir,
)
if fps is not None:
dt_s = time.perf_counter() - now
busy_wait(1 / fps - dt_s)
if time.perf_counter() - start_time > record_time_s:
break
print(f"Frame: {frame_index:04d}\tLatency (ms): {(time.perf_counter() - now) * 1000:.2f}")
frame_index += 1
print(f"Images have been saved to {images_dir}")
@dataclass
class OpenCVCameraConfig:
"""
Example of tested options for Intel Real Sense D405:
```python
OpenCVCameraConfig(30, 640, 480)
OpenCVCameraConfig(60, 640, 480)
OpenCVCameraConfig(90, 640, 480)
OpenCVCameraConfig(30, 1280, 720)
```
"""
fps: int | None = None
width: int | None = None
height: int | None = None
color_mode: str = "rgb"
def __post_init__(self):
if self.color_mode not in ["rgb", "bgr"]:
raise ValueError(
f"Expected color_mode values are 'rgb' or 'bgr', but {self.color_mode} is provided."
)
class OpenCVCamera:
"""
The OpenCVCamera class allows to efficiently record images from cameras. It relies on opencv2 to communicate
with the cameras. Most cameras are compatible. For more info, see the [Video I/O with OpenCV Overview](https://docs.opencv.org/4.x/d0/da7/videoio_overview.html).
An OpenCVCamera instance requires a camera index (e.g. `OpenCVCamera(camera_index=0)`). When you only have one camera
like a webcam of a laptop, the camera index is expected to be 0, but it might also be very different, and the camera index
might change if you reboot your computer or re-plug your camera. This behavior depends on your operation system.
To find the camera indices of your cameras, you can run our utility script that will be save a few frames for each camera:
```bash
python lerobot/common/robot_devices/cameras/opencv.py --images-dir outputs/images_from_opencv_cameras
```
When an OpenCVCamera is instantiated, if no specific config is provided, the default fps, width, height and color_mode
of the given camera will be used.
Example of usage of the class:
```python
camera = OpenCVCamera(camera_index=0)
camera.connect()
color_image = camera.read()
# when done using the camera, consider disconnecting
camera.disconnect()
```
Example of changing default fps, width, height and color_mode:
```python
camera = OpenCVCamera(0, fps=30, width=1280, height=720)
camera = connect() # applies the settings, might error out if these settings are not compatible with the camera
camera = OpenCVCamera(0, fps=90, width=640, height=480)
camera = connect()
camera = OpenCVCamera(0, fps=90, width=640, height=480, color_mode="bgr")
camera = connect()
```
"""
def __init__(self, camera_index: int, config: OpenCVCameraConfig | None = None, **kwargs):
if config is None:
config = OpenCVCameraConfig()
# Overwrite config arguments using kwargs
config = replace(config, **kwargs)
self.camera_index = camera_index
self.fps = config.fps
self.width = config.width
self.height = config.height
self.color_mode = config.color_mode
if not isinstance(self.camera_index, int):
raise ValueError(
f"Camera index must be provided as an int, but {self.camera_index} was given instead."
)
self.camera = None
self.is_connected = False
self.thread = None
self.stop_event = None
self.color_image = None
self.logs = {}
def connect(self):
if self.is_connected:
raise RobotDeviceAlreadyConnectedError(f"Camera {self.camera_index} is already connected.")
# First create a temporary camera trying to access `camera_index`,
# and verify it is a valid camera by calling `isOpened`.
tmp_camera = cv2.VideoCapture(self.camera_index)
is_camera_open = tmp_camera.isOpened()
# Release camera to make it accessible for `find_camera_indices`
del tmp_camera
# If the camera doesn't work, display the camera indices corresponding to
# valid cameras.
if not is_camera_open:
# Verify that the provided `camera_index` is valid before printing the traceback
available_cam_ids = find_camera_indices()
if self.camera_index not in available_cam_ids:
raise ValueError(
f"`camera_index` is expected to be one of these available cameras {available_cam_ids}, but {self.camera_index} is provided instead."
)
raise OSError(f"Can't access camera {self.camera_index}.")
# Secondly, create the camera that will be used downstream.
# Note: For some unknown reason, calling `isOpened` blocks the camera which then
# needs to be re-created.
self.camera = cv2.VideoCapture(self.camera_index)
if self.fps is not None:
self.camera.set(cv2.CAP_PROP_FPS, self.fps)
if self.width is not None:
self.camera.set(cv2.CAP_PROP_FRAME_WIDTH, self.width)
if self.height is not None:
self.camera.set(cv2.CAP_PROP_FRAME_HEIGHT, self.height)
actual_fps = self.camera.get(cv2.CAP_PROP_FPS)
actual_width = self.camera.get(cv2.CAP_PROP_FRAME_WIDTH)
actual_height = self.camera.get(cv2.CAP_PROP_FRAME_HEIGHT)
if self.fps is not None and not math.isclose(self.fps, actual_fps, rel_tol=1e-3):
raise OSError(
f"Can't set {self.fps=} for camera {self.camera_index}. Actual value is {actual_fps}."
)
if self.width is not None and self.width != actual_width:
raise OSError(
f"Can't set {self.width=} for camera {self.camera_index}. Actual value is {actual_width}."
)
if self.height is not None and self.height != actual_height:
raise OSError(
f"Can't set {self.height=} for camera {self.camera_index}. Actual value is {actual_height}."
)
self.fps = actual_fps
self.width = actual_width
self.height = actual_height
self.is_connected = True
def read(self, temporary_color_mode: str | None = None) -> np.ndarray:
"""Read a frame from the camera returned in the format (height, width, channels)
(e.g. (640, 480, 3)), contrarily to the pytorch format which is channel first.
Note: Reading a frame is done every `camera.fps` times per second, and it is blocking.
If you are reading data from other sensors, we advise to use `camera.async_read()` which is non blocking version of `camera.read()`.
"""
if not self.is_connected:
raise RobotDeviceNotConnectedError(
f"OpenCVCamera({self.camera_index}) is not connected. Try running `camera.connect()` first."
)
start_time = time.perf_counter()
ret, color_image = self.camera.read()
if not ret:
raise OSError(f"Can't capture color image from camera {self.camera_index}.")
requested_color_mode = self.color_mode if temporary_color_mode is None else temporary_color_mode
if requested_color_mode not in ["rgb", "bgr"]:
raise ValueError(
f"Expected color values are 'rgb' or 'bgr', but {requested_color_mode} is provided."
)
# OpenCV uses BGR format as default (blue, green red) for all operations, including displaying images.
# However, Deep Learning framework such as LeRobot uses RGB format as default to train neural networks,
# so we convert the image color from BGR to RGB.
if requested_color_mode == "rgb":
color_image = cv2.cvtColor(color_image, cv2.COLOR_BGR2RGB)
h, w, _ = color_image.shape
if h != self.height or w != self.width:
raise OSError(
f"Can't capture color image with expected height and width ({self.height} x {self.width}). ({h} x {w}) returned instead."
)
# log the number of seconds it took to read the image
self.logs["delta_timestamp_s"] = time.perf_counter() - start_time
# log the utc time at which the image was received
self.logs["timestamp_utc"] = capture_timestamp_utc()
return color_image
def read_loop(self):
while self.stop_event is None or not self.stop_event.is_set():
self.color_image = self.read()
def async_read(self):
if not self.is_connected:
raise RobotDeviceNotConnectedError(
f"OpenCVCamera({self.camera_index}) is not connected. Try running `camera.connect()` first."
)
if self.thread is None:
self.stop_event = threading.Event()
self.thread = Thread(target=self.read_loop, args=())
self.thread.daemon = True
self.thread.start()
num_tries = 0
while self.color_image is None:
num_tries += 1
time.sleep(1 / self.fps)
if num_tries > self.fps and (self.thread.ident is None or not self.thread.is_alive()):
raise Exception(
"The thread responsible for `self.async_read()` took too much time to start. There might be an issue. Verify that `self.thread.start()` has been called."
)
return self.color_image
def disconnect(self):
if not self.is_connected:
raise RobotDeviceNotConnectedError(
f"OpenCVCamera({self.camera_index}) is not connected. Try running `camera.connect()` first."
)
if self.thread is not None and self.thread.is_alive():
# wait for the thread to finish
self.stop_event.set()
self.thread.join()
self.thread = None
self.stop_event = None
self.camera.release()
self.camera = None
self.is_connected = False
def __del__(self):
if getattr(self, "is_connected", False):
self.disconnect()
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Save a few frames using `OpenCVCamera` for all cameras connected to the computer, or a selected subset."
)
parser.add_argument(
"--camera-ids",
type=int,
nargs="*",
default=None,
help="List of camera indices used to instantiate the `OpenCVCamera`. If not provided, find and use all available camera indices.",
)
parser.add_argument(
"--fps",
type=int,
default=None,
help="Set the number of frames recorded per seconds for all cameras. If not provided, use the default fps of each camera.",
)
parser.add_argument(
"--width",
type=str,
default=None,
help="Set the width for all cameras. If not provided, use the default width of each camera.",
)
parser.add_argument(
"--height",
type=str,
default=None,
help="Set the height for all cameras. If not provided, use the default height of each camera.",
)
parser.add_argument(
"--images-dir",
type=Path,
default="outputs/images_from_opencv_cameras",
help="Set directory to save a few frames for each camera.",
)
parser.add_argument(
"--record-time-s",
type=float,
default=2.0,
help="Set the number of seconds used to record the frames. By default, 2 seconds.",
)
args = parser.parse_args()
save_images_from_cameras(**vars(args))

View File

@ -0,0 +1,47 @@
from pathlib import Path
from typing import Protocol
import cv2
import numpy as np
def write_shape_on_image_inplace(image):
height, width = image.shape[:2]
text = f"Width: {width} Height: {height}"
# Define the font, scale, color, and thickness
font = cv2.FONT_HERSHEY_SIMPLEX
font_scale = 1
color = (255, 0, 0) # Blue in BGR
thickness = 2
position = (10, height - 10) # 10 pixels from the bottom-left corner
cv2.putText(image, text, position, font, font_scale, color, thickness)
def save_color_image(image, path, write_shape=False):
path = Path(path)
path.parent.mkdir(parents=True, exist_ok=True)
if write_shape:
write_shape_on_image_inplace(image)
cv2.imwrite(str(path), image)
def save_depth_image(depth, path, write_shape=False):
path = Path(path)
path.parent.mkdir(parents=True, exist_ok=True)
# Apply colormap on depth image (image must be converted to 8-bit per pixel first)
depth_image = cv2.applyColorMap(cv2.convertScaleAbs(depth, alpha=0.03), cv2.COLORMAP_JET)
if write_shape:
write_shape_on_image_inplace(depth_image)
cv2.imwrite(str(path), depth_image)
# Defines a camera type
class Camera(Protocol):
def connect(self): ...
def read(self, temporary_color: str | None = None) -> np.ndarray: ...
def async_read(self) -> np.ndarray: ...
def disconnect(self): ...

View File

@ -0,0 +1,492 @@
import enum
import time
import traceback
from copy import deepcopy
from pathlib import Path
import numpy as np
from dynamixel_sdk import (
COMM_SUCCESS,
DXL_HIBYTE,
DXL_HIWORD,
DXL_LOBYTE,
DXL_LOWORD,
GroupSyncRead,
GroupSyncWrite,
PacketHandler,
PortHandler,
)
from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError
from lerobot.common.utils.utils import capture_timestamp_utc
PROTOCOL_VERSION = 2.0
BAUD_RATE = 1_000_000
TIMEOUT_MS = 1000
# https://emanual.robotis.com/docs/en/dxl/x/xl330-m077
# https://emanual.robotis.com/docs/en/dxl/x/xl330-m288
# https://emanual.robotis.com/docs/en/dxl/x/xl430-w250
# https://emanual.robotis.com/docs/en/dxl/x/xm430-w350
# https://emanual.robotis.com/docs/en/dxl/x/xm540-w270
# data_name: (address, size_byte)
X_SERIES_CONTROL_TABLE = {
"Model_Number": (0, 2),
"Model_Information": (2, 4),
"Firmware_Version": (6, 1),
"ID": (7, 1),
"Baud_Rate": (8, 1),
"Return_Delay_Time": (9, 1),
"Drive_Mode": (10, 1),
"Operating_Mode": (11, 1),
"Secondary_ID": (12, 1),
"Protocol_Type": (13, 1),
"Homing_Offset": (20, 4),
"Moving_Threshold": (24, 4),
"Temperature_Limit": (31, 1),
"Max_Voltage_Limit": (32, 2),
"Min_Voltage_Limit": (34, 2),
"PWM_Limit": (36, 2),
"Current_Limit": (38, 2),
"Acceleration_Limit": (40, 4),
"Velocity_Limit": (44, 4),
"Max_Position_Limit": (48, 4),
"Min_Position_Limit": (52, 4),
"Shutdown": (63, 1),
"Torque_Enable": (64, 1),
"LED": (65, 1),
"Status_Return_Level": (68, 1),
"Registered_Instruction": (69, 1),
"Hardware_Error_Status": (70, 1),
"Velocity_I_Gain": (76, 2),
"Velocity_P_Gain": (78, 2),
"Position_D_Gain": (80, 2),
"Position_I_Gain": (82, 2),
"Position_P_Gain": (84, 2),
"Feedforward_2nd_Gain": (88, 2),
"Feedforward_1st_Gain": (90, 2),
"Bus_Watchdog": (98, 1),
"Goal_PWM": (100, 2),
"Goal_Current": (102, 2),
"Goal_Velocity": (104, 4),
"Profile_Acceleration": (108, 4),
"Profile_Velocity": (112, 4),
"Goal_Position": (116, 4),
"Realtime_Tick": (120, 2),
"Moving": (122, 1),
"Moving_Status": (123, 1),
"Present_PWM": (124, 2),
"Present_Current": (126, 2),
"Present_Velocity": (128, 4),
"Present_Position": (132, 4),
"Velocity_Trajectory": (136, 4),
"Position_Trajectory": (140, 4),
"Present_Input_Voltage": (144, 2),
"Present_Temperature": (146, 1),
}
CALIBRATION_REQUIRED = ["Goal_Position", "Present_Position"]
CONVERT_UINT32_TO_INT32_REQUIRED = ["Goal_Position", "Present_Position"]
MODEL_CONTROL_TABLE = {
"x_series": X_SERIES_CONTROL_TABLE,
"xl330-m077": X_SERIES_CONTROL_TABLE,
"xl330-m288": X_SERIES_CONTROL_TABLE,
"xl430-w250": X_SERIES_CONTROL_TABLE,
"xm430-w350": X_SERIES_CONTROL_TABLE,
"xm540-w270": X_SERIES_CONTROL_TABLE,
}
NUM_READ_RETRY = 10
def get_group_sync_key(data_name, motor_names):
group_key = f"{data_name}_" + "_".join(motor_names)
return group_key
def get_result_name(fn_name, data_name, motor_names):
group_key = get_group_sync_key(data_name, motor_names)
rslt_name = f"{fn_name}_{group_key}"
return rslt_name
def get_queue_name(fn_name, data_name, motor_names):
group_key = get_group_sync_key(data_name, motor_names)
queue_name = f"{fn_name}_{group_key}"
return queue_name
def get_log_name(var_name, fn_name, data_name, motor_names):
group_key = get_group_sync_key(data_name, motor_names)
log_name = f"{var_name}_{fn_name}_{group_key}"
return log_name
def assert_same_address(model_ctrl_table, motor_models, data_name):
all_addr = []
all_bytes = []
for model in motor_models:
addr, bytes = model_ctrl_table[model][data_name]
all_addr.append(addr)
all_bytes.append(bytes)
if len(set(all_addr)) != 1:
raise NotImplementedError(
f"At least two motor models use a different address for `data_name`='{data_name}' ({list(zip(motor_models, all_addr, strict=False))}). Contact a LeRobot maintainer."
)
if len(set(all_bytes)) != 1:
raise NotImplementedError(
f"At least two motor models use a different bytes representation for `data_name`='{data_name}' ({list(zip(motor_models, all_bytes, strict=False))}). Contact a LeRobot maintainer."
)
def find_available_ports():
ports = []
for path in Path("/dev").glob("tty*"):
ports.append(str(path))
return ports
def find_port():
print("Finding all available ports for the DynamixelMotorsBus.")
ports_before = find_available_ports()
print(ports_before)
print("Remove the usb cable from your DynamixelMotorsBus and press Enter when done.")
input()
time.sleep(0.5)
ports_after = find_available_ports()
ports_diff = list(set(ports_before) - set(ports_after))
if len(ports_diff) == 1:
port = ports_diff[0]
print(f"The port of this DynamixelMotorsBus is '{port}'")
print("Reconnect the usb cable.")
elif len(ports_diff) == 0:
raise OSError(f"Could not detect the port. No difference was found ({ports_diff}).")
else:
raise OSError(f"Could not detect the port. More than one port was found ({ports_diff}).")
class TorqueMode(enum.Enum):
ENABLED = 1
DISABLED = 0
class OperatingMode(enum.Enum):
VELOCITY = 1
POSITION = 3
EXTENDED_POSITION = 4
CURRENT_CONTROLLED_POSITION = 5
PWM = 16
UNKNOWN = -1
class DriveMode(enum.Enum):
NON_INVERTED = 0
INVERTED = 1
class DynamixelMotorsBus:
# TODO(rcadene): Add a script to find the motor indices without DynamixelWizzard2
"""
The DynamixelMotorsBus class allows to efficiently read and write to the attached motors. It relies on
the python dynamixel sdk to communicate with the motors. For more info, see the [Dynamixel SDK Documentation](https://emanual.robotis.com/docs/en/software/dynamixel/dynamixel_sdk/sample_code/python_read_write_protocol_2_0/#python-read-write-protocol-20).
A DynamixelMotorsBus instance requires a port (e.g. `DynamixelMotorsBus(port="/dev/tty.usbmodem575E0031751"`)).
To find the port, you can run our utility script:
```bash
python lerobot/common/robot_devices/motors/dynamixel.py
>>> Finding all available ports for the DynamixelMotorsBus.
>>> ['/dev/tty.usbmodem575E0032081', '/dev/tty.usbmodem575E0031751']
>>> Remove the usb cable from your DynamixelMotorsBus and press Enter when done.
>>> The port of this DynamixelMotorsBus is /dev/tty.usbmodem575E0031751.
>>> Reconnect the usb cable.
```
To find the motor indices, use [DynamixelWizzard2](https://emanual.robotis.com/docs/en/software/dynamixel/dynamixel_wizard2).
Example of usage for 1 motor connected to the bus:
```python
motor_name = "gripper"
motor_index = 6
motor_model = "xl330-m077"
motors_bus = DynamixelMotorsBus(
port="/dev/tty.usbmodem575E0031751",
motors={motor_name: (motor_index, motor_model)},
)
motors_bus.connect()
motors_bus.teleop_step()
# when done, consider disconnecting
motors_bus.disconnect()
```
"""
def __init__(
self,
port: str,
motors: dict[str, tuple[int, str]],
extra_model_control_table: dict[str, list[tuple]] | None = None,
):
self.port = port
self.motors = motors
self.model_ctrl_table = deepcopy(MODEL_CONTROL_TABLE)
if extra_model_control_table:
self.model_ctrl_table.update(extra_model_control_table)
self.port_handler = None
self.packet_handler = None
self.calibration = None
self.is_connected = False
self.group_readers = {}
self.group_writers = {}
self.logs = {}
def connect(self):
if self.is_connected:
raise RobotDeviceAlreadyConnectedError(
f"DynamixelMotorsBus({self.port}) is already connected. Do not call `motors_bus.connect()` twice."
)
self.port_handler = PortHandler(self.port)
self.packet_handler = PacketHandler(PROTOCOL_VERSION)
try:
if not self.port_handler.openPort():
raise OSError(f"Failed to open port '{self.port}'.")
except Exception:
traceback.print_exc()
print(
"\nTry running `python lerobot/common/robot_devices/motors/dynamixel.py` to make sure you are using the correct port.\n"
)
raise
self.port_handler.setBaudRate(BAUD_RATE)
self.port_handler.setPacketTimeoutMillis(TIMEOUT_MS)
self.is_connected = True
@property
def motor_names(self) -> list[int]:
return list(self.motors.keys())
def set_calibration(self, calibration: dict[str, tuple[int, bool]]):
self.calibration = calibration
def apply_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
if not self.calibration:
return values
if motor_names is None:
motor_names = self.motor_names
for i, name in enumerate(motor_names):
homing_offset, drive_mode = self.calibration[name]
if values[i] is not None:
if drive_mode:
values[i] *= -1
values[i] += homing_offset
return values
def revert_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
if not self.calibration:
return values
if motor_names is None:
motor_names = self.motor_names
for i, name in enumerate(motor_names):
homing_offset, drive_mode = self.calibration[name]
if values[i] is not None:
values[i] -= homing_offset
if drive_mode:
values[i] *= -1
return values
def read(self, data_name, motor_names: str | list[str] | None = None):
if not self.is_connected:
raise RobotDeviceNotConnectedError(
f"DynamixelMotorsBus({self.port}) is not connected. You need to run `motors_bus.connect()`."
)
start_time = time.perf_counter()
if motor_names is None:
motor_names = self.motor_names
if isinstance(motor_names, str):
motor_names = [motor_names]
motor_ids = []
models = []
for name in motor_names:
motor_idx, model = self.motors[name]
motor_ids.append(motor_idx)
models.append(model)
assert_same_address(self.model_ctrl_table, models, data_name)
addr, bytes = self.model_ctrl_table[model][data_name]
group_key = get_group_sync_key(data_name, motor_names)
if data_name not in self.group_readers:
# create new group reader
self.group_readers[group_key] = GroupSyncRead(self.port_handler, self.packet_handler, addr, bytes)
for idx in motor_ids:
self.group_readers[group_key].addParam(idx)
for _ in range(NUM_READ_RETRY):
comm = self.group_readers[group_key].txRxPacket()
if comm == COMM_SUCCESS:
break
if comm != COMM_SUCCESS:
raise ConnectionError(
f"Read failed due to communication error on port {self.port} for group_key {group_key}: "
f"{self.packet_handler.getTxRxResult(comm)}"
)
values = []
for idx in motor_ids:
value = self.group_readers[group_key].getData(idx, addr, bytes)
values.append(value)
values = np.array(values)
# Convert to signed int to use range [-2048, 2048] for our motor positions.
if data_name in CONVERT_UINT32_TO_INT32_REQUIRED:
values = values.astype(np.int32)
if data_name in CALIBRATION_REQUIRED:
values = self.apply_calibration(values, motor_names)
# log the number of seconds it took to read the data from the motors
delta_ts_name = get_log_name("delta_timestamp_s", "read", data_name, motor_names)
self.logs[delta_ts_name] = time.perf_counter() - start_time
# log the utc time at which the data was received
ts_utc_name = get_log_name("timestamp_utc", "read", data_name, motor_names)
self.logs[ts_utc_name] = capture_timestamp_utc()
return values
def write(self, data_name, values: int | float | np.ndarray, motor_names: str | list[str] | None = None):
if not self.is_connected:
raise RobotDeviceNotConnectedError(
f"DynamixelMotorsBus({self.port}) is not connected. You need to run `motors_bus.connect()`."
)
start_time = time.perf_counter()
if motor_names is None:
motor_names = self.motor_names
if isinstance(motor_names, str):
motor_names = [motor_names]
if isinstance(values, (int, float, np.integer)):
values = [int(values)] * len(motor_names)
values = np.array(values)
motor_ids = []
models = []
for name in motor_names:
motor_idx, model = self.motors[name]
motor_ids.append(motor_idx)
models.append(model)
if data_name in CALIBRATION_REQUIRED:
values = self.revert_calibration(values, motor_names)
values = values.tolist()
assert_same_address(self.model_ctrl_table, models, data_name)
addr, bytes = self.model_ctrl_table[model][data_name]
group_key = get_group_sync_key(data_name, motor_names)
init_group = data_name not in self.group_readers
if init_group:
self.group_writers[group_key] = GroupSyncWrite(
self.port_handler, self.packet_handler, addr, bytes
)
for idx, value in zip(motor_ids, values, strict=True):
# Note: No need to convert back into unsigned int, since this byte preprocessing
# already handles it for us.
if bytes == 1:
data = [
DXL_LOBYTE(DXL_LOWORD(value)),
]
elif bytes == 2:
data = [
DXL_LOBYTE(DXL_LOWORD(value)),
DXL_HIBYTE(DXL_LOWORD(value)),
]
elif bytes == 4:
data = [
DXL_LOBYTE(DXL_LOWORD(value)),
DXL_HIBYTE(DXL_LOWORD(value)),
DXL_LOBYTE(DXL_HIWORD(value)),
DXL_HIBYTE(DXL_HIWORD(value)),
]
else:
raise NotImplementedError(
f"Value of the number of bytes to be sent is expected to be in [1, 2, 4], but "
f"{bytes} is provided instead."
)
if init_group:
self.group_writers[group_key].addParam(idx, data)
else:
self.group_writers[group_key].changeParam(idx, data)
comm = self.group_writers[group_key].txPacket()
if comm != COMM_SUCCESS:
raise ConnectionError(
f"Write failed due to communication error on port {self.port} for group_key {group_key}: "
f"{self.packet_handler.getTxRxResult(comm)}"
)
# log the number of seconds it took to write the data to the motors
delta_ts_name = get_log_name("delta_timestamp_s", "write", data_name, motor_names)
self.logs[delta_ts_name] = time.perf_counter() - start_time
# TODO(rcadene): should we log the time before sending the write command?
# log the utc time when the write has been completed
ts_utc_name = get_log_name("timestamp_utc", "write", data_name, motor_names)
self.logs[ts_utc_name] = capture_timestamp_utc()
def disconnect(self):
if not self.is_connected:
raise RobotDeviceNotConnectedError(
f"DynamixelMotorsBus({self.port}) is not connected. Try running `motors_bus.connect()` first."
)
if self.port_handler is not None:
self.port_handler.closePort()
self.port_handler = None
self.packet_handler = None
self.group_readers = {}
self.group_writers = {}
self.is_connected = False
def __del__(self):
if getattr(self, "is_connected", False):
self.disconnect()
if __name__ == "__main__":
# Helper to find the usb port associated to all your DynamixelMotorsBus.
find_port()

View File

@ -0,0 +1,10 @@
from typing import Protocol
class MotorsBus(Protocol):
def motor_names(self): ...
def set_calibration(self): ...
def apply_calibration(self): ...
def revert_calibration(self): ...
def read(self): ...
def write(self): ...

View File

@ -0,0 +1,46 @@
def make_robot(name):
if name == "koch":
# TODO(rcadene): Add configurable robot from command line and yaml config
# TODO(rcadene): Add example with and without cameras
from lerobot.common.robot_devices.cameras.opencv import OpenCVCamera
from lerobot.common.robot_devices.motors.dynamixel import DynamixelMotorsBus
from lerobot.common.robot_devices.robots.koch import KochRobot
robot = KochRobot(
leader_arms={
"main": DynamixelMotorsBus(
port="/dev/tty.usbmodem575E0031751",
motors={
# name: (index, model)
"shoulder_pan": (1, "xl330-m077"),
"shoulder_lift": (2, "xl330-m077"),
"elbow_flex": (3, "xl330-m077"),
"wrist_flex": (4, "xl330-m077"),
"wrist_roll": (5, "xl330-m077"),
"gripper": (6, "xl330-m077"),
},
),
},
follower_arms={
"main": DynamixelMotorsBus(
port="/dev/tty.usbmodem575E0032081",
motors={
# name: (index, model)
"shoulder_pan": (1, "xl430-w250"),
"shoulder_lift": (2, "xl430-w250"),
"elbow_flex": (3, "xl330-m288"),
"wrist_flex": (4, "xl330-m288"),
"wrist_roll": (5, "xl330-m288"),
"gripper": (6, "xl330-m288"),
},
),
},
cameras={
"laptop": OpenCVCamera(0, fps=30, width=640, height=480),
"phone": OpenCVCamera(1, fps=30, width=640, height=480),
},
)
else:
raise ValueError(f"Robot '{name}' not found.")
return robot

View File

@ -0,0 +1,548 @@
import pickle
import time
from dataclasses import dataclass, field, replace
from pathlib import Path
import numpy as np
import torch
from lerobot.common.robot_devices.cameras.utils import Camera
from lerobot.common.robot_devices.motors.dynamixel import (
DriveMode,
DynamixelMotorsBus,
OperatingMode,
TorqueMode,
)
from lerobot.common.robot_devices.motors.utils import MotorsBus
from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError
URL_HORIZONTAL_POSITION = {
"follower": "https://raw.githubusercontent.com/huggingface/lerobot/main/media/koch/follower_horizontal.png",
"leader": "https://raw.githubusercontent.com/huggingface/lerobot/main/media/koch/leader_horizontal.png",
}
URL_90_DEGREE_POSITION = {
"follower": "https://raw.githubusercontent.com/huggingface/lerobot/main/media/koch/follower_90_degree.png",
"leader": "https://raw.githubusercontent.com/huggingface/lerobot/main/media/koch/leader_90_degree.png",
}
########################################################################
# Calibration logic
########################################################################
TARGET_HORIZONTAL_POSITION = np.array([0, -1024, 1024, 0, -1024, 0])
TARGET_90_DEGREE_POSITION = np.array([1024, 0, 0, 1024, 0, -1024])
GRIPPER_OPEN = np.array([-400])
def apply_homing_offset(values: np.array, homing_offset: np.array) -> np.array:
for i in range(len(values)):
if values[i] is not None:
values[i] += homing_offset[i]
return values
def apply_drive_mode(values: np.array, drive_mode: np.array) -> np.array:
for i in range(len(values)):
if values[i] is not None and drive_mode[i]:
values[i] = -values[i]
return values
def apply_calibration(values: np.array, homing_offset: np.array, drive_mode: np.array) -> np.array:
values = apply_drive_mode(values, drive_mode)
values = apply_homing_offset(values, homing_offset)
return values
def revert_calibration(values: np.array, homing_offset: np.array, drive_mode: np.array) -> np.array:
"""
Transform working position into real position for the robot.
"""
values = apply_homing_offset(
values,
np.array([-homing_offset if homing_offset is not None else None for homing_offset in homing_offset]),
)
values = apply_drive_mode(values, drive_mode)
return values
def revert_appropriate_positions(positions: np.array, drive_mode: list[bool]) -> np.array:
for i, revert in enumerate(drive_mode):
if not revert and positions[i] is not None:
positions[i] = -positions[i]
return positions
def compute_corrections(positions: np.array, drive_mode: list[bool], target_position: np.array) -> np.array:
correction = revert_appropriate_positions(positions, drive_mode)
for i in range(len(positions)):
if correction[i] is not None:
if drive_mode[i]:
correction[i] -= target_position[i]
else:
correction[i] += target_position[i]
return correction
def compute_nearest_rounded_positions(positions: np.array) -> np.array:
return np.array(
[
round(positions[i] / 1024) * 1024 if positions[i] is not None else None
for i in range(len(positions))
]
)
def compute_homing_offset(
arm: DynamixelMotorsBus, drive_mode: list[bool], target_position: np.array
) -> np.array:
# Get the present positions of the servos
present_positions = apply_calibration(
arm.read("Present_Position"), np.array([0, 0, 0, 0, 0, 0]), drive_mode
)
nearest_positions = compute_nearest_rounded_positions(present_positions)
correction = compute_corrections(nearest_positions, drive_mode, target_position)
return correction
def compute_drive_mode(arm: DynamixelMotorsBus, offset: np.array):
# Get current positions
present_positions = apply_calibration(
arm.read("Present_Position"), offset, np.array([False, False, False, False, False, False])
)
nearest_positions = compute_nearest_rounded_positions(present_positions)
# construct 'drive_mode' list comparing nearest_positions and TARGET_90_DEGREE_POSITION
drive_mode = []
for i in range(len(nearest_positions)):
drive_mode.append(nearest_positions[i] != TARGET_90_DEGREE_POSITION[i])
return drive_mode
def reset_arm(arm: MotorsBus):
# To be configured, all servos must be in "torque disable" mode
arm.write("Torque_Enable", TorqueMode.DISABLED.value)
# Use 'extended position mode' for all motors except gripper, because in joint mode the servos can't
# rotate more than 360 degrees (from 0 to 4095) And some mistake can happen while assembling the arm,
# you could end up with a servo with a position 0 or 4095 at a crucial point See [
# https://emanual.robotis.com/docs/en/dxl/x/x_series/#operating-mode11]
all_motors_except_gripper = [name for name in arm.motor_names if name != "gripper"]
arm.write("Operating_Mode", OperatingMode.EXTENDED_POSITION.value, all_motors_except_gripper)
# TODO(rcadene): why?
# Use 'position control current based' for gripper
arm.write("Operating_Mode", OperatingMode.CURRENT_CONTROLLED_POSITION.value, "gripper")
# Make sure the native calibration (homing offset abd drive mode) is disabled, since we use our own calibration layer to be more generic
arm.write("Homing_Offset", 0)
arm.write("Drive_Mode", DriveMode.NON_INVERTED.value)
def run_arm_calibration(arm: MotorsBus, name: str, arm_type: str):
"""Example of usage:
```python
run_arm_calibration(arm, "left", "follower")
```
"""
reset_arm(arm)
# TODO(rcadene): document what position 1 mean
print(
f"Please move the '{name} {arm_type}' arm to the horizontal position (gripper fully closed, see {URL_HORIZONTAL_POSITION[arm_type]})"
)
input("Press Enter to continue...")
horizontal_homing_offset = compute_homing_offset(
arm, [False, False, False, False, False, False], TARGET_HORIZONTAL_POSITION
)
# TODO(rcadene): document what position 2 mean
print(
f"Please move the '{name} {arm_type}' arm to the 90 degree position (gripper fully open, see {URL_90_DEGREE_POSITION[arm_type]})"
)
input("Press Enter to continue...")
drive_mode = compute_drive_mode(arm, horizontal_homing_offset)
homing_offset = compute_homing_offset(arm, drive_mode, TARGET_90_DEGREE_POSITION)
# Invert offset for all drive_mode servos
for i in range(len(drive_mode)):
if drive_mode[i]:
homing_offset[i] = -homing_offset[i]
print("Calibration is done!")
print("=====================================")
print(" HOMING_OFFSET: ", " ".join([str(i) for i in homing_offset]))
print(" DRIVE_MODE: ", " ".join([str(i) for i in drive_mode]))
print("=====================================")
return homing_offset, drive_mode
########################################################################
# Alexander Koch robot arm
########################################################################
@dataclass
class KochRobotConfig:
"""
Example of usage:
```python
KochRobotConfig()
```
"""
# Define all components of the robot
leader_arms: dict[str, MotorsBus] = field(default_factory=lambda: {})
follower_arms: dict[str, MotorsBus] = field(default_factory=lambda: {})
cameras: dict[str, Camera] = field(default_factory=lambda: {})
class KochRobot:
# TODO(rcadene): Implement force feedback
"""Tau Robotics: https://tau-robotics.com
Example of highest frequency teleoperation without camera:
```python
# Defines how to communicate with the motors of the leader and follower arms
leader_arms = {
"main": DynamixelMotorsBus(
port="/dev/tty.usbmodem575E0031751",
motors={
# name: (index, model)
"shoulder_pan": (1, "xl330-m077"),
"shoulder_lift": (2, "xl330-m077"),
"elbow_flex": (3, "xl330-m077"),
"wrist_flex": (4, "xl330-m077"),
"wrist_roll": (5, "xl330-m077"),
"gripper": (6, "xl330-m077"),
},
),
}
follower_arms = {
"main": DynamixelMotorsBus(
port="/dev/tty.usbmodem575E0032081",
motors={
# name: (index, model)
"shoulder_pan": (1, "xl430-w250"),
"shoulder_lift": (2, "xl430-w250"),
"elbow_flex": (3, "xl330-m288"),
"wrist_flex": (4, "xl330-m288"),
"wrist_roll": (5, "xl330-m288"),
"gripper": (6, "xl330-m288"),
},
),
}
robot = KochRobot(leader_arms, follower_arms)
# Connect motors buses and cameras if any (Required)
robot.connect()
while True:
robot.teleop_step()
```
Example of highest frequency data collection without camera:
```python
# Assumes leader and follower arms have been instantiated already (see first example)
robot = KochRobot(leader_arms, follower_arms)
robot.connect()
while True:
observation, action = robot.teleop_step(record_data=True)
```
Example of highest frequency data collection with cameras:
```python
# Defines how to communicate with 2 cameras connected to the computer.
# Here, the webcam of the mackbookpro and the iphone (connected in USB to the macbookpro)
# can be reached respectively using the camera indices 0 and 1. These indices can be
# arbitrary. See the documentation of `OpenCVCamera` to find your own camera indices.
cameras = {
"macbookpro": OpenCVCamera(camera_index=0, fps=30, width=640, height=480),
"iphone": OpenCVCamera(camera_index=1, fps=30, width=640, height=480),
}
# Assumes leader and follower arms have been instantiated already (see first example)
robot = KochRobot(leader_arms, follower_arms, cameras)
robot.connect()
while True:
observation, action = robot.teleop_step(record_data=True)
```
Example of controlling the robot with a policy (without running multiple policies in parallel to ensure highest frequency):
```python
# Assumes leader and follower arms + cameras have been instantiated already (see previous example)
robot = KochRobot(leader_arms, follower_arms, cameras)
robot.connect()
while True:
# Uses the follower arms and cameras to capture an observation
observation = robot.capture_observation()
# Assumes a policy has been instantiated
with torch.inference_mode():
action = policy.select_action(observation)
# Orders the robot to move
robot.send_action(action)
```
Example of disconnecting which is not mandatory since we disconnect when the object is deleted:
```python
robot.disconnect()
```
"""
def __init__(
self,
config: KochRobotConfig | None = None,
calibration_path: Path = ".cache/calibration/koch.pkl",
**kwargs,
):
if config is None:
config = KochRobotConfig()
# Overwrite config arguments using kwargs
self.config = replace(config, **kwargs)
self.calibration_path = Path(calibration_path)
self.leader_arms = self.config.leader_arms
self.follower_arms = self.config.follower_arms
self.cameras = self.config.cameras
self.is_connected = False
self.logs = {}
def connect(self):
if self.is_connected:
raise RobotDeviceAlreadyConnectedError(
"KochRobot is already connected. Do not run `robot.connect()` twice."
)
if not self.leader_arms and not self.follower_arms and not self.cameras:
raise ValueError(
"KochRobot doesn't have any device to connect. See example of usage in docstring of the class."
)
# Connect the arms
for name in self.follower_arms:
self.follower_arms[name].connect()
self.leader_arms[name].connect()
# Reset the arms and load or run calibration
if self.calibration_path.exists():
# Reset all arms before setting calibration
for name in self.follower_arms:
reset_arm(self.follower_arms[name])
for name in self.leader_arms:
reset_arm(self.leader_arms[name])
with open(self.calibration_path, "rb") as f:
calibration = pickle.load(f)
else:
# Run calibration process which begins by reseting all arms
calibration = self.run_calibration()
self.calibration_path.parent.mkdir(parents=True, exist_ok=True)
with open(self.calibration_path, "wb") as f:
pickle.dump(calibration, f)
# Set calibration
for name in self.follower_arms:
self.follower_arms[name].set_calibration(calibration[f"follower_{name}"])
for name in self.leader_arms:
self.leader_arms[name].set_calibration(calibration[f"leader_{name}"])
# Set better PID values to close the gap between recored states and actions
# TODO(rcadene): Implement an automatic procedure to set optimial PID values for each motor
for name in self.follower_arms:
self.follower_arms[name].write("Position_P_Gain", 1500, "elbow_flex")
self.follower_arms[name].write("Position_I_Gain", 0, "elbow_flex")
self.follower_arms[name].write("Position_D_Gain", 600, "elbow_flex")
# Enable torque on all motors of the follower arms
for name in self.follower_arms:
self.follower_arms[name].write("Torque_Enable", 1)
# Enable torque on the gripper of the leader arms, and move it to 45 degrees,
# so that we can use it as a trigger to close the gripper of the follower arms.
for name in self.leader_arms:
self.leader_arms[name].write("Torque_Enable", 1, "gripper")
self.leader_arms[name].write("Goal_Position", GRIPPER_OPEN, "gripper")
# Connect the cameras
for name in self.cameras:
self.cameras[name].connect()
self.is_connected = True
def run_calibration(self):
calibration = {}
for name in self.follower_arms:
homing_offset, drive_mode = run_arm_calibration(self.follower_arms[name], name, "follower")
calibration[f"follower_{name}"] = {}
for idx, motor_name in enumerate(self.follower_arms[name].motor_names):
calibration[f"follower_{name}"][motor_name] = (homing_offset[idx], drive_mode[idx])
for name in self.leader_arms:
homing_offset, drive_mode = run_arm_calibration(self.leader_arms[name], name, "leader")
calibration[f"leader_{name}"] = {}
for idx, motor_name in enumerate(self.leader_arms[name].motor_names):
calibration[f"leader_{name}"][motor_name] = (homing_offset[idx], drive_mode[idx])
return calibration
def teleop_step(
self, record_data=False
) -> None | tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
if not self.is_connected:
raise RobotDeviceNotConnectedError(
"KochRobot is not connected. You need to run `robot.connect()`."
)
# Prepare to assign the positions of the leader to the follower
leader_pos = {}
for name in self.leader_arms:
now = time.perf_counter()
leader_pos[name] = self.leader_arms[name].read("Present_Position")
self.logs[f"read_leader_{name}_pos_dt_s"] = time.perf_counter() - now
follower_goal_pos = {}
for name in self.leader_arms:
follower_goal_pos[name] = leader_pos[name]
# Send action
for name in self.follower_arms:
now = time.perf_counter()
self.follower_arms[name].write("Goal_Position", follower_goal_pos[name])
self.logs[f"write_follower_{name}_goal_pos_dt_s"] = time.perf_counter() - now
# Early exit when recording data is not requested
if not record_data:
return
# TODO(rcadene): Add velocity and other info
# Read follower position
follower_pos = {}
for name in self.follower_arms:
now = time.perf_counter()
follower_pos[name] = self.follower_arms[name].read("Present_Position")
self.logs[f"read_follower_{name}_pos_dt_s"] = time.perf_counter() - now
# Create state by concatenating follower current position
state = []
for name in self.follower_arms:
if name in follower_pos:
state.append(follower_pos[name])
state = np.concatenate(state)
# Create action by concatenating follower goal position
action = []
for name in self.follower_arms:
if name in follower_goal_pos:
action.append(follower_goal_pos[name])
action = np.concatenate(action)
# Capture images from cameras
images = {}
for name in self.cameras:
now = time.perf_counter()
images[name] = self.cameras[name].async_read()
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - now
# Populate output dictionnaries and format to pytorch
obs_dict, action_dict = {}, {}
obs_dict["observation.state"] = torch.from_numpy(state)
action_dict["action"] = torch.from_numpy(action)
for name in self.cameras:
obs_dict[f"observation.images.{name}"] = torch.from_numpy(images[name])
return obs_dict, action_dict
def capture_observation(self):
"""The returned observations do not have a batch dimension."""
if not self.is_connected:
raise RobotDeviceNotConnectedError(
"KochRobot is not connected. You need to run `robot.connect()`."
)
# Read follower position
follower_pos = {}
for name in self.follower_arms:
now = time.perf_counter()
follower_pos[name] = self.follower_arms[name].read("Present_Position")
self.logs[f"read_follower_{name}_pos_dt_s"] = time.perf_counter() - now
# Create state by concatenating follower current position
state = []
for name in self.follower_arms:
if name in follower_pos:
state.append(follower_pos[name])
state = np.concatenate(state)
# Capture images from cameras
images = {}
for name in self.cameras:
now = time.perf_counter()
images[name] = self.cameras[name].async_read()
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - now
# Populate output dictionnaries and format to pytorch
obs_dict = {}
obs_dict["observation.state"] = torch.from_numpy(state)
for name in self.cameras:
# Convert to pytorch format: channel first and float32 in [0,1]
img = torch.from_numpy(images[name])
img = img.type(torch.float32) / 255
img = img.permute(2, 0, 1).contiguous()
obs_dict[f"observation.images.{name}"] = img
return obs_dict
def send_action(self, action: torch.Tensor):
"""The provided action is expected to be a vector."""
if not self.is_connected:
raise RobotDeviceNotConnectedError(
"KochRobot is not connected. You need to run `robot.connect()`."
)
from_idx = 0
to_idx = 0
follower_goal_pos = {}
for name in self.follower_arms:
if name in self.follower_arms:
to_idx += len(self.follower_arms[name].motor_names)
follower_goal_pos[name] = action[from_idx:to_idx].numpy()
from_idx = to_idx
for name in self.follower_arms:
self.follower_arms[name].write("Goal_Position", follower_goal_pos[name].astype(np.int32))
def disconnect(self):
if not self.is_connected:
raise RobotDeviceNotConnectedError(
"KochRobot is not connected. You need to run `robot.connect()` before disconnecting."
)
for name in self.follower_arms:
self.follower_arms[name].disconnect()
for name in self.leader_arms:
self.leader_arms[name].disconnect()
for name in self.cameras:
self.cameras[name].disconnect()
self.is_connected = False
def __del__(self):
if getattr(self, "is_connected", False):
self.disconnect()

View File

@ -0,0 +1,9 @@
from typing import Protocol
class Robot(Protocol):
def init_teleop(self): ...
def run_calibration(self): ...
def teleop_step(self, record_data=False): ...
def capture_observation(self): ...
def send_action(self, action): ...

View File

@ -0,0 +1,19 @@
class RobotDeviceNotConnectedError(Exception):
"""Exception raised when the robot device is not connected."""
def __init__(
self, message="This robot device is not connected. Try calling `robot_device.connect()` first."
):
self.message = message
super().__init__(self.message)
class RobotDeviceAlreadyConnectedError(Exception):
"""Exception raised when the robot device is already connected."""
def __init__(
self,
message="This robot device is already connected. Try not calling `robot_device.connect()` twice.",
):
self.message = message
super().__init__(self.message)

View File

@ -17,7 +17,7 @@ import logging
import os.path as osp import os.path as osp
import random import random
from contextlib import contextmanager from contextlib import contextmanager
from datetime import datetime from datetime import datetime, timezone
from pathlib import Path from pathlib import Path
from typing import Any, Generator from typing import Any, Generator
@ -172,3 +172,7 @@ def print_cuda_memory_usage():
print("Maximum GPU Memory Allocated: {:.2f} MB".format(torch.cuda.max_memory_allocated(0) / 1024**2)) print("Maximum GPU Memory Allocated: {:.2f} MB".format(torch.cuda.max_memory_allocated(0) / 1024**2))
print("Current GPU Memory Reserved: {:.2f} MB".format(torch.cuda.memory_reserved(0) / 1024**2)) print("Current GPU Memory Reserved: {:.2f} MB".format(torch.cuda.memory_reserved(0) / 1024**2))
print("Maximum GPU Memory Reserved: {:.2f} MB".format(torch.cuda.max_memory_reserved(0) / 1024**2)) print("Maximum GPU Memory Reserved: {:.2f} MB".format(torch.cuda.max_memory_reserved(0) / 1024**2))
def capture_timestamp_utc():
return datetime.now(timezone.utc)

10
lerobot/configs/env/koch_real.yaml vendored Normal file
View File

@ -0,0 +1,10 @@
# @package _global_
fps: 30
env:
name: real_world
task: null
state_dim: 6
action_dim: 6
fps: ${fps}

View File

@ -75,7 +75,7 @@ policy:
n_vae_encoder_layers: 4 n_vae_encoder_layers: 4
# Inference. # Inference.
temporal_ensemble_momentum: null temporal_ensemble_coeff: null
# Training and loss computation. # Training and loss computation.
dropout: 0.1 dropout: 0.1

View File

@ -0,0 +1,102 @@
# @package _global_
# Use `act_koch_real.yaml` to train on real-world datasets collected on Alexander Koch's robots.
# Compared to `act.yaml`, it contains 2 cameras (i.e. laptop, phone) instead of 1 camera (i.e. top).
# Also, `training.eval_freq` is set to -1. This config is used to evaluate checkpoints at a certain frequency of training steps.
# When it is set to -1, it deactivates evaluation. This is because real-world evaluation is done through our `control_robot.py` script.
# Look at the documentation in header of `control_robot.py` for more information on how to collect data , train and evaluate a policy.
#
# Example of usage for training:
# ```bash
# python lerobot/scripts/train.py \
# policy=act_koch_real \
# env=koch_real
# ```
seed: 1000
dataset_repo_id: lerobot/koch_pick_place_lego
override_dataset_stats:
observation.images.laptop:
# stats from imagenet, since we use a pretrained vision model
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
observation.images.phone:
# stats from imagenet, since we use a pretrained vision model
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
training:
offline_steps: 80000
online_steps: 0
eval_freq: -1
save_freq: 10000
log_freq: 100
save_checkpoint: true
batch_size: 8
lr: 1e-5
lr_backbone: 1e-5
weight_decay: 1e-4
grad_clip_norm: 10
online_steps_between_rollouts: 1
delta_timestamps:
action: "[i / ${fps} for i in range(${policy.chunk_size})]"
eval:
n_episodes: 50
batch_size: 50
# See `configuration_act.py` for more details.
policy:
name: act
# Input / output structure.
n_obs_steps: 1
chunk_size: 100
n_action_steps: 100
input_shapes:
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
observation.images.laptop: [3, 480, 640]
observation.images.phone: [3, 480, 640]
observation.state: ["${env.state_dim}"]
output_shapes:
action: ["${env.action_dim}"]
# Normalization / Unnormalization
input_normalization_modes:
observation.images.laptop: mean_std
observation.images.phone: mean_std
observation.state: mean_std
output_normalization_modes:
action: mean_std
# Architecture.
# Vision backbone.
vision_backbone: resnet18
pretrained_backbone_weights: ResNet18_Weights.IMAGENET1K_V1
replace_final_stride_with_dilation: false
# Transformer layers.
pre_norm: false
dim_model: 512
n_heads: 8
dim_feedforward: 3200
feedforward_activation: relu
n_encoder_layers: 4
# Note: Although the original ACT implementation has 7 for `n_decoder_layers`, there is a bug in the code
# that means only the first layer is used. Here we match the original implementation by setting this to 1.
# See this issue https://github.com/tonyzhaozh/act/issues/25#issue-2258740521.
n_decoder_layers: 1
# VAE.
use_vae: true
latent_dim: 32
n_vae_encoder_layers: 4
# Inference.
temporal_ensemble_momentum: null
# Training and loss computation.
dropout: 0.1
kl_weight: 10.0

View File

@ -107,7 +107,7 @@ policy:
n_vae_encoder_layers: 4 n_vae_encoder_layers: 4
# Inference. # Inference.
temporal_ensemble_momentum: null temporal_ensemble_coeff: null
# Training and loss computation. # Training and loss computation.
dropout: 0.1 dropout: 0.1

View File

@ -103,7 +103,7 @@ policy:
n_vae_encoder_layers: 4 n_vae_encoder_layers: 4
# Inference. # Inference.
temporal_ensemble_momentum: null temporal_ensemble_coeff: null
# Training and loss computation. # Training and loss computation.
dropout: 0.1 dropout: 0.1

View File

@ -0,0 +1,735 @@
"""
Examples of usage:
- Unlimited teleoperation at highest frequency (~200 Hz is expected), to exit with CTRL+C:
```bash
python lerobot/scripts/control_robot.py teleoperate
```
- Unlimited teleoperation at a limited frequency of 30 Hz, to simulate data recording frequency:
```bash
python lerobot/scripts/control_robot.py teleoperate \
--fps 30
```
- Record one episode in order to test replay:
```bash
python lerobot/scripts/control_robot.py record_dataset \
--fps 30 \
--root tmp/data \
--repo-id $USER/koch_test \
--num-episodes 1 \
--run-compute-stats 0
```
- Visualize dataset:
```bash
python lerobot/scripts/visualize_dataset.py \
--root tmp/data \
--repo-id $USER/koch_test \
--episode-index 0
```
- Replay this test episode:
```bash
python lerobot/scripts/control_robot.py replay_episode \
--fps 30 \
--root tmp/data \
--repo-id $USER/koch_test \
--episode 0
```
- Record a full dataset in order to train a policy, with 2 seconds of warmup,
30 seconds of recording for each episode, and 10 seconds to reset the environment in between episodes:
```bash
python lerobot/scripts/control_robot.py record_dataset \
--fps 30 \
--root data \
--repo-id $USER/koch_pick_place_lego \
--num-episodes 50 \
--run-compute-stats 1 \
--warmup-time-s 2 \
--episode-time-s 30 \
--reset-time-s 10
```
**NOTE**: You can use your keyboard to control data recording flow.
- Tap right arrow key '->' to early exit while recording an episode and go to resseting the environment.
- Tap right arrow key '->' to early exit while resetting the environment and got to recording the next episode.
- Tap left arrow key '<-' to early exit and re-record the current episode.
- Tap escape key 'esc' to stop the data recording.
This might require a sudo permission to allow your terminal to monitor keyboard events.
**NOTE**: You can resume/continue data recording by running the same data recording command twice.
To avoid resuming by deleting the dataset, use `--force-override 1`.
- Train on this dataset with the ACT policy:
```bash
DATA_DIR=data python lerobot/scripts/train.py \
policy=act_koch_real \
env=koch_real \
dataset_repo_id=$USER/koch_pick_place_lego \
hydra.run.dir=outputs/train/act_koch_real
```
- Run the pretrained policy on the robot:
```bash
python lerobot/scripts/control_robot.py run_policy \
-p outputs/train/act_koch_real/checkpoints/080000/pretrained_model
```
"""
import argparse
import concurrent.futures
import json
import logging
import os
import platform
import shutil
import time
from contextlib import nullcontext
from pathlib import Path
import torch
import tqdm
from huggingface_hub import create_branch
from omegaconf import DictConfig
from PIL import Image
from termcolor import colored
# from safetensors.torch import load_file, save_file
from lerobot.common.datasets.compute_stats import compute_stats
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
from lerobot.common.datasets.push_dataset_to_hub.aloha_hdf5_format import to_hf_dataset
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes
from lerobot.common.datasets.utils import calculate_episode_data_index
from lerobot.common.datasets.video_utils import encode_video_frames
from lerobot.common.policies.factory import make_policy
from lerobot.common.robot_devices.robots.factory import make_robot
from lerobot.common.robot_devices.robots.utils import Robot
from lerobot.common.utils.utils import get_safe_torch_device, init_hydra_config, init_logging, set_global_seed
from lerobot.scripts.eval import get_pretrained_policy_path
from lerobot.scripts.push_dataset_to_hub import push_meta_data_to_hub, push_videos_to_hub, save_meta_data
########################################################################################
# Utilities
########################################################################################
def save_image(img_tensor, key, frame_index, episode_index, videos_dir):
img = Image.fromarray(img_tensor.numpy())
path = videos_dir / f"{key}_episode_{episode_index:06d}" / f"frame_{frame_index:06d}.png"
path.parent.mkdir(parents=True, exist_ok=True)
img.save(str(path), quality=100)
def busy_wait(seconds):
# Significantly more accurate than `time.sleep`, and mendatory for our use case,
# but it consumes CPU cycles.
# TODO(rcadene): find an alternative: from python 11, time.sleep is precise
end_time = time.perf_counter() + seconds
while time.perf_counter() < end_time:
pass
def none_or_int(value):
if value == "None":
return None
return int(value)
def log_control_info(robot, dt_s, episode_index=None, frame_index=None, fps=None):
log_items = []
if episode_index is not None:
log_items += [f"ep:{episode_index}"]
if frame_index is not None:
log_items += [f"frame:{frame_index}"]
def log_dt(shortname, dt_val_s):
nonlocal log_items
log_items += [f"{shortname}:{dt_val_s * 1000:5.2f} ({1/ dt_val_s:3.1f}hz)"]
# total step time displayed in milliseconds and its frequency
log_dt("dt", dt_s)
for name in robot.leader_arms:
key = f"read_leader_{name}_pos_dt_s"
if key in robot.logs:
log_dt("dtRlead", robot.logs[key])
for name in robot.follower_arms:
key = f"write_follower_{name}_goal_pos_dt_s"
if key in robot.logs:
log_dt("dtRfoll", robot.logs[key])
key = f"read_follower_{name}_pos_dt_s"
if key in robot.logs:
log_dt("dtWfoll", robot.logs[key])
for name in robot.cameras:
key = f"read_camera_{name}_dt_s"
if key in robot.logs:
log_dt(f"dtR{name}", robot.logs[key])
info_str = " ".join(log_items)
if fps is not None:
actual_fps = 1 / dt_s
if actual_fps < fps - 1:
info_str = colored(info_str, "yellow")
logging.info(info_str)
def get_is_headless():
if platform.system() == "Linux":
display = os.environ.get("DISPLAY")
if display is None or display == "":
return True
return False
########################################################################################
# Control modes
########################################################################################
def teleoperate(robot: Robot, fps: int | None = None, teleop_time_s: float | None = None):
# TODO(rcadene): Add option to record logs
if not robot.is_connected:
robot.connect()
start_time = time.perf_counter()
while True:
now = time.perf_counter()
robot.teleop_step()
if fps is not None:
dt_s = time.perf_counter() - now
busy_wait(1 / fps - dt_s)
dt_s = time.perf_counter() - now
log_control_info(robot, dt_s, fps=fps)
if teleop_time_s is not None and time.perf_counter() - start_time > teleop_time_s:
break
def record_dataset(
robot: Robot,
fps: int | None = None,
root="data",
repo_id="lerobot/debug",
warmup_time_s=2,
episode_time_s=10,
reset_time_s=5,
num_episodes=50,
video=True,
run_compute_stats=True,
push_to_hub=True,
num_image_writers=8,
force_override=False,
):
# TODO(rcadene): Add option to record logs
if not video:
raise NotImplementedError()
if not robot.is_connected:
robot.connect()
local_dir = Path(root) / repo_id
if local_dir.exists() and force_override:
shutil.rmtree(local_dir)
episodes_dir = local_dir / "episodes"
episodes_dir.mkdir(parents=True, exist_ok=True)
videos_dir = local_dir / "videos"
videos_dir.mkdir(parents=True, exist_ok=True)
# Logic to resume data recording
rec_info_path = episodes_dir / "data_recording_info.json"
if rec_info_path.exists():
with open(rec_info_path) as f:
rec_info = json.load(f)
episode_index = rec_info["last_episode_index"] + 1
else:
episode_index = 0
is_headless = get_is_headless()
# Execute a few seconds without recording data, to give times
# to the robot devices to connect and start synchronizing.
timestamp = 0
start_time = time.perf_counter()
is_warmup_print = False
while timestamp < warmup_time_s:
if not is_warmup_print:
logging.info("Warming up (no data recording)")
os.system('say "Warmup" &')
is_warmup_print = True
now = time.perf_counter()
observation, action = robot.teleop_step(record_data=True)
if not is_headless:
image_keys = [key for key in observation if "image" in key]
dt_s = time.perf_counter() - now
busy_wait(1 / fps - dt_s)
dt_s = time.perf_counter() - now
log_control_info(robot, dt_s, fps=fps)
timestamp = time.perf_counter() - start_time
# Allow to exit early while recording an episode or resetting the environment,
# by tapping the right arrow key '->'. This might require a sudo permission
# to allow your terminal to monitor keyboard events.
exit_early = False
rerecord_episode = False
stop_recording = False
# Only import pynput if not in a headless environment
if is_headless:
logging.info("Headless environment detected. Keyboard input will not be available.")
else:
from pynput import keyboard
def on_press(key):
nonlocal exit_early, rerecord_episode, stop_recording
try:
if key == keyboard.Key.right:
print("Right arrow key pressed. Exiting loop...")
exit_early = True
elif key == keyboard.Key.left:
print("Left arrow key pressed. Exiting loop and rerecord the last episode...")
rerecord_episode = True
exit_early = True
elif key == keyboard.Key.esc:
print("Escape key pressed. Stopping data recording...")
stop_recording = True
exit_early = True
except Exception as e:
print(f"Error handling key press: {e}")
listener = keyboard.Listener(on_press=on_press)
listener.start()
# Save images using threads to reach high fps (30 and more)
# Using `with` to exist smoothly if an execption is raised.
# Using only 4 worker threads to avoid blocking the main thread.
futures = []
with concurrent.futures.ThreadPoolExecutor(max_workers=num_image_writers) as executor:
# Start recording all episodes
while episode_index < num_episodes:
logging.info(f"Recording episode {episode_index}")
os.system(f'say "Recording episode {episode_index}" &')
ep_dict = {}
frame_index = 0
timestamp = 0
start_time = time.perf_counter()
while timestamp < episode_time_s:
now = time.perf_counter()
observation, action = robot.teleop_step(record_data=True)
image_keys = [key for key in observation if "image" in key]
not_image_keys = [key for key in observation if "image" not in key]
for key in image_keys:
futures += [
executor.submit(
save_image, observation[key], key, frame_index, episode_index, videos_dir
)
]
for key in not_image_keys:
if key not in ep_dict:
ep_dict[key] = []
ep_dict[key].append(observation[key])
for key in action:
if key not in ep_dict:
ep_dict[key] = []
ep_dict[key].append(action[key])
frame_index += 1
dt_s = time.perf_counter() - now
busy_wait(1 / fps - dt_s)
dt_s = time.perf_counter() - now
log_control_info(robot, dt_s, fps=fps)
timestamp = time.perf_counter() - start_time
if exit_early:
exit_early = False
break
if not stop_recording:
# Start resetting env while the executor are finishing
logging.info("Reset the environment")
os.system('say "Reset the environment" &')
timestamp = 0
start_time = time.perf_counter()
# During env reset we save the data and encode the videos
num_frames = frame_index
for key in image_keys:
tmp_imgs_dir = videos_dir / f"{key}_episode_{episode_index:06d}"
fname = f"{key}_episode_{episode_index:06d}.mp4"
video_path = local_dir / "videos" / fname
if video_path.exists():
video_path.unlink()
# Store the reference to the video frame, even tho the videos are not yet encoded
ep_dict[key] = []
for i in range(num_frames):
ep_dict[key].append({"path": f"videos/{fname}", "timestamp": i / fps})
for key in not_image_keys:
ep_dict[key] = torch.stack(ep_dict[key])
for key in action:
ep_dict[key] = torch.stack(ep_dict[key])
ep_dict["episode_index"] = torch.tensor([episode_index] * num_frames)
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
done = torch.zeros(num_frames, dtype=torch.bool)
done[-1] = True
ep_dict["next.done"] = done
ep_path = episodes_dir / f"episode_{episode_index}.pth"
print("Saving episode dictionary...")
torch.save(ep_dict, ep_path)
rec_info = {
"last_episode_index": episode_index,
}
with open(rec_info_path, "w") as f:
json.dump(rec_info, f)
is_last_episode = stop_recording or (episode_index == (num_episodes - 1))
# Wait if necessary
with tqdm.tqdm(total=reset_time_s, desc="Waiting") as pbar:
while timestamp < reset_time_s and not is_last_episode:
time.sleep(1)
timestamp = time.perf_counter() - start_time
pbar.update(1)
if exit_early:
exit_early = False
break
# Skip updating episode index which forces re-recording episode
if rerecord_episode:
rerecord_episode = False
continue
episode_index += 1
if is_last_episode:
logging.info("Done recording")
os.system('say "Done recording"')
if not is_headless:
listener.stop()
logging.info("Waiting for threads writing the images on disk to terminate...")
for _ in tqdm.tqdm(
concurrent.futures.as_completed(futures), total=len(futures), desc="Writting images"
):
pass
break
num_episodes = episode_index
logging.info("Encoding videos")
os.system('say "Encoding videos" &')
# Use ffmpeg to convert frames stored as png into mp4 videos
for episode_index in tqdm.tqdm(range(num_episodes)):
for key in image_keys:
tmp_imgs_dir = videos_dir / f"{key}_episode_{episode_index:06d}"
fname = f"{key}_episode_{episode_index:06d}.mp4"
video_path = local_dir / "videos" / fname
if video_path.exists():
continue
# note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
# since video encoding with ffmpeg is already using multithreading.
encode_video_frames(tmp_imgs_dir, video_path, fps, overwrite=True)
shutil.rmtree(tmp_imgs_dir)
logging.info("Concatenating episodes")
ep_dicts = []
for episode_index in tqdm.tqdm(range(num_episodes)):
ep_path = episodes_dir / f"episode_{episode_index}.pth"
ep_dict = torch.load(ep_path)
ep_dicts.append(ep_dict)
data_dict = concatenate_episodes(ep_dicts)
total_frames = data_dict["frame_index"].shape[0]
data_dict["index"] = torch.arange(0, total_frames, 1)
hf_dataset = to_hf_dataset(data_dict, video)
episode_data_index = calculate_episode_data_index(hf_dataset)
info = {
"codebase_version": CODEBASE_VERSION,
"fps": fps,
"video": video,
}
lerobot_dataset = LeRobotDataset.from_preloaded(
repo_id=repo_id,
hf_dataset=hf_dataset,
episode_data_index=episode_data_index,
info=info,
videos_dir=videos_dir,
)
if run_compute_stats:
logging.info("Computing dataset statistics")
os.system('say "Computing dataset statistics" &')
stats = compute_stats(lerobot_dataset)
lerobot_dataset.stats = stats
else:
logging.info("Skipping computation of the dataset statistrics")
hf_dataset = hf_dataset.with_format(None) # to remove transforms that cant be saved
hf_dataset.save_to_disk(str(local_dir / "train"))
meta_data_dir = local_dir / "meta_data"
save_meta_data(info, stats, episode_data_index, meta_data_dir)
if push_to_hub:
hf_dataset.push_to_hub(repo_id, revision="main")
push_meta_data_to_hub(repo_id, meta_data_dir, revision="main")
if video:
push_videos_to_hub(repo_id, videos_dir, revision="main")
create_branch(repo_id, repo_type="dataset", branch=CODEBASE_VERSION)
logging.info("Exiting")
os.system('say "Exiting" &')
return lerobot_dataset
def replay_episode(robot: Robot, episode: int, fps: int | None = None, root="data", repo_id="lerobot/debug"):
# TODO(rcadene): Add option to record logs
local_dir = Path(root) / repo_id
if not local_dir.exists():
raise ValueError(local_dir)
dataset = LeRobotDataset(repo_id, root=root)
items = dataset.hf_dataset.select_columns("action")
from_idx = dataset.episode_data_index["from"][episode].item()
to_idx = dataset.episode_data_index["to"][episode].item()
if not robot.is_connected:
robot.connect()
logging.info("Replaying episode")
os.system('say "Replaying episode"')
for idx in range(from_idx, to_idx):
now = time.perf_counter()
action = items[idx]["action"]
robot.send_action(action)
dt_s = time.perf_counter() - now
busy_wait(1 / fps - dt_s)
dt_s = time.perf_counter() - now
log_control_info(robot, dt_s, fps=fps)
def run_policy(robot: Robot, policy: torch.nn.Module, hydra_cfg: DictConfig, run_time_s: float | None = None):
# TODO(rcadene): Add option to record eval dataset and logs
# Check device is available
device = get_safe_torch_device(hydra_cfg.device, log=True)
policy.eval()
policy.to(device)
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
set_global_seed(hydra_cfg.seed)
fps = hydra_cfg.env.fps
if not robot.is_connected:
robot.connect()
start_time = time.perf_counter()
while True:
now = time.perf_counter()
observation = robot.capture_observation()
with (
torch.inference_mode(),
torch.autocast(device_type=device.type)
if device.type == "cuda" and hydra_cfg.use_amp
else nullcontext(),
):
# add batch dimension to 1
for name in observation:
observation[name] = observation[name].unsqueeze(0)
if device.type == "mps":
for name in observation:
observation[name] = observation[name].to(device)
action = policy.select_action(observation)
# remove batch dimension
action = action.squeeze(0)
robot.send_action(action.to("cpu"))
dt_s = time.perf_counter() - now
busy_wait(1 / fps - dt_s)
dt_s = time.perf_counter() - now
log_control_info(robot, dt_s, fps=fps)
if run_time_s is not None and time.perf_counter() - start_time > run_time_s:
break
if __name__ == "__main__":
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers(dest="mode", required=True)
# Set common options for all the subparsers
base_parser = argparse.ArgumentParser(add_help=False)
base_parser.add_argument(
"--robot",
type=str,
default="koch",
help="Name of the robot provided to the `make_robot(name)` factory function.",
)
parser_teleop = subparsers.add_parser("teleoperate", parents=[base_parser])
parser_teleop.add_argument(
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
)
parser_record = subparsers.add_parser("record_dataset", parents=[base_parser])
parser_record.add_argument(
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
)
parser_record.add_argument(
"--root",
type=Path,
default="data",
help="Root directory where the dataset will be stored locally at '{root}/{repo_id}' (e.g. 'data/hf_username/dataset_name').",
)
parser_record.add_argument(
"--repo-id",
type=str,
default="lerobot/test",
help="Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).",
)
parser_record.add_argument(
"--warmup-time-s",
type=int,
default=2,
help="Number of seconds before starting data collection. It allows the robot devices to warmup and synchronize.",
)
parser_record.add_argument(
"--episode-time-s",
type=int,
default=10,
help="Number of seconds for data recording for each episode.",
)
parser_record.add_argument(
"--reset-time-s",
type=int,
default=5,
help="Number of seconds for resetting the environment after each episode.",
)
parser_record.add_argument("--num-episodes", type=int, default=50, help="Number of episodes to record.")
parser_record.add_argument(
"--run-compute-stats",
type=int,
default=1,
help="By default, run the computation of the data statistics at the end of data collection. Compute intensive and not required to just replay an episode.",
)
parser_record.add_argument(
"--push-to-hub",
type=int,
default=1,
help="Upload dataset to Hugging Face hub.",
)
parser_record.add_argument(
"--num-image-writers",
type=int,
default=8,
help="Number of threads writing the frames as png images on disk. Don't set too much as you might get unstable fps due to main thread being blocked.",
)
parser_record.add_argument(
"--force-override",
type=int,
default=0,
help="By default, data recording is resumed. When set to 1, delete the local directory and start data recording from scratch.",
)
parser_replay = subparsers.add_parser("replay_episode", parents=[base_parser])
parser_replay.add_argument(
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
)
parser_replay.add_argument(
"--root",
type=Path,
default="data",
help="Root directory where the dataset will be stored locally at '{root}/{repo_id}' (e.g. 'data/hf_username/dataset_name').",
)
parser_replay.add_argument(
"--repo-id",
type=str,
default="lerobot/test",
help="Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).",
)
parser_replay.add_argument("--episode", type=int, default=0, help="Index of the episode to replay.")
parser_policy = subparsers.add_parser("run_policy", parents=[base_parser])
parser_policy.add_argument(
"-p",
"--pretrained-policy-name-or-path",
type=str,
help=(
"Either the repo ID of a model hosted on the Hub or a path to a directory containing weights "
"saved using `Policy.save_pretrained`."
),
)
parser_policy.add_argument(
"overrides",
nargs="*",
help="Any key=value arguments to override config values (use dots for.nested=overrides)",
)
args = parser.parse_args()
init_logging()
control_mode = args.mode
robot_name = args.robot
kwargs = vars(args)
del kwargs["mode"]
del kwargs["robot"]
robot = make_robot(robot_name)
if control_mode == "teleoperate":
teleoperate(robot, **kwargs)
elif control_mode == "record_dataset":
record_dataset(robot, **kwargs)
elif control_mode == "replay_episode":
replay_episode(robot, **kwargs)
elif control_mode == "run_policy":
pretrained_policy_path = get_pretrained_policy_path(args.pretrained_policy_name_or_path)
hydra_cfg = init_hydra_config(pretrained_policy_path / "config.yaml", args.overrides)
policy = make_policy(hydra_cfg=hydra_cfg, pretrained_policy_name_or_path=pretrained_policy_path)
run_policy(robot, policy, hydra_cfg)

View File

@ -578,6 +578,29 @@ def main(
logging.info("End of eval") logging.info("End of eval")
def get_pretrained_policy_path(pretrained_policy_name_or_path, revision=None):
try:
pretrained_policy_path = Path(snapshot_download(pretrained_policy_name_or_path, revision=revision))
except (HFValidationError, RepositoryNotFoundError) as e:
if isinstance(e, HFValidationError):
error_message = (
"The provided pretrained_policy_name_or_path is not a valid Hugging Face Hub repo ID."
)
else:
error_message = (
"The provided pretrained_policy_name_or_path was not found on the Hugging Face Hub."
)
logging.warning(f"{error_message} Treating it as a local directory.")
pretrained_policy_path = Path(pretrained_policy_name_or_path)
if not pretrained_policy_path.is_dir() or not pretrained_policy_path.exists():
raise ValueError(
"The provided pretrained_policy_name_or_path is not a valid/existing Hugging Face Hub "
"repo ID, nor is it an existing local directory."
)
return pretrained_policy_path
if __name__ == "__main__": if __name__ == "__main__":
init_logging() init_logging()
@ -619,27 +642,9 @@ if __name__ == "__main__":
if args.pretrained_policy_name_or_path is None: if args.pretrained_policy_name_or_path is None:
main(hydra_cfg_path=args.config, out_dir=args.out_dir, config_overrides=args.overrides) main(hydra_cfg_path=args.config, out_dir=args.out_dir, config_overrides=args.overrides)
else: else:
try: pretrained_policy_path = get_pretrained_policy_path(
pretrained_policy_path = Path( args.pretrained_policy_name_or_path, revision=args.revision
snapshot_download(args.pretrained_policy_name_or_path, revision=args.revision) )
)
except (HFValidationError, RepositoryNotFoundError) as e:
if isinstance(e, HFValidationError):
error_message = (
"The provided pretrained_policy_name_or_path is not a valid Hugging Face Hub repo ID."
)
else:
error_message = (
"The provided pretrained_policy_name_or_path was not found on the Hugging Face Hub."
)
logging.warning(f"{error_message} Treating it as a local directory.")
pretrained_policy_path = Path(args.pretrained_policy_name_or_path)
if not pretrained_policy_path.is_dir() or not pretrained_policy_path.exists():
raise ValueError(
"The provided pretrained_policy_name_or_path is not a valid/existing Hugging Face Hub "
"repo ID, nor is it an existing local directory."
)
main( main(
pretrained_policy_path=pretrained_policy_path, pretrained_policy_path=pretrained_policy_path,

View File

@ -40,60 +40,6 @@ python lerobot/scripts/push_dataset_to_hub.py \
--raw-format umi_zarr \ --raw-format umi_zarr \
--repo-id lerobot/umi_cup_in_the_wild --repo-id lerobot/umi_cup_in_the_wild
``` ```
**WARNING: Updating an existing dataset**
If you want to update an existing dataset, you need to change the `CODEBASE_VERSION` from `lerobot_dataset.py`
before running `push_dataset_to_hub.py`. This is especially useful if you introduce a breaking change
intentionally or not (i.e. something not backward compatible such as modifying the reward functions used,
deleting some frames at the end of an episode, etc.). That way, people running a previous version of the
codebase won't be affected by your change and backward compatibility is maintained.
For instance, Pusht has many versions to maintain backward compatibility between LeRobot codebase versions:
- [v1.0](https://huggingface.co/datasets/lerobot/pusht/tree/v1.0)
- [v1.1](https://huggingface.co/datasets/lerobot/pusht/tree/v1.1)
- [v1.2](https://huggingface.co/datasets/lerobot/pusht/tree/v1.2)
- [v1.3](https://huggingface.co/datasets/lerobot/pusht/tree/v1.3)
- [v1.4](https://huggingface.co/datasets/lerobot/pusht/tree/v1.4)
- [v1.5](https://huggingface.co/datasets/lerobot/pusht/tree/v1.5) <-- last version
- [main](https://huggingface.co/datasets/lerobot/pusht/tree/main) <-- points to the last version
However, you will need to update the version of ALL the other datasets so that they have the new
`CODEBASE_VERSION` as a branch in their hugging face dataset repository. Don't worry, there is an easy way
that doesn't require to run `push_dataset_to_hub.py`. You can just "branch-out" from the `main` branch on HF
dataset repo by running this script which corresponds to a `git checkout -b` (so no copy or upload needed):
```python
import os
from huggingface_hub import create_branch, hf_hub_download
from huggingface_hub.utils._errors import RepositoryNotFoundError
from lerobot import available_datasets
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1" # makes it easier to see the print-out below
NEW_CODEBASE_VERSION = "v1.5" # REPLACE THIS WITH YOUR DESIRED VERSION
for repo_id in available_datasets:
# First check if the newer version already exists.
try:
hf_hub_download(
repo_id=repo_id, repo_type="dataset", filename=".gitattributes", revision=NEW_CODEBASE_VERSION
)
print(f"Found existing branch for {repo_id}. Please contact a member of the core LeRobot team.")
print("Exiting early")
break
except RepositoryNotFoundError:
# Now create a branch.
create_branch(repo_id, repo_type="dataset", branch=NEW_CODEBASE_VERSION, revision=CODEBASE_VERSION)
print(f"{repo_id} successfully updated")
```
On the other hand, if you are pushing a new dataset, you don't need to worry about any of the instructions
above, nor to be compatible with previous codebase versions.
""" """
import argparse import argparse
@ -104,7 +50,7 @@ from pathlib import Path
from typing import Any from typing import Any
import torch import torch
from huggingface_hub import HfApi, create_branch from huggingface_hub import HfApi
from safetensors.torch import save_file from safetensors.torch import save_file
from lerobot.common.datasets.compute_stats import compute_stats from lerobot.common.datasets.compute_stats import compute_stats
@ -208,8 +154,8 @@ def push_dataset_to_hub(
raw_dir = Path(raw_dir) raw_dir = Path(raw_dir)
if not raw_dir.exists(): if not raw_dir.exists():
raise NotADirectoryError( raise NotADirectoryError(
f"{raw_dir} does not exists. Check your paths or run this command to download an existing raw dataset on the hub:" f"{raw_dir} does not exists. Check your paths or run this command to download an existing raw dataset on the hub: "
f"python lerobot/common/datasets/push_dataset_to_hub/_download_raw.py --raw-dir your/raw/dir --repo-id your/repo/id_raw" f"`python lerobot/common/datasets/push_dataset_to_hub/_download_raw.py --raw-dir your/raw/dir --repo-id your/repo/id_raw`"
) )
if local_dir: if local_dir:
@ -270,7 +216,8 @@ def push_dataset_to_hub(
push_meta_data_to_hub(repo_id, meta_data_dir, revision="main") push_meta_data_to_hub(repo_id, meta_data_dir, revision="main")
if video: if video:
push_videos_to_hub(repo_id, videos_dir, revision="main") push_videos_to_hub(repo_id, videos_dir, revision="main")
create_branch(repo_id, repo_type="dataset", branch=CODEBASE_VERSION) api = HfApi()
api.create_branch(repo_id, repo_type="dataset", branch=CODEBASE_VERSION)
if tests_data_dir: if tests_data_dir:
# get the first episode # get the first episode

View File

@ -272,7 +272,8 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
cfg.resume = True cfg.resume = True
elif Logger.get_last_checkpoint_dir(out_dir).exists(): elif Logger.get_last_checkpoint_dir(out_dir).exists():
raise RuntimeError( raise RuntimeError(
f"The configured output directory {Logger.get_last_checkpoint_dir(out_dir)} already exists." f"The configured output directory {Logger.get_last_checkpoint_dir(out_dir)} already exists. If "
"you meant to resume training, please use `resume=true` in your command or yaml configuration."
) )
# log metrics to terminal and wandb # log metrics to terminal and wandb

View File

@ -25,7 +25,7 @@ Increase hue jitter
``` ```
python lerobot/scripts/visualize_image_transforms.py \ python lerobot/scripts/visualize_image_transforms.py \
dataset_repo_id=lerobot/aloha_mobile_shrimp \ dataset_repo_id=lerobot/aloha_mobile_shrimp \
training.image_transforms.hue.min_max=[-0.25,0.25] training.image_transforms.hue.min_max="[-0.25,0.25]"
``` ```
Increase brightness & brightness weight Increase brightness & brightness weight
@ -33,7 +33,7 @@ Increase brightness & brightness weight
python lerobot/scripts/visualize_image_transforms.py \ python lerobot/scripts/visualize_image_transforms.py \
dataset_repo_id=lerobot/aloha_mobile_shrimp \ dataset_repo_id=lerobot/aloha_mobile_shrimp \
training.image_transforms.brightness.weight=10.0 \ training.image_transforms.brightness.weight=10.0 \
training.image_transforms.brightness.min_max=[1.0,2.0] training.image_transforms.brightness.min_max="[1.0,2.0]"
``` ```
Blur images and disable saturation & hue Blur images and disable saturation & hue
@ -41,7 +41,7 @@ Blur images and disable saturation & hue
python lerobot/scripts/visualize_image_transforms.py \ python lerobot/scripts/visualize_image_transforms.py \
dataset_repo_id=lerobot/aloha_mobile_shrimp \ dataset_repo_id=lerobot/aloha_mobile_shrimp \
training.image_transforms.sharpness.weight=10.0 \ training.image_transforms.sharpness.weight=10.0 \
training.image_transforms.sharpness.min_max=[0.0,1.0] \ training.image_transforms.sharpness.min_max="[0.0,1.0]" \
training.image_transforms.saturation.weight=0.0 \ training.image_transforms.saturation.weight=0.0 \
training.image_transforms.hue.weight=0.0 training.image_transforms.hue.weight=0.0
``` ```
@ -172,4 +172,4 @@ def visualize_transforms_cli(cfg):
if __name__ == "__main__": if __name__ == "__main__":
visualize_transforms() visualize_transforms_cli()

Binary file not shown.

After

Width:  |  Height:  |  Size: 416 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 446 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 318 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 420 KiB

603
poetry.lock generated
View File

@ -444,63 +444,63 @@ files = [
[[package]] [[package]]
name = "coverage" name = "coverage"
version = "7.5.4" version = "7.6.0"
description = "Code coverage measurement for Python" description = "Code coverage measurement for Python"
optional = true optional = true
python-versions = ">=3.8" python-versions = ">=3.8"
files = [ files = [
{file = "coverage-7.5.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:6cfb5a4f556bb51aba274588200a46e4dd6b505fb1a5f8c5ae408222eb416f99"}, {file = "coverage-7.6.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:dff044f661f59dace805eedb4a7404c573b6ff0cdba4a524141bc63d7be5c7fd"},
{file = "coverage-7.5.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2174e7c23e0a454ffe12267a10732c273243b4f2d50d07544a91198f05c48f47"}, {file = "coverage-7.6.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a8659fd33ee9e6ca03950cfdcdf271d645cf681609153f218826dd9805ab585c"},
{file = "coverage-7.5.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2214ee920787d85db1b6a0bd9da5f8503ccc8fcd5814d90796c2f2493a2f4d2e"}, {file = "coverage-7.6.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7792f0ab20df8071d669d929c75c97fecfa6bcab82c10ee4adb91c7a54055463"},
{file = "coverage-7.5.4-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1137f46adb28e3813dec8c01fefadcb8c614f33576f672962e323b5128d9a68d"}, {file = "coverage-7.6.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d4b3cd1ca7cd73d229487fa5caca9e4bc1f0bca96526b922d61053ea751fe791"},
{file = "coverage-7.5.4-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b385d49609f8e9efc885790a5a0e89f2e3ae042cdf12958b6034cc442de428d3"}, {file = "coverage-7.6.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e7e128f85c0b419907d1f38e616c4f1e9f1d1b37a7949f44df9a73d5da5cd53c"},
{file = "coverage-7.5.4-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:b4a474f799456e0eb46d78ab07303286a84a3140e9700b9e154cfebc8f527016"}, {file = "coverage-7.6.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:a94925102c89247530ae1dab7dc02c690942566f22e189cbd53579b0693c0783"},
{file = "coverage-7.5.4-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:5cd64adedf3be66f8ccee418473c2916492d53cbafbfcff851cbec5a8454b136"}, {file = "coverage-7.6.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:dcd070b5b585b50e6617e8972f3fbbee786afca71b1936ac06257f7e178f00f6"},
{file = "coverage-7.5.4-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:e564c2cf45d2f44a9da56f4e3a26b2236504a496eb4cb0ca7221cd4cc7a9aca9"}, {file = "coverage-7.6.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:d50a252b23b9b4dfeefc1f663c568a221092cbaded20a05a11665d0dbec9b8fb"},
{file = "coverage-7.5.4-cp310-cp310-win32.whl", hash = "sha256:7076b4b3a5f6d2b5d7f1185fde25b1e54eb66e647a1dfef0e2c2bfaf9b4c88c8"}, {file = "coverage-7.6.0-cp310-cp310-win32.whl", hash = "sha256:0e7b27d04131c46e6894f23a4ae186a6a2207209a05df5b6ad4caee6d54a222c"},
{file = "coverage-7.5.4-cp310-cp310-win_amd64.whl", hash = "sha256:018a12985185038a5b2bcafab04ab833a9a0f2c59995b3cec07e10074c78635f"}, {file = "coverage-7.6.0-cp310-cp310-win_amd64.whl", hash = "sha256:54dece71673b3187c86226c3ca793c5f891f9fc3d8aa183f2e3653da18566169"},
{file = "coverage-7.5.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:db14f552ac38f10758ad14dd7b983dbab424e731588d300c7db25b6f89e335b5"}, {file = "coverage-7.6.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:c7b525ab52ce18c57ae232ba6f7010297a87ced82a2383b1afd238849c1ff933"},
{file = "coverage-7.5.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:3257fdd8e574805f27bb5342b77bc65578e98cbc004a92232106344053f319ba"}, {file = "coverage-7.6.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:4bea27c4269234e06f621f3fac3925f56ff34bc14521484b8f66a580aacc2e7d"},
{file = "coverage-7.5.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3a6612c99081d8d6134005b1354191e103ec9705d7ba2754e848211ac8cacc6b"}, {file = "coverage-7.6.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ed8d1d1821ba5fc88d4a4f45387b65de52382fa3ef1f0115a4f7a20cdfab0e94"},
{file = "coverage-7.5.4-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d45d3cbd94159c468b9b8c5a556e3f6b81a8d1af2a92b77320e887c3e7a5d080"}, {file = "coverage-7.6.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:01c322ef2bbe15057bc4bf132b525b7e3f7206f071799eb8aa6ad1940bcf5fb1"},
{file = "coverage-7.5.4-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ed550e7442f278af76d9d65af48069f1fb84c9f745ae249c1a183c1e9d1b025c"}, {file = "coverage-7.6.0-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:03cafe82c1b32b770a29fd6de923625ccac3185a54a5e66606da26d105f37dac"},
{file = "coverage-7.5.4-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:7a892be37ca35eb5019ec85402c3371b0f7cda5ab5056023a7f13da0961e60da"}, {file = "coverage-7.6.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:0d1b923fc4a40c5832be4f35a5dab0e5ff89cddf83bb4174499e02ea089daf57"},
{file = "coverage-7.5.4-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:8192794d120167e2a64721d88dbd688584675e86e15d0569599257566dec9bf0"}, {file = "coverage-7.6.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:4b03741e70fb811d1a9a1d75355cf391f274ed85847f4b78e35459899f57af4d"},
{file = "coverage-7.5.4-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:820bc841faa502e727a48311948e0461132a9c8baa42f6b2b84a29ced24cc078"}, {file = "coverage-7.6.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:a73d18625f6a8a1cbb11eadc1d03929f9510f4131879288e3f7922097a429f63"},
{file = "coverage-7.5.4-cp311-cp311-win32.whl", hash = "sha256:6aae5cce399a0f065da65c7bb1e8abd5c7a3043da9dceb429ebe1b289bc07806"}, {file = "coverage-7.6.0-cp311-cp311-win32.whl", hash = "sha256:65fa405b837060db569a61ec368b74688f429b32fa47a8929a7a2f9b47183713"},
{file = "coverage-7.5.4-cp311-cp311-win_amd64.whl", hash = "sha256:d2e344d6adc8ef81c5a233d3a57b3c7d5181f40e79e05e1c143da143ccb6377d"}, {file = "coverage-7.6.0-cp311-cp311-win_amd64.whl", hash = "sha256:6379688fb4cfa921ae349c76eb1a9ab26b65f32b03d46bb0eed841fd4cb6afb1"},
{file = "coverage-7.5.4-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:54317c2b806354cbb2dc7ac27e2b93f97096912cc16b18289c5d4e44fc663233"}, {file = "coverage-7.6.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:f7db0b6ae1f96ae41afe626095149ecd1b212b424626175a6633c2999eaad45b"},
{file = "coverage-7.5.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:042183de01f8b6d531e10c197f7f0315a61e8d805ab29c5f7b51a01d62782747"}, {file = "coverage-7.6.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:bbdf9a72403110a3bdae77948b8011f644571311c2fb35ee15f0f10a8fc082e8"},
{file = "coverage-7.5.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a6bb74ed465d5fb204b2ec41d79bcd28afccf817de721e8a807d5141c3426638"}, {file = "coverage-7.6.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9cc44bf0315268e253bf563f3560e6c004efe38f76db03a1558274a6e04bf5d5"},
{file = "coverage-7.5.4-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b3d45ff86efb129c599a3b287ae2e44c1e281ae0f9a9bad0edc202179bcc3a2e"}, {file = "coverage-7.6.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:da8549d17489cd52f85a9829d0e1d91059359b3c54a26f28bec2c5d369524807"},
{file = "coverage-7.5.4-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5013ed890dc917cef2c9f765c4c6a8ae9df983cd60dbb635df8ed9f4ebc9f555"}, {file = "coverage-7.6.0-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0086cd4fc71b7d485ac93ca4239c8f75732c2ae3ba83f6be1c9be59d9e2c6382"},
{file = "coverage-7.5.4-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:1014fbf665fef86cdfd6cb5b7371496ce35e4d2a00cda501cf9f5b9e6fced69f"}, {file = "coverage-7.6.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:1fad32ee9b27350687035cb5fdf9145bc9cf0a094a9577d43e909948ebcfa27b"},
{file = "coverage-7.5.4-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:3684bc2ff328f935981847082ba4fdc950d58906a40eafa93510d1b54c08a66c"}, {file = "coverage-7.6.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:044a0985a4f25b335882b0966625270a8d9db3d3409ddc49a4eb00b0ef5e8cee"},
{file = "coverage-7.5.4-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:581ea96f92bf71a5ec0974001f900db495488434a6928a2ca7f01eee20c23805"}, {file = "coverage-7.6.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:76d5f82213aa78098b9b964ea89de4617e70e0d43e97900c2778a50856dac605"},
{file = "coverage-7.5.4-cp312-cp312-win32.whl", hash = "sha256:73ca8fbc5bc622e54627314c1a6f1dfdd8db69788f3443e752c215f29fa87a0b"}, {file = "coverage-7.6.0-cp312-cp312-win32.whl", hash = "sha256:3c59105f8d58ce500f348c5b56163a4113a440dad6daa2294b5052a10db866da"},
{file = "coverage-7.5.4-cp312-cp312-win_amd64.whl", hash = "sha256:cef4649ec906ea7ea5e9e796e68b987f83fa9a718514fe147f538cfeda76d7a7"}, {file = "coverage-7.6.0-cp312-cp312-win_amd64.whl", hash = "sha256:ca5d79cfdae420a1d52bf177de4bc2289c321d6c961ae321503b2ca59c17ae67"},
{file = "coverage-7.5.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:cdd31315fc20868c194130de9ee6bfd99755cc9565edff98ecc12585b90be882"}, {file = "coverage-7.6.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:d39bd10f0ae453554798b125d2f39884290c480f56e8a02ba7a6ed552005243b"},
{file = "coverage-7.5.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:02ff6e898197cc1e9fa375581382b72498eb2e6d5fc0b53f03e496cfee3fac6d"}, {file = "coverage-7.6.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:beb08e8508e53a568811016e59f3234d29c2583f6b6e28572f0954a6b4f7e03d"},
{file = "coverage-7.5.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d05c16cf4b4c2fc880cb12ba4c9b526e9e5d5bb1d81313d4d732a5b9fe2b9d53"}, {file = "coverage-7.6.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b2e16f4cd2bc4d88ba30ca2d3bbf2f21f00f382cf4e1ce3b1ddc96c634bc48ca"},
{file = "coverage-7.5.4-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c5986ee7ea0795a4095ac4d113cbb3448601efca7f158ec7f7087a6c705304e4"}, {file = "coverage-7.6.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6616d1c9bf1e3faea78711ee42a8b972367d82ceae233ec0ac61cc7fec09fa6b"},
{file = "coverage-7.5.4-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5df54843b88901fdc2f598ac06737f03d71168fd1175728054c8f5a2739ac3e4"}, {file = "coverage-7.6.0-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ad4567d6c334c46046d1c4c20024de2a1c3abc626817ae21ae3da600f5779b44"},
{file = "coverage-7.5.4-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:ab73b35e8d109bffbda9a3e91c64e29fe26e03e49addf5b43d85fc426dde11f9"}, {file = "coverage-7.6.0-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:d17c6a415d68cfe1091d3296ba5749d3d8696e42c37fca5d4860c5bf7b729f03"},
{file = "coverage-7.5.4-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:aea072a941b033813f5e4814541fc265a5c12ed9720daef11ca516aeacd3bd7f"}, {file = "coverage-7.6.0-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:9146579352d7b5f6412735d0f203bbd8d00113a680b66565e205bc605ef81bc6"},
{file = "coverage-7.5.4-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:16852febd96acd953b0d55fc842ce2dac1710f26729b31c80b940b9afcd9896f"}, {file = "coverage-7.6.0-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:cdab02a0a941af190df8782aafc591ef3ad08824f97850b015c8c6a8b3877b0b"},
{file = "coverage-7.5.4-cp38-cp38-win32.whl", hash = "sha256:8f894208794b164e6bd4bba61fc98bf6b06be4d390cf2daacfa6eca0a6d2bb4f"}, {file = "coverage-7.6.0-cp38-cp38-win32.whl", hash = "sha256:df423f351b162a702c053d5dddc0fc0ef9a9e27ea3f449781ace5f906b664428"},
{file = "coverage-7.5.4-cp38-cp38-win_amd64.whl", hash = "sha256:e2afe743289273209c992075a5a4913e8d007d569a406ffed0bd080ea02b0633"}, {file = "coverage-7.6.0-cp38-cp38-win_amd64.whl", hash = "sha256:f2501d60d7497fd55e391f423f965bbe9e650e9ffc3c627d5f0ac516026000b8"},
{file = "coverage-7.5.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:b95c3a8cb0463ba9f77383d0fa8c9194cf91f64445a63fc26fb2327e1e1eb088"}, {file = "coverage-7.6.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:7221f9ac9dad9492cecab6f676b3eaf9185141539d5c9689d13fd6b0d7de840c"},
{file = "coverage-7.5.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:3d7564cc09dd91b5a6001754a5b3c6ecc4aba6323baf33a12bd751036c998be4"}, {file = "coverage-7.6.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:ddaaa91bfc4477d2871442bbf30a125e8fe6b05da8a0015507bfbf4718228ab2"},
{file = "coverage-7.5.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:44da56a2589b684813f86d07597fdf8a9c6ce77f58976727329272f5a01f99f7"}, {file = "coverage-7.6.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c4cbe651f3904e28f3a55d6f371203049034b4ddbce65a54527a3f189ca3b390"},
{file = "coverage-7.5.4-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e16f3d6b491c48c5ae726308e6ab1e18ee830b4cdd6913f2d7f77354b33f91c8"}, {file = "coverage-7.6.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:831b476d79408ab6ccfadaaf199906c833f02fdb32c9ab907b1d4aa0713cfa3b"},
{file = "coverage-7.5.4-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dbc5958cb471e5a5af41b0ddaea96a37e74ed289535e8deca404811f6cb0bc3d"}, {file = "coverage-7.6.0-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:46c3d091059ad0b9c59d1034de74a7f36dcfa7f6d3bde782c49deb42438f2450"},
{file = "coverage-7.5.4-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:a04e990a2a41740b02d6182b498ee9796cf60eefe40cf859b016650147908029"}, {file = "coverage-7.6.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:4d5fae0a22dc86259dee66f2cc6c1d3e490c4a1214d7daa2a93d07491c5c04b6"},
{file = "coverage-7.5.4-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:ddbd2f9713a79e8e7242d7c51f1929611e991d855f414ca9996c20e44a895f7c"}, {file = "coverage-7.6.0-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:07ed352205574aad067482e53dd606926afebcb5590653121063fbf4e2175166"},
{file = "coverage-7.5.4-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:b1ccf5e728ccf83acd313c89f07c22d70d6c375a9c6f339233dcf792094bcbf7"}, {file = "coverage-7.6.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:49c76cdfa13015c4560702574bad67f0e15ca5a2872c6a125f6327ead2b731dd"},
{file = "coverage-7.5.4-cp39-cp39-win32.whl", hash = "sha256:56b4eafa21c6c175b3ede004ca12c653a88b6f922494b023aeb1e836df953ace"}, {file = "coverage-7.6.0-cp39-cp39-win32.whl", hash = "sha256:482855914928c8175735a2a59c8dc5806cf7d8f032e4820d52e845d1f731dca2"},
{file = "coverage-7.5.4-cp39-cp39-win_amd64.whl", hash = "sha256:65e528e2e921ba8fd67d9055e6b9f9e34b21ebd6768ae1c1723f4ea6ace1234d"}, {file = "coverage-7.6.0-cp39-cp39-win_amd64.whl", hash = "sha256:543ef9179bc55edfd895154a51792b01c017c87af0ebaae092720152e19e42ca"},
{file = "coverage-7.5.4-pp38.pp39.pp310-none-any.whl", hash = "sha256:79b356f3dd5b26f3ad23b35c75dbdaf1f9e2450b6bcefc6d0825ea0aa3f86ca5"}, {file = "coverage-7.6.0-pp38.pp39.pp310-none-any.whl", hash = "sha256:6fe885135c8a479d3e37a7aae61cbd3a0fb2deccb4dda3c25f92a49189f766d6"},
{file = "coverage-7.5.4.tar.gz", hash = "sha256:a44963520b069e12789d0faea4e9fdb1e410cdc4aab89d94f7f55cbb7fef0353"}, {file = "coverage-7.6.0.tar.gz", hash = "sha256:289cc803fa1dc901f84701ac10c9ee873619320f2f9aff38794db4a4a0268d51"},
] ]
[package.dependencies] [package.dependencies]
@ -615,18 +615,18 @@ optimize = ["orjson"]
[[package]] [[package]]
name = "diffusers" name = "diffusers"
version = "0.27.2" version = "0.29.2"
description = "State-of-the-art diffusion in PyTorch and JAX." description = "State-of-the-art diffusion in PyTorch and JAX."
optional = false optional = false
python-versions = ">=3.8.0" python-versions = ">=3.8.0"
files = [ files = [
{file = "diffusers-0.27.2-py3-none-any.whl", hash = "sha256:85da5cd1098ab428535d592136973ec0c3f12f78148c94b379cb9f02d2414e75"}, {file = "diffusers-0.29.2-py3-none-any.whl", hash = "sha256:d5e9bb13c8097b4eed10df23d1294d2e5a418f53e3f89c7ef228b5b982970428"},
{file = "diffusers-0.27.2.tar.gz", hash = "sha256:6cefd7770d7fc1d139614233aa17cdcd639c138d0c3517b8d8bbc8cf573050a0"}, {file = "diffusers-0.29.2.tar.gz", hash = "sha256:b85f277668e22089cf68b40dd9b76940db7d24ba9cdac107533ed10ab8e4e9db"},
] ]
[package.dependencies] [package.dependencies]
filelock = "*" filelock = "*"
huggingface-hub = ">=0.20.2" huggingface-hub = ">=0.23.2"
importlib-metadata = "*" importlib-metadata = "*"
numpy = "*" numpy = "*"
Pillow = "*" Pillow = "*"
@ -635,13 +635,13 @@ requests = "*"
safetensors = ">=0.3.1" safetensors = ">=0.3.1"
[package.extras] [package.extras]
dev = ["GitPython (<3.1.19)", "Jinja2", "accelerate (>=0.11.0)", "compel (==0.1.8)", "datasets", "flax (>=0.4.1)", "hf-doc-builder (>=0.3.0)", "invisible-watermark (>=0.2.0)", "isort (>=5.5.4)", "jax (>=0.4.1)", "jaxlib (>=0.4.1)", "k-diffusion (>=0.0.12)", "librosa", "parameterized", "peft (>=0.6.0)", "protobuf (>=3.20.3,<4)", "pytest", "pytest-timeout", "pytest-xdist", "requests-mock (==1.10.0)", "ruff (==0.1.5)", "safetensors (>=0.3.1)", "scipy", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "torch (>=1.4)", "torchvision", "transformers (>=4.25.1)", "urllib3 (<=2.0.0)"] dev = ["GitPython (<3.1.19)", "Jinja2", "accelerate (>=0.29.3)", "compel (==0.1.8)", "datasets", "flax (>=0.4.1)", "hf-doc-builder (>=0.3.0)", "invisible-watermark (>=0.2.0)", "isort (>=5.5.4)", "jax (>=0.4.1)", "jaxlib (>=0.4.1)", "k-diffusion (>=0.0.12)", "librosa", "parameterized", "peft (>=0.6.0)", "protobuf (>=3.20.3,<4)", "pytest", "pytest-timeout", "pytest-xdist", "requests-mock (==1.10.0)", "ruff (==0.1.5)", "safetensors (>=0.3.1)", "scipy", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "torch (>=1.4)", "torchvision", "transformers (>=4.25.1)", "urllib3 (<=2.0.0)"]
docs = ["hf-doc-builder (>=0.3.0)"] docs = ["hf-doc-builder (>=0.3.0)"]
flax = ["flax (>=0.4.1)", "jax (>=0.4.1)", "jaxlib (>=0.4.1)"] flax = ["flax (>=0.4.1)", "jax (>=0.4.1)", "jaxlib (>=0.4.1)"]
quality = ["hf-doc-builder (>=0.3.0)", "isort (>=5.5.4)", "ruff (==0.1.5)", "urllib3 (<=2.0.0)"] quality = ["hf-doc-builder (>=0.3.0)", "isort (>=5.5.4)", "ruff (==0.1.5)", "urllib3 (<=2.0.0)"]
test = ["GitPython (<3.1.19)", "Jinja2", "compel (==0.1.8)", "datasets", "invisible-watermark (>=0.2.0)", "k-diffusion (>=0.0.12)", "librosa", "parameterized", "pytest", "pytest-timeout", "pytest-xdist", "requests-mock (==1.10.0)", "safetensors (>=0.3.1)", "scipy", "sentencepiece (>=0.1.91,!=0.1.92)", "torchvision", "transformers (>=4.25.1)"] test = ["GitPython (<3.1.19)", "Jinja2", "compel (==0.1.8)", "datasets", "invisible-watermark (>=0.2.0)", "k-diffusion (>=0.0.12)", "librosa", "parameterized", "pytest", "pytest-timeout", "pytest-xdist", "requests-mock (==1.10.0)", "safetensors (>=0.3.1)", "scipy", "sentencepiece (>=0.1.91,!=0.1.92)", "torchvision", "transformers (>=4.25.1)"]
torch = ["accelerate (>=0.11.0)", "torch (>=1.4)"] torch = ["accelerate (>=0.29.3)", "torch (>=1.4)"]
training = ["Jinja2", "accelerate (>=0.11.0)", "datasets", "peft (>=0.6.0)", "protobuf (>=3.20.3,<4)", "tensorboard"] training = ["Jinja2", "accelerate (>=0.29.3)", "datasets", "peft (>=0.6.0)", "protobuf (>=3.20.3,<4)", "tensorboard"]
[[package]] [[package]]
name = "dill" name = "dill"
@ -795,7 +795,6 @@ files = [
{file = "dora_rs-0.3.5-cp37-abi3-macosx_10_12_x86_64.whl", hash = "sha256:01f811d0c6722f74743c153a7be0144686daeafa968c473e60f6b6c5dc8f5bff"}, {file = "dora_rs-0.3.5-cp37-abi3-macosx_10_12_x86_64.whl", hash = "sha256:01f811d0c6722f74743c153a7be0144686daeafa968c473e60f6b6c5dc8f5bff"},
{file = "dora_rs-0.3.5-cp37-abi3-macosx_11_0_arm64.whl", hash = "sha256:a36e97d31eeb66e6d5913130695d188ceee1248029961012a8b4f59fd3f58670"}, {file = "dora_rs-0.3.5-cp37-abi3-macosx_11_0_arm64.whl", hash = "sha256:a36e97d31eeb66e6d5913130695d188ceee1248029961012a8b4f59fd3f58670"},
{file = "dora_rs-0.3.5-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:25d620123a733661dc740ef2b456601ddbaa69ae2b50d8141daa3c684bda385c"}, {file = "dora_rs-0.3.5-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:25d620123a733661dc740ef2b456601ddbaa69ae2b50d8141daa3c684bda385c"},
{file = "dora_rs-0.3.5-cp37-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:a9fdc4e73578bebb1c8d0f8bea2243a5a9e179f08c74d98576123b59b75e5cac"},
{file = "dora_rs-0.3.5-cp37-abi3-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e65830634c58158557f0ab90e5d1f492bcbc6b74587b05825ba4c20b634dc1bd"}, {file = "dora_rs-0.3.5-cp37-abi3-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e65830634c58158557f0ab90e5d1f492bcbc6b74587b05825ba4c20b634dc1bd"},
{file = "dora_rs-0.3.5-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c01f9ab8f93295341aeab2d606d484d9cff9d05f57581e2180433ec8e0d38307"}, {file = "dora_rs-0.3.5-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c01f9ab8f93295341aeab2d606d484d9cff9d05f57581e2180433ec8e0d38307"},
{file = "dora_rs-0.3.5-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:5d6d46a49a34cd7e4f74496a1089b9a1b78282c219a28d98fe031a763e92d530"}, {file = "dora_rs-0.3.5-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:5d6d46a49a34cd7e4f74496a1089b9a1b78282c219a28d98fe031a763e92d530"},
@ -807,6 +806,19 @@ files = [
[package.dependencies] [package.dependencies]
pyarrow = "*" pyarrow = "*"
[[package]]
name = "dynamixel-sdk"
version = "3.7.31"
description = "Dynamixel SDK 3. python package"
optional = true
python-versions = "*"
files = [
{file = "dynamixel_sdk-3.7.31-py3-none-any.whl", hash = "sha256:74e8c112ca6b0b869b196dd8c6a44ffd5dd5c1a3cb9fe2030e9933922406b466"},
]
[package.dependencies]
pyserial = "*"
[[package]] [[package]]
name = "einops" name = "einops"
version = "0.8.0" version = "0.8.0"
@ -818,15 +830,25 @@ files = [
{file = "einops-0.8.0.tar.gz", hash = "sha256:63486517fed345712a8385c100cb279108d9d47e6ae59099b07657e983deae85"}, {file = "einops-0.8.0.tar.gz", hash = "sha256:63486517fed345712a8385c100cb279108d9d47e6ae59099b07657e983deae85"},
] ]
[[package]]
name = "evdev"
version = "1.7.1"
description = "Bindings to the Linux input handling subsystem"
optional = true
python-versions = ">=3.6"
files = [
{file = "evdev-1.7.1.tar.gz", hash = "sha256:0c72c370bda29d857e188d931019c32651a9c1ea977c08c8d939b1ced1637fde"},
]
[[package]] [[package]]
name = "exceptiongroup" name = "exceptiongroup"
version = "1.2.1" version = "1.2.2"
description = "Backport of PEP 654 (exception groups)" description = "Backport of PEP 654 (exception groups)"
optional = true optional = true
python-versions = ">=3.7" python-versions = ">=3.7"
files = [ files = [
{file = "exceptiongroup-1.2.1-py3-none-any.whl", hash = "sha256:5258b9ed329c5bbdd31a309f53cbfb0b155341807f6ff7606a1e801a891b29ad"}, {file = "exceptiongroup-1.2.2-py3-none-any.whl", hash = "sha256:3111b9d131c238bec2f8f516e123e14ba243563fb135d3fe885990585aa7795b"},
{file = "exceptiongroup-1.2.1.tar.gz", hash = "sha256:a4785e48b045528f5bfe627b6ad554ff32def154f42372786903b7abcfe1aa16"}, {file = "exceptiongroup-1.2.2.tar.gz", hash = "sha256:47c2edf7c6738fafb49fd34290706d1a1a2f4d1c6df275526b62cbb4aa5393cc"},
] ]
[package.extras] [package.extras]
@ -1109,7 +1131,7 @@ pyarrow = ">=12.0.0"
type = "git" type = "git"
url = "https://github.com/dora-rs/dora-lerobot.git" url = "https://github.com/dora-rs/dora-lerobot.git"
reference = "HEAD" reference = "HEAD"
resolved_reference = "2addd1131a3c94f7b70b805577901b7967853e98" resolved_reference = "fda22deba84c46695369736edd34dc740aef45eb"
subdirectory = "gym_dora" subdirectory = "gym_dora"
[[package]] [[package]]
@ -1315,13 +1337,13 @@ files = [
[[package]] [[package]]
name = "huggingface-hub" name = "huggingface-hub"
version = "0.23.4" version = "0.23.5"
description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub" description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub"
optional = false optional = false
python-versions = ">=3.8.0" python-versions = ">=3.8.0"
files = [ files = [
{file = "huggingface_hub-0.23.4-py3-none-any.whl", hash = "sha256:3a0b957aa87150addf0cc7bd71b4d954b78e749850e1e7fb29ebbd2db64ca037"}, {file = "huggingface_hub-0.23.5-py3-none-any.whl", hash = "sha256:d7a7d337615e11a45cc14a0ce5a605db6b038dc24af42866f731684825226e90"},
{file = "huggingface_hub-0.23.4.tar.gz", hash = "sha256:35d99016433900e44ae7efe1c209164a5a81dbbcd53a52f99c281dcd7ce22431"}, {file = "huggingface_hub-0.23.5.tar.gz", hash = "sha256:67a9caba79b71235be3752852ca27da86bd54311d2424ca8afdb8dda056edf98"},
] ]
[package.dependencies] [package.dependencies]
@ -1366,13 +1388,13 @@ packaging = "*"
[[package]] [[package]]
name = "identify" name = "identify"
version = "2.5.36" version = "2.6.0"
description = "File identification library for Python" description = "File identification library for Python"
optional = true optional = true
python-versions = ">=3.8" python-versions = ">=3.8"
files = [ files = [
{file = "identify-2.5.36-py2.py3-none-any.whl", hash = "sha256:37d93f380f4de590500d9dba7db359d0d3da95ffe7f9de1753faa159e71e7dfa"}, {file = "identify-2.6.0-py2.py3-none-any.whl", hash = "sha256:e79ae4406387a9d300332b5fd366d8994f1525e8414984e1a59e058b2eda2dd0"},
{file = "identify-2.5.36.tar.gz", hash = "sha256:e5e00f54165f9047fbebeb4a560f9acfb8af4c88232be60a488e9b68d122745d"}, {file = "identify-2.6.0.tar.gz", hash = "sha256:cb171c685bdc31bcc4c1734698736a7d5b6c8bf2e0c15117f4d469c8640ae5cf"},
] ]
[package.extras] [package.extras]
@ -1719,13 +1741,9 @@ files = [
{file = "lxml-5.2.2-cp36-cp36m-win_amd64.whl", hash = "sha256:edcfa83e03370032a489430215c1e7783128808fd3e2e0a3225deee278585196"}, {file = "lxml-5.2.2-cp36-cp36m-win_amd64.whl", hash = "sha256:edcfa83e03370032a489430215c1e7783128808fd3e2e0a3225deee278585196"},
{file = "lxml-5.2.2-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:28bf95177400066596cdbcfc933312493799382879da504633d16cf60bba735b"}, {file = "lxml-5.2.2-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:28bf95177400066596cdbcfc933312493799382879da504633d16cf60bba735b"},
{file = "lxml-5.2.2-cp37-cp37m-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3a745cc98d504d5bd2c19b10c79c61c7c3df9222629f1b6210c0368177589fb8"}, {file = "lxml-5.2.2-cp37-cp37m-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3a745cc98d504d5bd2c19b10c79c61c7c3df9222629f1b6210c0368177589fb8"},
{file = "lxml-5.2.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1b590b39ef90c6b22ec0be925b211298e810b4856909c8ca60d27ffbca6c12e6"},
{file = "lxml-5.2.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b336b0416828022bfd5a2e3083e7f5ba54b96242159f83c7e3eebaec752f1716"}, {file = "lxml-5.2.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b336b0416828022bfd5a2e3083e7f5ba54b96242159f83c7e3eebaec752f1716"},
{file = "lxml-5.2.2-cp37-cp37m-manylinux_2_28_aarch64.whl", hash = "sha256:c2faf60c583af0d135e853c86ac2735ce178f0e338a3c7f9ae8f622fd2eb788c"},
{file = "lxml-5.2.2-cp37-cp37m-manylinux_2_28_x86_64.whl", hash = "sha256:4bc6cb140a7a0ad1f7bc37e018d0ed690b7b6520ade518285dc3171f7a117905"}, {file = "lxml-5.2.2-cp37-cp37m-manylinux_2_28_x86_64.whl", hash = "sha256:4bc6cb140a7a0ad1f7bc37e018d0ed690b7b6520ade518285dc3171f7a117905"},
{file = "lxml-5.2.2-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:7ff762670cada8e05b32bf1e4dc50b140790909caa8303cfddc4d702b71ea184"},
{file = "lxml-5.2.2-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:57f0a0bbc9868e10ebe874e9f129d2917750adf008fe7b9c1598c0fbbfdde6a6"}, {file = "lxml-5.2.2-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:57f0a0bbc9868e10ebe874e9f129d2917750adf008fe7b9c1598c0fbbfdde6a6"},
{file = "lxml-5.2.2-cp37-cp37m-musllinux_1_2_aarch64.whl", hash = "sha256:a6d2092797b388342c1bc932077ad232f914351932353e2e8706851c870bca1f"},
{file = "lxml-5.2.2-cp37-cp37m-musllinux_1_2_x86_64.whl", hash = "sha256:60499fe961b21264e17a471ec296dcbf4365fbea611bf9e303ab69db7159ce61"}, {file = "lxml-5.2.2-cp37-cp37m-musllinux_1_2_x86_64.whl", hash = "sha256:60499fe961b21264e17a471ec296dcbf4365fbea611bf9e303ab69db7159ce61"},
{file = "lxml-5.2.2-cp37-cp37m-win32.whl", hash = "sha256:d9b342c76003c6b9336a80efcc766748a333573abf9350f4094ee46b006ec18f"}, {file = "lxml-5.2.2-cp37-cp37m-win32.whl", hash = "sha256:d9b342c76003c6b9336a80efcc766748a333573abf9350f4094ee46b006ec18f"},
{file = "lxml-5.2.2-cp37-cp37m-win_amd64.whl", hash = "sha256:b16db2770517b8799c79aa80f4053cd6f8b716f21f8aca962725a9565ce3ee40"}, {file = "lxml-5.2.2-cp37-cp37m-win_amd64.whl", hash = "sha256:b16db2770517b8799c79aa80f4053cd6f8b716f21f8aca962725a9565ce3ee40"},
@ -2152,43 +2170,36 @@ numpy = ">=1.22,<2.1"
[[package]] [[package]]
name = "numcodecs" name = "numcodecs"
version = "0.12.1" version = "0.13.0"
description = "A Python package providing buffer compression and transformation codecs for use in data storage and communication applications." description = "A Python package providing buffer compression and transformation codecs for use in data storage and communication applications."
optional = false optional = false
python-versions = ">=3.8" python-versions = ">=3.10"
files = [ files = [
{file = "numcodecs-0.12.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d37f628fe92b3699e65831d5733feca74d2e33b50ef29118ffd41c13c677210e"}, {file = "numcodecs-0.13.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:56e49f68ce6aeba29f144992524c8897d94f846d02bbcc820dd29d7c5c2a073e"},
{file = "numcodecs-0.12.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:941b7446b68cf79f089bcfe92edaa3b154533dcbcd82474f994b28f2eedb1c60"}, {file = "numcodecs-0.13.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:17bc4b568214582f4c623700592f633f3afd920848630049c584fa1e535253ad"},
{file = "numcodecs-0.12.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0e79bf9d1d37199ac00a60ff3adb64757523291d19d03116832e600cac391c51"}, {file = "numcodecs-0.13.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eed420a9c62d0a569aa94a387f93045f068ad3e7bbd787c6ce70bc5fefbaa7d9"},
{file = "numcodecs-0.12.1-cp310-cp310-win_amd64.whl", hash = "sha256:82d7107f80f9307235cb7e74719292d101c7ea1e393fe628817f0d635b7384f5"}, {file = "numcodecs-0.13.0-cp310-cp310-win_amd64.whl", hash = "sha256:e7d3b9693df52eeaf978d2a56971d01cf9b4e284ae769ec764807f2087cce51d"},
{file = "numcodecs-0.12.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:eeaf42768910f1c6eebf6c1bb00160728e62c9343df9e2e315dc9fe12e3f6071"}, {file = "numcodecs-0.13.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:f208a1b8b5e66c767ed043812ca74d9045e09b7b2e085d064a585c30b9efc8e7"},
{file = "numcodecs-0.12.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:135b2d47563f7b9dc5ee6ce3d1b81b0f1397f69309e909f1a35bb0f7c553d45e"}, {file = "numcodecs-0.13.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a68368d3ce625ec76fcacd84785f6110d30a232909d5c6093a7aa25628880477"},
{file = "numcodecs-0.12.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a191a8e347ecd016e5c357f2bf41fbcb026f6ffe78fff50c77ab12e96701d155"}, {file = "numcodecs-0.13.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f5904216811f2e9d312c23ffaad3b3d4c7442a3583d3a8bf81ca8319e9f5deb5"},
{file = "numcodecs-0.12.1-cp311-cp311-win_amd64.whl", hash = "sha256:21d8267bd4313f4d16f5b6287731d4c8ebdab236038f29ad1b0e93c9b2ca64ee"}, {file = "numcodecs-0.13.0-cp311-cp311-win_amd64.whl", hash = "sha256:208cab0f4d9cf4409e9c4a4c935e165833786614822c81dee9d865af372da9df"},
{file = "numcodecs-0.12.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:2f84df6b8693206365a5b37c005bfa9d1be486122bde683a7b6446af4b75d862"}, {file = "numcodecs-0.13.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:f3cf462d2357998d7f6baaa0427657b0eeda3eb79fba2b146d2d04542912a513"},
{file = "numcodecs-0.12.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:760627780a8b6afdb7f942f2a0ddaf4e31d3d7eea1d8498cf0fd3204a33c4618"}, {file = "numcodecs-0.13.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:ac4dd5556fb126271e93bd1a02266e21b01d3617db448d70d00eec8e034506b4"},
{file = "numcodecs-0.12.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c258bd1d3dfa75a9b708540d23b2da43d63607f9df76dfa0309a7597d1de3b73"}, {file = "numcodecs-0.13.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:820be89729583c91601a6b35c052008cdd2665b25bfedb91b367cc155fb34ba0"},
{file = "numcodecs-0.12.1-cp312-cp312-win_amd64.whl", hash = "sha256:e04649ea504aff858dbe294631f098fbfd671baf58bfc04fc48d746554c05d67"}, {file = "numcodecs-0.13.0-cp312-cp312-win_amd64.whl", hash = "sha256:d67a859dd8a7f026829e91cb1799c26720cc9d29ee4ae0060cc7a581670abc06"},
{file = "numcodecs-0.12.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:caf1a1e6678aab9c1e29d2109b299f7a467bd4d4c34235b1f0e082167846b88f"}, {file = "numcodecs-0.13.0.tar.gz", hash = "sha256:ba4fac7036ea5a078c7afe1d4dffeb9685080d42f19c9c16b12dad866703aa2e"},
{file = "numcodecs-0.12.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:c17687b1fd1fef68af616bc83f896035d24e40e04e91e7e6dae56379eb59fe33"},
{file = "numcodecs-0.12.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:29dfb195f835a55c4d490fb097aac8c1bcb96c54cf1b037d9218492c95e9d8c5"},
{file = "numcodecs-0.12.1-cp38-cp38-win_amd64.whl", hash = "sha256:2f1ba2f4af3fd3ba65b1bcffb717fe65efe101a50a91c368f79f3101dbb1e243"},
{file = "numcodecs-0.12.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:2fbb12a6a1abe95926f25c65e283762d63a9bf9e43c0de2c6a1a798347dfcb40"},
{file = "numcodecs-0.12.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:f2207871868b2464dc11c513965fd99b958a9d7cde2629be7b2dc84fdaab013b"},
{file = "numcodecs-0.12.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:abff3554a6892a89aacf7b642a044e4535499edf07aeae2f2e6e8fc08c9ba07f"},
{file = "numcodecs-0.12.1-cp39-cp39-win_amd64.whl", hash = "sha256:ef964d4860d3e6b38df0633caf3e51dc850a6293fd8e93240473642681d95136"},
{file = "numcodecs-0.12.1.tar.gz", hash = "sha256:05d91a433733e7eef268d7e80ec226a0232da244289614a8f3826901aec1098e"},
] ]
[package.dependencies] [package.dependencies]
numpy = ">=1.7" numpy = ">=1.7"
[package.extras] [package.extras]
docs = ["mock", "numpydoc", "sphinx (<7.0.0)", "sphinx-issues"] docs = ["mock", "numpydoc", "pydata-sphinx-theme", "sphinx (<7.0.0)", "sphinx-issues"]
msgpack = ["msgpack"] msgpack = ["msgpack"]
test = ["coverage", "flake8", "pytest", "pytest-cov"] pcodec = ["pcodec (>=0.2.0)"]
test = ["coverage", "pytest", "pytest-cov"]
test-extras = ["importlib-metadata"] test-extras = ["importlib-metadata"]
zfpy = ["zfpy (>=1.0.0)"] zfpy = ["numpy (<2.0.0)", "zfpy (>=1.0.0)"]
[[package]] [[package]]
name = "numpy" name = "numpy"
@ -2750,52 +2761,55 @@ test = ["enum34", "ipaddress", "mock", "pywin32", "wmi"]
[[package]] [[package]]
name = "pyarrow" name = "pyarrow"
version = "16.1.0" version = "17.0.0"
description = "Python library for Apache Arrow" description = "Python library for Apache Arrow"
optional = false optional = false
python-versions = ">=3.8" python-versions = ">=3.8"
files = [ files = [
{file = "pyarrow-16.1.0-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:17e23b9a65a70cc733d8b738baa6ad3722298fa0c81d88f63ff94bf25eaa77b9"}, {file = "pyarrow-17.0.0-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:a5c8b238d47e48812ee577ee20c9a2779e6a5904f1708ae240f53ecbee7c9f07"},
{file = "pyarrow-16.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:4740cc41e2ba5d641071d0ab5e9ef9b5e6e8c7611351a5cb7c1d175eaf43674a"}, {file = "pyarrow-17.0.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:db023dc4c6cae1015de9e198d41250688383c3f9af8f565370ab2b4cb5f62655"},
{file = "pyarrow-16.1.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:98100e0268d04e0eec47b73f20b39c45b4006f3c4233719c3848aa27a03c1aef"}, {file = "pyarrow-17.0.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:da1e060b3876faa11cee287839f9cc7cdc00649f475714b8680a05fd9071d545"},
{file = "pyarrow-16.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f68f409e7b283c085f2da014f9ef81e885d90dcd733bd648cfba3ef265961848"}, {file = "pyarrow-17.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:75c06d4624c0ad6674364bb46ef38c3132768139ddec1c56582dbac54f2663e2"},
{file = "pyarrow-16.1.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:a8914cd176f448e09746037b0c6b3a9d7688cef451ec5735094055116857580c"}, {file = "pyarrow-17.0.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:fa3c246cc58cb5a4a5cb407a18f193354ea47dd0648194e6265bd24177982fe8"},
{file = "pyarrow-16.1.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:48be160782c0556156d91adbdd5a4a7e719f8d407cb46ae3bb4eaee09b3111bd"}, {file = "pyarrow-17.0.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:f7ae2de664e0b158d1607699a16a488de3d008ba99b3a7aa5de1cbc13574d047"},
{file = "pyarrow-16.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:9cf389d444b0f41d9fe1444b70650fea31e9d52cfcb5f818b7888b91b586efff"}, {file = "pyarrow-17.0.0-cp310-cp310-win_amd64.whl", hash = "sha256:5984f416552eea15fd9cee03da53542bf4cddaef5afecefb9aa8d1010c335087"},
{file = "pyarrow-16.1.0-cp311-cp311-macosx_10_15_x86_64.whl", hash = "sha256:d0ebea336b535b37eee9eee31761813086d33ed06de9ab6fc6aaa0bace7b250c"}, {file = "pyarrow-17.0.0-cp311-cp311-macosx_10_15_x86_64.whl", hash = "sha256:1c8856e2ef09eb87ecf937104aacfa0708f22dfeb039c363ec99735190ffb977"},
{file = "pyarrow-16.1.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2e73cfc4a99e796727919c5541c65bb88b973377501e39b9842ea71401ca6c1c"}, {file = "pyarrow-17.0.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2e19f569567efcbbd42084e87f948778eb371d308e137a0f97afe19bb860ccb3"},
{file = "pyarrow-16.1.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bf9251264247ecfe93e5f5a0cd43b8ae834f1e61d1abca22da55b20c788417f6"}, {file = "pyarrow-17.0.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6b244dc8e08a23b3e352899a006a26ae7b4d0da7bb636872fa8f5884e70acf15"},
{file = "pyarrow-16.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ddf5aace92d520d3d2a20031d8b0ec27b4395cab9f74e07cc95edf42a5cc0147"}, {file = "pyarrow-17.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0b72e87fe3e1db343995562f7fff8aee354b55ee83d13afba65400c178ab2597"},
{file = "pyarrow-16.1.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:25233642583bf658f629eb230b9bb79d9af4d9f9229890b3c878699c82f7d11e"}, {file = "pyarrow-17.0.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:dc5c31c37409dfbc5d014047817cb4ccd8c1ea25d19576acf1a001fe07f5b420"},
{file = "pyarrow-16.1.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:a33a64576fddfbec0a44112eaf844c20853647ca833e9a647bfae0582b2ff94b"}, {file = "pyarrow-17.0.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:e3343cb1e88bc2ea605986d4b94948716edc7a8d14afd4e2c097232f729758b4"},
{file = "pyarrow-16.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:185d121b50836379fe012753cf15c4ba9638bda9645183ab36246923875f8d1b"}, {file = "pyarrow-17.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:a27532c38f3de9eb3e90ecab63dfda948a8ca859a66e3a47f5f42d1e403c4d03"},
{file = "pyarrow-16.1.0-cp312-cp312-macosx_10_15_x86_64.whl", hash = "sha256:2e51ca1d6ed7f2e9d5c3c83decf27b0d17bb207a7dea986e8dc3e24f80ff7d6f"}, {file = "pyarrow-17.0.0-cp312-cp312-macosx_10_15_x86_64.whl", hash = "sha256:9b8a823cea605221e61f34859dcc03207e52e409ccf6354634143e23af7c8d22"},
{file = "pyarrow-16.1.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:06ebccb6f8cb7357de85f60d5da50e83507954af617d7b05f48af1621d331c9a"}, {file = "pyarrow-17.0.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:f1e70de6cb5790a50b01d2b686d54aaf73da01266850b05e3af2a1bc89e16053"},
{file = "pyarrow-16.1.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b04707f1979815f5e49824ce52d1dceb46e2f12909a48a6a753fe7cafbc44a0c"}, {file = "pyarrow-17.0.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0071ce35788c6f9077ff9ecba4858108eebe2ea5a3f7cf2cf55ebc1dbc6ee24a"},
{file = "pyarrow-16.1.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0d32000693deff8dc5df444b032b5985a48592c0697cb6e3071a5d59888714e2"}, {file = "pyarrow-17.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:757074882f844411fcca735e39aae74248a1531367a7c80799b4266390ae51cc"},
{file = "pyarrow-16.1.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:8785bb10d5d6fd5e15d718ee1d1f914fe768bf8b4d1e5e9bf253de8a26cb1628"}, {file = "pyarrow-17.0.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:9ba11c4f16976e89146781a83833df7f82077cdab7dc6232c897789343f7891a"},
{file = "pyarrow-16.1.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:e1369af39587b794873b8a307cc6623a3b1194e69399af0efd05bb202195a5a7"}, {file = "pyarrow-17.0.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:b0c6ac301093b42d34410b187bba560b17c0330f64907bfa4f7f7f2444b0cf9b"},
{file = "pyarrow-16.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:febde33305f1498f6df85e8020bca496d0e9ebf2093bab9e0f65e2b4ae2b3444"}, {file = "pyarrow-17.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:392bc9feabc647338e6c89267635e111d71edad5fcffba204425a7c8d13610d7"},
{file = "pyarrow-16.1.0-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:b5f5705ab977947a43ac83b52ade3b881eb6e95fcc02d76f501d549a210ba77f"}, {file = "pyarrow-17.0.0-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:af5ff82a04b2171415f1410cff7ebb79861afc5dae50be73ce06d6e870615204"},
{file = "pyarrow-16.1.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:0d27bf89dfc2576f6206e9cd6cf7a107c9c06dc13d53bbc25b0bd4556f19cf5f"}, {file = "pyarrow-17.0.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:edca18eaca89cd6382dfbcff3dd2d87633433043650c07375d095cd3517561d8"},
{file = "pyarrow-16.1.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0d07de3ee730647a600037bc1d7b7994067ed64d0eba797ac74b2bc77384f4c2"}, {file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7c7916bff914ac5d4a8fe25b7a25e432ff921e72f6f2b7547d1e325c1ad9d155"},
{file = "pyarrow-16.1.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fbef391b63f708e103df99fbaa3acf9f671d77a183a07546ba2f2c297b361e83"}, {file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f553ca691b9e94b202ff741bdd40f6ccb70cdd5fbf65c187af132f1317de6145"},
{file = "pyarrow-16.1.0-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:19741c4dbbbc986d38856ee7ddfdd6a00fc3b0fc2d928795b95410d38bb97d15"}, {file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:0cdb0e627c86c373205a2f94a510ac4376fdc523f8bb36beab2e7f204416163c"},
{file = "pyarrow-16.1.0-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:f2c5fb249caa17b94e2b9278b36a05ce03d3180e6da0c4c3b3ce5b2788f30eed"}, {file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:d7d192305d9d8bc9082d10f361fc70a73590a4c65cf31c3e6926cd72b76bc35c"},
{file = "pyarrow-16.1.0-cp38-cp38-win_amd64.whl", hash = "sha256:e6b6d3cd35fbb93b70ade1336022cc1147b95ec6af7d36906ca7fe432eb09710"}, {file = "pyarrow-17.0.0-cp38-cp38-win_amd64.whl", hash = "sha256:02dae06ce212d8b3244dd3e7d12d9c4d3046945a5933d28026598e9dbbda1fca"},
{file = "pyarrow-16.1.0-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:18da9b76a36a954665ccca8aa6bd9f46c1145f79c0bb8f4f244f5f8e799bca55"}, {file = "pyarrow-17.0.0-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:13d7a460b412f31e4c0efa1148e1d29bdf18ad1411eb6757d38f8fbdcc8645fb"},
{file = "pyarrow-16.1.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:99f7549779b6e434467d2aa43ab2b7224dd9e41bdde486020bae198978c9e05e"}, {file = "pyarrow-17.0.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9b564a51fbccfab5a04a80453e5ac6c9954a9c5ef2890d1bcf63741909c3f8df"},
{file = "pyarrow-16.1.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f07fdffe4fd5b15f5ec15c8b64584868d063bc22b86b46c9695624ca3505b7b4"}, {file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:32503827abbc5aadedfa235f5ece8c4f8f8b0a3cf01066bc8d29de7539532687"},
{file = "pyarrow-16.1.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ddfe389a08ea374972bd4065d5f25d14e36b43ebc22fc75f7b951f24378bf0b5"}, {file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a155acc7f154b9ffcc85497509bcd0d43efb80d6f733b0dc3bb14e281f131c8b"},
{file = "pyarrow-16.1.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:3b20bd67c94b3a2ea0a749d2a5712fc845a69cb5d52e78e6449bbd295611f3aa"}, {file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:dec8d129254d0188a49f8a1fc99e0560dc1b85f60af729f47de4046015f9b0a5"},
{file = "pyarrow-16.1.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:ba8ac20693c0bb0bf4b238751d4409e62852004a8cf031c73b0e0962b03e45e3"}, {file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:a48ddf5c3c6a6c505904545c25a4ae13646ae1f8ba703c4df4a1bfe4f4006bda"},
{file = "pyarrow-16.1.0-cp39-cp39-win_amd64.whl", hash = "sha256:31a1851751433d89a986616015841977e0a188662fcffd1a5677453f1df2de0a"}, {file = "pyarrow-17.0.0-cp39-cp39-win_amd64.whl", hash = "sha256:42bf93249a083aca230ba7e2786c5f673507fa97bbd9725a1e2754715151a204"},
{file = "pyarrow-16.1.0.tar.gz", hash = "sha256:15fbb22ea96d11f0b5768504a3f961edab25eaf4197c341720c4a387f6c60315"}, {file = "pyarrow-17.0.0.tar.gz", hash = "sha256:4beca9521ed2c0921c1023e68d097d0299b62c362639ea315572a58f3f50fd28"},
] ]
[package.dependencies] [package.dependencies]
numpy = ">=1.16.6" numpy = ">=1.16.6"
[package.extras]
test = ["cffi", "hypothesis", "pandas", "pytest", "pytz"]
[[package]] [[package]]
name = "pyarrow-hotfix" name = "pyarrow-hotfix"
version = "0.6" version = "0.6"
@ -2987,6 +3001,126 @@ cffi = ">=1.15.0"
[package.extras] [package.extras]
dev = ["aafigure", "matplotlib", "numpy", "pygame", "pyglet (<2.0.0)", "sphinx", "wheel"] dev = ["aafigure", "matplotlib", "numpy", "pygame", "pyglet (<2.0.0)", "sphinx", "wheel"]
[[package]]
name = "pynput"
version = "1.7.7"
description = "Monitor and control user input devices"
optional = true
python-versions = "*"
files = [
{file = "pynput-1.7.7-py2.py3-none-any.whl", hash = "sha256:afc43f651684c98818de048abc76adf9f2d3d797083cb07c1f82be764a2d44cb"},
]
[package.dependencies]
evdev = {version = ">=1.3", markers = "sys_platform in \"linux\""}
pyobjc-framework-ApplicationServices = {version = ">=8.0", markers = "sys_platform == \"darwin\""}
pyobjc-framework-Quartz = {version = ">=8.0", markers = "sys_platform == \"darwin\""}
python-xlib = {version = ">=0.17", markers = "sys_platform in \"linux\""}
six = "*"
[[package]]
name = "pyobjc-core"
version = "10.3.1"
description = "Python<->ObjC Interoperability Module"
optional = true
python-versions = ">=3.8"
files = [
{file = "pyobjc_core-10.3.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:ea46d2cda17921e417085ac6286d43ae448113158afcf39e0abe484c58fb3d78"},
{file = "pyobjc_core-10.3.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:899d3c84d2933d292c808f385dc881a140cf08632907845043a333a9d7c899f9"},
{file = "pyobjc_core-10.3.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:6ff5823d13d0a534cdc17fa4ad47cf5bee4846ce0fd27fc40012e12b46db571b"},
{file = "pyobjc_core-10.3.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:2581e8e68885bcb0e11ec619e81ef28e08ee3fac4de20d8cc83bc5af5bcf4a90"},
{file = "pyobjc_core-10.3.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:ea98d4c2ec39ca29e62e0327db21418696161fb138ee6278daf2acbedf7ce504"},
{file = "pyobjc_core-10.3.1-cp38-cp38-macosx_11_0_universal2.whl", hash = "sha256:4c179c26ee2123d0aabffb9dbc60324b62b6f8614fb2c2328b09386ef59ef6d8"},
{file = "pyobjc_core-10.3.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:cb901fce65c9be420c40d8a6ee6fff5ff27c6945f44fd7191989b982baa66dea"},
{file = "pyobjc_core-10.3.1.tar.gz", hash = "sha256:b204a80ccc070f9ab3f8af423a3a25a6fd787e228508d00c4c30f8ac538ba720"},
]
[[package]]
name = "pyobjc-framework-applicationservices"
version = "10.3.1"
description = "Wrappers for the framework ApplicationServices on macOS"
optional = true
python-versions = ">=3.8"
files = [
{file = "pyobjc_framework_ApplicationServices-10.3.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:b694260d423c470cb90c3a7009cfde93e332ea6fb4b9b9526ad3acbd33460e3d"},
{file = "pyobjc_framework_ApplicationServices-10.3.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:d886ba1f65df47b77ff7546f3fc9bc7d08cfb6b3c04433b719f6b0689a2c0d1f"},
{file = "pyobjc_framework_ApplicationServices-10.3.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:be157f2c3ffb254064ef38249670af8cada5e519a714d2aa5da3740934d89bc8"},
{file = "pyobjc_framework_ApplicationServices-10.3.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:57737f41731661e4a3b78793ec9173f61242a32fa560c3e4e58484465d049c32"},
{file = "pyobjc_framework_ApplicationServices-10.3.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:c429eca69ee675e781e4e55f79e939196b47f02560ad865b1ba9ac753b90bd77"},
{file = "pyobjc_framework_ApplicationServices-10.3.1-cp38-cp38-macosx_11_0_universal2.whl", hash = "sha256:4f1814a17041a20adca454044080b52e39a4ebc567ad2c6a48866dd4beaa192a"},
{file = "pyobjc_framework_ApplicationServices-10.3.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:1252f1137f83eb2c6b9968d8c591363e8859dd2484bc9441d8f365bcfb43a0e4"},
{file = "pyobjc_framework_applicationservices-10.3.1.tar.gz", hash = "sha256:f27cb64aa4d129ce671fd42638c985eb2a56d544214a95fe3214a007eacc4790"},
]
[package.dependencies]
pyobjc-core = ">=10.3.1"
pyobjc-framework-Cocoa = ">=10.3.1"
pyobjc-framework-CoreText = ">=10.3.1"
pyobjc-framework-Quartz = ">=10.3.1"
[[package]]
name = "pyobjc-framework-cocoa"
version = "10.3.1"
description = "Wrappers for the Cocoa frameworks on macOS"
optional = true
python-versions = ">=3.8"
files = [
{file = "pyobjc_framework_Cocoa-10.3.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:4cb4f8491ab4d9b59f5187e42383f819f7a46306a4fa25b84f126776305291d1"},
{file = "pyobjc_framework_Cocoa-10.3.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:5f31021f4f8fdf873b57a97ee1f3c1620dbe285e0b4eaed73dd0005eb72fd773"},
{file = "pyobjc_framework_Cocoa-10.3.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:11b4e0bad4bbb44a4edda128612f03cdeab38644bbf174de0c13129715497296"},
{file = "pyobjc_framework_Cocoa-10.3.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:de5e62e5ccf2871a94acf3bf79646b20ea893cc9db78afa8d1fe1b0d0f7cbdb0"},
{file = "pyobjc_framework_Cocoa-10.3.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:6c5af24610ab639bd1f521ce4500484b40787f898f691b7a23da3339e6bc8b90"},
{file = "pyobjc_framework_Cocoa-10.3.1-cp38-cp38-macosx_11_0_universal2.whl", hash = "sha256:a7151186bb7805deea434fae9a4423335e6371d105f29e73cc2036c6779a9dbc"},
{file = "pyobjc_framework_Cocoa-10.3.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:743d2a1ac08027fd09eab65814c79002a1d0421d7c0074ffd1217b6560889744"},
{file = "pyobjc_framework_cocoa-10.3.1.tar.gz", hash = "sha256:1cf20714daaa986b488fb62d69713049f635c9d41a60c8da97d835710445281a"},
]
[package.dependencies]
pyobjc-core = ">=10.3.1"
[[package]]
name = "pyobjc-framework-coretext"
version = "10.3.1"
description = "Wrappers for the framework CoreText on macOS"
optional = true
python-versions = ">=3.8"
files = [
{file = "pyobjc_framework_CoreText-10.3.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:dd6123cfccc38e32be884d1a13fb62bd636ecb192b9e8ae2b8011c977dec229e"},
{file = "pyobjc_framework_CoreText-10.3.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:834142a14235bd80edaef8d3a28d1e203ed3c988810a9b78005df7c561390288"},
{file = "pyobjc_framework_CoreText-10.3.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:ae6c09d29eeaf30a67aa70e08a465b1f1e47d12e22b3a34ae8bc8fdb7e2e7342"},
{file = "pyobjc_framework_CoreText-10.3.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:51ca95df1db9401366f11a7467f64be57f9a0630d31c357237d4062df0216938"},
{file = "pyobjc_framework_CoreText-10.3.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:8b75bdc267945b3f33c937c108d79405baf9d7c4cd530f922e5df243082a5031"},
{file = "pyobjc_framework_CoreText-10.3.1-cp38-cp38-macosx_11_0_universal2.whl", hash = "sha256:029b24c338f58fc32a004256d8559507e4f366dfe4eb09d3144273d536012d90"},
{file = "pyobjc_framework_CoreText-10.3.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:418a55047dbff999fcd2b78cca167c4105587020b6c51567cfa28993bbfdc8ed"},
{file = "pyobjc_framework_coretext-10.3.1.tar.gz", hash = "sha256:b8fa2d5078ed774431ae64ba886156e319aec0b8c6cc23dabfd86778265b416f"},
]
[package.dependencies]
pyobjc-core = ">=10.3.1"
pyobjc-framework-Cocoa = ">=10.3.1"
pyobjc-framework-Quartz = ">=10.3.1"
[[package]]
name = "pyobjc-framework-quartz"
version = "10.3.1"
description = "Wrappers for the Quartz frameworks on macOS"
optional = true
python-versions = ">=3.8"
files = [
{file = "pyobjc_framework_Quartz-10.3.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:5ef4fd315ed2bc42ef77fdeb2bae28a88ec986bd7b8079a87ba3b3475348f96e"},
{file = "pyobjc_framework_Quartz-10.3.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:96578d4a3e70164efe44ad7dc320ecd4e211758ffcde5dcd694de1bbdfe090a4"},
{file = "pyobjc_framework_Quartz-10.3.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:ca35f92486869a41847a1703bb176aab8a53dbfd8e678d1f4d68d8e6e1581c71"},
{file = "pyobjc_framework_Quartz-10.3.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:00a0933267e3a46ea4afcc35d117b2efb920f06de797fa66279c52e7057e3590"},
{file = "pyobjc_framework_Quartz-10.3.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:a161bedb4c5257a02ad56a910cd7eefb28bdb0ea78607df0d70ed4efe4ea54c1"},
{file = "pyobjc_framework_Quartz-10.3.1-cp38-cp38-macosx_11_0_universal2.whl", hash = "sha256:d7a8028e117a94923a511944bfa9daf9744e212f06cf89010c60934a479863a5"},
{file = "pyobjc_framework_Quartz-10.3.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:de00c983b3267eb26fa42c6ed9f15e2bf006bde8afa7fe2b390646aa21a5d6fc"},
{file = "pyobjc_framework_quartz-10.3.1.tar.gz", hash = "sha256:b6d7e346d735c9a7f147cd78e6da79eeae416a0b7d3874644c83a23786c6f886"},
]
[package.dependencies]
pyobjc-core = ">=10.3.1"
pyobjc-framework-Cocoa = ">=10.3.1"
[[package]] [[package]]
name = "pyopengl" name = "pyopengl"
version = "3.1.7" version = "3.1.7"
@ -3012,6 +3146,20 @@ files = [
[package.extras] [package.extras]
diagrams = ["jinja2", "railroad-diagrams"] diagrams = ["jinja2", "railroad-diagrams"]
[[package]]
name = "pyserial"
version = "3.5"
description = "Python Serial Port Extension"
optional = true
python-versions = "*"
files = [
{file = "pyserial-3.5-py2.py3-none-any.whl", hash = "sha256:c4451db6ba391ca6ca299fb3ec7bae67a5c55dde170964c7a14ceefec02f2cf0"},
{file = "pyserial-3.5.tar.gz", hash = "sha256:3c77e014170dfffbd816e6ffc205e9842efb10be9f58ec16d3e8675b4925cddb"},
]
[package.extras]
cp2110 = ["hidapi"]
[[package]] [[package]]
name = "pysocks" name = "pysocks"
version = "1.7.1" version = "1.7.1"
@ -3095,6 +3243,20 @@ files = [
[package.dependencies] [package.dependencies]
six = ">=1.5" six = ">=1.5"
[[package]]
name = "python-xlib"
version = "0.33"
description = "Python X Library"
optional = true
python-versions = "*"
files = [
{file = "python-xlib-0.33.tar.gz", hash = "sha256:55af7906a2c75ce6cb280a584776080602444f75815a7aff4d287bb2d7018b32"},
{file = "python_xlib-0.33-py2.py3-none-any.whl", hash = "sha256:c3534038d42e0df2f1392a1b30a15a4ff5fdc2b86cfa94f072bf11b10a164398"},
]
[package.dependencies]
six = ">=1.10.0"
[[package]] [[package]]
name = "pytz" name = "pytz"
version = "2024.1" version = "2024.1"
@ -3131,7 +3293,6 @@ files = [
{file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
{file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
{file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"},
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
{file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
{file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
@ -3278,16 +3439,16 @@ use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"]
[[package]] [[package]]
name = "rerun-sdk" name = "rerun-sdk"
version = "0.16.1" version = "0.17.0"
description = "The Rerun Logging SDK" description = "The Rerun Logging SDK"
optional = false optional = false
python-versions = "<3.13,>=3.8" python-versions = "<3.13,>=3.8"
files = [ files = [
{file = "rerun_sdk-0.16.1-cp38-abi3-macosx_10_12_x86_64.whl", hash = "sha256:170c6976634008611753e10dfef8cdc395ce8180e634c169e7c61cef2f89a277"}, {file = "rerun_sdk-0.17.0-cp38-abi3-macosx_10_12_x86_64.whl", hash = "sha256:abd34f746eada83b8bb0bc50007183151981d7ccf18306f3d42165819a3f6fcb"},
{file = "rerun_sdk-0.16.1-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:c9a76eab7eb5559276737dad655200e9350df0837158dbc5a896970ab4201454"}, {file = "rerun_sdk-0.17.0-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:8b0a8a6feab3f8e679801d158216a71d88a81480021587719330f50d083c4d26"},
{file = "rerun_sdk-0.16.1-cp38-abi3-manylinux_2_31_aarch64.whl", hash = "sha256:4d6436752d57e8b8038489a0e7e37f0c760b088e96db5fb81667d3a376d63fea"}, {file = "rerun_sdk-0.17.0-cp38-abi3-manylinux_2_31_aarch64.whl", hash = "sha256:ad55807abafb01e527846742e087819aac8e103f1ec15aadc563a4038bb44e1d"},
{file = "rerun_sdk-0.16.1-cp38-abi3-manylinux_2_31_x86_64.whl", hash = "sha256:37b7b47948471873e84f224b16f417a94a91c7cbd6c72c68281eeff1ba414b8f"}, {file = "rerun_sdk-0.17.0-cp38-abi3-manylinux_2_31_x86_64.whl", hash = "sha256:9d41f1f475270b1e0d50ddb8cb62e0d828988f0c371ac8457af25c8be5aa1dc0"},
{file = "rerun_sdk-0.16.1-cp38-abi3-win_amd64.whl", hash = "sha256:be88799c8afdf68eafa99e64e2e4f0a484e187e017a180219abbe6bb988acd4e"}, {file = "rerun_sdk-0.17.0-cp38-abi3-win_amd64.whl", hash = "sha256:34e5595a326cbdddfebdf00b08e877358c564fce74cc8c6d617fc89ef3a6aa70"},
] ]
[package.dependencies] [package.dependencies]
@ -3298,6 +3459,7 @@ pyarrow = ">=14.0.2"
typing-extensions = ">=4.5" typing-extensions = ">=4.5"
[package.extras] [package.extras]
notebook = ["rerun-notebook (==0.17.0)"]
tests = ["pytest (==7.1.2)"] tests = ["pytest (==7.1.2)"]
[[package]] [[package]]
@ -3424,27 +3586,32 @@ torch = ["safetensors[numpy]", "torch (>=1.10)"]
[[package]] [[package]]
name = "scikit-image" name = "scikit-image"
version = "0.23.2" version = "0.24.0"
description = "Image processing in Python" description = "Image processing in Python"
optional = true optional = true
python-versions = ">=3.10" python-versions = ">=3.9"
files = [ files = [
{file = "scikit_image-0.23.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:f9a8db6c52f8d0e1474ea8320d7b8db442b4d6baa29dd0acbd02f8a49572f18a"}, {file = "scikit_image-0.24.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:cb3bc0264b6ab30b43c4179ee6156bc18b4861e78bb329dd8d16537b7bbf827a"},
{file = "scikit_image-0.23.2-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:524b51a7440e46ed2ebbde7bc288bf2dde1dee2caafdd9513b2aca38a48223b7"}, {file = "scikit_image-0.24.0-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:9c7a52e20cdd760738da38564ba1fed7942b623c0317489af1a598a8dedf088b"},
{file = "scikit_image-0.23.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8b335c229170d787b3fb8c60d220f72049ccf862d5191a3cfda6ac84b995ac4e"}, {file = "scikit_image-0.24.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:93f46e6ce42e5409f4d09ce1b0c7f80dd7e4373bcec635b6348b63e3c886eac8"},
{file = "scikit_image-0.23.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:08b10781efbd6b084f3c847ff4049b657241ea866b9e331b14bf791dcb3e6661"}, {file = "scikit_image-0.24.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:39ee0af13435c57351a3397eb379e72164ff85161923eec0c38849fecf1b4764"},
{file = "scikit_image-0.23.2-cp310-cp310-win_amd64.whl", hash = "sha256:a207352e9a1956dda1424bbe872c7795345187138118e8be6a421aef3b988c2a"}, {file = "scikit_image-0.24.0-cp310-cp310-win_amd64.whl", hash = "sha256:7ac7913b028b8aa780ffae85922894a69e33d1c0bf270ea1774f382fe8bf95e7"},
{file = "scikit_image-0.23.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ee83fdb1843ee938eabdfeb9498623282935ea30aa20dffc5d5d16698efb4b2a"}, {file = "scikit_image-0.24.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:272909e02a59cea3ed4aa03739bb88df2625daa809f633f40b5053cf09241831"},
{file = "scikit_image-0.23.2-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:a158f50d3df4867bbd1c698520ede8bc493e430ad83f54ac1f0d8f57b328779b"}, {file = "scikit_image-0.24.0-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:190ebde80b4470fe8838764b9b15f232a964f1a20391663e31008d76f0c696f7"},
{file = "scikit_image-0.23.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:55de3326be124334b89314e9e04c8971ad98d6681e11a243f71bfb85ef9554b0"}, {file = "scikit_image-0.24.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:59c98cc695005faf2b79904e4663796c977af22586ddf1b12d6af2fa22842dc2"},
{file = "scikit_image-0.23.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fce619a6d84fe40c1208fa579b646e93ce13ef0afc3652a23e9782b2c183291a"}, {file = "scikit_image-0.24.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fa27b3a0dbad807b966b8db2d78da734cb812ca4787f7fbb143764800ce2fa9c"},
{file = "scikit_image-0.23.2-cp311-cp311-win_amd64.whl", hash = "sha256:ee65669aa586e110346f567ed5c92d1bd63799a19e951cb83da3f54b0caf7c52"}, {file = "scikit_image-0.24.0-cp311-cp311-win_amd64.whl", hash = "sha256:dacf591ac0c272a111181afad4b788a27fe70d213cfddd631d151cbc34f8ca2c"},
{file = "scikit_image-0.23.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:15bfb4e8d7bd90a967e6a3c3ab6be678063fc45e950b730684a8db46a02ff892"}, {file = "scikit_image-0.24.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:6fccceb54c9574590abcddc8caf6cefa57c13b5b8b4260ab3ff88ad8f3c252b3"},
{file = "scikit_image-0.23.2-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:5736e66d01b11cd90988ec24ab929c80a03af28f690189c951886891ebf63154"}, {file = "scikit_image-0.24.0-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:ccc01e4760d655aab7601c1ba7aa4ddd8b46f494ac46ec9c268df6f33ccddf4c"},
{file = "scikit_image-0.23.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3597ac5d8f51dafbcb7433ef1fdefdefb535f50745b2002ae0a5d651df4f063b"}, {file = "scikit_image-0.24.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:18836a18d3a7b6aca5376a2d805f0045826bc6c9fc85331659c33b4813e0b563"},
{file = "scikit_image-0.23.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1978be2abe3c3c3189a99a411d48bbb1306f7c2debb3aefbf426e23947f26623"}, {file = "scikit_image-0.24.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8579bda9c3f78cb3b3ed8b9425213c53a25fa7e994b7ac01f2440b395babf660"},
{file = "scikit_image-0.23.2-cp312-cp312-win_amd64.whl", hash = "sha256:ae32bf0cb02b672ed74d28880ca6f88928ae8dd794d67e04fa3ff4836feb9bd6"}, {file = "scikit_image-0.24.0-cp312-cp312-win_amd64.whl", hash = "sha256:82ab903afa60b2da1da2e6f0c8c65e7c8868c60a869464c41971da929b3e82bc"},
{file = "scikit_image-0.23.2.tar.gz", hash = "sha256:c9da4b2c3117e3e30364a3d14496ee5c72b09eb1a4ab1292b302416faa360590"}, {file = "scikit_image-0.24.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:ef04360eda372ee5cd60aebe9be91258639c86ae2ea24093fb9182118008d009"},
{file = "scikit_image-0.24.0-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:e9aadb442360a7e76f0c5c9d105f79a83d6df0e01e431bd1d5757e2c5871a1f3"},
{file = "scikit_image-0.24.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5e37de6f4c1abcf794e13c258dc9b7d385d5be868441de11c180363824192ff7"},
{file = "scikit_image-0.24.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4688c18bd7ec33c08d7bf0fd19549be246d90d5f2c1d795a89986629af0a1e83"},
{file = "scikit_image-0.24.0-cp39-cp39-win_amd64.whl", hash = "sha256:56dab751d20b25d5d3985e95c9b4e975f55573554bd76b0aedf5875217c93e69"},
{file = "scikit_image-0.24.0.tar.gz", hash = "sha256:5d16efe95da8edbeb363e0c4157b99becbd650a60b77f6e3af5768b66cf007ab"},
] ]
[package.dependencies] [package.dependencies]
@ -3509,13 +3676,13 @@ test = ["Cython", "array-api-strict", "asv", "gmpy2", "hypothesis (>=6.30)", "me
[[package]] [[package]]
name = "sentry-sdk" name = "sentry-sdk"
version = "2.7.1" version = "2.10.0"
description = "Python client for Sentry (https://sentry.io)" description = "Python client for Sentry (https://sentry.io)"
optional = false optional = false
python-versions = ">=3.6" python-versions = ">=3.6"
files = [ files = [
{file = "sentry_sdk-2.7.1-py2.py3-none-any.whl", hash = "sha256:ef1b3d54eb715825657cd4bb3cb42bb4dc85087bac14c56b0fd8c21abd968c9a"}, {file = "sentry_sdk-2.10.0-py2.py3-none-any.whl", hash = "sha256:87b3d413c87d8e7f816cc9334bff255a83d8b577db2b22042651c30c19c09190"},
{file = "sentry_sdk-2.7.1.tar.gz", hash = "sha256:25006c7e68b75aaa5e6b9c6a420ece22e8d7daec4b7a906ffd3a8607b67c037b"}, {file = "sentry_sdk-2.10.0.tar.gz", hash = "sha256:545fcc6e36c335faa6d6cda84669b6e17025f31efbf3b2211ec14efe008b75d1"},
] ]
[package.dependencies] [package.dependencies]
@ -3659,67 +3826,63 @@ test = ["pytest"]
[[package]] [[package]]
name = "setuptools" name = "setuptools"
version = "70.2.0" version = "71.0.1"
description = "Easily download, build, install, upgrade, and uninstall Python packages" description = "Easily download, build, install, upgrade, and uninstall Python packages"
optional = false optional = false
python-versions = ">=3.8" python-versions = ">=3.8"
files = [ files = [
{file = "setuptools-70.2.0-py3-none-any.whl", hash = "sha256:b8b8060bb426838fbe942479c90296ce976249451118ef566a5a0b7d8b78fb05"}, {file = "setuptools-71.0.1-py3-none-any.whl", hash = "sha256:1eb8ef012efae7f6acbc53ec0abde4bc6746c43087fd215ee09e1df48998711f"},
{file = "setuptools-70.2.0.tar.gz", hash = "sha256:bd63e505105011b25c3c11f753f7e3b8465ea739efddaccef8f0efac2137bac1"}, {file = "setuptools-71.0.1.tar.gz", hash = "sha256:c51d7fd29843aa18dad362d4b4ecd917022131425438251f4e3d766c964dd1ad"},
] ]
[package.extras] [package.extras]
doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "pyproject-hooks (!=1.1)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier"] core = ["importlib-metadata (>=6)", "importlib-resources (>=5.10.2)", "jaraco.text (>=3.7)", "more-itertools (>=8.8)", "ordered-set (>=3.1.1)", "packaging (>=24)", "platformdirs (>=2.6.2)", "tomli (>=2.0.1)", "wheel (>=0.43.0)"]
test = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "importlib-metadata", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "jaraco.test", "mypy (==1.10.0)", "packaging (>=23.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.*)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-home (>=0.5)", "pytest-mypy", "pytest-perf", "pytest-ruff (>=0.3.2)", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"] doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "pyproject-hooks (!=1.1)", "rst.linker (>=1.9)", "sphinx (<7.4)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier"]
test = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "importlib-metadata", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "jaraco.test", "mypy (==1.10.0)", "packaging (>=23.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.*)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-home (>=0.5)", "pytest-mypy", "pytest-perf", "pytest-ruff (<0.4)", "pytest-ruff (>=0.2.1)", "pytest-ruff (>=0.3.2)", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"]
[[package]] [[package]]
name = "shapely" name = "shapely"
version = "2.0.4" version = "2.0.5"
description = "Manipulation and analysis of geometric objects" description = "Manipulation and analysis of geometric objects"
optional = true optional = true
python-versions = ">=3.7" python-versions = ">=3.7"
files = [ files = [
{file = "shapely-2.0.4-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:011b77153906030b795791f2fdfa2d68f1a8d7e40bce78b029782ade3afe4f2f"}, {file = "shapely-2.0.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:89d34787c44f77a7d37d55ae821f3a784fa33592b9d217a45053a93ade899375"},
{file = "shapely-2.0.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:9831816a5d34d5170aa9ed32a64982c3d6f4332e7ecfe62dc97767e163cb0b17"}, {file = "shapely-2.0.5-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:798090b426142df2c5258779c1d8d5734ec6942f778dab6c6c30cfe7f3bf64ff"},
{file = "shapely-2.0.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:5c4849916f71dc44e19ed370421518c0d86cf73b26e8656192fcfcda08218fbd"}, {file = "shapely-2.0.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:45211276900c4790d6bfc6105cbf1030742da67594ea4161a9ce6812a6721e68"},
{file = "shapely-2.0.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:841f93a0e31e4c64d62ea570d81c35de0f6cea224568b2430d832967536308e6"}, {file = "shapely-2.0.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2e119444bc27ca33e786772b81760f2028d930ac55dafe9bc50ef538b794a8e1"},
{file = "shapely-2.0.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b4431f522b277c79c34b65da128029a9955e4481462cbf7ebec23aab61fc58"}, {file = "shapely-2.0.5-cp310-cp310-win32.whl", hash = "sha256:9a4492a2b2ccbeaebf181e7310d2dfff4fdd505aef59d6cb0f217607cb042fb3"},
{file = "shapely-2.0.4-cp310-cp310-win32.whl", hash = "sha256:92a41d936f7d6743f343be265ace93b7c57f5b231e21b9605716f5a47c2879e7"}, {file = "shapely-2.0.5-cp310-cp310-win_amd64.whl", hash = "sha256:1e5cb5ee72f1bc7ace737c9ecd30dc174a5295fae412972d3879bac2e82c8fae"},
{file = "shapely-2.0.4-cp310-cp310-win_amd64.whl", hash = "sha256:30982f79f21bb0ff7d7d4a4e531e3fcaa39b778584c2ce81a147f95be1cd58c9"}, {file = "shapely-2.0.5-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:5bbfb048a74cf273db9091ff3155d373020852805a37dfc846ab71dde4be93ec"},
{file = "shapely-2.0.4-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:de0205cb21ad5ddaef607cda9a3191eadd1e7a62a756ea3a356369675230ac35"}, {file = "shapely-2.0.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:93be600cbe2fbaa86c8eb70656369f2f7104cd231f0d6585c7d0aa555d6878b8"},
{file = "shapely-2.0.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:7d56ce3e2a6a556b59a288771cf9d091470116867e578bebced8bfc4147fbfd7"}, {file = "shapely-2.0.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0f8e71bb9a46814019f6644c4e2560a09d44b80100e46e371578f35eaaa9da1c"},
{file = "shapely-2.0.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:58b0ecc505bbe49a99551eea3f2e8a9b3b24b3edd2a4de1ac0dc17bc75c9ec07"}, {file = "shapely-2.0.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d5251c28a29012e92de01d2e84f11637eb1d48184ee8f22e2df6c8c578d26760"},
{file = "shapely-2.0.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:790a168a808bd00ee42786b8ba883307c0e3684ebb292e0e20009588c426da47"}, {file = "shapely-2.0.5-cp311-cp311-win32.whl", hash = "sha256:35110e80070d664781ec7955c7de557456b25727a0257b354830abb759bf8311"},
{file = "shapely-2.0.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4310b5494271e18580d61022c0857eb85d30510d88606fa3b8314790df7f367d"}, {file = "shapely-2.0.5-cp311-cp311-win_amd64.whl", hash = "sha256:6c6b78c0007a34ce7144f98b7418800e0a6a5d9a762f2244b00ea560525290c9"},
{file = "shapely-2.0.4-cp311-cp311-win32.whl", hash = "sha256:63f3a80daf4f867bd80f5c97fbe03314348ac1b3b70fb1c0ad255a69e3749879"}, {file = "shapely-2.0.5-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:03bd7b5fa5deb44795cc0a503999d10ae9d8a22df54ae8d4a4cd2e8a93466195"},
{file = "shapely-2.0.4-cp311-cp311-win_amd64.whl", hash = "sha256:c52ed79f683f721b69a10fb9e3d940a468203f5054927215586c5d49a072de8d"}, {file = "shapely-2.0.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2ff9521991ed9e201c2e923da014e766c1aa04771bc93e6fe97c27dcf0d40ace"},
{file = "shapely-2.0.4-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:5bbd974193e2cc274312da16b189b38f5f128410f3377721cadb76b1e8ca5328"}, {file = "shapely-2.0.5-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1b65365cfbf657604e50d15161ffcc68de5cdb22a601bbf7823540ab4918a98d"},
{file = "shapely-2.0.4-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:41388321a73ba1a84edd90d86ecc8bfed55e6a1e51882eafb019f45895ec0f65"}, {file = "shapely-2.0.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:21f64e647a025b61b19585d2247137b3a38a35314ea68c66aaf507a1c03ef6fe"},
{file = "shapely-2.0.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:0776c92d584f72f1e584d2e43cfc5542c2f3dd19d53f70df0900fda643f4bae6"}, {file = "shapely-2.0.5-cp312-cp312-win32.whl", hash = "sha256:3ac7dc1350700c139c956b03d9c3df49a5b34aaf91d024d1510a09717ea39199"},
{file = "shapely-2.0.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c75c98380b1ede1cae9a252c6dc247e6279403fae38c77060a5e6186c95073ac"}, {file = "shapely-2.0.5-cp312-cp312-win_amd64.whl", hash = "sha256:30e8737983c9d954cd17feb49eb169f02f1da49e24e5171122cf2c2b62d65c95"},
{file = "shapely-2.0.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c3e700abf4a37b7b8b90532fa6ed5c38a9bfc777098bc9fbae5ec8e618ac8f30"}, {file = "shapely-2.0.5-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:ff7731fea5face9ec08a861ed351734a79475631b7540ceb0b66fb9732a5f529"},
{file = "shapely-2.0.4-cp312-cp312-win32.whl", hash = "sha256:4f2ab0faf8188b9f99e6a273b24b97662194160cc8ca17cf9d1fb6f18d7fb93f"}, {file = "shapely-2.0.5-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ff9e520af0c5a578e174bca3c18713cd47a6c6a15b6cf1f50ac17dc8bb8db6a2"},
{file = "shapely-2.0.4-cp312-cp312-win_amd64.whl", hash = "sha256:03152442d311a5e85ac73b39680dd64a9892fa42bb08fd83b3bab4fe6999bfa0"}, {file = "shapely-2.0.5-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:49b299b91557b04acb75e9732645428470825061f871a2edc36b9417d66c1fc5"},
{file = "shapely-2.0.4-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:994c244e004bc3cfbea96257b883c90a86e8cbd76e069718eb4c6b222a56f78b"}, {file = "shapely-2.0.5-cp37-cp37m-win32.whl", hash = "sha256:b5870633f8e684bf6d1ae4df527ddcb6f3895f7b12bced5c13266ac04f47d231"},
{file = "shapely-2.0.4-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:05ffd6491e9e8958b742b0e2e7c346635033d0a5f1a0ea083547fcc854e5d5cf"}, {file = "shapely-2.0.5-cp37-cp37m-win_amd64.whl", hash = "sha256:401cb794c5067598f50518e5a997e270cd7642c4992645479b915c503866abed"},
{file = "shapely-2.0.4-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2fbdc1140a7d08faa748256438291394967aa54b40009f54e8d9825e75ef6113"}, {file = "shapely-2.0.5-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:e91ee179af539100eb520281ba5394919067c6b51824e6ab132ad4b3b3e76dd0"},
{file = "shapely-2.0.4-cp37-cp37m-win32.whl", hash = "sha256:5af4cd0d8cf2912bd95f33586600cac9c4b7c5053a036422b97cfe4728d2eb53"}, {file = "shapely-2.0.5-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:8af6f7260f809c0862741ad08b1b89cb60c130ae30efab62320bbf4ee9cc71fa"},
{file = "shapely-2.0.4-cp37-cp37m-win_amd64.whl", hash = "sha256:464157509ce4efa5ff285c646a38b49f8c5ef8d4b340f722685b09bb033c5ccf"}, {file = "shapely-2.0.5-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f5456dd522800306ba3faef77c5ba847ec30a0bd73ab087a25e0acdd4db2514f"},
{file = "shapely-2.0.4-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:489c19152ec1f0e5c5e525356bcbf7e532f311bff630c9b6bc2db6f04da6a8b9"}, {file = "shapely-2.0.5-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b714a840402cde66fd7b663bb08cacb7211fa4412ea2a209688f671e0d0631fd"},
{file = "shapely-2.0.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:b79bbd648664aa6f44ef018474ff958b6b296fed5c2d42db60078de3cffbc8aa"}, {file = "shapely-2.0.5-cp38-cp38-win32.whl", hash = "sha256:7e8cf5c252fac1ea51b3162be2ec3faddedc82c256a1160fc0e8ddbec81b06d2"},
{file = "shapely-2.0.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:674d7baf0015a6037d5758496d550fc1946f34bfc89c1bf247cabdc415d7747e"}, {file = "shapely-2.0.5-cp38-cp38-win_amd64.whl", hash = "sha256:4461509afdb15051e73ab178fae79974387f39c47ab635a7330d7fee02c68a3f"},
{file = "shapely-2.0.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6cd4ccecc5ea5abd06deeaab52fcdba372f649728050c6143cc405ee0c166679"}, {file = "shapely-2.0.5-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:7545a39c55cad1562be302d74c74586f79e07b592df8ada56b79a209731c0219"},
{file = "shapely-2.0.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fb5cdcbbe3080181498931b52a91a21a781a35dcb859da741c0345c6402bf00c"}, {file = "shapely-2.0.5-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:4c83a36f12ec8dee2066946d98d4d841ab6512a6ed7eb742e026a64854019b5f"},
{file = "shapely-2.0.4-cp38-cp38-win32.whl", hash = "sha256:55a38dcd1cee2f298d8c2ebc60fc7d39f3b4535684a1e9e2f39a80ae88b0cea7"}, {file = "shapely-2.0.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:89e640c2cd37378480caf2eeda9a51be64201f01f786d127e78eaeff091ec897"},
{file = "shapely-2.0.4-cp38-cp38-win_amd64.whl", hash = "sha256:ec555c9d0db12d7fd777ba3f8b75044c73e576c720a851667432fabb7057da6c"}, {file = "shapely-2.0.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:06efe39beafde3a18a21dde169d32f315c57da962826a6d7d22630025200c5e6"},
{file = "shapely-2.0.4-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:3f9103abd1678cb1b5f7e8e1af565a652e036844166c91ec031eeb25c5ca8af0"}, {file = "shapely-2.0.5-cp39-cp39-win32.whl", hash = "sha256:8203a8b2d44dcb366becbc8c3d553670320e4acf0616c39e218c9561dd738d92"},
{file = "shapely-2.0.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:263bcf0c24d7a57c80991e64ab57cba7a3906e31d2e21b455f493d4aab534aaa"}, {file = "shapely-2.0.5-cp39-cp39-win_amd64.whl", hash = "sha256:7fed9dbfbcfec2682d9a047b9699db8dcc890dfca857ecba872c42185fc9e64e"},
{file = "shapely-2.0.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:ddf4a9bfaac643e62702ed662afc36f6abed2a88a21270e891038f9a19bc08fc"}, {file = "shapely-2.0.5.tar.gz", hash = "sha256:bff2366bc786bfa6cb353d6b47d0443c570c32776612e527ee47b6df63fcfe32"},
{file = "shapely-2.0.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:485246fcdb93336105c29a5cfbff8a226949db37b7473c89caa26c9bae52a242"},
{file = "shapely-2.0.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8de4578e838a9409b5b134a18ee820730e507b2d21700c14b71a2b0757396acc"},
{file = "shapely-2.0.4-cp39-cp39-win32.whl", hash = "sha256:9dab4c98acfb5fb85f5a20548b5c0abe9b163ad3525ee28822ffecb5c40e724c"},
{file = "shapely-2.0.4-cp39-cp39-win_amd64.whl", hash = "sha256:31c19a668b5a1eadab82ff070b5a260478ac6ddad3a5b62295095174a8d26398"},
{file = "shapely-2.0.4.tar.gz", hash = "sha256:5dc736127fac70009b8d309a0eeb74f3e08979e530cf7017f2f507ef62e6cfb8"},
] ]
[package.dependencies] [package.dependencies]
@ -3764,17 +3927,20 @@ files = [
[[package]] [[package]]
name = "sympy" name = "sympy"
version = "1.12.1" version = "1.13.0"
description = "Computer algebra system (CAS) in Python" description = "Computer algebra system (CAS) in Python"
optional = false optional = false
python-versions = ">=3.8" python-versions = ">=3.8"
files = [ files = [
{file = "sympy-1.12.1-py3-none-any.whl", hash = "sha256:9b2cbc7f1a640289430e13d2a56f02f867a1da0190f2f99d8968c2f74da0e515"}, {file = "sympy-1.13.0-py3-none-any.whl", hash = "sha256:6b0b32a4673fb91bd3cac3b55406c8e01d53ae22780be467301cc452f6680c92"},
{file = "sympy-1.12.1.tar.gz", hash = "sha256:2877b03f998cd8c08f07cd0de5b767119cd3ef40d09f41c30d722f6686b0fb88"}, {file = "sympy-1.13.0.tar.gz", hash = "sha256:3b6af8f4d008b9a1a6a4268b335b984b23835f26d1d60b0526ebc71d48a25f57"},
] ]
[package.dependencies] [package.dependencies]
mpmath = ">=1.1.0,<1.4.0" mpmath = ">=1.1.0,<1.4"
[package.extras]
dev = ["hypothesis (>=6.70.0)", "pytest (>=7.1.0)"]
[[package]] [[package]]
name = "tbb" name = "tbb"
@ -4326,6 +4492,7 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools",
aloha = ["gym-aloha"] aloha = ["gym-aloha"]
dev = ["debugpy", "pre-commit"] dev = ["debugpy", "pre-commit"]
dora = ["gym-dora"] dora = ["gym-dora"]
koch = ["dynamixel-sdk", "pynput"]
pusht = ["gym-pusht"] pusht = ["gym-pusht"]
test = ["pytest", "pytest-cov", "pytest-mock"] test = ["pytest", "pytest-cov", "pytest-mock"]
umi = ["imagecodecs"] umi = ["imagecodecs"]
@ -4335,4 +4502,4 @@ xarm = ["gym-xarm"]
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = ">=3.10,<3.13" python-versions = ">=3.10,<3.13"
content-hash = "91a402588458645c146da00cccf7627c5dddad61bd1168e539900eaec99987b3" content-hash = "882b44dada0890dd4e1c727d3363d95cbe1a4adf1d80aa5263080597d80be42c"

View File

@ -38,12 +38,12 @@ einops = ">=0.8.0"
pymunk = ">=6.6.0" pymunk = ">=6.6.0"
zarr = ">=2.17.0" zarr = ">=2.17.0"
numba = ">=0.59.0" numba = ">=0.59.0"
torch = "^2.2.1" torch = ">=2.2.1"
opencv-python = ">=4.9.0" opencv-python = ">=4.9.0"
diffusers = "^0.27.2" diffusers = ">=0.27.2"
torchvision = ">=0.17.1" torchvision = ">=0.17.1"
h5py = ">=3.10.0" h5py = ">=3.10.0"
huggingface-hub = {extras = ["hf-transfer"], version = "^0.23.0"} huggingface-hub = {extras = ["hf-transfer"], version = ">=0.23.0"}
gymnasium = ">=0.29.1" gymnasium = ">=0.29.1"
cmake = ">=3.29.0.1" cmake = ">=3.29.0.1"
gym-dora = { git = "https://github.com/dora-rs/dora-lerobot.git", subdirectory = "gym_dora", optional = true } gym-dora = { git = "https://github.com/dora-rs/dora-lerobot.git", subdirectory = "gym_dora", optional = true }
@ -54,15 +54,18 @@ pre-commit = {version = ">=3.7.0", optional = true}
debugpy = {version = ">=1.8.1", optional = true} debugpy = {version = ">=1.8.1", optional = true}
pytest = {version = ">=8.1.0", optional = true} pytest = {version = ">=8.1.0", optional = true}
pytest-cov = {version = ">=5.0.0", optional = true} pytest-cov = {version = ">=5.0.0", optional = true}
datasets = "^2.19.0" datasets = ">=2.19.0"
imagecodecs = { version = ">=2024.1.1", optional = true } imagecodecs = { version = ">=2024.1.1", optional = true }
pyav = ">=12.0.5" pyav = ">=12.0.5"
moviepy = ">=1.0.3" moviepy = ">=1.0.3"
rerun-sdk = ">=0.15.1" rerun-sdk = ">=0.15.1"
deepdiff = ">=7.0.1" deepdiff = ">=7.0.1"
scikit-image = {version = "^0.23.2", optional = true} scikit-image = {version = ">=0.23.2", optional = true}
pandas = {version = "^2.2.2", optional = true} pandas = {version = ">=2.2.2", optional = true}
pytest-mock = {version = "^3.14.0", optional = true} pytest-mock = {version = ">=3.14.0", optional = true}
dynamixel-sdk = {version = ">=3.7.31", optional = true}
pynput = {version = ">=1.7.7", optional = true}
[tool.poetry.extras] [tool.poetry.extras]
@ -74,6 +77,7 @@ dev = ["pre-commit", "debugpy"]
test = ["pytest", "pytest-cov", "pytest-mock"] test = ["pytest", "pytest-cov", "pytest-mock"]
umi = ["imagecodecs"] umi = ["imagecodecs"]
video_benchmark = ["scikit-image", "pandas"] video_benchmark = ["scikit-image", "pandas"]
koch = ["dynamixel-sdk", "pynput"]
[tool.ruff] [tool.ruff]
line-length = 110 line-length = 110

View File

@ -13,8 +13,25 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import pytest
from .utils import DEVICE from .utils import DEVICE
def pytest_collection_finish(): def pytest_collection_finish():
print(f"\nTesting with {DEVICE=}") print(f"\nTesting with {DEVICE=}")
@pytest.fixture(scope="session")
def is_koch_available():
try:
from lerobot.common.robot_devices.robots.factory import make_robot
robot = make_robot("koch")
robot.connect()
del robot
return True
except Exception as e:
print("An alexander koch robot is not available.")
print(e)
return False

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1 version https://git-lfs.github.com/spec/v1
oid sha256:523f220f3acbab0cd4aef8a13c77916634488b1af08a06e4e65d1aecafdc2cae oid sha256:28444747a9cb3876f86ae86fed72e587dbcacfccd87c5c24b8ecac30c3ce3077
size 5104 size 5104

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1 version https://git-lfs.github.com/spec/v1
oid sha256:95dd049b4386030ced4505586b874f16906f8d89f29b570201782eebcbe4f402 oid sha256:a43a9ddaf8527e3344b22bd21276e1f561e83599d720933b28725b00d94823c0
size 31688 size 31672

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1 version https://git-lfs.github.com/spec/v1
oid sha256:806851d60b6c492620b7269876eef9ce17756ec03da93f36b351f8aa75be0954 oid sha256:093bff1fbc3bde2547bccbbefc277d02368a8d4a9100b3e4bd47c755798cad68
size 33408 size 33400

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1 version https://git-lfs.github.com/spec/v1
oid sha256:3f4e0e525aeb22ea94b79e26b39a87e6f2da9fbee33e493906aaf2aad9a7c1ef oid sha256:85bed637e90f15c64e4af01d2dbc5d9c3a370215f2c8c379494fa3acb413bc2e
size 515400 size 515400

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1 version https://git-lfs.github.com/spec/v1
oid sha256:6dc658a1c1616c7d1c211eb8f87cec3d44f7b67d6b3cea7a6ce12b32d74674da oid sha256:00cf8e548d7ea23aa70de79e05c39990a32a790def824f729e6c98bea31c69bc
size 31688 size 31672

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1 version https://git-lfs.github.com/spec/v1
oid sha256:01d993c67a9267032fe9fbeff20b4359c209464976ea503040a0a76ae213450a oid sha256:b3a4c2581f48229312a582d91f0adea8078c0c5b744c34d76723edf4731f9003
size 33408 size 33400

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1 version https://git-lfs.github.com/spec/v1
oid sha256:2fff6294b94cf42d4dd1249dcc5c3b0269d6d9c697f894e61b867d7ab81a94e4 oid sha256:aab00b0349901450adbb8e0d7d4af1f743dd88e7e19f1bcfef821de8bdcc957d
size 5104 size 5104

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1 version https://git-lfs.github.com/spec/v1
oid sha256:4aa23e51607604a18b70fa42edbbe1af34f119d985628fc27cc1bbb0efbc8901 oid sha256:de70c3055aa052f5b811ec7c2994ec6861efe645c6caee41e04a3460598500d5
size 31688 size 31672

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1 version https://git-lfs.github.com/spec/v1
oid sha256:6fd368406c93cb562a69ff11cf7adf34a4b223507dcb2b9e9b8f44ee1036988a oid sha256:d4070bd1f1cd8c72bc2daf628088e42b8ef113f6df0bfd9e91be052bc90038c3
size 68 size 68

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1 version https://git-lfs.github.com/spec/v1
oid sha256:5663ee79a13bb70a1604b887dd21bf89d18482287442419c6cc6c5bf0e753e99 oid sha256:19fdc1edf327e04132c1917024289b3d16e25a1ec2130f3df797fe07434dfbbd
size 34928 size 34920

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1 version https://git-lfs.github.com/spec/v1
oid sha256:fb1a45463efd860af2ca22c16c77d55a18bd96fef080ae77978845a2f22ef716 oid sha256:dcd8ebaefd3ff267eb24654135d1efb179d713e6cfe6917f793a3e2483efd501
size 5104 size 5104

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1 version https://git-lfs.github.com/spec/v1
oid sha256:aa5a43e22f01d8e2f8d19f31753608794f1edbd74aaf71660091ab80ea58dc9b oid sha256:107e98647ed1081745476b250df8848c0c430b2aff51d614f6b2db95684467aa
size 30808 size 30800

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1 version https://git-lfs.github.com/spec/v1
oid sha256:54d1f75cf67a7b1d7a7c6865ecb9b1cc86a2f032d1890245f8996789ab6e0df6 oid sha256:adbae737c987f912509d3fba06f332bda700bfc2c6d83a09c969e9d7a3ca75f7
size 33608 size 33600

125
tests/test_cameras.py Normal file
View File

@ -0,0 +1,125 @@
import numpy as np
import pytest
from lerobot.common.robot_devices.cameras.opencv import OpenCVCamera, save_images_from_cameras
from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError
from tests.utils import require_koch
CAMERA_INDEX = 2
# Maximum absolute difference between two consecutive images recored by a camera.
# This value differs with respect to the camera.
MAX_PIXEL_DIFFERENCE = 25
def compute_max_pixel_difference(first_image, second_image):
return np.abs(first_image.astype(float) - second_image.astype(float)).max()
@require_koch
def test_camera(request):
"""Test assumes that `camera.read()` returns the same image when called multiple times in a row.
So the environment should not change (you shouldnt be in front of the camera) and the camera should not be moving.
Warning: The tests worked for a macbookpro camera, but I am getting assertion error (`np.allclose(color_image, async_color_image)`)
for my iphone camera and my LG monitor camera.
"""
# TODO(rcadene): measure fps in nightly?
# TODO(rcadene): test logs
# TODO(rcadene): add compatibility with other camera APIs
# Test instantiating
camera = OpenCVCamera(CAMERA_INDEX)
# Test reading, async reading, disconnecting before connecting raises an error
with pytest.raises(RobotDeviceNotConnectedError):
camera.read()
with pytest.raises(RobotDeviceNotConnectedError):
camera.async_read()
with pytest.raises(RobotDeviceNotConnectedError):
camera.disconnect()
# Test deleting the object without connecting first
del camera
# Test connecting
camera = OpenCVCamera(CAMERA_INDEX)
camera.connect()
assert camera.is_connected
assert camera.fps is not None
assert camera.width is not None
assert camera.height is not None
# Test connecting twice raises an error
with pytest.raises(RobotDeviceAlreadyConnectedError):
camera.connect()
# Test reading from the camera
color_image = camera.read()
assert isinstance(color_image, np.ndarray)
assert color_image.ndim == 3
h, w, c = color_image.shape
assert c == 3
assert w > h
# Test read and async_read outputs similar images
# ...warming up as the first frames can be black
for _ in range(30):
camera.read()
color_image = camera.read()
async_color_image = camera.async_read()
print(
"max_pixel_difference between read() and async_read()",
compute_max_pixel_difference(color_image, async_color_image),
)
assert np.allclose(color_image, async_color_image, rtol=1e-5, atol=MAX_PIXEL_DIFFERENCE)
# Test disconnecting
camera.disconnect()
assert camera.camera is None
assert camera.thread is None
# Test disconnecting with `__del__`
camera = OpenCVCamera(CAMERA_INDEX)
camera.connect()
del camera
# Test acquiring a bgr image
camera = OpenCVCamera(CAMERA_INDEX, color_mode="bgr")
camera.connect()
assert camera.color_mode == "bgr"
bgr_color_image = camera.read()
assert np.allclose(color_image, bgr_color_image[:, :, [2, 1, 0]], rtol=1e-5, atol=MAX_PIXEL_DIFFERENCE)
del camera
# TODO(rcadene): Add a test for a camera that doesnt support fps=60 and raises an OSError
# TODO(rcadene): Add a test for a camera that supports fps=60
# Test fps=10 raises an OSError
camera = OpenCVCamera(CAMERA_INDEX, fps=10)
with pytest.raises(OSError):
camera.connect()
del camera
# Test width and height can be set
camera = OpenCVCamera(CAMERA_INDEX, fps=30, width=1280, height=720)
camera.connect()
assert camera.fps == 30
assert camera.width == 1280
assert camera.height == 720
color_image = camera.read()
h, w, c = color_image.shape
assert h == 720
assert w == 1280
assert c == 3
del camera
# Test not supported width and height raise an error
camera = OpenCVCamera(CAMERA_INDEX, fps=30, width=0, height=0)
with pytest.raises(OSError):
camera.connect()
del camera
@require_koch
def test_save_images_from_cameras(tmpdir, request):
save_images_from_cameras(tmpdir, record_time_s=1)

View File

@ -0,0 +1,48 @@
from pathlib import Path
from lerobot.common.policies.factory import make_policy
from lerobot.common.robot_devices.robots.factory import make_robot
from lerobot.common.utils.utils import init_hydra_config
from lerobot.scripts.control_robot import record_dataset, replay_episode, run_policy, teleoperate
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_koch
@require_koch
def test_teleoperate(request):
robot = make_robot("koch")
teleoperate(robot, teleop_time_s=1)
teleoperate(robot, fps=30, teleop_time_s=1)
teleoperate(robot, fps=60, teleop_time_s=1)
del robot
@require_koch
def test_record_dataset_and_replay_episode_and_run_policy(tmpdir, request):
robot_name = "koch"
env_name = "koch_real"
policy_name = "act_koch_real"
root = Path(tmpdir)
repo_id = "lerobot/debug"
robot = make_robot(robot_name)
dataset = record_dataset(
robot, fps=30, root=root, repo_id=repo_id, warmup_time_s=1, episode_time_s=1, num_episodes=2
)
replay_episode(robot, episode=0, fps=30, root=root, repo_id=repo_id)
cfg = init_hydra_config(
DEFAULT_CONFIG_PATH,
overrides=[
f"env={env_name}",
f"policy={policy_name}",
f"device={DEVICE}",
],
)
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.stats)
run_policy(robot, policy, cfg, run_time_s=1)
del robot

92
tests/test_motors.py Normal file
View File

@ -0,0 +1,92 @@
import time
import numpy as np
import pytest
from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError
from tests.utils import require_koch
@require_koch
def test_motors_bus(request):
# TODO(rcadene): measure fps in nightly?
# TODO(rcadene): test logs
# TODO(rcadene): test calibration
# TODO(rcadene): add compatibility with other motors bus
from lerobot.common.robot_devices.motors.dynamixel import DynamixelMotorsBus
# Test instantiating a common motors structure.
# Here the one from Alexander Koch follower arm.
port = "/dev/tty.usbmodem575E0032081"
motors = {
# name: (index, model)
"shoulder_pan": (1, "xl430-w250"),
"shoulder_lift": (2, "xl430-w250"),
"elbow_flex": (3, "xl330-m288"),
"wrist_flex": (4, "xl330-m288"),
"wrist_roll": (5, "xl330-m288"),
"gripper": (6, "xl330-m288"),
}
motors_bus = DynamixelMotorsBus(port, motors)
# Test reading and writting before connecting raises an error
with pytest.raises(RobotDeviceNotConnectedError):
motors_bus.read("Torque_Enable")
with pytest.raises(RobotDeviceNotConnectedError):
motors_bus.write("Torque_Enable", 1)
with pytest.raises(RobotDeviceNotConnectedError):
motors_bus.disconnect()
# Test deleting the object without connecting first
del motors_bus
# Test connecting
motors_bus = DynamixelMotorsBus(port, motors)
motors_bus.connect()
# Test connecting twice raises an error
with pytest.raises(RobotDeviceAlreadyConnectedError):
motors_bus.connect()
# Test disabling torque and reading torque on all motors
motors_bus.write("Torque_Enable", 0)
values = motors_bus.read("Torque_Enable")
assert isinstance(values, np.ndarray)
assert len(values) == len(motors)
assert (values == 0).all()
# Test writing torque on a specific motor
motors_bus.write("Torque_Enable", 1, "gripper")
# Test reading torque from this specific motor. It is now 1
values = motors_bus.read("Torque_Enable", "gripper")
assert len(values) == 1
assert values[0] == 1
# Test reading torque from all motors. It is 1 for the specific motor,
# and 0 on the others.
values = motors_bus.read("Torque_Enable")
gripper_index = motors_bus.motor_names.index("gripper")
assert values[gripper_index] == 1
assert values.sum() == 1 # gripper is the only motor to have torque 1
# Test writing torque on all motors and it is 1 for all.
motors_bus.write("Torque_Enable", 1)
values = motors_bus.read("Torque_Enable")
assert (values == 1).all()
# Test ordering the motors to move slightly (+1 value among 4096) and this move
# can be executed and seen by the motor position sensor
values = motors_bus.read("Present_Position")
motors_bus.write("Goal_Position", values + 1)
# Give time for the motors to move to the goal position
time.sleep(1)
new_values = motors_bus.read("Present_Position")
assert (new_values == values).all()
@require_koch
def test_find_port(request):
from lerobot.common.robot_devices.motors.dynamixel import find_port
find_port()

View File

@ -16,6 +16,7 @@
import inspect import inspect
from pathlib import Path from pathlib import Path
import einops
import pytest import pytest
import torch import torch
from huggingface_hub import PyTorchModelHubMixin from huggingface_hub import PyTorchModelHubMixin
@ -26,6 +27,7 @@ from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.utils import cycle from lerobot.common.datasets.utils import cycle
from lerobot.common.envs.factory import make_env from lerobot.common.envs.factory import make_env
from lerobot.common.envs.utils import preprocess_observation from lerobot.common.envs.utils import preprocess_observation
from lerobot.common.policies.act.modeling_act import ACTTemporalEnsembler
from lerobot.common.policies.factory import ( from lerobot.common.policies.factory import (
_policy_cfg_from_hydra_cfg, _policy_cfg_from_hydra_cfg,
get_policy_and_config_classes, get_policy_and_config_classes,
@ -33,7 +35,7 @@ from lerobot.common.policies.factory import (
) )
from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.policy_protocol import Policy from lerobot.common.policies.policy_protocol import Policy
from lerobot.common.utils.utils import init_hydra_config from lerobot.common.utils.utils import init_hydra_config, seeded_context
from lerobot.scripts.train import make_optimizer_and_scheduler from lerobot.scripts.train import make_optimizer_and_scheduler
from tests.scripts.save_policy_to_safetensors import get_policy_stats from tests.scripts.save_policy_to_safetensors import get_policy_stats
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_cpu, require_env, require_x86_64_kernel from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_cpu, require_env, require_x86_64_kernel
@ -390,3 +392,62 @@ def test_backward_compatibility(env_name, policy_name, extra_overrides, file_nam
assert torch.isclose(param_stats[key], saved_param_stats[key], rtol=50, atol=1e-7).all() assert torch.isclose(param_stats[key], saved_param_stats[key], rtol=50, atol=1e-7).all()
for key in saved_actions: for key in saved_actions:
assert torch.isclose(actions[key], saved_actions[key], rtol=0.1, atol=1e-7).all() assert torch.isclose(actions[key], saved_actions[key], rtol=0.1, atol=1e-7).all()
def test_act_temporal_ensembler():
"""Check that the online method in ACTTemporalEnsembler matches a simple offline calculation."""
temporal_ensemble_coeff = 0.01
chunk_size = 100
episode_length = 101
ensembler = ACTTemporalEnsembler(temporal_ensemble_coeff, chunk_size)
# An batch of arbitrary sequences of 1D actions we wish to compute the average over. We'll keep the
# "action space" in [-1, 1]. Apart from that, there is no real reason for the numbers chosen.
with seeded_context(0):
# Dimension is (batch, episode_length, chunk_size, action_dim(=1))
# Stepping through the episode_length dim is like running inference at each rollout step and getting
# a different action chunk.
batch_seq = torch.stack(
[
torch.rand(episode_length, chunk_size) * 0.05 - 0.6,
torch.rand(episode_length, chunk_size) * 0.02 - 0.01,
torch.rand(episode_length, chunk_size) * 0.2 + 0.3,
],
dim=0,
).unsqueeze(-1) # unsqueeze for action dim
batch_size = batch_seq.shape[0]
# Exponential weighting (normalized). Unsqueeze once to match the position of the `episode_length`
# dimension of `batch_seq`.
weights = torch.exp(-temporal_ensemble_coeff * torch.arange(chunk_size)).unsqueeze(-1)
# Simulate stepping through a rollout and computing a batch of actions with model on each step.
for i in range(episode_length):
# Mock a batch of actions.
actions = torch.zeros(size=(batch_size, chunk_size, 1)) + batch_seq[:, i]
online_avg = ensembler.update(actions)
# Simple offline calculation: avg = Σ(aᵢ*wᵢ) / Σ(wᵢ).
# Note: The complicated bit here is the slicing. Think about the (episode_length, chunk_size) grid.
# What we want to do is take diagonal slices across it starting from the left.
# eg: chunk_size=4, episode_length=6
# ┌───────┐
# │0 1 2 3│
# │1 2 3 4│
# │2 3 4 5│
# │3 4 5 6│
# │4 5 6 7│
# │5 6 7 8│
# └───────┘
chunk_indices = torch.arange(min(i, chunk_size - 1), -1, -1)
episode_step_indices = torch.arange(i + 1)[-len(chunk_indices) :]
seq_slice = batch_seq[:, episode_step_indices, chunk_indices]
offline_avg = (
einops.reduce(seq_slice * weights[: i + 1], "b s 1 -> b 1", "sum") / weights[: i + 1].sum()
)
# Sanity check. The average should be between the extrema.
assert torch.all(einops.reduce(seq_slice, "b s 1 -> b 1", "min") <= offline_avg)
assert torch.all(offline_avg <= einops.reduce(seq_slice, "b s 1 -> b 1", "max"))
# Selected atol=1e-4 keeping in mind actions in [-1, 1] and excepting 0.01% error.
assert torch.allclose(online_avg, offline_avg, atol=1e-4)
if __name__ == "__main__":
test_act_temporal_ensembler()

128
tests/test_robots.py Normal file
View File

@ -0,0 +1,128 @@
import pickle
from pathlib import Path
import pytest
import torch
from lerobot.common.robot_devices.robots.factory import make_robot
from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError
from tests.utils import require_koch
@require_koch
def test_robot(tmpdir, request):
# TODO(rcadene): measure fps in nightly?
# TODO(rcadene): test logs
# TODO(rcadene): add compatibility with other robots
from lerobot.common.robot_devices.robots.koch import KochRobot
# Save calibration preset
calibration = {
"follower_main": {
"shoulder_pan": (-2048, False),
"shoulder_lift": (2048, True),
"elbow_flex": (-1024, False),
"wrist_flex": (2048, True),
"wrist_roll": (2048, True),
"gripper": (2048, True),
},
"leader_main": {
"shoulder_pan": (-2048, False),
"shoulder_lift": (1024, True),
"elbow_flex": (2048, True),
"wrist_flex": (-2048, False),
"wrist_roll": (2048, True),
"gripper": (2048, True),
},
}
tmpdir = Path(tmpdir)
calibration_path = tmpdir / "calibration.pkl"
calibration_path.parent.mkdir(parents=True, exist_ok=True)
with open(calibration_path, "wb") as f:
pickle.dump(calibration, f)
# Test connecting without devices raises an error
robot = KochRobot()
with pytest.raises(ValueError):
robot.connect()
del robot
# Test using robot before connecting raises an error
robot = KochRobot()
with pytest.raises(RobotDeviceNotConnectedError):
robot.teleop_step()
with pytest.raises(RobotDeviceNotConnectedError):
robot.teleop_step(record_data=True)
with pytest.raises(RobotDeviceNotConnectedError):
robot.capture_observation()
with pytest.raises(RobotDeviceNotConnectedError):
robot.send_action(None)
with pytest.raises(RobotDeviceNotConnectedError):
robot.disconnect()
# Test deleting the object without connecting first
del robot
# Test connecting
robot = make_robot("koch")
# TODO(rcadene): proper monkey patch
robot.calibration_path = calibration_path
robot.connect() # run the manual calibration precedure
assert robot.is_connected
# Test connecting twice raises an error
with pytest.raises(RobotDeviceAlreadyConnectedError):
robot.connect()
# Test disconnecting with `__del__`
del robot
# Test teleop can run
robot = make_robot("koch")
robot.calibration_path = calibration_path
robot.connect()
robot.teleop_step()
# Test data recorded during teleop are well formated
observation, action = robot.teleop_step(record_data=True)
# State
assert "observation.state" in observation
assert isinstance(observation["observation.state"], torch.Tensor)
assert observation["observation.state"].ndim == 1
dim_state = sum(len(robot.follower_arms[name].motors) for name in robot.follower_arms)
assert observation["observation.state"].shape[0] == dim_state
# Cameras
for name in robot.cameras:
assert f"observation.images.{name}" in observation
assert isinstance(observation[f"observation.images.{name}"], torch.Tensor)
assert observation[f"observation.images.{name}"].ndim == 3
# Action
assert "action" in action
assert isinstance(action["action"], torch.Tensor)
assert action["action"].ndim == 1
dim_action = sum(len(robot.follower_arms[name].motors) for name in robot.follower_arms)
assert action["action"].shape[0] == dim_action
# TODO(rcadene): test if observation and action data are returned as expected
# Test capture_observation can run and observation returned are the same (since the arm didnt move)
captured_observation = robot.capture_observation()
assert set(captured_observation.keys()) == set(observation.keys())
for name in captured_observation:
if "image" in name:
# TODO(rcadene): skipping image for now as it's challenging to assess equality between two consecutive frames
continue
assert torch.allclose(captured_observation[name], observation[name], atol=1)
# Test send_action can run
robot.send_action(action["action"])
# Test disconnecting
robot.disconnect()
assert not robot.is_connected
for name in robot.follower_arms:
assert not robot.follower_arms[name].is_connected
for name in robot.leader_arms:
assert not robot.leader_arms[name].is_connected
for name in robot.cameras:
assert not robot.cameras[name].is_connected
del robot

View File

@ -20,21 +20,6 @@ import pytest
from lerobot.scripts.visualize_dataset import visualize_dataset from lerobot.scripts.visualize_dataset import visualize_dataset
@pytest.mark.parametrize(
"repo_id",
["lerobot/pusht"],
)
def test_visualize_dataset(tmpdir, repo_id):
rrd_path = visualize_dataset(
repo_id,
episode_index=0,
batch_size=32,
save=True,
output_dir=tmpdir,
)
assert rrd_path.exists()
@pytest.mark.parametrize( @pytest.mark.parametrize(
"repo_id", "repo_id",
["lerobot/pusht"], ["lerobot/pusht"],

View File

@ -147,3 +147,22 @@ def require_package(package_name):
return wrapper return wrapper
return decorator return decorator
def require_koch(func):
"""
Decorator that skips the test if an alexander koch robot is not available
"""
@wraps(func)
def wrapper(*args, **kwargs):
# Access the pytest request context to get the is_koch_available fixture
request = kwargs.get("request")
if request is None:
raise ValueError("The 'request' fixture must be passed to the test function as a parameter.")
if not request.getfixturevalue("is_koch_available"):
pytest.skip("An alexander koch robot is not available.")
return func(*args, **kwargs)
return wrapper