refactor(robots): update lekiwi for the latest motor bus api

chore(teleop): Add missing abstract methods to keyboard implementation

refactor(robots): update lekiwi client and host code for the new api

chore(config): update host lekiwi ip in client config

chore(examples): move application scripts to the examples directory

fix(motors): missing type check condition in set_half_turn_homings

fix(robots): fix assumption in calibrate() for robots with more than just an arm

fix(robot): change Mode to Operating_Mode in configure write for lekiwi

fix(robots): make sure message is display in calibrate() method lekiwi

fix(robots): no need for .tolist() in lekiwi host app

fix(teleop): fix is_connected in teleoperator keyboard

fix(teleop): always display calibration message in so100

fix(robots): fix send_action in lekiwi_client

debug(examples): configuration for lekiwi client app

fix(robots): fix send_action in lekiwi client part 2

refactor(robots): use dicts in lekiwi for get_obs and send_action

dbg(robots): check sent action wheels lekiwi

debug(robots): fix overflow base commands

debug(robots): fix how we deal with negative values lekiwi

debug(robots): lekiwi sign degrees fix

fix(robots): right motors id in lekiwi host

chore(doc): update todos
This commit is contained in:
Steven Palma 2025-04-04 14:26:46 +02:00
parent 4311b39e73
commit 4ec2ef575f
No known key found for this signature in database
11 changed files with 181 additions and 135 deletions

View File

