Added crop_dataset_roi.py that allows you to load a lerobotdataset -> crop its images -> create a new lerobot dataset with the cropped and resized images.

Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
Michel Aractingi 2025-02-03 17:48:35 +00:00
parent 2211209be5
commit efb1982eec
3 changed files with 268 additions and 150 deletions

View File

@ -36,7 +36,7 @@ def log_control_info(robot: Robot, dt_s, episode_index=None, frame_index=None, f
def log_dt(shortname, dt_val_s):
nonlocal log_items, fps
info_str = f"{shortname}:{dt_val_s * 1000:5.2f} ({1/ dt_val_s:3.1f}hz)"
info_str = f"{shortname}:{dt_val_s * 1000:5.2f} ({1 / dt_val_s:3.1f}hz)"
if fps is not None:
actual_fps = 1 / dt_val_s
if actual_fps < fps - 1:
@ -335,7 +335,9 @@ def reset_environment(robot, events, reset_time_s):
def reset_follower_position(robot: Robot, target_position):
current_position = robot.follower_arms["main"].read("Present_Position")
trajectory = torch.from_numpy(np.linspace(current_position, target_position, 30)) # NOTE: 30 is just an aribtrary number
trajectory = torch.from_numpy(
np.linspace(current_position, target_position, 30)
) # NOTE: 30 is just an aribtrary number
for pose in trajectory:
robot.send_action(pose)
busy_wait(0.015)

View File

@ -0,0 +1,264 @@
import argparse # noqa: I001
from copy import deepcopy
from typing import Dict, Tuple
import cv2
# import torch.nn.functional as F # noqa: N812
import torchvision.transforms.functional as F # type: ignore # noqa: N812
from tqdm import tqdm # type: ignore
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
def select_rect_roi(img):
"""
Allows the user to draw a rectangular ROI on the image.
The user must click and drag to draw the rectangle.
- While dragging, the rectangle is dynamically drawn.
- On mouse button release, the rectangle is fixed.
- Press 'c' to confirm the selection.
- Press 'r' to reset the selection.
- Press ESC to cancel.
Returns:
A tuple (top, left, height, width) representing the rectangular ROI,
or None if no valid ROI is selected.
"""
# Create a working copy of the image
clone = img.copy()
working_img = clone.copy()
roi = None # Will store the final ROI as (top, left, height, width)
drawing = False
ix, iy = -1, -1 # Initial click coordinates
def mouse_callback(event, x, y, flags, param):
nonlocal ix, iy, drawing, roi, working_img
if event == cv2.EVENT_LBUTTONDOWN:
# Start drawing: record starting coordinates
drawing = True
ix, iy = x, y
elif event == cv2.EVENT_MOUSEMOVE:
if drawing:
# Compute the top-left and bottom-right corners regardless of drag direction
top = min(iy, y)
left = min(ix, x)
bottom = max(iy, y)
right = max(ix, x)
# Show a temporary image with the current rectangle drawn
temp = working_img.copy()
cv2.rectangle(temp, (left, top), (right, bottom), (0, 255, 0), 2)
cv2.imshow("Select ROI", temp)
elif event == cv2.EVENT_LBUTTONUP:
# Finish drawing
drawing = False
top = min(iy, y)
left = min(ix, x)
bottom = max(iy, y)
right = max(ix, x)
height = bottom - top
width = right - left
roi = (top, left, height, width) # (top, left, height, width)
# Draw the final rectangle on the working image and display it
working_img = clone.copy()
cv2.rectangle(working_img, (left, top), (right, bottom), (0, 255, 0), 2)
cv2.imshow("Select ROI", working_img)
# Create the window and set the callback
cv2.namedWindow("Select ROI")
cv2.setMouseCallback("Select ROI", mouse_callback)
cv2.imshow("Select ROI", working_img)
print("Instructions for ROI selection:")
print(" - Click and drag to draw a rectangular ROI.")
print(" - Press 'c' to confirm the selection.")
print(" - Press 'r' to reset and draw again.")
print(" - Press ESC to cancel the selection.")
# Wait until the user confirms with 'c', resets with 'r', or cancels with ESC
while True:
key = cv2.waitKey(1) & 0xFF
# Confirm ROI if one has been drawn
if key == ord("c") and roi is not None:
break
# Reset: clear the ROI and restore the original image
elif key == ord("r"):
working_img = clone.copy()
roi = None
cv2.imshow("Select ROI", working_img)
# Cancel selection for this image
elif key == 27: # ESC key
roi = None
break
cv2.destroyWindow("Select ROI")
return roi
def select_square_roi_for_images(images: dict) -> dict:
"""
For each image in the provided dictionary, open a window to allow the user
to select a rectangular ROI. Returns a dictionary mapping each key to a tuple
(top, left, height, width) representing the ROI.
Parameters:
images (dict): Dictionary where keys are identifiers and values are OpenCV images.
Returns:
dict: Mapping of image keys to the selected rectangular ROI.
"""
selected_rois = {}
for key, img in images.items():
if img is None:
print(f"Image for key '{key}' is None, skipping.")
continue
print(f"\nSelect rectangular ROI for image with key: '{key}'")
roi = select_rect_roi(img)
if roi is None:
print(f"No valid ROI selected for '{key}'.")
else:
selected_rois[key] = roi
print(f"ROI for '{key}': {roi}")
return selected_rois
def get_image_from_lerobot_dataset(dataset: LeRobotDataset):
"""
Find the first row in the dataset and extract the image in order to be used for the crop.
"""
row = dataset[0]
image_dict = {}
for k in row:
if "image" in k:
image_dict[k] = deepcopy(row[k])
return image_dict
def convert_lerobot_dataset_to_cropper_lerobot_dataset(
original_dataset: LeRobotDataset,
crop_params_dict: Dict[str, Tuple[int, int, int, int]],
new_repo_id: str,
new_dataset_root: str,
resize_size: Tuple[int, int] = (128, 128),
) -> LeRobotDataset:
"""
Converts an existing LeRobotDataset by iterating over its episodes and frames,
applying cropping and resizing to image observations, and saving a new dataset
with the transformed data.
Args:
original_dataset (LeRobotDataset): The source dataset.
crop_params_dict (Dict[str, Tuple[int, int, int, int]]):
A dictionary mapping observation keys to crop parameters (top, left, height, width).
new_repo_id (str): Repository id for the new dataset.
new_dataset_root (str): The root directory where the new dataset will be written.
resize_size (Tuple[int, int], optional): The target size (height, width) after cropping.
Defaults to (128, 128).
Returns:
LeRobotDataset: A new LeRobotDataset where the specified image observations have been cropped
and resized.
"""
# 1. Create a new (empty) LeRobotDataset for writing.
new_dataset = LeRobotDataset.create(
repo_id=new_repo_id,
fps=original_dataset.fps,
root=new_dataset_root,
robot_type=original_dataset.meta.robot_type,
features=original_dataset.meta.info["features"],
use_videos=len(original_dataset.meta.video_keys) > 0,
)
# Update the metadata for every image key that will be cropped:
# (Here we simply set the shape to be the final resize_size.)
for key in crop_params_dict:
if key in new_dataset.meta.info["features"]:
new_dataset.meta.info["features"][key]["shape"] = list(resize_size)
# 2. Process each episode in the original dataset.
episodes_info = original_dataset.meta.episodes
# (Sort episodes by episode_index for consistency.)
episodes_info = sorted(episodes_info, key=lambda x: x["episode_index"])
for ep in tqdm(episodes_info[:3], desc="Processing episodes"):
ep_index = ep.pop("episode_index")
# Use the first task from the episode metadata (or "unknown" if not provided)
task = ep["tasks"][0] if ep.get("tasks") else "unknown"
# Reset the episode buffer in the new dataset (this will store frames for one episode).
new_dataset.episode_buffer = new_dataset.create_episode_buffer(episode_index=ep_index)
# 3. Filter and process all frames belonging to this episode.
# Here we loop over the entire dataset and select the frames with the matching episode_index.
# (Depending on the dataset size, you might want a more efficient method.)
ep_frames = [sample for sample in original_dataset if sample["episode_index"] == ep_index]
for sample in tqdm(ep_frames):
sample.pop("episode_index")
sample.pop("frame_index")
# Make a shallow copy of the sample (the values—e.g. torch tensors—are assumed immutable)
new_sample = sample.copy()
# Loop over each observation key that should be cropped/resized.
for key, params in crop_params_dict.items():
if key in new_sample:
top, left, height, width = params
# Apply crop then resize.
cropped = F.crop(new_sample[key], top, left, height, width)
resized = F.resize(cropped, resize_size)
new_sample[key] = resized
# Add the transformed frame to the new dataset.
new_dataset.add_frame(new_sample)
# 4. Save the episode (this writes the parquet file and image files).
new_dataset.save_episode(task, encode_videos=True)
# Optionally, consolidate the new dataset to compute statistics and update video info.
new_dataset.consolidate(run_compute_stats=True, keep_image_files=True)
return new_dataset
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Crop rectangular ROIs from a LeRobot dataset.")
parser.add_argument(
"--repo-id",
type=str,
default="lerobot",
help="The repository id of the LeRobot dataset to process.",
)
parser.add_argument(
"--root",
type=str,
default=None,
help="The root directory of the LeRobot dataset.",
)
args = parser.parse_args()
dataset = LeRobotDataset(repo_id=args.repo_id, root=args.root)
images = get_image_from_lerobot_dataset(dataset)
images = {k: v.cpu().permute(1, 2, 0).numpy() for k, v in images.items()}
images = {k: (v * 255).astype("uint8") for k, v in images.items()}
rois = select_square_roi_for_images(images)
# Print the selected rectangular ROIs
print("\nSelected Rectangular Regions of Interest (top, left, height, width):")
for key, roi in rois.items():
print(f"{key}: {roi}")
croped_resized_dataset = convert_lerobot_dataset_to_cropper_lerobot_dataset(
original_dataset=dataset,
crop_params_dict=rois,
new_repo_id=args.repo_id + "_cropped_resized",
new_dataset_root="data/" + args.repo_id + "_cropped_resized",
resize_size=(128, 128),
)

View File

@ -1,148 +0,0 @@
import cv2
from lerobot.common.robot_devices.cameras.opencv import OpenCVCamera
def select_square_roi(img):
"""
Allows the user to draw a square ROI on the image.
The user must click and drag to draw the square.
- While dragging, the square is dynamically drawn.
- On mouse button release, the square is fixed.
- Press 'c' to confirm the selection.
- Press 'r' to reset the selection.
- Press ESC to cancel.
Returns:
A tuple (top, left, height, width) representing the square ROI,
or None if no valid ROI is selected.
"""
# Create a working copy of the image
clone = img.copy()
working_img = clone.copy()
roi = None # Will store the final ROI as (top, left, side, side)
drawing = False
ix, iy = -1, -1 # Initial click coordinates
def mouse_callback(event, x, y, flags, param):
nonlocal ix, iy, drawing, roi, working_img
if event == cv2.EVENT_LBUTTONDOWN:
# Start drawing: record starting coordinates
drawing = True
ix, iy = x, y
elif event == cv2.EVENT_MOUSEMOVE:
if drawing:
# Compute side length as the minimum of horizontal/vertical drags
side = min(abs(x - ix), abs(y - iy))
# Determine the direction to draw (in case of dragging to top/left)
dx = side if x >= ix else -side
dy = side if y >= iy else -side
# Show a temporary image with the current square drawn
temp = working_img.copy()
cv2.rectangle(temp, (ix, iy), (ix + dx, iy + dy), (0, 255, 0), 2)
cv2.imshow("Select ROI", temp)
elif event == cv2.EVENT_LBUTTONUP:
# Finish drawing
drawing = False
side = min(abs(x - ix), abs(y - iy))
dx = side if x >= ix else -side
dy = side if y >= iy else -side
# Normalize coordinates: (top, left) is the minimum of the two points
x1 = min(ix, ix + dx)
y1 = min(iy, iy + dy)
roi = (y1, x1, side, side) # (top, left, height, width)
# Draw the final square on the working image and display it
working_img = clone.copy()
cv2.rectangle(working_img, (ix, iy), (ix + dx, iy + dy), (0, 255, 0), 2)
cv2.imshow("Select ROI", working_img)
# Create the window and set the callback
cv2.namedWindow("Select ROI")
cv2.setMouseCallback("Select ROI", mouse_callback)
cv2.imshow("Select ROI", working_img)
print("Instructions for ROI selection:")
print(" - Click and drag to draw a square ROI.")
print(" - Press 'c' to confirm the selection.")
print(" - Press 'r' to reset and draw again.")
print(" - Press ESC to cancel the selection.")
# Wait until the user confirms with 'c', resets with 'r', or cancels with ESC
while True:
key = cv2.waitKey(1) & 0xFF
# Confirm ROI if one has been drawn
if key == ord("c") and roi is not None:
break
# Reset: clear the ROI and restore the original image
elif key == ord("r"):
working_img = clone.copy()
roi = None
cv2.imshow("Select ROI", working_img)
# Cancel selection for this image
elif key == 27: # ESC key
roi = None
break
cv2.destroyWindow("Select ROI")
return roi
def select_square_roi_for_images(images: dict) -> dict:
"""
For each image in the provided dictionary, open a window to allow the user
to select a square ROI. Returns a dictionary mapping each key to a tuple
(top, left, height, width) representing the ROI.
Parameters:
images (dict): Dictionary where keys are identifiers and values are OpenCV images.
Returns:
dict: Mapping of image keys to the selected square ROI.
"""
selected_rois = {}
for key, img in images.items():
if img is None:
print(f"Image for key '{key}' is None, skipping.")
continue
print(f"\nSelect square ROI for image with key: '{key}'")
roi = select_square_roi(img)
if roi is None:
print(f"No valid ROI selected for '{key}'.")
else:
selected_rois[key] = roi
print(f"ROI for '{key}': {roi}")
return selected_rois
if __name__ == "__main__":
# Example usage:
# Replace 'image1.jpg' and 'image2.jpg' with valid paths to your image files.
fps = [5, 30]
cameras = [OpenCVCamera(i, fps=fps[i], width=640, height=480, mock=False) for i in range(2)]
[camera.connect() for camera in cameras]
image_keys = ["image_" + str(i) for i in range(len(cameras))]
images = {image_keys[i]: cameras[i].read() for i in range(len(cameras))}
# Verify images loaded correctly
for key, img in images.items():
if img is None:
raise ValueError(f"Failed to load image for key '{key}'. Check the file path.")
# Let the user select a square ROI for each image
rois = select_square_roi_for_images(images)
# Print the selected square ROIs
print("\nSelected Square Regions of Interest (top, left, height, width):")
for key, roi in rois.items():
print(f"{key}: {roi}")