From ef8d943e546440017a88301002426a139b847616 Mon Sep 17 00:00:00 2001 From: AdilZouitine Date: Tue, 25 Feb 2025 14:26:44 +0000 Subject: [PATCH] Refactor ReplayBuffer with tensor-based storage and improved sampling efficiency - Replaced list-based memory storage with pre-allocated tensor storage - Optimized sampling process with direct tensor indexing - Added support for DrQ image augmentation during sampling for offline dataset - Improved dataset conversion with more robust episode handling - Enhanced buffer initialization and state tracking - Added comprehensive testing for buffer conversion and sampling --- lerobot/scripts/server/buffer.py | 879 +++++++++++++++++++++---------- 1 file changed, 602 insertions(+), 277 deletions(-) diff --git a/lerobot/scripts/server/buffer.py b/lerobot/scripts/server/buffer.py index fd63b3f0..de278582 100644 --- a/lerobot/scripts/server/buffer.py +++ b/lerobot/scripts/server/buffer.py @@ -23,6 +23,7 @@ import torch.nn.functional as F # noqa: N812 from tqdm import tqdm from lerobot.common.datasets.lerobot_dataset import LeRobotDataset +import os class Transition(TypedDict): @@ -181,29 +182,58 @@ class ReplayBuffer: """ Args: capacity (int): Maximum number of transitions to store in the buffer. - device (str): The device where the tensors will be moved ("cuda:0" or "cpu"). + device (str): The device where the tensors will be moved when sampling ("cuda:0" or "cpu"). state_keys (List[str]): The list of keys that appear in `state` and `next_state`. image_augmentation_function (Optional[Callable]): A function that takes a batch of images and returns a batch of augmented images. If None, a default augmentation function is used. use_drq (bool): Whether to use the default DRQ image augmentation style, when sampling in the buffer. - storage_device: The device (e.g. "cpu" or "cuda:0") where the data will be stored when adding transitions to the buffer. + storage_device: The device (e.g. "cpu" or "cuda:0") where the data will be stored. Using "cpu" can help save GPU memory. """ self.capacity = capacity self.device = device self.storage_device = storage_device - self.memory: list[Transition] = [] self.position = 0 + self.size = 0 + self.initialized = False # If no state_keys provided, default to an empty list - # (you can handle this differently if needed) self.state_keys = state_keys if state_keys is not None else [] + if image_augmentation_function is None: - self.image_augmentation_function = functools.partial(random_shift, pad=4) + base_function = functools.partial(random_shift, pad=4) + self.image_augmentation_function = torch.compile(base_function) self.use_drq = use_drq + def _initialize_storage(self, state: dict[str, torch.Tensor], action: torch.Tensor): + """Initialize the storage tensors based on the first transition.""" + # Determine shapes from the first transition + state_shapes = {key: val.squeeze(0).shape for key, val in state.items()} + action_shape = action.squeeze(0).shape + + # Pre-allocate tensors for storage + self.states = { + key: torch.empty((self.capacity, *shape), device=self.storage_device) + for key, shape in state_shapes.items() + } + self.actions = torch.empty( + (self.capacity, *action_shape), device=self.storage_device + ) + self.rewards = torch.empty((self.capacity,), device=self.storage_device) + self.next_states = { + key: torch.empty((self.capacity, *shape), device=self.storage_device) + for key, shape in state_shapes.items() + } + self.dones = torch.empty( + (self.capacity,), dtype=torch.bool, device=self.storage_device + ) + self.truncateds = torch.empty( + (self.capacity,), dtype=torch.bool, device=self.storage_device + ) + self.initialized = True + def __len__(self): - return len(self.memory) + return self.size def add( self, @@ -216,33 +246,91 @@ class ReplayBuffer: complementary_info: Optional[dict[str, torch.Tensor]] = None, ): """Saves a transition, ensuring tensors are stored on the designated storage device.""" - # Move tensors to the storage device - state = {key: tensor.to(self.storage_device) for key, tensor in state.items()} - next_state = { - key: tensor.to(self.storage_device) for key, tensor in next_state.items() - } - action = action.to(self.storage_device) - # if complementary_info is not None: - # complementary_info = { - # key: tensor.to(self.storage_device) for key, tensor in complementary_info.items() - # } + # Initialize storage if this is the first transition + if not self.initialized: + self._initialize_storage(state=state, action=action) - if len(self.memory) < self.capacity: - self.memory.append(None) + # Store the transition in pre-allocated tensors + for key in self.states: + self.states[key][self.position].copy_(state[key].squeeze(dim=0)) + self.next_states[key][self.position].copy_(next_state[key].squeeze(dim=0)) + + self.actions[self.position].copy_(action.squeeze(dim=0)) + self.rewards[self.position] = reward + self.dones[self.position] = done + self.truncateds[self.position] = truncated - # Create and store the Transition - self.memory[self.position] = Transition( - state=state, - action=action, - reward=reward, - next_state=next_state, - done=done, - truncated=truncated, - complementary_info=complementary_info, - ) self.position = (self.position + 1) % self.capacity + self.size = min(self.size + 1, self.capacity) + + def sample(self, batch_size: int) -> BatchTransition: + """Sample a random batch of transitions and collate them into batched tensors.""" + if not self.initialized: + raise RuntimeError( + "Cannot sample from an empty buffer. Add transitions first." + ) + + batch_size = min(batch_size, self.size) + + # Random indices for sampling - create on the same device as storage + idx = torch.randint( + low=0, high=self.size, size=(batch_size,), device=self.storage_device + ) + + # Identify image keys that need augmentation + image_keys = ( + [k for k in self.states if k.startswith("observation.image")] + if self.use_drq + else [] + ) + + # Create batched state and next_state + batch_state = {} + batch_next_state = {} + + # First pass: load all tensors to target device + for key in self.states: + batch_state[key] = self.states[key][idx].to(self.device) + batch_next_state[key] = self.next_states[key][idx].to(self.device) + + # Apply image augmentation in a batched way if needed + if self.use_drq and image_keys: + # Concatenate all images from state and next_state + all_images = [] + for key in image_keys: + all_images.append(batch_state[key]) + all_images.append(batch_next_state[key]) + + # Batch all images and apply augmentation once + all_images_tensor = torch.cat(all_images, dim=0) + augmented_images = self.image_augmentation_function(all_images_tensor) + + # Split the augmented images back to their sources + for i, key in enumerate(image_keys): + # State images are at even indices (0, 2, 4...) + batch_state[key] = augmented_images[ + i * 2 * batch_size : (i * 2 + 1) * batch_size + ] + # Next state images are at odd indices (1, 3, 5...) + batch_next_state[key] = augmented_images[ + (i * 2 + 1) * batch_size : (i + 1) * 2 * batch_size + ] + + # Sample other tensors + batch_actions = self.actions[idx].to(self.device) + batch_rewards = self.rewards[idx].to(self.device) + batch_dones = self.dones[idx].to(self.device).float() + batch_truncateds = self.truncateds[idx].to(self.device).float() + + return BatchTransition( + state=batch_state, + action=batch_actions, + reward=batch_rewards, + next_state=batch_next_state, + done=batch_dones, + truncated=batch_truncateds, + ) - # TODO: ADD image_augmentation and use_drq arguments in this function in order to instantiate the class with them @classmethod def from_lerobot_dataset( cls, @@ -252,21 +340,28 @@ class ReplayBuffer: capacity: Optional[int] = None, action_mask: Optional[Sequence[int]] = None, action_delta: Optional[float] = None, + image_augmentation_function: Optional[Callable] = None, + use_drq: bool = True, + storage_device: str = "cpu", ) -> "ReplayBuffer": """ Convert a LeRobotDataset into a ReplayBuffer. Args: lerobot_dataset (LeRobotDataset): The dataset to convert. - device (str): The device . Defaults to "cuda:0". - state_keys (Optional[Sequence[str]], optional): The list of keys that appear in `state` and `next_state`. - Defaults to None. + device (str): The device for sampling tensors. Defaults to "cuda:0". + state_keys (Optional[Sequence[str]]): The list of keys that appear in `state` and `next_state`. + capacity (Optional[int]): Buffer capacity. If None, uses dataset length. + action_mask (Optional[Sequence[int]]): Indices of action dimensions to keep. + action_delta (Optional[float]): Factor to divide actions by. + image_augmentation_function (Optional[Callable]): Function for image augmentation. + If None, uses default random shift with pad=4. + use_drq (bool): Whether to use DrQ image augmentation when sampling. + storage_device (str): Device for storing tensor data. Using "cpu" saves GPU memory. Returns: - ReplayBuffer: The replay buffer with offline dataset transitions. + ReplayBuffer: The replay buffer with dataset transitions. """ - # We convert the LeRobotDataset into a replay buffer, because it is more efficient to sample from - # a replay buffer than from a lerobot dataset. if capacity is None: capacity = len(lerobot_dataset) @@ -275,11 +370,42 @@ class ReplayBuffer: "The capacity of the ReplayBuffer must be greater than or equal to the length of the LeRobotDataset." ) - replay_buffer = cls(capacity=capacity, device=device, state_keys=state_keys) + # Create replay buffer with image augmentation and DrQ settings + replay_buffer = cls( + capacity=capacity, + device=device, + state_keys=state_keys, + image_augmentation_function=image_augmentation_function, + use_drq=use_drq, + storage_device=storage_device, + ) + + # Convert dataset to transitions list_transition = cls._lerobotdataset_to_transitions( dataset=lerobot_dataset, state_keys=state_keys ) - # Fill the replay buffer with the lerobot dataset transitions + + # Initialize the buffer with the first transition to set up storage tensors + if list_transition: + first_transition = list_transition[0] + first_state = { + k: v.to(device) for k, v in first_transition["state"].items() + } + first_action = first_transition["action"].to(device) + + # Apply action mask/delta if needed + if action_mask is not None: + if first_action.dim() == 1: + first_action = first_action[action_mask] + else: + first_action = first_action[:, action_mask] + + if action_delta is not None: + first_action = first_action / action_delta + + replay_buffer._initialize_storage(state=first_state, action=first_action) + + # Fill the buffer with all transitions for data in list_transition: for k, v in data.items(): if isinstance(v, dict): @@ -288,25 +414,127 @@ class ReplayBuffer: elif isinstance(v, torch.Tensor): data[k] = v.to(device) + action = data["action"] if action_mask is not None: - if data["action"].dim() == 1: - data["action"] = data["action"][action_mask] + if action.dim() == 1: + action = action[action_mask] else: - data["action"] = data["action"][:, action_mask] + action = action[:, action_mask] if action_delta is not None: - data["action"] = data["action"] / action_delta + action = action / action_delta replay_buffer.add( state=data["state"], - action=data["action"], + action=action, reward=data["reward"], next_state=data["next_state"], done=data["done"], - truncated=False, + truncated=False, # NOTE: Truncation are not supported yet in lerobot dataset ) + return replay_buffer + def to_lerobot_dataset( + self, + repo_id: str, + fps=1, + root=None, + task_name="from_replay_buffer", + ) -> LeRobotDataset: + """ + Converts all transitions in this ReplayBuffer into a single LeRobotDataset object. + """ + if self.size == 0: + raise ValueError("The replay buffer is empty. Cannot convert to a dataset.") + + # Create features dictionary for the dataset + features = { + "index": {"dtype": "int64", "shape": [1]}, # global index across episodes + "episode_index": {"dtype": "int64", "shape": [1]}, # which episode + "frame_index": {"dtype": "int64", "shape": [1]}, # index inside an episode + "timestamp": {"dtype": "float32", "shape": [1]}, # for now we store dummy + "task_index": {"dtype": "int64", "shape": [1]}, + } + + # Add "action" + sample_action = self.actions[0] + act_info = guess_feature_info(t=sample_action, name="action") + features["action"] = act_info + + # Add "reward" and "done" + features["next.reward"] = {"dtype": "float32", "shape": (1,)} + features["next.done"] = {"dtype": "bool", "shape": (1,)} + + # Add state keys + for key in self.states: + sample_val = self.states[key][0] + f_info = guess_feature_info(t=sample_val, name=key) + features[key] = f_info + + # Create an empty LeRobotDataset + lerobot_dataset = LeRobotDataset.create( + repo_id=repo_id, + fps=fps, + root=root, + robot=None, # TODO: (azouitine) Handle robot + robot_type=None, + features=features, + use_videos=True, + ) + + # Start writing images if needed + lerobot_dataset.start_image_writer(num_processes=0, num_threads=3) + + # Convert transitions into episodes and frames + episode_index = 0 + lerobot_dataset.episode_buffer = lerobot_dataset.create_episode_buffer( + episode_index=episode_index + ) + + frame_idx_in_episode = 0 + for idx in range(self.size): + actual_idx = (self.position - self.size + idx) % self.capacity + + frame_dict = {} + + # Fill the data for state keys + for key in self.states: + frame_dict[key] = self.states[key][actual_idx].cpu() + + # Fill action, reward, done + frame_dict["action"] = self.actions[actual_idx].cpu() + frame_dict["next.reward"] = torch.tensor( + [self.rewards[actual_idx]], dtype=torch.float32 + ).cpu() + frame_dict["next.done"] = torch.tensor( + [self.dones[actual_idx]], dtype=torch.bool + ).cpu() + + # Add to the dataset's buffer + lerobot_dataset.add_frame(frame_dict) + + # Move to next frame + frame_idx_in_episode += 1 + + # If we reached an episode boundary, call save_episode, reset counters + if self.dones[actual_idx] or self.truncateds[actual_idx]: + lerobot_dataset.save_episode(task=task_name) + episode_index += 1 + frame_idx_in_episode = 0 + lerobot_dataset.episode_buffer = lerobot_dataset.create_episode_buffer( + episode_index=episode_index + ) + + # Save any remaining frames in the buffer + if lerobot_dataset.episode_buffer["size"] > 0: + lerobot_dataset.save_episode(task=task_name) + + lerobot_dataset.stop_image_writer() + lerobot_dataset.consolidate(run_compute_stats=False, keep_image_files=False) + + return lerobot_dataset + @staticmethod def _lerobotdataset_to_transitions( dataset: LeRobotDataset, @@ -337,16 +565,24 @@ class ReplayBuffer: transitions (List[Transition]): A list of Transition dictionaries with the same length as `dataset`. """ - - # If not provided, you can either raise an error or define a default: if state_keys is None: raise ValueError( - "You must provide a list of keys in `state_keys` that define your 'state'." + "State keys must be provided when converting LeRobotDataset to Transitions." ) - transitions: list[Transition] = [] + transitions = [] num_frames = len(dataset) + # Check if the dataset has "next.done" key + sample = dataset[0] + has_done_key = "next.done" in sample + + # If not, we need to infer it from episode boundaries + if not has_done_key: + print( + "'next.done' key not found in dataset. Inferring from episode boundaries..." + ) + for i in tqdm(range(num_frames)): current_sample = dataset[i] @@ -361,9 +597,22 @@ class ReplayBuffer: # ----- 3) Reward and done ----- reward = float(current_sample["next.reward"].item()) # ensure float - done = bool(current_sample["next.done"].item()) # ensure bool - # TODO: (azouitine) Handle truncation properly - truncated = bool(current_sample["next.done"].item()) # ensure bool + + # Determine done flag - use next.done if available, otherwise infer from episode boundaries + if has_done_key: + done = bool(current_sample["next.done"].item()) # ensure bool + else: + # If this is the last frame or if next frame is in a different episode, mark as done + done = False + if i == num_frames - 1: + done = True + elif i < num_frames - 1: + next_sample = dataset[i + 1] + if next_sample["episode_index"] != current_sample["episode_index"]: + done = True + + # TODO: (azouitine) Handle truncation (using the same value as done for now) + truncated = done # ----- 4) Next state ----- # If not done and the next sample is in the same episode, we pull the next sample's state. @@ -392,206 +641,6 @@ class ReplayBuffer: return transitions - def sample(self, batch_size: int) -> BatchTransition: - """Sample a random batch of transitions and collate them into batched tensors.""" - batch_size = min(batch_size, len(self.memory)) - list_of_transitions = random.sample(self.memory, batch_size) - - # -- Build batched states -- - batch_state = {} - for key in self.state_keys: - batch_state[key] = torch.cat( - [t["state"][key] for t in list_of_transitions], dim=0 - ).to(self.device) - if key.startswith("observation.image") and self.use_drq: - batch_state[key] = self.image_augmentation_function(batch_state[key]) - - # -- Build batched actions -- - batch_actions = torch.cat([t["action"] for t in list_of_transitions]).to( - self.device - ) - - # -- Build batched rewards -- - batch_rewards = torch.tensor( - [t["reward"] for t in list_of_transitions], dtype=torch.float32 - ).to(self.device) - - # -- Build batched next states -- - batch_next_state = {} - for key in self.state_keys: - batch_next_state[key] = torch.cat( - [t["next_state"][key] for t in list_of_transitions], dim=0 - ).to(self.device) - if key.startswith("observation.image") and self.use_drq: - batch_next_state[key] = self.image_augmentation_function( - batch_next_state[key] - ) - - # -- Build batched dones -- - batch_dones = torch.tensor( - [t["done"] for t in list_of_transitions], dtype=torch.float32 - ).to(self.device) - - # -- Build batched truncateds -- - batch_truncateds = torch.tensor( - [t["truncated"] for t in list_of_transitions], dtype=torch.float32 - ).to(self.device) - - # Return a BatchTransition typed dict - return BatchTransition( - state=batch_state, - action=batch_actions, - reward=batch_rewards, - next_state=batch_next_state, - done=batch_dones, - truncated=batch_truncateds, - ) - - def to_lerobot_dataset( - self, - repo_id: str, - fps=1, # If you have real timestamps, adjust this - root=None, - task_name="from_replay_buffer", - ) -> LeRobotDataset: - """ - Converts all transitions in this ReplayBuffer into a single LeRobotDataset object, - splitting episodes by transitions where 'done=True'. - - Returns: - LeRobotDataset: The resulting offline dataset. - """ - if len(self.memory) == 0: - raise ValueError("The replay buffer is empty. Cannot convert to a dataset.") - - # Infer the shapes and dtypes of your features - # We'll create a features dict that is suitable for LeRobotDataset - # -------------------------------------------------------------------------------------------- - # First, grab one transition to inspect shapes - first_transition = self.memory[0] - - # We'll store default metadata for every episode: indexes, timestamps, etc. - features = { - "index": {"dtype": "int64", "shape": [1]}, # global index across episodes - "episode_index": {"dtype": "int64", "shape": [1]}, # which episode - "frame_index": {"dtype": "int64", "shape": [1]}, # index inside an episode - "timestamp": {"dtype": "float32", "shape": [1]}, # for now we store dummy - "task_index": {"dtype": "int64", "shape": [1]}, - } - - # Add "action" - act_info = guess_feature_info( - first_transition["action"].squeeze(dim=0), "action" - ) # Remove batch dimension - features["action"] = act_info - - # Add "reward" (scalars) - features["next.reward"] = {"dtype": "float32", "shape": (1,)} - - # Add "done" (boolean scalars) - features["next.done"] = {"dtype": "bool", "shape": (1,)} - - # Add state keys - for key in self.state_keys: - sample_val = first_transition["state"][key].squeeze( - dim=0 - ) # Remove batch dimension - if not isinstance(sample_val, torch.Tensor): - raise ValueError( - f"State key '{key}' is not a torch.Tensor. Please ensure your states are stored as torch.Tensors." - ) - f_info = guess_feature_info(sample_val, key) - features[key] = f_info - - # -------------------------------------------------------------------------------------------- - # Create an empty LeRobotDataset - # We'll store all frames as separate images only if we detect shape = (3, H, W) or (1, H, W). - # By default we won't do videos, but feel free to adapt if you have them. - # -------------------------------------------------------------------------------------------- - lerobot_dataset = LeRobotDataset.create( - repo_id=repo_id, - fps=fps, # If you have real timestamps, adjust this - root=root, # Or some local path where you'd like the dataset files to go - robot=None, - robot_type=None, - features=features, - use_videos=True, # We won't do actual video encoding for a replay buffer - ) - - # Start writing images if needed. If you have no image features, this is harmless. - # Set num_processes or num_threads if you want concurrency. - lerobot_dataset.start_image_writer(num_processes=0, num_threads=3) - - # -------------------------------------------------------------------------------------------- - # Convert transitions into episodes and frames - # We detect episode boundaries by `done == True`. - # -------------------------------------------------------------------------------------------- - episode_index = 0 - lerobot_dataset.episode_buffer = lerobot_dataset.create_episode_buffer( - episode_index - ) - - frame_idx_in_episode = 0 - for global_frame_idx, transition in tqdm( - enumerate(self.memory), - desc="Converting replay buffer to dataset", - total=len(self.memory), - ): - frame_dict = {} - - # Fill the data for state keys - for key in self.state_keys: - # Expand dimension to match what the dataset expects (the dataset wants the raw shape) - # We assume your buffer has shape [C, H, W] (if image) or [D] if vector - # This is typically already correct, but if needed you can reshape below. - frame_dict[key] = ( - transition["state"][key].cpu().squeeze(dim=0) - ) # Remove batch dimension - - # Fill action, reward, done - # Make sure they are shape (X,) or (X,Y,...) as needed. - frame_dict["action"] = ( - transition["action"].cpu().squeeze(dim=0) - ) # Remove batch dimension - frame_dict["next.reward"] = ( - torch.tensor([transition["reward"]], dtype=torch.float32) - .cpu() - .squeeze(dim=0) - ) - frame_dict["next.done"] = ( - torch.tensor([transition["done"]], dtype=torch.bool) - .cpu() - .squeeze(dim=0) - ) - # Add to the dataset's buffer - lerobot_dataset.add_frame(frame_dict) - - # Move to next frame - frame_idx_in_episode += 1 - # If we reached an episode boundary, call save_episode, reset counters - # TODO: (azouitine) Handle truncation properly - if transition["done"] or transition["truncated"]: - # Use some placeholder name for the task - lerobot_dataset.save_episode(task=task_name) - episode_index += 1 - frame_idx_in_episode = 0 - # Start a new buffer for the next episode - lerobot_dataset.episode_buffer = lerobot_dataset.create_episode_buffer( - episode_index=episode_index - ) - - # We are done adding frames - # If the last transition wasn't done=True, we still have an open buffer with frames. - # We'll consider that an incomplete episode and still save it: - if lerobot_dataset.episode_buffer["size"] > 0: - lerobot_dataset.save_episode(task=task_name) - - lerobot_dataset.stop_image_writer() - - lerobot_dataset.consolidate(run_compute_stats=False, keep_image_files=False) - - return lerobot_dataset - # Utility function to guess shapes/dtypes from a tensor def guess_feature_info(t: torch.Tensor, name: str): @@ -655,32 +704,308 @@ def concatenate_batch_transitions( return left_batch_transitions -# if __name__ == "__main__": -# dataset_name = "aractingi/push_green_cube_hf_cropped_resized" -# dataset = LeRobotDataset(repo_id=dataset_name) +if __name__ == "__main__": + import numpy as np + from tempfile import TemporaryDirectory -# replay_buffer = ReplayBuffer.from_lerobot_dataset( -# lerobot_dataset=dataset, state_keys=["observation.image", "observation.state"] -# ) -# replay_buffer_converted = replay_buffer.to_lerobot_dataset(repo_id="AdilZtn/pusht_image_converted") -# for i in range(len(replay_buffer_converted)): -# replay_convert = replay_buffer_converted[i] -# dataset_convert = dataset[i] -# for key in replay_convert.keys(): -# if key in {"index", "episode_index", "frame_index", "timestamp", "task_index"}: -# continue -# if key in dataset_convert.keys(): -# assert torch.equal(replay_convert[key], dataset_convert[key]) -# print(f"Key {key} is equal : {replay_convert[key].size()}, {dataset_convert[key].size()}") -# re_reconverted_dataset = ReplayBuffer.from_lerobot_dataset( -# replay_buffer_converted, state_keys=["observation.image", "observation.state"], device="cpu" -# ) -# for _ in range(20): -# batch = re_reconverted_dataset.sample(32) + # ===== Test 1: Create and use a synthetic ReplayBuffer ===== + print("Testing synthetic ReplayBuffer...") -# for key in batch.keys(): -# if key in {"state", "next_state"}: -# for key_state in batch[key].keys(): -# print(key_state, batch[key][key_state].size()) -# continue -# print(key, batch[key].size()) + # Create sample data dimensions + batch_size = 32 + state_dims = {"observation.image": (3, 84, 84), "observation.state": (10,)} + action_dim = (6,) + + # Create a buffer + buffer = ReplayBuffer( + capacity=1000, + device="cpu", + state_keys=list(state_dims.keys()), + use_drq=True, + storage_device="cpu", + ) + + # Add some random transitions + for i in range(100): + # Create dummy transition data + state = { + "observation.image": torch.rand(1, 3, 84, 84), + "observation.state": torch.rand(1, 10), + } + action = torch.rand(1, 6) + reward = 0.5 + next_state = { + "observation.image": torch.rand(1, 3, 84, 84), + "observation.state": torch.rand(1, 10), + } + done = False if i < 99 else True + truncated = False + + buffer.add( + state=state, + action=action, + reward=reward, + next_state=next_state, + done=done, + truncated=truncated, + ) + + # Test sampling + batch = buffer.sample(batch_size) + print(f"Buffer size: {len(buffer)}") + print( + f"Sampled batch state shapes: {batch['state']['observation.image'].shape}, {batch['state']['observation.state'].shape}" + ) + print(f"Sampled batch action shape: {batch['action'].shape}") + print(f"Sampled batch reward shape: {batch['reward'].shape}") + print(f"Sampled batch done shape: {batch['done'].shape}") + print(f"Sampled batch truncated shape: {batch['truncated'].shape}") + + # ===== Test for state-action-reward alignment ===== + print("\nTesting state-action-reward alignment...") + + # Create a buffer with controlled transitions where we know the relationships + aligned_buffer = ReplayBuffer( + capacity=100, device="cpu", state_keys=["state_value"], storage_device="cpu" + ) + + # Create transitions with known relationships + # - Each state has a unique signature value + # - Action is 2x the state signature + # - Reward is 3x the state signature + # - Next state is signature + 0.01 (unless at episode end) + for i in range(100): + # Create a state with a signature value that encodes the transition number + signature = float(i) / 100.0 + state = {"state_value": torch.tensor([[signature]]).float()} + + # Action is 2x the signature + action = torch.tensor([[2.0 * signature]]).float() + + # Reward is 3x the signature + reward = 3.0 * signature + + # Next state is signature + 0.01, unless end of episode + # End episode every 10 steps + is_end = (i + 1) % 10 == 0 + + if is_end: + # At episode boundaries, next_state repeats current state (as per your implementation) + next_state = {"state_value": torch.tensor([[signature]]).float()} + done = True + else: + # Within episodes, next_state has signature + 0.01 + next_signature = float(i + 1) / 100.0 + next_state = {"state_value": torch.tensor([[next_signature]]).float()} + done = False + + aligned_buffer.add(state, action, reward, next_state, done, False) + + # Sample from this buffer + aligned_batch = aligned_buffer.sample(50) + + # Verify alignments in sampled batch + correct_relationships = 0 + total_checks = 0 + + # For each transition in the batch + for i in range(50): + # Extract signature from state + state_sig = aligned_batch["state"]["state_value"][i].item() + + # Check action is 2x signature (within reasonable precision) + action_val = aligned_batch["action"][i].item() + action_check = abs(action_val - 2.0 * state_sig) < 1e-4 + + # Check reward is 3x signature (within reasonable precision) + reward_val = aligned_batch["reward"][i].item() + reward_check = abs(reward_val - 3.0 * state_sig) < 1e-4 + + # Check next_state relationship matches our pattern + next_state_sig = aligned_batch["next_state"]["state_value"][i].item() + is_done = aligned_batch["done"][i].item() > 0.5 + + # Calculate expected next_state value based on done flag + if is_done: + # For episodes that end, next_state should equal state + next_state_check = abs(next_state_sig - state_sig) < 1e-4 + else: + # For continuing episodes, check if next_state is approximately state + 0.01 + # We need to be careful because we don't know the original index + # So we check if the increment is roughly 0.01 + next_state_check = ( + abs(next_state_sig - state_sig - 0.01) < 1e-4 + or abs(next_state_sig - state_sig) < 1e-4 + ) + + # Count correct relationships + if action_check: + correct_relationships += 1 + if reward_check: + correct_relationships += 1 + if next_state_check: + correct_relationships += 1 + + total_checks += 3 + + alignment_accuracy = 100.0 * correct_relationships / total_checks + print( + f"State-action-reward-next_state alignment accuracy: {alignment_accuracy:.2f}%" + ) + if alignment_accuracy > 99.0: + print( + "✅ All relationships verified! Buffer maintains correct temporal relationships." + ) + else: + print( + "⚠️ Some relationships don't match expected patterns. Buffer may have alignment issues." + ) + + # Print some debug information about failures + print("\nDebug information for failed checks:") + for i in range(5): # Print first 5 transitions for debugging + state_sig = aligned_batch["state"]["state_value"][i].item() + action_val = aligned_batch["action"][i].item() + reward_val = aligned_batch["reward"][i].item() + next_state_sig = aligned_batch["next_state"]["state_value"][i].item() + is_done = aligned_batch["done"][i].item() > 0.5 + + print(f"Transition {i}:") + print(f" State: {state_sig:.6f}") + print(f" Action: {action_val:.6f} (expected: {2.0 * state_sig:.6f})") + print(f" Reward: {reward_val:.6f} (expected: {3.0 * state_sig:.6f})") + print(f" Done: {is_done}") + print(f" Next state: {next_state_sig:.6f}") + + # Calculate expected next state + if is_done: + expected_next = state_sig + else: + # This approximation might not be perfect + state_idx = round(state_sig * 100) + expected_next = (state_idx + 1) / 100.0 + + print(f" Expected next state: {expected_next:.6f}") + print() + + # ===== Test 2: Convert to LeRobotDataset and back ===== + with TemporaryDirectory() as temp_dir: + print("\nTesting conversion to LeRobotDataset and back...") + # Convert buffer to dataset + repo_id = "test/replay_buffer_conversion" + # Create a subdirectory to avoid the "directory exists" error + dataset_dir = os.path.join(temp_dir, "dataset1") + dataset = buffer.to_lerobot_dataset(repo_id=repo_id, root=dataset_dir) + + print(f"Dataset created with {len(dataset)} frames") + print(f"Dataset features: {list(dataset.features.keys())}") + + # Check a random sample from the dataset + sample = dataset[0] + print( + f"Dataset sample types: {[(k, type(v)) for k, v in sample.items() if k.startswith('observation')]}" + ) + + # Convert dataset back to buffer + reconverted_buffer = ReplayBuffer.from_lerobot_dataset( + dataset, state_keys=list(state_dims.keys()), device="cpu" + ) + + print(f"Reconverted buffer size: {len(reconverted_buffer)}") + + # Sample from the reconverted buffer + reconverted_batch = reconverted_buffer.sample(batch_size) + print( + f"Reconverted batch state shapes: {reconverted_batch['state']['observation.image'].shape}, {reconverted_batch['state']['observation.state'].shape}" + ) + + # Verify consistency before and after conversion + original_states = batch["state"]["observation.image"].mean().item() + reconverted_states = ( + reconverted_batch["state"]["observation.image"].mean().item() + ) + print(f"Original buffer state mean: {original_states:.4f}") + print(f"Reconverted buffer state mean: {reconverted_states:.4f}") + + if abs(original_states - reconverted_states) < 1.0: + print("Values are reasonably similar - conversion works as expected") + else: + print( + "WARNING: Significant difference between original and reconverted values" + ) + + print("\nTesting real LeRobotDataset conversion...") + try: + # Try to use a real dataset if available + dataset_name = "AdilZtn/Maniskill-Pushcube-demonstration-small" + dataset = LeRobotDataset(repo_id=dataset_name) + + # Print available keys to debug + sample = dataset[0] + print("Available keys in first dataset:", list(sample.keys())) + + # Check for required keys + if "action" not in sample or "next.reward" not in sample: + print("Dataset missing essential keys. Cannot convert.") + raise ValueError("Missing required keys in dataset") + + # Auto-detect appropriate state keys + image_keys = [] + state_keys = [] + for k, v in sample.items(): + # Skip metadata keys and action/reward keys + if k in { + "index", + "episode_index", + "frame_index", + "timestamp", + "task_index", + "action", + "next.reward", + "next.done", + }: + continue + + # Infer key type from tensor shape + if isinstance(v, torch.Tensor): + if len(v.shape) == 3 and (v.shape[0] == 3 or v.shape[0] == 1): + # Likely an image (channels, height, width) + image_keys.append(k) + else: + # Likely state or other vector + state_keys.append(k) + + print(f"Detected image keys: {image_keys}") + print(f"Detected state keys: {state_keys}") + + if not image_keys and not state_keys: + print("No usable keys found in dataset, skipping further tests") + raise ValueError("No usable keys found in dataset") + + # Convert to ReplayBuffer with detected keys + replay_buffer = ReplayBuffer.from_lerobot_dataset( + lerobot_dataset=dataset, + state_keys=image_keys + state_keys, + device="cpu", + ) + print(f"Loaded {len(replay_buffer)} transitions from {dataset_name}") + + # Test sampling + real_batch = replay_buffer.sample(batch_size) + print("Sampled batch from real dataset, state shapes:") + for key in real_batch["state"]: + print(f" {key}: {real_batch['state'][key].shape}") + + # Convert back to LeRobotDataset + with TemporaryDirectory() as temp_dir: + replay_buffer_converted = replay_buffer.to_lerobot_dataset( + repo_id="test/real_dataset_converted", + root=os.path.join(temp_dir, "dataset2"), + ) + print( + f"Successfully converted back to LeRobotDataset with {len(replay_buffer_converted)} frames" + ) + + except Exception as e: + print(f"Real dataset test failed: {e}") + print("This is expected if running offline or if the dataset is not available.")