refactor(robots): use dicts in lekiwi for get_obs and send_action
This commit is contained in:
parent
5c2566a4a9
commit
48e47e97df
|
@ -199,7 +199,7 @@ class LeKiwi(Robot):
|
|||
# Capture images from cameras
|
||||
for cam_key, cam in self.cameras.items():
|
||||
start = time.perf_counter()
|
||||
obs_dict[f"{OBS_IMAGES}.{cam_key}"] = cam.async_read()
|
||||
obs_dict[OBS_IMAGES][cam_key] = cam.async_read()
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
|
||||
|
||||
|
|
|
@ -63,10 +63,10 @@ class LeKiwiClient(Robot):
|
|||
self.zmq_observation_socket = None
|
||||
|
||||
self.last_frames = {}
|
||||
self.last_present_speed = [0, 0, 0]
|
||||
self.last_present_speed = {"x_cmd": 0.0, "y_cmd": 0.0, "theta_cmd": 0.0}
|
||||
|
||||
# TODO(Steven): Move everything to 32 instead
|
||||
self.last_remote_arm_state = torch.zeros(6, dtype=torch.float64)
|
||||
self.last_remote_arm_state = {}
|
||||
|
||||
# Define three speed levels and a current index
|
||||
self.speed_levels = [
|
||||
|
@ -250,8 +250,8 @@ class LeKiwiClient(Robot):
|
|||
|
||||
# Copied from robot_lekiwi MobileManipulator class
|
||||
def _wheel_raw_to_body(
|
||||
self, wheel_raw: np.array, wheel_radius: float = 0.05, base_radius: float = 0.125
|
||||
) -> tuple:
|
||||
self, wheel_raw: dict[str, Any], wheel_radius: float = 0.05, base_radius: float = 0.125
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Convert wheel raw command feedback back into body-frame velocities.
|
||||
|
||||
|
@ -267,8 +267,9 @@ class LeKiwiClient(Robot):
|
|||
theta_cmd : Rotational velocity in deg/s.
|
||||
"""
|
||||
|
||||
# TODO(Steven): No check is done for dict keys
|
||||
# Convert each raw command back to an angular speed in deg/s.
|
||||
wheel_degps = np.array([LeKiwiClient._raw_to_degps(int(r)) for r in wheel_raw])
|
||||
wheel_degps = np.array([LeKiwiClient._raw_to_degps(int(v)) for _, v in wheel_raw.items()])
|
||||
# Convert from deg/s to rad/s.
|
||||
wheel_radps = wheel_degps * (np.pi / 180.0)
|
||||
# Compute each wheel’s linear speed (m/s) from its angular speed.
|
||||
|
@ -283,11 +284,10 @@ class LeKiwiClient(Robot):
|
|||
velocity_vector = m_inv.dot(wheel_linear_speeds)
|
||||
x_cmd, y_cmd, theta_rad = velocity_vector
|
||||
theta_cmd = theta_rad * (180.0 / np.pi)
|
||||
return (x_cmd, y_cmd, theta_cmd)
|
||||
return {"x_cmd": x_cmd, "y_cmd": y_cmd, "theta_cmd": theta_cmd}
|
||||
|
||||
# TODO(Steven): This is flaky, for example, if we received a state but failed decoding the image, we will not update any value
|
||||
# TODO(Steven): All this function needs to be refactored
|
||||
# TODO(Steven): Fix this next
|
||||
def _get_data(self):
|
||||
# Copied from robot_lekiwi.py
|
||||
"""Polls the video socket for up to 15 ms. If data arrives, decode only
|
||||
|
@ -295,10 +295,10 @@ class LeKiwiClient(Robot):
|
|||
nothing arrives for any field, use the last known values."""
|
||||
|
||||
frames = {}
|
||||
present_speed = []
|
||||
present_speed = {}
|
||||
|
||||
# TODO(Steven): Size is being assumed, is this safe?
|
||||
remote_arm_state_tensor = torch.empty(6, dtype=torch.float64)
|
||||
remote_arm_state_tensor = {}
|
||||
|
||||
# Poll up to 15 ms
|
||||
poller = zmq.Poller()
|
||||
|
@ -327,11 +327,10 @@ class LeKiwiClient(Robot):
|
|||
# Decode only the final message
|
||||
try:
|
||||
observation = json.loads(last_msg)
|
||||
observation[OBS_STATE] = np.array(observation[OBS_STATE])
|
||||
|
||||
# TODO(Steven): Consider getting directly the item with observation[OBS_STATE]
|
||||
state_observation = {k: v for k, v in observation.items() if k.startswith(OBS_STATE)}
|
||||
image_observation = {k: v for k, v in observation.items() if k.startswith(OBS_IMAGES)}
|
||||
state_observation = observation[OBS_STATE]
|
||||
image_observation = observation[OBS_IMAGES]
|
||||
|
||||
# Convert images
|
||||
for cam_name, image_b64 in image_observation.items():
|
||||
|
@ -342,14 +341,17 @@ class LeKiwiClient(Robot):
|
|||
if frame_candidate is not None:
|
||||
frames[cam_name] = frame_candidate
|
||||
|
||||
# TODO(Steven): Should we really ignore the arm state if the image is None?
|
||||
# If remote_arm_state is None and frames is None there is no message then use the previous message
|
||||
if state_observation is not None and frames is not None:
|
||||
self.last_frames = frames
|
||||
|
||||
remote_arm_state_tensor = torch.tensor(state_observation[OBS_STATE][:6], dtype=torch.float64)
|
||||
# TODO(Steven): Do we really need the casting here?
|
||||
# TODO(Steven): hard-coded name of expected keys, not good
|
||||
remote_arm_state_tensor = {k: v for k, v in state_observation.items() if k.startswith("arm")}
|
||||
self.last_remote_arm_state = remote_arm_state_tensor
|
||||
|
||||
present_speed = state_observation[OBS_STATE][6:]
|
||||
present_speed = {k: v for k, v in state_observation.items() if k.startswith("base")}
|
||||
self.last_present_speed = present_speed
|
||||
else:
|
||||
frames = self.last_frames
|
||||
|
@ -373,15 +375,18 @@ class LeKiwiClient(Robot):
|
|||
if not self._is_connected:
|
||||
raise DeviceNotConnectedError("LeKiwiClient is not connected. You need to run `robot.connect()`.")
|
||||
|
||||
obs_dict = {}
|
||||
# TODO(Steven): remove hard-coded cam name
|
||||
# This is needed at init for when there's no comms
|
||||
obs_dict = {
|
||||
OBS_IMAGES: {"wrist": np.zeros(shape=(480, 640, 3)), "front": np.zeros(shape=(640, 480, 3))}
|
||||
}
|
||||
|
||||
frames, present_speed, remote_arm_state_tensor = self._get_data()
|
||||
body_state = self._wheel_raw_to_body(present_speed)
|
||||
body_state_mm = (body_state[0] * 1000.0, body_state[1] * 1000.0, body_state[2]) # Convert x,y to mm/s
|
||||
wheel_state_tensor = torch.tensor(body_state_mm, dtype=torch.float64)
|
||||
combined_state_tensor = torch.cat((remote_arm_state_tensor, wheel_state_tensor), dim=0)
|
||||
# TODO(Steven): out isdict[str,Any] and we multiply by 1000.0. This should be more explicit and specify the expected type instead of Any
|
||||
body_state_mm = {k: v * 1000.0 for k, v in body_state.items()} # Convert x,y to mm/s
|
||||
|
||||
obs_dict = {OBS_STATE: combined_state_tensor}
|
||||
obs_dict[OBS_STATE] = {**remote_arm_state_tensor, **body_state_mm}
|
||||
|
||||
# Loop over each configured camera
|
||||
for cam_name, frame in frames.items():
|
||||
|
@ -389,14 +394,9 @@ class LeKiwiClient(Robot):
|
|||
# TODO(Steven): Daemon doesn't know camera dimensions
|
||||
logging.warning("Frame is None")
|
||||
frame = np.zeros((480, 640, 3), dtype=np.uint8)
|
||||
obs_dict[cam_name] = torch.from_numpy(frame)
|
||||
|
||||
# TODO(Steven): Refactor this ugly thing (needed for when there are not comms at init)
|
||||
if OBS_IMAGES + ".wrist" not in obs_dict:
|
||||
obs_dict[OBS_IMAGES + ".wrist"] = np.zeros(shape=(480, 640, 3))
|
||||
if OBS_IMAGES + ".front" not in obs_dict:
|
||||
obs_dict[OBS_IMAGES + ".front"] = np.zeros(shape=(640, 480, 3))
|
||||
obs_dict[OBS_IMAGES][cam_name] = torch.from_numpy(frame)
|
||||
|
||||
print("obs_dict", obs_dict)
|
||||
return obs_dict
|
||||
|
||||
def _from_keyboard_to_wheel_action(self, pressed_keys: np.ndarray):
|
||||
|
|
|
@ -89,12 +89,12 @@ def main():
|
|||
# Encode ndarrays to base64 strings
|
||||
for cam_key, _ in robot.cameras.items():
|
||||
ret, buffer = cv2.imencode(
|
||||
".jpg", last_observation[f"{OBS_IMAGES}.{cam_key}"], [int(cv2.IMWRITE_JPEG_QUALITY), 90]
|
||||
".jpg", last_observation[OBS_IMAGES][cam_key], [int(cv2.IMWRITE_JPEG_QUALITY), 90]
|
||||
)
|
||||
if ret:
|
||||
last_observation[f"{OBS_IMAGES}.{cam_key}"] = base64.b64encode(buffer).decode("utf-8")
|
||||
last_observation[OBS_IMAGES][cam_key] = base64.b64encode(buffer).decode("utf-8")
|
||||
else:
|
||||
last_observation[f"{OBS_IMAGES}.{cam_key}"] = ""
|
||||
last_observation[OBS_IMAGES][cam_key] = ""
|
||||
|
||||
# Send the observation to the remote agent
|
||||
remote_agent.zmq_observation_socket.send_string(json.dumps(last_observation))
|
||||
|
|
Loading…
Reference in New Issue