Remove EMA model from Diffusion Policy (#134)
This commit is contained in:
parent
d747195c57
commit
f3bba0270d
|
@ -118,15 +118,6 @@ class DiffusionConfig:
|
||||||
# Inference
|
# Inference
|
||||||
num_inference_steps: int | None = None
|
num_inference_steps: int | None = None
|
||||||
|
|
||||||
# ---
|
|
||||||
# TODO(alexander-soare): Remove these from the policy config.
|
|
||||||
use_ema: bool = True
|
|
||||||
ema_update_after_step: int = 0
|
|
||||||
ema_min_alpha: float = 0.0
|
|
||||||
ema_max_alpha: float = 0.9999
|
|
||||||
ema_inv_gamma: float = 1.0
|
|
||||||
ema_power: float = 0.75
|
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
"""Input validation (not exhaustive)."""
|
"""Input validation (not exhaustive)."""
|
||||||
if not self.vision_backbone.startswith("resnet"):
|
if not self.vision_backbone.startswith("resnet"):
|
||||||
|
|
|
@ -3,12 +3,8 @@
|
||||||
TODO(alexander-soare):
|
TODO(alexander-soare):
|
||||||
- Remove reliance on Robomimic for SpatialSoftmax.
|
- Remove reliance on Robomimic for SpatialSoftmax.
|
||||||
- Remove reliance on diffusers for DDPMScheduler and LR scheduler.
|
- Remove reliance on diffusers for DDPMScheduler and LR scheduler.
|
||||||
- Move EMA out of policy.
|
|
||||||
- Consolidate _DiffusionUnetImagePolicy into DiffusionPolicy.
|
|
||||||
- One more pass on comments and documentation.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import copy
|
|
||||||
import math
|
import math
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
@ -21,7 +17,6 @@ from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
||||||
from huggingface_hub import PyTorchModelHubMixin
|
from huggingface_hub import PyTorchModelHubMixin
|
||||||
from robomimic.models.base_nets import SpatialSoftmax
|
from robomimic.models.base_nets import SpatialSoftmax
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
from torch.nn.modules.batchnorm import _BatchNorm
|
|
||||||
|
|
||||||
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
|
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||||
|
@ -71,13 +66,6 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
|
|
||||||
self.diffusion = DiffusionModel(config)
|
self.diffusion = DiffusionModel(config)
|
||||||
|
|
||||||
# TODO(alexander-soare): This should probably be managed outside of the policy class.
|
|
||||||
self.ema_diffusion = None
|
|
||||||
self.ema = None
|
|
||||||
if self.config.use_ema:
|
|
||||||
self.ema_diffusion = copy.deepcopy(self.diffusion)
|
|
||||||
self.ema = DiffusionEMA(config, model=self.ema_diffusion)
|
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
"""
|
"""
|
||||||
Clear observation and action queues. Should be called on `env.reset()`
|
Clear observation and action queues. Should be called on `env.reset()`
|
||||||
|
@ -109,9 +97,6 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
Note that this means we require: `n_action_steps < horizon - n_obs_steps + 1`. Also, note that
|
Note that this means we require: `n_action_steps < horizon - n_obs_steps + 1`. Also, note that
|
||||||
"horizon" may not the best name to describe what the variable actually means, because this period is
|
"horizon" may not the best name to describe what the variable actually means, because this period is
|
||||||
actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past.
|
actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past.
|
||||||
|
|
||||||
Note: this method uses the ema model weights if self.training == False, otherwise the non-ema model
|
|
||||||
weights.
|
|
||||||
"""
|
"""
|
||||||
assert "observation.image" in batch
|
assert "observation.image" in batch
|
||||||
assert "observation.state" in batch
|
assert "observation.state" in batch
|
||||||
|
@ -123,10 +108,7 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
if len(self._queues["action"]) == 0:
|
if len(self._queues["action"]) == 0:
|
||||||
# stack n latest observations from the queue
|
# stack n latest observations from the queue
|
||||||
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch}
|
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch}
|
||||||
if not self.training and self.ema_diffusion is not None:
|
actions = self.diffusion.generate_actions(batch)
|
||||||
actions = self.ema_diffusion.generate_actions(batch)
|
|
||||||
else:
|
|
||||||
actions = self.diffusion.generate_actions(batch)
|
|
||||||
|
|
||||||
# TODO(rcadene): make above methods return output dictionary?
|
# TODO(rcadene): make above methods return output dictionary?
|
||||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||||
|
@ -612,67 +594,3 @@ class DiffusionConditionalResidualBlock1d(nn.Module):
|
||||||
out = self.conv2(out)
|
out = self.conv2(out)
|
||||||
out = out + self.residual_conv(x)
|
out = out + self.residual_conv(x)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
class DiffusionEMA:
|
|
||||||
"""
|
|
||||||
Exponential Moving Average of models weights
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, config: DiffusionConfig, model: nn.Module):
|
|
||||||
"""
|
|
||||||
@crowsonkb's notes on EMA Warmup:
|
|
||||||
If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models
|
|
||||||
you plan to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999
|
|
||||||
at 1M steps), gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999
|
|
||||||
at 10K steps, 0.9999 at 215.4k steps).
|
|
||||||
Args:
|
|
||||||
inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
|
|
||||||
power (float): Exponential factor of EMA warmup. Default: 2/3.
|
|
||||||
min_alpha (float): The minimum EMA decay rate. Default: 0.
|
|
||||||
"""
|
|
||||||
|
|
||||||
self.averaged_model = model
|
|
||||||
self.averaged_model.eval()
|
|
||||||
self.averaged_model.requires_grad_(False)
|
|
||||||
|
|
||||||
self.update_after_step = config.ema_update_after_step
|
|
||||||
self.inv_gamma = config.ema_inv_gamma
|
|
||||||
self.power = config.ema_power
|
|
||||||
self.min_alpha = config.ema_min_alpha
|
|
||||||
self.max_alpha = config.ema_max_alpha
|
|
||||||
|
|
||||||
self.alpha = 0.0
|
|
||||||
self.optimization_step = 0
|
|
||||||
|
|
||||||
def get_decay(self, optimization_step):
|
|
||||||
"""
|
|
||||||
Compute the decay factor for the exponential moving average.
|
|
||||||
"""
|
|
||||||
step = max(0, optimization_step - self.update_after_step - 1)
|
|
||||||
value = 1 - (1 + step / self.inv_gamma) ** -self.power
|
|
||||||
|
|
||||||
if step <= 0:
|
|
||||||
return 0.0
|
|
||||||
|
|
||||||
return max(self.min_alpha, min(value, self.max_alpha))
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def step(self, new_model):
|
|
||||||
self.alpha = self.get_decay(self.optimization_step)
|
|
||||||
|
|
||||||
for module, ema_module in zip(new_model.modules(), self.averaged_model.modules(), strict=True):
|
|
||||||
# Iterate over immediate parameters only.
|
|
||||||
for param, ema_param in zip(
|
|
||||||
module.parameters(recurse=False), ema_module.parameters(recurse=False), strict=True
|
|
||||||
):
|
|
||||||
if isinstance(param, dict):
|
|
||||||
raise RuntimeError("Dict parameter not supported")
|
|
||||||
if isinstance(module, _BatchNorm) or not param.requires_grad:
|
|
||||||
# Copy BatchNorm parameters, and non-trainable parameters directly.
|
|
||||||
ema_param.copy_(param.to(dtype=ema_param.dtype).data)
|
|
||||||
else:
|
|
||||||
ema_param.mul_(self.alpha)
|
|
||||||
ema_param.add_(param.data.to(dtype=ema_param.dtype), alpha=1 - self.alpha)
|
|
||||||
|
|
||||||
self.optimization_step += 1
|
|
||||||
|
|
|
@ -1,5 +1,9 @@
|
||||||
# @package _global_
|
# @package _global_
|
||||||
|
|
||||||
|
# Defaults for training for the PushT dataset as per https://github.com/real-stanford/diffusion_policy.
|
||||||
|
# Note: We do not track EMA model weights as we discovered it does not improve the results. See
|
||||||
|
# https://github.com/huggingface/lerobot/pull/134 for more details.
|
||||||
|
|
||||||
seed: 100000
|
seed: 100000
|
||||||
dataset_repo_id: lerobot/pusht
|
dataset_repo_id: lerobot/pusht
|
||||||
|
|
||||||
|
@ -91,12 +95,3 @@ policy:
|
||||||
|
|
||||||
# Inference
|
# Inference
|
||||||
num_inference_steps: 100
|
num_inference_steps: 100
|
||||||
|
|
||||||
# ---
|
|
||||||
# TODO(alexander-soare): Remove these from the policy config.
|
|
||||||
use_ema: true
|
|
||||||
ema_update_after_step: 0
|
|
||||||
ema_min_alpha: 0.0
|
|
||||||
ema_max_alpha: 0.9999
|
|
||||||
ema_inv_gamma: 1.0
|
|
||||||
ema_power: 0.75
|
|
||||||
|
|
|
@ -121,7 +121,7 @@ def rollout(
|
||||||
max_steps = env.call("_max_episode_steps")[0]
|
max_steps = env.call("_max_episode_steps")[0]
|
||||||
progbar = trange(
|
progbar = trange(
|
||||||
max_steps,
|
max_steps,
|
||||||
desc=f"Running rollout with {max_steps} steps (maximum) per rollout",
|
desc=f"Running rollout with at most {max_steps} steps",
|
||||||
disable=not enable_progbar,
|
disable=not enable_progbar,
|
||||||
leave=False,
|
leave=False,
|
||||||
)
|
)
|
||||||
|
|
|
@ -89,9 +89,6 @@ def update_policy(policy, batch, optimizer, grad_clip_norm, lr_scheduler=None):
|
||||||
if lr_scheduler is not None:
|
if lr_scheduler is not None:
|
||||||
lr_scheduler.step()
|
lr_scheduler.step()
|
||||||
|
|
||||||
if hasattr(policy, "ema") and policy.ema is not None:
|
|
||||||
policy.ema.step(policy.diffusion)
|
|
||||||
|
|
||||||
if isinstance(policy, PolicyWithUpdate):
|
if isinstance(policy, PolicyWithUpdate):
|
||||||
# To possibly update an internal buffer (for instance an Exponential Moving Average like in TDMPC).
|
# To possibly update an internal buffer (for instance an Exponential Moving Average like in TDMPC).
|
||||||
policy.update()
|
policy.update()
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
# This file is automatically @generated by Poetry 1.8.1 and should not be changed by hand.
|
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "absl-py"
|
name = "absl-py"
|
||||||
|
@ -2407,7 +2407,6 @@ optional = false
|
||||||
python-versions = ">=3.9"
|
python-versions = ">=3.9"
|
||||||
files = [
|
files = [
|
||||||
{file = "pandas-2.2.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:90c6fca2acf139569e74e8781709dccb6fe25940488755716d1d354d6bc58bce"},
|
{file = "pandas-2.2.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:90c6fca2acf139569e74e8781709dccb6fe25940488755716d1d354d6bc58bce"},
|
||||||
{file = "pandas-2.2.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c7adfc142dac335d8c1e0dcbd37eb8617eac386596eb9e1a1b77791cf2498238"},
|
|
||||||
{file = "pandas-2.2.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4abfe0be0d7221be4f12552995e58723c7422c80a659da13ca382697de830c08"},
|
{file = "pandas-2.2.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4abfe0be0d7221be4f12552995e58723c7422c80a659da13ca382697de830c08"},
|
||||||
{file = "pandas-2.2.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8635c16bf3d99040fdf3ca3db669a7250ddf49c55dc4aa8fe0ae0fa8d6dcc1f0"},
|
{file = "pandas-2.2.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8635c16bf3d99040fdf3ca3db669a7250ddf49c55dc4aa8fe0ae0fa8d6dcc1f0"},
|
||||||
{file = "pandas-2.2.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:40ae1dffb3967a52203105a077415a86044a2bea011b5f321c6aa64b379a3f51"},
|
{file = "pandas-2.2.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:40ae1dffb3967a52203105a077415a86044a2bea011b5f321c6aa64b379a3f51"},
|
||||||
|
@ -2428,7 +2427,6 @@ files = [
|
||||||
{file = "pandas-2.2.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:43498c0bdb43d55cb162cdc8c06fac328ccb5d2eabe3cadeb3529ae6f0517c32"},
|
{file = "pandas-2.2.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:43498c0bdb43d55cb162cdc8c06fac328ccb5d2eabe3cadeb3529ae6f0517c32"},
|
||||||
{file = "pandas-2.2.2-cp312-cp312-win_amd64.whl", hash = "sha256:d187d355ecec3629624fccb01d104da7d7f391db0311145817525281e2804d23"},
|
{file = "pandas-2.2.2-cp312-cp312-win_amd64.whl", hash = "sha256:d187d355ecec3629624fccb01d104da7d7f391db0311145817525281e2804d23"},
|
||||||
{file = "pandas-2.2.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:0ca6377b8fca51815f382bd0b697a0814c8bda55115678cbc94c30aacbb6eff2"},
|
{file = "pandas-2.2.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:0ca6377b8fca51815f382bd0b697a0814c8bda55115678cbc94c30aacbb6eff2"},
|
||||||
{file = "pandas-2.2.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9057e6aa78a584bc93a13f0a9bf7e753a5e9770a30b4d758b8d5f2a62a9433cd"},
|
|
||||||
{file = "pandas-2.2.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:001910ad31abc7bf06f49dcc903755d2f7f3a9186c0c040b827e522e9cef0863"},
|
{file = "pandas-2.2.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:001910ad31abc7bf06f49dcc903755d2f7f3a9186c0c040b827e522e9cef0863"},
|
||||||
{file = "pandas-2.2.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:66b479b0bd07204e37583c191535505410daa8df638fd8e75ae1b383851fe921"},
|
{file = "pandas-2.2.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:66b479b0bd07204e37583c191535505410daa8df638fd8e75ae1b383851fe921"},
|
||||||
{file = "pandas-2.2.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:a77e9d1c386196879aa5eb712e77461aaee433e54c68cf253053a73b7e49c33a"},
|
{file = "pandas-2.2.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:a77e9d1c386196879aa5eb712e77461aaee433e54c68cf253053a73b7e49c33a"},
|
||||||
|
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -88,14 +88,8 @@ def save_policy_to_safetensors(output_dir, env_name, policy_name, extra_override
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
env_policies = [
|
# Instructions: include the policies that you want to save artifacts for here. Please make sure to revert
|
||||||
# ("xarm", "tdmpc", ["policy.n_action_repeats=2"]),
|
# your changes when you are done.
|
||||||
(
|
env_policies = []
|
||||||
"pusht",
|
|
||||||
"diffusion",
|
|
||||||
["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"],
|
|
||||||
),
|
|
||||||
("aloha", "act", ["policy.n_action_steps=10"]),
|
|
||||||
]
|
|
||||||
for env, policy, extra_overrides in env_policies:
|
for env, policy, extra_overrides in env_policies:
|
||||||
save_policy_to_safetensors("tests/data/save_policy_to_safetensors", env, policy, extra_overrides)
|
save_policy_to_safetensors("tests/data/save_policy_to_safetensors", env, policy, extra_overrides)
|
||||||
|
|
|
@ -249,6 +249,17 @@ def test_normalize(insert_temporal_dim):
|
||||||
# pass if it's run on another platform due to floating point errors
|
# pass if it's run on another platform due to floating point errors
|
||||||
@require_x86_64_kernel
|
@require_x86_64_kernel
|
||||||
def test_backward_compatibility(env_name, policy_name, extra_overrides):
|
def test_backward_compatibility(env_name, policy_name, extra_overrides):
|
||||||
|
"""
|
||||||
|
NOTE: If this test does not pass, and you have intentionally changed something in the policy:
|
||||||
|
1. Inspect the differences in policy outputs and make sure you can account for them. Your PR should
|
||||||
|
include a report on what changed and how that affected the outputs.
|
||||||
|
2. Go to the `if __name__ == "__main__"` block of `test/scripts/save_policy_to_safetensors.py` and
|
||||||
|
add the policies you want to update the test artifacts for.
|
||||||
|
3. Run `python test/scripts/save_policy_to_safetensors.py`. The test artifact should be updated.
|
||||||
|
4. Check that this test now passes.
|
||||||
|
5. Remember to restore `test/scripts/save_policy_to_safetensors.py` to its original state.
|
||||||
|
6. Remember to stage and commit the resulting changes to `tests/data`.
|
||||||
|
"""
|
||||||
env_policy_dir = Path("tests/data/save_policy_to_safetensors") / f"{env_name}_{policy_name}"
|
env_policy_dir = Path("tests/data/save_policy_to_safetensors") / f"{env_name}_{policy_name}"
|
||||||
saved_output_dict = load_file(env_policy_dir / "output_dict.safetensors")
|
saved_output_dict = load_file(env_policy_dir / "output_dict.safetensors")
|
||||||
saved_grad_stats = load_file(env_policy_dir / "grad_stats.safetensors")
|
saved_grad_stats = load_file(env_policy_dir / "grad_stats.safetensors")
|
||||||
|
|
Loading…
Reference in New Issue