Add select_action_chunk method in PreTrainedPolicy
This commit is contained in:
parent
485affb658
commit
509b9223f6
|
@ -106,6 +106,20 @@ class ACTPolicy(PreTrainedPolicy):
|
||||||
else:
|
else:
|
||||||
self._action_queue = deque([], maxlen=self.config.n_action_steps)
|
self._action_queue = deque([], maxlen=self.config.n_action_steps)
|
||||||
|
|
||||||
|
@torch.no_grad
|
||||||
|
def select_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
||||||
|
"""Select the entire action chunk given environment observations.
|
||||||
|
|
||||||
|
This method wraps `select_action_chunk` in order to return all actions the policy infers at a time.
|
||||||
|
Resuse the `select_action` method in order to reuse the open source lerobot, but its performance may
|
||||||
|
decrease a bit.
|
||||||
|
"""
|
||||||
|
first_action = self.select_action(batch)
|
||||||
|
result_tensor = torch.cat([first_action] + list(self._action_queue), dim=0)
|
||||||
|
self._action_queue = deque([], maxlen=self.config.n_action_steps)
|
||||||
|
return result_tensor
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad
|
@torch.no_grad
|
||||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||||
"""Select a single action given environment observations.
|
"""Select a single action given environment observations.
|
||||||
|
|
|
@ -143,6 +143,19 @@ class DiffusionPolicy(PreTrainedPolicy):
|
||||||
action = self._queues["action"].popleft()
|
action = self._queues["action"].popleft()
|
||||||
return action
|
return action
|
||||||
|
|
||||||
|
@torch.no_grad
|
||||||
|
def select_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
||||||
|
"""Select the entire action chunk given environment observations.
|
||||||
|
|
||||||
|
This method wraps `select_action_chunk` in order to return all actions the policy infers at a time.
|
||||||
|
Resuse the `select_action` method in order to reuse the open source lerobot, but its performance may
|
||||||
|
decrease a bit.
|
||||||
|
"""
|
||||||
|
first_action = self.select_action(batch)
|
||||||
|
result_tensor = torch.cat([first_action] + list(self._action_queue), dim=0)
|
||||||
|
self._action_queue = deque([], maxlen=self.config.n_action_steps)
|
||||||
|
return result_tensor
|
||||||
|
|
||||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, None]:
|
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, None]:
|
||||||
"""Run the batch through the model and compute the loss for training or validation."""
|
"""Run the batch through the model and compute the loss for training or validation."""
|
||||||
batch = self.normalize_inputs(batch)
|
batch = self.normalize_inputs(batch)
|
||||||
|
|
|
@ -197,3 +197,12 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
|
||||||
with caching.
|
with caching.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def select_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
||||||
|
"""Return one action to run in the environment (potentially in batch mode).
|
||||||
|
|
||||||
|
When the model uses a history of observations, or outputs a sequence of actions, this method deals
|
||||||
|
with caching.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
|
@ -95,8 +95,9 @@ class WebsocketPolicyServer:
|
||||||
obs[key] = obs[key].to("cuda", non_blocking=True)
|
obs[key] = obs[key].to("cuda", non_blocking=True)
|
||||||
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
action = self._policy.select_action(obs)
|
action = self._policy.select_action_chunk(obs)
|
||||||
print("inference once with action:", action)
|
action = action.squeeze(0)
|
||||||
|
print("inference once with action:", action.shape, action)
|
||||||
res = {"actions": action.cpu().numpy()}
|
res = {"actions": action.cpu().numpy()}
|
||||||
await websocket.send(packer.pack(res))
|
await websocket.send(packer.pack(res))
|
||||||
except websockets.ConnectionClosed:
|
except websockets.ConnectionClosed:
|
||||||
|
|
|
@ -38,6 +38,6 @@ if isinstance(response, str):
|
||||||
exit()
|
exit()
|
||||||
|
|
||||||
infer_result = msgpack_utils.unpackb(response)
|
infer_result = msgpack_utils.unpackb(response)
|
||||||
print(infer_result)
|
print(infer_result['actions'].shape, infer_result)
|
||||||
assert len(infer_result['actions'][0]) == len(input['state'])
|
assert len(infer_result['actions'][0]) == len(input['state'])
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue