From 48e47e97df0796860a4bf4cdf3eb5ab73b0b24a4 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Mon, 7 Apr 2025 11:04:39 +0200 Subject: [PATCH] refactor(robots): use dicts in lekiwi for get_obs and send_action --- lerobot/common/robots/lekiwi/lekiwi.py | 2 +- lerobot/common/robots/lekiwi/lekiwi_client.py | 52 +++++++++---------- lerobot/common/robots/lekiwi/lekiwi_host.py | 6 +-- 3 files changed, 30 insertions(+), 30 deletions(-) diff --git a/lerobot/common/robots/lekiwi/lekiwi.py b/lerobot/common/robots/lekiwi/lekiwi.py index a9633b24..52045606 100644 --- a/lerobot/common/robots/lekiwi/lekiwi.py +++ b/lerobot/common/robots/lekiwi/lekiwi.py @@ -199,7 +199,7 @@ class LeKiwi(Robot): # Capture images from cameras for cam_key, cam in self.cameras.items(): start = time.perf_counter() - obs_dict[f"{OBS_IMAGES}.{cam_key}"] = cam.async_read() + obs_dict[OBS_IMAGES][cam_key] = cam.async_read() dt_ms = (time.perf_counter() - start) * 1e3 logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms") diff --git a/lerobot/common/robots/lekiwi/lekiwi_client.py b/lerobot/common/robots/lekiwi/lekiwi_client.py index bb2fefe7..34e2b40b 100644 --- a/lerobot/common/robots/lekiwi/lekiwi_client.py +++ b/lerobot/common/robots/lekiwi/lekiwi_client.py @@ -63,10 +63,10 @@ class LeKiwiClient(Robot): self.zmq_observation_socket = None self.last_frames = {} - self.last_present_speed = [0, 0, 0] + self.last_present_speed = {"x_cmd": 0.0, "y_cmd": 0.0, "theta_cmd": 0.0} # TODO(Steven): Move everything to 32 instead - self.last_remote_arm_state = torch.zeros(6, dtype=torch.float64) + self.last_remote_arm_state = {} # Define three speed levels and a current index self.speed_levels = [ @@ -250,8 +250,8 @@ class LeKiwiClient(Robot): # Copied from robot_lekiwi MobileManipulator class def _wheel_raw_to_body( - self, wheel_raw: np.array, wheel_radius: float = 0.05, base_radius: float = 0.125 - ) -> tuple: + self, wheel_raw: dict[str, Any], wheel_radius: float = 0.05, base_radius: float = 0.125 + ) -> dict[str, Any]: """ Convert wheel raw command feedback back into body-frame velocities. @@ -267,8 +267,9 @@ class LeKiwiClient(Robot): theta_cmd : Rotational velocity in deg/s. """ + # TODO(Steven): No check is done for dict keys # Convert each raw command back to an angular speed in deg/s. - wheel_degps = np.array([LeKiwiClient._raw_to_degps(int(r)) for r in wheel_raw]) + wheel_degps = np.array([LeKiwiClient._raw_to_degps(int(v)) for _, v in wheel_raw.items()]) # Convert from deg/s to rad/s. wheel_radps = wheel_degps * (np.pi / 180.0) # Compute each wheel’s linear speed (m/s) from its angular speed. @@ -283,11 +284,10 @@ class LeKiwiClient(Robot): velocity_vector = m_inv.dot(wheel_linear_speeds) x_cmd, y_cmd, theta_rad = velocity_vector theta_cmd = theta_rad * (180.0 / np.pi) - return (x_cmd, y_cmd, theta_cmd) + return {"x_cmd": x_cmd, "y_cmd": y_cmd, "theta_cmd": theta_cmd} # TODO(Steven): This is flaky, for example, if we received a state but failed decoding the image, we will not update any value # TODO(Steven): All this function needs to be refactored - # TODO(Steven): Fix this next def _get_data(self): # Copied from robot_lekiwi.py """Polls the video socket for up to 15 ms. If data arrives, decode only @@ -295,10 +295,10 @@ class LeKiwiClient(Robot): nothing arrives for any field, use the last known values.""" frames = {} - present_speed = [] + present_speed = {} # TODO(Steven): Size is being assumed, is this safe? - remote_arm_state_tensor = torch.empty(6, dtype=torch.float64) + remote_arm_state_tensor = {} # Poll up to 15 ms poller = zmq.Poller() @@ -327,11 +327,10 @@ class LeKiwiClient(Robot): # Decode only the final message try: observation = json.loads(last_msg) - observation[OBS_STATE] = np.array(observation[OBS_STATE]) # TODO(Steven): Consider getting directly the item with observation[OBS_STATE] - state_observation = {k: v for k, v in observation.items() if k.startswith(OBS_STATE)} - image_observation = {k: v for k, v in observation.items() if k.startswith(OBS_IMAGES)} + state_observation = observation[OBS_STATE] + image_observation = observation[OBS_IMAGES] # Convert images for cam_name, image_b64 in image_observation.items(): @@ -342,14 +341,17 @@ class LeKiwiClient(Robot): if frame_candidate is not None: frames[cam_name] = frame_candidate + # TODO(Steven): Should we really ignore the arm state if the image is None? # If remote_arm_state is None and frames is None there is no message then use the previous message if state_observation is not None and frames is not None: self.last_frames = frames - remote_arm_state_tensor = torch.tensor(state_observation[OBS_STATE][:6], dtype=torch.float64) + # TODO(Steven): Do we really need the casting here? + # TODO(Steven): hard-coded name of expected keys, not good + remote_arm_state_tensor = {k: v for k, v in state_observation.items() if k.startswith("arm")} self.last_remote_arm_state = remote_arm_state_tensor - present_speed = state_observation[OBS_STATE][6:] + present_speed = {k: v for k, v in state_observation.items() if k.startswith("base")} self.last_present_speed = present_speed else: frames = self.last_frames @@ -373,15 +375,18 @@ class LeKiwiClient(Robot): if not self._is_connected: raise DeviceNotConnectedError("LeKiwiClient is not connected. You need to run `robot.connect()`.") - obs_dict = {} + # TODO(Steven): remove hard-coded cam name + # This is needed at init for when there's no comms + obs_dict = { + OBS_IMAGES: {"wrist": np.zeros(shape=(480, 640, 3)), "front": np.zeros(shape=(640, 480, 3))} + } frames, present_speed, remote_arm_state_tensor = self._get_data() body_state = self._wheel_raw_to_body(present_speed) - body_state_mm = (body_state[0] * 1000.0, body_state[1] * 1000.0, body_state[2]) # Convert x,y to mm/s - wheel_state_tensor = torch.tensor(body_state_mm, dtype=torch.float64) - combined_state_tensor = torch.cat((remote_arm_state_tensor, wheel_state_tensor), dim=0) + # TODO(Steven): out isdict[str,Any] and we multiply by 1000.0. This should be more explicit and specify the expected type instead of Any + body_state_mm = {k: v * 1000.0 for k, v in body_state.items()} # Convert x,y to mm/s - obs_dict = {OBS_STATE: combined_state_tensor} + obs_dict[OBS_STATE] = {**remote_arm_state_tensor, **body_state_mm} # Loop over each configured camera for cam_name, frame in frames.items(): @@ -389,14 +394,9 @@ class LeKiwiClient(Robot): # TODO(Steven): Daemon doesn't know camera dimensions logging.warning("Frame is None") frame = np.zeros((480, 640, 3), dtype=np.uint8) - obs_dict[cam_name] = torch.from_numpy(frame) - - # TODO(Steven): Refactor this ugly thing (needed for when there are not comms at init) - if OBS_IMAGES + ".wrist" not in obs_dict: - obs_dict[OBS_IMAGES + ".wrist"] = np.zeros(shape=(480, 640, 3)) - if OBS_IMAGES + ".front" not in obs_dict: - obs_dict[OBS_IMAGES + ".front"] = np.zeros(shape=(640, 480, 3)) + obs_dict[OBS_IMAGES][cam_name] = torch.from_numpy(frame) + print("obs_dict", obs_dict) return obs_dict def _from_keyboard_to_wheel_action(self, pressed_keys: np.ndarray): diff --git a/lerobot/common/robots/lekiwi/lekiwi_host.py b/lerobot/common/robots/lekiwi/lekiwi_host.py index 631295b6..56228ac9 100644 --- a/lerobot/common/robots/lekiwi/lekiwi_host.py +++ b/lerobot/common/robots/lekiwi/lekiwi_host.py @@ -89,12 +89,12 @@ def main(): # Encode ndarrays to base64 strings for cam_key, _ in robot.cameras.items(): ret, buffer = cv2.imencode( - ".jpg", last_observation[f"{OBS_IMAGES}.{cam_key}"], [int(cv2.IMWRITE_JPEG_QUALITY), 90] + ".jpg", last_observation[OBS_IMAGES][cam_key], [int(cv2.IMWRITE_JPEG_QUALITY), 90] ) if ret: - last_observation[f"{OBS_IMAGES}.{cam_key}"] = base64.b64encode(buffer).decode("utf-8") + last_observation[OBS_IMAGES][cam_key] = base64.b64encode(buffer).decode("utf-8") else: - last_observation[f"{OBS_IMAGES}.{cam_key}"] = "" + last_observation[OBS_IMAGES][cam_key] = "" # Send the observation to the remote agent remote_agent.zmq_observation_socket.send_string(json.dumps(last_observation))