diff --git a/lerobot/common/logger.py b/lerobot/common/logger.py index 853acbc3..bf578fcc 100644 --- a/lerobot/common/logger.py +++ b/lerobot/common/logger.py @@ -241,5 +241,6 @@ class Logger: def log_video(self, video_path: str, step: int, mode: str = "train"): assert mode in {"train", "eval"} + assert self._wandb is not None wandb_video = self._wandb.Video(video_path, fps=self._cfg.fps, format="mp4") self._wandb.log({f"{mode}/video": wandb_video}, step=step) diff --git a/lerobot/common/policies/policy_protocol.py b/lerobot/common/policies/policy_protocol.py index 38738a90..4e9e87af 100644 --- a/lerobot/common/policies/policy_protocol.py +++ b/lerobot/common/policies/policy_protocol.py @@ -57,7 +57,7 @@ class Policy(Protocol): other items should be logging-friendly, native Python types. """ - def select_action(self, batch: dict[str, Tensor]): + def select_action(self, batch: dict[str, Tensor]) -> Tensor: """Return one action to run in the environment (potentially in batch mode). When the model uses a history of observations, or outputs a sequence of actions, this method deals diff --git a/lerobot/common/policies/tdmpc/modeling_tdmpc.py b/lerobot/common/policies/tdmpc/modeling_tdmpc.py index 7c873bf2..de9658e9 100644 --- a/lerobot/common/policies/tdmpc/modeling_tdmpc.py +++ b/lerobot/common/policies/tdmpc/modeling_tdmpc.py @@ -134,7 +134,7 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin): self._prev_mean: torch.Tensor | None = None @torch.no_grad() - def select_action(self, batch: dict[str, Tensor]): + def select_action(self, batch: dict[str, Tensor]) -> Tensor: """Select a single action given environment observations.""" batch = self.normalize_inputs(batch) batch["observation.image"] = batch[self.input_image_key] diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index 4656130e..7bf8bde5 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -61,7 +61,7 @@ from huggingface_hub import snapshot_download from huggingface_hub.utils._errors import RepositoryNotFoundError from huggingface_hub.utils._validators import HFValidationError from PIL import Image as PILImage -from torch import Tensor +from torch import Tensor, nn from tqdm import trange from lerobot.common.datasets.factory import make_dataset @@ -99,13 +99,13 @@ def rollout( "reward": A (batch, sequence) tensor of rewards received for applying the actions. "success": A (batch, sequence) tensor of success conditions (the only time this can be True is upon environment termination/truncation). - "don": A (batch, sequence) tensor of **cumulative** done conditions. For any given batch element, + "done": A (batch, sequence) tensor of **cumulative** done conditions. For any given batch element, the first True is followed by True's all the way till the end. This can be used for masking extraneous elements from the sequences above. Args: env: The batch of environments. - policy: The policy. + policy: The policy. Must be a PyTorch nn module. seeds: The environments are seeded once at the start of the rollout. If provided, this argument specifies the seeds for each of the environments. return_observations: Whether to include all observations in the returned rollout data. Observations @@ -116,6 +116,7 @@ def rollout( Returns: The dictionary described above. """ + assert isinstance(policy, nn.Module), "Policy must be a PyTorch nn module." device = get_device_from_parameters(policy) # Reset the policy and environments. @@ -231,6 +232,10 @@ def eval_policy( Returns: Dictionary with metrics and data regarding the rollouts. """ + if max_episodes_rendered > 0 and not videos_dir: + raise ValueError("If max_episodes_rendered > 0, videos_dir must be provided.") + + assert isinstance(policy, Policy) start = time.time() policy.eval() @@ -271,11 +276,16 @@ def eval_policy( if max_episodes_rendered > 0: ep_frames: list[np.ndarray] = [] - seeds = range(start_seed + (batch_ix * env.num_envs), start_seed + ((batch_ix + 1) * env.num_envs)) + if start_seed is None: + seeds = None + else: + seeds = range( + start_seed + (batch_ix * env.num_envs), start_seed + ((batch_ix + 1) * env.num_envs) + ) rollout_data = rollout( env, policy, - seeds=seeds, + seeds=list(seeds) if seeds else None, return_observations=return_episode_data, render_callback=render_frame if max_episodes_rendered > 0 else None, enable_progbar=enable_inner_progbar, @@ -285,7 +295,8 @@ def eval_policy( # this won't be included). n_steps = rollout_data["done"].shape[1] # Note: this relies on a property of argmax: that it returns the first occurrence as a tiebreaker. - done_indices = torch.argmax(rollout_data["done"].to(int), axis=1) # (batch_size, rollout_steps) + done_indices = torch.argmax(rollout_data["done"].to(int), dim=1) + # Make a mask with shape (batch, n_steps) to mask out rollout data after the first done # (batch-element-wise). Note the `done_indices + 1` to make sure to keep the data from the done step. mask = (torch.arange(n_steps) <= einops.repeat(done_indices + 1, "b -> b s", s=n_steps)).int() @@ -296,8 +307,12 @@ def eval_policy( max_rewards.extend(batch_max_rewards.tolist()) batch_successes = einops.reduce((rollout_data["success"] * mask), "b n -> b", "any") all_successes.extend(batch_successes.tolist()) - all_seeds.extend(seeds) + if seeds: + all_seeds.extend(seeds) + else: + all_seeds.append(None) + # FIXME: episode_data is either None or it doesn't exist if return_episode_data: this_episode_data = _compile_episode_data( rollout_data, @@ -347,6 +362,7 @@ def eval_policy( ): if n_episodes_rendered >= max_episodes_rendered: break + videos_dir.mkdir(parents=True, exist_ok=True) video_path = videos_dir / f"eval_episode_{n_episodes_rendered}.mp4" video_paths.append(str(video_path)) @@ -504,16 +520,17 @@ def _compile_episode_data( def main( - pretrained_policy_path: str | None = None, + pretrained_policy_path: Path | None = None, hydra_cfg_path: str | None = None, out_dir: str | None = None, config_overrides: list[str] | None = None, ): assert (pretrained_policy_path is None) ^ (hydra_cfg_path is None) - if hydra_cfg_path is None: - hydra_cfg = init_hydra_config(pretrained_policy_path / "config.yaml", config_overrides) + if pretrained_policy_path is not None: + hydra_cfg = init_hydra_config(str(pretrained_policy_path / "config.yaml"), config_overrides) else: hydra_cfg = init_hydra_config(hydra_cfg_path, config_overrides) + if out_dir is None: out_dir = f"outputs/eval/{dt.now().strftime('%Y-%m-%d/%H-%M-%S')}_{hydra_cfg.env.name}_{hydra_cfg.policy.name}" @@ -531,10 +548,12 @@ def main( logging.info("Making policy.") if hydra_cfg_path is None: - policy = make_policy(hydra_cfg=hydra_cfg, pretrained_policy_name_or_path=pretrained_policy_path) + policy = make_policy(hydra_cfg=hydra_cfg, pretrained_policy_name_or_path=str(pretrained_policy_path)) else: # Note: We need the dataset stats to pass to the policy's normalization modules. policy = make_policy(hydra_cfg=hydra_cfg, dataset_stats=make_dataset(hydra_cfg).stats) + + assert isinstance(policy, nn.Module) policy.eval() with torch.no_grad(), torch.autocast(device_type=device.type) if hydra_cfg.use_amp else nullcontext(): diff --git a/lerobot/scripts/push_dataset_to_hub.py b/lerobot/scripts/push_dataset_to_hub.py index 52252b57..9cf72017 100644 --- a/lerobot/scripts/push_dataset_to_hub.py +++ b/lerobot/scripts/push_dataset_to_hub.py @@ -66,6 +66,7 @@ import argparse import json import shutil from pathlib import Path +from typing import Any import torch from huggingface_hub import HfApi @@ -77,7 +78,7 @@ from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_r from lerobot.common.datasets.utils import flatten_dict -def get_from_raw_to_lerobot_format_fn(raw_format): +def get_from_raw_to_lerobot_format_fn(raw_format: str): if raw_format == "pusht_zarr": from lerobot.common.datasets.push_dataset_to_hub.pusht_zarr_format import from_raw_to_lerobot_format elif raw_format == "umi_zarr": @@ -96,7 +97,9 @@ def get_from_raw_to_lerobot_format_fn(raw_format): return from_raw_to_lerobot_format -def save_meta_data(info, stats, episode_data_index, meta_data_dir): +def save_meta_data( + info: dict[str, Any], stats: dict, episode_data_index: dict[str, list], meta_data_dir: Path +): meta_data_dir.mkdir(parents=True, exist_ok=True) # save info @@ -114,7 +117,7 @@ def save_meta_data(info, stats, episode_data_index, meta_data_dir): save_file(episode_data_index, ep_data_idx_path) -def push_meta_data_to_hub(repo_id, meta_data_dir, revision): +def push_meta_data_to_hub(repo_id: str, meta_data_dir: str | Path, revision: str | None): """Expect all meta data files to be all stored in a single "meta_data" directory. On the hugging face repositery, they will be uploaded in a "meta_data" directory at the root. """ @@ -128,7 +131,7 @@ def push_meta_data_to_hub(repo_id, meta_data_dir, revision): ) -def push_videos_to_hub(repo_id, videos_dir, revision): +def push_videos_to_hub(repo_id: str, videos_dir: str | Path, revision: str | None): """Expect mp4 files to be all stored in a single "videos" directory. On the hugging face repositery, they will be uploaded in a "videos" directory at the root. """ @@ -209,6 +212,7 @@ def push_dataset_to_hub( save_meta_data(info, stats, episode_data_index, meta_data_dir) if not dry_run: + # TODO(rcadene): token needs to be a str | None hf_dataset.push_to_hub(repo_id, token=True, revision="main") hf_dataset.push_to_hub(repo_id, token=True, revision=revision) diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 5990d18a..01b2ef4f 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -24,6 +24,7 @@ import torch from deepdiff import DeepDiff from omegaconf import DictConfig, OmegaConf from termcolor import colored +from torch import nn from torch.cuda.amp import GradScaler from lerobot.common.datasets.factory import make_dataset, resolve_delta_timestamps @@ -292,6 +293,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No # Create environment used for evaluating checkpoints during training on simulation data. # On real-world data, no need to create an environment as evaluations are done outside train.py, # using the eval.py instead, with gym_dora environment and dora-rs. + eval_env = None if cfg.training.eval_freq > 0: logging.info("make_env") eval_env = make_env(cfg) @@ -302,7 +304,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No dataset_stats=offline_dataset.stats if not cfg.resume else None, pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None, ) - + assert isinstance(policy, nn.Module) # Create optimizer and scheduler # Temporary hack to move optimizer out of policy optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy) @@ -333,6 +335,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No if cfg.training.eval_freq > 0 and step % cfg.training.eval_freq == 0: logging.info(f"Eval policy at step {step}") with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.use_amp else nullcontext(): + assert eval_env is not None eval_info = eval_policy( eval_env, policy, @@ -414,7 +417,8 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No step += 1 - eval_env.close() + if eval_env: + eval_env.close() logging.info("End of training") diff --git a/lerobot/scripts/visualize_dataset.py b/lerobot/scripts/visualize_dataset.py index 2ed76898..f947e610 100644 --- a/lerobot/scripts/visualize_dataset.py +++ b/lerobot/scripts/visualize_dataset.py @@ -66,28 +66,31 @@ import gc import logging import time from pathlib import Path +from typing import Iterator +import numpy as np import rerun as rr import torch +import torch.utils.data import tqdm from lerobot.common.datasets.lerobot_dataset import LeRobotDataset class EpisodeSampler(torch.utils.data.Sampler): - def __init__(self, dataset, episode_index): + def __init__(self, dataset: LeRobotDataset, episode_index: int): from_idx = dataset.episode_data_index["from"][episode_index].item() to_idx = dataset.episode_data_index["to"][episode_index].item() self.frame_ids = range(from_idx, to_idx) - def __iter__(self): + def __iter__(self) -> Iterator: return iter(self.frame_ids) - def __len__(self): + def __len__(self) -> int: return len(self.frame_ids) -def to_hwc_uint8_numpy(chw_float32_torch): +def to_hwc_uint8_numpy(chw_float32_torch: torch.Tensor) -> np.ndarray: assert chw_float32_torch.dtype == torch.float32 assert chw_float32_torch.ndim == 3 c, h, w = chw_float32_torch.shape