Add WrapperConfig for environment wrappers and update SACConfig properties

- Introduced `WrapperConfig` dataclass for environment wrapper configurations.
- Updated `ManiskillEnvConfig` to include a `wrapper` field for enhanced environment management.
- Modified `SACConfig` to return `None` for `observation_delta_indices` and `action_delta_indices` properties.
- Refactored `make_robot_env` function to improve readability and maintainability.
This commit is contained in:
AdilZouitine 2025-03-27 17:07:06 +00:00
parent d0b7690bc0
commit 79e0f6e06c
4 changed files with 114 additions and 55 deletions

View File

@ -163,6 +163,12 @@ class VideoRecordConfig:
record_dir: str = "videos"
trajectory_name: str = "trajectory"
@dataclass
class WrapperConfig:
"""Configuration for environment wrappers."""
delta_action: float | None = None
joint_masking_action_space: list[bool] | None = None
@EnvConfig.register_subclass("maniskill_push")
@dataclass
class ManiskillEnvConfig(EnvConfig):
@ -181,6 +187,7 @@ class ManiskillEnvConfig(EnvConfig):
device: str = "cuda"
robot: str = "so100" # This is a hack to make the robot config work
video_record: VideoRecordConfig = field(default_factory=VideoRecordConfig)
wrapper: WrapperConfig = field(default_factory=WrapperConfig)
features: dict[str, PolicyFeature] = field(
default_factory=lambda: {
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),

View File

@ -233,11 +233,11 @@ class SACConfig(PreTrainedConfig):
@property
def observation_delta_indices(self) -> list:
return list(range(1 - self.n_obs_steps, 1))
return None
@property
def action_delta_indices(self) -> list:
return [0] # SAC typically predicts one action at a time
return None # SAC typically predicts one action at a time
@property
def reward_delta_indices(self) -> None:

View File

@ -1100,15 +1100,15 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv:
Returns:
A vectorized gym environment with all the necessary wrappers applied.
"""
# if "maniskill" in cfg.name:
# from lerobot.scripts.server.maniskill_manipulator import make_maniskill
if "maniskill" in cfg.name:
from lerobot.scripts.server.maniskill_manipulator import make_maniskill
# logging.warning("WE SHOULD REMOVE THE MANISKILL BEFORE THE MERGE INTO MAIN")
# env = make_maniskill(
# cfg=cfg,
# n_envs=1,
# )
# return env
logging.warning("WE SHOULD REMOVE THE MANISKILL BEFORE THE MERGE INTO MAIN")
env = make_maniskill(
cfg=cfg,
n_envs=1,
)
return env
robot = make_robot_from_config(cfg.robot)
# Create base environment
env = HILSerlRobotEnv(

View File

@ -222,7 +222,7 @@ def initialize_replay_buffer(cfg: TrainPipelineConfig, device: str, storage_devi
# NOTE: In RL is possible to not have a dataset.
repo_id = None
if cfg.dataset is not None:
repo_id = cfg.dataset.dataset_repo_id
repo_id = cfg.dataset.repo_id
dataset = LeRobotDataset(
repo_id=repo_id,
root=dataset_path,
@ -261,8 +261,7 @@ def initialize_offline_replay_buffer(
logging.info("load offline dataset")
dataset_offline_path = os.path.join(cfg.output_dir, "dataset_offline")
offline_dataset = LeRobotDataset(
repo_id=cfg.dataset.dataset_repo_id,
local_files_only=True,
repo_id=cfg.dataset.repo_id,
root=dataset_offline_path,
)
@ -779,53 +778,106 @@ def add_actor_information_and_train(
logging.info(f"[LEARNER] Number of optimization step: {optimization_step}")
if saving_checkpoint and (optimization_step % save_freq == 0 or optimization_step == online_steps):
logging.info(f"Checkpoint policy after step {optimization_step}")
_num_digits = max(6, len(str(online_steps)))
step_identifier = f"{optimization_step:0{_num_digits}d}"
interaction_step = (
interaction_message["Interaction step"] if interaction_message is not None else 0
save_training_checkpoint(
cfg=cfg,
optimization_step=optimization_step,
online_steps=online_steps,
interaction_message=interaction_message,
policy=policy,
optimizers=optimizers,
replay_buffer=replay_buffer,
offline_replay_buffer=offline_replay_buffer,
dataset_repo_id=dataset_repo_id,
fps=fps,
)
# Create checkpoint directory
checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, online_steps, optimization_step)
def save_training_checkpoint(
cfg: TrainPipelineConfig,
optimization_step: int,
online_steps: int,
interaction_message: dict | None,
policy: nn.Module,
optimizers: dict[str, Optimizer],
replay_buffer: ReplayBuffer,
offline_replay_buffer: ReplayBuffer | None = None,
dataset_repo_id: str | None = None,
fps: int = 30,
) -> None:
"""
Save training checkpoint and associated data.
Args:
cfg: Training configuration
optimization_step: Current optimization step
online_steps: Total number of online steps
interaction_message: Dictionary containing interaction information
policy: Policy model to save
optimizers: Dictionary of optimizers
replay_buffer: Replay buffer to save as dataset
offline_replay_buffer: Optional offline replay buffer to save
dataset_repo_id: Repository ID for dataset
fps: Frames per second for dataset
"""
logging.info(f"Checkpoint policy after step {optimization_step}")
_num_digits = max(6, len(str(online_steps)))
step_identifier = f"{optimization_step:0{_num_digits}d}"
interaction_step = (
interaction_message["Interaction step"] if interaction_message is not None else 0
)
# Create checkpoint directory
checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, online_steps, optimization_step)
# Save checkpoint
save_checkpoint(
checkpoint_dir=checkpoint_dir,
step=optimization_step,
cfg=cfg,
policy=policy,
optimizer=optimizers,
scheduler=None
)
# Save interaction step manually
training_state_dir = os.path.join(checkpoint_dir, TRAINING_STATE_DIR)
os.makedirs(training_state_dir, exist_ok=True)
training_state = {
"step": optimization_step,
"interaction_step": interaction_step
}
torch.save(training_state, os.path.join(training_state_dir, "training_state.pt"))
# Update the "last" symlink
update_last_checkpoint(checkpoint_dir)
# Save checkpoint
save_checkpoint(checkpoint_dir, optimization_step, cfg, policy, optimizers, scheduler=None)
# TODO : temporarly save replay buffer here, remove later when on the robot
# We want to control this with the keyboard inputs
dataset_dir = os.path.join(cfg.output_dir, "dataset")
if os.path.exists(dataset_dir) and os.path.isdir(dataset_dir):
shutil.rmtree(dataset_dir)
# Save dataset
# NOTE: Handle the case where the dataset repo id is not specified in the config
# eg. RL training without demonstrations data
repo_id_buffer_save = cfg.env.task if dataset_repo_id is None else dataset_repo_id
replay_buffer.to_lerobot_dataset(
repo_id=repo_id_buffer_save,
fps=fps,
root=dataset_dir
)
if offline_replay_buffer is not None:
dataset_offline_dir = os.path.join(cfg.output_dir, "dataset_offline")
if os.path.exists(dataset_offline_dir) and os.path.isdir(dataset_offline_dir):
shutil.rmtree(dataset_offline_dir)
# Save interaction step manually
training_state_dir = os.path.join(checkpoint_dir, TRAINING_STATE_DIR)
os.makedirs(training_state_dir, exist_ok=True)
training_state = {"step": optimization_step, "interaction_step": interaction_step}
torch.save(training_state, os.path.join(training_state_dir, "training_state.pt"))
# Update the "last" symlink
update_last_checkpoint(checkpoint_dir)
# TODO : temporarly save replay buffer here, remove later when on the robot
# We want to control this with the keyboard inputs
dataset_dir = os.path.join(cfg.output_dir, "dataset")
if os.path.exists(dataset_dir) and os.path.isdir(dataset_dir):
shutil.rmtree(dataset_dir)
# Save dataset
# NOTE: Handle the case where the dataset repo id is not specified in the config
# eg. RL training without demonstrations data
repo_id_buffer_save = cfg.env.task if dataset_repo_id is None else dataset_repo_id
replay_buffer.to_lerobot_dataset(repo_id=repo_id_buffer_save, fps=fps, root=dataset_dir)
if offline_replay_buffer is not None:
dataset_offline_dir = os.path.join(cfg.output_dir, "dataset_offline")
if os.path.exists(dataset_offline_dir) and os.path.isdir(dataset_offline_dir):
shutil.rmtree(dataset_offline_dir)
offline_replay_buffer.to_lerobot_dataset(
cfg.dataset.dataset_repo_id,
fps=cfg.env.fps,
root=dataset_offline_dir,
)
logging.info("Resume training")
offline_replay_buffer.to_lerobot_dataset(
cfg.dataset.repo_id,
fps=fps,
root=dataset_offline_dir,
)
logging.info("Resume training")
def make_optimizers_and_scheduler(cfg, policy: nn.Module):
"""