Add mock gripper support and enhance SAC policy action handling
- Introduced mock_gripper parameter in ManiskillEnvConfig to enable gripper simulation. - Added ManiskillMockGripperWrapper to adjust action space for environments with discrete actions. - Updated SACPolicy to compute continuous action dimensions correctly, ensuring compatibility with the new gripper setup. - Refactored action handling in the training loop to accommodate the changes in action dimensions.
This commit is contained in:
parent
f83d215e7a
commit
d86d29fe21
|
@ -257,6 +257,7 @@ class ManiskillEnvConfig(EnvConfig):
|
|||
robot: str = "so100" # This is a hack to make the robot config work
|
||||
video_record: VideoRecordConfig = field(default_factory=VideoRecordConfig)
|
||||
wrapper: WrapperConfig = field(default_factory=WrapperConfig)
|
||||
mock_gripper: bool = False
|
||||
features: dict[str, PolicyFeature] = field(
|
||||
default_factory=lambda: {
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
|
||||
|
|
|
@ -33,7 +33,7 @@ from lerobot.common.policies.pretrained import PreTrainedPolicy
|
|||
from lerobot.common.policies.sac.configuration_sac import SACConfig
|
||||
from lerobot.common.policies.utils import get_device_from_parameters
|
||||
|
||||
DISCRETE_DIMENSION_INDEX = -1
|
||||
DISCRETE_DIMENSION_INDEX = -1 # Gripper is always the last dimension
|
||||
|
||||
class SACPolicy(
|
||||
PreTrainedPolicy,
|
||||
|
@ -82,7 +82,7 @@ class SACPolicy(
|
|||
# Create a list of critic heads
|
||||
critic_heads = [
|
||||
CriticHead(
|
||||
input_dim=encoder_critic.output_dim + config.output_features["action"].shape[0],
|
||||
input_dim=encoder_critic.output_dim + continuous_action_dim,
|
||||
**asdict(config.critic_network_kwargs),
|
||||
)
|
||||
for _ in range(config.num_critics)
|
||||
|
@ -97,7 +97,7 @@ class SACPolicy(
|
|||
# Create target critic heads as deepcopies of the original critic heads
|
||||
target_critic_heads = [
|
||||
CriticHead(
|
||||
input_dim=encoder_critic.output_dim + config.output_features["action"].shape[0],
|
||||
input_dim=encoder_critic.output_dim + continuous_action_dim,
|
||||
**asdict(config.critic_network_kwargs),
|
||||
)
|
||||
for _ in range(config.num_critics)
|
||||
|
@ -117,7 +117,10 @@ class SACPolicy(
|
|||
self.grasp_critic = None
|
||||
self.grasp_critic_target = None
|
||||
|
||||
continuous_action_dim = config.output_features["action"].shape[0]
|
||||
if config.num_discrete_actions is not None:
|
||||
|
||||
continuous_action_dim -= 1
|
||||
# Create grasp critic
|
||||
self.grasp_critic = GraspCritic(
|
||||
encoder=encoder_critic,
|
||||
|
@ -139,15 +142,16 @@ class SACPolicy(
|
|||
self.grasp_critic = torch.compile(self.grasp_critic)
|
||||
self.grasp_critic_target = torch.compile(self.grasp_critic_target)
|
||||
|
||||
|
||||
self.actor = Policy(
|
||||
encoder=encoder_actor,
|
||||
network=MLP(input_dim=encoder_actor.output_dim, **asdict(config.actor_network_kwargs)),
|
||||
action_dim=config.output_features["action"].shape[0],
|
||||
action_dim=continuous_action_dim,
|
||||
encoder_is_shared=config.shared_encoder,
|
||||
**asdict(config.policy_kwargs),
|
||||
)
|
||||
if config.target_entropy is None:
|
||||
config.target_entropy = -np.prod(config.output_features["action"].shape[0]) / 2 # (-dim(A)/2)
|
||||
config.target_entropy = -np.prod(continuous_action_dim) / 2 # (-dim(A)/2)
|
||||
|
||||
# TODO (azouitine): Handle the case where the temparameter is a fixed
|
||||
# TODO (michel-aractingi): Put the log_alpha in cuda by default because otherwise
|
||||
|
@ -277,6 +281,8 @@ class SACPolicy(
|
|||
)
|
||||
return {"loss_critic": loss_critic, "loss_grasp_critic": loss_grasp_critic}
|
||||
|
||||
return {"loss_critic": loss_critic}
|
||||
|
||||
|
||||
if model == "actor":
|
||||
return {"loss_actor": self.compute_loss_actor(
|
||||
|
|
|
@ -11,6 +11,10 @@ from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
|
|||
|
||||
from lerobot.common.envs.configs import ManiskillEnvConfig
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.common.policies.sac.configuration_sac import SACConfig
|
||||
from lerobot.common.policies.sac.modeling_sac import SACPolicy
|
||||
|
||||
|
||||
|
||||
def preprocess_maniskill_observation(
|
||||
|
@ -152,6 +156,21 @@ class TimeLimitWrapper(gym.Wrapper):
|
|||
self.current_step = 0
|
||||
return super().reset(seed=seed, options=options)
|
||||
|
||||
class ManiskillMockGripperWrapper(gym.Wrapper):
|
||||
def __init__(self, env, nb_discrete_actions: int = 3):
|
||||
super().__init__(env)
|
||||
new_shape = env.action_space[0].shape[0] + 1
|
||||
new_low = np.concatenate([env.action_space[0].low, [0]])
|
||||
new_high = np.concatenate([env.action_space[0].high, [nb_discrete_actions - 1]])
|
||||
action_space_agent = gym.spaces.Box(low=new_low, high=new_high, shape=(new_shape,))
|
||||
self.action_space = gym.spaces.Tuple((action_space_agent, env.action_space[1]))
|
||||
|
||||
def step(self, action):
|
||||
action_agent, telop_action = action
|
||||
real_action = action_agent[:-1]
|
||||
final_action = (real_action, telop_action)
|
||||
obs, reward, terminated, truncated, info = self.env.step(final_action)
|
||||
return obs, reward, terminated, truncated, info
|
||||
|
||||
def make_maniskill(
|
||||
cfg: ManiskillEnvConfig,
|
||||
|
@ -197,40 +216,42 @@ def make_maniskill(
|
|||
env = ManiSkillCompat(env)
|
||||
env = ManiSkillActionWrapper(env)
|
||||
env = ManiSkillMultiplyActionWrapper(env, multiply_factor=0.03) # Scale actions for better control
|
||||
if cfg.mock_gripper:
|
||||
env = ManiskillMockGripperWrapper(env, nb_discrete_actions=3)
|
||||
|
||||
return env
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def main(cfg: ManiskillEnvConfig):
|
||||
"""Main function to run the ManiSkill environment."""
|
||||
# Create the ManiSkill environment
|
||||
env = make_maniskill(cfg, n_envs=1)
|
||||
# @parser.wrap()
|
||||
# def main(cfg: TrainPipelineConfig):
|
||||
# """Main function to run the ManiSkill environment."""
|
||||
# # Create the ManiSkill environment
|
||||
# env = make_maniskill(cfg.env, n_envs=1)
|
||||
|
||||
# Reset the environment
|
||||
obs, info = env.reset()
|
||||
# # Reset the environment
|
||||
# obs, info = env.reset()
|
||||
|
||||
# Run a simple interaction loop
|
||||
sum_reward = 0
|
||||
for i in range(100):
|
||||
# Sample a random action
|
||||
action = env.action_space.sample()
|
||||
# # Run a simple interaction loop
|
||||
# sum_reward = 0
|
||||
# for i in range(100):
|
||||
# # Sample a random action
|
||||
# action = env.action_space.sample()
|
||||
|
||||
# Step the environment
|
||||
start_time = time.perf_counter()
|
||||
obs, reward, terminated, truncated, info = env.step(action)
|
||||
step_time = time.perf_counter() - start_time
|
||||
sum_reward += reward
|
||||
# Log information
|
||||
# # Step the environment
|
||||
# start_time = time.perf_counter()
|
||||
# obs, reward, terminated, truncated, info = env.step(action)
|
||||
# step_time = time.perf_counter() - start_time
|
||||
# sum_reward += reward
|
||||
# # Log information
|
||||
|
||||
# Reset if episode terminated
|
||||
if terminated or truncated:
|
||||
logging.info(f"Step {i}, reward: {sum_reward}, step time: {step_time}s")
|
||||
sum_reward = 0
|
||||
obs, info = env.reset()
|
||||
# # Reset if episode terminated
|
||||
# if terminated or truncated:
|
||||
# logging.info(f"Step {i}, reward: {sum_reward}, step time: {step_time}s")
|
||||
# sum_reward = 0
|
||||
# obs, info = env.reset()
|
||||
|
||||
# Close the environment
|
||||
env.close()
|
||||
# # Close the environment
|
||||
# env.close()
|
||||
|
||||
|
||||
# if __name__ == "__main__":
|
||||
|
|
Loading…
Reference in New Issue