- Fixed big issue in the loading of the policy parameters sent by the learner to the actor -- pass only the actor to the `update_policy_parameters` and remove `strict=False`
- Fixed big issue in the normalization of the actions in the `forward` function of the critic -- remove the `torch.no_grad` decorator in `normalize.py` in the normalization function - Fixed performance issue to boost the optimization frequency by setting the storage device to be the same as the device of learning. Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
parent
befa1fe9af
commit
ff47c0b0d3
|
@ -130,7 +130,7 @@ class Normalize(nn.Module):
|
|||
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
|
||||
|
||||
# TODO(rcadene): should we remove torch.no_grad?
|
||||
@torch.no_grad
|
||||
# @torch.no_grad
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
batch = dict(batch) # shallow copy avoids mutating the input batch
|
||||
for key, mode in self.modes.items():
|
||||
|
|
|
@ -80,8 +80,8 @@ class SACPolicy(
|
|||
encoder_critic = SACObservationEncoder(config, self.normalize_inputs)
|
||||
encoder_actor: SACObservationEncoder = encoder_critic
|
||||
else:
|
||||
encoder_critic = SACObservationEncoder(config)
|
||||
encoder_actor = SACObservationEncoder(config)
|
||||
encoder_critic = SACObservationEncoder(config, self.normalize_inputs)
|
||||
encoder_actor = SACObservationEncoder(config, self.normalize_inputs)
|
||||
|
||||
self.critic_ensemble = CriticEnsemble(
|
||||
encoder=encoder_critic,
|
||||
|
|
|
@ -64,13 +64,29 @@ policy:
|
|||
action: [7]
|
||||
|
||||
# Normalization / Unnormalization
|
||||
input_normalization_modes: null
|
||||
input_normalization_modes:
|
||||
observation.state: min_max
|
||||
input_normalization_params:
|
||||
observation.state:
|
||||
min: [-1.9361e+00, -7.7640e-01, -7.7094e-01, -2.9709e+00, -8.5656e-01,
|
||||
1.0764e+00, -1.2680e+00, 0.0000e+00, 0.0000e+00, -9.3448e+00,
|
||||
-3.3828e+00, -3.8420e+00, -5.2553e+00, -3.4154e+00, -6.5082e+00,
|
||||
-6.0500e+00, -8.7193e+00, -8.2337e+00, -3.4650e-01, -4.9441e-01,
|
||||
8.3516e-03, -3.1114e-01, -9.9700e-01, -2.3471e-01, -2.7137e-01]
|
||||
|
||||
max: [ 0.8644, 1.4306, 1.8520, -0.7578, 0.9508, 3.4901, 1.9381, 0.0400,
|
||||
0.0400, 5.0885, 4.7156, 7.9393, 7.9100, 2.9796, 5.7720, 4.7163,
|
||||
7.8145, 9.7415, 0.2422, 0.4505, 0.6306, 0.2622, 1.0000, 0.5135,
|
||||
0.4001]
|
||||
|
||||
output_normalization_modes:
|
||||
action: min_max
|
||||
output_normalization_params:
|
||||
action:
|
||||
min: [-10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0]
|
||||
max: [10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0]
|
||||
output_normalization_shapes:
|
||||
action: [7]
|
||||
|
||||
# Architecture / modeling.
|
||||
# Neural networks.
|
||||
|
|
|
@ -166,7 +166,7 @@ def update_policy_parameters(policy: SACPolicy, parameters_queue: queue.Queue, d
|
|||
logging.info("[ACTOR] Load new parameters from Learner.")
|
||||
state_dict = parameters_queue.get()
|
||||
state_dict = move_state_dict_to_device(state_dict, device=device)
|
||||
policy.load_state_dict(state_dict, strict=False)
|
||||
policy.load_state_dict(state_dict)
|
||||
|
||||
|
||||
def act_with_policy(cfg: DictConfig, robot: Robot, reward_classifier: nn.Module):
|
||||
|
@ -182,7 +182,7 @@ def act_with_policy(cfg: DictConfig, robot: Robot, reward_classifier: nn.Module)
|
|||
|
||||
logging.info("make_env online")
|
||||
|
||||
online_env = make_robot_env(robot=robot, reward_classifier=reward_classifier, cfg=cfg.env)
|
||||
online_env = make_robot_env(robot=robot, reward_classifier=reward_classifier, cfg=cfg)
|
||||
|
||||
set_global_seed(cfg.seed)
|
||||
device = get_safe_torch_device(cfg.device, log=True)
|
||||
|
@ -283,7 +283,7 @@ def act_with_policy(cfg: DictConfig, robot: Robot, reward_classifier: nn.Module)
|
|||
# TODO: Handle logging for episode information
|
||||
logging.info(f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}")
|
||||
|
||||
update_policy_parameters(policy=policy, parameters_queue=parameters_queue, device=device)
|
||||
update_policy_parameters(policy=policy.actor, parameters_queue=parameters_queue, device=device)
|
||||
|
||||
if len(list_transition_to_send_to_learner) > 0:
|
||||
send_transitions_in_chunks(
|
||||
|
|
|
@ -684,38 +684,34 @@ def make_robot_env(
|
|||
Returns:
|
||||
A vectorized gym environment with all the necessary wrappers applied.
|
||||
"""
|
||||
if "maniskill" in cfg.name:
|
||||
if "maniskill" in cfg.env.name:
|
||||
logging.warning("WE SHOULD REMOVE THE MANISKILL BEFORE THE MERGE INTO MAIN")
|
||||
env = make_maniskill(
|
||||
task=cfg.task,
|
||||
obs_mode=cfg.obs,
|
||||
control_mode=cfg.control_mode,
|
||||
render_mode=cfg.render_mode,
|
||||
sensor_configs={"width": cfg.render_size, "height": cfg.render_size},
|
||||
device=cfg.device,
|
||||
cfg=cfg,
|
||||
n_envs=1,
|
||||
)
|
||||
return env
|
||||
# Create base environment
|
||||
env = HILSerlRobotEnv(
|
||||
robot=robot,
|
||||
display_cameras=cfg.wrapper.display_cameras,
|
||||
delta=cfg.wrapper.delta_action,
|
||||
use_delta_action_space=cfg.wrapper.use_relative_joint_positions,
|
||||
display_cameras=cfg.env.wrapper.display_cameras,
|
||||
delta=cfg.env.wrapper.delta_action,
|
||||
use_delta_action_space=cfg.env.wrapper.use_relative_joint_positions,
|
||||
)
|
||||
|
||||
# Add observation and image processing
|
||||
env = ConvertToLeRobotObservation(env=env, device=cfg.device)
|
||||
if cfg.wrapper.crop_params_dict is not None:
|
||||
if cfg.env.wrapper.crop_params_dict is not None:
|
||||
env = ImageCropResizeWrapper(
|
||||
env=env, crop_params_dict=cfg.wrapper.crop_params_dict, resize_size=cfg.wrapper.resize_size
|
||||
env=env, crop_params_dict=cfg.env.wrapper.crop_params_dict, resize_size=cfg.env.wrapper.resize_size
|
||||
)
|
||||
|
||||
# Add reward computation and control wrappers
|
||||
env = RewardWrapper(env=env, reward_classifier=reward_classifier, device=cfg.device)
|
||||
env = TimeLimitWrapper(env=env, control_time_s=cfg.wrapper.control_time_s, fps=cfg.fps)
|
||||
env = TimeLimitWrapper(env=env, control_time_s=cfg.env.wrapper.control_time_s, fps=cfg.fps)
|
||||
env = KeyboardInterfaceWrapper(env=env)
|
||||
env = ResetWrapper(env=env, reset_fn=None, reset_time_s=cfg.wrapper.reset_time_s)
|
||||
env = JointMaskingActionSpace(env=env, mask=cfg.wrapper.joint_masking_action_space)
|
||||
env = ResetWrapper(env=env, reset_fn=None, reset_time_s=cfg.env.wrapper.reset_time_s)
|
||||
env = JointMaskingActionSpace(env=env, mask=cfg.env.wrapper.joint_masking_action_space)
|
||||
env = BatchCompitableWrapper(env=env)
|
||||
|
||||
return env
|
||||
|
|
|
@ -142,6 +142,7 @@ def initialize_replay_buffer(cfg: DictConfig, logger: Logger, device: str) -> Re
|
|||
capacity=cfg.training.online_buffer_capacity,
|
||||
device=device,
|
||||
state_keys=cfg.policy.input_shapes.keys(),
|
||||
storage_device=device
|
||||
)
|
||||
|
||||
dataset = LeRobotDataset(
|
||||
|
|
|
@ -3,10 +3,14 @@ import numpy as np
|
|||
import gymnasium as gym
|
||||
import torch
|
||||
|
||||
from omegaconf import DictConfig
|
||||
from typing import Any
|
||||
|
||||
"""Make ManiSkill3 gym environment"""
|
||||
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
|
||||
|
||||
|
||||
|
||||
def preprocess_maniskill_observation(observations: dict[str, np.ndarray]) -> dict[str, torch.Tensor]:
|
||||
"""Convert environment observation to LeRobot format observation.
|
||||
Args:
|
||||
|
@ -43,32 +47,29 @@ def preprocess_maniskill_observation(observations: dict[str, np.ndarray]) -> dic
|
|||
|
||||
|
||||
class ManiSkillObservationWrapper(gym.ObservationWrapper):
|
||||
def __init__(self, env):
|
||||
super().__init__(env)
|
||||
|
||||
def observation(self, observation):
|
||||
return preprocess_maniskill_observation(observation)
|
||||
|
||||
|
||||
class ManiSkillToDeviceWrapper(gym.Wrapper):
|
||||
def __init__(self, env, device: torch.device = "cuda"):
|
||||
super().__init__(env)
|
||||
self.device = device
|
||||
|
||||
def reset(self, seed=None, options=None):
|
||||
obs, info = self.env.reset(seed=seed, options=options)
|
||||
obs = {k: v.to(self.device) for k, v in obs.items()}
|
||||
return obs, info
|
||||
|
||||
def step(self, action):
|
||||
obs, reward, terminated, truncated, info = self.env.step(action)
|
||||
obs = {k: v.to(self.device) for k, v in obs.items()}
|
||||
return obs, reward, terminated, truncated, info
|
||||
def observation(self, observation):
|
||||
observation = preprocess_maniskill_observation(observation)
|
||||
observation = {k: v.to(self.device) for k, v in observation.items()}
|
||||
return observation
|
||||
|
||||
|
||||
class ManiSkillCompat(gym.Wrapper):
|
||||
def __init__(self, env):
|
||||
super().__init__(env)
|
||||
new_action_space_shape = env.action_space.shape[-1]
|
||||
new_low = np.squeeze(env.action_space.low, axis=0)
|
||||
new_high = np.squeeze(env.action_space.high, axis=0)
|
||||
self.action_space = gym.spaces.Box(low=new_low, high=new_high, shape=(new_action_space_shape,))
|
||||
|
||||
def reset(
|
||||
self, *, seed: int | None = None, options: dict[str, Any] | None = None
|
||||
) -> tuple[Any, dict[str, Any]]:
|
||||
options = {}
|
||||
return super().reset(seed=seed, options=options)
|
||||
|
||||
def step(self, action):
|
||||
obs, reward, terminated, truncated, info = self.env.step(action)
|
||||
|
@ -89,7 +90,7 @@ class ManiSkillActionWrapper(gym.ActionWrapper):
|
|||
|
||||
|
||||
class ManiSkillMultiplyActionWrapper(gym.Wrapper):
|
||||
def __init__(self, env, multiply_factor: float = 10):
|
||||
def __init__(self, env, multiply_factor: float = 1):
|
||||
super().__init__(env)
|
||||
self.multiply_factor = multiply_factor
|
||||
action_space_agent: gym.spaces.Box = env.action_space[0]
|
||||
|
@ -108,13 +109,8 @@ class ManiSkillMultiplyActionWrapper(gym.Wrapper):
|
|||
|
||||
|
||||
def make_maniskill(
|
||||
task: str = "PushCube-v1",
|
||||
obs_mode: str = "rgb",
|
||||
control_mode: str = "pd_ee_delta_pose",
|
||||
render_mode: str = "rgb_array",
|
||||
sensor_configs: dict[str, int] | None = None,
|
||||
n_envs: int = 1,
|
||||
device: torch.device = "cuda",
|
||||
cfg: DictConfig,
|
||||
n_envs: int | None = None,
|
||||
) -> gym.Env:
|
||||
"""
|
||||
Factory function to create a ManiSkill environment with standard wrappers.
|
||||
|
@ -130,22 +126,24 @@ def make_maniskill(
|
|||
Returns:
|
||||
A wrapped ManiSkill environment
|
||||
"""
|
||||
if sensor_configs is None:
|
||||
sensor_configs = {"width": 64, "height": 64}
|
||||
|
||||
env = gym.make(
|
||||
task,
|
||||
obs_mode=obs_mode,
|
||||
control_mode=control_mode,
|
||||
render_mode=render_mode,
|
||||
sensor_configs=sensor_configs,
|
||||
cfg.env.task,
|
||||
obs_mode=cfg.env.obs,
|
||||
control_mode=cfg.env.control_mode,
|
||||
render_mode=cfg.env.render_mode,
|
||||
sensor_configs={"width": cfg.env.image_size, "height": cfg.env.image_size},
|
||||
num_envs=n_envs,
|
||||
)
|
||||
|
||||
env = ManiSkillObservationWrapper(env, device=cfg.env.device)
|
||||
env = ManiSkillVectorEnv(env, ignore_terminations=True, auto_reset=False)
|
||||
env._max_episode_steps = env.max_episode_steps = 50 # gym_utils.find_max_episode_steps_value(env)
|
||||
env.unwrapped.metadata["render_fps"] = 20
|
||||
env = ManiSkillCompat(env)
|
||||
env = ManiSkillObservationWrapper(env)
|
||||
env = ManiSkillActionWrapper(env)
|
||||
env = ManiSkillMultiplyActionWrapper(env)
|
||||
env = ManiSkillToDeviceWrapper(env, device=device)
|
||||
env = ManiSkillMultiplyActionWrapper(env, multiply_factor=10.0)
|
||||
|
||||
return env
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue