[WIP] Update SAC configuration and environment settings

- Reduced frame rate in `ManiskillEnvConfig` from 400 to 200.
- Enhanced `SACConfig` with new dataclasses for actor, learner, and network configurations.
- Improved input and output feature management in `SACConfig`.
- Refactored `actor_server` and `learner_server` to access configuration properties directly.
- Updated training pipeline to validate configurations and handle dataset repo IDs more robustly.
This commit is contained in:
AdilZouitine 2025-03-27 08:13:20 +00:00
parent 626e5dd35c
commit 052a4acfc2
7 changed files with 183 additions and 126 deletions

View File

@ -173,7 +173,7 @@ class ManiskillEnvConfig(EnvConfig):
control_mode: str = "pd_ee_delta_pose"
state_dim: int = 25
action_dim: int = 7
fps: int = 400
fps: int = 200
episode_length: int = 50
obs_type: str = "rgb"
render_mode: str = "rgb_array"

View File

@ -16,58 +16,100 @@
# limitations under the License.
from dataclasses import dataclass, field
from typing import Any
from typing import Any, Optional
from lerobot.common.optim.optimizers import MultiAdamConfig
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import NormalizationMode
from lerobot.configs.types import NormalizationMode, PolicyFeature, FeatureType
@dataclass
class ConcurrencyConfig:
actor: str = "threads"
learner: str = "threads"
@dataclass
class ActorLearnerConfig:
learner_host: str = "127.0.0.1"
learner_port: int = 50051
policy_parameters_push_frequency: int = 4
@dataclass
class CriticNetworkConfig:
hidden_dims: list[int] = field(default_factory=lambda: [256, 256])
activate_final: bool = True
final_activation: str | None = None
@dataclass
class ActorNetworkConfig:
hidden_dims: list[int] = field(default_factory=lambda: [256, 256])
activate_final: bool = True
@dataclass
class PolicyConfig:
use_tanh_squash: bool = True
log_std_min: int = -5
log_std_max: int = 2
init_final: float = 0.05
@PreTrainedConfig.register_subclass("sac")
@dataclass
class SACConfig(PreTrainedConfig):
"""Configuration class for Soft Actor-Critic (SAC) policy.
"""Soft Actor-Critic (SAC) configuration.
SAC is an off-policy actor-critic deep RL algorithm based on the maximum entropy
reinforcement learning framework. It learns a policy and a Q-function simultaneously
using experience collected from the environment.
This configuration class contains all the parameters needed to define a SAC agent,
including network architectures, optimization settings, and algorithm-specific
hyperparameters.
Args:
n_obs_steps: Number of environment steps worth of observations to pass to the policy.
normalization_mapping: Mapping from feature types to normalization modes.
dataset_stats: Statistics for normalizing different data types.
camera_number: Number of cameras to use.
device: Device to use for training.
storage_device: Device to use for storage.
vision_encoder_name: Name of the vision encoder to use.
freeze_vision_encoder: Whether to freeze the vision encoder.
image_encoder_hidden_dim: Hidden dimension for the image encoder.
shared_encoder: Whether to use a shared encoder.
online_steps: Total number of online training steps.
actor_network: Configuration for the actor network architecture.
critic_network: Configuration for the critic network architecture.
policy: Configuration for the policy parameters.
n_obs_steps: Number of observation steps to consider.
normalization_mapping: Mapping of feature types to normalization modes.
dataset_stats: Statistics for normalizing different types of inputs.
input_features: Dictionary of input features with their types and shapes.
output_features: Dictionary of output features with their types and shapes.
camera_number: Number of cameras used for visual observations.
device: Device to run the model on (e.g., "cuda", "cpu").
storage_device: Device to store the model on.
vision_encoder_name: Name of the vision encoder model.
freeze_vision_encoder: Whether to freeze the vision encoder during training.
image_encoder_hidden_dim: Hidden dimension size for the image encoder.
shared_encoder: Whether to use a shared encoder for actor and critic.
concurrency: Configuration for concurrency settings.
actor_learner: Configuration for actor-learner architecture.
online_steps: Number of steps for online training.
online_env_seed: Seed for the online environment.
online_buffer_capacity: Capacity of the online replay buffer.
online_step_before_learning: Number of steps to collect before starting learning.
offline_buffer_capacity: Capacity of the offline replay buffer.
online_step_before_learning: Number of steps before learning starts.
policy_update_freq: Frequency of policy updates.
discount: Discount factor for the RL algorithm.
temperature_init: Initial temperature for entropy regularization.
num_critics: Number of critic networks.
num_subsample_critics: Number of critics to subsample.
critic_lr: Learning rate for critic networks.
actor_lr: Learning rate for actor network.
temperature_lr: Learning rate for temperature parameter.
critic_target_update_weight: Weight for soft target updates.
utd_ratio: Update-to-data ratio (>1 to enable).
state_encoder_hidden_dim: Hidden dimension for state encoder.
latent_dim: Dimension of latent representation.
target_entropy: Target entropy for automatic temperature tuning.
use_backup_entropy: Whether to use backup entropy.
grad_clip_norm: Gradient clipping norm.
critic_network_kwargs: Additional arguments for critic networks.
actor_network_kwargs: Additional arguments for actor network.
policy_kwargs: Additional arguments for policy.
actor_learner_config: Configuration for actor-learner communication.
concurrency: Configuration for concurrency model.
discount: Discount factor for the SAC algorithm.
temperature_init: Initial temperature value.
num_critics: Number of critics in the ensemble.
num_subsample_critics: Number of subsampled critics for training.
critic_lr: Learning rate for the critic network.
actor_lr: Learning rate for the actor network.
temperature_lr: Learning rate for the temperature parameter.
critic_target_update_weight: Weight for the critic target update.
utd_ratio: Update-to-data ratio for the UTD algorithm.
state_encoder_hidden_dim: Hidden dimension size for the state encoder.
latent_dim: Dimension of the latent space.
target_entropy: Target entropy for the SAC algorithm.
use_backup_entropy: Whether to use backup entropy for the SAC algorithm.
grad_clip_norm: Gradient clipping norm for the SAC algorithm.
"""
# Input / output structure
n_obs_steps: int = 1
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.MEAN_STD,
@ -76,6 +118,7 @@ class SACConfig(PreTrainedConfig):
"ACTION": NormalizationMode.MIN_MAX,
}
)
dataset_stats: dict[str, dict[str, list[float]]] = field(
default_factory=lambda: {
"observation.image": {
@ -93,6 +136,18 @@ class SACConfig(PreTrainedConfig):
}
)
input_features: dict[str, PolicyFeature] = field(
default_factory=lambda: {
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 64, 64)),
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(2,)),
}
)
output_features: dict[str, PolicyFeature] = field(
default_factory=lambda: {
"action": PolicyFeature(type=FeatureType.ACTION, shape=(3,)),
}
)
# Architecture specifics
camera_number: int = 1
device: str = "cuda"
@ -106,7 +161,8 @@ class SACConfig(PreTrainedConfig):
# Training parameter
online_steps: int = 1000000
online_env_seed: int = 10000
online_buffer_capacity: int = 10000
online_buffer_capacity: int = 100000
offline_buffer_capacity: int = 100000
online_step_before_learning: int = 100
policy_update_freq: int = 1
@ -127,40 +183,21 @@ class SACConfig(PreTrainedConfig):
grad_clip_norm: float = 40.0
# Network configuration
critic_network_kwargs: dict[str, Any] = field(
default_factory=lambda: {
"hidden_dims": [256, 256],
"activate_final": True,
"final_activation": None,
}
critic_network_kwargs: CriticNetworkConfig = field(
default_factory=CriticNetworkConfig
)
actor_network_kwargs: dict[str, Any] = field(
default_factory=lambda: {
"hidden_dims": [256, 256],
"activate_final": True,
}
actor_network_kwargs: ActorNetworkConfig = field(
default_factory=ActorNetworkConfig
)
policy_kwargs: dict[str, Any] = field(
default_factory=lambda: {
"use_tanh_squash": True,
"log_std_min": -5,
"log_std_max": 2,
"init_final": 0.05,
}
policy_kwargs: PolicyConfig = field(
default_factory=PolicyConfig
)
actor_learner_config: dict[str, str | int] = field(
default_factory=lambda: {
"learner_host": "127.0.0.1",
"learner_port": 50051,
"policy_parameters_push_frequency": 4,
}
actor_learner_config: ActorLearnerConfig = field(
default_factory=ActorLearnerConfig
)
concurrency: dict[str, str] = field(
default_factory=lambda: {
"actor": "threads",
"learner": "threads"
}
concurrency: ConcurrencyConfig = field(
default_factory=ConcurrencyConfig
)
def __post_init__(self):
@ -181,9 +218,18 @@ class SACConfig(PreTrainedConfig):
return None
def validate_features(self) -> None:
# TODO: Maybe we should remove this raise?
if len(self.image_features) == 0:
raise ValueError("You must provide at least one image among the inputs.")
if "observation.image" not in self.input_features:
raise ValueError("You must provide 'observation.image' in the input features")
if "observation.state" not in self.input_features:
raise ValueError("You must provide 'observation.state' in the input features")
if "action" not in self.output_features:
raise ValueError("You must provide 'action' in the output features")
@property
def image_features(self) -> list[str]:
return [key for key in self.input_features.keys() if 'image' in key]
@property
def observation_delta_indices(self) -> list:

View File

@ -17,6 +17,7 @@
# TODO: (1) better device management
from dataclasses import asdict
import math
from pathlib import Path
from typing import Callable, Dict, List, Optional, Tuple, Union
@ -88,7 +89,7 @@ class SACPolicy(
critic_heads = [
CriticHead(
input_dim=encoder_critic.output_dim + config.output_features["action"].shape[0],
**config.critic_network_kwargs,
**asdict(config.critic_network_kwargs),
)
for _ in range(config.num_critics)
]
@ -103,7 +104,7 @@ class SACPolicy(
target_critic_heads = [
CriticHead(
input_dim=encoder_critic.output_dim + config.output_features["action"].shape[0],
**config.critic_network_kwargs,
**asdict(config.critic_network_kwargs),
)
for _ in range(config.num_critics)
]
@ -121,10 +122,10 @@ class SACPolicy(
self.actor = Policy(
encoder=encoder_actor,
network=MLP(input_dim=encoder_actor.output_dim, **config.actor_network_kwargs),
network=MLP(input_dim=encoder_actor.output_dim, **asdict(config.actor_network_kwargs)),
action_dim=config.output_features["action"].shape[0],
encoder_is_shared=config.shared_encoder,
**config.policy_kwargs,
**asdict(config.policy_kwargs),
)
if config.target_entropy is None:
config.target_entropy = -np.prod(config.output_features["action"].shape[0]) / 2 # (-dim(A)/2)

View File

@ -106,6 +106,7 @@ class TrainPipelineConfig(HubMixin):
train_dir = f"{now:%Y-%m-%d}/{now:%H-%M-%S}_{self.job_name}"
self.output_dir = Path("outputs/train") / train_dir
if self.dataset is not None:
if isinstance(self.dataset.repo_id, list):
raise NotImplementedError("LeRobotMultiDataset is not currently implemented.")

View File

@ -73,8 +73,8 @@ def receive_policy(
if grpc_channel is None or learner_client is None:
learner_client, grpc_channel = learner_service_client(
host=cfg.policy.actor_learner_config["learner_host"],
port=cfg.policy.actor_learner_config["learner_port"],
host=cfg.policy.actor_learner_config.learner_host,
port=cfg.policy.actor_learner_config.learner_port,
)
try:
@ -85,6 +85,7 @@ def receive_policy(
shutdown_event,
log_prefix="[ACTOR] parameters",
)
except grpc.RpcError as e:
logging.error(f"[ACTOR] gRPC error: {e}")
@ -153,8 +154,8 @@ def send_transitions(
if grpc_channel is None or learner_client is None:
learner_client, grpc_channel = learner_service_client(
host=cfg.policy.actor_learner_config["learner_host"],
port=cfg.policy.actor_learner_config["learner_port"],
host=cfg.policy.actor_learner_config.learner_host,
port=cfg.policy.actor_learner_config.learner_port,
)
try:
@ -193,8 +194,8 @@ def send_interactions(
if grpc_channel is None or learner_client is None:
learner_client, grpc_channel = learner_service_client(
host=cfg.policy.actor_learner_config["learner_host"],
port=cfg.policy.actor_learner_config["learner_port"],
host=cfg.policy.actor_learner_config.learner_host,
port=cfg.policy.actor_learner_config.learner_port,
)
try:
@ -286,10 +287,10 @@ def act_with_policy(
logging.info("make_env online")
online_env = make_robot_env(robot=robot, reward_classifier=reward_classifier, cfg=cfg)
online_env = make_robot_env( cfg=cfg.env)
set_seed(cfg.seed)
device = get_safe_torch_device(cfg.device, log=True)
device = get_safe_torch_device(cfg.policy.device, log=True)
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
@ -302,11 +303,7 @@ def act_with_policy(
# TODO: At some point we should just need make sac policy
policy: SACPolicy = make_policy(
cfg=cfg.policy,
# dataset_stats=offline_dataset.meta.stats if not cfg.resume else None,
# Hack: But if we do online training, we do not need dataset_stats
dataset_stats=None,
# TODO: Handle resume training
device=device,
env_cfg=cfg.env,
)
policy = torch.compile(policy)
assert isinstance(policy, nn.Module)
@ -322,13 +319,13 @@ def act_with_policy(
episode_intervention_steps = 0
episode_total_steps = 0
for interaction_step in range(cfg.training.online_steps):
for interaction_step in range(cfg.policy.online_steps):
start_time = time.perf_counter()
if shutdown_event.is_set():
logging.info("[ACTOR] Shutting down act_with_policy")
return
if interaction_step >= cfg.training.online_step_before_learning:
if interaction_step >= cfg.policy.online_step_before_learning:
# Time policy inference and check if it meets FPS requirement
with TimerManager(
elapsed_time_list=list_policy_time,
@ -426,9 +423,9 @@ def act_with_policy(
episode_total_steps = 0
obs, info = online_env.reset()
if cfg.fps is not None:
if cfg.env.fps is not None:
dt_time = time.perf_counter() - start_time
busy_wait(1 / cfg.fps - dt_time)
busy_wait(1 / cfg.env.fps - dt_time)
def push_transitions_to_transport_queue(transitions: list, transitions_queue):
@ -467,9 +464,9 @@ def get_frequency_stats(list_policy_time: list[float]) -> dict[str, float]:
def log_policy_frequency_issue(policy_fps: float, cfg: TrainPipelineConfig, interaction_step: int):
if policy_fps < cfg.fps:
if policy_fps < cfg.env.fps:
logging.warning(
f"[ACTOR] Policy FPS {policy_fps:.1f} below required {cfg.fps} at step {interaction_step}"
f"[ACTOR] Policy FPS {policy_fps:.1f} below required {cfg.env.fps} at step {interaction_step}"
)
@ -495,7 +492,7 @@ def establish_learner_connection(
def use_threads(cfg: TrainPipelineConfig) -> bool:
return cfg.policy.concurrency["actor"] == "threads"
return cfg.policy.concurrency.actor == "threads"
@parser.wrap()
@ -511,8 +508,8 @@ def actor_cli(cfg: TrainPipelineConfig):
shutdown_event = setup_process_handlers(use_threads(cfg))
learner_client, grpc_channel = learner_service_client(
host=cfg.policy.actor_learner_config["learner_host"],
port=cfg.policy.actor_learner_config["learner_port"],
host=cfg.policy.actor_learner_config.learner_host,
port=cfg.policy.actor_learner_config.learner_port,
)
logging.info("[ACTOR] Establishing connection with Learner")

View File

@ -1097,7 +1097,6 @@ class ActionScaleWrapper(gym.ActionWrapper):
return action * self.scale_vector, is_intervention
@parser.wrap()
def make_robot_env(cfg: EnvConfig) -> gym.vector.VectorEnv:
# def make_robot_env(cfg: TrainPipelineConfig) -> gym.vector.VectorEnv:
# def make_robot_env(cfg: ManiskillEnvConfig) -> gym.vector.VectorEnv:

View File

@ -48,6 +48,7 @@ from lerobot.common.utils.train_utils import (
load_training_state as utils_load_training_state,
save_checkpoint,
update_last_checkpoint,
save_training_state,
)
from lerobot.common.utils.random_utils import set_seed
from lerobot.common.utils.utils import (
@ -160,11 +161,12 @@ def load_training_state(
try:
# Use the utility function from train_utils which loads the optimizer state
# The function returns (step, updated_optimizer, scheduler)
step, optimizers, _ = utils_load_training_state(Path(checkpoint_dir), optimizers, None)
# For interaction step, we still need to load the training_state.pt file
# Load interaction step separately from training_state.pt
training_state_path = os.path.join(checkpoint_dir, TRAINING_STATE_DIR, "training_state.pt")
interaction_step = 0
if os.path.exists(training_state_path):
training_state = torch.load(training_state_path, weights_only=False)
interaction_step = training_state.get("interaction_step", 0)
@ -222,16 +224,20 @@ def initialize_replay_buffer(
logging.info("Resume training load the online dataset")
dataset_path = os.path.join(cfg.output_dir, "dataset")
# NOTE: In RL is possible to not have a dataset.
repo_id = None
if cfg.dataset is not None:
repo_id = cfg.dataset.dataset_repo_id
dataset = LeRobotDataset(
repo_id=cfg.dataset.dataset_repo_id,
local_files_only=True,
repo_id=repo_id,
root=dataset_path,
)
return ReplayBuffer.from_lerobot_dataset(
lerobot_dataset=dataset,
capacity=cfg.policy.online_buffer_capacity,
device=device,
state_keys=cfg.policy.input_shapes.keys(),
state_keys=cfg.policy.input_features.keys(),
optimize_memory=True,
)
@ -298,7 +304,7 @@ def get_observation_features(
def use_threads(cfg: TrainPipelineConfig) -> bool:
return cfg.policy.concurrency["learner"] == "threads"
return cfg.policy.concurrency.learner == "threads"
def start_learner_threads(
@ -388,7 +394,7 @@ def start_learner_server(
service = learner_service.LearnerService(
shutdown_event=shutdown_event,
parameters_queue=parameters_queue,
seconds_between_pushes=cfg.policy.actor_learner_config["policy_parameters_push_frequency"],
seconds_between_pushes=cfg.policy.actor_learner_config.policy_parameters_push_frequency,
transition_queue=transition_queue,
interaction_message_queue=interaction_message_queue,
)
@ -406,8 +412,8 @@ def start_learner_server(
server,
)
host = cfg.policy.actor_learner_config["learner_host"]
port = cfg.policy.actor_learner_config["learner_port"]
host = cfg.policy.actor_learner_config.learner_host
port = cfg.policy.actor_learner_config.learner_port
server.add_insecure_port(f"{host}:{port}")
server.start()
@ -509,7 +515,6 @@ def add_actor_information_and_train(
checkpoint_dir = os.path.join(cfg.output_dir, CHECKPOINTS_DIR, LAST_CHECKPOINT_LINK) if cfg.resume else None
pretrained_path = os.path.join(checkpoint_dir, PRETRAINED_MODEL_DIR) if checkpoint_dir else None
# TODO(Adil): This don't work anymore !
policy: SACPolicy = make_policy(
cfg=cfg.policy,
# ds_meta=cfg.dataset,
@ -575,8 +580,8 @@ def add_actor_information_and_train(
device = cfg.policy.device
storage_device = cfg.policy.storage_device
policy_update_freq = cfg.policy.policy_update_freq
policy_parameters_push_frequency = cfg.policy.actor_learner_config["policy_parameters_push_frequency"]
save_checkpoint = cfg.save_checkpoint
policy_parameters_push_frequency = cfg.policy.actor_learner_config.policy_parameters_push_frequency
saving_checkpoint = cfg.save_checkpoint
online_steps = cfg.policy.online_steps
while True:
@ -598,7 +603,7 @@ def add_actor_information_and_train(
continue
replay_buffer.add(**transition)
if cfg.dataset.repo_id is not None and transition.get("complementary_info", {}).get(
if dataset_repo_id is not None and transition.get("complementary_info", {}).get(
"is_intervention"
):
offline_replay_buffer.add(**transition)
@ -618,9 +623,6 @@ def add_actor_information_and_train(
mode="train",
custom_step_key="Interaction step"
)
else:
# Log to console if no WandB logger
logging.info(f"Interaction: {interaction_message}")
logging.debug("[LEARNER] Received interactions")
@ -765,9 +767,6 @@ def add_actor_information_and_train(
mode="train",
custom_step_key="Optimization step"
)
else:
# Log to console if no WandB logger
logging.info(f"Training: {training_infos}")
time_for_one_optimization_step = time.time() - time_for_one_optimization_step
frequency_for_one_optimization_step = 1 / (time_for_one_optimization_step + 1e-9)
@ -789,7 +788,7 @@ def add_actor_information_and_train(
if optimization_step % log_freq == 0:
logging.info(f"[LEARNER] Number of optimization step: {optimization_step}")
if save_checkpoint and (optimization_step % save_freq == 0 or optimization_step == online_steps):
if saving_checkpoint and (optimization_step % save_freq == 0 or optimization_step == online_steps):
logging.info(f"Checkpoint policy after step {optimization_step}")
_num_digits = max(6, len(str(online_steps)))
step_identifier = f"{optimization_step:0{_num_digits}d}"
@ -810,6 +809,15 @@ def add_actor_information_and_train(
scheduler=None
)
# Save interaction step manually
training_state_dir = os.path.join(checkpoint_dir, TRAINING_STATE_DIR)
os.makedirs(training_state_dir, exist_ok=True)
training_state = {
"step": optimization_step,
"interaction_step": interaction_step
}
torch.save(training_state, os.path.join(training_state_dir, "training_state.pt"))
# Update the "last" symlink
update_last_checkpoint(checkpoint_dir)
@ -820,8 +828,11 @@ def add_actor_information_and_train(
shutil.rmtree(dataset_dir)
# Save dataset
# NOTE: Handle the case where the dataset repo id is not specified in the config
# eg. RL training without demonstrations data
repo_id_buffer_save = cfg.env.task if dataset_repo_id is None else dataset_repo_id
replay_buffer.to_lerobot_dataset(
dataset_repo_id,
repo_id=repo_id_buffer_save,
fps=fps,
root=dataset_dir
)
@ -892,8 +903,10 @@ def train(cfg: TrainPipelineConfig, job_name: str | None = None):
cfg (TrainPipelineConfig): The training configuration
job_name (str | None, optional): Job name for logging. Defaults to None.
"""
if cfg.output_dir is None:
raise ValueError("Output directory must be specified in config")
cfg.validate()
# if cfg.output_dir is None:
# raise ValueError("Output directory must be specified in config")
if job_name is None:
job_name = cfg.job_name