From 76df8a31b3d4936a5a39414354a3361c4a123678 Mon Sep 17 00:00:00 2001 From: AdilZouitine Date: Tue, 4 Mar 2025 13:22:35 +0000 Subject: [PATCH] Add storage device configuration for SAC policy and replay buffer - Introduce `storage_device` parameter in SAC configuration and training settings - Update learner server to use configurable storage device for replay buffer - Reduce online buffer capacity in ManiSkill configuration - Modify replay buffer initialization to support custom storage device --- lerobot/common/policies/sac/configuration_sac.py | 2 ++ lerobot/configs/policy/sac_maniskill.yaml | 5 ++++- lerobot/scripts/server/learner_server.py | 9 +++++---- 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/lerobot/common/policies/sac/configuration_sac.py b/lerobot/common/policies/sac/configuration_sac.py index d225f11b..b834896e 100644 --- a/lerobot/common/policies/sac/configuration_sac.py +++ b/lerobot/common/policies/sac/configuration_sac.py @@ -64,6 +64,8 @@ class SACConfig: } ) camera_number: int = 1 + + storage_device: str = "cpu" # Add type annotations for these fields: vision_encoder_name: str | None = field(default="helper2424/resnet10") freeze_vision_encoder: bool = True diff --git a/lerobot/configs/policy/sac_maniskill.yaml b/lerobot/configs/policy/sac_maniskill.yaml index c78df904..87fc4095 100644 --- a/lerobot/configs/policy/sac_maniskill.yaml +++ b/lerobot/configs/policy/sac_maniskill.yaml @@ -20,6 +20,9 @@ training: grad_clip_norm: 10.0 lr: 3e-4 + + storage_device: "cpu" + eval_freq: 2500 log_freq: 10 save_freq: 2000000 @@ -30,7 +33,7 @@ training: online_steps_between_rollouts: 1000 online_sampling_ratio: 1.0 online_env_seed: 10000 - online_buffer_capacity: 1000000 + online_buffer_capacity: 200000 online_buffer_seed_size: 0 online_step_before_learning: 500 do_online_rollout_async: false diff --git a/lerobot/scripts/server/learner_server.py b/lerobot/scripts/server/learner_server.py index edbeb01c..baba99e7 100644 --- a/lerobot/scripts/server/learner_server.py +++ b/lerobot/scripts/server/learner_server.py @@ -146,14 +146,14 @@ def log_training_info(cfg: DictConfig, out_dir: str, policy: nn.Module) -> None: def initialize_replay_buffer( - cfg: DictConfig, logger: Logger, device: str + cfg: DictConfig, logger: Logger, device: str, storage_device:str ) -> ReplayBuffer: if not cfg.resume: return ReplayBuffer( capacity=cfg.training.online_buffer_capacity, device=device, state_keys=cfg.policy.input_shapes.keys(), - storage_device=device, + storage_device=storage_device, optimize_memory=True, ) @@ -596,6 +596,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No set_global_seed(cfg.seed) device = get_safe_torch_device(cfg.device, log=True) + storage_device = get_safe_torch_device(cfg_device=cfg.training.storage_device) torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True @@ -628,7 +629,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No log_training_info(cfg, out_dir, policy) - replay_buffer = initialize_replay_buffer(cfg, logger, device) + replay_buffer = initialize_replay_buffer(cfg, logger, device, storage_device) batch_size = cfg.training.batch_size offline_replay_buffer = None @@ -649,7 +650,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No state_keys=cfg.policy.input_shapes.keys(), action_mask=active_action_dims, action_delta=cfg.env.wrapper.delta_action, - storage_device=device, + storage_device=storage_device, optimize_memory=True, ) batch_size: int = batch_size // 2 # We will sample from both replay buffer