refactor(robots): update lekiwi client and host code for the new api
This commit is contained in:
parent
6c198d004c
commit
22a15ff755
|
@ -14,13 +14,10 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import base64
|
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import cv2
|
|
||||||
|
|
||||||
from lerobot.common.cameras.utils import make_cameras_from_configs
|
from lerobot.common.cameras.utils import make_cameras_from_configs
|
||||||
from lerobot.common.constants import OBS_IMAGES, OBS_STATE
|
from lerobot.common.constants import OBS_IMAGES, OBS_STATE
|
||||||
from lerobot.common.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
from lerobot.common.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||||
|
@ -192,12 +189,7 @@ class LeKiwi(Robot):
|
||||||
# Capture images from cameras
|
# Capture images from cameras
|
||||||
for cam_key, cam in self.cameras.items():
|
for cam_key, cam in self.cameras.items():
|
||||||
start = time.perf_counter()
|
start = time.perf_counter()
|
||||||
frame = cam.async_read()
|
obs_dict[f"{OBS_IMAGES}.{cam_key}"] = cam.async_read()
|
||||||
ret, buffer = cv2.imencode(".jpg", frame, [int(cv2.IMWRITE_JPEG_QUALITY), 90])
|
|
||||||
if ret:
|
|
||||||
obs_dict[f"{OBS_IMAGES}.{cam_key}"] = base64.b64encode(buffer).decode("utf-8")
|
|
||||||
else:
|
|
||||||
obs_dict[f"{OBS_IMAGES}.{cam_key}"] = ""
|
|
||||||
dt_ms = (time.perf_counter() - start) * 1e3
|
dt_ms = (time.perf_counter() - start) * 1e3
|
||||||
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
|
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
|
||||||
|
|
||||||
|
@ -229,15 +221,17 @@ class LeKiwi(Robot):
|
||||||
present_pos = self.bus.sync_read("Present_Position", self.arm_motors)
|
present_pos = self.bus.sync_read("Present_Position", self.arm_motors)
|
||||||
goal_present_pos = {key: (g_pos, present_pos[key]) for key, g_pos in arm_goal_pos.items()}
|
goal_present_pos = {key: (g_pos, present_pos[key]) for key, g_pos in arm_goal_pos.items()}
|
||||||
arm_safe_goal_pos = ensure_safe_goal_position(goal_present_pos, self.config.max_relative_target)
|
arm_safe_goal_pos = ensure_safe_goal_position(goal_present_pos, self.config.max_relative_target)
|
||||||
|
arm_goal_pos = arm_safe_goal_pos
|
||||||
|
|
||||||
# Send goal position to the actuators
|
# Send goal position to the actuators
|
||||||
self.bus.sync_write("Goal_Position", arm_safe_goal_pos)
|
self.bus.sync_write("Goal_Position", arm_goal_pos)
|
||||||
self.bus.sync_write("Goal_Speed", base_goal_vel)
|
self.bus.sync_write("Goal_Speed", base_goal_vel)
|
||||||
|
|
||||||
return {**arm_safe_goal_pos, **base_goal_vel}
|
return {**arm_goal_pos, **base_goal_vel}
|
||||||
|
|
||||||
def stop_base(self):
|
def stop_base(self):
|
||||||
self.bus.sync_write("Goal_Speed", {name: 0 for name in self.base_motors}, num_retry=5)
|
# TODO(Steven): Check this warning
|
||||||
|
self.bus.sync_write("Goal_Speed", dict.fromkeys(self.base_motors, 0), num_retry=5)
|
||||||
logger.info("Base motors stopped")
|
logger.info("Base motors stopped")
|
||||||
|
|
||||||
def disconnect(self):
|
def disconnect(self):
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
import base64
|
import base64
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -125,6 +126,14 @@ class LeKiwiClient(Robot):
|
||||||
}
|
}
|
||||||
return cam_ft
|
return cam_ft
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_connected(self) -> bool:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_calibrated(self) -> bool:
|
||||||
|
pass
|
||||||
|
|
||||||
def connect(self) -> None:
|
def connect(self) -> None:
|
||||||
"""Establishes ZMQ sockets with the remote mobile robot"""
|
"""Establishes ZMQ sockets with the remote mobile robot"""
|
||||||
|
|
||||||
|
@ -354,7 +363,7 @@ class LeKiwiClient(Robot):
|
||||||
|
|
||||||
# TODO(Steven): The returned space is different from the get_observation of LeKiwi
|
# TODO(Steven): The returned space is different from the get_observation of LeKiwi
|
||||||
# This returns body-frames velocities instead of wheel pos/speeds
|
# This returns body-frames velocities instead of wheel pos/speeds
|
||||||
def get_observation(self) -> dict[str, np.ndarray]:
|
def get_observation(self) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Capture observations from the remote robot: current follower arm positions,
|
Capture observations from the remote robot: current follower arm positions,
|
||||||
present wheel speeds (converted to body-frame velocities: x, y, theta),
|
present wheel speeds (converted to body-frame velocities: x, y, theta),
|
||||||
|
@ -418,6 +427,9 @@ class LeKiwiClient(Robot):
|
||||||
|
|
||||||
return self._body_to_wheel_raw(x_cmd, y_cmd, theta_cmd)
|
return self._body_to_wheel_raw(x_cmd, y_cmd, theta_cmd)
|
||||||
|
|
||||||
|
def configure(self):
|
||||||
|
pass
|
||||||
|
|
||||||
# TODO(Steven): This assumes this call is always called from a keyboard teleop command
|
# TODO(Steven): This assumes this call is always called from a keyboard teleop command
|
||||||
# TODO(Steven): Doing this mapping in here adds latecy between send_action and movement from the user perspective.
|
# TODO(Steven): Doing this mapping in here adds latecy between send_action and movement from the user perspective.
|
||||||
# t0: get teleop_cmd
|
# t0: get teleop_cmd
|
||||||
|
@ -430,7 +442,7 @@ class LeKiwiClient(Robot):
|
||||||
# t2': send_action(motor_cmd)
|
# t2': send_action(motor_cmd)
|
||||||
# t3': execute motor_cmd
|
# t3': execute motor_cmd
|
||||||
# t3'-t2' << t3-t1
|
# t3'-t2' << t3-t1
|
||||||
def send_action(self, action: np.ndarray) -> np.ndarray:
|
def send_action(self, action: dict[str, Any]) -> dict[str, Any]:
|
||||||
"""Command lekiwi to move to a target joint configuration. Translates to motor space + sends over ZMQ
|
"""Command lekiwi to move to a target joint configuration. Translates to motor space + sends over ZMQ
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
|
@ -14,8 +14,6 @@
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
from lerobot.common.robots.config import RobotMode
|
from lerobot.common.robots.config import RobotMode
|
||||||
from lerobot.common.teleoperators.keyboard import KeyboardTeleop, KeyboardTeleopConfig
|
from lerobot.common.teleoperators.keyboard import KeyboardTeleop, KeyboardTeleopConfig
|
||||||
|
@ -115,7 +113,7 @@ def main():
|
||||||
while i < 1000:
|
while i < 1000:
|
||||||
arm_action = leader_arm.get_action()
|
arm_action = leader_arm.get_action()
|
||||||
base_action = keyboard.get_action()
|
base_action = keyboard.get_action()
|
||||||
action = np.append(arm_action, base_action) if base_action.size > 0 else arm_action
|
action = {**arm_action, **base_action} if base_action.size > 0 else arm_action
|
||||||
|
|
||||||
# TODO(Steven): Deal with policy action space
|
# TODO(Steven): Deal with policy action space
|
||||||
# robot.set_mode(RobotMode.AUTO)
|
# robot.set_mode(RobotMode.AUTO)
|
||||||
|
@ -125,8 +123,7 @@ def main():
|
||||||
action_sent = robot.send_action(action)
|
action_sent = robot.send_action(action)
|
||||||
observation = robot.get_observation()
|
observation = robot.get_observation()
|
||||||
|
|
||||||
frame = {"action": action_sent}
|
frame = {**action_sent, **observation}
|
||||||
frame.update(observation)
|
|
||||||
frame.update({"task": "Dummy Task Dataset"})
|
frame.update({"task": "Dummy Task Dataset"})
|
||||||
|
|
||||||
logging.info("Saved a frame into the dataset")
|
logging.info("Saved a frame into the dataset")
|
||||||
|
|
|
@ -14,14 +14,15 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import base64
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
|
|
||||||
import numpy as np
|
import cv2
|
||||||
import zmq
|
import zmq
|
||||||
|
|
||||||
from lerobot.common.constants import OBS_STATE
|
from lerobot.common.constants import OBS_IMAGES, OBS_STATE
|
||||||
|
|
||||||
from .config_lekiwi import LeKiwiConfig
|
from .config_lekiwi import LeKiwiConfig
|
||||||
from .lekiwi import LeKiwi
|
from .lekiwi import LeKiwi
|
||||||
|
@ -69,7 +70,7 @@ def main():
|
||||||
loop_start_time = time.time()
|
loop_start_time = time.time()
|
||||||
try:
|
try:
|
||||||
msg = remote_agent.zmq_cmd_socket.recv_string(zmq.NOBLOCK)
|
msg = remote_agent.zmq_cmd_socket.recv_string(zmq.NOBLOCK)
|
||||||
data = np.array(json.loads(msg))
|
data = dict(json.loads(msg))
|
||||||
_action_sent = robot.send_action(data)
|
_action_sent = robot.send_action(data)
|
||||||
last_cmd_time = time.time()
|
last_cmd_time = time.time()
|
||||||
except zmq.Again:
|
except zmq.Again:
|
||||||
|
@ -85,6 +86,18 @@ def main():
|
||||||
|
|
||||||
last_observation = robot.get_observation()
|
last_observation = robot.get_observation()
|
||||||
last_observation[OBS_STATE] = last_observation[OBS_STATE].tolist()
|
last_observation[OBS_STATE] = last_observation[OBS_STATE].tolist()
|
||||||
|
|
||||||
|
# Encode ndarrays to base64 strings
|
||||||
|
for cam_key, _ in robot.cameras.items():
|
||||||
|
ret, buffer = cv2.imencode(
|
||||||
|
".jpg", last_observation[f"{OBS_IMAGES}.{cam_key}"], [int(cv2.IMWRITE_JPEG_QUALITY), 90]
|
||||||
|
)
|
||||||
|
if ret:
|
||||||
|
last_observation[f"{OBS_IMAGES}.{cam_key}"] = base64.b64encode(buffer).decode("utf-8")
|
||||||
|
else:
|
||||||
|
last_observation[f"{OBS_IMAGES}.{cam_key}"] = ""
|
||||||
|
|
||||||
|
# Send the observation to the remote agent
|
||||||
remote_agent.zmq_observation_socket.send_string(json.dumps(last_observation))
|
remote_agent.zmq_observation_socket.send_string(json.dumps(last_observation))
|
||||||
|
|
||||||
# Ensure a short sleep to avoid overloading the CPU.
|
# Ensure a short sleep to avoid overloading the CPU.
|
||||||
|
|
Loading…
Reference in New Issue