finish examples 2 and 3

This commit is contained in:
Alexander Soare 2024-03-26 16:13:40 +00:00
parent cb6d1e0871
commit 1ed0110900
10 changed files with 196 additions and 42 deletions

View File

@ -1 +1,39 @@
# TODO
"""
This scripts demonstrates how to evaluate a pretrained policy from the HuggingFace Hub or from your local
training outputs directory. In the latter case, you might want to run examples/3_train_policy.py first.
"""
from pathlib import Path
from huggingface_hub import snapshot_download
from lerobot.common.utils import init_hydra_config
from lerobot.scripts.eval import eval
# Get a pretrained policy from the hub.
hub_id = "lerobot/diffusion_policy_pusht_image"
folder = Path(snapshot_download(hub_id))
# OR uncomment the following to evaluate a policy from the local outputs/train folder.
folder = Path("outputs/train/example_pusht_diffusion")
config_path = folder / "config.yaml"
weights_path = folder / "model.pt"
stats_path = folder / "stats.pth" # normalization stats
# Override some config parameters to do with evaluation.
overrides = [
f"policy.pretrained_model_path={weights_path}",
"eval_episodes=10",
"rollout_batch_size=10",
"device=cuda",
]
# Create a Hydra config.
cfg = init_hydra_config(config_path, overrides)
# Evaluate the policy and save the outputs including metrics and videos.
eval(
cfg,
out_dir=f"outputs/eval/example_{cfg.env.name}_{cfg.policy.name}",
stats_path=stats_path,
)

View File

@ -1 +1,55 @@
# TODO
"""This scripts demonstrates how to train Diffusion Policy on the PushT environment.
Once you have trained a model with this script, you can try to evaluate it on
examples/2_evaluate_pretrained_policy.py
"""
import os
from pathlib import Path
import torch
from omegaconf import OmegaConf
from tqdm import trange
from lerobot.common.datasets.factory import make_offline_buffer
from lerobot.common.policies.diffusion.policy import DiffusionPolicy
from lerobot.common.utils import init_hydra_config
output_directory = Path("outputs/train/example_pusht_diffusion")
os.makedirs(output_directory, exist_ok=True)
overrides = [
"env=pusht",
"policy=diffusion",
# Adjust as you prefer. 5000 steps are needed to get something worth evaluating.
"offline_steps=5000",
"log_freq=250",
"device=cuda",
]
cfg = init_hydra_config("lerobot/configs/default.yaml", overrides)
policy = DiffusionPolicy(
cfg=cfg.policy,
cfg_device=cfg.device,
cfg_noise_scheduler=cfg.noise_scheduler,
cfg_rgb_model=cfg.rgb_model,
cfg_obs_encoder=cfg.obs_encoder,
cfg_optimizer=cfg.optimizer,
cfg_ema=cfg.ema,
n_action_steps=cfg.n_action_steps + cfg.n_latency_steps,
**cfg.policy,
)
policy.train()
offline_buffer = make_offline_buffer(cfg)
for offline_step in trange(cfg.offline_steps):
train_info = policy.update(offline_buffer, offline_step)
if offline_step % cfg.log_freq == 0:
print(train_info)
# Save the policy, configuration, and normalization stats for later use.
policy.save(output_directory / "model.pt")
OmegaConf.save(cfg, output_directory / "config.yaml")
torch.save(offline_buffer.transform[-1].stats, output_directory / "stats.pth")

View File

@ -59,6 +59,8 @@ class AbstractDataset(TensorDictReplayBuffer):
collate_fn: Callable | None = None,
writer: Writer | None = None,
transform: "torchrl.envs.Transform" = None,
# Don't actually load any data. This is a stand-in solution to get the transforms.
dummy: bool = False,
):
assert (
self.available_datasets is not None
@ -77,7 +79,7 @@ class AbstractDataset(TensorDictReplayBuffer):
f"The version of the dataset ({self.version}) is not enforced when root is provided ({self.root})."
)
storage = self._download_or_load_dataset()
storage = self._download_or_load_dataset() if not dummy else None
super().__init__(
storage=storage,

View File

@ -97,6 +97,7 @@ class AlohaDataset(AbstractDataset):
collate_fn: Callable | None = None,
writer: Writer | None = None,
transform: "torchrl.envs.Transform" = None,
dummy: bool = False,
):
super().__init__(
dataset_id,
@ -110,6 +111,7 @@ class AlohaDataset(AbstractDataset):
collate_fn=collate_fn,
writer=writer,
transform=transform,
dummy=dummy,
)
@property

