From ae70f12378f4c7b94cad5bfdba3c51e24c7ab381 Mon Sep 17 00:00:00 2001 From: ruanafan Date: Tue, 1 Apr 2025 18:20:36 +0800 Subject: [PATCH] Support openloop eval and serving --- lerobot/common/datasets/lerobot_dataset.py | 27 +++ lerobot/common/datasets/utils.py | 3 + .../v21/convert_dataset_v20_to_v21.py | 3 +- lerobot/common/policies/normalize.py | 1 + .../common/serving/websocket_policy_server.py | 111 +++++++++++ lerobot/common/utils/msgpack_utils.py | 43 +++++ lerobot/scripts/openloop_eval.py | 180 ++++++++++++++++++ lerobot/scripts/serving_policy.py | 81 ++++++++ lerobot/scripts/train.py | 45 ++++- pyproject.toml | 1 + tests/serving/test_websocket_serving.py | 43 +++++ 11 files changed, 535 insertions(+), 3 deletions(-) create mode 100644 lerobot/common/serving/websocket_policy_server.py create mode 100644 lerobot/common/utils/msgpack_utils.py create mode 100644 lerobot/scripts/openloop_eval.py create mode 100644 lerobot/scripts/serving_policy.py create mode 100644 tests/serving/test_websocket_serving.py diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index 6ef955dd..c03ca2f0 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -103,6 +103,14 @@ class LeRobotDatasetMetadata: def load_metadata(self): self.info = load_info(self.root) + # self.info['features']['observation.state']['shape'] = (13,) + # self.info['features']['observation.state']['names'] = ['left_arm_1', 'left_arm_2', 'left_arm_3', 'left_arm_4', 'left_arm_5', 'left_arm_6', + # 'right_arm_1', 'right_arm_2', 'right_arm_3', 'right_arm_4', 'right_arm_5', 'right_arm_6', 'vacuum'] + # self.info['features']['action']['shape'] = (13,) + # self.info['features']['action']['names'] = ['left_arm_exp_1', 'left_arm_exp_2', 'left_arm_exp_3', 'left_arm_exp_4', 'left_arm_exp_5', 'left_arm_exp_6', + # 'right_arm_exp_1', 'right_arm_exp_2', 'right_arm_exp_3', 'right_arm_exp_4', 'right_arm_exp_5', 'right_arm_exp_6', 'vacuum_exp'] + + check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION) self.tasks, self.task_to_task_index = load_tasks(self.root) self.episodes = load_episodes(self.root) @@ -110,7 +118,15 @@ class LeRobotDatasetMetadata: self.stats = load_stats(self.root) self.episodes_stats = backward_compatible_episodes_stats(self.stats, self.episodes) else: + # print("v2.1 episode stat process") self.episodes_stats = load_episodes_stats(self.root) + # print("episode stats: ", type(self.episodes_stats)) + # for _, episode_stats in self.episodes_stats.items(): + # for feature in ['observation.state', 'action']: + # for sub_feature in ['min', 'max', 'mean', 'std']: + # episode_stats[feature][sub_feature] = np.delete(episode_stats[feature][sub_feature], [6,13], axis=0) + # print('after process', feature, sub_feature, episode_stats[feature][sub_feature].shape) + self.stats = aggregate_stats(list(self.episodes_stats.values())) def pull_from_repo( @@ -731,6 +747,17 @@ class LeRobotDataset(torch.utils.data.Dataset): query_result = self._query_hf_dataset(query_indices) item = {**item, **padding} for key, val in query_result.items(): + # # remove gripper in original data + # if "observation.state" == key: + # keep_cols = [i for i in range(15) if i not in {6, 13}] + # # 切片操作删除指定列 + # val = val[:, keep_cols] + # # print("emmmmmm in __getitem_ state", val.shape) + # if "action" == key: + # keep_cols = [i for i in range(15) if i not in {6, 13}] + # # 切片操作删除指定列 + # val = val[:, keep_cols] + # # print("emmmmmm in __getitem_ action", val.shape) item[key] = val if len(self.meta.video_keys) > 0: diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index 7e297b35..4cc1f3e7 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -419,6 +419,9 @@ def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFea type = FeatureType.ACTION else: continue + + if 'depth' in key: + continue policy_features[key] = PolicyFeature( type=type, diff --git a/lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py b/lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py index 176d16d0..7e847061 100644 --- a/lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py +++ b/lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py @@ -59,7 +59,8 @@ def convert_dataset( num_workers: int = 4, ): with SuppressWarnings(): - dataset = LeRobotDataset(repo_id, revision=V20, force_cache_sync=True) + # dataset = LeRobotDataset(repo_id, revision=V20, force_cache_sync=True) + dataset = LeRobotDataset("move_reel_test_0322") if (dataset.root / EPISODES_STATS_PATH).is_file(): (dataset.root / EPISODES_STATS_PATH).unlink() diff --git a/lerobot/common/policies/normalize.py b/lerobot/common/policies/normalize.py index b3255ec1..27e72c11 100644 --- a/lerobot/common/policies/normalize.py +++ b/lerobot/common/policies/normalize.py @@ -168,6 +168,7 @@ class Normalize(nn.Module): std = buffer["std"] assert not torch.isinf(mean).any(), _no_stats_error_str("mean") assert not torch.isinf(std).any(), _no_stats_error_str("std") + print("process key:", key) batch[key] = (batch[key] - mean) / (std + 1e-8) elif norm_mode is NormalizationMode.MIN_MAX: min = buffer["min"] diff --git a/lerobot/common/serving/websocket_policy_server.py b/lerobot/common/serving/websocket_policy_server.py new file mode 100644 index 00000000..2660515e --- /dev/null +++ b/lerobot/common/serving/websocket_policy_server.py @@ -0,0 +1,111 @@ +import asyncio +import logging +import traceback +import einops + +import numpy as np +import torch +import websockets.asyncio.server +import websockets.frames + +import lerobot.common.utils.msgpack_utils as msgpack_utils +from lerobot.common.policies.pretrained import PreTrainedPolicy + +class WebsocketPolicyServer: + """Serves a policy using the websocket protocol. See websocket_client_policy.py for a client implementation. + + Currently only implements the `load` and `infer` methods. + """ + + def __init__( + self, + policy: PreTrainedPolicy, + host: str = "0.0.0.0", + port: int = 8000, + metadata: dict | None = None, + ) -> None: + self._policy = policy + self._host = host + self._port = port + self._metadata = metadata or {} + + def serve_forever(self) -> None: + asyncio.run(self.run()) + + async def run(self): + async with websockets.asyncio.server.serve( + self._handler, + self._host, + self._port, + compression=None, + max_size=None, + ) as server: + await server.serve_forever() + + async def preprocess_observation(self, observations: dict) -> dict[str, torch.Tensor]: + images = observations['images'] + return_observations = {} + + for imgkey, img in images.items(): + img = torch.from_numpy(img) + img = einops.rearrange(img, "h w c -> c h w").contiguous() + c, h, w = img.shape + if h == 360: + img = torch.nn.functional.pad(img,pad=(0, 0, 60, 60), mode='constant', value=0) + c, h, w = img.shape + + assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}" + assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}" + img = img.unsqueeze(0) + img = img.type(torch.float32) + img /= 255 + + imgkey = f"observation.images.{imgkey}" + print('add a key: ', imgkey, img.shape) + return_observations[imgkey] = img + + return_observations["observation.state"] = torch.from_numpy(observations["state"]).float() + return_observations["observation.state"] = return_observations["observation.state"].unsqueeze(0) + + return return_observations + + + async def _handler(self, websocket: websockets.asyncio.server.ServerConnection): + logging.info(f"Connection from {websocket.remote_address} opened") + packer = msgpack_utils.Packer() + + await websocket.send(packer.pack(self._metadata)) + + while True: + try: + # example + # obs = { + # "images": { + # "cam_high": numpy.NDArray, + # "cam_right_wrist": numpy.NDArray, + # }, + # "state": numpy.NDarray, + # "prompt": "xxx text" + # } + obs = msgpack_utils.unpackb(await websocket.recv()) + + obs = await self.preprocess_observation(obs) + for key in obs: + if isinstance(obs[key], torch.Tensor): + obs[key] = obs[key].to("cuda", non_blocking=True) + + with torch.inference_mode(): + action = self._policy.select_action(obs) + print("inference once with action:", action) + res = {"actions": action.cpu().numpy()} + await websocket.send(packer.pack(res)) + except websockets.ConnectionClosed: + logging.info(f"Connection from {websocket.remote_address} closed") + break + except Exception: + await websocket.send(traceback.format_exc()) + await websocket.close( + code=websockets.frames.CloseCode.INTERNAL_ERROR, + reason="Internal server error. Traceback included in previous frame.", + ) + raise diff --git a/lerobot/common/utils/msgpack_utils.py b/lerobot/common/utils/msgpack_utils.py new file mode 100644 index 00000000..8546a881 --- /dev/null +++ b/lerobot/common/utils/msgpack_utils.py @@ -0,0 +1,43 @@ +import functools + +import msgpack +import numpy as np + + +def pack_array(obj): + if (isinstance(obj, (np.ndarray, np.generic))) and obj.dtype.kind in ("V", "O", "c"): + raise ValueError(f"Unsupported dtype: {obj.dtype}") + + if isinstance(obj, np.ndarray): + return { + b"__ndarray__": True, + b"data": obj.tobytes(), + b"dtype": obj.dtype.str, + b"shape": obj.shape, + } + + if isinstance(obj, np.generic): + return { + b"__npgeneric__": True, + b"data": obj.item(), + b"dtype": obj.dtype.str, + } + + return obj + + +def unpack_array(obj): + if b"__ndarray__" in obj: + return np.ndarray(buffer=obj[b"data"], dtype=np.dtype(obj[b"dtype"]), shape=obj[b"shape"]) + + if b"__npgeneric__" in obj: + return np.dtype(obj[b"dtype"]).type(obj[b"data"]) + + return obj + + +Packer = functools.partial(msgpack.Packer, default=pack_array) +packb = functools.partial(msgpack.packb, default=pack_array) + +Unpacker = functools.partial(msgpack.Unpacker, object_hook=unpack_array) +unpackb = functools.partial(msgpack.unpackb, object_hook=unpack_array) \ No newline at end of file diff --git a/lerobot/scripts/openloop_eval.py b/lerobot/scripts/openloop_eval.py new file mode 100644 index 00000000..901d7b66 --- /dev/null +++ b/lerobot/scripts/openloop_eval.py @@ -0,0 +1,180 @@ +from contextlib import nullcontext +from pprint import pformat +from termcolor import colored +import dataclasses +import logging +from typing import Iterator + +import wandb +import numpy as np +import torch + +import lerobot.common.datasets.lerobot_dataset as lerobot_dataset +from lerobot.common.datasets.factory import make_dataset +from lerobot.common.datasets.utils import cycle +from lerobot.common.policies.factory import make_policy +from lerobot.common.utils.utils import ( + init_logging, + get_safe_torch_device, +) +from lerobot.common.utils.random_utils import set_seed +from lerobot.configs import parser +from lerobot.configs.train import TrainPipelineConfig + +class EpisodeSampler(torch.utils.data.Sampler): + def __init__(self, dataset: lerobot_dataset.LeRobotDataset, episode_index: int): + from_idx = dataset.episode_data_index["from"][episode_index].item() + to_idx = dataset.episode_data_index["to"][episode_index].item() + self.frame_ids = range(from_idx, to_idx) + + def __iter__(self) -> Iterator: + return iter(self.frame_ids) + + def __len__(self) -> int: + return len(self.frame_ids) + +@parser.wrap() +def main(cfg: TrainPipelineConfig): + init_logging() + logging.info(pformat(dataclasses.asdict(cfg))) + + device = get_safe_torch_device(cfg.policy.device, log=True) + torch.backends.cudnn.benchmark = True + torch.backends.cuda.matmul.allow_tf32 = True + set_seed(cfg.seed) + + wandb.init( + name=cfg.job_name, + config=dataclasses.asdict(cfg), + project="lerobot", + ) + + logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}") + + logging.info("Making dataset.") + dataset = make_dataset(cfg) + dataset.meta.info['features']['observation.state']['shape'] = (13,) + dataset.meta.info['features']['observation.state']['names'] = ['left_arm_1', 'left_arm_2', 'left_arm_3', 'left_arm_4', 'left_arm_5', 'left_arm_6', + 'right_arm_1', 'right_arm_2', 'right_arm_3', 'right_arm_4', 'right_arm_5', 'right_arm_6', 'vacuum'] + dataset.meta.info['features']['action']['shape'] = (13,) + dataset.meta.info['features']['action']['names'] = ['left_arm_exp_1', 'left_arm_exp_2', 'left_arm_exp_3', 'left_arm_exp_4', 'left_arm_exp_5', 'left_arm_exp_6', + 'right_arm_exp_1', 'right_arm_exp_2', 'right_arm_exp_3', 'right_arm_exp_4', 'right_arm_exp_5', 'right_arm_exp_6', 'vacuum_exp'] + + for key, episode_stats in dataset.meta.episodes_stats.items(): + for feature in ['observation.state', 'action']: + for sub_feature in ['min', 'max', 'mean', 'std']: + episode_stats[feature][sub_feature] = np.delete(episode_stats[feature][sub_feature], [6,13], axis=0) + dataset.meta.episodes_stats[key] = episode_stats + for feature in ['min', 'max', 'mean', 'std']: + dataset.meta.stats['observation.state'][feature] = np.delete(dataset.meta.stats['observation.state'][feature], [6,13], axis=0) + dataset.meta.stats['action'][feature] = np.delete(dataset.meta.stats['action'][feature], [6,13], axis=0) + + + logging.info("Making policy.") + policy = make_policy( + cfg=cfg.policy, + ds_meta=dataset.meta, + ) + + # TODO: 支持多条,从配置读取 + episode_index = 0 + + sampler = EpisodeSampler( + dataset, episode_index + ) + data_len = len(sampler) + print("entire data len:", data_len) + dataloader = torch.utils.data.DataLoader( + dataset, + num_workers=cfg.num_workers, + batch_size=cfg.batch_size, + sampler=sampler, + ) + + policy.eval() + + # aloha + # dim_names = [ + # "arm_left_0", "arm_left_1", "arm_left_2", "arm_left_3", "arm_left_4", "arm_left_5", + # "arm_left_gripper_0", "arm_left_gripper_1", + # "arm_right_0", "arm_right_1", "arm_right_2", "arm_right_3", "arm_right_4", "arm_right_5", + # "arm_right_gripper_0", "arm_right_gripper_1", + # ] + + # galaxea vacumn - 13 dim + dim_names = [ + "arm_left_0", "arm_left_1", "arm_left_2", "arm_left_3", "arm_left_4", "arm_left_5", + "arm_right_0", "arm_right_1", "arm_right_2", "arm_right_3", "arm_right_4", "arm_right_5", + "arm_right_vacumn", + ] + # galaxea vacumn - 15 dim + # dim_names = [ + # "arm_left_0", "arm_left_1", "arm_left_2", "arm_left_3", "arm_left_4", "arm_left_5", "arm_left_gripper", + # "arm_right_0", "arm_right_1", "arm_right_2", "arm_right_3", "arm_right_4", "arm_right_5","arm_right_gripper", + # "arm_right_vacumn", + # ] + + with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(): + # infer_action = policy.forward(input_raw_data) + dl_iter = cycle(dataloader) + + # skip the first few frames + for i in range(50): + a = next(dl_iter) + + + for step in range(200): + # 准备数据 + input_raw_data = next(dl_iter) + + # shape记录 galaxea + # observation.images.front: torch.Size([batch, 3, 360, 640]) + # observation.images.wrist_right: torch.Size([batch, 3, 480, 640]) + # observation.state torch.Size([batch, 15]) + if 'observation.images.depth' in input_raw_data: + del input_raw_data['observation.images.depth'] + + if input_raw_data["observation.state"].shape[-1] == 15: + keep_cols = [i for i in range(15) if i not in {6, 13}] + # 切片操作删除指定列 + input_raw_data["observation.state"] = input_raw_data["observation.state"][..., keep_cols] + print("state shape", input_raw_data["observation.state"].shape) + + if input_raw_data["action"].shape[-1] == 15: + keep_cols = [i for i in range(15) if i not in {6, 13}] + input_raw_data["action"] = input_raw_data["action"][..., keep_cols] + print("action shape", input_raw_data["action"].shape) + + for key in input_raw_data: + if isinstance(input_raw_data[key], torch.Tensor): + input_raw_data[key] = input_raw_data[key].to(device, non_blocking=True) + + # the front camera resolution is 640x360, pad it to 640x480 which is the resolution of the wrist camera + input_raw_data['observation.images.front'] = torch.nn.functional.pad( + input_raw_data['observation.images.front'], + pad=(0, 0, 60, 60), # 左右不填充,上下各填充60 + mode='constant', + value=0 # black + ) + logging.warning("pad the front images to ", input_raw_data['observation.images.front'].shape) + + # 推理 + infer_action = policy.select_action(input_raw_data) + infer_action = infer_action[0].cpu().numpy() + + + origin_action = input_raw_data["action"].cpu().numpy()[0][0] + + # 记录每个action dim的差异 + diff = infer_action - origin_action + log_dict = {} + for dim in range(cfg.policy.output_features["action"].shape[0]): + key = f"{dim_names[dim]}_diff" + log_dict[key] = diff[dim] + + wandb.log(log_dict) + + +if __name__ == "__main__": + main() + diff --git a/lerobot/scripts/serving_policy.py b/lerobot/scripts/serving_policy.py new file mode 100644 index 00000000..79fa2304 --- /dev/null +++ b/lerobot/scripts/serving_policy.py @@ -0,0 +1,81 @@ +import dataclasses +import enum +import logging +import socket +import tyro + +import torch + +from lerobot.common.utils.utils import ( + init_logging, + get_safe_torch_device, +) +from lerobot.common.utils.random_utils import set_seed +from lerobot.configs import parser +from lerobot.configs.eval import EvalPipelineConfig +from lerobot.common.serving.websocket_policy_server import WebsocketPolicyServer +from lerobot.common.policies.act.modeling_act import ACTPolicy +from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy + + +class PolicyType(enum.Enum): + """Supported environments.""" + + ACT = "act" + DIFFUSION = "diffusion" + PI0 = "pi0" + +@dataclasses.dataclass +class Checkpoint: + """Load a policy from a trained checkpoint.""" + + # Checkpoint directory (e.g., "outputs/train/act_move_reel_0322_nodepth/checkpoints/040000/pretrained_mode"). + path: str + + # policy type + type: str + +@dataclasses.dataclass +class Args: + """Arguments for the serve_policy script.""" + # If provided, will be used in case the "prompt" key is not present in the data, or if the model doesn't have a default + # prompt. + default_prompt: str | None = None + + # Port to serve the policy on. + port: int = 8000 + # Record the policy's behavior for debugging. + record: bool = False + + # Specifies how to load the policy. If not provided, the default policy for the environment will be used. + policy: Checkpoint = dataclasses.field(default_factory=Checkpoint) + +def main(args: Args) -> None: + torch.backends.cudnn.benchmark = True + torch.backends.cuda.matmul.allow_tf32 = True + set_seed(1000) + + print(args) + + # policy has been in device and evaluated + if args.policy.type == PolicyType.ACT.value: + policy = ACTPolicy.from_pretrained(args.policy.path) + elif args.policy.type == PolicyType.DIFFUSION.value: + policy = DiffusionPolicy.from_pretrained(args.policy.path) + + # Record the policy's behavior. + # if args.record: + # policy = _policy.PolicyRecorder(policy, "policy_records") + + server = WebsocketPolicyServer( + policy=policy, + host="0.0.0.0", + port=args.port, + metadata={}, + ) + server.serve_forever() + + +if __name__ == "__main__": + init_logging() + main(tyro.cli(Args)) diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 0de247be..6d2cf4d2 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -19,6 +19,7 @@ from contextlib import nullcontext from pprint import pformat from typing import Any +import numpy as np import torch from termcolor import colored from torch.amp import GradScaler @@ -126,7 +127,24 @@ def train(cfg: TrainPipelineConfig): logging.info("Creating dataset") dataset = make_dataset(cfg) - + dataset.meta.info['features']['observation.state']['shape'] = (13,) + dataset.meta.info['features']['observation.state']['names'] = ['left_arm_1', 'left_arm_2', 'left_arm_3', 'left_arm_4', 'left_arm_5', 'left_arm_6', + 'right_arm_1', 'right_arm_2', 'right_arm_3', 'right_arm_4', 'right_arm_5', 'right_arm_6', 'vacuum'] + dataset.meta.info['features']['action']['shape'] = (13,) + dataset.meta.info['features']['action']['names'] = ['left_arm_exp_1', 'left_arm_exp_2', 'left_arm_exp_3', 'left_arm_exp_4', 'left_arm_exp_5', 'left_arm_exp_6', + 'right_arm_exp_1', 'right_arm_exp_2', 'right_arm_exp_3', 'right_arm_exp_4', 'right_arm_exp_5', 'right_arm_exp_6', 'vacuum_exp'] + + for key, episode_stats in dataset.meta.episodes_stats.items(): + for feature in ['observation.state', 'action']: + for sub_feature in ['min', 'max', 'mean', 'std']: + episode_stats[feature][sub_feature] = np.delete(episode_stats[feature][sub_feature], [6,13], axis=0) + dataset.meta.episodes_stats[key] = episode_stats + for feature in ['min', 'max', 'mean', 'std']: + dataset.meta.stats['observation.state'][feature] = np.delete(dataset.meta.stats['observation.state'][feature], [6,13], axis=0) + dataset.meta.stats['action'][feature] = np.delete(dataset.meta.stats['action'][feature], [6,13], axis=0) + + + # Create environment used for evaluating checkpoints during training on simulation data. # On real-world data, no need to create an environment as evaluations are done outside train.py, # using the eval.py instead, with gym_dora environment and dora-rs. @@ -208,7 +226,30 @@ def train(cfg: TrainPipelineConfig): for key in batch: if isinstance(batch[key], torch.Tensor): batch[key] = batch[key].to(device, non_blocking=True) - + + # remove unused depth info + if 'observation.images.depth' in batch: + del batch['observation.images.depth'] + + # remove redundant joint states + if batch["observation.state"].shape[-1] == 15: + keep_cols = [i for i in range(15) if i not in {6, 13}] + # 切片操作删除指定列 + batch["observation.state"] = batch["observation.state"][..., keep_cols] + + # remove redundant joint states + if batch["action"].shape[-1] == 15: + keep_cols = [i for i in range(15) if i not in {6, 13}] + batch["action"] = batch["action"][..., keep_cols] + + # the front camera resolution is 640x360, pad it to 640x480 which is the resolution of the wrist camera + batch['observation.images.front'] = torch.nn.functional.pad( + batch['observation.images.front'], + pad=(0, 0, 60, 60), # 左右不填充,上下各填充60 + mode='constant', + value=0 # black + ) + train_tracker, output_dict = update_policy( train_tracker, policy, diff --git a/pyproject.toml b/pyproject.toml index 6b9b6802..e06b9856 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -85,6 +85,7 @@ dynamixel = ["dynamixel-sdk>=3.7.31", "pynput>=1.7.7"] feetech = ["feetech-servo-sdk>=1.0.0", "pynput>=1.7.7"] intelrealsense = ["pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'"] pi0 = ["transformers>=4.48.0"] +serving = ["websockets==14.1", "msgpack==1.1.0", "tyro"] pusht = ["gym-pusht>=0.1.5 ; python_version < '4.0'"] stretch = [ "hello-robot-stretch-body>=0.7.27 ; python_version < '4.0' and sys_platform == 'linux'", diff --git a/tests/serving/test_websocket_serving.py b/tests/serving/test_websocket_serving.py new file mode 100644 index 00000000..77459686 --- /dev/null +++ b/tests/serving/test_websocket_serving.py @@ -0,0 +1,43 @@ +import logging +import time + +import numpy as np +import websockets.sync.client + +import lerobot.common.utils.msgpack_utils as msgpack_utils + +input = { + "state": np.ones((13,)), + "images": { + # input images from client has spec h w c (client) + "front": np.random.randint(256, size=(480, 640, 3), dtype=np.uint8), + "wrist_right": np.random.randint(256, size=(480, 640, 3), dtype=np.uint8), + }, + "prompt": "do something", +} + +url = "ws://127.0.0.1:8000" +packer = msgpack_utils.Packer() + +logging.info(f"Waiting for server at {url}...") +while True: + try: + conn = websockets.sync.client.connect(url, compression=None, max_size=None) + metadata = msgpack_utils.unpackb(conn.recv()) + break + except ConnectionRefusedError: + logging.info("Still waiting for server...") + time.sleep(5) + +data = packer.pack(input) +conn.send(data) +response = conn.recv() +if isinstance(response, str): + # we're expecting bytes; if the server sends a string, it's an error. + print(f"Error in inference server:\n{response}") + exit() + +infer_result = msgpack_utils.unpackb(response) +print(infer_result) +assert len(infer_result['actions'][0]) == len(input['state']) +