Optimized the replay buffer from the memory side to store data on cpu instead of a gpu device and send the batches to the gpu.
Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
parent
95de8e273d
commit
c9e50bb9b1
|
@ -138,6 +138,7 @@ class ReplayBuffer:
|
|||
state_keys: Optional[Sequence[str]] = None,
|
||||
image_augmentation_function: Optional[Callable] = None,
|
||||
use_drq: bool = True,
|
||||
storage_device: str = "cpu",
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
|
@ -147,9 +148,12 @@ class ReplayBuffer:
|
|||
image_augmentation_function (Optional[Callable]): A function that takes a batch of images
|
||||
and returns a batch of augmented images. If None, a default augmentation function is used.
|
||||
use_drq (bool): Whether to use the default DRQ image augmentation style, when sampling in the buffer.
|
||||
storage_device: The device (e.g. "cpu" or "cuda:0") where the data will be stored when adding transitions to the buffer.
|
||||
Using "cpu" can help save GPU memory.
|
||||
"""
|
||||
self.capacity = capacity
|
||||
self.device = device
|
||||
self.storage_device = storage_device
|
||||
self.memory: list[Transition] = []
|
||||
self.position = 0
|
||||
|
||||
|
@ -172,7 +176,16 @@ class ReplayBuffer:
|
|||
done: bool,
|
||||
complementary_info: Optional[dict[str, torch.Tensor]] = None,
|
||||
):
|
||||
"""Saves a transition."""
|
||||
"""Saves a transition, ensuring tensors are stored on the designated storage device."""
|
||||
# Move tensors to the storage device
|
||||
state = {key: tensor.to(self.storage_device) for key, tensor in state.items()}
|
||||
next_state = {key: tensor.to(self.storage_device) for key, tensor in next_state.items()}
|
||||
action = action.to(self.storage_device)
|
||||
if complementary_info is not None:
|
||||
complementary_info = {
|
||||
key: tensor.to(self.storage_device) for key, tensor in complementary_info.items()
|
||||
}
|
||||
|
||||
if len(self.memory) < self.capacity:
|
||||
self.memory.append(None)
|
||||
|
||||
|
@ -185,7 +198,7 @@ class ReplayBuffer:
|
|||
done=done,
|
||||
complementary_info=complementary_info,
|
||||
)
|
||||
self.position: int = (self.position + 1) % self.capacity
|
||||
self.position = (self.position + 1) % self.capacity
|
||||
|
||||
# TODO: ADD image_augmentation and use_drq arguments in this function in order to instantiate the class with them
|
||||
@classmethod
|
||||
|
@ -475,7 +488,6 @@ class ReplayBuffer:
|
|||
|
||||
# Move to next frame
|
||||
frame_idx_in_episode += 1
|
||||
|
||||
# If we reached an episode boundary, call save_episode, reset counters
|
||||
if transition["done"]:
|
||||
# Use some placeholder name for the task
|
||||
|
|
Loading…
Reference in New Issue