Added Gripper quantization wrapper and grasp penalty
removed complementary info from buffer and learner server removed get_gripper_action function added gripper parameters to `common/envs/configs.py`
This commit is contained in:
parent
7983baf4fc
commit
fe2ff516a8
|
@ -203,6 +203,9 @@ class EnvWrapperConfig:
|
||||||
joint_masking_action_space: Optional[Any] = None
|
joint_masking_action_space: Optional[Any] = None
|
||||||
ee_action_space_params: Optional[EEActionSpaceConfig] = None
|
ee_action_space_params: Optional[EEActionSpaceConfig] = None
|
||||||
use_gripper: bool = False
|
use_gripper: bool = False
|
||||||
|
gripper_quantization_threshold: float = 0.8
|
||||||
|
gripper_penalty: float = 0.0
|
||||||
|
open_gripper_on_reset: bool = False
|
||||||
|
|
||||||
|
|
||||||
@EnvConfig.register_subclass(name="gym_manipulator")
|
@EnvConfig.register_subclass(name="gym_manipulator")
|
||||||
|
|
|
@ -245,10 +245,6 @@ class ReplayBuffer:
|
||||||
self.dones = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device)
|
self.dones = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device)
|
||||||
self.truncateds = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device)
|
self.truncateds = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device)
|
||||||
|
|
||||||
# Initialize complementary_info storage
|
|
||||||
self.complementary_info_keys = []
|
|
||||||
self.complementary_info_storage = {}
|
|
||||||
|
|
||||||
self.initialized = True
|
self.initialized = True
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
|
@ -282,28 +278,6 @@ class ReplayBuffer:
|
||||||
self.dones[self.position] = done
|
self.dones[self.position] = done
|
||||||
self.truncateds[self.position] = truncated
|
self.truncateds[self.position] = truncated
|
||||||
|
|
||||||
# Store complementary info if provided
|
|
||||||
if complementary_info is not None:
|
|
||||||
# Initialize storage for new keys on first encounter
|
|
||||||
for key, value in complementary_info.items():
|
|
||||||
if key not in self.complementary_info_keys:
|
|
||||||
self.complementary_info_keys.append(key)
|
|
||||||
if isinstance(value, torch.Tensor):
|
|
||||||
shape = value.shape if value.ndim > 0 else (1,)
|
|
||||||
self.complementary_info_storage[key] = torch.zeros(
|
|
||||||
(self.capacity, *shape), dtype=value.dtype, device=self.storage_device
|
|
||||||
)
|
|
||||||
|
|
||||||
# Store the value
|
|
||||||
if key in self.complementary_info_storage:
|
|
||||||
if isinstance(value, torch.Tensor):
|
|
||||||
self.complementary_info_storage[key][self.position] = value
|
|
||||||
else:
|
|
||||||
# For non-tensor values (like grasp_penalty)
|
|
||||||
self.complementary_info_storage[key][self.position] = torch.tensor(
|
|
||||||
value, device=self.storage_device
|
|
||||||
)
|
|
||||||
|
|
||||||
self.position = (self.position + 1) % self.capacity
|
self.position = (self.position + 1) % self.capacity
|
||||||
self.size = min(self.size + 1, self.capacity)
|
self.size = min(self.size + 1, self.capacity)
|
||||||
|
|
||||||
|
@ -362,13 +336,6 @@ class ReplayBuffer:
|
||||||
batch_dones = self.dones[idx].to(self.device).float()
|
batch_dones = self.dones[idx].to(self.device).float()
|
||||||
batch_truncateds = self.truncateds[idx].to(self.device).float()
|
batch_truncateds = self.truncateds[idx].to(self.device).float()
|
||||||
|
|
||||||
# Add complementary_info to batch if it exists
|
|
||||||
batch_complementary_info = {}
|
|
||||||
if hasattr(self, "complementary_info_keys") and self.complementary_info_keys:
|
|
||||||
for key in self.complementary_info_keys:
|
|
||||||
if key in self.complementary_info_storage:
|
|
||||||
batch_complementary_info[key] = self.complementary_info_storage[key][idx].to(self.device)
|
|
||||||
|
|
||||||
return BatchTransition(
|
return BatchTransition(
|
||||||
state=batch_state,
|
state=batch_state,
|
||||||
action=batch_actions,
|
action=batch_actions,
|
||||||
|
@ -376,7 +343,6 @@ class ReplayBuffer:
|
||||||
next_state=batch_next_state,
|
next_state=batch_next_state,
|
||||||
done=batch_dones,
|
done=batch_dones,
|
||||||
truncated=batch_truncateds,
|
truncated=batch_truncateds,
|
||||||
complementary_info=batch_complementary_info if batch_complementary_info else None,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
|
@ -312,31 +312,6 @@ class GamepadController(InputController):
|
||||||
logging.error("Error reading gamepad. Is it still connected?")
|
logging.error("Error reading gamepad. Is it still connected?")
|
||||||
return 0.0, 0.0, 0.0
|
return 0.0, 0.0, 0.0
|
||||||
|
|
||||||
def get_gripper_action(self):
|
|
||||||
"""
|
|
||||||
Get gripper action using L3/R3 buttons.
|
|
||||||
Press left stick (L3) to open the gripper.
|
|
||||||
Press right stick (R3) to close the gripper.
|
|
||||||
"""
|
|
||||||
import pygame
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Check if buttons are pressed
|
|
||||||
l3_pressed = self.joystick.get_button(9)
|
|
||||||
r3_pressed = self.joystick.get_button(10)
|
|
||||||
|
|
||||||
# Determine action based on button presses
|
|
||||||
if r3_pressed:
|
|
||||||
return 1.0 # Close gripper
|
|
||||||
elif l3_pressed:
|
|
||||||
return -1.0 # Open gripper
|
|
||||||
else:
|
|
||||||
return 0.0 # No change
|
|
||||||
|
|
||||||
except pygame.error:
|
|
||||||
logging.error("Error reading gamepad. Is it still connected?")
|
|
||||||
return 0.0
|
|
||||||
|
|
||||||
|
|
||||||
class GamepadControllerHID(InputController):
|
class GamepadControllerHID(InputController):
|
||||||
"""Generate motion deltas from gamepad input using HIDAPI."""
|
"""Generate motion deltas from gamepad input using HIDAPI."""
|
||||||
|
|
|
@ -761,6 +761,62 @@ class BatchCompitableWrapper(gym.ObservationWrapper):
|
||||||
return observation
|
return observation
|
||||||
|
|
||||||
|
|
||||||
|
class GripperPenaltyWrapper(gym.RewardWrapper):
|
||||||
|
def __init__(self, env, penalty: float = -0.1):
|
||||||
|
super().__init__(env)
|
||||||
|
self.penalty = penalty
|
||||||
|
self.last_gripper_state = None
|
||||||
|
|
||||||
|
def reward(self, reward, action):
|
||||||
|
gripper_state_normalized = self.last_gripper_state / MAX_GRIPPER_COMMAND
|
||||||
|
|
||||||
|
if isinstance(action, tuple):
|
||||||
|
action = action[0]
|
||||||
|
action_normalized = action[-1] / MAX_GRIPPER_COMMAND
|
||||||
|
|
||||||
|
gripper_penalty_bool = (gripper_state_normalized < 0.1 and action_normalized > 0.9) or (
|
||||||
|
gripper_state_normalized > 0.9 and action_normalized < 0.1
|
||||||
|
)
|
||||||
|
breakpoint()
|
||||||
|
|
||||||
|
return reward + self.penalty * gripper_penalty_bool
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
self.last_gripper_state = self.unwrapped.robot.follower_arms["main"].read("Present_Position")[-1]
|
||||||
|
obs, reward, terminated, truncated, info = self.env.step(action)
|
||||||
|
reward = self.reward(reward, action)
|
||||||
|
return obs, reward, terminated, truncated, info
|
||||||
|
|
||||||
|
def reset(self, **kwargs):
|
||||||
|
self.last_gripper_state = None
|
||||||
|
return super().reset(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class GripperQuantizationWrapper(gym.ActionWrapper):
|
||||||
|
def __init__(self, env, quantization_threshold: float = 0.2):
|
||||||
|
super().__init__(env)
|
||||||
|
self.quantization_threshold = quantization_threshold
|
||||||
|
|
||||||
|
def action(self, action):
|
||||||
|
is_intervention = False
|
||||||
|
if isinstance(action, tuple):
|
||||||
|
action, is_intervention = action
|
||||||
|
|
||||||
|
gripper_command = action[-1]
|
||||||
|
# Quantize gripper command to -1, 0 or 1
|
||||||
|
if gripper_command < -self.quantization_threshold:
|
||||||
|
gripper_command = -MAX_GRIPPER_COMMAND
|
||||||
|
elif gripper_command > self.quantization_threshold:
|
||||||
|
gripper_command = MAX_GRIPPER_COMMAND
|
||||||
|
else:
|
||||||
|
gripper_command = 0.0
|
||||||
|
|
||||||
|
gripper_state = self.unwrapped.robot.follower_arms["main"].read("Present_Position")[-1]
|
||||||
|
gripper_action = np.clip(gripper_state + gripper_command, 0, MAX_GRIPPER_COMMAND)
|
||||||
|
action[-1] = gripper_action.item()
|
||||||
|
return action, is_intervention
|
||||||
|
|
||||||
|
|
||||||
class EEActionWrapper(gym.ActionWrapper):
|
class EEActionWrapper(gym.ActionWrapper):
|
||||||
def __init__(self, env, ee_action_space_params=None, use_gripper=False):
|
def __init__(self, env, ee_action_space_params=None, use_gripper=False):
|
||||||
super().__init__(env)
|
super().__init__(env)
|
||||||
|
@ -820,17 +876,7 @@ class EEActionWrapper(gym.ActionWrapper):
|
||||||
fk_func=self.fk_function,
|
fk_func=self.fk_function,
|
||||||
)
|
)
|
||||||
if self.use_gripper:
|
if self.use_gripper:
|
||||||
# Quantize gripper command to -1, 0 or 1
|
target_joint_pos[-1] = gripper_command
|
||||||
if gripper_command < -0.2:
|
|
||||||
gripper_command = -1.0
|
|
||||||
elif gripper_command > 0.2:
|
|
||||||
gripper_command = 1.0
|
|
||||||
else:
|
|
||||||
gripper_command = 0.0
|
|
||||||
|
|
||||||
gripper_state = self.unwrapped.robot.follower_arms["main"].read("Present_Position")[-1]
|
|
||||||
gripper_action = np.clip(gripper_state + gripper_command, 0, MAX_GRIPPER_COMMAND)
|
|
||||||
target_joint_pos[-1] = gripper_action
|
|
||||||
|
|
||||||
return target_joint_pos, is_intervention
|
return target_joint_pos, is_intervention
|
||||||
|
|
||||||
|
@ -1069,31 +1115,6 @@ class ActionScaleWrapper(gym.ActionWrapper):
|
||||||
return action * self.scale_vector, is_intervention
|
return action * self.scale_vector, is_intervention
|
||||||
|
|
||||||
|
|
||||||
class GripperPenaltyWrapper(gym.Wrapper):
|
|
||||||
def __init__(self, env, penalty=-0.05):
|
|
||||||
super().__init__(env)
|
|
||||||
self.penalty = penalty
|
|
||||||
self.last_gripper_pos = None
|
|
||||||
|
|
||||||
def reset(self, **kwargs):
|
|
||||||
obs, info = self.env.reset(**kwargs)
|
|
||||||
self.last_gripper_pos = obs["observation.state"][0, 0] # first idx for the gripper
|
|
||||||
return obs, info
|
|
||||||
|
|
||||||
def step(self, action):
|
|
||||||
observation, reward, terminated, truncated, info = self.env.step(action)
|
|
||||||
|
|
||||||
if (action[-1] < -0.5 and self.last_gripper_pos > 0.9) or (
|
|
||||||
action[-1] > 0.5 and self.last_gripper_pos < 0.9
|
|
||||||
):
|
|
||||||
info["grasp_penalty"] = self.penalty
|
|
||||||
else:
|
|
||||||
info["grasp_penalty"] = 0.0
|
|
||||||
|
|
||||||
self.last_gripper_pos = observation["observation.state"][0, 0] # first idx for the gripper
|
|
||||||
return observation, reward, terminated, truncated, info
|
|
||||||
|
|
||||||
|
|
||||||
def make_robot_env(cfg) -> gym.vector.VectorEnv:
|
def make_robot_env(cfg) -> gym.vector.VectorEnv:
|
||||||
"""
|
"""
|
||||||
Factory function to create a vectorized robot environment.
|
Factory function to create a vectorized robot environment.
|
||||||
|
@ -1143,6 +1164,12 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv:
|
||||||
# Add reward computation and control wrappers
|
# Add reward computation and control wrappers
|
||||||
# env = RewardWrapper(env=env, reward_classifier=reward_classifier, device=cfg.device)
|
# env = RewardWrapper(env=env, reward_classifier=reward_classifier, device=cfg.device)
|
||||||
env = TimeLimitWrapper(env=env, control_time_s=cfg.wrapper.control_time_s, fps=cfg.fps)
|
env = TimeLimitWrapper(env=env, control_time_s=cfg.wrapper.control_time_s, fps=cfg.fps)
|
||||||
|
if cfg.wrapper.use_gripper:
|
||||||
|
env = GripperQuantizationWrapper(
|
||||||
|
env=env, quantization_threshold=cfg.wrapper.gripper_quantization_threshold
|
||||||
|
)
|
||||||
|
# env = GripperPenaltyWrapper(env=env, penalty=cfg.wrapper.gripper_penalty)
|
||||||
|
|
||||||
if cfg.wrapper.ee_action_space_params is not None:
|
if cfg.wrapper.ee_action_space_params is not None:
|
||||||
env = EEActionWrapper(
|
env = EEActionWrapper(
|
||||||
env=env,
|
env=env,
|
||||||
|
@ -1169,7 +1196,6 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv:
|
||||||
if cfg.wrapper.ee_action_space_params is None and cfg.wrapper.joint_masking_action_space is not None:
|
if cfg.wrapper.ee_action_space_params is None and cfg.wrapper.joint_masking_action_space is not None:
|
||||||
env = JointMaskingActionSpace(env=env, mask=cfg.wrapper.joint_masking_action_space)
|
env = JointMaskingActionSpace(env=env, mask=cfg.wrapper.joint_masking_action_space)
|
||||||
env = BatchCompitableWrapper(env=env)
|
env = BatchCompitableWrapper(env=env)
|
||||||
env = GripperPenaltyWrapper(env=env)
|
|
||||||
|
|
||||||
return env
|
return env
|
||||||
|
|
||||||
|
|
|
@ -375,7 +375,6 @@ def add_actor_information_and_train(
|
||||||
observations = batch["state"]
|
observations = batch["state"]
|
||||||
next_observations = batch["next_state"]
|
next_observations = batch["next_state"]
|
||||||
done = batch["done"]
|
done = batch["done"]
|
||||||
complementary_info = batch["complementary_info"]
|
|
||||||
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
|
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
|
||||||
|
|
||||||
observation_features, next_observation_features = get_observation_features(
|
observation_features, next_observation_features = get_observation_features(
|
||||||
|
@ -391,7 +390,6 @@ def add_actor_information_and_train(
|
||||||
"done": done,
|
"done": done,
|
||||||
"observation_feature": observation_features,
|
"observation_feature": observation_features,
|
||||||
"next_observation_feature": next_observation_features,
|
"next_observation_feature": next_observation_features,
|
||||||
"complementary_info": complementary_info,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# Use the forward method for critic loss
|
# Use the forward method for critic loss
|
||||||
|
@ -450,7 +448,6 @@ def add_actor_information_and_train(
|
||||||
"done": done,
|
"done": done,
|
||||||
"observation_feature": observation_features,
|
"observation_feature": observation_features,
|
||||||
"next_observation_feature": next_observation_features,
|
"next_observation_feature": next_observation_features,
|
||||||
"complementary_info": complementary_info,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# Use the forward method for critic loss
|
# Use the forward method for critic loss
|
||||||
|
|
Loading…
Reference in New Issue