[WIP] Non functional yet

Add ManiSkill environment configuration and wrappers

- Introduced `VideoRecordConfig` for video recording settings.
- Added `ManiskillEnvConfig` to encapsulate environment-specific configurations.
- Implemented various wrappers for the ManiSkill environment, including observation and action scaling.
- Enhanced the `make_maniskill` function to create a wrapped ManiSkill environment with video recording and observation processing.
- Updated the `actor_server` and `learner_server` to utilize the new configuration structure.
- Refactored the training pipeline to accommodate the new environment and policy configurations.
This commit is contained in:
AdilZouitine 2025-03-26 08:15:05 +00:00
parent b7b6d8102f
commit dd37bd412e
9 changed files with 667 additions and 436 deletions

View File

@ -154,3 +154,61 @@ class XarmEnv(EnvConfig):
"visualization_height": self.visualization_height, "visualization_height": self.visualization_height,
"max_episode_steps": self.episode_length, "max_episode_steps": self.episode_length,
} }
@dataclass
class VideoRecordConfig:
"""Configuration for video recording in ManiSkill environments."""
enabled: bool = False
record_dir: str = "videos"
trajectory_name: str = "trajectory"
@EnvConfig.register_subclass("maniskill_push")
@dataclass
class ManiskillEnvConfig(EnvConfig):
"""Configuration for the ManiSkill environment."""
name: str = "maniskill/pushcube"
task: str = "PushCube-v1"
image_size: int = 64
control_mode: str = "pd_ee_delta_pose"
state_dim: int = 25
action_dim: int = 7
fps: int = 400
episode_length: int = 50
obs_type: str = "rgb"
render_mode: str = "rgb_array"
render_size: int = 64
device: str = "cuda"
robot: str = "so100" # This is a hack to make the robot config work
video_record: VideoRecordConfig = field(default_factory=VideoRecordConfig)
features: dict[str, PolicyFeature] = field(
default_factory=lambda: {
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 64, 64)),
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(25,)),
}
)
features_map: dict[str, str] = field(
default_factory=lambda: {
"action": ACTION,
"observation.image": OBS_IMAGE,
"observation.state": OBS_ROBOT,
}
)
reward_classifier: dict[str, str | None] = field(
default_factory=lambda: {
"pretrained_path": None,
"config_path": None,
}
)
@property
def gym_kwargs(self) -> dict:
return {
"obs_type": self.obs_type,
"render_mode": self.render_mode,
"max_episode_steps": self.episode_length,
"control_mode": self.control_mode,
"sensor_configs": {"width": self.image_size, "height": self.image_size},
"num_envs": 1,
}

View File

@ -69,88 +69,3 @@ def make_env(cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False) -> g
return env return env
def make_maniskill_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv | None:
"""Make ManiSkill3 gym environment"""
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
env = gym.make(
cfg.env.task,
obs_mode=cfg.env.obs,
control_mode=cfg.env.control_mode,
render_mode=cfg.env.render_mode,
sensor_configs=dict(width=cfg.env.image_size, height=cfg.env.image_size),
num_envs=n_envs,
)
# cfg.env_cfg.control_mode = cfg.eval_env_cfg.control_mode = env.control_mode
env = ManiSkillVectorEnv(env, ignore_terminations=True)
# state should have the size of 25
# env = ConvertToLeRobotEnv(env, n_envs)
# env = PixelWrapper(cfg, env, n_envs)
env._max_episode_steps = env.max_episode_steps = 50 # gym_utils.find_max_episode_steps_value(env)
env.unwrapped.metadata["render_fps"] = 20
return env
class PixelWrapper(gym.Wrapper):
"""
Wrapper for pixel observations. Works with Maniskill vectorized environments
"""
def __init__(self, cfg, env, num_envs, num_frames=3):
super().__init__(env)
self.cfg = cfg
self.env = env
self.observation_space = gym.spaces.Box(
low=0,
high=255,
shape=(num_envs, num_frames * 3, cfg.env.render_size, cfg.env.render_size),
dtype=np.uint8,
)
self._frames = deque([], maxlen=num_frames)
self._render_size = cfg.env.render_size
def _get_obs(self, obs):
frame = obs["sensor_data"]["base_camera"]["rgb"].cpu().permute(0, 3, 1, 2)
self._frames.append(frame)
return {"pixels": torch.from_numpy(np.concatenate(self._frames, axis=1)).to(self.env.device)}
def reset(self, seed):
obs, info = self.env.reset() # (seed=seed)
for _ in range(self._frames.maxlen):
obs_frames = self._get_obs(obs)
return obs_frames, info
def step(self, action):
obs, reward, terminated, truncated, info = self.env.step(action)
return self._get_obs(obs), reward, terminated, truncated, info
# TODO: Remove this
class ConvertToLeRobotEnv(gym.Wrapper):
def __init__(self, env, num_envs):
super().__init__(env)
def reset(self, seed=None, options=None):
obs, info = self.env.reset(seed=seed, options={})
return self._get_obs(obs), info
def step(self, action):
obs, reward, terminated, truncated, info = self.env.step(action)
return self._get_obs(obs), reward, terminated, truncated, info
def _get_obs(self, observation):
sensor_data = observation.pop("sensor_data")
del observation["sensor_param"]
images = []
for cam_data in sensor_data.values():
images.append(cam_data["rgb"])
images = torch.concat(images, axis=-1)
# flatten the rest of the data which should just be state data
observation = common.flatten_state_dict(observation, use_torch=True, device=self.base_env.device)
ret = dict()
ret["state"] = observation
ret["pixels"] = images
return ret

View File

@ -31,12 +31,19 @@ class SACConfig(PreTrainedConfig):
Args: Args:
n_obs_steps: Number of environment steps worth of observations to pass to the policy. n_obs_steps: Number of environment steps worth of observations to pass to the policy.
normalization_mapping: Mapping from feature types to normalization modes. normalization_mapping: Mapping from feature types to normalization modes.
dataset_stats: Statistics for normalizing different data types.
camera_number: Number of cameras to use. camera_number: Number of cameras to use.
device: Device to use for training.
storage_device: Device to use for storage. storage_device: Device to use for storage.
vision_encoder_name: Name of the vision encoder to use. vision_encoder_name: Name of the vision encoder to use.
freeze_vision_encoder: Whether to freeze the vision encoder. freeze_vision_encoder: Whether to freeze the vision encoder.
image_encoder_hidden_dim: Hidden dimension for the image encoder. image_encoder_hidden_dim: Hidden dimension for the image encoder.
shared_encoder: Whether to use a shared encoder. shared_encoder: Whether to use a shared encoder.
online_steps: Total number of online training steps.
online_env_seed: Seed for the online environment.
online_buffer_capacity: Capacity of the online replay buffer.
online_step_before_learning: Number of steps to collect before starting learning.
policy_update_freq: Frequency of policy updates.
discount: Discount factor for the RL algorithm. discount: Discount factor for the RL algorithm.
temperature_init: Initial temperature for entropy regularization. temperature_init: Initial temperature for entropy regularization.
num_critics: Number of critic networks. num_critics: Number of critic networks.
@ -54,6 +61,8 @@ class SACConfig(PreTrainedConfig):
critic_network_kwargs: Additional arguments for critic networks. critic_network_kwargs: Additional arguments for critic networks.
actor_network_kwargs: Additional arguments for actor network. actor_network_kwargs: Additional arguments for actor network.
policy_kwargs: Additional arguments for policy. policy_kwargs: Additional arguments for policy.
actor_learner_config: Configuration for actor-learner communication.
concurrency: Configuration for concurrency model.
""" """
# Input / output structure # Input / output structure
@ -86,6 +95,7 @@ class SACConfig(PreTrainedConfig):
# Architecture specifics # Architecture specifics
camera_number: int = 1 camera_number: int = 1
device: str = "cuda"
storage_device: str = "cpu" storage_device: str = "cpu"
# Set to "helper2424/resnet10" for hil serl # Set to "helper2424/resnet10" for hil serl
vision_encoder_name: str | None = None vision_encoder_name: str | None = None
@ -93,6 +103,13 @@ class SACConfig(PreTrainedConfig):
image_encoder_hidden_dim: int = 32 image_encoder_hidden_dim: int = 32
shared_encoder: bool = True shared_encoder: bool = True
# Training parameter
online_steps: int = 1000000
online_env_seed: int = 10000
online_buffer_capacity: int = 10000
online_step_before_learning: int = 100
policy_update_freq: int = 1
# SAC algorithm parameters # SAC algorithm parameters
discount: float = 0.99 discount: float = 0.99
temperature_init: float = 1.0 temperature_init: float = 1.0
@ -132,11 +149,17 @@ class SACConfig(PreTrainedConfig):
} }
) )
# Deprecated, kept for backward compatibility
actor_learner_config: dict[str, str | int] = field( actor_learner_config: dict[str, str | int] = field(
default_factory=lambda: { default_factory=lambda: {
"learner_host": "127.0.0.1", "learner_host": "127.0.0.1",
"learner_port": 50051, "learner_port": 50051,
"policy_parameters_push_frequency": 4,
}
)
concurrency: dict[str, str] = field(
default_factory=lambda: {
"actor": "threads",
"learner": "threads"
} }
) )

View File

