finish examples 2 and 3
This commit is contained in:
parent
cb6d1e0871
commit
1ed0110900
|
@ -1 +1,39 @@
|
|||
# TODO
|
||||
"""
|
||||
This scripts demonstrates how to evaluate a pretrained policy from the HuggingFace Hub or from your local
|
||||
training outputs directory. In the latter case, you might want to run examples/3_train_policy.py first.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from lerobot.common.utils import init_hydra_config
|
||||
from lerobot.scripts.eval import eval
|
||||
|
||||
# Get a pretrained policy from the hub.
|
||||
hub_id = "lerobot/diffusion_policy_pusht_image"
|
||||
folder = Path(snapshot_download(hub_id))
|
||||
# OR uncomment the following to evaluate a policy from the local outputs/train folder.
|
||||
folder = Path("outputs/train/example_pusht_diffusion")
|
||||
|
||||
config_path = folder / "config.yaml"
|
||||
weights_path = folder / "model.pt"
|
||||
stats_path = folder / "stats.pth" # normalization stats
|
||||
|
||||
# Override some config parameters to do with evaluation.
|
||||
overrides = [
|
||||
f"policy.pretrained_model_path={weights_path}",
|
||||
"eval_episodes=10",
|
||||
"rollout_batch_size=10",
|
||||
"device=cuda",
|
||||
]
|
||||
|
||||
# Create a Hydra config.
|
||||
cfg = init_hydra_config(config_path, overrides)
|
||||
|
||||
# Evaluate the policy and save the outputs including metrics and videos.
|
||||
eval(
|
||||
cfg,
|
||||
out_dir=f"outputs/eval/example_{cfg.env.name}_{cfg.policy.name}",
|
||||
stats_path=stats_path,
|
||||
)
|
||||
|
|
|
@ -1 +1,55 @@
|
|||
# TODO
|
||||
"""This scripts demonstrates how to train Diffusion Policy on the PushT environment.
|
||||
|
||||
Once you have trained a model with this script, you can try to evaluate it on
|
||||
examples/2_evaluate_pretrained_policy.py
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from omegaconf import OmegaConf
|
||||
from tqdm import trange
|
||||
|
||||
from lerobot.common.datasets.factory import make_offline_buffer
|
||||
from lerobot.common.policies.diffusion.policy import DiffusionPolicy
|
||||
from lerobot.common.utils import init_hydra_config
|
||||
|
||||
output_directory = Path("outputs/train/example_pusht_diffusion")
|
||||
os.makedirs(output_directory, exist_ok=True)
|
||||
|
||||
overrides = [
|
||||
"env=pusht",
|
||||
"policy=diffusion",
|
||||
# Adjust as you prefer. 5000 steps are needed to get something worth evaluating.
|
||||
"offline_steps=5000",
|
||||
"log_freq=250",
|
||||
"device=cuda",
|
||||
]
|
||||
|
||||
cfg = init_hydra_config("lerobot/configs/default.yaml", overrides)
|
||||
|
||||
policy = DiffusionPolicy(
|
||||
cfg=cfg.policy,
|
||||
cfg_device=cfg.device,
|
||||
cfg_noise_scheduler=cfg.noise_scheduler,
|
||||
cfg_rgb_model=cfg.rgb_model,
|
||||
cfg_obs_encoder=cfg.obs_encoder,
|
||||
cfg_optimizer=cfg.optimizer,
|
||||
cfg_ema=cfg.ema,
|
||||
n_action_steps=cfg.n_action_steps + cfg.n_latency_steps,
|
||||
**cfg.policy,
|
||||
)
|
||||
policy.train()
|
||||
|
||||
offline_buffer = make_offline_buffer(cfg)
|
||||
|
||||
for offline_step in trange(cfg.offline_steps):
|
||||
train_info = policy.update(offline_buffer, offline_step)
|
||||
if offline_step % cfg.log_freq == 0:
|
||||
print(train_info)
|
||||
|
||||
# Save the policy, configuration, and normalization stats for later use.
|
||||
policy.save(output_directory / "model.pt")
|
||||
OmegaConf.save(cfg, output_directory / "config.yaml")
|
||||
torch.save(offline_buffer.transform[-1].stats, output_directory / "stats.pth")
|
||||
|
|
|
@ -59,6 +59,8 @@ class AbstractDataset(TensorDictReplayBuffer):
|
|||
collate_fn: Callable | None = None,
|
||||
writer: Writer | None = None,
|
||||
transform: "torchrl.envs.Transform" = None,
|
||||
# Don't actually load any data. This is a stand-in solution to get the transforms.
|
||||
dummy: bool = False,
|
||||
):
|
||||
assert (
|
||||
self.available_datasets is not None
|
||||
|
@ -77,7 +79,7 @@ class AbstractDataset(TensorDictReplayBuffer):
|
|||
f"The version of the dataset ({self.version}) is not enforced when root is provided ({self.root})."
|
||||
)
|
||||
|
||||
storage = self._download_or_load_dataset()
|
||||
storage = self._download_or_load_dataset() if not dummy else None
|
||||
|
||||
super().__init__(
|
||||
storage=storage,
|
||||
|
|
|
@ -97,6 +97,7 @@ class AlohaDataset(AbstractDataset):
|
|||
collate_fn: Callable | None = None,
|
||||
writer: Writer | None = None,
|
||||
transform: "torchrl.envs.Transform" = None,
|
||||
dummy: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
dataset_id,
|
||||
|
@ -110,6 +111,7 @@ class AlohaDataset(AbstractDataset):
|
|||
collate_fn=collate_fn,
|
||||
writer=writer,
|
||||
transform=transform,
|
||||
dummy=dummy,
|
||||
)
|
||||
|
||||
@property
|
||||
|
|
|
@ -21,7 +21,12 @@ def make_offline_buffer(
|
|||
overwrite_batch_size=None,
|
||||
overwrite_prefetch=None,
|
||||
stats_path=None,
|
||||
# Don't actually load any data. This is a stand-in solution to get the transforms.
|
||||
dummy=False,
|
||||
):
|
||||
if dummy and normalize and stats_path is None:
|
||||
raise ValueError("`stats_path` is required if `dummy` and `normalize` are True.")
|
||||
|
||||
if cfg.policy.balanced_sampling:
|
||||
assert cfg.online_steps > 0
|
||||
batch_size = None
|
||||
|
@ -93,6 +98,7 @@ def make_offline_buffer(
|
|||
root=DATA_DIR,
|
||||
pin_memory=pin_memory,
|
||||
prefetch=prefetch if isinstance(prefetch, int) else None,
|
||||
dummy=dummy,
|
||||
)
|
||||
|
||||
if cfg.policy.name == "tdmpc":
|
||||
|
|
|
@ -100,6 +100,7 @@ class PushtDataset(AbstractDataset):
|
|||
collate_fn: Callable | None = None,
|
||||
writer: Writer | None = None,
|
||||
transform: "torchrl.envs.Transform" = None,
|
||||
dummy: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
dataset_id,
|
||||
|
@ -113,6 +114,7 @@ class PushtDataset(AbstractDataset):
|
|||
collate_fn=collate_fn,
|
||||
writer=writer,
|
||||
transform=transform,
|
||||
dummy=dummy,
|
||||
)
|
||||
|
||||
def _download_and_preproc_obsolete(self):
|
||||
|
|
|
@ -51,6 +51,7 @@ class SimxarmDataset(AbstractDataset):
|
|||
collate_fn: Callable | None = None,
|
||||
writer: Writer | None = None,
|
||||
transform: "torchrl.envs.Transform" = None,
|
||||
dummy: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
dataset_id,
|
||||
|
@ -64,6 +65,7 @@ class SimxarmDataset(AbstractDataset):
|
|||
collate_fn=collate_fn,
|
||||
writer=writer,
|
||||
transform=transform,
|
||||
dummy=dummy,
|
||||
)
|
||||
|
||||
def _download_and_preproc_obsolete(self):
|
||||
|
|
|
@ -1,9 +1,13 @@
|
|||
import logging
|
||||
import os.path as osp
|
||||
import random
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import hydra
|
||||
import numpy as np
|
||||
import torch
|
||||
from omegaconf import DictConfig
|
||||
|
||||
|
||||
def get_safe_torch_device(cfg_device: str, log: bool = False) -> torch.device:
|
||||
|
@ -63,3 +67,29 @@ def format_big_number(num):
|
|||
num /= divisor
|
||||
|
||||
return num
|
||||
|
||||
|
||||
def _relative_path_between(path1: Path, path2: Path) -> Path:
|
||||
"""Returns path1 relative to path2."""
|
||||
path1 = path1.absolute()
|
||||
path2 = path2.absolute()
|
||||
try:
|
||||
return path1.relative_to(path2)
|
||||
except ValueError: # most likely because path1 is not a subpath of path2
|
||||
common_parts = Path(osp.commonpath([path1, path2])).parts
|
||||
return Path(
|
||||
"/".join([".."] * (len(path2.parts) - len(common_parts)) + list(path1.parts[len(common_parts) :]))
|
||||
)
|
||||
|
||||
|
||||
def init_hydra_config(config_path: str, overrides: list[str] | None = None) -> DictConfig:
|
||||
"""Initialize a Hydra config given only the path to the relevant config file.
|
||||
|
||||
For config resolution, it is assumed that the config file's parent is the Hydra config dir.
|
||||
"""
|
||||
# Hydra needs a path relative to this file.
|
||||
hydra.initialize(
|
||||
str(_relative_path_between(Path(config_path).absolute().parent, Path(__file__).absolute().parent))
|
||||
)
|
||||
cfg = hydra.compose(Path(config_path).stem, overrides)
|
||||
return cfg
|
||||
|
|
|
@ -30,14 +30,12 @@ python lerobot/scripts/eval.py --hub-id HUB/ID --revision v1.0 eval_episodes=10
|
|||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import os.path as osp
|
||||
import threading
|
||||
import time
|
||||
from datetime import datetime as dt
|
||||
from pathlib import Path
|
||||
|
||||
import einops
|
||||
import hydra
|
||||
import imageio
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -52,7 +50,7 @@ from lerobot.common.envs.factory import make_env
|
|||
from lerobot.common.logger import log_output_dir
|
||||
from lerobot.common.policies.abstract import AbstractPolicy
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.utils import get_safe_torch_device, init_logging, set_global_seed
|
||||
from lerobot.common.utils import get_safe_torch_device, init_hydra_config, init_logging, set_global_seed
|
||||
|
||||
|
||||
def write_video(video_path, stacked_frames, fps):
|
||||
|
@ -195,7 +193,8 @@ def eval(cfg: dict, out_dir=None, stats_path=None):
|
|||
log_output_dir(out_dir)
|
||||
|
||||
logging.info("Making transforms.")
|
||||
offline_buffer = make_offline_buffer(cfg, stats_path=stats_path)
|
||||
# TODO(alexander-soare): Completely decouple datasets from evaluation.
|
||||
offline_buffer = make_offline_buffer(cfg, stats_path=stats_path, dummy=stats_path is not None)
|
||||
|
||||
logging.info("Making environment.")
|
||||
env = make_env(cfg, transform=offline_buffer.transform)
|
||||
|
@ -229,19 +228,6 @@ def eval(cfg: dict, out_dir=None, stats_path=None):
|
|||
logging.info("End of eval")
|
||||
|
||||
|
||||
def _relative_path_between(path1: Path, path2: Path) -> Path:
|
||||
"""Returns path1 relative to path2."""
|
||||
path1 = path1.absolute()
|
||||
path2 = path2.absolute()
|
||||
try:
|
||||
return path1.relative_to(path2)
|
||||
except ValueError: # most likely because path1 is not a subpath of path2
|
||||
common_parts = Path(osp.commonpath([path1, path2])).parts
|
||||
return Path(
|
||||
"/".join([".."] * (len(path2.parts) - len(common_parts)) + list(path1.parts[len(common_parts) :]))
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
|
||||
|
@ -259,19 +245,14 @@ if __name__ == "__main__":
|
|||
|
||||
if args.config is not None:
|
||||
# Note: For the config_path, Hydra wants a path relative to this script file.
|
||||
hydra.initialize(
|
||||
config_path=str(
|
||||
_relative_path_between(Path(args.config).absolute().parent, Path(__file__).parent)
|
||||
)
|
||||
)
|
||||
cfg = hydra.compose(Path(args.config).stem, args.overrides)
|
||||
cfg = init_hydra_config(args.config, args.overrides)
|
||||
# TODO(alexander-soare): Save and load stats in trained model directory.
|
||||
stats_path = None
|
||||
elif args.hub_id is not None:
|
||||
folder = Path(snapshot_download(args.hub_id, revision="v1.0"))
|
||||
cfg = hydra.initialize(config_path=str(_relative_path_between(folder, Path(__file__).parent)))
|
||||
cfg = hydra.compose("config", args.overrides)
|
||||
cfg.policy.pretrained_model_path = folder / "model.pt"
|
||||
cfg = init_hydra_config(
|
||||
folder / "config.yaml", [f"policy.pretrained_model_path={folder / 'model.pt'}", *args.overrides]
|
||||
)
|
||||
stats_path = folder / "stats.pth"
|
||||
|
||||
eval(
|
||||
|
|
|
@ -1,19 +1,56 @@
|
|||
import pytest
|
||||
from pathlib import Path
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
"examples/1_visualize_dataset.py",
|
||||
"examples/2_evaluate_pretrained_policy.py",
|
||||
"examples/3_train_policy.py",
|
||||
],
|
||||
)
|
||||
def test_example(path):
|
||||
|
||||
with open(path, 'r') as file:
|
||||
def _find_and_replace(text: str, finds: list[str], replaces: list[str]) -> str:
|
||||
for f, r in zip(finds, replaces):
|
||||
assert f in text
|
||||
text = text.replace(f, r)
|
||||
return text
|
||||
|
||||
|
||||
def test_example_1():
|
||||
path = "examples/1_visualize_dataset.py"
|
||||
|
||||
with open(path, "r") as file:
|
||||
file_contents = file.read()
|
||||
exec(file_contents)
|
||||
|
||||
if path == "examples/1_visualize_dataset.py":
|
||||
assert Path("outputs/visualize_dataset/example/episode_0.mp4").exists()
|
||||
assert Path("outputs/visualize_dataset/example/episode_0.mp4").exists()
|
||||
|
||||
|
||||
def test_examples_3_and_2():
|
||||
"""
|
||||
Train a model with example 3, check the outputs.
|
||||
Evaluate the trained model with example 2, check the outputs.
|
||||
"""
|
||||
|
||||
path = "examples/3_train_policy.py"
|
||||
|
||||
with open(path, "r") as file:
|
||||
file_contents = file.read()
|
||||
|
||||
# Do less steps and use CPU.
|
||||
file_contents = _find_and_replace(
|
||||
file_contents,
|
||||
['"offline_steps=5000"', '"device=cuda"'],
|
||||
['"offline_steps=1"', '"device=cpu"'],
|
||||
)
|
||||
|
||||
exec(file_contents)
|
||||
|
||||
for file_name in ["model.pt", "stats.pth", "config.yaml"]:
|
||||
assert Path(f"outputs/train/example_pusht_diffusion/{file_name}").exists()
|
||||
|
||||
path = "examples/2_evaluate_pretrained_policy.py"
|
||||
|
||||
with open(path, "r") as file:
|
||||
file_contents = file.read()
|
||||
|
||||
# Do less evals and use CPU.
|
||||
file_contents = _find_and_replace(
|
||||
file_contents,
|
||||
['"eval_episodes=10"', '"rollout_batch_size=10"', '"device=cuda"'],
|
||||
['"eval_episodes=1"', '"rollout_batch_size=1"','"device=cpu"'],
|
||||
)
|
||||
|
||||
assert Path(f"outputs/train/example_pusht_diffusion").exists()
|
Loading…
Reference in New Issue