refactor(robots): use dicts in lekiwi for get_obs and send_action

This commit is contained in:
Steven Palma 2025-04-07 11:04:39 +02:00
parent 5c2566a4a9
commit 48e47e97df
No known key found for this signature in database
3 changed files with 30 additions and 30 deletions

View File

@ -199,7 +199,7 @@ class LeKiwi(Robot):
# Capture images from cameras
for cam_key, cam in self.cameras.items():
start = time.perf_counter()
obs_dict[f"{OBS_IMAGES}.{cam_key}"] = cam.async_read()
obs_dict[OBS_IMAGES][cam_key] = cam.async_read()
dt_ms = (time.perf_counter() - start) * 1e3
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")

View File

@ -63,10 +63,10 @@ class LeKiwiClient(Robot):
self.zmq_observation_socket = None
self.last_frames = {}
self.last_present_speed = [0, 0, 0]
self.last_present_speed = {"x_cmd": 0.0, "y_cmd": 0.0, "theta_cmd": 0.0}
# TODO(Steven): Move everything to 32 instead
self.last_remote_arm_state = torch.zeros(6, dtype=torch.float64)
self.last_remote_arm_state = {}
# Define three speed levels and a current index
self.speed_levels = [
@ -250,8 +250,8 @@ class LeKiwiClient(Robot):
# Copied from robot_lekiwi MobileManipulator class
def _wheel_raw_to_body(
self, wheel_raw: np.array, wheel_radius: float = 0.05, base_radius: float = 0.125
) -> tuple:
self, wheel_raw: dict[str, Any], wheel_radius: float = 0.05, base_radius: float = 0.125
) -> dict[str, Any]:
"""
Convert wheel raw command feedback back into body-frame velocities.
@ -267,8 +267,9 @@ class LeKiwiClient(Robot):
theta_cmd : Rotational velocity in deg/s.
"""
# TODO(Steven): No check is done for dict keys
# Convert each raw command back to an angular speed in deg/s.
wheel_degps = np.array([LeKiwiClient._raw_to_degps(int(r)) for r in wheel_raw])
wheel_degps = np.array([LeKiwiClient._raw_to_degps(int(v)) for _, v in wheel_raw.items()])
# Convert from deg/s to rad/s.
wheel_radps = wheel_degps * (np.pi / 180.0)
# Compute each wheels linear speed (m/s) from its angular speed.
@ -283,11 +284,10 @@ class LeKiwiClient(Robot):
velocity_vector = m_inv.dot(wheel_linear_speeds)
x_cmd, y_cmd, theta_rad = velocity_vector
theta_cmd = theta_rad * (180.0 / np.pi)
return (x_cmd, y_cmd, theta_cmd)
return {"x_cmd": x_cmd, "y_cmd": y_cmd, "theta_cmd": theta_cmd}
# TODO(Steven): This is flaky, for example, if we received a state but failed decoding the image, we will not update any value
# TODO(Steven): All this function needs to be refactored
# TODO(Steven): Fix this next
def _get_data(self):
# Copied from robot_lekiwi.py
"""Polls the video socket for up to 15 ms. If data arrives, decode only
@ -295,10 +295,10 @@ class LeKiwiClient(Robot):
nothing arrives for any field, use the last known values."""
frames = {}
present_speed = []
present_speed = {}
# TODO(Steven): Size is being assumed, is this safe?
remote_arm_state_tensor = torch.empty(6, dtype=torch.float64)
remote_arm_state_tensor = {}
# Poll up to 15 ms
poller = zmq.Poller()
@ -327,11 +327,10 @@ class LeKiwiClient(Robot):
# Decode only the final message
try:
observation = json.loads(last_msg)
observation[OBS_STATE] = np.array(observation[OBS_STATE])
# TODO(Steven): Consider getting directly the item with observation[OBS_STATE]
state_observation = {k: v for k, v in observation.items() if k.startswith(OBS_STATE)}
image_observation = {k: v for k, v in observation.items() if k.startswith(OBS_IMAGES)}
state_observation = observation[OBS_STATE]
image_observation = observation[OBS_IMAGES]
# Convert images
for cam_name, image_b64 in image_observation.items():
@ -342,14 +341,17 @@ class LeKiwiClient(Robot):
if frame_candidate is not None:
frames[cam_name] = frame_candidate
# TODO(Steven): Should we really ignore the arm state if the image is None?
# If remote_arm_state is None and frames is None there is no message then use the previous message
if state_observation is not None and frames is not None:
self.last_frames = frames
remote_arm_state_tensor = torch.tensor(state_observation[OBS_STATE][:6], dtype=torch.float64)
# TODO(Steven): Do we really need the casting here?
# TODO(Steven): hard-coded name of expected keys, not good
remote_arm_state_tensor = {k: v for k, v in state_observation.items() if k.startswith("arm")}
self.last_remote_arm_state = remote_arm_state_tensor
present_speed = state_observation[OBS_STATE][6:]
present_speed = {k: v for k, v in state_observation.items() if k.startswith("base")}
self.last_present_speed = present_speed
else:
frames = self.last_frames
@ -373,15 +375,18 @@ class LeKiwiClient(Robot):
if not self._is_connected:
raise DeviceNotConnectedError("LeKiwiClient is not connected. You need to run `robot.connect()`.")
obs_dict = {}
# TODO(Steven): remove hard-coded cam name
# This is needed at init for when there's no comms
obs_dict = {
OBS_IMAGES: {"wrist": np.zeros(shape=(480, 640, 3)), "front": np.zeros(shape=(640, 480, 3))}
}
frames, present_speed, remote_arm_state_tensor = self._get_data()
body_state = self._wheel_raw_to_body(present_speed)
body_state_mm = (body_state[0] * 1000.0, body_state[1] * 1000.0, body_state[2]) # Convert x,y to mm/s
wheel_state_tensor = torch.tensor(body_state_mm, dtype=torch.float64)
combined_state_tensor = torch.cat((remote_arm_state_tensor, wheel_state_tensor), dim=0)
# TODO(Steven): out isdict[str,Any] and we multiply by 1000.0. This should be more explicit and specify the expected type instead of Any
body_state_mm = {k: v * 1000.0 for k, v in body_state.items()} # Convert x,y to mm/s
obs_dict = {OBS_STATE: combined_state_tensor}
obs_dict[OBS_STATE] = {**remote_arm_state_tensor, **body_state_mm}
# Loop over each configured camera
for cam_name, frame in frames.items():
@ -389,14 +394,9 @@ class LeKiwiClient(Robot):
# TODO(Steven): Daemon doesn't know camera dimensions
logging.warning("Frame is None")
frame = np.zeros((480, 640, 3), dtype=np.uint8)
obs_dict[cam_name] = torch.from_numpy(frame)
# TODO(Steven): Refactor this ugly thing (needed for when there are not comms at init)
if OBS_IMAGES + ".wrist" not in obs_dict:
obs_dict[OBS_IMAGES + ".wrist"] = np.zeros(shape=(480, 640, 3))
if OBS_IMAGES + ".front" not in obs_dict:
obs_dict[OBS_IMAGES + ".front"] = np.zeros(shape=(640, 480, 3))
obs_dict[OBS_IMAGES][cam_name] = torch.from_numpy(frame)
print("obs_dict", obs_dict)
return obs_dict
def _from_keyboard_to_wheel_action(self, pressed_keys: np.ndarray):

View File

@ -89,12 +89,12 @@ def main():
# Encode ndarrays to base64 strings
for cam_key, _ in robot.cameras.items():
ret, buffer = cv2.imencode(
".jpg", last_observation[f"{OBS_IMAGES}.{cam_key}"], [int(cv2.IMWRITE_JPEG_QUALITY), 90]
".jpg", last_observation[OBS_IMAGES][cam_key], [int(cv2.IMWRITE_JPEG_QUALITY), 90]
)
if ret:
last_observation[f"{OBS_IMAGES}.{cam_key}"] = base64.b64encode(buffer).decode("utf-8")
last_observation[OBS_IMAGES][cam_key] = base64.b64encode(buffer).decode("utf-8")
else:
last_observation[f"{OBS_IMAGES}.{cam_key}"] = ""
last_observation[OBS_IMAGES][cam_key] = ""
# Send the observation to the remote agent
remote_agent.zmq_observation_socket.send_string(json.dumps(last_observation))