This commit is contained in:
Remi Cadene 2024-09-30 15:22:13 +02:00
parent b2e5f7fe2d
commit ec1efc64b4
3 changed files with 129 additions and 19 deletions

View File

@ -20,8 +20,10 @@ The majority of changes here involve removing unused code, unifying naming, and
""" """
import math import math
from collections import deque import threading
import time
from itertools import chain from itertools import chain
from threading import Thread
from typing import Callable from typing import Callable
import einops import einops
@ -36,6 +38,7 @@ from torchvision.ops.misc import FrozenBatchNorm2d
from lerobot.common.policies.act.configuration_act import ACTConfig from lerobot.common.policies.act.configuration_act import ACTConfig
from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.utils import TemporalQueue
class ACTPolicy( class ACTPolicy(
@ -87,22 +90,23 @@ class ACTPolicy(
self.temporal_ensembler = ACTTemporalEnsembler(config.temporal_ensemble_coeff, config.chunk_size) self.temporal_ensembler = ACTTemporalEnsembler(config.temporal_ensemble_coeff, config.chunk_size)
self.reset() self.reset()
self.thread = None
# TODO(rcadene): Add delta timestamps in policy
FPS = 10 # noqa: N806
self.delta_timestamps = [i / FPS for i in range(self.config.n_action_steps)]
def reset(self): def reset(self):
"""This should be called whenever the environment is reset.""" """This should be called whenever the environment is reset."""
if self.config.temporal_ensemble_coeff is not None: if self.config.temporal_ensemble_coeff is not None:
self.temporal_ensembler.reset() self.temporal_ensembler.reset()
else: else:
self._action_queue = deque([], maxlen=self.config.n_action_steps) # TODO(rcadene): set proper maxlen
self._obs_queue = TemporalQueue(maxlen=1)
self._action_queue = TemporalQueue(maxlen=200)
@torch.no_grad @torch.no_grad
def select_action(self, batch: dict[str, Tensor]) -> Tensor: def inference(self, batch: dict[str, Tensor]) -> Tensor:
"""Select a single action given environment observations.
This method wraps `select_actions` in order to return one action at a time for execution in the
environment. It works by managing the actions in a queue and only calling `select_actions` when the
queue is empty.
"""
self.eval() self.eval()
batch = self.normalize_inputs(batch) batch = self.normalize_inputs(batch)
@ -118,18 +122,47 @@ class ACTPolicy(
action = self.temporal_ensembler.update(actions) action = self.temporal_ensembler.update(actions)
return action return action
# Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by actions = self.model(batch)[0][:, : self.config.n_action_steps]
# querying the policy.
if len(self._action_queue) == 0:
actions = self.model(batch)[0][:, : self.config.n_action_steps]
# TODO(rcadene): make _forward return output dictionary? # TODO(rcadene): make _forward return output dictionary?
actions = self.unnormalize_outputs({"action": actions})["action"] actions = self.unnormalize_outputs({"action": actions})["action"]
return actions
# `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue def inference_loop(self):
# effectively has shape (n_action_steps, batch_size, *), hence the transpose. prev_timestamp = None
self._action_queue.extend(actions.transpose(0, 1)) while not self.stop_event.is_set():
return self._action_queue.popleft() last_observation, last_timestamp = self._obs_queue.get_latest()
if prev_timestamp is not None and prev_timestamp == last_timestamp:
# in case inference ran faster than recording/adding a new observation in the queue
time.sleep(0.1)
continue
pred_action_sequence = self.inference(last_observation)
for action, delta_ts in zip(pred_action_sequence, self.delta_timestamps, strict=False):
self._action_queue.add(action, last_timestamp + delta_ts)
prev_timestamp = last_timestamp
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
present_time = time.time()
self._obs_queue.add(batch, present_time)
if self.thread is None:
self.stop_event = threading.Event()
self.thread = Thread(target=self.inference_loop, args=())
self.thread.daemon = True
self.thread.start()
next_action = None
while next_action is None:
try:
next_action = self._action_queue.get(present_time)
except ValueError:
time.sleep(0.1) # no action available at this present time, we wait a bit
return next_action
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""Run the batch through the model and compute the loss for training or validation.""" """Run the batch through the model and compute the loss for training or validation."""

View File

@ -13,6 +13,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from collections import deque
import torch import torch
from torch import nn from torch import nn
@ -47,3 +49,33 @@ def get_dtype_from_parameters(module: nn.Module) -> torch.dtype:
Note: assumes that all parameters have the same dtype. Note: assumes that all parameters have the same dtype.
""" """
return next(iter(module.parameters())).dtype return next(iter(module.parameters())).dtype
class TemporalQueue:
def __init__(self, maxlen):
# TODO(rcadene): set proper maxlen
self.items = deque(maxlen=maxlen)
self.timestamps = deque(maxlen=maxlen)
def add(self, item, timestamp):
self.items.append(item)
self.timestamps.append(timestamp)
def get_latest(self):
return self.items[-1], self.timestamps[-1]
def get(self, timestamp):
import numpy as np
timestamps = np.array(list(self.timestamps))
distances = np.abs(timestamps - timestamp)
nearest_idx = distances.argmin()
# print(float(distances[nearest_idx]))
if float(distances[nearest_idx]) > 1 / 5:
raise ValueError()
return self.items[nearest_idx], self.timestamps[nearest_idx]
def __len__(self):
return len(self.items)

45
test2.py Normal file
View File

@ -0,0 +1,45 @@
import torch
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.policies.factory import make_policy
from lerobot.common.utils.utils import init_hydra_config, set_global_seed
from tests.utils import DEFAULT_CONFIG_PATH
def main(env_name, policy_name, extra_overrides):
cfg = init_hydra_config(
DEFAULT_CONFIG_PATH,
overrides=[
f"env={env_name}",
f"policy={policy_name}",
"device=cpu",
]
+ extra_overrides,
)
set_global_seed(1337)
dataset = make_dataset(cfg)
policy = make_policy(cfg, dataset_stats=dataset.stats)
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=0,
batch_size=1,
shuffle=False,
)
batch = next(iter(dataloader))
obs = {}
for k in batch:
if k.startswith("observation"):
obs[k] = batch[k]
actions = policy.inference(obs)
action, timestamp = policy.select_action(obs)
print(actions[0])
print(action)
if __name__ == "__main__":
main("aloha", "act", ["policy.n_action_steps=10"])