69 lines
2.3 KiB
Python
69 lines
2.3 KiB
Python
#!/usr/bin/env python
|
|
|
|
import logging
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
|
from lerobot.scripts.server.buffer import ReplayBuffer
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
def main():
|
|
# Initialize the dataset
|
|
logger.info("Loading LeRobotDataset...")
|
|
dataset = LeRobotDataset(
|
|
repo_id="aractingi/pushcube_gamepad",
|
|
download_videos=True, # Set to False if you don't need video data
|
|
)
|
|
|
|
# Print dataset information
|
|
logger.info(f"Dataset loaded successfully!")
|
|
logger.info(f"Number of episodes: {dataset.num_episodes}")
|
|
logger.info(f"Number of frames: {dataset.num_frames}")
|
|
logger.info(f"FPS: {dataset.fps}")
|
|
logger.info(f"Features: {list(dataset.features.keys())}")
|
|
|
|
# Convert dataset to ReplayBuffer
|
|
logger.info("Converting dataset to ReplayBuffer...")
|
|
|
|
# Define which keys from the dataset to use as state
|
|
# Get all observation keys from the first sample
|
|
sample = dataset[0]
|
|
state_keys = [key for key in sample.keys() if "observation" in key]
|
|
logger.info(f"Using observation keys: {state_keys}")
|
|
|
|
# Create ReplayBuffer from the dataset
|
|
buffer = ReplayBuffer.from_lerobot_dataset(
|
|
lerobot_dataset=dataset,
|
|
device="cuda:0" if torch.cuda.is_available() else "cpu",
|
|
state_keys=state_keys,
|
|
capacity=None, # Use all data from the dataset
|
|
use_drq=True,
|
|
optimize_memory=False,
|
|
)
|
|
|
|
logger.info(f"ReplayBuffer created with {len(buffer)} transitions")
|
|
|
|
# Sample from the buffer and display information
|
|
if len(buffer) > 0:
|
|
batch_size = min(5, len(buffer))
|
|
logger.info(f"Sampling {batch_size} transitions from the buffer...")
|
|
|
|
batch = buffer.sample(batch_size)
|
|
|
|
logger.info(f"Batch keys: {list(batch.keys())}")
|
|
|
|
# Print shapes of state tensors
|
|
logger.info("State shapes:")
|
|
for key, tensor in batch["state"].items():
|
|
logger.info(f" {key}: {tensor.shape}")
|
|
|
|
# Print action and reward information
|
|
logger.info(f"Action shape: {batch['action'].shape}")
|
|
logger.info(f"Reward shape: {batch['reward'].shape}")
|
|
logger.info(f"Sample rewards: {batch['reward']}")
|
|
|
|
if __name__ == "__main__":
|
|
main() |