refactor(robots): update lekiwi client and host code for the new api
This commit is contained in:
parent
6c198d004c
commit
22a15ff755
|
@ -14,13 +14,10 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import base64
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import cv2
|
||||
|
||||
from lerobot.common.cameras.utils import make_cameras_from_configs
|
||||
from lerobot.common.constants import OBS_IMAGES, OBS_STATE
|
||||
from lerobot.common.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
|
@ -192,12 +189,7 @@ class LeKiwi(Robot):
|
|||
# Capture images from cameras
|
||||
for cam_key, cam in self.cameras.items():
|
||||
start = time.perf_counter()
|
||||
frame = cam.async_read()
|
||||
ret, buffer = cv2.imencode(".jpg", frame, [int(cv2.IMWRITE_JPEG_QUALITY), 90])
|
||||
if ret:
|
||||
obs_dict[f"{OBS_IMAGES}.{cam_key}"] = base64.b64encode(buffer).decode("utf-8")
|
||||
else:
|
||||
obs_dict[f"{OBS_IMAGES}.{cam_key}"] = ""
|
||||
obs_dict[f"{OBS_IMAGES}.{cam_key}"] = cam.async_read()
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
|
||||
|
||||
|
@ -229,15 +221,17 @@ class LeKiwi(Robot):
|
|||
present_pos = self.bus.sync_read("Present_Position", self.arm_motors)
|
||||
goal_present_pos = {key: (g_pos, present_pos[key]) for key, g_pos in arm_goal_pos.items()}
|
||||
arm_safe_goal_pos = ensure_safe_goal_position(goal_present_pos, self.config.max_relative_target)
|
||||
arm_goal_pos = arm_safe_goal_pos
|
||||
|
||||
# Send goal position to the actuators
|
||||
self.bus.sync_write("Goal_Position", arm_safe_goal_pos)
|
||||
self.bus.sync_write("Goal_Position", arm_goal_pos)
|
||||
self.bus.sync_write("Goal_Speed", base_goal_vel)
|
||||
|
||||
return {**arm_safe_goal_pos, **base_goal_vel}
|
||||
return {**arm_goal_pos, **base_goal_vel}
|
||||
|
||||
def stop_base(self):
|
||||
self.bus.sync_write("Goal_Speed", {name: 0 for name in self.base_motors}, num_retry=5)
|
||||
# TODO(Steven): Check this warning
|
||||
self.bus.sync_write("Goal_Speed", dict.fromkeys(self.base_motors, 0), num_retry=5)
|
||||
logger.info("Base motors stopped")
|
||||
|
||||
def disconnect(self):
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
import base64
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
@ -125,6 +126,14 @@ class LeKiwiClient(Robot):
|
|||
}
|
||||
return cam_ft
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
pass
|
||||
|
||||
@property
|
||||
def is_calibrated(self) -> bool:
|
||||
pass
|
||||
|
||||
def connect(self) -> None:
|
||||
"""Establishes ZMQ sockets with the remote mobile robot"""
|
||||
|
||||
|
@ -354,7 +363,7 @@ class LeKiwiClient(Robot):
|
|||
|
||||
# TODO(Steven): The returned space is different from the get_observation of LeKiwi
|
||||
# This returns body-frames velocities instead of wheel pos/speeds
|
||||
def get_observation(self) -> dict[str, np.ndarray]:
|
||||
def get_observation(self) -> dict[str, Any]:
|
||||
"""
|
||||
Capture observations from the remote robot: current follower arm positions,
|
||||
present wheel speeds (converted to body-frame velocities: x, y, theta),
|
||||
|
@ -418,6 +427,9 @@ class LeKiwiClient(Robot):
|
|||
|
||||
return self._body_to_wheel_raw(x_cmd, y_cmd, theta_cmd)
|
||||
|
||||
def configure(self):
|
||||
pass
|
||||
|
||||
# TODO(Steven): This assumes this call is always called from a keyboard teleop command
|
||||
# TODO(Steven): Doing this mapping in here adds latecy between send_action and movement from the user perspective.
|
||||
# t0: get teleop_cmd
|
||||
|
@ -430,7 +442,7 @@ class LeKiwiClient(Robot):
|
|||
# t2': send_action(motor_cmd)
|
||||
# t3': execute motor_cmd
|
||||
# t3'-t2' << t3-t1
|
||||
def send_action(self, action: np.ndarray) -> np.ndarray:
|
||||
def send_action(self, action: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Command lekiwi to move to a target joint configuration. Translates to motor space + sends over ZMQ
|
||||
|
||||
Args:
|
||||
|
|
|
@ -14,8 +14,6 @@
|
|||
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.robots.config import RobotMode
|
||||
from lerobot.common.teleoperators.keyboard import KeyboardTeleop, KeyboardTeleopConfig
|
||||
|
@ -115,7 +113,7 @@ def main():
|
|||
while i < 1000:
|
||||
arm_action = leader_arm.get_action()
|
||||
base_action = keyboard.get_action()
|
||||
action = np.append(arm_action, base_action) if base_action.size > 0 else arm_action
|
||||
action = {**arm_action, **base_action} if base_action.size > 0 else arm_action
|
||||
|
||||
# TODO(Steven): Deal with policy action space
|
||||
# robot.set_mode(RobotMode.AUTO)
|
||||
|
@ -125,8 +123,7 @@ def main():
|
|||
action_sent = robot.send_action(action)
|
||||
observation = robot.get_observation()
|
||||
|
||||
frame = {"action": action_sent}
|
||||
frame.update(observation)
|
||||
frame = {**action_sent, **observation}
|
||||
frame.update({"task": "Dummy Task Dataset"})
|
||||
|
||||
logging.info("Saved a frame into the dataset")
|
||||
|
|
|
@ -14,14 +14,15 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import cv2
|
||||
import zmq
|
||||
|
||||
from lerobot.common.constants import OBS_STATE
|
||||
from lerobot.common.constants import OBS_IMAGES, OBS_STATE
|
||||
|
||||
from .config_lekiwi import LeKiwiConfig
|
||||
from .lekiwi import LeKiwi
|
||||
|
@ -69,7 +70,7 @@ def main():
|
|||
loop_start_time = time.time()
|
||||
try:
|
||||
msg = remote_agent.zmq_cmd_socket.recv_string(zmq.NOBLOCK)
|
||||
data = np.array(json.loads(msg))
|
||||
data = dict(json.loads(msg))
|
||||
_action_sent = robot.send_action(data)
|
||||
last_cmd_time = time.time()
|
||||
except zmq.Again:
|
||||
|
@ -85,6 +86,18 @@ def main():
|
|||
|
||||
last_observation = robot.get_observation()
|
||||
last_observation[OBS_STATE] = last_observation[OBS_STATE].tolist()
|
||||
|
||||
# Encode ndarrays to base64 strings
|
||||
for cam_key, _ in robot.cameras.items():
|
||||
ret, buffer = cv2.imencode(
|
||||
".jpg", last_observation[f"{OBS_IMAGES}.{cam_key}"], [int(cv2.IMWRITE_JPEG_QUALITY), 90]
|
||||
)
|
||||
if ret:
|
||||
last_observation[f"{OBS_IMAGES}.{cam_key}"] = base64.b64encode(buffer).decode("utf-8")
|
||||
else:
|
||||
last_observation[f"{OBS_IMAGES}.{cam_key}"] = ""
|
||||
|
||||
# Send the observation to the remote agent
|
||||
remote_agent.zmq_observation_socket.send_string(json.dumps(last_observation))
|
||||
|
||||
# Ensure a short sleep to avoid overloading the CPU.
|
||||
|
|
Loading…
Reference in New Issue