remove EMA from DP

This commit is contained in:
Alexander Soare 2024-05-05 08:18:13 +01:00
parent d747195c57
commit ee55d28afd
5 changed files with 2 additions and 105 deletions

View File

@ -118,15 +118,6 @@ class DiffusionConfig:
# Inference # Inference
num_inference_steps: int | None = None num_inference_steps: int | None = None
# ---
# TODO(alexander-soare): Remove these from the policy config.
use_ema: bool = True
ema_update_after_step: int = 0
ema_min_alpha: float = 0.0
ema_max_alpha: float = 0.9999
ema_inv_gamma: float = 1.0
ema_power: float = 0.75
def __post_init__(self): def __post_init__(self):
"""Input validation (not exhaustive).""" """Input validation (not exhaustive)."""
if not self.vision_backbone.startswith("resnet"): if not self.vision_backbone.startswith("resnet"):

View File

@ -3,12 +3,8 @@
TODO(alexander-soare): TODO(alexander-soare):
- Remove reliance on Robomimic for SpatialSoftmax. - Remove reliance on Robomimic for SpatialSoftmax.
- Remove reliance on diffusers for DDPMScheduler and LR scheduler. - Remove reliance on diffusers for DDPMScheduler and LR scheduler.
- Move EMA out of policy.
- Consolidate _DiffusionUnetImagePolicy into DiffusionPolicy.
- One more pass on comments and documentation.
""" """
import copy
import math import math
from collections import deque from collections import deque
from typing import Callable from typing import Callable
@ -21,7 +17,6 @@ from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from huggingface_hub import PyTorchModelHubMixin from huggingface_hub import PyTorchModelHubMixin
from robomimic.models.base_nets import SpatialSoftmax from robomimic.models.base_nets import SpatialSoftmax
from torch import Tensor, nn from torch import Tensor, nn
from torch.nn.modules.batchnorm import _BatchNorm
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.normalize import Normalize, Unnormalize
@ -71,13 +66,6 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
self.diffusion = DiffusionModel(config) self.diffusion = DiffusionModel(config)
# TODO(alexander-soare): This should probably be managed outside of the policy class.
self.ema_diffusion = None
self.ema = None
if self.config.use_ema:
self.ema_diffusion = copy.deepcopy(self.diffusion)
self.ema = DiffusionEMA(config, model=self.ema_diffusion)
def reset(self): def reset(self):
""" """
Clear observation and action queues. Should be called on `env.reset()` Clear observation and action queues. Should be called on `env.reset()`
@ -109,9 +97,6 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
Note that this means we require: `n_action_steps < horizon - n_obs_steps + 1`. Also, note that Note that this means we require: `n_action_steps < horizon - n_obs_steps + 1`. Also, note that
"horizon" may not the best name to describe what the variable actually means, because this period is "horizon" may not the best name to describe what the variable actually means, because this period is
actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past. actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past.
Note: this method uses the ema model weights if self.training == False, otherwise the non-ema model
weights.
""" """
assert "observation.image" in batch assert "observation.image" in batch
assert "observation.state" in batch assert "observation.state" in batch
@ -123,10 +108,7 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
if len(self._queues["action"]) == 0: if len(self._queues["action"]) == 0:
# stack n latest observations from the queue # stack n latest observations from the queue
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch} batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch}
if not self.training and self.ema_diffusion is not None: actions = self.diffusion.generate_actions(batch)
actions = self.ema_diffusion.generate_actions(batch)
else:
actions = self.diffusion.generate_actions(batch)
# TODO(rcadene): make above methods return output dictionary? # TODO(rcadene): make above methods return output dictionary?
actions = self.unnormalize_outputs({"action": actions})["action"] actions = self.unnormalize_outputs({"action": actions})["action"]
@ -612,67 +594,3 @@ class DiffusionConditionalResidualBlock1d(nn.Module):
out = self.conv2(out) out = self.conv2(out)
out = out + self.residual_conv(x) out = out + self.residual_conv(x)
return out return out
class DiffusionEMA:
"""
Exponential Moving Average of models weights
"""
def __init__(self, config: DiffusionConfig, model: nn.Module):
"""
@crowsonkb's notes on EMA Warmup:
If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models
you plan to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999
at 1M steps), gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999
at 10K steps, 0.9999 at 215.4k steps).
Args:
inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
power (float): Exponential factor of EMA warmup. Default: 2/3.
min_alpha (float): The minimum EMA decay rate. Default: 0.
"""
self.averaged_model = model
self.averaged_model.eval()
self.averaged_model.requires_grad_(False)
self.update_after_step = config.ema_update_after_step
self.inv_gamma = config.ema_inv_gamma
self.power = config.ema_power
self.min_alpha = config.ema_min_alpha
self.max_alpha = config.ema_max_alpha
self.alpha = 0.0
self.optimization_step = 0
def get_decay(self, optimization_step):
"""
Compute the decay factor for the exponential moving average.
"""
step = max(0, optimization_step - self.update_after_step - 1)
value = 1 - (1 + step / self.inv_gamma) ** -self.power
if step <= 0:
return 0.0
return max(self.min_alpha, min(value, self.max_alpha))
@torch.no_grad()
def step(self, new_model):
self.alpha = self.get_decay(self.optimization_step)
for module, ema_module in zip(new_model.modules(), self.averaged_model.modules(), strict=True):
# Iterate over immediate parameters only.
for param, ema_param in zip(
module.parameters(recurse=False), ema_module.parameters(recurse=False), strict=True
):
if isinstance(param, dict):
raise RuntimeError("Dict parameter not supported")
if isinstance(module, _BatchNorm) or not param.requires_grad:
# Copy BatchNorm parameters, and non-trainable parameters directly.
ema_param.copy_(param.to(dtype=ema_param.dtype).data)
else:
ema_param.mul_(self.alpha)
ema_param.add_(param.data.to(dtype=ema_param.dtype), alpha=1 - self.alpha)
self.optimization_step += 1

View File

@ -91,12 +91,3 @@ policy:
# Inference # Inference
num_inference_steps: 100 num_inference_steps: 100
# ---
# TODO(alexander-soare): Remove these from the policy config.
use_ema: true
ema_update_after_step: 0
ema_min_alpha: 0.0
ema_max_alpha: 0.9999
ema_inv_gamma: 1.0
ema_power: 0.75

View File

@ -121,7 +121,7 @@ def rollout(
max_steps = env.call("_max_episode_steps")[0] max_steps = env.call("_max_episode_steps")[0]
progbar = trange( progbar = trange(
max_steps, max_steps,
desc=f"Running rollout with {max_steps} steps (maximum) per rollout", desc=f"Running rollout with at most {max_steps} steps",
disable=not enable_progbar, disable=not enable_progbar,
leave=False, leave=False,
) )

View File

@ -89,9 +89,6 @@ def update_policy(policy, batch, optimizer, grad_clip_norm, lr_scheduler=None):
if lr_scheduler is not None: if lr_scheduler is not None:
lr_scheduler.step() lr_scheduler.step()
if hasattr(policy, "ema") and policy.ema is not None:
policy.ema.step(policy.diffusion)
if isinstance(policy, PolicyWithUpdate): if isinstance(policy, PolicyWithUpdate):
# To possibly update an internal buffer (for instance an Exponential Moving Average like in TDMPC). # To possibly update an internal buffer (for instance an Exponential Moving Average like in TDMPC).
policy.update() policy.update()