lerobot/lerobot/scripts/server/load_dataset_and_buffer.py

69 lines
2.3 KiB
Python

#!/usr/bin/env python
import logging
from pathlib import Path
import torch
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.scripts.server.buffer import ReplayBuffer
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def main():
# Initialize the dataset
logger.info("Loading LeRobotDataset...")
dataset = LeRobotDataset(
repo_id="aractingi/pushcube_gamepad",
download_videos=True, # Set to False if you don't need video data
)
# Print dataset information
logger.info(f"Dataset loaded successfully!")
logger.info(f"Number of episodes: {dataset.num_episodes}")
logger.info(f"Number of frames: {dataset.num_frames}")
logger.info(f"FPS: {dataset.fps}")
logger.info(f"Features: {list(dataset.features.keys())}")
# Convert dataset to ReplayBuffer
logger.info("Converting dataset to ReplayBuffer...")
# Define which keys from the dataset to use as state
# Get all observation keys from the first sample
sample = dataset[0]
state_keys = [key for key in sample.keys() if "observation" in key]
logger.info(f"Using observation keys: {state_keys}")
# Create ReplayBuffer from the dataset
buffer = ReplayBuffer.from_lerobot_dataset(
lerobot_dataset=dataset,
device="cuda:0" if torch.cuda.is_available() else "cpu",
state_keys=state_keys,
capacity=None, # Use all data from the dataset
use_drq=True,
optimize_memory=False,
)
logger.info(f"ReplayBuffer created with {len(buffer)} transitions")
# Sample from the buffer and display information
if len(buffer) > 0:
batch_size = min(5, len(buffer))
logger.info(f"Sampling {batch_size} transitions from the buffer...")
batch = buffer.sample(batch_size)
logger.info(f"Batch keys: {list(batch.keys())}")
# Print shapes of state tensors
logger.info("State shapes:")
for key, tensor in batch["state"].items():
logger.info(f" {key}: {tensor.shape}")
# Print action and reward information
logger.info(f"Action shape: {batch['action'].shape}")
logger.info(f"Reward shape: {batch['reward'].shape}")
logger.info(f"Sample rewards: {batch['reward']}")
if __name__ == "__main__":
main()