modified grasppenalty, added init_final inititalization for the grasp critic

This commit is contained in:
Michel Aractingi 2025-04-07 10:26:11 +02:00
parent f3cea2a3e5
commit 90a30ed319
3 changed files with 26 additions and 21 deletions

View File

@ -203,6 +203,7 @@ class EnvWrapperConfig:
use_gripper: bool = False use_gripper: bool = False
gripper_quantization_threshold: float | None = None gripper_quantization_threshold: float | None = None
gripper_penalty: float = 0.0 gripper_penalty: float = 0.0
gripper_penalty_in_reward: bool = False
open_gripper_on_reset: bool = False open_gripper_on_reset: bool = False

View File

@ -224,10 +224,6 @@ class SACPolicy(
critics = self.critic_target if use_target else self.critic_ensemble critics = self.critic_target if use_target else self.critic_ensemble
q_values = critics(observations, actions, observation_features) q_values = critics(observations, actions, observation_features)
if not use_target:
for name, param in critics.named_parameters():
if param.requires_grad:
print(f"Critic Ensemble layer {name}, norm {param.data.norm().item()}")
return q_values return q_values
def grasp_critic_forward(self, observations, use_target=False, observation_features=None) -> torch.Tensor: def grasp_critic_forward(self, observations, use_target=False, observation_features=None) -> torch.Tensor:
@ -243,10 +239,6 @@ class SACPolicy(
""" """
grasp_critic = self.grasp_critic_target if use_target else self.grasp_critic grasp_critic = self.grasp_critic_target if use_target else self.grasp_critic
q_values = grasp_critic(observations, observation_features) q_values = grasp_critic(observations, observation_features)
if not use_target:
for name, param in grasp_critic.named_parameters():
if param.requires_grad:
print(f"Grasp critic layer {name}, norm {param.data.norm().item()}")
return q_values return q_values
def forward( def forward(
@ -577,7 +569,6 @@ class SACObservationEncoder(nn.Module):
obs_dict = self.input_normalization(obs_dict) obs_dict = self.input_normalization(obs_dict)
if len(self.all_image_keys) > 0 and vision_encoder_cache is None: if len(self.all_image_keys) > 0 and vision_encoder_cache is None:
vision_encoder_cache = self.get_image_features(obs_dict) vision_encoder_cache = self.get_image_features(obs_dict)
feat.append(vision_encoder_cache)
if vision_encoder_cache is not None: if vision_encoder_cache is not None:
feat.append(vision_encoder_cache) feat.append(vision_encoder_cache)
@ -805,6 +796,7 @@ class GraspCritic(nn.Module):
) )
self.output_layer = nn.Linear(in_features=hidden_dims[-1], out_features=self.output_dim) self.output_layer = nn.Linear(in_features=hidden_dims[-1], out_features=self.output_dim)
init_final = 0.05
if init_final is not None: if init_final is not None:
nn.init.uniform_(self.output_layer.weight, -init_final, init_final) nn.init.uniform_(self.output_layer.weight, -init_final, init_final)
nn.init.uniform_(self.output_layer.bias, -init_final, init_final) nn.init.uniform_(self.output_layer.bias, -init_final, init_final)

View File

@ -370,7 +370,7 @@ class RewardWrapper(gym.Wrapper):
self.device = device self.device = device
def step(self, action): def step(self, action):
observation, _, terminated, truncated, info = self.env.step(action) observation, reward, terminated, truncated, info = self.env.step(action)
images = [ images = [
observation[key].to(self.device, non_blocking=self.device.type == "cuda") observation[key].to(self.device, non_blocking=self.device.type == "cuda")
for key in observation for key in observation
@ -378,15 +378,17 @@ class RewardWrapper(gym.Wrapper):
] ]
start_time = time.perf_counter() start_time = time.perf_counter()
with torch.inference_mode(): with torch.inference_mode():
reward = ( success = (
self.reward_classifier.predict_reward(images, threshold=0.8) self.reward_classifier.predict_reward(images, threshold=0.8)
if self.reward_classifier is not None if self.reward_classifier is not None
else 0.0 else 0.0
) )
info["Reward classifer frequency"] = 1 / (time.perf_counter() - start_time) info["Reward classifer frequency"] = 1 / (time.perf_counter() - start_time)
if reward == 1.0: if success == 1.0:
terminated = True terminated = True
reward = 1.0
return observation, reward, terminated, truncated, info return observation, reward, terminated, truncated, info
def reset(self, seed=None, options=None): def reset(self, seed=None, options=None):
@ -773,28 +775,38 @@ class BatchCompitableWrapper(gym.ObservationWrapper):
class GripperPenaltyWrapper(gym.RewardWrapper): class GripperPenaltyWrapper(gym.RewardWrapper):
def __init__(self, env, penalty: float = -0.1): def __init__(self, env, penalty: float = -0.1, gripper_penalty_in_reward: bool = True):
super().__init__(env) super().__init__(env)
self.penalty = penalty self.penalty = penalty
self.gripper_penalty_in_reward = gripper_penalty_in_reward
self.last_gripper_state = None self.last_gripper_state = None
def reward(self, reward, action): def reward(self, reward, action):
gripper_state_normalized = self.last_gripper_state / MAX_GRIPPER_COMMAND gripper_state_normalized = self.last_gripper_state / MAX_GRIPPER_COMMAND
if isinstance(action, tuple): action_normalized = action - 1.0 #action / MAX_GRIPPER_COMMAND
action = action[0]
action_normalized = action[-1] / MAX_GRIPPER_COMMAND
gripper_penalty_bool = (gripper_state_normalized < 0.1 and action_normalized > 0.9) or ( gripper_penalty_bool = (gripper_state_normalized < 0.75 and action_normalized > 0.5) or (
gripper_state_normalized > 0.9 and action_normalized < 0.1 gripper_state_normalized > 0.75 and action_normalized < -0.5
) )
return reward + self.penalty * gripper_penalty_bool return reward + self.penalty * int(gripper_penalty_bool)
def step(self, action): def step(self, action):
self.last_gripper_state = self.unwrapped.robot.follower_arms["main"].read("Present_Position")[-1] self.last_gripper_state = self.unwrapped.robot.follower_arms["main"].read("Present_Position")[-1]
if isinstance(action, tuple):
gripper_action = action[0][-1]
else:
gripper_action = action[-1]
obs, reward, terminated, truncated, info = self.env.step(action) obs, reward, terminated, truncated, info = self.env.step(action)
reward = self.reward(reward, action) grasp_reward = self.reward(reward, gripper_action)
if self.gripper_penalty_in_reward:
reward += grasp_reward
else:
info["grasp_reward"] = grasp_reward
return obs, reward, terminated, truncated, info return obs, reward, terminated, truncated, info
def reset(self, **kwargs): def reset(self, **kwargs):
@ -1180,7 +1192,7 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv:
env = GripperActionWrapper( env = GripperActionWrapper(
env=env, quantization_threshold=cfg.wrapper.gripper_quantization_threshold env=env, quantization_threshold=cfg.wrapper.gripper_quantization_threshold
) )
# env = GripperPenaltyWrapper(env=env, penalty=cfg.wrapper.gripper_penalty) env = GripperPenaltyWrapper(env=env, penalty=cfg.wrapper.gripper_penalty)
if cfg.wrapper.ee_action_space_params is not None: if cfg.wrapper.ee_action_space_params is not None:
env = EEActionWrapper( env = EEActionWrapper(