Add storage device configuration for SAC policy and replay buffer
- Introduce `storage_device` parameter in SAC configuration and training settings - Update learner server to use configurable storage device for replay buffer - Reduce online buffer capacity in ManiSkill configuration - Modify replay buffer initialization to support custom storage device
This commit is contained in:
parent
24f93c755a
commit
76df8a31b3
|
@ -64,6 +64,8 @@ class SACConfig:
|
|||
}
|
||||
)
|
||||
camera_number: int = 1
|
||||
|
||||
storage_device: str = "cpu"
|
||||
# Add type annotations for these fields:
|
||||
vision_encoder_name: str | None = field(default="helper2424/resnet10")
|
||||
freeze_vision_encoder: bool = True
|
||||
|
|
|
@ -20,6 +20,9 @@ training:
|
|||
grad_clip_norm: 10.0
|
||||
lr: 3e-4
|
||||
|
||||
|
||||
storage_device: "cpu"
|
||||
|
||||
eval_freq: 2500
|
||||
log_freq: 10
|
||||
save_freq: 2000000
|
||||
|
@ -30,7 +33,7 @@ training:
|
|||
online_steps_between_rollouts: 1000
|
||||
online_sampling_ratio: 1.0
|
||||
online_env_seed: 10000
|
||||
online_buffer_capacity: 1000000
|
||||
online_buffer_capacity: 200000
|
||||
online_buffer_seed_size: 0
|
||||
online_step_before_learning: 500
|
||||
do_online_rollout_async: false
|
||||
|
|
|
@ -146,14 +146,14 @@ def log_training_info(cfg: DictConfig, out_dir: str, policy: nn.Module) -> None:
|
|||
|
||||
|
||||
def initialize_replay_buffer(
|
||||
cfg: DictConfig, logger: Logger, device: str
|
||||
cfg: DictConfig, logger: Logger, device: str, storage_device:str
|
||||
) -> ReplayBuffer:
|
||||
if not cfg.resume:
|
||||
return ReplayBuffer(
|
||||
capacity=cfg.training.online_buffer_capacity,
|
||||
device=device,
|
||||
state_keys=cfg.policy.input_shapes.keys(),
|
||||
storage_device=device,
|
||||
storage_device=storage_device,
|
||||
optimize_memory=True,
|
||||
)
|
||||
|
||||
|
@ -596,6 +596,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||
set_global_seed(cfg.seed)
|
||||
|
||||
device = get_safe_torch_device(cfg.device, log=True)
|
||||
storage_device = get_safe_torch_device(cfg_device=cfg.training.storage_device)
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
@ -628,7 +629,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||
|
||||
log_training_info(cfg, out_dir, policy)
|
||||
|
||||
replay_buffer = initialize_replay_buffer(cfg, logger, device)
|
||||
replay_buffer = initialize_replay_buffer(cfg, logger, device, storage_device)
|
||||
batch_size = cfg.training.batch_size
|
||||
offline_replay_buffer = None
|
||||
|
||||
|
@ -649,7 +650,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||
state_keys=cfg.policy.input_shapes.keys(),
|
||||
action_mask=active_action_dims,
|
||||
action_delta=cfg.env.wrapper.delta_action,
|
||||
storage_device=device,
|
||||
storage_device=storage_device,
|
||||
optimize_memory=True,
|
||||
)
|
||||
batch_size: int = batch_size // 2 # We will sample from both replay buffer
|
||||
|
|
Loading…
Reference in New Issue