backup wip

This commit is contained in:
Alexander Soare 2024-06-11 07:36:44 +01:00
parent 42e6a5e9b3
commit a931d45993
4 changed files with 119 additions and 22 deletions

View File

@ -102,7 +102,7 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
}
@torch.no_grad
def select_action(self, batch: dict[str, Tensor], update_queue: bool = False) -> Tensor:
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select a single action given environment observations.
This method handles caching a history of observations and an action trajectory generated by the
@ -128,18 +128,18 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
self._queues = populate_queues(self._queues, batch)
# if len(self._queues["action"]) == 0:
# stack n latest observations from the queue
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
actions = self.diffusion.generate_actions(batch)
if len(self._queues["action"]) == 0:
# stack n latest observations from the queue
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
actions = self.diffusion.generate_actions(batch)
# TODO(rcadene): make above methods return output dictionary?
actions = self.unnormalize_outputs({"action": actions})["action"]
return actions
# self._queues["action"].extend(actions.transpose(0, 1))
# TODO(rcadene): make above methods return output dictionary?
actions = self.unnormalize_outputs({"action": actions})["action"]
# action = self._queues["action"].popleft()
# return action
self._queues["action"].extend(actions.transpose(0, 1))
action = self._queues["action"].popleft()
return action
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""Run the batch through the model and compute the loss for training or validation."""
@ -243,9 +243,8 @@ class DiffusionModel(nn.Module):
# Extract `n_action_steps` steps worth of actions (from the current observation).
start = n_obs_steps - 1
# end = start + self.config.n_action_steps
# actions = actions[:, start:end]
actions[:, start:]
end = start + self.config.n_action_steps
actions = actions[:, start:end]
return actions

View File

@ -2,6 +2,7 @@
fps: 10
env:
name: pusht
task: PushT-v0

View File

@ -44,8 +44,10 @@ https://huggingface.co/lerobot/diffusion_pusht/tree/main.
import argparse
import json
import logging
import math
import threading
import time
from collections import deque
from contextlib import nullcontext
from copy import deepcopy
from datetime import datetime as dt
@ -75,6 +77,8 @@ from lerobot.common.policies.utils import get_device_from_parameters
from lerobot.common.utils.io_utils import write_video
from lerobot.common.utils.utils import get_safe_torch_device, init_hydra_config, init_logging, set_global_seed
LATENCY = True
def rollout(
env: gym.vector.VectorEnv,
@ -121,7 +125,7 @@ def rollout(
# Reset the policy and environments.
policy.reset()
observation, info = env.reset(seed=seeds)
gym_observation, info = env.reset(seed=seeds)
if render_callback is not None:
render_callback(env)
@ -131,7 +135,6 @@ def rollout(
all_successes = []
all_dones = []
step = 0
# Keep track of which environments are done.
done = np.array([False] * env.num_envs)
max_steps = env.call("_max_episode_steps")[0]
@ -141,9 +144,19 @@ def rollout(
disable=not enable_progbar,
leave=False,
)
first_step_done_for_latency_logic = False
# If we are simulating latency, store a queue of actions along with their supposed eta (relative to now).
# For example, if we do inference to get an action and it takes 200 ms to run, we store 0.2. Every loop
# iteration we decrement this by 1 / fps.
action_queue: deque[tuple[np.ndarray, float]] = deque()
# If we are simulating latency, we will keep track of the number of clock cycles that we missed because
# the policy was not done with inference on time.
n_dropped_cycles = 0
while not np.all(done):
start_policy_time = time.perf_counter()
# Numpy array to tensor and changing dictionary keys to LeRobot policy format.
observation = preprocess_observation(observation)
observation = preprocess_observation(gym_observation)
if return_observations:
all_observations.append(deepcopy(observation))
@ -156,8 +169,90 @@ def rollout(
action = action.to("cpu").numpy()
assert action.ndim == 2, "Action dimensions should be (batch, action_dim)"
policy_latency = time.perf_counter() - start_policy_time
if LATENCY:
# Note: We use this for the rendering frame rate, but also the clock frequency discussed below.
fps = env.unwrapped.metadata["render_fps"]
# Make some assumptions about the setup we are simulating:
# 1. Suppose that observations and actions happen on the rising edge of the same clock. Therefore,
# any action that is chosen based on an observation can't be executed till the next rising
# edge. This means we have at least one clock cycle of delay (but it could be more if
# inference takes longer than a clock cycle).
# 2. Suppose that we do NOT have parallelism for processing multiple step's worth of observations
# at once. The policy can only take in a new observation once it is done processing the last
# one. If a policy is busy when an observation comes in, we miss the opportunity to process it.
# 3. If we miss computing an action for a clock cycle (maybe there was high inference latency), we
# adopt the strategy of repeating the most recently executed action.
# To be clear, we aren't actually simulating these phenomena explicitly (none of this code runs
# asynchronously). We are fudging the simulation by measuring inference time and delaying the
# application of actions by some number of environment steps.
# Note that this set of assumptions is by no means general. In real world settings we may have
# a separate clock for action and observation and they may have different frequencies and phases.
if not first_step_done_for_latency_logic:
# For the fist step, we pretend the starting state was frozen in time and the policy had
# unlimited time to decide on a first move.
action_queue.append((action, -1)) # -1 indicating we already have the action ready to go
first_step_done_for_latency_logic = True
continue
else:
# Figure out how many cycles the action should be delayed by: floor(policy_latency / period).
# For example, if n_delay_cycles == 3, it means that it will supposedly take between 2 and 3
# clock cycles to predict the action based on the current observation. To account for this,
# we'll end up using it in a future loop iteration (3 iterations into the future to be
# precise).
n_delay_cycles = int(math.ceil(policy_latency * fps))
# First, we need to decide if we are even allowed to use the action. Assumption #2 from
# above says that the policy can't process the current observation if it is still working
# on a previous one (in reality we DID process it, but we may have to discard the result).
if (pending_count := sum(item[-1] > 0 for item in action_queue)) > 0:
# At least 1 of the items in the queue is (supposedly) still waiting to be processed!
n_dropped_cycles += 1
# In fact, we should hope that ONLY 1 item has a positive value (otherwise it means we
# have supposedly allowed the policy to run concurrently for different steps).
assert pending_count == 1
elif n_delay_cycles == len(action_queue):
# In this case, we just append onto the queue. For example, consider that we currently
# have a queue:
# [
# action to be executed now,
# action to be executed in 1 clock cycle,
# action to be executed in 2 clock cycles,
# ]
# For brevity, we will write:
# action to be executed in... [0, 1, 2] ... clock cycles.
# And consider n_delay_cycles == 3. After appending onto the queue we'll have:
# action to be executed in... [0, 1, 2, 3] ... clock cycles.
action_queue.append((action, policy_latency))
elif n_delay_cycles > len(action_queue):
# This means that the latency may have been unusually high this time (or we are near
# the start of the rollout). We can't add this action to the queue without forming a gap.
# To fill the gap, we'll have to make do with repeating the last queued action.
# For example, consider that we currently have a queue:
# action to be executed in ... [0] ... clock cycles
# And consider that we have n_delay_cycles == 3. We would then copy the first action twice
# and append the currently predicted action onto the queue to get:
# action to be executed in ... [0, 1, 2, 3] ... clock cycles
# ^ ^ ^
# copies of 0 ┴──┘ └─ new action
while len(action_queue) < n_delay_cycles:
action_queue.append((action_queue[-1][0].copy(), action_queue[-1][1]))
action_queue.append((action, policy_latency))
elif n_delay_cycles < len(action_queue):
# This means something has gone wrong with this logic! Recall assumption #2. The policy
# can't concurrently process observations from different steps.
raise AssertionError("Something went wrong with the latency accounting logic.")
# Get the next action in the queue.
action, policy_latency_ = action_queue.popleft()
# Dev assertion. It should be that the policy was done producing this action in "the past".
assert policy_latency_ <= 0
# Shift all latencies by one clock cycle.
for i, item in enumerate(action_queue):
action_queue[i] = (item[0], item[-1] - 1 / fps)
# Apply the next action.
observation, reward, terminated, truncated, info = env.step(action)
gym_observation, reward, terminated, truncated, info = env.step(action)
if render_callback is not None:
render_callback(env)
@ -176,16 +271,18 @@ def rollout(
all_dones.append(torch.from_numpy(done))
all_successes.append(torch.tensor(successes))
step += 1
running_success_rate = (
einops.reduce(torch.stack(all_successes, dim=1), "b n -> b", "any").numpy().mean()
)
progbar.set_postfix({"running_success_rate": f"{running_success_rate.item() * 100:.1f}%"})
progbar_postfix = {"running_success_rate": f"{running_success_rate.item() * 100:.1f}%"}
if LATENCY:
progbar_postfix.update({"n_dropped_cycles": n_dropped_cycles})
progbar.set_postfix(progbar_postfix)
progbar.update()
# Track the final observation.
if return_observations:
observation = preprocess_observation(observation)
observation = preprocess_observation(gym_observation)
all_observations.append(deepcopy(observation))
# Stack the sequence along the first dimension so that we have (batch, sequence, *) tensors.

View File

@ -82,7 +82,7 @@ def rollout(env: gym.vector.VectorEnv, policy: Policy, seed: int | None = None):
start = time.time()
# If we have less than some number of actions left in the queue, we need to start working on producing
# the next chunk.
if len(actions_queue) < 25 and not thread.is_alive():
if len(actions_queue) < 2 and not thread.is_alive():
thread = threading.Thread(target=run_policy)
thread.start()
# # Process the observation that we have right now to decide an action.