Add mock gripper support and enhance SAC policy action handling
- Introduced mock_gripper parameter in ManiskillEnvConfig to enable gripper simulation. - Added ManiskillMockGripperWrapper to adjust action space for environments with discrete actions. - Updated SACPolicy to compute continuous action dimensions correctly, ensuring compatibility with the new gripper setup. - Refactored action handling in the training loop to accommodate the changes in action dimensions.
This commit is contained in:
parent
306c735172
commit
451a7b01db
|
@ -257,6 +257,7 @@ class ManiskillEnvConfig(EnvConfig):
|
||||||
robot: str = "so100" # This is a hack to make the robot config work
|
robot: str = "so100" # This is a hack to make the robot config work
|
||||||
video_record: VideoRecordConfig = field(default_factory=VideoRecordConfig)
|
video_record: VideoRecordConfig = field(default_factory=VideoRecordConfig)
|
||||||
wrapper: WrapperConfig = field(default_factory=WrapperConfig)
|
wrapper: WrapperConfig = field(default_factory=WrapperConfig)
|
||||||
|
mock_gripper: bool = False
|
||||||
features: dict[str, PolicyFeature] = field(
|
features: dict[str, PolicyFeature] = field(
|
||||||
default_factory=lambda: {
|
default_factory=lambda: {
|
||||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
|
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
|
||||||
|
|
|
@ -33,7 +33,7 @@ from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||||
from lerobot.common.policies.sac.configuration_sac import SACConfig
|
from lerobot.common.policies.sac.configuration_sac import SACConfig
|
||||||
from lerobot.common.policies.utils import get_device_from_parameters
|
from lerobot.common.policies.utils import get_device_from_parameters
|
||||||
|
|
||||||
DISCRETE_DIMENSION_INDEX = -1
|
DISCRETE_DIMENSION_INDEX = -1 # Gripper is always the last dimension
|
||||||
|
|
||||||
class SACPolicy(
|
class SACPolicy(
|
||||||
PreTrainedPolicy,
|
PreTrainedPolicy,
|
||||||
|
@ -82,7 +82,7 @@ class SACPolicy(
|
||||||
# Create a list of critic heads
|
# Create a list of critic heads
|
||||||
critic_heads = [
|
critic_heads = [
|
||||||
CriticHead(
|
CriticHead(
|
||||||
input_dim=encoder_critic.output_dim + config.output_features["action"].shape[0],
|
input_dim=encoder_critic.output_dim + continuous_action_dim,
|
||||||
**asdict(config.critic_network_kwargs),
|
**asdict(config.critic_network_kwargs),
|
||||||
)
|
)
|
||||||
for _ in range(config.num_critics)
|
for _ in range(config.num_critics)
|
||||||
|
@ -97,7 +97,7 @@ class SACPolicy(
|
||||||
# Create target critic heads as deepcopies of the original critic heads
|
# Create target critic heads as deepcopies of the original critic heads
|
||||||
target_critic_heads = [
|
target_critic_heads = [
|
||||||
CriticHead(
|
CriticHead(
|
||||||
input_dim=encoder_critic.output_dim + config.output_features["action"].shape[0],
|
input_dim=encoder_critic.output_dim + continuous_action_dim,
|
||||||
**asdict(config.critic_network_kwargs),
|
**asdict(config.critic_network_kwargs),
|
||||||
)
|
)
|
||||||
for _ in range(config.num_critics)
|
for _ in range(config.num_critics)
|
||||||
|
@ -117,7 +117,10 @@ class SACPolicy(
|
||||||
self.grasp_critic = None
|
self.grasp_critic = None
|
||||||
self.grasp_critic_target = None
|
self.grasp_critic_target = None
|
||||||
|
|
||||||
|
continuous_action_dim = config.output_features["action"].shape[0]
|
||||||
if config.num_discrete_actions is not None:
|
if config.num_discrete_actions is not None:
|
||||||
|
|
||||||
|
continuous_action_dim -= 1
|
||||||
# Create grasp critic
|
# Create grasp critic
|
||||||
self.grasp_critic = GraspCritic(
|
self.grasp_critic = GraspCritic(
|
||||||
encoder=encoder_critic,
|
encoder=encoder_critic,
|
||||||
|
@ -139,15 +142,16 @@ class SACPolicy(
|
||||||
self.grasp_critic = torch.compile(self.grasp_critic)
|
self.grasp_critic = torch.compile(self.grasp_critic)
|
||||||
self.grasp_critic_target = torch.compile(self.grasp_critic_target)
|
self.grasp_critic_target = torch.compile(self.grasp_critic_target)
|
||||||
|
|
||||||
|
|
||||||
self.actor = Policy(
|
self.actor = Policy(
|
||||||
encoder=encoder_actor,
|
encoder=encoder_actor,
|
||||||
network=MLP(input_dim=encoder_actor.output_dim, **asdict(config.actor_network_kwargs)),
|
network=MLP(input_dim=encoder_actor.output_dim, **asdict(config.actor_network_kwargs)),
|
||||||
action_dim=config.output_features["action"].shape[0],
|
action_dim=continuous_action_dim,
|
||||||
encoder_is_shared=config.shared_encoder,
|
encoder_is_shared=config.shared_encoder,
|
||||||
**asdict(config.policy_kwargs),
|
**asdict(config.policy_kwargs),
|
||||||
)
|
)
|
||||||
if config.target_entropy is None:
|
if config.target_entropy is None:
|
||||||
config.target_entropy = -np.prod(config.output_features["action"].shape[0]) / 2 # (-dim(A)/2)
|
config.target_entropy = -np.prod(continuous_action_dim) / 2 # (-dim(A)/2)
|
||||||
|
|
||||||
# TODO (azouitine): Handle the case where the temparameter is a fixed
|
# TODO (azouitine): Handle the case where the temparameter is a fixed
|
||||||
# TODO (michel-aractingi): Put the log_alpha in cuda by default because otherwise
|
# TODO (michel-aractingi): Put the log_alpha in cuda by default because otherwise
|
||||||
|
@ -275,7 +279,9 @@ class SACPolicy(
|
||||||
next_observations=next_observations,
|
next_observations=next_observations,
|
||||||
done=done,
|
done=done,
|
||||||
)
|
)
|
||||||
return {"loss_critic": loss_critic, "loss_grasp_critic": loss_grasp_critic}
|
return {"loss_critic": loss_critic, "loss_grasp_critic": loss_grasp_critic}
|
||||||
|
|
||||||
|
return {"loss_critic": loss_critic}
|
||||||
|
|
||||||
|
|
||||||
if model == "actor":
|
if model == "actor":
|
||||||
|
|
|
@ -11,6 +11,10 @@ from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
|
||||||
|
|
||||||
from lerobot.common.envs.configs import ManiskillEnvConfig
|
from lerobot.common.envs.configs import ManiskillEnvConfig
|
||||||
from lerobot.configs import parser
|
from lerobot.configs import parser
|
||||||
|
from lerobot.configs.train import TrainPipelineConfig
|
||||||
|
from lerobot.common.policies.sac.configuration_sac import SACConfig
|
||||||
|
from lerobot.common.policies.sac.modeling_sac import SACPolicy
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def preprocess_maniskill_observation(
|
def preprocess_maniskill_observation(
|
||||||
|
@ -152,6 +156,21 @@ class TimeLimitWrapper(gym.Wrapper):
|
||||||
self.current_step = 0
|
self.current_step = 0
|
||||||
return super().reset(seed=seed, options=options)
|
return super().reset(seed=seed, options=options)
|
||||||
|
|
||||||
|
class ManiskillMockGripperWrapper(gym.Wrapper):
|
||||||
|
def __init__(self, env, nb_discrete_actions: int = 3):
|
||||||
|
super().__init__(env)
|
||||||
|
new_shape = env.action_space[0].shape[0] + 1
|
||||||
|
new_low = np.concatenate([env.action_space[0].low, [0]])
|
||||||
|
new_high = np.concatenate([env.action_space[0].high, [nb_discrete_actions - 1]])
|
||||||
|
action_space_agent = gym.spaces.Box(low=new_low, high=new_high, shape=(new_shape,))
|
||||||
|
self.action_space = gym.spaces.Tuple((action_space_agent, env.action_space[1]))
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
action_agent, telop_action = action
|
||||||
|
real_action = action_agent[:-1]
|
||||||
|
final_action = (real_action, telop_action)
|
||||||
|
obs, reward, terminated, truncated, info = self.env.step(final_action)
|
||||||
|
return obs, reward, terminated, truncated, info
|
||||||
|
|
||||||
def make_maniskill(
|
def make_maniskill(
|
||||||
cfg: ManiskillEnvConfig,
|
cfg: ManiskillEnvConfig,
|
||||||
|
@ -197,40 +216,42 @@ def make_maniskill(
|
||||||
env = ManiSkillCompat(env)
|
env = ManiSkillCompat(env)
|
||||||
env = ManiSkillActionWrapper(env)
|
env = ManiSkillActionWrapper(env)
|
||||||
env = ManiSkillMultiplyActionWrapper(env, multiply_factor=0.03) # Scale actions for better control
|
env = ManiSkillMultiplyActionWrapper(env, multiply_factor=0.03) # Scale actions for better control
|
||||||
|
if cfg.mock_gripper:
|
||||||
|
env = ManiskillMockGripperWrapper(env, nb_discrete_actions=3)
|
||||||
|
|
||||||
return env
|
return env
|
||||||
|
|
||||||
|
|
||||||
@parser.wrap()
|
# @parser.wrap()
|
||||||
def main(cfg: ManiskillEnvConfig):
|
# def main(cfg: TrainPipelineConfig):
|
||||||
"""Main function to run the ManiSkill environment."""
|
# """Main function to run the ManiSkill environment."""
|
||||||
# Create the ManiSkill environment
|
# # Create the ManiSkill environment
|
||||||
env = make_maniskill(cfg, n_envs=1)
|
# env = make_maniskill(cfg.env, n_envs=1)
|
||||||
|
|
||||||
# Reset the environment
|
# # Reset the environment
|
||||||
obs, info = env.reset()
|
# obs, info = env.reset()
|
||||||
|
|
||||||
# Run a simple interaction loop
|
# # Run a simple interaction loop
|
||||||
sum_reward = 0
|
# sum_reward = 0
|
||||||
for i in range(100):
|
# for i in range(100):
|
||||||
# Sample a random action
|
# # Sample a random action
|
||||||
action = env.action_space.sample()
|
# action = env.action_space.sample()
|
||||||
|
|
||||||
# Step the environment
|
# # Step the environment
|
||||||
start_time = time.perf_counter()
|
# start_time = time.perf_counter()
|
||||||
obs, reward, terminated, truncated, info = env.step(action)
|
# obs, reward, terminated, truncated, info = env.step(action)
|
||||||
step_time = time.perf_counter() - start_time
|
# step_time = time.perf_counter() - start_time
|
||||||
sum_reward += reward
|
# sum_reward += reward
|
||||||
# Log information
|
# # Log information
|
||||||
|
|
||||||
# Reset if episode terminated
|
# # Reset if episode terminated
|
||||||
if terminated or truncated:
|
# if terminated or truncated:
|
||||||
logging.info(f"Step {i}, reward: {sum_reward}, step time: {step_time}s")
|
# logging.info(f"Step {i}, reward: {sum_reward}, step time: {step_time}s")
|
||||||
sum_reward = 0
|
# sum_reward = 0
|
||||||
obs, info = env.reset()
|
# obs, info = env.reset()
|
||||||
|
|
||||||
# Close the environment
|
# # Close the environment
|
||||||
env.close()
|
# env.close()
|
||||||
|
|
||||||
|
|
||||||
# if __name__ == "__main__":
|
# if __name__ == "__main__":
|
||||||
|
|
Loading…
Reference in New Issue