Optimized the replay buffer from the memory side to store data on cpu instead of a gpu device and send the batches to the gpu.
Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
parent
95de8e273d
commit
c9e50bb9b1
|
@ -138,6 +138,7 @@ class ReplayBuffer:
|
||||||
state_keys: Optional[Sequence[str]] = None,
|
state_keys: Optional[Sequence[str]] = None,
|
||||||
image_augmentation_function: Optional[Callable] = None,
|
image_augmentation_function: Optional[Callable] = None,
|
||||||
use_drq: bool = True,
|
use_drq: bool = True,
|
||||||
|
storage_device: str = "cpu",
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
|
@ -147,9 +148,12 @@ class ReplayBuffer:
|
||||||
image_augmentation_function (Optional[Callable]): A function that takes a batch of images
|
image_augmentation_function (Optional[Callable]): A function that takes a batch of images
|
||||||
and returns a batch of augmented images. If None, a default augmentation function is used.
|
and returns a batch of augmented images. If None, a default augmentation function is used.
|
||||||
use_drq (bool): Whether to use the default DRQ image augmentation style, when sampling in the buffer.
|
use_drq (bool): Whether to use the default DRQ image augmentation style, when sampling in the buffer.
|
||||||
|
storage_device: The device (e.g. "cpu" or "cuda:0") where the data will be stored when adding transitions to the buffer.
|
||||||
|
Using "cpu" can help save GPU memory.
|
||||||
"""
|
"""
|
||||||
self.capacity = capacity
|
self.capacity = capacity
|
||||||
self.device = device
|
self.device = device
|
||||||
|
self.storage_device = storage_device
|
||||||
self.memory: list[Transition] = []
|
self.memory: list[Transition] = []
|
||||||
self.position = 0
|
self.position = 0
|
||||||
|
|
||||||
|
@ -172,7 +176,16 @@ class ReplayBuffer:
|
||||||
done: bool,
|
done: bool,
|
||||||
complementary_info: Optional[dict[str, torch.Tensor]] = None,
|
complementary_info: Optional[dict[str, torch.Tensor]] = None,
|
||||||
):
|
):
|
||||||
"""Saves a transition."""
|
"""Saves a transition, ensuring tensors are stored on the designated storage device."""
|
||||||
|
# Move tensors to the storage device
|
||||||
|
state = {key: tensor.to(self.storage_device) for key, tensor in state.items()}
|
||||||
|
next_state = {key: tensor.to(self.storage_device) for key, tensor in next_state.items()}
|
||||||
|
action = action.to(self.storage_device)
|
||||||
|
if complementary_info is not None:
|
||||||
|
complementary_info = {
|
||||||
|
key: tensor.to(self.storage_device) for key, tensor in complementary_info.items()
|
||||||
|
}
|
||||||
|
|
||||||
if len(self.memory) < self.capacity:
|
if len(self.memory) < self.capacity:
|
||||||
self.memory.append(None)
|
self.memory.append(None)
|
||||||
|
|
||||||
|
@ -185,7 +198,7 @@ class ReplayBuffer:
|
||||||
done=done,
|
done=done,
|
||||||
complementary_info=complementary_info,
|
complementary_info=complementary_info,
|
||||||
)
|
)
|
||||||
self.position: int = (self.position + 1) % self.capacity
|
self.position = (self.position + 1) % self.capacity
|
||||||
|
|
||||||
# TODO: ADD image_augmentation and use_drq arguments in this function in order to instantiate the class with them
|
# TODO: ADD image_augmentation and use_drq arguments in this function in order to instantiate the class with them
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -475,7 +488,6 @@ class ReplayBuffer:
|
||||||
|
|
||||||
# Move to next frame
|
# Move to next frame
|
||||||
frame_idx_in_episode += 1
|
frame_idx_in_episode += 1
|
||||||
|
|
||||||
# If we reached an episode boundary, call save_episode, reset counters
|
# If we reached an episode boundary, call save_episode, reset counters
|
||||||
if transition["done"]:
|
if transition["done"]:
|
||||||
# Use some placeholder name for the task
|
# Use some placeholder name for the task
|
||||||
|
|
Loading…
Reference in New Issue