Merge remote-tracking branch 'origin/main' into user/rcadene/2024_06_01_custom_visualize_dataset
This commit is contained in:
commit
b502a82005
|
@ -10,6 +10,7 @@ on:
|
|||
- "examples/**"
|
||||
- ".github/**"
|
||||
- "poetry.lock"
|
||||
- "Makefile"
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
|
@ -19,6 +20,7 @@ on:
|
|||
- "examples/**"
|
||||
- ".github/**"
|
||||
- "poetry.lock"
|
||||
- "Makefile"
|
||||
|
||||
jobs:
|
||||
pytest:
|
||||
|
@ -32,8 +34,8 @@ jobs:
|
|||
with:
|
||||
lfs: true # Ensure LFS files are pulled
|
||||
|
||||
- name: Install EGL
|
||||
run: sudo apt-get update && sudo apt-get install -y libegl1-mesa-dev
|
||||
- name: Install apt dependencies
|
||||
run: sudo apt-get update && sudo apt-get install -y libegl1-mesa-dev ffmpeg
|
||||
|
||||
- name: Install poetry
|
||||
run: |
|
||||
|
@ -70,6 +72,9 @@ jobs:
|
|||
with:
|
||||
lfs: true # Ensure LFS files are pulled
|
||||
|
||||
- name: Install apt dependencies
|
||||
run: sudo apt-get update && sudo apt-get install -y ffmpeg
|
||||
|
||||
- name: Install poetry
|
||||
run: |
|
||||
pipx install poetry && poetry config virtualenvs.in-project true
|
||||
|
@ -104,7 +109,7 @@ jobs:
|
|||
with:
|
||||
lfs: true # Ensure LFS files are pulled
|
||||
|
||||
- name: Install EGL
|
||||
- name: Install apt dependencies
|
||||
run: sudo apt-get update && sudo apt-get install -y libegl1-mesa-dev
|
||||
|
||||
- name: Install poetry
|
||||
|
|
|
@ -0,0 +1,18 @@
|
|||
on:
|
||||
push:
|
||||
|
||||
name: Secret Leaks
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
trufflehog:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Secret Scanning
|
||||
uses: trufflesecurity/trufflehog@main
|
7
Makefile
7
Makefile
|
@ -5,7 +5,7 @@ PYTHON_PATH := $(shell which python)
|
|||
# If Poetry is installed, redefine PYTHON_PATH to use the Poetry-managed Python
|
||||
POETRY_CHECK := $(shell command -v poetry)
|
||||
ifneq ($(POETRY_CHECK),)
|
||||
PYTHON_PATH := $(shell poetry run which python)
|
||||
PYTHON_PATH := $(shell poetry run which python)
|
||||
endif
|
||||
|
||||
export PATH := $(dir $(PYTHON_PATH)):$(PATH)
|
||||
|
@ -46,6 +46,7 @@ test-act-ete-train:
|
|||
policy.n_action_steps=20 \
|
||||
policy.chunk_size=20 \
|
||||
training.batch_size=2 \
|
||||
training.image_transforms.enable=true \
|
||||
hydra.run.dir=tests/outputs/act/
|
||||
|
||||
test-act-ete-eval:
|
||||
|
@ -73,6 +74,7 @@ test-act-ete-train-amp:
|
|||
policy.chunk_size=20 \
|
||||
training.batch_size=2 \
|
||||
hydra.run.dir=tests/outputs/act_amp/ \
|
||||
training.image_transforms.enable=true \
|
||||
use_amp=true
|
||||
|
||||
test-act-ete-eval-amp:
|
||||
|
@ -100,6 +102,7 @@ test-diffusion-ete-train:
|
|||
training.save_checkpoint=true \
|
||||
training.save_freq=2 \
|
||||
training.batch_size=2 \
|
||||
training.image_transforms.enable=true \
|
||||
hydra.run.dir=tests/outputs/diffusion/
|
||||
|
||||
test-diffusion-ete-eval:
|
||||
|
@ -127,6 +130,7 @@ test-tdmpc-ete-train:
|
|||
training.save_checkpoint=true \
|
||||
training.save_freq=2 \
|
||||
training.batch_size=2 \
|
||||
training.image_transforms.enable=true \
|
||||
hydra.run.dir=tests/outputs/tdmpc/
|
||||
|
||||
test-tdmpc-ete-eval:
|
||||
|
@ -159,5 +163,6 @@ test-act-pusht-tutorial:
|
|||
training.save_model=true \
|
||||
training.save_freq=2 \
|
||||
training.batch_size=2 \
|
||||
training.image_transforms.enable=true \
|
||||
hydra.run.dir=tests/outputs/act_pusht/
|
||||
rm lerobot/configs/policy/created_by_Makefile.yaml
|
||||
|
|
10
README.md
10
README.md
|
@ -228,13 +228,13 @@ To add a dataset to the hub, you need to login using a write-access token, which
|
|||
huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
|
||||
```
|
||||
|
||||
Then move your dataset folder in `data` directory (e.g. `data/aloha_static_pingpong_test`), and push your dataset to the hub with:
|
||||
Then point to your raw dataset folder (e.g. `data/aloha_static_pingpong_test_raw`), and push your dataset to the hub with:
|
||||
```bash
|
||||
python lerobot/scripts/push_dataset_to_hub.py \
|
||||
--data-dir data \
|
||||
--dataset-id aloha_static_pingpong_test \
|
||||
--raw-format aloha_hdf5 \
|
||||
--community-id lerobot
|
||||
--raw-dir data/aloha_static_pingpong_test_raw \
|
||||
--out-dir data \
|
||||
--repo-id lerobot/aloha_static_pingpong_test \
|
||||
--raw-format aloha_hdf5
|
||||
```
|
||||
|
||||
See `python lerobot/scripts/push_dataset_to_hub.py --help` for more instructions.
|
||||
|
|
|
@ -46,7 +46,7 @@ defaults:
|
|||
- policy: diffusion
|
||||
```
|
||||
|
||||
This logic tells Hydra to incorporate configuration parameters from `env/pusht.yaml` and `policy/diffusion.yaml`. _Note: Be aware of the order as any configuration parameters with the same name will be overidden. Thus, `default.yaml` is overriden by `env/pusht.yaml` which is overidden by `policy/diffusion.yaml`_.
|
||||
This logic tells Hydra to incorporate configuration parameters from `env/pusht.yaml` and `policy/diffusion.yaml`. _Note: Be aware of the order as any configuration parameters with the same name will be overidden. Thus, `default.yaml` is overridden by `env/pusht.yaml` which is overidden by `policy/diffusion.yaml`_.
|
||||
|
||||
Then, `default.yaml` also contains common configuration parameters such as `device: cuda` or `use_amp: false` (for enabling fp16 training). Some other parameters are set to `???` which indicates that they are expected to be set in additional yaml files. For instance, `training.offline_steps: ???` in `default.yaml` is set to `200000` in `diffusion.yaml`.
|
||||
|
||||
|
|
|
@ -0,0 +1,52 @@
|
|||
"""
|
||||
This script demonstrates how to use torchvision's image transformation with LeRobotDataset for data
|
||||
augmentation purposes. The transformations are passed to the dataset as an argument upon creation, and
|
||||
transforms are applied to the observation images before they are returned in the dataset's __get_item__.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from torchvision.transforms import ToPILImage, v2
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
dataset_repo_id = "lerobot/aloha_static_tape"
|
||||
|
||||
# Create a LeRobotDataset with no transformations
|
||||
dataset = LeRobotDataset(dataset_repo_id)
|
||||
# This is equivalent to `dataset = LeRobotDataset(dataset_repo_id, image_transforms=None)`
|
||||
|
||||
# Get the index of the first observation in the first episode
|
||||
first_idx = dataset.episode_data_index["from"][0].item()
|
||||
|
||||
# Get the frame corresponding to the first camera
|
||||
frame = dataset[first_idx][dataset.camera_keys[0]]
|
||||
|
||||
|
||||
# Define the transformations
|
||||
transforms = v2.Compose(
|
||||
[
|
||||
v2.ColorJitter(brightness=(0.5, 1.5)),
|
||||
v2.ColorJitter(contrast=(0.5, 1.5)),
|
||||
v2.RandomAdjustSharpness(sharpness_factor=2, p=1),
|
||||
]
|
||||
)
|
||||
|
||||
# Create another LeRobotDataset with the defined transformations
|
||||
transformed_dataset = LeRobotDataset(dataset_repo_id, image_transforms=transforms)
|
||||
|
||||
# Get a frame from the transformed dataset
|
||||
transformed_frame = transformed_dataset[first_idx][transformed_dataset.camera_keys[0]]
|
||||
|
||||
# Create a directory to store output images
|
||||
output_dir = Path("outputs/image_transforms")
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Save the original frame
|
||||
to_pil = ToPILImage()
|
||||
to_pil(frame).save(output_dir / "original_frame.png", quality=100)
|
||||
print(f"Original frame saved to {output_dir / 'original_frame.png'}.")
|
||||
|
||||
# Save the transformed frame
|
||||
to_pil(transformed_frame).save(output_dir / "transformed_frame.png", quality=100)
|
||||
print(f"Transformed frame saved to {output_dir / 'transformed_frame.png'}.")
|
|
@ -19,6 +19,7 @@ import torch
|
|||
from omegaconf import ListConfig, OmegaConf
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, MultiLeRobotDataset
|
||||
from lerobot.common.datasets.transforms import get_image_transforms
|
||||
|
||||
|
||||
def resolve_delta_timestamps(cfg):
|
||||
|
@ -71,17 +72,37 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData
|
|||
|
||||
resolve_delta_timestamps(cfg)
|
||||
|
||||
# TODO(rcadene): add data augmentations
|
||||
image_transforms = None
|
||||
if cfg.training.image_transforms.enable:
|
||||
cfg_tf = cfg.training.image_transforms
|
||||
image_transforms = get_image_transforms(
|
||||
brightness_weight=cfg_tf.brightness.weight,
|
||||
brightness_min_max=cfg_tf.brightness.min_max,
|
||||
contrast_weight=cfg_tf.contrast.weight,
|
||||
contrast_min_max=cfg_tf.contrast.min_max,
|
||||
saturation_weight=cfg_tf.saturation.weight,
|
||||
saturation_min_max=cfg_tf.saturation.min_max,
|
||||
hue_weight=cfg_tf.hue.weight,
|
||||
hue_min_max=cfg_tf.hue.min_max,
|
||||
sharpness_weight=cfg_tf.sharpness.weight,
|
||||
sharpness_min_max=cfg_tf.sharpness.min_max,
|
||||
max_num_transforms=cfg_tf.max_num_transforms,
|
||||
random_order=cfg_tf.random_order,
|
||||
)
|
||||
|
||||
if isinstance(cfg.dataset_repo_id, str):
|
||||
dataset = LeRobotDataset(
|
||||
cfg.dataset_repo_id,
|
||||
split=split,
|
||||
delta_timestamps=cfg.training.get("delta_timestamps"),
|
||||
image_transforms=image_transforms,
|
||||
)
|
||||
else:
|
||||
dataset = MultiLeRobotDataset(
|
||||
cfg.dataset_repo_id, split=split, delta_timestamps=cfg.training.get("delta_timestamps")
|
||||
cfg.dataset_repo_id,
|
||||
split=split,
|
||||
delta_timestamps=cfg.training.get("delta_timestamps"),
|
||||
image_transforms=image_transforms,
|
||||
)
|
||||
|
||||
if cfg.get("override_dataset_stats"):
|
||||
|
|
|
@ -46,7 +46,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
version: str | None = CODEBASE_VERSION,
|
||||
root: Path | None = DATA_DIR,
|
||||
split: str = "train",
|
||||
transform: Callable | None = None,
|
||||
image_transforms: Callable | None = None,
|
||||
delta_timestamps: dict[list[float]] | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
@ -54,7 +54,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
self.version = version
|
||||
self.root = root
|
||||
self.split = split
|
||||
self.transform = transform
|
||||
self.image_transforms = image_transforms
|
||||
self.delta_timestamps = delta_timestamps
|
||||
# load data from hub or locally when root is provided
|
||||
# TODO(rcadene, aliberts): implement faster transfer
|
||||
|
@ -151,8 +151,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
self.tolerance_s,
|
||||
)
|
||||
|
||||
if self.transform is not None:
|
||||
item = self.transform(item)
|
||||
if self.image_transforms is not None:
|
||||
for cam in self.camera_keys:
|
||||
item[cam] = self.image_transforms(item[cam])
|
||||
|
||||
return item
|
||||
|
||||
|
@ -168,7 +169,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
f" Recorded Frames per Second: {self.fps},\n"
|
||||
f" Camera Keys: {self.camera_keys},\n"
|
||||
f" Video Frame Keys: {self.video_frame_keys if self.video else 'N/A'},\n"
|
||||
f" Transformations: {self.transform},\n"
|
||||
f" Transformations: {self.image_transforms},\n"
|
||||
f")"
|
||||
)
|
||||
|
||||
|
@ -202,7 +203,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
obj.version = version
|
||||
obj.root = root
|
||||
obj.split = split
|
||||
obj.transform = transform
|
||||
obj.image_transforms = transform
|
||||
obj.delta_timestamps = delta_timestamps
|
||||
obj.hf_dataset = hf_dataset
|
||||
obj.episode_data_index = episode_data_index
|
||||
|
@ -225,7 +226,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
|||
version: str | None = CODEBASE_VERSION,
|
||||
root: Path | None = DATA_DIR,
|
||||
split: str = "train",
|
||||
transform: Callable | None = None,
|
||||
image_transforms: Callable | None = None,
|
||||
delta_timestamps: dict[list[float]] | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
@ -239,7 +240,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
|||
root=root,
|
||||
split=split,
|
||||
delta_timestamps=delta_timestamps,
|
||||
transform=transform,
|
||||
image_transforms=image_transforms,
|
||||
)
|
||||
for repo_id in repo_ids
|
||||
]
|
||||
|
@ -274,7 +275,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
|||
self.version = version
|
||||
self.root = root
|
||||
self.split = split
|
||||
self.transform = transform
|
||||
self.image_transforms = image_transforms
|
||||
self.delta_timestamps = delta_timestamps
|
||||
self.stats = aggregate_stats(self._datasets)
|
||||
|
||||
|
@ -380,6 +381,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
|||
for data_key in self.disabled_data_keys:
|
||||
if data_key in item:
|
||||
del item[data_key]
|
||||
|
||||
return item
|
||||
|
||||
def __repr__(self):
|
||||
|
@ -394,6 +396,6 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
|||
f" Recorded Frames per Second: {self.fps},\n"
|
||||
f" Camera Keys: {self.camera_keys},\n"
|
||||
f" Video Frame Keys: {self.video_frame_keys if self.video else 'N/A'},\n"
|
||||
f" Transformations: {self.transform},\n"
|
||||
f" Transformations: {self.image_transforms},\n"
|
||||
f")"
|
||||
)
|
||||
|
|
|
@ -14,156 +14,119 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This file contains all obsolete download scripts. They are centralized here to not have to load
|
||||
useless dependencies when using datasets.
|
||||
This file contains download scripts for raw datasets.
|
||||
|
||||
Example of usage:
|
||||
```
|
||||
python lerobot/common/datasets/push_dataset_to_hub/_download_raw.py \
|
||||
--raw-dir data/cadene/pusht_raw \
|
||||
--repo-id cadene/pusht_raw
|
||||
```
|
||||
"""
|
||||
|
||||
import io
|
||||
import argparse
|
||||
import logging
|
||||
import shutil
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
|
||||
import tqdm
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
|
||||
def download_raw(raw_dir, dataset_id):
|
||||
if "aloha" in dataset_id or "image" in dataset_id:
|
||||
download_hub(raw_dir, dataset_id)
|
||||
elif "pusht" in dataset_id:
|
||||
download_pusht(raw_dir)
|
||||
elif "xarm" in dataset_id:
|
||||
download_xarm(raw_dir)
|
||||
elif "umi" in dataset_id:
|
||||
download_umi(raw_dir)
|
||||
else:
|
||||
raise ValueError(dataset_id)
|
||||
def download_raw(raw_dir: Path, repo_id: str):
|
||||
# Check repo_id is well formated
|
||||
if len(repo_id.split("/")) != 2:
|
||||
raise ValueError(
|
||||
f"`repo_id` is expected to contain a community or user id `/` the name of the dataset (e.g. 'lerobot/pusht'), but contains '{repo_id}'."
|
||||
)
|
||||
user_id, dataset_id = repo_id.split("/")
|
||||
|
||||
|
||||
def download_and_extract_zip(url: str, destination_folder: Path) -> bool:
|
||||
import zipfile
|
||||
|
||||
import requests
|
||||
|
||||
print(f"downloading from {url}")
|
||||
response = requests.get(url, stream=True)
|
||||
if response.status_code == 200:
|
||||
total_size = int(response.headers.get("content-length", 0))
|
||||
progress_bar = tqdm.tqdm(total=total_size, unit="B", unit_scale=True)
|
||||
|
||||
zip_file = io.BytesIO()
|
||||
for chunk in response.iter_content(chunk_size=1024):
|
||||
if chunk:
|
||||
zip_file.write(chunk)
|
||||
progress_bar.update(len(chunk))
|
||||
|
||||
progress_bar.close()
|
||||
|
||||
zip_file.seek(0)
|
||||
|
||||
with zipfile.ZipFile(zip_file, "r") as zip_ref:
|
||||
zip_ref.extractall(destination_folder)
|
||||
|
||||
|
||||
def download_pusht(raw_dir: str):
|
||||
pusht_url = "https://diffusion-policy.cs.columbia.edu/data/training/pusht.zip"
|
||||
if not dataset_id.endswith("_raw"):
|
||||
warnings.warn(
|
||||
f"`dataset_id` ({dataset_id}) doesn't end with '_raw' (e.g. 'lerobot/pusht_raw'). Following this naming convention by renaming your repository is advised, but not mandatory.",
|
||||
stacklevel=1,
|
||||
)
|
||||
|
||||
raw_dir = Path(raw_dir)
|
||||
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||
download_and_extract_zip(pusht_url, raw_dir)
|
||||
# file is created inside a useful "pusht" directory, so we move it out and delete the dir
|
||||
zarr_path = raw_dir / "pusht_cchi_v7_replay.zarr"
|
||||
shutil.move(raw_dir / "pusht" / "pusht_cchi_v7_replay.zarr", zarr_path)
|
||||
shutil.rmtree(raw_dir / "pusht")
|
||||
|
||||
|
||||
def download_xarm(raw_dir: Path):
|
||||
"""Download all xarm datasets at once"""
|
||||
import zipfile
|
||||
|
||||
import gdown
|
||||
|
||||
raw_dir = Path(raw_dir)
|
||||
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||
# from https://github.com/fyhMer/fowm/blob/main/scripts/download_datasets.py
|
||||
url = "https://drive.google.com/uc?id=1nhxpykGtPDhmQKm-_B8zBSywVRdgeVya"
|
||||
zip_path = raw_dir / "data.zip"
|
||||
gdown.download(url, str(zip_path), quiet=False)
|
||||
print("Extracting...")
|
||||
with zipfile.ZipFile(str(zip_path), "r") as zip_f:
|
||||
for pkl_path in zip_f.namelist():
|
||||
if pkl_path.startswith("data/xarm") and pkl_path.endswith(".pkl"):
|
||||
zip_f.extract(member=pkl_path)
|
||||
# move to corresponding raw directory
|
||||
extract_dir = pkl_path.replace("/buffer.pkl", "")
|
||||
raw_pkl_path = raw_dir / "buffer.pkl"
|
||||
shutil.move(pkl_path, raw_pkl_path)
|
||||
shutil.rmtree(extract_dir)
|
||||
zip_path.unlink()
|
||||
|
||||
|
||||
def download_hub(raw_dir: Path, dataset_id: str):
|
||||
raw_dir = Path(raw_dir)
|
||||
# Send warning if raw_dir isn't well formated
|
||||
if raw_dir.parts[-2] != user_id or raw_dir.parts[-1] != dataset_id:
|
||||
warnings.warn(
|
||||
f"`raw_dir` ({raw_dir}) doesn't contain a community or user id `/` the name of the dataset that match the `repo_id` (e.g. 'data/lerobot/pusht_raw'). Following this naming convention is advised, but not mandatory.",
|
||||
stacklevel=1,
|
||||
)
|
||||
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
logging.info(f"Start downloading from huggingface.co/cadene for {dataset_id}")
|
||||
snapshot_download(f"cadene/{dataset_id}_raw", repo_type="dataset", local_dir=raw_dir)
|
||||
logging.info(f"Finish downloading from huggingface.co/cadene for {dataset_id}")
|
||||
logging.info(f"Start downloading from huggingface.co/{user_id} for {dataset_id}")
|
||||
snapshot_download(f"{repo_id}", repo_type="dataset", local_dir=raw_dir)
|
||||
logging.info(f"Finish downloading from huggingface.co/{user_id} for {dataset_id}")
|
||||
|
||||
|
||||
def download_umi(raw_dir: Path):
|
||||
url_cup_in_the_wild = "https://real.stanford.edu/umi/data/zarr_datasets/cup_in_the_wild.zarr.zip"
|
||||
zarr_path = raw_dir / "cup_in_the_wild.zarr"
|
||||
def download_all_raw_datasets():
|
||||
data_dir = Path("data")
|
||||
repo_ids = [
|
||||
"cadene/pusht_image_raw",
|
||||
"cadene/xarm_lift_medium_image_raw",
|
||||
"cadene/xarm_lift_medium_replay_image_raw",
|
||||
"cadene/xarm_push_medium_image_raw",
|
||||
"cadene/xarm_push_medium_replay_image_raw",
|
||||
"cadene/aloha_sim_insertion_human_image_raw",
|
||||
"cadene/aloha_sim_insertion_scripted_image_raw",
|
||||
"cadene/aloha_sim_transfer_cube_human_image_raw",
|
||||
"cadene/aloha_sim_transfer_cube_scripted_image_raw",
|
||||
"cadene/pusht_raw",
|
||||
"cadene/xarm_lift_medium_raw",
|
||||
"cadene/xarm_lift_medium_replay_raw",
|
||||
"cadene/xarm_push_medium_raw",
|
||||
"cadene/xarm_push_medium_replay_raw",
|
||||
"cadene/aloha_sim_insertion_human_raw",
|
||||
"cadene/aloha_sim_insertion_scripted_raw",
|
||||
"cadene/aloha_sim_transfer_cube_human_raw",
|
||||
"cadene/aloha_sim_transfer_cube_scripted_raw",
|
||||
"cadene/aloha_mobile_cabinet_raw",
|
||||
"cadene/aloha_mobile_chair_raw",
|
||||
"cadene/aloha_mobile_elevator_raw",
|
||||
"cadene/aloha_mobile_shrimp_raw",
|
||||
"cadene/aloha_mobile_wash_pan_raw",
|
||||
"cadene/aloha_mobile_wipe_wine_raw",
|
||||
"cadene/aloha_static_battery_raw",
|
||||
"cadene/aloha_static_candy_raw",
|
||||
"cadene/aloha_static_coffee_raw",
|
||||
"cadene/aloha_static_coffee_new_raw",
|
||||
"cadene/aloha_static_cups_open_raw",
|
||||
"cadene/aloha_static_fork_pick_up_raw",
|
||||
"cadene/aloha_static_pingpong_test_raw",
|
||||
"cadene/aloha_static_pro_pencil_raw",
|
||||
"cadene/aloha_static_screw_driver_raw",
|
||||
"cadene/aloha_static_tape_raw",
|
||||
"cadene/aloha_static_thread_velcro_raw",
|
||||
"cadene/aloha_static_towel_raw",
|
||||
"cadene/aloha_static_vinh_cup_raw",
|
||||
"cadene/aloha_static_vinh_cup_left_raw",
|
||||
"cadene/aloha_static_ziploc_slide_raw",
|
||||
"cadene/umi_cup_in_the_wild_raw",
|
||||
]
|
||||
for repo_id in repo_ids:
|
||||
raw_dir = data_dir / repo_id
|
||||
download_raw(raw_dir, repo_id)
|
||||
|
||||
raw_dir = Path(raw_dir)
|
||||
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||
download_and_extract_zip(url_cup_in_the_wild, zarr_path)
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--raw-dir",
|
||||
type=Path,
|
||||
required=True,
|
||||
help="Directory containing input raw datasets (e.g. `data/aloha_mobile_chair_raw` or `data/pusht_raw).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Repositery identifier on Hugging Face: a community or a user name `/` the name of the dataset (e.g. `lerobot/pusht_raw`, `cadene/aloha_sim_insertion_human_raw`).",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
download_raw(**vars(args))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
data_dir = Path("data")
|
||||
dataset_ids = [
|
||||
"pusht_image",
|
||||
"xarm_lift_medium_image",
|
||||
"xarm_lift_medium_replay_image",
|
||||
"xarm_push_medium_image",
|
||||
"xarm_push_medium_replay_image",
|
||||
"aloha_sim_insertion_human_image",
|
||||
"aloha_sim_insertion_scripted_image",
|
||||
"aloha_sim_transfer_cube_human_image",
|
||||
"aloha_sim_transfer_cube_scripted_image",
|
||||
"pusht",
|
||||
"xarm_lift_medium",
|
||||
"xarm_lift_medium_replay",
|
||||
"xarm_push_medium",
|
||||
"xarm_push_medium_replay",
|
||||
"aloha_sim_insertion_human",
|
||||
"aloha_sim_insertion_scripted",
|
||||
"aloha_sim_transfer_cube_human",
|
||||
"aloha_sim_transfer_cube_scripted",
|
||||
"aloha_mobile_cabinet",
|
||||
"aloha_mobile_chair",
|
||||
"aloha_mobile_elevator",
|
||||
"aloha_mobile_shrimp",
|
||||
"aloha_mobile_wash_pan",
|
||||
"aloha_mobile_wipe_wine",
|
||||
"aloha_static_battery",
|
||||
"aloha_static_candy",
|
||||
"aloha_static_coffee",
|
||||
"aloha_static_coffee_new",
|
||||
"aloha_static_cups_open",
|
||||
"aloha_static_fork_pick_up",
|
||||
"aloha_static_pingpong_test",
|
||||
"aloha_static_pro_pencil",
|
||||
"aloha_static_screw_driver",
|
||||
"aloha_static_tape",
|
||||
"aloha_static_thread_velcro",
|
||||
"aloha_static_towel",
|
||||
"aloha_static_vinh_cup",
|
||||
"aloha_static_vinh_cup_left",
|
||||
"aloha_static_ziploc_slide",
|
||||
"umi_cup_in_the_wild",
|
||||
]
|
||||
for dataset_id in dataset_ids:
|
||||
raw_dir = data_dir / f"{dataset_id}_raw"
|
||||
download_raw(raw_dir, dataset_id)
|
||||
main()
|
||||
|
|
|
@ -30,6 +30,7 @@ from PIL import Image as PILImage
|
|||
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently
|
||||
from lerobot.common.datasets.utils import (
|
||||
calculate_episode_data_index,
|
||||
hf_transform_to_torch,
|
||||
)
|
||||
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
|
||||
|
@ -70,16 +71,17 @@ def check_format(raw_dir) -> bool:
|
|||
assert c < h and c < w, f"Expect (h,w,c) image format but ({h=},{w=},{c=}) provided."
|
||||
|
||||
|
||||
def load_from_raw(raw_dir, out_dir, fps, video, debug):
|
||||
def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episodes: list[int] | None = None):
|
||||
# only frames from simulation are uncompressed
|
||||
compressed_images = "sim" not in raw_dir.name
|
||||
|
||||
hdf5_files = list(raw_dir.glob("*.hdf5"))
|
||||
ep_dicts = []
|
||||
episode_data_index = {"from": [], "to": []}
|
||||
hdf5_files = sorted(raw_dir.glob("episode_*.hdf5"))
|
||||
num_episodes = len(hdf5_files)
|
||||
|
||||
id_from = 0
|
||||
for ep_idx, ep_path in tqdm.tqdm(enumerate(hdf5_files), total=len(hdf5_files)):
|
||||
ep_dicts = []
|
||||
ep_ids = episodes if episodes else range(num_episodes)
|
||||
for ep_idx in tqdm.tqdm(ep_ids):
|
||||
ep_path = hdf5_files[ep_idx]
|
||||
with h5py.File(ep_path, "r") as ep:
|
||||
num_frames = ep["/action"].shape[0]
|
||||
|
||||
|
@ -114,12 +116,12 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
|
|||
|
||||
if video:
|
||||
# save png images in temporary directory
|
||||
tmp_imgs_dir = out_dir / "tmp_images"
|
||||
tmp_imgs_dir = videos_dir / "tmp_images"
|
||||
save_images_concurrently(imgs_array, tmp_imgs_dir)
|
||||
|
||||
# encode images to a mp4 video
|
||||
fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
|
||||
video_path = out_dir / "videos" / fname
|
||||
video_path = videos_dir / fname
|
||||
encode_video_frames(tmp_imgs_dir, video_path, fps)
|
||||
|
||||
# clean temporary images directory
|
||||
|
@ -147,19 +149,13 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
|
|||
assert isinstance(ep_idx, int)
|
||||
ep_dicts.append(ep_dict)
|
||||
|
||||
episode_data_index["from"].append(id_from)
|
||||
episode_data_index["to"].append(id_from + num_frames)
|
||||
|
||||
id_from += num_frames
|
||||
|
||||
gc.collect()
|
||||
|
||||
# process first episode only
|
||||
if debug:
|
||||
break
|
||||
|
||||
data_dict = concatenate_episodes(ep_dicts)
|
||||
return data_dict, episode_data_index
|
||||
|
||||
total_frames = data_dict["frame_index"].shape[0]
|
||||
data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||
return data_dict
|
||||
|
||||
|
||||
def to_hf_dataset(data_dict, video) -> Dataset:
|
||||
|
@ -197,16 +193,22 @@ def to_hf_dataset(data_dict, video) -> Dataset:
|
|||
return hf_dataset
|
||||
|
||||
|
||||
def from_raw_to_lerobot_format(raw_dir: Path, out_dir: Path, fps=None, video=True, debug=False):
|
||||
def from_raw_to_lerobot_format(
|
||||
raw_dir: Path,
|
||||
videos_dir: Path,
|
||||
fps: int | None = None,
|
||||
video: bool = True,
|
||||
episodes: list[int] | None = None,
|
||||
):
|
||||
# sanity check
|
||||
check_format(raw_dir)
|
||||
|
||||
if fps is None:
|
||||
fps = 50
|
||||
|
||||
data_dir, episode_data_index = load_from_raw(raw_dir, out_dir, fps, video, debug)
|
||||
hf_dataset = to_hf_dataset(data_dir, video)
|
||||
|
||||
data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes)
|
||||
hf_dataset = to_hf_dataset(data_dict, video)
|
||||
episode_data_index = calculate_episode_data_index(hf_dataset)
|
||||
info = {
|
||||
"fps": fps,
|
||||
"video": video,
|
||||
|
|
|
@ -17,7 +17,6 @@
|
|||
Contains utilities to process raw data format from dora-record
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
|
@ -26,10 +25,10 @@ import torch
|
|||
from datasets import Dataset, Features, Image, Sequence, Value
|
||||
|
||||
from lerobot.common.datasets.utils import (
|
||||
calculate_episode_data_index,
|
||||
hf_transform_to_torch,
|
||||
)
|
||||
from lerobot.common.datasets.video_utils import VideoFrame
|
||||
from lerobot.common.utils.utils import init_logging
|
||||
|
||||
|
||||
def check_format(raw_dir) -> bool:
|
||||
|
@ -41,7 +40,7 @@ def check_format(raw_dir) -> bool:
|
|||
return True
|
||||
|
||||
|
||||
def load_from_raw(raw_dir: Path, out_dir: Path, fps: int):
|
||||
def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episodes: list[int] | None = None):
|
||||
# Load data stream that will be used as reference for the timestamps synchronization
|
||||
reference_files = list(raw_dir.glob("observation.images.cam_*.parquet"))
|
||||
if len(reference_files) == 0:
|
||||
|
@ -122,8 +121,7 @@ def load_from_raw(raw_dir: Path, out_dir: Path, fps: int):
|
|||
raise ValueError(f"Episodes indices go from {ep_ids} instead of {expected_ep_ids}")
|
||||
|
||||
# Create symlink to raw videos directory (that needs to be absolute not relative)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
videos_dir = out_dir / "videos"
|
||||
videos_dir.parent.mkdir(parents=True, exist_ok=True)
|
||||
videos_dir.symlink_to((raw_dir / "videos").absolute())
|
||||
|
||||
# sanity check the video paths are well formated
|
||||
|
@ -156,16 +154,7 @@ def load_from_raw(raw_dir: Path, out_dir: Path, fps: int):
|
|||
else:
|
||||
raise ValueError(key)
|
||||
|
||||
# Get the episode index containing for each unique episode index
|
||||
first_ep_index_df = df.groupby("episode_index").agg(start_index=("index", "first")).reset_index()
|
||||
from_ = first_ep_index_df["start_index"].tolist()
|
||||
to_ = from_[1:] + [len(df)]
|
||||
episode_data_index = {
|
||||
"from": from_,
|
||||
"to": to_,
|
||||
}
|
||||
|
||||
return data_dict, episode_data_index
|
||||
return data_dict
|
||||
|
||||
|
||||
def to_hf_dataset(data_dict, video) -> Dataset:
|
||||
|
@ -203,12 +192,13 @@ def to_hf_dataset(data_dict, video) -> Dataset:
|
|||
return hf_dataset
|
||||
|
||||
|
||||
def from_raw_to_lerobot_format(raw_dir: Path, out_dir: Path, fps=None, video=True, debug=False):
|
||||
init_logging()
|
||||
|
||||
if debug:
|
||||
logging.warning("debug=True not implemented. Falling back to debug=False.")
|
||||
|
||||
def from_raw_to_lerobot_format(
|
||||
raw_dir: Path,
|
||||
videos_dir: Path,
|
||||
fps: int | None = None,
|
||||
video: bool = True,
|
||||
episodes: list[int] | None = None,
|
||||
):
|
||||
# sanity check
|
||||
check_format(raw_dir)
|
||||
|
||||
|
@ -220,9 +210,9 @@ def from_raw_to_lerobot_format(raw_dir: Path, out_dir: Path, fps=None, video=Tru
|
|||
if not video:
|
||||
raise NotImplementedError()
|
||||
|
||||
data_df, episode_data_index = load_from_raw(raw_dir, out_dir, fps)
|
||||
data_df = load_from_raw(raw_dir, videos_dir, fps, episodes)
|
||||
hf_dataset = to_hf_dataset(data_df, video)
|
||||
|
||||
episode_data_index = calculate_episode_data_index(hf_dataset)
|
||||
info = {
|
||||
"fps": fps,
|
||||
"video": video,
|
|
@ -27,6 +27,7 @@ from PIL import Image as PILImage
|
|||
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently
|
||||
from lerobot.common.datasets.utils import (
|
||||
calculate_episode_data_index,
|
||||
hf_transform_to_torch,
|
||||
)
|
||||
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
|
||||
|
@ -53,7 +54,7 @@ def check_format(raw_dir):
|
|||
assert all(nb_frames == zarr_data[dataset].shape[0] for dataset in required_datasets)
|
||||
|
||||
|
||||
def load_from_raw(raw_dir, out_dir, fps, video, debug):
|
||||
def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episodes: list[int] | None = None):
|
||||
try:
|
||||
import pymunk
|
||||
from gym_pusht.envs.pusht import PushTEnv, pymunk_to_shapely
|
||||
|
@ -71,7 +72,6 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
|
|||
zarr_data = DiffusionPolicyReplayBuffer.copy_from_path(zarr_path)
|
||||
|
||||
episode_ids = torch.from_numpy(zarr_data.get_episode_idxs())
|
||||
num_episodes = zarr_data.meta["episode_ends"].shape[0]
|
||||
assert len(
|
||||
{zarr_data[key].shape[0] for key in zarr_data.keys()} # noqa: SIM118
|
||||
), "Some data type dont have the same number of total frames."
|
||||
|
@ -84,25 +84,34 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
|
|||
states = torch.from_numpy(zarr_data["state"])
|
||||
actions = torch.from_numpy(zarr_data["action"])
|
||||
|
||||
ep_dicts = []
|
||||
episode_data_index = {"from": [], "to": []}
|
||||
# load data indices from which each episode starts and ends
|
||||
from_ids, to_ids = [], []
|
||||
from_idx = 0
|
||||
for to_idx in zarr_data.meta["episode_ends"]:
|
||||
from_ids.append(from_idx)
|
||||
to_ids.append(to_idx)
|
||||
from_idx = to_idx
|
||||
|
||||
id_from = 0
|
||||
for ep_idx in tqdm.tqdm(range(num_episodes)):
|
||||
id_to = zarr_data.meta["episode_ends"][ep_idx]
|
||||
num_frames = id_to - id_from
|
||||
num_episodes = len(from_ids)
|
||||
|
||||
ep_dicts = []
|
||||
ep_ids = episodes if episodes else range(num_episodes)
|
||||
for ep_idx, selected_ep_idx in tqdm.tqdm(enumerate(ep_ids)):
|
||||
from_idx = from_ids[selected_ep_idx]
|
||||
to_idx = to_ids[selected_ep_idx]
|
||||
num_frames = to_idx - from_idx
|
||||
|
||||
# sanity check
|
||||
assert (episode_ids[id_from:id_to] == ep_idx).all()
|
||||
assert (episode_ids[from_idx:to_idx] == ep_idx).all()
|
||||
|
||||
# get image
|
||||
image = imgs[id_from:id_to]
|
||||
image = imgs[from_idx:to_idx]
|
||||
assert image.min() >= 0.0
|
||||
assert image.max() <= 255.0
|
||||
image = image.type(torch.uint8)
|
||||
|
||||
# get state
|
||||
state = states[id_from:id_to]
|
||||
state = states[from_idx:to_idx]
|
||||
agent_pos = state[:, :2]
|
||||
block_pos = state[:, 2:4]
|
||||
block_angle = state[:, 4]
|
||||
|
@ -143,12 +152,12 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
|
|||
img_key = "observation.image"
|
||||
if video:
|
||||
# save png images in temporary directory
|
||||
tmp_imgs_dir = out_dir / "tmp_images"
|
||||
tmp_imgs_dir = videos_dir / "tmp_images"
|
||||
save_images_concurrently(imgs_array, tmp_imgs_dir)
|
||||
|
||||
# encode images to a mp4 video
|
||||
fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
|
||||
video_path = out_dir / "videos" / fname
|
||||
video_path = videos_dir / fname
|
||||
encode_video_frames(tmp_imgs_dir, video_path, fps)
|
||||
|
||||
# clean temporary images directory
|
||||
|
@ -160,7 +169,7 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
|
|||
ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array]
|
||||
|
||||
ep_dict["observation.state"] = agent_pos
|
||||
ep_dict["action"] = actions[id_from:id_to]
|
||||
ep_dict["action"] = actions[from_idx:to_idx]
|
||||
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int64)
|
||||
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
|
||||
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
|
||||
|
@ -172,17 +181,11 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
|
|||
ep_dict["next.success"] = torch.cat([success[1:], success[[-1]]])
|
||||
ep_dicts.append(ep_dict)
|
||||
|
||||
episode_data_index["from"].append(id_from)
|
||||
episode_data_index["to"].append(id_from + num_frames)
|
||||
|
||||
id_from += num_frames
|
||||
|
||||
# process first episode only
|
||||
if debug:
|
||||
break
|
||||
|
||||
data_dict = concatenate_episodes(ep_dicts)
|
||||
return data_dict, episode_data_index
|
||||
|
||||
total_frames = data_dict["frame_index"].shape[0]
|
||||
data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||
return data_dict
|
||||
|
||||
|
||||
def to_hf_dataset(data_dict, video):
|
||||
|
@ -212,16 +215,22 @@ def to_hf_dataset(data_dict, video):
|
|||
return hf_dataset
|
||||
|
||||
|
||||
def from_raw_to_lerobot_format(raw_dir: Path, out_dir: Path, fps=None, video=True, debug=False):
|
||||
def from_raw_to_lerobot_format(
|
||||
raw_dir: Path,
|
||||
videos_dir: Path,
|
||||
fps: int | None = None,
|
||||
video: bool = True,
|
||||
episodes: list[int] | None = None,
|
||||
):
|
||||
# sanity check
|
||||
check_format(raw_dir)
|
||||
|
||||
if fps is None:
|
||||
fps = 10
|
||||
|
||||
data_dict, episode_data_index = load_from_raw(raw_dir, out_dir, fps, video, debug)
|
||||
data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes)
|
||||
hf_dataset = to_hf_dataset(data_dict, video)
|
||||
|
||||
episode_data_index = calculate_episode_data_index(hf_dataset)
|
||||
info = {
|
||||
"fps": fps,
|
||||
"video": video,
|
||||
|
|
|
@ -19,7 +19,6 @@ import logging
|
|||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import tqdm
|
||||
import zarr
|
||||
|
@ -29,6 +28,7 @@ from PIL import Image as PILImage
|
|||
from lerobot.common.datasets.push_dataset_to_hub._umi_imagecodecs_numcodecs import register_codecs
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently
|
||||
from lerobot.common.datasets.utils import (
|
||||
calculate_episode_data_index,
|
||||
hf_transform_to_torch,
|
||||
)
|
||||
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
|
||||
|
@ -59,23 +59,7 @@ def check_format(raw_dir) -> bool:
|
|||
assert all(nb_frames == zarr_data[dataset].shape[0] for dataset in required_datasets)
|
||||
|
||||
|
||||
def get_episode_idxs(episode_ends: np.ndarray) -> np.ndarray:
|
||||
# Optimized and simplified version of this function: https://github.com/real-stanford/universal_manipulation_interface/blob/298776ce251f33b6b3185a98d6e7d1f9ad49168b/diffusion_policy/common/replay_buffer.py#L374
|
||||
from numba import jit
|
||||
|
||||
@jit(nopython=True)
|
||||
def _get_episode_idxs(episode_ends):
|
||||
result = np.zeros((episode_ends[-1],), dtype=np.int64)
|
||||
start_idx = 0
|
||||
for episode_number, end_idx in enumerate(episode_ends):
|
||||
result[start_idx:end_idx] = episode_number
|
||||
start_idx = end_idx
|
||||
return result
|
||||
|
||||
return _get_episode_idxs(episode_ends)
|
||||
|
||||
|
||||
def load_from_raw(raw_dir, out_dir, fps, video, debug):
|
||||
def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episodes: list[int] | None = None):
|
||||
zarr_path = raw_dir / "cup_in_the_wild.zarr"
|
||||
zarr_data = zarr.open(zarr_path, mode="r")
|
||||
|
||||
|
@ -92,39 +76,41 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
|
|||
episode_ends = zarr_data["meta/episode_ends"][:]
|
||||
num_episodes = episode_ends.shape[0]
|
||||
|
||||
episode_ids = torch.from_numpy(get_episode_idxs(episode_ends))
|
||||
|
||||
# We convert it in torch tensor later because the jit function does not support torch tensors
|
||||
episode_ends = torch.from_numpy(episode_ends)
|
||||
|
||||
# load data indices from which each episode starts and ends
|
||||
from_ids, to_ids = [], []
|
||||
from_idx = 0
|
||||
for to_idx in episode_ends:
|
||||
from_ids.append(from_idx)
|
||||
to_ids.append(to_idx)
|
||||
from_idx = to_idx
|
||||
|
||||
ep_dicts = []
|
||||
episode_data_index = {"from": [], "to": []}
|
||||
|
||||
id_from = 0
|
||||
for ep_idx in tqdm.tqdm(range(num_episodes)):
|
||||
id_to = episode_ends[ep_idx]
|
||||
num_frames = id_to - id_from
|
||||
|
||||
# sanity heck
|
||||
assert (episode_ids[id_from:id_to] == ep_idx).all()
|
||||
ep_ids = episodes if episodes else range(num_episodes)
|
||||
for ep_idx, selected_ep_idx in tqdm.tqdm(enumerate(ep_ids)):
|
||||
from_idx = from_ids[selected_ep_idx]
|
||||
to_idx = to_ids[selected_ep_idx]
|
||||
num_frames = to_idx - from_idx
|
||||
|
||||
# TODO(rcadene): save temporary images of the episode?
|
||||
|
||||
state = states[id_from:id_to]
|
||||
state = states[from_idx:to_idx]
|
||||
|
||||
ep_dict = {}
|
||||
|
||||
# load 57MB of images in RAM (400x224x224x3 uint8)
|
||||
imgs_array = zarr_data["data/camera0_rgb"][id_from:id_to]
|
||||
imgs_array = zarr_data["data/camera0_rgb"][from_idx:to_idx]
|
||||
img_key = "observation.image"
|
||||
if video:
|
||||
# save png images in temporary directory
|
||||
tmp_imgs_dir = out_dir / "tmp_images"
|
||||
tmp_imgs_dir = videos_dir / "tmp_images"
|
||||
save_images_concurrently(imgs_array, tmp_imgs_dir)
|
||||
|
||||
# encode images to a mp4 video
|
||||
fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
|
||||
video_path = out_dir / "videos" / fname
|
||||
video_path = videos_dir / fname
|
||||
encode_video_frames(tmp_imgs_dir, video_path, fps)
|
||||
|
||||
# clean temporary images directory
|
||||
|
@ -139,27 +125,18 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
|
|||
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int64)
|
||||
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
|
||||
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
|
||||
ep_dict["episode_data_index_from"] = torch.tensor([id_from] * num_frames)
|
||||
ep_dict["episode_data_index_to"] = torch.tensor([id_from + num_frames] * num_frames)
|
||||
ep_dict["end_pose"] = end_pose[id_from:id_to]
|
||||
ep_dict["start_pos"] = start_pos[id_from:id_to]
|
||||
ep_dict["gripper_width"] = gripper_width[id_from:id_to]
|
||||
ep_dict["episode_data_index_from"] = torch.tensor([from_idx] * num_frames)
|
||||
ep_dict["episode_data_index_to"] = torch.tensor([from_idx + num_frames] * num_frames)
|
||||
ep_dict["end_pose"] = end_pose[from_idx:to_idx]
|
||||
ep_dict["start_pos"] = start_pos[from_idx:to_idx]
|
||||
ep_dict["gripper_width"] = gripper_width[from_idx:to_idx]
|
||||
ep_dicts.append(ep_dict)
|
||||
|
||||
episode_data_index["from"].append(id_from)
|
||||
episode_data_index["to"].append(id_from + num_frames)
|
||||
id_from += num_frames
|
||||
|
||||
# process first episode only
|
||||
if debug:
|
||||
break
|
||||
|
||||
data_dict = concatenate_episodes(ep_dicts)
|
||||
|
||||
total_frames = id_from
|
||||
total_frames = data_dict["frame_index"].shape[0]
|
||||
data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||
|
||||
return data_dict, episode_data_index
|
||||
return data_dict
|
||||
|
||||
|
||||
def to_hf_dataset(data_dict, video):
|
||||
|
@ -199,7 +176,13 @@ def to_hf_dataset(data_dict, video):
|
|||
return hf_dataset
|
||||
|
||||
|
||||
def from_raw_to_lerobot_format(raw_dir: Path, out_dir: Path, fps=None, video=True, debug=False):
|
||||
def from_raw_to_lerobot_format(
|
||||
raw_dir: Path,
|
||||
videos_dir: Path,
|
||||
fps: int | None = None,
|
||||
video: bool = True,
|
||||
episodes: list[int] | None = None,
|
||||
):
|
||||
# sanity check
|
||||
check_format(raw_dir)
|
||||
|
||||
|
@ -212,9 +195,9 @@ def from_raw_to_lerobot_format(raw_dir: Path, out_dir: Path, fps=None, video=Tru
|
|||
"Generating UMI dataset without `video=True` creates ~150GB on disk and requires ~80GB in RAM."
|
||||
)
|
||||
|
||||
data_dict, episode_data_index = load_from_raw(raw_dir, out_dir, fps, video, debug)
|
||||
data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes)
|
||||
hf_dataset = to_hf_dataset(data_dict, video)
|
||||
|
||||
episode_data_index = calculate_episode_data_index(hf_dataset)
|
||||
info = {
|
||||
"fps": fps,
|
||||
"video": video,
|
||||
|
|
|
@ -27,6 +27,7 @@ from PIL import Image as PILImage
|
|||
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently
|
||||
from lerobot.common.datasets.utils import (
|
||||
calculate_episode_data_index,
|
||||
hf_transform_to_torch,
|
||||
)
|
||||
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
|
||||
|
@ -54,37 +55,42 @@ def check_format(raw_dir):
|
|||
assert all(len(nested_dict[subkey]) == expected_len for subkey in subkeys if subkey in nested_dict)
|
||||
|
||||
|
||||
def load_from_raw(raw_dir, out_dir, fps, video, debug):
|
||||
def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episodes: list[int] | None = None):
|
||||
pkl_path = raw_dir / "buffer.pkl"
|
||||
|
||||
with open(pkl_path, "rb") as f:
|
||||
pkl_data = pickle.load(f)
|
||||
|
||||
ep_dicts = []
|
||||
episode_data_index = {"from": [], "to": []}
|
||||
|
||||
id_from = 0
|
||||
id_to = 0
|
||||
ep_idx = 0
|
||||
total_frames = pkl_data["actions"].shape[0]
|
||||
for i in tqdm.tqdm(range(total_frames)):
|
||||
id_to += 1
|
||||
|
||||
if not pkl_data["dones"][i]:
|
||||
# load data indices from which each episode starts and ends
|
||||
from_ids, to_ids = [], []
|
||||
from_idx, to_idx = 0, 0
|
||||
for done in pkl_data["dones"]:
|
||||
to_idx += 1
|
||||
if not done:
|
||||
continue
|
||||
from_ids.append(from_idx)
|
||||
to_ids.append(to_idx)
|
||||
from_idx = to_idx
|
||||
|
||||
num_frames = id_to - id_from
|
||||
num_episodes = len(from_ids)
|
||||
|
||||
image = torch.tensor(pkl_data["observations"]["rgb"][id_from:id_to])
|
||||
ep_dicts = []
|
||||
ep_ids = episodes if episodes else range(num_episodes)
|
||||
for ep_idx, selected_ep_idx in tqdm.tqdm(enumerate(ep_ids)):
|
||||
from_idx = from_ids[selected_ep_idx]
|
||||
to_idx = to_ids[selected_ep_idx]
|
||||
num_frames = to_idx - from_idx
|
||||
|
||||
image = torch.tensor(pkl_data["observations"]["rgb"][from_idx:to_idx])
|
||||
image = einops.rearrange(image, "b c h w -> b h w c")
|
||||
state = torch.tensor(pkl_data["observations"]["state"][id_from:id_to])
|
||||
action = torch.tensor(pkl_data["actions"][id_from:id_to])
|
||||
state = torch.tensor(pkl_data["observations"]["state"][from_idx:to_idx])
|
||||
action = torch.tensor(pkl_data["actions"][from_idx:to_idx])
|
||||
# TODO(rcadene): we have a missing last frame which is the observation when the env is done
|
||||
# it is critical to have this frame for tdmpc to predict a "done observation/state"
|
||||
# next_image = torch.tensor(pkl_data["next_observations"]["rgb"][id_from:id_to])
|
||||
# next_state = torch.tensor(pkl_data["next_observations"]["state"][id_from:id_to])
|
||||
next_reward = torch.tensor(pkl_data["rewards"][id_from:id_to])
|
||||
next_done = torch.tensor(pkl_data["dones"][id_from:id_to])
|
||||
# next_image = torch.tensor(pkl_data["next_observations"]["rgb"][from_idx:to_idx])
|
||||
# next_state = torch.tensor(pkl_data["next_observations"]["state"][from_idx:to_idx])
|
||||
next_reward = torch.tensor(pkl_data["rewards"][from_idx:to_idx])
|
||||
next_done = torch.tensor(pkl_data["dones"][from_idx:to_idx])
|
||||
|
||||
ep_dict = {}
|
||||
|
||||
|
@ -92,12 +98,12 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
|
|||
img_key = "observation.image"
|
||||
if video:
|
||||
# save png images in temporary directory
|
||||
tmp_imgs_dir = out_dir / "tmp_images"
|
||||
tmp_imgs_dir = videos_dir / "tmp_images"
|
||||
save_images_concurrently(imgs_array, tmp_imgs_dir)
|
||||
|
||||
# encode images to a mp4 video
|
||||
fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
|
||||
video_path = out_dir / "videos" / fname
|
||||
video_path = videos_dir / fname
|
||||
encode_video_frames(tmp_imgs_dir, video_path, fps)
|
||||
|
||||
# clean temporary images directory
|
||||
|
@ -119,18 +125,11 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
|
|||
ep_dict["next.done"] = next_done
|
||||
ep_dicts.append(ep_dict)
|
||||
|
||||
episode_data_index["from"].append(id_from)
|
||||
episode_data_index["to"].append(id_from + num_frames)
|
||||
|
||||
id_from = id_to
|
||||
ep_idx += 1
|
||||
|
||||
# process first episode only
|
||||
if debug:
|
||||
break
|
||||
|
||||
data_dict = concatenate_episodes(ep_dicts)
|
||||
return data_dict, episode_data_index
|
||||
|
||||
total_frames = data_dict["frame_index"].shape[0]
|
||||
data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||
return data_dict
|
||||
|
||||
|
||||
def to_hf_dataset(data_dict, video):
|
||||
|
@ -161,16 +160,22 @@ def to_hf_dataset(data_dict, video):
|
|||
return hf_dataset
|
||||
|
||||
|
||||
def from_raw_to_lerobot_format(raw_dir: Path, out_dir: Path, fps=None, video=True, debug=False):
|
||||
def from_raw_to_lerobot_format(
|
||||
raw_dir: Path,
|
||||
videos_dir: Path,
|
||||
fps: int | None = None,
|
||||
video: bool = True,
|
||||
episodes: list[int] | None = None,
|
||||
):
|
||||
# sanity check
|
||||
check_format(raw_dir)
|
||||
|
||||
if fps is None:
|
||||
fps = 15
|
||||
|
||||
data_dict, episode_data_index = load_from_raw(raw_dir, out_dir, fps, video, debug)
|
||||
data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes)
|
||||
hf_dataset = to_hf_dataset(data_dict, video)
|
||||
|
||||
episode_data_index = calculate_episode_data_index(hf_dataset)
|
||||
info = {
|
||||
"fps": fps,
|
||||
"video": video,
|
||||
|
|
|
@ -0,0 +1,197 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import collections
|
||||
from typing import Any, Callable, Dict, Sequence
|
||||
|
||||
import torch
|
||||
from torchvision.transforms import v2
|
||||
from torchvision.transforms.v2 import Transform
|
||||
from torchvision.transforms.v2 import functional as F # noqa: N812
|
||||
|
||||
|
||||
class RandomSubsetApply(Transform):
|
||||
"""Apply a random subset of N transformations from a list of transformations.
|
||||
|
||||
Args:
|
||||
transforms: list of transformations.
|
||||
p: represents the multinomial probabilities (with no replacement) used for sampling the transform.
|
||||
If the sum of the weights is not 1, they will be normalized. If ``None`` (default), all transforms
|
||||
have the same probability.
|
||||
n_subset: number of transformations to apply. If ``None``, all transforms are applied.
|
||||
Must be in [1, len(transforms)].
|
||||
random_order: apply transformations in a random order.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
transforms: Sequence[Callable],
|
||||
p: list[float] | None = None,
|
||||
n_subset: int | None = None,
|
||||
random_order: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if not isinstance(transforms, Sequence):
|
||||
raise TypeError("Argument transforms should be a sequence of callables")
|
||||
if p is None:
|
||||
p = [1] * len(transforms)
|
||||
elif len(p) != len(transforms):
|
||||
raise ValueError(
|
||||
f"Length of p doesn't match the number of transforms: {len(p)} != {len(transforms)}"
|
||||
)
|
||||
|
||||
if n_subset is None:
|
||||
n_subset = len(transforms)
|
||||
elif not isinstance(n_subset, int):
|
||||
raise TypeError("n_subset should be an int or None")
|
||||
elif not (1 <= n_subset <= len(transforms)):
|
||||
raise ValueError(f"n_subset should be in the interval [1, {len(transforms)}]")
|
||||
|
||||
self.transforms = transforms
|
||||
total = sum(p)
|
||||
self.p = [prob / total for prob in p]
|
||||
self.n_subset = n_subset
|
||||
self.random_order = random_order
|
||||
|
||||
def forward(self, *inputs: Any) -> Any:
|
||||
needs_unpacking = len(inputs) > 1
|
||||
|
||||
selected_indices = torch.multinomial(torch.tensor(self.p), self.n_subset)
|
||||
if not self.random_order:
|
||||
selected_indices = selected_indices.sort().values
|
||||
|
||||
selected_transforms = [self.transforms[i] for i in selected_indices]
|
||||
|
||||
for transform in selected_transforms:
|
||||
outputs = transform(*inputs)
|
||||
inputs = outputs if needs_unpacking else (outputs,)
|
||||
|
||||
return outputs
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return (
|
||||
f"transforms={self.transforms}, "
|
||||
f"p={self.p}, "
|
||||
f"n_subset={self.n_subset}, "
|
||||
f"random_order={self.random_order}"
|
||||
)
|
||||
|
||||
|
||||
class SharpnessJitter(Transform):
|
||||
"""Randomly change the sharpness of an image or video.
|
||||
|
||||
Similar to a v2.RandomAdjustSharpness with p=1 and a sharpness_factor sampled randomly.
|
||||
While v2.RandomAdjustSharpness applies — with a given probability — a fixed sharpness_factor to an image,
|
||||
SharpnessJitter applies a random sharpness_factor each time. This is to have a more diverse set of
|
||||
augmentations as a result.
|
||||
|
||||
A sharpness_factor of 0 gives a blurred image, 1 gives the original image while 2 increases the sharpness
|
||||
by a factor of 2.
|
||||
|
||||
If the input is a :class:`torch.Tensor`,
|
||||
it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
|
||||
|
||||
Args:
|
||||
sharpness: How much to jitter sharpness. sharpness_factor is chosen uniformly from
|
||||
[max(0, 1 - sharpness), 1 + sharpness] or the given
|
||||
[min, max]. Should be non negative numbers.
|
||||
"""
|
||||
|
||||
def __init__(self, sharpness: float | Sequence[float]) -> None:
|
||||
super().__init__()
|
||||
self.sharpness = self._check_input(sharpness)
|
||||
|
||||
def _check_input(self, sharpness):
|
||||
if isinstance(sharpness, (int, float)):
|
||||
if sharpness < 0:
|
||||
raise ValueError("If sharpness is a single number, it must be non negative.")
|
||||
sharpness = [1.0 - sharpness, 1.0 + sharpness]
|
||||
sharpness[0] = max(sharpness[0], 0.0)
|
||||
elif isinstance(sharpness, collections.abc.Sequence) and len(sharpness) == 2:
|
||||
sharpness = [float(v) for v in sharpness]
|
||||
else:
|
||||
raise TypeError(f"{sharpness=} should be a single number or a sequence with length 2.")
|
||||
|
||||
if not 0.0 <= sharpness[0] <= sharpness[1]:
|
||||
raise ValueError(f"sharpnesss values should be between (0., inf), but got {sharpness}.")
|
||||
|
||||
return float(sharpness[0]), float(sharpness[1])
|
||||
|
||||
def _generate_value(self, left: float, right: float) -> float:
|
||||
return torch.empty(1).uniform_(left, right).item()
|
||||
|
||||
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
||||
sharpness_factor = self._generate_value(self.sharpness[0], self.sharpness[1])
|
||||
return self._call_kernel(F.adjust_sharpness, inpt, sharpness_factor=sharpness_factor)
|
||||
|
||||
|
||||
def get_image_transforms(
|
||||
brightness_weight: float = 1.0,
|
||||
brightness_min_max: tuple[float, float] | None = None,
|
||||
contrast_weight: float = 1.0,
|
||||
contrast_min_max: tuple[float, float] | None = None,
|
||||
saturation_weight: float = 1.0,
|
||||
saturation_min_max: tuple[float, float] | None = None,
|
||||
hue_weight: float = 1.0,
|
||||
hue_min_max: tuple[float, float] | None = None,
|
||||
sharpness_weight: float = 1.0,
|
||||
sharpness_min_max: tuple[float, float] | None = None,
|
||||
max_num_transforms: int | None = None,
|
||||
random_order: bool = False,
|
||||
):
|
||||
def check_value(name, weight, min_max):
|
||||
if min_max is not None:
|
||||
if len(min_max) != 2:
|
||||
raise ValueError(
|
||||
f"`{name}_min_max` is expected to be a tuple of 2 dimensions, but {min_max} provided."
|
||||
)
|
||||
if weight < 0.0:
|
||||
raise ValueError(
|
||||
f"`{name}_weight` is expected to be 0 or positive, but is negative ({weight})."
|
||||
)
|
||||
|
||||
check_value("brightness", brightness_weight, brightness_min_max)
|
||||
check_value("contrast", contrast_weight, contrast_min_max)
|
||||
check_value("saturation", saturation_weight, saturation_min_max)
|
||||
check_value("hue", hue_weight, hue_min_max)
|
||||
check_value("sharpness", sharpness_weight, sharpness_min_max)
|
||||
|
||||
weights = []
|
||||
transforms = []
|
||||
if brightness_min_max is not None and brightness_weight > 0.0:
|
||||
weights.append(brightness_weight)
|
||||
transforms.append(v2.ColorJitter(brightness=brightness_min_max))
|
||||
if contrast_min_max is not None and contrast_weight > 0.0:
|
||||
weights.append(contrast_weight)
|
||||
transforms.append(v2.ColorJitter(contrast=contrast_min_max))
|
||||
if saturation_min_max is not None and saturation_weight > 0.0:
|
||||
weights.append(saturation_weight)
|
||||
transforms.append(v2.ColorJitter(saturation=saturation_min_max))
|
||||
if hue_min_max is not None and hue_weight > 0.0:
|
||||
weights.append(hue_weight)
|
||||
transforms.append(v2.ColorJitter(hue=hue_min_max))
|
||||
if sharpness_min_max is not None and sharpness_weight > 0.0:
|
||||
weights.append(sharpness_weight)
|
||||
transforms.append(SharpnessJitter(sharpness=sharpness_min_max))
|
||||
|
||||
n_subset = len(transforms)
|
||||
if max_num_transforms is not None:
|
||||
n_subset = min(n_subset, max_num_transforms)
|
||||
|
||||
if n_subset == 0:
|
||||
return v2.Identity()
|
||||
else:
|
||||
# TODO(rcadene, aliberts): add v2.ToDtype float16?
|
||||
return RandomSubsetApply(transforms, p=weights, n_subset=n_subset, random_order=random_order)
|
|
@ -238,5 +238,6 @@ class Logger:
|
|||
|
||||
def log_video(self, video_path: str, step: int, mode: str = "train"):
|
||||
assert mode in {"train", "eval"}
|
||||
assert self._wandb is not None
|
||||
wandb_video = self._wandb.Video(video_path, fps=self._cfg.fps, format="mp4")
|
||||
self._wandb.log({f"{mode}/video": wandb_video}, step=step)
|
||||
|
|
|
@ -239,10 +239,8 @@ class DiffusionModel(nn.Module):
|
|||
global_cond = torch.cat([batch["observation.state"], img_features], dim=-1).flatten(start_dim=1)
|
||||
|
||||
# run sampling
|
||||
sample = self.conditional_sample(batch_size, global_cond=global_cond)
|
||||
actions = self.conditional_sample(batch_size, global_cond=global_cond)
|
||||
|
||||
# `horizon` steps worth of actions (from the first observation).
|
||||
actions = sample[..., : self.config.output_shapes["action"][0]]
|
||||
# Extract `n_action_steps` steps worth of actions (from the current observation).
|
||||
start = n_obs_steps - 1
|
||||
end = start + self.config.n_action_steps
|
||||
|
|
|
@ -147,7 +147,7 @@ class Normalize(nn.Module):
|
|||
assert not torch.isinf(min).any(), _no_stats_error_str("min")
|
||||
assert not torch.isinf(max).any(), _no_stats_error_str("max")
|
||||
# normalize to [0,1]
|
||||
batch[key] = (batch[key] - min) / (max - min)
|
||||
batch[key] = (batch[key] - min) / (max - min + 1e-8)
|
||||
# normalize to [-1, 1]
|
||||
batch[key] = batch[key] * 2 - 1
|
||||
else:
|
||||
|
|
|
@ -57,7 +57,7 @@ class Policy(Protocol):
|
|||
other items should be logging-friendly, native Python types.
|
||||
"""
|
||||
|
||||
def select_action(self, batch: dict[str, Tensor]):
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Return one action to run in the environment (potentially in batch mode).
|
||||
|
||||
When the model uses a history of observations, or outputs a sequence of actions, this method deals
|
||||
|
|
|
@ -134,7 +134,7 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
self._prev_mean: torch.Tensor | None = None
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor]):
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Select a single action given environment observations."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch["observation.image"] = batch[self.input_image_key]
|
||||
|
|
|
@ -43,6 +43,40 @@ training:
|
|||
save_checkpoint: true
|
||||
num_workers: 4
|
||||
batch_size: ???
|
||||
image_transforms:
|
||||
# These transforms are all using standard torchvision.transforms.v2
|
||||
# You can find out how these transformations affect images here:
|
||||
# https://pytorch.org/vision/0.18/auto_examples/transforms/plot_transforms_illustrations.html
|
||||
# We use a custom RandomSubsetApply container to sample them.
|
||||
# For each transform, the following parameters are available:
|
||||
# weight: This represents the multinomial probability (with no replacement)
|
||||
# used for sampling the transform. If the sum of the weights is not 1,
|
||||
# they will be normalized.
|
||||
# min_max: Lower & upper bound respectively used for sampling the transform's parameter
|
||||
# (following uniform distribution) when it's applied.
|
||||
# Set this flag to `true` to enable transforms during training
|
||||
enable: false
|
||||
# This is the maximum number of transforms (sampled from these below) that will be applied to each frame.
|
||||
# It's an integer in the interval [1, number of available transforms].
|
||||
max_num_transforms: 3
|
||||
# By default, transforms are applied in Torchvision's suggested order (shown below).
|
||||
# Set this to True to apply them in a random order.
|
||||
random_order: false
|
||||
brightness:
|
||||
weight: 1
|
||||
min_max: [0.8, 1.2]
|
||||
contrast:
|
||||
weight: 1
|
||||
min_max: [0.8, 1.2]
|
||||
saturation:
|
||||
weight: 1
|
||||
min_max: [0.5, 1.5]
|
||||
hue:
|
||||
weight: 1
|
||||
min_max: [-0.05, 0.05]
|
||||
sharpness:
|
||||
weight: 1
|
||||
min_max: [0.8, 1.2]
|
||||
|
||||
eval:
|
||||
n_episodes: 1
|
||||
|
|
|
@ -13,39 +13,71 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Use this script to get a quick summary of your system config.
|
||||
It should be able to run without any of LeRobot's dependencies or LeRobot itself installed.
|
||||
"""
|
||||
|
||||
import platform
|
||||
|
||||
import huggingface_hub
|
||||
HAS_HF_HUB = True
|
||||
HAS_HF_DATASETS = True
|
||||
HAS_NP = True
|
||||
HAS_TORCH = True
|
||||
HAS_LEROBOT = True
|
||||
|
||||
# import dataset
|
||||
import numpy as np
|
||||
import torch
|
||||
try:
|
||||
import huggingface_hub
|
||||
except ImportError:
|
||||
HAS_HF_HUB = False
|
||||
|
||||
from lerobot import __version__ as version
|
||||
try:
|
||||
import datasets
|
||||
except ImportError:
|
||||
HAS_HF_DATASETS = False
|
||||
|
||||
pt_version = torch.__version__
|
||||
pt_cuda_available = torch.cuda.is_available()
|
||||
pt_cuda_available = torch.cuda.is_available()
|
||||
cuda_version = torch._C._cuda_getCompiledVersion() if torch.version.cuda is not None else "N/A"
|
||||
try:
|
||||
import numpy as np
|
||||
except ImportError:
|
||||
HAS_NP = False
|
||||
|
||||
try:
|
||||
import torch
|
||||
except ImportError:
|
||||
HAS_TORCH = False
|
||||
|
||||
try:
|
||||
import lerobot
|
||||
except ImportError:
|
||||
HAS_LEROBOT = False
|
||||
|
||||
|
||||
lerobot_version = lerobot.__version__ if HAS_LEROBOT else "N/A"
|
||||
hf_hub_version = huggingface_hub.__version__ if HAS_HF_HUB else "N/A"
|
||||
hf_datasets_version = datasets.__version__ if HAS_HF_DATASETS else "N/A"
|
||||
np_version = np.__version__ if HAS_NP else "N/A"
|
||||
|
||||
torch_version = torch.__version__ if HAS_TORCH else "N/A"
|
||||
torch_cuda_available = torch.cuda.is_available() if HAS_TORCH else "N/A"
|
||||
cuda_version = torch._C._cuda_getCompiledVersion() if HAS_TORCH and torch.version.cuda is not None else "N/A"
|
||||
|
||||
|
||||
# TODO(aliberts): refactor into an actual command `lerobot env`
|
||||
def display_sys_info() -> dict:
|
||||
"""Run this to get basic system info to help for tracking issues & bugs."""
|
||||
info = {
|
||||
"`lerobot` version": version,
|
||||
"`lerobot` version": lerobot_version,
|
||||
"Platform": platform.platform(),
|
||||
"Python version": platform.python_version(),
|
||||
"Huggingface_hub version": huggingface_hub.__version__,
|
||||
# TODO(aliberts): Add dataset when https://github.com/huggingface/lerobot/pull/73 is merged
|
||||
# "Dataset version": dataset.__version__,
|
||||
"Numpy version": np.__version__,
|
||||
"PyTorch version (GPU?)": f"{pt_version} ({pt_cuda_available})",
|
||||
"Huggingface_hub version": hf_hub_version,
|
||||
"Dataset version": hf_datasets_version,
|
||||
"Numpy version": np_version,
|
||||
"PyTorch version (GPU?)": f"{torch_version} ({torch_cuda_available})",
|
||||
"Cuda version": cuda_version,
|
||||
"Using GPU in script?": "<fill in>",
|
||||
"Using distributed or parallel set-up in script?": "<fill in>",
|
||||
# "Using distributed or parallel set-up in script?": "<fill in>",
|
||||
}
|
||||
print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the two last points.\n")
|
||||
print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the last point.\n")
|
||||
print(format_dict(info))
|
||||
return info
|
||||
|
||||
|
|
|
@ -61,7 +61,7 @@ from huggingface_hub import snapshot_download
|
|||
from huggingface_hub.utils._errors import RepositoryNotFoundError
|
||||
from huggingface_hub.utils._validators import HFValidationError
|
||||
from PIL import Image as PILImage
|
||||
from torch import Tensor
|
||||
from torch import Tensor, nn
|
||||
from tqdm import trange
|
||||
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
|
@ -99,13 +99,13 @@ def rollout(
|
|||
"reward": A (batch, sequence) tensor of rewards received for applying the actions.
|
||||
"success": A (batch, sequence) tensor of success conditions (the only time this can be True is upon
|
||||
environment termination/truncation).
|
||||
"don": A (batch, sequence) tensor of **cumulative** done conditions. For any given batch element,
|
||||
"done": A (batch, sequence) tensor of **cumulative** done conditions. For any given batch element,
|
||||
the first True is followed by True's all the way till the end. This can be used for masking
|
||||
extraneous elements from the sequences above.
|
||||
|
||||
Args:
|
||||
env: The batch of environments.
|
||||
policy: The policy.
|
||||
policy: The policy. Must be a PyTorch nn module.
|
||||
seeds: The environments are seeded once at the start of the rollout. If provided, this argument
|
||||
specifies the seeds for each of the environments.
|
||||
return_observations: Whether to include all observations in the returned rollout data. Observations
|
||||
|
@ -116,6 +116,7 @@ def rollout(
|
|||
Returns:
|
||||
The dictionary described above.
|
||||
"""
|
||||
assert isinstance(policy, nn.Module), "Policy must be a PyTorch nn module."
|
||||
device = get_device_from_parameters(policy)
|
||||
|
||||
# Reset the policy and environments.
|
||||
|
@ -209,7 +210,7 @@ def eval_policy(
|
|||
policy: torch.nn.Module,
|
||||
n_episodes: int,
|
||||
max_episodes_rendered: int = 0,
|
||||
video_dir: Path | None = None,
|
||||
videos_dir: Path | None = None,
|
||||
return_episode_data: bool = False,
|
||||
start_seed: int | None = None,
|
||||
enable_progbar: bool = False,
|
||||
|
@ -221,7 +222,7 @@ def eval_policy(
|
|||
policy: The policy.
|
||||
n_episodes: The number of episodes to evaluate.
|
||||
max_episodes_rendered: Maximum number of episodes to render into videos.
|
||||
video_dir: Where to save rendered videos.
|
||||
videos_dir: Where to save rendered videos.
|
||||
return_episode_data: Whether to return episode data for online training. Incorporates the data into
|
||||
the "episodes" key of the returned dictionary.
|
||||
start_seed: The first seed to use for the first individual rollout. For all subsequent rollouts the
|
||||
|
@ -231,6 +232,10 @@ def eval_policy(
|
|||
Returns:
|
||||
Dictionary with metrics and data regarding the rollouts.
|
||||
"""
|
||||
if max_episodes_rendered > 0 and not videos_dir:
|
||||
raise ValueError("If max_episodes_rendered > 0, videos_dir must be provided.")
|
||||
|
||||
assert isinstance(policy, Policy)
|
||||
start = time.time()
|
||||
policy.eval()
|
||||
|
||||
|
@ -271,11 +276,16 @@ def eval_policy(
|
|||
if max_episodes_rendered > 0:
|
||||
ep_frames: list[np.ndarray] = []
|
||||
|
||||
seeds = range(start_seed + (batch_ix * env.num_envs), start_seed + ((batch_ix + 1) * env.num_envs))
|
||||
if start_seed is None:
|
||||
seeds = None
|
||||
else:
|
||||
seeds = range(
|
||||
start_seed + (batch_ix * env.num_envs), start_seed + ((batch_ix + 1) * env.num_envs)
|
||||
)
|
||||
rollout_data = rollout(
|
||||
env,
|
||||
policy,
|
||||
seeds=seeds,
|
||||
seeds=list(seeds) if seeds else None,
|
||||
return_observations=return_episode_data,
|
||||
render_callback=render_frame if max_episodes_rendered > 0 else None,
|
||||
enable_progbar=enable_inner_progbar,
|
||||
|
@ -285,7 +295,8 @@ def eval_policy(
|
|||
# this won't be included).
|
||||
n_steps = rollout_data["done"].shape[1]
|
||||
# Note: this relies on a property of argmax: that it returns the first occurrence as a tiebreaker.
|
||||
done_indices = torch.argmax(rollout_data["done"].to(int), axis=1) # (batch_size, rollout_steps)
|
||||
done_indices = torch.argmax(rollout_data["done"].to(int), dim=1)
|
||||
|
||||
# Make a mask with shape (batch, n_steps) to mask out rollout data after the first done
|
||||
# (batch-element-wise). Note the `done_indices + 1` to make sure to keep the data from the done step.
|
||||
mask = (torch.arange(n_steps) <= einops.repeat(done_indices + 1, "b -> b s", s=n_steps)).int()
|
||||
|
@ -296,8 +307,12 @@ def eval_policy(
|
|||
max_rewards.extend(batch_max_rewards.tolist())
|
||||
batch_successes = einops.reduce((rollout_data["success"] * mask), "b n -> b", "any")
|
||||
all_successes.extend(batch_successes.tolist())
|
||||
all_seeds.extend(seeds)
|
||||
if seeds:
|
||||
all_seeds.extend(seeds)
|
||||
else:
|
||||
all_seeds.append(None)
|
||||
|
||||
# FIXME: episode_data is either None or it doesn't exist
|
||||
if return_episode_data:
|
||||
this_episode_data = _compile_episode_data(
|
||||
rollout_data,
|
||||
|
@ -347,8 +362,9 @@ def eval_policy(
|
|||
):
|
||||
if n_episodes_rendered >= max_episodes_rendered:
|
||||
break
|
||||
video_dir.mkdir(parents=True, exist_ok=True)
|
||||
video_path = video_dir / f"eval_episode_{n_episodes_rendered}.mp4"
|
||||
|
||||
videos_dir.mkdir(parents=True, exist_ok=True)
|
||||
video_path = videos_dir / f"eval_episode_{n_episodes_rendered}.mp4"
|
||||
video_paths.append(str(video_path))
|
||||
thread = threading.Thread(
|
||||
target=write_video,
|
||||
|
@ -503,22 +519,20 @@ def _compile_episode_data(
|
|||
}
|
||||
|
||||
|
||||
def eval(
|
||||
pretrained_policy_path: str | None = None,
|
||||
def main(
|
||||
pretrained_policy_path: Path | None = None,
|
||||
hydra_cfg_path: str | None = None,
|
||||
out_dir: str | None = None,
|
||||
config_overrides: list[str] | None = None,
|
||||
):
|
||||
assert (pretrained_policy_path is None) ^ (hydra_cfg_path is None)
|
||||
if hydra_cfg_path is None:
|
||||
hydra_cfg = init_hydra_config(pretrained_policy_path / "config.yaml", config_overrides)
|
||||
if pretrained_policy_path is not None:
|
||||
hydra_cfg = init_hydra_config(str(pretrained_policy_path / "config.yaml"), config_overrides)
|
||||
else:
|
||||
hydra_cfg = init_hydra_config(hydra_cfg_path, config_overrides)
|
||||
out_dir = (
|
||||
f"outputs/eval/{dt.now().strftime('%Y-%m-%d/%H-%M-%S')}_{hydra_cfg.env.name}_{hydra_cfg.policy.name}"
|
||||
)
|
||||
|
||||
if out_dir is None:
|
||||
raise NotImplementedError()
|
||||
out_dir = f"outputs/eval/{dt.now().strftime('%Y-%m-%d/%H-%M-%S')}_{hydra_cfg.env.name}_{hydra_cfg.policy.name}"
|
||||
|
||||
# Check device is available
|
||||
device = get_safe_torch_device(hydra_cfg.device, log=True)
|
||||
|
@ -534,10 +548,12 @@ def eval(
|
|||
|
||||
logging.info("Making policy.")
|
||||
if hydra_cfg_path is None:
|
||||
policy = make_policy(hydra_cfg=hydra_cfg, pretrained_policy_name_or_path=pretrained_policy_path)
|
||||
policy = make_policy(hydra_cfg=hydra_cfg, pretrained_policy_name_or_path=str(pretrained_policy_path))
|
||||
else:
|
||||
# Note: We need the dataset stats to pass to the policy's normalization modules.
|
||||
policy = make_policy(hydra_cfg=hydra_cfg, dataset_stats=make_dataset(hydra_cfg).stats)
|
||||
|
||||
assert isinstance(policy, nn.Module)
|
||||
policy.eval()
|
||||
|
||||
with torch.no_grad(), torch.autocast(device_type=device.type) if hydra_cfg.use_amp else nullcontext():
|
||||
|
@ -546,7 +562,7 @@ def eval(
|
|||
policy,
|
||||
hydra_cfg.eval.n_episodes,
|
||||
max_episodes_rendered=10,
|
||||
video_dir=Path(out_dir) / "eval",
|
||||
videos_dir=Path(out_dir) / "videos",
|
||||
start_seed=hydra_cfg.seed,
|
||||
enable_progbar=True,
|
||||
enable_inner_progbar=True,
|
||||
|
@ -586,6 +602,13 @@ if __name__ == "__main__":
|
|||
),
|
||||
)
|
||||
parser.add_argument("--revision", help="Optionally provide the Hugging Face Hub revision ID.")
|
||||
parser.add_argument(
|
||||
"--out-dir",
|
||||
help=(
|
||||
"Where to save the evaluation outputs. If not provided, outputs are saved in "
|
||||
"outputs/eval/{timestamp}_{env_name}_{policy_name}"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"overrides",
|
||||
nargs="*",
|
||||
|
@ -594,7 +617,7 @@ if __name__ == "__main__":
|
|||
args = parser.parse_args()
|
||||
|
||||
if args.pretrained_policy_name_or_path is None:
|
||||
eval(hydra_cfg_path=args.config, config_overrides=args.overrides)
|
||||
main(hydra_cfg_path=args.config, out_dir=args.out_dir, config_overrides=args.overrides)
|
||||
else:
|
||||
try:
|
||||
pretrained_policy_path = Path(
|
||||
|
@ -618,4 +641,8 @@ if __name__ == "__main__":
|
|||
"repo ID, nor is it an existing local directory."
|
||||
)
|
||||
|
||||
eval(pretrained_policy_path=pretrained_policy_path, config_overrides=args.overrides)
|
||||
main(
|
||||
pretrained_policy_path=pretrained_policy_path,
|
||||
out_dir=args.out_dir,
|
||||
config_overrides=args.overrides,
|
||||
)
|
||||
|
|
|
@ -18,57 +18,39 @@ Use this script to convert your dataset into LeRobot dataset format and upload i
|
|||
or store it locally. LeRobot dataset format is lightweight, fast to load from, and does not require any
|
||||
installation of neural net specific packages like pytorch, tensorflow, jax.
|
||||
|
||||
Example:
|
||||
Example of how to download raw datasets, convert them into LeRobotDataset format, and push them to the hub:
|
||||
```
|
||||
python lerobot/scripts/push_dataset_to_hub.py \
|
||||
--data-dir data \
|
||||
--dataset-id pusht \
|
||||
--raw-dir data/pusht_raw \
|
||||
--raw-format pusht_zarr \
|
||||
--community-id lerobot \
|
||||
--dry-run 1 \
|
||||
--save-to-disk 1 \
|
||||
--save-tests-to-disk 0 \
|
||||
--debug 1
|
||||
--repo-id lerobot/pusht
|
||||
|
||||
python lerobot/scripts/push_dataset_to_hub.py \
|
||||
--data-dir data \
|
||||
--dataset-id xarm_lift_medium \
|
||||
--raw-dir data/xarm_lift_medium_raw \
|
||||
--raw-format xarm_pkl \
|
||||
--community-id lerobot \
|
||||
--dry-run 1 \
|
||||
--save-to-disk 1 \
|
||||
--save-tests-to-disk 0 \
|
||||
--debug 1
|
||||
--repo-id lerobot/xarm_lift_medium
|
||||
|
||||
python lerobot/scripts/push_dataset_to_hub.py \
|
||||
--data-dir data \
|
||||
--dataset-id aloha_sim_insertion_scripted \
|
||||
--raw-dir data/aloha_sim_insertion_scripted_raw \
|
||||
--raw-format aloha_hdf5 \
|
||||
--community-id lerobot \
|
||||
--dry-run 1 \
|
||||
--save-to-disk 1 \
|
||||
--save-tests-to-disk 0 \
|
||||
--debug 1
|
||||
--repo-id lerobot/aloha_sim_insertion_scripted
|
||||
|
||||
python lerobot/scripts/push_dataset_to_hub.py \
|
||||
--data-dir data \
|
||||
--dataset-id umi_cup_in_the_wild \
|
||||
--raw-dir data/umi_cup_in_the_wild_raw \
|
||||
--raw-format umi_zarr \
|
||||
--community-id lerobot \
|
||||
--dry-run 1 \
|
||||
--save-to-disk 1 \
|
||||
--save-tests-to-disk 0 \
|
||||
--debug 1
|
||||
--repo-id lerobot/umi_cup_in_the_wild
|
||||
```
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import shutil
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from huggingface_hub import HfApi
|
||||
from huggingface_hub import HfApi, create_branch
|
||||
from safetensors.torch import save_file
|
||||
|
||||
from lerobot.common.datasets.compute_stats import compute_stats
|
||||
|
@ -77,15 +59,15 @@ from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_r
|
|||
from lerobot.common.datasets.utils import flatten_dict
|
||||
|
||||
|
||||
def get_from_raw_to_lerobot_format_fn(raw_format):
|
||||
def get_from_raw_to_lerobot_format_fn(raw_format: str):
|
||||
if raw_format == "pusht_zarr":
|
||||
from lerobot.common.datasets.push_dataset_to_hub.pusht_zarr_format import from_raw_to_lerobot_format
|
||||
elif raw_format == "umi_zarr":
|
||||
from lerobot.common.datasets.push_dataset_to_hub.umi_zarr_format import from_raw_to_lerobot_format
|
||||
elif raw_format == "aloha_hdf5":
|
||||
from lerobot.common.datasets.push_dataset_to_hub.aloha_hdf5_format import from_raw_to_lerobot_format
|
||||
elif raw_format == "aloha_dora":
|
||||
from lerobot.common.datasets.push_dataset_to_hub.aloha_dora_format import from_raw_to_lerobot_format
|
||||
elif raw_format == "dora_parquet":
|
||||
from lerobot.common.datasets.push_dataset_to_hub.dora_parquet_format import from_raw_to_lerobot_format
|
||||
elif raw_format == "xarm_pkl":
|
||||
from lerobot.common.datasets.push_dataset_to_hub.xarm_pkl_format import from_raw_to_lerobot_format
|
||||
else:
|
||||
|
@ -96,7 +78,9 @@ def get_from_raw_to_lerobot_format_fn(raw_format):
|
|||
return from_raw_to_lerobot_format
|
||||
|
||||
|
||||
def save_meta_data(info, stats, episode_data_index, meta_data_dir):
|
||||
def save_meta_data(
|
||||
info: dict[str, Any], stats: dict, episode_data_index: dict[str, list], meta_data_dir: Path
|
||||
):
|
||||
meta_data_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# save info
|
||||
|
@ -114,7 +98,7 @@ def save_meta_data(info, stats, episode_data_index, meta_data_dir):
|
|||
save_file(episode_data_index, ep_data_idx_path)
|
||||
|
||||
|
||||
def push_meta_data_to_hub(repo_id, meta_data_dir, revision):
|
||||
def push_meta_data_to_hub(repo_id: str, meta_data_dir: str | Path, revision: str | None):
|
||||
"""Expect all meta data files to be all stored in a single "meta_data" directory.
|
||||
On the hugging face repositery, they will be uploaded in a "meta_data" directory at the root.
|
||||
"""
|
||||
|
@ -128,7 +112,7 @@ def push_meta_data_to_hub(repo_id, meta_data_dir, revision):
|
|||
)
|
||||
|
||||
|
||||
def push_videos_to_hub(repo_id, videos_dir, revision):
|
||||
def push_videos_to_hub(repo_id: str, videos_dir: str | Path, revision: str | None):
|
||||
"""Expect mp4 files to be all stored in a single "videos" directory.
|
||||
On the hugging face repositery, they will be uploaded in a "videos" directory at the root.
|
||||
"""
|
||||
|
@ -144,39 +128,61 @@ def push_videos_to_hub(repo_id, videos_dir, revision):
|
|||
|
||||
|
||||
def push_dataset_to_hub(
|
||||
data_dir: Path,
|
||||
dataset_id: str,
|
||||
raw_format: str | None,
|
||||
community_id: str,
|
||||
revision: str,
|
||||
dry_run: bool,
|
||||
save_to_disk: bool,
|
||||
tests_data_dir: Path,
|
||||
save_tests_to_disk: bool,
|
||||
fps: int | None,
|
||||
video: bool,
|
||||
batch_size: int,
|
||||
num_workers: int,
|
||||
debug: bool,
|
||||
raw_dir: Path,
|
||||
raw_format: str,
|
||||
repo_id: str,
|
||||
push_to_hub: bool = True,
|
||||
local_dir: Path | None = None,
|
||||
fps: int | None = None,
|
||||
video: bool = True,
|
||||
batch_size: int = 32,
|
||||
num_workers: int = 8,
|
||||
episodes: list[int] | None = None,
|
||||
force_override: bool = False,
|
||||
cache_dir: Path = Path("/tmp"),
|
||||
tests_data_dir: Path | None = None,
|
||||
):
|
||||
repo_id = f"{community_id}/{dataset_id}"
|
||||
# Check repo_id is well formated
|
||||
if len(repo_id.split("/")) != 2:
|
||||
raise ValueError(
|
||||
f"`repo_id` is expected to contain a community or user id `/` the name of the dataset (e.g. 'lerobot/pusht'), but instead contains '{repo_id}'."
|
||||
)
|
||||
user_id, dataset_id = repo_id.split("/")
|
||||
|
||||
raw_dir = data_dir / f"{dataset_id}_raw"
|
||||
# Robustify when `raw_dir` is str instead of Path
|
||||
raw_dir = Path(raw_dir)
|
||||
if not raw_dir.exists():
|
||||
raise NotADirectoryError(
|
||||
f"{raw_dir} does not exists. Check your paths or run this command to download an existing raw dataset on the hub:"
|
||||
f"python lerobot/common/datasets/push_dataset_to_hub/_download_raw.py --raw-dir your/raw/dir --repo-id your/repo/id_raw"
|
||||
)
|
||||
|
||||
out_dir = data_dir / repo_id
|
||||
meta_data_dir = out_dir / "meta_data"
|
||||
videos_dir = out_dir / "videos"
|
||||
if local_dir:
|
||||
# Robustify when `local_dir` is str instead of Path
|
||||
local_dir = Path(local_dir)
|
||||
|
||||
tests_out_dir = tests_data_dir / repo_id
|
||||
tests_meta_data_dir = tests_out_dir / "meta_data"
|
||||
tests_videos_dir = tests_out_dir / "videos"
|
||||
# Send warning if local_dir isn't well formated
|
||||
if local_dir.parts[-2] != user_id or local_dir.parts[-1] != dataset_id:
|
||||
warnings.warn(
|
||||
f"`local_dir` ({local_dir}) doesn't contain a community or user id `/` the name of the dataset that match the `repo_id` (e.g. 'data/lerobot/pusht'). Following this naming convention is advised, but not mandatory.",
|
||||
stacklevel=1,
|
||||
)
|
||||
|
||||
if out_dir.exists():
|
||||
shutil.rmtree(out_dir)
|
||||
# Check we don't override an existing `local_dir` by mistake
|
||||
if local_dir.exists():
|
||||
if force_override:
|
||||
shutil.rmtree(local_dir)
|
||||
else:
|
||||
raise ValueError(f"`local_dir` already exists ({local_dir}). Use `--force-override 1`.")
|
||||
|
||||
if tests_out_dir.exists() and save_tests_to_disk:
|
||||
shutil.rmtree(tests_out_dir)
|
||||
meta_data_dir = local_dir / "meta_data"
|
||||
videos_dir = local_dir / "videos"
|
||||
else:
|
||||
# Temporary directory used to store images, videos, meta_data
|
||||
meta_data_dir = Path(cache_dir) / "meta_data"
|
||||
videos_dir = Path(cache_dir) / "videos"
|
||||
|
||||
# Download the raw dataset if available
|
||||
if not raw_dir.exists():
|
||||
download_raw(raw_dir, dataset_id)
|
||||
|
||||
|
@ -185,14 +191,14 @@ def push_dataset_to_hub(
|
|||
raise NotImplementedError()
|
||||
# raw_format = auto_find_raw_format(raw_dir)
|
||||
|
||||
from_raw_to_lerobot_format = get_from_raw_to_lerobot_format_fn(raw_format)
|
||||
|
||||
# convert dataset from original raw format to LeRobot format
|
||||
hf_dataset, episode_data_index, info = from_raw_to_lerobot_format(raw_dir, out_dir, fps, video, debug)
|
||||
from_raw_to_lerobot_format = get_from_raw_to_lerobot_format_fn(raw_format)
|
||||
hf_dataset, episode_data_index, info = from_raw_to_lerobot_format(
|
||||
raw_dir, videos_dir, fps, video, episodes
|
||||
)
|
||||
|
||||
lerobot_dataset = LeRobotDataset.from_preloaded(
|
||||
repo_id=repo_id,
|
||||
version=revision,
|
||||
hf_dataset=hf_dataset,
|
||||
episode_data_index=episode_data_index,
|
||||
info=info,
|
||||
|
@ -200,102 +206,80 @@ def push_dataset_to_hub(
|
|||
)
|
||||
stats = compute_stats(lerobot_dataset, batch_size, num_workers)
|
||||
|
||||
if save_to_disk:
|
||||
if local_dir:
|
||||
hf_dataset = hf_dataset.with_format(None) # to remove transforms that cant be saved
|
||||
hf_dataset.save_to_disk(str(out_dir / "train"))
|
||||
hf_dataset.save_to_disk(str(local_dir / "train"))
|
||||
|
||||
if not dry_run or save_to_disk:
|
||||
if push_to_hub or local_dir:
|
||||
# mandatory for upload
|
||||
save_meta_data(info, stats, episode_data_index, meta_data_dir)
|
||||
|
||||
if not dry_run:
|
||||
hf_dataset.push_to_hub(repo_id, token=True, revision="main")
|
||||
hf_dataset.push_to_hub(repo_id, token=True, revision=revision)
|
||||
|
||||
if push_to_hub:
|
||||
hf_dataset.push_to_hub(repo_id, revision="main")
|
||||
push_meta_data_to_hub(repo_id, meta_data_dir, revision="main")
|
||||
push_meta_data_to_hub(repo_id, meta_data_dir, revision=revision)
|
||||
|
||||
if video:
|
||||
push_videos_to_hub(repo_id, videos_dir, revision="main")
|
||||
push_videos_to_hub(repo_id, videos_dir, revision=revision)
|
||||
create_branch(repo_id, repo_type="dataset", branch=CODEBASE_VERSION)
|
||||
|
||||
if save_tests_to_disk:
|
||||
if tests_data_dir:
|
||||
# get the first episode
|
||||
num_items_first_ep = episode_data_index["to"][0] - episode_data_index["from"][0]
|
||||
test_hf_dataset = hf_dataset.select(range(num_items_first_ep))
|
||||
|
||||
test_hf_dataset = test_hf_dataset.with_format(None)
|
||||
test_hf_dataset.save_to_disk(str(tests_out_dir / "train"))
|
||||
test_hf_dataset.save_to_disk(str(tests_data_dir / repo_id / "train"))
|
||||
|
||||
save_meta_data(info, stats, episode_data_index, tests_meta_data_dir)
|
||||
tests_meta_data = tests_data_dir / repo_id / "meta_data"
|
||||
save_meta_data(info, stats, episode_data_index, tests_meta_data)
|
||||
|
||||
# copy videos of first episode to tests directory
|
||||
episode_index = 0
|
||||
tests_videos_dir = tests_data_dir / repo_id / "videos"
|
||||
tests_videos_dir.mkdir(parents=True, exist_ok=True)
|
||||
for key in lerobot_dataset.video_frame_keys:
|
||||
fname = f"{key}_episode_{episode_index:06d}.mp4"
|
||||
shutil.copy(videos_dir / fname, tests_videos_dir / fname)
|
||||
|
||||
if not save_to_disk and out_dir.exists():
|
||||
# remove possible temporary files remaining in the output directory
|
||||
shutil.rmtree(out_dir)
|
||||
if local_dir is None:
|
||||
# clear cache
|
||||
shutil.rmtree(meta_data_dir)
|
||||
shutil.rmtree(videos_dir)
|
||||
|
||||
return lerobot_dataset
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--data-dir",
|
||||
"--raw-dir",
|
||||
type=Path,
|
||||
required=True,
|
||||
help="Root directory containing datasets (e.g. `data` or `tmp/data` or `/tmp/lerobot/data`).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-id",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Name of the dataset (e.g. `pusht`, `aloha_sim_insertion_human`), which matches the folder where the data is stored (e.g. `data/pusht`).",
|
||||
help="Directory containing input raw datasets (e.g. `data/aloha_mobile_chair_raw` or `data/pusht_raw).",
|
||||
)
|
||||
# TODO(rcadene): add automatic detection of the format
|
||||
parser.add_argument(
|
||||
"--raw-format",
|
||||
type=str,
|
||||
help="Dataset type (e.g. `pusht_zarr`, `umi_zarr`, `aloha_hdf5`, `xarm_pkl`). If not provided, will be detected automatically.",
|
||||
required=True,
|
||||
help="Dataset type (e.g. `pusht_zarr`, `umi_zarr`, `aloha_hdf5`, `xarm_pkl`, `dora_parquet`).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--community-id",
|
||||
"--repo-id",
|
||||
type=str,
|
||||
default="lerobot",
|
||||
help="Community or user ID under which the dataset will be hosted on the Hub.",
|
||||
required=True,
|
||||
help="Repositery identifier on Hugging Face: a community or a user name `/` the name of the dataset (e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--revision",
|
||||
type=str,
|
||||
default=CODEBASE_VERSION,
|
||||
help="Codebase version used to generate the dataset.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dry-run",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Run everything without uploading to hub, for testing purposes or storing a dataset locally.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save-to-disk",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Save the dataset in the directory specified by `--data-dir`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tests-data-dir",
|
||||
"--local-dir",
|
||||
type=Path,
|
||||
default="tests/data",
|
||||
help="Directory containing tests artifacts datasets.",
|
||||
help="When provided, writes the dataset converted to LeRobotDataset format in this directory (e.g. `data/lerobot/aloha_mobile_chair`).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save-tests-to-disk",
|
||||
"--push-to-hub",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Save the dataset with 1 episode used for unit tests in the directory specified by `--tests-data-dir`.",
|
||||
help="Upload to hub.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fps",
|
||||
|
@ -321,10 +305,21 @@ def main():
|
|||
help="Number of processes of Dataloader for computing the dataset statistics.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--debug",
|
||||
"--episodes",
|
||||
type=int,
|
||||
nargs="*",
|
||||
help="When provided, only converts the provided episodes (e.g `--episodes 2 3 4`). Useful to test the code on 1 episode.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--force-override",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Debug mode process the first episode only.",
|
||||
help="When set to 1, removes provided output directory if it already exists. By default, raises a ValueError exception.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tests-data-dir",
|
||||
type=Path,
|
||||
help="When provided, save tests artifacts into the given directory for (e.g. `--tests-data-dir tests/data/lerobot/pusht`).",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
|
|
@ -24,6 +24,7 @@ import torch
|
|||
from deepdiff import DeepDiff
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from termcolor import colored
|
||||
from torch import nn
|
||||
from torch.cuda.amp import GradScaler
|
||||
|
||||
from lerobot.common.datasets.factory import make_dataset, resolve_delta_timestamps
|
||||
|
@ -150,6 +151,7 @@ def log_train_info(logger: Logger, info, step, cfg, dataset, is_offline):
|
|||
grad_norm = info["grad_norm"]
|
||||
lr = info["lr"]
|
||||
update_s = info["update_s"]
|
||||
dataloading_s = info["dataloading_s"]
|
||||
|
||||
# A sample is an (observation,action) pair, where observation and action
|
||||
# can be on multiple timestamps. In a batch, we have `batch_size`` number of samples.
|
||||
|
@ -170,6 +172,7 @@ def log_train_info(logger: Logger, info, step, cfg, dataset, is_offline):
|
|||
f"lr:{lr:0.1e}",
|
||||
# in seconds
|
||||
f"updt_s:{update_s:.3f}",
|
||||
f"data_s:{dataloading_s:.3f}", # if not ~0, you are bottlenecked by cpu or io
|
||||
]
|
||||
logging.info(" ".join(log_items))
|
||||
|
||||
|
@ -290,6 +293,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||
# Create environment used for evaluating checkpoints during training on simulation data.
|
||||
# On real-world data, no need to create an environment as evaluations are done outside train.py,
|
||||
# using the eval.py instead, with gym_dora environment and dora-rs.
|
||||
eval_env = None
|
||||
if cfg.training.eval_freq > 0:
|
||||
logging.info("make_env")
|
||||
eval_env = make_env(cfg)
|
||||
|
@ -300,7 +304,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||
dataset_stats=offline_dataset.stats if not cfg.resume else None,
|
||||
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None,
|
||||
)
|
||||
|
||||
assert isinstance(policy, nn.Module)
|
||||
# Create optimizer and scheduler
|
||||
# Temporary hack to move optimizer out of policy
|
||||
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
|
||||
|
@ -325,14 +329,18 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||
|
||||
# Note: this helper will be used in offline and online training loops.
|
||||
def evaluate_and_checkpoint_if_needed(step):
|
||||
_num_digits = max(6, len(str(cfg.training.offline_steps + cfg.training.online_steps)))
|
||||
step_identifier = f"{step:0{_num_digits}d}"
|
||||
|
||||
if cfg.training.eval_freq > 0 and step % cfg.training.eval_freq == 0:
|
||||
logging.info(f"Eval policy at step {step}")
|
||||
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.use_amp else nullcontext():
|
||||
assert eval_env is not None
|
||||
eval_info = eval_policy(
|
||||
eval_env,
|
||||
policy,
|
||||
cfg.eval.n_episodes,
|
||||
video_dir=Path(out_dir) / "eval",
|
||||
videos_dir=Path(out_dir) / "eval" / f"videos_step_{step_identifier}",
|
||||
max_episodes_rendered=4,
|
||||
start_seed=cfg.seed,
|
||||
)
|
||||
|
@ -350,9 +358,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||
policy,
|
||||
optimizer,
|
||||
lr_scheduler,
|
||||
identifier=str(step).zfill(
|
||||
max(6, len(str(cfg.training.offline_steps + cfg.training.online_steps)))
|
||||
),
|
||||
identifier=step_identifier,
|
||||
)
|
||||
logging.info("Resume training")
|
||||
|
||||
|
@ -382,7 +388,10 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||
for _ in range(step, cfg.training.offline_steps):
|
||||
if step == 0:
|
||||
logging.info("Start offline training on a fixed dataset")
|
||||
|
||||
start_time = time.perf_counter()
|
||||
batch = next(dl_iter)
|
||||
dataloading_s = time.perf_counter() - start_time
|
||||
|
||||
for key in batch:
|
||||
batch[key] = batch[key].to(device, non_blocking=True)
|
||||
|
@ -397,6 +406,8 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||
use_amp=cfg.use_amp,
|
||||
)
|
||||
|
||||
train_info["dataloading_s"] = dataloading_s
|
||||
|
||||
if step % cfg.training.log_freq == 0:
|
||||
log_train_info(logger, train_info, step, cfg, offline_dataset, is_offline=True)
|
||||
|
||||
|
@ -406,7 +417,8 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||
|
||||
step += 1
|
||||
|
||||
eval_env.close()
|
||||
if eval_env:
|
||||
eval_env.close()
|
||||
logging.info("End of training")
|
||||
|
||||
|
||||
|
|
|
@ -66,28 +66,31 @@ import gc
|
|||
import logging
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Iterator
|
||||
|
||||
import numpy as np
|
||||
import rerun as rr
|
||||
import torch
|
||||
import torch.utils.data
|
||||
import tqdm
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
|
||||
class EpisodeSampler(torch.utils.data.Sampler):
|
||||
def __init__(self, dataset, episode_index):
|
||||
def __init__(self, dataset: LeRobotDataset, episode_index: int):
|
||||
from_idx = dataset.episode_data_index["from"][episode_index].item()
|
||||
to_idx = dataset.episode_data_index["to"][episode_index].item()
|
||||
self.frame_ids = range(from_idx, to_idx)
|
||||
|
||||
def __iter__(self):
|
||||
def __iter__(self) -> Iterator:
|
||||
return iter(self.frame_ids)
|
||||
|
||||
def __len__(self):
|
||||
def __len__(self) -> int:
|
||||
return len(self.frame_ids)
|
||||
|
||||
|
||||
def to_hwc_uint8_numpy(chw_float32_torch):
|
||||
def to_hwc_uint8_numpy(chw_float32_torch: torch.Tensor) -> np.ndarray:
|
||||
assert chw_float32_torch.dtype == torch.float32
|
||||
assert chw_float32_torch.ndim == 3
|
||||
c, h, w = chw_float32_torch.shape
|
||||
|
@ -106,6 +109,7 @@ def visualize_dataset(
|
|||
ws_port: int = 9087,
|
||||
save: bool = False,
|
||||
output_dir: Path | None = None,
|
||||
root: Path | None = None,
|
||||
) -> Path | None:
|
||||
if save:
|
||||
assert (
|
||||
|
@ -113,7 +117,7 @@ def visualize_dataset(
|
|||
), "Set an output directory where to write .rrd files with `--output-dir path/to/directory`."
|
||||
|
||||
logging.info("Loading dataset")
|
||||
dataset = LeRobotDataset(repo_id)
|
||||
dataset = LeRobotDataset(repo_id, root=root)
|
||||
|
||||
logging.info("Loading dataloader")
|
||||
episode_sampler = EpisodeSampler(dataset, episode_index)
|
||||
|
@ -224,7 +228,8 @@ def main():
|
|||
help=(
|
||||
"Mode of viewing between 'local' or 'distant'. "
|
||||
"'local' requires data to be on a local machine. It spawns a viewer to visualize the data locally. "
|
||||
"'distant' creates a server on the distant machine where the data is stored. Visualize the data by connecting to the server with `rerun ws://localhost:PORT` on the local machine."
|
||||
"'distant' creates a server on the distant machine where the data is stored. "
|
||||
"Visualize the data by connecting to the server with `rerun ws://localhost:PORT` on the local machine."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
|
@ -245,8 +250,8 @@ def main():
|
|||
default=0,
|
||||
help=(
|
||||
"Save a .rrd file in the directory provided by `--output-dir`. "
|
||||
"It also deactivates the spawning of a viewer. ",
|
||||
"Visualize the data by running `rerun path/to/file.rrd` on your local machine.",
|
||||
"It also deactivates the spawning of a viewer. "
|
||||
"Visualize the data by running `rerun path/to/file.rrd` on your local machine."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
|
|
|
@ -0,0 +1,142 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Visualize effects of image transforms for a given configuration.
|
||||
|
||||
This script will generate examples of transformed images as they are output by LeRobot dataset.
|
||||
Additionally, each individual transform can be visualized separately as well as examples of combined transforms
|
||||
|
||||
|
||||
--- Usage Examples ---
|
||||
|
||||
Increase hue jitter
|
||||
```
|
||||
python lerobot/scripts/visualize_image_transforms.py \
|
||||
dataset_repo_id=lerobot/aloha_mobile_shrimp \
|
||||
training.image_transforms.hue.min_max=[-0.25,0.25]
|
||||
```
|
||||
|
||||
Increase brightness & brightness weight
|
||||
```
|
||||
python lerobot/scripts/visualize_image_transforms.py \
|
||||
dataset_repo_id=lerobot/aloha_mobile_shrimp \
|
||||
training.image_transforms.brightness.weight=10.0 \
|
||||
training.image_transforms.brightness.min_max=[1.0,2.0]
|
||||
```
|
||||
|
||||
Blur images and disable saturation & hue
|
||||
```
|
||||
python lerobot/scripts/visualize_image_transforms.py \
|
||||
dataset_repo_id=lerobot/aloha_mobile_shrimp \
|
||||
training.image_transforms.sharpness.weight=10.0 \
|
||||
training.image_transforms.sharpness.min_max=[0.0,1.0] \
|
||||
training.image_transforms.saturation.weight=0.0 \
|
||||
training.image_transforms.hue.weight=0.0
|
||||
```
|
||||
|
||||
Use all transforms with random order
|
||||
```
|
||||
python lerobot/scripts/visualize_image_transforms.py \
|
||||
dataset_repo_id=lerobot/aloha_mobile_shrimp \
|
||||
training.image_transforms.max_num_transforms=5 \
|
||||
training.image_transforms.random_order=true
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import hydra
|
||||
from torchvision.transforms import ToPILImage
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.datasets.transforms import get_image_transforms
|
||||
|
||||
OUTPUT_DIR = Path("outputs/image_transforms")
|
||||
N_EXAMPLES = 5
|
||||
to_pil = ToPILImage()
|
||||
|
||||
|
||||
def save_config_all_transforms(cfg, original_frame, output_dir):
|
||||
tf = get_image_transforms(
|
||||
brightness_weight=cfg.brightness.weight,
|
||||
brightness_min_max=cfg.brightness.min_max,
|
||||
contrast_weight=cfg.contrast.weight,
|
||||
contrast_min_max=cfg.contrast.min_max,
|
||||
saturation_weight=cfg.saturation.weight,
|
||||
saturation_min_max=cfg.saturation.min_max,
|
||||
hue_weight=cfg.hue.weight,
|
||||
hue_min_max=cfg.hue.min_max,
|
||||
sharpness_weight=cfg.sharpness.weight,
|
||||
sharpness_min_max=cfg.sharpness.min_max,
|
||||
max_num_transforms=cfg.max_num_transforms,
|
||||
random_order=cfg.random_order,
|
||||
)
|
||||
|
||||
output_dir_all = output_dir / "all"
|
||||
output_dir_all.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
for i in range(1, N_EXAMPLES + 1):
|
||||
transformed_frame = tf(original_frame)
|
||||
to_pil(transformed_frame).save(output_dir_all / f"{i}.png", quality=100)
|
||||
|
||||
print("Combined transforms examples saved to:")
|
||||
print(f" {output_dir_all}")
|
||||
|
||||
|
||||
def save_config_single_transforms(cfg, original_frame, output_dir):
|
||||
transforms = [
|
||||
"brightness",
|
||||
"contrast",
|
||||
"saturation",
|
||||
"hue",
|
||||
"sharpness",
|
||||
]
|
||||
print("Individual transforms examples saved to:")
|
||||
for transform in transforms:
|
||||
kwargs = {
|
||||
f"{transform}_weight": cfg[f"{transform}"].weight,
|
||||
f"{transform}_min_max": cfg[f"{transform}"].min_max,
|
||||
}
|
||||
tf = get_image_transforms(**kwargs)
|
||||
output_dir_single = output_dir / f"{transform}"
|
||||
output_dir_single.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
for i in range(1, N_EXAMPLES + 1):
|
||||
transformed_frame = tf(original_frame)
|
||||
to_pil(transformed_frame).save(output_dir_single / f"{i}.png", quality=100)
|
||||
|
||||
print(f" {output_dir_single}")
|
||||
|
||||
|
||||
@hydra.main(version_base="1.2", config_name="default", config_path="../configs")
|
||||
def visualize_transforms(cfg):
|
||||
dataset = LeRobotDataset(cfg.dataset_repo_id)
|
||||
|
||||
output_dir = Path(OUTPUT_DIR) / cfg.dataset_repo_id.split("/")[-1]
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Get 1st frame from 1st camera of 1st episode
|
||||
original_frame = dataset[0][dataset.camera_keys[0]]
|
||||
to_pil(original_frame).save(output_dir / "original_frame.png", quality=100)
|
||||
print("\nOriginal frame saved to:")
|
||||
print(f" {output_dir / 'original_frame.png'}.")
|
||||
|
||||
save_config_all_transforms(cfg.training.image_transforms, original_frame, output_dir)
|
||||
save_config_single_transforms(cfg.training.image_transforms, original_frame, output_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
visualize_transforms()
|
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:36f50697dacc82d52d1799dbc53c6c2fb722b9c0bd5bfa90a92dfa336591c74a
|
||||
size 3686488
|
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:d0e3b4bde97c34606536b655c1e6a23316c9157bd21dcbc73a97500fb985607f
|
||||
size 40551392
|
|
@ -0,0 +1,86 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from safetensors.torch import save_file
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.datasets.transforms import get_image_transforms
|
||||
from lerobot.common.utils.utils import init_hydra_config, seeded_context
|
||||
from tests.test_image_transforms import ARTIFACT_DIR, DATASET_REPO_ID
|
||||
from tests.utils import DEFAULT_CONFIG_PATH
|
||||
|
||||
|
||||
def save_default_config_transform(original_frame: torch.Tensor, output_dir: Path):
|
||||
cfg = init_hydra_config(DEFAULT_CONFIG_PATH)
|
||||
cfg_tf = cfg.training.image_transforms
|
||||
default_tf = get_image_transforms(
|
||||
brightness_weight=cfg_tf.brightness.weight,
|
||||
brightness_min_max=cfg_tf.brightness.min_max,
|
||||
contrast_weight=cfg_tf.contrast.weight,
|
||||
contrast_min_max=cfg_tf.contrast.min_max,
|
||||
saturation_weight=cfg_tf.saturation.weight,
|
||||
saturation_min_max=cfg_tf.saturation.min_max,
|
||||
hue_weight=cfg_tf.hue.weight,
|
||||
hue_min_max=cfg_tf.hue.min_max,
|
||||
sharpness_weight=cfg_tf.sharpness.weight,
|
||||
sharpness_min_max=cfg_tf.sharpness.min_max,
|
||||
max_num_transforms=cfg_tf.max_num_transforms,
|
||||
random_order=cfg_tf.random_order,
|
||||
)
|
||||
|
||||
with seeded_context(1337):
|
||||
img_tf = default_tf(original_frame)
|
||||
|
||||
save_file({"default": img_tf}, output_dir / "default_transforms.safetensors")
|
||||
|
||||
|
||||
def save_single_transforms(original_frame: torch.Tensor, output_dir: Path):
|
||||
transforms = {
|
||||
"brightness": [(0.5, 0.5), (2.0, 2.0)],
|
||||
"contrast": [(0.5, 0.5), (2.0, 2.0)],
|
||||
"saturation": [(0.5, 0.5), (2.0, 2.0)],
|
||||
"hue": [(-0.25, -0.25), (0.25, 0.25)],
|
||||
"sharpness": [(0.5, 0.5), (2.0, 2.0)],
|
||||
}
|
||||
|
||||
frames = {"original_frame": original_frame}
|
||||
for transform, values in transforms.items():
|
||||
for min_max in values:
|
||||
kwargs = {
|
||||
f"{transform}_weight": 1.0,
|
||||
f"{transform}_min_max": min_max,
|
||||
}
|
||||
tf = get_image_transforms(**kwargs)
|
||||
key = f"{transform}_{min_max[0]}_{min_max[1]}"
|
||||
frames[key] = tf(original_frame)
|
||||
|
||||
save_file(frames, output_dir / "single_transforms.safetensors")
|
||||
|
||||
|
||||
def main():
|
||||
dataset = LeRobotDataset(DATASET_REPO_ID, image_transforms=None)
|
||||
output_dir = Path(ARTIFACT_DIR)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
original_frame = dataset[0][dataset.camera_keys[0]]
|
||||
|
||||
save_single_transforms(original_frame, output_dir)
|
||||
save_default_config_transform(original_frame, output_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,260 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from PIL import Image
|
||||
from safetensors.torch import load_file
|
||||
from torchvision.transforms import v2
|
||||
from torchvision.transforms.v2 import functional as F # noqa: N812
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.datasets.transforms import RandomSubsetApply, SharpnessJitter, get_image_transforms
|
||||
from lerobot.common.utils.utils import init_hydra_config, seeded_context
|
||||
from tests.utils import DEFAULT_CONFIG_PATH, require_x86_64_kernel
|
||||
|
||||
ARTIFACT_DIR = Path("tests/data/save_image_transforms_to_safetensors")
|
||||
DATASET_REPO_ID = "lerobot/aloha_mobile_shrimp"
|
||||
|
||||
|
||||
def load_png_to_tensor(path: Path):
|
||||
return torch.from_numpy(np.array(Image.open(path).convert("RGB"))).permute(2, 0, 1)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def img():
|
||||
dataset = LeRobotDataset(DATASET_REPO_ID)
|
||||
return dataset[0][dataset.camera_keys[0]]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def img_random():
|
||||
return torch.rand(3, 480, 640)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def color_jitters():
|
||||
return [
|
||||
v2.ColorJitter(brightness=0.5),
|
||||
v2.ColorJitter(contrast=0.5),
|
||||
v2.ColorJitter(saturation=0.5),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def single_transforms():
|
||||
return load_file(ARTIFACT_DIR / "single_transforms.safetensors")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_transforms():
|
||||
return load_file(ARTIFACT_DIR / "default_transforms.safetensors")
|
||||
|
||||
|
||||
def test_get_image_transforms_no_transform(img):
|
||||
tf_actual = get_image_transforms(brightness_min_max=(0.5, 0.5), max_num_transforms=0)
|
||||
torch.testing.assert_close(tf_actual(img), img)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)])
|
||||
def test_get_image_transforms_brightness(img, min_max):
|
||||
tf_actual = get_image_transforms(brightness_weight=1.0, brightness_min_max=min_max)
|
||||
tf_expected = v2.ColorJitter(brightness=min_max)
|
||||
torch.testing.assert_close(tf_actual(img), tf_expected(img))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)])
|
||||
def test_get_image_transforms_contrast(img, min_max):
|
||||
tf_actual = get_image_transforms(contrast_weight=1.0, contrast_min_max=min_max)
|
||||
tf_expected = v2.ColorJitter(contrast=min_max)
|
||||
torch.testing.assert_close(tf_actual(img), tf_expected(img))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)])
|
||||
def test_get_image_transforms_saturation(img, min_max):
|
||||
tf_actual = get_image_transforms(saturation_weight=1.0, saturation_min_max=min_max)
|
||||
tf_expected = v2.ColorJitter(saturation=min_max)
|
||||
torch.testing.assert_close(tf_actual(img), tf_expected(img))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("min_max", [(-0.25, -0.25), (0.25, 0.25)])
|
||||
def test_get_image_transforms_hue(img, min_max):
|
||||
tf_actual = get_image_transforms(hue_weight=1.0, hue_min_max=min_max)
|
||||
tf_expected = v2.ColorJitter(hue=min_max)
|
||||
torch.testing.assert_close(tf_actual(img), tf_expected(img))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)])
|
||||
def test_get_image_transforms_sharpness(img, min_max):
|
||||
tf_actual = get_image_transforms(sharpness_weight=1.0, sharpness_min_max=min_max)
|
||||
tf_expected = SharpnessJitter(sharpness=min_max)
|
||||
torch.testing.assert_close(tf_actual(img), tf_expected(img))
|
||||
|
||||
|
||||
def test_get_image_transforms_max_num_transforms(img):
|
||||
tf_actual = get_image_transforms(
|
||||
brightness_min_max=(0.5, 0.5),
|
||||
contrast_min_max=(0.5, 0.5),
|
||||
saturation_min_max=(0.5, 0.5),
|
||||
hue_min_max=(0.5, 0.5),
|
||||
sharpness_min_max=(0.5, 0.5),
|
||||
random_order=False,
|
||||
)
|
||||
tf_expected = v2.Compose(
|
||||
[
|
||||
v2.ColorJitter(brightness=(0.5, 0.5)),
|
||||
v2.ColorJitter(contrast=(0.5, 0.5)),
|
||||
v2.ColorJitter(saturation=(0.5, 0.5)),
|
||||
v2.ColorJitter(hue=(0.5, 0.5)),
|
||||
SharpnessJitter(sharpness=(0.5, 0.5)),
|
||||
]
|
||||
)
|
||||
torch.testing.assert_close(tf_actual(img), tf_expected(img))
|
||||
|
||||
|
||||
@require_x86_64_kernel
|
||||
def test_get_image_transforms_random_order(img):
|
||||
out_imgs = []
|
||||
tf = get_image_transforms(
|
||||
brightness_min_max=(0.5, 0.5),
|
||||
contrast_min_max=(0.5, 0.5),
|
||||
saturation_min_max=(0.5, 0.5),
|
||||
hue_min_max=(0.5, 0.5),
|
||||
sharpness_min_max=(0.5, 0.5),
|
||||
random_order=True,
|
||||
)
|
||||
with seeded_context(1337):
|
||||
for _ in range(10):
|
||||
out_imgs.append(tf(img))
|
||||
|
||||
for i in range(1, len(out_imgs)):
|
||||
with pytest.raises(AssertionError):
|
||||
torch.testing.assert_close(out_imgs[0], out_imgs[i])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"transform, min_max_values",
|
||||
[
|
||||
("brightness", [(0.5, 0.5), (2.0, 2.0)]),
|
||||
("contrast", [(0.5, 0.5), (2.0, 2.0)]),
|
||||
("saturation", [(0.5, 0.5), (2.0, 2.0)]),
|
||||
("hue", [(-0.25, -0.25), (0.25, 0.25)]),
|
||||
("sharpness", [(0.5, 0.5), (2.0, 2.0)]),
|
||||
],
|
||||
)
|
||||
def test_backward_compatibility_torchvision(transform, min_max_values, img, single_transforms):
|
||||
for min_max in min_max_values:
|
||||
kwargs = {
|
||||
f"{transform}_weight": 1.0,
|
||||
f"{transform}_min_max": min_max,
|
||||
}
|
||||
tf = get_image_transforms(**kwargs)
|
||||
actual = tf(img)
|
||||
key = f"{transform}_{min_max[0]}_{min_max[1]}"
|
||||
expected = single_transforms[key]
|
||||
torch.testing.assert_close(actual, expected)
|
||||
|
||||
|
||||
@require_x86_64_kernel
|
||||
def test_backward_compatibility_default_config(img, default_transforms):
|
||||
cfg = init_hydra_config(DEFAULT_CONFIG_PATH)
|
||||
cfg_tf = cfg.training.image_transforms
|
||||
default_tf = get_image_transforms(
|
||||
brightness_weight=cfg_tf.brightness.weight,
|
||||
brightness_min_max=cfg_tf.brightness.min_max,
|
||||
contrast_weight=cfg_tf.contrast.weight,
|
||||
contrast_min_max=cfg_tf.contrast.min_max,
|
||||
saturation_weight=cfg_tf.saturation.weight,
|
||||
saturation_min_max=cfg_tf.saturation.min_max,
|
||||
hue_weight=cfg_tf.hue.weight,
|
||||
hue_min_max=cfg_tf.hue.min_max,
|
||||
sharpness_weight=cfg_tf.sharpness.weight,
|
||||
sharpness_min_max=cfg_tf.sharpness.min_max,
|
||||
max_num_transforms=cfg_tf.max_num_transforms,
|
||||
random_order=cfg_tf.random_order,
|
||||
)
|
||||
|
||||
with seeded_context(1337):
|
||||
actual = default_tf(img)
|
||||
|
||||
expected = default_transforms["default"]
|
||||
|
||||
torch.testing.assert_close(actual, expected)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("p", [[0, 1], [1, 0]])
|
||||
def test_random_subset_apply_single_choice(p, img):
|
||||
flips = [v2.RandomHorizontalFlip(p=1), v2.RandomVerticalFlip(p=1)]
|
||||
random_choice = RandomSubsetApply(flips, p=p, n_subset=1, random_order=False)
|
||||
actual = random_choice(img)
|
||||
|
||||
p_horz, _ = p
|
||||
if p_horz:
|
||||
torch.testing.assert_close(actual, F.horizontal_flip(img))
|
||||
else:
|
||||
torch.testing.assert_close(actual, F.vertical_flip(img))
|
||||
|
||||
|
||||
def test_random_subset_apply_random_order(img):
|
||||
flips = [v2.RandomHorizontalFlip(p=1), v2.RandomVerticalFlip(p=1)]
|
||||
random_order = RandomSubsetApply(flips, p=[0.5, 0.5], n_subset=2, random_order=True)
|
||||
# We can't really check whether the transforms are actually applied in random order. However,
|
||||
# horizontal and vertical flip are commutative. Meaning, even under the assumption that the transform
|
||||
# applies them in random order, we can use a fixed order to compute the expected value.
|
||||
actual = random_order(img)
|
||||
expected = v2.Compose(flips)(img)
|
||||
torch.testing.assert_close(actual, expected)
|
||||
|
||||
|
||||
def test_random_subset_apply_valid_transforms(color_jitters, img):
|
||||
transform = RandomSubsetApply(color_jitters)
|
||||
output = transform(img)
|
||||
assert output.shape == img.shape
|
||||
|
||||
|
||||
def test_random_subset_apply_probability_length_mismatch(color_jitters):
|
||||
with pytest.raises(ValueError):
|
||||
RandomSubsetApply(color_jitters, p=[0.5, 0.5])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n_subset", [0, 5])
|
||||
def test_random_subset_apply_invalid_n_subset(color_jitters, n_subset):
|
||||
with pytest.raises(ValueError):
|
||||
RandomSubsetApply(color_jitters, n_subset=n_subset)
|
||||
|
||||
|
||||
def test_sharpness_jitter_valid_range_tuple(img):
|
||||
tf = SharpnessJitter((0.1, 2.0))
|
||||
output = tf(img)
|
||||
assert output.shape == img.shape
|
||||
|
||||
|
||||
def test_sharpness_jitter_valid_range_float(img):
|
||||
tf = SharpnessJitter(0.5)
|
||||
output = tf(img)
|
||||
assert output.shape == img.shape
|
||||
|
||||
|
||||
def test_sharpness_jitter_invalid_range_min_negative():
|
||||
with pytest.raises(ValueError):
|
||||
SharpnessJitter((-0.1, 2.0))
|
||||
|
||||
|
||||
def test_sharpness_jitter_invalid_range_max_smaller():
|
||||
with pytest.raises(ValueError):
|
||||
SharpnessJitter((2.0, 0.1))
|
|
@ -0,0 +1,352 @@
|
|||
"""
|
||||
This file contains generic tests to ensure that nothing breaks if we modify the push_dataset_to_hub API.
|
||||
Also, this file contains backward compatibility tests. Because they are slow and require to download the raw datasets,
|
||||
we skip them for now in our CI.
|
||||
|
||||
Example to run backward compatiblity tests locally:
|
||||
```
|
||||
DATA_DIR=tests/data python -m pytest --run-skipped tests/test_push_dataset_to_hub.py::test_push_dataset_to_hub_pusht_backward_compatibility
|
||||
```
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import save_images_concurrently
|
||||
from lerobot.common.datasets.video_utils import encode_video_frames
|
||||
from lerobot.scripts.push_dataset_to_hub import push_dataset_to_hub
|
||||
from tests.utils import require_package_arg
|
||||
|
||||
|
||||
def _mock_download_raw_pusht(raw_dir, num_frames=4, num_episodes=3):
|
||||
import zarr
|
||||
|
||||
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||
zarr_path = raw_dir / "pusht_cchi_v7_replay.zarr"
|
||||
store = zarr.DirectoryStore(zarr_path)
|
||||
zarr_data = zarr.group(store=store)
|
||||
|
||||
zarr_data.create_dataset(
|
||||
"data/action", shape=(num_frames, 1), chunks=(num_frames, 1), dtype=np.float32, overwrite=True
|
||||
)
|
||||
zarr_data.create_dataset(
|
||||
"data/img",
|
||||
shape=(num_frames, 96, 96, 3),
|
||||
chunks=(num_frames, 96, 96, 3),
|
||||
dtype=np.uint8,
|
||||
overwrite=True,
|
||||
)
|
||||
zarr_data.create_dataset(
|
||||
"data/n_contacts", shape=(num_frames, 2), chunks=(num_frames, 2), dtype=np.float32, overwrite=True
|
||||
)
|
||||
zarr_data.create_dataset(
|
||||
"data/state", shape=(num_frames, 5), chunks=(num_frames, 5), dtype=np.float32, overwrite=True
|
||||
)
|
||||
zarr_data.create_dataset(
|
||||
"data/keypoint", shape=(num_frames, 9, 2), chunks=(num_frames, 9, 2), dtype=np.float32, overwrite=True
|
||||
)
|
||||
zarr_data.create_dataset(
|
||||
"meta/episode_ends", shape=(num_episodes,), chunks=(num_episodes,), dtype=np.int32, overwrite=True
|
||||
)
|
||||
|
||||
zarr_data["data/action"][:] = np.random.randn(num_frames, 1)
|
||||
zarr_data["data/img"][:] = np.random.randint(0, 255, size=(num_frames, 96, 96, 3), dtype=np.uint8)
|
||||
zarr_data["data/n_contacts"][:] = np.random.randn(num_frames, 2)
|
||||
zarr_data["data/state"][:] = np.random.randn(num_frames, 5)
|
||||
zarr_data["data/keypoint"][:] = np.random.randn(num_frames, 9, 2)
|
||||
zarr_data["meta/episode_ends"][:] = np.array([1, 3, 4])
|
||||
|
||||
store.close()
|
||||
|
||||
|
||||
def _mock_download_raw_umi(raw_dir, num_frames=4, num_episodes=3):
|
||||
import zarr
|
||||
|
||||
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||
zarr_path = raw_dir / "cup_in_the_wild.zarr"
|
||||
store = zarr.DirectoryStore(zarr_path)
|
||||
zarr_data = zarr.group(store=store)
|
||||
|
||||
zarr_data.create_dataset(
|
||||
"data/camera0_rgb",
|
||||
shape=(num_frames, 96, 96, 3),
|
||||
chunks=(num_frames, 96, 96, 3),
|
||||
dtype=np.uint8,
|
||||
overwrite=True,
|
||||
)
|
||||
zarr_data.create_dataset(
|
||||
"data/robot0_demo_end_pose",
|
||||
shape=(num_frames, 5),
|
||||
chunks=(num_frames, 5),
|
||||
dtype=np.float32,
|
||||
overwrite=True,
|
||||
)
|
||||
zarr_data.create_dataset(
|
||||
"data/robot0_demo_start_pose",
|
||||
shape=(num_frames, 5),
|
||||
chunks=(num_frames, 5),
|
||||
dtype=np.float32,
|
||||
overwrite=True,
|
||||
)
|
||||
zarr_data.create_dataset(
|
||||
"data/robot0_eef_pos", shape=(num_frames, 5), chunks=(num_frames, 5), dtype=np.float32, overwrite=True
|
||||
)
|
||||
zarr_data.create_dataset(
|
||||
"data/robot0_eef_rot_axis_angle",
|
||||
shape=(num_frames, 5),
|
||||
chunks=(num_frames, 5),
|
||||
dtype=np.float32,
|
||||
overwrite=True,
|
||||
)
|
||||
zarr_data.create_dataset(
|
||||
"data/robot0_gripper_width",
|
||||
shape=(num_frames, 5),
|
||||
chunks=(num_frames, 5),
|
||||
dtype=np.float32,
|
||||
overwrite=True,
|
||||
)
|
||||
zarr_data.create_dataset(
|
||||
"meta/episode_ends", shape=(num_episodes,), chunks=(num_episodes,), dtype=np.int32, overwrite=True
|
||||
)
|
||||
|
||||
zarr_data["data/camera0_rgb"][:] = np.random.randint(0, 255, size=(num_frames, 96, 96, 3), dtype=np.uint8)
|
||||
zarr_data["data/robot0_demo_end_pose"][:] = np.random.randn(num_frames, 5)
|
||||
zarr_data["data/robot0_demo_start_pose"][:] = np.random.randn(num_frames, 5)
|
||||
zarr_data["data/robot0_eef_pos"][:] = np.random.randn(num_frames, 5)
|
||||
zarr_data["data/robot0_eef_rot_axis_angle"][:] = np.random.randn(num_frames, 5)
|
||||
zarr_data["data/robot0_gripper_width"][:] = np.random.randn(num_frames, 5)
|
||||
zarr_data["meta/episode_ends"][:] = np.array([1, 3, 4])
|
||||
|
||||
store.close()
|
||||
|
||||
|
||||
def _mock_download_raw_xarm(raw_dir, num_frames=4):
|
||||
import pickle
|
||||
|
||||
dataset_dict = {
|
||||
"observations": {
|
||||
"rgb": np.random.randint(0, 255, size=(num_frames, 3, 84, 84), dtype=np.uint8),
|
||||
"state": np.random.randn(num_frames, 4),
|
||||
},
|
||||
"actions": np.random.randn(num_frames, 3),
|
||||
"rewards": np.random.randn(num_frames),
|
||||
"masks": np.random.randn(num_frames),
|
||||
"dones": np.array([False, True, True, True]),
|
||||
}
|
||||
|
||||
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||
pkl_path = raw_dir / "buffer.pkl"
|
||||
with open(pkl_path, "wb") as f:
|
||||
pickle.dump(dataset_dict, f)
|
||||
|
||||
|
||||
def _mock_download_raw_aloha(raw_dir, num_frames=6, num_episodes=3):
|
||||
import h5py
|
||||
|
||||
for ep_idx in range(num_episodes):
|
||||
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||
path_h5 = raw_dir / f"episode_{ep_idx}.hdf5"
|
||||
with h5py.File(str(path_h5), "w") as f:
|
||||
f.create_dataset("action", data=np.random.randn(num_frames // num_episodes, 14))
|
||||
f.create_dataset("observations/qpos", data=np.random.randn(num_frames // num_episodes, 14))
|
||||
f.create_dataset("observations/qvel", data=np.random.randn(num_frames // num_episodes, 14))
|
||||
f.create_dataset(
|
||||
"observations/images/top",
|
||||
data=np.random.randint(
|
||||
0, 255, size=(num_frames // num_episodes, 480, 640, 3), dtype=np.uint8
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _mock_download_raw_dora(raw_dir, num_frames=6, num_episodes=3, fps=30):
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import pandas
|
||||
|
||||
def write_parquet(key, timestamps, values):
|
||||
data = {
|
||||
"timestamp_utc": timestamps,
|
||||
key: values,
|
||||
}
|
||||
df = pandas.DataFrame(data)
|
||||
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||
df.to_parquet(raw_dir / f"{key}.parquet", engine="pyarrow")
|
||||
|
||||
episode_indices = [None, None, -1, None, None, -1, None, None, -1]
|
||||
episode_indices_mapping = [0, 0, 0, 1, 1, 1, 2, 2, 2]
|
||||
frame_indices = [0, 1, -1, 0, 1, -1, 0, 1, -1]
|
||||
|
||||
cam_key = "observation.images.cam_high"
|
||||
timestamps = []
|
||||
actions = []
|
||||
states = []
|
||||
frames = []
|
||||
# `+ num_episodes`` for buffer frames associated to episode_index=-1
|
||||
for i, frame_idx in enumerate(frame_indices):
|
||||
t_utc = datetime.now(timezone.utc) + timedelta(seconds=i / fps)
|
||||
action = np.random.randn(21).tolist()
|
||||
state = np.random.randn(21).tolist()
|
||||
ep_idx = episode_indices_mapping[i]
|
||||
frame = [{"path": f"videos/{cam_key}_episode_{ep_idx:06d}.mp4", "timestamp": frame_idx / fps}]
|
||||
timestamps.append(t_utc)
|
||||
actions.append(action)
|
||||
states.append(state)
|
||||
frames.append(frame)
|
||||
|
||||
write_parquet(cam_key, timestamps, frames)
|
||||
write_parquet("observation.state", timestamps, states)
|
||||
write_parquet("action", timestamps, actions)
|
||||
write_parquet("episode_index", timestamps, episode_indices)
|
||||
|
||||
# write fake mp4 file for each episode
|
||||
for ep_idx in range(num_episodes):
|
||||
imgs_array = np.random.randint(0, 255, size=(num_frames // num_episodes, 480, 640, 3), dtype=np.uint8)
|
||||
|
||||
tmp_imgs_dir = raw_dir / "tmp_images"
|
||||
save_images_concurrently(imgs_array, tmp_imgs_dir)
|
||||
|
||||
fname = f"{cam_key}_episode_{ep_idx:06d}.mp4"
|
||||
video_path = raw_dir / "videos" / fname
|
||||
encode_video_frames(tmp_imgs_dir, video_path, fps)
|
||||
|
||||
|
||||
def _mock_download_raw(raw_dir, repo_id):
|
||||
if "wrist_gripper" in repo_id:
|
||||
_mock_download_raw_dora(raw_dir)
|
||||
elif "aloha" in repo_id:
|
||||
_mock_download_raw_aloha(raw_dir)
|
||||
elif "pusht" in repo_id:
|
||||
_mock_download_raw_pusht(raw_dir)
|
||||
elif "xarm" in repo_id:
|
||||
_mock_download_raw_xarm(raw_dir)
|
||||
elif "umi" in repo_id:
|
||||
_mock_download_raw_umi(raw_dir)
|
||||
else:
|
||||
raise ValueError(repo_id)
|
||||
|
||||
|
||||
def test_push_dataset_to_hub_invalid_repo_id(tmpdir):
|
||||
with pytest.raises(ValueError):
|
||||
push_dataset_to_hub(Path(tmpdir), "raw_format", "invalid_repo_id")
|
||||
|
||||
|
||||
def test_push_dataset_to_hub_out_dir_force_override_false(tmpdir):
|
||||
tmpdir = Path(tmpdir)
|
||||
out_dir = tmpdir / "out"
|
||||
raw_dir = tmpdir / "raw"
|
||||
# mkdir to skip download
|
||||
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||
with pytest.raises(ValueError):
|
||||
push_dataset_to_hub(
|
||||
raw_dir=raw_dir,
|
||||
raw_format="some_format",
|
||||
repo_id="user/dataset",
|
||||
local_dir=out_dir,
|
||||
force_override=False,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"required_packages, raw_format, repo_id",
|
||||
[
|
||||
(["gym-pusht"], "pusht_zarr", "lerobot/pusht"),
|
||||
(None, "xarm_pkl", "lerobot/xarm_lift_medium"),
|
||||
(None, "aloha_hdf5", "lerobot/aloha_sim_insertion_scripted"),
|
||||
(["imagecodecs"], "umi_zarr", "lerobot/umi_cup_in_the_wild"),
|
||||
(None, "dora_parquet", "cadene/wrist_gripper"),
|
||||
],
|
||||
)
|
||||
@require_package_arg
|
||||
def test_push_dataset_to_hub_format(required_packages, tmpdir, raw_format, repo_id):
|
||||
num_episodes = 3
|
||||
tmpdir = Path(tmpdir)
|
||||
|
||||
raw_dir = tmpdir / f"{repo_id}_raw"
|
||||
_mock_download_raw(raw_dir, repo_id)
|
||||
|
||||
local_dir = tmpdir / repo_id
|
||||
|
||||
lerobot_dataset = push_dataset_to_hub(
|
||||
raw_dir=raw_dir,
|
||||
raw_format=raw_format,
|
||||
repo_id=repo_id,
|
||||
push_to_hub=False,
|
||||
local_dir=local_dir,
|
||||
force_override=False,
|
||||
cache_dir=tmpdir / "cache",
|
||||
)
|
||||
|
||||
# minimal generic tests on the local directory containing LeRobotDataset
|
||||
assert (local_dir / "meta_data" / "info.json").exists()
|
||||
assert (local_dir / "meta_data" / "stats.safetensors").exists()
|
||||
assert (local_dir / "meta_data" / "episode_data_index.safetensors").exists()
|
||||
for i in range(num_episodes):
|
||||
for cam_key in lerobot_dataset.camera_keys:
|
||||
assert (local_dir / "videos" / f"{cam_key}_episode_{i:06d}.mp4").exists()
|
||||
assert (local_dir / "train" / "dataset_info.json").exists()
|
||||
assert (local_dir / "train" / "state.json").exists()
|
||||
assert len(list((local_dir / "train").glob("*.arrow"))) > 0
|
||||
|
||||
# minimal generic tests on the item
|
||||
item = lerobot_dataset[0]
|
||||
assert "index" in item
|
||||
assert "episode_index" in item
|
||||
assert "timestamp" in item
|
||||
for cam_key in lerobot_dataset.camera_keys:
|
||||
assert cam_key in item
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"raw_format, repo_id",
|
||||
[
|
||||
# TODO(rcadene): add raw dataset test artifacts
|
||||
("pusht_zarr", "lerobot/pusht"),
|
||||
("xarm_pkl", "lerobot/xarm_lift_medium"),
|
||||
("aloha_hdf5", "lerobot/aloha_sim_insertion_scripted"),
|
||||
("umi_zarr", "lerobot/umi_cup_in_the_wild"),
|
||||
("dora_parquet", "cadene/wrist_gripper"),
|
||||
],
|
||||
)
|
||||
@pytest.mark.skip(
|
||||
"Not compatible with our CI since it downloads raw datasets. Run with `DATA_DIR=tests/data python -m pytest --run-skipped tests/test_push_dataset_to_hub.py::test_push_dataset_to_hub_pusht_backward_compatibility`"
|
||||
)
|
||||
def test_push_dataset_to_hub_pusht_backward_compatibility(tmpdir, raw_format, repo_id):
|
||||
_, dataset_id = repo_id.split("/")
|
||||
|
||||
tmpdir = Path(tmpdir)
|
||||
raw_dir = tmpdir / f"{dataset_id}_raw"
|
||||
local_dir = tmpdir / repo_id
|
||||
|
||||
push_dataset_to_hub(
|
||||
raw_dir=raw_dir,
|
||||
raw_format=raw_format,
|
||||
repo_id=repo_id,
|
||||
push_to_hub=False,
|
||||
local_dir=local_dir,
|
||||
force_override=False,
|
||||
cache_dir=tmpdir / "cache",
|
||||
episodes=[0],
|
||||
)
|
||||
|
||||
ds_actual = LeRobotDataset(repo_id, root=tmpdir)
|
||||
ds_reference = LeRobotDataset(repo_id)
|
||||
|
||||
assert len(ds_reference.hf_dataset) == len(ds_actual.hf_dataset)
|
||||
|
||||
def check_same_items(item1, item2):
|
||||
assert item1.keys() == item2.keys(), "Keys mismatch"
|
||||
|
||||
for key in item1:
|
||||
if isinstance(item1[key], torch.Tensor) and isinstance(item2[key], torch.Tensor):
|
||||
assert torch.equal(item1[key], item2[key]), f"Mismatch found in key: {key}"
|
||||
else:
|
||||
assert item1[key] == item2[key], f"Mismatch found in key: {key}"
|
||||
|
||||
for i in range(len(ds_reference.hf_dataset)):
|
||||
item_reference = ds_reference.hf_dataset[i]
|
||||
item_actual = ds_actual.hf_dataset[i]
|
||||
check_same_items(item_reference, item_actual)
|
|
@ -13,6 +13,8 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from lerobot.scripts.visualize_dataset import visualize_dataset
|
||||
|
@ -30,3 +32,20 @@ def test_visualize_dataset(tmpdir, repo_id):
|
|||
serve=False,
|
||||
)
|
||||
assert rrd_path.exists()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"repo_id",
|
||||
["lerobot/pusht"],
|
||||
)
|
||||
@pytest.mark.parametrize("root", [Path(__file__).parent / "data"])
|
||||
def test_visualize_local_dataset(tmpdir, repo_id, root):
|
||||
rrd_path = visualize_dataset(
|
||||
repo_id,
|
||||
episode_index=0,
|
||||
batch_size=32,
|
||||
save=True,
|
||||
output_dir=tmpdir,
|
||||
root=root,
|
||||
)
|
||||
assert rrd_path.exists()
|
||||
|
|
|
@ -76,6 +76,7 @@ def require_env(func):
|
|||
"""
|
||||
Decorator that skips the test if the required environment package is not installed.
|
||||
As it need 'env_name' in args, it also checks whether it is provided as an argument.
|
||||
If 'env_name' is None, this check is skipped.
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
|
@ -91,7 +92,7 @@ def require_env(func):
|
|||
|
||||
# Perform the package check
|
||||
package_name = f"gym_{env_name}"
|
||||
if not is_package_available(package_name):
|
||||
if env_name is not None and not is_package_available(package_name):
|
||||
pytest.skip(f"gym-{env_name} not installed")
|
||||
|
||||
return func(*args, **kwargs)
|
||||
|
@ -99,6 +100,38 @@ def require_env(func):
|
|||
return wrapper
|
||||
|
||||
|
||||
def require_package_arg(func):
|
||||
"""
|
||||
Decorator that skips the test if the required package is not installed.
|
||||
This is similar to `require_env` but more general in that it can check any package (not just environments).
|
||||
As it need 'required_packages' in args, it also checks whether it is provided as an argument.
|
||||
If 'required_packages' is None, this check is skipped.
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
# Determine if 'required_packages' is provided and extract its value
|
||||
arg_names = func.__code__.co_varnames[: func.__code__.co_argcount]
|
||||
if "required_packages" in arg_names:
|
||||
# Get the index of 'required_packages' and retrieve the value from args
|
||||
index = arg_names.index("required_packages")
|
||||
required_packages = args[index] if len(args) > index else kwargs.get("required_packages")
|
||||
else:
|
||||
raise ValueError("Function does not have 'required_packages' as an argument.")
|
||||
|
||||
if required_packages is None:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
# Perform the package check
|
||||
for package in required_packages:
|
||||
if not is_package_available(package):
|
||||
pytest.skip(f"{package} not installed")
|
||||
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def require_package(package_name):
|
||||
"""
|
||||
Decorator that skips the test if the specified package is not installed.
|
||||
|
|
Loading…
Reference in New Issue