[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
parent
808cf63221
commit
c05e4835d0
|
@ -14,7 +14,6 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import importlib
|
||||
from collections import deque
|
||||
|
||||
import gymnasium as gym
|
||||
|
||||
|
|
|
@ -99,36 +99,37 @@ class SGDConfig(OptimizerConfig):
|
|||
@dataclass
|
||||
class MultiAdamConfig(OptimizerConfig):
|
||||
"""Configuration for multiple Adam optimizers with different parameter groups.
|
||||
|
||||
|
||||
This creates a dictionary of Adam optimizers, each with its own hyperparameters.
|
||||
|
||||
|
||||
Args:
|
||||
lr: Default learning rate (used if not specified for a group)
|
||||
weight_decay: Default weight decay (used if not specified for a group)
|
||||
optimizer_groups: Dictionary mapping parameter group names to their hyperparameters
|
||||
grad_clip_norm: Gradient clipping norm
|
||||
"""
|
||||
|
||||
lr: float = 1e-3
|
||||
weight_decay: float = 0.0
|
||||
grad_clip_norm: float = 10.0
|
||||
optimizer_groups: dict[str, dict[str, Any]] = field(default_factory=dict)
|
||||
|
||||
|
||||
def build(self, params_dict: dict[str, list]) -> dict[str, torch.optim.Optimizer]:
|
||||
"""Build multiple Adam optimizers.
|
||||
|
||||
|
||||
Args:
|
||||
params_dict: Dictionary mapping parameter group names to lists of parameters
|
||||
The keys should match the keys in optimizer_groups
|
||||
|
||||
|
||||
Returns:
|
||||
Dictionary mapping parameter group names to their optimizers
|
||||
"""
|
||||
optimizers = {}
|
||||
|
||||
|
||||
for name, params in params_dict.items():
|
||||
# Get group-specific hyperparameters or use defaults
|
||||
group_config = self.optimizer_groups.get(name, {})
|
||||
|
||||
|
||||
# Create optimizer with merged parameters (defaults + group-specific)
|
||||
optimizer_kwargs = {
|
||||
"lr": group_config.get("lr", self.lr),
|
||||
|
@ -136,15 +137,17 @@ class MultiAdamConfig(OptimizerConfig):
|
|||
"eps": group_config.get("eps", 1e-5),
|
||||
"weight_decay": group_config.get("weight_decay", self.weight_decay),
|
||||
}
|
||||
|
||||
|
||||
optimizers[name] = torch.optim.Adam(params, **optimizer_kwargs)
|
||||
|
||||
|
||||
return optimizers
|
||||
|
||||
|
||||
def save_optimizer_state(optimizer: torch.optim.Optimizer | dict[str, torch.optim.Optimizer], save_dir: Path) -> None:
|
||||
def save_optimizer_state(
|
||||
optimizer: torch.optim.Optimizer | dict[str, torch.optim.Optimizer], save_dir: Path
|
||||
) -> None:
|
||||
"""Save optimizer state to disk.
|
||||
|
||||
|
||||
Args:
|
||||
optimizer: Either a single optimizer or a dictionary of optimizers.
|
||||
save_dir: Directory to save the optimizer state.
|
||||
|
@ -173,11 +176,11 @@ def load_optimizer_state(
|
|||
optimizer: torch.optim.Optimizer | dict[str, torch.optim.Optimizer], save_dir: Path
|
||||
) -> torch.optim.Optimizer | dict[str, torch.optim.Optimizer]:
|
||||
"""Load optimizer state from disk.
|
||||
|
||||
|
||||
Args:
|
||||
optimizer: Either a single optimizer or a dictionary of optimizers.
|
||||
save_dir: Directory to load the optimizer state from.
|
||||
|
||||
|
||||
Returns:
|
||||
The updated optimizer(s) with loaded state.
|
||||
"""
|
||||
|
@ -201,7 +204,7 @@ def _load_single_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Pat
|
|||
current_state_dict = optimizer.state_dict()
|
||||
flat_state = load_file(save_dir / OPTIMIZER_STATE)
|
||||
state = unflatten_dict(flat_state)
|
||||
|
||||
|
||||
# Handle case where 'state' key might not exist (for newly created optimizers)
|
||||
if "state" in state:
|
||||
loaded_state_dict = {"state": {int(k): v for k, v in state["state"].items()}}
|
||||
|
|
|
@ -24,11 +24,11 @@ from lerobot.common.envs.configs import EnvConfig
|
|||
from lerobot.common.envs.utils import env_to_policy_features
|
||||
from lerobot.common.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig
|
||||
from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||
from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType
|
||||
|
||||
|
|
|
@ -1,10 +1,9 @@
|
|||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
|
||||
from lerobot.common.optim.optimizers import AdamWConfig, OptimizerConfig
|
||||
from lerobot.common.optim.schedulers import LRSchedulerConfig
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass(name="hilserl_classifier")
|
||||
|
|
|
@ -82,8 +82,10 @@ def create_stats_buffers(
|
|||
if stats and key in stats:
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
if "mean" not in stats[key] or "std" not in stats[key]:
|
||||
raise ValueError(f"Missing 'mean' or 'std' in stats for key {key} with MEAN_STD normalization")
|
||||
|
||||
raise ValueError(
|
||||
f"Missing 'mean' or 'std' in stats for key {key} with MEAN_STD normalization"
|
||||
)
|
||||
|
||||
if isinstance(stats[key]["mean"], np.ndarray):
|
||||
buffer["mean"].data = torch.from_numpy(stats[key]["mean"]).to(dtype=torch.float32)
|
||||
buffer["std"].data = torch.from_numpy(stats[key]["std"]).to(dtype=torch.float32)
|
||||
|
@ -96,12 +98,16 @@ def create_stats_buffers(
|
|||
buffer["std"].data = stats[key]["std"].clone().to(dtype=torch.float32)
|
||||
else:
|
||||
type_ = type(stats[key]["mean"])
|
||||
raise ValueError(f"np.ndarray or torch.Tensor expected for 'mean', but type is '{type_}' instead.")
|
||||
|
||||
raise ValueError(
|
||||
f"np.ndarray or torch.Tensor expected for 'mean', but type is '{type_}' instead."
|
||||
)
|
||||
|
||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||
if "min" not in stats[key] or "max" not in stats[key]:
|
||||
raise ValueError(f"Missing 'min' or 'max' in stats for key {key} with MIN_MAX normalization")
|
||||
|
||||
raise ValueError(
|
||||
f"Missing 'min' or 'max' in stats for key {key} with MIN_MAX normalization"
|
||||
)
|
||||
|
||||
if isinstance(stats[key]["min"], np.ndarray):
|
||||
buffer["min"].data = torch.from_numpy(stats[key]["min"]).to(dtype=torch.float32)
|
||||
buffer["max"].data = torch.from_numpy(stats[key]["max"]).to(dtype=torch.float32)
|
||||
|
@ -110,7 +116,9 @@ def create_stats_buffers(
|
|||
buffer["max"].data = stats[key]["max"].clone().to(dtype=torch.float32)
|
||||
else:
|
||||
type_ = type(stats[key]["min"])
|
||||
raise ValueError(f"np.ndarray or torch.Tensor expected for 'min', but type is '{type_}' instead.")
|
||||
raise ValueError(
|
||||
f"np.ndarray or torch.Tensor expected for 'min', but type is '{type_}' instead."
|
||||
)
|
||||
|
||||
stats_buffers[key] = buffer
|
||||
return stats_buffers
|
||||
|
|
|
@ -19,7 +19,7 @@ from dataclasses import dataclass, field
|
|||
|
||||
from lerobot.common.optim.optimizers import MultiAdamConfig
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.configs.types import NormalizationMode
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
|
@ -897,7 +897,6 @@ if __name__ == "__main__":
|
|||
# for j in range(i + 1, num_critics):
|
||||
# diff = torch.abs(q_values[i] - q_values[j]).mean().item()
|
||||
# print(f"Mean difference between critic {i} and {j}: {diff:.6f}")
|
||||
import draccus
|
||||
|
||||
from lerobot.configs import parser
|
||||
|
||||
|
|
|
@ -115,11 +115,13 @@ class WandBLogger:
|
|||
artifact.add_file(checkpoint_dir / PRETRAINED_MODEL_DIR / SAFETENSORS_SINGLE_FILE)
|
||||
self._wandb.log_artifact(artifact)
|
||||
|
||||
def log_dict(self, d: dict, step: int | None = None, mode: str = "train", custom_step_key: str | None = None):
|
||||
def log_dict(
|
||||
self, d: dict, step: int | None = None, mode: str = "train", custom_step_key: str | None = None
|
||||
):
|
||||
if mode not in {"train", "eval"}:
|
||||
raise ValueError(mode)
|
||||
if step is None and custom_step_key is None:
|
||||
raise ValueError("Either step or custom_step_key must be provided.")
|
||||
raise ValueError("Either step or custom_step_key must be provided.")
|
||||
|
||||
# NOTE: This is not simple. Wandb step is it must always monotonically increase and it
|
||||
# increases with each wandb.log call, but in the case of asynchronous RL for example,
|
||||
|
@ -142,10 +144,7 @@ class WandBLogger:
|
|||
continue
|
||||
|
||||
# Do not log the custom step key itself.
|
||||
if (
|
||||
self._wandb_custom_step_key is not None
|
||||
and k in self._wandb_custom_step_key
|
||||
):
|
||||
if self._wandb_custom_step_key is not None and k in self._wandb_custom_step_key:
|
||||
continue
|
||||
|
||||
if custom_step_key is not None:
|
||||
|
@ -160,7 +159,6 @@ class WandBLogger:
|
|||
|
||||
self._wandb.log(data={f"{mode}/{k}": v}, step=step)
|
||||
|
||||
|
||||
def log_video(self, video_path: str, step: int, mode: str = "train"):
|
||||
if mode not in {"train", "eval"}:
|
||||
raise ValueError(mode)
|
||||
|
|
|
@ -34,7 +34,7 @@ TRAIN_CONFIG_NAME = "train_config.json"
|
|||
|
||||
@dataclass
|
||||
class TrainPipelineConfig(HubMixin):
|
||||
dataset: DatasetConfig | None = None # NOTE: In RL, we don't need a dataset
|
||||
dataset: DatasetConfig | None = None # NOTE: In RL, we don't need a dataset
|
||||
env: envs.EnvConfig | None = None
|
||||
policy: PreTrainedConfig | None = None
|
||||
# Set `dir` to where you would like to save all of the run outputs. If you run another training session # with the same value for `dir` its contents will be overwritten unless you set `resume` to true.
|
||||
|
|
|
@ -47,7 +47,7 @@ from lerobot.scripts.server.buffer import (
|
|||
python_object_to_bytes,
|
||||
transitions_to_bytes,
|
||||
)
|
||||
from lerobot.scripts.server.gym_manipulator import get_classifier, make_robot_env
|
||||
from lerobot.scripts.server.gym_manipulator import make_robot_env
|
||||
from lerobot.scripts.server.network_utils import (
|
||||
receive_bytes_in_chunks,
|
||||
send_bytes_in_chunks,
|
||||
|
@ -444,7 +444,7 @@ def receive_policy(
|
|||
|
||||
# Initialize logging with explicit log file
|
||||
init_logging(log_file=log_file)
|
||||
logging.info(f"Actor receive policy process logging initialized")
|
||||
logging.info("Actor receive policy process logging initialized")
|
||||
|
||||
# Setup process handlers to handle shutdown signal
|
||||
# But use shutdown event from the main process
|
||||
|
|
|
@ -515,7 +515,7 @@ class ReplayBuffer:
|
|||
frame_dict["action"] = self.actions[actual_idx].cpu()
|
||||
frame_dict["next.reward"] = torch.tensor([self.rewards[actual_idx]], dtype=torch.float32).cpu()
|
||||
frame_dict["next.done"] = torch.tensor([self.dones[actual_idx]], dtype=torch.bool).cpu()
|
||||
|
||||
|
||||
# Add task field which is required by LeRobotDataset
|
||||
frame_dict["task"] = task_name
|
||||
|
||||
|
|
|
@ -701,10 +701,10 @@ def teleoperate_gym_env(env, controller, fps: int = 30):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from lerobot.common.envs.configs import EEActionSpaceConfig, EnvWrapperConfig, HILSerlRobotEnvConfig
|
||||
from lerobot.common.robot_devices.robots.configs import RobotConfig
|
||||
from lerobot.common.robot_devices.robots.utils import make_robot_from_config
|
||||
from lerobot.scripts.server.gym_manipulator import make_robot_env
|
||||
from lerobot.common.envs.configs import HILSerlRobotEnvConfig, EEActionSpaceConfig, EnvWrapperConfig
|
||||
|
||||
parser = argparse.ArgumentParser(description="Test end-effector control")
|
||||
parser.add_argument(
|
||||
|
|
|
@ -5,14 +5,15 @@ import cv2
|
|||
import numpy as np
|
||||
|
||||
from lerobot.common.robot_devices.control_utils import is_headless
|
||||
from lerobot.common.robot_devices.robots.utils import make_robot_from_config
|
||||
from lerobot.scripts.server.kinematics import RobotKinematics
|
||||
from lerobot.configs import parser
|
||||
from lerobot.common.robot_devices.robots.configs import RobotConfig
|
||||
from lerobot.common.robot_devices.robots.utils import make_robot_from_config
|
||||
from lerobot.configs import parser
|
||||
from lerobot.scripts.server.kinematics import RobotKinematics
|
||||
|
||||
follower_port = "/dev/tty.usbmodem58760431631"
|
||||
leader_port = "/dev/tty.usbmodem58760433331"
|
||||
|
||||
|
||||
def find_joint_bounds(
|
||||
robot,
|
||||
control_time_s=30,
|
||||
|
@ -85,21 +86,22 @@ def find_ee_bounds(
|
|||
def make_robot(robot_type="so100"):
|
||||
"""
|
||||
Create a robot instance using the appropriate robot config class.
|
||||
|
||||
|
||||
Args:
|
||||
robot_type: Robot type string (e.g., "so100", "koch", "aloha")
|
||||
|
||||
|
||||
Returns:
|
||||
Robot instance
|
||||
"""
|
||||
|
||||
|
||||
# Get the appropriate robot config class based on robot_type
|
||||
robot_config = RobotConfig.get_choice_class(robot_type)(mock=False)
|
||||
robot_config.leader_arms["main"].port = leader_port
|
||||
robot_config.follower_arms["main"].port = follower_port
|
||||
|
||||
|
||||
return make_robot_from_config(robot_config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Create argparse for script-specific arguments
|
||||
parser = argparse.ArgumentParser(add_help=False) # Set add_help=False to avoid conflict
|
||||
|
@ -125,14 +127,14 @@ if __name__ == "__main__":
|
|||
|
||||
# Only parse known args, leaving robot config args for Hydra if used
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
# Create robot with the appropriate config
|
||||
robot = make_robot(args.robot_type)
|
||||
|
||||
|
||||
if args.mode == "joint":
|
||||
find_joint_bounds(robot, args.control_time_s)
|
||||
elif args.mode == "ee":
|
||||
find_ee_bounds(robot, args.control_time_s)
|
||||
|
||||
|
||||
if robot.is_connected:
|
||||
robot.disconnect()
|
||||
robot.disconnect()
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
from typing import Any
|
||||
|
||||
import einops
|
||||
import gymnasium as gym
|
||||
|
@ -10,10 +9,8 @@ import torch
|
|||
from mani_skill.utils.wrappers.record import RecordEpisode
|
||||
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
|
||||
|
||||
from lerobot.common.envs.configs import EnvConfig, ManiskillEnvConfig
|
||||
from lerobot.common.envs.configs import ManiskillEnvConfig
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.common.constants import ACTION, OBS_IMAGE, OBS_ROBOT
|
||||
|
||||
|
||||
def preprocess_maniskill_observation(
|
||||
|
@ -53,9 +50,6 @@ def preprocess_maniskill_observation(
|
|||
return return_observations
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class ManiSkillObservationWrapper(gym.ObservationWrapper):
|
||||
def __init__(self, env, device: torch.device = "cuda"):
|
||||
super().__init__(env)
|
||||
|
@ -122,6 +116,7 @@ class ManiSkillMultiplyActionWrapper(gym.Wrapper):
|
|||
|
||||
class BatchCompatibleWrapper(gym.ObservationWrapper):
|
||||
"""Ensures observations are batch-compatible by adding a batch dimension if necessary."""
|
||||
|
||||
def __init__(self, env):
|
||||
super().__init__(env)
|
||||
|
||||
|
@ -136,6 +131,7 @@ class BatchCompatibleWrapper(gym.ObservationWrapper):
|
|||
|
||||
class TimeLimitWrapper(gym.Wrapper):
|
||||
"""Adds a time limit to the environment based on fps and control_time."""
|
||||
|
||||
def __init__(self, env, control_time_s, fps):
|
||||
super().__init__(env)
|
||||
self.control_time_s = control_time_s
|
||||
|
@ -146,10 +142,10 @@ class TimeLimitWrapper(gym.Wrapper):
|
|||
def step(self, action):
|
||||
obs, reward, terminated, truncated, info = self.env.step(action)
|
||||
self.current_step += 1
|
||||
|
||||
|
||||
if self.current_step >= self.max_episode_steps:
|
||||
terminated = True
|
||||
|
||||
|
||||
return obs, reward, terminated, truncated, info
|
||||
|
||||
def reset(self, *, seed=None, options=None):
|
||||
|
@ -190,18 +186,18 @@ def make_maniskill(
|
|||
save_video=True,
|
||||
video_fps=30,
|
||||
)
|
||||
|
||||
|
||||
# Add observation and image processing
|
||||
env = ManiSkillObservationWrapper(env, device=cfg.device)
|
||||
env = ManiSkillVectorEnv(env, ignore_terminations=True, auto_reset=False)
|
||||
env._max_episode_steps = env.max_episode_steps = cfg.episode_length
|
||||
env.unwrapped.metadata["render_fps"] = cfg.fps
|
||||
|
||||
|
||||
# Add compatibility wrappers
|
||||
env = ManiSkillCompat(env)
|
||||
env = ManiSkillActionWrapper(env)
|
||||
env = ManiSkillMultiplyActionWrapper(env, multiply_factor=0.03) # Scale actions for better control
|
||||
|
||||
|
||||
return env
|
||||
|
||||
|
||||
|
@ -210,29 +206,29 @@ def main(cfg: ManiskillEnvConfig):
|
|||
"""Main function to run the ManiSkill environment."""
|
||||
# Create the ManiSkill environment
|
||||
env = make_maniskill(cfg, n_envs=1)
|
||||
|
||||
|
||||
# Reset the environment
|
||||
obs, info = env.reset()
|
||||
|
||||
|
||||
# Run a simple interaction loop
|
||||
sum_reward = 0
|
||||
for i in range(100):
|
||||
# Sample a random action
|
||||
action = env.action_space.sample()
|
||||
|
||||
|
||||
# Step the environment
|
||||
start_time = time.perf_counter()
|
||||
obs, reward, terminated, truncated, info = env.step(action)
|
||||
step_time = time.perf_counter() - start_time
|
||||
sum_reward += reward
|
||||
# Log information
|
||||
|
||||
|
||||
# Reset if episode terminated
|
||||
if terminated or truncated:
|
||||
logging.info(f"Step {i}, reward: {sum_reward}, step time: {step_time}s")
|
||||
sum_reward = 0
|
||||
obs, info = env.reset()
|
||||
|
||||
|
||||
# Close the environment
|
||||
env.close()
|
||||
|
||||
|
@ -243,6 +239,10 @@ def main(cfg: ManiskillEnvConfig):
|
|||
|
||||
if __name__ == "__main__":
|
||||
import draccus
|
||||
|
||||
config = ManiskillEnvConfig()
|
||||
draccus.set_config_type("json")
|
||||
draccus.dump(config=config, stream=open(file='run_config.json', mode='w'), )
|
||||
draccus.dump(
|
||||
config=config,
|
||||
stream=open(file="run_config.json", mode="w"),
|
||||
)
|
||||
|
|
|
@ -32,7 +32,6 @@ from tqdm import tqdm
|
|||
from lerobot.common.datasets.factory import resolve_delta_timestamps
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig
|
||||
|
||||
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
|
||||
from lerobot.common.utils.utils import (
|
||||
format_big_number,
|
||||
|
|
|
@ -134,15 +134,15 @@ def test_multi_adam_configuration(base_params_dict, config_params, expected_valu
|
|||
# Create config with the given parameters
|
||||
config = MultiAdamConfig(**config_params)
|
||||
optimizers = config.build(base_params_dict)
|
||||
|
||||
|
||||
# Verify optimizer count and keys
|
||||
assert len(optimizers) == len(expected_values)
|
||||
assert set(optimizers.keys()) == set(expected_values.keys())
|
||||
|
||||
|
||||
# Check that all optimizers are Adam instances
|
||||
for opt in optimizers.values():
|
||||
assert isinstance(opt, torch.optim.Adam)
|
||||
|
||||
|
||||
# Verify hyperparameters for each optimizer
|
||||
for name, expected in expected_values.items():
|
||||
optimizer = optimizers[name]
|
||||
|
@ -166,7 +166,7 @@ def multi_optimizers(base_params_dict):
|
|||
def test_save_multi_optimizer_state(multi_optimizers, tmp_path):
|
||||
# Save optimizer states
|
||||
save_optimizer_state(multi_optimizers, tmp_path)
|
||||
|
||||
|
||||
# Verify that directories were created for each optimizer
|
||||
for name in multi_optimizers.keys():
|
||||
assert (tmp_path / name).is_dir()
|
||||
|
@ -185,10 +185,10 @@ def test_save_and_load_multi_optimizer_state(base_params_dict, multi_optimizers,
|
|||
multi_optimizers[name].step()
|
||||
# Zero gradients for next steps
|
||||
multi_optimizers[name].zero_grad()
|
||||
|
||||
|
||||
# Save optimizer states
|
||||
save_optimizer_state(multi_optimizers, tmp_path)
|
||||
|
||||
|
||||
# Create new optimizers with the same config
|
||||
config = MultiAdamConfig(
|
||||
lr=1e-3,
|
||||
|
@ -199,16 +199,13 @@ def test_save_and_load_multi_optimizer_state(base_params_dict, multi_optimizers,
|
|||
},
|
||||
)
|
||||
new_optimizers = config.build(base_params_dict)
|
||||
|
||||
|
||||
# Load optimizer states
|
||||
loaded_optimizers = load_optimizer_state(new_optimizers, tmp_path)
|
||||
|
||||
|
||||
# Verify state dictionaries match
|
||||
for name in multi_optimizers.keys():
|
||||
torch.testing.assert_close(
|
||||
multi_optimizers[name].state_dict(),
|
||||
loaded_optimizers[name].state_dict()
|
||||
)
|
||||
torch.testing.assert_close(multi_optimizers[name].state_dict(), loaded_optimizers[name].state_dict())
|
||||
|
||||
|
||||
def test_save_and_load_empty_multi_optimizer_state(base_params_dict, tmp_path):
|
||||
|
@ -223,25 +220,23 @@ def test_save_and_load_empty_multi_optimizer_state(base_params_dict, tmp_path):
|
|||
},
|
||||
)
|
||||
optimizers = config.build(base_params_dict)
|
||||
|
||||
|
||||
# Save optimizer states without any backward pass (empty state)
|
||||
save_optimizer_state(optimizers, tmp_path)
|
||||
|
||||
|
||||
# Create new optimizers with the same config
|
||||
new_optimizers = config.build(base_params_dict)
|
||||
|
||||
|
||||
# Load optimizer states
|
||||
loaded_optimizers = load_optimizer_state(new_optimizers, tmp_path)
|
||||
|
||||
|
||||
# Verify hyperparameters match even with empty state
|
||||
for name, optimizer in optimizers.items():
|
||||
assert optimizer.defaults["lr"] == loaded_optimizers[name].defaults["lr"]
|
||||
assert optimizer.defaults["weight_decay"] == loaded_optimizers[name].defaults["weight_decay"]
|
||||
assert optimizer.defaults["betas"] == loaded_optimizers[name].defaults["betas"]
|
||||
|
||||
|
||||
# Verify state dictionaries match (they will be empty)
|
||||
torch.testing.assert_close(
|
||||
optimizer.state_dict()["param_groups"],
|
||||
loaded_optimizers[name].state_dict()["param_groups"]
|
||||
optimizer.state_dict()["param_groups"], loaded_optimizers[name].state_dict()["param_groups"]
|
||||
)
|
||||
|
||||
|
|
Loading…
Reference in New Issue