Add select_action_chunk method in PreTrainedPolicy

This commit is contained in:
ruanafan 2025-04-07 17:41:22 +08:00
parent 485affb658
commit 509b9223f6
5 changed files with 41 additions and 4 deletions

View File

@ -106,6 +106,20 @@ class ACTPolicy(PreTrainedPolicy):
else:
self._action_queue = deque([], maxlen=self.config.n_action_steps)
@torch.no_grad
def select_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
"""Select the entire action chunk given environment observations.
This method wraps `select_action_chunk` in order to return all actions the policy infers at a time.
Resuse the `select_action` method in order to reuse the open source lerobot, but its performance may
decrease a bit.
"""
first_action = self.select_action(batch)
result_tensor = torch.cat([first_action] + list(self._action_queue), dim=0)
self._action_queue = deque([], maxlen=self.config.n_action_steps)
return result_tensor
@torch.no_grad
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select a single action given environment observations.

View File

@ -143,6 +143,19 @@ class DiffusionPolicy(PreTrainedPolicy):
action = self._queues["action"].popleft()
return action
@torch.no_grad
def select_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
"""Select the entire action chunk given environment observations.
This method wraps `select_action_chunk` in order to return all actions the policy infers at a time.
Resuse the `select_action` method in order to reuse the open source lerobot, but its performance may
decrease a bit.
"""
first_action = self.select_action(batch)
result_tensor = torch.cat([first_action] + list(self._action_queue), dim=0)
self._action_queue = deque([], maxlen=self.config.n_action_steps)
return result_tensor
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, None]:
"""Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch)

View File

@ -197,3 +197,12 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
with caching.
"""
raise NotImplementedError
@abc.abstractmethod
def select_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
"""Return one action to run in the environment (potentially in batch mode).
When the model uses a history of observations, or outputs a sequence of actions, this method deals
with caching.
"""
raise NotImplementedError

View File

@ -94,9 +94,10 @@ class WebsocketPolicyServer:
if isinstance(obs[key], torch.Tensor):
obs[key] = obs[key].to("cuda", non_blocking=True)
with torch.inference_mode():
action = self._policy.select_action(obs)
print("inference once with action:", action)
with torch.inference_mode():
action = self._policy.select_action_chunk(obs)
action = action.squeeze(0)
print("inference once with action:", action.shape, action)
res = {"actions": action.cpu().numpy()}
await websocket.send(packer.pack(res))
except websockets.ConnectionClosed:

View File

@ -38,6 +38,6 @@ if isinstance(response, str):
exit()
infer_result = msgpack_utils.unpackb(response)
print(infer_result)
print(infer_result['actions'].shape, infer_result)
assert len(infer_result['actions'][0]) == len(input['state'])