Refactor SACPolicy and learner server for improved replay buffer management

- Updated SACPolicy to create critic heads using a list comprehension for better readability.
- Simplified the saving and loading of models using `save_model` and `load_model` functions from the safetensors library.
- Introduced `initialize_offline_replay_buffer` function in the learner server to streamline offline dataset handling and replay buffer initialization.
- Enhanced logging for dataset loading processes to improve traceability during training.
This commit is contained in:
AdilZouitine 2025-03-18 14:57:15 +00:00
parent b82faf7d8c
commit 4bb2077afa
2 changed files with 159 additions and 138 deletions

View File

@ -18,7 +18,7 @@
# TODO: (1) better device management
from copy import deepcopy
from typing import Callable, Optional, Tuple, Union, Dict
from typing import Callable, Optional, Tuple, Union, Dict, List
from pathlib import Path
import einops
@ -88,33 +88,33 @@ class SACPolicy(
encoder_critic = SACObservationEncoder(config, self.normalize_inputs)
encoder_actor = SACObservationEncoder(config, self.normalize_inputs)
# Create a list of critic heads
critic_heads = [
CriticHead(
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0],
**config.critic_network_kwargs,
)
for _ in range(config.num_critics)
]
self.critic_ensemble = CriticEnsemble(
encoder=encoder_critic,
ensemble=Ensemble(
[
CriticHead(
input_dim=encoder_critic.output_dim
+ config.output_shapes["action"][0],
**config.critic_network_kwargs,
)
for _ in range(config.num_critics)
]
),
ensemble=critic_heads,
output_normalization=self.normalize_targets,
)
# Create target critic heads as deepcopies of the original critic heads
target_critic_heads = [
CriticHead(
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0],
**config.critic_network_kwargs,
)
for _ in range(config.num_critics)
]
self.critic_target = CriticEnsemble(
encoder=encoder_critic,
ensemble=Ensemble(
[
CriticHead(
input_dim=encoder_critic.output_dim
+ config.output_shapes["action"][0],
**config.critic_network_kwargs,
)
for _ in range(config.num_critics)
]
),
ensemble=target_critic_heads,
output_normalization=self.normalize_targets,
)
@ -149,19 +149,9 @@ class SACPolicy(
import json
from dataclasses import asdict
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE, CONFIG_NAME
from safetensors.torch import save_file
from safetensors.torch import save_model
# NOTE: Using tensordict.from_modules in the model to batch the inference using torch.vmap
# implies one side effect: the __batch_size parameters are saved in the state_dict
# __batch_size is torch.Size or safetensor save only torch.Tensor
# so we need to filter them out before saving
simplified_state_dict = {}
for name, param in self.named_parameters():
simplified_state_dict[name] = param
save_file(
simplified_state_dict, os.path.join(save_directory, SAFETENSORS_SINGLE_FILE)
)
save_model(self, os.path.join(save_directory, SAFETENSORS_SINGLE_FILE))
# Save config
config_dict = asdict(self.config)
@ -191,7 +181,7 @@ class SACPolicy(
from pathlib import Path
from huggingface_hub import hf_hub_download
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE, CONFIG_NAME
from safetensors.torch import load_file
from safetensors.torch import load_model
from lerobot.common.policies.sac.configuration_sac import SACConfig
# Check if model_id is a local path or a hub model ID
@ -243,28 +233,7 @@ class SACPolicy(
# Load state dict from safetensors file
if os.path.exists(safetensors_file):
# Note: The load_file function returns a dict with the parameters, but __batch_size
# is not loaded so we need to copy it from the model state_dict
# Load the parameters only
loaded_state_dict = load_file(safetensors_file, device=map_location)
# Copy batch size parameters
find_and_copy_params(
original_state_dict=model.state_dict(),
loaded_state_dict=loaded_state_dict,
pattern="__batch_size",
match_type="endswith",
)
# Copy normalization buffer parameters
find_and_copy_params(
original_state_dict=model.state_dict(),
loaded_state_dict=loaded_state_dict,
pattern="_orig_mod.output_normalization.buffer_action",
match_type="contains",
)
model.load_state_dict(loaded_state_dict, strict=False)
load_model(model, filename=safetensors_file, device=map_location)
return model
@ -594,21 +563,21 @@ class CriticEnsemble(nn.Module):
def __init__(
self,
encoder: Optional[nn.Module],
ensemble: "Ensemble[CriticHead]",
ensemble: List[CriticHead],
output_normalization: nn.Module,
init_final: Optional[float] = None,
):
super().__init__()
self.encoder = encoder
self.ensemble = ensemble
self.init_final = init_final
self.output_normalization = output_normalization
self.critics = nn.ModuleList(ensemble)
self.parameters_to_optimize = []
# Handle the case where a part of the encoder if frozen
if self.encoder is not None:
self.parameters_to_optimize += list(self.encoder.parameters_to_optimize)
self.parameters_to_optimize += list(self.ensemble.parameters())
self.parameters_to_optimize += list(self.critics.parameters())
def forward(
self,
@ -632,8 +601,15 @@ class CriticEnsemble(nn.Module):
)
inputs = torch.cat([obs_enc, actions], dim=-1)
q_values = self.ensemble(inputs) # [num_critics, B, 1]
return q_values.squeeze(-1) # [num_critics, B]
# Loop through critics and collect outputs
q_values = []
for critic in self.critics:
q_values.append(critic(inputs))
# Stack outputs to match expected shape [num_critics, batch_size]
q_values = torch.stack([q.squeeze(-1) for q in q_values], dim=0)
return q_values
class Policy(nn.Module):
@ -706,9 +682,9 @@ class Policy(nn.Module):
# Compute standard deviations
if self.fixed_std is None:
log_std = self.std_layer(outputs)
assert not torch.isnan(
log_std
).any(), "[ERROR] log_std became NaN after std_layer!"
assert not torch.isnan(log_std).any(), (
"[ERROR] log_std became NaN after std_layer!"
)
if self.use_tanh_squash:
log_std = torch.tanh(log_std)
@ -1025,73 +1001,78 @@ def _convert_normalization_params_to_tensor(normalization_params: dict) -> dict:
if __name__ == "__main__":
# Test the SACObservationEncoder
# Benchmark the CriticEnsemble performance
import time
config = SACConfig()
config.num_critics = 10
config.vision_encoder_name = None
encoder = SACObservationEncoder(config, nn.Identity())
# actor_encoder = SACObservationEncoder(config)
# encoder = torch.compile(encoder)
# Configuration
num_critics = 10
batch_size = 32
action_dim = 7
obs_dim = 64
hidden_dims = [256, 256]
num_iterations = 100
print("Creating test environment...")
# Create a simple dummy encoder
class DummyEncoder(nn.Module):
def __init__(self):
super().__init__()
self.output_dim = obs_dim
self.parameters_to_optimize = []
def forward(self, obs):
# Just return a random tensor of the right shape
# In practice, this would encode the observations
return torch.randn(batch_size, obs_dim, device=device)
# Create critic heads
print(f"Creating {num_critics} critic heads...")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
critic_heads = [
CriticHead(
input_dim=obs_dim + action_dim,
hidden_dims=hidden_dims,
).to(device)
for _ in range(num_critics)
]
# Create the critic ensemble
print("Creating CriticEnsemble...")
critic_ensemble = CriticEnsemble(
encoder=encoder,
ensemble=Ensemble(
[
CriticHead(
input_dim=encoder.output_dim + config.output_shapes["action"][0],
**config.critic_network_kwargs,
)
for _ in range(config.num_critics)
]
),
encoder=DummyEncoder().to(device),
ensemble=critic_heads,
output_normalization=nn.Identity(),
)
# actor = Policy(
# encoder=actor_encoder,
# network=MLP(input_dim=actor_encoder.output_dim, **config.actor_network_kwargs),
# action_dim=config.output_shapes["action"][0],
# encoder_is_shared=config.shared_encoder,
# **config.policy_kwargs,
# )
# encoder = encoder.to("cuda:0")
# critic_ensemble = torch.compile(critic_ensemble)
critic_ensemble = critic_ensemble.to("cuda:0")
# actor = torch.compile(actor)
# actor = actor.to("cuda:0")
).to(device)
# Create random input data
print("Creating input data...")
obs_dict = {
"observation.image": torch.randn(8, 3, 84, 84),
"observation.state": torch.randn(8, 4),
"observation.state": torch.randn(batch_size, obs_dim, device=device),
}
actions = torch.randn(8, 2).to("cuda:0")
# obs_dict = {k: v.to("cuda:0") for k, v in obs_dict.items()}
# print("compiling...")
q_value = critic_ensemble(obs_dict, actions)
print(q_value.size())
# action = actor(obs_dict)
# print("compiled")
# start = time.perf_counter()
# for _ in range(1000):
# # features = encoder(obs_dict)
# action = actor(obs_dict)
# # q_value = critic_ensemble(obs_dict, actions)
# print("Time taken:", time.perf_counter() - start)
# Compare the performance of the ensemble vs a for loop of 16 MLPs
ensemble = Ensemble([CriticHead(256, [256, 256]) for _ in range(2)])
ensemble = ensemble.to("cuda:0")
critic = CriticHead(256, [256, 256])
critic = critic.to("cuda:0")
data_ensemble = torch.randn(8, 256).to("cuda:0")
ensemble = torch.compile(ensemble)
# critic = torch.compile(critic)
print(ensemble(data_ensemble).size())
print(critic(data_ensemble).size())
start = time.perf_counter()
for _ in range(1000):
ensemble(data_ensemble)
print("Time taken:", time.perf_counter() - start)
start = time.perf_counter()
for _ in range(1000):
for i in range(2):
critic(data_ensemble)
print("Time taken:", time.perf_counter() - start)
actions = torch.randn(batch_size, action_dim, device=device)
# Warmup run
print("Warming up...")
_ = critic_ensemble(obs_dict, actions)
# Time the forward pass
print(f"Running benchmark with {num_iterations} iterations...")
start_time = time.perf_counter()
for _ in range(num_iterations):
q_values = critic_ensemble(obs_dict, actions)
end_time = time.perf_counter()
# Print results
elapsed_time = end_time - start_time
print(f"Total time: {elapsed_time:.4f} seconds")
print(f"Average time per iteration: {elapsed_time / num_iterations * 1000:.4f} ms")
print(f"Output shape: {q_values.shape}") # Should be [num_critics, batch_size]
# Verify that all critic heads produce different outputs
# This confirms each critic head is unique
# print("\nVerifying critic outputs are different:")
# for i in range(num_critics):
# for j in range(i + 1, num_critics):
# diff = torch.abs(q_values[i] - q_values[j]).mean().item()
# print(f"Mean difference between critic {i} and {j}: {diff:.6f}")

View File

@ -121,7 +121,7 @@ def load_training_state(
return None, None
training_state = torch.load(
logger.last_checkpoint_dir / logger.training_state_file_name
logger.last_checkpoint_dir / logger.training_state_file_name, weights_only=False
)
if isinstance(training_state["optimizer"], dict):
@ -160,6 +160,7 @@ def initialize_replay_buffer(
optimize_memory=True,
)
logging.info("Resume training load the online dataset")
dataset = LeRobotDataset(
repo_id=cfg.dataset_repo_id,
local_files_only=True,
@ -174,6 +175,37 @@ def initialize_replay_buffer(
)
def initialize_offline_replay_buffer(
cfg: DictConfig,
logger: Logger,
device: str,
storage_device: str,
active_action_dims: list[int] | None = None,
) -> ReplayBuffer:
if not cfg.resume:
logging.info("make_dataset offline buffer")
offline_dataset = make_dataset(cfg)
if cfg.resume:
logging.info("load offline dataset")
offline_dataset = LeRobotDataset(
repo_id=cfg.dataset_repo_id,
local_files_only=True,
root=logger.log_dir / "dataset_offline",
)
logging.info("Convert to a offline replay buffer")
offline_replay_buffer = ReplayBuffer.from_lerobot_dataset(
offline_dataset,
device=device,
state_keys=cfg.policy.input_shapes.keys(),
action_mask=active_action_dims,
action_delta=cfg.env.wrapper.delta_action,
storage_device=storage_device,
optimize_memory=True,
)
return offline_replay_buffer
def get_observation_features(
policy: SACPolicy, observations: torch.Tensor, next_observations: torch.Tensor
) -> tuple[torch.Tensor | None, torch.Tensor | None]:
@ -447,9 +479,6 @@ def add_actor_information_and_train(
offline_replay_buffer = None
if cfg.dataset_repo_id is not None:
logging.info("make_dataset offline buffer")
offline_dataset = make_dataset(cfg)
logging.info("Convertion to a offline replay buffer")
active_action_dims = None
if cfg.env.wrapper.joint_masking_action_space is not None:
active_action_dims = [
@ -457,14 +486,12 @@ def add_actor_information_and_train(
for i, mask in enumerate(cfg.env.wrapper.joint_masking_action_space)
if mask
]
offline_replay_buffer = ReplayBuffer.from_lerobot_dataset(
offline_dataset,
offline_replay_buffer = initialize_offline_replay_buffer(
cfg=cfg,
logger=logger,
device=device,
state_keys=cfg.policy.input_shapes.keys(),
action_mask=active_action_dims,
action_delta=cfg.env.wrapper.delta_action,
storage_device=storage_device,
optimize_memory=True,
active_action_dims=active_action_dims,
)
batch_size: int = batch_size // 2 # We will sample from both replay buffer
@ -714,6 +741,19 @@ def add_actor_information_and_train(
replay_buffer.to_lerobot_dataset(
cfg.dataset_repo_id, fps=cfg.fps, root=logger.log_dir / "dataset"
)
if offline_replay_buffer is not None:
dataset_dir = logger.log_dir / "dataset_offline"
if dataset_dir.exists() and dataset_dir.is_dir():
shutil.rmtree(
dataset_dir,
)
offline_replay_buffer.to_lerobot_dataset(
cfg.dataset_repo_id,
fps=cfg.fps,
root=logger.log_dir / "dataset_offline",
)
logging.info("Resume training")