Update tensor device assignment in ReplayBuffer class
- Changed the device assignment for tensors in the ReplayBuffer class from `device` to `storage_device` for consistency and improved resource management.
This commit is contained in:
parent
68b8e274dd
commit
36714a14a7
|
@ -463,9 +463,9 @@ class ReplayBuffer:
|
||||||
for k, v in data.items():
|
for k, v in data.items():
|
||||||
if isinstance(v, dict):
|
if isinstance(v, dict):
|
||||||
for key, tensor in v.items():
|
for key, tensor in v.items():
|
||||||
v[key] = tensor.to(device)
|
v[key] = tensor.to(storage_device)
|
||||||
elif isinstance(v, torch.Tensor):
|
elif isinstance(v, torch.Tensor):
|
||||||
data[k] = v.to(device)
|
data[k] = v.to(storage_device)
|
||||||
|
|
||||||
action = data["action"]
|
action = data["action"]
|
||||||
if action_mask is not None:
|
if action_mask is not None:
|
||||||
|
|
Loading…
Reference in New Issue