Add WrapperConfig for environment wrappers and update SACConfig properties
- Introduced `WrapperConfig` dataclass for environment wrapper configurations. - Updated `ManiskillEnvConfig` to include a `wrapper` field for enhanced environment management. - Modified `SACConfig` to return `None` for `observation_delta_indices` and `action_delta_indices` properties. - Refactored `make_robot_env` function to improve readability and maintainability.
This commit is contained in:
parent
d0b7690bc0
commit
79e0f6e06c
|
@ -163,6 +163,12 @@ class VideoRecordConfig:
|
|||
record_dir: str = "videos"
|
||||
trajectory_name: str = "trajectory"
|
||||
|
||||
@dataclass
|
||||
class WrapperConfig:
|
||||
"""Configuration for environment wrappers."""
|
||||
delta_action: float | None = None
|
||||
joint_masking_action_space: list[bool] | None = None
|
||||
|
||||
@EnvConfig.register_subclass("maniskill_push")
|
||||
@dataclass
|
||||
class ManiskillEnvConfig(EnvConfig):
|
||||
|
@ -181,6 +187,7 @@ class ManiskillEnvConfig(EnvConfig):
|
|||
device: str = "cuda"
|
||||
robot: str = "so100" # This is a hack to make the robot config work
|
||||
video_record: VideoRecordConfig = field(default_factory=VideoRecordConfig)
|
||||
wrapper: WrapperConfig = field(default_factory=WrapperConfig)
|
||||
features: dict[str, PolicyFeature] = field(
|
||||
default_factory=lambda: {
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
|
||||
|
|
|
@ -233,11 +233,11 @@ class SACConfig(PreTrainedConfig):
|
|||
|
||||
@property
|
||||
def observation_delta_indices(self) -> list:
|
||||
return list(range(1 - self.n_obs_steps, 1))
|
||||
return None
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list:
|
||||
return [0] # SAC typically predicts one action at a time
|
||||
return None # SAC typically predicts one action at a time
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
|
|
|
@ -1100,15 +1100,15 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv:
|
|||
Returns:
|
||||
A vectorized gym environment with all the necessary wrappers applied.
|
||||
"""
|
||||
# if "maniskill" in cfg.name:
|
||||
# from lerobot.scripts.server.maniskill_manipulator import make_maniskill
|
||||
if "maniskill" in cfg.name:
|
||||
from lerobot.scripts.server.maniskill_manipulator import make_maniskill
|
||||
|
||||
# logging.warning("WE SHOULD REMOVE THE MANISKILL BEFORE THE MERGE INTO MAIN")
|
||||
# env = make_maniskill(
|
||||
# cfg=cfg,
|
||||
# n_envs=1,
|
||||
# )
|
||||
# return env
|
||||
logging.warning("WE SHOULD REMOVE THE MANISKILL BEFORE THE MERGE INTO MAIN")
|
||||
env = make_maniskill(
|
||||
cfg=cfg,
|
||||
n_envs=1,
|
||||
)
|
||||
return env
|
||||
robot = make_robot_from_config(cfg.robot)
|
||||
# Create base environment
|
||||
env = HILSerlRobotEnv(
|
||||
|
|
|
@ -222,7 +222,7 @@ def initialize_replay_buffer(cfg: TrainPipelineConfig, device: str, storage_devi
|
|||
# NOTE: In RL is possible to not have a dataset.
|
||||
repo_id = None
|
||||
if cfg.dataset is not None:
|
||||
repo_id = cfg.dataset.dataset_repo_id
|
||||
repo_id = cfg.dataset.repo_id
|
||||
dataset = LeRobotDataset(
|
||||
repo_id=repo_id,
|
||||
root=dataset_path,
|
||||
|
@ -261,8 +261,7 @@ def initialize_offline_replay_buffer(
|
|||
logging.info("load offline dataset")
|
||||
dataset_offline_path = os.path.join(cfg.output_dir, "dataset_offline")
|
||||
offline_dataset = LeRobotDataset(
|
||||
repo_id=cfg.dataset.dataset_repo_id,
|
||||
local_files_only=True,
|
||||
repo_id=cfg.dataset.repo_id,
|
||||
root=dataset_offline_path,
|
||||
)
|
||||
|
||||
|
@ -779,53 +778,106 @@ def add_actor_information_and_train(
|
|||
logging.info(f"[LEARNER] Number of optimization step: {optimization_step}")
|
||||
|
||||
if saving_checkpoint and (optimization_step % save_freq == 0 or optimization_step == online_steps):
|
||||
logging.info(f"Checkpoint policy after step {optimization_step}")
|
||||
_num_digits = max(6, len(str(online_steps)))
|
||||
step_identifier = f"{optimization_step:0{_num_digits}d}"
|
||||
interaction_step = (
|
||||
interaction_message["Interaction step"] if interaction_message is not None else 0
|
||||
save_training_checkpoint(
|
||||
cfg=cfg,
|
||||
optimization_step=optimization_step,
|
||||
online_steps=online_steps,
|
||||
interaction_message=interaction_message,
|
||||
policy=policy,
|
||||
optimizers=optimizers,
|
||||
replay_buffer=replay_buffer,
|
||||
offline_replay_buffer=offline_replay_buffer,
|
||||
dataset_repo_id=dataset_repo_id,
|
||||
fps=fps,
|
||||
)
|
||||
|
||||
# Create checkpoint directory
|
||||
checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, online_steps, optimization_step)
|
||||
def save_training_checkpoint(
|
||||
cfg: TrainPipelineConfig,
|
||||
optimization_step: int,
|
||||
online_steps: int,
|
||||
interaction_message: dict | None,
|
||||
policy: nn.Module,
|
||||
optimizers: dict[str, Optimizer],
|
||||
replay_buffer: ReplayBuffer,
|
||||
offline_replay_buffer: ReplayBuffer | None = None,
|
||||
dataset_repo_id: str | None = None,
|
||||
fps: int = 30,
|
||||
) -> None:
|
||||
"""
|
||||
Save training checkpoint and associated data.
|
||||
|
||||
Args:
|
||||
cfg: Training configuration
|
||||
optimization_step: Current optimization step
|
||||
online_steps: Total number of online steps
|
||||
interaction_message: Dictionary containing interaction information
|
||||
policy: Policy model to save
|
||||
optimizers: Dictionary of optimizers
|
||||
replay_buffer: Replay buffer to save as dataset
|
||||
offline_replay_buffer: Optional offline replay buffer to save
|
||||
dataset_repo_id: Repository ID for dataset
|
||||
fps: Frames per second for dataset
|
||||
"""
|
||||
logging.info(f"Checkpoint policy after step {optimization_step}")
|
||||
_num_digits = max(6, len(str(online_steps)))
|
||||
step_identifier = f"{optimization_step:0{_num_digits}d}"
|
||||
interaction_step = (
|
||||
interaction_message["Interaction step"] if interaction_message is not None else 0
|
||||
)
|
||||
|
||||
# Create checkpoint directory
|
||||
checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, online_steps, optimization_step)
|
||||
|
||||
# Save checkpoint
|
||||
save_checkpoint(
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
step=optimization_step,
|
||||
cfg=cfg,
|
||||
policy=policy,
|
||||
optimizer=optimizers,
|
||||
scheduler=None
|
||||
)
|
||||
|
||||
# Save interaction step manually
|
||||
training_state_dir = os.path.join(checkpoint_dir, TRAINING_STATE_DIR)
|
||||
os.makedirs(training_state_dir, exist_ok=True)
|
||||
training_state = {
|
||||
"step": optimization_step,
|
||||
"interaction_step": interaction_step
|
||||
}
|
||||
torch.save(training_state, os.path.join(training_state_dir, "training_state.pt"))
|
||||
|
||||
# Update the "last" symlink
|
||||
update_last_checkpoint(checkpoint_dir)
|
||||
|
||||
# Save checkpoint
|
||||
save_checkpoint(checkpoint_dir, optimization_step, cfg, policy, optimizers, scheduler=None)
|
||||
# TODO : temporarly save replay buffer here, remove later when on the robot
|
||||
# We want to control this with the keyboard inputs
|
||||
dataset_dir = os.path.join(cfg.output_dir, "dataset")
|
||||
if os.path.exists(dataset_dir) and os.path.isdir(dataset_dir):
|
||||
shutil.rmtree(dataset_dir)
|
||||
|
||||
# Save dataset
|
||||
# NOTE: Handle the case where the dataset repo id is not specified in the config
|
||||
# eg. RL training without demonstrations data
|
||||
repo_id_buffer_save = cfg.env.task if dataset_repo_id is None else dataset_repo_id
|
||||
replay_buffer.to_lerobot_dataset(
|
||||
repo_id=repo_id_buffer_save,
|
||||
fps=fps,
|
||||
root=dataset_dir
|
||||
)
|
||||
|
||||
if offline_replay_buffer is not None:
|
||||
dataset_offline_dir = os.path.join(cfg.output_dir, "dataset_offline")
|
||||
if os.path.exists(dataset_offline_dir) and os.path.isdir(dataset_offline_dir):
|
||||
shutil.rmtree(dataset_offline_dir)
|
||||
|
||||
# Save interaction step manually
|
||||
training_state_dir = os.path.join(checkpoint_dir, TRAINING_STATE_DIR)
|
||||
os.makedirs(training_state_dir, exist_ok=True)
|
||||
training_state = {"step": optimization_step, "interaction_step": interaction_step}
|
||||
torch.save(training_state, os.path.join(training_state_dir, "training_state.pt"))
|
||||
|
||||
# Update the "last" symlink
|
||||
update_last_checkpoint(checkpoint_dir)
|
||||
|
||||
# TODO : temporarly save replay buffer here, remove later when on the robot
|
||||
# We want to control this with the keyboard inputs
|
||||
dataset_dir = os.path.join(cfg.output_dir, "dataset")
|
||||
if os.path.exists(dataset_dir) and os.path.isdir(dataset_dir):
|
||||
shutil.rmtree(dataset_dir)
|
||||
|
||||
# Save dataset
|
||||
# NOTE: Handle the case where the dataset repo id is not specified in the config
|
||||
# eg. RL training without demonstrations data
|
||||
repo_id_buffer_save = cfg.env.task if dataset_repo_id is None else dataset_repo_id
|
||||
replay_buffer.to_lerobot_dataset(repo_id=repo_id_buffer_save, fps=fps, root=dataset_dir)
|
||||
|
||||
if offline_replay_buffer is not None:
|
||||
dataset_offline_dir = os.path.join(cfg.output_dir, "dataset_offline")
|
||||
if os.path.exists(dataset_offline_dir) and os.path.isdir(dataset_offline_dir):
|
||||
shutil.rmtree(dataset_offline_dir)
|
||||
|
||||
offline_replay_buffer.to_lerobot_dataset(
|
||||
cfg.dataset.dataset_repo_id,
|
||||
fps=cfg.env.fps,
|
||||
root=dataset_offline_dir,
|
||||
)
|
||||
|
||||
logging.info("Resume training")
|
||||
offline_replay_buffer.to_lerobot_dataset(
|
||||
cfg.dataset.repo_id,
|
||||
fps=fps,
|
||||
root=dataset_offline_dir,
|
||||
)
|
||||
|
||||
logging.info("Resume training")
|
||||
|
||||
def make_optimizers_and_scheduler(cfg, policy: nn.Module):
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue