Added Gripper quantization wrapper and grasp penalty

removed complementary info from buffer and learner server
removed get_gripper_action function
added gripper parameters to `common/envs/configs.py`
This commit is contained in:
Michel Aractingi 2025-04-01 11:08:15 +02:00
parent 7983baf4fc
commit fe2ff516a8
5 changed files with 66 additions and 99 deletions

View File

@ -203,6 +203,9 @@ class EnvWrapperConfig:
joint_masking_action_space: Optional[Any] = None
ee_action_space_params: Optional[EEActionSpaceConfig] = None
use_gripper: bool = False
gripper_quantization_threshold: float = 0.8
gripper_penalty: float = 0.0
open_gripper_on_reset: bool = False
@EnvConfig.register_subclass(name="gym_manipulator")

View File

@ -245,10 +245,6 @@ class ReplayBuffer:
self.dones = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device)
self.truncateds = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device)
# Initialize complementary_info storage
self.complementary_info_keys = []
self.complementary_info_storage = {}
self.initialized = True
def __len__(self):
@ -282,28 +278,6 @@ class ReplayBuffer:
self.dones[self.position] = done
self.truncateds[self.position] = truncated
# Store complementary info if provided
if complementary_info is not None:
# Initialize storage for new keys on first encounter
for key, value in complementary_info.items():
if key not in self.complementary_info_keys:
self.complementary_info_keys.append(key)
if isinstance(value, torch.Tensor):
shape = value.shape if value.ndim > 0 else (1,)
self.complementary_info_storage[key] = torch.zeros(
(self.capacity, *shape), dtype=value.dtype, device=self.storage_device
)
# Store the value
if key in self.complementary_info_storage:
if isinstance(value, torch.Tensor):
self.complementary_info_storage[key][self.position] = value
else:
# For non-tensor values (like grasp_penalty)
self.complementary_info_storage[key][self.position] = torch.tensor(
value, device=self.storage_device
)
self.position = (self.position + 1) % self.capacity
self.size = min(self.size + 1, self.capacity)
@ -362,13 +336,6 @@ class ReplayBuffer:
batch_dones = self.dones[idx].to(self.device).float()
batch_truncateds = self.truncateds[idx].to(self.device).float()
# Add complementary_info to batch if it exists
batch_complementary_info = {}
if hasattr(self, "complementary_info_keys") and self.complementary_info_keys:
for key in self.complementary_info_keys:
if key in self.complementary_info_storage:
batch_complementary_info[key] = self.complementary_info_storage[key][idx].to(self.device)
return BatchTransition(
state=batch_state,
action=batch_actions,
@ -376,7 +343,6 @@ class ReplayBuffer:
next_state=batch_next_state,
done=batch_dones,
truncated=batch_truncateds,
complementary_info=batch_complementary_info if batch_complementary_info else None,
)
@classmethod

View File

@ -312,31 +312,6 @@ class GamepadController(InputController):
logging.error("Error reading gamepad. Is it still connected?")
return 0.0, 0.0, 0.0
def get_gripper_action(self):
"""
Get gripper action using L3/R3 buttons.
Press left stick (L3) to open the gripper.
Press right stick (R3) to close the gripper.
"""
import pygame
try:
# Check if buttons are pressed
l3_pressed = self.joystick.get_button(9)
r3_pressed = self.joystick.get_button(10)
# Determine action based on button presses
if r3_pressed:
return 1.0 # Close gripper
elif l3_pressed:
return -1.0 # Open gripper
else:
return 0.0 # No change
except pygame.error:
logging.error("Error reading gamepad. Is it still connected?")
return 0.0
class GamepadControllerHID(InputController):
"""Generate motion deltas from gamepad input using HIDAPI."""

View File

@ -761,6 +761,62 @@ class BatchCompitableWrapper(gym.ObservationWrapper):
return observation
class GripperPenaltyWrapper(gym.RewardWrapper):
def __init__(self, env, penalty: float = -0.1):
super().__init__(env)
self.penalty = penalty
self.last_gripper_state = None
def reward(self, reward, action):
gripper_state_normalized = self.last_gripper_state / MAX_GRIPPER_COMMAND
if isinstance(action, tuple):
action = action[0]
action_normalized = action[-1] / MAX_GRIPPER_COMMAND
gripper_penalty_bool = (gripper_state_normalized < 0.1 and action_normalized > 0.9) or (
gripper_state_normalized > 0.9 and action_normalized < 0.1
)
breakpoint()
return reward + self.penalty * gripper_penalty_bool
def step(self, action):
self.last_gripper_state = self.unwrapped.robot.follower_arms["main"].read("Present_Position")[-1]
obs, reward, terminated, truncated, info = self.env.step(action)
reward = self.reward(reward, action)
return obs, reward, terminated, truncated, info
def reset(self, **kwargs):
self.last_gripper_state = None
return super().reset(**kwargs)
class GripperQuantizationWrapper(gym.ActionWrapper):
def __init__(self, env, quantization_threshold: float = 0.2):
super().__init__(env)
self.quantization_threshold = quantization_threshold
def action(self, action):
is_intervention = False
if isinstance(action, tuple):
action, is_intervention = action
gripper_command = action[-1]
# Quantize gripper command to -1, 0 or 1
if gripper_command < -self.quantization_threshold:
gripper_command = -MAX_GRIPPER_COMMAND
elif gripper_command > self.quantization_threshold:
gripper_command = MAX_GRIPPER_COMMAND
else:
gripper_command = 0.0
gripper_state = self.unwrapped.robot.follower_arms["main"].read("Present_Position")[-1]
gripper_action = np.clip(gripper_state + gripper_command, 0, MAX_GRIPPER_COMMAND)
action[-1] = gripper_action.item()
return action, is_intervention
class EEActionWrapper(gym.ActionWrapper):
def __init__(self, env, ee_action_space_params=None, use_gripper=False):
super().__init__(env)
@ -820,17 +876,7 @@ class EEActionWrapper(gym.ActionWrapper):
fk_func=self.fk_function,
)
if self.use_gripper:
# Quantize gripper command to -1, 0 or 1
if gripper_command < -0.2:
gripper_command = -1.0
elif gripper_command > 0.2:
gripper_command = 1.0
else:
gripper_command = 0.0
gripper_state = self.unwrapped.robot.follower_arms["main"].read("Present_Position")[-1]
gripper_action = np.clip(gripper_state + gripper_command, 0, MAX_GRIPPER_COMMAND)
target_joint_pos[-1] = gripper_action
target_joint_pos[-1] = gripper_command
return target_joint_pos, is_intervention
@ -1069,31 +1115,6 @@ class ActionScaleWrapper(gym.ActionWrapper):
return action * self.scale_vector, is_intervention
class GripperPenaltyWrapper(gym.Wrapper):
def __init__(self, env, penalty=-0.05):
super().__init__(env)
self.penalty = penalty
self.last_gripper_pos = None
def reset(self, **kwargs):
obs, info = self.env.reset(**kwargs)
self.last_gripper_pos = obs["observation.state"][0, 0] # first idx for the gripper
return obs, info
def step(self, action):
observation, reward, terminated, truncated, info = self.env.step(action)
if (action[-1] < -0.5 and self.last_gripper_pos > 0.9) or (
action[-1] > 0.5 and self.last_gripper_pos < 0.9
):
info["grasp_penalty"] = self.penalty
else:
info["grasp_penalty"] = 0.0
self.last_gripper_pos = observation["observation.state"][0, 0] # first idx for the gripper
return observation, reward, terminated, truncated, info
def make_robot_env(cfg) -> gym.vector.VectorEnv:
"""
Factory function to create a vectorized robot environment.
@ -1143,6 +1164,12 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv:
# Add reward computation and control wrappers
# env = RewardWrapper(env=env, reward_classifier=reward_classifier, device=cfg.device)
env = TimeLimitWrapper(env=env, control_time_s=cfg.wrapper.control_time_s, fps=cfg.fps)
if cfg.wrapper.use_gripper:
env = GripperQuantizationWrapper(
env=env, quantization_threshold=cfg.wrapper.gripper_quantization_threshold
)
# env = GripperPenaltyWrapper(env=env, penalty=cfg.wrapper.gripper_penalty)
if cfg.wrapper.ee_action_space_params is not None:
env = EEActionWrapper(
env=env,
@ -1169,7 +1196,6 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv:
if cfg.wrapper.ee_action_space_params is None and cfg.wrapper.joint_masking_action_space is not None:
env = JointMaskingActionSpace(env=env, mask=cfg.wrapper.joint_masking_action_space)
env = BatchCompitableWrapper(env=env)
env = GripperPenaltyWrapper(env=env)
return env

View File

@ -375,7 +375,6 @@ def add_actor_information_and_train(
observations = batch["state"]
next_observations = batch["next_state"]
done = batch["done"]
complementary_info = batch["complementary_info"]
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
observation_features, next_observation_features = get_observation_features(
@ -391,7 +390,6 @@ def add_actor_information_and_train(
"done": done,
"observation_feature": observation_features,
"next_observation_feature": next_observation_features,
"complementary_info": complementary_info,
}
# Use the forward method for critic loss
@ -450,7 +448,6 @@ def add_actor_information_and_train(
"done": done,
"observation_feature": observation_features,
"next_observation_feature": next_observation_features,
"complementary_info": complementary_info,
}
# Use the forward method for critic loss