[WIP] Non functional yet

Add ManiSkill environment configuration and wrappers

- Introduced `VideoRecordConfig` for video recording settings.
- Added `ManiskillEnvConfig` to encapsulate environment-specific configurations.
- Implemented various wrappers for the ManiSkill environment, including observation and action scaling.
- Enhanced the `make_maniskill` function to create a wrapped ManiSkill environment with video recording and observation processing.
- Updated the `actor_server` and `learner_server` to utilize the new configuration structure.
- Refactored the training pipeline to accommodate the new environment and policy configurations.
This commit is contained in:
AdilZouitine 2025-03-26 08:15:05 +00:00
parent b7b6d8102f
commit dd37bd412e
9 changed files with 667 additions and 436 deletions

View File

@ -154,3 +154,61 @@ class XarmEnv(EnvConfig):
"visualization_height": self.visualization_height,
"max_episode_steps": self.episode_length,
}
@dataclass
class VideoRecordConfig:
"""Configuration for video recording in ManiSkill environments."""
enabled: bool = False
record_dir: str = "videos"
trajectory_name: str = "trajectory"
@EnvConfig.register_subclass("maniskill_push")
@dataclass
class ManiskillEnvConfig(EnvConfig):
"""Configuration for the ManiSkill environment."""
name: str = "maniskill/pushcube"
task: str = "PushCube-v1"
image_size: int = 64
control_mode: str = "pd_ee_delta_pose"
state_dim: int = 25
action_dim: int = 7
fps: int = 400
episode_length: int = 50
obs_type: str = "rgb"
render_mode: str = "rgb_array"
render_size: int = 64
device: str = "cuda"
robot: str = "so100" # This is a hack to make the robot config work
video_record: VideoRecordConfig = field(default_factory=VideoRecordConfig)
features: dict[str, PolicyFeature] = field(
default_factory=lambda: {
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 64, 64)),
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(25,)),
}
)
features_map: dict[str, str] = field(
default_factory=lambda: {
"action": ACTION,
"observation.image": OBS_IMAGE,
"observation.state": OBS_ROBOT,
}
)
reward_classifier: dict[str, str | None] = field(
default_factory=lambda: {
"pretrained_path": None,
"config_path": None,
}
)
@property
def gym_kwargs(self) -> dict:
return {
"obs_type": self.obs_type,
"render_mode": self.render_mode,
"max_episode_steps": self.episode_length,
"control_mode": self.control_mode,
"sensor_configs": {"width": self.image_size, "height": self.image_size},
"num_envs": 1,
}

View File

@ -69,88 +69,3 @@ def make_env(cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False) -> g
return env
def make_maniskill_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv | None:
"""Make ManiSkill3 gym environment"""
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
env = gym.make(
cfg.env.task,
obs_mode=cfg.env.obs,
control_mode=cfg.env.control_mode,
render_mode=cfg.env.render_mode,
sensor_configs=dict(width=cfg.env.image_size, height=cfg.env.image_size),
num_envs=n_envs,
)
# cfg.env_cfg.control_mode = cfg.eval_env_cfg.control_mode = env.control_mode
env = ManiSkillVectorEnv(env, ignore_terminations=True)
# state should have the size of 25
# env = ConvertToLeRobotEnv(env, n_envs)
# env = PixelWrapper(cfg, env, n_envs)
env._max_episode_steps = env.max_episode_steps = 50 # gym_utils.find_max_episode_steps_value(env)
env.unwrapped.metadata["render_fps"] = 20
return env
class PixelWrapper(gym.Wrapper):
"""
Wrapper for pixel observations. Works with Maniskill vectorized environments
"""
def __init__(self, cfg, env, num_envs, num_frames=3):
super().__init__(env)
self.cfg = cfg
self.env = env
self.observation_space = gym.spaces.Box(
low=0,
high=255,
shape=(num_envs, num_frames * 3, cfg.env.render_size, cfg.env.render_size),
dtype=np.uint8,
)
self._frames = deque([], maxlen=num_frames)
self._render_size = cfg.env.render_size
def _get_obs(self, obs):
frame = obs["sensor_data"]["base_camera"]["rgb"].cpu().permute(0, 3, 1, 2)
self._frames.append(frame)
return {"pixels": torch.from_numpy(np.concatenate(self._frames, axis=1)).to(self.env.device)}
def reset(self, seed):
obs, info = self.env.reset() # (seed=seed)
for _ in range(self._frames.maxlen):
obs_frames = self._get_obs(obs)
return obs_frames, info
def step(self, action):
obs, reward, terminated, truncated, info = self.env.step(action)
return self._get_obs(obs), reward, terminated, truncated, info
# TODO: Remove this
class ConvertToLeRobotEnv(gym.Wrapper):
def __init__(self, env, num_envs):
super().__init__(env)
def reset(self, seed=None, options=None):
obs, info = self.env.reset(seed=seed, options={})
return self._get_obs(obs), info
def step(self, action):
obs, reward, terminated, truncated, info = self.env.step(action)
return self._get_obs(obs), reward, terminated, truncated, info
def _get_obs(self, observation):
sensor_data = observation.pop("sensor_data")
del observation["sensor_param"]
images = []
for cam_data in sensor_data.values():
images.append(cam_data["rgb"])
images = torch.concat(images, axis=-1)
# flatten the rest of the data which should just be state data
observation = common.flatten_state_dict(observation, use_torch=True, device=self.base_env.device)
ret = dict()
ret["state"] = observation
ret["pixels"] = images
return ret

View File

