Support openloop eval and serving

This commit is contained in:
ruanafan 2025-04-01 18:20:36 +08:00
parent 1c873df5c0
commit ae70f12378
11 changed files with 535 additions and 3 deletions

View File

@ -103,6 +103,14 @@ class LeRobotDatasetMetadata:
def load_metadata(self):
self.info = load_info(self.root)
# self.info['features']['observation.state']['shape'] = (13,)
# self.info['features']['observation.state']['names'] = ['left_arm_1', 'left_arm_2', 'left_arm_3', 'left_arm_4', 'left_arm_5', 'left_arm_6',
# 'right_arm_1', 'right_arm_2', 'right_arm_3', 'right_arm_4', 'right_arm_5', 'right_arm_6', 'vacuum']
# self.info['features']['action']['shape'] = (13,)
# self.info['features']['action']['names'] = ['left_arm_exp_1', 'left_arm_exp_2', 'left_arm_exp_3', 'left_arm_exp_4', 'left_arm_exp_5', 'left_arm_exp_6',
# 'right_arm_exp_1', 'right_arm_exp_2', 'right_arm_exp_3', 'right_arm_exp_4', 'right_arm_exp_5', 'right_arm_exp_6', 'vacuum_exp']
check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION)
self.tasks, self.task_to_task_index = load_tasks(self.root)
self.episodes = load_episodes(self.root)
@ -110,7 +118,15 @@ class LeRobotDatasetMetadata:
self.stats = load_stats(self.root)
self.episodes_stats = backward_compatible_episodes_stats(self.stats, self.episodes)
else:
# print("v2.1 episode stat process")
self.episodes_stats = load_episodes_stats(self.root)
# print("episode stats: ", type(self.episodes_stats))
# for _, episode_stats in self.episodes_stats.items():
# for feature in ['observation.state', 'action']:
# for sub_feature in ['min', 'max', 'mean', 'std']:
# episode_stats[feature][sub_feature] = np.delete(episode_stats[feature][sub_feature], [6,13], axis=0)
# print('after process', feature, sub_feature, episode_stats[feature][sub_feature].shape)
self.stats = aggregate_stats(list(self.episodes_stats.values()))
def pull_from_repo(
@ -731,6 +747,17 @@ class LeRobotDataset(torch.utils.data.Dataset):
query_result = self._query_hf_dataset(query_indices)
item = {**item, **padding}
for key, val in query_result.items():
# # remove gripper in original data
# if "observation.state" == key:
# keep_cols = [i for i in range(15) if i not in {6, 13}]
# # 切片操作删除指定列
# val = val[:, keep_cols]
# # print("emmmmmm in __getitem_ state", val.shape)
# if "action" == key:
# keep_cols = [i for i in range(15) if i not in {6, 13}]
# # 切片操作删除指定列
# val = val[:, keep_cols]
# # print("emmmmmm in __getitem_ action", val.shape)
item[key] = val
if len(self.meta.video_keys) > 0:

View File

@ -420,6 +420,9 @@ def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFea
else:
continue
if 'depth' in key:
continue
policy_features[key] = PolicyFeature(
type=type,
shape=shape,

View File

@ -59,7 +59,8 @@ def convert_dataset(
num_workers: int = 4,
):
with SuppressWarnings():
dataset = LeRobotDataset(repo_id, revision=V20, force_cache_sync=True)
# dataset = LeRobotDataset(repo_id, revision=V20, force_cache_sync=True)
dataset = LeRobotDataset("move_reel_test_0322")
if (dataset.root / EPISODES_STATS_PATH).is_file():
(dataset.root / EPISODES_STATS_PATH).unlink()

View File

@ -168,6 +168,7 @@ class Normalize(nn.Module):
std = buffer["std"]
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
assert not torch.isinf(std).any(), _no_stats_error_str("std")
print("process key:", key)
batch[key] = (batch[key] - mean) / (std + 1e-8)
elif norm_mode is NormalizationMode.MIN_MAX:
min = buffer["min"]

View File

@ -0,0 +1,111 @@
import asyncio
import logging
import traceback
import einops
import numpy as np
import torch
import websockets.asyncio.server
import websockets.frames
import lerobot.common.utils.msgpack_utils as msgpack_utils
from lerobot.common.policies.pretrained import PreTrainedPolicy
class WebsocketPolicyServer:
"""Serves a policy using the websocket protocol. See websocket_client_policy.py for a client implementation.
Currently only implements the `load` and `infer` methods.
"""
def __init__(
self,
policy: PreTrainedPolicy,
host: str = "0.0.0.0",
port: int = 8000,
metadata: dict | None = None,
) -> None:
self._policy = policy
self._host = host
self._port = port
self._metadata = metadata or {}
def serve_forever(self) -> None:
asyncio.run(self.run())
async def run(self):
async with websockets.asyncio.server.serve(
self._handler,
self._host,
self._port,
compression=None,
max_size=None,
) as server:
await server.serve_forever()
async def preprocess_observation(self, observations: dict) -> dict[str, torch.Tensor]:
images = observations['images']
return_observations = {}
for imgkey, img in images.items():
img = torch.from_numpy(img)
img = einops.rearrange(img, "h w c -> c h w").contiguous()
c, h, w = img.shape
if h == 360:
img = torch.nn.functional.pad(img,pad=(0, 0, 60, 60), mode='constant', value=0)
c, h, w = img.shape
assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}"
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
img = img.unsqueeze(0)
img = img.type(torch.float32)
img /= 255
imgkey = f"observation.images.{imgkey}"
print('add a key: ', imgkey, img.shape)
return_observations[imgkey] = img
return_observations["observation.state"] = torch.from_numpy(observations["state"]).float()
return_observations["observation.state"] = return_observations["observation.state"].unsqueeze(0)
return return_observations
async def _handler(self, websocket: websockets.asyncio.server.ServerConnection):
logging.info(f"Connection from {websocket.remote_address} opened")
packer = msgpack_utils.Packer()
await websocket.send(packer.pack(self._metadata))
while True:
try:
# example
# obs = {
# "images": {
# "cam_high": numpy.NDArray,
# "cam_right_wrist": numpy.NDArray,
# },
# "state": numpy.NDarray,
# "prompt": "xxx text"
# }
obs = msgpack_utils.unpackb(await websocket.recv())
obs = await self.preprocess_observation(obs)
for key in obs:
if isinstance(obs[key], torch.Tensor):
obs[key] = obs[key].to("cuda", non_blocking=True)
with torch.inference_mode():
action = self._policy.select_action(obs)
print("inference once with action:", action)
res = {"actions": action.cpu().numpy()}
await websocket.send(packer.pack(res))
except websockets.ConnectionClosed:
logging.info(f"Connection from {websocket.remote_address} closed")
break
except Exception:
await websocket.send(traceback.format_exc())
await websocket.close(
code=websockets.frames.CloseCode.INTERNAL_ERROR,
reason="Internal server error. Traceback included in previous frame.",
)
raise

View File

@ -0,0 +1,43 @@
import functools
import msgpack
import numpy as np
def pack_array(obj):
if (isinstance(obj, (np.ndarray, np.generic))) and obj.dtype.kind in ("V", "O", "c"):
raise ValueError(f"Unsupported dtype: {obj.dtype}")
if isinstance(obj, np.ndarray):
return {
b"__ndarray__": True,
b"data": obj.tobytes(),
b"dtype": obj.dtype.str,
b"shape": obj.shape,
}
if isinstance(obj, np.generic):
return {
b"__npgeneric__": True,
b"data": obj.item(),
b"dtype": obj.dtype.str,
}
return obj
def unpack_array(obj):
if b"__ndarray__" in obj:
return np.ndarray(buffer=obj[b"data"], dtype=np.dtype(obj[b"dtype"]), shape=obj[b"shape"])
if b"__npgeneric__" in obj:
return np.dtype(obj[b"dtype"]).type(obj[b"data"])
return obj
Packer = functools.partial(msgpack.Packer, default=pack_array)
packb = functools.partial(msgpack.packb, default=pack_array)
Unpacker = functools.partial(msgpack.Unpacker, object_hook=unpack_array)
unpackb = functools.partial(msgpack.unpackb, object_hook=unpack_array)

View File

@ -0,0 +1,180 @@
from contextlib import nullcontext
from pprint import pformat
from termcolor import colored
import dataclasses
import logging
from typing import Iterator
import wandb
import numpy as np
import torch
import lerobot.common.datasets.lerobot_dataset as lerobot_dataset
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.utils import cycle
from lerobot.common.policies.factory import make_policy
from lerobot.common.utils.utils import (
init_logging,
get_safe_torch_device,
)
from lerobot.common.utils.random_utils import set_seed
from lerobot.configs import parser
from lerobot.configs.train import TrainPipelineConfig
class EpisodeSampler(torch.utils.data.Sampler):
def __init__(self, dataset: lerobot_dataset.LeRobotDataset, episode_index: int):
from_idx = dataset.episode_data_index["from"][episode_index].item()
to_idx = dataset.episode_data_index["to"][episode_index].item()
self.frame_ids = range(from_idx, to_idx)
def __iter__(self) -> Iterator:
return iter(self.frame_ids)
def __len__(self) -> int:
return len(self.frame_ids)
@parser.wrap()
def main(cfg: TrainPipelineConfig):
init_logging()
logging.info(pformat(dataclasses.asdict(cfg)))
device = get_safe_torch_device(cfg.policy.device, log=True)
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
set_seed(cfg.seed)
wandb.init(
name=cfg.job_name,
config=dataclasses.asdict(cfg),
project="lerobot",
)
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}")
logging.info("Making dataset.")
dataset = make_dataset(cfg)
dataset.meta.info['features']['observation.state']['shape'] = (13,)
dataset.meta.info['features']['observation.state']['names'] = ['left_arm_1', 'left_arm_2', 'left_arm_3', 'left_arm_4', 'left_arm_5', 'left_arm_6',
'right_arm_1', 'right_arm_2', 'right_arm_3', 'right_arm_4', 'right_arm_5', 'right_arm_6', 'vacuum']
dataset.meta.info['features']['action']['shape'] = (13,)
dataset.meta.info['features']['action']['names'] = ['left_arm_exp_1', 'left_arm_exp_2', 'left_arm_exp_3', 'left_arm_exp_4', 'left_arm_exp_5', 'left_arm_exp_6',
'right_arm_exp_1', 'right_arm_exp_2', 'right_arm_exp_3', 'right_arm_exp_4', 'right_arm_exp_5', 'right_arm_exp_6', 'vacuum_exp']
for key, episode_stats in dataset.meta.episodes_stats.items():
for feature in ['observation.state', 'action']:
for sub_feature in ['min', 'max', 'mean', 'std']:
episode_stats[feature][sub_feature] = np.delete(episode_stats[feature][sub_feature], [6,13], axis=0)
dataset.meta.episodes_stats[key] = episode_stats
for feature in ['min', 'max', 'mean', 'std']:
dataset.meta.stats['observation.state'][feature] = np.delete(dataset.meta.stats['observation.state'][feature], [6,13], axis=0)
dataset.meta.stats['action'][feature] = np.delete(dataset.meta.stats['action'][feature], [6,13], axis=0)
logging.info("Making policy.")
policy = make_policy(
cfg=cfg.policy,
ds_meta=dataset.meta,
)
# TODO: 支持多条,从配置读取
episode_index = 0
sampler = EpisodeSampler(
dataset, episode_index
)
data_len = len(sampler)
print("entire data len:", data_len)
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=cfg.num_workers,
batch_size=cfg.batch_size,
sampler=sampler,
)
policy.eval()
# aloha
# dim_names = [
# "arm_left_0", "arm_left_1", "arm_left_2", "arm_left_3", "arm_left_4", "arm_left_5",
# "arm_left_gripper_0", "arm_left_gripper_1",
# "arm_right_0", "arm_right_1", "arm_right_2", "arm_right_3", "arm_right_4", "arm_right_5",
# "arm_right_gripper_0", "arm_right_gripper_1",
# ]
# galaxea vacumn - 13 dim
dim_names = [
"arm_left_0", "arm_left_1", "arm_left_2", "arm_left_3", "arm_left_4", "arm_left_5",
"arm_right_0", "arm_right_1", "arm_right_2", "arm_right_3", "arm_right_4", "arm_right_5",
"arm_right_vacumn",
]
# galaxea vacumn - 15 dim
# dim_names = [
# "arm_left_0", "arm_left_1", "arm_left_2", "arm_left_3", "arm_left_4", "arm_left_5", "arm_left_gripper",
# "arm_right_0", "arm_right_1", "arm_right_2", "arm_right_3", "arm_right_4", "arm_right_5","arm_right_gripper",
# "arm_right_vacumn",
# ]
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext():
# infer_action = policy.forward(input_raw_data)
dl_iter = cycle(dataloader)
# skip the first few frames
for i in range(50):
a = next(dl_iter)
for step in range(200):
# 准备数据
input_raw_data = next(dl_iter)
# shape记录 galaxea
# observation.images.front: torch.Size([batch, 3, 360, 640])
# observation.images.wrist_right: torch.Size([batch, 3, 480, 640])
# observation.state torch.Size([batch, 15])
if 'observation.images.depth' in input_raw_data:
del input_raw_data['observation.images.depth']
if input_raw_data["observation.state"].shape[-1] == 15:
keep_cols = [i for i in range(15) if i not in {6, 13}]
# 切片操作删除指定列
input_raw_data["observation.state"] = input_raw_data["observation.state"][..., keep_cols]
print("state shape", input_raw_data["observation.state"].shape)
if input_raw_data["action"].shape[-1] == 15:
keep_cols = [i for i in range(15) if i not in {6, 13}]
input_raw_data["action"] = input_raw_data["action"][..., keep_cols]
print("action shape", input_raw_data["action"].shape)
for key in input_raw_data:
if isinstance(input_raw_data[key], torch.Tensor):
input_raw_data[key] = input_raw_data[key].to(device, non_blocking=True)
# the front camera resolution is 640x360, pad it to 640x480 which is the resolution of the wrist camera
input_raw_data['observation.images.front'] = torch.nn.functional.pad(
input_raw_data['observation.images.front'],
pad=(0, 0, 60, 60), # 左右不填充上下各填充60
mode='constant',
value=0 # black
)
logging.warning("pad the front images to ", input_raw_data['observation.images.front'].shape)
# 推理
infer_action = policy.select_action(input_raw_data)
infer_action = infer_action[0].cpu().numpy()
origin_action = input_raw_data["action"].cpu().numpy()[0][0]
# 记录每个action dim的差异
diff = infer_action - origin_action
log_dict = {}
for dim in range(cfg.policy.output_features["action"].shape[0]):
key = f"{dim_names[dim]}_diff"
log_dict[key] = diff[dim]
wandb.log(log_dict)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,81 @@
import dataclasses
import enum
import logging
import socket
import tyro
import torch
from lerobot.common.utils.utils import (
init_logging,
get_safe_torch_device,
)
from lerobot.common.utils.random_utils import set_seed
from lerobot.configs import parser
from lerobot.configs.eval import EvalPipelineConfig
from lerobot.common.serving.websocket_policy_server import WebsocketPolicyServer
from lerobot.common.policies.act.modeling_act import ACTPolicy
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
class PolicyType(enum.Enum):
"""Supported environments."""
ACT = "act"
DIFFUSION = "diffusion"
PI0 = "pi0"
@dataclasses.dataclass
class Checkpoint:
"""Load a policy from a trained checkpoint."""
# Checkpoint directory (e.g., "outputs/train/act_move_reel_0322_nodepth/checkpoints/040000/pretrained_mode").
path: str
# policy type
type: str
@dataclasses.dataclass
class Args:
"""Arguments for the serve_policy script."""
# If provided, will be used in case the "prompt" key is not present in the data, or if the model doesn't have a default
# prompt.
default_prompt: str | None = None
# Port to serve the policy on.
port: int = 8000
# Record the policy's behavior for debugging.
record: bool = False
# Specifies how to load the policy. If not provided, the default policy for the environment will be used.
policy: Checkpoint = dataclasses.field(default_factory=Checkpoint)
def main(args: Args) -> None:
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
set_seed(1000)
print(args)
# policy has been in device and evaluated
if args.policy.type == PolicyType.ACT.value:
policy = ACTPolicy.from_pretrained(args.policy.path)
elif args.policy.type == PolicyType.DIFFUSION.value:
policy = DiffusionPolicy.from_pretrained(args.policy.path)
# Record the policy's behavior.
# if args.record:
# policy = _policy.PolicyRecorder(policy, "policy_records")
server = WebsocketPolicyServer(
policy=policy,
host="0.0.0.0",
port=args.port,
metadata={},
)
server.serve_forever()
if __name__ == "__main__":
init_logging()
main(tyro.cli(Args))

