Improves Type Annotations (#252)

This commit is contained in:
Wael Karkoub 2024-06-10 19:09:48 +01:00 committed by GitHub
parent a06598678c
commit 54c9776bde
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 54 additions and 23 deletions

View File

@ -241,5 +241,6 @@ class Logger:
def log_video(self, video_path: str, step: int, mode: str = "train"): def log_video(self, video_path: str, step: int, mode: str = "train"):
assert mode in {"train", "eval"} assert mode in {"train", "eval"}
assert self._wandb is not None
wandb_video = self._wandb.Video(video_path, fps=self._cfg.fps, format="mp4") wandb_video = self._wandb.Video(video_path, fps=self._cfg.fps, format="mp4")
self._wandb.log({f"{mode}/video": wandb_video}, step=step) self._wandb.log({f"{mode}/video": wandb_video}, step=step)

View File

@ -57,7 +57,7 @@ class Policy(Protocol):
other items should be logging-friendly, native Python types. other items should be logging-friendly, native Python types.
""" """
def select_action(self, batch: dict[str, Tensor]): def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Return one action to run in the environment (potentially in batch mode). """Return one action to run in the environment (potentially in batch mode).
When the model uses a history of observations, or outputs a sequence of actions, this method deals When the model uses a history of observations, or outputs a sequence of actions, this method deals

View File

@ -134,7 +134,7 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
self._prev_mean: torch.Tensor | None = None self._prev_mean: torch.Tensor | None = None
@torch.no_grad() @torch.no_grad()
def select_action(self, batch: dict[str, Tensor]): def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select a single action given environment observations.""" """Select a single action given environment observations."""
batch = self.normalize_inputs(batch) batch = self.normalize_inputs(batch)
batch["observation.image"] = batch[self.input_image_key] batch["observation.image"] = batch[self.input_image_key]

View File

@ -61,7 +61,7 @@ from huggingface_hub import snapshot_download
from huggingface_hub.utils._errors import RepositoryNotFoundError from huggingface_hub.utils._errors import RepositoryNotFoundError
from huggingface_hub.utils._validators import HFValidationError from huggingface_hub.utils._validators import HFValidationError
from PIL import Image as PILImage from PIL import Image as PILImage
from torch import Tensor from torch import Tensor, nn
from tqdm import trange from tqdm import trange
from lerobot.common.datasets.factory import make_dataset from lerobot.common.datasets.factory import make_dataset
@ -99,13 +99,13 @@ def rollout(
"reward": A (batch, sequence) tensor of rewards received for applying the actions. "reward": A (batch, sequence) tensor of rewards received for applying the actions.
"success": A (batch, sequence) tensor of success conditions (the only time this can be True is upon "success": A (batch, sequence) tensor of success conditions (the only time this can be True is upon
environment termination/truncation). environment termination/truncation).
"don": A (batch, sequence) tensor of **cumulative** done conditions. For any given batch element, "done": A (batch, sequence) tensor of **cumulative** done conditions. For any given batch element,
the first True is followed by True's all the way till the end. This can be used for masking the first True is followed by True's all the way till the end. This can be used for masking
extraneous elements from the sequences above. extraneous elements from the sequences above.
Args: Args:
env: The batch of environments. env: The batch of environments.
policy: The policy. policy: The policy. Must be a PyTorch nn module.
seeds: The environments are seeded once at the start of the rollout. If provided, this argument seeds: The environments are seeded once at the start of the rollout. If provided, this argument
specifies the seeds for each of the environments. specifies the seeds for each of the environments.
return_observations: Whether to include all observations in the returned rollout data. Observations return_observations: Whether to include all observations in the returned rollout data. Observations
@ -116,6 +116,7 @@ def rollout(
Returns: Returns:
The dictionary described above. The dictionary described above.
""" """
assert isinstance(policy, nn.Module), "Policy must be a PyTorch nn module."
device = get_device_from_parameters(policy) device = get_device_from_parameters(policy)
# Reset the policy and environments. # Reset the policy and environments.
@ -231,6 +232,10 @@ def eval_policy(
Returns: Returns:
Dictionary with metrics and data regarding the rollouts. Dictionary with metrics and data regarding the rollouts.
""" """
if max_episodes_rendered > 0 and not videos_dir:
raise ValueError("If max_episodes_rendered > 0, videos_dir must be provided.")
assert isinstance(policy, Policy)
start = time.time() start = time.time()
policy.eval() policy.eval()
@ -271,11 +276,16 @@ def eval_policy(
if max_episodes_rendered > 0: if max_episodes_rendered > 0:
ep_frames: list[np.ndarray] = [] ep_frames: list[np.ndarray] = []
seeds = range(start_seed + (batch_ix * env.num_envs), start_seed + ((batch_ix + 1) * env.num_envs)) if start_seed is None:
seeds = None
else:
seeds = range(
start_seed + (batch_ix * env.num_envs), start_seed + ((batch_ix + 1) * env.num_envs)
)
rollout_data = rollout( rollout_data = rollout(
env, env,
policy, policy,
seeds=seeds, seeds=list(seeds) if seeds else None,
return_observations=return_episode_data, return_observations=return_episode_data,
render_callback=render_frame if max_episodes_rendered > 0 else None, render_callback=render_frame if max_episodes_rendered > 0 else None,
enable_progbar=enable_inner_progbar, enable_progbar=enable_inner_progbar,
@ -285,7 +295,8 @@ def eval_policy(
# this won't be included). # this won't be included).
n_steps = rollout_data["done"].shape[1] n_steps = rollout_data["done"].shape[1]
# Note: this relies on a property of argmax: that it returns the first occurrence as a tiebreaker. # Note: this relies on a property of argmax: that it returns the first occurrence as a tiebreaker.
done_indices = torch.argmax(rollout_data["done"].to(int), axis=1) # (batch_size, rollout_steps) done_indices = torch.argmax(rollout_data["done"].to(int), dim=1)
# Make a mask with shape (batch, n_steps) to mask out rollout data after the first done # Make a mask with shape (batch, n_steps) to mask out rollout data after the first done
# (batch-element-wise). Note the `done_indices + 1` to make sure to keep the data from the done step. # (batch-element-wise). Note the `done_indices + 1` to make sure to keep the data from the done step.
mask = (torch.arange(n_steps) <= einops.repeat(done_indices + 1, "b -> b s", s=n_steps)).int() mask = (torch.arange(n_steps) <= einops.repeat(done_indices + 1, "b -> b s", s=n_steps)).int()
@ -296,8 +307,12 @@ def eval_policy(
max_rewards.extend(batch_max_rewards.tolist()) max_rewards.extend(batch_max_rewards.tolist())
batch_successes = einops.reduce((rollout_data["success"] * mask), "b n -> b", "any") batch_successes = einops.reduce((rollout_data["success"] * mask), "b n -> b", "any")
all_successes.extend(batch_successes.tolist()) all_successes.extend(batch_successes.tolist())
all_seeds.extend(seeds) if seeds:
all_seeds.extend(seeds)
else:
all_seeds.append(None)
# FIXME: episode_data is either None or it doesn't exist
if return_episode_data: if return_episode_data:
this_episode_data = _compile_episode_data( this_episode_data = _compile_episode_data(
rollout_data, rollout_data,
@ -347,6 +362,7 @@ def eval_policy(
): ):
if n_episodes_rendered >= max_episodes_rendered: if n_episodes_rendered >= max_episodes_rendered:
break break
videos_dir.mkdir(parents=True, exist_ok=True) videos_dir.mkdir(parents=True, exist_ok=True)
video_path = videos_dir / f"eval_episode_{n_episodes_rendered}.mp4" video_path = videos_dir / f"eval_episode_{n_episodes_rendered}.mp4"
video_paths.append(str(video_path)) video_paths.append(str(video_path))
@ -504,16 +520,17 @@ def _compile_episode_data(
def main( def main(
pretrained_policy_path: str | None = None, pretrained_policy_path: Path | None = None,
hydra_cfg_path: str | None = None, hydra_cfg_path: str | None = None,
out_dir: str | None = None, out_dir: str | None = None,
config_overrides: list[str] | None = None, config_overrides: list[str] | None = None,
): ):
assert (pretrained_policy_path is None) ^ (hydra_cfg_path is None) assert (pretrained_policy_path is None) ^ (hydra_cfg_path is None)
if hydra_cfg_path is None: if pretrained_policy_path is not None:
hydra_cfg = init_hydra_config(pretrained_policy_path / "config.yaml", config_overrides) hydra_cfg = init_hydra_config(str(pretrained_policy_path / "config.yaml"), config_overrides)
else: else:
hydra_cfg = init_hydra_config(hydra_cfg_path, config_overrides) hydra_cfg = init_hydra_config(hydra_cfg_path, config_overrides)
if out_dir is None: if out_dir is None:
out_dir = f"outputs/eval/{dt.now().strftime('%Y-%m-%d/%H-%M-%S')}_{hydra_cfg.env.name}_{hydra_cfg.policy.name}" out_dir = f"outputs/eval/{dt.now().strftime('%Y-%m-%d/%H-%M-%S')}_{hydra_cfg.env.name}_{hydra_cfg.policy.name}"
@ -531,10 +548,12 @@ def main(
logging.info("Making policy.") logging.info("Making policy.")
if hydra_cfg_path is None: if hydra_cfg_path is None:
policy = make_policy(hydra_cfg=hydra_cfg, pretrained_policy_name_or_path=pretrained_policy_path) policy = make_policy(hydra_cfg=hydra_cfg, pretrained_policy_name_or_path=str(pretrained_policy_path))
else: else:
# Note: We need the dataset stats to pass to the policy's normalization modules. # Note: We need the dataset stats to pass to the policy's normalization modules.
policy = make_policy(hydra_cfg=hydra_cfg, dataset_stats=make_dataset(hydra_cfg).stats) policy = make_policy(hydra_cfg=hydra_cfg, dataset_stats=make_dataset(hydra_cfg).stats)
assert isinstance(policy, nn.Module)
policy.eval() policy.eval()
with torch.no_grad(), torch.autocast(device_type=device.type) if hydra_cfg.use_amp else nullcontext(): with torch.no_grad(), torch.autocast(device_type=device.type) if hydra_cfg.use_amp else nullcontext():

View File

@ -66,6 +66,7 @@ import argparse
import json import json
import shutil import shutil
from pathlib import Path from pathlib import Path
from typing import Any
import torch import torch
from huggingface_hub import HfApi from huggingface_hub import HfApi
@ -77,7 +78,7 @@ from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_r
from lerobot.common.datasets.utils import flatten_dict from lerobot.common.datasets.utils import flatten_dict
def get_from_raw_to_lerobot_format_fn(raw_format): def get_from_raw_to_lerobot_format_fn(raw_format: str):
if raw_format == "pusht_zarr": if raw_format == "pusht_zarr":
from lerobot.common.datasets.push_dataset_to_hub.pusht_zarr_format import from_raw_to_lerobot_format from lerobot.common.datasets.push_dataset_to_hub.pusht_zarr_format import from_raw_to_lerobot_format
elif raw_format == "umi_zarr": elif raw_format == "umi_zarr":
@ -96,7 +97,9 @@ def get_from_raw_to_lerobot_format_fn(raw_format):
return from_raw_to_lerobot_format return from_raw_to_lerobot_format
def save_meta_data(info, stats, episode_data_index, meta_data_dir): def save_meta_data(
info: dict[str, Any], stats: dict, episode_data_index: dict[str, list], meta_data_dir: Path
):
meta_data_dir.mkdir(parents=True, exist_ok=True) meta_data_dir.mkdir(parents=True, exist_ok=True)
# save info # save info
@ -114,7 +117,7 @@ def save_meta_data(info, stats, episode_data_index, meta_data_dir):
save_file(episode_data_index, ep_data_idx_path) save_file(episode_data_index, ep_data_idx_path)
def push_meta_data_to_hub(repo_id, meta_data_dir, revision): def push_meta_data_to_hub(repo_id: str, meta_data_dir: str | Path, revision: str | None):
"""Expect all meta data files to be all stored in a single "meta_data" directory. """Expect all meta data files to be all stored in a single "meta_data" directory.
On the hugging face repositery, they will be uploaded in a "meta_data" directory at the root. On the hugging face repositery, they will be uploaded in a "meta_data" directory at the root.
""" """
@ -128,7 +131,7 @@ def push_meta_data_to_hub(repo_id, meta_data_dir, revision):
) )
def push_videos_to_hub(repo_id, videos_dir, revision): def push_videos_to_hub(repo_id: str, videos_dir: str | Path, revision: str | None):
"""Expect mp4 files to be all stored in a single "videos" directory. """Expect mp4 files to be all stored in a single "videos" directory.
On the hugging face repositery, they will be uploaded in a "videos" directory at the root. On the hugging face repositery, they will be uploaded in a "videos" directory at the root.
""" """
@ -209,6 +212,7 @@ def push_dataset_to_hub(
save_meta_data(info, stats, episode_data_index, meta_data_dir) save_meta_data(info, stats, episode_data_index, meta_data_dir)
if not dry_run: if not dry_run:
# TODO(rcadene): token needs to be a str | None
hf_dataset.push_to_hub(repo_id, token=True, revision="main") hf_dataset.push_to_hub(repo_id, token=True, revision="main")
hf_dataset.push_to_hub(repo_id, token=True, revision=revision) hf_dataset.push_to_hub(repo_id, token=True, revision=revision)

View File

@ -24,6 +24,7 @@ import torch
from deepdiff import DeepDiff from deepdiff import DeepDiff
from omegaconf import DictConfig, OmegaConf from omegaconf import DictConfig, OmegaConf
from termcolor import colored from termcolor import colored
from torch import nn
from torch.cuda.amp import GradScaler from torch.cuda.amp import GradScaler
from lerobot.common.datasets.factory import make_dataset, resolve_delta_timestamps from lerobot.common.datasets.factory import make_dataset, resolve_delta_timestamps
@ -292,6 +293,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
# Create environment used for evaluating checkpoints during training on simulation data. # Create environment used for evaluating checkpoints during training on simulation data.
# On real-world data, no need to create an environment as evaluations are done outside train.py, # On real-world data, no need to create an environment as evaluations are done outside train.py,
# using the eval.py instead, with gym_dora environment and dora-rs. # using the eval.py instead, with gym_dora environment and dora-rs.
eval_env = None
if cfg.training.eval_freq > 0: if cfg.training.eval_freq > 0:
logging.info("make_env") logging.info("make_env")
eval_env = make_env(cfg) eval_env = make_env(cfg)
@ -302,7 +304,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
dataset_stats=offline_dataset.stats if not cfg.resume else None, dataset_stats=offline_dataset.stats if not cfg.resume else None,
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None, pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None,
) )
assert isinstance(policy, nn.Module)
# Create optimizer and scheduler # Create optimizer and scheduler
# Temporary hack to move optimizer out of policy # Temporary hack to move optimizer out of policy
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy) optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
@ -333,6 +335,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
if cfg.training.eval_freq > 0 and step % cfg.training.eval_freq == 0: if cfg.training.eval_freq > 0 and step % cfg.training.eval_freq == 0:
logging.info(f"Eval policy at step {step}") logging.info(f"Eval policy at step {step}")
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.use_amp else nullcontext(): with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.use_amp else nullcontext():
assert eval_env is not None
eval_info = eval_policy( eval_info = eval_policy(
eval_env, eval_env,
policy, policy,
@ -414,7 +417,8 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
step += 1 step += 1
eval_env.close() if eval_env:
eval_env.close()
logging.info("End of training") logging.info("End of training")

View File

@ -66,28 +66,31 @@ import gc
import logging import logging
import time import time
from pathlib import Path from pathlib import Path
from typing import Iterator
import numpy as np
import rerun as rr import rerun as rr
import torch import torch
import torch.utils.data
import tqdm import tqdm
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
class EpisodeSampler(torch.utils.data.Sampler): class EpisodeSampler(torch.utils.data.Sampler):
def __init__(self, dataset, episode_index): def __init__(self, dataset: LeRobotDataset, episode_index: int):
from_idx = dataset.episode_data_index["from"][episode_index].item() from_idx = dataset.episode_data_index["from"][episode_index].item()
to_idx = dataset.episode_data_index["to"][episode_index].item() to_idx = dataset.episode_data_index["to"][episode_index].item()
self.frame_ids = range(from_idx, to_idx) self.frame_ids = range(from_idx, to_idx)
def __iter__(self): def __iter__(self) -> Iterator:
return iter(self.frame_ids) return iter(self.frame_ids)
def __len__(self): def __len__(self) -> int:
return len(self.frame_ids) return len(self.frame_ids)
def to_hwc_uint8_numpy(chw_float32_torch): def to_hwc_uint8_numpy(chw_float32_torch: torch.Tensor) -> np.ndarray:
assert chw_float32_torch.dtype == torch.float32 assert chw_float32_torch.dtype == torch.float32
assert chw_float32_torch.ndim == 3 assert chw_float32_torch.ndim == 3
c, h, w = chw_float32_torch.shape c, h, w = chw_float32_torch.shape