Add complementary info in the replay buffer
- Added complementary info in the add method - Added complementary info in the sample method
This commit is contained in:
parent
4a1c26d9ee
commit
007fee9230
|
@ -244,6 +244,11 @@ class ReplayBuffer:
|
|||
|
||||
self.dones = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device)
|
||||
self.truncateds = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device)
|
||||
|
||||
# Initialize complementary_info storage
|
||||
self.complementary_info_keys = []
|
||||
self.complementary_info_storage = {}
|
||||
|
||||
self.initialized = True
|
||||
|
||||
def __len__(self):
|
||||
|
@ -277,6 +282,30 @@ class ReplayBuffer:
|
|||
self.dones[self.position] = done
|
||||
self.truncateds[self.position] = truncated
|
||||
|
||||
# Store complementary info if provided
|
||||
if complementary_info is not None:
|
||||
# Initialize storage for new keys on first encounter
|
||||
for key, value in complementary_info.items():
|
||||
if key not in self.complementary_info_keys:
|
||||
self.complementary_info_keys.append(key)
|
||||
if isinstance(value, torch.Tensor):
|
||||
shape = value.shape if value.ndim > 0 else (1,)
|
||||
self.complementary_info_storage[key] = torch.zeros(
|
||||
(self.capacity, *shape),
|
||||
dtype=value.dtype,
|
||||
device=self.storage_device
|
||||
)
|
||||
|
||||
# Store the value
|
||||
if key in self.complementary_info_storage:
|
||||
if isinstance(value, torch.Tensor):
|
||||
self.complementary_info_storage[key][self.position] = value
|
||||
else:
|
||||
# For non-tensor values (like grasp_penalty)
|
||||
self.complementary_info_storage[key][self.position] = torch.tensor(
|
||||
value, device=self.storage_device
|
||||
)
|
||||
|
||||
self.position = (self.position + 1) % self.capacity
|
||||
self.size = min(self.size + 1, self.capacity)
|
||||
|
||||
|
@ -335,6 +364,13 @@ class ReplayBuffer:
|
|||
batch_dones = self.dones[idx].to(self.device).float()
|
||||
batch_truncateds = self.truncateds[idx].to(self.device).float()
|
||||
|
||||
# Add complementary_info to batch if it exists
|
||||
batch_complementary_info = {}
|
||||
if hasattr(self, 'complementary_info_keys') and self.complementary_info_keys:
|
||||
for key in self.complementary_info_keys:
|
||||
if key in self.complementary_info_storage:
|
||||
batch_complementary_info[key] = self.complementary_info_storage[key][idx].to(self.device)
|
||||
|
||||
return BatchTransition(
|
||||
state=batch_state,
|
||||
action=batch_actions,
|
||||
|
@ -342,6 +378,7 @@ class ReplayBuffer:
|
|||
next_state=batch_next_state,
|
||||
done=batch_dones,
|
||||
truncated=batch_truncateds,
|
||||
complementary_info=batch_complementary_info if batch_complementary_info else None,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
|
Loading…
Reference in New Issue