[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot] 2025-03-31 16:10:00 +00:00
parent c774bbe522
commit 7983baf4fc
5 changed files with 35 additions and 26 deletions

View File

@ -369,8 +369,17 @@ class SACPolicy(
).sum() ).sum()
return critics_loss return critics_loss
def compute_loss_grasp_critic(self, observations, actions, rewards, next_observations, done, observation_features=None, next_observation_features=None, complementary_info=None): def compute_loss_grasp_critic(
self,
observations,
actions,
rewards,
next_observations,
done,
observation_features=None,
next_observation_features=None,
complementary_info=None,
):
batch_size = rewards.shape[0] batch_size = rewards.shape[0]
grasp_actions = torch.clip(actions[:, -1].long() + 1, 0, 2) # Map [-1, 0, 1] -> [0, 1, 2] grasp_actions = torch.clip(actions[:, -1].long() + 1, 0, 2) # Map [-1, 0, 1] -> [0, 1, 2]
@ -632,9 +641,7 @@ class GraspCritic(nn.Module):
self.parameters_to_optimize += list(self.output_layer.parameters()) self.parameters_to_optimize += list(self.output_layer.parameters())
def forward( def forward(
self, self, observations: torch.Tensor, observation_features: torch.Tensor | None = None
observations: torch.Tensor,
observation_features: torch.Tensor | None = None
) -> torch.Tensor: ) -> torch.Tensor:
device = get_device_from_parameters(self) device = get_device_from_parameters(self)
# Move each tensor in observations to device # Move each tensor in observations to device

View File

@ -291,9 +291,7 @@ class ReplayBuffer:
if isinstance(value, torch.Tensor): if isinstance(value, torch.Tensor):
shape = value.shape if value.ndim > 0 else (1,) shape = value.shape if value.ndim > 0 else (1,)
self.complementary_info_storage[key] = torch.zeros( self.complementary_info_storage[key] = torch.zeros(
(self.capacity, *shape), (self.capacity, *shape), dtype=value.dtype, device=self.storage_device
dtype=value.dtype,
device=self.storage_device
) )
# Store the value # Store the value
@ -366,7 +364,7 @@ class ReplayBuffer:
# Add complementary_info to batch if it exists # Add complementary_info to batch if it exists
batch_complementary_info = {} batch_complementary_info = {}
if hasattr(self, 'complementary_info_keys') and self.complementary_info_keys: if hasattr(self, "complementary_info_keys") and self.complementary_info_keys:
for key in self.complementary_info_keys: for key in self.complementary_info_keys:
if key in self.complementary_info_storage: if key in self.complementary_info_storage:
batch_complementary_info[key] = self.complementary_info_storage[key][idx].to(self.device) batch_complementary_info[key] = self.complementary_info_storage[key][idx].to(self.device)

View File

@ -334,7 +334,7 @@ class GamepadController(InputController):
return 0.0 # No change return 0.0 # No change
except pygame.error: except pygame.error:
logging.error(f"Error reading gamepad. Is it still connected?") logging.error("Error reading gamepad. Is it still connected?")
return 0.0 return 0.0

View File

@ -1083,7 +1083,9 @@ class GripperPenaltyWrapper(gym.Wrapper):
def step(self, action): def step(self, action):
observation, reward, terminated, truncated, info = self.env.step(action) observation, reward, terminated, truncated, info = self.env.step(action)
if (action[-1] < -0.5 and self.last_gripper_pos > 0.9) or (action[-1] > 0.5 and self.last_gripper_pos < 0.9): if (action[-1] < -0.5 and self.last_gripper_pos > 0.9) or (
action[-1] > 0.5 and self.last_gripper_pos < 0.9
):
info["grasp_penalty"] = self.penalty info["grasp_penalty"] = self.penalty
else: else:
info["grasp_penalty"] = 0.0 info["grasp_penalty"] = 0.0

View File

@ -762,7 +762,9 @@ def make_optimizers_and_scheduler(cfg, policy: nn.Module):
lr=cfg.policy.actor_lr, lr=cfg.policy.actor_lr,
) )
optimizer_critic = torch.optim.Adam(params=policy.critic_ensemble.parameters(), lr=cfg.policy.critic_lr) optimizer_critic = torch.optim.Adam(params=policy.critic_ensemble.parameters(), lr=cfg.policy.critic_lr)
optimizer_grasp_critic = torch.optim.Adam(params=policy.grasp_critic.parameters(), lr=policy.config.grasp_critic_lr) optimizer_grasp_critic = torch.optim.Adam(
params=policy.grasp_critic.parameters(), lr=policy.config.grasp_critic_lr
)
optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=cfg.policy.critic_lr) optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=cfg.policy.critic_lr)
lr_scheduler = None lr_scheduler = None
optimizers = { optimizers = {