Rename openx to droid + Improve all (not tested)

This commit is contained in:
Remi Cadene 2025-03-18 16:28:09 +00:00
parent 7866c1f7d1
commit 1a5c1ef9c7
14 changed files with 1241 additions and 2709 deletions

View File

@ -0,0 +1,143 @@
# Port DROID 1.0.1 dataset to LeRobotDataset
## Download
TODO
It will take 2 TB in your local disk.
## Port on a single computer
First, install tensorflow dataset utilities to read from raw files:
```bash
pip install tensorflow
pip install tensorflow_datasets
```
Then run this script to start porting the dataset:
```bash
python examples/port_datasets/droid_rlds/port_droid.py \
--raw-dir /your/data/droid/1.0.1 \
--repo-id your_id/droid_1.0.1 \
--push-to-hub
```
It will take 400GB in your local disk.
As usual, your LeRobotDataset will be stored in your huggingface/lerobot cache folder.
WARNING: it will take 7 days for porting the dataset locally and 3 days to upload, so we will need to parallelize over multiple nodes on a slurm cluster.
NOTE: For development, run this script to start porting a shard:
```bash
python examples/port_datasets/droid_rlds/port.py \
--raw-dir /your/data/droid/1.0.1 \
--repo-id your_id/droid_1.0.1 \
--num-shards 2048 \
--shard-index 0
```
## Port over SLURM
### 1. Port one shard per job
First, install slurm utilities from Hugging Face:
```bash
pip install datatrove
```
Then run this script to start porting shards of the dataset:
```bash
python examples/port_datasets/droid_rlds/slurm_port_shards.py \
--raw-dir /your/data/droid/1.0.1 \
--repo-id your_id/droid_1.0.1 \
--logs-dir /your/logs \
--job-name port_droid \
--partition your_partition \
--workers 2048 \
--cpus-per-task 8 \
--mem-per-cpu 1950M
```
**Note on how to set your command line arguments**
Regarding `--partition`, find yours by running:
```bash
info --format="%R"`
```
and select the CPU partition if you have one. No GPU needed.
Regarding `--workers`, it is the number of slurm jobs you will launch in parallel. 2048 is the maximum number, since there is 2048 shards in Droid. This big number will certainly max-out your cluster.
Regarding `--cpus-per-task` and `--mem-per-cpu`, by default it will use ~16GB of RAM (8*1950M) which is recommended to load the raw frames and 8 CPUs which can be useful to parallelize the encoding of the frames.
Find the number of CPUs and Memory of the nodes of your partition by running:
```bash
sinfo -N -p your_partition -h -o "%N cpus=%c mem=%m"
```
**Useful commands to check progress and debug**
Check if your jobs are running:
```bash
squeue -u $USER`
```
You should see a list with job indices like `15125385_155` where `15125385` is the job index and `155` is the worker index. The output/print of this worker is written in real time in `/your/logs/job_name/slurm_jobs/15125385_155.out`. For instance, you can inspect the content of this file by running `less /your/logs/job_name/slurm_jobs/15125385_155.out`.
Check the progression of your jobs by running:
```bash
jobs_status /your/logs
```
If it's not 100% and no more slurm job is running, it means that some of them failed. Inspect the logs by running:
```bash
failed_logs /your/logs/job_name
```
If there is an issue in the code, you can fix it in debug mode with `--slurm 0` which allows to set breakpoint:
```bash
python examples/port_datasets/droid_rlds/slurm_port_shards.py --slurm 0 ...
```
And you can relaunch the same command, which will skip the completed jobs:
```bash
python examples/port_datasets/droid_rlds/slurm_port_shards.py --slurm 1 ...
```
Once all jobs are completed, you will have one dataset per shard (e.g. `droid_1.0.1_world_2048_rank_1594`) saved on disk in your `/lerobot/home/dir/your_id` directory. You can find your `/lerobot/home/dir` by running:
```bash
python -c "from lerobot.common.constants import HF_LEROBOT_HOME;print(HF_LEROBOT_HOME)"
```
### 2. Aggregate all shards
Run this script to start aggregation:
```bash
python examples/port_datasets/droid_rlds/slurm_aggregate_shards.py \
--repo-id your_id/droid_1.0.1 \
--logs-dir /your/logs \
--job-name aggr_droid \
--partition your_partition \
--workers 2048 \
--cpus-per-task 8 \
--mem-per-cpu 1950M
```
Once all jobs are completed, you will have one dataset your `/lerobot/home/dir/your_id/droid_1.0.1` directory.
### 3. Upload dataset
Run this script to start uploading:
```bash
python examples/port_datasets/droid_rlds/slurm_upload.py \
--repo-id your_id/droid_1.0.1 \
--logs-dir /your/logs \
--job-name aggr_droid \
--partition your_partition \
--workers 50 \
--cpus-per-task 4 \
--mem-per-cpu 1950M
```

View File

@ -0,0 +1,410 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
import time
from pathlib import Path
import numpy as np
import tensorflow_datasets as tfds
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.utils.utils import get_elapsed_time_in_days_hours_minutes_seconds
DROID_SHARDS = 2048
DROID_FPS = 15
DROID_ROBOT_TYPE = "Franka"
# Dataset schema slightly adapted from: https://droid-dataset.github.io/droid/the-droid-dataset.html#-dataset-schema
DROID_FEATURES = {
# true on first step of the episode
"is_first": {
"dtype": "bool",
"shape": (1,),
"names": None,
},
# true on last step of the episode
"is_last": {
"dtype": "bool",
"shape": (1,),
"names": None,
},
# true on last step of the episode if it is a terminal step, True for demos
"is_terminal": {
"dtype": "bool",
"shape": (1,),
"names": None,
},
# language_instruction is also stored as "task" to follow LeRobot standard
"language_instruction": {
"dtype": "string",
"shape": (1,),
"names": None,
},
"language_instruction_2": {
"dtype": "string",
"shape": (1,),
"names": None,
},
"language_instruction_3": {
"dtype": "string",
"shape": (1,),
"names": None,
},
"observation.state.gripper_position": {
"dtype": "float32",
"shape": (1,),
"names": {
"axes": ["gripper"],
},
},
"observation.state.cartesian_position": {
"dtype": "float32",
"shape": (6,),
"names": {
"axes": ["x", "y", "z", "roll", "pitch", "yaw"],
},
},
"observation.state.joint_position": {
"dtype": "float32",
"shape": (7,),
"names": {
"axes": ["x", "y", "z", "roll", "pitch", "yaw"],
},
},
# Add this new feature to follow LeRobot standard of using joint position + gripper
"observation.state": {
"dtype": "float32",
"shape": (8,),
"names": {
"axes": ["joint_0", "joint_1", "joint_2", "joint_3", "joint_4", "joint_5", "joint_6", "gripper"],
},
},
# Initially called wrist_image_left
"observation.images.wrist_left": {
"dtype": "video",
"shape": (180, 320, 3),
"names": [
"height",
"width",
"channels",
],
},
# Initially called exterior_image_1_left
"observation.images.exterior_1_left": {
"dtype": "video",
"shape": (180, 320, 3),
"names": [
"height",
"width",
"channels",
],
},
# Initially called exterior_image_2_left
"observation.images.exterior_2_left": {
"dtype": "video",
"shape": (180, 320, 3),
"names": [
"height",
"width",
"channels",
],
},
"action.gripper_position": {
"dtype": "float32",
"shape": (1,),
"names": {
"axes": ["gripper"],
},
},
"action.gripper_velocity": {
"dtype": "float32",
"shape": (1,),
"names": {
"axes": ["gripper"],
},
},
"action.cartesian_position": {
"dtype": "float32",
"shape": (6,),
"names": {
"axes": ["x", "y", "z", "roll", "pitch", "yaw"],
},
},
"action.cartesian_velocity": {
"dtype": "float32",
"shape": (6,),
"names": {
"axes": ["x", "y", "z", "roll", "pitch", "yaw"],
},
},
"action.joint_position": {
"dtype": "float32",
"shape": (7,),
"names": {
"axes": ["joint_0", "joint_1", "joint_2", "joint_3", "joint_4", "joint_5", "joint_6"],
},
},
"action.joint_velocity": {
"dtype": "float32",
"shape": (7,),
"names": {
"axes": ["joint_0", "joint_1", "joint_2", "joint_3", "joint_4", "joint_5", "joint_6"],
},
},
# This feature was called "action" in RLDS dataset and consists of [6x joint velocities, 1x gripper position]
"action.original": {
"dtype": "float32",
"shape": (7,),
"names": {
"axes": ["x", "y", "z", "roll", "pitch", "yaw", "gripper"],
},
},
# Add this new feature to follow LeRobot standard of using joint position + gripper
"action": {
"dtype": "float32",
"shape": (8,),
"names": {
"axes": ["joint_0", "joint_1", "joint_2", "joint_3", "joint_4", "joint_5", "joint_6", "gripper"],
},
},
"discount": {
"dtype": "float32",
"shape": (1,),
"names": None,
},
"reward": {
"dtype": "float32",
"shape": (1,),
"names": None,
},
# Meta data that are the same for all frames in the episode
"task_category": {
"dtype": "string",
"shape": (1,),
"names": None,
},
"building": {
"dtype": "string",
"shape": (1,),
"names": None,
},
"collector_id": {
"dtype": "string",
"shape": (1,),
"names": None,
},
"date": {
"dtype": "string",
"shape": (1,),
"names": None,
},
"camera_extrinsics.wrist_left": {
"dtype": "float32",
"shape": (6,),
"names": {
"axes": ["x", "y", "z", "roll", "pitch", "yaw"],
},
},
"camera_extrinsics.exterior_1_left": {
"dtype": "float32",
"shape": (6,),
"names": {
"axes": ["x", "y", "z", "roll", "pitch", "yaw"],
},
},
"camera_extrinsics.exterior_2_left": {
"dtype": "float32",
"shape": (6,),
"names": {
"axes": ["x", "y", "z", "roll", "pitch", "yaw"],
},
},
"is_episode_successful": {
"dtype": "bool",
"shape": (1,),
"names": None,
},
}
def is_episode_successful(tf_episode_metadata):
# Adapted from: https://github.com/droid-dataset/droid_policy_learning/blob/dd1020eb20d981f90b5ff07dc80d80d5c0cb108b/robomimic/utils/rlds_utils.py#L8
return "/success/" in tf_episode_metadata["file_path"].numpy().decode()
def generate_lerobot_frames(tf_episode):
m = tf_episode["episode_metadata"]
frame_meta = {
"task_category": m["building"].numpy().decode(),
"building": m["building"].numpy().decode(),
"collector_id": m["collector_id"].numpy().decode(),
"date": m["date"].numpy().decode(),
"camera_extrinsics.wrist_left": m["extrinsics_wrist_cam"].numpy(),
"camera_extrinsics.exterior_1_left": m["extrinsics_exterior_cam_1"].numpy(),
"camera_extrinsics.exterior_2_left": m["extrinsics_exterior_cam_2"].numpy(),
"is_episode_successful": np.array([is_episode_successful(m)]),
}
for f in tf_episode["steps"]:
# Dataset schema slightly adapted from: https://droid-dataset.github.io/droid/the-droid-dataset.html#-dataset-schema
frame = {
"is_first": np.array([f["is_first"].numpy()]),
"is_last": np.array([f["is_last"].numpy()]),
"is_terminal": np.array([f["is_terminal"].numpy()]),
"language_instruction": f["language_instruction"].numpy().decode(),
"language_instruction_2": f["language_instruction_2"].numpy().decode(),
"language_instruction_3": f["language_instruction_3"].numpy().decode(),
"observation.state.gripper_position": f["observation"]["gripper_position"].numpy(),
"observation.state.cartesian_position": f["observation"]["cartesian_position"].numpy(),
"observation.state.joint_position": f["observation"]["joint_position"].numpy(),
"observation.images.wrist_left": f["observation"]["wrist_image_left"].numpy(),
"observation.images.exterior_1_left": f["observation"]["exterior_image_1_left"].numpy(),
"observation.images.exterior_2_left": f["observation"]["exterior_image_2_left"].numpy(),
"action.gripper_position": f["action_dict"]["gripper_position"].numpy(),
"action.gripper_velocity": f["action_dict"]["gripper_velocity"].numpy(),
"action.cartesian_position": f["action_dict"]["cartesian_position"].numpy(),
"action.cartesian_velocity": f["action_dict"]["cartesian_velocity"].numpy(),
"action.joint_position": f["action_dict"]["joint_position"].numpy(),
"action.joint_velocity": f["action_dict"]["joint_velocity"].numpy(),
"discount": np.array([f["discount"].numpy()]),
"reward": np.array([f["reward"].numpy()]),
"action.original": f["action"].numpy(),
}
# language_instruction is also stored as "task" to follow LeRobot standard
frame["task"] = frame["language_instruction"]
# Add this new feature to follow LeRobot standard of using joint position + gripper
frame["observation.state"] = np.concatenate(
[frame["observation.state.joint_position"], frame["observation.state.gripper_position"]]
)
frame["action"] = np.concatenate([frame["action.joint_position"], frame["action.gripper_position"]])
# Meta data that are the same for all frames in the episode
frame.update(frame_meta)
# Cast fp64 to fp32
for key in frame:
if isinstance(frame[key], np.ndarray) and frame[key].dtype == np.float64:
frame[key] = frame[key].astype(np.float32)
yield frame
def port_droid(
raw_dir: Path,
repo_id: str = None,
push_to_hub: bool = False,
num_shards: int | None = None,
shard_index: int | None = None,
):
dataset_name = raw_dir.parent.name
version = raw_dir.name
data_dir = raw_dir.parent.parent
builder = tfds.builder(f"{dataset_name}/{version}", data_dir=data_dir, version="")
if num_shards is not None:
tfds_num_shards = builder.info.splits["train"].num_shards
if tfds_num_shards != DROID_SHARDS:
raise ValueError(
f"Number of shards of Droid dataset is expected to be {DROID_SHARDS} but is {tfds_num_shards}."
)
if num_shards != tfds_num_shards:
raise ValueError(
f"We only shard over the fixed number of shards provided by tensorflow dataset ({tfds_num_shards}), but {num_shards} shards provided instead."
)
if shard_index >= tfds_num_shards:
raise ValueError(
f"Shard index is greater than the num of shards ({shard_index} >= {num_shards})."
)
raw_dataset = builder.as_dataset(split=f"train[{shard_index}shard]")
else:
raw_dataset = builder.as_dataset(split="train")
lerobot_dataset = LeRobotDataset.create(
repo_id=repo_id,
robot_type=DROID_ROBOT_TYPE,
fps=DROID_FPS,
features=DROID_FEATURES,
)
start_time = time.time()
num_episodes = raw_dataset.cardinality().numpy().item()
logging.info(f"Number of episodes {num_episodes}")
for episode_index, episode in enumerate(raw_dataset):
logging.info(f"{episode_index} / {num_episodes} episodes processed")
elapsed_time = time.time() - start_time
d, h, m, s = get_elapsed_time_in_days_hours_minutes_seconds(elapsed_time)
logging.info(f"It has been {d} days, {h} hours, {m} minutes, {s:.3f} seconds")
for frame in generate_lerobot_frames(episode):
lerobot_dataset.add_frame(frame)
lerobot_dataset.save_episode()
logging.info("Save_episode")
if push_to_hub:
lerobot_dataset.push_to_hub(
# Add openx tag, since it belongs to the openx collection of datasets
tags=["openx"],
private=False,
)
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--raw-dir",
type=Path,
required=True,
help="Directory containing input raw datasets (e.g. `path/to/dataset` or `path/to/dataset/version).",
)
parser.add_argument(
"--repo-id",
type=str,
help="Repositery identifier on Hugging Face: a community or a user name `/` the name of the dataset, required when push-to-hub is True",
)
parser.add_argument(
"--push-to-hub",
action="store_true",
help="Upload to hub.",
)
parser.add_argument(
"--num-shards",
type=int,
default=None,
help="Number of shards. Can be either None to load the full dataset, or 2048 to load one of the 2048 tensorflow dataset files.",
)
parser.add_argument(
"--shard-index",
type=int,
default=None,
help="Index of the shard. Can be either None to load the full dataset, or in [0,2047] to load one of the 2048 tensorflow dataset files.",
)
args = parser.parse_args()
port_droid(**vars(args))
if __name__ == "__main__":
main()

