gixes for gripper penalty
This commit is contained in:
parent
90a30ed319
commit
06de182448
|
@ -428,6 +428,7 @@ class SACPolicy(
|
|||
actions_discrete = torch.round(actions_discrete)
|
||||
actions_discrete = actions_discrete.long()
|
||||
|
||||
gripper_penalties: Tensor | None = None
|
||||
if complementary_info is not None:
|
||||
gripper_penalties: Tensor | None = complementary_info.get("gripper_penalty")
|
||||
|
||||
|
|
|
@ -281,7 +281,7 @@ def act_with_policy(
|
|||
for key, tensor in obs.items():
|
||||
if torch.isnan(tensor).any():
|
||||
logging.error(f"[ACTOR] NaN values found in obs[{key}] at step {interaction_step}")
|
||||
|
||||
|
||||
list_transition_to_send_to_learner.append(
|
||||
Transition(
|
||||
state=obs,
|
||||
|
|
|
@ -269,7 +269,7 @@ class ReplayBuffer:
|
|||
self.complementary_info[key] = torch.empty(
|
||||
(self.capacity, *value_shape), device=self.storage_device
|
||||
)
|
||||
elif isinstance(value, (int, float)):
|
||||
elif isinstance(value, (int, float, bool)):
|
||||
# Handle scalar values similar to reward
|
||||
self.complementary_info[key] = torch.empty((self.capacity,), device=self.storage_device)
|
||||
else:
|
||||
|
|
|
@ -800,18 +800,21 @@ class GripperPenaltyWrapper(gym.RewardWrapper):
|
|||
else:
|
||||
gripper_action = action[-1]
|
||||
obs, reward, terminated, truncated, info = self.env.step(action)
|
||||
grasp_reward = self.reward(reward, gripper_action)
|
||||
gripper_penalty = self.reward(reward, gripper_action)
|
||||
|
||||
if self.gripper_penalty_in_reward:
|
||||
reward += grasp_reward
|
||||
reward += gripper_penalty
|
||||
else:
|
||||
info["grasp_reward"] = grasp_reward
|
||||
info["gripper_penalty"] = gripper_penalty
|
||||
|
||||
return obs, reward, terminated, truncated, info
|
||||
|
||||
def reset(self, **kwargs):
|
||||
self.last_gripper_state = None
|
||||
return super().reset(**kwargs)
|
||||
obs, info = super().reset(**kwargs)
|
||||
if self.gripper_penalty_in_reward:
|
||||
info["gripper_penalty"] = 0.0
|
||||
return obs, info
|
||||
|
||||
class GripperActionWrapper(gym.ActionWrapper):
|
||||
def __init__(self, env, quantization_threshold: float = 0.2):
|
||||
|
@ -1192,7 +1195,7 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv:
|
|||
env = GripperActionWrapper(
|
||||
env=env, quantization_threshold=cfg.wrapper.gripper_quantization_threshold
|
||||
)
|
||||
env = GripperPenaltyWrapper(env=env, penalty=cfg.wrapper.gripper_penalty)
|
||||
env = GripperPenaltyWrapper(env=env, penalty=cfg.wrapper.gripper_penalty, gripper_penalty_in_reward=cfg.wrapper.gripper_penalty_in_reward)
|
||||
|
||||
if cfg.wrapper.ee_action_space_params is not None:
|
||||
env = EEActionWrapper(
|
||||
|
|
|
@ -408,6 +408,7 @@ def add_actor_information_and_train(
|
|||
"done": done,
|
||||
"observation_feature": observation_features,
|
||||
"next_observation_feature": next_observation_features,
|
||||
"complementary_info": batch.get("complementary_info", None),
|
||||
}
|
||||
|
||||
# Use the forward method for critic loss (includes both main critic and grasp critic)
|
||||
|
|
Loading…
Reference in New Issue