Improves Type Annotations (#252)
This commit is contained in:
parent
a06598678c
commit
54c9776bde
|
@ -241,5 +241,6 @@ class Logger:
|
||||||
|
|
||||||
def log_video(self, video_path: str, step: int, mode: str = "train"):
|
def log_video(self, video_path: str, step: int, mode: str = "train"):
|
||||||
assert mode in {"train", "eval"}
|
assert mode in {"train", "eval"}
|
||||||
|
assert self._wandb is not None
|
||||||
wandb_video = self._wandb.Video(video_path, fps=self._cfg.fps, format="mp4")
|
wandb_video = self._wandb.Video(video_path, fps=self._cfg.fps, format="mp4")
|
||||||
self._wandb.log({f"{mode}/video": wandb_video}, step=step)
|
self._wandb.log({f"{mode}/video": wandb_video}, step=step)
|
||||||
|
|
|
@ -57,7 +57,7 @@ class Policy(Protocol):
|
||||||
other items should be logging-friendly, native Python types.
|
other items should be logging-friendly, native Python types.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def select_action(self, batch: dict[str, Tensor]):
|
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||||
"""Return one action to run in the environment (potentially in batch mode).
|
"""Return one action to run in the environment (potentially in batch mode).
|
||||||
|
|
||||||
When the model uses a history of observations, or outputs a sequence of actions, this method deals
|
When the model uses a history of observations, or outputs a sequence of actions, this method deals
|
||||||
|
|
|
@ -134,7 +134,7 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
self._prev_mean: torch.Tensor | None = None
|
self._prev_mean: torch.Tensor | None = None
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def select_action(self, batch: dict[str, Tensor]):
|
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||||
"""Select a single action given environment observations."""
|
"""Select a single action given environment observations."""
|
||||||
batch = self.normalize_inputs(batch)
|
batch = self.normalize_inputs(batch)
|
||||||
batch["observation.image"] = batch[self.input_image_key]
|
batch["observation.image"] = batch[self.input_image_key]
|
||||||
|
|
|
@ -61,7 +61,7 @@ from huggingface_hub import snapshot_download
|
||||||
from huggingface_hub.utils._errors import RepositoryNotFoundError
|
from huggingface_hub.utils._errors import RepositoryNotFoundError
|
||||||
from huggingface_hub.utils._validators import HFValidationError
|
from huggingface_hub.utils._validators import HFValidationError
|
||||||
from PIL import Image as PILImage
|
from PIL import Image as PILImage
|
||||||
from torch import Tensor
|
from torch import Tensor, nn
|
||||||
from tqdm import trange
|
from tqdm import trange
|
||||||
|
|
||||||
from lerobot.common.datasets.factory import make_dataset
|
from lerobot.common.datasets.factory import make_dataset
|
||||||
|
@ -99,13 +99,13 @@ def rollout(
|
||||||
"reward": A (batch, sequence) tensor of rewards received for applying the actions.
|
"reward": A (batch, sequence) tensor of rewards received for applying the actions.
|
||||||
"success": A (batch, sequence) tensor of success conditions (the only time this can be True is upon
|
"success": A (batch, sequence) tensor of success conditions (the only time this can be True is upon
|
||||||
environment termination/truncation).
|
environment termination/truncation).
|
||||||
"don": A (batch, sequence) tensor of **cumulative** done conditions. For any given batch element,
|
"done": A (batch, sequence) tensor of **cumulative** done conditions. For any given batch element,
|
||||||
the first True is followed by True's all the way till the end. This can be used for masking
|
the first True is followed by True's all the way till the end. This can be used for masking
|
||||||
extraneous elements from the sequences above.
|
extraneous elements from the sequences above.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
env: The batch of environments.
|
env: The batch of environments.
|
||||||
policy: The policy.
|
policy: The policy. Must be a PyTorch nn module.
|
||||||
seeds: The environments are seeded once at the start of the rollout. If provided, this argument
|
seeds: The environments are seeded once at the start of the rollout. If provided, this argument
|
||||||
specifies the seeds for each of the environments.
|
specifies the seeds for each of the environments.
|
||||||
return_observations: Whether to include all observations in the returned rollout data. Observations
|
return_observations: Whether to include all observations in the returned rollout data. Observations
|
||||||
|
@ -116,6 +116,7 @@ def rollout(
|
||||||
Returns:
|
Returns:
|
||||||
The dictionary described above.
|
The dictionary described above.
|
||||||
"""
|
"""
|
||||||
|
assert isinstance(policy, nn.Module), "Policy must be a PyTorch nn module."
|
||||||
device = get_device_from_parameters(policy)
|
device = get_device_from_parameters(policy)
|
||||||
|
|
||||||
# Reset the policy and environments.
|
# Reset the policy and environments.
|
||||||
|
@ -231,6 +232,10 @@ def eval_policy(
|
||||||
Returns:
|
Returns:
|
||||||
Dictionary with metrics and data regarding the rollouts.
|
Dictionary with metrics and data regarding the rollouts.
|
||||||
"""
|
"""
|
||||||
|
if max_episodes_rendered > 0 and not videos_dir:
|
||||||
|
raise ValueError("If max_episodes_rendered > 0, videos_dir must be provided.")
|
||||||
|
|
||||||
|
assert isinstance(policy, Policy)
|
||||||
start = time.time()
|
start = time.time()
|
||||||
policy.eval()
|
policy.eval()
|
||||||
|
|
||||||
|
@ -271,11 +276,16 @@ def eval_policy(
|
||||||
if max_episodes_rendered > 0:
|
if max_episodes_rendered > 0:
|
||||||
ep_frames: list[np.ndarray] = []
|
ep_frames: list[np.ndarray] = []
|
||||||
|
|
||||||
seeds = range(start_seed + (batch_ix * env.num_envs), start_seed + ((batch_ix + 1) * env.num_envs))
|
if start_seed is None:
|
||||||
|
seeds = None
|
||||||
|
else:
|
||||||
|
seeds = range(
|
||||||
|
start_seed + (batch_ix * env.num_envs), start_seed + ((batch_ix + 1) * env.num_envs)
|
||||||
|
)
|
||||||
rollout_data = rollout(
|
rollout_data = rollout(
|
||||||
env,
|
env,
|
||||||
policy,
|
policy,
|
||||||
seeds=seeds,
|
seeds=list(seeds) if seeds else None,
|
||||||
return_observations=return_episode_data,
|
return_observations=return_episode_data,
|
||||||
render_callback=render_frame if max_episodes_rendered > 0 else None,
|
render_callback=render_frame if max_episodes_rendered > 0 else None,
|
||||||
enable_progbar=enable_inner_progbar,
|
enable_progbar=enable_inner_progbar,
|
||||||
|
@ -285,7 +295,8 @@ def eval_policy(
|
||||||
# this won't be included).
|
# this won't be included).
|
||||||
n_steps = rollout_data["done"].shape[1]
|
n_steps = rollout_data["done"].shape[1]
|
||||||
# Note: this relies on a property of argmax: that it returns the first occurrence as a tiebreaker.
|
# Note: this relies on a property of argmax: that it returns the first occurrence as a tiebreaker.
|
||||||
done_indices = torch.argmax(rollout_data["done"].to(int), axis=1) # (batch_size, rollout_steps)
|
done_indices = torch.argmax(rollout_data["done"].to(int), dim=1)
|
||||||
|
|
||||||
# Make a mask with shape (batch, n_steps) to mask out rollout data after the first done
|
# Make a mask with shape (batch, n_steps) to mask out rollout data after the first done
|
||||||
# (batch-element-wise). Note the `done_indices + 1` to make sure to keep the data from the done step.
|
# (batch-element-wise). Note the `done_indices + 1` to make sure to keep the data from the done step.
|
||||||
mask = (torch.arange(n_steps) <= einops.repeat(done_indices + 1, "b -> b s", s=n_steps)).int()
|
mask = (torch.arange(n_steps) <= einops.repeat(done_indices + 1, "b -> b s", s=n_steps)).int()
|
||||||
|
@ -296,8 +307,12 @@ def eval_policy(
|
||||||
max_rewards.extend(batch_max_rewards.tolist())
|
max_rewards.extend(batch_max_rewards.tolist())
|
||||||
batch_successes = einops.reduce((rollout_data["success"] * mask), "b n -> b", "any")
|
batch_successes = einops.reduce((rollout_data["success"] * mask), "b n -> b", "any")
|
||||||
all_successes.extend(batch_successes.tolist())
|
all_successes.extend(batch_successes.tolist())
|
||||||
all_seeds.extend(seeds)
|
if seeds:
|
||||||
|
all_seeds.extend(seeds)
|
||||||
|
else:
|
||||||
|
all_seeds.append(None)
|
||||||
|
|
||||||
|
# FIXME: episode_data is either None or it doesn't exist
|
||||||
if return_episode_data:
|
if return_episode_data:
|
||||||
this_episode_data = _compile_episode_data(
|
this_episode_data = _compile_episode_data(
|
||||||
rollout_data,
|
rollout_data,
|
||||||
|
@ -347,6 +362,7 @@ def eval_policy(
|
||||||
):
|
):
|
||||||
if n_episodes_rendered >= max_episodes_rendered:
|
if n_episodes_rendered >= max_episodes_rendered:
|
||||||
break
|
break
|
||||||
|
|
||||||
videos_dir.mkdir(parents=True, exist_ok=True)
|
videos_dir.mkdir(parents=True, exist_ok=True)
|
||||||
video_path = videos_dir / f"eval_episode_{n_episodes_rendered}.mp4"
|
video_path = videos_dir / f"eval_episode_{n_episodes_rendered}.mp4"
|
||||||
video_paths.append(str(video_path))
|
video_paths.append(str(video_path))
|
||||||
|
@ -504,16 +520,17 @@ def _compile_episode_data(
|
||||||
|
|
||||||
|
|
||||||
def main(
|
def main(
|
||||||
pretrained_policy_path: str | None = None,
|
pretrained_policy_path: Path | None = None,
|
||||||
hydra_cfg_path: str | None = None,
|
hydra_cfg_path: str | None = None,
|
||||||
out_dir: str | None = None,
|
out_dir: str | None = None,
|
||||||
config_overrides: list[str] | None = None,
|
config_overrides: list[str] | None = None,
|
||||||
):
|
):
|
||||||
assert (pretrained_policy_path is None) ^ (hydra_cfg_path is None)
|
assert (pretrained_policy_path is None) ^ (hydra_cfg_path is None)
|
||||||
if hydra_cfg_path is None:
|
if pretrained_policy_path is not None:
|
||||||
hydra_cfg = init_hydra_config(pretrained_policy_path / "config.yaml", config_overrides)
|
hydra_cfg = init_hydra_config(str(pretrained_policy_path / "config.yaml"), config_overrides)
|
||||||
else:
|
else:
|
||||||
hydra_cfg = init_hydra_config(hydra_cfg_path, config_overrides)
|
hydra_cfg = init_hydra_config(hydra_cfg_path, config_overrides)
|
||||||
|
|
||||||
if out_dir is None:
|
if out_dir is None:
|
||||||
out_dir = f"outputs/eval/{dt.now().strftime('%Y-%m-%d/%H-%M-%S')}_{hydra_cfg.env.name}_{hydra_cfg.policy.name}"
|
out_dir = f"outputs/eval/{dt.now().strftime('%Y-%m-%d/%H-%M-%S')}_{hydra_cfg.env.name}_{hydra_cfg.policy.name}"
|
||||||
|
|
||||||
|
@ -531,10 +548,12 @@ def main(
|
||||||
|
|
||||||
logging.info("Making policy.")
|
logging.info("Making policy.")
|
||||||
if hydra_cfg_path is None:
|
if hydra_cfg_path is None:
|
||||||
policy = make_policy(hydra_cfg=hydra_cfg, pretrained_policy_name_or_path=pretrained_policy_path)
|
policy = make_policy(hydra_cfg=hydra_cfg, pretrained_policy_name_or_path=str(pretrained_policy_path))
|
||||||
else:
|
else:
|
||||||
# Note: We need the dataset stats to pass to the policy's normalization modules.
|
# Note: We need the dataset stats to pass to the policy's normalization modules.
|
||||||
policy = make_policy(hydra_cfg=hydra_cfg, dataset_stats=make_dataset(hydra_cfg).stats)
|
policy = make_policy(hydra_cfg=hydra_cfg, dataset_stats=make_dataset(hydra_cfg).stats)
|
||||||
|
|
||||||
|
assert isinstance(policy, nn.Module)
|
||||||
policy.eval()
|
policy.eval()
|
||||||
|
|
||||||
with torch.no_grad(), torch.autocast(device_type=device.type) if hydra_cfg.use_amp else nullcontext():
|
with torch.no_grad(), torch.autocast(device_type=device.type) if hydra_cfg.use_amp else nullcontext():
|
||||||
|
|
|
@ -66,6 +66,7 @@ import argparse
|
||||||
import json
|
import json
|
||||||
import shutil
|
import shutil
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from huggingface_hub import HfApi
|
from huggingface_hub import HfApi
|
||||||
|
@ -77,7 +78,7 @@ from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_r
|
||||||
from lerobot.common.datasets.utils import flatten_dict
|
from lerobot.common.datasets.utils import flatten_dict
|
||||||
|
|
||||||
|
|
||||||
def get_from_raw_to_lerobot_format_fn(raw_format):
|
def get_from_raw_to_lerobot_format_fn(raw_format: str):
|
||||||
if raw_format == "pusht_zarr":
|
if raw_format == "pusht_zarr":
|
||||||
from lerobot.common.datasets.push_dataset_to_hub.pusht_zarr_format import from_raw_to_lerobot_format
|
from lerobot.common.datasets.push_dataset_to_hub.pusht_zarr_format import from_raw_to_lerobot_format
|
||||||
elif raw_format == "umi_zarr":
|
elif raw_format == "umi_zarr":
|
||||||
|
@ -96,7 +97,9 @@ def get_from_raw_to_lerobot_format_fn(raw_format):
|
||||||
return from_raw_to_lerobot_format
|
return from_raw_to_lerobot_format
|
||||||
|
|
||||||
|
|
||||||
def save_meta_data(info, stats, episode_data_index, meta_data_dir):
|
def save_meta_data(
|
||||||
|
info: dict[str, Any], stats: dict, episode_data_index: dict[str, list], meta_data_dir: Path
|
||||||
|
):
|
||||||
meta_data_dir.mkdir(parents=True, exist_ok=True)
|
meta_data_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
# save info
|
# save info
|
||||||
|
@ -114,7 +117,7 @@ def save_meta_data(info, stats, episode_data_index, meta_data_dir):
|
||||||
save_file(episode_data_index, ep_data_idx_path)
|
save_file(episode_data_index, ep_data_idx_path)
|
||||||
|
|
||||||
|
|
||||||
def push_meta_data_to_hub(repo_id, meta_data_dir, revision):
|
def push_meta_data_to_hub(repo_id: str, meta_data_dir: str | Path, revision: str | None):
|
||||||
"""Expect all meta data files to be all stored in a single "meta_data" directory.
|
"""Expect all meta data files to be all stored in a single "meta_data" directory.
|
||||||
On the hugging face repositery, they will be uploaded in a "meta_data" directory at the root.
|
On the hugging face repositery, they will be uploaded in a "meta_data" directory at the root.
|
||||||
"""
|
"""
|
||||||
|
@ -128,7 +131,7 @@ def push_meta_data_to_hub(repo_id, meta_data_dir, revision):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def push_videos_to_hub(repo_id, videos_dir, revision):
|
def push_videos_to_hub(repo_id: str, videos_dir: str | Path, revision: str | None):
|
||||||
"""Expect mp4 files to be all stored in a single "videos" directory.
|
"""Expect mp4 files to be all stored in a single "videos" directory.
|
||||||
On the hugging face repositery, they will be uploaded in a "videos" directory at the root.
|
On the hugging face repositery, they will be uploaded in a "videos" directory at the root.
|
||||||
"""
|
"""
|
||||||
|
@ -209,6 +212,7 @@ def push_dataset_to_hub(
|
||||||
save_meta_data(info, stats, episode_data_index, meta_data_dir)
|
save_meta_data(info, stats, episode_data_index, meta_data_dir)
|
||||||
|
|
||||||
if not dry_run:
|
if not dry_run:
|
||||||
|
# TODO(rcadene): token needs to be a str | None
|
||||||
hf_dataset.push_to_hub(repo_id, token=True, revision="main")
|
hf_dataset.push_to_hub(repo_id, token=True, revision="main")
|
||||||
hf_dataset.push_to_hub(repo_id, token=True, revision=revision)
|
hf_dataset.push_to_hub(repo_id, token=True, revision=revision)
|
||||||
|
|
||||||
|
|
|
@ -24,6 +24,7 @@ import torch
|
||||||
from deepdiff import DeepDiff
|
from deepdiff import DeepDiff
|
||||||
from omegaconf import DictConfig, OmegaConf
|
from omegaconf import DictConfig, OmegaConf
|
||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
|
from torch import nn
|
||||||
from torch.cuda.amp import GradScaler
|
from torch.cuda.amp import GradScaler
|
||||||
|
|
||||||
from lerobot.common.datasets.factory import make_dataset, resolve_delta_timestamps
|
from lerobot.common.datasets.factory import make_dataset, resolve_delta_timestamps
|
||||||
|
@ -292,6 +293,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||||
# Create environment used for evaluating checkpoints during training on simulation data.
|
# Create environment used for evaluating checkpoints during training on simulation data.
|
||||||
# On real-world data, no need to create an environment as evaluations are done outside train.py,
|
# On real-world data, no need to create an environment as evaluations are done outside train.py,
|
||||||
# using the eval.py instead, with gym_dora environment and dora-rs.
|
# using the eval.py instead, with gym_dora environment and dora-rs.
|
||||||
|
eval_env = None
|
||||||
if cfg.training.eval_freq > 0:
|
if cfg.training.eval_freq > 0:
|
||||||
logging.info("make_env")
|
logging.info("make_env")
|
||||||
eval_env = make_env(cfg)
|
eval_env = make_env(cfg)
|
||||||
|
@ -302,7 +304,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||||
dataset_stats=offline_dataset.stats if not cfg.resume else None,
|
dataset_stats=offline_dataset.stats if not cfg.resume else None,
|
||||||
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None,
|
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None,
|
||||||
)
|
)
|
||||||
|
assert isinstance(policy, nn.Module)
|
||||||
# Create optimizer and scheduler
|
# Create optimizer and scheduler
|
||||||
# Temporary hack to move optimizer out of policy
|
# Temporary hack to move optimizer out of policy
|
||||||
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
|
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
|
||||||
|
@ -333,6 +335,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||||
if cfg.training.eval_freq > 0 and step % cfg.training.eval_freq == 0:
|
if cfg.training.eval_freq > 0 and step % cfg.training.eval_freq == 0:
|
||||||
logging.info(f"Eval policy at step {step}")
|
logging.info(f"Eval policy at step {step}")
|
||||||
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.use_amp else nullcontext():
|
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.use_amp else nullcontext():
|
||||||
|
assert eval_env is not None
|
||||||
eval_info = eval_policy(
|
eval_info = eval_policy(
|
||||||
eval_env,
|
eval_env,
|
||||||
policy,
|
policy,
|
||||||
|
@ -414,7 +417,8 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||||
|
|
||||||
step += 1
|
step += 1
|
||||||
|
|
||||||
eval_env.close()
|
if eval_env:
|
||||||
|
eval_env.close()
|
||||||
logging.info("End of training")
|
logging.info("End of training")
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -66,28 +66,31 @@ import gc
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Iterator
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import rerun as rr
|
import rerun as rr
|
||||||
import torch
|
import torch
|
||||||
|
import torch.utils.data
|
||||||
import tqdm
|
import tqdm
|
||||||
|
|
||||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
|
|
||||||
|
|
||||||
class EpisodeSampler(torch.utils.data.Sampler):
|
class EpisodeSampler(torch.utils.data.Sampler):
|
||||||
def __init__(self, dataset, episode_index):
|
def __init__(self, dataset: LeRobotDataset, episode_index: int):
|
||||||
from_idx = dataset.episode_data_index["from"][episode_index].item()
|
from_idx = dataset.episode_data_index["from"][episode_index].item()
|
||||||
to_idx = dataset.episode_data_index["to"][episode_index].item()
|
to_idx = dataset.episode_data_index["to"][episode_index].item()
|
||||||
self.frame_ids = range(from_idx, to_idx)
|
self.frame_ids = range(from_idx, to_idx)
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self) -> Iterator:
|
||||||
return iter(self.frame_ids)
|
return iter(self.frame_ids)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self) -> int:
|
||||||
return len(self.frame_ids)
|
return len(self.frame_ids)
|
||||||
|
|
||||||
|
|
||||||
def to_hwc_uint8_numpy(chw_float32_torch):
|
def to_hwc_uint8_numpy(chw_float32_torch: torch.Tensor) -> np.ndarray:
|
||||||
assert chw_float32_torch.dtype == torch.float32
|
assert chw_float32_torch.dtype == torch.float32
|
||||||
assert chw_float32_torch.ndim == 3
|
assert chw_float32_torch.ndim == 3
|
||||||
c, h, w = chw_float32_torch.shape
|
c, h, w = chw_float32_torch.shape
|
||||||
|
|
Loading…
Reference in New Issue