#!/usr/bin/env python # Copyright 2024 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import functools import io import pickle from typing import Any, Callable, Optional, Sequence, TypedDict import torch import torch.nn.functional as F # noqa: N812 from tqdm import tqdm from lerobot.common.datasets.lerobot_dataset import LeRobotDataset class Transition(TypedDict): state: dict[str, torch.Tensor] action: torch.Tensor reward: float next_state: dict[str, torch.Tensor] done: bool truncated: bool complementary_info: dict[str, torch.Tensor | float | int] | None = None class BatchTransition(TypedDict): state: dict[str, torch.Tensor] action: torch.Tensor reward: torch.Tensor next_state: dict[str, torch.Tensor] done: torch.Tensor truncated: torch.Tensor complementary_info: dict[str, torch.Tensor | float | int] | None = None def move_transition_to_device(transition: Transition, device: str = "cpu") -> Transition: device = torch.device(device) non_blocking = device.type == "cuda" # Move state tensors to device transition["state"] = { key: val.to(device, non_blocking=non_blocking) for key, val in transition["state"].items() } # Move action to device transition["action"] = transition["action"].to(device, non_blocking=non_blocking) # Move reward and done if they are tensors if isinstance(transition["reward"], torch.Tensor): transition["reward"] = transition["reward"].to(device, non_blocking=non_blocking) if isinstance(transition["done"], torch.Tensor): transition["done"] = transition["done"].to(device, non_blocking=non_blocking) if isinstance(transition["truncated"], torch.Tensor): transition["truncated"] = transition["truncated"].to(device, non_blocking=non_blocking) # Move next_state tensors to device transition["next_state"] = { key: val.to(device, non_blocking=non_blocking) for key, val in transition["next_state"].items() } # Move complementary_info tensors if present if transition.get("complementary_info") is not None: for key, val in transition["complementary_info"].items(): if isinstance(val, torch.Tensor): transition["complementary_info"][key] = val.to(device, non_blocking=non_blocking) elif isinstance(val, (int, float, bool)): transition["complementary_info"][key] = torch.tensor(val, device=device) else: raise ValueError(f"Unsupported type {type(val)} for complementary_info[{key}]") return transition def move_state_dict_to_device(state_dict, device="cpu"): """ Recursively move all tensors in a (potentially) nested dict/list/tuple structure to the CPU. """ if isinstance(state_dict, torch.Tensor): return state_dict.to(device) elif isinstance(state_dict, dict): return {k: move_state_dict_to_device(v, device=device) for k, v in state_dict.items()} elif isinstance(state_dict, list): return [move_state_dict_to_device(v, device=device) for v in state_dict] elif isinstance(state_dict, tuple): return tuple(move_state_dict_to_device(v, device=device) for v in state_dict) else: return state_dict def state_to_bytes(state_dict: dict[str, torch.Tensor]) -> bytes: """Convert model state dict to flat array for transmission""" buffer = io.BytesIO() torch.save(state_dict, buffer) return buffer.getvalue() def bytes_to_state_dict(buffer: bytes) -> dict[str, torch.Tensor]: buffer = io.BytesIO(buffer) buffer.seek(0) return torch.load(buffer) def python_object_to_bytes(python_object: Any) -> bytes: return pickle.dumps(python_object) def bytes_to_python_object(buffer: bytes) -> Any: buffer = io.BytesIO(buffer) buffer.seek(0) return pickle.load(buffer) def bytes_to_transitions(buffer: bytes) -> list[Transition]: buffer = io.BytesIO(buffer) buffer.seek(0) return torch.load(buffer) def transitions_to_bytes(transitions: list[Transition]) -> bytes: buffer = io.BytesIO() torch.save(transitions, buffer) return buffer.getvalue() def random_crop_vectorized(images: torch.Tensor, output_size: tuple) -> torch.Tensor: """ Perform a per-image random crop over a batch of images in a vectorized way. (Same as shown previously.) """ B, C, H, W = images.shape # noqa: N806 crop_h, crop_w = output_size if crop_h > H or crop_w > W: raise ValueError( f"Requested crop size ({crop_h}, {crop_w}) is bigger than the image size ({H}, {W})." ) tops = torch.randint(0, H - crop_h + 1, (B,), device=images.device) lefts = torch.randint(0, W - crop_w + 1, (B,), device=images.device) rows = torch.arange(crop_h, device=images.device).unsqueeze(0) + tops.unsqueeze(1) cols = torch.arange(crop_w, device=images.device).unsqueeze(0) + lefts.unsqueeze(1) rows = rows.unsqueeze(2).expand(-1, -1, crop_w) # (B, crop_h, crop_w) cols = cols.unsqueeze(1).expand(-1, crop_h, -1) # (B, crop_h, crop_w) images_hwcn = images.permute(0, 2, 3, 1) # (B, H, W, C) # Gather pixels cropped_hwcn = images_hwcn[torch.arange(B, device=images.device).view(B, 1, 1), rows, cols, :] # cropped_hwcn => (B, crop_h, crop_w, C) cropped = cropped_hwcn.permute(0, 3, 1, 2) # (B, C, crop_h, crop_w) return cropped def random_shift(images: torch.Tensor, pad: int = 4): """Vectorized random shift, imgs: (B,C,H,W), pad: #pixels""" _, _, h, w = images.shape images = F.pad(input=images, pad=(pad, pad, pad, pad), mode="replicate") return random_crop_vectorized(images=images, output_size=(h, w)) class ReplayBuffer: def __init__( self, capacity: int, device: str = "cuda:0", state_keys: Optional[Sequence[str]] = None, image_augmentation_function: Optional[Callable] = None, use_drq: bool = True, storage_device: str = "cpu", optimize_memory: bool = False, ): """ Args: capacity (int): Maximum number of transitions to store in the buffer. device (str): The device where the tensors will be moved when sampling ("cuda:0" or "cpu"). state_keys (List[str]): The list of keys that appear in `state` and `next_state`. image_augmentation_function (Optional[Callable]): A function that takes a batch of images and returns a batch of augmented images. If None, a default augmentation function is used. use_drq (bool): Whether to use the default DRQ image augmentation style, when sampling in the buffer. storage_device: The device (e.g. "cpu" or "cuda:0") where the data will be stored. Using "cpu" can help save GPU memory. optimize_memory (bool): If True, optimizes memory by not storing duplicate next_states when they can be derived from states. This is useful for large datasets where next_state[i] = state[i+1]. """ self.capacity = capacity self.device = device self.storage_device = storage_device self.position = 0 self.size = 0 self.initialized = False self.optimize_memory = optimize_memory # Track episode boundaries for memory optimization self.episode_ends = torch.zeros(capacity, dtype=torch.bool, device=storage_device) # If no state_keys provided, default to an empty list self.state_keys = state_keys if state_keys is not None else [] if image_augmentation_function is None: base_function = functools.partial(random_shift, pad=4) self.image_augmentation_function = torch.compile(base_function) self.use_drq = use_drq def _initialize_storage( self, state: dict[str, torch.Tensor], action: torch.Tensor, complementary_info: Optional[dict[str, torch.Tensor]] = None, ): """Initialize the storage tensors based on the first transition.""" # Determine shapes from the first transition state_shapes = {key: val.squeeze(0).shape for key, val in state.items()} action_shape = action.squeeze(0).shape # Pre-allocate tensors for storage self.states = { key: torch.empty((self.capacity, *shape), device=self.storage_device) for key, shape in state_shapes.items() } self.actions = torch.empty((self.capacity, *action_shape), device=self.storage_device) self.rewards = torch.empty((self.capacity,), device=self.storage_device) if not self.optimize_memory: # Standard approach: store states and next_states separately self.next_states = { key: torch.empty((self.capacity, *shape), device=self.storage_device) for key, shape in state_shapes.items() } else: # Memory-optimized approach: don't allocate next_states buffer # Just create a reference to states for consistent API self.next_states = self.states # Just a reference for API consistency self.dones = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device) self.truncateds = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device) # Initialize storage for complementary_info self.has_complementary_info = complementary_info is not None self.complementary_info_keys = [] self.complementary_info = {} if self.has_complementary_info: self.complementary_info_keys = list(complementary_info.keys()) # Pre-allocate tensors for each key in complementary_info for key, value in complementary_info.items(): if isinstance(value, torch.Tensor): value_shape = value.squeeze(0).shape self.complementary_info[key] = torch.empty( (self.capacity, *value_shape), device=self.storage_device ) elif isinstance(value, (int, float)): # Handle scalar values similar to reward self.complementary_info[key] = torch.empty((self.capacity,), device=self.storage_device) else: raise ValueError(f"Unsupported type {type(value)} for complementary_info[{key}]") self.initialized = True def __len__(self): return self.size def add( self, state: dict[str, torch.Tensor], action: torch.Tensor, reward: float, next_state: dict[str, torch.Tensor], done: bool, truncated: bool, complementary_info: Optional[dict[str, torch.Tensor]] = None, ): """Saves a transition, ensuring tensors are stored on the designated storage device.""" # Initialize storage if this is the first transition if not self.initialized: self._initialize_storage(state=state, action=action, complementary_info=complementary_info) # Store the transition in pre-allocated tensors for key in self.states: self.states[key][self.position].copy_(state[key].squeeze(dim=0)) if not self.optimize_memory: # Only store next_states if not optimizing memory self.next_states[key][self.position].copy_(next_state[key].squeeze(dim=0)) self.actions[self.position].copy_(action.squeeze(dim=0)) self.rewards[self.position] = reward self.dones[self.position] = done self.truncateds[self.position] = truncated # Handle complementary_info if provided and storage is initialized if complementary_info is not None and self.has_complementary_info: # Store the complementary_info for key in self.complementary_info_keys: if key in complementary_info: value = complementary_info[key] if isinstance(value, torch.Tensor): self.complementary_info[key][self.position].copy_(value.squeeze(dim=0)) elif isinstance(value, (int, float)): self.complementary_info[key][self.position] = value self.position = (self.position + 1) % self.capacity self.size = min(self.size + 1, self.capacity) def sample(self, batch_size: int) -> BatchTransition: """Sample a random batch of transitions and collate them into batched tensors.""" if not self.initialized: raise RuntimeError("Cannot sample from an empty buffer. Add transitions first.") batch_size = min(batch_size, self.size) high = max(0, self.size - 1) if self.optimize_memory and self.size < self.capacity else self.size # Random indices for sampling - create on the same device as storage idx = torch.randint(low=0, high=high, size=(batch_size,), device=self.storage_device) # Identify image keys that need augmentation image_keys = [k for k in self.states if k.startswith("observation.image")] if self.use_drq else [] # Create batched state and next_state batch_state = {} batch_next_state = {} # First pass: load all state tensors to target device for key in self.states: batch_state[key] = self.states[key][idx].to(self.device) if not self.optimize_memory: # Standard approach - load next_states directly batch_next_state[key] = self.next_states[key][idx].to(self.device) else: # Memory-optimized approach - get next_state from the next index next_idx = (idx + 1) % self.capacity batch_next_state[key] = self.states[key][next_idx].to(self.device) # Apply image augmentation in a batched way if needed if self.use_drq and image_keys: # Concatenate all images from state and next_state all_images = [] for key in image_keys: all_images.append(batch_state[key]) all_images.append(batch_next_state[key]) # Batch all images and apply augmentation once all_images_tensor = torch.cat(all_images, dim=0) augmented_images = self.image_augmentation_function(all_images_tensor) # Split the augmented images back to their sources for i, key in enumerate(image_keys): # State images are at even indices (0, 2, 4...) batch_state[key] = augmented_images[i * 2 * batch_size : (i * 2 + 1) * batch_size] # Next state images are at odd indices (1, 3, 5...) batch_next_state[key] = augmented_images[(i * 2 + 1) * batch_size : (i + 1) * 2 * batch_size] # Sample other tensors batch_actions = self.actions[idx].to(self.device) batch_rewards = self.rewards[idx].to(self.device) batch_dones = self.dones[idx].to(self.device).float() batch_truncateds = self.truncateds[idx].to(self.device).float() # Sample complementary_info if available batch_complementary_info = None if self.has_complementary_info: batch_complementary_info = {} for key in self.complementary_info_keys: batch_complementary_info[key] = self.complementary_info[key][idx].to(self.device) return BatchTransition( state=batch_state, action=batch_actions, reward=batch_rewards, next_state=batch_next_state, done=batch_dones, truncated=batch_truncateds, complementary_info=batch_complementary_info, ) def get_iterator( self, batch_size: int, async_prefetch: bool = True, queue_size: int = 2, ): """ Creates an infinite iterator that yields batches of transitions. Will automatically restart when internal iterator is exhausted. Args: batch_size (int): Size of batches to sample async_prefetch (bool): Whether to use asynchronous prefetching with threads (default: True) queue_size (int): Number of batches to prefetch (default: 2) Yields: BatchTransition: Batched transitions """ while True: # Create an infinite loop if async_prefetch: # Get the standard iterator iterator = self._get_async_iterator(queue_size=queue_size, batch_size=batch_size) else: iterator = self._get_naive_iterator(batch_size=batch_size, queue_size=queue_size) # Yield all items from the iterator try: yield from iterator except StopIteration: # Just continue the outer loop to create a new iterator pass def _get_async_iterator(self, batch_size: int, queue_size: int = 2): """ Creates an iterator that prefetches batches in a background thread. Args: queue_size (int): Number of batches to prefetch (default: 2) batch_size (int): Size of batches to sample (default: 128) Yields: BatchTransition: Prefetched batch transitions """ import queue import threading # Use thread-safe queue data_queue = queue.Queue(maxsize=queue_size) running = [True] # Use list to allow modification in nested function def prefetch_worker(): while running[0]: try: # Sample data and add to queue data = self.sample(batch_size) data_queue.put(data, block=True, timeout=0.5) except queue.Full: continue except Exception as e: print(f"Prefetch error: {e}") break # Start prefetching thread thread = threading.Thread(target=prefetch_worker, daemon=True) thread.start() try: while running[0]: try: yield data_queue.get(block=True, timeout=0.5) except queue.Empty: if not thread.is_alive(): break finally: # Clean up running[0] = False thread.join(timeout=1.0) def _get_naive_iterator(self, batch_size: int, queue_size: int = 2): """ Creates a simple non-threaded iterator that yields batches. Args: batch_size (int): Size of batches to sample queue_size (int): Number of initial batches to prefetch Yields: BatchTransition: Batch transitions """ import collections queue = collections.deque() def enqueue(n): for _ in range(n): data = self.sample(batch_size) queue.append(data) enqueue(queue_size) while queue: yield queue.popleft() enqueue(1) @classmethod def from_lerobot_dataset( cls, lerobot_dataset: LeRobotDataset, device: str = "cuda:0", state_keys: Optional[Sequence[str]] = None, capacity: Optional[int] = None, action_mask: Optional[Sequence[int]] = None, image_augmentation_function: Optional[Callable] = None, use_drq: bool = True, storage_device: str = "cpu", optimize_memory: bool = False, ) -> "ReplayBuffer": """ Convert a LeRobotDataset into a ReplayBuffer. Args: lerobot_dataset (LeRobotDataset): The dataset to convert. device (str): The device for sampling tensors. Defaults to "cuda:0". state_keys (Optional[Sequence[str]]): The list of keys that appear in `state` and `next_state`. capacity (Optional[int]): Buffer capacity. If None, uses dataset length. action_mask (Optional[Sequence[int]]): Indices of action dimensions to keep. image_augmentation_function (Optional[Callable]): Function for image augmentation. If None, uses default random shift with pad=4. use_drq (bool): Whether to use DrQ image augmentation when sampling. storage_device (str): Device for storing tensor data. Using "cpu" saves GPU memory. optimize_memory (bool): If True, reduces memory usage by not duplicating state data. Returns: ReplayBuffer: The replay buffer with dataset transitions. """ if capacity is None: capacity = len(lerobot_dataset) if capacity < len(lerobot_dataset): raise ValueError( "The capacity of the ReplayBuffer must be greater than or equal to the length of the LeRobotDataset." ) # Create replay buffer with image augmentation and DrQ settings replay_buffer = cls( capacity=capacity, device=device, state_keys=state_keys, image_augmentation_function=image_augmentation_function, use_drq=use_drq, storage_device=storage_device, optimize_memory=optimize_memory, ) # Convert dataset to transitions list_transition = cls._lerobotdataset_to_transitions(dataset=lerobot_dataset, state_keys=state_keys) # Initialize the buffer with the first transition to set up storage tensors if list_transition: first_transition = list_transition[0] first_state = {k: v.to(device) for k, v in first_transition["state"].items()} first_action = first_transition["action"].to(device) # Apply action mask/delta if needed if action_mask is not None: if first_action.dim() == 1: first_action = first_action[action_mask] else: first_action = first_action[:, action_mask] # Get complementary info if available first_complementary_info = None if ( "complementary_info" in first_transition and first_transition["complementary_info"] is not None ): first_complementary_info = { k: v.to(device) for k, v in first_transition["complementary_info"].items() } replay_buffer._initialize_storage( state=first_state, action=first_action, complementary_info=first_complementary_info ) # Fill the buffer with all transitions for data in list_transition: for k, v in data.items(): if isinstance(v, dict): for key, tensor in v.items(): v[key] = tensor.to(storage_device) elif isinstance(v, torch.Tensor): data[k] = v.to(storage_device) action = data["action"] if action_mask is not None: if action.dim() == 1: action = action[action_mask] else: action = action[:, action_mask] replay_buffer.add( state=data["state"], action=action, reward=data["reward"], next_state=data["next_state"], done=data["done"], truncated=False, # NOTE: Truncation are not supported yet in lerobot dataset complementary_info=data.get("complementary_info", None), ) return replay_buffer def to_lerobot_dataset( self, repo_id: str, fps=1, root=None, task_name="from_replay_buffer", ) -> LeRobotDataset: """ Converts all transitions in this ReplayBuffer into a single LeRobotDataset object. """ if self.size == 0: raise ValueError("The replay buffer is empty. Cannot convert to a dataset.") # Create features dictionary for the dataset features = { "index": {"dtype": "int64", "shape": [1]}, # global index across episodes "episode_index": {"dtype": "int64", "shape": [1]}, # which episode "frame_index": {"dtype": "int64", "shape": [1]}, # index inside an episode "timestamp": {"dtype": "float32", "shape": [1]}, # for now we store dummy "task_index": {"dtype": "int64", "shape": [1]}, } # Add "action" sample_action = self.actions[0] act_info = guess_feature_info(t=sample_action, name="action") features["action"] = act_info # Add "reward" and "done" features["next.reward"] = {"dtype": "float32", "shape": (1,)} features["next.done"] = {"dtype": "bool", "shape": (1,)} # Add state keys for key in self.states: sample_val = self.states[key][0] f_info = guess_feature_info(t=sample_val, name=key) features[key] = f_info # Add complementary_info keys if available if self.has_complementary_info: for key in self.complementary_info_keys: sample_val = self.complementary_info[key][0] if isinstance(sample_val, torch.Tensor) and sample_val.ndim == 0: sample_val = sample_val.unsqueeze(0) f_info = guess_feature_info(t=sample_val, name=f"complementary_info.{key}") features[f"complementary_info.{key}"] = f_info # Create an empty LeRobotDataset lerobot_dataset = LeRobotDataset.create( repo_id=repo_id, fps=fps, root=root, robot=None, # TODO: (azouitine) Handle robot robot_type=None, features=features, use_videos=True, ) # Start writing images if needed lerobot_dataset.start_image_writer(num_processes=0, num_threads=3) # Convert transitions into episodes and frames episode_index = 0 lerobot_dataset.episode_buffer = lerobot_dataset.create_episode_buffer(episode_index=episode_index) frame_idx_in_episode = 0 for idx in range(self.size): actual_idx = (self.position - self.size + idx) % self.capacity frame_dict = {} # Fill the data for state keys for key in self.states: frame_dict[key] = self.states[key][actual_idx].cpu() # Fill action, reward, done frame_dict["action"] = self.actions[actual_idx].cpu() frame_dict["next.reward"] = torch.tensor([self.rewards[actual_idx]], dtype=torch.float32).cpu() frame_dict["next.done"] = torch.tensor([self.dones[actual_idx]], dtype=torch.bool).cpu() # Add complementary_info if available if self.has_complementary_info: for key in self.complementary_info_keys: val = self.complementary_info[key][actual_idx] # Convert tensors to CPU if isinstance(val, torch.Tensor): if val.ndim == 0: val = val.unsqueeze(0) frame_dict[f"complementary_info.{key}"] = val.cpu() # Non-tensor values can be used directly else: frame_dict[f"complementary_info.{key}"] = val # Add task field which is required by LeRobotDataset frame_dict["task"] = task_name # Add to the dataset's buffer lerobot_dataset.add_frame(frame_dict) # Move to next frame frame_idx_in_episode += 1 # If we reached an episode boundary, call save_episode, reset counters if self.dones[actual_idx] or self.truncateds[actual_idx]: lerobot_dataset.save_episode() episode_index += 1 frame_idx_in_episode = 0 lerobot_dataset.episode_buffer = lerobot_dataset.create_episode_buffer( episode_index=episode_index ) # Save any remaining frames in the buffer if lerobot_dataset.episode_buffer["size"] > 0: lerobot_dataset.save_episode() lerobot_dataset.stop_image_writer() return lerobot_dataset @staticmethod def _lerobotdataset_to_transitions( dataset: LeRobotDataset, state_keys: Optional[Sequence[str]] = None, ) -> list[Transition]: """ Convert a LeRobotDataset into a list of RL (s, a, r, s', done) transitions. Args: dataset (LeRobotDataset): The dataset to convert. Each item in the dataset is expected to have at least the following keys: { "action": ... "next.reward": ... "next.done": ... "episode_index": ... } plus whatever your 'state_keys' specify. state_keys (Optional[Sequence[str]]): The dataset keys to include in 'state' and 'next_state'. Their names will be kept as-is in the output transitions. E.g. ["observation.state", "observation.environment_state"]. If None, you must handle or define default keys. Returns: transitions (List[Transition]): A list of Transition dictionaries with the same length as `dataset`. """ if state_keys is None: raise ValueError("State keys must be provided when converting LeRobotDataset to Transitions.") transitions = [] num_frames = len(dataset) # Check if the dataset has "next.done" key sample = dataset[0] has_done_key = "next.done" in sample # Check for complementary_info keys complementary_info_keys = [key for key in sample if key.startswith("complementary_info.")] has_complementary_info = len(complementary_info_keys) > 0 # If not, we need to infer it from episode boundaries if not has_done_key: print("'next.done' key not found in dataset. Inferring from episode boundaries...") for i in tqdm(range(num_frames)): current_sample = dataset[i] # ----- 1) Current state ----- current_state: dict[str, torch.Tensor] = {} for key in state_keys: val = current_sample[key] current_state[key] = val.unsqueeze(0) # Add batch dimension # ----- 2) Action ----- action = current_sample["action"].unsqueeze(0) # Add batch dimension # ----- 3) Reward and done ----- reward = float(current_sample["next.reward"].item()) # ensure float # Determine done flag - use next.done if available, otherwise infer from episode boundaries if has_done_key: done = bool(current_sample["next.done"].item()) # ensure bool else: # If this is the last frame or if next frame is in a different episode, mark as done done = False if i == num_frames - 1: done = True elif i < num_frames - 1: next_sample = dataset[i + 1] if next_sample["episode_index"] != current_sample["episode_index"]: done = True # TODO: (azouitine) Handle truncation (using the same value as done for now) truncated = done # ----- 4) Next state ----- # If not done and the next sample is in the same episode, we pull the next sample's state. # Otherwise (done=True or next sample crosses to a new episode), next_state = current_state. next_state = current_state # default if not done and (i < num_frames - 1): next_sample = dataset[i + 1] if next_sample["episode_index"] == current_sample["episode_index"]: # Build next_state from the same keys next_state_data: dict[str, torch.Tensor] = {} for key in state_keys: val = next_sample[key] next_state_data[key] = val.unsqueeze(0) # Add batch dimension next_state = next_state_data # ----- 5) Complementary info (if available) ----- complementary_info = None if has_complementary_info: complementary_info = {} for key in complementary_info_keys: # Strip the "complementary_info." prefix to get the actual key clean_key = key[len("complementary_info.") :] val = current_sample[key] # Handle tensor and non-tensor values differently if isinstance(val, torch.Tensor): complementary_info[clean_key] = val.unsqueeze(0) # Add batch dimension else: # TODO: (azouitine) Check if it's necessary to convert to tensor # For non-tensor values, use directly complementary_info[clean_key] = val # ----- Construct the Transition ----- transition = Transition( state=current_state, action=action, reward=reward, next_state=next_state, done=done, truncated=truncated, complementary_info=complementary_info, ) transitions.append(transition) return transitions # Utility function to guess shapes/dtypes from a tensor def guess_feature_info(t, name: str): """ Return a dictionary with the 'dtype' and 'shape' for a given tensor or scalar value. If it looks like a 3D (C,H,W) shape, we might consider it an 'image'. Otherwise default to appropriate dtype for numeric. """ shape = tuple(t.shape) # Basic guess: if we have exactly 3 dims and shape[0] in {1, 3}, guess 'image' if len(shape) == 3 and shape[0] in [1, 3]: return { "dtype": "image", "shape": shape, } else: # Otherwise treat as numeric return { "dtype": "float32", "shape": shape, } def concatenate_batch_transitions( left_batch_transitions: BatchTransition, right_batch_transition: BatchTransition ) -> BatchTransition: """NOTE: Be careful it change the left_batch_transitions in place""" # Concatenate state fields left_batch_transitions["state"] = { key: torch.cat( [left_batch_transitions["state"][key], right_batch_transition["state"][key]], dim=0, ) for key in left_batch_transitions["state"] } # Concatenate basic fields left_batch_transitions["action"] = torch.cat( [left_batch_transitions["action"], right_batch_transition["action"]], dim=0 ) left_batch_transitions["reward"] = torch.cat( [left_batch_transitions["reward"], right_batch_transition["reward"]], dim=0 ) # Concatenate next_state fields left_batch_transitions["next_state"] = { key: torch.cat( [left_batch_transitions["next_state"][key], right_batch_transition["next_state"][key]], dim=0, ) for key in left_batch_transitions["next_state"] } # Concatenate done and truncated fields left_batch_transitions["done"] = torch.cat( [left_batch_transitions["done"], right_batch_transition["done"]], dim=0 ) left_batch_transitions["truncated"] = torch.cat( [left_batch_transitions["truncated"], right_batch_transition["truncated"]], dim=0, ) # Handle complementary_info left_info = left_batch_transitions.get("complementary_info") right_info = right_batch_transition.get("complementary_info") # Only process if right_info exists if right_info is not None: # Initialize left complementary_info if needed if left_info is None: left_batch_transitions["complementary_info"] = right_info else: # Concatenate each field for key in right_info: if key in left_info: left_info[key] = torch.cat([left_info[key], right_info[key]], dim=0) else: left_info[key] = right_info[key] return left_batch_transitions if __name__ == "__main__": def test_load_dataset_with_complementary_info(): """ Test loading a dataset with complementary_info into a ReplayBuffer. The dataset 'aractingi/pick_lift_cube_two_cameras_gripper_penalty' contains gripper_penalty values in complementary_info. """ import time from lerobot.common.datasets.lerobot_dataset import LeRobotDataset print("Loading dataset with complementary info...") # Load a small subset of the dataset (first episode) dataset = LeRobotDataset( repo_id="aractingi/pick_lift_cube_two_cameras_gripper_penalty", ) print(f"Dataset loaded with {len(dataset)} frames") print(f"Dataset features: {list(dataset.features.keys())}") # Check if dataset has complementary_info.gripper_penalty sample = dataset[0] complementary_info_keys = [key for key in sample if key.startswith("complementary_info")] print(f"Complementary info keys: {complementary_info_keys}") if "complementary_info.gripper_penalty" in sample: print(f"Found gripper_penalty: {sample['complementary_info.gripper_penalty']}") # Extract state keys for the buffer state_keys = [] for key in sample: if key.startswith("observation"): state_keys.append(key) print(f"Using state keys: {state_keys}") # Create a replay buffer from the dataset start_time = time.time() buffer = ReplayBuffer.from_lerobot_dataset( lerobot_dataset=dataset, state_keys=state_keys, use_drq=True, optimize_memory=False ) load_time = time.time() - start_time print(f"Loaded dataset into buffer in {load_time:.2f} seconds") print(f"Buffer size: {len(buffer)}") # Check if complementary_info was transferred correctly print("Sampling from buffer to check complementary_info...") batch = buffer.sample(batch_size=4) if batch["complementary_info"] is not None: print("Complementary info in batch:") for key, value in batch["complementary_info"].items(): print(f" {key}: {type(value)}, shape: {value.shape if hasattr(value, 'shape') else 'N/A'}") if key == "gripper_penalty": print(f" Sample gripper_penalty values: {value[:5]}") else: print("No complementary_info found in batch") # Now convert the buffer back to a LeRobotDataset print("\nConverting buffer back to LeRobotDataset...") start_time = time.time() new_dataset = buffer.to_lerobot_dataset( repo_id="test_dataset_from_buffer", fps=dataset.fps, root="./test_dataset_from_buffer", task_name="test_conversion", ) convert_time = time.time() - start_time print(f"Converted buffer to dataset in {convert_time:.2f} seconds") print(f"New dataset size: {len(new_dataset)} frames") # Check if complementary_info was preserved new_sample = new_dataset[0] new_complementary_info_keys = [key for key in new_sample if key.startswith("complementary_info")] print(f"New dataset complementary info keys: {new_complementary_info_keys}") if "complementary_info.gripper_penalty" in new_sample: print(f"Found gripper_penalty in new dataset: {new_sample['complementary_info.gripper_penalty']}") # Compare original and new datasets print("\nComparing original and new datasets:") print(f"Original dataset frames: {len(dataset)}, New dataset frames: {len(new_dataset)}") print(f"Original features: {list(dataset.features.keys())}") print(f"New features: {list(new_dataset.features.keys())}") return buffer, dataset, new_dataset # Run the test test_load_dataset_with_complementary_info()