Support openloop eval and serving
This commit is contained in:
parent
1c873df5c0
commit
ae70f12378
|
@ -103,6 +103,14 @@ class LeRobotDatasetMetadata:
|
|||
|
||||
def load_metadata(self):
|
||||
self.info = load_info(self.root)
|
||||
# self.info['features']['observation.state']['shape'] = (13,)
|
||||
# self.info['features']['observation.state']['names'] = ['left_arm_1', 'left_arm_2', 'left_arm_3', 'left_arm_4', 'left_arm_5', 'left_arm_6',
|
||||
# 'right_arm_1', 'right_arm_2', 'right_arm_3', 'right_arm_4', 'right_arm_5', 'right_arm_6', 'vacuum']
|
||||
# self.info['features']['action']['shape'] = (13,)
|
||||
# self.info['features']['action']['names'] = ['left_arm_exp_1', 'left_arm_exp_2', 'left_arm_exp_3', 'left_arm_exp_4', 'left_arm_exp_5', 'left_arm_exp_6',
|
||||
# 'right_arm_exp_1', 'right_arm_exp_2', 'right_arm_exp_3', 'right_arm_exp_4', 'right_arm_exp_5', 'right_arm_exp_6', 'vacuum_exp']
|
||||
|
||||
|
||||
check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION)
|
||||
self.tasks, self.task_to_task_index = load_tasks(self.root)
|
||||
self.episodes = load_episodes(self.root)
|
||||
|
@ -110,7 +118,15 @@ class LeRobotDatasetMetadata:
|
|||
self.stats = load_stats(self.root)
|
||||
self.episodes_stats = backward_compatible_episodes_stats(self.stats, self.episodes)
|
||||
else:
|
||||
# print("v2.1 episode stat process")
|
||||
self.episodes_stats = load_episodes_stats(self.root)
|
||||
# print("episode stats: ", type(self.episodes_stats))
|
||||
# for _, episode_stats in self.episodes_stats.items():
|
||||
# for feature in ['observation.state', 'action']:
|
||||
# for sub_feature in ['min', 'max', 'mean', 'std']:
|
||||
# episode_stats[feature][sub_feature] = np.delete(episode_stats[feature][sub_feature], [6,13], axis=0)
|
||||
# print('after process', feature, sub_feature, episode_stats[feature][sub_feature].shape)
|
||||
|
||||
self.stats = aggregate_stats(list(self.episodes_stats.values()))
|
||||
|
||||
def pull_from_repo(
|
||||
|
@ -731,6 +747,17 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
query_result = self._query_hf_dataset(query_indices)
|
||||
item = {**item, **padding}
|
||||
for key, val in query_result.items():
|
||||
# # remove gripper in original data
|
||||
# if "observation.state" == key:
|
||||
# keep_cols = [i for i in range(15) if i not in {6, 13}]
|
||||
# # 切片操作删除指定列
|
||||
# val = val[:, keep_cols]
|
||||
# # print("emmmmmm in __getitem_ state", val.shape)
|
||||
# if "action" == key:
|
||||
# keep_cols = [i for i in range(15) if i not in {6, 13}]
|
||||
# # 切片操作删除指定列
|
||||
# val = val[:, keep_cols]
|
||||
# # print("emmmmmm in __getitem_ action", val.shape)
|
||||
item[key] = val
|
||||
|
||||
if len(self.meta.video_keys) > 0:
|
||||
|
|
|
@ -420,6 +420,9 @@ def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFea
|
|||
else:
|
||||
continue
|
||||
|
||||
if 'depth' in key:
|
||||
continue
|
||||
|
||||
policy_features[key] = PolicyFeature(
|
||||
type=type,
|
||||
shape=shape,
|
||||
|
|
|
@ -59,7 +59,8 @@ def convert_dataset(
|
|||
num_workers: int = 4,
|
||||
):
|
||||
with SuppressWarnings():
|
||||
dataset = LeRobotDataset(repo_id, revision=V20, force_cache_sync=True)
|
||||
# dataset = LeRobotDataset(repo_id, revision=V20, force_cache_sync=True)
|
||||
dataset = LeRobotDataset("move_reel_test_0322")
|
||||
|
||||
if (dataset.root / EPISODES_STATS_PATH).is_file():
|
||||
(dataset.root / EPISODES_STATS_PATH).unlink()
|
||||
|
|
|
@ -168,6 +168,7 @@ class Normalize(nn.Module):
|
|||
std = buffer["std"]
|
||||
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
|
||||
assert not torch.isinf(std).any(), _no_stats_error_str("std")
|
||||
print("process key:", key)
|
||||
batch[key] = (batch[key] - mean) / (std + 1e-8)
|
||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||
min = buffer["min"]
|
||||
|
|
|
@ -0,0 +1,111 @@
|
|||
import asyncio
|
||||
import logging
|
||||
import traceback
|
||||
import einops
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import websockets.asyncio.server
|
||||
import websockets.frames
|
||||
|
||||
import lerobot.common.utils.msgpack_utils as msgpack_utils
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
|
||||
class WebsocketPolicyServer:
|
||||
"""Serves a policy using the websocket protocol. See websocket_client_policy.py for a client implementation.
|
||||
|
||||
Currently only implements the `load` and `infer` methods.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
policy: PreTrainedPolicy,
|
||||
host: str = "0.0.0.0",
|
||||
port: int = 8000,
|
||||
metadata: dict | None = None,
|
||||
) -> None:
|
||||
self._policy = policy
|
||||
self._host = host
|
||||
self._port = port
|
||||
self._metadata = metadata or {}
|
||||
|
||||
def serve_forever(self) -> None:
|
||||
asyncio.run(self.run())
|
||||
|
||||
async def run(self):
|
||||
async with websockets.asyncio.server.serve(
|
||||
self._handler,
|
||||
self._host,
|
||||
self._port,
|
||||
compression=None,
|
||||
max_size=None,
|
||||
) as server:
|
||||
await server.serve_forever()
|
||||
|
||||
async def preprocess_observation(self, observations: dict) -> dict[str, torch.Tensor]:
|
||||
images = observations['images']
|
||||
return_observations = {}
|
||||
|
||||
for imgkey, img in images.items():
|
||||
img = torch.from_numpy(img)
|
||||
img = einops.rearrange(img, "h w c -> c h w").contiguous()
|
||||
c, h, w = img.shape
|
||||
if h == 360:
|
||||
img = torch.nn.functional.pad(img,pad=(0, 0, 60, 60), mode='constant', value=0)
|
||||
c, h, w = img.shape
|
||||
|
||||
assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}"
|
||||
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
|
||||
img = img.unsqueeze(0)
|
||||
img = img.type(torch.float32)
|
||||
img /= 255
|
||||
|
||||
imgkey = f"observation.images.{imgkey}"
|
||||
print('add a key: ', imgkey, img.shape)
|
||||
return_observations[imgkey] = img
|
||||
|
||||
return_observations["observation.state"] = torch.from_numpy(observations["state"]).float()
|
||||
return_observations["observation.state"] = return_observations["observation.state"].unsqueeze(0)
|
||||
|
||||
return return_observations
|
||||
|
||||
|
||||
async def _handler(self, websocket: websockets.asyncio.server.ServerConnection):
|
||||
logging.info(f"Connection from {websocket.remote_address} opened")
|
||||
packer = msgpack_utils.Packer()
|
||||
|
||||
await websocket.send(packer.pack(self._metadata))
|
||||
|
||||
while True:
|
||||
try:
|
||||
# example
|
||||
# obs = {
|
||||
# "images": {
|
||||
# "cam_high": numpy.NDArray,
|
||||
# "cam_right_wrist": numpy.NDArray,
|
||||
# },
|
||||
# "state": numpy.NDarray,
|
||||
# "prompt": "xxx text"
|
||||
# }
|
||||
obs = msgpack_utils.unpackb(await websocket.recv())
|
||||
|
||||
obs = await self.preprocess_observation(obs)
|
||||
for key in obs:
|
||||
if isinstance(obs[key], torch.Tensor):
|
||||
obs[key] = obs[key].to("cuda", non_blocking=True)
|
||||
|
||||
with torch.inference_mode():
|
||||
action = self._policy.select_action(obs)
|
||||
print("inference once with action:", action)
|
||||
res = {"actions": action.cpu().numpy()}
|
||||
await websocket.send(packer.pack(res))
|
||||
except websockets.ConnectionClosed:
|
||||
logging.info(f"Connection from {websocket.remote_address} closed")
|
||||
break
|
||||
except Exception:
|
||||
await websocket.send(traceback.format_exc())
|
||||
await websocket.close(
|
||||
code=websockets.frames.CloseCode.INTERNAL_ERROR,
|
||||
reason="Internal server error. Traceback included in previous frame.",
|
||||
)
|
||||
raise
|
|
@ -0,0 +1,43 @@
|
|||
import functools
|
||||
|
||||
import msgpack
|
||||
import numpy as np
|
||||
|
||||
|
||||
def pack_array(obj):
|
||||
if (isinstance(obj, (np.ndarray, np.generic))) and obj.dtype.kind in ("V", "O", "c"):
|
||||
raise ValueError(f"Unsupported dtype: {obj.dtype}")
|
||||
|
||||
if isinstance(obj, np.ndarray):
|
||||
return {
|
||||
b"__ndarray__": True,
|
||||
b"data": obj.tobytes(),
|
||||
b"dtype": obj.dtype.str,
|
||||
b"shape": obj.shape,
|
||||
}
|
||||
|
||||
if isinstance(obj, np.generic):
|
||||
return {
|
||||
b"__npgeneric__": True,
|
||||
b"data": obj.item(),
|
||||
b"dtype": obj.dtype.str,
|
||||
}
|
||||
|
||||
return obj
|
||||
|
||||
|
||||
def unpack_array(obj):
|
||||
if b"__ndarray__" in obj:
|
||||
return np.ndarray(buffer=obj[b"data"], dtype=np.dtype(obj[b"dtype"]), shape=obj[b"shape"])
|
||||
|
||||
if b"__npgeneric__" in obj:
|
||||
return np.dtype(obj[b"dtype"]).type(obj[b"data"])
|
||||
|
||||
return obj
|
||||
|
||||
|
||||
Packer = functools.partial(msgpack.Packer, default=pack_array)
|
||||
packb = functools.partial(msgpack.packb, default=pack_array)
|
||||
|
||||
Unpacker = functools.partial(msgpack.Unpacker, object_hook=unpack_array)
|
||||
unpackb = functools.partial(msgpack.unpackb, object_hook=unpack_array)
|
|
@ -0,0 +1,180 @@
|
|||
from contextlib import nullcontext
|
||||
from pprint import pformat
|
||||
from termcolor import colored
|
||||
import dataclasses
|
||||
import logging
|
||||
from typing import Iterator
|
||||
|
||||
import wandb
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import lerobot.common.datasets.lerobot_dataset as lerobot_dataset
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.datasets.utils import cycle
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.utils.utils import (
|
||||
init_logging,
|
||||
get_safe_torch_device,
|
||||
)
|
||||
from lerobot.common.utils.random_utils import set_seed
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
|
||||
class EpisodeSampler(torch.utils.data.Sampler):
|
||||
def __init__(self, dataset: lerobot_dataset.LeRobotDataset, episode_index: int):
|
||||
from_idx = dataset.episode_data_index["from"][episode_index].item()
|
||||
to_idx = dataset.episode_data_index["to"][episode_index].item()
|
||||
self.frame_ids = range(from_idx, to_idx)
|
||||
|
||||
def __iter__(self) -> Iterator:
|
||||
return iter(self.frame_ids)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.frame_ids)
|
||||
|
||||
@parser.wrap()
|
||||
def main(cfg: TrainPipelineConfig):
|
||||
init_logging()
|
||||
logging.info(pformat(dataclasses.asdict(cfg)))
|
||||
|
||||
device = get_safe_torch_device(cfg.policy.device, log=True)
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
set_seed(cfg.seed)
|
||||
|
||||
wandb.init(
|
||||
name=cfg.job_name,
|
||||
config=dataclasses.asdict(cfg),
|
||||
project="lerobot",
|
||||
)
|
||||
|
||||
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}")
|
||||
|
||||
logging.info("Making dataset.")
|
||||
dataset = make_dataset(cfg)
|
||||
dataset.meta.info['features']['observation.state']['shape'] = (13,)
|
||||
dataset.meta.info['features']['observation.state']['names'] = ['left_arm_1', 'left_arm_2', 'left_arm_3', 'left_arm_4', 'left_arm_5', 'left_arm_6',
|
||||
'right_arm_1', 'right_arm_2', 'right_arm_3', 'right_arm_4', 'right_arm_5', 'right_arm_6', 'vacuum']
|
||||
dataset.meta.info['features']['action']['shape'] = (13,)
|
||||
dataset.meta.info['features']['action']['names'] = ['left_arm_exp_1', 'left_arm_exp_2', 'left_arm_exp_3', 'left_arm_exp_4', 'left_arm_exp_5', 'left_arm_exp_6',
|
||||
'right_arm_exp_1', 'right_arm_exp_2', 'right_arm_exp_3', 'right_arm_exp_4', 'right_arm_exp_5', 'right_arm_exp_6', 'vacuum_exp']
|
||||
|
||||
for key, episode_stats in dataset.meta.episodes_stats.items():
|
||||
for feature in ['observation.state', 'action']:
|
||||
for sub_feature in ['min', 'max', 'mean', 'std']:
|
||||
episode_stats[feature][sub_feature] = np.delete(episode_stats[feature][sub_feature], [6,13], axis=0)
|
||||
dataset.meta.episodes_stats[key] = episode_stats
|
||||
for feature in ['min', 'max', 'mean', 'std']:
|
||||
dataset.meta.stats['observation.state'][feature] = np.delete(dataset.meta.stats['observation.state'][feature], [6,13], axis=0)
|
||||
dataset.meta.stats['action'][feature] = np.delete(dataset.meta.stats['action'][feature], [6,13], axis=0)
|
||||
|
||||
|
||||
logging.info("Making policy.")
|
||||
policy = make_policy(
|
||||
cfg=cfg.policy,
|
||||
ds_meta=dataset.meta,
|
||||
)
|
||||
|
||||
# TODO: 支持多条,从配置读取
|
||||
episode_index = 0
|
||||
|
||||
sampler = EpisodeSampler(
|
||||
dataset, episode_index
|
||||
)
|
||||
data_len = len(sampler)
|
||||
print("entire data len:", data_len)
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=cfg.num_workers,
|
||||
batch_size=cfg.batch_size,
|
||||
sampler=sampler,
|
||||
)
|
||||
|
||||
policy.eval()
|
||||
|
||||
# aloha
|
||||
# dim_names = [
|
||||
# "arm_left_0", "arm_left_1", "arm_left_2", "arm_left_3", "arm_left_4", "arm_left_5",
|
||||
# "arm_left_gripper_0", "arm_left_gripper_1",
|
||||
# "arm_right_0", "arm_right_1", "arm_right_2", "arm_right_3", "arm_right_4", "arm_right_5",
|
||||
# "arm_right_gripper_0", "arm_right_gripper_1",
|
||||
# ]
|
||||
|
||||
# galaxea vacumn - 13 dim
|
||||
dim_names = [
|
||||
"arm_left_0", "arm_left_1", "arm_left_2", "arm_left_3", "arm_left_4", "arm_left_5",
|
||||
"arm_right_0", "arm_right_1", "arm_right_2", "arm_right_3", "arm_right_4", "arm_right_5",
|
||||
"arm_right_vacumn",
|
||||
]
|
||||
# galaxea vacumn - 15 dim
|
||||
# dim_names = [
|
||||
# "arm_left_0", "arm_left_1", "arm_left_2", "arm_left_3", "arm_left_4", "arm_left_5", "arm_left_gripper",
|
||||
# "arm_right_0", "arm_right_1", "arm_right_2", "arm_right_3", "arm_right_4", "arm_right_5","arm_right_gripper",
|
||||
# "arm_right_vacumn",
|
||||
# ]
|
||||
|
||||
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext():
|
||||
# infer_action = policy.forward(input_raw_data)
|
||||
dl_iter = cycle(dataloader)
|
||||
|
||||
# skip the first few frames
|
||||
for i in range(50):
|
||||
a = next(dl_iter)
|
||||
|
||||
|
||||
for step in range(200):
|
||||
# 准备数据
|
||||
input_raw_data = next(dl_iter)
|
||||
|
||||
# shape记录 galaxea
|
||||
# observation.images.front: torch.Size([batch, 3, 360, 640])
|
||||
# observation.images.wrist_right: torch.Size([batch, 3, 480, 640])
|
||||
# observation.state torch.Size([batch, 15])
|
||||
if 'observation.images.depth' in input_raw_data:
|
||||
del input_raw_data['observation.images.depth']
|
||||
|
||||
if input_raw_data["observation.state"].shape[-1] == 15:
|
||||
keep_cols = [i for i in range(15) if i not in {6, 13}]
|
||||
# 切片操作删除指定列
|
||||
input_raw_data["observation.state"] = input_raw_data["observation.state"][..., keep_cols]
|
||||
print("state shape", input_raw_data["observation.state"].shape)
|
||||
|
||||
if input_raw_data["action"].shape[-1] == 15:
|
||||
keep_cols = [i for i in range(15) if i not in {6, 13}]
|
||||
input_raw_data["action"] = input_raw_data["action"][..., keep_cols]
|
||||
print("action shape", input_raw_data["action"].shape)
|
||||
|
||||
for key in input_raw_data:
|
||||
if isinstance(input_raw_data[key], torch.Tensor):
|
||||
input_raw_data[key] = input_raw_data[key].to(device, non_blocking=True)
|
||||
|
||||
# the front camera resolution is 640x360, pad it to 640x480 which is the resolution of the wrist camera
|
||||
input_raw_data['observation.images.front'] = torch.nn.functional.pad(
|
||||
input_raw_data['observation.images.front'],
|
||||
pad=(0, 0, 60, 60), # 左右不填充,上下各填充60
|
||||
mode='constant',
|
||||
value=0 # black
|
||||
)
|
||||
logging.warning("pad the front images to ", input_raw_data['observation.images.front'].shape)
|
||||
|
||||
# 推理
|
||||
infer_action = policy.select_action(input_raw_data)
|
||||
infer_action = infer_action[0].cpu().numpy()
|
||||
|
||||
|
||||
origin_action = input_raw_data["action"].cpu().numpy()[0][0]
|
||||
|
||||
# 记录每个action dim的差异
|
||||
diff = infer_action - origin_action
|
||||
log_dict = {}
|
||||
for dim in range(cfg.policy.output_features["action"].shape[0]):
|
||||
key = f"{dim_names[dim]}_diff"
|
||||
log_dict[key] = diff[dim]
|
||||
|
||||
wandb.log(log_dict)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -0,0 +1,81 @@
|
|||
import dataclasses
|
||||
import enum
|
||||
import logging
|
||||
import socket
|
||||
import tyro
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.common.utils.utils import (
|
||||
init_logging,
|
||||
get_safe_torch_device,
|
||||
)
|
||||
from lerobot.common.utils.random_utils import set_seed
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.eval import EvalPipelineConfig
|
||||
from lerobot.common.serving.websocket_policy_server import WebsocketPolicyServer
|
||||
from lerobot.common.policies.act.modeling_act import ACTPolicy
|
||||
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
||||
|
||||
|
||||
class PolicyType(enum.Enum):
|
||||
"""Supported environments."""
|
||||
|
||||
ACT = "act"
|
||||
DIFFUSION = "diffusion"
|
||||
PI0 = "pi0"
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Checkpoint:
|
||||
"""Load a policy from a trained checkpoint."""
|
||||
|
||||
# Checkpoint directory (e.g., "outputs/train/act_move_reel_0322_nodepth/checkpoints/040000/pretrained_mode").
|
||||
path: str
|
||||
|
||||
# policy type
|
||||
type: str
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Args:
|
||||
"""Arguments for the serve_policy script."""
|
||||
# If provided, will be used in case the "prompt" key is not present in the data, or if the model doesn't have a default
|
||||
# prompt.
|
||||
default_prompt: str | None = None
|
||||
|
||||
# Port to serve the policy on.
|
||||
port: int = 8000
|
||||
# Record the policy's behavior for debugging.
|
||||
record: bool = False
|
||||
|
||||
# Specifies how to load the policy. If not provided, the default policy for the environment will be used.
|
||||
policy: Checkpoint = dataclasses.field(default_factory=Checkpoint)
|
||||
|
||||
def main(args: Args) -> None:
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
set_seed(1000)
|
||||
|
||||
print(args)
|
||||
|
||||
# policy has been in device and evaluated
|
||||
if args.policy.type == PolicyType.ACT.value:
|
||||
policy = ACTPolicy.from_pretrained(args.policy.path)
|
||||
elif args.policy.type == PolicyType.DIFFUSION.value:
|
||||
policy = DiffusionPolicy.from_pretrained(args.policy.path)
|
||||
|
||||
# Record the policy's behavior.
|
||||
# if args.record:
|
||||
# policy = _policy.PolicyRecorder(policy, "policy_records")
|
||||
|
||||
server = WebsocketPolicyServer(
|
||||
policy=policy,
|
||||
host="0.0.0.0",
|
||||
port=args.port,
|
||||
metadata={},
|
||||
)
|
||||
server.serve_forever()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
init_logging()
|
||||
main(tyro.cli(Args))
|
|
@ -19,6 +19,7 @@ from contextlib import nullcontext
|
|||
from pprint import pformat
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from termcolor import colored
|
||||
from torch.amp import GradScaler
|
||||
|
@ -126,6 +127,23 @@ def train(cfg: TrainPipelineConfig):
|
|||
|
||||
logging.info("Creating dataset")
|
||||
dataset = make_dataset(cfg)
|
||||
dataset.meta.info['features']['observation.state']['shape'] = (13,)
|
||||
dataset.meta.info['features']['observation.state']['names'] = ['left_arm_1', 'left_arm_2', 'left_arm_3', 'left_arm_4', 'left_arm_5', 'left_arm_6',
|
||||
'right_arm_1', 'right_arm_2', 'right_arm_3', 'right_arm_4', 'right_arm_5', 'right_arm_6', 'vacuum']
|
||||
dataset.meta.info['features']['action']['shape'] = (13,)
|
||||
dataset.meta.info['features']['action']['names'] = ['left_arm_exp_1', 'left_arm_exp_2', 'left_arm_exp_3', 'left_arm_exp_4', 'left_arm_exp_5', 'left_arm_exp_6',
|
||||
'right_arm_exp_1', 'right_arm_exp_2', 'right_arm_exp_3', 'right_arm_exp_4', 'right_arm_exp_5', 'right_arm_exp_6', 'vacuum_exp']
|
||||
|
||||
for key, episode_stats in dataset.meta.episodes_stats.items():
|
||||
for feature in ['observation.state', 'action']:
|
||||
for sub_feature in ['min', 'max', 'mean', 'std']:
|
||||
episode_stats[feature][sub_feature] = np.delete(episode_stats[feature][sub_feature], [6,13], axis=0)
|
||||
dataset.meta.episodes_stats[key] = episode_stats
|
||||
for feature in ['min', 'max', 'mean', 'std']:
|
||||
dataset.meta.stats['observation.state'][feature] = np.delete(dataset.meta.stats['observation.state'][feature], [6,13], axis=0)
|
||||
dataset.meta.stats['action'][feature] = np.delete(dataset.meta.stats['action'][feature], [6,13], axis=0)
|
||||
|
||||
|
||||
|
||||
# Create environment used for evaluating checkpoints during training on simulation data.
|
||||
# On real-world data, no need to create an environment as evaluations are done outside train.py,
|
||||
|
@ -209,6 +227,29 @@ def train(cfg: TrainPipelineConfig):
|
|||
if isinstance(batch[key], torch.Tensor):
|
||||
batch[key] = batch[key].to(device, non_blocking=True)
|
||||
|
||||
# remove unused depth info
|
||||
if 'observation.images.depth' in batch:
|
||||
del batch['observation.images.depth']
|
||||
|
||||
# remove redundant joint states
|
||||
if batch["observation.state"].shape[-1] == 15:
|
||||
keep_cols = [i for i in range(15) if i not in {6, 13}]
|
||||
# 切片操作删除指定列
|
||||
batch["observation.state"] = batch["observation.state"][..., keep_cols]
|
||||
|
||||
# remove redundant joint states
|
||||
if batch["action"].shape[-1] == 15:
|
||||
keep_cols = [i for i in range(15) if i not in {6, 13}]
|
||||
batch["action"] = batch["action"][..., keep_cols]
|
||||
|
||||
# the front camera resolution is 640x360, pad it to 640x480 which is the resolution of the wrist camera
|
||||
batch['observation.images.front'] = torch.nn.functional.pad(
|
||||
batch['observation.images.front'],
|
||||
pad=(0, 0, 60, 60), # 左右不填充,上下各填充60
|
||||
mode='constant',
|
||||
value=0 # black
|
||||
)
|
||||
|
||||
train_tracker, output_dict = update_policy(
|
||||
train_tracker,
|
||||
policy,
|
||||
|
|
|
@ -85,6 +85,7 @@ dynamixel = ["dynamixel-sdk>=3.7.31", "pynput>=1.7.7"]
|
|||
feetech = ["feetech-servo-sdk>=1.0.0", "pynput>=1.7.7"]
|
||||
intelrealsense = ["pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'"]
|
||||
pi0 = ["transformers>=4.48.0"]
|
||||
serving = ["websockets==14.1", "msgpack==1.1.0", "tyro"]
|
||||
pusht = ["gym-pusht>=0.1.5 ; python_version < '4.0'"]
|
||||
stretch = [
|
||||
"hello-robot-stretch-body>=0.7.27 ; python_version < '4.0' and sys_platform == 'linux'",
|
||||
|
|
|
@ -0,0 +1,43 @@
|
|||
import logging
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import websockets.sync.client
|
||||
|
||||
import lerobot.common.utils.msgpack_utils as msgpack_utils
|
||||
|
||||
input = {
|
||||
"state": np.ones((13,)),
|
||||
"images": {
|
||||
# input images from client has spec h w c (client)
|
||||
"front": np.random.randint(256, size=(480, 640, 3), dtype=np.uint8),
|
||||
"wrist_right": np.random.randint(256, size=(480, 640, 3), dtype=np.uint8),
|
||||
},
|
||||
"prompt": "do something",
|
||||
}
|
||||
|
||||
url = "ws://127.0.0.1:8000"
|
||||
packer = msgpack_utils.Packer()
|
||||
|
||||
logging.info(f"Waiting for server at {url}...")
|
||||
while True:
|
||||
try:
|
||||
conn = websockets.sync.client.connect(url, compression=None, max_size=None)
|
||||
metadata = msgpack_utils.unpackb(conn.recv())
|
||||
break
|
||||
except ConnectionRefusedError:
|
||||
logging.info("Still waiting for server...")
|
||||
time.sleep(5)
|
||||
|
||||
data = packer.pack(input)
|
||||
conn.send(data)
|
||||
response = conn.recv()
|
||||
if isinstance(response, str):
|
||||
# we're expecting bytes; if the server sends a string, it's an error.
|
||||
print(f"Error in inference server:\n{response}")
|
||||
exit()
|
||||
|
||||
infer_result = msgpack_utils.unpackb(response)
|
||||
print(infer_result)
|
||||
assert len(infer_result['actions'][0]) == len(input['state'])
|
||||
|
Loading…
Reference in New Issue