Add mock gripper support and enhance SAC policy action handling

- Introduced mock_gripper parameter in ManiskillEnvConfig to enable gripper simulation.
- Added ManiskillMockGripperWrapper to adjust action space for environments with discrete actions.
- Updated SACPolicy to compute continuous action dimensions correctly, ensuring compatibility with the new gripper setup.
- Refactored action handling in the training loop to accommodate the changes in action dimensions.
This commit is contained in:
AdilZouitine 2025-04-01 14:22:08 +00:00
parent 306c735172
commit 451a7b01db
3 changed files with 59 additions and 31 deletions

View File

@ -257,6 +257,7 @@ class ManiskillEnvConfig(EnvConfig):
robot: str = "so100" # This is a hack to make the robot config work robot: str = "so100" # This is a hack to make the robot config work
video_record: VideoRecordConfig = field(default_factory=VideoRecordConfig) video_record: VideoRecordConfig = field(default_factory=VideoRecordConfig)
wrapper: WrapperConfig = field(default_factory=WrapperConfig) wrapper: WrapperConfig = field(default_factory=WrapperConfig)
mock_gripper: bool = False
features: dict[str, PolicyFeature] = field( features: dict[str, PolicyFeature] = field(
default_factory=lambda: { default_factory=lambda: {
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)), "action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),

View File

@ -33,7 +33,7 @@ from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.policies.sac.configuration_sac import SACConfig from lerobot.common.policies.sac.configuration_sac import SACConfig
from lerobot.common.policies.utils import get_device_from_parameters from lerobot.common.policies.utils import get_device_from_parameters
DISCRETE_DIMENSION_INDEX = -1 DISCRETE_DIMENSION_INDEX = -1 # Gripper is always the last dimension
class SACPolicy( class SACPolicy(
PreTrainedPolicy, PreTrainedPolicy,
@ -82,7 +82,7 @@ class SACPolicy(
# Create a list of critic heads # Create a list of critic heads
critic_heads = [ critic_heads = [
CriticHead( CriticHead(
input_dim=encoder_critic.output_dim + config.output_features["action"].shape[0], input_dim=encoder_critic.output_dim + continuous_action_dim,
**asdict(config.critic_network_kwargs), **asdict(config.critic_network_kwargs),
) )
for _ in range(config.num_critics) for _ in range(config.num_critics)
@ -97,7 +97,7 @@ class SACPolicy(
# Create target critic heads as deepcopies of the original critic heads # Create target critic heads as deepcopies of the original critic heads
target_critic_heads = [ target_critic_heads = [
CriticHead( CriticHead(
input_dim=encoder_critic.output_dim + config.output_features["action"].shape[0], input_dim=encoder_critic.output_dim + continuous_action_dim,
**asdict(config.critic_network_kwargs), **asdict(config.critic_network_kwargs),
) )
for _ in range(config.num_critics) for _ in range(config.num_critics)
@ -117,7 +117,10 @@ class SACPolicy(
self.grasp_critic = None self.grasp_critic = None
self.grasp_critic_target = None self.grasp_critic_target = None
continuous_action_dim = config.output_features["action"].shape[0]
if config.num_discrete_actions is not None: if config.num_discrete_actions is not None:
continuous_action_dim -= 1
# Create grasp critic # Create grasp critic
self.grasp_critic = GraspCritic( self.grasp_critic = GraspCritic(
encoder=encoder_critic, encoder=encoder_critic,
@ -139,15 +142,16 @@ class SACPolicy(
self.grasp_critic = torch.compile(self.grasp_critic) self.grasp_critic = torch.compile(self.grasp_critic)
self.grasp_critic_target = torch.compile(self.grasp_critic_target) self.grasp_critic_target = torch.compile(self.grasp_critic_target)
self.actor = Policy( self.actor = Policy(
encoder=encoder_actor, encoder=encoder_actor,
network=MLP(input_dim=encoder_actor.output_dim, **asdict(config.actor_network_kwargs)), network=MLP(input_dim=encoder_actor.output_dim, **asdict(config.actor_network_kwargs)),
action_dim=config.output_features["action"].shape[0], action_dim=continuous_action_dim,
encoder_is_shared=config.shared_encoder, encoder_is_shared=config.shared_encoder,
**asdict(config.policy_kwargs), **asdict(config.policy_kwargs),
) )
if config.target_entropy is None: if config.target_entropy is None:
config.target_entropy = -np.prod(config.output_features["action"].shape[0]) / 2 # (-dim(A)/2) config.target_entropy = -np.prod(continuous_action_dim) / 2 # (-dim(A)/2)
# TODO (azouitine): Handle the case where the temparameter is a fixed # TODO (azouitine): Handle the case where the temparameter is a fixed
# TODO (michel-aractingi): Put the log_alpha in cuda by default because otherwise # TODO (michel-aractingi): Put the log_alpha in cuda by default because otherwise
@ -275,7 +279,9 @@ class SACPolicy(
next_observations=next_observations, next_observations=next_observations,
done=done, done=done,
) )
return {"loss_critic": loss_critic, "loss_grasp_critic": loss_grasp_critic} return {"loss_critic": loss_critic, "loss_grasp_critic": loss_grasp_critic}
return {"loss_critic": loss_critic}
if model == "actor": if model == "actor":

View File

@ -11,6 +11,10 @@ from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
from lerobot.common.envs.configs import ManiskillEnvConfig from lerobot.common.envs.configs import ManiskillEnvConfig
from lerobot.configs import parser from lerobot.configs import parser
from lerobot.configs.train import TrainPipelineConfig
from lerobot.common.policies.sac.configuration_sac import SACConfig
from lerobot.common.policies.sac.modeling_sac import SACPolicy
def preprocess_maniskill_observation( def preprocess_maniskill_observation(
@ -152,6 +156,21 @@ class TimeLimitWrapper(gym.Wrapper):
self.current_step = 0 self.current_step = 0
return super().reset(seed=seed, options=options) return super().reset(seed=seed, options=options)
class ManiskillMockGripperWrapper(gym.Wrapper):
def __init__(self, env, nb_discrete_actions: int = 3):
super().__init__(env)
new_shape = env.action_space[0].shape[0] + 1
new_low = np.concatenate([env.action_space[0].low, [0]])
new_high = np.concatenate([env.action_space[0].high, [nb_discrete_actions - 1]])
action_space_agent = gym.spaces.Box(low=new_low, high=new_high, shape=(new_shape,))
self.action_space = gym.spaces.Tuple((action_space_agent, env.action_space[1]))
def step(self, action):
action_agent, telop_action = action
real_action = action_agent[:-1]
final_action = (real_action, telop_action)
obs, reward, terminated, truncated, info = self.env.step(final_action)
return obs, reward, terminated, truncated, info
def make_maniskill( def make_maniskill(
cfg: ManiskillEnvConfig, cfg: ManiskillEnvConfig,
@ -197,40 +216,42 @@ def make_maniskill(
env = ManiSkillCompat(env) env = ManiSkillCompat(env)
env = ManiSkillActionWrapper(env) env = ManiSkillActionWrapper(env)
env = ManiSkillMultiplyActionWrapper(env, multiply_factor=0.03) # Scale actions for better control env = ManiSkillMultiplyActionWrapper(env, multiply_factor=0.03) # Scale actions for better control
if cfg.mock_gripper:
env = ManiskillMockGripperWrapper(env, nb_discrete_actions=3)
return env return env
@parser.wrap() # @parser.wrap()
def main(cfg: ManiskillEnvConfig): # def main(cfg: TrainPipelineConfig):
"""Main function to run the ManiSkill environment.""" # """Main function to run the ManiSkill environment."""
# Create the ManiSkill environment # # Create the ManiSkill environment
env = make_maniskill(cfg, n_envs=1) # env = make_maniskill(cfg.env, n_envs=1)
# Reset the environment # # Reset the environment
obs, info = env.reset() # obs, info = env.reset()
# Run a simple interaction loop # # Run a simple interaction loop
sum_reward = 0 # sum_reward = 0
for i in range(100): # for i in range(100):
# Sample a random action # # Sample a random action
action = env.action_space.sample() # action = env.action_space.sample()
# Step the environment # # Step the environment
start_time = time.perf_counter() # start_time = time.perf_counter()
obs, reward, terminated, truncated, info = env.step(action) # obs, reward, terminated, truncated, info = env.step(action)
step_time = time.perf_counter() - start_time # step_time = time.perf_counter() - start_time
sum_reward += reward # sum_reward += reward
# Log information # # Log information
# Reset if episode terminated # # Reset if episode terminated
if terminated or truncated: # if terminated or truncated:
logging.info(f"Step {i}, reward: {sum_reward}, step time: {step_time}s") # logging.info(f"Step {i}, reward: {sum_reward}, step time: {step_time}s")
sum_reward = 0 # sum_reward = 0
obs, info = env.reset() # obs, info = env.reset()
# Close the environment # # Close the environment
env.close() # env.close()
# if __name__ == "__main__": # if __name__ == "__main__":