Modified crop_dataset_roi interface to automatically write the cropped parameters to a json file in the meta of the dataset
Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
parent
c9e50bb9b1
commit
36711d766a
|
@ -84,6 +84,7 @@ class LeRobotDatasetMetadata:
|
||||||
|
|
||||||
# Load metadata
|
# Load metadata
|
||||||
(self.root / "meta").mkdir(exist_ok=True, parents=True)
|
(self.root / "meta").mkdir(exist_ok=True, parents=True)
|
||||||
|
if not self.local_files_only:
|
||||||
self.pull_from_repo(allow_patterns="meta/")
|
self.pull_from_repo(allow_patterns="meta/")
|
||||||
self.info = load_info(self.root)
|
self.info = load_info(self.root)
|
||||||
self.stats = load_stats(self.root)
|
self.stats = load_stats(self.root)
|
||||||
|
@ -537,8 +538,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
]
|
]
|
||||||
files += video_files
|
files += video_files
|
||||||
|
|
||||||
# HACK: UNCOMMENT IF YOU REVIEW THAT, PLEASE SUGGEST TO UNCOMMENT
|
if not self.local_files_only:
|
||||||
logging.warning("HACK: WE COMMENT THIS LINE, IF SOMETHING IS WEIRD WITH DATASETS UNCOMMENT")
|
|
||||||
self.pull_from_repo(allow_patterns=files, ignore_patterns=ignore_patterns)
|
self.pull_from_repo(allow_patterns=files, ignore_patterns=ignore_patterns)
|
||||||
|
|
||||||
def load_hf_dataset(self) -> datasets.Dataset:
|
def load_hf_dataset(self) -> datasets.Dataset:
|
||||||
|
|
|
@ -77,8 +77,11 @@ policy:
|
||||||
mean: [0.485, 0.456, 0.406]
|
mean: [0.485, 0.456, 0.406]
|
||||||
std: [0.229, 0.224, 0.225]
|
std: [0.229, 0.224, 0.225]
|
||||||
observation.state:
|
observation.state:
|
||||||
min: [-87.09961, 62.402344, 67.23633, 36.035156, 77.34375,0.53691274]
|
min: [-77.08008, 56.25, 60.55664, 19.511719, 0., -0.63829786]
|
||||||
max: [58.183594, 131.83594, 145.98633, 82.08984, 78.22266, 0.60402685]
|
max: [ 7.215820e+01, 1.5398438e+02, 1.6075195e+02, 9.3251953e+01, 0., -1.4184397e-01]
|
||||||
|
|
||||||
|
# min: [-87.09961, 62.402344, 67.23633, 36.035156, 77.34375,0.53691274]
|
||||||
|
# max: [58.183594, 131.83594, 145.98633, 82.08984, 78.22266, 0.60402685]
|
||||||
# min: [-88.50586, 23.81836, 0.87890625, -32.16797, 78.66211, 0.53691274]
|
# min: [-88.50586, 23.81836, 0.87890625, -32.16797, 78.66211, 0.53691274]
|
||||||
# max: [84.55078, 187.11914, 145.98633, 101.60156, 146.60156, 88.18792]
|
# max: [84.55078, 187.11914, 145.98633, 101.60156, 146.60156, 88.18792]
|
||||||
|
|
||||||
|
@ -96,7 +99,7 @@ policy:
|
||||||
# Neural networks.
|
# Neural networks.
|
||||||
image_encoder_hidden_dim: 32
|
image_encoder_hidden_dim: 32
|
||||||
# discount: 0.99
|
# discount: 0.99
|
||||||
discount: 0.80
|
discount: 0.97
|
||||||
temperature_init: 1.0
|
temperature_init: 1.0
|
||||||
num_critics: 2 #10
|
num_critics: 2 #10
|
||||||
camera_number: 2
|
camera_number: 2
|
||||||
|
|
|
@ -181,10 +181,10 @@ class ReplayBuffer:
|
||||||
state = {key: tensor.to(self.storage_device) for key, tensor in state.items()}
|
state = {key: tensor.to(self.storage_device) for key, tensor in state.items()}
|
||||||
next_state = {key: tensor.to(self.storage_device) for key, tensor in next_state.items()}
|
next_state = {key: tensor.to(self.storage_device) for key, tensor in next_state.items()}
|
||||||
action = action.to(self.storage_device)
|
action = action.to(self.storage_device)
|
||||||
if complementary_info is not None:
|
# if complementary_info is not None:
|
||||||
complementary_info = {
|
# complementary_info = {
|
||||||
key: tensor.to(self.storage_device) for key, tensor in complementary_info.items()
|
# key: tensor.to(self.storage_device) for key, tensor in complementary_info.items()
|
||||||
}
|
# }
|
||||||
|
|
||||||
if len(self.memory) < self.capacity:
|
if len(self.memory) < self.capacity:
|
||||||
self.memory.append(None)
|
self.memory.append(None)
|
||||||
|
|
|
@ -1,7 +1,8 @@
|
||||||
import argparse # noqa: I001
|
import argparse # noqa: I001
|
||||||
|
import json
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import Dict, Tuple
|
from typing import Dict, Tuple
|
||||||
|
from pathlib import Path
|
||||||
import cv2
|
import cv2
|
||||||
|
|
||||||
# import torch.nn.functional as F # noqa: N812
|
# import torch.nn.functional as F # noqa: N812
|
||||||
|
@ -237,19 +238,27 @@ if __name__ == "__main__":
|
||||||
default=None,
|
default=None,
|
||||||
help="The root directory of the LeRobot dataset.",
|
help="The root directory of the LeRobot dataset.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--crop-params-path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="The path to the JSON file containing the ROIs.",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
dataset = LeRobotDataset(repo_id=args.repo_id, root=args.root, local_files_only=False)
|
local_files_only = args.root is not None
|
||||||
|
dataset = LeRobotDataset(repo_id=args.repo_id, root=args.root, local_files_only=local_files_only)
|
||||||
|
|
||||||
images = get_image_from_lerobot_dataset(dataset)
|
images = get_image_from_lerobot_dataset(dataset)
|
||||||
images = {k: v.cpu().permute(1, 2, 0).numpy() for k, v in images.items()}
|
images = {k: v.cpu().permute(1, 2, 0).numpy() for k, v in images.items()}
|
||||||
images = {k: (v * 255).astype("uint8") for k, v in images.items()}
|
images = {k: (v * 255).astype("uint8") for k, v in images.items()}
|
||||||
|
|
||||||
# rois = select_square_roi_for_images(images)
|
if args.crop_params_path is None:
|
||||||
rois = {
|
rois = select_square_roi_for_images(images)
|
||||||
"observation.images.front": [102, 43, 358, 523],
|
else:
|
||||||
"observation.images.side": [92, 123, 379, 349],
|
with open(args.crop_params_path, "r") as f:
|
||||||
}
|
rois = json.load(f)
|
||||||
|
|
||||||
# rois = {
|
# rois = {
|
||||||
# "observation.images.side": (92, 123, 379, 349),
|
# "observation.images.side": (92, 123, 379, 349),
|
||||||
# "observation.images.front": (109, 37, 361, 557),
|
# "observation.images.front": (109, 37, 361, 557),
|
||||||
|
@ -263,10 +272,20 @@ if __name__ == "__main__":
|
||||||
print("\nSelected Rectangular Regions of Interest (top, left, height, width):")
|
print("\nSelected Rectangular Regions of Interest (top, left, height, width):")
|
||||||
for key, roi in rois.items():
|
for key, roi in rois.items():
|
||||||
print(f"{key}: {roi}")
|
print(f"{key}: {roi}")
|
||||||
|
|
||||||
|
new_repo_id = args.repo_id + "_cropped_resized"
|
||||||
|
new_dataset_root = Path(str(dataset.root) + "_cropped_resized")
|
||||||
|
|
||||||
croped_resized_dataset = convert_lerobot_dataset_to_cropper_lerobot_dataset(
|
croped_resized_dataset = convert_lerobot_dataset_to_cropper_lerobot_dataset(
|
||||||
original_dataset=dataset,
|
original_dataset=dataset,
|
||||||
crop_params_dict=rois,
|
crop_params_dict=rois,
|
||||||
new_repo_id=args.repo_id + "_cropped_resized",
|
new_repo_id=new_repo_id,
|
||||||
new_dataset_root="data/" + args.repo_id + "_cropped_resized",
|
new_dataset_root=new_dataset_root,
|
||||||
resize_size=(128, 128),
|
resize_size=(128, 128),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
meta_dir = new_dataset_root / "meta"
|
||||||
|
meta_dir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
with open(meta_dir / "crop_params.json", "w") as f:
|
||||||
|
json.dump(rois, f, indent=4)
|
||||||
|
|
Loading…
Reference in New Issue