Improves Type Annotations (#252)
This commit is contained in:
parent
a06598678c
commit
54c9776bde
|
@ -241,5 +241,6 @@ class Logger:
|
|||
|
||||
def log_video(self, video_path: str, step: int, mode: str = "train"):
|
||||
assert mode in {"train", "eval"}
|
||||
assert self._wandb is not None
|
||||
wandb_video = self._wandb.Video(video_path, fps=self._cfg.fps, format="mp4")
|
||||
self._wandb.log({f"{mode}/video": wandb_video}, step=step)
|
||||
|
|
|
@ -57,7 +57,7 @@ class Policy(Protocol):
|
|||
other items should be logging-friendly, native Python types.
|
||||
"""
|
||||
|
||||
def select_action(self, batch: dict[str, Tensor]):
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Return one action to run in the environment (potentially in batch mode).
|
||||
|
||||
When the model uses a history of observations, or outputs a sequence of actions, this method deals
|
||||
|
|
|
@ -134,7 +134,7 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
self._prev_mean: torch.Tensor | None = None
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor]):
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Select a single action given environment observations."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch["observation.image"] = batch[self.input_image_key]
|
||||
|
|
|
@ -61,7 +61,7 @@ from huggingface_hub import snapshot_download
|
|||
from huggingface_hub.utils._errors import RepositoryNotFoundError
|
||||
from huggingface_hub.utils._validators import HFValidationError
|
||||
from PIL import Image as PILImage
|
||||
from torch import Tensor
|
||||
from torch import Tensor, nn
|
||||
from tqdm import trange
|
||||
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
|
@ -99,13 +99,13 @@ def rollout(
|
|||
"reward": A (batch, sequence) tensor of rewards received for applying the actions.
|
||||
"success": A (batch, sequence) tensor of success conditions (the only time this can be True is upon
|
||||
environment termination/truncation).
|
||||
"don": A (batch, sequence) tensor of **cumulative** done conditions. For any given batch element,
|
||||
"done": A (batch, sequence) tensor of **cumulative** done conditions. For any given batch element,
|
||||
the first True is followed by True's all the way till the end. This can be used for masking
|
||||
extraneous elements from the sequences above.
|
||||
|
||||
Args:
|
||||
env: The batch of environments.
|
||||
policy: The policy.
|
||||
policy: The policy. Must be a PyTorch nn module.
|
||||
seeds: The environments are seeded once at the start of the rollout. If provided, this argument
|
||||
specifies the seeds for each of the environments.
|
||||
return_observations: Whether to include all observations in the returned rollout data. Observations
|
||||
|
@ -116,6 +116,7 @@ def rollout(
|
|||
Returns:
|
||||
The dictionary described above.
|
||||
"""
|
||||
assert isinstance(policy, nn.Module), "Policy must be a PyTorch nn module."
|
||||
device = get_device_from_parameters(policy)
|
||||
|
||||
# Reset the policy and environments.
|
||||
|
@ -231,6 +232,10 @@ def eval_policy(
|
|||
Returns:
|
||||
Dictionary with metrics and data regarding the rollouts.
|
||||
"""
|
||||
if max_episodes_rendered > 0 and not videos_dir:
|
||||
raise ValueError("If max_episodes_rendered > 0, videos_dir must be provided.")
|
||||
|
||||
assert isinstance(policy, Policy)
|
||||
start = time.time()
|
||||
policy.eval()
|
||||
|
||||
|
@ -271,11 +276,16 @@ def eval_policy(
|
|||
if max_episodes_rendered > 0:
|
||||
ep_frames: list[np.ndarray] = []
|
||||
|
||||
seeds = range(start_seed + (batch_ix * env.num_envs), start_seed + ((batch_ix + 1) * env.num_envs))
|
||||
if start_seed is None:
|
||||
seeds = None
|
||||
else:
|
||||
seeds = range(
|
||||
start_seed + (batch_ix * env.num_envs), start_seed + ((batch_ix + 1) * env.num_envs)
|
||||
)
|
||||
rollout_data = rollout(
|
||||
env,
|
||||
policy,
|
||||
seeds=seeds,
|
||||
seeds=list(seeds) if seeds else None,
|
||||
return_observations=return_episode_data,
|
||||
render_callback=render_frame if max_episodes_rendered > 0 else None,
|
||||
enable_progbar=enable_inner_progbar,
|
||||
|
@ -285,7 +295,8 @@ def eval_policy(
|
|||
# this won't be included).
|
||||
n_steps = rollout_data["done"].shape[1]
|
||||
# Note: this relies on a property of argmax: that it returns the first occurrence as a tiebreaker.
|
||||
done_indices = torch.argmax(rollout_data["done"].to(int), axis=1) # (batch_size, rollout_steps)
|
||||
done_indices = torch.argmax(rollout_data["done"].to(int), dim=1)
|
||||
|
||||
# Make a mask with shape (batch, n_steps) to mask out rollout data after the first done
|
||||
# (batch-element-wise). Note the `done_indices + 1` to make sure to keep the data from the done step.
|
||||
mask = (torch.arange(n_steps) <= einops.repeat(done_indices + 1, "b -> b s", s=n_steps)).int()
|
||||
|
@ -296,8 +307,12 @@ def eval_policy(
|
|||
max_rewards.extend(batch_max_rewards.tolist())
|
||||
batch_successes = einops.reduce((rollout_data["success"] * mask), "b n -> b", "any")
|
||||
all_successes.extend(batch_successes.tolist())
|
||||
if seeds:
|
||||
all_seeds.extend(seeds)
|
||||
else:
|
||||
all_seeds.append(None)
|
||||
|
||||
# FIXME: episode_data is either None or it doesn't exist
|
||||
if return_episode_data:
|
||||
this_episode_data = _compile_episode_data(
|
||||
rollout_data,
|
||||
|
@ -347,6 +362,7 @@ def eval_policy(
|
|||
):
|
||||
if n_episodes_rendered >= max_episodes_rendered:
|
||||
break
|
||||
|
||||
videos_dir.mkdir(parents=True, exist_ok=True)
|
||||
video_path = videos_dir / f"eval_episode_{n_episodes_rendered}.mp4"
|
||||
video_paths.append(str(video_path))
|
||||
|
@ -504,16 +520,17 @@ def _compile_episode_data(
|
|||
|
||||
|
||||
def main(
|
||||
pretrained_policy_path: str | None = None,
|
||||
pretrained_policy_path: Path | None = None,
|
||||
hydra_cfg_path: str | None = None,
|
||||
out_dir: str | None = None,
|
||||
config_overrides: list[str] | None = None,
|
||||
):
|
||||
assert (pretrained_policy_path is None) ^ (hydra_cfg_path is None)
|
||||
if hydra_cfg_path is None:
|
||||
hydra_cfg = init_hydra_config(pretrained_policy_path / "config.yaml", config_overrides)
|
||||
if pretrained_policy_path is not None:
|
||||
hydra_cfg = init_hydra_config(str(pretrained_policy_path / "config.yaml"), config_overrides)
|
||||
else:
|
||||
hydra_cfg = init_hydra_config(hydra_cfg_path, config_overrides)
|
||||
|
||||
if out_dir is None:
|
||||
out_dir = f"outputs/eval/{dt.now().strftime('%Y-%m-%d/%H-%M-%S')}_{hydra_cfg.env.name}_{hydra_cfg.policy.name}"
|
||||
|
||||
|
@ -531,10 +548,12 @@ def main(
|
|||
|
||||
logging.info("Making policy.")
|
||||
if hydra_cfg_path is None:
|
||||
policy = make_policy(hydra_cfg=hydra_cfg, pretrained_policy_name_or_path=pretrained_policy_path)
|
||||
policy = make_policy(hydra_cfg=hydra_cfg, pretrained_policy_name_or_path=str(pretrained_policy_path))
|
||||
else:
|
||||
# Note: We need the dataset stats to pass to the policy's normalization modules.
|
||||
policy = make_policy(hydra_cfg=hydra_cfg, dataset_stats=make_dataset(hydra_cfg).stats)
|
||||
|
||||
assert isinstance(policy, nn.Module)
|
||||
policy.eval()
|
||||
|
||||
with torch.no_grad(), torch.autocast(device_type=device.type) if hydra_cfg.use_amp else nullcontext():
|
||||
|
|
|
@ -66,6 +66,7 @@ import argparse
|
|||
import json
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from huggingface_hub import HfApi
|
||||
|
@ -77,7 +78,7 @@ from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_r
|
|||
from lerobot.common.datasets.utils import flatten_dict
|
||||
|
||||
|
||||
def get_from_raw_to_lerobot_format_fn(raw_format):
|
||||
def get_from_raw_to_lerobot_format_fn(raw_format: str):
|
||||
if raw_format == "pusht_zarr":
|
||||
from lerobot.common.datasets.push_dataset_to_hub.pusht_zarr_format import from_raw_to_lerobot_format
|
||||
elif raw_format == "umi_zarr":
|
||||
|
@ -96,7 +97,9 @@ def get_from_raw_to_lerobot_format_fn(raw_format):
|
|||
return from_raw_to_lerobot_format
|
||||
|
||||
|
||||
def save_meta_data(info, stats, episode_data_index, meta_data_dir):
|
||||
def save_meta_data(
|
||||
info: dict[str, Any], stats: dict, episode_data_index: dict[str, list], meta_data_dir: Path
|
||||
):
|
||||
meta_data_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# save info
|
||||
|
@ -114,7 +117,7 @@ def save_meta_data(info, stats, episode_data_index, meta_data_dir):
|
|||
save_file(episode_data_index, ep_data_idx_path)
|
||||
|
||||
|
||||
def push_meta_data_to_hub(repo_id, meta_data_dir, revision):
|
||||
def push_meta_data_to_hub(repo_id: str, meta_data_dir: str | Path, revision: str | None):
|
||||
"""Expect all meta data files to be all stored in a single "meta_data" directory.
|
||||
On the hugging face repositery, they will be uploaded in a "meta_data" directory at the root.
|
||||
"""
|
||||
|
@ -128,7 +131,7 @@ def push_meta_data_to_hub(repo_id, meta_data_dir, revision):
|
|||
)
|
||||
|
||||
|
||||
def push_videos_to_hub(repo_id, videos_dir, revision):
|
||||
def push_videos_to_hub(repo_id: str, videos_dir: str | Path, revision: str | None):
|
||||
"""Expect mp4 files to be all stored in a single "videos" directory.
|
||||
On the hugging face repositery, they will be uploaded in a "videos" directory at the root.
|
||||
"""
|
||||
|
@ -209,6 +212,7 @@ def push_dataset_to_hub(
|
|||
save_meta_data(info, stats, episode_data_index, meta_data_dir)
|
||||
|
||||
if not dry_run:
|
||||
# TODO(rcadene): token needs to be a str | None
|
||||
hf_dataset.push_to_hub(repo_id, token=True, revision="main")
|
||||
hf_dataset.push_to_hub(repo_id, token=True, revision=revision)
|
||||
|
||||
|
|
|
@ -24,6 +24,7 @@ import torch
|
|||
from deepdiff import DeepDiff
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from termcolor import colored
|
||||
from torch import nn
|
||||
from torch.cuda.amp import GradScaler
|
||||
|
||||
from lerobot.common.datasets.factory import make_dataset, resolve_delta_timestamps
|
||||
|
@ -292,6 +293,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||
# Create environment used for evaluating checkpoints during training on simulation data.
|
||||
# On real-world data, no need to create an environment as evaluations are done outside train.py,
|
||||
# using the eval.py instead, with gym_dora environment and dora-rs.
|
||||
eval_env = None
|
||||
if cfg.training.eval_freq > 0:
|
||||
logging.info("make_env")
|
||||
eval_env = make_env(cfg)
|
||||
|
@ -302,7 +304,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||
dataset_stats=offline_dataset.stats if not cfg.resume else None,
|
||||
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None,
|
||||
)
|
||||
|
||||
assert isinstance(policy, nn.Module)
|
||||
# Create optimizer and scheduler
|
||||
# Temporary hack to move optimizer out of policy
|
||||
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
|
||||
|
@ -333,6 +335,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||
if cfg.training.eval_freq > 0 and step % cfg.training.eval_freq == 0:
|
||||
logging.info(f"Eval policy at step {step}")
|
||||
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.use_amp else nullcontext():
|
||||
assert eval_env is not None
|
||||
eval_info = eval_policy(
|
||||
eval_env,
|
||||
policy,
|
||||
|
@ -414,6 +417,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||
|
||||
step += 1
|
||||
|
||||
if eval_env:
|
||||
eval_env.close()
|
||||
logging.info("End of training")
|
||||
|
||||
|
|
|
@ -66,28 +66,31 @@ import gc
|
|||
import logging
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Iterator
|
||||
|
||||
import numpy as np
|
||||
import rerun as rr
|
||||
import torch
|
||||
import torch.utils.data
|
||||
import tqdm
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
|
||||
class EpisodeSampler(torch.utils.data.Sampler):
|
||||
def __init__(self, dataset, episode_index):
|
||||
def __init__(self, dataset: LeRobotDataset, episode_index: int):
|
||||
from_idx = dataset.episode_data_index["from"][episode_index].item()
|
||||
to_idx = dataset.episode_data_index["to"][episode_index].item()
|
||||
self.frame_ids = range(from_idx, to_idx)
|
||||
|
||||
def __iter__(self):
|
||||
def __iter__(self) -> Iterator:
|
||||
return iter(self.frame_ids)
|
||||
|
||||
def __len__(self):
|
||||
def __len__(self) -> int:
|
||||
return len(self.frame_ids)
|
||||
|
||||
|
||||
def to_hwc_uint8_numpy(chw_float32_torch):
|
||||
def to_hwc_uint8_numpy(chw_float32_torch: torch.Tensor) -> np.ndarray:
|
||||
assert chw_float32_torch.dtype == torch.float32
|
||||
assert chw_float32_torch.ndim == 3
|
||||
c, h, w = chw_float32_torch.shape
|
||||
|
|
Loading…
Reference in New Issue