diff --git a/lerobot/scripts/server/buffer.py b/lerobot/scripts/server/buffer.py
index 99f5c55b..fb463762 100644
--- a/lerobot/scripts/server/buffer.py
+++ b/lerobot/scripts/server/buffer.py
@@ -138,6 +138,7 @@ class ReplayBuffer:
         state_keys: Optional[Sequence[str]] = None,
         image_augmentation_function: Optional[Callable] = None,
         use_drq: bool = True,
+        storage_device: str = "cpu",
     ):
         """
         Args:
@@ -147,9 +148,12 @@ class ReplayBuffer:
             image_augmentation_function (Optional[Callable]): A function that takes a batch of images
                 and returns a batch of augmented images. If None, a default augmentation function is used.
             use_drq (bool): Whether to use the default DRQ image augmentation style, when sampling in the buffer.
+            storage_device: The device (e.g. "cpu" or "cuda:0") where the data will be stored when adding transitions to the buffer.
+                Using "cpu" can help save GPU memory.
         """
         self.capacity = capacity
         self.device = device
+        self.storage_device = storage_device
         self.memory: list[Transition] = []
         self.position = 0
 
@@ -172,7 +176,16 @@ class ReplayBuffer:
         done: bool,
         complementary_info: Optional[dict[str, torch.Tensor]] = None,
     ):
-        """Saves a transition."""
+        """Saves a transition, ensuring tensors are stored on the designated storage device."""
+        # Move tensors to the storage device
+        state = {key: tensor.to(self.storage_device) for key, tensor in state.items()}
+        next_state = {key: tensor.to(self.storage_device) for key, tensor in next_state.items()}
+        action = action.to(self.storage_device)
+        if complementary_info is not None:
+            complementary_info = {
+                key: tensor.to(self.storage_device) for key, tensor in complementary_info.items()
+            }
+
         if len(self.memory) < self.capacity:
             self.memory.append(None)
 
@@ -185,7 +198,7 @@ class ReplayBuffer:
             done=done,
             complementary_info=complementary_info,
         )
-        self.position: int = (self.position + 1) % self.capacity
+        self.position = (self.position + 1) % self.capacity
 
     # TODO: ADD image_augmentation and use_drq arguments in this function in order to instantiate the class with them
     @classmethod
@@ -475,7 +488,6 @@ class ReplayBuffer:
 
             # Move to next frame
             frame_idx_in_episode += 1
-
             # If we reached an episode boundary, call save_episode, reset counters
             if transition["done"]:
                 # Use some placeholder name for the task