From 052a4acfc2e2905af0f839be4e71df832ae47733 Mon Sep 17 00:00:00 2001 From: AdilZouitine Date: Thu, 27 Mar 2025 08:13:20 +0000 Subject: [PATCH] [WIP] Update SAC configuration and environment settings - Reduced frame rate in `ManiskillEnvConfig` from 400 to 200. - Enhanced `SACConfig` with new dataclasses for actor, learner, and network configurations. - Improved input and output feature management in `SACConfig`. - Refactored `actor_server` and `learner_server` to access configuration properties directly. - Updated training pipeline to validate configurations and handle dataset repo IDs more robustly. --- lerobot/common/envs/configs.py | 2 +- .../common/policies/sac/configuration_sac.py | 188 +++++++++++------- lerobot/common/policies/sac/modeling_sac.py | 9 +- lerobot/configs/train.py | 5 +- lerobot/scripts/server/actor_server.py | 41 ++-- lerobot/scripts/server/gym_manipulator.py | 1 - lerobot/scripts/server/learner_server.py | 63 +++--- 7 files changed, 183 insertions(+), 126 deletions(-) diff --git a/lerobot/common/envs/configs.py b/lerobot/common/envs/configs.py index 0414d64f..313f003a 100644 --- a/lerobot/common/envs/configs.py +++ b/lerobot/common/envs/configs.py @@ -173,7 +173,7 @@ class ManiskillEnvConfig(EnvConfig): control_mode: str = "pd_ee_delta_pose" state_dim: int = 25 action_dim: int = 7 - fps: int = 400 + fps: int = 200 episode_length: int = 50 obs_type: str = "rgb" render_mode: str = "rgb_array" diff --git a/lerobot/common/policies/sac/configuration_sac.py b/lerobot/common/policies/sac/configuration_sac.py index 5221a1f2..d252ddc0 100644 --- a/lerobot/common/policies/sac/configuration_sac.py +++ b/lerobot/common/policies/sac/configuration_sac.py @@ -16,58 +16,100 @@ # limitations under the License. from dataclasses import dataclass, field -from typing import Any +from typing import Any, Optional from lerobot.common.optim.optimizers import MultiAdamConfig from lerobot.configs.policies import PreTrainedConfig -from lerobot.configs.types import NormalizationMode +from lerobot.configs.types import NormalizationMode, PolicyFeature, FeatureType + + +@dataclass +class ConcurrencyConfig: + actor: str = "threads" + learner: str = "threads" + + + +@dataclass +class ActorLearnerConfig: + learner_host: str = "127.0.0.1" + learner_port: int = 50051 + policy_parameters_push_frequency: int = 4 + + +@dataclass +class CriticNetworkConfig: + hidden_dims: list[int] = field(default_factory=lambda: [256, 256]) + activate_final: bool = True + final_activation: str | None = None + + +@dataclass +class ActorNetworkConfig: + hidden_dims: list[int] = field(default_factory=lambda: [256, 256]) + activate_final: bool = True + + +@dataclass +class PolicyConfig: + use_tanh_squash: bool = True + log_std_min: int = -5 + log_std_max: int = 2 + init_final: float = 0.05 @PreTrainedConfig.register_subclass("sac") @dataclass class SACConfig(PreTrainedConfig): - """Configuration class for Soft Actor-Critic (SAC) policy. + """Soft Actor-Critic (SAC) configuration. + + SAC is an off-policy actor-critic deep RL algorithm based on the maximum entropy + reinforcement learning framework. It learns a policy and a Q-function simultaneously + using experience collected from the environment. + + This configuration class contains all the parameters needed to define a SAC agent, + including network architectures, optimization settings, and algorithm-specific + hyperparameters. Args: - n_obs_steps: Number of environment steps worth of observations to pass to the policy. - normalization_mapping: Mapping from feature types to normalization modes. - dataset_stats: Statistics for normalizing different data types. - camera_number: Number of cameras to use. - device: Device to use for training. - storage_device: Device to use for storage. - vision_encoder_name: Name of the vision encoder to use. - freeze_vision_encoder: Whether to freeze the vision encoder. - image_encoder_hidden_dim: Hidden dimension for the image encoder. - shared_encoder: Whether to use a shared encoder. - online_steps: Total number of online training steps. + actor_network: Configuration for the actor network architecture. + critic_network: Configuration for the critic network architecture. + policy: Configuration for the policy parameters. + n_obs_steps: Number of observation steps to consider. + normalization_mapping: Mapping of feature types to normalization modes. + dataset_stats: Statistics for normalizing different types of inputs. + input_features: Dictionary of input features with their types and shapes. + output_features: Dictionary of output features with their types and shapes. + camera_number: Number of cameras used for visual observations. + device: Device to run the model on (e.g., "cuda", "cpu"). + storage_device: Device to store the model on. + vision_encoder_name: Name of the vision encoder model. + freeze_vision_encoder: Whether to freeze the vision encoder during training. + image_encoder_hidden_dim: Hidden dimension size for the image encoder. + shared_encoder: Whether to use a shared encoder for actor and critic. + concurrency: Configuration for concurrency settings. + actor_learner: Configuration for actor-learner architecture. + online_steps: Number of steps for online training. online_env_seed: Seed for the online environment. online_buffer_capacity: Capacity of the online replay buffer. - online_step_before_learning: Number of steps to collect before starting learning. + offline_buffer_capacity: Capacity of the offline replay buffer. + online_step_before_learning: Number of steps before learning starts. policy_update_freq: Frequency of policy updates. - discount: Discount factor for the RL algorithm. - temperature_init: Initial temperature for entropy regularization. - num_critics: Number of critic networks. - num_subsample_critics: Number of critics to subsample. - critic_lr: Learning rate for critic networks. - actor_lr: Learning rate for actor network. - temperature_lr: Learning rate for temperature parameter. - critic_target_update_weight: Weight for soft target updates. - utd_ratio: Update-to-data ratio (>1 to enable). - state_encoder_hidden_dim: Hidden dimension for state encoder. - latent_dim: Dimension of latent representation. - target_entropy: Target entropy for automatic temperature tuning. - use_backup_entropy: Whether to use backup entropy. - grad_clip_norm: Gradient clipping norm. - critic_network_kwargs: Additional arguments for critic networks. - actor_network_kwargs: Additional arguments for actor network. - policy_kwargs: Additional arguments for policy. - actor_learner_config: Configuration for actor-learner communication. - concurrency: Configuration for concurrency model. + discount: Discount factor for the SAC algorithm. + temperature_init: Initial temperature value. + num_critics: Number of critics in the ensemble. + num_subsample_critics: Number of subsampled critics for training. + critic_lr: Learning rate for the critic network. + actor_lr: Learning rate for the actor network. + temperature_lr: Learning rate for the temperature parameter. + critic_target_update_weight: Weight for the critic target update. + utd_ratio: Update-to-data ratio for the UTD algorithm. + state_encoder_hidden_dim: Hidden dimension size for the state encoder. + latent_dim: Dimension of the latent space. + target_entropy: Target entropy for the SAC algorithm. + use_backup_entropy: Whether to use backup entropy for the SAC algorithm. + grad_clip_norm: Gradient clipping norm for the SAC algorithm. """ - - # Input / output structure - n_obs_steps: int = 1 - normalization_mapping: dict[str, NormalizationMode] = field( default_factory=lambda: { "VISUAL": NormalizationMode.MEAN_STD, @@ -76,6 +118,7 @@ class SACConfig(PreTrainedConfig): "ACTION": NormalizationMode.MIN_MAX, } ) + dataset_stats: dict[str, dict[str, list[float]]] = field( default_factory=lambda: { "observation.image": { @@ -93,6 +136,18 @@ class SACConfig(PreTrainedConfig): } ) + input_features: dict[str, PolicyFeature] = field( + default_factory=lambda: { + "observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 64, 64)), + "observation.state": PolicyFeature(type=FeatureType.STATE, shape=(2,)), + } + ) + output_features: dict[str, PolicyFeature] = field( + default_factory=lambda: { + "action": PolicyFeature(type=FeatureType.ACTION, shape=(3,)), + } + ) + # Architecture specifics camera_number: int = 1 device: str = "cuda" @@ -106,7 +161,8 @@ class SACConfig(PreTrainedConfig): # Training parameter online_steps: int = 1000000 online_env_seed: int = 10000 - online_buffer_capacity: int = 10000 + online_buffer_capacity: int = 100000 + offline_buffer_capacity: int = 100000 online_step_before_learning: int = 100 policy_update_freq: int = 1 @@ -127,40 +183,21 @@ class SACConfig(PreTrainedConfig): grad_clip_norm: float = 40.0 # Network configuration - critic_network_kwargs: dict[str, Any] = field( - default_factory=lambda: { - "hidden_dims": [256, 256], - "activate_final": True, - "final_activation": None, - } + critic_network_kwargs: CriticNetworkConfig = field( + default_factory=CriticNetworkConfig ) - actor_network_kwargs: dict[str, Any] = field( - default_factory=lambda: { - "hidden_dims": [256, 256], - "activate_final": True, - } + actor_network_kwargs: ActorNetworkConfig = field( + default_factory=ActorNetworkConfig ) - policy_kwargs: dict[str, Any] = field( - default_factory=lambda: { - "use_tanh_squash": True, - "log_std_min": -5, - "log_std_max": 2, - "init_final": 0.05, - } + policy_kwargs: PolicyConfig = field( + default_factory=PolicyConfig ) - actor_learner_config: dict[str, str | int] = field( - default_factory=lambda: { - "learner_host": "127.0.0.1", - "learner_port": 50051, - "policy_parameters_push_frequency": 4, - } + actor_learner_config: ActorLearnerConfig = field( + default_factory=ActorLearnerConfig ) - concurrency: dict[str, str] = field( - default_factory=lambda: { - "actor": "threads", - "learner": "threads" - } + concurrency: ConcurrencyConfig = field( + default_factory=ConcurrencyConfig ) def __post_init__(self): @@ -181,9 +218,18 @@ class SACConfig(PreTrainedConfig): return None def validate_features(self) -> None: - # TODO: Maybe we should remove this raise? - if len(self.image_features) == 0: - raise ValueError("You must provide at least one image among the inputs.") + if "observation.image" not in self.input_features: + raise ValueError("You must provide 'observation.image' in the input features") + + if "observation.state" not in self.input_features: + raise ValueError("You must provide 'observation.state' in the input features") + + if "action" not in self.output_features: + raise ValueError("You must provide 'action' in the output features") + + @property + def image_features(self) -> list[str]: + return [key for key in self.input_features.keys() if 'image' in key] @property def observation_delta_indices(self) -> list: diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index 083ef567..b54def54 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -17,6 +17,7 @@ # TODO: (1) better device management +from dataclasses import asdict import math from pathlib import Path from typing import Callable, Dict, List, Optional, Tuple, Union @@ -88,7 +89,7 @@ class SACPolicy( critic_heads = [ CriticHead( input_dim=encoder_critic.output_dim + config.output_features["action"].shape[0], - **config.critic_network_kwargs, + **asdict(config.critic_network_kwargs), ) for _ in range(config.num_critics) ] @@ -103,7 +104,7 @@ class SACPolicy( target_critic_heads = [ CriticHead( input_dim=encoder_critic.output_dim + config.output_features["action"].shape[0], - **config.critic_network_kwargs, + **asdict(config.critic_network_kwargs), ) for _ in range(config.num_critics) ] @@ -121,10 +122,10 @@ class SACPolicy( self.actor = Policy( encoder=encoder_actor, - network=MLP(input_dim=encoder_actor.output_dim, **config.actor_network_kwargs), + network=MLP(input_dim=encoder_actor.output_dim, **asdict(config.actor_network_kwargs)), action_dim=config.output_features["action"].shape[0], encoder_is_shared=config.shared_encoder, - **config.policy_kwargs, + **asdict(config.policy_kwargs), ) if config.target_entropy is None: config.target_entropy = -np.prod(config.output_features["action"].shape[0]) / 2 # (-dim(A)/2) diff --git a/lerobot/configs/train.py b/lerobot/configs/train.py index f38cd8e6..02a9edd6 100644 --- a/lerobot/configs/train.py +++ b/lerobot/configs/train.py @@ -106,8 +106,9 @@ class TrainPipelineConfig(HubMixin): train_dir = f"{now:%Y-%m-%d}/{now:%H-%M-%S}_{self.job_name}" self.output_dir = Path("outputs/train") / train_dir - if isinstance(self.dataset.repo_id, list): - raise NotImplementedError("LeRobotMultiDataset is not currently implemented.") + if self.dataset is not None: + if isinstance(self.dataset.repo_id, list): + raise NotImplementedError("LeRobotMultiDataset is not currently implemented.") if not self.use_policy_training_preset and (self.optimizer is None or self.scheduler is None): raise ValueError("Optimizer and Scheduler must be set when the policy presets are not used.") diff --git a/lerobot/scripts/server/actor_server.py b/lerobot/scripts/server/actor_server.py index 7b68145e..dcb9c3d3 100644 --- a/lerobot/scripts/server/actor_server.py +++ b/lerobot/scripts/server/actor_server.py @@ -73,8 +73,8 @@ def receive_policy( if grpc_channel is None or learner_client is None: learner_client, grpc_channel = learner_service_client( - host=cfg.policy.actor_learner_config["learner_host"], - port=cfg.policy.actor_learner_config["learner_port"], + host=cfg.policy.actor_learner_config.learner_host, + port=cfg.policy.actor_learner_config.learner_port, ) try: @@ -85,6 +85,7 @@ def receive_policy( shutdown_event, log_prefix="[ACTOR] parameters", ) + except grpc.RpcError as e: logging.error(f"[ACTOR] gRPC error: {e}") @@ -153,8 +154,8 @@ def send_transitions( if grpc_channel is None or learner_client is None: learner_client, grpc_channel = learner_service_client( - host=cfg.policy.actor_learner_config["learner_host"], - port=cfg.policy.actor_learner_config["learner_port"], + host=cfg.policy.actor_learner_config.learner_host, + port=cfg.policy.actor_learner_config.learner_port, ) try: @@ -193,8 +194,8 @@ def send_interactions( if grpc_channel is None or learner_client is None: learner_client, grpc_channel = learner_service_client( - host=cfg.policy.actor_learner_config["learner_host"], - port=cfg.policy.actor_learner_config["learner_port"], + host=cfg.policy.actor_learner_config.learner_host, + port=cfg.policy.actor_learner_config.learner_port, ) try: @@ -286,10 +287,10 @@ def act_with_policy( logging.info("make_env online") - online_env = make_robot_env(robot=robot, reward_classifier=reward_classifier, cfg=cfg) + online_env = make_robot_env( cfg=cfg.env) set_seed(cfg.seed) - device = get_safe_torch_device(cfg.device, log=True) + device = get_safe_torch_device(cfg.policy.device, log=True) torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True @@ -302,11 +303,7 @@ def act_with_policy( # TODO: At some point we should just need make sac policy policy: SACPolicy = make_policy( cfg=cfg.policy, - # dataset_stats=offline_dataset.meta.stats if not cfg.resume else None, - # Hack: But if we do online training, we do not need dataset_stats - dataset_stats=None, - # TODO: Handle resume training - device=device, + env_cfg=cfg.env, ) policy = torch.compile(policy) assert isinstance(policy, nn.Module) @@ -322,13 +319,13 @@ def act_with_policy( episode_intervention_steps = 0 episode_total_steps = 0 - for interaction_step in range(cfg.training.online_steps): + for interaction_step in range(cfg.policy.online_steps): start_time = time.perf_counter() if shutdown_event.is_set(): logging.info("[ACTOR] Shutting down act_with_policy") return - if interaction_step >= cfg.training.online_step_before_learning: + if interaction_step >= cfg.policy.online_step_before_learning: # Time policy inference and check if it meets FPS requirement with TimerManager( elapsed_time_list=list_policy_time, @@ -426,9 +423,9 @@ def act_with_policy( episode_total_steps = 0 obs, info = online_env.reset() - if cfg.fps is not None: + if cfg.env.fps is not None: dt_time = time.perf_counter() - start_time - busy_wait(1 / cfg.fps - dt_time) + busy_wait(1 / cfg.env.fps - dt_time) def push_transitions_to_transport_queue(transitions: list, transitions_queue): @@ -467,9 +464,9 @@ def get_frequency_stats(list_policy_time: list[float]) -> dict[str, float]: def log_policy_frequency_issue(policy_fps: float, cfg: TrainPipelineConfig, interaction_step: int): - if policy_fps < cfg.fps: + if policy_fps < cfg.env.fps: logging.warning( - f"[ACTOR] Policy FPS {policy_fps:.1f} below required {cfg.fps} at step {interaction_step}" + f"[ACTOR] Policy FPS {policy_fps:.1f} below required {cfg.env.fps} at step {interaction_step}" ) @@ -495,7 +492,7 @@ def establish_learner_connection( def use_threads(cfg: TrainPipelineConfig) -> bool: - return cfg.policy.concurrency["actor"] == "threads" + return cfg.policy.concurrency.actor == "threads" @parser.wrap() @@ -511,8 +508,8 @@ def actor_cli(cfg: TrainPipelineConfig): shutdown_event = setup_process_handlers(use_threads(cfg)) learner_client, grpc_channel = learner_service_client( - host=cfg.policy.actor_learner_config["learner_host"], - port=cfg.policy.actor_learner_config["learner_port"], + host=cfg.policy.actor_learner_config.learner_host, + port=cfg.policy.actor_learner_config.learner_port, ) logging.info("[ACTOR] Establishing connection with Learner") diff --git a/lerobot/scripts/server/gym_manipulator.py b/lerobot/scripts/server/gym_manipulator.py index 55c0c6de..4abd385f 100644 --- a/lerobot/scripts/server/gym_manipulator.py +++ b/lerobot/scripts/server/gym_manipulator.py @@ -1097,7 +1097,6 @@ class ActionScaleWrapper(gym.ActionWrapper): return action * self.scale_vector, is_intervention -@parser.wrap() def make_robot_env(cfg: EnvConfig) -> gym.vector.VectorEnv: # def make_robot_env(cfg: TrainPipelineConfig) -> gym.vector.VectorEnv: # def make_robot_env(cfg: ManiskillEnvConfig) -> gym.vector.VectorEnv: diff --git a/lerobot/scripts/server/learner_server.py b/lerobot/scripts/server/learner_server.py index c34182b7..04fab60f 100644 --- a/lerobot/scripts/server/learner_server.py +++ b/lerobot/scripts/server/learner_server.py @@ -48,6 +48,7 @@ from lerobot.common.utils.train_utils import ( load_training_state as utils_load_training_state, save_checkpoint, update_last_checkpoint, + save_training_state, ) from lerobot.common.utils.random_utils import set_seed from lerobot.common.utils.utils import ( @@ -160,13 +161,14 @@ def load_training_state( try: # Use the utility function from train_utils which loads the optimizer state - # The function returns (step, updated_optimizer, scheduler) step, optimizers, _ = utils_load_training_state(Path(checkpoint_dir), optimizers, None) - # For interaction step, we still need to load the training_state.pt file + # Load interaction step separately from training_state.pt training_state_path = os.path.join(checkpoint_dir, TRAINING_STATE_DIR, "training_state.pt") - training_state = torch.load(training_state_path, weights_only=False) - interaction_step = training_state.get("interaction_step", 0) + interaction_step = 0 + if os.path.exists(training_state_path): + training_state = torch.load(training_state_path, weights_only=False) + interaction_step = training_state.get("interaction_step", 0) logging.info(f"Resuming from step {step}, interaction step {interaction_step}") return step, interaction_step @@ -222,16 +224,20 @@ def initialize_replay_buffer( logging.info("Resume training load the online dataset") dataset_path = os.path.join(cfg.output_dir, "dataset") + + # NOTE: In RL is possible to not have a dataset. + repo_id = None + if cfg.dataset is not None: + repo_id = cfg.dataset.dataset_repo_id dataset = LeRobotDataset( - repo_id=cfg.dataset.dataset_repo_id, - local_files_only=True, + repo_id=repo_id, root=dataset_path, ) return ReplayBuffer.from_lerobot_dataset( lerobot_dataset=dataset, capacity=cfg.policy.online_buffer_capacity, device=device, - state_keys=cfg.policy.input_shapes.keys(), + state_keys=cfg.policy.input_features.keys(), optimize_memory=True, ) @@ -298,7 +304,7 @@ def get_observation_features( def use_threads(cfg: TrainPipelineConfig) -> bool: - return cfg.policy.concurrency["learner"] == "threads" + return cfg.policy.concurrency.learner == "threads" def start_learner_threads( @@ -388,7 +394,7 @@ def start_learner_server( service = learner_service.LearnerService( shutdown_event=shutdown_event, parameters_queue=parameters_queue, - seconds_between_pushes=cfg.policy.actor_learner_config["policy_parameters_push_frequency"], + seconds_between_pushes=cfg.policy.actor_learner_config.policy_parameters_push_frequency, transition_queue=transition_queue, interaction_message_queue=interaction_message_queue, ) @@ -406,8 +412,8 @@ def start_learner_server( server, ) - host = cfg.policy.actor_learner_config["learner_host"] - port = cfg.policy.actor_learner_config["learner_port"] + host = cfg.policy.actor_learner_config.learner_host + port = cfg.policy.actor_learner_config.learner_port server.add_insecure_port(f"{host}:{port}") server.start() @@ -509,7 +515,6 @@ def add_actor_information_and_train( checkpoint_dir = os.path.join(cfg.output_dir, CHECKPOINTS_DIR, LAST_CHECKPOINT_LINK) if cfg.resume else None pretrained_path = os.path.join(checkpoint_dir, PRETRAINED_MODEL_DIR) if checkpoint_dir else None - # TODO(Adil): This don't work anymore ! policy: SACPolicy = make_policy( cfg=cfg.policy, # ds_meta=cfg.dataset, @@ -575,8 +580,8 @@ def add_actor_information_and_train( device = cfg.policy.device storage_device = cfg.policy.storage_device policy_update_freq = cfg.policy.policy_update_freq - policy_parameters_push_frequency = cfg.policy.actor_learner_config["policy_parameters_push_frequency"] - save_checkpoint = cfg.save_checkpoint + policy_parameters_push_frequency = cfg.policy.actor_learner_config.policy_parameters_push_frequency + saving_checkpoint = cfg.save_checkpoint online_steps = cfg.policy.online_steps while True: @@ -598,7 +603,7 @@ def add_actor_information_and_train( continue replay_buffer.add(**transition) - if cfg.dataset.repo_id is not None and transition.get("complementary_info", {}).get( + if dataset_repo_id is not None and transition.get("complementary_info", {}).get( "is_intervention" ): offline_replay_buffer.add(**transition) @@ -618,9 +623,6 @@ def add_actor_information_and_train( mode="train", custom_step_key="Interaction step" ) - else: - # Log to console if no WandB logger - logging.info(f"Interaction: {interaction_message}") logging.debug("[LEARNER] Received interactions") @@ -765,9 +767,6 @@ def add_actor_information_and_train( mode="train", custom_step_key="Optimization step" ) - else: - # Log to console if no WandB logger - logging.info(f"Training: {training_infos}") time_for_one_optimization_step = time.time() - time_for_one_optimization_step frequency_for_one_optimization_step = 1 / (time_for_one_optimization_step + 1e-9) @@ -789,7 +788,7 @@ def add_actor_information_and_train( if optimization_step % log_freq == 0: logging.info(f"[LEARNER] Number of optimization step: {optimization_step}") - if save_checkpoint and (optimization_step % save_freq == 0 or optimization_step == online_steps): + if saving_checkpoint and (optimization_step % save_freq == 0 or optimization_step == online_steps): logging.info(f"Checkpoint policy after step {optimization_step}") _num_digits = max(6, len(str(online_steps))) step_identifier = f"{optimization_step:0{_num_digits}d}" @@ -810,6 +809,15 @@ def add_actor_information_and_train( scheduler=None ) + # Save interaction step manually + training_state_dir = os.path.join(checkpoint_dir, TRAINING_STATE_DIR) + os.makedirs(training_state_dir, exist_ok=True) + training_state = { + "step": optimization_step, + "interaction_step": interaction_step + } + torch.save(training_state, os.path.join(training_state_dir, "training_state.pt")) + # Update the "last" symlink update_last_checkpoint(checkpoint_dir) @@ -820,8 +828,11 @@ def add_actor_information_and_train( shutil.rmtree(dataset_dir) # Save dataset + # NOTE: Handle the case where the dataset repo id is not specified in the config + # eg. RL training without demonstrations data + repo_id_buffer_save = cfg.env.task if dataset_repo_id is None else dataset_repo_id replay_buffer.to_lerobot_dataset( - dataset_repo_id, + repo_id=repo_id_buffer_save, fps=fps, root=dataset_dir ) @@ -892,8 +903,10 @@ def train(cfg: TrainPipelineConfig, job_name: str | None = None): cfg (TrainPipelineConfig): The training configuration job_name (str | None, optional): Job name for logging. Defaults to None. """ - if cfg.output_dir is None: - raise ValueError("Output directory must be specified in config") + + cfg.validate() + # if cfg.output_dir is None: + # raise ValueError("Output directory must be specified in config") if job_name is None: job_name = cfg.job_name