modified grasppenalty, added init_final inititalization for the grasp critic
This commit is contained in:
parent
f3cea2a3e5
commit
90a30ed319
|
@ -203,6 +203,7 @@ class EnvWrapperConfig:
|
||||||
use_gripper: bool = False
|
use_gripper: bool = False
|
||||||
gripper_quantization_threshold: float | None = None
|
gripper_quantization_threshold: float | None = None
|
||||||
gripper_penalty: float = 0.0
|
gripper_penalty: float = 0.0
|
||||||
|
gripper_penalty_in_reward: bool = False
|
||||||
open_gripper_on_reset: bool = False
|
open_gripper_on_reset: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -224,10 +224,6 @@ class SACPolicy(
|
||||||
|
|
||||||
critics = self.critic_target if use_target else self.critic_ensemble
|
critics = self.critic_target if use_target else self.critic_ensemble
|
||||||
q_values = critics(observations, actions, observation_features)
|
q_values = critics(observations, actions, observation_features)
|
||||||
if not use_target:
|
|
||||||
for name, param in critics.named_parameters():
|
|
||||||
if param.requires_grad:
|
|
||||||
print(f"Critic Ensemble layer {name}, norm {param.data.norm().item()}")
|
|
||||||
return q_values
|
return q_values
|
||||||
|
|
||||||
def grasp_critic_forward(self, observations, use_target=False, observation_features=None) -> torch.Tensor:
|
def grasp_critic_forward(self, observations, use_target=False, observation_features=None) -> torch.Tensor:
|
||||||
|
@ -243,10 +239,6 @@ class SACPolicy(
|
||||||
"""
|
"""
|
||||||
grasp_critic = self.grasp_critic_target if use_target else self.grasp_critic
|
grasp_critic = self.grasp_critic_target if use_target else self.grasp_critic
|
||||||
q_values = grasp_critic(observations, observation_features)
|
q_values = grasp_critic(observations, observation_features)
|
||||||
if not use_target:
|
|
||||||
for name, param in grasp_critic.named_parameters():
|
|
||||||
if param.requires_grad:
|
|
||||||
print(f"Grasp critic layer {name}, norm {param.data.norm().item()}")
|
|
||||||
return q_values
|
return q_values
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
@ -577,7 +569,6 @@ class SACObservationEncoder(nn.Module):
|
||||||
obs_dict = self.input_normalization(obs_dict)
|
obs_dict = self.input_normalization(obs_dict)
|
||||||
if len(self.all_image_keys) > 0 and vision_encoder_cache is None:
|
if len(self.all_image_keys) > 0 and vision_encoder_cache is None:
|
||||||
vision_encoder_cache = self.get_image_features(obs_dict)
|
vision_encoder_cache = self.get_image_features(obs_dict)
|
||||||
feat.append(vision_encoder_cache)
|
|
||||||
|
|
||||||
if vision_encoder_cache is not None:
|
if vision_encoder_cache is not None:
|
||||||
feat.append(vision_encoder_cache)
|
feat.append(vision_encoder_cache)
|
||||||
|
@ -805,6 +796,7 @@ class GraspCritic(nn.Module):
|
||||||
)
|
)
|
||||||
|
|
||||||
self.output_layer = nn.Linear(in_features=hidden_dims[-1], out_features=self.output_dim)
|
self.output_layer = nn.Linear(in_features=hidden_dims[-1], out_features=self.output_dim)
|
||||||
|
init_final = 0.05
|
||||||
if init_final is not None:
|
if init_final is not None:
|
||||||
nn.init.uniform_(self.output_layer.weight, -init_final, init_final)
|
nn.init.uniform_(self.output_layer.weight, -init_final, init_final)
|
||||||
nn.init.uniform_(self.output_layer.bias, -init_final, init_final)
|
nn.init.uniform_(self.output_layer.bias, -init_final, init_final)
|
||||||
|
|
|
@ -370,7 +370,7 @@ class RewardWrapper(gym.Wrapper):
|
||||||
self.device = device
|
self.device = device
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
observation, _, terminated, truncated, info = self.env.step(action)
|
observation, reward, terminated, truncated, info = self.env.step(action)
|
||||||
images = [
|
images = [
|
||||||
observation[key].to(self.device, non_blocking=self.device.type == "cuda")
|
observation[key].to(self.device, non_blocking=self.device.type == "cuda")
|
||||||
for key in observation
|
for key in observation
|
||||||
|
@ -378,15 +378,17 @@ class RewardWrapper(gym.Wrapper):
|
||||||
]
|
]
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
reward = (
|
success = (
|
||||||
self.reward_classifier.predict_reward(images, threshold=0.8)
|
self.reward_classifier.predict_reward(images, threshold=0.8)
|
||||||
if self.reward_classifier is not None
|
if self.reward_classifier is not None
|
||||||
else 0.0
|
else 0.0
|
||||||
)
|
)
|
||||||
info["Reward classifer frequency"] = 1 / (time.perf_counter() - start_time)
|
info["Reward classifer frequency"] = 1 / (time.perf_counter() - start_time)
|
||||||
|
|
||||||
if reward == 1.0:
|
if success == 1.0:
|
||||||
terminated = True
|
terminated = True
|
||||||
|
reward = 1.0
|
||||||
|
|
||||||
return observation, reward, terminated, truncated, info
|
return observation, reward, terminated, truncated, info
|
||||||
|
|
||||||
def reset(self, seed=None, options=None):
|
def reset(self, seed=None, options=None):
|
||||||
|
@ -773,28 +775,38 @@ class BatchCompitableWrapper(gym.ObservationWrapper):
|
||||||
|
|
||||||
|
|
||||||
class GripperPenaltyWrapper(gym.RewardWrapper):
|
class GripperPenaltyWrapper(gym.RewardWrapper):
|
||||||
def __init__(self, env, penalty: float = -0.1):
|
def __init__(self, env, penalty: float = -0.1, gripper_penalty_in_reward: bool = True):
|
||||||
super().__init__(env)
|
super().__init__(env)
|
||||||
self.penalty = penalty
|
self.penalty = penalty
|
||||||
|
self.gripper_penalty_in_reward = gripper_penalty_in_reward
|
||||||
self.last_gripper_state = None
|
self.last_gripper_state = None
|
||||||
|
|
||||||
|
|
||||||
def reward(self, reward, action):
|
def reward(self, reward, action):
|
||||||
gripper_state_normalized = self.last_gripper_state / MAX_GRIPPER_COMMAND
|
gripper_state_normalized = self.last_gripper_state / MAX_GRIPPER_COMMAND
|
||||||
|
|
||||||
if isinstance(action, tuple):
|
action_normalized = action - 1.0 #action / MAX_GRIPPER_COMMAND
|
||||||
action = action[0]
|
|
||||||
action_normalized = action[-1] / MAX_GRIPPER_COMMAND
|
|
||||||
|
|
||||||
gripper_penalty_bool = (gripper_state_normalized < 0.1 and action_normalized > 0.9) or (
|
gripper_penalty_bool = (gripper_state_normalized < 0.75 and action_normalized > 0.5) or (
|
||||||
gripper_state_normalized > 0.9 and action_normalized < 0.1
|
gripper_state_normalized > 0.75 and action_normalized < -0.5
|
||||||
)
|
)
|
||||||
|
|
||||||
return reward + self.penalty * gripper_penalty_bool
|
return reward + self.penalty * int(gripper_penalty_bool)
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
self.last_gripper_state = self.unwrapped.robot.follower_arms["main"].read("Present_Position")[-1]
|
self.last_gripper_state = self.unwrapped.robot.follower_arms["main"].read("Present_Position")[-1]
|
||||||
|
if isinstance(action, tuple):
|
||||||
|
gripper_action = action[0][-1]
|
||||||
|
else:
|
||||||
|
gripper_action = action[-1]
|
||||||
obs, reward, terminated, truncated, info = self.env.step(action)
|
obs, reward, terminated, truncated, info = self.env.step(action)
|
||||||
reward = self.reward(reward, action)
|
grasp_reward = self.reward(reward, gripper_action)
|
||||||
|
|
||||||
|
if self.gripper_penalty_in_reward:
|
||||||
|
reward += grasp_reward
|
||||||
|
else:
|
||||||
|
info["grasp_reward"] = grasp_reward
|
||||||
|
|
||||||
return obs, reward, terminated, truncated, info
|
return obs, reward, terminated, truncated, info
|
||||||
|
|
||||||
def reset(self, **kwargs):
|
def reset(self, **kwargs):
|
||||||
|
@ -1180,7 +1192,7 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv:
|
||||||
env = GripperActionWrapper(
|
env = GripperActionWrapper(
|
||||||
env=env, quantization_threshold=cfg.wrapper.gripper_quantization_threshold
|
env=env, quantization_threshold=cfg.wrapper.gripper_quantization_threshold
|
||||||
)
|
)
|
||||||
# env = GripperPenaltyWrapper(env=env, penalty=cfg.wrapper.gripper_penalty)
|
env = GripperPenaltyWrapper(env=env, penalty=cfg.wrapper.gripper_penalty)
|
||||||
|
|
||||||
if cfg.wrapper.ee_action_space_params is not None:
|
if cfg.wrapper.ee_action_space_params is not None:
|
||||||
env = EEActionWrapper(
|
env = EEActionWrapper(
|
||||||
|
|
Loading…
Reference in New Issue