[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot] 2025-03-28 17:20:38 +00:00
parent 808cf63221
commit c05e4835d0
16 changed files with 93 additions and 91 deletions

View File

@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import importlib import importlib
from collections import deque
import gymnasium as gym import gymnasium as gym

View File

@ -99,36 +99,37 @@ class SGDConfig(OptimizerConfig):
@dataclass @dataclass
class MultiAdamConfig(OptimizerConfig): class MultiAdamConfig(OptimizerConfig):
"""Configuration for multiple Adam optimizers with different parameter groups. """Configuration for multiple Adam optimizers with different parameter groups.
This creates a dictionary of Adam optimizers, each with its own hyperparameters. This creates a dictionary of Adam optimizers, each with its own hyperparameters.
Args: Args:
lr: Default learning rate (used if not specified for a group) lr: Default learning rate (used if not specified for a group)
weight_decay: Default weight decay (used if not specified for a group) weight_decay: Default weight decay (used if not specified for a group)
optimizer_groups: Dictionary mapping parameter group names to their hyperparameters optimizer_groups: Dictionary mapping parameter group names to their hyperparameters
grad_clip_norm: Gradient clipping norm grad_clip_norm: Gradient clipping norm
""" """
lr: float = 1e-3 lr: float = 1e-3
weight_decay: float = 0.0 weight_decay: float = 0.0
grad_clip_norm: float = 10.0 grad_clip_norm: float = 10.0
optimizer_groups: dict[str, dict[str, Any]] = field(default_factory=dict) optimizer_groups: dict[str, dict[str, Any]] = field(default_factory=dict)
def build(self, params_dict: dict[str, list]) -> dict[str, torch.optim.Optimizer]: def build(self, params_dict: dict[str, list]) -> dict[str, torch.optim.Optimizer]:
"""Build multiple Adam optimizers. """Build multiple Adam optimizers.
Args: Args:
params_dict: Dictionary mapping parameter group names to lists of parameters params_dict: Dictionary mapping parameter group names to lists of parameters
The keys should match the keys in optimizer_groups The keys should match the keys in optimizer_groups
Returns: Returns:
Dictionary mapping parameter group names to their optimizers Dictionary mapping parameter group names to their optimizers
""" """
optimizers = {} optimizers = {}
for name, params in params_dict.items(): for name, params in params_dict.items():
# Get group-specific hyperparameters or use defaults # Get group-specific hyperparameters or use defaults
group_config = self.optimizer_groups.get(name, {}) group_config = self.optimizer_groups.get(name, {})
# Create optimizer with merged parameters (defaults + group-specific) # Create optimizer with merged parameters (defaults + group-specific)
optimizer_kwargs = { optimizer_kwargs = {
"lr": group_config.get("lr", self.lr), "lr": group_config.get("lr", self.lr),
@ -136,15 +137,17 @@ class MultiAdamConfig(OptimizerConfig):
"eps": group_config.get("eps", 1e-5), "eps": group_config.get("eps", 1e-5),
"weight_decay": group_config.get("weight_decay", self.weight_decay), "weight_decay": group_config.get("weight_decay", self.weight_decay),
} }
optimizers[name] = torch.optim.Adam(params, **optimizer_kwargs) optimizers[name] = torch.optim.Adam(params, **optimizer_kwargs)
return optimizers return optimizers
def save_optimizer_state(optimizer: torch.optim.Optimizer | dict[str, torch.optim.Optimizer], save_dir: Path) -> None: def save_optimizer_state(
optimizer: torch.optim.Optimizer | dict[str, torch.optim.Optimizer], save_dir: Path
) -> None:
"""Save optimizer state to disk. """Save optimizer state to disk.
Args: Args:
optimizer: Either a single optimizer or a dictionary of optimizers. optimizer: Either a single optimizer or a dictionary of optimizers.
save_dir: Directory to save the optimizer state. save_dir: Directory to save the optimizer state.
@ -173,11 +176,11 @@ def load_optimizer_state(
optimizer: torch.optim.Optimizer | dict[str, torch.optim.Optimizer], save_dir: Path optimizer: torch.optim.Optimizer | dict[str, torch.optim.Optimizer], save_dir: Path
) -> torch.optim.Optimizer | dict[str, torch.optim.Optimizer]: ) -> torch.optim.Optimizer | dict[str, torch.optim.Optimizer]:
"""Load optimizer state from disk. """Load optimizer state from disk.
Args: Args:
optimizer: Either a single optimizer or a dictionary of optimizers. optimizer: Either a single optimizer or a dictionary of optimizers.
save_dir: Directory to load the optimizer state from. save_dir: Directory to load the optimizer state from.
Returns: Returns:
The updated optimizer(s) with loaded state. The updated optimizer(s) with loaded state.
""" """
@ -201,7 +204,7 @@ def _load_single_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Pat
current_state_dict = optimizer.state_dict() current_state_dict = optimizer.state_dict()
flat_state = load_file(save_dir / OPTIMIZER_STATE) flat_state = load_file(save_dir / OPTIMIZER_STATE)
state = unflatten_dict(flat_state) state = unflatten_dict(flat_state)
# Handle case where 'state' key might not exist (for newly created optimizers) # Handle case where 'state' key might not exist (for newly created optimizers)
if "state" in state: if "state" in state:
loaded_state_dict = {"state": {int(k): v for k, v in state["state"].items()}} loaded_state_dict = {"state": {int(k): v for k, v in state["state"].items()}}

