This commit is contained in:
Remi Cadene 2024-09-30 15:22:13 +02:00
parent b2e5f7fe2d
commit ec1efc64b4
3 changed files with 129 additions and 19 deletions
lerobot/common/policies
test2.py

View File

@ -20,8 +20,10 @@ The majority of changes here involve removing unused code, unifying naming, and
"""
import math
from collections import deque
import threading
import time
from itertools import chain
from threading import Thread
from typing import Callable
import einops
@ -36,6 +38,7 @@ from torchvision.ops.misc import FrozenBatchNorm2d
from lerobot.common.policies.act.configuration_act import ACTConfig
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.utils import TemporalQueue
class ACTPolicy(
@ -87,22 +90,23 @@ class ACTPolicy(
self.temporal_ensembler = ACTTemporalEnsembler(config.temporal_ensemble_coeff, config.chunk_size)
self.reset()
self.thread = None
# TODO(rcadene): Add delta timestamps in policy
FPS = 10 # noqa: N806
self.delta_timestamps = [i / FPS for i in range(self.config.n_action_steps)]
def reset(self):
"""This should be called whenever the environment is reset."""
if self.config.temporal_ensemble_coeff is not None:
self.temporal_ensembler.reset()
else:
self._action_queue = deque([], maxlen=self.config.n_action_steps)
# TODO(rcadene): set proper maxlen
self._obs_queue = TemporalQueue(maxlen=1)
self._action_queue = TemporalQueue(maxlen=200)
@torch.no_grad
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select a single action given environment observations.
This method wraps `select_actions` in order to return one action at a time for execution in the
environment. It works by managing the actions in a queue and only calling `select_actions` when the
queue is empty.
"""
def inference(self, batch: dict[str, Tensor]) -> Tensor:
self.eval()
batch = self.normalize_inputs(batch)
@ -118,18 +122,47 @@ class ACTPolicy(
action = self.temporal_ensembler.update(actions)
return action
# Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
# querying the policy.
if len(self._action_queue) == 0:
actions = self.model(batch)[0][:, : self.config.n_action_steps]
actions = self.model(batch)[0][:, : self.config.n_action_steps]
# TODO(rcadene): make _forward return output dictionary?
actions = self.unnormalize_outputs({"action": actions})["action"]
# TODO(rcadene): make _forward return output dictionary?
actions = self.unnormalize_outputs({"action": actions})["action"]
return actions
# `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
# effectively has shape (n_action_steps, batch_size, *), hence the transpose.
self._action_queue.extend(actions.transpose(0, 1))
return self._action_queue.popleft()
def inference_loop(self):
prev_timestamp = None
while not self.stop_event.is_set():
last_observation, last_timestamp = self._obs_queue.get_latest()
if prev_timestamp is not None and prev_timestamp == last_timestamp:
# in case inference ran faster than recording/adding a new observation in the queue
time.sleep(0.1)
continue
pred_action_sequence = self.inference(last_observation)
for action, delta_ts in zip(pred_action_sequence, self.delta_timestamps, strict=False):
self._action_queue.add(action, last_timestamp + delta_ts)
prev_timestamp = last_timestamp
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
present_time = time.time()
self._obs_queue.add(batch, present_time)
if self.thread is None:
self.stop_event = threading.Event()
self.thread = Thread(target=self.inference_loop, args=())
self.thread.daemon = True
self.thread.start()
next_action = None
while next_action is None:
try:
next_action = self._action_queue.get(present_time)
except ValueError:
time.sleep(0.1) # no action available at this present time, we wait a bit
return next_action
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""Run the batch through the model and compute the loss for training or validation."""

View File

@ -13,6 +13,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import deque
import torch
from torch import nn
@ -47,3 +49,33 @@ def get_dtype_from_parameters(module: nn.Module) -> torch.dtype:
Note: assumes that all parameters have the same dtype.
"""
return next(iter(module.parameters())).dtype
class TemporalQueue:
def __init__(self, maxlen):
# TODO(rcadene): set proper maxlen
self.items = deque(maxlen=maxlen)
self.timestamps = deque(maxlen=maxlen)
def add(self, item, timestamp):
self.items.append(item)
self.timestamps.append(timestamp)
def get_latest(self):
return self.items[-1], self.timestamps[-1]
def get(self, timestamp):
import numpy as np
timestamps = np.array(list(self.timestamps))
distances = np.abs(timestamps - timestamp)
nearest_idx = distances.argmin()
# print(float(distances[nearest_idx]))
if float(distances[nearest_idx]) > 1 / 5:
raise ValueError()
return self.items[nearest_idx], self.timestamps[nearest_idx]
def __len__(self):
return len(self.items)

45
test2.py Normal file
View File

@ -0,0 +1,45 @@
import torch
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.policies.factory import make_policy
from lerobot.common.utils.utils import init_hydra_config, set_global_seed
from tests.utils import DEFAULT_CONFIG_PATH
def main(env_name, policy_name, extra_overrides):
cfg = init_hydra_config(
DEFAULT_CONFIG_PATH,
overrides=[
f"env={env_name}",
f"policy={policy_name}",
"device=cpu",
]
+ extra_overrides,
)
set_global_seed(1337)
dataset = make_dataset(cfg)
policy = make_policy(cfg, dataset_stats=dataset.stats)
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=0,
batch_size=1,
shuffle=False,
)
batch = next(iter(dataloader))
obs = {}
for k in batch:
if k.startswith("observation"):
obs[k] = batch[k]
actions = policy.inference(obs)
action, timestamp = policy.select_action(obs)
print(actions[0])
print(action)
if __name__ == "__main__":
main("aloha", "act", ["policy.n_action_steps=10"])