diff --git a/lerobot/scripts/server/buffer.py b/lerobot/scripts/server/buffer.py index 20787568..776ad9ec 100644 --- a/lerobot/scripts/server/buffer.py +++ b/lerobot/scripts/server/buffer.py @@ -286,9 +286,10 @@ class ReplayBuffer: raise RuntimeError("Cannot sample from an empty buffer. Add transitions first.") batch_size = min(batch_size, self.size) + high = max(0, self.size - 1) if self.optimize_memory and self.size < self.capacity else self.size # Random indices for sampling - create on the same device as storage - idx = torch.randint(low=0, high=self.size, size=(batch_size,), device=self.storage_device) + idx = torch.randint(low=0, high=high, size=(batch_size,), device=self.storage_device) # Identify image keys that need augmentation image_keys = [k for k in self.states if k.startswith("observation.image")] if self.use_drq else []