From 6366c7f46e9da3dc2cdb3d4260b981e260e12f2b Mon Sep 17 00:00:00 2001
From: Simon Alibert <simon.alibert@huggingface.co>
Date: Wed, 27 Nov 2024 11:11:54 +0100
Subject: [PATCH] WIP

---
 .../common/robot_devices/cameras/reachy2.py   | 48 ++++++++++++++-----
 lerobot/common/robot_devices/control_utils.py |  2 +-
 .../common/robot_devices/robots/reachy2.py    | 46 ++++++++++++------
 lerobot/configs/robot/reachy2.yaml            | 15 +++++-
 lerobot/scripts/control_robot.py              | 21 ++++----
 5 files changed, 92 insertions(+), 40 deletions(-)

diff --git a/lerobot/common/robot_devices/cameras/reachy2.py b/lerobot/common/robot_devices/cameras/reachy2.py
index c581c3e1..24040034 100644
--- a/lerobot/common/robot_devices/cameras/reachy2.py
+++ b/lerobot/common/robot_devices/cameras/reachy2.py
@@ -2,8 +2,9 @@
 Wrapper for Reachy2 camera from sdk
 """
 
-from dataclasses import dataclass
+from dataclasses import dataclass, replace
 
+import cv2
 import numpy as np
 from reachy2_sdk.media.camera import CameraView
 from reachy2_sdk.media.camera_manager import CameraManager
@@ -18,6 +19,14 @@ class ReachyCameraConfig:
     rotation: int | None = None
     mock: bool = False
 
+    def __post_init__(self):
+        if self.color_mode not in ["rgb", "bgr"]:
+            raise ValueError(
+                f"`color_mode` is expected to be 'rgb' or 'bgr', but {self.color_mode} is provided."
+            )
+
+        self.channels = 3
+
 
 class ReachyCamera:
     def __init__(
@@ -29,8 +38,18 @@ class ReachyCamera:
         config: ReachyCameraConfig | None = None,
         **kwargs,
     ):
+        if config is None:
+            config = ReachyCameraConfig()
+
+        # Overwrite config arguments using kwargs
+        config = replace(config, **kwargs)
+
         self.host = host
         self.port = port
+        self.width = config.width
+        self.height = config.height
+        self.channels = config.channels
+        self.fps = config.fps
         self.image_type = image_type
         self.name = name
         self.config = config
@@ -48,21 +67,24 @@ class ReachyCamera:
         if not self.is_connected:
             self.connect()
 
+        frame = None
+
         if self.name == "teleop" and hasattr(self.cam_manager, "teleop"):
             if self.image_type == "left":
-                return self.cam_manager.teleop.get_frame(CameraView.LEFT)
-                # return self.cam_manager.teleop.get_compressed_frame(CameraView.LEFT)
+                frame = self.cam_manager.teleop.get_frame(CameraView.LEFT)
             elif self.image_type == "right":
-                return self.cam_manager.teleop.get_frame(CameraView.RIGHT)
-                # return self.cam_manager.teleop.get_compressed_frame(CameraView.RIGHT)
-            else:
-                return None
+                frame = self.cam_manager.teleop.get_frame(CameraView.RIGHT)
         elif self.name == "depth" and hasattr(self.cam_manager, "depth"):
             if self.image_type == "depth":
-                return self.cam_manager.depth.get_depth_frame()
+                frame = self.cam_manager.depth.get_depth_frame()
             elif self.image_type == "rgb":
-                return self.cam_manager.depth.get_frame()
-                # return self.cam_manager.depth.get_compressed_frame()
-            else:
-                return None
-        return None
+                frame = self.cam_manager.depth.get_frame()
+
+        if frame is None:
+            return None
+
+        if frame is not None and self.config.color_mode == "rgb":
+            img, timestamp = frame
+            frame = (cv2.cvtColor(img, cv2.COLOR_BGR2RGB), timestamp)
+
+        return frame
diff --git a/lerobot/common/robot_devices/control_utils.py b/lerobot/common/robot_devices/control_utils.py
index 87690217..b2d54a66 100644
--- a/lerobot/common/robot_devices/control_utils.py
+++ b/lerobot/common/robot_devices/control_utils.py
@@ -46,7 +46,7 @@ def log_control_info(robot: Robot, dt_s, episode_index=None, frame_index=None, f
     log_dt("dt", dt_s)
 
     # TODO(aliberts): move robot-specific logs logic in robot.print_logs()
-    if not robot.robot_type.startswith(("stretch", "Reachy")):
+    if not robot.robot_type.lower().startswith(("stretch", "reachy")):
         for name in robot.leader_arms:
             key = f"read_leader_{name}_pos_dt_s"
             if key in robot.logs:
diff --git a/lerobot/common/robot_devices/robots/reachy2.py b/lerobot/common/robot_devices/robots/reachy2.py
index 2667f3e2..d048d6f9 100644
--- a/lerobot/common/robot_devices/robots/reachy2.py
+++ b/lerobot/common/robot_devices/robots/reachy2.py
@@ -18,10 +18,11 @@ import time
 from copy import copy
 from dataclasses import dataclass, field, replace
 
+import numpy as np
 import torch
 from reachy2_sdk import ReachySDK
 
-from lerobot.common.robot_devices.cameras.utils import Camera
+from lerobot.common.robot_devices.cameras.reachy2 import ReachyCamera
 
 REACHY_MOTORS = [
     "neck_yaw.pos",
@@ -52,8 +53,9 @@ REACHY_MOTORS = [
 @dataclass
 class ReachyRobotConfig:
     robot_type: str | None = "reachy2"
-    cameras: dict[str, Camera] = field(default_factory=lambda: {})
+    cameras: dict[str, ReachyCamera] = field(default_factory=lambda: {})
     ip_address: str | None = "172.17.135.207"
+    # ip_address: str | None = "192.168.0.197"
     # ip_address: str | None = "localhost"
 
 
@@ -74,10 +76,8 @@ class ReachyRobot:
         self.is_connected = False
         self.teleop = None
         self.logs = {}
-        self.reachy: ReachySDK = ReachySDK(host=config.ip_address)
-        self.reachy.turn_on()
-        self.is_connected = True  # at init Reachy2 is in fact connected...
-        self.mobile_base_available = self.reachy.mobile_base is not None
+        self.reachy = None
+        self.mobile_base_available = False
 
         self.state_keys = None
         self.action_keys = None
@@ -96,16 +96,19 @@ class ReachyRobot:
 
     @property
     def motor_features(self) -> dict:
+        motors = REACHY_MOTORS
+        # if self.mobile_base_available:
+        #     motors += REACHY_MOBILE_BASE
         return {
             "action": {
                 "dtype": "float32",
-                "shape": (len(REACHY_MOTORS),),
-                "names": REACHY_MOTORS,
+                "shape": (len(motors),),
+                "names": motors,
             },
             "observation.state": {
                 "dtype": "float32",
-                "shape": (len(REACHY_MOTORS),),
-                "names": REACHY_MOTORS,
+                "shape": (len(motors),),
+                "names": motors,
             },
         }
 
@@ -114,14 +117,16 @@ class ReachyRobot:
         return {**self.motor_features, **self.camera_features}
 
     def connect(self) -> None:
+        self.reachy = ReachySDK(host=self.config.ip_address)
         print("Connecting to Reachy")
-        self.reachy.is_connected = self.reachy.connect()
+        self.reachy.connect()
+        self.is_connected = self.reachy.is_connected
         if not self.is_connected:
             print(
                 f"Cannot connect to Reachy at address {self.config.ip_address}. Maybe a connection already exists."
             )
             raise ConnectionError()
-        self.reachy.turn_on()
+        # self.reachy.turn_on()
         print(self.cameras)
         if self.cameras is not None:
             for name in self.cameras:
@@ -133,6 +138,8 @@ class ReachyRobot:
             print("Could not connect to the cameras, check that all cameras are plugged-in.")
             raise ConnectionError()
 
+        self.mobile_base_available = self.reachy.mobile_base is not None
+
     def run_calibration(self):
         pass
 
@@ -169,8 +176,14 @@ class ReachyRobot:
             action["mobile_base_x.vel"] = last_cmd_vel["x"]
             action["mobile_base_y.vel"] = last_cmd_vel["y"]
             action["mobile_base_theta.vel"] = last_cmd_vel["theta"]
+        else:
+            action["mobile_base_x.vel"] = 0
+            action["mobile_base_y.vel"] = 0
+            action["mobile_base_theta.vel"] = 0
 
-        action = torch.as_tensor(list(action.values()))
+        dtype = self.motor_features["action"]["dtype"]
+        action = np.array(list(action.values()), dtype=dtype)
+        # action = torch.as_tensor(list(action.values()))
 
         obs_dict = self.capture_observation()
         action_dict = {}
@@ -224,7 +237,9 @@ class ReachyRobot:
             if self.state_keys is None:
                 self.state_keys = list(state)
 
-            state = torch.as_tensor(list(state.values()))
+            dtype = self.motor_features["observation.state"]["dtype"]
+            state = np.array(list(state.values()), dtype=dtype)
+            # state = torch.as_tensor(list(state.values()))
 
             # Capture images from cameras
             images = {}
@@ -233,6 +248,7 @@ class ReachyRobot:
                 images[name] = self.cameras[name].read()  # Reachy cameras read() is not blocking?
                 # print(f'name: {name} img: {images[name]}')
                 if images[name] is not None:
+                    # images[name] = copy(images[name][0])  # seems like I need to copy?
                     images[name] = torch.from_numpy(copy(images[name][0]))  # seems like I need to copy?
                     self.logs[f"read_camera_{name}_dt_s"] = images[name][1]  # full timestamp, TODO dt
 
@@ -295,7 +311,7 @@ class ReachyRobot:
         print("Disconnecting")
         self.is_connected = False
         print("Turn off")
-        self.reachy.turn_off_smoothly()
+        # self.reachy.turn_off_smoothly()
         # self.reachy.turn_off()
         print("\t turn off done")
         self.reachy.disconnect()
diff --git a/lerobot/configs/robot/reachy2.yaml b/lerobot/configs/robot/reachy2.yaml
index 5ec9c23c..4fbfb75e 100644
--- a/lerobot/configs/robot/reachy2.yaml
+++ b/lerobot/configs/robot/reachy2.yaml
@@ -12,9 +12,13 @@ cameras:
   head_left:
     _target_: lerobot.common.robot_devices.cameras.reachy2.ReachyCamera
     name: teleop
-    host: 172.17.135.207
+    host: 172.17.134.85
+    # host: 192.168.0.197
     # host: localhost
     port: 50065
+    fps: 30
+    width: 960
+    height: 720
     image_type: left
   # head_right:
   #   _target_: lerobot.common.robot_devices.cameras.reachy2.ReachyCamera
@@ -22,6 +26,9 @@ cameras:
   #   host: 172.17.135.207
   #   port: 50065
   #   image_type: right
+  #   fps: 30
+  #   width: 960
+  #   height: 720
   # torso_rgb:
   #   _target_: lerobot.common.robot_devices.cameras.reachy2.ReachyCamera
   #   name: depth
@@ -29,9 +36,15 @@ cameras:
   #   # host: localhost
   #   port: 50065
   #   image_type: rgb
+  #   fps: 30
+  #   width: 1280
+  #   height: 720
   # torso_depth:
   #   _target_: lerobot.common.robot_devices.cameras.reachy2.ReachyCamera
   #   name: depth
   #   host: 172.17.135.207
   #   port: 50065
   #   image_type: depth
+  #   fps: 30
+  #   width: 1280
+  #   height: 720
diff --git a/lerobot/scripts/control_robot.py b/lerobot/scripts/control_robot.py
index ad73eef4..3eac60ea 100644
--- a/lerobot/scripts/control_robot.py
+++ b/lerobot/scripts/control_robot.py
@@ -191,7 +191,7 @@ def teleoperate(
 @safe_disconnect
 def record(
     robot: Robot,
-    root: str,
+    root: Path,
     repo_id: str,
     single_task: str,
     pretrained_policy_name_or_path: str | None = None,
@@ -204,6 +204,7 @@ def record(
     video: bool = True,
     run_compute_stats: bool = True,
     push_to_hub: bool = True,
+    tags: list[str] | None = None,
     num_image_writer_processes: int = 0,
     num_image_writer_threads_per_camera: int = 4,
     display_cameras: bool = True,
@@ -331,7 +332,7 @@ def record(
     dataset.consolidate(run_compute_stats)
 
     if push_to_hub:
-        dataset.push_to_hub()
+        dataset.push_to_hub(tags=tags)
 
     log_say("Exiting", play_sounds)
     return dataset
@@ -427,7 +428,7 @@ if __name__ == "__main__":
     parser_record.add_argument(
         "--root",
         type=Path,
-        default="data",
+        default=None,
         help="Root directory where the dataset will be stored locally at '{root}/{repo_id}' (e.g. 'data/hf_username/dataset_name').",
     )
     parser_record.add_argument(
@@ -436,6 +437,12 @@ if __name__ == "__main__":
         default="lerobot/test",
         help="Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).",
     )
+    parser_record.add_argument(
+        "--resume",
+        type=int,
+        default=0,
+        help="Resume recording on an existing dataset.",
+    )
     parser_record.add_argument(
         "--warmup-time-s",
         type=int,
@@ -494,12 +501,6 @@ if __name__ == "__main__":
             "Not enough threads might cause low camera fps."
         ),
     )
-    parser_record.add_argument(
-        "--force-override",
-        type=int,
-        default=0,
-        help="By default, data recording is resumed. When set to 1, delete the local directory and start data recording from scratch.",
-    )
     parser_record.add_argument(
         "-p",
         "--pretrained-policy-name-or-path",
@@ -523,7 +524,7 @@ if __name__ == "__main__":
     parser_replay.add_argument(
         "--root",
         type=Path,
-        default="data",
+        default=None,
         help="Root directory where the dataset will be stored locally at '{root}/{repo_id}' (e.g. 'data/hf_username/dataset_name').",
     )
     parser_replay.add_argument(