View File

@ -24,11 +24,11 @@ from lerobot.common.envs.configs import EnvConfig
from lerobot.common.envs.utils import env_to_policy_features from lerobot.common.envs.utils import env_to_policy_features
from lerobot.common.policies.act.configuration_act import ACTConfig from lerobot.common.policies.act.configuration_act import ACTConfig
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig
from lerobot.common.policies.pi0.configuration_pi0 import PI0Config from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
from lerobot.common.policies.pretrained import PreTrainedPolicy from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig
from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType from lerobot.configs.types import FeatureType

View File

@ -1,10 +1,9 @@
from dataclasses import dataclass, field from dataclasses import dataclass
from typing import Dict, List from typing import List
from lerobot.common.optim.optimizers import AdamWConfig, OptimizerConfig from lerobot.common.optim.optimizers import AdamWConfig, OptimizerConfig
from lerobot.common.optim.schedulers import LRSchedulerConfig from lerobot.common.optim.schedulers import LRSchedulerConfig
from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType, PolicyFeature
@PreTrainedConfig.register_subclass(name="hilserl_classifier") @PreTrainedConfig.register_subclass(name="hilserl_classifier")

View File

@ -82,8 +82,10 @@ def create_stats_buffers(
if stats and key in stats: if stats and key in stats:
if norm_mode is NormalizationMode.MEAN_STD: if norm_mode is NormalizationMode.MEAN_STD:
if "mean" not in stats[key] or "std" not in stats[key]: if "mean" not in stats[key] or "std" not in stats[key]:
raise ValueError(f"Missing 'mean' or 'std' in stats for key {key} with MEAN_STD normalization") raise ValueError(
f"Missing 'mean' or 'std' in stats for key {key} with MEAN_STD normalization"
)
if isinstance(stats[key]["mean"], np.ndarray): if isinstance(stats[key]["mean"], np.ndarray):
buffer["mean"].data = torch.from_numpy(stats[key]["mean"]).to(dtype=torch.float32) buffer["mean"].data = torch.from_numpy(stats[key]["mean"]).to(dtype=torch.float32)
buffer["std"].data = torch.from_numpy(stats[key]["std"]).to(dtype=torch.float32) buffer["std"].data = torch.from_numpy(stats[key]["std"]).to(dtype=torch.float32)
@ -96,12 +98,16 @@ def create_stats_buffers(
buffer["std"].data = stats[key]["std"].clone().to(dtype=torch.float32) buffer["std"].data = stats[key]["std"].clone().to(dtype=torch.float32)
else: else:
type_ = type(stats[key]["mean"]) type_ = type(stats[key]["mean"])
raise ValueError(f"np.ndarray or torch.Tensor expected for 'mean', but type is '{type_}' instead.") raise ValueError(
f"np.ndarray or torch.Tensor expected for 'mean', but type is '{type_}' instead."
)
elif norm_mode is NormalizationMode.MIN_MAX: elif norm_mode is NormalizationMode.MIN_MAX:
if "min" not in stats[key] or "max" not in stats[key]: if "min" not in stats[key] or "max" not in stats[key]:
raise ValueError(f"Missing 'min' or 'max' in stats for key {key} with MIN_MAX normalization") raise ValueError(
f"Missing 'min' or 'max' in stats for key {key} with MIN_MAX normalization"
)
if isinstance(stats[key]["min"], np.ndarray): if isinstance(stats[key]["min"], np.ndarray):
buffer["min"].data = torch.from_numpy(stats[key]["min"]).to(dtype=torch.float32) buffer["min"].data = torch.from_numpy(stats[key]["min"]).to(dtype=torch.float32)
buffer["max"].data = torch.from_numpy(stats[key]["max"]).to(dtype=torch.float32) buffer["max"].data = torch.from_numpy(stats[key]["max"]).to(dtype=torch.float32)
@ -110,7 +116,9 @@ def create_stats_buffers(
buffer["max"].data = stats[key]["max"].clone().to(dtype=torch.float32) buffer["max"].data = stats[key]["max"].clone().to(dtype=torch.float32)
else: else:
type_ = type(stats[key]["min"]) type_ = type(stats[key]["min"])
raise ValueError(f"np.ndarray or torch.Tensor expected for 'min', but type is '{type_}' instead.") raise ValueError(
f"np.ndarray or torch.Tensor expected for 'min', but type is '{type_}' instead."
)
stats_buffers[key] = buffer stats_buffers[key] = buffer
return stats_buffers return stats_buffers

View File

@ -19,7 +19,7 @@ from dataclasses import dataclass, field
from lerobot.common.optim.optimizers import MultiAdamConfig from lerobot.common.optim.optimizers import MultiAdamConfig
from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature from lerobot.configs.types import NormalizationMode
@dataclass @dataclass

View File

@ -897,7 +897,6 @@ if __name__ == "__main__":
# for j in range(i + 1, num_critics): # for j in range(i + 1, num_critics):
# diff = torch.abs(q_values[i] - q_values[j]).mean().item() # diff = torch.abs(q_values[i] - q_values[j]).mean().item()
# print(f"Mean difference between critic {i} and {j}: {diff:.6f}") # print(f"Mean difference between critic {i} and {j}: {diff:.6f}")
import draccus
from lerobot.configs import parser from lerobot.configs import parser

View File

@ -115,11 +115,13 @@ class WandBLogger:
artifact.add_file(checkpoint_dir / PRETRAINED_MODEL_DIR / SAFETENSORS_SINGLE_FILE) artifact.add_file(checkpoint_dir / PRETRAINED_MODEL_DIR / SAFETENSORS_SINGLE_FILE)
self._wandb.log_artifact(artifact) self._wandb.log_artifact(artifact)
def log_dict(self, d: dict, step: int | None = None, mode: str = "train", custom_step_key: str | None = None): def log_dict(
self, d: dict, step: int | None = None, mode: str = "train", custom_step_key: str | None = None
):
if mode not in {"train", "eval"}: if mode not in {"train", "eval"}:
raise ValueError(mode) raise ValueError(mode)
if step is None and custom_step_key is None: if step is None and custom_step_key is None:
raise ValueError("Either step or custom_step_key must be provided.") raise ValueError("Either step or custom_step_key must be provided.")
# NOTE: This is not simple. Wandb step is it must always monotonically increase and it # NOTE: This is not simple. Wandb step is it must always monotonically increase and it
# increases with each wandb.log call, but in the case of asynchronous RL for example, # increases with each wandb.log call, but in the case of asynchronous RL for example,
@ -142,10 +144,7 @@ class WandBLogger:
continue continue
# Do not log the custom step key itself. # Do not log the custom step key itself.
if ( if self._wandb_custom_step_key is not None and k in self._wandb_custom_step_key:
self._wandb_custom_step_key is not None
and k in self._wandb_custom_step_key
):
continue continue
if custom_step_key is not None: if custom_step_key is not None:
@ -160,7 +159,6 @@ class WandBLogger:
self._wandb.log(data={f"{mode}/{k}": v}, step=step) self._wandb.log(data={f"{mode}/{k}": v}, step=step)
def log_video(self, video_path: str, step: int, mode: str = "train"): def log_video(self, video_path: str, step: int, mode: str = "train"):
if mode not in {"train", "eval"}: if mode not in {"train", "eval"}:
raise ValueError(mode) raise ValueError(mode)

View File

@ -34,7 +34,7 @@ TRAIN_CONFIG_NAME = "train_config.json"
@dataclass @dataclass
class TrainPipelineConfig(HubMixin): class TrainPipelineConfig(HubMixin):
dataset: DatasetConfig | None = None # NOTE: In RL, we don't need a dataset dataset: DatasetConfig | None = None # NOTE: In RL, we don't need a dataset
env: envs.EnvConfig | None = None env: envs.EnvConfig | None = None
policy: PreTrainedConfig | None = None policy: PreTrainedConfig | None = None
# Set `dir` to where you would like to save all of the run outputs. If you run another training session # with the same value for `dir` its contents will be overwritten unless you set `resume` to true. # Set `dir` to where you would like to save all of the run outputs. If you run another training session # with the same value for `dir` its contents will be overwritten unless you set `resume` to true.

View File

@ -47,7 +47,7 @@ from lerobot.scripts.server.buffer import (
python_object_to_bytes, python_object_to_bytes,
transitions_to_bytes, transitions_to_bytes,
) )
from lerobot.scripts.server.gym_manipulator import get_classifier, make_robot_env from lerobot.scripts.server.gym_manipulator import make_robot_env
from lerobot.scripts.server.network_utils import ( from lerobot.scripts.server.network_utils import (
receive_bytes_in_chunks, receive_bytes_in_chunks,
send_bytes_in_chunks, send_bytes_in_chunks,
@ -444,7 +444,7 @@ def receive_policy(
# Initialize logging with explicit log file # Initialize logging with explicit log file
init_logging(log_file=log_file) init_logging(log_file=log_file)
logging.info(f"Actor receive policy process logging initialized") logging.info("Actor receive policy process logging initialized")
# Setup process handlers to handle shutdown signal # Setup process handlers to handle shutdown signal
# But use shutdown event from the main process # But use shutdown event from the main process

View File

@ -515,7 +515,7 @@ class ReplayBuffer:
frame_dict["action"] = self.actions[actual_idx].cpu() frame_dict["action"] = self.actions[actual_idx].cpu()
frame_dict["next.reward"] = torch.tensor([self.rewards[actual_idx]], dtype=torch.float32).cpu() frame_dict["next.reward"] = torch.tensor([self.rewards[actual_idx]], dtype=torch.float32).cpu()
frame_dict["next.done"] = torch.tensor([self.dones[actual_idx]], dtype=torch.bool).cpu() frame_dict["next.done"] = torch.tensor([self.dones[actual_idx]], dtype=torch.bool).cpu()
# Add task field which is required by LeRobotDataset # Add task field which is required by LeRobotDataset
frame_dict["task"] = task_name frame_dict["task"] = task_name

View File

@ -701,10 +701,10 @@ def teleoperate_gym_env(env, controller, fps: int = 30):
if __name__ == "__main__": if __name__ == "__main__":
from lerobot.common.envs.configs import EEActionSpaceConfig, EnvWrapperConfig, HILSerlRobotEnvConfig
from lerobot.common.robot_devices.robots.configs import RobotConfig from lerobot.common.robot_devices.robots.configs import RobotConfig
from lerobot.common.robot_devices.robots.utils import make_robot_from_config from lerobot.common.robot_devices.robots.utils import make_robot_from_config
from lerobot.scripts.server.gym_manipulator import make_robot_env from lerobot.scripts.server.gym_manipulator import make_robot_env
from lerobot.common.envs.configs import HILSerlRobotEnvConfig, EEActionSpaceConfig, EnvWrapperConfig
parser = argparse.ArgumentParser(description="Test end-effector control") parser = argparse.ArgumentParser(description="Test end-effector control")
parser.add_argument( parser.add_argument(

View File

@ -5,14 +5,15 @@ import cv2
import numpy as np import numpy as np
from lerobot.common.robot_devices.control_utils import is_headless from lerobot.common.robot_devices.control_utils import is_headless
from lerobot.common.robot_devices.robots.utils import make_robot_from_config
from lerobot.scripts.server.kinematics import RobotKinematics
from lerobot.configs import parser
from lerobot.common.robot_devices.robots.configs import RobotConfig from lerobot.common.robot_devices.robots.configs import RobotConfig
from lerobot.common.robot_devices.robots.utils import make_robot_from_config
from lerobot.configs import parser
from lerobot.scripts.server.kinematics import RobotKinematics
follower_port = "/dev/tty.usbmodem58760431631" follower_port = "/dev/tty.usbmodem58760431631"
leader_port = "/dev/tty.usbmodem58760433331" leader_port = "/dev/tty.usbmodem58760433331"
def find_joint_bounds( def find_joint_bounds(
robot, robot,
control_time_s=30, control_time_s=30,
@ -85,21 +86,22 @@ def find_ee_bounds(
def make_robot(robot_type="so100"): def make_robot(robot_type="so100"):
""" """
Create a robot instance using the appropriate robot config class. Create a robot instance using the appropriate robot config class.
Args: Args:
robot_type: Robot type string (e.g., "so100", "koch", "aloha") robot_type: Robot type string (e.g., "so100", "koch", "aloha")
Returns: Returns:
Robot instance Robot instance
""" """
# Get the appropriate robot config class based on robot_type # Get the appropriate robot config class based on robot_type
robot_config = RobotConfig.get_choice_class(robot_type)(mock=False) robot_config = RobotConfig.get_choice_class(robot_type)(mock=False)
robot_config.leader_arms["main"].port = leader_port robot_config.leader_arms["main"].port = leader_port
robot_config.follower_arms["main"].port = follower_port robot_config.follower_arms["main"].port = follower_port
return make_robot_from_config(robot_config) return make_robot_from_config(robot_config)
if __name__ == "__main__": if __name__ == "__main__":
# Create argparse for script-specific arguments # Create argparse for script-specific arguments
parser = argparse.ArgumentParser(add_help=False) # Set add_help=False to avoid conflict parser = argparse.ArgumentParser(add_help=False) # Set add_help=False to avoid conflict
@ -125,14 +127,14 @@ if __name__ == "__main__":
# Only parse known args, leaving robot config args for Hydra if used # Only parse known args, leaving robot config args for Hydra if used
args = parser.parse_args() args = parser.parse_args()
# Create robot with the appropriate config # Create robot with the appropriate config
robot = make_robot(args.robot_type) robot = make_robot(args.robot_type)
if args.mode == "joint": if args.mode == "joint":
find_joint_bounds(robot, args.control_time_s) find_joint_bounds(robot, args.control_time_s)
elif args.mode == "ee": elif args.mode == "ee":
find_ee_bounds(robot, args.control_time_s) find_ee_bounds(robot, args.control_time_s)
if robot.is_connected: if robot.is_connected:
robot.disconnect() robot.disconnect()

View File

@ -1,7 +1,6 @@
import logging import logging
import time import time
from dataclasses import dataclass, field from typing import Any
from typing import Any, Dict, Optional, Tuple
import einops import einops
import gymnasium as gym import gymnasium as gym
@ -10,10 +9,8 @@ import torch
from mani_skill.utils.wrappers.record import RecordEpisode from mani_skill.utils.wrappers.record import RecordEpisode
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
from lerobot.common.envs.configs import EnvConfig, ManiskillEnvConfig from lerobot.common.envs.configs import ManiskillEnvConfig
from lerobot.configs import parser from lerobot.configs import parser
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.common.constants import ACTION, OBS_IMAGE, OBS_ROBOT
def preprocess_maniskill_observation( def preprocess_maniskill_observation(
@ -53,9 +50,6 @@ def preprocess_maniskill_observation(
return return_observations return return_observations
class ManiSkillObservationWrapper(gym.ObservationWrapper): class ManiSkillObservationWrapper(gym.ObservationWrapper):
def __init__(self, env, device: torch.device = "cuda"): def __init__(self, env, device: torch.device = "cuda"):
super().__init__(env) super().__init__(env)
@ -122,6 +116,7 @@ class ManiSkillMultiplyActionWrapper(gym.Wrapper):
class BatchCompatibleWrapper(gym.ObservationWrapper): class BatchCompatibleWrapper(gym.ObservationWrapper):
"""Ensures observations are batch-compatible by adding a batch dimension if necessary.""" """Ensures observations are batch-compatible by adding a batch dimension if necessary."""
def __init__(self, env): def __init__(self, env):
super().__init__(env) super().__init__(env)
@ -136,6 +131,7 @@ class BatchCompatibleWrapper(gym.ObservationWrapper):
class TimeLimitWrapper(gym.Wrapper): class TimeLimitWrapper(gym.Wrapper):
"""Adds a time limit to the environment based on fps and control_time.""" """Adds a time limit to the environment based on fps and control_time."""
def __init__(self, env, control_time_s, fps): def __init__(self, env, control_time_s, fps):
super().__init__(env) super().__init__(env)
self.control_time_s = control_time_s self.control_time_s = control_time_s
@ -146,10 +142,10 @@ class TimeLimitWrapper(gym.Wrapper):
def step(self, action): def step(self, action):
obs, reward, terminated, truncated, info = self.env.step(action) obs, reward, terminated, truncated, info = self.env.step(action)
self.current_step += 1 self.current_step += 1
if self.current_step >= self.max_episode_steps: if self.current_step >= self.max_episode_steps:
terminated = True terminated = True
return obs, reward, terminated, truncated, info return obs, reward, terminated, truncated, info
def reset(self, *, seed=None, options=None): def reset(self, *, seed=None, options=None):
@ -190,18 +186,18 @@ def make_maniskill(
save_video=True, save_video=True,
video_fps=30, video_fps=30,
) )
# Add observation and image processing # Add observation and image processing
env = ManiSkillObservationWrapper(env, device=cfg.device) env = ManiSkillObservationWrapper(env, device=cfg.device)
env = ManiSkillVectorEnv(env, ignore_terminations=True, auto_reset=False) env = ManiSkillVectorEnv(env, ignore_terminations=True, auto_reset=False)
env._max_episode_steps = env.max_episode_steps = cfg.episode_length env._max_episode_steps = env.max_episode_steps = cfg.episode_length
env.unwrapped.metadata["render_fps"] = cfg.fps env.unwrapped.metadata["render_fps"] = cfg.fps
# Add compatibility wrappers # Add compatibility wrappers
env = ManiSkillCompat(env) env = ManiSkillCompat(env)
env = ManiSkillActionWrapper(env) env = ManiSkillActionWrapper(env)
env = ManiSkillMultiplyActionWrapper(env, multiply_factor=0.03) # Scale actions for better control env = ManiSkillMultiplyActionWrapper(env, multiply_factor=0.03) # Scale actions for better control
return env return env
@ -210,29 +206,29 @@ def main(cfg: ManiskillEnvConfig):
"""Main function to run the ManiSkill environment.""" """Main function to run the ManiSkill environment."""
# Create the ManiSkill environment # Create the ManiSkill environment
env = make_maniskill(cfg, n_envs=1) env = make_maniskill(cfg, n_envs=1)
# Reset the environment # Reset the environment
obs, info = env.reset() obs, info = env.reset()
# Run a simple interaction loop # Run a simple interaction loop
sum_reward = 0 sum_reward = 0
for i in range(100): for i in range(100):
# Sample a random action # Sample a random action
action = env.action_space.sample() action = env.action_space.sample()
# Step the environment # Step the environment
start_time = time.perf_counter() start_time = time.perf_counter()
obs, reward, terminated, truncated, info = env.step(action) obs, reward, terminated, truncated, info = env.step(action)
step_time = time.perf_counter() - start_time step_time = time.perf_counter() - start_time
sum_reward += reward sum_reward += reward
# Log information # Log information
# Reset if episode terminated # Reset if episode terminated
if terminated or truncated: if terminated or truncated:
logging.info(f"Step {i}, reward: {sum_reward}, step time: {step_time}s") logging.info(f"Step {i}, reward: {sum_reward}, step time: {step_time}s")
sum_reward = 0 sum_reward = 0
obs, info = env.reset() obs, info = env.reset()
# Close the environment # Close the environment
env.close() env.close()
@ -243,6 +239,10 @@ def main(cfg: ManiskillEnvConfig):
if __name__ == "__main__": if __name__ == "__main__":
import draccus import draccus
config = ManiskillEnvConfig() config = ManiskillEnvConfig()
draccus.set_config_type("json") draccus.set_config_type("json")
draccus.dump(config=config, stream=open(file='run_config.json', mode='w'), ) draccus.dump(
config=config,
stream=open(file="run_config.json", mode="w"),
)

View File

@ -32,7 +32,6 @@ from tqdm import tqdm
from lerobot.common.datasets.factory import resolve_delta_timestamps from lerobot.common.datasets.factory import resolve_delta_timestamps
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
from lerobot.common.utils.utils import ( from lerobot.common.utils.utils import (
format_big_number, format_big_number,

View File

@ -134,15 +134,15 @@ def test_multi_adam_configuration(base_params_dict, config_params, expected_valu
# Create config with the given parameters # Create config with the given parameters
config = MultiAdamConfig(**config_params) config = MultiAdamConfig(**config_params)
optimizers = config.build(base_params_dict) optimizers = config.build(base_params_dict)
# Verify optimizer count and keys # Verify optimizer count and keys
assert len(optimizers) == len(expected_values) assert len(optimizers) == len(expected_values)
assert set(optimizers.keys()) == set(expected_values.keys()) assert set(optimizers.keys()) == set(expected_values.keys())
# Check that all optimizers are Adam instances # Check that all optimizers are Adam instances
for opt in optimizers.values(): for opt in optimizers.values():
assert isinstance(opt, torch.optim.Adam) assert isinstance(opt, torch.optim.Adam)
# Verify hyperparameters for each optimizer # Verify hyperparameters for each optimizer
for name, expected in expected_values.items(): for name, expected in expected_values.items():
optimizer = optimizers[name] optimizer = optimizers[name]
@ -166,7 +166,7 @@ def multi_optimizers(base_params_dict):
def test_save_multi_optimizer_state(multi_optimizers, tmp_path): def test_save_multi_optimizer_state(multi_optimizers, tmp_path):
# Save optimizer states # Save optimizer states
save_optimizer_state(multi_optimizers, tmp_path) save_optimizer_state(multi_optimizers, tmp_path)
# Verify that directories were created for each optimizer # Verify that directories were created for each optimizer
for name in multi_optimizers.keys(): for name in multi_optimizers.keys():
assert (tmp_path / name).is_dir() assert (tmp_path / name).is_dir()
@ -185,10 +185,10 @@ def test_save_and_load_multi_optimizer_state(base_params_dict, multi_optimizers,
multi_optimizers[name].step() multi_optimizers[name].step()
# Zero gradients for next steps # Zero gradients for next steps
multi_optimizers[name].zero_grad() multi_optimizers[name].zero_grad()
# Save optimizer states # Save optimizer states
save_optimizer_state(multi_optimizers, tmp_path) save_optimizer_state(multi_optimizers, tmp_path)
# Create new optimizers with the same config # Create new optimizers with the same config
config = MultiAdamConfig( config = MultiAdamConfig(
lr=1e-3, lr=1e-3,
@ -199,16 +199,13 @@ def test_save_and_load_multi_optimizer_state(base_params_dict, multi_optimizers,
}, },
) )
new_optimizers = config.build(base_params_dict) new_optimizers = config.build(base_params_dict)
# Load optimizer states # Load optimizer states
loaded_optimizers = load_optimizer_state(new_optimizers, tmp_path) loaded_optimizers = load_optimizer_state(new_optimizers, tmp_path)
# Verify state dictionaries match # Verify state dictionaries match
for name in multi_optimizers.keys(): for name in multi_optimizers.keys():
torch.testing.assert_close( torch.testing.assert_close(multi_optimizers[name].state_dict(), loaded_optimizers[name].state_dict())
multi_optimizers[name].state_dict(),
loaded_optimizers[name].state_dict()
)
def test_save_and_load_empty_multi_optimizer_state(base_params_dict, tmp_path): def test_save_and_load_empty_multi_optimizer_state(base_params_dict, tmp_path):
@ -223,25 +220,23 @@ def test_save_and_load_empty_multi_optimizer_state(base_params_dict, tmp_path):
}, },
) )
optimizers = config.build(base_params_dict) optimizers = config.build(base_params_dict)
# Save optimizer states without any backward pass (empty state) # Save optimizer states without any backward pass (empty state)
save_optimizer_state(optimizers, tmp_path) save_optimizer_state(optimizers, tmp_path)
# Create new optimizers with the same config # Create new optimizers with the same config
new_optimizers = config.build(base_params_dict) new_optimizers = config.build(base_params_dict)
# Load optimizer states # Load optimizer states
loaded_optimizers = load_optimizer_state(new_optimizers, tmp_path) loaded_optimizers = load_optimizer_state(new_optimizers, tmp_path)
# Verify hyperparameters match even with empty state # Verify hyperparameters match even with empty state
for name, optimizer in optimizers.items(): for name, optimizer in optimizers.items():
assert optimizer.defaults["lr"] == loaded_optimizers[name].defaults["lr"] assert optimizer.defaults["lr"] == loaded_optimizers[name].defaults["lr"]
assert optimizer.defaults["weight_decay"] == loaded_optimizers[name].defaults["weight_decay"] assert optimizer.defaults["weight_decay"] == loaded_optimizers[name].defaults["weight_decay"]
assert optimizer.defaults["betas"] == loaded_optimizers[name].defaults["betas"] assert optimizer.defaults["betas"] == loaded_optimizers[name].defaults["betas"]
# Verify state dictionaries match (they will be empty) # Verify state dictionaries match (they will be empty)
torch.testing.assert_close( torch.testing.assert_close(
optimizer.state_dict()["param_groups"], optimizer.state_dict()["param_groups"], loaded_optimizers[name].state_dict()["param_groups"]
loaded_optimizers[name].state_dict()["param_groups"]
) )