View File

@ -19,6 +19,7 @@ from contextlib import nullcontext
from pprint import pformat
from typing import Any
import numpy as np
import torch
from termcolor import colored
from torch.amp import GradScaler
@ -126,6 +127,23 @@ def train(cfg: TrainPipelineConfig):
logging.info("Creating dataset")
dataset = make_dataset(cfg)
dataset.meta.info['features']['observation.state']['shape'] = (13,)
dataset.meta.info['features']['observation.state']['names'] = ['left_arm_1', 'left_arm_2', 'left_arm_3', 'left_arm_4', 'left_arm_5', 'left_arm_6',
'right_arm_1', 'right_arm_2', 'right_arm_3', 'right_arm_4', 'right_arm_5', 'right_arm_6', 'vacuum']
dataset.meta.info['features']['action']['shape'] = (13,)
dataset.meta.info['features']['action']['names'] = ['left_arm_exp_1', 'left_arm_exp_2', 'left_arm_exp_3', 'left_arm_exp_4', 'left_arm_exp_5', 'left_arm_exp_6',
'right_arm_exp_1', 'right_arm_exp_2', 'right_arm_exp_3', 'right_arm_exp_4', 'right_arm_exp_5', 'right_arm_exp_6', 'vacuum_exp']
for key, episode_stats in dataset.meta.episodes_stats.items():
for feature in ['observation.state', 'action']:
for sub_feature in ['min', 'max', 'mean', 'std']:
episode_stats[feature][sub_feature] = np.delete(episode_stats[feature][sub_feature], [6,13], axis=0)
dataset.meta.episodes_stats[key] = episode_stats
for feature in ['min', 'max', 'mean', 'std']:
dataset.meta.stats['observation.state'][feature] = np.delete(dataset.meta.stats['observation.state'][feature], [6,13], axis=0)
dataset.meta.stats['action'][feature] = np.delete(dataset.meta.stats['action'][feature], [6,13], axis=0)
# Create environment used for evaluating checkpoints during training on simulation data.
# On real-world data, no need to create an environment as evaluations are done outside train.py,
@ -209,6 +227,29 @@ def train(cfg: TrainPipelineConfig):
if isinstance(batch[key], torch.Tensor):
batch[key] = batch[key].to(device, non_blocking=True)
# remove unused depth info
if 'observation.images.depth' in batch:
del batch['observation.images.depth']
# remove redundant joint states
if batch["observation.state"].shape[-1] == 15:
keep_cols = [i for i in range(15) if i not in {6, 13}]
# 切片操作删除指定列
batch["observation.state"] = batch["observation.state"][..., keep_cols]
# remove redundant joint states
if batch["action"].shape[-1] == 15:
keep_cols = [i for i in range(15) if i not in {6, 13}]
batch["action"] = batch["action"][..., keep_cols]
# the front camera resolution is 640x360, pad it to 640x480 which is the resolution of the wrist camera
batch['observation.images.front'] = torch.nn.functional.pad(
batch['observation.images.front'],
pad=(0, 0, 60, 60), # 左右不填充上下各填充60
mode='constant',
value=0 # black
)
train_tracker, output_dict = update_policy(
train_tracker,
policy,

View File

@ -85,6 +85,7 @@ dynamixel = ["dynamixel-sdk>=3.7.31", "pynput>=1.7.7"]
feetech = ["feetech-servo-sdk>=1.0.0", "pynput>=1.7.7"]
intelrealsense = ["pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'"]
pi0 = ["transformers>=4.48.0"]
serving = ["websockets==14.1", "msgpack==1.1.0", "tyro"]
pusht = ["gym-pusht>=0.1.5 ; python_version < '4.0'"]
stretch = [
"hello-robot-stretch-body>=0.7.27 ; python_version < '4.0' and sys_platform == 'linux'",

View File

@ -0,0 +1,43 @@
import logging
import time
import numpy as np
import websockets.sync.client
import lerobot.common.utils.msgpack_utils as msgpack_utils
input = {
"state": np.ones((13,)),
"images": {
# input images from client has spec h w c (client)
"front": np.random.randint(256, size=(480, 640, 3), dtype=np.uint8),
"wrist_right": np.random.randint(256, size=(480, 640, 3), dtype=np.uint8),
},
"prompt": "do something",
}
url = "ws://127.0.0.1:8000"
packer = msgpack_utils.Packer()
logging.info(f"Waiting for server at {url}...")
while True:
try:
conn = websockets.sync.client.connect(url, compression=None, max_size=None)
metadata = msgpack_utils.unpackb(conn.recv())
break
except ConnectionRefusedError:
logging.info("Still waiting for server...")
time.sleep(5)
data = packer.pack(input)
conn.send(data)
response = conn.recv()
if isinstance(response, str):
# we're expecting bytes; if the server sends a string, it's an error.
print(f"Error in inference server:\n{response}")
exit()
infer_result = msgpack_utils.unpackb(response)
print(infer_result)
assert len(infer_result['actions'][0]) == len(input['state'])