From 09e15025ccc8751f3a11690755f94514129c2c04 Mon Sep 17 00:00:00 2001 From: Ville Kuosmanen Date: Mon, 24 Mar 2025 18:32:02 +0000 Subject: [PATCH 1/2] feat: add task annotations overrider (#11) --- .../common/datasets/annotation_overrider.py | 196 ++++++++++++++++++ lerobot/scripts/visualize_dataset_html.py | 6 + 2 files changed, 202 insertions(+) create mode 100644 lerobot/common/datasets/annotation_overrider.py diff --git a/lerobot/common/datasets/annotation_overrider.py b/lerobot/common/datasets/annotation_overrider.py new file mode 100644 index 00000000..431d399d --- /dev/null +++ b/lerobot/common/datasets/annotation_overrider.py @@ -0,0 +1,196 @@ +import json +from pathlib import Path +from typing import Dict, Any, Union, List, Optional + +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, MultiLeRobotDataset + + +class TaskAnnotationOverrider: + """ + A class to override task annotations in LeRobotDataset or MultiLeRobotDataset + without modifying the original dataset files. + + This allows users to use datasets shared by others but with customized task descriptions. + """ + + def __init__(self, annotation_file: Union[str, Path]): + """ + Initialize the TaskAnnotationOverrider with an annotation file. + + Args: + annotation_file (Union[str, Path]): Path to the JSON file containing task annotation overrides. + The file should have the following structure: + { + "repo_id1": { + "0": "new task description for task 0", + "1": "new task description for task 1" + }, + "repo_id2": { + "0": "new task description for task 0" + } + } + """ + self.annotation_file = Path(annotation_file) + self.load_annotations() + + def load_annotations(self) -> None: + """Load task annotations from the specified JSON file.""" + if not self.annotation_file.exists(): + raise FileNotFoundError(f"Annotation file not found: {self.annotation_file}") + + with open(self.annotation_file, 'r') as f: + self.annotations = json.load(f) + + def apply_overrides(self, dataset: Union[LeRobotDataset, MultiLeRobotDataset]) -> None: + """ + Apply task annotation overrides to a LeRobotDataset or MultiLeRobotDataset. + + Args: + dataset (Union[LeRobotDataset, MultiLeRobotDataset]): The dataset to override task annotations for. + """ + if isinstance(dataset, MultiLeRobotDataset): + self._apply_overrides_multi(dataset) + elif isinstance(dataset, LeRobotDataset): + self._apply_overrides_single(dataset) + else: + raise TypeError(f"Unsupported dataset type: {type(dataset)}") + + def _apply_overrides_single(self, dataset: LeRobotDataset) -> None: + """ + Apply task annotation overrides to a single LeRobotDataset. + + Args: + dataset (LeRobotDataset): The dataset to override task annotations for. + """ + repo_id = dataset.repo_id + if repo_id not in self.annotations: + print(f"No annotations found for repository: {repo_id}") + return + + repo_annotations = self.annotations[repo_id] + modified = False + + # Create mapping of old task descriptions to new ones + task_description_mapping = {} + + # Update the tasks in the metadata + for task_idx_str, new_description in repo_annotations.items(): + task_idx = int(task_idx_str) + + # Check if the task index exists in the dataset + if task_idx >= dataset.meta.total_tasks: + print(f"Warning: Task index {task_idx} not found in dataset {repo_id}") + continue + + # Update the task description + if task_idx in dataset.meta.tasks: + original_description = dataset.meta.tasks[task_idx] + dataset.meta.tasks[task_idx] = new_description + task_description_mapping[original_description] = new_description + modified = True + else: + print(f"Warning: Task index {task_idx} not found in dataset {repo_id}") + + # If any annotations were modified, update the dataset + if modified: + self._update_task_index_mapping(dataset) + self._update_episode_tasks(dataset, task_description_mapping) + self._update_item_access(dataset) + + def _apply_overrides_multi(self, multi_dataset: MultiLeRobotDataset) -> None: + """ + Apply task annotation overrides to each LeRobotDataset in a MultiLeRobotDataset. + + Args: + multi_dataset (MultiLeRobotDataset): The multi-dataset to override task annotations for. + """ + for dataset in multi_dataset._datasets: + self._apply_overrides_single(dataset) + + def _update_task_index_mapping(self, dataset: LeRobotDataset) -> None: + """ + Update the task_to_task_index mapping to reflect the new task descriptions. + + Args: + dataset (LeRobotDataset): The dataset to update the mapping for. + """ + # Rebuild the task_to_task_index mapping + dataset.meta.task_to_task_index = { + task_desc: task_idx for task_idx, task_desc in dataset.meta.tasks.items() + } + + def _update_episode_tasks(self, dataset: LeRobotDataset, task_mapping: Dict[str, str]) -> None: + """ + Update the task descriptions in the episodes metadata. + + Args: + dataset (LeRobotDataset): The dataset to update. + task_mapping (Dict[str, str]): Mapping from original to new task descriptions. + """ + for ep_idx, episode in dataset.meta.episodes.items(): + if "tasks" in episode: + # Replace task descriptions in the episode's tasks list + updated_tasks = [] + for task in episode["tasks"]: + if task in task_mapping: + updated_tasks.append(task_mapping[task]) + else: + updated_tasks.append(task) + + if updated_tasks != episode["tasks"]: + episode["tasks"] = updated_tasks + + def _update_item_access(self, dataset: LeRobotDataset) -> None: + """ + Modify the __getitem__ method of the dataset to use the updated task annotations. + + Args: + dataset (LeRobotDataset): The dataset to update. + """ + # Since task lookup is done in __getitem__, no need to modify anything else here + # The overridden tasks in dataset.meta.tasks will be used automatically + pass + + def save_overridden_metadata(self, dataset: LeRobotDataset, output_path: Optional[Path] = None) -> None: + """ + Save the overridden metadata to a new file without modifying the original dataset. + + This allows saving the modified tasks as a separate file that can be loaded later. + + Args: + dataset (LeRobotDataset): The dataset with overridden task annotations. + output_path (Optional[Path]): The path to save the overridden metadata. + If None, will save to dataset.root/meta/tasks_overridden.json + """ + if output_path is None: + output_path = dataset.root / "meta" / "tasks_overridden.json" + + output_path.parent.mkdir(parents=True, exist_ok=True) + + # Save the overridden tasks + tasks_dict = {str(idx): desc for idx, desc in dataset.meta.tasks.items()} + with open(output_path, 'w') as f: + json.dump(tasks_dict, f, indent=2) + + # Optionally, also save the updated episodes + episodes_path = output_path.parent / "episodes_overridden.json" + with open(episodes_path, 'w') as f: + json.dump(dataset.meta.episodes, f, indent=2) + +# Usage example: +# +# from lerobot.common.datasets.task_annotation_overrider import TaskAnnotationOverrider +# +# # Load a dataset +# dataset = LeRobotDataset("jpata/so100_pick_place_tangerine") +# +# # Create and apply overrides +# overrider = TaskAnnotationOverrider("task_annotations.json") +# overrider.apply_overrides(dataset) +# +# # Now dataset has updated task descriptions +# # The original dataset files are unchanged +# +# # When using the dataset, task descriptions will be updated +# sample = dataset[0] +# print(sample["task"]) # This will use the overridden task description diff --git a/lerobot/scripts/visualize_dataset_html.py b/lerobot/scripts/visualize_dataset_html.py index 0fc21a8f..e11048d0 100644 --- a/lerobot/scripts/visualize_dataset_html.py +++ b/lerobot/scripts/visualize_dataset_html.py @@ -464,6 +464,7 @@ def main(): root = kwargs.pop("root") tolerance_s = kwargs.pop("tolerance_s") + dataset = None if repo_id: dataset = ( @@ -471,6 +472,11 @@ def main(): if not load_from_hf_hub else get_dataset_info(repo_id) ) + + # test annotation overrider + from lerobot.common.datasets.annotation_overrider import TaskAnnotationOverrider + overrider = TaskAnnotationOverrider("data/annotation_overrides.json") + overrider.apply_overrides(dataset) visualize_dataset_html(dataset, **vars(args)) From fe187254c43393e0099dcb4998282ef898853125 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 24 Mar 2025 18:37:07 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../common/datasets/annotation_overrider.py | 79 ++++++++++--------- lerobot/scripts/visualize_dataset_html.py | 4 +- 2 files changed, 42 insertions(+), 41 deletions(-) diff --git a/lerobot/common/datasets/annotation_overrider.py b/lerobot/common/datasets/annotation_overrider.py index 431d399d..81b36180 100644 --- a/lerobot/common/datasets/annotation_overrider.py +++ b/lerobot/common/datasets/annotation_overrider.py @@ -1,6 +1,6 @@ import json from pathlib import Path -from typing import Dict, Any, Union, List, Optional +from typing import Dict, Optional, Union from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, MultiLeRobotDataset @@ -9,14 +9,14 @@ class TaskAnnotationOverrider: """ A class to override task annotations in LeRobotDataset or MultiLeRobotDataset without modifying the original dataset files. - + This allows users to use datasets shared by others but with customized task descriptions. """ - + def __init__(self, annotation_file: Union[str, Path]): """ Initialize the TaskAnnotationOverrider with an annotation file. - + Args: annotation_file (Union[str, Path]): Path to the JSON file containing task annotation overrides. The file should have the following structure: @@ -32,19 +32,19 @@ class TaskAnnotationOverrider: """ self.annotation_file = Path(annotation_file) self.load_annotations() - + def load_annotations(self) -> None: """Load task annotations from the specified JSON file.""" if not self.annotation_file.exists(): raise FileNotFoundError(f"Annotation file not found: {self.annotation_file}") - - with open(self.annotation_file, 'r') as f: + + with open(self.annotation_file) as f: self.annotations = json.load(f) - + def apply_overrides(self, dataset: Union[LeRobotDataset, MultiLeRobotDataset]) -> None: """ Apply task annotation overrides to a LeRobotDataset or MultiLeRobotDataset. - + Args: dataset (Union[LeRobotDataset, MultiLeRobotDataset]): The dataset to override task annotations for. """ @@ -54,11 +54,11 @@ class TaskAnnotationOverrider: self._apply_overrides_single(dataset) else: raise TypeError(f"Unsupported dataset type: {type(dataset)}") - + def _apply_overrides_single(self, dataset: LeRobotDataset) -> None: """ Apply task annotation overrides to a single LeRobotDataset. - + Args: dataset (LeRobotDataset): The dataset to override task annotations for. """ @@ -66,22 +66,22 @@ class TaskAnnotationOverrider: if repo_id not in self.annotations: print(f"No annotations found for repository: {repo_id}") return - + repo_annotations = self.annotations[repo_id] modified = False - + # Create mapping of old task descriptions to new ones task_description_mapping = {} - + # Update the tasks in the metadata for task_idx_str, new_description in repo_annotations.items(): task_idx = int(task_idx_str) - + # Check if the task index exists in the dataset if task_idx >= dataset.meta.total_tasks: print(f"Warning: Task index {task_idx} not found in dataset {repo_id}") continue - + # Update the task description if task_idx in dataset.meta.tasks: original_description = dataset.meta.tasks[task_idx] @@ -90,27 +90,27 @@ class TaskAnnotationOverrider: modified = True else: print(f"Warning: Task index {task_idx} not found in dataset {repo_id}") - + # If any annotations were modified, update the dataset if modified: self._update_task_index_mapping(dataset) self._update_episode_tasks(dataset, task_description_mapping) self._update_item_access(dataset) - + def _apply_overrides_multi(self, multi_dataset: MultiLeRobotDataset) -> None: """ Apply task annotation overrides to each LeRobotDataset in a MultiLeRobotDataset. - + Args: multi_dataset (MultiLeRobotDataset): The multi-dataset to override task annotations for. """ for dataset in multi_dataset._datasets: self._apply_overrides_single(dataset) - + def _update_task_index_mapping(self, dataset: LeRobotDataset) -> None: """ Update the task_to_task_index mapping to reflect the new task descriptions. - + Args: dataset (LeRobotDataset): The dataset to update the mapping for. """ @@ -118,11 +118,11 @@ class TaskAnnotationOverrider: dataset.meta.task_to_task_index = { task_desc: task_idx for task_idx, task_desc in dataset.meta.tasks.items() } - + def _update_episode_tasks(self, dataset: LeRobotDataset, task_mapping: Dict[str, str]) -> None: """ Update the task descriptions in the episodes metadata. - + Args: dataset (LeRobotDataset): The dataset to update. task_mapping (Dict[str, str]): Mapping from original to new task descriptions. @@ -136,61 +136,62 @@ class TaskAnnotationOverrider: updated_tasks.append(task_mapping[task]) else: updated_tasks.append(task) - + if updated_tasks != episode["tasks"]: episode["tasks"] = updated_tasks def _update_item_access(self, dataset: LeRobotDataset) -> None: """ Modify the __getitem__ method of the dataset to use the updated task annotations. - + Args: dataset (LeRobotDataset): The dataset to update. """ # Since task lookup is done in __getitem__, no need to modify anything else here # The overridden tasks in dataset.meta.tasks will be used automatically pass - + def save_overridden_metadata(self, dataset: LeRobotDataset, output_path: Optional[Path] = None) -> None: """ Save the overridden metadata to a new file without modifying the original dataset. - + This allows saving the modified tasks as a separate file that can be loaded later. - + Args: dataset (LeRobotDataset): The dataset with overridden task annotations. - output_path (Optional[Path]): The path to save the overridden metadata. + output_path (Optional[Path]): The path to save the overridden metadata. If None, will save to dataset.root/meta/tasks_overridden.json """ if output_path is None: output_path = dataset.root / "meta" / "tasks_overridden.json" - + output_path.parent.mkdir(parents=True, exist_ok=True) - + # Save the overridden tasks tasks_dict = {str(idx): desc for idx, desc in dataset.meta.tasks.items()} - with open(output_path, 'w') as f: + with open(output_path, "w") as f: json.dump(tasks_dict, f, indent=2) - + # Optionally, also save the updated episodes episodes_path = output_path.parent / "episodes_overridden.json" - with open(episodes_path, 'w') as f: + with open(episodes_path, "w") as f: json.dump(dataset.meta.episodes, f, indent=2) + # Usage example: -# +# # from lerobot.common.datasets.task_annotation_overrider import TaskAnnotationOverrider -# +# # # Load a dataset # dataset = LeRobotDataset("jpata/so100_pick_place_tangerine") -# +# # # Create and apply overrides # overrider = TaskAnnotationOverrider("task_annotations.json") # overrider.apply_overrides(dataset) -# +# # # Now dataset has updated task descriptions # # The original dataset files are unchanged -# +# # # When using the dataset, task descriptions will be updated # sample = dataset[0] # print(sample["task"]) # This will use the overridden task description diff --git a/lerobot/scripts/visualize_dataset_html.py b/lerobot/scripts/visualize_dataset_html.py index e11048d0..df7251bb 100644 --- a/lerobot/scripts/visualize_dataset_html.py +++ b/lerobot/scripts/visualize_dataset_html.py @@ -464,7 +464,6 @@ def main(): root = kwargs.pop("root") tolerance_s = kwargs.pop("tolerance_s") - dataset = None if repo_id: dataset = ( @@ -472,9 +471,10 @@ def main(): if not load_from_hf_hub else get_dataset_info(repo_id) ) - + # test annotation overrider from lerobot.common.datasets.annotation_overrider import TaskAnnotationOverrider + overrider = TaskAnnotationOverrider("data/annotation_overrides.json") overrider.apply_overrides(dataset)