Modified crop_dataset_roi interface to automatically write the cropped parameters to a json file in the meta of the dataset
Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
parent
c9e50bb9b1
commit
36711d766a
|
@ -84,7 +84,8 @@ class LeRobotDatasetMetadata:
|
|||
|
||||
# Load metadata
|
||||
(self.root / "meta").mkdir(exist_ok=True, parents=True)
|
||||
self.pull_from_repo(allow_patterns="meta/")
|
||||
if not self.local_files_only:
|
||||
self.pull_from_repo(allow_patterns="meta/")
|
||||
self.info = load_info(self.root)
|
||||
self.stats = load_stats(self.root)
|
||||
self.tasks = load_tasks(self.root)
|
||||
|
@ -537,9 +538,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
]
|
||||
files += video_files
|
||||
|
||||
# HACK: UNCOMMENT IF YOU REVIEW THAT, PLEASE SUGGEST TO UNCOMMENT
|
||||
logging.warning("HACK: WE COMMENT THIS LINE, IF SOMETHING IS WEIRD WITH DATASETS UNCOMMENT")
|
||||
self.pull_from_repo(allow_patterns=files, ignore_patterns=ignore_patterns)
|
||||
if not self.local_files_only:
|
||||
self.pull_from_repo(allow_patterns=files, ignore_patterns=ignore_patterns)
|
||||
|
||||
def load_hf_dataset(self) -> datasets.Dataset:
|
||||
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
|
||||
|
|
|
@ -77,8 +77,11 @@ policy:
|
|||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
observation.state:
|
||||
min: [-87.09961, 62.402344, 67.23633, 36.035156, 77.34375,0.53691274]
|
||||
max: [58.183594, 131.83594, 145.98633, 82.08984, 78.22266, 0.60402685]
|
||||
min: [-77.08008, 56.25, 60.55664, 19.511719, 0., -0.63829786]
|
||||
max: [ 7.215820e+01, 1.5398438e+02, 1.6075195e+02, 9.3251953e+01, 0., -1.4184397e-01]
|
||||
|
||||
# min: [-87.09961, 62.402344, 67.23633, 36.035156, 77.34375,0.53691274]
|
||||
# max: [58.183594, 131.83594, 145.98633, 82.08984, 78.22266, 0.60402685]
|
||||
# min: [-88.50586, 23.81836, 0.87890625, -32.16797, 78.66211, 0.53691274]
|
||||
# max: [84.55078, 187.11914, 145.98633, 101.60156, 146.60156, 88.18792]
|
||||
|
||||
|
@ -96,7 +99,7 @@ policy:
|
|||
# Neural networks.
|
||||
image_encoder_hidden_dim: 32
|
||||
# discount: 0.99
|
||||
discount: 0.80
|
||||
discount: 0.97
|
||||
temperature_init: 1.0
|
||||
num_critics: 2 #10
|
||||
camera_number: 2
|
||||
|
|
|
@ -181,10 +181,10 @@ class ReplayBuffer:
|
|||
state = {key: tensor.to(self.storage_device) for key, tensor in state.items()}
|
||||
next_state = {key: tensor.to(self.storage_device) for key, tensor in next_state.items()}
|
||||
action = action.to(self.storage_device)
|
||||
if complementary_info is not None:
|
||||
complementary_info = {
|
||||
key: tensor.to(self.storage_device) for key, tensor in complementary_info.items()
|
||||
}
|
||||
# if complementary_info is not None:
|
||||
# complementary_info = {
|
||||
# key: tensor.to(self.storage_device) for key, tensor in complementary_info.items()
|
||||
# }
|
||||
|
||||
if len(self.memory) < self.capacity:
|
||||
self.memory.append(None)
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
import argparse # noqa: I001
|
||||
import json
|
||||
from copy import deepcopy
|
||||
from typing import Dict, Tuple
|
||||
|
||||
from pathlib import Path
|
||||
import cv2
|
||||
|
||||
# import torch.nn.functional as F # noqa: N812
|
||||
|
@ -237,19 +238,27 @@ if __name__ == "__main__":
|
|||
default=None,
|
||||
help="The root directory of the LeRobot dataset.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--crop-params-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The path to the JSON file containing the ROIs.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
dataset = LeRobotDataset(repo_id=args.repo_id, root=args.root, local_files_only=False)
|
||||
local_files_only = args.root is not None
|
||||
dataset = LeRobotDataset(repo_id=args.repo_id, root=args.root, local_files_only=local_files_only)
|
||||
|
||||
images = get_image_from_lerobot_dataset(dataset)
|
||||
images = {k: v.cpu().permute(1, 2, 0).numpy() for k, v in images.items()}
|
||||
images = {k: (v * 255).astype("uint8") for k, v in images.items()}
|
||||
|
||||
# rois = select_square_roi_for_images(images)
|
||||
rois = {
|
||||
"observation.images.front": [102, 43, 358, 523],
|
||||
"observation.images.side": [92, 123, 379, 349],
|
||||
}
|
||||
if args.crop_params_path is None:
|
||||
rois = select_square_roi_for_images(images)
|
||||
else:
|
||||
with open(args.crop_params_path, "r") as f:
|
||||
rois = json.load(f)
|
||||
|
||||
# rois = {
|
||||
# "observation.images.side": (92, 123, 379, 349),
|
||||
# "observation.images.front": (109, 37, 361, 557),
|
||||
|
@ -263,10 +272,20 @@ if __name__ == "__main__":
|
|||
print("\nSelected Rectangular Regions of Interest (top, left, height, width):")
|
||||
for key, roi in rois.items():
|
||||
print(f"{key}: {roi}")
|
||||
|
||||
new_repo_id = args.repo_id + "_cropped_resized"
|
||||
new_dataset_root = Path(str(dataset.root) + "_cropped_resized")
|
||||
|
||||
croped_resized_dataset = convert_lerobot_dataset_to_cropper_lerobot_dataset(
|
||||
original_dataset=dataset,
|
||||
crop_params_dict=rois,
|
||||
new_repo_id=args.repo_id + "_cropped_resized",
|
||||
new_dataset_root="data/" + args.repo_id + "_cropped_resized",
|
||||
new_repo_id=new_repo_id,
|
||||
new_dataset_root=new_dataset_root,
|
||||
resize_size=(128, 128),
|
||||
)
|
||||
|
||||
meta_dir = new_dataset_root / "meta"
|
||||
meta_dir.mkdir(exist_ok=True)
|
||||
|
||||
with open(meta_dir / "crop_params.json", "w") as f:
|
||||
json.dump(rois, f, indent=4)
|
||||
|
|
Loading…
Reference in New Issue