Add select_action_chunk method in PreTrainedPolicy
This commit is contained in:
parent
485affb658
commit
509b9223f6
|
@ -106,6 +106,20 @@ class ACTPolicy(PreTrainedPolicy):
|
|||
else:
|
||||
self._action_queue = deque([], maxlen=self.config.n_action_steps)
|
||||
|
||||
@torch.no_grad
|
||||
def select_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Select the entire action chunk given environment observations.
|
||||
|
||||
This method wraps `select_action_chunk` in order to return all actions the policy infers at a time.
|
||||
Resuse the `select_action` method in order to reuse the open source lerobot, but its performance may
|
||||
decrease a bit.
|
||||
"""
|
||||
first_action = self.select_action(batch)
|
||||
result_tensor = torch.cat([first_action] + list(self._action_queue), dim=0)
|
||||
self._action_queue = deque([], maxlen=self.config.n_action_steps)
|
||||
return result_tensor
|
||||
|
||||
|
||||
@torch.no_grad
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Select a single action given environment observations.
|
||||
|
|
|
@ -143,6 +143,19 @@ class DiffusionPolicy(PreTrainedPolicy):
|
|||
action = self._queues["action"].popleft()
|
||||
return action
|
||||
|
||||
@torch.no_grad
|
||||
def select_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Select the entire action chunk given environment observations.
|
||||
|
||||
This method wraps `select_action_chunk` in order to return all actions the policy infers at a time.
|
||||
Resuse the `select_action` method in order to reuse the open source lerobot, but its performance may
|
||||
decrease a bit.
|
||||
"""
|
||||
first_action = self.select_action(batch)
|
||||
result_tensor = torch.cat([first_action] + list(self._action_queue), dim=0)
|
||||
self._action_queue = deque([], maxlen=self.config.n_action_steps)
|
||||
return result_tensor
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, None]:
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
|
|
|
@ -197,3 +197,12 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
|
|||
with caching.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def select_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Return one action to run in the environment (potentially in batch mode).
|
||||
|
||||
When the model uses a history of observations, or outputs a sequence of actions, this method deals
|
||||
with caching.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -94,9 +94,10 @@ class WebsocketPolicyServer:
|
|||
if isinstance(obs[key], torch.Tensor):
|
||||
obs[key] = obs[key].to("cuda", non_blocking=True)
|
||||
|
||||
with torch.inference_mode():
|
||||
action = self._policy.select_action(obs)
|
||||
print("inference once with action:", action)
|
||||
with torch.inference_mode():
|
||||
action = self._policy.select_action_chunk(obs)
|
||||
action = action.squeeze(0)
|
||||
print("inference once with action:", action.shape, action)
|
||||
res = {"actions": action.cpu().numpy()}
|
||||
await websocket.send(packer.pack(res))
|
||||
except websockets.ConnectionClosed:
|
||||
|
|
|
@ -38,6 +38,6 @@ if isinstance(response, str):
|
|||
exit()
|
||||
|
||||
infer_result = msgpack_utils.unpackb(response)
|
||||
print(infer_result)
|
||||
print(infer_result['actions'].shape, infer_result)
|
||||
assert len(infer_result['actions'][0]) == len(input['state'])
|
||||
|
||||
|
|
Loading…
Reference in New Issue