Refactor SACPolicy and learner server for improved replay buffer management

- Updated SACPolicy to create critic heads using a list comprehension for better readability.
- Simplified the saving and loading of models using `save_model` and `load_model` functions from the safetensors library.
- Introduced `initialize_offline_replay_buffer` function in the learner server to streamline offline dataset handling and replay buffer initialization.
- Enhanced logging for dataset loading processes to improve traceability during training.
This commit is contained in:
AdilZouitine 2025-03-18 14:57:15 +00:00
parent 9e3c8461ca
commit 17ec837a7a
2 changed files with 159 additions and 138 deletions

View File

@ -18,7 +18,7 @@
# TODO: (1) better device management # TODO: (1) better device management
from copy import deepcopy from copy import deepcopy
from typing import Callable, Optional, Tuple, Union, Dict from typing import Callable, Optional, Tuple, Union, Dict, List
from pathlib import Path from pathlib import Path
import einops import einops
@ -88,33 +88,33 @@ class SACPolicy(
encoder_critic = SACObservationEncoder(config, self.normalize_inputs) encoder_critic = SACObservationEncoder(config, self.normalize_inputs)
encoder_actor = SACObservationEncoder(config, self.normalize_inputs) encoder_actor = SACObservationEncoder(config, self.normalize_inputs)
# Create a list of critic heads
critic_heads = [
CriticHead(
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0],
**config.critic_network_kwargs,
)
for _ in range(config.num_critics)
]
self.critic_ensemble = CriticEnsemble( self.critic_ensemble = CriticEnsemble(
encoder=encoder_critic, encoder=encoder_critic,
ensemble=Ensemble( ensemble=critic_heads,
[
CriticHead(
input_dim=encoder_critic.output_dim
+ config.output_shapes["action"][0],
**config.critic_network_kwargs,
)
for _ in range(config.num_critics)
]
),
output_normalization=self.normalize_targets, output_normalization=self.normalize_targets,
) )
# Create target critic heads as deepcopies of the original critic heads
target_critic_heads = [
CriticHead(
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0],
**config.critic_network_kwargs,
)
for _ in range(config.num_critics)
]
self.critic_target = CriticEnsemble( self.critic_target = CriticEnsemble(
encoder=encoder_critic, encoder=encoder_critic,
ensemble=Ensemble( ensemble=target_critic_heads,
[
CriticHead(
input_dim=encoder_critic.output_dim
+ config.output_shapes["action"][0],
**config.critic_network_kwargs,
)
for _ in range(config.num_critics)
]
),
output_normalization=self.normalize_targets, output_normalization=self.normalize_targets,
) )
@ -149,19 +149,9 @@ class SACPolicy(
import json import json
from dataclasses import asdict from dataclasses import asdict
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE, CONFIG_NAME from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE, CONFIG_NAME
from safetensors.torch import save_file from safetensors.torch import save_model
# NOTE: Using tensordict.from_modules in the model to batch the inference using torch.vmap save_model(self, os.path.join(save_directory, SAFETENSORS_SINGLE_FILE))
# implies one side effect: the __batch_size parameters are saved in the state_dict
# __batch_size is torch.Size or safetensor save only torch.Tensor
# so we need to filter them out before saving
simplified_state_dict = {}
for name, param in self.named_parameters():
simplified_state_dict[name] = param
save_file(
simplified_state_dict, os.path.join(save_directory, SAFETENSORS_SINGLE_FILE)
)
# Save config # Save config
config_dict = asdict(self.config) config_dict = asdict(self.config)
@ -191,7 +181,7 @@ class SACPolicy(
from pathlib import Path from pathlib import Path
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE, CONFIG_NAME from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE, CONFIG_NAME
from safetensors.torch import load_file from safetensors.torch import load_model
from lerobot.common.policies.sac.configuration_sac import SACConfig from lerobot.common.policies.sac.configuration_sac import SACConfig
# Check if model_id is a local path or a hub model ID # Check if model_id is a local path or a hub model ID
@ -243,28 +233,7 @@ class SACPolicy(
# Load state dict from safetensors file # Load state dict from safetensors file
if os.path.exists(safetensors_file): if os.path.exists(safetensors_file):
# Note: The load_file function returns a dict with the parameters, but __batch_size load_model(model, filename=safetensors_file, device=map_location)
# is not loaded so we need to copy it from the model state_dict
# Load the parameters only
loaded_state_dict = load_file(safetensors_file, device=map_location)
# Copy batch size parameters
find_and_copy_params(
original_state_dict=model.state_dict(),
loaded_state_dict=loaded_state_dict,
pattern="__batch_size",
match_type="endswith",
)
# Copy normalization buffer parameters
find_and_copy_params(
original_state_dict=model.state_dict(),
loaded_state_dict=loaded_state_dict,
pattern="_orig_mod.output_normalization.buffer_action",
match_type="contains",
)
model.load_state_dict(loaded_state_dict, strict=False)
return model return model
@ -594,21 +563,21 @@ class CriticEnsemble(nn.Module):
def __init__( def __init__(
self, self,
encoder: Optional[nn.Module], encoder: Optional[nn.Module],
ensemble: "Ensemble[CriticHead]", ensemble: List[CriticHead],
output_normalization: nn.Module, output_normalization: nn.Module,
init_final: Optional[float] = None, init_final: Optional[float] = None,
): ):
super().__init__() super().__init__()
self.encoder = encoder self.encoder = encoder
self.ensemble = ensemble
self.init_final = init_final self.init_final = init_final
self.output_normalization = output_normalization self.output_normalization = output_normalization
self.critics = nn.ModuleList(ensemble)
self.parameters_to_optimize = [] self.parameters_to_optimize = []
# Handle the case where a part of the encoder if frozen # Handle the case where a part of the encoder if frozen
if self.encoder is not None: if self.encoder is not None:
self.parameters_to_optimize += list(self.encoder.parameters_to_optimize) self.parameters_to_optimize += list(self.encoder.parameters_to_optimize)
self.parameters_to_optimize += list(self.ensemble.parameters()) self.parameters_to_optimize += list(self.critics.parameters())
def forward( def forward(
self, self,
@ -632,8 +601,15 @@ class CriticEnsemble(nn.Module):
) )
inputs = torch.cat([obs_enc, actions], dim=-1) inputs = torch.cat([obs_enc, actions], dim=-1)
q_values = self.ensemble(inputs) # [num_critics, B, 1]
return q_values.squeeze(-1) # [num_critics, B] # Loop through critics and collect outputs
q_values = []
for critic in self.critics:
q_values.append(critic(inputs))
# Stack outputs to match expected shape [num_critics, batch_size]
q_values = torch.stack([q.squeeze(-1) for q in q_values], dim=0)
return q_values
class Policy(nn.Module): class Policy(nn.Module):
@ -706,9 +682,9 @@ class Policy(nn.Module):
# Compute standard deviations # Compute standard deviations
if self.fixed_std is None: if self.fixed_std is None:
log_std = self.std_layer(outputs) log_std = self.std_layer(outputs)
assert not torch.isnan( assert not torch.isnan(log_std).any(), (
log_std "[ERROR] log_std became NaN after std_layer!"
).any(), "[ERROR] log_std became NaN after std_layer!" )
if self.use_tanh_squash: if self.use_tanh_squash:
log_std = torch.tanh(log_std) log_std = torch.tanh(log_std)
@ -1025,73 +1001,78 @@ def _convert_normalization_params_to_tensor(normalization_params: dict) -> dict:
if __name__ == "__main__": if __name__ == "__main__":
# Test the SACObservationEncoder # Benchmark the CriticEnsemble performance
import time import time
config = SACConfig() # Configuration
config.num_critics = 10 num_critics = 10
config.vision_encoder_name = None batch_size = 32
encoder = SACObservationEncoder(config, nn.Identity()) action_dim = 7
# actor_encoder = SACObservationEncoder(config) obs_dim = 64
# encoder = torch.compile(encoder) hidden_dims = [256, 256]
num_iterations = 100
print("Creating test environment...")
# Create a simple dummy encoder
class DummyEncoder(nn.Module):
def __init__(self):
super().__init__()
self.output_dim = obs_dim
self.parameters_to_optimize = []
def forward(self, obs):
# Just return a random tensor of the right shape
# In practice, this would encode the observations
return torch.randn(batch_size, obs_dim, device=device)
# Create critic heads
print(f"Creating {num_critics} critic heads...")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
critic_heads = [
CriticHead(
input_dim=obs_dim + action_dim,
hidden_dims=hidden_dims,
).to(device)
for _ in range(num_critics)
]
# Create the critic ensemble
print("Creating CriticEnsemble...")
critic_ensemble = CriticEnsemble( critic_ensemble = CriticEnsemble(
encoder=encoder, encoder=DummyEncoder().to(device),
ensemble=Ensemble( ensemble=critic_heads,
[
CriticHead(
input_dim=encoder.output_dim + config.output_shapes["action"][0],
**config.critic_network_kwargs,
)
for _ in range(config.num_critics)
]
),
output_normalization=nn.Identity(), output_normalization=nn.Identity(),
) ).to(device)
# actor = Policy(
# encoder=actor_encoder, # Create random input data
# network=MLP(input_dim=actor_encoder.output_dim, **config.actor_network_kwargs), print("Creating input data...")
# action_dim=config.output_shapes["action"][0],
# encoder_is_shared=config.shared_encoder,
# **config.policy_kwargs,
# )
# encoder = encoder.to("cuda:0")
# critic_ensemble = torch.compile(critic_ensemble)
critic_ensemble = critic_ensemble.to("cuda:0")
# actor = torch.compile(actor)
# actor = actor.to("cuda:0")
obs_dict = { obs_dict = {
"observation.image": torch.randn(8, 3, 84, 84), "observation.state": torch.randn(batch_size, obs_dim, device=device),
"observation.state": torch.randn(8, 4),
} }
actions = torch.randn(8, 2).to("cuda:0") actions = torch.randn(batch_size, action_dim, device=device)
# obs_dict = {k: v.to("cuda:0") for k, v in obs_dict.items()}
# print("compiling...") # Warmup run
q_value = critic_ensemble(obs_dict, actions) print("Warming up...")
print(q_value.size()) _ = critic_ensemble(obs_dict, actions)
# action = actor(obs_dict)
# print("compiled") # Time the forward pass
# start = time.perf_counter() print(f"Running benchmark with {num_iterations} iterations...")
# for _ in range(1000): start_time = time.perf_counter()
# # features = encoder(obs_dict) for _ in range(num_iterations):
# action = actor(obs_dict) q_values = critic_ensemble(obs_dict, actions)
# # q_value = critic_ensemble(obs_dict, actions) end_time = time.perf_counter()
# print("Time taken:", time.perf_counter() - start)
# Compare the performance of the ensemble vs a for loop of 16 MLPs # Print results
ensemble = Ensemble([CriticHead(256, [256, 256]) for _ in range(2)]) elapsed_time = end_time - start_time
ensemble = ensemble.to("cuda:0") print(f"Total time: {elapsed_time:.4f} seconds")
critic = CriticHead(256, [256, 256]) print(f"Average time per iteration: {elapsed_time / num_iterations * 1000:.4f} ms")
critic = critic.to("cuda:0") print(f"Output shape: {q_values.shape}") # Should be [num_critics, batch_size]
data_ensemble = torch.randn(8, 256).to("cuda:0")
ensemble = torch.compile(ensemble) # Verify that all critic heads produce different outputs
# critic = torch.compile(critic) # This confirms each critic head is unique
print(ensemble(data_ensemble).size()) # print("\nVerifying critic outputs are different:")
print(critic(data_ensemble).size()) # for i in range(num_critics):
start = time.perf_counter() # for j in range(i + 1, num_critics):
for _ in range(1000): # diff = torch.abs(q_values[i] - q_values[j]).mean().item()
ensemble(data_ensemble) # print(f"Mean difference between critic {i} and {j}: {diff:.6f}")
print("Time taken:", time.perf_counter() - start)
start = time.perf_counter()
for _ in range(1000):
for i in range(2):
critic(data_ensemble)
print("Time taken:", time.perf_counter() - start)

View File

@ -121,7 +121,7 @@ def load_training_state(
return None, None return None, None
training_state = torch.load( training_state = torch.load(
logger.last_checkpoint_dir / logger.training_state_file_name logger.last_checkpoint_dir / logger.training_state_file_name, weights_only=False
) )
if isinstance(training_state["optimizer"], dict): if isinstance(training_state["optimizer"], dict):
@ -160,6 +160,7 @@ def initialize_replay_buffer(
optimize_memory=True, optimize_memory=True,
) )
logging.info("Resume training load the online dataset")
dataset = LeRobotDataset( dataset = LeRobotDataset(
repo_id=cfg.dataset_repo_id, repo_id=cfg.dataset_repo_id,
local_files_only=True, local_files_only=True,
@ -174,6 +175,37 @@ def initialize_replay_buffer(
) )
def initialize_offline_replay_buffer(
cfg: DictConfig,
logger: Logger,
device: str,
storage_device: str,
active_action_dims: list[int] | None = None,
) -> ReplayBuffer:
if not cfg.resume:
logging.info("make_dataset offline buffer")
offline_dataset = make_dataset(cfg)
if cfg.resume:
logging.info("load offline dataset")
offline_dataset = LeRobotDataset(
repo_id=cfg.dataset_repo_id,
local_files_only=True,
root=logger.log_dir / "dataset_offline",
)
logging.info("Convert to a offline replay buffer")
offline_replay_buffer = ReplayBuffer.from_lerobot_dataset(
offline_dataset,
device=device,
state_keys=cfg.policy.input_shapes.keys(),
action_mask=active_action_dims,
action_delta=cfg.env.wrapper.delta_action,
storage_device=storage_device,
optimize_memory=True,
)
return offline_replay_buffer
def get_observation_features( def get_observation_features(
policy: SACPolicy, observations: torch.Tensor, next_observations: torch.Tensor policy: SACPolicy, observations: torch.Tensor, next_observations: torch.Tensor
) -> tuple[torch.Tensor | None, torch.Tensor | None]: ) -> tuple[torch.Tensor | None, torch.Tensor | None]:
@ -447,9 +479,6 @@ def add_actor_information_and_train(
offline_replay_buffer = None offline_replay_buffer = None
if cfg.dataset_repo_id is not None: if cfg.dataset_repo_id is not None:
logging.info("make_dataset offline buffer")
offline_dataset = make_dataset(cfg)
logging.info("Convertion to a offline replay buffer")
active_action_dims = None active_action_dims = None
if cfg.env.wrapper.joint_masking_action_space is not None: if cfg.env.wrapper.joint_masking_action_space is not None:
active_action_dims = [ active_action_dims = [
@ -457,14 +486,12 @@ def add_actor_information_and_train(
for i, mask in enumerate(cfg.env.wrapper.joint_masking_action_space) for i, mask in enumerate(cfg.env.wrapper.joint_masking_action_space)
if mask if mask
] ]
offline_replay_buffer = ReplayBuffer.from_lerobot_dataset( offline_replay_buffer = initialize_offline_replay_buffer(
offline_dataset, cfg=cfg,
logger=logger,
device=device, device=device,
state_keys=cfg.policy.input_shapes.keys(),
action_mask=active_action_dims,
action_delta=cfg.env.wrapper.delta_action,
storage_device=storage_device, storage_device=storage_device,
optimize_memory=True, active_action_dims=active_action_dims,
) )
batch_size: int = batch_size // 2 # We will sample from both replay buffer batch_size: int = batch_size // 2 # We will sample from both replay buffer
@ -714,6 +741,19 @@ def add_actor_information_and_train(
replay_buffer.to_lerobot_dataset( replay_buffer.to_lerobot_dataset(
cfg.dataset_repo_id, fps=cfg.fps, root=logger.log_dir / "dataset" cfg.dataset_repo_id, fps=cfg.fps, root=logger.log_dir / "dataset"
) )
if offline_replay_buffer is not None:
dataset_dir = logger.log_dir / "dataset_offline"
if dataset_dir.exists() and dataset_dir.is_dir():
shutil.rmtree(
dataset_dir,
)
offline_replay_buffer.to_lerobot_dataset(
cfg.dataset_repo_id,
fps=cfg.fps,
root=logger.log_dir / "dataset_offline",
)
logging.info("Resume training") logging.info("Resume training")