273 lines
9.9 KiB
Python
273 lines
9.9 KiB
Python
import argparse # noqa: I001
|
|
from copy import deepcopy
|
|
from typing import Dict, Tuple
|
|
|
|
import cv2
|
|
|
|
# import torch.nn.functional as F # noqa: N812
|
|
import torchvision.transforms.functional as F # type: ignore # noqa: N812
|
|
from tqdm import tqdm # type: ignore
|
|
|
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
|
|
|
|
|
def select_rect_roi(img):
|
|
"""
|
|
Allows the user to draw a rectangular ROI on the image.
|
|
|
|
The user must click and drag to draw the rectangle.
|
|
- While dragging, the rectangle is dynamically drawn.
|
|
- On mouse button release, the rectangle is fixed.
|
|
- Press 'c' to confirm the selection.
|
|
- Press 'r' to reset the selection.
|
|
- Press ESC to cancel.
|
|
|
|
Returns:
|
|
A tuple (top, left, height, width) representing the rectangular ROI,
|
|
or None if no valid ROI is selected.
|
|
"""
|
|
# Create a working copy of the image
|
|
clone = img.copy()
|
|
working_img = clone.copy()
|
|
|
|
roi = None # Will store the final ROI as (top, left, height, width)
|
|
drawing = False
|
|
ix, iy = -1, -1 # Initial click coordinates
|
|
|
|
def mouse_callback(event, x, y, flags, param):
|
|
nonlocal ix, iy, drawing, roi, working_img
|
|
|
|
if event == cv2.EVENT_LBUTTONDOWN:
|
|
# Start drawing: record starting coordinates
|
|
drawing = True
|
|
ix, iy = x, y
|
|
|
|
elif event == cv2.EVENT_MOUSEMOVE:
|
|
if drawing:
|
|
# Compute the top-left and bottom-right corners regardless of drag direction
|
|
top = min(iy, y)
|
|
left = min(ix, x)
|
|
bottom = max(iy, y)
|
|
right = max(ix, x)
|
|
# Show a temporary image with the current rectangle drawn
|
|
temp = working_img.copy()
|
|
cv2.rectangle(temp, (left, top), (right, bottom), (0, 255, 0), 2)
|
|
cv2.imshow("Select ROI", temp)
|
|
|
|
elif event == cv2.EVENT_LBUTTONUP:
|
|
# Finish drawing
|
|
drawing = False
|
|
top = min(iy, y)
|
|
left = min(ix, x)
|
|
bottom = max(iy, y)
|
|
right = max(ix, x)
|
|
height = bottom - top
|
|
width = right - left
|
|
roi = (top, left, height, width) # (top, left, height, width)
|
|
# Draw the final rectangle on the working image and display it
|
|
working_img = clone.copy()
|
|
cv2.rectangle(working_img, (left, top), (right, bottom), (0, 255, 0), 2)
|
|
cv2.imshow("Select ROI", working_img)
|
|
|
|
# Create the window and set the callback
|
|
cv2.namedWindow("Select ROI")
|
|
cv2.setMouseCallback("Select ROI", mouse_callback)
|
|
cv2.imshow("Select ROI", working_img)
|
|
|
|
print("Instructions for ROI selection:")
|
|
print(" - Click and drag to draw a rectangular ROI.")
|
|
print(" - Press 'c' to confirm the selection.")
|
|
print(" - Press 'r' to reset and draw again.")
|
|
print(" - Press ESC to cancel the selection.")
|
|
|
|
# Wait until the user confirms with 'c', resets with 'r', or cancels with ESC
|
|
while True:
|
|
key = cv2.waitKey(1) & 0xFF
|
|
# Confirm ROI if one has been drawn
|
|
if key == ord("c") and roi is not None:
|
|
break
|
|
# Reset: clear the ROI and restore the original image
|
|
elif key == ord("r"):
|
|
working_img = clone.copy()
|
|
roi = None
|
|
cv2.imshow("Select ROI", working_img)
|
|
# Cancel selection for this image
|
|
elif key == 27: # ESC key
|
|
roi = None
|
|
break
|
|
|
|
cv2.destroyWindow("Select ROI")
|
|
return roi
|
|
|
|
|
|
def select_square_roi_for_images(images: dict) -> dict:
|
|
"""
|
|
For each image in the provided dictionary, open a window to allow the user
|
|
to select a rectangular ROI. Returns a dictionary mapping each key to a tuple
|
|
(top, left, height, width) representing the ROI.
|
|
|
|
Parameters:
|
|
images (dict): Dictionary where keys are identifiers and values are OpenCV images.
|
|
|
|
Returns:
|
|
dict: Mapping of image keys to the selected rectangular ROI.
|
|
"""
|
|
selected_rois = {}
|
|
|
|
for key, img in images.items():
|
|
if img is None:
|
|
print(f"Image for key '{key}' is None, skipping.")
|
|
continue
|
|
|
|
print(f"\nSelect rectangular ROI for image with key: '{key}'")
|
|
roi = select_rect_roi(img)
|
|
|
|
if roi is None:
|
|
print(f"No valid ROI selected for '{key}'.")
|
|
else:
|
|
selected_rois[key] = roi
|
|
print(f"ROI for '{key}': {roi}")
|
|
|
|
return selected_rois
|
|
|
|
|
|
def get_image_from_lerobot_dataset(dataset: LeRobotDataset):
|
|
"""
|
|
Find the first row in the dataset and extract the image in order to be used for the crop.
|
|
"""
|
|
row = dataset[0]
|
|
image_dict = {}
|
|
for k in row:
|
|
if "image" in k:
|
|
image_dict[k] = deepcopy(row[k])
|
|
return image_dict
|
|
|
|
|
|
def convert_lerobot_dataset_to_cropper_lerobot_dataset(
|
|
original_dataset: LeRobotDataset,
|
|
crop_params_dict: Dict[str, Tuple[int, int, int, int]],
|
|
new_repo_id: str,
|
|
new_dataset_root: str,
|
|
resize_size: Tuple[int, int] = (128, 128),
|
|
) -> LeRobotDataset:
|
|
"""
|
|
Converts an existing LeRobotDataset by iterating over its episodes and frames,
|
|
applying cropping and resizing to image observations, and saving a new dataset
|
|
with the transformed data.
|
|
|
|
Args:
|
|
original_dataset (LeRobotDataset): The source dataset.
|
|
crop_params_dict (Dict[str, Tuple[int, int, int, int]]):
|
|
A dictionary mapping observation keys to crop parameters (top, left, height, width).
|
|
new_repo_id (str): Repository id for the new dataset.
|
|
new_dataset_root (str): The root directory where the new dataset will be written.
|
|
resize_size (Tuple[int, int], optional): The target size (height, width) after cropping.
|
|
Defaults to (128, 128).
|
|
|
|
Returns:
|
|
LeRobotDataset: A new LeRobotDataset where the specified image observations have been cropped
|
|
and resized.
|
|
"""
|
|
# 1. Create a new (empty) LeRobotDataset for writing.
|
|
new_dataset = LeRobotDataset.create(
|
|
repo_id=new_repo_id,
|
|
fps=original_dataset.fps,
|
|
root=new_dataset_root,
|
|
robot_type=original_dataset.meta.robot_type,
|
|
features=original_dataset.meta.info["features"],
|
|
use_videos=len(original_dataset.meta.video_keys) > 0,
|
|
)
|
|
|
|
# Update the metadata for every image key that will be cropped:
|
|
# (Here we simply set the shape to be the final resize_size.)
|
|
for key in crop_params_dict:
|
|
if key in new_dataset.meta.info["features"]:
|
|
new_dataset.meta.info["features"][key]["shape"] = list(resize_size)
|
|
|
|
# 2. Process each episode in the original dataset.
|
|
episodes_info = original_dataset.meta.episodes
|
|
# (Sort episodes by episode_index for consistency.)
|
|
|
|
episodes_info = sorted(episodes_info, key=lambda x: x["episode_index"])
|
|
# Use the first task from the episode metadata (or "unknown" if not provided)
|
|
task = episodes_info[0]["tasks"][0] if episodes_info[0].get("tasks") else "unknown"
|
|
|
|
last_episode_index = 0
|
|
for sample in tqdm(original_dataset):
|
|
episode_index = sample.pop("episode_index")
|
|
if episode_index != last_episode_index:
|
|
new_dataset.save_episode(task, encode_videos=True)
|
|
last_episode_index = episode_index
|
|
sample.pop("frame_index")
|
|
# Make a shallow copy of the sample (the values—e.g. torch tensors—are assumed immutable)
|
|
new_sample = sample.copy()
|
|
# Loop over each observation key that should be cropped/resized.
|
|
for key, params in crop_params_dict.items():
|
|
if key in new_sample:
|
|
top, left, height, width = params
|
|
# Apply crop then resize.
|
|
cropped = F.crop(new_sample[key], top, left, height, width)
|
|
resized = F.resize(cropped, resize_size)
|
|
new_sample[key] = resized
|
|
# Add the transformed frame to the new dataset.
|
|
new_dataset.add_frame(new_sample)
|
|
|
|
# save last episode
|
|
new_dataset.save_episode(task, encode_videos=True)
|
|
|
|
# Optionally, consolidate the new dataset to compute statistics and update video info.
|
|
new_dataset.consolidate(run_compute_stats=True, keep_image_files=True)
|
|
|
|
new_dataset.push_to_hub(tags=None)
|
|
|
|
return new_dataset
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(description="Crop rectangular ROIs from a LeRobot dataset.")
|
|
parser.add_argument(
|
|
"--repo-id",
|
|
type=str,
|
|
default="lerobot",
|
|
help="The repository id of the LeRobot dataset to process.",
|
|
)
|
|
parser.add_argument(
|
|
"--root",
|
|
type=str,
|
|
default=None,
|
|
help="The root directory of the LeRobot dataset.",
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
dataset = LeRobotDataset(repo_id=args.repo_id, root=args.root, local_files_only=True)
|
|
|
|
images = get_image_from_lerobot_dataset(dataset)
|
|
images = {k: v.cpu().permute(1, 2, 0).numpy() for k, v in images.items()}
|
|
images = {k: (v * 255).astype("uint8") for k, v in images.items()}
|
|
|
|
# rois = select_square_roi_for_images(images)
|
|
rois = {
|
|
"observation.images.front": [102, 43, 358, 523],
|
|
"observation.images.side": [92, 123, 379, 349],
|
|
}
|
|
# rois = {
|
|
# "observation.images.side": (92, 123, 379, 349),
|
|
# "observation.images.front": (109, 37, 361, 557),
|
|
# }
|
|
# rois = {
|
|
# "observation.images.front": [109, 37, 361, 557],
|
|
# "observation.images.side": [94, 161, 372, 315],
|
|
# }
|
|
|
|
# Print the selected rectangular ROIs
|
|
print("\nSelected Rectangular Regions of Interest (top, left, height, width):")
|
|
for key, roi in rois.items():
|
|
print(f"{key}: {roi}")
|
|
croped_resized_dataset = convert_lerobot_dataset_to_cropper_lerobot_dataset(
|
|
original_dataset=dataset,
|
|
crop_params_dict=rois,
|
|
new_repo_id=args.repo_id + "_cropped_resized",
|
|
new_dataset_root="data/" + args.repo_id + "_cropped_resized",
|
|
resize_size=(128, 128),
|
|
)
|