View File

@ -21,7 +21,12 @@ def make_offline_buffer(
overwrite_batch_size=None,
overwrite_prefetch=None,
stats_path=None,
# Don't actually load any data. This is a stand-in solution to get the transforms.
dummy=False,
):
if dummy and normalize and stats_path is None:
raise ValueError("`stats_path` is required if `dummy` and `normalize` are True.")
if cfg.policy.balanced_sampling:
assert cfg.online_steps > 0
batch_size = None
@ -93,6 +98,7 @@ def make_offline_buffer(
root=DATA_DIR,
pin_memory=pin_memory,
prefetch=prefetch if isinstance(prefetch, int) else None,
dummy=dummy,
)
if cfg.policy.name == "tdmpc":

View File

@ -100,6 +100,7 @@ class PushtDataset(AbstractDataset):
collate_fn: Callable | None = None,
writer: Writer | None = None,
transform: "torchrl.envs.Transform" = None,
dummy: bool = False,
):
super().__init__(
dataset_id,
@ -113,6 +114,7 @@ class PushtDataset(AbstractDataset):
collate_fn=collate_fn,
writer=writer,
transform=transform,
dummy=dummy,
)
def _download_and_preproc_obsolete(self):

View File

@ -51,6 +51,7 @@ class SimxarmDataset(AbstractDataset):
collate_fn: Callable | None = None,
writer: Writer | None = None,
transform: "torchrl.envs.Transform" = None,
dummy: bool = False,
):
super().__init__(
dataset_id,
@ -64,6 +65,7 @@ class SimxarmDataset(AbstractDataset):
collate_fn=collate_fn,
writer=writer,
transform=transform,
dummy=dummy,
)
def _download_and_preproc_obsolete(self):

View File

@ -1,9 +1,13 @@
import logging
import os.path as osp
import random
from datetime import datetime
from pathlib import Path
import hydra
import numpy as np
import torch
from omegaconf import DictConfig
def get_safe_torch_device(cfg_device: str, log: bool = False) -> torch.device:
@ -63,3 +67,29 @@ def format_big_number(num):
num /= divisor
return num
def _relative_path_between(path1: Path, path2: Path) -> Path:
"""Returns path1 relative to path2."""
path1 = path1.absolute()
path2 = path2.absolute()
try:
return path1.relative_to(path2)
except ValueError: # most likely because path1 is not a subpath of path2
common_parts = Path(osp.commonpath([path1, path2])).parts
return Path(
"/".join([".."] * (len(path2.parts) - len(common_parts)) + list(path1.parts[len(common_parts) :]))
)
def init_hydra_config(config_path: str, overrides: list[str] | None = None) -> DictConfig:
"""Initialize a Hydra config given only the path to the relevant config file.
For config resolution, it is assumed that the config file's parent is the Hydra config dir.
"""
# Hydra needs a path relative to this file.
hydra.initialize(
str(_relative_path_between(Path(config_path).absolute().parent, Path(__file__).absolute().parent))
)
cfg = hydra.compose(Path(config_path).stem, overrides)
return cfg

View File

@ -30,14 +30,12 @@ python lerobot/scripts/eval.py --hub-id HUB/ID --revision v1.0 eval_episodes=10
import argparse
import json
import logging
import os.path as osp
import threading
import time
from datetime import datetime as dt
from pathlib import Path
import einops
import hydra
import imageio
import numpy as np
import torch
@ -52,7 +50,7 @@ from lerobot.common.envs.factory import make_env
from lerobot.common.logger import log_output_dir
from lerobot.common.policies.abstract import AbstractPolicy
from lerobot.common.policies.factory import make_policy
from lerobot.common.utils import get_safe_torch_device, init_logging, set_global_seed
from lerobot.common.utils import get_safe_torch_device, init_hydra_config, init_logging, set_global_seed
def write_video(video_path, stacked_frames, fps):
@ -195,7 +193,8 @@ def eval(cfg: dict, out_dir=None, stats_path=None):
log_output_dir(out_dir)
logging.info("Making transforms.")
offline_buffer = make_offline_buffer(cfg, stats_path=stats_path)
# TODO(alexander-soare): Completely decouple datasets from evaluation.
offline_buffer = make_offline_buffer(cfg, stats_path=stats_path, dummy=stats_path is not None)
logging.info("Making environment.")
env = make_env(cfg, transform=offline_buffer.transform)
@ -229,19 +228,6 @@ def eval(cfg: dict, out_dir=None, stats_path=None):
logging.info("End of eval")
def _relative_path_between(path1: Path, path2: Path) -> Path:
"""Returns path1 relative to path2."""
path1 = path1.absolute()
path2 = path2.absolute()
try:
return path1.relative_to(path2)
except ValueError: # most likely because path1 is not a subpath of path2
common_parts = Path(osp.commonpath([path1, path2])).parts
return Path(
"/".join([".."] * (len(path2.parts) - len(common_parts)) + list(path1.parts[len(common_parts) :]))
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
@ -259,19 +245,14 @@ if __name__ == "__main__":
if args.config is not None:
# Note: For the config_path, Hydra wants a path relative to this script file.
hydra.initialize(
config_path=str(
_relative_path_between(Path(args.config).absolute().parent, Path(__file__).parent)
)
)
cfg = hydra.compose(Path(args.config).stem, args.overrides)
cfg = init_hydra_config(args.config, args.overrides)
# TODO(alexander-soare): Save and load stats in trained model directory.
stats_path = None
elif args.hub_id is not None:
folder = Path(snapshot_download(args.hub_id, revision="v1.0"))
cfg = hydra.initialize(config_path=str(_relative_path_between(folder, Path(__file__).parent)))
cfg = hydra.compose("config", args.overrides)
cfg.policy.pretrained_model_path = folder / "model.pt"
cfg = init_hydra_config(
folder / "config.yaml", [f"policy.pretrained_model_path={folder / 'model.pt'}", *args.overrides]
)
stats_path = folder / "stats.pth"
eval(

View File

@ -1,19 +1,56 @@
import pytest
from pathlib import Path
@pytest.mark.parametrize(
"path",
[
"examples/1_visualize_dataset.py",
"examples/2_evaluate_pretrained_policy.py",
"examples/3_train_policy.py",
],
)
def test_example(path):
with open(path, 'r') as file:
def _find_and_replace(text: str, finds: list[str], replaces: list[str]) -> str:
for f, r in zip(finds, replaces):
assert f in text
text = text.replace(f, r)
return text
def test_example_1():
path = "examples/1_visualize_dataset.py"
with open(path, "r") as file:
file_contents = file.read()
exec(file_contents)
if path == "examples/1_visualize_dataset.py":
assert Path("outputs/visualize_dataset/example/episode_0.mp4").exists()
assert Path("outputs/visualize_dataset/example/episode_0.mp4").exists()
def test_examples_3_and_2():
"""
Train a model with example 3, check the outputs.
Evaluate the trained model with example 2, check the outputs.
"""
path = "examples/3_train_policy.py"
with open(path, "r") as file:
file_contents = file.read()
# Do less steps and use CPU.
file_contents = _find_and_replace(
file_contents,
['"offline_steps=5000"', '"device=cuda"'],
['"offline_steps=1"', '"device=cpu"'],
)
exec(file_contents)
for file_name in ["model.pt", "stats.pth", "config.yaml"]:
assert Path(f"outputs/train/example_pusht_diffusion/{file_name}").exists()
path = "examples/2_evaluate_pretrained_policy.py"
with open(path, "r") as file:
file_contents = file.read()
# Do less evals and use CPU.
file_contents = _find_and_replace(
file_contents,
['"eval_episodes=10"', '"rollout_batch_size=10"', '"device=cuda"'],
['"eval_episodes=1"', '"rollout_batch_size=1"','"device=cpu"'],
)
assert Path(f"outputs/train/example_pusht_diffusion").exists()