diff --git a/lerobot/common/envs/configs.py b/lerobot/common/envs/configs.py index d91df43f..958f25d1 100644 --- a/lerobot/common/envs/configs.py +++ b/lerobot/common/envs/configs.py @@ -203,6 +203,7 @@ class EnvWrapperConfig: use_gripper: bool = False gripper_quantization_threshold: float | None = None gripper_penalty: float = 0.0 + gripper_penalty_in_reward: bool = False open_gripper_on_reset: bool = False diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index 29f4b5ff..7b3c9c41 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -224,10 +224,6 @@ class SACPolicy( critics = self.critic_target if use_target else self.critic_ensemble q_values = critics(observations, actions, observation_features) - if not use_target: - for name, param in critics.named_parameters(): - if param.requires_grad: - print(f"Critic Ensemble layer {name}, norm {param.data.norm().item()}") return q_values def grasp_critic_forward(self, observations, use_target=False, observation_features=None) -> torch.Tensor: @@ -243,10 +239,6 @@ class SACPolicy( """ grasp_critic = self.grasp_critic_target if use_target else self.grasp_critic q_values = grasp_critic(observations, observation_features) - if not use_target: - for name, param in grasp_critic.named_parameters(): - if param.requires_grad: - print(f"Grasp critic layer {name}, norm {param.data.norm().item()}") return q_values def forward( @@ -577,7 +569,6 @@ class SACObservationEncoder(nn.Module): obs_dict = self.input_normalization(obs_dict) if len(self.all_image_keys) > 0 and vision_encoder_cache is None: vision_encoder_cache = self.get_image_features(obs_dict) - feat.append(vision_encoder_cache) if vision_encoder_cache is not None: feat.append(vision_encoder_cache) @@ -805,6 +796,7 @@ class GraspCritic(nn.Module): ) self.output_layer = nn.Linear(in_features=hidden_dims[-1], out_features=self.output_dim) + init_final = 0.05 if init_final is not None: nn.init.uniform_(self.output_layer.weight, -init_final, init_final) nn.init.uniform_(self.output_layer.bias, -init_final, init_final) diff --git a/lerobot/scripts/server/gym_manipulator.py b/lerobot/scripts/server/gym_manipulator.py index 8eae6fb9..501a71ce 100644 --- a/lerobot/scripts/server/gym_manipulator.py +++ b/lerobot/scripts/server/gym_manipulator.py @@ -370,7 +370,7 @@ class RewardWrapper(gym.Wrapper): self.device = device def step(self, action): - observation, _, terminated, truncated, info = self.env.step(action) + observation, reward, terminated, truncated, info = self.env.step(action) images = [ observation[key].to(self.device, non_blocking=self.device.type == "cuda") for key in observation @@ -378,15 +378,17 @@ class RewardWrapper(gym.Wrapper): ] start_time = time.perf_counter() with torch.inference_mode(): - reward = ( + success = ( self.reward_classifier.predict_reward(images, threshold=0.8) if self.reward_classifier is not None else 0.0 ) info["Reward classifer frequency"] = 1 / (time.perf_counter() - start_time) - if reward == 1.0: + if success == 1.0: terminated = True + reward = 1.0 + return observation, reward, terminated, truncated, info def reset(self, seed=None, options=None): @@ -773,28 +775,38 @@ class BatchCompitableWrapper(gym.ObservationWrapper): class GripperPenaltyWrapper(gym.RewardWrapper): - def __init__(self, env, penalty: float = -0.1): + def __init__(self, env, penalty: float = -0.1, gripper_penalty_in_reward: bool = True): super().__init__(env) self.penalty = penalty + self.gripper_penalty_in_reward = gripper_penalty_in_reward self.last_gripper_state = None + def reward(self, reward, action): gripper_state_normalized = self.last_gripper_state / MAX_GRIPPER_COMMAND - if isinstance(action, tuple): - action = action[0] - action_normalized = action[-1] / MAX_GRIPPER_COMMAND + action_normalized = action - 1.0 #action / MAX_GRIPPER_COMMAND - gripper_penalty_bool = (gripper_state_normalized < 0.1 and action_normalized > 0.9) or ( - gripper_state_normalized > 0.9 and action_normalized < 0.1 + gripper_penalty_bool = (gripper_state_normalized < 0.75 and action_normalized > 0.5) or ( + gripper_state_normalized > 0.75 and action_normalized < -0.5 ) - return reward + self.penalty * gripper_penalty_bool + return reward + self.penalty * int(gripper_penalty_bool) def step(self, action): self.last_gripper_state = self.unwrapped.robot.follower_arms["main"].read("Present_Position")[-1] + if isinstance(action, tuple): + gripper_action = action[0][-1] + else: + gripper_action = action[-1] obs, reward, terminated, truncated, info = self.env.step(action) - reward = self.reward(reward, action) + grasp_reward = self.reward(reward, gripper_action) + + if self.gripper_penalty_in_reward: + reward += grasp_reward + else: + info["grasp_reward"] = grasp_reward + return obs, reward, terminated, truncated, info def reset(self, **kwargs): @@ -1180,7 +1192,7 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv: env = GripperActionWrapper( env=env, quantization_threshold=cfg.wrapper.gripper_quantization_threshold ) - # env = GripperPenaltyWrapper(env=env, penalty=cfg.wrapper.gripper_penalty) + env = GripperPenaltyWrapper(env=env, penalty=cfg.wrapper.gripper_penalty) if cfg.wrapper.ee_action_space_params is not None: env = EEActionWrapper(