Support openloop eval and serving

This commit is contained in:
ruanafan 2025-04-01 18:20:36 +08:00
parent 1c873df5c0
commit ae70f12378
11 changed files with 535 additions and 3 deletions

View File

@ -103,6 +103,14 @@ class LeRobotDatasetMetadata:
def load_metadata(self): def load_metadata(self):
self.info = load_info(self.root) self.info = load_info(self.root)
# self.info['features']['observation.state']['shape'] = (13,)
# self.info['features']['observation.state']['names'] = ['left_arm_1', 'left_arm_2', 'left_arm_3', 'left_arm_4', 'left_arm_5', 'left_arm_6',
# 'right_arm_1', 'right_arm_2', 'right_arm_3', 'right_arm_4', 'right_arm_5', 'right_arm_6', 'vacuum']
# self.info['features']['action']['shape'] = (13,)
# self.info['features']['action']['names'] = ['left_arm_exp_1', 'left_arm_exp_2', 'left_arm_exp_3', 'left_arm_exp_4', 'left_arm_exp_5', 'left_arm_exp_6',
# 'right_arm_exp_1', 'right_arm_exp_2', 'right_arm_exp_3', 'right_arm_exp_4', 'right_arm_exp_5', 'right_arm_exp_6', 'vacuum_exp']
check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION) check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION)
self.tasks, self.task_to_task_index = load_tasks(self.root) self.tasks, self.task_to_task_index = load_tasks(self.root)
self.episodes = load_episodes(self.root) self.episodes = load_episodes(self.root)
@ -110,7 +118,15 @@ class LeRobotDatasetMetadata:
self.stats = load_stats(self.root) self.stats = load_stats(self.root)
self.episodes_stats = backward_compatible_episodes_stats(self.stats, self.episodes) self.episodes_stats = backward_compatible_episodes_stats(self.stats, self.episodes)
else: else:
# print("v2.1 episode stat process")
self.episodes_stats = load_episodes_stats(self.root) self.episodes_stats = load_episodes_stats(self.root)
# print("episode stats: ", type(self.episodes_stats))
# for _, episode_stats in self.episodes_stats.items():
# for feature in ['observation.state', 'action']:
# for sub_feature in ['min', 'max', 'mean', 'std']:
# episode_stats[feature][sub_feature] = np.delete(episode_stats[feature][sub_feature], [6,13], axis=0)
# print('after process', feature, sub_feature, episode_stats[feature][sub_feature].shape)
self.stats = aggregate_stats(list(self.episodes_stats.values())) self.stats = aggregate_stats(list(self.episodes_stats.values()))
def pull_from_repo( def pull_from_repo(
@ -731,6 +747,17 @@ class LeRobotDataset(torch.utils.data.Dataset):
query_result = self._query_hf_dataset(query_indices) query_result = self._query_hf_dataset(query_indices)
item = {**item, **padding} item = {**item, **padding}
for key, val in query_result.items(): for key, val in query_result.items():
# # remove gripper in original data
# if "observation.state" == key:
# keep_cols = [i for i in range(15) if i not in {6, 13}]
# # 切片操作删除指定列
# val = val[:, keep_cols]
# # print("emmmmmm in __getitem_ state", val.shape)
# if "action" == key:
# keep_cols = [i for i in range(15) if i not in {6, 13}]
# # 切片操作删除指定列
# val = val[:, keep_cols]
# # print("emmmmmm in __getitem_ action", val.shape)
item[key] = val item[key] = val
if len(self.meta.video_keys) > 0: if len(self.meta.video_keys) > 0:

View File

@ -420,6 +420,9 @@ def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFea
else: else:
continue continue
if 'depth' in key:
continue
policy_features[key] = PolicyFeature( policy_features[key] = PolicyFeature(
type=type, type=type,
shape=shape, shape=shape,

View File

@ -59,7 +59,8 @@ def convert_dataset(
num_workers: int = 4, num_workers: int = 4,
): ):
with SuppressWarnings(): with SuppressWarnings():
dataset = LeRobotDataset(repo_id, revision=V20, force_cache_sync=True) # dataset = LeRobotDataset(repo_id, revision=V20, force_cache_sync=True)
dataset = LeRobotDataset("move_reel_test_0322")
if (dataset.root / EPISODES_STATS_PATH).is_file(): if (dataset.root / EPISODES_STATS_PATH).is_file():
(dataset.root / EPISODES_STATS_PATH).unlink() (dataset.root / EPISODES_STATS_PATH).unlink()

View File

@ -168,6 +168,7 @@ class Normalize(nn.Module):
std = buffer["std"] std = buffer["std"]
assert not torch.isinf(mean).any(), _no_stats_error_str("mean") assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
assert not torch.isinf(std).any(), _no_stats_error_str("std") assert not torch.isinf(std).any(), _no_stats_error_str("std")
print("process key:", key)
batch[key] = (batch[key] - mean) / (std + 1e-8) batch[key] = (batch[key] - mean) / (std + 1e-8)
elif norm_mode is NormalizationMode.MIN_MAX: elif norm_mode is NormalizationMode.MIN_MAX:
min = buffer["min"] min = buffer["min"]

View File

@ -0,0 +1,111 @@
import asyncio
import logging
import traceback
import einops
import numpy as np
import torch
import websockets.asyncio.server
import websockets.frames
import lerobot.common.utils.msgpack_utils as msgpack_utils
from lerobot.common.policies.pretrained import PreTrainedPolicy
class WebsocketPolicyServer:
"""Serves a policy using the websocket protocol. See websocket_client_policy.py for a client implementation.
Currently only implements the `load` and `infer` methods.
"""
def __init__(
self,
policy: PreTrainedPolicy,
host: str = "0.0.0.0",
port: int = 8000,
metadata: dict | None = None,
) -> None:
self._policy = policy
self._host = host
self._port = port
self._metadata = metadata or {}
def serve_forever(self) -> None:
asyncio.run(self.run())
async def run(self):
async with websockets.asyncio.server.serve(
self._handler,
self._host,
self._port,
compression=None,
max_size=None,
) as server:
await server.serve_forever()
async def preprocess_observation(self, observations: dict) -> dict[str, torch.Tensor]:
images = observations['images']
return_observations = {}
for imgkey, img in images.items():
img = torch.from_numpy(img)
img = einops.rearrange(img, "h w c -> c h w").contiguous()
c, h, w = img.shape
if h == 360:
img = torch.nn.functional.pad(img,pad=(0, 0, 60, 60), mode='constant', value=0)
c, h, w = img.shape
assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}"
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
img = img.unsqueeze(0)
img = img.type(torch.float32)
img /= 255
imgkey = f"observation.images.{imgkey}"
print('add a key: ', imgkey, img.shape)
return_observations[imgkey] = img
return_observations["observation.state"] = torch.from_numpy(observations["state"]).float()
return_observations["observation.state"] = return_observations["observation.state"].unsqueeze(0)
return return_observations
async def _handler(self, websocket: websockets.asyncio.server.ServerConnection):
logging.info(f"Connection from {websocket.remote_address} opened")
packer = msgpack_utils.Packer()
await websocket.send(packer.pack(self._metadata))
while True:
try:
# example
# obs = {
# "images": {
# "cam_high": numpy.NDArray,
# "cam_right_wrist": numpy.NDArray,
# },
# "state": numpy.NDarray,
# "prompt": "xxx text"
# }
obs = msgpack_utils.unpackb(await websocket.recv())
obs = await self.preprocess_observation(obs)
for key in obs:
if isinstance(obs[key], torch.Tensor):
obs[key] = obs[key].to("cuda", non_blocking=True)
with torch.inference_mode():
action = self._policy.select_action(obs)
print("inference once with action:", action)
res = {"actions": action.cpu().numpy()}
await websocket.send(packer.pack(res))
except websockets.ConnectionClosed:
logging.info(f"Connection from {websocket.remote_address} closed")
break
except Exception:
await websocket.send(traceback.format_exc())
await websocket.close(
code=websockets.frames.CloseCode.INTERNAL_ERROR,
reason="Internal server error. Traceback included in previous frame.",
)
raise

View File

@ -0,0 +1,43 @@
import functools
import msgpack
import numpy as np
def pack_array(obj):
if (isinstance(obj, (np.ndarray, np.generic))) and obj.dtype.kind in ("V", "O", "c"):
raise ValueError(f"Unsupported dtype: {obj.dtype}")
if isinstance(obj, np.ndarray):
return {
b"__ndarray__": True,
b"data": obj.tobytes(),
b"dtype": obj.dtype.str,
b"shape": obj.shape,
}
if isinstance(obj, np.generic):
return {
b"__npgeneric__": True,
b"data": obj.item(),
b"dtype": obj.dtype.str,
}
return obj
def unpack_array(obj):
if b"__ndarray__" in obj:
return np.ndarray(buffer=obj[b"data"], dtype=np.dtype(obj[b"dtype"]), shape=obj[b"shape"])
if b"__npgeneric__" in obj:
return np.dtype(obj[b"dtype"]).type(obj[b"data"])
return obj
Packer = functools.partial(msgpack.Packer, default=pack_array)
packb = functools.partial(msgpack.packb, default=pack_array)
Unpacker = functools.partial(msgpack.Unpacker, object_hook=unpack_array)
unpackb = functools.partial(msgpack.unpackb, object_hook=unpack_array)

View File

@ -0,0 +1,180 @@
from contextlib import nullcontext
from pprint import pformat
from termcolor import colored
import dataclasses
import logging
from typing import Iterator
import wandb
import numpy as np
import torch
import lerobot.common.datasets.lerobot_dataset as lerobot_dataset
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.utils import cycle
from lerobot.common.policies.factory import make_policy
from lerobot.common.utils.utils import (
init_logging,
get_safe_torch_device,
)
from lerobot.common.utils.random_utils import set_seed
from lerobot.configs import parser
from lerobot.configs.train import TrainPipelineConfig
class EpisodeSampler(torch.utils.data.Sampler):
def __init__(self, dataset: lerobot_dataset.LeRobotDataset, episode_index: int):
from_idx = dataset.episode_data_index["from"][episode_index].item()
to_idx = dataset.episode_data_index["to"][episode_index].item()
self.frame_ids = range(from_idx, to_idx)
def __iter__(self) -> Iterator:
return iter(self.frame_ids)
def __len__(self) -> int:
return len(self.frame_ids)
@parser.wrap()
def main(cfg: TrainPipelineConfig):
init_logging()
logging.info(pformat(dataclasses.asdict(cfg)))
device = get_safe_torch_device(cfg.policy.device, log=True)
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
set_seed(cfg.seed)
wandb.init(
name=cfg.job_name,
config=dataclasses.asdict(cfg),
project="lerobot",
)
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}")
logging.info("Making dataset.")
dataset = make_dataset(cfg)
dataset.meta.info['features']['observation.state']['shape'] = (13,)
dataset.meta.info['features']['observation.state']['names'] = ['left_arm_1', 'left_arm_2', 'left_arm_3', 'left_arm_4', 'left_arm_5', 'left_arm_6',
'right_arm_1', 'right_arm_2', 'right_arm_3', 'right_arm_4', 'right_arm_5', 'right_arm_6', 'vacuum']
dataset.meta.info['features']['action']['shape'] = (13,)
dataset.meta.info['features']['action']['names'] = ['left_arm_exp_1', 'left_arm_exp_2', 'left_arm_exp_3', 'left_arm_exp_4', 'left_arm_exp_5', 'left_arm_exp_6',
'right_arm_exp_1', 'right_arm_exp_2', 'right_arm_exp_3', 'right_arm_exp_4', 'right_arm_exp_5', 'right_arm_exp_6', 'vacuum_exp']
for key, episode_stats in dataset.meta.episodes_stats.items():
for feature in ['observation.state', 'action']:
for sub_feature in ['min', 'max', 'mean', 'std']:
episode_stats[feature][sub_feature] = np.delete(episode_stats[feature][sub_feature], [6,13], axis=0)
dataset.meta.episodes_stats[key] = episode_stats
for feature in ['min', 'max', 'mean', 'std']:
dataset.meta.stats['observation.state'][feature] = np.delete(dataset.meta.stats['observation.state'][feature], [6,13], axis=0)
dataset.meta.stats['action'][feature] = np.delete(dataset.meta.stats['action'][feature], [6,13], axis=0)
logging.info("Making policy.")
policy = make_policy(
cfg=cfg.policy,
ds_meta=dataset.meta,
)
# TODO: 支持多条,从配置读取
episode_index = 0
sampler = EpisodeSampler(
dataset, episode_index
)
data_len = len(sampler)
print("entire data len:", data_len)
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=cfg.num_workers,
batch_size=cfg.batch_size,
sampler=sampler,
)
policy.eval()
# aloha
# dim_names = [
# "arm_left_0", "arm_left_1", "arm_left_2", "arm_left_3", "arm_left_4", "arm_left_5",
# "arm_left_gripper_0", "arm_left_gripper_1",
# "arm_right_0", "arm_right_1", "arm_right_2", "arm_right_3", "arm_right_4", "arm_right_5",
# "arm_right_gripper_0", "arm_right_gripper_1",
# ]
# galaxea vacumn - 13 dim
dim_names = [
"arm_left_0", "arm_left_1", "arm_left_2", "arm_left_3", "arm_left_4", "arm_left_5",
"arm_right_0", "arm_right_1", "arm_right_2", "arm_right_3", "arm_right_4", "arm_right_5",
"arm_right_vacumn",
]
# galaxea vacumn - 15 dim
# dim_names = [
# "arm_left_0", "arm_left_1", "arm_left_2", "arm_left_3", "arm_left_4", "arm_left_5", "arm_left_gripper",
# "arm_right_0", "arm_right_1", "arm_right_2", "arm_right_3", "arm_right_4", "arm_right_5","arm_right_gripper",
# "arm_right_vacumn",
# ]
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext():
# infer_action = policy.forward(input_raw_data)
dl_iter = cycle(dataloader)
# skip the first few frames
for i in range(50):
a = next(dl_iter)
for step in range(200):
# 准备数据
input_raw_data = next(dl_iter)
# shape记录 galaxea
# observation.images.front: torch.Size([batch, 3, 360, 640])
# observation.images.wrist_right: torch.Size([batch, 3, 480, 640])
# observation.state torch.Size([batch, 15])
if 'observation.images.depth' in input_raw_data:
del input_raw_data['observation.images.depth']
if input_raw_data["observation.state"].shape[-1] == 15:
keep_cols = [i for i in range(15) if i not in {6, 13}]
# 切片操作删除指定列
input_raw_data["observation.state"] = input_raw_data["observation.state"][..., keep_cols]
print("state shape", input_raw_data["observation.state"].shape)
if input_raw_data["action"].shape[-1] == 15:
keep_cols = [i for i in range(15) if i not in {6, 13}]
input_raw_data["action"] = input_raw_data["action"][..., keep_cols]
print("action shape", input_raw_data["action"].shape)
for key in input_raw_data:
if isinstance(input_raw_data[key], torch.Tensor):
input_raw_data[key] = input_raw_data[key].to(device, non_blocking=True)
# the front camera resolution is 640x360, pad it to 640x480 which is the resolution of the wrist camera
input_raw_data['observation.images.front'] = torch.nn.functional.pad(
input_raw_data['observation.images.front'],
pad=(0, 0, 60, 60), # 左右不填充上下各填充60
mode='constant',
value=0 # black
)
logging.warning("pad the front images to ", input_raw_data['observation.images.front'].shape)
# 推理
infer_action = policy.select_action(input_raw_data)
infer_action = infer_action[0].cpu().numpy()
origin_action = input_raw_data["action"].cpu().numpy()[0][0]
# 记录每个action dim的差异
diff = infer_action - origin_action
log_dict = {}
for dim in range(cfg.policy.output_features["action"].shape[0]):
key = f"{dim_names[dim]}_diff"
log_dict[key] = diff[dim]
wandb.log(log_dict)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,81 @@
import dataclasses
import enum
import logging
import socket
import tyro
import torch
from lerobot.common.utils.utils import (
init_logging,
get_safe_torch_device,
)
from lerobot.common.utils.random_utils import set_seed
from lerobot.configs import parser
from lerobot.configs.eval import EvalPipelineConfig
from lerobot.common.serving.websocket_policy_server import WebsocketPolicyServer
from lerobot.common.policies.act.modeling_act import ACTPolicy
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
class PolicyType(enum.Enum):
"""Supported environments."""
ACT = "act"
DIFFUSION = "diffusion"
PI0 = "pi0"
@dataclasses.dataclass
class Checkpoint:
"""Load a policy from a trained checkpoint."""
# Checkpoint directory (e.g., "outputs/train/act_move_reel_0322_nodepth/checkpoints/040000/pretrained_mode").
path: str
# policy type
type: str
@dataclasses.dataclass
class Args:
"""Arguments for the serve_policy script."""
# If provided, will be used in case the "prompt" key is not present in the data, or if the model doesn't have a default
# prompt.
default_prompt: str | None = None
# Port to serve the policy on.
port: int = 8000
# Record the policy's behavior for debugging.
record: bool = False
# Specifies how to load the policy. If not provided, the default policy for the environment will be used.
policy: Checkpoint = dataclasses.field(default_factory=Checkpoint)
def main(args: Args) -> None:
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
set_seed(1000)
print(args)
# policy has been in device and evaluated
if args.policy.type == PolicyType.ACT.value:
policy = ACTPolicy.from_pretrained(args.policy.path)
elif args.policy.type == PolicyType.DIFFUSION.value:
policy = DiffusionPolicy.from_pretrained(args.policy.path)
# Record the policy's behavior.
# if args.record:
# policy = _policy.PolicyRecorder(policy, "policy_records")
server = WebsocketPolicyServer(
policy=policy,
host="0.0.0.0",
port=args.port,
metadata={},
)
server.serve_forever()
if __name__ == "__main__":
init_logging()
main(tyro.cli(Args))

