WIP
This commit is contained in:
parent
b2e5f7fe2d
commit
ec1efc64b4
|
@ -20,8 +20,10 @@ The majority of changes here involve removing unused code, unifying naming, and
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from collections import deque
|
import threading
|
||||||
|
import time
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
|
from threading import Thread
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
|
||||||
import einops
|
import einops
|
||||||
|
@ -36,6 +38,7 @@ from torchvision.ops.misc import FrozenBatchNorm2d
|
||||||
|
|
||||||
from lerobot.common.policies.act.configuration_act import ACTConfig
|
from lerobot.common.policies.act.configuration_act import ACTConfig
|
||||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||||
|
from lerobot.common.policies.utils import TemporalQueue
|
||||||
|
|
||||||
|
|
||||||
class ACTPolicy(
|
class ACTPolicy(
|
||||||
|
@ -87,22 +90,23 @@ class ACTPolicy(
|
||||||
self.temporal_ensembler = ACTTemporalEnsembler(config.temporal_ensemble_coeff, config.chunk_size)
|
self.temporal_ensembler = ACTTemporalEnsembler(config.temporal_ensemble_coeff, config.chunk_size)
|
||||||
|
|
||||||
self.reset()
|
self.reset()
|
||||||
|
self.thread = None
|
||||||
|
|
||||||
|
# TODO(rcadene): Add delta timestamps in policy
|
||||||
|
FPS = 10 # noqa: N806
|
||||||
|
self.delta_timestamps = [i / FPS for i in range(self.config.n_action_steps)]
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
"""This should be called whenever the environment is reset."""
|
"""This should be called whenever the environment is reset."""
|
||||||
if self.config.temporal_ensemble_coeff is not None:
|
if self.config.temporal_ensemble_coeff is not None:
|
||||||
self.temporal_ensembler.reset()
|
self.temporal_ensembler.reset()
|
||||||
else:
|
else:
|
||||||
self._action_queue = deque([], maxlen=self.config.n_action_steps)
|
# TODO(rcadene): set proper maxlen
|
||||||
|
self._obs_queue = TemporalQueue(maxlen=1)
|
||||||
|
self._action_queue = TemporalQueue(maxlen=200)
|
||||||
|
|
||||||
@torch.no_grad
|
@torch.no_grad
|
||||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
def inference(self, batch: dict[str, Tensor]) -> Tensor:
|
||||||
"""Select a single action given environment observations.
|
|
||||||
|
|
||||||
This method wraps `select_actions` in order to return one action at a time for execution in the
|
|
||||||
environment. It works by managing the actions in a queue and only calling `select_actions` when the
|
|
||||||
queue is empty.
|
|
||||||
"""
|
|
||||||
self.eval()
|
self.eval()
|
||||||
|
|
||||||
batch = self.normalize_inputs(batch)
|
batch = self.normalize_inputs(batch)
|
||||||
|
@ -118,18 +122,47 @@ class ACTPolicy(
|
||||||
action = self.temporal_ensembler.update(actions)
|
action = self.temporal_ensembler.update(actions)
|
||||||
return action
|
return action
|
||||||
|
|
||||||
# Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
|
|
||||||
# querying the policy.
|
|
||||||
if len(self._action_queue) == 0:
|
|
||||||
actions = self.model(batch)[0][:, : self.config.n_action_steps]
|
actions = self.model(batch)[0][:, : self.config.n_action_steps]
|
||||||
|
|
||||||
# TODO(rcadene): make _forward return output dictionary?
|
# TODO(rcadene): make _forward return output dictionary?
|
||||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||||
|
return actions
|
||||||
|
|
||||||
# `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
|
def inference_loop(self):
|
||||||
# effectively has shape (n_action_steps, batch_size, *), hence the transpose.
|
prev_timestamp = None
|
||||||
self._action_queue.extend(actions.transpose(0, 1))
|
while not self.stop_event.is_set():
|
||||||
return self._action_queue.popleft()
|
last_observation, last_timestamp = self._obs_queue.get_latest()
|
||||||
|
|
||||||
|
if prev_timestamp is not None and prev_timestamp == last_timestamp:
|
||||||
|
# in case inference ran faster than recording/adding a new observation in the queue
|
||||||
|
time.sleep(0.1)
|
||||||
|
continue
|
||||||
|
|
||||||
|
pred_action_sequence = self.inference(last_observation)
|
||||||
|
|
||||||
|
for action, delta_ts in zip(pred_action_sequence, self.delta_timestamps, strict=False):
|
||||||
|
self._action_queue.add(action, last_timestamp + delta_ts)
|
||||||
|
|
||||||
|
prev_timestamp = last_timestamp
|
||||||
|
|
||||||
|
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||||
|
present_time = time.time()
|
||||||
|
self._obs_queue.add(batch, present_time)
|
||||||
|
|
||||||
|
if self.thread is None:
|
||||||
|
self.stop_event = threading.Event()
|
||||||
|
self.thread = Thread(target=self.inference_loop, args=())
|
||||||
|
self.thread.daemon = True
|
||||||
|
self.thread.start()
|
||||||
|
|
||||||
|
next_action = None
|
||||||
|
while next_action is None:
|
||||||
|
try:
|
||||||
|
next_action = self._action_queue.get(present_time)
|
||||||
|
except ValueError:
|
||||||
|
time.sleep(0.1) # no action available at this present time, we wait a bit
|
||||||
|
|
||||||
|
return next_action
|
||||||
|
|
||||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||||
"""Run the batch through the model and compute the loss for training or validation."""
|
"""Run the batch through the model and compute the loss for training or validation."""
|
||||||
|
|
|
@ -13,6 +13,8 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
from collections import deque
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
|
@ -47,3 +49,33 @@ def get_dtype_from_parameters(module: nn.Module) -> torch.dtype:
|
||||||
Note: assumes that all parameters have the same dtype.
|
Note: assumes that all parameters have the same dtype.
|
||||||
"""
|
"""
|
||||||
return next(iter(module.parameters())).dtype
|
return next(iter(module.parameters())).dtype
|
||||||
|
|
||||||
|
|
||||||
|
class TemporalQueue:
|
||||||
|
def __init__(self, maxlen):
|
||||||
|
# TODO(rcadene): set proper maxlen
|
||||||
|
self.items = deque(maxlen=maxlen)
|
||||||
|
self.timestamps = deque(maxlen=maxlen)
|
||||||
|
|
||||||
|
def add(self, item, timestamp):
|
||||||
|
self.items.append(item)
|
||||||
|
self.timestamps.append(timestamp)
|
||||||
|
|
||||||
|
def get_latest(self):
|
||||||
|
return self.items[-1], self.timestamps[-1]
|
||||||
|
|
||||||
|
def get(self, timestamp):
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
timestamps = np.array(list(self.timestamps))
|
||||||
|
distances = np.abs(timestamps - timestamp)
|
||||||
|
nearest_idx = distances.argmin()
|
||||||
|
|
||||||
|
# print(float(distances[nearest_idx]))
|
||||||
|
if float(distances[nearest_idx]) > 1 / 5:
|
||||||
|
raise ValueError()
|
||||||
|
|
||||||
|
return self.items[nearest_idx], self.timestamps[nearest_idx]
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.items)
|
||||||
|
|
|
@ -0,0 +1,45 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from lerobot.common.datasets.factory import make_dataset
|
||||||
|
from lerobot.common.policies.factory import make_policy
|
||||||
|
from lerobot.common.utils.utils import init_hydra_config, set_global_seed
|
||||||
|
from tests.utils import DEFAULT_CONFIG_PATH
|
||||||
|
|
||||||
|
|
||||||
|
def main(env_name, policy_name, extra_overrides):
|
||||||
|
cfg = init_hydra_config(
|
||||||
|
DEFAULT_CONFIG_PATH,
|
||||||
|
overrides=[
|
||||||
|
f"env={env_name}",
|
||||||
|
f"policy={policy_name}",
|
||||||
|
"device=cpu",
|
||||||
|
]
|
||||||
|
+ extra_overrides,
|
||||||
|
)
|
||||||
|
set_global_seed(1337)
|
||||||
|
dataset = make_dataset(cfg)
|
||||||
|
policy = make_policy(cfg, dataset_stats=dataset.stats)
|
||||||
|
|
||||||
|
dataloader = torch.utils.data.DataLoader(
|
||||||
|
dataset,
|
||||||
|
num_workers=0,
|
||||||
|
batch_size=1,
|
||||||
|
shuffle=False,
|
||||||
|
)
|
||||||
|
batch = next(iter(dataloader))
|
||||||
|
|
||||||
|
obs = {}
|
||||||
|
for k in batch:
|
||||||
|
if k.startswith("observation"):
|
||||||
|
obs[k] = batch[k]
|
||||||
|
|
||||||
|
actions = policy.inference(obs)
|
||||||
|
|
||||||
|
action, timestamp = policy.select_action(obs)
|
||||||
|
|
||||||
|
print(actions[0])
|
||||||
|
print(action)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main("aloha", "act", ["policy.n_action_steps=10"])
|
Loading…
Reference in New Issue