Add memory optimization option to ReplayBuffer
- Introduce `optimize_memory` parameter to reduce memory usage in replay buffer - Implement simplified memory optimization by not storing duplicate next_states - Update learner server and buffer initialization to use memory optimization by default
This commit is contained in:
parent
5b4a7aa81d
commit
1df9ee4f2d
|
@ -178,6 +178,7 @@ class ReplayBuffer:
|
||||||
image_augmentation_function: Optional[Callable] = None,
|
image_augmentation_function: Optional[Callable] = None,
|
||||||
use_drq: bool = True,
|
use_drq: bool = True,
|
||||||
storage_device: str = "cpu",
|
storage_device: str = "cpu",
|
||||||
|
optimize_memory: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
|
@ -189,6 +190,8 @@ class ReplayBuffer:
|
||||||
use_drq (bool): Whether to use the default DRQ image augmentation style, when sampling in the buffer.
|
use_drq (bool): Whether to use the default DRQ image augmentation style, when sampling in the buffer.
|
||||||
storage_device: The device (e.g. "cpu" or "cuda:0") where the data will be stored.
|
storage_device: The device (e.g. "cpu" or "cuda:0") where the data will be stored.
|
||||||
Using "cpu" can help save GPU memory.
|
Using "cpu" can help save GPU memory.
|
||||||
|
optimize_memory (bool): If True, optimizes memory by not storing duplicate next_states when
|
||||||
|
they can be derived from states. This is useful for large datasets where next_state[i] = state[i+1].
|
||||||
"""
|
"""
|
||||||
self.capacity = capacity
|
self.capacity = capacity
|
||||||
self.device = device
|
self.device = device
|
||||||
|
@ -196,6 +199,12 @@ class ReplayBuffer:
|
||||||
self.position = 0
|
self.position = 0
|
||||||
self.size = 0
|
self.size = 0
|
||||||
self.initialized = False
|
self.initialized = False
|
||||||
|
self.optimize_memory = optimize_memory
|
||||||
|
|
||||||
|
# Track episode boundaries for memory optimization
|
||||||
|
self.episode_ends = torch.zeros(
|
||||||
|
capacity, dtype=torch.bool, device=storage_device
|
||||||
|
)
|
||||||
|
|
||||||
# If no state_keys provided, default to an empty list
|
# If no state_keys provided, default to an empty list
|
||||||
self.state_keys = state_keys if state_keys is not None else []
|
self.state_keys = state_keys if state_keys is not None else []
|
||||||
|
@ -220,10 +229,18 @@ class ReplayBuffer:
|
||||||
(self.capacity, *action_shape), device=self.storage_device
|
(self.capacity, *action_shape), device=self.storage_device
|
||||||
)
|
)
|
||||||
self.rewards = torch.empty((self.capacity,), device=self.storage_device)
|
self.rewards = torch.empty((self.capacity,), device=self.storage_device)
|
||||||
self.next_states = {
|
|
||||||
key: torch.empty((self.capacity, *shape), device=self.storage_device)
|
if not self.optimize_memory:
|
||||||
for key, shape in state_shapes.items()
|
# Standard approach: store states and next_states separately
|
||||||
}
|
self.next_states = {
|
||||||
|
key: torch.empty((self.capacity, *shape), device=self.storage_device)
|
||||||
|
for key, shape in state_shapes.items()
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
# Memory-optimized approach: don't allocate next_states buffer
|
||||||
|
# Just create a reference to states for consistent API
|
||||||
|
self.next_states = self.states # Just a reference for API consistency
|
||||||
|
|
||||||
self.dones = torch.empty(
|
self.dones = torch.empty(
|
||||||
(self.capacity,), dtype=torch.bool, device=self.storage_device
|
(self.capacity,), dtype=torch.bool, device=self.storage_device
|
||||||
)
|
)
|
||||||
|
@ -253,7 +270,12 @@ class ReplayBuffer:
|
||||||
# Store the transition in pre-allocated tensors
|
# Store the transition in pre-allocated tensors
|
||||||
for key in self.states:
|
for key in self.states:
|
||||||
self.states[key][self.position].copy_(state[key].squeeze(dim=0))
|
self.states[key][self.position].copy_(state[key].squeeze(dim=0))
|
||||||
self.next_states[key][self.position].copy_(next_state[key].squeeze(dim=0))
|
|
||||||
|
if not self.optimize_memory:
|
||||||
|
# Only store next_states if not optimizing memory
|
||||||
|
self.next_states[key][self.position].copy_(
|
||||||
|
next_state[key].squeeze(dim=0)
|
||||||
|
)
|
||||||
|
|
||||||
self.actions[self.position].copy_(action.squeeze(dim=0))
|
self.actions[self.position].copy_(action.squeeze(dim=0))
|
||||||
self.rewards[self.position] = reward
|
self.rewards[self.position] = reward
|
||||||
|
@ -288,10 +310,17 @@ class ReplayBuffer:
|
||||||
batch_state = {}
|
batch_state = {}
|
||||||
batch_next_state = {}
|
batch_next_state = {}
|
||||||
|
|
||||||
# First pass: load all tensors to target device
|
# First pass: load all state tensors to target device
|
||||||
for key in self.states:
|
for key in self.states:
|
||||||
batch_state[key] = self.states[key][idx].to(self.device)
|
batch_state[key] = self.states[key][idx].to(self.device)
|
||||||
batch_next_state[key] = self.next_states[key][idx].to(self.device)
|
|
||||||
|
if not self.optimize_memory:
|
||||||
|
# Standard approach - load next_states directly
|
||||||
|
batch_next_state[key] = self.next_states[key][idx].to(self.device)
|
||||||
|
else:
|
||||||
|
# Memory-optimized approach - get next_state from the next index
|
||||||
|
next_idx = (idx + 1) % self.capacity
|
||||||
|
batch_next_state[key] = self.states[key][next_idx].to(self.device)
|
||||||
|
|
||||||
# Apply image augmentation in a batched way if needed
|
# Apply image augmentation in a batched way if needed
|
||||||
if self.use_drq and image_keys:
|
if self.use_drq and image_keys:
|
||||||
|
@ -343,6 +372,7 @@ class ReplayBuffer:
|
||||||
image_augmentation_function: Optional[Callable] = None,
|
image_augmentation_function: Optional[Callable] = None,
|
||||||
use_drq: bool = True,
|
use_drq: bool = True,
|
||||||
storage_device: str = "cpu",
|
storage_device: str = "cpu",
|
||||||
|
optimize_memory: bool = False,
|
||||||
) -> "ReplayBuffer":
|
) -> "ReplayBuffer":
|
||||||
"""
|
"""
|
||||||
Convert a LeRobotDataset into a ReplayBuffer.
|
Convert a LeRobotDataset into a ReplayBuffer.
|
||||||
|
@ -358,6 +388,7 @@ class ReplayBuffer:
|
||||||
If None, uses default random shift with pad=4.
|
If None, uses default random shift with pad=4.
|
||||||
use_drq (bool): Whether to use DrQ image augmentation when sampling.
|
use_drq (bool): Whether to use DrQ image augmentation when sampling.
|
||||||
storage_device (str): Device for storing tensor data. Using "cpu" saves GPU memory.
|
storage_device (str): Device for storing tensor data. Using "cpu" saves GPU memory.
|
||||||
|
optimize_memory (bool): If True, reduces memory usage by not duplicating state data.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
ReplayBuffer: The replay buffer with dataset transitions.
|
ReplayBuffer: The replay buffer with dataset transitions.
|
||||||
|
@ -378,6 +409,7 @@ class ReplayBuffer:
|
||||||
image_augmentation_function=image_augmentation_function,
|
image_augmentation_function=image_augmentation_function,
|
||||||
use_drq=use_drq,
|
use_drq=use_drq,
|
||||||
storage_device=storage_device,
|
storage_device=storage_device,
|
||||||
|
optimize_memory=optimize_memory,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Convert dataset to transitions
|
# Convert dataset to transitions
|
||||||
|
@ -934,78 +966,268 @@ if __name__ == "__main__":
|
||||||
"WARNING: Significant difference between original and reconverted values"
|
"WARNING: Significant difference between original and reconverted values"
|
||||||
)
|
)
|
||||||
|
|
||||||
print("\nTesting real LeRobotDataset conversion...")
|
print("\nAll previous tests completed!")
|
||||||
try:
|
|
||||||
# Try to use a real dataset if available
|
|
||||||
dataset_name = "AdilZtn/Maniskill-Pushcube-demonstration-small"
|
|
||||||
dataset = LeRobotDataset(repo_id=dataset_name)
|
|
||||||
|
|
||||||
# Print available keys to debug
|
# ===== Test for memory optimization =====
|
||||||
sample = dataset[0]
|
print("\n===== Testing Memory Optimization =====")
|
||||||
print("Available keys in first dataset:", list(sample.keys()))
|
|
||||||
|
|
||||||
# Check for required keys
|
# Create two buffers, one with memory optimization and one without
|
||||||
if "action" not in sample or "next.reward" not in sample:
|
standard_buffer = ReplayBuffer(
|
||||||
print("Dataset missing essential keys. Cannot convert.")
|
capacity=1000,
|
||||||
raise ValueError("Missing required keys in dataset")
|
device="cpu",
|
||||||
|
state_keys=["observation.image", "observation.state"],
|
||||||
|
storage_device="cpu",
|
||||||
|
optimize_memory=False,
|
||||||
|
use_drq=True,
|
||||||
|
)
|
||||||
|
|
||||||
# Auto-detect appropriate state keys
|
optimized_buffer = ReplayBuffer(
|
||||||
image_keys = []
|
capacity=1000,
|
||||||
state_keys = []
|
device="cpu",
|
||||||
for k, v in sample.items():
|
state_keys=["observation.image", "observation.state"],
|
||||||
# Skip metadata keys and action/reward keys
|
storage_device="cpu",
|
||||||
if k in {
|
optimize_memory=True,
|
||||||
"index",
|
use_drq=True,
|
||||||
"episode_index",
|
)
|
||||||
"frame_index",
|
|
||||||
"timestamp",
|
|
||||||
"task_index",
|
|
||||||
"action",
|
|
||||||
"next.reward",
|
|
||||||
"next.done",
|
|
||||||
}:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Infer key type from tensor shape
|
# Generate sample data with larger state dimensions for better memory impact
|
||||||
if isinstance(v, torch.Tensor):
|
print("Generating test data...")
|
||||||
if len(v.shape) == 3 and (v.shape[0] == 3 or v.shape[0] == 1):
|
num_episodes = 10
|
||||||
# Likely an image (channels, height, width)
|
steps_per_episode = 50
|
||||||
image_keys.append(k)
|
total_steps = num_episodes * steps_per_episode
|
||||||
else:
|
|
||||||
# Likely state or other vector
|
|
||||||
state_keys.append(k)
|
|
||||||
|
|
||||||
print(f"Detected image keys: {image_keys}")
|
for episode in range(num_episodes):
|
||||||
print(f"Detected state keys: {state_keys}")
|
for step in range(steps_per_episode):
|
||||||
|
# Index in the overall sequence
|
||||||
|
i = episode * steps_per_episode + step
|
||||||
|
|
||||||
if not image_keys and not state_keys:
|
# Create state with identifiable values
|
||||||
print("No usable keys found in dataset, skipping further tests")
|
img = torch.ones((3, 84, 84)) * (i / total_steps)
|
||||||
raise ValueError("No usable keys found in dataset")
|
state_vec = torch.ones((10,)) * (i / total_steps)
|
||||||
|
|
||||||
# Convert to ReplayBuffer with detected keys
|
state = {
|
||||||
replay_buffer = ReplayBuffer.from_lerobot_dataset(
|
"observation.image": img.unsqueeze(0),
|
||||||
lerobot_dataset=dataset,
|
"observation.state": state_vec.unsqueeze(0),
|
||||||
state_keys=image_keys + state_keys,
|
}
|
||||||
device="cpu",
|
|
||||||
|
# Create next state (i+1 or same as current if last in episode)
|
||||||
|
is_last_step = step == steps_per_episode - 1
|
||||||
|
|
||||||
|
if is_last_step:
|
||||||
|
# At episode end, next state = current state
|
||||||
|
next_img = img.clone()
|
||||||
|
next_state_vec = state_vec.clone()
|
||||||
|
done = True
|
||||||
|
truncated = False
|
||||||
|
else:
|
||||||
|
# Within episode, next state has incremented value
|
||||||
|
next_val = (i + 1) / total_steps
|
||||||
|
next_img = torch.ones((3, 84, 84)) * next_val
|
||||||
|
next_state_vec = torch.ones((10,)) * next_val
|
||||||
|
done = False
|
||||||
|
truncated = False
|
||||||
|
|
||||||
|
next_state = {
|
||||||
|
"observation.image": next_img.unsqueeze(0),
|
||||||
|
"observation.state": next_state_vec.unsqueeze(0),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Action and reward
|
||||||
|
action = torch.tensor([[i / total_steps]])
|
||||||
|
reward = float(i / total_steps)
|
||||||
|
|
||||||
|
# Add to both buffers
|
||||||
|
standard_buffer.add(state, action, reward, next_state, done, truncated)
|
||||||
|
optimized_buffer.add(state, action, reward, next_state, done, truncated)
|
||||||
|
|
||||||
|
# Verify episode boundaries with our simplified approach
|
||||||
|
print("\nVerifying simplified memory optimization...")
|
||||||
|
|
||||||
|
# Test with a new buffer with a small sequence
|
||||||
|
test_buffer = ReplayBuffer(
|
||||||
|
capacity=20,
|
||||||
|
device="cpu",
|
||||||
|
state_keys=["value"],
|
||||||
|
storage_device="cpu",
|
||||||
|
optimize_memory=True,
|
||||||
|
use_drq=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add a simple sequence with known episode boundaries
|
||||||
|
for i in range(20):
|
||||||
|
val = float(i)
|
||||||
|
state = {"value": torch.tensor([[val]]).float()}
|
||||||
|
next_val = float(i + 1) if i % 5 != 4 else val # Episode ends every 5 steps
|
||||||
|
next_state = {"value": torch.tensor([[next_val]]).float()}
|
||||||
|
|
||||||
|
# Set done=True at every 5th step
|
||||||
|
done = (i % 5) == 4
|
||||||
|
action = torch.tensor([[0.0]])
|
||||||
|
reward = 1.0
|
||||||
|
truncated = False
|
||||||
|
|
||||||
|
test_buffer.add(state, action, reward, next_state, done, truncated)
|
||||||
|
|
||||||
|
# Get sequential batch for verification
|
||||||
|
sequential_batch_size = test_buffer.size
|
||||||
|
all_indices = torch.arange(sequential_batch_size, device=test_buffer.storage_device)
|
||||||
|
|
||||||
|
# Get state tensors
|
||||||
|
batch_state = {
|
||||||
|
"value": test_buffer.states["value"][all_indices].to(test_buffer.device)
|
||||||
|
}
|
||||||
|
|
||||||
|
# Get next_state using memory-optimized approach (simply index+1)
|
||||||
|
next_indices = (all_indices + 1) % test_buffer.capacity
|
||||||
|
batch_next_state = {
|
||||||
|
"value": test_buffer.states["value"][next_indices].to(test_buffer.device)
|
||||||
|
}
|
||||||
|
|
||||||
|
# Get other tensors
|
||||||
|
batch_dones = test_buffer.dones[all_indices].to(test_buffer.device)
|
||||||
|
|
||||||
|
# Print sequential values
|
||||||
|
print("State, Next State, Done (Sequential values with simplified optimization):")
|
||||||
|
state_values = batch_state["value"].squeeze().tolist()
|
||||||
|
next_values = batch_next_state["value"].squeeze().tolist()
|
||||||
|
done_flags = batch_dones.tolist()
|
||||||
|
|
||||||
|
# Print all values
|
||||||
|
for i in range(len(state_values)):
|
||||||
|
print(f" {state_values[i]:.1f} → {next_values[i]:.1f}, Done: {done_flags[i]}")
|
||||||
|
|
||||||
|
# Explain the memory optimization tradeoff
|
||||||
|
print("\nWith simplified memory optimization:")
|
||||||
|
print("- We always use the next state in the buffer (index+1) as next_state")
|
||||||
|
print("- For terminal states, this means using the first state of the next episode")
|
||||||
|
print("- This is a common tradeoff in RL implementations for memory efficiency")
|
||||||
|
print(
|
||||||
|
"- Since we track done flags, the algorithm can handle these transitions correctly"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test random sampling
|
||||||
|
print("\nVerifying random sampling with simplified memory optimization...")
|
||||||
|
random_samples = test_buffer.sample(20) # Sample all transitions
|
||||||
|
|
||||||
|
# Extract values
|
||||||
|
random_state_values = random_samples["state"]["value"].squeeze().tolist()
|
||||||
|
random_next_values = random_samples["next_state"]["value"].squeeze().tolist()
|
||||||
|
random_done_flags = random_samples["done"].bool().tolist()
|
||||||
|
|
||||||
|
# Print a few samples
|
||||||
|
print("Random samples - State, Next State, Done (First 10):")
|
||||||
|
for i in range(10):
|
||||||
|
print(
|
||||||
|
f" {random_state_values[i]:.1f} → {random_next_values[i]:.1f}, Done: {random_done_flags[i]}"
|
||||||
)
|
)
|
||||||
print(f"Loaded {len(replay_buffer)} transitions from {dataset_name}")
|
|
||||||
|
|
||||||
# Test sampling
|
# Calculate memory savings
|
||||||
real_batch = replay_buffer.sample(batch_size)
|
# Assume optimized_buffer and standard_buffer have already been initialized and filled
|
||||||
print("Sampled batch from real dataset, state shapes:")
|
std_mem = (
|
||||||
for key in real_batch["state"]:
|
sum(
|
||||||
print(f" {key}: {real_batch['state'][key].shape}")
|
standard_buffer.states[key].nelement()
|
||||||
|
* standard_buffer.states[key].element_size()
|
||||||
|
for key in standard_buffer.states
|
||||||
|
)
|
||||||
|
* 2
|
||||||
|
)
|
||||||
|
opt_mem = sum(
|
||||||
|
optimized_buffer.states[key].nelement()
|
||||||
|
* optimized_buffer.states[key].element_size()
|
||||||
|
for key in optimized_buffer.states
|
||||||
|
)
|
||||||
|
|
||||||
# Convert back to LeRobotDataset
|
savings_percent = (std_mem - opt_mem) / std_mem * 100
|
||||||
with TemporaryDirectory() as temp_dir:
|
|
||||||
replay_buffer_converted = replay_buffer.to_lerobot_dataset(
|
|
||||||
repo_id="test/real_dataset_converted",
|
|
||||||
root=os.path.join(temp_dir, "dataset2"),
|
|
||||||
)
|
|
||||||
print(
|
|
||||||
f"Successfully converted back to LeRobotDataset with {len(replay_buffer_converted)} frames"
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
print(f"\nMemory optimization result:")
|
||||||
print(f"Real dataset test failed: {e}")
|
print(f"- Standard buffer state memory: {std_mem / (1024 * 1024):.2f} MB")
|
||||||
print("This is expected if running offline or if the dataset is not available.")
|
print(f"- Optimized buffer state memory: {opt_mem / (1024 * 1024):.2f} MB")
|
||||||
|
print(f"- Memory savings for state tensors: {savings_percent:.1f}%")
|
||||||
|
|
||||||
|
print("\nAll memory optimization tests completed!")
|
||||||
|
|
||||||
|
# # ===== Test real dataset conversion =====
|
||||||
|
# print("\n===== Testing Real LeRobotDataset Conversion =====")
|
||||||
|
# try:
|
||||||
|
# # Try to use a real dataset if available
|
||||||
|
# dataset_name = "AdilZtn/Maniskill-Pushcube-demonstration-small"
|
||||||
|
# dataset = LeRobotDataset(repo_id=dataset_name)
|
||||||
|
|
||||||
|
# # Print available keys to debug
|
||||||
|
# sample = dataset[0]
|
||||||
|
# print("Available keys in dataset:", list(sample.keys()))
|
||||||
|
|
||||||
|
# # Check for required keys
|
||||||
|
# if "action" not in sample or "next.reward" not in sample:
|
||||||
|
# print("Dataset missing essential keys. Cannot convert.")
|
||||||
|
# raise ValueError("Missing required keys in dataset")
|
||||||
|
|
||||||
|
# # Auto-detect appropriate state keys
|
||||||
|
# image_keys = []
|
||||||
|
# state_keys = []
|
||||||
|
# for k, v in sample.items():
|
||||||
|
# # Skip metadata keys and action/reward keys
|
||||||
|
# if k in {
|
||||||
|
# "index",
|
||||||
|
# "episode_index",
|
||||||
|
# "frame_index",
|
||||||
|
# "timestamp",
|
||||||
|
# "task_index",
|
||||||
|
# "action",
|
||||||
|
# "next.reward",
|
||||||
|
# "next.done",
|
||||||
|
# }:
|
||||||
|
# continue
|
||||||
|
|
||||||
|
# # Infer key type from tensor shape
|
||||||
|
# if isinstance(v, torch.Tensor):
|
||||||
|
# if len(v.shape) == 3 and (v.shape[0] == 3 or v.shape[0] == 1):
|
||||||
|
# # Likely an image (channels, height, width)
|
||||||
|
# image_keys.append(k)
|
||||||
|
# else:
|
||||||
|
# # Likely state or other vector
|
||||||
|
# state_keys.append(k)
|
||||||
|
|
||||||
|
# print(f"Detected image keys: {image_keys}")
|
||||||
|
# print(f"Detected state keys: {state_keys}")
|
||||||
|
|
||||||
|
# if not image_keys and not state_keys:
|
||||||
|
# print("No usable keys found in dataset, skipping further tests")
|
||||||
|
# raise ValueError("No usable keys found in dataset")
|
||||||
|
|
||||||
|
# # Test with standard and memory-optimized buffers
|
||||||
|
# for optimize_memory in [False, True]:
|
||||||
|
# buffer_type = "Standard" if not optimize_memory else "Memory-optimized"
|
||||||
|
# print(f"\nTesting {buffer_type} buffer with real dataset...")
|
||||||
|
|
||||||
|
# # Convert to ReplayBuffer with detected keys
|
||||||
|
# replay_buffer = ReplayBuffer.from_lerobot_dataset(
|
||||||
|
# lerobot_dataset=dataset,
|
||||||
|
# state_keys=image_keys + state_keys,
|
||||||
|
# device="cpu",
|
||||||
|
# optimize_memory=optimize_memory,
|
||||||
|
# )
|
||||||
|
# print(f"Loaded {len(replay_buffer)} transitions from {dataset_name}")
|
||||||
|
|
||||||
|
# # Test sampling
|
||||||
|
# real_batch = replay_buffer.sample(32)
|
||||||
|
# print(f"Sampled batch from real dataset ({buffer_type}), state shapes:")
|
||||||
|
# for key in real_batch["state"]:
|
||||||
|
# print(f" {key}: {real_batch['state'][key].shape}")
|
||||||
|
|
||||||
|
# # Convert back to LeRobotDataset
|
||||||
|
# with TemporaryDirectory() as temp_dir:
|
||||||
|
# dataset_name = f"test/real_dataset_converted_{buffer_type}"
|
||||||
|
# replay_buffer_converted = replay_buffer.to_lerobot_dataset(
|
||||||
|
# repo_id=dataset_name,
|
||||||
|
# root=os.path.join(temp_dir, f"dataset_{buffer_type}"),
|
||||||
|
# )
|
||||||
|
# print(
|
||||||
|
# f"Successfully converted back to LeRobotDataset with {len(replay_buffer_converted)} frames"
|
||||||
|
# )
|
||||||
|
|
||||||
|
# except Exception as e:
|
||||||
|
# print(f"Real dataset test failed: {e}")
|
||||||
|
# print("This is expected if running offline or if the dataset is not available.")
|
||||||
|
|
||||||
|
# print("\nAll tests completed!")
|
||||||
|
|
|
@ -154,6 +154,7 @@ def initialize_replay_buffer(
|
||||||
device=device,
|
device=device,
|
||||||
state_keys=cfg.policy.input_shapes.keys(),
|
state_keys=cfg.policy.input_shapes.keys(),
|
||||||
storage_device=device,
|
storage_device=device,
|
||||||
|
optimize_memory=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset = LeRobotDataset(
|
dataset = LeRobotDataset(
|
||||||
|
@ -166,6 +167,7 @@ def initialize_replay_buffer(
|
||||||
capacity=cfg.training.online_buffer_capacity,
|
capacity=cfg.training.online_buffer_capacity,
|
||||||
device=device,
|
device=device,
|
||||||
state_keys=cfg.policy.input_shapes.keys(),
|
state_keys=cfg.policy.input_shapes.keys(),
|
||||||
|
optimize_memory=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -648,6 +650,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||||
action_mask=active_action_dims,
|
action_mask=active_action_dims,
|
||||||
action_delta=cfg.env.wrapper.delta_action,
|
action_delta=cfg.env.wrapper.delta_action,
|
||||||
storage_device=device,
|
storage_device=device,
|
||||||
|
optimize_memory=True,
|
||||||
)
|
)
|
||||||
batch_size: int = batch_size // 2 # We will sample from both replay buffer
|
batch_size: int = batch_size // 2 # We will sample from both replay buffer
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue