- Fixed big issue in the loading of the policy parameters sent by the learner to the actor -- pass only the actor to the `update_policy_parameters` and remove `strict=False`
- Fixed big issue in the normalization of the actions in the `forward` function of the critic -- remove the `torch.no_grad` decorator in `normalize.py` in the normalization function - Fixed performance issue to boost the optimization frequency by setting the storage device to be the same as the device of learning. Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
parent
befa1fe9af
commit
ff47c0b0d3
|
@ -130,7 +130,7 @@ class Normalize(nn.Module):
|
||||||
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
|
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
|
||||||
|
|
||||||
# TODO(rcadene): should we remove torch.no_grad?
|
# TODO(rcadene): should we remove torch.no_grad?
|
||||||
@torch.no_grad
|
# @torch.no_grad
|
||||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||||
batch = dict(batch) # shallow copy avoids mutating the input batch
|
batch = dict(batch) # shallow copy avoids mutating the input batch
|
||||||
for key, mode in self.modes.items():
|
for key, mode in self.modes.items():
|
||||||
|
|
|
@ -80,8 +80,8 @@ class SACPolicy(
|
||||||
encoder_critic = SACObservationEncoder(config, self.normalize_inputs)
|
encoder_critic = SACObservationEncoder(config, self.normalize_inputs)
|
||||||
encoder_actor: SACObservationEncoder = encoder_critic
|
encoder_actor: SACObservationEncoder = encoder_critic
|
||||||
else:
|
else:
|
||||||
encoder_critic = SACObservationEncoder(config)
|
encoder_critic = SACObservationEncoder(config, self.normalize_inputs)
|
||||||
encoder_actor = SACObservationEncoder(config)
|
encoder_actor = SACObservationEncoder(config, self.normalize_inputs)
|
||||||
|
|
||||||
self.critic_ensemble = CriticEnsemble(
|
self.critic_ensemble = CriticEnsemble(
|
||||||
encoder=encoder_critic,
|
encoder=encoder_critic,
|
||||||
|
|
|
@ -64,13 +64,29 @@ policy:
|
||||||
action: [7]
|
action: [7]
|
||||||
|
|
||||||
# Normalization / Unnormalization
|
# Normalization / Unnormalization
|
||||||
input_normalization_modes: null
|
input_normalization_modes:
|
||||||
|
observation.state: min_max
|
||||||
|
input_normalization_params:
|
||||||
|
observation.state:
|
||||||
|
min: [-1.9361e+00, -7.7640e-01, -7.7094e-01, -2.9709e+00, -8.5656e-01,
|
||||||
|
1.0764e+00, -1.2680e+00, 0.0000e+00, 0.0000e+00, -9.3448e+00,
|
||||||
|
-3.3828e+00, -3.8420e+00, -5.2553e+00, -3.4154e+00, -6.5082e+00,
|
||||||
|
-6.0500e+00, -8.7193e+00, -8.2337e+00, -3.4650e-01, -4.9441e-01,
|
||||||
|
8.3516e-03, -3.1114e-01, -9.9700e-01, -2.3471e-01, -2.7137e-01]
|
||||||
|
|
||||||
|
max: [ 0.8644, 1.4306, 1.8520, -0.7578, 0.9508, 3.4901, 1.9381, 0.0400,
|
||||||
|
0.0400, 5.0885, 4.7156, 7.9393, 7.9100, 2.9796, 5.7720, 4.7163,
|
||||||
|
7.8145, 9.7415, 0.2422, 0.4505, 0.6306, 0.2622, 1.0000, 0.5135,
|
||||||
|
0.4001]
|
||||||
|
|
||||||
output_normalization_modes:
|
output_normalization_modes:
|
||||||
action: min_max
|
action: min_max
|
||||||
output_normalization_params:
|
output_normalization_params:
|
||||||
action:
|
action:
|
||||||
min: [-10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0]
|
min: [-10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0]
|
||||||
max: [10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0]
|
max: [10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0]
|
||||||
|
output_normalization_shapes:
|
||||||
|
action: [7]
|
||||||
|
|
||||||
# Architecture / modeling.
|
# Architecture / modeling.
|
||||||
# Neural networks.
|
# Neural networks.
|
||||||
|
|
|
@ -166,7 +166,7 @@ def update_policy_parameters(policy: SACPolicy, parameters_queue: queue.Queue, d
|
||||||
logging.info("[ACTOR] Load new parameters from Learner.")
|
logging.info("[ACTOR] Load new parameters from Learner.")
|
||||||
state_dict = parameters_queue.get()
|
state_dict = parameters_queue.get()
|
||||||
state_dict = move_state_dict_to_device(state_dict, device=device)
|
state_dict = move_state_dict_to_device(state_dict, device=device)
|
||||||
policy.load_state_dict(state_dict, strict=False)
|
policy.load_state_dict(state_dict)
|
||||||
|
|
||||||
|
|
||||||
def act_with_policy(cfg: DictConfig, robot: Robot, reward_classifier: nn.Module):
|
def act_with_policy(cfg: DictConfig, robot: Robot, reward_classifier: nn.Module):
|
||||||
|
@ -182,7 +182,7 @@ def act_with_policy(cfg: DictConfig, robot: Robot, reward_classifier: nn.Module)
|
||||||
|
|
||||||
logging.info("make_env online")
|
logging.info("make_env online")
|
||||||
|
|
||||||
online_env = make_robot_env(robot=robot, reward_classifier=reward_classifier, cfg=cfg.env)
|
online_env = make_robot_env(robot=robot, reward_classifier=reward_classifier, cfg=cfg)
|
||||||
|
|
||||||
set_global_seed(cfg.seed)
|
set_global_seed(cfg.seed)
|
||||||
device = get_safe_torch_device(cfg.device, log=True)
|
device = get_safe_torch_device(cfg.device, log=True)
|
||||||
|
@ -283,7 +283,7 @@ def act_with_policy(cfg: DictConfig, robot: Robot, reward_classifier: nn.Module)
|
||||||
# TODO: Handle logging for episode information
|
# TODO: Handle logging for episode information
|
||||||
logging.info(f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}")
|
logging.info(f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}")
|
||||||
|
|
||||||
update_policy_parameters(policy=policy, parameters_queue=parameters_queue, device=device)
|
update_policy_parameters(policy=policy.actor, parameters_queue=parameters_queue, device=device)
|
||||||
|
|
||||||
if len(list_transition_to_send_to_learner) > 0:
|
if len(list_transition_to_send_to_learner) > 0:
|
||||||
send_transitions_in_chunks(
|
send_transitions_in_chunks(
|
||||||
|
|
|
@ -684,38 +684,34 @@ def make_robot_env(
|
||||||
Returns:
|
Returns:
|
||||||
A vectorized gym environment with all the necessary wrappers applied.
|
A vectorized gym environment with all the necessary wrappers applied.
|
||||||
"""
|
"""
|
||||||
if "maniskill" in cfg.name:
|
if "maniskill" in cfg.env.name:
|
||||||
logging.warning("WE SHOULD REMOVE THE MANISKILL BEFORE THE MERGE INTO MAIN")
|
logging.warning("WE SHOULD REMOVE THE MANISKILL BEFORE THE MERGE INTO MAIN")
|
||||||
env = make_maniskill(
|
env = make_maniskill(
|
||||||
task=cfg.task,
|
cfg=cfg,
|
||||||
obs_mode=cfg.obs,
|
n_envs=1,
|
||||||
control_mode=cfg.control_mode,
|
|
||||||
render_mode=cfg.render_mode,
|
|
||||||
sensor_configs={"width": cfg.render_size, "height": cfg.render_size},
|
|
||||||
device=cfg.device,
|
|
||||||
)
|
)
|
||||||
return env
|
return env
|
||||||
# Create base environment
|
# Create base environment
|
||||||
env = HILSerlRobotEnv(
|
env = HILSerlRobotEnv(
|
||||||
robot=robot,
|
robot=robot,
|
||||||
display_cameras=cfg.wrapper.display_cameras,
|
display_cameras=cfg.env.wrapper.display_cameras,
|
||||||
delta=cfg.wrapper.delta_action,
|
delta=cfg.env.wrapper.delta_action,
|
||||||
use_delta_action_space=cfg.wrapper.use_relative_joint_positions,
|
use_delta_action_space=cfg.env.wrapper.use_relative_joint_positions,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add observation and image processing
|
# Add observation and image processing
|
||||||
env = ConvertToLeRobotObservation(env=env, device=cfg.device)
|
env = ConvertToLeRobotObservation(env=env, device=cfg.device)
|
||||||
if cfg.wrapper.crop_params_dict is not None:
|
if cfg.env.wrapper.crop_params_dict is not None:
|
||||||
env = ImageCropResizeWrapper(
|
env = ImageCropResizeWrapper(
|
||||||
env=env, crop_params_dict=cfg.wrapper.crop_params_dict, resize_size=cfg.wrapper.resize_size
|
env=env, crop_params_dict=cfg.env.wrapper.crop_params_dict, resize_size=cfg.env.wrapper.resize_size
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add reward computation and control wrappers
|
# Add reward computation and control wrappers
|
||||||
env = RewardWrapper(env=env, reward_classifier=reward_classifier, device=cfg.device)
|
env = RewardWrapper(env=env, reward_classifier=reward_classifier, device=cfg.device)
|
||||||
env = TimeLimitWrapper(env=env, control_time_s=cfg.wrapper.control_time_s, fps=cfg.fps)
|
env = TimeLimitWrapper(env=env, control_time_s=cfg.env.wrapper.control_time_s, fps=cfg.fps)
|
||||||
env = KeyboardInterfaceWrapper(env=env)
|
env = KeyboardInterfaceWrapper(env=env)
|
||||||
env = ResetWrapper(env=env, reset_fn=None, reset_time_s=cfg.wrapper.reset_time_s)
|
env = ResetWrapper(env=env, reset_fn=None, reset_time_s=cfg.env.wrapper.reset_time_s)
|
||||||
env = JointMaskingActionSpace(env=env, mask=cfg.wrapper.joint_masking_action_space)
|
env = JointMaskingActionSpace(env=env, mask=cfg.env.wrapper.joint_masking_action_space)
|
||||||
env = BatchCompitableWrapper(env=env)
|
env = BatchCompitableWrapper(env=env)
|
||||||
|
|
||||||
return env
|
return env
|
||||||
|
|
|
@ -142,6 +142,7 @@ def initialize_replay_buffer(cfg: DictConfig, logger: Logger, device: str) -> Re
|
||||||
capacity=cfg.training.online_buffer_capacity,
|
capacity=cfg.training.online_buffer_capacity,
|
||||||
device=device,
|
device=device,
|
||||||
state_keys=cfg.policy.input_shapes.keys(),
|
state_keys=cfg.policy.input_shapes.keys(),
|
||||||
|
storage_device=device
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset = LeRobotDataset(
|
dataset = LeRobotDataset(
|
||||||
|
|
|
@ -3,10 +3,14 @@ import numpy as np
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from omegaconf import DictConfig
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
"""Make ManiSkill3 gym environment"""
|
"""Make ManiSkill3 gym environment"""
|
||||||
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
|
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def preprocess_maniskill_observation(observations: dict[str, np.ndarray]) -> dict[str, torch.Tensor]:
|
def preprocess_maniskill_observation(observations: dict[str, np.ndarray]) -> dict[str, torch.Tensor]:
|
||||||
"""Convert environment observation to LeRobot format observation.
|
"""Convert environment observation to LeRobot format observation.
|
||||||
Args:
|
Args:
|
||||||
|
@ -43,32 +47,29 @@ def preprocess_maniskill_observation(observations: dict[str, np.ndarray]) -> dic
|
||||||
|
|
||||||
|
|
||||||
class ManiSkillObservationWrapper(gym.ObservationWrapper):
|
class ManiSkillObservationWrapper(gym.ObservationWrapper):
|
||||||
def __init__(self, env):
|
|
||||||
super().__init__(env)
|
|
||||||
|
|
||||||
def observation(self, observation):
|
|
||||||
return preprocess_maniskill_observation(observation)
|
|
||||||
|
|
||||||
|
|
||||||
class ManiSkillToDeviceWrapper(gym.Wrapper):
|
|
||||||
def __init__(self, env, device: torch.device = "cuda"):
|
def __init__(self, env, device: torch.device = "cuda"):
|
||||||
super().__init__(env)
|
super().__init__(env)
|
||||||
self.device = device
|
self.device = device
|
||||||
|
|
||||||
def reset(self, seed=None, options=None):
|
def observation(self, observation):
|
||||||
obs, info = self.env.reset(seed=seed, options=options)
|
observation = preprocess_maniskill_observation(observation)
|
||||||
obs = {k: v.to(self.device) for k, v in obs.items()}
|
observation = {k: v.to(self.device) for k, v in observation.items()}
|
||||||
return obs, info
|
return observation
|
||||||
|
|
||||||
def step(self, action):
|
|
||||||
obs, reward, terminated, truncated, info = self.env.step(action)
|
|
||||||
obs = {k: v.to(self.device) for k, v in obs.items()}
|
|
||||||
return obs, reward, terminated, truncated, info
|
|
||||||
|
|
||||||
|
|
||||||
class ManiSkillCompat(gym.Wrapper):
|
class ManiSkillCompat(gym.Wrapper):
|
||||||
def __init__(self, env):
|
def __init__(self, env):
|
||||||
super().__init__(env)
|
super().__init__(env)
|
||||||
|
new_action_space_shape = env.action_space.shape[-1]
|
||||||
|
new_low = np.squeeze(env.action_space.low, axis=0)
|
||||||
|
new_high = np.squeeze(env.action_space.high, axis=0)
|
||||||
|
self.action_space = gym.spaces.Box(low=new_low, high=new_high, shape=(new_action_space_shape,))
|
||||||
|
|
||||||
|
def reset(
|
||||||
|
self, *, seed: int | None = None, options: dict[str, Any] | None = None
|
||||||
|
) -> tuple[Any, dict[str, Any]]:
|
||||||
|
options = {}
|
||||||
|
return super().reset(seed=seed, options=options)
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
obs, reward, terminated, truncated, info = self.env.step(action)
|
obs, reward, terminated, truncated, info = self.env.step(action)
|
||||||
|
@ -89,7 +90,7 @@ class ManiSkillActionWrapper(gym.ActionWrapper):
|
||||||
|
|
||||||
|
|
||||||
class ManiSkillMultiplyActionWrapper(gym.Wrapper):
|
class ManiSkillMultiplyActionWrapper(gym.Wrapper):
|
||||||
def __init__(self, env, multiply_factor: float = 10):
|
def __init__(self, env, multiply_factor: float = 1):
|
||||||
super().__init__(env)
|
super().__init__(env)
|
||||||
self.multiply_factor = multiply_factor
|
self.multiply_factor = multiply_factor
|
||||||
action_space_agent: gym.spaces.Box = env.action_space[0]
|
action_space_agent: gym.spaces.Box = env.action_space[0]
|
||||||
|
@ -108,13 +109,8 @@ class ManiSkillMultiplyActionWrapper(gym.Wrapper):
|
||||||
|
|
||||||
|
|
||||||
def make_maniskill(
|
def make_maniskill(
|
||||||
task: str = "PushCube-v1",
|
cfg: DictConfig,
|
||||||
obs_mode: str = "rgb",
|
n_envs: int | None = None,
|
||||||
control_mode: str = "pd_ee_delta_pose",
|
|
||||||
render_mode: str = "rgb_array",
|
|
||||||
sensor_configs: dict[str, int] | None = None,
|
|
||||||
n_envs: int = 1,
|
|
||||||
device: torch.device = "cuda",
|
|
||||||
) -> gym.Env:
|
) -> gym.Env:
|
||||||
"""
|
"""
|
||||||
Factory function to create a ManiSkill environment with standard wrappers.
|
Factory function to create a ManiSkill environment with standard wrappers.
|
||||||
|
@ -130,22 +126,24 @@ def make_maniskill(
|
||||||
Returns:
|
Returns:
|
||||||
A wrapped ManiSkill environment
|
A wrapped ManiSkill environment
|
||||||
"""
|
"""
|
||||||
if sensor_configs is None:
|
|
||||||
sensor_configs = {"width": 64, "height": 64}
|
|
||||||
|
|
||||||
env = gym.make(
|
env = gym.make(
|
||||||
task,
|
cfg.env.task,
|
||||||
obs_mode=obs_mode,
|
obs_mode=cfg.env.obs,
|
||||||
control_mode=control_mode,
|
control_mode=cfg.env.control_mode,
|
||||||
render_mode=render_mode,
|
render_mode=cfg.env.render_mode,
|
||||||
sensor_configs=sensor_configs,
|
sensor_configs={"width": cfg.env.image_size, "height": cfg.env.image_size},
|
||||||
num_envs=n_envs,
|
num_envs=n_envs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
env = ManiSkillObservationWrapper(env, device=cfg.env.device)
|
||||||
|
env = ManiSkillVectorEnv(env, ignore_terminations=True, auto_reset=False)
|
||||||
|
env._max_episode_steps = env.max_episode_steps = 50 # gym_utils.find_max_episode_steps_value(env)
|
||||||
|
env.unwrapped.metadata["render_fps"] = 20
|
||||||
env = ManiSkillCompat(env)
|
env = ManiSkillCompat(env)
|
||||||
env = ManiSkillObservationWrapper(env)
|
|
||||||
env = ManiSkillActionWrapper(env)
|
env = ManiSkillActionWrapper(env)
|
||||||
env = ManiSkillMultiplyActionWrapper(env)
|
env = ManiSkillMultiplyActionWrapper(env, multiply_factor=10.0)
|
||||||
env = ManiSkillToDeviceWrapper(env, device=device)
|
|
||||||
return env
|
return env
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue