ready for review

This commit is contained in:
Alexander Soare 2024-04-12 11:36:52 +01:00
parent 5666ec3ec7
commit 6d0a45a97d
7 changed files with 11 additions and 42 deletions

View File

@ -32,8 +32,6 @@ policy = DiffusionPolicy(
cfg=cfg.policy,
cfg_device=cfg.device,
cfg_noise_scheduler=cfg.noise_scheduler,
cfg_rgb_model=cfg.rgb_model,
cfg_obs_encoder=cfg.obs_encoder,
cfg_optimizer=cfg.optimizer,
cfg_ema=cfg.ema,
**cfg.policy,

View File

@ -213,7 +213,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
return action[: self.n_action_steps]
def __call__(self, *args, **kwargs) -> dict:
# TODO(now): Temporary bridge until we know what to do about the `update` method.
# TODO(alexander-soare): Temporary bridge until we know what to do about the `update` method.
return self.update(*args, **kwargs)
def _preprocess_batch(

View File

@ -10,7 +10,6 @@ logger = logging.getLogger(__name__)
class _SinusoidalPosEmb(nn.Module):
# TODO(now): consolidate?
def __init__(self, dim):
super().__init__()
self.dim = dim

View File

@ -10,18 +10,6 @@ from lerobot.common.policies.utils import get_device_from_parameters, get_dtype_
class DiffusionUnetImagePolicy(nn.Module):
"""
TODO(now): Add DDIM scheduler.
Changes: TODO(now)
- Use single image encoder for now instead of generic obs_encoder. We may generalize again when/if
needed. Code for a general observation encoder can be found at:
https://github.com/huggingface/lerobot/blob/920e0d118b493e4cc3058a9b1b764f38ae145d8e/lerobot/common/policies/diffusion/model/multi_image_obs_encoder.py
- Uses the observation as global conditioning for the Unet by default.
- Does not do any inpainting (which would be applicable if the observation were not used to condition
the Unet).
"""
def __init__(
self,
cfg,
@ -87,7 +75,7 @@ class DiffusionUnetImagePolicy(nn.Module):
torch.full(sample.shape[:1], t, dtype=torch.long, device=sample.device),
global_cond=global_cond,
)
# Compute previous image: x_t -> x_t-1 # TODO(now): Is this right?
# Compute previous image: x_t -> x_t-1
sample = self.noise_scheduler.step(model_output, t, sample, generator=generator).prev_sample
return sample

View File

@ -6,7 +6,7 @@ from collections import deque
import hydra
import torch
from diffusers.optimization import get_scheduler
from torch import Tensor, nn
from torch import nn
from lerobot.common.policies.diffusion.model.diffusion_unet_image_policy import DiffusionUnetImagePolicy
from lerobot.common.policies.utils import populate_queues
@ -43,7 +43,6 @@ class DiffusionPolicy(nn.Module):
# queues are populated during rollout of the policy, they contain the n latest observations and actions
self._queues = None
# TODO(now): In-house this.
noise_scheduler = hydra.utils.instantiate(cfg_noise_scheduler)
self.diffusion = DiffusionUnetImagePolicy(
@ -103,45 +102,35 @@ class DiffusionPolicy(nn.Module):
"action": deque(maxlen=self.n_action_steps),
}
def forward(self, batch: dict[str, Tensor], **_) -> Tensor:
"""A forward pass through the DNN part of this policy with optional loss computation."""
return self.select_action(batch)
@torch.no_grad
def select_action(self, batch, **_):
"""
Note: this uses the ema model weights if self.training == False, otherwise the non-ema model weights.
# TODO(now): Handle a batch
"""
assert "observation.image" in batch
assert "observation.state" in batch
assert len(batch) == 2 # TODO(now): Does this not have a batch dim?
assert len(batch) == 2
self._queues = populate_queues(self._queues, batch)
if len(self._queues["action"]) == 0:
# stack n latest observations from the queue
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch}
actions = self._generate_actions(batch)
if not self.training and self.ema_diffusion is not None:
actions = self.ema_diffusion.generate_actions(batch)
else:
actions = self.diffusion.generate_actions(batch)
self._queues["action"].extend(actions.transpose(0, 1))
action = self._queues["action"].popleft()
return action
def _generate_actions(self, batch):
if not self.training and self.ema_diffusion is not None:
return self.ema_diffusion.generate_actions(batch)
else:
return self.diffusion.generate_actions(batch)
def update(self, batch, **_):
"""Run the model in train mode, compute the loss, and do an optimization step."""
def forward(self, batch, **_):
start_time = time.time()
self.diffusion.train()
loss = self.compute_loss(batch)
loss = self.diffusion.compute_loss(batch)
loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(
@ -166,9 +155,6 @@ class DiffusionPolicy(nn.Module):
return info
def compute_loss(self, batch: dict[str, Tensor]) -> Tensor:
return self.diffusion.compute_loss(batch)
def save(self, fp):
torch.save(self.state_dict(), fp)

View File

@ -18,7 +18,6 @@ def get_device_from_parameters(module: nn.Module) -> torch.device:
"""Get a module's device by checking one of its parameters.
Note: assumes that all parameters have the same device
TODO(now): Add test.
"""
return next(iter(module.parameters())).device
@ -27,6 +26,5 @@ def get_dtype_from_parameters(module: nn.Module) -> torch.dtype:
"""Get a module's parameter dtype by checking one of its parameters.
Note: assumes that all parameters have the same dtype.
TODO(now): Add test.
"""
return next(iter(module.parameters())).dtype

View File

@ -251,7 +251,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
for key in batch:
batch[key] = batch[key].to(cfg.device, non_blocking=True)
train_info = policy.update(batch, step=step)
train_info = policy(batch, step=step)
# TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done?
if step % cfg.log_freq == 0: