Fix: Prevent Invalid next_state References When optimize_memory=True (#918)

This commit is contained in:
s1lent4gnt 2025-03-31 09:43:40 +02:00 committed by GitHub
parent c05e4835d0
commit 66c3672738
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 2 additions and 1 deletions

View File

@ -286,9 +286,10 @@ class ReplayBuffer:
raise RuntimeError("Cannot sample from an empty buffer. Add transitions first.")
batch_size = min(batch_size, self.size)
high = max(0, self.size - 1) if self.optimize_memory and self.size < self.capacity else self.size
# Random indices for sampling - create on the same device as storage
idx = torch.randint(low=0, high=self.size, size=(batch_size,), device=self.storage_device)
idx = torch.randint(low=0, high=high, size=(batch_size,), device=self.storage_device)
# Identify image keys that need augmentation
image_keys = [k for k in self.states if k.startswith("observation.image")] if self.use_drq else []