Add multitask support, refactor conversion script
This commit is contained in:
parent
8bd406e607
commit
cf633344be
|
@ -3,34 +3,70 @@ This script will help you convert any LeRobot dataset already pushed to the hub
|
||||||
2.0. You will be required to provide the 'tasks', which is a short but accurate description in plain English
|
2.0. You will be required to provide the 'tasks', which is a short but accurate description in plain English
|
||||||
for each of the task performed in the dataset. This will allow to easily train models with task-conditionning.
|
for each of the task performed in the dataset. This will allow to easily train models with task-conditionning.
|
||||||
|
|
||||||
If your dataset contains a single task, you can provide it directly via the CLI with the '--task' option (see
|
We support 3 different scenarios for these tasks:
|
||||||
examples below).
|
1. Single task dataset: all episodes of your dataset have the same single task.
|
||||||
|
2. Single task episodes: the episodes of your dataset each contain a single task but they can differ from
|
||||||
|
one episode to the next.
|
||||||
|
3. Multi task episodes: episodes of your dataset may each contain several different tasks.
|
||||||
|
|
||||||
If your dataset is a multi-task dataset, TODO
|
|
||||||
|
|
||||||
In any case, keep in mind that there should only be one task per episode. Multi-task episodes are not
|
# 1. Single task dataset
|
||||||
supported for now.
|
If your dataset contains a single task, you can simply provide it directly via the CLI with the
|
||||||
|
'--single-task' option (see examples below).
|
||||||
|
|
||||||
Usage examples
|
Examples:
|
||||||
|
|
||||||
Single-task dataset:
|
|
||||||
```bash
|
```bash
|
||||||
python convert_dataset_16_to_20.py \
|
python convert_dataset_v1_to_v2.py \
|
||||||
--repo-id lerobot/aloha_sim_insertion_human_image \
|
--repo-id lerobot/aloha_sim_insertion_human_image \
|
||||||
--task "Insert the peg into the socket." \
|
--single-task "Insert the peg into the socket." \
|
||||||
--robot-config lerobot/configs/robot/aloha.yaml \
|
--robot-config lerobot/configs/robot/aloha.yaml \
|
||||||
--local-dir data
|
--local-dir data
|
||||||
```
|
```
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python convert_dataset_16_to_20.py \
|
python convert_dataset_v1_to_v2.py \
|
||||||
--repo-id aliberts/koch_tutorial \
|
--repo-id aliberts/koch_tutorial \
|
||||||
--task "Pick the Lego block and drop it in the box on the right." \
|
--single-task "Pick the Lego block and drop it in the box on the right." \
|
||||||
--robot-config lerobot/configs/robot/koch.yaml \
|
--robot-config lerobot/configs/robot/koch.yaml \
|
||||||
--local-dir data
|
--local-dir data
|
||||||
```
|
```
|
||||||
|
|
||||||
Multi-task dataset:
|
|
||||||
|
# 2. Single task episodes
|
||||||
|
If your dataset is a multi-task dataset, you have two options to provide the tasks to this script:
|
||||||
|
|
||||||
|
- If your dataset already contains a language instruction column in its parquet file, you can simply provide
|
||||||
|
this column's name with the '--tasks-col' arg.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python convert_dataset_v1_to_v2.py \
|
||||||
|
--repo-id lerobot/stanford_kuka_multimodal_dataset \
|
||||||
|
--tasks-col "language_instruction" \
|
||||||
|
--local-dir data
|
||||||
|
```
|
||||||
|
|
||||||
|
- If your dataset doesn't contain a language instruction, you should provide the path to a .json file with the
|
||||||
|
'--tasks-path' arg. This file should have the following structure where keys correspond to each
|
||||||
|
episode_index in the dataset, and values are the language instruction for that episode.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"0": "Do something",
|
||||||
|
"1": "Do something else",
|
||||||
|
"2": "Do something",
|
||||||
|
"3": "Go there",
|
||||||
|
...
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
# 3. Multi task episodes
|
||||||
|
If you have multiple tasks per episodes, your dataset should contain a language instruction column in its
|
||||||
|
parquet file, and you must provide this column's name with the '--tasks-col' arg.
|
||||||
TODO
|
TODO
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -39,13 +75,13 @@ import contextlib
|
||||||
import json
|
import json
|
||||||
import math
|
import math
|
||||||
import subprocess
|
import subprocess
|
||||||
from io import BytesIO
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import pyarrow as pa
|
import datasets
|
||||||
import pyarrow.compute as pc
|
import pyarrow.compute as pc
|
||||||
import pyarrow.parquet as pq
|
import pyarrow.parquet as pq
|
||||||
import torch
|
import torch
|
||||||
|
from datasets import Dataset
|
||||||
from huggingface_hub import HfApi
|
from huggingface_hub import HfApi
|
||||||
from huggingface_hub.errors import EntryNotFoundError
|
from huggingface_hub.errors import EntryNotFoundError
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
@ -123,15 +159,14 @@ def convert_stats_to_json(input_dir: Path, output_dir: Path) -> None:
|
||||||
torch.testing.assert_close(stats_json[key], stats[key])
|
torch.testing.assert_close(stats_json[key], stats[key])
|
||||||
|
|
||||||
|
|
||||||
def get_keys(table: pa.Table) -> dict[str, list]:
|
def get_keys(dataset: Dataset) -> dict[str, list]:
|
||||||
table_metadata = json.loads(table.schema.metadata[b"huggingface"].decode("utf-8"))
|
|
||||||
sequence_keys, image_keys, video_keys = [], [], []
|
sequence_keys, image_keys, video_keys = [], [], []
|
||||||
for key, val in table_metadata["info"]["features"].items():
|
for key, ft in dataset.features.items():
|
||||||
if val["_type"] == "Sequence":
|
if isinstance(ft, datasets.Sequence):
|
||||||
sequence_keys.append(key)
|
sequence_keys.append(key)
|
||||||
elif val["_type"] == "Image":
|
elif isinstance(ft, datasets.Image):
|
||||||
image_keys.append(key)
|
image_keys.append(key)
|
||||||
elif val["_type"] == "VideoFrame":
|
elif ft._type == "VideoFrame":
|
||||||
video_keys.append(key)
|
video_keys.append(key)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
@ -141,55 +176,49 @@ def get_keys(table: pa.Table) -> dict[str, list]:
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def remove_hf_metadata_features(table: pa.Table, features: list[str]) -> pa.Table:
|
def add_task_index_by_episodes(dataset: Dataset, tasks_by_episodes: dict) -> tuple[Dataset, list[str]]:
|
||||||
# HACK
|
df = dataset.to_pandas()
|
||||||
schema = table.schema
|
tasks = list(set(tasks_by_episodes.values()))
|
||||||
# decode bytes dict
|
tasks_to_task_index = {task: task_idx for task_idx, task in enumerate(tasks)}
|
||||||
table_metadata = json.loads(schema.metadata[b"huggingface"].decode("utf-8"))
|
episodes_to_task_index = {ep_idx: tasks_to_task_index[task] for ep_idx, task in tasks_by_episodes.items()}
|
||||||
for key in features:
|
df["task_index"] = df["episode_index"].map(episodes_to_task_index).astype(int)
|
||||||
table_metadata["info"]["features"].pop(key)
|
|
||||||
|
|
||||||
# re-encode bytes dict
|
features = dataset.features
|
||||||
table_metadata = {b"huggingface": json.dumps(table_metadata).encode("utf-8")}
|
features["task_index"] = datasets.Value(dtype="int64")
|
||||||
new_schema = schema.with_metadata(table_metadata)
|
dataset = Dataset.from_pandas(df, features=features, split="train")
|
||||||
return table.replace_schema_metadata(new_schema.metadata)
|
return dataset, tasks
|
||||||
|
|
||||||
|
|
||||||
def add_hf_metadata_features(table: pa.Table, features: dict[str, dict]) -> pa.Table:
|
def add_task_index_from_tasks_col(
|
||||||
# HACK
|
dataset: Dataset, tasks_col: str
|
||||||
schema = table.schema
|
) -> tuple[Dataset, dict[str, list[str]], list[str]]:
|
||||||
# decode bytes dict
|
df = dataset.to_pandas()
|
||||||
table_metadata = json.loads(schema.metadata[b"huggingface"].decode("utf-8"))
|
|
||||||
for key, val in features.items():
|
|
||||||
table_metadata["info"]["features"][key] = val
|
|
||||||
|
|
||||||
# re-encode bytes dict
|
# HACK: This is to clean some of the instructions in our version of Open X datasets
|
||||||
table_metadata = {b"huggingface": json.dumps(table_metadata).encode("utf-8")}
|
prefix_to_clean = "tf.Tensor(b'"
|
||||||
new_schema = schema.with_metadata(table_metadata)
|
suffix_to_clean = "', shape=(), dtype=string)"
|
||||||
return table.replace_schema_metadata(new_schema.metadata)
|
df[tasks_col] = df[tasks_col].str.removeprefix(prefix_to_clean).str.removesuffix(suffix_to_clean)
|
||||||
|
|
||||||
|
# Create task_index col
|
||||||
|
tasks_by_episode = df.groupby("episode_index")[tasks_col].unique().apply(lambda x: x.tolist()).to_dict()
|
||||||
|
tasks = df[tasks_col].unique().tolist()
|
||||||
|
tasks_to_task_index = {task: idx for idx, task in enumerate(tasks)}
|
||||||
|
df["task_index"] = df[tasks_col].map(tasks_to_task_index).astype(int)
|
||||||
|
|
||||||
def remove_videoframe_from_table(table: pa.Table, image_columns: list) -> pa.Table:
|
# Build the dataset back from df
|
||||||
table = table.drop(image_columns)
|
features = dataset.features
|
||||||
table = remove_hf_metadata_features(table, image_columns)
|
features["task_index"] = datasets.Value(dtype="int64")
|
||||||
return table
|
dataset = Dataset.from_pandas(df, features=features, split="train")
|
||||||
|
dataset = dataset.remove_columns(tasks_col)
|
||||||
|
|
||||||
|
return dataset, tasks, tasks_by_episode
|
||||||
def add_tasks(table: pa.Table, tasks_by_episodes: dict) -> pa.Table:
|
|
||||||
tasks_index = pa.array([tasks_by_episodes.get(key.as_py(), None) for key in table["episode_index"]])
|
|
||||||
table = table.append_column("task_index", tasks_index)
|
|
||||||
hf_feature = {"task_index": {"dtype": "int64", "_type": "Value"}}
|
|
||||||
table = add_hf_metadata_features(table, hf_feature)
|
|
||||||
return table
|
|
||||||
|
|
||||||
|
|
||||||
def split_parquet_by_episodes(
|
def split_parquet_by_episodes(
|
||||||
table: pa.Table, keys: dict[str, list], total_episodes: int, episode_indices: list, output_dir: Path
|
dataset: Dataset, keys: dict[str, list], total_episodes: int, episode_indices: list, output_dir: Path
|
||||||
) -> list:
|
) -> list:
|
||||||
(output_dir / "data").mkdir(exist_ok=True, parents=True)
|
(output_dir / "data").mkdir(exist_ok=True, parents=True)
|
||||||
if len(keys["video"]) > 0:
|
table = dataset.remove_columns(keys["video"])._data.table
|
||||||
table = remove_videoframe_from_table(table, keys["video"])
|
|
||||||
|
|
||||||
episode_lengths = []
|
episode_lengths = []
|
||||||
for episode_index in sorted(episode_indices):
|
for episode_index in sorted(episode_indices):
|
||||||
# Write each episode_index to a new parquet file
|
# Write each episode_index to a new parquet file
|
||||||
|
@ -330,11 +359,10 @@ def get_video_shapes(videos_info: dict, video_keys: list) -> dict:
|
||||||
return video_shapes
|
return video_shapes
|
||||||
|
|
||||||
|
|
||||||
def get_image_shapes(table: pa.Table, image_keys: list) -> dict:
|
def get_image_shapes(dataset: Dataset, image_keys: list) -> dict:
|
||||||
image_shapes = {}
|
image_shapes = {}
|
||||||
for img_key in image_keys:
|
for img_key in image_keys:
|
||||||
image_bytes = table[img_key][0].as_py() # Assuming first row
|
image = dataset[0][img_key] # Assuming first row
|
||||||
image = Image.open(BytesIO(image_bytes["bytes"]))
|
|
||||||
channels = get_image_pixel_channels(image)
|
channels = get_image_pixel_channels(image)
|
||||||
image_shapes[img_key] = {
|
image_shapes[img_key] = {
|
||||||
"width": image.width,
|
"width": image.width,
|
||||||
|
@ -352,8 +380,9 @@ def get_generic_motor_names(sequence_shapes: dict) -> dict:
|
||||||
def convert_dataset(
|
def convert_dataset(
|
||||||
repo_id: str,
|
repo_id: str,
|
||||||
local_dir: Path,
|
local_dir: Path,
|
||||||
tasks: dict,
|
single_task: str | None = None,
|
||||||
tasks_by_episodes: dict | None = None,
|
tasks_path: Path | None = None,
|
||||||
|
tasks_col: Path | None = None,
|
||||||
robot_config: dict | None = None,
|
robot_config: dict | None = None,
|
||||||
):
|
):
|
||||||
v1_6_dir = local_dir / V1_6 / repo_id
|
v1_6_dir = local_dir / V1_6 / repo_id
|
||||||
|
@ -367,29 +396,40 @@ def convert_dataset(
|
||||||
)
|
)
|
||||||
|
|
||||||
metadata_v1_6 = load_json(v1_6_dir / "meta_data" / "info.json")
|
metadata_v1_6 = load_json(v1_6_dir / "meta_data" / "info.json")
|
||||||
|
dataset = datasets.load_dataset("parquet", data_dir=v1_6_dir / "data", split="train")
|
||||||
table = pq.read_table(v1_6_dir / "data")
|
keys = get_keys(dataset)
|
||||||
keys = get_keys(table)
|
|
||||||
|
|
||||||
# Episodes
|
# Episodes
|
||||||
episode_indices = sorted(table["episode_index"].unique().to_pylist())
|
episode_indices = sorted(dataset.unique("episode_index"))
|
||||||
total_episodes = len(episode_indices)
|
total_episodes = len(episode_indices)
|
||||||
assert episode_indices == list(range(total_episodes))
|
assert episode_indices == list(range(total_episodes))
|
||||||
|
|
||||||
# Tasks
|
# Tasks
|
||||||
if tasks_by_episodes is None: # Single task dataset
|
if single_task:
|
||||||
tasks_by_episodes = {ep_idx: 0 for ep_idx in episode_indices}
|
tasks_by_episodes = {ep_idx: single_task for ep_idx in episode_indices}
|
||||||
|
dataset, tasks = add_task_index_by_episodes(dataset, tasks_by_episodes)
|
||||||
|
tasks_by_episodes = {ep_idx: [task] for ep_idx, task in tasks_by_episodes.items()}
|
||||||
|
elif tasks_path:
|
||||||
|
tasks_by_episodes = load_json(tasks_path)
|
||||||
|
tasks_by_episodes = {int(ep_idx): task for ep_idx, task in tasks_by_episodes.items()}
|
||||||
|
# tasks = list(set(tasks_by_episodes.values()))
|
||||||
|
dataset, tasks = add_task_index_by_episodes(dataset, tasks_by_episodes)
|
||||||
|
tasks_by_episodes = {ep_idx: [task] for ep_idx, task in tasks_by_episodes.items()}
|
||||||
|
elif tasks_col:
|
||||||
|
dataset, tasks, tasks_by_episodes = add_task_index_from_tasks_col(dataset, tasks_col)
|
||||||
|
else:
|
||||||
|
raise ValueError
|
||||||
|
|
||||||
assert set(tasks) == set(tasks_by_episodes.values())
|
assert set(tasks) == {task for ep_tasks in tasks_by_episodes.values() for task in ep_tasks}
|
||||||
table = add_tasks(table, tasks_by_episodes)
|
task_json = [{"task_index": task_idx, "task": task} for task_idx, task in enumerate(tasks)]
|
||||||
write_json(tasks, v2_0_dir / "meta" / "tasks.json")
|
write_json(task_json, v2_0_dir / "meta" / "tasks.json")
|
||||||
|
|
||||||
# Split data into 1 parquet file by episode
|
# Split data into 1 parquet file by episode
|
||||||
episode_lengths = split_parquet_by_episodes(table, keys, total_episodes, episode_indices, v2_0_dir)
|
episode_lengths = split_parquet_by_episodes(dataset, keys, total_episodes, episode_indices, v2_0_dir)
|
||||||
|
|
||||||
# Shapes
|
# Shapes
|
||||||
sequence_shapes = {key: len(table[key][0]) for key in keys["sequence"]}
|
sequence_shapes = {key: len(dataset[key][0]) for key in keys["sequence"]}
|
||||||
image_shapes = get_image_shapes(table, keys["image"]) if len(keys["image"]) > 0 else {}
|
image_shapes = get_image_shapes(dataset, keys["image"]) if len(keys["image"]) > 0 else {}
|
||||||
if len(keys["video"]) > 0:
|
if len(keys["video"]) > 0:
|
||||||
assert metadata_v1_6.get("video", False)
|
assert metadata_v1_6.get("video", False)
|
||||||
videos_info = get_videos_info(repo_id, v1_6_dir, video_keys=keys["video"])
|
videos_info = get_videos_info(repo_id, v1_6_dir, video_keys=keys["video"])
|
||||||
|
@ -416,11 +456,12 @@ def convert_dataset(
|
||||||
for key in sequence_shapes:
|
for key in sequence_shapes:
|
||||||
assert len(names[key]) == sequence_shapes[key]
|
assert len(names[key]) == sequence_shapes[key]
|
||||||
|
|
||||||
# Episodes info
|
# Episodes
|
||||||
episodes = [
|
episodes = [
|
||||||
{"index": ep_idx, "task": tasks_by_episodes[ep_idx], "length": episode_lengths[ep_idx]}
|
{"episode_index": ep_idx, "tasks": [tasks_by_episodes[ep_idx]], "length": episode_lengths[ep_idx]}
|
||||||
for ep_idx in episode_indices
|
for ep_idx in episode_indices
|
||||||
]
|
]
|
||||||
|
write_json(episodes, v2_0_dir / "meta" / "episodes.json")
|
||||||
|
|
||||||
# Assemble metadata v2.0
|
# Assemble metadata v2.0
|
||||||
metadata_v2_0 = {
|
metadata_v2_0 = {
|
||||||
|
@ -437,11 +478,17 @@ def convert_dataset(
|
||||||
"shapes": {**sequence_shapes, **video_shapes, **image_shapes},
|
"shapes": {**sequence_shapes, **video_shapes, **image_shapes},
|
||||||
"names": names,
|
"names": names,
|
||||||
"videos": videos_info,
|
"videos": videos_info,
|
||||||
"episodes": episodes,
|
|
||||||
}
|
}
|
||||||
write_json(metadata_v2_0, v2_0_dir / "meta" / "info.json")
|
write_json(metadata_v2_0, v2_0_dir / "meta" / "info.json")
|
||||||
convert_stats_to_json(v1_6_dir / "meta_data", v2_0_dir / "meta")
|
convert_stats_to_json(v1_6_dir / "meta_data", v2_0_dir / "meta")
|
||||||
|
|
||||||
|
#### TODO: delete
|
||||||
|
repo_id = f"aliberts/{repo_id.split('/')[1]}"
|
||||||
|
# if hub_api.repo_exists(repo_id=repo_id, repo_type="dataset"):
|
||||||
|
# hub_api.delete_repo(repo_id=repo_id, repo_type="dataset")
|
||||||
|
hub_api.create_repo(repo_id=repo_id, repo_type="dataset", exist_ok=True)
|
||||||
|
####
|
||||||
|
|
||||||
with contextlib.suppress(EntryNotFoundError):
|
with contextlib.suppress(EntryNotFoundError):
|
||||||
hub_api.delete_folder(repo_id=repo_id, path_in_repo="data", repo_type="dataset", revision="main")
|
hub_api.delete_folder(repo_id=repo_id, path_in_repo="data", repo_type="dataset", revision="main")
|
||||||
|
|
||||||
|
@ -455,6 +502,13 @@ def convert_dataset(
|
||||||
repo_type="dataset",
|
repo_type="dataset",
|
||||||
revision="main",
|
revision="main",
|
||||||
)
|
)
|
||||||
|
hub_api.upload_folder(
|
||||||
|
repo_id=repo_id,
|
||||||
|
path_in_repo="videos",
|
||||||
|
folder_path=v1_6_dir / "videos",
|
||||||
|
repo_type="dataset",
|
||||||
|
revision="main",
|
||||||
|
)
|
||||||
hub_api.upload_folder(
|
hub_api.upload_folder(
|
||||||
repo_id=repo_id,
|
repo_id=repo_id,
|
||||||
path_in_repo="meta",
|
path_in_repo="meta",
|
||||||
|
@ -463,7 +517,6 @@ def convert_dataset(
|
||||||
revision="main",
|
revision="main",
|
||||||
)
|
)
|
||||||
|
|
||||||
metadata_v2_0.pop("episodes")
|
|
||||||
card_text = f"[meta/info.json](meta/info.json)\n```json\n{json.dumps(metadata_v2_0, indent=4)}\n```"
|
card_text = f"[meta/info.json](meta/info.json)\n```json\n{json.dumps(metadata_v2_0, indent=4)}\n```"
|
||||||
push_dataset_card_to_hub(repo_id=repo_id, revision="main", tags=repo_tags, text=card_text)
|
push_dataset_card_to_hub(repo_id=repo_id, revision="main", tags=repo_tags, text=card_text)
|
||||||
create_branch(repo_id=repo_id, branch=V2_0, repo_type="dataset")
|
create_branch(repo_id=repo_id, branch=V2_0, repo_type="dataset")
|
||||||
|
@ -478,12 +531,13 @@ def convert_dataset(
|
||||||
# - [X] Add robot_type
|
# - [X] Add robot_type
|
||||||
# - [X] Add splits
|
# - [X] Add splits
|
||||||
# - [X] Push properly to branch v2.0 and delete v1.6 stuff from that branch
|
# - [X] Push properly to branch v2.0 and delete v1.6 stuff from that branch
|
||||||
|
# - [X] Handle multitask datasets
|
||||||
# - [/] Add sanity checks (encoding, shapes)
|
# - [/] Add sanity checks (encoding, shapes)
|
||||||
# - [ ] Handle multitask datasets
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
task_args = parser.add_mutually_exclusive_group(required=True)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--repo-id",
|
"--repo-id",
|
||||||
|
@ -491,11 +545,20 @@ def main():
|
||||||
required=True,
|
required=True,
|
||||||
help="Repository identifier on Hugging Face: a community or a user name `/` the name of the dataset (e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).",
|
help="Repository identifier on Hugging Face: a community or a user name `/` the name of the dataset (e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
task_args.add_argument(
|
||||||
"--task",
|
"--single-task",
|
||||||
type=str,
|
type=str,
|
||||||
required=True,
|
help="A short but accurate description of the single task performed in the dataset.",
|
||||||
help="A short but accurate description of the task performed in the dataset.",
|
)
|
||||||
|
task_args.add_argument(
|
||||||
|
"--tasks-col",
|
||||||
|
type=str,
|
||||||
|
help="The name of the column containing language instructions",
|
||||||
|
)
|
||||||
|
task_args.add_argument(
|
||||||
|
"--tasks-path",
|
||||||
|
type=Path,
|
||||||
|
help="The path to a .json file containing one language instruction for each episode_index",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--robot-config",
|
"--robot-config",
|
||||||
|
@ -517,19 +580,13 @@ def main():
|
||||||
)
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
if args.local_dir is None:
|
if not args.local_dir:
|
||||||
args.local_dir = Path(f"/tmp/{args.repo_id}")
|
args.local_dir = Path(f"/tmp/{args.repo_id}")
|
||||||
|
|
||||||
tasks = {0: args.task}
|
robot_config = parse_robot_config(args.robot_config, args.robot_overrides) if args.robot_config else None
|
||||||
del args.task
|
|
||||||
|
|
||||||
if args.robot_config is not None:
|
|
||||||
robot_config = parse_robot_config(args.robot_config, args.robot_overrides)
|
|
||||||
else:
|
|
||||||
robot_config = None
|
|
||||||
del args.robot_config, args.robot_overrides
|
del args.robot_config, args.robot_overrides
|
||||||
|
|
||||||
convert_dataset(**vars(args), tasks=tasks, robot_config=robot_config)
|
convert_dataset(**vars(args), robot_config=robot_config)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
Loading…
Reference in New Issue