From d86d29fe21f0a5eda12cbe51b59377a8d6d9b9d7 Mon Sep 17 00:00:00 2001 From: AdilZouitine Date: Tue, 1 Apr 2025 14:22:08 +0000 Subject: [PATCH] Add mock gripper support and enhance SAC policy action handling - Introduced mock_gripper parameter in ManiskillEnvConfig to enable gripper simulation. - Added ManiskillMockGripperWrapper to adjust action space for environments with discrete actions. - Updated SACPolicy to compute continuous action dimensions correctly, ensuring compatibility with the new gripper setup. - Refactored action handling in the training loop to accommodate the changes in action dimensions. --- lerobot/common/envs/configs.py | 1 + lerobot/common/policies/sac/modeling_sac.py | 18 +++-- .../scripts/server/maniskill_manipulator.py | 71 ++++++++++++------- 3 files changed, 59 insertions(+), 31 deletions(-) diff --git a/lerobot/common/envs/configs.py b/lerobot/common/envs/configs.py index 440512c3..a6eda93b 100644 --- a/lerobot/common/envs/configs.py +++ b/lerobot/common/envs/configs.py @@ -257,6 +257,7 @@ class ManiskillEnvConfig(EnvConfig): robot: str = "so100" # This is a hack to make the robot config work video_record: VideoRecordConfig = field(default_factory=VideoRecordConfig) wrapper: WrapperConfig = field(default_factory=WrapperConfig) + mock_gripper: bool = False features: dict[str, PolicyFeature] = field( default_factory=lambda: { "action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)), diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index d0e8b25d..0c3d76d2 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -33,7 +33,7 @@ from lerobot.common.policies.pretrained import PreTrainedPolicy from lerobot.common.policies.sac.configuration_sac import SACConfig from lerobot.common.policies.utils import get_device_from_parameters -DISCRETE_DIMENSION_INDEX = -1 +DISCRETE_DIMENSION_INDEX = -1 # Gripper is always the last dimension class SACPolicy( PreTrainedPolicy, @@ -82,7 +82,7 @@ class SACPolicy( # Create a list of critic heads critic_heads = [ CriticHead( - input_dim=encoder_critic.output_dim + config.output_features["action"].shape[0], + input_dim=encoder_critic.output_dim + continuous_action_dim, **asdict(config.critic_network_kwargs), ) for _ in range(config.num_critics) @@ -97,7 +97,7 @@ class SACPolicy( # Create target critic heads as deepcopies of the original critic heads target_critic_heads = [ CriticHead( - input_dim=encoder_critic.output_dim + config.output_features["action"].shape[0], + input_dim=encoder_critic.output_dim + continuous_action_dim, **asdict(config.critic_network_kwargs), ) for _ in range(config.num_critics) @@ -117,7 +117,10 @@ class SACPolicy( self.grasp_critic = None self.grasp_critic_target = None + continuous_action_dim = config.output_features["action"].shape[0] if config.num_discrete_actions is not None: + + continuous_action_dim -= 1 # Create grasp critic self.grasp_critic = GraspCritic( encoder=encoder_critic, @@ -139,15 +142,16 @@ class SACPolicy( self.grasp_critic = torch.compile(self.grasp_critic) self.grasp_critic_target = torch.compile(self.grasp_critic_target) + self.actor = Policy( encoder=encoder_actor, network=MLP(input_dim=encoder_actor.output_dim, **asdict(config.actor_network_kwargs)), - action_dim=config.output_features["action"].shape[0], + action_dim=continuous_action_dim, encoder_is_shared=config.shared_encoder, **asdict(config.policy_kwargs), ) if config.target_entropy is None: - config.target_entropy = -np.prod(config.output_features["action"].shape[0]) / 2 # (-dim(A)/2) + config.target_entropy = -np.prod(continuous_action_dim) / 2 # (-dim(A)/2) # TODO (azouitine): Handle the case where the temparameter is a fixed # TODO (michel-aractingi): Put the log_alpha in cuda by default because otherwise @@ -275,7 +279,9 @@ class SACPolicy( next_observations=next_observations, done=done, ) - return {"loss_critic": loss_critic, "loss_grasp_critic": loss_grasp_critic} + return {"loss_critic": loss_critic, "loss_grasp_critic": loss_grasp_critic} + + return {"loss_critic": loss_critic} if model == "actor": diff --git a/lerobot/scripts/server/maniskill_manipulator.py b/lerobot/scripts/server/maniskill_manipulator.py index e10b8766..f4a89888 100644 --- a/lerobot/scripts/server/maniskill_manipulator.py +++ b/lerobot/scripts/server/maniskill_manipulator.py @@ -11,6 +11,10 @@ from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv from lerobot.common.envs.configs import ManiskillEnvConfig from lerobot.configs import parser +from lerobot.configs.train import TrainPipelineConfig +from lerobot.common.policies.sac.configuration_sac import SACConfig +from lerobot.common.policies.sac.modeling_sac import SACPolicy + def preprocess_maniskill_observation( @@ -152,6 +156,21 @@ class TimeLimitWrapper(gym.Wrapper): self.current_step = 0 return super().reset(seed=seed, options=options) +class ManiskillMockGripperWrapper(gym.Wrapper): + def __init__(self, env, nb_discrete_actions: int = 3): + super().__init__(env) + new_shape = env.action_space[0].shape[0] + 1 + new_low = np.concatenate([env.action_space[0].low, [0]]) + new_high = np.concatenate([env.action_space[0].high, [nb_discrete_actions - 1]]) + action_space_agent = gym.spaces.Box(low=new_low, high=new_high, shape=(new_shape,)) + self.action_space = gym.spaces.Tuple((action_space_agent, env.action_space[1])) + + def step(self, action): + action_agent, telop_action = action + real_action = action_agent[:-1] + final_action = (real_action, telop_action) + obs, reward, terminated, truncated, info = self.env.step(final_action) + return obs, reward, terminated, truncated, info def make_maniskill( cfg: ManiskillEnvConfig, @@ -197,40 +216,42 @@ def make_maniskill( env = ManiSkillCompat(env) env = ManiSkillActionWrapper(env) env = ManiSkillMultiplyActionWrapper(env, multiply_factor=0.03) # Scale actions for better control + if cfg.mock_gripper: + env = ManiskillMockGripperWrapper(env, nb_discrete_actions=3) return env -@parser.wrap() -def main(cfg: ManiskillEnvConfig): - """Main function to run the ManiSkill environment.""" - # Create the ManiSkill environment - env = make_maniskill(cfg, n_envs=1) +# @parser.wrap() +# def main(cfg: TrainPipelineConfig): +# """Main function to run the ManiSkill environment.""" +# # Create the ManiSkill environment +# env = make_maniskill(cfg.env, n_envs=1) - # Reset the environment - obs, info = env.reset() +# # Reset the environment +# obs, info = env.reset() - # Run a simple interaction loop - sum_reward = 0 - for i in range(100): - # Sample a random action - action = env.action_space.sample() +# # Run a simple interaction loop +# sum_reward = 0 +# for i in range(100): +# # Sample a random action +# action = env.action_space.sample() - # Step the environment - start_time = time.perf_counter() - obs, reward, terminated, truncated, info = env.step(action) - step_time = time.perf_counter() - start_time - sum_reward += reward - # Log information +# # Step the environment +# start_time = time.perf_counter() +# obs, reward, terminated, truncated, info = env.step(action) +# step_time = time.perf_counter() - start_time +# sum_reward += reward +# # Log information - # Reset if episode terminated - if terminated or truncated: - logging.info(f"Step {i}, reward: {sum_reward}, step time: {step_time}s") - sum_reward = 0 - obs, info = env.reset() +# # Reset if episode terminated +# if terminated or truncated: +# logging.info(f"Step {i}, reward: {sum_reward}, step time: {step_time}s") +# sum_reward = 0 +# obs, info = env.reset() - # Close the environment - env.close() +# # Close the environment +# env.close() # if __name__ == "__main__":