feat(lekiwi): Make dataset recording work
This commit is contained in:
parent
e0d1b75408
commit
0da9063efd
|
@ -1,16 +1,23 @@
|
||||||
import abc
|
import abc
|
||||||
|
import enum
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import draccus
|
import draccus
|
||||||
|
|
||||||
|
|
||||||
|
class RobotMode(enum.Enum):
|
||||||
|
TELEOP = 0
|
||||||
|
AUTO = 1
|
||||||
|
|
||||||
|
|
||||||
@dataclass(kw_only=True)
|
@dataclass(kw_only=True)
|
||||||
class RobotConfig(draccus.ChoiceRegistry, abc.ABC):
|
class RobotConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||||
# Allows to distinguish between different robots of the same type
|
# Allows to distinguish between different robots of the same type
|
||||||
id: str | None = None
|
id: str | None = None
|
||||||
# Directory to store calibration file
|
# Directory to store calibration file
|
||||||
calibration_dir: Path | None = None
|
calibration_dir: Path | None = None
|
||||||
|
robot_mode: RobotMode | None = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def type(self) -> str:
|
def type(self) -> str:
|
||||||
|
|
|
@ -11,10 +11,6 @@ class DaemonLeKiwiRobotConfig(RobotConfig):
|
||||||
port_zmq_cmd: int = 5555
|
port_zmq_cmd: int = 5555
|
||||||
port_zmq_observations: int = 5556
|
port_zmq_observations: int = 5556
|
||||||
|
|
||||||
id = "daemonlekiwi"
|
|
||||||
|
|
||||||
calibration_dir: str = ".cache/calibration/lekiwi"
|
|
||||||
|
|
||||||
teleop_keys: dict[str, str] = field(
|
teleop_keys: dict[str, str] = field(
|
||||||
default_factory=lambda: {
|
default_factory=lambda: {
|
||||||
# Movement
|
# Movement
|
||||||
|
|
|
@ -25,8 +25,9 @@ import zmq
|
||||||
|
|
||||||
from lerobot.common.constants import OBS_IMAGES, OBS_STATE
|
from lerobot.common.constants import OBS_IMAGES, OBS_STATE
|
||||||
from lerobot.common.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError, InvalidActionError
|
from lerobot.common.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError, InvalidActionError
|
||||||
|
from lerobot.common.robots.config import RobotMode
|
||||||
|
|
||||||
from ..robot import Robot, RobotMode
|
from ..robot import Robot
|
||||||
from .configuration_daemon_lekiwi import DaemonLeKiwiRobotConfig
|
from .configuration_daemon_lekiwi import DaemonLeKiwiRobotConfig
|
||||||
|
|
||||||
|
|
||||||
|
@ -50,6 +51,7 @@ class DaemonLeKiwiRobot(Robot):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.id = config.id
|
self.id = config.id
|
||||||
self.robot_type = config.type
|
self.robot_type = config.type
|
||||||
|
self.robot_mode = config.robot_mode
|
||||||
|
|
||||||
self.remote_ip = config.remote_ip
|
self.remote_ip = config.remote_ip
|
||||||
self.port_zmq_cmd = config.port_zmq_cmd
|
self.port_zmq_cmd = config.port_zmq_cmd
|
||||||
|
@ -63,7 +65,9 @@ class DaemonLeKiwiRobot(Robot):
|
||||||
|
|
||||||
self.last_frames = {}
|
self.last_frames = {}
|
||||||
self.last_present_speed = [0, 0, 0]
|
self.last_present_speed = [0, 0, 0]
|
||||||
self.last_remote_arm_state = torch.zeros(6, dtype=torch.float32)
|
|
||||||
|
# TODO(Steven): Consider 32 instead
|
||||||
|
self.last_remote_arm_state = torch.zeros(6, dtype=torch.float64)
|
||||||
|
|
||||||
# Define three speed levels and a current index
|
# Define three speed levels and a current index
|
||||||
self.speed_levels = [
|
self.speed_levels = [
|
||||||
|
@ -81,12 +85,23 @@ class DaemonLeKiwiRobot(Robot):
|
||||||
# TODO(Steven): Get this from the data fetched?
|
# TODO(Steven): Get this from the data fetched?
|
||||||
# TODO(Steven): Motor names are unknown for the Daemon
|
# TODO(Steven): Motor names are unknown for the Daemon
|
||||||
# Or assume its size/metadata?
|
# Or assume its size/metadata?
|
||||||
# return {
|
return {
|
||||||
# "dtype": "float32",
|
"dtype": "float64",
|
||||||
# "shape": (len(self.actuators),),
|
"shape": (9,),
|
||||||
# "names": {"motors": list(self.actuators.motors)},
|
"names": {
|
||||||
# }
|
"motors": [
|
||||||
pass
|
"arm_shoulder_pan",
|
||||||
|
"arm_shoulder_lift",
|
||||||
|
"arm_elbow_flex",
|
||||||
|
"arm_wrist_flex",
|
||||||
|
"arm_wrist_roll",
|
||||||
|
"arm_gripper",
|
||||||
|
"base_left_wheel",
|
||||||
|
"base_right_wheel",
|
||||||
|
"base_back_wheel",
|
||||||
|
]
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def action_feature(self) -> dict:
|
def action_feature(self) -> dict:
|
||||||
|
@ -97,15 +112,19 @@ class DaemonLeKiwiRobot(Robot):
|
||||||
# TODO(Steven): Get this from the data fetched?
|
# TODO(Steven): Get this from the data fetched?
|
||||||
# TODO(Steven): Motor names are unknown for the Daemon
|
# TODO(Steven): Motor names are unknown for the Daemon
|
||||||
# Or assume its size/metadata?
|
# Or assume its size/metadata?
|
||||||
# cam_ft = {}
|
cam_ft = {
|
||||||
# for cam_key, cam in self.cameras.items():
|
"front": {
|
||||||
# cam_ft[cam_key] = {
|
"shape": (480, 640, 3),
|
||||||
# "shape": (cam.height, cam.width, cam.channels),
|
"names": ["height", "width", "channels"],
|
||||||
# "names": ["height", "width", "channels"],
|
"info": None,
|
||||||
# "info": None,
|
},
|
||||||
# }
|
"wrist": {
|
||||||
# return cam_ft
|
"shape": (480, 640, 3),
|
||||||
pass
|
"names": ["height", "width", "channels"],
|
||||||
|
"info": None,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return cam_ft
|
||||||
|
|
||||||
def connect(self) -> None:
|
def connect(self) -> None:
|
||||||
"""Establishes ZMQ sockets with the remote mobile robot"""
|
"""Establishes ZMQ sockets with the remote mobile robot"""
|
||||||
|
@ -259,6 +278,7 @@ class DaemonLeKiwiRobot(Robot):
|
||||||
return (x_cmd, y_cmd, theta_cmd)
|
return (x_cmd, y_cmd, theta_cmd)
|
||||||
|
|
||||||
# TODO(Steven): This is flaky, for example, if we received a state but failed decoding the image, we will not update any value
|
# TODO(Steven): This is flaky, for example, if we received a state but failed decoding the image, we will not update any value
|
||||||
|
# TODO(Steven): All this function needs to be refactored
|
||||||
def _get_data(self):
|
def _get_data(self):
|
||||||
# Copied from robot_lekiwi.py
|
# Copied from robot_lekiwi.py
|
||||||
"""Polls the video socket for up to 15 ms. If data arrives, decode only
|
"""Polls the video socket for up to 15 ms. If data arrives, decode only
|
||||||
|
@ -269,7 +289,7 @@ class DaemonLeKiwiRobot(Robot):
|
||||||
present_speed = []
|
present_speed = []
|
||||||
|
|
||||||
# TODO(Steven): Size is being assumed, is this safe?
|
# TODO(Steven): Size is being assumed, is this safe?
|
||||||
remote_arm_state_tensor = torch.empty(6, dtype=torch.float32)
|
remote_arm_state_tensor = torch.empty(6, dtype=torch.float64)
|
||||||
|
|
||||||
# Poll up to 15 ms
|
# Poll up to 15 ms
|
||||||
poller = zmq.Poller()
|
poller = zmq.Poller()
|
||||||
|
@ -317,7 +337,7 @@ class DaemonLeKiwiRobot(Robot):
|
||||||
if state_observation is not None and frames is not None:
|
if state_observation is not None and frames is not None:
|
||||||
self.last_frames = frames
|
self.last_frames = frames
|
||||||
|
|
||||||
remote_arm_state_tensor = torch.tensor(state_observation[OBS_STATE][:6], dtype=torch.float32)
|
remote_arm_state_tensor = torch.tensor(state_observation[OBS_STATE][:6], dtype=torch.float64)
|
||||||
self.last_remote_arm_state = remote_arm_state_tensor
|
self.last_remote_arm_state = remote_arm_state_tensor
|
||||||
|
|
||||||
present_speed = state_observation[OBS_STATE][6:]
|
present_speed = state_observation[OBS_STATE][6:]
|
||||||
|
@ -351,7 +371,7 @@ class DaemonLeKiwiRobot(Robot):
|
||||||
frames, present_speed, remote_arm_state_tensor = self._get_data()
|
frames, present_speed, remote_arm_state_tensor = self._get_data()
|
||||||
body_state = self._wheel_raw_to_body(present_speed)
|
body_state = self._wheel_raw_to_body(present_speed)
|
||||||
body_state_mm = (body_state[0] * 1000.0, body_state[1] * 1000.0, body_state[2]) # Convert x,y to mm/s
|
body_state_mm = (body_state[0] * 1000.0, body_state[1] * 1000.0, body_state[2]) # Convert x,y to mm/s
|
||||||
wheel_state_tensor = torch.tensor(body_state_mm, dtype=torch.float32)
|
wheel_state_tensor = torch.tensor(body_state_mm, dtype=torch.float64)
|
||||||
combined_state_tensor = torch.cat((remote_arm_state_tensor, wheel_state_tensor), dim=0)
|
combined_state_tensor = torch.cat((remote_arm_state_tensor, wheel_state_tensor), dim=0)
|
||||||
|
|
||||||
obs_dict = {OBS_STATE: combined_state_tensor}
|
obs_dict = {OBS_STATE: combined_state_tensor}
|
||||||
|
@ -361,9 +381,15 @@ class DaemonLeKiwiRobot(Robot):
|
||||||
if frame is None:
|
if frame is None:
|
||||||
# TODO(Steven): Daemon doesn't know camera dimensions
|
# TODO(Steven): Daemon doesn't know camera dimensions
|
||||||
logging.warning("Frame is None")
|
logging.warning("Frame is None")
|
||||||
# frame = np.zeros((cam.height, cam.width, cam.channels), dtype=np.uint8)
|
frame = np.zeros((480, 640, 3), dtype=np.uint8)
|
||||||
obs_dict[cam_name] = torch.from_numpy(frame)
|
obs_dict[cam_name] = torch.from_numpy(frame)
|
||||||
|
|
||||||
|
# TODO(Steven): Refactor this ugly thing
|
||||||
|
if OBS_IMAGES + ".wrist" not in obs_dict:
|
||||||
|
obs_dict[OBS_IMAGES + ".wrist"] = np.zeros(shape=(480, 640, 3))
|
||||||
|
if OBS_IMAGES + ".front" not in obs_dict:
|
||||||
|
obs_dict[OBS_IMAGES + ".front"] = np.zeros(shape=(640, 480, 3))
|
||||||
|
|
||||||
return obs_dict
|
return obs_dict
|
||||||
|
|
||||||
def _from_keyboard_to_wheel_action(self, pressed_keys: np.ndarray):
|
def _from_keyboard_to_wheel_action(self, pressed_keys: np.ndarray):
|
||||||
|
@ -425,7 +451,6 @@ class DaemonLeKiwiRobot(Robot):
|
||||||
)
|
)
|
||||||
|
|
||||||
goal_pos: np.array = np.zeros(9)
|
goal_pos: np.array = np.zeros(9)
|
||||||
|
|
||||||
if self.robot_mode is RobotMode.AUTO:
|
if self.robot_mode is RobotMode.AUTO:
|
||||||
# TODO(Steven): Not yet implemented. The policy outputs might need a different conversion
|
# TODO(Steven): Not yet implemented. The policy outputs might need a different conversion
|
||||||
raise Exception
|
raise Exception
|
||||||
|
@ -441,7 +466,6 @@ class DaemonLeKiwiRobot(Robot):
|
||||||
# TODO(Steven): Assumes size and order is respected
|
# TODO(Steven): Assumes size and order is respected
|
||||||
wheel_actions = [v for _, v in self._from_keyboard_to_wheel_action(action[6:]).items()]
|
wheel_actions = [v for _, v in self._from_keyboard_to_wheel_action(action[6:]).items()]
|
||||||
goal_pos[6:] = wheel_actions
|
goal_pos[6:] = wheel_actions
|
||||||
|
|
||||||
self.zmq_cmd_socket.send_string(json.dumps(goal_pos.tolist())) # action is in motor space
|
self.zmq_cmd_socket.send_string(json.dumps(goal_pos.tolist())) # action is in motor space
|
||||||
|
|
||||||
return goal_pos
|
return goal_pos
|
||||||
|
|
|
@ -1,13 +1,69 @@
|
||||||
import logging
|
import logging
|
||||||
import time
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
|
from lerobot.common.robots.config import RobotMode
|
||||||
from lerobot.common.robots.lekiwi.configuration_daemon_lekiwi import DaemonLeKiwiRobotConfig
|
from lerobot.common.robots.lekiwi.configuration_daemon_lekiwi import DaemonLeKiwiRobotConfig
|
||||||
from lerobot.common.robots.lekiwi.daemon_lekiwi import DaemonLeKiwiRobot, RobotMode
|
from lerobot.common.robots.lekiwi.daemon_lekiwi import DaemonLeKiwiRobot
|
||||||
from lerobot.common.teleoperators.keyboard import KeyboardTeleop, KeyboardTeleopConfig
|
from lerobot.common.teleoperators.keyboard import KeyboardTeleop, KeyboardTeleopConfig
|
||||||
from lerobot.common.teleoperators.so100 import SO100Teleop, SO100TeleopConfig
|
from lerobot.common.teleoperators.so100 import SO100Teleop, SO100TeleopConfig
|
||||||
|
|
||||||
|
DUMMY_FEATURES = {
|
||||||
|
"observation.state": {
|
||||||
|
"dtype": "float64",
|
||||||
|
"shape": (9,),
|
||||||
|
"names": {
|
||||||
|
"motors": [
|
||||||
|
"arm_shoulder_pan",
|
||||||
|
"arm_shoulder_lift",
|
||||||
|
"arm_elbow_flex",
|
||||||
|
"arm_wrist_flex",
|
||||||
|
"arm_wrist_roll",
|
||||||
|
"arm_gripper",
|
||||||
|
"base_left_wheel",
|
||||||
|
"base_right_wheel",
|
||||||
|
"base_back_wheel",
|
||||||
|
]
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"action": {
|
||||||
|
"dtype": "float64",
|
||||||
|
"shape": (9,),
|
||||||
|
"names": {
|
||||||
|
"motors": [
|
||||||
|
"arm_shoulder_pan",
|
||||||
|
"arm_shoulder_lift",
|
||||||
|
"arm_elbow_flex",
|
||||||
|
"arm_wrist_flex",
|
||||||
|
"arm_wrist_roll",
|
||||||
|
"arm_gripper",
|
||||||
|
"base_left_wheel",
|
||||||
|
"base_right_wheel",
|
||||||
|
"base_back_wheel",
|
||||||
|
]
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"observation.images.front": {
|
||||||
|
"dtype": "image",
|
||||||
|
"shape": (640, 480, 3),
|
||||||
|
"names": [
|
||||||
|
"width",
|
||||||
|
"height",
|
||||||
|
"channels",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"observation.images.wrist": {
|
||||||
|
"dtype": "image",
|
||||||
|
"shape": (480, 640, 3),
|
||||||
|
"names": [
|
||||||
|
"width",
|
||||||
|
"height",
|
||||||
|
"channels",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
logging.info("Configuring Teleop Devices")
|
logging.info("Configuring Teleop Devices")
|
||||||
|
@ -17,41 +73,58 @@ def main():
|
||||||
keyboard_config = KeyboardTeleopConfig()
|
keyboard_config = KeyboardTeleopConfig()
|
||||||
keyboard = KeyboardTeleop(keyboard_config)
|
keyboard = KeyboardTeleop(keyboard_config)
|
||||||
|
|
||||||
|
logging.info("Configuring LeKiwiRobot Daemon")
|
||||||
|
robot_config = DaemonLeKiwiRobotConfig(
|
||||||
|
id="daemonlekiwi", calibration_dir=".cache/calibration/lekiwi", robot_mode=RobotMode.TELEOP
|
||||||
|
)
|
||||||
|
robot = DaemonLeKiwiRobot(robot_config)
|
||||||
|
|
||||||
|
logging.info("Creating LeRobot Dataset")
|
||||||
|
|
||||||
|
# TODO(Steven): Check this creation
|
||||||
|
dataset = LeRobotDataset.create(
|
||||||
|
repo_id="user/lekiwi",
|
||||||
|
fps=10,
|
||||||
|
features=DUMMY_FEATURES,
|
||||||
|
)
|
||||||
|
|
||||||
logging.info("Connecting Teleop Devices")
|
logging.info("Connecting Teleop Devices")
|
||||||
leader_arm.connect()
|
leader_arm.connect()
|
||||||
keyboard.connect()
|
keyboard.connect()
|
||||||
|
|
||||||
logging.info("Configuring LeKiwiRobot Daemon")
|
|
||||||
robot_config = DaemonLeKiwiRobotConfig()
|
|
||||||
robot = DaemonLeKiwiRobot(robot_config)
|
|
||||||
|
|
||||||
logging.info("Connecting remote LeKiwiRobot")
|
logging.info("Connecting remote LeKiwiRobot")
|
||||||
robot.connect()
|
robot.connect()
|
||||||
robot.robot_mode = RobotMode.TELEOP
|
|
||||||
|
|
||||||
logging.info("Starting LeKiwiRobot teleoperation")
|
logging.info("Starting LeKiwiRobot teleoperation")
|
||||||
start = time.perf_counter()
|
i = 0
|
||||||
duration = 0
|
while i < 1000:
|
||||||
while duration < 100:
|
|
||||||
arm_action = leader_arm.get_action()
|
arm_action = leader_arm.get_action()
|
||||||
base_action = keyboard.get_action()
|
base_action = keyboard.get_action()
|
||||||
action = np.append(arm_action, base_action) if base_action.size > 0 else arm_action
|
action = np.append(arm_action, base_action) if base_action.size > 0 else arm_action
|
||||||
_action_sent = robot.send_action(action)
|
|
||||||
_observation = robot.get_observation()
|
|
||||||
|
|
||||||
# dataset.save(action_sent, obs)
|
|
||||||
|
|
||||||
# TODO(Steven): Deal with policy action space
|
# TODO(Steven): Deal with policy action space
|
||||||
# robot.set_mode(RobotMode.AUTO)
|
# robot.set_mode(RobotMode.AUTO)
|
||||||
# policy_action = policy.get_action() # This might be in body frame, key space or smt else
|
# policy_action = policy.get_action() # This might be in body frame, key space or smt else
|
||||||
# robot.send_action(policy_action)
|
# robot.send_action(policy_action)
|
||||||
duration = time.perf_counter() - start
|
|
||||||
|
action_sent = robot.send_action(action)
|
||||||
|
observation = robot.get_observation()
|
||||||
|
|
||||||
|
frame = {"action": action_sent}
|
||||||
|
frame.update(observation)
|
||||||
|
frame.update({"task": "Dummy Task Dataset"})
|
||||||
|
|
||||||
|
logging.info("Saved a frame into the dataset")
|
||||||
|
dataset.add_frame(frame)
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
dataset.save_episode()
|
||||||
|
# dataset.push_to_hub()
|
||||||
|
|
||||||
logging.info("Disconnecting Teleop Devices and LeKiwiRobot Daemon")
|
logging.info("Disconnecting Teleop Devices and LeKiwiRobot Daemon")
|
||||||
robot.disconnect()
|
robot.disconnect()
|
||||||
leader_arm.disconnect()
|
leader_arm.disconnect()
|
||||||
keyboard.disconnect()
|
keyboard.disconnect()
|
||||||
|
|
||||||
logging.info("Finished LeKiwiRobot cleanly")
|
logging.info("Finished LeKiwiRobot cleanly")
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,5 @@
|
||||||
import abc
|
import abc
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import enum
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import draccus
|
import draccus
|
||||||
|
@ -12,11 +10,6 @@ from lerobot.common.motors import MotorCalibration
|
||||||
from .config import RobotConfig
|
from .config import RobotConfig
|
||||||
|
|
||||||
|
|
||||||
class RobotMode(enum.Enum):
|
|
||||||
TELEOP = 0
|
|
||||||
AUTO = 1
|
|
||||||
|
|
||||||
|
|
||||||
# TODO(aliberts): action/obs typing such as Generic[ObsType, ActType] similar to gym.Env ?
|
# TODO(aliberts): action/obs typing such as Generic[ObsType, ActType] similar to gym.Env ?
|
||||||
# https://github.com/Farama-Foundation/Gymnasium/blob/3287c869f9a48d99454306b0d4b4ec537f0f35e3/gymnasium/core.py#L23
|
# https://github.com/Farama-Foundation/Gymnasium/blob/3287c869f9a48d99454306b0d4b4ec537f0f35e3/gymnasium/core.py#L23
|
||||||
class Robot(abc.ABC):
|
class Robot(abc.ABC):
|
||||||
|
@ -28,8 +21,8 @@ class Robot(abc.ABC):
|
||||||
|
|
||||||
def __init__(self, config: RobotConfig):
|
def __init__(self, config: RobotConfig):
|
||||||
self.robot_type = self.name
|
self.robot_type = self.name
|
||||||
self.robot_mode: RobotMode | None = None
|
|
||||||
self.id = config.id
|
self.id = config.id
|
||||||
|
self.robot_mode = config.robot_mode
|
||||||
self.calibration_dir = (
|
self.calibration_dir = (
|
||||||
Path(config.calibration_dir)
|
Path(config.calibration_dir)
|
||||||
if config.calibration_dir
|
if config.calibration_dir
|
||||||
|
|
Loading…
Reference in New Issue