From 509b9223f6b4a771e0b949b7063b26b46ee0f2ee Mon Sep 17 00:00:00 2001 From: ruanafan Date: Mon, 7 Apr 2025 17:41:22 +0800 Subject: [PATCH] Add select_action_chunk method in PreTrainedPolicy --- lerobot/common/policies/act/modeling_act.py | 14 ++++++++++++++ .../policies/diffusion/modeling_diffusion.py | 13 +++++++++++++ lerobot/common/policies/pretrained.py | 9 +++++++++ lerobot/common/serving/websocket_policy_server.py | 7 ++++--- tests/serving/test_websocket_serving.py | 2 +- 5 files changed, 41 insertions(+), 4 deletions(-) diff --git a/lerobot/common/policies/act/modeling_act.py b/lerobot/common/policies/act/modeling_act.py index 72d4df03..4d526b69 100644 --- a/lerobot/common/policies/act/modeling_act.py +++ b/lerobot/common/policies/act/modeling_act.py @@ -106,6 +106,20 @@ class ACTPolicy(PreTrainedPolicy): else: self._action_queue = deque([], maxlen=self.config.n_action_steps) + @torch.no_grad + def select_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: + """Select the entire action chunk given environment observations. + + This method wraps `select_action_chunk` in order to return all actions the policy infers at a time. + Resuse the `select_action` method in order to reuse the open source lerobot, but its performance may + decrease a bit. + """ + first_action = self.select_action(batch) + result_tensor = torch.cat([first_action] + list(self._action_queue), dim=0) + self._action_queue = deque([], maxlen=self.config.n_action_steps) + return result_tensor + + @torch.no_grad def select_action(self, batch: dict[str, Tensor]) -> Tensor: """Select a single action given environment observations. diff --git a/lerobot/common/policies/diffusion/modeling_diffusion.py b/lerobot/common/policies/diffusion/modeling_diffusion.py index 9ecadcb0..d1966329 100644 --- a/lerobot/common/policies/diffusion/modeling_diffusion.py +++ b/lerobot/common/policies/diffusion/modeling_diffusion.py @@ -143,6 +143,19 @@ class DiffusionPolicy(PreTrainedPolicy): action = self._queues["action"].popleft() return action + @torch.no_grad + def select_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: + """Select the entire action chunk given environment observations. + + This method wraps `select_action_chunk` in order to return all actions the policy infers at a time. + Resuse the `select_action` method in order to reuse the open source lerobot, but its performance may + decrease a bit. + """ + first_action = self.select_action(batch) + result_tensor = torch.cat([first_action] + list(self._action_queue), dim=0) + self._action_queue = deque([], maxlen=self.config.n_action_steps) + return result_tensor + def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, None]: """Run the batch through the model and compute the loss for training or validation.""" batch = self.normalize_inputs(batch) diff --git a/lerobot/common/policies/pretrained.py b/lerobot/common/policies/pretrained.py index da4ef157..6c883d08 100644 --- a/lerobot/common/policies/pretrained.py +++ b/lerobot/common/policies/pretrained.py @@ -197,3 +197,12 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC): with caching. """ raise NotImplementedError + + @abc.abstractmethod + def select_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: + """Return one action to run in the environment (potentially in batch mode). + + When the model uses a history of observations, or outputs a sequence of actions, this method deals + with caching. + """ + raise NotImplementedError diff --git a/lerobot/common/serving/websocket_policy_server.py b/lerobot/common/serving/websocket_policy_server.py index 2660515e..bf9e4807 100644 --- a/lerobot/common/serving/websocket_policy_server.py +++ b/lerobot/common/serving/websocket_policy_server.py @@ -94,9 +94,10 @@ class WebsocketPolicyServer: if isinstance(obs[key], torch.Tensor): obs[key] = obs[key].to("cuda", non_blocking=True) - with torch.inference_mode(): - action = self._policy.select_action(obs) - print("inference once with action:", action) + with torch.inference_mode(): + action = self._policy.select_action_chunk(obs) + action = action.squeeze(0) + print("inference once with action:", action.shape, action) res = {"actions": action.cpu().numpy()} await websocket.send(packer.pack(res)) except websockets.ConnectionClosed: diff --git a/tests/serving/test_websocket_serving.py b/tests/serving/test_websocket_serving.py index 77459686..2cb82589 100644 --- a/tests/serving/test_websocket_serving.py +++ b/tests/serving/test_websocket_serving.py @@ -38,6 +38,6 @@ if isinstance(response, str): exit() infer_result = msgpack_utils.unpackb(response) -print(infer_result) +print(infer_result['actions'].shape, infer_result) assert len(infer_result['actions'][0]) == len(input['state'])