@ -92,6 +92,8 @@ class WandBLogger:
resume="must" if cfg.resume else None, resume="must" if cfg.resume else None,
mode=self.cfg.mode if self.cfg.mode in ["online", "offline", "disabled"] else "online", mode=self.cfg.mode if self.cfg.mode in ["online", "offline", "disabled"] else "online",
) )
# Handle custom step key for rl asynchronous training.
self._wandb_custom_step_key: set[str] | None = None
print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"])) print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"]))
logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}") logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}")
self._wandb = wandb self._wandb = wandb
@ -108,9 +110,24 @@ class WandBLogger:
artifact.add_file(checkpoint_dir / PRETRAINED_MODEL_DIR / SAFETENSORS_SINGLE_FILE) artifact.add_file(checkpoint_dir / PRETRAINED_MODEL_DIR / SAFETENSORS_SINGLE_FILE)
self._wandb.log_artifact(artifact) self._wandb.log_artifact(artifact)
def log_dict(self, d: dict, step: int, mode: str = "train"): def log_dict(self, d: dict, step: int, mode: str = "train", custom_step_key: str | None = None):
if mode not in {"train", "eval"}: if mode not in {"train", "eval"}:
raise ValueError(mode) raise ValueError(mode)
if step is None and custom_step_key is None:
raise ValueError("Either step or custom_step_key must be provided.")
# NOTE: This is not simple. Wandb step is it must always monotonically increase and it
# increases with each wandb.log call, but in the case of asynchronous RL for example,
# multiple time steps is possible for example, the interaction step with the environment,
# the training step, the evaluation step, etc. So we need to define a custom step key
# to log the correct step for each metric.
if custom_step_key is not None:
if self._wandb_custom_step_key is None:
self._wandb_custom_step_key = set()
new_custom_key = f"{mode}/{custom_step_key}"
if new_custom_key not in self._wandb_custom_step_key:
self._wandb_custom_step_key.add(new_custom_key)
self._wandb.define_metric(new_custom_key, hidden=True)
for k, v in d.items(): for k, v in d.items():
if not isinstance(v, (int, float, str)): if not isinstance(v, (int, float, str)):
@ -118,7 +135,26 @@ class WandBLogger:
f'WandB logging of key "{k}" was ignored as its type is not handled by this wrapper.' f'WandB logging of key "{k}" was ignored as its type is not handled by this wrapper.'
) )
continue continue
self._wandb.log({f"{mode}/{k}": v}, step=step)
# Do not log the custom step key itself.
if (
self._wandb_custom_step_key is not None
and k in self._wandb_custom_step_key
):
continue
if custom_step_key is not None:
value_custom_step = d[custom_step_key]
self._wandb.log(
{
f"{mode}/{k}": v,
f"{mode}/{custom_step_key}": value_custom_step,
}
)
continue
self._wandb.log(data={f"{mode}/{k}": v}, step=step)
def log_video(self, video_path: str, step: int, mode: str = "train"): def log_video(self, video_path: str, step: int, mode: str = "train"):
if mode not in {"train", "eval"}: if mode not in {"train", "eval"}:

View File

@ -34,11 +34,10 @@ TRAIN_CONFIG_NAME = "train_config.json"
@dataclass @dataclass
class TrainPipelineConfig(HubMixin): class TrainPipelineConfig(HubMixin):
dataset: DatasetConfig dataset: DatasetConfig | None = None # NOTE: In RL, we don't need a dataset
env: envs.EnvConfig | None = None env: envs.EnvConfig | None = None
policy: PreTrainedConfig | None = None policy: PreTrainedConfig | None = None
# Set `dir` to where you would like to save all of the run outputs. If you run another training session # Set `dir` to where you would like to save all of the run outputs. If you run another training session # with the same value for `dir` its contents will be overwritten unless you set `resume` to true.
# with the same value for `dir` its contents will be overwritten unless you set `resume` to true.
output_dir: Path | None = None output_dir: Path | None = None
job_name: str | None = None job_name: str | None = None
# Set `resume` to true to resume a previous run. In order for this to work, you will need to make sure # Set `resume` to true to resume a previous run. In order for this to work, you will need to make sure

View File

@ -21,26 +21,23 @@ from statistics import mean, quantiles
# from lerobot.scripts.eval import eval_policy # from lerobot.scripts.eval import eval_policy
import grpc import grpc
import hydra
import torch import torch
from omegaconf import DictConfig
from torch import nn from torch import nn
from torch.multiprocessing import Event, Queue from torch.multiprocessing import Event, Queue
# TODO: Remove the import of maniskill # TODO: Remove the import of maniskill
# from lerobot.common.envs.factory import make_maniskill_env
# from lerobot.common.envs.utils import preprocess_maniskill_observation
from lerobot.common.policies.factory import make_policy from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.sac.modeling_sac import SACPolicy from lerobot.common.policies.sac.modeling_sac import SACPolicy
from lerobot.common.robot_devices.robots.factory import make_robot from lerobot.common.robot_devices.robots.utils import Robot, make_robot
from lerobot.common.robot_devices.robots.utils import Robot
from lerobot.common.robot_devices.utils import busy_wait from lerobot.common.robot_devices.utils import busy_wait
from lerobot.common.utils.random_utils import set_seed
from lerobot.common.utils.utils import ( from lerobot.common.utils.utils import (
TimerManager, TimerManager,
get_safe_torch_device, get_safe_torch_device,
init_logging, init_logging,
set_global_seed,
) )
from lerobot.configs import parser
from lerobot.configs.train import TrainPipelineConfig
from lerobot.scripts.server import hilserl_pb2, hilserl_pb2_grpc, learner_service from lerobot.scripts.server import hilserl_pb2, hilserl_pb2_grpc, learner_service
from lerobot.scripts.server.buffer import ( from lerobot.scripts.server.buffer import (
Transition, Transition,
@ -61,7 +58,7 @@ ACTOR_SHUTDOWN_TIMEOUT = 30
def receive_policy( def receive_policy(
cfg: DictConfig, cfg: TrainPipelineConfig,
parameters_queue: Queue, parameters_queue: Queue,
shutdown_event: any, # Event, shutdown_event: any, # Event,
learner_client: hilserl_pb2_grpc.LearnerServiceStub | None = None, learner_client: hilserl_pb2_grpc.LearnerServiceStub | None = None,
@ -72,12 +69,12 @@ def receive_policy(
if not use_threads(cfg): if not use_threads(cfg):
# Setup process handlers to handle shutdown signal # Setup process handlers to handle shutdown signal
# But use shutdown event from the main process # But use shutdown event from the main process
setup_process_handlers(False) setup_process_handlers(use_threads=False)
if grpc_channel is None or learner_client is None: if grpc_channel is None or learner_client is None:
learner_client, grpc_channel = learner_service_client( learner_client, grpc_channel = learner_service_client(
host=cfg.actor_learner_config.learner_host, host=cfg.policy.actor_learner_config["learner_host"],
port=cfg.actor_learner_config.learner_port, port=cfg.policy.actor_learner_config["learner_port"],
) )
try: try:
@ -132,7 +129,7 @@ def interactions_stream(
def send_transitions( def send_transitions(
cfg: DictConfig, cfg: TrainPipelineConfig,
transitions_queue: Queue, transitions_queue: Queue,
shutdown_event: any, # Event, shutdown_event: any, # Event,
learner_client: hilserl_pb2_grpc.LearnerServiceStub | None = None, learner_client: hilserl_pb2_grpc.LearnerServiceStub | None = None,
@ -156,8 +153,8 @@ def send_transitions(
if grpc_channel is None or learner_client is None: if grpc_channel is None or learner_client is None:
learner_client, grpc_channel = learner_service_client( learner_client, grpc_channel = learner_service_client(
host=cfg.actor_learner_config.learner_host, host=cfg.policy.actor_learner_config["learner_host"],
port=cfg.actor_learner_config.learner_port, port=cfg.policy.actor_learner_config["learner_port"],
) )
try: try:
@ -173,7 +170,7 @@ def send_transitions(
def send_interactions( def send_interactions(
cfg: DictConfig, cfg: TrainPipelineConfig,
interactions_queue: Queue, interactions_queue: Queue,
shutdown_event: any, # Event, shutdown_event: any, # Event,
learner_client: hilserl_pb2_grpc.LearnerServiceStub | None = None, learner_client: hilserl_pb2_grpc.LearnerServiceStub | None = None,
@ -196,8 +193,8 @@ def send_interactions(
if grpc_channel is None or learner_client is None: if grpc_channel is None or learner_client is None:
learner_client, grpc_channel = learner_service_client( learner_client, grpc_channel = learner_service_client(
host=cfg.actor_learner_config.learner_host, host=cfg.policy.actor_learner_config["learner_host"],
port=cfg.actor_learner_config.learner_port, port=cfg.policy.actor_learner_config["learner_port"],
) )
try: try:
@ -269,7 +266,7 @@ def update_policy_parameters(policy: SACPolicy, parameters_queue: Queue, device)
def act_with_policy( def act_with_policy(
cfg: DictConfig, cfg: TrainPipelineConfig,
robot: Robot, robot: Robot,
reward_classifier: nn.Module, reward_classifier: nn.Module,
shutdown_event: any, # Event, shutdown_event: any, # Event,
@ -291,7 +288,7 @@ def act_with_policy(
online_env = make_robot_env(robot=robot, reward_classifier=reward_classifier, cfg=cfg) online_env = make_robot_env(robot=robot, reward_classifier=reward_classifier, cfg=cfg)
set_global_seed(cfg.seed) set_seed(cfg.seed)
device = get_safe_torch_device(cfg.device, log=True) device = get_safe_torch_device(cfg.device, log=True)
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
@ -304,7 +301,7 @@ def act_with_policy(
### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters ### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters
# TODO: At some point we should just need make sac policy # TODO: At some point we should just need make sac policy
policy: SACPolicy = make_policy( policy: SACPolicy = make_policy(
hydra_cfg=cfg, cfg=cfg.policy,
# dataset_stats=offline_dataset.meta.stats if not cfg.resume else None, # dataset_stats=offline_dataset.meta.stats if not cfg.resume else None,
# Hack: But if we do online training, we do not need dataset_stats # Hack: But if we do online training, we do not need dataset_stats
dataset_stats=None, dataset_stats=None,
@ -469,7 +466,7 @@ def get_frequency_stats(list_policy_time: list[float]) -> dict[str, float]:
return stats return stats
def log_policy_frequency_issue(policy_fps: float, cfg: DictConfig, interaction_step: int): def log_policy_frequency_issue(policy_fps: float, cfg: TrainPipelineConfig, interaction_step: int):
if policy_fps < cfg.fps: if policy_fps < cfg.fps:
logging.warning( logging.warning(
f"[ACTOR] Policy FPS {policy_fps:.1f} below required {cfg.fps} at step {interaction_step}" f"[ACTOR] Policy FPS {policy_fps:.1f} below required {cfg.fps} at step {interaction_step}"
@ -497,25 +494,25 @@ def establish_learner_connection(
return False return False
def use_threads(cfg: DictConfig) -> bool: def use_threads(cfg: TrainPipelineConfig) -> bool:
return cfg.actor_learner_config.concurrency.actor == "threads" return cfg.policy.concurrency["actor"] == "threads"
@hydra.main(version_base="1.2", config_name="default", config_path="../../configs") @parser.wrap()
def actor_cli(cfg: dict): def actor_cli(cfg: TrainPipelineConfig):
if not use_threads(cfg): if not use_threads(cfg):
import torch.multiprocessing as mp import torch.multiprocessing as mp
mp.set_start_method("spawn") mp.set_start_method("spawn")
init_logging(log_file="actor.log") init_logging(log_file="actor.log")
robot = make_robot(cfg=cfg.robot) robot = make_robot(robot_type=cfg.env.robot)
shutdown_event = setup_process_handlers(use_threads(cfg)) shutdown_event = setup_process_handlers(use_threads(cfg))
learner_client, grpc_channel = learner_service_client( learner_client, grpc_channel = learner_service_client(
host=cfg.actor_learner_config.learner_host, host=cfg.policy.actor_learner_config["learner_host"],
port=cfg.actor_learner_config.learner_port, port=cfg.policy.actor_learner_config["learner_port"],
) )
logging.info("[ACTOR] Establishing connection with Learner") logging.info("[ACTOR] Establishing connection with Learner")
@ -570,22 +567,22 @@ def actor_cli(cfg: dict):
# TODO: Remove this once we merge into main # TODO: Remove this once we merge into main
reward_classifier = None reward_classifier = None
if ( if (
cfg.env.reward_classifier.pretrained_path is not None cfg.env.reward_classifier["pretrained_path"] is not None
and cfg.env.reward_classifier.config_path is not None and cfg.env.reward_classifier["config_path"] is not None
): ):
reward_classifier = get_classifier( reward_classifier = get_classifier(
pretrained_path=cfg.env.reward_classifier.pretrained_path, pretrained_path=cfg.env.reward_classifier["pretrained_path"],
config_path=cfg.env.reward_classifier.config_path, config_path=cfg.env.reward_classifier["config_path"],
) )
act_with_policy( act_with_policy(
cfg, cfg=cfg,
robot, robot=robot,
reward_classifier, reward_classifier=reward_classifier,
shutdown_event, shutdown_event=shutdown_event,
parameters_queue, parameters_queue=parameters_queue,
transitions_queue, transitions_queue=transitions_queue,
interactions_queue, interactions_queue=interactions_queue,
) )
logging.info("[ACTOR] Policy process joined") logging.info("[ACTOR] Policy process joined")

View File

@ -15,10 +15,12 @@ import json
from dataclasses import dataclass from dataclasses import dataclass
from lerobot.common.envs.utils import preprocess_observation from lerobot.common.envs.utils import preprocess_observation
from lerobot.configs.train import TrainPipelineConfig
from lerobot.common.envs.configs import EnvConfig
from lerobot.common.robot_devices.control_utils import ( from lerobot.common.robot_devices.control_utils import (
busy_wait, busy_wait,
is_headless, is_headless,
reset_follower_position, # reset_follower_position,
) )
from typing import Optional from typing import Optional
@ -28,6 +30,7 @@ from lerobot.common.robot_devices.robots.utils import make_robot_from_config
from lerobot.common.robot_devices.robots.configs import RobotConfig from lerobot.common.robot_devices.robots.configs import RobotConfig
from lerobot.scripts.server.kinematics import RobotKinematics from lerobot.scripts.server.kinematics import RobotKinematics
from lerobot.scripts.server.maniskill_manipulator import ManiskillEnvConfig, make_maniskill
from lerobot.configs import parser from lerobot.configs import parser
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
@ -1094,7 +1097,10 @@ class ActionScaleWrapper(gym.ActionWrapper):
return action * self.scale_vector, is_intervention return action * self.scale_vector, is_intervention
def make_robot_env(cfg, robot) -> gym.vector.VectorEnv: @parser.wrap()
def make_robot_env(cfg: EnvConfig) -> gym.vector.VectorEnv:
# def make_robot_env(cfg: TrainPipelineConfig) -> gym.vector.VectorEnv:
# def make_robot_env(cfg: ManiskillEnvConfig) -> gym.vector.VectorEnv:
""" """
Factory function to create a vectorized robot environment. Factory function to create a vectorized robot environment.
@ -1106,7 +1112,7 @@ def make_robot_env(cfg, robot) -> gym.vector.VectorEnv:
Returns: Returns:
A vectorized gym environment with all the necessary wrappers applied. A vectorized gym environment with all the necessary wrappers applied.
""" """
if "maniskill" in cfg.env_name: if "maniskill" in cfg.name:
from lerobot.scripts.server.maniskill_manipulator import make_maniskill from lerobot.scripts.server.maniskill_manipulator import make_maniskill
logging.warning("WE SHOULD REMOVE THE MANISKILL BEFORE THE MERGE INTO MAIN") logging.warning("WE SHOULD REMOVE THE MANISKILL BEFORE THE MERGE INTO MAIN")
@ -1115,7 +1121,7 @@ def make_robot_env(cfg, robot) -> gym.vector.VectorEnv:
n_envs=1, n_envs=1,
) )
return env return env
robot = cfg.robot
# Create base environment # Create base environment
env = HILSerlRobotEnv( env = HILSerlRobotEnv(
robot=robot, robot=robot,
@ -1329,80 +1335,82 @@ def replay_episode(env, repo_id, root=None, episode=0):
busy_wait(1 / 10 - dt_s) busy_wait(1 / 10 - dt_s)
@parser.wrap() # @parser.wrap()
def main(cfg: HILSerlRobotEnvConfig): # def main(cfg):
robot = make_robot_from_config(cfg.robot) # robot = make_robot_from_config(cfg.robot)
reward_classifier = None #get_classifier( # reward_classifier = None #get_classifier(
# cfg.wrapper.reward_classifier_pretrained_path, cfg.wrapper.reward_classifier_config_file # # cfg.wrapper.reward_classifier_pretrained_path, cfg.wrapper.reward_classifier_config_file
# ) # # )
user_relative_joint_positions = True # user_relative_joint_positions = True
env = make_robot_env(cfg, robot) # env = make_robot_env(cfg, robot)
if cfg.mode == "record": # if cfg.mode == "record":
policy = None # policy = None
if cfg.pretrained_policy_name_or_path is not None: # if cfg.pretrained_policy_name_or_path is not None:
from lerobot.common.policies.sac.modeling_sac import SACPolicy # from lerobot.common.policies.sac.modeling_sac import SACPolicy
policy = SACPolicy.from_pretrained(cfg.pretrained_policy_name_or_path) # policy = SACPolicy.from_pretrained(cfg.pretrained_policy_name_or_path)
policy.to(cfg.device) # policy.to(cfg.device)
policy.eval() # policy.eval()
record_dataset( # record_dataset(
env, # env,
cfg.repo_id, # cfg.repo_id,
root=cfg.dataset_root, # root=cfg.dataset_root,
num_episodes=cfg.num_episodes, # num_episodes=cfg.num_episodes,
fps=cfg.fps, # fps=cfg.fps,
task_description=cfg.task, # task_description=cfg.task,
policy=policy, # policy=policy,
) # )
exit() # exit()
if cfg.mode == "replay": # if cfg.mode == "replay":
replay_episode( # replay_episode(
env, # env,
cfg.replay_repo_id, # cfg.replay_repo_id,
root=cfg.dataset_root, # root=cfg.dataset_root,
episode=cfg.replay_episode, # episode=cfg.replay_episode,
) # )
exit() # exit()
env.reset() # env.reset()
# Retrieve the robot's action space for joint commands. # # Retrieve the robot's action space for joint commands.
action_space_robot = env.action_space.spaces[0] # action_space_robot = env.action_space.spaces[0]
# Initialize the smoothed action as a random sample. # # Initialize the smoothed action as a random sample.
smoothed_action = action_space_robot.sample() # smoothed_action = action_space_robot.sample()
# Smoothing coefficient (alpha) defines how much of the new random sample to mix in. # # Smoothing coefficient (alpha) defines how much of the new random sample to mix in.
# A value close to 0 makes the trajectory very smooth (slow to change), while a value close to 1 is less smooth. # # A value close to 0 makes the trajectory very smooth (slow to change), while a value close to 1 is less smooth.
alpha = 1.0 # alpha = 1.0
num_episode = 0 # num_episode = 0
sucesses = [] # sucesses = []
while num_episode < 20: # while num_episode < 20:
start_loop_s = time.perf_counter() # start_loop_s = time.perf_counter()
# Sample a new random action from the robot's action space. # # Sample a new random action from the robot's action space.
new_random_action = action_space_robot.sample() # new_random_action = action_space_robot.sample()
# Update the smoothed action using an exponential moving average. # # Update the smoothed action using an exponential moving average.
smoothed_action = alpha * new_random_action + (1 - alpha) * smoothed_action # smoothed_action = alpha * new_random_action + (1 - alpha) * smoothed_action
# Execute the step: wrap the NumPy action in a torch tensor. # # Execute the step: wrap the NumPy action in a torch tensor.
obs, reward, terminated, truncated, info = env.step((torch.from_numpy(smoothed_action), False)) # obs, reward, terminated, truncated, info = env.step((torch.from_numpy(smoothed_action), False))
if terminated or truncated: # if terminated or truncated:
sucesses.append(reward) # sucesses.append(reward)
env.reset() # env.reset()
num_episode += 1 # num_episode += 1
dt_s = time.perf_counter() - start_loop_s # dt_s = time.perf_counter() - start_loop_s
busy_wait(1 / cfg.fps - dt_s) # busy_wait(1 / cfg.fps - dt_s)
logging.info(f"Success after 20 steps {sucesses}") # logging.info(f"Success after 20 steps {sucesses}")
logging.info(f"success rate {sum(sucesses) / len(sucesses)}") # logging.info(f"success rate {sum(sucesses) / len(sucesses)}")
# if __name__ == "__main__":
# main()
if __name__ == "__main__": if __name__ == "__main__":
main() make_robot_env()

View File

@ -19,40 +19,45 @@ import shutil
import time import time
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from pprint import pformat from pprint import pformat
import os
from pathlib import Path
import draccus
import grpc import grpc
# Import generated stubs # Import generated stubs
import hilserl_pb2_grpc # type: ignore import hilserl_pb2_grpc # type: ignore
import hydra
import torch import torch
from deepdiff import DeepDiff
from omegaconf import DictConfig, OmegaConf
from termcolor import colored from termcolor import colored
from torch import nn from torch import nn
# from torch.multiprocessing import Event, Queue, Process
# from threading import Event, Thread
# from torch.multiprocessing import Queue, Event
from torch.multiprocessing import Queue from torch.multiprocessing import Queue
from torch.optim.optimizer import Optimizer from torch.optim.optimizer import Optimizer
from lerobot.common.datasets.factory import make_dataset from lerobot.common.datasets.factory import make_dataset
from lerobot.configs.train import TrainPipelineConfig
from lerobot.configs import parser
# TODO: Remove the import of maniskill # TODO: Remove the import of maniskill
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.logger import Logger, log_output_dir
from lerobot.common.policies.factory import make_policy from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.sac.modeling_sac import SACPolicy from lerobot.common.policies.sac.modeling_sac import SACPolicy, SACConfig
from lerobot.common.utils.train_utils import (
get_step_checkpoint_dir,
get_step_identifier,
load_training_state as utils_load_training_state,
save_checkpoint,
update_last_checkpoint,
)
from lerobot.common.utils.random_utils import set_seed
from lerobot.common.utils.utils import ( from lerobot.common.utils.utils import (
format_big_number, format_big_number,
get_global_random_state,
get_safe_torch_device, get_safe_torch_device,
init_hydra_config,
init_logging, init_logging,
set_global_random_state,
set_global_seed,
) )
from lerobot.common.policies.utils import get_device_from_parameters
from lerobot.common.utils.wandb_utils import WandBLogger
from lerobot.scripts.server import learner_service from lerobot.scripts.server import learner_service
from lerobot.scripts.server.buffer import ( from lerobot.scripts.server.buffer import (
ReplayBuffer, ReplayBuffer,
@ -64,102 +69,167 @@ from lerobot.scripts.server.buffer import (
state_to_bytes, state_to_bytes,
) )
from lerobot.scripts.server.utils import setup_process_handlers from lerobot.scripts.server.utils import setup_process_handlers
from lerobot.common.constants import (
CHECKPOINTS_DIR,
LAST_CHECKPOINT_LINK,
PRETRAINED_MODEL_DIR,
TRAINING_STATE_DIR,
TRAINING_STEP,
)
def handle_resume_logic(cfg: DictConfig, out_dir: str) -> DictConfig: def handle_resume_logic(cfg: TrainPipelineConfig) -> TrainPipelineConfig:
"""
Handle the resume logic for training.
If resume is True:
- Verifies that a checkpoint exists
- Loads the checkpoint configuration
- Logs resumption details
- Returns the checkpoint configuration
If resume is False:
- Checks if an output directory exists (to prevent accidental overwriting)
- Returns the original configuration
Args:
cfg (TrainPipelineConfig): The training configuration
Returns:
TrainPipelineConfig: The updated configuration
Raises:
RuntimeError: If resume is True but no checkpoint found, or if resume is False but directory exists
"""
out_dir = cfg.output_dir
# Case 1: Not resuming, but need to check if directory exists to prevent overwrites
if not cfg.resume: if not cfg.resume:
if Logger.get_last_checkpoint_dir(out_dir).exists(): checkpoint_dir = os.path.join(out_dir, CHECKPOINTS_DIR, LAST_CHECKPOINT_LINK)
if os.path.exists(checkpoint_dir):
raise RuntimeError( raise RuntimeError(
f"Output directory {Logger.get_last_checkpoint_dir(out_dir)} already exists. " f"Output directory {checkpoint_dir} already exists. "
"Use `resume=true` to resume training." "Use `resume=true` to resume training."
) )
return cfg return cfg
# if resume == True # Case 2: Resuming training
checkpoint_dir = Logger.get_last_checkpoint_dir(out_dir) checkpoint_dir = os.path.join(out_dir, CHECKPOINTS_DIR, LAST_CHECKPOINT_LINK)
if not checkpoint_dir.exists(): if not os.path.exists(checkpoint_dir):
raise RuntimeError(f"No model checkpoint found in {checkpoint_dir} for resume=True") raise RuntimeError(f"No model checkpoint found in {checkpoint_dir} for resume=True")
checkpoint_cfg_path = str(Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml") # Log that we found a valid checkpoint and are resuming
logging.info( logging.info(
colored( colored(
"Resume=True detected, resuming previous run", "Valid checkpoint found: resume=True detected, resuming previous run",
color="yellow", color="yellow",
attrs=["bold"], attrs=["bold"],
) )
) )
checkpoint_cfg = init_hydra_config(checkpoint_cfg_path) # Load config using Draccus
diff = DeepDiff(OmegaConf.to_container(checkpoint_cfg), OmegaConf.to_container(cfg)) checkpoint_cfg_path = os.path.join(checkpoint_dir, PRETRAINED_MODEL_DIR, "train_config.json")
checkpoint_cfg = TrainPipelineConfig.from_pretrained(checkpoint_cfg_path)
if "values_changed" in diff and "root['resume']" in diff["values_changed"]:
del diff["values_changed"]["root['resume']"]
if len(diff) > 0:
logging.warning(
f"Differences between the checkpoint config and the provided config detected: \n{pformat(diff)}\n"
"Checkpoint configuration takes precedence."
)
# Ensure resume flag is set in returned config
checkpoint_cfg.resume = True checkpoint_cfg.resume = True
return checkpoint_cfg return checkpoint_cfg
def load_training_state( def load_training_state(
cfg: DictConfig, cfg: TrainPipelineConfig,
logger: Logger, optimizers: Optimizer | dict[str, Optimizer],
optimizers: Optimizer | dict,
): ):
"""
Loads the training state (optimizers, step count, etc.) from a checkpoint.
Args:
cfg (TrainPipelineConfig): Training configuration
optimizers (Optimizer | dict): Optimizers to load state into
Returns:
tuple: (optimization_step, interaction_step) or (None, None) if not resuming
"""
if not cfg.resume: if not cfg.resume:
return None, None return None, None
training_state = torch.load( # Construct path to the last checkpoint directory
logger.last_checkpoint_dir / logger.training_state_file_name, weights_only=False checkpoint_dir = os.path.join(cfg.output_dir, CHECKPOINTS_DIR, LAST_CHECKPOINT_LINK)
)
if isinstance(training_state["optimizer"], dict): logging.info(f"Loading training state from {checkpoint_dir}")
assert set(training_state["optimizer"].keys()) == set(optimizers.keys())
for k, v in training_state["optimizer"].items():
optimizers[k].load_state_dict(v)
else:
optimizers.load_state_dict(training_state["optimizer"])
set_global_random_state({k: training_state[k] for k in get_global_random_state()}) try:
return training_state["step"], training_state["interaction_step"] # Use the utility function from train_utils which loads the optimizer state
# The function returns (step, updated_optimizer, scheduler)
step, optimizers, _ = utils_load_training_state(Path(checkpoint_dir), optimizers, None)
# For interaction step, we still need to load the training_state.pt file
training_state_path = os.path.join(checkpoint_dir, TRAINING_STATE_DIR, "training_state.pt")
training_state = torch.load(training_state_path, weights_only=False)
interaction_step = training_state.get("interaction_step", 0)
logging.info(f"Resuming from step {step}, interaction step {interaction_step}")
return step, interaction_step
except Exception as e:
logging.error(f"Failed to load training state: {e}")
return None, None
def log_training_info(cfg: DictConfig, out_dir: str, policy: nn.Module) -> None: def log_training_info(cfg: TrainPipelineConfig, policy: nn.Module) -> None:
"""
Log information about the training process.
Args:
cfg (TrainPipelineConfig): Training configuration
policy (nn.Module): Policy model
"""
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad) num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
num_total_params = sum(p.numel() for p in policy.parameters()) num_total_params = sum(p.numel() for p in policy.parameters())
log_output_dir(out_dir)
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}")
logging.info(f"{cfg.env.task=}") logging.info(f"{cfg.env.task=}")
logging.info(f"{cfg.training.online_steps=}") logging.info(f"{cfg.policy.online_steps=}")
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})") logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})") logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
def initialize_replay_buffer( def initialize_replay_buffer(
cfg: DictConfig, logger: Logger, device: str, storage_device: str cfg: TrainPipelineConfig,
device: str,
storage_device: str
) -> ReplayBuffer: ) -> ReplayBuffer:
"""
Initialize a replay buffer, either empty or from a dataset if resuming.
Args:
cfg (TrainPipelineConfig): Training configuration
device (str): Device to store tensors on
storage_device (str): Device for storage optimization
Returns:
ReplayBuffer: Initialized replay buffer
"""
if not cfg.resume: if not cfg.resume:
return ReplayBuffer( return ReplayBuffer(
capacity=cfg.training.online_buffer_capacity, capacity=cfg.policy.online_buffer_capacity,
device=device, device=device,
state_keys=cfg.policy.input_shapes.keys(), state_keys=cfg.policy.input_features.keys(),
storage_device=storage_device, storage_device=storage_device,
optimize_memory=True, optimize_memory=True,
) )
logging.info("Resume training load the online dataset") logging.info("Resume training load the online dataset")
dataset_path = os.path.join(cfg.output_dir, "dataset")
dataset = LeRobotDataset( dataset = LeRobotDataset(
repo_id=cfg.dataset_repo_id, repo_id=cfg.dataset.dataset_repo_id,
local_files_only=True, local_files_only=True,
root=logger.log_dir / "dataset", root=dataset_path,
) )
return ReplayBuffer.from_lerobot_dataset( return ReplayBuffer.from_lerobot_dataset(
lerobot_dataset=dataset, lerobot_dataset=dataset,
capacity=cfg.training.online_buffer_capacity, capacity=cfg.policy.online_buffer_capacity,
device=device, device=device,
state_keys=cfg.policy.input_shapes.keys(), state_keys=cfg.policy.input_shapes.keys(),
optimize_memory=True, optimize_memory=True,
@ -167,33 +237,45 @@ def initialize_replay_buffer(
def initialize_offline_replay_buffer( def initialize_offline_replay_buffer(
cfg: DictConfig, cfg: TrainPipelineConfig,
logger: Logger,
device: str, device: str,
storage_device: str, storage_device: str,
active_action_dims: list[int] | None = None, active_action_dims: list[int] | None = None,
) -> ReplayBuffer: ) -> ReplayBuffer:
"""
Initialize an offline replay buffer from a dataset.
Args:
cfg (TrainPipelineConfig): Training configuration
device (str): Device to store tensors on
storage_device (str): Device for storage optimization
active_action_dims (list[int] | None): Active action dimensions for masking
Returns:
ReplayBuffer: Initialized offline replay buffer
"""
if not cfg.resume: if not cfg.resume:
logging.info("make_dataset offline buffer") logging.info("make_dataset offline buffer")
offline_dataset = make_dataset(cfg) offline_dataset = make_dataset(cfg)
if cfg.resume: else:
logging.info("load offline dataset") logging.info("load offline dataset")
dataset_offline_path = os.path.join(cfg.output_dir, "dataset_offline")
offline_dataset = LeRobotDataset( offline_dataset = LeRobotDataset(
repo_id=cfg.dataset_repo_id, repo_id=cfg.dataset.dataset_repo_id,
local_files_only=True, local_files_only=True,
root=logger.log_dir / "dataset_offline", root=dataset_offline_path,
) )
logging.info("Convert to a offline replay buffer") logging.info("Convert to a offline replay buffer")
offline_replay_buffer = ReplayBuffer.from_lerobot_dataset( offline_replay_buffer = ReplayBuffer.from_lerobot_dataset(
offline_dataset, offline_dataset,
device=device, device=device,
state_keys=cfg.policy.input_shapes.keys(), state_keys=cfg.policy.input_features.keys(),
action_mask=active_action_dims, action_mask=active_action_dims,
action_delta=cfg.env.wrapper.delta_action, action_delta=cfg.env.wrapper.delta_action,
storage_device=storage_device, storage_device=storage_device,
optimize_memory=True, optimize_memory=True,
capacity=cfg.training.offline_buffer_capacity, capacity=cfg.policy.offline_buffer_capacity,
) )
return offline_replay_buffer return offline_replay_buffer
@ -215,16 +297,23 @@ def get_observation_features(
return observation_features, next_observation_features return observation_features, next_observation_features
def use_threads(cfg: DictConfig) -> bool: def use_threads(cfg: TrainPipelineConfig) -> bool:
return cfg.actor_learner_config.concurrency.learner == "threads" return cfg.policy.concurrency["learner"] == "threads"
def start_learner_threads( def start_learner_threads(
cfg: DictConfig, cfg: TrainPipelineConfig,
logger: Logger, wandb_logger: WandBLogger | None,
out_dir: str,
shutdown_event: any, # Event, shutdown_event: any, # Event,
) -> None: ) -> None:
"""
Start the learner threads for training.
Args:
cfg (TrainPipelineConfig): Training configuration
wandb_logger (WandBLogger | None): Logger for metrics
shutdown_event: Event to signal shutdown
"""
# Create multiprocessing queues # Create multiprocessing queues
transition_queue = Queue() transition_queue = Queue()
interaction_message_queue = Queue() interaction_message_queue = Queue()
@ -255,13 +344,12 @@ def start_learner_threads(
communication_process.start() communication_process.start()
add_actor_information_and_train( add_actor_information_and_train(
cfg, cfg=cfg,
logger, wandb_logger=wandb_logger,
out_dir, shutdown_event=shutdown_event,
shutdown_event, transition_queue=transition_queue,
transition_queue, interaction_message_queue=interaction_message_queue,
interaction_message_queue, parameters_queue=parameters_queue,
parameters_queue,
) )
logging.info("[LEARNER] Training process stopped") logging.info("[LEARNER] Training process stopped")
@ -286,7 +374,7 @@ def start_learner_server(
transition_queue: Queue, transition_queue: Queue,
interaction_message_queue: Queue, interaction_message_queue: Queue,
shutdown_event: any, # Event, shutdown_event: any, # Event,
cfg: DictConfig, cfg: TrainPipelineConfig,
): ):
if not use_threads(cfg): if not use_threads(cfg):
# We need init logging for MP separataly # We need init logging for MP separataly
@ -298,11 +386,11 @@ def start_learner_server(
setup_process_handlers(False) setup_process_handlers(False)
service = learner_service.LearnerService( service = learner_service.LearnerService(
shutdown_event, shutdown_event=shutdown_event,
parameters_queue, parameters_queue=parameters_queue,
cfg.actor_learner_config.policy_parameters_push_frequency, seconds_between_pushes=cfg.policy.actor_learner_config["policy_parameters_push_frequency"],
transition_queue, transition_queue=transition_queue,
interaction_message_queue, interaction_message_queue=interaction_message_queue,
) )
server = grpc.server( server = grpc.server(
@ -318,8 +406,8 @@ def start_learner_server(
server, server,
) )
host = cfg.actor_learner_config.learner_host host = cfg.policy.actor_learner_config["learner_host"]
port = cfg.actor_learner_config.learner_port port = cfg.policy.actor_learner_config["learner_port"]
server.add_insecure_port(f"{host}:{port}") server.add_insecure_port(f"{host}:{port}")
server.start() server.start()
@ -385,9 +473,8 @@ def push_actor_policy_to_queue(parameters_queue: Queue, policy: nn.Module):
def add_actor_information_and_train( def add_actor_information_and_train(
cfg, cfg: TrainPipelineConfig,
logger: Logger, wandb_logger: WandBLogger | None,
out_dir: str,
shutdown_event: any, # Event, shutdown_event: any, # Event,
transition_queue: Queue, transition_queue: Queue,
interaction_message_queue: Queue, interaction_message_queue: Queue,
@ -405,69 +492,60 @@ def add_actor_information_and_train(
- Periodically updates the actor, critic, and temperature optimizers. - Periodically updates the actor, critic, and temperature optimizers.
- Logs training statistics, including loss values and optimization frequency. - Logs training statistics, including loss values and optimization frequency.
**NOTE:**
- This function performs multiple responsibilities (data transfer, training, and logging).
It should ideally be split into smaller functions in the future.
- Due to Python's **Global Interpreter Lock (GIL)**, running separate threads for different tasks
significantly reduces performance. Instead, this function executes all operations in a single thread.
Args: Args:
cfg: Configuration object containing hyperparameters. cfg (TrainPipelineConfig): Configuration object containing hyperparameters.
device (str): The computing device (`"cpu"` or `"cuda"`). wandb_logger (WandBLogger | None): Logger for tracking training progress.
logger (Logger): Logger instance for tracking training progress.
out_dir (str): The output directory for storing training checkpoints and logs.
shutdown_event (Event): Event to signal shutdown. shutdown_event (Event): Event to signal shutdown.
transition_queue (Queue): Queue for receiving transitions from the actor. transition_queue (Queue): Queue for receiving transitions from the actor.
interaction_message_queue (Queue): Queue for receiving interaction messages from the actor. interaction_message_queue (Queue): Queue for receiving interaction messages from the actor.
parameters_queue (Queue): Queue for sending policy parameters to the actor. parameters_queue (Queue): Queue for sending policy parameters to the actor.
""" """
device = get_safe_torch_device(cfg.device, log=True) device = get_safe_torch_device(try_device=cfg.policy.device, log=True)
storage_device = get_safe_torch_device(cfg_device=cfg.training.storage_device) storage_device = get_safe_torch_device(try_device=cfg.policy.storage_device)
logging.info("Initializing policy") logging.info("Initializing policy")
### Instantiate the policy in both the actor and learner processes # Get checkpoint dir for resuming
### To avoid sending a SACPolicy object through the port, we create a policy intance checkpoint_dir = os.path.join(cfg.output_dir, CHECKPOINTS_DIR, LAST_CHECKPOINT_LINK) if cfg.resume else None
### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters pretrained_path = os.path.join(checkpoint_dir, PRETRAINED_MODEL_DIR) if checkpoint_dir else None
# TODO: At some point we should just need make sac policy
# TODO(Adil): This don't work anymore !
policy: SACPolicy = make_policy( policy: SACPolicy = make_policy(
hydra_cfg=cfg, cfg=cfg.policy,
# dataset_stats=offline_dataset.meta.stats if not cfg.resume else None, # ds_meta=cfg.dataset,
# Hack: But if we do online traning, we do not need dataset_stats env_cfg=cfg.env
dataset_stats=None,
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None,
) )
# Update the policy config with the grad_clip_norm value from training config if it exists # Update the policy config with the grad_clip_norm value from training config if it exists
clip_grad_norm_value = cfg.training.grad_clip_norm clip_grad_norm_value:float = cfg.policy.grad_clip_norm
# compile policy # compile policy
policy = torch.compile(policy) policy = torch.compile(policy)
assert isinstance(policy, nn.Module) assert isinstance(policy, nn.Module)
policy.train()
push_actor_policy_to_queue(parameters_queue, policy) push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy)
last_time_policy_pushed = time.time() last_time_policy_pushed = time.time()
optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg, policy) optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg=cfg, policy=policy)
resume_optimization_step, resume_interaction_step = load_training_state(cfg, logger, optimizers) resume_optimization_step, resume_interaction_step = load_training_state(cfg=cfg, optimizers=optimizers)
log_training_info(cfg, out_dir, policy) log_training_info(cfg=cfg, policy= policy)
replay_buffer = initialize_replay_buffer(cfg, logger, device, storage_device) replay_buffer = initialize_replay_buffer(cfg, device, storage_device)
batch_size = cfg.training.batch_size batch_size = cfg.batch_size
offline_replay_buffer = None offline_replay_buffer = None
if cfg.dataset_repo_id is not None: if cfg.dataset is not None:
active_action_dims = None active_action_dims = None
# TODO: FIX THIS
if cfg.env.wrapper.joint_masking_action_space is not None: if cfg.env.wrapper.joint_masking_action_space is not None:
active_action_dims = [ active_action_dims = [
i for i, mask in enumerate(cfg.env.wrapper.joint_masking_action_space) if mask i for i, mask in enumerate(cfg.env.wrapper.joint_masking_action_space) if mask
] ]
offline_replay_buffer = initialize_offline_replay_buffer( offline_replay_buffer = initialize_offline_replay_buffer(
cfg=cfg, cfg=cfg,
logger=logger,
device=device, device=device,
storage_device=storage_device, storage_device=storage_device,
active_action_dims=active_action_dims, active_action_dims=active_action_dims,
@ -484,18 +562,22 @@ def add_actor_information_and_train(
interaction_step_shift = resume_interaction_step if resume_interaction_step is not None else 0 interaction_step_shift = resume_interaction_step if resume_interaction_step is not None else 0
# Extract variables from cfg # Extract variables from cfg
online_step_before_learning = cfg.training.online_step_before_learning online_step_before_learning = cfg.policy.online_step_before_learning
utd_ratio = cfg.policy.utd_ratio utd_ratio = cfg.policy.utd_ratio
dataset_repo_id = cfg.dataset_repo_id
fps = cfg.fps dataset_repo_id = None
log_freq = cfg.training.log_freq if cfg.dataset is not None:
save_freq = cfg.training.save_freq dataset_repo_id = cfg.dataset.repo_id
device = cfg.device
storage_device = cfg.training.storage_device fps = cfg.env.fps
policy_update_freq = cfg.training.policy_update_freq log_freq = cfg.log_freq
policy_parameters_push_frequency = cfg.actor_learner_config.policy_parameters_push_frequency save_freq = cfg.save_freq
save_checkpoint = cfg.training.save_checkpoint device = cfg.policy.device
online_steps = cfg.training.online_steps storage_device = cfg.policy.storage_device
policy_update_freq = cfg.policy.policy_update_freq
policy_parameters_push_frequency = cfg.policy.actor_learner_config["policy_parameters_push_frequency"]
save_checkpoint = cfg.save_checkpoint
online_steps = cfg.policy.online_steps
while True: while True:
if shutdown_event is not None and shutdown_event.is_set(): if shutdown_event is not None and shutdown_event.is_set():
@ -516,7 +598,7 @@ def add_actor_information_and_train(
continue continue
replay_buffer.add(**transition) replay_buffer.add(**transition)
if cfg.dataset_repo_id is not None and transition.get("complementary_info", {}).get( if cfg.dataset.repo_id is not None and transition.get("complementary_info", {}).get(
"is_intervention" "is_intervention"
): ):
offline_replay_buffer.add(**transition) offline_replay_buffer.add(**transition)
@ -528,7 +610,17 @@ def add_actor_information_and_train(
interaction_message = bytes_to_python_object(interaction_message) interaction_message = bytes_to_python_object(interaction_message)
# If cfg.resume, shift the interaction step with the last checkpointed step in order to not break the logging # If cfg.resume, shift the interaction step with the last checkpointed step in order to not break the logging
interaction_message["Interaction step"] += interaction_step_shift interaction_message["Interaction step"] += interaction_step_shift
logger.log_dict(interaction_message, mode="train", custom_step_key="Interaction step")
# Log interaction messages with WandB if available
if wandb_logger:
wandb_logger.log_dict(
d=interaction_message,
mode="train",
custom_step_key="Interaction step"
)
else:
# Log to console if no WandB logger
logging.info(f"Interaction: {interaction_message}")
logging.debug("[LEARNER] Received interactions") logging.debug("[LEARNER] Received interactions")
@ -538,11 +630,11 @@ def add_actor_information_and_train(
logging.debug("[LEARNER] Starting optimization loop") logging.debug("[LEARNER] Starting optimization loop")
time_for_one_optimization_step = time.time() time_for_one_optimization_step = time.time()
for _ in range(utd_ratio - 1): for _ in range(utd_ratio - 1):
batch = replay_buffer.sample(batch_size) batch = replay_buffer.sample(batch_size=batch_size)
if dataset_repo_id is not None: if dataset_repo_id is not None:
batch_offline = offline_replay_buffer.sample(batch_size) batch_offline = offline_replay_buffer.sample(batch_size=batch_size)
batch = concatenate_batch_transitions(batch, batch_offline) batch = concatenate_batch_transitions(left_batch_transitions=batch, right_batch_transition=batch_offline)
actions = batch["action"] actions = batch["action"]
rewards = batch["reward"] rewards = batch["reward"]
@ -552,7 +644,7 @@ def add_actor_information_and_train(
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations) check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
observation_features, next_observation_features = get_observation_features( observation_features, next_observation_features = get_observation_features(
policy, observations, next_observations policy=policy, observations=observations, next_observations=next_observations
) )
loss_critic = policy.compute_loss_critic( loss_critic = policy.compute_loss_critic(
observations=observations, observations=observations,
@ -568,15 +660,15 @@ def add_actor_information_and_train(
# clip gradients # clip gradients
critic_grad_norm = torch.nn.utils.clip_grad_norm_( critic_grad_norm = torch.nn.utils.clip_grad_norm_(
policy.critic_ensemble.parameters(), clip_grad_norm_value parameters=policy.critic_ensemble.parameters(), max_norm=clip_grad_norm_value
) )
optimizers["critic"].step() optimizers["critic"].step()
batch = replay_buffer.sample(batch_size) batch = replay_buffer.sample(batch_size=batch_size)
if dataset_repo_id is not None: if dataset_repo_id is not None:
batch_offline = offline_replay_buffer.sample(batch_size) batch_offline = offline_replay_buffer.sample(batch_size=batch_size)
batch = concatenate_batch_transitions( batch = concatenate_batch_transitions(
left_batch_transitions=batch, right_batch_transition=batch_offline left_batch_transitions=batch, right_batch_transition=batch_offline
) )
@ -590,7 +682,7 @@ def add_actor_information_and_train(
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations) check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
observation_features, next_observation_features = get_observation_features( observation_features, next_observation_features = get_observation_features(
policy, observations, next_observations policy=policy, observations=observations, next_observations=next_observations
) )
loss_critic = policy.compute_loss_critic( loss_critic = policy.compute_loss_critic(
observations=observations, observations=observations,
@ -606,7 +698,7 @@ def add_actor_information_and_train(
# clip gradients # clip gradients
critic_grad_norm = torch.nn.utils.clip_grad_norm_( critic_grad_norm = torch.nn.utils.clip_grad_norm_(
policy.critic_ensemble.parameters(), clip_grad_norm_value parameters=policy.critic_ensemble.parameters(), max_norm=clip_grad_norm_value
).item() ).item()
optimizers["critic"].step() optimizers["critic"].step()
@ -627,7 +719,7 @@ def add_actor_information_and_train(
# clip gradients # clip gradients
actor_grad_norm = torch.nn.utils.clip_grad_norm_( actor_grad_norm = torch.nn.utils.clip_grad_norm_(
policy.actor.parameters_to_optimize, clip_grad_norm_value parameters=policy.actor.parameters_to_optimize, max_norm=clip_grad_norm_value
).item() ).item()
optimizers["actor"].step() optimizers["actor"].step()
@ -645,7 +737,7 @@ def add_actor_information_and_train(
# clip gradients # clip gradients
temp_grad_norm = torch.nn.utils.clip_grad_norm_( temp_grad_norm = torch.nn.utils.clip_grad_norm_(
[policy.log_alpha], clip_grad_norm_value parameters=[policy.log_alpha], max_norm=clip_grad_norm_value
).item() ).item()
optimizers["temperature"].step() optimizers["temperature"].step()
@ -655,7 +747,7 @@ def add_actor_information_and_train(
training_infos["temperature"] = policy.temperature training_infos["temperature"] = policy.temperature
if time.time() - last_time_policy_pushed > policy_parameters_push_frequency: if time.time() - last_time_policy_pushed > policy_parameters_push_frequency:
push_actor_policy_to_queue(parameters_queue, policy) push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy)
last_time_policy_pushed = time.time() last_time_policy_pushed = time.time()
policy.update_target_networks() policy.update_target_networks()
@ -665,15 +757,26 @@ def add_actor_information_and_train(
if offline_replay_buffer is not None: if offline_replay_buffer is not None:
training_infos["offline_replay_buffer_size"] = len(offline_replay_buffer) training_infos["offline_replay_buffer_size"] = len(offline_replay_buffer)
training_infos["Optimization step"] = optimization_step training_infos["Optimization step"] = optimization_step
logger.log_dict(d=training_infos, mode="train", custom_step_key="Optimization step")
# logging.info(f"Training infos: {training_infos}") # Log training metrics
if wandb_logger:
wandb_logger.log_dict(
d=training_infos,
mode="train",
custom_step_key="Optimization step"
)
else:
# Log to console if no WandB logger
logging.info(f"Training: {training_infos}")
time_for_one_optimization_step = time.time() - time_for_one_optimization_step time_for_one_optimization_step = time.time() - time_for_one_optimization_step
frequency_for_one_optimization_step = 1 / (time_for_one_optimization_step + 1e-9) frequency_for_one_optimization_step = 1 / (time_for_one_optimization_step + 1e-9)
logging.info(f"[LEARNER] Optimization frequency loop [Hz]: {frequency_for_one_optimization_step}") logging.info(f"[LEARNER] Optimization frequency loop [Hz]: {frequency_for_one_optimization_step}")
logger.log_dict( # Log optimization frequency
if wandb_logger:
wandb_logger.log_dict(
{ {
"Optimization frequency loop [Hz]": frequency_for_one_optimization_step, "Optimization frequency loop [Hz]": frequency_for_one_optimization_step,
"Optimization step": optimization_step, "Optimization step": optimization_step,
@ -693,35 +796,45 @@ def add_actor_information_and_train(
interaction_step = ( interaction_step = (
interaction_message["Interaction step"] if interaction_message is not None else 0 interaction_message["Interaction step"] if interaction_message is not None else 0
) )
logger.save_checkpoint(
# Create checkpoint directory
checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, online_steps, optimization_step)
# Save checkpoint
save_checkpoint(
checkpoint_dir,
optimization_step, optimization_step,
cfg,
policy, policy,
optimizers, optimizers,
scheduler=None, scheduler=None
identifier=step_identifier,
interaction_step=interaction_step,
) )
# Update the "last" symlink
update_last_checkpoint(checkpoint_dir)
# TODO : temporarly save replay buffer here, remove later when on the robot # TODO : temporarly save replay buffer here, remove later when on the robot
# We want to control this with the keyboard inputs # We want to control this with the keyboard inputs
dataset_dir = logger.log_dir / "dataset" dataset_dir = os.path.join(cfg.output_dir, "dataset")
if dataset_dir.exists() and dataset_dir.is_dir(): if os.path.exists(dataset_dir) and os.path.isdir(dataset_dir):
shutil.rmtree( shutil.rmtree(dataset_dir)
dataset_dir,
)
replay_buffer.to_lerobot_dataset(dataset_repo_id, fps=fps, root=logger.log_dir / "dataset")
if offline_replay_buffer is not None:
dataset_dir = logger.log_dir / "dataset_offline"
if dataset_dir.exists() and dataset_dir.is_dir(): # Save dataset
shutil.rmtree( replay_buffer.to_lerobot_dataset(
dataset_dir, dataset_repo_id,
fps=fps,
root=dataset_dir
) )
if offline_replay_buffer is not None:
dataset_offline_dir = os.path.join(cfg.output_dir, "dataset_offline")
if os.path.exists(dataset_offline_dir) and os.path.isdir(dataset_offline_dir):
shutil.rmtree(dataset_offline_dir)
offline_replay_buffer.to_lerobot_dataset( offline_replay_buffer.to_lerobot_dataset(
cfg.dataset_repo_id, cfg.dataset.dataset_repo_id,
fps=cfg.fps, fps=cfg.env.fps,
root=logger.log_dir / "dataset_offline", root=dataset_offline_dir,
) )
logging.info("Resume training") logging.info("Resume training")
@ -756,12 +869,12 @@ def make_optimizers_and_scheduler(cfg, policy: nn.Module):
optimizer_actor = torch.optim.Adam( optimizer_actor = torch.optim.Adam(
# NOTE: Handle the case of shared encoder where the encoder weights are not optimized with the gradient of the actor # NOTE: Handle the case of shared encoder where the encoder weights are not optimized with the gradient of the actor
params=policy.actor.parameters_to_optimize, params=policy.actor.parameters_to_optimize,
lr=policy.config.actor_lr, lr=cfg.policy.actor_lr,
) )
optimizer_critic = torch.optim.Adam( optimizer_critic = torch.optim.Adam(
params=policy.critic_ensemble.parameters(), lr=policy.config.critic_lr params=policy.critic_ensemble.parameters(), lr=cfg.policy.critic_lr
) )
optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=policy.config.critic_lr) optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=cfg.policy.critic_lr)
lr_scheduler = None lr_scheduler = None
optimizers = { optimizers = {
"actor": optimizer_actor, "actor": optimizer_actor,
@ -771,19 +884,38 @@ def make_optimizers_and_scheduler(cfg, policy: nn.Module):
return optimizers, lr_scheduler return optimizers, lr_scheduler
def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = None): def train(cfg: TrainPipelineConfig, job_name: str | None = None):
if out_dir is None: """
raise NotImplementedError() Main training function that initializes and runs the training process.
Args:
cfg (TrainPipelineConfig): The training configuration
job_name (str | None, optional): Job name for logging. Defaults to None.
"""
if cfg.output_dir is None:
raise ValueError("Output directory must be specified in config")
if job_name is None: if job_name is None:
raise NotImplementedError() job_name = cfg.job_name
if job_name is None:
raise ValueError("Job name must be specified either in config or as a parameter")
init_logging() init_logging()
logging.info(pformat(OmegaConf.to_container(cfg))) logging.info(pformat(cfg.to_dict()))
logger = Logger(cfg, out_dir, wandb_job_name=job_name) # Setup WandB logging if enabled
cfg = handle_resume_logic(cfg, out_dir) if cfg.wandb.enable and cfg.wandb.project:
from lerobot.common.utils.wandb_utils import WandBLogger
wandb_logger = WandBLogger(cfg)
else:
wandb_logger = None
logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"]))
set_global_seed(cfg.seed) # Handle resume logic
cfg = handle_resume_logic(cfg)
set_seed(seed=cfg.seed)
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True
@ -791,24 +923,23 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
shutdown_event = setup_process_handlers(use_threads(cfg)) shutdown_event = setup_process_handlers(use_threads(cfg))
start_learner_threads( start_learner_threads(
cfg, cfg=cfg,
logger, wandb_logger=wandb_logger,
out_dir, shutdown_event=shutdown_event,
shutdown_event,
) )
@hydra.main(version_base="1.2", config_name="default", config_path="../../configs") @parser.wrap()
def train_cli(cfg: dict): def train_cli(cfg: TrainPipelineConfig):
if not use_threads(cfg): if not use_threads(cfg):
import torch.multiprocessing as mp import torch.multiprocessing as mp
mp.set_start_method("spawn") mp.set_start_method("spawn")
# Use the job_name from the config
train( train(
cfg, cfg,
out_dir=hydra.core.hydra_config.HydraConfig.get().run.dir, job_name=cfg.job_name,
job_name=hydra.core.hydra_config.HydraConfig.get().job.name,
) )
logging.info("[LEARNER] train_cli finished") logging.info("[LEARNER] train_cli finished")
@ -816,5 +947,4 @@ def train_cli(cfg: dict):
if __name__ == "__main__": if __name__ == "__main__":
train_cli() train_cli()
logging.info("[LEARNER] main finished") logging.info("[LEARNER] main finished")

View File

@ -1,4 +1,7 @@
from typing import Any import logging
import time
from dataclasses import dataclass, field
from typing import Any, Dict, Optional, Tuple
import einops import einops
import gymnasium as gym import gymnasium as gym
@ -6,7 +9,11 @@ import numpy as np
import torch import torch
from mani_skill.utils.wrappers.record import RecordEpisode from mani_skill.utils.wrappers.record import RecordEpisode
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
from omegaconf import DictConfig
from lerobot.common.envs.configs import EnvConfig, ManiskillEnvConfig
from lerobot.configs import parser
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.common.constants import ACTION, OBS_IMAGE, OBS_ROBOT
def preprocess_maniskill_observation( def preprocess_maniskill_observation(
@ -46,9 +53,14 @@ def preprocess_maniskill_observation(
return return_observations return return_observations
class ManiSkillObservationWrapper(gym.ObservationWrapper): class ManiSkillObservationWrapper(gym.ObservationWrapper):
def __init__(self, env, device: torch.device = "cuda"): def __init__(self, env, device: torch.device = "cuda"):
super().__init__(env) super().__init__(env)
if isinstance(device, str):
device = torch.device(device)
self.device = device self.device = device
def observation(self, observation): def observation(self, observation):
@ -108,76 +120,129 @@ class ManiSkillMultiplyActionWrapper(gym.Wrapper):
return obs, reward, terminated, truncated, info return obs, reward, terminated, truncated, info
class BatchCompatibleWrapper(gym.ObservationWrapper):
"""Ensures observations are batch-compatible by adding a batch dimension if necessary."""
def __init__(self, env):
super().__init__(env)
def observation(self, observation: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
for key in observation:
if "image" in key and observation[key].dim() == 3:
observation[key] = observation[key].unsqueeze(0)
if "state" in key and observation[key].dim() == 1:
observation[key] = observation[key].unsqueeze(0)
return observation
class TimeLimitWrapper(gym.Wrapper):
"""Adds a time limit to the environment based on fps and control_time."""
def __init__(self, env, control_time_s, fps):
super().__init__(env)
self.control_time_s = control_time_s
self.fps = fps
self.max_episode_steps = int(self.control_time_s * self.fps)
self.current_step = 0
def step(self, action):
obs, reward, terminated, truncated, info = self.env.step(action)
self.current_step += 1
if self.current_step >= self.max_episode_steps:
terminated = True
return obs, reward, terminated, truncated, info
def reset(self, *, seed=None, options=None):
self.current_step = 0
return super().reset(seed=seed, options=options)
def make_maniskill( def make_maniskill(
cfg: DictConfig, cfg: ManiskillEnvConfig,
n_envs: int | None = None, n_envs: int | None = None,
) -> gym.Env: ) -> gym.Env:
""" """
Factory function to create a ManiSkill environment with standard wrappers. Factory function to create a ManiSkill environment with standard wrappers.
Args: Args:
task: Name of the ManiSkill task cfg: Configuration for the ManiSkill environment
obs_mode: Observation mode (rgb, rgbd, etc)
control_mode: Control mode for the robot
render_mode: Rendering mode
sensor_configs: Camera sensor configurations
n_envs: Number of parallel environments n_envs: Number of parallel environments
Returns: Returns:
A wrapped ManiSkill environment A wrapped ManiSkill environment
""" """
env = gym.make( env = gym.make(
cfg.env.task, cfg.task,
obs_mode=cfg.env.obs, obs_mode=cfg.obs_type,
control_mode=cfg.env.control_mode, control_mode=cfg.control_mode,
render_mode=cfg.env.render_mode, render_mode=cfg.render_mode,
sensor_configs={"width": cfg.env.image_size, "height": cfg.env.image_size}, sensor_configs={"width": cfg.image_size, "height": cfg.image_size},
num_envs=n_envs, num_envs=n_envs,
) )
if cfg.env.video_record.enabled: # Add video recording if enabled
if cfg.video_record.enabled:
env = RecordEpisode( env = RecordEpisode(
env, env,
output_dir=cfg.env.video_record.record_dir, output_dir=cfg.video_record.record_dir,
save_trajectory=True, save_trajectory=True,
trajectory_name=cfg.env.video_record.trajectory_name, trajectory_name=cfg.video_record.trajectory_name,
save_video=True, save_video=True,
video_fps=30, video_fps=30,
) )
env = ManiSkillObservationWrapper(env, device=cfg.env.device)
# Add observation and image processing
env = ManiSkillObservationWrapper(env, device=cfg.device)
env = ManiSkillVectorEnv(env, ignore_terminations=True, auto_reset=False) env = ManiSkillVectorEnv(env, ignore_terminations=True, auto_reset=False)
env._max_episode_steps = env.max_episode_steps = 50 # gym_utils.find_max_episode_steps_value(env) env._max_episode_steps = env.max_episode_steps = cfg.episode_length
env.unwrapped.metadata["render_fps"] = 20 env.unwrapped.metadata["render_fps"] = cfg.fps
# Add compatibility wrappers
env = ManiSkillCompat(env) env = ManiSkillCompat(env)
env = ManiSkillActionWrapper(env) env = ManiSkillActionWrapper(env)
env = ManiSkillMultiplyActionWrapper(env, multiply_factor=0.03) env = ManiSkillMultiplyActionWrapper(env, multiply_factor=0.03) # Scale actions for better control
return env return env
if __name__ == "__main__": @parser.wrap()
import argparse def main(cfg: ManiskillEnvConfig):
"""Main function to run the ManiSkill environment."""
# Create the ManiSkill environment
env = make_maniskill(cfg, n_envs=1)
import hydra # Reset the environment
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default="lerobot/configs/env/maniskill_example.yaml")
args = parser.parse_args()
# Initialize config
with hydra.initialize(version_base=None, config_path="../../configs"):
cfg = hydra.compose(config_name="env/maniskill_example.yaml")
env = make_maniskill(
task=cfg.env.task,
obs_mode=cfg.env.obs,
control_mode=cfg.env.control_mode,
render_mode=cfg.env.render_mode,
sensor_configs={"width": cfg.env.render_size, "height": cfg.env.render_size},
)
print("env done")
obs, info = env.reset() obs, info = env.reset()
random_action = env.action_space.sample()
obs, reward, terminated, truncated, info = env.step(random_action) # Run a simple interaction loop
sum_reward = 0
for i in range(100):
# Sample a random action
action = env.action_space.sample()
# Step the environment
start_time = time.perf_counter()
obs, reward, terminated, truncated, info = env.step(action)
step_time = time.perf_counter() - start_time
sum_reward += reward
# Log information
# Reset if episode terminated
if terminated or truncated:
logging.info(f"Step {i}, reward: {sum_reward}, step time: {step_time}s")
sum_reward = 0
obs, info = env.reset()
# Close the environment
env.close()
# if __name__ == "__main__":
# logging.basicConfig(level=logging.INFO)
# main()
if __name__ == "__main__":
import draccus
config = ManiskillEnvConfig()
draccus.set_config_type("json")
draccus.dump(config=config, stream=open(file='run_config.json', mode='w'), )