Add multitask support, refactor conversion script

This commit is contained in:
Simon Alibert 2024-10-13 21:21:40 +02:00
parent 8bd406e607
commit cf633344be
1 changed files with 149 additions and 92 deletions

View File

@ -3,34 +3,70 @@ This script will help you convert any LeRobot dataset already pushed to the hub
2.0. You will be required to provide the 'tasks', which is a short but accurate description in plain English 2.0. You will be required to provide the 'tasks', which is a short but accurate description in plain English
for each of the task performed in the dataset. This will allow to easily train models with task-conditionning. for each of the task performed in the dataset. This will allow to easily train models with task-conditionning.
If your dataset contains a single task, you can provide it directly via the CLI with the '--task' option (see We support 3 different scenarios for these tasks:
examples below). 1. Single task dataset: all episodes of your dataset have the same single task.
2. Single task episodes: the episodes of your dataset each contain a single task but they can differ from
one episode to the next.
3. Multi task episodes: episodes of your dataset may each contain several different tasks.
If your dataset is a multi-task dataset, TODO
In any case, keep in mind that there should only be one task per episode. Multi-task episodes are not # 1. Single task dataset
supported for now. If your dataset contains a single task, you can simply provide it directly via the CLI with the
'--single-task' option (see examples below).
Usage examples Examples:
Single-task dataset:
```bash ```bash
python convert_dataset_16_to_20.py \ python convert_dataset_v1_to_v2.py \
--repo-id lerobot/aloha_sim_insertion_human_image \ --repo-id lerobot/aloha_sim_insertion_human_image \
--task "Insert the peg into the socket." \ --single-task "Insert the peg into the socket." \
--robot-config lerobot/configs/robot/aloha.yaml \ --robot-config lerobot/configs/robot/aloha.yaml \
--local-dir data --local-dir data
``` ```
```bash ```bash
python convert_dataset_16_to_20.py \ python convert_dataset_v1_to_v2.py \
--repo-id aliberts/koch_tutorial \ --repo-id aliberts/koch_tutorial \
--task "Pick the Lego block and drop it in the box on the right." \ --single-task "Pick the Lego block and drop it in the box on the right." \
--robot-config lerobot/configs/robot/koch.yaml \ --robot-config lerobot/configs/robot/koch.yaml \
--local-dir data --local-dir data
``` ```
Multi-task dataset:
# 2. Single task episodes
If your dataset is a multi-task dataset, you have two options to provide the tasks to this script:
- If your dataset already contains a language instruction column in its parquet file, you can simply provide
this column's name with the '--tasks-col' arg.
Example:
```bash
python convert_dataset_v1_to_v2.py \
--repo-id lerobot/stanford_kuka_multimodal_dataset \
--tasks-col "language_instruction" \
--local-dir data
```
- If your dataset doesn't contain a language instruction, you should provide the path to a .json file with the
'--tasks-path' arg. This file should have the following structure where keys correspond to each
episode_index in the dataset, and values are the language instruction for that episode.
Example:
```json
{
"0": "Do something",
"1": "Do something else",
"2": "Do something",
"3": "Go there",
...
}
```
# 3. Multi task episodes
If you have multiple tasks per episodes, your dataset should contain a language instruction column in its
parquet file, and you must provide this column's name with the '--tasks-col' arg.
TODO TODO
""" """
@ -39,13 +75,13 @@ import contextlib
import json import json
import math import math
import subprocess import subprocess
from io import BytesIO
from pathlib import Path from pathlib import Path
import pyarrow as pa import datasets
import pyarrow.compute as pc import pyarrow.compute as pc
import pyarrow.parquet as pq import pyarrow.parquet as pq
import torch import torch
from datasets import Dataset
from huggingface_hub import HfApi from huggingface_hub import HfApi
from huggingface_hub.errors import EntryNotFoundError from huggingface_hub.errors import EntryNotFoundError
from PIL import Image from PIL import Image
@ -123,15 +159,14 @@ def convert_stats_to_json(input_dir: Path, output_dir: Path) -> None:
torch.testing.assert_close(stats_json[key], stats[key]) torch.testing.assert_close(stats_json[key], stats[key])
def get_keys(table: pa.Table) -> dict[str, list]: def get_keys(dataset: Dataset) -> dict[str, list]:
table_metadata = json.loads(table.schema.metadata[b"huggingface"].decode("utf-8"))
sequence_keys, image_keys, video_keys = [], [], [] sequence_keys, image_keys, video_keys = [], [], []
for key, val in table_metadata["info"]["features"].items(): for key, ft in dataset.features.items():
if val["_type"] == "Sequence": if isinstance(ft, datasets.Sequence):
sequence_keys.append(key) sequence_keys.append(key)
elif val["_type"] == "Image": elif isinstance(ft, datasets.Image):
image_keys.append(key) image_keys.append(key)
elif val["_type"] == "VideoFrame": elif ft._type == "VideoFrame":
video_keys.append(key) video_keys.append(key)
return { return {
@ -141,55 +176,49 @@ def get_keys(table: pa.Table) -> dict[str, list]:
} }
def remove_hf_metadata_features(table: pa.Table, features: list[str]) -> pa.Table: def add_task_index_by_episodes(dataset: Dataset, tasks_by_episodes: dict) -> tuple[Dataset, list[str]]:
# HACK df = dataset.to_pandas()
schema = table.schema tasks = list(set(tasks_by_episodes.values()))
# decode bytes dict tasks_to_task_index = {task: task_idx for task_idx, task in enumerate(tasks)}
table_metadata = json.loads(schema.metadata[b"huggingface"].decode("utf-8")) episodes_to_task_index = {ep_idx: tasks_to_task_index[task] for ep_idx, task in tasks_by_episodes.items()}
for key in features: df["task_index"] = df["episode_index"].map(episodes_to_task_index).astype(int)
table_metadata["info"]["features"].pop(key)
# re-encode bytes dict features = dataset.features
table_metadata = {b"huggingface": json.dumps(table_metadata).encode("utf-8")} features["task_index"] = datasets.Value(dtype="int64")
new_schema = schema.with_metadata(table_metadata) dataset = Dataset.from_pandas(df, features=features, split="train")
return table.replace_schema_metadata(new_schema.metadata) return dataset, tasks
def add_hf_metadata_features(table: pa.Table, features: dict[str, dict]) -> pa.Table: def add_task_index_from_tasks_col(
# HACK dataset: Dataset, tasks_col: str
schema = table.schema ) -> tuple[Dataset, dict[str, list[str]], list[str]]:
# decode bytes dict df = dataset.to_pandas()
table_metadata = json.loads(schema.metadata[b"huggingface"].decode("utf-8"))
for key, val in features.items():
table_metadata["info"]["features"][key] = val
# re-encode bytes dict # HACK: This is to clean some of the instructions in our version of Open X datasets
table_metadata = {b"huggingface": json.dumps(table_metadata).encode("utf-8")} prefix_to_clean = "tf.Tensor(b'"
new_schema = schema.with_metadata(table_metadata) suffix_to_clean = "', shape=(), dtype=string)"
return table.replace_schema_metadata(new_schema.metadata) df[tasks_col] = df[tasks_col].str.removeprefix(prefix_to_clean).str.removesuffix(suffix_to_clean)
# Create task_index col
tasks_by_episode = df.groupby("episode_index")[tasks_col].unique().apply(lambda x: x.tolist()).to_dict()
tasks = df[tasks_col].unique().tolist()
tasks_to_task_index = {task: idx for idx, task in enumerate(tasks)}
df["task_index"] = df[tasks_col].map(tasks_to_task_index).astype(int)
def remove_videoframe_from_table(table: pa.Table, image_columns: list) -> pa.Table: # Build the dataset back from df
table = table.drop(image_columns) features = dataset.features
table = remove_hf_metadata_features(table, image_columns) features["task_index"] = datasets.Value(dtype="int64")
return table dataset = Dataset.from_pandas(df, features=features, split="train")
dataset = dataset.remove_columns(tasks_col)
return dataset, tasks, tasks_by_episode
def add_tasks(table: pa.Table, tasks_by_episodes: dict) -> pa.Table:
tasks_index = pa.array([tasks_by_episodes.get(key.as_py(), None) for key in table["episode_index"]])
table = table.append_column("task_index", tasks_index)
hf_feature = {"task_index": {"dtype": "int64", "_type": "Value"}}
table = add_hf_metadata_features(table, hf_feature)
return table
def split_parquet_by_episodes( def split_parquet_by_episodes(
table: pa.Table, keys: dict[str, list], total_episodes: int, episode_indices: list, output_dir: Path dataset: Dataset, keys: dict[str, list], total_episodes: int, episode_indices: list, output_dir: Path
) -> list: ) -> list:
(output_dir / "data").mkdir(exist_ok=True, parents=True) (output_dir / "data").mkdir(exist_ok=True, parents=True)
if len(keys["video"]) > 0: table = dataset.remove_columns(keys["video"])._data.table
table = remove_videoframe_from_table(table, keys["video"])
episode_lengths = [] episode_lengths = []
for episode_index in sorted(episode_indices): for episode_index in sorted(episode_indices):
# Write each episode_index to a new parquet file # Write each episode_index to a new parquet file
@ -330,11 +359,10 @@ def get_video_shapes(videos_info: dict, video_keys: list) -> dict:
return video_shapes return video_shapes
def get_image_shapes(table: pa.Table, image_keys: list) -> dict: def get_image_shapes(dataset: Dataset, image_keys: list) -> dict:
image_shapes = {} image_shapes = {}
for img_key in image_keys: for img_key in image_keys:
image_bytes = table[img_key][0].as_py() # Assuming first row image = dataset[0][img_key] # Assuming first row
image = Image.open(BytesIO(image_bytes["bytes"]))
channels = get_image_pixel_channels(image) channels = get_image_pixel_channels(image)
image_shapes[img_key] = { image_shapes[img_key] = {
"width": image.width, "width": image.width,
@ -352,8 +380,9 @@ def get_generic_motor_names(sequence_shapes: dict) -> dict:
def convert_dataset( def convert_dataset(
repo_id: str, repo_id: str,
local_dir: Path, local_dir: Path,
tasks: dict, single_task: str | None = None,
tasks_by_episodes: dict | None = None, tasks_path: Path | None = None,
tasks_col: Path | None = None,
robot_config: dict | None = None, robot_config: dict | None = None,
): ):
v1_6_dir = local_dir / V1_6 / repo_id v1_6_dir = local_dir / V1_6 / repo_id
@ -367,29 +396,40 @@ def convert_dataset(
) )
metadata_v1_6 = load_json(v1_6_dir / "meta_data" / "info.json") metadata_v1_6 = load_json(v1_6_dir / "meta_data" / "info.json")
dataset = datasets.load_dataset("parquet", data_dir=v1_6_dir / "data", split="train")
table = pq.read_table(v1_6_dir / "data") keys = get_keys(dataset)
keys = get_keys(table)
# Episodes # Episodes
episode_indices = sorted(table["episode_index"].unique().to_pylist()) episode_indices = sorted(dataset.unique("episode_index"))
total_episodes = len(episode_indices) total_episodes = len(episode_indices)
assert episode_indices == list(range(total_episodes)) assert episode_indices == list(range(total_episodes))
# Tasks # Tasks
if tasks_by_episodes is None: # Single task dataset if single_task:
tasks_by_episodes = {ep_idx: 0 for ep_idx in episode_indices} tasks_by_episodes = {ep_idx: single_task for ep_idx in episode_indices}
dataset, tasks = add_task_index_by_episodes(dataset, tasks_by_episodes)
tasks_by_episodes = {ep_idx: [task] for ep_idx, task in tasks_by_episodes.items()}
elif tasks_path:
tasks_by_episodes = load_json(tasks_path)
tasks_by_episodes = {int(ep_idx): task for ep_idx, task in tasks_by_episodes.items()}
# tasks = list(set(tasks_by_episodes.values()))
dataset, tasks = add_task_index_by_episodes(dataset, tasks_by_episodes)
tasks_by_episodes = {ep_idx: [task] for ep_idx, task in tasks_by_episodes.items()}
elif tasks_col:
dataset, tasks, tasks_by_episodes = add_task_index_from_tasks_col(dataset, tasks_col)
else:
raise ValueError
assert set(tasks) == set(tasks_by_episodes.values()) assert set(tasks) == {task for ep_tasks in tasks_by_episodes.values() for task in ep_tasks}
table = add_tasks(table, tasks_by_episodes) task_json = [{"task_index": task_idx, "task": task} for task_idx, task in enumerate(tasks)]
write_json(tasks, v2_0_dir / "meta" / "tasks.json") write_json(task_json, v2_0_dir / "meta" / "tasks.json")
# Split data into 1 parquet file by episode # Split data into 1 parquet file by episode
episode_lengths = split_parquet_by_episodes(table, keys, total_episodes, episode_indices, v2_0_dir) episode_lengths = split_parquet_by_episodes(dataset, keys, total_episodes, episode_indices, v2_0_dir)
# Shapes # Shapes
sequence_shapes = {key: len(table[key][0]) for key in keys["sequence"]} sequence_shapes = {key: len(dataset[key][0]) for key in keys["sequence"]}
image_shapes = get_image_shapes(table, keys["image"]) if len(keys["image"]) > 0 else {} image_shapes = get_image_shapes(dataset, keys["image"]) if len(keys["image"]) > 0 else {}
if len(keys["video"]) > 0: if len(keys["video"]) > 0:
assert metadata_v1_6.get("video", False) assert metadata_v1_6.get("video", False)
videos_info = get_videos_info(repo_id, v1_6_dir, video_keys=keys["video"]) videos_info = get_videos_info(repo_id, v1_6_dir, video_keys=keys["video"])
@ -416,11 +456,12 @@ def convert_dataset(
for key in sequence_shapes: for key in sequence_shapes:
assert len(names[key]) == sequence_shapes[key] assert len(names[key]) == sequence_shapes[key]
# Episodes info # Episodes
episodes = [ episodes = [
{"index": ep_idx, "task": tasks_by_episodes[ep_idx], "length": episode_lengths[ep_idx]} {"episode_index": ep_idx, "tasks": [tasks_by_episodes[ep_idx]], "length": episode_lengths[ep_idx]}
for ep_idx in episode_indices for ep_idx in episode_indices
] ]
write_json(episodes, v2_0_dir / "meta" / "episodes.json")
# Assemble metadata v2.0 # Assemble metadata v2.0
metadata_v2_0 = { metadata_v2_0 = {
@ -437,11 +478,17 @@ def convert_dataset(
"shapes": {**sequence_shapes, **video_shapes, **image_shapes}, "shapes": {**sequence_shapes, **video_shapes, **image_shapes},
"names": names, "names": names,
"videos": videos_info, "videos": videos_info,
"episodes": episodes,
} }
write_json(metadata_v2_0, v2_0_dir / "meta" / "info.json") write_json(metadata_v2_0, v2_0_dir / "meta" / "info.json")
convert_stats_to_json(v1_6_dir / "meta_data", v2_0_dir / "meta") convert_stats_to_json(v1_6_dir / "meta_data", v2_0_dir / "meta")
#### TODO: delete
repo_id = f"aliberts/{repo_id.split('/')[1]}"
# if hub_api.repo_exists(repo_id=repo_id, repo_type="dataset"):
# hub_api.delete_repo(repo_id=repo_id, repo_type="dataset")
hub_api.create_repo(repo_id=repo_id, repo_type="dataset", exist_ok=True)
####
with contextlib.suppress(EntryNotFoundError): with contextlib.suppress(EntryNotFoundError):
hub_api.delete_folder(repo_id=repo_id, path_in_repo="data", repo_type="dataset", revision="main") hub_api.delete_folder(repo_id=repo_id, path_in_repo="data", repo_type="dataset", revision="main")
@ -455,6 +502,13 @@ def convert_dataset(
repo_type="dataset", repo_type="dataset",
revision="main", revision="main",
) )
hub_api.upload_folder(
repo_id=repo_id,
path_in_repo="videos",
folder_path=v1_6_dir / "videos",
repo_type="dataset",
revision="main",
)
hub_api.upload_folder( hub_api.upload_folder(
repo_id=repo_id, repo_id=repo_id,
path_in_repo="meta", path_in_repo="meta",
@ -463,7 +517,6 @@ def convert_dataset(
revision="main", revision="main",
) )
metadata_v2_0.pop("episodes")
card_text = f"[meta/info.json](meta/info.json)\n```json\n{json.dumps(metadata_v2_0, indent=4)}\n```" card_text = f"[meta/info.json](meta/info.json)\n```json\n{json.dumps(metadata_v2_0, indent=4)}\n```"
push_dataset_card_to_hub(repo_id=repo_id, revision="main", tags=repo_tags, text=card_text) push_dataset_card_to_hub(repo_id=repo_id, revision="main", tags=repo_tags, text=card_text)
create_branch(repo_id=repo_id, branch=V2_0, repo_type="dataset") create_branch(repo_id=repo_id, branch=V2_0, repo_type="dataset")
@ -478,12 +531,13 @@ def convert_dataset(
# - [X] Add robot_type # - [X] Add robot_type
# - [X] Add splits # - [X] Add splits
# - [X] Push properly to branch v2.0 and delete v1.6 stuff from that branch # - [X] Push properly to branch v2.0 and delete v1.6 stuff from that branch
# - [X] Handle multitask datasets
# - [/] Add sanity checks (encoding, shapes) # - [/] Add sanity checks (encoding, shapes)
# - [ ] Handle multitask datasets
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
task_args = parser.add_mutually_exclusive_group(required=True)
parser.add_argument( parser.add_argument(
"--repo-id", "--repo-id",
@ -491,11 +545,20 @@ def main():
required=True, required=True,
help="Repository identifier on Hugging Face: a community or a user name `/` the name of the dataset (e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).", help="Repository identifier on Hugging Face: a community or a user name `/` the name of the dataset (e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).",
) )
parser.add_argument( task_args.add_argument(
"--task", "--single-task",
type=str, type=str,
required=True, help="A short but accurate description of the single task performed in the dataset.",
help="A short but accurate description of the task performed in the dataset.", )
task_args.add_argument(
"--tasks-col",
type=str,
help="The name of the column containing language instructions",
)
task_args.add_argument(
"--tasks-path",
type=Path,
help="The path to a .json file containing one language instruction for each episode_index",
) )
parser.add_argument( parser.add_argument(
"--robot-config", "--robot-config",
@ -517,19 +580,13 @@ def main():
) )
args = parser.parse_args() args = parser.parse_args()
if args.local_dir is None: if not args.local_dir:
args.local_dir = Path(f"/tmp/{args.repo_id}") args.local_dir = Path(f"/tmp/{args.repo_id}")
tasks = {0: args.task} robot_config = parse_robot_config(args.robot_config, args.robot_overrides) if args.robot_config else None
del args.task
if args.robot_config is not None:
robot_config = parse_robot_config(args.robot_config, args.robot_overrides)
else:
robot_config = None
del args.robot_config, args.robot_overrides del args.robot_config, args.robot_overrides
convert_dataset(**vars(args), tasks=tasks, robot_config=robot_config) convert_dataset(**vars(args), robot_config=robot_config)
if __name__ == "__main__": if __name__ == "__main__":