Merge remote-tracking branch 'origin/main' into user/rcadene/2024_06_01_custom_visualize_dataset
This commit is contained in:
commit
b502a82005
|
@ -10,6 +10,7 @@ on:
|
||||||
- "examples/**"
|
- "examples/**"
|
||||||
- ".github/**"
|
- ".github/**"
|
||||||
- "poetry.lock"
|
- "poetry.lock"
|
||||||
|
- "Makefile"
|
||||||
push:
|
push:
|
||||||
branches:
|
branches:
|
||||||
- main
|
- main
|
||||||
|
@ -19,6 +20,7 @@ on:
|
||||||
- "examples/**"
|
- "examples/**"
|
||||||
- ".github/**"
|
- ".github/**"
|
||||||
- "poetry.lock"
|
- "poetry.lock"
|
||||||
|
- "Makefile"
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
pytest:
|
pytest:
|
||||||
|
@ -32,8 +34,8 @@ jobs:
|
||||||
with:
|
with:
|
||||||
lfs: true # Ensure LFS files are pulled
|
lfs: true # Ensure LFS files are pulled
|
||||||
|
|
||||||
- name: Install EGL
|
- name: Install apt dependencies
|
||||||
run: sudo apt-get update && sudo apt-get install -y libegl1-mesa-dev
|
run: sudo apt-get update && sudo apt-get install -y libegl1-mesa-dev ffmpeg
|
||||||
|
|
||||||
- name: Install poetry
|
- name: Install poetry
|
||||||
run: |
|
run: |
|
||||||
|
@ -70,6 +72,9 @@ jobs:
|
||||||
with:
|
with:
|
||||||
lfs: true # Ensure LFS files are pulled
|
lfs: true # Ensure LFS files are pulled
|
||||||
|
|
||||||
|
- name: Install apt dependencies
|
||||||
|
run: sudo apt-get update && sudo apt-get install -y ffmpeg
|
||||||
|
|
||||||
- name: Install poetry
|
- name: Install poetry
|
||||||
run: |
|
run: |
|
||||||
pipx install poetry && poetry config virtualenvs.in-project true
|
pipx install poetry && poetry config virtualenvs.in-project true
|
||||||
|
@ -104,7 +109,7 @@ jobs:
|
||||||
with:
|
with:
|
||||||
lfs: true # Ensure LFS files are pulled
|
lfs: true # Ensure LFS files are pulled
|
||||||
|
|
||||||
- name: Install EGL
|
- name: Install apt dependencies
|
||||||
run: sudo apt-get update && sudo apt-get install -y libegl1-mesa-dev
|
run: sudo apt-get update && sudo apt-get install -y libegl1-mesa-dev
|
||||||
|
|
||||||
- name: Install poetry
|
- name: Install poetry
|
||||||
|
|
|
@ -0,0 +1,18 @@
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
|
||||||
|
name: Secret Leaks
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
trufflehog:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
fetch-depth: 0
|
||||||
|
- name: Secret Scanning
|
||||||
|
uses: trufflesecurity/trufflehog@main
|
5
Makefile
5
Makefile
|
@ -46,6 +46,7 @@ test-act-ete-train:
|
||||||
policy.n_action_steps=20 \
|
policy.n_action_steps=20 \
|
||||||
policy.chunk_size=20 \
|
policy.chunk_size=20 \
|
||||||
training.batch_size=2 \
|
training.batch_size=2 \
|
||||||
|
training.image_transforms.enable=true \
|
||||||
hydra.run.dir=tests/outputs/act/
|
hydra.run.dir=tests/outputs/act/
|
||||||
|
|
||||||
test-act-ete-eval:
|
test-act-ete-eval:
|
||||||
|
@ -73,6 +74,7 @@ test-act-ete-train-amp:
|
||||||
policy.chunk_size=20 \
|
policy.chunk_size=20 \
|
||||||
training.batch_size=2 \
|
training.batch_size=2 \
|
||||||
hydra.run.dir=tests/outputs/act_amp/ \
|
hydra.run.dir=tests/outputs/act_amp/ \
|
||||||
|
training.image_transforms.enable=true \
|
||||||
use_amp=true
|
use_amp=true
|
||||||
|
|
||||||
test-act-ete-eval-amp:
|
test-act-ete-eval-amp:
|
||||||
|
@ -100,6 +102,7 @@ test-diffusion-ete-train:
|
||||||
training.save_checkpoint=true \
|
training.save_checkpoint=true \
|
||||||
training.save_freq=2 \
|
training.save_freq=2 \
|
||||||
training.batch_size=2 \
|
training.batch_size=2 \
|
||||||
|
training.image_transforms.enable=true \
|
||||||
hydra.run.dir=tests/outputs/diffusion/
|
hydra.run.dir=tests/outputs/diffusion/
|
||||||
|
|
||||||
test-diffusion-ete-eval:
|
test-diffusion-ete-eval:
|
||||||
|
@ -127,6 +130,7 @@ test-tdmpc-ete-train:
|
||||||
training.save_checkpoint=true \
|
training.save_checkpoint=true \
|
||||||
training.save_freq=2 \
|
training.save_freq=2 \
|
||||||
training.batch_size=2 \
|
training.batch_size=2 \
|
||||||
|
training.image_transforms.enable=true \
|
||||||
hydra.run.dir=tests/outputs/tdmpc/
|
hydra.run.dir=tests/outputs/tdmpc/
|
||||||
|
|
||||||
test-tdmpc-ete-eval:
|
test-tdmpc-ete-eval:
|
||||||
|
@ -159,5 +163,6 @@ test-act-pusht-tutorial:
|
||||||
training.save_model=true \
|
training.save_model=true \
|
||||||
training.save_freq=2 \
|
training.save_freq=2 \
|
||||||
training.batch_size=2 \
|
training.batch_size=2 \
|
||||||
|
training.image_transforms.enable=true \
|
||||||
hydra.run.dir=tests/outputs/act_pusht/
|
hydra.run.dir=tests/outputs/act_pusht/
|
||||||
rm lerobot/configs/policy/created_by_Makefile.yaml
|
rm lerobot/configs/policy/created_by_Makefile.yaml
|
||||||
|
|
10
README.md
10
README.md
|
@ -228,13 +228,13 @@ To add a dataset to the hub, you need to login using a write-access token, which
|
||||||
huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
|
huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
|
||||||
```
|
```
|
||||||
|
|
||||||
Then move your dataset folder in `data` directory (e.g. `data/aloha_static_pingpong_test`), and push your dataset to the hub with:
|
Then point to your raw dataset folder (e.g. `data/aloha_static_pingpong_test_raw`), and push your dataset to the hub with:
|
||||||
```bash
|
```bash
|
||||||
python lerobot/scripts/push_dataset_to_hub.py \
|
python lerobot/scripts/push_dataset_to_hub.py \
|
||||||
--data-dir data \
|
--raw-dir data/aloha_static_pingpong_test_raw \
|
||||||
--dataset-id aloha_static_pingpong_test \
|
--out-dir data \
|
||||||
--raw-format aloha_hdf5 \
|
--repo-id lerobot/aloha_static_pingpong_test \
|
||||||
--community-id lerobot
|
--raw-format aloha_hdf5
|
||||||
```
|
```
|
||||||
|
|
||||||
See `python lerobot/scripts/push_dataset_to_hub.py --help` for more instructions.
|
See `python lerobot/scripts/push_dataset_to_hub.py --help` for more instructions.
|
||||||
|
|
|
@ -46,7 +46,7 @@ defaults:
|
||||||
- policy: diffusion
|
- policy: diffusion
|
||||||
```
|
```
|
||||||
|
|
||||||
This logic tells Hydra to incorporate configuration parameters from `env/pusht.yaml` and `policy/diffusion.yaml`. _Note: Be aware of the order as any configuration parameters with the same name will be overidden. Thus, `default.yaml` is overriden by `env/pusht.yaml` which is overidden by `policy/diffusion.yaml`_.
|
This logic tells Hydra to incorporate configuration parameters from `env/pusht.yaml` and `policy/diffusion.yaml`. _Note: Be aware of the order as any configuration parameters with the same name will be overidden. Thus, `default.yaml` is overridden by `env/pusht.yaml` which is overidden by `policy/diffusion.yaml`_.
|
||||||
|
|
||||||
Then, `default.yaml` also contains common configuration parameters such as `device: cuda` or `use_amp: false` (for enabling fp16 training). Some other parameters are set to `???` which indicates that they are expected to be set in additional yaml files. For instance, `training.offline_steps: ???` in `default.yaml` is set to `200000` in `diffusion.yaml`.
|
Then, `default.yaml` also contains common configuration parameters such as `device: cuda` or `use_amp: false` (for enabling fp16 training). Some other parameters are set to `???` which indicates that they are expected to be set in additional yaml files. For instance, `training.offline_steps: ???` in `default.yaml` is set to `200000` in `diffusion.yaml`.
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,52 @@
|
||||||
|
"""
|
||||||
|
This script demonstrates how to use torchvision's image transformation with LeRobotDataset for data
|
||||||
|
augmentation purposes. The transformations are passed to the dataset as an argument upon creation, and
|
||||||
|
transforms are applied to the observation images before they are returned in the dataset's __get_item__.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from torchvision.transforms import ToPILImage, v2
|
||||||
|
|
||||||
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
|
|
||||||
|
dataset_repo_id = "lerobot/aloha_static_tape"
|
||||||
|
|
||||||
|
# Create a LeRobotDataset with no transformations
|
||||||
|
dataset = LeRobotDataset(dataset_repo_id)
|
||||||
|
# This is equivalent to `dataset = LeRobotDataset(dataset_repo_id, image_transforms=None)`
|
||||||
|
|
||||||
|
# Get the index of the first observation in the first episode
|
||||||
|
first_idx = dataset.episode_data_index["from"][0].item()
|
||||||
|
|
||||||
|
# Get the frame corresponding to the first camera
|
||||||
|
frame = dataset[first_idx][dataset.camera_keys[0]]
|
||||||
|
|
||||||
|
|
||||||
|
# Define the transformations
|
||||||
|
transforms = v2.Compose(
|
||||||
|
[
|
||||||
|
v2.ColorJitter(brightness=(0.5, 1.5)),
|
||||||
|
v2.ColorJitter(contrast=(0.5, 1.5)),
|
||||||
|
v2.RandomAdjustSharpness(sharpness_factor=2, p=1),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create another LeRobotDataset with the defined transformations
|
||||||
|
transformed_dataset = LeRobotDataset(dataset_repo_id, image_transforms=transforms)
|
||||||
|
|
||||||
|
# Get a frame from the transformed dataset
|
||||||
|
transformed_frame = transformed_dataset[first_idx][transformed_dataset.camera_keys[0]]
|
||||||
|
|
||||||
|
# Create a directory to store output images
|
||||||
|
output_dir = Path("outputs/image_transforms")
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Save the original frame
|
||||||
|
to_pil = ToPILImage()
|
||||||
|
to_pil(frame).save(output_dir / "original_frame.png", quality=100)
|
||||||
|
print(f"Original frame saved to {output_dir / 'original_frame.png'}.")
|
||||||
|
|
||||||
|
# Save the transformed frame
|
||||||
|
to_pil(transformed_frame).save(output_dir / "transformed_frame.png", quality=100)
|
||||||
|
print(f"Transformed frame saved to {output_dir / 'transformed_frame.png'}.")
|
|
@ -19,6 +19,7 @@ import torch
|
||||||
from omegaconf import ListConfig, OmegaConf
|
from omegaconf import ListConfig, OmegaConf
|
||||||
|
|
||||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, MultiLeRobotDataset
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, MultiLeRobotDataset
|
||||||
|
from lerobot.common.datasets.transforms import get_image_transforms
|
||||||
|
|
||||||
|
|
||||||
def resolve_delta_timestamps(cfg):
|
def resolve_delta_timestamps(cfg):
|
||||||
|
@ -71,17 +72,37 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData
|
||||||
|
|
||||||
resolve_delta_timestamps(cfg)
|
resolve_delta_timestamps(cfg)
|
||||||
|
|
||||||
# TODO(rcadene): add data augmentations
|
image_transforms = None
|
||||||
|
if cfg.training.image_transforms.enable:
|
||||||
|
cfg_tf = cfg.training.image_transforms
|
||||||
|
image_transforms = get_image_transforms(
|
||||||
|
brightness_weight=cfg_tf.brightness.weight,
|
||||||
|
brightness_min_max=cfg_tf.brightness.min_max,
|
||||||
|
contrast_weight=cfg_tf.contrast.weight,
|
||||||
|
contrast_min_max=cfg_tf.contrast.min_max,
|
||||||
|
saturation_weight=cfg_tf.saturation.weight,
|
||||||
|
saturation_min_max=cfg_tf.saturation.min_max,
|
||||||
|
hue_weight=cfg_tf.hue.weight,
|
||||||
|
hue_min_max=cfg_tf.hue.min_max,
|
||||||
|
sharpness_weight=cfg_tf.sharpness.weight,
|
||||||
|
sharpness_min_max=cfg_tf.sharpness.min_max,
|
||||||
|
max_num_transforms=cfg_tf.max_num_transforms,
|
||||||
|
random_order=cfg_tf.random_order,
|
||||||
|
)
|
||||||
|
|
||||||
if isinstance(cfg.dataset_repo_id, str):
|
if isinstance(cfg.dataset_repo_id, str):
|
||||||
dataset = LeRobotDataset(
|
dataset = LeRobotDataset(
|
||||||
cfg.dataset_repo_id,
|
cfg.dataset_repo_id,
|
||||||
split=split,
|
split=split,
|
||||||
delta_timestamps=cfg.training.get("delta_timestamps"),
|
delta_timestamps=cfg.training.get("delta_timestamps"),
|
||||||
|
image_transforms=image_transforms,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
dataset = MultiLeRobotDataset(
|
dataset = MultiLeRobotDataset(
|
||||||
cfg.dataset_repo_id, split=split, delta_timestamps=cfg.training.get("delta_timestamps")
|
cfg.dataset_repo_id,
|
||||||
|
split=split,
|
||||||
|
delta_timestamps=cfg.training.get("delta_timestamps"),
|
||||||
|
image_transforms=image_transforms,
|
||||||
)
|
)
|
||||||
|
|
||||||
if cfg.get("override_dataset_stats"):
|
if cfg.get("override_dataset_stats"):
|
||||||
|
|
|
@ -46,7 +46,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
version: str | None = CODEBASE_VERSION,
|
version: str | None = CODEBASE_VERSION,
|
||||||
root: Path | None = DATA_DIR,
|
root: Path | None = DATA_DIR,
|
||||||
split: str = "train",
|
split: str = "train",
|
||||||
transform: Callable | None = None,
|
image_transforms: Callable | None = None,
|
||||||
delta_timestamps: dict[list[float]] | None = None,
|
delta_timestamps: dict[list[float]] | None = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -54,7 +54,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
self.version = version
|
self.version = version
|
||||||
self.root = root
|
self.root = root
|
||||||
self.split = split
|
self.split = split
|
||||||
self.transform = transform
|
self.image_transforms = image_transforms
|
||||||
self.delta_timestamps = delta_timestamps
|
self.delta_timestamps = delta_timestamps
|
||||||
# load data from hub or locally when root is provided
|
# load data from hub or locally when root is provided
|
||||||
# TODO(rcadene, aliberts): implement faster transfer
|
# TODO(rcadene, aliberts): implement faster transfer
|
||||||
|
@ -151,8 +151,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
self.tolerance_s,
|
self.tolerance_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.transform is not None:
|
if self.image_transforms is not None:
|
||||||
item = self.transform(item)
|
for cam in self.camera_keys:
|
||||||
|
item[cam] = self.image_transforms(item[cam])
|
||||||
|
|
||||||
return item
|
return item
|
||||||
|
|
||||||
|
@ -168,7 +169,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
f" Recorded Frames per Second: {self.fps},\n"
|
f" Recorded Frames per Second: {self.fps},\n"
|
||||||
f" Camera Keys: {self.camera_keys},\n"
|
f" Camera Keys: {self.camera_keys},\n"
|
||||||
f" Video Frame Keys: {self.video_frame_keys if self.video else 'N/A'},\n"
|
f" Video Frame Keys: {self.video_frame_keys if self.video else 'N/A'},\n"
|
||||||
f" Transformations: {self.transform},\n"
|
f" Transformations: {self.image_transforms},\n"
|
||||||
f")"
|
f")"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -202,7 +203,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
obj.version = version
|
obj.version = version
|
||||||
obj.root = root
|
obj.root = root
|
||||||
obj.split = split
|
obj.split = split
|
||||||
obj.transform = transform
|
obj.image_transforms = transform
|
||||||
obj.delta_timestamps = delta_timestamps
|
obj.delta_timestamps = delta_timestamps
|
||||||
obj.hf_dataset = hf_dataset
|
obj.hf_dataset = hf_dataset
|
||||||
obj.episode_data_index = episode_data_index
|
obj.episode_data_index = episode_data_index
|
||||||
|
@ -225,7 +226,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||||
version: str | None = CODEBASE_VERSION,
|
version: str | None = CODEBASE_VERSION,
|
||||||
root: Path | None = DATA_DIR,
|
root: Path | None = DATA_DIR,
|
||||||
split: str = "train",
|
split: str = "train",
|
||||||
transform: Callable | None = None,
|
image_transforms: Callable | None = None,
|
||||||
delta_timestamps: dict[list[float]] | None = None,
|
delta_timestamps: dict[list[float]] | None = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -239,7 +240,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||||
root=root,
|
root=root,
|
||||||
split=split,
|
split=split,
|
||||||
delta_timestamps=delta_timestamps,
|
delta_timestamps=delta_timestamps,
|
||||||
transform=transform,
|
image_transforms=image_transforms,
|
||||||
)
|
)
|
||||||
for repo_id in repo_ids
|
for repo_id in repo_ids
|
||||||
]
|
]
|
||||||
|
@ -274,7 +275,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||||
self.version = version
|
self.version = version
|
||||||
self.root = root
|
self.root = root
|
||||||
self.split = split
|
self.split = split
|
||||||
self.transform = transform
|
self.image_transforms = image_transforms
|
||||||
self.delta_timestamps = delta_timestamps
|
self.delta_timestamps = delta_timestamps
|
||||||
self.stats = aggregate_stats(self._datasets)
|
self.stats = aggregate_stats(self._datasets)
|
||||||
|
|
||||||
|
@ -380,6 +381,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||||
for data_key in self.disabled_data_keys:
|
for data_key in self.disabled_data_keys:
|
||||||
if data_key in item:
|
if data_key in item:
|
||||||
del item[data_key]
|
del item[data_key]
|
||||||
|
|
||||||
return item
|
return item
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
|
@ -394,6 +396,6 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||||
f" Recorded Frames per Second: {self.fps},\n"
|
f" Recorded Frames per Second: {self.fps},\n"
|
||||||
f" Camera Keys: {self.camera_keys},\n"
|
f" Camera Keys: {self.camera_keys},\n"
|
||||||
f" Video Frame Keys: {self.video_frame_keys if self.video else 'N/A'},\n"
|
f" Video Frame Keys: {self.video_frame_keys if self.video else 'N/A'},\n"
|
||||||
f" Transformations: {self.transform},\n"
|
f" Transformations: {self.image_transforms},\n"
|
||||||
f")"
|
f")"
|
||||||
)
|
)
|
||||||
|
|
|
@ -14,156 +14,119 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""
|
"""
|
||||||
This file contains all obsolete download scripts. They are centralized here to not have to load
|
This file contains download scripts for raw datasets.
|
||||||
useless dependencies when using datasets.
|
|
||||||
|
Example of usage:
|
||||||
|
```
|
||||||
|
python lerobot/common/datasets/push_dataset_to_hub/_download_raw.py \
|
||||||
|
--raw-dir data/cadene/pusht_raw \
|
||||||
|
--repo-id cadene/pusht_raw
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import io
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
import shutil
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import tqdm
|
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
|
|
||||||
def download_raw(raw_dir, dataset_id):
|
def download_raw(raw_dir: Path, repo_id: str):
|
||||||
if "aloha" in dataset_id or "image" in dataset_id:
|
# Check repo_id is well formated
|
||||||
download_hub(raw_dir, dataset_id)
|
if len(repo_id.split("/")) != 2:
|
||||||
elif "pusht" in dataset_id:
|
raise ValueError(
|
||||||
download_pusht(raw_dir)
|
f"`repo_id` is expected to contain a community or user id `/` the name of the dataset (e.g. 'lerobot/pusht'), but contains '{repo_id}'."
|
||||||
elif "xarm" in dataset_id:
|
)
|
||||||
download_xarm(raw_dir)
|
user_id, dataset_id = repo_id.split("/")
|
||||||
elif "umi" in dataset_id:
|
|
||||||
download_umi(raw_dir)
|
|
||||||
else:
|
|
||||||
raise ValueError(dataset_id)
|
|
||||||
|
|
||||||
|
if not dataset_id.endswith("_raw"):
|
||||||
def download_and_extract_zip(url: str, destination_folder: Path) -> bool:
|
warnings.warn(
|
||||||
import zipfile
|
f"`dataset_id` ({dataset_id}) doesn't end with '_raw' (e.g. 'lerobot/pusht_raw'). Following this naming convention by renaming your repository is advised, but not mandatory.",
|
||||||
|
stacklevel=1,
|
||||||
import requests
|
)
|
||||||
|
|
||||||
print(f"downloading from {url}")
|
|
||||||
response = requests.get(url, stream=True)
|
|
||||||
if response.status_code == 200:
|
|
||||||
total_size = int(response.headers.get("content-length", 0))
|
|
||||||
progress_bar = tqdm.tqdm(total=total_size, unit="B", unit_scale=True)
|
|
||||||
|
|
||||||
zip_file = io.BytesIO()
|
|
||||||
for chunk in response.iter_content(chunk_size=1024):
|
|
||||||
if chunk:
|
|
||||||
zip_file.write(chunk)
|
|
||||||
progress_bar.update(len(chunk))
|
|
||||||
|
|
||||||
progress_bar.close()
|
|
||||||
|
|
||||||
zip_file.seek(0)
|
|
||||||
|
|
||||||
with zipfile.ZipFile(zip_file, "r") as zip_ref:
|
|
||||||
zip_ref.extractall(destination_folder)
|
|
||||||
|
|
||||||
|
|
||||||
def download_pusht(raw_dir: str):
|
|
||||||
pusht_url = "https://diffusion-policy.cs.columbia.edu/data/training/pusht.zip"
|
|
||||||
|
|
||||||
raw_dir = Path(raw_dir)
|
raw_dir = Path(raw_dir)
|
||||||
raw_dir.mkdir(parents=True, exist_ok=True)
|
# Send warning if raw_dir isn't well formated
|
||||||
download_and_extract_zip(pusht_url, raw_dir)
|
if raw_dir.parts[-2] != user_id or raw_dir.parts[-1] != dataset_id:
|
||||||
# file is created inside a useful "pusht" directory, so we move it out and delete the dir
|
warnings.warn(
|
||||||
zarr_path = raw_dir / "pusht_cchi_v7_replay.zarr"
|
f"`raw_dir` ({raw_dir}) doesn't contain a community or user id `/` the name of the dataset that match the `repo_id` (e.g. 'data/lerobot/pusht_raw'). Following this naming convention is advised, but not mandatory.",
|
||||||
shutil.move(raw_dir / "pusht" / "pusht_cchi_v7_replay.zarr", zarr_path)
|
stacklevel=1,
|
||||||
shutil.rmtree(raw_dir / "pusht")
|
)
|
||||||
|
|
||||||
|
|
||||||
def download_xarm(raw_dir: Path):
|
|
||||||
"""Download all xarm datasets at once"""
|
|
||||||
import zipfile
|
|
||||||
|
|
||||||
import gdown
|
|
||||||
|
|
||||||
raw_dir = Path(raw_dir)
|
|
||||||
raw_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
# from https://github.com/fyhMer/fowm/blob/main/scripts/download_datasets.py
|
|
||||||
url = "https://drive.google.com/uc?id=1nhxpykGtPDhmQKm-_B8zBSywVRdgeVya"
|
|
||||||
zip_path = raw_dir / "data.zip"
|
|
||||||
gdown.download(url, str(zip_path), quiet=False)
|
|
||||||
print("Extracting...")
|
|
||||||
with zipfile.ZipFile(str(zip_path), "r") as zip_f:
|
|
||||||
for pkl_path in zip_f.namelist():
|
|
||||||
if pkl_path.startswith("data/xarm") and pkl_path.endswith(".pkl"):
|
|
||||||
zip_f.extract(member=pkl_path)
|
|
||||||
# move to corresponding raw directory
|
|
||||||
extract_dir = pkl_path.replace("/buffer.pkl", "")
|
|
||||||
raw_pkl_path = raw_dir / "buffer.pkl"
|
|
||||||
shutil.move(pkl_path, raw_pkl_path)
|
|
||||||
shutil.rmtree(extract_dir)
|
|
||||||
zip_path.unlink()
|
|
||||||
|
|
||||||
|
|
||||||
def download_hub(raw_dir: Path, dataset_id: str):
|
|
||||||
raw_dir = Path(raw_dir)
|
|
||||||
raw_dir.mkdir(parents=True, exist_ok=True)
|
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
logging.info(f"Start downloading from huggingface.co/cadene for {dataset_id}")
|
logging.info(f"Start downloading from huggingface.co/{user_id} for {dataset_id}")
|
||||||
snapshot_download(f"cadene/{dataset_id}_raw", repo_type="dataset", local_dir=raw_dir)
|
snapshot_download(f"{repo_id}", repo_type="dataset", local_dir=raw_dir)
|
||||||
logging.info(f"Finish downloading from huggingface.co/cadene for {dataset_id}")
|
logging.info(f"Finish downloading from huggingface.co/{user_id} for {dataset_id}")
|
||||||
|
|
||||||
|
|
||||||
def download_umi(raw_dir: Path):
|
def download_all_raw_datasets():
|
||||||
url_cup_in_the_wild = "https://real.stanford.edu/umi/data/zarr_datasets/cup_in_the_wild.zarr.zip"
|
data_dir = Path("data")
|
||||||
zarr_path = raw_dir / "cup_in_the_wild.zarr"
|
repo_ids = [
|
||||||
|
"cadene/pusht_image_raw",
|
||||||
|
"cadene/xarm_lift_medium_image_raw",
|
||||||
|
"cadene/xarm_lift_medium_replay_image_raw",
|
||||||
|
"cadene/xarm_push_medium_image_raw",
|
||||||
|
"cadene/xarm_push_medium_replay_image_raw",
|
||||||
|
"cadene/aloha_sim_insertion_human_image_raw",
|
||||||
|
"cadene/aloha_sim_insertion_scripted_image_raw",
|
||||||
|
"cadene/aloha_sim_transfer_cube_human_image_raw",
|
||||||
|
"cadene/aloha_sim_transfer_cube_scripted_image_raw",
|
||||||
|
"cadene/pusht_raw",
|
||||||
|
"cadene/xarm_lift_medium_raw",
|
||||||
|
"cadene/xarm_lift_medium_replay_raw",
|
||||||
|
"cadene/xarm_push_medium_raw",
|
||||||
|
"cadene/xarm_push_medium_replay_raw",
|
||||||
|
"cadene/aloha_sim_insertion_human_raw",
|
||||||
|
"cadene/aloha_sim_insertion_scripted_raw",
|
||||||
|
"cadene/aloha_sim_transfer_cube_human_raw",
|
||||||
|
"cadene/aloha_sim_transfer_cube_scripted_raw",
|
||||||
|
"cadene/aloha_mobile_cabinet_raw",
|
||||||
|
"cadene/aloha_mobile_chair_raw",
|
||||||
|
"cadene/aloha_mobile_elevator_raw",
|
||||||
|
"cadene/aloha_mobile_shrimp_raw",
|
||||||
|
"cadene/aloha_mobile_wash_pan_raw",
|
||||||
|
"cadene/aloha_mobile_wipe_wine_raw",
|
||||||
|
"cadene/aloha_static_battery_raw",
|
||||||
|
"cadene/aloha_static_candy_raw",
|
||||||
|
"cadene/aloha_static_coffee_raw",
|
||||||
|
"cadene/aloha_static_coffee_new_raw",
|
||||||
|
"cadene/aloha_static_cups_open_raw",
|
||||||
|
"cadene/aloha_static_fork_pick_up_raw",
|
||||||
|
"cadene/aloha_static_pingpong_test_raw",
|
||||||
|
"cadene/aloha_static_pro_pencil_raw",
|
||||||
|
"cadene/aloha_static_screw_driver_raw",
|
||||||
|
"cadene/aloha_static_tape_raw",
|
||||||
|
"cadene/aloha_static_thread_velcro_raw",
|
||||||
|
"cadene/aloha_static_towel_raw",
|
||||||
|
"cadene/aloha_static_vinh_cup_raw",
|
||||||
|
"cadene/aloha_static_vinh_cup_left_raw",
|
||||||
|
"cadene/aloha_static_ziploc_slide_raw",
|
||||||
|
"cadene/umi_cup_in_the_wild_raw",
|
||||||
|
]
|
||||||
|
for repo_id in repo_ids:
|
||||||
|
raw_dir = data_dir / repo_id
|
||||||
|
download_raw(raw_dir, repo_id)
|
||||||
|
|
||||||
raw_dir = Path(raw_dir)
|
|
||||||
raw_dir.mkdir(parents=True, exist_ok=True)
|
def main():
|
||||||
download_and_extract_zip(url_cup_in_the_wild, zarr_path)
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--raw-dir",
|
||||||
|
type=Path,
|
||||||
|
required=True,
|
||||||
|
help="Directory containing input raw datasets (e.g. `data/aloha_mobile_chair_raw` or `data/pusht_raw).",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--repo-id",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Repositery identifier on Hugging Face: a community or a user name `/` the name of the dataset (e.g. `lerobot/pusht_raw`, `cadene/aloha_sim_insertion_human_raw`).",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
download_raw(**vars(args))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
data_dir = Path("data")
|
main()
|
||||||
dataset_ids = [
|
|
||||||
"pusht_image",
|
|
||||||
"xarm_lift_medium_image",
|
|
||||||
"xarm_lift_medium_replay_image",
|
|
||||||
"xarm_push_medium_image",
|
|
||||||
"xarm_push_medium_replay_image",
|
|
||||||
"aloha_sim_insertion_human_image",
|
|
||||||
"aloha_sim_insertion_scripted_image",
|
|
||||||
"aloha_sim_transfer_cube_human_image",
|
|
||||||
"aloha_sim_transfer_cube_scripted_image",
|
|
||||||
"pusht",
|
|
||||||
"xarm_lift_medium",
|
|
||||||
"xarm_lift_medium_replay",
|
|
||||||
"xarm_push_medium",
|
|
||||||
"xarm_push_medium_replay",
|
|
||||||
"aloha_sim_insertion_human",
|
|
||||||
"aloha_sim_insertion_scripted",
|
|
||||||
"aloha_sim_transfer_cube_human",
|
|
||||||
"aloha_sim_transfer_cube_scripted",
|
|
||||||
"aloha_mobile_cabinet",
|
|
||||||
"aloha_mobile_chair",
|
|
||||||
"aloha_mobile_elevator",
|
|
||||||
"aloha_mobile_shrimp",
|
|
||||||
"aloha_mobile_wash_pan",
|
|
||||||
"aloha_mobile_wipe_wine",
|
|
||||||
"aloha_static_battery",
|
|
||||||
"aloha_static_candy",
|
|
||||||
"aloha_static_coffee",
|
|
||||||
"aloha_static_coffee_new",
|
|
||||||
"aloha_static_cups_open",
|
|
||||||
"aloha_static_fork_pick_up",
|
|
||||||
"aloha_static_pingpong_test",
|
|
||||||
"aloha_static_pro_pencil",
|
|
||||||
"aloha_static_screw_driver",
|
|
||||||
"aloha_static_tape",
|
|
||||||
"aloha_static_thread_velcro",
|
|
||||||
"aloha_static_towel",
|
|
||||||
"aloha_static_vinh_cup",
|
|
||||||
"aloha_static_vinh_cup_left",
|
|
||||||
"aloha_static_ziploc_slide",
|
|
||||||
"umi_cup_in_the_wild",
|
|
||||||
]
|
|
||||||
for dataset_id in dataset_ids:
|
|
||||||
raw_dir = data_dir / f"{dataset_id}_raw"
|
|
||||||
download_raw(raw_dir, dataset_id)
|
|
||||||
|
|
|
@ -30,6 +30,7 @@ from PIL import Image as PILImage
|
||||||
|
|
||||||
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently
|
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently
|
||||||
from lerobot.common.datasets.utils import (
|
from lerobot.common.datasets.utils import (
|
||||||
|
calculate_episode_data_index,
|
||||||
hf_transform_to_torch,
|
hf_transform_to_torch,
|
||||||
)
|
)
|
||||||
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
|
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
|
||||||
|
@ -70,16 +71,17 @@ def check_format(raw_dir) -> bool:
|
||||||
assert c < h and c < w, f"Expect (h,w,c) image format but ({h=},{w=},{c=}) provided."
|
assert c < h and c < w, f"Expect (h,w,c) image format but ({h=},{w=},{c=}) provided."
|
||||||
|
|
||||||
|
|
||||||
def load_from_raw(raw_dir, out_dir, fps, video, debug):
|
def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episodes: list[int] | None = None):
|
||||||
# only frames from simulation are uncompressed
|
# only frames from simulation are uncompressed
|
||||||
compressed_images = "sim" not in raw_dir.name
|
compressed_images = "sim" not in raw_dir.name
|
||||||
|
|
||||||
hdf5_files = list(raw_dir.glob("*.hdf5"))
|
hdf5_files = sorted(raw_dir.glob("episode_*.hdf5"))
|
||||||
ep_dicts = []
|
num_episodes = len(hdf5_files)
|
||||||
episode_data_index = {"from": [], "to": []}
|
|
||||||
|
|
||||||
id_from = 0
|
ep_dicts = []
|
||||||
for ep_idx, ep_path in tqdm.tqdm(enumerate(hdf5_files), total=len(hdf5_files)):
|
ep_ids = episodes if episodes else range(num_episodes)
|
||||||
|
for ep_idx in tqdm.tqdm(ep_ids):
|
||||||
|
ep_path = hdf5_files[ep_idx]
|
||||||
with h5py.File(ep_path, "r") as ep:
|
with h5py.File(ep_path, "r") as ep:
|
||||||
num_frames = ep["/action"].shape[0]
|
num_frames = ep["/action"].shape[0]
|
||||||
|
|
||||||
|
@ -114,12 +116,12 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
|
||||||
|
|
||||||
if video:
|
if video:
|
||||||
# save png images in temporary directory
|
# save png images in temporary directory
|
||||||
tmp_imgs_dir = out_dir / "tmp_images"
|
tmp_imgs_dir = videos_dir / "tmp_images"
|
||||||
save_images_concurrently(imgs_array, tmp_imgs_dir)
|
save_images_concurrently(imgs_array, tmp_imgs_dir)
|
||||||
|
|
||||||
# encode images to a mp4 video
|
# encode images to a mp4 video
|
||||||
fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
|
fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
|
||||||
video_path = out_dir / "videos" / fname
|
video_path = videos_dir / fname
|
||||||
encode_video_frames(tmp_imgs_dir, video_path, fps)
|
encode_video_frames(tmp_imgs_dir, video_path, fps)
|
||||||
|
|
||||||
# clean temporary images directory
|
# clean temporary images directory
|
||||||
|
@ -147,19 +149,13 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
|
||||||
assert isinstance(ep_idx, int)
|
assert isinstance(ep_idx, int)
|
||||||
ep_dicts.append(ep_dict)
|
ep_dicts.append(ep_dict)
|
||||||
|
|
||||||
episode_data_index["from"].append(id_from)
|
|
||||||
episode_data_index["to"].append(id_from + num_frames)
|
|
||||||
|
|
||||||
id_from += num_frames
|
|
||||||
|
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
# process first episode only
|
|
||||||
if debug:
|
|
||||||
break
|
|
||||||
|
|
||||||
data_dict = concatenate_episodes(ep_dicts)
|
data_dict = concatenate_episodes(ep_dicts)
|
||||||
return data_dict, episode_data_index
|
|
||||||
|
total_frames = data_dict["frame_index"].shape[0]
|
||||||
|
data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||||
|
return data_dict
|
||||||
|
|
||||||
|
|
||||||
def to_hf_dataset(data_dict, video) -> Dataset:
|
def to_hf_dataset(data_dict, video) -> Dataset:
|
||||||
|
@ -197,16 +193,22 @@ def to_hf_dataset(data_dict, video) -> Dataset:
|
||||||
return hf_dataset
|
return hf_dataset
|
||||||
|
|
||||||
|
|
||||||
def from_raw_to_lerobot_format(raw_dir: Path, out_dir: Path, fps=None, video=True, debug=False):
|
def from_raw_to_lerobot_format(
|
||||||
|
raw_dir: Path,
|
||||||
|
videos_dir: Path,
|
||||||
|
fps: int | None = None,
|
||||||
|
video: bool = True,
|
||||||
|
episodes: list[int] | None = None,
|
||||||
|
):
|
||||||
# sanity check
|
# sanity check
|
||||||
check_format(raw_dir)
|
check_format(raw_dir)
|
||||||
|
|
||||||
if fps is None:
|
if fps is None:
|
||||||
fps = 50
|
fps = 50
|
||||||
|
|
||||||
data_dir, episode_data_index = load_from_raw(raw_dir, out_dir, fps, video, debug)
|
data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes)
|
||||||
hf_dataset = to_hf_dataset(data_dir, video)
|
hf_dataset = to_hf_dataset(data_dict, video)
|
||||||
|
episode_data_index = calculate_episode_data_index(hf_dataset)
|
||||||
info = {
|
info = {
|
||||||
"fps": fps,
|
"fps": fps,
|
||||||
"video": video,
|
"video": video,
|
||||||
|
|
|
@ -17,7 +17,6 @@
|
||||||
Contains utilities to process raw data format from dora-record
|
Contains utilities to process raw data format from dora-record
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
|
||||||
import re
|
import re
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
@ -26,10 +25,10 @@ import torch
|
||||||
from datasets import Dataset, Features, Image, Sequence, Value
|
from datasets import Dataset, Features, Image, Sequence, Value
|
||||||
|
|
||||||
from lerobot.common.datasets.utils import (
|
from lerobot.common.datasets.utils import (
|
||||||
|
calculate_episode_data_index,
|
||||||
hf_transform_to_torch,
|
hf_transform_to_torch,
|
||||||
)
|
)
|
||||||
from lerobot.common.datasets.video_utils import VideoFrame
|
from lerobot.common.datasets.video_utils import VideoFrame
|
||||||
from lerobot.common.utils.utils import init_logging
|
|
||||||
|
|
||||||
|
|
||||||
def check_format(raw_dir) -> bool:
|
def check_format(raw_dir) -> bool:
|
||||||
|
@ -41,7 +40,7 @@ def check_format(raw_dir) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def load_from_raw(raw_dir: Path, out_dir: Path, fps: int):
|
def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episodes: list[int] | None = None):
|
||||||
# Load data stream that will be used as reference for the timestamps synchronization
|
# Load data stream that will be used as reference for the timestamps synchronization
|
||||||
reference_files = list(raw_dir.glob("observation.images.cam_*.parquet"))
|
reference_files = list(raw_dir.glob("observation.images.cam_*.parquet"))
|
||||||
if len(reference_files) == 0:
|
if len(reference_files) == 0:
|
||||||
|
@ -122,8 +121,7 @@ def load_from_raw(raw_dir: Path, out_dir: Path, fps: int):
|
||||||
raise ValueError(f"Episodes indices go from {ep_ids} instead of {expected_ep_ids}")
|
raise ValueError(f"Episodes indices go from {ep_ids} instead of {expected_ep_ids}")
|
||||||
|
|
||||||
# Create symlink to raw videos directory (that needs to be absolute not relative)
|
# Create symlink to raw videos directory (that needs to be absolute not relative)
|
||||||
out_dir.mkdir(parents=True, exist_ok=True)
|
videos_dir.parent.mkdir(parents=True, exist_ok=True)
|
||||||
videos_dir = out_dir / "videos"
|
|
||||||
videos_dir.symlink_to((raw_dir / "videos").absolute())
|
videos_dir.symlink_to((raw_dir / "videos").absolute())
|
||||||
|
|
||||||
# sanity check the video paths are well formated
|
# sanity check the video paths are well formated
|
||||||
|
@ -156,16 +154,7 @@ def load_from_raw(raw_dir: Path, out_dir: Path, fps: int):
|
||||||
else:
|
else:
|
||||||
raise ValueError(key)
|
raise ValueError(key)
|
||||||
|
|
||||||
# Get the episode index containing for each unique episode index
|
return data_dict
|
||||||
first_ep_index_df = df.groupby("episode_index").agg(start_index=("index", "first")).reset_index()
|
|
||||||
from_ = first_ep_index_df["start_index"].tolist()
|
|
||||||
to_ = from_[1:] + [len(df)]
|
|
||||||
episode_data_index = {
|
|
||||||
"from": from_,
|
|
||||||
"to": to_,
|
|
||||||
}
|
|
||||||
|
|
||||||
return data_dict, episode_data_index
|
|
||||||
|
|
||||||
|
|
||||||
def to_hf_dataset(data_dict, video) -> Dataset:
|
def to_hf_dataset(data_dict, video) -> Dataset:
|
||||||
|
@ -203,12 +192,13 @@ def to_hf_dataset(data_dict, video) -> Dataset:
|
||||||
return hf_dataset
|
return hf_dataset
|
||||||
|
|
||||||
|
|
||||||
def from_raw_to_lerobot_format(raw_dir: Path, out_dir: Path, fps=None, video=True, debug=False):
|
def from_raw_to_lerobot_format(
|
||||||
init_logging()
|
raw_dir: Path,
|
||||||
|
videos_dir: Path,
|
||||||
if debug:
|
fps: int | None = None,
|
||||||
logging.warning("debug=True not implemented. Falling back to debug=False.")
|
video: bool = True,
|
||||||
|
episodes: list[int] | None = None,
|
||||||
|
):
|
||||||
# sanity check
|
# sanity check
|
||||||
check_format(raw_dir)
|
check_format(raw_dir)
|
||||||
|
|
||||||
|
@ -220,9 +210,9 @@ def from_raw_to_lerobot_format(raw_dir: Path, out_dir: Path, fps=None, video=Tru
|
||||||
if not video:
|
if not video:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
data_df, episode_data_index = load_from_raw(raw_dir, out_dir, fps)
|
data_df = load_from_raw(raw_dir, videos_dir, fps, episodes)
|
||||||
hf_dataset = to_hf_dataset(data_df, video)
|
hf_dataset = to_hf_dataset(data_df, video)
|
||||||
|
episode_data_index = calculate_episode_data_index(hf_dataset)
|
||||||
info = {
|
info = {
|
||||||
"fps": fps,
|
"fps": fps,
|
||||||
"video": video,
|
"video": video,
|
|
@ -27,6 +27,7 @@ from PIL import Image as PILImage
|
||||||
|
|
||||||
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently
|
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently
|
||||||
from lerobot.common.datasets.utils import (
|
from lerobot.common.datasets.utils import (
|
||||||
|
calculate_episode_data_index,
|
||||||
hf_transform_to_torch,
|
hf_transform_to_torch,
|
||||||
)
|
)
|
||||||
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
|
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
|
||||||
|
@ -53,7 +54,7 @@ def check_format(raw_dir):
|
||||||
assert all(nb_frames == zarr_data[dataset].shape[0] for dataset in required_datasets)
|
assert all(nb_frames == zarr_data[dataset].shape[0] for dataset in required_datasets)
|
||||||
|
|
||||||
|
|
||||||
def load_from_raw(raw_dir, out_dir, fps, video, debug):
|
def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episodes: list[int] | None = None):
|
||||||
try:
|
try:
|
||||||
import pymunk
|
import pymunk
|
||||||
from gym_pusht.envs.pusht import PushTEnv, pymunk_to_shapely
|
from gym_pusht.envs.pusht import PushTEnv, pymunk_to_shapely
|
||||||
|
@ -71,7 +72,6 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
|
||||||
zarr_data = DiffusionPolicyReplayBuffer.copy_from_path(zarr_path)
|
zarr_data = DiffusionPolicyReplayBuffer.copy_from_path(zarr_path)
|
||||||
|
|
||||||
episode_ids = torch.from_numpy(zarr_data.get_episode_idxs())
|
episode_ids = torch.from_numpy(zarr_data.get_episode_idxs())
|
||||||
num_episodes = zarr_data.meta["episode_ends"].shape[0]
|
|
||||||
assert len(
|
assert len(
|
||||||
{zarr_data[key].shape[0] for key in zarr_data.keys()} # noqa: SIM118
|
{zarr_data[key].shape[0] for key in zarr_data.keys()} # noqa: SIM118
|
||||||
), "Some data type dont have the same number of total frames."
|
), "Some data type dont have the same number of total frames."
|
||||||
|
@ -84,25 +84,34 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
|
||||||
states = torch.from_numpy(zarr_data["state"])
|
states = torch.from_numpy(zarr_data["state"])
|
||||||
actions = torch.from_numpy(zarr_data["action"])
|
actions = torch.from_numpy(zarr_data["action"])
|
||||||
|
|
||||||
ep_dicts = []
|
# load data indices from which each episode starts and ends
|
||||||
episode_data_index = {"from": [], "to": []}
|
from_ids, to_ids = [], []
|
||||||
|
from_idx = 0
|
||||||
|
for to_idx in zarr_data.meta["episode_ends"]:
|
||||||
|
from_ids.append(from_idx)
|
||||||
|
to_ids.append(to_idx)
|
||||||
|
from_idx = to_idx
|
||||||
|
|
||||||
id_from = 0
|
num_episodes = len(from_ids)
|
||||||
for ep_idx in tqdm.tqdm(range(num_episodes)):
|
|
||||||
id_to = zarr_data.meta["episode_ends"][ep_idx]
|
ep_dicts = []
|
||||||
num_frames = id_to - id_from
|
ep_ids = episodes if episodes else range(num_episodes)
|
||||||
|
for ep_idx, selected_ep_idx in tqdm.tqdm(enumerate(ep_ids)):
|
||||||
|
from_idx = from_ids[selected_ep_idx]
|
||||||
|
to_idx = to_ids[selected_ep_idx]
|
||||||
|
num_frames = to_idx - from_idx
|
||||||
|
|
||||||
# sanity check
|
# sanity check
|
||||||
assert (episode_ids[id_from:id_to] == ep_idx).all()
|
assert (episode_ids[from_idx:to_idx] == ep_idx).all()
|
||||||
|
|
||||||
# get image
|
# get image
|
||||||
image = imgs[id_from:id_to]
|
image = imgs[from_idx:to_idx]
|
||||||
assert image.min() >= 0.0
|
assert image.min() >= 0.0
|
||||||
assert image.max() <= 255.0
|
assert image.max() <= 255.0
|
||||||
image = image.type(torch.uint8)
|
image = image.type(torch.uint8)
|
||||||
|
|
||||||
# get state
|
# get state
|
||||||
state = states[id_from:id_to]
|
state = states[from_idx:to_idx]
|
||||||
agent_pos = state[:, :2]
|
agent_pos = state[:, :2]
|
||||||
block_pos = state[:, 2:4]
|
block_pos = state[:, 2:4]
|
||||||
block_angle = state[:, 4]
|
block_angle = state[:, 4]
|
||||||
|
@ -143,12 +152,12 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
|
||||||
img_key = "observation.image"
|
img_key = "observation.image"
|
||||||
if video:
|
if video:
|
||||||
# save png images in temporary directory
|
# save png images in temporary directory
|
||||||
tmp_imgs_dir = out_dir / "tmp_images"
|
tmp_imgs_dir = videos_dir / "tmp_images"
|
||||||
save_images_concurrently(imgs_array, tmp_imgs_dir)
|
save_images_concurrently(imgs_array, tmp_imgs_dir)
|
||||||
|
|
||||||
# encode images to a mp4 video
|
# encode images to a mp4 video
|
||||||
fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
|
fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
|
||||||
video_path = out_dir / "videos" / fname
|
video_path = videos_dir / fname
|
||||||
encode_video_frames(tmp_imgs_dir, video_path, fps)
|
encode_video_frames(tmp_imgs_dir, video_path, fps)
|
||||||
|
|
||||||
# clean temporary images directory
|
# clean temporary images directory
|
||||||
|
@ -160,7 +169,7 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
|
||||||
ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array]
|
ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array]
|
||||||
|
|
||||||
ep_dict["observation.state"] = agent_pos
|
ep_dict["observation.state"] = agent_pos
|
||||||
ep_dict["action"] = actions[id_from:id_to]
|
ep_dict["action"] = actions[from_idx:to_idx]
|
||||||
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int64)
|
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int64)
|
||||||
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
|
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
|
||||||
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
|
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
|
||||||
|
@ -172,17 +181,11 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
|
||||||
ep_dict["next.success"] = torch.cat([success[1:], success[[-1]]])
|
ep_dict["next.success"] = torch.cat([success[1:], success[[-1]]])
|
||||||
ep_dicts.append(ep_dict)
|
ep_dicts.append(ep_dict)
|
||||||
|
|
||||||
episode_data_index["from"].append(id_from)
|
|
||||||
episode_data_index["to"].append(id_from + num_frames)
|
|
||||||
|
|
||||||
id_from += num_frames
|
|
||||||
|
|
||||||
# process first episode only
|
|
||||||
if debug:
|
|
||||||
break
|
|
||||||
|
|
||||||
data_dict = concatenate_episodes(ep_dicts)
|
data_dict = concatenate_episodes(ep_dicts)
|
||||||
return data_dict, episode_data_index
|
|
||||||
|
total_frames = data_dict["frame_index"].shape[0]
|
||||||
|
data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||||
|
return data_dict
|
||||||
|
|
||||||
|
|
||||||
def to_hf_dataset(data_dict, video):
|
def to_hf_dataset(data_dict, video):
|
||||||
|
@ -212,16 +215,22 @@ def to_hf_dataset(data_dict, video):
|
||||||
return hf_dataset
|
return hf_dataset
|
||||||
|
|
||||||
|
|
||||||
def from_raw_to_lerobot_format(raw_dir: Path, out_dir: Path, fps=None, video=True, debug=False):
|
def from_raw_to_lerobot_format(
|
||||||
|
raw_dir: Path,
|
||||||
|
videos_dir: Path,
|
||||||
|
fps: int | None = None,
|
||||||
|
video: bool = True,
|
||||||
|
episodes: list[int] | None = None,
|
||||||
|
):
|
||||||
# sanity check
|
# sanity check
|
||||||
check_format(raw_dir)
|
check_format(raw_dir)
|
||||||
|
|
||||||
if fps is None:
|
if fps is None:
|
||||||
fps = 10
|
fps = 10
|
||||||
|
|
||||||
data_dict, episode_data_index = load_from_raw(raw_dir, out_dir, fps, video, debug)
|
data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes)
|
||||||
hf_dataset = to_hf_dataset(data_dict, video)
|
hf_dataset = to_hf_dataset(data_dict, video)
|
||||||
|
episode_data_index = calculate_episode_data_index(hf_dataset)
|
||||||
info = {
|
info = {
|
||||||
"fps": fps,
|
"fps": fps,
|
||||||
"video": video,
|
"video": video,
|
||||||
|
|
|
@ -19,7 +19,6 @@ import logging
|
||||||
import shutil
|
import shutil
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
import tqdm
|
import tqdm
|
||||||
import zarr
|
import zarr
|
||||||
|
@ -29,6 +28,7 @@ from PIL import Image as PILImage
|
||||||
from lerobot.common.datasets.push_dataset_to_hub._umi_imagecodecs_numcodecs import register_codecs
|
from lerobot.common.datasets.push_dataset_to_hub._umi_imagecodecs_numcodecs import register_codecs
|
||||||
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently
|
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently
|
||||||
from lerobot.common.datasets.utils import (
|
from lerobot.common.datasets.utils import (
|
||||||
|
calculate_episode_data_index,
|
||||||
hf_transform_to_torch,
|
hf_transform_to_torch,
|
||||||
)
|
)
|
||||||
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
|
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
|
||||||
|
@ -59,23 +59,7 @@ def check_format(raw_dir) -> bool:
|
||||||
assert all(nb_frames == zarr_data[dataset].shape[0] for dataset in required_datasets)
|
assert all(nb_frames == zarr_data[dataset].shape[0] for dataset in required_datasets)
|
||||||
|
|
||||||
|
|
||||||
def get_episode_idxs(episode_ends: np.ndarray) -> np.ndarray:
|
def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episodes: list[int] | None = None):
|
||||||
# Optimized and simplified version of this function: https://github.com/real-stanford/universal_manipulation_interface/blob/298776ce251f33b6b3185a98d6e7d1f9ad49168b/diffusion_policy/common/replay_buffer.py#L374
|
|
||||||
from numba import jit
|
|
||||||
|
|
||||||
@jit(nopython=True)
|
|
||||||
def _get_episode_idxs(episode_ends):
|
|
||||||
result = np.zeros((episode_ends[-1],), dtype=np.int64)
|
|
||||||
start_idx = 0
|
|
||||||
for episode_number, end_idx in enumerate(episode_ends):
|
|
||||||
result[start_idx:end_idx] = episode_number
|
|
||||||
start_idx = end_idx
|
|
||||||
return result
|
|
||||||
|
|
||||||
return _get_episode_idxs(episode_ends)
|
|
||||||
|
|
||||||
|
|
||||||
def load_from_raw(raw_dir, out_dir, fps, video, debug):
|
|
||||||
zarr_path = raw_dir / "cup_in_the_wild.zarr"
|
zarr_path = raw_dir / "cup_in_the_wild.zarr"
|
||||||
zarr_data = zarr.open(zarr_path, mode="r")
|
zarr_data = zarr.open(zarr_path, mode="r")
|
||||||
|
|
||||||
|
@ -92,39 +76,41 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
|
||||||
episode_ends = zarr_data["meta/episode_ends"][:]
|
episode_ends = zarr_data["meta/episode_ends"][:]
|
||||||
num_episodes = episode_ends.shape[0]
|
num_episodes = episode_ends.shape[0]
|
||||||
|
|
||||||
episode_ids = torch.from_numpy(get_episode_idxs(episode_ends))
|
|
||||||
|
|
||||||
# We convert it in torch tensor later because the jit function does not support torch tensors
|
# We convert it in torch tensor later because the jit function does not support torch tensors
|
||||||
episode_ends = torch.from_numpy(episode_ends)
|
episode_ends = torch.from_numpy(episode_ends)
|
||||||
|
|
||||||
|
# load data indices from which each episode starts and ends
|
||||||
|
from_ids, to_ids = [], []
|
||||||
|
from_idx = 0
|
||||||
|
for to_idx in episode_ends:
|
||||||
|
from_ids.append(from_idx)
|
||||||
|
to_ids.append(to_idx)
|
||||||
|
from_idx = to_idx
|
||||||
|
|
||||||
ep_dicts = []
|
ep_dicts = []
|
||||||
episode_data_index = {"from": [], "to": []}
|
ep_ids = episodes if episodes else range(num_episodes)
|
||||||
|
for ep_idx, selected_ep_idx in tqdm.tqdm(enumerate(ep_ids)):
|
||||||
id_from = 0
|
from_idx = from_ids[selected_ep_idx]
|
||||||
for ep_idx in tqdm.tqdm(range(num_episodes)):
|
to_idx = to_ids[selected_ep_idx]
|
||||||
id_to = episode_ends[ep_idx]
|
num_frames = to_idx - from_idx
|
||||||
num_frames = id_to - id_from
|
|
||||||
|
|
||||||
# sanity heck
|
|
||||||
assert (episode_ids[id_from:id_to] == ep_idx).all()
|
|
||||||
|
|
||||||
# TODO(rcadene): save temporary images of the episode?
|
# TODO(rcadene): save temporary images of the episode?
|
||||||
|
|
||||||
state = states[id_from:id_to]
|
state = states[from_idx:to_idx]
|
||||||
|
|
||||||
ep_dict = {}
|
ep_dict = {}
|
||||||
|
|
||||||
# load 57MB of images in RAM (400x224x224x3 uint8)
|
# load 57MB of images in RAM (400x224x224x3 uint8)
|
||||||
imgs_array = zarr_data["data/camera0_rgb"][id_from:id_to]
|
imgs_array = zarr_data["data/camera0_rgb"][from_idx:to_idx]
|
||||||
img_key = "observation.image"
|
img_key = "observation.image"
|
||||||
if video:
|
if video:
|
||||||
# save png images in temporary directory
|
# save png images in temporary directory
|
||||||
tmp_imgs_dir = out_dir / "tmp_images"
|
tmp_imgs_dir = videos_dir / "tmp_images"
|
||||||
save_images_concurrently(imgs_array, tmp_imgs_dir)
|
save_images_concurrently(imgs_array, tmp_imgs_dir)
|
||||||
|
|
||||||
# encode images to a mp4 video
|
# encode images to a mp4 video
|
||||||
fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
|
fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
|
||||||
video_path = out_dir / "videos" / fname
|
video_path = videos_dir / fname
|
||||||
encode_video_frames(tmp_imgs_dir, video_path, fps)
|
encode_video_frames(tmp_imgs_dir, video_path, fps)
|
||||||
|
|
||||||
# clean temporary images directory
|
# clean temporary images directory
|
||||||
|
@ -139,27 +125,18 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
|
||||||
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int64)
|
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int64)
|
||||||
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
|
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
|
||||||
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
|
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
|
||||||
ep_dict["episode_data_index_from"] = torch.tensor([id_from] * num_frames)
|
ep_dict["episode_data_index_from"] = torch.tensor([from_idx] * num_frames)
|
||||||
ep_dict["episode_data_index_to"] = torch.tensor([id_from + num_frames] * num_frames)
|
ep_dict["episode_data_index_to"] = torch.tensor([from_idx + num_frames] * num_frames)
|
||||||
ep_dict["end_pose"] = end_pose[id_from:id_to]
|
ep_dict["end_pose"] = end_pose[from_idx:to_idx]
|
||||||
ep_dict["start_pos"] = start_pos[id_from:id_to]
|
ep_dict["start_pos"] = start_pos[from_idx:to_idx]
|
||||||
ep_dict["gripper_width"] = gripper_width[id_from:id_to]
|
ep_dict["gripper_width"] = gripper_width[from_idx:to_idx]
|
||||||
ep_dicts.append(ep_dict)
|
ep_dicts.append(ep_dict)
|
||||||
|
|
||||||
episode_data_index["from"].append(id_from)
|
|
||||||
episode_data_index["to"].append(id_from + num_frames)
|
|
||||||
id_from += num_frames
|
|
||||||
|
|
||||||
# process first episode only
|
|
||||||
if debug:
|
|
||||||
break
|
|
||||||
|
|
||||||
data_dict = concatenate_episodes(ep_dicts)
|
data_dict = concatenate_episodes(ep_dicts)
|
||||||
|
|
||||||
total_frames = id_from
|
total_frames = data_dict["frame_index"].shape[0]
|
||||||
data_dict["index"] = torch.arange(0, total_frames, 1)
|
data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||||
|
return data_dict
|
||||||
return data_dict, episode_data_index
|
|
||||||
|
|
||||||
|
|
||||||
def to_hf_dataset(data_dict, video):
|
def to_hf_dataset(data_dict, video):
|
||||||
|
@ -199,7 +176,13 @@ def to_hf_dataset(data_dict, video):
|
||||||
return hf_dataset
|
return hf_dataset
|
||||||
|
|
||||||
|
|
||||||
def from_raw_to_lerobot_format(raw_dir: Path, out_dir: Path, fps=None, video=True, debug=False):
|
def from_raw_to_lerobot_format(
|
||||||
|
raw_dir: Path,
|
||||||
|
videos_dir: Path,
|
||||||
|
fps: int | None = None,
|
||||||
|
video: bool = True,
|
||||||
|
episodes: list[int] | None = None,
|
||||||
|
):
|
||||||
# sanity check
|
# sanity check
|
||||||
check_format(raw_dir)
|
check_format(raw_dir)
|
||||||
|
|
||||||
|
@ -212,9 +195,9 @@ def from_raw_to_lerobot_format(raw_dir: Path, out_dir: Path, fps=None, video=Tru
|
||||||
"Generating UMI dataset without `video=True` creates ~150GB on disk and requires ~80GB in RAM."
|
"Generating UMI dataset without `video=True` creates ~150GB on disk and requires ~80GB in RAM."
|
||||||
)
|
)
|
||||||
|
|
||||||
data_dict, episode_data_index = load_from_raw(raw_dir, out_dir, fps, video, debug)
|
data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes)
|
||||||
hf_dataset = to_hf_dataset(data_dict, video)
|
hf_dataset = to_hf_dataset(data_dict, video)
|
||||||
|
episode_data_index = calculate_episode_data_index(hf_dataset)
|
||||||
info = {
|
info = {
|
||||||
"fps": fps,
|
"fps": fps,
|
||||||
"video": video,
|
"video": video,
|
||||||
|
|
|
@ -27,6 +27,7 @@ from PIL import Image as PILImage
|
||||||
|
|
||||||
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently
|
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently
|
||||||
from lerobot.common.datasets.utils import (
|
from lerobot.common.datasets.utils import (
|
||||||
|
calculate_episode_data_index,
|
||||||
hf_transform_to_torch,
|
hf_transform_to_torch,
|
||||||
)
|
)
|
||||||
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
|
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
|
||||||
|
@ -54,37 +55,42 @@ def check_format(raw_dir):
|
||||||
assert all(len(nested_dict[subkey]) == expected_len for subkey in subkeys if subkey in nested_dict)
|
assert all(len(nested_dict[subkey]) == expected_len for subkey in subkeys if subkey in nested_dict)
|
||||||
|
|
||||||
|
|
||||||
def load_from_raw(raw_dir, out_dir, fps, video, debug):
|
def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episodes: list[int] | None = None):
|
||||||
pkl_path = raw_dir / "buffer.pkl"
|
pkl_path = raw_dir / "buffer.pkl"
|
||||||
|
|
||||||
with open(pkl_path, "rb") as f:
|
with open(pkl_path, "rb") as f:
|
||||||
pkl_data = pickle.load(f)
|
pkl_data = pickle.load(f)
|
||||||
|
|
||||||
ep_dicts = []
|
# load data indices from which each episode starts and ends
|
||||||
episode_data_index = {"from": [], "to": []}
|
from_ids, to_ids = [], []
|
||||||
|
from_idx, to_idx = 0, 0
|
||||||
id_from = 0
|
for done in pkl_data["dones"]:
|
||||||
id_to = 0
|
to_idx += 1
|
||||||
ep_idx = 0
|
if not done:
|
||||||
total_frames = pkl_data["actions"].shape[0]
|
|
||||||
for i in tqdm.tqdm(range(total_frames)):
|
|
||||||
id_to += 1
|
|
||||||
|
|
||||||
if not pkl_data["dones"][i]:
|
|
||||||
continue
|
continue
|
||||||
|
from_ids.append(from_idx)
|
||||||
|
to_ids.append(to_idx)
|
||||||
|
from_idx = to_idx
|
||||||
|
|
||||||
num_frames = id_to - id_from
|
num_episodes = len(from_ids)
|
||||||
|
|
||||||
image = torch.tensor(pkl_data["observations"]["rgb"][id_from:id_to])
|
ep_dicts = []
|
||||||
|
ep_ids = episodes if episodes else range(num_episodes)
|
||||||
|
for ep_idx, selected_ep_idx in tqdm.tqdm(enumerate(ep_ids)):
|
||||||
|
from_idx = from_ids[selected_ep_idx]
|
||||||
|
to_idx = to_ids[selected_ep_idx]
|
||||||
|
num_frames = to_idx - from_idx
|
||||||
|
|
||||||
|
image = torch.tensor(pkl_data["observations"]["rgb"][from_idx:to_idx])
|
||||||
image = einops.rearrange(image, "b c h w -> b h w c")
|
image = einops.rearrange(image, "b c h w -> b h w c")
|
||||||
state = torch.tensor(pkl_data["observations"]["state"][id_from:id_to])
|
state = torch.tensor(pkl_data["observations"]["state"][from_idx:to_idx])
|
||||||
action = torch.tensor(pkl_data["actions"][id_from:id_to])
|
action = torch.tensor(pkl_data["actions"][from_idx:to_idx])
|
||||||
# TODO(rcadene): we have a missing last frame which is the observation when the env is done
|
# TODO(rcadene): we have a missing last frame which is the observation when the env is done
|
||||||
# it is critical to have this frame for tdmpc to predict a "done observation/state"
|
# it is critical to have this frame for tdmpc to predict a "done observation/state"
|
||||||
# next_image = torch.tensor(pkl_data["next_observations"]["rgb"][id_from:id_to])
|
# next_image = torch.tensor(pkl_data["next_observations"]["rgb"][from_idx:to_idx])
|
||||||
# next_state = torch.tensor(pkl_data["next_observations"]["state"][id_from:id_to])
|
# next_state = torch.tensor(pkl_data["next_observations"]["state"][from_idx:to_idx])
|
||||||
next_reward = torch.tensor(pkl_data["rewards"][id_from:id_to])
|
next_reward = torch.tensor(pkl_data["rewards"][from_idx:to_idx])
|
||||||
next_done = torch.tensor(pkl_data["dones"][id_from:id_to])
|
next_done = torch.tensor(pkl_data["dones"][from_idx:to_idx])
|
||||||
|
|
||||||
ep_dict = {}
|
ep_dict = {}
|
||||||
|
|
||||||
|
@ -92,12 +98,12 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
|
||||||
img_key = "observation.image"
|
img_key = "observation.image"
|
||||||
if video:
|
if video:
|
||||||
# save png images in temporary directory
|
# save png images in temporary directory
|
||||||
tmp_imgs_dir = out_dir / "tmp_images"
|
tmp_imgs_dir = videos_dir / "tmp_images"
|
||||||
save_images_concurrently(imgs_array, tmp_imgs_dir)
|
save_images_concurrently(imgs_array, tmp_imgs_dir)
|
||||||
|
|
||||||
# encode images to a mp4 video
|
# encode images to a mp4 video
|
||||||
fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
|
fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
|
||||||
video_path = out_dir / "videos" / fname
|
video_path = videos_dir / fname
|
||||||
encode_video_frames(tmp_imgs_dir, video_path, fps)
|
encode_video_frames(tmp_imgs_dir, video_path, fps)
|
||||||
|
|
||||||
# clean temporary images directory
|
# clean temporary images directory
|
||||||
|
@ -119,18 +125,11 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
|
||||||
ep_dict["next.done"] = next_done
|
ep_dict["next.done"] = next_done
|
||||||
ep_dicts.append(ep_dict)
|
ep_dicts.append(ep_dict)
|
||||||
|
|
||||||
episode_data_index["from"].append(id_from)
|
|
||||||
episode_data_index["to"].append(id_from + num_frames)
|
|
||||||
|
|
||||||
id_from = id_to
|
|
||||||
ep_idx += 1
|
|
||||||
|
|
||||||
# process first episode only
|
|
||||||
if debug:
|
|
||||||
break
|
|
||||||
|
|
||||||
data_dict = concatenate_episodes(ep_dicts)
|
data_dict = concatenate_episodes(ep_dicts)
|
||||||
return data_dict, episode_data_index
|
|
||||||
|
total_frames = data_dict["frame_index"].shape[0]
|
||||||
|
data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||||
|
return data_dict
|
||||||
|
|
||||||
|
|
||||||
def to_hf_dataset(data_dict, video):
|
def to_hf_dataset(data_dict, video):
|
||||||
|
@ -161,16 +160,22 @@ def to_hf_dataset(data_dict, video):
|
||||||
return hf_dataset
|
return hf_dataset
|
||||||
|
|
||||||
|
|
||||||
def from_raw_to_lerobot_format(raw_dir: Path, out_dir: Path, fps=None, video=True, debug=False):
|
def from_raw_to_lerobot_format(
|
||||||
|
raw_dir: Path,
|
||||||
|
videos_dir: Path,
|
||||||
|
fps: int | None = None,
|
||||||
|
video: bool = True,
|
||||||
|
episodes: list[int] | None = None,
|
||||||
|
):
|
||||||
# sanity check
|
# sanity check
|
||||||
check_format(raw_dir)
|
check_format(raw_dir)
|
||||||
|
|
||||||
if fps is None:
|
if fps is None:
|
||||||
fps = 15
|
fps = 15
|
||||||
|
|
||||||
data_dict, episode_data_index = load_from_raw(raw_dir, out_dir, fps, video, debug)
|
data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes)
|
||||||
hf_dataset = to_hf_dataset(data_dict, video)
|
hf_dataset = to_hf_dataset(data_dict, video)
|
||||||
|
episode_data_index = calculate_episode_data_index(hf_dataset)
|
||||||
info = {
|
info = {
|
||||||
"fps": fps,
|
"fps": fps,
|
||||||
"video": video,
|
"video": video,
|
||||||
|
|
|
@ -0,0 +1,197 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import collections
|
||||||
|
from typing import Any, Callable, Dict, Sequence
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torchvision.transforms import v2
|
||||||
|
from torchvision.transforms.v2 import Transform
|
||||||
|
from torchvision.transforms.v2 import functional as F # noqa: N812
|
||||||
|
|
||||||
|
|
||||||
|
class RandomSubsetApply(Transform):
|
||||||
|
"""Apply a random subset of N transformations from a list of transformations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
transforms: list of transformations.
|
||||||
|
p: represents the multinomial probabilities (with no replacement) used for sampling the transform.
|
||||||
|
If the sum of the weights is not 1, they will be normalized. If ``None`` (default), all transforms
|
||||||
|
have the same probability.
|
||||||
|
n_subset: number of transformations to apply. If ``None``, all transforms are applied.
|
||||||
|
Must be in [1, len(transforms)].
|
||||||
|
random_order: apply transformations in a random order.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
transforms: Sequence[Callable],
|
||||||
|
p: list[float] | None = None,
|
||||||
|
n_subset: int | None = None,
|
||||||
|
random_order: bool = False,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
if not isinstance(transforms, Sequence):
|
||||||
|
raise TypeError("Argument transforms should be a sequence of callables")
|
||||||
|
if p is None:
|
||||||
|
p = [1] * len(transforms)
|
||||||
|
elif len(p) != len(transforms):
|
||||||
|
raise ValueError(
|
||||||
|
f"Length of p doesn't match the number of transforms: {len(p)} != {len(transforms)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if n_subset is None:
|
||||||
|
n_subset = len(transforms)
|
||||||
|
elif not isinstance(n_subset, int):
|
||||||
|
raise TypeError("n_subset should be an int or None")
|
||||||
|
elif not (1 <= n_subset <= len(transforms)):
|
||||||
|
raise ValueError(f"n_subset should be in the interval [1, {len(transforms)}]")
|
||||||
|
|
||||||
|
self.transforms = transforms
|
||||||
|
total = sum(p)
|
||||||
|
self.p = [prob / total for prob in p]
|
||||||
|
self.n_subset = n_subset
|
||||||
|
self.random_order = random_order
|
||||||
|
|
||||||
|
def forward(self, *inputs: Any) -> Any:
|
||||||
|
needs_unpacking = len(inputs) > 1
|
||||||
|
|
||||||
|
selected_indices = torch.multinomial(torch.tensor(self.p), self.n_subset)
|
||||||
|
if not self.random_order:
|
||||||
|
selected_indices = selected_indices.sort().values
|
||||||
|
|
||||||
|
selected_transforms = [self.transforms[i] for i in selected_indices]
|
||||||
|
|
||||||
|
for transform in selected_transforms:
|
||||||
|
outputs = transform(*inputs)
|
||||||
|
inputs = outputs if needs_unpacking else (outputs,)
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def extra_repr(self) -> str:
|
||||||
|
return (
|
||||||
|
f"transforms={self.transforms}, "
|
||||||
|
f"p={self.p}, "
|
||||||
|
f"n_subset={self.n_subset}, "
|
||||||
|
f"random_order={self.random_order}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SharpnessJitter(Transform):
|
||||||
|
"""Randomly change the sharpness of an image or video.
|
||||||
|
|
||||||
|
Similar to a v2.RandomAdjustSharpness with p=1 and a sharpness_factor sampled randomly.
|
||||||
|
While v2.RandomAdjustSharpness applies — with a given probability — a fixed sharpness_factor to an image,
|
||||||
|
SharpnessJitter applies a random sharpness_factor each time. This is to have a more diverse set of
|
||||||
|
augmentations as a result.
|
||||||
|
|
||||||
|
A sharpness_factor of 0 gives a blurred image, 1 gives the original image while 2 increases the sharpness
|
||||||
|
by a factor of 2.
|
||||||
|
|
||||||
|
If the input is a :class:`torch.Tensor`,
|
||||||
|
it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sharpness: How much to jitter sharpness. sharpness_factor is chosen uniformly from
|
||||||
|
[max(0, 1 - sharpness), 1 + sharpness] or the given
|
||||||
|
[min, max]. Should be non negative numbers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, sharpness: float | Sequence[float]) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.sharpness = self._check_input(sharpness)
|
||||||
|
|
||||||
|
def _check_input(self, sharpness):
|
||||||
|
if isinstance(sharpness, (int, float)):
|
||||||
|
if sharpness < 0:
|
||||||
|
raise ValueError("If sharpness is a single number, it must be non negative.")
|
||||||
|
sharpness = [1.0 - sharpness, 1.0 + sharpness]
|
||||||
|
sharpness[0] = max(sharpness[0], 0.0)
|
||||||
|
elif isinstance(sharpness, collections.abc.Sequence) and len(sharpness) == 2:
|
||||||
|
sharpness = [float(v) for v in sharpness]
|
||||||
|
else:
|
||||||
|
raise TypeError(f"{sharpness=} should be a single number or a sequence with length 2.")
|
||||||
|
|
||||||
|
if not 0.0 <= sharpness[0] <= sharpness[1]:
|
||||||
|
raise ValueError(f"sharpnesss values should be between (0., inf), but got {sharpness}.")
|
||||||
|
|
||||||
|
return float(sharpness[0]), float(sharpness[1])
|
||||||
|
|
||||||
|
def _generate_value(self, left: float, right: float) -> float:
|
||||||
|
return torch.empty(1).uniform_(left, right).item()
|
||||||
|
|
||||||
|
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
||||||
|
sharpness_factor = self._generate_value(self.sharpness[0], self.sharpness[1])
|
||||||
|
return self._call_kernel(F.adjust_sharpness, inpt, sharpness_factor=sharpness_factor)
|
||||||
|
|
||||||
|
|
||||||
|
def get_image_transforms(
|
||||||
|
brightness_weight: float = 1.0,
|
||||||
|
brightness_min_max: tuple[float, float] | None = None,
|
||||||
|
contrast_weight: float = 1.0,
|
||||||
|
contrast_min_max: tuple[float, float] | None = None,
|
||||||
|
saturation_weight: float = 1.0,
|
||||||
|
saturation_min_max: tuple[float, float] | None = None,
|
||||||
|
hue_weight: float = 1.0,
|
||||||
|
hue_min_max: tuple[float, float] | None = None,
|
||||||
|
sharpness_weight: float = 1.0,
|
||||||
|
sharpness_min_max: tuple[float, float] | None = None,
|
||||||
|
max_num_transforms: int | None = None,
|
||||||
|
random_order: bool = False,
|
||||||
|
):
|
||||||
|
def check_value(name, weight, min_max):
|
||||||
|
if min_max is not None:
|
||||||
|
if len(min_max) != 2:
|
||||||
|
raise ValueError(
|
||||||
|
f"`{name}_min_max` is expected to be a tuple of 2 dimensions, but {min_max} provided."
|
||||||
|
)
|
||||||
|
if weight < 0.0:
|
||||||
|
raise ValueError(
|
||||||
|
f"`{name}_weight` is expected to be 0 or positive, but is negative ({weight})."
|
||||||
|
)
|
||||||
|
|
||||||
|
check_value("brightness", brightness_weight, brightness_min_max)
|
||||||
|
check_value("contrast", contrast_weight, contrast_min_max)
|
||||||
|
check_value("saturation", saturation_weight, saturation_min_max)
|
||||||
|
check_value("hue", hue_weight, hue_min_max)
|
||||||
|
check_value("sharpness", sharpness_weight, sharpness_min_max)
|
||||||
|
|
||||||
|
weights = []
|
||||||
|
transforms = []
|
||||||
|
if brightness_min_max is not None and brightness_weight > 0.0:
|
||||||
|
weights.append(brightness_weight)
|
||||||
|
transforms.append(v2.ColorJitter(brightness=brightness_min_max))
|
||||||
|
if contrast_min_max is not None and contrast_weight > 0.0:
|
||||||
|
weights.append(contrast_weight)
|
||||||
|
transforms.append(v2.ColorJitter(contrast=contrast_min_max))
|
||||||
|
if saturation_min_max is not None and saturation_weight > 0.0:
|
||||||
|
weights.append(saturation_weight)
|
||||||
|
transforms.append(v2.ColorJitter(saturation=saturation_min_max))
|
||||||
|
if hue_min_max is not None and hue_weight > 0.0:
|
||||||
|
weights.append(hue_weight)
|
||||||
|
transforms.append(v2.ColorJitter(hue=hue_min_max))
|
||||||
|
if sharpness_min_max is not None and sharpness_weight > 0.0:
|
||||||
|
weights.append(sharpness_weight)
|
||||||
|
transforms.append(SharpnessJitter(sharpness=sharpness_min_max))
|
||||||
|
|
||||||
|
n_subset = len(transforms)
|
||||||
|
if max_num_transforms is not None:
|
||||||
|
n_subset = min(n_subset, max_num_transforms)
|
||||||
|
|
||||||
|
if n_subset == 0:
|
||||||
|
return v2.Identity()
|
||||||
|
else:
|
||||||
|
# TODO(rcadene, aliberts): add v2.ToDtype float16?
|
||||||
|
return RandomSubsetApply(transforms, p=weights, n_subset=n_subset, random_order=random_order)
|
|
@ -238,5 +238,6 @@ class Logger:
|
||||||
|
|
||||||
def log_video(self, video_path: str, step: int, mode: str = "train"):
|
def log_video(self, video_path: str, step: int, mode: str = "train"):
|
||||||
assert mode in {"train", "eval"}
|
assert mode in {"train", "eval"}
|
||||||
|
assert self._wandb is not None
|
||||||
wandb_video = self._wandb.Video(video_path, fps=self._cfg.fps, format="mp4")
|
wandb_video = self._wandb.Video(video_path, fps=self._cfg.fps, format="mp4")
|
||||||
self._wandb.log({f"{mode}/video": wandb_video}, step=step)
|
self._wandb.log({f"{mode}/video": wandb_video}, step=step)
|
||||||
|
|
|
@ -239,10 +239,8 @@ class DiffusionModel(nn.Module):
|
||||||
global_cond = torch.cat([batch["observation.state"], img_features], dim=-1).flatten(start_dim=1)
|
global_cond = torch.cat([batch["observation.state"], img_features], dim=-1).flatten(start_dim=1)
|
||||||
|
|
||||||
# run sampling
|
# run sampling
|
||||||
sample = self.conditional_sample(batch_size, global_cond=global_cond)
|
actions = self.conditional_sample(batch_size, global_cond=global_cond)
|
||||||
|
|
||||||
# `horizon` steps worth of actions (from the first observation).
|
|
||||||
actions = sample[..., : self.config.output_shapes["action"][0]]
|
|
||||||
# Extract `n_action_steps` steps worth of actions (from the current observation).
|
# Extract `n_action_steps` steps worth of actions (from the current observation).
|
||||||
start = n_obs_steps - 1
|
start = n_obs_steps - 1
|
||||||
end = start + self.config.n_action_steps
|
end = start + self.config.n_action_steps
|
||||||
|
|
|
@ -147,7 +147,7 @@ class Normalize(nn.Module):
|
||||||
assert not torch.isinf(min).any(), _no_stats_error_str("min")
|
assert not torch.isinf(min).any(), _no_stats_error_str("min")
|
||||||
assert not torch.isinf(max).any(), _no_stats_error_str("max")
|
assert not torch.isinf(max).any(), _no_stats_error_str("max")
|
||||||
# normalize to [0,1]
|
# normalize to [0,1]
|
||||||
batch[key] = (batch[key] - min) / (max - min)
|
batch[key] = (batch[key] - min) / (max - min + 1e-8)
|
||||||
# normalize to [-1, 1]
|
# normalize to [-1, 1]
|
||||||
batch[key] = batch[key] * 2 - 1
|
batch[key] = batch[key] * 2 - 1
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -57,7 +57,7 @@ class Policy(Protocol):
|
||||||
other items should be logging-friendly, native Python types.
|
other items should be logging-friendly, native Python types.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def select_action(self, batch: dict[str, Tensor]):
|
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||||
"""Return one action to run in the environment (potentially in batch mode).
|
"""Return one action to run in the environment (potentially in batch mode).
|
||||||
|
|
||||||
When the model uses a history of observations, or outputs a sequence of actions, this method deals
|
When the model uses a history of observations, or outputs a sequence of actions, this method deals
|
||||||
|
|
|
@ -134,7 +134,7 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
self._prev_mean: torch.Tensor | None = None
|
self._prev_mean: torch.Tensor | None = None
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def select_action(self, batch: dict[str, Tensor]):
|
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||||
"""Select a single action given environment observations."""
|
"""Select a single action given environment observations."""
|
||||||
batch = self.normalize_inputs(batch)
|
batch = self.normalize_inputs(batch)
|
||||||
batch["observation.image"] = batch[self.input_image_key]
|
batch["observation.image"] = batch[self.input_image_key]
|
||||||
|
|
|
@ -43,6 +43,40 @@ training:
|
||||||
save_checkpoint: true
|
save_checkpoint: true
|
||||||
num_workers: 4
|
num_workers: 4
|
||||||
batch_size: ???
|
batch_size: ???
|
||||||
|
image_transforms:
|
||||||
|
# These transforms are all using standard torchvision.transforms.v2
|
||||||
|
# You can find out how these transformations affect images here:
|
||||||
|
# https://pytorch.org/vision/0.18/auto_examples/transforms/plot_transforms_illustrations.html
|
||||||
|
# We use a custom RandomSubsetApply container to sample them.
|
||||||
|
# For each transform, the following parameters are available:
|
||||||
|
# weight: This represents the multinomial probability (with no replacement)
|
||||||
|
# used for sampling the transform. If the sum of the weights is not 1,
|
||||||
|
# they will be normalized.
|
||||||
|
# min_max: Lower & upper bound respectively used for sampling the transform's parameter
|
||||||
|
# (following uniform distribution) when it's applied.
|
||||||
|
# Set this flag to `true` to enable transforms during training
|
||||||
|
enable: false
|
||||||
|
# This is the maximum number of transforms (sampled from these below) that will be applied to each frame.
|
||||||
|
# It's an integer in the interval [1, number of available transforms].
|
||||||
|
max_num_transforms: 3
|
||||||
|
# By default, transforms are applied in Torchvision's suggested order (shown below).
|
||||||
|
# Set this to True to apply them in a random order.
|
||||||
|
random_order: false
|
||||||
|
brightness:
|
||||||
|
weight: 1
|
||||||
|
min_max: [0.8, 1.2]
|
||||||
|
contrast:
|
||||||
|
weight: 1
|
||||||
|
min_max: [0.8, 1.2]
|
||||||
|
saturation:
|
||||||
|
weight: 1
|
||||||
|
min_max: [0.5, 1.5]
|
||||||
|
hue:
|
||||||
|
weight: 1
|
||||||
|
min_max: [-0.05, 0.05]
|
||||||
|
sharpness:
|
||||||
|
weight: 1
|
||||||
|
min_max: [0.8, 1.2]
|
||||||
|
|
||||||
eval:
|
eval:
|
||||||
n_episodes: 1
|
n_episodes: 1
|
||||||
|
|
|
@ -13,39 +13,71 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""Use this script to get a quick summary of your system config.
|
||||||
|
It should be able to run without any of LeRobot's dependencies or LeRobot itself installed.
|
||||||
|
"""
|
||||||
|
|
||||||
import platform
|
import platform
|
||||||
|
|
||||||
import huggingface_hub
|
HAS_HF_HUB = True
|
||||||
|
HAS_HF_DATASETS = True
|
||||||
|
HAS_NP = True
|
||||||
|
HAS_TORCH = True
|
||||||
|
HAS_LEROBOT = True
|
||||||
|
|
||||||
# import dataset
|
try:
|
||||||
import numpy as np
|
import huggingface_hub
|
||||||
import torch
|
except ImportError:
|
||||||
|
HAS_HF_HUB = False
|
||||||
|
|
||||||
from lerobot import __version__ as version
|
try:
|
||||||
|
import datasets
|
||||||
|
except ImportError:
|
||||||
|
HAS_HF_DATASETS = False
|
||||||
|
|
||||||
pt_version = torch.__version__
|
try:
|
||||||
pt_cuda_available = torch.cuda.is_available()
|
import numpy as np
|
||||||
pt_cuda_available = torch.cuda.is_available()
|
except ImportError:
|
||||||
cuda_version = torch._C._cuda_getCompiledVersion() if torch.version.cuda is not None else "N/A"
|
HAS_NP = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
except ImportError:
|
||||||
|
HAS_TORCH = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
import lerobot
|
||||||
|
except ImportError:
|
||||||
|
HAS_LEROBOT = False
|
||||||
|
|
||||||
|
|
||||||
|
lerobot_version = lerobot.__version__ if HAS_LEROBOT else "N/A"
|
||||||
|
hf_hub_version = huggingface_hub.__version__ if HAS_HF_HUB else "N/A"
|
||||||
|
hf_datasets_version = datasets.__version__ if HAS_HF_DATASETS else "N/A"
|
||||||
|
np_version = np.__version__ if HAS_NP else "N/A"
|
||||||
|
|
||||||
|
torch_version = torch.__version__ if HAS_TORCH else "N/A"
|
||||||
|
torch_cuda_available = torch.cuda.is_available() if HAS_TORCH else "N/A"
|
||||||
|
cuda_version = torch._C._cuda_getCompiledVersion() if HAS_TORCH and torch.version.cuda is not None else "N/A"
|
||||||
|
|
||||||
|
|
||||||
# TODO(aliberts): refactor into an actual command `lerobot env`
|
# TODO(aliberts): refactor into an actual command `lerobot env`
|
||||||
def display_sys_info() -> dict:
|
def display_sys_info() -> dict:
|
||||||
"""Run this to get basic system info to help for tracking issues & bugs."""
|
"""Run this to get basic system info to help for tracking issues & bugs."""
|
||||||
info = {
|
info = {
|
||||||
"`lerobot` version": version,
|
"`lerobot` version": lerobot_version,
|
||||||
"Platform": platform.platform(),
|
"Platform": platform.platform(),
|
||||||
"Python version": platform.python_version(),
|
"Python version": platform.python_version(),
|
||||||
"Huggingface_hub version": huggingface_hub.__version__,
|
"Huggingface_hub version": hf_hub_version,
|
||||||
# TODO(aliberts): Add dataset when https://github.com/huggingface/lerobot/pull/73 is merged
|
"Dataset version": hf_datasets_version,
|
||||||
# "Dataset version": dataset.__version__,
|
"Numpy version": np_version,
|
||||||
"Numpy version": np.__version__,
|
"PyTorch version (GPU?)": f"{torch_version} ({torch_cuda_available})",
|
||||||
"PyTorch version (GPU?)": f"{pt_version} ({pt_cuda_available})",
|
|
||||||
"Cuda version": cuda_version,
|
"Cuda version": cuda_version,
|
||||||
"Using GPU in script?": "<fill in>",
|
"Using GPU in script?": "<fill in>",
|
||||||
"Using distributed or parallel set-up in script?": "<fill in>",
|
# "Using distributed or parallel set-up in script?": "<fill in>",
|
||||||
}
|
}
|
||||||
print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the two last points.\n")
|
print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the last point.\n")
|
||||||
print(format_dict(info))
|
print(format_dict(info))
|
||||||
return info
|
return info
|
||||||
|
|
||||||
|
|
|
@ -61,7 +61,7 @@ from huggingface_hub import snapshot_download
|
||||||
from huggingface_hub.utils._errors import RepositoryNotFoundError
|
from huggingface_hub.utils._errors import RepositoryNotFoundError
|
||||||
from huggingface_hub.utils._validators import HFValidationError
|
from huggingface_hub.utils._validators import HFValidationError
|
||||||
from PIL import Image as PILImage
|
from PIL import Image as PILImage
|
||||||
from torch import Tensor
|
from torch import Tensor, nn
|
||||||
from tqdm import trange
|
from tqdm import trange
|
||||||
|
|
||||||
from lerobot.common.datasets.factory import make_dataset
|
from lerobot.common.datasets.factory import make_dataset
|
||||||
|
@ -99,13 +99,13 @@ def rollout(
|
||||||
"reward": A (batch, sequence) tensor of rewards received for applying the actions.
|
"reward": A (batch, sequence) tensor of rewards received for applying the actions.
|
||||||
"success": A (batch, sequence) tensor of success conditions (the only time this can be True is upon
|
"success": A (batch, sequence) tensor of success conditions (the only time this can be True is upon
|
||||||
environment termination/truncation).
|
environment termination/truncation).
|
||||||
"don": A (batch, sequence) tensor of **cumulative** done conditions. For any given batch element,
|
"done": A (batch, sequence) tensor of **cumulative** done conditions. For any given batch element,
|
||||||
the first True is followed by True's all the way till the end. This can be used for masking
|
the first True is followed by True's all the way till the end. This can be used for masking
|
||||||
extraneous elements from the sequences above.
|
extraneous elements from the sequences above.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
env: The batch of environments.
|
env: The batch of environments.
|
||||||
policy: The policy.
|
policy: The policy. Must be a PyTorch nn module.
|
||||||
seeds: The environments are seeded once at the start of the rollout. If provided, this argument
|
seeds: The environments are seeded once at the start of the rollout. If provided, this argument
|
||||||
specifies the seeds for each of the environments.
|
specifies the seeds for each of the environments.
|
||||||
return_observations: Whether to include all observations in the returned rollout data. Observations
|
return_observations: Whether to include all observations in the returned rollout data. Observations
|
||||||
|
@ -116,6 +116,7 @@ def rollout(
|
||||||
Returns:
|
Returns:
|
||||||
The dictionary described above.
|
The dictionary described above.
|
||||||
"""
|
"""
|
||||||
|
assert isinstance(policy, nn.Module), "Policy must be a PyTorch nn module."
|
||||||
device = get_device_from_parameters(policy)
|
device = get_device_from_parameters(policy)
|
||||||
|
|
||||||
# Reset the policy and environments.
|
# Reset the policy and environments.
|
||||||
|
@ -209,7 +210,7 @@ def eval_policy(
|
||||||
policy: torch.nn.Module,
|
policy: torch.nn.Module,
|
||||||
n_episodes: int,
|
n_episodes: int,
|
||||||
max_episodes_rendered: int = 0,
|
max_episodes_rendered: int = 0,
|
||||||
video_dir: Path | None = None,
|
videos_dir: Path | None = None,
|
||||||
return_episode_data: bool = False,
|
return_episode_data: bool = False,
|
||||||
start_seed: int | None = None,
|
start_seed: int | None = None,
|
||||||
enable_progbar: bool = False,
|
enable_progbar: bool = False,
|
||||||
|
@ -221,7 +222,7 @@ def eval_policy(
|
||||||
policy: The policy.
|
policy: The policy.
|
||||||
n_episodes: The number of episodes to evaluate.
|
n_episodes: The number of episodes to evaluate.
|
||||||
max_episodes_rendered: Maximum number of episodes to render into videos.
|
max_episodes_rendered: Maximum number of episodes to render into videos.
|
||||||
video_dir: Where to save rendered videos.
|
videos_dir: Where to save rendered videos.
|
||||||
return_episode_data: Whether to return episode data for online training. Incorporates the data into
|
return_episode_data: Whether to return episode data for online training. Incorporates the data into
|
||||||
the "episodes" key of the returned dictionary.
|
the "episodes" key of the returned dictionary.
|
||||||
start_seed: The first seed to use for the first individual rollout. For all subsequent rollouts the
|
start_seed: The first seed to use for the first individual rollout. For all subsequent rollouts the
|
||||||
|
@ -231,6 +232,10 @@ def eval_policy(
|
||||||
Returns:
|
Returns:
|
||||||
Dictionary with metrics and data regarding the rollouts.
|
Dictionary with metrics and data regarding the rollouts.
|
||||||
"""
|
"""
|
||||||
|
if max_episodes_rendered > 0 and not videos_dir:
|
||||||
|
raise ValueError("If max_episodes_rendered > 0, videos_dir must be provided.")
|
||||||
|
|
||||||
|
assert isinstance(policy, Policy)
|
||||||
start = time.time()
|
start = time.time()
|
||||||
policy.eval()
|
policy.eval()
|
||||||
|
|
||||||
|
@ -271,11 +276,16 @@ def eval_policy(
|
||||||
if max_episodes_rendered > 0:
|
if max_episodes_rendered > 0:
|
||||||
ep_frames: list[np.ndarray] = []
|
ep_frames: list[np.ndarray] = []
|
||||||
|
|
||||||
seeds = range(start_seed + (batch_ix * env.num_envs), start_seed + ((batch_ix + 1) * env.num_envs))
|
if start_seed is None:
|
||||||
|
seeds = None
|
||||||
|
else:
|
||||||
|
seeds = range(
|
||||||
|
start_seed + (batch_ix * env.num_envs), start_seed + ((batch_ix + 1) * env.num_envs)
|
||||||
|
)
|
||||||
rollout_data = rollout(
|
rollout_data = rollout(
|
||||||
env,
|
env,
|
||||||
policy,
|
policy,
|
||||||
seeds=seeds,
|
seeds=list(seeds) if seeds else None,
|
||||||
return_observations=return_episode_data,
|
return_observations=return_episode_data,
|
||||||
render_callback=render_frame if max_episodes_rendered > 0 else None,
|
render_callback=render_frame if max_episodes_rendered > 0 else None,
|
||||||
enable_progbar=enable_inner_progbar,
|
enable_progbar=enable_inner_progbar,
|
||||||
|
@ -285,7 +295,8 @@ def eval_policy(
|
||||||
# this won't be included).
|
# this won't be included).
|
||||||
n_steps = rollout_data["done"].shape[1]
|
n_steps = rollout_data["done"].shape[1]
|
||||||
# Note: this relies on a property of argmax: that it returns the first occurrence as a tiebreaker.
|
# Note: this relies on a property of argmax: that it returns the first occurrence as a tiebreaker.
|
||||||
done_indices = torch.argmax(rollout_data["done"].to(int), axis=1) # (batch_size, rollout_steps)
|
done_indices = torch.argmax(rollout_data["done"].to(int), dim=1)
|
||||||
|
|
||||||
# Make a mask with shape (batch, n_steps) to mask out rollout data after the first done
|
# Make a mask with shape (batch, n_steps) to mask out rollout data after the first done
|
||||||
# (batch-element-wise). Note the `done_indices + 1` to make sure to keep the data from the done step.
|
# (batch-element-wise). Note the `done_indices + 1` to make sure to keep the data from the done step.
|
||||||
mask = (torch.arange(n_steps) <= einops.repeat(done_indices + 1, "b -> b s", s=n_steps)).int()
|
mask = (torch.arange(n_steps) <= einops.repeat(done_indices + 1, "b -> b s", s=n_steps)).int()
|
||||||
|
@ -296,8 +307,12 @@ def eval_policy(
|
||||||
max_rewards.extend(batch_max_rewards.tolist())
|
max_rewards.extend(batch_max_rewards.tolist())
|
||||||
batch_successes = einops.reduce((rollout_data["success"] * mask), "b n -> b", "any")
|
batch_successes = einops.reduce((rollout_data["success"] * mask), "b n -> b", "any")
|
||||||
all_successes.extend(batch_successes.tolist())
|
all_successes.extend(batch_successes.tolist())
|
||||||
|
if seeds:
|
||||||
all_seeds.extend(seeds)
|
all_seeds.extend(seeds)
|
||||||
|
else:
|
||||||
|
all_seeds.append(None)
|
||||||
|
|
||||||
|
# FIXME: episode_data is either None or it doesn't exist
|
||||||
if return_episode_data:
|
if return_episode_data:
|
||||||
this_episode_data = _compile_episode_data(
|
this_episode_data = _compile_episode_data(
|
||||||
rollout_data,
|
rollout_data,
|
||||||
|
@ -347,8 +362,9 @@ def eval_policy(
|
||||||
):
|
):
|
||||||
if n_episodes_rendered >= max_episodes_rendered:
|
if n_episodes_rendered >= max_episodes_rendered:
|
||||||
break
|
break
|
||||||
video_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
video_path = video_dir / f"eval_episode_{n_episodes_rendered}.mp4"
|
videos_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
video_path = videos_dir / f"eval_episode_{n_episodes_rendered}.mp4"
|
||||||
video_paths.append(str(video_path))
|
video_paths.append(str(video_path))
|
||||||
thread = threading.Thread(
|
thread = threading.Thread(
|
||||||
target=write_video,
|
target=write_video,
|
||||||
|
@ -503,22 +519,20 @@ def _compile_episode_data(
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def eval(
|
def main(
|
||||||
pretrained_policy_path: str | None = None,
|
pretrained_policy_path: Path | None = None,
|
||||||
hydra_cfg_path: str | None = None,
|
hydra_cfg_path: str | None = None,
|
||||||
|
out_dir: str | None = None,
|
||||||
config_overrides: list[str] | None = None,
|
config_overrides: list[str] | None = None,
|
||||||
):
|
):
|
||||||
assert (pretrained_policy_path is None) ^ (hydra_cfg_path is None)
|
assert (pretrained_policy_path is None) ^ (hydra_cfg_path is None)
|
||||||
if hydra_cfg_path is None:
|
if pretrained_policy_path is not None:
|
||||||
hydra_cfg = init_hydra_config(pretrained_policy_path / "config.yaml", config_overrides)
|
hydra_cfg = init_hydra_config(str(pretrained_policy_path / "config.yaml"), config_overrides)
|
||||||
else:
|
else:
|
||||||
hydra_cfg = init_hydra_config(hydra_cfg_path, config_overrides)
|
hydra_cfg = init_hydra_config(hydra_cfg_path, config_overrides)
|
||||||
out_dir = (
|
|
||||||
f"outputs/eval/{dt.now().strftime('%Y-%m-%d/%H-%M-%S')}_{hydra_cfg.env.name}_{hydra_cfg.policy.name}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if out_dir is None:
|
if out_dir is None:
|
||||||
raise NotImplementedError()
|
out_dir = f"outputs/eval/{dt.now().strftime('%Y-%m-%d/%H-%M-%S')}_{hydra_cfg.env.name}_{hydra_cfg.policy.name}"
|
||||||
|
|
||||||
# Check device is available
|
# Check device is available
|
||||||
device = get_safe_torch_device(hydra_cfg.device, log=True)
|
device = get_safe_torch_device(hydra_cfg.device, log=True)
|
||||||
|
@ -534,10 +548,12 @@ def eval(
|
||||||
|
|
||||||
logging.info("Making policy.")
|
logging.info("Making policy.")
|
||||||
if hydra_cfg_path is None:
|
if hydra_cfg_path is None:
|
||||||
policy = make_policy(hydra_cfg=hydra_cfg, pretrained_policy_name_or_path=pretrained_policy_path)
|
policy = make_policy(hydra_cfg=hydra_cfg, pretrained_policy_name_or_path=str(pretrained_policy_path))
|
||||||
else:
|
else:
|
||||||
# Note: We need the dataset stats to pass to the policy's normalization modules.
|
# Note: We need the dataset stats to pass to the policy's normalization modules.
|
||||||
policy = make_policy(hydra_cfg=hydra_cfg, dataset_stats=make_dataset(hydra_cfg).stats)
|
policy = make_policy(hydra_cfg=hydra_cfg, dataset_stats=make_dataset(hydra_cfg).stats)
|
||||||
|
|
||||||
|
assert isinstance(policy, nn.Module)
|
||||||
policy.eval()
|
policy.eval()
|
||||||
|
|
||||||
with torch.no_grad(), torch.autocast(device_type=device.type) if hydra_cfg.use_amp else nullcontext():
|
with torch.no_grad(), torch.autocast(device_type=device.type) if hydra_cfg.use_amp else nullcontext():
|
||||||
|
@ -546,7 +562,7 @@ def eval(
|
||||||
policy,
|
policy,
|
||||||
hydra_cfg.eval.n_episodes,
|
hydra_cfg.eval.n_episodes,
|
||||||
max_episodes_rendered=10,
|
max_episodes_rendered=10,
|
||||||
video_dir=Path(out_dir) / "eval",
|
videos_dir=Path(out_dir) / "videos",
|
||||||
start_seed=hydra_cfg.seed,
|
start_seed=hydra_cfg.seed,
|
||||||
enable_progbar=True,
|
enable_progbar=True,
|
||||||
enable_inner_progbar=True,
|
enable_inner_progbar=True,
|
||||||
|
@ -586,6 +602,13 @@ if __name__ == "__main__":
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
parser.add_argument("--revision", help="Optionally provide the Hugging Face Hub revision ID.")
|
parser.add_argument("--revision", help="Optionally provide the Hugging Face Hub revision ID.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--out-dir",
|
||||||
|
help=(
|
||||||
|
"Where to save the evaluation outputs. If not provided, outputs are saved in "
|
||||||
|
"outputs/eval/{timestamp}_{env_name}_{policy_name}"
|
||||||
|
),
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"overrides",
|
"overrides",
|
||||||
nargs="*",
|
nargs="*",
|
||||||
|
@ -594,7 +617,7 @@ if __name__ == "__main__":
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if args.pretrained_policy_name_or_path is None:
|
if args.pretrained_policy_name_or_path is None:
|
||||||
eval(hydra_cfg_path=args.config, config_overrides=args.overrides)
|
main(hydra_cfg_path=args.config, out_dir=args.out_dir, config_overrides=args.overrides)
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
pretrained_policy_path = Path(
|
pretrained_policy_path = Path(
|
||||||
|
@ -618,4 +641,8 @@ if __name__ == "__main__":
|
||||||
"repo ID, nor is it an existing local directory."
|
"repo ID, nor is it an existing local directory."
|
||||||
)
|
)
|
||||||
|
|
||||||
eval(pretrained_policy_path=pretrained_policy_path, config_overrides=args.overrides)
|
main(
|
||||||
|
pretrained_policy_path=pretrained_policy_path,
|
||||||
|
out_dir=args.out_dir,
|
||||||
|
config_overrides=args.overrides,
|
||||||
|
)
|
||||||
|
|
|
@ -18,57 +18,39 @@ Use this script to convert your dataset into LeRobot dataset format and upload i
|
||||||
or store it locally. LeRobot dataset format is lightweight, fast to load from, and does not require any
|
or store it locally. LeRobot dataset format is lightweight, fast to load from, and does not require any
|
||||||
installation of neural net specific packages like pytorch, tensorflow, jax.
|
installation of neural net specific packages like pytorch, tensorflow, jax.
|
||||||
|
|
||||||
Example:
|
Example of how to download raw datasets, convert them into LeRobotDataset format, and push them to the hub:
|
||||||
```
|
```
|
||||||
python lerobot/scripts/push_dataset_to_hub.py \
|
python lerobot/scripts/push_dataset_to_hub.py \
|
||||||
--data-dir data \
|
--raw-dir data/pusht_raw \
|
||||||
--dataset-id pusht \
|
|
||||||
--raw-format pusht_zarr \
|
--raw-format pusht_zarr \
|
||||||
--community-id lerobot \
|
--repo-id lerobot/pusht
|
||||||
--dry-run 1 \
|
|
||||||
--save-to-disk 1 \
|
|
||||||
--save-tests-to-disk 0 \
|
|
||||||
--debug 1
|
|
||||||
|
|
||||||
python lerobot/scripts/push_dataset_to_hub.py \
|
python lerobot/scripts/push_dataset_to_hub.py \
|
||||||
--data-dir data \
|
--raw-dir data/xarm_lift_medium_raw \
|
||||||
--dataset-id xarm_lift_medium \
|
|
||||||
--raw-format xarm_pkl \
|
--raw-format xarm_pkl \
|
||||||
--community-id lerobot \
|
--repo-id lerobot/xarm_lift_medium
|
||||||
--dry-run 1 \
|
|
||||||
--save-to-disk 1 \
|
|
||||||
--save-tests-to-disk 0 \
|
|
||||||
--debug 1
|
|
||||||
|
|
||||||
python lerobot/scripts/push_dataset_to_hub.py \
|
python lerobot/scripts/push_dataset_to_hub.py \
|
||||||
--data-dir data \
|
--raw-dir data/aloha_sim_insertion_scripted_raw \
|
||||||
--dataset-id aloha_sim_insertion_scripted \
|
|
||||||
--raw-format aloha_hdf5 \
|
--raw-format aloha_hdf5 \
|
||||||
--community-id lerobot \
|
--repo-id lerobot/aloha_sim_insertion_scripted
|
||||||
--dry-run 1 \
|
|
||||||
--save-to-disk 1 \
|
|
||||||
--save-tests-to-disk 0 \
|
|
||||||
--debug 1
|
|
||||||
|
|
||||||
python lerobot/scripts/push_dataset_to_hub.py \
|
python lerobot/scripts/push_dataset_to_hub.py \
|
||||||
--data-dir data \
|
--raw-dir data/umi_cup_in_the_wild_raw \
|
||||||
--dataset-id umi_cup_in_the_wild \
|
|
||||||
--raw-format umi_zarr \
|
--raw-format umi_zarr \
|
||||||
--community-id lerobot \
|
--repo-id lerobot/umi_cup_in_the_wild
|
||||||
--dry-run 1 \
|
|
||||||
--save-to-disk 1 \
|
|
||||||
--save-tests-to-disk 0 \
|
|
||||||
--debug 1
|
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
import shutil
|
import shutil
|
||||||
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from huggingface_hub import HfApi
|
from huggingface_hub import HfApi, create_branch
|
||||||
from safetensors.torch import save_file
|
from safetensors.torch import save_file
|
||||||
|
|
||||||
from lerobot.common.datasets.compute_stats import compute_stats
|
from lerobot.common.datasets.compute_stats import compute_stats
|
||||||
|
@ -77,15 +59,15 @@ from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_r
|
||||||
from lerobot.common.datasets.utils import flatten_dict
|
from lerobot.common.datasets.utils import flatten_dict
|
||||||
|
|
||||||
|
|
||||||
def get_from_raw_to_lerobot_format_fn(raw_format):
|
def get_from_raw_to_lerobot_format_fn(raw_format: str):
|
||||||
if raw_format == "pusht_zarr":
|
if raw_format == "pusht_zarr":
|
||||||
from lerobot.common.datasets.push_dataset_to_hub.pusht_zarr_format import from_raw_to_lerobot_format
|
from lerobot.common.datasets.push_dataset_to_hub.pusht_zarr_format import from_raw_to_lerobot_format
|
||||||
elif raw_format == "umi_zarr":
|
elif raw_format == "umi_zarr":
|
||||||
from lerobot.common.datasets.push_dataset_to_hub.umi_zarr_format import from_raw_to_lerobot_format
|
from lerobot.common.datasets.push_dataset_to_hub.umi_zarr_format import from_raw_to_lerobot_format
|
||||||
elif raw_format == "aloha_hdf5":
|
elif raw_format == "aloha_hdf5":
|
||||||
from lerobot.common.datasets.push_dataset_to_hub.aloha_hdf5_format import from_raw_to_lerobot_format
|
from lerobot.common.datasets.push_dataset_to_hub.aloha_hdf5_format import from_raw_to_lerobot_format
|
||||||
elif raw_format == "aloha_dora":
|
elif raw_format == "dora_parquet":
|
||||||
from lerobot.common.datasets.push_dataset_to_hub.aloha_dora_format import from_raw_to_lerobot_format
|
from lerobot.common.datasets.push_dataset_to_hub.dora_parquet_format import from_raw_to_lerobot_format
|
||||||
elif raw_format == "xarm_pkl":
|
elif raw_format == "xarm_pkl":
|
||||||
from lerobot.common.datasets.push_dataset_to_hub.xarm_pkl_format import from_raw_to_lerobot_format
|
from lerobot.common.datasets.push_dataset_to_hub.xarm_pkl_format import from_raw_to_lerobot_format
|
||||||
else:
|
else:
|
||||||
|
@ -96,7 +78,9 @@ def get_from_raw_to_lerobot_format_fn(raw_format):
|
||||||
return from_raw_to_lerobot_format
|
return from_raw_to_lerobot_format
|
||||||
|
|
||||||
|
|
||||||
def save_meta_data(info, stats, episode_data_index, meta_data_dir):
|
def save_meta_data(
|
||||||
|
info: dict[str, Any], stats: dict, episode_data_index: dict[str, list], meta_data_dir: Path
|
||||||
|
):
|
||||||
meta_data_dir.mkdir(parents=True, exist_ok=True)
|
meta_data_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
# save info
|
# save info
|
||||||
|
@ -114,7 +98,7 @@ def save_meta_data(info, stats, episode_data_index, meta_data_dir):
|
||||||
save_file(episode_data_index, ep_data_idx_path)
|
save_file(episode_data_index, ep_data_idx_path)
|
||||||
|
|
||||||
|
|
||||||
def push_meta_data_to_hub(repo_id, meta_data_dir, revision):
|
def push_meta_data_to_hub(repo_id: str, meta_data_dir: str | Path, revision: str | None):
|
||||||
"""Expect all meta data files to be all stored in a single "meta_data" directory.
|
"""Expect all meta data files to be all stored in a single "meta_data" directory.
|
||||||
On the hugging face repositery, they will be uploaded in a "meta_data" directory at the root.
|
On the hugging face repositery, they will be uploaded in a "meta_data" directory at the root.
|
||||||
"""
|
"""
|
||||||
|
@ -128,7 +112,7 @@ def push_meta_data_to_hub(repo_id, meta_data_dir, revision):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def push_videos_to_hub(repo_id, videos_dir, revision):
|
def push_videos_to_hub(repo_id: str, videos_dir: str | Path, revision: str | None):
|
||||||
"""Expect mp4 files to be all stored in a single "videos" directory.
|
"""Expect mp4 files to be all stored in a single "videos" directory.
|
||||||
On the hugging face repositery, they will be uploaded in a "videos" directory at the root.
|
On the hugging face repositery, they will be uploaded in a "videos" directory at the root.
|
||||||
"""
|
"""
|
||||||
|
@ -144,39 +128,61 @@ def push_videos_to_hub(repo_id, videos_dir, revision):
|
||||||
|
|
||||||
|
|
||||||
def push_dataset_to_hub(
|
def push_dataset_to_hub(
|
||||||
data_dir: Path,
|
raw_dir: Path,
|
||||||
dataset_id: str,
|
raw_format: str,
|
||||||
raw_format: str | None,
|
repo_id: str,
|
||||||
community_id: str,
|
push_to_hub: bool = True,
|
||||||
revision: str,
|
local_dir: Path | None = None,
|
||||||
dry_run: bool,
|
fps: int | None = None,
|
||||||
save_to_disk: bool,
|
video: bool = True,
|
||||||
tests_data_dir: Path,
|
batch_size: int = 32,
|
||||||
save_tests_to_disk: bool,
|
num_workers: int = 8,
|
||||||
fps: int | None,
|
episodes: list[int] | None = None,
|
||||||
video: bool,
|
force_override: bool = False,
|
||||||
batch_size: int,
|
cache_dir: Path = Path("/tmp"),
|
||||||
num_workers: int,
|
tests_data_dir: Path | None = None,
|
||||||
debug: bool,
|
|
||||||
):
|
):
|
||||||
repo_id = f"{community_id}/{dataset_id}"
|
# Check repo_id is well formated
|
||||||
|
if len(repo_id.split("/")) != 2:
|
||||||
|
raise ValueError(
|
||||||
|
f"`repo_id` is expected to contain a community or user id `/` the name of the dataset (e.g. 'lerobot/pusht'), but instead contains '{repo_id}'."
|
||||||
|
)
|
||||||
|
user_id, dataset_id = repo_id.split("/")
|
||||||
|
|
||||||
raw_dir = data_dir / f"{dataset_id}_raw"
|
# Robustify when `raw_dir` is str instead of Path
|
||||||
|
raw_dir = Path(raw_dir)
|
||||||
|
if not raw_dir.exists():
|
||||||
|
raise NotADirectoryError(
|
||||||
|
f"{raw_dir} does not exists. Check your paths or run this command to download an existing raw dataset on the hub:"
|
||||||
|
f"python lerobot/common/datasets/push_dataset_to_hub/_download_raw.py --raw-dir your/raw/dir --repo-id your/repo/id_raw"
|
||||||
|
)
|
||||||
|
|
||||||
out_dir = data_dir / repo_id
|
if local_dir:
|
||||||
meta_data_dir = out_dir / "meta_data"
|
# Robustify when `local_dir` is str instead of Path
|
||||||
videos_dir = out_dir / "videos"
|
local_dir = Path(local_dir)
|
||||||
|
|
||||||
tests_out_dir = tests_data_dir / repo_id
|
# Send warning if local_dir isn't well formated
|
||||||
tests_meta_data_dir = tests_out_dir / "meta_data"
|
if local_dir.parts[-2] != user_id or local_dir.parts[-1] != dataset_id:
|
||||||
tests_videos_dir = tests_out_dir / "videos"
|
warnings.warn(
|
||||||
|
f"`local_dir` ({local_dir}) doesn't contain a community or user id `/` the name of the dataset that match the `repo_id` (e.g. 'data/lerobot/pusht'). Following this naming convention is advised, but not mandatory.",
|
||||||
|
stacklevel=1,
|
||||||
|
)
|
||||||
|
|
||||||
if out_dir.exists():
|
# Check we don't override an existing `local_dir` by mistake
|
||||||
shutil.rmtree(out_dir)
|
if local_dir.exists():
|
||||||
|
if force_override:
|
||||||
|
shutil.rmtree(local_dir)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"`local_dir` already exists ({local_dir}). Use `--force-override 1`.")
|
||||||
|
|
||||||
if tests_out_dir.exists() and save_tests_to_disk:
|
meta_data_dir = local_dir / "meta_data"
|
||||||
shutil.rmtree(tests_out_dir)
|
videos_dir = local_dir / "videos"
|
||||||
|
else:
|
||||||
|
# Temporary directory used to store images, videos, meta_data
|
||||||
|
meta_data_dir = Path(cache_dir) / "meta_data"
|
||||||
|
videos_dir = Path(cache_dir) / "videos"
|
||||||
|
|
||||||
|
# Download the raw dataset if available
|
||||||
if not raw_dir.exists():
|
if not raw_dir.exists():
|
||||||
download_raw(raw_dir, dataset_id)
|
download_raw(raw_dir, dataset_id)
|
||||||
|
|
||||||
|
@ -185,14 +191,14 @@ def push_dataset_to_hub(
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
# raw_format = auto_find_raw_format(raw_dir)
|
# raw_format = auto_find_raw_format(raw_dir)
|
||||||
|
|
||||||
from_raw_to_lerobot_format = get_from_raw_to_lerobot_format_fn(raw_format)
|
|
||||||
|
|
||||||
# convert dataset from original raw format to LeRobot format
|
# convert dataset from original raw format to LeRobot format
|
||||||
hf_dataset, episode_data_index, info = from_raw_to_lerobot_format(raw_dir, out_dir, fps, video, debug)
|
from_raw_to_lerobot_format = get_from_raw_to_lerobot_format_fn(raw_format)
|
||||||
|
hf_dataset, episode_data_index, info = from_raw_to_lerobot_format(
|
||||||
|
raw_dir, videos_dir, fps, video, episodes
|
||||||
|
)
|
||||||
|
|
||||||
lerobot_dataset = LeRobotDataset.from_preloaded(
|
lerobot_dataset = LeRobotDataset.from_preloaded(
|
||||||
repo_id=repo_id,
|
repo_id=repo_id,
|
||||||
version=revision,
|
|
||||||
hf_dataset=hf_dataset,
|
hf_dataset=hf_dataset,
|
||||||
episode_data_index=episode_data_index,
|
episode_data_index=episode_data_index,
|
||||||
info=info,
|
info=info,
|
||||||
|
@ -200,102 +206,80 @@ def push_dataset_to_hub(
|
||||||
)
|
)
|
||||||
stats = compute_stats(lerobot_dataset, batch_size, num_workers)
|
stats = compute_stats(lerobot_dataset, batch_size, num_workers)
|
||||||
|
|
||||||
if save_to_disk:
|
if local_dir:
|
||||||
hf_dataset = hf_dataset.with_format(None) # to remove transforms that cant be saved
|
hf_dataset = hf_dataset.with_format(None) # to remove transforms that cant be saved
|
||||||
hf_dataset.save_to_disk(str(out_dir / "train"))
|
hf_dataset.save_to_disk(str(local_dir / "train"))
|
||||||
|
|
||||||
if not dry_run or save_to_disk:
|
if push_to_hub or local_dir:
|
||||||
# mandatory for upload
|
# mandatory for upload
|
||||||
save_meta_data(info, stats, episode_data_index, meta_data_dir)
|
save_meta_data(info, stats, episode_data_index, meta_data_dir)
|
||||||
|
|
||||||
if not dry_run:
|
if push_to_hub:
|
||||||
hf_dataset.push_to_hub(repo_id, token=True, revision="main")
|
hf_dataset.push_to_hub(repo_id, revision="main")
|
||||||
hf_dataset.push_to_hub(repo_id, token=True, revision=revision)
|
|
||||||
|
|
||||||
push_meta_data_to_hub(repo_id, meta_data_dir, revision="main")
|
push_meta_data_to_hub(repo_id, meta_data_dir, revision="main")
|
||||||
push_meta_data_to_hub(repo_id, meta_data_dir, revision=revision)
|
|
||||||
|
|
||||||
if video:
|
if video:
|
||||||
push_videos_to_hub(repo_id, videos_dir, revision="main")
|
push_videos_to_hub(repo_id, videos_dir, revision="main")
|
||||||
push_videos_to_hub(repo_id, videos_dir, revision=revision)
|
create_branch(repo_id, repo_type="dataset", branch=CODEBASE_VERSION)
|
||||||
|
|
||||||
if save_tests_to_disk:
|
if tests_data_dir:
|
||||||
# get the first episode
|
# get the first episode
|
||||||
num_items_first_ep = episode_data_index["to"][0] - episode_data_index["from"][0]
|
num_items_first_ep = episode_data_index["to"][0] - episode_data_index["from"][0]
|
||||||
test_hf_dataset = hf_dataset.select(range(num_items_first_ep))
|
test_hf_dataset = hf_dataset.select(range(num_items_first_ep))
|
||||||
|
|
||||||
test_hf_dataset = test_hf_dataset.with_format(None)
|
test_hf_dataset = test_hf_dataset.with_format(None)
|
||||||
test_hf_dataset.save_to_disk(str(tests_out_dir / "train"))
|
test_hf_dataset.save_to_disk(str(tests_data_dir / repo_id / "train"))
|
||||||
|
|
||||||
save_meta_data(info, stats, episode_data_index, tests_meta_data_dir)
|
tests_meta_data = tests_data_dir / repo_id / "meta_data"
|
||||||
|
save_meta_data(info, stats, episode_data_index, tests_meta_data)
|
||||||
|
|
||||||
# copy videos of first episode to tests directory
|
# copy videos of first episode to tests directory
|
||||||
episode_index = 0
|
episode_index = 0
|
||||||
|
tests_videos_dir = tests_data_dir / repo_id / "videos"
|
||||||
tests_videos_dir.mkdir(parents=True, exist_ok=True)
|
tests_videos_dir.mkdir(parents=True, exist_ok=True)
|
||||||
for key in lerobot_dataset.video_frame_keys:
|
for key in lerobot_dataset.video_frame_keys:
|
||||||
fname = f"{key}_episode_{episode_index:06d}.mp4"
|
fname = f"{key}_episode_{episode_index:06d}.mp4"
|
||||||
shutil.copy(videos_dir / fname, tests_videos_dir / fname)
|
shutil.copy(videos_dir / fname, tests_videos_dir / fname)
|
||||||
|
|
||||||
if not save_to_disk and out_dir.exists():
|
if local_dir is None:
|
||||||
# remove possible temporary files remaining in the output directory
|
# clear cache
|
||||||
shutil.rmtree(out_dir)
|
shutil.rmtree(meta_data_dir)
|
||||||
|
shutil.rmtree(videos_dir)
|
||||||
|
|
||||||
|
return lerobot_dataset
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--data-dir",
|
"--raw-dir",
|
||||||
type=Path,
|
type=Path,
|
||||||
required=True,
|
required=True,
|
||||||
help="Root directory containing datasets (e.g. `data` or `tmp/data` or `/tmp/lerobot/data`).",
|
help="Directory containing input raw datasets (e.g. `data/aloha_mobile_chair_raw` or `data/pusht_raw).",
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--dataset-id",
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
help="Name of the dataset (e.g. `pusht`, `aloha_sim_insertion_human`), which matches the folder where the data is stored (e.g. `data/pusht`).",
|
|
||||||
)
|
)
|
||||||
|
# TODO(rcadene): add automatic detection of the format
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--raw-format",
|
"--raw-format",
|
||||||
type=str,
|
type=str,
|
||||||
help="Dataset type (e.g. `pusht_zarr`, `umi_zarr`, `aloha_hdf5`, `xarm_pkl`). If not provided, will be detected automatically.",
|
required=True,
|
||||||
|
help="Dataset type (e.g. `pusht_zarr`, `umi_zarr`, `aloha_hdf5`, `xarm_pkl`, `dora_parquet`).",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--community-id",
|
"--repo-id",
|
||||||
type=str,
|
type=str,
|
||||||
default="lerobot",
|
required=True,
|
||||||
help="Community or user ID under which the dataset will be hosted on the Hub.",
|
help="Repositery identifier on Hugging Face: a community or a user name `/` the name of the dataset (e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--revision",
|
"--local-dir",
|
||||||
type=str,
|
|
||||||
default=CODEBASE_VERSION,
|
|
||||||
help="Codebase version used to generate the dataset.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--dry-run",
|
|
||||||
type=int,
|
|
||||||
default=0,
|
|
||||||
help="Run everything without uploading to hub, for testing purposes or storing a dataset locally.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--save-to-disk",
|
|
||||||
type=int,
|
|
||||||
default=1,
|
|
||||||
help="Save the dataset in the directory specified by `--data-dir`.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--tests-data-dir",
|
|
||||||
type=Path,
|
type=Path,
|
||||||
default="tests/data",
|
help="When provided, writes the dataset converted to LeRobotDataset format in this directory (e.g. `data/lerobot/aloha_mobile_chair`).",
|
||||||
help="Directory containing tests artifacts datasets.",
|
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--save-tests-to-disk",
|
"--push-to-hub",
|
||||||
type=int,
|
type=int,
|
||||||
default=1,
|
default=1,
|
||||||
help="Save the dataset with 1 episode used for unit tests in the directory specified by `--tests-data-dir`.",
|
help="Upload to hub.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--fps",
|
"--fps",
|
||||||
|
@ -321,10 +305,21 @@ def main():
|
||||||
help="Number of processes of Dataloader for computing the dataset statistics.",
|
help="Number of processes of Dataloader for computing the dataset statistics.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--debug",
|
"--episodes",
|
||||||
|
type=int,
|
||||||
|
nargs="*",
|
||||||
|
help="When provided, only converts the provided episodes (e.g `--episodes 2 3 4`). Useful to test the code on 1 episode.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--force-override",
|
||||||
type=int,
|
type=int,
|
||||||
default=0,
|
default=0,
|
||||||
help="Debug mode process the first episode only.",
|
help="When set to 1, removes provided output directory if it already exists. By default, raises a ValueError exception.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--tests-data-dir",
|
||||||
|
type=Path,
|
||||||
|
help="When provided, save tests artifacts into the given directory for (e.g. `--tests-data-dir tests/data/lerobot/pusht`).",
|
||||||
)
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
|
@ -24,6 +24,7 @@ import torch
|
||||||
from deepdiff import DeepDiff
|
from deepdiff import DeepDiff
|
||||||
from omegaconf import DictConfig, OmegaConf
|
from omegaconf import DictConfig, OmegaConf
|
||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
|
from torch import nn
|
||||||
from torch.cuda.amp import GradScaler
|
from torch.cuda.amp import GradScaler
|
||||||
|
|
||||||
from lerobot.common.datasets.factory import make_dataset, resolve_delta_timestamps
|
from lerobot.common.datasets.factory import make_dataset, resolve_delta_timestamps
|
||||||
|
@ -150,6 +151,7 @@ def log_train_info(logger: Logger, info, step, cfg, dataset, is_offline):
|
||||||
grad_norm = info["grad_norm"]
|
grad_norm = info["grad_norm"]
|
||||||
lr = info["lr"]
|
lr = info["lr"]
|
||||||
update_s = info["update_s"]
|
update_s = info["update_s"]
|
||||||
|
dataloading_s = info["dataloading_s"]
|
||||||
|
|
||||||
# A sample is an (observation,action) pair, where observation and action
|
# A sample is an (observation,action) pair, where observation and action
|
||||||
# can be on multiple timestamps. In a batch, we have `batch_size`` number of samples.
|
# can be on multiple timestamps. In a batch, we have `batch_size`` number of samples.
|
||||||
|
@ -170,6 +172,7 @@ def log_train_info(logger: Logger, info, step, cfg, dataset, is_offline):
|
||||||
f"lr:{lr:0.1e}",
|
f"lr:{lr:0.1e}",
|
||||||
# in seconds
|
# in seconds
|
||||||
f"updt_s:{update_s:.3f}",
|
f"updt_s:{update_s:.3f}",
|
||||||
|
f"data_s:{dataloading_s:.3f}", # if not ~0, you are bottlenecked by cpu or io
|
||||||
]
|
]
|
||||||
logging.info(" ".join(log_items))
|
logging.info(" ".join(log_items))
|
||||||
|
|
||||||
|
@ -290,6 +293,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||||
# Create environment used for evaluating checkpoints during training on simulation data.
|
# Create environment used for evaluating checkpoints during training on simulation data.
|
||||||
# On real-world data, no need to create an environment as evaluations are done outside train.py,
|
# On real-world data, no need to create an environment as evaluations are done outside train.py,
|
||||||
# using the eval.py instead, with gym_dora environment and dora-rs.
|
# using the eval.py instead, with gym_dora environment and dora-rs.
|
||||||
|
eval_env = None
|
||||||
if cfg.training.eval_freq > 0:
|
if cfg.training.eval_freq > 0:
|
||||||
logging.info("make_env")
|
logging.info("make_env")
|
||||||
eval_env = make_env(cfg)
|
eval_env = make_env(cfg)
|
||||||
|
@ -300,7 +304,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||||
dataset_stats=offline_dataset.stats if not cfg.resume else None,
|
dataset_stats=offline_dataset.stats if not cfg.resume else None,
|
||||||
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None,
|
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None,
|
||||||
)
|
)
|
||||||
|
assert isinstance(policy, nn.Module)
|
||||||
# Create optimizer and scheduler
|
# Create optimizer and scheduler
|
||||||
# Temporary hack to move optimizer out of policy
|
# Temporary hack to move optimizer out of policy
|
||||||
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
|
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
|
||||||
|
@ -325,14 +329,18 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||||
|
|
||||||
# Note: this helper will be used in offline and online training loops.
|
# Note: this helper will be used in offline and online training loops.
|
||||||
def evaluate_and_checkpoint_if_needed(step):
|
def evaluate_and_checkpoint_if_needed(step):
|
||||||
|
_num_digits = max(6, len(str(cfg.training.offline_steps + cfg.training.online_steps)))
|
||||||
|
step_identifier = f"{step:0{_num_digits}d}"
|
||||||
|
|
||||||
if cfg.training.eval_freq > 0 and step % cfg.training.eval_freq == 0:
|
if cfg.training.eval_freq > 0 and step % cfg.training.eval_freq == 0:
|
||||||
logging.info(f"Eval policy at step {step}")
|
logging.info(f"Eval policy at step {step}")
|
||||||
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.use_amp else nullcontext():
|
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.use_amp else nullcontext():
|
||||||
|
assert eval_env is not None
|
||||||
eval_info = eval_policy(
|
eval_info = eval_policy(
|
||||||
eval_env,
|
eval_env,
|
||||||
policy,
|
policy,
|
||||||
cfg.eval.n_episodes,
|
cfg.eval.n_episodes,
|
||||||
video_dir=Path(out_dir) / "eval",
|
videos_dir=Path(out_dir) / "eval" / f"videos_step_{step_identifier}",
|
||||||
max_episodes_rendered=4,
|
max_episodes_rendered=4,
|
||||||
start_seed=cfg.seed,
|
start_seed=cfg.seed,
|
||||||
)
|
)
|
||||||
|
@ -350,9 +358,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||||
policy,
|
policy,
|
||||||
optimizer,
|
optimizer,
|
||||||
lr_scheduler,
|
lr_scheduler,
|
||||||
identifier=str(step).zfill(
|
identifier=step_identifier,
|
||||||
max(6, len(str(cfg.training.offline_steps + cfg.training.online_steps)))
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
logging.info("Resume training")
|
logging.info("Resume training")
|
||||||
|
|
||||||
|
@ -382,7 +388,10 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||||
for _ in range(step, cfg.training.offline_steps):
|
for _ in range(step, cfg.training.offline_steps):
|
||||||
if step == 0:
|
if step == 0:
|
||||||
logging.info("Start offline training on a fixed dataset")
|
logging.info("Start offline training on a fixed dataset")
|
||||||
|
|
||||||
|
start_time = time.perf_counter()
|
||||||
batch = next(dl_iter)
|
batch = next(dl_iter)
|
||||||
|
dataloading_s = time.perf_counter() - start_time
|
||||||
|
|
||||||
for key in batch:
|
for key in batch:
|
||||||
batch[key] = batch[key].to(device, non_blocking=True)
|
batch[key] = batch[key].to(device, non_blocking=True)
|
||||||
|
@ -397,6 +406,8 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||||
use_amp=cfg.use_amp,
|
use_amp=cfg.use_amp,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
train_info["dataloading_s"] = dataloading_s
|
||||||
|
|
||||||
if step % cfg.training.log_freq == 0:
|
if step % cfg.training.log_freq == 0:
|
||||||
log_train_info(logger, train_info, step, cfg, offline_dataset, is_offline=True)
|
log_train_info(logger, train_info, step, cfg, offline_dataset, is_offline=True)
|
||||||
|
|
||||||
|
@ -406,6 +417,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||||
|
|
||||||
step += 1
|
step += 1
|
||||||
|
|
||||||
|
if eval_env:
|
||||||
eval_env.close()
|
eval_env.close()
|
||||||
logging.info("End of training")
|
logging.info("End of training")
|
||||||
|
|
||||||
|
|
|
@ -66,28 +66,31 @@ import gc
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Iterator
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import rerun as rr
|
import rerun as rr
|
||||||
import torch
|
import torch
|
||||||
|
import torch.utils.data
|
||||||
import tqdm
|
import tqdm
|
||||||
|
|
||||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
|
|
||||||
|
|
||||||
class EpisodeSampler(torch.utils.data.Sampler):
|
class EpisodeSampler(torch.utils.data.Sampler):
|
||||||
def __init__(self, dataset, episode_index):
|
def __init__(self, dataset: LeRobotDataset, episode_index: int):
|
||||||
from_idx = dataset.episode_data_index["from"][episode_index].item()
|
from_idx = dataset.episode_data_index["from"][episode_index].item()
|
||||||
to_idx = dataset.episode_data_index["to"][episode_index].item()
|
to_idx = dataset.episode_data_index["to"][episode_index].item()
|
||||||
self.frame_ids = range(from_idx, to_idx)
|
self.frame_ids = range(from_idx, to_idx)
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self) -> Iterator:
|
||||||
return iter(self.frame_ids)
|
return iter(self.frame_ids)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self) -> int:
|
||||||
return len(self.frame_ids)
|
return len(self.frame_ids)
|
||||||
|
|
||||||
|
|
||||||
def to_hwc_uint8_numpy(chw_float32_torch):
|
def to_hwc_uint8_numpy(chw_float32_torch: torch.Tensor) -> np.ndarray:
|
||||||
assert chw_float32_torch.dtype == torch.float32
|
assert chw_float32_torch.dtype == torch.float32
|
||||||
assert chw_float32_torch.ndim == 3
|
assert chw_float32_torch.ndim == 3
|
||||||
c, h, w = chw_float32_torch.shape
|
c, h, w = chw_float32_torch.shape
|
||||||
|
@ -106,6 +109,7 @@ def visualize_dataset(
|
||||||
ws_port: int = 9087,
|
ws_port: int = 9087,
|
||||||
save: bool = False,
|
save: bool = False,
|
||||||
output_dir: Path | None = None,
|
output_dir: Path | None = None,
|
||||||
|
root: Path | None = None,
|
||||||
) -> Path | None:
|
) -> Path | None:
|
||||||
if save:
|
if save:
|
||||||
assert (
|
assert (
|
||||||
|
@ -113,7 +117,7 @@ def visualize_dataset(
|
||||||
), "Set an output directory where to write .rrd files with `--output-dir path/to/directory`."
|
), "Set an output directory where to write .rrd files with `--output-dir path/to/directory`."
|
||||||
|
|
||||||
logging.info("Loading dataset")
|
logging.info("Loading dataset")
|
||||||
dataset = LeRobotDataset(repo_id)
|
dataset = LeRobotDataset(repo_id, root=root)
|
||||||
|
|
||||||
logging.info("Loading dataloader")
|
logging.info("Loading dataloader")
|
||||||
episode_sampler = EpisodeSampler(dataset, episode_index)
|
episode_sampler = EpisodeSampler(dataset, episode_index)
|
||||||
|
@ -224,7 +228,8 @@ def main():
|
||||||
help=(
|
help=(
|
||||||
"Mode of viewing between 'local' or 'distant'. "
|
"Mode of viewing between 'local' or 'distant'. "
|
||||||
"'local' requires data to be on a local machine. It spawns a viewer to visualize the data locally. "
|
"'local' requires data to be on a local machine. It spawns a viewer to visualize the data locally. "
|
||||||
"'distant' creates a server on the distant machine where the data is stored. Visualize the data by connecting to the server with `rerun ws://localhost:PORT` on the local machine."
|
"'distant' creates a server on the distant machine where the data is stored. "
|
||||||
|
"Visualize the data by connecting to the server with `rerun ws://localhost:PORT` on the local machine."
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
@ -245,8 +250,8 @@ def main():
|
||||||
default=0,
|
default=0,
|
||||||
help=(
|
help=(
|
||||||
"Save a .rrd file in the directory provided by `--output-dir`. "
|
"Save a .rrd file in the directory provided by `--output-dir`. "
|
||||||
"It also deactivates the spawning of a viewer. ",
|
"It also deactivates the spawning of a viewer. "
|
||||||
"Visualize the data by running `rerun path/to/file.rrd` on your local machine.",
|
"Visualize the data by running `rerun path/to/file.rrd` on your local machine."
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|
|
@ -0,0 +1,142 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
""" Visualize effects of image transforms for a given configuration.
|
||||||
|
|
||||||
|
This script will generate examples of transformed images as they are output by LeRobot dataset.
|
||||||
|
Additionally, each individual transform can be visualized separately as well as examples of combined transforms
|
||||||
|
|
||||||
|
|
||||||
|
--- Usage Examples ---
|
||||||
|
|
||||||
|
Increase hue jitter
|
||||||
|
```
|
||||||
|
python lerobot/scripts/visualize_image_transforms.py \
|
||||||
|
dataset_repo_id=lerobot/aloha_mobile_shrimp \
|
||||||
|
training.image_transforms.hue.min_max=[-0.25,0.25]
|
||||||
|
```
|
||||||
|
|
||||||
|
Increase brightness & brightness weight
|
||||||
|
```
|
||||||
|
python lerobot/scripts/visualize_image_transforms.py \
|
||||||
|
dataset_repo_id=lerobot/aloha_mobile_shrimp \
|
||||||
|
training.image_transforms.brightness.weight=10.0 \
|
||||||
|
training.image_transforms.brightness.min_max=[1.0,2.0]
|
||||||
|
```
|
||||||
|
|
||||||
|
Blur images and disable saturation & hue
|
||||||
|
```
|
||||||
|
python lerobot/scripts/visualize_image_transforms.py \
|
||||||
|
dataset_repo_id=lerobot/aloha_mobile_shrimp \
|
||||||
|
training.image_transforms.sharpness.weight=10.0 \
|
||||||
|
training.image_transforms.sharpness.min_max=[0.0,1.0] \
|
||||||
|
training.image_transforms.saturation.weight=0.0 \
|
||||||
|
training.image_transforms.hue.weight=0.0
|
||||||
|
```
|
||||||
|
|
||||||
|
Use all transforms with random order
|
||||||
|
```
|
||||||
|
python lerobot/scripts/visualize_image_transforms.py \
|
||||||
|
dataset_repo_id=lerobot/aloha_mobile_shrimp \
|
||||||
|
training.image_transforms.max_num_transforms=5 \
|
||||||
|
training.image_transforms.random_order=true
|
||||||
|
```
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import hydra
|
||||||
|
from torchvision.transforms import ToPILImage
|
||||||
|
|
||||||
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
|
from lerobot.common.datasets.transforms import get_image_transforms
|
||||||
|
|
||||||
|
OUTPUT_DIR = Path("outputs/image_transforms")
|
||||||
|
N_EXAMPLES = 5
|
||||||
|
to_pil = ToPILImage()
|
||||||
|
|
||||||
|
|
||||||
|
def save_config_all_transforms(cfg, original_frame, output_dir):
|
||||||
|
tf = get_image_transforms(
|
||||||
|
brightness_weight=cfg.brightness.weight,
|
||||||
|
brightness_min_max=cfg.brightness.min_max,
|
||||||
|
contrast_weight=cfg.contrast.weight,
|
||||||
|
contrast_min_max=cfg.contrast.min_max,
|
||||||
|
saturation_weight=cfg.saturation.weight,
|
||||||
|
saturation_min_max=cfg.saturation.min_max,
|
||||||
|
hue_weight=cfg.hue.weight,
|
||||||
|
hue_min_max=cfg.hue.min_max,
|
||||||
|
sharpness_weight=cfg.sharpness.weight,
|
||||||
|
sharpness_min_max=cfg.sharpness.min_max,
|
||||||
|
max_num_transforms=cfg.max_num_transforms,
|
||||||
|
random_order=cfg.random_order,
|
||||||
|
)
|
||||||
|
|
||||||
|
output_dir_all = output_dir / "all"
|
||||||
|
output_dir_all.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
for i in range(1, N_EXAMPLES + 1):
|
||||||
|
transformed_frame = tf(original_frame)
|
||||||
|
to_pil(transformed_frame).save(output_dir_all / f"{i}.png", quality=100)
|
||||||
|
|
||||||
|
print("Combined transforms examples saved to:")
|
||||||
|
print(f" {output_dir_all}")
|
||||||
|
|
||||||
|
|
||||||
|
def save_config_single_transforms(cfg, original_frame, output_dir):
|
||||||
|
transforms = [
|
||||||
|
"brightness",
|
||||||
|
"contrast",
|
||||||
|
"saturation",
|
||||||
|
"hue",
|
||||||
|
"sharpness",
|
||||||
|
]
|
||||||
|
print("Individual transforms examples saved to:")
|
||||||
|
for transform in transforms:
|
||||||
|
kwargs = {
|
||||||
|
f"{transform}_weight": cfg[f"{transform}"].weight,
|
||||||
|
f"{transform}_min_max": cfg[f"{transform}"].min_max,
|
||||||
|
}
|
||||||
|
tf = get_image_transforms(**kwargs)
|
||||||
|
output_dir_single = output_dir / f"{transform}"
|
||||||
|
output_dir_single.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
for i in range(1, N_EXAMPLES + 1):
|
||||||
|
transformed_frame = tf(original_frame)
|
||||||
|
to_pil(transformed_frame).save(output_dir_single / f"{i}.png", quality=100)
|
||||||
|
|
||||||
|
print(f" {output_dir_single}")
|
||||||
|
|
||||||
|
|
||||||
|
@hydra.main(version_base="1.2", config_name="default", config_path="../configs")
|
||||||
|
def visualize_transforms(cfg):
|
||||||
|
dataset = LeRobotDataset(cfg.dataset_repo_id)
|
||||||
|
|
||||||
|
output_dir = Path(OUTPUT_DIR) / cfg.dataset_repo_id.split("/")[-1]
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Get 1st frame from 1st camera of 1st episode
|
||||||
|
original_frame = dataset[0][dataset.camera_keys[0]]
|
||||||
|
to_pil(original_frame).save(output_dir / "original_frame.png", quality=100)
|
||||||
|
print("\nOriginal frame saved to:")
|
||||||
|
print(f" {output_dir / 'original_frame.png'}.")
|
||||||
|
|
||||||
|
save_config_all_transforms(cfg.training.image_transforms, original_frame, output_dir)
|
||||||
|
save_config_single_transforms(cfg.training.image_transforms, original_frame, output_dir)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
visualize_transforms()
|
|
@ -0,0 +1,3 @@
|
||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:36f50697dacc82d52d1799dbc53c6c2fb722b9c0bd5bfa90a92dfa336591c74a
|
||||||
|
size 3686488
|
|
@ -0,0 +1,3 @@
|
||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:d0e3b4bde97c34606536b655c1e6a23316c9157bd21dcbc73a97500fb985607f
|
||||||
|
size 40551392
|
|
@ -0,0 +1,86 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from safetensors.torch import save_file
|
||||||
|
|
||||||
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
|
from lerobot.common.datasets.transforms import get_image_transforms
|
||||||
|
from lerobot.common.utils.utils import init_hydra_config, seeded_context
|
||||||
|
from tests.test_image_transforms import ARTIFACT_DIR, DATASET_REPO_ID
|
||||||
|
from tests.utils import DEFAULT_CONFIG_PATH
|
||||||
|
|
||||||
|
|
||||||
|
def save_default_config_transform(original_frame: torch.Tensor, output_dir: Path):
|
||||||
|
cfg = init_hydra_config(DEFAULT_CONFIG_PATH)
|
||||||
|
cfg_tf = cfg.training.image_transforms
|
||||||
|
default_tf = get_image_transforms(
|
||||||
|
brightness_weight=cfg_tf.brightness.weight,
|
||||||
|
brightness_min_max=cfg_tf.brightness.min_max,
|
||||||
|
contrast_weight=cfg_tf.contrast.weight,
|
||||||
|
contrast_min_max=cfg_tf.contrast.min_max,
|
||||||
|
saturation_weight=cfg_tf.saturation.weight,
|
||||||
|
saturation_min_max=cfg_tf.saturation.min_max,
|
||||||
|
hue_weight=cfg_tf.hue.weight,
|
||||||
|
hue_min_max=cfg_tf.hue.min_max,
|
||||||
|
sharpness_weight=cfg_tf.sharpness.weight,
|
||||||
|
sharpness_min_max=cfg_tf.sharpness.min_max,
|
||||||
|
max_num_transforms=cfg_tf.max_num_transforms,
|
||||||
|
random_order=cfg_tf.random_order,
|
||||||
|
)
|
||||||
|
|
||||||
|
with seeded_context(1337):
|
||||||
|
img_tf = default_tf(original_frame)
|
||||||
|
|
||||||
|
save_file({"default": img_tf}, output_dir / "default_transforms.safetensors")
|
||||||
|
|
||||||
|
|
||||||
|
def save_single_transforms(original_frame: torch.Tensor, output_dir: Path):
|
||||||
|
transforms = {
|
||||||
|
"brightness": [(0.5, 0.5), (2.0, 2.0)],
|
||||||
|
"contrast": [(0.5, 0.5), (2.0, 2.0)],
|
||||||
|
"saturation": [(0.5, 0.5), (2.0, 2.0)],
|
||||||
|
"hue": [(-0.25, -0.25), (0.25, 0.25)],
|
||||||
|
"sharpness": [(0.5, 0.5), (2.0, 2.0)],
|
||||||
|
}
|
||||||
|
|
||||||
|
frames = {"original_frame": original_frame}
|
||||||
|
for transform, values in transforms.items():
|
||||||
|
for min_max in values:
|
||||||
|
kwargs = {
|
||||||
|
f"{transform}_weight": 1.0,
|
||||||
|
f"{transform}_min_max": min_max,
|
||||||
|
}
|
||||||
|
tf = get_image_transforms(**kwargs)
|
||||||
|
key = f"{transform}_{min_max[0]}_{min_max[1]}"
|
||||||
|
frames[key] = tf(original_frame)
|
||||||
|
|
||||||
|
save_file(frames, output_dir / "single_transforms.safetensors")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
dataset = LeRobotDataset(DATASET_REPO_ID, image_transforms=None)
|
||||||
|
output_dir = Path(ARTIFACT_DIR)
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
original_frame = dataset[0][dataset.camera_keys[0]]
|
||||||
|
|
||||||
|
save_single_transforms(original_frame, output_dir)
|
||||||
|
save_default_config_transform(original_frame, output_dir)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
|
@ -0,0 +1,260 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from safetensors.torch import load_file
|
||||||
|
from torchvision.transforms import v2
|
||||||
|
from torchvision.transforms.v2 import functional as F # noqa: N812
|
||||||
|
|
||||||
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
|
from lerobot.common.datasets.transforms import RandomSubsetApply, SharpnessJitter, get_image_transforms
|
||||||
|
from lerobot.common.utils.utils import init_hydra_config, seeded_context
|
||||||
|
from tests.utils import DEFAULT_CONFIG_PATH, require_x86_64_kernel
|
||||||
|
|
||||||
|
ARTIFACT_DIR = Path("tests/data/save_image_transforms_to_safetensors")
|
||||||
|
DATASET_REPO_ID = "lerobot/aloha_mobile_shrimp"
|
||||||
|
|
||||||
|
|
||||||
|
def load_png_to_tensor(path: Path):
|
||||||
|
return torch.from_numpy(np.array(Image.open(path).convert("RGB"))).permute(2, 0, 1)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def img():
|
||||||
|
dataset = LeRobotDataset(DATASET_REPO_ID)
|
||||||
|
return dataset[0][dataset.camera_keys[0]]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def img_random():
|
||||||
|
return torch.rand(3, 480, 640)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def color_jitters():
|
||||||
|
return [
|
||||||
|
v2.ColorJitter(brightness=0.5),
|
||||||
|
v2.ColorJitter(contrast=0.5),
|
||||||
|
v2.ColorJitter(saturation=0.5),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def single_transforms():
|
||||||
|
return load_file(ARTIFACT_DIR / "single_transforms.safetensors")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def default_transforms():
|
||||||
|
return load_file(ARTIFACT_DIR / "default_transforms.safetensors")
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_image_transforms_no_transform(img):
|
||||||
|
tf_actual = get_image_transforms(brightness_min_max=(0.5, 0.5), max_num_transforms=0)
|
||||||
|
torch.testing.assert_close(tf_actual(img), img)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)])
|
||||||
|
def test_get_image_transforms_brightness(img, min_max):
|
||||||
|
tf_actual = get_image_transforms(brightness_weight=1.0, brightness_min_max=min_max)
|
||||||
|
tf_expected = v2.ColorJitter(brightness=min_max)
|
||||||
|
torch.testing.assert_close(tf_actual(img), tf_expected(img))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)])
|
||||||
|
def test_get_image_transforms_contrast(img, min_max):
|
||||||
|
tf_actual = get_image_transforms(contrast_weight=1.0, contrast_min_max=min_max)
|
||||||
|
tf_expected = v2.ColorJitter(contrast=min_max)
|
||||||
|
torch.testing.assert_close(tf_actual(img), tf_expected(img))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)])
|
||||||
|
def test_get_image_transforms_saturation(img, min_max):
|
||||||
|
tf_actual = get_image_transforms(saturation_weight=1.0, saturation_min_max=min_max)
|
||||||
|
tf_expected = v2.ColorJitter(saturation=min_max)
|
||||||
|
torch.testing.assert_close(tf_actual(img), tf_expected(img))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("min_max", [(-0.25, -0.25), (0.25, 0.25)])
|
||||||
|
def test_get_image_transforms_hue(img, min_max):
|
||||||
|
tf_actual = get_image_transforms(hue_weight=1.0, hue_min_max=min_max)
|
||||||
|
tf_expected = v2.ColorJitter(hue=min_max)
|
||||||
|
torch.testing.assert_close(tf_actual(img), tf_expected(img))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)])
|
||||||
|
def test_get_image_transforms_sharpness(img, min_max):
|
||||||
|
tf_actual = get_image_transforms(sharpness_weight=1.0, sharpness_min_max=min_max)
|
||||||
|
tf_expected = SharpnessJitter(sharpness=min_max)
|
||||||
|
torch.testing.assert_close(tf_actual(img), tf_expected(img))
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_image_transforms_max_num_transforms(img):
|
||||||
|
tf_actual = get_image_transforms(
|
||||||
|
brightness_min_max=(0.5, 0.5),
|
||||||
|
contrast_min_max=(0.5, 0.5),
|
||||||
|
saturation_min_max=(0.5, 0.5),
|
||||||
|
hue_min_max=(0.5, 0.5),
|
||||||
|
sharpness_min_max=(0.5, 0.5),
|
||||||
|
random_order=False,
|
||||||
|
)
|
||||||
|
tf_expected = v2.Compose(
|
||||||
|
[
|
||||||
|
v2.ColorJitter(brightness=(0.5, 0.5)),
|
||||||
|
v2.ColorJitter(contrast=(0.5, 0.5)),
|
||||||
|
v2.ColorJitter(saturation=(0.5, 0.5)),
|
||||||
|
v2.ColorJitter(hue=(0.5, 0.5)),
|
||||||
|
SharpnessJitter(sharpness=(0.5, 0.5)),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
torch.testing.assert_close(tf_actual(img), tf_expected(img))
|
||||||
|
|
||||||
|
|
||||||
|
@require_x86_64_kernel
|
||||||
|
def test_get_image_transforms_random_order(img):
|
||||||
|
out_imgs = []
|
||||||
|
tf = get_image_transforms(
|
||||||
|
brightness_min_max=(0.5, 0.5),
|
||||||
|
contrast_min_max=(0.5, 0.5),
|
||||||
|
saturation_min_max=(0.5, 0.5),
|
||||||
|
hue_min_max=(0.5, 0.5),
|
||||||
|
sharpness_min_max=(0.5, 0.5),
|
||||||
|
random_order=True,
|
||||||
|
)
|
||||||
|
with seeded_context(1337):
|
||||||
|
for _ in range(10):
|
||||||
|
out_imgs.append(tf(img))
|
||||||
|
|
||||||
|
for i in range(1, len(out_imgs)):
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
torch.testing.assert_close(out_imgs[0], out_imgs[i])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"transform, min_max_values",
|
||||||
|
[
|
||||||
|
("brightness", [(0.5, 0.5), (2.0, 2.0)]),
|
||||||
|
("contrast", [(0.5, 0.5), (2.0, 2.0)]),
|
||||||
|
("saturation", [(0.5, 0.5), (2.0, 2.0)]),
|
||||||
|
("hue", [(-0.25, -0.25), (0.25, 0.25)]),
|
||||||
|
("sharpness", [(0.5, 0.5), (2.0, 2.0)]),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_backward_compatibility_torchvision(transform, min_max_values, img, single_transforms):
|
||||||
|
for min_max in min_max_values:
|
||||||
|
kwargs = {
|
||||||
|
f"{transform}_weight": 1.0,
|
||||||
|
f"{transform}_min_max": min_max,
|
||||||
|
}
|
||||||
|
tf = get_image_transforms(**kwargs)
|
||||||
|
actual = tf(img)
|
||||||
|
key = f"{transform}_{min_max[0]}_{min_max[1]}"
|
||||||
|
expected = single_transforms[key]
|
||||||
|
torch.testing.assert_close(actual, expected)
|
||||||
|
|
||||||
|
|
||||||
|
@require_x86_64_kernel
|
||||||
|
def test_backward_compatibility_default_config(img, default_transforms):
|
||||||
|
cfg = init_hydra_config(DEFAULT_CONFIG_PATH)
|
||||||
|
cfg_tf = cfg.training.image_transforms
|
||||||
|
default_tf = get_image_transforms(
|
||||||
|
brightness_weight=cfg_tf.brightness.weight,
|
||||||
|
brightness_min_max=cfg_tf.brightness.min_max,
|
||||||
|
contrast_weight=cfg_tf.contrast.weight,
|
||||||
|
contrast_min_max=cfg_tf.contrast.min_max,
|
||||||
|
saturation_weight=cfg_tf.saturation.weight,
|
||||||
|
saturation_min_max=cfg_tf.saturation.min_max,
|
||||||
|
hue_weight=cfg_tf.hue.weight,
|
||||||
|
hue_min_max=cfg_tf.hue.min_max,
|
||||||
|
sharpness_weight=cfg_tf.sharpness.weight,
|
||||||
|
sharpness_min_max=cfg_tf.sharpness.min_max,
|
||||||
|
max_num_transforms=cfg_tf.max_num_transforms,
|
||||||
|
random_order=cfg_tf.random_order,
|
||||||
|
)
|
||||||
|
|
||||||
|
with seeded_context(1337):
|
||||||
|
actual = default_tf(img)
|
||||||
|
|
||||||
|
expected = default_transforms["default"]
|
||||||
|
|
||||||
|
torch.testing.assert_close(actual, expected)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("p", [[0, 1], [1, 0]])
|
||||||
|
def test_random_subset_apply_single_choice(p, img):
|
||||||
|
flips = [v2.RandomHorizontalFlip(p=1), v2.RandomVerticalFlip(p=1)]
|
||||||
|
random_choice = RandomSubsetApply(flips, p=p, n_subset=1, random_order=False)
|
||||||
|
actual = random_choice(img)
|
||||||
|
|
||||||
|
p_horz, _ = p
|
||||||
|
if p_horz:
|
||||||
|
torch.testing.assert_close(actual, F.horizontal_flip(img))
|
||||||
|
else:
|
||||||
|
torch.testing.assert_close(actual, F.vertical_flip(img))
|
||||||
|
|
||||||
|
|
||||||
|
def test_random_subset_apply_random_order(img):
|
||||||
|
flips = [v2.RandomHorizontalFlip(p=1), v2.RandomVerticalFlip(p=1)]
|
||||||
|
random_order = RandomSubsetApply(flips, p=[0.5, 0.5], n_subset=2, random_order=True)
|
||||||
|
# We can't really check whether the transforms are actually applied in random order. However,
|
||||||
|
# horizontal and vertical flip are commutative. Meaning, even under the assumption that the transform
|
||||||
|
# applies them in random order, we can use a fixed order to compute the expected value.
|
||||||
|
actual = random_order(img)
|
||||||
|
expected = v2.Compose(flips)(img)
|
||||||
|
torch.testing.assert_close(actual, expected)
|
||||||
|
|
||||||
|
|
||||||
|
def test_random_subset_apply_valid_transforms(color_jitters, img):
|
||||||
|
transform = RandomSubsetApply(color_jitters)
|
||||||
|
output = transform(img)
|
||||||
|
assert output.shape == img.shape
|
||||||
|
|
||||||
|
|
||||||
|
def test_random_subset_apply_probability_length_mismatch(color_jitters):
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
RandomSubsetApply(color_jitters, p=[0.5, 0.5])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("n_subset", [0, 5])
|
||||||
|
def test_random_subset_apply_invalid_n_subset(color_jitters, n_subset):
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
RandomSubsetApply(color_jitters, n_subset=n_subset)
|
||||||
|
|
||||||
|
|
||||||
|
def test_sharpness_jitter_valid_range_tuple(img):
|
||||||
|
tf = SharpnessJitter((0.1, 2.0))
|
||||||
|
output = tf(img)
|
||||||
|
assert output.shape == img.shape
|
||||||
|
|
||||||
|
|
||||||
|
def test_sharpness_jitter_valid_range_float(img):
|
||||||
|
tf = SharpnessJitter(0.5)
|
||||||
|
output = tf(img)
|
||||||
|
assert output.shape == img.shape
|
||||||
|
|
||||||
|
|
||||||
|
def test_sharpness_jitter_invalid_range_min_negative():
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
SharpnessJitter((-0.1, 2.0))
|
||||||
|
|
||||||
|
|
||||||
|
def test_sharpness_jitter_invalid_range_max_smaller():
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
SharpnessJitter((2.0, 0.1))
|
|
@ -0,0 +1,352 @@
|
||||||
|
"""
|
||||||
|
This file contains generic tests to ensure that nothing breaks if we modify the push_dataset_to_hub API.
|
||||||
|
Also, this file contains backward compatibility tests. Because they are slow and require to download the raw datasets,
|
||||||
|
we skip them for now in our CI.
|
||||||
|
|
||||||
|
Example to run backward compatiblity tests locally:
|
||||||
|
```
|
||||||
|
DATA_DIR=tests/data python -m pytest --run-skipped tests/test_push_dataset_to_hub.py::test_push_dataset_to_hub_pusht_backward_compatibility
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
|
from lerobot.common.datasets.push_dataset_to_hub.utils import save_images_concurrently
|
||||||
|
from lerobot.common.datasets.video_utils import encode_video_frames
|
||||||
|
from lerobot.scripts.push_dataset_to_hub import push_dataset_to_hub
|
||||||
|
from tests.utils import require_package_arg
|
||||||
|
|
||||||
|
|
||||||
|
def _mock_download_raw_pusht(raw_dir, num_frames=4, num_episodes=3):
|
||||||
|
import zarr
|
||||||
|
|
||||||
|
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
zarr_path = raw_dir / "pusht_cchi_v7_replay.zarr"
|
||||||
|
store = zarr.DirectoryStore(zarr_path)
|
||||||
|
zarr_data = zarr.group(store=store)
|
||||||
|
|
||||||
|
zarr_data.create_dataset(
|
||||||
|
"data/action", shape=(num_frames, 1), chunks=(num_frames, 1), dtype=np.float32, overwrite=True
|
||||||
|
)
|
||||||
|
zarr_data.create_dataset(
|
||||||
|
"data/img",
|
||||||
|
shape=(num_frames, 96, 96, 3),
|
||||||
|
chunks=(num_frames, 96, 96, 3),
|
||||||
|
dtype=np.uint8,
|
||||||
|
overwrite=True,
|
||||||
|
)
|
||||||
|
zarr_data.create_dataset(
|
||||||
|
"data/n_contacts", shape=(num_frames, 2), chunks=(num_frames, 2), dtype=np.float32, overwrite=True
|
||||||
|
)
|
||||||
|
zarr_data.create_dataset(
|
||||||
|
"data/state", shape=(num_frames, 5), chunks=(num_frames, 5), dtype=np.float32, overwrite=True
|
||||||
|
)
|
||||||
|
zarr_data.create_dataset(
|
||||||
|
"data/keypoint", shape=(num_frames, 9, 2), chunks=(num_frames, 9, 2), dtype=np.float32, overwrite=True
|
||||||
|
)
|
||||||
|
zarr_data.create_dataset(
|
||||||
|
"meta/episode_ends", shape=(num_episodes,), chunks=(num_episodes,), dtype=np.int32, overwrite=True
|
||||||
|
)
|
||||||
|
|
||||||
|
zarr_data["data/action"][:] = np.random.randn(num_frames, 1)
|
||||||
|
zarr_data["data/img"][:] = np.random.randint(0, 255, size=(num_frames, 96, 96, 3), dtype=np.uint8)
|
||||||
|
zarr_data["data/n_contacts"][:] = np.random.randn(num_frames, 2)
|
||||||
|
zarr_data["data/state"][:] = np.random.randn(num_frames, 5)
|
||||||
|
zarr_data["data/keypoint"][:] = np.random.randn(num_frames, 9, 2)
|
||||||
|
zarr_data["meta/episode_ends"][:] = np.array([1, 3, 4])
|
||||||
|
|
||||||
|
store.close()
|
||||||
|
|
||||||
|
|
||||||
|
def _mock_download_raw_umi(raw_dir, num_frames=4, num_episodes=3):
|
||||||
|
import zarr
|
||||||
|
|
||||||
|
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
zarr_path = raw_dir / "cup_in_the_wild.zarr"
|
||||||
|
store = zarr.DirectoryStore(zarr_path)
|
||||||
|
zarr_data = zarr.group(store=store)
|
||||||
|
|
||||||
|
zarr_data.create_dataset(
|
||||||
|
"data/camera0_rgb",
|
||||||
|
shape=(num_frames, 96, 96, 3),
|
||||||
|
chunks=(num_frames, 96, 96, 3),
|
||||||
|
dtype=np.uint8,
|
||||||
|
overwrite=True,
|
||||||
|
)
|
||||||
|
zarr_data.create_dataset(
|
||||||
|
"data/robot0_demo_end_pose",
|
||||||
|
shape=(num_frames, 5),
|
||||||
|
chunks=(num_frames, 5),
|
||||||
|
dtype=np.float32,
|
||||||
|
overwrite=True,
|
||||||
|
)
|
||||||
|
zarr_data.create_dataset(
|
||||||
|
"data/robot0_demo_start_pose",
|
||||||
|
shape=(num_frames, 5),
|
||||||
|
chunks=(num_frames, 5),
|
||||||
|
dtype=np.float32,
|
||||||
|
overwrite=True,
|
||||||
|
)
|
||||||
|
zarr_data.create_dataset(
|
||||||
|
"data/robot0_eef_pos", shape=(num_frames, 5), chunks=(num_frames, 5), dtype=np.float32, overwrite=True
|
||||||
|
)
|
||||||
|
zarr_data.create_dataset(
|
||||||
|
"data/robot0_eef_rot_axis_angle",
|
||||||
|
shape=(num_frames, 5),
|
||||||
|
chunks=(num_frames, 5),
|
||||||
|
dtype=np.float32,
|
||||||
|
overwrite=True,
|
||||||
|
)
|
||||||
|
zarr_data.create_dataset(
|
||||||
|
"data/robot0_gripper_width",
|
||||||
|
shape=(num_frames, 5),
|
||||||
|
chunks=(num_frames, 5),
|
||||||
|
dtype=np.float32,
|
||||||
|
overwrite=True,
|
||||||
|
)
|
||||||
|
zarr_data.create_dataset(
|
||||||
|
"meta/episode_ends", shape=(num_episodes,), chunks=(num_episodes,), dtype=np.int32, overwrite=True
|
||||||
|
)
|
||||||
|
|
||||||
|
zarr_data["data/camera0_rgb"][:] = np.random.randint(0, 255, size=(num_frames, 96, 96, 3), dtype=np.uint8)
|
||||||
|
zarr_data["data/robot0_demo_end_pose"][:] = np.random.randn(num_frames, 5)
|
||||||
|
zarr_data["data/robot0_demo_start_pose"][:] = np.random.randn(num_frames, 5)
|
||||||
|
zarr_data["data/robot0_eef_pos"][:] = np.random.randn(num_frames, 5)
|
||||||
|
zarr_data["data/robot0_eef_rot_axis_angle"][:] = np.random.randn(num_frames, 5)
|
||||||
|
zarr_data["data/robot0_gripper_width"][:] = np.random.randn(num_frames, 5)
|
||||||
|
zarr_data["meta/episode_ends"][:] = np.array([1, 3, 4])
|
||||||
|
|
||||||
|
store.close()
|
||||||
|
|
||||||
|
|
||||||
|
def _mock_download_raw_xarm(raw_dir, num_frames=4):
|
||||||
|
import pickle
|
||||||
|
|
||||||
|
dataset_dict = {
|
||||||
|
"observations": {
|
||||||
|
"rgb": np.random.randint(0, 255, size=(num_frames, 3, 84, 84), dtype=np.uint8),
|
||||||
|
"state": np.random.randn(num_frames, 4),
|
||||||
|
},
|
||||||
|
"actions": np.random.randn(num_frames, 3),
|
||||||
|
"rewards": np.random.randn(num_frames),
|
||||||
|
"masks": np.random.randn(num_frames),
|
||||||
|
"dones": np.array([False, True, True, True]),
|
||||||
|
}
|
||||||
|
|
||||||
|
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
pkl_path = raw_dir / "buffer.pkl"
|
||||||
|
with open(pkl_path, "wb") as f:
|
||||||
|
pickle.dump(dataset_dict, f)
|
||||||
|
|
||||||
|
|
||||||
|
def _mock_download_raw_aloha(raw_dir, num_frames=6, num_episodes=3):
|
||||||
|
import h5py
|
||||||
|
|
||||||
|
for ep_idx in range(num_episodes):
|
||||||
|
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
path_h5 = raw_dir / f"episode_{ep_idx}.hdf5"
|
||||||
|
with h5py.File(str(path_h5), "w") as f:
|
||||||
|
f.create_dataset("action", data=np.random.randn(num_frames // num_episodes, 14))
|
||||||
|
f.create_dataset("observations/qpos", data=np.random.randn(num_frames // num_episodes, 14))
|
||||||
|
f.create_dataset("observations/qvel", data=np.random.randn(num_frames // num_episodes, 14))
|
||||||
|
f.create_dataset(
|
||||||
|
"observations/images/top",
|
||||||
|
data=np.random.randint(
|
||||||
|
0, 255, size=(num_frames // num_episodes, 480, 640, 3), dtype=np.uint8
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _mock_download_raw_dora(raw_dir, num_frames=6, num_episodes=3, fps=30):
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
|
||||||
|
import pandas
|
||||||
|
|
||||||
|
def write_parquet(key, timestamps, values):
|
||||||
|
data = {
|
||||||
|
"timestamp_utc": timestamps,
|
||||||
|
key: values,
|
||||||
|
}
|
||||||
|
df = pandas.DataFrame(data)
|
||||||
|
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
df.to_parquet(raw_dir / f"{key}.parquet", engine="pyarrow")
|
||||||
|
|
||||||
|
episode_indices = [None, None, -1, None, None, -1, None, None, -1]
|
||||||
|
episode_indices_mapping = [0, 0, 0, 1, 1, 1, 2, 2, 2]
|
||||||
|
frame_indices = [0, 1, -1, 0, 1, -1, 0, 1, -1]
|
||||||
|
|
||||||
|
cam_key = "observation.images.cam_high"
|
||||||
|
timestamps = []
|
||||||
|
actions = []
|
||||||
|
states = []
|
||||||
|
frames = []
|
||||||
|
# `+ num_episodes`` for buffer frames associated to episode_index=-1
|
||||||
|
for i, frame_idx in enumerate(frame_indices):
|
||||||
|
t_utc = datetime.now(timezone.utc) + timedelta(seconds=i / fps)
|
||||||
|
action = np.random.randn(21).tolist()
|
||||||
|
state = np.random.randn(21).tolist()
|
||||||
|
ep_idx = episode_indices_mapping[i]
|
||||||
|
frame = [{"path": f"videos/{cam_key}_episode_{ep_idx:06d}.mp4", "timestamp": frame_idx / fps}]
|
||||||
|
timestamps.append(t_utc)
|
||||||
|
actions.append(action)
|
||||||
|
states.append(state)
|
||||||
|
frames.append(frame)
|
||||||
|
|
||||||
|
write_parquet(cam_key, timestamps, frames)
|
||||||
|
write_parquet("observation.state", timestamps, states)
|
||||||
|
write_parquet("action", timestamps, actions)
|
||||||
|
write_parquet("episode_index", timestamps, episode_indices)
|
||||||
|
|
||||||
|
# write fake mp4 file for each episode
|
||||||
|
for ep_idx in range(num_episodes):
|
||||||
|
imgs_array = np.random.randint(0, 255, size=(num_frames // num_episodes, 480, 640, 3), dtype=np.uint8)
|
||||||
|
|
||||||
|
tmp_imgs_dir = raw_dir / "tmp_images"
|
||||||
|
save_images_concurrently(imgs_array, tmp_imgs_dir)
|
||||||
|
|
||||||
|
fname = f"{cam_key}_episode_{ep_idx:06d}.mp4"
|
||||||
|
video_path = raw_dir / "videos" / fname
|
||||||
|
encode_video_frames(tmp_imgs_dir, video_path, fps)
|
||||||
|
|
||||||
|
|
||||||
|
def _mock_download_raw(raw_dir, repo_id):
|
||||||
|
if "wrist_gripper" in repo_id:
|
||||||
|
_mock_download_raw_dora(raw_dir)
|
||||||
|
elif "aloha" in repo_id:
|
||||||
|
_mock_download_raw_aloha(raw_dir)
|
||||||
|
elif "pusht" in repo_id:
|
||||||
|
_mock_download_raw_pusht(raw_dir)
|
||||||
|
elif "xarm" in repo_id:
|
||||||
|
_mock_download_raw_xarm(raw_dir)
|
||||||
|
elif "umi" in repo_id:
|
||||||
|
_mock_download_raw_umi(raw_dir)
|
||||||
|
else:
|
||||||
|
raise ValueError(repo_id)
|
||||||
|
|
||||||
|
|
||||||
|
def test_push_dataset_to_hub_invalid_repo_id(tmpdir):
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
push_dataset_to_hub(Path(tmpdir), "raw_format", "invalid_repo_id")
|
||||||
|
|
||||||
|
|
||||||
|
def test_push_dataset_to_hub_out_dir_force_override_false(tmpdir):
|
||||||
|
tmpdir = Path(tmpdir)
|
||||||
|
out_dir = tmpdir / "out"
|
||||||
|
raw_dir = tmpdir / "raw"
|
||||||
|
# mkdir to skip download
|
||||||
|
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
push_dataset_to_hub(
|
||||||
|
raw_dir=raw_dir,
|
||||||
|
raw_format="some_format",
|
||||||
|
repo_id="user/dataset",
|
||||||
|
local_dir=out_dir,
|
||||||
|
force_override=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"required_packages, raw_format, repo_id",
|
||||||
|
[
|
||||||
|
(["gym-pusht"], "pusht_zarr", "lerobot/pusht"),
|
||||||
|
(None, "xarm_pkl", "lerobot/xarm_lift_medium"),
|
||||||
|
(None, "aloha_hdf5", "lerobot/aloha_sim_insertion_scripted"),
|
||||||
|
(["imagecodecs"], "umi_zarr", "lerobot/umi_cup_in_the_wild"),
|
||||||
|
(None, "dora_parquet", "cadene/wrist_gripper"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@require_package_arg
|
||||||
|
def test_push_dataset_to_hub_format(required_packages, tmpdir, raw_format, repo_id):
|
||||||
|
num_episodes = 3
|
||||||
|
tmpdir = Path(tmpdir)
|
||||||
|
|
||||||
|
raw_dir = tmpdir / f"{repo_id}_raw"
|
||||||
|
_mock_download_raw(raw_dir, repo_id)
|
||||||
|
|
||||||
|
local_dir = tmpdir / repo_id
|
||||||
|
|
||||||
|
lerobot_dataset = push_dataset_to_hub(
|
||||||
|
raw_dir=raw_dir,
|
||||||
|
raw_format=raw_format,
|
||||||
|
repo_id=repo_id,
|
||||||
|
push_to_hub=False,
|
||||||
|
local_dir=local_dir,
|
||||||
|
force_override=False,
|
||||||
|
cache_dir=tmpdir / "cache",
|
||||||
|
)
|
||||||
|
|
||||||
|
# minimal generic tests on the local directory containing LeRobotDataset
|
||||||
|
assert (local_dir / "meta_data" / "info.json").exists()
|
||||||
|
assert (local_dir / "meta_data" / "stats.safetensors").exists()
|
||||||
|
assert (local_dir / "meta_data" / "episode_data_index.safetensors").exists()
|
||||||
|
for i in range(num_episodes):
|
||||||
|
for cam_key in lerobot_dataset.camera_keys:
|
||||||
|
assert (local_dir / "videos" / f"{cam_key}_episode_{i:06d}.mp4").exists()
|
||||||
|
assert (local_dir / "train" / "dataset_info.json").exists()
|
||||||
|
assert (local_dir / "train" / "state.json").exists()
|
||||||
|
assert len(list((local_dir / "train").glob("*.arrow"))) > 0
|
||||||
|
|
||||||
|
# minimal generic tests on the item
|
||||||
|
item = lerobot_dataset[0]
|
||||||
|
assert "index" in item
|
||||||
|
assert "episode_index" in item
|
||||||
|
assert "timestamp" in item
|
||||||
|
for cam_key in lerobot_dataset.camera_keys:
|
||||||
|
assert cam_key in item
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"raw_format, repo_id",
|
||||||
|
[
|
||||||
|
# TODO(rcadene): add raw dataset test artifacts
|
||||||
|
("pusht_zarr", "lerobot/pusht"),
|
||||||
|
("xarm_pkl", "lerobot/xarm_lift_medium"),
|
||||||
|
("aloha_hdf5", "lerobot/aloha_sim_insertion_scripted"),
|
||||||
|
("umi_zarr", "lerobot/umi_cup_in_the_wild"),
|
||||||
|
("dora_parquet", "cadene/wrist_gripper"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@pytest.mark.skip(
|
||||||
|
"Not compatible with our CI since it downloads raw datasets. Run with `DATA_DIR=tests/data python -m pytest --run-skipped tests/test_push_dataset_to_hub.py::test_push_dataset_to_hub_pusht_backward_compatibility`"
|
||||||
|
)
|
||||||
|
def test_push_dataset_to_hub_pusht_backward_compatibility(tmpdir, raw_format, repo_id):
|
||||||
|
_, dataset_id = repo_id.split("/")
|
||||||
|
|
||||||
|
tmpdir = Path(tmpdir)
|
||||||
|
raw_dir = tmpdir / f"{dataset_id}_raw"
|
||||||
|
local_dir = tmpdir / repo_id
|
||||||
|
|
||||||
|
push_dataset_to_hub(
|
||||||
|
raw_dir=raw_dir,
|
||||||
|
raw_format=raw_format,
|
||||||
|
repo_id=repo_id,
|
||||||
|
push_to_hub=False,
|
||||||
|
local_dir=local_dir,
|
||||||
|
force_override=False,
|
||||||
|
cache_dir=tmpdir / "cache",
|
||||||
|
episodes=[0],
|
||||||
|
)
|
||||||
|
|
||||||
|
ds_actual = LeRobotDataset(repo_id, root=tmpdir)
|
||||||
|
ds_reference = LeRobotDataset(repo_id)
|
||||||
|
|
||||||
|
assert len(ds_reference.hf_dataset) == len(ds_actual.hf_dataset)
|
||||||
|
|
||||||
|
def check_same_items(item1, item2):
|
||||||
|
assert item1.keys() == item2.keys(), "Keys mismatch"
|
||||||
|
|
||||||
|
for key in item1:
|
||||||
|
if isinstance(item1[key], torch.Tensor) and isinstance(item2[key], torch.Tensor):
|
||||||
|
assert torch.equal(item1[key], item2[key]), f"Mismatch found in key: {key}"
|
||||||
|
else:
|
||||||
|
assert item1[key] == item2[key], f"Mismatch found in key: {key}"
|
||||||
|
|
||||||
|
for i in range(len(ds_reference.hf_dataset)):
|
||||||
|
item_reference = ds_reference.hf_dataset[i]
|
||||||
|
item_actual = ds_actual.hf_dataset[i]
|
||||||
|
check_same_items(item_reference, item_actual)
|
|
@ -13,6 +13,8 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from lerobot.scripts.visualize_dataset import visualize_dataset
|
from lerobot.scripts.visualize_dataset import visualize_dataset
|
||||||
|
@ -30,3 +32,20 @@ def test_visualize_dataset(tmpdir, repo_id):
|
||||||
serve=False,
|
serve=False,
|
||||||
)
|
)
|
||||||
assert rrd_path.exists()
|
assert rrd_path.exists()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"repo_id",
|
||||||
|
["lerobot/pusht"],
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize("root", [Path(__file__).parent / "data"])
|
||||||
|
def test_visualize_local_dataset(tmpdir, repo_id, root):
|
||||||
|
rrd_path = visualize_dataset(
|
||||||
|
repo_id,
|
||||||
|
episode_index=0,
|
||||||
|
batch_size=32,
|
||||||
|
save=True,
|
||||||
|
output_dir=tmpdir,
|
||||||
|
root=root,
|
||||||
|
)
|
||||||
|
assert rrd_path.exists()
|
||||||
|
|
|
@ -76,6 +76,7 @@ def require_env(func):
|
||||||
"""
|
"""
|
||||||
Decorator that skips the test if the required environment package is not installed.
|
Decorator that skips the test if the required environment package is not installed.
|
||||||
As it need 'env_name' in args, it also checks whether it is provided as an argument.
|
As it need 'env_name' in args, it also checks whether it is provided as an argument.
|
||||||
|
If 'env_name' is None, this check is skipped.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
|
@ -91,7 +92,7 @@ def require_env(func):
|
||||||
|
|
||||||
# Perform the package check
|
# Perform the package check
|
||||||
package_name = f"gym_{env_name}"
|
package_name = f"gym_{env_name}"
|
||||||
if not is_package_available(package_name):
|
if env_name is not None and not is_package_available(package_name):
|
||||||
pytest.skip(f"gym-{env_name} not installed")
|
pytest.skip(f"gym-{env_name} not installed")
|
||||||
|
|
||||||
return func(*args, **kwargs)
|
return func(*args, **kwargs)
|
||||||
|
@ -99,6 +100,38 @@ def require_env(func):
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def require_package_arg(func):
|
||||||
|
"""
|
||||||
|
Decorator that skips the test if the required package is not installed.
|
||||||
|
This is similar to `require_env` but more general in that it can check any package (not just environments).
|
||||||
|
As it need 'required_packages' in args, it also checks whether it is provided as an argument.
|
||||||
|
If 'required_packages' is None, this check is skipped.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@wraps(func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
# Determine if 'required_packages' is provided and extract its value
|
||||||
|
arg_names = func.__code__.co_varnames[: func.__code__.co_argcount]
|
||||||
|
if "required_packages" in arg_names:
|
||||||
|
# Get the index of 'required_packages' and retrieve the value from args
|
||||||
|
index = arg_names.index("required_packages")
|
||||||
|
required_packages = args[index] if len(args) > index else kwargs.get("required_packages")
|
||||||
|
else:
|
||||||
|
raise ValueError("Function does not have 'required_packages' as an argument.")
|
||||||
|
|
||||||
|
if required_packages is None:
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
# Perform the package check
|
||||||
|
for package in required_packages:
|
||||||
|
if not is_package_available(package):
|
||||||
|
pytest.skip(f"{package} not installed")
|
||||||
|
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
def require_package(package_name):
|
def require_package(package_name):
|
||||||
"""
|
"""
|
||||||
Decorator that skips the test if the specified package is not installed.
|
Decorator that skips the test if the specified package is not installed.
|
||||||
|
|
Loading…
Reference in New Issue