Refactor SACPolicy and learner server for improved replay buffer management
- Updated SACPolicy to create critic heads using a list comprehension for better readability. - Simplified the saving and loading of models using `save_model` and `load_model` functions from the safetensors library. - Introduced `initialize_offline_replay_buffer` function in the learner server to streamline offline dataset handling and replay buffer initialization. - Enhanced logging for dataset loading processes to improve traceability during training.
This commit is contained in:
parent
b82faf7d8c
commit
4bb2077afa
|
@ -18,7 +18,7 @@
|
|||
# TODO: (1) better device management
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import Callable, Optional, Tuple, Union, Dict
|
||||
from typing import Callable, Optional, Tuple, Union, Dict, List
|
||||
from pathlib import Path
|
||||
|
||||
import einops
|
||||
|
@ -88,33 +88,33 @@ class SACPolicy(
|
|||
encoder_critic = SACObservationEncoder(config, self.normalize_inputs)
|
||||
encoder_actor = SACObservationEncoder(config, self.normalize_inputs)
|
||||
|
||||
# Create a list of critic heads
|
||||
critic_heads = [
|
||||
CriticHead(
|
||||
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0],
|
||||
**config.critic_network_kwargs,
|
||||
)
|
||||
for _ in range(config.num_critics)
|
||||
]
|
||||
|
||||
self.critic_ensemble = CriticEnsemble(
|
||||
encoder=encoder_critic,
|
||||
ensemble=Ensemble(
|
||||
[
|
||||
CriticHead(
|
||||
input_dim=encoder_critic.output_dim
|
||||
+ config.output_shapes["action"][0],
|
||||
**config.critic_network_kwargs,
|
||||
)
|
||||
for _ in range(config.num_critics)
|
||||
]
|
||||
),
|
||||
ensemble=critic_heads,
|
||||
output_normalization=self.normalize_targets,
|
||||
)
|
||||
|
||||
# Create target critic heads as deepcopies of the original critic heads
|
||||
target_critic_heads = [
|
||||
CriticHead(
|
||||
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0],
|
||||
**config.critic_network_kwargs,
|
||||
)
|
||||
for _ in range(config.num_critics)
|
||||
]
|
||||
|
||||
self.critic_target = CriticEnsemble(
|
||||
encoder=encoder_critic,
|
||||
ensemble=Ensemble(
|
||||
[
|
||||
CriticHead(
|
||||
input_dim=encoder_critic.output_dim
|
||||
+ config.output_shapes["action"][0],
|
||||
**config.critic_network_kwargs,
|
||||
)
|
||||
for _ in range(config.num_critics)
|
||||
]
|
||||
),
|
||||
ensemble=target_critic_heads,
|
||||
output_normalization=self.normalize_targets,
|
||||
)
|
||||
|
||||
|
@ -149,19 +149,9 @@ class SACPolicy(
|
|||
import json
|
||||
from dataclasses import asdict
|
||||
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE, CONFIG_NAME
|
||||
from safetensors.torch import save_file
|
||||
from safetensors.torch import save_model
|
||||
|
||||
# NOTE: Using tensordict.from_modules in the model to batch the inference using torch.vmap
|
||||
# implies one side effect: the __batch_size parameters are saved in the state_dict
|
||||
# __batch_size is torch.Size or safetensor save only torch.Tensor
|
||||
# so we need to filter them out before saving
|
||||
simplified_state_dict = {}
|
||||
|
||||
for name, param in self.named_parameters():
|
||||
simplified_state_dict[name] = param
|
||||
save_file(
|
||||
simplified_state_dict, os.path.join(save_directory, SAFETENSORS_SINGLE_FILE)
|
||||
)
|
||||
save_model(self, os.path.join(save_directory, SAFETENSORS_SINGLE_FILE))
|
||||
|
||||
# Save config
|
||||
config_dict = asdict(self.config)
|
||||
|
@ -191,7 +181,7 @@ class SACPolicy(
|
|||
from pathlib import Path
|
||||
from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE, CONFIG_NAME
|
||||
from safetensors.torch import load_file
|
||||
from safetensors.torch import load_model
|
||||
from lerobot.common.policies.sac.configuration_sac import SACConfig
|
||||
|
||||
# Check if model_id is a local path or a hub model ID
|
||||
|
@ -243,28 +233,7 @@ class SACPolicy(
|
|||
|
||||
# Load state dict from safetensors file
|
||||
if os.path.exists(safetensors_file):
|
||||
# Note: The load_file function returns a dict with the parameters, but __batch_size
|
||||
# is not loaded so we need to copy it from the model state_dict
|
||||
# Load the parameters only
|
||||
loaded_state_dict = load_file(safetensors_file, device=map_location)
|
||||
|
||||
# Copy batch size parameters
|
||||
find_and_copy_params(
|
||||
original_state_dict=model.state_dict(),
|
||||
loaded_state_dict=loaded_state_dict,
|
||||
pattern="__batch_size",
|
||||
match_type="endswith",
|
||||
)
|
||||
|
||||
# Copy normalization buffer parameters
|
||||
find_and_copy_params(
|
||||
original_state_dict=model.state_dict(),
|
||||
loaded_state_dict=loaded_state_dict,
|
||||
pattern="_orig_mod.output_normalization.buffer_action",
|
||||
match_type="contains",
|
||||
)
|
||||
|
||||
model.load_state_dict(loaded_state_dict, strict=False)
|
||||
load_model(model, filename=safetensors_file, device=map_location)
|
||||
|
||||
return model
|
||||
|
||||
|
@ -594,21 +563,21 @@ class CriticEnsemble(nn.Module):
|
|||
def __init__(
|
||||
self,
|
||||
encoder: Optional[nn.Module],
|
||||
ensemble: "Ensemble[CriticHead]",
|
||||
ensemble: List[CriticHead],
|
||||
output_normalization: nn.Module,
|
||||
init_final: Optional[float] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.encoder = encoder
|
||||
self.ensemble = ensemble
|
||||
self.init_final = init_final
|
||||
self.output_normalization = output_normalization
|
||||
self.critics = nn.ModuleList(ensemble)
|
||||
|
||||
self.parameters_to_optimize = []
|
||||
# Handle the case where a part of the encoder if frozen
|
||||
if self.encoder is not None:
|
||||
self.parameters_to_optimize += list(self.encoder.parameters_to_optimize)
|
||||
self.parameters_to_optimize += list(self.ensemble.parameters())
|
||||
self.parameters_to_optimize += list(self.critics.parameters())
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
@ -632,8 +601,15 @@ class CriticEnsemble(nn.Module):
|
|||
)
|
||||
|
||||
inputs = torch.cat([obs_enc, actions], dim=-1)
|
||||
q_values = self.ensemble(inputs) # [num_critics, B, 1]
|
||||
return q_values.squeeze(-1) # [num_critics, B]
|
||||
|
||||
# Loop through critics and collect outputs
|
||||
q_values = []
|
||||
for critic in self.critics:
|
||||
q_values.append(critic(inputs))
|
||||
|
||||
# Stack outputs to match expected shape [num_critics, batch_size]
|
||||
q_values = torch.stack([q.squeeze(-1) for q in q_values], dim=0)
|
||||
return q_values
|
||||
|
||||
|
||||
class Policy(nn.Module):
|
||||
|
@ -706,9 +682,9 @@ class Policy(nn.Module):
|
|||
# Compute standard deviations
|
||||
if self.fixed_std is None:
|
||||
log_std = self.std_layer(outputs)
|
||||
assert not torch.isnan(
|
||||
log_std
|
||||
).any(), "[ERROR] log_std became NaN after std_layer!"
|
||||
assert not torch.isnan(log_std).any(), (
|
||||
"[ERROR] log_std became NaN after std_layer!"
|
||||
)
|
||||
|
||||
if self.use_tanh_squash:
|
||||
log_std = torch.tanh(log_std)
|
||||
|
@ -1025,73 +1001,78 @@ def _convert_normalization_params_to_tensor(normalization_params: dict) -> dict:
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Test the SACObservationEncoder
|
||||
# Benchmark the CriticEnsemble performance
|
||||
import time
|
||||
|
||||
config = SACConfig()
|
||||
config.num_critics = 10
|
||||
config.vision_encoder_name = None
|
||||
encoder = SACObservationEncoder(config, nn.Identity())
|
||||
# actor_encoder = SACObservationEncoder(config)
|
||||
# encoder = torch.compile(encoder)
|
||||
# Configuration
|
||||
num_critics = 10
|
||||
batch_size = 32
|
||||
action_dim = 7
|
||||
obs_dim = 64
|
||||
hidden_dims = [256, 256]
|
||||
num_iterations = 100
|
||||
|
||||
print("Creating test environment...")
|
||||
|
||||
# Create a simple dummy encoder
|
||||
class DummyEncoder(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.output_dim = obs_dim
|
||||
self.parameters_to_optimize = []
|
||||
|
||||
def forward(self, obs):
|
||||
# Just return a random tensor of the right shape
|
||||
# In practice, this would encode the observations
|
||||
return torch.randn(batch_size, obs_dim, device=device)
|
||||
|
||||
# Create critic heads
|
||||
print(f"Creating {num_critics} critic heads...")
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
critic_heads = [
|
||||
CriticHead(
|
||||
input_dim=obs_dim + action_dim,
|
||||
hidden_dims=hidden_dims,
|
||||
).to(device)
|
||||
for _ in range(num_critics)
|
||||
]
|
||||
|
||||
# Create the critic ensemble
|
||||
print("Creating CriticEnsemble...")
|
||||
critic_ensemble = CriticEnsemble(
|
||||
encoder=encoder,
|
||||
ensemble=Ensemble(
|
||||
[
|
||||
CriticHead(
|
||||
input_dim=encoder.output_dim + config.output_shapes["action"][0],
|
||||
**config.critic_network_kwargs,
|
||||
)
|
||||
for _ in range(config.num_critics)
|
||||
]
|
||||
),
|
||||
encoder=DummyEncoder().to(device),
|
||||
ensemble=critic_heads,
|
||||
output_normalization=nn.Identity(),
|
||||
)
|
||||
# actor = Policy(
|
||||
# encoder=actor_encoder,
|
||||
# network=MLP(input_dim=actor_encoder.output_dim, **config.actor_network_kwargs),
|
||||
# action_dim=config.output_shapes["action"][0],
|
||||
# encoder_is_shared=config.shared_encoder,
|
||||
# **config.policy_kwargs,
|
||||
# )
|
||||
# encoder = encoder.to("cuda:0")
|
||||
# critic_ensemble = torch.compile(critic_ensemble)
|
||||
critic_ensemble = critic_ensemble.to("cuda:0")
|
||||
# actor = torch.compile(actor)
|
||||
# actor = actor.to("cuda:0")
|
||||
).to(device)
|
||||
|
||||
# Create random input data
|
||||
print("Creating input data...")
|
||||
obs_dict = {
|
||||
"observation.image": torch.randn(8, 3, 84, 84),
|
||||
"observation.state": torch.randn(8, 4),
|
||||
"observation.state": torch.randn(batch_size, obs_dim, device=device),
|
||||
}
|
||||
actions = torch.randn(8, 2).to("cuda:0")
|
||||
# obs_dict = {k: v.to("cuda:0") for k, v in obs_dict.items()}
|
||||
# print("compiling...")
|
||||
q_value = critic_ensemble(obs_dict, actions)
|
||||
print(q_value.size())
|
||||
# action = actor(obs_dict)
|
||||
# print("compiled")
|
||||
# start = time.perf_counter()
|
||||
# for _ in range(1000):
|
||||
# # features = encoder(obs_dict)
|
||||
# action = actor(obs_dict)
|
||||
# # q_value = critic_ensemble(obs_dict, actions)
|
||||
# print("Time taken:", time.perf_counter() - start)
|
||||
# Compare the performance of the ensemble vs a for loop of 16 MLPs
|
||||
ensemble = Ensemble([CriticHead(256, [256, 256]) for _ in range(2)])
|
||||
ensemble = ensemble.to("cuda:0")
|
||||
critic = CriticHead(256, [256, 256])
|
||||
critic = critic.to("cuda:0")
|
||||
data_ensemble = torch.randn(8, 256).to("cuda:0")
|
||||
ensemble = torch.compile(ensemble)
|
||||
# critic = torch.compile(critic)
|
||||
print(ensemble(data_ensemble).size())
|
||||
print(critic(data_ensemble).size())
|
||||
start = time.perf_counter()
|
||||
for _ in range(1000):
|
||||
ensemble(data_ensemble)
|
||||
print("Time taken:", time.perf_counter() - start)
|
||||
start = time.perf_counter()
|
||||
for _ in range(1000):
|
||||
for i in range(2):
|
||||
critic(data_ensemble)
|
||||
print("Time taken:", time.perf_counter() - start)
|
||||
actions = torch.randn(batch_size, action_dim, device=device)
|
||||
|
||||
# Warmup run
|
||||
print("Warming up...")
|
||||
_ = critic_ensemble(obs_dict, actions)
|
||||
|
||||
# Time the forward pass
|
||||
print(f"Running benchmark with {num_iterations} iterations...")
|
||||
start_time = time.perf_counter()
|
||||
for _ in range(num_iterations):
|
||||
q_values = critic_ensemble(obs_dict, actions)
|
||||
end_time = time.perf_counter()
|
||||
|
||||
# Print results
|
||||
elapsed_time = end_time - start_time
|
||||
print(f"Total time: {elapsed_time:.4f} seconds")
|
||||
print(f"Average time per iteration: {elapsed_time / num_iterations * 1000:.4f} ms")
|
||||
print(f"Output shape: {q_values.shape}") # Should be [num_critics, batch_size]
|
||||
|
||||
# Verify that all critic heads produce different outputs
|
||||
# This confirms each critic head is unique
|
||||
# print("\nVerifying critic outputs are different:")
|
||||
# for i in range(num_critics):
|
||||
# for j in range(i + 1, num_critics):
|
||||
# diff = torch.abs(q_values[i] - q_values[j]).mean().item()
|
||||
# print(f"Mean difference between critic {i} and {j}: {diff:.6f}")
|
||||
|
|
|
@ -121,7 +121,7 @@ def load_training_state(
|
|||
return None, None
|
||||
|
||||
training_state = torch.load(
|
||||
logger.last_checkpoint_dir / logger.training_state_file_name
|
||||
logger.last_checkpoint_dir / logger.training_state_file_name, weights_only=False
|
||||
)
|
||||
|
||||
if isinstance(training_state["optimizer"], dict):
|
||||
|
@ -160,6 +160,7 @@ def initialize_replay_buffer(
|
|||
optimize_memory=True,
|
||||
)
|
||||
|
||||
logging.info("Resume training load the online dataset")
|
||||
dataset = LeRobotDataset(
|
||||
repo_id=cfg.dataset_repo_id,
|
||||
local_files_only=True,
|
||||
|
@ -174,6 +175,37 @@ def initialize_replay_buffer(
|
|||
)
|
||||
|
||||
|
||||
def initialize_offline_replay_buffer(
|
||||
cfg: DictConfig,
|
||||
logger: Logger,
|
||||
device: str,
|
||||
storage_device: str,
|
||||
active_action_dims: list[int] | None = None,
|
||||
) -> ReplayBuffer:
|
||||
if not cfg.resume:
|
||||
logging.info("make_dataset offline buffer")
|
||||
offline_dataset = make_dataset(cfg)
|
||||
if cfg.resume:
|
||||
logging.info("load offline dataset")
|
||||
offline_dataset = LeRobotDataset(
|
||||
repo_id=cfg.dataset_repo_id,
|
||||
local_files_only=True,
|
||||
root=logger.log_dir / "dataset_offline",
|
||||
)
|
||||
|
||||
logging.info("Convert to a offline replay buffer")
|
||||
offline_replay_buffer = ReplayBuffer.from_lerobot_dataset(
|
||||
offline_dataset,
|
||||
device=device,
|
||||
state_keys=cfg.policy.input_shapes.keys(),
|
||||
action_mask=active_action_dims,
|
||||
action_delta=cfg.env.wrapper.delta_action,
|
||||
storage_device=storage_device,
|
||||
optimize_memory=True,
|
||||
)
|
||||
return offline_replay_buffer
|
||||
|
||||
|
||||
def get_observation_features(
|
||||
policy: SACPolicy, observations: torch.Tensor, next_observations: torch.Tensor
|
||||
) -> tuple[torch.Tensor | None, torch.Tensor | None]:
|
||||
|
@ -447,9 +479,6 @@ def add_actor_information_and_train(
|
|||
offline_replay_buffer = None
|
||||
|
||||
if cfg.dataset_repo_id is not None:
|
||||
logging.info("make_dataset offline buffer")
|
||||
offline_dataset = make_dataset(cfg)
|
||||
logging.info("Convertion to a offline replay buffer")
|
||||
active_action_dims = None
|
||||
if cfg.env.wrapper.joint_masking_action_space is not None:
|
||||
active_action_dims = [
|
||||
|
@ -457,14 +486,12 @@ def add_actor_information_and_train(
|
|||
for i, mask in enumerate(cfg.env.wrapper.joint_masking_action_space)
|
||||
if mask
|
||||
]
|
||||
offline_replay_buffer = ReplayBuffer.from_lerobot_dataset(
|
||||
offline_dataset,
|
||||
offline_replay_buffer = initialize_offline_replay_buffer(
|
||||
cfg=cfg,
|
||||
logger=logger,
|
||||
device=device,
|
||||
state_keys=cfg.policy.input_shapes.keys(),
|
||||
action_mask=active_action_dims,
|
||||
action_delta=cfg.env.wrapper.delta_action,
|
||||
storage_device=storage_device,
|
||||
optimize_memory=True,
|
||||
active_action_dims=active_action_dims,
|
||||
)
|
||||
batch_size: int = batch_size // 2 # We will sample from both replay buffer
|
||||
|
||||
|
@ -714,6 +741,19 @@ def add_actor_information_and_train(
|
|||
replay_buffer.to_lerobot_dataset(
|
||||
cfg.dataset_repo_id, fps=cfg.fps, root=logger.log_dir / "dataset"
|
||||
)
|
||||
if offline_replay_buffer is not None:
|
||||
dataset_dir = logger.log_dir / "dataset_offline"
|
||||
|
||||
if dataset_dir.exists() and dataset_dir.is_dir():
|
||||
shutil.rmtree(
|
||||
dataset_dir,
|
||||
)
|
||||
|
||||
offline_replay_buffer.to_lerobot_dataset(
|
||||
cfg.dataset_repo_id,
|
||||
fps=cfg.fps,
|
||||
root=logger.log_dir / "dataset_offline",
|
||||
)
|
||||
|
||||
logging.info("Resume training")
|
||||
|
||||
|
|
Loading…
Reference in New Issue