General fixes in code, removed delta action, fixed grasp penalty, added logic to put gripper reward in info
This commit is contained in:
parent
e7edf2a8d8
commit
5428ab96f5
|
@ -171,7 +171,6 @@ class VideoRecordConfig:
|
|||
class WrapperConfig:
|
||||
"""Configuration for environment wrappers."""
|
||||
|
||||
delta_action: float | None = None
|
||||
joint_masking_action_space: list[bool] | None = None
|
||||
|
||||
|
||||
|
@ -191,7 +190,6 @@ class EnvWrapperConfig:
|
|||
"""Configuration for environment wrappers."""
|
||||
|
||||
display_cameras: bool = False
|
||||
delta_action: float = 0.1
|
||||
use_relative_joint_positions: bool = True
|
||||
add_joint_velocity_to_observation: bool = False
|
||||
add_ee_pose_to_observation: bool = False
|
||||
|
@ -203,11 +201,13 @@ class EnvWrapperConfig:
|
|||
joint_masking_action_space: Optional[Any] = None
|
||||
ee_action_space_params: Optional[EEActionSpaceConfig] = None
|
||||
use_gripper: bool = False
|
||||
gripper_quantization_threshold: float = 0.8
|
||||
gripper_penalty: float = 0.0
|
||||
gripper_quantization_threshold: float | None = 0.8
|
||||
gripper_penalty: float = 0.0
|
||||
gripper_penalty_in_reward: bool = False
|
||||
open_gripper_on_reset: bool = False
|
||||
|
||||
|
||||
|
||||
@EnvConfig.register_subclass(name="gym_manipulator")
|
||||
@dataclass
|
||||
class HILSerlRobotEnvConfig(EnvConfig):
|
||||
|
|
|
@ -428,6 +428,7 @@ class SACPolicy(
|
|||
actions_discrete = torch.round(actions_discrete)
|
||||
actions_discrete = actions_discrete.long()
|
||||
|
||||
gripper_penalties: Tensor | None = None
|
||||
if complementary_info is not None:
|
||||
gripper_penalties: Tensor | None = complementary_info.get("gripper_penalty")
|
||||
|
||||
|
|
|
@ -221,7 +221,6 @@ def record_episode(
|
|||
events=events,
|
||||
policy=policy,
|
||||
fps=fps,
|
||||
# record_delta_actions=record_delta_actions,
|
||||
teleoperate=policy is None,
|
||||
single_task=single_task,
|
||||
)
|
||||
|
@ -267,8 +266,6 @@ def control_loop(
|
|||
|
||||
if teleoperate:
|
||||
observation, action = robot.teleop_step(record_data=True)
|
||||
# if record_delta_actions:
|
||||
# action["action"] = action["action"] - current_joint_positions
|
||||
else:
|
||||
observation = robot.capture_observation()
|
||||
|
||||
|
|
|
@ -363,8 +363,6 @@ def replay(
|
|||
start_episode_t = time.perf_counter()
|
||||
|
||||
action = actions[idx]["action"]
|
||||
# if replay_delta_actions:
|
||||
# action = action + current_joint_positions
|
||||
robot.send_action(action)
|
||||
|
||||
dt_s = time.perf_counter() - start_episode_t
|
||||
|
|
|
@ -78,9 +78,7 @@ def move_transition_to_device(transition: Transition, device: str = "cpu") -> Tr
|
|||
if isinstance(val, torch.Tensor):
|
||||
transition["complementary_info"][key] = val.to(device, non_blocking=non_blocking)
|
||||
elif isinstance(val, (int, float, bool)):
|
||||
transition["complementary_info"][key] = torch.tensor(
|
||||
val, device=device, non_blocking=non_blocking
|
||||
)
|
||||
transition["complementary_info"][key] = torch.tensor(val, device=device)
|
||||
else:
|
||||
raise ValueError(f"Unsupported type {type(val)} for complementary_info[{key}]")
|
||||
return transition
|
||||
|
@ -505,7 +503,6 @@ class ReplayBuffer:
|
|||
state_keys: Optional[Sequence[str]] = None,
|
||||
capacity: Optional[int] = None,
|
||||
action_mask: Optional[Sequence[int]] = None,
|
||||
action_delta: Optional[float] = None,
|
||||
image_augmentation_function: Optional[Callable] = None,
|
||||
use_drq: bool = True,
|
||||
storage_device: str = "cpu",
|
||||
|
@ -520,7 +517,6 @@ class ReplayBuffer:
|
|||
state_keys (Optional[Sequence[str]]): The list of keys that appear in `state` and `next_state`.
|
||||
capacity (Optional[int]): Buffer capacity. If None, uses dataset length.
|
||||
action_mask (Optional[Sequence[int]]): Indices of action dimensions to keep.
|
||||
action_delta (Optional[float]): Factor to divide actions by.
|
||||
image_augmentation_function (Optional[Callable]): Function for image augmentation.
|
||||
If None, uses default random shift with pad=4.
|
||||
use_drq (bool): Whether to use DrQ image augmentation when sampling.
|
||||
|
@ -565,9 +561,6 @@ class ReplayBuffer:
|
|||
else:
|
||||
first_action = first_action[:, action_mask]
|
||||
|
||||
if action_delta is not None:
|
||||
first_action = first_action / action_delta
|
||||
|
||||
# Get complementary info if available
|
||||
first_complementary_info = None
|
||||
if (
|
||||
|
@ -598,9 +591,6 @@ class ReplayBuffer:
|
|||
else:
|
||||
action = action[:, action_mask]
|
||||
|
||||
if action_delta is not None:
|
||||
action = action / action_delta
|
||||
|
||||
replay_buffer.add(
|
||||
state=data["state"],
|
||||
action=action,
|
||||
|
|
|
@ -42,7 +42,6 @@ class HILSerlRobotEnv(gym.Env):
|
|||
self,
|
||||
robot,
|
||||
use_delta_action_space: bool = True,
|
||||
delta: float | None = None,
|
||||
display_cameras: bool = False,
|
||||
):
|
||||
"""
|
||||
|
@ -55,8 +54,6 @@ class HILSerlRobotEnv(gym.Env):
|
|||
robot: The robot interface object used to connect and interact with the physical robot.
|
||||
use_delta_action_space (bool): If True, uses a delta (relative) action space for joint control. Otherwise, absolute
|
||||
joint positions are used.
|
||||
delta (float or None): A scaling factor for the relative adjustments applied to joint positions. Should be a value between
|
||||
0 and 1 when using a delta action space.
|
||||
display_cameras (bool): If True, the robot's camera feeds will be displayed during execution.
|
||||
"""
|
||||
super().__init__()
|
||||
|
@ -74,7 +71,6 @@ class HILSerlRobotEnv(gym.Env):
|
|||
self.current_step = 0
|
||||
self.episode_data = None
|
||||
|
||||
self.delta = delta
|
||||
self.use_delta_action_space = use_delta_action_space
|
||||
self.current_joint_positions = self.robot.follower_arms["main"].read("Present_Position")
|
||||
|
||||
|
@ -374,7 +370,7 @@ class RewardWrapper(gym.Wrapper):
|
|||
self.device = device
|
||||
|
||||
def step(self, action):
|
||||
observation, _, terminated, truncated, info = self.env.step(action)
|
||||
observation, reward, terminated, truncated, info = self.env.step(action)
|
||||
images = [
|
||||
observation[key].to(self.device, non_blocking=self.device.type == "cuda")
|
||||
for key in observation
|
||||
|
@ -382,15 +378,17 @@ class RewardWrapper(gym.Wrapper):
|
|||
]
|
||||
start_time = time.perf_counter()
|
||||
with torch.inference_mode():
|
||||
reward = (
|
||||
success = (
|
||||
self.reward_classifier.predict_reward(images, threshold=0.8)
|
||||
if self.reward_classifier is not None
|
||||
else 0.0
|
||||
)
|
||||
info["Reward classifer frequency"] = 1 / (time.perf_counter() - start_time)
|
||||
|
||||
if reward == 1.0:
|
||||
if success == 1.0:
|
||||
terminated = True
|
||||
reward = 1.0
|
||||
|
||||
return observation, reward, terminated, truncated, info
|
||||
|
||||
def reset(self, seed=None, options=None):
|
||||
|
@ -720,19 +718,31 @@ class ResetWrapper(gym.Wrapper):
|
|||
env: HILSerlRobotEnv,
|
||||
reset_pose: np.ndarray | None = None,
|
||||
reset_time_s: float = 5,
|
||||
open_gripper_on_reset: bool = False
|
||||
):
|
||||
super().__init__(env)
|
||||
self.reset_time_s = reset_time_s
|
||||
self.reset_pose = reset_pose
|
||||
self.robot = self.unwrapped.robot
|
||||
self.open_gripper_on_reset = open_gripper_on_reset
|
||||
|
||||
def reset(self, *, seed=None, options=None):
|
||||
|
||||
|
||||
if self.reset_pose is not None:
|
||||
start_time = time.perf_counter()
|
||||
log_say("Reset the environment.", play_sounds=True)
|
||||
reset_follower_position(self.robot, self.reset_pose)
|
||||
busy_wait(self.reset_time_s - (time.perf_counter() - start_time))
|
||||
log_say("Reset the environment done.", play_sounds=True)
|
||||
if self.open_gripper_on_reset:
|
||||
current_joint_pos = self.robot.follower_arms["main"].read("Present_Position")
|
||||
current_joint_pos[-1] = MAX_GRIPPER_COMMAND
|
||||
self.robot.send_action(torch.from_numpy(current_joint_pos))
|
||||
busy_wait(0.1)
|
||||
current_joint_pos[-1] = 0.0
|
||||
self.robot.send_action(torch.from_numpy(current_joint_pos))
|
||||
busy_wait(0.2)
|
||||
else:
|
||||
log_say(
|
||||
f"Manually reset the environment for {self.reset_time_s} seconds.",
|
||||
|
@ -762,37 +772,48 @@ class BatchCompitableWrapper(gym.ObservationWrapper):
|
|||
|
||||
|
||||
class GripperPenaltyWrapper(gym.RewardWrapper):
|
||||
def __init__(self, env, penalty: float = -0.1):
|
||||
def __init__(self, env, penalty: float = -0.1, gripper_penalty_in_reward: bool = True):
|
||||
super().__init__(env)
|
||||
self.penalty = penalty
|
||||
self.gripper_penalty_in_reward = gripper_penalty_in_reward
|
||||
self.last_gripper_state = None
|
||||
|
||||
|
||||
def reward(self, reward, action):
|
||||
gripper_state_normalized = self.last_gripper_state / MAX_GRIPPER_COMMAND
|
||||
|
||||
if isinstance(action, tuple):
|
||||
action = action[0]
|
||||
action_normalized = action[-1] / MAX_GRIPPER_COMMAND
|
||||
action_normalized = action - 1.0 #action / MAX_GRIPPER_COMMAND
|
||||
|
||||
gripper_penalty_bool = (gripper_state_normalized < 0.1 and action_normalized > 0.9) or (
|
||||
gripper_state_normalized > 0.9 and action_normalized < 0.1
|
||||
gripper_penalty_bool = (gripper_state_normalized < 0.5 and action_normalized > 0.5) or (
|
||||
gripper_state_normalized > 0.75 and action_normalized < -0.5
|
||||
)
|
||||
breakpoint()
|
||||
|
||||
return reward + self.penalty * gripper_penalty_bool
|
||||
return reward + self.penalty * int(gripper_penalty_bool)
|
||||
|
||||
def step(self, action):
|
||||
self.last_gripper_state = self.unwrapped.robot.follower_arms["main"].read("Present_Position")[-1]
|
||||
if isinstance(action, tuple):
|
||||
gripper_action = action[0][-1]
|
||||
else:
|
||||
gripper_action = action[-1]
|
||||
obs, reward, terminated, truncated, info = self.env.step(action)
|
||||
reward = self.reward(reward, action)
|
||||
gripper_penalty = self.reward(reward, gripper_action)
|
||||
|
||||
if self.gripper_penalty_in_reward:
|
||||
reward += gripper_penalty
|
||||
else:
|
||||
info["gripper_penalty"] = gripper_penalty
|
||||
|
||||
return obs, reward, terminated, truncated, info
|
||||
|
||||
def reset(self, **kwargs):
|
||||
self.last_gripper_state = None
|
||||
return super().reset(**kwargs)
|
||||
obs, info = super().reset(**kwargs)
|
||||
if self.gripper_penalty_in_reward:
|
||||
info["gripper_penalty"] = 0.0
|
||||
return obs, info
|
||||
|
||||
|
||||
class GripperQuantizationWrapper(gym.ActionWrapper):
|
||||
class GripperActionWrapper(gym.ActionWrapper):
|
||||
def __init__(self, env, quantization_threshold: float = 0.2):
|
||||
super().__init__(env)
|
||||
self.quantization_threshold = quantization_threshold
|
||||
|
@ -801,16 +822,18 @@ class GripperQuantizationWrapper(gym.ActionWrapper):
|
|||
is_intervention = False
|
||||
if isinstance(action, tuple):
|
||||
action, is_intervention = action
|
||||
|
||||
gripper_command = action[-1]
|
||||
# Quantize gripper command to -1, 0 or 1
|
||||
if gripper_command < -self.quantization_threshold:
|
||||
gripper_command = -MAX_GRIPPER_COMMAND
|
||||
elif gripper_command > self.quantization_threshold:
|
||||
gripper_command = MAX_GRIPPER_COMMAND
|
||||
else:
|
||||
gripper_command = 0.0
|
||||
|
||||
# Gripper actions are between 0, 2
|
||||
# we want to quantize them to -1, 0 or 1
|
||||
gripper_command = gripper_command - 1.0
|
||||
|
||||
if self.quantization_threshold is not None:
|
||||
# Quantize gripper command to -1, 0 or 1
|
||||
gripper_command = (
|
||||
np.sign(gripper_command) if abs(gripper_command) > self.quantization_threshold else 0.0
|
||||
)
|
||||
gripper_command = gripper_command * MAX_GRIPPER_COMMAND
|
||||
gripper_state = self.unwrapped.robot.follower_arms["main"].read("Present_Position")[-1]
|
||||
gripper_action = np.clip(gripper_state + gripper_command, 0, MAX_GRIPPER_COMMAND)
|
||||
action[-1] = gripper_action.item()
|
||||
|
@ -836,10 +859,12 @@ class EEActionWrapper(gym.ActionWrapper):
|
|||
]
|
||||
)
|
||||
if self.use_gripper:
|
||||
action_space_bounds = np.concatenate([action_space_bounds, [1.0]])
|
||||
# gripper actions open at 2.0, and closed at 0.0
|
||||
min_action_space_bounds = np.concatenate([-action_space_bounds, [0.0]])
|
||||
max_action_space_bounds = np.concatenate([action_space_bounds, [2.0]])
|
||||
ee_action_space = gym.spaces.Box(
|
||||
low=-action_space_bounds,
|
||||
high=action_space_bounds,
|
||||
low=min_action_space_bounds,
|
||||
high=max_action_space_bounds,
|
||||
shape=(3 + int(self.use_gripper),),
|
||||
dtype=np.float32,
|
||||
)
|
||||
|
@ -997,11 +1022,11 @@ class GamepadControlWrapper(gym.Wrapper):
|
|||
if self.use_gripper:
|
||||
gripper_command = self.controller.gripper_command()
|
||||
if gripper_command == "open":
|
||||
gamepad_action = np.concatenate([gamepad_action, [1.0]])
|
||||
gamepad_action = np.concatenate([gamepad_action, [2.0]])
|
||||
elif gripper_command == "close":
|
||||
gamepad_action = np.concatenate([gamepad_action, [-1.0]])
|
||||
else:
|
||||
gamepad_action = np.concatenate([gamepad_action, [0.0]])
|
||||
else:
|
||||
gamepad_action = np.concatenate([gamepad_action, [1.0]])
|
||||
|
||||
# Check episode ending buttons
|
||||
# We'll rely on controller.get_episode_end_status() which returns "success", "failure", or None
|
||||
|
@ -1141,7 +1166,6 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv:
|
|||
env = HILSerlRobotEnv(
|
||||
robot=robot,
|
||||
display_cameras=cfg.wrapper.display_cameras,
|
||||
delta=cfg.wrapper.delta_action,
|
||||
use_delta_action_space=cfg.wrapper.use_relative_joint_positions
|
||||
and cfg.wrapper.ee_action_space_params is None,
|
||||
)
|
||||
|
@ -1165,10 +1189,11 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv:
|
|||
# env = RewardWrapper(env=env, reward_classifier=reward_classifier, device=cfg.device)
|
||||
env = TimeLimitWrapper(env=env, control_time_s=cfg.wrapper.control_time_s, fps=cfg.fps)
|
||||
if cfg.wrapper.use_gripper:
|
||||
env = GripperQuantizationWrapper(
|
||||
env = GripperActionWrapper(
|
||||
env=env, quantization_threshold=cfg.wrapper.gripper_quantization_threshold
|
||||
)
|
||||
# env = GripperPenaltyWrapper(env=env, penalty=cfg.wrapper.gripper_penalty)
|
||||
if cfg.wrapper.gripper_penalty is not None:
|
||||
env = GripperPenaltyWrapper(env=env, penalty=cfg.wrapper.gripper_penalty, gripper_penalty_in_reward=cfg.wrapper.gripper_penalty_in_reward)
|
||||
|
||||
if cfg.wrapper.ee_action_space_params is not None:
|
||||
env = EEActionWrapper(
|
||||
|
@ -1176,6 +1201,7 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv:
|
|||
ee_action_space_params=cfg.wrapper.ee_action_space_params,
|
||||
use_gripper=cfg.wrapper.use_gripper,
|
||||
)
|
||||
|
||||
if cfg.wrapper.ee_action_space_params is not None and cfg.wrapper.ee_action_space_params.use_gamepad:
|
||||
# env = ActionScaleWrapper(env=env, ee_action_space_params=cfg.wrapper.ee_action_space_params)
|
||||
env = GamepadControlWrapper(
|
||||
|
@ -1192,6 +1218,7 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv:
|
|||
env=env,
|
||||
reset_pose=cfg.wrapper.fixed_reset_joint_positions,
|
||||
reset_time_s=cfg.wrapper.reset_time_s,
|
||||
open_gripper_on_reset=cfg.wrapper.open_gripper_on_reset
|
||||
)
|
||||
if cfg.wrapper.ee_action_space_params is None and cfg.wrapper.joint_masking_action_space is not None:
|
||||
env = JointMaskingActionSpace(env=env, mask=cfg.wrapper.joint_masking_action_space)
|
||||
|
@ -1341,11 +1368,10 @@ def record_dataset(env, policy, cfg):
|
|||
dataset.push_to_hub()
|
||||
|
||||
|
||||
def replay_episode(env, repo_id, root=None, episode=0):
|
||||
def replay_episode(env, cfg):
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
local_files_only = root is not None
|
||||
dataset = LeRobotDataset(repo_id, root=root, episodes=[episode], local_files_only=local_files_only)
|
||||
dataset = LeRobotDataset(cfg.repo_id, root=cfg.dataset_root, episodes=[cfg.episode])
|
||||
env.reset()
|
||||
|
||||
actions = dataset.hf_dataset.select_columns("action")
|
||||
|
@ -1353,7 +1379,7 @@ def replay_episode(env, repo_id, root=None, episode=0):
|
|||
for idx in range(dataset.num_frames):
|
||||
start_episode_t = time.perf_counter()
|
||||
|
||||
action = actions[idx]["action"][:4]
|
||||
action = actions[idx]["action"]
|
||||
env.step((action, False))
|
||||
# env.step((action / env.unwrapped.delta, False))
|
||||
|
||||
|
@ -1384,9 +1410,7 @@ def main(cfg: EnvConfig):
|
|||
if cfg.mode == "replay":
|
||||
replay_episode(
|
||||
env,
|
||||
cfg.replay_repo_id,
|
||||
root=cfg.dataset_root,
|
||||
episode=cfg.replay_episode,
|
||||
cfg=cfg,
|
||||
)
|
||||
exit()
|
||||
|
||||
|
|
|
@ -406,7 +406,8 @@ def add_actor_information_and_train(
|
|||
"next_state": next_observations,
|
||||
"done": done,
|
||||
"observation_feature": observation_features,
|
||||
"next_observation_feature": next_observation_features,
|
||||
"next_observation_feature": next_observation_features,
|
||||
"complementary_info": batch["complementary_info"],
|
||||
}
|
||||
|
||||
# Use the forward method for critic loss (includes both main critic and grasp critic)
|
||||
|
@ -992,7 +993,6 @@ def initialize_offline_replay_buffer(
|
|||
device=device,
|
||||
state_keys=cfg.policy.input_features.keys(),
|
||||
action_mask=active_action_dims,
|
||||
action_delta=cfg.env.wrapper.delta_action,
|
||||
storage_device=storage_device,
|
||||
optimize_memory=True,
|
||||
capacity=cfg.policy.offline_buffer_capacity,
|
||||
|
|
Loading…
Reference in New Issue