Refactor SACPolicy and learner server for improved replay buffer management
- Updated SACPolicy to create critic heads using a list comprehension for better readability. - Simplified the saving and loading of models using `save_model` and `load_model` functions from the safetensors library. - Introduced `initialize_offline_replay_buffer` function in the learner server to streamline offline dataset handling and replay buffer initialization. - Enhanced logging for dataset loading processes to improve traceability during training.
This commit is contained in:
parent
944355c042
commit
7dd581d817
|
@ -18,7 +18,7 @@
|
||||||
# TODO: (1) better device management
|
# TODO: (1) better device management
|
||||||
|
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import Callable, Optional, Tuple, Union, Dict
|
from typing import Callable, Optional, Tuple, Union, Dict, List
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import einops
|
import einops
|
||||||
|
@ -88,33 +88,33 @@ class SACPolicy(
|
||||||
encoder_critic = SACObservationEncoder(config, self.normalize_inputs)
|
encoder_critic = SACObservationEncoder(config, self.normalize_inputs)
|
||||||
encoder_actor = SACObservationEncoder(config, self.normalize_inputs)
|
encoder_actor = SACObservationEncoder(config, self.normalize_inputs)
|
||||||
|
|
||||||
|
# Create a list of critic heads
|
||||||
|
critic_heads = [
|
||||||
|
CriticHead(
|
||||||
|
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0],
|
||||||
|
**config.critic_network_kwargs,
|
||||||
|
)
|
||||||
|
for _ in range(config.num_critics)
|
||||||
|
]
|
||||||
|
|
||||||
self.critic_ensemble = CriticEnsemble(
|
self.critic_ensemble = CriticEnsemble(
|
||||||
encoder=encoder_critic,
|
encoder=encoder_critic,
|
||||||
ensemble=Ensemble(
|
ensemble=critic_heads,
|
||||||
[
|
|
||||||
CriticHead(
|
|
||||||
input_dim=encoder_critic.output_dim
|
|
||||||
+ config.output_shapes["action"][0],
|
|
||||||
**config.critic_network_kwargs,
|
|
||||||
)
|
|
||||||
for _ in range(config.num_critics)
|
|
||||||
]
|
|
||||||
),
|
|
||||||
output_normalization=self.normalize_targets,
|
output_normalization=self.normalize_targets,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Create target critic heads as deepcopies of the original critic heads
|
||||||
|
target_critic_heads = [
|
||||||
|
CriticHead(
|
||||||
|
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0],
|
||||||
|
**config.critic_network_kwargs,
|
||||||
|
)
|
||||||
|
for _ in range(config.num_critics)
|
||||||
|
]
|
||||||
|
|
||||||
self.critic_target = CriticEnsemble(
|
self.critic_target = CriticEnsemble(
|
||||||
encoder=encoder_critic,
|
encoder=encoder_critic,
|
||||||
ensemble=Ensemble(
|
ensemble=target_critic_heads,
|
||||||
[
|
|
||||||
CriticHead(
|
|
||||||
input_dim=encoder_critic.output_dim
|
|
||||||
+ config.output_shapes["action"][0],
|
|
||||||
**config.critic_network_kwargs,
|
|
||||||
)
|
|
||||||
for _ in range(config.num_critics)
|
|
||||||
]
|
|
||||||
),
|
|
||||||
output_normalization=self.normalize_targets,
|
output_normalization=self.normalize_targets,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -149,19 +149,9 @@ class SACPolicy(
|
||||||
import json
|
import json
|
||||||
from dataclasses import asdict
|
from dataclasses import asdict
|
||||||
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE, CONFIG_NAME
|
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE, CONFIG_NAME
|
||||||
from safetensors.torch import save_file
|
from safetensors.torch import save_model
|
||||||
|
|
||||||
# NOTE: Using tensordict.from_modules in the model to batch the inference using torch.vmap
|
save_model(self, os.path.join(save_directory, SAFETENSORS_SINGLE_FILE))
|
||||||
# implies one side effect: the __batch_size parameters are saved in the state_dict
|
|
||||||
# __batch_size is torch.Size or safetensor save only torch.Tensor
|
|
||||||
# so we need to filter them out before saving
|
|
||||||
simplified_state_dict = {}
|
|
||||||
|
|
||||||
for name, param in self.named_parameters():
|
|
||||||
simplified_state_dict[name] = param
|
|
||||||
save_file(
|
|
||||||
simplified_state_dict, os.path.join(save_directory, SAFETENSORS_SINGLE_FILE)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Save config
|
# Save config
|
||||||
config_dict = asdict(self.config)
|
config_dict = asdict(self.config)
|
||||||
|
@ -191,7 +181,7 @@ class SACPolicy(
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE, CONFIG_NAME
|
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE, CONFIG_NAME
|
||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_model
|
||||||
from lerobot.common.policies.sac.configuration_sac import SACConfig
|
from lerobot.common.policies.sac.configuration_sac import SACConfig
|
||||||
|
|
||||||
# Check if model_id is a local path or a hub model ID
|
# Check if model_id is a local path or a hub model ID
|
||||||
|
@ -243,28 +233,7 @@ class SACPolicy(
|
||||||
|
|
||||||
# Load state dict from safetensors file
|
# Load state dict from safetensors file
|
||||||
if os.path.exists(safetensors_file):
|
if os.path.exists(safetensors_file):
|
||||||
# Note: The load_file function returns a dict with the parameters, but __batch_size
|
load_model(model, filename=safetensors_file, device=map_location)
|
||||||
# is not loaded so we need to copy it from the model state_dict
|
|
||||||
# Load the parameters only
|
|
||||||
loaded_state_dict = load_file(safetensors_file, device=map_location)
|
|
||||||
|
|
||||||
# Copy batch size parameters
|
|
||||||
find_and_copy_params(
|
|
||||||
original_state_dict=model.state_dict(),
|
|
||||||
loaded_state_dict=loaded_state_dict,
|
|
||||||
pattern="__batch_size",
|
|
||||||
match_type="endswith",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Copy normalization buffer parameters
|
|
||||||
find_and_copy_params(
|
|
||||||
original_state_dict=model.state_dict(),
|
|
||||||
loaded_state_dict=loaded_state_dict,
|
|
||||||
pattern="_orig_mod.output_normalization.buffer_action",
|
|
||||||
match_type="contains",
|
|
||||||
)
|
|
||||||
|
|
||||||
model.load_state_dict(loaded_state_dict, strict=False)
|
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
@ -594,21 +563,21 @@ class CriticEnsemble(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
encoder: Optional[nn.Module],
|
encoder: Optional[nn.Module],
|
||||||
ensemble: "Ensemble[CriticHead]",
|
ensemble: List[CriticHead],
|
||||||
output_normalization: nn.Module,
|
output_normalization: nn.Module,
|
||||||
init_final: Optional[float] = None,
|
init_final: Optional[float] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.encoder = encoder
|
self.encoder = encoder
|
||||||
self.ensemble = ensemble
|
|
||||||
self.init_final = init_final
|
self.init_final = init_final
|
||||||
self.output_normalization = output_normalization
|
self.output_normalization = output_normalization
|
||||||
|
self.critics = nn.ModuleList(ensemble)
|
||||||
|
|
||||||
self.parameters_to_optimize = []
|
self.parameters_to_optimize = []
|
||||||
# Handle the case where a part of the encoder if frozen
|
# Handle the case where a part of the encoder if frozen
|
||||||
if self.encoder is not None:
|
if self.encoder is not None:
|
||||||
self.parameters_to_optimize += list(self.encoder.parameters_to_optimize)
|
self.parameters_to_optimize += list(self.encoder.parameters_to_optimize)
|
||||||
self.parameters_to_optimize += list(self.ensemble.parameters())
|
self.parameters_to_optimize += list(self.critics.parameters())
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
@ -632,8 +601,15 @@ class CriticEnsemble(nn.Module):
|
||||||
)
|
)
|
||||||
|
|
||||||
inputs = torch.cat([obs_enc, actions], dim=-1)
|
inputs = torch.cat([obs_enc, actions], dim=-1)
|
||||||
q_values = self.ensemble(inputs) # [num_critics, B, 1]
|
|
||||||
return q_values.squeeze(-1) # [num_critics, B]
|
# Loop through critics and collect outputs
|
||||||
|
q_values = []
|
||||||
|
for critic in self.critics:
|
||||||
|
q_values.append(critic(inputs))
|
||||||
|
|
||||||
|
# Stack outputs to match expected shape [num_critics, batch_size]
|
||||||
|
q_values = torch.stack([q.squeeze(-1) for q in q_values], dim=0)
|
||||||
|
return q_values
|
||||||
|
|
||||||
|
|
||||||
class Policy(nn.Module):
|
class Policy(nn.Module):
|
||||||
|
@ -706,9 +682,9 @@ class Policy(nn.Module):
|
||||||
# Compute standard deviations
|
# Compute standard deviations
|
||||||
if self.fixed_std is None:
|
if self.fixed_std is None:
|
||||||
log_std = self.std_layer(outputs)
|
log_std = self.std_layer(outputs)
|
||||||
assert not torch.isnan(
|
assert not torch.isnan(log_std).any(), (
|
||||||
log_std
|
"[ERROR] log_std became NaN after std_layer!"
|
||||||
).any(), "[ERROR] log_std became NaN after std_layer!"
|
)
|
||||||
|
|
||||||
if self.use_tanh_squash:
|
if self.use_tanh_squash:
|
||||||
log_std = torch.tanh(log_std)
|
log_std = torch.tanh(log_std)
|
||||||
|
@ -1025,73 +1001,78 @@ def _convert_normalization_params_to_tensor(normalization_params: dict) -> dict:
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Test the SACObservationEncoder
|
# Benchmark the CriticEnsemble performance
|
||||||
import time
|
import time
|
||||||
|
|
||||||
config = SACConfig()
|
# Configuration
|
||||||
config.num_critics = 10
|
num_critics = 10
|
||||||
config.vision_encoder_name = None
|
batch_size = 32
|
||||||
encoder = SACObservationEncoder(config, nn.Identity())
|
action_dim = 7
|
||||||
# actor_encoder = SACObservationEncoder(config)
|
obs_dim = 64
|
||||||
# encoder = torch.compile(encoder)
|
hidden_dims = [256, 256]
|
||||||
|
num_iterations = 100
|
||||||
|
|
||||||
|
print("Creating test environment...")
|
||||||
|
|
||||||
|
# Create a simple dummy encoder
|
||||||
|
class DummyEncoder(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.output_dim = obs_dim
|
||||||
|
self.parameters_to_optimize = []
|
||||||
|
|
||||||
|
def forward(self, obs):
|
||||||
|
# Just return a random tensor of the right shape
|
||||||
|
# In practice, this would encode the observations
|
||||||
|
return torch.randn(batch_size, obs_dim, device=device)
|
||||||
|
|
||||||
|
# Create critic heads
|
||||||
|
print(f"Creating {num_critics} critic heads...")
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
critic_heads = [
|
||||||
|
CriticHead(
|
||||||
|
input_dim=obs_dim + action_dim,
|
||||||
|
hidden_dims=hidden_dims,
|
||||||
|
).to(device)
|
||||||
|
for _ in range(num_critics)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Create the critic ensemble
|
||||||
|
print("Creating CriticEnsemble...")
|
||||||
critic_ensemble = CriticEnsemble(
|
critic_ensemble = CriticEnsemble(
|
||||||
encoder=encoder,
|
encoder=DummyEncoder().to(device),
|
||||||
ensemble=Ensemble(
|
ensemble=critic_heads,
|
||||||
[
|
|
||||||
CriticHead(
|
|
||||||
input_dim=encoder.output_dim + config.output_shapes["action"][0],
|
|
||||||
**config.critic_network_kwargs,
|
|
||||||
)
|
|
||||||
for _ in range(config.num_critics)
|
|
||||||
]
|
|
||||||
),
|
|
||||||
output_normalization=nn.Identity(),
|
output_normalization=nn.Identity(),
|
||||||
)
|
).to(device)
|
||||||
# actor = Policy(
|
|
||||||
# encoder=actor_encoder,
|
# Create random input data
|
||||||
# network=MLP(input_dim=actor_encoder.output_dim, **config.actor_network_kwargs),
|
print("Creating input data...")
|
||||||
# action_dim=config.output_shapes["action"][0],
|
|
||||||
# encoder_is_shared=config.shared_encoder,
|
|
||||||
# **config.policy_kwargs,
|
|
||||||
# )
|
|
||||||
# encoder = encoder.to("cuda:0")
|
|
||||||
# critic_ensemble = torch.compile(critic_ensemble)
|
|
||||||
critic_ensemble = critic_ensemble.to("cuda:0")
|
|
||||||
# actor = torch.compile(actor)
|
|
||||||
# actor = actor.to("cuda:0")
|
|
||||||
obs_dict = {
|
obs_dict = {
|
||||||
"observation.image": torch.randn(8, 3, 84, 84),
|
"observation.state": torch.randn(batch_size, obs_dim, device=device),
|
||||||
"observation.state": torch.randn(8, 4),
|
|
||||||
}
|
}
|
||||||
actions = torch.randn(8, 2).to("cuda:0")
|
actions = torch.randn(batch_size, action_dim, device=device)
|
||||||
# obs_dict = {k: v.to("cuda:0") for k, v in obs_dict.items()}
|
|
||||||
# print("compiling...")
|
# Warmup run
|
||||||
q_value = critic_ensemble(obs_dict, actions)
|
print("Warming up...")
|
||||||
print(q_value.size())
|
_ = critic_ensemble(obs_dict, actions)
|
||||||
# action = actor(obs_dict)
|
|
||||||
# print("compiled")
|
# Time the forward pass
|
||||||
# start = time.perf_counter()
|
print(f"Running benchmark with {num_iterations} iterations...")
|
||||||
# for _ in range(1000):
|
start_time = time.perf_counter()
|
||||||
# # features = encoder(obs_dict)
|
for _ in range(num_iterations):
|
||||||
# action = actor(obs_dict)
|
q_values = critic_ensemble(obs_dict, actions)
|
||||||
# # q_value = critic_ensemble(obs_dict, actions)
|
end_time = time.perf_counter()
|
||||||
# print("Time taken:", time.perf_counter() - start)
|
|
||||||
# Compare the performance of the ensemble vs a for loop of 16 MLPs
|
# Print results
|
||||||
ensemble = Ensemble([CriticHead(256, [256, 256]) for _ in range(2)])
|
elapsed_time = end_time - start_time
|
||||||
ensemble = ensemble.to("cuda:0")
|
print(f"Total time: {elapsed_time:.4f} seconds")
|
||||||
critic = CriticHead(256, [256, 256])
|
print(f"Average time per iteration: {elapsed_time / num_iterations * 1000:.4f} ms")
|
||||||
critic = critic.to("cuda:0")
|
print(f"Output shape: {q_values.shape}") # Should be [num_critics, batch_size]
|
||||||
data_ensemble = torch.randn(8, 256).to("cuda:0")
|
|
||||||
ensemble = torch.compile(ensemble)
|
# Verify that all critic heads produce different outputs
|
||||||
# critic = torch.compile(critic)
|
# This confirms each critic head is unique
|
||||||
print(ensemble(data_ensemble).size())
|
# print("\nVerifying critic outputs are different:")
|
||||||
print(critic(data_ensemble).size())
|
# for i in range(num_critics):
|
||||||
start = time.perf_counter()
|
# for j in range(i + 1, num_critics):
|
||||||
for _ in range(1000):
|
# diff = torch.abs(q_values[i] - q_values[j]).mean().item()
|
||||||
ensemble(data_ensemble)
|
# print(f"Mean difference between critic {i} and {j}: {diff:.6f}")
|
||||||
print("Time taken:", time.perf_counter() - start)
|
|
||||||
start = time.perf_counter()
|
|
||||||
for _ in range(1000):
|
|
||||||
for i in range(2):
|
|
||||||
critic(data_ensemble)
|
|
||||||
print("Time taken:", time.perf_counter() - start)
|
|
||||||
|
|
|
@ -121,7 +121,7 @@ def load_training_state(
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
training_state = torch.load(
|
training_state = torch.load(
|
||||||
logger.last_checkpoint_dir / logger.training_state_file_name
|
logger.last_checkpoint_dir / logger.training_state_file_name, weights_only=False
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(training_state["optimizer"], dict):
|
if isinstance(training_state["optimizer"], dict):
|
||||||
|
@ -160,6 +160,7 @@ def initialize_replay_buffer(
|
||||||
optimize_memory=True,
|
optimize_memory=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logging.info("Resume training load the online dataset")
|
||||||
dataset = LeRobotDataset(
|
dataset = LeRobotDataset(
|
||||||
repo_id=cfg.dataset_repo_id,
|
repo_id=cfg.dataset_repo_id,
|
||||||
local_files_only=True,
|
local_files_only=True,
|
||||||
|
@ -174,6 +175,37 @@ def initialize_replay_buffer(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_offline_replay_buffer(
|
||||||
|
cfg: DictConfig,
|
||||||
|
logger: Logger,
|
||||||
|
device: str,
|
||||||
|
storage_device: str,
|
||||||
|
active_action_dims: list[int] | None = None,
|
||||||
|
) -> ReplayBuffer:
|
||||||
|
if not cfg.resume:
|
||||||
|
logging.info("make_dataset offline buffer")
|
||||||
|
offline_dataset = make_dataset(cfg)
|
||||||
|
if cfg.resume:
|
||||||
|
logging.info("load offline dataset")
|
||||||
|
offline_dataset = LeRobotDataset(
|
||||||
|
repo_id=cfg.dataset_repo_id,
|
||||||
|
local_files_only=True,
|
||||||
|
root=logger.log_dir / "dataset_offline",
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info("Convert to a offline replay buffer")
|
||||||
|
offline_replay_buffer = ReplayBuffer.from_lerobot_dataset(
|
||||||
|
offline_dataset,
|
||||||
|
device=device,
|
||||||
|
state_keys=cfg.policy.input_shapes.keys(),
|
||||||
|
action_mask=active_action_dims,
|
||||||
|
action_delta=cfg.env.wrapper.delta_action,
|
||||||
|
storage_device=storage_device,
|
||||||
|
optimize_memory=True,
|
||||||
|
)
|
||||||
|
return offline_replay_buffer
|
||||||
|
|
||||||
|
|
||||||
def get_observation_features(
|
def get_observation_features(
|
||||||
policy: SACPolicy, observations: torch.Tensor, next_observations: torch.Tensor
|
policy: SACPolicy, observations: torch.Tensor, next_observations: torch.Tensor
|
||||||
) -> tuple[torch.Tensor | None, torch.Tensor | None]:
|
) -> tuple[torch.Tensor | None, torch.Tensor | None]:
|
||||||
|
@ -447,9 +479,6 @@ def add_actor_information_and_train(
|
||||||
offline_replay_buffer = None
|
offline_replay_buffer = None
|
||||||
|
|
||||||
if cfg.dataset_repo_id is not None:
|
if cfg.dataset_repo_id is not None:
|
||||||
logging.info("make_dataset offline buffer")
|
|
||||||
offline_dataset = make_dataset(cfg)
|
|
||||||
logging.info("Convertion to a offline replay buffer")
|
|
||||||
active_action_dims = None
|
active_action_dims = None
|
||||||
if cfg.env.wrapper.joint_masking_action_space is not None:
|
if cfg.env.wrapper.joint_masking_action_space is not None:
|
||||||
active_action_dims = [
|
active_action_dims = [
|
||||||
|
@ -457,14 +486,12 @@ def add_actor_information_and_train(
|
||||||
for i, mask in enumerate(cfg.env.wrapper.joint_masking_action_space)
|
for i, mask in enumerate(cfg.env.wrapper.joint_masking_action_space)
|
||||||
if mask
|
if mask
|
||||||
]
|
]
|
||||||
offline_replay_buffer = ReplayBuffer.from_lerobot_dataset(
|
offline_replay_buffer = initialize_offline_replay_buffer(
|
||||||
offline_dataset,
|
cfg=cfg,
|
||||||
|
logger=logger,
|
||||||
device=device,
|
device=device,
|
||||||
state_keys=cfg.policy.input_shapes.keys(),
|
|
||||||
action_mask=active_action_dims,
|
|
||||||
action_delta=cfg.env.wrapper.delta_action,
|
|
||||||
storage_device=storage_device,
|
storage_device=storage_device,
|
||||||
optimize_memory=True,
|
active_action_dims=active_action_dims,
|
||||||
)
|
)
|
||||||
batch_size: int = batch_size // 2 # We will sample from both replay buffer
|
batch_size: int = batch_size // 2 # We will sample from both replay buffer
|
||||||
|
|
||||||
|
@ -714,6 +741,19 @@ def add_actor_information_and_train(
|
||||||
replay_buffer.to_lerobot_dataset(
|
replay_buffer.to_lerobot_dataset(
|
||||||
cfg.dataset_repo_id, fps=cfg.fps, root=logger.log_dir / "dataset"
|
cfg.dataset_repo_id, fps=cfg.fps, root=logger.log_dir / "dataset"
|
||||||
)
|
)
|
||||||
|
if offline_replay_buffer is not None:
|
||||||
|
dataset_dir = logger.log_dir / "dataset_offline"
|
||||||
|
|
||||||
|
if dataset_dir.exists() and dataset_dir.is_dir():
|
||||||
|
shutil.rmtree(
|
||||||
|
dataset_dir,
|
||||||
|
)
|
||||||
|
|
||||||
|
offline_replay_buffer.to_lerobot_dataset(
|
||||||
|
cfg.dataset_repo_id,
|
||||||
|
fps=cfg.fps,
|
||||||
|
root=logger.log_dir / "dataset_offline",
|
||||||
|
)
|
||||||
|
|
||||||
logging.info("Resume training")
|
logging.info("Resume training")
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue