Fix: Prevent Invalid next_state References When optimize_memory=True (#918)
This commit is contained in:
parent
c05e4835d0
commit
66c3672738
|
@ -286,9 +286,10 @@ class ReplayBuffer:
|
||||||
raise RuntimeError("Cannot sample from an empty buffer. Add transitions first.")
|
raise RuntimeError("Cannot sample from an empty buffer. Add transitions first.")
|
||||||
|
|
||||||
batch_size = min(batch_size, self.size)
|
batch_size = min(batch_size, self.size)
|
||||||
|
high = max(0, self.size - 1) if self.optimize_memory and self.size < self.capacity else self.size
|
||||||
|
|
||||||
# Random indices for sampling - create on the same device as storage
|
# Random indices for sampling - create on the same device as storage
|
||||||
idx = torch.randint(low=0, high=self.size, size=(batch_size,), device=self.storage_device)
|
idx = torch.randint(low=0, high=high, size=(batch_size,), device=self.storage_device)
|
||||||
|
|
||||||
# Identify image keys that need augmentation
|
# Identify image keys that need augmentation
|
||||||
image_keys = [k for k in self.states if k.startswith("observation.image")] if self.use_drq else []
|
image_keys = [k for k in self.states if k.startswith("observation.image")] if self.use_drq else []
|
||||||
|
|
Loading…
Reference in New Issue