Add memory optimization option to ReplayBuffer

- Introduce `optimize_memory` parameter to reduce memory usage in replay buffer
- Implement simplified memory optimization by not storing duplicate next_states
- Update learner server and buffer initialization to use memory optimization by default
This commit is contained in:
AdilZouitine 2025-02-25 19:04:58 +00:00
parent 5b4a7aa81d
commit 1df9ee4f2d
2 changed files with 296 additions and 71 deletions

View File

@ -178,6 +178,7 @@ class ReplayBuffer:
image_augmentation_function: Optional[Callable] = None,
use_drq: bool = True,
storage_device: str = "cpu",
optimize_memory: bool = False,
):
"""
Args:
@ -189,6 +190,8 @@ class ReplayBuffer:
use_drq (bool): Whether to use the default DRQ image augmentation style, when sampling in the buffer.
storage_device: The device (e.g. "cpu" or "cuda:0") where the data will be stored.
Using "cpu" can help save GPU memory.
optimize_memory (bool): If True, optimizes memory by not storing duplicate next_states when
they can be derived from states. This is useful for large datasets where next_state[i] = state[i+1].
"""
self.capacity = capacity
self.device = device
@ -196,6 +199,12 @@ class ReplayBuffer:
self.position = 0
self.size = 0
self.initialized = False
self.optimize_memory = optimize_memory
# Track episode boundaries for memory optimization
self.episode_ends = torch.zeros(
capacity, dtype=torch.bool, device=storage_device
)
# If no state_keys provided, default to an empty list
self.state_keys = state_keys if state_keys is not None else []
@ -220,10 +229,18 @@ class ReplayBuffer:
(self.capacity, *action_shape), device=self.storage_device
)
self.rewards = torch.empty((self.capacity,), device=self.storage_device)
self.next_states = {
key: torch.empty((self.capacity, *shape), device=self.storage_device)
for key, shape in state_shapes.items()
}
if not self.optimize_memory:
# Standard approach: store states and next_states separately
self.next_states = {
key: torch.empty((self.capacity, *shape), device=self.storage_device)
for key, shape in state_shapes.items()
}
else:
# Memory-optimized approach: don't allocate next_states buffer
# Just create a reference to states for consistent API
self.next_states = self.states # Just a reference for API consistency
self.dones = torch.empty(
(self.capacity,), dtype=torch.bool, device=self.storage_device
)
@ -253,7 +270,12 @@ class ReplayBuffer:
# Store the transition in pre-allocated tensors
for key in self.states:
self.states[key][self.position].copy_(state[key].squeeze(dim=0))
self.next_states[key][self.position].copy_(next_state[key].squeeze(dim=0))
if not self.optimize_memory:
# Only store next_states if not optimizing memory
self.next_states[key][self.position].copy_(
next_state[key].squeeze(dim=0)
)
self.actions[self.position].copy_(action.squeeze(dim=0))
self.rewards[self.position] = reward
@ -288,10 +310,17 @@ class ReplayBuffer:
batch_state = {}
batch_next_state = {}
# First pass: load all tensors to target device
# First pass: load all state tensors to target device
for key in self.states:
batch_state[key] = self.states[key][idx].to(self.device)
batch_next_state[key] = self.next_states[key][idx].to(self.device)
if not self.optimize_memory:
# Standard approach - load next_states directly
batch_next_state[key] = self.next_states[key][idx].to(self.device)
else:
# Memory-optimized approach - get next_state from the next index
next_idx = (idx + 1) % self.capacity
batch_next_state[key] = self.states[key][next_idx].to(self.device)
# Apply image augmentation in a batched way if needed
if self.use_drq and image_keys:
@ -343,6 +372,7 @@ class ReplayBuffer:
image_augmentation_function: Optional[Callable] = None,
use_drq: bool = True,
storage_device: str = "cpu",
optimize_memory: bool = False,
) -> "ReplayBuffer":
"""
Convert a LeRobotDataset into a ReplayBuffer.
@ -358,6 +388,7 @@ class ReplayBuffer:
If None, uses default random shift with pad=4.
use_drq (bool): Whether to use DrQ image augmentation when sampling.
storage_device (str): Device for storing tensor data. Using "cpu" saves GPU memory.
optimize_memory (bool): If True, reduces memory usage by not duplicating state data.
Returns:
ReplayBuffer: The replay buffer with dataset transitions.
@ -378,6 +409,7 @@ class ReplayBuffer:
image_augmentation_function=image_augmentation_function,
use_drq=use_drq,
storage_device=storage_device,
optimize_memory=optimize_memory,
)
# Convert dataset to transitions
@ -934,78 +966,268 @@ if __name__ == "__main__":
"WARNING: Significant difference between original and reconverted values"
)
print("\nTesting real LeRobotDataset conversion...")
try:
# Try to use a real dataset if available
dataset_name = "AdilZtn/Maniskill-Pushcube-demonstration-small"
dataset = LeRobotDataset(repo_id=dataset_name)
print("\nAll previous tests completed!")
# Print available keys to debug
sample = dataset[0]
print("Available keys in first dataset:", list(sample.keys()))
# ===== Test for memory optimization =====
print("\n===== Testing Memory Optimization =====")
# Check for required keys
if "action" not in sample or "next.reward" not in sample:
print("Dataset missing essential keys. Cannot convert.")
raise ValueError("Missing required keys in dataset")
# Create two buffers, one with memory optimization and one without
standard_buffer = ReplayBuffer(
capacity=1000,
device="cpu",
state_keys=["observation.image", "observation.state"],
storage_device="cpu",
optimize_memory=False,
use_drq=True,
)
# Auto-detect appropriate state keys
image_keys = []
state_keys = []
for k, v in sample.items():
# Skip metadata keys and action/reward keys
if k in {
"index",
"episode_index",
"frame_index",
"timestamp",
"task_index",
"action",
"next.reward",
"next.done",
}:
continue
optimized_buffer = ReplayBuffer(
capacity=1000,
device="cpu",
state_keys=["observation.image", "observation.state"],
storage_device="cpu",
optimize_memory=True,
use_drq=True,
)
# Infer key type from tensor shape
if isinstance(v, torch.Tensor):
if len(v.shape) == 3 and (v.shape[0] == 3 or v.shape[0] == 1):
# Likely an image (channels, height, width)
image_keys.append(k)
else:
# Likely state or other vector
state_keys.append(k)
# Generate sample data with larger state dimensions for better memory impact
print("Generating test data...")
num_episodes = 10
steps_per_episode = 50
total_steps = num_episodes * steps_per_episode
print(f"Detected image keys: {image_keys}")
print(f"Detected state keys: {state_keys}")
for episode in range(num_episodes):
for step in range(steps_per_episode):
# Index in the overall sequence
i = episode * steps_per_episode + step
if not image_keys and not state_keys:
print("No usable keys found in dataset, skipping further tests")
raise ValueError("No usable keys found in dataset")
# Create state with identifiable values
img = torch.ones((3, 84, 84)) * (i / total_steps)
state_vec = torch.ones((10,)) * (i / total_steps)
# Convert to ReplayBuffer with detected keys
replay_buffer = ReplayBuffer.from_lerobot_dataset(
lerobot_dataset=dataset,
state_keys=image_keys + state_keys,
device="cpu",
state = {
"observation.image": img.unsqueeze(0),
"observation.state": state_vec.unsqueeze(0),
}
# Create next state (i+1 or same as current if last in episode)
is_last_step = step == steps_per_episode - 1
if is_last_step:
# At episode end, next state = current state
next_img = img.clone()
next_state_vec = state_vec.clone()
done = True
truncated = False
else:
# Within episode, next state has incremented value
next_val = (i + 1) / total_steps
next_img = torch.ones((3, 84, 84)) * next_val
next_state_vec = torch.ones((10,)) * next_val
done = False
truncated = False
next_state = {
"observation.image": next_img.unsqueeze(0),
"observation.state": next_state_vec.unsqueeze(0),
}
# Action and reward
action = torch.tensor([[i / total_steps]])
reward = float(i / total_steps)
# Add to both buffers
standard_buffer.add(state, action, reward, next_state, done, truncated)
optimized_buffer.add(state, action, reward, next_state, done, truncated)
# Verify episode boundaries with our simplified approach
print("\nVerifying simplified memory optimization...")
# Test with a new buffer with a small sequence
test_buffer = ReplayBuffer(
capacity=20,
device="cpu",
state_keys=["value"],
storage_device="cpu",
optimize_memory=True,
use_drq=False,
)
# Add a simple sequence with known episode boundaries
for i in range(20):
val = float(i)
state = {"value": torch.tensor([[val]]).float()}
next_val = float(i + 1) if i % 5 != 4 else val # Episode ends every 5 steps
next_state = {"value": torch.tensor([[next_val]]).float()}
# Set done=True at every 5th step
done = (i % 5) == 4
action = torch.tensor([[0.0]])
reward = 1.0
truncated = False
test_buffer.add(state, action, reward, next_state, done, truncated)
# Get sequential batch for verification
sequential_batch_size = test_buffer.size
all_indices = torch.arange(sequential_batch_size, device=test_buffer.storage_device)
# Get state tensors
batch_state = {
"value": test_buffer.states["value"][all_indices].to(test_buffer.device)
}
# Get next_state using memory-optimized approach (simply index+1)
next_indices = (all_indices + 1) % test_buffer.capacity
batch_next_state = {
"value": test_buffer.states["value"][next_indices].to(test_buffer.device)
}
# Get other tensors
batch_dones = test_buffer.dones[all_indices].to(test_buffer.device)
# Print sequential values
print("State, Next State, Done (Sequential values with simplified optimization):")
state_values = batch_state["value"].squeeze().tolist()
next_values = batch_next_state["value"].squeeze().tolist()
done_flags = batch_dones.tolist()
# Print all values
for i in range(len(state_values)):
print(f" {state_values[i]:.1f}{next_values[i]:.1f}, Done: {done_flags[i]}")
# Explain the memory optimization tradeoff
print("\nWith simplified memory optimization:")
print("- We always use the next state in the buffer (index+1) as next_state")
print("- For terminal states, this means using the first state of the next episode")
print("- This is a common tradeoff in RL implementations for memory efficiency")
print(
"- Since we track done flags, the algorithm can handle these transitions correctly"
)
# Test random sampling
print("\nVerifying random sampling with simplified memory optimization...")
random_samples = test_buffer.sample(20) # Sample all transitions
# Extract values
random_state_values = random_samples["state"]["value"].squeeze().tolist()
random_next_values = random_samples["next_state"]["value"].squeeze().tolist()
random_done_flags = random_samples["done"].bool().tolist()
# Print a few samples
print("Random samples - State, Next State, Done (First 10):")
for i in range(10):
print(
f" {random_state_values[i]:.1f}{random_next_values[i]:.1f}, Done: {random_done_flags[i]}"
)
print(f"Loaded {len(replay_buffer)} transitions from {dataset_name}")
# Test sampling
real_batch = replay_buffer.sample(batch_size)
print("Sampled batch from real dataset, state shapes:")
for key in real_batch["state"]:
print(f" {key}: {real_batch['state'][key].shape}")
# Calculate memory savings
# Assume optimized_buffer and standard_buffer have already been initialized and filled
std_mem = (
sum(
standard_buffer.states[key].nelement()
* standard_buffer.states[key].element_size()
for key in standard_buffer.states
)
* 2
)
opt_mem = sum(
optimized_buffer.states[key].nelement()
* optimized_buffer.states[key].element_size()
for key in optimized_buffer.states
)
# Convert back to LeRobotDataset
with TemporaryDirectory() as temp_dir:
replay_buffer_converted = replay_buffer.to_lerobot_dataset(
repo_id="test/real_dataset_converted",
root=os.path.join(temp_dir, "dataset2"),
)
print(
f"Successfully converted back to LeRobotDataset with {len(replay_buffer_converted)} frames"
)
savings_percent = (std_mem - opt_mem) / std_mem * 100
except Exception as e:
print(f"Real dataset test failed: {e}")
print("This is expected if running offline or if the dataset is not available.")
print(f"\nMemory optimization result:")
print(f"- Standard buffer state memory: {std_mem / (1024 * 1024):.2f} MB")
print(f"- Optimized buffer state memory: {opt_mem / (1024 * 1024):.2f} MB")
print(f"- Memory savings for state tensors: {savings_percent:.1f}%")
print("\nAll memory optimization tests completed!")
# # ===== Test real dataset conversion =====
# print("\n===== Testing Real LeRobotDataset Conversion =====")
# try:
# # Try to use a real dataset if available
# dataset_name = "AdilZtn/Maniskill-Pushcube-demonstration-small"
# dataset = LeRobotDataset(repo_id=dataset_name)
# # Print available keys to debug
# sample = dataset[0]
# print("Available keys in dataset:", list(sample.keys()))
# # Check for required keys
# if "action" not in sample or "next.reward" not in sample:
# print("Dataset missing essential keys. Cannot convert.")
# raise ValueError("Missing required keys in dataset")
# # Auto-detect appropriate state keys
# image_keys = []
# state_keys = []
# for k, v in sample.items():
# # Skip metadata keys and action/reward keys
# if k in {
# "index",
# "episode_index",
# "frame_index",
# "timestamp",
# "task_index",
# "action",
# "next.reward",
# "next.done",
# }:
# continue
# # Infer key type from tensor shape
# if isinstance(v, torch.Tensor):
# if len(v.shape) == 3 and (v.shape[0] == 3 or v.shape[0] == 1):
# # Likely an image (channels, height, width)
# image_keys.append(k)
# else:
# # Likely state or other vector
# state_keys.append(k)
# print(f"Detected image keys: {image_keys}")
# print(f"Detected state keys: {state_keys}")
# if not image_keys and not state_keys:
# print("No usable keys found in dataset, skipping further tests")
# raise ValueError("No usable keys found in dataset")
# # Test with standard and memory-optimized buffers
# for optimize_memory in [False, True]:
# buffer_type = "Standard" if not optimize_memory else "Memory-optimized"
# print(f"\nTesting {buffer_type} buffer with real dataset...")
# # Convert to ReplayBuffer with detected keys
# replay_buffer = ReplayBuffer.from_lerobot_dataset(
# lerobot_dataset=dataset,
# state_keys=image_keys + state_keys,
# device="cpu",
# optimize_memory=optimize_memory,
# )
# print(f"Loaded {len(replay_buffer)} transitions from {dataset_name}")
# # Test sampling
# real_batch = replay_buffer.sample(32)
# print(f"Sampled batch from real dataset ({buffer_type}), state shapes:")
# for key in real_batch["state"]:
# print(f" {key}: {real_batch['state'][key].shape}")
# # Convert back to LeRobotDataset
# with TemporaryDirectory() as temp_dir:
# dataset_name = f"test/real_dataset_converted_{buffer_type}"
# replay_buffer_converted = replay_buffer.to_lerobot_dataset(
# repo_id=dataset_name,
# root=os.path.join(temp_dir, f"dataset_{buffer_type}"),
# )
# print(
# f"Successfully converted back to LeRobotDataset with {len(replay_buffer_converted)} frames"
# )
# except Exception as e:
# print(f"Real dataset test failed: {e}")
# print("This is expected if running offline or if the dataset is not available.")
# print("\nAll tests completed!")

View File

@ -154,6 +154,7 @@ def initialize_replay_buffer(
device=device,
state_keys=cfg.policy.input_shapes.keys(),
storage_device=device,
optimize_memory=True,
)
dataset = LeRobotDataset(
@ -166,6 +167,7 @@ def initialize_replay_buffer(
capacity=cfg.training.online_buffer_capacity,
device=device,
state_keys=cfg.policy.input_shapes.keys(),
optimize_memory=True,
)
@ -648,6 +650,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
action_mask=active_action_dims,
action_delta=cfg.env.wrapper.delta_action,
storage_device=device,
optimize_memory=True,
)
batch_size: int = batch_size // 2 # We will sample from both replay buffer