[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot] 2025-03-28 17:20:38 +00:00
parent 808cf63221
commit c05e4835d0
16 changed files with 93 additions and 91 deletions

View File

@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import importlib import importlib
from collections import deque
import gymnasium as gym import gymnasium as gym

View File

@ -108,6 +108,7 @@ class MultiAdamConfig(OptimizerConfig):
optimizer_groups: Dictionary mapping parameter group names to their hyperparameters optimizer_groups: Dictionary mapping parameter group names to their hyperparameters
grad_clip_norm: Gradient clipping norm grad_clip_norm: Gradient clipping norm
""" """
lr: float = 1e-3 lr: float = 1e-3
weight_decay: float = 0.0 weight_decay: float = 0.0
grad_clip_norm: float = 10.0 grad_clip_norm: float = 10.0
@ -142,7 +143,9 @@ class MultiAdamConfig(OptimizerConfig):
return optimizers return optimizers
def save_optimizer_state(optimizer: torch.optim.Optimizer | dict[str, torch.optim.Optimizer], save_dir: Path) -> None: def save_optimizer_state(
optimizer: torch.optim.Optimizer | dict[str, torch.optim.Optimizer], save_dir: Path
) -> None:
"""Save optimizer state to disk. """Save optimizer state to disk.
Args: Args:

View File

@ -24,11 +24,11 @@ from lerobot.common.envs.configs import EnvConfig
from lerobot.common.envs.utils import env_to_policy_features from lerobot.common.envs.utils import env_to_policy_features
from lerobot.common.policies.act.configuration_act import ACTConfig from lerobot.common.policies.act.configuration_act import ACTConfig
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig
from lerobot.common.policies.pi0.configuration_pi0 import PI0Config from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
from lerobot.common.policies.pretrained import PreTrainedPolicy from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig
from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType from lerobot.configs.types import FeatureType

View File

@ -1,10 +1,9 @@
from dataclasses import dataclass, field from dataclasses import dataclass
from typing import Dict, List from typing import List
from lerobot.common.optim.optimizers import AdamWConfig, OptimizerConfig from lerobot.common.optim.optimizers import AdamWConfig, OptimizerConfig
from lerobot.common.optim.schedulers import LRSchedulerConfig from lerobot.common.optim.schedulers import LRSchedulerConfig
from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType, PolicyFeature
@PreTrainedConfig.register_subclass(name="hilserl_classifier") @PreTrainedConfig.register_subclass(name="hilserl_classifier")

View File