@ -14,16 +14,12 @@
import logging
import numpy as np
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.robots.config import RobotMode
from lerobot.common.robots.lekiwi.config_lekiwi import LeKiwiClientConfig
from lerobot.common.robots.lekiwi.lekiwi_client import LeKiwiClient
from lerobot.common.teleoperators.keyboard import KeyboardTeleop, KeyboardTeleopConfig
from lerobot.common.teleoperators.so100 import SO100Leader, SO100LeaderConfig
from .config_lekiwi import LeKiwiClientConfig
from .lekiwi_client import LeKiwiClient
DUMMY_FEATURES = {
"observation.state": {
"dtype": "float64",
@ -82,26 +78,24 @@ DUMMY_FEATURES = {
def main():
logging.info("Configuring Teleop Devices")
leader_arm_config = SO100LeaderConfig(port="/dev/tty.usbmodem58760429271")
leader_arm_config = SO100LeaderConfig(port="/dev/tty.usbmodem58760434171")
leader_arm = SO100Leader(leader_arm_config)
keyboard_config = KeyboardTeleopConfig()
keyboard = KeyboardTeleop(keyboard_config)
logging.info("Configuring LeKiwi Client")
robot_config = LeKiwiClientConfig(
id="daemonlekiwi", calibration_dir=".cache/calibration/lekiwi", robot_mode=RobotMode.TELEOP
)
robot_config = LeKiwiClientConfig(id="lekiwi", robot_mode=RobotMode.TELEOP)
robot = LeKiwiClient(robot_config)
logging.info("Creating LeRobot Dataset")
# TODO(Steven): Check this creation
dataset = LeRobotDataset.create(
repo_id="user/lekiwi",
fps=10,
features=DUMMY_FEATURES,
)
# # TODO(Steven): Check this creation
# dataset = LeRobotDataset.create(
# repo_id="user/lekiwi2",
# fps=10,
# features=DUMMY_FEATURES,
# )
logging.info("Connecting Teleop Devices")
leader_arm.connect()
@ -110,30 +104,32 @@ def main():
logging.info("Connecting remote LeKiwi")
robot.connect()
if not robot.is_connected or not leader_arm.is_connected or not keyboard.is_connected:
logging.error("Failed to connect to all devices")
return
logging.info("Starting LeKiwi teleoperation")
i = 0
while i < 1000:
arm_action = leader_arm.get_action()
base_action = keyboard.get_action()
action = np.append(arm_action, base_action) if base_action.size > 0 else arm_action
action = {**arm_action, **base_action} if len(base_action) > 0 else arm_action
# TODO(Steven): Deal with policy action space
# robot.set_mode(RobotMode.AUTO)
# policy_action = policy.get_action() # This might be in body frame, key space or smt else
# robot.send_action(policy_action)
action_sent = robot.send_action(action)
observation = robot.get_observation()
frame = {"action": action_sent}
frame.update(observation)
frame = {**action_sent, **observation}
frame.update({"task": "Dummy Task Dataset"})
logging.info("Saved a frame into the dataset")
dataset.add_frame(frame)
# dataset.add_frame(frame)
i += 1
dataset.save_episode()
# dataset.save_episode()
# dataset.push_to_hub()
logging.info("Disconnecting Teleop Devices and LeKiwi Client")

View File

@ -48,5 +48,5 @@ default_cache_path = Path(HF_HOME) / "lerobot"
HF_LEROBOT_HOME = Path(os.getenv("HF_LEROBOT_HOME", default_cache_path)).expanduser()
# calibration dir
default_calibration_path = HF_LEROBOT_HOME / ".calibration"
default_calibration_path = HF_LEROBOT_HOME / "calibration"
HF_LEROBOT_CALIBRATION = Path(os.getenv("HF_LEROBOT_CALIBRATION", default_calibration_path)).expanduser()

View File

@ -543,7 +543,7 @@ class MotorsBus(abc.ABC):
motors = self.names
elif isinstance(motors, (str, int)):
motors = [motors]
else:
elif not isinstance(motors, list):
raise TypeError(motors)
self.reset_calibration(motors)
@ -603,6 +603,7 @@ class MotorsBus(abc.ABC):
min_ = self.calibration[name].range_min
max_ = self.calibration[name].range_max
bounded_val = min(max_, max(min_, val))
# TODO(Steven): normalization can go boom if max_ == min_, we should add a check probably in record_ranges_of_motions (which probably indicates the user forgot to move a motor)
if self.motors[name].norm_mode is MotorNormMode.RANGE_M100_100:
normalized_values[id_] = (((bounded_val - min_) / (max_ - min_)) * 200) - 100
elif self.motors[name].norm_mode is MotorNormMode.RANGE_0_100:
@ -822,8 +823,11 @@ class MotorsBus(abc.ABC):
f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`."
)
if isinstance(values, int):
ids_values = {id_: values for id_ in self.ids}
if isinstance(values, int): # TODO(Steven): wouldn't this be instead isinstance(values, Value)?
ids_values = {
id_: values for id_ in self.ids
} # TODO(Steven): And then cast it here to an int if it is not possible to write a float
# TODO(Steven): Consider also doing: ids_values=dict.fromkeys(self.ids, values)
elif isinstance(values, dict):
ids_values = {self._get_motor_id(motor): val for motor, val in values.items()}
else:

View File

@ -48,7 +48,7 @@ class LeKiwiConfig(RobotConfig):
@dataclass
class LeKiwiClientConfig(RobotConfig):
# Network Configuration
remote_ip: str = "172.18.133.90"
remote_ip: str = "172.18.129.208"
port_zmq_cmd: int = 5555
port_zmq_observations: int = 5556

View File

@ -14,13 +14,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import base64
import logging
import time
from typing import Any
import cv2
from lerobot.common.cameras.utils import make_cameras_from_configs
from lerobot.common.constants import OBS_IMAGES, OBS_STATE
from lerobot.common.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
@ -64,8 +61,8 @@ class LeKiwi(Robot):
"arm_gripper": Motor(6, "sts3215", MotorNormMode.RANGE_0_100),
# base
"base_left_wheel": Motor(7, "sts3215", MotorNormMode.RANGE_M100_100),
"base_back_wheel": Motor(8, "sts3215", MotorNormMode.RANGE_M100_100),
"base_right_wheel": Motor(9, "sts3215", MotorNormMode.RANGE_M100_100),
"base_right_wheel": Motor(8, "sts3215", MotorNormMode.RANGE_M100_100),
"base_back_wheel": Motor(9, "sts3215", MotorNormMode.RANGE_M100_100),
},
calibration=self.calibration,
)
@ -119,23 +116,33 @@ class LeKiwi(Robot):
def is_calibrated(self) -> bool:
return self.bus.is_calibrated
# TODO(Steven): I think we should extend this to give the user the option of re-calibrate
# calibrate(recalibrate: bool = False) -> None:
# If true, then we overwrite the previous calibration file with new values
def calibrate(self) -> None:
logger.info(f"\nRunning calibration of {self}")
motors = self.arm_motors + self.base_motors
self.bus.disable_torque(self.arm_motors)
for name in self.arm_motors:
self.bus.write("Operating_Mode", name, OperatingMode.POSITION.value)
input("Move robot to the middle of its range of motion and press ENTER....")
homing_offsets = self.bus.set_half_turn_homings(self.arm_motors)
homing_offsets = self.bus.set_half_turn_homings(motors)
full_turn_motor = "arm_wrist_roll"
unknown_range_motors = [name for name in self.arm_motors if name != full_turn_motor]
logger.info(
# TODO(Steven): Might be worth to do this also in other robots but it should be added in the docs
full_turn_motor = [
motor for motor in motors if any(keyword in motor for keyword in ["wheel", "gripper"])
]
unknown_range_motors = [motor for motor in motors if motor not in full_turn_motor]
print(
f"Move all arm joints except '{full_turn_motor}' sequentially through their "
"entire ranges of motion.\nRecording positions. Press ENTER to stop..."
)
range_mins, range_maxes = self.bus.record_ranges_of_motion(unknown_range_motors)
for name in [*self.base_motors, full_turn_motor]:
for name in full_turn_motor:
range_mins[name] = 0
range_maxes[name] = 4095
@ -159,7 +166,7 @@ class LeKiwi(Robot):
# and torque can be safely disabled to run calibration.
self.bus.disable_torque(self.arm_motors)
for name in self.arm_motors:
self.bus.write("Mode", name, OperatingMode.POSITION.value)
self.bus.write("Operating_Mode", name, OperatingMode.POSITION.value)
# Set P_Coefficient to lower value to avoid shakiness (Default is 32)
self.bus.write("P_Coefficient", name, 16)
# Set I_Coefficient and D_Coefficient to default value 0 and 32
@ -171,15 +178,15 @@ class LeKiwi(Robot):
self.bus.write("Acceleration", name, 254)
for name in self.base_motors:
self.bus.write("Mode", name, OperatingMode.VELOCITY.value)
self.bus.write("Operating_Mode", name, OperatingMode.VELOCITY.value)
self.bus.enable_torque()
self.bus.enable_torque() # TODO(Steven): Operation has failed with: ConnectionError: Failed to write 'Lock' on id_=6 with '1' after 1 tries. [TxRxResult] Incorrect status packet!
def get_observation(self) -> dict[str, Any]:
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
obs_dict = {}
obs_dict = {OBS_IMAGES: {}}
# Read actuators position for arm and vel for base
start = time.perf_counter()
@ -192,12 +199,7 @@ class LeKiwi(Robot):
# Capture images from cameras
for cam_key, cam in self.cameras.items():
start = time.perf_counter()
frame = cam.async_read()
ret, buffer = cv2.imencode(".jpg", frame, [int(cv2.IMWRITE_JPEG_QUALITY), 90])
if ret:
obs_dict[f"{OBS_IMAGES}.{cam_key}"] = base64.b64encode(buffer).decode("utf-8")
else:
obs_dict[f"{OBS_IMAGES}.{cam_key}"] = ""
obs_dict[OBS_IMAGES][cam_key] = cam.async_read()
dt_ms = (time.perf_counter() - start) * 1e3
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
@ -229,15 +231,19 @@ class LeKiwi(Robot):
present_pos = self.bus.sync_read("Present_Position", self.arm_motors)
goal_present_pos = {key: (g_pos, present_pos[key]) for key, g_pos in arm_goal_pos.items()}
arm_safe_goal_pos = ensure_safe_goal_position(goal_present_pos, self.config.max_relative_target)
arm_goal_pos = arm_safe_goal_pos
# TODO(Steven): Message fetching failed: Magnitude 34072 exceeds 32767 (max for sign_bit_index=15)
# TODO(Steven): IMO, this should be a check in the motor bus
# Send goal position to the actuators
self.bus.sync_write("Goal_Position", arm_safe_goal_pos)
self.bus.sync_write("Goal_Position", arm_goal_pos)
self.bus.sync_write("Goal_Speed", base_goal_vel)
return {**arm_safe_goal_pos, **base_goal_vel}
return {**arm_goal_pos, **base_goal_vel}
def stop_base(self):
self.bus.sync_write("Goal_Speed", {name: 0 for name in self.base_motors}, num_retry=5)
self.bus.sync_write("Goal_Speed", dict.fromkeys(self.base_motors, 0), num_retry=5)
logger.info("Base motors stopped")
def disconnect(self):

View File

@ -15,6 +15,7 @@
import base64
import json
import logging
from typing import Any
import cv2
import numpy as np
@ -40,6 +41,7 @@ from .config_lekiwi import LeKiwiClientConfig
# 2. Adding it into the robot implementation
# (meaning the policy might be needed to be train
# over the teleop action space)
# TODO(Steven): Check if we can move everything to 32 instead
class LeKiwiClient(Robot):
config_class = LeKiwiClientConfig
name = "lekiwi_client"
@ -62,10 +64,9 @@ class LeKiwiClient(Robot):
self.zmq_observation_socket = None
self.last_frames = {}
self.last_present_speed = [0, 0, 0]
self.last_present_speed = {"x_cmd": 0.0, "y_cmd": 0.0, "theta_cmd": 0.0}
# TODO(Steven): Move everything to 32 instead
self.last_remote_arm_state = torch.zeros(6, dtype=torch.float64)
self.last_remote_arm_state = {}
# Define three speed levels and a current index
self.speed_levels = [
@ -75,7 +76,7 @@ class LeKiwiClient(Robot):
]
self.speed_index = 0 # Start at slow
self.is_connected = False
self._is_connected = False
self.logs = {}
@property
@ -108,7 +109,7 @@ class LeKiwiClient(Robot):
@property
def camera_features(self) -> dict[str, dict]:
# TODO(Steven): Get this from the data fetched?
# TODO(Steven): Motor names are unknown for the Daemon
# TODO(Steven): camera names are unknown for the Daemon
# Or assume its size/metadata?
# TODO(Steven): Check consistency of image sizes
cam_ft = {
@ -125,10 +126,18 @@ class LeKiwiClient(Robot):
}
return cam_ft
@property
def is_connected(self) -> bool:
return self._is_connected
@property
def is_calibrated(self) -> bool:
pass
def connect(self) -> None:
"""Establishes ZMQ sockets with the remote mobile robot"""
if self.is_connected:
if self._is_connected:
raise DeviceAlreadyConnectedError(
"LeKiwi Daemon is already connected. Do not run `robot.connect()` twice."
)
@ -144,37 +153,32 @@ class LeKiwiClient(Robot):
self.zmq_observation_socket.connect(zmq_observations_locator)
self.zmq_observation_socket.setsockopt(zmq.CONFLATE, 1)
self.is_connected = True
self._is_connected = True
def calibrate(self) -> None:
# TODO(Steven): Nothing to calibrate.
# Consider triggering calibrate() on the remote mobile robot?
# Although this would require a more complex comms schema
logging.warning("LeKiwiClient has nothing to calibrate.")
return
# Consider moving these static functions out of the class
# Copied from robot_lekiwi MobileManipulator class
# Copied from robot_lekiwi MobileManipulator class* (before the refactor)
@staticmethod
def _degps_to_raw(degps: float) -> int:
steps_per_deg = 4096.0 / 360.0
speed_in_steps = abs(degps) * steps_per_deg
speed_in_steps = degps * steps_per_deg
speed_int = int(round(speed_in_steps))
# Cap the value to fit within signed 16-bit range (-32768 to 32767)
if speed_int > 0x7FFF:
speed_int = 0x7FFF
if degps < 0:
return speed_int | 0x8000
else:
return speed_int & 0x7FFF
speed_int = 0x7FFF # 32767 -> maximum positive value
elif speed_int < -0x8000:
speed_int = -0x8000 # -32768 -> minimum negative value
return speed_int
# Copied from robot_lekiwi MobileManipulator class
@staticmethod
def _raw_to_degps(raw_speed: int) -> float:
steps_per_deg = 4096.0 / 360.0
magnitude = raw_speed & 0x7FFF
magnitude = raw_speed
degps = magnitude / steps_per_deg
if raw_speed & 0x8000:
degps = -degps
return degps
# Copied from robot_lekiwi MobileManipulator class
@ -237,12 +241,13 @@ class LeKiwiClient(Robot):
# Convert each wheels angular speed (deg/s) to a raw integer.
wheel_raw = [LeKiwiClient._degps_to_raw(deg) for deg in wheel_degps]
# TODO(Steven): remove hard-coded names
return {"left_wheel": wheel_raw[0], "back_wheel": wheel_raw[1], "right_wheel": wheel_raw[2]}
# Copied from robot_lekiwi MobileManipulator class
def _wheel_raw_to_body(
self, wheel_raw: np.array, wheel_radius: float = 0.05, base_radius: float = 0.125
) -> tuple:
self, wheel_raw: dict[str, Any], wheel_radius: float = 0.05, base_radius: float = 0.125
) -> dict[str, Any]:
"""
Convert wheel raw command feedback back into body-frame velocities.
@ -258,8 +263,9 @@ class LeKiwiClient(Robot):
theta_cmd : Rotational velocity in deg/s.
"""
# TODO(Steven): No check is done for dict keys
# Convert each raw command back to an angular speed in deg/s.
wheel_degps = np.array([LeKiwiClient._raw_to_degps(int(r)) for r in wheel_raw])
wheel_degps = np.array([LeKiwiClient._raw_to_degps(int(v)) for _, v in wheel_raw.items()])
# Convert from deg/s to rad/s.
wheel_radps = wheel_degps * (np.pi / 180.0)
# Compute each wheels linear speed (m/s) from its angular speed.
@ -274,7 +280,7 @@ class LeKiwiClient(Robot):
velocity_vector = m_inv.dot(wheel_linear_speeds)
x_cmd, y_cmd, theta_rad = velocity_vector
theta_cmd = theta_rad * (180.0 / np.pi)
return (x_cmd, y_cmd, theta_cmd)
return {"x_cmd": x_cmd, "y_cmd": y_cmd, "theta_cmd": theta_cmd}
# TODO(Steven): This is flaky, for example, if we received a state but failed decoding the image, we will not update any value
# TODO(Steven): All this function needs to be refactored
@ -285,10 +291,9 @@ class LeKiwiClient(Robot):
nothing arrives for any field, use the last known values."""
frames = {}
present_speed = []
present_speed = {}
# TODO(Steven): Size is being assumed, is this safe?
remote_arm_state_tensor = torch.empty(6, dtype=torch.float64)
remote_arm_state_tensor = {}
# Poll up to 15 ms
poller = zmq.Poller()
@ -317,11 +322,9 @@ class LeKiwiClient(Robot):
# Decode only the final message
try:
observation = json.loads(last_msg)
observation[OBS_STATE] = np.array(observation[OBS_STATE])
# TODO(Steven): Consider getting directly the item with observation[OBS_STATE]
state_observation = {k: v for k, v in observation.items() if k.startswith(OBS_STATE)}
image_observation = {k: v for k, v in observation.items() if k.startswith(OBS_IMAGES)}
state_observation = observation[OBS_STATE]
image_observation = observation[OBS_IMAGES]
# Convert images
for cam_name, image_b64 in image_observation.items():
@ -332,14 +335,16 @@ class LeKiwiClient(Robot):
if frame_candidate is not None:
frames[cam_name] = frame_candidate
# TODO(Steven): Should we really ignore the arm state if the image is None?
# If remote_arm_state is None and frames is None there is no message then use the previous message
if state_observation is not None and frames is not None:
self.last_frames = frames
remote_arm_state_tensor = torch.tensor(state_observation[OBS_STATE][:6], dtype=torch.float64)
# TODO(Steven): hard-coded name of expected keys, not good
remote_arm_state_tensor = {k: v for k, v in state_observation.items() if k.startswith("arm")}
self.last_remote_arm_state = remote_arm_state_tensor
present_speed = state_observation[OBS_STATE][6:]
present_speed = {k: v for k, v in state_observation.items() if k.startswith("base")}
self.last_present_speed = present_speed
else:
frames = self.last_frames
@ -354,38 +359,35 @@ class LeKiwiClient(Robot):
# TODO(Steven): The returned space is different from the get_observation of LeKiwi
# This returns body-frames velocities instead of wheel pos/speeds
def get_observation(self) -> dict[str, np.ndarray]:
def get_observation(self) -> dict[str, Any]:
"""
Capture observations from the remote robot: current follower arm positions,
present wheel speeds (converted to body-frame velocities: x, y, theta),
and a camera frame. Receives over ZMQ, translate to body-frame vel
"""
if not self.is_connected:
if not self._is_connected:
raise DeviceNotConnectedError("LeKiwiClient is not connected. You need to run `robot.connect()`.")
obs_dict = {}
# TODO(Steven): remove hard-coded cam name
# This is needed at init for when there's no comms
obs_dict = {
OBS_IMAGES: {"wrist": np.zeros(shape=(480, 640, 3)), "front": np.zeros(shape=(640, 480, 3))}
}
frames, present_speed, remote_arm_state_tensor = self._get_data()
body_state = self._wheel_raw_to_body(present_speed)
body_state_mm = (body_state[0] * 1000.0, body_state[1] * 1000.0, body_state[2]) # Convert x,y to mm/s
wheel_state_tensor = torch.tensor(body_state_mm, dtype=torch.float64)
combined_state_tensor = torch.cat((remote_arm_state_tensor, wheel_state_tensor), dim=0)
# TODO(Steven): output isdict[str,Any] and we multiply by 1000.0. This should be more explicit and specify the expected type instead of Any
body_state_mm = {k: v * 1000.0 for k, v in body_state.items()} # Convert x,y to mm/s
obs_dict = {OBS_STATE: combined_state_tensor}
obs_dict[OBS_STATE] = {**remote_arm_state_tensor, **body_state_mm}
# Loop over each configured camera
for cam_name, frame in frames.items():
if frame is None:
# TODO(Steven): Daemon doesn't know camera dimensions
# TODO(Steven): Daemon doesn't know camera dimensions (hard-coded for now), consider at least getting them from state features
logging.warning("Frame is None")
frame = np.zeros((480, 640, 3), dtype=np.uint8)
obs_dict[cam_name] = torch.from_numpy(frame)
# TODO(Steven): Refactor this ugly thing (needed for when there are not comms at init)
if OBS_IMAGES + ".wrist" not in obs_dict:
obs_dict[OBS_IMAGES + ".wrist"] = np.zeros(shape=(480, 640, 3))
if OBS_IMAGES + ".front" not in obs_dict:
obs_dict[OBS_IMAGES + ".front"] = np.zeros(shape=(640, 480, 3))
obs_dict[OBS_IMAGES][cam_name] = torch.from_numpy(frame)
return obs_dict
@ -415,9 +417,11 @@ class LeKiwiClient(Robot):
theta_cmd += theta_speed
if self.teleop_keys["rotate_right"] in pressed_keys:
theta_cmd -= theta_speed
return self._body_to_wheel_raw(x_cmd, y_cmd, theta_cmd)
def configure(self):
pass
# TODO(Steven): This assumes this call is always called from a keyboard teleop command
# TODO(Steven): Doing this mapping in here adds latecy between send_action and movement from the user perspective.
# t0: get teleop_cmd
@ -430,7 +434,7 @@ class LeKiwiClient(Robot):
# t2': send_action(motor_cmd)
# t3': execute motor_cmd
# t3'-t2' << t3-t1
def send_action(self, action: np.ndarray) -> np.ndarray:
def send_action(self, action: dict[str, Any]) -> dict[str, Any]:
"""Command lekiwi to move to a target joint configuration. Translates to motor space + sends over ZMQ
Args:
@ -442,28 +446,40 @@ class LeKiwiClient(Robot):
Returns:
np.ndarray: the action sent to the motors, potentially clipped.
"""
if not self.is_connected:
if not self._is_connected:
raise DeviceNotConnectedError(
"ManipulatorRobot is not connected. You need to run `robot.connect()`."
)
goal_pos: np.array = np.zeros(9)
if self.robot_mode is RobotMode.AUTO:
# TODO(Steven): Not yet implemented. The policy outputs might need a different conversion
raise Exception
goal_pos = {}
# TODO(Steven): This assumes teleop mode is always used with keyboard. Tomorrow we could teleop with another device ... ?
if self.robot_mode is RobotMode.TELEOP:
if action.size < 6:
logging.error("Action should include at least the 6 states of the leader arm")
motors_name = self.state_feature.get("names").get("motors")
common_keys = [
key for key in action if key in (motor.replace("arm_", "") for motor in motors_name)
]
# TODO(Steven): I don't like this
if len(common_keys) < 6:
logging.error("Action should include at least the states of the leader arm")
raise InvalidActionError
goal_pos[:6] = action[:6]
if action.size > 6:
# TODO(Steven): Assumes size and order is respected
wheel_actions = [v for _, v in self._from_keyboard_to_wheel_action(action[6:]).items()]
goal_pos[6:] = wheel_actions
self.zmq_cmd_socket.send_string(json.dumps(goal_pos.tolist())) # action is in motor space
arm_actions = {"arm_" + arm_motor: action[arm_motor] for arm_motor in common_keys}
goal_pos = arm_actions
if len(action) > 6:
keyboard_keys = np.array(list(set(action.keys()) - set(common_keys)))
wheel_actions = {
"base_" + k: v for k, v in self._from_keyboard_to_wheel_action(keyboard_keys).items()
}
goal_pos = {**arm_actions, **wheel_actions}
self.zmq_cmd_socket.send_string(json.dumps(goal_pos)) # action is in motor space
return goal_pos
@ -474,15 +490,14 @@ class LeKiwiClient(Robot):
def disconnect(self):
"""Cleans ZMQ comms"""
if not self.is_connected:
if not self._is_connected:
raise DeviceNotConnectedError(
"LeKiwi is not connected. You need to run `robot.connect()` before disconnecting."
)
# TODO(Steven): Consider sending a stop to the remote mobile robot. Although this would need a moore complex comms schema
self.zmq_observation_socket.close()
self.zmq_cmd_socket.close()
self.zmq_context.term()
self.is_connected = False
self._is_connected = False
def __del__(self):
if getattr(self, "is_connected", False):

View File

@ -14,14 +14,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import base64
import json
import logging
import time
import numpy as np
import cv2
import zmq
from lerobot.common.constants import OBS_STATE
from lerobot.common.constants import OBS_IMAGES
from .config_lekiwi import LeKiwiConfig
from .lekiwi import LeKiwi
@ -69,7 +70,7 @@ def main():
loop_start_time = time.time()
try:
msg = remote_agent.zmq_cmd_socket.recv_string(zmq.NOBLOCK)
data = np.array(json.loads(msg))
data = dict(json.loads(msg))
_action_sent = robot.send_action(data)
last_cmd_time = time.time()
except zmq.Again:
@ -84,7 +85,18 @@ def main():
robot.stop_base()
last_observation = robot.get_observation()
last_observation[OBS_STATE] = last_observation[OBS_STATE].tolist()
# Encode ndarrays to base64 strings
for cam_key, _ in robot.cameras.items():
ret, buffer = cv2.imencode(
".jpg", last_observation[OBS_IMAGES][cam_key], [int(cv2.IMWRITE_JPEG_QUALITY), 90]
)
if ret:
last_observation[OBS_IMAGES][cam_key] = base64.b64encode(buffer).decode("utf-8")
else:
last_observation[OBS_IMAGES][cam_key] = ""
# Send the observation to the remote agent
remote_agent.zmq_observation_socket.send_string(json.dumps(last_observation))
# Ensure a short sleep to avoid overloading the CPU.

View File

@ -19,8 +19,7 @@ import os
import sys
import time
from queue import Queue
import numpy as np
from typing import Any
from lerobot.common.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
@ -59,7 +58,7 @@ class KeyboardTeleop(Teleoperator):
self.event_queue = Queue()
self.current_pressed = {}
self.listener = None
self.is_connected = False
self._is_connected = False
self.logs = {}
@property
@ -75,14 +74,22 @@ class KeyboardTeleop(Teleoperator):
def feedback_feature(self) -> dict:
return {}
@property
def is_connected(self) -> bool:
return self._is_connected
@property
def is_calibrated(self) -> bool:
pass
def connect(self) -> None:
# TODO(Steven): Consider instead of raising a warning and then returning the status
# if self.is_connected:
# if self._is_connected:
# logging.warning(
# "ManipulatorRobot is already connected. Do not run `robot.connect()` twice."
# )
# return self.is_connected
if self.is_connected:
# return self._is_connected
if self._is_connected:
raise DeviceAlreadyConnectedError(
"ManipulatorRobot is already connected. Do not run `robot.connect()` twice."
)
@ -90,24 +97,24 @@ class KeyboardTeleop(Teleoperator):
if PYNPUT_AVAILABLE:
logging.info("pynput is available - enabling local keyboard listener.")
self.listener = keyboard.Listener(
on_press=self.on_press,
on_release=self.on_release,
on_press=self._on_press,
on_release=self._on_release,
)
self.listener.start()
else:
logging.info("pynput not available - skipping local keyboard listener.")
self.listener = None
self.is_connected = True
self._is_connected = True
def calibrate(self) -> None:
pass
def on_press(self, key):
def _on_press(self, key):
if hasattr(key, "char"):
self.event_queue.put((key.char, True))
def on_release(self, key):
def _on_release(self, key):
if hasattr(key, "char"):
self.event_queue.put((key.char, False))
if key == keyboard.Key.esc:
@ -119,10 +126,13 @@ class KeyboardTeleop(Teleoperator):
key_char, is_pressed = self.event_queue.get_nowait()
self.current_pressed[key_char] = is_pressed
def get_action(self) -> np.ndarray:
def configure(self):
pass
def get_action(self) -> dict[str, Any]:
before_read_t = time.perf_counter()
if not self.is_connected:
if not self._is_connected:
raise DeviceNotConnectedError(
"KeyboardTeleop is not connected. You need to run `connect()` before `get_action()`."
)
@ -133,17 +143,17 @@ class KeyboardTeleop(Teleoperator):
action = {key for key, val in self.current_pressed.items() if val}
self.logs["read_pos_dt_s"] = time.perf_counter() - before_read_t
return np.array(list(action))
return dict.fromkeys(action, None)
def send_feedback(self, feedback: np.ndarray) -> None:
def send_feedback(self, feedback: dict[str, Any]) -> None:
pass
def disconnect(self) -> None:
if not self.is_connected:
if not self._is_connected:
raise DeviceNotConnectedError(
"KeyboardTeleop is not connected. You need to run `robot.connect()` before `disconnect()`."
)
if self.listener is not None:
self.listener.stop()
self.is_connected = False
self._is_connected = False

View File

@ -95,7 +95,7 @@ class SO100Leader(Teleoperator):
full_turn_motor = "wrist_roll"
unknown_range_motors = [name for name in self.arm.names if name != full_turn_motor]
logger.info(
print(
f"Move all joints except '{full_turn_motor}' sequentially through their "
"entire ranges of motion.\nRecording positions. Press ENTER to stop..."
)

View File

@ -45,6 +45,9 @@ class Teleoperator(abc.ABC):
def is_connected(self) -> bool:
pass
# TODO(Steven): I think connect() should return a bool, such that the client/application code can check if the device connected successfully
# if not device.connect():
# raise DeviceNotConnectedError(f"{device} failed to connect")
@abc.abstractmethod
def connect(self) -> None:
"""Connects to the teleoperator."""