finish examples 2 and 3
This commit is contained in:
parent
cb6d1e0871
commit
1ed0110900
|
@ -1 +1,39 @@
|
||||||
# TODO
|
"""
|
||||||
|
This scripts demonstrates how to evaluate a pretrained policy from the HuggingFace Hub or from your local
|
||||||
|
training outputs directory. In the latter case, you might want to run examples/3_train_policy.py first.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
|
from lerobot.common.utils import init_hydra_config
|
||||||
|
from lerobot.scripts.eval import eval
|
||||||
|
|
||||||
|
# Get a pretrained policy from the hub.
|
||||||
|
hub_id = "lerobot/diffusion_policy_pusht_image"
|
||||||
|
folder = Path(snapshot_download(hub_id))
|
||||||
|
# OR uncomment the following to evaluate a policy from the local outputs/train folder.
|
||||||
|
folder = Path("outputs/train/example_pusht_diffusion")
|
||||||
|
|
||||||
|
config_path = folder / "config.yaml"
|
||||||
|
weights_path = folder / "model.pt"
|
||||||
|
stats_path = folder / "stats.pth" # normalization stats
|
||||||
|
|
||||||
|
# Override some config parameters to do with evaluation.
|
||||||
|
overrides = [
|
||||||
|
f"policy.pretrained_model_path={weights_path}",
|
||||||
|
"eval_episodes=10",
|
||||||
|
"rollout_batch_size=10",
|
||||||
|
"device=cuda",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Create a Hydra config.
|
||||||
|
cfg = init_hydra_config(config_path, overrides)
|
||||||
|
|
||||||
|
# Evaluate the policy and save the outputs including metrics and videos.
|
||||||
|
eval(
|
||||||
|
cfg,
|
||||||
|
out_dir=f"outputs/eval/example_{cfg.env.name}_{cfg.policy.name}",
|
||||||
|
stats_path=stats_path,
|
||||||
|
)
|
||||||
|
|
|
@ -1 +1,55 @@
|
||||||
# TODO
|
"""This scripts demonstrates how to train Diffusion Policy on the PushT environment.
|
||||||
|
|
||||||
|
Once you have trained a model with this script, you can try to evaluate it on
|
||||||
|
examples/2_evaluate_pretrained_policy.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
from tqdm import trange
|
||||||
|
|
||||||
|
from lerobot.common.datasets.factory import make_offline_buffer
|
||||||
|
from lerobot.common.policies.diffusion.policy import DiffusionPolicy
|
||||||
|
from lerobot.common.utils import init_hydra_config
|
||||||
|
|
||||||
|
output_directory = Path("outputs/train/example_pusht_diffusion")
|
||||||
|
os.makedirs(output_directory, exist_ok=True)
|
||||||
|
|
||||||
|
overrides = [
|
||||||
|
"env=pusht",
|
||||||
|
"policy=diffusion",
|
||||||
|
# Adjust as you prefer. 5000 steps are needed to get something worth evaluating.
|
||||||
|
"offline_steps=5000",
|
||||||
|
"log_freq=250",
|
||||||
|
"device=cuda",
|
||||||
|
]
|
||||||
|
|
||||||
|
cfg = init_hydra_config("lerobot/configs/default.yaml", overrides)
|
||||||
|
|
||||||
|
policy = DiffusionPolicy(
|
||||||
|
cfg=cfg.policy,
|
||||||
|
cfg_device=cfg.device,
|
||||||
|
cfg_noise_scheduler=cfg.noise_scheduler,
|
||||||
|
cfg_rgb_model=cfg.rgb_model,
|
||||||
|
cfg_obs_encoder=cfg.obs_encoder,
|
||||||
|
cfg_optimizer=cfg.optimizer,
|
||||||
|
cfg_ema=cfg.ema,
|
||||||
|
n_action_steps=cfg.n_action_steps + cfg.n_latency_steps,
|
||||||
|
**cfg.policy,
|
||||||
|
)
|
||||||
|
policy.train()
|
||||||
|
|
||||||
|
offline_buffer = make_offline_buffer(cfg)
|
||||||
|
|
||||||
|
for offline_step in trange(cfg.offline_steps):
|
||||||
|
train_info = policy.update(offline_buffer, offline_step)
|
||||||
|
if offline_step % cfg.log_freq == 0:
|
||||||
|
print(train_info)
|
||||||
|
|
||||||
|
# Save the policy, configuration, and normalization stats for later use.
|
||||||
|
policy.save(output_directory / "model.pt")
|
||||||
|
OmegaConf.save(cfg, output_directory / "config.yaml")
|
||||||
|
torch.save(offline_buffer.transform[-1].stats, output_directory / "stats.pth")
|
||||||
|
|
|
@ -59,6 +59,8 @@ class AbstractDataset(TensorDictReplayBuffer):
|
||||||
collate_fn: Callable | None = None,
|
collate_fn: Callable | None = None,
|
||||||
writer: Writer | None = None,
|
writer: Writer | None = None,
|
||||||
transform: "torchrl.envs.Transform" = None,
|
transform: "torchrl.envs.Transform" = None,
|
||||||
|
# Don't actually load any data. This is a stand-in solution to get the transforms.
|
||||||
|
dummy: bool = False,
|
||||||
):
|
):
|
||||||
assert (
|
assert (
|
||||||
self.available_datasets is not None
|
self.available_datasets is not None
|
||||||
|
@ -77,7 +79,7 @@ class AbstractDataset(TensorDictReplayBuffer):
|
||||||
f"The version of the dataset ({self.version}) is not enforced when root is provided ({self.root})."
|
f"The version of the dataset ({self.version}) is not enforced when root is provided ({self.root})."
|
||||||
)
|
)
|
||||||
|
|
||||||
storage = self._download_or_load_dataset()
|
storage = self._download_or_load_dataset() if not dummy else None
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
storage=storage,
|
storage=storage,
|
||||||
|
|
|
@ -97,6 +97,7 @@ class AlohaDataset(AbstractDataset):
|
||||||
collate_fn: Callable | None = None,
|
collate_fn: Callable | None = None,
|
||||||
writer: Writer | None = None,
|
writer: Writer | None = None,
|
||||||
transform: "torchrl.envs.Transform" = None,
|
transform: "torchrl.envs.Transform" = None,
|
||||||
|
dummy: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
dataset_id,
|
dataset_id,
|
||||||
|
@ -110,6 +111,7 @@ class AlohaDataset(AbstractDataset):
|
||||||
collate_fn=collate_fn,
|
collate_fn=collate_fn,
|
||||||
writer=writer,
|
writer=writer,
|
||||||
transform=transform,
|
transform=transform,
|
||||||
|
dummy=dummy,
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
|
@ -21,7 +21,12 @@ def make_offline_buffer(
|
||||||
overwrite_batch_size=None,
|
overwrite_batch_size=None,
|
||||||
overwrite_prefetch=None,
|
overwrite_prefetch=None,
|
||||||
stats_path=None,
|
stats_path=None,
|
||||||
|
# Don't actually load any data. This is a stand-in solution to get the transforms.
|
||||||
|
dummy=False,
|
||||||
):
|
):
|
||||||
|
if dummy and normalize and stats_path is None:
|
||||||
|
raise ValueError("`stats_path` is required if `dummy` and `normalize` are True.")
|
||||||
|
|
||||||
if cfg.policy.balanced_sampling:
|
if cfg.policy.balanced_sampling:
|
||||||
assert cfg.online_steps > 0
|
assert cfg.online_steps > 0
|
||||||
batch_size = None
|
batch_size = None
|
||||||
|
@ -93,6 +98,7 @@ def make_offline_buffer(
|
||||||
root=DATA_DIR,
|
root=DATA_DIR,
|
||||||
pin_memory=pin_memory,
|
pin_memory=pin_memory,
|
||||||
prefetch=prefetch if isinstance(prefetch, int) else None,
|
prefetch=prefetch if isinstance(prefetch, int) else None,
|
||||||
|
dummy=dummy,
|
||||||
)
|
)
|
||||||
|
|
||||||
if cfg.policy.name == "tdmpc":
|
if cfg.policy.name == "tdmpc":
|
||||||
|
|
|
@ -100,6 +100,7 @@ class PushtDataset(AbstractDataset):
|
||||||
collate_fn: Callable | None = None,
|
collate_fn: Callable | None = None,
|
||||||
writer: Writer | None = None,
|
writer: Writer | None = None,
|
||||||
transform: "torchrl.envs.Transform" = None,
|
transform: "torchrl.envs.Transform" = None,
|
||||||
|
dummy: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
dataset_id,
|
dataset_id,
|
||||||
|
@ -113,6 +114,7 @@ class PushtDataset(AbstractDataset):
|
||||||
collate_fn=collate_fn,
|
collate_fn=collate_fn,
|
||||||
writer=writer,
|
writer=writer,
|
||||||
transform=transform,
|
transform=transform,
|
||||||
|
dummy=dummy,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _download_and_preproc_obsolete(self):
|
def _download_and_preproc_obsolete(self):
|
||||||
|
|
|
@ -51,6 +51,7 @@ class SimxarmDataset(AbstractDataset):
|
||||||
collate_fn: Callable | None = None,
|
collate_fn: Callable | None = None,
|
||||||
writer: Writer | None = None,
|
writer: Writer | None = None,
|
||||||
transform: "torchrl.envs.Transform" = None,
|
transform: "torchrl.envs.Transform" = None,
|
||||||
|
dummy: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
dataset_id,
|
dataset_id,
|
||||||
|
@ -64,6 +65,7 @@ class SimxarmDataset(AbstractDataset):
|
||||||
collate_fn=collate_fn,
|
collate_fn=collate_fn,
|
||||||
writer=writer,
|
writer=writer,
|
||||||
transform=transform,
|
transform=transform,
|
||||||
|
dummy=dummy,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _download_and_preproc_obsolete(self):
|
def _download_and_preproc_obsolete(self):
|
||||||
|
|
|
@ -1,9 +1,13 @@
|
||||||
import logging
|
import logging
|
||||||
|
import os.path as osp
|
||||||
import random
|
import random
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import hydra
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from omegaconf import DictConfig
|
||||||
|
|
||||||
|
|
||||||
def get_safe_torch_device(cfg_device: str, log: bool = False) -> torch.device:
|
def get_safe_torch_device(cfg_device: str, log: bool = False) -> torch.device:
|
||||||
|
@ -63,3 +67,29 @@ def format_big_number(num):
|
||||||
num /= divisor
|
num /= divisor
|
||||||
|
|
||||||
return num
|
return num
|
||||||
|
|
||||||
|
|
||||||
|
def _relative_path_between(path1: Path, path2: Path) -> Path:
|
||||||
|
"""Returns path1 relative to path2."""
|
||||||
|
path1 = path1.absolute()
|
||||||
|
path2 = path2.absolute()
|
||||||
|
try:
|
||||||
|
return path1.relative_to(path2)
|
||||||
|
except ValueError: # most likely because path1 is not a subpath of path2
|
||||||
|
common_parts = Path(osp.commonpath([path1, path2])).parts
|
||||||
|
return Path(
|
||||||
|
"/".join([".."] * (len(path2.parts) - len(common_parts)) + list(path1.parts[len(common_parts) :]))
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def init_hydra_config(config_path: str, overrides: list[str] | None = None) -> DictConfig:
|
||||||
|
"""Initialize a Hydra config given only the path to the relevant config file.
|
||||||
|
|
||||||
|
For config resolution, it is assumed that the config file's parent is the Hydra config dir.
|
||||||
|
"""
|
||||||
|
# Hydra needs a path relative to this file.
|
||||||
|
hydra.initialize(
|
||||||
|
str(_relative_path_between(Path(config_path).absolute().parent, Path(__file__).absolute().parent))
|
||||||
|
)
|
||||||
|
cfg = hydra.compose(Path(config_path).stem, overrides)
|
||||||
|
return cfg
|
||||||
|
|
|
@ -30,14 +30,12 @@ python lerobot/scripts/eval.py --hub-id HUB/ID --revision v1.0 eval_episodes=10
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os.path as osp
|
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from datetime import datetime as dt
|
from datetime import datetime as dt
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import einops
|
import einops
|
||||||
import hydra
|
|
||||||
import imageio
|
import imageio
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
@ -52,7 +50,7 @@ from lerobot.common.envs.factory import make_env
|
||||||
from lerobot.common.logger import log_output_dir
|
from lerobot.common.logger import log_output_dir
|
||||||
from lerobot.common.policies.abstract import AbstractPolicy
|
from lerobot.common.policies.abstract import AbstractPolicy
|
||||||
from lerobot.common.policies.factory import make_policy
|
from lerobot.common.policies.factory import make_policy
|
||||||
from lerobot.common.utils import get_safe_torch_device, init_logging, set_global_seed
|
from lerobot.common.utils import get_safe_torch_device, init_hydra_config, init_logging, set_global_seed
|
||||||
|
|
||||||
|
|
||||||
def write_video(video_path, stacked_frames, fps):
|
def write_video(video_path, stacked_frames, fps):
|
||||||
|
@ -195,7 +193,8 @@ def eval(cfg: dict, out_dir=None, stats_path=None):
|
||||||
log_output_dir(out_dir)
|
log_output_dir(out_dir)
|
||||||
|
|
||||||
logging.info("Making transforms.")
|
logging.info("Making transforms.")
|
||||||
offline_buffer = make_offline_buffer(cfg, stats_path=stats_path)
|
# TODO(alexander-soare): Completely decouple datasets from evaluation.
|
||||||
|
offline_buffer = make_offline_buffer(cfg, stats_path=stats_path, dummy=stats_path is not None)
|
||||||
|
|
||||||
logging.info("Making environment.")
|
logging.info("Making environment.")
|
||||||
env = make_env(cfg, transform=offline_buffer.transform)
|
env = make_env(cfg, transform=offline_buffer.transform)
|
||||||
|
@ -229,19 +228,6 @@ def eval(cfg: dict, out_dir=None, stats_path=None):
|
||||||
logging.info("End of eval")
|
logging.info("End of eval")
|
||||||
|
|
||||||
|
|
||||||
def _relative_path_between(path1: Path, path2: Path) -> Path:
|
|
||||||
"""Returns path1 relative to path2."""
|
|
||||||
path1 = path1.absolute()
|
|
||||||
path2 = path2.absolute()
|
|
||||||
try:
|
|
||||||
return path1.relative_to(path2)
|
|
||||||
except ValueError: # most likely because path1 is not a subpath of path2
|
|
||||||
common_parts = Path(osp.commonpath([path1, path2])).parts
|
|
||||||
return Path(
|
|
||||||
"/".join([".."] * (len(path2.parts) - len(common_parts)) + list(path1.parts[len(common_parts) :]))
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
|
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
|
||||||
|
@ -259,19 +245,14 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
if args.config is not None:
|
if args.config is not None:
|
||||||
# Note: For the config_path, Hydra wants a path relative to this script file.
|
# Note: For the config_path, Hydra wants a path relative to this script file.
|
||||||
hydra.initialize(
|
cfg = init_hydra_config(args.config, args.overrides)
|
||||||
config_path=str(
|
|
||||||
_relative_path_between(Path(args.config).absolute().parent, Path(__file__).parent)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
cfg = hydra.compose(Path(args.config).stem, args.overrides)
|
|
||||||
# TODO(alexander-soare): Save and load stats in trained model directory.
|
# TODO(alexander-soare): Save and load stats in trained model directory.
|
||||||
stats_path = None
|
stats_path = None
|
||||||
elif args.hub_id is not None:
|
elif args.hub_id is not None:
|
||||||
folder = Path(snapshot_download(args.hub_id, revision="v1.0"))
|
folder = Path(snapshot_download(args.hub_id, revision="v1.0"))
|
||||||
cfg = hydra.initialize(config_path=str(_relative_path_between(folder, Path(__file__).parent)))
|
cfg = init_hydra_config(
|
||||||
cfg = hydra.compose("config", args.overrides)
|
folder / "config.yaml", [f"policy.pretrained_model_path={folder / 'model.pt'}", *args.overrides]
|
||||||
cfg.policy.pretrained_model_path = folder / "model.pt"
|
)
|
||||||
stats_path = folder / "stats.pth"
|
stats_path = folder / "stats.pth"
|
||||||
|
|
||||||
eval(
|
eval(
|
||||||
|
|
|
@ -1,19 +1,56 @@
|
||||||
import pytest
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"path",
|
|
||||||
[
|
|
||||||
"examples/1_visualize_dataset.py",
|
|
||||||
"examples/2_evaluate_pretrained_policy.py",
|
|
||||||
"examples/3_train_policy.py",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_example(path):
|
|
||||||
|
|
||||||
with open(path, 'r') as file:
|
def _find_and_replace(text: str, finds: list[str], replaces: list[str]) -> str:
|
||||||
|
for f, r in zip(finds, replaces):
|
||||||
|
assert f in text
|
||||||
|
text = text.replace(f, r)
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def test_example_1():
|
||||||
|
path = "examples/1_visualize_dataset.py"
|
||||||
|
|
||||||
|
with open(path, "r") as file:
|
||||||
file_contents = file.read()
|
file_contents = file.read()
|
||||||
exec(file_contents)
|
exec(file_contents)
|
||||||
|
|
||||||
if path == "examples/1_visualize_dataset.py":
|
assert Path("outputs/visualize_dataset/example/episode_0.mp4").exists()
|
||||||
assert Path("outputs/visualize_dataset/example/episode_0.mp4").exists()
|
|
||||||
|
|
||||||
|
def test_examples_3_and_2():
|
||||||
|
"""
|
||||||
|
Train a model with example 3, check the outputs.
|
||||||
|
Evaluate the trained model with example 2, check the outputs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
path = "examples/3_train_policy.py"
|
||||||
|
|
||||||
|
with open(path, "r") as file:
|
||||||
|
file_contents = file.read()
|
||||||
|
|
||||||
|
# Do less steps and use CPU.
|
||||||
|
file_contents = _find_and_replace(
|
||||||
|
file_contents,
|
||||||
|
['"offline_steps=5000"', '"device=cuda"'],
|
||||||
|
['"offline_steps=1"', '"device=cpu"'],
|
||||||
|
)
|
||||||
|
|
||||||
|
exec(file_contents)
|
||||||
|
|
||||||
|
for file_name in ["model.pt", "stats.pth", "config.yaml"]:
|
||||||
|
assert Path(f"outputs/train/example_pusht_diffusion/{file_name}").exists()
|
||||||
|
|
||||||
|
path = "examples/2_evaluate_pretrained_policy.py"
|
||||||
|
|
||||||
|
with open(path, "r") as file:
|
||||||
|
file_contents = file.read()
|
||||||
|
|
||||||
|
# Do less evals and use CPU.
|
||||||
|
file_contents = _find_and_replace(
|
||||||
|
file_contents,
|
||||||
|
['"eval_episodes=10"', '"rollout_batch_size=10"', '"device=cuda"'],
|
||||||
|
['"eval_episodes=1"', '"rollout_batch_size=1"','"device=cpu"'],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert Path(f"outputs/train/example_pusht_diffusion").exists()
|
Loading…
Reference in New Issue