diff --git a/lerobot/scripts/server/buffer.py b/lerobot/scripts/server/buffer.py index de278582..905157f1 100644 --- a/lerobot/scripts/server/buffer.py +++ b/lerobot/scripts/server/buffer.py @@ -178,6 +178,7 @@ class ReplayBuffer: image_augmentation_function: Optional[Callable] = None, use_drq: bool = True, storage_device: str = "cpu", + optimize_memory: bool = False, ): """ Args: @@ -189,6 +190,8 @@ class ReplayBuffer: use_drq (bool): Whether to use the default DRQ image augmentation style, when sampling in the buffer. storage_device: The device (e.g. "cpu" or "cuda:0") where the data will be stored. Using "cpu" can help save GPU memory. + optimize_memory (bool): If True, optimizes memory by not storing duplicate next_states when + they can be derived from states. This is useful for large datasets where next_state[i] = state[i+1]. """ self.capacity = capacity self.device = device @@ -196,6 +199,12 @@ class ReplayBuffer: self.position = 0 self.size = 0 self.initialized = False + self.optimize_memory = optimize_memory + + # Track episode boundaries for memory optimization + self.episode_ends = torch.zeros( + capacity, dtype=torch.bool, device=storage_device + ) # If no state_keys provided, default to an empty list self.state_keys = state_keys if state_keys is not None else [] @@ -220,10 +229,18 @@ class ReplayBuffer: (self.capacity, *action_shape), device=self.storage_device ) self.rewards = torch.empty((self.capacity,), device=self.storage_device) - self.next_states = { - key: torch.empty((self.capacity, *shape), device=self.storage_device) - for key, shape in state_shapes.items() - } + + if not self.optimize_memory: + # Standard approach: store states and next_states separately + self.next_states = { + key: torch.empty((self.capacity, *shape), device=self.storage_device) + for key, shape in state_shapes.items() + } + else: + # Memory-optimized approach: don't allocate next_states buffer + # Just create a reference to states for consistent API + self.next_states = self.states # Just a reference for API consistency + self.dones = torch.empty( (self.capacity,), dtype=torch.bool, device=self.storage_device ) @@ -253,7 +270,12 @@ class ReplayBuffer: # Store the transition in pre-allocated tensors for key in self.states: self.states[key][self.position].copy_(state[key].squeeze(dim=0)) - self.next_states[key][self.position].copy_(next_state[key].squeeze(dim=0)) + + if not self.optimize_memory: + # Only store next_states if not optimizing memory + self.next_states[key][self.position].copy_( + next_state[key].squeeze(dim=0) + ) self.actions[self.position].copy_(action.squeeze(dim=0)) self.rewards[self.position] = reward @@ -288,10 +310,17 @@ class ReplayBuffer: batch_state = {} batch_next_state = {} - # First pass: load all tensors to target device + # First pass: load all state tensors to target device for key in self.states: batch_state[key] = self.states[key][idx].to(self.device) - batch_next_state[key] = self.next_states[key][idx].to(self.device) + + if not self.optimize_memory: + # Standard approach - load next_states directly + batch_next_state[key] = self.next_states[key][idx].to(self.device) + else: + # Memory-optimized approach - get next_state from the next index + next_idx = (idx + 1) % self.capacity + batch_next_state[key] = self.states[key][next_idx].to(self.device) # Apply image augmentation in a batched way if needed if self.use_drq and image_keys: @@ -343,6 +372,7 @@ class ReplayBuffer: image_augmentation_function: Optional[Callable] = None, use_drq: bool = True, storage_device: str = "cpu", + optimize_memory: bool = False, ) -> "ReplayBuffer": """ Convert a LeRobotDataset into a ReplayBuffer. @@ -358,6 +388,7 @@ class ReplayBuffer: If None, uses default random shift with pad=4. use_drq (bool): Whether to use DrQ image augmentation when sampling. storage_device (str): Device for storing tensor data. Using "cpu" saves GPU memory. + optimize_memory (bool): If True, reduces memory usage by not duplicating state data. Returns: ReplayBuffer: The replay buffer with dataset transitions. @@ -378,6 +409,7 @@ class ReplayBuffer: image_augmentation_function=image_augmentation_function, use_drq=use_drq, storage_device=storage_device, + optimize_memory=optimize_memory, ) # Convert dataset to transitions @@ -934,78 +966,268 @@ if __name__ == "__main__": "WARNING: Significant difference between original and reconverted values" ) - print("\nTesting real LeRobotDataset conversion...") - try: - # Try to use a real dataset if available - dataset_name = "AdilZtn/Maniskill-Pushcube-demonstration-small" - dataset = LeRobotDataset(repo_id=dataset_name) + print("\nAll previous tests completed!") - # Print available keys to debug - sample = dataset[0] - print("Available keys in first dataset:", list(sample.keys())) + # ===== Test for memory optimization ===== + print("\n===== Testing Memory Optimization =====") - # Check for required keys - if "action" not in sample or "next.reward" not in sample: - print("Dataset missing essential keys. Cannot convert.") - raise ValueError("Missing required keys in dataset") + # Create two buffers, one with memory optimization and one without + standard_buffer = ReplayBuffer( + capacity=1000, + device="cpu", + state_keys=["observation.image", "observation.state"], + storage_device="cpu", + optimize_memory=False, + use_drq=True, + ) - # Auto-detect appropriate state keys - image_keys = [] - state_keys = [] - for k, v in sample.items(): - # Skip metadata keys and action/reward keys - if k in { - "index", - "episode_index", - "frame_index", - "timestamp", - "task_index", - "action", - "next.reward", - "next.done", - }: - continue + optimized_buffer = ReplayBuffer( + capacity=1000, + device="cpu", + state_keys=["observation.image", "observation.state"], + storage_device="cpu", + optimize_memory=True, + use_drq=True, + ) - # Infer key type from tensor shape - if isinstance(v, torch.Tensor): - if len(v.shape) == 3 and (v.shape[0] == 3 or v.shape[0] == 1): - # Likely an image (channels, height, width) - image_keys.append(k) - else: - # Likely state or other vector - state_keys.append(k) + # Generate sample data with larger state dimensions for better memory impact + print("Generating test data...") + num_episodes = 10 + steps_per_episode = 50 + total_steps = num_episodes * steps_per_episode - print(f"Detected image keys: {image_keys}") - print(f"Detected state keys: {state_keys}") + for episode in range(num_episodes): + for step in range(steps_per_episode): + # Index in the overall sequence + i = episode * steps_per_episode + step - if not image_keys and not state_keys: - print("No usable keys found in dataset, skipping further tests") - raise ValueError("No usable keys found in dataset") + # Create state with identifiable values + img = torch.ones((3, 84, 84)) * (i / total_steps) + state_vec = torch.ones((10,)) * (i / total_steps) - # Convert to ReplayBuffer with detected keys - replay_buffer = ReplayBuffer.from_lerobot_dataset( - lerobot_dataset=dataset, - state_keys=image_keys + state_keys, - device="cpu", + state = { + "observation.image": img.unsqueeze(0), + "observation.state": state_vec.unsqueeze(0), + } + + # Create next state (i+1 or same as current if last in episode) + is_last_step = step == steps_per_episode - 1 + + if is_last_step: + # At episode end, next state = current state + next_img = img.clone() + next_state_vec = state_vec.clone() + done = True + truncated = False + else: + # Within episode, next state has incremented value + next_val = (i + 1) / total_steps + next_img = torch.ones((3, 84, 84)) * next_val + next_state_vec = torch.ones((10,)) * next_val + done = False + truncated = False + + next_state = { + "observation.image": next_img.unsqueeze(0), + "observation.state": next_state_vec.unsqueeze(0), + } + + # Action and reward + action = torch.tensor([[i / total_steps]]) + reward = float(i / total_steps) + + # Add to both buffers + standard_buffer.add(state, action, reward, next_state, done, truncated) + optimized_buffer.add(state, action, reward, next_state, done, truncated) + + # Verify episode boundaries with our simplified approach + print("\nVerifying simplified memory optimization...") + + # Test with a new buffer with a small sequence + test_buffer = ReplayBuffer( + capacity=20, + device="cpu", + state_keys=["value"], + storage_device="cpu", + optimize_memory=True, + use_drq=False, + ) + + # Add a simple sequence with known episode boundaries + for i in range(20): + val = float(i) + state = {"value": torch.tensor([[val]]).float()} + next_val = float(i + 1) if i % 5 != 4 else val # Episode ends every 5 steps + next_state = {"value": torch.tensor([[next_val]]).float()} + + # Set done=True at every 5th step + done = (i % 5) == 4 + action = torch.tensor([[0.0]]) + reward = 1.0 + truncated = False + + test_buffer.add(state, action, reward, next_state, done, truncated) + + # Get sequential batch for verification + sequential_batch_size = test_buffer.size + all_indices = torch.arange(sequential_batch_size, device=test_buffer.storage_device) + + # Get state tensors + batch_state = { + "value": test_buffer.states["value"][all_indices].to(test_buffer.device) + } + + # Get next_state using memory-optimized approach (simply index+1) + next_indices = (all_indices + 1) % test_buffer.capacity + batch_next_state = { + "value": test_buffer.states["value"][next_indices].to(test_buffer.device) + } + + # Get other tensors + batch_dones = test_buffer.dones[all_indices].to(test_buffer.device) + + # Print sequential values + print("State, Next State, Done (Sequential values with simplified optimization):") + state_values = batch_state["value"].squeeze().tolist() + next_values = batch_next_state["value"].squeeze().tolist() + done_flags = batch_dones.tolist() + + # Print all values + for i in range(len(state_values)): + print(f" {state_values[i]:.1f} → {next_values[i]:.1f}, Done: {done_flags[i]}") + + # Explain the memory optimization tradeoff + print("\nWith simplified memory optimization:") + print("- We always use the next state in the buffer (index+1) as next_state") + print("- For terminal states, this means using the first state of the next episode") + print("- This is a common tradeoff in RL implementations for memory efficiency") + print( + "- Since we track done flags, the algorithm can handle these transitions correctly" + ) + + # Test random sampling + print("\nVerifying random sampling with simplified memory optimization...") + random_samples = test_buffer.sample(20) # Sample all transitions + + # Extract values + random_state_values = random_samples["state"]["value"].squeeze().tolist() + random_next_values = random_samples["next_state"]["value"].squeeze().tolist() + random_done_flags = random_samples["done"].bool().tolist() + + # Print a few samples + print("Random samples - State, Next State, Done (First 10):") + for i in range(10): + print( + f" {random_state_values[i]:.1f} → {random_next_values[i]:.1f}, Done: {random_done_flags[i]}" ) - print(f"Loaded {len(replay_buffer)} transitions from {dataset_name}") - # Test sampling - real_batch = replay_buffer.sample(batch_size) - print("Sampled batch from real dataset, state shapes:") - for key in real_batch["state"]: - print(f" {key}: {real_batch['state'][key].shape}") + # Calculate memory savings + # Assume optimized_buffer and standard_buffer have already been initialized and filled + std_mem = ( + sum( + standard_buffer.states[key].nelement() + * standard_buffer.states[key].element_size() + for key in standard_buffer.states + ) + * 2 + ) + opt_mem = sum( + optimized_buffer.states[key].nelement() + * optimized_buffer.states[key].element_size() + for key in optimized_buffer.states + ) - # Convert back to LeRobotDataset - with TemporaryDirectory() as temp_dir: - replay_buffer_converted = replay_buffer.to_lerobot_dataset( - repo_id="test/real_dataset_converted", - root=os.path.join(temp_dir, "dataset2"), - ) - print( - f"Successfully converted back to LeRobotDataset with {len(replay_buffer_converted)} frames" - ) + savings_percent = (std_mem - opt_mem) / std_mem * 100 - except Exception as e: - print(f"Real dataset test failed: {e}") - print("This is expected if running offline or if the dataset is not available.") + print(f"\nMemory optimization result:") + print(f"- Standard buffer state memory: {std_mem / (1024 * 1024):.2f} MB") + print(f"- Optimized buffer state memory: {opt_mem / (1024 * 1024):.2f} MB") + print(f"- Memory savings for state tensors: {savings_percent:.1f}%") + + print("\nAll memory optimization tests completed!") + + # # ===== Test real dataset conversion ===== + # print("\n===== Testing Real LeRobotDataset Conversion =====") + # try: + # # Try to use a real dataset if available + # dataset_name = "AdilZtn/Maniskill-Pushcube-demonstration-small" + # dataset = LeRobotDataset(repo_id=dataset_name) + + # # Print available keys to debug + # sample = dataset[0] + # print("Available keys in dataset:", list(sample.keys())) + + # # Check for required keys + # if "action" not in sample or "next.reward" not in sample: + # print("Dataset missing essential keys. Cannot convert.") + # raise ValueError("Missing required keys in dataset") + + # # Auto-detect appropriate state keys + # image_keys = [] + # state_keys = [] + # for k, v in sample.items(): + # # Skip metadata keys and action/reward keys + # if k in { + # "index", + # "episode_index", + # "frame_index", + # "timestamp", + # "task_index", + # "action", + # "next.reward", + # "next.done", + # }: + # continue + + # # Infer key type from tensor shape + # if isinstance(v, torch.Tensor): + # if len(v.shape) == 3 and (v.shape[0] == 3 or v.shape[0] == 1): + # # Likely an image (channels, height, width) + # image_keys.append(k) + # else: + # # Likely state or other vector + # state_keys.append(k) + + # print(f"Detected image keys: {image_keys}") + # print(f"Detected state keys: {state_keys}") + + # if not image_keys and not state_keys: + # print("No usable keys found in dataset, skipping further tests") + # raise ValueError("No usable keys found in dataset") + + # # Test with standard and memory-optimized buffers + # for optimize_memory in [False, True]: + # buffer_type = "Standard" if not optimize_memory else "Memory-optimized" + # print(f"\nTesting {buffer_type} buffer with real dataset...") + + # # Convert to ReplayBuffer with detected keys + # replay_buffer = ReplayBuffer.from_lerobot_dataset( + # lerobot_dataset=dataset, + # state_keys=image_keys + state_keys, + # device="cpu", + # optimize_memory=optimize_memory, + # ) + # print(f"Loaded {len(replay_buffer)} transitions from {dataset_name}") + + # # Test sampling + # real_batch = replay_buffer.sample(32) + # print(f"Sampled batch from real dataset ({buffer_type}), state shapes:") + # for key in real_batch["state"]: + # print(f" {key}: {real_batch['state'][key].shape}") + + # # Convert back to LeRobotDataset + # with TemporaryDirectory() as temp_dir: + # dataset_name = f"test/real_dataset_converted_{buffer_type}" + # replay_buffer_converted = replay_buffer.to_lerobot_dataset( + # repo_id=dataset_name, + # root=os.path.join(temp_dir, f"dataset_{buffer_type}"), + # ) + # print( + # f"Successfully converted back to LeRobotDataset with {len(replay_buffer_converted)} frames" + # ) + + # except Exception as e: + # print(f"Real dataset test failed: {e}") + # print("This is expected if running offline or if the dataset is not available.") + + # print("\nAll tests completed!") diff --git a/lerobot/scripts/server/learner_server.py b/lerobot/scripts/server/learner_server.py index a4e42305..edbeb01c 100644 --- a/lerobot/scripts/server/learner_server.py +++ b/lerobot/scripts/server/learner_server.py @@ -154,6 +154,7 @@ def initialize_replay_buffer( device=device, state_keys=cfg.policy.input_shapes.keys(), storage_device=device, + optimize_memory=True, ) dataset = LeRobotDataset( @@ -166,6 +167,7 @@ def initialize_replay_buffer( capacity=cfg.training.online_buffer_capacity, device=device, state_keys=cfg.policy.input_shapes.keys(), + optimize_memory=True, ) @@ -648,6 +650,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No action_mask=active_action_dims, action_delta=cfg.env.wrapper.delta_action, storage_device=device, + optimize_memory=True, ) batch_size: int = batch_size // 2 # We will sample from both replay buffer