ready for review
This commit is contained in:
parent
5666ec3ec7
commit
6d0a45a97d
|
@ -32,8 +32,6 @@ policy = DiffusionPolicy(
|
|||
cfg=cfg.policy,
|
||||
cfg_device=cfg.device,
|
||||
cfg_noise_scheduler=cfg.noise_scheduler,
|
||||
cfg_rgb_model=cfg.rgb_model,
|
||||
cfg_obs_encoder=cfg.obs_encoder,
|
||||
cfg_optimizer=cfg.optimizer,
|
||||
cfg_ema=cfg.ema,
|
||||
**cfg.policy,
|
||||
|
|
|
@ -213,7 +213,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
|||
return action[: self.n_action_steps]
|
||||
|
||||
def __call__(self, *args, **kwargs) -> dict:
|
||||
# TODO(now): Temporary bridge until we know what to do about the `update` method.
|
||||
# TODO(alexander-soare): Temporary bridge until we know what to do about the `update` method.
|
||||
return self.update(*args, **kwargs)
|
||||
|
||||
def _preprocess_batch(
|
||||
|
|
|
@ -10,7 +10,6 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class _SinusoidalPosEmb(nn.Module):
|
||||
# TODO(now): consolidate?
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
|
|
@ -10,18 +10,6 @@ from lerobot.common.policies.utils import get_device_from_parameters, get_dtype_
|
|||
|
||||
|
||||
class DiffusionUnetImagePolicy(nn.Module):
|
||||
"""
|
||||
TODO(now): Add DDIM scheduler.
|
||||
|
||||
Changes: TODO(now)
|
||||
- Use single image encoder for now instead of generic obs_encoder. We may generalize again when/if
|
||||
needed. Code for a general observation encoder can be found at:
|
||||
https://github.com/huggingface/lerobot/blob/920e0d118b493e4cc3058a9b1b764f38ae145d8e/lerobot/common/policies/diffusion/model/multi_image_obs_encoder.py
|
||||
- Uses the observation as global conditioning for the Unet by default.
|
||||
- Does not do any inpainting (which would be applicable if the observation were not used to condition
|
||||
the Unet).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cfg,
|
||||
|
@ -87,7 +75,7 @@ class DiffusionUnetImagePolicy(nn.Module):
|
|||
torch.full(sample.shape[:1], t, dtype=torch.long, device=sample.device),
|
||||
global_cond=global_cond,
|
||||
)
|
||||
# Compute previous image: x_t -> x_t-1 # TODO(now): Is this right?
|
||||
# Compute previous image: x_t -> x_t-1
|
||||
sample = self.noise_scheduler.step(model_output, t, sample, generator=generator).prev_sample
|
||||
|
||||
return sample
|
||||
|
|
|
@ -6,7 +6,7 @@ from collections import deque
|
|||
import hydra
|
||||
import torch
|
||||
from diffusers.optimization import get_scheduler
|
||||
from torch import Tensor, nn
|
||||
from torch import nn
|
||||
|
||||
from lerobot.common.policies.diffusion.model.diffusion_unet_image_policy import DiffusionUnetImagePolicy
|
||||
from lerobot.common.policies.utils import populate_queues
|
||||
|
@ -43,7 +43,6 @@ class DiffusionPolicy(nn.Module):
|
|||
# queues are populated during rollout of the policy, they contain the n latest observations and actions
|
||||
self._queues = None
|
||||
|
||||
# TODO(now): In-house this.
|
||||
noise_scheduler = hydra.utils.instantiate(cfg_noise_scheduler)
|
||||
|
||||
self.diffusion = DiffusionUnetImagePolicy(
|
||||
|
@ -103,45 +102,35 @@ class DiffusionPolicy(nn.Module):
|
|||
"action": deque(maxlen=self.n_action_steps),
|
||||
}
|
||||
|
||||
def forward(self, batch: dict[str, Tensor], **_) -> Tensor:
|
||||
"""A forward pass through the DNN part of this policy with optional loss computation."""
|
||||
return self.select_action(batch)
|
||||
|
||||
@torch.no_grad
|
||||
def select_action(self, batch, **_):
|
||||
"""
|
||||
Note: this uses the ema model weights if self.training == False, otherwise the non-ema model weights.
|
||||
# TODO(now): Handle a batch
|
||||
"""
|
||||
assert "observation.image" in batch
|
||||
assert "observation.state" in batch
|
||||
assert len(batch) == 2 # TODO(now): Does this not have a batch dim?
|
||||
assert len(batch) == 2
|
||||
|
||||
self._queues = populate_queues(self._queues, batch)
|
||||
|
||||
if len(self._queues["action"]) == 0:
|
||||
# stack n latest observations from the queue
|
||||
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch}
|
||||
actions = self._generate_actions(batch)
|
||||
if not self.training and self.ema_diffusion is not None:
|
||||
actions = self.ema_diffusion.generate_actions(batch)
|
||||
else:
|
||||
actions = self.diffusion.generate_actions(batch)
|
||||
self._queues["action"].extend(actions.transpose(0, 1))
|
||||
|
||||
action = self._queues["action"].popleft()
|
||||
return action
|
||||
|
||||
def _generate_actions(self, batch):
|
||||
if not self.training and self.ema_diffusion is not None:
|
||||
return self.ema_diffusion.generate_actions(batch)
|
||||
else:
|
||||
return self.diffusion.generate_actions(batch)
|
||||
|
||||
def update(self, batch, **_):
|
||||
"""Run the model in train mode, compute the loss, and do an optimization step."""
|
||||
def forward(self, batch, **_):
|
||||
start_time = time.time()
|
||||
|
||||
self.diffusion.train()
|
||||
|
||||
loss = self.compute_loss(batch)
|
||||
|
||||
loss = self.diffusion.compute_loss(batch)
|
||||
loss.backward()
|
||||
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
|
@ -166,9 +155,6 @@ class DiffusionPolicy(nn.Module):
|
|||
|
||||
return info
|
||||
|
||||
def compute_loss(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
return self.diffusion.compute_loss(batch)
|
||||
|
||||
def save(self, fp):
|
||||
torch.save(self.state_dict(), fp)
|
||||
|
||||
|
|
|
@ -18,7 +18,6 @@ def get_device_from_parameters(module: nn.Module) -> torch.device:
|
|||
"""Get a module's device by checking one of its parameters.
|
||||
|
||||
Note: assumes that all parameters have the same device
|
||||
TODO(now): Add test.
|
||||
"""
|
||||
return next(iter(module.parameters())).device
|
||||
|
||||
|
@ -27,6 +26,5 @@ def get_dtype_from_parameters(module: nn.Module) -> torch.dtype:
|
|||
"""Get a module's parameter dtype by checking one of its parameters.
|
||||
|
||||
Note: assumes that all parameters have the same dtype.
|
||||
TODO(now): Add test.
|
||||
"""
|
||||
return next(iter(module.parameters())).dtype
|
||||
|
|
|
@ -251,7 +251,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
|||
for key in batch:
|
||||
batch[key] = batch[key].to(cfg.device, non_blocking=True)
|
||||
|
||||
train_info = policy.update(batch, step=step)
|
||||
train_info = policy(batch, step=step)
|
||||
|
||||
# TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done?
|
||||
if step % cfg.log_freq == 0:
|
||||
|
|
Loading…
Reference in New Issue