From 6d0a45a97d0d04d324023fe5a1b50815085b14c4 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Fri, 12 Apr 2024 11:36:52 +0100 Subject: [PATCH] ready for review --- examples/3_train_policy.py | 2 -- lerobot/common/policies/act/policy.py | 2 +- .../diffusion/model/conditional_unet1d.py | 1 - .../model/diffusion_unet_image_policy.py | 14 +-------- lerobot/common/policies/diffusion/policy.py | 30 +++++-------------- lerobot/common/policies/utils.py | 2 -- lerobot/scripts/train.py | 2 +- 7 files changed, 11 insertions(+), 42 deletions(-) diff --git a/examples/3_train_policy.py b/examples/3_train_policy.py index 238f953d..d2fff13b 100644 --- a/examples/3_train_policy.py +++ b/examples/3_train_policy.py @@ -32,8 +32,6 @@ policy = DiffusionPolicy( cfg=cfg.policy, cfg_device=cfg.device, cfg_noise_scheduler=cfg.noise_scheduler, - cfg_rgb_model=cfg.rgb_model, - cfg_obs_encoder=cfg.obs_encoder, cfg_optimizer=cfg.optimizer, cfg_ema=cfg.ema, **cfg.policy, diff --git a/lerobot/common/policies/act/policy.py b/lerobot/common/policies/act/policy.py index 821b0196..24667795 100644 --- a/lerobot/common/policies/act/policy.py +++ b/lerobot/common/policies/act/policy.py @@ -213,7 +213,7 @@ class ActionChunkingTransformerPolicy(nn.Module): return action[: self.n_action_steps] def __call__(self, *args, **kwargs) -> dict: - # TODO(now): Temporary bridge until we know what to do about the `update` method. + # TODO(alexander-soare): Temporary bridge until we know what to do about the `update` method. return self.update(*args, **kwargs) def _preprocess_batch( diff --git a/lerobot/common/policies/diffusion/model/conditional_unet1d.py b/lerobot/common/policies/diffusion/model/conditional_unet1d.py index 5c43d488..c3dcc198 100644 --- a/lerobot/common/policies/diffusion/model/conditional_unet1d.py +++ b/lerobot/common/policies/diffusion/model/conditional_unet1d.py @@ -10,7 +10,6 @@ logger = logging.getLogger(__name__) class _SinusoidalPosEmb(nn.Module): - # TODO(now): consolidate? def __init__(self, dim): super().__init__() self.dim = dim diff --git a/lerobot/common/policies/diffusion/model/diffusion_unet_image_policy.py b/lerobot/common/policies/diffusion/model/diffusion_unet_image_policy.py index 92928c70..3e7727f3 100644 --- a/lerobot/common/policies/diffusion/model/diffusion_unet_image_policy.py +++ b/lerobot/common/policies/diffusion/model/diffusion_unet_image_policy.py @@ -10,18 +10,6 @@ from lerobot.common.policies.utils import get_device_from_parameters, get_dtype_ class DiffusionUnetImagePolicy(nn.Module): - """ - TODO(now): Add DDIM scheduler. - - Changes: TODO(now) - - Use single image encoder for now instead of generic obs_encoder. We may generalize again when/if - needed. Code for a general observation encoder can be found at: - https://github.com/huggingface/lerobot/blob/920e0d118b493e4cc3058a9b1b764f38ae145d8e/lerobot/common/policies/diffusion/model/multi_image_obs_encoder.py - - Uses the observation as global conditioning for the Unet by default. - - Does not do any inpainting (which would be applicable if the observation were not used to condition - the Unet). - """ - def __init__( self, cfg, @@ -87,7 +75,7 @@ class DiffusionUnetImagePolicy(nn.Module): torch.full(sample.shape[:1], t, dtype=torch.long, device=sample.device), global_cond=global_cond, ) - # Compute previous image: x_t -> x_t-1 # TODO(now): Is this right? + # Compute previous image: x_t -> x_t-1 sample = self.noise_scheduler.step(model_output, t, sample, generator=generator).prev_sample return sample diff --git a/lerobot/common/policies/diffusion/policy.py b/lerobot/common/policies/diffusion/policy.py index b1713869..f88e2f25 100644 --- a/lerobot/common/policies/diffusion/policy.py +++ b/lerobot/common/policies/diffusion/policy.py @@ -6,7 +6,7 @@ from collections import deque import hydra import torch from diffusers.optimization import get_scheduler -from torch import Tensor, nn +from torch import nn from lerobot.common.policies.diffusion.model.diffusion_unet_image_policy import DiffusionUnetImagePolicy from lerobot.common.policies.utils import populate_queues @@ -43,7 +43,6 @@ class DiffusionPolicy(nn.Module): # queues are populated during rollout of the policy, they contain the n latest observations and actions self._queues = None - # TODO(now): In-house this. noise_scheduler = hydra.utils.instantiate(cfg_noise_scheduler) self.diffusion = DiffusionUnetImagePolicy( @@ -103,45 +102,35 @@ class DiffusionPolicy(nn.Module): "action": deque(maxlen=self.n_action_steps), } - def forward(self, batch: dict[str, Tensor], **_) -> Tensor: - """A forward pass through the DNN part of this policy with optional loss computation.""" - return self.select_action(batch) - @torch.no_grad def select_action(self, batch, **_): """ Note: this uses the ema model weights if self.training == False, otherwise the non-ema model weights. - # TODO(now): Handle a batch """ assert "observation.image" in batch assert "observation.state" in batch - assert len(batch) == 2 # TODO(now): Does this not have a batch dim? + assert len(batch) == 2 self._queues = populate_queues(self._queues, batch) if len(self._queues["action"]) == 0: # stack n latest observations from the queue batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch} - actions = self._generate_actions(batch) + if not self.training and self.ema_diffusion is not None: + actions = self.ema_diffusion.generate_actions(batch) + else: + actions = self.diffusion.generate_actions(batch) self._queues["action"].extend(actions.transpose(0, 1)) action = self._queues["action"].popleft() return action - def _generate_actions(self, batch): - if not self.training and self.ema_diffusion is not None: - return self.ema_diffusion.generate_actions(batch) - else: - return self.diffusion.generate_actions(batch) - - def update(self, batch, **_): - """Run the model in train mode, compute the loss, and do an optimization step.""" + def forward(self, batch, **_): start_time = time.time() self.diffusion.train() - loss = self.compute_loss(batch) - + loss = self.diffusion.compute_loss(batch) loss.backward() grad_norm = torch.nn.utils.clip_grad_norm_( @@ -166,9 +155,6 @@ class DiffusionPolicy(nn.Module): return info - def compute_loss(self, batch: dict[str, Tensor]) -> Tensor: - return self.diffusion.compute_loss(batch) - def save(self, fp): torch.save(self.state_dict(), fp) diff --git a/lerobot/common/policies/utils.py b/lerobot/common/policies/utils.py index 9d4b42f0..b23c1336 100644 --- a/lerobot/common/policies/utils.py +++ b/lerobot/common/policies/utils.py @@ -18,7 +18,6 @@ def get_device_from_parameters(module: nn.Module) -> torch.device: """Get a module's device by checking one of its parameters. Note: assumes that all parameters have the same device - TODO(now): Add test. """ return next(iter(module.parameters())).device @@ -27,6 +26,5 @@ def get_dtype_from_parameters(module: nn.Module) -> torch.dtype: """Get a module's parameter dtype by checking one of its parameters. Note: assumes that all parameters have the same dtype. - TODO(now): Add test. """ return next(iter(module.parameters())).dtype diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 300a8617..5ff6538d 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -251,7 +251,7 @@ def train(cfg: dict, out_dir=None, job_name=None): for key in batch: batch[key] = batch[key].to(cfg.device, non_blocking=True) - train_info = policy.update(batch, step=step) + train_info = policy(batch, step=step) # TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done? if step % cfg.log_freq == 0: