Optimized the replay buffer from the memory side to store data on cpu instead of a gpu device and send the batches to the gpu.

Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
Michel Aractingi 2025-02-13 18:03:57 +01:00
parent 95de8e273d
commit c9e50bb9b1
1 changed files with 15 additions and 3 deletions

View File

@ -138,6 +138,7 @@ class ReplayBuffer:
state_keys: Optional[Sequence[str]] = None,
image_augmentation_function: Optional[Callable] = None,
use_drq: bool = True,
storage_device: str = "cpu",
):
"""
Args:
@ -147,9 +148,12 @@ class ReplayBuffer:
image_augmentation_function (Optional[Callable]): A function that takes a batch of images
and returns a batch of augmented images. If None, a default augmentation function is used.
use_drq (bool): Whether to use the default DRQ image augmentation style, when sampling in the buffer.
storage_device: The device (e.g. "cpu" or "cuda:0") where the data will be stored when adding transitions to the buffer.
Using "cpu" can help save GPU memory.
"""
self.capacity = capacity
self.device = device
self.storage_device = storage_device
self.memory: list[Transition] = []
self.position = 0
@ -172,7 +176,16 @@ class ReplayBuffer:
done: bool,
complementary_info: Optional[dict[str, torch.Tensor]] = None,
):
"""Saves a transition."""
"""Saves a transition, ensuring tensors are stored on the designated storage device."""
# Move tensors to the storage device
state = {key: tensor.to(self.storage_device) for key, tensor in state.items()}
next_state = {key: tensor.to(self.storage_device) for key, tensor in next_state.items()}
action = action.to(self.storage_device)
if complementary_info is not None:
complementary_info = {
key: tensor.to(self.storage_device) for key, tensor in complementary_info.items()
}
if len(self.memory) < self.capacity:
self.memory.append(None)
@ -185,7 +198,7 @@ class ReplayBuffer:
done=done,
complementary_info=complementary_info,
)
self.position: int = (self.position + 1) % self.capacity
self.position = (self.position + 1) % self.capacity
# TODO: ADD image_augmentation and use_drq arguments in this function in order to instantiate the class with them
@classmethod
@ -475,7 +488,6 @@ class ReplayBuffer:
# Move to next frame
frame_idx_in_episode += 1
# If we reached an episode boundary, call save_episode, reset counters
if transition["done"]:
# Use some placeholder name for the task