From 36711d766a77eba30591aeca3ff303c228df1895 Mon Sep 17 00:00:00 2001
From: Michel Aractingi <michel.aractingi@huggingface.co>
Date: Fri, 14 Feb 2025 12:32:45 +0100
Subject: [PATCH] Modified crop_dataset_roi interface to automatically write
 the cropped parameters to a json file in the meta of the dataset

Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
---
 lerobot/common/datasets/lerobot_dataset.py |  8 ++---
 lerobot/configs/policy/sac_real.yaml       |  9 ++++--
 lerobot/scripts/server/buffer.py           |  8 ++---
 lerobot/scripts/server/crop_dataset_roi.py | 37 ++++++++++++++++------
 4 files changed, 42 insertions(+), 20 deletions(-)

diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py
index 5278987b..000b0bcb 100644
--- a/lerobot/common/datasets/lerobot_dataset.py
+++ b/lerobot/common/datasets/lerobot_dataset.py
@@ -84,7 +84,8 @@ class LeRobotDatasetMetadata:
 
         # Load metadata
         (self.root / "meta").mkdir(exist_ok=True, parents=True)
-        self.pull_from_repo(allow_patterns="meta/")
+        if not self.local_files_only:
+            self.pull_from_repo(allow_patterns="meta/")
         self.info = load_info(self.root)
         self.stats = load_stats(self.root)
         self.tasks = load_tasks(self.root)
@@ -537,9 +538,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
                 ]
                 files += video_files
 
-        # HACK: UNCOMMENT IF YOU REVIEW THAT, PLEASE SUGGEST TO UNCOMMENT
-        logging.warning("HACK: WE COMMENT THIS LINE, IF SOMETHING IS WEIRD WITH DATASETS UNCOMMENT")
-        self.pull_from_repo(allow_patterns=files, ignore_patterns=ignore_patterns)
+        if not self.local_files_only:
+            self.pull_from_repo(allow_patterns=files, ignore_patterns=ignore_patterns)
 
     def load_hf_dataset(self) -> datasets.Dataset:
         """hf_dataset contains all the observations, states, actions, rewards, etc."""
diff --git a/lerobot/configs/policy/sac_real.yaml b/lerobot/configs/policy/sac_real.yaml
index 14a63713..5d248aef 100644
--- a/lerobot/configs/policy/sac_real.yaml
+++ b/lerobot/configs/policy/sac_real.yaml
@@ -77,8 +77,11 @@ policy:
       mean: [0.485, 0.456, 0.406]
       std: [0.229, 0.224, 0.225]
     observation.state:
-      min: [-87.09961,     62.402344,    67.23633,     36.035156,    77.34375,0.53691274] 
-      max: [58.183594,   131.83594,    145.98633,     82.08984,     78.22266, 0.60402685]
+      min: [-77.08008,     56.25,        60.55664,     19.511719,   0., -0.63829786]
+      max: [ 7.215820e+01,  1.5398438e+02,  1.6075195e+02,  9.3251953e+01, 0., -1.4184397e-01]
+
+      # min: [-87.09961,     62.402344,    67.23633,     36.035156,    77.34375,0.53691274] 
+      # max: [58.183594,   131.83594,    145.98633,     82.08984,     78.22266, 0.60402685]
       # min: [-88.50586,  23.81836, 0.87890625, -32.16797, 78.66211,   0.53691274]
       # max: [84.55078, 187.11914, 145.98633, 101.60156, 146.60156,  88.18792]
 
@@ -96,7 +99,7 @@ policy:
   # Neural networks.
   image_encoder_hidden_dim: 32
   # discount: 0.99
-  discount: 0.80
+  discount: 0.97
   temperature_init: 1.0
   num_critics: 2 #10
   camera_number: 2
diff --git a/lerobot/scripts/server/buffer.py b/lerobot/scripts/server/buffer.py
index fb463762..dcfc259c 100644
--- a/lerobot/scripts/server/buffer.py
+++ b/lerobot/scripts/server/buffer.py
@@ -181,10 +181,10 @@ class ReplayBuffer:
         state = {key: tensor.to(self.storage_device) for key, tensor in state.items()}
         next_state = {key: tensor.to(self.storage_device) for key, tensor in next_state.items()}
         action = action.to(self.storage_device)
-        if complementary_info is not None:
-            complementary_info = {
-                key: tensor.to(self.storage_device) for key, tensor in complementary_info.items()
-            }
+        # if complementary_info is not None:
+        #     complementary_info = {
+        #         key: tensor.to(self.storage_device) for key, tensor in complementary_info.items()
+        #     }
 
         if len(self.memory) < self.capacity:
             self.memory.append(None)
diff --git a/lerobot/scripts/server/crop_dataset_roi.py b/lerobot/scripts/server/crop_dataset_roi.py
index 5b534a46..172eb22c 100644
--- a/lerobot/scripts/server/crop_dataset_roi.py
+++ b/lerobot/scripts/server/crop_dataset_roi.py
@@ -1,7 +1,8 @@
 import argparse  # noqa: I001
+import json
 from copy import deepcopy
 from typing import Dict, Tuple
-
+from pathlib import Path
 import cv2
 
 # import torch.nn.functional as F  # noqa: N812
@@ -237,19 +238,27 @@ if __name__ == "__main__":
         default=None,
         help="The root directory of the LeRobot dataset.",
     )
+    parser.add_argument(
+        "--crop-params-path",
+        type=str,
+        default=None,
+        help="The path to the JSON file containing the ROIs.",
+    )
     args = parser.parse_args()
 
-    dataset = LeRobotDataset(repo_id=args.repo_id, root=args.root, local_files_only=False)
+    local_files_only = args.root is not None
+    dataset = LeRobotDataset(repo_id=args.repo_id, root=args.root, local_files_only=local_files_only)
 
     images = get_image_from_lerobot_dataset(dataset)
     images = {k: v.cpu().permute(1, 2, 0).numpy() for k, v in images.items()}
     images = {k: (v * 255).astype("uint8") for k, v in images.items()}
 
-    # rois = select_square_roi_for_images(images)
-    rois = {
-        "observation.images.front": [102, 43, 358, 523],
-        "observation.images.side": [92, 123, 379, 349],
-    }
+    if args.crop_params_path is None:
+        rois = select_square_roi_for_images(images)
+    else:
+        with open(args.crop_params_path, "r") as f:
+            rois = json.load(f)
+
     # rois = {
     #     "observation.images.side": (92, 123, 379, 349),
     #     "observation.images.front": (109, 37, 361, 557),
@@ -263,10 +272,20 @@ if __name__ == "__main__":
     print("\nSelected Rectangular Regions of Interest (top, left, height, width):")
     for key, roi in rois.items():
         print(f"{key}: {roi}")
+
+    new_repo_id = args.repo_id + "_cropped_resized"
+    new_dataset_root = Path(str(dataset.root) + "_cropped_resized")
+
     croped_resized_dataset = convert_lerobot_dataset_to_cropper_lerobot_dataset(
         original_dataset=dataset,
         crop_params_dict=rois,
-        new_repo_id=args.repo_id + "_cropped_resized",
-        new_dataset_root="data/" + args.repo_id + "_cropped_resized",
+        new_repo_id=new_repo_id,
+        new_dataset_root=new_dataset_root,
         resize_size=(128, 128),
     )
+
+    meta_dir = new_dataset_root / "meta"
+    meta_dir.mkdir(exist_ok=True)
+
+    with open(meta_dir / "crop_params.json", "w") as f:
+        json.dump(rois, f, indent=4)