From dd37bd412e61c5b2110c5b58b1e783afab3b0c82 Mon Sep 17 00:00:00 2001 From: AdilZouitine Date: Wed, 26 Mar 2025 08:15:05 +0000 Subject: [PATCH] [WIP] Non functional yet Add ManiSkill environment configuration and wrappers - Introduced `VideoRecordConfig` for video recording settings. - Added `ManiskillEnvConfig` to encapsulate environment-specific configurations. - Implemented various wrappers for the ManiSkill environment, including observation and action scaling. - Enhanced the `make_maniskill` function to create a wrapped ManiSkill environment with video recording and observation processing. - Updated the `actor_server` and `learner_server` to utilize the new configuration structure. - Refactored the training pipeline to accommodate the new environment and policy configurations. --- lerobot/common/envs/configs.py | 58 ++ lerobot/common/envs/factory.py | 85 --- .../common/policies/sac/configuration_sac.py | 27 +- lerobot/common/utils/wandb_utils.py | 40 +- lerobot/configs/train.py | 5 +- lerobot/scripts/server/actor_server.py | 75 ++- lerobot/scripts/server/gym_manipulator.py | 136 ++--- lerobot/scripts/server/learner_server.py | 520 +++++++++++------- .../scripts/server/maniskill_manipulator.py | 157 ++++-- 9 files changed, 667 insertions(+), 436 deletions(-) diff --git a/lerobot/common/envs/configs.py b/lerobot/common/envs/configs.py index cf90048a..0414d64f 100644 --- a/lerobot/common/envs/configs.py +++ b/lerobot/common/envs/configs.py @@ -154,3 +154,61 @@ class XarmEnv(EnvConfig): "visualization_height": self.visualization_height, "max_episode_steps": self.episode_length, } + + +@dataclass +class VideoRecordConfig: + """Configuration for video recording in ManiSkill environments.""" + enabled: bool = False + record_dir: str = "videos" + trajectory_name: str = "trajectory" + +@EnvConfig.register_subclass("maniskill_push") +@dataclass +class ManiskillEnvConfig(EnvConfig): + """Configuration for the ManiSkill environment.""" + name: str = "maniskill/pushcube" + task: str = "PushCube-v1" + image_size: int = 64 + control_mode: str = "pd_ee_delta_pose" + state_dim: int = 25 + action_dim: int = 7 + fps: int = 400 + episode_length: int = 50 + obs_type: str = "rgb" + render_mode: str = "rgb_array" + render_size: int = 64 + device: str = "cuda" + robot: str = "so100" # This is a hack to make the robot config work + video_record: VideoRecordConfig = field(default_factory=VideoRecordConfig) + features: dict[str, PolicyFeature] = field( + default_factory=lambda: { + "action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)), + "observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 64, 64)), + "observation.state": PolicyFeature(type=FeatureType.STATE, shape=(25,)), + } + ) + features_map: dict[str, str] = field( + default_factory=lambda: { + "action": ACTION, + "observation.image": OBS_IMAGE, + "observation.state": OBS_ROBOT, + } + ) + reward_classifier: dict[str, str | None] = field( + default_factory=lambda: { + "pretrained_path": None, + "config_path": None, + } + ) + + @property + def gym_kwargs(self) -> dict: + return { + "obs_type": self.obs_type, + "render_mode": self.render_mode, + "max_episode_steps": self.episode_length, + "control_mode": self.control_mode, + "sensor_configs": {"width": self.image_size, "height": self.image_size}, + "num_envs": 1, + } \ No newline at end of file diff --git a/lerobot/common/envs/factory.py b/lerobot/common/envs/factory.py index c3996d84..b6b3e547 100644 --- a/lerobot/common/envs/factory.py +++ b/lerobot/common/envs/factory.py @@ -69,88 +69,3 @@ def make_env(cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False) -> g return env - -def make_maniskill_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv | None: - """Make ManiSkill3 gym environment""" - from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv - - env = gym.make( - cfg.env.task, - obs_mode=cfg.env.obs, - control_mode=cfg.env.control_mode, - render_mode=cfg.env.render_mode, - sensor_configs=dict(width=cfg.env.image_size, height=cfg.env.image_size), - num_envs=n_envs, - ) - # cfg.env_cfg.control_mode = cfg.eval_env_cfg.control_mode = env.control_mode - env = ManiSkillVectorEnv(env, ignore_terminations=True) - # state should have the size of 25 - # env = ConvertToLeRobotEnv(env, n_envs) - # env = PixelWrapper(cfg, env, n_envs) - env._max_episode_steps = env.max_episode_steps = 50 # gym_utils.find_max_episode_steps_value(env) - env.unwrapped.metadata["render_fps"] = 20 - - return env - - -class PixelWrapper(gym.Wrapper): - """ - Wrapper for pixel observations. Works with Maniskill vectorized environments - """ - - def __init__(self, cfg, env, num_envs, num_frames=3): - super().__init__(env) - self.cfg = cfg - self.env = env - self.observation_space = gym.spaces.Box( - low=0, - high=255, - shape=(num_envs, num_frames * 3, cfg.env.render_size, cfg.env.render_size), - dtype=np.uint8, - ) - self._frames = deque([], maxlen=num_frames) - self._render_size = cfg.env.render_size - - def _get_obs(self, obs): - frame = obs["sensor_data"]["base_camera"]["rgb"].cpu().permute(0, 3, 1, 2) - self._frames.append(frame) - return {"pixels": torch.from_numpy(np.concatenate(self._frames, axis=1)).to(self.env.device)} - - def reset(self, seed): - obs, info = self.env.reset() # (seed=seed) - for _ in range(self._frames.maxlen): - obs_frames = self._get_obs(obs) - return obs_frames, info - - def step(self, action): - obs, reward, terminated, truncated, info = self.env.step(action) - return self._get_obs(obs), reward, terminated, truncated, info - - -# TODO: Remove this -class ConvertToLeRobotEnv(gym.Wrapper): - def __init__(self, env, num_envs): - super().__init__(env) - - def reset(self, seed=None, options=None): - obs, info = self.env.reset(seed=seed, options={}) - return self._get_obs(obs), info - - def step(self, action): - obs, reward, terminated, truncated, info = self.env.step(action) - return self._get_obs(obs), reward, terminated, truncated, info - - def _get_obs(self, observation): - sensor_data = observation.pop("sensor_data") - del observation["sensor_param"] - images = [] - for cam_data in sensor_data.values(): - images.append(cam_data["rgb"]) - - images = torch.concat(images, axis=-1) - # flatten the rest of the data which should just be state data - observation = common.flatten_state_dict(observation, use_torch=True, device=self.base_env.device) - ret = dict() - ret["state"] = observation - ret["pixels"] = images - return ret diff --git a/lerobot/common/policies/sac/configuration_sac.py b/lerobot/common/policies/sac/configuration_sac.py index b34e5f60..5221a1f2 100644 --- a/lerobot/common/policies/sac/configuration_sac.py +++ b/lerobot/common/policies/sac/configuration_sac.py @@ -31,12 +31,19 @@ class SACConfig(PreTrainedConfig): Args: n_obs_steps: Number of environment steps worth of observations to pass to the policy. normalization_mapping: Mapping from feature types to normalization modes. + dataset_stats: Statistics for normalizing different data types. camera_number: Number of cameras to use. + device: Device to use for training. storage_device: Device to use for storage. vision_encoder_name: Name of the vision encoder to use. freeze_vision_encoder: Whether to freeze the vision encoder. image_encoder_hidden_dim: Hidden dimension for the image encoder. shared_encoder: Whether to use a shared encoder. + online_steps: Total number of online training steps. + online_env_seed: Seed for the online environment. + online_buffer_capacity: Capacity of the online replay buffer. + online_step_before_learning: Number of steps to collect before starting learning. + policy_update_freq: Frequency of policy updates. discount: Discount factor for the RL algorithm. temperature_init: Initial temperature for entropy regularization. num_critics: Number of critic networks. @@ -54,6 +61,8 @@ class SACConfig(PreTrainedConfig): critic_network_kwargs: Additional arguments for critic networks. actor_network_kwargs: Additional arguments for actor network. policy_kwargs: Additional arguments for policy. + actor_learner_config: Configuration for actor-learner communication. + concurrency: Configuration for concurrency model. """ # Input / output structure @@ -86,13 +95,21 @@ class SACConfig(PreTrainedConfig): # Architecture specifics camera_number: int = 1 + device: str = "cuda" storage_device: str = "cpu" # Set to "helper2424/resnet10" for hil serl vision_encoder_name: str | None = None freeze_vision_encoder: bool = True image_encoder_hidden_dim: int = 32 shared_encoder: bool = True - + + # Training parameter + online_steps: int = 1000000 + online_env_seed: int = 10000 + online_buffer_capacity: int = 10000 + online_step_before_learning: int = 100 + policy_update_freq: int = 1 + # SAC algorithm parameters discount: float = 0.99 temperature_init: float = 1.0 @@ -132,11 +149,17 @@ class SACConfig(PreTrainedConfig): } ) - # Deprecated, kept for backward compatibility actor_learner_config: dict[str, str | int] = field( default_factory=lambda: { "learner_host": "127.0.0.1", "learner_port": 50051, + "policy_parameters_push_frequency": 4, + } + ) + concurrency: dict[str, str] = field( + default_factory=lambda: { + "actor": "threads", + "learner": "threads" } ) diff --git a/lerobot/common/utils/wandb_utils.py b/lerobot/common/utils/wandb_utils.py index 3fe241d4..9a192406 100644 --- a/lerobot/common/utils/wandb_utils.py +++ b/lerobot/common/utils/wandb_utils.py @@ -92,6 +92,8 @@ class WandBLogger: resume="must" if cfg.resume else None, mode=self.cfg.mode if self.cfg.mode in ["online", "offline", "disabled"] else "online", ) + # Handle custom step key for rl asynchronous training. + self._wandb_custom_step_key: set[str] | None = None print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"])) logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}") self._wandb = wandb @@ -108,9 +110,24 @@ class WandBLogger: artifact.add_file(checkpoint_dir / PRETRAINED_MODEL_DIR / SAFETENSORS_SINGLE_FILE) self._wandb.log_artifact(artifact) - def log_dict(self, d: dict, step: int, mode: str = "train"): + def log_dict(self, d: dict, step: int, mode: str = "train", custom_step_key: str | None = None): if mode not in {"train", "eval"}: raise ValueError(mode) + if step is None and custom_step_key is None: + raise ValueError("Either step or custom_step_key must be provided.") + + # NOTE: This is not simple. Wandb step is it must always monotonically increase and it + # increases with each wandb.log call, but in the case of asynchronous RL for example, + # multiple time steps is possible for example, the interaction step with the environment, + # the training step, the evaluation step, etc. So we need to define a custom step key + # to log the correct step for each metric. + if custom_step_key is not None: + if self._wandb_custom_step_key is None: + self._wandb_custom_step_key = set() + new_custom_key = f"{mode}/{custom_step_key}" + if new_custom_key not in self._wandb_custom_step_key: + self._wandb_custom_step_key.add(new_custom_key) + self._wandb.define_metric(new_custom_key, hidden=True) for k, v in d.items(): if not isinstance(v, (int, float, str)): @@ -118,7 +135,26 @@ class WandBLogger: f'WandB logging of key "{k}" was ignored as its type is not handled by this wrapper.' ) continue - self._wandb.log({f"{mode}/{k}": v}, step=step) + + # Do not log the custom step key itself. + if ( + self._wandb_custom_step_key is not None + and k in self._wandb_custom_step_key + ): + continue + + if custom_step_key is not None: + value_custom_step = d[custom_step_key] + self._wandb.log( + { + f"{mode}/{k}": v, + f"{mode}/{custom_step_key}": value_custom_step, + } + ) + continue + + self._wandb.log(data={f"{mode}/{k}": v}, step=step) + def log_video(self, video_path: str, step: int, mode: str = "train"): if mode not in {"train", "eval"}: diff --git a/lerobot/configs/train.py b/lerobot/configs/train.py index 1e2f6544..f38cd8e6 100644 --- a/lerobot/configs/train.py +++ b/lerobot/configs/train.py @@ -34,11 +34,10 @@ TRAIN_CONFIG_NAME = "train_config.json" @dataclass class TrainPipelineConfig(HubMixin): - dataset: DatasetConfig + dataset: DatasetConfig | None = None # NOTE: In RL, we don't need a dataset env: envs.EnvConfig | None = None policy: PreTrainedConfig | None = None - # Set `dir` to where you would like to save all of the run outputs. If you run another training session - # with the same value for `dir` its contents will be overwritten unless you set `resume` to true. + # Set `dir` to where you would like to save all of the run outputs. If you run another training session # with the same value for `dir` its contents will be overwritten unless you set `resume` to true. output_dir: Path | None = None job_name: str | None = None # Set `resume` to true to resume a previous run. In order for this to work, you will need to make sure diff --git a/lerobot/scripts/server/actor_server.py b/lerobot/scripts/server/actor_server.py index e3a47e9b..7b68145e 100644 --- a/lerobot/scripts/server/actor_server.py +++ b/lerobot/scripts/server/actor_server.py @@ -21,26 +21,23 @@ from statistics import mean, quantiles # from lerobot.scripts.eval import eval_policy import grpc -import hydra import torch -from omegaconf import DictConfig from torch import nn from torch.multiprocessing import Event, Queue # TODO: Remove the import of maniskill -# from lerobot.common.envs.factory import make_maniskill_env -# from lerobot.common.envs.utils import preprocess_maniskill_observation from lerobot.common.policies.factory import make_policy from lerobot.common.policies.sac.modeling_sac import SACPolicy -from lerobot.common.robot_devices.robots.factory import make_robot -from lerobot.common.robot_devices.robots.utils import Robot +from lerobot.common.robot_devices.robots.utils import Robot, make_robot from lerobot.common.robot_devices.utils import busy_wait +from lerobot.common.utils.random_utils import set_seed from lerobot.common.utils.utils import ( TimerManager, get_safe_torch_device, init_logging, - set_global_seed, ) +from lerobot.configs import parser +from lerobot.configs.train import TrainPipelineConfig from lerobot.scripts.server import hilserl_pb2, hilserl_pb2_grpc, learner_service from lerobot.scripts.server.buffer import ( Transition, @@ -61,7 +58,7 @@ ACTOR_SHUTDOWN_TIMEOUT = 30 def receive_policy( - cfg: DictConfig, + cfg: TrainPipelineConfig, parameters_queue: Queue, shutdown_event: any, # Event, learner_client: hilserl_pb2_grpc.LearnerServiceStub | None = None, @@ -72,12 +69,12 @@ def receive_policy( if not use_threads(cfg): # Setup process handlers to handle shutdown signal # But use shutdown event from the main process - setup_process_handlers(False) + setup_process_handlers(use_threads=False) if grpc_channel is None or learner_client is None: learner_client, grpc_channel = learner_service_client( - host=cfg.actor_learner_config.learner_host, - port=cfg.actor_learner_config.learner_port, + host=cfg.policy.actor_learner_config["learner_host"], + port=cfg.policy.actor_learner_config["learner_port"], ) try: @@ -132,7 +129,7 @@ def interactions_stream( def send_transitions( - cfg: DictConfig, + cfg: TrainPipelineConfig, transitions_queue: Queue, shutdown_event: any, # Event, learner_client: hilserl_pb2_grpc.LearnerServiceStub | None = None, @@ -156,8 +153,8 @@ def send_transitions( if grpc_channel is None or learner_client is None: learner_client, grpc_channel = learner_service_client( - host=cfg.actor_learner_config.learner_host, - port=cfg.actor_learner_config.learner_port, + host=cfg.policy.actor_learner_config["learner_host"], + port=cfg.policy.actor_learner_config["learner_port"], ) try: @@ -173,7 +170,7 @@ def send_transitions( def send_interactions( - cfg: DictConfig, + cfg: TrainPipelineConfig, interactions_queue: Queue, shutdown_event: any, # Event, learner_client: hilserl_pb2_grpc.LearnerServiceStub | None = None, @@ -196,8 +193,8 @@ def send_interactions( if grpc_channel is None or learner_client is None: learner_client, grpc_channel = learner_service_client( - host=cfg.actor_learner_config.learner_host, - port=cfg.actor_learner_config.learner_port, + host=cfg.policy.actor_learner_config["learner_host"], + port=cfg.policy.actor_learner_config["learner_port"], ) try: @@ -269,7 +266,7 @@ def update_policy_parameters(policy: SACPolicy, parameters_queue: Queue, device) def act_with_policy( - cfg: DictConfig, + cfg: TrainPipelineConfig, robot: Robot, reward_classifier: nn.Module, shutdown_event: any, # Event, @@ -291,7 +288,7 @@ def act_with_policy( online_env = make_robot_env(robot=robot, reward_classifier=reward_classifier, cfg=cfg) - set_global_seed(cfg.seed) + set_seed(cfg.seed) device = get_safe_torch_device(cfg.device, log=True) torch.backends.cudnn.benchmark = True @@ -304,7 +301,7 @@ def act_with_policy( ### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters # TODO: At some point we should just need make sac policy policy: SACPolicy = make_policy( - hydra_cfg=cfg, + cfg=cfg.policy, # dataset_stats=offline_dataset.meta.stats if not cfg.resume else None, # Hack: But if we do online training, we do not need dataset_stats dataset_stats=None, @@ -469,7 +466,7 @@ def get_frequency_stats(list_policy_time: list[float]) -> dict[str, float]: return stats -def log_policy_frequency_issue(policy_fps: float, cfg: DictConfig, interaction_step: int): +def log_policy_frequency_issue(policy_fps: float, cfg: TrainPipelineConfig, interaction_step: int): if policy_fps < cfg.fps: logging.warning( f"[ACTOR] Policy FPS {policy_fps:.1f} below required {cfg.fps} at step {interaction_step}" @@ -497,25 +494,25 @@ def establish_learner_connection( return False -def use_threads(cfg: DictConfig) -> bool: - return cfg.actor_learner_config.concurrency.actor == "threads" +def use_threads(cfg: TrainPipelineConfig) -> bool: + return cfg.policy.concurrency["actor"] == "threads" -@hydra.main(version_base="1.2", config_name="default", config_path="../../configs") -def actor_cli(cfg: dict): +@parser.wrap() +def actor_cli(cfg: TrainPipelineConfig): if not use_threads(cfg): import torch.multiprocessing as mp mp.set_start_method("spawn") init_logging(log_file="actor.log") - robot = make_robot(cfg=cfg.robot) + robot = make_robot(robot_type=cfg.env.robot) shutdown_event = setup_process_handlers(use_threads(cfg)) learner_client, grpc_channel = learner_service_client( - host=cfg.actor_learner_config.learner_host, - port=cfg.actor_learner_config.learner_port, + host=cfg.policy.actor_learner_config["learner_host"], + port=cfg.policy.actor_learner_config["learner_port"], ) logging.info("[ACTOR] Establishing connection with Learner") @@ -570,22 +567,22 @@ def actor_cli(cfg: dict): # TODO: Remove this once we merge into main reward_classifier = None if ( - cfg.env.reward_classifier.pretrained_path is not None - and cfg.env.reward_classifier.config_path is not None + cfg.env.reward_classifier["pretrained_path"] is not None + and cfg.env.reward_classifier["config_path"] is not None ): reward_classifier = get_classifier( - pretrained_path=cfg.env.reward_classifier.pretrained_path, - config_path=cfg.env.reward_classifier.config_path, + pretrained_path=cfg.env.reward_classifier["pretrained_path"], + config_path=cfg.env.reward_classifier["config_path"], ) act_with_policy( - cfg, - robot, - reward_classifier, - shutdown_event, - parameters_queue, - transitions_queue, - interactions_queue, + cfg=cfg, + robot=robot, + reward_classifier=reward_classifier, + shutdown_event=shutdown_event, + parameters_queue=parameters_queue, + transitions_queue=transitions_queue, + interactions_queue=interactions_queue, ) logging.info("[ACTOR] Policy process joined") diff --git a/lerobot/scripts/server/gym_manipulator.py b/lerobot/scripts/server/gym_manipulator.py index 2f39bfdb..55c0c6de 100644 --- a/lerobot/scripts/server/gym_manipulator.py +++ b/lerobot/scripts/server/gym_manipulator.py @@ -15,10 +15,12 @@ import json from dataclasses import dataclass from lerobot.common.envs.utils import preprocess_observation +from lerobot.configs.train import TrainPipelineConfig +from lerobot.common.envs.configs import EnvConfig from lerobot.common.robot_devices.control_utils import ( busy_wait, is_headless, - reset_follower_position, + # reset_follower_position, ) from typing import Optional @@ -28,6 +30,7 @@ from lerobot.common.robot_devices.robots.utils import make_robot_from_config from lerobot.common.robot_devices.robots.configs import RobotConfig from lerobot.scripts.server.kinematics import RobotKinematics +from lerobot.scripts.server.maniskill_manipulator import ManiskillEnvConfig, make_maniskill from lerobot.configs import parser logging.basicConfig(level=logging.INFO) @@ -1094,7 +1097,10 @@ class ActionScaleWrapper(gym.ActionWrapper): return action * self.scale_vector, is_intervention -def make_robot_env(cfg, robot) -> gym.vector.VectorEnv: +@parser.wrap() +def make_robot_env(cfg: EnvConfig) -> gym.vector.VectorEnv: +# def make_robot_env(cfg: TrainPipelineConfig) -> gym.vector.VectorEnv: +# def make_robot_env(cfg: ManiskillEnvConfig) -> gym.vector.VectorEnv: """ Factory function to create a vectorized robot environment. @@ -1106,7 +1112,7 @@ def make_robot_env(cfg, robot) -> gym.vector.VectorEnv: Returns: A vectorized gym environment with all the necessary wrappers applied. """ - if "maniskill" in cfg.env_name: + if "maniskill" in cfg.name: from lerobot.scripts.server.maniskill_manipulator import make_maniskill logging.warning("WE SHOULD REMOVE THE MANISKILL BEFORE THE MERGE INTO MAIN") @@ -1115,7 +1121,7 @@ def make_robot_env(cfg, robot) -> gym.vector.VectorEnv: n_envs=1, ) return env - + robot = cfg.robot # Create base environment env = HILSerlRobotEnv( robot=robot, @@ -1329,80 +1335,82 @@ def replay_episode(env, repo_id, root=None, episode=0): busy_wait(1 / 10 - dt_s) -@parser.wrap() -def main(cfg: HILSerlRobotEnvConfig): +# @parser.wrap() +# def main(cfg): - robot = make_robot_from_config(cfg.robot) +# robot = make_robot_from_config(cfg.robot) - reward_classifier = None #get_classifier( - # cfg.wrapper.reward_classifier_pretrained_path, cfg.wrapper.reward_classifier_config_file - # ) - user_relative_joint_positions = True +# reward_classifier = None #get_classifier( +# # cfg.wrapper.reward_classifier_pretrained_path, cfg.wrapper.reward_classifier_config_file +# # ) +# user_relative_joint_positions = True - env = make_robot_env(cfg, robot) +# env = make_robot_env(cfg, robot) - if cfg.mode == "record": - policy = None - if cfg.pretrained_policy_name_or_path is not None: - from lerobot.common.policies.sac.modeling_sac import SACPolicy +# if cfg.mode == "record": +# policy = None +# if cfg.pretrained_policy_name_or_path is not None: +# from lerobot.common.policies.sac.modeling_sac import SACPolicy - policy = SACPolicy.from_pretrained(cfg.pretrained_policy_name_or_path) - policy.to(cfg.device) - policy.eval() +# policy = SACPolicy.from_pretrained(cfg.pretrained_policy_name_or_path) +# policy.to(cfg.device) +# policy.eval() - record_dataset( - env, - cfg.repo_id, - root=cfg.dataset_root, - num_episodes=cfg.num_episodes, - fps=cfg.fps, - task_description=cfg.task, - policy=policy, - ) - exit() +# record_dataset( +# env, +# cfg.repo_id, +# root=cfg.dataset_root, +# num_episodes=cfg.num_episodes, +# fps=cfg.fps, +# task_description=cfg.task, +# policy=policy, +# ) +# exit() - if cfg.mode == "replay": - replay_episode( - env, - cfg.replay_repo_id, - root=cfg.dataset_root, - episode=cfg.replay_episode, - ) - exit() +# if cfg.mode == "replay": +# replay_episode( +# env, +# cfg.replay_repo_id, +# root=cfg.dataset_root, +# episode=cfg.replay_episode, +# ) +# exit() - env.reset() +# env.reset() - # Retrieve the robot's action space for joint commands. - action_space_robot = env.action_space.spaces[0] +# # Retrieve the robot's action space for joint commands. +# action_space_robot = env.action_space.spaces[0] - # Initialize the smoothed action as a random sample. - smoothed_action = action_space_robot.sample() +# # Initialize the smoothed action as a random sample. +# smoothed_action = action_space_robot.sample() - # Smoothing coefficient (alpha) defines how much of the new random sample to mix in. - # A value close to 0 makes the trajectory very smooth (slow to change), while a value close to 1 is less smooth. - alpha = 1.0 +# # Smoothing coefficient (alpha) defines how much of the new random sample to mix in. +# # A value close to 0 makes the trajectory very smooth (slow to change), while a value close to 1 is less smooth. +# alpha = 1.0 - num_episode = 0 - sucesses = [] - while num_episode < 20: - start_loop_s = time.perf_counter() - # Sample a new random action from the robot's action space. - new_random_action = action_space_robot.sample() - # Update the smoothed action using an exponential moving average. - smoothed_action = alpha * new_random_action + (1 - alpha) * smoothed_action +# num_episode = 0 +# sucesses = [] +# while num_episode < 20: +# start_loop_s = time.perf_counter() +# # Sample a new random action from the robot's action space. +# new_random_action = action_space_robot.sample() +# # Update the smoothed action using an exponential moving average. +# smoothed_action = alpha * new_random_action + (1 - alpha) * smoothed_action - # Execute the step: wrap the NumPy action in a torch tensor. - obs, reward, terminated, truncated, info = env.step((torch.from_numpy(smoothed_action), False)) - if terminated or truncated: - sucesses.append(reward) - env.reset() - num_episode += 1 +# # Execute the step: wrap the NumPy action in a torch tensor. +# obs, reward, terminated, truncated, info = env.step((torch.from_numpy(smoothed_action), False)) +# if terminated or truncated: +# sucesses.append(reward) +# env.reset() +# num_episode += 1 - dt_s = time.perf_counter() - start_loop_s - busy_wait(1 / cfg.fps - dt_s) +# dt_s = time.perf_counter() - start_loop_s +# busy_wait(1 / cfg.fps - dt_s) - logging.info(f"Success after 20 steps {sucesses}") - logging.info(f"success rate {sum(sucesses) / len(sucesses)}") +# logging.info(f"Success after 20 steps {sucesses}") +# logging.info(f"success rate {sum(sucesses) / len(sucesses)}") +# if __name__ == "__main__": +# main() if __name__ == "__main__": - main() \ No newline at end of file + make_robot_env() \ No newline at end of file diff --git a/lerobot/scripts/server/learner_server.py b/lerobot/scripts/server/learner_server.py index 7d7db5cd..c34182b7 100644 --- a/lerobot/scripts/server/learner_server.py +++ b/lerobot/scripts/server/learner_server.py @@ -19,40 +19,45 @@ import shutil import time from concurrent.futures import ThreadPoolExecutor from pprint import pformat +import os +from pathlib import Path +import draccus import grpc # Import generated stubs import hilserl_pb2_grpc # type: ignore -import hydra import torch -from deepdiff import DeepDiff -from omegaconf import DictConfig, OmegaConf from termcolor import colored from torch import nn -# from torch.multiprocessing import Event, Queue, Process -# from threading import Event, Thread -# from torch.multiprocessing import Queue, Event from torch.multiprocessing import Queue from torch.optim.optimizer import Optimizer from lerobot.common.datasets.factory import make_dataset +from lerobot.configs.train import TrainPipelineConfig +from lerobot.configs import parser # TODO: Remove the import of maniskill from lerobot.common.datasets.lerobot_dataset import LeRobotDataset -from lerobot.common.logger import Logger, log_output_dir from lerobot.common.policies.factory import make_policy -from lerobot.common.policies.sac.modeling_sac import SACPolicy +from lerobot.common.policies.sac.modeling_sac import SACPolicy, SACConfig +from lerobot.common.utils.train_utils import ( + get_step_checkpoint_dir, + get_step_identifier, + load_training_state as utils_load_training_state, + save_checkpoint, + update_last_checkpoint, +) +from lerobot.common.utils.random_utils import set_seed from lerobot.common.utils.utils import ( format_big_number, - get_global_random_state, get_safe_torch_device, - init_hydra_config, init_logging, - set_global_random_state, - set_global_seed, ) + +from lerobot.common.policies.utils import get_device_from_parameters +from lerobot.common.utils.wandb_utils import WandBLogger from lerobot.scripts.server import learner_service from lerobot.scripts.server.buffer import ( ReplayBuffer, @@ -64,102 +69,167 @@ from lerobot.scripts.server.buffer import ( state_to_bytes, ) from lerobot.scripts.server.utils import setup_process_handlers +from lerobot.common.constants import ( + CHECKPOINTS_DIR, + LAST_CHECKPOINT_LINK, + PRETRAINED_MODEL_DIR, + TRAINING_STATE_DIR, + TRAINING_STEP, +) -def handle_resume_logic(cfg: DictConfig, out_dir: str) -> DictConfig: +def handle_resume_logic(cfg: TrainPipelineConfig) -> TrainPipelineConfig: + """ + Handle the resume logic for training. + + If resume is True: + - Verifies that a checkpoint exists + - Loads the checkpoint configuration + - Logs resumption details + - Returns the checkpoint configuration + + If resume is False: + - Checks if an output directory exists (to prevent accidental overwriting) + - Returns the original configuration + + Args: + cfg (TrainPipelineConfig): The training configuration + + Returns: + TrainPipelineConfig: The updated configuration + + Raises: + RuntimeError: If resume is True but no checkpoint found, or if resume is False but directory exists + """ + out_dir = cfg.output_dir + + # Case 1: Not resuming, but need to check if directory exists to prevent overwrites if not cfg.resume: - if Logger.get_last_checkpoint_dir(out_dir).exists(): + checkpoint_dir = os.path.join(out_dir, CHECKPOINTS_DIR, LAST_CHECKPOINT_LINK) + if os.path.exists(checkpoint_dir): raise RuntimeError( - f"Output directory {Logger.get_last_checkpoint_dir(out_dir)} already exists. " + f"Output directory {checkpoint_dir} already exists. " "Use `resume=true` to resume training." ) return cfg - # if resume == True - checkpoint_dir = Logger.get_last_checkpoint_dir(out_dir) - if not checkpoint_dir.exists(): + # Case 2: Resuming training + checkpoint_dir = os.path.join(out_dir, CHECKPOINTS_DIR, LAST_CHECKPOINT_LINK) + if not os.path.exists(checkpoint_dir): raise RuntimeError(f"No model checkpoint found in {checkpoint_dir} for resume=True") - checkpoint_cfg_path = str(Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml") + # Log that we found a valid checkpoint and are resuming logging.info( colored( - "Resume=True detected, resuming previous run", + "Valid checkpoint found: resume=True detected, resuming previous run", color="yellow", attrs=["bold"], ) ) - checkpoint_cfg = init_hydra_config(checkpoint_cfg_path) - diff = DeepDiff(OmegaConf.to_container(checkpoint_cfg), OmegaConf.to_container(cfg)) - - if "values_changed" in diff and "root['resume']" in diff["values_changed"]: - del diff["values_changed"]["root['resume']"] - - if len(diff) > 0: - logging.warning( - f"Differences between the checkpoint config and the provided config detected: \n{pformat(diff)}\n" - "Checkpoint configuration takes precedence." - ) - + # Load config using Draccus + checkpoint_cfg_path = os.path.join(checkpoint_dir, PRETRAINED_MODEL_DIR, "train_config.json") + checkpoint_cfg = TrainPipelineConfig.from_pretrained(checkpoint_cfg_path) + + # Ensure resume flag is set in returned config checkpoint_cfg.resume = True return checkpoint_cfg def load_training_state( - cfg: DictConfig, - logger: Logger, - optimizers: Optimizer | dict, + cfg: TrainPipelineConfig, + optimizers: Optimizer | dict[str, Optimizer], ): + """ + Loads the training state (optimizers, step count, etc.) from a checkpoint. + + Args: + cfg (TrainPipelineConfig): Training configuration + optimizers (Optimizer | dict): Optimizers to load state into + + Returns: + tuple: (optimization_step, interaction_step) or (None, None) if not resuming + """ if not cfg.resume: return None, None - training_state = torch.load( - logger.last_checkpoint_dir / logger.training_state_file_name, weights_only=False - ) - - if isinstance(training_state["optimizer"], dict): - assert set(training_state["optimizer"].keys()) == set(optimizers.keys()) - for k, v in training_state["optimizer"].items(): - optimizers[k].load_state_dict(v) - else: - optimizers.load_state_dict(training_state["optimizer"]) - - set_global_random_state({k: training_state[k] for k in get_global_random_state()}) - return training_state["step"], training_state["interaction_step"] + # Construct path to the last checkpoint directory + checkpoint_dir = os.path.join(cfg.output_dir, CHECKPOINTS_DIR, LAST_CHECKPOINT_LINK) + + logging.info(f"Loading training state from {checkpoint_dir}") + + try: + # Use the utility function from train_utils which loads the optimizer state + # The function returns (step, updated_optimizer, scheduler) + step, optimizers, _ = utils_load_training_state(Path(checkpoint_dir), optimizers, None) + + # For interaction step, we still need to load the training_state.pt file + training_state_path = os.path.join(checkpoint_dir, TRAINING_STATE_DIR, "training_state.pt") + training_state = torch.load(training_state_path, weights_only=False) + interaction_step = training_state.get("interaction_step", 0) + + logging.info(f"Resuming from step {step}, interaction step {interaction_step}") + return step, interaction_step + + except Exception as e: + logging.error(f"Failed to load training state: {e}") + return None, None -def log_training_info(cfg: DictConfig, out_dir: str, policy: nn.Module) -> None: +def log_training_info(cfg: TrainPipelineConfig, policy: nn.Module) -> None: + """ + Log information about the training process. + + Args: + cfg (TrainPipelineConfig): Training configuration + policy (nn.Module): Policy model + """ num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad) num_total_params = sum(p.numel() for p in policy.parameters()) - log_output_dir(out_dir) + + logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}") logging.info(f"{cfg.env.task=}") - logging.info(f"{cfg.training.online_steps=}") + logging.info(f"{cfg.policy.online_steps=}") logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})") logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})") def initialize_replay_buffer( - cfg: DictConfig, logger: Logger, device: str, storage_device: str + cfg: TrainPipelineConfig, + device: str, + storage_device: str ) -> ReplayBuffer: + """ + Initialize a replay buffer, either empty or from a dataset if resuming. + + Args: + cfg (TrainPipelineConfig): Training configuration + device (str): Device to store tensors on + storage_device (str): Device for storage optimization + + Returns: + ReplayBuffer: Initialized replay buffer + """ if not cfg.resume: return ReplayBuffer( - capacity=cfg.training.online_buffer_capacity, + capacity=cfg.policy.online_buffer_capacity, device=device, - state_keys=cfg.policy.input_shapes.keys(), + state_keys=cfg.policy.input_features.keys(), storage_device=storage_device, optimize_memory=True, ) logging.info("Resume training load the online dataset") + dataset_path = os.path.join(cfg.output_dir, "dataset") dataset = LeRobotDataset( - repo_id=cfg.dataset_repo_id, + repo_id=cfg.dataset.dataset_repo_id, local_files_only=True, - root=logger.log_dir / "dataset", + root=dataset_path, ) return ReplayBuffer.from_lerobot_dataset( lerobot_dataset=dataset, - capacity=cfg.training.online_buffer_capacity, + capacity=cfg.policy.online_buffer_capacity, device=device, state_keys=cfg.policy.input_shapes.keys(), optimize_memory=True, @@ -167,33 +237,45 @@ def initialize_replay_buffer( def initialize_offline_replay_buffer( - cfg: DictConfig, - logger: Logger, + cfg: TrainPipelineConfig, device: str, storage_device: str, active_action_dims: list[int] | None = None, ) -> ReplayBuffer: + """ + Initialize an offline replay buffer from a dataset. + + Args: + cfg (TrainPipelineConfig): Training configuration + device (str): Device to store tensors on + storage_device (str): Device for storage optimization + active_action_dims (list[int] | None): Active action dimensions for masking + + Returns: + ReplayBuffer: Initialized offline replay buffer + """ if not cfg.resume: logging.info("make_dataset offline buffer") offline_dataset = make_dataset(cfg) - if cfg.resume: + else: logging.info("load offline dataset") + dataset_offline_path = os.path.join(cfg.output_dir, "dataset_offline") offline_dataset = LeRobotDataset( - repo_id=cfg.dataset_repo_id, + repo_id=cfg.dataset.dataset_repo_id, local_files_only=True, - root=logger.log_dir / "dataset_offline", + root=dataset_offline_path, ) logging.info("Convert to a offline replay buffer") offline_replay_buffer = ReplayBuffer.from_lerobot_dataset( offline_dataset, device=device, - state_keys=cfg.policy.input_shapes.keys(), + state_keys=cfg.policy.input_features.keys(), action_mask=active_action_dims, action_delta=cfg.env.wrapper.delta_action, storage_device=storage_device, optimize_memory=True, - capacity=cfg.training.offline_buffer_capacity, + capacity=cfg.policy.offline_buffer_capacity, ) return offline_replay_buffer @@ -215,16 +297,23 @@ def get_observation_features( return observation_features, next_observation_features -def use_threads(cfg: DictConfig) -> bool: - return cfg.actor_learner_config.concurrency.learner == "threads" +def use_threads(cfg: TrainPipelineConfig) -> bool: + return cfg.policy.concurrency["learner"] == "threads" def start_learner_threads( - cfg: DictConfig, - logger: Logger, - out_dir: str, + cfg: TrainPipelineConfig, + wandb_logger: WandBLogger | None, shutdown_event: any, # Event, ) -> None: + """ + Start the learner threads for training. + + Args: + cfg (TrainPipelineConfig): Training configuration + wandb_logger (WandBLogger | None): Logger for metrics + shutdown_event: Event to signal shutdown + """ # Create multiprocessing queues transition_queue = Queue() interaction_message_queue = Queue() @@ -255,13 +344,12 @@ def start_learner_threads( communication_process.start() add_actor_information_and_train( - cfg, - logger, - out_dir, - shutdown_event, - transition_queue, - interaction_message_queue, - parameters_queue, + cfg=cfg, + wandb_logger=wandb_logger, + shutdown_event=shutdown_event, + transition_queue=transition_queue, + interaction_message_queue=interaction_message_queue, + parameters_queue=parameters_queue, ) logging.info("[LEARNER] Training process stopped") @@ -286,7 +374,7 @@ def start_learner_server( transition_queue: Queue, interaction_message_queue: Queue, shutdown_event: any, # Event, - cfg: DictConfig, + cfg: TrainPipelineConfig, ): if not use_threads(cfg): # We need init logging for MP separataly @@ -298,11 +386,11 @@ def start_learner_server( setup_process_handlers(False) service = learner_service.LearnerService( - shutdown_event, - parameters_queue, - cfg.actor_learner_config.policy_parameters_push_frequency, - transition_queue, - interaction_message_queue, + shutdown_event=shutdown_event, + parameters_queue=parameters_queue, + seconds_between_pushes=cfg.policy.actor_learner_config["policy_parameters_push_frequency"], + transition_queue=transition_queue, + interaction_message_queue=interaction_message_queue, ) server = grpc.server( @@ -318,8 +406,8 @@ def start_learner_server( server, ) - host = cfg.actor_learner_config.learner_host - port = cfg.actor_learner_config.learner_port + host = cfg.policy.actor_learner_config["learner_host"] + port = cfg.policy.actor_learner_config["learner_port"] server.add_insecure_port(f"{host}:{port}") server.start() @@ -385,9 +473,8 @@ def push_actor_policy_to_queue(parameters_queue: Queue, policy: nn.Module): def add_actor_information_and_train( - cfg, - logger: Logger, - out_dir: str, + cfg: TrainPipelineConfig, + wandb_logger: WandBLogger | None, shutdown_event: any, # Event, transition_queue: Queue, interaction_message_queue: Queue, @@ -405,69 +492,60 @@ def add_actor_information_and_train( - Periodically updates the actor, critic, and temperature optimizers. - Logs training statistics, including loss values and optimization frequency. - **NOTE:** - - This function performs multiple responsibilities (data transfer, training, and logging). - It should ideally be split into smaller functions in the future. - - Due to Python's **Global Interpreter Lock (GIL)**, running separate threads for different tasks - significantly reduces performance. Instead, this function executes all operations in a single thread. - Args: - cfg: Configuration object containing hyperparameters. - device (str): The computing device (`"cpu"` or `"cuda"`). - logger (Logger): Logger instance for tracking training progress. - out_dir (str): The output directory for storing training checkpoints and logs. + cfg (TrainPipelineConfig): Configuration object containing hyperparameters. + wandb_logger (WandBLogger | None): Logger for tracking training progress. shutdown_event (Event): Event to signal shutdown. transition_queue (Queue): Queue for receiving transitions from the actor. interaction_message_queue (Queue): Queue for receiving interaction messages from the actor. parameters_queue (Queue): Queue for sending policy parameters to the actor. """ - device = get_safe_torch_device(cfg.device, log=True) - storage_device = get_safe_torch_device(cfg_device=cfg.training.storage_device) + device = get_safe_torch_device(try_device=cfg.policy.device, log=True) + storage_device = get_safe_torch_device(try_device=cfg.policy.storage_device) logging.info("Initializing policy") - ### Instantiate the policy in both the actor and learner processes - ### To avoid sending a SACPolicy object through the port, we create a policy intance - ### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters - # TODO: At some point we should just need make sac policy - + # Get checkpoint dir for resuming + checkpoint_dir = os.path.join(cfg.output_dir, CHECKPOINTS_DIR, LAST_CHECKPOINT_LINK) if cfg.resume else None + pretrained_path = os.path.join(checkpoint_dir, PRETRAINED_MODEL_DIR) if checkpoint_dir else None + + # TODO(Adil): This don't work anymore ! policy: SACPolicy = make_policy( - hydra_cfg=cfg, - # dataset_stats=offline_dataset.meta.stats if not cfg.resume else None, - # Hack: But if we do online traning, we do not need dataset_stats - dataset_stats=None, - pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None, + cfg=cfg.policy, + # ds_meta=cfg.dataset, + env_cfg=cfg.env ) # Update the policy config with the grad_clip_norm value from training config if it exists - clip_grad_norm_value = cfg.training.grad_clip_norm + clip_grad_norm_value:float = cfg.policy.grad_clip_norm # compile policy policy = torch.compile(policy) assert isinstance(policy, nn.Module) + policy.train() - push_actor_policy_to_queue(parameters_queue, policy) + push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy) last_time_policy_pushed = time.time() - optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg, policy) - resume_optimization_step, resume_interaction_step = load_training_state(cfg, logger, optimizers) + optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg=cfg, policy=policy) + resume_optimization_step, resume_interaction_step = load_training_state(cfg=cfg, optimizers=optimizers) - log_training_info(cfg, out_dir, policy) + log_training_info(cfg=cfg, policy= policy) - replay_buffer = initialize_replay_buffer(cfg, logger, device, storage_device) - batch_size = cfg.training.batch_size + replay_buffer = initialize_replay_buffer(cfg, device, storage_device) + batch_size = cfg.batch_size offline_replay_buffer = None - if cfg.dataset_repo_id is not None: + if cfg.dataset is not None: active_action_dims = None + # TODO: FIX THIS if cfg.env.wrapper.joint_masking_action_space is not None: active_action_dims = [ i for i, mask in enumerate(cfg.env.wrapper.joint_masking_action_space) if mask ] offline_replay_buffer = initialize_offline_replay_buffer( cfg=cfg, - logger=logger, device=device, storage_device=storage_device, active_action_dims=active_action_dims, @@ -484,18 +562,22 @@ def add_actor_information_and_train( interaction_step_shift = resume_interaction_step if resume_interaction_step is not None else 0 # Extract variables from cfg - online_step_before_learning = cfg.training.online_step_before_learning + online_step_before_learning = cfg.policy.online_step_before_learning utd_ratio = cfg.policy.utd_ratio - dataset_repo_id = cfg.dataset_repo_id - fps = cfg.fps - log_freq = cfg.training.log_freq - save_freq = cfg.training.save_freq - device = cfg.device - storage_device = cfg.training.storage_device - policy_update_freq = cfg.training.policy_update_freq - policy_parameters_push_frequency = cfg.actor_learner_config.policy_parameters_push_frequency - save_checkpoint = cfg.training.save_checkpoint - online_steps = cfg.training.online_steps + + dataset_repo_id = None + if cfg.dataset is not None: + dataset_repo_id = cfg.dataset.repo_id + + fps = cfg.env.fps + log_freq = cfg.log_freq + save_freq = cfg.save_freq + device = cfg.policy.device + storage_device = cfg.policy.storage_device + policy_update_freq = cfg.policy.policy_update_freq + policy_parameters_push_frequency = cfg.policy.actor_learner_config["policy_parameters_push_frequency"] + save_checkpoint = cfg.save_checkpoint + online_steps = cfg.policy.online_steps while True: if shutdown_event is not None and shutdown_event.is_set(): @@ -516,7 +598,7 @@ def add_actor_information_and_train( continue replay_buffer.add(**transition) - if cfg.dataset_repo_id is not None and transition.get("complementary_info", {}).get( + if cfg.dataset.repo_id is not None and transition.get("complementary_info", {}).get( "is_intervention" ): offline_replay_buffer.add(**transition) @@ -528,7 +610,17 @@ def add_actor_information_and_train( interaction_message = bytes_to_python_object(interaction_message) # If cfg.resume, shift the interaction step with the last checkpointed step in order to not break the logging interaction_message["Interaction step"] += interaction_step_shift - logger.log_dict(interaction_message, mode="train", custom_step_key="Interaction step") + + # Log interaction messages with WandB if available + if wandb_logger: + wandb_logger.log_dict( + d=interaction_message, + mode="train", + custom_step_key="Interaction step" + ) + else: + # Log to console if no WandB logger + logging.info(f"Interaction: {interaction_message}") logging.debug("[LEARNER] Received interactions") @@ -538,11 +630,11 @@ def add_actor_information_and_train( logging.debug("[LEARNER] Starting optimization loop") time_for_one_optimization_step = time.time() for _ in range(utd_ratio - 1): - batch = replay_buffer.sample(batch_size) + batch = replay_buffer.sample(batch_size=batch_size) if dataset_repo_id is not None: - batch_offline = offline_replay_buffer.sample(batch_size) - batch = concatenate_batch_transitions(batch, batch_offline) + batch_offline = offline_replay_buffer.sample(batch_size=batch_size) + batch = concatenate_batch_transitions(left_batch_transitions=batch, right_batch_transition=batch_offline) actions = batch["action"] rewards = batch["reward"] @@ -552,7 +644,7 @@ def add_actor_information_and_train( check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations) observation_features, next_observation_features = get_observation_features( - policy, observations, next_observations + policy=policy, observations=observations, next_observations=next_observations ) loss_critic = policy.compute_loss_critic( observations=observations, @@ -568,15 +660,15 @@ def add_actor_information_and_train( # clip gradients critic_grad_norm = torch.nn.utils.clip_grad_norm_( - policy.critic_ensemble.parameters(), clip_grad_norm_value + parameters=policy.critic_ensemble.parameters(), max_norm=clip_grad_norm_value ) optimizers["critic"].step() - batch = replay_buffer.sample(batch_size) + batch = replay_buffer.sample(batch_size=batch_size) if dataset_repo_id is not None: - batch_offline = offline_replay_buffer.sample(batch_size) + batch_offline = offline_replay_buffer.sample(batch_size=batch_size) batch = concatenate_batch_transitions( left_batch_transitions=batch, right_batch_transition=batch_offline ) @@ -590,7 +682,7 @@ def add_actor_information_and_train( check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations) observation_features, next_observation_features = get_observation_features( - policy, observations, next_observations + policy=policy, observations=observations, next_observations=next_observations ) loss_critic = policy.compute_loss_critic( observations=observations, @@ -606,7 +698,7 @@ def add_actor_information_and_train( # clip gradients critic_grad_norm = torch.nn.utils.clip_grad_norm_( - policy.critic_ensemble.parameters(), clip_grad_norm_value + parameters=policy.critic_ensemble.parameters(), max_norm=clip_grad_norm_value ).item() optimizers["critic"].step() @@ -627,7 +719,7 @@ def add_actor_information_and_train( # clip gradients actor_grad_norm = torch.nn.utils.clip_grad_norm_( - policy.actor.parameters_to_optimize, clip_grad_norm_value + parameters=policy.actor.parameters_to_optimize, max_norm=clip_grad_norm_value ).item() optimizers["actor"].step() @@ -645,7 +737,7 @@ def add_actor_information_and_train( # clip gradients temp_grad_norm = torch.nn.utils.clip_grad_norm_( - [policy.log_alpha], clip_grad_norm_value + parameters=[policy.log_alpha], max_norm=clip_grad_norm_value ).item() optimizers["temperature"].step() @@ -655,7 +747,7 @@ def add_actor_information_and_train( training_infos["temperature"] = policy.temperature if time.time() - last_time_policy_pushed > policy_parameters_push_frequency: - push_actor_policy_to_queue(parameters_queue, policy) + push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy) last_time_policy_pushed = time.time() policy.update_target_networks() @@ -665,22 +757,33 @@ def add_actor_information_and_train( if offline_replay_buffer is not None: training_infos["offline_replay_buffer_size"] = len(offline_replay_buffer) training_infos["Optimization step"] = optimization_step - logger.log_dict(d=training_infos, mode="train", custom_step_key="Optimization step") - # logging.info(f"Training infos: {training_infos}") + + # Log training metrics + if wandb_logger: + wandb_logger.log_dict( + d=training_infos, + mode="train", + custom_step_key="Optimization step" + ) + else: + # Log to console if no WandB logger + logging.info(f"Training: {training_infos}") time_for_one_optimization_step = time.time() - time_for_one_optimization_step frequency_for_one_optimization_step = 1 / (time_for_one_optimization_step + 1e-9) logging.info(f"[LEARNER] Optimization frequency loop [Hz]: {frequency_for_one_optimization_step}") - logger.log_dict( - { - "Optimization frequency loop [Hz]": frequency_for_one_optimization_step, - "Optimization step": optimization_step, - }, - mode="train", - custom_step_key="Optimization step", - ) + # Log optimization frequency + if wandb_logger: + wandb_logger.log_dict( + { + "Optimization frequency loop [Hz]": frequency_for_one_optimization_step, + "Optimization step": optimization_step, + }, + mode="train", + custom_step_key="Optimization step", + ) optimization_step += 1 if optimization_step % log_freq == 0: @@ -693,35 +796,45 @@ def add_actor_information_and_train( interaction_step = ( interaction_message["Interaction step"] if interaction_message is not None else 0 ) - logger.save_checkpoint( - optimization_step, - policy, + + # Create checkpoint directory + checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, online_steps, optimization_step) + + # Save checkpoint + save_checkpoint( + checkpoint_dir, + optimization_step, + cfg, + policy, optimizers, - scheduler=None, - identifier=step_identifier, - interaction_step=interaction_step, + scheduler=None ) + + # Update the "last" symlink + update_last_checkpoint(checkpoint_dir) # TODO : temporarly save replay buffer here, remove later when on the robot # We want to control this with the keyboard inputs - dataset_dir = logger.log_dir / "dataset" - if dataset_dir.exists() and dataset_dir.is_dir(): - shutil.rmtree( - dataset_dir, - ) - replay_buffer.to_lerobot_dataset(dataset_repo_id, fps=fps, root=logger.log_dir / "dataset") + dataset_dir = os.path.join(cfg.output_dir, "dataset") + if os.path.exists(dataset_dir) and os.path.isdir(dataset_dir): + shutil.rmtree(dataset_dir) + + # Save dataset + replay_buffer.to_lerobot_dataset( + dataset_repo_id, + fps=fps, + root=dataset_dir + ) + if offline_replay_buffer is not None: - dataset_dir = logger.log_dir / "dataset_offline" - - if dataset_dir.exists() and dataset_dir.is_dir(): - shutil.rmtree( - dataset_dir, - ) + dataset_offline_dir = os.path.join(cfg.output_dir, "dataset_offline") + if os.path.exists(dataset_offline_dir) and os.path.isdir(dataset_offline_dir): + shutil.rmtree(dataset_offline_dir) offline_replay_buffer.to_lerobot_dataset( - cfg.dataset_repo_id, - fps=cfg.fps, - root=logger.log_dir / "dataset_offline", + cfg.dataset.dataset_repo_id, + fps=cfg.env.fps, + root=dataset_offline_dir, ) logging.info("Resume training") @@ -756,12 +869,12 @@ def make_optimizers_and_scheduler(cfg, policy: nn.Module): optimizer_actor = torch.optim.Adam( # NOTE: Handle the case of shared encoder where the encoder weights are not optimized with the gradient of the actor params=policy.actor.parameters_to_optimize, - lr=policy.config.actor_lr, + lr=cfg.policy.actor_lr, ) optimizer_critic = torch.optim.Adam( - params=policy.critic_ensemble.parameters(), lr=policy.config.critic_lr + params=policy.critic_ensemble.parameters(), lr=cfg.policy.critic_lr ) - optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=policy.config.critic_lr) + optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=cfg.policy.critic_lr) lr_scheduler = None optimizers = { "actor": optimizer_actor, @@ -771,19 +884,38 @@ def make_optimizers_and_scheduler(cfg, policy: nn.Module): return optimizers, lr_scheduler -def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = None): - if out_dir is None: - raise NotImplementedError() +def train(cfg: TrainPipelineConfig, job_name: str | None = None): + """ + Main training function that initializes and runs the training process. + + Args: + cfg (TrainPipelineConfig): The training configuration + job_name (str | None, optional): Job name for logging. Defaults to None. + """ + if cfg.output_dir is None: + raise ValueError("Output directory must be specified in config") + if job_name is None: - raise NotImplementedError() + job_name = cfg.job_name + + if job_name is None: + raise ValueError("Job name must be specified either in config or as a parameter") init_logging() - logging.info(pformat(OmegaConf.to_container(cfg))) + logging.info(pformat(cfg.to_dict())) - logger = Logger(cfg, out_dir, wandb_job_name=job_name) - cfg = handle_resume_logic(cfg, out_dir) + # Setup WandB logging if enabled + if cfg.wandb.enable and cfg.wandb.project: + from lerobot.common.utils.wandb_utils import WandBLogger + wandb_logger = WandBLogger(cfg) + else: + wandb_logger = None + logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"])) + + # Handle resume logic + cfg = handle_resume_logic(cfg) - set_global_seed(cfg.seed) + set_seed(seed=cfg.seed) torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True @@ -791,24 +923,23 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No shutdown_event = setup_process_handlers(use_threads(cfg)) start_learner_threads( - cfg, - logger, - out_dir, - shutdown_event, + cfg=cfg, + wandb_logger=wandb_logger, + shutdown_event=shutdown_event, ) -@hydra.main(version_base="1.2", config_name="default", config_path="../../configs") -def train_cli(cfg: dict): +@parser.wrap() +def train_cli(cfg: TrainPipelineConfig): + if not use_threads(cfg): import torch.multiprocessing as mp - mp.set_start_method("spawn") + # Use the job_name from the config train( cfg, - out_dir=hydra.core.hydra_config.HydraConfig.get().run.dir, - job_name=hydra.core.hydra_config.HydraConfig.get().job.name, + job_name=cfg.job_name, ) logging.info("[LEARNER] train_cli finished") @@ -816,5 +947,4 @@ def train_cli(cfg: dict): if __name__ == "__main__": train_cli() - logging.info("[LEARNER] main finished") diff --git a/lerobot/scripts/server/maniskill_manipulator.py b/lerobot/scripts/server/maniskill_manipulator.py index 9db7aa40..2ad7c661 100644 --- a/lerobot/scripts/server/maniskill_manipulator.py +++ b/lerobot/scripts/server/maniskill_manipulator.py @@ -1,4 +1,7 @@ -from typing import Any +import logging +import time +from dataclasses import dataclass, field +from typing import Any, Dict, Optional, Tuple import einops import gymnasium as gym @@ -6,7 +9,11 @@ import numpy as np import torch from mani_skill.utils.wrappers.record import RecordEpisode from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv -from omegaconf import DictConfig + +from lerobot.common.envs.configs import EnvConfig, ManiskillEnvConfig +from lerobot.configs import parser +from lerobot.configs.types import FeatureType, PolicyFeature +from lerobot.common.constants import ACTION, OBS_IMAGE, OBS_ROBOT def preprocess_maniskill_observation( @@ -46,9 +53,14 @@ def preprocess_maniskill_observation( return return_observations + + + class ManiSkillObservationWrapper(gym.ObservationWrapper): def __init__(self, env, device: torch.device = "cuda"): super().__init__(env) + if isinstance(device, str): + device = torch.device(device) self.device = device def observation(self, observation): @@ -108,76 +120,129 @@ class ManiSkillMultiplyActionWrapper(gym.Wrapper): return obs, reward, terminated, truncated, info +class BatchCompatibleWrapper(gym.ObservationWrapper): + """Ensures observations are batch-compatible by adding a batch dimension if necessary.""" + def __init__(self, env): + super().__init__(env) + + def observation(self, observation: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + for key in observation: + if "image" in key and observation[key].dim() == 3: + observation[key] = observation[key].unsqueeze(0) + if "state" in key and observation[key].dim() == 1: + observation[key] = observation[key].unsqueeze(0) + return observation + + +class TimeLimitWrapper(gym.Wrapper): + """Adds a time limit to the environment based on fps and control_time.""" + def __init__(self, env, control_time_s, fps): + super().__init__(env) + self.control_time_s = control_time_s + self.fps = fps + self.max_episode_steps = int(self.control_time_s * self.fps) + self.current_step = 0 + + def step(self, action): + obs, reward, terminated, truncated, info = self.env.step(action) + self.current_step += 1 + + if self.current_step >= self.max_episode_steps: + terminated = True + + return obs, reward, terminated, truncated, info + + def reset(self, *, seed=None, options=None): + self.current_step = 0 + return super().reset(seed=seed, options=options) + + def make_maniskill( - cfg: DictConfig, + cfg: ManiskillEnvConfig, n_envs: int | None = None, ) -> gym.Env: """ Factory function to create a ManiSkill environment with standard wrappers. Args: - task: Name of the ManiSkill task - obs_mode: Observation mode (rgb, rgbd, etc) - control_mode: Control mode for the robot - render_mode: Rendering mode - sensor_configs: Camera sensor configurations + cfg: Configuration for the ManiSkill environment n_envs: Number of parallel environments Returns: A wrapped ManiSkill environment """ - env = gym.make( - cfg.env.task, - obs_mode=cfg.env.obs, - control_mode=cfg.env.control_mode, - render_mode=cfg.env.render_mode, - sensor_configs={"width": cfg.env.image_size, "height": cfg.env.image_size}, + cfg.task, + obs_mode=cfg.obs_type, + control_mode=cfg.control_mode, + render_mode=cfg.render_mode, + sensor_configs={"width": cfg.image_size, "height": cfg.image_size}, num_envs=n_envs, ) - if cfg.env.video_record.enabled: + # Add video recording if enabled + if cfg.video_record.enabled: env = RecordEpisode( env, - output_dir=cfg.env.video_record.record_dir, + output_dir=cfg.video_record.record_dir, save_trajectory=True, - trajectory_name=cfg.env.video_record.trajectory_name, + trajectory_name=cfg.video_record.trajectory_name, save_video=True, video_fps=30, ) - env = ManiSkillObservationWrapper(env, device=cfg.env.device) + + # Add observation and image processing + env = ManiSkillObservationWrapper(env, device=cfg.device) env = ManiSkillVectorEnv(env, ignore_terminations=True, auto_reset=False) - env._max_episode_steps = env.max_episode_steps = 50 # gym_utils.find_max_episode_steps_value(env) - env.unwrapped.metadata["render_fps"] = 20 + env._max_episode_steps = env.max_episode_steps = cfg.episode_length + env.unwrapped.metadata["render_fps"] = cfg.fps + + # Add compatibility wrappers env = ManiSkillCompat(env) env = ManiSkillActionWrapper(env) - env = ManiSkillMultiplyActionWrapper(env, multiply_factor=0.03) - + env = ManiSkillMultiplyActionWrapper(env, multiply_factor=0.03) # Scale actions for better control + return env -if __name__ == "__main__": - import argparse - - import hydra - - parser = argparse.ArgumentParser() - parser.add_argument("--config", type=str, default="lerobot/configs/env/maniskill_example.yaml") - args = parser.parse_args() - - # Initialize config - with hydra.initialize(version_base=None, config_path="../../configs"): - cfg = hydra.compose(config_name="env/maniskill_example.yaml") - - env = make_maniskill( - task=cfg.env.task, - obs_mode=cfg.env.obs, - control_mode=cfg.env.control_mode, - render_mode=cfg.env.render_mode, - sensor_configs={"width": cfg.env.render_size, "height": cfg.env.render_size}, - ) - - print("env done") +@parser.wrap() +def main(cfg: ManiskillEnvConfig): + """Main function to run the ManiSkill environment.""" + # Create the ManiSkill environment + env = make_maniskill(cfg, n_envs=1) + + # Reset the environment obs, info = env.reset() - random_action = env.action_space.sample() - obs, reward, terminated, truncated, info = env.step(random_action) + + # Run a simple interaction loop + sum_reward = 0 + for i in range(100): + # Sample a random action + action = env.action_space.sample() + + # Step the environment + start_time = time.perf_counter() + obs, reward, terminated, truncated, info = env.step(action) + step_time = time.perf_counter() - start_time + sum_reward += reward + # Log information + + # Reset if episode terminated + if terminated or truncated: + logging.info(f"Step {i}, reward: {sum_reward}, step time: {step_time}s") + sum_reward = 0 + obs, info = env.reset() + + # Close the environment + env.close() + + +# if __name__ == "__main__": +# logging.basicConfig(level=logging.INFO) +# main() + +if __name__ == "__main__": + import draccus + config = ManiskillEnvConfig() + draccus.set_config_type("json") + draccus.dump(config=config, stream=open(file='run_config.json', mode='w'), ) \ No newline at end of file