ready for review
This commit is contained in:
parent
5666ec3ec7
commit
6d0a45a97d
|
@ -32,8 +32,6 @@ policy = DiffusionPolicy(
|
||||||
cfg=cfg.policy,
|
cfg=cfg.policy,
|
||||||
cfg_device=cfg.device,
|
cfg_device=cfg.device,
|
||||||
cfg_noise_scheduler=cfg.noise_scheduler,
|
cfg_noise_scheduler=cfg.noise_scheduler,
|
||||||
cfg_rgb_model=cfg.rgb_model,
|
|
||||||
cfg_obs_encoder=cfg.obs_encoder,
|
|
||||||
cfg_optimizer=cfg.optimizer,
|
cfg_optimizer=cfg.optimizer,
|
||||||
cfg_ema=cfg.ema,
|
cfg_ema=cfg.ema,
|
||||||
**cfg.policy,
|
**cfg.policy,
|
||||||
|
|
|
@ -213,7 +213,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||||
return action[: self.n_action_steps]
|
return action[: self.n_action_steps]
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs) -> dict:
|
def __call__(self, *args, **kwargs) -> dict:
|
||||||
# TODO(now): Temporary bridge until we know what to do about the `update` method.
|
# TODO(alexander-soare): Temporary bridge until we know what to do about the `update` method.
|
||||||
return self.update(*args, **kwargs)
|
return self.update(*args, **kwargs)
|
||||||
|
|
||||||
def _preprocess_batch(
|
def _preprocess_batch(
|
||||||
|
|
|
@ -10,7 +10,6 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class _SinusoidalPosEmb(nn.Module):
|
class _SinusoidalPosEmb(nn.Module):
|
||||||
# TODO(now): consolidate?
|
|
||||||
def __init__(self, dim):
|
def __init__(self, dim):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
|
|
|
@ -10,18 +10,6 @@ from lerobot.common.policies.utils import get_device_from_parameters, get_dtype_
|
||||||
|
|
||||||
|
|
||||||
class DiffusionUnetImagePolicy(nn.Module):
|
class DiffusionUnetImagePolicy(nn.Module):
|
||||||
"""
|
|
||||||
TODO(now): Add DDIM scheduler.
|
|
||||||
|
|
||||||
Changes: TODO(now)
|
|
||||||
- Use single image encoder for now instead of generic obs_encoder. We may generalize again when/if
|
|
||||||
needed. Code for a general observation encoder can be found at:
|
|
||||||
https://github.com/huggingface/lerobot/blob/920e0d118b493e4cc3058a9b1b764f38ae145d8e/lerobot/common/policies/diffusion/model/multi_image_obs_encoder.py
|
|
||||||
- Uses the observation as global conditioning for the Unet by default.
|
|
||||||
- Does not do any inpainting (which would be applicable if the observation were not used to condition
|
|
||||||
the Unet).
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
cfg,
|
cfg,
|
||||||
|
@ -87,7 +75,7 @@ class DiffusionUnetImagePolicy(nn.Module):
|
||||||
torch.full(sample.shape[:1], t, dtype=torch.long, device=sample.device),
|
torch.full(sample.shape[:1], t, dtype=torch.long, device=sample.device),
|
||||||
global_cond=global_cond,
|
global_cond=global_cond,
|
||||||
)
|
)
|
||||||
# Compute previous image: x_t -> x_t-1 # TODO(now): Is this right?
|
# Compute previous image: x_t -> x_t-1
|
||||||
sample = self.noise_scheduler.step(model_output, t, sample, generator=generator).prev_sample
|
sample = self.noise_scheduler.step(model_output, t, sample, generator=generator).prev_sample
|
||||||
|
|
||||||
return sample
|
return sample
|
||||||
|
|
|
@ -6,7 +6,7 @@ from collections import deque
|
||||||
import hydra
|
import hydra
|
||||||
import torch
|
import torch
|
||||||
from diffusers.optimization import get_scheduler
|
from diffusers.optimization import get_scheduler
|
||||||
from torch import Tensor, nn
|
from torch import nn
|
||||||
|
|
||||||
from lerobot.common.policies.diffusion.model.diffusion_unet_image_policy import DiffusionUnetImagePolicy
|
from lerobot.common.policies.diffusion.model.diffusion_unet_image_policy import DiffusionUnetImagePolicy
|
||||||
from lerobot.common.policies.utils import populate_queues
|
from lerobot.common.policies.utils import populate_queues
|
||||||
|
@ -43,7 +43,6 @@ class DiffusionPolicy(nn.Module):
|
||||||
# queues are populated during rollout of the policy, they contain the n latest observations and actions
|
# queues are populated during rollout of the policy, they contain the n latest observations and actions
|
||||||
self._queues = None
|
self._queues = None
|
||||||
|
|
||||||
# TODO(now): In-house this.
|
|
||||||
noise_scheduler = hydra.utils.instantiate(cfg_noise_scheduler)
|
noise_scheduler = hydra.utils.instantiate(cfg_noise_scheduler)
|
||||||
|
|
||||||
self.diffusion = DiffusionUnetImagePolicy(
|
self.diffusion = DiffusionUnetImagePolicy(
|
||||||
|
@ -103,45 +102,35 @@ class DiffusionPolicy(nn.Module):
|
||||||
"action": deque(maxlen=self.n_action_steps),
|
"action": deque(maxlen=self.n_action_steps),
|
||||||
}
|
}
|
||||||
|
|
||||||
def forward(self, batch: dict[str, Tensor], **_) -> Tensor:
|
|
||||||
"""A forward pass through the DNN part of this policy with optional loss computation."""
|
|
||||||
return self.select_action(batch)
|
|
||||||
|
|
||||||
@torch.no_grad
|
@torch.no_grad
|
||||||
def select_action(self, batch, **_):
|
def select_action(self, batch, **_):
|
||||||
"""
|
"""
|
||||||
Note: this uses the ema model weights if self.training == False, otherwise the non-ema model weights.
|
Note: this uses the ema model weights if self.training == False, otherwise the non-ema model weights.
|
||||||
# TODO(now): Handle a batch
|
|
||||||
"""
|
"""
|
||||||
assert "observation.image" in batch
|
assert "observation.image" in batch
|
||||||
assert "observation.state" in batch
|
assert "observation.state" in batch
|
||||||
assert len(batch) == 2 # TODO(now): Does this not have a batch dim?
|
assert len(batch) == 2
|
||||||
|
|
||||||
self._queues = populate_queues(self._queues, batch)
|
self._queues = populate_queues(self._queues, batch)
|
||||||
|
|
||||||
if len(self._queues["action"]) == 0:
|
if len(self._queues["action"]) == 0:
|
||||||
# stack n latest observations from the queue
|
# stack n latest observations from the queue
|
||||||
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch}
|
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch}
|
||||||
actions = self._generate_actions(batch)
|
if not self.training and self.ema_diffusion is not None:
|
||||||
|
actions = self.ema_diffusion.generate_actions(batch)
|
||||||
|
else:
|
||||||
|
actions = self.diffusion.generate_actions(batch)
|
||||||
self._queues["action"].extend(actions.transpose(0, 1))
|
self._queues["action"].extend(actions.transpose(0, 1))
|
||||||
|
|
||||||
action = self._queues["action"].popleft()
|
action = self._queues["action"].popleft()
|
||||||
return action
|
return action
|
||||||
|
|
||||||
def _generate_actions(self, batch):
|
def forward(self, batch, **_):
|
||||||
if not self.training and self.ema_diffusion is not None:
|
|
||||||
return self.ema_diffusion.generate_actions(batch)
|
|
||||||
else:
|
|
||||||
return self.diffusion.generate_actions(batch)
|
|
||||||
|
|
||||||
def update(self, batch, **_):
|
|
||||||
"""Run the model in train mode, compute the loss, and do an optimization step."""
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
self.diffusion.train()
|
self.diffusion.train()
|
||||||
|
|
||||||
loss = self.compute_loss(batch)
|
loss = self.diffusion.compute_loss(batch)
|
||||||
|
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||||
|
@ -166,9 +155,6 @@ class DiffusionPolicy(nn.Module):
|
||||||
|
|
||||||
return info
|
return info
|
||||||
|
|
||||||
def compute_loss(self, batch: dict[str, Tensor]) -> Tensor:
|
|
||||||
return self.diffusion.compute_loss(batch)
|
|
||||||
|
|
||||||
def save(self, fp):
|
def save(self, fp):
|
||||||
torch.save(self.state_dict(), fp)
|
torch.save(self.state_dict(), fp)
|
||||||
|
|
||||||
|
|
|
@ -18,7 +18,6 @@ def get_device_from_parameters(module: nn.Module) -> torch.device:
|
||||||
"""Get a module's device by checking one of its parameters.
|
"""Get a module's device by checking one of its parameters.
|
||||||
|
|
||||||
Note: assumes that all parameters have the same device
|
Note: assumes that all parameters have the same device
|
||||||
TODO(now): Add test.
|
|
||||||
"""
|
"""
|
||||||
return next(iter(module.parameters())).device
|
return next(iter(module.parameters())).device
|
||||||
|
|
||||||
|
@ -27,6 +26,5 @@ def get_dtype_from_parameters(module: nn.Module) -> torch.dtype:
|
||||||
"""Get a module's parameter dtype by checking one of its parameters.
|
"""Get a module's parameter dtype by checking one of its parameters.
|
||||||
|
|
||||||
Note: assumes that all parameters have the same dtype.
|
Note: assumes that all parameters have the same dtype.
|
||||||
TODO(now): Add test.
|
|
||||||
"""
|
"""
|
||||||
return next(iter(module.parameters())).dtype
|
return next(iter(module.parameters())).dtype
|
||||||
|
|
|
@ -251,7 +251,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
for key in batch:
|
for key in batch:
|
||||||
batch[key] = batch[key].to(cfg.device, non_blocking=True)
|
batch[key] = batch[key].to(cfg.device, non_blocking=True)
|
||||||
|
|
||||||
train_info = policy.update(batch, step=step)
|
train_info = policy(batch, step=step)
|
||||||
|
|
||||||
# TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done?
|
# TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done?
|
||||||
if step % cfg.log_freq == 0:
|
if step % cfg.log_freq == 0:
|
||||||
|
|
Loading…
Reference in New Issue