Fix for the issue https://github.com/huggingface/lerobot/issues/638 (#639)
This commit is contained in:
parent
5ac79d5b2b
commit
212b12cf82
|
@ -0,0 +1,20 @@
|
|||
{
|
||||
"name": "Maniskill Dev Container",
|
||||
"image": "maniskill",
|
||||
"workspaceFolder": "/lerobot",
|
||||
"mounts": [
|
||||
"source=${localWorkspaceFolder},target=/lerobot,type=bind,consistency=cached"
|
||||
],
|
||||
"runArgs": [
|
||||
"--network", "host",
|
||||
"--gpus", "all",
|
||||
"--runtime=nvidia"
|
||||
],
|
||||
"settings": {
|
||||
"terminal.integrated.defaultProfile.linux": "bash"
|
||||
},
|
||||
"extensions": [
|
||||
"ms-python.python",
|
||||
"ms-vscode.cpptools"
|
||||
]
|
||||
}
|
|
@ -0,0 +1,6 @@
|
|||
{
|
||||
"file_format_version" : "1.0.0",
|
||||
"ICD" : {
|
||||
"library_path" : "libEGL_nvidia.so.0"
|
||||
}
|
||||
}
|
|
@ -0,0 +1,65 @@
|
|||
# Use the Nvidia base image
|
||||
FROM nvidia/cudagl:11.3.1-devel-ubuntu20.04
|
||||
ENV NVIDIA_DRIVER_CAPABILITIES=all
|
||||
ARG PYTHON_VERSION=3.10
|
||||
|
||||
# Install os-level packages
|
||||
RUN apt-get update && \
|
||||
DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
||||
bash-completion \
|
||||
build-essential \
|
||||
ca-certificates \
|
||||
cmake \
|
||||
curl \
|
||||
git \
|
||||
git-lfs \
|
||||
htop \
|
||||
libegl1 \
|
||||
libxext6 \
|
||||
libjpeg-dev \
|
||||
libpng-dev \
|
||||
libvulkan1 \
|
||||
rsync \
|
||||
tmux \
|
||||
unzip \
|
||||
vim \
|
||||
vulkan-utils \
|
||||
wget \
|
||||
xvfb \
|
||||
libglib2.0-0 \
|
||||
libgl1-mesa-glx \
|
||||
libegl1-mesa \
|
||||
ffmpeg \
|
||||
build-essential \
|
||||
cmake \
|
||||
portaudio19-dev \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install (mini) conda
|
||||
RUN curl -o ~/miniconda.sh https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \
|
||||
chmod +x ~/miniconda.sh && \
|
||||
~/miniconda.sh -b -p /opt/conda && \
|
||||
rm ~/miniconda.sh && \
|
||||
/opt/conda/bin/conda init && \
|
||||
/opt/conda/bin/conda install -y python="$PYTHON_VERSION" && \
|
||||
/opt/conda/bin/conda clean -ya
|
||||
|
||||
ENV PATH=/opt/conda/bin:$PATH
|
||||
SHELL ["/bin/bash", "-c"]
|
||||
|
||||
# https://github.com/haosulab/ManiSkill/issues/9
|
||||
# Install Poetry
|
||||
RUN curl -sSL https://install.python-poetry.org | python3 -
|
||||
ENV PATH="/root/.local/bin:${PATH}"
|
||||
# Copy Vulkan JSON files
|
||||
COPY docker/manyskill-lerobot-gpu/nvidia_icd.json /usr/share/vulkan/icd.d/nvidia_icd.json
|
||||
COPY docker/manyskill-lerobot-gpu/nvidia_layers.json /etc/vulkan/implicit_layer.d/nvidia_layers.json
|
||||
|
||||
# Install LeRobot
|
||||
COPY . /lerobot
|
||||
WORKDIR /lerobot
|
||||
RUN poetry install --sync --all-extras
|
||||
RUN pip install --upgrade mani-skill==3.0.0.b15 && pip cache purge
|
||||
|
||||
# Download PhysX GPU binary
|
||||
RUN python -c "exec('import sapien.physx as physx;\ntry:\n physx.enable_gpu()\nexcept:\n pass;')"
|
|
@ -0,0 +1,7 @@
|
|||
{
|
||||
"file_format_version" : "1.0.0",
|
||||
"ICD": {
|
||||
"library_path": "libGLX_nvidia.so.0",
|
||||
"api_version" : "1.2.155"
|
||||
}
|
||||
}
|
|
@ -0,0 +1,21 @@
|
|||
{
|
||||
"file_format_version" : "1.0.0",
|
||||
"layer": {
|
||||
"name": "VK_LAYER_NV_optimus",
|
||||
"type": "INSTANCE",
|
||||
"library_path": "libGLX_nvidia.so.0",
|
||||
"api_version" : "1.2.155",
|
||||
"implementation_version" : "1",
|
||||
"description" : "NVIDIA Optimus layer",
|
||||
"functions": {
|
||||
"vkGetInstanceProcAddr": "vk_optimusGetInstanceProcAddr",
|
||||
"vkGetDeviceProcAddr": "vk_optimusGetDeviceProcAddr"
|
||||
},
|
||||
"enable_environment": {
|
||||
"__NV_PRIME_RENDER_OFFLOAD": "1"
|
||||
},
|
||||
"disable_environment": {
|
||||
"DISABLE_LAYER_NV_OPTIMUS_1": ""
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,327 @@
|
|||
#!/usr/bin/env python
|
||||
"""
|
||||
Benchmark ReplayBuffer performance with different device configurations.
|
||||
|
||||
This script compares performance of ReplayBuffer across different configurations:
|
||||
1. Pure GPU mode (storage and computation on GPU)
|
||||
2. Mixed mode (storage on CPU, computation on GPU)
|
||||
|
||||
For each configuration, it benchmarks:
|
||||
- The add method (adding transitions to the buffer)
|
||||
- The sample method with a batch size of 512
|
||||
"""
|
||||
|
||||
import time
|
||||
import torch
|
||||
import argparse
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from typing import Dict, List, Tuple, Optional
|
||||
from collections import defaultdict
|
||||
|
||||
from lerobot.scripts.server.buffer import ReplayBuffer
|
||||
|
||||
def generate_random_transition(
|
||||
image_shape: Tuple[int, int, int] = (3, 224, 224),
|
||||
state_shape: Tuple[int, ...] = (10,),
|
||||
action_shape: Tuple[int, ...] = (6,),
|
||||
device: str = "cpu"
|
||||
) -> Tuple[Dict[str, torch.Tensor], torch.Tensor, float, Dict[str, torch.Tensor], bool, bool]:
|
||||
"""Generate a random transition for testing."""
|
||||
# Create state with both image and vector components
|
||||
state = {
|
||||
"observation.image": torch.randn(1, *image_shape, device=device),
|
||||
"observation.vector": torch.randn(1, *state_shape, device=device)
|
||||
}
|
||||
|
||||
# Create next_state with same structure
|
||||
next_state = {
|
||||
"observation.image": torch.randn(1, *image_shape, device=device),
|
||||
"observation.vector": torch.randn(1, *state_shape, device=device)
|
||||
}
|
||||
|
||||
# Create random action, reward, done flag
|
||||
action = torch.randn(1, *action_shape, device=device)
|
||||
reward = float(torch.rand(1).item())
|
||||
done = bool(torch.rand(1) > 0.9) # 10% chance of being done
|
||||
truncated = bool(torch.rand(1) > 0.95) # 5% chance of being truncated
|
||||
|
||||
return state, action, reward, next_state, done, truncated
|
||||
|
||||
def warm_up_gpu():
|
||||
"""Warm up the GPU to ensure consistent benchmarking."""
|
||||
if torch.cuda.is_available():
|
||||
# Run some operations to warm up the GPU
|
||||
print("Warming up GPU...")
|
||||
x = torch.randn(1000, 1000, device="cuda")
|
||||
for _ in range(10):
|
||||
x = torch.matmul(x, x)
|
||||
# Clear cache
|
||||
torch.cuda.empty_cache()
|
||||
print("GPU warm-up complete")
|
||||
|
||||
def benchmark_buffer(
|
||||
capacity: int = 100000,
|
||||
add_count: int = 10000,
|
||||
sample_count: int = 100,
|
||||
batch_size: int = 512,
|
||||
image_shape: Tuple[int, int, int] = (3, 224, 224),
|
||||
state_shape: Tuple[int, ...] = (10,),
|
||||
action_shape: Tuple[int, ...] = (6,),
|
||||
configs: Optional[List[Dict]] = None
|
||||
) -> Dict[str, Dict[str, float]]:
|
||||
"""
|
||||
Benchmark ReplayBuffer with different configurations.
|
||||
|
||||
Args:
|
||||
capacity: Buffer capacity
|
||||
add_count: Number of transitions to add during benchmark
|
||||
sample_count: Number of sampling operations to benchmark
|
||||
batch_size: Batch size for sampling
|
||||
image_shape: Shape of image observations
|
||||
state_shape: Shape of vector observations
|
||||
action_shape: Shape of actions
|
||||
configs: List of configurations to benchmark
|
||||
|
||||
Returns:
|
||||
Dictionary with benchmark results
|
||||
"""
|
||||
if configs is None:
|
||||
configs = [
|
||||
{
|
||||
"name": "Pure GPU",
|
||||
"device": "cuda:0",
|
||||
"storage_device": "cuda:0",
|
||||
"use_pinned_memory": False,
|
||||
"async_transfer": False,
|
||||
},
|
||||
{
|
||||
"name": "CPU Storage + GPU Compute",
|
||||
"device": "cuda:0",
|
||||
"storage_device": "cpu",
|
||||
"use_pinned_memory": True,
|
||||
"async_transfer": True,
|
||||
},
|
||||
{
|
||||
"name": "CPU Storage + GPU Compute (No Pinned Memory)",
|
||||
"device": "cuda:0",
|
||||
"storage_device": "cpu",
|
||||
"use_pinned_memory": False,
|
||||
"async_transfer": False,
|
||||
},
|
||||
{
|
||||
"name": "Pure CPU",
|
||||
"device": "cpu",
|
||||
"storage_device": "cpu",
|
||||
"use_pinned_memory": False,
|
||||
"async_transfer": False,
|
||||
}
|
||||
]
|
||||
|
||||
results = defaultdict(dict)
|
||||
|
||||
for config in configs:
|
||||
if not torch.cuda.is_available() and "cuda" in config["device"]:
|
||||
print(f"Skipping {config['name']} as CUDA is not available")
|
||||
continue
|
||||
|
||||
print(f"\nBenchmarking configuration: {config['name']}")
|
||||
print(f" - Compute device: {config['device']}")
|
||||
print(f" - Storage device: {config['storage_device']}")
|
||||
print(f" - Pinned memory: {config['use_pinned_memory']}")
|
||||
print(f" - Async transfer: {config['async_transfer']}")
|
||||
|
||||
# Create buffer with this configuration
|
||||
buffer = ReplayBuffer(
|
||||
capacity=capacity,
|
||||
device=config["device"],
|
||||
storage_device=config["storage_device"],
|
||||
use_pinned_memory=config["use_pinned_memory"],
|
||||
async_transfer=config["async_transfer"],
|
||||
optimize_memory=False # Keep simple for benchmarking
|
||||
)
|
||||
|
||||
# Benchmark add operation
|
||||
add_times = []
|
||||
for i in range(add_count):
|
||||
# Generate random transition on the appropriate device
|
||||
initial_device = "cuda:0" if config["device"] == "cuda:0" else "cpu"
|
||||
state, action, reward, next_state, done, truncated = generate_random_transition(
|
||||
image_shape=image_shape,
|
||||
state_shape=state_shape,
|
||||
action_shape=action_shape,
|
||||
device=initial_device
|
||||
)
|
||||
|
||||
# Measure add time
|
||||
start_time = time.perf_counter()
|
||||
buffer.add(state, action, reward, next_state, done, truncated)
|
||||
end_time = time.perf_counter()
|
||||
|
||||
add_times.append(end_time - start_time)
|
||||
|
||||
# Print progress
|
||||
if (i + 1) % (add_count // 10) == 0:
|
||||
print(f" Added {i + 1}/{add_count} transitions")
|
||||
|
||||
# Ensure buffer has enough samples for sampling benchmark
|
||||
while len(buffer) < batch_size:
|
||||
state, action, reward, next_state, done, truncated = generate_random_transition(
|
||||
image_shape=image_shape,
|
||||
state_shape=state_shape,
|
||||
action_shape=action_shape,
|
||||
device=initial_device
|
||||
)
|
||||
buffer.add(state, action, reward, next_state, done, truncated)
|
||||
|
||||
# Benchmark sample operation
|
||||
sample_times = []
|
||||
for i in range(sample_count):
|
||||
# Synchronize GPU before timing (for fair comparison)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
|
||||
start_time = time.perf_counter()
|
||||
batch = buffer.sample(batch_size)
|
||||
|
||||
# Ensure computation is complete before timing
|
||||
if torch.cuda.is_available() and "cuda" in config["device"]:
|
||||
torch.cuda.synchronize()
|
||||
|
||||
end_time = time.perf_counter()
|
||||
sample_times.append(end_time - start_time)
|
||||
|
||||
# Print progress
|
||||
if (i + 1) % (sample_count // 5) == 0:
|
||||
print(f" Sampled {i + 1}/{sample_count} batches")
|
||||
|
||||
# Record results
|
||||
results[config["name"]]["add_avg_ms"] = np.mean(add_times[100:]) * 1000 # Skip first 100 for warmup
|
||||
results[config["name"]]["add_min_ms"] = np.min(add_times[100:]) * 1000
|
||||
results[config["name"]]["add_max_ms"] = np.max(add_times[100:]) * 1000
|
||||
results[config["name"]]["sample_avg_ms"] = np.mean(sample_times) * 1000
|
||||
results[config["name"]]["sample_min_ms"] = np.min(sample_times) * 1000
|
||||
results[config["name"]]["sample_max_ms"] = np.max(sample_times) * 1000
|
||||
|
||||
# Force cleanup
|
||||
del buffer
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return results
|
||||
|
||||
def plot_results(results: Dict[str, Dict[str, float]], output_path: Optional[str] = None):
|
||||
"""Plot benchmark results."""
|
||||
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))
|
||||
|
||||
# Extract data for plotting
|
||||
configs = list(results.keys())
|
||||
add_times = [results[config]["add_avg_ms"] for config in configs]
|
||||
add_mins = [results[config]["add_min_ms"] for config in configs]
|
||||
add_maxs = [results[config]["add_max_ms"] for config in configs]
|
||||
|
||||
sample_times = [results[config]["sample_avg_ms"] for config in configs]
|
||||
sample_mins = [results[config]["sample_min_ms"] for config in configs]
|
||||
sample_maxs = [results[config]["sample_max_ms"] for config in configs]
|
||||
|
||||
# Add operation plot
|
||||
bar_width = 0.5
|
||||
x = np.arange(len(configs))
|
||||
bars1 = ax1.bar(x, add_times, bar_width, label='Average', color='skyblue')
|
||||
|
||||
# Add error bars
|
||||
ax1.errorbar(x, add_times, yerr=[
|
||||
np.array(add_times) - np.array(add_mins),
|
||||
np.array(add_maxs) - np.array(add_times)
|
||||
], fmt='none', ecolor='black', capsize=5)
|
||||
|
||||
ax1.set_xlabel('Configuration')
|
||||
ax1.set_ylabel('Time (ms)')
|
||||
ax1.set_title('add() Operation Performance')
|
||||
ax1.set_xticks(x)
|
||||
ax1.set_xticklabels(configs, rotation=45, ha='right')
|
||||
|
||||
# Annotate with values
|
||||
for bar in bars1:
|
||||
height = bar.get_height()
|
||||
ax1.text(bar.get_x() + bar.get_width()/2., height + 0.1,
|
||||
f'{height:.2f}ms', ha='center', va='bottom', rotation=0)
|
||||
|
||||
# Sample operation plot
|
||||
bars2 = ax2.bar(x, sample_times, bar_width, label='Average', color='lightgreen')
|
||||
|
||||
# Add error bars
|
||||
ax2.errorbar(x, sample_times, yerr=[
|
||||
np.array(sample_times) - np.array(sample_mins),
|
||||
np.array(sample_maxs) - np.array(sample_times)
|
||||
], fmt='none', ecolor='black', capsize=5)
|
||||
|
||||
ax2.set_xlabel('Configuration')
|
||||
ax2.set_ylabel('Time (ms)')
|
||||
ax2.set_title('sample(512) Operation Performance')
|
||||
ax2.set_xticks(x)
|
||||
ax2.set_xticklabels(configs, rotation=45, ha='right')
|
||||
|
||||
# Annotate with values
|
||||
for bar in bars2:
|
||||
height = bar.get_height()
|
||||
ax2.text(bar.get_x() + bar.get_width()/2., height + 0.1,
|
||||
f'{height:.2f}ms', ha='center', va='bottom', rotation=0)
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
if output_path:
|
||||
plt.savefig(output_path)
|
||||
print(f"Results saved to {output_path}")
|
||||
|
||||
plt.show()
|
||||
|
||||
def print_results(results: Dict[str, Dict[str, float]]):
|
||||
"""Print benchmark results in a formatted table."""
|
||||
print("\n=== BENCHMARK RESULTS ===")
|
||||
|
||||
# Header
|
||||
print(f"{'Configuration':<40} | {'add() avg':<10} | {'add() min':<10} | {'add() max':<10} | "
|
||||
f"{'sample() avg':<12} | {'sample() min':<12} | {'sample() max':<12}")
|
||||
print("-" * 120)
|
||||
|
||||
# Data rows
|
||||
for config, metrics in results.items():
|
||||
print(f"{config:<40} | "
|
||||
f"{metrics['add_avg_ms']:<10.2f} | "
|
||||
f"{metrics['add_min_ms']:<10.2f} | "
|
||||
f"{metrics['add_max_ms']:<10.2f} | "
|
||||
f"{metrics['sample_avg_ms']:<12.2f} | "
|
||||
f"{metrics['sample_min_ms']:<12.2f} | "
|
||||
f"{metrics['sample_max_ms']:<12.2f}")
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Benchmark ReplayBuffer performance")
|
||||
parser.add_argument("--capacity", type=int, default=50000, help="Buffer capacity")
|
||||
parser.add_argument("--add-count", type=int, default=10000, help="Number of add operations to benchmark")
|
||||
parser.add_argument("--sample-count", type=int, default=100, help="Number of sample operations to benchmark")
|
||||
parser.add_argument("--batch-size", type=int, default=512, help="Batch size for sampling")
|
||||
parser.add_argument("--output", type=str, default=None, help="Path to save the results plot")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Check if CUDA is available
|
||||
if not torch.cuda.is_available():
|
||||
print("WARNING: CUDA is not available, only CPU benchmarks will be run")
|
||||
else:
|
||||
warm_up_gpu()
|
||||
|
||||
# Run benchmark
|
||||
results = benchmark_buffer(
|
||||
capacity=args.capacity,
|
||||
add_count=args.add_count,
|
||||
sample_count=args.sample_count,
|
||||
batch_size=args.batch_size
|
||||
)
|
||||
|
||||
# Print and plot results
|
||||
print_results(results)
|
||||
plot_results(results, args.output)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,54 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def main():
|
||||
# Initialize the dataset
|
||||
dataset = LeRobotDataset(
|
||||
repo_id="aractingi/pushcube_gamepad",
|
||||
download_videos=True, # Set to False if you don't need video data
|
||||
)
|
||||
|
||||
# Print dataset information
|
||||
logger.info(f"Dataset loaded successfully!")
|
||||
logger.info(f"Number of episodes: {dataset.num_episodes}")
|
||||
logger.info(f"Number of frames: {dataset.num_frames}")
|
||||
logger.info(f"FPS: {dataset.fps}")
|
||||
logger.info(f"Features: {list(dataset.features.keys())}")
|
||||
|
||||
# Get a sample frame
|
||||
sample = dataset[0]
|
||||
logger.info(f"\nSample frame keys: {list(sample.keys())}")
|
||||
|
||||
# Print shapes of some key features
|
||||
for key in ["observation.images.laptop", "observation.images.phone"]:
|
||||
if key in sample:
|
||||
logger.info(f"Shape of {key}: {sample[key].shape}")
|
||||
|
||||
# Print task information
|
||||
logger.info(f"\nTotal tasks: {dataset.meta.total_tasks}")
|
||||
logger.info("Tasks:")
|
||||
for task_idx, task in dataset.meta.tasks.items():
|
||||
logger.info(f" {task_idx}: {task}")
|
||||
global_min = float("inf")
|
||||
global_max = float("-inf")
|
||||
for sample in dataset:
|
||||
for k, v in sample.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
v = v.to("cuda")
|
||||
if k == "observation.state":
|
||||
global_min = min(global_min, torch.min(v))
|
||||
global_max = max(global_max, torch.max(v))
|
||||
print(global_min, global_max)
|
||||
|
||||
breakpoint()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,69 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.scripts.server.buffer import ReplayBuffer
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def main():
|
||||
# Initialize the dataset
|
||||
logger.info("Loading LeRobotDataset...")
|
||||
dataset = LeRobotDataset(
|
||||
repo_id="aractingi/pushcube_gamepad",
|
||||
download_videos=True, # Set to False if you don't need video data
|
||||
)
|
||||
|
||||
# Print dataset information
|
||||
logger.info(f"Dataset loaded successfully!")
|
||||
logger.info(f"Number of episodes: {dataset.num_episodes}")
|
||||
logger.info(f"Number of frames: {dataset.num_frames}")
|
||||
logger.info(f"FPS: {dataset.fps}")
|
||||
logger.info(f"Features: {list(dataset.features.keys())}")
|
||||
|
||||
# Convert dataset to ReplayBuffer
|
||||
logger.info("Converting dataset to ReplayBuffer...")
|
||||
|
||||
# Define which keys from the dataset to use as state
|
||||
# Get all observation keys from the first sample
|
||||
sample = dataset[0]
|
||||
state_keys = [key for key in sample.keys() if "observation" in key]
|
||||
logger.info(f"Using observation keys: {state_keys}")
|
||||
|
||||
# Create ReplayBuffer from the dataset
|
||||
buffer = ReplayBuffer.from_lerobot_dataset(
|
||||
lerobot_dataset=dataset,
|
||||
device="cuda:0" if torch.cuda.is_available() else "cpu",
|
||||
state_keys=state_keys,
|
||||
capacity=None, # Use all data from the dataset
|
||||
use_drq=True,
|
||||
optimize_memory=False,
|
||||
)
|
||||
|
||||
logger.info(f"ReplayBuffer created with {len(buffer)} transitions")
|
||||
|
||||
# Sample from the buffer and display information
|
||||
if len(buffer) > 0:
|
||||
batch_size = min(5, len(buffer))
|
||||
logger.info(f"Sampling {batch_size} transitions from the buffer...")
|
||||
|
||||
batch = buffer.sample(batch_size)
|
||||
|
||||
logger.info(f"Batch keys: {list(batch.keys())}")
|
||||
|
||||
# Print shapes of state tensors
|
||||
logger.info("State shapes:")
|
||||
for key, tensor in batch["state"].items():
|
||||
logger.info(f" {key}: {tensor.shape}")
|
||||
|
||||
# Print action and reward information
|
||||
logger.info(f"Action shape: {batch['action'].shape}")
|
||||
logger.info(f"Reward shape: {batch['reward'].shape}")
|
||||
logger.info(f"Sample rewards: {batch['reward']}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
Reference in New Issue