modified grasppenalty, added init_final inititalization for the grasp critic
This commit is contained in:
parent
f3cea2a3e5
commit
90a30ed319
|
@ -203,6 +203,7 @@ class EnvWrapperConfig:
|
|||
use_gripper: bool = False
|
||||
gripper_quantization_threshold: float | None = None
|
||||
gripper_penalty: float = 0.0
|
||||
gripper_penalty_in_reward: bool = False
|
||||
open_gripper_on_reset: bool = False
|
||||
|
||||
|
||||
|
|
|
@ -224,10 +224,6 @@ class SACPolicy(
|
|||
|
||||
critics = self.critic_target if use_target else self.critic_ensemble
|
||||
q_values = critics(observations, actions, observation_features)
|
||||
if not use_target:
|
||||
for name, param in critics.named_parameters():
|
||||
if param.requires_grad:
|
||||
print(f"Critic Ensemble layer {name}, norm {param.data.norm().item()}")
|
||||
return q_values
|
||||
|
||||
def grasp_critic_forward(self, observations, use_target=False, observation_features=None) -> torch.Tensor:
|
||||
|
@ -243,10 +239,6 @@ class SACPolicy(
|
|||
"""
|
||||
grasp_critic = self.grasp_critic_target if use_target else self.grasp_critic
|
||||
q_values = grasp_critic(observations, observation_features)
|
||||
if not use_target:
|
||||
for name, param in grasp_critic.named_parameters():
|
||||
if param.requires_grad:
|
||||
print(f"Grasp critic layer {name}, norm {param.data.norm().item()}")
|
||||
return q_values
|
||||
|
||||
def forward(
|
||||
|
@ -577,7 +569,6 @@ class SACObservationEncoder(nn.Module):
|
|||
obs_dict = self.input_normalization(obs_dict)
|
||||
if len(self.all_image_keys) > 0 and vision_encoder_cache is None:
|
||||
vision_encoder_cache = self.get_image_features(obs_dict)
|
||||
feat.append(vision_encoder_cache)
|
||||
|
||||
if vision_encoder_cache is not None:
|
||||
feat.append(vision_encoder_cache)
|
||||
|
@ -805,6 +796,7 @@ class GraspCritic(nn.Module):
|
|||
)
|
||||
|
||||
self.output_layer = nn.Linear(in_features=hidden_dims[-1], out_features=self.output_dim)
|
||||
init_final = 0.05
|
||||
if init_final is not None:
|
||||
nn.init.uniform_(self.output_layer.weight, -init_final, init_final)
|
||||
nn.init.uniform_(self.output_layer.bias, -init_final, init_final)
|
||||
|
|
|
@ -370,7 +370,7 @@ class RewardWrapper(gym.Wrapper):
|
|||
self.device = device
|
||||
|
||||
def step(self, action):
|
||||
observation, _, terminated, truncated, info = self.env.step(action)
|
||||
observation, reward, terminated, truncated, info = self.env.step(action)
|
||||
images = [
|
||||
observation[key].to(self.device, non_blocking=self.device.type == "cuda")
|
||||
for key in observation
|
||||
|
@ -378,15 +378,17 @@ class RewardWrapper(gym.Wrapper):
|
|||
]
|
||||
start_time = time.perf_counter()
|
||||
with torch.inference_mode():
|
||||
reward = (
|
||||
success = (
|
||||
self.reward_classifier.predict_reward(images, threshold=0.8)
|
||||
if self.reward_classifier is not None
|
||||
else 0.0
|
||||
)
|
||||
info["Reward classifer frequency"] = 1 / (time.perf_counter() - start_time)
|
||||
|
||||
if reward == 1.0:
|
||||
if success == 1.0:
|
||||
terminated = True
|
||||
reward = 1.0
|
||||
|
||||
return observation, reward, terminated, truncated, info
|
||||
|
||||
def reset(self, seed=None, options=None):
|
||||
|
@ -773,28 +775,38 @@ class BatchCompitableWrapper(gym.ObservationWrapper):
|
|||
|
||||
|
||||
class GripperPenaltyWrapper(gym.RewardWrapper):
|
||||
def __init__(self, env, penalty: float = -0.1):
|
||||
def __init__(self, env, penalty: float = -0.1, gripper_penalty_in_reward: bool = True):
|
||||
super().__init__(env)
|
||||
self.penalty = penalty
|
||||
self.gripper_penalty_in_reward = gripper_penalty_in_reward
|
||||
self.last_gripper_state = None
|
||||
|
||||
|
||||
def reward(self, reward, action):
|
||||
gripper_state_normalized = self.last_gripper_state / MAX_GRIPPER_COMMAND
|
||||
|
||||
if isinstance(action, tuple):
|
||||
action = action[0]
|
||||
action_normalized = action[-1] / MAX_GRIPPER_COMMAND
|
||||
action_normalized = action - 1.0 #action / MAX_GRIPPER_COMMAND
|
||||
|
||||
gripper_penalty_bool = (gripper_state_normalized < 0.1 and action_normalized > 0.9) or (
|
||||
gripper_state_normalized > 0.9 and action_normalized < 0.1
|
||||
gripper_penalty_bool = (gripper_state_normalized < 0.75 and action_normalized > 0.5) or (
|
||||
gripper_state_normalized > 0.75 and action_normalized < -0.5
|
||||
)
|
||||
|
||||
return reward + self.penalty * gripper_penalty_bool
|
||||
return reward + self.penalty * int(gripper_penalty_bool)
|
||||
|
||||
def step(self, action):
|
||||
self.last_gripper_state = self.unwrapped.robot.follower_arms["main"].read("Present_Position")[-1]
|
||||
if isinstance(action, tuple):
|
||||
gripper_action = action[0][-1]
|
||||
else:
|
||||
gripper_action = action[-1]
|
||||
obs, reward, terminated, truncated, info = self.env.step(action)
|
||||
reward = self.reward(reward, action)
|
||||
grasp_reward = self.reward(reward, gripper_action)
|
||||
|
||||
if self.gripper_penalty_in_reward:
|
||||
reward += grasp_reward
|
||||
else:
|
||||
info["grasp_reward"] = grasp_reward
|
||||
|
||||
return obs, reward, terminated, truncated, info
|
||||
|
||||
def reset(self, **kwargs):
|
||||
|
@ -1180,7 +1192,7 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv:
|
|||
env = GripperActionWrapper(
|
||||
env=env, quantization_threshold=cfg.wrapper.gripper_quantization_threshold
|
||||
)
|
||||
# env = GripperPenaltyWrapper(env=env, penalty=cfg.wrapper.gripper_penalty)
|
||||
env = GripperPenaltyWrapper(env=env, penalty=cfg.wrapper.gripper_penalty)
|
||||
|
||||
if cfg.wrapper.ee_action_space_params is not None:
|
||||
env = EEActionWrapper(
|
||||
|
|
Loading…
Reference in New Issue