[WIP] Non functional yet
Add ManiSkill environment configuration and wrappers - Introduced `VideoRecordConfig` for video recording settings. - Added `ManiskillEnvConfig` to encapsulate environment-specific configurations. - Implemented various wrappers for the ManiSkill environment, including observation and action scaling. - Enhanced the `make_maniskill` function to create a wrapped ManiSkill environment with video recording and observation processing. - Updated the `actor_server` and `learner_server` to utilize the new configuration structure. - Refactored the training pipeline to accommodate the new environment and policy configurations.
This commit is contained in:
parent
b7b6d8102f
commit
dd37bd412e
|
@ -154,3 +154,61 @@ class XarmEnv(EnvConfig):
|
|||
"visualization_height": self.visualization_height,
|
||||
"max_episode_steps": self.episode_length,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class VideoRecordConfig:
|
||||
"""Configuration for video recording in ManiSkill environments."""
|
||||
enabled: bool = False
|
||||
record_dir: str = "videos"
|
||||
trajectory_name: str = "trajectory"
|
||||
|
||||
@EnvConfig.register_subclass("maniskill_push")
|
||||
@dataclass
|
||||
class ManiskillEnvConfig(EnvConfig):
|
||||
"""Configuration for the ManiSkill environment."""
|
||||
name: str = "maniskill/pushcube"
|
||||
task: str = "PushCube-v1"
|
||||
image_size: int = 64
|
||||
control_mode: str = "pd_ee_delta_pose"
|
||||
state_dim: int = 25
|
||||
action_dim: int = 7
|
||||
fps: int = 400
|
||||
episode_length: int = 50
|
||||
obs_type: str = "rgb"
|
||||
render_mode: str = "rgb_array"
|
||||
render_size: int = 64
|
||||
device: str = "cuda"
|
||||
robot: str = "so100" # This is a hack to make the robot config work
|
||||
video_record: VideoRecordConfig = field(default_factory=VideoRecordConfig)
|
||||
features: dict[str, PolicyFeature] = field(
|
||||
default_factory=lambda: {
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
|
||||
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 64, 64)),
|
||||
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(25,)),
|
||||
}
|
||||
)
|
||||
features_map: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"action": ACTION,
|
||||
"observation.image": OBS_IMAGE,
|
||||
"observation.state": OBS_ROBOT,
|
||||
}
|
||||
)
|
||||
reward_classifier: dict[str, str | None] = field(
|
||||
default_factory=lambda: {
|
||||
"pretrained_path": None,
|
||||
"config_path": None,
|
||||
}
|
||||
)
|
||||
|
||||
@property
|
||||
def gym_kwargs(self) -> dict:
|
||||
return {
|
||||
"obs_type": self.obs_type,
|
||||
"render_mode": self.render_mode,
|
||||
"max_episode_steps": self.episode_length,
|
||||
"control_mode": self.control_mode,
|
||||
"sensor_configs": {"width": self.image_size, "height": self.image_size},
|
||||
"num_envs": 1,
|
||||
}
|
|
@ -69,88 +69,3 @@ def make_env(cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False) -> g
|
|||
|
||||
return env
|
||||
|
||||
|
||||
def make_maniskill_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv | None:
|
||||
"""Make ManiSkill3 gym environment"""
|
||||
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
|
||||
|
||||
env = gym.make(
|
||||
cfg.env.task,
|
||||
obs_mode=cfg.env.obs,
|
||||
control_mode=cfg.env.control_mode,
|
||||
render_mode=cfg.env.render_mode,
|
||||
sensor_configs=dict(width=cfg.env.image_size, height=cfg.env.image_size),
|
||||
num_envs=n_envs,
|
||||
)
|
||||
# cfg.env_cfg.control_mode = cfg.eval_env_cfg.control_mode = env.control_mode
|
||||
env = ManiSkillVectorEnv(env, ignore_terminations=True)
|
||||
# state should have the size of 25
|
||||
# env = ConvertToLeRobotEnv(env, n_envs)
|
||||
# env = PixelWrapper(cfg, env, n_envs)
|
||||
env._max_episode_steps = env.max_episode_steps = 50 # gym_utils.find_max_episode_steps_value(env)
|
||||
env.unwrapped.metadata["render_fps"] = 20
|
||||
|
||||
return env
|
||||
|
||||
|
||||
class PixelWrapper(gym.Wrapper):
|
||||
"""
|
||||
Wrapper for pixel observations. Works with Maniskill vectorized environments
|
||||
"""
|
||||
|
||||
def __init__(self, cfg, env, num_envs, num_frames=3):
|
||||
super().__init__(env)
|
||||
self.cfg = cfg
|
||||
self.env = env
|
||||
self.observation_space = gym.spaces.Box(
|
||||
low=0,
|
||||
high=255,
|
||||
shape=(num_envs, num_frames * 3, cfg.env.render_size, cfg.env.render_size),
|
||||
dtype=np.uint8,
|
||||
)
|
||||
self._frames = deque([], maxlen=num_frames)
|
||||
self._render_size = cfg.env.render_size
|
||||
|
||||
def _get_obs(self, obs):
|
||||
frame = obs["sensor_data"]["base_camera"]["rgb"].cpu().permute(0, 3, 1, 2)
|
||||
self._frames.append(frame)
|
||||
return {"pixels": torch.from_numpy(np.concatenate(self._frames, axis=1)).to(self.env.device)}
|
||||
|
||||
def reset(self, seed):
|
||||
obs, info = self.env.reset() # (seed=seed)
|
||||
for _ in range(self._frames.maxlen):
|
||||
obs_frames = self._get_obs(obs)
|
||||
return obs_frames, info
|
||||
|
||||
def step(self, action):
|
||||
obs, reward, terminated, truncated, info = self.env.step(action)
|
||||
return self._get_obs(obs), reward, terminated, truncated, info
|
||||
|
||||
|
||||
# TODO: Remove this
|
||||
class ConvertToLeRobotEnv(gym.Wrapper):
|
||||
def __init__(self, env, num_envs):
|
||||
super().__init__(env)
|
||||
|
||||
def reset(self, seed=None, options=None):
|
||||
obs, info = self.env.reset(seed=seed, options={})
|
||||
return self._get_obs(obs), info
|
||||
|
||||
def step(self, action):
|
||||
obs, reward, terminated, truncated, info = self.env.step(action)
|
||||
return self._get_obs(obs), reward, terminated, truncated, info
|
||||
|
||||
def _get_obs(self, observation):
|
||||
sensor_data = observation.pop("sensor_data")
|
||||
del observation["sensor_param"]
|
||||
images = []
|
||||
for cam_data in sensor_data.values():
|
||||
images.append(cam_data["rgb"])
|
||||
|
||||
images = torch.concat(images, axis=-1)
|
||||
# flatten the rest of the data which should just be state data
|
||||
observation = common.flatten_state_dict(observation, use_torch=True, device=self.base_env.device)
|
||||
ret = dict()
|
||||
ret["state"] = observation
|
||||
ret["pixels"] = images
|
||||
return ret
|
||||
|
|
|
@ -31,12 +31,19 @@ class SACConfig(PreTrainedConfig):
|
|||
Args:
|
||||
n_obs_steps: Number of environment steps worth of observations to pass to the policy.
|
||||
normalization_mapping: Mapping from feature types to normalization modes.
|
||||
dataset_stats: Statistics for normalizing different data types.
|
||||
camera_number: Number of cameras to use.
|
||||
device: Device to use for training.
|
||||
storage_device: Device to use for storage.
|
||||
vision_encoder_name: Name of the vision encoder to use.
|
||||
freeze_vision_encoder: Whether to freeze the vision encoder.
|
||||
image_encoder_hidden_dim: Hidden dimension for the image encoder.
|
||||
shared_encoder: Whether to use a shared encoder.
|
||||
online_steps: Total number of online training steps.
|
||||
online_env_seed: Seed for the online environment.
|
||||
online_buffer_capacity: Capacity of the online replay buffer.
|
||||
online_step_before_learning: Number of steps to collect before starting learning.
|
||||
policy_update_freq: Frequency of policy updates.
|
||||
discount: Discount factor for the RL algorithm.
|
||||
temperature_init: Initial temperature for entropy regularization.
|
||||
num_critics: Number of critic networks.
|
||||
|
@ -54,6 +61,8 @@ class SACConfig(PreTrainedConfig):
|
|||
critic_network_kwargs: Additional arguments for critic networks.
|
||||
actor_network_kwargs: Additional arguments for actor network.
|
||||
policy_kwargs: Additional arguments for policy.
|
||||
actor_learner_config: Configuration for actor-learner communication.
|
||||
concurrency: Configuration for concurrency model.
|
||||
"""
|
||||
|
||||
# Input / output structure
|
||||
|
@ -86,13 +95,21 @@ class SACConfig(PreTrainedConfig):
|
|||
|
||||
# Architecture specifics
|
||||
camera_number: int = 1
|
||||
device: str = "cuda"
|
||||
storage_device: str = "cpu"
|
||||
# Set to "helper2424/resnet10" for hil serl
|
||||
vision_encoder_name: str | None = None
|
||||
freeze_vision_encoder: bool = True
|
||||
image_encoder_hidden_dim: int = 32
|
||||
shared_encoder: bool = True
|
||||
|
||||
|
||||
# Training parameter
|
||||
online_steps: int = 1000000
|
||||
online_env_seed: int = 10000
|
||||
online_buffer_capacity: int = 10000
|
||||
online_step_before_learning: int = 100
|
||||
policy_update_freq: int = 1
|
||||
|
||||
# SAC algorithm parameters
|
||||
discount: float = 0.99
|
||||
temperature_init: float = 1.0
|
||||
|
@ -132,11 +149,17 @@ class SACConfig(PreTrainedConfig):
|
|||
}
|
||||
)
|
||||
|
||||
# Deprecated, kept for backward compatibility
|
||||
actor_learner_config: dict[str, str | int] = field(
|
||||
default_factory=lambda: {
|
||||
"learner_host": "127.0.0.1",
|
||||
"learner_port": 50051,
|
||||
"policy_parameters_push_frequency": 4,
|
||||
}
|
||||
)
|
||||
concurrency: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"actor": "threads",
|
||||
"learner": "threads"
|
||||
}
|
||||
)
|
||||
|
||||
|
|
|
@ -92,6 +92,8 @@ class WandBLogger:
|
|||
resume="must" if cfg.resume else None,
|
||||
mode=self.cfg.mode if self.cfg.mode in ["online", "offline", "disabled"] else "online",
|
||||
)
|
||||
# Handle custom step key for rl asynchronous training.
|
||||
self._wandb_custom_step_key: set[str] | None = None
|
||||
print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"]))
|
||||
logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}")
|
||||
self._wandb = wandb
|
||||
|
@ -108,9 +110,24 @@ class WandBLogger:
|
|||
artifact.add_file(checkpoint_dir / PRETRAINED_MODEL_DIR / SAFETENSORS_SINGLE_FILE)
|
||||
self._wandb.log_artifact(artifact)
|
||||
|
||||
def log_dict(self, d: dict, step: int, mode: str = "train"):
|
||||
def log_dict(self, d: dict, step: int, mode: str = "train", custom_step_key: str | None = None):
|
||||
if mode not in {"train", "eval"}:
|
||||
raise ValueError(mode)
|
||||
if step is None and custom_step_key is None:
|
||||
raise ValueError("Either step or custom_step_key must be provided.")
|
||||
|
||||
# NOTE: This is not simple. Wandb step is it must always monotonically increase and it
|
||||
# increases with each wandb.log call, but in the case of asynchronous RL for example,
|
||||
# multiple time steps is possible for example, the interaction step with the environment,
|
||||
# the training step, the evaluation step, etc. So we need to define a custom step key
|
||||
# to log the correct step for each metric.
|
||||
if custom_step_key is not None:
|
||||
if self._wandb_custom_step_key is None:
|
||||
self._wandb_custom_step_key = set()
|
||||
new_custom_key = f"{mode}/{custom_step_key}"
|
||||
if new_custom_key not in self._wandb_custom_step_key:
|
||||
self._wandb_custom_step_key.add(new_custom_key)
|
||||
self._wandb.define_metric(new_custom_key, hidden=True)
|
||||
|
||||
for k, v in d.items():
|
||||
if not isinstance(v, (int, float, str)):
|
||||
|
@ -118,7 +135,26 @@ class WandBLogger:
|
|||
f'WandB logging of key "{k}" was ignored as its type is not handled by this wrapper.'
|
||||
)
|
||||
continue
|
||||
self._wandb.log({f"{mode}/{k}": v}, step=step)
|
||||
|
||||
# Do not log the custom step key itself.
|
||||
if (
|
||||
self._wandb_custom_step_key is not None
|
||||
and k in self._wandb_custom_step_key
|
||||
):
|
||||
continue
|
||||
|
||||
if custom_step_key is not None:
|
||||
value_custom_step = d[custom_step_key]
|
||||
self._wandb.log(
|
||||
{
|
||||
f"{mode}/{k}": v,
|
||||
f"{mode}/{custom_step_key}": value_custom_step,
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
self._wandb.log(data={f"{mode}/{k}": v}, step=step)
|
||||
|
||||
|
||||
def log_video(self, video_path: str, step: int, mode: str = "train"):
|
||||
if mode not in {"train", "eval"}:
|
||||
|
|
|
@ -34,11 +34,10 @@ TRAIN_CONFIG_NAME = "train_config.json"
|
|||
|
||||
@dataclass
|
||||
class TrainPipelineConfig(HubMixin):
|
||||
dataset: DatasetConfig
|
||||
dataset: DatasetConfig | None = None # NOTE: In RL, we don't need a dataset
|
||||
env: envs.EnvConfig | None = None
|
||||
policy: PreTrainedConfig | None = None
|
||||
# Set `dir` to where you would like to save all of the run outputs. If you run another training session
|
||||
# with the same value for `dir` its contents will be overwritten unless you set `resume` to true.
|
||||
# Set `dir` to where you would like to save all of the run outputs. If you run another training session # with the same value for `dir` its contents will be overwritten unless you set `resume` to true.
|
||||
output_dir: Path | None = None
|
||||
job_name: str | None = None
|
||||
# Set `resume` to true to resume a previous run. In order for this to work, you will need to make sure
|
||||
|
|
|
@ -21,26 +21,23 @@ from statistics import mean, quantiles
|
|||
|
||||
# from lerobot.scripts.eval import eval_policy
|
||||
import grpc
|
||||
import hydra
|
||||
import torch
|
||||
from omegaconf import DictConfig
|
||||
from torch import nn
|
||||
from torch.multiprocessing import Event, Queue
|
||||
|
||||
# TODO: Remove the import of maniskill
|
||||
# from lerobot.common.envs.factory import make_maniskill_env
|
||||
# from lerobot.common.envs.utils import preprocess_maniskill_observation
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.policies.sac.modeling_sac import SACPolicy
|
||||
from lerobot.common.robot_devices.robots.factory import make_robot
|
||||
from lerobot.common.robot_devices.robots.utils import Robot
|
||||
from lerobot.common.robot_devices.robots.utils import Robot, make_robot
|
||||
from lerobot.common.robot_devices.utils import busy_wait
|
||||
from lerobot.common.utils.random_utils import set_seed
|
||||
from lerobot.common.utils.utils import (
|
||||
TimerManager,
|
||||
get_safe_torch_device,
|
||||
init_logging,
|
||||
set_global_seed,
|
||||
)
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.scripts.server import hilserl_pb2, hilserl_pb2_grpc, learner_service
|
||||
from lerobot.scripts.server.buffer import (
|
||||
Transition,
|
||||
|
@ -61,7 +58,7 @@ ACTOR_SHUTDOWN_TIMEOUT = 30
|
|||
|
||||
|
||||
def receive_policy(
|
||||
cfg: DictConfig,
|
||||
cfg: TrainPipelineConfig,
|
||||
parameters_queue: Queue,
|
||||
shutdown_event: any, # Event,
|
||||
learner_client: hilserl_pb2_grpc.LearnerServiceStub | None = None,
|
||||
|
@ -72,12 +69,12 @@ def receive_policy(
|
|||
if not use_threads(cfg):
|
||||
# Setup process handlers to handle shutdown signal
|
||||
# But use shutdown event from the main process
|
||||
setup_process_handlers(False)
|
||||
setup_process_handlers(use_threads=False)
|
||||
|
||||
if grpc_channel is None or learner_client is None:
|
||||
learner_client, grpc_channel = learner_service_client(
|
||||
host=cfg.actor_learner_config.learner_host,
|
||||
port=cfg.actor_learner_config.learner_port,
|
||||
host=cfg.policy.actor_learner_config["learner_host"],
|
||||
port=cfg.policy.actor_learner_config["learner_port"],
|
||||
)
|
||||
|
||||
try:
|
||||
|
@ -132,7 +129,7 @@ def interactions_stream(
|
|||
|
||||
|
||||
def send_transitions(
|
||||
cfg: DictConfig,
|
||||
cfg: TrainPipelineConfig,
|
||||
transitions_queue: Queue,
|
||||
shutdown_event: any, # Event,
|
||||
learner_client: hilserl_pb2_grpc.LearnerServiceStub | None = None,
|
||||
|
@ -156,8 +153,8 @@ def send_transitions(
|
|||
|
||||
if grpc_channel is None or learner_client is None:
|
||||
learner_client, grpc_channel = learner_service_client(
|
||||
host=cfg.actor_learner_config.learner_host,
|
||||
port=cfg.actor_learner_config.learner_port,
|
||||
host=cfg.policy.actor_learner_config["learner_host"],
|
||||
port=cfg.policy.actor_learner_config["learner_port"],
|
||||
)
|
||||
|
||||
try:
|
||||
|
@ -173,7 +170,7 @@ def send_transitions(
|
|||
|
||||
|
||||
def send_interactions(
|
||||
cfg: DictConfig,
|
||||
cfg: TrainPipelineConfig,
|
||||
interactions_queue: Queue,
|
||||
shutdown_event: any, # Event,
|
||||
learner_client: hilserl_pb2_grpc.LearnerServiceStub | None = None,
|
||||
|
@ -196,8 +193,8 @@ def send_interactions(
|
|||
|
||||
if grpc_channel is None or learner_client is None:
|
||||
learner_client, grpc_channel = learner_service_client(
|
||||
host=cfg.actor_learner_config.learner_host,
|
||||
port=cfg.actor_learner_config.learner_port,
|
||||
host=cfg.policy.actor_learner_config["learner_host"],
|
||||
port=cfg.policy.actor_learner_config["learner_port"],
|
||||
)
|
||||
|
||||
try:
|
||||
|
@ -269,7 +266,7 @@ def update_policy_parameters(policy: SACPolicy, parameters_queue: Queue, device)
|
|||
|
||||
|
||||
def act_with_policy(
|
||||
cfg: DictConfig,
|
||||
cfg: TrainPipelineConfig,
|
||||
robot: Robot,
|
||||
reward_classifier: nn.Module,
|
||||
shutdown_event: any, # Event,
|
||||
|
@ -291,7 +288,7 @@ def act_with_policy(
|
|||
|
||||
online_env = make_robot_env(robot=robot, reward_classifier=reward_classifier, cfg=cfg)
|
||||
|
||||
set_global_seed(cfg.seed)
|
||||
set_seed(cfg.seed)
|
||||
device = get_safe_torch_device(cfg.device, log=True)
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
@ -304,7 +301,7 @@ def act_with_policy(
|
|||
### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters
|
||||
# TODO: At some point we should just need make sac policy
|
||||
policy: SACPolicy = make_policy(
|
||||
hydra_cfg=cfg,
|
||||
cfg=cfg.policy,
|
||||
# dataset_stats=offline_dataset.meta.stats if not cfg.resume else None,
|
||||
# Hack: But if we do online training, we do not need dataset_stats
|
||||
dataset_stats=None,
|
||||
|
@ -469,7 +466,7 @@ def get_frequency_stats(list_policy_time: list[float]) -> dict[str, float]:
|
|||
return stats
|
||||
|
||||
|
||||
def log_policy_frequency_issue(policy_fps: float, cfg: DictConfig, interaction_step: int):
|
||||
def log_policy_frequency_issue(policy_fps: float, cfg: TrainPipelineConfig, interaction_step: int):
|
||||
if policy_fps < cfg.fps:
|
||||
logging.warning(
|
||||
f"[ACTOR] Policy FPS {policy_fps:.1f} below required {cfg.fps} at step {interaction_step}"
|
||||
|
@ -497,25 +494,25 @@ def establish_learner_connection(
|
|||
return False
|
||||
|
||||
|
||||
def use_threads(cfg: DictConfig) -> bool:
|
||||
return cfg.actor_learner_config.concurrency.actor == "threads"
|
||||
def use_threads(cfg: TrainPipelineConfig) -> bool:
|
||||
return cfg.policy.concurrency["actor"] == "threads"
|
||||
|
||||
|
||||
@hydra.main(version_base="1.2", config_name="default", config_path="../../configs")
|
||||
def actor_cli(cfg: dict):
|
||||
@parser.wrap()
|
||||
def actor_cli(cfg: TrainPipelineConfig):
|
||||
if not use_threads(cfg):
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
mp.set_start_method("spawn")
|
||||
|
||||
init_logging(log_file="actor.log")
|
||||
robot = make_robot(cfg=cfg.robot)
|
||||
robot = make_robot(robot_type=cfg.env.robot)
|
||||
|
||||
shutdown_event = setup_process_handlers(use_threads(cfg))
|
||||
|
||||
learner_client, grpc_channel = learner_service_client(
|
||||
host=cfg.actor_learner_config.learner_host,
|
||||
port=cfg.actor_learner_config.learner_port,
|
||||
host=cfg.policy.actor_learner_config["learner_host"],
|
||||
port=cfg.policy.actor_learner_config["learner_port"],
|
||||
)
|
||||
|
||||
logging.info("[ACTOR] Establishing connection with Learner")
|
||||
|
@ -570,22 +567,22 @@ def actor_cli(cfg: dict):
|
|||
# TODO: Remove this once we merge into main
|
||||
reward_classifier = None
|
||||
if (
|
||||
cfg.env.reward_classifier.pretrained_path is not None
|
||||
and cfg.env.reward_classifier.config_path is not None
|
||||
cfg.env.reward_classifier["pretrained_path"] is not None
|
||||
and cfg.env.reward_classifier["config_path"] is not None
|
||||
):
|
||||
reward_classifier = get_classifier(
|
||||
pretrained_path=cfg.env.reward_classifier.pretrained_path,
|
||||
config_path=cfg.env.reward_classifier.config_path,
|
||||
pretrained_path=cfg.env.reward_classifier["pretrained_path"],
|
||||
config_path=cfg.env.reward_classifier["config_path"],
|
||||
)
|
||||
|
||||
act_with_policy(
|
||||
cfg,
|
||||
robot,
|
||||
reward_classifier,
|
||||
shutdown_event,
|
||||
parameters_queue,
|
||||
transitions_queue,
|
||||
interactions_queue,
|
||||
cfg=cfg,
|
||||
robot=robot,
|
||||
reward_classifier=reward_classifier,
|
||||
shutdown_event=shutdown_event,
|
||||
parameters_queue=parameters_queue,
|
||||
transitions_queue=transitions_queue,
|
||||
interactions_queue=interactions_queue,
|
||||
)
|
||||
logging.info("[ACTOR] Policy process joined")
|
||||
|
||||
|
|
|
@ -15,10 +15,12 @@ import json
|
|||
from dataclasses import dataclass
|
||||
|
||||
from lerobot.common.envs.utils import preprocess_observation
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.common.envs.configs import EnvConfig
|
||||
from lerobot.common.robot_devices.control_utils import (
|
||||
busy_wait,
|
||||
is_headless,
|
||||
reset_follower_position,
|
||||
# reset_follower_position,
|
||||
)
|
||||
|
||||
from typing import Optional
|
||||
|
@ -28,6 +30,7 @@ from lerobot.common.robot_devices.robots.utils import make_robot_from_config
|
|||
from lerobot.common.robot_devices.robots.configs import RobotConfig
|
||||
|
||||
from lerobot.scripts.server.kinematics import RobotKinematics
|
||||
from lerobot.scripts.server.maniskill_manipulator import ManiskillEnvConfig, make_maniskill
|
||||
from lerobot.configs import parser
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
@ -1094,7 +1097,10 @@ class ActionScaleWrapper(gym.ActionWrapper):
|
|||
return action * self.scale_vector, is_intervention
|
||||
|
||||
|
||||
def make_robot_env(cfg, robot) -> gym.vector.VectorEnv:
|
||||
@parser.wrap()
|
||||
def make_robot_env(cfg: EnvConfig) -> gym.vector.VectorEnv:
|
||||
# def make_robot_env(cfg: TrainPipelineConfig) -> gym.vector.VectorEnv:
|
||||
# def make_robot_env(cfg: ManiskillEnvConfig) -> gym.vector.VectorEnv:
|
||||
"""
|
||||
Factory function to create a vectorized robot environment.
|
||||
|
||||
|
@ -1106,7 +1112,7 @@ def make_robot_env(cfg, robot) -> gym.vector.VectorEnv:
|
|||
Returns:
|
||||
A vectorized gym environment with all the necessary wrappers applied.
|
||||
"""
|
||||
if "maniskill" in cfg.env_name:
|
||||
if "maniskill" in cfg.name:
|
||||
from lerobot.scripts.server.maniskill_manipulator import make_maniskill
|
||||
|
||||
logging.warning("WE SHOULD REMOVE THE MANISKILL BEFORE THE MERGE INTO MAIN")
|
||||
|
@ -1115,7 +1121,7 @@ def make_robot_env(cfg, robot) -> gym.vector.VectorEnv:
|
|||
n_envs=1,
|
||||
)
|
||||
return env
|
||||
|
||||
robot = cfg.robot
|
||||
# Create base environment
|
||||
env = HILSerlRobotEnv(
|
||||
robot=robot,
|
||||
|
@ -1329,80 +1335,82 @@ def replay_episode(env, repo_id, root=None, episode=0):
|
|||
busy_wait(1 / 10 - dt_s)
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def main(cfg: HILSerlRobotEnvConfig):
|
||||
# @parser.wrap()
|
||||
# def main(cfg):
|
||||
|
||||
robot = make_robot_from_config(cfg.robot)
|
||||
# robot = make_robot_from_config(cfg.robot)
|
||||
|
||||
reward_classifier = None #get_classifier(
|
||||
# cfg.wrapper.reward_classifier_pretrained_path, cfg.wrapper.reward_classifier_config_file
|
||||
# )
|
||||
user_relative_joint_positions = True
|
||||
# reward_classifier = None #get_classifier(
|
||||
# # cfg.wrapper.reward_classifier_pretrained_path, cfg.wrapper.reward_classifier_config_file
|
||||
# # )
|
||||
# user_relative_joint_positions = True
|
||||
|
||||
env = make_robot_env(cfg, robot)
|
||||
# env = make_robot_env(cfg, robot)
|
||||
|
||||
if cfg.mode == "record":
|
||||
policy = None
|
||||
if cfg.pretrained_policy_name_or_path is not None:
|
||||
from lerobot.common.policies.sac.modeling_sac import SACPolicy
|
||||
# if cfg.mode == "record":
|
||||
# policy = None
|
||||
# if cfg.pretrained_policy_name_or_path is not None:
|
||||
# from lerobot.common.policies.sac.modeling_sac import SACPolicy
|
||||
|
||||
policy = SACPolicy.from_pretrained(cfg.pretrained_policy_name_or_path)
|
||||
policy.to(cfg.device)
|
||||
policy.eval()
|
||||
# policy = SACPolicy.from_pretrained(cfg.pretrained_policy_name_or_path)
|
||||
# policy.to(cfg.device)
|
||||
# policy.eval()
|
||||
|
||||
record_dataset(
|
||||
env,
|
||||
cfg.repo_id,
|
||||
root=cfg.dataset_root,
|
||||
num_episodes=cfg.num_episodes,
|
||||
fps=cfg.fps,
|
||||
task_description=cfg.task,
|
||||
policy=policy,
|
||||
)
|
||||
exit()
|
||||
# record_dataset(
|
||||
# env,
|
||||
# cfg.repo_id,
|
||||
# root=cfg.dataset_root,
|
||||
# num_episodes=cfg.num_episodes,
|
||||
# fps=cfg.fps,
|
||||
# task_description=cfg.task,
|
||||
# policy=policy,
|
||||
# )
|
||||
# exit()
|
||||
|
||||
if cfg.mode == "replay":
|
||||
replay_episode(
|
||||
env,
|
||||
cfg.replay_repo_id,
|
||||
root=cfg.dataset_root,
|
||||
episode=cfg.replay_episode,
|
||||
)
|
||||
exit()
|
||||
# if cfg.mode == "replay":
|
||||
# replay_episode(
|
||||
# env,
|
||||
# cfg.replay_repo_id,
|
||||
# root=cfg.dataset_root,
|
||||
# episode=cfg.replay_episode,
|
||||
# )
|
||||
# exit()
|
||||
|
||||
env.reset()
|
||||
# env.reset()
|
||||
|
||||
# Retrieve the robot's action space for joint commands.
|
||||
action_space_robot = env.action_space.spaces[0]
|
||||
# # Retrieve the robot's action space for joint commands.
|
||||
# action_space_robot = env.action_space.spaces[0]
|
||||
|
||||
# Initialize the smoothed action as a random sample.
|
||||
smoothed_action = action_space_robot.sample()
|
||||
# # Initialize the smoothed action as a random sample.
|
||||
# smoothed_action = action_space_robot.sample()
|
||||
|
||||
# Smoothing coefficient (alpha) defines how much of the new random sample to mix in.
|
||||
# A value close to 0 makes the trajectory very smooth (slow to change), while a value close to 1 is less smooth.
|
||||
alpha = 1.0
|
||||
# # Smoothing coefficient (alpha) defines how much of the new random sample to mix in.
|
||||
# # A value close to 0 makes the trajectory very smooth (slow to change), while a value close to 1 is less smooth.
|
||||
# alpha = 1.0
|
||||
|
||||
num_episode = 0
|
||||
sucesses = []
|
||||
while num_episode < 20:
|
||||
start_loop_s = time.perf_counter()
|
||||
# Sample a new random action from the robot's action space.
|
||||
new_random_action = action_space_robot.sample()
|
||||
# Update the smoothed action using an exponential moving average.
|
||||
smoothed_action = alpha * new_random_action + (1 - alpha) * smoothed_action
|
||||
# num_episode = 0
|
||||
# sucesses = []
|
||||
# while num_episode < 20:
|
||||
# start_loop_s = time.perf_counter()
|
||||
# # Sample a new random action from the robot's action space.
|
||||
# new_random_action = action_space_robot.sample()
|
||||
# # Update the smoothed action using an exponential moving average.
|
||||
# smoothed_action = alpha * new_random_action + (1 - alpha) * smoothed_action
|
||||
|
||||
# Execute the step: wrap the NumPy action in a torch tensor.
|
||||
obs, reward, terminated, truncated, info = env.step((torch.from_numpy(smoothed_action), False))
|
||||
if terminated or truncated:
|
||||
sucesses.append(reward)
|
||||
env.reset()
|
||||
num_episode += 1
|
||||
# # Execute the step: wrap the NumPy action in a torch tensor.
|
||||
# obs, reward, terminated, truncated, info = env.step((torch.from_numpy(smoothed_action), False))
|
||||
# if terminated or truncated:
|
||||
# sucesses.append(reward)
|
||||
# env.reset()
|
||||
# num_episode += 1
|
||||
|
||||
dt_s = time.perf_counter() - start_loop_s
|
||||
busy_wait(1 / cfg.fps - dt_s)
|
||||
# dt_s = time.perf_counter() - start_loop_s
|
||||
# busy_wait(1 / cfg.fps - dt_s)
|
||||
|
||||
logging.info(f"Success after 20 steps {sucesses}")
|
||||
logging.info(f"success rate {sum(sucesses) / len(sucesses)}")
|
||||
# logging.info(f"Success after 20 steps {sucesses}")
|
||||
# logging.info(f"success rate {sum(sucesses) / len(sucesses)}")
|
||||
|
||||
# if __name__ == "__main__":
|
||||
# main()
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
make_robot_env()
|
|
@ -19,40 +19,45 @@ import shutil
|
|||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from pprint import pformat
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import draccus
|
||||
import grpc
|
||||
|
||||
# Import generated stubs
|
||||
import hilserl_pb2_grpc # type: ignore
|
||||
import hydra
|
||||
import torch
|
||||
from deepdiff import DeepDiff
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from termcolor import colored
|
||||
from torch import nn
|
||||
|
||||
# from torch.multiprocessing import Event, Queue, Process
|
||||
# from threading import Event, Thread
|
||||
# from torch.multiprocessing import Queue, Event
|
||||
from torch.multiprocessing import Queue
|
||||
from torch.optim.optimizer import Optimizer
|
||||
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.configs import parser
|
||||
# TODO: Remove the import of maniskill
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.logger import Logger, log_output_dir
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.policies.sac.modeling_sac import SACPolicy
|
||||
from lerobot.common.policies.sac.modeling_sac import SACPolicy, SACConfig
|
||||
from lerobot.common.utils.train_utils import (
|
||||
get_step_checkpoint_dir,
|
||||
get_step_identifier,
|
||||
load_training_state as utils_load_training_state,
|
||||
save_checkpoint,
|
||||
update_last_checkpoint,
|
||||
)
|
||||
from lerobot.common.utils.random_utils import set_seed
|
||||
from lerobot.common.utils.utils import (
|
||||
format_big_number,
|
||||
get_global_random_state,
|
||||
get_safe_torch_device,
|
||||
init_hydra_config,
|
||||
init_logging,
|
||||
set_global_random_state,
|
||||
set_global_seed,
|
||||
)
|
||||
|
||||
from lerobot.common.policies.utils import get_device_from_parameters
|
||||
from lerobot.common.utils.wandb_utils import WandBLogger
|
||||
from lerobot.scripts.server import learner_service
|
||||
from lerobot.scripts.server.buffer import (
|
||||
ReplayBuffer,
|
||||
|
@ -64,102 +69,167 @@ from lerobot.scripts.server.buffer import (
|
|||
state_to_bytes,
|
||||
)
|
||||
from lerobot.scripts.server.utils import setup_process_handlers
|
||||
from lerobot.common.constants import (
|
||||
CHECKPOINTS_DIR,
|
||||
LAST_CHECKPOINT_LINK,
|
||||
PRETRAINED_MODEL_DIR,
|
||||
TRAINING_STATE_DIR,
|
||||
TRAINING_STEP,
|
||||
)
|
||||
|
||||
|
||||
def handle_resume_logic(cfg: DictConfig, out_dir: str) -> DictConfig:
|
||||
def handle_resume_logic(cfg: TrainPipelineConfig) -> TrainPipelineConfig:
|
||||
"""
|
||||
Handle the resume logic for training.
|
||||
|
||||
If resume is True:
|
||||
- Verifies that a checkpoint exists
|
||||
- Loads the checkpoint configuration
|
||||
- Logs resumption details
|
||||
- Returns the checkpoint configuration
|
||||
|
||||
If resume is False:
|
||||
- Checks if an output directory exists (to prevent accidental overwriting)
|
||||
- Returns the original configuration
|
||||
|
||||
Args:
|
||||
cfg (TrainPipelineConfig): The training configuration
|
||||
|
||||
Returns:
|
||||
TrainPipelineConfig: The updated configuration
|
||||
|
||||
Raises:
|
||||
RuntimeError: If resume is True but no checkpoint found, or if resume is False but directory exists
|
||||
"""
|
||||
out_dir = cfg.output_dir
|
||||
|
||||
# Case 1: Not resuming, but need to check if directory exists to prevent overwrites
|
||||
if not cfg.resume:
|
||||
if Logger.get_last_checkpoint_dir(out_dir).exists():
|
||||
checkpoint_dir = os.path.join(out_dir, CHECKPOINTS_DIR, LAST_CHECKPOINT_LINK)
|
||||
if os.path.exists(checkpoint_dir):
|
||||
raise RuntimeError(
|
||||
f"Output directory {Logger.get_last_checkpoint_dir(out_dir)} already exists. "
|
||||
f"Output directory {checkpoint_dir} already exists. "
|
||||
"Use `resume=true` to resume training."
|
||||
)
|
||||
return cfg
|
||||
|
||||
# if resume == True
|
||||
checkpoint_dir = Logger.get_last_checkpoint_dir(out_dir)
|
||||
if not checkpoint_dir.exists():
|
||||
# Case 2: Resuming training
|
||||
checkpoint_dir = os.path.join(out_dir, CHECKPOINTS_DIR, LAST_CHECKPOINT_LINK)
|
||||
if not os.path.exists(checkpoint_dir):
|
||||
raise RuntimeError(f"No model checkpoint found in {checkpoint_dir} for resume=True")
|
||||
|
||||
checkpoint_cfg_path = str(Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml")
|
||||
# Log that we found a valid checkpoint and are resuming
|
||||
logging.info(
|
||||
colored(
|
||||
"Resume=True detected, resuming previous run",
|
||||
"Valid checkpoint found: resume=True detected, resuming previous run",
|
||||
color="yellow",
|
||||
attrs=["bold"],
|
||||
)
|
||||
)
|
||||
|
||||
checkpoint_cfg = init_hydra_config(checkpoint_cfg_path)
|
||||
diff = DeepDiff(OmegaConf.to_container(checkpoint_cfg), OmegaConf.to_container(cfg))
|
||||
|
||||
if "values_changed" in diff and "root['resume']" in diff["values_changed"]:
|
||||
del diff["values_changed"]["root['resume']"]
|
||||
|
||||
if len(diff) > 0:
|
||||
logging.warning(
|
||||
f"Differences between the checkpoint config and the provided config detected: \n{pformat(diff)}\n"
|
||||
"Checkpoint configuration takes precedence."
|
||||
)
|
||||
|
||||
# Load config using Draccus
|
||||
checkpoint_cfg_path = os.path.join(checkpoint_dir, PRETRAINED_MODEL_DIR, "train_config.json")
|
||||
checkpoint_cfg = TrainPipelineConfig.from_pretrained(checkpoint_cfg_path)
|
||||
|
||||
# Ensure resume flag is set in returned config
|
||||
checkpoint_cfg.resume = True
|
||||
return checkpoint_cfg
|
||||
|
||||
|
||||
def load_training_state(
|
||||
cfg: DictConfig,
|
||||
logger: Logger,
|
||||
optimizers: Optimizer | dict,
|
||||
cfg: TrainPipelineConfig,
|
||||
optimizers: Optimizer | dict[str, Optimizer],
|
||||
):
|
||||
"""
|
||||
Loads the training state (optimizers, step count, etc.) from a checkpoint.
|
||||
|
||||
Args:
|
||||
cfg (TrainPipelineConfig): Training configuration
|
||||
optimizers (Optimizer | dict): Optimizers to load state into
|
||||
|
||||
Returns:
|
||||
tuple: (optimization_step, interaction_step) or (None, None) if not resuming
|
||||
"""
|
||||
if not cfg.resume:
|
||||
return None, None
|
||||
|
||||
training_state = torch.load(
|
||||
logger.last_checkpoint_dir / logger.training_state_file_name, weights_only=False
|
||||
)
|
||||
|
||||
if isinstance(training_state["optimizer"], dict):
|
||||
assert set(training_state["optimizer"].keys()) == set(optimizers.keys())
|
||||
for k, v in training_state["optimizer"].items():
|
||||
optimizers[k].load_state_dict(v)
|
||||
else:
|
||||
optimizers.load_state_dict(training_state["optimizer"])
|
||||
|
||||
set_global_random_state({k: training_state[k] for k in get_global_random_state()})
|
||||
return training_state["step"], training_state["interaction_step"]
|
||||
# Construct path to the last checkpoint directory
|
||||
checkpoint_dir = os.path.join(cfg.output_dir, CHECKPOINTS_DIR, LAST_CHECKPOINT_LINK)
|
||||
|
||||
logging.info(f"Loading training state from {checkpoint_dir}")
|
||||
|
||||
try:
|
||||
# Use the utility function from train_utils which loads the optimizer state
|
||||
# The function returns (step, updated_optimizer, scheduler)
|
||||
step, optimizers, _ = utils_load_training_state(Path(checkpoint_dir), optimizers, None)
|
||||
|
||||
# For interaction step, we still need to load the training_state.pt file
|
||||
training_state_path = os.path.join(checkpoint_dir, TRAINING_STATE_DIR, "training_state.pt")
|
||||
training_state = torch.load(training_state_path, weights_only=False)
|
||||
interaction_step = training_state.get("interaction_step", 0)
|
||||
|
||||
logging.info(f"Resuming from step {step}, interaction step {interaction_step}")
|
||||
return step, interaction_step
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to load training state: {e}")
|
||||
return None, None
|
||||
|
||||
|
||||
def log_training_info(cfg: DictConfig, out_dir: str, policy: nn.Module) -> None:
|
||||
def log_training_info(cfg: TrainPipelineConfig, policy: nn.Module) -> None:
|
||||
"""
|
||||
Log information about the training process.
|
||||
|
||||
Args:
|
||||
cfg (TrainPipelineConfig): Training configuration
|
||||
policy (nn.Module): Policy model
|
||||
"""
|
||||
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
||||
num_total_params = sum(p.numel() for p in policy.parameters())
|
||||
|
||||
log_output_dir(out_dir)
|
||||
|
||||
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}")
|
||||
logging.info(f"{cfg.env.task=}")
|
||||
logging.info(f"{cfg.training.online_steps=}")
|
||||
logging.info(f"{cfg.policy.online_steps=}")
|
||||
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
|
||||
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
|
||||
|
||||
|
||||
def initialize_replay_buffer(
|
||||
cfg: DictConfig, logger: Logger, device: str, storage_device: str
|
||||
cfg: TrainPipelineConfig,
|
||||
device: str,
|
||||
storage_device: str
|
||||
) -> ReplayBuffer:
|
||||
"""
|
||||
Initialize a replay buffer, either empty or from a dataset if resuming.
|
||||
|
||||
Args:
|
||||
cfg (TrainPipelineConfig): Training configuration
|
||||
device (str): Device to store tensors on
|
||||
storage_device (str): Device for storage optimization
|
||||
|
||||
Returns:
|
||||
ReplayBuffer: Initialized replay buffer
|
||||
"""
|
||||
if not cfg.resume:
|
||||
return ReplayBuffer(
|
||||
capacity=cfg.training.online_buffer_capacity,
|
||||
capacity=cfg.policy.online_buffer_capacity,
|
||||
device=device,
|
||||
state_keys=cfg.policy.input_shapes.keys(),
|
||||
state_keys=cfg.policy.input_features.keys(),
|
||||
storage_device=storage_device,
|
||||
optimize_memory=True,
|
||||
)
|
||||
|
||||
logging.info("Resume training load the online dataset")
|
||||
dataset_path = os.path.join(cfg.output_dir, "dataset")
|
||||
dataset = LeRobotDataset(
|
||||
repo_id=cfg.dataset_repo_id,
|
||||
repo_id=cfg.dataset.dataset_repo_id,
|
||||
local_files_only=True,
|
||||
root=logger.log_dir / "dataset",
|
||||
root=dataset_path,
|
||||
)
|
||||
return ReplayBuffer.from_lerobot_dataset(
|
||||
lerobot_dataset=dataset,
|
||||
capacity=cfg.training.online_buffer_capacity,
|
||||
capacity=cfg.policy.online_buffer_capacity,
|
||||
device=device,
|
||||
state_keys=cfg.policy.input_shapes.keys(),
|
||||
optimize_memory=True,
|
||||
|
@ -167,33 +237,45 @@ def initialize_replay_buffer(
|
|||
|
||||
|
||||
def initialize_offline_replay_buffer(
|
||||
cfg: DictConfig,
|
||||
logger: Logger,
|
||||
cfg: TrainPipelineConfig,
|
||||
device: str,
|
||||
storage_device: str,
|
||||
active_action_dims: list[int] | None = None,
|
||||
) -> ReplayBuffer:
|
||||
"""
|
||||
Initialize an offline replay buffer from a dataset.
|
||||
|
||||
Args:
|
||||
cfg (TrainPipelineConfig): Training configuration
|
||||
device (str): Device to store tensors on
|
||||
storage_device (str): Device for storage optimization
|
||||
active_action_dims (list[int] | None): Active action dimensions for masking
|
||||
|
||||
Returns:
|
||||
ReplayBuffer: Initialized offline replay buffer
|
||||
"""
|
||||
if not cfg.resume:
|
||||
logging.info("make_dataset offline buffer")
|
||||
offline_dataset = make_dataset(cfg)
|
||||
if cfg.resume:
|
||||
else:
|
||||
logging.info("load offline dataset")
|
||||
dataset_offline_path = os.path.join(cfg.output_dir, "dataset_offline")
|
||||
offline_dataset = LeRobotDataset(
|
||||
repo_id=cfg.dataset_repo_id,
|
||||
repo_id=cfg.dataset.dataset_repo_id,
|
||||
local_files_only=True,
|
||||
root=logger.log_dir / "dataset_offline",
|
||||
root=dataset_offline_path,
|
||||
)
|
||||
|
||||
logging.info("Convert to a offline replay buffer")
|
||||
offline_replay_buffer = ReplayBuffer.from_lerobot_dataset(
|
||||
offline_dataset,
|
||||
device=device,
|
||||
state_keys=cfg.policy.input_shapes.keys(),
|
||||
state_keys=cfg.policy.input_features.keys(),
|
||||
action_mask=active_action_dims,
|
||||
action_delta=cfg.env.wrapper.delta_action,
|
||||
storage_device=storage_device,
|
||||
optimize_memory=True,
|
||||
capacity=cfg.training.offline_buffer_capacity,
|
||||
capacity=cfg.policy.offline_buffer_capacity,
|
||||
)
|
||||
return offline_replay_buffer
|
||||
|
||||
|
@ -215,16 +297,23 @@ def get_observation_features(
|
|||
return observation_features, next_observation_features
|
||||
|
||||
|
||||
def use_threads(cfg: DictConfig) -> bool:
|
||||
return cfg.actor_learner_config.concurrency.learner == "threads"
|
||||
def use_threads(cfg: TrainPipelineConfig) -> bool:
|
||||
return cfg.policy.concurrency["learner"] == "threads"
|
||||
|
||||
|
||||
def start_learner_threads(
|
||||
cfg: DictConfig,
|
||||
logger: Logger,
|
||||
out_dir: str,
|
||||
cfg: TrainPipelineConfig,
|
||||
wandb_logger: WandBLogger | None,
|
||||
shutdown_event: any, # Event,
|
||||
) -> None:
|
||||
"""
|
||||
Start the learner threads for training.
|
||||
|
||||
Args:
|
||||
cfg (TrainPipelineConfig): Training configuration
|
||||
wandb_logger (WandBLogger | None): Logger for metrics
|
||||
shutdown_event: Event to signal shutdown
|
||||
"""
|
||||
# Create multiprocessing queues
|
||||
transition_queue = Queue()
|
||||
interaction_message_queue = Queue()
|
||||
|
@ -255,13 +344,12 @@ def start_learner_threads(
|
|||
communication_process.start()
|
||||
|
||||
add_actor_information_and_train(
|
||||
cfg,
|
||||
logger,
|
||||
out_dir,
|
||||
shutdown_event,
|
||||
transition_queue,
|
||||
interaction_message_queue,
|
||||
parameters_queue,
|
||||
cfg=cfg,
|
||||
wandb_logger=wandb_logger,
|
||||
shutdown_event=shutdown_event,
|
||||
transition_queue=transition_queue,
|
||||
interaction_message_queue=interaction_message_queue,
|
||||
parameters_queue=parameters_queue,
|
||||
)
|
||||
logging.info("[LEARNER] Training process stopped")
|
||||
|
||||
|
@ -286,7 +374,7 @@ def start_learner_server(
|
|||
transition_queue: Queue,
|
||||
interaction_message_queue: Queue,
|
||||
shutdown_event: any, # Event,
|
||||
cfg: DictConfig,
|
||||
cfg: TrainPipelineConfig,
|
||||
):
|
||||
if not use_threads(cfg):
|
||||
# We need init logging for MP separataly
|
||||
|
@ -298,11 +386,11 @@ def start_learner_server(
|
|||
setup_process_handlers(False)
|
||||
|
||||
service = learner_service.LearnerService(
|
||||
shutdown_event,
|
||||
parameters_queue,
|
||||
cfg.actor_learner_config.policy_parameters_push_frequency,
|
||||
transition_queue,
|
||||
interaction_message_queue,
|
||||
shutdown_event=shutdown_event,
|
||||
parameters_queue=parameters_queue,
|
||||
seconds_between_pushes=cfg.policy.actor_learner_config["policy_parameters_push_frequency"],
|
||||
transition_queue=transition_queue,
|
||||
interaction_message_queue=interaction_message_queue,
|
||||
)
|
||||
|
||||
server = grpc.server(
|
||||
|
@ -318,8 +406,8 @@ def start_learner_server(
|
|||
server,
|
||||
)
|
||||
|
||||
host = cfg.actor_learner_config.learner_host
|
||||
port = cfg.actor_learner_config.learner_port
|
||||
host = cfg.policy.actor_learner_config["learner_host"]
|
||||
port = cfg.policy.actor_learner_config["learner_port"]
|
||||
|
||||
server.add_insecure_port(f"{host}:{port}")
|
||||
server.start()
|
||||
|
@ -385,9 +473,8 @@ def push_actor_policy_to_queue(parameters_queue: Queue, policy: nn.Module):
|
|||
|
||||
|
||||
def add_actor_information_and_train(
|
||||
cfg,
|
||||
logger: Logger,
|
||||
out_dir: str,
|
||||
cfg: TrainPipelineConfig,
|
||||
wandb_logger: WandBLogger | None,
|
||||
shutdown_event: any, # Event,
|
||||
transition_queue: Queue,
|
||||
interaction_message_queue: Queue,
|
||||
|
@ -405,69 +492,60 @@ def add_actor_information_and_train(
|
|||
- Periodically updates the actor, critic, and temperature optimizers.
|
||||
- Logs training statistics, including loss values and optimization frequency.
|
||||
|
||||
**NOTE:**
|
||||
- This function performs multiple responsibilities (data transfer, training, and logging).
|
||||
It should ideally be split into smaller functions in the future.
|
||||
- Due to Python's **Global Interpreter Lock (GIL)**, running separate threads for different tasks
|
||||
significantly reduces performance. Instead, this function executes all operations in a single thread.
|
||||
|
||||
Args:
|
||||
cfg: Configuration object containing hyperparameters.
|
||||
device (str): The computing device (`"cpu"` or `"cuda"`).
|
||||
logger (Logger): Logger instance for tracking training progress.
|
||||
out_dir (str): The output directory for storing training checkpoints and logs.
|
||||
cfg (TrainPipelineConfig): Configuration object containing hyperparameters.
|
||||
wandb_logger (WandBLogger | None): Logger for tracking training progress.
|
||||
shutdown_event (Event): Event to signal shutdown.
|
||||
transition_queue (Queue): Queue for receiving transitions from the actor.
|
||||
interaction_message_queue (Queue): Queue for receiving interaction messages from the actor.
|
||||
parameters_queue (Queue): Queue for sending policy parameters to the actor.
|
||||
"""
|
||||
|
||||
device = get_safe_torch_device(cfg.device, log=True)
|
||||
storage_device = get_safe_torch_device(cfg_device=cfg.training.storage_device)
|
||||
device = get_safe_torch_device(try_device=cfg.policy.device, log=True)
|
||||
storage_device = get_safe_torch_device(try_device=cfg.policy.storage_device)
|
||||
|
||||
logging.info("Initializing policy")
|
||||
### Instantiate the policy in both the actor and learner processes
|
||||
### To avoid sending a SACPolicy object through the port, we create a policy intance
|
||||
### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters
|
||||
# TODO: At some point we should just need make sac policy
|
||||
|
||||
# Get checkpoint dir for resuming
|
||||
checkpoint_dir = os.path.join(cfg.output_dir, CHECKPOINTS_DIR, LAST_CHECKPOINT_LINK) if cfg.resume else None
|
||||
pretrained_path = os.path.join(checkpoint_dir, PRETRAINED_MODEL_DIR) if checkpoint_dir else None
|
||||
|
||||
# TODO(Adil): This don't work anymore !
|
||||
policy: SACPolicy = make_policy(
|
||||
hydra_cfg=cfg,
|
||||
# dataset_stats=offline_dataset.meta.stats if not cfg.resume else None,
|
||||
# Hack: But if we do online traning, we do not need dataset_stats
|
||||
dataset_stats=None,
|
||||
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None,
|
||||
cfg=cfg.policy,
|
||||
# ds_meta=cfg.dataset,
|
||||
env_cfg=cfg.env
|
||||
)
|
||||
|
||||
# Update the policy config with the grad_clip_norm value from training config if it exists
|
||||
clip_grad_norm_value = cfg.training.grad_clip_norm
|
||||
clip_grad_norm_value:float = cfg.policy.grad_clip_norm
|
||||
|
||||
# compile policy
|
||||
policy = torch.compile(policy)
|
||||
assert isinstance(policy, nn.Module)
|
||||
policy.train()
|
||||
|
||||
push_actor_policy_to_queue(parameters_queue, policy)
|
||||
push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy)
|
||||
|
||||
last_time_policy_pushed = time.time()
|
||||
|
||||
optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg, policy)
|
||||
resume_optimization_step, resume_interaction_step = load_training_state(cfg, logger, optimizers)
|
||||
optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg=cfg, policy=policy)
|
||||
resume_optimization_step, resume_interaction_step = load_training_state(cfg=cfg, optimizers=optimizers)
|
||||
|
||||
log_training_info(cfg, out_dir, policy)
|
||||
log_training_info(cfg=cfg, policy= policy)
|
||||
|
||||
replay_buffer = initialize_replay_buffer(cfg, logger, device, storage_device)
|
||||
batch_size = cfg.training.batch_size
|
||||
replay_buffer = initialize_replay_buffer(cfg, device, storage_device)
|
||||
batch_size = cfg.batch_size
|
||||
offline_replay_buffer = None
|
||||
|
||||
if cfg.dataset_repo_id is not None:
|
||||
if cfg.dataset is not None:
|
||||
active_action_dims = None
|
||||
# TODO: FIX THIS
|
||||
if cfg.env.wrapper.joint_masking_action_space is not None:
|
||||
active_action_dims = [
|
||||
i for i, mask in enumerate(cfg.env.wrapper.joint_masking_action_space) if mask
|
||||
]
|
||||
offline_replay_buffer = initialize_offline_replay_buffer(
|
||||
cfg=cfg,
|
||||
logger=logger,
|
||||
device=device,
|
||||
storage_device=storage_device,
|
||||
active_action_dims=active_action_dims,
|
||||
|
@ -484,18 +562,22 @@ def add_actor_information_and_train(
|
|||
interaction_step_shift = resume_interaction_step if resume_interaction_step is not None else 0
|
||||
|
||||
# Extract variables from cfg
|
||||
online_step_before_learning = cfg.training.online_step_before_learning
|
||||
online_step_before_learning = cfg.policy.online_step_before_learning
|
||||
utd_ratio = cfg.policy.utd_ratio
|
||||
dataset_repo_id = cfg.dataset_repo_id
|
||||
fps = cfg.fps
|
||||
log_freq = cfg.training.log_freq
|
||||
save_freq = cfg.training.save_freq
|
||||
device = cfg.device
|
||||
storage_device = cfg.training.storage_device
|
||||
policy_update_freq = cfg.training.policy_update_freq
|
||||
policy_parameters_push_frequency = cfg.actor_learner_config.policy_parameters_push_frequency
|
||||
save_checkpoint = cfg.training.save_checkpoint
|
||||
online_steps = cfg.training.online_steps
|
||||
|
||||
dataset_repo_id = None
|
||||
if cfg.dataset is not None:
|
||||
dataset_repo_id = cfg.dataset.repo_id
|
||||
|
||||
fps = cfg.env.fps
|
||||
log_freq = cfg.log_freq
|
||||
save_freq = cfg.save_freq
|
||||
device = cfg.policy.device
|
||||
storage_device = cfg.policy.storage_device
|
||||
policy_update_freq = cfg.policy.policy_update_freq
|
||||
policy_parameters_push_frequency = cfg.policy.actor_learner_config["policy_parameters_push_frequency"]
|
||||
save_checkpoint = cfg.save_checkpoint
|
||||
online_steps = cfg.policy.online_steps
|
||||
|
||||
while True:
|
||||
if shutdown_event is not None and shutdown_event.is_set():
|
||||
|
@ -516,7 +598,7 @@ def add_actor_information_and_train(
|
|||
continue
|
||||
replay_buffer.add(**transition)
|
||||
|
||||
if cfg.dataset_repo_id is not None and transition.get("complementary_info", {}).get(
|
||||
if cfg.dataset.repo_id is not None and transition.get("complementary_info", {}).get(
|
||||
"is_intervention"
|
||||
):
|
||||
offline_replay_buffer.add(**transition)
|
||||
|
@ -528,7 +610,17 @@ def add_actor_information_and_train(
|
|||
interaction_message = bytes_to_python_object(interaction_message)
|
||||
# If cfg.resume, shift the interaction step with the last checkpointed step in order to not break the logging
|
||||
interaction_message["Interaction step"] += interaction_step_shift
|
||||
logger.log_dict(interaction_message, mode="train", custom_step_key="Interaction step")
|
||||
|
||||
# Log interaction messages with WandB if available
|
||||
if wandb_logger:
|
||||
wandb_logger.log_dict(
|
||||
d=interaction_message,
|
||||
mode="train",
|
||||
custom_step_key="Interaction step"
|
||||
)
|
||||
else:
|
||||
# Log to console if no WandB logger
|
||||
logging.info(f"Interaction: {interaction_message}")
|
||||
|
||||
logging.debug("[LEARNER] Received interactions")
|
||||
|
||||
|
@ -538,11 +630,11 @@ def add_actor_information_and_train(
|
|||
logging.debug("[LEARNER] Starting optimization loop")
|
||||
time_for_one_optimization_step = time.time()
|
||||
for _ in range(utd_ratio - 1):
|
||||
batch = replay_buffer.sample(batch_size)
|
||||
batch = replay_buffer.sample(batch_size=batch_size)
|
||||
|
||||
if dataset_repo_id is not None:
|
||||
batch_offline = offline_replay_buffer.sample(batch_size)
|
||||
batch = concatenate_batch_transitions(batch, batch_offline)
|
||||
batch_offline = offline_replay_buffer.sample(batch_size=batch_size)
|
||||
batch = concatenate_batch_transitions(left_batch_transitions=batch, right_batch_transition=batch_offline)
|
||||
|
||||
actions = batch["action"]
|
||||
rewards = batch["reward"]
|
||||
|
@ -552,7 +644,7 @@ def add_actor_information_and_train(
|
|||
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
|
||||
|
||||
observation_features, next_observation_features = get_observation_features(
|
||||
policy, observations, next_observations
|
||||
policy=policy, observations=observations, next_observations=next_observations
|
||||
)
|
||||
loss_critic = policy.compute_loss_critic(
|
||||
observations=observations,
|
||||
|
@ -568,15 +660,15 @@ def add_actor_information_and_train(
|
|||
|
||||
# clip gradients
|
||||
critic_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
policy.critic_ensemble.parameters(), clip_grad_norm_value
|
||||
parameters=policy.critic_ensemble.parameters(), max_norm=clip_grad_norm_value
|
||||
)
|
||||
|
||||
optimizers["critic"].step()
|
||||
|
||||
batch = replay_buffer.sample(batch_size)
|
||||
batch = replay_buffer.sample(batch_size=batch_size)
|
||||
|
||||
if dataset_repo_id is not None:
|
||||
batch_offline = offline_replay_buffer.sample(batch_size)
|
||||
batch_offline = offline_replay_buffer.sample(batch_size=batch_size)
|
||||
batch = concatenate_batch_transitions(
|
||||
left_batch_transitions=batch, right_batch_transition=batch_offline
|
||||
)
|
||||
|
@ -590,7 +682,7 @@ def add_actor_information_and_train(
|
|||
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
|
||||
|
||||
observation_features, next_observation_features = get_observation_features(
|
||||
policy, observations, next_observations
|
||||
policy=policy, observations=observations, next_observations=next_observations
|
||||
)
|
||||
loss_critic = policy.compute_loss_critic(
|
||||
observations=observations,
|
||||
|
@ -606,7 +698,7 @@ def add_actor_information_and_train(
|
|||
|
||||
# clip gradients
|
||||
critic_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
policy.critic_ensemble.parameters(), clip_grad_norm_value
|
||||
parameters=policy.critic_ensemble.parameters(), max_norm=clip_grad_norm_value
|
||||
).item()
|
||||
|
||||
optimizers["critic"].step()
|
||||
|
@ -627,7 +719,7 @@ def add_actor_information_and_train(
|
|||
|
||||
# clip gradients
|
||||
actor_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
policy.actor.parameters_to_optimize, clip_grad_norm_value
|
||||
parameters=policy.actor.parameters_to_optimize, max_norm=clip_grad_norm_value
|
||||
).item()
|
||||
|
||||
optimizers["actor"].step()
|
||||
|
@ -645,7 +737,7 @@ def add_actor_information_and_train(
|
|||
|
||||
# clip gradients
|
||||
temp_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
[policy.log_alpha], clip_grad_norm_value
|
||||
parameters=[policy.log_alpha], max_norm=clip_grad_norm_value
|
||||
).item()
|
||||
|
||||
optimizers["temperature"].step()
|
||||
|
@ -655,7 +747,7 @@ def add_actor_information_and_train(
|
|||
training_infos["temperature"] = policy.temperature
|
||||
|
||||
if time.time() - last_time_policy_pushed > policy_parameters_push_frequency:
|
||||
push_actor_policy_to_queue(parameters_queue, policy)
|
||||
push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy)
|
||||
last_time_policy_pushed = time.time()
|
||||
|
||||
policy.update_target_networks()
|
||||
|
@ -665,22 +757,33 @@ def add_actor_information_and_train(
|
|||
if offline_replay_buffer is not None:
|
||||
training_infos["offline_replay_buffer_size"] = len(offline_replay_buffer)
|
||||
training_infos["Optimization step"] = optimization_step
|
||||
logger.log_dict(d=training_infos, mode="train", custom_step_key="Optimization step")
|
||||
# logging.info(f"Training infos: {training_infos}")
|
||||
|
||||
# Log training metrics
|
||||
if wandb_logger:
|
||||
wandb_logger.log_dict(
|
||||
d=training_infos,
|
||||
mode="train",
|
||||
custom_step_key="Optimization step"
|
||||
)
|
||||
else:
|
||||
# Log to console if no WandB logger
|
||||
logging.info(f"Training: {training_infos}")
|
||||
|
||||
time_for_one_optimization_step = time.time() - time_for_one_optimization_step
|
||||
frequency_for_one_optimization_step = 1 / (time_for_one_optimization_step + 1e-9)
|
||||
|
||||
logging.info(f"[LEARNER] Optimization frequency loop [Hz]: {frequency_for_one_optimization_step}")
|
||||
|
||||
logger.log_dict(
|
||||
{
|
||||
"Optimization frequency loop [Hz]": frequency_for_one_optimization_step,
|
||||
"Optimization step": optimization_step,
|
||||
},
|
||||
mode="train",
|
||||
custom_step_key="Optimization step",
|
||||
)
|
||||
# Log optimization frequency
|
||||
if wandb_logger:
|
||||
wandb_logger.log_dict(
|
||||
{
|
||||
"Optimization frequency loop [Hz]": frequency_for_one_optimization_step,
|
||||
"Optimization step": optimization_step,
|
||||
},
|
||||
mode="train",
|
||||
custom_step_key="Optimization step",
|
||||
)
|
||||
|
||||
optimization_step += 1
|
||||
if optimization_step % log_freq == 0:
|
||||
|
@ -693,35 +796,45 @@ def add_actor_information_and_train(
|
|||
interaction_step = (
|
||||
interaction_message["Interaction step"] if interaction_message is not None else 0
|
||||
)
|
||||
logger.save_checkpoint(
|
||||
optimization_step,
|
||||
policy,
|
||||
|
||||
# Create checkpoint directory
|
||||
checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, online_steps, optimization_step)
|
||||
|
||||
# Save checkpoint
|
||||
save_checkpoint(
|
||||
checkpoint_dir,
|
||||
optimization_step,
|
||||
cfg,
|
||||
policy,
|
||||
optimizers,
|
||||
scheduler=None,
|
||||
identifier=step_identifier,
|
||||
interaction_step=interaction_step,
|
||||
scheduler=None
|
||||
)
|
||||
|
||||
# Update the "last" symlink
|
||||
update_last_checkpoint(checkpoint_dir)
|
||||
|
||||
# TODO : temporarly save replay buffer here, remove later when on the robot
|
||||
# We want to control this with the keyboard inputs
|
||||
dataset_dir = logger.log_dir / "dataset"
|
||||
if dataset_dir.exists() and dataset_dir.is_dir():
|
||||
shutil.rmtree(
|
||||
dataset_dir,
|
||||
)
|
||||
replay_buffer.to_lerobot_dataset(dataset_repo_id, fps=fps, root=logger.log_dir / "dataset")
|
||||
dataset_dir = os.path.join(cfg.output_dir, "dataset")
|
||||
if os.path.exists(dataset_dir) and os.path.isdir(dataset_dir):
|
||||
shutil.rmtree(dataset_dir)
|
||||
|
||||
# Save dataset
|
||||
replay_buffer.to_lerobot_dataset(
|
||||
dataset_repo_id,
|
||||
fps=fps,
|
||||
root=dataset_dir
|
||||
)
|
||||
|
||||
if offline_replay_buffer is not None:
|
||||
dataset_dir = logger.log_dir / "dataset_offline"
|
||||
|
||||
if dataset_dir.exists() and dataset_dir.is_dir():
|
||||
shutil.rmtree(
|
||||
dataset_dir,
|
||||
)
|
||||
dataset_offline_dir = os.path.join(cfg.output_dir, "dataset_offline")
|
||||
if os.path.exists(dataset_offline_dir) and os.path.isdir(dataset_offline_dir):
|
||||
shutil.rmtree(dataset_offline_dir)
|
||||
|
||||
offline_replay_buffer.to_lerobot_dataset(
|
||||
cfg.dataset_repo_id,
|
||||
fps=cfg.fps,
|
||||
root=logger.log_dir / "dataset_offline",
|
||||
cfg.dataset.dataset_repo_id,
|
||||
fps=cfg.env.fps,
|
||||
root=dataset_offline_dir,
|
||||
)
|
||||
|
||||
logging.info("Resume training")
|
||||
|
@ -756,12 +869,12 @@ def make_optimizers_and_scheduler(cfg, policy: nn.Module):
|
|||
optimizer_actor = torch.optim.Adam(
|
||||
# NOTE: Handle the case of shared encoder where the encoder weights are not optimized with the gradient of the actor
|
||||
params=policy.actor.parameters_to_optimize,
|
||||
lr=policy.config.actor_lr,
|
||||
lr=cfg.policy.actor_lr,
|
||||
)
|
||||
optimizer_critic = torch.optim.Adam(
|
||||
params=policy.critic_ensemble.parameters(), lr=policy.config.critic_lr
|
||||
params=policy.critic_ensemble.parameters(), lr=cfg.policy.critic_lr
|
||||
)
|
||||
optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=policy.config.critic_lr)
|
||||
optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=cfg.policy.critic_lr)
|
||||
lr_scheduler = None
|
||||
optimizers = {
|
||||
"actor": optimizer_actor,
|
||||
|
@ -771,19 +884,38 @@ def make_optimizers_and_scheduler(cfg, policy: nn.Module):
|
|||
return optimizers, lr_scheduler
|
||||
|
||||
|
||||
def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = None):
|
||||
if out_dir is None:
|
||||
raise NotImplementedError()
|
||||
def train(cfg: TrainPipelineConfig, job_name: str | None = None):
|
||||
"""
|
||||
Main training function that initializes and runs the training process.
|
||||
|
||||
Args:
|
||||
cfg (TrainPipelineConfig): The training configuration
|
||||
job_name (str | None, optional): Job name for logging. Defaults to None.
|
||||
"""
|
||||
if cfg.output_dir is None:
|
||||
raise ValueError("Output directory must be specified in config")
|
||||
|
||||
if job_name is None:
|
||||
raise NotImplementedError()
|
||||
job_name = cfg.job_name
|
||||
|
||||
if job_name is None:
|
||||
raise ValueError("Job name must be specified either in config or as a parameter")
|
||||
|
||||
init_logging()
|
||||
logging.info(pformat(OmegaConf.to_container(cfg)))
|
||||
logging.info(pformat(cfg.to_dict()))
|
||||
|
||||
logger = Logger(cfg, out_dir, wandb_job_name=job_name)
|
||||
cfg = handle_resume_logic(cfg, out_dir)
|
||||
# Setup WandB logging if enabled
|
||||
if cfg.wandb.enable and cfg.wandb.project:
|
||||
from lerobot.common.utils.wandb_utils import WandBLogger
|
||||
wandb_logger = WandBLogger(cfg)
|
||||
else:
|
||||
wandb_logger = None
|
||||
logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"]))
|
||||
|
||||
# Handle resume logic
|
||||
cfg = handle_resume_logic(cfg)
|
||||
|
||||
set_global_seed(cfg.seed)
|
||||
set_seed(seed=cfg.seed)
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
@ -791,24 +923,23 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||
shutdown_event = setup_process_handlers(use_threads(cfg))
|
||||
|
||||
start_learner_threads(
|
||||
cfg,
|
||||
logger,
|
||||
out_dir,
|
||||
shutdown_event,
|
||||
cfg=cfg,
|
||||
wandb_logger=wandb_logger,
|
||||
shutdown_event=shutdown_event,
|
||||
)
|
||||
|
||||
|
||||
@hydra.main(version_base="1.2", config_name="default", config_path="../../configs")
|
||||
def train_cli(cfg: dict):
|
||||
@parser.wrap()
|
||||
def train_cli(cfg: TrainPipelineConfig):
|
||||
|
||||
if not use_threads(cfg):
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
mp.set_start_method("spawn")
|
||||
|
||||
# Use the job_name from the config
|
||||
train(
|
||||
cfg,
|
||||
out_dir=hydra.core.hydra_config.HydraConfig.get().run.dir,
|
||||
job_name=hydra.core.hydra_config.HydraConfig.get().job.name,
|
||||
job_name=cfg.job_name,
|
||||
)
|
||||
|
||||
logging.info("[LEARNER] train_cli finished")
|
||||
|
@ -816,5 +947,4 @@ def train_cli(cfg: dict):
|
|||
|
||||
if __name__ == "__main__":
|
||||
train_cli()
|
||||
|
||||
logging.info("[LEARNER] main finished")
|
||||
|
|
|
@ -1,4 +1,7 @@
|
|||
from typing import Any
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import einops
|
||||
import gymnasium as gym
|
||||
|
@ -6,7 +9,11 @@ import numpy as np
|
|||
import torch
|
||||
from mani_skill.utils.wrappers.record import RecordEpisode
|
||||
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
|
||||
from omegaconf import DictConfig
|
||||
|
||||
from lerobot.common.envs.configs import EnvConfig, ManiskillEnvConfig
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.common.constants import ACTION, OBS_IMAGE, OBS_ROBOT
|
||||
|
||||
|
||||
def preprocess_maniskill_observation(
|
||||
|
@ -46,9 +53,14 @@ def preprocess_maniskill_observation(
|
|||
return return_observations
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class ManiSkillObservationWrapper(gym.ObservationWrapper):
|
||||
def __init__(self, env, device: torch.device = "cuda"):
|
||||
super().__init__(env)
|
||||
if isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
self.device = device
|
||||
|
||||
def observation(self, observation):
|
||||
|
@ -108,76 +120,129 @@ class ManiSkillMultiplyActionWrapper(gym.Wrapper):
|
|||
return obs, reward, terminated, truncated, info
|
||||
|
||||
|
||||
class BatchCompatibleWrapper(gym.ObservationWrapper):
|
||||
"""Ensures observations are batch-compatible by adding a batch dimension if necessary."""
|
||||
def __init__(self, env):
|
||||
super().__init__(env)
|
||||
|
||||
def observation(self, observation: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
||||
for key in observation:
|
||||
if "image" in key and observation[key].dim() == 3:
|
||||
observation[key] = observation[key].unsqueeze(0)
|
||||
if "state" in key and observation[key].dim() == 1:
|
||||
observation[key] = observation[key].unsqueeze(0)
|
||||
return observation
|
||||
|
||||
|
||||
class TimeLimitWrapper(gym.Wrapper):
|
||||
"""Adds a time limit to the environment based on fps and control_time."""
|
||||
def __init__(self, env, control_time_s, fps):
|
||||
super().__init__(env)
|
||||
self.control_time_s = control_time_s
|
||||
self.fps = fps
|
||||
self.max_episode_steps = int(self.control_time_s * self.fps)
|
||||
self.current_step = 0
|
||||
|
||||
def step(self, action):
|
||||
obs, reward, terminated, truncated, info = self.env.step(action)
|
||||
self.current_step += 1
|
||||
|
||||
if self.current_step >= self.max_episode_steps:
|
||||
terminated = True
|
||||
|
||||
return obs, reward, terminated, truncated, info
|
||||
|
||||
def reset(self, *, seed=None, options=None):
|
||||
self.current_step = 0
|
||||
return super().reset(seed=seed, options=options)
|
||||
|
||||
|
||||
def make_maniskill(
|
||||
cfg: DictConfig,
|
||||
cfg: ManiskillEnvConfig,
|
||||
n_envs: int | None = None,
|
||||
) -> gym.Env:
|
||||
"""
|
||||
Factory function to create a ManiSkill environment with standard wrappers.
|
||||
|
||||
Args:
|
||||
task: Name of the ManiSkill task
|
||||
obs_mode: Observation mode (rgb, rgbd, etc)
|
||||
control_mode: Control mode for the robot
|
||||
render_mode: Rendering mode
|
||||
sensor_configs: Camera sensor configurations
|
||||
cfg: Configuration for the ManiSkill environment
|
||||
n_envs: Number of parallel environments
|
||||
|
||||
Returns:
|
||||
A wrapped ManiSkill environment
|
||||
"""
|
||||
|
||||
env = gym.make(
|
||||
cfg.env.task,
|
||||
obs_mode=cfg.env.obs,
|
||||
control_mode=cfg.env.control_mode,
|
||||
render_mode=cfg.env.render_mode,
|
||||
sensor_configs={"width": cfg.env.image_size, "height": cfg.env.image_size},
|
||||
cfg.task,
|
||||
obs_mode=cfg.obs_type,
|
||||
control_mode=cfg.control_mode,
|
||||
render_mode=cfg.render_mode,
|
||||
sensor_configs={"width": cfg.image_size, "height": cfg.image_size},
|
||||
num_envs=n_envs,
|
||||
)
|
||||
|
||||
if cfg.env.video_record.enabled:
|
||||
# Add video recording if enabled
|
||||
if cfg.video_record.enabled:
|
||||
env = RecordEpisode(
|
||||
env,
|
||||
output_dir=cfg.env.video_record.record_dir,
|
||||
output_dir=cfg.video_record.record_dir,
|
||||
save_trajectory=True,
|
||||
trajectory_name=cfg.env.video_record.trajectory_name,
|
||||
trajectory_name=cfg.video_record.trajectory_name,
|
||||
save_video=True,
|
||||
video_fps=30,
|
||||
)
|
||||
env = ManiSkillObservationWrapper(env, device=cfg.env.device)
|
||||
|
||||
# Add observation and image processing
|
||||
env = ManiSkillObservationWrapper(env, device=cfg.device)
|
||||
env = ManiSkillVectorEnv(env, ignore_terminations=True, auto_reset=False)
|
||||
env._max_episode_steps = env.max_episode_steps = 50 # gym_utils.find_max_episode_steps_value(env)
|
||||
env.unwrapped.metadata["render_fps"] = 20
|
||||
env._max_episode_steps = env.max_episode_steps = cfg.episode_length
|
||||
env.unwrapped.metadata["render_fps"] = cfg.fps
|
||||
|
||||
# Add compatibility wrappers
|
||||
env = ManiSkillCompat(env)
|
||||
env = ManiSkillActionWrapper(env)
|
||||
env = ManiSkillMultiplyActionWrapper(env, multiply_factor=0.03)
|
||||
|
||||
env = ManiSkillMultiplyActionWrapper(env, multiply_factor=0.03) # Scale actions for better control
|
||||
|
||||
return env
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
import hydra
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--config", type=str, default="lerobot/configs/env/maniskill_example.yaml")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Initialize config
|
||||
with hydra.initialize(version_base=None, config_path="../../configs"):
|
||||
cfg = hydra.compose(config_name="env/maniskill_example.yaml")
|
||||
|
||||
env = make_maniskill(
|
||||
task=cfg.env.task,
|
||||
obs_mode=cfg.env.obs,
|
||||
control_mode=cfg.env.control_mode,
|
||||
render_mode=cfg.env.render_mode,
|
||||
sensor_configs={"width": cfg.env.render_size, "height": cfg.env.render_size},
|
||||
)
|
||||
|
||||
print("env done")
|
||||
@parser.wrap()
|
||||
def main(cfg: ManiskillEnvConfig):
|
||||
"""Main function to run the ManiSkill environment."""
|
||||
# Create the ManiSkill environment
|
||||
env = make_maniskill(cfg, n_envs=1)
|
||||
|
||||
# Reset the environment
|
||||
obs, info = env.reset()
|
||||
random_action = env.action_space.sample()
|
||||
obs, reward, terminated, truncated, info = env.step(random_action)
|
||||
|
||||
# Run a simple interaction loop
|
||||
sum_reward = 0
|
||||
for i in range(100):
|
||||
# Sample a random action
|
||||
action = env.action_space.sample()
|
||||
|
||||
# Step the environment
|
||||
start_time = time.perf_counter()
|
||||
obs, reward, terminated, truncated, info = env.step(action)
|
||||
step_time = time.perf_counter() - start_time
|
||||
sum_reward += reward
|
||||
# Log information
|
||||
|
||||
# Reset if episode terminated
|
||||
if terminated or truncated:
|
||||
logging.info(f"Step {i}, reward: {sum_reward}, step time: {step_time}s")
|
||||
sum_reward = 0
|
||||
obs, info = env.reset()
|
||||
|
||||
# Close the environment
|
||||
env.close()
|
||||
|
||||
|
||||
# if __name__ == "__main__":
|
||||
# logging.basicConfig(level=logging.INFO)
|
||||
# main()
|
||||
|
||||
if __name__ == "__main__":
|
||||
import draccus
|
||||
config = ManiskillEnvConfig()
|
||||
draccus.set_config_type("json")
|
||||
draccus.dump(config=config, stream=open(file='run_config.json', mode='w'), )
|
Loading…
Reference in New Issue