feat(lekiwi): Make dataset recording work

This commit is contained in:
Steven Palma 2025-03-19 10:54:58 +01:00
parent e0d1b75408
commit 0da9063efd
No known key found for this signature in database
5 changed files with 144 additions and 51 deletions

View File

@ -1,16 +1,23 @@
import abc import abc
import enum
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
import draccus import draccus
class RobotMode(enum.Enum):
TELEOP = 0
AUTO = 1
@dataclass(kw_only=True) @dataclass(kw_only=True)
class RobotConfig(draccus.ChoiceRegistry, abc.ABC): class RobotConfig(draccus.ChoiceRegistry, abc.ABC):
# Allows to distinguish between different robots of the same type # Allows to distinguish between different robots of the same type
id: str | None = None id: str | None = None
# Directory to store calibration file # Directory to store calibration file
calibration_dir: Path | None = None calibration_dir: Path | None = None
robot_mode: RobotMode | None = None
@property @property
def type(self) -> str: def type(self) -> str:

View File

@ -11,10 +11,6 @@ class DaemonLeKiwiRobotConfig(RobotConfig):
port_zmq_cmd: int = 5555 port_zmq_cmd: int = 5555
port_zmq_observations: int = 5556 port_zmq_observations: int = 5556
id = "daemonlekiwi"
calibration_dir: str = ".cache/calibration/lekiwi"
teleop_keys: dict[str, str] = field( teleop_keys: dict[str, str] = field(
default_factory=lambda: { default_factory=lambda: {
# Movement # Movement

View File

@ -25,8 +25,9 @@ import zmq
from lerobot.common.constants import OBS_IMAGES, OBS_STATE from lerobot.common.constants import OBS_IMAGES, OBS_STATE
from lerobot.common.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError, InvalidActionError from lerobot.common.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError, InvalidActionError
from lerobot.common.robots.config import RobotMode
from ..robot import Robot, RobotMode from ..robot import Robot
from .configuration_daemon_lekiwi import DaemonLeKiwiRobotConfig from .configuration_daemon_lekiwi import DaemonLeKiwiRobotConfig
@ -50,6 +51,7 @@ class DaemonLeKiwiRobot(Robot):
self.config = config self.config = config
self.id = config.id self.id = config.id
self.robot_type = config.type self.robot_type = config.type
self.robot_mode = config.robot_mode
self.remote_ip = config.remote_ip self.remote_ip = config.remote_ip
self.port_zmq_cmd = config.port_zmq_cmd self.port_zmq_cmd = config.port_zmq_cmd
@ -63,7 +65,9 @@ class DaemonLeKiwiRobot(Robot):
self.last_frames = {} self.last_frames = {}
self.last_present_speed = [0, 0, 0] self.last_present_speed = [0, 0, 0]
self.last_remote_arm_state = torch.zeros(6, dtype=torch.float32)
# TODO(Steven): Consider 32 instead
self.last_remote_arm_state = torch.zeros(6, dtype=torch.float64)
# Define three speed levels and a current index # Define three speed levels and a current index
self.speed_levels = [ self.speed_levels = [
@ -81,12 +85,23 @@ class DaemonLeKiwiRobot(Robot):
# TODO(Steven): Get this from the data fetched? # TODO(Steven): Get this from the data fetched?
# TODO(Steven): Motor names are unknown for the Daemon # TODO(Steven): Motor names are unknown for the Daemon
# Or assume its size/metadata? # Or assume its size/metadata?
# return { return {
# "dtype": "float32", "dtype": "float64",
# "shape": (len(self.actuators),), "shape": (9,),
# "names": {"motors": list(self.actuators.motors)}, "names": {
# } "motors": [
pass "arm_shoulder_pan",
"arm_shoulder_lift",
"arm_elbow_flex",
"arm_wrist_flex",
"arm_wrist_roll",
"arm_gripper",
"base_left_wheel",
"base_right_wheel",
"base_back_wheel",
]
},
}
@property @property
def action_feature(self) -> dict: def action_feature(self) -> dict:
@ -97,15 +112,19 @@ class DaemonLeKiwiRobot(Robot):
# TODO(Steven): Get this from the data fetched? # TODO(Steven): Get this from the data fetched?
# TODO(Steven): Motor names are unknown for the Daemon # TODO(Steven): Motor names are unknown for the Daemon
# Or assume its size/metadata? # Or assume its size/metadata?
# cam_ft = {} cam_ft = {
# for cam_key, cam in self.cameras.items(): "front": {
# cam_ft[cam_key] = { "shape": (480, 640, 3),
# "shape": (cam.height, cam.width, cam.channels), "names": ["height", "width", "channels"],
# "names": ["height", "width", "channels"], "info": None,
# "info": None, },
# } "wrist": {
# return cam_ft "shape": (480, 640, 3),
pass "names": ["height", "width", "channels"],
"info": None,
},
}
return cam_ft
def connect(self) -> None: def connect(self) -> None:
"""Establishes ZMQ sockets with the remote mobile robot""" """Establishes ZMQ sockets with the remote mobile robot"""
@ -259,6 +278,7 @@ class DaemonLeKiwiRobot(Robot):
return (x_cmd, y_cmd, theta_cmd) return (x_cmd, y_cmd, theta_cmd)
# TODO(Steven): This is flaky, for example, if we received a state but failed decoding the image, we will not update any value # TODO(Steven): This is flaky, for example, if we received a state but failed decoding the image, we will not update any value
# TODO(Steven): All this function needs to be refactored
def _get_data(self): def _get_data(self):
# Copied from robot_lekiwi.py # Copied from robot_lekiwi.py
"""Polls the video socket for up to 15 ms. If data arrives, decode only """Polls the video socket for up to 15 ms. If data arrives, decode only
@ -269,7 +289,7 @@ class DaemonLeKiwiRobot(Robot):
present_speed = [] present_speed = []
# TODO(Steven): Size is being assumed, is this safe? # TODO(Steven): Size is being assumed, is this safe?
remote_arm_state_tensor = torch.empty(6, dtype=torch.float32) remote_arm_state_tensor = torch.empty(6, dtype=torch.float64)
# Poll up to 15 ms # Poll up to 15 ms
poller = zmq.Poller() poller = zmq.Poller()
@ -317,7 +337,7 @@ class DaemonLeKiwiRobot(Robot):
if state_observation is not None and frames is not None: if state_observation is not None and frames is not None:
self.last_frames = frames self.last_frames = frames
remote_arm_state_tensor = torch.tensor(state_observation[OBS_STATE][:6], dtype=torch.float32) remote_arm_state_tensor = torch.tensor(state_observation[OBS_STATE][:6], dtype=torch.float64)
self.last_remote_arm_state = remote_arm_state_tensor self.last_remote_arm_state = remote_arm_state_tensor
present_speed = state_observation[OBS_STATE][6:] present_speed = state_observation[OBS_STATE][6:]
@ -351,7 +371,7 @@ class DaemonLeKiwiRobot(Robot):
frames, present_speed, remote_arm_state_tensor = self._get_data() frames, present_speed, remote_arm_state_tensor = self._get_data()
body_state = self._wheel_raw_to_body(present_speed) body_state = self._wheel_raw_to_body(present_speed)
body_state_mm = (body_state[0] * 1000.0, body_state[1] * 1000.0, body_state[2]) # Convert x,y to mm/s body_state_mm = (body_state[0] * 1000.0, body_state[1] * 1000.0, body_state[2]) # Convert x,y to mm/s
wheel_state_tensor = torch.tensor(body_state_mm, dtype=torch.float32) wheel_state_tensor = torch.tensor(body_state_mm, dtype=torch.float64)
combined_state_tensor = torch.cat((remote_arm_state_tensor, wheel_state_tensor), dim=0) combined_state_tensor = torch.cat((remote_arm_state_tensor, wheel_state_tensor), dim=0)
obs_dict = {OBS_STATE: combined_state_tensor} obs_dict = {OBS_STATE: combined_state_tensor}
@ -361,9 +381,15 @@ class DaemonLeKiwiRobot(Robot):
if frame is None: if frame is None:
# TODO(Steven): Daemon doesn't know camera dimensions # TODO(Steven): Daemon doesn't know camera dimensions
logging.warning("Frame is None") logging.warning("Frame is None")
# frame = np.zeros((cam.height, cam.width, cam.channels), dtype=np.uint8) frame = np.zeros((480, 640, 3), dtype=np.uint8)
obs_dict[cam_name] = torch.from_numpy(frame) obs_dict[cam_name] = torch.from_numpy(frame)
# TODO(Steven): Refactor this ugly thing
if OBS_IMAGES + ".wrist" not in obs_dict:
obs_dict[OBS_IMAGES + ".wrist"] = np.zeros(shape=(480, 640, 3))
if OBS_IMAGES + ".front" not in obs_dict:
obs_dict[OBS_IMAGES + ".front"] = np.zeros(shape=(640, 480, 3))
return obs_dict return obs_dict
def _from_keyboard_to_wheel_action(self, pressed_keys: np.ndarray): def _from_keyboard_to_wheel_action(self, pressed_keys: np.ndarray):
@ -425,7 +451,6 @@ class DaemonLeKiwiRobot(Robot):
) )
goal_pos: np.array = np.zeros(9) goal_pos: np.array = np.zeros(9)
if self.robot_mode is RobotMode.AUTO: if self.robot_mode is RobotMode.AUTO:
# TODO(Steven): Not yet implemented. The policy outputs might need a different conversion # TODO(Steven): Not yet implemented. The policy outputs might need a different conversion
raise Exception raise Exception
@ -441,7 +466,6 @@ class DaemonLeKiwiRobot(Robot):
# TODO(Steven): Assumes size and order is respected # TODO(Steven): Assumes size and order is respected
wheel_actions = [v for _, v in self._from_keyboard_to_wheel_action(action[6:]).items()] wheel_actions = [v for _, v in self._from_keyboard_to_wheel_action(action[6:]).items()]
goal_pos[6:] = wheel_actions goal_pos[6:] = wheel_actions
self.zmq_cmd_socket.send_string(json.dumps(goal_pos.tolist())) # action is in motor space self.zmq_cmd_socket.send_string(json.dumps(goal_pos.tolist())) # action is in motor space
return goal_pos return goal_pos

View File

@ -1,13 +1,69 @@
import logging import logging
import time
import numpy as np import numpy as np
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.robots.config import RobotMode
from lerobot.common.robots.lekiwi.configuration_daemon_lekiwi import DaemonLeKiwiRobotConfig from lerobot.common.robots.lekiwi.configuration_daemon_lekiwi import DaemonLeKiwiRobotConfig
from lerobot.common.robots.lekiwi.daemon_lekiwi import DaemonLeKiwiRobot, RobotMode from lerobot.common.robots.lekiwi.daemon_lekiwi import DaemonLeKiwiRobot
from lerobot.common.teleoperators.keyboard import KeyboardTeleop, KeyboardTeleopConfig from lerobot.common.teleoperators.keyboard import KeyboardTeleop, KeyboardTeleopConfig
from lerobot.common.teleoperators.so100 import SO100Teleop, SO100TeleopConfig from lerobot.common.teleoperators.so100 import SO100Teleop, SO100TeleopConfig
DUMMY_FEATURES = {
"observation.state": {
"dtype": "float64",
"shape": (9,),
"names": {
"motors": [
"arm_shoulder_pan",
"arm_shoulder_lift",
"arm_elbow_flex",
"arm_wrist_flex",
"arm_wrist_roll",
"arm_gripper",
"base_left_wheel",
"base_right_wheel",
"base_back_wheel",
]
},
},
"action": {
"dtype": "float64",
"shape": (9,),
"names": {
"motors": [
"arm_shoulder_pan",
"arm_shoulder_lift",
"arm_elbow_flex",
"arm_wrist_flex",
"arm_wrist_roll",
"arm_gripper",
"base_left_wheel",
"base_right_wheel",
"base_back_wheel",
]
},
},
"observation.images.front": {
"dtype": "image",
"shape": (640, 480, 3),
"names": [
"width",
"height",
"channels",
],
},
"observation.images.wrist": {
"dtype": "image",
"shape": (480, 640, 3),
"names": [
"width",
"height",
"channels",
],
},
}
def main(): def main():
logging.info("Configuring Teleop Devices") logging.info("Configuring Teleop Devices")
@ -17,41 +73,58 @@ def main():
keyboard_config = KeyboardTeleopConfig() keyboard_config = KeyboardTeleopConfig()
keyboard = KeyboardTeleop(keyboard_config) keyboard = KeyboardTeleop(keyboard_config)
logging.info("Configuring LeKiwiRobot Daemon")
robot_config = DaemonLeKiwiRobotConfig(
id="daemonlekiwi", calibration_dir=".cache/calibration/lekiwi", robot_mode=RobotMode.TELEOP
)
robot = DaemonLeKiwiRobot(robot_config)
logging.info("Creating LeRobot Dataset")
# TODO(Steven): Check this creation
dataset = LeRobotDataset.create(
repo_id="user/lekiwi",
fps=10,
features=DUMMY_FEATURES,
)
logging.info("Connecting Teleop Devices") logging.info("Connecting Teleop Devices")
leader_arm.connect() leader_arm.connect()
keyboard.connect() keyboard.connect()
logging.info("Configuring LeKiwiRobot Daemon")
robot_config = DaemonLeKiwiRobotConfig()
robot = DaemonLeKiwiRobot(robot_config)
logging.info("Connecting remote LeKiwiRobot") logging.info("Connecting remote LeKiwiRobot")
robot.connect() robot.connect()
robot.robot_mode = RobotMode.TELEOP
logging.info("Starting LeKiwiRobot teleoperation") logging.info("Starting LeKiwiRobot teleoperation")
start = time.perf_counter() i = 0
duration = 0 while i < 1000:
while duration < 100:
arm_action = leader_arm.get_action() arm_action = leader_arm.get_action()
base_action = keyboard.get_action() base_action = keyboard.get_action()
action = np.append(arm_action, base_action) if base_action.size > 0 else arm_action action = np.append(arm_action, base_action) if base_action.size > 0 else arm_action
_action_sent = robot.send_action(action)
_observation = robot.get_observation()
# dataset.save(action_sent, obs)
# TODO(Steven): Deal with policy action space # TODO(Steven): Deal with policy action space
# robot.set_mode(RobotMode.AUTO) # robot.set_mode(RobotMode.AUTO)
# policy_action = policy.get_action() # This might be in body frame, key space or smt else # policy_action = policy.get_action() # This might be in body frame, key space or smt else
# robot.send_action(policy_action) # robot.send_action(policy_action)
duration = time.perf_counter() - start
action_sent = robot.send_action(action)
observation = robot.get_observation()
frame = {"action": action_sent}
frame.update(observation)
frame.update({"task": "Dummy Task Dataset"})
logging.info("Saved a frame into the dataset")
dataset.add_frame(frame)
i += 1
dataset.save_episode()
# dataset.push_to_hub()
logging.info("Disconnecting Teleop Devices and LeKiwiRobot Daemon") logging.info("Disconnecting Teleop Devices and LeKiwiRobot Daemon")
robot.disconnect() robot.disconnect()
leader_arm.disconnect() leader_arm.disconnect()
keyboard.disconnect() keyboard.disconnect()
logging.info("Finished LeKiwiRobot cleanly") logging.info("Finished LeKiwiRobot cleanly")

