Merge branch 'main' into add-dot
This commit is contained in:
commit
035e95a41b
|
@ -41,7 +41,7 @@ jobs:
|
||||||
|
|
||||||
- name: Get changed files
|
- name: Get changed files
|
||||||
id: changed-files
|
id: changed-files
|
||||||
uses: tj-actions/changed-files@v44
|
uses: tj-actions/changed-files@3f54ebb830831fc121d3263c1857cfbdc310cdb9 #v42
|
||||||
with:
|
with:
|
||||||
files: docker/**
|
files: docker/**
|
||||||
json: "true"
|
json: "true"
|
||||||
|
|
|
@ -126,7 +126,7 @@ jobs:
|
||||||
# portaudio19-dev is needed to install pyaudio
|
# portaudio19-dev is needed to install pyaudio
|
||||||
run: |
|
run: |
|
||||||
sudo apt-get update && \
|
sudo apt-get update && \
|
||||||
sudo apt-get install -y libegl1-mesa-dev portaudio19-dev
|
sudo apt-get install -y libegl1-mesa-dev ffmpeg portaudio19-dev
|
||||||
|
|
||||||
- name: Install uv and python
|
- name: Install uv and python
|
||||||
uses: astral-sh/setup-uv@v5
|
uses: astral-sh/setup-uv@v5
|
||||||
|
|
16
README.md
16
README.md
|
@ -98,14 +98,18 @@ conda create -y -n lerobot python=3.10
|
||||||
conda activate lerobot
|
conda activate lerobot
|
||||||
```
|
```
|
||||||
|
|
||||||
Install 🤗 LeRobot:
|
When using `miniconda`, if you don't have `ffmpeg` in your environment:
|
||||||
```bash
|
```bash
|
||||||
pip install -e .
|
conda install ffmpeg
|
||||||
```
|
```
|
||||||
|
|
||||||
> **NOTE:** Depending on your platform, If you encounter any build errors during this step
|
Install 🤗 LeRobot:
|
||||||
you may need to install `cmake` and `build-essential` for building some of our dependencies.
|
```bash
|
||||||
On linux: `sudo apt-get install cmake build-essential`
|
pip install --no-binary=av -e .
|
||||||
|
```
|
||||||
|
|
||||||
|
> **NOTE:** If you encounter build errors, you may need to install additional dependencies (`cmake`, `build-essential`, and `ffmpeg libs`). On Linux, run:
|
||||||
|
`sudo apt-get install cmake build-essential python-dev pkg-config libavformat-dev libavcodec-dev libavdevice-dev libavutil-dev libswscale-dev libswresample-dev libavfilter-dev pkg-config`. For other systems, see: [Compiling PyAV](https://pyav.org/docs/develop/overview/installation.html#bring-your-own-ffmpeg)
|
||||||
|
|
||||||
For simulations, 🤗 LeRobot comes with gymnasium environments that can be installed as extras:
|
For simulations, 🤗 LeRobot comes with gymnasium environments that can be installed as extras:
|
||||||
- [aloha](https://github.com/huggingface/gym-aloha)
|
- [aloha](https://github.com/huggingface/gym-aloha)
|
||||||
|
@ -114,7 +118,7 @@ For simulations, 🤗 LeRobot comes with gymnasium environments that can be inst
|
||||||
|
|
||||||
For instance, to install 🤗 LeRobot with aloha and pusht, use:
|
For instance, to install 🤗 LeRobot with aloha and pusht, use:
|
||||||
```bash
|
```bash
|
||||||
pip install -e ".[aloha, pusht]"
|
pip install --no-binary=av -e ".[aloha, pusht]"
|
||||||
```
|
```
|
||||||
|
|
||||||
To use [Weights and Biases](https://docs.wandb.ai/quickstart) for experiment tracking, log in with
|
To use [Weights and Biases](https://docs.wandb.ai/quickstart) for experiment tracking, log in with
|
||||||
|
|
|
@ -51,7 +51,7 @@ For a comprehensive list and documentation of these parameters, see the ffmpeg d
|
||||||
### Decoding parameters
|
### Decoding parameters
|
||||||
**Decoder**
|
**Decoder**
|
||||||
We tested two video decoding backends from torchvision:
|
We tested two video decoding backends from torchvision:
|
||||||
- `pyav` (default)
|
- `pyav`
|
||||||
- `video_reader` (requires to build torchvision from source)
|
- `video_reader` (requires to build torchvision from source)
|
||||||
|
|
||||||
**Requested timestamps**
|
**Requested timestamps**
|
||||||
|
|
|
@ -67,7 +67,7 @@ def parse_int_or_none(value) -> int | None:
|
||||||
def check_datasets_formats(repo_ids: list) -> None:
|
def check_datasets_formats(repo_ids: list) -> None:
|
||||||
for repo_id in repo_ids:
|
for repo_id in repo_ids:
|
||||||
dataset = LeRobotDataset(repo_id)
|
dataset = LeRobotDataset(repo_id)
|
||||||
if dataset.video:
|
if len(dataset.meta.video_keys) > 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Use only image dataset for running this benchmark. Video dataset provided: {repo_id}"
|
f"Use only image dataset for running this benchmark. Video dataset provided: {repo_id}"
|
||||||
)
|
)
|
||||||
|
|
|
@ -59,15 +59,9 @@ git clone https://github.com/huggingface/lerobot.git ~/lerobot
|
||||||
|
|
||||||
#### 5. Install LeRobot with dependencies for the feetech motors:
|
#### 5. Install LeRobot with dependencies for the feetech motors:
|
||||||
```bash
|
```bash
|
||||||
cd ~/lerobot && pip install -e ".[feetech]"
|
cd ~/lerobot && pip install --no-binary=av -e ".[feetech]"
|
||||||
```
|
```
|
||||||
|
|
||||||
*EXTRA: For Linux only (not Mac)*: install extra dependencies for recording datasets:
|
|
||||||
```bash
|
|
||||||
conda install -y -c conda-forge ffmpeg
|
|
||||||
pip uninstall -y opencv-python
|
|
||||||
conda install -y -c conda-forge "opencv>=4.10.0"
|
|
||||||
```
|
|
||||||
Great :hugs:! You are now done installing LeRobot and we can begin assembling the SO100 arms :robot:.
|
Great :hugs:! You are now done installing LeRobot and we can begin assembling the SO100 arms :robot:.
|
||||||
Every time you now want to use LeRobot you can go to the `~/lerobot` folder where we installed LeRobot and run one of the commands.
|
Every time you now want to use LeRobot you can go to the `~/lerobot` folder where we installed LeRobot and run one of the commands.
|
||||||
|
|
||||||
|
@ -583,6 +577,13 @@ Let's explain it:
|
||||||
|
|
||||||
Training should take several hours. You will find checkpoints in `outputs/train/act_so100_test/checkpoints`.
|
Training should take several hours. You will find checkpoints in `outputs/train/act_so100_test/checkpoints`.
|
||||||
|
|
||||||
|
To resume training from a checkpoint, below is an example command to resume from `last` checkpoint of the `act_so100_test` policy:
|
||||||
|
```bash
|
||||||
|
python lerobot/scripts/train.py \
|
||||||
|
--config_path=outputs/train/act_so100_test/checkpoints/last/pretrained_model/train_config.json \
|
||||||
|
--resume=true
|
||||||
|
```
|
||||||
|
|
||||||
## K. Evaluate your policy
|
## K. Evaluate your policy
|
||||||
|
|
||||||
You can use the `record` function from [`lerobot/scripts/control_robot.py`](../lerobot/scripts/control_robot.py) but with a policy checkpoint as input. For instance, run this command to record 10 evaluation episodes:
|
You can use the `record` function from [`lerobot/scripts/control_robot.py`](../lerobot/scripts/control_robot.py) but with a policy checkpoint as input. For instance, run this command to record 10 evaluation episodes:
|
||||||
|
|
|
@ -69,7 +69,7 @@ git clone https://github.com/huggingface/lerobot.git ~/lerobot
|
||||||
|
|
||||||
#### 5. Install LeRobot with dependencies for the feetech motors:
|
#### 5. Install LeRobot with dependencies for the feetech motors:
|
||||||
```bash
|
```bash
|
||||||
cd ~/lerobot && pip install -e ".[feetech]"
|
cd ~/lerobot && pip install --no-binary=av -e ".[feetech]"
|
||||||
```
|
```
|
||||||
|
|
||||||
## C. Install LeRobot on laptop
|
## C. Install LeRobot on laptop
|
||||||
|
@ -110,15 +110,9 @@ git clone https://github.com/huggingface/lerobot.git ~/lerobot
|
||||||
|
|
||||||
#### 5. Install LeRobot with dependencies for the feetech motors:
|
#### 5. Install LeRobot with dependencies for the feetech motors:
|
||||||
```bash
|
```bash
|
||||||
cd ~/lerobot && pip install -e ".[feetech]"
|
cd ~/lerobot && pip install --no-binary=av -e ".[feetech]"
|
||||||
```
|
```
|
||||||
|
|
||||||
*EXTRA: For Linux only (not Mac)*: install extra dependencies for recording datasets:
|
|
||||||
```bash
|
|
||||||
conda install -y -c conda-forge ffmpeg
|
|
||||||
pip uninstall -y opencv-python
|
|
||||||
conda install -y -c conda-forge "opencv>=4.10.0"
|
|
||||||
```
|
|
||||||
Great :hugs:! You are now done installing LeRobot and we can begin assembling the SO100 arms and Mobile base :robot:.
|
Great :hugs:! You are now done installing LeRobot and we can begin assembling the SO100 arms and Mobile base :robot:.
|
||||||
Every time you now want to use LeRobot you can go to the `~/lerobot` folder where we installed LeRobot and run one of the commands.
|
Every time you now want to use LeRobot you can go to the `~/lerobot` folder where we installed LeRobot and run one of the commands.
|
||||||
|
|
||||||
|
@ -399,6 +393,10 @@ python lerobot/scripts/control_robot.py \
|
||||||
```
|
```
|
||||||
|
|
||||||
# F. Teleoperate
|
# F. Teleoperate
|
||||||
|
|
||||||
|
> [!TIP]
|
||||||
|
> If you're using a Mac, you might need to give Terminal permission to access your keyboard. Go to System Preferences > Security & Privacy > Input Monitoring and check the box for Terminal.
|
||||||
|
|
||||||
To teleoperate SSH into your Raspberry Pi, and run `conda activate lerobot` and this script:
|
To teleoperate SSH into your Raspberry Pi, and run `conda activate lerobot` and this script:
|
||||||
```bash
|
```bash
|
||||||
python lerobot/scripts/control_robot.py \
|
python lerobot/scripts/control_robot.py \
|
||||||
|
|
|
@ -33,14 +33,7 @@ git clone https://github.com/huggingface/lerobot.git ~/lerobot
|
||||||
|
|
||||||
5. Install LeRobot with dependencies for the feetech motors:
|
5. Install LeRobot with dependencies for the feetech motors:
|
||||||
```bash
|
```bash
|
||||||
cd ~/lerobot && pip install -e ".[feetech]"
|
cd ~/lerobot && pip install --no-binary=av -e ".[feetech]"
|
||||||
```
|
|
||||||
|
|
||||||
For Linux only (not Mac), install extra dependencies for recording datasets:
|
|
||||||
```bash
|
|
||||||
conda install -y -c conda-forge ffmpeg
|
|
||||||
pip uninstall -y opencv-python
|
|
||||||
conda install -y -c conda-forge "opencv>=4.10.0"
|
|
||||||
```
|
```
|
||||||
|
|
||||||
## Configure the motors
|
## Configure the motors
|
||||||
|
|
|
@ -18,7 +18,7 @@ training outputs directory. In the latter case, you might want to run examples/3
|
||||||
|
|
||||||
It requires the installation of the 'gym_pusht' simulation environment. Install it by running:
|
It requires the installation of the 'gym_pusht' simulation environment. Install it by running:
|
||||||
```bash
|
```bash
|
||||||
pip install -e ".[pusht]"`
|
pip install --no-binary=av -e ".[pusht]"`
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
|
@ -33,7 +33,7 @@ First, install the additional dependencies required for robots built with dynami
|
||||||
|
|
||||||
Using `pip`:
|
Using `pip`:
|
||||||
```bash
|
```bash
|
||||||
pip install -e ".[dynamixel]"
|
pip install --no-binary=av -e ".[dynamixel]"
|
||||||
```
|
```
|
||||||
|
|
||||||
Using `poetry`:
|
Using `poetry`:
|
||||||
|
@ -46,13 +46,6 @@ Using `uv`:
|
||||||
uv sync --extra "dynamixel"
|
uv sync --extra "dynamixel"
|
||||||
```
|
```
|
||||||
|
|
||||||
/!\ For Linux only, ffmpeg and opencv requires conda install for now. Run this exact sequence of commands:
|
|
||||||
```bash
|
|
||||||
conda install -c conda-forge ffmpeg
|
|
||||||
pip uninstall opencv-python
|
|
||||||
conda install -c conda-forge "opencv>=4.10.0"
|
|
||||||
```
|
|
||||||
|
|
||||||
You are now ready to plug the 5V power supply to the motor bus of the leader arm (the smaller one) since all its motors only require 5V.
|
You are now ready to plug the 5V power supply to the motor bus of the leader arm (the smaller one) since all its motors only require 5V.
|
||||||
|
|
||||||
Then plug the 12V power supply to the motor bus of the follower arm. It has two motors that need 12V, and the rest will be powered with 5V through the voltage convertor.
|
Then plug the 12V power supply to the motor bus of the follower arm. It has two motors that need 12V, and the rest will be powered with 5V through the voltage convertor.
|
||||||
|
@ -292,6 +285,11 @@ Steps:
|
||||||
- Scan for devices. All 12 motors should appear.
|
- Scan for devices. All 12 motors should appear.
|
||||||
- Select the motors one by one and move the arm. Check that the graphical indicator near the top right shows the movement.
|
- Select the motors one by one and move the arm. Check that the graphical indicator near the top right shows the movement.
|
||||||
|
|
||||||
|
** There is a common issue with the Dynamixel XL430-W250 motors where the motors become undiscoverable after upgrading their firmware from Mac and Windows Dynamixel Wizard2 applications. When this occurs, it is required to do a firmware recovery (Select `DYNAMIXEL Firmware Recovery` and follow the prompts). There are two known workarounds to conduct this firmware reset:
|
||||||
|
1) Install the Dynamixel Wizard on a linux machine and complete the firmware recovery
|
||||||
|
2) Use the Dynamixel U2D2 in order to perform the reset with Windows or Mac. This U2D2 can be purchased [here](https://www.robotis.us/u2d2/).
|
||||||
|
For either solution, open DYNAMIXEL Wizard 2.0 and select the appropriate port. You will likely be unable to see the motor in the GUI at this time. Select `Firmware Recovery`, carefully choose the correct model, and wait for the process to complete. Finally, re-scan to confirm the firmware recovery was successful.
|
||||||
|
|
||||||
**Read and Write with DynamixelMotorsBus**
|
**Read and Write with DynamixelMotorsBus**
|
||||||
|
|
||||||
To get familiar with how `DynamixelMotorsBus` communicates with the motors, you can start by reading data from them. Copy past this code in the same interactive python session:
|
To get familiar with how `DynamixelMotorsBus` communicates with the motors, you can start by reading data from them. Copy past this code in the same interactive python session:
|
||||||
|
@ -829,11 +827,6 @@ It contains:
|
||||||
- `dtRphone:33.84 (29.5hz)` which is the delta time of capturing an image from the phone camera in the thread running asynchronously.
|
- `dtRphone:33.84 (29.5hz)` which is the delta time of capturing an image from the phone camera in the thread running asynchronously.
|
||||||
|
|
||||||
Troubleshooting:
|
Troubleshooting:
|
||||||
- On Linux, if you encounter a hanging issue when using cameras, uninstall opencv and re-install it with conda:
|
|
||||||
```bash
|
|
||||||
pip uninstall opencv-python
|
|
||||||
conda install -c conda-forge opencv=4.10.0
|
|
||||||
```
|
|
||||||
- On Linux, if you encounter any issue during video encoding with `ffmpeg: unknown encoder libsvtav1`, you can:
|
- On Linux, if you encounter any issue during video encoding with `ffmpeg: unknown encoder libsvtav1`, you can:
|
||||||
- install with conda-forge by running `conda install -c conda-forge ffmpeg` (it should be compiled with `libsvtav1`),
|
- install with conda-forge by running `conda install -c conda-forge ffmpeg` (it should be compiled with `libsvtav1`),
|
||||||
- or, install [Homebrew](https://brew.sh) and run `brew install ffmpeg` (it should be compiled with `libsvtav1`),
|
- or, install [Homebrew](https://brew.sh) and run `brew install ffmpeg` (it should be compiled with `libsvtav1`),
|
||||||
|
|
|
@ -45,18 +45,11 @@ git clone https://github.com/huggingface/lerobot.git ~/lerobot
|
||||||
|
|
||||||
6. Install LeRobot with stretch dependencies:
|
6. Install LeRobot with stretch dependencies:
|
||||||
```bash
|
```bash
|
||||||
cd ~/lerobot && pip install -e ".[stretch]"
|
cd ~/lerobot && pip install --no-binary=av -e ".[stretch]"
|
||||||
```
|
```
|
||||||
|
|
||||||
> **Note:** If you get this message, you can ignore it: `ERROR: pip's dependency resolver does not currently take into account all the packages that are installed.`
|
> **Note:** If you get this message, you can ignore it: `ERROR: pip's dependency resolver does not currently take into account all the packages that are installed.`
|
||||||
|
|
||||||
For Linux only (not Mac), install extra dependencies for recording datasets:
|
|
||||||
```bash
|
|
||||||
conda install -y -c conda-forge ffmpeg
|
|
||||||
pip uninstall -y opencv-python
|
|
||||||
conda install -y -c conda-forge "opencv>=4.10.0"
|
|
||||||
```
|
|
||||||
|
|
||||||
7. Run a [system check](https://docs.hello-robot.com/0.3/getting_started/stretch_hardware_overview/#system-check) to make sure your robot is ready:
|
7. Run a [system check](https://docs.hello-robot.com/0.3/getting_started/stretch_hardware_overview/#system-check) to make sure your robot is ready:
|
||||||
```bash
|
```bash
|
||||||
stretch_system_check.py
|
stretch_system_check.py
|
||||||
|
|
|
@ -32,14 +32,7 @@ git clone https://github.com/huggingface/lerobot.git ~/lerobot
|
||||||
|
|
||||||
5. Install LeRobot with dependencies for the Aloha motors (dynamixel) and cameras (intelrealsense):
|
5. Install LeRobot with dependencies for the Aloha motors (dynamixel) and cameras (intelrealsense):
|
||||||
```bash
|
```bash
|
||||||
cd ~/lerobot && pip install -e ".[dynamixel, intelrealsense]"
|
cd ~/lerobot && pip install --no-binary=av -e ".[dynamixel, intelrealsense]"
|
||||||
```
|
|
||||||
|
|
||||||
For Linux only (not Mac), install extra dependencies for recording datasets:
|
|
||||||
```bash
|
|
||||||
conda install -y -c conda-forge ffmpeg
|
|
||||||
pip uninstall -y opencv-python
|
|
||||||
conda install -y -c conda-forge "opencv>=4.10.0"
|
|
||||||
```
|
```
|
||||||
|
|
||||||
## Teleoperate
|
## Teleoperate
|
||||||
|
|
|
@ -1,243 +0,0 @@
|
||||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
import shutil
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from huggingface_hub import HfApi
|
|
||||||
|
|
||||||
from lerobot.common.constants import HF_LEROBOT_HOME
|
|
||||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
|
|
||||||
from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_raw
|
|
||||||
|
|
||||||
PUSHT_TASK = "Push the T-shaped blue block onto the T-shaped green target surface."
|
|
||||||
PUSHT_FEATURES = {
|
|
||||||
"observation.state": {
|
|
||||||
"dtype": "float32",
|
|
||||||
"shape": (2,),
|
|
||||||
"names": {
|
|
||||||
"axes": ["x", "y"],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"action": {
|
|
||||||
"dtype": "float32",
|
|
||||||
"shape": (2,),
|
|
||||||
"names": {
|
|
||||||
"axes": ["x", "y"],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"next.reward": {
|
|
||||||
"dtype": "float32",
|
|
||||||
"shape": (1,),
|
|
||||||
"names": None,
|
|
||||||
},
|
|
||||||
"next.success": {
|
|
||||||
"dtype": "bool",
|
|
||||||
"shape": (1,),
|
|
||||||
"names": None,
|
|
||||||
},
|
|
||||||
"observation.environment_state": {
|
|
||||||
"dtype": "float32",
|
|
||||||
"shape": (16,),
|
|
||||||
"names": [
|
|
||||||
"keypoints",
|
|
||||||
],
|
|
||||||
},
|
|
||||||
"observation.image": {
|
|
||||||
"dtype": None,
|
|
||||||
"shape": (3, 96, 96),
|
|
||||||
"names": [
|
|
||||||
"channels",
|
|
||||||
"height",
|
|
||||||
"width",
|
|
||||||
],
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def build_features(mode: str) -> dict:
|
|
||||||
features = PUSHT_FEATURES
|
|
||||||
if mode == "keypoints":
|
|
||||||
features.pop("observation.image")
|
|
||||||
else:
|
|
||||||
features.pop("observation.environment_state")
|
|
||||||
features["observation.image"]["dtype"] = mode
|
|
||||||
|
|
||||||
return features
|
|
||||||
|
|
||||||
|
|
||||||
def load_raw_dataset(zarr_path: Path):
|
|
||||||
try:
|
|
||||||
from lerobot.common.datasets.push_dataset_to_hub._diffusion_policy_replay_buffer import (
|
|
||||||
ReplayBuffer as DiffusionPolicyReplayBuffer,
|
|
||||||
)
|
|
||||||
except ModuleNotFoundError as e:
|
|
||||||
print("`gym_pusht` is not installed. Please install it with `pip install 'lerobot[gym_pusht]'`")
|
|
||||||
raise e
|
|
||||||
|
|
||||||
zarr_data = DiffusionPolicyReplayBuffer.copy_from_path(zarr_path)
|
|
||||||
return zarr_data
|
|
||||||
|
|
||||||
|
|
||||||
def calculate_coverage(zarr_data):
|
|
||||||
try:
|
|
||||||
import pymunk
|
|
||||||
from gym_pusht.envs.pusht import PushTEnv, pymunk_to_shapely
|
|
||||||
except ModuleNotFoundError as e:
|
|
||||||
print("`gym_pusht` is not installed. Please install it with `pip install 'lerobot[gym_pusht]'`")
|
|
||||||
raise e
|
|
||||||
|
|
||||||
block_pos = zarr_data["state"][:, 2:4]
|
|
||||||
block_angle = zarr_data["state"][:, 4]
|
|
||||||
|
|
||||||
num_frames = len(block_pos)
|
|
||||||
|
|
||||||
coverage = np.zeros((num_frames,), dtype=np.float32)
|
|
||||||
# 8 keypoints with 2 coords each
|
|
||||||
keypoints = np.zeros((num_frames, 16), dtype=np.float32)
|
|
||||||
|
|
||||||
# Set x, y, theta (in radians)
|
|
||||||
goal_pos_angle = np.array([256, 256, np.pi / 4])
|
|
||||||
goal_body = PushTEnv.get_goal_pose_body(goal_pos_angle)
|
|
||||||
|
|
||||||
for i in range(num_frames):
|
|
||||||
space = pymunk.Space()
|
|
||||||
space.gravity = 0, 0
|
|
||||||
space.damping = 0
|
|
||||||
|
|
||||||
# Add walls.
|
|
||||||
walls = [
|
|
||||||
PushTEnv.add_segment(space, (5, 506), (5, 5), 2),
|
|
||||||
PushTEnv.add_segment(space, (5, 5), (506, 5), 2),
|
|
||||||
PushTEnv.add_segment(space, (506, 5), (506, 506), 2),
|
|
||||||
PushTEnv.add_segment(space, (5, 506), (506, 506), 2),
|
|
||||||
]
|
|
||||||
space.add(*walls)
|
|
||||||
|
|
||||||
block_body, block_shapes = PushTEnv.add_tee(space, block_pos[i].tolist(), block_angle[i].item())
|
|
||||||
goal_geom = pymunk_to_shapely(goal_body, block_body.shapes)
|
|
||||||
block_geom = pymunk_to_shapely(block_body, block_body.shapes)
|
|
||||||
intersection_area = goal_geom.intersection(block_geom).area
|
|
||||||
goal_area = goal_geom.area
|
|
||||||
coverage[i] = intersection_area / goal_area
|
|
||||||
keypoints[i] = PushTEnv.get_keypoints(block_shapes).flatten()
|
|
||||||
|
|
||||||
return coverage, keypoints
|
|
||||||
|
|
||||||
|
|
||||||
def calculate_success(coverage: float, success_threshold: float):
|
|
||||||
return coverage > success_threshold
|
|
||||||
|
|
||||||
|
|
||||||
def calculate_reward(coverage: float, success_threshold: float):
|
|
||||||
return np.clip(coverage / success_threshold, 0, 1)
|
|
||||||
|
|
||||||
|
|
||||||
def main(raw_dir: Path, repo_id: str, mode: str = "video", push_to_hub: bool = True):
|
|
||||||
if mode not in ["video", "image", "keypoints"]:
|
|
||||||
raise ValueError(mode)
|
|
||||||
|
|
||||||
if (HF_LEROBOT_HOME / repo_id).exists():
|
|
||||||
shutil.rmtree(HF_LEROBOT_HOME / repo_id)
|
|
||||||
|
|
||||||
if not raw_dir.exists():
|
|
||||||
download_raw(raw_dir, repo_id="lerobot-raw/pusht_raw")
|
|
||||||
|
|
||||||
zarr_data = load_raw_dataset(zarr_path=raw_dir / "pusht_cchi_v7_replay.zarr")
|
|
||||||
|
|
||||||
env_state = zarr_data["state"][:]
|
|
||||||
agent_pos = env_state[:, :2]
|
|
||||||
|
|
||||||
action = zarr_data["action"][:]
|
|
||||||
image = zarr_data["img"] # (b, h, w, c)
|
|
||||||
|
|
||||||
if image.dtype == np.float32 and image.max() == np.float32(255):
|
|
||||||
# HACK: images are loaded as float32 but they actually encode uint8 data
|
|
||||||
image = image.astype(np.uint8)
|
|
||||||
|
|
||||||
episode_data_index = {
|
|
||||||
"from": np.concatenate(([0], zarr_data.meta["episode_ends"][:-1])),
|
|
||||||
"to": zarr_data.meta["episode_ends"],
|
|
||||||
}
|
|
||||||
|
|
||||||
# Calculate success and reward based on the overlapping area
|
|
||||||
# of the T-object and the T-area.
|
|
||||||
coverage, keypoints = calculate_coverage(zarr_data)
|
|
||||||
success = calculate_success(coverage, success_threshold=0.95)
|
|
||||||
reward = calculate_reward(coverage, success_threshold=0.95)
|
|
||||||
|
|
||||||
features = build_features(mode)
|
|
||||||
dataset = LeRobotDataset.create(
|
|
||||||
repo_id=repo_id,
|
|
||||||
fps=10,
|
|
||||||
robot_type="2d pointer",
|
|
||||||
features=features,
|
|
||||||
image_writer_threads=4,
|
|
||||||
)
|
|
||||||
episodes = range(len(episode_data_index["from"]))
|
|
||||||
for ep_idx in episodes:
|
|
||||||
from_idx = episode_data_index["from"][ep_idx]
|
|
||||||
to_idx = episode_data_index["to"][ep_idx]
|
|
||||||
num_frames = to_idx - from_idx
|
|
||||||
|
|
||||||
for frame_idx in range(num_frames):
|
|
||||||
i = from_idx + frame_idx
|
|
||||||
idx = i + (frame_idx < num_frames - 1)
|
|
||||||
frame = {
|
|
||||||
"action": action[i],
|
|
||||||
# Shift reward and success by +1 until the last item of the episode
|
|
||||||
"next.reward": reward[idx : idx + 1],
|
|
||||||
"next.success": success[idx : idx + 1],
|
|
||||||
"task": PUSHT_TASK,
|
|
||||||
}
|
|
||||||
|
|
||||||
frame["observation.state"] = agent_pos[i]
|
|
||||||
|
|
||||||
if mode == "keypoints":
|
|
||||||
frame["observation.environment_state"] = keypoints[i]
|
|
||||||
else:
|
|
||||||
frame["observation.image"] = image[i]
|
|
||||||
|
|
||||||
dataset.add_frame(frame)
|
|
||||||
|
|
||||||
dataset.save_episode()
|
|
||||||
|
|
||||||
if push_to_hub:
|
|
||||||
dataset.push_to_hub()
|
|
||||||
hub_api = HfApi()
|
|
||||||
hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, repo_type="dataset")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
# To try this script, modify the repo id with your own HuggingFace user (e.g cadene/pusht)
|
|
||||||
repo_id = "lerobot/pusht"
|
|
||||||
|
|
||||||
modes = ["video", "image", "keypoints"]
|
|
||||||
# Uncomment if you want to try with a specific mode
|
|
||||||
# modes = ["video"]
|
|
||||||
# modes = ["image"]
|
|
||||||
# modes = ["keypoints"]
|
|
||||||
|
|
||||||
raw_dir = Path("data/lerobot-raw/pusht_raw")
|
|
||||||
for mode in modes:
|
|
||||||
if mode in ["image", "keypoints"]:
|
|
||||||
repo_id += f"_{mode}"
|
|
||||||
|
|
||||||
# download and load raw dataset, create LeRobotDataset, populate it, push to hub
|
|
||||||
main(raw_dir, repo_id=repo_id, mode=mode)
|
|
||||||
|
|
||||||
# Uncomment if you want to load the local dataset and explore it
|
|
||||||
# dataset = LeRobotDataset(repo_id=repo_id)
|
|
||||||
# breakpoint()
|
|
|
@ -67,8 +67,9 @@ from lerobot.common.datasets.utils import (
|
||||||
)
|
)
|
||||||
from lerobot.common.datasets.video_utils import (
|
from lerobot.common.datasets.video_utils import (
|
||||||
VideoFrame,
|
VideoFrame,
|
||||||
decode_video_frames_torchvision,
|
decode_video_frames,
|
||||||
encode_video_frames,
|
encode_video_frames,
|
||||||
|
get_safe_default_codec,
|
||||||
get_video_info,
|
get_video_info,
|
||||||
)
|
)
|
||||||
from lerobot.common.robot_devices.robots.utils import Robot
|
from lerobot.common.robot_devices.robots.utils import Robot
|
||||||
|
@ -462,8 +463,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
download_videos (bool, optional): Flag to download the videos. Note that when set to True but the
|
download_videos (bool, optional): Flag to download the videos. Note that when set to True but the
|
||||||
video files are already present on local disk, they won't be downloaded again. Defaults to
|
video files are already present on local disk, they won't be downloaded again. Defaults to
|
||||||
True.
|
True.
|
||||||
video_backend (str | None, optional): Video backend to use for decoding videos. There is currently
|
video_backend (str | None, optional): Video backend to use for decoding videos. Defaults to torchcodec when available int the platform; otherwise, defaults to 'pyav'.
|
||||||
a single option which is the pyav decoder used by Torchvision. Defaults to pyav.
|
You can also use the 'pyav' decoder used by Torchvision, which used to be the default option, or 'video_reader' which is another decoder of Torchvision.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.repo_id = repo_id
|
self.repo_id = repo_id
|
||||||
|
@ -473,7 +474,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
self.episodes = episodes
|
self.episodes = episodes
|
||||||
self.tolerance_s = tolerance_s
|
self.tolerance_s = tolerance_s
|
||||||
self.revision = revision if revision else CODEBASE_VERSION
|
self.revision = revision if revision else CODEBASE_VERSION
|
||||||
self.video_backend = video_backend if video_backend else "pyav"
|
self.video_backend = video_backend if video_backend else get_safe_default_codec()
|
||||||
self.delta_indices = None
|
self.delta_indices = None
|
||||||
|
|
||||||
# Unused attributes
|
# Unused attributes
|
||||||
|
@ -707,9 +708,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
item = {}
|
item = {}
|
||||||
for vid_key, query_ts in query_timestamps.items():
|
for vid_key, query_ts in query_timestamps.items():
|
||||||
video_path = self.root / self.meta.get_video_file_path(ep_idx, vid_key)
|
video_path = self.root / self.meta.get_video_file_path(ep_idx, vid_key)
|
||||||
frames = decode_video_frames_torchvision(
|
frames = decode_video_frames(video_path, query_ts, self.tolerance_s, self.video_backend)
|
||||||
video_path, query_ts, self.tolerance_s, self.video_backend
|
|
||||||
)
|
|
||||||
item[vid_key] = frames.squeeze(0)
|
item[vid_key] = frames.squeeze(0)
|
||||||
|
|
||||||
return item
|
return item
|
||||||
|
@ -1029,7 +1028,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
obj.delta_timestamps = None
|
obj.delta_timestamps = None
|
||||||
obj.delta_indices = None
|
obj.delta_indices = None
|
||||||
obj.episode_data_index = None
|
obj.episode_data_index = None
|
||||||
obj.video_backend = video_backend if video_backend is not None else "pyav"
|
obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec()
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,85 +0,0 @@
|
||||||
https://drive.google.com/file/d/1_SOJkgfP5yZyVjMhTt3nwhvyUjcnlI51/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1rmgN8UUzph1qwJnzG1d-uOafodn-gLvb/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1NYQ-XxsBVinB6dUoZmVWweT83367P3i2/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1oAv_j74zxxCJieMG7r5Vl2BeHK1__3s3/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1wFUJQROsrTJt64YRuIeExhFjr2wnK5uu/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1KzL3Tt0Le7jVl58XVRUcmigmXjyiuhbK/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1qy_YBladeHtianSSGtgAPSHtMin7msvf/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1rA_F0V_qL_nyuC_0aBKCisF4-0TIkF2Y/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1hw-8qMpz9VgSt62XoASqNRuPECpCwJQP/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1BpHOl9rKMzdvNGka6js7C0s40hH6vnDA/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1PazhkhiDnJ-OUMyDVDFxEZNKQQqHiNWS/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1lZ665R6ATl57dypxH4dGJ2NSt6XYnbuz/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1V9HzLaf-tlG15wUzT7KrTDCS_z1vi5NV/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1aKauWiXoKqbNwn_2xs4MrmLlaNYlVNmO/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1WVD5DFhriO1YmmOgiVHhacR6HWoTPxav/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1_X43WgeBAsfkhH9EmpyPki8U9joMeAGC/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1t8x0GqWoNKWtnBsB7_D40Z34nL9ak4kf/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/15V_f26WaKOXjKnq2T3HRWAmtQUi4lbu2/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/11VFIAsiSDsMOBANgrOcZBpKB9AFWnLy7/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1M0NS7vVaxJv3FHnuRYtdwTFYF7We4LxP/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1mR0OItTNqFnVLoczcyKYlm6drAy778lO/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1NbVFWDQAh-z4JJ4D-Zw6Lps9kdvpqh2j/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1JQoZGBzl4W3QG26-n39tefcGN0fDRMbB/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1VBjHl-TvZpncopvasIP5G9gecbB2a5f6/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1VzSf6zaB21nahm7MsPwroXbJ84NIwq0b/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1OtNnfMEydNtZOcivs4k6E_uJSpf8PkGy/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/14nVvpvsrFr_03Pa_N7MKzwnRwibOUYM6/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1M8li6duiO2r3lv_9HhF_XJn0oZUIEK5F/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1Cpzea6fO14lxAaNfSBifqoa4ekhCiLD1/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1mbxRTm5vlbsY9UJ0jfjM6j9D7kPJjBpG/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1RXD1i6IfWsHRlCxVmG04h2h5Ycm_WwZN/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1QFqFSwDGOk1BkgGmqgCcc2BRWnJ6R3MA/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1bFqWR8DQM0ZUxxtS2bl-RANQvukeFLzp/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1pR-rH3yNGoyPdD4hJ6-3lXQ-PstBx9du/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/107OAwLY-hva9HeQLIK7VCh-ytdDabVjr/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1Tpl08QOaSZ37GTO4awFWSdD8wBR9xdlT/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1MR164AOM-0S1T6RX8xKTV2IHyaCvpqAW/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1_wknJfVnStIhJ82lU_QtcrwahsqYIsr8/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1ZuEktWrbYkTx0l5pj3WiZ2CJrfbDOHNo/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/15G_10hkkkq6yxvyI5NGZirlF-RzduR2F/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1DBKxg3ONqh7dhLuX6oh1Yyo2x383V1Hp/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1B5iDBkTUr5vopDddV_fHud18SqAHhauS/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1acwFV0eenRkki1QcjSKH5xqOtys-P3Pr/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1S47BI83xyrh-FKXsvAQqer98Biu_p8XK/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1JL6DmBZl3uyq9dyLfgSqtGF06e7E9JwM/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/16WvRS4Kjog8Pxgr0E3sGGnI01YwL9Uql/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/12ttGqL33IPWg0-s1SD44rr22M6LiSQBr/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1OyZqqnldTU_DliRbr6x0C4a_iWPwIN7j/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1oYk00IpLnR9fesLfD15Ebe7nVBffEbcS/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1eyE2-MQduCEqCd-5_kl5zsoOEERAzpZD/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1ir1Ya-vO0d97pfvbePlUeuKTTRc0qIMU/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1hOi-JnqlMt47gVnLZHMTqeojyYVErohl/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1NFFw5_PqigQ7xGqsL-MNq2B1r5yAscCf/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1uftq1-Zlh8d2sNLWrlVcKYQUwZTD7o24/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1-ax19dSLPacVgk000T-m3l4flPcg07pM/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/126y-lgn86-ZmCz8hooF1THKJGGObw3OB/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1JiDniK0VmDIkk92AbBILb8J2Ba59PWML/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1kr8nPIRljiU0R4J9SMgj80o1FPQxzu9z/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1bbThWRij1pKBh_kFgV8FwK0sXtTHBoLX/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1WenzDW6lxk1xkOFm-OiGFfc0ROskAuKU/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1MiKRzuzUn1yN-k_6kPJJzIGy7dT-nnsD/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/17rRg2tcmB-gNhQ0KoZJQmNfyFeoij1jH/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/11mokBpvrY3ld6sY5WztREtJ1jgqfQV70/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1Il_6IOx9NDp1bX_KHizJfBwzTufTmn86/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1KswtJGsxJ7eeBDAmNA_aeLjOxcH6MIxa/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1gzMhi5uWu4C3Y6WbQ3L-08V96GxTZrRR/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1nRQFtaBxfUCYc2W90Qibh0kHCt6YQCfc/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1vs-gyW-KheqHbUATwAhA2mmR9GOGw7f_/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1MuxzGOA2fgLaHryq82KkQumtuRJGcUOC/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1IIwxZnGlqrXLUXqG6yMO0r7uhCvhpk9e/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1vE7XPyaFcXP4DtTY5Y9WKIt7zWgmX-Cr/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1j-bIV09gr21RC3-x1N_pK4RPLV3fmWKz/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1t3nW1rD3S-EL0Oymb5U7ZAj5UMkydkln/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/14hbfHCdMKtJZ41F9CQReMec2jeRFTOqR/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1x-hUyOSne5BW0AzQ3W6_Pf4g5yXQWi9M/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1sw9JqRg6E-3P84I3ZhzTrJMu0vuiaMmP/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1LuqhQlL4MGZhB_6THmkovRxrlP26BbdC/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/15C5K6v_lkjnMSmUvVyqHQKwh2N166e7K/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1ns_9eSsQeeoZ10nlbkLy8tu0GmJFSnkt/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1NpzWJeK6CqjxzjIMYe6aYdX8xGsQwD4o/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1NMLezwufKJ9_8xTc9KQThSzVVD71B9Ui/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1aa71DCUqs6oXlIxX35jgsmsgm-NlDxPV/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1UJzkIZzAL0j-D5YQBnoq7mHvttASy12O/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1nPgx36HIJFb7oI94VbRzWjpPP2GANxzG/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1NovAP-KVJjqcuvWy3d6G4ptGGAIDqcCx/view?usp=drive_link
|
|
|
@ -1,55 +0,0 @@
|
||||||
https://drive.google.com/file/d/11M3Ye0r5agMaaicPbVGD0q2Hb3rGklbb/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1-tx7SvYYgSvXCvnf_EI2OVdwK-CkFY6S/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1EWJunmOpMHaU1hE106wwpbkGYcjQXYAF/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1IDn95Z7FSiCckrSENtGV4u3RyFHNQSDY/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1CwzvWj1i7QOtqrZvsCZ6BdZaKNDfpN32/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1HvAvlhm77nAD3Td24QPSeq8lw-Rl_aOh/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1t-suKYOPhXH666RpAYNRp2QU_DOy3AeM/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/18xpKgWh7RWyjMN5PkLTOo-AxsAadAuRw/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1oci5Eto-ztv-AQNz8EnwZveBIhxvk-xJ/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1Y-t_4vxdE6NpHO0DLJR8f3mD0Q-Wj5-c/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1lylRqbbbB8bgtpsBWMPACmHJreuKmllv/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1yliSyMig_NXShWfQx6qyW7Ijf2Y5lFK6/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1XXhwJsJbeb7KXAooGvJapnm9bjnGUmxS/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1_xs1f3hW2JArKyvfF7UWubWjyROGTLs6/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1WVEHpr6EqKCZbkHapQSTXJq4xE4SWFT-/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1RqOHv9pEQGvW8NUA7ynffFmG999TL_Az/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1cu5AgD2gh-uA3PFJmzxxzNaF3qOSlYY1/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1SsrXqiPclNrnYToPZ9Uq-k3y0C4qdHT1/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1-J7EXf0vjkLIfSqT8ICEsP6CTjzSLBop/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/11O7ewUmoZXfyyKjy_6B5RW4DpjICxqBT/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1iic44kZoCsjNsfAz2cMstZ9-WQvAhblF/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1yLV1lVX-2WnWQldGlnQZ0x7QBuDiVkL3/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1Tybp9ru98TTbGn4eyROpUQwDFuALWXmk/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/13E9OTMiipVJByDs5-J19oWwAz7l94LTN/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1EeTpJQdMSliw4JzSMtJ6CyTvVdexjM4M/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1NHyNwoFqzeAu-1_PSpq5JfxaiD_xbpn9/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1fJcS0phDp4xm_FyGaJ5wr9Pe4KqtHaxD/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/12AqrLUaewDPEcFRqPZeZFb_TQ0Lfi3At/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1x_hd4Qsq1oJS-aj2t3qM7WbbV7KZj05b/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/14OUSUArmsB068hs6BuEIXQhI1Cyz8Sf0/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/16zlzh1T5zeUJQnFf382NXkFEKEnDub4O/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1IbDltmN-NEFCNtr1TO4ILxEgQ94rtjWv/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/15gmlf8Gx9455pZ1AlqcCSwh3nDPxMzSr/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1qHpRL1oZfIMo_vxnm8qfwQ-7l0BZIVva/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1H1xskIgiFZivkYn23rMzH3xePGOh3VTC/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1avls6Pv0kYiCMNVknbc1zQsgy64MUDMM/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1MmWVgCj5khc8KMIifmt3EzF1o-CtPyyn/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1U0kCc_xqW0WNppf4sbnK14euWKdPZtzB/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/16CaEyQscOuhLj23PEGDTL9DeyNkohkMn/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1Iu8uM6UUJ0zW8tvN-9UiOe_4oSNzEutg/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1UImqiBaIxCR-1DNJaZhHqeHhaySOtVIr/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1VpU2V_leIoRIyv_lAvE7eLHBG8DxCTnp/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1_Q8J27OT3Xby7QY6yHvIJauFRWEMxkRm/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1bantmVo1L9Xz4tbiNw_a1UC2Z_HPO1wT/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1IRIXMJMCBDkBjbaHvAlEiBogSvZ1jK_3/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1mAHXKjiFbjwydypW2t5Lv8_H5x6nHegl/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1SfyY796fLrBCMY39OcyuxZafqSCRZPZk/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1X-44sZ8CcfzIskc0dvSx882o1yFhHaZB/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1BOIWCCCk6DLD4Bmvc75ZbbLi9AQm-1ao/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1RuyDtRE1kk76sw-wP8vx5SgLoPF3PA_H/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1c4eoQiBbGuy3CTAQDUSkd84Ponh1roAQ/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/19PXB9z4Ljq6dsbf9TqcOrrP5SRbw2Tc_/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1nn1VVZVoIXWdYDozR7XHXE4mPLQG80PQ/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1MBdFGOKPV8GUhwoSsJ_Ky3qAMLM2Bv3K/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1of3k_M-7Nh3I1TndcWedxK4ca9dn8Sc5/view?usp=drive_link
|
|
|
@ -1,20 +0,0 @@
|
||||||
https://drive.google.com/file/d/12ctkOAdkCNGN1JLbZb5ww3XTBn2LFpGI/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1G_Vd46_4fq6O64gHHjUbJX5Ld44ZZx0y/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1uKgUy73B3xBogQAOUhfZjO0X5qZGsi2c/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1fu9cIrfI-fE2LhdGUxbx7-8Ci_PF8Ypm/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1Ygk9ZPJzx8xw2A9JF3NHbJ44TqnvSTQR/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/18m5xPuccNsEB20WPshm3zhxmXc6k63ED/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1DiqqxC44rriviRQpqogcv0-EB-Y6nr9g/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1qPdaoTVDizJXkfXLioWU7iJ8hqCXSyOQ/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1Fj9kIA_mG7f67WFfACJEaZ7izcHG7vUm/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1WpYehZnI2P7dUdJPfkE-ij1rqCnjZEbB/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1_zwWkT4jPyzB38STWb6whlzsPzXmfA9r/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1U6-J4I_fPlSFFGfhZPxS5_YzKXwXIZYp/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1pRhxxcTfZp5tQo_EScvJUwfc3amiS6Vk/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1lWLntqra83RlYU_gN7Vostnfydf6gutd/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1vIBKo0x-NYEHV1FvRpco1lQMpRdAWAIL/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1pdrLV3JTQou_XH0Aap61Ssf60iVKm1jJ/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1QTsLoQ7SwmKdQHjBGVDaR2uTwfFwtrOf/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1Gytai8M_12J36GY6L_TulEcOC-035jwS/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/14LJudNc629NT-i8xreXtzl27ce_DxOFJ/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1sBvPCODbzxGAI0S3lgN5cSG9Go3lRi00/view?usp=drive_link
|
|
|
@ -1,18 +0,0 @@
|
||||||
https://drive.google.com/file/d/1MJn9GbC8p9lN4gC9KDMLEkTkP_gGpXj0/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1-4LXgjl7ZCOgp-8GCJmFRD8OeqN5Jf7-/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1Ho06Ce0SPbqU3juaMxNUwAt3zCRLGC8W/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1ivHoj7_7olBSxH-Y8kqXEW7ttITK-45j/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1qjY4hM_IvZ8cq2II_n9MeJbvyeuN4oBP/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1rKVhO_f92-7sw13T8hTVrza3B9oAVgoy/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1pcLPHO8fBkc1-CRa88tyQtEueE4xiXNi/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1Vev_chCsIeEdvQ8poEYNsOJFGy_QU8kZ/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1l5G4zpRkxSLCQjvGPYSN4zfCvVRQuzMz/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/14vgthE1eoakXkr2-DRw50E6lAqYOiUuE/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/17nPSmKKmgQ2B7zkzWrZYiLM3RBuFod82/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1QcDsxplVvb_ID9BVrihl5FvlC-j7waXi/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/18pEejBpI-eEVaWAAjBCyC0vgbX3T1Esj/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1H8eH6_IRODtEFT6WoM77ltR5OoOrqXmI/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1IWlpFRZhoxyG4nS13CWK4leZVk5wbNx4/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1PbZA8_OCGmMLxNP9xbkLRSChniL4uGxl/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1p9XAdmG2f_WeflNO4DIJ_tr1rK6M9B4B/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1nS59Et1cNAvKo3Y4SeSGRuZD5TvBbCF3/view?usp=drive_link
|
|
|
@ -1 +0,0 @@
|
||||||
https://drive.google.com/drive/folders/1S8eFg98IaGAIKVZ8QFWG1bx4mHa-O204
|
|
|
@ -1,4 +0,0 @@
|
||||||
https://drive.google.com/drive/folders/1tC_g1AJ8lglBLY-fjsQrG6DMBa3Ucp-0
|
|
||||||
https://drive.google.com/file/d/1fG_Yi2MJrFjiUVN3XoiWXLtTxHlwwaDv/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1WX32VWfzzX3Blmd06DRxLwFbMJfVe7P4/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/18onsX3vXg3xkFwP5bVUCjdV4n9TRn0C9/view?usp=drive_link
|
|
|
@ -1,3 +0,0 @@
|
||||||
https://drive.google.com/drive/folders/1RgyD0JgTX30H4IM5XZn8I3zSV_mr8pyF
|
|
||||||
https://drive.google.com/file/d/18Cudl6nikDtgRolea7je8iF_gGKzynOP/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1C1kZYyROzs-PrLc0SkDgUgMi4-L3lauE/view?usp=drive_link
|
|
|
@ -1,3 +0,0 @@
|
||||||
https://drive.google.com/drive/folders/1TsojQQSXtHEoGnqgJ3gmpPQR2DPLtS2N
|
|
||||||
https://drive.google.com/file/d/1wfMSZ24oOh5KR_0aaP3Cnu_c4ZCveduB/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/17EuCUWS6uCCr6yyNzpXdcdE-_TTNCKtf/view?usp=drive_link
|
|
|
@ -1,3 +0,0 @@
|
||||||
https://drive.google.com/drive/folders/1sc-E4QYW7A0o23m1u2VWNGVq5smAsfCo
|
|
||||||
https://drive.google.com/file/d/18smMymtr8tIxaNUQ61gW6dG50pt3MvGq/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1Nk7l53d9sJoGDBKAOnNrExX5nLacATc6/view?usp=drive_link
|
|
|
@ -1,3 +0,0 @@
|
||||||
https://drive.google.com/drive/folders/1aRyoOhQwxhyt1J8XgEig4s6kzaw__LXj
|
|
||||||
https://drive.google.com/file/d/1pnGIOd-E4-rhz2P3VxpknMKRZCoKt6eI/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1GKReZHrXU73NMiC5zKCq_UtqPVtYq8eo/view?usp=drive_link
|
|
|
@ -1,2 +0,0 @@
|
||||||
https://drive.google.com/drive/folders/19qS_n7vKgDcPeTMnvDHQ5-n73xEbJz5D
|
|
||||||
https://drive.google.com/file/d/1oC31By0A2bsBeHyUwBdQw1z4ng6yi9Za/view?usp=drive_link
|
|
|
@ -1,2 +0,0 @@
|
||||||
https://drive.google.com/drive/folders/1m5rQ6UVH8Q9RQp_6c0CxkQ88-L-ScO7q
|
|
||||||
https://drive.google.com/file/d/1wHz2qcmwcVG0C0CZ9MjQDQcmj4OY9_a3/view?usp=drive_link
|
|
|
@ -1,2 +0,0 @@
|
||||||
https://drive.google.com/drive/folders/1seQGay470nGQ-knBI5TjsTr8iL9Qws5q
|
|
||||||
https://drive.google.com/file/d/1T89hSX5U99wLGvGTE7yUBaQPOpyj6Sai/view?usp=drive_link
|
|
|
@ -1,2 +0,0 @@
|
||||||
https://drive.google.com/drive/folders/1t3eDc5Rg0DveyRe8oTm6Dia_FYU5mXyf
|
|
||||||
https://drive.google.com/file/d/1TXFaduTakvS0ZWJqKCX-HIvYglum_5CY/view?usp=drive_link
|
|
|
@ -1,2 +0,0 @@
|
||||||
https://drive.google.com/drive/folders/1Z9X3DNzd6LS0FFjQemNUMoMA5yk5VQOh
|
|
||||||
https://drive.google.com/file/d/1Wlyc0vTkjXuWB6zbaVOWhEfD7BmPgUV_/view?usp=drive_link
|
|
|
@ -1,53 +0,0 @@
|
||||||
https://drive.google.com/drive/folders/1DYgB4ifX4uIid9m9jnC0Zdz8Nf7ZC0fc
|
|
||||||
https://drive.google.com/file/d/1Eb-NRNk_FmVleCbU_Ng5Y4dfcjTKN7Rv/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1dkhjEADakT-44l9jf-nK4x89kr4yG_qb/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/14hDhgcZkVqNExGb4tIXpSjMshhqZETch/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1zVMEHpHbuNyP5A_lYU7RPSLB-4V0yfZw/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1JtgDjBvy7FnRpFzrx_foC3quorYQFAR-/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1EHdneB6F-PP0dQlX8qPaXbxmKoBy_YwO/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/17Z0jjVBy1OPKREPu77_n_rQzorDiapji/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1F4i23qPJ_qTf5jWjfLo4ARGJChznYWt3/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1kZtXWM3uS0-rLblydBfJ0mMcVnMMXw9w/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1mNODox87xFfY5Z_o5mcLsr8SHb39jDik/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1Ob44VdmEUA93FKDECiRb5Ogz2xQg5IWp/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1fdQLdjj3Cwv33R1wZhfrLz9Del8mqgHb/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1Yu3L3ft21zP__XL8pCfhb788ZleuW1n5/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1ozBBWXVZ9hXDh9ooHUNroHdYm8UDqnhJ/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1o0TGqvfWw_Lunxb5ubKDS21Lr_WC0h75/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1jZnd5eP5L6BH5l98BPN6OnoQx3fu8e9n/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1S5sYbz8wcLYp0V67v13i4PRcBxodn4Hg/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1rFeg_x6ftJYwPtBv34D3h2L2cpDLeR4G/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1GvS3lcm4o6nm_scUk0XxKeVFNmzjucDZ/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1-9i0riphC7NhhDahcQfD1QoBXP5gF90A/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/15p_IqGsMbKuvzMS872THAZr-3SBtb1Fr/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1ToyYcBfJL8gbQn0q_59zPLsFmm7dmMJo/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1e_7PNH7CYafE4pAebP7ZdI7XFbmEcy_i/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1JoabvGVsIQdug2xOhUIhetEIyDM91y_Y/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1kOMw1y0lmnVaCjwZICfzCsx6e0Z8MNGR/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/16it_wd1JOevUQTK2_CvF_pBACTgpIPgM/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1IRcCj9HnJSfbyMgr5XEERGlEnWeZQwOc/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1Z2dIJfq_S3liGmPN9Rphvkmucnmw7tlb/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1J3NoAjzndGx9yNyaBOJHdNny1epzUoBt/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/18nOvxV1k8FSmBrhT4TPo2sKKSZXougyx/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1CT8FxclafFMjSd7gCWVw3VSeryeiF04i/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/16M9KVqQMFfSsXfypK0bocFft8Nz3j2Rt/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/18QPVkw6bj6HW8LTPrQLWrrUX4R6RcF42/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1hQTVtA5hBTE_StXpJafTZJ3tgt2VQQ_t/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1Dn-d5g69H6EgAWgsFdrcbJKtz7ySsCQ8/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/13hMr16483P7ALYv73yMRUN37fJdVQM62/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1848yN3XMN5zJMEgApt6KzrWgfRPfimtv/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1oAD9kSnS0fTgj-CjD4u9VdZ5X67IOIMa/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1ilzIWLCCG5b_KgF5s0wdN2I5-lFNpwC1/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1rjsT2YBjnidxod1s9s-myAYz8boHr-WB/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/18Gg48HTub15bd8qzbhiCUufbVy0fbN5G/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1WsSnQSqmMTVSRwrhT1Y-v782My2zcjLm/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1ea9ZCvoyc-xqiFXgeDcA_mOWsw7VUuoi/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1wv1v3-XhPgbNzp62BXbJTDzMPu2tlDUc/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/18-ikzt8LoZ83Gi3goKCELs4U4z8hrRoF/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/16Bjhp7JNCXkGuLvyNcZowAx3W-Y-15DV/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1Gc-KRI-xwcp1fMR55ugbrLg_5y3SPde-/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1oP72Q386Z4Sy5MMm-t5yNogIe5Van_9k/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/112T90eDUDVH-SyOV7UnZl5bscAH2hcfq/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1y-uKOesRRhjgDtFbG_j65f4SGg0v8XDg/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1LOP05OagoI3km-ZKQBrS204A85UVk7Ok/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1QkHQKgasVzWsmdPvkXgGhWyQ84d93_Az/view?usp=drive_link
|
|
|
@ -1 +0,0 @@
|
||||||
https://drive.google.com/drive/folders/1Ut2cv6o6Pkfgg46DgwVUM7Z5PkNG8eJ-
|
|
|
@ -1 +0,0 @@
|
||||||
https://drive.google.com/drive/folders/1FqxPV0PgvgIu8XFjtvZSPSExuNcxVVAY
|
|
|
@ -1,2 +0,0 @@
|
||||||
https://drive.google.com/drive/folders/1SKtG0ct9q0nVdYssJNMWSOjikcXliT58
|
|
||||||
https://drive.google.com/file/d/1nchD21O30B3i3LDoqramo1zgW5YvpJIN/view?usp=drive_link
|
|
|
@ -1,2 +0,0 @@
|
||||||
https://drive.google.com/drive/folders/1_4DHf2cma0xsChLQFghwigX6Ukti5-zQ
|
|
||||||
https://drive.google.com/file/d/1_8vS4hDNDgUQY-SmekrNaa7dF67QJYU-/view?usp=drive_link
|
|
|
@ -1,2 +0,0 @@
|
||||||
https://drive.google.com/drive/folders/1_4DHf2cma0xsChLQFghwigX6Ukti5-zQ
|
|
||||||
https://drive.google.com/file/d/1_8vS4hDNDgUQY-SmekrNaa7dF67QJYU-/view?usp=drive_link
|
|
|
@ -1,2 +0,0 @@
|
||||||
https://drive.google.com/drive/folders/1fAD7vkyTGTFB_nGXIKofCU1U05oE3MFv
|
|
||||||
https://drive.google.com/file/d/1XzyQ2B6LLvcurIonOpEu4nij2qwNWshH/view?usp=drive_link
|
|
|
@ -1,53 +0,0 @@
|
||||||
https://drive.google.com/drive/folders/13EQsVsnxT86K20QAoyE_YpsFbQ7fZQdu
|
|
||||||
https://drive.google.com/file/d/1-W_JHghZG65FNTVhw1SXhtQrazdLL3Ue/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1VwRJgdWUo-2nQaNM7Bs77-fsm8iwUxEo/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1wFzGRo5iYA13WLi6IV1ry64RyahQBFio/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1IKtQzQ-n-UTv64hYpReu2R4cqUvmNQqD/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1GicVci9OiuuZZH79i5Mg7AtWod94MzwT/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1JVnIoR7EIQp70T4eAf9RX65JcTrzsjQc/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1W2xr4h23ucjPrc-mBEeqnACsfaImpc0p/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/10xj_0V7A07o3uCa7v5omUrTC0YlPW8H3/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1FOc3EMaCy8Mb0_a7PuXLAwKwvxkbKmwU/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/143PgDXBcf2GQ0Q07ZPMVMfBgZDd5sLJG/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1pE5Tyj0LlGbGWvUzuhixp86Ibu55Ez3I/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/141668b1VzX80ncrVJPzhkoAeIFB4MEK9/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1bw12lo37p1ZvRvErHsll7cEYi2OxscvZ/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1zfnMFvbgBjl6SzYhksbaOzfbwLrCN6tb/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1-GIszA6mUJMaNB-tdh9r9skc77SWA0VX/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1fTB0zWFYU6zh4IIUFT2zX_OkwYqmElwY/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1gPIPNKGmrO9c7gKF7SP0SuUYbIBBq8z1/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/12JeJ-dQd5lYyn6PlDOGdE-ChVeiZ-Uv0/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/100_20cgCqerU6qoh3TfTbwLy9mlDAFEG/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/111oAGJ76ku_pYgbBoIdZAC1_XEQcPI__/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1UhC8L-354ZQ2gblPFGI35EMsVwfpuKa0/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1sIXQSgUR_xdrNtGrL6QGBnkLMKErsIp1/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/16Ax77bDSIXnsn4GFL8XYKKT1P6bPpfMd/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1pgRVYwwVIsWq_qsWqZpe1UBzZfF5Fa9D/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1jtimaZkWsY1P5gC2bbS64H_WCUU7HXN2/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1N6Bh02P-RiTEgtx1YH1Db_X3TGpP-X_r/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/14Fy8EwJ8d9Vh97Yt1VOvUChSCrfIjBij/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1IRuv42dvIMPuKhcMZmuXaBjJ-lPFOmQd/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/16XWzNY2D8ucVVn5geBgsVdhm3ppO4que/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1xsVOoQgthK_L_SDrmq_JvQgUpAvPEAY8/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1bZbw66DyEMvnJnzkdUUNbKjvNKg8KFYM/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1CyTVkdrNGGpouCXr4CfhKbMzE6Ah3oo3/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1hDRyeM-XEDpHXpptbT8LvNnlQUR3PWOh/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1XhHWxbra8Iy5irQZ83IvxwaJqHq9x4s1/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1haZcn6aM1o4JlmP9tJj3x2enrxiPaDSD/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1ypDyuUTbljaBZ34f-t7lj3O_0bRmyX2n/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1ILEEZo_tA9_ChIAprr2mPaNVKZi5vXsO/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1U7nVYFaGE8vVTfLCW33D74xOjDcqfgyJ/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1rZ93_rmCov5SMDxPkfM3qthcRELZrQX6/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1mYO1b_csddtyE3qT6cwLiw-m2w2_1Lxh/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1xz7Q5x2jikY8wJQjMRQpRws6AnfWlHm5/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1OO8GaO-0FrSZRd1kxMYwBmubyiLOWnbl/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1EXn4NVDmf-4_HCy34mYwT-vwK2CFI9ev/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/10hH70XhXRL9C5SnAG4toHtfHqfJUJo4H/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/18tiBcxea0guUai4lwsXQvt0q2LZ8ZnnJ/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1Q8R8qv37vk5PQ5kQ2ibx6BFLOySD0VpX/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/17aNriHzjhdibCyuUjQoMFZqjybJZtggG/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1LVjEYHSdeKm6CotU1QguIeNEPaIaFl_1/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1ufAhE_EkgJ85slg2EW8aW_grOzE_Lmxd/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1wtzLtXrkw9eXRGESTPIOlpl1tInu-b2m/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1Mk5qvVtD_QHwGOUApRq76TUw2T5THu6f/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1y1WQ3hboWVJ68KEYQQ3OhreGuaUpSgwc/view?usp=drive_link
|
|
|
@ -1,52 +0,0 @@
|
||||||
https://drive.google.com/drive/folders/1dxWh6YFZUDt6qXIoxgD9bla3CiFjZ11C
|
|
||||||
https://drive.google.com/file/d/1hNBJN00SCAlOl0ZEgm7RRGbAGDjyBs0p/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/17He0CVwXGeoMmXg4SHKo-osNn7YPKVL7/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1laNKUVID1x2CV6a2O2WQjwFewKu4lidL/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1pNf36xbZJGRArYLmNAvRj5y6CoqdC6kB/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1_4E1-y3JXk5I0ebycLYM70YDPK9g52gZ/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1PHfzhGPdbolKyOpS3FnR2w7Q8zUlJXSk/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/17ls2PPN-Pi3tEuK059cwV2_iDT8aGhOO/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1LWsg6PmCT00Kv_N_slrmcwKmQPGoBT3k/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/12LckrchoHTUVH7rxi8J7zD9dA19GXvoW/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1VqrJKjAIkj5gtFXL69grdSeu9CyaqnSw/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1g5rQYDBZvW-kUtYPeyF3qmd53v6k7kXu/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/10kUgaSJ0TS7teaG83G3Rf_DG4XGrBt6A/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1je9XmneZQZvTma5adMJICUPDovW3ppei/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1v28r6bedwZGbUPVVTVImXhK-42XdtGfj/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1-TEEx9sGVvzMMaNXYfQMtY2JJ6cvl0dT/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1YdBKdJFP9rJWBUX7qrOYL_gfUA8o6J9M/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1X9vffwQHNUSKLXr2RlYNtbWDIFCIDfdF/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/11hqesqa5kvEe5FABUnZRcvmOhR373cYM/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1ltTTECjEcbQPgS3UPRgMzaE2x9n6H7dC/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1Zxqfa29JdwT-bfMpivi6IG2vz34d21dD/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/11LQlVxS5hz494dYUJ_PNRPx2NHIJbQns/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1i1JhNtnZpO_E8rAv8gxBP3ZTZRvcvsZi/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/11jOXAr2EULUO4Qkm748634lg4UUFho5U/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1rj67wur8DdB_Pipwx24bY43xu4X1eQ5e/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/15ZTm6lO6f_JQy_4SNfrOu3iPYn1Ro8mh/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1q4gBtqWPJtCwXEvknGgN0WHGp7Vfn1b9/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1t17keyre47AYqm8GgXiQ7EcvcUkeSiDQ/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1OYUPGxtZgOF86Ng_BEOTXm_XOYpuQPsO/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1cBjbGHi3dwWHtx6r9EQJi0JT_CE3LuHt/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/14qaMyF0mcbCB-fCYKNyo5_2NahSC6D5u/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/12FgX86eA7Y5co9ULBVK80XMsiKQSs-Ri/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1yvoHWidf-jdBVw6qCCXOFfkVwKj_2hPk/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1a2SugsSDlC8UtUrFzp-_KAwyZckQOvdQ/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1l8pILBFSAosypWJMza2K09Vm7rug9axm/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1hfPQ8dBCk97PnOhq6_MIISm3IEzcOxJG/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1PPAUwlJCFKpms8cqF_k1v2_fCgDBOc3S/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1lVKQZeqFfK3amEmLuFhYLUFQ2eyE8rOW/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1K9iPMLfDowcIFoyzpvgn88dQ6x6kVwNG/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1PNvMqG9tL7QxeLaYBGHiWYR6SYb5iIct/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1xkRtzbvIkUsylx9hrFLGQsJn0h1EYu-5/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1nxMRrJlSayjDIfr5CmHO1NzAw3COhsLi/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1Qs3WEyMGrmagiHIkkFEueWNnJhkUeR1s/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1D-G2_Q0SS3M8zyJbg_XzkF2ANPw1HTuX/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1mdmJsDGO-YtJAOF_yPKl6lq4PJOIbQhT/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/11m9bwfop_sPmnQr_8amB6EEsrbAeG_z5/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/19tyYt5FMn5kru0g9o2nMJhKPnsDqkIZv/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1XvTpUdsVTZ-vydvdYYmynbma--HfUGSl/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1MO3hFu68J6NohTzr9aB_fY02VA6QSOqj/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1Lh-UjwAk__04YOTWINF_QGVU8SjetVaY/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1jkSOUwZV5GJ7rZlVeErjcu0DBQs8Np0d/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1VIN1eLI-93WrVQwCjsv6XQr353DqqBYA/view?usp=drive_link
|
|
|
@ -1,8 +0,0 @@
|
||||||
https://drive.google.com/drive/folders/1EgKar7rWBmTIRmeJYZciSwjZx3uP2mHO
|
|
||||||
https://drive.google.com/file/d/12eYWQO15atK2hBjXhynPJd9MKAj_42pz/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1Ul4oEeICJDjgfYTl4H1uaisTzVYIM6wd/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1WSF-OG8lKSe2wVYCv5D1aJNipxpgddk-/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1_ppD5j5sFh26aWW0JmhLzJMeNB-lCArk/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1WUp846dgWXYhu4oJfhHxiU6YL_7N6s4W/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1HRZNAIoAQw_uYiPwnBvtBioQoqiqoXdA/view?usp=drive_link
|
|
||||||
https://drive.google.com/file/d/1hedGq-QDMnIn8GlXXBC3GiEJ_Y-LTxyt/view?usp=drive_link
|
|
|
@ -1,634 +0,0 @@
|
||||||
#!/usr/bin/env python
|
|
||||||
|
|
||||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""Helper code for loading PushT dataset from Diffusion Policy (https://diffusion-policy.cs.columbia.edu/)
|
|
||||||
|
|
||||||
Copied from the original Diffusion Policy repository and used in our `download_and_upload_dataset.py` script.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import math
|
|
||||||
import numbers
|
|
||||||
import os
|
|
||||||
from functools import cached_property
|
|
||||||
|
|
||||||
import numcodecs
|
|
||||||
import numpy as np
|
|
||||||
import zarr
|
|
||||||
|
|
||||||
|
|
||||||
def check_chunks_compatible(chunks: tuple, shape: tuple):
|
|
||||||
assert len(shape) == len(chunks)
|
|
||||||
for c in chunks:
|
|
||||||
assert isinstance(c, numbers.Integral)
|
|
||||||
assert c > 0
|
|
||||||
|
|
||||||
|
|
||||||
def rechunk_recompress_array(group, name, chunks=None, chunk_length=None, compressor=None, tmp_key="_temp"):
|
|
||||||
old_arr = group[name]
|
|
||||||
if chunks is None:
|
|
||||||
chunks = (chunk_length,) + old_arr.chunks[1:] if chunk_length is not None else old_arr.chunks
|
|
||||||
check_chunks_compatible(chunks, old_arr.shape)
|
|
||||||
|
|
||||||
if compressor is None:
|
|
||||||
compressor = old_arr.compressor
|
|
||||||
|
|
||||||
if (chunks == old_arr.chunks) and (compressor == old_arr.compressor):
|
|
||||||
# no change
|
|
||||||
return old_arr
|
|
||||||
|
|
||||||
# rechunk recompress
|
|
||||||
group.move(name, tmp_key)
|
|
||||||
old_arr = group[tmp_key]
|
|
||||||
n_copied, n_skipped, n_bytes_copied = zarr.copy(
|
|
||||||
source=old_arr,
|
|
||||||
dest=group,
|
|
||||||
name=name,
|
|
||||||
chunks=chunks,
|
|
||||||
compressor=compressor,
|
|
||||||
)
|
|
||||||
del group[tmp_key]
|
|
||||||
arr = group[name]
|
|
||||||
return arr
|
|
||||||
|
|
||||||
|
|
||||||
def get_optimal_chunks(shape, dtype, target_chunk_bytes=2e6, max_chunk_length=None):
|
|
||||||
"""
|
|
||||||
Common shapes
|
|
||||||
T,D
|
|
||||||
T,N,D
|
|
||||||
T,H,W,C
|
|
||||||
T,N,H,W,C
|
|
||||||
"""
|
|
||||||
itemsize = np.dtype(dtype).itemsize
|
|
||||||
# reversed
|
|
||||||
rshape = list(shape[::-1])
|
|
||||||
if max_chunk_length is not None:
|
|
||||||
rshape[-1] = int(max_chunk_length)
|
|
||||||
split_idx = len(shape) - 1
|
|
||||||
for i in range(len(shape) - 1):
|
|
||||||
this_chunk_bytes = itemsize * np.prod(rshape[:i])
|
|
||||||
next_chunk_bytes = itemsize * np.prod(rshape[: i + 1])
|
|
||||||
if this_chunk_bytes <= target_chunk_bytes and next_chunk_bytes > target_chunk_bytes:
|
|
||||||
split_idx = i
|
|
||||||
|
|
||||||
rchunks = rshape[:split_idx]
|
|
||||||
item_chunk_bytes = itemsize * np.prod(rshape[:split_idx])
|
|
||||||
this_max_chunk_length = rshape[split_idx]
|
|
||||||
next_chunk_length = min(this_max_chunk_length, math.ceil(target_chunk_bytes / item_chunk_bytes))
|
|
||||||
rchunks.append(next_chunk_length)
|
|
||||||
len_diff = len(shape) - len(rchunks)
|
|
||||||
rchunks.extend([1] * len_diff)
|
|
||||||
chunks = tuple(rchunks[::-1])
|
|
||||||
# print(np.prod(chunks) * itemsize / target_chunk_bytes)
|
|
||||||
return chunks
|
|
||||||
|
|
||||||
|
|
||||||
class ReplayBuffer:
|
|
||||||
"""
|
|
||||||
Zarr-based temporal datastructure.
|
|
||||||
Assumes first dimension to be time. Only chunk in time dimension.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, root: zarr.Group | dict[str, dict]):
|
|
||||||
"""
|
|
||||||
Dummy constructor. Use copy_from* and create_from* class methods instead.
|
|
||||||
"""
|
|
||||||
assert "data" in root
|
|
||||||
assert "meta" in root
|
|
||||||
assert "episode_ends" in root["meta"]
|
|
||||||
for value in root["data"].values():
|
|
||||||
assert value.shape[0] == root["meta"]["episode_ends"][-1]
|
|
||||||
self.root = root
|
|
||||||
|
|
||||||
# ============= create constructors ===============
|
|
||||||
@classmethod
|
|
||||||
def create_empty_zarr(cls, storage=None, root=None):
|
|
||||||
if root is None:
|
|
||||||
if storage is None:
|
|
||||||
storage = zarr.MemoryStore()
|
|
||||||
root = zarr.group(store=storage)
|
|
||||||
root.require_group("data", overwrite=False)
|
|
||||||
meta = root.require_group("meta", overwrite=False)
|
|
||||||
if "episode_ends" not in meta:
|
|
||||||
meta.zeros("episode_ends", shape=(0,), dtype=np.int64, compressor=None, overwrite=False)
|
|
||||||
return cls(root=root)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def create_empty_numpy(cls):
|
|
||||||
root = {"data": {}, "meta": {"episode_ends": np.zeros((0,), dtype=np.int64)}}
|
|
||||||
return cls(root=root)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def create_from_group(cls, group, **kwargs):
|
|
||||||
if "data" not in group:
|
|
||||||
# create from stratch
|
|
||||||
buffer = cls.create_empty_zarr(root=group, **kwargs)
|
|
||||||
else:
|
|
||||||
# already exist
|
|
||||||
buffer = cls(root=group, **kwargs)
|
|
||||||
return buffer
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def create_from_path(cls, zarr_path, mode="r", **kwargs):
|
|
||||||
"""
|
|
||||||
Open a on-disk zarr directly (for dataset larger than memory).
|
|
||||||
Slower.
|
|
||||||
"""
|
|
||||||
group = zarr.open(os.path.expanduser(zarr_path), mode)
|
|
||||||
return cls.create_from_group(group, **kwargs)
|
|
||||||
|
|
||||||
# ============= copy constructors ===============
|
|
||||||
@classmethod
|
|
||||||
def copy_from_store(
|
|
||||||
cls,
|
|
||||||
src_store,
|
|
||||||
store=None,
|
|
||||||
keys=None,
|
|
||||||
chunks: dict[str, tuple] | None = None,
|
|
||||||
compressors: dict | str | numcodecs.abc.Codec | None = None,
|
|
||||||
if_exists="replace",
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Load to memory.
|
|
||||||
"""
|
|
||||||
src_root = zarr.group(src_store)
|
|
||||||
if chunks is None:
|
|
||||||
chunks = {}
|
|
||||||
if compressors is None:
|
|
||||||
compressors = {}
|
|
||||||
root = None
|
|
||||||
if store is None:
|
|
||||||
# numpy backend
|
|
||||||
meta = {}
|
|
||||||
for key, value in src_root["meta"].items():
|
|
||||||
if len(value.shape) == 0:
|
|
||||||
meta[key] = np.array(value)
|
|
||||||
else:
|
|
||||||
meta[key] = value[:]
|
|
||||||
|
|
||||||
if keys is None:
|
|
||||||
keys = src_root["data"].keys()
|
|
||||||
data = {}
|
|
||||||
for key in keys:
|
|
||||||
arr = src_root["data"][key]
|
|
||||||
data[key] = arr[:]
|
|
||||||
|
|
||||||
root = {"meta": meta, "data": data}
|
|
||||||
else:
|
|
||||||
root = zarr.group(store=store)
|
|
||||||
# copy without recompression
|
|
||||||
n_copied, n_skipped, n_bytes_copied = zarr.copy_store(
|
|
||||||
source=src_store, dest=store, source_path="/meta", dest_path="/meta", if_exists=if_exists
|
|
||||||
)
|
|
||||||
data_group = root.create_group("data", overwrite=True)
|
|
||||||
if keys is None:
|
|
||||||
keys = src_root["data"].keys()
|
|
||||||
for key in keys:
|
|
||||||
value = src_root["data"][key]
|
|
||||||
cks = cls._resolve_array_chunks(chunks=chunks, key=key, array=value)
|
|
||||||
cpr = cls._resolve_array_compressor(compressors=compressors, key=key, array=value)
|
|
||||||
if cks == value.chunks and cpr == value.compressor:
|
|
||||||
# copy without recompression
|
|
||||||
this_path = "/data/" + key
|
|
||||||
n_copied, n_skipped, n_bytes_copied = zarr.copy_store(
|
|
||||||
source=src_store,
|
|
||||||
dest=store,
|
|
||||||
source_path=this_path,
|
|
||||||
dest_path=this_path,
|
|
||||||
if_exists=if_exists,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# copy with recompression
|
|
||||||
n_copied, n_skipped, n_bytes_copied = zarr.copy(
|
|
||||||
source=value,
|
|
||||||
dest=data_group,
|
|
||||||
name=key,
|
|
||||||
chunks=cks,
|
|
||||||
compressor=cpr,
|
|
||||||
if_exists=if_exists,
|
|
||||||
)
|
|
||||||
buffer = cls(root=root)
|
|
||||||
return buffer
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def copy_from_path(
|
|
||||||
cls,
|
|
||||||
zarr_path,
|
|
||||||
backend=None,
|
|
||||||
store=None,
|
|
||||||
keys=None,
|
|
||||||
chunks: dict[str, tuple] | None = None,
|
|
||||||
compressors: dict | str | numcodecs.abc.Codec | None = None,
|
|
||||||
if_exists="replace",
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Copy a on-disk zarr to in-memory compressed.
|
|
||||||
Recommended
|
|
||||||
"""
|
|
||||||
if chunks is None:
|
|
||||||
chunks = {}
|
|
||||||
if compressors is None:
|
|
||||||
compressors = {}
|
|
||||||
if backend == "numpy":
|
|
||||||
print("backend argument is deprecated!")
|
|
||||||
store = None
|
|
||||||
group = zarr.open(os.path.expanduser(zarr_path), "r")
|
|
||||||
return cls.copy_from_store(
|
|
||||||
src_store=group.store,
|
|
||||||
store=store,
|
|
||||||
keys=keys,
|
|
||||||
chunks=chunks,
|
|
||||||
compressors=compressors,
|
|
||||||
if_exists=if_exists,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
# ============= save methods ===============
|
|
||||||
def save_to_store(
|
|
||||||
self,
|
|
||||||
store,
|
|
||||||
chunks: dict[str, tuple] | None = None,
|
|
||||||
compressors: str | numcodecs.abc.Codec | dict | None = None,
|
|
||||||
if_exists="replace",
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
root = zarr.group(store)
|
|
||||||
if chunks is None:
|
|
||||||
chunks = {}
|
|
||||||
if compressors is None:
|
|
||||||
compressors = {}
|
|
||||||
if self.backend == "zarr":
|
|
||||||
# recompression free copy
|
|
||||||
n_copied, n_skipped, n_bytes_copied = zarr.copy_store(
|
|
||||||
source=self.root.store,
|
|
||||||
dest=store,
|
|
||||||
source_path="/meta",
|
|
||||||
dest_path="/meta",
|
|
||||||
if_exists=if_exists,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
meta_group = root.create_group("meta", overwrite=True)
|
|
||||||
# save meta, no chunking
|
|
||||||
for key, value in self.root["meta"].items():
|
|
||||||
_ = meta_group.array(name=key, data=value, shape=value.shape, chunks=value.shape)
|
|
||||||
|
|
||||||
# save data, chunk
|
|
||||||
data_group = root.create_group("data", overwrite=True)
|
|
||||||
for key, value in self.root["data"].items():
|
|
||||||
cks = self._resolve_array_chunks(chunks=chunks, key=key, array=value)
|
|
||||||
cpr = self._resolve_array_compressor(compressors=compressors, key=key, array=value)
|
|
||||||
if isinstance(value, zarr.Array):
|
|
||||||
if cks == value.chunks and cpr == value.compressor:
|
|
||||||
# copy without recompression
|
|
||||||
this_path = "/data/" + key
|
|
||||||
n_copied, n_skipped, n_bytes_copied = zarr.copy_store(
|
|
||||||
source=self.root.store,
|
|
||||||
dest=store,
|
|
||||||
source_path=this_path,
|
|
||||||
dest_path=this_path,
|
|
||||||
if_exists=if_exists,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# copy with recompression
|
|
||||||
n_copied, n_skipped, n_bytes_copied = zarr.copy(
|
|
||||||
source=value,
|
|
||||||
dest=data_group,
|
|
||||||
name=key,
|
|
||||||
chunks=cks,
|
|
||||||
compressor=cpr,
|
|
||||||
if_exists=if_exists,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# numpy
|
|
||||||
_ = data_group.array(name=key, data=value, chunks=cks, compressor=cpr)
|
|
||||||
return store
|
|
||||||
|
|
||||||
def save_to_path(
|
|
||||||
self,
|
|
||||||
zarr_path,
|
|
||||||
chunks: dict[str, tuple] | None = None,
|
|
||||||
compressors: str | numcodecs.abc.Codec | dict | None = None,
|
|
||||||
if_exists="replace",
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
if chunks is None:
|
|
||||||
chunks = {}
|
|
||||||
if compressors is None:
|
|
||||||
compressors = {}
|
|
||||||
store = zarr.DirectoryStore(os.path.expanduser(zarr_path))
|
|
||||||
return self.save_to_store(
|
|
||||||
store, chunks=chunks, compressors=compressors, if_exists=if_exists, **kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def resolve_compressor(compressor="default"):
|
|
||||||
if compressor == "default":
|
|
||||||
compressor = numcodecs.Blosc(cname="lz4", clevel=5, shuffle=numcodecs.Blosc.NOSHUFFLE)
|
|
||||||
elif compressor == "disk":
|
|
||||||
compressor = numcodecs.Blosc("zstd", clevel=5, shuffle=numcodecs.Blosc.BITSHUFFLE)
|
|
||||||
return compressor
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _resolve_array_compressor(cls, compressors: dict | str | numcodecs.abc.Codec, key, array):
|
|
||||||
# allows compressor to be explicitly set to None
|
|
||||||
cpr = "nil"
|
|
||||||
if isinstance(compressors, dict):
|
|
||||||
if key in compressors:
|
|
||||||
cpr = cls.resolve_compressor(compressors[key])
|
|
||||||
elif isinstance(array, zarr.Array):
|
|
||||||
cpr = array.compressor
|
|
||||||
else:
|
|
||||||
cpr = cls.resolve_compressor(compressors)
|
|
||||||
# backup default
|
|
||||||
if cpr == "nil":
|
|
||||||
cpr = cls.resolve_compressor("default")
|
|
||||||
return cpr
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _resolve_array_chunks(cls, chunks: dict | tuple, key, array):
|
|
||||||
cks = None
|
|
||||||
if isinstance(chunks, dict):
|
|
||||||
if key in chunks:
|
|
||||||
cks = chunks[key]
|
|
||||||
elif isinstance(array, zarr.Array):
|
|
||||||
cks = array.chunks
|
|
||||||
elif isinstance(chunks, tuple):
|
|
||||||
cks = chunks
|
|
||||||
else:
|
|
||||||
raise TypeError(f"Unsupported chunks type {type(chunks)}")
|
|
||||||
# backup default
|
|
||||||
if cks is None:
|
|
||||||
cks = get_optimal_chunks(shape=array.shape, dtype=array.dtype)
|
|
||||||
# check
|
|
||||||
check_chunks_compatible(chunks=cks, shape=array.shape)
|
|
||||||
return cks
|
|
||||||
|
|
||||||
# ============= properties =================
|
|
||||||
@cached_property
|
|
||||||
def data(self):
|
|
||||||
return self.root["data"]
|
|
||||||
|
|
||||||
@cached_property
|
|
||||||
def meta(self):
|
|
||||||
return self.root["meta"]
|
|
||||||
|
|
||||||
def update_meta(self, data):
|
|
||||||
# sanitize data
|
|
||||||
np_data = {}
|
|
||||||
for key, value in data.items():
|
|
||||||
if isinstance(value, np.ndarray):
|
|
||||||
np_data[key] = value
|
|
||||||
else:
|
|
||||||
arr = np.array(value)
|
|
||||||
if arr.dtype == object:
|
|
||||||
raise TypeError(f"Invalid value type {type(value)}")
|
|
||||||
np_data[key] = arr
|
|
||||||
|
|
||||||
meta_group = self.meta
|
|
||||||
if self.backend == "zarr":
|
|
||||||
for key, value in np_data.items():
|
|
||||||
_ = meta_group.array(
|
|
||||||
name=key, data=value, shape=value.shape, chunks=value.shape, overwrite=True
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
meta_group.update(np_data)
|
|
||||||
|
|
||||||
return meta_group
|
|
||||||
|
|
||||||
@property
|
|
||||||
def episode_ends(self):
|
|
||||||
return self.meta["episode_ends"]
|
|
||||||
|
|
||||||
def get_episode_idxs(self):
|
|
||||||
import numba
|
|
||||||
|
|
||||||
numba.jit(nopython=True)
|
|
||||||
|
|
||||||
def _get_episode_idxs(episode_ends):
|
|
||||||
result = np.zeros((episode_ends[-1],), dtype=np.int64)
|
|
||||||
for i in range(len(episode_ends)):
|
|
||||||
start = 0
|
|
||||||
if i > 0:
|
|
||||||
start = episode_ends[i - 1]
|
|
||||||
end = episode_ends[i]
|
|
||||||
for idx in range(start, end):
|
|
||||||
result[idx] = i
|
|
||||||
return result
|
|
||||||
|
|
||||||
return _get_episode_idxs(self.episode_ends)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def backend(self):
|
|
||||||
backend = "numpy"
|
|
||||||
if isinstance(self.root, zarr.Group):
|
|
||||||
backend = "zarr"
|
|
||||||
return backend
|
|
||||||
|
|
||||||
# =========== dict-like API ==============
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
if self.backend == "zarr":
|
|
||||||
return str(self.root.tree())
|
|
||||||
else:
|
|
||||||
return super().__repr__()
|
|
||||||
|
|
||||||
def keys(self):
|
|
||||||
return self.data.keys()
|
|
||||||
|
|
||||||
def values(self):
|
|
||||||
return self.data.values()
|
|
||||||
|
|
||||||
def items(self):
|
|
||||||
return self.data.items()
|
|
||||||
|
|
||||||
def __getitem__(self, key):
|
|
||||||
return self.data[key]
|
|
||||||
|
|
||||||
def __contains__(self, key):
|
|
||||||
return key in self.data
|
|
||||||
|
|
||||||
# =========== our API ==============
|
|
||||||
@property
|
|
||||||
def n_steps(self):
|
|
||||||
if len(self.episode_ends) == 0:
|
|
||||||
return 0
|
|
||||||
return self.episode_ends[-1]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def n_episodes(self):
|
|
||||||
return len(self.episode_ends)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def chunk_size(self):
|
|
||||||
if self.backend == "zarr":
|
|
||||||
return next(iter(self.data.arrays()))[-1].chunks[0]
|
|
||||||
return None
|
|
||||||
|
|
||||||
@property
|
|
||||||
def episode_lengths(self):
|
|
||||||
ends = self.episode_ends[:]
|
|
||||||
ends = np.insert(ends, 0, 0)
|
|
||||||
lengths = np.diff(ends)
|
|
||||||
return lengths
|
|
||||||
|
|
||||||
def add_episode(
|
|
||||||
self,
|
|
||||||
data: dict[str, np.ndarray],
|
|
||||||
chunks: dict[str, tuple] | None = None,
|
|
||||||
compressors: str | numcodecs.abc.Codec | dict | None = None,
|
|
||||||
):
|
|
||||||
if chunks is None:
|
|
||||||
chunks = {}
|
|
||||||
if compressors is None:
|
|
||||||
compressors = {}
|
|
||||||
assert len(data) > 0
|
|
||||||
is_zarr = self.backend == "zarr"
|
|
||||||
|
|
||||||
curr_len = self.n_steps
|
|
||||||
episode_length = None
|
|
||||||
for value in data.values():
|
|
||||||
assert len(value.shape) >= 1
|
|
||||||
if episode_length is None:
|
|
||||||
episode_length = len(value)
|
|
||||||
else:
|
|
||||||
assert episode_length == len(value)
|
|
||||||
new_len = curr_len + episode_length
|
|
||||||
|
|
||||||
for key, value in data.items():
|
|
||||||
new_shape = (new_len,) + value.shape[1:]
|
|
||||||
# create array
|
|
||||||
if key not in self.data:
|
|
||||||
if is_zarr:
|
|
||||||
cks = self._resolve_array_chunks(chunks=chunks, key=key, array=value)
|
|
||||||
cpr = self._resolve_array_compressor(compressors=compressors, key=key, array=value)
|
|
||||||
arr = self.data.zeros(
|
|
||||||
name=key, shape=new_shape, chunks=cks, dtype=value.dtype, compressor=cpr
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# copy data to prevent modify
|
|
||||||
arr = np.zeros(shape=new_shape, dtype=value.dtype)
|
|
||||||
self.data[key] = arr
|
|
||||||
else:
|
|
||||||
arr = self.data[key]
|
|
||||||
assert value.shape[1:] == arr.shape[1:]
|
|
||||||
# same method for both zarr and numpy
|
|
||||||
if is_zarr:
|
|
||||||
arr.resize(new_shape)
|
|
||||||
else:
|
|
||||||
arr.resize(new_shape, refcheck=False)
|
|
||||||
# copy data
|
|
||||||
arr[-value.shape[0] :] = value
|
|
||||||
|
|
||||||
# append to episode ends
|
|
||||||
episode_ends = self.episode_ends
|
|
||||||
if is_zarr:
|
|
||||||
episode_ends.resize(episode_ends.shape[0] + 1)
|
|
||||||
else:
|
|
||||||
episode_ends.resize(episode_ends.shape[0] + 1, refcheck=False)
|
|
||||||
episode_ends[-1] = new_len
|
|
||||||
|
|
||||||
# rechunk
|
|
||||||
if is_zarr and episode_ends.chunks[0] < episode_ends.shape[0]:
|
|
||||||
rechunk_recompress_array(self.meta, "episode_ends", chunk_length=int(episode_ends.shape[0] * 1.5))
|
|
||||||
|
|
||||||
def drop_episode(self):
|
|
||||||
is_zarr = self.backend == "zarr"
|
|
||||||
episode_ends = self.episode_ends[:].copy()
|
|
||||||
assert len(episode_ends) > 0
|
|
||||||
start_idx = 0
|
|
||||||
if len(episode_ends) > 1:
|
|
||||||
start_idx = episode_ends[-2]
|
|
||||||
for value in self.data.values():
|
|
||||||
new_shape = (start_idx,) + value.shape[1:]
|
|
||||||
if is_zarr:
|
|
||||||
value.resize(new_shape)
|
|
||||||
else:
|
|
||||||
value.resize(new_shape, refcheck=False)
|
|
||||||
if is_zarr:
|
|
||||||
self.episode_ends.resize(len(episode_ends) - 1)
|
|
||||||
else:
|
|
||||||
self.episode_ends.resize(len(episode_ends) - 1, refcheck=False)
|
|
||||||
|
|
||||||
def pop_episode(self):
|
|
||||||
assert self.n_episodes > 0
|
|
||||||
episode = self.get_episode(self.n_episodes - 1, copy=True)
|
|
||||||
self.drop_episode()
|
|
||||||
return episode
|
|
||||||
|
|
||||||
def extend(self, data):
|
|
||||||
self.add_episode(data)
|
|
||||||
|
|
||||||
def get_episode(self, idx, copy=False):
|
|
||||||
idx = list(range(len(self.episode_ends)))[idx]
|
|
||||||
start_idx = 0
|
|
||||||
if idx > 0:
|
|
||||||
start_idx = self.episode_ends[idx - 1]
|
|
||||||
end_idx = self.episode_ends[idx]
|
|
||||||
result = self.get_steps_slice(start_idx, end_idx, copy=copy)
|
|
||||||
return result
|
|
||||||
|
|
||||||
def get_episode_slice(self, idx):
|
|
||||||
start_idx = 0
|
|
||||||
if idx > 0:
|
|
||||||
start_idx = self.episode_ends[idx - 1]
|
|
||||||
end_idx = self.episode_ends[idx]
|
|
||||||
return slice(start_idx, end_idx)
|
|
||||||
|
|
||||||
def get_steps_slice(self, start, stop, step=None, copy=False):
|
|
||||||
_slice = slice(start, stop, step)
|
|
||||||
|
|
||||||
result = {}
|
|
||||||
for key, value in self.data.items():
|
|
||||||
x = value[_slice]
|
|
||||||
if copy and isinstance(value, np.ndarray):
|
|
||||||
x = x.copy()
|
|
||||||
result[key] = x
|
|
||||||
return result
|
|
||||||
|
|
||||||
# =========== chunking =============
|
|
||||||
def get_chunks(self) -> dict:
|
|
||||||
assert self.backend == "zarr"
|
|
||||||
chunks = {}
|
|
||||||
for key, value in self.data.items():
|
|
||||||
chunks[key] = value.chunks
|
|
||||||
return chunks
|
|
||||||
|
|
||||||
def set_chunks(self, chunks: dict):
|
|
||||||
assert self.backend == "zarr"
|
|
||||||
for key, value in chunks.items():
|
|
||||||
if key in self.data:
|
|
||||||
arr = self.data[key]
|
|
||||||
if value != arr.chunks:
|
|
||||||
check_chunks_compatible(chunks=value, shape=arr.shape)
|
|
||||||
rechunk_recompress_array(self.data, key, chunks=value)
|
|
||||||
|
|
||||||
def get_compressors(self) -> dict:
|
|
||||||
assert self.backend == "zarr"
|
|
||||||
compressors = {}
|
|
||||||
for key, value in self.data.items():
|
|
||||||
compressors[key] = value.compressor
|
|
||||||
return compressors
|
|
||||||
|
|
||||||
def set_compressors(self, compressors: dict):
|
|
||||||
assert self.backend == "zarr"
|
|
||||||
for key, value in compressors.items():
|
|
||||||
if key in self.data:
|
|
||||||
arr = self.data[key]
|
|
||||||
compressor = self.resolve_compressor(value)
|
|
||||||
if compressor != arr.compressor:
|
|
||||||
rechunk_recompress_array(self.data, key, compressor=compressor)
|
|
|
@ -1,202 +0,0 @@
|
||||||
#!/usr/bin/env python
|
|
||||||
|
|
||||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""
|
|
||||||
This file contains download scripts for raw datasets.
|
|
||||||
|
|
||||||
Example of usage:
|
|
||||||
```
|
|
||||||
python lerobot/common/datasets/push_dataset_to_hub/_download_raw.py \
|
|
||||||
--raw-dir data/lerobot-raw/pusht_raw \
|
|
||||||
--repo-id lerobot-raw/pusht_raw
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import logging
|
|
||||||
import warnings
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from huggingface_hub import snapshot_download
|
|
||||||
|
|
||||||
from lerobot.common.datasets.push_dataset_to_hub.utils import check_repo_id
|
|
||||||
|
|
||||||
# {raw_repo_id: raw_format}
|
|
||||||
AVAILABLE_RAW_REPO_IDS = {
|
|
||||||
"lerobot-raw/aloha_mobile_cabinet_raw": "aloha_hdf5",
|
|
||||||
"lerobot-raw/aloha_mobile_chair_raw": "aloha_hdf5",
|
|
||||||
"lerobot-raw/aloha_mobile_elevator_raw": "aloha_hdf5",
|
|
||||||
"lerobot-raw/aloha_mobile_shrimp_raw": "aloha_hdf5",
|
|
||||||
"lerobot-raw/aloha_mobile_wash_pan_raw": "aloha_hdf5",
|
|
||||||
"lerobot-raw/aloha_mobile_wipe_wine_raw": "aloha_hdf5",
|
|
||||||
"lerobot-raw/aloha_sim_insertion_human_raw": "aloha_hdf5",
|
|
||||||
"lerobot-raw/aloha_sim_insertion_scripted_raw": "aloha_hdf5",
|
|
||||||
"lerobot-raw/aloha_sim_transfer_cube_human_raw": "aloha_hdf5",
|
|
||||||
"lerobot-raw/aloha_sim_transfer_cube_scripted_raw": "aloha_hdf5",
|
|
||||||
"lerobot-raw/aloha_static_battery_raw": "aloha_hdf5",
|
|
||||||
"lerobot-raw/aloha_static_candy_raw": "aloha_hdf5",
|
|
||||||
"lerobot-raw/aloha_static_coffee_new_raw": "aloha_hdf5",
|
|
||||||
"lerobot-raw/aloha_static_coffee_raw": "aloha_hdf5",
|
|
||||||
"lerobot-raw/aloha_static_cups_open_raw": "aloha_hdf5",
|
|
||||||
"lerobot-raw/aloha_static_fork_pick_up_raw": "aloha_hdf5",
|
|
||||||
"lerobot-raw/aloha_static_pingpong_test_raw": "aloha_hdf5",
|
|
||||||
"lerobot-raw/aloha_static_pro_pencil_raw": "aloha_hdf5",
|
|
||||||
"lerobot-raw/aloha_static_screw_driver_raw": "aloha_hdf5",
|
|
||||||
"lerobot-raw/aloha_static_tape_raw": "aloha_hdf5",
|
|
||||||
"lerobot-raw/aloha_static_thread_velcro_raw": "aloha_hdf5",
|
|
||||||
"lerobot-raw/aloha_static_towel_raw": "aloha_hdf5",
|
|
||||||
"lerobot-raw/aloha_static_vinh_cup_left_raw": "aloha_hdf5",
|
|
||||||
"lerobot-raw/aloha_static_vinh_cup_raw": "aloha_hdf5",
|
|
||||||
"lerobot-raw/aloha_static_ziploc_slide_raw": "aloha_hdf5",
|
|
||||||
"lerobot-raw/umi_cup_in_the_wild_raw": "umi_zarr",
|
|
||||||
"lerobot-raw/pusht_raw": "pusht_zarr",
|
|
||||||
"lerobot-raw/unitreeh1_fold_clothes_raw": "aloha_hdf5",
|
|
||||||
"lerobot-raw/unitreeh1_rearrange_objects_raw": "aloha_hdf5",
|
|
||||||
"lerobot-raw/unitreeh1_two_robot_greeting_raw": "aloha_hdf5",
|
|
||||||
"lerobot-raw/unitreeh1_warehouse_raw": "aloha_hdf5",
|
|
||||||
"lerobot-raw/xarm_lift_medium_raw": "xarm_pkl",
|
|
||||||
"lerobot-raw/xarm_lift_medium_replay_raw": "xarm_pkl",
|
|
||||||
"lerobot-raw/xarm_push_medium_raw": "xarm_pkl",
|
|
||||||
"lerobot-raw/xarm_push_medium_replay_raw": "xarm_pkl",
|
|
||||||
"lerobot-raw/fractal20220817_data_raw": "openx_rlds.fractal20220817_data",
|
|
||||||
"lerobot-raw/kuka_raw": "openx_rlds.kuka",
|
|
||||||
"lerobot-raw/bridge_openx_raw": "openx_rlds.bridge_openx",
|
|
||||||
"lerobot-raw/taco_play_raw": "openx_rlds.taco_play",
|
|
||||||
"lerobot-raw/jaco_play_raw": "openx_rlds.jaco_play",
|
|
||||||
"lerobot-raw/berkeley_cable_routing_raw": "openx_rlds.berkeley_cable_routing",
|
|
||||||
"lerobot-raw/roboturk_raw": "openx_rlds.roboturk",
|
|
||||||
"lerobot-raw/nyu_door_opening_surprising_effectiveness_raw": "openx_rlds.nyu_door_opening_surprising_effectiveness",
|
|
||||||
"lerobot-raw/viola_raw": "openx_rlds.viola",
|
|
||||||
"lerobot-raw/berkeley_autolab_ur5_raw": "openx_rlds.berkeley_autolab_ur5",
|
|
||||||
"lerobot-raw/toto_raw": "openx_rlds.toto",
|
|
||||||
"lerobot-raw/language_table_raw": "openx_rlds.language_table",
|
|
||||||
"lerobot-raw/columbia_cairlab_pusht_real_raw": "openx_rlds.columbia_cairlab_pusht_real",
|
|
||||||
"lerobot-raw/stanford_kuka_multimodal_dataset_raw": "openx_rlds.stanford_kuka_multimodal_dataset",
|
|
||||||
"lerobot-raw/nyu_rot_dataset_raw": "openx_rlds.nyu_rot_dataset",
|
|
||||||
"lerobot-raw/io_ai_tech_raw": "openx_rlds.io_ai_tech",
|
|
||||||
"lerobot-raw/stanford_hydra_dataset_raw": "openx_rlds.stanford_hydra_dataset",
|
|
||||||
"lerobot-raw/austin_buds_dataset_raw": "openx_rlds.austin_buds_dataset",
|
|
||||||
"lerobot-raw/nyu_franka_play_dataset_raw": "openx_rlds.nyu_franka_play_dataset",
|
|
||||||
"lerobot-raw/maniskill_dataset_raw": "openx_rlds.maniskill_dataset",
|
|
||||||
"lerobot-raw/furniture_bench_dataset_raw": "openx_rlds.furniture_bench_dataset",
|
|
||||||
"lerobot-raw/cmu_franka_exploration_dataset_raw": "openx_rlds.cmu_franka_exploration_dataset",
|
|
||||||
"lerobot-raw/ucsd_kitchen_dataset_raw": "openx_rlds.ucsd_kitchen_dataset",
|
|
||||||
"lerobot-raw/ucsd_pick_and_place_dataset_raw": "openx_rlds.ucsd_pick_and_place_dataset",
|
|
||||||
"lerobot-raw/spoc_raw": "openx_rlds.spoc",
|
|
||||||
"lerobot-raw/austin_sailor_dataset_raw": "openx_rlds.austin_sailor_dataset",
|
|
||||||
"lerobot-raw/austin_sirius_dataset_raw": "openx_rlds.austin_sirius_dataset",
|
|
||||||
"lerobot-raw/bc_z_raw": "openx_rlds.bc_z",
|
|
||||||
"lerobot-raw/utokyo_pr2_opening_fridge_raw": "openx_rlds.utokyo_pr2_opening_fridge",
|
|
||||||
"lerobot-raw/utokyo_pr2_tabletop_manipulation_raw": "openx_rlds.utokyo_pr2_tabletop_manipulation",
|
|
||||||
"lerobot-raw/utokyo_xarm_pick_and_place_raw": "openx_rlds.utokyo_xarm_pick_and_place",
|
|
||||||
"lerobot-raw/utokyo_xarm_bimanual_raw": "openx_rlds.utokyo_xarm_bimanual",
|
|
||||||
"lerobot-raw/utokyo_saytap_raw": "openx_rlds.utokyo_saytap",
|
|
||||||
"lerobot-raw/robo_net_raw": "openx_rlds.robo_net",
|
|
||||||
"lerobot-raw/robo_set_raw": "openx_rlds.robo_set",
|
|
||||||
"lerobot-raw/berkeley_mvp_raw": "openx_rlds.berkeley_mvp",
|
|
||||||
"lerobot-raw/berkeley_rpt_raw": "openx_rlds.berkeley_rpt",
|
|
||||||
"lerobot-raw/kaist_nonprehensile_raw": "openx_rlds.kaist_nonprehensile",
|
|
||||||
"lerobot-raw/stanford_mask_vit_raw": "openx_rlds.stanford_mask_vit",
|
|
||||||
"lerobot-raw/tokyo_u_lsmo_raw": "openx_rlds.tokyo_u_lsmo",
|
|
||||||
"lerobot-raw/dlr_sara_pour_raw": "openx_rlds.dlr_sara_pour",
|
|
||||||
"lerobot-raw/dlr_sara_grid_clamp_raw": "openx_rlds.dlr_sara_grid_clamp",
|
|
||||||
"lerobot-raw/dlr_edan_shared_control_raw": "openx_rlds.dlr_edan_shared_control",
|
|
||||||
"lerobot-raw/asu_table_top_raw": "openx_rlds.asu_table_top",
|
|
||||||
"lerobot-raw/stanford_robocook_raw": "openx_rlds.stanford_robocook",
|
|
||||||
"lerobot-raw/imperialcollege_sawyer_wrist_cam_raw": "openx_rlds.imperialcollege_sawyer_wrist_cam",
|
|
||||||
"lerobot-raw/iamlab_cmu_pickup_insert_raw": "openx_rlds.iamlab_cmu_pickup_insert",
|
|
||||||
"lerobot-raw/uiuc_d3field_raw": "openx_rlds.uiuc_d3field",
|
|
||||||
"lerobot-raw/utaustin_mutex_raw": "openx_rlds.utaustin_mutex",
|
|
||||||
"lerobot-raw/berkeley_fanuc_manipulation_raw": "openx_rlds.berkeley_fanuc_manipulation",
|
|
||||||
"lerobot-raw/cmu_playing_with_food_raw": "openx_rlds.cmu_playing_with_food",
|
|
||||||
"lerobot-raw/cmu_play_fusion_raw": "openx_rlds.cmu_play_fusion",
|
|
||||||
"lerobot-raw/cmu_stretch_raw": "openx_rlds.cmu_stretch",
|
|
||||||
"lerobot-raw/berkeley_gnm_recon_raw": "openx_rlds.berkeley_gnm_recon",
|
|
||||||
"lerobot-raw/berkeley_gnm_cory_hall_raw": "openx_rlds.berkeley_gnm_cory_hall",
|
|
||||||
"lerobot-raw/berkeley_gnm_sac_son_raw": "openx_rlds.berkeley_gnm_sac_son",
|
|
||||||
"lerobot-raw/droid_raw": "openx_rlds.droid",
|
|
||||||
"lerobot-raw/droid_100_raw": "openx_rlds.droid100",
|
|
||||||
"lerobot-raw/fmb_raw": "openx_rlds.fmb",
|
|
||||||
"lerobot-raw/dobbe_raw": "openx_rlds.dobbe",
|
|
||||||
"lerobot-raw/usc_cloth_sim_raw": "openx_rlds.usc_cloth_sim",
|
|
||||||
"lerobot-raw/plex_robosuite_raw": "openx_rlds.plex_robosuite",
|
|
||||||
"lerobot-raw/conq_hose_manipulation_raw": "openx_rlds.conq_hose_manipulation",
|
|
||||||
"lerobot-raw/vima_raw": "openx_rlds.vima",
|
|
||||||
"lerobot-raw/robot_vqa_raw": "openx_rlds.robot_vqa",
|
|
||||||
"lerobot-raw/mimic_play_raw": "openx_rlds.mimic_play",
|
|
||||||
"lerobot-raw/tidybot_raw": "openx_rlds.tidybot",
|
|
||||||
"lerobot-raw/eth_agent_affordances_raw": "openx_rlds.eth_agent_affordances",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def download_raw(raw_dir: Path, repo_id: str):
|
|
||||||
check_repo_id(repo_id)
|
|
||||||
user_id, dataset_id = repo_id.split("/")
|
|
||||||
|
|
||||||
if not dataset_id.endswith("_raw"):
|
|
||||||
warnings.warn(
|
|
||||||
f"""`dataset_id` ({dataset_id}) doesn't end with '_raw' (e.g. 'lerobot/pusht_raw'). Following this
|
|
||||||
naming convention by renaming your repository is advised, but not mandatory.""",
|
|
||||||
stacklevel=1,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Send warning if raw_dir isn't well formatted
|
|
||||||
if raw_dir.parts[-2] != user_id or raw_dir.parts[-1] != dataset_id:
|
|
||||||
warnings.warn(
|
|
||||||
f"""`raw_dir` ({raw_dir}) doesn't contain a community or user id `/` the name of the dataset that
|
|
||||||
match the `repo_id` (e.g. 'data/lerobot/pusht_raw'). Following this naming convention is advised,
|
|
||||||
but not mandatory.""",
|
|
||||||
stacklevel=1,
|
|
||||||
)
|
|
||||||
raw_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
logging.info(f"Start downloading from huggingface.co/{user_id} for {dataset_id}")
|
|
||||||
snapshot_download(repo_id, repo_type="dataset", local_dir=raw_dir)
|
|
||||||
logging.info(f"Finish downloading from huggingface.co/{user_id} for {dataset_id}")
|
|
||||||
|
|
||||||
|
|
||||||
def download_all_raw_datasets(data_dir: Path | None = None):
|
|
||||||
if data_dir is None:
|
|
||||||
data_dir = Path("data")
|
|
||||||
for repo_id in AVAILABLE_RAW_REPO_IDS:
|
|
||||||
raw_dir = data_dir / repo_id
|
|
||||||
download_raw(raw_dir, repo_id)
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
description=f"""A script to download raw datasets from Hugging Face hub to a local directory. Here is a
|
|
||||||
non exhaustive list of available repositories to use in `--repo-id`: {list(AVAILABLE_RAW_REPO_IDS.keys())}""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--raw-dir",
|
|
||||||
type=Path,
|
|
||||||
required=True,
|
|
||||||
help="Directory containing input raw datasets (e.g. `data/aloha_mobile_chair_raw` or `data/pusht_raw).",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--repo-id",
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
help="""Repositery identifier on Hugging Face: a community or a user name `/` the name of
|
|
||||||
the dataset (e.g. `lerobot/pusht_raw`, `cadene/aloha_sim_insertion_human_raw`).""",
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
download_raw(**vars(args))
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
|
@ -1,184 +0,0 @@
|
||||||
#!/usr/bin/env python
|
|
||||||
|
|
||||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""
|
|
||||||
Use this script to batch encode lerobot dataset from their raw format to LeRobotDataset and push their updated
|
|
||||||
version to the hub. Under the hood, this script reuses 'push_dataset_to_hub.py'. It assumes that you already
|
|
||||||
downloaded raw datasets, which you can do with the related '_download_raw.py' script.
|
|
||||||
|
|
||||||
For instance, for codebase_version = 'v1.6', the following command was run, assuming raw datasets from
|
|
||||||
lerobot-raw were downloaded in 'raw/datasets/directory':
|
|
||||||
```bash
|
|
||||||
python lerobot/common/datasets/push_dataset_to_hub/_encode_datasets.py \
|
|
||||||
--raw-dir raw/datasets/directory \
|
|
||||||
--raw-repo-ids lerobot-raw \
|
|
||||||
--local-dir push/datasets/directory \
|
|
||||||
--tests-data-dir tests/data \
|
|
||||||
--push-repo lerobot \
|
|
||||||
--vcodec libsvtav1 \
|
|
||||||
--pix-fmt yuv420p \
|
|
||||||
--g 2 \
|
|
||||||
--crf 30
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
|
||||||
from lerobot.common.datasets.push_dataset_to_hub._download_raw import AVAILABLE_RAW_REPO_IDS
|
|
||||||
from lerobot.common.datasets.push_dataset_to_hub.utils import check_repo_id
|
|
||||||
from lerobot.scripts.push_dataset_to_hub import push_dataset_to_hub
|
|
||||||
|
|
||||||
|
|
||||||
def get_push_repo_id_from_raw(raw_repo_id: str, push_repo: str) -> str:
|
|
||||||
dataset_id_raw = raw_repo_id.split("/")[1]
|
|
||||||
dataset_id = dataset_id_raw.removesuffix("_raw")
|
|
||||||
return f"{push_repo}/{dataset_id}"
|
|
||||||
|
|
||||||
|
|
||||||
def encode_datasets(
|
|
||||||
raw_dir: Path,
|
|
||||||
raw_repo_ids: list[str],
|
|
||||||
push_repo: str,
|
|
||||||
vcodec: str,
|
|
||||||
pix_fmt: str,
|
|
||||||
g: int,
|
|
||||||
crf: int,
|
|
||||||
local_dir: Path | None = None,
|
|
||||||
tests_data_dir: Path | None = None,
|
|
||||||
raw_format: str | None = None,
|
|
||||||
dry_run: bool = False,
|
|
||||||
) -> None:
|
|
||||||
if len(raw_repo_ids) == 1 and raw_repo_ids[0].lower() == "lerobot-raw":
|
|
||||||
raw_repo_ids_format = AVAILABLE_RAW_REPO_IDS
|
|
||||||
else:
|
|
||||||
if raw_format is None:
|
|
||||||
raise ValueError(raw_format)
|
|
||||||
raw_repo_ids_format = {id_: raw_format for id_ in raw_repo_ids}
|
|
||||||
|
|
||||||
for raw_repo_id, repo_raw_format in raw_repo_ids_format.items():
|
|
||||||
check_repo_id(raw_repo_id)
|
|
||||||
dataset_repo_id_push = get_push_repo_id_from_raw(raw_repo_id, push_repo)
|
|
||||||
dataset_raw_dir = raw_dir / raw_repo_id
|
|
||||||
dataset_dir = local_dir / dataset_repo_id_push if local_dir is not None else None
|
|
||||||
encoding = {
|
|
||||||
"vcodec": vcodec,
|
|
||||||
"pix_fmt": pix_fmt,
|
|
||||||
"g": g,
|
|
||||||
"crf": crf,
|
|
||||||
}
|
|
||||||
|
|
||||||
if not (dataset_raw_dir).is_dir():
|
|
||||||
raise NotADirectoryError(dataset_raw_dir)
|
|
||||||
|
|
||||||
if not dry_run:
|
|
||||||
push_dataset_to_hub(
|
|
||||||
dataset_raw_dir,
|
|
||||||
raw_format=repo_raw_format,
|
|
||||||
repo_id=dataset_repo_id_push,
|
|
||||||
local_dir=dataset_dir,
|
|
||||||
resume=True,
|
|
||||||
encoding=encoding,
|
|
||||||
tests_data_dir=tests_data_dir,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
print(
|
|
||||||
f"DRY RUN: {dataset_raw_dir} --> {dataset_dir} --> {dataset_repo_id_push}@{CODEBASE_VERSION}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument(
|
|
||||||
"--raw-dir",
|
|
||||||
type=Path,
|
|
||||||
default=Path("data"),
|
|
||||||
help="Directory where raw datasets are located.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--raw-repo-ids",
|
|
||||||
type=str,
|
|
||||||
nargs="*",
|
|
||||||
default=["lerobot-raw"],
|
|
||||||
help="""Raw dataset repo ids. if 'lerobot-raw', the keys from `AVAILABLE_RAW_REPO_IDS` will be
|
|
||||||
used and raw datasets will be fetched from the 'lerobot-raw/' repo and pushed with their
|
|
||||||
associated format. It is assumed that each dataset is located at `raw_dir / raw_repo_id` """,
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--raw-format",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help="""Raw format to use for the raw repo-ids. Must be specified if --raw-repo-ids is not
|
|
||||||
'lerobot-raw'""",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--local-dir",
|
|
||||||
type=Path,
|
|
||||||
default=None,
|
|
||||||
help="""When provided, writes the dataset converted to LeRobotDataset format in this directory
|
|
||||||
(e.g. `data/lerobot/aloha_mobile_chair`).""",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--push-repo",
|
|
||||||
type=str,
|
|
||||||
default="lerobot",
|
|
||||||
help="Repo to upload datasets to",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--vcodec",
|
|
||||||
type=str,
|
|
||||||
default="libsvtav1",
|
|
||||||
help="Codec to use for encoding videos",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--pix-fmt",
|
|
||||||
type=str,
|
|
||||||
default="yuv420p",
|
|
||||||
help="Pixel formats (chroma subsampling) to be used for encoding",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--g",
|
|
||||||
type=int,
|
|
||||||
default=2,
|
|
||||||
help="Group of pictures sizes to be used for encoding.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--crf",
|
|
||||||
type=int,
|
|
||||||
default=30,
|
|
||||||
help="Constant rate factors to be used for encoding.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--tests-data-dir",
|
|
||||||
type=Path,
|
|
||||||
default=None,
|
|
||||||
help=(
|
|
||||||
"When provided, save tests artifacts into the given directory "
|
|
||||||
"(e.g. `--tests-data-dir tests/data` will save to tests/data/{--repo-id})."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--dry-run",
|
|
||||||
type=int,
|
|
||||||
default=0,
|
|
||||||
help="If not set to 0, this script won't download or upload anything.",
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
encode_datasets(**vars(args))
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
|
@ -1,326 +0,0 @@
|
||||||
#!/usr/bin/env python
|
|
||||||
|
|
||||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# imagecodecs/numcodecs.py
|
|
||||||
|
|
||||||
# Copyright (c) 2021-2022, Christoph Gohlke
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# Redistribution and use in source and binary forms, with or without
|
|
||||||
# modification, are permitted provided that the following conditions are met:
|
|
||||||
#
|
|
||||||
# 1. Redistributions of source code must retain the above copyright notice,
|
|
||||||
# this list of conditions and the following disclaimer.
|
|
||||||
#
|
|
||||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
|
||||||
# this list of conditions and the following disclaimer in the documentation
|
|
||||||
# and/or other materials provided with the distribution.
|
|
||||||
#
|
|
||||||
# 3. Neither the name of the copyright holder nor the names of its
|
|
||||||
# contributors may be used to endorse or promote products derived from
|
|
||||||
# this software without specific prior written permission.
|
|
||||||
#
|
|
||||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
||||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
||||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
|
||||||
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
|
||||||
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
|
||||||
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
|
||||||
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
|
||||||
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
|
||||||
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
|
||||||
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
|
||||||
# POSSIBILITY OF SUCH DAMAGE.
|
|
||||||
|
|
||||||
# Copied from: https://github.com/real-stanford/universal_manipulation_interface/blob/298776ce251f33b6b3185a98d6e7d1f9ad49168b/diffusion_policy/codecs/imagecodecs_numcodecs.py#L1
|
|
||||||
"""Additional numcodecs implemented using imagecodecs."""
|
|
||||||
|
|
||||||
__version__ = "2022.9.26"
|
|
||||||
|
|
||||||
__all__ = ("register_codecs",)
|
|
||||||
|
|
||||||
import imagecodecs
|
|
||||||
import numpy
|
|
||||||
from numcodecs.abc import Codec
|
|
||||||
from numcodecs.registry import get_codec, register_codec
|
|
||||||
|
|
||||||
# TODO (azouitine): Remove useless codecs
|
|
||||||
|
|
||||||
|
|
||||||
def protective_squeeze(x: numpy.ndarray):
|
|
||||||
"""
|
|
||||||
Squeeze dim only if it's not the last dim.
|
|
||||||
Image dim expected to be *, H, W, C
|
|
||||||
"""
|
|
||||||
img_shape = x.shape[-3:]
|
|
||||||
if len(x.shape) > 3:
|
|
||||||
n_imgs = numpy.prod(x.shape[:-3])
|
|
||||||
if n_imgs > 1:
|
|
||||||
img_shape = (-1,) + img_shape
|
|
||||||
return x.reshape(img_shape)
|
|
||||||
|
|
||||||
|
|
||||||
def get_default_image_compressor(**kwargs):
|
|
||||||
if imagecodecs.JPEGXL:
|
|
||||||
# has JPEGXL
|
|
||||||
this_kwargs = {
|
|
||||||
"effort": 3,
|
|
||||||
"distance": 0.3,
|
|
||||||
# bug in libjxl, invalid codestream for non-lossless
|
|
||||||
# when decoding speed > 1
|
|
||||||
"decodingspeed": 1,
|
|
||||||
}
|
|
||||||
this_kwargs.update(kwargs)
|
|
||||||
return JpegXl(**this_kwargs)
|
|
||||||
else:
|
|
||||||
this_kwargs = {"level": 50}
|
|
||||||
this_kwargs.update(kwargs)
|
|
||||||
return Jpeg2k(**this_kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
class Jpeg2k(Codec):
|
|
||||||
"""JPEG 2000 codec for numcodecs."""
|
|
||||||
|
|
||||||
codec_id = "imagecodecs_jpeg2k"
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
level=None,
|
|
||||||
codecformat=None,
|
|
||||||
colorspace=None,
|
|
||||||
tile=None,
|
|
||||||
reversible=None,
|
|
||||||
bitspersample=None,
|
|
||||||
resolutions=None,
|
|
||||||
numthreads=None,
|
|
||||||
verbose=0,
|
|
||||||
):
|
|
||||||
self.level = level
|
|
||||||
self.codecformat = codecformat
|
|
||||||
self.colorspace = colorspace
|
|
||||||
self.tile = None if tile is None else tuple(tile)
|
|
||||||
self.reversible = reversible
|
|
||||||
self.bitspersample = bitspersample
|
|
||||||
self.resolutions = resolutions
|
|
||||||
self.numthreads = numthreads
|
|
||||||
self.verbose = verbose
|
|
||||||
|
|
||||||
def encode(self, buf):
|
|
||||||
buf = protective_squeeze(numpy.asarray(buf))
|
|
||||||
return imagecodecs.jpeg2k_encode(
|
|
||||||
buf,
|
|
||||||
level=self.level,
|
|
||||||
codecformat=self.codecformat,
|
|
||||||
colorspace=self.colorspace,
|
|
||||||
tile=self.tile,
|
|
||||||
reversible=self.reversible,
|
|
||||||
bitspersample=self.bitspersample,
|
|
||||||
resolutions=self.resolutions,
|
|
||||||
numthreads=self.numthreads,
|
|
||||||
verbose=self.verbose,
|
|
||||||
)
|
|
||||||
|
|
||||||
def decode(self, buf, out=None):
|
|
||||||
return imagecodecs.jpeg2k_decode(buf, verbose=self.verbose, numthreads=self.numthreads, out=out)
|
|
||||||
|
|
||||||
|
|
||||||
class JpegXl(Codec):
|
|
||||||
"""JPEG XL codec for numcodecs."""
|
|
||||||
|
|
||||||
codec_id = "imagecodecs_jpegxl"
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
# encode
|
|
||||||
level=None,
|
|
||||||
effort=None,
|
|
||||||
distance=None,
|
|
||||||
lossless=None,
|
|
||||||
decodingspeed=None,
|
|
||||||
photometric=None,
|
|
||||||
planar=None,
|
|
||||||
usecontainer=None,
|
|
||||||
# decode
|
|
||||||
index=None,
|
|
||||||
keeporientation=None,
|
|
||||||
# both
|
|
||||||
numthreads=None,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Return JPEG XL image from numpy array.
|
|
||||||
Float must be in nominal range 0..1.
|
|
||||||
|
|
||||||
Currently L, LA, RGB, RGBA images are supported in contig mode.
|
|
||||||
Extra channels are only supported for grayscale images in planar mode.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
level : Default to None, i.e. not overwriting lossess and decodingspeed options.
|
|
||||||
When < 0: Use lossless compression
|
|
||||||
When in [0,1,2,3,4]: Sets the decoding speed tier for the provided options.
|
|
||||||
Minimum is 0 (slowest to decode, best quality/density), and maximum
|
|
||||||
is 4 (fastest to decode, at the cost of some quality/density).
|
|
||||||
effort : Default to 3.
|
|
||||||
Sets encoder effort/speed level without affecting decoding speed.
|
|
||||||
Valid values are, from faster to slower speed: 1:lightning 2:thunder
|
|
||||||
3:falcon 4:cheetah 5:hare 6:wombat 7:squirrel 8:kitten 9:tortoise.
|
|
||||||
Speed: lightning, thunder, falcon, cheetah, hare, wombat, squirrel, kitten, tortoise
|
|
||||||
control the encoder effort in ascending order.
|
|
||||||
This also affects memory usage: using lower effort will typically reduce memory
|
|
||||||
consumption during encoding.
|
|
||||||
lightning and thunder are fast modes useful for lossless mode (modular).
|
|
||||||
falcon disables all of the following tools.
|
|
||||||
cheetah enables coefficient reordering, context clustering, and heuristics for selecting DCT sizes and quantization steps.
|
|
||||||
hare enables Gaborish filtering, chroma from luma, and an initial estimate of quantization steps.
|
|
||||||
wombat enables error diffusion quantization and full DCT size selection heuristics.
|
|
||||||
squirrel (default) enables dots, patches, and spline detection, and full context clustering.
|
|
||||||
kitten optimizes the adaptive quantization for a psychovisual metric.
|
|
||||||
tortoise enables a more thorough adaptive quantization search.
|
|
||||||
distance : Default to 1.0
|
|
||||||
Sets the distance level for lossy compression: target max butteraugli distance,
|
|
||||||
lower = higher quality. Range: 0 .. 15. 0.0 = mathematically lossless
|
|
||||||
(however, use JxlEncoderSetFrameLossless instead to use true lossless,
|
|
||||||
as setting distance to 0 alone is not the only requirement).
|
|
||||||
1.0 = visually lossless. Recommended range: 0.5 .. 3.0.
|
|
||||||
lossess : Default to False.
|
|
||||||
Use lossess encoding.
|
|
||||||
decodingspeed : Default to 0.
|
|
||||||
Duplicate to level. [0,4]
|
|
||||||
photometric : Return JxlColorSpace value.
|
|
||||||
Default logic is quite complicated but works most of the time.
|
|
||||||
Accepted value:
|
|
||||||
int: [-1,3]
|
|
||||||
str: ['RGB',
|
|
||||||
'WHITEISZERO', 'MINISWHITE',
|
|
||||||
'BLACKISZERO', 'MINISBLACK', 'GRAY',
|
|
||||||
'XYB', 'KNOWN']
|
|
||||||
planar : Enable multi-channel mode.
|
|
||||||
Default to false.
|
|
||||||
usecontainer :
|
|
||||||
Forces the encoder to use the box-based container format (BMFF)
|
|
||||||
even when not necessary.
|
|
||||||
When using JxlEncoderUseBoxes, JxlEncoderStoreJPEGMetadata or
|
|
||||||
JxlEncoderSetCodestreamLevel with level 10, the encoder will
|
|
||||||
automatically also use the container format, it is not necessary
|
|
||||||
to use JxlEncoderUseContainer for those use cases.
|
|
||||||
By default this setting is disabled.
|
|
||||||
index : Selectively decode frames for animation.
|
|
||||||
Default to 0, decode all frames.
|
|
||||||
When set to > 0, decode that frame index only.
|
|
||||||
keeporientation :
|
|
||||||
Enables or disables preserving of as-in-bitstream pixeldata orientation.
|
|
||||||
Some images are encoded with an Orientation tag indicating that the
|
|
||||||
decoder must perform a rotation and/or mirroring to the encoded image data.
|
|
||||||
|
|
||||||
If skip_reorientation is JXL_FALSE (the default): the decoder will apply
|
|
||||||
the transformation from the orientation setting, hence rendering the image
|
|
||||||
according to its specified intent. When producing a JxlBasicInfo, the decoder
|
|
||||||
will always set the orientation field to JXL_ORIENT_IDENTITY (matching the
|
|
||||||
returned pixel data) and also align xsize and ysize so that they correspond
|
|
||||||
to the width and the height of the returned pixel data.
|
|
||||||
|
|
||||||
If skip_reorientation is JXL_TRUE: the decoder will skip applying the
|
|
||||||
transformation from the orientation setting, returning the image in
|
|
||||||
the as-in-bitstream pixeldata orientation. This may be faster to decode
|
|
||||||
since the decoder doesnt have to apply the transformation, but can
|
|
||||||
cause wrong display of the image if the orientation tag is not correctly
|
|
||||||
taken into account by the user.
|
|
||||||
|
|
||||||
By default, this option is disabled, and the returned pixel data is
|
|
||||||
re-oriented according to the images Orientation setting.
|
|
||||||
threads : Default to 1.
|
|
||||||
If <= 0, use all cores.
|
|
||||||
If > 32, clipped to 32.
|
|
||||||
"""
|
|
||||||
|
|
||||||
self.level = level
|
|
||||||
self.effort = effort
|
|
||||||
self.distance = distance
|
|
||||||
self.lossless = bool(lossless)
|
|
||||||
self.decodingspeed = decodingspeed
|
|
||||||
self.photometric = photometric
|
|
||||||
self.planar = planar
|
|
||||||
self.usecontainer = usecontainer
|
|
||||||
self.index = index
|
|
||||||
self.keeporientation = keeporientation
|
|
||||||
self.numthreads = numthreads
|
|
||||||
|
|
||||||
def encode(self, buf):
|
|
||||||
# TODO: only squeeze all but last dim
|
|
||||||
buf = protective_squeeze(numpy.asarray(buf))
|
|
||||||
return imagecodecs.jpegxl_encode(
|
|
||||||
buf,
|
|
||||||
level=self.level,
|
|
||||||
effort=self.effort,
|
|
||||||
distance=self.distance,
|
|
||||||
lossless=self.lossless,
|
|
||||||
decodingspeed=self.decodingspeed,
|
|
||||||
photometric=self.photometric,
|
|
||||||
planar=self.planar,
|
|
||||||
usecontainer=self.usecontainer,
|
|
||||||
numthreads=self.numthreads,
|
|
||||||
)
|
|
||||||
|
|
||||||
def decode(self, buf, out=None):
|
|
||||||
return imagecodecs.jpegxl_decode(
|
|
||||||
buf,
|
|
||||||
index=self.index,
|
|
||||||
keeporientation=self.keeporientation,
|
|
||||||
numthreads=self.numthreads,
|
|
||||||
out=out,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _flat(out):
|
|
||||||
"""Return numpy array as contiguous view of bytes if possible."""
|
|
||||||
if out is None:
|
|
||||||
return None
|
|
||||||
view = memoryview(out)
|
|
||||||
if view.readonly or not view.contiguous:
|
|
||||||
return None
|
|
||||||
return view.cast("B")
|
|
||||||
|
|
||||||
|
|
||||||
def register_codecs(codecs=None, force=False, verbose=True):
|
|
||||||
"""Register codecs in this module with numcodecs."""
|
|
||||||
for name, cls in globals().items():
|
|
||||||
if not hasattr(cls, "codec_id") or name == "Codec":
|
|
||||||
continue
|
|
||||||
if codecs is not None and cls.codec_id not in codecs:
|
|
||||||
continue
|
|
||||||
try:
|
|
||||||
try: # noqa: SIM105
|
|
||||||
get_codec({"id": cls.codec_id})
|
|
||||||
except TypeError:
|
|
||||||
# registered, but failed
|
|
||||||
pass
|
|
||||||
except ValueError:
|
|
||||||
# not registered yet
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
if not force:
|
|
||||||
if verbose:
|
|
||||||
log_warning(f"numcodec {cls.codec_id!r} already registered")
|
|
||||||
continue
|
|
||||||
if verbose:
|
|
||||||
log_warning(f"replacing registered numcodec {cls.codec_id!r}")
|
|
||||||
register_codec(cls)
|
|
||||||
|
|
||||||
|
|
||||||
def log_warning(msg, *args, **kwargs):
|
|
||||||
"""Log message with level WARNING."""
|
|
||||||
import logging
|
|
||||||
|
|
||||||
logging.getLogger(__name__).warning(msg, *args, **kwargs)
|
|
|
@ -1,233 +0,0 @@
|
||||||
#!/usr/bin/env python
|
|
||||||
|
|
||||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""
|
|
||||||
Contains utilities to process raw data format of HDF5 files like in: https://github.com/tonyzhaozh/act
|
|
||||||
"""
|
|
||||||
|
|
||||||
import gc
|
|
||||||
import shutil
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import h5py
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import tqdm
|
|
||||||
from datasets import Dataset, Features, Image, Sequence, Value
|
|
||||||
from PIL import Image as PILImage
|
|
||||||
|
|
||||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
|
||||||
from lerobot.common.datasets.push_dataset_to_hub.utils import (
|
|
||||||
calculate_episode_data_index,
|
|
||||||
concatenate_episodes,
|
|
||||||
get_default_encoding,
|
|
||||||
save_images_concurrently,
|
|
||||||
)
|
|
||||||
from lerobot.common.datasets.utils import (
|
|
||||||
hf_transform_to_torch,
|
|
||||||
)
|
|
||||||
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
|
|
||||||
|
|
||||||
|
|
||||||
def get_cameras(hdf5_data):
|
|
||||||
# ignore depth channel, not currently handled
|
|
||||||
# TODO(rcadene): add depth
|
|
||||||
rgb_cameras = [key for key in hdf5_data["/observations/images"].keys() if "depth" not in key] # noqa: SIM118
|
|
||||||
return rgb_cameras
|
|
||||||
|
|
||||||
|
|
||||||
def check_format(raw_dir) -> bool:
|
|
||||||
# only frames from simulation are uncompressed
|
|
||||||
compressed_images = "sim" not in raw_dir.name
|
|
||||||
|
|
||||||
hdf5_paths = list(raw_dir.glob("episode_*.hdf5"))
|
|
||||||
assert len(hdf5_paths) != 0
|
|
||||||
for hdf5_path in hdf5_paths:
|
|
||||||
with h5py.File(hdf5_path, "r") as data:
|
|
||||||
assert "/action" in data
|
|
||||||
assert "/observations/qpos" in data
|
|
||||||
|
|
||||||
assert data["/action"].ndim == 2
|
|
||||||
assert data["/observations/qpos"].ndim == 2
|
|
||||||
|
|
||||||
num_frames = data["/action"].shape[0]
|
|
||||||
assert num_frames == data["/observations/qpos"].shape[0]
|
|
||||||
|
|
||||||
for camera in get_cameras(data):
|
|
||||||
assert num_frames == data[f"/observations/images/{camera}"].shape[0]
|
|
||||||
|
|
||||||
if compressed_images:
|
|
||||||
assert data[f"/observations/images/{camera}"].ndim == 2
|
|
||||||
else:
|
|
||||||
assert data[f"/observations/images/{camera}"].ndim == 4
|
|
||||||
b, h, w, c = data[f"/observations/images/{camera}"].shape
|
|
||||||
assert c < h and c < w, f"Expect (h,w,c) image format but ({h=},{w=},{c=}) provided."
|
|
||||||
|
|
||||||
|
|
||||||
def load_from_raw(
|
|
||||||
raw_dir: Path,
|
|
||||||
videos_dir: Path,
|
|
||||||
fps: int,
|
|
||||||
video: bool,
|
|
||||||
episodes: list[int] | None = None,
|
|
||||||
encoding: dict | None = None,
|
|
||||||
):
|
|
||||||
# only frames from simulation are uncompressed
|
|
||||||
compressed_images = "sim" not in raw_dir.name
|
|
||||||
|
|
||||||
hdf5_files = sorted(raw_dir.glob("episode_*.hdf5"))
|
|
||||||
num_episodes = len(hdf5_files)
|
|
||||||
|
|
||||||
ep_dicts = []
|
|
||||||
ep_ids = episodes if episodes else range(num_episodes)
|
|
||||||
for ep_idx in tqdm.tqdm(ep_ids):
|
|
||||||
ep_path = hdf5_files[ep_idx]
|
|
||||||
with h5py.File(ep_path, "r") as ep:
|
|
||||||
num_frames = ep["/action"].shape[0]
|
|
||||||
|
|
||||||
# last step of demonstration is considered done
|
|
||||||
done = torch.zeros(num_frames, dtype=torch.bool)
|
|
||||||
done[-1] = True
|
|
||||||
|
|
||||||
state = torch.from_numpy(ep["/observations/qpos"][:])
|
|
||||||
action = torch.from_numpy(ep["/action"][:])
|
|
||||||
if "/observations/qvel" in ep:
|
|
||||||
velocity = torch.from_numpy(ep["/observations/qvel"][:])
|
|
||||||
if "/observations/effort" in ep:
|
|
||||||
effort = torch.from_numpy(ep["/observations/effort"][:])
|
|
||||||
|
|
||||||
ep_dict = {}
|
|
||||||
|
|
||||||
for camera in get_cameras(ep):
|
|
||||||
img_key = f"observation.images.{camera}"
|
|
||||||
|
|
||||||
if compressed_images:
|
|
||||||
import cv2
|
|
||||||
|
|
||||||
# load one compressed image after the other in RAM and uncompress
|
|
||||||
imgs_array = []
|
|
||||||
for data in ep[f"/observations/images/{camera}"]:
|
|
||||||
imgs_array.append(cv2.imdecode(data, 1))
|
|
||||||
imgs_array = np.array(imgs_array)
|
|
||||||
|
|
||||||
else:
|
|
||||||
# load all images in RAM
|
|
||||||
imgs_array = ep[f"/observations/images/{camera}"][:]
|
|
||||||
|
|
||||||
if video:
|
|
||||||
# save png images in temporary directory
|
|
||||||
tmp_imgs_dir = videos_dir / "tmp_images"
|
|
||||||
save_images_concurrently(imgs_array, tmp_imgs_dir)
|
|
||||||
|
|
||||||
# encode images to a mp4 video
|
|
||||||
fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
|
|
||||||
video_path = videos_dir / fname
|
|
||||||
encode_video_frames(tmp_imgs_dir, video_path, fps, **(encoding or {}))
|
|
||||||
|
|
||||||
# clean temporary images directory
|
|
||||||
shutil.rmtree(tmp_imgs_dir)
|
|
||||||
|
|
||||||
# store the reference to the video frame
|
|
||||||
ep_dict[img_key] = [
|
|
||||||
{"path": f"videos/{fname}", "timestamp": i / fps} for i in range(num_frames)
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array]
|
|
||||||
|
|
||||||
ep_dict["observation.state"] = state
|
|
||||||
if "/observations/velocity" in ep:
|
|
||||||
ep_dict["observation.velocity"] = velocity
|
|
||||||
if "/observations/effort" in ep:
|
|
||||||
ep_dict["observation.effort"] = effort
|
|
||||||
ep_dict["action"] = action
|
|
||||||
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames)
|
|
||||||
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
|
|
||||||
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
|
|
||||||
ep_dict["next.done"] = done
|
|
||||||
# TODO(rcadene): add reward and success by computing them in sim
|
|
||||||
|
|
||||||
assert isinstance(ep_idx, int)
|
|
||||||
ep_dicts.append(ep_dict)
|
|
||||||
|
|
||||||
gc.collect()
|
|
||||||
|
|
||||||
data_dict = concatenate_episodes(ep_dicts)
|
|
||||||
|
|
||||||
total_frames = data_dict["frame_index"].shape[0]
|
|
||||||
data_dict["index"] = torch.arange(0, total_frames, 1)
|
|
||||||
return data_dict
|
|
||||||
|
|
||||||
|
|
||||||
def to_hf_dataset(data_dict, video) -> Dataset:
|
|
||||||
features = {}
|
|
||||||
|
|
||||||
keys = [key for key in data_dict if "observation.images." in key]
|
|
||||||
for key in keys:
|
|
||||||
if video:
|
|
||||||
features[key] = VideoFrame()
|
|
||||||
else:
|
|
||||||
features[key] = Image()
|
|
||||||
|
|
||||||
features["observation.state"] = Sequence(
|
|
||||||
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
|
|
||||||
)
|
|
||||||
if "observation.velocity" in data_dict:
|
|
||||||
features["observation.velocity"] = Sequence(
|
|
||||||
length=data_dict["observation.velocity"].shape[1], feature=Value(dtype="float32", id=None)
|
|
||||||
)
|
|
||||||
if "observation.effort" in data_dict:
|
|
||||||
features["observation.effort"] = Sequence(
|
|
||||||
length=data_dict["observation.effort"].shape[1], feature=Value(dtype="float32", id=None)
|
|
||||||
)
|
|
||||||
features["action"] = Sequence(
|
|
||||||
length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)
|
|
||||||
)
|
|
||||||
features["episode_index"] = Value(dtype="int64", id=None)
|
|
||||||
features["frame_index"] = Value(dtype="int64", id=None)
|
|
||||||
features["timestamp"] = Value(dtype="float32", id=None)
|
|
||||||
features["next.done"] = Value(dtype="bool", id=None)
|
|
||||||
features["index"] = Value(dtype="int64", id=None)
|
|
||||||
|
|
||||||
hf_dataset = Dataset.from_dict(data_dict, features=Features(features))
|
|
||||||
hf_dataset.set_transform(hf_transform_to_torch)
|
|
||||||
return hf_dataset
|
|
||||||
|
|
||||||
|
|
||||||
def from_raw_to_lerobot_format(
|
|
||||||
raw_dir: Path,
|
|
||||||
videos_dir: Path,
|
|
||||||
fps: int | None = None,
|
|
||||||
video: bool = True,
|
|
||||||
episodes: list[int] | None = None,
|
|
||||||
encoding: dict | None = None,
|
|
||||||
):
|
|
||||||
# sanity check
|
|
||||||
check_format(raw_dir)
|
|
||||||
|
|
||||||
if fps is None:
|
|
||||||
fps = 50
|
|
||||||
|
|
||||||
data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes, encoding)
|
|
||||||
hf_dataset = to_hf_dataset(data_dict, video)
|
|
||||||
episode_data_index = calculate_episode_data_index(hf_dataset)
|
|
||||||
info = {
|
|
||||||
"codebase_version": CODEBASE_VERSION,
|
|
||||||
"fps": fps,
|
|
||||||
"video": video,
|
|
||||||
}
|
|
||||||
if video:
|
|
||||||
info["encoding"] = get_default_encoding()
|
|
||||||
|
|
||||||
return hf_dataset, episode_data_index, info
|
|
|
@ -1,107 +0,0 @@
|
||||||
#!/usr/bin/env python
|
|
||||||
|
|
||||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""
|
|
||||||
Contains utilities to process raw data format of png images files recorded with capture_camera_feed.py
|
|
||||||
"""
|
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from datasets import Dataset, Features, Image, Value
|
|
||||||
from PIL import Image as PILImage
|
|
||||||
|
|
||||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
|
||||||
from lerobot.common.datasets.push_dataset_to_hub.utils import (
|
|
||||||
calculate_episode_data_index,
|
|
||||||
concatenate_episodes,
|
|
||||||
)
|
|
||||||
from lerobot.common.datasets.utils import hf_transform_to_torch
|
|
||||||
from lerobot.common.datasets.video_utils import VideoFrame
|
|
||||||
|
|
||||||
|
|
||||||
def check_format(raw_dir: Path) -> bool:
|
|
||||||
image_paths = list(raw_dir.glob("frame_*.png"))
|
|
||||||
if len(image_paths) == 0:
|
|
||||||
raise ValueError
|
|
||||||
|
|
||||||
|
|
||||||
def load_from_raw(raw_dir: Path, fps: int, episodes: list[int] | None = None):
|
|
||||||
if episodes is not None:
|
|
||||||
# TODO(aliberts): add support for multi-episodes.
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
ep_dict = {}
|
|
||||||
ep_idx = 0
|
|
||||||
|
|
||||||
image_paths = sorted(raw_dir.glob("frame_*.png"))
|
|
||||||
num_frames = len(image_paths)
|
|
||||||
|
|
||||||
ep_dict["observation.image"] = [PILImage.open(x) for x in image_paths]
|
|
||||||
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames)
|
|
||||||
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
|
|
||||||
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
|
|
||||||
|
|
||||||
ep_dicts = [ep_dict]
|
|
||||||
data_dict = concatenate_episodes(ep_dicts)
|
|
||||||
total_frames = data_dict["frame_index"].shape[0]
|
|
||||||
data_dict["index"] = torch.arange(0, total_frames, 1)
|
|
||||||
return data_dict
|
|
||||||
|
|
||||||
|
|
||||||
def to_hf_dataset(data_dict, video) -> Dataset:
|
|
||||||
features = {}
|
|
||||||
if video:
|
|
||||||
features["observation.image"] = VideoFrame()
|
|
||||||
else:
|
|
||||||
features["observation.image"] = Image()
|
|
||||||
|
|
||||||
features["episode_index"] = Value(dtype="int64", id=None)
|
|
||||||
features["frame_index"] = Value(dtype="int64", id=None)
|
|
||||||
features["timestamp"] = Value(dtype="float32", id=None)
|
|
||||||
features["index"] = Value(dtype="int64", id=None)
|
|
||||||
|
|
||||||
hf_dataset = Dataset.from_dict(data_dict, features=Features(features))
|
|
||||||
hf_dataset.set_transform(hf_transform_to_torch)
|
|
||||||
return hf_dataset
|
|
||||||
|
|
||||||
|
|
||||||
def from_raw_to_lerobot_format(
|
|
||||||
raw_dir: Path,
|
|
||||||
videos_dir: Path,
|
|
||||||
fps: int | None = None,
|
|
||||||
video: bool = True,
|
|
||||||
episodes: list[int] | None = None,
|
|
||||||
encoding: dict | None = None,
|
|
||||||
):
|
|
||||||
if video or episodes or encoding is not None:
|
|
||||||
# TODO(aliberts): support this
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
# sanity check
|
|
||||||
check_format(raw_dir)
|
|
||||||
|
|
||||||
if fps is None:
|
|
||||||
fps = 30
|
|
||||||
|
|
||||||
data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes)
|
|
||||||
hf_dataset = to_hf_dataset(data_dict, video)
|
|
||||||
episode_data_index = calculate_episode_data_index(hf_dataset)
|
|
||||||
info = {
|
|
||||||
"codebase_version": CODEBASE_VERSION,
|
|
||||||
"fps": fps,
|
|
||||||
"video": video,
|
|
||||||
}
|
|
||||||
return hf_dataset, episode_data_index, info
|
|
|
@ -1,233 +0,0 @@
|
||||||
#!/usr/bin/env python
|
|
||||||
|
|
||||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""
|
|
||||||
Contains utilities to process raw data format from dora-record
|
|
||||||
"""
|
|
||||||
|
|
||||||
import re
|
|
||||||
import warnings
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import pandas as pd
|
|
||||||
import torch
|
|
||||||
from datasets import Dataset, Features, Image, Sequence, Value
|
|
||||||
|
|
||||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
|
||||||
from lerobot.common.datasets.push_dataset_to_hub.utils import calculate_episode_data_index
|
|
||||||
from lerobot.common.datasets.utils import (
|
|
||||||
hf_transform_to_torch,
|
|
||||||
)
|
|
||||||
from lerobot.common.datasets.video_utils import VideoFrame
|
|
||||||
|
|
||||||
|
|
||||||
def check_format(raw_dir) -> bool:
|
|
||||||
assert raw_dir.exists()
|
|
||||||
|
|
||||||
leader_file = list(raw_dir.glob("*.parquet"))
|
|
||||||
if len(leader_file) == 0:
|
|
||||||
raise ValueError(f"Missing parquet files in '{raw_dir}'")
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episodes: list[int] | None = None):
|
|
||||||
# Load data stream that will be used as reference for the timestamps synchronization
|
|
||||||
reference_files = list(raw_dir.glob("observation.images.cam_*.parquet"))
|
|
||||||
if len(reference_files) == 0:
|
|
||||||
raise ValueError(f"Missing reference files for camera, starting with in '{raw_dir}'")
|
|
||||||
# select first camera in alphanumeric order
|
|
||||||
reference_key = sorted(reference_files)[0].stem
|
|
||||||
reference_df = pd.read_parquet(raw_dir / f"{reference_key}.parquet")
|
|
||||||
reference_df = reference_df[["timestamp_utc", reference_key]]
|
|
||||||
|
|
||||||
# Merge all data stream using nearest backward strategy
|
|
||||||
df = reference_df
|
|
||||||
for path in raw_dir.glob("*.parquet"):
|
|
||||||
key = path.stem # action or observation.state or ...
|
|
||||||
if key == reference_key:
|
|
||||||
continue
|
|
||||||
if "failed_episode_index" in key:
|
|
||||||
# TODO(rcadene): add support for removing episodes that are tagged as "failed"
|
|
||||||
continue
|
|
||||||
modality_df = pd.read_parquet(path)
|
|
||||||
modality_df = modality_df[["timestamp_utc", key]]
|
|
||||||
df = pd.merge_asof(
|
|
||||||
df,
|
|
||||||
modality_df,
|
|
||||||
on="timestamp_utc",
|
|
||||||
# "nearest" is the best option over "backward", since the latter can desynchronizes camera timestamps by
|
|
||||||
# matching timestamps that are too far apart, in order to fit the backward constraints. It's not the case for "nearest".
|
|
||||||
# However, note that "nearest" might synchronize the reference camera with other cameras on slightly future timestamps.
|
|
||||||
# are too far apart.
|
|
||||||
direction="nearest",
|
|
||||||
tolerance=pd.Timedelta(f"{1 / fps} seconds"),
|
|
||||||
)
|
|
||||||
# Remove rows with episode_index -1 which indicates data that correspond to in-between episodes
|
|
||||||
df = df[df["episode_index"] != -1]
|
|
||||||
|
|
||||||
image_keys = [key for key in df if "observation.images." in key]
|
|
||||||
|
|
||||||
def get_episode_index(row):
|
|
||||||
episode_index_per_cam = {}
|
|
||||||
for key in image_keys:
|
|
||||||
path = row[key][0]["path"]
|
|
||||||
match = re.search(r"_(\d{6}).mp4", path)
|
|
||||||
if not match:
|
|
||||||
raise ValueError(path)
|
|
||||||
episode_index = int(match.group(1))
|
|
||||||
episode_index_per_cam[key] = episode_index
|
|
||||||
if len(set(episode_index_per_cam.values())) != 1:
|
|
||||||
raise ValueError(
|
|
||||||
f"All cameras are expected to belong to the same episode, but getting {episode_index_per_cam}"
|
|
||||||
)
|
|
||||||
return episode_index
|
|
||||||
|
|
||||||
df["episode_index"] = df.apply(get_episode_index, axis=1)
|
|
||||||
|
|
||||||
# dora only use arrays, so single values are encapsulated into a list
|
|
||||||
df["frame_index"] = df.groupby("episode_index").cumcount()
|
|
||||||
df = df.reset_index()
|
|
||||||
df["index"] = df.index
|
|
||||||
|
|
||||||
# set 'next.done' to True for the last frame of each episode
|
|
||||||
df["next.done"] = False
|
|
||||||
df.loc[df.groupby("episode_index").tail(1).index, "next.done"] = True
|
|
||||||
|
|
||||||
df["timestamp"] = df["timestamp_utc"].map(lambda x: x.timestamp())
|
|
||||||
# each episode starts with timestamp 0 to match the ones from the video
|
|
||||||
df["timestamp"] = df.groupby("episode_index")["timestamp"].transform(lambda x: x - x.iloc[0])
|
|
||||||
|
|
||||||
del df["timestamp_utc"]
|
|
||||||
|
|
||||||
# sanity check
|
|
||||||
has_nan = df.isna().any().any()
|
|
||||||
if has_nan:
|
|
||||||
raise ValueError("Dataset contains Nan values.")
|
|
||||||
|
|
||||||
# sanity check episode indices go from 0 to n-1
|
|
||||||
ep_ids = [ep_idx for ep_idx, _ in df.groupby("episode_index")]
|
|
||||||
expected_ep_ids = list(range(df["episode_index"].max() + 1))
|
|
||||||
if ep_ids != expected_ep_ids:
|
|
||||||
raise ValueError(f"Episodes indices go from {ep_ids} instead of {expected_ep_ids}")
|
|
||||||
|
|
||||||
# Create symlink to raw videos directory (that needs to be absolute not relative)
|
|
||||||
videos_dir.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
videos_dir.symlink_to((raw_dir / "videos").absolute())
|
|
||||||
|
|
||||||
# sanity check the video paths are well formatted
|
|
||||||
for key in df:
|
|
||||||
if "observation.images." not in key:
|
|
||||||
continue
|
|
||||||
for ep_idx in ep_ids:
|
|
||||||
video_path = videos_dir / f"{key}_episode_{ep_idx:06d}.mp4"
|
|
||||||
if not video_path.exists():
|
|
||||||
raise ValueError(f"Video file not found in {video_path}")
|
|
||||||
|
|
||||||
data_dict = {}
|
|
||||||
for key in df:
|
|
||||||
# is video frame
|
|
||||||
if "observation.images." in key:
|
|
||||||
# we need `[0] because dora only use arrays, so single values are encapsulated into a list.
|
|
||||||
# it is the case for video_frame dictionary = [{"path": ..., "timestamp": ...}]
|
|
||||||
data_dict[key] = [video_frame[0] for video_frame in df[key].values]
|
|
||||||
|
|
||||||
# sanity check the video path is well formatted
|
|
||||||
video_path = videos_dir.parent / data_dict[key][0]["path"]
|
|
||||||
if not video_path.exists():
|
|
||||||
raise ValueError(f"Video file not found in {video_path}")
|
|
||||||
# is number
|
|
||||||
elif df[key].iloc[0].ndim == 0 or df[key].iloc[0].shape[0] == 1:
|
|
||||||
data_dict[key] = torch.from_numpy(df[key].values)
|
|
||||||
# is vector
|
|
||||||
elif df[key].iloc[0].shape[0] > 1:
|
|
||||||
data_dict[key] = torch.stack([torch.from_numpy(x.copy()) for x in df[key].values])
|
|
||||||
else:
|
|
||||||
raise ValueError(key)
|
|
||||||
|
|
||||||
return data_dict
|
|
||||||
|
|
||||||
|
|
||||||
def to_hf_dataset(data_dict, video) -> Dataset:
|
|
||||||
features = {}
|
|
||||||
|
|
||||||
keys = [key for key in data_dict if "observation.images." in key]
|
|
||||||
for key in keys:
|
|
||||||
if video:
|
|
||||||
features[key] = VideoFrame()
|
|
||||||
else:
|
|
||||||
features[key] = Image()
|
|
||||||
|
|
||||||
features["observation.state"] = Sequence(
|
|
||||||
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
|
|
||||||
)
|
|
||||||
if "observation.velocity" in data_dict:
|
|
||||||
features["observation.velocity"] = Sequence(
|
|
||||||
length=data_dict["observation.velocity"].shape[1], feature=Value(dtype="float32", id=None)
|
|
||||||
)
|
|
||||||
if "observation.effort" in data_dict:
|
|
||||||
features["observation.effort"] = Sequence(
|
|
||||||
length=data_dict["observation.effort"].shape[1], feature=Value(dtype="float32", id=None)
|
|
||||||
)
|
|
||||||
features["action"] = Sequence(
|
|
||||||
length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)
|
|
||||||
)
|
|
||||||
features["episode_index"] = Value(dtype="int64", id=None)
|
|
||||||
features["frame_index"] = Value(dtype="int64", id=None)
|
|
||||||
features["timestamp"] = Value(dtype="float32", id=None)
|
|
||||||
features["next.done"] = Value(dtype="bool", id=None)
|
|
||||||
features["index"] = Value(dtype="int64", id=None)
|
|
||||||
|
|
||||||
hf_dataset = Dataset.from_dict(data_dict, features=Features(features))
|
|
||||||
hf_dataset.set_transform(hf_transform_to_torch)
|
|
||||||
return hf_dataset
|
|
||||||
|
|
||||||
|
|
||||||
def from_raw_to_lerobot_format(
|
|
||||||
raw_dir: Path,
|
|
||||||
videos_dir: Path,
|
|
||||||
fps: int | None = None,
|
|
||||||
video: bool = True,
|
|
||||||
episodes: list[int] | None = None,
|
|
||||||
encoding: dict | None = None,
|
|
||||||
):
|
|
||||||
# sanity check
|
|
||||||
check_format(raw_dir)
|
|
||||||
|
|
||||||
if fps is None:
|
|
||||||
fps = 30
|
|
||||||
else:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
if not video:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
if encoding is not None:
|
|
||||||
warnings.warn(
|
|
||||||
"Video encoding is currently done outside of LeRobot for the dora_parquet format.",
|
|
||||||
stacklevel=1,
|
|
||||||
)
|
|
||||||
|
|
||||||
data_df = load_from_raw(raw_dir, videos_dir, fps, episodes)
|
|
||||||
hf_dataset = to_hf_dataset(data_df, video)
|
|
||||||
episode_data_index = calculate_episode_data_index(hf_dataset)
|
|
||||||
info = {
|
|
||||||
"codebase_version": CODEBASE_VERSION,
|
|
||||||
"fps": fps,
|
|
||||||
"video": video,
|
|
||||||
}
|
|
||||||
if video:
|
|
||||||
info["encoding"] = "unknown"
|
|
||||||
|
|
||||||
return hf_dataset, episode_data_index, info
|
|
|
@ -1,312 +0,0 @@
|
||||||
#!/usr/bin/env python
|
|
||||||
|
|
||||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""
|
|
||||||
For all datasets in the RLDS format.
|
|
||||||
For https://github.com/google-deepmind/open_x_embodiment (OPENX) datasets.
|
|
||||||
|
|
||||||
NOTE: You need to install tensorflow and tensorflow_datasets before running this script.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
python lerobot/scripts/push_dataset_to_hub.py \
|
|
||||||
--raw-dir /path/to/data/bridge_dataset/1.0.0/ \
|
|
||||||
--repo-id your_hub/sampled_bridge_data_v2 \
|
|
||||||
--raw-format rlds \
|
|
||||||
--episodes 3 4 5 8 9
|
|
||||||
|
|
||||||
Exact dataset fps defined in openx/config.py, obtained from:
|
|
||||||
https://docs.google.com/spreadsheets/d/1rPBD77tk60AEIGZrGSODwyyzs5FgCU9Uz3h-3_t2A9g/edit?gid=0#gid=0&range=R:R
|
|
||||||
"""
|
|
||||||
|
|
||||||
import shutil
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import tensorflow as tf
|
|
||||||
import tensorflow_datasets as tfds
|
|
||||||
import torch
|
|
||||||
import tqdm
|
|
||||||
from datasets import Dataset, Features, Image, Sequence, Value
|
|
||||||
from PIL import Image as PILImage
|
|
||||||
|
|
||||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
|
||||||
from lerobot.common.datasets.push_dataset_to_hub.utils import (
|
|
||||||
calculate_episode_data_index,
|
|
||||||
concatenate_episodes,
|
|
||||||
get_default_encoding,
|
|
||||||
save_images_concurrently,
|
|
||||||
)
|
|
||||||
from lerobot.common.datasets.utils import (
|
|
||||||
hf_transform_to_torch,
|
|
||||||
)
|
|
||||||
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
|
|
||||||
|
|
||||||
np.set_printoptions(precision=2)
|
|
||||||
|
|
||||||
|
|
||||||
def tf_to_torch(data):
|
|
||||||
return torch.from_numpy(data.numpy())
|
|
||||||
|
|
||||||
|
|
||||||
def tf_img_convert(img):
|
|
||||||
if img.dtype == tf.string:
|
|
||||||
img = tf.io.decode_image(img, expand_animations=False, dtype=tf.uint8)
|
|
||||||
elif img.dtype != tf.uint8:
|
|
||||||
raise ValueError(f"Unsupported image dtype: found with dtype {img.dtype}")
|
|
||||||
return img.numpy()
|
|
||||||
|
|
||||||
|
|
||||||
def _broadcast_metadata_rlds(i: tf.Tensor, traj: dict) -> dict:
|
|
||||||
"""
|
|
||||||
In the RLDS format, each trajectory has some top-level metadata that is explicitly separated out, and a "steps"
|
|
||||||
entry. This function moves the "steps" entry to the top level, broadcasting any metadata to the length of the
|
|
||||||
trajectory. This function also adds the extra metadata fields `_len`, `_traj_index`, and `_frame_index`.
|
|
||||||
|
|
||||||
NOTE: adapted from DLimp library https://github.com/kvablack/dlimp/
|
|
||||||
"""
|
|
||||||
steps = traj.pop("steps")
|
|
||||||
|
|
||||||
traj_len = tf.shape(tf.nest.flatten(steps)[0])[0]
|
|
||||||
|
|
||||||
# broadcast metadata to the length of the trajectory
|
|
||||||
metadata = tf.nest.map_structure(lambda x: tf.repeat(x, traj_len), traj)
|
|
||||||
|
|
||||||
# put steps back in
|
|
||||||
assert "traj_metadata" not in steps
|
|
||||||
traj = {**steps, "traj_metadata": metadata}
|
|
||||||
|
|
||||||
assert "_len" not in traj
|
|
||||||
assert "_traj_index" not in traj
|
|
||||||
assert "_frame_index" not in traj
|
|
||||||
traj["_len"] = tf.repeat(traj_len, traj_len)
|
|
||||||
traj["_traj_index"] = tf.repeat(i, traj_len)
|
|
||||||
traj["_frame_index"] = tf.range(traj_len)
|
|
||||||
|
|
||||||
return traj
|
|
||||||
|
|
||||||
|
|
||||||
def load_from_raw(
|
|
||||||
raw_dir: Path,
|
|
||||||
videos_dir: Path,
|
|
||||||
fps: int,
|
|
||||||
video: bool,
|
|
||||||
episodes: list[int] | None = None,
|
|
||||||
encoding: dict | None = None,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
raw_dir (Path): _description_
|
|
||||||
videos_dir (Path): _description_
|
|
||||||
fps (int): _description_
|
|
||||||
video (bool): _description_
|
|
||||||
episodes (list[int] | None, optional): _description_. Defaults to None.
|
|
||||||
"""
|
|
||||||
ds_builder = tfds.builder_from_directory(str(raw_dir))
|
|
||||||
dataset = ds_builder.as_dataset(
|
|
||||||
split="all",
|
|
||||||
decoders={"steps": tfds.decode.SkipDecoding()},
|
|
||||||
)
|
|
||||||
|
|
||||||
dataset_info = ds_builder.info
|
|
||||||
print("dataset_info: ", dataset_info)
|
|
||||||
|
|
||||||
ds_length = len(dataset)
|
|
||||||
dataset = dataset.take(ds_length)
|
|
||||||
# "flatten" the dataset as such we can apply trajectory level map() easily
|
|
||||||
# each [obs][key] has a shape of (frame_size, ...)
|
|
||||||
dataset = dataset.enumerate().map(_broadcast_metadata_rlds)
|
|
||||||
|
|
||||||
# we will apply the standardization transform if the dataset_name is provided
|
|
||||||
# if the dataset name is not provided and the goal is to convert any rlds formatted dataset
|
|
||||||
# search for 'image' keys in the observations
|
|
||||||
image_keys = []
|
|
||||||
state_keys = []
|
|
||||||
observation_info = dataset_info.features["steps"]["observation"]
|
|
||||||
for key in observation_info:
|
|
||||||
# check whether the key is for an image or a vector observation
|
|
||||||
if len(observation_info[key].shape) == 3:
|
|
||||||
# only adding uint8 images discards depth images
|
|
||||||
if observation_info[key].dtype == tf.uint8:
|
|
||||||
image_keys.append(key)
|
|
||||||
else:
|
|
||||||
state_keys.append(key)
|
|
||||||
|
|
||||||
lang_key = "language_instruction" if "language_instruction" in dataset.element_spec else None
|
|
||||||
|
|
||||||
print(" - image_keys: ", image_keys)
|
|
||||||
print(" - lang_key: ", lang_key)
|
|
||||||
|
|
||||||
it = iter(dataset)
|
|
||||||
|
|
||||||
ep_dicts = []
|
|
||||||
# Init temp path to save ep_dicts in case of crash
|
|
||||||
tmp_ep_dicts_dir = videos_dir.parent.joinpath("ep_dicts")
|
|
||||||
tmp_ep_dicts_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
# check if ep_dicts have already been saved in /tmp
|
|
||||||
starting_ep_idx = 0
|
|
||||||
saved_ep_dicts = [ep.__str__() for ep in tmp_ep_dicts_dir.iterdir()]
|
|
||||||
if len(saved_ep_dicts) > 0:
|
|
||||||
saved_ep_dicts.sort()
|
|
||||||
# get last ep_idx number
|
|
||||||
starting_ep_idx = int(saved_ep_dicts[-1][-13:-3]) + 1
|
|
||||||
for i in range(starting_ep_idx):
|
|
||||||
episode = next(it)
|
|
||||||
ep_dicts.append(torch.load(saved_ep_dicts[i]))
|
|
||||||
|
|
||||||
# if we user specified episodes, skip the ones not in the list
|
|
||||||
if episodes is not None:
|
|
||||||
if ds_length == 0:
|
|
||||||
raise ValueError("No episodes found.")
|
|
||||||
# convert episodes index to sorted list
|
|
||||||
episodes = sorted(episodes)
|
|
||||||
|
|
||||||
for ep_idx in tqdm.tqdm(range(starting_ep_idx, ds_length)):
|
|
||||||
episode = next(it)
|
|
||||||
|
|
||||||
# if user specified episodes, skip the ones not in the list
|
|
||||||
if episodes is not None:
|
|
||||||
if len(episodes) == 0:
|
|
||||||
break
|
|
||||||
if ep_idx == episodes[0]:
|
|
||||||
# process this episode
|
|
||||||
print(" selecting episode idx: ", ep_idx)
|
|
||||||
episodes.pop(0)
|
|
||||||
else:
|
|
||||||
continue # skip
|
|
||||||
|
|
||||||
num_frames = episode["action"].shape[0]
|
|
||||||
|
|
||||||
ep_dict = {}
|
|
||||||
for key in state_keys:
|
|
||||||
ep_dict[f"observation.{key}"] = tf_to_torch(episode["observation"][key])
|
|
||||||
|
|
||||||
ep_dict["action"] = tf_to_torch(episode["action"])
|
|
||||||
ep_dict["next.reward"] = tf_to_torch(episode["reward"]).float()
|
|
||||||
ep_dict["next.done"] = tf_to_torch(episode["is_last"])
|
|
||||||
ep_dict["is_terminal"] = tf_to_torch(episode["is_terminal"])
|
|
||||||
ep_dict["is_first"] = tf_to_torch(episode["is_first"])
|
|
||||||
ep_dict["discount"] = tf_to_torch(episode["discount"])
|
|
||||||
|
|
||||||
# If lang_key is present, convert the entire tensor at once
|
|
||||||
if lang_key is not None:
|
|
||||||
ep_dict["language_instruction"] = [x.numpy().decode("utf-8") for x in episode[lang_key]]
|
|
||||||
|
|
||||||
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
|
|
||||||
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames)
|
|
||||||
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
|
|
||||||
|
|
||||||
image_array_dict = {key: [] for key in image_keys}
|
|
||||||
|
|
||||||
for im_key in image_keys:
|
|
||||||
imgs = episode["observation"][im_key]
|
|
||||||
image_array_dict[im_key] = [tf_img_convert(img) for img in imgs]
|
|
||||||
|
|
||||||
# loop through all cameras
|
|
||||||
for im_key in image_keys:
|
|
||||||
img_key = f"observation.images.{im_key}"
|
|
||||||
imgs_array = image_array_dict[im_key]
|
|
||||||
imgs_array = np.array(imgs_array)
|
|
||||||
if video:
|
|
||||||
# save png images in temporary directory
|
|
||||||
tmp_imgs_dir = videos_dir / "tmp_images"
|
|
||||||
save_images_concurrently(imgs_array, tmp_imgs_dir)
|
|
||||||
|
|
||||||
# encode images to a mp4 video
|
|
||||||
fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
|
|
||||||
video_path = videos_dir / fname
|
|
||||||
encode_video_frames(tmp_imgs_dir, video_path, fps, **(encoding or {}))
|
|
||||||
|
|
||||||
# clean temporary images directory
|
|
||||||
shutil.rmtree(tmp_imgs_dir)
|
|
||||||
|
|
||||||
# store the reference to the video frame
|
|
||||||
ep_dict[img_key] = [
|
|
||||||
{"path": f"videos/{fname}", "timestamp": i / fps} for i in range(num_frames)
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array]
|
|
||||||
|
|
||||||
path_ep_dict = tmp_ep_dicts_dir.joinpath(
|
|
||||||
"ep_dict_" + "0" * (10 - len(str(ep_idx))) + str(ep_idx) + ".pt"
|
|
||||||
)
|
|
||||||
torch.save(ep_dict, path_ep_dict)
|
|
||||||
|
|
||||||
ep_dicts.append(ep_dict)
|
|
||||||
|
|
||||||
data_dict = concatenate_episodes(ep_dicts)
|
|
||||||
|
|
||||||
total_frames = data_dict["frame_index"].shape[0]
|
|
||||||
data_dict["index"] = torch.arange(0, total_frames, 1)
|
|
||||||
return data_dict
|
|
||||||
|
|
||||||
|
|
||||||
def to_hf_dataset(data_dict, video) -> Dataset:
|
|
||||||
features = {}
|
|
||||||
|
|
||||||
for key in data_dict:
|
|
||||||
# check if vector state obs
|
|
||||||
if key.startswith("observation.") and "observation.images." not in key:
|
|
||||||
features[key] = Sequence(length=data_dict[key].shape[1], feature=Value(dtype="float32", id=None))
|
|
||||||
# check if image obs
|
|
||||||
elif "observation.images." in key:
|
|
||||||
if video:
|
|
||||||
features[key] = VideoFrame()
|
|
||||||
else:
|
|
||||||
features[key] = Image()
|
|
||||||
|
|
||||||
if "language_instruction" in data_dict:
|
|
||||||
features["language_instruction"] = Value(dtype="string", id=None)
|
|
||||||
|
|
||||||
features["action"] = Sequence(
|
|
||||||
length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)
|
|
||||||
)
|
|
||||||
|
|
||||||
features["is_terminal"] = Value(dtype="bool", id=None)
|
|
||||||
features["is_first"] = Value(dtype="bool", id=None)
|
|
||||||
features["discount"] = Value(dtype="float32", id=None)
|
|
||||||
|
|
||||||
features["episode_index"] = Value(dtype="int64", id=None)
|
|
||||||
features["frame_index"] = Value(dtype="int64", id=None)
|
|
||||||
features["timestamp"] = Value(dtype="float32", id=None)
|
|
||||||
features["next.reward"] = Value(dtype="float32", id=None)
|
|
||||||
features["next.done"] = Value(dtype="bool", id=None)
|
|
||||||
features["index"] = Value(dtype="int64", id=None)
|
|
||||||
|
|
||||||
hf_dataset = Dataset.from_dict(data_dict, features=Features(features))
|
|
||||||
hf_dataset.set_transform(hf_transform_to_torch)
|
|
||||||
return hf_dataset
|
|
||||||
|
|
||||||
|
|
||||||
def from_raw_to_lerobot_format(
|
|
||||||
raw_dir: Path,
|
|
||||||
videos_dir: Path,
|
|
||||||
fps: int | None = None,
|
|
||||||
video: bool = True,
|
|
||||||
episodes: list[int] | None = None,
|
|
||||||
encoding: dict | None = None,
|
|
||||||
):
|
|
||||||
data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes, encoding)
|
|
||||||
hf_dataset = to_hf_dataset(data_dict, video)
|
|
||||||
episode_data_index = calculate_episode_data_index(hf_dataset)
|
|
||||||
info = {
|
|
||||||
"codebase_version": CODEBASE_VERSION,
|
|
||||||
"fps": fps,
|
|
||||||
"video": video,
|
|
||||||
}
|
|
||||||
if video:
|
|
||||||
info["encoding"] = get_default_encoding()
|
|
||||||
|
|
||||||
return hf_dataset, episode_data_index, info
|
|
|
@ -1,275 +0,0 @@
|
||||||
#!/usr/bin/env python
|
|
||||||
|
|
||||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""Process zarr files formatted like in: https://github.com/real-stanford/diffusion_policy"""
|
|
||||||
|
|
||||||
import shutil
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import tqdm
|
|
||||||
import zarr
|
|
||||||
from datasets import Dataset, Features, Image, Sequence, Value
|
|
||||||
from PIL import Image as PILImage
|
|
||||||
|
|
||||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
|
||||||
from lerobot.common.datasets.push_dataset_to_hub.utils import (
|
|
||||||
calculate_episode_data_index,
|
|
||||||
concatenate_episodes,
|
|
||||||
get_default_encoding,
|
|
||||||
save_images_concurrently,
|
|
||||||
)
|
|
||||||
from lerobot.common.datasets.utils import (
|
|
||||||
hf_transform_to_torch,
|
|
||||||
)
|
|
||||||
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
|
|
||||||
|
|
||||||
|
|
||||||
def check_format(raw_dir):
|
|
||||||
zarr_path = raw_dir / "pusht_cchi_v7_replay.zarr"
|
|
||||||
zarr_data = zarr.open(zarr_path, mode="r")
|
|
||||||
|
|
||||||
required_datasets = {
|
|
||||||
"data/action",
|
|
||||||
"data/img",
|
|
||||||
"data/keypoint",
|
|
||||||
"data/n_contacts",
|
|
||||||
"data/state",
|
|
||||||
"meta/episode_ends",
|
|
||||||
}
|
|
||||||
for dataset in required_datasets:
|
|
||||||
assert dataset in zarr_data
|
|
||||||
nb_frames = zarr_data["data/img"].shape[0]
|
|
||||||
|
|
||||||
required_datasets.remove("meta/episode_ends")
|
|
||||||
|
|
||||||
assert all(nb_frames == zarr_data[dataset].shape[0] for dataset in required_datasets)
|
|
||||||
|
|
||||||
|
|
||||||
def load_from_raw(
|
|
||||||
raw_dir: Path,
|
|
||||||
videos_dir: Path,
|
|
||||||
fps: int,
|
|
||||||
video: bool,
|
|
||||||
episodes: list[int] | None = None,
|
|
||||||
keypoints_instead_of_image: bool = False,
|
|
||||||
encoding: dict | None = None,
|
|
||||||
):
|
|
||||||
try:
|
|
||||||
import pymunk
|
|
||||||
from gym_pusht.envs.pusht import PushTEnv, pymunk_to_shapely
|
|
||||||
|
|
||||||
from lerobot.common.datasets.push_dataset_to_hub._diffusion_policy_replay_buffer import (
|
|
||||||
ReplayBuffer as DiffusionPolicyReplayBuffer,
|
|
||||||
)
|
|
||||||
except ModuleNotFoundError as e:
|
|
||||||
print("`gym_pusht` is not installed. Please install it with `pip install 'lerobot[gym_pusht]'`")
|
|
||||||
raise e
|
|
||||||
# as define in gmy-pusht env: https://github.com/huggingface/gym-pusht/blob/e0684ff988d223808c0a9dcfaba9dc4991791370/gym_pusht/envs/pusht.py#L174
|
|
||||||
success_threshold = 0.95 # 95% coverage,
|
|
||||||
|
|
||||||
zarr_path = raw_dir / "pusht_cchi_v7_replay.zarr"
|
|
||||||
zarr_data = DiffusionPolicyReplayBuffer.copy_from_path(zarr_path)
|
|
||||||
|
|
||||||
episode_ids = torch.from_numpy(zarr_data.get_episode_idxs())
|
|
||||||
assert len(
|
|
||||||
{zarr_data[key].shape[0] for key in zarr_data.keys()} # noqa: SIM118
|
|
||||||
), "Some data type dont have the same number of total frames."
|
|
||||||
|
|
||||||
# TODO(rcadene): verify that goal pose is expected to be fixed
|
|
||||||
goal_pos_angle = np.array([256, 256, np.pi / 4]) # x, y, theta (in radians)
|
|
||||||
goal_body = PushTEnv.get_goal_pose_body(goal_pos_angle)
|
|
||||||
|
|
||||||
imgs = torch.from_numpy(zarr_data["img"]) # b h w c
|
|
||||||
states = torch.from_numpy(zarr_data["state"])
|
|
||||||
actions = torch.from_numpy(zarr_data["action"])
|
|
||||||
|
|
||||||
# load data indices from which each episode starts and ends
|
|
||||||
from_ids, to_ids = [], []
|
|
||||||
from_idx = 0
|
|
||||||
for to_idx in zarr_data.meta["episode_ends"]:
|
|
||||||
from_ids.append(from_idx)
|
|
||||||
to_ids.append(to_idx)
|
|
||||||
from_idx = to_idx
|
|
||||||
|
|
||||||
num_episodes = len(from_ids)
|
|
||||||
|
|
||||||
ep_dicts = []
|
|
||||||
ep_ids = episodes if episodes else range(num_episodes)
|
|
||||||
for ep_idx, selected_ep_idx in tqdm.tqdm(enumerate(ep_ids)):
|
|
||||||
from_idx = from_ids[selected_ep_idx]
|
|
||||||
to_idx = to_ids[selected_ep_idx]
|
|
||||||
num_frames = to_idx - from_idx
|
|
||||||
|
|
||||||
# sanity check
|
|
||||||
assert (episode_ids[from_idx:to_idx] == ep_idx).all()
|
|
||||||
|
|
||||||
# get image
|
|
||||||
if not keypoints_instead_of_image:
|
|
||||||
image = imgs[from_idx:to_idx]
|
|
||||||
assert image.min() >= 0.0
|
|
||||||
assert image.max() <= 255.0
|
|
||||||
image = image.type(torch.uint8)
|
|
||||||
|
|
||||||
# get state
|
|
||||||
state = states[from_idx:to_idx]
|
|
||||||
agent_pos = state[:, :2]
|
|
||||||
block_pos = state[:, 2:4]
|
|
||||||
block_angle = state[:, 4]
|
|
||||||
|
|
||||||
# get reward, success, done, and (maybe) keypoints
|
|
||||||
reward = torch.zeros(num_frames)
|
|
||||||
success = torch.zeros(num_frames, dtype=torch.bool)
|
|
||||||
if keypoints_instead_of_image:
|
|
||||||
keypoints = torch.zeros(num_frames, 16) # 8 keypoints each with 2 coords
|
|
||||||
done = torch.zeros(num_frames, dtype=torch.bool)
|
|
||||||
for i in range(num_frames):
|
|
||||||
space = pymunk.Space()
|
|
||||||
space.gravity = 0, 0
|
|
||||||
space.damping = 0
|
|
||||||
|
|
||||||
# Add walls.
|
|
||||||
walls = [
|
|
||||||
PushTEnv.add_segment(space, (5, 506), (5, 5), 2),
|
|
||||||
PushTEnv.add_segment(space, (5, 5), (506, 5), 2),
|
|
||||||
PushTEnv.add_segment(space, (506, 5), (506, 506), 2),
|
|
||||||
PushTEnv.add_segment(space, (5, 506), (506, 506), 2),
|
|
||||||
]
|
|
||||||
space.add(*walls)
|
|
||||||
|
|
||||||
block_body, block_shapes = PushTEnv.add_tee(space, block_pos[i].tolist(), block_angle[i].item())
|
|
||||||
goal_geom = pymunk_to_shapely(goal_body, block_body.shapes)
|
|
||||||
block_geom = pymunk_to_shapely(block_body, block_body.shapes)
|
|
||||||
intersection_area = goal_geom.intersection(block_geom).area
|
|
||||||
goal_area = goal_geom.area
|
|
||||||
coverage = intersection_area / goal_area
|
|
||||||
reward[i] = np.clip(coverage / success_threshold, 0, 1)
|
|
||||||
success[i] = coverage > success_threshold
|
|
||||||
if keypoints_instead_of_image:
|
|
||||||
keypoints[i] = torch.from_numpy(PushTEnv.get_keypoints(block_shapes).flatten())
|
|
||||||
|
|
||||||
# last step of demonstration is considered done
|
|
||||||
done[-1] = True
|
|
||||||
|
|
||||||
ep_dict = {}
|
|
||||||
|
|
||||||
if not keypoints_instead_of_image:
|
|
||||||
imgs_array = [x.numpy() for x in image]
|
|
||||||
img_key = "observation.image"
|
|
||||||
if video:
|
|
||||||
# save png images in temporary directory
|
|
||||||
tmp_imgs_dir = videos_dir / "tmp_images"
|
|
||||||
save_images_concurrently(imgs_array, tmp_imgs_dir)
|
|
||||||
|
|
||||||
# encode images to a mp4 video
|
|
||||||
fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
|
|
||||||
video_path = videos_dir / fname
|
|
||||||
encode_video_frames(tmp_imgs_dir, video_path, fps, **(encoding or {}))
|
|
||||||
|
|
||||||
# clean temporary images directory
|
|
||||||
shutil.rmtree(tmp_imgs_dir)
|
|
||||||
|
|
||||||
# store the reference to the video frame
|
|
||||||
ep_dict[img_key] = [
|
|
||||||
{"path": f"videos/{fname}", "timestamp": i / fps} for i in range(num_frames)
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array]
|
|
||||||
|
|
||||||
ep_dict["observation.state"] = agent_pos
|
|
||||||
if keypoints_instead_of_image:
|
|
||||||
ep_dict["observation.environment_state"] = keypoints
|
|
||||||
ep_dict["action"] = actions[from_idx:to_idx]
|
|
||||||
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int64)
|
|
||||||
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
|
|
||||||
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
|
|
||||||
# ep_dict["next.observation.image"] = image[1:],
|
|
||||||
# ep_dict["next.observation.state"] = agent_pos[1:],
|
|
||||||
# TODO(rcadene)] = verify that reward and done are aligned with image and agent_pos
|
|
||||||
ep_dict["next.reward"] = torch.cat([reward[1:], reward[[-1]]])
|
|
||||||
ep_dict["next.done"] = torch.cat([done[1:], done[[-1]]])
|
|
||||||
ep_dict["next.success"] = torch.cat([success[1:], success[[-1]]])
|
|
||||||
ep_dicts.append(ep_dict)
|
|
||||||
data_dict = concatenate_episodes(ep_dicts)
|
|
||||||
|
|
||||||
total_frames = data_dict["frame_index"].shape[0]
|
|
||||||
data_dict["index"] = torch.arange(0, total_frames, 1)
|
|
||||||
return data_dict
|
|
||||||
|
|
||||||
|
|
||||||
def to_hf_dataset(data_dict, video, keypoints_instead_of_image: bool = False):
|
|
||||||
features = {}
|
|
||||||
|
|
||||||
if not keypoints_instead_of_image:
|
|
||||||
if video:
|
|
||||||
features["observation.image"] = VideoFrame()
|
|
||||||
else:
|
|
||||||
features["observation.image"] = Image()
|
|
||||||
|
|
||||||
features["observation.state"] = Sequence(
|
|
||||||
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
|
|
||||||
)
|
|
||||||
if keypoints_instead_of_image:
|
|
||||||
features["observation.environment_state"] = Sequence(
|
|
||||||
length=data_dict["observation.environment_state"].shape[1],
|
|
||||||
feature=Value(dtype="float32", id=None),
|
|
||||||
)
|
|
||||||
features["action"] = Sequence(
|
|
||||||
length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)
|
|
||||||
)
|
|
||||||
features["episode_index"] = Value(dtype="int64", id=None)
|
|
||||||
features["frame_index"] = Value(dtype="int64", id=None)
|
|
||||||
features["timestamp"] = Value(dtype="float32", id=None)
|
|
||||||
features["next.reward"] = Value(dtype="float32", id=None)
|
|
||||||
features["next.done"] = Value(dtype="bool", id=None)
|
|
||||||
features["next.success"] = Value(dtype="bool", id=None)
|
|
||||||
features["index"] = Value(dtype="int64", id=None)
|
|
||||||
|
|
||||||
hf_dataset = Dataset.from_dict(data_dict, features=Features(features))
|
|
||||||
hf_dataset.set_transform(hf_transform_to_torch)
|
|
||||||
return hf_dataset
|
|
||||||
|
|
||||||
|
|
||||||
def from_raw_to_lerobot_format(
|
|
||||||
raw_dir: Path,
|
|
||||||
videos_dir: Path,
|
|
||||||
fps: int | None = None,
|
|
||||||
video: bool = True,
|
|
||||||
episodes: list[int] | None = None,
|
|
||||||
encoding: dict | None = None,
|
|
||||||
):
|
|
||||||
# Manually change this to True to use keypoints of the T instead of an image observation (but don't merge
|
|
||||||
# with True). Also make sure to use video = 0 in the `push_dataset_to_hub.py` script.
|
|
||||||
keypoints_instead_of_image = False
|
|
||||||
|
|
||||||
# sanity check
|
|
||||||
check_format(raw_dir)
|
|
||||||
|
|
||||||
if fps is None:
|
|
||||||
fps = 10
|
|
||||||
|
|
||||||
data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes, keypoints_instead_of_image, encoding)
|
|
||||||
hf_dataset = to_hf_dataset(data_dict, video, keypoints_instead_of_image)
|
|
||||||
episode_data_index = calculate_episode_data_index(hf_dataset)
|
|
||||||
info = {
|
|
||||||
"codebase_version": CODEBASE_VERSION,
|
|
||||||
"fps": fps,
|
|
||||||
"video": video if not keypoints_instead_of_image else 0,
|
|
||||||
}
|
|
||||||
if video:
|
|
||||||
info["encoding"] = get_default_encoding()
|
|
||||||
|
|
||||||
return hf_dataset, episode_data_index, info
|
|
|
@ -1,234 +0,0 @@
|
||||||
#!/usr/bin/env python
|
|
||||||
|
|
||||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""Process UMI (Universal Manipulation Interface) data stored in Zarr format like in: https://github.com/real-stanford/universal_manipulation_interface"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import shutil
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import tqdm
|
|
||||||
import zarr
|
|
||||||
from datasets import Dataset, Features, Image, Sequence, Value
|
|
||||||
from PIL import Image as PILImage
|
|
||||||
|
|
||||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
|
||||||
from lerobot.common.datasets.push_dataset_to_hub._umi_imagecodecs_numcodecs import register_codecs
|
|
||||||
from lerobot.common.datasets.push_dataset_to_hub.utils import (
|
|
||||||
calculate_episode_data_index,
|
|
||||||
concatenate_episodes,
|
|
||||||
get_default_encoding,
|
|
||||||
save_images_concurrently,
|
|
||||||
)
|
|
||||||
from lerobot.common.datasets.utils import (
|
|
||||||
hf_transform_to_torch,
|
|
||||||
)
|
|
||||||
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
|
|
||||||
|
|
||||||
|
|
||||||
def check_format(raw_dir) -> bool:
|
|
||||||
zarr_path = raw_dir / "cup_in_the_wild.zarr"
|
|
||||||
zarr_data = zarr.open(zarr_path, mode="r")
|
|
||||||
|
|
||||||
required_datasets = {
|
|
||||||
"data/robot0_demo_end_pose",
|
|
||||||
"data/robot0_demo_start_pose",
|
|
||||||
"data/robot0_eef_pos",
|
|
||||||
"data/robot0_eef_rot_axis_angle",
|
|
||||||
"data/robot0_gripper_width",
|
|
||||||
"meta/episode_ends",
|
|
||||||
"data/camera0_rgb",
|
|
||||||
}
|
|
||||||
for dataset in required_datasets:
|
|
||||||
if dataset not in zarr_data:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# mandatory to access zarr_data
|
|
||||||
register_codecs()
|
|
||||||
nb_frames = zarr_data["data/camera0_rgb"].shape[0]
|
|
||||||
|
|
||||||
required_datasets.remove("meta/episode_ends")
|
|
||||||
assert all(nb_frames == zarr_data[dataset].shape[0] for dataset in required_datasets)
|
|
||||||
|
|
||||||
|
|
||||||
def load_from_raw(
|
|
||||||
raw_dir: Path,
|
|
||||||
videos_dir: Path,
|
|
||||||
fps: int,
|
|
||||||
video: bool,
|
|
||||||
episodes: list[int] | None = None,
|
|
||||||
encoding: dict | None = None,
|
|
||||||
):
|
|
||||||
zarr_path = raw_dir / "cup_in_the_wild.zarr"
|
|
||||||
zarr_data = zarr.open(zarr_path, mode="r")
|
|
||||||
|
|
||||||
# We process the image data separately because it is too large to fit in memory
|
|
||||||
end_pose = torch.from_numpy(zarr_data["data/robot0_demo_end_pose"][:])
|
|
||||||
start_pos = torch.from_numpy(zarr_data["data/robot0_demo_start_pose"][:])
|
|
||||||
eff_pos = torch.from_numpy(zarr_data["data/robot0_eef_pos"][:])
|
|
||||||
eff_rot_axis_angle = torch.from_numpy(zarr_data["data/robot0_eef_rot_axis_angle"][:])
|
|
||||||
gripper_width = torch.from_numpy(zarr_data["data/robot0_gripper_width"][:])
|
|
||||||
|
|
||||||
states_pos = torch.cat([eff_pos, eff_rot_axis_angle], dim=1)
|
|
||||||
states = torch.cat([states_pos, gripper_width], dim=1)
|
|
||||||
|
|
||||||
episode_ends = zarr_data["meta/episode_ends"][:]
|
|
||||||
num_episodes = episode_ends.shape[0]
|
|
||||||
|
|
||||||
# We convert it in torch tensor later because the jit function does not support torch tensors
|
|
||||||
episode_ends = torch.from_numpy(episode_ends)
|
|
||||||
|
|
||||||
# load data indices from which each episode starts and ends
|
|
||||||
from_ids, to_ids = [], []
|
|
||||||
from_idx = 0
|
|
||||||
for to_idx in episode_ends:
|
|
||||||
from_ids.append(from_idx)
|
|
||||||
to_ids.append(to_idx)
|
|
||||||
from_idx = to_idx
|
|
||||||
|
|
||||||
ep_dicts_dir = videos_dir / "ep_dicts"
|
|
||||||
ep_dicts_dir.mkdir(exist_ok=True, parents=True)
|
|
||||||
ep_dicts = []
|
|
||||||
|
|
||||||
ep_ids = episodes if episodes else range(num_episodes)
|
|
||||||
for ep_idx, selected_ep_idx in tqdm.tqdm(enumerate(ep_ids)):
|
|
||||||
ep_dict_path = ep_dicts_dir / f"{ep_idx}"
|
|
||||||
if not ep_dict_path.is_file():
|
|
||||||
from_idx = from_ids[selected_ep_idx]
|
|
||||||
to_idx = to_ids[selected_ep_idx]
|
|
||||||
num_frames = to_idx - from_idx
|
|
||||||
|
|
||||||
# TODO(rcadene): save temporary images of the episode?
|
|
||||||
|
|
||||||
state = states[from_idx:to_idx]
|
|
||||||
|
|
||||||
ep_dict = {}
|
|
||||||
|
|
||||||
# load 57MB of images in RAM (400x224x224x3 uint8)
|
|
||||||
imgs_array = zarr_data["data/camera0_rgb"][from_idx:to_idx]
|
|
||||||
img_key = "observation.image"
|
|
||||||
if video:
|
|
||||||
fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
|
|
||||||
video_path = videos_dir / fname
|
|
||||||
if not video_path.is_file():
|
|
||||||
# save png images in temporary directory
|
|
||||||
tmp_imgs_dir = videos_dir / "tmp_images"
|
|
||||||
save_images_concurrently(imgs_array, tmp_imgs_dir)
|
|
||||||
|
|
||||||
# encode images to a mp4 video
|
|
||||||
encode_video_frames(tmp_imgs_dir, video_path, fps, **(encoding or {}))
|
|
||||||
|
|
||||||
# clean temporary images directory
|
|
||||||
shutil.rmtree(tmp_imgs_dir)
|
|
||||||
|
|
||||||
# store the reference to the video frame
|
|
||||||
ep_dict[img_key] = [
|
|
||||||
{"path": f"videos/{fname}", "timestamp": i / fps} for i in range(num_frames)
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array]
|
|
||||||
|
|
||||||
ep_dict["observation.state"] = state
|
|
||||||
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int64)
|
|
||||||
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
|
|
||||||
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
|
|
||||||
ep_dict["episode_data_index_from"] = torch.tensor([from_idx] * num_frames)
|
|
||||||
ep_dict["episode_data_index_to"] = torch.tensor([from_idx + num_frames] * num_frames)
|
|
||||||
ep_dict["end_pose"] = end_pose[from_idx:to_idx]
|
|
||||||
ep_dict["start_pos"] = start_pos[from_idx:to_idx]
|
|
||||||
ep_dict["gripper_width"] = gripper_width[from_idx:to_idx]
|
|
||||||
torch.save(ep_dict, ep_dict_path)
|
|
||||||
else:
|
|
||||||
ep_dict = torch.load(ep_dict_path)
|
|
||||||
|
|
||||||
ep_dicts.append(ep_dict)
|
|
||||||
|
|
||||||
data_dict = concatenate_episodes(ep_dicts)
|
|
||||||
|
|
||||||
total_frames = data_dict["frame_index"].shape[0]
|
|
||||||
data_dict["index"] = torch.arange(0, total_frames, 1)
|
|
||||||
return data_dict
|
|
||||||
|
|
||||||
|
|
||||||
def to_hf_dataset(data_dict, video):
|
|
||||||
features = {}
|
|
||||||
|
|
||||||
if video:
|
|
||||||
features["observation.image"] = VideoFrame()
|
|
||||||
else:
|
|
||||||
features["observation.image"] = Image()
|
|
||||||
|
|
||||||
features["observation.state"] = Sequence(
|
|
||||||
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
|
|
||||||
)
|
|
||||||
features["episode_index"] = Value(dtype="int64", id=None)
|
|
||||||
features["frame_index"] = Value(dtype="int64", id=None)
|
|
||||||
features["timestamp"] = Value(dtype="float32", id=None)
|
|
||||||
features["index"] = Value(dtype="int64", id=None)
|
|
||||||
features["episode_data_index_from"] = Value(dtype="int64", id=None)
|
|
||||||
features["episode_data_index_to"] = Value(dtype="int64", id=None)
|
|
||||||
# `start_pos` and `end_pos` respectively represent the positions of the end-effector
|
|
||||||
# at the beginning and the end of the episode.
|
|
||||||
# `gripper_width` indicates the distance between the grippers, and this value is included
|
|
||||||
# in the state vector, which comprises the concatenation of the end-effector position
|
|
||||||
# and gripper width.
|
|
||||||
features["end_pose"] = Sequence(
|
|
||||||
length=data_dict["end_pose"].shape[1], feature=Value(dtype="float32", id=None)
|
|
||||||
)
|
|
||||||
features["start_pos"] = Sequence(
|
|
||||||
length=data_dict["start_pos"].shape[1], feature=Value(dtype="float32", id=None)
|
|
||||||
)
|
|
||||||
features["gripper_width"] = Sequence(
|
|
||||||
length=data_dict["gripper_width"].shape[1], feature=Value(dtype="float32", id=None)
|
|
||||||
)
|
|
||||||
|
|
||||||
hf_dataset = Dataset.from_dict(data_dict, features=Features(features))
|
|
||||||
hf_dataset.set_transform(hf_transform_to_torch)
|
|
||||||
return hf_dataset
|
|
||||||
|
|
||||||
|
|
||||||
def from_raw_to_lerobot_format(
|
|
||||||
raw_dir: Path,
|
|
||||||
videos_dir: Path,
|
|
||||||
fps: int | None = None,
|
|
||||||
video: bool = True,
|
|
||||||
episodes: list[int] | None = None,
|
|
||||||
encoding: dict | None = None,
|
|
||||||
):
|
|
||||||
# sanity check
|
|
||||||
check_format(raw_dir)
|
|
||||||
|
|
||||||
if fps is None:
|
|
||||||
# For umi cup in the wild: https://arxiv.org/pdf/2402.10329#table.caption.16
|
|
||||||
fps = 10
|
|
||||||
|
|
||||||
if not video:
|
|
||||||
logging.warning(
|
|
||||||
"Generating UMI dataset without `video=True` creates ~150GB on disk and requires ~80GB in RAM."
|
|
||||||
)
|
|
||||||
|
|
||||||
data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes, encoding)
|
|
||||||
hf_dataset = to_hf_dataset(data_dict, video)
|
|
||||||
episode_data_index = calculate_episode_data_index(hf_dataset)
|
|
||||||
info = {
|
|
||||||
"codebase_version": CODEBASE_VERSION,
|
|
||||||
"fps": fps,
|
|
||||||
"video": video,
|
|
||||||
}
|
|
||||||
if video:
|
|
||||||
info["encoding"] = get_default_encoding()
|
|
||||||
|
|
||||||
return hf_dataset, episode_data_index, info
|
|
|
@ -1,200 +0,0 @@
|
||||||
#!/usr/bin/env python
|
|
||||||
|
|
||||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""Process pickle files formatted like in: https://github.com/fyhMer/fowm"""
|
|
||||||
|
|
||||||
import pickle
|
|
||||||
import shutil
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import einops
|
|
||||||
import torch
|
|
||||||
import tqdm
|
|
||||||
from datasets import Dataset, Features, Image, Sequence, Value
|
|
||||||
from PIL import Image as PILImage
|
|
||||||
|
|
||||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
|
||||||
from lerobot.common.datasets.push_dataset_to_hub.utils import (
|
|
||||||
calculate_episode_data_index,
|
|
||||||
concatenate_episodes,
|
|
||||||
get_default_encoding,
|
|
||||||
save_images_concurrently,
|
|
||||||
)
|
|
||||||
from lerobot.common.datasets.utils import (
|
|
||||||
hf_transform_to_torch,
|
|
||||||
)
|
|
||||||
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
|
|
||||||
|
|
||||||
|
|
||||||
def check_format(raw_dir):
|
|
||||||
keys = {"actions", "rewards", "dones"}
|
|
||||||
nested_keys = {"observations": {"rgb", "state"}, "next_observations": {"rgb", "state"}}
|
|
||||||
|
|
||||||
xarm_files = list(raw_dir.glob("*.pkl"))
|
|
||||||
assert len(xarm_files) > 0
|
|
||||||
|
|
||||||
with open(xarm_files[0], "rb") as f:
|
|
||||||
dataset_dict = pickle.load(f)
|
|
||||||
|
|
||||||
assert isinstance(dataset_dict, dict)
|
|
||||||
assert all(k in dataset_dict for k in keys)
|
|
||||||
|
|
||||||
# Check for consistent lengths in nested keys
|
|
||||||
expected_len = len(dataset_dict["actions"])
|
|
||||||
assert all(len(dataset_dict[key]) == expected_len for key in keys if key in dataset_dict)
|
|
||||||
|
|
||||||
for key, subkeys in nested_keys.items():
|
|
||||||
nested_dict = dataset_dict.get(key, {})
|
|
||||||
assert all(len(nested_dict[subkey]) == expected_len for subkey in subkeys if subkey in nested_dict)
|
|
||||||
|
|
||||||
|
|
||||||
def load_from_raw(
|
|
||||||
raw_dir: Path,
|
|
||||||
videos_dir: Path,
|
|
||||||
fps: int,
|
|
||||||
video: bool,
|
|
||||||
episodes: list[int] | None = None,
|
|
||||||
encoding: dict | None = None,
|
|
||||||
):
|
|
||||||
pkl_path = raw_dir / "buffer.pkl"
|
|
||||||
|
|
||||||
with open(pkl_path, "rb") as f:
|
|
||||||
pkl_data = pickle.load(f)
|
|
||||||
|
|
||||||
# load data indices from which each episode starts and ends
|
|
||||||
from_ids, to_ids = [], []
|
|
||||||
from_idx, to_idx = 0, 0
|
|
||||||
for done in pkl_data["dones"]:
|
|
||||||
to_idx += 1
|
|
||||||
if not done:
|
|
||||||
continue
|
|
||||||
from_ids.append(from_idx)
|
|
||||||
to_ids.append(to_idx)
|
|
||||||
from_idx = to_idx
|
|
||||||
|
|
||||||
num_episodes = len(from_ids)
|
|
||||||
|
|
||||||
ep_dicts = []
|
|
||||||
ep_ids = episodes if episodes else range(num_episodes)
|
|
||||||
for ep_idx, selected_ep_idx in tqdm.tqdm(enumerate(ep_ids)):
|
|
||||||
from_idx = from_ids[selected_ep_idx]
|
|
||||||
to_idx = to_ids[selected_ep_idx]
|
|
||||||
num_frames = to_idx - from_idx
|
|
||||||
|
|
||||||
image = torch.tensor(pkl_data["observations"]["rgb"][from_idx:to_idx])
|
|
||||||
image = einops.rearrange(image, "b c h w -> b h w c")
|
|
||||||
state = torch.tensor(pkl_data["observations"]["state"][from_idx:to_idx])
|
|
||||||
action = torch.tensor(pkl_data["actions"][from_idx:to_idx])
|
|
||||||
# TODO(rcadene): we have a missing last frame which is the observation when the env is done
|
|
||||||
# it is critical to have this frame for tdmpc to predict a "done observation/state"
|
|
||||||
# next_image = torch.tensor(pkl_data["next_observations"]["rgb"][from_idx:to_idx])
|
|
||||||
# next_state = torch.tensor(pkl_data["next_observations"]["state"][from_idx:to_idx])
|
|
||||||
next_reward = torch.tensor(pkl_data["rewards"][from_idx:to_idx])
|
|
||||||
next_done = torch.tensor(pkl_data["dones"][from_idx:to_idx])
|
|
||||||
|
|
||||||
ep_dict = {}
|
|
||||||
|
|
||||||
imgs_array = [x.numpy() for x in image]
|
|
||||||
img_key = "observation.image"
|
|
||||||
if video:
|
|
||||||
# save png images in temporary directory
|
|
||||||
tmp_imgs_dir = videos_dir / "tmp_images"
|
|
||||||
save_images_concurrently(imgs_array, tmp_imgs_dir)
|
|
||||||
|
|
||||||
# encode images to a mp4 video
|
|
||||||
fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
|
|
||||||
video_path = videos_dir / fname
|
|
||||||
encode_video_frames(tmp_imgs_dir, video_path, fps, **(encoding or {}))
|
|
||||||
|
|
||||||
# clean temporary images directory
|
|
||||||
shutil.rmtree(tmp_imgs_dir)
|
|
||||||
|
|
||||||
# store the reference to the video frame
|
|
||||||
ep_dict[img_key] = [{"path": f"videos/{fname}", "timestamp": i / fps} for i in range(num_frames)]
|
|
||||||
else:
|
|
||||||
ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array]
|
|
||||||
|
|
||||||
ep_dict["observation.state"] = state
|
|
||||||
ep_dict["action"] = action
|
|
||||||
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int64)
|
|
||||||
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
|
|
||||||
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
|
|
||||||
# ep_dict["next.observation.image"] = next_image
|
|
||||||
# ep_dict["next.observation.state"] = next_state
|
|
||||||
ep_dict["next.reward"] = next_reward
|
|
||||||
ep_dict["next.done"] = next_done
|
|
||||||
ep_dicts.append(ep_dict)
|
|
||||||
|
|
||||||
data_dict = concatenate_episodes(ep_dicts)
|
|
||||||
|
|
||||||
total_frames = data_dict["frame_index"].shape[0]
|
|
||||||
data_dict["index"] = torch.arange(0, total_frames, 1)
|
|
||||||
return data_dict
|
|
||||||
|
|
||||||
|
|
||||||
def to_hf_dataset(data_dict, video):
|
|
||||||
features = {}
|
|
||||||
|
|
||||||
if video:
|
|
||||||
features["observation.image"] = VideoFrame()
|
|
||||||
else:
|
|
||||||
features["observation.image"] = Image()
|
|
||||||
|
|
||||||
features["observation.state"] = Sequence(
|
|
||||||
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
|
|
||||||
)
|
|
||||||
features["action"] = Sequence(
|
|
||||||
length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)
|
|
||||||
)
|
|
||||||
features["episode_index"] = Value(dtype="int64", id=None)
|
|
||||||
features["frame_index"] = Value(dtype="int64", id=None)
|
|
||||||
features["timestamp"] = Value(dtype="float32", id=None)
|
|
||||||
features["next.reward"] = Value(dtype="float32", id=None)
|
|
||||||
features["next.done"] = Value(dtype="bool", id=None)
|
|
||||||
features["index"] = Value(dtype="int64", id=None)
|
|
||||||
# TODO(rcadene): add success
|
|
||||||
# features["next.success"] = Value(dtype='bool', id=None)
|
|
||||||
|
|
||||||
hf_dataset = Dataset.from_dict(data_dict, features=Features(features))
|
|
||||||
hf_dataset.set_transform(hf_transform_to_torch)
|
|
||||||
return hf_dataset
|
|
||||||
|
|
||||||
|
|
||||||
def from_raw_to_lerobot_format(
|
|
||||||
raw_dir: Path,
|
|
||||||
videos_dir: Path,
|
|
||||||
fps: int | None = None,
|
|
||||||
video: bool = True,
|
|
||||||
episodes: list[int] | None = None,
|
|
||||||
encoding: dict | None = None,
|
|
||||||
):
|
|
||||||
# sanity check
|
|
||||||
check_format(raw_dir)
|
|
||||||
|
|
||||||
if fps is None:
|
|
||||||
fps = 15
|
|
||||||
|
|
||||||
data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes, encoding)
|
|
||||||
hf_dataset = to_hf_dataset(data_dict, video)
|
|
||||||
episode_data_index = calculate_episode_data_index(hf_dataset)
|
|
||||||
info = {
|
|
||||||
"codebase_version": CODEBASE_VERSION,
|
|
||||||
"fps": fps,
|
|
||||||
"video": video,
|
|
||||||
}
|
|
||||||
if video:
|
|
||||||
info["encoding"] = get_default_encoding()
|
|
||||||
|
|
||||||
return hf_dataset, episode_data_index, info
|
|
|
@ -13,6 +13,7 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
import importlib
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import subprocess
|
import subprocess
|
||||||
|
@ -29,6 +30,46 @@ from datasets.features.features import register_feature
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
|
def get_safe_default_codec():
|
||||||
|
if importlib.util.find_spec("torchcodec"):
|
||||||
|
return "torchcodec"
|
||||||
|
else:
|
||||||
|
logging.warning(
|
||||||
|
"'torchcodec' is not available in your platform, falling back to 'pyav' as a default decoder"
|
||||||
|
)
|
||||||
|
return "pyav"
|
||||||
|
|
||||||
|
|
||||||
|
def decode_video_frames(
|
||||||
|
video_path: Path | str,
|
||||||
|
timestamps: list[float],
|
||||||
|
tolerance_s: float,
|
||||||
|
backend: str | None = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Decodes video frames using the specified backend.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
video_path (Path): Path to the video file.
|
||||||
|
timestamps (list[float]): List of timestamps to extract frames.
|
||||||
|
tolerance_s (float): Allowed deviation in seconds for frame retrieval.
|
||||||
|
backend (str, optional): Backend to use for decoding. Defaults to "torchcodec" when available in the platform; otherwise, defaults to "pyav"..
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Decoded frames.
|
||||||
|
|
||||||
|
Currently supports torchcodec on cpu and pyav.
|
||||||
|
"""
|
||||||
|
if backend is None:
|
||||||
|
backend = get_safe_default_codec()
|
||||||
|
if backend == "torchcodec":
|
||||||
|
return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s)
|
||||||
|
elif backend in ["pyav", "video_reader"]:
|
||||||
|
return decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported video backend: {backend}")
|
||||||
|
|
||||||
|
|
||||||
def decode_video_frames_torchvision(
|
def decode_video_frames_torchvision(
|
||||||
video_path: Path | str,
|
video_path: Path | str,
|
||||||
timestamps: list[float],
|
timestamps: list[float],
|
||||||
|
@ -127,6 +168,81 @@ def decode_video_frames_torchvision(
|
||||||
return closest_frames
|
return closest_frames
|
||||||
|
|
||||||
|
|
||||||
|
def decode_video_frames_torchcodec(
|
||||||
|
video_path: Path | str,
|
||||||
|
timestamps: list[float],
|
||||||
|
tolerance_s: float,
|
||||||
|
device: str = "cpu",
|
||||||
|
log_loaded_timestamps: bool = False,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Loads frames associated with the requested timestamps of a video using torchcodec.
|
||||||
|
|
||||||
|
Note: Setting device="cuda" outside the main process, e.g. in data loader workers, will lead to CUDA initialization errors.
|
||||||
|
|
||||||
|
Note: Video benefits from inter-frame compression. Instead of storing every frame individually,
|
||||||
|
the encoder stores a reference frame (or a key frame) and subsequent frames as differences relative to
|
||||||
|
that key frame. As a consequence, to access a requested frame, we need to load the preceding key frame,
|
||||||
|
and all subsequent frames until reaching the requested frame. The number of key frames in a video
|
||||||
|
can be adjusted during encoding to take into account decoding time and video size in bytes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if importlib.util.find_spec("torchcodec"):
|
||||||
|
from torchcodec.decoders import VideoDecoder
|
||||||
|
else:
|
||||||
|
raise ImportError("torchcodec is required but not available.")
|
||||||
|
|
||||||
|
# initialize video decoder
|
||||||
|
decoder = VideoDecoder(video_path, device=device, seek_mode="approximate")
|
||||||
|
loaded_frames = []
|
||||||
|
loaded_ts = []
|
||||||
|
# get metadata for frame information
|
||||||
|
metadata = decoder.metadata
|
||||||
|
average_fps = metadata.average_fps
|
||||||
|
|
||||||
|
# convert timestamps to frame indices
|
||||||
|
frame_indices = [round(ts * average_fps) for ts in timestamps]
|
||||||
|
|
||||||
|
# retrieve frames based on indices
|
||||||
|
frames_batch = decoder.get_frames_at(indices=frame_indices)
|
||||||
|
|
||||||
|
for frame, pts in zip(frames_batch.data, frames_batch.pts_seconds, strict=False):
|
||||||
|
loaded_frames.append(frame)
|
||||||
|
loaded_ts.append(pts.item())
|
||||||
|
if log_loaded_timestamps:
|
||||||
|
logging.info(f"Frame loaded at timestamp={pts:.4f}")
|
||||||
|
|
||||||
|
query_ts = torch.tensor(timestamps)
|
||||||
|
loaded_ts = torch.tensor(loaded_ts)
|
||||||
|
|
||||||
|
# compute distances between each query timestamp and loaded timestamps
|
||||||
|
dist = torch.cdist(query_ts[:, None], loaded_ts[:, None], p=1)
|
||||||
|
min_, argmin_ = dist.min(1)
|
||||||
|
|
||||||
|
is_within_tol = min_ < tolerance_s
|
||||||
|
assert is_within_tol.all(), (
|
||||||
|
f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})."
|
||||||
|
"It means that the closest frame that can be loaded from the video is too far away in time."
|
||||||
|
"This might be due to synchronization issues with timestamps during data collection."
|
||||||
|
"To be safe, we advise to ignore this item during training."
|
||||||
|
f"\nqueried timestamps: {query_ts}"
|
||||||
|
f"\nloaded timestamps: {loaded_ts}"
|
||||||
|
f"\nvideo: {video_path}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# get closest frames to the query timestamps
|
||||||
|
closest_frames = torch.stack([loaded_frames[idx] for idx in argmin_])
|
||||||
|
closest_ts = loaded_ts[argmin_]
|
||||||
|
|
||||||
|
if log_loaded_timestamps:
|
||||||
|
logging.info(f"{closest_ts=}")
|
||||||
|
|
||||||
|
# convert to float32 in [0,1] range (channel first)
|
||||||
|
closest_frames = closest_frames.type(torch.float32) / 255
|
||||||
|
|
||||||
|
assert len(timestamps) == len(closest_frames)
|
||||||
|
return closest_frames
|
||||||
|
|
||||||
|
|
||||||
def encode_video_frames(
|
def encode_video_frames(
|
||||||
imgs_dir: Path | str,
|
imgs_dir: Path | str,
|
||||||
video_path: Path | str,
|
video_path: Path | str,
|
||||||
|
@ -141,6 +257,7 @@ def encode_video_frames(
|
||||||
) -> None:
|
) -> None:
|
||||||
"""More info on ffmpeg arguments tuning on `benchmark/video/README.md`"""
|
"""More info on ffmpeg arguments tuning on `benchmark/video/README.md`"""
|
||||||
video_path = Path(video_path)
|
video_path = Path(video_path)
|
||||||
|
imgs_dir = Path(imgs_dir)
|
||||||
video_path.parent.mkdir(parents=True, exist_ok=True)
|
video_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
ffmpeg_args = OrderedDict(
|
ffmpeg_args = OrderedDict(
|
||||||
|
|
|
@ -13,7 +13,11 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
import warnings
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
import einops
|
import einops
|
||||||
|
import gymnasium as gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
@ -86,3 +90,38 @@ def env_to_policy_features(env_cfg: EnvConfig) -> dict[str, PolicyFeature]:
|
||||||
policy_features[policy_key] = feature
|
policy_features[policy_key] = feature
|
||||||
|
|
||||||
return policy_features
|
return policy_features
|
||||||
|
|
||||||
|
|
||||||
|
def are_all_envs_same_type(env: gym.vector.VectorEnv) -> bool:
|
||||||
|
first_type = type(env.envs[0]) # Get type of first env
|
||||||
|
return all(type(e) is first_type for e in env.envs) # Fast type check
|
||||||
|
|
||||||
|
|
||||||
|
def check_env_attributes_and_types(env: gym.vector.VectorEnv) -> None:
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.simplefilter("once", UserWarning) # Apply filter only in this function
|
||||||
|
|
||||||
|
if not (hasattr(env.envs[0], "task_description") and hasattr(env.envs[0], "task")):
|
||||||
|
warnings.warn(
|
||||||
|
"The environment does not have 'task_description' and 'task'. Some policies require these features.",
|
||||||
|
UserWarning,
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
|
if not are_all_envs_same_type(env):
|
||||||
|
warnings.warn(
|
||||||
|
"The environments have different types. Make sure you infer the right task from each environment. Empty task will be passed instead.",
|
||||||
|
UserWarning,
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def add_envs_task(env: gym.vector.VectorEnv, observation: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""Adds task feature to the observation dict with respect to the first environment attribute."""
|
||||||
|
if hasattr(env.envs[0], "task_description"):
|
||||||
|
observation["task"] = env.call("task_description")
|
||||||
|
elif hasattr(env.envs[0], "task"):
|
||||||
|
observation["task"] = env.call("task")
|
||||||
|
else: # For envs without language instructions, e.g. aloha transfer cube and etc.
|
||||||
|
num_envs = observation[list(observation.keys())[0]].shape[0]
|
||||||
|
observation["task"] = ["" for _ in range(num_envs)]
|
||||||
|
return observation
|
||||||
|
|
|
@ -119,9 +119,7 @@ class ACTPolicy(PreTrainedPolicy):
|
||||||
batch = self.normalize_inputs(batch)
|
batch = self.normalize_inputs(batch)
|
||||||
if self.config.image_features:
|
if self.config.image_features:
|
||||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||||
batch["observation.images"] = torch.stack(
|
batch["observation.images"] = [batch[key] for key in self.config.image_features]
|
||||||
[batch[key] for key in self.config.image_features], dim=-4
|
|
||||||
)
|
|
||||||
|
|
||||||
# If we are doing temporal ensembling, do online updates where we keep track of the number of actions
|
# If we are doing temporal ensembling, do online updates where we keep track of the number of actions
|
||||||
# we are ensembling over.
|
# we are ensembling over.
|
||||||
|
@ -149,9 +147,8 @@ class ACTPolicy(PreTrainedPolicy):
|
||||||
batch = self.normalize_inputs(batch)
|
batch = self.normalize_inputs(batch)
|
||||||
if self.config.image_features:
|
if self.config.image_features:
|
||||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||||
batch["observation.images"] = torch.stack(
|
batch["observation.images"] = [batch[key] for key in self.config.image_features]
|
||||||
[batch[key] for key in self.config.image_features], dim=-4
|
|
||||||
)
|
|
||||||
batch = self.normalize_targets(batch)
|
batch = self.normalize_targets(batch)
|
||||||
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
|
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
|
||||||
|
|
||||||
|
@ -413,11 +410,10 @@ class ACT(nn.Module):
|
||||||
"actions must be provided when using the variational objective in training mode."
|
"actions must be provided when using the variational objective in training mode."
|
||||||
)
|
)
|
||||||
|
|
||||||
batch_size = (
|
if "observation.images" in batch:
|
||||||
batch["observation.images"]
|
batch_size = batch["observation.images"][0].shape[0]
|
||||||
if "observation.images" in batch
|
else:
|
||||||
else batch["observation.environment_state"]
|
batch_size = batch["observation.environment_state"].shape[0]
|
||||||
).shape[0]
|
|
||||||
|
|
||||||
# Prepare the latent for input to the transformer encoder.
|
# Prepare the latent for input to the transformer encoder.
|
||||||
if self.config.use_vae and "action" in batch:
|
if self.config.use_vae and "action" in batch:
|
||||||
|
@ -490,20 +486,21 @@ class ACT(nn.Module):
|
||||||
all_cam_features = []
|
all_cam_features = []
|
||||||
all_cam_pos_embeds = []
|
all_cam_pos_embeds = []
|
||||||
|
|
||||||
for cam_index in range(batch["observation.images"].shape[-4]):
|
# For a list of images, the H and W may vary but H*W is constant.
|
||||||
cam_features = self.backbone(batch["observation.images"][:, cam_index])["feature_map"]
|
for img in batch["observation.images"]:
|
||||||
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use
|
cam_features = self.backbone(img)["feature_map"]
|
||||||
# buffer
|
|
||||||
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
|
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
|
||||||
cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w)
|
cam_features = self.encoder_img_feat_input_proj(cam_features)
|
||||||
|
|
||||||
|
# Rearrange features to (sequence, batch, dim).
|
||||||
|
cam_features = einops.rearrange(cam_features, "b c h w -> (h w) b c")
|
||||||
|
cam_pos_embed = einops.rearrange(cam_pos_embed, "b c h w -> (h w) b c")
|
||||||
|
|
||||||
all_cam_features.append(cam_features)
|
all_cam_features.append(cam_features)
|
||||||
all_cam_pos_embeds.append(cam_pos_embed)
|
all_cam_pos_embeds.append(cam_pos_embed)
|
||||||
# Concatenate camera observation feature maps and positional embeddings along the width dimension,
|
|
||||||
# and move to (sequence, batch, dim).
|
encoder_in_tokens.extend(torch.cat(all_cam_features, axis=0))
|
||||||
all_cam_features = torch.cat(all_cam_features, axis=-1)
|
encoder_in_pos_embed.extend(torch.cat(all_cam_pos_embeds, axis=0))
|
||||||
encoder_in_tokens.extend(einops.rearrange(all_cam_features, "b c h w -> (h w) b c"))
|
|
||||||
all_cam_pos_embeds = torch.cat(all_cam_pos_embeds, axis=-1)
|
|
||||||
encoder_in_pos_embed.extend(einops.rearrange(all_cam_pos_embeds, "b c h w -> (h w) b c"))
|
|
||||||
|
|
||||||
# Stack all tokens along the sequence dimension.
|
# Stack all tokens along the sequence dimension.
|
||||||
encoder_in_tokens = torch.stack(encoder_in_tokens, axis=0)
|
encoder_in_tokens = torch.stack(encoder_in_tokens, axis=0)
|
||||||
|
|
|
@ -26,6 +26,7 @@ from lerobot.common.policies.act.configuration_act import ACTConfig
|
||||||
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
|
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||||
from lerobot.common.policies.dot.configuration_dot import DOTConfig
|
from lerobot.common.policies.dot.configuration_dot import DOTConfig
|
||||||
from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
|
from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
|
||||||
|
from lerobot.common.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
|
||||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||||
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||||
from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
|
from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||||
|
@ -55,6 +56,10 @@ def get_policy_class(name: str) -> PreTrainedPolicy:
|
||||||
from lerobot.common.policies.pi0.modeling_pi0 import PI0Policy
|
from lerobot.common.policies.pi0.modeling_pi0 import PI0Policy
|
||||||
|
|
||||||
return PI0Policy
|
return PI0Policy
|
||||||
|
elif name == "pi0fast":
|
||||||
|
from lerobot.common.policies.pi0fast.modeling_pi0fast import PI0FASTPolicy
|
||||||
|
|
||||||
|
return PI0FASTPolicy
|
||||||
elif name == "dot":
|
elif name == "dot":
|
||||||
from lerobot.common.policies.dot.modeling_dot import DOTPolicy
|
from lerobot.common.policies.dot.modeling_dot import DOTPolicy
|
||||||
|
|
||||||
|
@ -74,6 +79,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||||
return VQBeTConfig(**kwargs)
|
return VQBeTConfig(**kwargs)
|
||||||
elif policy_type == "pi0":
|
elif policy_type == "pi0":
|
||||||
return PI0Config(**kwargs)
|
return PI0Config(**kwargs)
|
||||||
|
elif policy_type == "pi0fast":
|
||||||
|
return PI0FASTConfig(**kwargs)
|
||||||
elif policy_type == "dot":
|
elif policy_type == "dot":
|
||||||
return DOTConfig(**kwargs)
|
return DOTConfig(**kwargs)
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -24,7 +24,7 @@ Designed by Physical Intelligence. Ported from Jax by Hugging Face.
|
||||||
|
|
||||||
Install pi0 extra dependencies:
|
Install pi0 extra dependencies:
|
||||||
```bash
|
```bash
|
||||||
pip install -e ".[pi0]"
|
pip install --no-binary=av -e ".[pi0]"
|
||||||
```
|
```
|
||||||
|
|
||||||
Example of finetuning the pi0 pretrained model (`pi0_base` in `openpi`):
|
Example of finetuning the pi0 pretrained model (`pi0_base` in `openpi`):
|
||||||
|
@ -313,7 +313,7 @@ class PI0Policy(PreTrainedPolicy):
|
||||||
state = self.prepare_state(batch)
|
state = self.prepare_state(batch)
|
||||||
lang_tokens, lang_masks = self.prepare_language(batch)
|
lang_tokens, lang_masks = self.prepare_language(batch)
|
||||||
actions = self.prepare_action(batch)
|
actions = self.prepare_action(batch)
|
||||||
actions_is_pad = batch.get("actions_is_pad")
|
actions_is_pad = batch.get("action_is_pad")
|
||||||
|
|
||||||
loss_dict = {}
|
loss_dict = {}
|
||||||
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions, noise, time)
|
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions, noise, time)
|
||||||
|
|
|
@ -0,0 +1,136 @@
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
from lerobot.common.optim.optimizers import AdamWConfig
|
||||||
|
from lerobot.common.optim.schedulers import (
|
||||||
|
CosineDecayWithWarmupSchedulerConfig,
|
||||||
|
)
|
||||||
|
from lerobot.configs.policies import PreTrainedConfig
|
||||||
|
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||||
|
|
||||||
|
|
||||||
|
@PreTrainedConfig.register_subclass("pi0fast")
|
||||||
|
@dataclass
|
||||||
|
class PI0FASTConfig(PreTrainedConfig):
|
||||||
|
# Input / output structure.
|
||||||
|
n_obs_steps: int = 1
|
||||||
|
chunk_size: int = 10
|
||||||
|
n_action_steps: int = 5
|
||||||
|
|
||||||
|
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||||
|
default_factory=lambda: {
|
||||||
|
"VISUAL": NormalizationMode.IDENTITY,
|
||||||
|
"STATE": NormalizationMode.MEAN_STD,
|
||||||
|
"ACTION": NormalizationMode.MEAN_STD,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Shorter state and action vectors will be padded
|
||||||
|
max_state_dim: int = 32 # 32
|
||||||
|
max_action_dim: int = 32 # 32
|
||||||
|
|
||||||
|
# Image preprocessing
|
||||||
|
resize_imgs_with_padding: tuple[int, int] = (224, 224)
|
||||||
|
interpolate_like_pi: bool = False
|
||||||
|
|
||||||
|
# Add empty images. Used by pi0_aloha_sim which adds the empty
|
||||||
|
# left and right wrist cameras in addition to the top camera.
|
||||||
|
empty_cameras: int = 0
|
||||||
|
|
||||||
|
# Converts the joint and gripper values from the standard Aloha space to
|
||||||
|
# the space used by the pi internal runtime which was used to train the base model.
|
||||||
|
adapt_to_pi_aloha: bool = False
|
||||||
|
|
||||||
|
# Converts joint dimensions to deltas with respect to the current state before passing to the model.
|
||||||
|
# Gripper dimensions will remain in absolute values.
|
||||||
|
use_delta_joint_actions_aloha: bool = False
|
||||||
|
|
||||||
|
# Tokenizer
|
||||||
|
tokenizer_max_length: int = 48
|
||||||
|
|
||||||
|
# Projector
|
||||||
|
proj_width: int = 1024
|
||||||
|
|
||||||
|
# Decoding
|
||||||
|
max_decoding_steps: int = 256
|
||||||
|
fast_skip_tokens: int = 128 # Skip last 128 tokens in PaliGemma vocab since they are special tokens
|
||||||
|
max_input_seq_len: int = 256 # 512
|
||||||
|
|
||||||
|
# Utils
|
||||||
|
use_cache: bool = True
|
||||||
|
|
||||||
|
# Frozen parameters
|
||||||
|
freeze_vision_encoder: bool = True
|
||||||
|
freeze_lm_head: bool = True
|
||||||
|
|
||||||
|
# Training presets
|
||||||
|
optimizer_lr: float = 1e-4
|
||||||
|
optimizer_betas: tuple[float, float] = (0.9, 0.95)
|
||||||
|
optimizer_eps: float = 1e-8
|
||||||
|
optimizer_weight_decay: float = 1e-5
|
||||||
|
|
||||||
|
scheduler_warmup_steps: int = 1_000
|
||||||
|
scheduler_decay_steps: int = 30_000
|
||||||
|
scheduler_decay_lr: float = 2.5e-6
|
||||||
|
|
||||||
|
checkpoint_path: str = None
|
||||||
|
|
||||||
|
padding_side: str = "right"
|
||||||
|
|
||||||
|
precision: str = "bfloat16"
|
||||||
|
grad_clip_norm: float = 1
|
||||||
|
|
||||||
|
# Allows padding/truncation of generated action tokens during detokenization to ensure decoding.
|
||||||
|
# In the original version, tensors of 0s were generated if shapes didn't match for stable decoding.
|
||||||
|
relaxed_action_decoding: bool = True
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
super().__post_init__()
|
||||||
|
|
||||||
|
"""Input validation (not exhaustive)."""
|
||||||
|
if self.n_action_steps > self.chunk_size:
|
||||||
|
raise ValueError(
|
||||||
|
f"The chunk size is the upper bound for the number of action steps per model invocation. Got "
|
||||||
|
f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`."
|
||||||
|
)
|
||||||
|
if self.n_obs_steps != 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`"
|
||||||
|
)
|
||||||
|
|
||||||
|
def validate_features(self) -> None:
|
||||||
|
for i in range(self.empty_cameras):
|
||||||
|
key = f"observation.images.empty_camera_{i}"
|
||||||
|
empty_camera = PolicyFeature(
|
||||||
|
type=FeatureType.VISUAL,
|
||||||
|
shape=(3, 480, 640),
|
||||||
|
)
|
||||||
|
self.input_features[key] = empty_camera
|
||||||
|
|
||||||
|
def get_optimizer_preset(self) -> AdamWConfig:
|
||||||
|
return AdamWConfig(
|
||||||
|
lr=self.optimizer_lr,
|
||||||
|
betas=self.optimizer_betas,
|
||||||
|
eps=self.optimizer_eps,
|
||||||
|
weight_decay=self.optimizer_weight_decay,
|
||||||
|
grad_clip_norm=self.grad_clip_norm,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_scheduler_preset(self):
|
||||||
|
return CosineDecayWithWarmupSchedulerConfig(
|
||||||
|
peak_lr=self.optimizer_lr,
|
||||||
|
decay_lr=self.scheduler_decay_lr,
|
||||||
|
num_warmup_steps=self.scheduler_warmup_steps,
|
||||||
|
num_decay_steps=self.scheduler_decay_steps,
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def observation_delta_indices(self) -> None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def action_delta_indices(self) -> list:
|
||||||
|
return list(range(self.chunk_size))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def reward_delta_indices(self) -> None:
|
||||||
|
return None
|
|
@ -0,0 +1,973 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
π0+FAST: Efficient Action Tokenization for Vision-Language-Action Models
|
||||||
|
|
||||||
|
[Paper](https://arxiv.org/abs/2501.09747)
|
||||||
|
[Jax code](https://github.com/Physical-Intelligence/openpi)
|
||||||
|
|
||||||
|
Designed by Physical Intelligence. Ported from Jax by Hugging Face.
|
||||||
|
|
||||||
|
Example of finetuning the pi0+FAST pretrained model (`pi0_fast_base` in `openpi`):
|
||||||
|
```bash
|
||||||
|
python lerobot/scripts/train.py \
|
||||||
|
--policy.path=lerobot/pi0fast_base \
|
||||||
|
--dataset.repo_id=danaaubakirova/koch_test
|
||||||
|
```
|
||||||
|
|
||||||
|
Example of training the pi0+FAST neural network with from scratch:
|
||||||
|
```bash
|
||||||
|
python lerobot/scripts/train.py \
|
||||||
|
--policy.type=pi0fast \
|
||||||
|
--dataset.repo_id=danaaubakirova/koch_test
|
||||||
|
```
|
||||||
|
|
||||||
|
Example of using the pi0 pretrained model outside LeRobot training framework:
|
||||||
|
```python
|
||||||
|
policy = PI0FASTPolicy.from_pretrained("lerobot/pi0fast_base")
|
||||||
|
```
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from collections import deque
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F # noqa: N812
|
||||||
|
from PIL import Image
|
||||||
|
from scipy.fft import idct
|
||||||
|
from torch import Tensor, nn
|
||||||
|
from transformers import AutoProcessor, AutoTokenizer, PaliGemmaForConditionalGeneration
|
||||||
|
from transformers.cache_utils import HybridCache, StaticCache
|
||||||
|
from transformers.models.auto import CONFIG_MAPPING
|
||||||
|
|
||||||
|
from lerobot.common.constants import ACTION, OBS_ROBOT
|
||||||
|
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||||
|
from lerobot.common.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
|
||||||
|
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||||
|
|
||||||
|
PRECISION = {
|
||||||
|
"float16": torch.float16,
|
||||||
|
"float32": torch.float32,
|
||||||
|
"bfloat16": torch.bfloat16,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def normalize(x, min_val, max_val):
|
||||||
|
return (x - min_val) / (max_val - min_val)
|
||||||
|
|
||||||
|
|
||||||
|
def unnormalize(x, min_val, max_val):
|
||||||
|
return x * (max_val - min_val) + min_val
|
||||||
|
|
||||||
|
|
||||||
|
def safe_arcsin(value):
|
||||||
|
# This ensures that the input stays within
|
||||||
|
# [−1,1] to avoid invalid values for arcsin
|
||||||
|
return torch.arcsin(torch.clamp(value, -1.0, 1.0))
|
||||||
|
|
||||||
|
|
||||||
|
def aloha_gripper_to_angular(value):
|
||||||
|
# Aloha transforms the gripper positions into a linear space. The following code
|
||||||
|
# reverses this transformation to be consistent with pi0 which is pretrained in
|
||||||
|
# angular space.
|
||||||
|
#
|
||||||
|
# These values are coming from the Aloha code:
|
||||||
|
# PUPPET_GRIPPER_POSITION_OPEN, PUPPET_GRIPPER_POSITION_CLOSED
|
||||||
|
value = unnormalize(value, min_val=0.01844, max_val=0.05800)
|
||||||
|
|
||||||
|
# This is the inverse of the angular to linear transformation inside the Interbotix code.
|
||||||
|
def linear_to_radian(linear_position, arm_length, horn_radius):
|
||||||
|
value = (horn_radius**2 + linear_position**2 - arm_length**2) / (2 * horn_radius * linear_position)
|
||||||
|
return safe_arcsin(value)
|
||||||
|
|
||||||
|
# The constants are taken from the Interbotix code.
|
||||||
|
value = linear_to_radian(value, arm_length=0.036, horn_radius=0.022)
|
||||||
|
|
||||||
|
# Normalize to [0, 1].
|
||||||
|
# The values 0.4 and 1.5 were measured on an actual Trossen robot.
|
||||||
|
return normalize(value, min_val=0.4, max_val=1.5)
|
||||||
|
|
||||||
|
|
||||||
|
def aloha_gripper_from_angular(value):
|
||||||
|
# Convert from the gripper position used by pi0 to the gripper position that is used by Aloha.
|
||||||
|
# Note that the units are still angular but the range is different.
|
||||||
|
|
||||||
|
# The values 0.4 and 1.5 were measured on an actual Trossen robot.
|
||||||
|
value = unnormalize(value, min_val=0.4, max_val=1.5)
|
||||||
|
|
||||||
|
# These values are coming from the Aloha code:
|
||||||
|
# PUPPET_GRIPPER_JOINT_OPEN, PUPPET_GRIPPER_JOINT_CLOSE
|
||||||
|
return normalize(value, min_val=-0.6213, max_val=1.4910)
|
||||||
|
|
||||||
|
|
||||||
|
def aloha_gripper_from_angular_inv(value):
|
||||||
|
# Directly inverts the gripper_from_angular function.
|
||||||
|
value = unnormalize(value, min_val=-0.6213, max_val=1.4910)
|
||||||
|
return normalize(value, min_val=0.4, max_val=1.5)
|
||||||
|
|
||||||
|
|
||||||
|
class PI0FASTPolicy(PreTrainedPolicy):
|
||||||
|
"""Wrapper class around PI0FAST tokenizer and model to train and run inference within LeRobot."""
|
||||||
|
|
||||||
|
config_class = PI0FASTConfig
|
||||||
|
name = "pi0fast"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: PI0FASTConfig,
|
||||||
|
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
config: Policy configuration class instance or None, in which case the default instantiation of
|
||||||
|
the configuration class is used.
|
||||||
|
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
|
||||||
|
that they will be passed with a call to `load_state_dict` before the policy is used.
|
||||||
|
"""
|
||||||
|
|
||||||
|
super().__init__(config)
|
||||||
|
config.validate_features()
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||||||
|
self.normalize_targets = Normalize(
|
||||||
|
config.output_features, config.normalization_mapping, dataset_stats
|
||||||
|
)
|
||||||
|
self.unnormalize_outputs = Unnormalize(
|
||||||
|
config.output_features, config.normalization_mapping, dataset_stats
|
||||||
|
)
|
||||||
|
|
||||||
|
self.language_tokenizer = AutoProcessor.from_pretrained("google/paligemma-3b-pt-224")
|
||||||
|
self.model = PI0FAST(config)
|
||||||
|
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
"""This should be called whenever the environment is reset."""
|
||||||
|
self._action_queue = deque([], maxlen=self.config.n_action_steps)
|
||||||
|
|
||||||
|
def get_optim_params(self) -> dict:
|
||||||
|
return self.parameters()
|
||||||
|
|
||||||
|
def _pi_aloha_decode_state(self, state):
|
||||||
|
# Flip the joints.
|
||||||
|
for motor_idx in [1, 2, 8, 9]:
|
||||||
|
state[:, motor_idx] *= -1
|
||||||
|
# Reverse the gripper transformation that is being applied by the Aloha runtime.
|
||||||
|
for motor_idx in [6, 13]:
|
||||||
|
state[:, motor_idx] = aloha_gripper_to_angular(state[:, motor_idx])
|
||||||
|
return state
|
||||||
|
|
||||||
|
def _pi_aloha_encode_actions(self, actions):
|
||||||
|
# Flip the joints.
|
||||||
|
for motor_idx in [1, 2, 8, 9]:
|
||||||
|
actions[:, :, motor_idx] *= -1
|
||||||
|
# Reverse the gripper transformation that is being applied by the Aloha runtime.
|
||||||
|
for motor_idx in [6, 13]:
|
||||||
|
actions[:, :, motor_idx] = aloha_gripper_from_angular(actions[:, :, motor_idx])
|
||||||
|
return actions
|
||||||
|
|
||||||
|
def _pi_aloha_encode_actions_inv(self, actions):
|
||||||
|
# Flip the joints again.
|
||||||
|
for motor_idx in [1, 2, 8, 9]:
|
||||||
|
actions[:, :, motor_idx] *= -1
|
||||||
|
# Reverse the gripper transformation that is being applied by the Aloha runtime.
|
||||||
|
for motor_idx in [6, 13]:
|
||||||
|
actions[:, :, motor_idx] = aloha_gripper_from_angular_inv(actions[:, :, motor_idx])
|
||||||
|
return actions
|
||||||
|
|
||||||
|
@torch.no_grad
|
||||||
|
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||||
|
"""Select a single action given environment observations.
|
||||||
|
|
||||||
|
This method wraps `select_actions` in order to return one action at a time for execution in the
|
||||||
|
environment. It works by managing the actions in a queue and only calling `select_actions` when the
|
||||||
|
queue is empty.
|
||||||
|
"""
|
||||||
|
self.eval()
|
||||||
|
|
||||||
|
if self.config.adapt_to_pi_aloha:
|
||||||
|
batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT])
|
||||||
|
|
||||||
|
batch = self.normalize_inputs(batch)
|
||||||
|
|
||||||
|
# Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
|
||||||
|
# querying the policy.
|
||||||
|
if len(self._action_queue) == 0:
|
||||||
|
actions = self.model.generate_actions(batch)
|
||||||
|
|
||||||
|
actions = actions[:, : self.config.n_action_steps]
|
||||||
|
|
||||||
|
original_action_dim = self.config.action_feature.shape[
|
||||||
|
0
|
||||||
|
] # self.config.max_action_dim # self.config.action_feature.shape[0]
|
||||||
|
actions = actions[:, :, :original_action_dim]
|
||||||
|
|
||||||
|
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||||
|
|
||||||
|
if self.config.adapt_to_pi_aloha:
|
||||||
|
actions = self._pi_aloha_encode_actions(actions)
|
||||||
|
|
||||||
|
# `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
|
||||||
|
# effectively has shape (n_action_steps, batch_size, *), hence the transpose.
|
||||||
|
self._action_queue.extend(actions.transpose(0, 1))
|
||||||
|
return self._action_queue.popleft()
|
||||||
|
|
||||||
|
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||||
|
if self.config.adapt_to_pi_aloha:
|
||||||
|
batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT])
|
||||||
|
batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION])
|
||||||
|
batch = self.normalize_inputs(batch)
|
||||||
|
batch = self.normalize_targets(batch)
|
||||||
|
loss_dict = self.model.forward(batch)
|
||||||
|
return loss_dict["loss"], loss_dict
|
||||||
|
|
||||||
|
|
||||||
|
def block_causal_update_causal_mask(
|
||||||
|
attention_mask,
|
||||||
|
token_type_ids=None,
|
||||||
|
past_key_values=None,
|
||||||
|
cache_position=None,
|
||||||
|
input_tensor=None,
|
||||||
|
attn_implementation: str = "eager",
|
||||||
|
dtype: torch.dtype = "float32",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Update the causal mask during training and generation. It can be customized to different attention masks.
|
||||||
|
"""
|
||||||
|
if attn_implementation == "flash_attention_2":
|
||||||
|
if attention_mask is not None and 0.0 in attention_mask:
|
||||||
|
return attention_mask
|
||||||
|
return None
|
||||||
|
using_static_cache = isinstance(past_key_values, StaticCache)
|
||||||
|
min_dtype = torch.finfo(dtype).min
|
||||||
|
|
||||||
|
if input_tensor is None:
|
||||||
|
input_tensor = attention_mask
|
||||||
|
|
||||||
|
inputs_lead_dim, sequence_length = input_tensor.shape[:2]
|
||||||
|
|
||||||
|
if using_static_cache or isinstance(past_key_values, HybridCache):
|
||||||
|
target_length = past_key_values.get_max_cache_shape()
|
||||||
|
else:
|
||||||
|
target_length = (
|
||||||
|
attention_mask.shape[-1]
|
||||||
|
if isinstance(attention_mask, torch.Tensor)
|
||||||
|
else cache_position[0] + sequence_length + 1
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle precomputed attention masks
|
||||||
|
if attention_mask is not None and attention_mask.dim() == 4:
|
||||||
|
return attention_mask
|
||||||
|
|
||||||
|
# Causal mask initialization
|
||||||
|
causal_mask = torch.full(
|
||||||
|
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
|
||||||
|
)
|
||||||
|
|
||||||
|
# Standard causal masking (triu ensures tokens can only attend to past)
|
||||||
|
if sequence_length != 1:
|
||||||
|
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||||
|
|
||||||
|
# Apply block causal mask
|
||||||
|
if token_type_ids is not None:
|
||||||
|
token_type_ids = token_type_ids.to(causal_mask.device).bool()
|
||||||
|
cumsum = torch.cumsum(token_type_ids, dim=1)
|
||||||
|
block_causal_mask = cumsum[:, None, :] <= cumsum[:, :, None]
|
||||||
|
|
||||||
|
# Combine causal_mask with block-wise attention mask
|
||||||
|
causal_mask = torch.where(block_causal_mask, 0.0, causal_mask)
|
||||||
|
causal_mask = causal_mask[:, None, :, :]
|
||||||
|
else:
|
||||||
|
# Apply past cache position constraint
|
||||||
|
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
|
||||||
|
-1, 1
|
||||||
|
)
|
||||||
|
causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1)
|
||||||
|
else:
|
||||||
|
# Apply past cache position constraint
|
||||||
|
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
|
||||||
|
-1, 1
|
||||||
|
)
|
||||||
|
causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
causal_mask = causal_mask.clone() # Copy to contiguous memory for in-place edits
|
||||||
|
mask_length = attention_mask.shape[-1]
|
||||||
|
|
||||||
|
# Apply padding mask
|
||||||
|
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
|
||||||
|
causal_mask.device
|
||||||
|
)
|
||||||
|
padding_mask = padding_mask == 0
|
||||||
|
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||||
|
padding_mask, min_dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
return causal_mask
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_inputs_for_generation(
|
||||||
|
# self,
|
||||||
|
input_ids,
|
||||||
|
past_key_values=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
cache_position=None,
|
||||||
|
position_ids=None,
|
||||||
|
pixel_values=None,
|
||||||
|
attention_mask=None,
|
||||||
|
token_type_ids=None,
|
||||||
|
use_cache=True,
|
||||||
|
num_logits_to_keep=None,
|
||||||
|
labels=None,
|
||||||
|
self=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
# create block causal attention
|
||||||
|
if cache_position[0] > 0 and input_ids.shape[1] > 0:
|
||||||
|
input_tensor = input_ids[:, -1:]
|
||||||
|
new_positions = (
|
||||||
|
torch.ones(
|
||||||
|
(position_ids.shape[0], input_ids.shape[1]),
|
||||||
|
dtype=position_ids.dtype,
|
||||||
|
device=position_ids.device,
|
||||||
|
).cumsum(-1)
|
||||||
|
+ position_ids[:, -1:]
|
||||||
|
)
|
||||||
|
position_ids = torch.cat([position_ids, new_positions], dim=-1)
|
||||||
|
else:
|
||||||
|
input_tensor = inputs_embeds
|
||||||
|
attention_mask = block_causal_update_causal_mask(
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
cache_position=cache_position,
|
||||||
|
input_tensor=input_tensor,
|
||||||
|
token_type_ids=token_type_ids,
|
||||||
|
dtype=self.dtype,
|
||||||
|
attn_implementation=self.config.text_config._attn_implementation,
|
||||||
|
)
|
||||||
|
# Overwritten -- custom `position_ids` and `pixel_values` handling
|
||||||
|
model_inputs = self.language_model.prepare_inputs_for_generation(
|
||||||
|
input_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
cache_position=cache_position,
|
||||||
|
use_cache=use_cache,
|
||||||
|
num_logits_to_keep=num_logits_to_keep,
|
||||||
|
token_type_ids=token_type_ids,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Position_ids in Paligemma are 1-indexed
|
||||||
|
if model_inputs.get("position_ids") is not None:
|
||||||
|
model_inputs["position_ids"] += 1
|
||||||
|
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
|
||||||
|
# Otherwise we need pixel values to be passed to model. NOTE: use_cache=False needs pixel_values always
|
||||||
|
if cache_position[0] == 0:
|
||||||
|
model_inputs["pixel_values"] = pixel_values
|
||||||
|
is_training = token_type_ids is not None and labels is not None
|
||||||
|
if cache_position[0] == 0 and isinstance(past_key_values, HybridCache):
|
||||||
|
input_tensor = inputs_embeds if inputs_embeds is not None else input_ids
|
||||||
|
causal_mask = self._update_causal_mask(
|
||||||
|
attention_mask, token_type_ids, past_key_values, cache_position, input_tensor, is_training
|
||||||
|
)
|
||||||
|
model_inputs["attention_mask"] = causal_mask
|
||||||
|
|
||||||
|
return model_inputs
|
||||||
|
|
||||||
|
|
||||||
|
class PI0FAST(nn.Module):
|
||||||
|
def __init__(self, config: PI0FASTConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
# TODO: move tokenizers in Policy
|
||||||
|
fast_tokenizer_path = "physical-intelligence/fast"
|
||||||
|
pi0_paligemma_path = "google/paligemma-3b-pt-224"
|
||||||
|
self.paligemma_tokenizer = AutoTokenizer.from_pretrained(pi0_paligemma_path)
|
||||||
|
self.processor = AutoProcessor.from_pretrained(pi0_paligemma_path)
|
||||||
|
self.fast_tokenizer = AutoProcessor.from_pretrained(fast_tokenizer_path, trust_remote_code=True)
|
||||||
|
self.fast_skip_tokens = self.config.fast_skip_tokens
|
||||||
|
self.max_input_seq_len = self.config.max_input_seq_len
|
||||||
|
self.action_horizon = self.config.chunk_size
|
||||||
|
self.action_dim = self.config.action_feature.shape[
|
||||||
|
0
|
||||||
|
] # self.config.max_action_dim # self.config.action_feature.shape[0]
|
||||||
|
precision = config.precision
|
||||||
|
torch_precision = PRECISION.get(precision, torch.float32)
|
||||||
|
self.pad_token_id = (
|
||||||
|
self.paligemma_tokenizer.pad_token_id
|
||||||
|
if hasattr(self.paligemma_tokenizer, "pad_token_id")
|
||||||
|
else self.paligemma_tokenizer.eos_token_id
|
||||||
|
)
|
||||||
|
|
||||||
|
paligemma_config = CONFIG_MAPPING["paligemma"](
|
||||||
|
transformers_version="4.48.1",
|
||||||
|
_vocab_size=257152,
|
||||||
|
bos_token_id=2,
|
||||||
|
eos_token_id=1,
|
||||||
|
hidden_size=2048,
|
||||||
|
image_token_index=257152,
|
||||||
|
model_type="paligemma",
|
||||||
|
pad_token_id=0,
|
||||||
|
projection_dim=2048,
|
||||||
|
text_config={
|
||||||
|
"hidden_activation": "gelu_pytorch_tanh",
|
||||||
|
"hidden_size": 2048,
|
||||||
|
"intermediate_size": 16384,
|
||||||
|
"model_type": "gemma",
|
||||||
|
"num_attention_heads": 8,
|
||||||
|
"num_hidden_layers": 18,
|
||||||
|
"num_image_tokens": 256,
|
||||||
|
"num_key_value_heads": 1,
|
||||||
|
"torch_dtype": precision,
|
||||||
|
"vocab_size": 257152,
|
||||||
|
"_attn_implementation": "eager",
|
||||||
|
},
|
||||||
|
vision_config={
|
||||||
|
"hidden_size": 1152,
|
||||||
|
"intermediate_size": 4304,
|
||||||
|
"model_type": "siglip_vision_model",
|
||||||
|
"num_attention_heads": 16,
|
||||||
|
"num_hidden_layers": 27,
|
||||||
|
"num_image_tokens": 256,
|
||||||
|
"patch_size": 14,
|
||||||
|
"projection_dim": 2048,
|
||||||
|
"projector_hidden_act": "gelu_pytorch_tanh",
|
||||||
|
"torch_dtype": precision,
|
||||||
|
"vision_use_head": False,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.pi0_paligemma = PaliGemmaForConditionalGeneration(config=paligemma_config)
|
||||||
|
|
||||||
|
self.pi0_paligemma.prepare_inputs_for_generation = partial(
|
||||||
|
prepare_inputs_for_generation, self=self.pi0_paligemma
|
||||||
|
)
|
||||||
|
# change important stuff in bf16
|
||||||
|
params_to_change_dtype = [
|
||||||
|
"language_model",
|
||||||
|
"vision_tower",
|
||||||
|
"multi_modal",
|
||||||
|
]
|
||||||
|
for name, param in self.pi0_paligemma.named_parameters():
|
||||||
|
if any(selector in name for selector in params_to_change_dtype):
|
||||||
|
param.data = param.data.to(dtype=torch_precision)
|
||||||
|
self.set_requires_grad()
|
||||||
|
self.image_keys = self.config.image_features.keys()
|
||||||
|
self.ignore_index = self.pi0_paligemma.config.ignore_index
|
||||||
|
self.padding_side = self.config.padding_side
|
||||||
|
|
||||||
|
def set_requires_grad(self):
|
||||||
|
if self.config.freeze_vision_encoder:
|
||||||
|
self.pi0_paligemma.vision_tower.eval()
|
||||||
|
for params in self.pi0_paligemma.vision_tower.parameters():
|
||||||
|
params.requires_grad = False
|
||||||
|
# To avoid unused params issue with distributed training
|
||||||
|
if self.config.freeze_lm_head:
|
||||||
|
for name, params in self.pi0_paligemma.named_parameters():
|
||||||
|
if "embed_tokens" in name: # lm heads and embedding layer are tied
|
||||||
|
params.requires_grad = False
|
||||||
|
|
||||||
|
def embed_tokens(self, tokens: torch.Tensor):
|
||||||
|
return self.pi0_paligemma.language_model.model.embed_tokens(tokens)
|
||||||
|
|
||||||
|
def prepare_inputs_for_generation(self, *args, **kwargs):
|
||||||
|
return self.pi0_paligemma.prepare_inputs_for_generation(*args, **kwargs)
|
||||||
|
|
||||||
|
def prepare_images(self, batch):
|
||||||
|
"""Preprocess LeRobot batch into Pi0 inputs"""
|
||||||
|
images = []
|
||||||
|
img_masks = []
|
||||||
|
present_img_keys = [key for key in self.image_keys if key in batch]
|
||||||
|
if len(present_img_keys) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"All image features are missing from the batch. At least one expected. (batch: {batch.keys()}) (image_features:{self.config.image_features})"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Preprocess image features present in the batch
|
||||||
|
num_empty_cameras = 0
|
||||||
|
for key in self.image_keys:
|
||||||
|
if key in present_img_keys:
|
||||||
|
img = batch[key]
|
||||||
|
|
||||||
|
if self.config.resize_imgs_with_padding is not None:
|
||||||
|
img = resize_with_pad(
|
||||||
|
img,
|
||||||
|
*self.config.resize_imgs_with_padding,
|
||||||
|
pad_value=0,
|
||||||
|
interpolate_like_pi=self.config.interpolate_like_pi,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Normalize from range [0,1] to [-1,1] as expacted by siglip
|
||||||
|
img = img * 2.0 - 1.0
|
||||||
|
|
||||||
|
bsize = img.shape[0]
|
||||||
|
device = img.device
|
||||||
|
mask = torch.ones(bsize, dtype=torch.bool, device=device)
|
||||||
|
else:
|
||||||
|
if num_empty_cameras >= self.config.empty_cameras:
|
||||||
|
continue
|
||||||
|
img = torch.ones_like(img) * -1
|
||||||
|
bsize = img.shape[0]
|
||||||
|
device = img.device
|
||||||
|
mask = torch.ones(bsize, dtype=torch.bool, device=device)
|
||||||
|
num_empty_cameras += 1
|
||||||
|
|
||||||
|
images.append(img)
|
||||||
|
img_masks.append(mask)
|
||||||
|
return images, img_masks
|
||||||
|
|
||||||
|
def normalize_actions(self, actions: torch.Tensor) -> torch.Tensor:
|
||||||
|
mins = actions.amin(dim=(1, 2), keepdim=True) # [0]
|
||||||
|
maxs = actions.amax(dim=(1, 2), keepdim=True) # [0]
|
||||||
|
return 2 * (actions - mins) / (maxs - mins + 1e-8) - 1
|
||||||
|
|
||||||
|
def _act_tokens_to_paligemma_tokens(self, tokens: torch.Tensor) -> torch.Tensor:
|
||||||
|
out = self.paligemma_tokenizer.vocab_size - 1 - self.fast_skip_tokens - tokens
|
||||||
|
return out
|
||||||
|
|
||||||
|
def fast_tokenizer_wrapper(self, actions_norm):
|
||||||
|
"""
|
||||||
|
A wrapper for self.fast_tokenizer that ensures batch processing,
|
||||||
|
conversion to PyTorch tensors, and returns a dictionary without padding.
|
||||||
|
"""
|
||||||
|
batch_tokens = self.fast_tokenizer(actions_norm)
|
||||||
|
fast_out = self.processor.tokenizer.pad({"input_ids": batch_tokens}, return_tensors="pt")
|
||||||
|
|
||||||
|
return fast_out
|
||||||
|
|
||||||
|
def create_token_type_ids(self, padded_mask: torch.Tensor, prefix_len: int) -> torch.Tensor:
|
||||||
|
token_type_ids = torch.zeros_like(padded_mask, dtype=torch.bool)
|
||||||
|
# Compute cumulative sum mask
|
||||||
|
cumsum_mask = (padded_mask != 0).cumsum(dim=1)
|
||||||
|
# Suffix block (everything after prefix_len)
|
||||||
|
suffix_mask = cumsum_mask > prefix_len
|
||||||
|
token_type_ids = suffix_mask
|
||||||
|
return token_type_ids
|
||||||
|
|
||||||
|
def create_input_tokens(self, state, lang_text, actions=None):
|
||||||
|
bsize = state.shape[0]
|
||||||
|
device = state.device
|
||||||
|
bins = torch.linspace(-1, 1, 256 + 1, device=device)[:-1]
|
||||||
|
discretized = torch.bucketize(state, bins) - 1
|
||||||
|
discretized = discretized[:, :32]
|
||||||
|
|
||||||
|
prefix_texts = []
|
||||||
|
state_text = []
|
||||||
|
for txt, disc in zip(lang_text, discretized, strict=False):
|
||||||
|
cleaned = txt.lower().strip().replace("_", " ")
|
||||||
|
state_str = " ".join(str(val.item()) for val in disc)
|
||||||
|
prefix_texts.append(f"Task: {cleaned}, State: {state_str};\n")
|
||||||
|
state_text.append(f"State: {state_str};\n")
|
||||||
|
|
||||||
|
prefix_out = self.paligemma_tokenizer(
|
||||||
|
prefix_texts, add_special_tokens=True, return_tensors="pt", padding="longest", truncation=False
|
||||||
|
)
|
||||||
|
prefix_ids = prefix_out["input_ids"].to(device)
|
||||||
|
prefix_mask = prefix_out["attention_mask"].to(device)
|
||||||
|
prefix_lens = prefix_mask.sum(dim=1)[:, None].cpu()
|
||||||
|
|
||||||
|
if actions is not None:
|
||||||
|
actions_norm = self.normalize_actions(actions)
|
||||||
|
actions_pad = F.pad(
|
||||||
|
actions_norm, (0, max(0, self.config.max_action_dim - actions_norm.shape[2])), value=0
|
||||||
|
)[:, :, : self.config.max_action_dim]
|
||||||
|
fast_out = self.fast_tokenizer_wrapper(
|
||||||
|
actions_pad.cpu(),
|
||||||
|
)
|
||||||
|
act_ids = fast_out["input_ids"]
|
||||||
|
act_mask = fast_out["attention_mask"].to(device)
|
||||||
|
|
||||||
|
act_ids = self._act_tokens_to_paligemma_tokens(act_ids).to(device)
|
||||||
|
# Replace action with 0 to pad tokens
|
||||||
|
act_ids = torch.where(
|
||||||
|
act_ids == self.paligemma_tokenizer.vocab_size - 1 - self.fast_skip_tokens,
|
||||||
|
self.pad_token_id,
|
||||||
|
act_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
eos_token = torch.tensor(
|
||||||
|
[self.paligemma_tokenizer.eos_token_id], dtype=torch.long, device=device
|
||||||
|
).expand(bsize, -1)
|
||||||
|
eos_mask = torch.tensor([1], dtype=torch.long, device=device).expand(bsize, -1)
|
||||||
|
bos = self.paligemma_tokenizer("Action: ", add_special_tokens=False, return_tensors="pt")
|
||||||
|
bos_token = bos["input_ids"].expand(act_ids.shape[0], -1).to(device)
|
||||||
|
bos_mask = bos["attention_mask"].expand(act_ids.shape[0], -1).to(device)
|
||||||
|
act_ids = torch.cat([bos_token, act_ids, eos_token], dim=1)
|
||||||
|
act_mask = torch.cat([bos_mask, act_mask, eos_mask], dim=1)
|
||||||
|
act_mask = act_mask.to(device)
|
||||||
|
else:
|
||||||
|
act_ids = torch.empty(bsize, self.pad_token_id, dtype=torch.long, device=device)
|
||||||
|
act_mask = torch.empty(bsize, 0, dtype=torch.long, device=device)
|
||||||
|
final_ids = torch.cat([prefix_ids, act_ids], dim=1)
|
||||||
|
|
||||||
|
final_mask = torch.cat([prefix_mask, act_mask], dim=1)
|
||||||
|
batch_inputs = {"input_ids": final_ids.tolist(), "attention_mask": final_mask.tolist()}
|
||||||
|
|
||||||
|
# Use tokenizer pad function
|
||||||
|
padded_output = self.paligemma_tokenizer.pad(
|
||||||
|
batch_inputs, padding="longest", max_length=180, return_tensors="pt"
|
||||||
|
)
|
||||||
|
padded_mask = padded_output["attention_mask"]
|
||||||
|
|
||||||
|
# define tensor of padding lengths
|
||||||
|
att_mask = (padded_mask != 0).cumsum(dim=1) > prefix_lens
|
||||||
|
|
||||||
|
token_type_ids = self.create_token_type_ids(padded_mask=padded_mask, prefix_len=prefix_lens)
|
||||||
|
|
||||||
|
padded_output["padded_mask"] = padded_output.pop("attention_mask")
|
||||||
|
padded_output["attention_mask"] = att_mask
|
||||||
|
# loss is computed not on prefix, and not on padding
|
||||||
|
padded_output["loss_mask"] = att_mask & padded_output["padded_mask"]
|
||||||
|
padded_output["token_type_ids"] = token_type_ids
|
||||||
|
return padded_output
|
||||||
|
|
||||||
|
def shift_padding_side(
|
||||||
|
self,
|
||||||
|
tokens: torch.Tensor,
|
||||||
|
ar_mask: torch.Tensor,
|
||||||
|
padding_mask: torch.Tensor,
|
||||||
|
loss_mask: torch.Tensor,
|
||||||
|
targets: torch.Tensor,
|
||||||
|
token_type_ids: torch.Tensor,
|
||||||
|
padding_side: str = "right",
|
||||||
|
) -> tuple[torch.Tensor]:
|
||||||
|
if padding_side not in ["right", "left"]:
|
||||||
|
return tokens, ar_mask, padding_mask, loss_mask, targets, token_type_ids
|
||||||
|
|
||||||
|
new_tokens = torch.empty_like(tokens)
|
||||||
|
new_ar_masks = torch.empty_like(ar_mask)
|
||||||
|
new_padding_mask = torch.empty_like(padding_mask)
|
||||||
|
new_loss_mask = torch.empty_like(loss_mask)
|
||||||
|
new_targets = torch.empty_like(targets)
|
||||||
|
new_token_type_ids = torch.empty_like(token_type_ids)
|
||||||
|
batch_size = tokens.shape[0]
|
||||||
|
for i in range(batch_size):
|
||||||
|
padding_indices = torch.where(padding_mask[i] == 0)[0]
|
||||||
|
non_padding_indices = torch.where(padding_mask[i] == 1)[0]
|
||||||
|
if padding_side == "left":
|
||||||
|
new_indices = torch.cat((padding_indices, non_padding_indices), dim=0)
|
||||||
|
else:
|
||||||
|
new_indices = torch.cat((non_padding_indices, padding_indices), dim=0)
|
||||||
|
new_tokens[i] = tokens[i].index_select(0, new_indices)
|
||||||
|
new_ar_masks[i] = ar_mask[i].index_select(0, new_indices)
|
||||||
|
new_padding_mask[i] = padding_mask[i].index_select(0, new_indices)
|
||||||
|
new_loss_mask[i] = loss_mask[i].index_select(0, new_indices)
|
||||||
|
new_targets[i] = targets[i].index_select(0, new_indices)
|
||||||
|
new_token_type_ids[i] = token_type_ids[i].index_select(0, new_indices)
|
||||||
|
|
||||||
|
return new_tokens, new_ar_masks, new_padding_mask, new_loss_mask, new_targets, new_token_type_ids
|
||||||
|
|
||||||
|
def forward(self, batch: dict[str, Tensor]):
|
||||||
|
device = batch[OBS_ROBOT].device
|
||||||
|
# TODO: keep like this or move to the policy .forward
|
||||||
|
images, img_masks = self.prepare_images(batch)
|
||||||
|
|
||||||
|
padded_outs = self.create_input_tokens(
|
||||||
|
state=batch[OBS_ROBOT],
|
||||||
|
lang_text=batch["task"],
|
||||||
|
actions=batch[ACTION],
|
||||||
|
)
|
||||||
|
|
||||||
|
embs, pad_masks, _, targets, loss_mask, token_type_ids = self.embed_inputs(
|
||||||
|
images,
|
||||||
|
img_masks,
|
||||||
|
padded_outs["input_ids"],
|
||||||
|
padded_outs["padded_mask"],
|
||||||
|
padded_outs["attention_mask"],
|
||||||
|
padded_outs["loss_mask"],
|
||||||
|
padded_outs["token_type_ids"],
|
||||||
|
padding_side=self.padding_side,
|
||||||
|
)
|
||||||
|
position_ids = torch.cumsum(pad_masks, dim=1) - 1
|
||||||
|
token_type_ids = token_type_ids.to(dtype=torch.int64)
|
||||||
|
past_seen_tokens = 0
|
||||||
|
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + embs.shape[1], device=embs.device)
|
||||||
|
pad_masks = block_causal_update_causal_mask(
|
||||||
|
attention_mask=pad_masks,
|
||||||
|
past_key_values=None,
|
||||||
|
cache_position=cache_position,
|
||||||
|
input_tensor=embs,
|
||||||
|
token_type_ids=token_type_ids,
|
||||||
|
dtype=self.pi0_paligemma.dtype,
|
||||||
|
attn_implementation=self.pi0_paligemma.config.text_config._attn_implementation,
|
||||||
|
)
|
||||||
|
outputs = self.pi0_paligemma.forward(
|
||||||
|
input_ids=None,
|
||||||
|
token_type_ids=None,
|
||||||
|
attention_mask=pad_masks,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=None,
|
||||||
|
inputs_embeds=embs,
|
||||||
|
use_cache=False,
|
||||||
|
labels=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
logits = outputs.logits
|
||||||
|
|
||||||
|
loss_fct = nn.CrossEntropyLoss(reduction="none")
|
||||||
|
|
||||||
|
# Shift left for next-step prediction
|
||||||
|
logits = logits[:, :-1, :]
|
||||||
|
targets = targets[:, 1:].to(device) # Shift targets
|
||||||
|
loss_mask = loss_mask[:, 1:].to(device) # Ensure correct shape
|
||||||
|
|
||||||
|
# Compute per-token loss
|
||||||
|
token_loss = loss_fct(logits.reshape(-1, logits.shape[-1]), targets.reshape(-1))
|
||||||
|
|
||||||
|
# Apply loss mask
|
||||||
|
token_loss = token_loss * loss_mask.reshape(-1)
|
||||||
|
|
||||||
|
# Compute final loss
|
||||||
|
loss = token_loss.sum() / torch.clamp(loss_mask.sum(), min=1)
|
||||||
|
|
||||||
|
# Return loss dictionary
|
||||||
|
loss_dict = {"ce_loss": loss.item(), "loss": loss}
|
||||||
|
return loss_dict
|
||||||
|
|
||||||
|
def decode_actions_with_fast(
|
||||||
|
self,
|
||||||
|
tokens: list[list[int]],
|
||||||
|
*,
|
||||||
|
time_horizon: int | None = None,
|
||||||
|
action_dim: int | None = None,
|
||||||
|
relaxed_decoding: bool = True,
|
||||||
|
) -> np.array:
|
||||||
|
"""
|
||||||
|
Adapt original decoding in FAST to always return actions instead of zeros.
|
||||||
|
"""
|
||||||
|
self.time_horizon = (
|
||||||
|
time_horizon or self.fast_tokenizer.time_horizon or self.fast_tokenizer.called_time_horizon
|
||||||
|
)
|
||||||
|
self.action_dim = (
|
||||||
|
action_dim or self.fast_tokenizer.action_dim or self.fast_tokenizer.called_action_dim
|
||||||
|
)
|
||||||
|
|
||||||
|
# Cache the time horizon and action dimension for the next call
|
||||||
|
self.called_time_horizon = self.time_horizon
|
||||||
|
self.called_action_dim = self.action_dim
|
||||||
|
|
||||||
|
assert self.time_horizon is not None and self.action_dim is not None, (
|
||||||
|
"Tokenizer not initialized, call encode() once or pass in time_horizon and action_dim."
|
||||||
|
)
|
||||||
|
|
||||||
|
decoded_actions = []
|
||||||
|
for token in tokens:
|
||||||
|
try:
|
||||||
|
decoded_tokens = self.fast_tokenizer.bpe_tokenizer.decode(token)
|
||||||
|
decoded_dct_coeff = np.array(list(map(ord, decoded_tokens))) + self.fast_tokenizer.min_token
|
||||||
|
if relaxed_decoding:
|
||||||
|
# Expected sequence length
|
||||||
|
expected_seq_len = self.time_horizon * self.action_dim
|
||||||
|
diff = expected_seq_len - decoded_dct_coeff.shape[0]
|
||||||
|
# Apply truncation if too long
|
||||||
|
if diff < 0:
|
||||||
|
decoded_dct_coeff = decoded_dct_coeff[:expected_seq_len] # Truncate on the right
|
||||||
|
# Apply padding if too short
|
||||||
|
elif diff > 0:
|
||||||
|
decoded_dct_coeff = np.pad(
|
||||||
|
decoded_dct_coeff, (0, diff), mode="constant", constant_values=0
|
||||||
|
)
|
||||||
|
|
||||||
|
decoded_dct_coeff = decoded_dct_coeff.reshape(-1, self.action_dim)
|
||||||
|
assert decoded_dct_coeff.shape == (
|
||||||
|
self.time_horizon,
|
||||||
|
self.action_dim,
|
||||||
|
), (
|
||||||
|
f"Decoded DCT coefficients have shape {decoded_dct_coeff.shape}, expected ({self.time_horizon}, {self.action_dim})"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error decoding tokens: {e}")
|
||||||
|
print(f"Tokens: {token}")
|
||||||
|
decoded_dct_coeff = np.zeros((self.time_horizon, self.action_dim))
|
||||||
|
decoded_actions.append(idct(decoded_dct_coeff / self.fast_tokenizer.scale, axis=0, norm="ortho"))
|
||||||
|
return np.stack(decoded_actions)
|
||||||
|
|
||||||
|
def extract_actions(self, tokens: torch.Tensor, action_horizon: int, action_dim: int) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Extracts actions from predicted output tokens using the FAST model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tokens (torch.Tensor): The input tensor of tokenized outputs.
|
||||||
|
action_horizon (int): The number of timesteps for actions.
|
||||||
|
action_dim (int): The dimensionality of each action.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: The extracted actions as a tensor of shape (action_horizon, action_dim).
|
||||||
|
"""
|
||||||
|
# Decode predicted output tokens
|
||||||
|
decoded_tokens = self.paligemma_tokenizer.batch_decode(tokens, skip_special_tokens=True)
|
||||||
|
cleaned_tokens = [
|
||||||
|
tokens_sequence.replace("Action:", "").replace(":", "").strip().split("|")[0].strip()
|
||||||
|
for tokens_sequence in decoded_tokens
|
||||||
|
]
|
||||||
|
raw_action_tokens = [
|
||||||
|
self.processor.tokenizer.encode(sample_tokens, return_tensors="pt", padding=False)
|
||||||
|
for sample_tokens in cleaned_tokens
|
||||||
|
] # something like this should be robust #looks good
|
||||||
|
action_tokens = [
|
||||||
|
self._act_tokens_to_paligemma_tokens(raw_action_token) for raw_action_token in raw_action_tokens
|
||||||
|
]
|
||||||
|
# returns the tensor of decoded actions per sample in a list
|
||||||
|
decoded_actions = [
|
||||||
|
torch.tensor(
|
||||||
|
self.decode_actions_with_fast(
|
||||||
|
tok.tolist(),
|
||||||
|
time_horizon=action_horizon,
|
||||||
|
action_dim=action_dim,
|
||||||
|
relaxed_decoding=self.config.relaxed_action_decoding,
|
||||||
|
),
|
||||||
|
device=tokens.device,
|
||||||
|
).squeeze(0)
|
||||||
|
for tok in action_tokens
|
||||||
|
]
|
||||||
|
|
||||||
|
return torch.stack(
|
||||||
|
decoded_actions,
|
||||||
|
dim=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
def generate_actions(self, batch: dict[str, Tensor]):
|
||||||
|
# TODO: keep like this or move to the policy .forward
|
||||||
|
images, img_masks = self.prepare_images(batch)
|
||||||
|
|
||||||
|
padded_outs = self.create_input_tokens(state=batch[OBS_ROBOT], lang_text=batch["task"], actions=None)
|
||||||
|
embs, pad_masks, att_masks2, targets, loss_mask, token_type_ids = self.embed_inputs(
|
||||||
|
images,
|
||||||
|
img_masks,
|
||||||
|
padded_outs["input_ids"],
|
||||||
|
padded_outs["padded_mask"],
|
||||||
|
padded_outs["attention_mask"],
|
||||||
|
padded_outs["loss_mask"],
|
||||||
|
padded_outs["token_type_ids"],
|
||||||
|
padding_side="left",
|
||||||
|
)
|
||||||
|
token_type_ids = token_type_ids.to(dtype=torch.int64)
|
||||||
|
prefix_position_ids = torch.cumsum(pad_masks, dim=1) - 1
|
||||||
|
output_tokens = self.pi0_paligemma.generate(
|
||||||
|
input_ids=None,
|
||||||
|
attention_mask=pad_masks,
|
||||||
|
position_ids=prefix_position_ids,
|
||||||
|
past_key_values=None,
|
||||||
|
inputs_embeds=embs,
|
||||||
|
use_cache=self.config.use_cache,
|
||||||
|
max_new_tokens=self.config.max_decoding_steps,
|
||||||
|
do_sample=False,
|
||||||
|
num_beams=1,
|
||||||
|
token_type_ids=token_type_ids,
|
||||||
|
)
|
||||||
|
actions = self.extract_actions(output_tokens, self.action_horizon, self.action_dim)
|
||||||
|
return actions
|
||||||
|
|
||||||
|
def embed_image(self, image: torch.Tensor):
|
||||||
|
return self.pi0_paligemma.get_image_features(image)
|
||||||
|
|
||||||
|
def embed_inputs(
|
||||||
|
self,
|
||||||
|
images,
|
||||||
|
img_masks,
|
||||||
|
tokens,
|
||||||
|
pad_mask,
|
||||||
|
ar_mask,
|
||||||
|
loss_mask,
|
||||||
|
token_type_ids,
|
||||||
|
padding_side: str = "right",
|
||||||
|
):
|
||||||
|
# TODO: avoid list in python and torch.cat ; prefer pre-allocation with torch.empty
|
||||||
|
# images are a list of same size
|
||||||
|
# vectorizing everything!
|
||||||
|
device = images[0].device
|
||||||
|
image_embedding_dim = images[0].shape[-1] # TODO should be from self.config
|
||||||
|
all_images = torch.stack(images, dim=1).to(device)
|
||||||
|
b, n, c, h, w = all_images.shape
|
||||||
|
all_images = all_images.view(b * n, c, h, w)
|
||||||
|
embedded = self.embed_image(all_images).to(device)
|
||||||
|
b_n, p, image_embedding_dim = embedded.shape # Extract current dimensions
|
||||||
|
m = b_n // b # Compute the number of images per sample dynamically
|
||||||
|
|
||||||
|
# Reshape dynamically
|
||||||
|
embedded = embedded.view(b, m, p, image_embedding_dim)
|
||||||
|
tokens_embs = self.embed_tokens(tokens.to(device))
|
||||||
|
|
||||||
|
img_masks = torch.stack(img_masks, dim=1).unsqueeze(-1).to(device)
|
||||||
|
num_img_emb = embedded.shape[2]
|
||||||
|
img_pad_masks = img_masks.repeat(1, 1, num_img_emb).view(b, -1)
|
||||||
|
img_att_masks = torch.zeros((b, n, num_img_emb), dtype=torch.long, device=device).reshape(b, -1)
|
||||||
|
|
||||||
|
image_target_tokens = (
|
||||||
|
torch.ones((b, n, num_img_emb), dtype=torch.long, device=device) * self.pad_token_id
|
||||||
|
).reshape(b, -1)
|
||||||
|
image_loss_mask = torch.zeros((b, n, num_img_emb), dtype=torch.long, device=device).reshape(b, -1)
|
||||||
|
|
||||||
|
embedded = embedded.reshape(b, n * num_img_emb, image_embedding_dim) # Shape: (B, N*P, D)
|
||||||
|
|
||||||
|
embs = torch.cat([embedded, tokens_embs], dim=1).to(device)
|
||||||
|
pad_masks = torch.cat([img_pad_masks, pad_mask.to(device)], dim=1)
|
||||||
|
att_masks = torch.cat([img_att_masks, ar_mask.to(device)], dim=1)
|
||||||
|
loss_masks = torch.cat([image_loss_mask, loss_mask.to(device)], dim=1)
|
||||||
|
targets = torch.cat([image_target_tokens, tokens.to(device)], dim=1)
|
||||||
|
token_type_ids = torch.cat([img_att_masks, token_type_ids.to(device)], dim=1)
|
||||||
|
|
||||||
|
# Shift pad tokens to the left (.generate()) or right (.train())
|
||||||
|
embs, att_masks, pad_masks, loss_masks, targets, token_type_ids = self.shift_padding_side(
|
||||||
|
embs, att_masks, pad_masks, loss_masks, targets, token_type_ids, padding_side=padding_side
|
||||||
|
)
|
||||||
|
|
||||||
|
targets = torch.where(targets == self.pad_token_id, self.ignore_index, targets)
|
||||||
|
return embs, pad_masks, att_masks, targets, loss_masks, token_type_ids
|
||||||
|
|
||||||
|
|
||||||
|
def resize_with_pad(img, width, height, pad_value=0, interpolate_like_pi=True):
|
||||||
|
# assume no-op when width height fits already
|
||||||
|
if img.ndim != 4:
|
||||||
|
raise ValueError(f"(b,c,h,w) expected, but {img.shape}")
|
||||||
|
|
||||||
|
cur_height, cur_width = img.shape[2:]
|
||||||
|
|
||||||
|
ratio = max(cur_width / width, cur_height / height)
|
||||||
|
resized_height = int(cur_height / ratio)
|
||||||
|
resized_width = int(cur_width / ratio)
|
||||||
|
|
||||||
|
if interpolate_like_pi:
|
||||||
|
img = (img * 255.0).to(dtype=torch.uint8)
|
||||||
|
img = img.permute(0, 2, 3, 1)
|
||||||
|
original_device = img.device
|
||||||
|
img = img.to(device="cpu").numpy()
|
||||||
|
imgs = []
|
||||||
|
for sub_img in img:
|
||||||
|
sub_img = Image.fromarray(sub_img)
|
||||||
|
resized_img = sub_img.resize((resized_width, resized_height), resample=2)
|
||||||
|
resized_img = torch.from_numpy(np.array(resized_img))
|
||||||
|
imgs.append(resized_img)
|
||||||
|
img = torch.stack(imgs, dim=0)
|
||||||
|
img = img.permute(0, 3, 1, 2)
|
||||||
|
resized_img = img.to(device=original_device, dtype=torch.float32) / 255.0
|
||||||
|
else:
|
||||||
|
resized_img = F.interpolate(
|
||||||
|
img, size=(resized_height, resized_width), mode="bilinear", align_corners=False
|
||||||
|
)
|
||||||
|
|
||||||
|
pad_height = max(0, int(height - resized_height))
|
||||||
|
pad_width = max(0, int(width - resized_width))
|
||||||
|
|
||||||
|
# pad on left and top of image
|
||||||
|
padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value)
|
||||||
|
return padded_img
|
|
@ -122,7 +122,7 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||||
|
|
||||||
# When the action queue is depleted, populate it again by querying the policy.
|
# When the action queue is depleted, populate it again by querying the policy.
|
||||||
if len(self._queues["action"]) == 0:
|
if len(self._queues["action"]) == 0:
|
||||||
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch}
|
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch if key in self._queues}
|
||||||
|
|
||||||
# Remove the time dimensions as it is not handled yet.
|
# Remove the time dimensions as it is not handled yet.
|
||||||
for key in batch:
|
for key in batch:
|
||||||
|
|
|
@ -474,7 +474,7 @@ class ManipulatorRobot:
|
||||||
# Used when record_data=True
|
# Used when record_data=True
|
||||||
follower_goal_pos[name] = goal_pos
|
follower_goal_pos[name] = goal_pos
|
||||||
|
|
||||||
goal_pos = goal_pos.numpy().astype(np.int32)
|
goal_pos = goal_pos.numpy().astype(np.float32)
|
||||||
self.follower_arms[name].write("Goal_Position", goal_pos)
|
self.follower_arms[name].write("Goal_Position", goal_pos)
|
||||||
self.logs[f"write_follower_{name}_goal_pos_dt_s"] = time.perf_counter() - before_fwrite_t
|
self.logs[f"write_follower_{name}_goal_pos_dt_s"] = time.perf_counter() - before_fwrite_t
|
||||||
|
|
||||||
|
@ -596,7 +596,7 @@ class ManipulatorRobot:
|
||||||
action_sent.append(goal_pos)
|
action_sent.append(goal_pos)
|
||||||
|
|
||||||
# Send goal position to each follower
|
# Send goal position to each follower
|
||||||
goal_pos = goal_pos.numpy().astype(np.int32)
|
goal_pos = goal_pos.numpy().astype(np.float32)
|
||||||
self.follower_arms[name].write("Goal_Position", goal_pos)
|
self.follower_arms[name].write("Goal_Position", goal_pos)
|
||||||
|
|
||||||
return torch.cat(action_sent)
|
return torch.cat(action_sent)
|
||||||
|
|
|
@ -69,7 +69,13 @@ class WandBLogger:
|
||||||
os.environ["WANDB_SILENT"] = "True"
|
os.environ["WANDB_SILENT"] = "True"
|
||||||
import wandb
|
import wandb
|
||||||
|
|
||||||
wandb_run_id = get_wandb_run_id_from_filesystem(self.log_dir) if cfg.resume else None
|
wandb_run_id = (
|
||||||
|
cfg.wandb.run_id
|
||||||
|
if cfg.wandb.run_id
|
||||||
|
else get_wandb_run_id_from_filesystem(self.log_dir)
|
||||||
|
if cfg.resume
|
||||||
|
else None
|
||||||
|
)
|
||||||
wandb.init(
|
wandb.init(
|
||||||
id=wandb_run_id,
|
id=wandb_run_id,
|
||||||
project=self.cfg.project,
|
project=self.cfg.project,
|
||||||
|
@ -84,6 +90,7 @@ class WandBLogger:
|
||||||
# TODO(rcadene): split train and eval, and run async eval with job_type="eval"
|
# TODO(rcadene): split train and eval, and run async eval with job_type="eval"
|
||||||
job_type="train_eval",
|
job_type="train_eval",
|
||||||
resume="must" if cfg.resume else None,
|
resume="must" if cfg.resume else None,
|
||||||
|
mode=self.cfg.mode if self.cfg.mode in ["online", "offline", "disabled"] else "online",
|
||||||
)
|
)
|
||||||
print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"]))
|
print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"]))
|
||||||
logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}")
|
logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}")
|
||||||
|
|
|
@ -20,6 +20,7 @@ from lerobot.common import (
|
||||||
policies, # noqa: F401
|
policies, # noqa: F401
|
||||||
)
|
)
|
||||||
from lerobot.common.datasets.transforms import ImageTransformsConfig
|
from lerobot.common.datasets.transforms import ImageTransformsConfig
|
||||||
|
from lerobot.common.datasets.video_utils import get_safe_default_codec
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -35,7 +36,7 @@ class DatasetConfig:
|
||||||
image_transforms: ImageTransformsConfig = field(default_factory=ImageTransformsConfig)
|
image_transforms: ImageTransformsConfig = field(default_factory=ImageTransformsConfig)
|
||||||
revision: str | None = None
|
revision: str | None = None
|
||||||
use_imagenet_stats: bool = True
|
use_imagenet_stats: bool = True
|
||||||
video_backend: str = "pyav"
|
video_backend: str = field(default_factory=get_safe_default_codec)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -46,6 +47,8 @@ class WandBConfig:
|
||||||
project: str = "lerobot"
|
project: str = "lerobot"
|
||||||
entity: str | None = None
|
entity: str | None = None
|
||||||
notes: str | None = None
|
notes: str | None = None
|
||||||
|
run_id: str | None = None
|
||||||
|
mode: str | None = None # Allowed values: 'online', 'offline' 'disabled'. Defaults to 'online'
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
|
@ -79,7 +79,9 @@ class TrainPipelineConfig(HubMixin):
|
||||||
# The entire train config is already loaded, we just need to get the checkpoint dir
|
# The entire train config is already loaded, we just need to get the checkpoint dir
|
||||||
config_path = parser.parse_arg("config_path")
|
config_path = parser.parse_arg("config_path")
|
||||||
if not config_path:
|
if not config_path:
|
||||||
raise ValueError("A config_path is expected when resuming a run.")
|
raise ValueError(
|
||||||
|
f"A config_path is expected when resuming a run. Please specify path to {TRAIN_CONFIG_NAME}"
|
||||||
|
)
|
||||||
if not Path(config_path).resolve().exists():
|
if not Path(config_path).resolve().exists():
|
||||||
raise NotADirectoryError(
|
raise NotADirectoryError(
|
||||||
f"{config_path=} is expected to be a local path. "
|
f"{config_path=} is expected to be a local path. "
|
||||||
|
|
|
@ -66,7 +66,7 @@ from torch import Tensor, nn
|
||||||
from tqdm import trange
|
from tqdm import trange
|
||||||
|
|
||||||
from lerobot.common.envs.factory import make_env
|
from lerobot.common.envs.factory import make_env
|
||||||
from lerobot.common.envs.utils import preprocess_observation
|
from lerobot.common.envs.utils import add_envs_task, check_env_attributes_and_types, preprocess_observation
|
||||||
from lerobot.common.policies.factory import make_policy
|
from lerobot.common.policies.factory import make_policy
|
||||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||||
from lerobot.common.policies.utils import get_device_from_parameters
|
from lerobot.common.policies.utils import get_device_from_parameters
|
||||||
|
@ -124,7 +124,6 @@ def rollout(
|
||||||
|
|
||||||
# Reset the policy and environments.
|
# Reset the policy and environments.
|
||||||
policy.reset()
|
policy.reset()
|
||||||
|
|
||||||
observation, info = env.reset(seed=seeds)
|
observation, info = env.reset(seed=seeds)
|
||||||
if render_callback is not None:
|
if render_callback is not None:
|
||||||
render_callback(env)
|
render_callback(env)
|
||||||
|
@ -145,6 +144,7 @@ def rollout(
|
||||||
disable=inside_slurm(), # we dont want progress bar when we use slurm, since it clutters the logs
|
disable=inside_slurm(), # we dont want progress bar when we use slurm, since it clutters the logs
|
||||||
leave=False,
|
leave=False,
|
||||||
)
|
)
|
||||||
|
check_env_attributes_and_types(env)
|
||||||
while not np.all(done):
|
while not np.all(done):
|
||||||
# Numpy array to tensor and changing dictionary keys to LeRobot policy format.
|
# Numpy array to tensor and changing dictionary keys to LeRobot policy format.
|
||||||
observation = preprocess_observation(observation)
|
observation = preprocess_observation(observation)
|
||||||
|
@ -155,6 +155,10 @@ def rollout(
|
||||||
key: observation[key].to(device, non_blocking=device.type == "cuda") for key in observation
|
key: observation[key].to(device, non_blocking=device.type == "cuda") for key in observation
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Infer "task" from attributes of environments.
|
||||||
|
# TODO: works with SyncVectorEnv but not AsyncVectorEnv
|
||||||
|
observation = add_envs_task(env, observation)
|
||||||
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
action = policy.select_action(observation)
|
action = policy.select_action(observation)
|
||||||
|
|
||||||
|
|
|
@ -1,364 +0,0 @@
|
||||||
#!/usr/bin/env python
|
|
||||||
|
|
||||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""
|
|
||||||
Use this script to convert your dataset into LeRobot dataset format and upload it to the Hugging Face hub,
|
|
||||||
or store it locally. LeRobot dataset format is lightweight, fast to load from, and does not require any
|
|
||||||
installation of neural net specific packages like pytorch, tensorflow, jax.
|
|
||||||
|
|
||||||
Example of how to download raw datasets, convert them into LeRobotDataset format, and push them to the hub:
|
|
||||||
```
|
|
||||||
python lerobot/scripts/push_dataset_to_hub.py \
|
|
||||||
--raw-dir data/pusht_raw \
|
|
||||||
--raw-format pusht_zarr \
|
|
||||||
--repo-id lerobot/pusht
|
|
||||||
|
|
||||||
python lerobot/scripts/push_dataset_to_hub.py \
|
|
||||||
--raw-dir data/xarm_lift_medium_raw \
|
|
||||||
--raw-format xarm_pkl \
|
|
||||||
--repo-id lerobot/xarm_lift_medium
|
|
||||||
|
|
||||||
python lerobot/scripts/push_dataset_to_hub.py \
|
|
||||||
--raw-dir data/aloha_sim_insertion_scripted_raw \
|
|
||||||
--raw-format aloha_hdf5 \
|
|
||||||
--repo-id lerobot/aloha_sim_insertion_scripted
|
|
||||||
|
|
||||||
python lerobot/scripts/push_dataset_to_hub.py \
|
|
||||||
--raw-dir data/umi_cup_in_the_wild_raw \
|
|
||||||
--raw-format umi_zarr \
|
|
||||||
--repo-id lerobot/umi_cup_in_the_wild
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import json
|
|
||||||
import shutil
|
|
||||||
import warnings
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from huggingface_hub import HfApi
|
|
||||||
from safetensors.torch import save_file
|
|
||||||
|
|
||||||
from lerobot.common.datasets.compute_stats import compute_stats
|
|
||||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
|
|
||||||
from lerobot.common.datasets.push_dataset_to_hub.utils import check_repo_id
|
|
||||||
from lerobot.common.datasets.utils import create_branch, create_lerobot_dataset_card, flatten_dict
|
|
||||||
|
|
||||||
|
|
||||||
def get_from_raw_to_lerobot_format_fn(raw_format: str):
|
|
||||||
if raw_format == "pusht_zarr":
|
|
||||||
from lerobot.common.datasets.push_dataset_to_hub.pusht_zarr_format import from_raw_to_lerobot_format
|
|
||||||
elif raw_format == "umi_zarr":
|
|
||||||
from lerobot.common.datasets.push_dataset_to_hub.umi_zarr_format import from_raw_to_lerobot_format
|
|
||||||
elif raw_format == "aloha_hdf5":
|
|
||||||
from lerobot.common.datasets.push_dataset_to_hub.aloha_hdf5_format import from_raw_to_lerobot_format
|
|
||||||
elif raw_format in ["rlds", "openx"]:
|
|
||||||
from lerobot.common.datasets.push_dataset_to_hub.openx_rlds_format import from_raw_to_lerobot_format
|
|
||||||
elif raw_format == "dora_parquet":
|
|
||||||
from lerobot.common.datasets.push_dataset_to_hub.dora_parquet_format import from_raw_to_lerobot_format
|
|
||||||
elif raw_format == "xarm_pkl":
|
|
||||||
from lerobot.common.datasets.push_dataset_to_hub.xarm_pkl_format import from_raw_to_lerobot_format
|
|
||||||
elif raw_format == "cam_png":
|
|
||||||
from lerobot.common.datasets.push_dataset_to_hub.cam_png_format import from_raw_to_lerobot_format
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"The selected {raw_format} can't be found. Did you add it to `lerobot/scripts/push_dataset_to_hub.py::get_from_raw_to_lerobot_format_fn`?"
|
|
||||||
)
|
|
||||||
|
|
||||||
return from_raw_to_lerobot_format
|
|
||||||
|
|
||||||
|
|
||||||
def save_meta_data(
|
|
||||||
info: dict[str, Any], stats: dict, episode_data_index: dict[str, list], meta_data_dir: Path
|
|
||||||
):
|
|
||||||
meta_data_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
# save info
|
|
||||||
info_path = meta_data_dir / "info.json"
|
|
||||||
with open(str(info_path), "w") as f:
|
|
||||||
json.dump(info, f, indent=4)
|
|
||||||
|
|
||||||
# save stats
|
|
||||||
stats_path = meta_data_dir / "stats.safetensors"
|
|
||||||
save_file(flatten_dict(stats), stats_path)
|
|
||||||
|
|
||||||
# save episode_data_index
|
|
||||||
episode_data_index = {key: torch.tensor(episode_data_index[key]) for key in episode_data_index}
|
|
||||||
ep_data_idx_path = meta_data_dir / "episode_data_index.safetensors"
|
|
||||||
save_file(episode_data_index, ep_data_idx_path)
|
|
||||||
|
|
||||||
|
|
||||||
def push_meta_data_to_hub(repo_id: str, meta_data_dir: str | Path, revision: str | None):
|
|
||||||
"""Expect all meta data files to be all stored in a single "meta_data" directory.
|
|
||||||
On the hugging face repositery, they will be uploaded in a "meta_data" directory at the root.
|
|
||||||
"""
|
|
||||||
api = HfApi()
|
|
||||||
api.upload_folder(
|
|
||||||
folder_path=meta_data_dir,
|
|
||||||
path_in_repo="meta_data",
|
|
||||||
repo_id=repo_id,
|
|
||||||
revision=revision,
|
|
||||||
repo_type="dataset",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def push_dataset_card_to_hub(
|
|
||||||
repo_id: str,
|
|
||||||
revision: str | None,
|
|
||||||
tags: list | None = None,
|
|
||||||
license: str = "apache-2.0",
|
|
||||||
**card_kwargs,
|
|
||||||
):
|
|
||||||
"""Creates and pushes a LeRobotDataset Card with appropriate tags to easily find it on the hub."""
|
|
||||||
card = create_lerobot_dataset_card(tags=tags, license=license, **card_kwargs)
|
|
||||||
card.push_to_hub(repo_id=repo_id, repo_type="dataset", revision=revision)
|
|
||||||
|
|
||||||
|
|
||||||
def push_videos_to_hub(repo_id: str, videos_dir: str | Path, revision: str | None):
|
|
||||||
"""Expect mp4 files to be all stored in a single "videos" directory.
|
|
||||||
On the hugging face repositery, they will be uploaded in a "videos" directory at the root.
|
|
||||||
"""
|
|
||||||
api = HfApi()
|
|
||||||
api.upload_folder(
|
|
||||||
folder_path=videos_dir,
|
|
||||||
path_in_repo="videos",
|
|
||||||
repo_id=repo_id,
|
|
||||||
revision=revision,
|
|
||||||
repo_type="dataset",
|
|
||||||
allow_patterns="*.mp4",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def push_dataset_to_hub(
|
|
||||||
raw_dir: Path,
|
|
||||||
raw_format: str,
|
|
||||||
repo_id: str,
|
|
||||||
push_to_hub: bool = True,
|
|
||||||
local_dir: Path | None = None,
|
|
||||||
fps: int | None = None,
|
|
||||||
video: bool = True,
|
|
||||||
batch_size: int = 32,
|
|
||||||
num_workers: int = 8,
|
|
||||||
episodes: list[int] | None = None,
|
|
||||||
force_override: bool = False,
|
|
||||||
resume: bool = False,
|
|
||||||
cache_dir: Path = Path("/tmp"),
|
|
||||||
tests_data_dir: Path | None = None,
|
|
||||||
encoding: dict | None = None,
|
|
||||||
):
|
|
||||||
check_repo_id(repo_id)
|
|
||||||
user_id, dataset_id = repo_id.split("/")
|
|
||||||
|
|
||||||
# Robustify when `raw_dir` is str instead of Path
|
|
||||||
raw_dir = Path(raw_dir)
|
|
||||||
if not raw_dir.exists():
|
|
||||||
raise NotADirectoryError(
|
|
||||||
f"{raw_dir} does not exists. Check your paths or run this command to download an existing raw dataset on the hub: "
|
|
||||||
f"`python lerobot/common/datasets/push_dataset_to_hub/_download_raw.py --raw-dir your/raw/dir --repo-id your/repo/id_raw`"
|
|
||||||
)
|
|
||||||
|
|
||||||
if local_dir:
|
|
||||||
# Robustify when `local_dir` is str instead of Path
|
|
||||||
local_dir = Path(local_dir)
|
|
||||||
|
|
||||||
# Send warning if local_dir isn't well formatted
|
|
||||||
if local_dir.parts[-2] != user_id or local_dir.parts[-1] != dataset_id:
|
|
||||||
warnings.warn(
|
|
||||||
f"`local_dir` ({local_dir}) doesn't contain a community or user id `/` the name of the dataset that match the `repo_id` (e.g. 'data/lerobot/pusht'). Following this naming convention is advised, but not mandatory.",
|
|
||||||
stacklevel=1,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check we don't override an existing `local_dir` by mistake
|
|
||||||
if local_dir.exists():
|
|
||||||
if force_override:
|
|
||||||
shutil.rmtree(local_dir)
|
|
||||||
elif not resume:
|
|
||||||
raise ValueError(f"`local_dir` already exists ({local_dir}). Use `--force-override 1`.")
|
|
||||||
|
|
||||||
meta_data_dir = local_dir / "meta_data"
|
|
||||||
videos_dir = local_dir / "videos"
|
|
||||||
else:
|
|
||||||
# Temporary directory used to store images, videos, meta_data
|
|
||||||
meta_data_dir = Path(cache_dir) / "meta_data"
|
|
||||||
videos_dir = Path(cache_dir) / "videos"
|
|
||||||
|
|
||||||
if raw_format is None:
|
|
||||||
# TODO(rcadene, adilzouitine): implement auto_find_raw_format
|
|
||||||
raise NotImplementedError()
|
|
||||||
# raw_format = auto_find_raw_format(raw_dir)
|
|
||||||
|
|
||||||
# convert dataset from original raw format to LeRobot format
|
|
||||||
from_raw_to_lerobot_format = get_from_raw_to_lerobot_format_fn(raw_format)
|
|
||||||
|
|
||||||
hf_dataset, episode_data_index, info = from_raw_to_lerobot_format(
|
|
||||||
raw_dir,
|
|
||||||
videos_dir,
|
|
||||||
fps,
|
|
||||||
video,
|
|
||||||
episodes,
|
|
||||||
encoding,
|
|
||||||
)
|
|
||||||
|
|
||||||
lerobot_dataset = LeRobotDataset.from_preloaded(
|
|
||||||
repo_id=repo_id,
|
|
||||||
hf_dataset=hf_dataset,
|
|
||||||
episode_data_index=episode_data_index,
|
|
||||||
info=info,
|
|
||||||
videos_dir=videos_dir,
|
|
||||||
)
|
|
||||||
stats = compute_stats(lerobot_dataset, batch_size, num_workers)
|
|
||||||
|
|
||||||
if local_dir:
|
|
||||||
hf_dataset = hf_dataset.with_format(None) # to remove transforms that cant be saved
|
|
||||||
hf_dataset.save_to_disk(str(local_dir / "train"))
|
|
||||||
|
|
||||||
if push_to_hub or local_dir:
|
|
||||||
# mandatory for upload
|
|
||||||
save_meta_data(info, stats, episode_data_index, meta_data_dir)
|
|
||||||
|
|
||||||
if push_to_hub:
|
|
||||||
hf_dataset.push_to_hub(repo_id, revision="main")
|
|
||||||
push_meta_data_to_hub(repo_id, meta_data_dir, revision="main")
|
|
||||||
push_dataset_card_to_hub(repo_id, revision="main")
|
|
||||||
if video:
|
|
||||||
push_videos_to_hub(repo_id, videos_dir, revision="main")
|
|
||||||
create_branch(repo_id, repo_type="dataset", branch=CODEBASE_VERSION)
|
|
||||||
|
|
||||||
if tests_data_dir:
|
|
||||||
# get the first episode
|
|
||||||
num_items_first_ep = episode_data_index["to"][0] - episode_data_index["from"][0]
|
|
||||||
test_hf_dataset = hf_dataset.select(range(num_items_first_ep))
|
|
||||||
episode_data_index = {k: v[:1] for k, v in episode_data_index.items()}
|
|
||||||
|
|
||||||
test_hf_dataset = test_hf_dataset.with_format(None)
|
|
||||||
test_hf_dataset.save_to_disk(str(tests_data_dir / repo_id / "train"))
|
|
||||||
|
|
||||||
tests_meta_data = tests_data_dir / repo_id / "meta_data"
|
|
||||||
save_meta_data(info, stats, episode_data_index, tests_meta_data)
|
|
||||||
|
|
||||||
# copy videos of first episode to tests directory
|
|
||||||
episode_index = 0
|
|
||||||
tests_videos_dir = tests_data_dir / repo_id / "videos"
|
|
||||||
tests_videos_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
for key in lerobot_dataset.camera_keys:
|
|
||||||
fname = f"{key}_episode_{episode_index:06d}.mp4"
|
|
||||||
shutil.copy(videos_dir / fname, tests_videos_dir / fname)
|
|
||||||
|
|
||||||
if local_dir is None:
|
|
||||||
# clear cache
|
|
||||||
shutil.rmtree(meta_data_dir)
|
|
||||||
shutil.rmtree(videos_dir)
|
|
||||||
|
|
||||||
return lerobot_dataset
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--raw-dir",
|
|
||||||
type=Path,
|
|
||||||
required=True,
|
|
||||||
help="Directory containing input raw datasets (e.g. `data/aloha_mobile_chair_raw` or `data/pusht_raw).",
|
|
||||||
)
|
|
||||||
# TODO(rcadene): add automatic detection of the format
|
|
||||||
parser.add_argument(
|
|
||||||
"--raw-format",
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
help="Dataset type (e.g. `pusht_zarr`, `umi_zarr`, `aloha_hdf5`, `xarm_pkl`, `dora_parquet`, `rlds`, `openx`).",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--repo-id",
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
help="Repositery identifier on Hugging Face: a community or a user name `/` the name of the dataset (e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--local-dir",
|
|
||||||
type=Path,
|
|
||||||
help="When provided, writes the dataset converted to LeRobotDataset format in this directory (e.g. `data/lerobot/aloha_mobile_chair`).",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--push-to-hub",
|
|
||||||
type=int,
|
|
||||||
default=1,
|
|
||||||
help="Upload to hub.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--fps",
|
|
||||||
type=int,
|
|
||||||
help="Frame rate used to collect videos. If not provided, use the default one specified in the code.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--video",
|
|
||||||
type=int,
|
|
||||||
default=1,
|
|
||||||
help="Convert each episode of the raw dataset to an mp4 video. This option allows 60 times lower disk space consumption and 25 faster loading time during training.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--batch-size",
|
|
||||||
type=int,
|
|
||||||
default=32,
|
|
||||||
help="Batch size loaded by DataLoader for computing the dataset statistics.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--num-workers",
|
|
||||||
type=int,
|
|
||||||
default=8,
|
|
||||||
help="Number of processes of Dataloader for computing the dataset statistics.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--episodes",
|
|
||||||
type=int,
|
|
||||||
nargs="*",
|
|
||||||
help="When provided, only converts the provided episodes (e.g `--episodes 2 3 4`). Useful to test the code on 1 episode.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--force-override",
|
|
||||||
type=int,
|
|
||||||
default=0,
|
|
||||||
help="When set to 1, removes provided output directory if it already exists. By default, raises a ValueError exception.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--resume",
|
|
||||||
type=int,
|
|
||||||
default=0,
|
|
||||||
help="When set to 1, resumes a previous run.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--cache-dir",
|
|
||||||
type=Path,
|
|
||||||
required=False,
|
|
||||||
default="/tmp",
|
|
||||||
help="Directory to store the temporary videos and images generated while creating the dataset.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--tests-data-dir",
|
|
||||||
type=Path,
|
|
||||||
help=(
|
|
||||||
"When provided, save tests artifacts into the given directory "
|
|
||||||
"(e.g. `--tests-data-dir tests/data` will save to tests/data/{--repo-id})."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
push_dataset_to_hub(**vars(args))
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
|
@ -133,7 +133,7 @@ def train(cfg: TrainPipelineConfig):
|
||||||
eval_env = None
|
eval_env = None
|
||||||
if cfg.eval_freq > 0 and cfg.env is not None:
|
if cfg.eval_freq > 0 and cfg.env is not None:
|
||||||
logging.info("Creating env")
|
logging.info("Creating env")
|
||||||
eval_env = make_env(cfg.env, n_envs=cfg.eval.batch_size)
|
eval_env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs)
|
||||||
|
|
||||||
logging.info("Creating policy")
|
logging.info("Creating policy")
|
||||||
policy = make_policy(
|
policy = make_policy(
|
||||||
|
|
|
@ -62,13 +62,14 @@ dependencies = [
|
||||||
"omegaconf>=2.3.0",
|
"omegaconf>=2.3.0",
|
||||||
"opencv-python>=4.9.0",
|
"opencv-python>=4.9.0",
|
||||||
"packaging>=24.2",
|
"packaging>=24.2",
|
||||||
"av>=12.0.5",
|
"av>=12.0.5,<13.0.0",
|
||||||
"pymunk>=6.6.0",
|
"pymunk>=6.6.0",
|
||||||
"pynput>=1.7.7",
|
"pynput>=1.7.7",
|
||||||
"pyzmq>=26.2.1",
|
"pyzmq>=26.2.1",
|
||||||
"rerun-sdk>=0.21.0",
|
"rerun-sdk>=0.21.0",
|
||||||
"termcolor>=2.4.0",
|
"termcolor>=2.4.0",
|
||||||
"torch>=2.2.1",
|
"torch>=2.2.1",
|
||||||
|
"torchcodec>=0.2.1; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')",
|
||||||
"torchvision>=0.21.0",
|
"torchvision>=0.21.0",
|
||||||
"wandb>=0.16.3",
|
"wandb>=0.16.3",
|
||||||
"zarr>=2.17.0",
|
"zarr>=2.17.0",
|
||||||
|
|
Loading…
Reference in New Issue