Add WrapperConfig for environment wrappers and update SACConfig properties
- Introduced `WrapperConfig` dataclass for environment wrapper configurations. - Updated `ManiskillEnvConfig` to include a `wrapper` field for enhanced environment management. - Modified `SACConfig` to return `None` for `observation_delta_indices` and `action_delta_indices` properties. - Refactored `make_robot_env` function to improve readability and maintainability.
This commit is contained in:
parent
b69132c79d
commit
88cc2b8fc8
|
@ -163,6 +163,12 @@ class VideoRecordConfig:
|
||||||
record_dir: str = "videos"
|
record_dir: str = "videos"
|
||||||
trajectory_name: str = "trajectory"
|
trajectory_name: str = "trajectory"
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class WrapperConfig:
|
||||||
|
"""Configuration for environment wrappers."""
|
||||||
|
delta_action: float | None = None
|
||||||
|
joint_masking_action_space: list[bool] | None = None
|
||||||
|
|
||||||
@EnvConfig.register_subclass("maniskill_push")
|
@EnvConfig.register_subclass("maniskill_push")
|
||||||
@dataclass
|
@dataclass
|
||||||
class ManiskillEnvConfig(EnvConfig):
|
class ManiskillEnvConfig(EnvConfig):
|
||||||
|
@ -181,6 +187,7 @@ class ManiskillEnvConfig(EnvConfig):
|
||||||
device: str = "cuda"
|
device: str = "cuda"
|
||||||
robot: str = "so100" # This is a hack to make the robot config work
|
robot: str = "so100" # This is a hack to make the robot config work
|
||||||
video_record: VideoRecordConfig = field(default_factory=VideoRecordConfig)
|
video_record: VideoRecordConfig = field(default_factory=VideoRecordConfig)
|
||||||
|
wrapper: WrapperConfig = field(default_factory=WrapperConfig)
|
||||||
features: dict[str, PolicyFeature] = field(
|
features: dict[str, PolicyFeature] = field(
|
||||||
default_factory=lambda: {
|
default_factory=lambda: {
|
||||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
|
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
|
||||||
|
|
|
@ -233,11 +233,11 @@ class SACConfig(PreTrainedConfig):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def observation_delta_indices(self) -> list:
|
def observation_delta_indices(self) -> list:
|
||||||
return list(range(1 - self.n_obs_steps, 1))
|
return None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def action_delta_indices(self) -> list:
|
def action_delta_indices(self) -> list:
|
||||||
return [0] # SAC typically predicts one action at a time
|
return None # SAC typically predicts one action at a time
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def reward_delta_indices(self) -> None:
|
def reward_delta_indices(self) -> None:
|
||||||
|
|
|
@ -1100,15 +1100,15 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv:
|
||||||
Returns:
|
Returns:
|
||||||
A vectorized gym environment with all the necessary wrappers applied.
|
A vectorized gym environment with all the necessary wrappers applied.
|
||||||
"""
|
"""
|
||||||
# if "maniskill" in cfg.name:
|
if "maniskill" in cfg.name:
|
||||||
# from lerobot.scripts.server.maniskill_manipulator import make_maniskill
|
from lerobot.scripts.server.maniskill_manipulator import make_maniskill
|
||||||
|
|
||||||
# logging.warning("WE SHOULD REMOVE THE MANISKILL BEFORE THE MERGE INTO MAIN")
|
logging.warning("WE SHOULD REMOVE THE MANISKILL BEFORE THE MERGE INTO MAIN")
|
||||||
# env = make_maniskill(
|
env = make_maniskill(
|
||||||
# cfg=cfg,
|
cfg=cfg,
|
||||||
# n_envs=1,
|
n_envs=1,
|
||||||
# )
|
)
|
||||||
# return env
|
return env
|
||||||
robot = make_robot_from_config(cfg.robot)
|
robot = make_robot_from_config(cfg.robot)
|
||||||
# Create base environment
|
# Create base environment
|
||||||
env = HILSerlRobotEnv(
|
env = HILSerlRobotEnv(
|
||||||
|
|
|
@ -222,7 +222,7 @@ def initialize_replay_buffer(cfg: TrainPipelineConfig, device: str, storage_devi
|
||||||
# NOTE: In RL is possible to not have a dataset.
|
# NOTE: In RL is possible to not have a dataset.
|
||||||
repo_id = None
|
repo_id = None
|
||||||
if cfg.dataset is not None:
|
if cfg.dataset is not None:
|
||||||
repo_id = cfg.dataset.dataset_repo_id
|
repo_id = cfg.dataset.repo_id
|
||||||
dataset = LeRobotDataset(
|
dataset = LeRobotDataset(
|
||||||
repo_id=repo_id,
|
repo_id=repo_id,
|
||||||
root=dataset_path,
|
root=dataset_path,
|
||||||
|
@ -261,8 +261,7 @@ def initialize_offline_replay_buffer(
|
||||||
logging.info("load offline dataset")
|
logging.info("load offline dataset")
|
||||||
dataset_offline_path = os.path.join(cfg.output_dir, "dataset_offline")
|
dataset_offline_path = os.path.join(cfg.output_dir, "dataset_offline")
|
||||||
offline_dataset = LeRobotDataset(
|
offline_dataset = LeRobotDataset(
|
||||||
repo_id=cfg.dataset.dataset_repo_id,
|
repo_id=cfg.dataset.repo_id,
|
||||||
local_files_only=True,
|
|
||||||
root=dataset_offline_path,
|
root=dataset_offline_path,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -779,53 +778,106 @@ def add_actor_information_and_train(
|
||||||
logging.info(f"[LEARNER] Number of optimization step: {optimization_step}")
|
logging.info(f"[LEARNER] Number of optimization step: {optimization_step}")
|
||||||
|
|
||||||
if saving_checkpoint and (optimization_step % save_freq == 0 or optimization_step == online_steps):
|
if saving_checkpoint and (optimization_step % save_freq == 0 or optimization_step == online_steps):
|
||||||
logging.info(f"Checkpoint policy after step {optimization_step}")
|
save_training_checkpoint(
|
||||||
_num_digits = max(6, len(str(online_steps)))
|
cfg=cfg,
|
||||||
step_identifier = f"{optimization_step:0{_num_digits}d}"
|
optimization_step=optimization_step,
|
||||||
interaction_step = (
|
online_steps=online_steps,
|
||||||
interaction_message["Interaction step"] if interaction_message is not None else 0
|
interaction_message=interaction_message,
|
||||||
|
policy=policy,
|
||||||
|
optimizers=optimizers,
|
||||||
|
replay_buffer=replay_buffer,
|
||||||
|
offline_replay_buffer=offline_replay_buffer,
|
||||||
|
dataset_repo_id=dataset_repo_id,
|
||||||
|
fps=fps,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create checkpoint directory
|
def save_training_checkpoint(
|
||||||
checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, online_steps, optimization_step)
|
cfg: TrainPipelineConfig,
|
||||||
|
optimization_step: int,
|
||||||
|
online_steps: int,
|
||||||
|
interaction_message: dict | None,
|
||||||
|
policy: nn.Module,
|
||||||
|
optimizers: dict[str, Optimizer],
|
||||||
|
replay_buffer: ReplayBuffer,
|
||||||
|
offline_replay_buffer: ReplayBuffer | None = None,
|
||||||
|
dataset_repo_id: str | None = None,
|
||||||
|
fps: int = 30,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Save training checkpoint and associated data.
|
||||||
|
|
||||||
# Save checkpoint
|
Args:
|
||||||
save_checkpoint(checkpoint_dir, optimization_step, cfg, policy, optimizers, scheduler=None)
|
cfg: Training configuration
|
||||||
|
optimization_step: Current optimization step
|
||||||
|
online_steps: Total number of online steps
|
||||||
|
interaction_message: Dictionary containing interaction information
|
||||||
|
policy: Policy model to save
|
||||||
|
optimizers: Dictionary of optimizers
|
||||||
|
replay_buffer: Replay buffer to save as dataset
|
||||||
|
offline_replay_buffer: Optional offline replay buffer to save
|
||||||
|
dataset_repo_id: Repository ID for dataset
|
||||||
|
fps: Frames per second for dataset
|
||||||
|
"""
|
||||||
|
logging.info(f"Checkpoint policy after step {optimization_step}")
|
||||||
|
_num_digits = max(6, len(str(online_steps)))
|
||||||
|
step_identifier = f"{optimization_step:0{_num_digits}d}"
|
||||||
|
interaction_step = (
|
||||||
|
interaction_message["Interaction step"] if interaction_message is not None else 0
|
||||||
|
)
|
||||||
|
|
||||||
# Save interaction step manually
|
# Create checkpoint directory
|
||||||
training_state_dir = os.path.join(checkpoint_dir, TRAINING_STATE_DIR)
|
checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, online_steps, optimization_step)
|
||||||
os.makedirs(training_state_dir, exist_ok=True)
|
|
||||||
training_state = {"step": optimization_step, "interaction_step": interaction_step}
|
|
||||||
torch.save(training_state, os.path.join(training_state_dir, "training_state.pt"))
|
|
||||||
|
|
||||||
# Update the "last" symlink
|
# Save checkpoint
|
||||||
update_last_checkpoint(checkpoint_dir)
|
save_checkpoint(
|
||||||
|
checkpoint_dir=checkpoint_dir,
|
||||||
|
step=optimization_step,
|
||||||
|
cfg=cfg,
|
||||||
|
policy=policy,
|
||||||
|
optimizer=optimizers,
|
||||||
|
scheduler=None
|
||||||
|
)
|
||||||
|
|
||||||
# TODO : temporarly save replay buffer here, remove later when on the robot
|
# Save interaction step manually
|
||||||
# We want to control this with the keyboard inputs
|
training_state_dir = os.path.join(checkpoint_dir, TRAINING_STATE_DIR)
|
||||||
dataset_dir = os.path.join(cfg.output_dir, "dataset")
|
os.makedirs(training_state_dir, exist_ok=True)
|
||||||
if os.path.exists(dataset_dir) and os.path.isdir(dataset_dir):
|
training_state = {
|
||||||
shutil.rmtree(dataset_dir)
|
"step": optimization_step,
|
||||||
|
"interaction_step": interaction_step
|
||||||
|
}
|
||||||
|
torch.save(training_state, os.path.join(training_state_dir, "training_state.pt"))
|
||||||
|
|
||||||
# Save dataset
|
# Update the "last" symlink
|
||||||
# NOTE: Handle the case where the dataset repo id is not specified in the config
|
update_last_checkpoint(checkpoint_dir)
|
||||||
# eg. RL training without demonstrations data
|
|
||||||
repo_id_buffer_save = cfg.env.task if dataset_repo_id is None else dataset_repo_id
|
|
||||||
replay_buffer.to_lerobot_dataset(repo_id=repo_id_buffer_save, fps=fps, root=dataset_dir)
|
|
||||||
|
|
||||||
if offline_replay_buffer is not None:
|
# TODO : temporarly save replay buffer here, remove later when on the robot
|
||||||
dataset_offline_dir = os.path.join(cfg.output_dir, "dataset_offline")
|
# We want to control this with the keyboard inputs
|
||||||
if os.path.exists(dataset_offline_dir) and os.path.isdir(dataset_offline_dir):
|
dataset_dir = os.path.join(cfg.output_dir, "dataset")
|
||||||
shutil.rmtree(dataset_offline_dir)
|
if os.path.exists(dataset_dir) and os.path.isdir(dataset_dir):
|
||||||
|
shutil.rmtree(dataset_dir)
|
||||||
|
|
||||||
offline_replay_buffer.to_lerobot_dataset(
|
# Save dataset
|
||||||
cfg.dataset.dataset_repo_id,
|
# NOTE: Handle the case where the dataset repo id is not specified in the config
|
||||||
fps=cfg.env.fps,
|
# eg. RL training without demonstrations data
|
||||||
root=dataset_offline_dir,
|
repo_id_buffer_save = cfg.env.task if dataset_repo_id is None else dataset_repo_id
|
||||||
)
|
replay_buffer.to_lerobot_dataset(
|
||||||
|
repo_id=repo_id_buffer_save,
|
||||||
|
fps=fps,
|
||||||
|
root=dataset_dir
|
||||||
|
)
|
||||||
|
|
||||||
logging.info("Resume training")
|
if offline_replay_buffer is not None:
|
||||||
|
dataset_offline_dir = os.path.join(cfg.output_dir, "dataset_offline")
|
||||||
|
if os.path.exists(dataset_offline_dir) and os.path.isdir(dataset_offline_dir):
|
||||||
|
shutil.rmtree(dataset_offline_dir)
|
||||||
|
|
||||||
|
offline_replay_buffer.to_lerobot_dataset(
|
||||||
|
cfg.dataset.repo_id,
|
||||||
|
fps=fps,
|
||||||
|
root=dataset_offline_dir,
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info("Resume training")
|
||||||
|
|
||||||
def make_optimizers_and_scheduler(cfg, policy: nn.Module):
|
def make_optimizers_and_scheduler(cfg, policy: nn.Module):
|
||||||
"""
|
"""
|
||||||
|
|
Loading…
Reference in New Issue