Add complementary info in the replay buffer

- Added complementary info in the add method
- Added complementary info in the sample method
This commit is contained in:
s1lent4gnt 2025-03-31 17:36:35 +02:00
parent 4a1c26d9ee
commit 007fee9230
1 changed files with 37 additions and 0 deletions

View File

@ -244,6 +244,11 @@ class ReplayBuffer:
self.dones = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device)
self.truncateds = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device)
# Initialize complementary_info storage
self.complementary_info_keys = []
self.complementary_info_storage = {}
self.initialized = True
def __len__(self):
@ -277,6 +282,30 @@ class ReplayBuffer:
self.dones[self.position] = done
self.truncateds[self.position] = truncated
# Store complementary info if provided
if complementary_info is not None:
# Initialize storage for new keys on first encounter
for key, value in complementary_info.items():
if key not in self.complementary_info_keys:
self.complementary_info_keys.append(key)
if isinstance(value, torch.Tensor):
shape = value.shape if value.ndim > 0 else (1,)
self.complementary_info_storage[key] = torch.zeros(
(self.capacity, *shape),
dtype=value.dtype,
device=self.storage_device
)
# Store the value
if key in self.complementary_info_storage:
if isinstance(value, torch.Tensor):
self.complementary_info_storage[key][self.position] = value
else:
# For non-tensor values (like grasp_penalty)
self.complementary_info_storage[key][self.position] = torch.tensor(
value, device=self.storage_device
)
self.position = (self.position + 1) % self.capacity
self.size = min(self.size + 1, self.capacity)
@ -335,6 +364,13 @@ class ReplayBuffer:
batch_dones = self.dones[idx].to(self.device).float()
batch_truncateds = self.truncateds[idx].to(self.device).float()
# Add complementary_info to batch if it exists
batch_complementary_info = {}
if hasattr(self, 'complementary_info_keys') and self.complementary_info_keys:
for key in self.complementary_info_keys:
if key in self.complementary_info_storage:
batch_complementary_info[key] = self.complementary_info_storage[key][idx].to(self.device)
return BatchTransition(
state=batch_state,
action=batch_actions,
@ -342,6 +378,7 @@ class ReplayBuffer:
next_state=batch_next_state,
done=batch_dones,
truncated=batch_truncateds,
complementary_info=batch_complementary_info if batch_complementary_info else None,
)
@classmethod