[WIP] Update SAC configuration and environment settings

- Reduced frame rate in `ManiskillEnvConfig` from 400 to 200.
- Enhanced `SACConfig` with new dataclasses for actor, learner, and network configurations.
- Improved input and output feature management in `SACConfig`.
- Refactored `actor_server` and `learner_server` to access configuration properties directly.
- Updated training pipeline to validate configurations and handle dataset repo IDs more robustly.
This commit is contained in:
AdilZouitine 2025-03-27 08:13:20 +00:00
parent 626e5dd35c
commit 052a4acfc2
7 changed files with 183 additions and 126 deletions

View File

@ -173,7 +173,7 @@ class ManiskillEnvConfig(EnvConfig):
control_mode: str = "pd_ee_delta_pose" control_mode: str = "pd_ee_delta_pose"
state_dim: int = 25 state_dim: int = 25
action_dim: int = 7 action_dim: int = 7
fps: int = 400 fps: int = 200
episode_length: int = 50 episode_length: int = 50
obs_type: str = "rgb" obs_type: str = "rgb"
render_mode: str = "rgb_array" render_mode: str = "rgb_array"

View File

@ -16,58 +16,100 @@
# limitations under the License. # limitations under the License.
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any from typing import Any, Optional
from lerobot.common.optim.optimizers import MultiAdamConfig from lerobot.common.optim.optimizers import MultiAdamConfig
from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import NormalizationMode from lerobot.configs.types import NormalizationMode, PolicyFeature, FeatureType
@dataclass
class ConcurrencyConfig:
actor: str = "threads"
learner: str = "threads"
@dataclass
class ActorLearnerConfig:
learner_host: str = "127.0.0.1"
learner_port: int = 50051
policy_parameters_push_frequency: int = 4
@dataclass
class CriticNetworkConfig:
hidden_dims: list[int] = field(default_factory=lambda: [256, 256])
activate_final: bool = True
final_activation: str | None = None
@dataclass
class ActorNetworkConfig:
hidden_dims: list[int] = field(default_factory=lambda: [256, 256])
activate_final: bool = True
@dataclass
class PolicyConfig:
use_tanh_squash: bool = True
log_std_min: int = -5
log_std_max: int = 2
init_final: float = 0.05
@PreTrainedConfig.register_subclass("sac") @PreTrainedConfig.register_subclass("sac")
@dataclass @dataclass
class SACConfig(PreTrainedConfig): class SACConfig(PreTrainedConfig):
"""Configuration class for Soft Actor-Critic (SAC) policy. """Soft Actor-Critic (SAC) configuration.
SAC is an off-policy actor-critic deep RL algorithm based on the maximum entropy
reinforcement learning framework. It learns a policy and a Q-function simultaneously
using experience collected from the environment.
This configuration class contains all the parameters needed to define a SAC agent,
including network architectures, optimization settings, and algorithm-specific
hyperparameters.
Args: Args:
n_obs_steps: Number of environment steps worth of observations to pass to the policy. actor_network: Configuration for the actor network architecture.
normalization_mapping: Mapping from feature types to normalization modes. critic_network: Configuration for the critic network architecture.
dataset_stats: Statistics for normalizing different data types. policy: Configuration for the policy parameters.
camera_number: Number of cameras to use. n_obs_steps: Number of observation steps to consider.
device: Device to use for training. normalization_mapping: Mapping of feature types to normalization modes.
storage_device: Device to use for storage. dataset_stats: Statistics for normalizing different types of inputs.
vision_encoder_name: Name of the vision encoder to use. input_features: Dictionary of input features with their types and shapes.
freeze_vision_encoder: Whether to freeze the vision encoder. output_features: Dictionary of output features with their types and shapes.
image_encoder_hidden_dim: Hidden dimension for the image encoder. camera_number: Number of cameras used for visual observations.
shared_encoder: Whether to use a shared encoder. device: Device to run the model on (e.g., "cuda", "cpu").
online_steps: Total number of online training steps. storage_device: Device to store the model on.
vision_encoder_name: Name of the vision encoder model.
freeze_vision_encoder: Whether to freeze the vision encoder during training.
image_encoder_hidden_dim: Hidden dimension size for the image encoder.
shared_encoder: Whether to use a shared encoder for actor and critic.
concurrency: Configuration for concurrency settings.
actor_learner: Configuration for actor-learner architecture.
online_steps: Number of steps for online training.
online_env_seed: Seed for the online environment. online_env_seed: Seed for the online environment.
online_buffer_capacity: Capacity of the online replay buffer. online_buffer_capacity: Capacity of the online replay buffer.
online_step_before_learning: Number of steps to collect before starting learning. offline_buffer_capacity: Capacity of the offline replay buffer.
online_step_before_learning: Number of steps before learning starts.
policy_update_freq: Frequency of policy updates. policy_update_freq: Frequency of policy updates.
discount: Discount factor for the RL algorithm. discount: Discount factor for the SAC algorithm.
temperature_init: Initial temperature for entropy regularization. temperature_init: Initial temperature value.
num_critics: Number of critic networks. num_critics: Number of critics in the ensemble.
num_subsample_critics: Number of critics to subsample. num_subsample_critics: Number of subsampled critics for training.
critic_lr: Learning rate for critic networks. critic_lr: Learning rate for the critic network.
actor_lr: Learning rate for actor network. actor_lr: Learning rate for the actor network.
temperature_lr: Learning rate for temperature parameter. temperature_lr: Learning rate for the temperature parameter.
critic_target_update_weight: Weight for soft target updates. critic_target_update_weight: Weight for the critic target update.
utd_ratio: Update-to-data ratio (>1 to enable). utd_ratio: Update-to-data ratio for the UTD algorithm.
state_encoder_hidden_dim: Hidden dimension for state encoder. state_encoder_hidden_dim: Hidden dimension size for the state encoder.
latent_dim: Dimension of latent representation. latent_dim: Dimension of the latent space.
target_entropy: Target entropy for automatic temperature tuning. target_entropy: Target entropy for the SAC algorithm.
use_backup_entropy: Whether to use backup entropy. use_backup_entropy: Whether to use backup entropy for the SAC algorithm.
grad_clip_norm: Gradient clipping norm. grad_clip_norm: Gradient clipping norm for the SAC algorithm.
critic_network_kwargs: Additional arguments for critic networks.
actor_network_kwargs: Additional arguments for actor network.
policy_kwargs: Additional arguments for policy.
actor_learner_config: Configuration for actor-learner communication.
concurrency: Configuration for concurrency model.
""" """
# Input / output structure
n_obs_steps: int = 1
normalization_mapping: dict[str, NormalizationMode] = field( normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: { default_factory=lambda: {
"VISUAL": NormalizationMode.MEAN_STD, "VISUAL": NormalizationMode.MEAN_STD,
@ -76,6 +118,7 @@ class SACConfig(PreTrainedConfig):
"ACTION": NormalizationMode.MIN_MAX, "ACTION": NormalizationMode.MIN_MAX,
} }
) )
dataset_stats: dict[str, dict[str, list[float]]] = field( dataset_stats: dict[str, dict[str, list[float]]] = field(
default_factory=lambda: { default_factory=lambda: {
"observation.image": { "observation.image": {
@ -93,6 +136,18 @@ class SACConfig(PreTrainedConfig):
} }
) )
input_features: dict[str, PolicyFeature] = field(
default_factory=lambda: {
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 64, 64)),
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(2,)),
}
)
output_features: dict[str, PolicyFeature] = field(
default_factory=lambda: {
"action": PolicyFeature(type=FeatureType.ACTION, shape=(3,)),
}
)
# Architecture specifics # Architecture specifics
camera_number: int = 1 camera_number: int = 1
device: str = "cuda" device: str = "cuda"
@ -106,7 +161,8 @@ class SACConfig(PreTrainedConfig):
# Training parameter # Training parameter
online_steps: int = 1000000 online_steps: int = 1000000
online_env_seed: int = 10000 online_env_seed: int = 10000
online_buffer_capacity: int = 10000 online_buffer_capacity: int = 100000
offline_buffer_capacity: int = 100000
online_step_before_learning: int = 100 online_step_before_learning: int = 100
policy_update_freq: int = 1 policy_update_freq: int = 1
@ -127,40 +183,21 @@ class SACConfig(PreTrainedConfig):
grad_clip_norm: float = 40.0 grad_clip_norm: float = 40.0
# Network configuration # Network configuration
critic_network_kwargs: dict[str, Any] = field( critic_network_kwargs: CriticNetworkConfig = field(
default_factory=lambda: { default_factory=CriticNetworkConfig
"hidden_dims": [256, 256],
"activate_final": True,
"final_activation": None,
}
) )
actor_network_kwargs: dict[str, Any] = field( actor_network_kwargs: ActorNetworkConfig = field(
default_factory=lambda: { default_factory=ActorNetworkConfig
"hidden_dims": [256, 256],
"activate_final": True,
}
) )
policy_kwargs: dict[str, Any] = field( policy_kwargs: PolicyConfig = field(
default_factory=lambda: { default_factory=PolicyConfig
"use_tanh_squash": True,
"log_std_min": -5,
"log_std_max": 2,
"init_final": 0.05,
}
) )
actor_learner_config: dict[str, str | int] = field( actor_learner_config: ActorLearnerConfig = field(
default_factory=lambda: { default_factory=ActorLearnerConfig
"learner_host": "127.0.0.1",
"learner_port": 50051,
"policy_parameters_push_frequency": 4,
}
) )
concurrency: dict[str, str] = field( concurrency: ConcurrencyConfig = field(
default_factory=lambda: { default_factory=ConcurrencyConfig
"actor": "threads",
"learner": "threads"
}
) )
def __post_init__(self): def __post_init__(self):
@ -181,9 +218,18 @@ class SACConfig(PreTrainedConfig):
return None return None
def validate_features(self) -> None: def validate_features(self) -> None:
# TODO: Maybe we should remove this raise? if "observation.image" not in self.input_features:
if len(self.image_features) == 0: raise ValueError("You must provide 'observation.image' in the input features")
raise ValueError("You must provide at least one image among the inputs.")
if "observation.state" not in self.input_features:
raise ValueError("You must provide 'observation.state' in the input features")
if "action" not in self.output_features:
raise ValueError("You must provide 'action' in the output features")
@property
def image_features(self) -> list[str]:
return [key for key in self.input_features.keys() if 'image' in key]
@property @property
def observation_delta_indices(self) -> list: def observation_delta_indices(self) -> list:

