Add multitask support, refactor conversion script

This commit is contained in:
Simon Alibert 2024-10-13 21:21:40 +02:00
parent 8bd406e607
commit cf633344be
1 changed files with 149 additions and 92 deletions

View File

@ -3,34 +3,70 @@ This script will help you convert any LeRobot dataset already pushed to the hub
2.0. You will be required to provide the 'tasks', which is a short but accurate description in plain English
for each of the task performed in the dataset. This will allow to easily train models with task-conditionning.
If your dataset contains a single task, you can provide it directly via the CLI with the '--task' option (see
examples below).
We support 3 different scenarios for these tasks:
1. Single task dataset: all episodes of your dataset have the same single task.
2. Single task episodes: the episodes of your dataset each contain a single task but they can differ from
one episode to the next.
3. Multi task episodes: episodes of your dataset may each contain several different tasks.
If your dataset is a multi-task dataset, TODO
In any case, keep in mind that there should only be one task per episode. Multi-task episodes are not
supported for now.
# 1. Single task dataset
If your dataset contains a single task, you can simply provide it directly via the CLI with the
'--single-task' option (see examples below).
Usage examples
Examples:
Single-task dataset:
```bash
python convert_dataset_16_to_20.py \
python convert_dataset_v1_to_v2.py \
--repo-id lerobot/aloha_sim_insertion_human_image \
--task "Insert the peg into the socket." \
--single-task "Insert the peg into the socket." \
--robot-config lerobot/configs/robot/aloha.yaml \
--local-dir data
```
```bash
python convert_dataset_16_to_20.py \
python convert_dataset_v1_to_v2.py \
--repo-id aliberts/koch_tutorial \
--task "Pick the Lego block and drop it in the box on the right." \
--single-task "Pick the Lego block and drop it in the box on the right." \
--robot-config lerobot/configs/robot/koch.yaml \
--local-dir data
```
Multi-task dataset:
# 2. Single task episodes
If your dataset is a multi-task dataset, you have two options to provide the tasks to this script:
- If your dataset already contains a language instruction column in its parquet file, you can simply provide
this column's name with the '--tasks-col' arg.
Example:
```bash
python convert_dataset_v1_to_v2.py \
--repo-id lerobot/stanford_kuka_multimodal_dataset \
--tasks-col "language_instruction" \
--local-dir data
```
- If your dataset doesn't contain a language instruction, you should provide the path to a .json file with the
'--tasks-path' arg. This file should have the following structure where keys correspond to each
episode_index in the dataset, and values are the language instruction for that episode.
Example:
```json
{
"0": "Do something",
"1": "Do something else",
"2": "Do something",
"3": "Go there",
...
}
```
# 3. Multi task episodes
If you have multiple tasks per episodes, your dataset should contain a language instruction column in its
parquet file, and you must provide this column's name with the '--tasks-col' arg.
TODO
"""
@ -39,13 +75,13 @@ import contextlib
import json
import math
import subprocess
from io import BytesIO
from pathlib import Path
import pyarrow as pa
import datasets
import pyarrow.compute as pc
import pyarrow.parquet as pq
import torch
from datasets import Dataset
from huggingface_hub import HfApi
from huggingface_hub.errors import EntryNotFoundError
from PIL import Image
@ -123,15 +159,14 @@ def convert_stats_to_json(input_dir: Path, output_dir: Path) -> None:
torch.testing.assert_close(stats_json[key], stats[key])
def get_keys(table: pa.Table) -> dict[str, list]:
table_metadata = json.loads(table.schema.metadata[b"huggingface"].decode("utf-8"))
def get_keys(dataset: Dataset) -> dict[str, list]:
sequence_keys, image_keys, video_keys = [], [], []
for key, val in table_metadata["info"]["features"].items():
if val["_type"] == "Sequence":
for key, ft in dataset.features.items():
if isinstance(ft, datasets.Sequence):
sequence_keys.append(key)
elif val["_type"] == "Image":
elif isinstance(ft, datasets.Image):
image_keys.append(key)
elif val["_type"] == "VideoFrame":
elif ft._type == "VideoFrame":
video_keys.append(key)
return {
@ -141,55 +176,49 @@ def get_keys(table: pa.Table) -> dict[str, list]:
}
def remove_hf_metadata_features(table: pa.Table, features: list[str]) -> pa.Table:
# HACK
schema = table.schema
# decode bytes dict
table_metadata = json.loads(schema.metadata[b"huggingface"].decode("utf-8"))
for key in features:
table_metadata["info"]["features"].pop(key)
def add_task_index_by_episodes(dataset: Dataset, tasks_by_episodes: dict) -> tuple[Dataset, list[str]]:
df = dataset.to_pandas()
tasks = list(set(tasks_by_episodes.values()))
tasks_to_task_index = {task: task_idx for task_idx, task in enumerate(tasks)}
episodes_to_task_index = {ep_idx: tasks_to_task_index[task] for ep_idx, task in tasks_by_episodes.items()}
df["task_index"] = df["episode_index"].map(episodes_to_task_index).astype(int)
# re-encode bytes dict
table_metadata = {b"huggingface": json.dumps(table_metadata).encode("utf-8")}
new_schema = schema.with_metadata(table_metadata)
return table.replace_schema_metadata(new_schema.metadata)
features = dataset.features
features["task_index"] = datasets.Value(dtype="int64")
dataset = Dataset.from_pandas(df, features=features, split="train")
return dataset, tasks
def add_hf_metadata_features(table: pa.Table, features: dict[str, dict]) -> pa.Table:
# HACK
schema = table.schema
# decode bytes dict
table_metadata = json.loads(schema.metadata[b"huggingface"].decode("utf-8"))
for key, val in features.items():
table_metadata["info"]["features"][key] = val
def add_task_index_from_tasks_col(
dataset: Dataset, tasks_col: str
) -> tuple[Dataset, dict[str, list[str]], list[str]]:
df = dataset.to_pandas()
# re-encode bytes dict
table_metadata = {b"huggingface": json.dumps(table_metadata).encode("utf-8")}
new_schema = schema.with_metadata(table_metadata)
return table.replace_schema_metadata(new_schema.metadata)
# HACK: This is to clean some of the instructions in our version of Open X datasets
prefix_to_clean = "tf.Tensor(b'"
suffix_to_clean = "', shape=(), dtype=string)"
df[tasks_col] = df[tasks_col].str.removeprefix(prefix_to_clean).str.removesuffix(suffix_to_clean)
# Create task_index col
tasks_by_episode = df.groupby("episode_index")[tasks_col].unique().apply(lambda x: x.tolist()).to_dict()
tasks = df[tasks_col].unique().tolist()
tasks_to_task_index = {task: idx for idx, task in enumerate(tasks)}
df["task_index"] = df[tasks_col].map(tasks_to_task_index).astype(int)
def remove_videoframe_from_table(table: pa.Table, image_columns: list) -> pa.Table:
table = table.drop(image_columns)
table = remove_hf_metadata_features(table, image_columns)
return table
# Build the dataset back from df
features = dataset.features
features["task_index"] = datasets.Value(dtype="int64")
dataset = Dataset.from_pandas(df, features=features, split="train")
dataset = dataset.remove_columns(tasks_col)
def add_tasks(table: pa.Table, tasks_by_episodes: dict) -> pa.Table:
tasks_index = pa.array([tasks_by_episodes.get(key.as_py(), None) for key in table["episode_index"]])
table = table.append_column("task_index", tasks_index)
hf_feature = {"task_index": {"dtype": "int64", "_type": "Value"}}
table = add_hf_metadata_features(table, hf_feature)
return table
return dataset, tasks, tasks_by_episode
def split_parquet_by_episodes(
table: pa.Table, keys: dict[str, list], total_episodes: int, episode_indices: list, output_dir: Path
dataset: Dataset, keys: dict[str, list], total_episodes: int, episode_indices: list, output_dir: Path
) -> list:
(output_dir / "data").mkdir(exist_ok=True, parents=True)
if len(keys["video"]) > 0:
table = remove_videoframe_from_table(table, keys["video"])
table = dataset.remove_columns(keys["video"])._data.table
episode_lengths = []
for episode_index in sorted(episode_indices):
# Write each episode_index to a new parquet file
@ -330,11 +359,10 @@ def get_video_shapes(videos_info: dict, video_keys: list) -> dict:
return video_shapes
def get_image_shapes(table: pa.Table, image_keys: list) -> dict:
def get_image_shapes(dataset: Dataset, image_keys: list) -> dict:
image_shapes = {}
for img_key in image_keys:
image_bytes = table[img_key][0].as_py() # Assuming first row
image = Image.open(BytesIO(image_bytes["bytes"]))
image = dataset[0][img_key] # Assuming first row
channels = get_image_pixel_channels(image)
image_shapes[img_key] = {
"width": image.width,
@ -352,8 +380,9 @@ def get_generic_motor_names(sequence_shapes: dict) -> dict:
def convert_dataset(
repo_id: str,
local_dir: Path,
tasks: dict,
tasks_by_episodes: dict | None = None,
single_task: str | None = None,
tasks_path: Path | None = None,
tasks_col: Path | None = None,
robot_config: dict | None = None,
):
v1_6_dir = local_dir / V1_6 / repo_id
@ -367,29 +396,40 @@ def convert_dataset(
)
metadata_v1_6 = load_json(v1_6_dir / "meta_data" / "info.json")
table = pq.read_table(v1_6_dir / "data")
keys = get_keys(table)
dataset = datasets.load_dataset("parquet", data_dir=v1_6_dir / "data", split="train")
keys = get_keys(dataset)
# Episodes
episode_indices = sorted(table["episode_index"].unique().to_pylist())
episode_indices = sorted(dataset.unique("episode_index"))
total_episodes = len(episode_indices)
assert episode_indices == list(range(total_episodes))
# Tasks
if tasks_by_episodes is None: # Single task dataset
tasks_by_episodes = {ep_idx: 0 for ep_idx in episode_indices}
if single_task:
tasks_by_episodes = {ep_idx: single_task for ep_idx in episode_indices}
dataset, tasks = add_task_index_by_episodes(dataset, tasks_by_episodes)
tasks_by_episodes = {ep_idx: [task] for ep_idx, task in tasks_by_episodes.items()}
elif tasks_path:
tasks_by_episodes = load_json(tasks_path)
tasks_by_episodes = {int(ep_idx): task for ep_idx, task in tasks_by_episodes.items()}
# tasks = list(set(tasks_by_episodes.values()))
dataset, tasks = add_task_index_by_episodes(dataset, tasks_by_episodes)
tasks_by_episodes = {ep_idx: [task] for ep_idx, task in tasks_by_episodes.items()}
elif tasks_col:
dataset, tasks, tasks_by_episodes = add_task_index_from_tasks_col(dataset, tasks_col)
else:
raise ValueError
assert set(tasks) == set(tasks_by_episodes.values())
table = add_tasks(table, tasks_by_episodes)
write_json(tasks, v2_0_dir / "meta" / "tasks.json")
assert set(tasks) == {task for ep_tasks in tasks_by_episodes.values() for task in ep_tasks}
task_json = [{"task_index": task_idx, "task": task} for task_idx, task in enumerate(tasks)]
write_json(task_json, v2_0_dir / "meta" / "tasks.json")
# Split data into 1 parquet file by episode
episode_lengths = split_parquet_by_episodes(table, keys, total_episodes, episode_indices, v2_0_dir)
episode_lengths = split_parquet_by_episodes(dataset, keys, total_episodes, episode_indices, v2_0_dir)
# Shapes
sequence_shapes = {key: len(table[key][0]) for key in keys["sequence"]}
image_shapes = get_image_shapes(table, keys["image"]) if len(keys["image"]) > 0 else {}
sequence_shapes = {key: len(dataset[key][0]) for key in keys["sequence"]}
image_shapes = get_image_shapes(dataset, keys["image"]) if len(keys["image"]) > 0 else {}
if len(keys["video"]) > 0:
assert metadata_v1_6.get("video", False)
videos_info = get_videos_info(repo_id, v1_6_dir, video_keys=keys["video"])
@ -416,11 +456,12 @@ def convert_dataset(
for key in sequence_shapes:
assert len(names[key]) == sequence_shapes[key]
# Episodes info
# Episodes
episodes = [
{"index": ep_idx, "task": tasks_by_episodes[ep_idx], "length": episode_lengths[ep_idx]}
{"episode_index": ep_idx, "tasks": [tasks_by_episodes[ep_idx]], "length": episode_lengths[ep_idx]}
for ep_idx in episode_indices
]
write_json(episodes, v2_0_dir / "meta" / "episodes.json")
# Assemble metadata v2.0
metadata_v2_0 = {
@ -437,11 +478,17 @@ def convert_dataset(
"shapes": {**sequence_shapes, **video_shapes, **image_shapes},
"names": names,
"videos": videos_info,
"episodes": episodes,
}
write_json(metadata_v2_0, v2_0_dir / "meta" / "info.json")
convert_stats_to_json(v1_6_dir / "meta_data", v2_0_dir / "meta")
#### TODO: delete
repo_id = f"aliberts/{repo_id.split('/')[1]}"
# if hub_api.repo_exists(repo_id=repo_id, repo_type="dataset"):
# hub_api.delete_repo(repo_id=repo_id, repo_type="dataset")
hub_api.create_repo(repo_id=repo_id, repo_type="dataset", exist_ok=True)
####
with contextlib.suppress(EntryNotFoundError):
hub_api.delete_folder(repo_id=repo_id, path_in_repo="data", repo_type="dataset", revision="main")
@ -455,6 +502,13 @@ def convert_dataset(
repo_type="dataset",
revision="main",
)
hub_api.upload_folder(
repo_id=repo_id,
path_in_repo="videos",
folder_path=v1_6_dir / "videos",
repo_type="dataset",
revision="main",
)
hub_api.upload_folder(
repo_id=repo_id,
path_in_repo="meta",
@ -463,7 +517,6 @@ def convert_dataset(
revision="main",
)
metadata_v2_0.pop("episodes")
card_text = f"[meta/info.json](meta/info.json)\n```json\n{json.dumps(metadata_v2_0, indent=4)}\n```"
push_dataset_card_to_hub(repo_id=repo_id, revision="main", tags=repo_tags, text=card_text)
create_branch(repo_id=repo_id, branch=V2_0, repo_type="dataset")
@ -478,12 +531,13 @@ def convert_dataset(
# - [X] Add robot_type
# - [X] Add splits
# - [X] Push properly to branch v2.0 and delete v1.6 stuff from that branch
# - [X] Handle multitask datasets
# - [/] Add sanity checks (encoding, shapes)
# - [ ] Handle multitask datasets
def main():
parser = argparse.ArgumentParser()
task_args = parser.add_mutually_exclusive_group(required=True)
parser.add_argument(
"--repo-id",
@ -491,11 +545,20 @@ def main():
required=True,
help="Repository identifier on Hugging Face: a community or a user name `/` the name of the dataset (e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).",
)
parser.add_argument(
"--task",
task_args.add_argument(
"--single-task",
type=str,
required=True,
help="A short but accurate description of the task performed in the dataset.",
help="A short but accurate description of the single task performed in the dataset.",
)
task_args.add_argument(
"--tasks-col",
type=str,
help="The name of the column containing language instructions",
)
task_args.add_argument(
"--tasks-path",
type=Path,
help="The path to a .json file containing one language instruction for each episode_index",
)
parser.add_argument(
"--robot-config",
@ -517,19 +580,13 @@ def main():
)
args = parser.parse_args()
if args.local_dir is None:
if not args.local_dir:
args.local_dir = Path(f"/tmp/{args.repo_id}")
tasks = {0: args.task}
del args.task
if args.robot_config is not None:
robot_config = parse_robot_config(args.robot_config, args.robot_overrides)
else:
robot_config = None
robot_config = parse_robot_config(args.robot_config, args.robot_overrides) if args.robot_config else None
del args.robot_config, args.robot_overrides
convert_dataset(**vars(args), tasks=tasks, robot_config=robot_config)
convert_dataset(**vars(args), robot_config=robot_config)
if __name__ == "__main__":