lerobot/lerobot/scripts/train.py

206 lines
6.5 KiB
Python

import time
import hydra
import numpy as np
import torch
from tensordict.nn import TensorDictModule
from termcolor import colored
from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
from torchrl.data.datasets.d4rl import D4RLExperienceReplay
from torchrl.data.datasets.openx import OpenXExperienceReplay
from torchrl.data.replay_buffers import PrioritizedSliceSampler
from lerobot.common.datasets.factory import make_offline_buffer
from lerobot.common.datasets.simxarm import SimxarmExperienceReplay
from lerobot.common.envs.factory import make_env
from lerobot.common.logger import Logger
from lerobot.common.tdmpc import TDMPC
from lerobot.common.utils import set_seed
from lerobot.scripts.eval import eval_policy
@hydra.main(version_base=None, config_name="default", config_path="../configs")
def train_cli(cfg: dict):
train(
cfg,
out_dir=hydra.core.hydra_config.HydraConfig.get().run.dir,
job_name=hydra.core.hydra_config.HydraConfig.get().job.name,
)
def train_notebook(
out_dir=None, job_name=None, config_name="default", config_path="../configs"
):
from hydra import compose, initialize
hydra.core.global_hydra.GlobalHydra.instance().clear()
initialize(config_path=config_path)
cfg = compose(config_name=config_name)
train(cfg, out_dir=out_dir, job_name=job_name)
def train(cfg: dict, out_dir=None, job_name=None):
if out_dir is None:
raise NotImplementedError()
if job_name is None:
raise NotImplementedError()
assert torch.cuda.is_available()
set_seed(cfg.seed)
print(colored("Work dir:", "yellow", attrs=["bold"]), out_dir)
env = make_env(cfg)
policy = TDMPC(cfg)
if cfg.pretrained_model_path:
# TODO(rcadene): hack for old pretrained models from fowm
if "fowm" in cfg.pretrained_model_path:
if "offline" in cfg.pretrained_model_path:
policy.step = 25000
elif "final" in cfg.pretrained_model_path:
policy.step = 100000
else:
raise NotImplementedError()
policy.load(cfg.pretrained_model_path)
td_policy = TensorDictModule(
policy,
in_keys=["observation", "step_count"],
out_keys=["action"],
)
# initialize offline dataset
offline_buffer = make_offline_buffer(cfg)
if cfg.balanced_sampling:
num_traj_per_batch = cfg.batch_size
online_sampler = PrioritizedSliceSampler(
max_capacity=100_000,
alpha=cfg.per_alpha,
beta=cfg.per_beta,
num_slices=num_traj_per_batch,
strict_length=True,
)
online_buffer = TensorDictReplayBuffer(
storage=LazyMemmapStorage(100_000),
sampler=online_sampler,
)
L = Logger(out_dir, job_name, cfg)
online_episode_idx = 0
start_time = time.time()
step = 0
last_log_step = 0
last_save_step = 0
while step < cfg.train_steps:
is_offline = True
num_updates = cfg.episode_length
_step = step + num_updates
rollout_metrics = {}
if step >= cfg.offline_steps:
is_offline = False
# TODO: use SyncDataCollector for that?
with torch.no_grad():
rollout = env.rollout(
max_steps=cfg.episode_length,
policy=td_policy,
auto_cast_to_device=True,
)
assert len(rollout) <= cfg.episode_length
rollout["episode"] = torch.tensor(
[online_episode_idx] * len(rollout), dtype=torch.int
)
online_buffer.extend(rollout)
ep_sum_reward = rollout["next", "reward"].sum()
ep_max_reward = rollout["next", "reward"].max()
ep_success = rollout["next", "success"].any()
online_episode_idx += 1
rollout_metrics = {
"avg_sum_reward": np.nanmean(ep_sum_reward),
"avg_max_reward": np.nanmean(ep_max_reward),
"pc_success": np.nanmean(ep_success) * 100,
}
num_updates = len(rollout) * cfg.utd
_step = min(step + len(rollout), cfg.train_steps)
# Update model
for i in range(num_updates):
if is_offline:
train_metrics = policy.update(offline_buffer, step + i)
else:
train_metrics = policy.update(
online_buffer,
step + i // cfg.utd,
demo_buffer=offline_buffer if cfg.balanced_sampling else None,
)
# Log training metrics
env_step = int(_step * cfg.action_repeat)
common_metrics = {
"episode": online_episode_idx,
"step": _step,
"env_step": env_step,
"total_time": time.time() - start_time,
"is_offline": float(is_offline),
}
train_metrics.update(common_metrics)
train_metrics.update(rollout_metrics)
L.log(train_metrics, category="train")
# Evaluate policy periodically
if step == 0 or env_step - last_log_step >= cfg.eval_freq:
eval_metrics = eval_policy(
env,
td_policy,
num_episodes=cfg.eval_episodes,
env_step=env_step,
wandb=L._wandb,
)
common_metrics.update(eval_metrics)
L.log(common_metrics, category="eval")
last_log_step = env_step - env_step % cfg.eval_freq
# Save model periodically
if cfg.save_model and env_step - last_save_step >= cfg.save_freq:
L.save_model(policy, identifier=env_step)
print(f"Model has been checkpointed at step {env_step}")
last_save_step = env_step - env_step % cfg.save_freq
if cfg.save_model and is_offline and _step >= cfg.offline_steps:
# save the model after offline training
L.save_model(policy, identifier="offline")
step = _step
# dataset_d4rl = D4RLExperienceReplay(
# dataset_id="maze2d-umaze-v1",
# split_trajs=False,
# batch_size=1,
# sampler=SamplerWithoutReplacement(drop_last=False),
# prefetch=4,
# direct_download=True,
# )
# dataset_openx = OpenXExperienceReplay(
# "cmu_stretch",
# batch_size=1,
# num_slices=1,
# #download="force",
# streaming=False,
# root="data",
# )
if __name__ == "__main__":
train_cli()