refactor(robots): use dicts in lekiwi for get_obs and send_action
This commit is contained in:
parent
5c2566a4a9
commit
48e47e97df
|
@ -199,7 +199,7 @@ class LeKiwi(Robot):
|
||||||
# Capture images from cameras
|
# Capture images from cameras
|
||||||
for cam_key, cam in self.cameras.items():
|
for cam_key, cam in self.cameras.items():
|
||||||
start = time.perf_counter()
|
start = time.perf_counter()
|
||||||
obs_dict[f"{OBS_IMAGES}.{cam_key}"] = cam.async_read()
|
obs_dict[OBS_IMAGES][cam_key] = cam.async_read()
|
||||||
dt_ms = (time.perf_counter() - start) * 1e3
|
dt_ms = (time.perf_counter() - start) * 1e3
|
||||||
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
|
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
|
||||||
|
|
||||||
|
|
|
@ -63,10 +63,10 @@ class LeKiwiClient(Robot):
|
||||||
self.zmq_observation_socket = None
|
self.zmq_observation_socket = None
|
||||||
|
|
||||||
self.last_frames = {}
|
self.last_frames = {}
|
||||||
self.last_present_speed = [0, 0, 0]
|
self.last_present_speed = {"x_cmd": 0.0, "y_cmd": 0.0, "theta_cmd": 0.0}
|
||||||
|
|
||||||
# TODO(Steven): Move everything to 32 instead
|
# TODO(Steven): Move everything to 32 instead
|
||||||
self.last_remote_arm_state = torch.zeros(6, dtype=torch.float64)
|
self.last_remote_arm_state = {}
|
||||||
|
|
||||||
# Define three speed levels and a current index
|
# Define three speed levels and a current index
|
||||||
self.speed_levels = [
|
self.speed_levels = [
|
||||||
|
@ -250,8 +250,8 @@ class LeKiwiClient(Robot):
|
||||||
|
|
||||||
# Copied from robot_lekiwi MobileManipulator class
|
# Copied from robot_lekiwi MobileManipulator class
|
||||||
def _wheel_raw_to_body(
|
def _wheel_raw_to_body(
|
||||||
self, wheel_raw: np.array, wheel_radius: float = 0.05, base_radius: float = 0.125
|
self, wheel_raw: dict[str, Any], wheel_radius: float = 0.05, base_radius: float = 0.125
|
||||||
) -> tuple:
|
) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Convert wheel raw command feedback back into body-frame velocities.
|
Convert wheel raw command feedback back into body-frame velocities.
|
||||||
|
|
||||||
|
@ -267,8 +267,9 @@ class LeKiwiClient(Robot):
|
||||||
theta_cmd : Rotational velocity in deg/s.
|
theta_cmd : Rotational velocity in deg/s.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# TODO(Steven): No check is done for dict keys
|
||||||
# Convert each raw command back to an angular speed in deg/s.
|
# Convert each raw command back to an angular speed in deg/s.
|
||||||
wheel_degps = np.array([LeKiwiClient._raw_to_degps(int(r)) for r in wheel_raw])
|
wheel_degps = np.array([LeKiwiClient._raw_to_degps(int(v)) for _, v in wheel_raw.items()])
|
||||||
# Convert from deg/s to rad/s.
|
# Convert from deg/s to rad/s.
|
||||||
wheel_radps = wheel_degps * (np.pi / 180.0)
|
wheel_radps = wheel_degps * (np.pi / 180.0)
|
||||||
# Compute each wheel’s linear speed (m/s) from its angular speed.
|
# Compute each wheel’s linear speed (m/s) from its angular speed.
|
||||||
|
@ -283,11 +284,10 @@ class LeKiwiClient(Robot):
|
||||||
velocity_vector = m_inv.dot(wheel_linear_speeds)
|
velocity_vector = m_inv.dot(wheel_linear_speeds)
|
||||||
x_cmd, y_cmd, theta_rad = velocity_vector
|
x_cmd, y_cmd, theta_rad = velocity_vector
|
||||||
theta_cmd = theta_rad * (180.0 / np.pi)
|
theta_cmd = theta_rad * (180.0 / np.pi)
|
||||||
return (x_cmd, y_cmd, theta_cmd)
|
return {"x_cmd": x_cmd, "y_cmd": y_cmd, "theta_cmd": theta_cmd}
|
||||||
|
|
||||||
# TODO(Steven): This is flaky, for example, if we received a state but failed decoding the image, we will not update any value
|
# TODO(Steven): This is flaky, for example, if we received a state but failed decoding the image, we will not update any value
|
||||||
# TODO(Steven): All this function needs to be refactored
|
# TODO(Steven): All this function needs to be refactored
|
||||||
# TODO(Steven): Fix this next
|
|
||||||
def _get_data(self):
|
def _get_data(self):
|
||||||
# Copied from robot_lekiwi.py
|
# Copied from robot_lekiwi.py
|
||||||
"""Polls the video socket for up to 15 ms. If data arrives, decode only
|
"""Polls the video socket for up to 15 ms. If data arrives, decode only
|
||||||
|
@ -295,10 +295,10 @@ class LeKiwiClient(Robot):
|
||||||
nothing arrives for any field, use the last known values."""
|
nothing arrives for any field, use the last known values."""
|
||||||
|
|
||||||
frames = {}
|
frames = {}
|
||||||
present_speed = []
|
present_speed = {}
|
||||||
|
|
||||||
# TODO(Steven): Size is being assumed, is this safe?
|
# TODO(Steven): Size is being assumed, is this safe?
|
||||||
remote_arm_state_tensor = torch.empty(6, dtype=torch.float64)
|
remote_arm_state_tensor = {}
|
||||||
|
|
||||||
# Poll up to 15 ms
|
# Poll up to 15 ms
|
||||||
poller = zmq.Poller()
|
poller = zmq.Poller()
|
||||||
|
@ -327,11 +327,10 @@ class LeKiwiClient(Robot):
|
||||||
# Decode only the final message
|
# Decode only the final message
|
||||||
try:
|
try:
|
||||||
observation = json.loads(last_msg)
|
observation = json.loads(last_msg)
|
||||||
observation[OBS_STATE] = np.array(observation[OBS_STATE])
|
|
||||||
|
|
||||||
# TODO(Steven): Consider getting directly the item with observation[OBS_STATE]
|
# TODO(Steven): Consider getting directly the item with observation[OBS_STATE]
|
||||||
state_observation = {k: v for k, v in observation.items() if k.startswith(OBS_STATE)}
|
state_observation = observation[OBS_STATE]
|
||||||
image_observation = {k: v for k, v in observation.items() if k.startswith(OBS_IMAGES)}
|
image_observation = observation[OBS_IMAGES]
|
||||||
|
|
||||||
# Convert images
|
# Convert images
|
||||||
for cam_name, image_b64 in image_observation.items():
|
for cam_name, image_b64 in image_observation.items():
|
||||||
|
@ -342,14 +341,17 @@ class LeKiwiClient(Robot):
|
||||||
if frame_candidate is not None:
|
if frame_candidate is not None:
|
||||||
frames[cam_name] = frame_candidate
|
frames[cam_name] = frame_candidate
|
||||||
|
|
||||||
|
# TODO(Steven): Should we really ignore the arm state if the image is None?
|
||||||
# If remote_arm_state is None and frames is None there is no message then use the previous message
|
# If remote_arm_state is None and frames is None there is no message then use the previous message
|
||||||
if state_observation is not None and frames is not None:
|
if state_observation is not None and frames is not None:
|
||||||
self.last_frames = frames
|
self.last_frames = frames
|
||||||
|
|
||||||
remote_arm_state_tensor = torch.tensor(state_observation[OBS_STATE][:6], dtype=torch.float64)
|
# TODO(Steven): Do we really need the casting here?
|
||||||
|
# TODO(Steven): hard-coded name of expected keys, not good
|
||||||
|
remote_arm_state_tensor = {k: v for k, v in state_observation.items() if k.startswith("arm")}
|
||||||
self.last_remote_arm_state = remote_arm_state_tensor
|
self.last_remote_arm_state = remote_arm_state_tensor
|
||||||
|
|
||||||
present_speed = state_observation[OBS_STATE][6:]
|
present_speed = {k: v for k, v in state_observation.items() if k.startswith("base")}
|
||||||
self.last_present_speed = present_speed
|
self.last_present_speed = present_speed
|
||||||
else:
|
else:
|
||||||
frames = self.last_frames
|
frames = self.last_frames
|
||||||
|
@ -373,15 +375,18 @@ class LeKiwiClient(Robot):
|
||||||
if not self._is_connected:
|
if not self._is_connected:
|
||||||
raise DeviceNotConnectedError("LeKiwiClient is not connected. You need to run `robot.connect()`.")
|
raise DeviceNotConnectedError("LeKiwiClient is not connected. You need to run `robot.connect()`.")
|
||||||
|
|
||||||
obs_dict = {}
|
# TODO(Steven): remove hard-coded cam name
|
||||||
|
# This is needed at init for when there's no comms
|
||||||
|
obs_dict = {
|
||||||
|
OBS_IMAGES: {"wrist": np.zeros(shape=(480, 640, 3)), "front": np.zeros(shape=(640, 480, 3))}
|
||||||
|
}
|
||||||
|
|
||||||
frames, present_speed, remote_arm_state_tensor = self._get_data()
|
frames, present_speed, remote_arm_state_tensor = self._get_data()
|
||||||
body_state = self._wheel_raw_to_body(present_speed)
|
body_state = self._wheel_raw_to_body(present_speed)
|
||||||
body_state_mm = (body_state[0] * 1000.0, body_state[1] * 1000.0, body_state[2]) # Convert x,y to mm/s
|
# TODO(Steven): out isdict[str,Any] and we multiply by 1000.0. This should be more explicit and specify the expected type instead of Any
|
||||||
wheel_state_tensor = torch.tensor(body_state_mm, dtype=torch.float64)
|
body_state_mm = {k: v * 1000.0 for k, v in body_state.items()} # Convert x,y to mm/s
|
||||||
combined_state_tensor = torch.cat((remote_arm_state_tensor, wheel_state_tensor), dim=0)
|
|
||||||
|
|
||||||
obs_dict = {OBS_STATE: combined_state_tensor}
|
obs_dict[OBS_STATE] = {**remote_arm_state_tensor, **body_state_mm}
|
||||||
|
|
||||||
# Loop over each configured camera
|
# Loop over each configured camera
|
||||||
for cam_name, frame in frames.items():
|
for cam_name, frame in frames.items():
|
||||||
|
@ -389,14 +394,9 @@ class LeKiwiClient(Robot):
|
||||||
# TODO(Steven): Daemon doesn't know camera dimensions
|
# TODO(Steven): Daemon doesn't know camera dimensions
|
||||||
logging.warning("Frame is None")
|
logging.warning("Frame is None")
|
||||||
frame = np.zeros((480, 640, 3), dtype=np.uint8)
|
frame = np.zeros((480, 640, 3), dtype=np.uint8)
|
||||||
obs_dict[cam_name] = torch.from_numpy(frame)
|
obs_dict[OBS_IMAGES][cam_name] = torch.from_numpy(frame)
|
||||||
|
|
||||||
# TODO(Steven): Refactor this ugly thing (needed for when there are not comms at init)
|
|
||||||
if OBS_IMAGES + ".wrist" not in obs_dict:
|
|
||||||
obs_dict[OBS_IMAGES + ".wrist"] = np.zeros(shape=(480, 640, 3))
|
|
||||||
if OBS_IMAGES + ".front" not in obs_dict:
|
|
||||||
obs_dict[OBS_IMAGES + ".front"] = np.zeros(shape=(640, 480, 3))
|
|
||||||
|
|
||||||
|
print("obs_dict", obs_dict)
|
||||||
return obs_dict
|
return obs_dict
|
||||||
|
|
||||||
def _from_keyboard_to_wheel_action(self, pressed_keys: np.ndarray):
|
def _from_keyboard_to_wheel_action(self, pressed_keys: np.ndarray):
|
||||||
|
|
|
@ -89,12 +89,12 @@ def main():
|
||||||
# Encode ndarrays to base64 strings
|
# Encode ndarrays to base64 strings
|
||||||
for cam_key, _ in robot.cameras.items():
|
for cam_key, _ in robot.cameras.items():
|
||||||
ret, buffer = cv2.imencode(
|
ret, buffer = cv2.imencode(
|
||||||
".jpg", last_observation[f"{OBS_IMAGES}.{cam_key}"], [int(cv2.IMWRITE_JPEG_QUALITY), 90]
|
".jpg", last_observation[OBS_IMAGES][cam_key], [int(cv2.IMWRITE_JPEG_QUALITY), 90]
|
||||||
)
|
)
|
||||||
if ret:
|
if ret:
|
||||||
last_observation[f"{OBS_IMAGES}.{cam_key}"] = base64.b64encode(buffer).decode("utf-8")
|
last_observation[OBS_IMAGES][cam_key] = base64.b64encode(buffer).decode("utf-8")
|
||||||
else:
|
else:
|
||||||
last_observation[f"{OBS_IMAGES}.{cam_key}"] = ""
|
last_observation[OBS_IMAGES][cam_key] = ""
|
||||||
|
|
||||||
# Send the observation to the remote agent
|
# Send the observation to the remote agent
|
||||||
remote_agent.zmq_observation_socket.send_string(json.dumps(last_observation))
|
remote_agent.zmq_observation_socket.send_string(json.dumps(last_observation))
|
||||||
|
|
Loading…
Reference in New Issue