Refactor ReplayBuffer with tensor-based storage and improved sampling efficiency

- Replaced list-based memory storage with pre-allocated tensor storage
- Optimized sampling process with direct tensor indexing
- Added support for DrQ image augmentation during sampling for offline dataset
- Improved dataset conversion with more robust episode handling
- Enhanced buffer initialization and state tracking
- Added comprehensive testing for buffer conversion and sampling
This commit is contained in:
AdilZouitine 2025-02-25 14:26:44 +00:00
parent 2c799508d7
commit 7c366e3223
1 changed files with 602 additions and 277 deletions

View File

@ -23,6 +23,7 @@ import torch.nn.functional as F # noqa: N812
from tqdm import tqdm from tqdm import tqdm
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
import os
class Transition(TypedDict): class Transition(TypedDict):
@ -181,29 +182,58 @@ class ReplayBuffer:
""" """
Args: Args:
capacity (int): Maximum number of transitions to store in the buffer. capacity (int): Maximum number of transitions to store in the buffer.
device (str): The device where the tensors will be moved ("cuda:0" or "cpu"). device (str): The device where the tensors will be moved when sampling ("cuda:0" or "cpu").
state_keys (List[str]): The list of keys that appear in `state` and `next_state`. state_keys (List[str]): The list of keys that appear in `state` and `next_state`.
image_augmentation_function (Optional[Callable]): A function that takes a batch of images image_augmentation_function (Optional[Callable]): A function that takes a batch of images
and returns a batch of augmented images. If None, a default augmentation function is used. and returns a batch of augmented images. If None, a default augmentation function is used.
use_drq (bool): Whether to use the default DRQ image augmentation style, when sampling in the buffer. use_drq (bool): Whether to use the default DRQ image augmentation style, when sampling in the buffer.
storage_device: The device (e.g. "cpu" or "cuda:0") where the data will be stored when adding transitions to the buffer. storage_device: The device (e.g. "cpu" or "cuda:0") where the data will be stored.
Using "cpu" can help save GPU memory. Using "cpu" can help save GPU memory.
""" """
self.capacity = capacity self.capacity = capacity
self.device = device self.device = device
self.storage_device = storage_device self.storage_device = storage_device
self.memory: list[Transition] = []
self.position = 0 self.position = 0
self.size = 0
self.initialized = False
# If no state_keys provided, default to an empty list # If no state_keys provided, default to an empty list
# (you can handle this differently if needed)
self.state_keys = state_keys if state_keys is not None else [] self.state_keys = state_keys if state_keys is not None else []
if image_augmentation_function is None: if image_augmentation_function is None:
self.image_augmentation_function = functools.partial(random_shift, pad=4) base_function = functools.partial(random_shift, pad=4)
self.image_augmentation_function = torch.compile(base_function)
self.use_drq = use_drq self.use_drq = use_drq
def _initialize_storage(self, state: dict[str, torch.Tensor], action: torch.Tensor):
"""Initialize the storage tensors based on the first transition."""
# Determine shapes from the first transition
state_shapes = {key: val.squeeze(0).shape for key, val in state.items()}
action_shape = action.squeeze(0).shape
# Pre-allocate tensors for storage
self.states = {
key: torch.empty((self.capacity, *shape), device=self.storage_device)
for key, shape in state_shapes.items()
}
self.actions = torch.empty(
(self.capacity, *action_shape), device=self.storage_device
)
self.rewards = torch.empty((self.capacity,), device=self.storage_device)
self.next_states = {
key: torch.empty((self.capacity, *shape), device=self.storage_device)
for key, shape in state_shapes.items()
}
self.dones = torch.empty(
(self.capacity,), dtype=torch.bool, device=self.storage_device
)
self.truncateds = torch.empty(
(self.capacity,), dtype=torch.bool, device=self.storage_device
)
self.initialized = True
def __len__(self): def __len__(self):
return len(self.memory) return self.size
def add( def add(
self, self,
@ -216,33 +246,91 @@ class ReplayBuffer:
complementary_info: Optional[dict[str, torch.Tensor]] = None, complementary_info: Optional[dict[str, torch.Tensor]] = None,
): ):
"""Saves a transition, ensuring tensors are stored on the designated storage device.""" """Saves a transition, ensuring tensors are stored on the designated storage device."""
# Move tensors to the storage device # Initialize storage if this is the first transition
state = {key: tensor.to(self.storage_device) for key, tensor in state.items()} if not self.initialized:
next_state = { self._initialize_storage(state=state, action=action)
key: tensor.to(self.storage_device) for key, tensor in next_state.items()
}
action = action.to(self.storage_device)
# if complementary_info is not None:
# complementary_info = {
# key: tensor.to(self.storage_device) for key, tensor in complementary_info.items()
# }
if len(self.memory) < self.capacity: # Store the transition in pre-allocated tensors
self.memory.append(None) for key in self.states:
self.states[key][self.position].copy_(state[key].squeeze(dim=0))
self.next_states[key][self.position].copy_(next_state[key].squeeze(dim=0))
self.actions[self.position].copy_(action.squeeze(dim=0))
self.rewards[self.position] = reward
self.dones[self.position] = done
self.truncateds[self.position] = truncated
# Create and store the Transition
self.memory[self.position] = Transition(
state=state,
action=action,
reward=reward,
next_state=next_state,
done=done,
truncated=truncated,
complementary_info=complementary_info,
)
self.position = (self.position + 1) % self.capacity self.position = (self.position + 1) % self.capacity
self.size = min(self.size + 1, self.capacity)
def sample(self, batch_size: int) -> BatchTransition:
"""Sample a random batch of transitions and collate them into batched tensors."""
if not self.initialized:
raise RuntimeError(
"Cannot sample from an empty buffer. Add transitions first."
)
batch_size = min(batch_size, self.size)
# Random indices for sampling - create on the same device as storage
idx = torch.randint(
low=0, high=self.size, size=(batch_size,), device=self.storage_device
)
# Identify image keys that need augmentation
image_keys = (
[k for k in self.states if k.startswith("observation.image")]
if self.use_drq
else []
)
# Create batched state and next_state
batch_state = {}
batch_next_state = {}
# First pass: load all tensors to target device
for key in self.states:
batch_state[key] = self.states[key][idx].to(self.device)
batch_next_state[key] = self.next_states[key][idx].to(self.device)
# Apply image augmentation in a batched way if needed
if self.use_drq and image_keys:
# Concatenate all images from state and next_state
all_images = []
for key in image_keys:
all_images.append(batch_state[key])
all_images.append(batch_next_state[key])
# Batch all images and apply augmentation once
all_images_tensor = torch.cat(all_images, dim=0)
augmented_images = self.image_augmentation_function(all_images_tensor)
# Split the augmented images back to their sources
for i, key in enumerate(image_keys):
# State images are at even indices (0, 2, 4...)
batch_state[key] = augmented_images[
i * 2 * batch_size : (i * 2 + 1) * batch_size
]
# Next state images are at odd indices (1, 3, 5...)
batch_next_state[key] = augmented_images[
(i * 2 + 1) * batch_size : (i + 1) * 2 * batch_size
]
# Sample other tensors
batch_actions = self.actions[idx].to(self.device)
batch_rewards = self.rewards[idx].to(self.device)
batch_dones = self.dones[idx].to(self.device).float()
batch_truncateds = self.truncateds[idx].to(self.device).float()
return BatchTransition(
state=batch_state,
action=batch_actions,
reward=batch_rewards,
next_state=batch_next_state,
done=batch_dones,
truncated=batch_truncateds,
)
# TODO: ADD image_augmentation and use_drq arguments in this function in order to instantiate the class with them
@classmethod @classmethod
def from_lerobot_dataset( def from_lerobot_dataset(
cls, cls,
@ -252,21 +340,28 @@ class ReplayBuffer:
capacity: Optional[int] = None, capacity: Optional[int] = None,
action_mask: Optional[Sequence[int]] = None, action_mask: Optional[Sequence[int]] = None,
action_delta: Optional[float] = None, action_delta: Optional[float] = None,
image_augmentation_function: Optional[Callable] = None,
use_drq: bool = True,
storage_device: str = "cpu",
) -> "ReplayBuffer": ) -> "ReplayBuffer":
""" """
Convert a LeRobotDataset into a ReplayBuffer. Convert a LeRobotDataset into a ReplayBuffer.
Args: Args:
lerobot_dataset (LeRobotDataset): The dataset to convert. lerobot_dataset (LeRobotDataset): The dataset to convert.
device (str): The device . Defaults to "cuda:0". device (str): The device for sampling tensors. Defaults to "cuda:0".
state_keys (Optional[Sequence[str]], optional): The list of keys that appear in `state` and `next_state`. state_keys (Optional[Sequence[str]]): The list of keys that appear in `state` and `next_state`.
Defaults to None. capacity (Optional[int]): Buffer capacity. If None, uses dataset length.
action_mask (Optional[Sequence[int]]): Indices of action dimensions to keep.
action_delta (Optional[float]): Factor to divide actions by.
image_augmentation_function (Optional[Callable]): Function for image augmentation.
If None, uses default random shift with pad=4.
use_drq (bool): Whether to use DrQ image augmentation when sampling.
storage_device (str): Device for storing tensor data. Using "cpu" saves GPU memory.
Returns: Returns:
ReplayBuffer: The replay buffer with offline dataset transitions. ReplayBuffer: The replay buffer with dataset transitions.
""" """
# We convert the LeRobotDataset into a replay buffer, because it is more efficient to sample from
# a replay buffer than from a lerobot dataset.
if capacity is None: if capacity is None:
capacity = len(lerobot_dataset) capacity = len(lerobot_dataset)
@ -275,11 +370,42 @@ class ReplayBuffer:
"The capacity of the ReplayBuffer must be greater than or equal to the length of the LeRobotDataset." "The capacity of the ReplayBuffer must be greater than or equal to the length of the LeRobotDataset."
) )
replay_buffer = cls(capacity=capacity, device=device, state_keys=state_keys) # Create replay buffer with image augmentation and DrQ settings
replay_buffer = cls(
capacity=capacity,
device=device,
state_keys=state_keys,
image_augmentation_function=image_augmentation_function,
use_drq=use_drq,
storage_device=storage_device,
)
# Convert dataset to transitions
list_transition = cls._lerobotdataset_to_transitions( list_transition = cls._lerobotdataset_to_transitions(
dataset=lerobot_dataset, state_keys=state_keys dataset=lerobot_dataset, state_keys=state_keys
) )
# Fill the replay buffer with the lerobot dataset transitions
# Initialize the buffer with the first transition to set up storage tensors
if list_transition:
first_transition = list_transition[0]
first_state = {
k: v.to(device) for k, v in first_transition["state"].items()
}
first_action = first_transition["action"].to(device)
# Apply action mask/delta if needed
if action_mask is not None:
if first_action.dim() == 1:
first_action = first_action[action_mask]
else:
first_action = first_action[:, action_mask]
if action_delta is not None:
first_action = first_action / action_delta
replay_buffer._initialize_storage(state=first_state, action=first_action)
# Fill the buffer with all transitions
for data in list_transition: for data in list_transition:
for k, v in data.items(): for k, v in data.items():
if isinstance(v, dict): if isinstance(v, dict):
@ -288,25 +414,127 @@ class ReplayBuffer:
elif isinstance(v, torch.Tensor): elif isinstance(v, torch.Tensor):
data[k] = v.to(device) data[k] = v.to(device)
action = data["action"]
if action_mask is not None: if action_mask is not None:
if data["action"].dim() == 1: if action.dim() == 1:
data["action"] = data["action"][action_mask] action = action[action_mask]
else: else:
data["action"] = data["action"][:, action_mask] action = action[:, action_mask]
if action_delta is not None: if action_delta is not None:
data["action"] = data["action"] / action_delta action = action / action_delta
replay_buffer.add( replay_buffer.add(
state=data["state"], state=data["state"],
action=data["action"], action=action,
reward=data["reward"], reward=data["reward"],
next_state=data["next_state"], next_state=data["next_state"],
done=data["done"], done=data["done"],
truncated=False, truncated=False, # NOTE: Truncation are not supported yet in lerobot dataset
) )
return replay_buffer return replay_buffer
def to_lerobot_dataset(
self,
repo_id: str,
fps=1,
root=None,
task_name="from_replay_buffer",
) -> LeRobotDataset:
"""
Converts all transitions in this ReplayBuffer into a single LeRobotDataset object.
"""
if self.size == 0:
raise ValueError("The replay buffer is empty. Cannot convert to a dataset.")
# Create features dictionary for the dataset
features = {
"index": {"dtype": "int64", "shape": [1]}, # global index across episodes
"episode_index": {"dtype": "int64", "shape": [1]}, # which episode
"frame_index": {"dtype": "int64", "shape": [1]}, # index inside an episode
"timestamp": {"dtype": "float32", "shape": [1]}, # for now we store dummy
"task_index": {"dtype": "int64", "shape": [1]},
}
# Add "action"
sample_action = self.actions[0]
act_info = guess_feature_info(t=sample_action, name="action")
features["action"] = act_info
# Add "reward" and "done"
features["next.reward"] = {"dtype": "float32", "shape": (1,)}
features["next.done"] = {"dtype": "bool", "shape": (1,)}
# Add state keys
for key in self.states:
sample_val = self.states[key][0]
f_info = guess_feature_info(t=sample_val, name=key)
features[key] = f_info
# Create an empty LeRobotDataset
lerobot_dataset = LeRobotDataset.create(
repo_id=repo_id,
fps=fps,
root=root,
robot=None, # TODO: (azouitine) Handle robot
robot_type=None,
features=features,
use_videos=True,
)
# Start writing images if needed
lerobot_dataset.start_image_writer(num_processes=0, num_threads=3)
# Convert transitions into episodes and frames
episode_index = 0
lerobot_dataset.episode_buffer = lerobot_dataset.create_episode_buffer(
episode_index=episode_index
)
frame_idx_in_episode = 0
for idx in range(self.size):
actual_idx = (self.position - self.size + idx) % self.capacity
frame_dict = {}
# Fill the data for state keys
for key in self.states:
frame_dict[key] = self.states[key][actual_idx].cpu()
# Fill action, reward, done
frame_dict["action"] = self.actions[actual_idx].cpu()
frame_dict["next.reward"] = torch.tensor(
[self.rewards[actual_idx]], dtype=torch.float32
).cpu()
frame_dict["next.done"] = torch.tensor(
[self.dones[actual_idx]], dtype=torch.bool
).cpu()
# Add to the dataset's buffer
lerobot_dataset.add_frame(frame_dict)
# Move to next frame
frame_idx_in_episode += 1
# If we reached an episode boundary, call save_episode, reset counters
if self.dones[actual_idx] or self.truncateds[actual_idx]:
lerobot_dataset.save_episode(task=task_name)
episode_index += 1
frame_idx_in_episode = 0
lerobot_dataset.episode_buffer = lerobot_dataset.create_episode_buffer(
episode_index=episode_index
)
# Save any remaining frames in the buffer
if lerobot_dataset.episode_buffer["size"] > 0:
lerobot_dataset.save_episode(task=task_name)
lerobot_dataset.stop_image_writer()
lerobot_dataset.consolidate(run_compute_stats=False, keep_image_files=False)
return lerobot_dataset
@staticmethod @staticmethod
def _lerobotdataset_to_transitions( def _lerobotdataset_to_transitions(
dataset: LeRobotDataset, dataset: LeRobotDataset,
@ -337,16 +565,24 @@ class ReplayBuffer:
transitions (List[Transition]): transitions (List[Transition]):
A list of Transition dictionaries with the same length as `dataset`. A list of Transition dictionaries with the same length as `dataset`.
""" """
# If not provided, you can either raise an error or define a default:
if state_keys is None: if state_keys is None:
raise ValueError( raise ValueError(
"You must provide a list of keys in `state_keys` that define your 'state'." "State keys must be provided when converting LeRobotDataset to Transitions."
) )
transitions: list[Transition] = [] transitions = []
num_frames = len(dataset) num_frames = len(dataset)
# Check if the dataset has "next.done" key
sample = dataset[0]
has_done_key = "next.done" in sample
# If not, we need to infer it from episode boundaries
if not has_done_key:
print(
"'next.done' key not found in dataset. Inferring from episode boundaries..."
)
for i in tqdm(range(num_frames)): for i in tqdm(range(num_frames)):
current_sample = dataset[i] current_sample = dataset[i]
@ -361,9 +597,22 @@ class ReplayBuffer:
# ----- 3) Reward and done ----- # ----- 3) Reward and done -----
reward = float(current_sample["next.reward"].item()) # ensure float reward = float(current_sample["next.reward"].item()) # ensure float
done = bool(current_sample["next.done"].item()) # ensure bool
# TODO: (azouitine) Handle truncation properly # Determine done flag - use next.done if available, otherwise infer from episode boundaries
truncated = bool(current_sample["next.done"].item()) # ensure bool if has_done_key:
done = bool(current_sample["next.done"].item()) # ensure bool
else:
# If this is the last frame or if next frame is in a different episode, mark as done
done = False
if i == num_frames - 1:
done = True
elif i < num_frames - 1:
next_sample = dataset[i + 1]
if next_sample["episode_index"] != current_sample["episode_index"]:
done = True
# TODO: (azouitine) Handle truncation (using the same value as done for now)
truncated = done
# ----- 4) Next state ----- # ----- 4) Next state -----
# If not done and the next sample is in the same episode, we pull the next sample's state. # If not done and the next sample is in the same episode, we pull the next sample's state.
@ -392,206 +641,6 @@ class ReplayBuffer:
return transitions return transitions
def sample(self, batch_size: int) -> BatchTransition:
"""Sample a random batch of transitions and collate them into batched tensors."""
batch_size = min(batch_size, len(self.memory))
list_of_transitions = random.sample(self.memory, batch_size)
# -- Build batched states --
batch_state = {}
for key in self.state_keys:
batch_state[key] = torch.cat(
[t["state"][key] for t in list_of_transitions], dim=0
).to(self.device)
if key.startswith("observation.image") and self.use_drq:
batch_state[key] = self.image_augmentation_function(batch_state[key])
# -- Build batched actions --
batch_actions = torch.cat([t["action"] for t in list_of_transitions]).to(
self.device
)
# -- Build batched rewards --
batch_rewards = torch.tensor(
[t["reward"] for t in list_of_transitions], dtype=torch.float32
).to(self.device)
# -- Build batched next states --
batch_next_state = {}
for key in self.state_keys:
batch_next_state[key] = torch.cat(
[t["next_state"][key] for t in list_of_transitions], dim=0
).to(self.device)
if key.startswith("observation.image") and self.use_drq:
batch_next_state[key] = self.image_augmentation_function(
batch_next_state[key]
)
# -- Build batched dones --
batch_dones = torch.tensor(
[t["done"] for t in list_of_transitions], dtype=torch.float32
).to(self.device)
# -- Build batched truncateds --
batch_truncateds = torch.tensor(
[t["truncated"] for t in list_of_transitions], dtype=torch.float32
).to(self.device)
# Return a BatchTransition typed dict
return BatchTransition(
state=batch_state,
action=batch_actions,
reward=batch_rewards,
next_state=batch_next_state,
done=batch_dones,
truncated=batch_truncateds,
)
def to_lerobot_dataset(
self,
repo_id: str,
fps=1, # If you have real timestamps, adjust this
root=None,
task_name="from_replay_buffer",
) -> LeRobotDataset:
"""
Converts all transitions in this ReplayBuffer into a single LeRobotDataset object,
splitting episodes by transitions where 'done=True'.
Returns:
LeRobotDataset: The resulting offline dataset.
"""
if len(self.memory) == 0:
raise ValueError("The replay buffer is empty. Cannot convert to a dataset.")
# Infer the shapes and dtypes of your features
# We'll create a features dict that is suitable for LeRobotDataset
# --------------------------------------------------------------------------------------------
# First, grab one transition to inspect shapes
first_transition = self.memory[0]
# We'll store default metadata for every episode: indexes, timestamps, etc.
features = {
"index": {"dtype": "int64", "shape": [1]}, # global index across episodes
"episode_index": {"dtype": "int64", "shape": [1]}, # which episode
"frame_index": {"dtype": "int64", "shape": [1]}, # index inside an episode
"timestamp": {"dtype": "float32", "shape": [1]}, # for now we store dummy
"task_index": {"dtype": "int64", "shape": [1]},
}
# Add "action"
act_info = guess_feature_info(
first_transition["action"].squeeze(dim=0), "action"
) # Remove batch dimension
features["action"] = act_info
# Add "reward" (scalars)
features["next.reward"] = {"dtype": "float32", "shape": (1,)}
# Add "done" (boolean scalars)
features["next.done"] = {"dtype": "bool", "shape": (1,)}
# Add state keys
for key in self.state_keys:
sample_val = first_transition["state"][key].squeeze(
dim=0
) # Remove batch dimension
if not isinstance(sample_val, torch.Tensor):
raise ValueError(
f"State key '{key}' is not a torch.Tensor. Please ensure your states are stored as torch.Tensors."
)
f_info = guess_feature_info(sample_val, key)
features[key] = f_info
# --------------------------------------------------------------------------------------------
# Create an empty LeRobotDataset
# We'll store all frames as separate images only if we detect shape = (3, H, W) or (1, H, W).
# By default we won't do videos, but feel free to adapt if you have them.
# --------------------------------------------------------------------------------------------
lerobot_dataset = LeRobotDataset.create(
repo_id=repo_id,
fps=fps, # If you have real timestamps, adjust this
root=root, # Or some local path where you'd like the dataset files to go
robot=None,
robot_type=None,
features=features,
use_videos=True, # We won't do actual video encoding for a replay buffer
)
# Start writing images if needed. If you have no image features, this is harmless.
# Set num_processes or num_threads if you want concurrency.
lerobot_dataset.start_image_writer(num_processes=0, num_threads=3)
# --------------------------------------------------------------------------------------------
# Convert transitions into episodes and frames
# We detect episode boundaries by `done == True`.
# --------------------------------------------------------------------------------------------
episode_index = 0
lerobot_dataset.episode_buffer = lerobot_dataset.create_episode_buffer(
episode_index
)
frame_idx_in_episode = 0
for global_frame_idx, transition in tqdm(
enumerate(self.memory),
desc="Converting replay buffer to dataset",
total=len(self.memory),
):
frame_dict = {}
# Fill the data for state keys
for key in self.state_keys:
# Expand dimension to match what the dataset expects (the dataset wants the raw shape)
# We assume your buffer has shape [C, H, W] (if image) or [D] if vector
# This is typically already correct, but if needed you can reshape below.
frame_dict[key] = (
transition["state"][key].cpu().squeeze(dim=0)
) # Remove batch dimension
# Fill action, reward, done
# Make sure they are shape (X,) or (X,Y,...) as needed.
frame_dict["action"] = (
transition["action"].cpu().squeeze(dim=0)
) # Remove batch dimension
frame_dict["next.reward"] = (
torch.tensor([transition["reward"]], dtype=torch.float32)
.cpu()
.squeeze(dim=0)
)
frame_dict["next.done"] = (
torch.tensor([transition["done"]], dtype=torch.bool)
.cpu()
.squeeze(dim=0)
)
# Add to the dataset's buffer
lerobot_dataset.add_frame(frame_dict)
# Move to next frame
frame_idx_in_episode += 1
# If we reached an episode boundary, call save_episode, reset counters
# TODO: (azouitine) Handle truncation properly
if transition["done"] or transition["truncated"]:
# Use some placeholder name for the task
lerobot_dataset.save_episode(task=task_name)
episode_index += 1
frame_idx_in_episode = 0
# Start a new buffer for the next episode
lerobot_dataset.episode_buffer = lerobot_dataset.create_episode_buffer(
episode_index=episode_index
)
# We are done adding frames
# If the last transition wasn't done=True, we still have an open buffer with frames.
# We'll consider that an incomplete episode and still save it:
if lerobot_dataset.episode_buffer["size"] > 0:
lerobot_dataset.save_episode(task=task_name)
lerobot_dataset.stop_image_writer()
lerobot_dataset.consolidate(run_compute_stats=False, keep_image_files=False)
return lerobot_dataset
# Utility function to guess shapes/dtypes from a tensor # Utility function to guess shapes/dtypes from a tensor
def guess_feature_info(t: torch.Tensor, name: str): def guess_feature_info(t: torch.Tensor, name: str):
@ -655,32 +704,308 @@ def concatenate_batch_transitions(
return left_batch_transitions return left_batch_transitions
# if __name__ == "__main__": if __name__ == "__main__":
# dataset_name = "aractingi/push_green_cube_hf_cropped_resized" import numpy as np
# dataset = LeRobotDataset(repo_id=dataset_name) from tempfile import TemporaryDirectory
# replay_buffer = ReplayBuffer.from_lerobot_dataset( # ===== Test 1: Create and use a synthetic ReplayBuffer =====
# lerobot_dataset=dataset, state_keys=["observation.image", "observation.state"] print("Testing synthetic ReplayBuffer...")
# )
# replay_buffer_converted = replay_buffer.to_lerobot_dataset(repo_id="AdilZtn/pusht_image_converted")
# for i in range(len(replay_buffer_converted)):
# replay_convert = replay_buffer_converted[i]
# dataset_convert = dataset[i]
# for key in replay_convert.keys():
# if key in {"index", "episode_index", "frame_index", "timestamp", "task_index"}:
# continue
# if key in dataset_convert.keys():
# assert torch.equal(replay_convert[key], dataset_convert[key])
# print(f"Key {key} is equal : {replay_convert[key].size()}, {dataset_convert[key].size()}")
# re_reconverted_dataset = ReplayBuffer.from_lerobot_dataset(
# replay_buffer_converted, state_keys=["observation.image", "observation.state"], device="cpu"
# )
# for _ in range(20):
# batch = re_reconverted_dataset.sample(32)
# for key in batch.keys(): # Create sample data dimensions
# if key in {"state", "next_state"}: batch_size = 32
# for key_state in batch[key].keys(): state_dims = {"observation.image": (3, 84, 84), "observation.state": (10,)}
# print(key_state, batch[key][key_state].size()) action_dim = (6,)
# continue
# print(key, batch[key].size()) # Create a buffer
buffer = ReplayBuffer(
capacity=1000,
device="cpu",
state_keys=list(state_dims.keys()),
use_drq=True,
storage_device="cpu",
)
# Add some random transitions
for i in range(100):
# Create dummy transition data
state = {
"observation.image": torch.rand(1, 3, 84, 84),
"observation.state": torch.rand(1, 10),
}
action = torch.rand(1, 6)
reward = 0.5
next_state = {
"observation.image": torch.rand(1, 3, 84, 84),
"observation.state": torch.rand(1, 10),
}
done = False if i < 99 else True
truncated = False
buffer.add(
state=state,
action=action,
reward=reward,
next_state=next_state,
done=done,
truncated=truncated,
)
# Test sampling
batch = buffer.sample(batch_size)
print(f"Buffer size: {len(buffer)}")
print(
f"Sampled batch state shapes: {batch['state']['observation.image'].shape}, {batch['state']['observation.state'].shape}"
)
print(f"Sampled batch action shape: {batch['action'].shape}")
print(f"Sampled batch reward shape: {batch['reward'].shape}")
print(f"Sampled batch done shape: {batch['done'].shape}")
print(f"Sampled batch truncated shape: {batch['truncated'].shape}")
# ===== Test for state-action-reward alignment =====
print("\nTesting state-action-reward alignment...")
# Create a buffer with controlled transitions where we know the relationships
aligned_buffer = ReplayBuffer(
capacity=100, device="cpu", state_keys=["state_value"], storage_device="cpu"
)
# Create transitions with known relationships
# - Each state has a unique signature value
# - Action is 2x the state signature
# - Reward is 3x the state signature
# - Next state is signature + 0.01 (unless at episode end)
for i in range(100):
# Create a state with a signature value that encodes the transition number
signature = float(i) / 100.0
state = {"state_value": torch.tensor([[signature]]).float()}
# Action is 2x the signature
action = torch.tensor([[2.0 * signature]]).float()
# Reward is 3x the signature
reward = 3.0 * signature
# Next state is signature + 0.01, unless end of episode
# End episode every 10 steps
is_end = (i + 1) % 10 == 0
if is_end:
# At episode boundaries, next_state repeats current state (as per your implementation)
next_state = {"state_value": torch.tensor([[signature]]).float()}
done = True
else:
# Within episodes, next_state has signature + 0.01
next_signature = float(i + 1) / 100.0
next_state = {"state_value": torch.tensor([[next_signature]]).float()}
done = False
aligned_buffer.add(state, action, reward, next_state, done, False)
# Sample from this buffer
aligned_batch = aligned_buffer.sample(50)
# Verify alignments in sampled batch
correct_relationships = 0
total_checks = 0
# For each transition in the batch
for i in range(50):
# Extract signature from state
state_sig = aligned_batch["state"]["state_value"][i].item()
# Check action is 2x signature (within reasonable precision)
action_val = aligned_batch["action"][i].item()
action_check = abs(action_val - 2.0 * state_sig) < 1e-4
# Check reward is 3x signature (within reasonable precision)
reward_val = aligned_batch["reward"][i].item()
reward_check = abs(reward_val - 3.0 * state_sig) < 1e-4
# Check next_state relationship matches our pattern
next_state_sig = aligned_batch["next_state"]["state_value"][i].item()
is_done = aligned_batch["done"][i].item() > 0.5
# Calculate expected next_state value based on done flag
if is_done:
# For episodes that end, next_state should equal state
next_state_check = abs(next_state_sig - state_sig) < 1e-4
else:
# For continuing episodes, check if next_state is approximately state + 0.01
# We need to be careful because we don't know the original index
# So we check if the increment is roughly 0.01
next_state_check = (
abs(next_state_sig - state_sig - 0.01) < 1e-4
or abs(next_state_sig - state_sig) < 1e-4
)
# Count correct relationships
if action_check:
correct_relationships += 1
if reward_check:
correct_relationships += 1
if next_state_check:
correct_relationships += 1
total_checks += 3
alignment_accuracy = 100.0 * correct_relationships / total_checks
print(
f"State-action-reward-next_state alignment accuracy: {alignment_accuracy:.2f}%"
)
if alignment_accuracy > 99.0:
print(
"✅ All relationships verified! Buffer maintains correct temporal relationships."
)
else:
print(
"⚠️ Some relationships don't match expected patterns. Buffer may have alignment issues."
)
# Print some debug information about failures
print("\nDebug information for failed checks:")
for i in range(5): # Print first 5 transitions for debugging
state_sig = aligned_batch["state"]["state_value"][i].item()
action_val = aligned_batch["action"][i].item()
reward_val = aligned_batch["reward"][i].item()
next_state_sig = aligned_batch["next_state"]["state_value"][i].item()
is_done = aligned_batch["done"][i].item() > 0.5
print(f"Transition {i}:")
print(f" State: {state_sig:.6f}")
print(f" Action: {action_val:.6f} (expected: {2.0 * state_sig:.6f})")
print(f" Reward: {reward_val:.6f} (expected: {3.0 * state_sig:.6f})")
print(f" Done: {is_done}")
print(f" Next state: {next_state_sig:.6f}")
# Calculate expected next state
if is_done:
expected_next = state_sig
else:
# This approximation might not be perfect
state_idx = round(state_sig * 100)
expected_next = (state_idx + 1) / 100.0
print(f" Expected next state: {expected_next:.6f}")
print()
# ===== Test 2: Convert to LeRobotDataset and back =====
with TemporaryDirectory() as temp_dir:
print("\nTesting conversion to LeRobotDataset and back...")
# Convert buffer to dataset
repo_id = "test/replay_buffer_conversion"
# Create a subdirectory to avoid the "directory exists" error
dataset_dir = os.path.join(temp_dir, "dataset1")
dataset = buffer.to_lerobot_dataset(repo_id=repo_id, root=dataset_dir)
print(f"Dataset created with {len(dataset)} frames")
print(f"Dataset features: {list(dataset.features.keys())}")
# Check a random sample from the dataset
sample = dataset[0]
print(
f"Dataset sample types: {[(k, type(v)) for k, v in sample.items() if k.startswith('observation')]}"
)
# Convert dataset back to buffer
reconverted_buffer = ReplayBuffer.from_lerobot_dataset(
dataset, state_keys=list(state_dims.keys()), device="cpu"
)
print(f"Reconverted buffer size: {len(reconverted_buffer)}")
# Sample from the reconverted buffer
reconverted_batch = reconverted_buffer.sample(batch_size)
print(
f"Reconverted batch state shapes: {reconverted_batch['state']['observation.image'].shape}, {reconverted_batch['state']['observation.state'].shape}"
)
# Verify consistency before and after conversion
original_states = batch["state"]["observation.image"].mean().item()
reconverted_states = (
reconverted_batch["state"]["observation.image"].mean().item()
)
print(f"Original buffer state mean: {original_states:.4f}")
print(f"Reconverted buffer state mean: {reconverted_states:.4f}")
if abs(original_states - reconverted_states) < 1.0:
print("Values are reasonably similar - conversion works as expected")
else:
print(
"WARNING: Significant difference between original and reconverted values"
)
print("\nTesting real LeRobotDataset conversion...")
try:
# Try to use a real dataset if available
dataset_name = "AdilZtn/Maniskill-Pushcube-demonstration-small"
dataset = LeRobotDataset(repo_id=dataset_name)
# Print available keys to debug
sample = dataset[0]
print("Available keys in first dataset:", list(sample.keys()))
# Check for required keys
if "action" not in sample or "next.reward" not in sample:
print("Dataset missing essential keys. Cannot convert.")
raise ValueError("Missing required keys in dataset")
# Auto-detect appropriate state keys
image_keys = []
state_keys = []
for k, v in sample.items():
# Skip metadata keys and action/reward keys
if k in {
"index",
"episode_index",
"frame_index",
"timestamp",
"task_index",
"action",
"next.reward",
"next.done",
}:
continue
# Infer key type from tensor shape
if isinstance(v, torch.Tensor):
if len(v.shape) == 3 and (v.shape[0] == 3 or v.shape[0] == 1):
# Likely an image (channels, height, width)
image_keys.append(k)
else:
# Likely state or other vector
state_keys.append(k)
print(f"Detected image keys: {image_keys}")
print(f"Detected state keys: {state_keys}")
if not image_keys and not state_keys:
print("No usable keys found in dataset, skipping further tests")
raise ValueError("No usable keys found in dataset")
# Convert to ReplayBuffer with detected keys
replay_buffer = ReplayBuffer.from_lerobot_dataset(
lerobot_dataset=dataset,
state_keys=image_keys + state_keys,
device="cpu",
)
print(f"Loaded {len(replay_buffer)} transitions from {dataset_name}")
# Test sampling
real_batch = replay_buffer.sample(batch_size)
print("Sampled batch from real dataset, state shapes:")
for key in real_batch["state"]:
print(f" {key}: {real_batch['state'][key].shape}")
# Convert back to LeRobotDataset
with TemporaryDirectory() as temp_dir:
replay_buffer_converted = replay_buffer.to_lerobot_dataset(
repo_id="test/real_dataset_converted",
root=os.path.join(temp_dir, "dataset2"),
)
print(
f"Successfully converted back to LeRobotDataset with {len(replay_buffer_converted)} frames"
)
except Exception as e:
print(f"Real dataset test failed: {e}")
print("This is expected if running offline or if the dataset is not available.")