modified grasppenalty, added init_final inititalization for the grasp critic

This commit is contained in:
Michel Aractingi 2025-04-07 10:26:11 +02:00
parent f3cea2a3e5
commit 90a30ed319
3 changed files with 26 additions and 21 deletions

View File

@ -203,6 +203,7 @@ class EnvWrapperConfig:
use_gripper: bool = False
gripper_quantization_threshold: float | None = None
gripper_penalty: float = 0.0
gripper_penalty_in_reward: bool = False
open_gripper_on_reset: bool = False

View File

@ -224,10 +224,6 @@ class SACPolicy(
critics = self.critic_target if use_target else self.critic_ensemble
q_values = critics(observations, actions, observation_features)
if not use_target:
for name, param in critics.named_parameters():
if param.requires_grad:
print(f"Critic Ensemble layer {name}, norm {param.data.norm().item()}")
return q_values
def grasp_critic_forward(self, observations, use_target=False, observation_features=None) -> torch.Tensor:
@ -243,10 +239,6 @@ class SACPolicy(
"""
grasp_critic = self.grasp_critic_target if use_target else self.grasp_critic
q_values = grasp_critic(observations, observation_features)
if not use_target:
for name, param in grasp_critic.named_parameters():
if param.requires_grad:
print(f"Grasp critic layer {name}, norm {param.data.norm().item()}")
return q_values
def forward(
@ -577,7 +569,6 @@ class SACObservationEncoder(nn.Module):
obs_dict = self.input_normalization(obs_dict)
if len(self.all_image_keys) > 0 and vision_encoder_cache is None:
vision_encoder_cache = self.get_image_features(obs_dict)
feat.append(vision_encoder_cache)
if vision_encoder_cache is not None:
feat.append(vision_encoder_cache)
@ -805,6 +796,7 @@ class GraspCritic(nn.Module):
)
self.output_layer = nn.Linear(in_features=hidden_dims[-1], out_features=self.output_dim)
init_final = 0.05
if init_final is not None:
nn.init.uniform_(self.output_layer.weight, -init_final, init_final)
nn.init.uniform_(self.output_layer.bias, -init_final, init_final)

View File

@ -370,7 +370,7 @@ class RewardWrapper(gym.Wrapper):
self.device = device
def step(self, action):
observation, _, terminated, truncated, info = self.env.step(action)
observation, reward, terminated, truncated, info = self.env.step(action)
images = [
observation[key].to(self.device, non_blocking=self.device.type == "cuda")
for key in observation
@ -378,15 +378,17 @@ class RewardWrapper(gym.Wrapper):
]
start_time = time.perf_counter()
with torch.inference_mode():
reward = (
success = (
self.reward_classifier.predict_reward(images, threshold=0.8)
if self.reward_classifier is not None
else 0.0
)
info["Reward classifer frequency"] = 1 / (time.perf_counter() - start_time)
if reward == 1.0:
if success == 1.0:
terminated = True
reward = 1.0
return observation, reward, terminated, truncated, info
def reset(self, seed=None, options=None):
@ -773,28 +775,38 @@ class BatchCompitableWrapper(gym.ObservationWrapper):
class GripperPenaltyWrapper(gym.RewardWrapper):
def __init__(self, env, penalty: float = -0.1):
def __init__(self, env, penalty: float = -0.1, gripper_penalty_in_reward: bool = True):
super().__init__(env)
self.penalty = penalty
self.gripper_penalty_in_reward = gripper_penalty_in_reward
self.last_gripper_state = None
def reward(self, reward, action):
gripper_state_normalized = self.last_gripper_state / MAX_GRIPPER_COMMAND
if isinstance(action, tuple):
action = action[0]
action_normalized = action[-1] / MAX_GRIPPER_COMMAND
action_normalized = action - 1.0 #action / MAX_GRIPPER_COMMAND
gripper_penalty_bool = (gripper_state_normalized < 0.1 and action_normalized > 0.9) or (
gripper_state_normalized > 0.9 and action_normalized < 0.1
gripper_penalty_bool = (gripper_state_normalized < 0.75 and action_normalized > 0.5) or (
gripper_state_normalized > 0.75 and action_normalized < -0.5
)
return reward + self.penalty * gripper_penalty_bool
return reward + self.penalty * int(gripper_penalty_bool)
def step(self, action):
self.last_gripper_state = self.unwrapped.robot.follower_arms["main"].read("Present_Position")[-1]
if isinstance(action, tuple):
gripper_action = action[0][-1]
else:
gripper_action = action[-1]
obs, reward, terminated, truncated, info = self.env.step(action)
reward = self.reward(reward, action)
grasp_reward = self.reward(reward, gripper_action)
if self.gripper_penalty_in_reward:
reward += grasp_reward
else:
info["grasp_reward"] = grasp_reward
return obs, reward, terminated, truncated, info
def reset(self, **kwargs):
@ -1180,7 +1192,7 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv:
env = GripperActionWrapper(
env=env, quantization_threshold=cfg.wrapper.gripper_quantization_threshold
)
# env = GripperPenaltyWrapper(env=env, penalty=cfg.wrapper.gripper_penalty)
env = GripperPenaltyWrapper(env=env, penalty=cfg.wrapper.gripper_penalty)
if cfg.wrapper.ee_action_space_params is not None:
env = EEActionWrapper(