gixes for gripper penalty

This commit is contained in:
Michel Aractingi 2025-04-07 16:51:17 +02:00
parent 90a30ed319
commit 06de182448
5 changed files with 12 additions and 7 deletions

View File

@ -428,6 +428,7 @@ class SACPolicy(
actions_discrete = torch.round(actions_discrete)
actions_discrete = actions_discrete.long()
gripper_penalties: Tensor | None = None
if complementary_info is not None:
gripper_penalties: Tensor | None = complementary_info.get("gripper_penalty")

View File

@ -281,7 +281,7 @@ def act_with_policy(
for key, tensor in obs.items():
if torch.isnan(tensor).any():
logging.error(f"[ACTOR] NaN values found in obs[{key}] at step {interaction_step}")
list_transition_to_send_to_learner.append(
Transition(
state=obs,

View File

@ -269,7 +269,7 @@ class ReplayBuffer:
self.complementary_info[key] = torch.empty(
(self.capacity, *value_shape), device=self.storage_device
)
elif isinstance(value, (int, float)):
elif isinstance(value, (int, float, bool)):
# Handle scalar values similar to reward
self.complementary_info[key] = torch.empty((self.capacity,), device=self.storage_device)
else:

View File

@ -800,18 +800,21 @@ class GripperPenaltyWrapper(gym.RewardWrapper):
else:
gripper_action = action[-1]
obs, reward, terminated, truncated, info = self.env.step(action)
grasp_reward = self.reward(reward, gripper_action)
gripper_penalty = self.reward(reward, gripper_action)
if self.gripper_penalty_in_reward:
reward += grasp_reward
reward += gripper_penalty
else:
info["grasp_reward"] = grasp_reward
info["gripper_penalty"] = gripper_penalty
return obs, reward, terminated, truncated, info
def reset(self, **kwargs):
self.last_gripper_state = None
return super().reset(**kwargs)
obs, info = super().reset(**kwargs)
if self.gripper_penalty_in_reward:
info["gripper_penalty"] = 0.0
return obs, info
class GripperActionWrapper(gym.ActionWrapper):
def __init__(self, env, quantization_threshold: float = 0.2):
@ -1192,7 +1195,7 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv:
env = GripperActionWrapper(
env=env, quantization_threshold=cfg.wrapper.gripper_quantization_threshold
)
env = GripperPenaltyWrapper(env=env, penalty=cfg.wrapper.gripper_penalty)
env = GripperPenaltyWrapper(env=env, penalty=cfg.wrapper.gripper_penalty, gripper_penalty_in_reward=cfg.wrapper.gripper_penalty_in_reward)
if cfg.wrapper.ee_action_space_params is not None:
env = EEActionWrapper(

View File

@ -408,6 +408,7 @@ def add_actor_information_and_train(
"done": done,
"observation_feature": observation_features,
"next_observation_feature": next_observation_features,
"complementary_info": batch.get("complementary_info", None),
}
# Use the forward method for critic loss (includes both main critic and grasp critic)