Refactor ReplayBuffer with tensor-based storage and improved sampling efficiency
- Replaced list-based memory storage with pre-allocated tensor storage - Optimized sampling process with direct tensor indexing - Added support for DrQ image augmentation during sampling for offline dataset - Improved dataset conversion with more robust episode handling - Enhanced buffer initialization and state tracking - Added comprehensive testing for buffer conversion and sampling
This commit is contained in:
parent
42a038173f
commit
ef8d943e54
|
@ -23,6 +23,7 @@ import torch.nn.functional as F # noqa: N812
|
|||
from tqdm import tqdm
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
import os
|
||||
|
||||
|
||||
class Transition(TypedDict):
|
||||
|
@ -181,29 +182,58 @@ class ReplayBuffer:
|
|||
"""
|
||||
Args:
|
||||
capacity (int): Maximum number of transitions to store in the buffer.
|
||||
device (str): The device where the tensors will be moved ("cuda:0" or "cpu").
|
||||
device (str): The device where the tensors will be moved when sampling ("cuda:0" or "cpu").
|
||||
state_keys (List[str]): The list of keys that appear in `state` and `next_state`.
|
||||
image_augmentation_function (Optional[Callable]): A function that takes a batch of images
|
||||
and returns a batch of augmented images. If None, a default augmentation function is used.
|
||||
use_drq (bool): Whether to use the default DRQ image augmentation style, when sampling in the buffer.
|
||||
storage_device: The device (e.g. "cpu" or "cuda:0") where the data will be stored when adding transitions to the buffer.
|
||||
storage_device: The device (e.g. "cpu" or "cuda:0") where the data will be stored.
|
||||
Using "cpu" can help save GPU memory.
|
||||
"""
|
||||
self.capacity = capacity
|
||||
self.device = device
|
||||
self.storage_device = storage_device
|
||||
self.memory: list[Transition] = []
|
||||
self.position = 0
|
||||
self.size = 0
|
||||
self.initialized = False
|
||||
|
||||
# If no state_keys provided, default to an empty list
|
||||
# (you can handle this differently if needed)
|
||||
self.state_keys = state_keys if state_keys is not None else []
|
||||
|
||||
if image_augmentation_function is None:
|
||||
self.image_augmentation_function = functools.partial(random_shift, pad=4)
|
||||
base_function = functools.partial(random_shift, pad=4)
|
||||
self.image_augmentation_function = torch.compile(base_function)
|
||||
self.use_drq = use_drq
|
||||
|
||||
def _initialize_storage(self, state: dict[str, torch.Tensor], action: torch.Tensor):
|
||||
"""Initialize the storage tensors based on the first transition."""
|
||||
# Determine shapes from the first transition
|
||||
state_shapes = {key: val.squeeze(0).shape for key, val in state.items()}
|
||||
action_shape = action.squeeze(0).shape
|
||||
|
||||
# Pre-allocate tensors for storage
|
||||
self.states = {
|
||||
key: torch.empty((self.capacity, *shape), device=self.storage_device)
|
||||
for key, shape in state_shapes.items()
|
||||
}
|
||||
self.actions = torch.empty(
|
||||
(self.capacity, *action_shape), device=self.storage_device
|
||||
)
|
||||
self.rewards = torch.empty((self.capacity,), device=self.storage_device)
|
||||
self.next_states = {
|
||||
key: torch.empty((self.capacity, *shape), device=self.storage_device)
|
||||
for key, shape in state_shapes.items()
|
||||
}
|
||||
self.dones = torch.empty(
|
||||
(self.capacity,), dtype=torch.bool, device=self.storage_device
|
||||
)
|
||||
self.truncateds = torch.empty(
|
||||
(self.capacity,), dtype=torch.bool, device=self.storage_device
|
||||
)
|
||||
self.initialized = True
|
||||
|
||||
def __len__(self):
|
||||
return len(self.memory)
|
||||
return self.size
|
||||
|
||||
def add(
|
||||
self,
|
||||
|
@ -216,33 +246,91 @@ class ReplayBuffer:
|
|||
complementary_info: Optional[dict[str, torch.Tensor]] = None,
|
||||
):
|
||||
"""Saves a transition, ensuring tensors are stored on the designated storage device."""
|
||||
# Move tensors to the storage device
|
||||
state = {key: tensor.to(self.storage_device) for key, tensor in state.items()}
|
||||
next_state = {
|
||||
key: tensor.to(self.storage_device) for key, tensor in next_state.items()
|
||||
}
|
||||
action = action.to(self.storage_device)
|
||||
# if complementary_info is not None:
|
||||
# complementary_info = {
|
||||
# key: tensor.to(self.storage_device) for key, tensor in complementary_info.items()
|
||||
# }
|
||||
# Initialize storage if this is the first transition
|
||||
if not self.initialized:
|
||||
self._initialize_storage(state=state, action=action)
|
||||
|
||||
if len(self.memory) < self.capacity:
|
||||
self.memory.append(None)
|
||||
# Store the transition in pre-allocated tensors
|
||||
for key in self.states:
|
||||
self.states[key][self.position].copy_(state[key].squeeze(dim=0))
|
||||
self.next_states[key][self.position].copy_(next_state[key].squeeze(dim=0))
|
||||
|
||||
self.actions[self.position].copy_(action.squeeze(dim=0))
|
||||
self.rewards[self.position] = reward
|
||||
self.dones[self.position] = done
|
||||
self.truncateds[self.position] = truncated
|
||||
|
||||
# Create and store the Transition
|
||||
self.memory[self.position] = Transition(
|
||||
state=state,
|
||||
action=action,
|
||||
reward=reward,
|
||||
next_state=next_state,
|
||||
done=done,
|
||||
truncated=truncated,
|
||||
complementary_info=complementary_info,
|
||||
)
|
||||
self.position = (self.position + 1) % self.capacity
|
||||
self.size = min(self.size + 1, self.capacity)
|
||||
|
||||
def sample(self, batch_size: int) -> BatchTransition:
|
||||
"""Sample a random batch of transitions and collate them into batched tensors."""
|
||||
if not self.initialized:
|
||||
raise RuntimeError(
|
||||
"Cannot sample from an empty buffer. Add transitions first."
|
||||
)
|
||||
|
||||
batch_size = min(batch_size, self.size)
|
||||
|
||||
# Random indices for sampling - create on the same device as storage
|
||||
idx = torch.randint(
|
||||
low=0, high=self.size, size=(batch_size,), device=self.storage_device
|
||||
)
|
||||
|
||||
# Identify image keys that need augmentation
|
||||
image_keys = (
|
||||
[k for k in self.states if k.startswith("observation.image")]
|
||||
if self.use_drq
|
||||
else []
|
||||
)
|
||||
|
||||
# Create batched state and next_state
|
||||
batch_state = {}
|
||||
batch_next_state = {}
|
||||
|
||||
# First pass: load all tensors to target device
|
||||
for key in self.states:
|
||||
batch_state[key] = self.states[key][idx].to(self.device)
|
||||
batch_next_state[key] = self.next_states[key][idx].to(self.device)
|
||||
|
||||
# Apply image augmentation in a batched way if needed
|
||||
if self.use_drq and image_keys:
|
||||
# Concatenate all images from state and next_state
|
||||
all_images = []
|
||||
for key in image_keys:
|
||||
all_images.append(batch_state[key])
|
||||
all_images.append(batch_next_state[key])
|
||||
|
||||
# Batch all images and apply augmentation once
|
||||
all_images_tensor = torch.cat(all_images, dim=0)
|
||||
augmented_images = self.image_augmentation_function(all_images_tensor)
|
||||
|
||||
# Split the augmented images back to their sources
|
||||
for i, key in enumerate(image_keys):
|
||||
# State images are at even indices (0, 2, 4...)
|
||||
batch_state[key] = augmented_images[
|
||||
i * 2 * batch_size : (i * 2 + 1) * batch_size
|
||||
]
|
||||
# Next state images are at odd indices (1, 3, 5...)
|
||||
batch_next_state[key] = augmented_images[
|
||||
(i * 2 + 1) * batch_size : (i + 1) * 2 * batch_size
|
||||
]
|
||||
|
||||
# Sample other tensors
|
||||
batch_actions = self.actions[idx].to(self.device)
|
||||
batch_rewards = self.rewards[idx].to(self.device)
|
||||
batch_dones = self.dones[idx].to(self.device).float()
|
||||
batch_truncateds = self.truncateds[idx].to(self.device).float()
|
||||
|
||||
return BatchTransition(
|
||||
state=batch_state,
|
||||
action=batch_actions,
|
||||
reward=batch_rewards,
|
||||
next_state=batch_next_state,
|
||||
done=batch_dones,
|
||||
truncated=batch_truncateds,
|
||||
)
|
||||
|
||||
# TODO: ADD image_augmentation and use_drq arguments in this function in order to instantiate the class with them
|
||||
@classmethod
|
||||
def from_lerobot_dataset(
|
||||
cls,
|
||||
|
@ -252,21 +340,28 @@ class ReplayBuffer:
|
|||
capacity: Optional[int] = None,
|
||||
action_mask: Optional[Sequence[int]] = None,
|
||||
action_delta: Optional[float] = None,
|
||||
image_augmentation_function: Optional[Callable] = None,
|
||||
use_drq: bool = True,
|
||||
storage_device: str = "cpu",
|
||||
) -> "ReplayBuffer":
|
||||
"""
|
||||
Convert a LeRobotDataset into a ReplayBuffer.
|
||||
|
||||
Args:
|
||||
lerobot_dataset (LeRobotDataset): The dataset to convert.
|
||||
device (str): The device . Defaults to "cuda:0".
|
||||
state_keys (Optional[Sequence[str]], optional): The list of keys that appear in `state` and `next_state`.
|
||||
Defaults to None.
|
||||
device (str): The device for sampling tensors. Defaults to "cuda:0".
|
||||
state_keys (Optional[Sequence[str]]): The list of keys that appear in `state` and `next_state`.
|
||||
capacity (Optional[int]): Buffer capacity. If None, uses dataset length.
|
||||
action_mask (Optional[Sequence[int]]): Indices of action dimensions to keep.
|
||||
action_delta (Optional[float]): Factor to divide actions by.
|
||||
image_augmentation_function (Optional[Callable]): Function for image augmentation.
|
||||
If None, uses default random shift with pad=4.
|
||||
use_drq (bool): Whether to use DrQ image augmentation when sampling.
|
||||
storage_device (str): Device for storing tensor data. Using "cpu" saves GPU memory.
|
||||
|
||||
Returns:
|
||||
ReplayBuffer: The replay buffer with offline dataset transitions.
|
||||
ReplayBuffer: The replay buffer with dataset transitions.
|
||||
"""
|
||||
# We convert the LeRobotDataset into a replay buffer, because it is more efficient to sample from
|
||||
# a replay buffer than from a lerobot dataset.
|
||||
if capacity is None:
|
||||
capacity = len(lerobot_dataset)
|
||||
|
||||
|
@ -275,11 +370,42 @@ class ReplayBuffer:
|
|||
"The capacity of the ReplayBuffer must be greater than or equal to the length of the LeRobotDataset."
|
||||
)
|
||||
|
||||
replay_buffer = cls(capacity=capacity, device=device, state_keys=state_keys)
|
||||
# Create replay buffer with image augmentation and DrQ settings
|
||||
replay_buffer = cls(
|
||||
capacity=capacity,
|
||||
device=device,
|
||||
state_keys=state_keys,
|
||||
image_augmentation_function=image_augmentation_function,
|
||||
use_drq=use_drq,
|
||||
storage_device=storage_device,
|
||||
)
|
||||
|
||||
# Convert dataset to transitions
|
||||
list_transition = cls._lerobotdataset_to_transitions(
|
||||
dataset=lerobot_dataset, state_keys=state_keys
|
||||
)
|
||||
# Fill the replay buffer with the lerobot dataset transitions
|
||||
|
||||
# Initialize the buffer with the first transition to set up storage tensors
|
||||
if list_transition:
|
||||
first_transition = list_transition[0]
|
||||
first_state = {
|
||||
k: v.to(device) for k, v in first_transition["state"].items()
|
||||
}
|
||||
first_action = first_transition["action"].to(device)
|
||||
|
||||
# Apply action mask/delta if needed
|
||||
if action_mask is not None:
|
||||
if first_action.dim() == 1:
|
||||
first_action = first_action[action_mask]
|
||||
else:
|
||||
first_action = first_action[:, action_mask]
|
||||
|
||||
if action_delta is not None:
|
||||
first_action = first_action / action_delta
|
||||
|
||||
replay_buffer._initialize_storage(state=first_state, action=first_action)
|
||||
|
||||
# Fill the buffer with all transitions
|
||||
for data in list_transition:
|
||||
for k, v in data.items():
|
||||
if isinstance(v, dict):
|
||||
|
@ -288,25 +414,127 @@ class ReplayBuffer:
|
|||
elif isinstance(v, torch.Tensor):
|
||||
data[k] = v.to(device)
|
||||
|
||||
action = data["action"]
|
||||
if action_mask is not None:
|
||||
if data["action"].dim() == 1:
|
||||
data["action"] = data["action"][action_mask]
|
||||
if action.dim() == 1:
|
||||
action = action[action_mask]
|
||||
else:
|
||||
data["action"] = data["action"][:, action_mask]
|
||||
action = action[:, action_mask]
|
||||
|
||||
if action_delta is not None:
|
||||
data["action"] = data["action"] / action_delta
|
||||
action = action / action_delta
|
||||
|
||||
replay_buffer.add(
|
||||
state=data["state"],
|
||||
action=data["action"],
|
||||
action=action,
|
||||
reward=data["reward"],
|
||||
next_state=data["next_state"],
|
||||
done=data["done"],
|
||||
truncated=False,
|
||||
truncated=False, # NOTE: Truncation are not supported yet in lerobot dataset
|
||||
)
|
||||
|
||||
return replay_buffer
|
||||
|
||||
def to_lerobot_dataset(
|
||||
self,
|
||||
repo_id: str,
|
||||
fps=1,
|
||||
root=None,
|
||||
task_name="from_replay_buffer",
|
||||
) -> LeRobotDataset:
|
||||
"""
|
||||
Converts all transitions in this ReplayBuffer into a single LeRobotDataset object.
|
||||
"""
|
||||
if self.size == 0:
|
||||
raise ValueError("The replay buffer is empty. Cannot convert to a dataset.")
|
||||
|
||||
# Create features dictionary for the dataset
|
||||
features = {
|
||||
"index": {"dtype": "int64", "shape": [1]}, # global index across episodes
|
||||
"episode_index": {"dtype": "int64", "shape": [1]}, # which episode
|
||||
"frame_index": {"dtype": "int64", "shape": [1]}, # index inside an episode
|
||||
"timestamp": {"dtype": "float32", "shape": [1]}, # for now we store dummy
|
||||
"task_index": {"dtype": "int64", "shape": [1]},
|
||||
}
|
||||
|
||||
# Add "action"
|
||||
sample_action = self.actions[0]
|
||||
act_info = guess_feature_info(t=sample_action, name="action")
|
||||
features["action"] = act_info
|
||||
|
||||
# Add "reward" and "done"
|
||||
features["next.reward"] = {"dtype": "float32", "shape": (1,)}
|
||||
features["next.done"] = {"dtype": "bool", "shape": (1,)}
|
||||
|
||||
# Add state keys
|
||||
for key in self.states:
|
||||
sample_val = self.states[key][0]
|
||||
f_info = guess_feature_info(t=sample_val, name=key)
|
||||
features[key] = f_info
|
||||
|
||||
# Create an empty LeRobotDataset
|
||||
lerobot_dataset = LeRobotDataset.create(
|
||||
repo_id=repo_id,
|
||||
fps=fps,
|
||||
root=root,
|
||||
robot=None, # TODO: (azouitine) Handle robot
|
||||
robot_type=None,
|
||||
features=features,
|
||||
use_videos=True,
|
||||
)
|
||||
|
||||
# Start writing images if needed
|
||||
lerobot_dataset.start_image_writer(num_processes=0, num_threads=3)
|
||||
|
||||
# Convert transitions into episodes and frames
|
||||
episode_index = 0
|
||||
lerobot_dataset.episode_buffer = lerobot_dataset.create_episode_buffer(
|
||||
episode_index=episode_index
|
||||
)
|
||||
|
||||
frame_idx_in_episode = 0
|
||||
for idx in range(self.size):
|
||||
actual_idx = (self.position - self.size + idx) % self.capacity
|
||||
|
||||
frame_dict = {}
|
||||
|
||||
# Fill the data for state keys
|
||||
for key in self.states:
|
||||
frame_dict[key] = self.states[key][actual_idx].cpu()
|
||||
|
||||
# Fill action, reward, done
|
||||
frame_dict["action"] = self.actions[actual_idx].cpu()
|
||||
frame_dict["next.reward"] = torch.tensor(
|
||||
[self.rewards[actual_idx]], dtype=torch.float32
|
||||
).cpu()
|
||||
frame_dict["next.done"] = torch.tensor(
|
||||
[self.dones[actual_idx]], dtype=torch.bool
|
||||
).cpu()
|
||||
|
||||
# Add to the dataset's buffer
|
||||
lerobot_dataset.add_frame(frame_dict)
|
||||
|
||||
# Move to next frame
|
||||
frame_idx_in_episode += 1
|
||||
|
||||
# If we reached an episode boundary, call save_episode, reset counters
|
||||
if self.dones[actual_idx] or self.truncateds[actual_idx]:
|
||||
lerobot_dataset.save_episode(task=task_name)
|
||||
episode_index += 1
|
||||
frame_idx_in_episode = 0
|
||||
lerobot_dataset.episode_buffer = lerobot_dataset.create_episode_buffer(
|
||||
episode_index=episode_index
|
||||
)
|
||||
|
||||
# Save any remaining frames in the buffer
|
||||
if lerobot_dataset.episode_buffer["size"] > 0:
|
||||
lerobot_dataset.save_episode(task=task_name)
|
||||
|
||||
lerobot_dataset.stop_image_writer()
|
||||
lerobot_dataset.consolidate(run_compute_stats=False, keep_image_files=False)
|
||||
|
||||
return lerobot_dataset
|
||||
|
||||
@staticmethod
|
||||
def _lerobotdataset_to_transitions(
|
||||
dataset: LeRobotDataset,
|
||||
|
@ -337,16 +565,24 @@ class ReplayBuffer:
|
|||
transitions (List[Transition]):
|
||||
A list of Transition dictionaries with the same length as `dataset`.
|
||||
"""
|
||||
|
||||
# If not provided, you can either raise an error or define a default:
|
||||
if state_keys is None:
|
||||
raise ValueError(
|
||||
"You must provide a list of keys in `state_keys` that define your 'state'."
|
||||
"State keys must be provided when converting LeRobotDataset to Transitions."
|
||||
)
|
||||
|
||||
transitions: list[Transition] = []
|
||||
transitions = []
|
||||
num_frames = len(dataset)
|
||||
|
||||
# Check if the dataset has "next.done" key
|
||||
sample = dataset[0]
|
||||
has_done_key = "next.done" in sample
|
||||
|
||||
# If not, we need to infer it from episode boundaries
|
||||
if not has_done_key:
|
||||
print(
|
||||
"'next.done' key not found in dataset. Inferring from episode boundaries..."
|
||||
)
|
||||
|
||||
for i in tqdm(range(num_frames)):
|
||||
current_sample = dataset[i]
|
||||
|
||||
|
@ -361,9 +597,22 @@ class ReplayBuffer:
|
|||
|
||||
# ----- 3) Reward and done -----
|
||||
reward = float(current_sample["next.reward"].item()) # ensure float
|
||||
done = bool(current_sample["next.done"].item()) # ensure bool
|
||||
# TODO: (azouitine) Handle truncation properly
|
||||
truncated = bool(current_sample["next.done"].item()) # ensure bool
|
||||
|
||||
# Determine done flag - use next.done if available, otherwise infer from episode boundaries
|
||||
if has_done_key:
|
||||
done = bool(current_sample["next.done"].item()) # ensure bool
|
||||
else:
|
||||
# If this is the last frame or if next frame is in a different episode, mark as done
|
||||
done = False
|
||||
if i == num_frames - 1:
|
||||
done = True
|
||||
elif i < num_frames - 1:
|
||||
next_sample = dataset[i + 1]
|
||||
if next_sample["episode_index"] != current_sample["episode_index"]:
|
||||
done = True
|
||||
|
||||
# TODO: (azouitine) Handle truncation (using the same value as done for now)
|
||||
truncated = done
|
||||
|
||||
# ----- 4) Next state -----
|
||||
# If not done and the next sample is in the same episode, we pull the next sample's state.
|
||||
|
@ -392,206 +641,6 @@ class ReplayBuffer:
|
|||
|
||||
return transitions
|
||||
|
||||
def sample(self, batch_size: int) -> BatchTransition:
|
||||
"""Sample a random batch of transitions and collate them into batched tensors."""
|
||||
batch_size = min(batch_size, len(self.memory))
|
||||
list_of_transitions = random.sample(self.memory, batch_size)
|
||||
|
||||
# -- Build batched states --
|
||||
batch_state = {}
|
||||
for key in self.state_keys:
|
||||
batch_state[key] = torch.cat(
|
||||
[t["state"][key] for t in list_of_transitions], dim=0
|
||||
).to(self.device)
|
||||
if key.startswith("observation.image") and self.use_drq:
|
||||
batch_state[key] = self.image_augmentation_function(batch_state[key])
|
||||
|
||||
# -- Build batched actions --
|
||||
batch_actions = torch.cat([t["action"] for t in list_of_transitions]).to(
|
||||
self.device
|
||||
)
|
||||
|
||||
# -- Build batched rewards --
|
||||
batch_rewards = torch.tensor(
|
||||
[t["reward"] for t in list_of_transitions], dtype=torch.float32
|
||||
).to(self.device)
|
||||
|
||||
# -- Build batched next states --
|
||||
batch_next_state = {}
|
||||
for key in self.state_keys:
|
||||
batch_next_state[key] = torch.cat(
|
||||
[t["next_state"][key] for t in list_of_transitions], dim=0
|
||||
).to(self.device)
|
||||
if key.startswith("observation.image") and self.use_drq:
|
||||
batch_next_state[key] = self.image_augmentation_function(
|
||||
batch_next_state[key]
|
||||
)
|
||||
|
||||
# -- Build batched dones --
|
||||
batch_dones = torch.tensor(
|
||||
[t["done"] for t in list_of_transitions], dtype=torch.float32
|
||||
).to(self.device)
|
||||
|
||||
# -- Build batched truncateds --
|
||||
batch_truncateds = torch.tensor(
|
||||
[t["truncated"] for t in list_of_transitions], dtype=torch.float32
|
||||
).to(self.device)
|
||||
|
||||
# Return a BatchTransition typed dict
|
||||
return BatchTransition(
|
||||
state=batch_state,
|
||||
action=batch_actions,
|
||||
reward=batch_rewards,
|
||||
next_state=batch_next_state,
|
||||
done=batch_dones,
|
||||
truncated=batch_truncateds,
|
||||
)
|
||||
|
||||
def to_lerobot_dataset(
|
||||
self,
|
||||
repo_id: str,
|
||||
fps=1, # If you have real timestamps, adjust this
|
||||
root=None,
|
||||
task_name="from_replay_buffer",
|
||||
) -> LeRobotDataset:
|
||||
"""
|
||||
Converts all transitions in this ReplayBuffer into a single LeRobotDataset object,
|
||||
splitting episodes by transitions where 'done=True'.
|
||||
|
||||
Returns:
|
||||
LeRobotDataset: The resulting offline dataset.
|
||||
"""
|
||||
if len(self.memory) == 0:
|
||||
raise ValueError("The replay buffer is empty. Cannot convert to a dataset.")
|
||||
|
||||
# Infer the shapes and dtypes of your features
|
||||
# We'll create a features dict that is suitable for LeRobotDataset
|
||||
# --------------------------------------------------------------------------------------------
|
||||
# First, grab one transition to inspect shapes
|
||||
first_transition = self.memory[0]
|
||||
|
||||
# We'll store default metadata for every episode: indexes, timestamps, etc.
|
||||
features = {
|
||||
"index": {"dtype": "int64", "shape": [1]}, # global index across episodes
|
||||
"episode_index": {"dtype": "int64", "shape": [1]}, # which episode
|
||||
"frame_index": {"dtype": "int64", "shape": [1]}, # index inside an episode
|
||||
"timestamp": {"dtype": "float32", "shape": [1]}, # for now we store dummy
|
||||
"task_index": {"dtype": "int64", "shape": [1]},
|
||||
}
|
||||
|
||||
# Add "action"
|
||||
act_info = guess_feature_info(
|
||||
first_transition["action"].squeeze(dim=0), "action"
|
||||
) # Remove batch dimension
|
||||
features["action"] = act_info
|
||||
|
||||
# Add "reward" (scalars)
|
||||
features["next.reward"] = {"dtype": "float32", "shape": (1,)}
|
||||
|
||||
# Add "done" (boolean scalars)
|
||||
features["next.done"] = {"dtype": "bool", "shape": (1,)}
|
||||
|
||||
# Add state keys
|
||||
for key in self.state_keys:
|
||||
sample_val = first_transition["state"][key].squeeze(
|
||||
dim=0
|
||||
) # Remove batch dimension
|
||||
if not isinstance(sample_val, torch.Tensor):
|
||||
raise ValueError(
|
||||
f"State key '{key}' is not a torch.Tensor. Please ensure your states are stored as torch.Tensors."
|
||||
)
|
||||
f_info = guess_feature_info(sample_val, key)
|
||||
features[key] = f_info
|
||||
|
||||
# --------------------------------------------------------------------------------------------
|
||||
# Create an empty LeRobotDataset
|
||||
# We'll store all frames as separate images only if we detect shape = (3, H, W) or (1, H, W).
|
||||
# By default we won't do videos, but feel free to adapt if you have them.
|
||||
# --------------------------------------------------------------------------------------------
|
||||
lerobot_dataset = LeRobotDataset.create(
|
||||
repo_id=repo_id,
|
||||
fps=fps, # If you have real timestamps, adjust this
|
||||
root=root, # Or some local path where you'd like the dataset files to go
|
||||
robot=None,
|
||||
robot_type=None,
|
||||
features=features,
|
||||
use_videos=True, # We won't do actual video encoding for a replay buffer
|
||||
)
|
||||
|
||||
# Start writing images if needed. If you have no image features, this is harmless.
|
||||
# Set num_processes or num_threads if you want concurrency.
|
||||
lerobot_dataset.start_image_writer(num_processes=0, num_threads=3)
|
||||
|
||||
# --------------------------------------------------------------------------------------------
|
||||
# Convert transitions into episodes and frames
|
||||
# We detect episode boundaries by `done == True`.
|
||||
# --------------------------------------------------------------------------------------------
|
||||
episode_index = 0
|
||||
lerobot_dataset.episode_buffer = lerobot_dataset.create_episode_buffer(
|
||||
episode_index
|
||||
)
|
||||
|
||||
frame_idx_in_episode = 0
|
||||
for global_frame_idx, transition in tqdm(
|
||||
enumerate(self.memory),
|
||||
desc="Converting replay buffer to dataset",
|
||||
total=len(self.memory),
|
||||
):
|
||||
frame_dict = {}
|
||||
|
||||
# Fill the data for state keys
|
||||
for key in self.state_keys:
|
||||
# Expand dimension to match what the dataset expects (the dataset wants the raw shape)
|
||||
# We assume your buffer has shape [C, H, W] (if image) or [D] if vector
|
||||
# This is typically already correct, but if needed you can reshape below.
|
||||
frame_dict[key] = (
|
||||
transition["state"][key].cpu().squeeze(dim=0)
|
||||
) # Remove batch dimension
|
||||
|
||||
# Fill action, reward, done
|
||||
# Make sure they are shape (X,) or (X,Y,...) as needed.
|
||||
frame_dict["action"] = (
|
||||
transition["action"].cpu().squeeze(dim=0)
|
||||
) # Remove batch dimension
|
||||
frame_dict["next.reward"] = (
|
||||
torch.tensor([transition["reward"]], dtype=torch.float32)
|
||||
.cpu()
|
||||
.squeeze(dim=0)
|
||||
)
|
||||
frame_dict["next.done"] = (
|
||||
torch.tensor([transition["done"]], dtype=torch.bool)
|
||||
.cpu()
|
||||
.squeeze(dim=0)
|
||||
)
|
||||
# Add to the dataset's buffer
|
||||
lerobot_dataset.add_frame(frame_dict)
|
||||
|
||||
# Move to next frame
|
||||
frame_idx_in_episode += 1
|
||||
# If we reached an episode boundary, call save_episode, reset counters
|
||||
# TODO: (azouitine) Handle truncation properly
|
||||
if transition["done"] or transition["truncated"]:
|
||||
# Use some placeholder name for the task
|
||||
lerobot_dataset.save_episode(task=task_name)
|
||||
episode_index += 1
|
||||
frame_idx_in_episode = 0
|
||||
# Start a new buffer for the next episode
|
||||
lerobot_dataset.episode_buffer = lerobot_dataset.create_episode_buffer(
|
||||
episode_index=episode_index
|
||||
)
|
||||
|
||||
# We are done adding frames
|
||||
# If the last transition wasn't done=True, we still have an open buffer with frames.
|
||||
# We'll consider that an incomplete episode and still save it:
|
||||
if lerobot_dataset.episode_buffer["size"] > 0:
|
||||
lerobot_dataset.save_episode(task=task_name)
|
||||
|
||||
lerobot_dataset.stop_image_writer()
|
||||
|
||||
lerobot_dataset.consolidate(run_compute_stats=False, keep_image_files=False)
|
||||
|
||||
return lerobot_dataset
|
||||
|
||||
|
||||
# Utility function to guess shapes/dtypes from a tensor
|
||||
def guess_feature_info(t: torch.Tensor, name: str):
|
||||
|
@ -655,32 +704,308 @@ def concatenate_batch_transitions(
|
|||
return left_batch_transitions
|
||||
|
||||
|
||||
# if __name__ == "__main__":
|
||||
# dataset_name = "aractingi/push_green_cube_hf_cropped_resized"
|
||||
# dataset = LeRobotDataset(repo_id=dataset_name)
|
||||
if __name__ == "__main__":
|
||||
import numpy as np
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
# replay_buffer = ReplayBuffer.from_lerobot_dataset(
|
||||
# lerobot_dataset=dataset, state_keys=["observation.image", "observation.state"]
|
||||
# )
|
||||
# replay_buffer_converted = replay_buffer.to_lerobot_dataset(repo_id="AdilZtn/pusht_image_converted")
|
||||
# for i in range(len(replay_buffer_converted)):
|
||||
# replay_convert = replay_buffer_converted[i]
|
||||
# dataset_convert = dataset[i]
|
||||
# for key in replay_convert.keys():
|
||||
# if key in {"index", "episode_index", "frame_index", "timestamp", "task_index"}:
|
||||
# continue
|
||||
# if key in dataset_convert.keys():
|
||||
# assert torch.equal(replay_convert[key], dataset_convert[key])
|
||||
# print(f"Key {key} is equal : {replay_convert[key].size()}, {dataset_convert[key].size()}")
|
||||
# re_reconverted_dataset = ReplayBuffer.from_lerobot_dataset(
|
||||
# replay_buffer_converted, state_keys=["observation.image", "observation.state"], device="cpu"
|
||||
# )
|
||||
# for _ in range(20):
|
||||
# batch = re_reconverted_dataset.sample(32)
|
||||
# ===== Test 1: Create and use a synthetic ReplayBuffer =====
|
||||
print("Testing synthetic ReplayBuffer...")
|
||||
|
||||
# for key in batch.keys():
|
||||
# if key in {"state", "next_state"}:
|
||||
# for key_state in batch[key].keys():
|
||||
# print(key_state, batch[key][key_state].size())
|
||||
# continue
|
||||
# print(key, batch[key].size())
|
||||
# Create sample data dimensions
|
||||
batch_size = 32
|
||||
state_dims = {"observation.image": (3, 84, 84), "observation.state": (10,)}
|
||||
action_dim = (6,)
|
||||
|
||||
# Create a buffer
|
||||
buffer = ReplayBuffer(
|
||||
capacity=1000,
|
||||
device="cpu",
|
||||
state_keys=list(state_dims.keys()),
|
||||
use_drq=True,
|
||||
storage_device="cpu",
|
||||
)
|
||||
|
||||
# Add some random transitions
|
||||
for i in range(100):
|
||||
# Create dummy transition data
|
||||
state = {
|
||||
"observation.image": torch.rand(1, 3, 84, 84),
|
||||
"observation.state": torch.rand(1, 10),
|
||||
}
|
||||
action = torch.rand(1, 6)
|
||||
reward = 0.5
|
||||
next_state = {
|
||||
"observation.image": torch.rand(1, 3, 84, 84),
|
||||
"observation.state": torch.rand(1, 10),
|
||||
}
|
||||
done = False if i < 99 else True
|
||||
truncated = False
|
||||
|
||||
buffer.add(
|
||||
state=state,
|
||||
action=action,
|
||||
reward=reward,
|
||||
next_state=next_state,
|
||||
done=done,
|
||||
truncated=truncated,
|
||||
)
|
||||
|
||||
# Test sampling
|
||||
batch = buffer.sample(batch_size)
|
||||
print(f"Buffer size: {len(buffer)}")
|
||||
print(
|
||||
f"Sampled batch state shapes: {batch['state']['observation.image'].shape}, {batch['state']['observation.state'].shape}"
|
||||
)
|
||||
print(f"Sampled batch action shape: {batch['action'].shape}")
|
||||
print(f"Sampled batch reward shape: {batch['reward'].shape}")
|
||||
print(f"Sampled batch done shape: {batch['done'].shape}")
|
||||
print(f"Sampled batch truncated shape: {batch['truncated'].shape}")
|
||||
|
||||
# ===== Test for state-action-reward alignment =====
|
||||
print("\nTesting state-action-reward alignment...")
|
||||
|
||||
# Create a buffer with controlled transitions where we know the relationships
|
||||
aligned_buffer = ReplayBuffer(
|
||||
capacity=100, device="cpu", state_keys=["state_value"], storage_device="cpu"
|
||||
)
|
||||
|
||||
# Create transitions with known relationships
|
||||
# - Each state has a unique signature value
|
||||
# - Action is 2x the state signature
|
||||
# - Reward is 3x the state signature
|
||||
# - Next state is signature + 0.01 (unless at episode end)
|
||||
for i in range(100):
|
||||
# Create a state with a signature value that encodes the transition number
|
||||
signature = float(i) / 100.0
|
||||
state = {"state_value": torch.tensor([[signature]]).float()}
|
||||
|
||||
# Action is 2x the signature
|
||||
action = torch.tensor([[2.0 * signature]]).float()
|
||||
|
||||
# Reward is 3x the signature
|
||||
reward = 3.0 * signature
|
||||
|
||||
# Next state is signature + 0.01, unless end of episode
|
||||
# End episode every 10 steps
|
||||
is_end = (i + 1) % 10 == 0
|
||||
|
||||
if is_end:
|
||||
# At episode boundaries, next_state repeats current state (as per your implementation)
|
||||
next_state = {"state_value": torch.tensor([[signature]]).float()}
|
||||
done = True
|
||||
else:
|
||||
# Within episodes, next_state has signature + 0.01
|
||||
next_signature = float(i + 1) / 100.0
|
||||
next_state = {"state_value": torch.tensor([[next_signature]]).float()}
|
||||
done = False
|
||||
|
||||
aligned_buffer.add(state, action, reward, next_state, done, False)
|
||||
|
||||
# Sample from this buffer
|
||||
aligned_batch = aligned_buffer.sample(50)
|
||||
|
||||
# Verify alignments in sampled batch
|
||||
correct_relationships = 0
|
||||
total_checks = 0
|
||||
|
||||
# For each transition in the batch
|
||||
for i in range(50):
|
||||
# Extract signature from state
|
||||
state_sig = aligned_batch["state"]["state_value"][i].item()
|
||||
|
||||
# Check action is 2x signature (within reasonable precision)
|
||||
action_val = aligned_batch["action"][i].item()
|
||||
action_check = abs(action_val - 2.0 * state_sig) < 1e-4
|
||||
|
||||
# Check reward is 3x signature (within reasonable precision)
|
||||
reward_val = aligned_batch["reward"][i].item()
|
||||
reward_check = abs(reward_val - 3.0 * state_sig) < 1e-4
|
||||
|
||||
# Check next_state relationship matches our pattern
|
||||
next_state_sig = aligned_batch["next_state"]["state_value"][i].item()
|
||||
is_done = aligned_batch["done"][i].item() > 0.5
|
||||
|
||||
# Calculate expected next_state value based on done flag
|
||||
if is_done:
|
||||
# For episodes that end, next_state should equal state
|
||||
next_state_check = abs(next_state_sig - state_sig) < 1e-4
|
||||
else:
|
||||
# For continuing episodes, check if next_state is approximately state + 0.01
|
||||
# We need to be careful because we don't know the original index
|
||||
# So we check if the increment is roughly 0.01
|
||||
next_state_check = (
|
||||
abs(next_state_sig - state_sig - 0.01) < 1e-4
|
||||
or abs(next_state_sig - state_sig) < 1e-4
|
||||
)
|
||||
|
||||
# Count correct relationships
|
||||
if action_check:
|
||||
correct_relationships += 1
|
||||
if reward_check:
|
||||
correct_relationships += 1
|
||||
if next_state_check:
|
||||
correct_relationships += 1
|
||||
|
||||
total_checks += 3
|
||||
|
||||
alignment_accuracy = 100.0 * correct_relationships / total_checks
|
||||
print(
|
||||
f"State-action-reward-next_state alignment accuracy: {alignment_accuracy:.2f}%"
|
||||
)
|
||||
if alignment_accuracy > 99.0:
|
||||
print(
|
||||
"✅ All relationships verified! Buffer maintains correct temporal relationships."
|
||||
)
|
||||
else:
|
||||
print(
|
||||
"⚠️ Some relationships don't match expected patterns. Buffer may have alignment issues."
|
||||
)
|
||||
|
||||
# Print some debug information about failures
|
||||
print("\nDebug information for failed checks:")
|
||||
for i in range(5): # Print first 5 transitions for debugging
|
||||
state_sig = aligned_batch["state"]["state_value"][i].item()
|
||||
action_val = aligned_batch["action"][i].item()
|
||||
reward_val = aligned_batch["reward"][i].item()
|
||||
next_state_sig = aligned_batch["next_state"]["state_value"][i].item()
|
||||
is_done = aligned_batch["done"][i].item() > 0.5
|
||||
|
||||
print(f"Transition {i}:")
|
||||
print(f" State: {state_sig:.6f}")
|
||||
print(f" Action: {action_val:.6f} (expected: {2.0 * state_sig:.6f})")
|
||||
print(f" Reward: {reward_val:.6f} (expected: {3.0 * state_sig:.6f})")
|
||||
print(f" Done: {is_done}")
|
||||
print(f" Next state: {next_state_sig:.6f}")
|
||||
|
||||
# Calculate expected next state
|
||||
if is_done:
|
||||
expected_next = state_sig
|
||||
else:
|
||||
# This approximation might not be perfect
|
||||
state_idx = round(state_sig * 100)
|
||||
expected_next = (state_idx + 1) / 100.0
|
||||
|
||||
print(f" Expected next state: {expected_next:.6f}")
|
||||
print()
|
||||
|
||||
# ===== Test 2: Convert to LeRobotDataset and back =====
|
||||
with TemporaryDirectory() as temp_dir:
|
||||
print("\nTesting conversion to LeRobotDataset and back...")
|
||||
# Convert buffer to dataset
|
||||
repo_id = "test/replay_buffer_conversion"
|
||||
# Create a subdirectory to avoid the "directory exists" error
|
||||
dataset_dir = os.path.join(temp_dir, "dataset1")
|
||||
dataset = buffer.to_lerobot_dataset(repo_id=repo_id, root=dataset_dir)
|
||||
|
||||
print(f"Dataset created with {len(dataset)} frames")
|
||||
print(f"Dataset features: {list(dataset.features.keys())}")
|
||||
|
||||
# Check a random sample from the dataset
|
||||
sample = dataset[0]
|
||||
print(
|
||||
f"Dataset sample types: {[(k, type(v)) for k, v in sample.items() if k.startswith('observation')]}"
|
||||
)
|
||||
|
||||
# Convert dataset back to buffer
|
||||
reconverted_buffer = ReplayBuffer.from_lerobot_dataset(
|
||||
dataset, state_keys=list(state_dims.keys()), device="cpu"
|
||||
)
|
||||
|
||||
print(f"Reconverted buffer size: {len(reconverted_buffer)}")
|
||||
|
||||
# Sample from the reconverted buffer
|
||||
reconverted_batch = reconverted_buffer.sample(batch_size)
|
||||
print(
|
||||
f"Reconverted batch state shapes: {reconverted_batch['state']['observation.image'].shape}, {reconverted_batch['state']['observation.state'].shape}"
|
||||
)
|
||||
|
||||
# Verify consistency before and after conversion
|
||||
original_states = batch["state"]["observation.image"].mean().item()
|
||||
reconverted_states = (
|
||||
reconverted_batch["state"]["observation.image"].mean().item()
|
||||
)
|
||||
print(f"Original buffer state mean: {original_states:.4f}")
|
||||
print(f"Reconverted buffer state mean: {reconverted_states:.4f}")
|
||||
|
||||
if abs(original_states - reconverted_states) < 1.0:
|
||||
print("Values are reasonably similar - conversion works as expected")
|
||||
else:
|
||||
print(
|
||||
"WARNING: Significant difference between original and reconverted values"
|
||||
)
|
||||
|
||||
print("\nTesting real LeRobotDataset conversion...")
|
||||
try:
|
||||
# Try to use a real dataset if available
|
||||
dataset_name = "AdilZtn/Maniskill-Pushcube-demonstration-small"
|
||||
dataset = LeRobotDataset(repo_id=dataset_name)
|
||||
|
||||
# Print available keys to debug
|
||||
sample = dataset[0]
|
||||
print("Available keys in first dataset:", list(sample.keys()))
|
||||
|
||||
# Check for required keys
|
||||
if "action" not in sample or "next.reward" not in sample:
|
||||
print("Dataset missing essential keys. Cannot convert.")
|
||||
raise ValueError("Missing required keys in dataset")
|
||||
|
||||
# Auto-detect appropriate state keys
|
||||
image_keys = []
|
||||
state_keys = []
|
||||
for k, v in sample.items():
|
||||
# Skip metadata keys and action/reward keys
|
||||
if k in {
|
||||
"index",
|
||||
"episode_index",
|
||||
"frame_index",
|
||||
"timestamp",
|
||||
"task_index",
|
||||
"action",
|
||||
"next.reward",
|
||||
"next.done",
|
||||
}:
|
||||
continue
|
||||
|
||||
# Infer key type from tensor shape
|
||||
if isinstance(v, torch.Tensor):
|
||||
if len(v.shape) == 3 and (v.shape[0] == 3 or v.shape[0] == 1):
|
||||
# Likely an image (channels, height, width)
|
||||
image_keys.append(k)
|
||||
else:
|
||||
# Likely state or other vector
|
||||
state_keys.append(k)
|
||||
|
||||
print(f"Detected image keys: {image_keys}")
|
||||
print(f"Detected state keys: {state_keys}")
|
||||
|
||||
if not image_keys and not state_keys:
|
||||
print("No usable keys found in dataset, skipping further tests")
|
||||
raise ValueError("No usable keys found in dataset")
|
||||
|
||||
# Convert to ReplayBuffer with detected keys
|
||||
replay_buffer = ReplayBuffer.from_lerobot_dataset(
|
||||
lerobot_dataset=dataset,
|
||||
state_keys=image_keys + state_keys,
|
||||
device="cpu",
|
||||
)
|
||||
print(f"Loaded {len(replay_buffer)} transitions from {dataset_name}")
|
||||
|
||||
# Test sampling
|
||||
real_batch = replay_buffer.sample(batch_size)
|
||||
print("Sampled batch from real dataset, state shapes:")
|
||||
for key in real_batch["state"]:
|
||||
print(f" {key}: {real_batch['state'][key].shape}")
|
||||
|
||||
# Convert back to LeRobotDataset
|
||||
with TemporaryDirectory() as temp_dir:
|
||||
replay_buffer_converted = replay_buffer.to_lerobot_dataset(
|
||||
repo_id="test/real_dataset_converted",
|
||||
root=os.path.join(temp_dir, "dataset2"),
|
||||
)
|
||||
print(
|
||||
f"Successfully converted back to LeRobotDataset with {len(replay_buffer_converted)} frames"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Real dataset test failed: {e}")
|
||||
print("This is expected if running offline or if the dataset is not available.")
|
||||
|
|
Loading…
Reference in New Issue