From 007fee923071a860ff05bf1d7b536375ed6dea5f Mon Sep 17 00:00:00 2001 From: s1lent4gnt Date: Mon, 31 Mar 2025 17:36:35 +0200 Subject: [PATCH] Add complementary info in the replay buffer - Added complementary info in the add method - Added complementary info in the sample method --- lerobot/scripts/server/buffer.py | 37 ++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/lerobot/scripts/server/buffer.py b/lerobot/scripts/server/buffer.py index 776ad9ec..1fbc8803 100644 --- a/lerobot/scripts/server/buffer.py +++ b/lerobot/scripts/server/buffer.py @@ -244,6 +244,11 @@ class ReplayBuffer: self.dones = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device) self.truncateds = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device) + + # Initialize complementary_info storage + self.complementary_info_keys = [] + self.complementary_info_storage = {} + self.initialized = True def __len__(self): @@ -277,6 +282,30 @@ class ReplayBuffer: self.dones[self.position] = done self.truncateds[self.position] = truncated + # Store complementary info if provided + if complementary_info is not None: + # Initialize storage for new keys on first encounter + for key, value in complementary_info.items(): + if key not in self.complementary_info_keys: + self.complementary_info_keys.append(key) + if isinstance(value, torch.Tensor): + shape = value.shape if value.ndim > 0 else (1,) + self.complementary_info_storage[key] = torch.zeros( + (self.capacity, *shape), + dtype=value.dtype, + device=self.storage_device + ) + + # Store the value + if key in self.complementary_info_storage: + if isinstance(value, torch.Tensor): + self.complementary_info_storage[key][self.position] = value + else: + # For non-tensor values (like grasp_penalty) + self.complementary_info_storage[key][self.position] = torch.tensor( + value, device=self.storage_device + ) + self.position = (self.position + 1) % self.capacity self.size = min(self.size + 1, self.capacity) @@ -335,6 +364,13 @@ class ReplayBuffer: batch_dones = self.dones[idx].to(self.device).float() batch_truncateds = self.truncateds[idx].to(self.device).float() + # Add complementary_info to batch if it exists + batch_complementary_info = {} + if hasattr(self, 'complementary_info_keys') and self.complementary_info_keys: + for key in self.complementary_info_keys: + if key in self.complementary_info_storage: + batch_complementary_info[key] = self.complementary_info_storage[key][idx].to(self.device) + return BatchTransition( state=batch_state, action=batch_actions, @@ -342,6 +378,7 @@ class ReplayBuffer: next_state=batch_next_state, done=batch_dones, truncated=batch_truncateds, + complementary_info=batch_complementary_info if batch_complementary_info else None, ) @classmethod