[WIP] Update SAC configuration and environment settings
- Reduced frame rate in `ManiskillEnvConfig` from 400 to 200. - Enhanced `SACConfig` with new dataclasses for actor, learner, and network configurations. - Improved input and output feature management in `SACConfig`. - Refactored `actor_server` and `learner_server` to access configuration properties directly. - Updated training pipeline to validate configurations and handle dataset repo IDs more robustly.
This commit is contained in:
parent
626e5dd35c
commit
052a4acfc2
|
@ -173,7 +173,7 @@ class ManiskillEnvConfig(EnvConfig):
|
||||||
control_mode: str = "pd_ee_delta_pose"
|
control_mode: str = "pd_ee_delta_pose"
|
||||||
state_dim: int = 25
|
state_dim: int = 25
|
||||||
action_dim: int = 7
|
action_dim: int = 7
|
||||||
fps: int = 400
|
fps: int = 200
|
||||||
episode_length: int = 50
|
episode_length: int = 50
|
||||||
obs_type: str = "rgb"
|
obs_type: str = "rgb"
|
||||||
render_mode: str = "rgb_array"
|
render_mode: str = "rgb_array"
|
||||||
|
|
|
@ -16,58 +16,100 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any
|
from typing import Any, Optional
|
||||||
|
|
||||||
from lerobot.common.optim.optimizers import MultiAdamConfig
|
from lerobot.common.optim.optimizers import MultiAdamConfig
|
||||||
from lerobot.configs.policies import PreTrainedConfig
|
from lerobot.configs.policies import PreTrainedConfig
|
||||||
from lerobot.configs.types import NormalizationMode
|
from lerobot.configs.types import NormalizationMode, PolicyFeature, FeatureType
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ConcurrencyConfig:
|
||||||
|
actor: str = "threads"
|
||||||
|
learner: str = "threads"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ActorLearnerConfig:
|
||||||
|
learner_host: str = "127.0.0.1"
|
||||||
|
learner_port: int = 50051
|
||||||
|
policy_parameters_push_frequency: int = 4
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CriticNetworkConfig:
|
||||||
|
hidden_dims: list[int] = field(default_factory=lambda: [256, 256])
|
||||||
|
activate_final: bool = True
|
||||||
|
final_activation: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ActorNetworkConfig:
|
||||||
|
hidden_dims: list[int] = field(default_factory=lambda: [256, 256])
|
||||||
|
activate_final: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PolicyConfig:
|
||||||
|
use_tanh_squash: bool = True
|
||||||
|
log_std_min: int = -5
|
||||||
|
log_std_max: int = 2
|
||||||
|
init_final: float = 0.05
|
||||||
|
|
||||||
|
|
||||||
@PreTrainedConfig.register_subclass("sac")
|
@PreTrainedConfig.register_subclass("sac")
|
||||||
@dataclass
|
@dataclass
|
||||||
class SACConfig(PreTrainedConfig):
|
class SACConfig(PreTrainedConfig):
|
||||||
"""Configuration class for Soft Actor-Critic (SAC) policy.
|
"""Soft Actor-Critic (SAC) configuration.
|
||||||
|
|
||||||
|
SAC is an off-policy actor-critic deep RL algorithm based on the maximum entropy
|
||||||
|
reinforcement learning framework. It learns a policy and a Q-function simultaneously
|
||||||
|
using experience collected from the environment.
|
||||||
|
|
||||||
|
This configuration class contains all the parameters needed to define a SAC agent,
|
||||||
|
including network architectures, optimization settings, and algorithm-specific
|
||||||
|
hyperparameters.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
n_obs_steps: Number of environment steps worth of observations to pass to the policy.
|
actor_network: Configuration for the actor network architecture.
|
||||||
normalization_mapping: Mapping from feature types to normalization modes.
|
critic_network: Configuration for the critic network architecture.
|
||||||
dataset_stats: Statistics for normalizing different data types.
|
policy: Configuration for the policy parameters.
|
||||||
camera_number: Number of cameras to use.
|
n_obs_steps: Number of observation steps to consider.
|
||||||
device: Device to use for training.
|
normalization_mapping: Mapping of feature types to normalization modes.
|
||||||
storage_device: Device to use for storage.
|
dataset_stats: Statistics for normalizing different types of inputs.
|
||||||
vision_encoder_name: Name of the vision encoder to use.
|
input_features: Dictionary of input features with their types and shapes.
|
||||||
freeze_vision_encoder: Whether to freeze the vision encoder.
|
output_features: Dictionary of output features with their types and shapes.
|
||||||
image_encoder_hidden_dim: Hidden dimension for the image encoder.
|
camera_number: Number of cameras used for visual observations.
|
||||||
shared_encoder: Whether to use a shared encoder.
|
device: Device to run the model on (e.g., "cuda", "cpu").
|
||||||
online_steps: Total number of online training steps.
|
storage_device: Device to store the model on.
|
||||||
|
vision_encoder_name: Name of the vision encoder model.
|
||||||
|
freeze_vision_encoder: Whether to freeze the vision encoder during training.
|
||||||
|
image_encoder_hidden_dim: Hidden dimension size for the image encoder.
|
||||||
|
shared_encoder: Whether to use a shared encoder for actor and critic.
|
||||||
|
concurrency: Configuration for concurrency settings.
|
||||||
|
actor_learner: Configuration for actor-learner architecture.
|
||||||
|
online_steps: Number of steps for online training.
|
||||||
online_env_seed: Seed for the online environment.
|
online_env_seed: Seed for the online environment.
|
||||||
online_buffer_capacity: Capacity of the online replay buffer.
|
online_buffer_capacity: Capacity of the online replay buffer.
|
||||||
online_step_before_learning: Number of steps to collect before starting learning.
|
offline_buffer_capacity: Capacity of the offline replay buffer.
|
||||||
|
online_step_before_learning: Number of steps before learning starts.
|
||||||
policy_update_freq: Frequency of policy updates.
|
policy_update_freq: Frequency of policy updates.
|
||||||
discount: Discount factor for the RL algorithm.
|
discount: Discount factor for the SAC algorithm.
|
||||||
temperature_init: Initial temperature for entropy regularization.
|
temperature_init: Initial temperature value.
|
||||||
num_critics: Number of critic networks.
|
num_critics: Number of critics in the ensemble.
|
||||||
num_subsample_critics: Number of critics to subsample.
|
num_subsample_critics: Number of subsampled critics for training.
|
||||||
critic_lr: Learning rate for critic networks.
|
critic_lr: Learning rate for the critic network.
|
||||||
actor_lr: Learning rate for actor network.
|
actor_lr: Learning rate for the actor network.
|
||||||
temperature_lr: Learning rate for temperature parameter.
|
temperature_lr: Learning rate for the temperature parameter.
|
||||||
critic_target_update_weight: Weight for soft target updates.
|
critic_target_update_weight: Weight for the critic target update.
|
||||||
utd_ratio: Update-to-data ratio (>1 to enable).
|
utd_ratio: Update-to-data ratio for the UTD algorithm.
|
||||||
state_encoder_hidden_dim: Hidden dimension for state encoder.
|
state_encoder_hidden_dim: Hidden dimension size for the state encoder.
|
||||||
latent_dim: Dimension of latent representation.
|
latent_dim: Dimension of the latent space.
|
||||||
target_entropy: Target entropy for automatic temperature tuning.
|
target_entropy: Target entropy for the SAC algorithm.
|
||||||
use_backup_entropy: Whether to use backup entropy.
|
use_backup_entropy: Whether to use backup entropy for the SAC algorithm.
|
||||||
grad_clip_norm: Gradient clipping norm.
|
grad_clip_norm: Gradient clipping norm for the SAC algorithm.
|
||||||
critic_network_kwargs: Additional arguments for critic networks.
|
|
||||||
actor_network_kwargs: Additional arguments for actor network.
|
|
||||||
policy_kwargs: Additional arguments for policy.
|
|
||||||
actor_learner_config: Configuration for actor-learner communication.
|
|
||||||
concurrency: Configuration for concurrency model.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Input / output structure
|
|
||||||
n_obs_steps: int = 1
|
|
||||||
|
|
||||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||||
default_factory=lambda: {
|
default_factory=lambda: {
|
||||||
"VISUAL": NormalizationMode.MEAN_STD,
|
"VISUAL": NormalizationMode.MEAN_STD,
|
||||||
|
@ -76,6 +118,7 @@ class SACConfig(PreTrainedConfig):
|
||||||
"ACTION": NormalizationMode.MIN_MAX,
|
"ACTION": NormalizationMode.MIN_MAX,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset_stats: dict[str, dict[str, list[float]]] = field(
|
dataset_stats: dict[str, dict[str, list[float]]] = field(
|
||||||
default_factory=lambda: {
|
default_factory=lambda: {
|
||||||
"observation.image": {
|
"observation.image": {
|
||||||
|
@ -93,6 +136,18 @@ class SACConfig(PreTrainedConfig):
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
input_features: dict[str, PolicyFeature] = field(
|
||||||
|
default_factory=lambda: {
|
||||||
|
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 64, 64)),
|
||||||
|
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(2,)),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
output_features: dict[str, PolicyFeature] = field(
|
||||||
|
default_factory=lambda: {
|
||||||
|
"action": PolicyFeature(type=FeatureType.ACTION, shape=(3,)),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
# Architecture specifics
|
# Architecture specifics
|
||||||
camera_number: int = 1
|
camera_number: int = 1
|
||||||
device: str = "cuda"
|
device: str = "cuda"
|
||||||
|
@ -106,7 +161,8 @@ class SACConfig(PreTrainedConfig):
|
||||||
# Training parameter
|
# Training parameter
|
||||||
online_steps: int = 1000000
|
online_steps: int = 1000000
|
||||||
online_env_seed: int = 10000
|
online_env_seed: int = 10000
|
||||||
online_buffer_capacity: int = 10000
|
online_buffer_capacity: int = 100000
|
||||||
|
offline_buffer_capacity: int = 100000
|
||||||
online_step_before_learning: int = 100
|
online_step_before_learning: int = 100
|
||||||
policy_update_freq: int = 1
|
policy_update_freq: int = 1
|
||||||
|
|
||||||
|
@ -127,40 +183,21 @@ class SACConfig(PreTrainedConfig):
|
||||||
grad_clip_norm: float = 40.0
|
grad_clip_norm: float = 40.0
|
||||||
|
|
||||||
# Network configuration
|
# Network configuration
|
||||||
critic_network_kwargs: dict[str, Any] = field(
|
critic_network_kwargs: CriticNetworkConfig = field(
|
||||||
default_factory=lambda: {
|
default_factory=CriticNetworkConfig
|
||||||
"hidden_dims": [256, 256],
|
|
||||||
"activate_final": True,
|
|
||||||
"final_activation": None,
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
actor_network_kwargs: dict[str, Any] = field(
|
actor_network_kwargs: ActorNetworkConfig = field(
|
||||||
default_factory=lambda: {
|
default_factory=ActorNetworkConfig
|
||||||
"hidden_dims": [256, 256],
|
|
||||||
"activate_final": True,
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
policy_kwargs: dict[str, Any] = field(
|
policy_kwargs: PolicyConfig = field(
|
||||||
default_factory=lambda: {
|
default_factory=PolicyConfig
|
||||||
"use_tanh_squash": True,
|
|
||||||
"log_std_min": -5,
|
|
||||||
"log_std_max": 2,
|
|
||||||
"init_final": 0.05,
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
actor_learner_config: dict[str, str | int] = field(
|
actor_learner_config: ActorLearnerConfig = field(
|
||||||
default_factory=lambda: {
|
default_factory=ActorLearnerConfig
|
||||||
"learner_host": "127.0.0.1",
|
|
||||||
"learner_port": 50051,
|
|
||||||
"policy_parameters_push_frequency": 4,
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
concurrency: dict[str, str] = field(
|
concurrency: ConcurrencyConfig = field(
|
||||||
default_factory=lambda: {
|
default_factory=ConcurrencyConfig
|
||||||
"actor": "threads",
|
|
||||||
"learner": "threads"
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
|
@ -181,9 +218,18 @@ class SACConfig(PreTrainedConfig):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def validate_features(self) -> None:
|
def validate_features(self) -> None:
|
||||||
# TODO: Maybe we should remove this raise?
|
if "observation.image" not in self.input_features:
|
||||||
if len(self.image_features) == 0:
|
raise ValueError("You must provide 'observation.image' in the input features")
|
||||||
raise ValueError("You must provide at least one image among the inputs.")
|
|
||||||
|
if "observation.state" not in self.input_features:
|
||||||
|
raise ValueError("You must provide 'observation.state' in the input features")
|
||||||
|
|
||||||
|
if "action" not in self.output_features:
|
||||||
|
raise ValueError("You must provide 'action' in the output features")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def image_features(self) -> list[str]:
|
||||||
|
return [key for key in self.input_features.keys() if 'image' in key]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def observation_delta_indices(self) -> list:
|
def observation_delta_indices(self) -> list:
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
|
|
||||||
# TODO: (1) better device management
|
# TODO: (1) better device management
|
||||||
|
|
||||||
|
from dataclasses import asdict
|
||||||
import math
|
import math
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||||
|
@ -88,7 +89,7 @@ class SACPolicy(
|
||||||
critic_heads = [
|
critic_heads = [
|
||||||
CriticHead(
|
CriticHead(
|
||||||
input_dim=encoder_critic.output_dim + config.output_features["action"].shape[0],
|
input_dim=encoder_critic.output_dim + config.output_features["action"].shape[0],
|
||||||
**config.critic_network_kwargs,
|
**asdict(config.critic_network_kwargs),
|
||||||
)
|
)
|
||||||
for _ in range(config.num_critics)
|
for _ in range(config.num_critics)
|
||||||
]
|
]
|
||||||
|
@ -103,7 +104,7 @@ class SACPolicy(
|
||||||
target_critic_heads = [
|
target_critic_heads = [
|
||||||
CriticHead(
|
CriticHead(
|
||||||
input_dim=encoder_critic.output_dim + config.output_features["action"].shape[0],
|
input_dim=encoder_critic.output_dim + config.output_features["action"].shape[0],
|
||||||
**config.critic_network_kwargs,
|
**asdict(config.critic_network_kwargs),
|
||||||
)
|
)
|
||||||
for _ in range(config.num_critics)
|
for _ in range(config.num_critics)
|
||||||
]
|
]
|
||||||
|
@ -121,10 +122,10 @@ class SACPolicy(
|
||||||
|
|
||||||
self.actor = Policy(
|
self.actor = Policy(
|
||||||
encoder=encoder_actor,
|
encoder=encoder_actor,
|
||||||
network=MLP(input_dim=encoder_actor.output_dim, **config.actor_network_kwargs),
|
network=MLP(input_dim=encoder_actor.output_dim, **asdict(config.actor_network_kwargs)),
|
||||||
action_dim=config.output_features["action"].shape[0],
|
action_dim=config.output_features["action"].shape[0],
|
||||||
encoder_is_shared=config.shared_encoder,
|
encoder_is_shared=config.shared_encoder,
|
||||||
**config.policy_kwargs,
|
**asdict(config.policy_kwargs),
|
||||||
)
|
)
|
||||||
if config.target_entropy is None:
|
if config.target_entropy is None:
|
||||||
config.target_entropy = -np.prod(config.output_features["action"].shape[0]) / 2 # (-dim(A)/2)
|
config.target_entropy = -np.prod(config.output_features["action"].shape[0]) / 2 # (-dim(A)/2)
|
||||||
|
|
|
@ -106,8 +106,9 @@ class TrainPipelineConfig(HubMixin):
|
||||||
train_dir = f"{now:%Y-%m-%d}/{now:%H-%M-%S}_{self.job_name}"
|
train_dir = f"{now:%Y-%m-%d}/{now:%H-%M-%S}_{self.job_name}"
|
||||||
self.output_dir = Path("outputs/train") / train_dir
|
self.output_dir = Path("outputs/train") / train_dir
|
||||||
|
|
||||||
if isinstance(self.dataset.repo_id, list):
|
if self.dataset is not None:
|
||||||
raise NotImplementedError("LeRobotMultiDataset is not currently implemented.")
|
if isinstance(self.dataset.repo_id, list):
|
||||||
|
raise NotImplementedError("LeRobotMultiDataset is not currently implemented.")
|
||||||
|
|
||||||
if not self.use_policy_training_preset and (self.optimizer is None or self.scheduler is None):
|
if not self.use_policy_training_preset and (self.optimizer is None or self.scheduler is None):
|
||||||
raise ValueError("Optimizer and Scheduler must be set when the policy presets are not used.")
|
raise ValueError("Optimizer and Scheduler must be set when the policy presets are not used.")
|
||||||
|
|
|
@ -73,8 +73,8 @@ def receive_policy(
|
||||||
|
|
||||||
if grpc_channel is None or learner_client is None:
|
if grpc_channel is None or learner_client is None:
|
||||||
learner_client, grpc_channel = learner_service_client(
|
learner_client, grpc_channel = learner_service_client(
|
||||||
host=cfg.policy.actor_learner_config["learner_host"],
|
host=cfg.policy.actor_learner_config.learner_host,
|
||||||
port=cfg.policy.actor_learner_config["learner_port"],
|
port=cfg.policy.actor_learner_config.learner_port,
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -85,6 +85,7 @@ def receive_policy(
|
||||||
shutdown_event,
|
shutdown_event,
|
||||||
log_prefix="[ACTOR] parameters",
|
log_prefix="[ACTOR] parameters",
|
||||||
)
|
)
|
||||||
|
|
||||||
except grpc.RpcError as e:
|
except grpc.RpcError as e:
|
||||||
logging.error(f"[ACTOR] gRPC error: {e}")
|
logging.error(f"[ACTOR] gRPC error: {e}")
|
||||||
|
|
||||||
|
@ -153,8 +154,8 @@ def send_transitions(
|
||||||
|
|
||||||
if grpc_channel is None or learner_client is None:
|
if grpc_channel is None or learner_client is None:
|
||||||
learner_client, grpc_channel = learner_service_client(
|
learner_client, grpc_channel = learner_service_client(
|
||||||
host=cfg.policy.actor_learner_config["learner_host"],
|
host=cfg.policy.actor_learner_config.learner_host,
|
||||||
port=cfg.policy.actor_learner_config["learner_port"],
|
port=cfg.policy.actor_learner_config.learner_port,
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -193,8 +194,8 @@ def send_interactions(
|
||||||
|
|
||||||
if grpc_channel is None or learner_client is None:
|
if grpc_channel is None or learner_client is None:
|
||||||
learner_client, grpc_channel = learner_service_client(
|
learner_client, grpc_channel = learner_service_client(
|
||||||
host=cfg.policy.actor_learner_config["learner_host"],
|
host=cfg.policy.actor_learner_config.learner_host,
|
||||||
port=cfg.policy.actor_learner_config["learner_port"],
|
port=cfg.policy.actor_learner_config.learner_port,
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -286,10 +287,10 @@ def act_with_policy(
|
||||||
|
|
||||||
logging.info("make_env online")
|
logging.info("make_env online")
|
||||||
|
|
||||||
online_env = make_robot_env(robot=robot, reward_classifier=reward_classifier, cfg=cfg)
|
online_env = make_robot_env( cfg=cfg.env)
|
||||||
|
|
||||||
set_seed(cfg.seed)
|
set_seed(cfg.seed)
|
||||||
device = get_safe_torch_device(cfg.device, log=True)
|
device = get_safe_torch_device(cfg.policy.device, log=True)
|
||||||
|
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
@ -302,11 +303,7 @@ def act_with_policy(
|
||||||
# TODO: At some point we should just need make sac policy
|
# TODO: At some point we should just need make sac policy
|
||||||
policy: SACPolicy = make_policy(
|
policy: SACPolicy = make_policy(
|
||||||
cfg=cfg.policy,
|
cfg=cfg.policy,
|
||||||
# dataset_stats=offline_dataset.meta.stats if not cfg.resume else None,
|
env_cfg=cfg.env,
|
||||||
# Hack: But if we do online training, we do not need dataset_stats
|
|
||||||
dataset_stats=None,
|
|
||||||
# TODO: Handle resume training
|
|
||||||
device=device,
|
|
||||||
)
|
)
|
||||||
policy = torch.compile(policy)
|
policy = torch.compile(policy)
|
||||||
assert isinstance(policy, nn.Module)
|
assert isinstance(policy, nn.Module)
|
||||||
|
@ -322,13 +319,13 @@ def act_with_policy(
|
||||||
episode_intervention_steps = 0
|
episode_intervention_steps = 0
|
||||||
episode_total_steps = 0
|
episode_total_steps = 0
|
||||||
|
|
||||||
for interaction_step in range(cfg.training.online_steps):
|
for interaction_step in range(cfg.policy.online_steps):
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
if shutdown_event.is_set():
|
if shutdown_event.is_set():
|
||||||
logging.info("[ACTOR] Shutting down act_with_policy")
|
logging.info("[ACTOR] Shutting down act_with_policy")
|
||||||
return
|
return
|
||||||
|
|
||||||
if interaction_step >= cfg.training.online_step_before_learning:
|
if interaction_step >= cfg.policy.online_step_before_learning:
|
||||||
# Time policy inference and check if it meets FPS requirement
|
# Time policy inference and check if it meets FPS requirement
|
||||||
with TimerManager(
|
with TimerManager(
|
||||||
elapsed_time_list=list_policy_time,
|
elapsed_time_list=list_policy_time,
|
||||||
|
@ -426,9 +423,9 @@ def act_with_policy(
|
||||||
episode_total_steps = 0
|
episode_total_steps = 0
|
||||||
obs, info = online_env.reset()
|
obs, info = online_env.reset()
|
||||||
|
|
||||||
if cfg.fps is not None:
|
if cfg.env.fps is not None:
|
||||||
dt_time = time.perf_counter() - start_time
|
dt_time = time.perf_counter() - start_time
|
||||||
busy_wait(1 / cfg.fps - dt_time)
|
busy_wait(1 / cfg.env.fps - dt_time)
|
||||||
|
|
||||||
|
|
||||||
def push_transitions_to_transport_queue(transitions: list, transitions_queue):
|
def push_transitions_to_transport_queue(transitions: list, transitions_queue):
|
||||||
|
@ -467,9 +464,9 @@ def get_frequency_stats(list_policy_time: list[float]) -> dict[str, float]:
|
||||||
|
|
||||||
|
|
||||||
def log_policy_frequency_issue(policy_fps: float, cfg: TrainPipelineConfig, interaction_step: int):
|
def log_policy_frequency_issue(policy_fps: float, cfg: TrainPipelineConfig, interaction_step: int):
|
||||||
if policy_fps < cfg.fps:
|
if policy_fps < cfg.env.fps:
|
||||||
logging.warning(
|
logging.warning(
|
||||||
f"[ACTOR] Policy FPS {policy_fps:.1f} below required {cfg.fps} at step {interaction_step}"
|
f"[ACTOR] Policy FPS {policy_fps:.1f} below required {cfg.env.fps} at step {interaction_step}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -495,7 +492,7 @@ def establish_learner_connection(
|
||||||
|
|
||||||
|
|
||||||
def use_threads(cfg: TrainPipelineConfig) -> bool:
|
def use_threads(cfg: TrainPipelineConfig) -> bool:
|
||||||
return cfg.policy.concurrency["actor"] == "threads"
|
return cfg.policy.concurrency.actor == "threads"
|
||||||
|
|
||||||
|
|
||||||
@parser.wrap()
|
@parser.wrap()
|
||||||
|
@ -511,8 +508,8 @@ def actor_cli(cfg: TrainPipelineConfig):
|
||||||
shutdown_event = setup_process_handlers(use_threads(cfg))
|
shutdown_event = setup_process_handlers(use_threads(cfg))
|
||||||
|
|
||||||
learner_client, grpc_channel = learner_service_client(
|
learner_client, grpc_channel = learner_service_client(
|
||||||
host=cfg.policy.actor_learner_config["learner_host"],
|
host=cfg.policy.actor_learner_config.learner_host,
|
||||||
port=cfg.policy.actor_learner_config["learner_port"],
|
port=cfg.policy.actor_learner_config.learner_port,
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info("[ACTOR] Establishing connection with Learner")
|
logging.info("[ACTOR] Establishing connection with Learner")
|
||||||
|
|
|
@ -1097,7 +1097,6 @@ class ActionScaleWrapper(gym.ActionWrapper):
|
||||||
return action * self.scale_vector, is_intervention
|
return action * self.scale_vector, is_intervention
|
||||||
|
|
||||||
|
|
||||||
@parser.wrap()
|
|
||||||
def make_robot_env(cfg: EnvConfig) -> gym.vector.VectorEnv:
|
def make_robot_env(cfg: EnvConfig) -> gym.vector.VectorEnv:
|
||||||
# def make_robot_env(cfg: TrainPipelineConfig) -> gym.vector.VectorEnv:
|
# def make_robot_env(cfg: TrainPipelineConfig) -> gym.vector.VectorEnv:
|
||||||
# def make_robot_env(cfg: ManiskillEnvConfig) -> gym.vector.VectorEnv:
|
# def make_robot_env(cfg: ManiskillEnvConfig) -> gym.vector.VectorEnv:
|
||||||
|
|
|
@ -48,6 +48,7 @@ from lerobot.common.utils.train_utils import (
|
||||||
load_training_state as utils_load_training_state,
|
load_training_state as utils_load_training_state,
|
||||||
save_checkpoint,
|
save_checkpoint,
|
||||||
update_last_checkpoint,
|
update_last_checkpoint,
|
||||||
|
save_training_state,
|
||||||
)
|
)
|
||||||
from lerobot.common.utils.random_utils import set_seed
|
from lerobot.common.utils.random_utils import set_seed
|
||||||
from lerobot.common.utils.utils import (
|
from lerobot.common.utils.utils import (
|
||||||
|
@ -160,13 +161,14 @@ def load_training_state(
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Use the utility function from train_utils which loads the optimizer state
|
# Use the utility function from train_utils which loads the optimizer state
|
||||||
# The function returns (step, updated_optimizer, scheduler)
|
|
||||||
step, optimizers, _ = utils_load_training_state(Path(checkpoint_dir), optimizers, None)
|
step, optimizers, _ = utils_load_training_state(Path(checkpoint_dir), optimizers, None)
|
||||||
|
|
||||||
# For interaction step, we still need to load the training_state.pt file
|
# Load interaction step separately from training_state.pt
|
||||||
training_state_path = os.path.join(checkpoint_dir, TRAINING_STATE_DIR, "training_state.pt")
|
training_state_path = os.path.join(checkpoint_dir, TRAINING_STATE_DIR, "training_state.pt")
|
||||||
training_state = torch.load(training_state_path, weights_only=False)
|
interaction_step = 0
|
||||||
interaction_step = training_state.get("interaction_step", 0)
|
if os.path.exists(training_state_path):
|
||||||
|
training_state = torch.load(training_state_path, weights_only=False)
|
||||||
|
interaction_step = training_state.get("interaction_step", 0)
|
||||||
|
|
||||||
logging.info(f"Resuming from step {step}, interaction step {interaction_step}")
|
logging.info(f"Resuming from step {step}, interaction step {interaction_step}")
|
||||||
return step, interaction_step
|
return step, interaction_step
|
||||||
|
@ -222,16 +224,20 @@ def initialize_replay_buffer(
|
||||||
|
|
||||||
logging.info("Resume training load the online dataset")
|
logging.info("Resume training load the online dataset")
|
||||||
dataset_path = os.path.join(cfg.output_dir, "dataset")
|
dataset_path = os.path.join(cfg.output_dir, "dataset")
|
||||||
|
|
||||||
|
# NOTE: In RL is possible to not have a dataset.
|
||||||
|
repo_id = None
|
||||||
|
if cfg.dataset is not None:
|
||||||
|
repo_id = cfg.dataset.dataset_repo_id
|
||||||
dataset = LeRobotDataset(
|
dataset = LeRobotDataset(
|
||||||
repo_id=cfg.dataset.dataset_repo_id,
|
repo_id=repo_id,
|
||||||
local_files_only=True,
|
|
||||||
root=dataset_path,
|
root=dataset_path,
|
||||||
)
|
)
|
||||||
return ReplayBuffer.from_lerobot_dataset(
|
return ReplayBuffer.from_lerobot_dataset(
|
||||||
lerobot_dataset=dataset,
|
lerobot_dataset=dataset,
|
||||||
capacity=cfg.policy.online_buffer_capacity,
|
capacity=cfg.policy.online_buffer_capacity,
|
||||||
device=device,
|
device=device,
|
||||||
state_keys=cfg.policy.input_shapes.keys(),
|
state_keys=cfg.policy.input_features.keys(),
|
||||||
optimize_memory=True,
|
optimize_memory=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -298,7 +304,7 @@ def get_observation_features(
|
||||||
|
|
||||||
|
|
||||||
def use_threads(cfg: TrainPipelineConfig) -> bool:
|
def use_threads(cfg: TrainPipelineConfig) -> bool:
|
||||||
return cfg.policy.concurrency["learner"] == "threads"
|
return cfg.policy.concurrency.learner == "threads"
|
||||||
|
|
||||||
|
|
||||||
def start_learner_threads(
|
def start_learner_threads(
|
||||||
|
@ -388,7 +394,7 @@ def start_learner_server(
|
||||||
service = learner_service.LearnerService(
|
service = learner_service.LearnerService(
|
||||||
shutdown_event=shutdown_event,
|
shutdown_event=shutdown_event,
|
||||||
parameters_queue=parameters_queue,
|
parameters_queue=parameters_queue,
|
||||||
seconds_between_pushes=cfg.policy.actor_learner_config["policy_parameters_push_frequency"],
|
seconds_between_pushes=cfg.policy.actor_learner_config.policy_parameters_push_frequency,
|
||||||
transition_queue=transition_queue,
|
transition_queue=transition_queue,
|
||||||
interaction_message_queue=interaction_message_queue,
|
interaction_message_queue=interaction_message_queue,
|
||||||
)
|
)
|
||||||
|
@ -406,8 +412,8 @@ def start_learner_server(
|
||||||
server,
|
server,
|
||||||
)
|
)
|
||||||
|
|
||||||
host = cfg.policy.actor_learner_config["learner_host"]
|
host = cfg.policy.actor_learner_config.learner_host
|
||||||
port = cfg.policy.actor_learner_config["learner_port"]
|
port = cfg.policy.actor_learner_config.learner_port
|
||||||
|
|
||||||
server.add_insecure_port(f"{host}:{port}")
|
server.add_insecure_port(f"{host}:{port}")
|
||||||
server.start()
|
server.start()
|
||||||
|
@ -509,7 +515,6 @@ def add_actor_information_and_train(
|
||||||
checkpoint_dir = os.path.join(cfg.output_dir, CHECKPOINTS_DIR, LAST_CHECKPOINT_LINK) if cfg.resume else None
|
checkpoint_dir = os.path.join(cfg.output_dir, CHECKPOINTS_DIR, LAST_CHECKPOINT_LINK) if cfg.resume else None
|
||||||
pretrained_path = os.path.join(checkpoint_dir, PRETRAINED_MODEL_DIR) if checkpoint_dir else None
|
pretrained_path = os.path.join(checkpoint_dir, PRETRAINED_MODEL_DIR) if checkpoint_dir else None
|
||||||
|
|
||||||
# TODO(Adil): This don't work anymore !
|
|
||||||
policy: SACPolicy = make_policy(
|
policy: SACPolicy = make_policy(
|
||||||
cfg=cfg.policy,
|
cfg=cfg.policy,
|
||||||
# ds_meta=cfg.dataset,
|
# ds_meta=cfg.dataset,
|
||||||
|
@ -575,8 +580,8 @@ def add_actor_information_and_train(
|
||||||
device = cfg.policy.device
|
device = cfg.policy.device
|
||||||
storage_device = cfg.policy.storage_device
|
storage_device = cfg.policy.storage_device
|
||||||
policy_update_freq = cfg.policy.policy_update_freq
|
policy_update_freq = cfg.policy.policy_update_freq
|
||||||
policy_parameters_push_frequency = cfg.policy.actor_learner_config["policy_parameters_push_frequency"]
|
policy_parameters_push_frequency = cfg.policy.actor_learner_config.policy_parameters_push_frequency
|
||||||
save_checkpoint = cfg.save_checkpoint
|
saving_checkpoint = cfg.save_checkpoint
|
||||||
online_steps = cfg.policy.online_steps
|
online_steps = cfg.policy.online_steps
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
|
@ -598,7 +603,7 @@ def add_actor_information_and_train(
|
||||||
continue
|
continue
|
||||||
replay_buffer.add(**transition)
|
replay_buffer.add(**transition)
|
||||||
|
|
||||||
if cfg.dataset.repo_id is not None and transition.get("complementary_info", {}).get(
|
if dataset_repo_id is not None and transition.get("complementary_info", {}).get(
|
||||||
"is_intervention"
|
"is_intervention"
|
||||||
):
|
):
|
||||||
offline_replay_buffer.add(**transition)
|
offline_replay_buffer.add(**transition)
|
||||||
|
@ -618,9 +623,6 @@ def add_actor_information_and_train(
|
||||||
mode="train",
|
mode="train",
|
||||||
custom_step_key="Interaction step"
|
custom_step_key="Interaction step"
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
# Log to console if no WandB logger
|
|
||||||
logging.info(f"Interaction: {interaction_message}")
|
|
||||||
|
|
||||||
logging.debug("[LEARNER] Received interactions")
|
logging.debug("[LEARNER] Received interactions")
|
||||||
|
|
||||||
|
@ -765,9 +767,6 @@ def add_actor_information_and_train(
|
||||||
mode="train",
|
mode="train",
|
||||||
custom_step_key="Optimization step"
|
custom_step_key="Optimization step"
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
# Log to console if no WandB logger
|
|
||||||
logging.info(f"Training: {training_infos}")
|
|
||||||
|
|
||||||
time_for_one_optimization_step = time.time() - time_for_one_optimization_step
|
time_for_one_optimization_step = time.time() - time_for_one_optimization_step
|
||||||
frequency_for_one_optimization_step = 1 / (time_for_one_optimization_step + 1e-9)
|
frequency_for_one_optimization_step = 1 / (time_for_one_optimization_step + 1e-9)
|
||||||
|
@ -789,7 +788,7 @@ def add_actor_information_and_train(
|
||||||
if optimization_step % log_freq == 0:
|
if optimization_step % log_freq == 0:
|
||||||
logging.info(f"[LEARNER] Number of optimization step: {optimization_step}")
|
logging.info(f"[LEARNER] Number of optimization step: {optimization_step}")
|
||||||
|
|
||||||
if save_checkpoint and (optimization_step % save_freq == 0 or optimization_step == online_steps):
|
if saving_checkpoint and (optimization_step % save_freq == 0 or optimization_step == online_steps):
|
||||||
logging.info(f"Checkpoint policy after step {optimization_step}")
|
logging.info(f"Checkpoint policy after step {optimization_step}")
|
||||||
_num_digits = max(6, len(str(online_steps)))
|
_num_digits = max(6, len(str(online_steps)))
|
||||||
step_identifier = f"{optimization_step:0{_num_digits}d}"
|
step_identifier = f"{optimization_step:0{_num_digits}d}"
|
||||||
|
@ -810,6 +809,15 @@ def add_actor_information_and_train(
|
||||||
scheduler=None
|
scheduler=None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Save interaction step manually
|
||||||
|
training_state_dir = os.path.join(checkpoint_dir, TRAINING_STATE_DIR)
|
||||||
|
os.makedirs(training_state_dir, exist_ok=True)
|
||||||
|
training_state = {
|
||||||
|
"step": optimization_step,
|
||||||
|
"interaction_step": interaction_step
|
||||||
|
}
|
||||||
|
torch.save(training_state, os.path.join(training_state_dir, "training_state.pt"))
|
||||||
|
|
||||||
# Update the "last" symlink
|
# Update the "last" symlink
|
||||||
update_last_checkpoint(checkpoint_dir)
|
update_last_checkpoint(checkpoint_dir)
|
||||||
|
|
||||||
|
@ -820,8 +828,11 @@ def add_actor_information_and_train(
|
||||||
shutil.rmtree(dataset_dir)
|
shutil.rmtree(dataset_dir)
|
||||||
|
|
||||||
# Save dataset
|
# Save dataset
|
||||||
|
# NOTE: Handle the case where the dataset repo id is not specified in the config
|
||||||
|
# eg. RL training without demonstrations data
|
||||||
|
repo_id_buffer_save = cfg.env.task if dataset_repo_id is None else dataset_repo_id
|
||||||
replay_buffer.to_lerobot_dataset(
|
replay_buffer.to_lerobot_dataset(
|
||||||
dataset_repo_id,
|
repo_id=repo_id_buffer_save,
|
||||||
fps=fps,
|
fps=fps,
|
||||||
root=dataset_dir
|
root=dataset_dir
|
||||||
)
|
)
|
||||||
|
@ -892,8 +903,10 @@ def train(cfg: TrainPipelineConfig, job_name: str | None = None):
|
||||||
cfg (TrainPipelineConfig): The training configuration
|
cfg (TrainPipelineConfig): The training configuration
|
||||||
job_name (str | None, optional): Job name for logging. Defaults to None.
|
job_name (str | None, optional): Job name for logging. Defaults to None.
|
||||||
"""
|
"""
|
||||||
if cfg.output_dir is None:
|
|
||||||
raise ValueError("Output directory must be specified in config")
|
cfg.validate()
|
||||||
|
# if cfg.output_dir is None:
|
||||||
|
# raise ValueError("Output directory must be specified in config")
|
||||||
|
|
||||||
if job_name is None:
|
if job_name is None:
|
||||||
job_name = cfg.job_name
|
job_name = cfg.job_name
|
||||||
|
|
Loading…
Reference in New Issue