Add mock gripper support and enhance SAC policy action handling

- Introduced mock_gripper parameter in ManiskillEnvConfig to enable gripper simulation.
- Added ManiskillMockGripperWrapper to adjust action space for environments with discrete actions.
- Updated SACPolicy to compute continuous action dimensions correctly, ensuring compatibility with the new gripper setup.
- Refactored action handling in the training loop to accommodate the changes in action dimensions.
This commit is contained in:
AdilZouitine 2025-04-01 14:22:08 +00:00 committed by Michel Aractingi
parent f83d215e7a
commit d86d29fe21
3 changed files with 59 additions and 31 deletions

View File

@ -257,6 +257,7 @@ class ManiskillEnvConfig(EnvConfig):
robot: str = "so100" # This is a hack to make the robot config work
video_record: VideoRecordConfig = field(default_factory=VideoRecordConfig)
wrapper: WrapperConfig = field(default_factory=WrapperConfig)
mock_gripper: bool = False
features: dict[str, PolicyFeature] = field(
default_factory=lambda: {
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),

View File

@ -33,7 +33,7 @@ from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.policies.sac.configuration_sac import SACConfig
from lerobot.common.policies.utils import get_device_from_parameters
DISCRETE_DIMENSION_INDEX = -1
DISCRETE_DIMENSION_INDEX = -1 # Gripper is always the last dimension
class SACPolicy(
PreTrainedPolicy,
@ -82,7 +82,7 @@ class SACPolicy(
# Create a list of critic heads
critic_heads = [
CriticHead(
input_dim=encoder_critic.output_dim + config.output_features["action"].shape[0],
input_dim=encoder_critic.output_dim + continuous_action_dim,
**asdict(config.critic_network_kwargs),
)
for _ in range(config.num_critics)
@ -97,7 +97,7 @@ class SACPolicy(
# Create target critic heads as deepcopies of the original critic heads
target_critic_heads = [
CriticHead(
input_dim=encoder_critic.output_dim + config.output_features["action"].shape[0],
input_dim=encoder_critic.output_dim + continuous_action_dim,
**asdict(config.critic_network_kwargs),
)
for _ in range(config.num_critics)
@ -117,7 +117,10 @@ class SACPolicy(
self.grasp_critic = None
self.grasp_critic_target = None
continuous_action_dim = config.output_features["action"].shape[0]
if config.num_discrete_actions is not None:
continuous_action_dim -= 1
# Create grasp critic
self.grasp_critic = GraspCritic(
encoder=encoder_critic,
@ -139,15 +142,16 @@ class SACPolicy(
self.grasp_critic = torch.compile(self.grasp_critic)
self.grasp_critic_target = torch.compile(self.grasp_critic_target)
self.actor = Policy(
encoder=encoder_actor,
network=MLP(input_dim=encoder_actor.output_dim, **asdict(config.actor_network_kwargs)),
action_dim=config.output_features["action"].shape[0],
action_dim=continuous_action_dim,
encoder_is_shared=config.shared_encoder,
**asdict(config.policy_kwargs),
)
if config.target_entropy is None:
config.target_entropy = -np.prod(config.output_features["action"].shape[0]) / 2 # (-dim(A)/2)
config.target_entropy = -np.prod(continuous_action_dim) / 2 # (-dim(A)/2)
# TODO (azouitine): Handle the case where the temparameter is a fixed
# TODO (michel-aractingi): Put the log_alpha in cuda by default because otherwise
@ -275,7 +279,9 @@ class SACPolicy(
next_observations=next_observations,
done=done,
)
return {"loss_critic": loss_critic, "loss_grasp_critic": loss_grasp_critic}
return {"loss_critic": loss_critic, "loss_grasp_critic": loss_grasp_critic}
return {"loss_critic": loss_critic}
if model == "actor":

View File

@ -11,6 +11,10 @@ from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
from lerobot.common.envs.configs import ManiskillEnvConfig
from lerobot.configs import parser
from lerobot.configs.train import TrainPipelineConfig
from lerobot.common.policies.sac.configuration_sac import SACConfig
from lerobot.common.policies.sac.modeling_sac import SACPolicy
def preprocess_maniskill_observation(
@ -152,6 +156,21 @@ class TimeLimitWrapper(gym.Wrapper):
self.current_step = 0
return super().reset(seed=seed, options=options)
class ManiskillMockGripperWrapper(gym.Wrapper):
def __init__(self, env, nb_discrete_actions: int = 3):
super().__init__(env)
new_shape = env.action_space[0].shape[0] + 1
new_low = np.concatenate([env.action_space[0].low, [0]])
new_high = np.concatenate([env.action_space[0].high, [nb_discrete_actions - 1]])
action_space_agent = gym.spaces.Box(low=new_low, high=new_high, shape=(new_shape,))
self.action_space = gym.spaces.Tuple((action_space_agent, env.action_space[1]))
def step(self, action):
action_agent, telop_action = action
real_action = action_agent[:-1]
final_action = (real_action, telop_action)
obs, reward, terminated, truncated, info = self.env.step(final_action)
return obs, reward, terminated, truncated, info
def make_maniskill(
cfg: ManiskillEnvConfig,
@ -197,40 +216,42 @@ def make_maniskill(
env = ManiSkillCompat(env)
env = ManiSkillActionWrapper(env)
env = ManiSkillMultiplyActionWrapper(env, multiply_factor=0.03) # Scale actions for better control
if cfg.mock_gripper:
env = ManiskillMockGripperWrapper(env, nb_discrete_actions=3)
return env
@parser.wrap()
def main(cfg: ManiskillEnvConfig):
"""Main function to run the ManiSkill environment."""
# Create the ManiSkill environment
env = make_maniskill(cfg, n_envs=1)
# @parser.wrap()
# def main(cfg: TrainPipelineConfig):
# """Main function to run the ManiSkill environment."""
# # Create the ManiSkill environment
# env = make_maniskill(cfg.env, n_envs=1)
# Reset the environment
obs, info = env.reset()
# # Reset the environment
# obs, info = env.reset()
# Run a simple interaction loop
sum_reward = 0
for i in range(100):
# Sample a random action
action = env.action_space.sample()
# # Run a simple interaction loop
# sum_reward = 0
# for i in range(100):
# # Sample a random action
# action = env.action_space.sample()
# Step the environment
start_time = time.perf_counter()
obs, reward, terminated, truncated, info = env.step(action)
step_time = time.perf_counter() - start_time
sum_reward += reward
# Log information
# # Step the environment
# start_time = time.perf_counter()
# obs, reward, terminated, truncated, info = env.step(action)
# step_time = time.perf_counter() - start_time
# sum_reward += reward
# # Log information
# Reset if episode terminated
if terminated or truncated:
logging.info(f"Step {i}, reward: {sum_reward}, step time: {step_time}s")
sum_reward = 0
obs, info = env.reset()
# # Reset if episode terminated
# if terminated or truncated:
# logging.info(f"Step {i}, reward: {sum_reward}, step time: {step_time}s")
# sum_reward = 0
# obs, info = env.reset()
# Close the environment
env.close()
# # Close the environment
# env.close()
# if __name__ == "__main__":