1185 lines
46 KiB
Python
1185 lines
46 KiB
Python
#!/usr/bin/env python
|
|
|
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import functools
|
|
import io
|
|
import os
|
|
import pickle
|
|
from typing import Any, Callable, Optional, Sequence, TypedDict
|
|
|
|
import torch
|
|
import torch.nn.functional as F # noqa: N812
|
|
from tqdm import tqdm
|
|
|
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
|
|
|
|
|
class Transition(TypedDict):
|
|
state: dict[str, torch.Tensor]
|
|
action: torch.Tensor
|
|
reward: float
|
|
next_state: dict[str, torch.Tensor]
|
|
done: bool
|
|
truncated: bool
|
|
complementary_info: dict[str, Any] = None
|
|
|
|
|
|
class BatchTransition(TypedDict):
|
|
state: dict[str, torch.Tensor]
|
|
action: torch.Tensor
|
|
reward: torch.Tensor
|
|
next_state: dict[str, torch.Tensor]
|
|
done: torch.Tensor
|
|
truncated: torch.Tensor
|
|
|
|
|
|
def move_transition_to_device(transition: Transition, device: str = "cpu") -> Transition:
|
|
# Move state tensors to CPU
|
|
device = torch.device(device)
|
|
transition["state"] = {
|
|
key: val.to(device, non_blocking=device.type == "cuda") for key, val in transition["state"].items()
|
|
}
|
|
|
|
# Move action to CPU
|
|
transition["action"] = transition["action"].to(device, non_blocking=device.type == "cuda")
|
|
|
|
# No need to move reward or done, as they are float and bool
|
|
|
|
# No need to move reward or done, as they are float and bool
|
|
if isinstance(transition["reward"], torch.Tensor):
|
|
transition["reward"] = transition["reward"].to(device=device, non_blocking=device.type == "cuda")
|
|
|
|
if isinstance(transition["done"], torch.Tensor):
|
|
transition["done"] = transition["done"].to(device, non_blocking=device.type == "cuda")
|
|
|
|
if isinstance(transition["truncated"], torch.Tensor):
|
|
transition["truncated"] = transition["truncated"].to(device, non_blocking=device.type == "cuda")
|
|
|
|
# Move next_state tensors to CPU
|
|
transition["next_state"] = {
|
|
key: val.to(device, non_blocking=device.type == "cuda")
|
|
for key, val in transition["next_state"].items()
|
|
}
|
|
|
|
# If complementary_info is present, move its tensors to CPU
|
|
# if transition["complementary_info"] is not None:
|
|
# transition["complementary_info"] = {
|
|
# key: val.to(device, non_blocking=True) for key, val in transition["complementary_info"].items()
|
|
# }
|
|
return transition
|
|
|
|
|
|
def move_state_dict_to_device(state_dict, device="cpu"):
|
|
"""
|
|
Recursively move all tensors in a (potentially) nested
|
|
dict/list/tuple structure to the CPU.
|
|
"""
|
|
if isinstance(state_dict, torch.Tensor):
|
|
return state_dict.to(device)
|
|
elif isinstance(state_dict, dict):
|
|
return {k: move_state_dict_to_device(v, device=device) for k, v in state_dict.items()}
|
|
elif isinstance(state_dict, list):
|
|
return [move_state_dict_to_device(v, device=device) for v in state_dict]
|
|
elif isinstance(state_dict, tuple):
|
|
return tuple(move_state_dict_to_device(v, device=device) for v in state_dict)
|
|
else:
|
|
return state_dict
|
|
|
|
|
|
def state_to_bytes(state_dict: dict[str, torch.Tensor]) -> bytes:
|
|
"""Convert model state dict to flat array for transmission"""
|
|
buffer = io.BytesIO()
|
|
|
|
torch.save(state_dict, buffer)
|
|
|
|
return buffer.getvalue()
|
|
|
|
|
|
def bytes_to_state_dict(buffer: bytes) -> dict[str, torch.Tensor]:
|
|
buffer = io.BytesIO(buffer)
|
|
buffer.seek(0)
|
|
return torch.load(buffer)
|
|
|
|
|
|
def python_object_to_bytes(python_object: Any) -> bytes:
|
|
return pickle.dumps(python_object)
|
|
|
|
|
|
def bytes_to_python_object(buffer: bytes) -> Any:
|
|
buffer = io.BytesIO(buffer)
|
|
buffer.seek(0)
|
|
return pickle.load(buffer)
|
|
|
|
|
|
def bytes_to_transitions(buffer: bytes) -> list[Transition]:
|
|
buffer = io.BytesIO(buffer)
|
|
buffer.seek(0)
|
|
return torch.load(buffer)
|
|
|
|
|
|
def transitions_to_bytes(transitions: list[Transition]) -> bytes:
|
|
buffer = io.BytesIO()
|
|
torch.save(transitions, buffer)
|
|
return buffer.getvalue()
|
|
|
|
|
|
def random_crop_vectorized(images: torch.Tensor, output_size: tuple) -> torch.Tensor:
|
|
"""
|
|
Perform a per-image random crop over a batch of images in a vectorized way.
|
|
(Same as shown previously.)
|
|
"""
|
|
B, C, H, W = images.shape # noqa: N806
|
|
crop_h, crop_w = output_size
|
|
|
|
if crop_h > H or crop_w > W:
|
|
raise ValueError(
|
|
f"Requested crop size ({crop_h}, {crop_w}) is bigger than the image size ({H}, {W})."
|
|
)
|
|
|
|
tops = torch.randint(0, H - crop_h + 1, (B,), device=images.device)
|
|
lefts = torch.randint(0, W - crop_w + 1, (B,), device=images.device)
|
|
|
|
rows = torch.arange(crop_h, device=images.device).unsqueeze(0) + tops.unsqueeze(1)
|
|
cols = torch.arange(crop_w, device=images.device).unsqueeze(0) + lefts.unsqueeze(1)
|
|
|
|
rows = rows.unsqueeze(2).expand(-1, -1, crop_w) # (B, crop_h, crop_w)
|
|
cols = cols.unsqueeze(1).expand(-1, crop_h, -1) # (B, crop_h, crop_w)
|
|
|
|
images_hwcn = images.permute(0, 2, 3, 1) # (B, H, W, C)
|
|
|
|
# Gather pixels
|
|
cropped_hwcn = images_hwcn[torch.arange(B, device=images.device).view(B, 1, 1), rows, cols, :]
|
|
# cropped_hwcn => (B, crop_h, crop_w, C)
|
|
|
|
cropped = cropped_hwcn.permute(0, 3, 1, 2) # (B, C, crop_h, crop_w)
|
|
return cropped
|
|
|
|
|
|
def random_shift(images: torch.Tensor, pad: int = 4):
|
|
"""Vectorized random shift, imgs: (B,C,H,W), pad: #pixels"""
|
|
_, _, h, w = images.shape
|
|
images = F.pad(input=images, pad=(pad, pad, pad, pad), mode="replicate")
|
|
return random_crop_vectorized(images=images, output_size=(h, w))
|
|
|
|
|
|
class ReplayBuffer:
|
|
def __init__(
|
|
self,
|
|
capacity: int,
|
|
device: str = "cuda:0",
|
|
state_keys: Optional[Sequence[str]] = None,
|
|
image_augmentation_function: Optional[Callable] = None,
|
|
use_drq: bool = True,
|
|
storage_device: str = "cpu",
|
|
optimize_memory: bool = False,
|
|
):
|
|
"""
|
|
Args:
|
|
capacity (int): Maximum number of transitions to store in the buffer.
|
|
device (str): The device where the tensors will be moved when sampling ("cuda:0" or "cpu").
|
|
state_keys (List[str]): The list of keys that appear in `state` and `next_state`.
|
|
image_augmentation_function (Optional[Callable]): A function that takes a batch of images
|
|
and returns a batch of augmented images. If None, a default augmentation function is used.
|
|
use_drq (bool): Whether to use the default DRQ image augmentation style, when sampling in the buffer.
|
|
storage_device: The device (e.g. "cpu" or "cuda:0") where the data will be stored.
|
|
Using "cpu" can help save GPU memory.
|
|
optimize_memory (bool): If True, optimizes memory by not storing duplicate next_states when
|
|
they can be derived from states. This is useful for large datasets where next_state[i] = state[i+1].
|
|
"""
|
|
self.capacity = capacity
|
|
self.device = device
|
|
self.storage_device = storage_device
|
|
self.position = 0
|
|
self.size = 0
|
|
self.initialized = False
|
|
self.optimize_memory = optimize_memory
|
|
|
|
# Track episode boundaries for memory optimization
|
|
self.episode_ends = torch.zeros(capacity, dtype=torch.bool, device=storage_device)
|
|
|
|
# If no state_keys provided, default to an empty list
|
|
self.state_keys = state_keys if state_keys is not None else []
|
|
|
|
if image_augmentation_function is None:
|
|
base_function = functools.partial(random_shift, pad=4)
|
|
self.image_augmentation_function = torch.compile(base_function)
|
|
self.use_drq = use_drq
|
|
|
|
def _initialize_storage(self, state: dict[str, torch.Tensor], action: torch.Tensor):
|
|
"""Initialize the storage tensors based on the first transition."""
|
|
# Determine shapes from the first transition
|
|
state_shapes = {key: val.squeeze(0).shape for key, val in state.items()}
|
|
action_shape = action.squeeze(0).shape
|
|
|
|
# Pre-allocate tensors for storage
|
|
self.states = {
|
|
key: torch.empty((self.capacity, *shape), device=self.storage_device)
|
|
for key, shape in state_shapes.items()
|
|
}
|
|
self.actions = torch.empty((self.capacity, *action_shape), device=self.storage_device)
|
|
self.rewards = torch.empty((self.capacity,), device=self.storage_device)
|
|
|
|
if not self.optimize_memory:
|
|
# Standard approach: store states and next_states separately
|
|
self.next_states = {
|
|
key: torch.empty((self.capacity, *shape), device=self.storage_device)
|
|
for key, shape in state_shapes.items()
|
|
}
|
|
else:
|
|
# Memory-optimized approach: don't allocate next_states buffer
|
|
# Just create a reference to states for consistent API
|
|
self.next_states = self.states # Just a reference for API consistency
|
|
|
|
self.dones = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device)
|
|
self.truncateds = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device)
|
|
|
|
self.initialized = True
|
|
|
|
def __len__(self):
|
|
return self.size
|
|
|
|
def add(
|
|
self,
|
|
state: dict[str, torch.Tensor],
|
|
action: torch.Tensor,
|
|
reward: float,
|
|
next_state: dict[str, torch.Tensor],
|
|
done: bool,
|
|
truncated: bool,
|
|
complementary_info: Optional[dict[str, torch.Tensor]] = None,
|
|
):
|
|
"""Saves a transition, ensuring tensors are stored on the designated storage device."""
|
|
# Initialize storage if this is the first transition
|
|
if not self.initialized:
|
|
self._initialize_storage(state=state, action=action)
|
|
|
|
# Store the transition in pre-allocated tensors
|
|
for key in self.states:
|
|
self.states[key][self.position].copy_(state[key].squeeze(dim=0))
|
|
|
|
if not self.optimize_memory:
|
|
# Only store next_states if not optimizing memory
|
|
self.next_states[key][self.position].copy_(next_state[key].squeeze(dim=0))
|
|
|
|
self.actions[self.position].copy_(action.squeeze(dim=0))
|
|
self.rewards[self.position] = reward
|
|
self.dones[self.position] = done
|
|
self.truncateds[self.position] = truncated
|
|
|
|
self.position = (self.position + 1) % self.capacity
|
|
self.size = min(self.size + 1, self.capacity)
|
|
|
|
def sample(self, batch_size: int) -> BatchTransition:
|
|
"""Sample a random batch of transitions and collate them into batched tensors."""
|
|
if not self.initialized:
|
|
raise RuntimeError("Cannot sample from an empty buffer. Add transitions first.")
|
|
|
|
batch_size = min(batch_size, self.size)
|
|
high = max(0, self.size - 1) if self.optimize_memory and self.size < self.capacity else self.size
|
|
|
|
# Random indices for sampling - create on the same device as storage
|
|
idx = torch.randint(low=0, high=high, size=(batch_size,), device=self.storage_device)
|
|
|
|
# Identify image keys that need augmentation
|
|
image_keys = [k for k in self.states if k.startswith("observation.image")] if self.use_drq else []
|
|
|
|
# Create batched state and next_state
|
|
batch_state = {}
|
|
batch_next_state = {}
|
|
|
|
# First pass: load all state tensors to target device
|
|
for key in self.states:
|
|
batch_state[key] = self.states[key][idx].to(self.device)
|
|
|
|
if not self.optimize_memory:
|
|
# Standard approach - load next_states directly
|
|
batch_next_state[key] = self.next_states[key][idx].to(self.device)
|
|
else:
|
|
# Memory-optimized approach - get next_state from the next index
|
|
next_idx = (idx + 1) % self.capacity
|
|
batch_next_state[key] = self.states[key][next_idx].to(self.device)
|
|
|
|
# Apply image augmentation in a batched way if needed
|
|
if self.use_drq and image_keys:
|
|
# Concatenate all images from state and next_state
|
|
all_images = []
|
|
for key in image_keys:
|
|
all_images.append(batch_state[key])
|
|
all_images.append(batch_next_state[key])
|
|
|
|
# Batch all images and apply augmentation once
|
|
all_images_tensor = torch.cat(all_images, dim=0)
|
|
augmented_images = self.image_augmentation_function(all_images_tensor)
|
|
|
|
# Split the augmented images back to their sources
|
|
for i, key in enumerate(image_keys):
|
|
# State images are at even indices (0, 2, 4...)
|
|
batch_state[key] = augmented_images[i * 2 * batch_size : (i * 2 + 1) * batch_size]
|
|
# Next state images are at odd indices (1, 3, 5...)
|
|
batch_next_state[key] = augmented_images[(i * 2 + 1) * batch_size : (i + 1) * 2 * batch_size]
|
|
|
|
# Sample other tensors
|
|
batch_actions = self.actions[idx].to(self.device)
|
|
batch_rewards = self.rewards[idx].to(self.device)
|
|
batch_dones = self.dones[idx].to(self.device).float()
|
|
batch_truncateds = self.truncateds[idx].to(self.device).float()
|
|
|
|
return BatchTransition(
|
|
state=batch_state,
|
|
action=batch_actions,
|
|
reward=batch_rewards,
|
|
next_state=batch_next_state,
|
|
done=batch_dones,
|
|
truncated=batch_truncateds,
|
|
)
|
|
|
|
@classmethod
|
|
def from_lerobot_dataset(
|
|
cls,
|
|
lerobot_dataset: LeRobotDataset,
|
|
device: str = "cuda:0",
|
|
state_keys: Optional[Sequence[str]] = None,
|
|
capacity: Optional[int] = None,
|
|
action_mask: Optional[Sequence[int]] = None,
|
|
action_delta: Optional[float] = None,
|
|
image_augmentation_function: Optional[Callable] = None,
|
|
use_drq: bool = True,
|
|
storage_device: str = "cpu",
|
|
optimize_memory: bool = False,
|
|
) -> "ReplayBuffer":
|
|
"""
|
|
Convert a LeRobotDataset into a ReplayBuffer.
|
|
|
|
Args:
|
|
lerobot_dataset (LeRobotDataset): The dataset to convert.
|
|
device (str): The device for sampling tensors. Defaults to "cuda:0".
|
|
state_keys (Optional[Sequence[str]]): The list of keys that appear in `state` and `next_state`.
|
|
capacity (Optional[int]): Buffer capacity. If None, uses dataset length.
|
|
action_mask (Optional[Sequence[int]]): Indices of action dimensions to keep.
|
|
action_delta (Optional[float]): Factor to divide actions by.
|
|
image_augmentation_function (Optional[Callable]): Function for image augmentation.
|
|
If None, uses default random shift with pad=4.
|
|
use_drq (bool): Whether to use DrQ image augmentation when sampling.
|
|
storage_device (str): Device for storing tensor data. Using "cpu" saves GPU memory.
|
|
optimize_memory (bool): If True, reduces memory usage by not duplicating state data.
|
|
|
|
Returns:
|
|
ReplayBuffer: The replay buffer with dataset transitions.
|
|
"""
|
|
if capacity is None:
|
|
capacity = len(lerobot_dataset)
|
|
|
|
if capacity < len(lerobot_dataset):
|
|
raise ValueError(
|
|
"The capacity of the ReplayBuffer must be greater than or equal to the length of the LeRobotDataset."
|
|
)
|
|
|
|
# Create replay buffer with image augmentation and DrQ settings
|
|
replay_buffer = cls(
|
|
capacity=capacity,
|
|
device=device,
|
|
state_keys=state_keys,
|
|
image_augmentation_function=image_augmentation_function,
|
|
use_drq=use_drq,
|
|
storage_device=storage_device,
|
|
optimize_memory=optimize_memory,
|
|
)
|
|
|
|
# Convert dataset to transitions
|
|
list_transition = cls._lerobotdataset_to_transitions(dataset=lerobot_dataset, state_keys=state_keys)
|
|
|
|
# Initialize the buffer with the first transition to set up storage tensors
|
|
if list_transition:
|
|
first_transition = list_transition[0]
|
|
first_state = {k: v.to(device) for k, v in first_transition["state"].items()}
|
|
first_action = first_transition["action"].to(device)
|
|
|
|
# Apply action mask/delta if needed
|
|
if action_mask is not None:
|
|
if first_action.dim() == 1:
|
|
first_action = first_action[action_mask]
|
|
else:
|
|
first_action = first_action[:, action_mask]
|
|
|
|
if action_delta is not None:
|
|
first_action = first_action / action_delta
|
|
|
|
replay_buffer._initialize_storage(state=first_state, action=first_action)
|
|
|
|
# Fill the buffer with all transitions
|
|
for data in list_transition:
|
|
for k, v in data.items():
|
|
if isinstance(v, dict):
|
|
for key, tensor in v.items():
|
|
v[key] = tensor.to(storage_device)
|
|
elif isinstance(v, torch.Tensor):
|
|
data[k] = v.to(storage_device)
|
|
|
|
action = data["action"]
|
|
if action_mask is not None:
|
|
if action.dim() == 1:
|
|
action = action[action_mask]
|
|
else:
|
|
action = action[:, action_mask]
|
|
|
|
if action_delta is not None:
|
|
action = action / action_delta
|
|
|
|
replay_buffer.add(
|
|
state=data["state"],
|
|
action=action,
|
|
reward=data["reward"],
|
|
next_state=data["next_state"],
|
|
done=data["done"],
|
|
truncated=False, # NOTE: Truncation are not supported yet in lerobot dataset
|
|
)
|
|
|
|
return replay_buffer
|
|
|
|
def to_lerobot_dataset(
|
|
self,
|
|
repo_id: str,
|
|
fps=1,
|
|
root=None,
|
|
task_name="from_replay_buffer",
|
|
) -> LeRobotDataset:
|
|
"""
|
|
Converts all transitions in this ReplayBuffer into a single LeRobotDataset object.
|
|
"""
|
|
if self.size == 0:
|
|
raise ValueError("The replay buffer is empty. Cannot convert to a dataset.")
|
|
|
|
# Create features dictionary for the dataset
|
|
features = {
|
|
"index": {"dtype": "int64", "shape": [1]}, # global index across episodes
|
|
"episode_index": {"dtype": "int64", "shape": [1]}, # which episode
|
|
"frame_index": {"dtype": "int64", "shape": [1]}, # index inside an episode
|
|
"timestamp": {"dtype": "float32", "shape": [1]}, # for now we store dummy
|
|
"task_index": {"dtype": "int64", "shape": [1]},
|
|
}
|
|
|
|
# Add "action"
|
|
sample_action = self.actions[0]
|
|
act_info = guess_feature_info(t=sample_action, name="action")
|
|
features["action"] = act_info
|
|
|
|
# Add "reward" and "done"
|
|
features["next.reward"] = {"dtype": "float32", "shape": (1,)}
|
|
features["next.done"] = {"dtype": "bool", "shape": (1,)}
|
|
|
|
# Add state keys
|
|
for key in self.states:
|
|
sample_val = self.states[key][0]
|
|
f_info = guess_feature_info(t=sample_val, name=key)
|
|
features[key] = f_info
|
|
|
|
# Create an empty LeRobotDataset
|
|
lerobot_dataset = LeRobotDataset.create(
|
|
repo_id=repo_id,
|
|
fps=fps,
|
|
root=root,
|
|
robot=None, # TODO: (azouitine) Handle robot
|
|
robot_type=None,
|
|
features=features,
|
|
use_videos=True,
|
|
)
|
|
|
|
# Start writing images if needed
|
|
lerobot_dataset.start_image_writer(num_processes=0, num_threads=3)
|
|
|
|
# Convert transitions into episodes and frames
|
|
episode_index = 0
|
|
lerobot_dataset.episode_buffer = lerobot_dataset.create_episode_buffer(episode_index=episode_index)
|
|
|
|
frame_idx_in_episode = 0
|
|
for idx in range(self.size):
|
|
actual_idx = (self.position - self.size + idx) % self.capacity
|
|
|
|
frame_dict = {}
|
|
|
|
# Fill the data for state keys
|
|
for key in self.states:
|
|
frame_dict[key] = self.states[key][actual_idx].cpu()
|
|
|
|
# Fill action, reward, done
|
|
frame_dict["action"] = self.actions[actual_idx].cpu()
|
|
frame_dict["next.reward"] = torch.tensor([self.rewards[actual_idx]], dtype=torch.float32).cpu()
|
|
frame_dict["next.done"] = torch.tensor([self.dones[actual_idx]], dtype=torch.bool).cpu()
|
|
|
|
# Add task field which is required by LeRobotDataset
|
|
frame_dict["task"] = task_name
|
|
|
|
# Add to the dataset's buffer
|
|
lerobot_dataset.add_frame(frame_dict)
|
|
|
|
# Move to next frame
|
|
frame_idx_in_episode += 1
|
|
|
|
# If we reached an episode boundary, call save_episode, reset counters
|
|
if self.dones[actual_idx] or self.truncateds[actual_idx]:
|
|
lerobot_dataset.save_episode()
|
|
episode_index += 1
|
|
frame_idx_in_episode = 0
|
|
lerobot_dataset.episode_buffer = lerobot_dataset.create_episode_buffer(
|
|
episode_index=episode_index
|
|
)
|
|
|
|
# Save any remaining frames in the buffer
|
|
if lerobot_dataset.episode_buffer["size"] > 0:
|
|
lerobot_dataset.save_episode()
|
|
|
|
lerobot_dataset.stop_image_writer()
|
|
|
|
return lerobot_dataset
|
|
|
|
@staticmethod
|
|
def _lerobotdataset_to_transitions(
|
|
dataset: LeRobotDataset,
|
|
state_keys: Optional[Sequence[str]] = None,
|
|
) -> list[Transition]:
|
|
"""
|
|
Convert a LeRobotDataset into a list of RL (s, a, r, s', done) transitions.
|
|
|
|
Args:
|
|
dataset (LeRobotDataset):
|
|
The dataset to convert. Each item in the dataset is expected to have
|
|
at least the following keys:
|
|
{
|
|
"action": ...
|
|
"next.reward": ...
|
|
"next.done": ...
|
|
"episode_index": ...
|
|
}
|
|
plus whatever your 'state_keys' specify.
|
|
|
|
state_keys (Optional[Sequence[str]]):
|
|
The dataset keys to include in 'state' and 'next_state'. Their names
|
|
will be kept as-is in the output transitions. E.g.
|
|
["observation.state", "observation.environment_state"].
|
|
If None, you must handle or define default keys.
|
|
|
|
Returns:
|
|
transitions (List[Transition]):
|
|
A list of Transition dictionaries with the same length as `dataset`.
|
|
"""
|
|
if state_keys is None:
|
|
raise ValueError("State keys must be provided when converting LeRobotDataset to Transitions.")
|
|
|
|
transitions = []
|
|
num_frames = len(dataset)
|
|
|
|
# Check if the dataset has "next.done" key
|
|
sample = dataset[0]
|
|
has_done_key = "next.done" in sample
|
|
|
|
# If not, we need to infer it from episode boundaries
|
|
if not has_done_key:
|
|
print("'next.done' key not found in dataset. Inferring from episode boundaries...")
|
|
|
|
for i in tqdm(range(num_frames)):
|
|
current_sample = dataset[i]
|
|
|
|
# ----- 1) Current state -----
|
|
current_state: dict[str, torch.Tensor] = {}
|
|
for key in state_keys:
|
|
val = current_sample[key]
|
|
current_state[key] = val.unsqueeze(0) # Add batch dimension
|
|
|
|
# ----- 2) Action -----
|
|
action = current_sample["action"].unsqueeze(0) # Add batch dimension
|
|
|
|
# ----- 3) Reward and done -----
|
|
reward = float(current_sample["next.reward"].item()) # ensure float
|
|
|
|
# Determine done flag - use next.done if available, otherwise infer from episode boundaries
|
|
if has_done_key:
|
|
done = bool(current_sample["next.done"].item()) # ensure bool
|
|
else:
|
|
# If this is the last frame or if next frame is in a different episode, mark as done
|
|
done = False
|
|
if i == num_frames - 1:
|
|
done = True
|
|
elif i < num_frames - 1:
|
|
next_sample = dataset[i + 1]
|
|
if next_sample["episode_index"] != current_sample["episode_index"]:
|
|
done = True
|
|
|
|
# TODO: (azouitine) Handle truncation (using the same value as done for now)
|
|
truncated = done
|
|
|
|
# ----- 4) Next state -----
|
|
# If not done and the next sample is in the same episode, we pull the next sample's state.
|
|
# Otherwise (done=True or next sample crosses to a new episode), next_state = current_state.
|
|
next_state = current_state # default
|
|
if not done and (i < num_frames - 1):
|
|
next_sample = dataset[i + 1]
|
|
if next_sample["episode_index"] == current_sample["episode_index"]:
|
|
# Build next_state from the same keys
|
|
next_state_data: dict[str, torch.Tensor] = {}
|
|
for key in state_keys:
|
|
val = next_sample[key]
|
|
next_state_data[key] = val.unsqueeze(0) # Add batch dimension
|
|
next_state = next_state_data
|
|
|
|
# ----- Construct the Transition -----
|
|
transition = Transition(
|
|
state=current_state,
|
|
action=action,
|
|
reward=reward,
|
|
next_state=next_state,
|
|
done=done,
|
|
truncated=truncated,
|
|
)
|
|
transitions.append(transition)
|
|
|
|
return transitions
|
|
|
|
|
|
# Utility function to guess shapes/dtypes from a tensor
|
|
def guess_feature_info(t: torch.Tensor, name: str):
|
|
"""
|
|
Return a dictionary with the 'dtype' and 'shape' for a given tensor or array.
|
|
If it looks like a 3D (C,H,W) shape, we might consider it an 'image'.
|
|
Otherwise default to 'float32' for numeric. You can customize as needed.
|
|
"""
|
|
shape = tuple(t.shape)
|
|
# Basic guess: if we have exactly 3 dims and shape[0] in {1, 3}, guess 'image'
|
|
if len(shape) == 3 and shape[0] in [1, 3]:
|
|
return {
|
|
"dtype": "image",
|
|
"shape": shape,
|
|
}
|
|
else:
|
|
# Otherwise treat as numeric
|
|
return {
|
|
"dtype": "float32",
|
|
"shape": shape,
|
|
}
|
|
|
|
|
|
def concatenate_batch_transitions(
|
|
left_batch_transitions: BatchTransition, right_batch_transition: BatchTransition
|
|
) -> BatchTransition:
|
|
"""NOTE: Be careful it change the left_batch_transitions in place"""
|
|
left_batch_transitions["state"] = {
|
|
key: torch.cat(
|
|
[
|
|
left_batch_transitions["state"][key],
|
|
right_batch_transition["state"][key],
|
|
],
|
|
dim=0,
|
|
)
|
|
for key in left_batch_transitions["state"]
|
|
}
|
|
left_batch_transitions["action"] = torch.cat(
|
|
[left_batch_transitions["action"], right_batch_transition["action"]], dim=0
|
|
)
|
|
left_batch_transitions["reward"] = torch.cat(
|
|
[left_batch_transitions["reward"], right_batch_transition["reward"]], dim=0
|
|
)
|
|
left_batch_transitions["next_state"] = {
|
|
key: torch.cat(
|
|
[
|
|
left_batch_transitions["next_state"][key],
|
|
right_batch_transition["next_state"][key],
|
|
],
|
|
dim=0,
|
|
)
|
|
for key in left_batch_transitions["next_state"]
|
|
}
|
|
left_batch_transitions["done"] = torch.cat(
|
|
[left_batch_transitions["done"], right_batch_transition["done"]], dim=0
|
|
)
|
|
left_batch_transitions["truncated"] = torch.cat(
|
|
[left_batch_transitions["truncated"], right_batch_transition["truncated"]],
|
|
dim=0,
|
|
)
|
|
return left_batch_transitions
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from tempfile import TemporaryDirectory
|
|
|
|
# ===== Test 1: Create and use a synthetic ReplayBuffer =====
|
|
print("Testing synthetic ReplayBuffer...")
|
|
|
|
# Create sample data dimensions
|
|
batch_size = 32
|
|
state_dims = {"observation.image": (3, 84, 84), "observation.state": (10,)}
|
|
action_dim = (6,)
|
|
|
|
# Create a buffer
|
|
buffer = ReplayBuffer(
|
|
capacity=1000,
|
|
device="cpu",
|
|
state_keys=list(state_dims.keys()),
|
|
use_drq=True,
|
|
storage_device="cpu",
|
|
)
|
|
|
|
# Add some random transitions
|
|
for i in range(100):
|
|
# Create dummy transition data
|
|
state = {
|
|
"observation.image": torch.rand(1, 3, 84, 84),
|
|
"observation.state": torch.rand(1, 10),
|
|
}
|
|
action = torch.rand(1, 6)
|
|
reward = 0.5
|
|
next_state = {
|
|
"observation.image": torch.rand(1, 3, 84, 84),
|
|
"observation.state": torch.rand(1, 10),
|
|
}
|
|
done = False if i < 99 else True
|
|
truncated = False
|
|
|
|
buffer.add(
|
|
state=state,
|
|
action=action,
|
|
reward=reward,
|
|
next_state=next_state,
|
|
done=done,
|
|
truncated=truncated,
|
|
)
|
|
|
|
# Test sampling
|
|
batch = buffer.sample(batch_size)
|
|
print(f"Buffer size: {len(buffer)}")
|
|
print(
|
|
f"Sampled batch state shapes: {batch['state']['observation.image'].shape}, {batch['state']['observation.state'].shape}"
|
|
)
|
|
print(f"Sampled batch action shape: {batch['action'].shape}")
|
|
print(f"Sampled batch reward shape: {batch['reward'].shape}")
|
|
print(f"Sampled batch done shape: {batch['done'].shape}")
|
|
print(f"Sampled batch truncated shape: {batch['truncated'].shape}")
|
|
|
|
# ===== Test for state-action-reward alignment =====
|
|
print("\nTesting state-action-reward alignment...")
|
|
|
|
# Create a buffer with controlled transitions where we know the relationships
|
|
aligned_buffer = ReplayBuffer(
|
|
capacity=100, device="cpu", state_keys=["state_value"], storage_device="cpu"
|
|
)
|
|
|
|
# Create transitions with known relationships
|
|
# - Each state has a unique signature value
|
|
# - Action is 2x the state signature
|
|
# - Reward is 3x the state signature
|
|
# - Next state is signature + 0.01 (unless at episode end)
|
|
for i in range(100):
|
|
# Create a state with a signature value that encodes the transition number
|
|
signature = float(i) / 100.0
|
|
state = {"state_value": torch.tensor([[signature]]).float()}
|
|
|
|
# Action is 2x the signature
|
|
action = torch.tensor([[2.0 * signature]]).float()
|
|
|
|
# Reward is 3x the signature
|
|
reward = 3.0 * signature
|
|
|
|
# Next state is signature + 0.01, unless end of episode
|
|
# End episode every 10 steps
|
|
is_end = (i + 1) % 10 == 0
|
|
|
|
if is_end:
|
|
# At episode boundaries, next_state repeats current state (as per your implementation)
|
|
next_state = {"state_value": torch.tensor([[signature]]).float()}
|
|
done = True
|
|
else:
|
|
# Within episodes, next_state has signature + 0.01
|
|
next_signature = float(i + 1) / 100.0
|
|
next_state = {"state_value": torch.tensor([[next_signature]]).float()}
|
|
done = False
|
|
|
|
aligned_buffer.add(state, action, reward, next_state, done, False)
|
|
|
|
# Sample from this buffer
|
|
aligned_batch = aligned_buffer.sample(50)
|
|
|
|
# Verify alignments in sampled batch
|
|
correct_relationships = 0
|
|
total_checks = 0
|
|
|
|
# For each transition in the batch
|
|
for i in range(50):
|
|
# Extract signature from state
|
|
state_sig = aligned_batch["state"]["state_value"][i].item()
|
|
|
|
# Check action is 2x signature (within reasonable precision)
|
|
action_val = aligned_batch["action"][i].item()
|
|
action_check = abs(action_val - 2.0 * state_sig) < 1e-4
|
|
|
|
# Check reward is 3x signature (within reasonable precision)
|
|
reward_val = aligned_batch["reward"][i].item()
|
|
reward_check = abs(reward_val - 3.0 * state_sig) < 1e-4
|
|
|
|
# Check next_state relationship matches our pattern
|
|
next_state_sig = aligned_batch["next_state"]["state_value"][i].item()
|
|
is_done = aligned_batch["done"][i].item() > 0.5
|
|
|
|
# Calculate expected next_state value based on done flag
|
|
if is_done:
|
|
# For episodes that end, next_state should equal state
|
|
next_state_check = abs(next_state_sig - state_sig) < 1e-4
|
|
else:
|
|
# For continuing episodes, check if next_state is approximately state + 0.01
|
|
# We need to be careful because we don't know the original index
|
|
# So we check if the increment is roughly 0.01
|
|
next_state_check = (
|
|
abs(next_state_sig - state_sig - 0.01) < 1e-4 or abs(next_state_sig - state_sig) < 1e-4
|
|
)
|
|
|
|
# Count correct relationships
|
|
if action_check:
|
|
correct_relationships += 1
|
|
if reward_check:
|
|
correct_relationships += 1
|
|
if next_state_check:
|
|
correct_relationships += 1
|
|
|
|
total_checks += 3
|
|
|
|
alignment_accuracy = 100.0 * correct_relationships / total_checks
|
|
print(f"State-action-reward-next_state alignment accuracy: {alignment_accuracy:.2f}%")
|
|
if alignment_accuracy > 99.0:
|
|
print("✅ All relationships verified! Buffer maintains correct temporal relationships.")
|
|
else:
|
|
print("⚠️ Some relationships don't match expected patterns. Buffer may have alignment issues.")
|
|
|
|
# Print some debug information about failures
|
|
print("\nDebug information for failed checks:")
|
|
for i in range(5): # Print first 5 transitions for debugging
|
|
state_sig = aligned_batch["state"]["state_value"][i].item()
|
|
action_val = aligned_batch["action"][i].item()
|
|
reward_val = aligned_batch["reward"][i].item()
|
|
next_state_sig = aligned_batch["next_state"]["state_value"][i].item()
|
|
is_done = aligned_batch["done"][i].item() > 0.5
|
|
|
|
print(f"Transition {i}:")
|
|
print(f" State: {state_sig:.6f}")
|
|
print(f" Action: {action_val:.6f} (expected: {2.0 * state_sig:.6f})")
|
|
print(f" Reward: {reward_val:.6f} (expected: {3.0 * state_sig:.6f})")
|
|
print(f" Done: {is_done}")
|
|
print(f" Next state: {next_state_sig:.6f}")
|
|
|
|
# Calculate expected next state
|
|
if is_done:
|
|
expected_next = state_sig
|
|
else:
|
|
# This approximation might not be perfect
|
|
state_idx = round(state_sig * 100)
|
|
expected_next = (state_idx + 1) / 100.0
|
|
|
|
print(f" Expected next state: {expected_next:.6f}")
|
|
print()
|
|
|
|
# ===== Test 2: Convert to LeRobotDataset and back =====
|
|
with TemporaryDirectory() as temp_dir:
|
|
print("\nTesting conversion to LeRobotDataset and back...")
|
|
# Convert buffer to dataset
|
|
repo_id = "test/replay_buffer_conversion"
|
|
# Create a subdirectory to avoid the "directory exists" error
|
|
dataset_dir = os.path.join(temp_dir, "dataset1")
|
|
dataset = buffer.to_lerobot_dataset(repo_id=repo_id, root=dataset_dir)
|
|
|
|
print(f"Dataset created with {len(dataset)} frames")
|
|
print(f"Dataset features: {list(dataset.features.keys())}")
|
|
|
|
# Check a random sample from the dataset
|
|
sample = dataset[0]
|
|
print(
|
|
f"Dataset sample types: {[(k, type(v)) for k, v in sample.items() if k.startswith('observation')]}"
|
|
)
|
|
|
|
# Convert dataset back to buffer
|
|
reconverted_buffer = ReplayBuffer.from_lerobot_dataset(
|
|
dataset, state_keys=list(state_dims.keys()), device="cpu"
|
|
)
|
|
|
|
print(f"Reconverted buffer size: {len(reconverted_buffer)}")
|
|
|
|
# Sample from the reconverted buffer
|
|
reconverted_batch = reconverted_buffer.sample(batch_size)
|
|
print(
|
|
f"Reconverted batch state shapes: {reconverted_batch['state']['observation.image'].shape}, {reconverted_batch['state']['observation.state'].shape}"
|
|
)
|
|
|
|
# Verify consistency before and after conversion
|
|
original_states = batch["state"]["observation.image"].mean().item()
|
|
reconverted_states = reconverted_batch["state"]["observation.image"].mean().item()
|
|
print(f"Original buffer state mean: {original_states:.4f}")
|
|
print(f"Reconverted buffer state mean: {reconverted_states:.4f}")
|
|
|
|
if abs(original_states - reconverted_states) < 1.0:
|
|
print("Values are reasonably similar - conversion works as expected")
|
|
else:
|
|
print("WARNING: Significant difference between original and reconverted values")
|
|
|
|
print("\nAll previous tests completed!")
|
|
|
|
# ===== Test for memory optimization =====
|
|
print("\n===== Testing Memory Optimization =====")
|
|
|
|
# Create two buffers, one with memory optimization and one without
|
|
standard_buffer = ReplayBuffer(
|
|
capacity=1000,
|
|
device="cpu",
|
|
state_keys=["observation.image", "observation.state"],
|
|
storage_device="cpu",
|
|
optimize_memory=False,
|
|
use_drq=True,
|
|
)
|
|
|
|
optimized_buffer = ReplayBuffer(
|
|
capacity=1000,
|
|
device="cpu",
|
|
state_keys=["observation.image", "observation.state"],
|
|
storage_device="cpu",
|
|
optimize_memory=True,
|
|
use_drq=True,
|
|
)
|
|
|
|
# Generate sample data with larger state dimensions for better memory impact
|
|
print("Generating test data...")
|
|
num_episodes = 10
|
|
steps_per_episode = 50
|
|
total_steps = num_episodes * steps_per_episode
|
|
|
|
for episode in range(num_episodes):
|
|
for step in range(steps_per_episode):
|
|
# Index in the overall sequence
|
|
i = episode * steps_per_episode + step
|
|
|
|
# Create state with identifiable values
|
|
img = torch.ones((3, 84, 84)) * (i / total_steps)
|
|
state_vec = torch.ones((10,)) * (i / total_steps)
|
|
|
|
state = {
|
|
"observation.image": img.unsqueeze(0),
|
|
"observation.state": state_vec.unsqueeze(0),
|
|
}
|
|
|
|
# Create next state (i+1 or same as current if last in episode)
|
|
is_last_step = step == steps_per_episode - 1
|
|
|
|
if is_last_step:
|
|
# At episode end, next state = current state
|
|
next_img = img.clone()
|
|
next_state_vec = state_vec.clone()
|
|
done = True
|
|
truncated = False
|
|
else:
|
|
# Within episode, next state has incremented value
|
|
next_val = (i + 1) / total_steps
|
|
next_img = torch.ones((3, 84, 84)) * next_val
|
|
next_state_vec = torch.ones((10,)) * next_val
|
|
done = False
|
|
truncated = False
|
|
|
|
next_state = {
|
|
"observation.image": next_img.unsqueeze(0),
|
|
"observation.state": next_state_vec.unsqueeze(0),
|
|
}
|
|
|
|
# Action and reward
|
|
action = torch.tensor([[i / total_steps]])
|
|
reward = float(i / total_steps)
|
|
|
|
# Add to both buffers
|
|
standard_buffer.add(state, action, reward, next_state, done, truncated)
|
|
optimized_buffer.add(state, action, reward, next_state, done, truncated)
|
|
|
|
# Verify episode boundaries with our simplified approach
|
|
print("\nVerifying simplified memory optimization...")
|
|
|
|
# Test with a new buffer with a small sequence
|
|
test_buffer = ReplayBuffer(
|
|
capacity=20,
|
|
device="cpu",
|
|
state_keys=["value"],
|
|
storage_device="cpu",
|
|
optimize_memory=True,
|
|
use_drq=False,
|
|
)
|
|
|
|
# Add a simple sequence with known episode boundaries
|
|
for i in range(20):
|
|
val = float(i)
|
|
state = {"value": torch.tensor([[val]]).float()}
|
|
next_val = float(i + 1) if i % 5 != 4 else val # Episode ends every 5 steps
|
|
next_state = {"value": torch.tensor([[next_val]]).float()}
|
|
|
|
# Set done=True at every 5th step
|
|
done = (i % 5) == 4
|
|
action = torch.tensor([[0.0]])
|
|
reward = 1.0
|
|
truncated = False
|
|
|
|
test_buffer.add(state, action, reward, next_state, done, truncated)
|
|
|
|
# Get sequential batch for verification
|
|
sequential_batch_size = test_buffer.size
|
|
all_indices = torch.arange(sequential_batch_size, device=test_buffer.storage_device)
|
|
|
|
# Get state tensors
|
|
batch_state = {"value": test_buffer.states["value"][all_indices].to(test_buffer.device)}
|
|
|
|
# Get next_state using memory-optimized approach (simply index+1)
|
|
next_indices = (all_indices + 1) % test_buffer.capacity
|
|
batch_next_state = {"value": test_buffer.states["value"][next_indices].to(test_buffer.device)}
|
|
|
|
# Get other tensors
|
|
batch_dones = test_buffer.dones[all_indices].to(test_buffer.device)
|
|
|
|
# Print sequential values
|
|
print("State, Next State, Done (Sequential values with simplified optimization):")
|
|
state_values = batch_state["value"].squeeze().tolist()
|
|
next_values = batch_next_state["value"].squeeze().tolist()
|
|
done_flags = batch_dones.tolist()
|
|
|
|
# Print all values
|
|
for i in range(len(state_values)):
|
|
print(f" {state_values[i]:.1f} → {next_values[i]:.1f}, Done: {done_flags[i]}")
|
|
|
|
# Explain the memory optimization tradeoff
|
|
print("\nWith simplified memory optimization:")
|
|
print("- We always use the next state in the buffer (index+1) as next_state")
|
|
print("- For terminal states, this means using the first state of the next episode")
|
|
print("- This is a common tradeoff in RL implementations for memory efficiency")
|
|
print("- Since we track done flags, the algorithm can handle these transitions correctly")
|
|
|
|
# Test random sampling
|
|
print("\nVerifying random sampling with simplified memory optimization...")
|
|
random_samples = test_buffer.sample(20) # Sample all transitions
|
|
|
|
# Extract values
|
|
random_state_values = random_samples["state"]["value"].squeeze().tolist()
|
|
random_next_values = random_samples["next_state"]["value"].squeeze().tolist()
|
|
random_done_flags = random_samples["done"].bool().tolist()
|
|
|
|
# Print a few samples
|
|
print("Random samples - State, Next State, Done (First 10):")
|
|
for i in range(10):
|
|
print(f" {random_state_values[i]:.1f} → {random_next_values[i]:.1f}, Done: {random_done_flags[i]}")
|
|
|
|
# Calculate memory savings
|
|
# Assume optimized_buffer and standard_buffer have already been initialized and filled
|
|
std_mem = (
|
|
sum(
|
|
standard_buffer.states[key].nelement() * standard_buffer.states[key].element_size()
|
|
for key in standard_buffer.states
|
|
)
|
|
* 2
|
|
)
|
|
opt_mem = sum(
|
|
optimized_buffer.states[key].nelement() * optimized_buffer.states[key].element_size()
|
|
for key in optimized_buffer.states
|
|
)
|
|
|
|
savings_percent = (std_mem - opt_mem) / std_mem * 100
|
|
|
|
print("\nMemory optimization result:")
|
|
print(f"- Standard buffer state memory: {std_mem / (1024 * 1024):.2f} MB")
|
|
print(f"- Optimized buffer state memory: {opt_mem / (1024 * 1024):.2f} MB")
|
|
print(f"- Memory savings for state tensors: {savings_percent:.1f}%")
|
|
|
|
print("\nAll memory optimization tests completed!")
|
|
|
|
# # ===== Test real dataset conversion =====
|
|
# print("\n===== Testing Real LeRobotDataset Conversion =====")
|
|
# try:
|
|
# # Try to use a real dataset if available
|
|
# dataset_name = "AdilZtn/Maniskill-Pushcube-demonstration-small"
|
|
# dataset = LeRobotDataset(repo_id=dataset_name)
|
|
|
|
# # Print available keys to debug
|
|
# sample = dataset[0]
|
|
# print("Available keys in dataset:", list(sample.keys()))
|
|
|
|
# # Check for required keys
|
|
# if "action" not in sample or "next.reward" not in sample:
|
|
# print("Dataset missing essential keys. Cannot convert.")
|
|
# raise ValueError("Missing required keys in dataset")
|
|
|
|
# # Auto-detect appropriate state keys
|
|
# image_keys = []
|
|
# state_keys = []
|
|
# for k, v in sample.items():
|
|
# # Skip metadata keys and action/reward keys
|
|
# if k in {
|
|
# "index",
|
|
# "episode_index",
|
|
# "frame_index",
|
|
# "timestamp",
|
|
# "task_index",
|
|
# "action",
|
|
# "next.reward",
|
|
# "next.done",
|
|
# }:
|
|
# continue
|
|
|
|
# # Infer key type from tensor shape
|
|
# if isinstance(v, torch.Tensor):
|
|
# if len(v.shape) == 3 and (v.shape[0] == 3 or v.shape[0] == 1):
|
|
# # Likely an image (channels, height, width)
|
|
# image_keys.append(k)
|
|
# else:
|
|
# # Likely state or other vector
|
|
# state_keys.append(k)
|
|
|
|
# print(f"Detected image keys: {image_keys}")
|
|
# print(f"Detected state keys: {state_keys}")
|
|
|
|
# if not image_keys and not state_keys:
|
|
# print("No usable keys found in dataset, skipping further tests")
|
|
# raise ValueError("No usable keys found in dataset")
|
|
|
|
# # Test with standard and memory-optimized buffers
|
|
# for optimize_memory in [False, True]:
|
|
# buffer_type = "Standard" if not optimize_memory else "Memory-optimized"
|
|
# print(f"\nTesting {buffer_type} buffer with real dataset...")
|
|
|
|
# # Convert to ReplayBuffer with detected keys
|
|
# replay_buffer = ReplayBuffer.from_lerobot_dataset(
|
|
# lerobot_dataset=dataset,
|
|
# state_keys=image_keys + state_keys,
|
|
# device="cpu",
|
|
# optimize_memory=optimize_memory,
|
|
# )
|
|
# print(f"Loaded {len(replay_buffer)} transitions from {dataset_name}")
|
|
|
|
# # Test sampling
|
|
# real_batch = replay_buffer.sample(32)
|
|
# print(f"Sampled batch from real dataset ({buffer_type}), state shapes:")
|
|
# for key in real_batch["state"]:
|
|
# print(f" {key}: {real_batch['state'][key].shape}")
|
|
|
|
# # Convert back to LeRobotDataset
|
|
# with TemporaryDirectory() as temp_dir:
|
|
# dataset_name = f"test/real_dataset_converted_{buffer_type}"
|
|
# replay_buffer_converted = replay_buffer.to_lerobot_dataset(
|
|
# repo_id=dataset_name,
|
|
# root=os.path.join(temp_dir, f"dataset_{buffer_type}"),
|
|
# )
|
|
# print(
|
|
# f"Successfully converted back to LeRobotDataset with {len(replay_buffer_converted)} frames"
|
|
# )
|
|
|
|
# except Exception as e:
|
|
# print(f"Real dataset test failed: {e}")
|
|
# print("This is expected if running offline or if the dataset is not available.")
|
|
|
|
# print("\nAll tests completed!")
|