WIP
This commit is contained in:
parent
b2e5f7fe2d
commit
ec1efc64b4
|
@ -20,8 +20,10 @@ The majority of changes here involve removing unused code, unifying naming, and
|
|||
"""
|
||||
|
||||
import math
|
||||
from collections import deque
|
||||
import threading
|
||||
import time
|
||||
from itertools import chain
|
||||
from threading import Thread
|
||||
from typing import Callable
|
||||
|
||||
import einops
|
||||
|
@ -36,6 +38,7 @@ from torchvision.ops.misc import FrozenBatchNorm2d
|
|||
|
||||
from lerobot.common.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.utils import TemporalQueue
|
||||
|
||||
|
||||
class ACTPolicy(
|
||||
|
@ -87,22 +90,23 @@ class ACTPolicy(
|
|||
self.temporal_ensembler = ACTTemporalEnsembler(config.temporal_ensemble_coeff, config.chunk_size)
|
||||
|
||||
self.reset()
|
||||
self.thread = None
|
||||
|
||||
# TODO(rcadene): Add delta timestamps in policy
|
||||
FPS = 10 # noqa: N806
|
||||
self.delta_timestamps = [i / FPS for i in range(self.config.n_action_steps)]
|
||||
|
||||
def reset(self):
|
||||
"""This should be called whenever the environment is reset."""
|
||||
if self.config.temporal_ensemble_coeff is not None:
|
||||
self.temporal_ensembler.reset()
|
||||
else:
|
||||
self._action_queue = deque([], maxlen=self.config.n_action_steps)
|
||||
# TODO(rcadene): set proper maxlen
|
||||
self._obs_queue = TemporalQueue(maxlen=1)
|
||||
self._action_queue = TemporalQueue(maxlen=200)
|
||||
|
||||
@torch.no_grad
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Select a single action given environment observations.
|
||||
|
||||
This method wraps `select_actions` in order to return one action at a time for execution in the
|
||||
environment. It works by managing the actions in a queue and only calling `select_actions` when the
|
||||
queue is empty.
|
||||
"""
|
||||
def inference(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
self.eval()
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
|
@ -118,18 +122,47 @@ class ACTPolicy(
|
|||
action = self.temporal_ensembler.update(actions)
|
||||
return action
|
||||
|
||||
# Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
|
||||
# querying the policy.
|
||||
if len(self._action_queue) == 0:
|
||||
actions = self.model(batch)[0][:, : self.config.n_action_steps]
|
||||
actions = self.model(batch)[0][:, : self.config.n_action_steps]
|
||||
|
||||
# TODO(rcadene): make _forward return output dictionary?
|
||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||
# TODO(rcadene): make _forward return output dictionary?
|
||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||
return actions
|
||||
|
||||
# `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
|
||||
# effectively has shape (n_action_steps, batch_size, *), hence the transpose.
|
||||
self._action_queue.extend(actions.transpose(0, 1))
|
||||
return self._action_queue.popleft()
|
||||
def inference_loop(self):
|
||||
prev_timestamp = None
|
||||
while not self.stop_event.is_set():
|
||||
last_observation, last_timestamp = self._obs_queue.get_latest()
|
||||
|
||||
if prev_timestamp is not None and prev_timestamp == last_timestamp:
|
||||
# in case inference ran faster than recording/adding a new observation in the queue
|
||||
time.sleep(0.1)
|
||||
continue
|
||||
|
||||
pred_action_sequence = self.inference(last_observation)
|
||||
|
||||
for action, delta_ts in zip(pred_action_sequence, self.delta_timestamps, strict=False):
|
||||
self._action_queue.add(action, last_timestamp + delta_ts)
|
||||
|
||||
prev_timestamp = last_timestamp
|
||||
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
present_time = time.time()
|
||||
self._obs_queue.add(batch, present_time)
|
||||
|
||||
if self.thread is None:
|
||||
self.stop_event = threading.Event()
|
||||
self.thread = Thread(target=self.inference_loop, args=())
|
||||
self.thread.daemon = True
|
||||
self.thread.start()
|
||||
|
||||
next_action = None
|
||||
while next_action is None:
|
||||
try:
|
||||
next_action = self._action_queue.get(present_time)
|
||||
except ValueError:
|
||||
time.sleep(0.1) # no action available at this present time, we wait a bit
|
||||
|
||||
return next_action
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
|
|
|
@ -13,6 +13,8 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from collections import deque
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
@ -47,3 +49,33 @@ def get_dtype_from_parameters(module: nn.Module) -> torch.dtype:
|
|||
Note: assumes that all parameters have the same dtype.
|
||||
"""
|
||||
return next(iter(module.parameters())).dtype
|
||||
|
||||
|
||||
class TemporalQueue:
|
||||
def __init__(self, maxlen):
|
||||
# TODO(rcadene): set proper maxlen
|
||||
self.items = deque(maxlen=maxlen)
|
||||
self.timestamps = deque(maxlen=maxlen)
|
||||
|
||||
def add(self, item, timestamp):
|
||||
self.items.append(item)
|
||||
self.timestamps.append(timestamp)
|
||||
|
||||
def get_latest(self):
|
||||
return self.items[-1], self.timestamps[-1]
|
||||
|
||||
def get(self, timestamp):
|
||||
import numpy as np
|
||||
|
||||
timestamps = np.array(list(self.timestamps))
|
||||
distances = np.abs(timestamps - timestamp)
|
||||
nearest_idx = distances.argmin()
|
||||
|
||||
# print(float(distances[nearest_idx]))
|
||||
if float(distances[nearest_idx]) > 1 / 5:
|
||||
raise ValueError()
|
||||
|
||||
return self.items[nearest_idx], self.timestamps[nearest_idx]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.items)
|
||||
|
|
|
@ -0,0 +1,45 @@
|
|||
import torch
|
||||
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.utils.utils import init_hydra_config, set_global_seed
|
||||
from tests.utils import DEFAULT_CONFIG_PATH
|
||||
|
||||
|
||||
def main(env_name, policy_name, extra_overrides):
|
||||
cfg = init_hydra_config(
|
||||
DEFAULT_CONFIG_PATH,
|
||||
overrides=[
|
||||
f"env={env_name}",
|
||||
f"policy={policy_name}",
|
||||
"device=cpu",
|
||||
]
|
||||
+ extra_overrides,
|
||||
)
|
||||
set_global_seed(1337)
|
||||
dataset = make_dataset(cfg)
|
||||
policy = make_policy(cfg, dataset_stats=dataset.stats)
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=0,
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
)
|
||||
batch = next(iter(dataloader))
|
||||
|
||||
obs = {}
|
||||
for k in batch:
|
||||
if k.startswith("observation"):
|
||||
obs[k] = batch[k]
|
||||
|
||||
actions = policy.inference(obs)
|
||||
|
||||
action, timestamp = policy.select_action(obs)
|
||||
|
||||
print(actions[0])
|
||||
print(action)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main("aloha", "act", ["policy.n_action_steps=10"])
|
Loading…
Reference in New Issue