Add storage device configuration for SAC policy and replay buffer

- Introduce `storage_device` parameter in SAC configuration and training settings
- Update learner server to use configurable storage device for replay buffer
- Reduce online buffer capacity in ManiSkill configuration
- Modify replay buffer initialization to support custom storage device
This commit is contained in:
AdilZouitine 2025-03-04 13:22:35 +00:00
parent 24f93c755a
commit 76df8a31b3
3 changed files with 11 additions and 5 deletions

View File

@ -64,6 +64,8 @@ class SACConfig:
}
)
camera_number: int = 1
storage_device: str = "cpu"
# Add type annotations for these fields:
vision_encoder_name: str | None = field(default="helper2424/resnet10")
freeze_vision_encoder: bool = True

View File

@ -20,6 +20,9 @@ training:
grad_clip_norm: 10.0
lr: 3e-4
storage_device: "cpu"
eval_freq: 2500
log_freq: 10
save_freq: 2000000
@ -30,7 +33,7 @@ training:
online_steps_between_rollouts: 1000
online_sampling_ratio: 1.0
online_env_seed: 10000
online_buffer_capacity: 1000000
online_buffer_capacity: 200000
online_buffer_seed_size: 0
online_step_before_learning: 500
do_online_rollout_async: false

View File

@ -146,14 +146,14 @@ def log_training_info(cfg: DictConfig, out_dir: str, policy: nn.Module) -> None:
def initialize_replay_buffer(
cfg: DictConfig, logger: Logger, device: str
cfg: DictConfig, logger: Logger, device: str, storage_device:str
) -> ReplayBuffer:
if not cfg.resume:
return ReplayBuffer(
capacity=cfg.training.online_buffer_capacity,
device=device,
state_keys=cfg.policy.input_shapes.keys(),
storage_device=device,
storage_device=storage_device,
optimize_memory=True,
)
@ -596,6 +596,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
set_global_seed(cfg.seed)
device = get_safe_torch_device(cfg.device, log=True)
storage_device = get_safe_torch_device(cfg_device=cfg.training.storage_device)
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
@ -628,7 +629,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
log_training_info(cfg, out_dir, policy)
replay_buffer = initialize_replay_buffer(cfg, logger, device)
replay_buffer = initialize_replay_buffer(cfg, logger, device, storage_device)
batch_size = cfg.training.batch_size
offline_replay_buffer = None
@ -649,7 +650,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
state_keys=cfg.policy.input_shapes.keys(),
action_mask=active_action_dims,
action_delta=cfg.env.wrapper.delta_action,
storage_device=device,
storage_device=storage_device,
optimize_memory=True,
)
batch_size: int = batch_size // 2 # We will sample from both replay buffer