@ -31,12 +31,19 @@ class SACConfig(PreTrainedConfig):
Args:
n_obs_steps: Number of environment steps worth of observations to pass to the policy.
normalization_mapping: Mapping from feature types to normalization modes.
dataset_stats: Statistics for normalizing different data types.
camera_number: Number of cameras to use.
device: Device to use for training.
storage_device: Device to use for storage.
vision_encoder_name: Name of the vision encoder to use.
freeze_vision_encoder: Whether to freeze the vision encoder.
image_encoder_hidden_dim: Hidden dimension for the image encoder.
shared_encoder: Whether to use a shared encoder.
online_steps: Total number of online training steps.
online_env_seed: Seed for the online environment.
online_buffer_capacity: Capacity of the online replay buffer.
online_step_before_learning: Number of steps to collect before starting learning.
policy_update_freq: Frequency of policy updates.
discount: Discount factor for the RL algorithm.
temperature_init: Initial temperature for entropy regularization.
num_critics: Number of critic networks.
@ -54,6 +61,8 @@ class SACConfig(PreTrainedConfig):
critic_network_kwargs: Additional arguments for critic networks.
actor_network_kwargs: Additional arguments for actor network.
policy_kwargs: Additional arguments for policy.
actor_learner_config: Configuration for actor-learner communication.
concurrency: Configuration for concurrency model.
"""
# Input / output structure
@ -86,6 +95,7 @@ class SACConfig(PreTrainedConfig):
# Architecture specifics
camera_number: int = 1
device: str = "cuda"
storage_device: str = "cpu"
# Set to "helper2424/resnet10" for hil serl
vision_encoder_name: str | None = None
@ -93,6 +103,13 @@ class SACConfig(PreTrainedConfig):
image_encoder_hidden_dim: int = 32
shared_encoder: bool = True
# Training parameter
online_steps: int = 1000000
online_env_seed: int = 10000
online_buffer_capacity: int = 10000
online_step_before_learning: int = 100
policy_update_freq: int = 1
# SAC algorithm parameters
discount: float = 0.99
temperature_init: float = 1.0
@ -132,11 +149,17 @@ class SACConfig(PreTrainedConfig):
}
)
# Deprecated, kept for backward compatibility
actor_learner_config: dict[str, str | int] = field(
default_factory=lambda: {
"learner_host": "127.0.0.1",
"learner_port": 50051,
"policy_parameters_push_frequency": 4,
}
)
concurrency: dict[str, str] = field(
default_factory=lambda: {
"actor": "threads",
"learner": "threads"
}
)

View File

@ -92,6 +92,8 @@ class WandBLogger:
resume="must" if cfg.resume else None,
mode=self.cfg.mode if self.cfg.mode in ["online", "offline", "disabled"] else "online",
)
# Handle custom step key for rl asynchronous training.
self._wandb_custom_step_key: set[str] | None = None
print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"]))
logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}")
self._wandb = wandb
@ -108,9 +110,24 @@ class WandBLogger:
artifact.add_file(checkpoint_dir / PRETRAINED_MODEL_DIR / SAFETENSORS_SINGLE_FILE)
self._wandb.log_artifact(artifact)
def log_dict(self, d: dict, step: int, mode: str = "train"):
def log_dict(self, d: dict, step: int, mode: str = "train", custom_step_key: str | None = None):
if mode not in {"train", "eval"}:
raise ValueError(mode)
if step is None and custom_step_key is None:
raise ValueError("Either step or custom_step_key must be provided.")
# NOTE: This is not simple. Wandb step is it must always monotonically increase and it
# increases with each wandb.log call, but in the case of asynchronous RL for example,
# multiple time steps is possible for example, the interaction step with the environment,
# the training step, the evaluation step, etc. So we need to define a custom step key
# to log the correct step for each metric.
if custom_step_key is not None:
if self._wandb_custom_step_key is None:
self._wandb_custom_step_key = set()
new_custom_key = f"{mode}/{custom_step_key}"
if new_custom_key not in self._wandb_custom_step_key:
self._wandb_custom_step_key.add(new_custom_key)
self._wandb.define_metric(new_custom_key, hidden=True)
for k, v in d.items():
if not isinstance(v, (int, float, str)):
@ -118,7 +135,26 @@ class WandBLogger:
f'WandB logging of key "{k}" was ignored as its type is not handled by this wrapper.'
)
continue
self._wandb.log({f"{mode}/{k}": v}, step=step)
# Do not log the custom step key itself.
if (
self._wandb_custom_step_key is not None
and k in self._wandb_custom_step_key
):
continue
if custom_step_key is not None:
value_custom_step = d[custom_step_key]
self._wandb.log(
{
f"{mode}/{k}": v,
f"{mode}/{custom_step_key}": value_custom_step,
}
)
continue
self._wandb.log(data={f"{mode}/{k}": v}, step=step)
def log_video(self, video_path: str, step: int, mode: str = "train"):
if mode not in {"train", "eval"}:

View File

@ -34,11 +34,10 @@ TRAIN_CONFIG_NAME = "train_config.json"
@dataclass
class TrainPipelineConfig(HubMixin):
dataset: DatasetConfig
dataset: DatasetConfig | None = None # NOTE: In RL, we don't need a dataset
env: envs.EnvConfig | None = None
policy: PreTrainedConfig | None = None
# Set `dir` to where you would like to save all of the run outputs. If you run another training session
# with the same value for `dir` its contents will be overwritten unless you set `resume` to true.
# Set `dir` to where you would like to save all of the run outputs. If you run another training session # with the same value for `dir` its contents will be overwritten unless you set `resume` to true.
output_dir: Path | None = None
job_name: str | None = None
# Set `resume` to true to resume a previous run. In order for this to work, you will need to make sure

View File

@ -21,26 +21,23 @@ from statistics import mean, quantiles
# from lerobot.scripts.eval import eval_policy
import grpc
import hydra
import torch
from omegaconf import DictConfig
from torch import nn
from torch.multiprocessing import Event, Queue
# TODO: Remove the import of maniskill
# from lerobot.common.envs.factory import make_maniskill_env
# from lerobot.common.envs.utils import preprocess_maniskill_observation
from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.sac.modeling_sac import SACPolicy
from lerobot.common.robot_devices.robots.factory import make_robot
from lerobot.common.robot_devices.robots.utils import Robot
from lerobot.common.robot_devices.robots.utils import Robot, make_robot
from lerobot.common.robot_devices.utils import busy_wait
from lerobot.common.utils.random_utils import set_seed
from lerobot.common.utils.utils import (
TimerManager,
get_safe_torch_device,
init_logging,
set_global_seed,
)
from lerobot.configs import parser
from lerobot.configs.train import TrainPipelineConfig
from lerobot.scripts.server import hilserl_pb2, hilserl_pb2_grpc, learner_service
from lerobot.scripts.server.buffer import (
Transition,
@ -61,7 +58,7 @@ ACTOR_SHUTDOWN_TIMEOUT = 30
def receive_policy(
cfg: DictConfig,
cfg: TrainPipelineConfig,
parameters_queue: Queue,
shutdown_event: any, # Event,
learner_client: hilserl_pb2_grpc.LearnerServiceStub | None = None,
@ -72,12 +69,12 @@ def receive_policy(
if not use_threads(cfg):
# Setup process handlers to handle shutdown signal
# But use shutdown event from the main process
setup_process_handlers(False)
setup_process_handlers(use_threads=False)
if grpc_channel is None or learner_client is None:
learner_client, grpc_channel = learner_service_client(
host=cfg.actor_learner_config.learner_host,
port=cfg.actor_learner_config.learner_port,
host=cfg.policy.actor_learner_config["learner_host"],
port=cfg.policy.actor_learner_config["learner_port"],
)
try:
@ -132,7 +129,7 @@ def interactions_stream(
def send_transitions(
cfg: DictConfig,
cfg: TrainPipelineConfig,
transitions_queue: Queue,
shutdown_event: any, # Event,
learner_client: hilserl_pb2_grpc.LearnerServiceStub | None = None,
@ -156,8 +153,8 @@ def send_transitions(
if grpc_channel is None or learner_client is None:
learner_client, grpc_channel = learner_service_client(
host=cfg.actor_learner_config.learner_host,
port=cfg.actor_learner_config.learner_port,
host=cfg.policy.actor_learner_config["learner_host"],
port=cfg.policy.actor_learner_config["learner_port"],
)
try:
@ -173,7 +170,7 @@ def send_transitions(
def send_interactions(
cfg: DictConfig,
cfg: TrainPipelineConfig,
interactions_queue: Queue,
shutdown_event: any, # Event,
learner_client: hilserl_pb2_grpc.LearnerServiceStub | None = None,
@ -196,8 +193,8 @@ def send_interactions(
if grpc_channel is None or learner_client is None:
learner_client, grpc_channel = learner_service_client(
host=cfg.actor_learner_config.learner_host,
port=cfg.actor_learner_config.learner_port,
host=cfg.policy.actor_learner_config["learner_host"],
port=cfg.policy.actor_learner_config["learner_port"],
)
try:
@ -269,7 +266,7 @@ def update_policy_parameters(policy: SACPolicy, parameters_queue: Queue, device)
def act_with_policy(
cfg: DictConfig,
cfg: TrainPipelineConfig,
robot: Robot,
reward_classifier: nn.Module,
shutdown_event: any, # Event,
@ -291,7 +288,7 @@ def act_with_policy(
online_env = make_robot_env(robot=robot, reward_classifier=reward_classifier, cfg=cfg)
set_global_seed(cfg.seed)
set_seed(cfg.seed)
device = get_safe_torch_device(cfg.device, log=True)
torch.backends.cudnn.benchmark = True
@ -304,7 +301,7 @@ def act_with_policy(
### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters
# TODO: At some point we should just need make sac policy
policy: SACPolicy = make_policy(
hydra_cfg=cfg,
cfg=cfg.policy,
# dataset_stats=offline_dataset.meta.stats if not cfg.resume else None,
# Hack: But if we do online training, we do not need dataset_stats
dataset_stats=None,
@ -469,7 +466,7 @@ def get_frequency_stats(list_policy_time: list[float]) -> dict[str, float]:
return stats
def log_policy_frequency_issue(policy_fps: float, cfg: DictConfig, interaction_step: int):
def log_policy_frequency_issue(policy_fps: float, cfg: TrainPipelineConfig, interaction_step: int):
if policy_fps < cfg.fps:
logging.warning(
f"[ACTOR] Policy FPS {policy_fps:.1f} below required {cfg.fps} at step {interaction_step}"
@ -497,25 +494,25 @@ def establish_learner_connection(
return False
def use_threads(cfg: DictConfig) -> bool:
return cfg.actor_learner_config.concurrency.actor == "threads"
def use_threads(cfg: TrainPipelineConfig) -> bool:
return cfg.policy.concurrency["actor"] == "threads"
@hydra.main(version_base="1.2", config_name="default", config_path="../../configs")
def actor_cli(cfg: dict):
@parser.wrap()
def actor_cli(cfg: TrainPipelineConfig):
if not use_threads(cfg):
import torch.multiprocessing as mp
mp.set_start_method("spawn")
init_logging(log_file="actor.log")
robot = make_robot(cfg=cfg.robot)
robot = make_robot(robot_type=cfg.env.robot)
shutdown_event = setup_process_handlers(use_threads(cfg))
learner_client, grpc_channel = learner_service_client(
host=cfg.actor_learner_config.learner_host,
port=cfg.actor_learner_config.learner_port,
host=cfg.policy.actor_learner_config["learner_host"],
port=cfg.policy.actor_learner_config["learner_port"],
)
logging.info("[ACTOR] Establishing connection with Learner")
@ -570,22 +567,22 @@ def actor_cli(cfg: dict):
# TODO: Remove this once we merge into main
reward_classifier = None
if (
cfg.env.reward_classifier.pretrained_path is not None
and cfg.env.reward_classifier.config_path is not None
cfg.env.reward_classifier["pretrained_path"] is not None
and cfg.env.reward_classifier["config_path"] is not None
):
reward_classifier = get_classifier(
pretrained_path=cfg.env.reward_classifier.pretrained_path,
config_path=cfg.env.reward_classifier.config_path,
pretrained_path=cfg.env.reward_classifier["pretrained_path"],
config_path=cfg.env.reward_classifier["config_path"],
)
act_with_policy(
cfg,
robot,
reward_classifier,
shutdown_event,
parameters_queue,
transitions_queue,
interactions_queue,
cfg=cfg,
robot=robot,
reward_classifier=reward_classifier,
shutdown_event=shutdown_event,
parameters_queue=parameters_queue,
transitions_queue=transitions_queue,
interactions_queue=interactions_queue,
)
logging.info("[ACTOR] Policy process joined")

View File

@ -15,10 +15,12 @@ import json
from dataclasses import dataclass
from lerobot.common.envs.utils import preprocess_observation
from lerobot.configs.train import TrainPipelineConfig
from lerobot.common.envs.configs import EnvConfig
from lerobot.common.robot_devices.control_utils import (
busy_wait,
is_headless,
reset_follower_position,
# reset_follower_position,
)
from typing import Optional
@ -28,6 +30,7 @@ from lerobot.common.robot_devices.robots.utils import make_robot_from_config
from lerobot.common.robot_devices.robots.configs import RobotConfig
from lerobot.scripts.server.kinematics import RobotKinematics
from lerobot.scripts.server.maniskill_manipulator import ManiskillEnvConfig, make_maniskill
from lerobot.configs import parser
logging.basicConfig(level=logging.INFO)
@ -1094,7 +1097,10 @@ class ActionScaleWrapper(gym.ActionWrapper):
return action * self.scale_vector, is_intervention
def make_robot_env(cfg, robot) -> gym.vector.VectorEnv:
@parser.wrap()
def make_robot_env(cfg: EnvConfig) -> gym.vector.VectorEnv:
# def make_robot_env(cfg: TrainPipelineConfig) -> gym.vector.VectorEnv:
# def make_robot_env(cfg: ManiskillEnvConfig) -> gym.vector.VectorEnv:
"""
Factory function to create a vectorized robot environment.
@ -1106,7 +1112,7 @@ def make_robot_env(cfg, robot) -> gym.vector.VectorEnv:
Returns:
A vectorized gym environment with all the necessary wrappers applied.
"""
if "maniskill" in cfg.env_name:
if "maniskill" in cfg.name:
from lerobot.scripts.server.maniskill_manipulator import make_maniskill
logging.warning("WE SHOULD REMOVE THE MANISKILL BEFORE THE MERGE INTO MAIN")
@ -1115,7 +1121,7 @@ def make_robot_env(cfg, robot) -> gym.vector.VectorEnv:
n_envs=1,
)
return env
robot = cfg.robot
# Create base environment
env = HILSerlRobotEnv(
robot=robot,
@ -1329,80 +1335,82 @@ def replay_episode(env, repo_id, root=None, episode=0):
busy_wait(1 / 10 - dt_s)
@parser.wrap()
def main(cfg: HILSerlRobotEnvConfig):
# @parser.wrap()
# def main(cfg):
robot = make_robot_from_config(cfg.robot)
# robot = make_robot_from_config(cfg.robot)
reward_classifier = None #get_classifier(
# cfg.wrapper.reward_classifier_pretrained_path, cfg.wrapper.reward_classifier_config_file
# )
user_relative_joint_positions = True
# reward_classifier = None #get_classifier(
# # cfg.wrapper.reward_classifier_pretrained_path, cfg.wrapper.reward_classifier_config_file
# # )
# user_relative_joint_positions = True
env = make_robot_env(cfg, robot)
# env = make_robot_env(cfg, robot)
if cfg.mode == "record":
policy = None
if cfg.pretrained_policy_name_or_path is not None:
from lerobot.common.policies.sac.modeling_sac import SACPolicy
# if cfg.mode == "record":
# policy = None
# if cfg.pretrained_policy_name_or_path is not None:
# from lerobot.common.policies.sac.modeling_sac import SACPolicy
policy = SACPolicy.from_pretrained(cfg.pretrained_policy_name_or_path)
policy.to(cfg.device)
policy.eval()
# policy = SACPolicy.from_pretrained(cfg.pretrained_policy_name_or_path)
# policy.to(cfg.device)
# policy.eval()
record_dataset(
env,
cfg.repo_id,
root=cfg.dataset_root,
num_episodes=cfg.num_episodes,
fps=cfg.fps,
task_description=cfg.task,
policy=policy,
)
exit()
# record_dataset(
# env,
# cfg.repo_id,
# root=cfg.dataset_root,
# num_episodes=cfg.num_episodes,
# fps=cfg.fps,
# task_description=cfg.task,
# policy=policy,
# )
# exit()
if cfg.mode == "replay":
replay_episode(
env,
cfg.replay_repo_id,
root=cfg.dataset_root,
episode=cfg.replay_episode,
)
exit()
# if cfg.mode == "replay":
# replay_episode(
# env,
# cfg.replay_repo_id,
# root=cfg.dataset_root,
# episode=cfg.replay_episode,
# )
# exit()
env.reset()
# env.reset()
# Retrieve the robot's action space for joint commands.
action_space_robot = env.action_space.spaces[0]
# # Retrieve the robot's action space for joint commands.
# action_space_robot = env.action_space.spaces[0]
# Initialize the smoothed action as a random sample.
smoothed_action = action_space_robot.sample()
# # Initialize the smoothed action as a random sample.
# smoothed_action = action_space_robot.sample()
# Smoothing coefficient (alpha) defines how much of the new random sample to mix in.
# A value close to 0 makes the trajectory very smooth (slow to change), while a value close to 1 is less smooth.
alpha = 1.0
# # Smoothing coefficient (alpha) defines how much of the new random sample to mix in.
# # A value close to 0 makes the trajectory very smooth (slow to change), while a value close to 1 is less smooth.
# alpha = 1.0
num_episode = 0
sucesses = []
while num_episode < 20:
start_loop_s = time.perf_counter()
# Sample a new random action from the robot's action space.
new_random_action = action_space_robot.sample()
# Update the smoothed action using an exponential moving average.
smoothed_action = alpha * new_random_action + (1 - alpha) * smoothed_action
# num_episode = 0
# sucesses = []
# while num_episode < 20:
# start_loop_s = time.perf_counter()
# # Sample a new random action from the robot's action space.
# new_random_action = action_space_robot.sample()
# # Update the smoothed action using an exponential moving average.
# smoothed_action = alpha * new_random_action + (1 - alpha) * smoothed_action
# Execute the step: wrap the NumPy action in a torch tensor.
obs, reward, terminated, truncated, info = env.step((torch.from_numpy(smoothed_action), False))
if terminated or truncated:
sucesses.append(reward)
env.reset()
num_episode += 1
# # Execute the step: wrap the NumPy action in a torch tensor.
# obs, reward, terminated, truncated, info = env.step((torch.from_numpy(smoothed_action), False))
# if terminated or truncated:
# sucesses.append(reward)
# env.reset()
# num_episode += 1
dt_s = time.perf_counter() - start_loop_s
busy_wait(1 / cfg.fps - dt_s)
# dt_s = time.perf_counter() - start_loop_s
# busy_wait(1 / cfg.fps - dt_s)
logging.info(f"Success after 20 steps {sucesses}")
logging.info(f"success rate {sum(sucesses) / len(sucesses)}")
# logging.info(f"Success after 20 steps {sucesses}")
# logging.info(f"success rate {sum(sucesses) / len(sucesses)}")
# if __name__ == "__main__":
# main()
if __name__ == "__main__":
main()
make_robot_env()

View File

@ -19,40 +19,45 @@ import shutil
import time
from concurrent.futures import ThreadPoolExecutor
from pprint import pformat
import os
from pathlib import Path
import draccus
import grpc
# Import generated stubs
import hilserl_pb2_grpc # type: ignore
import hydra
import torch
from deepdiff import DeepDiff
from omegaconf import DictConfig, OmegaConf
from termcolor import colored
from torch import nn
# from torch.multiprocessing import Event, Queue, Process
# from threading import Event, Thread
# from torch.multiprocessing import Queue, Event
from torch.multiprocessing import Queue
from torch.optim.optimizer import Optimizer
from lerobot.common.datasets.factory import make_dataset
from lerobot.configs.train import TrainPipelineConfig
from lerobot.configs import parser
# TODO: Remove the import of maniskill
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.logger import Logger, log_output_dir
from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.sac.modeling_sac import SACPolicy
from lerobot.common.policies.sac.modeling_sac import SACPolicy, SACConfig
from lerobot.common.utils.train_utils import (
get_step_checkpoint_dir,
get_step_identifier,
load_training_state as utils_load_training_state,
save_checkpoint,
update_last_checkpoint,
)
from lerobot.common.utils.random_utils import set_seed
from lerobot.common.utils.utils import (
format_big_number,
get_global_random_state,
get_safe_torch_device,
init_hydra_config,
init_logging,
set_global_random_state,
set_global_seed,
)
from lerobot.common.policies.utils import get_device_from_parameters
from lerobot.common.utils.wandb_utils import WandBLogger
from lerobot.scripts.server import learner_service
from lerobot.scripts.server.buffer import (
ReplayBuffer,
@ -64,102 +69,167 @@ from lerobot.scripts.server.buffer import (
state_to_bytes,
)
from lerobot.scripts.server.utils import setup_process_handlers
from lerobot.common.constants import (
CHECKPOINTS_DIR,
LAST_CHECKPOINT_LINK,
PRETRAINED_MODEL_DIR,
TRAINING_STATE_DIR,
TRAINING_STEP,
)
def handle_resume_logic(cfg: DictConfig, out_dir: str) -> DictConfig:
def handle_resume_logic(cfg: TrainPipelineConfig) -> TrainPipelineConfig:
"""
Handle the resume logic for training.
If resume is True:
- Verifies that a checkpoint exists
- Loads the checkpoint configuration
- Logs resumption details
- Returns the checkpoint configuration
If resume is False:
- Checks if an output directory exists (to prevent accidental overwriting)
- Returns the original configuration
Args:
cfg (TrainPipelineConfig): The training configuration
Returns:
TrainPipelineConfig: The updated configuration
Raises:
RuntimeError: If resume is True but no checkpoint found, or if resume is False but directory exists
"""
out_dir = cfg.output_dir
# Case 1: Not resuming, but need to check if directory exists to prevent overwrites
if not cfg.resume:
if Logger.get_last_checkpoint_dir(out_dir).exists():
checkpoint_dir = os.path.join(out_dir, CHECKPOINTS_DIR, LAST_CHECKPOINT_LINK)
if os.path.exists(checkpoint_dir):
raise RuntimeError(
f"Output directory {Logger.get_last_checkpoint_dir(out_dir)} already exists. "
f"Output directory {checkpoint_dir} already exists. "
"Use `resume=true` to resume training."
)
return cfg
# if resume == True
checkpoint_dir = Logger.get_last_checkpoint_dir(out_dir)
if not checkpoint_dir.exists():
# Case 2: Resuming training
checkpoint_dir = os.path.join(out_dir, CHECKPOINTS_DIR, LAST_CHECKPOINT_LINK)
if not os.path.exists(checkpoint_dir):
raise RuntimeError(f"No model checkpoint found in {checkpoint_dir} for resume=True")
checkpoint_cfg_path = str(Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml")
# Log that we found a valid checkpoint and are resuming
logging.info(
colored(
"Resume=True detected, resuming previous run",
"Valid checkpoint found: resume=True detected, resuming previous run",
color="yellow",
attrs=["bold"],
)
)
checkpoint_cfg = init_hydra_config(checkpoint_cfg_path)
diff = DeepDiff(OmegaConf.to_container(checkpoint_cfg), OmegaConf.to_container(cfg))
if "values_changed" in diff and "root['resume']" in diff["values_changed"]:
del diff["values_changed"]["root['resume']"]
if len(diff) > 0:
logging.warning(
f"Differences between the checkpoint config and the provided config detected: \n{pformat(diff)}\n"
"Checkpoint configuration takes precedence."
)
# Load config using Draccus
checkpoint_cfg_path = os.path.join(checkpoint_dir, PRETRAINED_MODEL_DIR, "train_config.json")
checkpoint_cfg = TrainPipelineConfig.from_pretrained(checkpoint_cfg_path)
# Ensure resume flag is set in returned config
checkpoint_cfg.resume = True
return checkpoint_cfg
def load_training_state(
cfg: DictConfig,
logger: Logger,
optimizers: Optimizer | dict,
cfg: TrainPipelineConfig,
optimizers: Optimizer | dict[str, Optimizer],
):
"""
Loads the training state (optimizers, step count, etc.) from a checkpoint.
Args:
cfg (TrainPipelineConfig): Training configuration
optimizers (Optimizer | dict): Optimizers to load state into
Returns:
tuple: (optimization_step, interaction_step) or (None, None) if not resuming
"""
if not cfg.resume:
return None, None
training_state = torch.load(
logger.last_checkpoint_dir / logger.training_state_file_name, weights_only=False
)
# Construct path to the last checkpoint directory
checkpoint_dir = os.path.join(cfg.output_dir, CHECKPOINTS_DIR, LAST_CHECKPOINT_LINK)
if isinstance(training_state["optimizer"], dict):
assert set(training_state["optimizer"].keys()) == set(optimizers.keys())
for k, v in training_state["optimizer"].items():
optimizers[k].load_state_dict(v)
else:
optimizers.load_state_dict(training_state["optimizer"])
logging.info(f"Loading training state from {checkpoint_dir}")
set_global_random_state({k: training_state[k] for k in get_global_random_state()})
return training_state["step"], training_state["interaction_step"]
try:
# Use the utility function from train_utils which loads the optimizer state
# The function returns (step, updated_optimizer, scheduler)
step, optimizers, _ = utils_load_training_state(Path(checkpoint_dir), optimizers, None)
# For interaction step, we still need to load the training_state.pt file
training_state_path = os.path.join(checkpoint_dir, TRAINING_STATE_DIR, "training_state.pt")
training_state = torch.load(training_state_path, weights_only=False)
interaction_step = training_state.get("interaction_step", 0)
logging.info(f"Resuming from step {step}, interaction step {interaction_step}")
return step, interaction_step
except Exception as e:
logging.error(f"Failed to load training state: {e}")
return None, None
def log_training_info(cfg: DictConfig, out_dir: str, policy: nn.Module) -> None:
def log_training_info(cfg: TrainPipelineConfig, policy: nn.Module) -> None:
"""
Log information about the training process.
Args:
cfg (TrainPipelineConfig): Training configuration
policy (nn.Module): Policy model
"""
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
num_total_params = sum(p.numel() for p in policy.parameters())
log_output_dir(out_dir)
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}")
logging.info(f"{cfg.env.task=}")
logging.info(f"{cfg.training.online_steps=}")
logging.info(f"{cfg.policy.online_steps=}")
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
def initialize_replay_buffer(
cfg: DictConfig, logger: Logger, device: str, storage_device: str
cfg: TrainPipelineConfig,
device: str,
storage_device: str
) -> ReplayBuffer:
"""
Initialize a replay buffer, either empty or from a dataset if resuming.
Args:
cfg (TrainPipelineConfig): Training configuration
device (str): Device to store tensors on
storage_device (str): Device for storage optimization
Returns:
ReplayBuffer: Initialized replay buffer
"""
if not cfg.resume:
return ReplayBuffer(
capacity=cfg.training.online_buffer_capacity,
capacity=cfg.policy.online_buffer_capacity,
device=device,
state_keys=cfg.policy.input_shapes.keys(),
state_keys=cfg.policy.input_features.keys(),
storage_device=storage_device,
optimize_memory=True,
)
logging.info("Resume training load the online dataset")
dataset_path = os.path.join(cfg.output_dir, "dataset")
dataset = LeRobotDataset(
repo_id=cfg.dataset_repo_id,
repo_id=cfg.dataset.dataset_repo_id,
local_files_only=True,
root=logger.log_dir / "dataset",
root=dataset_path,
)
return ReplayBuffer.from_lerobot_dataset(
lerobot_dataset=dataset,
capacity=cfg.training.online_buffer_capacity,
capacity=cfg.policy.online_buffer_capacity,
device=device,
state_keys=cfg.policy.input_shapes.keys(),
optimize_memory=True,
@ -167,33 +237,45 @@ def initialize_replay_buffer(
def initialize_offline_replay_buffer(
cfg: DictConfig,
logger: Logger,
cfg: TrainPipelineConfig,
device: str,
storage_device: str,
active_action_dims: list[int] | None = None,
) -> ReplayBuffer:
"""
Initialize an offline replay buffer from a dataset.
Args:
cfg (TrainPipelineConfig): Training configuration
device (str): Device to store tensors on
storage_device (str): Device for storage optimization
active_action_dims (list[int] | None): Active action dimensions for masking
Returns:
ReplayBuffer: Initialized offline replay buffer
"""
if not cfg.resume:
logging.info("make_dataset offline buffer")
offline_dataset = make_dataset(cfg)
if cfg.resume:
else:
logging.info("load offline dataset")
dataset_offline_path = os.path.join(cfg.output_dir, "dataset_offline")
offline_dataset = LeRobotDataset(
repo_id=cfg.dataset_repo_id,
repo_id=cfg.dataset.dataset_repo_id,
local_files_only=True,
root=logger.log_dir / "dataset_offline",
root=dataset_offline_path,
)
logging.info("Convert to a offline replay buffer")
offline_replay_buffer = ReplayBuffer.from_lerobot_dataset(
offline_dataset,
device=device,
state_keys=cfg.policy.input_shapes.keys(),
state_keys=cfg.policy.input_features.keys(),
action_mask=active_action_dims,
action_delta=cfg.env.wrapper.delta_action,
storage_device=storage_device,
optimize_memory=True,
capacity=cfg.training.offline_buffer_capacity,
capacity=cfg.policy.offline_buffer_capacity,
)
return offline_replay_buffer
@ -215,16 +297,23 @@ def get_observation_features(
return observation_features, next_observation_features
def use_threads(cfg: DictConfig) -> bool:
return cfg.actor_learner_config.concurrency.learner == "threads"
def use_threads(cfg: TrainPipelineConfig) -> bool:
return cfg.policy.concurrency["learner"] == "threads"
def start_learner_threads(
cfg: DictConfig,
logger: Logger,
out_dir: str,
cfg: TrainPipelineConfig,
wandb_logger: WandBLogger | None,
shutdown_event: any, # Event,
) -> None:
"""
Start the learner threads for training.
Args:
cfg (TrainPipelineConfig): Training configuration
wandb_logger (WandBLogger | None): Logger for metrics
shutdown_event: Event to signal shutdown
"""
# Create multiprocessing queues
transition_queue = Queue()
interaction_message_queue = Queue()
@ -255,13 +344,12 @@ def start_learner_threads(
communication_process.start()
add_actor_information_and_train(
cfg,
logger,
out_dir,
shutdown_event,
transition_queue,
interaction_message_queue,
parameters_queue,
cfg=cfg,
wandb_logger=wandb_logger,
shutdown_event=shutdown_event,
transition_queue=transition_queue,
interaction_message_queue=interaction_message_queue,
parameters_queue=parameters_queue,
)
logging.info("[LEARNER] Training process stopped")
@ -286,7 +374,7 @@ def start_learner_server(
transition_queue: Queue,
interaction_message_queue: Queue,
shutdown_event: any, # Event,
cfg: DictConfig,
cfg: TrainPipelineConfig,
):
if not use_threads(cfg):
# We need init logging for MP separataly
@ -298,11 +386,11 @@ def start_learner_server(
setup_process_handlers(False)
service = learner_service.LearnerService(
shutdown_event,
parameters_queue,
cfg.actor_learner_config.policy_parameters_push_frequency,
transition_queue,
interaction_message_queue,
shutdown_event=shutdown_event,
parameters_queue=parameters_queue,
seconds_between_pushes=cfg.policy.actor_learner_config["policy_parameters_push_frequency"],
transition_queue=transition_queue,
interaction_message_queue=interaction_message_queue,
)
server = grpc.server(
@ -318,8 +406,8 @@ def start_learner_server(
server,
)
host = cfg.actor_learner_config.learner_host
port = cfg.actor_learner_config.learner_port
host = cfg.policy.actor_learner_config["learner_host"]
port = cfg.policy.actor_learner_config["learner_port"]
server.add_insecure_port(f"{host}:{port}")
server.start()
@ -385,9 +473,8 @@ def push_actor_policy_to_queue(parameters_queue: Queue, policy: nn.Module):
def add_actor_information_and_train(
cfg,
logger: Logger,
out_dir: str,
cfg: TrainPipelineConfig,
wandb_logger: WandBLogger | None,
shutdown_event: any, # Event,
transition_queue: Queue,
interaction_message_queue: Queue,
@ -405,69 +492,60 @@ def add_actor_information_and_train(
- Periodically updates the actor, critic, and temperature optimizers.
- Logs training statistics, including loss values and optimization frequency.
**NOTE:**
- This function performs multiple responsibilities (data transfer, training, and logging).
It should ideally be split into smaller functions in the future.
- Due to Python's **Global Interpreter Lock (GIL)**, running separate threads for different tasks
significantly reduces performance. Instead, this function executes all operations in a single thread.
Args:
cfg: Configuration object containing hyperparameters.
device (str): The computing device (`"cpu"` or `"cuda"`).
logger (Logger): Logger instance for tracking training progress.
out_dir (str): The output directory for storing training checkpoints and logs.
cfg (TrainPipelineConfig): Configuration object containing hyperparameters.
wandb_logger (WandBLogger | None): Logger for tracking training progress.
shutdown_event (Event): Event to signal shutdown.
transition_queue (Queue): Queue for receiving transitions from the actor.
interaction_message_queue (Queue): Queue for receiving interaction messages from the actor.
parameters_queue (Queue): Queue for sending policy parameters to the actor.
"""
device = get_safe_torch_device(cfg.device, log=True)
storage_device = get_safe_torch_device(cfg_device=cfg.training.storage_device)
device = get_safe_torch_device(try_device=cfg.policy.device, log=True)
storage_device = get_safe_torch_device(try_device=cfg.policy.storage_device)
logging.info("Initializing policy")
### Instantiate the policy in both the actor and learner processes
### To avoid sending a SACPolicy object through the port, we create a policy intance
### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters
# TODO: At some point we should just need make sac policy
# Get checkpoint dir for resuming
checkpoint_dir = os.path.join(cfg.output_dir, CHECKPOINTS_DIR, LAST_CHECKPOINT_LINK) if cfg.resume else None
pretrained_path = os.path.join(checkpoint_dir, PRETRAINED_MODEL_DIR) if checkpoint_dir else None
# TODO(Adil): This don't work anymore !
policy: SACPolicy = make_policy(
hydra_cfg=cfg,
# dataset_stats=offline_dataset.meta.stats if not cfg.resume else None,
# Hack: But if we do online traning, we do not need dataset_stats
dataset_stats=None,
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None,
cfg=cfg.policy,
# ds_meta=cfg.dataset,
env_cfg=cfg.env
)
# Update the policy config with the grad_clip_norm value from training config if it exists
clip_grad_norm_value = cfg.training.grad_clip_norm
clip_grad_norm_value:float = cfg.policy.grad_clip_norm
# compile policy
policy = torch.compile(policy)
assert isinstance(policy, nn.Module)
policy.train()
push_actor_policy_to_queue(parameters_queue, policy)
push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy)
last_time_policy_pushed = time.time()
optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg, policy)
resume_optimization_step, resume_interaction_step = load_training_state(cfg, logger, optimizers)
optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg=cfg, policy=policy)
resume_optimization_step, resume_interaction_step = load_training_state(cfg=cfg, optimizers=optimizers)
log_training_info(cfg, out_dir, policy)
log_training_info(cfg=cfg, policy= policy)
replay_buffer = initialize_replay_buffer(cfg, logger, device, storage_device)
batch_size = cfg.training.batch_size
replay_buffer = initialize_replay_buffer(cfg, device, storage_device)
batch_size = cfg.batch_size
offline_replay_buffer = None
if cfg.dataset_repo_id is not None:
if cfg.dataset is not None:
active_action_dims = None
# TODO: FIX THIS
if cfg.env.wrapper.joint_masking_action_space is not None:
active_action_dims = [
i for i, mask in enumerate(cfg.env.wrapper.joint_masking_action_space) if mask
]
offline_replay_buffer = initialize_offline_replay_buffer(
cfg=cfg,
logger=logger,
device=device,
storage_device=storage_device,
active_action_dims=active_action_dims,
@ -484,18 +562,22 @@ def add_actor_information_and_train(
interaction_step_shift = resume_interaction_step if resume_interaction_step is not None else 0
# Extract variables from cfg
online_step_before_learning = cfg.training.online_step_before_learning
online_step_before_learning = cfg.policy.online_step_before_learning
utd_ratio = cfg.policy.utd_ratio
dataset_repo_id = cfg.dataset_repo_id
fps = cfg.fps
log_freq = cfg.training.log_freq
save_freq = cfg.training.save_freq
device = cfg.device
storage_device = cfg.training.storage_device
policy_update_freq = cfg.training.policy_update_freq
policy_parameters_push_frequency = cfg.actor_learner_config.policy_parameters_push_frequency
save_checkpoint = cfg.training.save_checkpoint
online_steps = cfg.training.online_steps
dataset_repo_id = None
if cfg.dataset is not None:
dataset_repo_id = cfg.dataset.repo_id
fps = cfg.env.fps
log_freq = cfg.log_freq
save_freq = cfg.save_freq
device = cfg.policy.device
storage_device = cfg.policy.storage_device
policy_update_freq = cfg.policy.policy_update_freq
policy_parameters_push_frequency = cfg.policy.actor_learner_config["policy_parameters_push_frequency"]
save_checkpoint = cfg.save_checkpoint
online_steps = cfg.policy.online_steps
while True:
if shutdown_event is not None and shutdown_event.is_set():
@ -516,7 +598,7 @@ def add_actor_information_and_train(
continue
replay_buffer.add(**transition)
if cfg.dataset_repo_id is not None and transition.get("complementary_info", {}).get(
if cfg.dataset.repo_id is not None and transition.get("complementary_info", {}).get(
"is_intervention"
):
offline_replay_buffer.add(**transition)
@ -528,7 +610,17 @@ def add_actor_information_and_train(
interaction_message = bytes_to_python_object(interaction_message)
# If cfg.resume, shift the interaction step with the last checkpointed step in order to not break the logging
interaction_message["Interaction step"] += interaction_step_shift
logger.log_dict(interaction_message, mode="train", custom_step_key="Interaction step")
# Log interaction messages with WandB if available
if wandb_logger:
wandb_logger.log_dict(
d=interaction_message,
mode="train",
custom_step_key="Interaction step"
)
else:
# Log to console if no WandB logger
logging.info(f"Interaction: {interaction_message}")
logging.debug("[LEARNER] Received interactions")
@ -538,11 +630,11 @@ def add_actor_information_and_train(
logging.debug("[LEARNER] Starting optimization loop")
time_for_one_optimization_step = time.time()
for _ in range(utd_ratio - 1):
batch = replay_buffer.sample(batch_size)
batch = replay_buffer.sample(batch_size=batch_size)
if dataset_repo_id is not None:
batch_offline = offline_replay_buffer.sample(batch_size)
batch = concatenate_batch_transitions(batch, batch_offline)
batch_offline = offline_replay_buffer.sample(batch_size=batch_size)
batch = concatenate_batch_transitions(left_batch_transitions=batch, right_batch_transition=batch_offline)
actions = batch["action"]
rewards = batch["reward"]
@ -552,7 +644,7 @@ def add_actor_information_and_train(
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
observation_features, next_observation_features = get_observation_features(
policy, observations, next_observations
policy=policy, observations=observations, next_observations=next_observations
)
loss_critic = policy.compute_loss_critic(
observations=observations,
@ -568,15 +660,15 @@ def add_actor_information_and_train(
# clip gradients
critic_grad_norm = torch.nn.utils.clip_grad_norm_(
policy.critic_ensemble.parameters(), clip_grad_norm_value
parameters=policy.critic_ensemble.parameters(), max_norm=clip_grad_norm_value
)
optimizers["critic"].step()
batch = replay_buffer.sample(batch_size)
batch = replay_buffer.sample(batch_size=batch_size)
if dataset_repo_id is not None:
batch_offline = offline_replay_buffer.sample(batch_size)
batch_offline = offline_replay_buffer.sample(batch_size=batch_size)
batch = concatenate_batch_transitions(
left_batch_transitions=batch, right_batch_transition=batch_offline
)
@ -590,7 +682,7 @@ def add_actor_information_and_train(
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
observation_features, next_observation_features = get_observation_features(
policy, observations, next_observations
policy=policy, observations=observations, next_observations=next_observations
)
loss_critic = policy.compute_loss_critic(
observations=observations,
@ -606,7 +698,7 @@ def add_actor_information_and_train(
# clip gradients
critic_grad_norm = torch.nn.utils.clip_grad_norm_(
policy.critic_ensemble.parameters(), clip_grad_norm_value
parameters=policy.critic_ensemble.parameters(), max_norm=clip_grad_norm_value
).item()
optimizers["critic"].step()
@ -627,7 +719,7 @@ def add_actor_information_and_train(
# clip gradients
actor_grad_norm = torch.nn.utils.clip_grad_norm_(
policy.actor.parameters_to_optimize, clip_grad_norm_value
parameters=policy.actor.parameters_to_optimize, max_norm=clip_grad_norm_value
).item()
optimizers["actor"].step()
@ -645,7 +737,7 @@ def add_actor_information_and_train(
# clip gradients
temp_grad_norm = torch.nn.utils.clip_grad_norm_(
[policy.log_alpha], clip_grad_norm_value
parameters=[policy.log_alpha], max_norm=clip_grad_norm_value
).item()
optimizers["temperature"].step()
@ -655,7 +747,7 @@ def add_actor_information_and_train(
training_infos["temperature"] = policy.temperature
if time.time() - last_time_policy_pushed > policy_parameters_push_frequency:
push_actor_policy_to_queue(parameters_queue, policy)
push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy)
last_time_policy_pushed = time.time()
policy.update_target_networks()
@ -665,15 +757,26 @@ def add_actor_information_and_train(
if offline_replay_buffer is not None:
training_infos["offline_replay_buffer_size"] = len(offline_replay_buffer)
training_infos["Optimization step"] = optimization_step
logger.log_dict(d=training_infos, mode="train", custom_step_key="Optimization step")
# logging.info(f"Training infos: {training_infos}")
# Log training metrics
if wandb_logger:
wandb_logger.log_dict(
d=training_infos,
mode="train",
custom_step_key="Optimization step"
)
else:
# Log to console if no WandB logger
logging.info(f"Training: {training_infos}")
time_for_one_optimization_step = time.time() - time_for_one_optimization_step
frequency_for_one_optimization_step = 1 / (time_for_one_optimization_step + 1e-9)
logging.info(f"[LEARNER] Optimization frequency loop [Hz]: {frequency_for_one_optimization_step}")
logger.log_dict(
# Log optimization frequency
if wandb_logger:
wandb_logger.log_dict(
{
"Optimization frequency loop [Hz]": frequency_for_one_optimization_step,
"Optimization step": optimization_step,
@ -693,35 +796,45 @@ def add_actor_information_and_train(
interaction_step = (
interaction_message["Interaction step"] if interaction_message is not None else 0
)
logger.save_checkpoint(
# Create checkpoint directory
checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, online_steps, optimization_step)
# Save checkpoint
save_checkpoint(
checkpoint_dir,
optimization_step,
cfg,
policy,
optimizers,
scheduler=None,
identifier=step_identifier,
interaction_step=interaction_step,
scheduler=None
)
# Update the "last" symlink
update_last_checkpoint(checkpoint_dir)
# TODO : temporarly save replay buffer here, remove later when on the robot
# We want to control this with the keyboard inputs
dataset_dir = logger.log_dir / "dataset"
if dataset_dir.exists() and dataset_dir.is_dir():
shutil.rmtree(
dataset_dir,
)
replay_buffer.to_lerobot_dataset(dataset_repo_id, fps=fps, root=logger.log_dir / "dataset")
if offline_replay_buffer is not None:
dataset_dir = logger.log_dir / "dataset_offline"
dataset_dir = os.path.join(cfg.output_dir, "dataset")
if os.path.exists(dataset_dir) and os.path.isdir(dataset_dir):
shutil.rmtree(dataset_dir)
if dataset_dir.exists() and dataset_dir.is_dir():
shutil.rmtree(
dataset_dir,
# Save dataset
replay_buffer.to_lerobot_dataset(
dataset_repo_id,
fps=fps,
root=dataset_dir
)
if offline_replay_buffer is not None:
dataset_offline_dir = os.path.join(cfg.output_dir, "dataset_offline")
if os.path.exists(dataset_offline_dir) and os.path.isdir(dataset_offline_dir):
shutil.rmtree(dataset_offline_dir)
offline_replay_buffer.to_lerobot_dataset(
cfg.dataset_repo_id,
fps=cfg.fps,
root=logger.log_dir / "dataset_offline",
cfg.dataset.dataset_repo_id,
fps=cfg.env.fps,
root=dataset_offline_dir,
)
logging.info("Resume training")
@ -756,12 +869,12 @@ def make_optimizers_and_scheduler(cfg, policy: nn.Module):
optimizer_actor = torch.optim.Adam(
# NOTE: Handle the case of shared encoder where the encoder weights are not optimized with the gradient of the actor
params=policy.actor.parameters_to_optimize,
lr=policy.config.actor_lr,
lr=cfg.policy.actor_lr,
)
optimizer_critic = torch.optim.Adam(
params=policy.critic_ensemble.parameters(), lr=policy.config.critic_lr
params=policy.critic_ensemble.parameters(), lr=cfg.policy.critic_lr
)
optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=policy.config.critic_lr)
optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=cfg.policy.critic_lr)
lr_scheduler = None
optimizers = {
"actor": optimizer_actor,
@ -771,19 +884,38 @@ def make_optimizers_and_scheduler(cfg, policy: nn.Module):
return optimizers, lr_scheduler
def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = None):
if out_dir is None:
raise NotImplementedError()
def train(cfg: TrainPipelineConfig, job_name: str | None = None):
"""
Main training function that initializes and runs the training process.
Args:
cfg (TrainPipelineConfig): The training configuration
job_name (str | None, optional): Job name for logging. Defaults to None.
"""
if cfg.output_dir is None:
raise ValueError("Output directory must be specified in config")
if job_name is None:
raise NotImplementedError()
job_name = cfg.job_name
if job_name is None:
raise ValueError("Job name must be specified either in config or as a parameter")
init_logging()
logging.info(pformat(OmegaConf.to_container(cfg)))
logging.info(pformat(cfg.to_dict()))
logger = Logger(cfg, out_dir, wandb_job_name=job_name)
cfg = handle_resume_logic(cfg, out_dir)
# Setup WandB logging if enabled
if cfg.wandb.enable and cfg.wandb.project:
from lerobot.common.utils.wandb_utils import WandBLogger
wandb_logger = WandBLogger(cfg)
else:
wandb_logger = None
logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"]))
set_global_seed(cfg.seed)
# Handle resume logic
cfg = handle_resume_logic(cfg)
set_seed(seed=cfg.seed)
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
@ -791,24 +923,23 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
shutdown_event = setup_process_handlers(use_threads(cfg))
start_learner_threads(
cfg,
logger,
out_dir,
shutdown_event,
cfg=cfg,
wandb_logger=wandb_logger,
shutdown_event=shutdown_event,
)
@hydra.main(version_base="1.2", config_name="default", config_path="../../configs")
def train_cli(cfg: dict):
@parser.wrap()
def train_cli(cfg: TrainPipelineConfig):
if not use_threads(cfg):
import torch.multiprocessing as mp
mp.set_start_method("spawn")
# Use the job_name from the config
train(
cfg,
out_dir=hydra.core.hydra_config.HydraConfig.get().run.dir,
job_name=hydra.core.hydra_config.HydraConfig.get().job.name,
job_name=cfg.job_name,
)
logging.info("[LEARNER] train_cli finished")
@ -816,5 +947,4 @@ def train_cli(cfg: dict):
if __name__ == "__main__":
train_cli()
logging.info("[LEARNER] main finished")

View File

@ -1,4 +1,7 @@
from typing import Any
import logging
import time
from dataclasses import dataclass, field
from typing import Any, Dict, Optional, Tuple
import einops
import gymnasium as gym
@ -6,7 +9,11 @@ import numpy as np
import torch
from mani_skill.utils.wrappers.record import RecordEpisode
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
from omegaconf import DictConfig
from lerobot.common.envs.configs import EnvConfig, ManiskillEnvConfig
from lerobot.configs import parser
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.common.constants import ACTION, OBS_IMAGE, OBS_ROBOT
def preprocess_maniskill_observation(
@ -46,9 +53,14 @@ def preprocess_maniskill_observation(
return return_observations
class ManiSkillObservationWrapper(gym.ObservationWrapper):
def __init__(self, env, device: torch.device = "cuda"):
super().__init__(env)
if isinstance(device, str):
device = torch.device(device)
self.device = device
def observation(self, observation):
@ -108,76 +120,129 @@ class ManiSkillMultiplyActionWrapper(gym.Wrapper):
return obs, reward, terminated, truncated, info
class BatchCompatibleWrapper(gym.ObservationWrapper):
"""Ensures observations are batch-compatible by adding a batch dimension if necessary."""
def __init__(self, env):
super().__init__(env)
def observation(self, observation: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
for key in observation:
if "image" in key and observation[key].dim() == 3:
observation[key] = observation[key].unsqueeze(0)
if "state" in key and observation[key].dim() == 1:
observation[key] = observation[key].unsqueeze(0)
return observation
class TimeLimitWrapper(gym.Wrapper):
"""Adds a time limit to the environment based on fps and control_time."""
def __init__(self, env, control_time_s, fps):
super().__init__(env)
self.control_time_s = control_time_s
self.fps = fps
self.max_episode_steps = int(self.control_time_s * self.fps)
self.current_step = 0
def step(self, action):
obs, reward, terminated, truncated, info = self.env.step(action)
self.current_step += 1
if self.current_step >= self.max_episode_steps:
terminated = True
return obs, reward, terminated, truncated, info
def reset(self, *, seed=None, options=None):
self.current_step = 0
return super().reset(seed=seed, options=options)
def make_maniskill(
cfg: DictConfig,
cfg: ManiskillEnvConfig,
n_envs: int | None = None,
) -> gym.Env:
"""
Factory function to create a ManiSkill environment with standard wrappers.
Args:
task: Name of the ManiSkill task
obs_mode: Observation mode (rgb, rgbd, etc)
control_mode: Control mode for the robot
render_mode: Rendering mode
sensor_configs: Camera sensor configurations
cfg: Configuration for the ManiSkill environment
n_envs: Number of parallel environments
Returns:
A wrapped ManiSkill environment
"""
env = gym.make(
cfg.env.task,
obs_mode=cfg.env.obs,
control_mode=cfg.env.control_mode,
render_mode=cfg.env.render_mode,
sensor_configs={"width": cfg.env.image_size, "height": cfg.env.image_size},
cfg.task,
obs_mode=cfg.obs_type,
control_mode=cfg.control_mode,
render_mode=cfg.render_mode,
sensor_configs={"width": cfg.image_size, "height": cfg.image_size},
num_envs=n_envs,
)
if cfg.env.video_record.enabled:
# Add video recording if enabled
if cfg.video_record.enabled:
env = RecordEpisode(
env,
output_dir=cfg.env.video_record.record_dir,
output_dir=cfg.video_record.record_dir,
save_trajectory=True,
trajectory_name=cfg.env.video_record.trajectory_name,
trajectory_name=cfg.video_record.trajectory_name,
save_video=True,
video_fps=30,
)
env = ManiSkillObservationWrapper(env, device=cfg.env.device)
# Add observation and image processing
env = ManiSkillObservationWrapper(env, device=cfg.device)
env = ManiSkillVectorEnv(env, ignore_terminations=True, auto_reset=False)
env._max_episode_steps = env.max_episode_steps = 50 # gym_utils.find_max_episode_steps_value(env)
env.unwrapped.metadata["render_fps"] = 20
env._max_episode_steps = env.max_episode_steps = cfg.episode_length
env.unwrapped.metadata["render_fps"] = cfg.fps
# Add compatibility wrappers
env = ManiSkillCompat(env)
env = ManiSkillActionWrapper(env)
env = ManiSkillMultiplyActionWrapper(env, multiply_factor=0.03)
env = ManiSkillMultiplyActionWrapper(env, multiply_factor=0.03) # Scale actions for better control
return env
if __name__ == "__main__":
import argparse
@parser.wrap()
def main(cfg: ManiskillEnvConfig):
"""Main function to run the ManiSkill environment."""
# Create the ManiSkill environment
env = make_maniskill(cfg, n_envs=1)
import hydra
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default="lerobot/configs/env/maniskill_example.yaml")
args = parser.parse_args()
# Initialize config
with hydra.initialize(version_base=None, config_path="../../configs"):
cfg = hydra.compose(config_name="env/maniskill_example.yaml")
env = make_maniskill(
task=cfg.env.task,
obs_mode=cfg.env.obs,
control_mode=cfg.env.control_mode,
render_mode=cfg.env.render_mode,
sensor_configs={"width": cfg.env.render_size, "height": cfg.env.render_size},
)
print("env done")
# Reset the environment
obs, info = env.reset()
random_action = env.action_space.sample()
obs, reward, terminated, truncated, info = env.step(random_action)
# Run a simple interaction loop
sum_reward = 0
for i in range(100):
# Sample a random action
action = env.action_space.sample()
# Step the environment
start_time = time.perf_counter()
obs, reward, terminated, truncated, info = env.step(action)
step_time = time.perf_counter() - start_time
sum_reward += reward
# Log information
# Reset if episode terminated
if terminated or truncated:
logging.info(f"Step {i}, reward: {sum_reward}, step time: {step_time}s")
sum_reward = 0
obs, info = env.reset()
# Close the environment
env.close()
# if __name__ == "__main__":
# logging.basicConfig(level=logging.INFO)
# main()
if __name__ == "__main__":
import draccus
config = ManiskillEnvConfig()
draccus.set_config_type("json")
draccus.dump(config=config, stream=open(file='run_config.json', mode='w'), )