Cleanup, fix load_tasks
This commit is contained in:
parent
f96773de10
commit
835ab5a81b
|
@ -80,6 +80,7 @@ def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
|
||||||
if isinstance(first_item, PILImage.Image):
|
if isinstance(first_item, PILImage.Image):
|
||||||
to_tensor = transforms.ToTensor()
|
to_tensor = transforms.ToTensor()
|
||||||
items_dict[key] = [to_tensor(img) for img in items_dict[key]]
|
items_dict[key] = [to_tensor(img) for img in items_dict[key]]
|
||||||
|
# TODO(aliberts): remove this part as we'll be using task_index
|
||||||
elif isinstance(first_item, str):
|
elif isinstance(first_item, str):
|
||||||
# TODO (michel-aractingi): add str2embedding via language tokenizer
|
# TODO (michel-aractingi): add str2embedding via language tokenizer
|
||||||
# For now we leave this part up to the user to choose how to address
|
# For now we leave this part up to the user to choose how to address
|
||||||
|
@ -96,13 +97,13 @@ def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
|
||||||
|
|
||||||
|
|
||||||
@cache
|
@cache
|
||||||
def get_hub_safe_version(repo_id: str, version: str) -> str:
|
def get_hub_safe_version(repo_id: str, version: str, enforce_v2: bool = True) -> str:
|
||||||
num_version = float(version.strip("v"))
|
num_version = float(version.strip("v"))
|
||||||
if num_version < 2:
|
if num_version < 2 and enforce_v2:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"""The dataset you requested ({repo_id}) is in {version} format. We introduced a new
|
f"""The dataset you requested ({repo_id}) is in {version} format. We introduced a new
|
||||||
format with v2.0 that is not backward compatible. Please use our conversion script
|
format with v2.0 that is not backward compatible. Please use our conversion script
|
||||||
first (convert_dataset_16_to_20.py) to convert your dataset to this new format."""
|
first (convert_dataset_v1_to_v2.py) to convert your dataset to this new format."""
|
||||||
)
|
)
|
||||||
api = HfApi()
|
api = HfApi()
|
||||||
dataset_info = api.list_repo_refs(repo_id, repo_type="dataset")
|
dataset_info = api.list_repo_refs(repo_id, repo_type="dataset")
|
||||||
|
@ -192,7 +193,9 @@ def load_tasks(repo_id: str, version: str, local_dir: Path) -> dict:
|
||||||
repo_id, filename="meta/tasks.json", local_dir=local_dir, repo_type="dataset", revision=version
|
repo_id, filename="meta/tasks.json", local_dir=local_dir, repo_type="dataset", revision=version
|
||||||
)
|
)
|
||||||
with open(fpath) as f:
|
with open(fpath) as f:
|
||||||
return json.load(f)
|
tasks = json.load(f)
|
||||||
|
|
||||||
|
return {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}
|
||||||
|
|
||||||
|
|
||||||
def get_episode_data_index(episodes: list, episode_dicts: list[dict]) -> dict[str, torch.Tensor]:
|
def get_episode_data_index(episodes: list, episode_dicts: list[dict]) -> dict[str, torch.Tensor]:
|
||||||
|
|
|
@ -3,13 +3,18 @@ This script will help you convert any LeRobot dataset already pushed to the hub
|
||||||
2.0. You will be required to provide the 'tasks', which is a short but accurate description in plain English
|
2.0. You will be required to provide the 'tasks', which is a short but accurate description in plain English
|
||||||
for each of the task performed in the dataset. This will allow to easily train models with task-conditionning.
|
for each of the task performed in the dataset. This will allow to easily train models with task-conditionning.
|
||||||
|
|
||||||
We support 3 different scenarios for these tasks:
|
We support 3 different scenarios for these tasks (see instructions below):
|
||||||
1. Single task dataset: all episodes of your dataset have the same single task.
|
1. Single task dataset: all episodes of your dataset have the same single task.
|
||||||
2. Single task episodes: the episodes of your dataset each contain a single task but they can differ from
|
2. Single task episodes: the episodes of your dataset each contain a single task but they can differ from
|
||||||
one episode to the next.
|
one episode to the next.
|
||||||
3. Multi task episodes: episodes of your dataset may each contain several different tasks.
|
3. Multi task episodes: episodes of your dataset may each contain several different tasks.
|
||||||
|
|
||||||
|
|
||||||
|
Can you can also provide a robot config .yaml file (not mandatory) to this script via the option
|
||||||
|
'--robot-config' so that it writes information about the robot (robot type, motors names) this dataset was
|
||||||
|
recorded with. For now, only Aloha/Koch type robots are supported with this option.
|
||||||
|
|
||||||
|
|
||||||
# 1. Single task dataset
|
# 1. Single task dataset
|
||||||
If your dataset contains a single task, you can simply provide it directly via the CLI with the
|
If your dataset contains a single task, you can simply provide it directly via the CLI with the
|
||||||
'--single-task' option.
|
'--single-task' option.
|
||||||
|
@ -17,7 +22,7 @@ If your dataset contains a single task, you can simply provide it directly via t
|
||||||
Examples:
|
Examples:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python convert_dataset_v1_to_v2.py \
|
python lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py \
|
||||||
--repo-id lerobot/aloha_sim_insertion_human_image \
|
--repo-id lerobot/aloha_sim_insertion_human_image \
|
||||||
--single-task "Insert the peg into the socket." \
|
--single-task "Insert the peg into the socket." \
|
||||||
--robot-config lerobot/configs/robot/aloha.yaml \
|
--robot-config lerobot/configs/robot/aloha.yaml \
|
||||||
|
@ -25,7 +30,7 @@ python convert_dataset_v1_to_v2.py \
|
||||||
```
|
```
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python convert_dataset_v1_to_v2.py \
|
python lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py \
|
||||||
--repo-id aliberts/koch_tutorial \
|
--repo-id aliberts/koch_tutorial \
|
||||||
--single-task "Pick the Lego block and drop it in the box on the right." \
|
--single-task "Pick the Lego block and drop it in the box on the right." \
|
||||||
--robot-config lerobot/configs/robot/koch.yaml \
|
--robot-config lerobot/configs/robot/koch.yaml \
|
||||||
|
@ -42,7 +47,7 @@ If your dataset is a multi-task dataset, you have two options to provide the tas
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python convert_dataset_v1_to_v2.py \
|
python lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py \
|
||||||
--repo-id lerobot/stanford_kuka_multimodal_dataset \
|
--repo-id lerobot/stanford_kuka_multimodal_dataset \
|
||||||
--tasks-col "language_instruction" \
|
--tasks-col "language_instruction" \
|
||||||
--local-dir data
|
--local-dir data
|
||||||
|
@ -71,7 +76,7 @@ parquet file, and you must provide this column's name with the '--tasks-col' arg
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python convert_dataset_v1_to_v2.py \
|
python lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py \
|
||||||
--repo-id lerobot/stanford_kuka_multimodal_dataset \
|
--repo-id lerobot/stanford_kuka_multimodal_dataset \
|
||||||
--tasks-col "language_instruction" \
|
--tasks-col "language_instruction" \
|
||||||
--local-dir data
|
--local-dir data
|
||||||
|
@ -321,6 +326,7 @@ def get_videos_info(repo_id: str, local_dir: Path, video_keys: list[str]) -> dic
|
||||||
hub_api = HfApi()
|
hub_api = HfApi()
|
||||||
videos_info_dict = {"videos_path": VIDEO_PATH}
|
videos_info_dict = {"videos_path": VIDEO_PATH}
|
||||||
for vid_key in video_keys:
|
for vid_key in video_keys:
|
||||||
|
# Assumes first episode
|
||||||
video_path = VIDEO_PATH.format(video_key=vid_key, episode_index=0)
|
video_path = VIDEO_PATH.format(video_key=vid_key, episode_index=0)
|
||||||
video_path = hub_api.hf_hub_download(
|
video_path = hub_api.hf_hub_download(
|
||||||
repo_id=repo_id, repo_type="dataset", local_dir=local_dir, filename=video_path
|
repo_id=repo_id, repo_type="dataset", local_dir=local_dir, filename=video_path
|
||||||
|
@ -437,7 +443,7 @@ def convert_dataset(
|
||||||
episode_lengths = split_parquet_by_episodes(dataset, keys, total_episodes, episode_indices, v20_dir)
|
episode_lengths = split_parquet_by_episodes(dataset, keys, total_episodes, episode_indices, v20_dir)
|
||||||
|
|
||||||
# Shapes
|
# Shapes
|
||||||
sequence_shapes = {key: len(dataset[key][0]) for key in keys["sequence"]}
|
sequence_shapes = {key: dataset.features[key].length for key in keys["sequence"]}
|
||||||
image_shapes = get_image_shapes(dataset, keys["image"]) if len(keys["image"]) > 0 else {}
|
image_shapes = get_image_shapes(dataset, keys["image"]) if len(keys["image"]) > 0 else {}
|
||||||
if len(keys["video"]) > 0:
|
if len(keys["video"]) > 0:
|
||||||
assert metadata_v1.get("video", False)
|
assert metadata_v1.get("video", False)
|
||||||
|
@ -479,6 +485,7 @@ def convert_dataset(
|
||||||
"data_path": PARQUET_PATH,
|
"data_path": PARQUET_PATH,
|
||||||
"robot_type": robot_type,
|
"robot_type": robot_type,
|
||||||
"total_episodes": total_episodes,
|
"total_episodes": total_episodes,
|
||||||
|
"total_frames": len(dataset),
|
||||||
"total_tasks": len(tasks),
|
"total_tasks": len(tasks),
|
||||||
"fps": metadata_v1["fps"],
|
"fps": metadata_v1["fps"],
|
||||||
"splits": {"train": f"0:{total_episodes}"},
|
"splits": {"train": f"0:{total_episodes}"},
|
||||||
|
@ -512,13 +519,6 @@ def convert_dataset(
|
||||||
repo_type="dataset",
|
repo_type="dataset",
|
||||||
revision="main",
|
revision="main",
|
||||||
)
|
)
|
||||||
hub_api.upload_folder(
|
|
||||||
repo_id=repo_id,
|
|
||||||
path_in_repo="videos",
|
|
||||||
folder_path=v1x_dir / "videos",
|
|
||||||
repo_type="dataset",
|
|
||||||
revision="main",
|
|
||||||
)
|
|
||||||
hub_api.upload_folder(
|
hub_api.upload_folder(
|
||||||
repo_id=repo_id,
|
repo_id=repo_id,
|
||||||
path_in_repo="meta",
|
path_in_repo="meta",
|
Loading…
Reference in New Issue