Add memory optimization option to ReplayBuffer

- Introduce `optimize_memory` parameter to reduce memory usage in replay buffer
- Implement simplified memory optimization by not storing duplicate next_states
- Update learner server and buffer initialization to use memory optimization by default
This commit is contained in:
AdilZouitine 2025-02-25 19:04:58 +00:00
parent 5b4a7aa81d
commit 1df9ee4f2d
2 changed files with 296 additions and 71 deletions

View File

@ -178,6 +178,7 @@ class ReplayBuffer:
image_augmentation_function: Optional[Callable] = None, image_augmentation_function: Optional[Callable] = None,
use_drq: bool = True, use_drq: bool = True,
storage_device: str = "cpu", storage_device: str = "cpu",
optimize_memory: bool = False,
): ):
""" """
Args: Args:
@ -189,6 +190,8 @@ class ReplayBuffer:
use_drq (bool): Whether to use the default DRQ image augmentation style, when sampling in the buffer. use_drq (bool): Whether to use the default DRQ image augmentation style, when sampling in the buffer.
storage_device: The device (e.g. "cpu" or "cuda:0") where the data will be stored. storage_device: The device (e.g. "cpu" or "cuda:0") where the data will be stored.
Using "cpu" can help save GPU memory. Using "cpu" can help save GPU memory.
optimize_memory (bool): If True, optimizes memory by not storing duplicate next_states when
they can be derived from states. This is useful for large datasets where next_state[i] = state[i+1].
""" """
self.capacity = capacity self.capacity = capacity
self.device = device self.device = device
@ -196,6 +199,12 @@ class ReplayBuffer:
self.position = 0 self.position = 0
self.size = 0 self.size = 0
self.initialized = False self.initialized = False
self.optimize_memory = optimize_memory
# Track episode boundaries for memory optimization
self.episode_ends = torch.zeros(
capacity, dtype=torch.bool, device=storage_device
)
# If no state_keys provided, default to an empty list # If no state_keys provided, default to an empty list
self.state_keys = state_keys if state_keys is not None else [] self.state_keys = state_keys if state_keys is not None else []
@ -220,10 +229,18 @@ class ReplayBuffer:
(self.capacity, *action_shape), device=self.storage_device (self.capacity, *action_shape), device=self.storage_device
) )
self.rewards = torch.empty((self.capacity,), device=self.storage_device) self.rewards = torch.empty((self.capacity,), device=self.storage_device)
self.next_states = {
key: torch.empty((self.capacity, *shape), device=self.storage_device) if not self.optimize_memory:
for key, shape in state_shapes.items() # Standard approach: store states and next_states separately
} self.next_states = {
key: torch.empty((self.capacity, *shape), device=self.storage_device)
for key, shape in state_shapes.items()
}
else:
# Memory-optimized approach: don't allocate next_states buffer
# Just create a reference to states for consistent API
self.next_states = self.states # Just a reference for API consistency
self.dones = torch.empty( self.dones = torch.empty(
(self.capacity,), dtype=torch.bool, device=self.storage_device (self.capacity,), dtype=torch.bool, device=self.storage_device
) )
@ -253,7 +270,12 @@ class ReplayBuffer:
# Store the transition in pre-allocated tensors # Store the transition in pre-allocated tensors
for key in self.states: for key in self.states:
self.states[key][self.position].copy_(state[key].squeeze(dim=0)) self.states[key][self.position].copy_(state[key].squeeze(dim=0))
self.next_states[key][self.position].copy_(next_state[key].squeeze(dim=0))
if not self.optimize_memory:
# Only store next_states if not optimizing memory
self.next_states[key][self.position].copy_(
next_state[key].squeeze(dim=0)
)
self.actions[self.position].copy_(action.squeeze(dim=0)) self.actions[self.position].copy_(action.squeeze(dim=0))
self.rewards[self.position] = reward self.rewards[self.position] = reward
@ -288,10 +310,17 @@ class ReplayBuffer:
batch_state = {} batch_state = {}
batch_next_state = {} batch_next_state = {}
# First pass: load all tensors to target device # First pass: load all state tensors to target device
for key in self.states: for key in self.states:
batch_state[key] = self.states[key][idx].to(self.device) batch_state[key] = self.states[key][idx].to(self.device)
batch_next_state[key] = self.next_states[key][idx].to(self.device)
if not self.optimize_memory:
# Standard approach - load next_states directly
batch_next_state[key] = self.next_states[key][idx].to(self.device)
else:
# Memory-optimized approach - get next_state from the next index
next_idx = (idx + 1) % self.capacity
batch_next_state[key] = self.states[key][next_idx].to(self.device)
# Apply image augmentation in a batched way if needed # Apply image augmentation in a batched way if needed
if self.use_drq and image_keys: if self.use_drq and image_keys:
@ -343,6 +372,7 @@ class ReplayBuffer:
image_augmentation_function: Optional[Callable] = None, image_augmentation_function: Optional[Callable] = None,
use_drq: bool = True, use_drq: bool = True,
storage_device: str = "cpu", storage_device: str = "cpu",
optimize_memory: bool = False,
) -> "ReplayBuffer": ) -> "ReplayBuffer":
""" """
Convert a LeRobotDataset into a ReplayBuffer. Convert a LeRobotDataset into a ReplayBuffer.
@ -358,6 +388,7 @@ class ReplayBuffer:
If None, uses default random shift with pad=4. If None, uses default random shift with pad=4.
use_drq (bool): Whether to use DrQ image augmentation when sampling. use_drq (bool): Whether to use DrQ image augmentation when sampling.
storage_device (str): Device for storing tensor data. Using "cpu" saves GPU memory. storage_device (str): Device for storing tensor data. Using "cpu" saves GPU memory.
optimize_memory (bool): If True, reduces memory usage by not duplicating state data.
Returns: Returns:
ReplayBuffer: The replay buffer with dataset transitions. ReplayBuffer: The replay buffer with dataset transitions.
@ -378,6 +409,7 @@ class ReplayBuffer:
image_augmentation_function=image_augmentation_function, image_augmentation_function=image_augmentation_function,
use_drq=use_drq, use_drq=use_drq,
storage_device=storage_device, storage_device=storage_device,
optimize_memory=optimize_memory,
) )
# Convert dataset to transitions # Convert dataset to transitions
@ -934,78 +966,268 @@ if __name__ == "__main__":
"WARNING: Significant difference between original and reconverted values" "WARNING: Significant difference between original and reconverted values"
) )
print("\nTesting real LeRobotDataset conversion...") print("\nAll previous tests completed!")
try:
# Try to use a real dataset if available
dataset_name = "AdilZtn/Maniskill-Pushcube-demonstration-small"
dataset = LeRobotDataset(repo_id=dataset_name)
# Print available keys to debug # ===== Test for memory optimization =====
sample = dataset[0] print("\n===== Testing Memory Optimization =====")
print("Available keys in first dataset:", list(sample.keys()))
# Check for required keys # Create two buffers, one with memory optimization and one without
if "action" not in sample or "next.reward" not in sample: standard_buffer = ReplayBuffer(
print("Dataset missing essential keys. Cannot convert.") capacity=1000,
raise ValueError("Missing required keys in dataset") device="cpu",
state_keys=["observation.image", "observation.state"],
storage_device="cpu",
optimize_memory=False,
use_drq=True,
)
# Auto-detect appropriate state keys optimized_buffer = ReplayBuffer(
image_keys = [] capacity=1000,
state_keys = [] device="cpu",
for k, v in sample.items(): state_keys=["observation.image", "observation.state"],
# Skip metadata keys and action/reward keys storage_device="cpu",
if k in { optimize_memory=True,
"index", use_drq=True,
"episode_index", )
"frame_index",
"timestamp",
"task_index",
"action",
"next.reward",
"next.done",
}:
continue
# Infer key type from tensor shape # Generate sample data with larger state dimensions for better memory impact
if isinstance(v, torch.Tensor): print("Generating test data...")
if len(v.shape) == 3 and (v.shape[0] == 3 or v.shape[0] == 1): num_episodes = 10
# Likely an image (channels, height, width) steps_per_episode = 50
image_keys.append(k) total_steps = num_episodes * steps_per_episode
else:
# Likely state or other vector
state_keys.append(k)
print(f"Detected image keys: {image_keys}") for episode in range(num_episodes):
print(f"Detected state keys: {state_keys}") for step in range(steps_per_episode):
# Index in the overall sequence
i = episode * steps_per_episode + step
if not image_keys and not state_keys: # Create state with identifiable values
print("No usable keys found in dataset, skipping further tests") img = torch.ones((3, 84, 84)) * (i / total_steps)
raise ValueError("No usable keys found in dataset") state_vec = torch.ones((10,)) * (i / total_steps)
# Convert to ReplayBuffer with detected keys state = {
replay_buffer = ReplayBuffer.from_lerobot_dataset( "observation.image": img.unsqueeze(0),
lerobot_dataset=dataset, "observation.state": state_vec.unsqueeze(0),
state_keys=image_keys + state_keys, }
device="cpu",
# Create next state (i+1 or same as current if last in episode)
is_last_step = step == steps_per_episode - 1
if is_last_step:
# At episode end, next state = current state
next_img = img.clone()
next_state_vec = state_vec.clone()
done = True
truncated = False
else:
# Within episode, next state has incremented value
next_val = (i + 1) / total_steps
next_img = torch.ones((3, 84, 84)) * next_val
next_state_vec = torch.ones((10,)) * next_val
done = False
truncated = False
next_state = {
"observation.image": next_img.unsqueeze(0),
"observation.state": next_state_vec.unsqueeze(0),
}
# Action and reward
action = torch.tensor([[i / total_steps]])
reward = float(i / total_steps)
# Add to both buffers
standard_buffer.add(state, action, reward, next_state, done, truncated)
optimized_buffer.add(state, action, reward, next_state, done, truncated)
# Verify episode boundaries with our simplified approach
print("\nVerifying simplified memory optimization...")
# Test with a new buffer with a small sequence
test_buffer = ReplayBuffer(
capacity=20,
device="cpu",
state_keys=["value"],
storage_device="cpu",
optimize_memory=True,
use_drq=False,
)
# Add a simple sequence with known episode boundaries
for i in range(20):
val = float(i)
state = {"value": torch.tensor([[val]]).float()}
next_val = float(i + 1) if i % 5 != 4 else val # Episode ends every 5 steps
next_state = {"value": torch.tensor([[next_val]]).float()}
# Set done=True at every 5th step
done = (i % 5) == 4
action = torch.tensor([[0.0]])
reward = 1.0
truncated = False
test_buffer.add(state, action, reward, next_state, done, truncated)
# Get sequential batch for verification
sequential_batch_size = test_buffer.size
all_indices = torch.arange(sequential_batch_size, device=test_buffer.storage_device)
# Get state tensors
batch_state = {
"value": test_buffer.states["value"][all_indices].to(test_buffer.device)
}
# Get next_state using memory-optimized approach (simply index+1)
next_indices = (all_indices + 1) % test_buffer.capacity
batch_next_state = {
"value": test_buffer.states["value"][next_indices].to(test_buffer.device)
}
# Get other tensors
batch_dones = test_buffer.dones[all_indices].to(test_buffer.device)
# Print sequential values
print("State, Next State, Done (Sequential values with simplified optimization):")
state_values = batch_state["value"].squeeze().tolist()
next_values = batch_next_state["value"].squeeze().tolist()
done_flags = batch_dones.tolist()
# Print all values
for i in range(len(state_values)):
print(f" {state_values[i]:.1f}{next_values[i]:.1f}, Done: {done_flags[i]}")
# Explain the memory optimization tradeoff
print("\nWith simplified memory optimization:")
print("- We always use the next state in the buffer (index+1) as next_state")
print("- For terminal states, this means using the first state of the next episode")
print("- This is a common tradeoff in RL implementations for memory efficiency")
print(
"- Since we track done flags, the algorithm can handle these transitions correctly"
)
# Test random sampling
print("\nVerifying random sampling with simplified memory optimization...")
random_samples = test_buffer.sample(20) # Sample all transitions
# Extract values
random_state_values = random_samples["state"]["value"].squeeze().tolist()
random_next_values = random_samples["next_state"]["value"].squeeze().tolist()
random_done_flags = random_samples["done"].bool().tolist()
# Print a few samples
print("Random samples - State, Next State, Done (First 10):")
for i in range(10):
print(
f" {random_state_values[i]:.1f}{random_next_values[i]:.1f}, Done: {random_done_flags[i]}"
) )
print(f"Loaded {len(replay_buffer)} transitions from {dataset_name}")
# Test sampling # Calculate memory savings
real_batch = replay_buffer.sample(batch_size) # Assume optimized_buffer and standard_buffer have already been initialized and filled
print("Sampled batch from real dataset, state shapes:") std_mem = (
for key in real_batch["state"]: sum(
print(f" {key}: {real_batch['state'][key].shape}") standard_buffer.states[key].nelement()
* standard_buffer.states[key].element_size()
for key in standard_buffer.states
)
* 2
)
opt_mem = sum(
optimized_buffer.states[key].nelement()
* optimized_buffer.states[key].element_size()
for key in optimized_buffer.states
)
# Convert back to LeRobotDataset savings_percent = (std_mem - opt_mem) / std_mem * 100
with TemporaryDirectory() as temp_dir:
replay_buffer_converted = replay_buffer.to_lerobot_dataset(
repo_id="test/real_dataset_converted",
root=os.path.join(temp_dir, "dataset2"),
)
print(
f"Successfully converted back to LeRobotDataset with {len(replay_buffer_converted)} frames"
)
except Exception as e: print(f"\nMemory optimization result:")
print(f"Real dataset test failed: {e}") print(f"- Standard buffer state memory: {std_mem / (1024 * 1024):.2f} MB")
print("This is expected if running offline or if the dataset is not available.") print(f"- Optimized buffer state memory: {opt_mem / (1024 * 1024):.2f} MB")
print(f"- Memory savings for state tensors: {savings_percent:.1f}%")
print("\nAll memory optimization tests completed!")
# # ===== Test real dataset conversion =====
# print("\n===== Testing Real LeRobotDataset Conversion =====")
# try:
# # Try to use a real dataset if available
# dataset_name = "AdilZtn/Maniskill-Pushcube-demonstration-small"
# dataset = LeRobotDataset(repo_id=dataset_name)
# # Print available keys to debug
# sample = dataset[0]
# print("Available keys in dataset:", list(sample.keys()))
# # Check for required keys
# if "action" not in sample or "next.reward" not in sample:
# print("Dataset missing essential keys. Cannot convert.")
# raise ValueError("Missing required keys in dataset")
# # Auto-detect appropriate state keys
# image_keys = []
# state_keys = []
# for k, v in sample.items():
# # Skip metadata keys and action/reward keys
# if k in {
# "index",
# "episode_index",
# "frame_index",
# "timestamp",
# "task_index",
# "action",
# "next.reward",
# "next.done",
# }:
# continue
# # Infer key type from tensor shape
# if isinstance(v, torch.Tensor):
# if len(v.shape) == 3 and (v.shape[0] == 3 or v.shape[0] == 1):
# # Likely an image (channels, height, width)
# image_keys.append(k)
# else:
# # Likely state or other vector
# state_keys.append(k)
# print(f"Detected image keys: {image_keys}")
# print(f"Detected state keys: {state_keys}")
# if not image_keys and not state_keys:
# print("No usable keys found in dataset, skipping further tests")
# raise ValueError("No usable keys found in dataset")
# # Test with standard and memory-optimized buffers
# for optimize_memory in [False, True]:
# buffer_type = "Standard" if not optimize_memory else "Memory-optimized"
# print(f"\nTesting {buffer_type} buffer with real dataset...")
# # Convert to ReplayBuffer with detected keys
# replay_buffer = ReplayBuffer.from_lerobot_dataset(
# lerobot_dataset=dataset,
# state_keys=image_keys + state_keys,
# device="cpu",
# optimize_memory=optimize_memory,
# )
# print(f"Loaded {len(replay_buffer)} transitions from {dataset_name}")
# # Test sampling
# real_batch = replay_buffer.sample(32)
# print(f"Sampled batch from real dataset ({buffer_type}), state shapes:")
# for key in real_batch["state"]:
# print(f" {key}: {real_batch['state'][key].shape}")
# # Convert back to LeRobotDataset
# with TemporaryDirectory() as temp_dir:
# dataset_name = f"test/real_dataset_converted_{buffer_type}"
# replay_buffer_converted = replay_buffer.to_lerobot_dataset(
# repo_id=dataset_name,
# root=os.path.join(temp_dir, f"dataset_{buffer_type}"),
# )
# print(
# f"Successfully converted back to LeRobotDataset with {len(replay_buffer_converted)} frames"
# )
# except Exception as e:
# print(f"Real dataset test failed: {e}")
# print("This is expected if running offline or if the dataset is not available.")
# print("\nAll tests completed!")

View File

@ -154,6 +154,7 @@ def initialize_replay_buffer(
device=device, device=device,
state_keys=cfg.policy.input_shapes.keys(), state_keys=cfg.policy.input_shapes.keys(),
storage_device=device, storage_device=device,
optimize_memory=True,
) )
dataset = LeRobotDataset( dataset = LeRobotDataset(
@ -166,6 +167,7 @@ def initialize_replay_buffer(
capacity=cfg.training.online_buffer_capacity, capacity=cfg.training.online_buffer_capacity,
device=device, device=device,
state_keys=cfg.policy.input_shapes.keys(), state_keys=cfg.policy.input_shapes.keys(),
optimize_memory=True,
) )
@ -648,6 +650,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
action_mask=active_action_dims, action_mask=active_action_dims,
action_delta=cfg.env.wrapper.delta_action, action_delta=cfg.env.wrapper.delta_action,
storage_device=device, storage_device=device,
optimize_memory=True,
) )
batch_size: int = batch_size // 2 # We will sample from both replay buffer batch_size: int = batch_size // 2 # We will sample from both replay buffer