Update tensor device assignment in ReplayBuffer class

- Changed the device assignment for tensors in the ReplayBuffer class from `device` to `storage_device` for consistency and improved resource management.
This commit is contained in:
AdilZouitine 2025-03-21 14:21:31 +00:00
parent 68b8e274dd
commit 36714a14a7
1 changed files with 2 additions and 2 deletions

View File

@ -463,9 +463,9 @@ class ReplayBuffer:
for k, v in data.items():
if isinstance(v, dict):
for key, tensor in v.items():
v[key] = tensor.to(device)
v[key] = tensor.to(storage_device)
elif isinstance(v, torch.Tensor):
data[k] = v.to(device)
data[k] = v.to(storage_device)
action = data["action"]
if action_mask is not None: