From f5cfd9fd481c2d8443eba1995a86c6384f51b18a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 31 Mar 2025 16:10:00 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- lerobot/common/policies/sac/modeling_sac.py | 23 ++++++++++++------- lerobot/scripts/server/buffer.py | 12 ++++------ .../server/end_effector_control_utils.py | 10 ++++---- lerobot/scripts/server/gym_manipulator.py | 10 ++++---- lerobot/scripts/server/learner_server.py | 6 +++-- 5 files changed, 35 insertions(+), 26 deletions(-) diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index bd74c65b..95ea3928 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -198,7 +198,7 @@ class SACPolicy( def grasp_critic_forward(self, observations, use_target=False, observation_features=None): """Forward pass through a grasp critic network - + Args: observations: Dictionary of observations use_target: If True, use target critics, otherwise use ensemble critics @@ -254,7 +254,7 @@ class SACPolicy( observation_features=observation_features, next_observation_features=next_observation_features, ) - + if model == "grasp_critic": # Extract grasp_critic-specific components complementary_info: dict[str, Tensor] = batch["complementary_info"] @@ -307,7 +307,7 @@ class SACPolicy( param.data * self.config.critic_target_update_weight + target_param.data * (1.0 - self.config.critic_target_update_weight) ) - + def update_temperature(self): self.temperature = self.log_alpha.exp().item() @@ -369,8 +369,17 @@ class SACPolicy( ).sum() return critics_loss - def compute_loss_grasp_critic(self, observations, actions, rewards, next_observations, done, observation_features=None, next_observation_features=None, complementary_info=None): - + def compute_loss_grasp_critic( + self, + observations, + actions, + rewards, + next_observations, + done, + observation_features=None, + next_observation_features=None, + complementary_info=None, + ): batch_size = rewards.shape[0] grasp_actions = torch.clip(actions[:, -1].long() + 1, 0, 2) # Map [-1, 0, 1] -> [0, 1, 2] @@ -632,9 +641,7 @@ class GraspCritic(nn.Module): self.parameters_to_optimize += list(self.output_layer.parameters()) def forward( - self, - observations: torch.Tensor, - observation_features: torch.Tensor | None = None + self, observations: torch.Tensor, observation_features: torch.Tensor | None = None ) -> torch.Tensor: device = get_device_from_parameters(self) # Move each tensor in observations to device diff --git a/lerobot/scripts/server/buffer.py b/lerobot/scripts/server/buffer.py index 1fbc8803..bf65b1ec 100644 --- a/lerobot/scripts/server/buffer.py +++ b/lerobot/scripts/server/buffer.py @@ -248,7 +248,7 @@ class ReplayBuffer: # Initialize complementary_info storage self.complementary_info_keys = [] self.complementary_info_storage = {} - + self.initialized = True def __len__(self): @@ -291,11 +291,9 @@ class ReplayBuffer: if isinstance(value, torch.Tensor): shape = value.shape if value.ndim > 0 else (1,) self.complementary_info_storage[key] = torch.zeros( - (self.capacity, *shape), - dtype=value.dtype, - device=self.storage_device + (self.capacity, *shape), dtype=value.dtype, device=self.storage_device ) - + # Store the value if key in self.complementary_info_storage: if isinstance(value, torch.Tensor): @@ -304,7 +302,7 @@ class ReplayBuffer: # For non-tensor values (like grasp_penalty) self.complementary_info_storage[key][self.position] = torch.tensor( value, device=self.storage_device - ) + ) self.position = (self.position + 1) % self.capacity self.size = min(self.size + 1, self.capacity) @@ -366,7 +364,7 @@ class ReplayBuffer: # Add complementary_info to batch if it exists batch_complementary_info = {} - if hasattr(self, 'complementary_info_keys') and self.complementary_info_keys: + if hasattr(self, "complementary_info_keys") and self.complementary_info_keys: for key in self.complementary_info_keys: if key in self.complementary_info_storage: batch_complementary_info[key] = self.complementary_info_storage[key][idx].to(self.device) diff --git a/lerobot/scripts/server/end_effector_control_utils.py b/lerobot/scripts/server/end_effector_control_utils.py index f272426d..3b2cfa90 100644 --- a/lerobot/scripts/server/end_effector_control_utils.py +++ b/lerobot/scripts/server/end_effector_control_utils.py @@ -311,7 +311,7 @@ class GamepadController(InputController): except pygame.error: logging.error("Error reading gamepad. Is it still connected?") return 0.0, 0.0, 0.0 - + def get_gripper_action(self): """ Get gripper action using L3/R3 buttons. @@ -319,12 +319,12 @@ class GamepadController(InputController): Press right stick (R3) to close the gripper. """ import pygame - + try: # Check if buttons are pressed l3_pressed = self.joystick.get_button(9) r3_pressed = self.joystick.get_button(10) - + # Determine action based on button presses if r3_pressed: return 1.0 # Close gripper @@ -332,9 +332,9 @@ class GamepadController(InputController): return -1.0 # Open gripper else: return 0.0 # No change - + except pygame.error: - logging.error(f"Error reading gamepad. Is it still connected?") + logging.error("Error reading gamepad. Is it still connected?") return 0.0 diff --git a/lerobot/scripts/server/gym_manipulator.py b/lerobot/scripts/server/gym_manipulator.py index 26ed1991..ac3bbb0a 100644 --- a/lerobot/scripts/server/gym_manipulator.py +++ b/lerobot/scripts/server/gym_manipulator.py @@ -1077,18 +1077,20 @@ class GripperPenaltyWrapper(gym.Wrapper): def reset(self, **kwargs): obs, info = self.env.reset(**kwargs) - self.last_gripper_pos = obs["observation.state"][0, 0] # first idx for the gripper + self.last_gripper_pos = obs["observation.state"][0, 0] # first idx for the gripper return obs, info def step(self, action): observation, reward, terminated, truncated, info = self.env.step(action) - if (action[-1] < -0.5 and self.last_gripper_pos > 0.9) or (action[-1] > 0.5 and self.last_gripper_pos < 0.9): + if (action[-1] < -0.5 and self.last_gripper_pos > 0.9) or ( + action[-1] > 0.5 and self.last_gripper_pos < 0.9 + ): info["grasp_penalty"] = self.penalty else: info["grasp_penalty"] = 0.0 - self.last_gripper_pos = observation["observation.state"][0, 0] # first idx for the gripper + self.last_gripper_pos = observation["observation.state"][0, 0] # first idx for the gripper return observation, reward, terminated, truncated, info @@ -1167,7 +1169,7 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv: if cfg.wrapper.ee_action_space_params is None and cfg.wrapper.joint_masking_action_space is not None: env = JointMaskingActionSpace(env=env, mask=cfg.wrapper.joint_masking_action_space) env = BatchCompitableWrapper(env=env) - env= GripperPenaltyWrapper(env=env) + env = GripperPenaltyWrapper(env=env) return env diff --git a/lerobot/scripts/server/learner_server.py b/lerobot/scripts/server/learner_server.py index f79e8d57..0f760dc5 100644 --- a/lerobot/scripts/server/learner_server.py +++ b/lerobot/scripts/server/learner_server.py @@ -417,7 +417,7 @@ def add_actor_information_and_train( ) optimizers["grasp_critic"].step() - + policy.update_target_networks() policy.update_grasp_target_networks() @@ -762,7 +762,9 @@ def make_optimizers_and_scheduler(cfg, policy: nn.Module): lr=cfg.policy.actor_lr, ) optimizer_critic = torch.optim.Adam(params=policy.critic_ensemble.parameters(), lr=cfg.policy.critic_lr) - optimizer_grasp_critic = torch.optim.Adam(params=policy.grasp_critic.parameters(), lr=policy.config.grasp_critic_lr) + optimizer_grasp_critic = torch.optim.Adam( + params=policy.grasp_critic.parameters(), lr=policy.config.grasp_critic_lr + ) optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=cfg.policy.critic_lr) lr_scheduler = None optimizers = {