Add storage device configuration for SAC policy and replay buffer
- Introduce `storage_device` parameter in SAC configuration and training settings - Update learner server to use configurable storage device for replay buffer - Reduce online buffer capacity in ManiSkill configuration - Modify replay buffer initialization to support custom storage device
This commit is contained in:
parent
1df9ee4f2d
commit
d8a1758122
|
@ -64,6 +64,8 @@ class SACConfig:
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
camera_number: int = 1
|
camera_number: int = 1
|
||||||
|
|
||||||
|
storage_device: str = "cpu"
|
||||||
# Add type annotations for these fields:
|
# Add type annotations for these fields:
|
||||||
vision_encoder_name: str | None = field(default="helper2424/resnet10")
|
vision_encoder_name: str | None = field(default="helper2424/resnet10")
|
||||||
freeze_vision_encoder: bool = True
|
freeze_vision_encoder: bool = True
|
||||||
|
|
|
@ -20,6 +20,9 @@ training:
|
||||||
grad_clip_norm: 10.0
|
grad_clip_norm: 10.0
|
||||||
lr: 3e-4
|
lr: 3e-4
|
||||||
|
|
||||||
|
|
||||||
|
storage_device: "cpu"
|
||||||
|
|
||||||
eval_freq: 2500
|
eval_freq: 2500
|
||||||
log_freq: 10
|
log_freq: 10
|
||||||
save_freq: 2000000
|
save_freq: 2000000
|
||||||
|
@ -30,7 +33,7 @@ training:
|
||||||
online_steps_between_rollouts: 1000
|
online_steps_between_rollouts: 1000
|
||||||
online_sampling_ratio: 1.0
|
online_sampling_ratio: 1.0
|
||||||
online_env_seed: 10000
|
online_env_seed: 10000
|
||||||
online_buffer_capacity: 1000000
|
online_buffer_capacity: 200000
|
||||||
online_buffer_seed_size: 0
|
online_buffer_seed_size: 0
|
||||||
online_step_before_learning: 500
|
online_step_before_learning: 500
|
||||||
do_online_rollout_async: false
|
do_online_rollout_async: false
|
||||||
|
|
|
@ -146,14 +146,14 @@ def log_training_info(cfg: DictConfig, out_dir: str, policy: nn.Module) -> None:
|
||||||
|
|
||||||
|
|
||||||
def initialize_replay_buffer(
|
def initialize_replay_buffer(
|
||||||
cfg: DictConfig, logger: Logger, device: str
|
cfg: DictConfig, logger: Logger, device: str, storage_device:str
|
||||||
) -> ReplayBuffer:
|
) -> ReplayBuffer:
|
||||||
if not cfg.resume:
|
if not cfg.resume:
|
||||||
return ReplayBuffer(
|
return ReplayBuffer(
|
||||||
capacity=cfg.training.online_buffer_capacity,
|
capacity=cfg.training.online_buffer_capacity,
|
||||||
device=device,
|
device=device,
|
||||||
state_keys=cfg.policy.input_shapes.keys(),
|
state_keys=cfg.policy.input_shapes.keys(),
|
||||||
storage_device=device,
|
storage_device=storage_device,
|
||||||
optimize_memory=True,
|
optimize_memory=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -596,6 +596,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||||
set_global_seed(cfg.seed)
|
set_global_seed(cfg.seed)
|
||||||
|
|
||||||
device = get_safe_torch_device(cfg.device, log=True)
|
device = get_safe_torch_device(cfg.device, log=True)
|
||||||
|
storage_device = get_safe_torch_device(cfg_device=cfg.training.storage_device)
|
||||||
|
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
@ -628,7 +629,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||||
|
|
||||||
log_training_info(cfg, out_dir, policy)
|
log_training_info(cfg, out_dir, policy)
|
||||||
|
|
||||||
replay_buffer = initialize_replay_buffer(cfg, logger, device)
|
replay_buffer = initialize_replay_buffer(cfg, logger, device, storage_device)
|
||||||
batch_size = cfg.training.batch_size
|
batch_size = cfg.training.batch_size
|
||||||
offline_replay_buffer = None
|
offline_replay_buffer = None
|
||||||
|
|
||||||
|
@ -649,7 +650,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||||
state_keys=cfg.policy.input_shapes.keys(),
|
state_keys=cfg.policy.input_shapes.keys(),
|
||||||
action_mask=active_action_dims,
|
action_mask=active_action_dims,
|
||||||
action_delta=cfg.env.wrapper.delta_action,
|
action_delta=cfg.env.wrapper.delta_action,
|
||||||
storage_device=device,
|
storage_device=storage_device,
|
||||||
optimize_memory=True,
|
optimize_memory=True,
|
||||||
)
|
)
|
||||||
batch_size: int = batch_size // 2 # We will sample from both replay buffer
|
batch_size: int = batch_size // 2 # We will sample from both replay buffer
|
||||||
|
|
Loading…
Reference in New Issue