Modified crop_dataset_roi interface to automatically write the cropped parameters to a json file in the meta of the dataset

Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
Michel Aractingi 2025-02-14 12:32:45 +01:00
parent c9e50bb9b1
commit 36711d766a
4 changed files with 42 additions and 20 deletions

View File

@ -84,7 +84,8 @@ class LeRobotDatasetMetadata:
# Load metadata
(self.root / "meta").mkdir(exist_ok=True, parents=True)
self.pull_from_repo(allow_patterns="meta/")
if not self.local_files_only:
self.pull_from_repo(allow_patterns="meta/")
self.info = load_info(self.root)
self.stats = load_stats(self.root)
self.tasks = load_tasks(self.root)
@ -537,9 +538,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
]
files += video_files
# HACK: UNCOMMENT IF YOU REVIEW THAT, PLEASE SUGGEST TO UNCOMMENT
logging.warning("HACK: WE COMMENT THIS LINE, IF SOMETHING IS WEIRD WITH DATASETS UNCOMMENT")
self.pull_from_repo(allow_patterns=files, ignore_patterns=ignore_patterns)
if not self.local_files_only:
self.pull_from_repo(allow_patterns=files, ignore_patterns=ignore_patterns)
def load_hf_dataset(self) -> datasets.Dataset:
"""hf_dataset contains all the observations, states, actions, rewards, etc."""

View File

@ -77,8 +77,11 @@ policy:
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
observation.state:
min: [-87.09961, 62.402344, 67.23633, 36.035156, 77.34375,0.53691274]
max: [58.183594, 131.83594, 145.98633, 82.08984, 78.22266, 0.60402685]
min: [-77.08008, 56.25, 60.55664, 19.511719, 0., -0.63829786]
max: [ 7.215820e+01, 1.5398438e+02, 1.6075195e+02, 9.3251953e+01, 0., -1.4184397e-01]
# min: [-87.09961, 62.402344, 67.23633, 36.035156, 77.34375,0.53691274]
# max: [58.183594, 131.83594, 145.98633, 82.08984, 78.22266, 0.60402685]
# min: [-88.50586, 23.81836, 0.87890625, -32.16797, 78.66211, 0.53691274]
# max: [84.55078, 187.11914, 145.98633, 101.60156, 146.60156, 88.18792]
@ -96,7 +99,7 @@ policy:
# Neural networks.
image_encoder_hidden_dim: 32
# discount: 0.99
discount: 0.80
discount: 0.97
temperature_init: 1.0
num_critics: 2 #10
camera_number: 2

View File

@ -181,10 +181,10 @@ class ReplayBuffer:
state = {key: tensor.to(self.storage_device) for key, tensor in state.items()}
next_state = {key: tensor.to(self.storage_device) for key, tensor in next_state.items()}
action = action.to(self.storage_device)
if complementary_info is not None:
complementary_info = {
key: tensor.to(self.storage_device) for key, tensor in complementary_info.items()
}
# if complementary_info is not None:
# complementary_info = {
# key: tensor.to(self.storage_device) for key, tensor in complementary_info.items()
# }
if len(self.memory) < self.capacity:
self.memory.append(None)

View File

@ -1,7 +1,8 @@
import argparse # noqa: I001
import json
from copy import deepcopy
from typing import Dict, Tuple
from pathlib import Path
import cv2
# import torch.nn.functional as F # noqa: N812
@ -237,19 +238,27 @@ if __name__ == "__main__":
default=None,
help="The root directory of the LeRobot dataset.",
)
parser.add_argument(
"--crop-params-path",
type=str,
default=None,
help="The path to the JSON file containing the ROIs.",
)
args = parser.parse_args()
dataset = LeRobotDataset(repo_id=args.repo_id, root=args.root, local_files_only=False)
local_files_only = args.root is not None
dataset = LeRobotDataset(repo_id=args.repo_id, root=args.root, local_files_only=local_files_only)
images = get_image_from_lerobot_dataset(dataset)
images = {k: v.cpu().permute(1, 2, 0).numpy() for k, v in images.items()}
images = {k: (v * 255).astype("uint8") for k, v in images.items()}
# rois = select_square_roi_for_images(images)
rois = {
"observation.images.front": [102, 43, 358, 523],
"observation.images.side": [92, 123, 379, 349],
}
if args.crop_params_path is None:
rois = select_square_roi_for_images(images)
else:
with open(args.crop_params_path, "r") as f:
rois = json.load(f)
# rois = {
# "observation.images.side": (92, 123, 379, 349),
# "observation.images.front": (109, 37, 361, 557),
@ -263,10 +272,20 @@ if __name__ == "__main__":
print("\nSelected Rectangular Regions of Interest (top, left, height, width):")
for key, roi in rois.items():
print(f"{key}: {roi}")
new_repo_id = args.repo_id + "_cropped_resized"
new_dataset_root = Path(str(dataset.root) + "_cropped_resized")
croped_resized_dataset = convert_lerobot_dataset_to_cropper_lerobot_dataset(
original_dataset=dataset,
crop_params_dict=rois,
new_repo_id=args.repo_id + "_cropped_resized",
new_dataset_root="data/" + args.repo_id + "_cropped_resized",
new_repo_id=new_repo_id,
new_dataset_root=new_dataset_root,
resize_size=(128, 128),
)
meta_dir = new_dataset_root / "meta"
meta_dir.mkdir(exist_ok=True)
with open(meta_dir / "crop_params.json", "w") as f:
json.dump(rois, f, indent=4)