modifications to gym_manipulator and buffer
This commit is contained in:
parent
ab2c2d39fb
commit
f3cea2a3e5
|
@ -171,7 +171,6 @@ class VideoRecordConfig:
|
||||||
class WrapperConfig:
|
class WrapperConfig:
|
||||||
"""Configuration for environment wrappers."""
|
"""Configuration for environment wrappers."""
|
||||||
|
|
||||||
delta_action: float | None = None
|
|
||||||
joint_masking_action_space: list[bool] | None = None
|
joint_masking_action_space: list[bool] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@ -191,7 +190,6 @@ class EnvWrapperConfig:
|
||||||
"""Configuration for environment wrappers."""
|
"""Configuration for environment wrappers."""
|
||||||
|
|
||||||
display_cameras: bool = False
|
display_cameras: bool = False
|
||||||
delta_action: float = 0.1
|
|
||||||
use_relative_joint_positions: bool = True
|
use_relative_joint_positions: bool = True
|
||||||
add_joint_velocity_to_observation: bool = False
|
add_joint_velocity_to_observation: bool = False
|
||||||
add_ee_pose_to_observation: bool = False
|
add_ee_pose_to_observation: bool = False
|
||||||
|
@ -203,7 +201,7 @@ class EnvWrapperConfig:
|
||||||
joint_masking_action_space: Optional[Any] = None
|
joint_masking_action_space: Optional[Any] = None
|
||||||
ee_action_space_params: Optional[EEActionSpaceConfig] = None
|
ee_action_space_params: Optional[EEActionSpaceConfig] = None
|
||||||
use_gripper: bool = False
|
use_gripper: bool = False
|
||||||
gripper_quantization_threshold: float = 0.8
|
gripper_quantization_threshold: float | None = None
|
||||||
gripper_penalty: float = 0.0
|
gripper_penalty: float = 0.0
|
||||||
open_gripper_on_reset: bool = False
|
open_gripper_on_reset: bool = False
|
||||||
|
|
||||||
|
|
|
@ -144,8 +144,8 @@ class SACPolicy(
|
||||||
|
|
||||||
self.grasp_critic_target.load_state_dict(self.grasp_critic.state_dict())
|
self.grasp_critic_target.load_state_dict(self.grasp_critic.state_dict())
|
||||||
|
|
||||||
self.grasp_critic = torch.compile(self.grasp_critic)
|
# self.grasp_critic = torch.compile(self.grasp_critic)
|
||||||
self.grasp_critic_target = torch.compile(self.grasp_critic_target)
|
# self.grasp_critic_target = torch.compile(self.grasp_critic_target)
|
||||||
|
|
||||||
self.actor = Policy(
|
self.actor = Policy(
|
||||||
encoder=encoder_actor,
|
encoder=encoder_actor,
|
||||||
|
@ -224,6 +224,10 @@ class SACPolicy(
|
||||||
|
|
||||||
critics = self.critic_target if use_target else self.critic_ensemble
|
critics = self.critic_target if use_target else self.critic_ensemble
|
||||||
q_values = critics(observations, actions, observation_features)
|
q_values = critics(observations, actions, observation_features)
|
||||||
|
if not use_target:
|
||||||
|
for name, param in critics.named_parameters():
|
||||||
|
if param.requires_grad:
|
||||||
|
print(f"Critic Ensemble layer {name}, norm {param.data.norm().item()}")
|
||||||
return q_values
|
return q_values
|
||||||
|
|
||||||
def grasp_critic_forward(self, observations, use_target=False, observation_features=None) -> torch.Tensor:
|
def grasp_critic_forward(self, observations, use_target=False, observation_features=None) -> torch.Tensor:
|
||||||
|
@ -239,6 +243,10 @@ class SACPolicy(
|
||||||
"""
|
"""
|
||||||
grasp_critic = self.grasp_critic_target if use_target else self.grasp_critic
|
grasp_critic = self.grasp_critic_target if use_target else self.grasp_critic
|
||||||
q_values = grasp_critic(observations, observation_features)
|
q_values = grasp_critic(observations, observation_features)
|
||||||
|
if not use_target:
|
||||||
|
for name, param in grasp_critic.named_parameters():
|
||||||
|
if param.requires_grad:
|
||||||
|
print(f"Grasp critic layer {name}, norm {param.data.norm().item()}")
|
||||||
return q_values
|
return q_values
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
|
|
@ -221,7 +221,6 @@ def record_episode(
|
||||||
events=events,
|
events=events,
|
||||||
policy=policy,
|
policy=policy,
|
||||||
fps=fps,
|
fps=fps,
|
||||||
# record_delta_actions=record_delta_actions,
|
|
||||||
teleoperate=policy is None,
|
teleoperate=policy is None,
|
||||||
single_task=single_task,
|
single_task=single_task,
|
||||||
)
|
)
|
||||||
|
@ -267,8 +266,6 @@ def control_loop(
|
||||||
|
|
||||||
if teleoperate:
|
if teleoperate:
|
||||||
observation, action = robot.teleop_step(record_data=True)
|
observation, action = robot.teleop_step(record_data=True)
|
||||||
# if record_delta_actions:
|
|
||||||
# action["action"] = action["action"] - current_joint_positions
|
|
||||||
else:
|
else:
|
||||||
observation = robot.capture_observation()
|
observation = robot.capture_observation()
|
||||||
|
|
||||||
|
|
|
@ -250,28 +250,18 @@ def act_with_policy(
|
||||||
logging.info("[ACTOR] Shutting down act_with_policy")
|
logging.info("[ACTOR] Shutting down act_with_policy")
|
||||||
return
|
return
|
||||||
|
|
||||||
if interaction_step >= cfg.policy.online_step_before_learning:
|
# Time policy inference and check if it meets FPS requirement
|
||||||
# Time policy inference and check if it meets FPS requirement
|
with TimerManager(
|
||||||
with TimerManager(
|
elapsed_time_list=list_policy_time,
|
||||||
elapsed_time_list=list_policy_time,
|
label="Policy inference time",
|
||||||
label="Policy inference time",
|
log=False,
|
||||||
log=False,
|
) as timer: # noqa: F841
|
||||||
) as timer: # noqa: F841
|
action = policy.select_action(batch=obs)
|
||||||
action = policy.select_action(batch=obs)
|
policy_fps = 1.0 / (list_policy_time[-1] + 1e-9)
|
||||||
policy_fps = 1.0 / (list_policy_time[-1] + 1e-9)
|
|
||||||
|
|
||||||
log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step)
|
log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step)
|
||||||
|
|
||||||
next_obs, reward, done, truncated, info = online_env.step(action.squeeze(dim=0).cpu().numpy())
|
next_obs, reward, done, truncated, info = online_env.step(action.squeeze(dim=0).cpu().numpy())
|
||||||
else:
|
|
||||||
# TODO (azouitine): Make a custom space for torch tensor
|
|
||||||
action = online_env.action_space.sample()
|
|
||||||
next_obs, reward, done, truncated, info = online_env.step(action)
|
|
||||||
|
|
||||||
# HACK: We have only one env but we want to batch it, it will be resolved with the torch box
|
|
||||||
action = (
|
|
||||||
torch.from_numpy(action[0]).to(device, non_blocking=device.type == "cuda").unsqueeze(dim=0)
|
|
||||||
)
|
|
||||||
|
|
||||||
sum_reward_episode += float(reward)
|
sum_reward_episode += float(reward)
|
||||||
# Increment total steps counter for intervention rate
|
# Increment total steps counter for intervention rate
|
||||||
|
|
|
@ -505,7 +505,6 @@ class ReplayBuffer:
|
||||||
state_keys: Optional[Sequence[str]] = None,
|
state_keys: Optional[Sequence[str]] = None,
|
||||||
capacity: Optional[int] = None,
|
capacity: Optional[int] = None,
|
||||||
action_mask: Optional[Sequence[int]] = None,
|
action_mask: Optional[Sequence[int]] = None,
|
||||||
action_delta: Optional[float] = None,
|
|
||||||
image_augmentation_function: Optional[Callable] = None,
|
image_augmentation_function: Optional[Callable] = None,
|
||||||
use_drq: bool = True,
|
use_drq: bool = True,
|
||||||
storage_device: str = "cpu",
|
storage_device: str = "cpu",
|
||||||
|
@ -520,7 +519,6 @@ class ReplayBuffer:
|
||||||
state_keys (Optional[Sequence[str]]): The list of keys that appear in `state` and `next_state`.
|
state_keys (Optional[Sequence[str]]): The list of keys that appear in `state` and `next_state`.
|
||||||
capacity (Optional[int]): Buffer capacity. If None, uses dataset length.
|
capacity (Optional[int]): Buffer capacity. If None, uses dataset length.
|
||||||
action_mask (Optional[Sequence[int]]): Indices of action dimensions to keep.
|
action_mask (Optional[Sequence[int]]): Indices of action dimensions to keep.
|
||||||
action_delta (Optional[float]): Factor to divide actions by.
|
|
||||||
image_augmentation_function (Optional[Callable]): Function for image augmentation.
|
image_augmentation_function (Optional[Callable]): Function for image augmentation.
|
||||||
If None, uses default random shift with pad=4.
|
If None, uses default random shift with pad=4.
|
||||||
use_drq (bool): Whether to use DrQ image augmentation when sampling.
|
use_drq (bool): Whether to use DrQ image augmentation when sampling.
|
||||||
|
@ -565,8 +563,6 @@ class ReplayBuffer:
|
||||||
else:
|
else:
|
||||||
first_action = first_action[:, action_mask]
|
first_action = first_action[:, action_mask]
|
||||||
|
|
||||||
if action_delta is not None:
|
|
||||||
first_action = first_action / action_delta
|
|
||||||
|
|
||||||
# Get complementary info if available
|
# Get complementary info if available
|
||||||
first_complementary_info = None
|
first_complementary_info = None
|
||||||
|
@ -598,8 +594,6 @@ class ReplayBuffer:
|
||||||
else:
|
else:
|
||||||
action = action[:, action_mask]
|
action = action[:, action_mask]
|
||||||
|
|
||||||
if action_delta is not None:
|
|
||||||
action = action / action_delta
|
|
||||||
|
|
||||||
replay_buffer.add(
|
replay_buffer.add(
|
||||||
state=data["state"],
|
state=data["state"],
|
||||||
|
|
|
@ -258,25 +258,25 @@ class GamepadController(InputController):
|
||||||
elif event.button == 0:
|
elif event.button == 0:
|
||||||
self.episode_end_status = "rerecord_episode"
|
self.episode_end_status = "rerecord_episode"
|
||||||
|
|
||||||
# RB button (6) for opening gripper
|
# LT button for closing gripper
|
||||||
elif event.button == 6:
|
elif event.button == 6:
|
||||||
self.open_gripper_command = True
|
|
||||||
|
|
||||||
# LT button (7) for closing gripper
|
|
||||||
elif event.button == 7:
|
|
||||||
self.close_gripper_command = True
|
self.close_gripper_command = True
|
||||||
|
|
||||||
|
# RB button for opening gripper
|
||||||
|
elif event.button == 7:
|
||||||
|
self.open_gripper_command = True
|
||||||
|
|
||||||
# Reset episode status on button release
|
# Reset episode status on button release
|
||||||
elif event.type == pygame.JOYBUTTONUP:
|
elif event.type == pygame.JOYBUTTONUP:
|
||||||
if event.button in [0, 2, 3]:
|
if event.button in [0, 2, 3]:
|
||||||
self.episode_end_status = None
|
self.episode_end_status = None
|
||||||
|
|
||||||
elif event.button == 6:
|
if event.button == 6:
|
||||||
self.open_gripper_command = False
|
|
||||||
|
|
||||||
elif event.button == 7:
|
|
||||||
self.close_gripper_command = False
|
self.close_gripper_command = False
|
||||||
|
|
||||||
|
if event.button == 7:
|
||||||
|
self.open_gripper_command = False
|
||||||
|
|
||||||
# Check for RB button (typically button 5) for intervention flag
|
# Check for RB button (typically button 5) for intervention flag
|
||||||
if self.joystick.get_button(5):
|
if self.joystick.get_button(5):
|
||||||
self.intervention_flag = True
|
self.intervention_flag = True
|
||||||
|
|
|
@ -42,7 +42,6 @@ class HILSerlRobotEnv(gym.Env):
|
||||||
self,
|
self,
|
||||||
robot,
|
robot,
|
||||||
use_delta_action_space: bool = True,
|
use_delta_action_space: bool = True,
|
||||||
delta: float | None = None,
|
|
||||||
display_cameras: bool = False,
|
display_cameras: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
@ -55,8 +54,6 @@ class HILSerlRobotEnv(gym.Env):
|
||||||
robot: The robot interface object used to connect and interact with the physical robot.
|
robot: The robot interface object used to connect and interact with the physical robot.
|
||||||
use_delta_action_space (bool): If True, uses a delta (relative) action space for joint control. Otherwise, absolute
|
use_delta_action_space (bool): If True, uses a delta (relative) action space for joint control. Otherwise, absolute
|
||||||
joint positions are used.
|
joint positions are used.
|
||||||
delta (float or None): A scaling factor for the relative adjustments applied to joint positions. Should be a value between
|
|
||||||
0 and 1 when using a delta action space.
|
|
||||||
display_cameras (bool): If True, the robot's camera feeds will be displayed during execution.
|
display_cameras (bool): If True, the robot's camera feeds will be displayed during execution.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -74,7 +71,6 @@ class HILSerlRobotEnv(gym.Env):
|
||||||
self.current_step = 0
|
self.current_step = 0
|
||||||
self.episode_data = None
|
self.episode_data = None
|
||||||
|
|
||||||
self.delta = delta
|
|
||||||
self.use_delta_action_space = use_delta_action_space
|
self.use_delta_action_space = use_delta_action_space
|
||||||
self.current_joint_positions = self.robot.follower_arms["main"].read("Present_Position")
|
self.current_joint_positions = self.robot.follower_arms["main"].read("Present_Position")
|
||||||
|
|
||||||
|
@ -555,6 +551,9 @@ class ImageCropResizeWrapper(gym.Wrapper):
|
||||||
# TODO(michel-aractingi): Bug in resize, it returns values outside [0, 1]
|
# TODO(michel-aractingi): Bug in resize, it returns values outside [0, 1]
|
||||||
obs[k] = obs[k].clamp(0.0, 1.0)
|
obs[k] = obs[k].clamp(0.0, 1.0)
|
||||||
|
|
||||||
|
# import cv2
|
||||||
|
# cv2.imwrite(f"tmp_img/{k}.jpg", cv2.cvtColor(obs[k].squeeze(0).permute(1,2,0).cpu().numpy()*255, cv2.COLOR_RGB2BGR))
|
||||||
|
|
||||||
# Check for NaNs after processing
|
# Check for NaNs after processing
|
||||||
if torch.isnan(obs[k]).any():
|
if torch.isnan(obs[k]).any():
|
||||||
logging.error(f"NaN values detected in observation {k} after crop and resize")
|
logging.error(f"NaN values detected in observation {k} after crop and resize")
|
||||||
|
@ -720,19 +719,31 @@ class ResetWrapper(gym.Wrapper):
|
||||||
env: HILSerlRobotEnv,
|
env: HILSerlRobotEnv,
|
||||||
reset_pose: np.ndarray | None = None,
|
reset_pose: np.ndarray | None = None,
|
||||||
reset_time_s: float = 5,
|
reset_time_s: float = 5,
|
||||||
|
open_gripper_on_reset: bool = False
|
||||||
):
|
):
|
||||||
super().__init__(env)
|
super().__init__(env)
|
||||||
self.reset_time_s = reset_time_s
|
self.reset_time_s = reset_time_s
|
||||||
self.reset_pose = reset_pose
|
self.reset_pose = reset_pose
|
||||||
self.robot = self.unwrapped.robot
|
self.robot = self.unwrapped.robot
|
||||||
|
self.open_gripper_on_reset = open_gripper_on_reset
|
||||||
|
|
||||||
def reset(self, *, seed=None, options=None):
|
def reset(self, *, seed=None, options=None):
|
||||||
|
|
||||||
|
|
||||||
if self.reset_pose is not None:
|
if self.reset_pose is not None:
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
log_say("Reset the environment.", play_sounds=True)
|
log_say("Reset the environment.", play_sounds=True)
|
||||||
reset_follower_position(self.robot, self.reset_pose)
|
reset_follower_position(self.robot, self.reset_pose)
|
||||||
busy_wait(self.reset_time_s - (time.perf_counter() - start_time))
|
busy_wait(self.reset_time_s - (time.perf_counter() - start_time))
|
||||||
log_say("Reset the environment done.", play_sounds=True)
|
log_say("Reset the environment done.", play_sounds=True)
|
||||||
|
if self.open_gripper_on_reset:
|
||||||
|
current_joint_pos = self.robot.follower_arms["main"].read("Present_Position")
|
||||||
|
current_joint_pos[-1] = MAX_GRIPPER_COMMAND
|
||||||
|
self.robot.send_action(torch.from_numpy(current_joint_pos))
|
||||||
|
busy_wait(0.1)
|
||||||
|
current_joint_pos[-1] = 0.0
|
||||||
|
self.robot.send_action(torch.from_numpy(current_joint_pos))
|
||||||
|
busy_wait(0.2)
|
||||||
else:
|
else:
|
||||||
log_say(
|
log_say(
|
||||||
f"Manually reset the environment for {self.reset_time_s} seconds.",
|
f"Manually reset the environment for {self.reset_time_s} seconds.",
|
||||||
|
@ -777,7 +788,6 @@ class GripperPenaltyWrapper(gym.RewardWrapper):
|
||||||
gripper_penalty_bool = (gripper_state_normalized < 0.1 and action_normalized > 0.9) or (
|
gripper_penalty_bool = (gripper_state_normalized < 0.1 and action_normalized > 0.9) or (
|
||||||
gripper_state_normalized > 0.9 and action_normalized < 0.1
|
gripper_state_normalized > 0.9 and action_normalized < 0.1
|
||||||
)
|
)
|
||||||
breakpoint()
|
|
||||||
|
|
||||||
return reward + self.penalty * gripper_penalty_bool
|
return reward + self.penalty * gripper_penalty_bool
|
||||||
|
|
||||||
|
@ -791,8 +801,7 @@ class GripperPenaltyWrapper(gym.RewardWrapper):
|
||||||
self.last_gripper_state = None
|
self.last_gripper_state = None
|
||||||
return super().reset(**kwargs)
|
return super().reset(**kwargs)
|
||||||
|
|
||||||
|
class GripperActionWrapper(gym.ActionWrapper):
|
||||||
class GripperQuantizationWrapper(gym.ActionWrapper):
|
|
||||||
def __init__(self, env, quantization_threshold: float = 0.2):
|
def __init__(self, env, quantization_threshold: float = 0.2):
|
||||||
super().__init__(env)
|
super().__init__(env)
|
||||||
self.quantization_threshold = quantization_threshold
|
self.quantization_threshold = quantization_threshold
|
||||||
|
@ -801,16 +810,18 @@ class GripperQuantizationWrapper(gym.ActionWrapper):
|
||||||
is_intervention = False
|
is_intervention = False
|
||||||
if isinstance(action, tuple):
|
if isinstance(action, tuple):
|
||||||
action, is_intervention = action
|
action, is_intervention = action
|
||||||
|
|
||||||
gripper_command = action[-1]
|
gripper_command = action[-1]
|
||||||
# Quantize gripper command to -1, 0 or 1
|
|
||||||
if gripper_command < -self.quantization_threshold:
|
|
||||||
gripper_command = -MAX_GRIPPER_COMMAND
|
|
||||||
elif gripper_command > self.quantization_threshold:
|
|
||||||
gripper_command = MAX_GRIPPER_COMMAND
|
|
||||||
else:
|
|
||||||
gripper_command = 0.0
|
|
||||||
|
|
||||||
|
# Gripper actions are between 0, 2
|
||||||
|
# we want to quantize them to -1, 0 or 1
|
||||||
|
gripper_command = gripper_command - 1.0
|
||||||
|
|
||||||
|
if self.quantization_threshold is not None:
|
||||||
|
# Quantize gripper command to -1, 0 or 1
|
||||||
|
gripper_command = (
|
||||||
|
np.sign(gripper_command) if abs(gripper_command) > self.quantization_threshold else 0.0
|
||||||
|
)
|
||||||
|
gripper_command = gripper_command * MAX_GRIPPER_COMMAND
|
||||||
gripper_state = self.unwrapped.robot.follower_arms["main"].read("Present_Position")[-1]
|
gripper_state = self.unwrapped.robot.follower_arms["main"].read("Present_Position")[-1]
|
||||||
gripper_action = np.clip(gripper_state + gripper_command, 0, MAX_GRIPPER_COMMAND)
|
gripper_action = np.clip(gripper_state + gripper_command, 0, MAX_GRIPPER_COMMAND)
|
||||||
action[-1] = gripper_action.item()
|
action[-1] = gripper_action.item()
|
||||||
|
@ -836,10 +847,12 @@ class EEActionWrapper(gym.ActionWrapper):
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
if self.use_gripper:
|
if self.use_gripper:
|
||||||
action_space_bounds = np.concatenate([action_space_bounds, [1.0]])
|
# gripper actions open at 2.0, and closed at 0.0
|
||||||
|
min_action_space_bounds = np.concatenate([-action_space_bounds, [0.0]])
|
||||||
|
max_action_space_bounds = np.concatenate([action_space_bounds, [2.0]])
|
||||||
ee_action_space = gym.spaces.Box(
|
ee_action_space = gym.spaces.Box(
|
||||||
low=-action_space_bounds,
|
low=min_action_space_bounds,
|
||||||
high=action_space_bounds,
|
high=max_action_space_bounds,
|
||||||
shape=(3 + int(self.use_gripper),),
|
shape=(3 + int(self.use_gripper),),
|
||||||
dtype=np.float32,
|
dtype=np.float32,
|
||||||
)
|
)
|
||||||
|
@ -997,11 +1010,11 @@ class GamepadControlWrapper(gym.Wrapper):
|
||||||
if self.use_gripper:
|
if self.use_gripper:
|
||||||
gripper_command = self.controller.gripper_command()
|
gripper_command = self.controller.gripper_command()
|
||||||
if gripper_command == "open":
|
if gripper_command == "open":
|
||||||
gamepad_action = np.concatenate([gamepad_action, [1.0]])
|
gamepad_action = np.concatenate([gamepad_action, [2.0]])
|
||||||
elif gripper_command == "close":
|
elif gripper_command == "close":
|
||||||
gamepad_action = np.concatenate([gamepad_action, [-1.0]])
|
|
||||||
else:
|
|
||||||
gamepad_action = np.concatenate([gamepad_action, [0.0]])
|
gamepad_action = np.concatenate([gamepad_action, [0.0]])
|
||||||
|
else:
|
||||||
|
gamepad_action = np.concatenate([gamepad_action, [1.0]])
|
||||||
|
|
||||||
# Check episode ending buttons
|
# Check episode ending buttons
|
||||||
# We'll rely on controller.get_episode_end_status() which returns "success", "failure", or None
|
# We'll rely on controller.get_episode_end_status() which returns "success", "failure", or None
|
||||||
|
@ -1141,7 +1154,6 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv:
|
||||||
env = HILSerlRobotEnv(
|
env = HILSerlRobotEnv(
|
||||||
robot=robot,
|
robot=robot,
|
||||||
display_cameras=cfg.wrapper.display_cameras,
|
display_cameras=cfg.wrapper.display_cameras,
|
||||||
delta=cfg.wrapper.delta_action,
|
|
||||||
use_delta_action_space=cfg.wrapper.use_relative_joint_positions
|
use_delta_action_space=cfg.wrapper.use_relative_joint_positions
|
||||||
and cfg.wrapper.ee_action_space_params is None,
|
and cfg.wrapper.ee_action_space_params is None,
|
||||||
)
|
)
|
||||||
|
@ -1165,7 +1177,7 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv:
|
||||||
# env = RewardWrapper(env=env, reward_classifier=reward_classifier, device=cfg.device)
|
# env = RewardWrapper(env=env, reward_classifier=reward_classifier, device=cfg.device)
|
||||||
env = TimeLimitWrapper(env=env, control_time_s=cfg.wrapper.control_time_s, fps=cfg.fps)
|
env = TimeLimitWrapper(env=env, control_time_s=cfg.wrapper.control_time_s, fps=cfg.fps)
|
||||||
if cfg.wrapper.use_gripper:
|
if cfg.wrapper.use_gripper:
|
||||||
env = GripperQuantizationWrapper(
|
env = GripperActionWrapper(
|
||||||
env=env, quantization_threshold=cfg.wrapper.gripper_quantization_threshold
|
env=env, quantization_threshold=cfg.wrapper.gripper_quantization_threshold
|
||||||
)
|
)
|
||||||
# env = GripperPenaltyWrapper(env=env, penalty=cfg.wrapper.gripper_penalty)
|
# env = GripperPenaltyWrapper(env=env, penalty=cfg.wrapper.gripper_penalty)
|
||||||
|
@ -1176,6 +1188,7 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv:
|
||||||
ee_action_space_params=cfg.wrapper.ee_action_space_params,
|
ee_action_space_params=cfg.wrapper.ee_action_space_params,
|
||||||
use_gripper=cfg.wrapper.use_gripper,
|
use_gripper=cfg.wrapper.use_gripper,
|
||||||
)
|
)
|
||||||
|
|
||||||
if cfg.wrapper.ee_action_space_params is not None and cfg.wrapper.ee_action_space_params.use_gamepad:
|
if cfg.wrapper.ee_action_space_params is not None and cfg.wrapper.ee_action_space_params.use_gamepad:
|
||||||
# env = ActionScaleWrapper(env=env, ee_action_space_params=cfg.wrapper.ee_action_space_params)
|
# env = ActionScaleWrapper(env=env, ee_action_space_params=cfg.wrapper.ee_action_space_params)
|
||||||
env = GamepadControlWrapper(
|
env = GamepadControlWrapper(
|
||||||
|
@ -1192,6 +1205,7 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv:
|
||||||
env=env,
|
env=env,
|
||||||
reset_pose=cfg.wrapper.fixed_reset_joint_positions,
|
reset_pose=cfg.wrapper.fixed_reset_joint_positions,
|
||||||
reset_time_s=cfg.wrapper.reset_time_s,
|
reset_time_s=cfg.wrapper.reset_time_s,
|
||||||
|
open_gripper_on_reset=cfg.wrapper.open_gripper_on_reset
|
||||||
)
|
)
|
||||||
if cfg.wrapper.ee_action_space_params is None and cfg.wrapper.joint_masking_action_space is not None:
|
if cfg.wrapper.ee_action_space_params is None and cfg.wrapper.joint_masking_action_space is not None:
|
||||||
env = JointMaskingActionSpace(env=env, mask=cfg.wrapper.joint_masking_action_space)
|
env = JointMaskingActionSpace(env=env, mask=cfg.wrapper.joint_masking_action_space)
|
||||||
|
@ -1341,11 +1355,10 @@ def record_dataset(env, policy, cfg):
|
||||||
dataset.push_to_hub()
|
dataset.push_to_hub()
|
||||||
|
|
||||||
|
|
||||||
def replay_episode(env, repo_id, root=None, episode=0):
|
def replay_episode(env, cfg):
|
||||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
|
|
||||||
local_files_only = root is not None
|
dataset = LeRobotDataset(cfg.repo_id, root=cfg.dataset_root, episodes=[cfg.episode])
|
||||||
dataset = LeRobotDataset(repo_id, root=root, episodes=[episode], local_files_only=local_files_only)
|
|
||||||
env.reset()
|
env.reset()
|
||||||
|
|
||||||
actions = dataset.hf_dataset.select_columns("action")
|
actions = dataset.hf_dataset.select_columns("action")
|
||||||
|
@ -1353,7 +1366,7 @@ def replay_episode(env, repo_id, root=None, episode=0):
|
||||||
for idx in range(dataset.num_frames):
|
for idx in range(dataset.num_frames):
|
||||||
start_episode_t = time.perf_counter()
|
start_episode_t = time.perf_counter()
|
||||||
|
|
||||||
action = actions[idx]["action"][:4]
|
action = actions[idx]["action"]
|
||||||
env.step((action, False))
|
env.step((action, False))
|
||||||
# env.step((action / env.unwrapped.delta, False))
|
# env.step((action / env.unwrapped.delta, False))
|
||||||
|
|
||||||
|
@ -1384,9 +1397,7 @@ def main(cfg: EnvConfig):
|
||||||
if cfg.mode == "replay":
|
if cfg.mode == "replay":
|
||||||
replay_episode(
|
replay_episode(
|
||||||
env,
|
env,
|
||||||
cfg.replay_repo_id,
|
cfg=cfg,
|
||||||
root=cfg.dataset_root,
|
|
||||||
episode=cfg.replay_episode,
|
|
||||||
)
|
)
|
||||||
exit()
|
exit()
|
||||||
|
|
||||||
|
|
|
@ -380,6 +380,7 @@ def add_actor_information_and_train(
|
||||||
for _ in range(utd_ratio - 1):
|
for _ in range(utd_ratio - 1):
|
||||||
# Sample from the iterators
|
# Sample from the iterators
|
||||||
batch = next(online_iterator)
|
batch = next(online_iterator)
|
||||||
|
# batch = replay_buffer.sample(batch_size)
|
||||||
|
|
||||||
if dataset_repo_id is not None:
|
if dataset_repo_id is not None:
|
||||||
batch_offline = next(offline_iterator)
|
batch_offline = next(offline_iterator)
|
||||||
|
@ -437,9 +438,11 @@ def add_actor_information_and_train(
|
||||||
|
|
||||||
# Sample for the last update in the UTD ratio
|
# Sample for the last update in the UTD ratio
|
||||||
batch = next(online_iterator)
|
batch = next(online_iterator)
|
||||||
|
# batch = replay_buffer.sample(batch_size)
|
||||||
|
|
||||||
if dataset_repo_id is not None:
|
if dataset_repo_id is not None:
|
||||||
batch_offline = next(offline_iterator)
|
batch_offline = next(offline_iterator)
|
||||||
|
# batch_offline = offline_replay_buffer.sample(batch_size)
|
||||||
batch = concatenate_batch_transitions(
|
batch = concatenate_batch_transitions(
|
||||||
left_batch_transitions=batch, right_batch_transition=batch_offline
|
left_batch_transitions=batch, right_batch_transition=batch_offline
|
||||||
)
|
)
|
||||||
|
@ -775,9 +778,7 @@ def make_optimizers_and_scheduler(cfg: TrainPipelineConfig, policy: nn.Module):
|
||||||
params=policy.actor.parameters_to_optimize,
|
params=policy.actor.parameters_to_optimize,
|
||||||
lr=cfg.policy.actor_lr,
|
lr=cfg.policy.actor_lr,
|
||||||
)
|
)
|
||||||
optimizer_critic = torch.optim.Adam(
|
optimizer_critic = torch.optim.Adam(params=policy.critic_ensemble.parameters_to_optimize, lr=cfg.policy.critic_lr)
|
||||||
params=policy.critic_ensemble.parameters_to_optimize, lr=cfg.policy.critic_lr
|
|
||||||
)
|
|
||||||
|
|
||||||
if cfg.policy.num_discrete_actions is not None:
|
if cfg.policy.num_discrete_actions is not None:
|
||||||
optimizer_grasp_critic = torch.optim.Adam(
|
optimizer_grasp_critic = torch.optim.Adam(
|
||||||
|
@ -992,7 +993,6 @@ def initialize_offline_replay_buffer(
|
||||||
device=device,
|
device=device,
|
||||||
state_keys=cfg.policy.input_features.keys(),
|
state_keys=cfg.policy.input_features.keys(),
|
||||||
action_mask=active_action_dims,
|
action_mask=active_action_dims,
|
||||||
action_delta=cfg.env.wrapper.delta_action,
|
|
||||||
storage_device=storage_device,
|
storage_device=storage_device,
|
||||||
optimize_memory=True,
|
optimize_memory=True,
|
||||||
capacity=cfg.policy.offline_buffer_capacity,
|
capacity=cfg.policy.offline_buffer_capacity,
|
||||||
|
|
Loading…
Reference in New Issue