backup wip
This commit is contained in:
parent
42e6a5e9b3
commit
a931d45993
|
@ -102,7 +102,7 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
}
|
||||
|
||||
@torch.no_grad
|
||||
def select_action(self, batch: dict[str, Tensor], update_queue: bool = False) -> Tensor:
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Select a single action given environment observations.
|
||||
|
||||
This method handles caching a history of observations and an action trajectory generated by the
|
||||
|
@ -128,18 +128,18 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
|
||||
self._queues = populate_queues(self._queues, batch)
|
||||
|
||||
# if len(self._queues["action"]) == 0:
|
||||
# stack n latest observations from the queue
|
||||
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
|
||||
actions = self.diffusion.generate_actions(batch)
|
||||
if len(self._queues["action"]) == 0:
|
||||
# stack n latest observations from the queue
|
||||
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
|
||||
actions = self.diffusion.generate_actions(batch)
|
||||
|
||||
# TODO(rcadene): make above methods return output dictionary?
|
||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||
return actions
|
||||
# self._queues["action"].extend(actions.transpose(0, 1))
|
||||
# TODO(rcadene): make above methods return output dictionary?
|
||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||
|
||||
# action = self._queues["action"].popleft()
|
||||
# return action
|
||||
self._queues["action"].extend(actions.transpose(0, 1))
|
||||
|
||||
action = self._queues["action"].popleft()
|
||||
return action
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
|
@ -243,9 +243,8 @@ class DiffusionModel(nn.Module):
|
|||
|
||||
# Extract `n_action_steps` steps worth of actions (from the current observation).
|
||||
start = n_obs_steps - 1
|
||||
# end = start + self.config.n_action_steps
|
||||
# actions = actions[:, start:end]
|
||||
actions[:, start:]
|
||||
end = start + self.config.n_action_steps
|
||||
actions = actions[:, start:end]
|
||||
|
||||
return actions
|
||||
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
fps: 10
|
||||
|
||||
|
||||
env:
|
||||
name: pusht
|
||||
task: PushT-v0
|
||||
|
|
|
@ -44,8 +44,10 @@ https://huggingface.co/lerobot/diffusion_pusht/tree/main.
|
|||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import threading
|
||||
import time
|
||||
from collections import deque
|
||||
from contextlib import nullcontext
|
||||
from copy import deepcopy
|
||||
from datetime import datetime as dt
|
||||
|
@ -75,6 +77,8 @@ from lerobot.common.policies.utils import get_device_from_parameters
|
|||
from lerobot.common.utils.io_utils import write_video
|
||||
from lerobot.common.utils.utils import get_safe_torch_device, init_hydra_config, init_logging, set_global_seed
|
||||
|
||||
LATENCY = True
|
||||
|
||||
|
||||
def rollout(
|
||||
env: gym.vector.VectorEnv,
|
||||
|
@ -121,7 +125,7 @@ def rollout(
|
|||
# Reset the policy and environments.
|
||||
policy.reset()
|
||||
|
||||
observation, info = env.reset(seed=seeds)
|
||||
gym_observation, info = env.reset(seed=seeds)
|
||||
if render_callback is not None:
|
||||
render_callback(env)
|
||||
|
||||
|
@ -131,7 +135,6 @@ def rollout(
|
|||
all_successes = []
|
||||
all_dones = []
|
||||
|
||||
step = 0
|
||||
# Keep track of which environments are done.
|
||||
done = np.array([False] * env.num_envs)
|
||||
max_steps = env.call("_max_episode_steps")[0]
|
||||
|
@ -141,9 +144,19 @@ def rollout(
|
|||
disable=not enable_progbar,
|
||||
leave=False,
|
||||
)
|
||||
first_step_done_for_latency_logic = False
|
||||
# If we are simulating latency, store a queue of actions along with their supposed eta (relative to now).
|
||||
# For example, if we do inference to get an action and it takes 200 ms to run, we store 0.2. Every loop
|
||||
# iteration we decrement this by 1 / fps.
|
||||
action_queue: deque[tuple[np.ndarray, float]] = deque()
|
||||
# If we are simulating latency, we will keep track of the number of clock cycles that we missed because
|
||||
# the policy was not done with inference on time.
|
||||
n_dropped_cycles = 0
|
||||
while not np.all(done):
|
||||
start_policy_time = time.perf_counter()
|
||||
|
||||
# Numpy array to tensor and changing dictionary keys to LeRobot policy format.
|
||||
observation = preprocess_observation(observation)
|
||||
observation = preprocess_observation(gym_observation)
|
||||
if return_observations:
|
||||
all_observations.append(deepcopy(observation))
|
||||
|
||||
|
@ -156,8 +169,90 @@ def rollout(
|
|||
action = action.to("cpu").numpy()
|
||||
assert action.ndim == 2, "Action dimensions should be (batch, action_dim)"
|
||||
|
||||
policy_latency = time.perf_counter() - start_policy_time
|
||||
|
||||
if LATENCY:
|
||||
# Note: We use this for the rendering frame rate, but also the clock frequency discussed below.
|
||||
fps = env.unwrapped.metadata["render_fps"]
|
||||
# Make some assumptions about the setup we are simulating:
|
||||
# 1. Suppose that observations and actions happen on the rising edge of the same clock. Therefore,
|
||||
# any action that is chosen based on an observation can't be executed till the next rising
|
||||
# edge. This means we have at least one clock cycle of delay (but it could be more if
|
||||
# inference takes longer than a clock cycle).
|
||||
# 2. Suppose that we do NOT have parallelism for processing multiple step's worth of observations
|
||||
# at once. The policy can only take in a new observation once it is done processing the last
|
||||
# one. If a policy is busy when an observation comes in, we miss the opportunity to process it.
|
||||
# 3. If we miss computing an action for a clock cycle (maybe there was high inference latency), we
|
||||
# adopt the strategy of repeating the most recently executed action.
|
||||
# To be clear, we aren't actually simulating these phenomena explicitly (none of this code runs
|
||||
# asynchronously). We are fudging the simulation by measuring inference time and delaying the
|
||||
# application of actions by some number of environment steps.
|
||||
# Note that this set of assumptions is by no means general. In real world settings we may have
|
||||
# a separate clock for action and observation and they may have different frequencies and phases.
|
||||
if not first_step_done_for_latency_logic:
|
||||
# For the fist step, we pretend the starting state was frozen in time and the policy had
|
||||
# unlimited time to decide on a first move.
|
||||
action_queue.append((action, -1)) # -1 indicating we already have the action ready to go
|
||||
first_step_done_for_latency_logic = True
|
||||
continue
|
||||
else:
|
||||
# Figure out how many cycles the action should be delayed by: floor(policy_latency / period).
|
||||
# For example, if n_delay_cycles == 3, it means that it will supposedly take between 2 and 3
|
||||
# clock cycles to predict the action based on the current observation. To account for this,
|
||||
# we'll end up using it in a future loop iteration (3 iterations into the future to be
|
||||
# precise).
|
||||
n_delay_cycles = int(math.ceil(policy_latency * fps))
|
||||
# First, we need to decide if we are even allowed to use the action. Assumption #2 from
|
||||
# above says that the policy can't process the current observation if it is still working
|
||||
# on a previous one (in reality we DID process it, but we may have to discard the result).
|
||||
if (pending_count := sum(item[-1] > 0 for item in action_queue)) > 0:
|
||||
# At least 1 of the items in the queue is (supposedly) still waiting to be processed!
|
||||
n_dropped_cycles += 1
|
||||
# In fact, we should hope that ONLY 1 item has a positive value (otherwise it means we
|
||||
# have supposedly allowed the policy to run concurrently for different steps).
|
||||
assert pending_count == 1
|
||||
elif n_delay_cycles == len(action_queue):
|
||||
# In this case, we just append onto the queue. For example, consider that we currently
|
||||
# have a queue:
|
||||
# [
|
||||
# action to be executed now,
|
||||
# action to be executed in 1 clock cycle,
|
||||
# action to be executed in 2 clock cycles,
|
||||
# ]
|
||||
# For brevity, we will write:
|
||||
# action to be executed in... [0, 1, 2] ... clock cycles.
|
||||
# And consider n_delay_cycles == 3. After appending onto the queue we'll have:
|
||||
# action to be executed in... [0, 1, 2, 3] ... clock cycles.
|
||||
action_queue.append((action, policy_latency))
|
||||
elif n_delay_cycles > len(action_queue):
|
||||
# This means that the latency may have been unusually high this time (or we are near
|
||||
# the start of the rollout). We can't add this action to the queue without forming a gap.
|
||||
# To fill the gap, we'll have to make do with repeating the last queued action.
|
||||
# For example, consider that we currently have a queue:
|
||||
# action to be executed in ... [0] ... clock cycles
|
||||
# And consider that we have n_delay_cycles == 3. We would then copy the first action twice
|
||||
# and append the currently predicted action onto the queue to get:
|
||||
# action to be executed in ... [0, 1, 2, 3] ... clock cycles
|
||||
# ^ ^ ^
|
||||
# copies of 0 ┴──┘ └─ new action
|
||||
while len(action_queue) < n_delay_cycles:
|
||||
action_queue.append((action_queue[-1][0].copy(), action_queue[-1][1]))
|
||||
action_queue.append((action, policy_latency))
|
||||
elif n_delay_cycles < len(action_queue):
|
||||
# This means something has gone wrong with this logic! Recall assumption #2. The policy
|
||||
# can't concurrently process observations from different steps.
|
||||
raise AssertionError("Something went wrong with the latency accounting logic.")
|
||||
|
||||
# Get the next action in the queue.
|
||||
action, policy_latency_ = action_queue.popleft()
|
||||
# Dev assertion. It should be that the policy was done producing this action in "the past".
|
||||
assert policy_latency_ <= 0
|
||||
# Shift all latencies by one clock cycle.
|
||||
for i, item in enumerate(action_queue):
|
||||
action_queue[i] = (item[0], item[-1] - 1 / fps)
|
||||
|
||||
# Apply the next action.
|
||||
observation, reward, terminated, truncated, info = env.step(action)
|
||||
gym_observation, reward, terminated, truncated, info = env.step(action)
|
||||
if render_callback is not None:
|
||||
render_callback(env)
|
||||
|
||||
|
@ -176,16 +271,18 @@ def rollout(
|
|||
all_dones.append(torch.from_numpy(done))
|
||||
all_successes.append(torch.tensor(successes))
|
||||
|
||||
step += 1
|
||||
running_success_rate = (
|
||||
einops.reduce(torch.stack(all_successes, dim=1), "b n -> b", "any").numpy().mean()
|
||||
)
|
||||
progbar.set_postfix({"running_success_rate": f"{running_success_rate.item() * 100:.1f}%"})
|
||||
progbar_postfix = {"running_success_rate": f"{running_success_rate.item() * 100:.1f}%"}
|
||||
if LATENCY:
|
||||
progbar_postfix.update({"n_dropped_cycles": n_dropped_cycles})
|
||||
progbar.set_postfix(progbar_postfix)
|
||||
progbar.update()
|
||||
|
||||
# Track the final observation.
|
||||
if return_observations:
|
||||
observation = preprocess_observation(observation)
|
||||
observation = preprocess_observation(gym_observation)
|
||||
all_observations.append(deepcopy(observation))
|
||||
|
||||
# Stack the sequence along the first dimension so that we have (batch, sequence, *) tensors.
|
||||
|
|
|
@ -82,7 +82,7 @@ def rollout(env: gym.vector.VectorEnv, policy: Policy, seed: int | None = None):
|
|||
start = time.time()
|
||||
# If we have less than some number of actions left in the queue, we need to start working on producing
|
||||
# the next chunk.
|
||||
if len(actions_queue) < 25 and not thread.is_alive():
|
||||
if len(actions_queue) < 2 and not thread.is_alive():
|
||||
thread = threading.Thread(target=run_policy)
|
||||
thread.start()
|
||||
# # Process the observation that we have right now to decide an action.
|
||||
|
|
Loading…
Reference in New Issue