@ -82,7 +82,9 @@ def create_stats_buffers(
if stats and key in stats: if stats and key in stats:
if norm_mode is NormalizationMode.MEAN_STD: if norm_mode is NormalizationMode.MEAN_STD:
if "mean" not in stats[key] or "std" not in stats[key]: if "mean" not in stats[key] or "std" not in stats[key]:
raise ValueError(f"Missing 'mean' or 'std' in stats for key {key} with MEAN_STD normalization") raise ValueError(
f"Missing 'mean' or 'std' in stats for key {key} with MEAN_STD normalization"
)
if isinstance(stats[key]["mean"], np.ndarray): if isinstance(stats[key]["mean"], np.ndarray):
buffer["mean"].data = torch.from_numpy(stats[key]["mean"]).to(dtype=torch.float32) buffer["mean"].data = torch.from_numpy(stats[key]["mean"]).to(dtype=torch.float32)
@ -96,11 +98,15 @@ def create_stats_buffers(
buffer["std"].data = stats[key]["std"].clone().to(dtype=torch.float32) buffer["std"].data = stats[key]["std"].clone().to(dtype=torch.float32)
else: else:
type_ = type(stats[key]["mean"]) type_ = type(stats[key]["mean"])
raise ValueError(f"np.ndarray or torch.Tensor expected for 'mean', but type is '{type_}' instead.") raise ValueError(
f"np.ndarray or torch.Tensor expected for 'mean', but type is '{type_}' instead."
)
elif norm_mode is NormalizationMode.MIN_MAX: elif norm_mode is NormalizationMode.MIN_MAX:
if "min" not in stats[key] or "max" not in stats[key]: if "min" not in stats[key] or "max" not in stats[key]:
raise ValueError(f"Missing 'min' or 'max' in stats for key {key} with MIN_MAX normalization") raise ValueError(
f"Missing 'min' or 'max' in stats for key {key} with MIN_MAX normalization"
)
if isinstance(stats[key]["min"], np.ndarray): if isinstance(stats[key]["min"], np.ndarray):
buffer["min"].data = torch.from_numpy(stats[key]["min"]).to(dtype=torch.float32) buffer["min"].data = torch.from_numpy(stats[key]["min"]).to(dtype=torch.float32)
@ -110,7 +116,9 @@ def create_stats_buffers(
buffer["max"].data = stats[key]["max"].clone().to(dtype=torch.float32) buffer["max"].data = stats[key]["max"].clone().to(dtype=torch.float32)
else: else:
type_ = type(stats[key]["min"]) type_ = type(stats[key]["min"])
raise ValueError(f"np.ndarray or torch.Tensor expected for 'min', but type is '{type_}' instead.") raise ValueError(
f"np.ndarray or torch.Tensor expected for 'min', but type is '{type_}' instead."
)
stats_buffers[key] = buffer stats_buffers[key] = buffer
return stats_buffers return stats_buffers

View File

@ -19,7 +19,7 @@ from dataclasses import dataclass, field
from lerobot.common.optim.optimizers import MultiAdamConfig from lerobot.common.optim.optimizers import MultiAdamConfig
from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature from lerobot.configs.types import NormalizationMode
@dataclass @dataclass

View File

@ -897,7 +897,6 @@ if __name__ == "__main__":
# for j in range(i + 1, num_critics): # for j in range(i + 1, num_critics):
# diff = torch.abs(q_values[i] - q_values[j]).mean().item() # diff = torch.abs(q_values[i] - q_values[j]).mean().item()
# print(f"Mean difference between critic {i} and {j}: {diff:.6f}") # print(f"Mean difference between critic {i} and {j}: {diff:.6f}")
import draccus
from lerobot.configs import parser from lerobot.configs import parser

View File

@ -115,11 +115,13 @@ class WandBLogger:
artifact.add_file(checkpoint_dir / PRETRAINED_MODEL_DIR / SAFETENSORS_SINGLE_FILE) artifact.add_file(checkpoint_dir / PRETRAINED_MODEL_DIR / SAFETENSORS_SINGLE_FILE)
self._wandb.log_artifact(artifact) self._wandb.log_artifact(artifact)
def log_dict(self, d: dict, step: int | None = None, mode: str = "train", custom_step_key: str | None = None): def log_dict(
self, d: dict, step: int | None = None, mode: str = "train", custom_step_key: str | None = None
):
if mode not in {"train", "eval"}: if mode not in {"train", "eval"}:
raise ValueError(mode) raise ValueError(mode)
if step is None and custom_step_key is None: if step is None and custom_step_key is None:
raise ValueError("Either step or custom_step_key must be provided.") raise ValueError("Either step or custom_step_key must be provided.")
# NOTE: This is not simple. Wandb step is it must always monotonically increase and it # NOTE: This is not simple. Wandb step is it must always monotonically increase and it
# increases with each wandb.log call, but in the case of asynchronous RL for example, # increases with each wandb.log call, but in the case of asynchronous RL for example,
@ -142,10 +144,7 @@ class WandBLogger:
continue continue
# Do not log the custom step key itself. # Do not log the custom step key itself.
if ( if self._wandb_custom_step_key is not None and k in self._wandb_custom_step_key:
self._wandb_custom_step_key is not None
and k in self._wandb_custom_step_key
):
continue continue
if custom_step_key is not None: if custom_step_key is not None:
@ -160,7 +159,6 @@ class WandBLogger:
self._wandb.log(data={f"{mode}/{k}": v}, step=step) self._wandb.log(data={f"{mode}/{k}": v}, step=step)
def log_video(self, video_path: str, step: int, mode: str = "train"): def log_video(self, video_path: str, step: int, mode: str = "train"):
if mode not in {"train", "eval"}: if mode not in {"train", "eval"}:
raise ValueError(mode) raise ValueError(mode)

View File

@ -34,7 +34,7 @@ TRAIN_CONFIG_NAME = "train_config.json"
@dataclass @dataclass
class TrainPipelineConfig(HubMixin): class TrainPipelineConfig(HubMixin):
dataset: DatasetConfig | None = None # NOTE: In RL, we don't need a dataset dataset: DatasetConfig | None = None # NOTE: In RL, we don't need a dataset
env: envs.EnvConfig | None = None env: envs.EnvConfig | None = None
policy: PreTrainedConfig | None = None policy: PreTrainedConfig | None = None
# Set `dir` to where you would like to save all of the run outputs. If you run another training session # with the same value for `dir` its contents will be overwritten unless you set `resume` to true. # Set `dir` to where you would like to save all of the run outputs. If you run another training session # with the same value for `dir` its contents will be overwritten unless you set `resume` to true.

View File

@ -47,7 +47,7 @@ from lerobot.scripts.server.buffer import (
python_object_to_bytes, python_object_to_bytes,
transitions_to_bytes, transitions_to_bytes,
) )
from lerobot.scripts.server.gym_manipulator import get_classifier, make_robot_env from lerobot.scripts.server.gym_manipulator import make_robot_env
from lerobot.scripts.server.network_utils import ( from lerobot.scripts.server.network_utils import (
receive_bytes_in_chunks, receive_bytes_in_chunks,
send_bytes_in_chunks, send_bytes_in_chunks,
@ -444,7 +444,7 @@ def receive_policy(
# Initialize logging with explicit log file # Initialize logging with explicit log file
init_logging(log_file=log_file) init_logging(log_file=log_file)
logging.info(f"Actor receive policy process logging initialized") logging.info("Actor receive policy process logging initialized")
# Setup process handlers to handle shutdown signal # Setup process handlers to handle shutdown signal
# But use shutdown event from the main process # But use shutdown event from the main process

View File

@ -701,10 +701,10 @@ def teleoperate_gym_env(env, controller, fps: int = 30):
if __name__ == "__main__": if __name__ == "__main__":
from lerobot.common.envs.configs import EEActionSpaceConfig, EnvWrapperConfig, HILSerlRobotEnvConfig
from lerobot.common.robot_devices.robots.configs import RobotConfig from lerobot.common.robot_devices.robots.configs import RobotConfig
from lerobot.common.robot_devices.robots.utils import make_robot_from_config from lerobot.common.robot_devices.robots.utils import make_robot_from_config
from lerobot.scripts.server.gym_manipulator import make_robot_env from lerobot.scripts.server.gym_manipulator import make_robot_env
from lerobot.common.envs.configs import HILSerlRobotEnvConfig, EEActionSpaceConfig, EnvWrapperConfig
parser = argparse.ArgumentParser(description="Test end-effector control") parser = argparse.ArgumentParser(description="Test end-effector control")
parser.add_argument( parser.add_argument(

View File

@ -5,14 +5,15 @@ import cv2
import numpy as np import numpy as np
from lerobot.common.robot_devices.control_utils import is_headless from lerobot.common.robot_devices.control_utils import is_headless
from lerobot.common.robot_devices.robots.utils import make_robot_from_config
from lerobot.scripts.server.kinematics import RobotKinematics
from lerobot.configs import parser
from lerobot.common.robot_devices.robots.configs import RobotConfig from lerobot.common.robot_devices.robots.configs import RobotConfig
from lerobot.common.robot_devices.robots.utils import make_robot_from_config
from lerobot.configs import parser
from lerobot.scripts.server.kinematics import RobotKinematics
follower_port = "/dev/tty.usbmodem58760431631" follower_port = "/dev/tty.usbmodem58760431631"
leader_port = "/dev/tty.usbmodem58760433331" leader_port = "/dev/tty.usbmodem58760433331"
def find_joint_bounds( def find_joint_bounds(
robot, robot,
control_time_s=30, control_time_s=30,
@ -100,6 +101,7 @@ def make_robot(robot_type="so100"):
return make_robot_from_config(robot_config) return make_robot_from_config(robot_config)
if __name__ == "__main__": if __name__ == "__main__":
# Create argparse for script-specific arguments # Create argparse for script-specific arguments
parser = argparse.ArgumentParser(add_help=False) # Set add_help=False to avoid conflict parser = argparse.ArgumentParser(add_help=False) # Set add_help=False to avoid conflict

View File

@ -1,7 +1,6 @@
import logging import logging
import time import time
from dataclasses import dataclass, field from typing import Any
from typing import Any, Dict, Optional, Tuple
import einops import einops
import gymnasium as gym import gymnasium as gym
@ -10,10 +9,8 @@ import torch
from mani_skill.utils.wrappers.record import RecordEpisode from mani_skill.utils.wrappers.record import RecordEpisode
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
from lerobot.common.envs.configs import EnvConfig, ManiskillEnvConfig from lerobot.common.envs.configs import ManiskillEnvConfig
from lerobot.configs import parser from lerobot.configs import parser
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.common.constants import ACTION, OBS_IMAGE, OBS_ROBOT
def preprocess_maniskill_observation( def preprocess_maniskill_observation(
@ -53,9 +50,6 @@ def preprocess_maniskill_observation(
return return_observations return return_observations
class ManiSkillObservationWrapper(gym.ObservationWrapper): class ManiSkillObservationWrapper(gym.ObservationWrapper):
def __init__(self, env, device: torch.device = "cuda"): def __init__(self, env, device: torch.device = "cuda"):
super().__init__(env) super().__init__(env)
@ -122,6 +116,7 @@ class ManiSkillMultiplyActionWrapper(gym.Wrapper):
class BatchCompatibleWrapper(gym.ObservationWrapper): class BatchCompatibleWrapper(gym.ObservationWrapper):
"""Ensures observations are batch-compatible by adding a batch dimension if necessary.""" """Ensures observations are batch-compatible by adding a batch dimension if necessary."""
def __init__(self, env): def __init__(self, env):
super().__init__(env) super().__init__(env)
@ -136,6 +131,7 @@ class BatchCompatibleWrapper(gym.ObservationWrapper):
class TimeLimitWrapper(gym.Wrapper): class TimeLimitWrapper(gym.Wrapper):
"""Adds a time limit to the environment based on fps and control_time.""" """Adds a time limit to the environment based on fps and control_time."""
def __init__(self, env, control_time_s, fps): def __init__(self, env, control_time_s, fps):
super().__init__(env) super().__init__(env)
self.control_time_s = control_time_s self.control_time_s = control_time_s
@ -243,6 +239,10 @@ def main(cfg: ManiskillEnvConfig):
if __name__ == "__main__": if __name__ == "__main__":
import draccus import draccus
config = ManiskillEnvConfig() config = ManiskillEnvConfig()
draccus.set_config_type("json") draccus.set_config_type("json")
draccus.dump(config=config, stream=open(file='run_config.json', mode='w'), ) draccus.dump(
config=config,
stream=open(file="run_config.json", mode="w"),
)

View File

@ -32,7 +32,6 @@ from tqdm import tqdm
from lerobot.common.datasets.factory import resolve_delta_timestamps from lerobot.common.datasets.factory import resolve_delta_timestamps
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
from lerobot.common.utils.utils import ( from lerobot.common.utils.utils import (
format_big_number, format_big_number,

View File

@ -205,10 +205,7 @@ def test_save_and_load_multi_optimizer_state(base_params_dict, multi_optimizers,
# Verify state dictionaries match # Verify state dictionaries match
for name in multi_optimizers.keys(): for name in multi_optimizers.keys():
torch.testing.assert_close( torch.testing.assert_close(multi_optimizers[name].state_dict(), loaded_optimizers[name].state_dict())
multi_optimizers[name].state_dict(),
loaded_optimizers[name].state_dict()
)
def test_save_and_load_empty_multi_optimizer_state(base_params_dict, tmp_path): def test_save_and_load_empty_multi_optimizer_state(base_params_dict, tmp_path):
@ -241,7 +238,5 @@ def test_save_and_load_empty_multi_optimizer_state(base_params_dict, tmp_path):
# Verify state dictionaries match (they will be empty) # Verify state dictionaries match (they will be empty)
torch.testing.assert_close( torch.testing.assert_close(
optimizer.state_dict()["param_groups"], optimizer.state_dict()["param_groups"], loaded_optimizers[name].state_dict()["param_groups"]
loaded_optimizers[name].state_dict()["param_groups"]
) )