diff --git a/lerobot/common/policies/act/modeling_act.py b/lerobot/common/policies/act/modeling_act.py index 418863a1..be75396f 100644 --- a/lerobot/common/policies/act/modeling_act.py +++ b/lerobot/common/policies/act/modeling_act.py @@ -20,8 +20,10 @@ The majority of changes here involve removing unused code, unifying naming, and """ import math -from collections import deque +import threading +import time from itertools import chain +from threading import Thread from typing import Callable import einops @@ -36,6 +38,7 @@ from torchvision.ops.misc import FrozenBatchNorm2d from lerobot.common.policies.act.configuration_act import ACTConfig from lerobot.common.policies.normalize import Normalize, Unnormalize +from lerobot.common.policies.utils import TemporalQueue class ACTPolicy( @@ -87,22 +90,23 @@ class ACTPolicy( self.temporal_ensembler = ACTTemporalEnsembler(config.temporal_ensemble_coeff, config.chunk_size) self.reset() + self.thread = None + + # TODO(rcadene): Add delta timestamps in policy + FPS = 10 # noqa: N806 + self.delta_timestamps = [i / FPS for i in range(self.config.n_action_steps)] def reset(self): """This should be called whenever the environment is reset.""" if self.config.temporal_ensemble_coeff is not None: self.temporal_ensembler.reset() else: - self._action_queue = deque([], maxlen=self.config.n_action_steps) + # TODO(rcadene): set proper maxlen + self._obs_queue = TemporalQueue(maxlen=1) + self._action_queue = TemporalQueue(maxlen=200) @torch.no_grad - def select_action(self, batch: dict[str, Tensor]) -> Tensor: - """Select a single action given environment observations. - - This method wraps `select_actions` in order to return one action at a time for execution in the - environment. It works by managing the actions in a queue and only calling `select_actions` when the - queue is empty. - """ + def inference(self, batch: dict[str, Tensor]) -> Tensor: self.eval() batch = self.normalize_inputs(batch) @@ -118,18 +122,47 @@ class ACTPolicy( action = self.temporal_ensembler.update(actions) return action - # Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by - # querying the policy. - if len(self._action_queue) == 0: - actions = self.model(batch)[0][:, : self.config.n_action_steps] + actions = self.model(batch)[0][:, : self.config.n_action_steps] - # TODO(rcadene): make _forward return output dictionary? - actions = self.unnormalize_outputs({"action": actions})["action"] + # TODO(rcadene): make _forward return output dictionary? + actions = self.unnormalize_outputs({"action": actions})["action"] + return actions - # `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue - # effectively has shape (n_action_steps, batch_size, *), hence the transpose. - self._action_queue.extend(actions.transpose(0, 1)) - return self._action_queue.popleft() + def inference_loop(self): + prev_timestamp = None + while not self.stop_event.is_set(): + last_observation, last_timestamp = self._obs_queue.get_latest() + + if prev_timestamp is not None and prev_timestamp == last_timestamp: + # in case inference ran faster than recording/adding a new observation in the queue + time.sleep(0.1) + continue + + pred_action_sequence = self.inference(last_observation) + + for action, delta_ts in zip(pred_action_sequence, self.delta_timestamps, strict=False): + self._action_queue.add(action, last_timestamp + delta_ts) + + prev_timestamp = last_timestamp + + def select_action(self, batch: dict[str, Tensor]) -> Tensor: + present_time = time.time() + self._obs_queue.add(batch, present_time) + + if self.thread is None: + self.stop_event = threading.Event() + self.thread = Thread(target=self.inference_loop, args=()) + self.thread.daemon = True + self.thread.start() + + next_action = None + while next_action is None: + try: + next_action = self._action_queue.get(present_time) + except ValueError: + time.sleep(0.1) # no action available at this present time, we wait a bit + + return next_action def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: """Run the batch through the model and compute the loss for training or validation.""" diff --git a/lerobot/common/policies/utils.py b/lerobot/common/policies/utils.py index 5a62daa2..beb1f1d4 100644 --- a/lerobot/common/policies/utils.py +++ b/lerobot/common/policies/utils.py @@ -13,6 +13,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from collections import deque + import torch from torch import nn @@ -47,3 +49,33 @@ def get_dtype_from_parameters(module: nn.Module) -> torch.dtype: Note: assumes that all parameters have the same dtype. """ return next(iter(module.parameters())).dtype + + +class TemporalQueue: + def __init__(self, maxlen): + # TODO(rcadene): set proper maxlen + self.items = deque(maxlen=maxlen) + self.timestamps = deque(maxlen=maxlen) + + def add(self, item, timestamp): + self.items.append(item) + self.timestamps.append(timestamp) + + def get_latest(self): + return self.items[-1], self.timestamps[-1] + + def get(self, timestamp): + import numpy as np + + timestamps = np.array(list(self.timestamps)) + distances = np.abs(timestamps - timestamp) + nearest_idx = distances.argmin() + + # print(float(distances[nearest_idx])) + if float(distances[nearest_idx]) > 1 / 5: + raise ValueError() + + return self.items[nearest_idx], self.timestamps[nearest_idx] + + def __len__(self): + return len(self.items) diff --git a/test2.py b/test2.py new file mode 100644 index 00000000..50f8932f --- /dev/null +++ b/test2.py @@ -0,0 +1,45 @@ +import torch + +from lerobot.common.datasets.factory import make_dataset +from lerobot.common.policies.factory import make_policy +from lerobot.common.utils.utils import init_hydra_config, set_global_seed +from tests.utils import DEFAULT_CONFIG_PATH + + +def main(env_name, policy_name, extra_overrides): + cfg = init_hydra_config( + DEFAULT_CONFIG_PATH, + overrides=[ + f"env={env_name}", + f"policy={policy_name}", + "device=cpu", + ] + + extra_overrides, + ) + set_global_seed(1337) + dataset = make_dataset(cfg) + policy = make_policy(cfg, dataset_stats=dataset.stats) + + dataloader = torch.utils.data.DataLoader( + dataset, + num_workers=0, + batch_size=1, + shuffle=False, + ) + batch = next(iter(dataloader)) + + obs = {} + for k in batch: + if k.startswith("observation"): + obs[k] = batch[k] + + actions = policy.inference(obs) + + action, timestamp = policy.select_action(obs) + + print(actions[0]) + print(action) + + +if __name__ == "__main__": + main("aloha", "act", ["policy.n_action_steps=10"])