From efb1982eecc316aee61e957f4288de20559709a6 Mon Sep 17 00:00:00 2001 From: Michel Aractingi Date: Mon, 3 Feb 2025 17:48:35 +0000 Subject: [PATCH] Added crop_dataset_roi.py that allows you to load a lerobotdataset -> crop its images -> create a new lerobot dataset with the cropped and resized images. Co-authored-by: Adil Zouitine --- lerobot/common/robot_devices/control_utils.py | 6 +- lerobot/scripts/server/crop_dataset_roi.py | 264 ++++++++++++++++++ lerobot/scripts/server/crop_roi.py | 148 ---------- 3 files changed, 268 insertions(+), 150 deletions(-) create mode 100644 lerobot/scripts/server/crop_dataset_roi.py delete mode 100644 lerobot/scripts/server/crop_roi.py diff --git a/lerobot/common/robot_devices/control_utils.py b/lerobot/common/robot_devices/control_utils.py index 10cb9f5c..f88f6d3e 100644 --- a/lerobot/common/robot_devices/control_utils.py +++ b/lerobot/common/robot_devices/control_utils.py @@ -36,7 +36,7 @@ def log_control_info(robot: Robot, dt_s, episode_index=None, frame_index=None, f def log_dt(shortname, dt_val_s): nonlocal log_items, fps - info_str = f"{shortname}:{dt_val_s * 1000:5.2f} ({1/ dt_val_s:3.1f}hz)" + info_str = f"{shortname}:{dt_val_s * 1000:5.2f} ({1 / dt_val_s:3.1f}hz)" if fps is not None: actual_fps = 1 / dt_val_s if actual_fps < fps - 1: @@ -335,7 +335,9 @@ def reset_environment(robot, events, reset_time_s): def reset_follower_position(robot: Robot, target_position): current_position = robot.follower_arms["main"].read("Present_Position") - trajectory = torch.from_numpy(np.linspace(current_position, target_position, 30)) # NOTE: 30 is just an aribtrary number + trajectory = torch.from_numpy( + np.linspace(current_position, target_position, 30) + ) # NOTE: 30 is just an aribtrary number for pose in trajectory: robot.send_action(pose) busy_wait(0.015) diff --git a/lerobot/scripts/server/crop_dataset_roi.py b/lerobot/scripts/server/crop_dataset_roi.py new file mode 100644 index 00000000..8d7d7ebf --- /dev/null +++ b/lerobot/scripts/server/crop_dataset_roi.py @@ -0,0 +1,264 @@ +import argparse # noqa: I001 +from copy import deepcopy +from typing import Dict, Tuple + +import cv2 + +# import torch.nn.functional as F # noqa: N812 +import torchvision.transforms.functional as F # type: ignore # noqa: N812 +from tqdm import tqdm # type: ignore + +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset + + +def select_rect_roi(img): + """ + Allows the user to draw a rectangular ROI on the image. + + The user must click and drag to draw the rectangle. + - While dragging, the rectangle is dynamically drawn. + - On mouse button release, the rectangle is fixed. + - Press 'c' to confirm the selection. + - Press 'r' to reset the selection. + - Press ESC to cancel. + + Returns: + A tuple (top, left, height, width) representing the rectangular ROI, + or None if no valid ROI is selected. + """ + # Create a working copy of the image + clone = img.copy() + working_img = clone.copy() + + roi = None # Will store the final ROI as (top, left, height, width) + drawing = False + ix, iy = -1, -1 # Initial click coordinates + + def mouse_callback(event, x, y, flags, param): + nonlocal ix, iy, drawing, roi, working_img + + if event == cv2.EVENT_LBUTTONDOWN: + # Start drawing: record starting coordinates + drawing = True + ix, iy = x, y + + elif event == cv2.EVENT_MOUSEMOVE: + if drawing: + # Compute the top-left and bottom-right corners regardless of drag direction + top = min(iy, y) + left = min(ix, x) + bottom = max(iy, y) + right = max(ix, x) + # Show a temporary image with the current rectangle drawn + temp = working_img.copy() + cv2.rectangle(temp, (left, top), (right, bottom), (0, 255, 0), 2) + cv2.imshow("Select ROI", temp) + + elif event == cv2.EVENT_LBUTTONUP: + # Finish drawing + drawing = False + top = min(iy, y) + left = min(ix, x) + bottom = max(iy, y) + right = max(ix, x) + height = bottom - top + width = right - left + roi = (top, left, height, width) # (top, left, height, width) + # Draw the final rectangle on the working image and display it + working_img = clone.copy() + cv2.rectangle(working_img, (left, top), (right, bottom), (0, 255, 0), 2) + cv2.imshow("Select ROI", working_img) + + # Create the window and set the callback + cv2.namedWindow("Select ROI") + cv2.setMouseCallback("Select ROI", mouse_callback) + cv2.imshow("Select ROI", working_img) + + print("Instructions for ROI selection:") + print(" - Click and drag to draw a rectangular ROI.") + print(" - Press 'c' to confirm the selection.") + print(" - Press 'r' to reset and draw again.") + print(" - Press ESC to cancel the selection.") + + # Wait until the user confirms with 'c', resets with 'r', or cancels with ESC + while True: + key = cv2.waitKey(1) & 0xFF + # Confirm ROI if one has been drawn + if key == ord("c") and roi is not None: + break + # Reset: clear the ROI and restore the original image + elif key == ord("r"): + working_img = clone.copy() + roi = None + cv2.imshow("Select ROI", working_img) + # Cancel selection for this image + elif key == 27: # ESC key + roi = None + break + + cv2.destroyWindow("Select ROI") + return roi + + +def select_square_roi_for_images(images: dict) -> dict: + """ + For each image in the provided dictionary, open a window to allow the user + to select a rectangular ROI. Returns a dictionary mapping each key to a tuple + (top, left, height, width) representing the ROI. + + Parameters: + images (dict): Dictionary where keys are identifiers and values are OpenCV images. + + Returns: + dict: Mapping of image keys to the selected rectangular ROI. + """ + selected_rois = {} + + for key, img in images.items(): + if img is None: + print(f"Image for key '{key}' is None, skipping.") + continue + + print(f"\nSelect rectangular ROI for image with key: '{key}'") + roi = select_rect_roi(img) + + if roi is None: + print(f"No valid ROI selected for '{key}'.") + else: + selected_rois[key] = roi + print(f"ROI for '{key}': {roi}") + + return selected_rois + + +def get_image_from_lerobot_dataset(dataset: LeRobotDataset): + """ + Find the first row in the dataset and extract the image in order to be used for the crop. + """ + row = dataset[0] + image_dict = {} + for k in row: + if "image" in k: + image_dict[k] = deepcopy(row[k]) + return image_dict + + +def convert_lerobot_dataset_to_cropper_lerobot_dataset( + original_dataset: LeRobotDataset, + crop_params_dict: Dict[str, Tuple[int, int, int, int]], + new_repo_id: str, + new_dataset_root: str, + resize_size: Tuple[int, int] = (128, 128), +) -> LeRobotDataset: + """ + Converts an existing LeRobotDataset by iterating over its episodes and frames, + applying cropping and resizing to image observations, and saving a new dataset + with the transformed data. + + Args: + original_dataset (LeRobotDataset): The source dataset. + crop_params_dict (Dict[str, Tuple[int, int, int, int]]): + A dictionary mapping observation keys to crop parameters (top, left, height, width). + new_repo_id (str): Repository id for the new dataset. + new_dataset_root (str): The root directory where the new dataset will be written. + resize_size (Tuple[int, int], optional): The target size (height, width) after cropping. + Defaults to (128, 128). + + Returns: + LeRobotDataset: A new LeRobotDataset where the specified image observations have been cropped + and resized. + """ + # 1. Create a new (empty) LeRobotDataset for writing. + new_dataset = LeRobotDataset.create( + repo_id=new_repo_id, + fps=original_dataset.fps, + root=new_dataset_root, + robot_type=original_dataset.meta.robot_type, + features=original_dataset.meta.info["features"], + use_videos=len(original_dataset.meta.video_keys) > 0, + ) + + # Update the metadata for every image key that will be cropped: + # (Here we simply set the shape to be the final resize_size.) + for key in crop_params_dict: + if key in new_dataset.meta.info["features"]: + new_dataset.meta.info["features"][key]["shape"] = list(resize_size) + + # 2. Process each episode in the original dataset. + episodes_info = original_dataset.meta.episodes + # (Sort episodes by episode_index for consistency.) + episodes_info = sorted(episodes_info, key=lambda x: x["episode_index"]) + + for ep in tqdm(episodes_info[:3], desc="Processing episodes"): + ep_index = ep.pop("episode_index") + # Use the first task from the episode metadata (or "unknown" if not provided) + task = ep["tasks"][0] if ep.get("tasks") else "unknown" + + # Reset the episode buffer in the new dataset (this will store frames for one episode). + new_dataset.episode_buffer = new_dataset.create_episode_buffer(episode_index=ep_index) + + # 3. Filter and process all frames belonging to this episode. + # Here we loop over the entire dataset and select the frames with the matching episode_index. + # (Depending on the dataset size, you might want a more efficient method.) + ep_frames = [sample for sample in original_dataset if sample["episode_index"] == ep_index] + + for sample in tqdm(ep_frames): + sample.pop("episode_index") + sample.pop("frame_index") + # Make a shallow copy of the sample (the values—e.g. torch tensors—are assumed immutable) + new_sample = sample.copy() + # Loop over each observation key that should be cropped/resized. + for key, params in crop_params_dict.items(): + if key in new_sample: + top, left, height, width = params + # Apply crop then resize. + cropped = F.crop(new_sample[key], top, left, height, width) + resized = F.resize(cropped, resize_size) + new_sample[key] = resized + # Add the transformed frame to the new dataset. + new_dataset.add_frame(new_sample) + + # 4. Save the episode (this writes the parquet file and image files). + new_dataset.save_episode(task, encode_videos=True) + + # Optionally, consolidate the new dataset to compute statistics and update video info. + new_dataset.consolidate(run_compute_stats=True, keep_image_files=True) + + return new_dataset + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Crop rectangular ROIs from a LeRobot dataset.") + parser.add_argument( + "--repo-id", + type=str, + default="lerobot", + help="The repository id of the LeRobot dataset to process.", + ) + parser.add_argument( + "--root", + type=str, + default=None, + help="The root directory of the LeRobot dataset.", + ) + args = parser.parse_args() + + dataset = LeRobotDataset(repo_id=args.repo_id, root=args.root) + + images = get_image_from_lerobot_dataset(dataset) + images = {k: v.cpu().permute(1, 2, 0).numpy() for k, v in images.items()} + images = {k: (v * 255).astype("uint8") for k, v in images.items()} + + rois = select_square_roi_for_images(images) + + # Print the selected rectangular ROIs + print("\nSelected Rectangular Regions of Interest (top, left, height, width):") + for key, roi in rois.items(): + print(f"{key}: {roi}") + croped_resized_dataset = convert_lerobot_dataset_to_cropper_lerobot_dataset( + original_dataset=dataset, + crop_params_dict=rois, + new_repo_id=args.repo_id + "_cropped_resized", + new_dataset_root="data/" + args.repo_id + "_cropped_resized", + resize_size=(128, 128), + ) diff --git a/lerobot/scripts/server/crop_roi.py b/lerobot/scripts/server/crop_roi.py deleted file mode 100644 index f00f3eb6..00000000 --- a/lerobot/scripts/server/crop_roi.py +++ /dev/null @@ -1,148 +0,0 @@ -import cv2 - -from lerobot.common.robot_devices.cameras.opencv import OpenCVCamera - - -def select_square_roi(img): - """ - Allows the user to draw a square ROI on the image. - - The user must click and drag to draw the square. - - While dragging, the square is dynamically drawn. - - On mouse button release, the square is fixed. - - Press 'c' to confirm the selection. - - Press 'r' to reset the selection. - - Press ESC to cancel. - - Returns: - A tuple (top, left, height, width) representing the square ROI, - or None if no valid ROI is selected. - """ - # Create a working copy of the image - clone = img.copy() - working_img = clone.copy() - - roi = None # Will store the final ROI as (top, left, side, side) - drawing = False - ix, iy = -1, -1 # Initial click coordinates - - def mouse_callback(event, x, y, flags, param): - nonlocal ix, iy, drawing, roi, working_img - - if event == cv2.EVENT_LBUTTONDOWN: - # Start drawing: record starting coordinates - drawing = True - ix, iy = x, y - - elif event == cv2.EVENT_MOUSEMOVE: - if drawing: - # Compute side length as the minimum of horizontal/vertical drags - side = min(abs(x - ix), abs(y - iy)) - # Determine the direction to draw (in case of dragging to top/left) - dx = side if x >= ix else -side - dy = side if y >= iy else -side - # Show a temporary image with the current square drawn - temp = working_img.copy() - cv2.rectangle(temp, (ix, iy), (ix + dx, iy + dy), (0, 255, 0), 2) - cv2.imshow("Select ROI", temp) - - elif event == cv2.EVENT_LBUTTONUP: - # Finish drawing - drawing = False - side = min(abs(x - ix), abs(y - iy)) - dx = side if x >= ix else -side - dy = side if y >= iy else -side - # Normalize coordinates: (top, left) is the minimum of the two points - x1 = min(ix, ix + dx) - y1 = min(iy, iy + dy) - roi = (y1, x1, side, side) # (top, left, height, width) - # Draw the final square on the working image and display it - working_img = clone.copy() - cv2.rectangle(working_img, (ix, iy), (ix + dx, iy + dy), (0, 255, 0), 2) - cv2.imshow("Select ROI", working_img) - - # Create the window and set the callback - cv2.namedWindow("Select ROI") - cv2.setMouseCallback("Select ROI", mouse_callback) - cv2.imshow("Select ROI", working_img) - - print("Instructions for ROI selection:") - print(" - Click and drag to draw a square ROI.") - print(" - Press 'c' to confirm the selection.") - print(" - Press 'r' to reset and draw again.") - print(" - Press ESC to cancel the selection.") - - # Wait until the user confirms with 'c', resets with 'r', or cancels with ESC - while True: - key = cv2.waitKey(1) & 0xFF - # Confirm ROI if one has been drawn - if key == ord("c") and roi is not None: - break - # Reset: clear the ROI and restore the original image - elif key == ord("r"): - working_img = clone.copy() - roi = None - cv2.imshow("Select ROI", working_img) - # Cancel selection for this image - elif key == 27: # ESC key - roi = None - break - - cv2.destroyWindow("Select ROI") - return roi - - -def select_square_roi_for_images(images: dict) -> dict: - """ - For each image in the provided dictionary, open a window to allow the user - to select a square ROI. Returns a dictionary mapping each key to a tuple - (top, left, height, width) representing the ROI. - - Parameters: - images (dict): Dictionary where keys are identifiers and values are OpenCV images. - - Returns: - dict: Mapping of image keys to the selected square ROI. - """ - selected_rois = {} - - for key, img in images.items(): - if img is None: - print(f"Image for key '{key}' is None, skipping.") - continue - - print(f"\nSelect square ROI for image with key: '{key}'") - roi = select_square_roi(img) - - if roi is None: - print(f"No valid ROI selected for '{key}'.") - else: - selected_rois[key] = roi - print(f"ROI for '{key}': {roi}") - - return selected_rois - - -if __name__ == "__main__": - # Example usage: - # Replace 'image1.jpg' and 'image2.jpg' with valid paths to your image files. - fps = [5, 30] - cameras = [OpenCVCamera(i, fps=fps[i], width=640, height=480, mock=False) for i in range(2)] - [camera.connect() for camera in cameras] - - image_keys = ["image_" + str(i) for i in range(len(cameras))] - - images = {image_keys[i]: cameras[i].read() for i in range(len(cameras))} - - # Verify images loaded correctly - for key, img in images.items(): - if img is None: - raise ValueError(f"Failed to load image for key '{key}'. Check the file path.") - - # Let the user select a square ROI for each image - rois = select_square_roi_for_images(images) - - # Print the selected square ROIs - print("\nSelected Square Regions of Interest (top, left, height, width):") - for key, roi in rois.items(): - print(f"{key}: {roi}")