[WIP] Non functional yet
Add ManiSkill environment configuration and wrappers - Introduced `VideoRecordConfig` for video recording settings. - Added `ManiskillEnvConfig` to encapsulate environment-specific configurations. - Implemented various wrappers for the ManiSkill environment, including observation and action scaling. - Enhanced the `make_maniskill` function to create a wrapped ManiSkill environment with video recording and observation processing. - Updated the `actor_server` and `learner_server` to utilize the new configuration structure. - Refactored the training pipeline to accommodate the new environment and policy configurations.
This commit is contained in:
parent
b7b6d8102f
commit
dd37bd412e
|
@ -154,3 +154,61 @@ class XarmEnv(EnvConfig):
|
||||||
"visualization_height": self.visualization_height,
|
"visualization_height": self.visualization_height,
|
||||||
"max_episode_steps": self.episode_length,
|
"max_episode_steps": self.episode_length,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class VideoRecordConfig:
|
||||||
|
"""Configuration for video recording in ManiSkill environments."""
|
||||||
|
enabled: bool = False
|
||||||
|
record_dir: str = "videos"
|
||||||
|
trajectory_name: str = "trajectory"
|
||||||
|
|
||||||
|
@EnvConfig.register_subclass("maniskill_push")
|
||||||
|
@dataclass
|
||||||
|
class ManiskillEnvConfig(EnvConfig):
|
||||||
|
"""Configuration for the ManiSkill environment."""
|
||||||
|
name: str = "maniskill/pushcube"
|
||||||
|
task: str = "PushCube-v1"
|
||||||
|
image_size: int = 64
|
||||||
|
control_mode: str = "pd_ee_delta_pose"
|
||||||
|
state_dim: int = 25
|
||||||
|
action_dim: int = 7
|
||||||
|
fps: int = 400
|
||||||
|
episode_length: int = 50
|
||||||
|
obs_type: str = "rgb"
|
||||||
|
render_mode: str = "rgb_array"
|
||||||
|
render_size: int = 64
|
||||||
|
device: str = "cuda"
|
||||||
|
robot: str = "so100" # This is a hack to make the robot config work
|
||||||
|
video_record: VideoRecordConfig = field(default_factory=VideoRecordConfig)
|
||||||
|
features: dict[str, PolicyFeature] = field(
|
||||||
|
default_factory=lambda: {
|
||||||
|
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
|
||||||
|
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 64, 64)),
|
||||||
|
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(25,)),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
features_map: dict[str, str] = field(
|
||||||
|
default_factory=lambda: {
|
||||||
|
"action": ACTION,
|
||||||
|
"observation.image": OBS_IMAGE,
|
||||||
|
"observation.state": OBS_ROBOT,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
reward_classifier: dict[str, str | None] = field(
|
||||||
|
default_factory=lambda: {
|
||||||
|
"pretrained_path": None,
|
||||||
|
"config_path": None,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def gym_kwargs(self) -> dict:
|
||||||
|
return {
|
||||||
|
"obs_type": self.obs_type,
|
||||||
|
"render_mode": self.render_mode,
|
||||||
|
"max_episode_steps": self.episode_length,
|
||||||
|
"control_mode": self.control_mode,
|
||||||
|
"sensor_configs": {"width": self.image_size, "height": self.image_size},
|
||||||
|
"num_envs": 1,
|
||||||
|
}
|
|
@ -69,88 +69,3 @@ def make_env(cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False) -> g
|
||||||
|
|
||||||
return env
|
return env
|
||||||
|
|
||||||
|
|
||||||
def make_maniskill_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv | None:
|
|
||||||
"""Make ManiSkill3 gym environment"""
|
|
||||||
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
|
|
||||||
|
|
||||||
env = gym.make(
|
|
||||||
cfg.env.task,
|
|
||||||
obs_mode=cfg.env.obs,
|
|
||||||
control_mode=cfg.env.control_mode,
|
|
||||||
render_mode=cfg.env.render_mode,
|
|
||||||
sensor_configs=dict(width=cfg.env.image_size, height=cfg.env.image_size),
|
|
||||||
num_envs=n_envs,
|
|
||||||
)
|
|
||||||
# cfg.env_cfg.control_mode = cfg.eval_env_cfg.control_mode = env.control_mode
|
|
||||||
env = ManiSkillVectorEnv(env, ignore_terminations=True)
|
|
||||||
# state should have the size of 25
|
|
||||||
# env = ConvertToLeRobotEnv(env, n_envs)
|
|
||||||
# env = PixelWrapper(cfg, env, n_envs)
|
|
||||||
env._max_episode_steps = env.max_episode_steps = 50 # gym_utils.find_max_episode_steps_value(env)
|
|
||||||
env.unwrapped.metadata["render_fps"] = 20
|
|
||||||
|
|
||||||
return env
|
|
||||||
|
|
||||||
|
|
||||||
class PixelWrapper(gym.Wrapper):
|
|
||||||
"""
|
|
||||||
Wrapper for pixel observations. Works with Maniskill vectorized environments
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, cfg, env, num_envs, num_frames=3):
|
|
||||||
super().__init__(env)
|
|
||||||
self.cfg = cfg
|
|
||||||
self.env = env
|
|
||||||
self.observation_space = gym.spaces.Box(
|
|
||||||
low=0,
|
|
||||||
high=255,
|
|
||||||
shape=(num_envs, num_frames * 3, cfg.env.render_size, cfg.env.render_size),
|
|
||||||
dtype=np.uint8,
|
|
||||||
)
|
|
||||||
self._frames = deque([], maxlen=num_frames)
|
|
||||||
self._render_size = cfg.env.render_size
|
|
||||||
|
|
||||||
def _get_obs(self, obs):
|
|
||||||
frame = obs["sensor_data"]["base_camera"]["rgb"].cpu().permute(0, 3, 1, 2)
|
|
||||||
self._frames.append(frame)
|
|
||||||
return {"pixels": torch.from_numpy(np.concatenate(self._frames, axis=1)).to(self.env.device)}
|
|
||||||
|
|
||||||
def reset(self, seed):
|
|
||||||
obs, info = self.env.reset() # (seed=seed)
|
|
||||||
for _ in range(self._frames.maxlen):
|
|
||||||
obs_frames = self._get_obs(obs)
|
|
||||||
return obs_frames, info
|
|
||||||
|
|
||||||
def step(self, action):
|
|
||||||
obs, reward, terminated, truncated, info = self.env.step(action)
|
|
||||||
return self._get_obs(obs), reward, terminated, truncated, info
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: Remove this
|
|
||||||
class ConvertToLeRobotEnv(gym.Wrapper):
|
|
||||||
def __init__(self, env, num_envs):
|
|
||||||
super().__init__(env)
|
|
||||||
|
|
||||||
def reset(self, seed=None, options=None):
|
|
||||||
obs, info = self.env.reset(seed=seed, options={})
|
|
||||||
return self._get_obs(obs), info
|
|
||||||
|
|
||||||
def step(self, action):
|
|
||||||
obs, reward, terminated, truncated, info = self.env.step(action)
|
|
||||||
return self._get_obs(obs), reward, terminated, truncated, info
|
|
||||||
|
|
||||||
def _get_obs(self, observation):
|
|
||||||
sensor_data = observation.pop("sensor_data")
|
|
||||||
del observation["sensor_param"]
|
|
||||||
images = []
|
|
||||||
for cam_data in sensor_data.values():
|
|
||||||
images.append(cam_data["rgb"])
|
|
||||||
|
|
||||||
images = torch.concat(images, axis=-1)
|
|
||||||
# flatten the rest of the data which should just be state data
|
|
||||||
observation = common.flatten_state_dict(observation, use_torch=True, device=self.base_env.device)
|
|
||||||
ret = dict()
|
|
||||||
ret["state"] = observation
|
|
||||||
ret["pixels"] = images
|
|
||||||
return ret
|
|
||||||
|
|
|
@ -31,12 +31,19 @@ class SACConfig(PreTrainedConfig):
|
||||||
Args:
|
Args:
|
||||||
n_obs_steps: Number of environment steps worth of observations to pass to the policy.
|
n_obs_steps: Number of environment steps worth of observations to pass to the policy.
|
||||||
normalization_mapping: Mapping from feature types to normalization modes.
|
normalization_mapping: Mapping from feature types to normalization modes.
|
||||||
|
dataset_stats: Statistics for normalizing different data types.
|
||||||
camera_number: Number of cameras to use.
|
camera_number: Number of cameras to use.
|
||||||
|
device: Device to use for training.
|
||||||
storage_device: Device to use for storage.
|
storage_device: Device to use for storage.
|
||||||
vision_encoder_name: Name of the vision encoder to use.
|
vision_encoder_name: Name of the vision encoder to use.
|
||||||
freeze_vision_encoder: Whether to freeze the vision encoder.
|
freeze_vision_encoder: Whether to freeze the vision encoder.
|
||||||
image_encoder_hidden_dim: Hidden dimension for the image encoder.
|
image_encoder_hidden_dim: Hidden dimension for the image encoder.
|
||||||
shared_encoder: Whether to use a shared encoder.
|
shared_encoder: Whether to use a shared encoder.
|
||||||
|
online_steps: Total number of online training steps.
|
||||||
|
online_env_seed: Seed for the online environment.
|
||||||
|
online_buffer_capacity: Capacity of the online replay buffer.
|
||||||
|
online_step_before_learning: Number of steps to collect before starting learning.
|
||||||
|
policy_update_freq: Frequency of policy updates.
|
||||||
discount: Discount factor for the RL algorithm.
|
discount: Discount factor for the RL algorithm.
|
||||||
temperature_init: Initial temperature for entropy regularization.
|
temperature_init: Initial temperature for entropy regularization.
|
||||||
num_critics: Number of critic networks.
|
num_critics: Number of critic networks.
|
||||||
|
@ -54,6 +61,8 @@ class SACConfig(PreTrainedConfig):
|
||||||
critic_network_kwargs: Additional arguments for critic networks.
|
critic_network_kwargs: Additional arguments for critic networks.
|
||||||
actor_network_kwargs: Additional arguments for actor network.
|
actor_network_kwargs: Additional arguments for actor network.
|
||||||
policy_kwargs: Additional arguments for policy.
|
policy_kwargs: Additional arguments for policy.
|
||||||
|
actor_learner_config: Configuration for actor-learner communication.
|
||||||
|
concurrency: Configuration for concurrency model.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Input / output structure
|
# Input / output structure
|
||||||
|
@ -86,6 +95,7 @@ class SACConfig(PreTrainedConfig):
|
||||||
|
|
||||||
# Architecture specifics
|
# Architecture specifics
|
||||||
camera_number: int = 1
|
camera_number: int = 1
|
||||||
|
device: str = "cuda"
|
||||||
storage_device: str = "cpu"
|
storage_device: str = "cpu"
|
||||||
# Set to "helper2424/resnet10" for hil serl
|
# Set to "helper2424/resnet10" for hil serl
|
||||||
vision_encoder_name: str | None = None
|
vision_encoder_name: str | None = None
|
||||||
|
@ -93,6 +103,13 @@ class SACConfig(PreTrainedConfig):
|
||||||
image_encoder_hidden_dim: int = 32
|
image_encoder_hidden_dim: int = 32
|
||||||
shared_encoder: bool = True
|
shared_encoder: bool = True
|
||||||
|
|
||||||
|
# Training parameter
|
||||||
|
online_steps: int = 1000000
|
||||||
|
online_env_seed: int = 10000
|
||||||
|
online_buffer_capacity: int = 10000
|
||||||
|
online_step_before_learning: int = 100
|
||||||
|
policy_update_freq: int = 1
|
||||||
|
|
||||||
# SAC algorithm parameters
|
# SAC algorithm parameters
|
||||||
discount: float = 0.99
|
discount: float = 0.99
|
||||||
temperature_init: float = 1.0
|
temperature_init: float = 1.0
|
||||||
|
@ -132,11 +149,17 @@ class SACConfig(PreTrainedConfig):
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Deprecated, kept for backward compatibility
|
|
||||||
actor_learner_config: dict[str, str | int] = field(
|
actor_learner_config: dict[str, str | int] = field(
|
||||||
default_factory=lambda: {
|
default_factory=lambda: {
|
||||||
"learner_host": "127.0.0.1",
|
"learner_host": "127.0.0.1",
|
||||||
"learner_port": 50051,
|
"learner_port": 50051,
|
||||||
|
"policy_parameters_push_frequency": 4,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
concurrency: dict[str, str] = field(
|
||||||
|
default_factory=lambda: {
|
||||||
|
"actor": "threads",
|
||||||
|
"learner": "threads"
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -92,6 +92,8 @@ class WandBLogger:
|
||||||
resume="must" if cfg.resume else None,
|
resume="must" if cfg.resume else None,
|
||||||
mode=self.cfg.mode if self.cfg.mode in ["online", "offline", "disabled"] else "online",
|
mode=self.cfg.mode if self.cfg.mode in ["online", "offline", "disabled"] else "online",
|
||||||
)
|
)
|
||||||
|
# Handle custom step key for rl asynchronous training.
|
||||||
|
self._wandb_custom_step_key: set[str] | None = None
|
||||||
print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"]))
|
print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"]))
|
||||||
logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}")
|
logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}")
|
||||||
self._wandb = wandb
|
self._wandb = wandb
|
||||||
|
@ -108,9 +110,24 @@ class WandBLogger:
|
||||||
artifact.add_file(checkpoint_dir / PRETRAINED_MODEL_DIR / SAFETENSORS_SINGLE_FILE)
|
artifact.add_file(checkpoint_dir / PRETRAINED_MODEL_DIR / SAFETENSORS_SINGLE_FILE)
|
||||||
self._wandb.log_artifact(artifact)
|
self._wandb.log_artifact(artifact)
|
||||||
|
|
||||||
def log_dict(self, d: dict, step: int, mode: str = "train"):
|
def log_dict(self, d: dict, step: int, mode: str = "train", custom_step_key: str | None = None):
|
||||||
if mode not in {"train", "eval"}:
|
if mode not in {"train", "eval"}:
|
||||||
raise ValueError(mode)
|
raise ValueError(mode)
|
||||||
|
if step is None and custom_step_key is None:
|
||||||
|
raise ValueError("Either step or custom_step_key must be provided.")
|
||||||
|
|
||||||
|
# NOTE: This is not simple. Wandb step is it must always monotonically increase and it
|
||||||
|
# increases with each wandb.log call, but in the case of asynchronous RL for example,
|
||||||
|
# multiple time steps is possible for example, the interaction step with the environment,
|
||||||
|
# the training step, the evaluation step, etc. So we need to define a custom step key
|
||||||
|
# to log the correct step for each metric.
|
||||||
|
if custom_step_key is not None:
|
||||||
|
if self._wandb_custom_step_key is None:
|
||||||
|
self._wandb_custom_step_key = set()
|
||||||
|
new_custom_key = f"{mode}/{custom_step_key}"
|
||||||
|
if new_custom_key not in self._wandb_custom_step_key:
|
||||||
|
self._wandb_custom_step_key.add(new_custom_key)
|
||||||
|
self._wandb.define_metric(new_custom_key, hidden=True)
|
||||||
|
|
||||||
for k, v in d.items():
|
for k, v in d.items():
|
||||||
if not isinstance(v, (int, float, str)):
|
if not isinstance(v, (int, float, str)):
|
||||||
|
@ -118,7 +135,26 @@ class WandBLogger:
|
||||||
f'WandB logging of key "{k}" was ignored as its type is not handled by this wrapper.'
|
f'WandB logging of key "{k}" was ignored as its type is not handled by this wrapper.'
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
self._wandb.log({f"{mode}/{k}": v}, step=step)
|
|
||||||
|
# Do not log the custom step key itself.
|
||||||
|
if (
|
||||||
|
self._wandb_custom_step_key is not None
|
||||||
|
and k in self._wandb_custom_step_key
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if custom_step_key is not None:
|
||||||
|
value_custom_step = d[custom_step_key]
|
||||||
|
self._wandb.log(
|
||||||
|
{
|
||||||
|
f"{mode}/{k}": v,
|
||||||
|
f"{mode}/{custom_step_key}": value_custom_step,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
self._wandb.log(data={f"{mode}/{k}": v}, step=step)
|
||||||
|
|
||||||
|
|
||||||
def log_video(self, video_path: str, step: int, mode: str = "train"):
|
def log_video(self, video_path: str, step: int, mode: str = "train"):
|
||||||
if mode not in {"train", "eval"}:
|
if mode not in {"train", "eval"}:
|
||||||
|
|
|
@ -34,11 +34,10 @@ TRAIN_CONFIG_NAME = "train_config.json"
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TrainPipelineConfig(HubMixin):
|
class TrainPipelineConfig(HubMixin):
|
||||||
dataset: DatasetConfig
|
dataset: DatasetConfig | None = None # NOTE: In RL, we don't need a dataset
|
||||||
env: envs.EnvConfig | None = None
|
env: envs.EnvConfig | None = None
|
||||||
policy: PreTrainedConfig | None = None
|
policy: PreTrainedConfig | None = None
|
||||||
# Set `dir` to where you would like to save all of the run outputs. If you run another training session
|
# Set `dir` to where you would like to save all of the run outputs. If you run another training session # with the same value for `dir` its contents will be overwritten unless you set `resume` to true.
|
||||||
# with the same value for `dir` its contents will be overwritten unless you set `resume` to true.
|
|
||||||
output_dir: Path | None = None
|
output_dir: Path | None = None
|
||||||
job_name: str | None = None
|
job_name: str | None = None
|
||||||
# Set `resume` to true to resume a previous run. In order for this to work, you will need to make sure
|
# Set `resume` to true to resume a previous run. In order for this to work, you will need to make sure
|
||||||
|
|
|
@ -21,26 +21,23 @@ from statistics import mean, quantiles
|
||||||
|
|
||||||
# from lerobot.scripts.eval import eval_policy
|
# from lerobot.scripts.eval import eval_policy
|
||||||
import grpc
|
import grpc
|
||||||
import hydra
|
|
||||||
import torch
|
import torch
|
||||||
from omegaconf import DictConfig
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.multiprocessing import Event, Queue
|
from torch.multiprocessing import Event, Queue
|
||||||
|
|
||||||
# TODO: Remove the import of maniskill
|
# TODO: Remove the import of maniskill
|
||||||
# from lerobot.common.envs.factory import make_maniskill_env
|
|
||||||
# from lerobot.common.envs.utils import preprocess_maniskill_observation
|
|
||||||
from lerobot.common.policies.factory import make_policy
|
from lerobot.common.policies.factory import make_policy
|
||||||
from lerobot.common.policies.sac.modeling_sac import SACPolicy
|
from lerobot.common.policies.sac.modeling_sac import SACPolicy
|
||||||
from lerobot.common.robot_devices.robots.factory import make_robot
|
from lerobot.common.robot_devices.robots.utils import Robot, make_robot
|
||||||
from lerobot.common.robot_devices.robots.utils import Robot
|
|
||||||
from lerobot.common.robot_devices.utils import busy_wait
|
from lerobot.common.robot_devices.utils import busy_wait
|
||||||
|
from lerobot.common.utils.random_utils import set_seed
|
||||||
from lerobot.common.utils.utils import (
|
from lerobot.common.utils.utils import (
|
||||||
TimerManager,
|
TimerManager,
|
||||||
get_safe_torch_device,
|
get_safe_torch_device,
|
||||||
init_logging,
|
init_logging,
|
||||||
set_global_seed,
|
|
||||||
)
|
)
|
||||||
|
from lerobot.configs import parser
|
||||||
|
from lerobot.configs.train import TrainPipelineConfig
|
||||||
from lerobot.scripts.server import hilserl_pb2, hilserl_pb2_grpc, learner_service
|
from lerobot.scripts.server import hilserl_pb2, hilserl_pb2_grpc, learner_service
|
||||||
from lerobot.scripts.server.buffer import (
|
from lerobot.scripts.server.buffer import (
|
||||||
Transition,
|
Transition,
|
||||||
|
@ -61,7 +58,7 @@ ACTOR_SHUTDOWN_TIMEOUT = 30
|
||||||
|
|
||||||
|
|
||||||
def receive_policy(
|
def receive_policy(
|
||||||
cfg: DictConfig,
|
cfg: TrainPipelineConfig,
|
||||||
parameters_queue: Queue,
|
parameters_queue: Queue,
|
||||||
shutdown_event: any, # Event,
|
shutdown_event: any, # Event,
|
||||||
learner_client: hilserl_pb2_grpc.LearnerServiceStub | None = None,
|
learner_client: hilserl_pb2_grpc.LearnerServiceStub | None = None,
|
||||||
|
@ -72,12 +69,12 @@ def receive_policy(
|
||||||
if not use_threads(cfg):
|
if not use_threads(cfg):
|
||||||
# Setup process handlers to handle shutdown signal
|
# Setup process handlers to handle shutdown signal
|
||||||
# But use shutdown event from the main process
|
# But use shutdown event from the main process
|
||||||
setup_process_handlers(False)
|
setup_process_handlers(use_threads=False)
|
||||||
|
|
||||||
if grpc_channel is None or learner_client is None:
|
if grpc_channel is None or learner_client is None:
|
||||||
learner_client, grpc_channel = learner_service_client(
|
learner_client, grpc_channel = learner_service_client(
|
||||||
host=cfg.actor_learner_config.learner_host,
|
host=cfg.policy.actor_learner_config["learner_host"],
|
||||||
port=cfg.actor_learner_config.learner_port,
|
port=cfg.policy.actor_learner_config["learner_port"],
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -132,7 +129,7 @@ def interactions_stream(
|
||||||
|
|
||||||
|
|
||||||
def send_transitions(
|
def send_transitions(
|
||||||
cfg: DictConfig,
|
cfg: TrainPipelineConfig,
|
||||||
transitions_queue: Queue,
|
transitions_queue: Queue,
|
||||||
shutdown_event: any, # Event,
|
shutdown_event: any, # Event,
|
||||||
learner_client: hilserl_pb2_grpc.LearnerServiceStub | None = None,
|
learner_client: hilserl_pb2_grpc.LearnerServiceStub | None = None,
|
||||||
|
@ -156,8 +153,8 @@ def send_transitions(
|
||||||
|
|
||||||
if grpc_channel is None or learner_client is None:
|
if grpc_channel is None or learner_client is None:
|
||||||
learner_client, grpc_channel = learner_service_client(
|
learner_client, grpc_channel = learner_service_client(
|
||||||
host=cfg.actor_learner_config.learner_host,
|
host=cfg.policy.actor_learner_config["learner_host"],
|
||||||
port=cfg.actor_learner_config.learner_port,
|
port=cfg.policy.actor_learner_config["learner_port"],
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -173,7 +170,7 @@ def send_transitions(
|
||||||
|
|
||||||
|
|
||||||
def send_interactions(
|
def send_interactions(
|
||||||
cfg: DictConfig,
|
cfg: TrainPipelineConfig,
|
||||||
interactions_queue: Queue,
|
interactions_queue: Queue,
|
||||||
shutdown_event: any, # Event,
|
shutdown_event: any, # Event,
|
||||||
learner_client: hilserl_pb2_grpc.LearnerServiceStub | None = None,
|
learner_client: hilserl_pb2_grpc.LearnerServiceStub | None = None,
|
||||||
|
@ -196,8 +193,8 @@ def send_interactions(
|
||||||
|
|
||||||
if grpc_channel is None or learner_client is None:
|
if grpc_channel is None or learner_client is None:
|
||||||
learner_client, grpc_channel = learner_service_client(
|
learner_client, grpc_channel = learner_service_client(
|
||||||
host=cfg.actor_learner_config.learner_host,
|
host=cfg.policy.actor_learner_config["learner_host"],
|
||||||
port=cfg.actor_learner_config.learner_port,
|
port=cfg.policy.actor_learner_config["learner_port"],
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -269,7 +266,7 @@ def update_policy_parameters(policy: SACPolicy, parameters_queue: Queue, device)
|
||||||
|
|
||||||
|
|
||||||
def act_with_policy(
|
def act_with_policy(
|
||||||
cfg: DictConfig,
|
cfg: TrainPipelineConfig,
|
||||||
robot: Robot,
|
robot: Robot,
|
||||||
reward_classifier: nn.Module,
|
reward_classifier: nn.Module,
|
||||||
shutdown_event: any, # Event,
|
shutdown_event: any, # Event,
|
||||||
|
@ -291,7 +288,7 @@ def act_with_policy(
|
||||||
|
|
||||||
online_env = make_robot_env(robot=robot, reward_classifier=reward_classifier, cfg=cfg)
|
online_env = make_robot_env(robot=robot, reward_classifier=reward_classifier, cfg=cfg)
|
||||||
|
|
||||||
set_global_seed(cfg.seed)
|
set_seed(cfg.seed)
|
||||||
device = get_safe_torch_device(cfg.device, log=True)
|
device = get_safe_torch_device(cfg.device, log=True)
|
||||||
|
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
|
@ -304,7 +301,7 @@ def act_with_policy(
|
||||||
### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters
|
### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters
|
||||||
# TODO: At some point we should just need make sac policy
|
# TODO: At some point we should just need make sac policy
|
||||||
policy: SACPolicy = make_policy(
|
policy: SACPolicy = make_policy(
|
||||||
hydra_cfg=cfg,
|
cfg=cfg.policy,
|
||||||
# dataset_stats=offline_dataset.meta.stats if not cfg.resume else None,
|
# dataset_stats=offline_dataset.meta.stats if not cfg.resume else None,
|
||||||
# Hack: But if we do online training, we do not need dataset_stats
|
# Hack: But if we do online training, we do not need dataset_stats
|
||||||
dataset_stats=None,
|
dataset_stats=None,
|
||||||
|
@ -469,7 +466,7 @@ def get_frequency_stats(list_policy_time: list[float]) -> dict[str, float]:
|
||||||
return stats
|
return stats
|
||||||
|
|
||||||
|
|
||||||
def log_policy_frequency_issue(policy_fps: float, cfg: DictConfig, interaction_step: int):
|
def log_policy_frequency_issue(policy_fps: float, cfg: TrainPipelineConfig, interaction_step: int):
|
||||||
if policy_fps < cfg.fps:
|
if policy_fps < cfg.fps:
|
||||||
logging.warning(
|
logging.warning(
|
||||||
f"[ACTOR] Policy FPS {policy_fps:.1f} below required {cfg.fps} at step {interaction_step}"
|
f"[ACTOR] Policy FPS {policy_fps:.1f} below required {cfg.fps} at step {interaction_step}"
|
||||||
|
@ -497,25 +494,25 @@ def establish_learner_connection(
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def use_threads(cfg: DictConfig) -> bool:
|
def use_threads(cfg: TrainPipelineConfig) -> bool:
|
||||||
return cfg.actor_learner_config.concurrency.actor == "threads"
|
return cfg.policy.concurrency["actor"] == "threads"
|
||||||
|
|
||||||
|
|
||||||
@hydra.main(version_base="1.2", config_name="default", config_path="../../configs")
|
@parser.wrap()
|
||||||
def actor_cli(cfg: dict):
|
def actor_cli(cfg: TrainPipelineConfig):
|
||||||
if not use_threads(cfg):
|
if not use_threads(cfg):
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
|
|
||||||
mp.set_start_method("spawn")
|
mp.set_start_method("spawn")
|
||||||
|
|
||||||
init_logging(log_file="actor.log")
|
init_logging(log_file="actor.log")
|
||||||
robot = make_robot(cfg=cfg.robot)
|
robot = make_robot(robot_type=cfg.env.robot)
|
||||||
|
|
||||||
shutdown_event = setup_process_handlers(use_threads(cfg))
|
shutdown_event = setup_process_handlers(use_threads(cfg))
|
||||||
|
|
||||||
learner_client, grpc_channel = learner_service_client(
|
learner_client, grpc_channel = learner_service_client(
|
||||||
host=cfg.actor_learner_config.learner_host,
|
host=cfg.policy.actor_learner_config["learner_host"],
|
||||||
port=cfg.actor_learner_config.learner_port,
|
port=cfg.policy.actor_learner_config["learner_port"],
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info("[ACTOR] Establishing connection with Learner")
|
logging.info("[ACTOR] Establishing connection with Learner")
|
||||||
|
@ -570,22 +567,22 @@ def actor_cli(cfg: dict):
|
||||||
# TODO: Remove this once we merge into main
|
# TODO: Remove this once we merge into main
|
||||||
reward_classifier = None
|
reward_classifier = None
|
||||||
if (
|
if (
|
||||||
cfg.env.reward_classifier.pretrained_path is not None
|
cfg.env.reward_classifier["pretrained_path"] is not None
|
||||||
and cfg.env.reward_classifier.config_path is not None
|
and cfg.env.reward_classifier["config_path"] is not None
|
||||||
):
|
):
|
||||||
reward_classifier = get_classifier(
|
reward_classifier = get_classifier(
|
||||||
pretrained_path=cfg.env.reward_classifier.pretrained_path,
|
pretrained_path=cfg.env.reward_classifier["pretrained_path"],
|
||||||
config_path=cfg.env.reward_classifier.config_path,
|
config_path=cfg.env.reward_classifier["config_path"],
|
||||||
)
|
)
|
||||||
|
|
||||||
act_with_policy(
|
act_with_policy(
|
||||||
cfg,
|
cfg=cfg,
|
||||||
robot,
|
robot=robot,
|
||||||
reward_classifier,
|
reward_classifier=reward_classifier,
|
||||||
shutdown_event,
|
shutdown_event=shutdown_event,
|
||||||
parameters_queue,
|
parameters_queue=parameters_queue,
|
||||||
transitions_queue,
|
transitions_queue=transitions_queue,
|
||||||
interactions_queue,
|
interactions_queue=interactions_queue,
|
||||||
)
|
)
|
||||||
logging.info("[ACTOR] Policy process joined")
|
logging.info("[ACTOR] Policy process joined")
|
||||||
|
|
||||||
|
|
|
@ -15,10 +15,12 @@ import json
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from lerobot.common.envs.utils import preprocess_observation
|
from lerobot.common.envs.utils import preprocess_observation
|
||||||
|
from lerobot.configs.train import TrainPipelineConfig
|
||||||
|
from lerobot.common.envs.configs import EnvConfig
|
||||||
from lerobot.common.robot_devices.control_utils import (
|
from lerobot.common.robot_devices.control_utils import (
|
||||||
busy_wait,
|
busy_wait,
|
||||||
is_headless,
|
is_headless,
|
||||||
reset_follower_position,
|
# reset_follower_position,
|
||||||
)
|
)
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
@ -28,6 +30,7 @@ from lerobot.common.robot_devices.robots.utils import make_robot_from_config
|
||||||
from lerobot.common.robot_devices.robots.configs import RobotConfig
|
from lerobot.common.robot_devices.robots.configs import RobotConfig
|
||||||
|
|
||||||
from lerobot.scripts.server.kinematics import RobotKinematics
|
from lerobot.scripts.server.kinematics import RobotKinematics
|
||||||
|
from lerobot.scripts.server.maniskill_manipulator import ManiskillEnvConfig, make_maniskill
|
||||||
from lerobot.configs import parser
|
from lerobot.configs import parser
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
@ -1094,7 +1097,10 @@ class ActionScaleWrapper(gym.ActionWrapper):
|
||||||
return action * self.scale_vector, is_intervention
|
return action * self.scale_vector, is_intervention
|
||||||
|
|
||||||
|
|
||||||
def make_robot_env(cfg, robot) -> gym.vector.VectorEnv:
|
@parser.wrap()
|
||||||
|
def make_robot_env(cfg: EnvConfig) -> gym.vector.VectorEnv:
|
||||||
|
# def make_robot_env(cfg: TrainPipelineConfig) -> gym.vector.VectorEnv:
|
||||||
|
# def make_robot_env(cfg: ManiskillEnvConfig) -> gym.vector.VectorEnv:
|
||||||
"""
|
"""
|
||||||
Factory function to create a vectorized robot environment.
|
Factory function to create a vectorized robot environment.
|
||||||
|
|
||||||
|
@ -1106,7 +1112,7 @@ def make_robot_env(cfg, robot) -> gym.vector.VectorEnv:
|
||||||
Returns:
|
Returns:
|
||||||
A vectorized gym environment with all the necessary wrappers applied.
|
A vectorized gym environment with all the necessary wrappers applied.
|
||||||
"""
|
"""
|
||||||
if "maniskill" in cfg.env_name:
|
if "maniskill" in cfg.name:
|
||||||
from lerobot.scripts.server.maniskill_manipulator import make_maniskill
|
from lerobot.scripts.server.maniskill_manipulator import make_maniskill
|
||||||
|
|
||||||
logging.warning("WE SHOULD REMOVE THE MANISKILL BEFORE THE MERGE INTO MAIN")
|
logging.warning("WE SHOULD REMOVE THE MANISKILL BEFORE THE MERGE INTO MAIN")
|
||||||
|
@ -1115,7 +1121,7 @@ def make_robot_env(cfg, robot) -> gym.vector.VectorEnv:
|
||||||
n_envs=1,
|
n_envs=1,
|
||||||
)
|
)
|
||||||
return env
|
return env
|
||||||
|
robot = cfg.robot
|
||||||
# Create base environment
|
# Create base environment
|
||||||
env = HILSerlRobotEnv(
|
env = HILSerlRobotEnv(
|
||||||
robot=robot,
|
robot=robot,
|
||||||
|
@ -1329,80 +1335,82 @@ def replay_episode(env, repo_id, root=None, episode=0):
|
||||||
busy_wait(1 / 10 - dt_s)
|
busy_wait(1 / 10 - dt_s)
|
||||||
|
|
||||||
|
|
||||||
@parser.wrap()
|
# @parser.wrap()
|
||||||
def main(cfg: HILSerlRobotEnvConfig):
|
# def main(cfg):
|
||||||
|
|
||||||
robot = make_robot_from_config(cfg.robot)
|
# robot = make_robot_from_config(cfg.robot)
|
||||||
|
|
||||||
reward_classifier = None #get_classifier(
|
# reward_classifier = None #get_classifier(
|
||||||
# cfg.wrapper.reward_classifier_pretrained_path, cfg.wrapper.reward_classifier_config_file
|
# # cfg.wrapper.reward_classifier_pretrained_path, cfg.wrapper.reward_classifier_config_file
|
||||||
# )
|
# # )
|
||||||
user_relative_joint_positions = True
|
# user_relative_joint_positions = True
|
||||||
|
|
||||||
env = make_robot_env(cfg, robot)
|
# env = make_robot_env(cfg, robot)
|
||||||
|
|
||||||
if cfg.mode == "record":
|
# if cfg.mode == "record":
|
||||||
policy = None
|
# policy = None
|
||||||
if cfg.pretrained_policy_name_or_path is not None:
|
# if cfg.pretrained_policy_name_or_path is not None:
|
||||||
from lerobot.common.policies.sac.modeling_sac import SACPolicy
|
# from lerobot.common.policies.sac.modeling_sac import SACPolicy
|
||||||
|
|
||||||
policy = SACPolicy.from_pretrained(cfg.pretrained_policy_name_or_path)
|
# policy = SACPolicy.from_pretrained(cfg.pretrained_policy_name_or_path)
|
||||||
policy.to(cfg.device)
|
# policy.to(cfg.device)
|
||||||
policy.eval()
|
# policy.eval()
|
||||||
|
|
||||||
record_dataset(
|
# record_dataset(
|
||||||
env,
|
# env,
|
||||||
cfg.repo_id,
|
# cfg.repo_id,
|
||||||
root=cfg.dataset_root,
|
# root=cfg.dataset_root,
|
||||||
num_episodes=cfg.num_episodes,
|
# num_episodes=cfg.num_episodes,
|
||||||
fps=cfg.fps,
|
# fps=cfg.fps,
|
||||||
task_description=cfg.task,
|
# task_description=cfg.task,
|
||||||
policy=policy,
|
# policy=policy,
|
||||||
)
|
# )
|
||||||
exit()
|
# exit()
|
||||||
|
|
||||||
if cfg.mode == "replay":
|
# if cfg.mode == "replay":
|
||||||
replay_episode(
|
# replay_episode(
|
||||||
env,
|
# env,
|
||||||
cfg.replay_repo_id,
|
# cfg.replay_repo_id,
|
||||||
root=cfg.dataset_root,
|
# root=cfg.dataset_root,
|
||||||
episode=cfg.replay_episode,
|
# episode=cfg.replay_episode,
|
||||||
)
|
# )
|
||||||
exit()
|
# exit()
|
||||||
|
|
||||||
env.reset()
|
# env.reset()
|
||||||
|
|
||||||
# Retrieve the robot's action space for joint commands.
|
# # Retrieve the robot's action space for joint commands.
|
||||||
action_space_robot = env.action_space.spaces[0]
|
# action_space_robot = env.action_space.spaces[0]
|
||||||
|
|
||||||
# Initialize the smoothed action as a random sample.
|
# # Initialize the smoothed action as a random sample.
|
||||||
smoothed_action = action_space_robot.sample()
|
# smoothed_action = action_space_robot.sample()
|
||||||
|
|
||||||
# Smoothing coefficient (alpha) defines how much of the new random sample to mix in.
|
# # Smoothing coefficient (alpha) defines how much of the new random sample to mix in.
|
||||||
# A value close to 0 makes the trajectory very smooth (slow to change), while a value close to 1 is less smooth.
|
# # A value close to 0 makes the trajectory very smooth (slow to change), while a value close to 1 is less smooth.
|
||||||
alpha = 1.0
|
# alpha = 1.0
|
||||||
|
|
||||||
num_episode = 0
|
# num_episode = 0
|
||||||
sucesses = []
|
# sucesses = []
|
||||||
while num_episode < 20:
|
# while num_episode < 20:
|
||||||
start_loop_s = time.perf_counter()
|
# start_loop_s = time.perf_counter()
|
||||||
# Sample a new random action from the robot's action space.
|
# # Sample a new random action from the robot's action space.
|
||||||
new_random_action = action_space_robot.sample()
|
# new_random_action = action_space_robot.sample()
|
||||||
# Update the smoothed action using an exponential moving average.
|
# # Update the smoothed action using an exponential moving average.
|
||||||
smoothed_action = alpha * new_random_action + (1 - alpha) * smoothed_action
|
# smoothed_action = alpha * new_random_action + (1 - alpha) * smoothed_action
|
||||||
|
|
||||||
# Execute the step: wrap the NumPy action in a torch tensor.
|
# # Execute the step: wrap the NumPy action in a torch tensor.
|
||||||
obs, reward, terminated, truncated, info = env.step((torch.from_numpy(smoothed_action), False))
|
# obs, reward, terminated, truncated, info = env.step((torch.from_numpy(smoothed_action), False))
|
||||||
if terminated or truncated:
|
# if terminated or truncated:
|
||||||
sucesses.append(reward)
|
# sucesses.append(reward)
|
||||||
env.reset()
|
# env.reset()
|
||||||
num_episode += 1
|
# num_episode += 1
|
||||||
|
|
||||||
dt_s = time.perf_counter() - start_loop_s
|
# dt_s = time.perf_counter() - start_loop_s
|
||||||
busy_wait(1 / cfg.fps - dt_s)
|
# busy_wait(1 / cfg.fps - dt_s)
|
||||||
|
|
||||||
logging.info(f"Success after 20 steps {sucesses}")
|
# logging.info(f"Success after 20 steps {sucesses}")
|
||||||
logging.info(f"success rate {sum(sucesses) / len(sucesses)}")
|
# logging.info(f"success rate {sum(sucesses) / len(sucesses)}")
|
||||||
|
|
||||||
|
# if __name__ == "__main__":
|
||||||
|
# main()
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
make_robot_env()
|
|
@ -19,40 +19,45 @@ import shutil
|
||||||
import time
|
import time
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from pprint import pformat
|
from pprint import pformat
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import draccus
|
||||||
import grpc
|
import grpc
|
||||||
|
|
||||||
# Import generated stubs
|
# Import generated stubs
|
||||||
import hilserl_pb2_grpc # type: ignore
|
import hilserl_pb2_grpc # type: ignore
|
||||||
import hydra
|
|
||||||
import torch
|
import torch
|
||||||
from deepdiff import DeepDiff
|
|
||||||
from omegaconf import DictConfig, OmegaConf
|
|
||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
# from torch.multiprocessing import Event, Queue, Process
|
|
||||||
# from threading import Event, Thread
|
|
||||||
# from torch.multiprocessing import Queue, Event
|
|
||||||
from torch.multiprocessing import Queue
|
from torch.multiprocessing import Queue
|
||||||
from torch.optim.optimizer import Optimizer
|
from torch.optim.optimizer import Optimizer
|
||||||
|
|
||||||
from lerobot.common.datasets.factory import make_dataset
|
from lerobot.common.datasets.factory import make_dataset
|
||||||
|
|
||||||
|
from lerobot.configs.train import TrainPipelineConfig
|
||||||
|
from lerobot.configs import parser
|
||||||
# TODO: Remove the import of maniskill
|
# TODO: Remove the import of maniskill
|
||||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
from lerobot.common.logger import Logger, log_output_dir
|
|
||||||
from lerobot.common.policies.factory import make_policy
|
from lerobot.common.policies.factory import make_policy
|
||||||
from lerobot.common.policies.sac.modeling_sac import SACPolicy
|
from lerobot.common.policies.sac.modeling_sac import SACPolicy, SACConfig
|
||||||
|
from lerobot.common.utils.train_utils import (
|
||||||
|
get_step_checkpoint_dir,
|
||||||
|
get_step_identifier,
|
||||||
|
load_training_state as utils_load_training_state,
|
||||||
|
save_checkpoint,
|
||||||
|
update_last_checkpoint,
|
||||||
|
)
|
||||||
|
from lerobot.common.utils.random_utils import set_seed
|
||||||
from lerobot.common.utils.utils import (
|
from lerobot.common.utils.utils import (
|
||||||
format_big_number,
|
format_big_number,
|
||||||
get_global_random_state,
|
|
||||||
get_safe_torch_device,
|
get_safe_torch_device,
|
||||||
init_hydra_config,
|
|
||||||
init_logging,
|
init_logging,
|
||||||
set_global_random_state,
|
|
||||||
set_global_seed,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from lerobot.common.policies.utils import get_device_from_parameters
|
||||||
|
from lerobot.common.utils.wandb_utils import WandBLogger
|
||||||
from lerobot.scripts.server import learner_service
|
from lerobot.scripts.server import learner_service
|
||||||
from lerobot.scripts.server.buffer import (
|
from lerobot.scripts.server.buffer import (
|
||||||
ReplayBuffer,
|
ReplayBuffer,
|
||||||
|
@ -64,102 +69,167 @@ from lerobot.scripts.server.buffer import (
|
||||||
state_to_bytes,
|
state_to_bytes,
|
||||||
)
|
)
|
||||||
from lerobot.scripts.server.utils import setup_process_handlers
|
from lerobot.scripts.server.utils import setup_process_handlers
|
||||||
|
from lerobot.common.constants import (
|
||||||
|
CHECKPOINTS_DIR,
|
||||||
|
LAST_CHECKPOINT_LINK,
|
||||||
|
PRETRAINED_MODEL_DIR,
|
||||||
|
TRAINING_STATE_DIR,
|
||||||
|
TRAINING_STEP,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def handle_resume_logic(cfg: DictConfig, out_dir: str) -> DictConfig:
|
def handle_resume_logic(cfg: TrainPipelineConfig) -> TrainPipelineConfig:
|
||||||
|
"""
|
||||||
|
Handle the resume logic for training.
|
||||||
|
|
||||||
|
If resume is True:
|
||||||
|
- Verifies that a checkpoint exists
|
||||||
|
- Loads the checkpoint configuration
|
||||||
|
- Logs resumption details
|
||||||
|
- Returns the checkpoint configuration
|
||||||
|
|
||||||
|
If resume is False:
|
||||||
|
- Checks if an output directory exists (to prevent accidental overwriting)
|
||||||
|
- Returns the original configuration
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cfg (TrainPipelineConfig): The training configuration
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TrainPipelineConfig: The updated configuration
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If resume is True but no checkpoint found, or if resume is False but directory exists
|
||||||
|
"""
|
||||||
|
out_dir = cfg.output_dir
|
||||||
|
|
||||||
|
# Case 1: Not resuming, but need to check if directory exists to prevent overwrites
|
||||||
if not cfg.resume:
|
if not cfg.resume:
|
||||||
if Logger.get_last_checkpoint_dir(out_dir).exists():
|
checkpoint_dir = os.path.join(out_dir, CHECKPOINTS_DIR, LAST_CHECKPOINT_LINK)
|
||||||
|
if os.path.exists(checkpoint_dir):
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Output directory {Logger.get_last_checkpoint_dir(out_dir)} already exists. "
|
f"Output directory {checkpoint_dir} already exists. "
|
||||||
"Use `resume=true` to resume training."
|
"Use `resume=true` to resume training."
|
||||||
)
|
)
|
||||||
return cfg
|
return cfg
|
||||||
|
|
||||||
# if resume == True
|
# Case 2: Resuming training
|
||||||
checkpoint_dir = Logger.get_last_checkpoint_dir(out_dir)
|
checkpoint_dir = os.path.join(out_dir, CHECKPOINTS_DIR, LAST_CHECKPOINT_LINK)
|
||||||
if not checkpoint_dir.exists():
|
if not os.path.exists(checkpoint_dir):
|
||||||
raise RuntimeError(f"No model checkpoint found in {checkpoint_dir} for resume=True")
|
raise RuntimeError(f"No model checkpoint found in {checkpoint_dir} for resume=True")
|
||||||
|
|
||||||
checkpoint_cfg_path = str(Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml")
|
# Log that we found a valid checkpoint and are resuming
|
||||||
logging.info(
|
logging.info(
|
||||||
colored(
|
colored(
|
||||||
"Resume=True detected, resuming previous run",
|
"Valid checkpoint found: resume=True detected, resuming previous run",
|
||||||
color="yellow",
|
color="yellow",
|
||||||
attrs=["bold"],
|
attrs=["bold"],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
checkpoint_cfg = init_hydra_config(checkpoint_cfg_path)
|
# Load config using Draccus
|
||||||
diff = DeepDiff(OmegaConf.to_container(checkpoint_cfg), OmegaConf.to_container(cfg))
|
checkpoint_cfg_path = os.path.join(checkpoint_dir, PRETRAINED_MODEL_DIR, "train_config.json")
|
||||||
|
checkpoint_cfg = TrainPipelineConfig.from_pretrained(checkpoint_cfg_path)
|
||||||
if "values_changed" in diff and "root['resume']" in diff["values_changed"]:
|
|
||||||
del diff["values_changed"]["root['resume']"]
|
|
||||||
|
|
||||||
if len(diff) > 0:
|
|
||||||
logging.warning(
|
|
||||||
f"Differences between the checkpoint config and the provided config detected: \n{pformat(diff)}\n"
|
|
||||||
"Checkpoint configuration takes precedence."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
# Ensure resume flag is set in returned config
|
||||||
checkpoint_cfg.resume = True
|
checkpoint_cfg.resume = True
|
||||||
return checkpoint_cfg
|
return checkpoint_cfg
|
||||||
|
|
||||||
|
|
||||||
def load_training_state(
|
def load_training_state(
|
||||||
cfg: DictConfig,
|
cfg: TrainPipelineConfig,
|
||||||
logger: Logger,
|
optimizers: Optimizer | dict[str, Optimizer],
|
||||||
optimizers: Optimizer | dict,
|
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Loads the training state (optimizers, step count, etc.) from a checkpoint.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cfg (TrainPipelineConfig): Training configuration
|
||||||
|
optimizers (Optimizer | dict): Optimizers to load state into
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (optimization_step, interaction_step) or (None, None) if not resuming
|
||||||
|
"""
|
||||||
if not cfg.resume:
|
if not cfg.resume:
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
training_state = torch.load(
|
# Construct path to the last checkpoint directory
|
||||||
logger.last_checkpoint_dir / logger.training_state_file_name, weights_only=False
|
checkpoint_dir = os.path.join(cfg.output_dir, CHECKPOINTS_DIR, LAST_CHECKPOINT_LINK)
|
||||||
)
|
|
||||||
|
|
||||||
if isinstance(training_state["optimizer"], dict):
|
logging.info(f"Loading training state from {checkpoint_dir}")
|
||||||
assert set(training_state["optimizer"].keys()) == set(optimizers.keys())
|
|
||||||
for k, v in training_state["optimizer"].items():
|
|
||||||
optimizers[k].load_state_dict(v)
|
|
||||||
else:
|
|
||||||
optimizers.load_state_dict(training_state["optimizer"])
|
|
||||||
|
|
||||||
set_global_random_state({k: training_state[k] for k in get_global_random_state()})
|
try:
|
||||||
return training_state["step"], training_state["interaction_step"]
|
# Use the utility function from train_utils which loads the optimizer state
|
||||||
|
# The function returns (step, updated_optimizer, scheduler)
|
||||||
|
step, optimizers, _ = utils_load_training_state(Path(checkpoint_dir), optimizers, None)
|
||||||
|
|
||||||
|
# For interaction step, we still need to load the training_state.pt file
|
||||||
|
training_state_path = os.path.join(checkpoint_dir, TRAINING_STATE_DIR, "training_state.pt")
|
||||||
|
training_state = torch.load(training_state_path, weights_only=False)
|
||||||
|
interaction_step = training_state.get("interaction_step", 0)
|
||||||
|
|
||||||
|
logging.info(f"Resuming from step {step}, interaction step {interaction_step}")
|
||||||
|
return step, interaction_step
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Failed to load training state: {e}")
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
|
||||||
def log_training_info(cfg: DictConfig, out_dir: str, policy: nn.Module) -> None:
|
def log_training_info(cfg: TrainPipelineConfig, policy: nn.Module) -> None:
|
||||||
|
"""
|
||||||
|
Log information about the training process.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cfg (TrainPipelineConfig): Training configuration
|
||||||
|
policy (nn.Module): Policy model
|
||||||
|
"""
|
||||||
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
||||||
num_total_params = sum(p.numel() for p in policy.parameters())
|
num_total_params = sum(p.numel() for p in policy.parameters())
|
||||||
|
|
||||||
log_output_dir(out_dir)
|
|
||||||
|
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}")
|
||||||
logging.info(f"{cfg.env.task=}")
|
logging.info(f"{cfg.env.task=}")
|
||||||
logging.info(f"{cfg.training.online_steps=}")
|
logging.info(f"{cfg.policy.online_steps=}")
|
||||||
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
|
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
|
||||||
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
|
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
|
||||||
|
|
||||||
|
|
||||||
def initialize_replay_buffer(
|
def initialize_replay_buffer(
|
||||||
cfg: DictConfig, logger: Logger, device: str, storage_device: str
|
cfg: TrainPipelineConfig,
|
||||||
|
device: str,
|
||||||
|
storage_device: str
|
||||||
) -> ReplayBuffer:
|
) -> ReplayBuffer:
|
||||||
|
"""
|
||||||
|
Initialize a replay buffer, either empty or from a dataset if resuming.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cfg (TrainPipelineConfig): Training configuration
|
||||||
|
device (str): Device to store tensors on
|
||||||
|
storage_device (str): Device for storage optimization
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ReplayBuffer: Initialized replay buffer
|
||||||
|
"""
|
||||||
if not cfg.resume:
|
if not cfg.resume:
|
||||||
return ReplayBuffer(
|
return ReplayBuffer(
|
||||||
capacity=cfg.training.online_buffer_capacity,
|
capacity=cfg.policy.online_buffer_capacity,
|
||||||
device=device,
|
device=device,
|
||||||
state_keys=cfg.policy.input_shapes.keys(),
|
state_keys=cfg.policy.input_features.keys(),
|
||||||
storage_device=storage_device,
|
storage_device=storage_device,
|
||||||
optimize_memory=True,
|
optimize_memory=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info("Resume training load the online dataset")
|
logging.info("Resume training load the online dataset")
|
||||||
|
dataset_path = os.path.join(cfg.output_dir, "dataset")
|
||||||
dataset = LeRobotDataset(
|
dataset = LeRobotDataset(
|
||||||
repo_id=cfg.dataset_repo_id,
|
repo_id=cfg.dataset.dataset_repo_id,
|
||||||
local_files_only=True,
|
local_files_only=True,
|
||||||
root=logger.log_dir / "dataset",
|
root=dataset_path,
|
||||||
)
|
)
|
||||||
return ReplayBuffer.from_lerobot_dataset(
|
return ReplayBuffer.from_lerobot_dataset(
|
||||||
lerobot_dataset=dataset,
|
lerobot_dataset=dataset,
|
||||||
capacity=cfg.training.online_buffer_capacity,
|
capacity=cfg.policy.online_buffer_capacity,
|
||||||
device=device,
|
device=device,
|
||||||
state_keys=cfg.policy.input_shapes.keys(),
|
state_keys=cfg.policy.input_shapes.keys(),
|
||||||
optimize_memory=True,
|
optimize_memory=True,
|
||||||
|
@ -167,33 +237,45 @@ def initialize_replay_buffer(
|
||||||
|
|
||||||
|
|
||||||
def initialize_offline_replay_buffer(
|
def initialize_offline_replay_buffer(
|
||||||
cfg: DictConfig,
|
cfg: TrainPipelineConfig,
|
||||||
logger: Logger,
|
|
||||||
device: str,
|
device: str,
|
||||||
storage_device: str,
|
storage_device: str,
|
||||||
active_action_dims: list[int] | None = None,
|
active_action_dims: list[int] | None = None,
|
||||||
) -> ReplayBuffer:
|
) -> ReplayBuffer:
|
||||||
|
"""
|
||||||
|
Initialize an offline replay buffer from a dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cfg (TrainPipelineConfig): Training configuration
|
||||||
|
device (str): Device to store tensors on
|
||||||
|
storage_device (str): Device for storage optimization
|
||||||
|
active_action_dims (list[int] | None): Active action dimensions for masking
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ReplayBuffer: Initialized offline replay buffer
|
||||||
|
"""
|
||||||
if not cfg.resume:
|
if not cfg.resume:
|
||||||
logging.info("make_dataset offline buffer")
|
logging.info("make_dataset offline buffer")
|
||||||
offline_dataset = make_dataset(cfg)
|
offline_dataset = make_dataset(cfg)
|
||||||
if cfg.resume:
|
else:
|
||||||
logging.info("load offline dataset")
|
logging.info("load offline dataset")
|
||||||
|
dataset_offline_path = os.path.join(cfg.output_dir, "dataset_offline")
|
||||||
offline_dataset = LeRobotDataset(
|
offline_dataset = LeRobotDataset(
|
||||||
repo_id=cfg.dataset_repo_id,
|
repo_id=cfg.dataset.dataset_repo_id,
|
||||||
local_files_only=True,
|
local_files_only=True,
|
||||||
root=logger.log_dir / "dataset_offline",
|
root=dataset_offline_path,
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info("Convert to a offline replay buffer")
|
logging.info("Convert to a offline replay buffer")
|
||||||
offline_replay_buffer = ReplayBuffer.from_lerobot_dataset(
|
offline_replay_buffer = ReplayBuffer.from_lerobot_dataset(
|
||||||
offline_dataset,
|
offline_dataset,
|
||||||
device=device,
|
device=device,
|
||||||
state_keys=cfg.policy.input_shapes.keys(),
|
state_keys=cfg.policy.input_features.keys(),
|
||||||
action_mask=active_action_dims,
|
action_mask=active_action_dims,
|
||||||
action_delta=cfg.env.wrapper.delta_action,
|
action_delta=cfg.env.wrapper.delta_action,
|
||||||
storage_device=storage_device,
|
storage_device=storage_device,
|
||||||
optimize_memory=True,
|
optimize_memory=True,
|
||||||
capacity=cfg.training.offline_buffer_capacity,
|
capacity=cfg.policy.offline_buffer_capacity,
|
||||||
)
|
)
|
||||||
return offline_replay_buffer
|
return offline_replay_buffer
|
||||||
|
|
||||||
|
@ -215,16 +297,23 @@ def get_observation_features(
|
||||||
return observation_features, next_observation_features
|
return observation_features, next_observation_features
|
||||||
|
|
||||||
|
|
||||||
def use_threads(cfg: DictConfig) -> bool:
|
def use_threads(cfg: TrainPipelineConfig) -> bool:
|
||||||
return cfg.actor_learner_config.concurrency.learner == "threads"
|
return cfg.policy.concurrency["learner"] == "threads"
|
||||||
|
|
||||||
|
|
||||||
def start_learner_threads(
|
def start_learner_threads(
|
||||||
cfg: DictConfig,
|
cfg: TrainPipelineConfig,
|
||||||
logger: Logger,
|
wandb_logger: WandBLogger | None,
|
||||||
out_dir: str,
|
|
||||||
shutdown_event: any, # Event,
|
shutdown_event: any, # Event,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
"""
|
||||||
|
Start the learner threads for training.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cfg (TrainPipelineConfig): Training configuration
|
||||||
|
wandb_logger (WandBLogger | None): Logger for metrics
|
||||||
|
shutdown_event: Event to signal shutdown
|
||||||
|
"""
|
||||||
# Create multiprocessing queues
|
# Create multiprocessing queues
|
||||||
transition_queue = Queue()
|
transition_queue = Queue()
|
||||||
interaction_message_queue = Queue()
|
interaction_message_queue = Queue()
|
||||||
|
@ -255,13 +344,12 @@ def start_learner_threads(
|
||||||
communication_process.start()
|
communication_process.start()
|
||||||
|
|
||||||
add_actor_information_and_train(
|
add_actor_information_and_train(
|
||||||
cfg,
|
cfg=cfg,
|
||||||
logger,
|
wandb_logger=wandb_logger,
|
||||||
out_dir,
|
shutdown_event=shutdown_event,
|
||||||
shutdown_event,
|
transition_queue=transition_queue,
|
||||||
transition_queue,
|
interaction_message_queue=interaction_message_queue,
|
||||||
interaction_message_queue,
|
parameters_queue=parameters_queue,
|
||||||
parameters_queue,
|
|
||||||
)
|
)
|
||||||
logging.info("[LEARNER] Training process stopped")
|
logging.info("[LEARNER] Training process stopped")
|
||||||
|
|
||||||
|
@ -286,7 +374,7 @@ def start_learner_server(
|
||||||
transition_queue: Queue,
|
transition_queue: Queue,
|
||||||
interaction_message_queue: Queue,
|
interaction_message_queue: Queue,
|
||||||
shutdown_event: any, # Event,
|
shutdown_event: any, # Event,
|
||||||
cfg: DictConfig,
|
cfg: TrainPipelineConfig,
|
||||||
):
|
):
|
||||||
if not use_threads(cfg):
|
if not use_threads(cfg):
|
||||||
# We need init logging for MP separataly
|
# We need init logging for MP separataly
|
||||||
|
@ -298,11 +386,11 @@ def start_learner_server(
|
||||||
setup_process_handlers(False)
|
setup_process_handlers(False)
|
||||||
|
|
||||||
service = learner_service.LearnerService(
|
service = learner_service.LearnerService(
|
||||||
shutdown_event,
|
shutdown_event=shutdown_event,
|
||||||
parameters_queue,
|
parameters_queue=parameters_queue,
|
||||||
cfg.actor_learner_config.policy_parameters_push_frequency,
|
seconds_between_pushes=cfg.policy.actor_learner_config["policy_parameters_push_frequency"],
|
||||||
transition_queue,
|
transition_queue=transition_queue,
|
||||||
interaction_message_queue,
|
interaction_message_queue=interaction_message_queue,
|
||||||
)
|
)
|
||||||
|
|
||||||
server = grpc.server(
|
server = grpc.server(
|
||||||
|
@ -318,8 +406,8 @@ def start_learner_server(
|
||||||
server,
|
server,
|
||||||
)
|
)
|
||||||
|
|
||||||
host = cfg.actor_learner_config.learner_host
|
host = cfg.policy.actor_learner_config["learner_host"]
|
||||||
port = cfg.actor_learner_config.learner_port
|
port = cfg.policy.actor_learner_config["learner_port"]
|
||||||
|
|
||||||
server.add_insecure_port(f"{host}:{port}")
|
server.add_insecure_port(f"{host}:{port}")
|
||||||
server.start()
|
server.start()
|
||||||
|
@ -385,9 +473,8 @@ def push_actor_policy_to_queue(parameters_queue: Queue, policy: nn.Module):
|
||||||
|
|
||||||
|
|
||||||
def add_actor_information_and_train(
|
def add_actor_information_and_train(
|
||||||
cfg,
|
cfg: TrainPipelineConfig,
|
||||||
logger: Logger,
|
wandb_logger: WandBLogger | None,
|
||||||
out_dir: str,
|
|
||||||
shutdown_event: any, # Event,
|
shutdown_event: any, # Event,
|
||||||
transition_queue: Queue,
|
transition_queue: Queue,
|
||||||
interaction_message_queue: Queue,
|
interaction_message_queue: Queue,
|
||||||
|
@ -405,69 +492,60 @@ def add_actor_information_and_train(
|
||||||
- Periodically updates the actor, critic, and temperature optimizers.
|
- Periodically updates the actor, critic, and temperature optimizers.
|
||||||
- Logs training statistics, including loss values and optimization frequency.
|
- Logs training statistics, including loss values and optimization frequency.
|
||||||
|
|
||||||
**NOTE:**
|
|
||||||
- This function performs multiple responsibilities (data transfer, training, and logging).
|
|
||||||
It should ideally be split into smaller functions in the future.
|
|
||||||
- Due to Python's **Global Interpreter Lock (GIL)**, running separate threads for different tasks
|
|
||||||
significantly reduces performance. Instead, this function executes all operations in a single thread.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cfg: Configuration object containing hyperparameters.
|
cfg (TrainPipelineConfig): Configuration object containing hyperparameters.
|
||||||
device (str): The computing device (`"cpu"` or `"cuda"`).
|
wandb_logger (WandBLogger | None): Logger for tracking training progress.
|
||||||
logger (Logger): Logger instance for tracking training progress.
|
|
||||||
out_dir (str): The output directory for storing training checkpoints and logs.
|
|
||||||
shutdown_event (Event): Event to signal shutdown.
|
shutdown_event (Event): Event to signal shutdown.
|
||||||
transition_queue (Queue): Queue for receiving transitions from the actor.
|
transition_queue (Queue): Queue for receiving transitions from the actor.
|
||||||
interaction_message_queue (Queue): Queue for receiving interaction messages from the actor.
|
interaction_message_queue (Queue): Queue for receiving interaction messages from the actor.
|
||||||
parameters_queue (Queue): Queue for sending policy parameters to the actor.
|
parameters_queue (Queue): Queue for sending policy parameters to the actor.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
device = get_safe_torch_device(cfg.device, log=True)
|
device = get_safe_torch_device(try_device=cfg.policy.device, log=True)
|
||||||
storage_device = get_safe_torch_device(cfg_device=cfg.training.storage_device)
|
storage_device = get_safe_torch_device(try_device=cfg.policy.storage_device)
|
||||||
|
|
||||||
logging.info("Initializing policy")
|
logging.info("Initializing policy")
|
||||||
### Instantiate the policy in both the actor and learner processes
|
# Get checkpoint dir for resuming
|
||||||
### To avoid sending a SACPolicy object through the port, we create a policy intance
|
checkpoint_dir = os.path.join(cfg.output_dir, CHECKPOINTS_DIR, LAST_CHECKPOINT_LINK) if cfg.resume else None
|
||||||
### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters
|
pretrained_path = os.path.join(checkpoint_dir, PRETRAINED_MODEL_DIR) if checkpoint_dir else None
|
||||||
# TODO: At some point we should just need make sac policy
|
|
||||||
|
|
||||||
|
# TODO(Adil): This don't work anymore !
|
||||||
policy: SACPolicy = make_policy(
|
policy: SACPolicy = make_policy(
|
||||||
hydra_cfg=cfg,
|
cfg=cfg.policy,
|
||||||
# dataset_stats=offline_dataset.meta.stats if not cfg.resume else None,
|
# ds_meta=cfg.dataset,
|
||||||
# Hack: But if we do online traning, we do not need dataset_stats
|
env_cfg=cfg.env
|
||||||
dataset_stats=None,
|
|
||||||
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update the policy config with the grad_clip_norm value from training config if it exists
|
# Update the policy config with the grad_clip_norm value from training config if it exists
|
||||||
clip_grad_norm_value = cfg.training.grad_clip_norm
|
clip_grad_norm_value:float = cfg.policy.grad_clip_norm
|
||||||
|
|
||||||
# compile policy
|
# compile policy
|
||||||
policy = torch.compile(policy)
|
policy = torch.compile(policy)
|
||||||
assert isinstance(policy, nn.Module)
|
assert isinstance(policy, nn.Module)
|
||||||
|
policy.train()
|
||||||
|
|
||||||
push_actor_policy_to_queue(parameters_queue, policy)
|
push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy)
|
||||||
|
|
||||||
last_time_policy_pushed = time.time()
|
last_time_policy_pushed = time.time()
|
||||||
|
|
||||||
optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg, policy)
|
optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg=cfg, policy=policy)
|
||||||
resume_optimization_step, resume_interaction_step = load_training_state(cfg, logger, optimizers)
|
resume_optimization_step, resume_interaction_step = load_training_state(cfg=cfg, optimizers=optimizers)
|
||||||
|
|
||||||
log_training_info(cfg, out_dir, policy)
|
log_training_info(cfg=cfg, policy= policy)
|
||||||
|
|
||||||
replay_buffer = initialize_replay_buffer(cfg, logger, device, storage_device)
|
replay_buffer = initialize_replay_buffer(cfg, device, storage_device)
|
||||||
batch_size = cfg.training.batch_size
|
batch_size = cfg.batch_size
|
||||||
offline_replay_buffer = None
|
offline_replay_buffer = None
|
||||||
|
|
||||||
if cfg.dataset_repo_id is not None:
|
if cfg.dataset is not None:
|
||||||
active_action_dims = None
|
active_action_dims = None
|
||||||
|
# TODO: FIX THIS
|
||||||
if cfg.env.wrapper.joint_masking_action_space is not None:
|
if cfg.env.wrapper.joint_masking_action_space is not None:
|
||||||
active_action_dims = [
|
active_action_dims = [
|
||||||
i for i, mask in enumerate(cfg.env.wrapper.joint_masking_action_space) if mask
|
i for i, mask in enumerate(cfg.env.wrapper.joint_masking_action_space) if mask
|
||||||
]
|
]
|
||||||
offline_replay_buffer = initialize_offline_replay_buffer(
|
offline_replay_buffer = initialize_offline_replay_buffer(
|
||||||
cfg=cfg,
|
cfg=cfg,
|
||||||
logger=logger,
|
|
||||||
device=device,
|
device=device,
|
||||||
storage_device=storage_device,
|
storage_device=storage_device,
|
||||||
active_action_dims=active_action_dims,
|
active_action_dims=active_action_dims,
|
||||||
|
@ -484,18 +562,22 @@ def add_actor_information_and_train(
|
||||||
interaction_step_shift = resume_interaction_step if resume_interaction_step is not None else 0
|
interaction_step_shift = resume_interaction_step if resume_interaction_step is not None else 0
|
||||||
|
|
||||||
# Extract variables from cfg
|
# Extract variables from cfg
|
||||||
online_step_before_learning = cfg.training.online_step_before_learning
|
online_step_before_learning = cfg.policy.online_step_before_learning
|
||||||
utd_ratio = cfg.policy.utd_ratio
|
utd_ratio = cfg.policy.utd_ratio
|
||||||
dataset_repo_id = cfg.dataset_repo_id
|
|
||||||
fps = cfg.fps
|
dataset_repo_id = None
|
||||||
log_freq = cfg.training.log_freq
|
if cfg.dataset is not None:
|
||||||
save_freq = cfg.training.save_freq
|
dataset_repo_id = cfg.dataset.repo_id
|
||||||
device = cfg.device
|
|
||||||
storage_device = cfg.training.storage_device
|
fps = cfg.env.fps
|
||||||
policy_update_freq = cfg.training.policy_update_freq
|
log_freq = cfg.log_freq
|
||||||
policy_parameters_push_frequency = cfg.actor_learner_config.policy_parameters_push_frequency
|
save_freq = cfg.save_freq
|
||||||
save_checkpoint = cfg.training.save_checkpoint
|
device = cfg.policy.device
|
||||||
online_steps = cfg.training.online_steps
|
storage_device = cfg.policy.storage_device
|
||||||
|
policy_update_freq = cfg.policy.policy_update_freq
|
||||||
|
policy_parameters_push_frequency = cfg.policy.actor_learner_config["policy_parameters_push_frequency"]
|
||||||
|
save_checkpoint = cfg.save_checkpoint
|
||||||
|
online_steps = cfg.policy.online_steps
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
if shutdown_event is not None and shutdown_event.is_set():
|
if shutdown_event is not None and shutdown_event.is_set():
|
||||||
|
@ -516,7 +598,7 @@ def add_actor_information_and_train(
|
||||||
continue
|
continue
|
||||||
replay_buffer.add(**transition)
|
replay_buffer.add(**transition)
|
||||||
|
|
||||||
if cfg.dataset_repo_id is not None and transition.get("complementary_info", {}).get(
|
if cfg.dataset.repo_id is not None and transition.get("complementary_info", {}).get(
|
||||||
"is_intervention"
|
"is_intervention"
|
||||||
):
|
):
|
||||||
offline_replay_buffer.add(**transition)
|
offline_replay_buffer.add(**transition)
|
||||||
|
@ -528,7 +610,17 @@ def add_actor_information_and_train(
|
||||||
interaction_message = bytes_to_python_object(interaction_message)
|
interaction_message = bytes_to_python_object(interaction_message)
|
||||||
# If cfg.resume, shift the interaction step with the last checkpointed step in order to not break the logging
|
# If cfg.resume, shift the interaction step with the last checkpointed step in order to not break the logging
|
||||||
interaction_message["Interaction step"] += interaction_step_shift
|
interaction_message["Interaction step"] += interaction_step_shift
|
||||||
logger.log_dict(interaction_message, mode="train", custom_step_key="Interaction step")
|
|
||||||
|
# Log interaction messages with WandB if available
|
||||||
|
if wandb_logger:
|
||||||
|
wandb_logger.log_dict(
|
||||||
|
d=interaction_message,
|
||||||
|
mode="train",
|
||||||
|
custom_step_key="Interaction step"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Log to console if no WandB logger
|
||||||
|
logging.info(f"Interaction: {interaction_message}")
|
||||||
|
|
||||||
logging.debug("[LEARNER] Received interactions")
|
logging.debug("[LEARNER] Received interactions")
|
||||||
|
|
||||||
|
@ -538,11 +630,11 @@ def add_actor_information_and_train(
|
||||||
logging.debug("[LEARNER] Starting optimization loop")
|
logging.debug("[LEARNER] Starting optimization loop")
|
||||||
time_for_one_optimization_step = time.time()
|
time_for_one_optimization_step = time.time()
|
||||||
for _ in range(utd_ratio - 1):
|
for _ in range(utd_ratio - 1):
|
||||||
batch = replay_buffer.sample(batch_size)
|
batch = replay_buffer.sample(batch_size=batch_size)
|
||||||
|
|
||||||
if dataset_repo_id is not None:
|
if dataset_repo_id is not None:
|
||||||
batch_offline = offline_replay_buffer.sample(batch_size)
|
batch_offline = offline_replay_buffer.sample(batch_size=batch_size)
|
||||||
batch = concatenate_batch_transitions(batch, batch_offline)
|
batch = concatenate_batch_transitions(left_batch_transitions=batch, right_batch_transition=batch_offline)
|
||||||
|
|
||||||
actions = batch["action"]
|
actions = batch["action"]
|
||||||
rewards = batch["reward"]
|
rewards = batch["reward"]
|
||||||
|
@ -552,7 +644,7 @@ def add_actor_information_and_train(
|
||||||
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
|
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
|
||||||
|
|
||||||
observation_features, next_observation_features = get_observation_features(
|
observation_features, next_observation_features = get_observation_features(
|
||||||
policy, observations, next_observations
|
policy=policy, observations=observations, next_observations=next_observations
|
||||||
)
|
)
|
||||||
loss_critic = policy.compute_loss_critic(
|
loss_critic = policy.compute_loss_critic(
|
||||||
observations=observations,
|
observations=observations,
|
||||||
|
@ -568,15 +660,15 @@ def add_actor_information_and_train(
|
||||||
|
|
||||||
# clip gradients
|
# clip gradients
|
||||||
critic_grad_norm = torch.nn.utils.clip_grad_norm_(
|
critic_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||||
policy.critic_ensemble.parameters(), clip_grad_norm_value
|
parameters=policy.critic_ensemble.parameters(), max_norm=clip_grad_norm_value
|
||||||
)
|
)
|
||||||
|
|
||||||
optimizers["critic"].step()
|
optimizers["critic"].step()
|
||||||
|
|
||||||
batch = replay_buffer.sample(batch_size)
|
batch = replay_buffer.sample(batch_size=batch_size)
|
||||||
|
|
||||||
if dataset_repo_id is not None:
|
if dataset_repo_id is not None:
|
||||||
batch_offline = offline_replay_buffer.sample(batch_size)
|
batch_offline = offline_replay_buffer.sample(batch_size=batch_size)
|
||||||
batch = concatenate_batch_transitions(
|
batch = concatenate_batch_transitions(
|
||||||
left_batch_transitions=batch, right_batch_transition=batch_offline
|
left_batch_transitions=batch, right_batch_transition=batch_offline
|
||||||
)
|
)
|
||||||
|
@ -590,7 +682,7 @@ def add_actor_information_and_train(
|
||||||
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
|
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
|
||||||
|
|
||||||
observation_features, next_observation_features = get_observation_features(
|
observation_features, next_observation_features = get_observation_features(
|
||||||
policy, observations, next_observations
|
policy=policy, observations=observations, next_observations=next_observations
|
||||||
)
|
)
|
||||||
loss_critic = policy.compute_loss_critic(
|
loss_critic = policy.compute_loss_critic(
|
||||||
observations=observations,
|
observations=observations,
|
||||||
|
@ -606,7 +698,7 @@ def add_actor_information_and_train(
|
||||||
|
|
||||||
# clip gradients
|
# clip gradients
|
||||||
critic_grad_norm = torch.nn.utils.clip_grad_norm_(
|
critic_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||||
policy.critic_ensemble.parameters(), clip_grad_norm_value
|
parameters=policy.critic_ensemble.parameters(), max_norm=clip_grad_norm_value
|
||||||
).item()
|
).item()
|
||||||
|
|
||||||
optimizers["critic"].step()
|
optimizers["critic"].step()
|
||||||
|
@ -627,7 +719,7 @@ def add_actor_information_and_train(
|
||||||
|
|
||||||
# clip gradients
|
# clip gradients
|
||||||
actor_grad_norm = torch.nn.utils.clip_grad_norm_(
|
actor_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||||
policy.actor.parameters_to_optimize, clip_grad_norm_value
|
parameters=policy.actor.parameters_to_optimize, max_norm=clip_grad_norm_value
|
||||||
).item()
|
).item()
|
||||||
|
|
||||||
optimizers["actor"].step()
|
optimizers["actor"].step()
|
||||||
|
@ -645,7 +737,7 @@ def add_actor_information_and_train(
|
||||||
|
|
||||||
# clip gradients
|
# clip gradients
|
||||||
temp_grad_norm = torch.nn.utils.clip_grad_norm_(
|
temp_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||||
[policy.log_alpha], clip_grad_norm_value
|
parameters=[policy.log_alpha], max_norm=clip_grad_norm_value
|
||||||
).item()
|
).item()
|
||||||
|
|
||||||
optimizers["temperature"].step()
|
optimizers["temperature"].step()
|
||||||
|
@ -655,7 +747,7 @@ def add_actor_information_and_train(
|
||||||
training_infos["temperature"] = policy.temperature
|
training_infos["temperature"] = policy.temperature
|
||||||
|
|
||||||
if time.time() - last_time_policy_pushed > policy_parameters_push_frequency:
|
if time.time() - last_time_policy_pushed > policy_parameters_push_frequency:
|
||||||
push_actor_policy_to_queue(parameters_queue, policy)
|
push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy)
|
||||||
last_time_policy_pushed = time.time()
|
last_time_policy_pushed = time.time()
|
||||||
|
|
||||||
policy.update_target_networks()
|
policy.update_target_networks()
|
||||||
|
@ -665,15 +757,26 @@ def add_actor_information_and_train(
|
||||||
if offline_replay_buffer is not None:
|
if offline_replay_buffer is not None:
|
||||||
training_infos["offline_replay_buffer_size"] = len(offline_replay_buffer)
|
training_infos["offline_replay_buffer_size"] = len(offline_replay_buffer)
|
||||||
training_infos["Optimization step"] = optimization_step
|
training_infos["Optimization step"] = optimization_step
|
||||||
logger.log_dict(d=training_infos, mode="train", custom_step_key="Optimization step")
|
|
||||||
# logging.info(f"Training infos: {training_infos}")
|
# Log training metrics
|
||||||
|
if wandb_logger:
|
||||||
|
wandb_logger.log_dict(
|
||||||
|
d=training_infos,
|
||||||
|
mode="train",
|
||||||
|
custom_step_key="Optimization step"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Log to console if no WandB logger
|
||||||
|
logging.info(f"Training: {training_infos}")
|
||||||
|
|
||||||
time_for_one_optimization_step = time.time() - time_for_one_optimization_step
|
time_for_one_optimization_step = time.time() - time_for_one_optimization_step
|
||||||
frequency_for_one_optimization_step = 1 / (time_for_one_optimization_step + 1e-9)
|
frequency_for_one_optimization_step = 1 / (time_for_one_optimization_step + 1e-9)
|
||||||
|
|
||||||
logging.info(f"[LEARNER] Optimization frequency loop [Hz]: {frequency_for_one_optimization_step}")
|
logging.info(f"[LEARNER] Optimization frequency loop [Hz]: {frequency_for_one_optimization_step}")
|
||||||
|
|
||||||
logger.log_dict(
|
# Log optimization frequency
|
||||||
|
if wandb_logger:
|
||||||
|
wandb_logger.log_dict(
|
||||||
{
|
{
|
||||||
"Optimization frequency loop [Hz]": frequency_for_one_optimization_step,
|
"Optimization frequency loop [Hz]": frequency_for_one_optimization_step,
|
||||||
"Optimization step": optimization_step,
|
"Optimization step": optimization_step,
|
||||||
|
@ -693,35 +796,45 @@ def add_actor_information_and_train(
|
||||||
interaction_step = (
|
interaction_step = (
|
||||||
interaction_message["Interaction step"] if interaction_message is not None else 0
|
interaction_message["Interaction step"] if interaction_message is not None else 0
|
||||||
)
|
)
|
||||||
logger.save_checkpoint(
|
|
||||||
|
# Create checkpoint directory
|
||||||
|
checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, online_steps, optimization_step)
|
||||||
|
|
||||||
|
# Save checkpoint
|
||||||
|
save_checkpoint(
|
||||||
|
checkpoint_dir,
|
||||||
optimization_step,
|
optimization_step,
|
||||||
|
cfg,
|
||||||
policy,
|
policy,
|
||||||
optimizers,
|
optimizers,
|
||||||
scheduler=None,
|
scheduler=None
|
||||||
identifier=step_identifier,
|
|
||||||
interaction_step=interaction_step,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Update the "last" symlink
|
||||||
|
update_last_checkpoint(checkpoint_dir)
|
||||||
|
|
||||||
# TODO : temporarly save replay buffer here, remove later when on the robot
|
# TODO : temporarly save replay buffer here, remove later when on the robot
|
||||||
# We want to control this with the keyboard inputs
|
# We want to control this with the keyboard inputs
|
||||||
dataset_dir = logger.log_dir / "dataset"
|
dataset_dir = os.path.join(cfg.output_dir, "dataset")
|
||||||
if dataset_dir.exists() and dataset_dir.is_dir():
|
if os.path.exists(dataset_dir) and os.path.isdir(dataset_dir):
|
||||||
shutil.rmtree(
|
shutil.rmtree(dataset_dir)
|
||||||
dataset_dir,
|
|
||||||
)
|
|
||||||
replay_buffer.to_lerobot_dataset(dataset_repo_id, fps=fps, root=logger.log_dir / "dataset")
|
|
||||||
if offline_replay_buffer is not None:
|
|
||||||
dataset_dir = logger.log_dir / "dataset_offline"
|
|
||||||
|
|
||||||
if dataset_dir.exists() and dataset_dir.is_dir():
|
# Save dataset
|
||||||
shutil.rmtree(
|
replay_buffer.to_lerobot_dataset(
|
||||||
dataset_dir,
|
dataset_repo_id,
|
||||||
|
fps=fps,
|
||||||
|
root=dataset_dir
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if offline_replay_buffer is not None:
|
||||||
|
dataset_offline_dir = os.path.join(cfg.output_dir, "dataset_offline")
|
||||||
|
if os.path.exists(dataset_offline_dir) and os.path.isdir(dataset_offline_dir):
|
||||||
|
shutil.rmtree(dataset_offline_dir)
|
||||||
|
|
||||||
offline_replay_buffer.to_lerobot_dataset(
|
offline_replay_buffer.to_lerobot_dataset(
|
||||||
cfg.dataset_repo_id,
|
cfg.dataset.dataset_repo_id,
|
||||||
fps=cfg.fps,
|
fps=cfg.env.fps,
|
||||||
root=logger.log_dir / "dataset_offline",
|
root=dataset_offline_dir,
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info("Resume training")
|
logging.info("Resume training")
|
||||||
|
@ -756,12 +869,12 @@ def make_optimizers_and_scheduler(cfg, policy: nn.Module):
|
||||||
optimizer_actor = torch.optim.Adam(
|
optimizer_actor = torch.optim.Adam(
|
||||||
# NOTE: Handle the case of shared encoder where the encoder weights are not optimized with the gradient of the actor
|
# NOTE: Handle the case of shared encoder where the encoder weights are not optimized with the gradient of the actor
|
||||||
params=policy.actor.parameters_to_optimize,
|
params=policy.actor.parameters_to_optimize,
|
||||||
lr=policy.config.actor_lr,
|
lr=cfg.policy.actor_lr,
|
||||||
)
|
)
|
||||||
optimizer_critic = torch.optim.Adam(
|
optimizer_critic = torch.optim.Adam(
|
||||||
params=policy.critic_ensemble.parameters(), lr=policy.config.critic_lr
|
params=policy.critic_ensemble.parameters(), lr=cfg.policy.critic_lr
|
||||||
)
|
)
|
||||||
optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=policy.config.critic_lr)
|
optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=cfg.policy.critic_lr)
|
||||||
lr_scheduler = None
|
lr_scheduler = None
|
||||||
optimizers = {
|
optimizers = {
|
||||||
"actor": optimizer_actor,
|
"actor": optimizer_actor,
|
||||||
|
@ -771,19 +884,38 @@ def make_optimizers_and_scheduler(cfg, policy: nn.Module):
|
||||||
return optimizers, lr_scheduler
|
return optimizers, lr_scheduler
|
||||||
|
|
||||||
|
|
||||||
def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = None):
|
def train(cfg: TrainPipelineConfig, job_name: str | None = None):
|
||||||
if out_dir is None:
|
"""
|
||||||
raise NotImplementedError()
|
Main training function that initializes and runs the training process.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cfg (TrainPipelineConfig): The training configuration
|
||||||
|
job_name (str | None, optional): Job name for logging. Defaults to None.
|
||||||
|
"""
|
||||||
|
if cfg.output_dir is None:
|
||||||
|
raise ValueError("Output directory must be specified in config")
|
||||||
|
|
||||||
if job_name is None:
|
if job_name is None:
|
||||||
raise NotImplementedError()
|
job_name = cfg.job_name
|
||||||
|
|
||||||
|
if job_name is None:
|
||||||
|
raise ValueError("Job name must be specified either in config or as a parameter")
|
||||||
|
|
||||||
init_logging()
|
init_logging()
|
||||||
logging.info(pformat(OmegaConf.to_container(cfg)))
|
logging.info(pformat(cfg.to_dict()))
|
||||||
|
|
||||||
logger = Logger(cfg, out_dir, wandb_job_name=job_name)
|
# Setup WandB logging if enabled
|
||||||
cfg = handle_resume_logic(cfg, out_dir)
|
if cfg.wandb.enable and cfg.wandb.project:
|
||||||
|
from lerobot.common.utils.wandb_utils import WandBLogger
|
||||||
|
wandb_logger = WandBLogger(cfg)
|
||||||
|
else:
|
||||||
|
wandb_logger = None
|
||||||
|
logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"]))
|
||||||
|
|
||||||
set_global_seed(cfg.seed)
|
# Handle resume logic
|
||||||
|
cfg = handle_resume_logic(cfg)
|
||||||
|
|
||||||
|
set_seed(seed=cfg.seed)
|
||||||
|
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
@ -791,24 +923,23 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||||
shutdown_event = setup_process_handlers(use_threads(cfg))
|
shutdown_event = setup_process_handlers(use_threads(cfg))
|
||||||
|
|
||||||
start_learner_threads(
|
start_learner_threads(
|
||||||
cfg,
|
cfg=cfg,
|
||||||
logger,
|
wandb_logger=wandb_logger,
|
||||||
out_dir,
|
shutdown_event=shutdown_event,
|
||||||
shutdown_event,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@hydra.main(version_base="1.2", config_name="default", config_path="../../configs")
|
@parser.wrap()
|
||||||
def train_cli(cfg: dict):
|
def train_cli(cfg: TrainPipelineConfig):
|
||||||
|
|
||||||
if not use_threads(cfg):
|
if not use_threads(cfg):
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
|
|
||||||
mp.set_start_method("spawn")
|
mp.set_start_method("spawn")
|
||||||
|
|
||||||
|
# Use the job_name from the config
|
||||||
train(
|
train(
|
||||||
cfg,
|
cfg,
|
||||||
out_dir=hydra.core.hydra_config.HydraConfig.get().run.dir,
|
job_name=cfg.job_name,
|
||||||
job_name=hydra.core.hydra_config.HydraConfig.get().job.name,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info("[LEARNER] train_cli finished")
|
logging.info("[LEARNER] train_cli finished")
|
||||||
|
@ -816,5 +947,4 @@ def train_cli(cfg: dict):
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
train_cli()
|
train_cli()
|
||||||
|
|
||||||
logging.info("[LEARNER] main finished")
|
logging.info("[LEARNER] main finished")
|
||||||
|
|
|
@ -1,4 +1,7 @@
|
||||||
from typing import Any
|
import logging
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any, Dict, Optional, Tuple
|
||||||
|
|
||||||
import einops
|
import einops
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
|
@ -6,7 +9,11 @@ import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from mani_skill.utils.wrappers.record import RecordEpisode
|
from mani_skill.utils.wrappers.record import RecordEpisode
|
||||||
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
|
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
|
||||||
from omegaconf import DictConfig
|
|
||||||
|
from lerobot.common.envs.configs import EnvConfig, ManiskillEnvConfig
|
||||||
|
from lerobot.configs import parser
|
||||||
|
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||||
|
from lerobot.common.constants import ACTION, OBS_IMAGE, OBS_ROBOT
|
||||||
|
|
||||||
|
|
||||||
def preprocess_maniskill_observation(
|
def preprocess_maniskill_observation(
|
||||||
|
@ -46,9 +53,14 @@ def preprocess_maniskill_observation(
|
||||||
return return_observations
|
return return_observations
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ManiSkillObservationWrapper(gym.ObservationWrapper):
|
class ManiSkillObservationWrapper(gym.ObservationWrapper):
|
||||||
def __init__(self, env, device: torch.device = "cuda"):
|
def __init__(self, env, device: torch.device = "cuda"):
|
||||||
super().__init__(env)
|
super().__init__(env)
|
||||||
|
if isinstance(device, str):
|
||||||
|
device = torch.device(device)
|
||||||
self.device = device
|
self.device = device
|
||||||
|
|
||||||
def observation(self, observation):
|
def observation(self, observation):
|
||||||
|
@ -108,76 +120,129 @@ class ManiSkillMultiplyActionWrapper(gym.Wrapper):
|
||||||
return obs, reward, terminated, truncated, info
|
return obs, reward, terminated, truncated, info
|
||||||
|
|
||||||
|
|
||||||
|
class BatchCompatibleWrapper(gym.ObservationWrapper):
|
||||||
|
"""Ensures observations are batch-compatible by adding a batch dimension if necessary."""
|
||||||
|
def __init__(self, env):
|
||||||
|
super().__init__(env)
|
||||||
|
|
||||||
|
def observation(self, observation: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
||||||
|
for key in observation:
|
||||||
|
if "image" in key and observation[key].dim() == 3:
|
||||||
|
observation[key] = observation[key].unsqueeze(0)
|
||||||
|
if "state" in key and observation[key].dim() == 1:
|
||||||
|
observation[key] = observation[key].unsqueeze(0)
|
||||||
|
return observation
|
||||||
|
|
||||||
|
|
||||||
|
class TimeLimitWrapper(gym.Wrapper):
|
||||||
|
"""Adds a time limit to the environment based on fps and control_time."""
|
||||||
|
def __init__(self, env, control_time_s, fps):
|
||||||
|
super().__init__(env)
|
||||||
|
self.control_time_s = control_time_s
|
||||||
|
self.fps = fps
|
||||||
|
self.max_episode_steps = int(self.control_time_s * self.fps)
|
||||||
|
self.current_step = 0
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
obs, reward, terminated, truncated, info = self.env.step(action)
|
||||||
|
self.current_step += 1
|
||||||
|
|
||||||
|
if self.current_step >= self.max_episode_steps:
|
||||||
|
terminated = True
|
||||||
|
|
||||||
|
return obs, reward, terminated, truncated, info
|
||||||
|
|
||||||
|
def reset(self, *, seed=None, options=None):
|
||||||
|
self.current_step = 0
|
||||||
|
return super().reset(seed=seed, options=options)
|
||||||
|
|
||||||
|
|
||||||
def make_maniskill(
|
def make_maniskill(
|
||||||
cfg: DictConfig,
|
cfg: ManiskillEnvConfig,
|
||||||
n_envs: int | None = None,
|
n_envs: int | None = None,
|
||||||
) -> gym.Env:
|
) -> gym.Env:
|
||||||
"""
|
"""
|
||||||
Factory function to create a ManiSkill environment with standard wrappers.
|
Factory function to create a ManiSkill environment with standard wrappers.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
task: Name of the ManiSkill task
|
cfg: Configuration for the ManiSkill environment
|
||||||
obs_mode: Observation mode (rgb, rgbd, etc)
|
|
||||||
control_mode: Control mode for the robot
|
|
||||||
render_mode: Rendering mode
|
|
||||||
sensor_configs: Camera sensor configurations
|
|
||||||
n_envs: Number of parallel environments
|
n_envs: Number of parallel environments
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A wrapped ManiSkill environment
|
A wrapped ManiSkill environment
|
||||||
"""
|
"""
|
||||||
|
|
||||||
env = gym.make(
|
env = gym.make(
|
||||||
cfg.env.task,
|
cfg.task,
|
||||||
obs_mode=cfg.env.obs,
|
obs_mode=cfg.obs_type,
|
||||||
control_mode=cfg.env.control_mode,
|
control_mode=cfg.control_mode,
|
||||||
render_mode=cfg.env.render_mode,
|
render_mode=cfg.render_mode,
|
||||||
sensor_configs={"width": cfg.env.image_size, "height": cfg.env.image_size},
|
sensor_configs={"width": cfg.image_size, "height": cfg.image_size},
|
||||||
num_envs=n_envs,
|
num_envs=n_envs,
|
||||||
)
|
)
|
||||||
|
|
||||||
if cfg.env.video_record.enabled:
|
# Add video recording if enabled
|
||||||
|
if cfg.video_record.enabled:
|
||||||
env = RecordEpisode(
|
env = RecordEpisode(
|
||||||
env,
|
env,
|
||||||
output_dir=cfg.env.video_record.record_dir,
|
output_dir=cfg.video_record.record_dir,
|
||||||
save_trajectory=True,
|
save_trajectory=True,
|
||||||
trajectory_name=cfg.env.video_record.trajectory_name,
|
trajectory_name=cfg.video_record.trajectory_name,
|
||||||
save_video=True,
|
save_video=True,
|
||||||
video_fps=30,
|
video_fps=30,
|
||||||
)
|
)
|
||||||
env = ManiSkillObservationWrapper(env, device=cfg.env.device)
|
|
||||||
|
# Add observation and image processing
|
||||||
|
env = ManiSkillObservationWrapper(env, device=cfg.device)
|
||||||
env = ManiSkillVectorEnv(env, ignore_terminations=True, auto_reset=False)
|
env = ManiSkillVectorEnv(env, ignore_terminations=True, auto_reset=False)
|
||||||
env._max_episode_steps = env.max_episode_steps = 50 # gym_utils.find_max_episode_steps_value(env)
|
env._max_episode_steps = env.max_episode_steps = cfg.episode_length
|
||||||
env.unwrapped.metadata["render_fps"] = 20
|
env.unwrapped.metadata["render_fps"] = cfg.fps
|
||||||
|
|
||||||
|
# Add compatibility wrappers
|
||||||
env = ManiSkillCompat(env)
|
env = ManiSkillCompat(env)
|
||||||
env = ManiSkillActionWrapper(env)
|
env = ManiSkillActionWrapper(env)
|
||||||
env = ManiSkillMultiplyActionWrapper(env, multiply_factor=0.03)
|
env = ManiSkillMultiplyActionWrapper(env, multiply_factor=0.03) # Scale actions for better control
|
||||||
|
|
||||||
return env
|
return env
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
@parser.wrap()
|
||||||
import argparse
|
def main(cfg: ManiskillEnvConfig):
|
||||||
|
"""Main function to run the ManiSkill environment."""
|
||||||
|
# Create the ManiSkill environment
|
||||||
|
env = make_maniskill(cfg, n_envs=1)
|
||||||
|
|
||||||
import hydra
|
# Reset the environment
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument("--config", type=str, default="lerobot/configs/env/maniskill_example.yaml")
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
# Initialize config
|
|
||||||
with hydra.initialize(version_base=None, config_path="../../configs"):
|
|
||||||
cfg = hydra.compose(config_name="env/maniskill_example.yaml")
|
|
||||||
|
|
||||||
env = make_maniskill(
|
|
||||||
task=cfg.env.task,
|
|
||||||
obs_mode=cfg.env.obs,
|
|
||||||
control_mode=cfg.env.control_mode,
|
|
||||||
render_mode=cfg.env.render_mode,
|
|
||||||
sensor_configs={"width": cfg.env.render_size, "height": cfg.env.render_size},
|
|
||||||
)
|
|
||||||
|
|
||||||
print("env done")
|
|
||||||
obs, info = env.reset()
|
obs, info = env.reset()
|
||||||
random_action = env.action_space.sample()
|
|
||||||
obs, reward, terminated, truncated, info = env.step(random_action)
|
# Run a simple interaction loop
|
||||||
|
sum_reward = 0
|
||||||
|
for i in range(100):
|
||||||
|
# Sample a random action
|
||||||
|
action = env.action_space.sample()
|
||||||
|
|
||||||
|
# Step the environment
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
obs, reward, terminated, truncated, info = env.step(action)
|
||||||
|
step_time = time.perf_counter() - start_time
|
||||||
|
sum_reward += reward
|
||||||
|
# Log information
|
||||||
|
|
||||||
|
# Reset if episode terminated
|
||||||
|
if terminated or truncated:
|
||||||
|
logging.info(f"Step {i}, reward: {sum_reward}, step time: {step_time}s")
|
||||||
|
sum_reward = 0
|
||||||
|
obs, info = env.reset()
|
||||||
|
|
||||||
|
# Close the environment
|
||||||
|
env.close()
|
||||||
|
|
||||||
|
|
||||||
|
# if __name__ == "__main__":
|
||||||
|
# logging.basicConfig(level=logging.INFO)
|
||||||
|
# main()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import draccus
|
||||||
|
config = ManiskillEnvConfig()
|
||||||
|
draccus.set_config_type("json")
|
||||||
|
draccus.dump(config=config, stream=open(file='run_config.json', mode='w'), )
|
Loading…
Reference in New Issue