Add complementary info in the replay buffer
- Added complementary info in the add method - Added complementary info in the sample method
This commit is contained in:
parent
4a1c26d9ee
commit
007fee9230
|
@ -244,6 +244,11 @@ class ReplayBuffer:
|
||||||
|
|
||||||
self.dones = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device)
|
self.dones = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device)
|
||||||
self.truncateds = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device)
|
self.truncateds = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device)
|
||||||
|
|
||||||
|
# Initialize complementary_info storage
|
||||||
|
self.complementary_info_keys = []
|
||||||
|
self.complementary_info_storage = {}
|
||||||
|
|
||||||
self.initialized = True
|
self.initialized = True
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
|
@ -277,6 +282,30 @@ class ReplayBuffer:
|
||||||
self.dones[self.position] = done
|
self.dones[self.position] = done
|
||||||
self.truncateds[self.position] = truncated
|
self.truncateds[self.position] = truncated
|
||||||
|
|
||||||
|
# Store complementary info if provided
|
||||||
|
if complementary_info is not None:
|
||||||
|
# Initialize storage for new keys on first encounter
|
||||||
|
for key, value in complementary_info.items():
|
||||||
|
if key not in self.complementary_info_keys:
|
||||||
|
self.complementary_info_keys.append(key)
|
||||||
|
if isinstance(value, torch.Tensor):
|
||||||
|
shape = value.shape if value.ndim > 0 else (1,)
|
||||||
|
self.complementary_info_storage[key] = torch.zeros(
|
||||||
|
(self.capacity, *shape),
|
||||||
|
dtype=value.dtype,
|
||||||
|
device=self.storage_device
|
||||||
|
)
|
||||||
|
|
||||||
|
# Store the value
|
||||||
|
if key in self.complementary_info_storage:
|
||||||
|
if isinstance(value, torch.Tensor):
|
||||||
|
self.complementary_info_storage[key][self.position] = value
|
||||||
|
else:
|
||||||
|
# For non-tensor values (like grasp_penalty)
|
||||||
|
self.complementary_info_storage[key][self.position] = torch.tensor(
|
||||||
|
value, device=self.storage_device
|
||||||
|
)
|
||||||
|
|
||||||
self.position = (self.position + 1) % self.capacity
|
self.position = (self.position + 1) % self.capacity
|
||||||
self.size = min(self.size + 1, self.capacity)
|
self.size = min(self.size + 1, self.capacity)
|
||||||
|
|
||||||
|
@ -335,6 +364,13 @@ class ReplayBuffer:
|
||||||
batch_dones = self.dones[idx].to(self.device).float()
|
batch_dones = self.dones[idx].to(self.device).float()
|
||||||
batch_truncateds = self.truncateds[idx].to(self.device).float()
|
batch_truncateds = self.truncateds[idx].to(self.device).float()
|
||||||
|
|
||||||
|
# Add complementary_info to batch if it exists
|
||||||
|
batch_complementary_info = {}
|
||||||
|
if hasattr(self, 'complementary_info_keys') and self.complementary_info_keys:
|
||||||
|
for key in self.complementary_info_keys:
|
||||||
|
if key in self.complementary_info_storage:
|
||||||
|
batch_complementary_info[key] = self.complementary_info_storage[key][idx].to(self.device)
|
||||||
|
|
||||||
return BatchTransition(
|
return BatchTransition(
|
||||||
state=batch_state,
|
state=batch_state,
|
||||||
action=batch_actions,
|
action=batch_actions,
|
||||||
|
@ -342,6 +378,7 @@ class ReplayBuffer:
|
||||||
next_state=batch_next_state,
|
next_state=batch_next_state,
|
||||||
done=batch_dones,
|
done=batch_dones,
|
||||||
truncated=batch_truncateds,
|
truncated=batch_truncateds,
|
||||||
|
complementary_info=batch_complementary_info if batch_complementary_info else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
Loading…
Reference in New Issue