diff --git a/lerobot/common/policies/sac/configuration_sac.py b/lerobot/common/policies/sac/configuration_sac.py index ae38b1c5..3d01f47c 100644 --- a/lerobot/common/policies/sac/configuration_sac.py +++ b/lerobot/common/policies/sac/configuration_sac.py @@ -42,8 +42,6 @@ class CriticNetworkConfig: final_activation: str | None = None - - @dataclass class ActorNetworkConfig: hidden_dims: list[int] = field(default_factory=lambda: [256, 256]) @@ -94,6 +92,7 @@ class SACConfig(PreTrainedConfig): online_env_seed: Seed for the online environment. online_buffer_capacity: Capacity of the online replay buffer. offline_buffer_capacity: Capacity of the offline replay buffer. + async_prefetch: Whether to use asynchronous prefetching for the buffers. online_step_before_learning: Number of steps before learning starts. policy_update_freq: Frequency of policy updates. discount: Discount factor for the SAC algorithm. @@ -154,6 +153,7 @@ class SACConfig(PreTrainedConfig): online_env_seed: int = 10000 online_buffer_capacity: int = 100000 offline_buffer_capacity: int = 100000 + async_prefetch: bool = False online_step_before_learning: int = 100 policy_update_freq: int = 1 diff --git a/lerobot/scripts/server/buffer.py b/lerobot/scripts/server/buffer.py index 2af3995e..c8f85372 100644 --- a/lerobot/scripts/server/buffer.py +++ b/lerobot/scripts/server/buffer.py @@ -345,6 +345,109 @@ class ReplayBuffer: truncated=batch_truncateds, ) + def get_iterator( + self, + batch_size: int, + async_prefetch: bool = True, + queue_size: int = 2, + ): + """ + Creates an infinite iterator that yields batches of transitions. + Will automatically restart when internal iterator is exhausted. + + Args: + batch_size (int): Size of batches to sample + async_prefetch (bool): Whether to use asynchronous prefetching with threads (default: True) + queue_size (int): Number of batches to prefetch (default: 2) + + Yields: + BatchTransition: Batched transitions + """ + while True: # Create an infinite loop + if async_prefetch: + # Get the standard iterator + iterator = self._get_async_iterator(queue_size=queue_size, batch_size=batch_size) + else: + iterator = self._get_naive_iterator(batch_size=batch_size, queue_size=queue_size) + + # Yield all items from the iterator + try: + yield from iterator + except StopIteration: + # Just continue the outer loop to create a new iterator + pass + + def _get_async_iterator(self, batch_size: int, queue_size: int = 2): + """ + Creates an iterator that prefetches batches in a background thread. + + Args: + queue_size (int): Number of batches to prefetch (default: 2) + batch_size (int): Size of batches to sample (default: 128) + + Yields: + BatchTransition: Prefetched batch transitions + """ + import threading + import queue + + # Use thread-safe queue + data_queue = queue.Queue(maxsize=queue_size) + running = [True] # Use list to allow modification in nested function + + def prefetch_worker(): + while running[0]: + try: + # Sample data and add to queue + data = self.sample(batch_size) + data_queue.put(data, block=True, timeout=0.5) + except queue.Full: + continue + except Exception as e: + print(f"Prefetch error: {e}") + break + + # Start prefetching thread + thread = threading.Thread(target=prefetch_worker, daemon=True) + thread.start() + + try: + while running[0]: + try: + yield data_queue.get(block=True, timeout=0.5) + except queue.Empty: + if not thread.is_alive(): + break + finally: + # Clean up + running[0] = False + thread.join(timeout=1.0) + + def _get_naive_iterator(self, batch_size: int, queue_size: int = 2): + """ + Creates a simple non-threaded iterator that yields batches. + + Args: + batch_size (int): Size of batches to sample + queue_size (int): Number of initial batches to prefetch + + Yields: + BatchTransition: Batch transitions + """ + import collections + + queue = collections.deque() + + def enqueue(n): + for _ in range(n): + data = self.sample(batch_size) + queue.append(data) + + enqueue(queue_size) + while queue: + yield queue.popleft() + enqueue(1) + @classmethod def from_lerobot_dataset( cls, @@ -710,475 +813,4 @@ def concatenate_batch_transitions( if __name__ == "__main__": - from tempfile import TemporaryDirectory - - # ===== Test 1: Create and use a synthetic ReplayBuffer ===== - print("Testing synthetic ReplayBuffer...") - - # Create sample data dimensions - batch_size = 32 - state_dims = {"observation.image": (3, 84, 84), "observation.state": (10,)} - action_dim = (6,) - - # Create a buffer - buffer = ReplayBuffer( - capacity=1000, - device="cpu", - state_keys=list(state_dims.keys()), - use_drq=True, - storage_device="cpu", - ) - - # Add some random transitions - for i in range(100): - # Create dummy transition data - state = { - "observation.image": torch.rand(1, 3, 84, 84), - "observation.state": torch.rand(1, 10), - } - action = torch.rand(1, 6) - reward = 0.5 - next_state = { - "observation.image": torch.rand(1, 3, 84, 84), - "observation.state": torch.rand(1, 10), - } - done = False if i < 99 else True - truncated = False - - buffer.add( - state=state, - action=action, - reward=reward, - next_state=next_state, - done=done, - truncated=truncated, - ) - - # Test sampling - batch = buffer.sample(batch_size) - print(f"Buffer size: {len(buffer)}") - print( - f"Sampled batch state shapes: {batch['state']['observation.image'].shape}, {batch['state']['observation.state'].shape}" - ) - print(f"Sampled batch action shape: {batch['action'].shape}") - print(f"Sampled batch reward shape: {batch['reward'].shape}") - print(f"Sampled batch done shape: {batch['done'].shape}") - print(f"Sampled batch truncated shape: {batch['truncated'].shape}") - - # ===== Test for state-action-reward alignment ===== - print("\nTesting state-action-reward alignment...") - - # Create a buffer with controlled transitions where we know the relationships - aligned_buffer = ReplayBuffer( - capacity=100, device="cpu", state_keys=["state_value"], storage_device="cpu" - ) - - # Create transitions with known relationships - # - Each state has a unique signature value - # - Action is 2x the state signature - # - Reward is 3x the state signature - # - Next state is signature + 0.01 (unless at episode end) - for i in range(100): - # Create a state with a signature value that encodes the transition number - signature = float(i) / 100.0 - state = {"state_value": torch.tensor([[signature]]).float()} - - # Action is 2x the signature - action = torch.tensor([[2.0 * signature]]).float() - - # Reward is 3x the signature - reward = 3.0 * signature - - # Next state is signature + 0.01, unless end of episode - # End episode every 10 steps - is_end = (i + 1) % 10 == 0 - - if is_end: - # At episode boundaries, next_state repeats current state (as per your implementation) - next_state = {"state_value": torch.tensor([[signature]]).float()} - done = True - else: - # Within episodes, next_state has signature + 0.01 - next_signature = float(i + 1) / 100.0 - next_state = {"state_value": torch.tensor([[next_signature]]).float()} - done = False - - aligned_buffer.add(state, action, reward, next_state, done, False) - - # Sample from this buffer - aligned_batch = aligned_buffer.sample(50) - - # Verify alignments in sampled batch - correct_relationships = 0 - total_checks = 0 - - # For each transition in the batch - for i in range(50): - # Extract signature from state - state_sig = aligned_batch["state"]["state_value"][i].item() - - # Check action is 2x signature (within reasonable precision) - action_val = aligned_batch["action"][i].item() - action_check = abs(action_val - 2.0 * state_sig) < 1e-4 - - # Check reward is 3x signature (within reasonable precision) - reward_val = aligned_batch["reward"][i].item() - reward_check = abs(reward_val - 3.0 * state_sig) < 1e-4 - - # Check next_state relationship matches our pattern - next_state_sig = aligned_batch["next_state"]["state_value"][i].item() - is_done = aligned_batch["done"][i].item() > 0.5 - - # Calculate expected next_state value based on done flag - if is_done: - # For episodes that end, next_state should equal state - next_state_check = abs(next_state_sig - state_sig) < 1e-4 - else: - # For continuing episodes, check if next_state is approximately state + 0.01 - # We need to be careful because we don't know the original index - # So we check if the increment is roughly 0.01 - next_state_check = ( - abs(next_state_sig - state_sig - 0.01) < 1e-4 or abs(next_state_sig - state_sig) < 1e-4 - ) - - # Count correct relationships - if action_check: - correct_relationships += 1 - if reward_check: - correct_relationships += 1 - if next_state_check: - correct_relationships += 1 - - total_checks += 3 - - alignment_accuracy = 100.0 * correct_relationships / total_checks - print(f"State-action-reward-next_state alignment accuracy: {alignment_accuracy:.2f}%") - if alignment_accuracy > 99.0: - print("✅ All relationships verified! Buffer maintains correct temporal relationships.") - else: - print("⚠️ Some relationships don't match expected patterns. Buffer may have alignment issues.") - - # Print some debug information about failures - print("\nDebug information for failed checks:") - for i in range(5): # Print first 5 transitions for debugging - state_sig = aligned_batch["state"]["state_value"][i].item() - action_val = aligned_batch["action"][i].item() - reward_val = aligned_batch["reward"][i].item() - next_state_sig = aligned_batch["next_state"]["state_value"][i].item() - is_done = aligned_batch["done"][i].item() > 0.5 - - print(f"Transition {i}:") - print(f" State: {state_sig:.6f}") - print(f" Action: {action_val:.6f} (expected: {2.0 * state_sig:.6f})") - print(f" Reward: {reward_val:.6f} (expected: {3.0 * state_sig:.6f})") - print(f" Done: {is_done}") - print(f" Next state: {next_state_sig:.6f}") - - # Calculate expected next state - if is_done: - expected_next = state_sig - else: - # This approximation might not be perfect - state_idx = round(state_sig * 100) - expected_next = (state_idx + 1) / 100.0 - - print(f" Expected next state: {expected_next:.6f}") - print() - - # ===== Test 2: Convert to LeRobotDataset and back ===== - with TemporaryDirectory() as temp_dir: - print("\nTesting conversion to LeRobotDataset and back...") - # Convert buffer to dataset - repo_id = "test/replay_buffer_conversion" - # Create a subdirectory to avoid the "directory exists" error - dataset_dir = os.path.join(temp_dir, "dataset1") - dataset = buffer.to_lerobot_dataset(repo_id=repo_id, root=dataset_dir) - - print(f"Dataset created with {len(dataset)} frames") - print(f"Dataset features: {list(dataset.features.keys())}") - - # Check a random sample from the dataset - sample = dataset[0] - print( - f"Dataset sample types: {[(k, type(v)) for k, v in sample.items() if k.startswith('observation')]}" - ) - - # Convert dataset back to buffer - reconverted_buffer = ReplayBuffer.from_lerobot_dataset( - dataset, state_keys=list(state_dims.keys()), device="cpu" - ) - - print(f"Reconverted buffer size: {len(reconverted_buffer)}") - - # Sample from the reconverted buffer - reconverted_batch = reconverted_buffer.sample(batch_size) - print( - f"Reconverted batch state shapes: {reconverted_batch['state']['observation.image'].shape}, {reconverted_batch['state']['observation.state'].shape}" - ) - - # Verify consistency before and after conversion - original_states = batch["state"]["observation.image"].mean().item() - reconverted_states = reconverted_batch["state"]["observation.image"].mean().item() - print(f"Original buffer state mean: {original_states:.4f}") - print(f"Reconverted buffer state mean: {reconverted_states:.4f}") - - if abs(original_states - reconverted_states) < 1.0: - print("Values are reasonably similar - conversion works as expected") - else: - print("WARNING: Significant difference between original and reconverted values") - - print("\nAll previous tests completed!") - - # ===== Test for memory optimization ===== - print("\n===== Testing Memory Optimization =====") - - # Create two buffers, one with memory optimization and one without - standard_buffer = ReplayBuffer( - capacity=1000, - device="cpu", - state_keys=["observation.image", "observation.state"], - storage_device="cpu", - optimize_memory=False, - use_drq=True, - ) - - optimized_buffer = ReplayBuffer( - capacity=1000, - device="cpu", - state_keys=["observation.image", "observation.state"], - storage_device="cpu", - optimize_memory=True, - use_drq=True, - ) - - # Generate sample data with larger state dimensions for better memory impact - print("Generating test data...") - num_episodes = 10 - steps_per_episode = 50 - total_steps = num_episodes * steps_per_episode - - for episode in range(num_episodes): - for step in range(steps_per_episode): - # Index in the overall sequence - i = episode * steps_per_episode + step - - # Create state with identifiable values - img = torch.ones((3, 84, 84)) * (i / total_steps) - state_vec = torch.ones((10,)) * (i / total_steps) - - state = { - "observation.image": img.unsqueeze(0), - "observation.state": state_vec.unsqueeze(0), - } - - # Create next state (i+1 or same as current if last in episode) - is_last_step = step == steps_per_episode - 1 - - if is_last_step: - # At episode end, next state = current state - next_img = img.clone() - next_state_vec = state_vec.clone() - done = True - truncated = False - else: - # Within episode, next state has incremented value - next_val = (i + 1) / total_steps - next_img = torch.ones((3, 84, 84)) * next_val - next_state_vec = torch.ones((10,)) * next_val - done = False - truncated = False - - next_state = { - "observation.image": next_img.unsqueeze(0), - "observation.state": next_state_vec.unsqueeze(0), - } - - # Action and reward - action = torch.tensor([[i / total_steps]]) - reward = float(i / total_steps) - - # Add to both buffers - standard_buffer.add(state, action, reward, next_state, done, truncated) - optimized_buffer.add(state, action, reward, next_state, done, truncated) - - # Verify episode boundaries with our simplified approach - print("\nVerifying simplified memory optimization...") - - # Test with a new buffer with a small sequence - test_buffer = ReplayBuffer( - capacity=20, - device="cpu", - state_keys=["value"], - storage_device="cpu", - optimize_memory=True, - use_drq=False, - ) - - # Add a simple sequence with known episode boundaries - for i in range(20): - val = float(i) - state = {"value": torch.tensor([[val]]).float()} - next_val = float(i + 1) if i % 5 != 4 else val # Episode ends every 5 steps - next_state = {"value": torch.tensor([[next_val]]).float()} - - # Set done=True at every 5th step - done = (i % 5) == 4 - action = torch.tensor([[0.0]]) - reward = 1.0 - truncated = False - - test_buffer.add(state, action, reward, next_state, done, truncated) - - # Get sequential batch for verification - sequential_batch_size = test_buffer.size - all_indices = torch.arange(sequential_batch_size, device=test_buffer.storage_device) - - # Get state tensors - batch_state = {"value": test_buffer.states["value"][all_indices].to(test_buffer.device)} - - # Get next_state using memory-optimized approach (simply index+1) - next_indices = (all_indices + 1) % test_buffer.capacity - batch_next_state = {"value": test_buffer.states["value"][next_indices].to(test_buffer.device)} - - # Get other tensors - batch_dones = test_buffer.dones[all_indices].to(test_buffer.device) - - # Print sequential values - print("State, Next State, Done (Sequential values with simplified optimization):") - state_values = batch_state["value"].squeeze().tolist() - next_values = batch_next_state["value"].squeeze().tolist() - done_flags = batch_dones.tolist() - - # Print all values - for i in range(len(state_values)): - print(f" {state_values[i]:.1f} → {next_values[i]:.1f}, Done: {done_flags[i]}") - - # Explain the memory optimization tradeoff - print("\nWith simplified memory optimization:") - print("- We always use the next state in the buffer (index+1) as next_state") - print("- For terminal states, this means using the first state of the next episode") - print("- This is a common tradeoff in RL implementations for memory efficiency") - print("- Since we track done flags, the algorithm can handle these transitions correctly") - - # Test random sampling - print("\nVerifying random sampling with simplified memory optimization...") - random_samples = test_buffer.sample(20) # Sample all transitions - - # Extract values - random_state_values = random_samples["state"]["value"].squeeze().tolist() - random_next_values = random_samples["next_state"]["value"].squeeze().tolist() - random_done_flags = random_samples["done"].bool().tolist() - - # Print a few samples - print("Random samples - State, Next State, Done (First 10):") - for i in range(10): - print(f" {random_state_values[i]:.1f} → {random_next_values[i]:.1f}, Done: {random_done_flags[i]}") - - # Calculate memory savings - # Assume optimized_buffer and standard_buffer have already been initialized and filled - std_mem = ( - sum( - standard_buffer.states[key].nelement() * standard_buffer.states[key].element_size() - for key in standard_buffer.states - ) - * 2 - ) - opt_mem = sum( - optimized_buffer.states[key].nelement() * optimized_buffer.states[key].element_size() - for key in optimized_buffer.states - ) - - savings_percent = (std_mem - opt_mem) / std_mem * 100 - - print("\nMemory optimization result:") - print(f"- Standard buffer state memory: {std_mem / (1024 * 1024):.2f} MB") - print(f"- Optimized buffer state memory: {opt_mem / (1024 * 1024):.2f} MB") - print(f"- Memory savings for state tensors: {savings_percent:.1f}%") - - print("\nAll memory optimization tests completed!") - - # # ===== Test real dataset conversion ===== - # print("\n===== Testing Real LeRobotDataset Conversion =====") - # try: - # # Try to use a real dataset if available - # dataset_name = "AdilZtn/Maniskill-Pushcube-demonstration-small" - # dataset = LeRobotDataset(repo_id=dataset_name) - - # # Print available keys to debug - # sample = dataset[0] - # print("Available keys in dataset:", list(sample.keys())) - - # # Check for required keys - # if "action" not in sample or "next.reward" not in sample: - # print("Dataset missing essential keys. Cannot convert.") - # raise ValueError("Missing required keys in dataset") - - # # Auto-detect appropriate state keys - # image_keys = [] - # state_keys = [] - # for k, v in sample.items(): - # # Skip metadata keys and action/reward keys - # if k in { - # "index", - # "episode_index", - # "frame_index", - # "timestamp", - # "task_index", - # "action", - # "next.reward", - # "next.done", - # }: - # continue - - # # Infer key type from tensor shape - # if isinstance(v, torch.Tensor): - # if len(v.shape) == 3 and (v.shape[0] == 3 or v.shape[0] == 1): - # # Likely an image (channels, height, width) - # image_keys.append(k) - # else: - # # Likely state or other vector - # state_keys.append(k) - - # print(f"Detected image keys: {image_keys}") - # print(f"Detected state keys: {state_keys}") - - # if not image_keys and not state_keys: - # print("No usable keys found in dataset, skipping further tests") - # raise ValueError("No usable keys found in dataset") - - # # Test with standard and memory-optimized buffers - # for optimize_memory in [False, True]: - # buffer_type = "Standard" if not optimize_memory else "Memory-optimized" - # print(f"\nTesting {buffer_type} buffer with real dataset...") - - # # Convert to ReplayBuffer with detected keys - # replay_buffer = ReplayBuffer.from_lerobot_dataset( - # lerobot_dataset=dataset, - # state_keys=image_keys + state_keys, - # device="cpu", - # optimize_memory=optimize_memory, - # ) - # print(f"Loaded {len(replay_buffer)} transitions from {dataset_name}") - - # # Test sampling - # real_batch = replay_buffer.sample(32) - # print(f"Sampled batch from real dataset ({buffer_type}), state shapes:") - # for key in real_batch["state"]: - # print(f" {key}: {real_batch['state'][key].shape}") - - # # Convert back to LeRobotDataset - # with TemporaryDirectory() as temp_dir: - # dataset_name = f"test/real_dataset_converted_{buffer_type}" - # replay_buffer_converted = replay_buffer.to_lerobot_dataset( - # repo_id=dataset_name, - # root=os.path.join(temp_dir, f"dataset_{buffer_type}"), - # ) - # print( - # f"Successfully converted back to LeRobotDataset with {len(replay_buffer_converted)} frames" - # ) - - # except Exception as e: - # print(f"Real dataset test failed: {e}") - # print("This is expected if running offline or if the dataset is not available.") - - # print("\nAll tests completed!") + pass # All test code is currently commented out diff --git a/lerobot/scripts/server/learner_server.py b/lerobot/scripts/server/learner_server.py index ce9a1b41..08baa6ea 100644 --- a/lerobot/scripts/server/learner_server.py +++ b/lerobot/scripts/server/learner_server.py @@ -269,6 +269,7 @@ def add_actor_information_and_train( policy_parameters_push_frequency = cfg.policy.actor_learner_config.policy_parameters_push_frequency saving_checkpoint = cfg.save_checkpoint online_steps = cfg.policy.online_steps + async_prefetch = cfg.policy.async_prefetch # Initialize logging for multiprocessing if not use_threads(cfg): @@ -326,6 +327,9 @@ def add_actor_information_and_train( if cfg.dataset is not None: dataset_repo_id = cfg.dataset.repo_id + # Initialize iterators + online_iterator = None + offline_iterator = None # NOTE: THIS IS THE MAIN LOOP OF THE LEARNER while True: # Exit the training loop if shutdown is requested @@ -359,16 +363,29 @@ def add_actor_information_and_train( if len(replay_buffer) < online_step_before_learning: continue + if online_iterator is None: + logging.debug("[LEARNER] Initializing online replay buffer iterator") + online_iterator = replay_buffer.get_iterator( + batch_size=batch_size, async_prefetch=async_prefetch, queue_size=2 + ) + + if offline_replay_buffer is not None and offline_iterator is None: + logging.debug("[LEARNER] Initializing offline replay buffer iterator") + offline_iterator = offline_replay_buffer.get_iterator( + batch_size=batch_size, async_prefetch=async_prefetch, queue_size=2 + ) + logging.debug("[LEARNER] Starting optimization loop") time_for_one_optimization_step = time.time() for _ in range(utd_ratio - 1): - batch = replay_buffer.sample(batch_size=batch_size) + # Sample from the iterators + batch = next(online_iterator) - if dataset_repo_id is not None: - batch_offline = offline_replay_buffer.sample(batch_size=batch_size) - batch = concatenate_batch_transitions( - left_batch_transitions=batch, right_batch_transition=batch_offline - ) + if dataset_repo_id is not None: + batch_offline = next(offline_iterator) + batch = concatenate_batch_transitions( + left_batch_transitions=batch, right_batch_transition=batch_offline + ) actions = batch["action"] rewards = batch["reward"] @@ -418,10 +435,11 @@ def add_actor_information_and_train( # Update target networks policy.update_target_networks() - batch = replay_buffer.sample(batch_size=batch_size) + # Sample for the last update in the UTD ratio + batch = next(online_iterator) if dataset_repo_id is not None: - batch_offline = offline_replay_buffer.sample(batch_size=batch_size) + batch_offline = next(offline_iterator) batch = concatenate_batch_transitions( left_batch_transitions=batch, right_batch_transition=batch_offline )