View File

@ -1,7 +1,5 @@
import abc import abc
from pathlib import Path from pathlib import Path
import enum
from pathlib import Path
from typing import Any from typing import Any
import draccus import draccus
@ -12,11 +10,6 @@ from lerobot.common.motors import MotorCalibration
from .config import RobotConfig from .config import RobotConfig
class RobotMode(enum.Enum):
TELEOP = 0
AUTO = 1
# TODO(aliberts): action/obs typing such as Generic[ObsType, ActType] similar to gym.Env ? # TODO(aliberts): action/obs typing such as Generic[ObsType, ActType] similar to gym.Env ?
# https://github.com/Farama-Foundation/Gymnasium/blob/3287c869f9a48d99454306b0d4b4ec537f0f35e3/gymnasium/core.py#L23 # https://github.com/Farama-Foundation/Gymnasium/blob/3287c869f9a48d99454306b0d4b4ec537f0f35e3/gymnasium/core.py#L23
class Robot(abc.ABC): class Robot(abc.ABC):
@ -28,8 +21,8 @@ class Robot(abc.ABC):
def __init__(self, config: RobotConfig): def __init__(self, config: RobotConfig):
self.robot_type = self.name self.robot_type = self.name
self.robot_mode: RobotMode | None = None
self.id = config.id self.id = config.id
self.robot_mode = config.robot_mode
self.calibration_dir = ( self.calibration_dir = (
Path(config.calibration_dir) Path(config.calibration_dir)
if config.calibration_dir if config.calibration_dir