View File

@ -19,6 +19,7 @@ from contextlib import nullcontext
from pprint import pformat from pprint import pformat
from typing import Any from typing import Any
import numpy as np
import torch import torch
from termcolor import colored from termcolor import colored
from torch.amp import GradScaler from torch.amp import GradScaler
@ -126,6 +127,23 @@ def train(cfg: TrainPipelineConfig):
logging.info("Creating dataset") logging.info("Creating dataset")
dataset = make_dataset(cfg) dataset = make_dataset(cfg)
dataset.meta.info['features']['observation.state']['shape'] = (13,)
dataset.meta.info['features']['observation.state']['names'] = ['left_arm_1', 'left_arm_2', 'left_arm_3', 'left_arm_4', 'left_arm_5', 'left_arm_6',
'right_arm_1', 'right_arm_2', 'right_arm_3', 'right_arm_4', 'right_arm_5', 'right_arm_6', 'vacuum']
dataset.meta.info['features']['action']['shape'] = (13,)
dataset.meta.info['features']['action']['names'] = ['left_arm_exp_1', 'left_arm_exp_2', 'left_arm_exp_3', 'left_arm_exp_4', 'left_arm_exp_5', 'left_arm_exp_6',
'right_arm_exp_1', 'right_arm_exp_2', 'right_arm_exp_3', 'right_arm_exp_4', 'right_arm_exp_5', 'right_arm_exp_6', 'vacuum_exp']
for key, episode_stats in dataset.meta.episodes_stats.items():
for feature in ['observation.state', 'action']:
for sub_feature in ['min', 'max', 'mean', 'std']:
episode_stats[feature][sub_feature] = np.delete(episode_stats[feature][sub_feature], [6,13], axis=0)
dataset.meta.episodes_stats[key] = episode_stats
for feature in ['min', 'max', 'mean', 'std']:
dataset.meta.stats['observation.state'][feature] = np.delete(dataset.meta.stats['observation.state'][feature], [6,13], axis=0)
dataset.meta.stats['action'][feature] = np.delete(dataset.meta.stats['action'][feature], [6,13], axis=0)
# Create environment used for evaluating checkpoints during training on simulation data. # Create environment used for evaluating checkpoints during training on simulation data.
# On real-world data, no need to create an environment as evaluations are done outside train.py, # On real-world data, no need to create an environment as evaluations are done outside train.py,
@ -209,6 +227,29 @@ def train(cfg: TrainPipelineConfig):
if isinstance(batch[key], torch.Tensor): if isinstance(batch[key], torch.Tensor):
batch[key] = batch[key].to(device, non_blocking=True) batch[key] = batch[key].to(device, non_blocking=True)
# remove unused depth info
if 'observation.images.depth' in batch:
del batch['observation.images.depth']
# remove redundant joint states
if batch["observation.state"].shape[-1] == 15:
keep_cols = [i for i in range(15) if i not in {6, 13}]
# 切片操作删除指定列
batch["observation.state"] = batch["observation.state"][..., keep_cols]
# remove redundant joint states
if batch["action"].shape[-1] == 15:
keep_cols = [i for i in range(15) if i not in {6, 13}]
batch["action"] = batch["action"][..., keep_cols]
# the front camera resolution is 640x360, pad it to 640x480 which is the resolution of the wrist camera
batch['observation.images.front'] = torch.nn.functional.pad(
batch['observation.images.front'],
pad=(0, 0, 60, 60), # 左右不填充上下各填充60
mode='constant',
value=0 # black
)
train_tracker, output_dict = update_policy( train_tracker, output_dict = update_policy(
train_tracker, train_tracker,
policy, policy,

View File

@ -85,6 +85,7 @@ dynamixel = ["dynamixel-sdk>=3.7.31", "pynput>=1.7.7"]
feetech = ["feetech-servo-sdk>=1.0.0", "pynput>=1.7.7"] feetech = ["feetech-servo-sdk>=1.0.0", "pynput>=1.7.7"]
intelrealsense = ["pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'"] intelrealsense = ["pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'"]
pi0 = ["transformers>=4.48.0"] pi0 = ["transformers>=4.48.0"]
serving = ["websockets==14.1", "msgpack==1.1.0", "tyro"]
pusht = ["gym-pusht>=0.1.5 ; python_version < '4.0'"] pusht = ["gym-pusht>=0.1.5 ; python_version < '4.0'"]
stretch = [ stretch = [
"hello-robot-stretch-body>=0.7.27 ; python_version < '4.0' and sys_platform == 'linux'", "hello-robot-stretch-body>=0.7.27 ; python_version < '4.0' and sys_platform == 'linux'",

View File

@ -0,0 +1,43 @@
import logging
import time
import numpy as np
import websockets.sync.client
import lerobot.common.utils.msgpack_utils as msgpack_utils
input = {
"state": np.ones((13,)),
"images": {
# input images from client has spec h w c (client)
"front": np.random.randint(256, size=(480, 640, 3), dtype=np.uint8),
"wrist_right": np.random.randint(256, size=(480, 640, 3), dtype=np.uint8),
},
"prompt": "do something",
}
url = "ws://127.0.0.1:8000"
packer = msgpack_utils.Packer()
logging.info(f"Waiting for server at {url}...")
while True:
try:
conn = websockets.sync.client.connect(url, compression=None, max_size=None)
metadata = msgpack_utils.unpackb(conn.recv())
break
except ConnectionRefusedError:
logging.info("Still waiting for server...")
time.sleep(5)
data = packer.pack(input)
conn.send(data)
response = conn.recv()
if isinstance(response, str):
# we're expecting bytes; if the server sends a string, it's an error.
print(f"Error in inference server:\n{response}")
exit()
infer_result = msgpack_utils.unpackb(response)
print(infer_result)
assert len(infer_result['actions'][0]) == len(input['state'])