feat(lekiwi): Make dataset recording work
This commit is contained in:
parent
e0d1b75408
commit
0da9063efd
|
@ -1,16 +1,23 @@
|
|||
import abc
|
||||
import enum
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import draccus
|
||||
|
||||
|
||||
class RobotMode(enum.Enum):
|
||||
TELEOP = 0
|
||||
AUTO = 1
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class RobotConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||
# Allows to distinguish between different robots of the same type
|
||||
id: str | None = None
|
||||
# Directory to store calibration file
|
||||
calibration_dir: Path | None = None
|
||||
robot_mode: RobotMode | None = None
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
|
|
|
@ -11,10 +11,6 @@ class DaemonLeKiwiRobotConfig(RobotConfig):
|
|||
port_zmq_cmd: int = 5555
|
||||
port_zmq_observations: int = 5556
|
||||
|
||||
id = "daemonlekiwi"
|
||||
|
||||
calibration_dir: str = ".cache/calibration/lekiwi"
|
||||
|
||||
teleop_keys: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
# Movement
|
||||
|
|
|
@ -25,8 +25,9 @@ import zmq
|
|||
|
||||
from lerobot.common.constants import OBS_IMAGES, OBS_STATE
|
||||
from lerobot.common.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError, InvalidActionError
|
||||
from lerobot.common.robots.config import RobotMode
|
||||
|
||||
from ..robot import Robot, RobotMode
|
||||
from ..robot import Robot
|
||||
from .configuration_daemon_lekiwi import DaemonLeKiwiRobotConfig
|
||||
|
||||
|
||||
|
@ -50,6 +51,7 @@ class DaemonLeKiwiRobot(Robot):
|
|||
self.config = config
|
||||
self.id = config.id
|
||||
self.robot_type = config.type
|
||||
self.robot_mode = config.robot_mode
|
||||
|
||||
self.remote_ip = config.remote_ip
|
||||
self.port_zmq_cmd = config.port_zmq_cmd
|
||||
|
@ -63,7 +65,9 @@ class DaemonLeKiwiRobot(Robot):
|
|||
|
||||
self.last_frames = {}
|
||||
self.last_present_speed = [0, 0, 0]
|
||||
self.last_remote_arm_state = torch.zeros(6, dtype=torch.float32)
|
||||
|
||||
# TODO(Steven): Consider 32 instead
|
||||
self.last_remote_arm_state = torch.zeros(6, dtype=torch.float64)
|
||||
|
||||
# Define three speed levels and a current index
|
||||
self.speed_levels = [
|
||||
|
@ -81,12 +85,23 @@ class DaemonLeKiwiRobot(Robot):
|
|||
# TODO(Steven): Get this from the data fetched?
|
||||
# TODO(Steven): Motor names are unknown for the Daemon
|
||||
# Or assume its size/metadata?
|
||||
# return {
|
||||
# "dtype": "float32",
|
||||
# "shape": (len(self.actuators),),
|
||||
# "names": {"motors": list(self.actuators.motors)},
|
||||
# }
|
||||
pass
|
||||
return {
|
||||
"dtype": "float64",
|
||||
"shape": (9,),
|
||||
"names": {
|
||||
"motors": [
|
||||
"arm_shoulder_pan",
|
||||
"arm_shoulder_lift",
|
||||
"arm_elbow_flex",
|
||||
"arm_wrist_flex",
|
||||
"arm_wrist_roll",
|
||||
"arm_gripper",
|
||||
"base_left_wheel",
|
||||
"base_right_wheel",
|
||||
"base_back_wheel",
|
||||
]
|
||||
},
|
||||
}
|
||||
|
||||
@property
|
||||
def action_feature(self) -> dict:
|
||||
|
@ -97,15 +112,19 @@ class DaemonLeKiwiRobot(Robot):
|
|||
# TODO(Steven): Get this from the data fetched?
|
||||
# TODO(Steven): Motor names are unknown for the Daemon
|
||||
# Or assume its size/metadata?
|
||||
# cam_ft = {}
|
||||
# for cam_key, cam in self.cameras.items():
|
||||
# cam_ft[cam_key] = {
|
||||
# "shape": (cam.height, cam.width, cam.channels),
|
||||
# "names": ["height", "width", "channels"],
|
||||
# "info": None,
|
||||
# }
|
||||
# return cam_ft
|
||||
pass
|
||||
cam_ft = {
|
||||
"front": {
|
||||
"shape": (480, 640, 3),
|
||||
"names": ["height", "width", "channels"],
|
||||
"info": None,
|
||||
},
|
||||
"wrist": {
|
||||
"shape": (480, 640, 3),
|
||||
"names": ["height", "width", "channels"],
|
||||
"info": None,
|
||||
},
|
||||
}
|
||||
return cam_ft
|
||||
|
||||
def connect(self) -> None:
|
||||
"""Establishes ZMQ sockets with the remote mobile robot"""
|
||||
|
@ -259,6 +278,7 @@ class DaemonLeKiwiRobot(Robot):
|
|||
return (x_cmd, y_cmd, theta_cmd)
|
||||
|
||||
# TODO(Steven): This is flaky, for example, if we received a state but failed decoding the image, we will not update any value
|
||||
# TODO(Steven): All this function needs to be refactored
|
||||
def _get_data(self):
|
||||
# Copied from robot_lekiwi.py
|
||||
"""Polls the video socket for up to 15 ms. If data arrives, decode only
|
||||
|
@ -269,7 +289,7 @@ class DaemonLeKiwiRobot(Robot):
|
|||
present_speed = []
|
||||
|
||||
# TODO(Steven): Size is being assumed, is this safe?
|
||||
remote_arm_state_tensor = torch.empty(6, dtype=torch.float32)
|
||||
remote_arm_state_tensor = torch.empty(6, dtype=torch.float64)
|
||||
|
||||
# Poll up to 15 ms
|
||||
poller = zmq.Poller()
|
||||
|
@ -317,7 +337,7 @@ class DaemonLeKiwiRobot(Robot):
|
|||
if state_observation is not None and frames is not None:
|
||||
self.last_frames = frames
|
||||
|
||||
remote_arm_state_tensor = torch.tensor(state_observation[OBS_STATE][:6], dtype=torch.float32)
|
||||
remote_arm_state_tensor = torch.tensor(state_observation[OBS_STATE][:6], dtype=torch.float64)
|
||||
self.last_remote_arm_state = remote_arm_state_tensor
|
||||
|
||||
present_speed = state_observation[OBS_STATE][6:]
|
||||
|
@ -351,7 +371,7 @@ class DaemonLeKiwiRobot(Robot):
|
|||
frames, present_speed, remote_arm_state_tensor = self._get_data()
|
||||
body_state = self._wheel_raw_to_body(present_speed)
|
||||
body_state_mm = (body_state[0] * 1000.0, body_state[1] * 1000.0, body_state[2]) # Convert x,y to mm/s
|
||||
wheel_state_tensor = torch.tensor(body_state_mm, dtype=torch.float32)
|
||||
wheel_state_tensor = torch.tensor(body_state_mm, dtype=torch.float64)
|
||||
combined_state_tensor = torch.cat((remote_arm_state_tensor, wheel_state_tensor), dim=0)
|
||||
|
||||
obs_dict = {OBS_STATE: combined_state_tensor}
|
||||
|
@ -361,9 +381,15 @@ class DaemonLeKiwiRobot(Robot):
|
|||
if frame is None:
|
||||
# TODO(Steven): Daemon doesn't know camera dimensions
|
||||
logging.warning("Frame is None")
|
||||
# frame = np.zeros((cam.height, cam.width, cam.channels), dtype=np.uint8)
|
||||
frame = np.zeros((480, 640, 3), dtype=np.uint8)
|
||||
obs_dict[cam_name] = torch.from_numpy(frame)
|
||||
|
||||
# TODO(Steven): Refactor this ugly thing
|
||||
if OBS_IMAGES + ".wrist" not in obs_dict:
|
||||
obs_dict[OBS_IMAGES + ".wrist"] = np.zeros(shape=(480, 640, 3))
|
||||
if OBS_IMAGES + ".front" not in obs_dict:
|
||||
obs_dict[OBS_IMAGES + ".front"] = np.zeros(shape=(640, 480, 3))
|
||||
|
||||
return obs_dict
|
||||
|
||||
def _from_keyboard_to_wheel_action(self, pressed_keys: np.ndarray):
|
||||
|
@ -425,7 +451,6 @@ class DaemonLeKiwiRobot(Robot):
|
|||
)
|
||||
|
||||
goal_pos: np.array = np.zeros(9)
|
||||
|
||||
if self.robot_mode is RobotMode.AUTO:
|
||||
# TODO(Steven): Not yet implemented. The policy outputs might need a different conversion
|
||||
raise Exception
|
||||
|
@ -441,7 +466,6 @@ class DaemonLeKiwiRobot(Robot):
|
|||
# TODO(Steven): Assumes size and order is respected
|
||||
wheel_actions = [v for _, v in self._from_keyboard_to_wheel_action(action[6:]).items()]
|
||||
goal_pos[6:] = wheel_actions
|
||||
|
||||
self.zmq_cmd_socket.send_string(json.dumps(goal_pos.tolist())) # action is in motor space
|
||||
|
||||
return goal_pos
|
||||
|
|
|
@ -1,13 +1,69 @@
|
|||
import logging
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.robots.config import RobotMode
|
||||
from lerobot.common.robots.lekiwi.configuration_daemon_lekiwi import DaemonLeKiwiRobotConfig
|
||||
from lerobot.common.robots.lekiwi.daemon_lekiwi import DaemonLeKiwiRobot, RobotMode
|
||||
from lerobot.common.robots.lekiwi.daemon_lekiwi import DaemonLeKiwiRobot
|
||||
from lerobot.common.teleoperators.keyboard import KeyboardTeleop, KeyboardTeleopConfig
|
||||
from lerobot.common.teleoperators.so100 import SO100Teleop, SO100TeleopConfig
|
||||
|
||||
DUMMY_FEATURES = {
|
||||
"observation.state": {
|
||||
"dtype": "float64",
|
||||
"shape": (9,),
|
||||
"names": {
|
||||
"motors": [
|
||||
"arm_shoulder_pan",
|
||||
"arm_shoulder_lift",
|
||||
"arm_elbow_flex",
|
||||
"arm_wrist_flex",
|
||||
"arm_wrist_roll",
|
||||
"arm_gripper",
|
||||
"base_left_wheel",
|
||||
"base_right_wheel",
|
||||
"base_back_wheel",
|
||||
]
|
||||
},
|
||||
},
|
||||
"action": {
|
||||
"dtype": "float64",
|
||||
"shape": (9,),
|
||||
"names": {
|
||||
"motors": [
|
||||
"arm_shoulder_pan",
|
||||
"arm_shoulder_lift",
|
||||
"arm_elbow_flex",
|
||||
"arm_wrist_flex",
|
||||
"arm_wrist_roll",
|
||||
"arm_gripper",
|
||||
"base_left_wheel",
|
||||
"base_right_wheel",
|
||||
"base_back_wheel",
|
||||
]
|
||||
},
|
||||
},
|
||||
"observation.images.front": {
|
||||
"dtype": "image",
|
||||
"shape": (640, 480, 3),
|
||||
"names": [
|
||||
"width",
|
||||
"height",
|
||||
"channels",
|
||||
],
|
||||
},
|
||||
"observation.images.wrist": {
|
||||
"dtype": "image",
|
||||
"shape": (480, 640, 3),
|
||||
"names": [
|
||||
"width",
|
||||
"height",
|
||||
"channels",
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
logging.info("Configuring Teleop Devices")
|
||||
|
@ -17,41 +73,58 @@ def main():
|
|||
keyboard_config = KeyboardTeleopConfig()
|
||||
keyboard = KeyboardTeleop(keyboard_config)
|
||||
|
||||
logging.info("Configuring LeKiwiRobot Daemon")
|
||||
robot_config = DaemonLeKiwiRobotConfig(
|
||||
id="daemonlekiwi", calibration_dir=".cache/calibration/lekiwi", robot_mode=RobotMode.TELEOP
|
||||
)
|
||||
robot = DaemonLeKiwiRobot(robot_config)
|
||||
|
||||
logging.info("Creating LeRobot Dataset")
|
||||
|
||||
# TODO(Steven): Check this creation
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id="user/lekiwi",
|
||||
fps=10,
|
||||
features=DUMMY_FEATURES,
|
||||
)
|
||||
|
||||
logging.info("Connecting Teleop Devices")
|
||||
leader_arm.connect()
|
||||
keyboard.connect()
|
||||
|
||||
logging.info("Configuring LeKiwiRobot Daemon")
|
||||
robot_config = DaemonLeKiwiRobotConfig()
|
||||
robot = DaemonLeKiwiRobot(robot_config)
|
||||
|
||||
logging.info("Connecting remote LeKiwiRobot")
|
||||
robot.connect()
|
||||
robot.robot_mode = RobotMode.TELEOP
|
||||
|
||||
logging.info("Starting LeKiwiRobot teleoperation")
|
||||
start = time.perf_counter()
|
||||
duration = 0
|
||||
while duration < 100:
|
||||
i = 0
|
||||
while i < 1000:
|
||||
arm_action = leader_arm.get_action()
|
||||
base_action = keyboard.get_action()
|
||||
action = np.append(arm_action, base_action) if base_action.size > 0 else arm_action
|
||||
_action_sent = robot.send_action(action)
|
||||
_observation = robot.get_observation()
|
||||
|
||||
# dataset.save(action_sent, obs)
|
||||
|
||||
# TODO(Steven): Deal with policy action space
|
||||
# robot.set_mode(RobotMode.AUTO)
|
||||
# policy_action = policy.get_action() # This might be in body frame, key space or smt else
|
||||
# robot.send_action(policy_action)
|
||||
duration = time.perf_counter() - start
|
||||
|
||||
action_sent = robot.send_action(action)
|
||||
observation = robot.get_observation()
|
||||
|
||||
frame = {"action": action_sent}
|
||||
frame.update(observation)
|
||||
frame.update({"task": "Dummy Task Dataset"})
|
||||
|
||||
logging.info("Saved a frame into the dataset")
|
||||
dataset.add_frame(frame)
|
||||
i += 1
|
||||
|
||||
dataset.save_episode()
|
||||
# dataset.push_to_hub()
|
||||
|
||||
logging.info("Disconnecting Teleop Devices and LeKiwiRobot Daemon")
|
||||
robot.disconnect()
|
||||
leader_arm.disconnect()
|
||||
keyboard.disconnect()
|
||||
|
||||
logging.info("Finished LeKiwiRobot cleanly")
|
||||
|
||||
|
||||
|
|
|
@ -1,7 +1,5 @@
|
|||
import abc
|
||||
from pathlib import Path
|
||||
import enum
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import draccus
|
||||
|
@ -12,11 +10,6 @@ from lerobot.common.motors import MotorCalibration
|
|||
from .config import RobotConfig
|
||||
|
||||
|
||||
class RobotMode(enum.Enum):
|
||||
TELEOP = 0
|
||||
AUTO = 1
|
||||
|
||||
|
||||
# TODO(aliberts): action/obs typing such as Generic[ObsType, ActType] similar to gym.Env ?
|
||||
# https://github.com/Farama-Foundation/Gymnasium/blob/3287c869f9a48d99454306b0d4b4ec537f0f35e3/gymnasium/core.py#L23
|
||||
class Robot(abc.ABC):
|
||||
|
@ -28,8 +21,8 @@ class Robot(abc.ABC):
|
|||
|
||||
def __init__(self, config: RobotConfig):
|
||||
self.robot_type = self.name
|
||||
self.robot_mode: RobotMode | None = None
|
||||
self.id = config.id
|
||||
self.robot_mode = config.robot_mode
|
||||
self.calibration_dir = (
|
||||
Path(config.calibration_dir)
|
||||
if config.calibration_dir
|
||||
|
|
Loading…
Reference in New Issue