- Fixed big issue in the loading of the policy parameters sent by the learner to the actor -- pass only the actor to the `update_policy_parameters` and remove `strict=False`

- Fixed big issue in the normalization of the actions in the `forward` function of the critic -- remove the `torch.no_grad` decorator in `normalize.py` in the normalization function
- Fixed performance issue to boost the optimization frequency by setting the storage device to be the same as the device of learning.

Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
Michel Aractingi 2025-02-19 16:22:51 +00:00
parent befa1fe9af
commit ff47c0b0d3
7 changed files with 68 additions and 57 deletions

View File

@ -130,7 +130,7 @@ class Normalize(nn.Module):
setattr(self, "buffer_" + key.replace(".", "_"), buffer) setattr(self, "buffer_" + key.replace(".", "_"), buffer)
# TODO(rcadene): should we remove torch.no_grad? # TODO(rcadene): should we remove torch.no_grad?
@torch.no_grad # @torch.no_grad
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
batch = dict(batch) # shallow copy avoids mutating the input batch batch = dict(batch) # shallow copy avoids mutating the input batch
for key, mode in self.modes.items(): for key, mode in self.modes.items():

View File

@ -80,8 +80,8 @@ class SACPolicy(
encoder_critic = SACObservationEncoder(config, self.normalize_inputs) encoder_critic = SACObservationEncoder(config, self.normalize_inputs)
encoder_actor: SACObservationEncoder = encoder_critic encoder_actor: SACObservationEncoder = encoder_critic
else: else:
encoder_critic = SACObservationEncoder(config) encoder_critic = SACObservationEncoder(config, self.normalize_inputs)
encoder_actor = SACObservationEncoder(config) encoder_actor = SACObservationEncoder(config, self.normalize_inputs)
self.critic_ensemble = CriticEnsemble( self.critic_ensemble = CriticEnsemble(
encoder=encoder_critic, encoder=encoder_critic,

View File

@ -64,13 +64,29 @@ policy:
action: [7] action: [7]
# Normalization / Unnormalization # Normalization / Unnormalization
input_normalization_modes: null input_normalization_modes:
observation.state: min_max
input_normalization_params:
observation.state:
min: [-1.9361e+00, -7.7640e-01, -7.7094e-01, -2.9709e+00, -8.5656e-01,
1.0764e+00, -1.2680e+00, 0.0000e+00, 0.0000e+00, -9.3448e+00,
-3.3828e+00, -3.8420e+00, -5.2553e+00, -3.4154e+00, -6.5082e+00,
-6.0500e+00, -8.7193e+00, -8.2337e+00, -3.4650e-01, -4.9441e-01,
8.3516e-03, -3.1114e-01, -9.9700e-01, -2.3471e-01, -2.7137e-01]
max: [ 0.8644, 1.4306, 1.8520, -0.7578, 0.9508, 3.4901, 1.9381, 0.0400,
0.0400, 5.0885, 4.7156, 7.9393, 7.9100, 2.9796, 5.7720, 4.7163,
7.8145, 9.7415, 0.2422, 0.4505, 0.6306, 0.2622, 1.0000, 0.5135,
0.4001]
output_normalization_modes: output_normalization_modes:
action: min_max action: min_max
output_normalization_params: output_normalization_params:
action: action:
min: [-10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0] min: [-10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0]
max: [10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0] max: [10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0]
output_normalization_shapes:
action: [7]
# Architecture / modeling. # Architecture / modeling.
# Neural networks. # Neural networks.

View File

@ -166,7 +166,7 @@ def update_policy_parameters(policy: SACPolicy, parameters_queue: queue.Queue, d
logging.info("[ACTOR] Load new parameters from Learner.") logging.info("[ACTOR] Load new parameters from Learner.")
state_dict = parameters_queue.get() state_dict = parameters_queue.get()
state_dict = move_state_dict_to_device(state_dict, device=device) state_dict = move_state_dict_to_device(state_dict, device=device)
policy.load_state_dict(state_dict, strict=False) policy.load_state_dict(state_dict)
def act_with_policy(cfg: DictConfig, robot: Robot, reward_classifier: nn.Module): def act_with_policy(cfg: DictConfig, robot: Robot, reward_classifier: nn.Module):
@ -182,7 +182,7 @@ def act_with_policy(cfg: DictConfig, robot: Robot, reward_classifier: nn.Module)
logging.info("make_env online") logging.info("make_env online")
online_env = make_robot_env(robot=robot, reward_classifier=reward_classifier, cfg=cfg.env) online_env = make_robot_env(robot=robot, reward_classifier=reward_classifier, cfg=cfg)
set_global_seed(cfg.seed) set_global_seed(cfg.seed)
device = get_safe_torch_device(cfg.device, log=True) device = get_safe_torch_device(cfg.device, log=True)
@ -283,7 +283,7 @@ def act_with_policy(cfg: DictConfig, robot: Robot, reward_classifier: nn.Module)
# TODO: Handle logging for episode information # TODO: Handle logging for episode information
logging.info(f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}") logging.info(f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}")
update_policy_parameters(policy=policy, parameters_queue=parameters_queue, device=device) update_policy_parameters(policy=policy.actor, parameters_queue=parameters_queue, device=device)
if len(list_transition_to_send_to_learner) > 0: if len(list_transition_to_send_to_learner) > 0:
send_transitions_in_chunks( send_transitions_in_chunks(

View File

@ -684,38 +684,34 @@ def make_robot_env(
Returns: Returns:
A vectorized gym environment with all the necessary wrappers applied. A vectorized gym environment with all the necessary wrappers applied.
""" """
if "maniskill" in cfg.name: if "maniskill" in cfg.env.name:
logging.warning("WE SHOULD REMOVE THE MANISKILL BEFORE THE MERGE INTO MAIN") logging.warning("WE SHOULD REMOVE THE MANISKILL BEFORE THE MERGE INTO MAIN")
env = make_maniskill( env = make_maniskill(
task=cfg.task, cfg=cfg,
obs_mode=cfg.obs, n_envs=1,
control_mode=cfg.control_mode,
render_mode=cfg.render_mode,
sensor_configs={"width": cfg.render_size, "height": cfg.render_size},
device=cfg.device,
) )
return env return env
# Create base environment # Create base environment
env = HILSerlRobotEnv( env = HILSerlRobotEnv(
robot=robot, robot=robot,
display_cameras=cfg.wrapper.display_cameras, display_cameras=cfg.env.wrapper.display_cameras,
delta=cfg.wrapper.delta_action, delta=cfg.env.wrapper.delta_action,
use_delta_action_space=cfg.wrapper.use_relative_joint_positions, use_delta_action_space=cfg.env.wrapper.use_relative_joint_positions,
) )
# Add observation and image processing # Add observation and image processing
env = ConvertToLeRobotObservation(env=env, device=cfg.device) env = ConvertToLeRobotObservation(env=env, device=cfg.device)
if cfg.wrapper.crop_params_dict is not None: if cfg.env.wrapper.crop_params_dict is not None:
env = ImageCropResizeWrapper( env = ImageCropResizeWrapper(
env=env, crop_params_dict=cfg.wrapper.crop_params_dict, resize_size=cfg.wrapper.resize_size env=env, crop_params_dict=cfg.env.wrapper.crop_params_dict, resize_size=cfg.env.wrapper.resize_size
) )
# Add reward computation and control wrappers # Add reward computation and control wrappers
env = RewardWrapper(env=env, reward_classifier=reward_classifier, device=cfg.device) env = RewardWrapper(env=env, reward_classifier=reward_classifier, device=cfg.device)
env = TimeLimitWrapper(env=env, control_time_s=cfg.wrapper.control_time_s, fps=cfg.fps) env = TimeLimitWrapper(env=env, control_time_s=cfg.env.wrapper.control_time_s, fps=cfg.fps)
env = KeyboardInterfaceWrapper(env=env) env = KeyboardInterfaceWrapper(env=env)
env = ResetWrapper(env=env, reset_fn=None, reset_time_s=cfg.wrapper.reset_time_s) env = ResetWrapper(env=env, reset_fn=None, reset_time_s=cfg.env.wrapper.reset_time_s)
env = JointMaskingActionSpace(env=env, mask=cfg.wrapper.joint_masking_action_space) env = JointMaskingActionSpace(env=env, mask=cfg.env.wrapper.joint_masking_action_space)
env = BatchCompitableWrapper(env=env) env = BatchCompitableWrapper(env=env)
return env return env

View File

@ -142,6 +142,7 @@ def initialize_replay_buffer(cfg: DictConfig, logger: Logger, device: str) -> Re
capacity=cfg.training.online_buffer_capacity, capacity=cfg.training.online_buffer_capacity,
device=device, device=device,
state_keys=cfg.policy.input_shapes.keys(), state_keys=cfg.policy.input_shapes.keys(),
storage_device=device
) )
dataset = LeRobotDataset( dataset = LeRobotDataset(

View File

@ -3,10 +3,14 @@ import numpy as np
import gymnasium as gym import gymnasium as gym
import torch import torch
from omegaconf import DictConfig
from typing import Any
"""Make ManiSkill3 gym environment""" """Make ManiSkill3 gym environment"""
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
def preprocess_maniskill_observation(observations: dict[str, np.ndarray]) -> dict[str, torch.Tensor]: def preprocess_maniskill_observation(observations: dict[str, np.ndarray]) -> dict[str, torch.Tensor]:
"""Convert environment observation to LeRobot format observation. """Convert environment observation to LeRobot format observation.
Args: Args:
@ -43,32 +47,29 @@ def preprocess_maniskill_observation(observations: dict[str, np.ndarray]) -> dic
class ManiSkillObservationWrapper(gym.ObservationWrapper): class ManiSkillObservationWrapper(gym.ObservationWrapper):
def __init__(self, env):
super().__init__(env)
def observation(self, observation):
return preprocess_maniskill_observation(observation)
class ManiSkillToDeviceWrapper(gym.Wrapper):
def __init__(self, env, device: torch.device = "cuda"): def __init__(self, env, device: torch.device = "cuda"):
super().__init__(env) super().__init__(env)
self.device = device self.device = device
def reset(self, seed=None, options=None): def observation(self, observation):
obs, info = self.env.reset(seed=seed, options=options) observation = preprocess_maniskill_observation(observation)
obs = {k: v.to(self.device) for k, v in obs.items()} observation = {k: v.to(self.device) for k, v in observation.items()}
return obs, info return observation
def step(self, action):
obs, reward, terminated, truncated, info = self.env.step(action)
obs = {k: v.to(self.device) for k, v in obs.items()}
return obs, reward, terminated, truncated, info
class ManiSkillCompat(gym.Wrapper): class ManiSkillCompat(gym.Wrapper):
def __init__(self, env): def __init__(self, env):
super().__init__(env) super().__init__(env)
new_action_space_shape = env.action_space.shape[-1]
new_low = np.squeeze(env.action_space.low, axis=0)
new_high = np.squeeze(env.action_space.high, axis=0)
self.action_space = gym.spaces.Box(low=new_low, high=new_high, shape=(new_action_space_shape,))
def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[Any, dict[str, Any]]:
options = {}
return super().reset(seed=seed, options=options)
def step(self, action): def step(self, action):
obs, reward, terminated, truncated, info = self.env.step(action) obs, reward, terminated, truncated, info = self.env.step(action)
@ -89,7 +90,7 @@ class ManiSkillActionWrapper(gym.ActionWrapper):
class ManiSkillMultiplyActionWrapper(gym.Wrapper): class ManiSkillMultiplyActionWrapper(gym.Wrapper):
def __init__(self, env, multiply_factor: float = 10): def __init__(self, env, multiply_factor: float = 1):
super().__init__(env) super().__init__(env)
self.multiply_factor = multiply_factor self.multiply_factor = multiply_factor
action_space_agent: gym.spaces.Box = env.action_space[0] action_space_agent: gym.spaces.Box = env.action_space[0]
@ -108,13 +109,8 @@ class ManiSkillMultiplyActionWrapper(gym.Wrapper):
def make_maniskill( def make_maniskill(
task: str = "PushCube-v1", cfg: DictConfig,
obs_mode: str = "rgb", n_envs: int | None = None,
control_mode: str = "pd_ee_delta_pose",
render_mode: str = "rgb_array",
sensor_configs: dict[str, int] | None = None,
n_envs: int = 1,
device: torch.device = "cuda",
) -> gym.Env: ) -> gym.Env:
""" """
Factory function to create a ManiSkill environment with standard wrappers. Factory function to create a ManiSkill environment with standard wrappers.
@ -130,22 +126,24 @@ def make_maniskill(
Returns: Returns:
A wrapped ManiSkill environment A wrapped ManiSkill environment
""" """
if sensor_configs is None:
sensor_configs = {"width": 64, "height": 64}
env = gym.make( env = gym.make(
task, cfg.env.task,
obs_mode=obs_mode, obs_mode=cfg.env.obs,
control_mode=control_mode, control_mode=cfg.env.control_mode,
render_mode=render_mode, render_mode=cfg.env.render_mode,
sensor_configs=sensor_configs, sensor_configs={"width": cfg.env.image_size, "height": cfg.env.image_size},
num_envs=n_envs, num_envs=n_envs,
) )
env = ManiSkillObservationWrapper(env, device=cfg.env.device)
env = ManiSkillVectorEnv(env, ignore_terminations=True, auto_reset=False)
env._max_episode_steps = env.max_episode_steps = 50 # gym_utils.find_max_episode_steps_value(env)
env.unwrapped.metadata["render_fps"] = 20
env = ManiSkillCompat(env) env = ManiSkillCompat(env)
env = ManiSkillObservationWrapper(env)
env = ManiSkillActionWrapper(env) env = ManiSkillActionWrapper(env)
env = ManiSkillMultiplyActionWrapper(env) env = ManiSkillMultiplyActionWrapper(env, multiply_factor=10.0)
env = ManiSkillToDeviceWrapper(env, device=device)
return env return env