From ee55d28afd4a668da0a94bdc59557b44726e480b Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Sun, 5 May 2024 08:18:13 +0100 Subject: [PATCH] remove EMA from DP --- .../diffusion/configuration_diffusion.py | 9 -- .../policies/diffusion/modeling_diffusion.py | 84 +------------------ lerobot/configs/policy/diffusion.yaml | 9 -- lerobot/scripts/eval.py | 2 +- lerobot/scripts/train.py | 3 - 5 files changed, 2 insertions(+), 105 deletions(-) diff --git a/lerobot/common/policies/diffusion/configuration_diffusion.py b/lerobot/common/policies/diffusion/configuration_diffusion.py index b5188488..73fabefa 100644 --- a/lerobot/common/policies/diffusion/configuration_diffusion.py +++ b/lerobot/common/policies/diffusion/configuration_diffusion.py @@ -118,15 +118,6 @@ class DiffusionConfig: # Inference num_inference_steps: int | None = None - # --- - # TODO(alexander-soare): Remove these from the policy config. - use_ema: bool = True - ema_update_after_step: int = 0 - ema_min_alpha: float = 0.0 - ema_max_alpha: float = 0.9999 - ema_inv_gamma: float = 1.0 - ema_power: float = 0.75 - def __post_init__(self): """Input validation (not exhaustive).""" if not self.vision_backbone.startswith("resnet"): diff --git a/lerobot/common/policies/diffusion/modeling_diffusion.py b/lerobot/common/policies/diffusion/modeling_diffusion.py index c639e2f9..f5f64d80 100644 --- a/lerobot/common/policies/diffusion/modeling_diffusion.py +++ b/lerobot/common/policies/diffusion/modeling_diffusion.py @@ -3,12 +3,8 @@ TODO(alexander-soare): - Remove reliance on Robomimic for SpatialSoftmax. - Remove reliance on diffusers for DDPMScheduler and LR scheduler. - - Move EMA out of policy. - - Consolidate _DiffusionUnetImagePolicy into DiffusionPolicy. - - One more pass on comments and documentation. """ -import copy import math from collections import deque from typing import Callable @@ -21,7 +17,6 @@ from diffusers.schedulers.scheduling_ddpm import DDPMScheduler from huggingface_hub import PyTorchModelHubMixin from robomimic.models.base_nets import SpatialSoftmax from torch import Tensor, nn -from torch.nn.modules.batchnorm import _BatchNorm from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig from lerobot.common.policies.normalize import Normalize, Unnormalize @@ -71,13 +66,6 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin): self.diffusion = DiffusionModel(config) - # TODO(alexander-soare): This should probably be managed outside of the policy class. - self.ema_diffusion = None - self.ema = None - if self.config.use_ema: - self.ema_diffusion = copy.deepcopy(self.diffusion) - self.ema = DiffusionEMA(config, model=self.ema_diffusion) - def reset(self): """ Clear observation and action queues. Should be called on `env.reset()` @@ -109,9 +97,6 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin): Note that this means we require: `n_action_steps < horizon - n_obs_steps + 1`. Also, note that "horizon" may not the best name to describe what the variable actually means, because this period is actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past. - - Note: this method uses the ema model weights if self.training == False, otherwise the non-ema model - weights. """ assert "observation.image" in batch assert "observation.state" in batch @@ -123,10 +108,7 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin): if len(self._queues["action"]) == 0: # stack n latest observations from the queue batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch} - if not self.training and self.ema_diffusion is not None: - actions = self.ema_diffusion.generate_actions(batch) - else: - actions = self.diffusion.generate_actions(batch) + actions = self.diffusion.generate_actions(batch) # TODO(rcadene): make above methods return output dictionary? actions = self.unnormalize_outputs({"action": actions})["action"] @@ -612,67 +594,3 @@ class DiffusionConditionalResidualBlock1d(nn.Module): out = self.conv2(out) out = out + self.residual_conv(x) return out - - -class DiffusionEMA: - """ - Exponential Moving Average of models weights - """ - - def __init__(self, config: DiffusionConfig, model: nn.Module): - """ - @crowsonkb's notes on EMA Warmup: - If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models - you plan to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 - at 1M steps), gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 - at 10K steps, 0.9999 at 215.4k steps). - Args: - inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1. - power (float): Exponential factor of EMA warmup. Default: 2/3. - min_alpha (float): The minimum EMA decay rate. Default: 0. - """ - - self.averaged_model = model - self.averaged_model.eval() - self.averaged_model.requires_grad_(False) - - self.update_after_step = config.ema_update_after_step - self.inv_gamma = config.ema_inv_gamma - self.power = config.ema_power - self.min_alpha = config.ema_min_alpha - self.max_alpha = config.ema_max_alpha - - self.alpha = 0.0 - self.optimization_step = 0 - - def get_decay(self, optimization_step): - """ - Compute the decay factor for the exponential moving average. - """ - step = max(0, optimization_step - self.update_after_step - 1) - value = 1 - (1 + step / self.inv_gamma) ** -self.power - - if step <= 0: - return 0.0 - - return max(self.min_alpha, min(value, self.max_alpha)) - - @torch.no_grad() - def step(self, new_model): - self.alpha = self.get_decay(self.optimization_step) - - for module, ema_module in zip(new_model.modules(), self.averaged_model.modules(), strict=True): - # Iterate over immediate parameters only. - for param, ema_param in zip( - module.parameters(recurse=False), ema_module.parameters(recurse=False), strict=True - ): - if isinstance(param, dict): - raise RuntimeError("Dict parameter not supported") - if isinstance(module, _BatchNorm) or not param.requires_grad: - # Copy BatchNorm parameters, and non-trainable parameters directly. - ema_param.copy_(param.to(dtype=ema_param.dtype).data) - else: - ema_param.mul_(self.alpha) - ema_param.add_(param.data.to(dtype=ema_param.dtype), alpha=1 - self.alpha) - - self.optimization_step += 1 diff --git a/lerobot/configs/policy/diffusion.yaml b/lerobot/configs/policy/diffusion.yaml index aa90afdf..a443d9b8 100644 --- a/lerobot/configs/policy/diffusion.yaml +++ b/lerobot/configs/policy/diffusion.yaml @@ -91,12 +91,3 @@ policy: # Inference num_inference_steps: 100 - - # --- - # TODO(alexander-soare): Remove these from the policy config. - use_ema: true - ema_update_after_step: 0 - ema_min_alpha: 0.0 - ema_max_alpha: 0.9999 - ema_inv_gamma: 1.0 - ema_power: 0.75 diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index e3afac41..e9aa3041 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -121,7 +121,7 @@ def rollout( max_steps = env.call("_max_episode_steps")[0] progbar = trange( max_steps, - desc=f"Running rollout with {max_steps} steps (maximum) per rollout", + desc=f"Running rollout with at most {max_steps} steps", disable=not enable_progbar, leave=False, ) diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index f58dbd06..6cbc8265 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -89,9 +89,6 @@ def update_policy(policy, batch, optimizer, grad_clip_norm, lr_scheduler=None): if lr_scheduler is not None: lr_scheduler.step() - if hasattr(policy, "ema") and policy.ema is not None: - policy.ema.step(policy.diffusion) - if isinstance(policy, PolicyWithUpdate): # To possibly update an internal buffer (for instance an Exponential Moving Average like in TDMPC). policy.update()