From eb44a06a9b51e4066f436215456b4109e9a14ce3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 28 Mar 2025 17:20:38 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- lerobot/common/envs/factory.py | 1 - lerobot/common/optim/optimizers.py | 31 ++++++++------- lerobot/common/policies/factory.py | 2 +- .../classifier/configuration_classifier.py | 5 +-- lerobot/common/policies/normalize.py | 22 +++++++---- .../common/policies/sac/configuration_sac.py | 2 +- lerobot/common/policies/sac/modeling_sac.py | 1 - lerobot/common/utils/wandb_utils.py | 12 +++--- lerobot/configs/train.py | 2 +- lerobot/scripts/server/actor_server.py | 4 +- lerobot/scripts/server/buffer.py | 2 +- .../server/end_effector_control_utils.py | 2 +- lerobot/scripts/server/find_joint_limits.py | 24 ++++++------ .../scripts/server/maniskill_manipulator.py | 38 +++++++++---------- lerobot/scripts/train_hilserl_classifier.py | 1 - tests/optim/test_optimizers.py | 35 ++++++++--------- 16 files changed, 93 insertions(+), 91 deletions(-) diff --git a/lerobot/common/envs/factory.py b/lerobot/common/envs/factory.py index e53ad945..8450f84b 100644 --- a/lerobot/common/envs/factory.py +++ b/lerobot/common/envs/factory.py @@ -14,7 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import importlib -from collections import deque import gymnasium as gym diff --git a/lerobot/common/optim/optimizers.py b/lerobot/common/optim/optimizers.py index bfccd19a..8a5b1803 100644 --- a/lerobot/common/optim/optimizers.py +++ b/lerobot/common/optim/optimizers.py @@ -99,36 +99,37 @@ class SGDConfig(OptimizerConfig): @dataclass class MultiAdamConfig(OptimizerConfig): """Configuration for multiple Adam optimizers with different parameter groups. - + This creates a dictionary of Adam optimizers, each with its own hyperparameters. - + Args: lr: Default learning rate (used if not specified for a group) weight_decay: Default weight decay (used if not specified for a group) optimizer_groups: Dictionary mapping parameter group names to their hyperparameters grad_clip_norm: Gradient clipping norm """ + lr: float = 1e-3 weight_decay: float = 0.0 grad_clip_norm: float = 10.0 optimizer_groups: dict[str, dict[str, Any]] = field(default_factory=dict) - + def build(self, params_dict: dict[str, list]) -> dict[str, torch.optim.Optimizer]: """Build multiple Adam optimizers. - + Args: params_dict: Dictionary mapping parameter group names to lists of parameters The keys should match the keys in optimizer_groups - + Returns: Dictionary mapping parameter group names to their optimizers """ optimizers = {} - + for name, params in params_dict.items(): # Get group-specific hyperparameters or use defaults group_config = self.optimizer_groups.get(name, {}) - + # Create optimizer with merged parameters (defaults + group-specific) optimizer_kwargs = { "lr": group_config.get("lr", self.lr), @@ -136,15 +137,17 @@ class MultiAdamConfig(OptimizerConfig): "eps": group_config.get("eps", 1e-5), "weight_decay": group_config.get("weight_decay", self.weight_decay), } - + optimizers[name] = torch.optim.Adam(params, **optimizer_kwargs) - + return optimizers -def save_optimizer_state(optimizer: torch.optim.Optimizer | dict[str, torch.optim.Optimizer], save_dir: Path) -> None: +def save_optimizer_state( + optimizer: torch.optim.Optimizer | dict[str, torch.optim.Optimizer], save_dir: Path +) -> None: """Save optimizer state to disk. - + Args: optimizer: Either a single optimizer or a dictionary of optimizers. save_dir: Directory to save the optimizer state. @@ -173,11 +176,11 @@ def load_optimizer_state( optimizer: torch.optim.Optimizer | dict[str, torch.optim.Optimizer], save_dir: Path ) -> torch.optim.Optimizer | dict[str, torch.optim.Optimizer]: """Load optimizer state from disk. - + Args: optimizer: Either a single optimizer or a dictionary of optimizers. save_dir: Directory to load the optimizer state from. - + Returns: The updated optimizer(s) with loaded state. """ @@ -201,7 +204,7 @@ def _load_single_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Pat current_state_dict = optimizer.state_dict() flat_state = load_file(save_dir / OPTIMIZER_STATE) state = unflatten_dict(flat_state) - + # Handle case where 'state' key might not exist (for newly created optimizers) if "state" in state: loaded_state_dict = {"state": {int(k): v for k, v in state["state"].items()}} diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py index 6ac082d7..64f83e73 100644 --- a/lerobot/common/policies/factory.py +++ b/lerobot/common/policies/factory.py @@ -24,12 +24,12 @@ from lerobot.common.envs.configs import EnvConfig from lerobot.common.envs.utils import env_to_policy_features from lerobot.common.policies.act.configuration_act import ACTConfig from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig +from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig from lerobot.common.policies.pi0.configuration_pi0 import PI0Config from lerobot.common.policies.pi0fast.configuration_pi0fast import PI0FASTConfig from lerobot.common.policies.pretrained import PreTrainedPolicy from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig -from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.types import FeatureType diff --git a/lerobot/common/policies/hilserl/classifier/configuration_classifier.py b/lerobot/common/policies/hilserl/classifier/configuration_classifier.py index 00688931..d04c189b 100644 --- a/lerobot/common/policies/hilserl/classifier/configuration_classifier.py +++ b/lerobot/common/policies/hilserl/classifier/configuration_classifier.py @@ -1,10 +1,9 @@ -from dataclasses import dataclass, field -from typing import Dict, List +from dataclasses import dataclass +from typing import List from lerobot.common.optim.optimizers import AdamWConfig, OptimizerConfig from lerobot.common.optim.schedulers import LRSchedulerConfig from lerobot.configs.policies import PreTrainedConfig -from lerobot.configs.types import FeatureType, PolicyFeature @PreTrainedConfig.register_subclass(name="hilserl_classifier") diff --git a/lerobot/common/policies/normalize.py b/lerobot/common/policies/normalize.py index 433387f9..5e8af4a9 100644 --- a/lerobot/common/policies/normalize.py +++ b/lerobot/common/policies/normalize.py @@ -82,8 +82,10 @@ def create_stats_buffers( if stats and key in stats: if norm_mode is NormalizationMode.MEAN_STD: if "mean" not in stats[key] or "std" not in stats[key]: - raise ValueError(f"Missing 'mean' or 'std' in stats for key {key} with MEAN_STD normalization") - + raise ValueError( + f"Missing 'mean' or 'std' in stats for key {key} with MEAN_STD normalization" + ) + if isinstance(stats[key]["mean"], np.ndarray): buffer["mean"].data = torch.from_numpy(stats[key]["mean"]).to(dtype=torch.float32) buffer["std"].data = torch.from_numpy(stats[key]["std"]).to(dtype=torch.float32) @@ -96,12 +98,16 @@ def create_stats_buffers( buffer["std"].data = stats[key]["std"].clone().to(dtype=torch.float32) else: type_ = type(stats[key]["mean"]) - raise ValueError(f"np.ndarray or torch.Tensor expected for 'mean', but type is '{type_}' instead.") - + raise ValueError( + f"np.ndarray or torch.Tensor expected for 'mean', but type is '{type_}' instead." + ) + elif norm_mode is NormalizationMode.MIN_MAX: if "min" not in stats[key] or "max" not in stats[key]: - raise ValueError(f"Missing 'min' or 'max' in stats for key {key} with MIN_MAX normalization") - + raise ValueError( + f"Missing 'min' or 'max' in stats for key {key} with MIN_MAX normalization" + ) + if isinstance(stats[key]["min"], np.ndarray): buffer["min"].data = torch.from_numpy(stats[key]["min"]).to(dtype=torch.float32) buffer["max"].data = torch.from_numpy(stats[key]["max"]).to(dtype=torch.float32) @@ -110,7 +116,9 @@ def create_stats_buffers( buffer["max"].data = stats[key]["max"].clone().to(dtype=torch.float32) else: type_ = type(stats[key]["min"]) - raise ValueError(f"np.ndarray or torch.Tensor expected for 'min', but type is '{type_}' instead.") + raise ValueError( + f"np.ndarray or torch.Tensor expected for 'min', but type is '{type_}' instead." + ) stats_buffers[key] = buffer return stats_buffers diff --git a/lerobot/common/policies/sac/configuration_sac.py b/lerobot/common/policies/sac/configuration_sac.py index 0d2c3765..906a3bed 100644 --- a/lerobot/common/policies/sac/configuration_sac.py +++ b/lerobot/common/policies/sac/configuration_sac.py @@ -19,7 +19,7 @@ from dataclasses import dataclass, field from lerobot.common.optim.optimizers import MultiAdamConfig from lerobot.configs.policies import PreTrainedConfig -from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature +from lerobot.configs.types import NormalizationMode @dataclass diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index 385efacc..6827da6a 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -897,7 +897,6 @@ if __name__ == "__main__": # for j in range(i + 1, num_critics): # diff = torch.abs(q_values[i] - q_values[j]).mean().item() # print(f"Mean difference between critic {i} and {j}: {diff:.6f}") - import draccus from lerobot.configs import parser diff --git a/lerobot/common/utils/wandb_utils.py b/lerobot/common/utils/wandb_utils.py index b62ff140..db8911d5 100644 --- a/lerobot/common/utils/wandb_utils.py +++ b/lerobot/common/utils/wandb_utils.py @@ -115,11 +115,13 @@ class WandBLogger: artifact.add_file(checkpoint_dir / PRETRAINED_MODEL_DIR / SAFETENSORS_SINGLE_FILE) self._wandb.log_artifact(artifact) - def log_dict(self, d: dict, step: int | None = None, mode: str = "train", custom_step_key: str | None = None): + def log_dict( + self, d: dict, step: int | None = None, mode: str = "train", custom_step_key: str | None = None + ): if mode not in {"train", "eval"}: raise ValueError(mode) if step is None and custom_step_key is None: - raise ValueError("Either step or custom_step_key must be provided.") + raise ValueError("Either step or custom_step_key must be provided.") # NOTE: This is not simple. Wandb step is it must always monotonically increase and it # increases with each wandb.log call, but in the case of asynchronous RL for example, @@ -142,10 +144,7 @@ class WandBLogger: continue # Do not log the custom step key itself. - if ( - self._wandb_custom_step_key is not None - and k in self._wandb_custom_step_key - ): + if self._wandb_custom_step_key is not None and k in self._wandb_custom_step_key: continue if custom_step_key is not None: @@ -160,7 +159,6 @@ class WandBLogger: self._wandb.log(data={f"{mode}/{k}": v}, step=step) - def log_video(self, video_path: str, step: int, mode: str = "train"): if mode not in {"train", "eval"}: raise ValueError(mode) diff --git a/lerobot/configs/train.py b/lerobot/configs/train.py index 02a9edd6..7de0f879 100644 --- a/lerobot/configs/train.py +++ b/lerobot/configs/train.py @@ -34,7 +34,7 @@ TRAIN_CONFIG_NAME = "train_config.json" @dataclass class TrainPipelineConfig(HubMixin): - dataset: DatasetConfig | None = None # NOTE: In RL, we don't need a dataset + dataset: DatasetConfig | None = None # NOTE: In RL, we don't need a dataset env: envs.EnvConfig | None = None policy: PreTrainedConfig | None = None # Set `dir` to where you would like to save all of the run outputs. If you run another training session # with the same value for `dir` its contents will be overwritten unless you set `resume` to true. diff --git a/lerobot/scripts/server/actor_server.py b/lerobot/scripts/server/actor_server.py index eda6d314..b388f62e 100644 --- a/lerobot/scripts/server/actor_server.py +++ b/lerobot/scripts/server/actor_server.py @@ -47,7 +47,7 @@ from lerobot.scripts.server.buffer import ( python_object_to_bytes, transitions_to_bytes, ) -from lerobot.scripts.server.gym_manipulator import get_classifier, make_robot_env +from lerobot.scripts.server.gym_manipulator import make_robot_env from lerobot.scripts.server.network_utils import ( receive_bytes_in_chunks, send_bytes_in_chunks, @@ -444,7 +444,7 @@ def receive_policy( # Initialize logging with explicit log file init_logging(log_file=log_file) - logging.info(f"Actor receive policy process logging initialized") + logging.info("Actor receive policy process logging initialized") # Setup process handlers to handle shutdown signal # But use shutdown event from the main process diff --git a/lerobot/scripts/server/buffer.py b/lerobot/scripts/server/buffer.py index 0e25253f..20787568 100644 --- a/lerobot/scripts/server/buffer.py +++ b/lerobot/scripts/server/buffer.py @@ -515,7 +515,7 @@ class ReplayBuffer: frame_dict["action"] = self.actions[actual_idx].cpu() frame_dict["next.reward"] = torch.tensor([self.rewards[actual_idx]], dtype=torch.float32).cpu() frame_dict["next.done"] = torch.tensor([self.dones[actual_idx]], dtype=torch.bool).cpu() - + # Add task field which is required by LeRobotDataset frame_dict["task"] = task_name diff --git a/lerobot/scripts/server/end_effector_control_utils.py b/lerobot/scripts/server/end_effector_control_utils.py index 5173809d..3bd927b4 100644 --- a/lerobot/scripts/server/end_effector_control_utils.py +++ b/lerobot/scripts/server/end_effector_control_utils.py @@ -701,10 +701,10 @@ def teleoperate_gym_env(env, controller, fps: int = 30): if __name__ == "__main__": + from lerobot.common.envs.configs import EEActionSpaceConfig, EnvWrapperConfig, HILSerlRobotEnvConfig from lerobot.common.robot_devices.robots.configs import RobotConfig from lerobot.common.robot_devices.robots.utils import make_robot_from_config from lerobot.scripts.server.gym_manipulator import make_robot_env - from lerobot.common.envs.configs import HILSerlRobotEnvConfig, EEActionSpaceConfig, EnvWrapperConfig parser = argparse.ArgumentParser(description="Test end-effector control") parser.add_argument( diff --git a/lerobot/scripts/server/find_joint_limits.py b/lerobot/scripts/server/find_joint_limits.py index f34c8f9f..d40caf5a 100644 --- a/lerobot/scripts/server/find_joint_limits.py +++ b/lerobot/scripts/server/find_joint_limits.py @@ -5,14 +5,15 @@ import cv2 import numpy as np from lerobot.common.robot_devices.control_utils import is_headless -from lerobot.common.robot_devices.robots.utils import make_robot_from_config -from lerobot.scripts.server.kinematics import RobotKinematics -from lerobot.configs import parser from lerobot.common.robot_devices.robots.configs import RobotConfig +from lerobot.common.robot_devices.robots.utils import make_robot_from_config +from lerobot.configs import parser +from lerobot.scripts.server.kinematics import RobotKinematics follower_port = "/dev/tty.usbmodem58760431631" leader_port = "/dev/tty.usbmodem58760433331" + def find_joint_bounds( robot, control_time_s=30, @@ -85,21 +86,22 @@ def find_ee_bounds( def make_robot(robot_type="so100"): """ Create a robot instance using the appropriate robot config class. - + Args: robot_type: Robot type string (e.g., "so100", "koch", "aloha") - + Returns: Robot instance """ - + # Get the appropriate robot config class based on robot_type robot_config = RobotConfig.get_choice_class(robot_type)(mock=False) robot_config.leader_arms["main"].port = leader_port robot_config.follower_arms["main"].port = follower_port - + return make_robot_from_config(robot_config) + if __name__ == "__main__": # Create argparse for script-specific arguments parser = argparse.ArgumentParser(add_help=False) # Set add_help=False to avoid conflict @@ -125,14 +127,14 @@ if __name__ == "__main__": # Only parse known args, leaving robot config args for Hydra if used args = parser.parse_args() - + # Create robot with the appropriate config robot = make_robot(args.robot_type) - + if args.mode == "joint": find_joint_bounds(robot, args.control_time_s) elif args.mode == "ee": find_ee_bounds(robot, args.control_time_s) - + if robot.is_connected: - robot.disconnect() \ No newline at end of file + robot.disconnect() diff --git a/lerobot/scripts/server/maniskill_manipulator.py b/lerobot/scripts/server/maniskill_manipulator.py index 2ad7c661..e10b8766 100644 --- a/lerobot/scripts/server/maniskill_manipulator.py +++ b/lerobot/scripts/server/maniskill_manipulator.py @@ -1,7 +1,6 @@ import logging import time -from dataclasses import dataclass, field -from typing import Any, Dict, Optional, Tuple +from typing import Any import einops import gymnasium as gym @@ -10,10 +9,8 @@ import torch from mani_skill.utils.wrappers.record import RecordEpisode from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv -from lerobot.common.envs.configs import EnvConfig, ManiskillEnvConfig +from lerobot.common.envs.configs import ManiskillEnvConfig from lerobot.configs import parser -from lerobot.configs.types import FeatureType, PolicyFeature -from lerobot.common.constants import ACTION, OBS_IMAGE, OBS_ROBOT def preprocess_maniskill_observation( @@ -53,9 +50,6 @@ def preprocess_maniskill_observation( return return_observations - - - class ManiSkillObservationWrapper(gym.ObservationWrapper): def __init__(self, env, device: torch.device = "cuda"): super().__init__(env) @@ -122,6 +116,7 @@ class ManiSkillMultiplyActionWrapper(gym.Wrapper): class BatchCompatibleWrapper(gym.ObservationWrapper): """Ensures observations are batch-compatible by adding a batch dimension if necessary.""" + def __init__(self, env): super().__init__(env) @@ -136,6 +131,7 @@ class BatchCompatibleWrapper(gym.ObservationWrapper): class TimeLimitWrapper(gym.Wrapper): """Adds a time limit to the environment based on fps and control_time.""" + def __init__(self, env, control_time_s, fps): super().__init__(env) self.control_time_s = control_time_s @@ -146,10 +142,10 @@ class TimeLimitWrapper(gym.Wrapper): def step(self, action): obs, reward, terminated, truncated, info = self.env.step(action) self.current_step += 1 - + if self.current_step >= self.max_episode_steps: terminated = True - + return obs, reward, terminated, truncated, info def reset(self, *, seed=None, options=None): @@ -190,18 +186,18 @@ def make_maniskill( save_video=True, video_fps=30, ) - + # Add observation and image processing env = ManiSkillObservationWrapper(env, device=cfg.device) env = ManiSkillVectorEnv(env, ignore_terminations=True, auto_reset=False) env._max_episode_steps = env.max_episode_steps = cfg.episode_length env.unwrapped.metadata["render_fps"] = cfg.fps - + # Add compatibility wrappers env = ManiSkillCompat(env) env = ManiSkillActionWrapper(env) env = ManiSkillMultiplyActionWrapper(env, multiply_factor=0.03) # Scale actions for better control - + return env @@ -210,29 +206,29 @@ def main(cfg: ManiskillEnvConfig): """Main function to run the ManiSkill environment.""" # Create the ManiSkill environment env = make_maniskill(cfg, n_envs=1) - + # Reset the environment obs, info = env.reset() - + # Run a simple interaction loop sum_reward = 0 for i in range(100): # Sample a random action action = env.action_space.sample() - + # Step the environment start_time = time.perf_counter() obs, reward, terminated, truncated, info = env.step(action) step_time = time.perf_counter() - start_time sum_reward += reward # Log information - + # Reset if episode terminated if terminated or truncated: logging.info(f"Step {i}, reward: {sum_reward}, step time: {step_time}s") sum_reward = 0 obs, info = env.reset() - + # Close the environment env.close() @@ -243,6 +239,10 @@ def main(cfg: ManiskillEnvConfig): if __name__ == "__main__": import draccus + config = ManiskillEnvConfig() draccus.set_config_type("json") - draccus.dump(config=config, stream=open(file='run_config.json', mode='w'), ) \ No newline at end of file + draccus.dump( + config=config, + stream=open(file="run_config.json", mode="w"), + ) diff --git a/lerobot/scripts/train_hilserl_classifier.py b/lerobot/scripts/train_hilserl_classifier.py index a69b2b3c..d2927dbe 100644 --- a/lerobot/scripts/train_hilserl_classifier.py +++ b/lerobot/scripts/train_hilserl_classifier.py @@ -32,7 +32,6 @@ from tqdm import tqdm from lerobot.common.datasets.factory import resolve_delta_timestamps from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig - from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier from lerobot.common.utils.utils import ( format_big_number, diff --git a/tests/optim/test_optimizers.py b/tests/optim/test_optimizers.py index 8f431ce5..ee2fe80e 100644 --- a/tests/optim/test_optimizers.py +++ b/tests/optim/test_optimizers.py @@ -134,15 +134,15 @@ def test_multi_adam_configuration(base_params_dict, config_params, expected_valu # Create config with the given parameters config = MultiAdamConfig(**config_params) optimizers = config.build(base_params_dict) - + # Verify optimizer count and keys assert len(optimizers) == len(expected_values) assert set(optimizers.keys()) == set(expected_values.keys()) - + # Check that all optimizers are Adam instances for opt in optimizers.values(): assert isinstance(opt, torch.optim.Adam) - + # Verify hyperparameters for each optimizer for name, expected in expected_values.items(): optimizer = optimizers[name] @@ -166,7 +166,7 @@ def multi_optimizers(base_params_dict): def test_save_multi_optimizer_state(multi_optimizers, tmp_path): # Save optimizer states save_optimizer_state(multi_optimizers, tmp_path) - + # Verify that directories were created for each optimizer for name in multi_optimizers.keys(): assert (tmp_path / name).is_dir() @@ -185,10 +185,10 @@ def test_save_and_load_multi_optimizer_state(base_params_dict, multi_optimizers, multi_optimizers[name].step() # Zero gradients for next steps multi_optimizers[name].zero_grad() - + # Save optimizer states save_optimizer_state(multi_optimizers, tmp_path) - + # Create new optimizers with the same config config = MultiAdamConfig( lr=1e-3, @@ -199,16 +199,13 @@ def test_save_and_load_multi_optimizer_state(base_params_dict, multi_optimizers, }, ) new_optimizers = config.build(base_params_dict) - + # Load optimizer states loaded_optimizers = load_optimizer_state(new_optimizers, tmp_path) - + # Verify state dictionaries match for name in multi_optimizers.keys(): - torch.testing.assert_close( - multi_optimizers[name].state_dict(), - loaded_optimizers[name].state_dict() - ) + torch.testing.assert_close(multi_optimizers[name].state_dict(), loaded_optimizers[name].state_dict()) def test_save_and_load_empty_multi_optimizer_state(base_params_dict, tmp_path): @@ -223,25 +220,23 @@ def test_save_and_load_empty_multi_optimizer_state(base_params_dict, tmp_path): }, ) optimizers = config.build(base_params_dict) - + # Save optimizer states without any backward pass (empty state) save_optimizer_state(optimizers, tmp_path) - + # Create new optimizers with the same config new_optimizers = config.build(base_params_dict) - + # Load optimizer states loaded_optimizers = load_optimizer_state(new_optimizers, tmp_path) - + # Verify hyperparameters match even with empty state for name, optimizer in optimizers.items(): assert optimizer.defaults["lr"] == loaded_optimizers[name].defaults["lr"] assert optimizer.defaults["weight_decay"] == loaded_optimizers[name].defaults["weight_decay"] assert optimizer.defaults["betas"] == loaded_optimizers[name].defaults["betas"] - + # Verify state dictionaries match (they will be empty) torch.testing.assert_close( - optimizer.state_dict()["param_groups"], - loaded_optimizers[name].state_dict()["param_groups"] + optimizer.state_dict()["param_groups"], loaded_optimizers[name].state_dict()["param_groups"] ) -