backup wip

This commit is contained in:
Alexander Soare 2024-04-16 16:31:44 +01:00
parent 43a614c173
commit 23be5e1e7b
4 changed files with 24 additions and 17 deletions

View File

@ -55,7 +55,7 @@ while not done:
for batch in dataloader:
for k in batch:
batch[k] = batch[k].to(device, non_blocking=True)
info = policy(batch)
info = policy.update(batch)
if step % log_freq == 0:
num_samples = (step + 1) * cfg.batch_size
loss = info["loss"]

View File

@ -161,6 +161,9 @@ class ActionChunkingTransformerPolicy(nn.Module):
if self.cfg.n_action_steps is not None:
self._action_queue = deque([], maxlen=self.cfg.n_action_steps)
def forward(self, batch: dict[str, Tensor], **_) -> Tensor:
return self.select_action(self, batch)
@torch.no_grad
def select_action(self, batch: dict[str, Tensor], **_) -> Tensor:
"""Select a single action given environment observations.
@ -172,23 +175,17 @@ class ActionChunkingTransformerPolicy(nn.Module):
if len(self._action_queue) == 0:
# `select_actions` returns a (batch_size, n_action_steps, *) tensor, but the queue effectively has
# shape (n_action_steps, batch_size, *), hence the transpose.
self._action_queue.extend(self.select_actions(batch).transpose(0, 1))
self._action_queue.extend(self._select_actions(batch).transpose(0, 1))
return self._action_queue.popleft()
@torch.no_grad
def select_actions(self, batch: dict[str, Tensor]) -> Tensor:
def _select_actions(self, batch: dict[str, Tensor]) -> Tensor:
"""Use the action chunking transformer to generate a sequence of actions."""
self.eval()
self._preprocess_batch(batch, add_obs_steps_dim=True)
action = self.forward(batch, return_loss=False)
return action[: self.cfg.n_action_steps]
def __call__(self, *args, **kwargs) -> dict:
# TODO(alexander-soare): Temporary bridge until we know what to do about the `update` method.
return self.update(*args, **kwargs)
def _preprocess_batch(
self, batch: dict[str, Tensor], add_obs_steps_dim: bool = False
) -> dict[str, Tensor]:
@ -216,9 +213,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
# Note: no squeeze is required for "observation.images.top" because then we'd have to unsqueeze to get
# the image index dimension.
def update(self, batch, **_) -> dict:
"""Run the model in train mode, compute the loss, and do an optimization step."""
start_time = time.time()
def compute_loss(self, batch, **_) -> float:
self._preprocess_batch(batch)
self.train()
@ -230,6 +225,12 @@ class ActionChunkingTransformerPolicy(nn.Module):
assert batch_size % num_slices == 0
loss = self.forward(batch, return_loss=True)["loss"]
return loss
def update(self, batch, **_) -> dict:
"""Run the model in train mode, compute the loss, and do an optimization step."""
start_time = time.time()
loss = self.compute_loss(batch)
loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(

View File

@ -251,7 +251,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
for key in batch:
batch[key] = batch[key].to(cfg.device, non_blocking=True)
train_info = policy(batch, step=step)
train_info = policy.update(batch, step=step)
# TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done?
if step % cfg.log_freq == 0:
@ -313,7 +313,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
for key in batch:
batch[key] = batch[key].to(cfg.device, non_blocking=True)
train_info = policy(batch, step)
train_info = policy.update(batch, step)
if step % cfg.log_freq == 0:
log_train_info(logger, train_info, step, cfg, online_dataset, is_offline)

View File

@ -4,11 +4,13 @@ import torch
from lerobot.common.datasets.utils import cycle
from lerobot.common.envs.utils import postprocess_action, preprocess_observation
from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.protocol import Policy
from lerobot.common.envs.factory import make_env
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.utils import init_hydra_config
from .utils import DEVICE, DEFAULT_CONFIG_PATH
@pytest.mark.parametrize(
"env_name,policy_name,extra_overrides",
[
@ -27,6 +29,7 @@ def test_policy(env_name, policy_name, extra_overrides):
"""
Tests:
- Making the policy object.
- Checking that the policy follows the correct protocol.
- Updating the policy.
- Using the policy to select actions at inference time.
- Test the action can be applied to the policy
@ -38,10 +41,14 @@ def test_policy(env_name, policy_name, extra_overrides):
f"policy={policy_name}",
f"device={DEVICE}",
]
+ extra_overrides
+ extra_overrides,
)
# Check that we can make the policy object.
policy = make_policy(cfg)
# Check that the policy follows the required protocol.
assert isinstance(
policy, Policy
), f"The policy does not follow the required protocol. Please see {Policy.__module__}.{Policy.__name__}."
# Check that we run select_actions and get the appropriate output.
dataset = make_dataset(cfg)
env = make_env(cfg, num_parallel_envs=2)
@ -62,7 +69,7 @@ def test_policy(env_name, policy_name, extra_overrides):
batch[key] = batch[key].to(DEVICE, non_blocking=True)
# Test updating the policy
policy(batch, step=0)
policy.update(batch, step=0)
# reset the policy and environment
policy.reset()
@ -83,4 +90,3 @@ def test_policy(env_name, policy_name, extra_overrides):
# Test step through policy
env.step(action)