View File

@ -0,0 +1,287 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
import tqdm
from datatrove.executor import LocalPipelineExecutor
from datatrove.executor.slurm import SlurmPipelineExecutor
from datatrove.pipeline.base import PipelineStep
from examples.port_datasets.droid_rlds.port_droid import DROID_SHARDS
from lerobot.common.datasets.aggregate import validate_all_metadata
from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata
from lerobot.common.datasets.utils import write_episode, write_episode_stats, write_info, write_task
from lerobot.common.utils.utils import init_logging
class AggregateDatasets(PipelineStep):
def __init__(
self,
repo_ids: list[str],
aggregated_repo_id: str,
):
super().__init__()
self.repo_ids = repo_ids
self.aggr_repo_id = aggregated_repo_id
self.create_aggr_dataset()
def create_aggr_dataset(self):
init_logging()
logging.info("Start aggregate_datasets")
all_metadata = [LeRobotDatasetMetadata(repo_id) for repo_id in self.repo_ids]
fps, robot_type, features = validate_all_metadata(all_metadata)
# Create resulting dataset folder
aggr_meta = LeRobotDatasetMetadata.create(
repo_id=self.aggr_repo_id,
fps=fps,
robot_type=robot_type,
features=features,
)
logging.info("Find all tasks")
# find all tasks, deduplicate them, create new task indices for each dataset
# indexed by dataset index
datasets_task_index_to_aggr_task_index = {}
aggr_task_index = 0
for dataset_index, meta in enumerate(tqdm.tqdm(all_metadata, desc="Find all tasks")):
task_index_to_aggr_task_index = {}
for task_index, task in meta.tasks.items():
if task not in aggr_meta.task_to_task_index:
# add the task to aggr tasks mappings
aggr_meta.tasks[aggr_task_index] = task
aggr_meta.task_to_task_index[task] = aggr_task_index
aggr_task_index += 1
# add task_index anyway
task_index_to_aggr_task_index[task_index] = aggr_meta.task_to_task_index[task]
datasets_task_index_to_aggr_task_index[dataset_index] = task_index_to_aggr_task_index
logging.info("Prepare copy data and videos")
datasets_ep_idx_to_aggr_ep_idx = {}
datasets_aggr_episode_index_shift = {}
aggr_episode_index_shift = 0
for dataset_index, meta in enumerate(tqdm.tqdm(all_metadata, desc="Prepare copy data and videos")):
ep_idx_to_aggr_ep_idx = {}
for episode_index in range(meta.total_episodes):
aggr_episode_index = episode_index + aggr_episode_index_shift
ep_idx_to_aggr_ep_idx[episode_index] = aggr_episode_index
datasets_ep_idx_to_aggr_ep_idx[dataset_index] = ep_idx_to_aggr_ep_idx
datasets_aggr_episode_index_shift[dataset_index] = aggr_episode_index_shift
# populate episodes
for episode_index, episode_dict in meta.episodes.items():
aggr_episode_index = episode_index + aggr_episode_index_shift
episode_dict["episode_index"] = aggr_episode_index
aggr_meta.episodes[aggr_episode_index] = episode_dict
# populate episodes_stats
for episode_index, episode_stats in meta.episodes_stats.items():
aggr_episode_index = episode_index + aggr_episode_index_shift
aggr_meta.episodes_stats[aggr_episode_index] = episode_stats
# populate info
aggr_meta.info["total_episodes"] += meta.total_episodes
aggr_meta.info["total_frames"] += meta.total_frames
aggr_meta.info["total_videos"] += len(aggr_meta.video_keys) * meta.total_episodes
aggr_episode_index_shift += meta.total_episodes
logging.info("Write meta data")
aggr_meta.info["total_tasks"] = len(aggr_meta.tasks)
aggr_meta.info["total_chunks"] = aggr_meta.get_episode_chunk(aggr_episode_index_shift - 1)
aggr_meta.info["splits"] = {"train": f"0:{aggr_meta.info['total_episodes']}"}
# create a new episodes jsonl with updated episode_index using write_episode
for episode_dict in tqdm.tqdm(aggr_meta.episodes.values(), desc="Write episodes"):
write_episode(episode_dict, aggr_meta.root)
# create a new episode_stats jsonl with updated episode_index using write_episode_stats
for episode_index, episode_stats in tqdm.tqdm(
aggr_meta.episodes_stats.items(), desc="Write episodes stats"
):
write_episode_stats(episode_index, episode_stats, aggr_meta.root)
# create a new task jsonl with updated episode_index using write_task
for task_index, task in tqdm.tqdm(aggr_meta.tasks.items(), desc="Write tasks"):
write_task(task_index, task, aggr_meta.root)
write_info(aggr_meta.info, aggr_meta.root)
self.datasets_task_index_to_aggr_task_index = datasets_task_index_to_aggr_task_index
self.datasets_ep_idx_to_aggr_ep_idx = datasets_ep_idx_to_aggr_ep_idx
self.datasets_aggr_episode_index_shift = datasets_aggr_episode_index_shift
logging.info("Meta data done writing!")
def run(self, data=None, rank: int = 0, world_size: int = 1):
import logging
import shutil
import pandas as pd
from lerobot.common.datasets.aggregate import get_update_episode_and_task_func
from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata
from lerobot.common.utils.utils import init_logging
init_logging()
aggr_meta = LeRobotDatasetMetadata(self.aggr_repo_id)
all_metadata = [LeRobotDatasetMetadata(repo_id) for repo_id in self.repo_ids]
if world_size != len(all_metadata):
raise ValueError()
dataset_index = rank
meta = all_metadata[dataset_index]
aggr_episode_index_shift = self.datasets_aggr_episode_index_shift[dataset_index]
logging.info("Copy data")
for episode_index in range(meta.total_episodes):
aggr_episode_index = self.datasets_ep_idx_to_aggr_ep_idx[dataset_index][episode_index]
data_path = meta.root / meta.get_data_file_path(episode_index)
aggr_data_path = aggr_meta.root / aggr_meta.get_data_file_path(aggr_episode_index)
# update episode_index and task_index
df = pd.read_parquet(data_path)
update_row_func = get_update_episode_and_task_func(
aggr_episode_index_shift, self.datasets_task_index_to_aggr_task_index[dataset_index]
)
df = df.apply(update_row_func, axis=1)
aggr_data_path.parent.mkdir(parents=True, exist_ok=True)
df.to_parquet(aggr_data_path)
logging.info("Copy videos")
for episode_index in range(meta.total_episodes):
aggr_episode_index = episode_index + aggr_episode_index_shift
for vid_key in meta.video_keys:
video_path = meta.root / meta.get_video_file_path(episode_index, vid_key)
aggr_video_path = aggr_meta.root / aggr_meta.get_video_file_path(aggr_episode_index, vid_key)
aggr_video_path.parent.mkdir(parents=True, exist_ok=True)
shutil.copy(video_path, aggr_video_path)
# copy_command = f"cp {video_path} {aggr_video_path} &"
# subprocess.Popen(copy_command, shell=True)
logging.info("Done!")
def make_aggregate_executor(
repo_ids, repo_id, job_name, logs_dir, workers, partition, cpus_per_task, mem_per_cpu, slurm=True
):
kwargs = {
"pipeline": [
AggregateDatasets(repo_ids, repo_id),
],
"logging_dir": str(logs_dir),
}
if slurm:
kwargs.update(
{
"job_name": job_name,
"tasks": DROID_SHARDS,
"workers": workers,
"time": "08:00:00",
"partition": partition,
"cpus_per_task": cpus_per_task,
"sbatch_args": {"mem-per-cpu": mem_per_cpu},
}
)
executor = SlurmPipelineExecutor(**kwargs)
else:
kwargs.update(
{
"tasks": DROID_SHARDS,
"workers": 1,
}
)
executor = LocalPipelineExecutor(**kwargs)
return executor
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--repo-id",
type=str,
help="Repositery identifier on Hugging Face: a community or a user name `/` the name of the dataset, required when push-to-hub is True.",
)
parser.add_argument(
"--logs-dir",
type=str,
help="Path to logs directory for `datatrove`.",
)
parser.add_argument(
"--job-name",
type=str,
default="port_droid",
help="Job name used in slurm, and name of the directory created inside the provided logs directory.",
)
parser.add_argument(
"--slurm",
type=int,
default=1,
help="Launch over slurm. Use `--slurm 0` to launch sequentially (useful to debug).",
)
parser.add_argument(
"--workers",
type=int,
default=2048,
help="Number of slurm workers. It should be less than the maximum number of shards.",
)
parser.add_argument(
"--partition",
type=str,
help="Slurm partition. Ideally a CPU partition. No need for GPU partition.",
)
parser.add_argument(
"--cpus-per-task",
type=int,
default=8,
help="Number of cpus that each slurm worker will use.",
)
parser.add_argument(
"--mem-per-cpu",
type=str,
default="1950M",
help="Memory per cpu that each worker will use.",
)
args = parser.parse_args()
kwargs = vars(args)
kwargs["slurm"] = kwargs.pop("slurm") == 1
repo_ids = [f"{args.repo_id}_world_{DROID_SHARDS}_rank_{rank}" for rank in range(DROID_SHARDS)]
aggregate_executor = make_aggregate_executor(repo_ids, **kwargs)
aggregate_executor.run()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,161 @@
import argparse
from pathlib import Path
from datatrove.executor import LocalPipelineExecutor
from datatrove.executor.slurm import SlurmPipelineExecutor
from datatrove.pipeline.base import PipelineStep
from examples.port_datasets.droid_rlds.port_droid import DROID_SHARDS
from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata
def validate_shard(repo_id):
"""Sanity check that ensure meta data can be loaded and all files are present."""
meta = LeRobotDatasetMetadata(repo_id)
if meta.total_episodes == 0:
raise ValueError("Number of episodes is 0.")
for ep_idx in range(meta.total_episodes):
data_path = meta.root / meta.get_data_file_path(ep_idx)
if not data_path.exists():
raise ValueError(f"Parquet file is missing in: {data_path}")
for vid_key in meta.video_keys:
vid_path = meta.root / meta.get_video_file_path(ep_idx, vid_key)
if not vid_path.exists():
raise ValueError(f"Video file is missing in: {vid_path}")
class PortDroidShards(PipelineStep):
def __init__(
self,
raw_dir: Path | str,
repo_id: str = None,
):
super().__init__()
self.raw_dir = Path(raw_dir)
self.repo_id = repo_id
def run(self, data=None, rank: int = 0, world_size: int = 1):
from datasets.utils.tqdm import disable_progress_bars
from examples.port_datasets.droid_rlds.port_droid import port_droid
from lerobot.common.utils.utils import init_logging
init_logging()
disable_progress_bars()
shard_repo_id = f"{self.repo_id}_world_{world_size}_rank_{rank}"
port_droid(
self.raw_dir,
shard_repo_id,
push_to_hub=False,
num_shards=world_size,
shard_index=rank,
)
validate_shard(shard_repo_id)
def make_port_executor(
raw_dir, repo_id, job_name, logs_dir, workers, partition, cpus_per_task, mem_per_cpu, slurm=True
):
kwargs = {
"pipeline": [
PortDroidShards(raw_dir, repo_id),
],
"logging_dir": str(logs_dir),
}
if slurm:
kwargs.update(
{
"job_name": job_name,
"tasks": DROID_SHARDS,
"workers": workers,
"time": "08:00:00",
"partition": partition,
"cpus_per_task": cpus_per_task,
"sbatch_args": {"mem-per-cpu": mem_per_cpu},
}
)
executor = SlurmPipelineExecutor(**kwargs)
else:
kwargs.update(
{
"tasks": 1,
"workers": 1,
}
)
executor = LocalPipelineExecutor(**kwargs)
return executor
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--raw-dir",
type=Path,
required=True,
help="Directory containing input raw datasets (e.g. `path/to/dataset` or `path/to/dataset/version).",
)
parser.add_argument(
"--repo-id",
type=str,
help="Repositery identifier on Hugging Face: a community or a user name `/` the name of the dataset, required when push-to-hub is True.",
)
parser.add_argument(
"--logs-dir",
type=str,
help="Path to logs directory for `datatrove`.",
)
parser.add_argument(
"--job-name",
type=str,
default="port_droid",
help="Job name used in slurm, and name of the directory created inside the provided logs directory.",
)
parser.add_argument(
"--slurm",
type=int,
default=1,
help="Launch over slurm. Use `--slurm 0` to launch sequentially (useful to debug).",
)
parser.add_argument(
"--workers",
type=int,
default=2048,
help="Number of slurm workers. It should be less than the maximum number of shards.",
)
parser.add_argument(
"--partition",
type=str,
help="Slurm partition. Ideally a CPU partition. No need for GPU partition.",
)
parser.add_argument(
"--cpus-per-task",
type=int,
default=8,
help="Number of cpus that each slurm worker will use.",
)
parser.add_argument(
"--mem-per-cpu",
type=str,
default="1950M",
help="Memory per cpu that each worker will use.",
)
args = parser.parse_args()
kwargs = vars(args)
kwargs["slurm"] = kwargs.pop("slurm") == 1
port_executor = make_port_executor(**kwargs)
port_executor.run()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,230 @@
import argparse
import logging
import os
from pathlib import Path
from datatrove.executor import LocalPipelineExecutor
from datatrove.executor.slurm import SlurmPipelineExecutor
from datatrove.pipeline.base import PipelineStep
from huggingface_hub import HfApi
from huggingface_hub.constants import REPOCARD_NAME
from examples.port_datasets.droid_rlds.port_droid import DROID_SHARDS
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDatasetMetadata
from lerobot.common.datasets.utils import create_lerobot_dataset_card
class UploadDataset(PipelineStep):
def __init__(
self,
repo_id: str,
branch: str | None = None,
revision: str | None = None,
tags: list | None = None,
license: str | None = "apache-2.0",
private: bool = False,
**card_kwargs,
):
super().__init__()
self.repo_id = repo_id
self.branch = branch
self.tags = tags
self.license = license
self.private = private
self.card_kwargs = card_kwargs
self.revision = revision if revision else CODEBASE_VERSION
if os.environ.get("HF_HUB_ENABLE_HF_TRANSFER", "0") != "1":
logging.warning(
'HF_HUB_ENABLE_HF_TRANSFER is not set to "1". Install hf_transfer and set the env '
"variable for faster uploads:\npip install hf-transfer\nexport HF_HUB_ENABLE_HF_TRANSFER=1"
)
self.create_repo()
def create_repo(self):
hub_api = HfApi()
meta = LeRobotDatasetMetadata(self.repo_id)
hub_api.create_repo(
repo_id=self.repo_id,
private=self.private,
repo_type="dataset",
exist_ok=True,
)
if self.branch:
hub_api.create_branch(
repo_id=self.repo_id,
branch=self.branch,
revision=self.revision,
repo_type="dataset",
exist_ok=True,
)
if not hub_api.file_exists(self.repo_id, REPOCARD_NAME, repo_type="dataset", revision=self.branch):
card = create_lerobot_dataset_card(
tags=self.tags, dataset_info=meta.info, license=self.license, **self.card_kwargs
)
card.push_to_hub(repo_id=self.repo_id, repo_type="dataset", revision=self.branch)
def list_files_recursively(directory):
base_path = Path(directory)
return [str(file.relative_to(base_path)) for file in base_path.rglob("*") if file.is_file()]
meta = LeRobotDatasetMetadata(self.repo_id)
self.file_paths = list_files_recursively(meta.root)
self.file_paths = sorted(self.file_paths)
def run(self, data=None, rank: int = 0, world_size: int = 1):
import logging
import random
import time
from itertools import islice
from huggingface_hub import CommitOperationAdd, create_commit, preupload_lfs_files
from huggingface_hub.utils import HfHubHTTPError
from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata
from lerobot.common.utils.utils import init_logging
BASE_DELAY = 1.0 # noqa: N806
MAX_RETRIES = 24 # noqa: N806
init_logging()
def chunked(lst, n):
it = iter(lst)
return [list(islice(it, size)) for size in [len(lst) // n + (i < len(lst) % n) for i in range(n)]]
chunks = chunked(self.file_paths, world_size)
file_paths = chunks[rank]
if len(file_paths) == 0:
raise ValueError(file_paths)
meta = LeRobotDatasetMetadata(self.repo_id)
additions = [
CommitOperationAdd(path_in_repo=path, path_or_fileobj=meta.root / path) for path in file_paths
]
logging.info(f"Uploading {','.join(file_paths)} to the hub...")
preupload_lfs_files(
repo_id=self.repo_id, repo_type="dataset", additions=additions, revision=self.branch
)
logging.info(f"Upload of {','.join(file_paths)} to the hub complete!")
retries = 0
while True:
try:
create_commit(
self.repo_id,
repo_type="dataset",
operations=additions,
commit_message=f"DataTrove upload ({len(additions)} files)",
revision=self.branch,
)
break
except HfHubHTTPError as e:
if "A commit has happened since" in e.server_message:
if retries >= MAX_RETRIES:
logging.error(f"Failed to create commit after {MAX_RETRIES=}. Giving up.")
raise e
logging.info("Commit creation race condition issue. Waiting...")
time.sleep(BASE_DELAY * 2**retries + random.uniform(0, 2))
retries += 1
else:
raise e
def make_upload_executor(
repo_id, job_name, logs_dir, workers, partition, cpus_per_task, mem_per_cpu, slurm=True
):
kwargs = {
"pipeline": [
UploadDataset(repo_id),
],
"logging_dir": str(logs_dir),
}
if slurm:
kwargs.update(
{
"job_name": job_name,
"tasks": DROID_SHARDS,
"workers": workers,
"time": "08:00:00",
"partition": partition,
"cpus_per_task": cpus_per_task,
"sbatch_args": {"mem-per-cpu": mem_per_cpu},
}
)
executor = SlurmPipelineExecutor(**kwargs)
else:
kwargs.update(
{
"tasks": DROID_SHARDS,
"workers": 1,
}
)
executor = LocalPipelineExecutor(**kwargs)
return executor
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--repo-id",
type=str,
help="Repositery identifier on Hugging Face: a community or a user name `/` the name of the dataset, required when push-to-hub is True.",
)
parser.add_argument(
"--logs-dir",
type=str,
help="Path to logs directory for `datatrove`.",
)
parser.add_argument(
"--job-name",
type=str,
default="port_droid",
help="Job name used in slurm, and name of the directory created inside the provided logs directory.",
)
parser.add_argument(
"--slurm",
type=int,
default=1,
help="Launch over slurm. Use `--slurm 0` to launch sequentially (useful to debug).",
)
parser.add_argument(
"--workers",
type=int,
default=2048,
help="Number of slurm workers. It should be less than the maximum number of shards.",
)
parser.add_argument(
"--partition",
type=str,
help="Slurm partition. Ideally a CPU partition. No need for GPU partition.",
)
parser.add_argument(
"--cpus-per-task",
type=int,
default=8,
help="Number of cpus that each slurm worker will use.",
)
parser.add_argument(
"--mem-per-cpu",
type=str,
default="1950M",
help="Memory per cpu that each worker will use.",
)
args = parser.parse_args()
kwargs = vars(args)
kwargs["slurm"] = kwargs.pop("slurm") == 1
upload_executor = make_upload_executor(**kwargs)
upload_executor.run()
if __name__ == "__main__":
main()

View File

@ -1,326 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
For all datasets in the RLDS format.
For https://github.com/google-deepmind/open_x_embodiment (OPENX) datasets.
NOTE: Install `tensorflow` and `tensorflow_datasets` before running this script.
```bash
pip install tensorflow
pip install tensorflow_datasets
```
Example:
```bash
python examples/port_datasets/openx_rlds.py \
--raw-dir /fsx/mustafa_shukor/droid \
--repo-id cadene/droid \
--use-videos \
--push-to-hub
```
"""
import argparse
import logging
import re
import time
from pathlib import Path
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
from examples.port_datasets.openx_utils.configs import OXE_DATASET_CONFIGS, StateEncoding
from examples.port_datasets.openx_utils.transforms import OXE_STANDARDIZATION_TRANSFORMS
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.utils.utils import get_elapsed_time_in_days_hours_minutes_seconds
np.set_printoptions(precision=2)
def transform_raw_dataset(episode, dataset_name):
traj = next(iter(episode["steps"].batch(episode["steps"].cardinality())))
if dataset_name in OXE_STANDARDIZATION_TRANSFORMS:
traj = OXE_STANDARDIZATION_TRANSFORMS[dataset_name](traj)
if dataset_name in OXE_DATASET_CONFIGS:
state_obs_keys = OXE_DATASET_CONFIGS[dataset_name]["state_obs_keys"]
else:
state_obs_keys = [None for _ in range(8)]
proprio = tf.concat(
[
(
tf.zeros((tf.shape(traj["action"])[0], 1), dtype=tf.float32) # padding
if key is None
else tf.cast(traj["observation"][key], tf.float32)
)
for key in state_obs_keys
],
axis=1,
)
traj.update(
{
"proprio": proprio,
"task": traj.pop("language_instruction"),
"action": tf.cast(traj["action"], tf.float32),
}
)
episode["steps"] = traj
return episode
def generate_features_from_raw(dataset_name: str, builder: tfds.core.DatasetBuilder, use_videos: bool = True):
state_names = [f"motor_{i}" for i in range(8)]
if dataset_name in OXE_DATASET_CONFIGS:
state_encoding = OXE_DATASET_CONFIGS[dataset_name]["state_encoding"]
if state_encoding == StateEncoding.POS_EULER:
state_names = ["x", "y", "z", "roll", "pitch", "yaw", "pad", "gripper"]
if "libero" in dataset_name:
state_names = [
"x",
"y",
"z",
"roll",
"pitch",
"yaw",
"gripper",
"gripper",
] # 2D gripper state
elif state_encoding == StateEncoding.POS_QUAT:
state_names = ["x", "y", "z", "rx", "ry", "rz", "rw", "gripper"]
DEFAULT_FEATURES = {
"observation.state": {
"dtype": "float32",
"shape": (8,),
"names": {"motors": state_names},
},
"action": {
"dtype": "float32",
"shape": (7,),
"names": {"motors": ["x", "y", "z", "roll", "pitch", "yaw", "gripper"]},
},
}
obs = builder.info.features["steps"]["observation"]
features = {
f"observation.images.{key}": {
"dtype": "video" if use_videos else "image",
"shape": value.shape,
"names": ["height", "width", "rgb"],
}
for key, value in obs.items()
if "depth" not in key and any(x in key for x in ["image", "rgb"])
}
return {**features, **DEFAULT_FEATURES}
def save_as_lerobot_dataset(
dataset_name: str,
lerobot_dataset: LeRobotDataset,
raw_dataset: tf.data.Dataset,
num_shards: int | None = None,
shard_index: int | None = None,
):
start_time = time.time()
total_num_episodes = raw_dataset.cardinality().numpy().item()
logging.info(f"Total number of episodes {total_num_episodes}")
if num_shards is not None:
sharded_dataset = raw_dataset.shard(num_shards=num_shards, index=shard_index)
sharded_num_episodes = sharded_dataset.cardinality().numpy().item()
logging.info(f"{sharded_num_episodes=}")
num_episodes = sharded_num_episodes
iter_ = iter(sharded_dataset)
else:
num_episodes = total_num_episodes
iter_ = iter(raw_dataset)
if num_episodes <= 0:
raise ValueError(f"Number of episodes is {num_episodes}, but needs to be positive.")
for episode_index in range(num_episodes):
logging.info(f"{episode_index} / {num_episodes} episodes processed")
elapsed_time = time.time() - start_time
d, h, m, s = get_elapsed_time_in_days_hours_minutes_seconds(elapsed_time)
logging.info(f"It has been {d} days, {h} hours, {m} minutes, {s:.3f} seconds")
episode = next(iter_)
logging.info("next")
episode = transform_raw_dataset(episode, dataset_name)
traj = episode["steps"]
for i in range(traj["action"].shape[0]):
image_dict = {
f"observation.images.{key}": value[i].numpy()
for key, value in traj["observation"].items()
if "depth" not in key and any(x in key for x in ["image", "rgb"])
}
lerobot_dataset.add_frame(
{
**image_dict,
"observation.state": traj["proprio"][i].numpy(),
"action": traj["action"][i].numpy(),
"task": traj["task"][i].numpy().decode(),
}
)
lerobot_dataset.save_episode()
logging.info("save_episode")
def create_lerobot_dataset(
raw_dir: Path,
repo_id: str = None,
push_to_hub: bool = False,
fps: int = None,
robot_type: str = None,
use_videos: bool = True,
image_writer_process: int = 5,
image_writer_threads: int = 10,
num_shards: int | None = None,
shard_index: int | None = None,
):
last_part = raw_dir.name
if re.match(r"^\d+\.\d+\.\d+$", last_part):
version = last_part
dataset_name = raw_dir.parent.name
data_dir = raw_dir.parent.parent
else:
version = ""
dataset_name = last_part
data_dir = raw_dir.parent
builder = tfds.builder(dataset_name, data_dir=data_dir, version=version)
features = generate_features_from_raw(dataset_name, builder, use_videos)
if num_shards is not None:
if num_shards != builder.info.splits["train"].num_shards:
raise ValueError()
if shard_index >= builder.info.splits["train"].num_shards:
raise ValueError()
raw_dataset = builder.as_dataset(split=f"train[{shard_index}shard]")
else:
raw_dataset = builder.as_dataset(split="train")
if fps is None:
if dataset_name in OXE_DATASET_CONFIGS:
fps = OXE_DATASET_CONFIGS[dataset_name]["control_frequency"]
else:
fps = 10
if robot_type is None:
if dataset_name in OXE_DATASET_CONFIGS:
robot_type = OXE_DATASET_CONFIGS[dataset_name]["robot_type"]
robot_type = robot_type.lower().replace(" ", "_").replace("-", "_")
else:
robot_type = "unknown"
lerobot_dataset = LeRobotDataset.create(
repo_id=repo_id,
robot_type=robot_type,
fps=fps,
use_videos=use_videos,
features=features,
image_writer_threads=image_writer_threads,
image_writer_processes=image_writer_process,
)
save_as_lerobot_dataset(
dataset_name,
lerobot_dataset,
raw_dataset,
)
if push_to_hub:
assert repo_id is not None
tags = []
if dataset_name in OXE_DATASET_CONFIGS:
tags.append("openx")
lerobot_dataset.push_to_hub(
tags=tags,
private=False,
push_videos=True,
license="apache-2.0",
)
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--raw-dir",
type=Path,
required=True,
help="Directory containing input raw datasets (e.g. `path/to/dataset` or `path/to/dataset/version).",
)
parser.add_argument(
"--repo-id",
type=str,
help="Repositery identifier on Hugging Face: a community or a user name `/` the name of the dataset, required when push-to-hub is True",
)
parser.add_argument(
"--push-to-hub",
action="store_true",
help="Upload to hub.",
)
parser.add_argument(
"--robot-type",
type=str,
default=None,
help="Robot type of this dataset.",
)
parser.add_argument(
"--fps",
type=int,
default=None,
help="Frame rate used to collect videos. Default fps equals to the control frequency of the robot.",
)
parser.add_argument(
"--use-videos",
action="store_true",
help="Convert each episode of the raw dataset to an mp4 video. This option allows 60 times lower disk space consumption and 25 faster loading time during training.",
)
parser.add_argument(
"--image-writer-process",
type=int,
default=0,
help="Number of processes of image writer for saving images.",
)
parser.add_argument(
"--image-writer-threads",
type=int,
default=8,
help="Number of threads per process of image writer for saving images.",
)
args = parser.parse_args()
# droid_dir = Path("/fsx/remi_cadene/.cache/huggingface/lerobot/cadene/droid")
# if droid_dir.exists():
# shutil.rmtree(droid_dir)
create_lerobot_dataset(**vars(args))
if __name__ == "__main__":
main()

View File

@ -1,52 +0,0 @@
from pathlib import Path
import tqdm
from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata
def main():
repo_id = "cadene/droid"
datetime = "2025-02-22_11-23-54"
port_log_dir = Path(f"/fsx/remi_cadene/logs/{datetime}_port_openx_droid")
compl_dir = port_log_dir / "completions"
paths = list(compl_dir.glob("*"))
total_items = len(paths)
# Use tqdm with the total parameter
wrong_completions = []
error_messages = []
for i, path in tqdm.tqdm(enumerate(paths), total=total_items):
try:
rank = path.name.lstrip("0")
if rank == "":
rank = 0
meta = LeRobotDatasetMetadata(f"{repo_id}_{datetime}_world_2048_rank_{rank}")
last_episode_index = meta.total_episodes - 1
last_ep_data_path = meta.root / meta.get_data_file_path(last_episode_index)
if not last_ep_data_path.exists():
raise ValueError(path)
for vid_key in meta.video_keys:
last_ep_vid_path = meta.root / meta.get_video_file_path(last_episode_index, vid_key)
if not last_ep_vid_path.exists():
raise ValueError(path)
except Exception as e:
error_messages.append(str(e))
wrong_completions.append(path)
for path, error_msg in zip(wrong_completions, error_messages, strict=False):
print(path)
print(error_msg)
print()
# path.unlink()
print(f"Error {len(wrong_completions)} / {total_items}")
if __name__ == "__main__":
main()

View File

@ -1,310 +0,0 @@
import datetime as dt
import logging
import os
import random
import time
from pathlib import Path
from datatrove.executor import LocalPipelineExecutor
from datatrove.executor.slurm import SlurmPipelineExecutor
from datatrove.pipeline.base import PipelineStep
from huggingface_hub import CommitOperationAdd, HfApi, create_commit, preupload_lfs_files
from huggingface_hub.constants import REPOCARD_NAME
from huggingface_hub.utils import HfHubHTTPError
from lerobot.common.datasets.aggregate import aggregate_datasets
from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata
from lerobot.common.datasets.utils import create_lerobot_dataset_card
BASE_DELAY = 0.1
MAX_RETRIES = 12
class PortOpenXDataset(PipelineStep):
def __init__(
self,
raw_dir: Path | str,
repo_id: str = None,
image_writer_process: int = 0,
image_writer_threads: int = 8,
):
super().__init__()
self.raw_dir = Path(raw_dir)
self.repo_id = repo_id
self.image_writer_process = image_writer_process
self.image_writer_threads = image_writer_threads
def run(self, data=None, rank: int = 0, world_size: int = 1):
from datasets.utils.tqdm import disable_progress_bars
from examples.port_datasets.openx_rlds import create_lerobot_dataset
from examples.port_datasets.openx_utils.test import display_slurm_info, display_system_info
from lerobot.common.utils.utils import init_logging
init_logging()
disable_progress_bars()
display_system_info()
display_slurm_info()
create_lerobot_dataset(
self.raw_dir,
f"{self.repo_id}_world_{world_size}_rank_{rank}",
image_writer_process=self.image_writer_process,
image_writer_threads=self.image_writer_threads,
push_to_hub=False,
num_shards=world_size,
shard_index=rank,
)
class AggregateDatasets(PipelineStep):
def __init__(
self,
repo_ids: list[str],
aggregated_repo_id: str,
):
super().__init__()
self.repo_ids = repo_ids
self.aggregated_repo_id = aggregated_repo_id
def run(self, data=None, rank: int = 0, world_size: int = 1):
aggregate_datasets(self.repo_ids, self.aggregated_repo_id)
class UploadDataset(PipelineStep):
def __init__(
self,
repo_id: str,
branch: str | None = None,
tags: list | None = None,
license: str | None = "apache-2.0",
private: bool = False,
**card_kwargs,
):
super().__init__()
self.repo_id = repo_id
self.branch = branch
self.tags = tags
self.license = license
self.private = private
self.card_kwargs = card_kwargs
if os.environ.get("HF_HUB_ENABLE_HF_TRANSFER", "0") != "1":
logging.warning(
'HF_HUB_ENABLE_HF_TRANSFER is not set to "1". Install hf_transfer and set the env '
"variable for faster uploads:\npip install hf-transfer\nexport HF_HUB_ENABLE_HF_TRANSFER=1"
)
self._repo_init = False
def _create_repo(self, hub_api):
hub_api.create_repo(
repo_id=self.repo_id,
private=self.private,
repo_type="dataset",
exist_ok=True,
)
if self.branch:
hub_api.create_branch(
repo_id=self.repo_id,
branch=self.branch,
revision=self.revision,
repo_type="dataset",
exist_ok=True,
)
if not hub_api.file_exists(self.repo_id, REPOCARD_NAME, repo_type="dataset", revision=self.branch):
card = create_lerobot_dataset_card(
tags=self.tags, dataset_info=self.meta.info, license=license, **self.card_kwargs
)
card.push_to_hub(repo_id=self.repo_id, repo_type="dataset", revision=self.branch)
def run(self, data=None, rank: int = 0, world_size: int = 1):
from lerobot.common.utils.utils import init_logging
init_logging()
meta = LeRobotDatasetMetadata(self.repo_id)
# TODO: list files, shard files, upload meta data for rank=0
filenames = []
raise NotImplementedError()
hub_api = HfApi()
if not self._repo_init:
self._create_repo(hub_api)
self._repo_init = True
additions = [
CommitOperationAdd(path_in_repo=filename, path_or_fileobj=meta.root / filename)
for filename in filenames
]
logging.info(f"Uploading {','.join(filenames)} to the hub...")
preupload_lfs_files(
repo_id=self.repo_id, repo_type="dataset", additions=additions, revision=self.revision
)
logging.info(f"Upload of {','.join(filenames)} to the hub complete!")
# if self.cleanup:
# for filename in filenames:
# self.local_working_dir.rm(filename)
self.operations.extend(additions)
def close(self, rank: int = 0):
filelist = list(self.output_mg.get_open_files().keys())
super().close()
if filelist:
logging.info(f"Starting upload of {len(filelist)} files to {self.dataset}")
self.upload_files(*filelist)
retries = 0
while True:
try:
create_commit(
self.repo_id,
repo_type="dataset",
operations=self.operations,
commit_message=f"DataTrove upload ({len(self.operations)} files)",
revision=self.revision,
)
break
except HfHubHTTPError as e:
if "A commit has happened since" in e.server_message:
if retries >= MAX_RETRIES:
logging.error(f"Failed to create commit after {MAX_RETRIES=}. Giving up.")
raise e
logging.info("Commit creation race condition issue. Waiting...")
time.sleep(BASE_DELAY * 2**retries + random.uniform(0, 2))
retries += 1
else:
raise e
def make_port_executor(raw_dir, repo_id, port_job_name, port_log_dir, slurm=True):
kwargs = {
"pipeline": [
PortOpenXDataset(raw_dir, repo_id),
],
"logging_dir": str(port_log_dir),
}
if slurm:
kwargs.update(
{
"job_name": port_job_name,
"tasks": 2048,
"workers": 20,
"time": "08:00:00",
"partition": "hopper-cpu",
"cpus_per_task": 24,
"mem_per_cpu_gb": 2,
"max_array_launch_parallel": True,
}
)
executor = SlurmPipelineExecutor(**kwargs)
else:
kwargs.update(
{
"tasks": 1,
"workers": 1,
}
)
executor = LocalPipelineExecutor(**kwargs)
return executor
def make_aggregate_executor(
repo_ids, aggr_repo_id, port_job_name, aggregate_log_dir, depends=None, slurm=True
):
kwargs = {
"pipeline": [
AggregateDatasets(repo_ids, aggr_repo_id),
],
"logging_dir": str(aggregate_log_dir),
"tasks": 1,
"workers": 1,
}
if depends:
kwargs["depends"] = depends
if slurm:
kwargs.update(
{
"job_name": port_job_name,
"time": "08:00:00",
"partition": "hopper-cpu",
}
)
executor = SlurmPipelineExecutor(**kwargs)
else:
executor = LocalPipelineExecutor(**kwargs)
return executor
def make_upload_executor(repo_id, upload_job_name, upload_log_dir, depends=None, slurm=True):
kwargs = {
"pipeline": [
UploadDataset(repo_id),
],
"logging_dir": str(upload_log_dir),
"tasks": 1,
"workers": 1,
}
if depends:
kwargs["depends"] = depends
if slurm:
kwargs.update(
{
"job_name": upload_job_name,
"time": "08:00:00",
"partition": "hopper-cpu",
}
)
executor = SlurmPipelineExecutor(**kwargs)
else:
executor = LocalPipelineExecutor(**kwargs)
return executor
def main(slurm=True):
# breakpoint()
# for dir_ in Path("/fsx/remi_cadene/.cache/huggingface/lerobot/cadene").glob("droid_world*"):
# shutil.rmtree(dir_)
world = 2048
raw_dir = "/fsx/mustafa_shukor/droid"
port_job_name = "port_openx_droid"
aggregate_job_name = "aggregate_openx_droid"
upload_job_name = "upload_openx_droid"
logs_dir = Path("/fsx/remi_cadene/logs")
repo_id = "cadene/droid"
now = dt.datetime.now()
datetime = f"{now:%Y-%m-%d}_{now:%H-%M-%S}"
# datetime = "2025-02-22_11-17-00"
port_log_dir = logs_dir / f"{datetime}_{port_job_name}"
aggregate_log_dir = logs_dir / f"{datetime}_{aggregate_job_name}"
upload_log_dir = logs_dir / f"{datetime}_{upload_job_name}"
port_executor = make_port_executor(raw_dir, repo_id, port_job_name, port_log_dir, slurm)
port_executor.run()
repo_ids = [f"{repo_id}_{datetime}_world_{world}_rank_{rank}" for rank in range(world)]
aggregate_executor = make_aggregate_executor(
repo_ids, repo_id, aggregate_job_name, aggregate_log_dir, port_executor, slurm
)
aggregate_executor.run()
upload_executor = make_upload_executor(
repo_id, upload_job_name, upload_log_dir, aggregate_executor, slurm
)
upload_executor.run()
if __name__ == "__main__":
main()

View File

@ -1,854 +0,0 @@
"""
Adapt from https://github.com/openvla/openvla/blob/main/prismatic/vla/datasets/rlds/oxe/configs.py
configs.py
Defines per-dataset configuration (kwargs) for each dataset in Open-X Embodiment.
Configuration adopts the following structure:
image_obs_keys:
primary: primary external RGB
secondary: secondary external RGB
wrist: wrist RGB
depth_obs_keys:
primary: primary external depth
secondary: secondary external depth
wrist: wrist depth
# Always 8-dim =>> changes based on `StateEncoding`
state_obs_keys:
StateEncoding.POS_EULER: EEF XYZ (3) + Roll-Pitch-Yaw (3) + <PAD> (1) + Gripper Open/Close (1)
StateEncoding.POS_QUAT: EEF XYZ (3) + Quaternion (4) + Gripper Open/Close (1)
StateEncoding.JOINT: Joint Angles (7, <PAD> if fewer) + Gripper Open/Close (1)
state_encoding: Type of `StateEncoding`
action_encoding: Type of action encoding (e.g., EEF Position vs. Joint Position)
"""
from enum import IntEnum
from typing import Dict
import tensorflow as tf
def zero_action_filter(traj: Dict) -> bool:
"""
Filters transitions whose actions are all-0 (only relative actions, no gripper action).
Note: this filter is applied *after* action normalization, so need to compare to "normalized 0".
"""
DROID_Q01 = tf.convert_to_tensor(
[
-0.7776297926902771,
-0.5803514122962952,
-0.5795090794563293,
-0.6464047729969025,
-0.7041108310222626,
-0.8895104378461838,
]
)
DROID_Q99 = tf.convert_to_tensor(
[
0.7597932070493698,
0.5726242214441299,
0.7351000607013702,
0.6705610305070877,
0.6464948207139969,
0.8897542208433151,
]
)
DROID_NORM_0_ACT = (
2 * (tf.zeros_like(traj["action"][:, :6]) - DROID_Q01) / (DROID_Q99 - DROID_Q01 + 1e-8) - 1
)
return tf.reduce_any(tf.math.abs(traj["action"][:, :6] - DROID_NORM_0_ACT) > 1e-5)
# Defines Proprioceptive State Encoding Schemes
class StateEncoding(IntEnum):
# fmt: off
NONE = -1 # No Proprioceptive State
POS_EULER = 1 # EEF XYZ (3) + Roll-Pitch-Yaw (3) + <PAD> (1) + Gripper Open/Close (1)
POS_QUAT = 2 # EEF XYZ (3) + Quaternion (4) + Gripper Open/Close (1)
JOINT = 3 # Joint Angles (7, <PAD> if fewer) + Gripper Open/Close (1)
JOINT_BIMANUAL = 4 # Joint Angles (2 x [ Joint Angles (6) + Gripper Open/Close (1) ])
# fmt: on
# Defines Action Encoding Schemes
class ActionEncoding(IntEnum):
# fmt: off
EEF_POS = 1 # EEF Delta XYZ (3) + Roll-Pitch-Yaw (3) + Gripper Open/Close (1)
JOINT_POS = 2 # Joint Delta Position (7) + Gripper Open/Close (1)
JOINT_POS_BIMANUAL = 3 # Joint Delta Position (2 x [ Joint Delta Position (6) + Gripper Open/Close (1) ])
EEF_R6 = 4 # EEF Delta XYZ (3) + R6 (6) + Gripper Open/Close (1)
# fmt: on
# === Individual Dataset Configs ===
OXE_DATASET_CONFIGS = {
"fractal20220817_data": {
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["base_pose_tool_reached", "gripper_closed"],
"state_encoding": StateEncoding.POS_QUAT,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 3,
"robot_type": "Google Robot",
},
"kuka": {
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": [
"clip_function_input/base_pose_tool_reached",
"gripper_closed",
],
"state_encoding": StateEncoding.POS_QUAT,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 10,
"robot_type": "Kuka iiwa",
},
"bridge_oxe": { # Version of Bridge V2 in Open X-Embodiment mixture
"image_obs_keys": {"primary": "image", "secondary": "image_1", "wrist": None},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["EEF_state", None, "gripper_state"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 5,
"robot_type": "WidowX",
},
"bridge_orig": { # Original version of Bridge V2 from project website
"image_obs_keys": {"primary": "image_0", "secondary": "image_1", "wrist": None},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["EEF_state", None, "gripper_state"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 5,
"robot_type": "WidowX",
},
"bridge_dataset": { # Original version of Bridge V2 from project website
"image_obs_keys": {"primary": "image_0", "secondary": "image_1", "wrist": None},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["EEF_state", None, "gripper_state"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 5,
"robot_type": "WidowX",
},
"taco_play": {
"image_obs_keys": {
"primary": "rgb_static",
"secondary": None,
"wrist": "rgb_gripper",
},
"depth_obs_keys": {
"primary": "depth_static",
"secondary": None,
"wrist": "depth_gripper",
},
"state_obs_keys": ["state_eef", None, "state_gripper"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 15,
"robot_type": "Franka",
},
"jaco_play": {
"image_obs_keys": {
"primary": "image",
"secondary": None,
"wrist": "image_wrist",
},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["state_eef", None, "state_gripper"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 10,
"robot_type": "Jaco 2",
},
"berkeley_cable_routing": {
"image_obs_keys": {
"primary": "image",
"secondary": "top_image",
"wrist": "wrist45_image",
},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["robot_state", None],
"state_encoding": StateEncoding.JOINT,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 10,
"robot_type": "Franka",
},
"roboturk": {
"image_obs_keys": {"primary": "front_rgb", "secondary": None, "wrist": None},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": [None, None, None, None, None, None, None, None],
"state_encoding": StateEncoding.NONE,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 10,
"robot_type": "Sawyer",
},
"nyu_door_opening_surprising_effectiveness": {
"image_obs_keys": {"primary": None, "secondary": None, "wrist": "image"},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": [None, None, None, None, None, None, None, None],
"state_encoding": StateEncoding.NONE,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 3,
"robot_type": "Hello Stretch",
},
"viola": {
"image_obs_keys": {
"primary": "agentview_rgb",
"secondary": None,
"wrist": "eye_in_hand_rgb",
},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["joint_states", "gripper_states"],
"state_encoding": StateEncoding.JOINT,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 20,
"robot_type": "Franka",
},
"berkeley_autolab_ur5": {
"image_obs_keys": {
"primary": "image",
"secondary": None,
"wrist": "hand_image",
},
"depth_obs_keys": {"primary": "depth", "secondary": None, "wrist": None},
"state_obs_keys": ["state"],
"state_encoding": StateEncoding.POS_QUAT,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 5,
"robot_type": "UR5",
},
"toto": {
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["state", None],
"state_encoding": StateEncoding.JOINT,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 30,
"robot_type": "Franka",
},
"language_table": {
"image_obs_keys": {"primary": "rgb", "secondary": None, "wrist": None},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["effector_translation", None, None, None, None, None, None],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 10,
"robot_type": "xArm",
},
"columbia_cairlab_pusht_real": {
"image_obs_keys": {
"primary": "image",
"secondary": None,
"wrist": "wrist_image",
},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["robot_state", None, None, None, None, None, None],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 10,
"robot_type": "UR5",
},
"stanford_kuka_multimodal_dataset_converted_externally_to_rlds": {
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
"depth_obs_keys": {"primary": "depth_image", "secondary": None, "wrist": None},
"state_obs_keys": ["ee_position", "ee_orientation", None],
"state_encoding": StateEncoding.POS_QUAT,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 20,
"robot_type": "Kuka iiwa",
},
"nyu_rot_dataset_converted_externally_to_rlds": {
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["eef_state", None, "gripper_state"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 3,
"robot_type": "xArm",
},
"stanford_hydra_dataset_converted_externally_to_rlds": {
"image_obs_keys": {
"primary": "image",
"secondary": None,
"wrist": "wrist_image",
},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["eef_state", None, "gripper_state"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 10,
"robot_type": "Franka",
},
"austin_buds_dataset_converted_externally_to_rlds": {
"image_obs_keys": {
"primary": "image",
"secondary": None,
"wrist": "wrist_image",
},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["state"],
"state_encoding": StateEncoding.JOINT,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 20,
"robot_type": "Franka",
},
"nyu_franka_play_dataset_converted_externally_to_rlds": {
"image_obs_keys": {
"primary": "image",
"secondary": "image_additional_view",
"wrist": None,
},
"depth_obs_keys": {
"primary": "depth",
"secondary": "depth_additional_view",
"wrist": None,
},
"state_obs_keys": ["eef_state", None, None],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 3,
"robot_type": "Franka",
},
"maniskill_dataset_converted_externally_to_rlds": {
"image_obs_keys": {
"primary": "image",
"secondary": None,
"wrist": "wrist_image",
},
"depth_obs_keys": {
"primary": "depth",
"secondary": None,
"wrist": "wrist_depth",
},
"state_obs_keys": ["tcp_pose", "gripper_state"],
"state_encoding": StateEncoding.POS_QUAT,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 20,
"robot_type": "Franka",
},
"furniture_bench_dataset_converted_externally_to_rlds": {
"image_obs_keys": {
"primary": "image",
"secondary": None,
"wrist": "wrist_image",
},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["state"],
"state_encoding": StateEncoding.POS_QUAT,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 10,
"robot_type": "Franka",
},
"cmu_franka_exploration_dataset_converted_externally_to_rlds": {
"image_obs_keys": {
"primary": "highres_image",
"secondary": None,
"wrist": None,
},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": [None, None, None, None, None, None, None, None],
"state_encoding": StateEncoding.NONE,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 10,
"robot_type": "Franka",
},
"ucsd_kitchen_dataset_converted_externally_to_rlds": {
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["joint_state", None],
"state_encoding": StateEncoding.JOINT,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 2,
"robot_type": "xArm",
},
"ucsd_pick_and_place_dataset_converted_externally_to_rlds": {
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["eef_state", None, "gripper_state"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 3,
"robot_type": "xArm",
},
"austin_sailor_dataset_converted_externally_to_rlds": {
"image_obs_keys": {
"primary": "image",
"secondary": None,
"wrist": "wrist_image",
},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["state"],
"state_encoding": StateEncoding.POS_QUAT,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 20,
"robot_type": "Franka",
},
"austin_sirius_dataset_converted_externally_to_rlds": {
"image_obs_keys": {
"primary": "image",
"secondary": None,
"wrist": "wrist_image",
},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["state"],
"state_encoding": StateEncoding.POS_QUAT,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 20,
"robot_type": "Franka",
},
"bc_z": {
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": [
"present/xyz",
"present/axis_angle",
None,
"present/sensed_close",
],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 10,
"robot_type": "Google Robot",
},
"utokyo_pr2_opening_fridge_converted_externally_to_rlds": {
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["eef_state", None, "gripper_state"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 10,
"robot_type": "PR2",
},
"utokyo_pr2_tabletop_manipulation_converted_externally_to_rlds": {
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["eef_state", None, "gripper_state"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 10,
"robot_type": "PR2",
},
"utokyo_xarm_pick_and_place_converted_externally_to_rlds": {
"image_obs_keys": {
"primary": "image",
"secondary": "image2",
"wrist": "hand_image",
},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["end_effector_pose", None, None],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 10,
"robot_type": "xArm",
},
"utokyo_xarm_bimanual_converted_externally_to_rlds": {
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["pose_r", None, None],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 10,
"robot_type": "xArm Bimanual",
},
"robo_net": {
"image_obs_keys": {"primary": "image", "secondary": "image1", "wrist": None},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["eef_state", None, "gripper_state"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 1,
"robot_type": "Multi-Robot",
},
"berkeley_mvp_converted_externally_to_rlds": {
"image_obs_keys": {"primary": None, "secondary": None, "wrist": "hand_image"},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["pose", "gripper"],
"state_encoding": StateEncoding.POS_QUAT,
"action_encoding": ActionEncoding.JOINT_POS,
"control_frequency": 5,
"robot_type": "xArm",
},
"berkeley_rpt_converted_externally_to_rlds": {
"image_obs_keys": {"primary": None, "secondary": None, "wrist": "hand_image"},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["joint_pos", "gripper"],
"state_encoding": StateEncoding.JOINT,
"action_encoding": ActionEncoding.JOINT_POS,
"control_frequency": 30,
"robot_type": "Franka",
},
"kaist_nonprehensile_converted_externally_to_rlds": {
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["state", None],
"state_encoding": StateEncoding.POS_QUAT,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 10,
"robot_type": "Franka",
},
"stanford_mask_vit_converted_externally_to_rlds": {
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["eef_state", None, "gripper_state"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": None,
"robot_type": "Sawyer",
},
"tokyo_u_lsmo_converted_externally_to_rlds": {
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["eef_state", None, "gripper_state"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 10,
"robot_type": "Cobotta",
},
"dlr_sara_pour_converted_externally_to_rlds": {
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["state", None, None],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 10,
"robot_type": "DLR SARA",
},
"dlr_sara_grid_clamp_converted_externally_to_rlds": {
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["state", None, None],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 10,
"robot_type": "DLR SARA",
},
"dlr_edan_shared_control_converted_externally_to_rlds": {
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["state", None],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 5,
"robot_type": "DLR EDAN",
},
"asu_table_top_converted_externally_to_rlds": {
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["eef_state", None, "gripper_state"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 12.5,
"robot_type": "UR5",
},
"stanford_robocook_converted_externally_to_rlds": {
"image_obs_keys": {"primary": "image_1", "secondary": "image_2", "wrist": None},
"depth_obs_keys": {"primary": "depth_1", "secondary": "depth_2", "wrist": None},
"state_obs_keys": ["eef_state", None, "gripper_state"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 5,
"robot_type": "Franka",
},
"imperialcollege_sawyer_wrist_cam": {
"image_obs_keys": {
"primary": "image",
"secondary": None,
"wrist": "wrist_image",
},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": [None, None, None, None, None, None, None, "state"],
"state_encoding": StateEncoding.NONE,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 10,
"robot_type": "Sawyer",
},
"iamlab_cmu_pickup_insert_converted_externally_to_rlds": {
"image_obs_keys": {
"primary": "image",
"secondary": None,
"wrist": "wrist_image",
},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["joint_state", "gripper_state"],
"state_encoding": StateEncoding.JOINT,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 20,
"robot_type": "Franka",
},
"uiuc_d3field": {
"image_obs_keys": {"primary": "image_1", "secondary": "image_2", "wrist": None},
"depth_obs_keys": {"primary": "depth_1", "secondary": "depth_2", "wrist": None},
"state_obs_keys": [None, None, None, None, None, None, None, None],
"state_encoding": StateEncoding.NONE,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 1,
"robot_type": "Kinova Gen3",
},
"utaustin_mutex": {
"image_obs_keys": {
"primary": "image",
"secondary": None,
"wrist": "wrist_image",
},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["state"],
"state_encoding": StateEncoding.JOINT,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 20,
"robot_type": "Franka",
},
"berkeley_fanuc_manipulation": {
"image_obs_keys": {
"primary": "image",
"secondary": None,
"wrist": "wrist_image",
},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["joint_state", None, "gripper_state"],
"state_encoding": StateEncoding.JOINT,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 10,
"robot_type": "Fanuc Mate",
},
"cmu_playing_with_food": {
"image_obs_keys": {
"primary": "image",
"secondary": None,
"wrist": "finger_vision_1",
},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["state", None, None],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 10,
"robot_type": "Franka",
},
"cmu_play_fusion": {
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["state"],
"state_encoding": StateEncoding.JOINT,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 5,
"robot_type": "Franka",
},
"cmu_stretch": {
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["eef_state", None, "gripper_state"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 10,
"robot_type": "Hello Stretch",
},
"berkeley_gnm_recon": {
"image_obs_keys": {"primary": None, "secondary": None, "wrist": "image"},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["state", None, None],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 3,
"robot_type": "Jackal",
},
"berkeley_gnm_cory_hall": {
"image_obs_keys": {"primary": None, "secondary": None, "wrist": "image"},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["state", None, None],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 5,
"robot_type": "RC Car",
},
"berkeley_gnm_sac_son": {
"image_obs_keys": {"primary": None, "secondary": None, "wrist": "image"},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["state", None, None],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 10,
"robot_type": "TurtleBot 2",
},
# NOTE: modified
"droid": {
"image_obs_keys": {
"primary": "exterior_image_1_left",
"secondary": "exterior_image_2_left",
"wrist": "wrist_image_left",
},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["EEF_state", None, "gripper_state"],
"state_encoding": StateEncoding.POS_QUAT,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 15,
"robot_type": "Franka",
"aux_kwargs": {
"dataset_frame_transform_kwargs": {
"chunk_filter_fn": zero_action_filter,
},
},
},
"fmb_dataset": {
"image_obs_keys": {
"primary": "image_side_1",
"secondary": "image_side_2",
"wrist": "image_wrist_1",
},
"depth_obs_keys": {
"primary": "image_side_1_depth",
"secondary": "image_side_2_depth",
"wrist": "image_wrist_1_depth",
},
"state_obs_keys": ["proprio"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 10,
"robot_type": "Franka",
},
# NOTE: modified
"dobbe": {
"image_obs_keys": {"primary": "wrist_image", "secondary": None, "wrist": None},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["EEF_state", None, "gripper_state"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 3.75,
"robot_type": "Hello Stretch",
},
"roboset": {
"image_obs_keys": {
"primary": "image_left",
"secondary": "image_right",
"wrist": "image_wrist",
},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["proprio"],
"state_encoding": StateEncoding.JOINT,
"action_encoding": ActionEncoding.JOINT_POS,
"control_frequency": 5,
"robot_type": "Franka",
},
"rh20t": {
"image_obs_keys": {
"primary": "image_front",
"secondary": "image_side_right",
"wrist": "image_wrist",
},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["proprio"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 10,
"robot_type": "Flexiv",
},
### T-DROID datasets
"tdroid_carrot_in_bowl": { # "put carrot in bowl" task, 50 demos @ 5 Hz control
"image_obs_keys": {"primary": "static_image", "secondary": None, "wrist": None},
"depth_obs_keys": {"primary": "static_depth_image", "secondary": None, "wrist": None},
"state_obs_keys": ["EEF_state", None, "gripper_state"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 5,
"robot_type": "Franka",
},
"tdroid_pour_corn_in_pot": { # "pour corn from red bonawl into steel pot" task, 50 demos @ 5 Hz control
"image_obs_keys": {"primary": "static_image", "secondary": None, "wrist": None},
"depth_obs_keys": {"primary": "static_depth_image", "secondary": None, "wrist": None},
"state_obs_keys": ["EEF_state", None, "gripper_state"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 5,
"robot_type": "Franka",
},
"tdroid_flip_pot_upright": { # "flip pot upright" task, 10 demos @ 5 Hz control
"image_obs_keys": {"primary": "static_image", "secondary": None, "wrist": None},
"depth_obs_keys": {"primary": "static_depth_image", "secondary": None, "wrist": None},
"state_obs_keys": ["EEF_state", None, "gripper_state"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 5,
"robot_type": "Franka",
},
"tdroid_move_object_onto_plate": { # "move <object> onto plate" task, 150 demos @ 5 Hz control
"image_obs_keys": {"primary": "static_image", "secondary": None, "wrist": None},
"depth_obs_keys": {"primary": "static_depth_image", "secondary": None, "wrist": None},
"state_obs_keys": ["EEF_state", None, "gripper_state"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 5,
"robot_type": "Franka",
},
"tdroid_knock_object_over": { # "knock <object> over" task, 70 demos @ 5 Hz control
"image_obs_keys": {"primary": "static_image", "secondary": None, "wrist": None},
"depth_obs_keys": {"primary": "static_depth_image", "secondary": None, "wrist": None},
"state_obs_keys": ["EEF_state", None, "gripper_state"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 5,
"robot_type": "Franka",
},
"tdroid_cover_object_with_towel": { # "cover <object> with towel" task, 45 demos @ 5 Hz control
"image_obs_keys": {"primary": "static_image", "secondary": None, "wrist": None},
"depth_obs_keys": {"primary": "static_depth_image", "secondary": None, "wrist": None},
"state_obs_keys": ["EEF_state", None, "gripper_state"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 5,
"robot_type": "Franka",
},
### DROID Finetuning datasets
"droid_wipe": {
"image_obs_keys": {
"primary": "exterior_image_2_left",
"secondary": None,
"wrist": "wrist_image_left",
},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["proprio"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 15,
"robot_type": "Franka",
},
# NOTE: modified
### LIBERO datasets (modified versions)
"libero_spatial_no_noops": {
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": "wrist_image"},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["EEF_state", "gripper_state"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 20,
"robot_type": "Franka",
},
"libero_object_no_noops": {
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": "wrist_image"},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["EEF_state", "gripper_state"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 20,
"robot_type": "Franka",
},
"libero_goal_no_noops": {
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": "wrist_image"},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["EEF_state", "gripper_state"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 20,
"robot_type": "Franka",
},
"libero_10_no_noops": {
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": "wrist_image"},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["EEF_state", "gripper_state"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 20,
"robot_type": "Franka",
},
}

View File

@ -1,30 +0,0 @@
#!/bin/bash
#SBATCH --job-name=openx_rlds
#SBATCH --partition=hopper-cpu
#SBATCH --requeue
#SBATCH --time=00:01:00
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=8
#SBATCH --output=/fsx/%u/slurm/%j-%x.out
OUTPUT_DIR="/fsx/${USER}/slurm/${SLURM_JOB_NAME}-${SLURM_JOB_ID}-tasks"
mkdir -p $OUTPUT_DIR
# Function to run a task and redirect output to a separate file
run_task() {
local task_id=$1
local output_file="${OUTPUT_DIR}/task-${task_id}-${SLURM_JOB_ID}.out"
# Run the task and redirect output
python examples/port_datasets/openx_utils/test.py > $output_file 2>&1
}
echo $SBATCH_OUTPUT
# node has 380850M and 96 cpus
trap 'scontrol requeue ${SLURM_JOB_ID}; exit 15' SIGUSR1
echo "Starting job"
# note the "&" to start srun as a background thread
srun python examples/port_datasets/openx_utils/test.py &
# wait for signals...
wait

View File

@ -1,54 +0,0 @@
import os
import psutil
def display_system_info():
# Get the number of CPUs
num_cpus = psutil.cpu_count(logical=True)
print(f"Number of CPUs: {num_cpus}")
# Get memory information
memory_info = psutil.virtual_memory()
total_memory = memory_info.total / (1024**3) # Convert bytes to GB
available_memory = memory_info.available / (1024**3) # Convert bytes to GB
used_memory = memory_info.used / (1024**3) # Convert bytes to GB
print(f"Total Memory: {total_memory:.2f} GB")
print(f"Available Memory: {available_memory:.2f} GB")
print(f"Used Memory: {used_memory:.2f} GB")
def display_slurm_info():
# Get SLURM job ID
job_id = os.getenv("SLURM_JOB_ID")
print(f"SLURM Job ID: {job_id}")
# Get SLURM job name
job_name = os.getenv("SLURM_JOB_NAME")
print(f"SLURM Job Name: {job_name}")
# Get the number of tasks
num_tasks = os.getenv("SLURM_NTASKS")
print(f"Number of Tasks: {num_tasks}")
# Get the number of nodes
num_nodes = os.getenv("SLURM_NNODES")
print(f"Number of Nodes: {num_nodes}")
# Get the number of CPUs per task
cpus_per_task = os.getenv("SLURM_CPUS_PER_TASK")
print(f"CPUs per Task: {cpus_per_task}")
# Get the node list
node_list = os.getenv("SLURM_NODELIST")
print(f"Node List: {node_list}")
# Get the task ID (only available within an srun task)
task_id = os.getenv("SLURM_PROCID")
print(f"Task ID: {task_id}")
if __name__ == "__main__":
display_system_info()
display_slurm_info()

View File

@ -1,76 +0,0 @@
"""
Copied from https://github.com/openvla/openvla/blob/main/prismatic/vla/datasets/rlds/utils/data_utils.py
"""
from typing import Any, Dict
import tensorflow as tf
def binarize_gripper_actions(actions: tf.Tensor) -> tf.Tensor:
"""
Converts gripper actions from continuous to binary values (0 and 1).
We exploit that fact that most of the time, the gripper is fully open (near 1.0) or fully closed (near 0.0). As it
transitions between the two, it sometimes passes through a few intermediate values. We relabel those intermediate
values based on the state that is reached _after_ those intermediate values.
In the edge case that the trajectory ends with an intermediate value, we give up on binarizing and relabel that
chunk of intermediate values as the last action in the trajectory.
The `scan_fn` implements the following logic:
new_actions = np.empty_like(actions)
carry = actions[-1]
for i in reversed(range(actions.shape[0])):
if in_between_mask[i]:
carry = carry
else:
carry = float(open_mask[i])
new_actions[i] = carry
"""
open_mask, closed_mask = actions > 0.95, actions < 0.05
in_between_mask = tf.logical_not(tf.logical_or(open_mask, closed_mask))
is_open_float = tf.cast(open_mask, tf.float32)
def scan_fn(carry, i):
return tf.cond(in_between_mask[i], lambda: tf.cast(carry, tf.float32), lambda: is_open_float[i])
return tf.scan(scan_fn, tf.range(tf.shape(actions)[0]), actions[-1], reverse=True)
def invert_gripper_actions(actions: tf.Tensor) -> tf.Tensor:
return 1 - actions
def rel2abs_gripper_actions(actions: tf.Tensor) -> tf.Tensor:
"""
Converts relative gripper actions (+1 for closing, -1 for opening) to absolute actions (0 = closed; 1 = open).
Assumes that the first relative gripper is not redundant (i.e. close when already closed)!
"""
# Note =>> -1 for closing, 1 for opening, 0 for no change
opening_mask, closing_mask = actions < -0.1, actions > 0.1
thresholded_actions = tf.where(opening_mask, 1, tf.where(closing_mask, -1, 0))
def scan_fn(carry, i):
return tf.cond(thresholded_actions[i] == 0, lambda: carry, lambda: thresholded_actions[i])
# If no relative grasp, assumes open for whole trajectory
start = -1 * thresholded_actions[tf.argmax(thresholded_actions != 0, axis=0)]
start = tf.cond(start == 0, lambda: 1, lambda: start)
# Note =>> -1 for closed, 1 for open
new_actions = tf.scan(scan_fn, tf.range(tf.shape(actions)[0]), start)
new_actions = tf.cast(new_actions, tf.float32) / 2 + 0.5
return new_actions
# === Bridge-V2 =>> Dataset-Specific Transform ===
def relabel_bridge_actions(traj: Dict[str, Any]) -> Dict[str, Any]:
"""Relabels actions to use reached proprioceptive state; discards last timestep (no-action)."""
movement_actions = traj["observation"]["state"][1:, :6] - traj["observation"]["state"][:-1, :6]
traj_truncated = tf.nest.map_structure(lambda x: x[:-1], traj)
traj_truncated["action"] = tf.concat([movement_actions, traj["action"][:-1, -1:]], axis=1)
return traj_truncated

View File

@ -1,997 +0,0 @@
"""
Adapt from https://github.com/openvla/openvla/blob/main/prismatic/vla/datasets/rlds/oxe/transforms.py
transforms.py
Defines a registry of per-dataset standardization transforms for each dataset in Open-X Embodiment.
Transforms adopt the following structure:
Input: Dictionary of *batched* features (i.e., has leading time dimension)
Output: Dictionary `step` =>> {
"observation": {
<image_keys, depth_image_keys>
State (in chosen state representation)
},
"action": Action (in chosen action representation),
"language_instruction": str
}
"""
from typing import Any, Dict
import tensorflow as tf
from examples.port_datasets.openx_utils.transform_utils import (
binarize_gripper_actions,
invert_gripper_actions,
rel2abs_gripper_actions,
relabel_bridge_actions,
)
def droid_baseact_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
"""
DROID dataset transformation for actions expressed in *base* frame of the robot.
"""
def rand_swap_exterior_images(img1, img2):
"""
Randomly swaps the two exterior images (for training with single exterior input).
"""
return tf.cond(tf.random.uniform(shape=[]) > 0.5, lambda: (img1, img2), lambda: (img2, img1))
dt = trajectory["action_dict"]["cartesian_velocity"][:, :3]
dR = trajectory["action_dict"]["cartesian_velocity"][:, 3:6]
trajectory["action"] = tf.concat(
(
dt,
dR,
1 - trajectory["action_dict"]["gripper_position"],
),
axis=-1,
)
trajectory["observation"]["exterior_image_1_left"], trajectory["observation"]["exterior_image_2_left"] = (
rand_swap_exterior_images(
trajectory["observation"]["exterior_image_1_left"],
trajectory["observation"]["exterior_image_2_left"],
)
)
# trajectory["observation"]["proprio"] = tf.concat(
# (
# trajectory["observation"]["cartesian_position"],
# trajectory["observation"]["gripper_position"],
# ),
# axis=-1,
# )
trajectory["observation"]["EEF_state"] = trajectory["observation"]["cartesian_position"]
trajectory["observation"]["gripper_state"] = trajectory["observation"]["gripper_position"]
return trajectory
def droid_finetuning_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
"""
DROID dataset transformation for actions expressed in *base* frame of the robot.
"""
dt = trajectory["action_dict"]["cartesian_velocity"][:, :3]
dR = trajectory["action_dict"]["cartesian_velocity"][:, 3:6]
trajectory["action"] = tf.concat(
(
dt,
dR,
1 - trajectory["action_dict"]["gripper_position"],
),
axis=-1,
)
trajectory["observation"]["proprio"] = tf.concat(
(
trajectory["observation"]["cartesian_position"],
trajectory["observation"]["gripper_position"],
),
axis=-1,
)
return trajectory
def bridge_oxe_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
"""
Applies to version of Bridge V2 in Open X-Embodiment mixture.
Note =>> In original Bridge V2 dataset, the first timestep has an all-zero action, so we remove it!
"""
for key in trajectory:
if key == "traj_metadata":
continue
elif key in ["observation", "action"]:
for key2 in trajectory[key]:
trajectory[key][key2] = trajectory[key][key2][1:]
else:
trajectory[key] = trajectory[key][1:]
trajectory["action"] = tf.concat(
(
trajectory["action"]["world_vector"],
trajectory["action"]["rotation_delta"],
tf.cast(trajectory["action"]["open_gripper"][:, None], tf.float32),
),
axis=-1,
)
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
trajectory = relabel_bridge_actions(trajectory)
trajectory["observation"]["EEF_state"] = trajectory["observation"]["state"][:, :6]
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
return trajectory
def bridge_orig_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
"""
Applies to original version of Bridge V2 from the official project website.
Note =>> In original Bridge V2 dataset, the first timestep has an all-zero action, so we remove it!
"""
for key in trajectory:
if key == "traj_metadata":
continue
elif key == "observation":
for key2 in trajectory[key]:
trajectory[key][key2] = trajectory[key][key2][1:]
else:
trajectory[key] = trajectory[key][1:]
trajectory["action"] = tf.concat(
[
trajectory["action"][:, :6],
binarize_gripper_actions(trajectory["action"][:, -1])[:, None],
],
axis=1,
)
trajectory = relabel_bridge_actions(trajectory)
trajectory["observation"]["EEF_state"] = trajectory["observation"]["state"][:, :6]
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
return trajectory
def ppgm_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["action"] = tf.concat(
[
trajectory["action"][:, :6],
binarize_gripper_actions(trajectory["action"][:, -1])[:, None],
],
axis=1,
)
trajectory["observation"]["EEF_state"] = trajectory["observation"]["cartesian_position"][:, :6]
trajectory["observation"]["gripper_state"] = trajectory["observation"]["gripper_position"][:, -1:]
return trajectory
def rt1_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
# make gripper action absolute action, +1 = open, 0 = close
gripper_action = trajectory["action"]["gripper_closedness_action"][:, 0]
gripper_action = rel2abs_gripper_actions(gripper_action)
trajectory["action"] = tf.concat(
(
trajectory["action"]["world_vector"],
trajectory["action"]["rotation_delta"],
gripper_action[:, None],
),
axis=-1,
)
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
return trajectory
def kuka_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
# make gripper action absolute action, +1 = open, 0 = close
gripper_action = trajectory["action"]["gripper_closedness_action"][:, 0]
gripper_action = rel2abs_gripper_actions(gripper_action)
trajectory["action"] = tf.concat(
(
trajectory["action"]["world_vector"],
trajectory["action"]["rotation_delta"],
gripper_action[:, None],
),
axis=-1,
)
# decode compressed state
eef_value = tf.io.decode_compressed(
trajectory["observation"]["clip_function_input/base_pose_tool_reached"],
compression_type="ZLIB",
)
eef_value = tf.io.decode_raw(eef_value, tf.float32)
trajectory["observation"]["clip_function_input/base_pose_tool_reached"] = tf.reshape(eef_value, (-1, 7))
gripper_value = tf.io.decode_compressed(
trajectory["observation"]["gripper_closed"], compression_type="ZLIB"
)
gripper_value = tf.io.decode_raw(gripper_value, tf.float32)
trajectory["observation"]["gripper_closed"] = tf.reshape(gripper_value, (-1, 1))
# trajectory["language_instruction"] = tf.fill(
# tf.shape(trajectory["observation"]["natural_language_instruction"]), ""
# ) # delete uninformative language instruction
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
return trajectory
def taco_play_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["state_eef"] = trajectory["observation"]["robot_obs"][:, :6]
trajectory["observation"]["state_gripper"] = trajectory["observation"]["robot_obs"][:, 7:8]
trajectory["action"] = trajectory["action"]["rel_actions_world"]
# invert gripper action + clip, +1 = open, 0 = close
trajectory["action"] = tf.concat(
(
trajectory["action"][:, :6],
tf.clip_by_value(trajectory["action"][:, -1:], 0, 1),
),
axis=-1,
)
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
return trajectory
def jaco_play_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["state_eef"] = trajectory["observation"]["end_effector_cartesian_pos"][:, :6]
trajectory["observation"]["state_gripper"] = trajectory["observation"]["end_effector_cartesian_pos"][
:, -1:
]
# make gripper action absolute action, +1 = open, 0 = close
gripper_action = trajectory["action"]["gripper_closedness_action"][:, 0]
gripper_action = rel2abs_gripper_actions(gripper_action)
trajectory["action"] = tf.concat(
(
trajectory["action"]["world_vector"],
tf.zeros_like(trajectory["action"]["world_vector"]),
gripper_action[:, None],
),
axis=-1,
)
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
return trajectory
def berkeley_cable_routing_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["action"] = tf.concat(
(
trajectory["action"]["world_vector"],
trajectory["action"]["rotation_delta"],
tf.zeros_like(trajectory["action"]["world_vector"][:, :1]),
),
axis=-1,
)
# trajectory["language_instruction"] = tf.fill(
# tf.shape(trajectory["observation"]["natural_language_instruction"]), ""
# ) # delete uninformative language instruction
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
return trajectory
def roboturk_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
# invert absolute gripper action, +1 = open, 0 = close
gripper_action = invert_gripper_actions(
tf.clip_by_value(trajectory["action"]["gripper_closedness_action"], 0, 1)
)
trajectory["action"] = tf.concat(
(
trajectory["action"]["world_vector"],
trajectory["action"]["rotation_delta"],
gripper_action,
),
axis=-1,
)
# trajectory["language_instruction"] = tf.fill(
# tf.shape(trajectory["observation"]["natural_language_instruction"]), ""
# ) # delete uninformative language instruction
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
return trajectory
def nyu_door_opening_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
# make gripper action absolute action, +1 = open, 0 = close
gripper_action = trajectory["action"]["gripper_closedness_action"][:, 0]
gripper_action = rel2abs_gripper_actions(gripper_action)
trajectory["action"] = tf.concat(
(
trajectory["action"]["world_vector"],
trajectory["action"]["rotation_delta"],
gripper_action[:, None],
),
axis=-1,
)
# trajectory["language_instruction"] = tf.fill(
# tf.shape(trajectory["observation"]["natural_language_instruction"]), ""
# ) # delete uninformative language instruction
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
return trajectory
def viola_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
# make gripper action, +1 = open, 0 = close
gripper_action = trajectory["action"]["gripper_closedness_action"][:, None]
gripper_action = tf.clip_by_value(gripper_action, 0, 1)
gripper_action = invert_gripper_actions(gripper_action)
trajectory["action"] = tf.concat(
(
trajectory["action"]["world_vector"],
trajectory["action"]["rotation_delta"],
gripper_action,
),
axis=-1,
)
# trajectory["language_instruction"] = tf.fill(
# tf.shape(trajectory["observation"]["natural_language_instruction"]), ""
# ) # delete uninformative language instruction
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
return trajectory
def berkeley_autolab_ur5_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["state"] = trajectory["observation"]["robot_state"][:, 6:14]
trajectory["observation"]["depth"] = trajectory["observation"].pop("image_with_depth")
# make gripper action absolute action, +1 = open, 0 = close
gripper_action = trajectory["action"]["gripper_closedness_action"]
gripper_action = rel2abs_gripper_actions(gripper_action)
trajectory["action"] = tf.concat(
(
trajectory["action"]["world_vector"],
trajectory["action"]["rotation_delta"],
gripper_action[:, None],
),
axis=-1,
)
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
return trajectory
def toto_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["action"] = tf.concat(
(
trajectory["action"]["world_vector"],
trajectory["action"]["rotation_delta"],
tf.cast(trajectory["action"]["open_gripper"][:, None], tf.float32),
),
axis=-1,
)
# trajectory["language_instruction"] = tf.fill(
# tf.shape(trajectory["observation"]["natural_language_instruction"]), ""
# ) # delete uninformative language instruction
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
return trajectory
def language_table_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
# default to "open" gripper
trajectory["action"] = tf.concat(
(
trajectory["action"],
tf.zeros_like(trajectory["action"]),
tf.zeros_like(trajectory["action"]),
tf.ones_like(trajectory["action"][:, :1]),
),
axis=-1,
)
# decode language instruction
instruction_bytes = trajectory["observation"]["instruction"]
instruction_encoded = tf.strings.unicode_encode(instruction_bytes, output_encoding="UTF-8")
# Remove trailing padding --> convert RaggedTensor to regular Tensor.
trajectory["language_instruction"] = tf.strings.split(instruction_encoded, "\x00")[:, :1].to_tensor()[
:, 0
]
return trajectory
def pusht_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["action"] = tf.concat(
(
trajectory["action"]["world_vector"],
trajectory["action"]["rotation_delta"],
trajectory["action"]["gripper_closedness_action"][:, None],
),
axis=-1,
)
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
return trajectory
def stanford_kuka_multimodal_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["depth_image"] = trajectory["observation"]["depth_image"][..., 0]
trajectory["action"] = tf.concat(
(
trajectory["action"][:, :3],
tf.zeros_like(trajectory["action"][:, :3]),
trajectory["action"][:, -1:],
),
axis=-1,
)
return trajectory
def nyu_rot_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][..., :6]
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][..., -1:]
trajectory["action"] = trajectory["action"][..., :7]
return trajectory
def stanford_hydra_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
# invert gripper action, +1 = open, 0 = close
trajectory["action"] = tf.concat(
(
trajectory["action"][:, :6],
invert_gripper_actions(trajectory["action"][:, -1:]),
),
axis=-1,
)
trajectory["observation"]["eef_state"] = tf.concat(
(
trajectory["observation"]["state"][:, :3],
trajectory["observation"]["state"][:, 7:10],
),
axis=-1,
)
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -3:-2]
# trajectory["language_instruction"] = tf.fill(
# tf.shape(trajectory["language_instruction"]), ""
# ) # delete uninformative language instruction
return trajectory
def austin_buds_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
# invert gripper action + clip, +1 = open, 0 = close
trajectory["action"] = tf.concat(
(
trajectory["action"][:, :6],
invert_gripper_actions(tf.clip_by_value(trajectory["action"][:, -1:], 0, 1)),
),
axis=-1,
)
trajectory["observation"]["state"] = trajectory["observation"]["state"][:, :8]
# trajectory["language_instruction"] = tf.fill(
# tf.shape(trajectory["language_instruction"]), ""
# ) # delete uninformative language instruction
return trajectory
def nyu_franka_play_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["depth"] = tf.cast(trajectory["observation"]["depth"][..., 0], tf.float32)
trajectory["observation"]["depth_additional_view"] = tf.cast(
trajectory["observation"]["depth_additional_view"][..., 0], tf.float32
)
trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, -6:]
# clip gripper action, +1 = open, 0 = close
trajectory["action"] = tf.concat(
(
trajectory["action"][:, -8:-2],
tf.clip_by_value(trajectory["action"][:, -2:-1], 0, 1),
),
axis=-1,
)
# trajectory["language_instruction"] = tf.fill(
# tf.shape(trajectory["language_instruction"]), ""
# ) # delete uninformative language instruction
return trajectory
def maniskill_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][..., 7:8]
return trajectory
def furniture_bench_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
import tensorflow_graphics.geometry.transformation as tft
trajectory["observation"]["state"] = tf.concat(
(
trajectory["observation"]["state"][:, :7],
trajectory["observation"]["state"][:, -1:],
),
axis=-1,
)
# invert gripper action + clip, +1 = open, 0 = close
trajectory["action"] = tf.concat(
(
trajectory["action"][:, :3],
tft.euler.from_quaternion(trajectory["action"][:, 3:7]),
invert_gripper_actions(tf.clip_by_value(trajectory["action"][:, -1:], 0, 1)),
),
axis=-1,
)
return trajectory
def cmu_franka_exploration_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["action"] = trajectory["action"][..., :-1]
return trajectory
def ucsd_kitchen_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["joint_state"] = trajectory["observation"]["state"][:, :7]
trajectory["action"] = trajectory["action"][..., :-1]
return trajectory
def ucsd_pick_place_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, :6]
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
trajectory["action"] = tf.concat(
(
trajectory["action"][:, :3],
tf.zeros_like(trajectory["action"][:, :3]),
trajectory["action"][:, -1:],
),
axis=-1,
)
return trajectory
def austin_sailor_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
# invert gripper action + clip, +1 = open, 0 = close
trajectory["action"] = tf.concat(
(
trajectory["action"][:, :6],
invert_gripper_actions(tf.clip_by_value(trajectory["action"][:, -1:], 0, 1)),
),
axis=-1,
)
# trajectory["language_instruction"] = tf.fill(
# tf.shape(trajectory["language_instruction"]), ""
# ) # delete uninformative language instruction
return trajectory
def austin_sirius_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
# invert gripper action + clip, +1 = open, 0 = close
trajectory["action"] = tf.concat(
(
trajectory["action"][:, :6],
invert_gripper_actions(tf.clip_by_value(trajectory["action"][:, -1:], 0, 1)),
),
axis=-1,
)
# trajectory["language_instruction"] = tf.fill(
# tf.shape(trajectory["language_instruction"]), ""
# ) # delete uninformative language instruction
return trajectory
def bc_z_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["action"] = tf.concat(
(
trajectory["action"]["future/xyz_residual"][:, :3],
trajectory["action"]["future/axis_angle_residual"][:, :3],
invert_gripper_actions(tf.cast(trajectory["action"]["future/target_close"][:, :1], tf.float32)),
),
axis=-1,
)
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
return trajectory
def tokyo_pr2_opening_fridge_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, :6]
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
trajectory["action"] = trajectory["action"][..., :-1]
return trajectory
def tokyo_pr2_tabletop_manipulation_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, :6]
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
trajectory["action"] = trajectory["action"][..., :-1]
return trajectory
def utokyo_xarm_pick_place_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
return trajectory
def utokyo_xarm_bimanual_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["action"] = trajectory["action"][..., -7:]
return trajectory
def robo_net_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["eef_state"] = tf.concat(
(
trajectory["observation"]["state"][:, :4],
tf.zeros_like(trajectory["observation"]["state"][:, :2]),
),
axis=-1,
)
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
trajectory["action"] = tf.concat(
(
trajectory["action"][:, :4],
tf.zeros_like(trajectory["action"][:, :2]),
trajectory["action"][:, -1:],
),
axis=-1,
)
return trajectory
def berkeley_mvp_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
return trajectory
def berkeley_rpt_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
return trajectory
def kaist_nonprehensible_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["state"] = trajectory["observation"]["state"][:, -7:]
trajectory["action"] = tf.concat(
(
trajectory["action"][:, :6],
tf.zeros_like(trajectory["action"][:, :1]),
),
axis=-1,
)
return trajectory
def stanford_mask_vit_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["eef_state"] = tf.concat(
(
trajectory["observation"]["end_effector_pose"][:, :4],
tf.zeros_like(trajectory["observation"]["end_effector_pose"][:, :2]),
),
axis=-1,
)
trajectory["observation"]["gripper_state"] = trajectory["observation"]["end_effector_pose"][:, -1:]
trajectory["action"] = tf.concat(
(
trajectory["action"][:, :4],
tf.zeros_like(trajectory["action"][:, :2]),
trajectory["action"][:, -1:],
),
axis=-1,
)
return trajectory
def tokyo_lsmo_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, :6]
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
return trajectory
def dlr_sara_pour_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
return trajectory
def dlr_sara_grid_clamp_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["state"] = trajectory["observation"]["state"][:, :6]
return trajectory
def dlr_edan_shared_control_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
# invert gripper action, +1 = open, 0 = close
trajectory["action"] = tf.concat(
(
trajectory["action"][:, :6],
invert_gripper_actions(trajectory["action"][:, -1:]),
),
axis=-1,
)
return trajectory
def asu_table_top_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["eef_state"] = trajectory["ground_truth_states"]["EE"]
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
return trajectory
def robocook_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, :6]
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
return trajectory
def imperial_wristcam_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["action"] = trajectory["action"][..., :-1]
return trajectory
def iamlab_pick_insert_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
import tensorflow_graphics.geometry.transformation as tft
trajectory["observation"]["joint_state"] = trajectory["observation"]["state"][:, :7]
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, 7:8]
trajectory["action"] = tf.concat(
(
trajectory["action"][:, :3],
tft.euler.from_quaternion(trajectory["action"][:, 3:7]),
trajectory["action"][:, 7:8],
),
axis=-1,
)
return trajectory
def uiuc_d3field_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["action"] = tf.concat(
(
trajectory["action"],
tf.zeros_like(trajectory["action"]),
tf.zeros_like(trajectory["action"][:, :1]),
),
axis=-1,
)
return trajectory
def utaustin_mutex_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["state"] = trajectory["observation"]["state"][:, :8]
# invert gripper action + clip, +1 = open, 0 = close
trajectory["action"] = tf.concat(
(
trajectory["action"][:, :6],
invert_gripper_actions(tf.clip_by_value(trajectory["action"][:, -1:], 0, 1)),
),
axis=-1,
)
# trajectory["language_instruction"] = tf.fill(
# tf.shape(trajectory["language_instruction"]), ""
# ) # delete uninformative language instruction
return trajectory
def berkeley_fanuc_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["joint_state"] = trajectory["observation"]["state"][:, :6]
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, 6:7]
# dataset does not store gripper actions, so use gripper state info, invert so +1 = open, 0 = close
trajectory["action"] = tf.concat(
(
trajectory["action"],
invert_gripper_actions(trajectory["observation"]["gripper_state"]),
),
axis=-1,
)
return trajectory
def cmu_playing_with_food_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
import tensorflow_graphics.geometry.transformation as tft
trajectory["action"] = tf.concat(
(
trajectory["action"][:, :3],
tft.euler.from_quaternion(trajectory["action"][:, 3:7]),
trajectory["action"][:, -1:],
),
axis=-1,
)
return trajectory
def playfusion_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["action"] = tf.concat(
(
trajectory["action"][:, :3],
trajectory["action"][:, -4:],
),
axis=-1,
)
return trajectory
def cmu_stretch_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["eef_state"] = tf.concat(
(
trajectory["observation"]["state"][:, :3],
tf.zeros_like(trajectory["observation"]["state"][:, :3]),
),
axis=-1,
)
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
trajectory["action"] = trajectory["action"][..., :-1]
return trajectory
def gnm_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["state"] = tf.concat(
(
trajectory["observation"]["position"],
tf.zeros_like(trajectory["observation"]["state"][:, :3]),
trajectory["observation"]["yaw"],
),
axis=-1,
)
trajectory["action"] = tf.concat(
(
trajectory["action"],
tf.zeros_like(trajectory["action"]),
tf.zeros_like(trajectory["action"]),
tf.zeros_like(trajectory["action"][:, :1]),
),
axis=-1,
)
return trajectory
def fmb_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
# every input feature is batched, ie has leading batch dimension
trajectory["observation"]["proprio"] = tf.concat(
(
trajectory["observation"]["eef_pose"],
trajectory["observation"]["state_gripper_pose"][..., None],
),
axis=-1,
)
return trajectory
def dobbe_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
# every input feature is batched, ie has leading batch dimension
trajectory["observation"]["EEF_state"] = trajectory["observation"]["state"][:, :6]
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
return trajectory
def roboset_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
# every input feature is batched, ie has leading batch dimension
trajectory["observation"]["proprio"] = trajectory["observation"]["state"]
# gripper action is in -1...1 --> clip to 0...1, flip
gripper_action = trajectory["action"][:, -1:]
gripper_action = invert_gripper_actions(tf.clip_by_value(gripper_action, 0, 1))
trajectory["action"] = tf.concat(
(
trajectory["action"][:, :7],
gripper_action,
),
axis=-1,
)
return trajectory
def rh20t_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["action"] = tf.concat(
(
trajectory["action"]["tcp_base"],
tf.cast(trajectory["action"]["gripper"][:, None], tf.float32),
),
axis=-1,
)
trajectory["observation"]["proprio"] = tf.concat(
(
trajectory["observation"]["tcp_base"],
trajectory["observation"]["gripper_width"][..., None],
),
axis=-1,
)
return trajectory
def tdroid_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["action"] = tf.concat(
[
trajectory["action"][:, :6],
binarize_gripper_actions(trajectory["action"][:, -1])[:, None],
],
axis=1,
)
trajectory["observation"]["EEF_state"] = trajectory["observation"]["cartesian_position"][:, :6]
trajectory["observation"]["gripper_state"] = trajectory["observation"]["gripper_position"][:, -1:]
return trajectory
def libero_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
# gripper action is in -1 (open)...1 (close) --> clip to 0...1, flip --> +1 = open, 0 = close
gripper_action = trajectory["action"][:, -1:]
gripper_action = invert_gripper_actions(tf.clip_by_value(gripper_action, 0, 1))
trajectory["action"] = tf.concat(
[
trajectory["action"][:, :6],
gripper_action,
],
axis=1,
)
trajectory["observation"]["EEF_state"] = trajectory["observation"]["state"][:, :6]
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][
:, -2:
] # 2D gripper state
return trajectory
# === Registry ===
OXE_STANDARDIZATION_TRANSFORMS = {
"bridge_oxe": bridge_oxe_dataset_transform,
"bridge_orig": bridge_orig_dataset_transform,
"bridge_dataset": bridge_orig_dataset_transform,
"ppgm": ppgm_dataset_transform,
"ppgm_static": ppgm_dataset_transform,
"ppgm_wrist": ppgm_dataset_transform,
"fractal20220817_data": rt1_dataset_transform,
"kuka": kuka_dataset_transform,
"taco_play": taco_play_dataset_transform,
"jaco_play": jaco_play_dataset_transform,
"berkeley_cable_routing": berkeley_cable_routing_dataset_transform,
"roboturk": roboturk_dataset_transform,
"nyu_door_opening_surprising_effectiveness": nyu_door_opening_dataset_transform,
"viola": viola_dataset_transform,
"berkeley_autolab_ur5": berkeley_autolab_ur5_dataset_transform,
"toto": toto_dataset_transform,
"language_table": language_table_dataset_transform,
"columbia_cairlab_pusht_real": pusht_dataset_transform,
"stanford_kuka_multimodal_dataset_converted_externally_to_rlds": stanford_kuka_multimodal_dataset_transform,
"nyu_rot_dataset_converted_externally_to_rlds": nyu_rot_dataset_transform,
"stanford_hydra_dataset_converted_externally_to_rlds": stanford_hydra_dataset_transform,
"austin_buds_dataset_converted_externally_to_rlds": austin_buds_dataset_transform,
"nyu_franka_play_dataset_converted_externally_to_rlds": nyu_franka_play_dataset_transform,
"maniskill_dataset_converted_externally_to_rlds": maniskill_dataset_transform,
"furniture_bench_dataset_converted_externally_to_rlds": furniture_bench_dataset_transform,
"cmu_franka_exploration_dataset_converted_externally_to_rlds": cmu_franka_exploration_dataset_transform,
"ucsd_kitchen_dataset_converted_externally_to_rlds": ucsd_kitchen_dataset_transform,
"ucsd_pick_and_place_dataset_converted_externally_to_rlds": ucsd_pick_place_dataset_transform,
"austin_sailor_dataset_converted_externally_to_rlds": austin_sailor_dataset_transform,
"austin_sirius_dataset_converted_externally_to_rlds": austin_sirius_dataset_transform,
"bc_z": bc_z_dataset_transform,
"utokyo_pr2_opening_fridge_converted_externally_to_rlds": tokyo_pr2_opening_fridge_dataset_transform,
"utokyo_pr2_tabletop_manipulation_converted_externally_to_rlds": tokyo_pr2_tabletop_manipulation_dataset_transform,
"utokyo_xarm_pick_and_place_converted_externally_to_rlds": utokyo_xarm_pick_place_dataset_transform,
"utokyo_xarm_bimanual_converted_externally_to_rlds": utokyo_xarm_bimanual_dataset_transform,
"robo_net": robo_net_dataset_transform,
"berkeley_mvp_converted_externally_to_rlds": berkeley_mvp_dataset_transform,
"berkeley_rpt_converted_externally_to_rlds": berkeley_rpt_dataset_transform,
"kaist_nonprehensile_converted_externally_to_rlds": kaist_nonprehensible_dataset_transform,
"stanford_mask_vit_converted_externally_to_rlds": stanford_mask_vit_dataset_transform,
"tokyo_u_lsmo_converted_externally_to_rlds": tokyo_lsmo_dataset_transform,
"dlr_sara_pour_converted_externally_to_rlds": dlr_sara_pour_dataset_transform,
"dlr_sara_grid_clamp_converted_externally_to_rlds": dlr_sara_grid_clamp_dataset_transform,
"dlr_edan_shared_control_converted_externally_to_rlds": dlr_edan_shared_control_dataset_transform,
"asu_table_top_converted_externally_to_rlds": asu_table_top_dataset_transform,
"stanford_robocook_converted_externally_to_rlds": robocook_dataset_transform,
"imperialcollege_sawyer_wrist_cam": imperial_wristcam_dataset_transform,
"iamlab_cmu_pickup_insert_converted_externally_to_rlds": iamlab_pick_insert_dataset_transform,
"uiuc_d3field": uiuc_d3field_dataset_transform,
"utaustin_mutex": utaustin_mutex_dataset_transform,
"berkeley_fanuc_manipulation": berkeley_fanuc_dataset_transform,
"cmu_playing_with_food": cmu_playing_with_food_dataset_transform,
"cmu_play_fusion": playfusion_dataset_transform,
"cmu_stretch": cmu_stretch_dataset_transform,
"berkeley_gnm_recon": gnm_dataset_transform,
"berkeley_gnm_cory_hall": gnm_dataset_transform,
"berkeley_gnm_sac_son": gnm_dataset_transform,
"droid": droid_baseact_transform,
"fmb_dataset": fmb_dataset_transform,
"dobbe": dobbe_dataset_transform,
"roboset": roboset_dataset_transform,
"rh20t_rlds": rh20t_dataset_transform,
### T-DROID datasets
"tdroid_carrot_in_bowl": tdroid_dataset_transform,
"tdroid_pour_corn_in_pot": tdroid_dataset_transform,
"tdroid_flip_pot_upright": tdroid_dataset_transform,
"tdroid_move_object_onto_plate": tdroid_dataset_transform,
"tdroid_knock_object_over": tdroid_dataset_transform,
"tdroid_cover_object_with_towel": tdroid_dataset_transform,
### DROID Finetuning datasets
"droid_wipe": droid_finetuning_transform,
### LIBERO datasets (modified versions)
"libero_spatial_no_noops": libero_dataset_transform,
"libero_object_no_noops": libero_dataset_transform,
"libero_goal_no_noops": libero_dataset_transform,
"libero_10_no_noops": libero_dataset_transform,
}

View File

@ -1,5 +1,5 @@
import logging
import subprocess
import shutil
import pandas as pd
import tqdm
@ -16,7 +16,7 @@ def validate_all_metadata(all_metadata: list[LeRobotDatasetMetadata]):
robot_type = all_metadata[0].robot_type
features = all_metadata[0].features
for meta in tqdm.tqdm(all_metadata):
for meta in tqdm.tqdm(all_metadata, desc="Validate all meta data"):
if fps != meta.fps:
raise ValueError(f"Same fps is expected, but got fps={meta.fps} instead of {fps}.")
if robot_type != meta.robot_type:
@ -41,7 +41,7 @@ def get_update_episode_and_task_func(episode_index_to_add, task_index_to_global_
def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, aggr_root=None):
logging.info("start aggregate_datasets")
logging.info("Start aggregate_datasets")
all_metadata = [LeRobotDatasetMetadata(repo_id) for repo_id in repo_ids]
@ -56,12 +56,12 @@ def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, aggr_root=None):
root=aggr_root,
)
logging.info("find all tasks")
logging.info("Find all tasks")
# find all tasks, deduplicate them, create new task indices for each dataset
# indexed by dataset index
datasets_task_index_to_aggr_task_index = {}
aggr_task_index = 0
for dataset_index, meta in enumerate(tqdm.tqdm(all_metadata)):
for dataset_index, meta in enumerate(tqdm.tqdm(all_metadata, desc="Find all tasks")):
task_index_to_aggr_task_index = {}
for task_index, task in meta.tasks.items():
@ -76,9 +76,9 @@ def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, aggr_root=None):
datasets_task_index_to_aggr_task_index[dataset_index] = task_index_to_aggr_task_index
logging.info("cp data and videos")
logging.info("Copy data and videos")
aggr_episode_index_shift = 0
for dataset_index, meta in enumerate(tqdm.tqdm(all_metadata)):
for dataset_index, meta in enumerate(tqdm.tqdm(all_metadata, desc="Copy data and videos")):
# cp data
for episode_index in range(meta.total_episodes):
aggr_episode_index = episode_index + aggr_episode_index_shift
@ -102,10 +102,10 @@ def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, aggr_root=None):
video_path = meta.root / meta.get_video_file_path(episode_index, vid_key)
aggr_video_path = aggr_meta.root / aggr_meta.get_video_file_path(aggr_episode_index, vid_key)
aggr_video_path.parent.mkdir(parents=True, exist_ok=True)
# shutil.copy(video_path, aggr_video_path)
shutil.copy(video_path, aggr_video_path)
copy_command = f"cp {video_path} {aggr_video_path} &"
subprocess.Popen(copy_command, shell=True)
# copy_command = f"cp {video_path} {aggr_video_path} &"
# subprocess.Popen(copy_command, shell=True)
# populate episodes
for episode_index, episode_dict in meta.episodes.items():