From c9e50bb9b151755bcf1f91c1277d8bc85f0a23e5 Mon Sep 17 00:00:00 2001 From: Michel Aractingi Date: Thu, 13 Feb 2025 18:03:57 +0100 Subject: [PATCH] Optimized the replay buffer from the memory side to store data on cpu instead of a gpu device and send the batches to the gpu. Co-authored-by: Adil Zouitine --- lerobot/scripts/server/buffer.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/lerobot/scripts/server/buffer.py b/lerobot/scripts/server/buffer.py index 99f5c55b..fb463762 100644 --- a/lerobot/scripts/server/buffer.py +++ b/lerobot/scripts/server/buffer.py @@ -138,6 +138,7 @@ class ReplayBuffer: state_keys: Optional[Sequence[str]] = None, image_augmentation_function: Optional[Callable] = None, use_drq: bool = True, + storage_device: str = "cpu", ): """ Args: @@ -147,9 +148,12 @@ class ReplayBuffer: image_augmentation_function (Optional[Callable]): A function that takes a batch of images and returns a batch of augmented images. If None, a default augmentation function is used. use_drq (bool): Whether to use the default DRQ image augmentation style, when sampling in the buffer. + storage_device: The device (e.g. "cpu" or "cuda:0") where the data will be stored when adding transitions to the buffer. + Using "cpu" can help save GPU memory. """ self.capacity = capacity self.device = device + self.storage_device = storage_device self.memory: list[Transition] = [] self.position = 0 @@ -172,7 +176,16 @@ class ReplayBuffer: done: bool, complementary_info: Optional[dict[str, torch.Tensor]] = None, ): - """Saves a transition.""" + """Saves a transition, ensuring tensors are stored on the designated storage device.""" + # Move tensors to the storage device + state = {key: tensor.to(self.storage_device) for key, tensor in state.items()} + next_state = {key: tensor.to(self.storage_device) for key, tensor in next_state.items()} + action = action.to(self.storage_device) + if complementary_info is not None: + complementary_info = { + key: tensor.to(self.storage_device) for key, tensor in complementary_info.items() + } + if len(self.memory) < self.capacity: self.memory.append(None) @@ -185,7 +198,7 @@ class ReplayBuffer: done=done, complementary_info=complementary_info, ) - self.position: int = (self.position + 1) % self.capacity + self.position = (self.position + 1) % self.capacity # TODO: ADD image_augmentation and use_drq arguments in this function in order to instantiate the class with them @classmethod @@ -475,7 +488,6 @@ class ReplayBuffer: # Move to next frame frame_idx_in_episode += 1 - # If we reached an episode boundary, call save_episode, reset counters if transition["done"]: # Use some placeholder name for the task