Added Gripper quantization wrapper and grasp penalty
removed complementary info from buffer and learner server removed get_gripper_action function added gripper parameters to `common/envs/configs.py`
This commit is contained in:
parent
7983baf4fc
commit
fe2ff516a8
|
@ -203,6 +203,9 @@ class EnvWrapperConfig:
|
|||
joint_masking_action_space: Optional[Any] = None
|
||||
ee_action_space_params: Optional[EEActionSpaceConfig] = None
|
||||
use_gripper: bool = False
|
||||
gripper_quantization_threshold: float = 0.8
|
||||
gripper_penalty: float = 0.0
|
||||
open_gripper_on_reset: bool = False
|
||||
|
||||
|
||||
@EnvConfig.register_subclass(name="gym_manipulator")
|
||||
|
|
|
@ -245,10 +245,6 @@ class ReplayBuffer:
|
|||
self.dones = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device)
|
||||
self.truncateds = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device)
|
||||
|
||||
# Initialize complementary_info storage
|
||||
self.complementary_info_keys = []
|
||||
self.complementary_info_storage = {}
|
||||
|
||||
self.initialized = True
|
||||
|
||||
def __len__(self):
|
||||
|
@ -282,28 +278,6 @@ class ReplayBuffer:
|
|||
self.dones[self.position] = done
|
||||
self.truncateds[self.position] = truncated
|
||||
|
||||
# Store complementary info if provided
|
||||
if complementary_info is not None:
|
||||
# Initialize storage for new keys on first encounter
|
||||
for key, value in complementary_info.items():
|
||||
if key not in self.complementary_info_keys:
|
||||
self.complementary_info_keys.append(key)
|
||||
if isinstance(value, torch.Tensor):
|
||||
shape = value.shape if value.ndim > 0 else (1,)
|
||||
self.complementary_info_storage[key] = torch.zeros(
|
||||
(self.capacity, *shape), dtype=value.dtype, device=self.storage_device
|
||||
)
|
||||
|
||||
# Store the value
|
||||
if key in self.complementary_info_storage:
|
||||
if isinstance(value, torch.Tensor):
|
||||
self.complementary_info_storage[key][self.position] = value
|
||||
else:
|
||||
# For non-tensor values (like grasp_penalty)
|
||||
self.complementary_info_storage[key][self.position] = torch.tensor(
|
||||
value, device=self.storage_device
|
||||
)
|
||||
|
||||
self.position = (self.position + 1) % self.capacity
|
||||
self.size = min(self.size + 1, self.capacity)
|
||||
|
||||
|
@ -362,13 +336,6 @@ class ReplayBuffer:
|
|||
batch_dones = self.dones[idx].to(self.device).float()
|
||||
batch_truncateds = self.truncateds[idx].to(self.device).float()
|
||||
|
||||
# Add complementary_info to batch if it exists
|
||||
batch_complementary_info = {}
|
||||
if hasattr(self, "complementary_info_keys") and self.complementary_info_keys:
|
||||
for key in self.complementary_info_keys:
|
||||
if key in self.complementary_info_storage:
|
||||
batch_complementary_info[key] = self.complementary_info_storage[key][idx].to(self.device)
|
||||
|
||||
return BatchTransition(
|
||||
state=batch_state,
|
||||
action=batch_actions,
|
||||
|
@ -376,7 +343,6 @@ class ReplayBuffer:
|
|||
next_state=batch_next_state,
|
||||
done=batch_dones,
|
||||
truncated=batch_truncateds,
|
||||
complementary_info=batch_complementary_info if batch_complementary_info else None,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
|
|
@ -312,31 +312,6 @@ class GamepadController(InputController):
|
|||
logging.error("Error reading gamepad. Is it still connected?")
|
||||
return 0.0, 0.0, 0.0
|
||||
|
||||
def get_gripper_action(self):
|
||||
"""
|
||||
Get gripper action using L3/R3 buttons.
|
||||
Press left stick (L3) to open the gripper.
|
||||
Press right stick (R3) to close the gripper.
|
||||
"""
|
||||
import pygame
|
||||
|
||||
try:
|
||||
# Check if buttons are pressed
|
||||
l3_pressed = self.joystick.get_button(9)
|
||||
r3_pressed = self.joystick.get_button(10)
|
||||
|
||||
# Determine action based on button presses
|
||||
if r3_pressed:
|
||||
return 1.0 # Close gripper
|
||||
elif l3_pressed:
|
||||
return -1.0 # Open gripper
|
||||
else:
|
||||
return 0.0 # No change
|
||||
|
||||
except pygame.error:
|
||||
logging.error("Error reading gamepad. Is it still connected?")
|
||||
return 0.0
|
||||
|
||||
|
||||
class GamepadControllerHID(InputController):
|
||||
"""Generate motion deltas from gamepad input using HIDAPI."""
|
||||
|
|
|
@ -761,6 +761,62 @@ class BatchCompitableWrapper(gym.ObservationWrapper):
|
|||
return observation
|
||||
|
||||
|
||||
class GripperPenaltyWrapper(gym.RewardWrapper):
|
||||
def __init__(self, env, penalty: float = -0.1):
|
||||
super().__init__(env)
|
||||
self.penalty = penalty
|
||||
self.last_gripper_state = None
|
||||
|
||||
def reward(self, reward, action):
|
||||
gripper_state_normalized = self.last_gripper_state / MAX_GRIPPER_COMMAND
|
||||
|
||||
if isinstance(action, tuple):
|
||||
action = action[0]
|
||||
action_normalized = action[-1] / MAX_GRIPPER_COMMAND
|
||||
|
||||
gripper_penalty_bool = (gripper_state_normalized < 0.1 and action_normalized > 0.9) or (
|
||||
gripper_state_normalized > 0.9 and action_normalized < 0.1
|
||||
)
|
||||
breakpoint()
|
||||
|
||||
return reward + self.penalty * gripper_penalty_bool
|
||||
|
||||
def step(self, action):
|
||||
self.last_gripper_state = self.unwrapped.robot.follower_arms["main"].read("Present_Position")[-1]
|
||||
obs, reward, terminated, truncated, info = self.env.step(action)
|
||||
reward = self.reward(reward, action)
|
||||
return obs, reward, terminated, truncated, info
|
||||
|
||||
def reset(self, **kwargs):
|
||||
self.last_gripper_state = None
|
||||
return super().reset(**kwargs)
|
||||
|
||||
|
||||
class GripperQuantizationWrapper(gym.ActionWrapper):
|
||||
def __init__(self, env, quantization_threshold: float = 0.2):
|
||||
super().__init__(env)
|
||||
self.quantization_threshold = quantization_threshold
|
||||
|
||||
def action(self, action):
|
||||
is_intervention = False
|
||||
if isinstance(action, tuple):
|
||||
action, is_intervention = action
|
||||
|
||||
gripper_command = action[-1]
|
||||
# Quantize gripper command to -1, 0 or 1
|
||||
if gripper_command < -self.quantization_threshold:
|
||||
gripper_command = -MAX_GRIPPER_COMMAND
|
||||
elif gripper_command > self.quantization_threshold:
|
||||
gripper_command = MAX_GRIPPER_COMMAND
|
||||
else:
|
||||
gripper_command = 0.0
|
||||
|
||||
gripper_state = self.unwrapped.robot.follower_arms["main"].read("Present_Position")[-1]
|
||||
gripper_action = np.clip(gripper_state + gripper_command, 0, MAX_GRIPPER_COMMAND)
|
||||
action[-1] = gripper_action.item()
|
||||
return action, is_intervention
|
||||
|
||||
|
||||
class EEActionWrapper(gym.ActionWrapper):
|
||||
def __init__(self, env, ee_action_space_params=None, use_gripper=False):
|
||||
super().__init__(env)
|
||||
|
@ -820,17 +876,7 @@ class EEActionWrapper(gym.ActionWrapper):
|
|||
fk_func=self.fk_function,
|
||||
)
|
||||
if self.use_gripper:
|
||||
# Quantize gripper command to -1, 0 or 1
|
||||
if gripper_command < -0.2:
|
||||
gripper_command = -1.0
|
||||
elif gripper_command > 0.2:
|
||||
gripper_command = 1.0
|
||||
else:
|
||||
gripper_command = 0.0
|
||||
|
||||
gripper_state = self.unwrapped.robot.follower_arms["main"].read("Present_Position")[-1]
|
||||
gripper_action = np.clip(gripper_state + gripper_command, 0, MAX_GRIPPER_COMMAND)
|
||||
target_joint_pos[-1] = gripper_action
|
||||
target_joint_pos[-1] = gripper_command
|
||||
|
||||
return target_joint_pos, is_intervention
|
||||
|
||||
|
@ -1069,31 +1115,6 @@ class ActionScaleWrapper(gym.ActionWrapper):
|
|||
return action * self.scale_vector, is_intervention
|
||||
|
||||
|
||||
class GripperPenaltyWrapper(gym.Wrapper):
|
||||
def __init__(self, env, penalty=-0.05):
|
||||
super().__init__(env)
|
||||
self.penalty = penalty
|
||||
self.last_gripper_pos = None
|
||||
|
||||
def reset(self, **kwargs):
|
||||
obs, info = self.env.reset(**kwargs)
|
||||
self.last_gripper_pos = obs["observation.state"][0, 0] # first idx for the gripper
|
||||
return obs, info
|
||||
|
||||
def step(self, action):
|
||||
observation, reward, terminated, truncated, info = self.env.step(action)
|
||||
|
||||
if (action[-1] < -0.5 and self.last_gripper_pos > 0.9) or (
|
||||
action[-1] > 0.5 and self.last_gripper_pos < 0.9
|
||||
):
|
||||
info["grasp_penalty"] = self.penalty
|
||||
else:
|
||||
info["grasp_penalty"] = 0.0
|
||||
|
||||
self.last_gripper_pos = observation["observation.state"][0, 0] # first idx for the gripper
|
||||
return observation, reward, terminated, truncated, info
|
||||
|
||||
|
||||
def make_robot_env(cfg) -> gym.vector.VectorEnv:
|
||||
"""
|
||||
Factory function to create a vectorized robot environment.
|
||||
|
@ -1143,6 +1164,12 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv:
|
|||
# Add reward computation and control wrappers
|
||||
# env = RewardWrapper(env=env, reward_classifier=reward_classifier, device=cfg.device)
|
||||
env = TimeLimitWrapper(env=env, control_time_s=cfg.wrapper.control_time_s, fps=cfg.fps)
|
||||
if cfg.wrapper.use_gripper:
|
||||
env = GripperQuantizationWrapper(
|
||||
env=env, quantization_threshold=cfg.wrapper.gripper_quantization_threshold
|
||||
)
|
||||
# env = GripperPenaltyWrapper(env=env, penalty=cfg.wrapper.gripper_penalty)
|
||||
|
||||
if cfg.wrapper.ee_action_space_params is not None:
|
||||
env = EEActionWrapper(
|
||||
env=env,
|
||||
|
@ -1169,7 +1196,6 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv:
|
|||
if cfg.wrapper.ee_action_space_params is None and cfg.wrapper.joint_masking_action_space is not None:
|
||||
env = JointMaskingActionSpace(env=env, mask=cfg.wrapper.joint_masking_action_space)
|
||||
env = BatchCompitableWrapper(env=env)
|
||||
env = GripperPenaltyWrapper(env=env)
|
||||
|
||||
return env
|
||||
|
||||
|
|
|
@ -375,7 +375,6 @@ def add_actor_information_and_train(
|
|||
observations = batch["state"]
|
||||
next_observations = batch["next_state"]
|
||||
done = batch["done"]
|
||||
complementary_info = batch["complementary_info"]
|
||||
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
|
||||
|
||||
observation_features, next_observation_features = get_observation_features(
|
||||
|
@ -391,7 +390,6 @@ def add_actor_information_and_train(
|
|||
"done": done,
|
||||
"observation_feature": observation_features,
|
||||
"next_observation_feature": next_observation_features,
|
||||
"complementary_info": complementary_info,
|
||||
}
|
||||
|
||||
# Use the forward method for critic loss
|
||||
|
@ -450,7 +448,6 @@ def add_actor_information_and_train(
|
|||
"done": done,
|
||||
"observation_feature": observation_features,
|
||||
"next_observation_feature": next_observation_features,
|
||||
"complementary_info": complementary_info,
|
||||
}
|
||||
|
||||
# Use the forward method for critic loss
|
||||
|
|
Loading…
Reference in New Issue