[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
parent
c774bbe522
commit
7983baf4fc
|
@ -198,7 +198,7 @@ class SACPolicy(
|
||||||
|
|
||||||
def grasp_critic_forward(self, observations, use_target=False, observation_features=None):
|
def grasp_critic_forward(self, observations, use_target=False, observation_features=None):
|
||||||
"""Forward pass through a grasp critic network
|
"""Forward pass through a grasp critic network
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
observations: Dictionary of observations
|
observations: Dictionary of observations
|
||||||
use_target: If True, use target critics, otherwise use ensemble critics
|
use_target: If True, use target critics, otherwise use ensemble critics
|
||||||
|
@ -254,7 +254,7 @@ class SACPolicy(
|
||||||
observation_features=observation_features,
|
observation_features=observation_features,
|
||||||
next_observation_features=next_observation_features,
|
next_observation_features=next_observation_features,
|
||||||
)
|
)
|
||||||
|
|
||||||
if model == "grasp_critic":
|
if model == "grasp_critic":
|
||||||
# Extract grasp_critic-specific components
|
# Extract grasp_critic-specific components
|
||||||
complementary_info: dict[str, Tensor] = batch["complementary_info"]
|
complementary_info: dict[str, Tensor] = batch["complementary_info"]
|
||||||
|
@ -307,7 +307,7 @@ class SACPolicy(
|
||||||
param.data * self.config.critic_target_update_weight
|
param.data * self.config.critic_target_update_weight
|
||||||
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
|
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
|
||||||
)
|
)
|
||||||
|
|
||||||
def update_temperature(self):
|
def update_temperature(self):
|
||||||
self.temperature = self.log_alpha.exp().item()
|
self.temperature = self.log_alpha.exp().item()
|
||||||
|
|
||||||
|
@ -369,8 +369,17 @@ class SACPolicy(
|
||||||
).sum()
|
).sum()
|
||||||
return critics_loss
|
return critics_loss
|
||||||
|
|
||||||
def compute_loss_grasp_critic(self, observations, actions, rewards, next_observations, done, observation_features=None, next_observation_features=None, complementary_info=None):
|
def compute_loss_grasp_critic(
|
||||||
|
self,
|
||||||
|
observations,
|
||||||
|
actions,
|
||||||
|
rewards,
|
||||||
|
next_observations,
|
||||||
|
done,
|
||||||
|
observation_features=None,
|
||||||
|
next_observation_features=None,
|
||||||
|
complementary_info=None,
|
||||||
|
):
|
||||||
batch_size = rewards.shape[0]
|
batch_size = rewards.shape[0]
|
||||||
grasp_actions = torch.clip(actions[:, -1].long() + 1, 0, 2) # Map [-1, 0, 1] -> [0, 1, 2]
|
grasp_actions = torch.clip(actions[:, -1].long() + 1, 0, 2) # Map [-1, 0, 1] -> [0, 1, 2]
|
||||||
|
|
||||||
|
@ -632,9 +641,7 @@ class GraspCritic(nn.Module):
|
||||||
self.parameters_to_optimize += list(self.output_layer.parameters())
|
self.parameters_to_optimize += list(self.output_layer.parameters())
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self, observations: torch.Tensor, observation_features: torch.Tensor | None = None
|
||||||
observations: torch.Tensor,
|
|
||||||
observation_features: torch.Tensor | None = None
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
device = get_device_from_parameters(self)
|
device = get_device_from_parameters(self)
|
||||||
# Move each tensor in observations to device
|
# Move each tensor in observations to device
|
||||||
|
|
|
@ -248,7 +248,7 @@ class ReplayBuffer:
|
||||||
# Initialize complementary_info storage
|
# Initialize complementary_info storage
|
||||||
self.complementary_info_keys = []
|
self.complementary_info_keys = []
|
||||||
self.complementary_info_storage = {}
|
self.complementary_info_storage = {}
|
||||||
|
|
||||||
self.initialized = True
|
self.initialized = True
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
|
@ -291,11 +291,9 @@ class ReplayBuffer:
|
||||||
if isinstance(value, torch.Tensor):
|
if isinstance(value, torch.Tensor):
|
||||||
shape = value.shape if value.ndim > 0 else (1,)
|
shape = value.shape if value.ndim > 0 else (1,)
|
||||||
self.complementary_info_storage[key] = torch.zeros(
|
self.complementary_info_storage[key] = torch.zeros(
|
||||||
(self.capacity, *shape),
|
(self.capacity, *shape), dtype=value.dtype, device=self.storage_device
|
||||||
dtype=value.dtype,
|
|
||||||
device=self.storage_device
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Store the value
|
# Store the value
|
||||||
if key in self.complementary_info_storage:
|
if key in self.complementary_info_storage:
|
||||||
if isinstance(value, torch.Tensor):
|
if isinstance(value, torch.Tensor):
|
||||||
|
@ -304,7 +302,7 @@ class ReplayBuffer:
|
||||||
# For non-tensor values (like grasp_penalty)
|
# For non-tensor values (like grasp_penalty)
|
||||||
self.complementary_info_storage[key][self.position] = torch.tensor(
|
self.complementary_info_storage[key][self.position] = torch.tensor(
|
||||||
value, device=self.storage_device
|
value, device=self.storage_device
|
||||||
)
|
)
|
||||||
|
|
||||||
self.position = (self.position + 1) % self.capacity
|
self.position = (self.position + 1) % self.capacity
|
||||||
self.size = min(self.size + 1, self.capacity)
|
self.size = min(self.size + 1, self.capacity)
|
||||||
|
@ -366,7 +364,7 @@ class ReplayBuffer:
|
||||||
|
|
||||||
# Add complementary_info to batch if it exists
|
# Add complementary_info to batch if it exists
|
||||||
batch_complementary_info = {}
|
batch_complementary_info = {}
|
||||||
if hasattr(self, 'complementary_info_keys') and self.complementary_info_keys:
|
if hasattr(self, "complementary_info_keys") and self.complementary_info_keys:
|
||||||
for key in self.complementary_info_keys:
|
for key in self.complementary_info_keys:
|
||||||
if key in self.complementary_info_storage:
|
if key in self.complementary_info_storage:
|
||||||
batch_complementary_info[key] = self.complementary_info_storage[key][idx].to(self.device)
|
batch_complementary_info[key] = self.complementary_info_storage[key][idx].to(self.device)
|
||||||
|
|
|
@ -311,7 +311,7 @@ class GamepadController(InputController):
|
||||||
except pygame.error:
|
except pygame.error:
|
||||||
logging.error("Error reading gamepad. Is it still connected?")
|
logging.error("Error reading gamepad. Is it still connected?")
|
||||||
return 0.0, 0.0, 0.0
|
return 0.0, 0.0, 0.0
|
||||||
|
|
||||||
def get_gripper_action(self):
|
def get_gripper_action(self):
|
||||||
"""
|
"""
|
||||||
Get gripper action using L3/R3 buttons.
|
Get gripper action using L3/R3 buttons.
|
||||||
|
@ -319,12 +319,12 @@ class GamepadController(InputController):
|
||||||
Press right stick (R3) to close the gripper.
|
Press right stick (R3) to close the gripper.
|
||||||
"""
|
"""
|
||||||
import pygame
|
import pygame
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Check if buttons are pressed
|
# Check if buttons are pressed
|
||||||
l3_pressed = self.joystick.get_button(9)
|
l3_pressed = self.joystick.get_button(9)
|
||||||
r3_pressed = self.joystick.get_button(10)
|
r3_pressed = self.joystick.get_button(10)
|
||||||
|
|
||||||
# Determine action based on button presses
|
# Determine action based on button presses
|
||||||
if r3_pressed:
|
if r3_pressed:
|
||||||
return 1.0 # Close gripper
|
return 1.0 # Close gripper
|
||||||
|
@ -332,9 +332,9 @@ class GamepadController(InputController):
|
||||||
return -1.0 # Open gripper
|
return -1.0 # Open gripper
|
||||||
else:
|
else:
|
||||||
return 0.0 # No change
|
return 0.0 # No change
|
||||||
|
|
||||||
except pygame.error:
|
except pygame.error:
|
||||||
logging.error(f"Error reading gamepad. Is it still connected?")
|
logging.error("Error reading gamepad. Is it still connected?")
|
||||||
return 0.0
|
return 0.0
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1077,18 +1077,20 @@ class GripperPenaltyWrapper(gym.Wrapper):
|
||||||
|
|
||||||
def reset(self, **kwargs):
|
def reset(self, **kwargs):
|
||||||
obs, info = self.env.reset(**kwargs)
|
obs, info = self.env.reset(**kwargs)
|
||||||
self.last_gripper_pos = obs["observation.state"][0, 0] # first idx for the gripper
|
self.last_gripper_pos = obs["observation.state"][0, 0] # first idx for the gripper
|
||||||
return obs, info
|
return obs, info
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
observation, reward, terminated, truncated, info = self.env.step(action)
|
observation, reward, terminated, truncated, info = self.env.step(action)
|
||||||
|
|
||||||
if (action[-1] < -0.5 and self.last_gripper_pos > 0.9) or (action[-1] > 0.5 and self.last_gripper_pos < 0.9):
|
if (action[-1] < -0.5 and self.last_gripper_pos > 0.9) or (
|
||||||
|
action[-1] > 0.5 and self.last_gripper_pos < 0.9
|
||||||
|
):
|
||||||
info["grasp_penalty"] = self.penalty
|
info["grasp_penalty"] = self.penalty
|
||||||
else:
|
else:
|
||||||
info["grasp_penalty"] = 0.0
|
info["grasp_penalty"] = 0.0
|
||||||
|
|
||||||
self.last_gripper_pos = observation["observation.state"][0, 0] # first idx for the gripper
|
self.last_gripper_pos = observation["observation.state"][0, 0] # first idx for the gripper
|
||||||
return observation, reward, terminated, truncated, info
|
return observation, reward, terminated, truncated, info
|
||||||
|
|
||||||
|
|
||||||
|
@ -1167,7 +1169,7 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv:
|
||||||
if cfg.wrapper.ee_action_space_params is None and cfg.wrapper.joint_masking_action_space is not None:
|
if cfg.wrapper.ee_action_space_params is None and cfg.wrapper.joint_masking_action_space is not None:
|
||||||
env = JointMaskingActionSpace(env=env, mask=cfg.wrapper.joint_masking_action_space)
|
env = JointMaskingActionSpace(env=env, mask=cfg.wrapper.joint_masking_action_space)
|
||||||
env = BatchCompitableWrapper(env=env)
|
env = BatchCompitableWrapper(env=env)
|
||||||
env= GripperPenaltyWrapper(env=env)
|
env = GripperPenaltyWrapper(env=env)
|
||||||
|
|
||||||
return env
|
return env
|
||||||
|
|
||||||
|
|
|
@ -417,7 +417,7 @@ def add_actor_information_and_train(
|
||||||
)
|
)
|
||||||
|
|
||||||
optimizers["grasp_critic"].step()
|
optimizers["grasp_critic"].step()
|
||||||
|
|
||||||
policy.update_target_networks()
|
policy.update_target_networks()
|
||||||
policy.update_grasp_target_networks()
|
policy.update_grasp_target_networks()
|
||||||
|
|
||||||
|
@ -762,7 +762,9 @@ def make_optimizers_and_scheduler(cfg, policy: nn.Module):
|
||||||
lr=cfg.policy.actor_lr,
|
lr=cfg.policy.actor_lr,
|
||||||
)
|
)
|
||||||
optimizer_critic = torch.optim.Adam(params=policy.critic_ensemble.parameters(), lr=cfg.policy.critic_lr)
|
optimizer_critic = torch.optim.Adam(params=policy.critic_ensemble.parameters(), lr=cfg.policy.critic_lr)
|
||||||
optimizer_grasp_critic = torch.optim.Adam(params=policy.grasp_critic.parameters(), lr=policy.config.grasp_critic_lr)
|
optimizer_grasp_critic = torch.optim.Adam(
|
||||||
|
params=policy.grasp_critic.parameters(), lr=policy.config.grasp_critic_lr
|
||||||
|
)
|
||||||
optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=cfg.policy.critic_lr)
|
optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=cfg.policy.critic_lr)
|
||||||
lr_scheduler = None
|
lr_scheduler = None
|
||||||
optimizers = {
|
optimizers = {
|
||||||
|
|
Loading…
Reference in New Issue