Improves Type Annotations (#252)

This commit is contained in:
Wael Karkoub 2024-06-10 19:09:48 +01:00 committed by GitHub
parent a06598678c
commit 54c9776bde
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 54 additions and 23 deletions

View File

@ -241,5 +241,6 @@ class Logger:
def log_video(self, video_path: str, step: int, mode: str = "train"):
assert mode in {"train", "eval"}
assert self._wandb is not None
wandb_video = self._wandb.Video(video_path, fps=self._cfg.fps, format="mp4")
self._wandb.log({f"{mode}/video": wandb_video}, step=step)

View File

@ -57,7 +57,7 @@ class Policy(Protocol):
other items should be logging-friendly, native Python types.
"""
def select_action(self, batch: dict[str, Tensor]):
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Return one action to run in the environment (potentially in batch mode).
When the model uses a history of observations, or outputs a sequence of actions, this method deals

View File

@ -134,7 +134,7 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
self._prev_mean: torch.Tensor | None = None
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]):
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select a single action given environment observations."""
batch = self.normalize_inputs(batch)
batch["observation.image"] = batch[self.input_image_key]

View File

@ -61,7 +61,7 @@ from huggingface_hub import snapshot_download
from huggingface_hub.utils._errors import RepositoryNotFoundError
from huggingface_hub.utils._validators import HFValidationError
from PIL import Image as PILImage
from torch import Tensor
from torch import Tensor, nn
from tqdm import trange
from lerobot.common.datasets.factory import make_dataset
@ -99,13 +99,13 @@ def rollout(
"reward": A (batch, sequence) tensor of rewards received for applying the actions.
"success": A (batch, sequence) tensor of success conditions (the only time this can be True is upon
environment termination/truncation).
"don": A (batch, sequence) tensor of **cumulative** done conditions. For any given batch element,
"done": A (batch, sequence) tensor of **cumulative** done conditions. For any given batch element,
the first True is followed by True's all the way till the end. This can be used for masking
extraneous elements from the sequences above.
Args:
env: The batch of environments.
policy: The policy.
policy: The policy. Must be a PyTorch nn module.
seeds: The environments are seeded once at the start of the rollout. If provided, this argument
specifies the seeds for each of the environments.
return_observations: Whether to include all observations in the returned rollout data. Observations
@ -116,6 +116,7 @@ def rollout(
Returns:
The dictionary described above.
"""
assert isinstance(policy, nn.Module), "Policy must be a PyTorch nn module."
device = get_device_from_parameters(policy)
# Reset the policy and environments.
@ -231,6 +232,10 @@ def eval_policy(
Returns:
Dictionary with metrics and data regarding the rollouts.
"""
if max_episodes_rendered > 0 and not videos_dir:
raise ValueError("If max_episodes_rendered > 0, videos_dir must be provided.")
assert isinstance(policy, Policy)
start = time.time()
policy.eval()
@ -271,11 +276,16 @@ def eval_policy(
if max_episodes_rendered > 0:
ep_frames: list[np.ndarray] = []
seeds = range(start_seed + (batch_ix * env.num_envs), start_seed + ((batch_ix + 1) * env.num_envs))
if start_seed is None:
seeds = None
else:
seeds = range(
start_seed + (batch_ix * env.num_envs), start_seed + ((batch_ix + 1) * env.num_envs)
)
rollout_data = rollout(
env,
policy,
seeds=seeds,
seeds=list(seeds) if seeds else None,
return_observations=return_episode_data,
render_callback=render_frame if max_episodes_rendered > 0 else None,
enable_progbar=enable_inner_progbar,
@ -285,7 +295,8 @@ def eval_policy(
# this won't be included).
n_steps = rollout_data["done"].shape[1]
# Note: this relies on a property of argmax: that it returns the first occurrence as a tiebreaker.
done_indices = torch.argmax(rollout_data["done"].to(int), axis=1) # (batch_size, rollout_steps)
done_indices = torch.argmax(rollout_data["done"].to(int), dim=1)
# Make a mask with shape (batch, n_steps) to mask out rollout data after the first done
# (batch-element-wise). Note the `done_indices + 1` to make sure to keep the data from the done step.
mask = (torch.arange(n_steps) <= einops.repeat(done_indices + 1, "b -> b s", s=n_steps)).int()
@ -296,8 +307,12 @@ def eval_policy(
max_rewards.extend(batch_max_rewards.tolist())
batch_successes = einops.reduce((rollout_data["success"] * mask), "b n -> b", "any")
all_successes.extend(batch_successes.tolist())
if seeds:
all_seeds.extend(seeds)
else:
all_seeds.append(None)
# FIXME: episode_data is either None or it doesn't exist
if return_episode_data:
this_episode_data = _compile_episode_data(
rollout_data,
@ -347,6 +362,7 @@ def eval_policy(
):
if n_episodes_rendered >= max_episodes_rendered:
break
videos_dir.mkdir(parents=True, exist_ok=True)
video_path = videos_dir / f"eval_episode_{n_episodes_rendered}.mp4"
video_paths.append(str(video_path))
@ -504,16 +520,17 @@ def _compile_episode_data(
def main(
pretrained_policy_path: str | None = None,
pretrained_policy_path: Path | None = None,
hydra_cfg_path: str | None = None,
out_dir: str | None = None,
config_overrides: list[str] | None = None,
):
assert (pretrained_policy_path is None) ^ (hydra_cfg_path is None)
if hydra_cfg_path is None:
hydra_cfg = init_hydra_config(pretrained_policy_path / "config.yaml", config_overrides)
if pretrained_policy_path is not None:
hydra_cfg = init_hydra_config(str(pretrained_policy_path / "config.yaml"), config_overrides)
else:
hydra_cfg = init_hydra_config(hydra_cfg_path, config_overrides)
if out_dir is None:
out_dir = f"outputs/eval/{dt.now().strftime('%Y-%m-%d/%H-%M-%S')}_{hydra_cfg.env.name}_{hydra_cfg.policy.name}"
@ -531,10 +548,12 @@ def main(
logging.info("Making policy.")
if hydra_cfg_path is None:
policy = make_policy(hydra_cfg=hydra_cfg, pretrained_policy_name_or_path=pretrained_policy_path)
policy = make_policy(hydra_cfg=hydra_cfg, pretrained_policy_name_or_path=str(pretrained_policy_path))
else:
# Note: We need the dataset stats to pass to the policy's normalization modules.
policy = make_policy(hydra_cfg=hydra_cfg, dataset_stats=make_dataset(hydra_cfg).stats)
assert isinstance(policy, nn.Module)
policy.eval()
with torch.no_grad(), torch.autocast(device_type=device.type) if hydra_cfg.use_amp else nullcontext():

View File

@ -66,6 +66,7 @@ import argparse
import json
import shutil
from pathlib import Path
from typing import Any
import torch
from huggingface_hub import HfApi
@ -77,7 +78,7 @@ from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_r
from lerobot.common.datasets.utils import flatten_dict
def get_from_raw_to_lerobot_format_fn(raw_format):
def get_from_raw_to_lerobot_format_fn(raw_format: str):
if raw_format == "pusht_zarr":
from lerobot.common.datasets.push_dataset_to_hub.pusht_zarr_format import from_raw_to_lerobot_format
elif raw_format == "umi_zarr":
@ -96,7 +97,9 @@ def get_from_raw_to_lerobot_format_fn(raw_format):
return from_raw_to_lerobot_format
def save_meta_data(info, stats, episode_data_index, meta_data_dir):
def save_meta_data(
info: dict[str, Any], stats: dict, episode_data_index: dict[str, list], meta_data_dir: Path
):
meta_data_dir.mkdir(parents=True, exist_ok=True)
# save info
@ -114,7 +117,7 @@ def save_meta_data(info, stats, episode_data_index, meta_data_dir):
save_file(episode_data_index, ep_data_idx_path)
def push_meta_data_to_hub(repo_id, meta_data_dir, revision):
def push_meta_data_to_hub(repo_id: str, meta_data_dir: str | Path, revision: str | None):
"""Expect all meta data files to be all stored in a single "meta_data" directory.
On the hugging face repositery, they will be uploaded in a "meta_data" directory at the root.
"""
@ -128,7 +131,7 @@ def push_meta_data_to_hub(repo_id, meta_data_dir, revision):
)
def push_videos_to_hub(repo_id, videos_dir, revision):
def push_videos_to_hub(repo_id: str, videos_dir: str | Path, revision: str | None):
"""Expect mp4 files to be all stored in a single "videos" directory.
On the hugging face repositery, they will be uploaded in a "videos" directory at the root.
"""
@ -209,6 +212,7 @@ def push_dataset_to_hub(
save_meta_data(info, stats, episode_data_index, meta_data_dir)
if not dry_run:
# TODO(rcadene): token needs to be a str | None
hf_dataset.push_to_hub(repo_id, token=True, revision="main")
hf_dataset.push_to_hub(repo_id, token=True, revision=revision)

View File

@ -24,6 +24,7 @@ import torch
from deepdiff import DeepDiff
from omegaconf import DictConfig, OmegaConf
from termcolor import colored
from torch import nn
from torch.cuda.amp import GradScaler
from lerobot.common.datasets.factory import make_dataset, resolve_delta_timestamps
@ -292,6 +293,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
# Create environment used for evaluating checkpoints during training on simulation data.
# On real-world data, no need to create an environment as evaluations are done outside train.py,
# using the eval.py instead, with gym_dora environment and dora-rs.
eval_env = None
if cfg.training.eval_freq > 0:
logging.info("make_env")
eval_env = make_env(cfg)
@ -302,7 +304,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
dataset_stats=offline_dataset.stats if not cfg.resume else None,
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None,
)
assert isinstance(policy, nn.Module)
# Create optimizer and scheduler
# Temporary hack to move optimizer out of policy
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
@ -333,6 +335,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
if cfg.training.eval_freq > 0 and step % cfg.training.eval_freq == 0:
logging.info(f"Eval policy at step {step}")
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.use_amp else nullcontext():
assert eval_env is not None
eval_info = eval_policy(
eval_env,
policy,
@ -414,6 +417,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
step += 1
if eval_env:
eval_env.close()
logging.info("End of training")

View File

@ -66,28 +66,31 @@ import gc
import logging
import time
from pathlib import Path
from typing import Iterator
import numpy as np
import rerun as rr
import torch
import torch.utils.data
import tqdm
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
class EpisodeSampler(torch.utils.data.Sampler):
def __init__(self, dataset, episode_index):
def __init__(self, dataset: LeRobotDataset, episode_index: int):
from_idx = dataset.episode_data_index["from"][episode_index].item()
to_idx = dataset.episode_data_index["to"][episode_index].item()
self.frame_ids = range(from_idx, to_idx)
def __iter__(self):
def __iter__(self) -> Iterator:
return iter(self.frame_ids)
def __len__(self):
def __len__(self) -> int:
return len(self.frame_ids)
def to_hwc_uint8_numpy(chw_float32_torch):
def to_hwc_uint8_numpy(chw_float32_torch: torch.Tensor) -> np.ndarray:
assert chw_float32_torch.dtype == torch.float32
assert chw_float32_torch.ndim == 3
c, h, w = chw_float32_torch.shape