Add memory optimization option to ReplayBuffer
- Introduce `optimize_memory` parameter to reduce memory usage in replay buffer - Implement simplified memory optimization by not storing duplicate next_states - Update learner server and buffer initialization to use memory optimization by default
This commit is contained in:
parent
5b4a7aa81d
commit
1df9ee4f2d
|
@ -178,6 +178,7 @@ class ReplayBuffer:
|
|||
image_augmentation_function: Optional[Callable] = None,
|
||||
use_drq: bool = True,
|
||||
storage_device: str = "cpu",
|
||||
optimize_memory: bool = False,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
|
@ -189,6 +190,8 @@ class ReplayBuffer:
|
|||
use_drq (bool): Whether to use the default DRQ image augmentation style, when sampling in the buffer.
|
||||
storage_device: The device (e.g. "cpu" or "cuda:0") where the data will be stored.
|
||||
Using "cpu" can help save GPU memory.
|
||||
optimize_memory (bool): If True, optimizes memory by not storing duplicate next_states when
|
||||
they can be derived from states. This is useful for large datasets where next_state[i] = state[i+1].
|
||||
"""
|
||||
self.capacity = capacity
|
||||
self.device = device
|
||||
|
@ -196,6 +199,12 @@ class ReplayBuffer:
|
|||
self.position = 0
|
||||
self.size = 0
|
||||
self.initialized = False
|
||||
self.optimize_memory = optimize_memory
|
||||
|
||||
# Track episode boundaries for memory optimization
|
||||
self.episode_ends = torch.zeros(
|
||||
capacity, dtype=torch.bool, device=storage_device
|
||||
)
|
||||
|
||||
# If no state_keys provided, default to an empty list
|
||||
self.state_keys = state_keys if state_keys is not None else []
|
||||
|
@ -220,10 +229,18 @@ class ReplayBuffer:
|
|||
(self.capacity, *action_shape), device=self.storage_device
|
||||
)
|
||||
self.rewards = torch.empty((self.capacity,), device=self.storage_device)
|
||||
self.next_states = {
|
||||
key: torch.empty((self.capacity, *shape), device=self.storage_device)
|
||||
for key, shape in state_shapes.items()
|
||||
}
|
||||
|
||||
if not self.optimize_memory:
|
||||
# Standard approach: store states and next_states separately
|
||||
self.next_states = {
|
||||
key: torch.empty((self.capacity, *shape), device=self.storage_device)
|
||||
for key, shape in state_shapes.items()
|
||||
}
|
||||
else:
|
||||
# Memory-optimized approach: don't allocate next_states buffer
|
||||
# Just create a reference to states for consistent API
|
||||
self.next_states = self.states # Just a reference for API consistency
|
||||
|
||||
self.dones = torch.empty(
|
||||
(self.capacity,), dtype=torch.bool, device=self.storage_device
|
||||
)
|
||||
|
@ -253,7 +270,12 @@ class ReplayBuffer:
|
|||
# Store the transition in pre-allocated tensors
|
||||
for key in self.states:
|
||||
self.states[key][self.position].copy_(state[key].squeeze(dim=0))
|
||||
self.next_states[key][self.position].copy_(next_state[key].squeeze(dim=0))
|
||||
|
||||
if not self.optimize_memory:
|
||||
# Only store next_states if not optimizing memory
|
||||
self.next_states[key][self.position].copy_(
|
||||
next_state[key].squeeze(dim=0)
|
||||
)
|
||||
|
||||
self.actions[self.position].copy_(action.squeeze(dim=0))
|
||||
self.rewards[self.position] = reward
|
||||
|
@ -288,10 +310,17 @@ class ReplayBuffer:
|
|||
batch_state = {}
|
||||
batch_next_state = {}
|
||||
|
||||
# First pass: load all tensors to target device
|
||||
# First pass: load all state tensors to target device
|
||||
for key in self.states:
|
||||
batch_state[key] = self.states[key][idx].to(self.device)
|
||||
batch_next_state[key] = self.next_states[key][idx].to(self.device)
|
||||
|
||||
if not self.optimize_memory:
|
||||
# Standard approach - load next_states directly
|
||||
batch_next_state[key] = self.next_states[key][idx].to(self.device)
|
||||
else:
|
||||
# Memory-optimized approach - get next_state from the next index
|
||||
next_idx = (idx + 1) % self.capacity
|
||||
batch_next_state[key] = self.states[key][next_idx].to(self.device)
|
||||
|
||||
# Apply image augmentation in a batched way if needed
|
||||
if self.use_drq and image_keys:
|
||||
|
@ -343,6 +372,7 @@ class ReplayBuffer:
|
|||
image_augmentation_function: Optional[Callable] = None,
|
||||
use_drq: bool = True,
|
||||
storage_device: str = "cpu",
|
||||
optimize_memory: bool = False,
|
||||
) -> "ReplayBuffer":
|
||||
"""
|
||||
Convert a LeRobotDataset into a ReplayBuffer.
|
||||
|
@ -358,6 +388,7 @@ class ReplayBuffer:
|
|||
If None, uses default random shift with pad=4.
|
||||
use_drq (bool): Whether to use DrQ image augmentation when sampling.
|
||||
storage_device (str): Device for storing tensor data. Using "cpu" saves GPU memory.
|
||||
optimize_memory (bool): If True, reduces memory usage by not duplicating state data.
|
||||
|
||||
Returns:
|
||||
ReplayBuffer: The replay buffer with dataset transitions.
|
||||
|
@ -378,6 +409,7 @@ class ReplayBuffer:
|
|||
image_augmentation_function=image_augmentation_function,
|
||||
use_drq=use_drq,
|
||||
storage_device=storage_device,
|
||||
optimize_memory=optimize_memory,
|
||||
)
|
||||
|
||||
# Convert dataset to transitions
|
||||
|
@ -934,78 +966,268 @@ if __name__ == "__main__":
|
|||
"WARNING: Significant difference between original and reconverted values"
|
||||
)
|
||||
|
||||
print("\nTesting real LeRobotDataset conversion...")
|
||||
try:
|
||||
# Try to use a real dataset if available
|
||||
dataset_name = "AdilZtn/Maniskill-Pushcube-demonstration-small"
|
||||
dataset = LeRobotDataset(repo_id=dataset_name)
|
||||
print("\nAll previous tests completed!")
|
||||
|
||||
# Print available keys to debug
|
||||
sample = dataset[0]
|
||||
print("Available keys in first dataset:", list(sample.keys()))
|
||||
# ===== Test for memory optimization =====
|
||||
print("\n===== Testing Memory Optimization =====")
|
||||
|
||||
# Check for required keys
|
||||
if "action" not in sample or "next.reward" not in sample:
|
||||
print("Dataset missing essential keys. Cannot convert.")
|
||||
raise ValueError("Missing required keys in dataset")
|
||||
# Create two buffers, one with memory optimization and one without
|
||||
standard_buffer = ReplayBuffer(
|
||||
capacity=1000,
|
||||
device="cpu",
|
||||
state_keys=["observation.image", "observation.state"],
|
||||
storage_device="cpu",
|
||||
optimize_memory=False,
|
||||
use_drq=True,
|
||||
)
|
||||
|
||||
# Auto-detect appropriate state keys
|
||||
image_keys = []
|
||||
state_keys = []
|
||||
for k, v in sample.items():
|
||||
# Skip metadata keys and action/reward keys
|
||||
if k in {
|
||||
"index",
|
||||
"episode_index",
|
||||
"frame_index",
|
||||
"timestamp",
|
||||
"task_index",
|
||||
"action",
|
||||
"next.reward",
|
||||
"next.done",
|
||||
}:
|
||||
continue
|
||||
optimized_buffer = ReplayBuffer(
|
||||
capacity=1000,
|
||||
device="cpu",
|
||||
state_keys=["observation.image", "observation.state"],
|
||||
storage_device="cpu",
|
||||
optimize_memory=True,
|
||||
use_drq=True,
|
||||
)
|
||||
|
||||
# Infer key type from tensor shape
|
||||
if isinstance(v, torch.Tensor):
|
||||
if len(v.shape) == 3 and (v.shape[0] == 3 or v.shape[0] == 1):
|
||||
# Likely an image (channels, height, width)
|
||||
image_keys.append(k)
|
||||
else:
|
||||
# Likely state or other vector
|
||||
state_keys.append(k)
|
||||
# Generate sample data with larger state dimensions for better memory impact
|
||||
print("Generating test data...")
|
||||
num_episodes = 10
|
||||
steps_per_episode = 50
|
||||
total_steps = num_episodes * steps_per_episode
|
||||
|
||||
print(f"Detected image keys: {image_keys}")
|
||||
print(f"Detected state keys: {state_keys}")
|
||||
for episode in range(num_episodes):
|
||||
for step in range(steps_per_episode):
|
||||
# Index in the overall sequence
|
||||
i = episode * steps_per_episode + step
|
||||
|
||||
if not image_keys and not state_keys:
|
||||
print("No usable keys found in dataset, skipping further tests")
|
||||
raise ValueError("No usable keys found in dataset")
|
||||
# Create state with identifiable values
|
||||
img = torch.ones((3, 84, 84)) * (i / total_steps)
|
||||
state_vec = torch.ones((10,)) * (i / total_steps)
|
||||
|
||||
# Convert to ReplayBuffer with detected keys
|
||||
replay_buffer = ReplayBuffer.from_lerobot_dataset(
|
||||
lerobot_dataset=dataset,
|
||||
state_keys=image_keys + state_keys,
|
||||
device="cpu",
|
||||
state = {
|
||||
"observation.image": img.unsqueeze(0),
|
||||
"observation.state": state_vec.unsqueeze(0),
|
||||
}
|
||||
|
||||
# Create next state (i+1 or same as current if last in episode)
|
||||
is_last_step = step == steps_per_episode - 1
|
||||
|
||||
if is_last_step:
|
||||
# At episode end, next state = current state
|
||||
next_img = img.clone()
|
||||
next_state_vec = state_vec.clone()
|
||||
done = True
|
||||
truncated = False
|
||||
else:
|
||||
# Within episode, next state has incremented value
|
||||
next_val = (i + 1) / total_steps
|
||||
next_img = torch.ones((3, 84, 84)) * next_val
|
||||
next_state_vec = torch.ones((10,)) * next_val
|
||||
done = False
|
||||
truncated = False
|
||||
|
||||
next_state = {
|
||||
"observation.image": next_img.unsqueeze(0),
|
||||
"observation.state": next_state_vec.unsqueeze(0),
|
||||
}
|
||||
|
||||
# Action and reward
|
||||
action = torch.tensor([[i / total_steps]])
|
||||
reward = float(i / total_steps)
|
||||
|
||||
# Add to both buffers
|
||||
standard_buffer.add(state, action, reward, next_state, done, truncated)
|
||||
optimized_buffer.add(state, action, reward, next_state, done, truncated)
|
||||
|
||||
# Verify episode boundaries with our simplified approach
|
||||
print("\nVerifying simplified memory optimization...")
|
||||
|
||||
# Test with a new buffer with a small sequence
|
||||
test_buffer = ReplayBuffer(
|
||||
capacity=20,
|
||||
device="cpu",
|
||||
state_keys=["value"],
|
||||
storage_device="cpu",
|
||||
optimize_memory=True,
|
||||
use_drq=False,
|
||||
)
|
||||
|
||||
# Add a simple sequence with known episode boundaries
|
||||
for i in range(20):
|
||||
val = float(i)
|
||||
state = {"value": torch.tensor([[val]]).float()}
|
||||
next_val = float(i + 1) if i % 5 != 4 else val # Episode ends every 5 steps
|
||||
next_state = {"value": torch.tensor([[next_val]]).float()}
|
||||
|
||||
# Set done=True at every 5th step
|
||||
done = (i % 5) == 4
|
||||
action = torch.tensor([[0.0]])
|
||||
reward = 1.0
|
||||
truncated = False
|
||||
|
||||
test_buffer.add(state, action, reward, next_state, done, truncated)
|
||||
|
||||
# Get sequential batch for verification
|
||||
sequential_batch_size = test_buffer.size
|
||||
all_indices = torch.arange(sequential_batch_size, device=test_buffer.storage_device)
|
||||
|
||||
# Get state tensors
|
||||
batch_state = {
|
||||
"value": test_buffer.states["value"][all_indices].to(test_buffer.device)
|
||||
}
|
||||
|
||||
# Get next_state using memory-optimized approach (simply index+1)
|
||||
next_indices = (all_indices + 1) % test_buffer.capacity
|
||||
batch_next_state = {
|
||||
"value": test_buffer.states["value"][next_indices].to(test_buffer.device)
|
||||
}
|
||||
|
||||
# Get other tensors
|
||||
batch_dones = test_buffer.dones[all_indices].to(test_buffer.device)
|
||||
|
||||
# Print sequential values
|
||||
print("State, Next State, Done (Sequential values with simplified optimization):")
|
||||
state_values = batch_state["value"].squeeze().tolist()
|
||||
next_values = batch_next_state["value"].squeeze().tolist()
|
||||
done_flags = batch_dones.tolist()
|
||||
|
||||
# Print all values
|
||||
for i in range(len(state_values)):
|
||||
print(f" {state_values[i]:.1f} → {next_values[i]:.1f}, Done: {done_flags[i]}")
|
||||
|
||||
# Explain the memory optimization tradeoff
|
||||
print("\nWith simplified memory optimization:")
|
||||
print("- We always use the next state in the buffer (index+1) as next_state")
|
||||
print("- For terminal states, this means using the first state of the next episode")
|
||||
print("- This is a common tradeoff in RL implementations for memory efficiency")
|
||||
print(
|
||||
"- Since we track done flags, the algorithm can handle these transitions correctly"
|
||||
)
|
||||
|
||||
# Test random sampling
|
||||
print("\nVerifying random sampling with simplified memory optimization...")
|
||||
random_samples = test_buffer.sample(20) # Sample all transitions
|
||||
|
||||
# Extract values
|
||||
random_state_values = random_samples["state"]["value"].squeeze().tolist()
|
||||
random_next_values = random_samples["next_state"]["value"].squeeze().tolist()
|
||||
random_done_flags = random_samples["done"].bool().tolist()
|
||||
|
||||
# Print a few samples
|
||||
print("Random samples - State, Next State, Done (First 10):")
|
||||
for i in range(10):
|
||||
print(
|
||||
f" {random_state_values[i]:.1f} → {random_next_values[i]:.1f}, Done: {random_done_flags[i]}"
|
||||
)
|
||||
print(f"Loaded {len(replay_buffer)} transitions from {dataset_name}")
|
||||
|
||||
# Test sampling
|
||||
real_batch = replay_buffer.sample(batch_size)
|
||||
print("Sampled batch from real dataset, state shapes:")
|
||||
for key in real_batch["state"]:
|
||||
print(f" {key}: {real_batch['state'][key].shape}")
|
||||
# Calculate memory savings
|
||||
# Assume optimized_buffer and standard_buffer have already been initialized and filled
|
||||
std_mem = (
|
||||
sum(
|
||||
standard_buffer.states[key].nelement()
|
||||
* standard_buffer.states[key].element_size()
|
||||
for key in standard_buffer.states
|
||||
)
|
||||
* 2
|
||||
)
|
||||
opt_mem = sum(
|
||||
optimized_buffer.states[key].nelement()
|
||||
* optimized_buffer.states[key].element_size()
|
||||
for key in optimized_buffer.states
|
||||
)
|
||||
|
||||
# Convert back to LeRobotDataset
|
||||
with TemporaryDirectory() as temp_dir:
|
||||
replay_buffer_converted = replay_buffer.to_lerobot_dataset(
|
||||
repo_id="test/real_dataset_converted",
|
||||
root=os.path.join(temp_dir, "dataset2"),
|
||||
)
|
||||
print(
|
||||
f"Successfully converted back to LeRobotDataset with {len(replay_buffer_converted)} frames"
|
||||
)
|
||||
savings_percent = (std_mem - opt_mem) / std_mem * 100
|
||||
|
||||
except Exception as e:
|
||||
print(f"Real dataset test failed: {e}")
|
||||
print("This is expected if running offline or if the dataset is not available.")
|
||||
print(f"\nMemory optimization result:")
|
||||
print(f"- Standard buffer state memory: {std_mem / (1024 * 1024):.2f} MB")
|
||||
print(f"- Optimized buffer state memory: {opt_mem / (1024 * 1024):.2f} MB")
|
||||
print(f"- Memory savings for state tensors: {savings_percent:.1f}%")
|
||||
|
||||
print("\nAll memory optimization tests completed!")
|
||||
|
||||
# # ===== Test real dataset conversion =====
|
||||
# print("\n===== Testing Real LeRobotDataset Conversion =====")
|
||||
# try:
|
||||
# # Try to use a real dataset if available
|
||||
# dataset_name = "AdilZtn/Maniskill-Pushcube-demonstration-small"
|
||||
# dataset = LeRobotDataset(repo_id=dataset_name)
|
||||
|
||||
# # Print available keys to debug
|
||||
# sample = dataset[0]
|
||||
# print("Available keys in dataset:", list(sample.keys()))
|
||||
|
||||
# # Check for required keys
|
||||
# if "action" not in sample or "next.reward" not in sample:
|
||||
# print("Dataset missing essential keys. Cannot convert.")
|
||||
# raise ValueError("Missing required keys in dataset")
|
||||
|
||||
# # Auto-detect appropriate state keys
|
||||
# image_keys = []
|
||||
# state_keys = []
|
||||
# for k, v in sample.items():
|
||||
# # Skip metadata keys and action/reward keys
|
||||
# if k in {
|
||||
# "index",
|
||||
# "episode_index",
|
||||
# "frame_index",
|
||||
# "timestamp",
|
||||
# "task_index",
|
||||
# "action",
|
||||
# "next.reward",
|
||||
# "next.done",
|
||||
# }:
|
||||
# continue
|
||||
|
||||
# # Infer key type from tensor shape
|
||||
# if isinstance(v, torch.Tensor):
|
||||
# if len(v.shape) == 3 and (v.shape[0] == 3 or v.shape[0] == 1):
|
||||
# # Likely an image (channels, height, width)
|
||||
# image_keys.append(k)
|
||||
# else:
|
||||
# # Likely state or other vector
|
||||
# state_keys.append(k)
|
||||
|
||||
# print(f"Detected image keys: {image_keys}")
|
||||
# print(f"Detected state keys: {state_keys}")
|
||||
|
||||
# if not image_keys and not state_keys:
|
||||
# print("No usable keys found in dataset, skipping further tests")
|
||||
# raise ValueError("No usable keys found in dataset")
|
||||
|
||||
# # Test with standard and memory-optimized buffers
|
||||
# for optimize_memory in [False, True]:
|
||||
# buffer_type = "Standard" if not optimize_memory else "Memory-optimized"
|
||||
# print(f"\nTesting {buffer_type} buffer with real dataset...")
|
||||
|
||||
# # Convert to ReplayBuffer with detected keys
|
||||
# replay_buffer = ReplayBuffer.from_lerobot_dataset(
|
||||
# lerobot_dataset=dataset,
|
||||
# state_keys=image_keys + state_keys,
|
||||
# device="cpu",
|
||||
# optimize_memory=optimize_memory,
|
||||
# )
|
||||
# print(f"Loaded {len(replay_buffer)} transitions from {dataset_name}")
|
||||
|
||||
# # Test sampling
|
||||
# real_batch = replay_buffer.sample(32)
|
||||
# print(f"Sampled batch from real dataset ({buffer_type}), state shapes:")
|
||||
# for key in real_batch["state"]:
|
||||
# print(f" {key}: {real_batch['state'][key].shape}")
|
||||
|
||||
# # Convert back to LeRobotDataset
|
||||
# with TemporaryDirectory() as temp_dir:
|
||||
# dataset_name = f"test/real_dataset_converted_{buffer_type}"
|
||||
# replay_buffer_converted = replay_buffer.to_lerobot_dataset(
|
||||
# repo_id=dataset_name,
|
||||
# root=os.path.join(temp_dir, f"dataset_{buffer_type}"),
|
||||
# )
|
||||
# print(
|
||||
# f"Successfully converted back to LeRobotDataset with {len(replay_buffer_converted)} frames"
|
||||
# )
|
||||
|
||||
# except Exception as e:
|
||||
# print(f"Real dataset test failed: {e}")
|
||||
# print("This is expected if running offline or if the dataset is not available.")
|
||||
|
||||
# print("\nAll tests completed!")
|
||||
|
|
|
@ -154,6 +154,7 @@ def initialize_replay_buffer(
|
|||
device=device,
|
||||
state_keys=cfg.policy.input_shapes.keys(),
|
||||
storage_device=device,
|
||||
optimize_memory=True,
|
||||
)
|
||||
|
||||
dataset = LeRobotDataset(
|
||||
|
@ -166,6 +167,7 @@ def initialize_replay_buffer(
|
|||
capacity=cfg.training.online_buffer_capacity,
|
||||
device=device,
|
||||
state_keys=cfg.policy.input_shapes.keys(),
|
||||
optimize_memory=True,
|
||||
)
|
||||
|
||||
|
||||
|
@ -648,6 +650,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||
action_mask=active_action_dims,
|
||||
action_delta=cfg.env.wrapper.delta_action,
|
||||
storage_device=device,
|
||||
optimize_memory=True,
|
||||
)
|
||||
batch_size: int = batch_size // 2 # We will sample from both replay buffer
|
||||
|
||||
|
|
Loading…
Reference in New Issue