Enhance SAC configuration and replay buffer with asynchronous prefetching support
- Added async_prefetch parameter to SACConfig for improved buffer management. - Implemented get_iterator method in ReplayBuffer to support asynchronous prefetching of batches. - Updated learner_server to utilize the new iterator for online and offline sampling, enhancing training efficiency.
This commit is contained in:
parent
51f1625c20
commit
38a8dbd9c9
|
@ -42,8 +42,6 @@ class CriticNetworkConfig:
|
||||||
final_activation: str | None = None
|
final_activation: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ActorNetworkConfig:
|
class ActorNetworkConfig:
|
||||||
hidden_dims: list[int] = field(default_factory=lambda: [256, 256])
|
hidden_dims: list[int] = field(default_factory=lambda: [256, 256])
|
||||||
|
@ -94,6 +92,7 @@ class SACConfig(PreTrainedConfig):
|
||||||
online_env_seed: Seed for the online environment.
|
online_env_seed: Seed for the online environment.
|
||||||
online_buffer_capacity: Capacity of the online replay buffer.
|
online_buffer_capacity: Capacity of the online replay buffer.
|
||||||
offline_buffer_capacity: Capacity of the offline replay buffer.
|
offline_buffer_capacity: Capacity of the offline replay buffer.
|
||||||
|
async_prefetch: Whether to use asynchronous prefetching for the buffers.
|
||||||
online_step_before_learning: Number of steps before learning starts.
|
online_step_before_learning: Number of steps before learning starts.
|
||||||
policy_update_freq: Frequency of policy updates.
|
policy_update_freq: Frequency of policy updates.
|
||||||
discount: Discount factor for the SAC algorithm.
|
discount: Discount factor for the SAC algorithm.
|
||||||
|
@ -154,6 +153,7 @@ class SACConfig(PreTrainedConfig):
|
||||||
online_env_seed: int = 10000
|
online_env_seed: int = 10000
|
||||||
online_buffer_capacity: int = 100000
|
online_buffer_capacity: int = 100000
|
||||||
offline_buffer_capacity: int = 100000
|
offline_buffer_capacity: int = 100000
|
||||||
|
async_prefetch: bool = False
|
||||||
online_step_before_learning: int = 100
|
online_step_before_learning: int = 100
|
||||||
policy_update_freq: int = 1
|
policy_update_freq: int = 1
|
||||||
|
|
||||||
|
|
|
@ -345,6 +345,109 @@ class ReplayBuffer:
|
||||||
truncated=batch_truncateds,
|
truncated=batch_truncateds,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_iterator(
|
||||||
|
self,
|
||||||
|
batch_size: int,
|
||||||
|
async_prefetch: bool = True,
|
||||||
|
queue_size: int = 2,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Creates an infinite iterator that yields batches of transitions.
|
||||||
|
Will automatically restart when internal iterator is exhausted.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch_size (int): Size of batches to sample
|
||||||
|
async_prefetch (bool): Whether to use asynchronous prefetching with threads (default: True)
|
||||||
|
queue_size (int): Number of batches to prefetch (default: 2)
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
BatchTransition: Batched transitions
|
||||||
|
"""
|
||||||
|
while True: # Create an infinite loop
|
||||||
|
if async_prefetch:
|
||||||
|
# Get the standard iterator
|
||||||
|
iterator = self._get_async_iterator(queue_size=queue_size, batch_size=batch_size)
|
||||||
|
else:
|
||||||
|
iterator = self._get_naive_iterator(batch_size=batch_size, queue_size=queue_size)
|
||||||
|
|
||||||
|
# Yield all items from the iterator
|
||||||
|
try:
|
||||||
|
yield from iterator
|
||||||
|
except StopIteration:
|
||||||
|
# Just continue the outer loop to create a new iterator
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _get_async_iterator(self, batch_size: int, queue_size: int = 2):
|
||||||
|
"""
|
||||||
|
Creates an iterator that prefetches batches in a background thread.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
queue_size (int): Number of batches to prefetch (default: 2)
|
||||||
|
batch_size (int): Size of batches to sample (default: 128)
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
BatchTransition: Prefetched batch transitions
|
||||||
|
"""
|
||||||
|
import threading
|
||||||
|
import queue
|
||||||
|
|
||||||
|
# Use thread-safe queue
|
||||||
|
data_queue = queue.Queue(maxsize=queue_size)
|
||||||
|
running = [True] # Use list to allow modification in nested function
|
||||||
|
|
||||||
|
def prefetch_worker():
|
||||||
|
while running[0]:
|
||||||
|
try:
|
||||||
|
# Sample data and add to queue
|
||||||
|
data = self.sample(batch_size)
|
||||||
|
data_queue.put(data, block=True, timeout=0.5)
|
||||||
|
except queue.Full:
|
||||||
|
continue
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Prefetch error: {e}")
|
||||||
|
break
|
||||||
|
|
||||||
|
# Start prefetching thread
|
||||||
|
thread = threading.Thread(target=prefetch_worker, daemon=True)
|
||||||
|
thread.start()
|
||||||
|
|
||||||
|
try:
|
||||||
|
while running[0]:
|
||||||
|
try:
|
||||||
|
yield data_queue.get(block=True, timeout=0.5)
|
||||||
|
except queue.Empty:
|
||||||
|
if not thread.is_alive():
|
||||||
|
break
|
||||||
|
finally:
|
||||||
|
# Clean up
|
||||||
|
running[0] = False
|
||||||
|
thread.join(timeout=1.0)
|
||||||
|
|
||||||
|
def _get_naive_iterator(self, batch_size: int, queue_size: int = 2):
|
||||||
|
"""
|
||||||
|
Creates a simple non-threaded iterator that yields batches.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch_size (int): Size of batches to sample
|
||||||
|
queue_size (int): Number of initial batches to prefetch
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
BatchTransition: Batch transitions
|
||||||
|
"""
|
||||||
|
import collections
|
||||||
|
|
||||||
|
queue = collections.deque()
|
||||||
|
|
||||||
|
def enqueue(n):
|
||||||
|
for _ in range(n):
|
||||||
|
data = self.sample(batch_size)
|
||||||
|
queue.append(data)
|
||||||
|
|
||||||
|
enqueue(queue_size)
|
||||||
|
while queue:
|
||||||
|
yield queue.popleft()
|
||||||
|
enqueue(1)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_lerobot_dataset(
|
def from_lerobot_dataset(
|
||||||
cls,
|
cls,
|
||||||
|
@ -710,475 +813,4 @@ def concatenate_batch_transitions(
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
from tempfile import TemporaryDirectory
|
pass # All test code is currently commented out
|
||||||
|
|
||||||
# ===== Test 1: Create and use a synthetic ReplayBuffer =====
|
|
||||||
print("Testing synthetic ReplayBuffer...")
|
|
||||||
|
|
||||||
# Create sample data dimensions
|
|
||||||
batch_size = 32
|
|
||||||
state_dims = {"observation.image": (3, 84, 84), "observation.state": (10,)}
|
|
||||||
action_dim = (6,)
|
|
||||||
|
|
||||||
# Create a buffer
|
|
||||||
buffer = ReplayBuffer(
|
|
||||||
capacity=1000,
|
|
||||||
device="cpu",
|
|
||||||
state_keys=list(state_dims.keys()),
|
|
||||||
use_drq=True,
|
|
||||||
storage_device="cpu",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add some random transitions
|
|
||||||
for i in range(100):
|
|
||||||
# Create dummy transition data
|
|
||||||
state = {
|
|
||||||
"observation.image": torch.rand(1, 3, 84, 84),
|
|
||||||
"observation.state": torch.rand(1, 10),
|
|
||||||
}
|
|
||||||
action = torch.rand(1, 6)
|
|
||||||
reward = 0.5
|
|
||||||
next_state = {
|
|
||||||
"observation.image": torch.rand(1, 3, 84, 84),
|
|
||||||
"observation.state": torch.rand(1, 10),
|
|
||||||
}
|
|
||||||
done = False if i < 99 else True
|
|
||||||
truncated = False
|
|
||||||
|
|
||||||
buffer.add(
|
|
||||||
state=state,
|
|
||||||
action=action,
|
|
||||||
reward=reward,
|
|
||||||
next_state=next_state,
|
|
||||||
done=done,
|
|
||||||
truncated=truncated,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Test sampling
|
|
||||||
batch = buffer.sample(batch_size)
|
|
||||||
print(f"Buffer size: {len(buffer)}")
|
|
||||||
print(
|
|
||||||
f"Sampled batch state shapes: {batch['state']['observation.image'].shape}, {batch['state']['observation.state'].shape}"
|
|
||||||
)
|
|
||||||
print(f"Sampled batch action shape: {batch['action'].shape}")
|
|
||||||
print(f"Sampled batch reward shape: {batch['reward'].shape}")
|
|
||||||
print(f"Sampled batch done shape: {batch['done'].shape}")
|
|
||||||
print(f"Sampled batch truncated shape: {batch['truncated'].shape}")
|
|
||||||
|
|
||||||
# ===== Test for state-action-reward alignment =====
|
|
||||||
print("\nTesting state-action-reward alignment...")
|
|
||||||
|
|
||||||
# Create a buffer with controlled transitions where we know the relationships
|
|
||||||
aligned_buffer = ReplayBuffer(
|
|
||||||
capacity=100, device="cpu", state_keys=["state_value"], storage_device="cpu"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create transitions with known relationships
|
|
||||||
# - Each state has a unique signature value
|
|
||||||
# - Action is 2x the state signature
|
|
||||||
# - Reward is 3x the state signature
|
|
||||||
# - Next state is signature + 0.01 (unless at episode end)
|
|
||||||
for i in range(100):
|
|
||||||
# Create a state with a signature value that encodes the transition number
|
|
||||||
signature = float(i) / 100.0
|
|
||||||
state = {"state_value": torch.tensor([[signature]]).float()}
|
|
||||||
|
|
||||||
# Action is 2x the signature
|
|
||||||
action = torch.tensor([[2.0 * signature]]).float()
|
|
||||||
|
|
||||||
# Reward is 3x the signature
|
|
||||||
reward = 3.0 * signature
|
|
||||||
|
|
||||||
# Next state is signature + 0.01, unless end of episode
|
|
||||||
# End episode every 10 steps
|
|
||||||
is_end = (i + 1) % 10 == 0
|
|
||||||
|
|
||||||
if is_end:
|
|
||||||
# At episode boundaries, next_state repeats current state (as per your implementation)
|
|
||||||
next_state = {"state_value": torch.tensor([[signature]]).float()}
|
|
||||||
done = True
|
|
||||||
else:
|
|
||||||
# Within episodes, next_state has signature + 0.01
|
|
||||||
next_signature = float(i + 1) / 100.0
|
|
||||||
next_state = {"state_value": torch.tensor([[next_signature]]).float()}
|
|
||||||
done = False
|
|
||||||
|
|
||||||
aligned_buffer.add(state, action, reward, next_state, done, False)
|
|
||||||
|
|
||||||
# Sample from this buffer
|
|
||||||
aligned_batch = aligned_buffer.sample(50)
|
|
||||||
|
|
||||||
# Verify alignments in sampled batch
|
|
||||||
correct_relationships = 0
|
|
||||||
total_checks = 0
|
|
||||||
|
|
||||||
# For each transition in the batch
|
|
||||||
for i in range(50):
|
|
||||||
# Extract signature from state
|
|
||||||
state_sig = aligned_batch["state"]["state_value"][i].item()
|
|
||||||
|
|
||||||
# Check action is 2x signature (within reasonable precision)
|
|
||||||
action_val = aligned_batch["action"][i].item()
|
|
||||||
action_check = abs(action_val - 2.0 * state_sig) < 1e-4
|
|
||||||
|
|
||||||
# Check reward is 3x signature (within reasonable precision)
|
|
||||||
reward_val = aligned_batch["reward"][i].item()
|
|
||||||
reward_check = abs(reward_val - 3.0 * state_sig) < 1e-4
|
|
||||||
|
|
||||||
# Check next_state relationship matches our pattern
|
|
||||||
next_state_sig = aligned_batch["next_state"]["state_value"][i].item()
|
|
||||||
is_done = aligned_batch["done"][i].item() > 0.5
|
|
||||||
|
|
||||||
# Calculate expected next_state value based on done flag
|
|
||||||
if is_done:
|
|
||||||
# For episodes that end, next_state should equal state
|
|
||||||
next_state_check = abs(next_state_sig - state_sig) < 1e-4
|
|
||||||
else:
|
|
||||||
# For continuing episodes, check if next_state is approximately state + 0.01
|
|
||||||
# We need to be careful because we don't know the original index
|
|
||||||
# So we check if the increment is roughly 0.01
|
|
||||||
next_state_check = (
|
|
||||||
abs(next_state_sig - state_sig - 0.01) < 1e-4 or abs(next_state_sig - state_sig) < 1e-4
|
|
||||||
)
|
|
||||||
|
|
||||||
# Count correct relationships
|
|
||||||
if action_check:
|
|
||||||
correct_relationships += 1
|
|
||||||
if reward_check:
|
|
||||||
correct_relationships += 1
|
|
||||||
if next_state_check:
|
|
||||||
correct_relationships += 1
|
|
||||||
|
|
||||||
total_checks += 3
|
|
||||||
|
|
||||||
alignment_accuracy = 100.0 * correct_relationships / total_checks
|
|
||||||
print(f"State-action-reward-next_state alignment accuracy: {alignment_accuracy:.2f}%")
|
|
||||||
if alignment_accuracy > 99.0:
|
|
||||||
print("✅ All relationships verified! Buffer maintains correct temporal relationships.")
|
|
||||||
else:
|
|
||||||
print("⚠️ Some relationships don't match expected patterns. Buffer may have alignment issues.")
|
|
||||||
|
|
||||||
# Print some debug information about failures
|
|
||||||
print("\nDebug information for failed checks:")
|
|
||||||
for i in range(5): # Print first 5 transitions for debugging
|
|
||||||
state_sig = aligned_batch["state"]["state_value"][i].item()
|
|
||||||
action_val = aligned_batch["action"][i].item()
|
|
||||||
reward_val = aligned_batch["reward"][i].item()
|
|
||||||
next_state_sig = aligned_batch["next_state"]["state_value"][i].item()
|
|
||||||
is_done = aligned_batch["done"][i].item() > 0.5
|
|
||||||
|
|
||||||
print(f"Transition {i}:")
|
|
||||||
print(f" State: {state_sig:.6f}")
|
|
||||||
print(f" Action: {action_val:.6f} (expected: {2.0 * state_sig:.6f})")
|
|
||||||
print(f" Reward: {reward_val:.6f} (expected: {3.0 * state_sig:.6f})")
|
|
||||||
print(f" Done: {is_done}")
|
|
||||||
print(f" Next state: {next_state_sig:.6f}")
|
|
||||||
|
|
||||||
# Calculate expected next state
|
|
||||||
if is_done:
|
|
||||||
expected_next = state_sig
|
|
||||||
else:
|
|
||||||
# This approximation might not be perfect
|
|
||||||
state_idx = round(state_sig * 100)
|
|
||||||
expected_next = (state_idx + 1) / 100.0
|
|
||||||
|
|
||||||
print(f" Expected next state: {expected_next:.6f}")
|
|
||||||
print()
|
|
||||||
|
|
||||||
# ===== Test 2: Convert to LeRobotDataset and back =====
|
|
||||||
with TemporaryDirectory() as temp_dir:
|
|
||||||
print("\nTesting conversion to LeRobotDataset and back...")
|
|
||||||
# Convert buffer to dataset
|
|
||||||
repo_id = "test/replay_buffer_conversion"
|
|
||||||
# Create a subdirectory to avoid the "directory exists" error
|
|
||||||
dataset_dir = os.path.join(temp_dir, "dataset1")
|
|
||||||
dataset = buffer.to_lerobot_dataset(repo_id=repo_id, root=dataset_dir)
|
|
||||||
|
|
||||||
print(f"Dataset created with {len(dataset)} frames")
|
|
||||||
print(f"Dataset features: {list(dataset.features.keys())}")
|
|
||||||
|
|
||||||
# Check a random sample from the dataset
|
|
||||||
sample = dataset[0]
|
|
||||||
print(
|
|
||||||
f"Dataset sample types: {[(k, type(v)) for k, v in sample.items() if k.startswith('observation')]}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Convert dataset back to buffer
|
|
||||||
reconverted_buffer = ReplayBuffer.from_lerobot_dataset(
|
|
||||||
dataset, state_keys=list(state_dims.keys()), device="cpu"
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"Reconverted buffer size: {len(reconverted_buffer)}")
|
|
||||||
|
|
||||||
# Sample from the reconverted buffer
|
|
||||||
reconverted_batch = reconverted_buffer.sample(batch_size)
|
|
||||||
print(
|
|
||||||
f"Reconverted batch state shapes: {reconverted_batch['state']['observation.image'].shape}, {reconverted_batch['state']['observation.state'].shape}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify consistency before and after conversion
|
|
||||||
original_states = batch["state"]["observation.image"].mean().item()
|
|
||||||
reconverted_states = reconverted_batch["state"]["observation.image"].mean().item()
|
|
||||||
print(f"Original buffer state mean: {original_states:.4f}")
|
|
||||||
print(f"Reconverted buffer state mean: {reconverted_states:.4f}")
|
|
||||||
|
|
||||||
if abs(original_states - reconverted_states) < 1.0:
|
|
||||||
print("Values are reasonably similar - conversion works as expected")
|
|
||||||
else:
|
|
||||||
print("WARNING: Significant difference between original and reconverted values")
|
|
||||||
|
|
||||||
print("\nAll previous tests completed!")
|
|
||||||
|
|
||||||
# ===== Test for memory optimization =====
|
|
||||||
print("\n===== Testing Memory Optimization =====")
|
|
||||||
|
|
||||||
# Create two buffers, one with memory optimization and one without
|
|
||||||
standard_buffer = ReplayBuffer(
|
|
||||||
capacity=1000,
|
|
||||||
device="cpu",
|
|
||||||
state_keys=["observation.image", "observation.state"],
|
|
||||||
storage_device="cpu",
|
|
||||||
optimize_memory=False,
|
|
||||||
use_drq=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
optimized_buffer = ReplayBuffer(
|
|
||||||
capacity=1000,
|
|
||||||
device="cpu",
|
|
||||||
state_keys=["observation.image", "observation.state"],
|
|
||||||
storage_device="cpu",
|
|
||||||
optimize_memory=True,
|
|
||||||
use_drq=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Generate sample data with larger state dimensions for better memory impact
|
|
||||||
print("Generating test data...")
|
|
||||||
num_episodes = 10
|
|
||||||
steps_per_episode = 50
|
|
||||||
total_steps = num_episodes * steps_per_episode
|
|
||||||
|
|
||||||
for episode in range(num_episodes):
|
|
||||||
for step in range(steps_per_episode):
|
|
||||||
# Index in the overall sequence
|
|
||||||
i = episode * steps_per_episode + step
|
|
||||||
|
|
||||||
# Create state with identifiable values
|
|
||||||
img = torch.ones((3, 84, 84)) * (i / total_steps)
|
|
||||||
state_vec = torch.ones((10,)) * (i / total_steps)
|
|
||||||
|
|
||||||
state = {
|
|
||||||
"observation.image": img.unsqueeze(0),
|
|
||||||
"observation.state": state_vec.unsqueeze(0),
|
|
||||||
}
|
|
||||||
|
|
||||||
# Create next state (i+1 or same as current if last in episode)
|
|
||||||
is_last_step = step == steps_per_episode - 1
|
|
||||||
|
|
||||||
if is_last_step:
|
|
||||||
# At episode end, next state = current state
|
|
||||||
next_img = img.clone()
|
|
||||||
next_state_vec = state_vec.clone()
|
|
||||||
done = True
|
|
||||||
truncated = False
|
|
||||||
else:
|
|
||||||
# Within episode, next state has incremented value
|
|
||||||
next_val = (i + 1) / total_steps
|
|
||||||
next_img = torch.ones((3, 84, 84)) * next_val
|
|
||||||
next_state_vec = torch.ones((10,)) * next_val
|
|
||||||
done = False
|
|
||||||
truncated = False
|
|
||||||
|
|
||||||
next_state = {
|
|
||||||
"observation.image": next_img.unsqueeze(0),
|
|
||||||
"observation.state": next_state_vec.unsqueeze(0),
|
|
||||||
}
|
|
||||||
|
|
||||||
# Action and reward
|
|
||||||
action = torch.tensor([[i / total_steps]])
|
|
||||||
reward = float(i / total_steps)
|
|
||||||
|
|
||||||
# Add to both buffers
|
|
||||||
standard_buffer.add(state, action, reward, next_state, done, truncated)
|
|
||||||
optimized_buffer.add(state, action, reward, next_state, done, truncated)
|
|
||||||
|
|
||||||
# Verify episode boundaries with our simplified approach
|
|
||||||
print("\nVerifying simplified memory optimization...")
|
|
||||||
|
|
||||||
# Test with a new buffer with a small sequence
|
|
||||||
test_buffer = ReplayBuffer(
|
|
||||||
capacity=20,
|
|
||||||
device="cpu",
|
|
||||||
state_keys=["value"],
|
|
||||||
storage_device="cpu",
|
|
||||||
optimize_memory=True,
|
|
||||||
use_drq=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add a simple sequence with known episode boundaries
|
|
||||||
for i in range(20):
|
|
||||||
val = float(i)
|
|
||||||
state = {"value": torch.tensor([[val]]).float()}
|
|
||||||
next_val = float(i + 1) if i % 5 != 4 else val # Episode ends every 5 steps
|
|
||||||
next_state = {"value": torch.tensor([[next_val]]).float()}
|
|
||||||
|
|
||||||
# Set done=True at every 5th step
|
|
||||||
done = (i % 5) == 4
|
|
||||||
action = torch.tensor([[0.0]])
|
|
||||||
reward = 1.0
|
|
||||||
truncated = False
|
|
||||||
|
|
||||||
test_buffer.add(state, action, reward, next_state, done, truncated)
|
|
||||||
|
|
||||||
# Get sequential batch for verification
|
|
||||||
sequential_batch_size = test_buffer.size
|
|
||||||
all_indices = torch.arange(sequential_batch_size, device=test_buffer.storage_device)
|
|
||||||
|
|
||||||
# Get state tensors
|
|
||||||
batch_state = {"value": test_buffer.states["value"][all_indices].to(test_buffer.device)}
|
|
||||||
|
|
||||||
# Get next_state using memory-optimized approach (simply index+1)
|
|
||||||
next_indices = (all_indices + 1) % test_buffer.capacity
|
|
||||||
batch_next_state = {"value": test_buffer.states["value"][next_indices].to(test_buffer.device)}
|
|
||||||
|
|
||||||
# Get other tensors
|
|
||||||
batch_dones = test_buffer.dones[all_indices].to(test_buffer.device)
|
|
||||||
|
|
||||||
# Print sequential values
|
|
||||||
print("State, Next State, Done (Sequential values with simplified optimization):")
|
|
||||||
state_values = batch_state["value"].squeeze().tolist()
|
|
||||||
next_values = batch_next_state["value"].squeeze().tolist()
|
|
||||||
done_flags = batch_dones.tolist()
|
|
||||||
|
|
||||||
# Print all values
|
|
||||||
for i in range(len(state_values)):
|
|
||||||
print(f" {state_values[i]:.1f} → {next_values[i]:.1f}, Done: {done_flags[i]}")
|
|
||||||
|
|
||||||
# Explain the memory optimization tradeoff
|
|
||||||
print("\nWith simplified memory optimization:")
|
|
||||||
print("- We always use the next state in the buffer (index+1) as next_state")
|
|
||||||
print("- For terminal states, this means using the first state of the next episode")
|
|
||||||
print("- This is a common tradeoff in RL implementations for memory efficiency")
|
|
||||||
print("- Since we track done flags, the algorithm can handle these transitions correctly")
|
|
||||||
|
|
||||||
# Test random sampling
|
|
||||||
print("\nVerifying random sampling with simplified memory optimization...")
|
|
||||||
random_samples = test_buffer.sample(20) # Sample all transitions
|
|
||||||
|
|
||||||
# Extract values
|
|
||||||
random_state_values = random_samples["state"]["value"].squeeze().tolist()
|
|
||||||
random_next_values = random_samples["next_state"]["value"].squeeze().tolist()
|
|
||||||
random_done_flags = random_samples["done"].bool().tolist()
|
|
||||||
|
|
||||||
# Print a few samples
|
|
||||||
print("Random samples - State, Next State, Done (First 10):")
|
|
||||||
for i in range(10):
|
|
||||||
print(f" {random_state_values[i]:.1f} → {random_next_values[i]:.1f}, Done: {random_done_flags[i]}")
|
|
||||||
|
|
||||||
# Calculate memory savings
|
|
||||||
# Assume optimized_buffer and standard_buffer have already been initialized and filled
|
|
||||||
std_mem = (
|
|
||||||
sum(
|
|
||||||
standard_buffer.states[key].nelement() * standard_buffer.states[key].element_size()
|
|
||||||
for key in standard_buffer.states
|
|
||||||
)
|
|
||||||
* 2
|
|
||||||
)
|
|
||||||
opt_mem = sum(
|
|
||||||
optimized_buffer.states[key].nelement() * optimized_buffer.states[key].element_size()
|
|
||||||
for key in optimized_buffer.states
|
|
||||||
)
|
|
||||||
|
|
||||||
savings_percent = (std_mem - opt_mem) / std_mem * 100
|
|
||||||
|
|
||||||
print("\nMemory optimization result:")
|
|
||||||
print(f"- Standard buffer state memory: {std_mem / (1024 * 1024):.2f} MB")
|
|
||||||
print(f"- Optimized buffer state memory: {opt_mem / (1024 * 1024):.2f} MB")
|
|
||||||
print(f"- Memory savings for state tensors: {savings_percent:.1f}%")
|
|
||||||
|
|
||||||
print("\nAll memory optimization tests completed!")
|
|
||||||
|
|
||||||
# # ===== Test real dataset conversion =====
|
|
||||||
# print("\n===== Testing Real LeRobotDataset Conversion =====")
|
|
||||||
# try:
|
|
||||||
# # Try to use a real dataset if available
|
|
||||||
# dataset_name = "AdilZtn/Maniskill-Pushcube-demonstration-small"
|
|
||||||
# dataset = LeRobotDataset(repo_id=dataset_name)
|
|
||||||
|
|
||||||
# # Print available keys to debug
|
|
||||||
# sample = dataset[0]
|
|
||||||
# print("Available keys in dataset:", list(sample.keys()))
|
|
||||||
|
|
||||||
# # Check for required keys
|
|
||||||
# if "action" not in sample or "next.reward" not in sample:
|
|
||||||
# print("Dataset missing essential keys. Cannot convert.")
|
|
||||||
# raise ValueError("Missing required keys in dataset")
|
|
||||||
|
|
||||||
# # Auto-detect appropriate state keys
|
|
||||||
# image_keys = []
|
|
||||||
# state_keys = []
|
|
||||||
# for k, v in sample.items():
|
|
||||||
# # Skip metadata keys and action/reward keys
|
|
||||||
# if k in {
|
|
||||||
# "index",
|
|
||||||
# "episode_index",
|
|
||||||
# "frame_index",
|
|
||||||
# "timestamp",
|
|
||||||
# "task_index",
|
|
||||||
# "action",
|
|
||||||
# "next.reward",
|
|
||||||
# "next.done",
|
|
||||||
# }:
|
|
||||||
# continue
|
|
||||||
|
|
||||||
# # Infer key type from tensor shape
|
|
||||||
# if isinstance(v, torch.Tensor):
|
|
||||||
# if len(v.shape) == 3 and (v.shape[0] == 3 or v.shape[0] == 1):
|
|
||||||
# # Likely an image (channels, height, width)
|
|
||||||
# image_keys.append(k)
|
|
||||||
# else:
|
|
||||||
# # Likely state or other vector
|
|
||||||
# state_keys.append(k)
|
|
||||||
|
|
||||||
# print(f"Detected image keys: {image_keys}")
|
|
||||||
# print(f"Detected state keys: {state_keys}")
|
|
||||||
|
|
||||||
# if not image_keys and not state_keys:
|
|
||||||
# print("No usable keys found in dataset, skipping further tests")
|
|
||||||
# raise ValueError("No usable keys found in dataset")
|
|
||||||
|
|
||||||
# # Test with standard and memory-optimized buffers
|
|
||||||
# for optimize_memory in [False, True]:
|
|
||||||
# buffer_type = "Standard" if not optimize_memory else "Memory-optimized"
|
|
||||||
# print(f"\nTesting {buffer_type} buffer with real dataset...")
|
|
||||||
|
|
||||||
# # Convert to ReplayBuffer with detected keys
|
|
||||||
# replay_buffer = ReplayBuffer.from_lerobot_dataset(
|
|
||||||
# lerobot_dataset=dataset,
|
|
||||||
# state_keys=image_keys + state_keys,
|
|
||||||
# device="cpu",
|
|
||||||
# optimize_memory=optimize_memory,
|
|
||||||
# )
|
|
||||||
# print(f"Loaded {len(replay_buffer)} transitions from {dataset_name}")
|
|
||||||
|
|
||||||
# # Test sampling
|
|
||||||
# real_batch = replay_buffer.sample(32)
|
|
||||||
# print(f"Sampled batch from real dataset ({buffer_type}), state shapes:")
|
|
||||||
# for key in real_batch["state"]:
|
|
||||||
# print(f" {key}: {real_batch['state'][key].shape}")
|
|
||||||
|
|
||||||
# # Convert back to LeRobotDataset
|
|
||||||
# with TemporaryDirectory() as temp_dir:
|
|
||||||
# dataset_name = f"test/real_dataset_converted_{buffer_type}"
|
|
||||||
# replay_buffer_converted = replay_buffer.to_lerobot_dataset(
|
|
||||||
# repo_id=dataset_name,
|
|
||||||
# root=os.path.join(temp_dir, f"dataset_{buffer_type}"),
|
|
||||||
# )
|
|
||||||
# print(
|
|
||||||
# f"Successfully converted back to LeRobotDataset with {len(replay_buffer_converted)} frames"
|
|
||||||
# )
|
|
||||||
|
|
||||||
# except Exception as e:
|
|
||||||
# print(f"Real dataset test failed: {e}")
|
|
||||||
# print("This is expected if running offline or if the dataset is not available.")
|
|
||||||
|
|
||||||
# print("\nAll tests completed!")
|
|
||||||
|
|
|
@ -269,6 +269,7 @@ def add_actor_information_and_train(
|
||||||
policy_parameters_push_frequency = cfg.policy.actor_learner_config.policy_parameters_push_frequency
|
policy_parameters_push_frequency = cfg.policy.actor_learner_config.policy_parameters_push_frequency
|
||||||
saving_checkpoint = cfg.save_checkpoint
|
saving_checkpoint = cfg.save_checkpoint
|
||||||
online_steps = cfg.policy.online_steps
|
online_steps = cfg.policy.online_steps
|
||||||
|
async_prefetch = cfg.policy.async_prefetch
|
||||||
|
|
||||||
# Initialize logging for multiprocessing
|
# Initialize logging for multiprocessing
|
||||||
if not use_threads(cfg):
|
if not use_threads(cfg):
|
||||||
|
@ -326,6 +327,9 @@ def add_actor_information_and_train(
|
||||||
if cfg.dataset is not None:
|
if cfg.dataset is not None:
|
||||||
dataset_repo_id = cfg.dataset.repo_id
|
dataset_repo_id = cfg.dataset.repo_id
|
||||||
|
|
||||||
|
# Initialize iterators
|
||||||
|
online_iterator = None
|
||||||
|
offline_iterator = None
|
||||||
# NOTE: THIS IS THE MAIN LOOP OF THE LEARNER
|
# NOTE: THIS IS THE MAIN LOOP OF THE LEARNER
|
||||||
while True:
|
while True:
|
||||||
# Exit the training loop if shutdown is requested
|
# Exit the training loop if shutdown is requested
|
||||||
|
@ -359,13 +363,26 @@ def add_actor_information_and_train(
|
||||||
if len(replay_buffer) < online_step_before_learning:
|
if len(replay_buffer) < online_step_before_learning:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
if online_iterator is None:
|
||||||
|
logging.debug("[LEARNER] Initializing online replay buffer iterator")
|
||||||
|
online_iterator = replay_buffer.get_iterator(
|
||||||
|
batch_size=batch_size, async_prefetch=async_prefetch, queue_size=2
|
||||||
|
)
|
||||||
|
|
||||||
|
if offline_replay_buffer is not None and offline_iterator is None:
|
||||||
|
logging.debug("[LEARNER] Initializing offline replay buffer iterator")
|
||||||
|
offline_iterator = offline_replay_buffer.get_iterator(
|
||||||
|
batch_size=batch_size, async_prefetch=async_prefetch, queue_size=2
|
||||||
|
)
|
||||||
|
|
||||||
logging.debug("[LEARNER] Starting optimization loop")
|
logging.debug("[LEARNER] Starting optimization loop")
|
||||||
time_for_one_optimization_step = time.time()
|
time_for_one_optimization_step = time.time()
|
||||||
for _ in range(utd_ratio - 1):
|
for _ in range(utd_ratio - 1):
|
||||||
batch = replay_buffer.sample(batch_size=batch_size)
|
# Sample from the iterators
|
||||||
|
batch = next(online_iterator)
|
||||||
|
|
||||||
if dataset_repo_id is not None:
|
if dataset_repo_id is not None:
|
||||||
batch_offline = offline_replay_buffer.sample(batch_size=batch_size)
|
batch_offline = next(offline_iterator)
|
||||||
batch = concatenate_batch_transitions(
|
batch = concatenate_batch_transitions(
|
||||||
left_batch_transitions=batch, right_batch_transition=batch_offline
|
left_batch_transitions=batch, right_batch_transition=batch_offline
|
||||||
)
|
)
|
||||||
|
@ -418,10 +435,11 @@ def add_actor_information_and_train(
|
||||||
# Update target networks
|
# Update target networks
|
||||||
policy.update_target_networks()
|
policy.update_target_networks()
|
||||||
|
|
||||||
batch = replay_buffer.sample(batch_size=batch_size)
|
# Sample for the last update in the UTD ratio
|
||||||
|
batch = next(online_iterator)
|
||||||
|
|
||||||
if dataset_repo_id is not None:
|
if dataset_repo_id is not None:
|
||||||
batch_offline = offline_replay_buffer.sample(batch_size=batch_size)
|
batch_offline = next(offline_iterator)
|
||||||
batch = concatenate_batch_transitions(
|
batch = concatenate_batch_transitions(
|
||||||
left_batch_transitions=batch, right_batch_transition=batch_offline
|
left_batch_transitions=batch, right_batch_transition=batch_offline
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue