diff --git a/lerobot/common/envs/configs.py b/lerobot/common/envs/configs.py index 313f003a..42be632d 100644 --- a/lerobot/common/envs/configs.py +++ b/lerobot/common/envs/configs.py @@ -163,6 +163,12 @@ class VideoRecordConfig: record_dir: str = "videos" trajectory_name: str = "trajectory" +@dataclass +class WrapperConfig: + """Configuration for environment wrappers.""" + delta_action: float | None = None + joint_masking_action_space: list[bool] | None = None + @EnvConfig.register_subclass("maniskill_push") @dataclass class ManiskillEnvConfig(EnvConfig): @@ -181,6 +187,7 @@ class ManiskillEnvConfig(EnvConfig): device: str = "cuda" robot: str = "so100" # This is a hack to make the robot config work video_record: VideoRecordConfig = field(default_factory=VideoRecordConfig) + wrapper: WrapperConfig = field(default_factory=WrapperConfig) features: dict[str, PolicyFeature] = field( default_factory=lambda: { "action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)), diff --git a/lerobot/common/policies/sac/configuration_sac.py b/lerobot/common/policies/sac/configuration_sac.py index d252ddc0..fa3b0187 100644 --- a/lerobot/common/policies/sac/configuration_sac.py +++ b/lerobot/common/policies/sac/configuration_sac.py @@ -233,11 +233,11 @@ class SACConfig(PreTrainedConfig): @property def observation_delta_indices(self) -> list: - return list(range(1 - self.n_obs_steps, 1)) + return None @property def action_delta_indices(self) -> list: - return [0] # SAC typically predicts one action at a time + return None # SAC typically predicts one action at a time @property def reward_delta_indices(self) -> None: diff --git a/lerobot/scripts/server/gym_manipulator.py b/lerobot/scripts/server/gym_manipulator.py index 856ea843..1f05d9a7 100644 --- a/lerobot/scripts/server/gym_manipulator.py +++ b/lerobot/scripts/server/gym_manipulator.py @@ -1100,15 +1100,15 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv: Returns: A vectorized gym environment with all the necessary wrappers applied. """ - # if "maniskill" in cfg.name: - # from lerobot.scripts.server.maniskill_manipulator import make_maniskill + if "maniskill" in cfg.name: + from lerobot.scripts.server.maniskill_manipulator import make_maniskill - # logging.warning("WE SHOULD REMOVE THE MANISKILL BEFORE THE MERGE INTO MAIN") - # env = make_maniskill( - # cfg=cfg, - # n_envs=1, - # ) - # return env + logging.warning("WE SHOULD REMOVE THE MANISKILL BEFORE THE MERGE INTO MAIN") + env = make_maniskill( + cfg=cfg, + n_envs=1, + ) + return env robot = make_robot_from_config(cfg.robot) # Create base environment env = HILSerlRobotEnv( diff --git a/lerobot/scripts/server/learner_server.py b/lerobot/scripts/server/learner_server.py index 4deb1972..da3e6606 100644 --- a/lerobot/scripts/server/learner_server.py +++ b/lerobot/scripts/server/learner_server.py @@ -222,7 +222,7 @@ def initialize_replay_buffer(cfg: TrainPipelineConfig, device: str, storage_devi # NOTE: In RL is possible to not have a dataset. repo_id = None if cfg.dataset is not None: - repo_id = cfg.dataset.dataset_repo_id + repo_id = cfg.dataset.repo_id dataset = LeRobotDataset( repo_id=repo_id, root=dataset_path, @@ -261,8 +261,7 @@ def initialize_offline_replay_buffer( logging.info("load offline dataset") dataset_offline_path = os.path.join(cfg.output_dir, "dataset_offline") offline_dataset = LeRobotDataset( - repo_id=cfg.dataset.dataset_repo_id, - local_files_only=True, + repo_id=cfg.dataset.repo_id, root=dataset_offline_path, ) @@ -779,53 +778,106 @@ def add_actor_information_and_train( logging.info(f"[LEARNER] Number of optimization step: {optimization_step}") if saving_checkpoint and (optimization_step % save_freq == 0 or optimization_step == online_steps): - logging.info(f"Checkpoint policy after step {optimization_step}") - _num_digits = max(6, len(str(online_steps))) - step_identifier = f"{optimization_step:0{_num_digits}d}" - interaction_step = ( - interaction_message["Interaction step"] if interaction_message is not None else 0 + save_training_checkpoint( + cfg=cfg, + optimization_step=optimization_step, + online_steps=online_steps, + interaction_message=interaction_message, + policy=policy, + optimizers=optimizers, + replay_buffer=replay_buffer, + offline_replay_buffer=offline_replay_buffer, + dataset_repo_id=dataset_repo_id, + fps=fps, ) - # Create checkpoint directory - checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, online_steps, optimization_step) +def save_training_checkpoint( + cfg: TrainPipelineConfig, + optimization_step: int, + online_steps: int, + interaction_message: dict | None, + policy: nn.Module, + optimizers: dict[str, Optimizer], + replay_buffer: ReplayBuffer, + offline_replay_buffer: ReplayBuffer | None = None, + dataset_repo_id: str | None = None, + fps: int = 30, +) -> None: + """ + Save training checkpoint and associated data. + + Args: + cfg: Training configuration + optimization_step: Current optimization step + online_steps: Total number of online steps + interaction_message: Dictionary containing interaction information + policy: Policy model to save + optimizers: Dictionary of optimizers + replay_buffer: Replay buffer to save as dataset + offline_replay_buffer: Optional offline replay buffer to save + dataset_repo_id: Repository ID for dataset + fps: Frames per second for dataset + """ + logging.info(f"Checkpoint policy after step {optimization_step}") + _num_digits = max(6, len(str(online_steps))) + step_identifier = f"{optimization_step:0{_num_digits}d}" + interaction_step = ( + interaction_message["Interaction step"] if interaction_message is not None else 0 + ) + + # Create checkpoint directory + checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, online_steps, optimization_step) + + # Save checkpoint + save_checkpoint( + checkpoint_dir=checkpoint_dir, + step=optimization_step, + cfg=cfg, + policy=policy, + optimizer=optimizers, + scheduler=None + ) + + # Save interaction step manually + training_state_dir = os.path.join(checkpoint_dir, TRAINING_STATE_DIR) + os.makedirs(training_state_dir, exist_ok=True) + training_state = { + "step": optimization_step, + "interaction_step": interaction_step + } + torch.save(training_state, os.path.join(training_state_dir, "training_state.pt")) + + # Update the "last" symlink + update_last_checkpoint(checkpoint_dir) - # Save checkpoint - save_checkpoint(checkpoint_dir, optimization_step, cfg, policy, optimizers, scheduler=None) + # TODO : temporarly save replay buffer here, remove later when on the robot + # We want to control this with the keyboard inputs + dataset_dir = os.path.join(cfg.output_dir, "dataset") + if os.path.exists(dataset_dir) and os.path.isdir(dataset_dir): + shutil.rmtree(dataset_dir) + + # Save dataset + # NOTE: Handle the case where the dataset repo id is not specified in the config + # eg. RL training without demonstrations data + repo_id_buffer_save = cfg.env.task if dataset_repo_id is None else dataset_repo_id + replay_buffer.to_lerobot_dataset( + repo_id=repo_id_buffer_save, + fps=fps, + root=dataset_dir + ) + + if offline_replay_buffer is not None: + dataset_offline_dir = os.path.join(cfg.output_dir, "dataset_offline") + if os.path.exists(dataset_offline_dir) and os.path.isdir(dataset_offline_dir): + shutil.rmtree(dataset_offline_dir) - # Save interaction step manually - training_state_dir = os.path.join(checkpoint_dir, TRAINING_STATE_DIR) - os.makedirs(training_state_dir, exist_ok=True) - training_state = {"step": optimization_step, "interaction_step": interaction_step} - torch.save(training_state, os.path.join(training_state_dir, "training_state.pt")) - - # Update the "last" symlink - update_last_checkpoint(checkpoint_dir) - - # TODO : temporarly save replay buffer here, remove later when on the robot - # We want to control this with the keyboard inputs - dataset_dir = os.path.join(cfg.output_dir, "dataset") - if os.path.exists(dataset_dir) and os.path.isdir(dataset_dir): - shutil.rmtree(dataset_dir) - - # Save dataset - # NOTE: Handle the case where the dataset repo id is not specified in the config - # eg. RL training without demonstrations data - repo_id_buffer_save = cfg.env.task if dataset_repo_id is None else dataset_repo_id - replay_buffer.to_lerobot_dataset(repo_id=repo_id_buffer_save, fps=fps, root=dataset_dir) - - if offline_replay_buffer is not None: - dataset_offline_dir = os.path.join(cfg.output_dir, "dataset_offline") - if os.path.exists(dataset_offline_dir) and os.path.isdir(dataset_offline_dir): - shutil.rmtree(dataset_offline_dir) - - offline_replay_buffer.to_lerobot_dataset( - cfg.dataset.dataset_repo_id, - fps=cfg.env.fps, - root=dataset_offline_dir, - ) - - logging.info("Resume training") + offline_replay_buffer.to_lerobot_dataset( + cfg.dataset.repo_id, + fps=fps, + root=dataset_offline_dir, + ) + logging.info("Resume training") def make_optimizers_and_scheduler(cfg, policy: nn.Module): """