View File

@ -17,6 +17,7 @@
# TODO: (1) better device management # TODO: (1) better device management
from dataclasses import asdict
import math import math
from pathlib import Path from pathlib import Path
from typing import Callable, Dict, List, Optional, Tuple, Union from typing import Callable, Dict, List, Optional, Tuple, Union
@ -88,7 +89,7 @@ class SACPolicy(
critic_heads = [ critic_heads = [
CriticHead( CriticHead(
input_dim=encoder_critic.output_dim + config.output_features["action"].shape[0], input_dim=encoder_critic.output_dim + config.output_features["action"].shape[0],
**config.critic_network_kwargs, **asdict(config.critic_network_kwargs),
) )
for _ in range(config.num_critics) for _ in range(config.num_critics)
] ]
@ -103,7 +104,7 @@ class SACPolicy(
target_critic_heads = [ target_critic_heads = [
CriticHead( CriticHead(
input_dim=encoder_critic.output_dim + config.output_features["action"].shape[0], input_dim=encoder_critic.output_dim + config.output_features["action"].shape[0],
**config.critic_network_kwargs, **asdict(config.critic_network_kwargs),
) )
for _ in range(config.num_critics) for _ in range(config.num_critics)
] ]
@ -121,10 +122,10 @@ class SACPolicy(
self.actor = Policy( self.actor = Policy(
encoder=encoder_actor, encoder=encoder_actor,
network=MLP(input_dim=encoder_actor.output_dim, **config.actor_network_kwargs), network=MLP(input_dim=encoder_actor.output_dim, **asdict(config.actor_network_kwargs)),
action_dim=config.output_features["action"].shape[0], action_dim=config.output_features["action"].shape[0],
encoder_is_shared=config.shared_encoder, encoder_is_shared=config.shared_encoder,
**config.policy_kwargs, **asdict(config.policy_kwargs),
) )
if config.target_entropy is None: if config.target_entropy is None:
config.target_entropy = -np.prod(config.output_features["action"].shape[0]) / 2 # (-dim(A)/2) config.target_entropy = -np.prod(config.output_features["action"].shape[0]) / 2 # (-dim(A)/2)

View File

@ -106,8 +106,9 @@ class TrainPipelineConfig(HubMixin):
train_dir = f"{now:%Y-%m-%d}/{now:%H-%M-%S}_{self.job_name}" train_dir = f"{now:%Y-%m-%d}/{now:%H-%M-%S}_{self.job_name}"
self.output_dir = Path("outputs/train") / train_dir self.output_dir = Path("outputs/train") / train_dir
if isinstance(self.dataset.repo_id, list): if self.dataset is not None:
raise NotImplementedError("LeRobotMultiDataset is not currently implemented.") if isinstance(self.dataset.repo_id, list):
raise NotImplementedError("LeRobotMultiDataset is not currently implemented.")
if not self.use_policy_training_preset and (self.optimizer is None or self.scheduler is None): if not self.use_policy_training_preset and (self.optimizer is None or self.scheduler is None):
raise ValueError("Optimizer and Scheduler must be set when the policy presets are not used.") raise ValueError("Optimizer and Scheduler must be set when the policy presets are not used.")

View File

@ -73,8 +73,8 @@ def receive_policy(
if grpc_channel is None or learner_client is None: if grpc_channel is None or learner_client is None:
learner_client, grpc_channel = learner_service_client( learner_client, grpc_channel = learner_service_client(
host=cfg.policy.actor_learner_config["learner_host"], host=cfg.policy.actor_learner_config.learner_host,
port=cfg.policy.actor_learner_config["learner_port"], port=cfg.policy.actor_learner_config.learner_port,
) )
try: try:
@ -85,6 +85,7 @@ def receive_policy(
shutdown_event, shutdown_event,
log_prefix="[ACTOR] parameters", log_prefix="[ACTOR] parameters",
) )
except grpc.RpcError as e: except grpc.RpcError as e:
logging.error(f"[ACTOR] gRPC error: {e}") logging.error(f"[ACTOR] gRPC error: {e}")
@ -153,8 +154,8 @@ def send_transitions(
if grpc_channel is None or learner_client is None: if grpc_channel is None or learner_client is None:
learner_client, grpc_channel = learner_service_client( learner_client, grpc_channel = learner_service_client(
host=cfg.policy.actor_learner_config["learner_host"], host=cfg.policy.actor_learner_config.learner_host,
port=cfg.policy.actor_learner_config["learner_port"], port=cfg.policy.actor_learner_config.learner_port,
) )
try: try:
@ -193,8 +194,8 @@ def send_interactions(
if grpc_channel is None or learner_client is None: if grpc_channel is None or learner_client is None:
learner_client, grpc_channel = learner_service_client( learner_client, grpc_channel = learner_service_client(
host=cfg.policy.actor_learner_config["learner_host"], host=cfg.policy.actor_learner_config.learner_host,
port=cfg.policy.actor_learner_config["learner_port"], port=cfg.policy.actor_learner_config.learner_port,
) )
try: try:
@ -286,10 +287,10 @@ def act_with_policy(
logging.info("make_env online") logging.info("make_env online")
online_env = make_robot_env(robot=robot, reward_classifier=reward_classifier, cfg=cfg) online_env = make_robot_env( cfg=cfg.env)
set_seed(cfg.seed) set_seed(cfg.seed)
device = get_safe_torch_device(cfg.device, log=True) device = get_safe_torch_device(cfg.policy.device, log=True)
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True
@ -302,11 +303,7 @@ def act_with_policy(
# TODO: At some point we should just need make sac policy # TODO: At some point we should just need make sac policy
policy: SACPolicy = make_policy( policy: SACPolicy = make_policy(
cfg=cfg.policy, cfg=cfg.policy,
# dataset_stats=offline_dataset.meta.stats if not cfg.resume else None, env_cfg=cfg.env,
# Hack: But if we do online training, we do not need dataset_stats
dataset_stats=None,
# TODO: Handle resume training
device=device,
) )
policy = torch.compile(policy) policy = torch.compile(policy)
assert isinstance(policy, nn.Module) assert isinstance(policy, nn.Module)
@ -322,13 +319,13 @@ def act_with_policy(
episode_intervention_steps = 0 episode_intervention_steps = 0
episode_total_steps = 0 episode_total_steps = 0
for interaction_step in range(cfg.training.online_steps): for interaction_step in range(cfg.policy.online_steps):
start_time = time.perf_counter() start_time = time.perf_counter()
if shutdown_event.is_set(): if shutdown_event.is_set():
logging.info("[ACTOR] Shutting down act_with_policy") logging.info("[ACTOR] Shutting down act_with_policy")
return return
if interaction_step >= cfg.training.online_step_before_learning: if interaction_step >= cfg.policy.online_step_before_learning:
# Time policy inference and check if it meets FPS requirement # Time policy inference and check if it meets FPS requirement
with TimerManager( with TimerManager(
elapsed_time_list=list_policy_time, elapsed_time_list=list_policy_time,
@ -426,9 +423,9 @@ def act_with_policy(
episode_total_steps = 0 episode_total_steps = 0
obs, info = online_env.reset() obs, info = online_env.reset()
if cfg.fps is not None: if cfg.env.fps is not None:
dt_time = time.perf_counter() - start_time dt_time = time.perf_counter() - start_time
busy_wait(1 / cfg.fps - dt_time) busy_wait(1 / cfg.env.fps - dt_time)
def push_transitions_to_transport_queue(transitions: list, transitions_queue): def push_transitions_to_transport_queue(transitions: list, transitions_queue):
@ -467,9 +464,9 @@ def get_frequency_stats(list_policy_time: list[float]) -> dict[str, float]:
def log_policy_frequency_issue(policy_fps: float, cfg: TrainPipelineConfig, interaction_step: int): def log_policy_frequency_issue(policy_fps: float, cfg: TrainPipelineConfig, interaction_step: int):
if policy_fps < cfg.fps: if policy_fps < cfg.env.fps:
logging.warning( logging.warning(
f"[ACTOR] Policy FPS {policy_fps:.1f} below required {cfg.fps} at step {interaction_step}" f"[ACTOR] Policy FPS {policy_fps:.1f} below required {cfg.env.fps} at step {interaction_step}"
) )
@ -495,7 +492,7 @@ def establish_learner_connection(
def use_threads(cfg: TrainPipelineConfig) -> bool: def use_threads(cfg: TrainPipelineConfig) -> bool:
return cfg.policy.concurrency["actor"] == "threads" return cfg.policy.concurrency.actor == "threads"
@parser.wrap() @parser.wrap()
@ -511,8 +508,8 @@ def actor_cli(cfg: TrainPipelineConfig):
shutdown_event = setup_process_handlers(use_threads(cfg)) shutdown_event = setup_process_handlers(use_threads(cfg))
learner_client, grpc_channel = learner_service_client( learner_client, grpc_channel = learner_service_client(
host=cfg.policy.actor_learner_config["learner_host"], host=cfg.policy.actor_learner_config.learner_host,
port=cfg.policy.actor_learner_config["learner_port"], port=cfg.policy.actor_learner_config.learner_port,
) )
logging.info("[ACTOR] Establishing connection with Learner") logging.info("[ACTOR] Establishing connection with Learner")

View File

@ -1097,7 +1097,6 @@ class ActionScaleWrapper(gym.ActionWrapper):
return action * self.scale_vector, is_intervention return action * self.scale_vector, is_intervention
@parser.wrap()
def make_robot_env(cfg: EnvConfig) -> gym.vector.VectorEnv: def make_robot_env(cfg: EnvConfig) -> gym.vector.VectorEnv:
# def make_robot_env(cfg: TrainPipelineConfig) -> gym.vector.VectorEnv: # def make_robot_env(cfg: TrainPipelineConfig) -> gym.vector.VectorEnv:
# def make_robot_env(cfg: ManiskillEnvConfig) -> gym.vector.VectorEnv: # def make_robot_env(cfg: ManiskillEnvConfig) -> gym.vector.VectorEnv:

View File

@ -48,6 +48,7 @@ from lerobot.common.utils.train_utils import (
load_training_state as utils_load_training_state, load_training_state as utils_load_training_state,
save_checkpoint, save_checkpoint,
update_last_checkpoint, update_last_checkpoint,
save_training_state,
) )
from lerobot.common.utils.random_utils import set_seed from lerobot.common.utils.random_utils import set_seed
from lerobot.common.utils.utils import ( from lerobot.common.utils.utils import (
@ -160,13 +161,14 @@ def load_training_state(
try: try:
# Use the utility function from train_utils which loads the optimizer state # Use the utility function from train_utils which loads the optimizer state
# The function returns (step, updated_optimizer, scheduler)
step, optimizers, _ = utils_load_training_state(Path(checkpoint_dir), optimizers, None) step, optimizers, _ = utils_load_training_state(Path(checkpoint_dir), optimizers, None)
# For interaction step, we still need to load the training_state.pt file # Load interaction step separately from training_state.pt
training_state_path = os.path.join(checkpoint_dir, TRAINING_STATE_DIR, "training_state.pt") training_state_path = os.path.join(checkpoint_dir, TRAINING_STATE_DIR, "training_state.pt")
training_state = torch.load(training_state_path, weights_only=False) interaction_step = 0
interaction_step = training_state.get("interaction_step", 0) if os.path.exists(training_state_path):
training_state = torch.load(training_state_path, weights_only=False)
interaction_step = training_state.get("interaction_step", 0)
logging.info(f"Resuming from step {step}, interaction step {interaction_step}") logging.info(f"Resuming from step {step}, interaction step {interaction_step}")
return step, interaction_step return step, interaction_step
@ -222,16 +224,20 @@ def initialize_replay_buffer(
logging.info("Resume training load the online dataset") logging.info("Resume training load the online dataset")
dataset_path = os.path.join(cfg.output_dir, "dataset") dataset_path = os.path.join(cfg.output_dir, "dataset")
# NOTE: In RL is possible to not have a dataset.
repo_id = None
if cfg.dataset is not None:
repo_id = cfg.dataset.dataset_repo_id
dataset = LeRobotDataset( dataset = LeRobotDataset(
repo_id=cfg.dataset.dataset_repo_id, repo_id=repo_id,
local_files_only=True,
root=dataset_path, root=dataset_path,
) )
return ReplayBuffer.from_lerobot_dataset( return ReplayBuffer.from_lerobot_dataset(
lerobot_dataset=dataset, lerobot_dataset=dataset,
capacity=cfg.policy.online_buffer_capacity, capacity=cfg.policy.online_buffer_capacity,
device=device, device=device,
state_keys=cfg.policy.input_shapes.keys(), state_keys=cfg.policy.input_features.keys(),
optimize_memory=True, optimize_memory=True,
) )
@ -298,7 +304,7 @@ def get_observation_features(
def use_threads(cfg: TrainPipelineConfig) -> bool: def use_threads(cfg: TrainPipelineConfig) -> bool:
return cfg.policy.concurrency["learner"] == "threads" return cfg.policy.concurrency.learner == "threads"
def start_learner_threads( def start_learner_threads(
@ -388,7 +394,7 @@ def start_learner_server(
service = learner_service.LearnerService( service = learner_service.LearnerService(
shutdown_event=shutdown_event, shutdown_event=shutdown_event,
parameters_queue=parameters_queue, parameters_queue=parameters_queue,
seconds_between_pushes=cfg.policy.actor_learner_config["policy_parameters_push_frequency"], seconds_between_pushes=cfg.policy.actor_learner_config.policy_parameters_push_frequency,
transition_queue=transition_queue, transition_queue=transition_queue,
interaction_message_queue=interaction_message_queue, interaction_message_queue=interaction_message_queue,
) )
@ -406,8 +412,8 @@ def start_learner_server(
server, server,
) )
host = cfg.policy.actor_learner_config["learner_host"] host = cfg.policy.actor_learner_config.learner_host
port = cfg.policy.actor_learner_config["learner_port"] port = cfg.policy.actor_learner_config.learner_port
server.add_insecure_port(f"{host}:{port}") server.add_insecure_port(f"{host}:{port}")
server.start() server.start()
@ -509,7 +515,6 @@ def add_actor_information_and_train(
checkpoint_dir = os.path.join(cfg.output_dir, CHECKPOINTS_DIR, LAST_CHECKPOINT_LINK) if cfg.resume else None checkpoint_dir = os.path.join(cfg.output_dir, CHECKPOINTS_DIR, LAST_CHECKPOINT_LINK) if cfg.resume else None
pretrained_path = os.path.join(checkpoint_dir, PRETRAINED_MODEL_DIR) if checkpoint_dir else None pretrained_path = os.path.join(checkpoint_dir, PRETRAINED_MODEL_DIR) if checkpoint_dir else None
# TODO(Adil): This don't work anymore !
policy: SACPolicy = make_policy( policy: SACPolicy = make_policy(
cfg=cfg.policy, cfg=cfg.policy,
# ds_meta=cfg.dataset, # ds_meta=cfg.dataset,
@ -575,8 +580,8 @@ def add_actor_information_and_train(
device = cfg.policy.device device = cfg.policy.device
storage_device = cfg.policy.storage_device storage_device = cfg.policy.storage_device
policy_update_freq = cfg.policy.policy_update_freq policy_update_freq = cfg.policy.policy_update_freq
policy_parameters_push_frequency = cfg.policy.actor_learner_config["policy_parameters_push_frequency"] policy_parameters_push_frequency = cfg.policy.actor_learner_config.policy_parameters_push_frequency
save_checkpoint = cfg.save_checkpoint saving_checkpoint = cfg.save_checkpoint
online_steps = cfg.policy.online_steps online_steps = cfg.policy.online_steps
while True: while True:
@ -598,7 +603,7 @@ def add_actor_information_and_train(
continue continue
replay_buffer.add(**transition) replay_buffer.add(**transition)
if cfg.dataset.repo_id is not None and transition.get("complementary_info", {}).get( if dataset_repo_id is not None and transition.get("complementary_info", {}).get(
"is_intervention" "is_intervention"
): ):
offline_replay_buffer.add(**transition) offline_replay_buffer.add(**transition)
@ -618,9 +623,6 @@ def add_actor_information_and_train(
mode="train", mode="train",
custom_step_key="Interaction step" custom_step_key="Interaction step"
) )
else:
# Log to console if no WandB logger
logging.info(f"Interaction: {interaction_message}")
logging.debug("[LEARNER] Received interactions") logging.debug("[LEARNER] Received interactions")
@ -765,9 +767,6 @@ def add_actor_information_and_train(
mode="train", mode="train",
custom_step_key="Optimization step" custom_step_key="Optimization step"
) )
else:
# Log to console if no WandB logger
logging.info(f"Training: {training_infos}")
time_for_one_optimization_step = time.time() - time_for_one_optimization_step time_for_one_optimization_step = time.time() - time_for_one_optimization_step
frequency_for_one_optimization_step = 1 / (time_for_one_optimization_step + 1e-9) frequency_for_one_optimization_step = 1 / (time_for_one_optimization_step + 1e-9)
@ -789,7 +788,7 @@ def add_actor_information_and_train(
if optimization_step % log_freq == 0: if optimization_step % log_freq == 0:
logging.info(f"[LEARNER] Number of optimization step: {optimization_step}") logging.info(f"[LEARNER] Number of optimization step: {optimization_step}")
if save_checkpoint and (optimization_step % save_freq == 0 or optimization_step == online_steps): if saving_checkpoint and (optimization_step % save_freq == 0 or optimization_step == online_steps):
logging.info(f"Checkpoint policy after step {optimization_step}") logging.info(f"Checkpoint policy after step {optimization_step}")
_num_digits = max(6, len(str(online_steps))) _num_digits = max(6, len(str(online_steps)))
step_identifier = f"{optimization_step:0{_num_digits}d}" step_identifier = f"{optimization_step:0{_num_digits}d}"
@ -810,6 +809,15 @@ def add_actor_information_and_train(
scheduler=None scheduler=None
) )
# Save interaction step manually
training_state_dir = os.path.join(checkpoint_dir, TRAINING_STATE_DIR)
os.makedirs(training_state_dir, exist_ok=True)
training_state = {
"step": optimization_step,
"interaction_step": interaction_step
}
torch.save(training_state, os.path.join(training_state_dir, "training_state.pt"))
# Update the "last" symlink # Update the "last" symlink
update_last_checkpoint(checkpoint_dir) update_last_checkpoint(checkpoint_dir)
@ -820,8 +828,11 @@ def add_actor_information_and_train(
shutil.rmtree(dataset_dir) shutil.rmtree(dataset_dir)
# Save dataset # Save dataset
# NOTE: Handle the case where the dataset repo id is not specified in the config
# eg. RL training without demonstrations data
repo_id_buffer_save = cfg.env.task if dataset_repo_id is None else dataset_repo_id
replay_buffer.to_lerobot_dataset( replay_buffer.to_lerobot_dataset(
dataset_repo_id, repo_id=repo_id_buffer_save,
fps=fps, fps=fps,
root=dataset_dir root=dataset_dir
) )
@ -892,8 +903,10 @@ def train(cfg: TrainPipelineConfig, job_name: str | None = None):
cfg (TrainPipelineConfig): The training configuration cfg (TrainPipelineConfig): The training configuration
job_name (str | None, optional): Job name for logging. Defaults to None. job_name (str | None, optional): Job name for logging. Defaults to None.
""" """
if cfg.output_dir is None:
raise ValueError("Output directory must be specified in config") cfg.validate()
# if cfg.output_dir is None:
# raise ValueError("Output directory must be specified in config")
if job_name is None: if job_name is None:
job_name = cfg.job_name job_name = cfg.job_name