diff --git a/.github/poetry/cpu/poetry.lock b/.github/poetry/cpu/poetry.lock index fe4ed7a0..98a3d58d 100644 --- a/.github/poetry/cpu/poetry.lock +++ b/.github/poetry/cpu/poetry.lock @@ -940,7 +940,7 @@ mujoco = "^2.3.7" type = "git" url = "git@github.com:huggingface/gym-xarm.git" reference = "HEAD" -resolved_reference = "08ddd5a9400783a6898bbf3c3014fc5da3961b9d" +resolved_reference = "6a88f7d63833705dfbec4b997bf36cac6b4a448c" [[package]] name = "gymnasium" diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index afdcc41f..b3411e11 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -142,6 +142,7 @@ jobs: wandb.enable=False \ offline_steps=2 \ online_steps=0 \ + eval_episodes=1 \ device=cpu \ save_model=true \ save_freq=2 \ @@ -159,17 +160,6 @@ jobs: device=cpu \ policy.pretrained_model_path=tests/outputs/act/models/2.pt - # TODO(aliberts): This takes ~2mn to run, needs to be improved - # - name: Test eval ACT on ALOHA end-to-end (policy is None) - # run: | - # source .venv/bin/activate - # python lerobot/scripts/eval.py \ - # --config lerobot/configs/default.yaml \ - # policy=act \ - # env=aloha \ - # eval_episodes=1 \ - # device=cpu - - name: Test train Diffusion on PushT end-to-end run: | source .venv/bin/activate @@ -179,9 +169,11 @@ jobs: wandb.enable=False \ offline_steps=2 \ online_steps=0 \ + eval_episodes=1 \ device=cpu \ save_model=true \ save_freq=2 \ + policy.batch_size=2 \ hydra.run.dir=tests/outputs/diffusion/ - name: Test eval Diffusion on PushT end-to-end @@ -194,16 +186,6 @@ jobs: device=cpu \ policy.pretrained_model_path=tests/outputs/diffusion/models/2.pt - - name: Test eval Diffusion on PushT end-to-end (policy is None) - run: | - source .venv/bin/activate - python lerobot/scripts/eval.py \ - --config lerobot/configs/default.yaml \ - policy=diffusion \ - env=pusht \ - eval_episodes=1 \ - device=cpu - - name: Test train TDMPC on Simxarm end-to-end run: | source .venv/bin/activate @@ -213,9 +195,11 @@ jobs: wandb.enable=False \ offline_steps=1 \ online_steps=1 \ + eval_episodes=1 \ device=cpu \ save_model=true \ save_freq=2 \ + policy.batch_size=2 \ hydra.run.dir=tests/outputs/tdmpc/ - name: Test eval TDMPC on Simxarm end-to-end @@ -227,13 +211,3 @@ jobs: env.episode_length=8 \ device=cpu \ policy.pretrained_model_path=tests/outputs/tdmpc/models/2.pt - - - name: Test eval TDPMC on Simxarm end-to-end (policy is None) - run: | - source .venv/bin/activate - python lerobot/scripts/eval.py \ - --config lerobot/configs/default.yaml \ - policy=tdmpc \ - env=xarm \ - eval_episodes=1 \ - device=cpu diff --git a/.gitignore b/.gitignore index ad9892d4..3132aba0 100644 --- a/.gitignore +++ b/.gitignore @@ -11,6 +11,9 @@ rl nautilus/*.yaml *.key +# Slurm +sbatch*.sh + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] diff --git a/README.md b/README.md index 51e03d65..25b8d1e4 100644 --- a/README.md +++ b/README.md @@ -120,34 +120,32 @@ wandb login You can import our dataset class, download the data from the HuggingFace hub and use our rendering utilities: ```python """ Copy pasted from `examples/1_visualize_dataset.py` """ +import os +from pathlib import Path + import lerobot from lerobot.common.datasets.aloha import AlohaDataset -from torchrl.data.replay_buffers import SamplerWithoutReplacement from lerobot.scripts.visualize_dataset import render_dataset print(lerobot.available_datasets) # >>> ['aloha_sim_insertion_human', 'aloha_sim_insertion_scripted', 'aloha_sim_transfer_cube_human', 'aloha_sim_transfer_cube_scripted', 'pusht', 'xarm_lift_medium'] -# we use this sampler to sample 1 frame after the other -sampler = SamplerWithoutReplacement(shuffle=False) - -dataset = AlohaDataset("aloha_sim_transfer_cube_human", sampler=sampler) +# TODO(rcadene): remove DATA_DIR +dataset = AlohaDataset("pusht", root=Path(os.environ.get("DATA_DIR"))) video_paths = render_dataset( dataset, out_dir="outputs/visualize_dataset/example", - max_num_samples=300, - fps=50, + max_num_episodes=1, ) print(video_paths) -# >>> ['outputs/visualize_dataset/example/episode_0.mp4'] +# ['outputs/visualize_dataset/example/episode_0.mp4'] ``` Or you can achieve the same result by executing our script from the command line: ```bash python lerobot/scripts/visualize_dataset.py \ -env=aloha \ -task=sim_sim_transfer_cube_human \ +env=pusht \ hydra.run.dir=outputs/visualize_dataset/example # >>> ['outputs/visualize_dataset/example/episode_0.mp4'] ``` diff --git a/examples/1_visualize_dataset.py b/examples/1_visualize_dataset.py index f52ab76a..15e0e54d 100644 --- a/examples/1_visualize_dataset.py +++ b/examples/1_visualize_dataset.py @@ -1,24 +1,20 @@ import os - -from torchrl.data.replay_buffers import SamplerWithoutReplacement +from pathlib import Path import lerobot -from lerobot.common.datasets.aloha import AlohaDataset +from lerobot.common.datasets.pusht import PushtDataset from lerobot.scripts.visualize_dataset import render_dataset print(lerobot.available_datasets) # >>> ['aloha_sim_insertion_human', 'aloha_sim_insertion_scripted', 'aloha_sim_transfer_cube_human', 'aloha_sim_transfer_cube_scripted', 'pusht', 'xarm_lift_medium'] -# we use this sampler to sample 1 frame after the other -sampler = SamplerWithoutReplacement(shuffle=False) - -dataset = AlohaDataset("aloha_sim_transfer_cube_human", sampler=sampler, root=os.environ.get("DATA_DIR")) +# TODO(rcadene): remove DATA_DIR +dataset = PushtDataset("pusht", root=Path(os.environ.get("DATA_DIR"))) video_paths = render_dataset( dataset, out_dir="outputs/visualize_dataset/example", - max_num_samples=300, - fps=50, + max_num_episodes=1, ) print(video_paths) # ['outputs/visualize_dataset/example/episode_0.mp4'] diff --git a/examples/3_train_policy.py b/examples/3_train_policy.py index 6e01a5d5..238f953d 100644 --- a/examples/3_train_policy.py +++ b/examples/3_train_policy.py @@ -9,9 +9,8 @@ from pathlib import Path import torch from omegaconf import OmegaConf -from tqdm import trange -from lerobot.common.datasets.factory import make_offline_buffer +from lerobot.common.datasets.factory import make_dataset from lerobot.common.policies.diffusion.policy import DiffusionPolicy from lerobot.common.utils import init_hydra_config @@ -37,19 +36,33 @@ policy = DiffusionPolicy( cfg_obs_encoder=cfg.obs_encoder, cfg_optimizer=cfg.optimizer, cfg_ema=cfg.ema, - n_action_steps=cfg.n_action_steps, **cfg.policy, ) policy.train() -offline_buffer = make_offline_buffer(cfg) +dataset = make_dataset(cfg) + +# create dataloader for offline training +dataloader = torch.utils.data.DataLoader( + dataset, + num_workers=4, + batch_size=cfg.policy.batch_size, + shuffle=True, + pin_memory=cfg.device != "cpu", + drop_last=True, +) + +for step, batch in enumerate(dataloader): + info = policy(batch, step) + + if step % cfg.log_freq == 0: + num_samples = (step + 1) * cfg.policy.batch_size + loss = info["loss"] + update_s = info["update_s"] + print(f"step:{step} samples:{num_samples} loss:{loss:.3f} update_time:{update_s:.3f}(seconds)") -for offline_step in trange(cfg.offline_steps): - train_info = policy.update(offline_buffer, offline_step) - if offline_step % cfg.log_freq == 0: - print(train_info) # Save the policy, configuration, and normalization stats for later use. policy.save(output_directory / "model.pt") OmegaConf.save(cfg, output_directory / "config.yaml") -torch.save(offline_buffer.transform[-1].stats, output_directory / "stats.pth") +torch.save(dataset.transform.transforms[-1].stats, output_directory / "stats.pth") diff --git a/lerobot/__init__.py b/lerobot/__init__.py index 4673aab0..8ab95df8 100644 --- a/lerobot/__init__.py +++ b/lerobot/__init__.py @@ -12,14 +12,11 @@ Example: print(lerobot.available_policies) ``` -Note: - When implementing a concrete class (e.g. `AlohaDataset`, `PushtEnv`, `DiffusionPolicy`), you need to: - 1. set the required class attributes: - - for classes inheriting from `AbstractDataset`: `available_datasets` - - for classes inheriting from `AbstractEnv`: `name`, `available_tasks` - - for classes inheriting from `AbstractPolicy`: `name` - 2. update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`) - 3. update variables in `tests/test_available.py` by importing your new class +When implementing a new dataset (e.g. `AlohaDataset`), policy (e.g. `DiffusionPolicy`), or environment, follow these steps: +- Set the required class attributes: `available_datasets`. +- Set the required class attributes: `name`. +- Update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`) +- Update variables in `tests/test_available.py` by importing your new class """ from lerobot.__version__ import __version__ # noqa: F401 @@ -32,11 +29,11 @@ available_envs = [ available_tasks_per_env = { "aloha": [ - "sim_insertion", - "sim_transfer_cube", + "AlohaInsertion-v0", + "AlohaTransferCube-v0", ], - "pusht": ["pusht"], - "xarm": ["lift"], + "pusht": ["PushT-v0"], + "xarm": ["XarmLift-v0"], } available_datasets_per_env = { diff --git a/lerobot/common/datasets/aloha.py b/lerobot/common/datasets/aloha.py index 50bf819a..4b241ad8 100644 --- a/lerobot/common/datasets/aloha.py +++ b/lerobot/common/datasets/aloha.py @@ -105,7 +105,7 @@ class AlohaDataset(torch.utils.data.Dataset): @property def num_samples(self) -> int: - return len(self.data_dict["index"]) + return len(self.data_dict["index"]) if "index" in self.data_dict else 0 @property def num_episodes(self) -> int: diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py index 0dab5d4b..4ae161f6 100644 --- a/lerobot/common/datasets/factory.py +++ b/lerobot/common/datasets/factory.py @@ -1,10 +1,11 @@ +import logging import os from pathlib import Path import torch from torchvision.transforms import v2 -from lerobot.common.datasets.utils import compute_or_load_stats +from lerobot.common.datasets.utils import compute_stats from lerobot.common.transforms import NormalizeTransform, Prod # DATA_DIR specifies to location where datasets are loaded. By default, DATA_DIR is None and @@ -40,7 +41,8 @@ def make_dataset( if normalize: # TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max, # min_max_from_spec - # stats = dataset.compute_or_load_stats() if stats_path is None else torch.load(stats_path) + # TODO(rcadene): remove this and put it in config. Ideally we want to reproduce SOTA results just with mean_std + normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max" if cfg.policy.name == "diffusion" and cfg.env.name == "pusht": stats = {} @@ -51,21 +53,27 @@ def make_dataset( stats["action"] = {} stats["action"]["min"] = torch.tensor([12.0, 25.0], dtype=torch.float32) stats["action"]["max"] = torch.tensor([511.0, 511.0], dtype=torch.float32) - else: + elif stats_path is None: # instantiate a one frame dataset with light transform stats_dataset = clsfunc( dataset_id=cfg.dataset_id, root=DATA_DIR, transform=Prod(in_keys=clsfunc.image_keys, prod=1 / 255.0), ) - stats = compute_or_load_stats(stats_dataset) - # TODO(rcadene): remove this and put it in config. Ideally we want to reproduce SOTA results just with mean_std - normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max" + # load stats if the file exists already or compute stats and save it + precomputed_stats_path = stats_dataset.data_dir / "stats.pth" + if precomputed_stats_path.exists(): + stats = torch.load(precomputed_stats_path) + else: + logging.info(f"compute_stats and save to {precomputed_stats_path}") + stats = compute_stats(stats_dataset) + torch.save(stats, stats_path) + else: + stats = torch.load(stats_path) transforms = v2.Compose( [ - # TODO(rcadene): we need to do something about image_keys Prod(in_keys=clsfunc.image_keys, prod=1 / 255.0), NormalizeTransform( stats, diff --git a/lerobot/common/datasets/pusht.py b/lerobot/common/datasets/pusht.py index 9af6f3a1..47253b15 100644 --- a/lerobot/common/datasets/pusht.py +++ b/lerobot/common/datasets/pusht.py @@ -2,11 +2,8 @@ from pathlib import Path import einops import numpy as np -import pygame -import pymunk import torch import tqdm -from gym_pusht.envs.pusht import pymunk_to_shapely from lerobot.common.datasets._diffusion_policy_replay_buffer import ( ReplayBuffer as DiffusionPolicyReplayBuffer, @@ -20,64 +17,6 @@ PUSHT_URL = "https://diffusion-policy.cs.columbia.edu/data/training/pusht.zip" PUSHT_ZARR = Path("pusht/pusht_cchi_v7_replay.zarr") -def get_goal_pose_body(pose): - mass = 1 - inertia = pymunk.moment_for_box(mass, (50, 100)) - body = pymunk.Body(mass, inertia) - # preserving the legacy assignment order for compatibility - # the order here doesn't matter somehow, maybe because CoM is aligned with body origin - body.position = pose[:2].tolist() - body.angle = pose[2] - return body - - -def add_segment(space, a, b, radius): - shape = pymunk.Segment(space.static_body, a, b, radius) - shape.color = pygame.Color("LightGray") # https://htmlcolorcodes.com/color-names - return shape - - -def add_tee( - space, - position, - angle, - scale=30, - color="LightSlateGray", - mask=None, -): - if mask is None: - mask = pymunk.ShapeFilter.ALL_MASKS() - mass = 1 - length = 4 - vertices1 = [ - (-length * scale / 2, scale), - (length * scale / 2, scale), - (length * scale / 2, 0), - (-length * scale / 2, 0), - ] - inertia1 = pymunk.moment_for_poly(mass, vertices=vertices1) - vertices2 = [ - (-scale / 2, scale), - (-scale / 2, length * scale), - (scale / 2, length * scale), - (scale / 2, scale), - ] - inertia2 = pymunk.moment_for_poly(mass, vertices=vertices1) - body = pymunk.Body(mass, inertia1 + inertia2) - shape1 = pymunk.Poly(body, vertices1) - shape2 = pymunk.Poly(body, vertices2) - shape1.color = pygame.Color(color) - shape2.color = pygame.Color(color) - shape1.filter = pymunk.ShapeFilter(mask=mask) - shape2.filter = pymunk.ShapeFilter(mask=mask) - body.center_of_gravity = (shape1.center_of_gravity + shape2.center_of_gravity) / 2 - body.position = position - body.angle = angle - body.friction = 1 - space.add(body, shape1, shape2) - return body - - class PushtDataset(torch.utils.data.Dataset): """ @@ -121,7 +60,7 @@ class PushtDataset(torch.utils.data.Dataset): @property def num_samples(self) -> int: - return len(self.data_dict["index"]) + return len(self.data_dict["index"]) if "index" in self.data_dict else 0 @property def num_episodes(self) -> int: @@ -158,6 +97,13 @@ class PushtDataset(torch.utils.data.Dataset): return item def _download_and_preproc_obsolete(self): + try: + import pymunk + from gym_pusht.envs.pusht import PushTEnv, pymunk_to_shapely + except ModuleNotFoundError as e: + print("`gym_pusht` is not installed. Please install it with `pip install 'lerobot[gym_pusht]'`") + raise e + assert self.root is not None raw_dir = self.root / f"{self.dataset_id}_raw" zarr_path = (raw_dir / PUSHT_ZARR).resolve() @@ -182,7 +128,7 @@ class PushtDataset(torch.utils.data.Dataset): # TODO: verify that goal pose is expected to be fixed goal_pos_angle = np.array([256, 256, np.pi / 4]) # x, y, theta (in radians) - goal_body = get_goal_pose_body(goal_pos_angle) + goal_body = PushTEnv.get_goal_pose_body(goal_pos_angle) imgs = torch.from_numpy(dataset_dict["img"]) imgs = einops.rearrange(imgs, "b h w c -> b c h w") @@ -201,6 +147,9 @@ class PushtDataset(torch.utils.data.Dataset): assert (episode_ids[idx0:idx1] == episode_id).all() image = imgs[idx0:idx1] + assert image.min() >= 0.0 + assert image.max() <= 255.0 + image = image.type(torch.uint8) state = states[idx0:idx1] agent_pos = state[:, :2] @@ -217,14 +166,14 @@ class PushtDataset(torch.utils.data.Dataset): # Add walls. walls = [ - add_segment(space, (5, 506), (5, 5), 2), - add_segment(space, (5, 5), (506, 5), 2), - add_segment(space, (506, 5), (506, 506), 2), - add_segment(space, (5, 506), (506, 506), 2), + PushTEnv.add_segment(space, (5, 506), (5, 5), 2), + PushTEnv.add_segment(space, (5, 5), (506, 5), 2), + PushTEnv.add_segment(space, (506, 5), (506, 506), 2), + PushTEnv.add_segment(space, (5, 506), (506, 506), 2), ] space.add(*walls) - block_body = add_tee(space, block_pos[i].tolist(), block_angle[i].item()) + block_body = PushTEnv.add_tee(space, block_pos[i].tolist(), block_angle[i].item()) goal_geom = pymunk_to_shapely(goal_body, block_body.shapes) block_geom = pymunk_to_shapely(block_body, block_body.shapes) intersection_area = goal_geom.intersection(block_geom).area @@ -265,16 +214,3 @@ class PushtDataset(torch.utils.data.Dataset): self.data_dict[key] = torch.cat([x[key] for x in ep_dicts]) self.data_dict["index"] = torch.arange(0, total_frames, 1) - - -if __name__ == "__main__": - dataset = PushtDataset( - "pusht", - root=Path("data"), - delta_timestamps={ - "observation.image": [0, -1, -0.2, -0.1], - "observation.state": [0, -1, -0.2, -0.1], - "action": [-0.1, 0, 1, 2, 3], - }, - ) - dataset[10] diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index 3b4aacfc..e67d8a04 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -1,5 +1,4 @@ import io -import logging import zipfile from copy import deepcopy from math import ceil @@ -35,52 +34,56 @@ def download_and_extract_zip(url: str, destination_folder: Path) -> bool: return False -def euclidean_distance_matrix(mat0, mat1): - # Compute the square of the distance matrix - sq0 = torch.sum(mat0**2, dim=1, keepdim=True) - sq1 = torch.sum(mat1**2, dim=1, keepdim=True) - distance_sq = sq0 + sq1.transpose(0, 1) - 2 * mat0 @ mat1.transpose(0, 1) - - # Taking the square root to get the euclidean distance - distance = torch.sqrt(torch.clamp(distance_sq, min=0)) - return distance - - -def is_contiguously_true_or_false(bool_vector): - assert bool_vector.ndim == 1 - assert bool_vector.dtype == torch.bool - - # Compare each element with its neighbor to find changes - changes = bool_vector[1:] != bool_vector[:-1] - - # Count the number of changes - num_changes = changes.sum().item() - - # If there's more than one change, the list is not contiguous - return num_changes <= 1 - - # examples = [ - # ([True, False, True, False, False, False], False), - # ([True, True, True, False, False, False], True), - # ([False, False, False, False, False, False], True) - # ] - # for bool_list, expected in examples: - # result = is_contiguously_true_or_false(bool_list) - - def load_data_with_delta_timestamps( - data_dict, data_ids_per_episode, delta_timestamps, key, current_ts, episode + data_dict: dict[torch.Tensor], + data_ids_per_episode: dict[torch.Tensor], + delta_timestamps: list[float], + key: str, + current_ts: float, + episode: int, + tol: float = 0.04, ): + """ + Given a current timestamp (e.g. current_ts=0.6) and a list of timestamps differences (e.g. delta_timestamps=[-0.8, -0.2, 0, 0.2]), + this function compute the query timestamps (e.g. [-0.2, 0.4, 0.6, 0.8]) and loads the closest frames of the specified modality (e.g. key="observation.image"). + + Importantly, when no frame can be found around a query timestamp within a specified tolerance window (e.g. tol=0.04), this function raises an AssertionError. + When a timestamp is queried before the first available timestamp of the episode or after the last available timestamp, + the violation of the tolerance doesnt raise an AssertionError, and the function populates a boolean array indicating which frames are outside of the episode range. + For instance, this boolean array is useful during batched training to not supervise actions associated to timestamps coming after the end of the episode, + or to pad the observations in a specific way. Note that by default the observation frames before the start of the episode are the same as the first frame of the episode. + + Parameters: + - data_dict (dict): A dictionary containing the data, where each key corresponds to a different modality (e.g., "timestamp", "observation.image", "action"). + - data_ids_per_episode (dict): A dictionary where keys are episode identifiers and values are lists of indices corresponding to frames associated with each episode. + - delta_timestamps (dict): A dictionary containing lists of delta timestamps for each possible key to be retrieved. These deltas are added to the current_ts to form the query timestamps. + - key (str): The key specifying which data modality is to be retrieved from the data_dict. + - current_ts (float): The current timestamp to which the delta timestamps are added to form the query timestamps. + - episode (int): The identifier of the episode from which frames are to be retrieved. + - tol (float, optional): The tolerance level used to determine if a data point is close enough to the query timestamp. Defaults to 0.04. + + Returns: + - tuple: A tuple containing two elements: + - The first element is the data retrieved from the specified modality based on the closest match to the query timestamps. + - The second element is a boolean array indicating which frames were considered as padding (True if the distance to the closest timestamp was greater than the tolerance level). + + Raises: + - AssertionError: If any of the frames unexpectedly violate the tolerance level. This could indicate synchronization issues with timestamps during data collection. + """ # get indices of the frames associated to the episode, and their timestamps ep_data_ids = data_ids_per_episode[episode] ep_timestamps = data_dict["timestamp"][ep_data_ids] + # we make the assumption that the timestamps are sorted + ep_first_ts = ep_timestamps[0] + ep_last_ts = ep_timestamps[-1] + # get timestamps used as query to retrieve data of previous/future frames delta_ts = delta_timestamps[key] query_ts = current_ts + torch.tensor(delta_ts) # compute distances between each query timestamp and all timestamps of all the frames belonging to the episode - dist = euclidean_distance_matrix(query_ts[:, None], ep_timestamps[:, None]) + dist = torch.cdist(query_ts[:, None], ep_timestamps[:, None], p=1) min_, argmin_ = dist.min(1) # get the indices of the data that are closest to the query timestamps @@ -92,24 +95,29 @@ def load_data_with_delta_timestamps( # TODO(rcadene): synchronize timestamps + interpolation if needed - tol = 0.04 is_pad = min_ > tol - assert is_contiguously_true_or_false(is_pad), ( - f"One or several timestamps unexpectedly violate the tolerance ({min_} > {tol=})." + # check violated query timestamps are all outside the episode range + assert ((query_ts[is_pad] < ep_first_ts) | (ep_last_ts < query_ts[is_pad])).all(), ( + f"One or several timestamps unexpectedly violate the tolerance ({min_} > {tol=}) inside episode range." "This might be due to synchronization issues with timestamps during data collection." ) return data, is_pad -def compute_or_load_stats(dataset, batch_size=32, max_num_samples=None): - stats_path = dataset.data_dir / "stats.pth" - if stats_path.exists(): - return torch.load(stats_path) +def get_stats_einops_patterns(dataset): + """These einops patterns will be used to aggregate batches and compute statistics.""" + stats_patterns = { + "action": "b c -> c", + "observation.state": "b c -> c", + } + for key in dataset.image_keys: + stats_patterns[key] = "b c h w -> c 1 1" + return stats_patterns - logging.info(f"compute_stats and save to {stats_path}") +def compute_stats(dataset, batch_size=32, max_num_samples=None): if max_num_samples is None: max_num_samples = len(dataset) else: @@ -124,13 +132,8 @@ def compute_or_load_stats(dataset, batch_size=32, max_num_samples=None): drop_last=False, ) - # these einops patterns will be used to aggregate batches and compute statistics - stats_patterns = { - "action": "b c -> c", - "observation.state": "b c -> c", - } - for key in dataset.image_keys: - stats_patterns[key] = "b c h w -> c 1 1" + # get einops patterns to aggregate batches and compute statistics + stats_patterns = get_stats_einops_patterns(dataset) # mean and std will be computed incrementally while max and min will track the running value. mean, std, max, min = {}, {}, {}, {} @@ -201,7 +204,6 @@ def compute_or_load_stats(dataset, batch_size=32, max_num_samples=None): "min": min[key], } - torch.save(stats, stats_path) return stats diff --git a/lerobot/common/datasets/xarm.py b/lerobot/common/datasets/xarm.py index 733267ab..0dfcc5c9 100644 --- a/lerobot/common/datasets/xarm.py +++ b/lerobot/common/datasets/xarm.py @@ -60,7 +60,7 @@ class XarmDataset(torch.utils.data.Dataset): @property def num_samples(self) -> int: - return len(self.data_dict["index"]) + return len(self.data_dict["index"]) if "index" in self.data_dict else 0 @property def num_episodes(self) -> int: @@ -126,7 +126,8 @@ class XarmDataset(torch.utils.data.Dataset): image = torch.tensor(dataset_dict["observations"]["rgb"][idx0:idx1]) state = torch.tensor(dataset_dict["observations"]["state"][idx0:idx1]) action = torch.tensor(dataset_dict["actions"][idx0:idx1]) - # TODO(rcadene): concat the last "next_observations" to "observations" + # TODO(rcadene): we have a missing last frame which is the observation when the env is done + # it is critical to have this frame for tdmpc to predict a "done observation/state" # next_image = torch.tensor(dataset_dict["next_observations"]["rgb"][idx0:idx1]) # next_state = torch.tensor(dataset_dict["next_observations"]["state"][idx0:idx1]) next_reward = torch.tensor(dataset_dict["rewards"][idx0:idx1]) diff --git a/lerobot/common/envs/utils.py b/lerobot/common/envs/utils.py index 9d0fb853..4d31ddb2 100644 --- a/lerobot/common/envs/utils.py +++ b/lerobot/common/envs/utils.py @@ -19,6 +19,7 @@ def preprocess_observation(observation, transform=None): img = einops.rearrange(img, "b h w c -> b c h w") obs[imgkey] = img + # TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing requirement for "agent_pos" obs["observation.state"] = torch.from_numpy(observation["agent_pos"]).float() # apply same transforms as in training diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py index 325b5608..a287614d 100644 --- a/lerobot/common/policies/factory.py +++ b/lerobot/common/policies/factory.py @@ -29,9 +29,9 @@ def make_policy(cfg): if cfg.policy.pretrained_model_path: # TODO(rcadene): hack for old pretrained models from fowm if cfg.policy.name == "tdmpc" and "fowm" in cfg.policy.pretrained_model_path: - if "offline" in cfg.pretrained_model_path: + if "offline" in cfg.policy.pretrained_model_path: policy.step[0] = 25000 - elif "final" in cfg.pretrained_model_path: + elif "final" in cfg.policy.pretrained_model_path: policy.step[0] = 100000 else: raise NotImplementedError() diff --git a/lerobot/common/policies/tdmpc/policy.py b/lerobot/common/policies/tdmpc/policy.py index 942ee9b1..14728576 100644 --- a/lerobot/common/policies/tdmpc/policy.py +++ b/lerobot/common/policies/tdmpc/policy.py @@ -333,94 +333,6 @@ class TDMPCPolicy(nn.Module): """Main update function. Corresponds to one iteration of the model learning.""" start_time = time.time() - # num_slices = self.cfg.batch_size - # batch_size = self.cfg.horizon * num_slices - - # if demo_buffer is None: - # demo_batch_size = 0 - # else: - # # Update oversampling ratio - # demo_pc_batch = h.linear_schedule(self.cfg.demo_schedule, step) - # demo_num_slices = int(demo_pc_batch * self.batch_size) - # demo_batch_size = self.cfg.horizon * demo_num_slices - # batch_size -= demo_batch_size - # num_slices -= demo_num_slices - # replay_buffer._sampler.num_slices = num_slices - # demo_buffer._sampler.num_slices = demo_num_slices - - # assert demo_batch_size % self.cfg.horizon == 0 - # assert demo_batch_size % demo_num_slices == 0 - - # assert batch_size % self.cfg.horizon == 0 - # assert batch_size % num_slices == 0 - - # # Sample from interaction dataset - - # def process_batch(batch, horizon, num_slices): - # # trajectory t = 256, horizon h = 5 - # # (t h) ... -> h t ... - # batch = batch.reshape(num_slices, horizon).transpose(1, 0).contiguous() - - # obs = { - # "rgb": batch["observation", "image"][FIRST_FRAME].to(self.device, non_blocking=True), - # "state": batch["observation", "state"][FIRST_FRAME].to(self.device, non_blocking=True), - # } - # action = batch["action"].to(self.device, non_blocking=True) - # next_obses = { - # "rgb": batch["next", "observation", "image"].to(self.device, non_blocking=True), - # "state": batch["next", "observation", "state"].to(self.device, non_blocking=True), - # } - # reward = batch["next", "reward"].to(self.device, non_blocking=True) - - # idxs = batch["index"][FIRST_FRAME].to(self.device, non_blocking=True) - # weights = batch["_weight"][FIRST_FRAME, :, None].to(self.device, non_blocking=True) - - # # TODO(rcadene): rearrange directly in offline dataset - # if reward.ndim == 2: - # reward = einops.rearrange(reward, "h t -> h t 1") - - # assert reward.ndim == 3 - # assert reward.shape == (horizon, num_slices, 1) - # # We dont use `batch["next", "done"]` since it only indicates the end of an - # # episode, but not the end of the trajectory of an episode. - # # Neither does `batch["next", "terminated"]` - # done = torch.zeros_like(reward, dtype=torch.bool, device=reward.device) - # mask = torch.ones_like(reward, dtype=torch.bool, device=reward.device) - # return obs, action, next_obses, reward, mask, done, idxs, weights - - # batch = replay_buffer.sample(batch_size) if self.cfg.balanced_sampling else replay_buffer.sample() - - # obs, action, next_obses, reward, mask, done, idxs, weights = process_batch( - # batch, self.cfg.horizon, num_slices - # ) - - # Sample from demonstration dataset - # if demo_batch_size > 0: - # demo_batch = demo_buffer.sample(demo_batch_size) - # ( - # demo_obs, - # demo_action, - # demo_next_obses, - # demo_reward, - # demo_mask, - # demo_done, - # demo_idxs, - # demo_weights, - # ) = process_batch(demo_batch, self.cfg.horizon, demo_num_slices) - - # if isinstance(obs, dict): - # obs = {k: torch.cat([obs[k], demo_obs[k]]) for k in obs} - # next_obses = {k: torch.cat([next_obses[k], demo_next_obses[k]], dim=1) for k in next_obses} - # else: - # obs = torch.cat([obs, demo_obs]) - # next_obses = torch.cat([next_obses, demo_next_obses], dim=1) - # action = torch.cat([action, demo_action], dim=1) - # reward = torch.cat([reward, demo_reward], dim=1) - # mask = torch.cat([mask, demo_mask], dim=1) - # done = torch.cat([done, demo_done], dim=1) - # idxs = torch.cat([idxs, demo_idxs]) - # weights = torch.cat([weights, demo_weights]) - batch_size = batch["index"].shape[0] # TODO(rcadene): convert tdmpc with (batch size, time/horizon, channels) @@ -534,6 +446,7 @@ class TDMPCPolicy(nn.Module): ) self.optim.step() + # TODO(rcadene): implement PrioritizedSampling by modifying sampler.weights with priorities computed by a criterion # if self.cfg.per: # # Update priorities # priorities = priority_loss.clamp(max=1e4).detach() diff --git a/lerobot/common/utils.py b/lerobot/common/utils.py index 28270bac..81b3d986 100644 --- a/lerobot/common/utils.py +++ b/lerobot/common/utils.py @@ -99,6 +99,7 @@ def init_hydra_config(config_path: str, overrides: list[str] | None = None) -> D def print_cuda_memory_usage(): + """Use this function to locate and debug memory leak.""" import gc gc.collect() diff --git a/lerobot/configs/env/aloha.yaml b/lerobot/configs/env/aloha.yaml index 7a8d8b58..6b836795 100644 --- a/lerobot/configs/env/aloha.yaml +++ b/lerobot/configs/env/aloha.yaml @@ -18,7 +18,6 @@ env: from_pixels: True pixels_only: False image_size: [3, 480, 640] - action_repeat: 1 episode_length: 400 fps: ${fps} diff --git a/lerobot/configs/env/pusht.yaml b/lerobot/configs/env/pusht.yaml index a5fbcc25..a7097ffd 100644 --- a/lerobot/configs/env/pusht.yaml +++ b/lerobot/configs/env/pusht.yaml @@ -18,7 +18,6 @@ env: from_pixels: True pixels_only: False image_size: 96 - action_repeat: 1 episode_length: 300 fps: ${fps} diff --git a/lerobot/configs/env/xarm.yaml b/lerobot/configs/env/xarm.yaml index 8b3c72ef..bcba659e 100644 --- a/lerobot/configs/env/xarm.yaml +++ b/lerobot/configs/env/xarm.yaml @@ -17,7 +17,6 @@ env: from_pixels: True pixels_only: False image_size: 84 - # action_repeat: 2 # we can remove if policy has n_action_steps=2 episode_length: 25 fps: ${fps} diff --git a/lerobot/configs/policy/tdmpc.yaml b/lerobot/configs/policy/tdmpc.yaml index 2ebaad9b..4fd2b6bb 100644 --- a/lerobot/configs/policy/tdmpc.yaml +++ b/lerobot/configs/policy/tdmpc.yaml @@ -36,6 +36,7 @@ policy: log_std_max: 2 # learning + batch_size: 256 max_buffer_size: 10000 horizon: 5 reward_coef: 0.5 diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index 06459a85..2b8906d7 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -32,6 +32,7 @@ import json import logging import threading import time +from copy import deepcopy from datetime import datetime as dt from pathlib import Path @@ -56,15 +57,15 @@ def write_video(video_path, stacked_frames, fps): def eval_policy( env: gym.vector.VectorEnv, - policy, - save_video: bool = False, + policy: torch.nn.Module, + max_episodes_rendered: int = 0, video_dir: Path = None, # TODO(rcadene): make it possible to overwrite fps? we should use env.fps - fps: int = 15, - return_first_video: bool = False, transform: callable = None, seed=None, ): + fps = env.unwrapped.metadata["render_fps"] + if policy is not None: policy.eval() device = "cpu" if policy is None else next(policy.parameters()).device @@ -83,14 +84,11 @@ def eval_policy( # needed as I'm currently taking a ceil. ep_frames = [] - def maybe_render_frame(env): - if save_video: # noqa: B023 - if return_first_video: - visu = env.envs[0].render() - visu = visu[None, ...] # add batch dim - else: - visu = np.stack([env.render() for env in env.envs]) - ep_frames.append(visu) # noqa: B023 + def render_frame(env): + # noqa: B023 + eps_rendered = min(max_episodes_rendered, len(env.envs)) + visu = np.stack([env.envs[i].render() for i in range(eps_rendered)]) + ep_frames.append(visu) # noqa: B023 for _ in range(num_episodes): seeds.append("TODO") @@ -104,8 +102,14 @@ def eval_policy( # reset the environment observation, info = env.reset(seed=seed) - maybe_render_frame(env) + if max_episodes_rendered > 0: + render_frame(env) + observations = [] + actions = [] + # episode + # frame_id + # timestamp rewards = [] successes = [] dones = [] @@ -113,8 +117,13 @@ def eval_policy( done = torch.tensor([False for _ in env.envs]) step = 0 while not done.all(): + # format from env keys to lerobot keys + observation = preprocess_observation(observation) + observations.append(deepcopy(observation)) + # apply transform to normalize the observations - observation = preprocess_observation(observation, transform) + for key in observation: + observation[key] = torch.stack([transform({key: item})[key] for item in observation[key]]) # send observation to device/gpu observation = {key: observation[key].to(device, non_blocking=True) for key in observation} @@ -126,11 +135,13 @@ def eval_policy( # apply inverse transform to unnormalize the action action = postprocess_action(action, transform) - # apply the next + # apply the next action observation, reward, terminated, truncated, info = env.step(action) - maybe_render_frame(env) + if max_episodes_rendered > 0: + render_frame(env) # TODO(rcadene): implement a wrapper over env to return torch tensors in float32 (and cuda?) + action = torch.from_numpy(action) reward = torch.from_numpy(reward) terminated = torch.from_numpy(terminated) truncated = torch.from_numpy(truncated) @@ -147,12 +158,24 @@ def eval_policy( success = [False for _ in env.envs] success = torch.tensor(success) + actions.append(action) rewards.append(reward) dones.append(done) successes.append(success) step += 1 + env.close() + + # add the last observation when the env is done + observation = preprocess_observation(observation) + observations.append(deepcopy(observation)) + + new_obses = {} + for key in observations[0].keys(): # noqa: SIM118 + new_obses[key] = torch.stack([obs[key] for obs in observations], dim=1) + observations = new_obses + actions = torch.stack(actions, dim=1) rewards = torch.stack(rewards, dim=1) successes = torch.stack(successes, dim=1) dones = torch.stack(dones, dim=1) @@ -172,29 +195,61 @@ def eval_policy( max_rewards.extend(batch_max_reward.tolist()) all_successes.extend(batch_success.tolist()) - env.close() + # similar logic is implemented in dataset preprocessing + ep_dicts = [] + num_episodes = dones.shape[0] + total_frames = 0 + idx0 = idx1 = 0 + data_ids_per_episode = {} + for ep_id in range(num_episodes): + num_frames = done_indices[ep_id].item() + 1 + # TODO(rcadene): We need to add a missing last frame which is the observation + # of a done state. it is critical to have this frame for tdmpc to predict a "done observation/state" + ep_dict = { + "action": actions[ep_id, :num_frames], + "episode": torch.tensor([ep_id] * num_frames), + "frame_id": torch.arange(0, num_frames, 1), + "timestamp": torch.arange(0, num_frames, 1) / fps, + "next.done": dones[ep_id, :num_frames], + "next.reward": rewards[ep_id, :num_frames].type(torch.float32), + } + for key in observations: + ep_dict[key] = observations[key][ep_id, :num_frames] + ep_dicts.append(ep_dict) - if save_video or return_first_video: + total_frames += num_frames + idx1 += num_frames + + data_ids_per_episode[ep_id] = torch.arange(idx0, idx1, 1) + + idx0 = idx1 + + # similar logic is implemented in dataset preprocessing + data_dict = {} + keys = ep_dicts[0].keys() + for key in keys: + data_dict[key] = torch.cat([x[key] for x in ep_dicts]) + data_dict["index"] = torch.arange(0, total_frames, 1) + + if max_episodes_rendered > 0: batch_stacked_frames = np.stack(ep_frames, 1) # (b, t, *) - if save_video: - for stacked_frames, done_index in zip( - batch_stacked_frames, done_indices.flatten().tolist(), strict=False - ): - if episode_counter >= num_episodes: - continue - video_dir.mkdir(parents=True, exist_ok=True) - video_path = video_dir / f"eval_episode_{episode_counter}.mp4" - thread = threading.Thread( - target=write_video, - args=(str(video_path), stacked_frames[:done_index], fps), - ) - thread.start() - threads.append(thread) - episode_counter += 1 + for stacked_frames, done_index in zip( + batch_stacked_frames, done_indices.flatten().tolist(), strict=False + ): + if episode_counter >= num_episodes: + continue + video_dir.mkdir(parents=True, exist_ok=True) + video_path = video_dir / f"eval_episode_{episode_counter}.mp4" + thread = threading.Thread( + target=write_video, + args=(str(video_path), stacked_frames[:done_index], fps), + ) + thread.start() + threads.append(thread) + episode_counter += 1 - if return_first_video: - first_video = batch_stacked_frames[0].transpose(0, 3, 1, 2) + videos = einops.rearrange(batch_stacked_frames, "b t h w c -> b t c h w") for thread in threads: thread.join() @@ -225,9 +280,13 @@ def eval_policy( "eval_s": time.time() - start, "eval_ep_s": (time.time() - start) / num_episodes, }, + "episodes": { + "data_dict": data_dict, + "data_ids_per_episode": data_ids_per_episode, + }, } - if return_first_video: - return info, first_video + if max_episodes_rendered > 0: + info["videos"] = videos return info @@ -253,16 +312,14 @@ def eval(cfg: dict, out_dir=None, stats_path=None): logging.info("Making environment.") env = make_env(cfg, num_parallel_envs=cfg.eval_episodes) - # when policy is None, rollout a random policy - policy = make_policy(cfg) if cfg.policy.pretrained_model_path else None + logging.info("Making policy.") + policy = make_policy(cfg) info = eval_policy( env, - policy=policy, - save_video=True, + policy, + max_episodes_rendered=10, video_dir=Path(out_dir) / "eval", - fps=cfg.env.fps, - # TODO(rcadene): what should we do with the transform? transform=transform, seed=cfg.seed, ) @@ -270,6 +327,9 @@ def eval(cfg: dict, out_dir=None, stats_path=None): # Save info with open(Path(out_dir) / "eval_info.json", "w") as f: + # remove pytorch tensors which are not serializable to save the evaluation results only + del info["episodes"] + del info["videos"] json.dump(info, f, indent=2) logging.info("End of eval") diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index dd3da978..300a8617 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -1,8 +1,8 @@ import logging +from copy import deepcopy from pathlib import Path import hydra -import numpy as np import torch from lerobot.common.datasets.factory import make_dataset @@ -108,6 +108,64 @@ def log_eval_info(logger, info, step, cfg, dataset, is_offline): logger.log_dict(info, step, mode="eval") +def calculate_online_sample_weight(n_off: int, n_on: int, pc_on: float): + """ + Calculate the sampling weight to be assigned to samples so that a specified percentage of the batch comes from online dataset (on average). + + Parameters: + - n_off (int): Number of offline samples, each with a sampling weight of 1. + - n_on (int): Number of online samples. + - pc_on (float): Desired percentage of online samples in decimal form (e.g., 50% as 0.5). + + The total weight of offline samples is n_off * 1.0. + The total weight of offline samples is n_on * w. + The total combined weight of all samples is n_off + n_on * w. + The fraction of the weight that is online is n_on * w / (n_off + n_on * w). + We want this fraction to equal pc_on, so we set up the equation n_on * w / (n_off + n_on * w) = pc_on. + The solution is w = - (n_off * pc_on) / (n_on * (pc_on - 1)) + """ + assert 0.0 <= pc_on <= 1.0 + return -(n_off * pc_on) / (n_on * (pc_on - 1)) + + +def add_episodes_inplace(episodes, online_dataset, concat_dataset, sampler, pc_online_samples): + data_dict = episodes["data_dict"] + data_ids_per_episode = episodes["data_ids_per_episode"] + + if len(online_dataset) == 0: + # initialize online dataset + online_dataset.data_dict = data_dict + online_dataset.data_ids_per_episode = data_ids_per_episode + else: + # find episode index and data frame indices according to previous episode in online_dataset + start_episode = max(online_dataset.data_ids_per_episode.keys()) + 1 + start_index = online_dataset.data_dict["index"][-1].item() + 1 + data_dict["episode"] += start_episode + data_dict["index"] += start_index + + # extend online dataset + for key in data_dict: + # TODO(rcadene): avoid reallocating memory at every step by preallocating memory or changing our data structure + online_dataset.data_dict[key] = torch.cat([online_dataset.data_dict[key], data_dict[key]]) + for ep_id in data_ids_per_episode: + online_dataset.data_ids_per_episode[ep_id + start_episode] = ( + data_ids_per_episode[ep_id] + start_index + ) + + # update the concatenated dataset length used during sampling + concat_dataset.cumulative_sizes = concat_dataset.cumsum(concat_dataset.datasets) + + # update the sampling weights for each frame so that online frames get sampled a certain percentage of times + len_online = len(online_dataset) + len_offline = len(concat_dataset) - len_online + weight_offline = 1.0 + weight_online = calculate_online_sample_weight(len_offline, len_online, pc_online_samples) + sampler.weights = torch.tensor([weight_offline] * len_offline + [weight_online] * len(online_dataset)) + + # update the total number of samples used during sampling + sampler.num_samples = len(concat_dataset) + + def train(cfg: dict, out_dir=None, job_name=None): if out_dir is None: raise NotImplementedError() @@ -126,26 +184,7 @@ def train(cfg: dict, out_dir=None, job_name=None): set_global_seed(cfg.seed) logging.info("make_dataset") - dataset = make_dataset(cfg) - - # TODO(rcadene): move balanced_sampling, per_alpha, per_beta outside policy - # if cfg.policy.balanced_sampling: - # logging.info("make online_buffer") - # num_traj_per_batch = cfg.policy.batch_size - - # online_sampler = PrioritizedSliceSampler( - # max_capacity=100_000, - # alpha=cfg.policy.per_alpha, - # beta=cfg.policy.per_beta, - # num_slices=num_traj_per_batch, - # strict_length=True, - # ) - - # online_buffer = TensorDictReplayBuffer( - # storage=LazyMemmapStorage(100_000), - # sampler=online_sampler, - # transform=dataset.transform, - # ) + offline_dataset = make_dataset(cfg) logging.info("make_env") env = make_env(cfg, num_parallel_envs=cfg.eval_episodes) @@ -163,9 +202,8 @@ def train(cfg: dict, out_dir=None, job_name=None): logging.info(f"{cfg.env.task=}") logging.info(f"{cfg.offline_steps=} ({format_big_number(cfg.offline_steps)})") logging.info(f"{cfg.online_steps=}") - logging.info(f"{cfg.env.action_repeat=}") - logging.info(f"{dataset.num_samples=} ({format_big_number(dataset.num_samples)})") - logging.info(f"{dataset.num_episodes=}") + logging.info(f"{offline_dataset.num_samples=} ({format_big_number(offline_dataset.num_samples)})") + logging.info(f"{offline_dataset.num_episodes=}") logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})") logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})") @@ -173,18 +211,17 @@ def train(cfg: dict, out_dir=None, job_name=None): def _maybe_eval_and_maybe_save(step): if step % cfg.eval_freq == 0: logging.info(f"Eval policy at step {step}") - eval_info, first_video = eval_policy( + eval_info = eval_policy( env, policy, - return_first_video=True, video_dir=Path(out_dir) / "eval", - save_video=True, - transform=dataset.transform, + max_episodes_rendered=4, + transform=offline_dataset.transform, seed=cfg.seed, ) - log_eval_info(logger, eval_info["aggregated"], step, cfg, dataset, is_offline) + log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_dataset, is_offline) if cfg.wandb.enable: - logger.log_video(first_video, step, mode="eval") + logger.log_video(eval_info["videos"][0], step, mode="eval") logging.info("Resume training") if cfg.save_model and step % cfg.save_freq == 0: @@ -192,18 +229,19 @@ def train(cfg: dict, out_dir=None, job_name=None): logger.save_model(policy, identifier=step) logging.info("Resume training") - step = 0 # number of policy update (forward + backward + optim) - - is_offline = True + # create dataloader for offline training dataloader = torch.utils.data.DataLoader( - dataset, + offline_dataset, num_workers=4, batch_size=cfg.policy.batch_size, shuffle=True, pin_memory=cfg.device != "cpu", - drop_last=True, + drop_last=False, ) dl_iter = cycle(dataloader) + + step = 0 # number of policy update (forward + backward + optim) + is_offline = True for offline_step in range(cfg.offline_steps): if offline_step == 0: logging.info("Start offline training on a fixed dataset") @@ -217,7 +255,7 @@ def train(cfg: dict, out_dir=None, job_name=None): # TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done? if step % cfg.log_freq == 0: - log_train_info(logger, train_info, step, cfg, dataset, is_offline) + log_train_info(logger, train_info, step, cfg, offline_dataset, is_offline) # Note: _maybe_eval_and_maybe_save happens **after** the `step`th training update has completed, so we pass in # step + 1. @@ -225,61 +263,60 @@ def train(cfg: dict, out_dir=None, job_name=None): step += 1 - raise NotImplementedError() + # create an env dedicated to online episodes collection from policy rollout + rollout_env = make_env(cfg, num_parallel_envs=1) + + # create an empty online dataset similar to offline dataset + online_dataset = deepcopy(offline_dataset) + online_dataset.data_dict = {} + online_dataset.data_ids_per_episode = {} + + # create dataloader for online training + concat_dataset = torch.utils.data.ConcatDataset([offline_dataset, online_dataset]) + weights = [1.0] * len(concat_dataset) + sampler = torch.utils.data.WeightedRandomSampler( + weights, num_samples=len(concat_dataset), replacement=True + ) + dataloader = torch.utils.data.DataLoader( + concat_dataset, + num_workers=4, + batch_size=cfg.policy.batch_size, + sampler=sampler, + pin_memory=cfg.device != "cpu", + drop_last=False, + ) + dl_iter = cycle(dataloader) - demo_buffer = dataset if cfg.policy.balanced_sampling else None online_step = 0 is_offline = False for env_step in range(cfg.online_steps): if env_step == 0: logging.info("Start online training by interacting with environment") - # TODO: add configurable number of rollout? (default=1) + with torch.no_grad(): - rollout = env.rollout( - max_steps=cfg.env.episode_length, - policy=policy, - auto_cast_to_device=True, + eval_info = eval_policy( + rollout_env, + policy, + transform=offline_dataset.transform, + seed=cfg.seed, ) - assert ( - len(rollout.batch_size) == 2 - ), "2 dimensions expected: number of env in parallel x max number of steps during rollout" - - num_parallel_env = rollout.batch_size[0] - if num_parallel_env != 1: - # TODO(rcadene): when num_parallel_env > 1, rollout["episode"] needs to be properly set and we need to add tests - raise NotImplementedError() - - num_max_steps = rollout.batch_size[1] - assert num_max_steps <= cfg.env.episode_length - - # reshape to have a list of steps to insert into online_buffer - rollout = rollout.reshape(num_parallel_env * num_max_steps) - - # set same episode index for all time steps contained in this rollout - rollout["episode"] = torch.tensor([env_step] * len(rollout), dtype=torch.int) - # online_buffer.extend(rollout) - - ep_sum_reward = rollout["next", "reward"].sum() - ep_max_reward = rollout["next", "reward"].max() - ep_success = rollout["next", "success"].any() - rollout_info = { - "avg_sum_reward": np.nanmean(ep_sum_reward), - "avg_max_reward": np.nanmean(ep_max_reward), - "pc_success": np.nanmean(ep_success) * 100, - "env_step": env_step, - "ep_length": len(rollout), - } + online_pc_sampling = cfg.get("demo_schedule", 0.5) + add_episodes_inplace( + eval_info["episodes"], online_dataset, concat_dataset, sampler, online_pc_sampling + ) for _ in range(cfg.policy.utd): - train_info = policy.update( - # online_buffer, - step, - demo_buffer=demo_buffer, - ) + policy.train() + batch = next(dl_iter) + + for key in batch: + batch[key] = batch[key].to(cfg.device, non_blocking=True) + + train_info = policy(batch, step) + if step % cfg.log_freq == 0: - train_info.update(rollout_info) - log_train_info(logger, train_info, step, cfg, dataset, is_offline) + log_train_info(logger, train_info, step, cfg, online_dataset, is_offline) # Note: _maybe_eval_and_maybe_save happens **after** the `step`th training update has completed, so we pass # in step + 1. diff --git a/lerobot/scripts/visualize_dataset.py b/lerobot/scripts/visualize_dataset.py index 93315e90..4b7b7d6c 100644 --- a/lerobot/scripts/visualize_dataset.py +++ b/lerobot/scripts/visualize_dataset.py @@ -6,9 +6,6 @@ import einops import hydra import imageio import torch -from torchrl.data.replay_buffers import ( - SamplerWithoutReplacement, -) from lerobot.common.datasets.factory import make_dataset from lerobot.common.logger import log_output_dir @@ -39,19 +36,11 @@ def visualize_dataset(cfg: dict, out_dir=None): init_logging() log_output_dir(out_dir) - # we expect frames of each episode to be stored next to each others sequentially - sampler = SamplerWithoutReplacement( - shuffle=False, - ) - logging.info("make_dataset") dataset = make_dataset( cfg, - overwrite_sampler=sampler, # remove all transformations such as rescale images from [0,255] to [0,1] or normalization normalize=False, - overwrite_batch_size=1, - overwrite_prefetch=12, ) logging.info("Start rendering episodes from offline buffer") @@ -60,64 +49,49 @@ def visualize_dataset(cfg: dict, out_dir=None): logging.info(video_path) -def render_dataset(dataset, out_dir, max_num_samples, fps): +def render_dataset(dataset, out_dir, max_num_episodes): out_dir = Path(out_dir) video_paths = [] threads = [] - frames = {} - current_ep_idx = 0 - logging.info(f"Visualizing episode {current_ep_idx}") - for i in range(max_num_samples): - # TODO(rcadene): make it work with bsize > 1 - ep_td = dataset.sample(1) - ep_idx = ep_td["episode"][FIRST_FRAME].item() - # TODO(rcadene): modify dataset._sampler._sample_list or sampler to randomly sample an episode, but sequentially sample frames - num_frames_left = dataset._sampler._sample_list.numel() - episode_is_done = ep_idx != current_ep_idx + dataloader = torch.utils.data.DataLoader( + dataset, + num_workers=4, + batch_size=1, + shuffle=False, + ) + dl_iter = iter(dataloader) - if episode_is_done: - logging.info(f"Rendering episode {current_ep_idx}") + num_episodes = len(dataset.data_ids_per_episode) + for ep_id in range(min(max_num_episodes, num_episodes)): + logging.info(f"Rendering episode {ep_id}") - for im_key in dataset.image_keys: - if not episode_is_done and num_frames_left > 0 and i < (max_num_samples - 1): + frames = {} + for _ in dataset.data_ids_per_episode[ep_id]: + item = next(dl_iter) + + for im_key in dataset.image_keys: # when first frame of episode, initialize frames dict if im_key not in frames: frames[im_key] = [] # add current frame to list of frames to render - frames[im_key].append(ep_td[im_key]) + frames[im_key].append(item[im_key]) + + out_dir.mkdir(parents=True, exist_ok=True) + for im_key in dataset.image_keys: + if len(dataset.image_keys) > 1: + im_name = im_key.replace("observation.images.", "") + video_path = out_dir / f"episode_{ep_id}_{im_name}.mp4" else: - # When episode has no more frame in its list of observation, - # one frame still remains. It is the result of the last action taken. - # It is stored in `"next"`, so we add it to the list of frames to render. - frames[im_key].append(ep_td["next"][im_key]) + video_path = out_dir / f"episode_{ep_id}.mp4" + video_paths.append(video_path) - out_dir.mkdir(parents=True, exist_ok=True) - if len(dataset.image_keys) > 1: - camera = im_key[-1] - video_path = out_dir / f"episode_{current_ep_idx}_{camera}.mp4" - else: - video_path = out_dir / f"episode_{current_ep_idx}.mp4" - video_paths.append(str(video_path)) - - thread = threading.Thread( - target=cat_and_write_video, - args=(str(video_path), frames[im_key], fps), - ) - thread.start() - threads.append(thread) - - current_ep_idx = ep_idx - - # reset list of frames - del frames[im_key] - - if num_frames_left == 0: - logging.info("Ran out of frames") - break - - if current_ep_idx == NUM_EPISODES_TO_RENDER: - break + thread = threading.Thread( + target=cat_and_write_video, + args=(str(video_path), frames[im_key], dataset.fps), + ) + thread.start() + threads.append(thread) for thread in threads: thread.join() diff --git a/poetry.lock b/poetry.lock index faeb70f1..0133b3ed 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.1 and should not be changed by hand. [[package]] name = "absl-py" @@ -921,7 +921,7 @@ shapely = "^2.0.3" type = "git" url = "git@github.com:huggingface/gym-pusht.git" reference = "HEAD" -resolved_reference = "824b22832cc8d71a4b4e96a57563510cf47e30c1" +resolved_reference = "080d4ce4d8d3140b2fd204ed628bda14dc58ff06" [[package]] name = "gym-xarm" @@ -941,7 +941,7 @@ mujoco = "^2.3.7" type = "git" url = "git@github.com:huggingface/gym-xarm.git" reference = "HEAD" -resolved_reference = "ce294c0d30def08414d9237e2bf9f373d448ca07" +resolved_reference = "6a88f7d63833705dfbec4b997bf36cac6b4a448c" [[package]] name = "gymnasium" diff --git a/sbatch.sh b/sbatch.sh deleted file mode 100644 index c08f7055..00000000 --- a/sbatch.sh +++ /dev/null @@ -1,25 +0,0 @@ -#!/bin/bash -#SBATCH --nodes=1 # total number of nodes (N to be defined) -#SBATCH --ntasks-per-node=1 # number of tasks per node (here 8 tasks, or 1 task per GPU) -#SBATCH --gres=gpu:1 # number of GPUs reserved per node (here 8, or all the GPUs) -#SBATCH --cpus-per-task=8 # number of cores per task (8x8 = 64 cores, or all the cores) -#SBATCH --time=2-00:00:00 -#SBATCH --output=/home/rcadene/slurm/%j.out -#SBATCH --error=/home/rcadene/slurm/%j.err -#SBATCH --qos=low -#SBATCH --mail-user=re.cadene@gmail.com -#SBATCH --mail-type=ALL - -CMD=$@ -echo "command: $CMD" - -apptainer exec --nv \ -~/apptainer/nvidia_cuda:12.2.2-devel-ubuntu22.04.sif $SHELL - -source ~/.bashrc -#conda activate fowm -conda activate lerobot - -export DATA_DIR="data" - -srun $CMD diff --git a/sbatch_hopper.sh b/sbatch_hopper.sh deleted file mode 100644 index cc410048..00000000 --- a/sbatch_hopper.sh +++ /dev/null @@ -1,17 +0,0 @@ -#!/bin/bash -#SBATCH --nodes=1 # total number of nodes (N to be defined) -#SBATCH --ntasks-per-node=1 # number of tasks per node (here 8 tasks, or 1 task per GPU) -#SBATCH --qos=normal # number of GPUs reserved per node (here 8, or all the GPUs) -#SBATCH --partition=hopper-prod -#SBATCH --gres=gpu:1 # number of GPUs reserved per node (here 8, or all the GPUs) -#SBATCH --cpus-per-task=12 # number of cores per task -#SBATCH --mem-per-cpu=11G -#SBATCH --time=12:00:00 -#SBATCH --output=/admin/home/remi_cadene/slurm/%j.out -#SBATCH --error=/admin/home/remi_cadene/slurm/%j.err -#SBATCH --mail-user=remi_cadene@huggingface.co -#SBATCH --mail-type=ALL - -CMD=$@ -echo "command: $CMD" -srun $CMD diff --git a/tests/data/pusht/data_dict.pth b/tests/data/pusht/data_dict.pth index 40d96a51..a083c86c 100644 Binary files a/tests/data/pusht/data_dict.pth and b/tests/data/pusht/data_dict.pth differ diff --git a/tests/test_available.py b/tests/test_available.py index 8df2c945..be74a42a 100644 --- a/tests/test_available.py +++ b/tests/test_available.py @@ -1,64 +1,53 @@ """ This test verifies that all environments, datasets, policies listed in `lerobot/__init__.py` can be sucessfully -imported and that their class attributes (eg. `available_datasets`, `name`, `available_tasks`) corresponds. +imported and that their class attributes (eg. `available_datasets`, `name`, `available_tasks`) are valid. -Note: - When implementing a concrete class (e.g. `AlohaDataset`, `PushtEnv`, `DiffusionPolicy`), you need to: - 1. set the required class attributes: - - for classes inheriting from `AbstractDataset`: `available_datasets` - - for classes inheriting from `AbstractEnv`: `name`, `available_tasks` - - for classes inheriting from `AbstractPolicy`: `name` - 2. update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`) - 3. update variables in `tests/test_available.py` by importing your new class +When implementing a new dataset (e.g. `AlohaDataset`), policy (e.g. `DiffusionPolicy`), or environment, follow these steps: +- Set the required class attributes: `available_datasets`. +- Set the required class attributes: `name`. +- Update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`) +- Update variables in `tests/test_available.py` by importing your new class """ +import importlib import pytest import lerobot +import gymnasium as gym -# from lerobot.common.envs.aloha.env import AlohaEnv -# from gym_pusht.envs import PushtEnv -# from gym_xarm.envs import SimxarmEnv +from lerobot.common.datasets.xarm import XarmDataset +from lerobot.common.datasets.aloha import AlohaDataset +from lerobot.common.datasets.pusht import PushtDataset -# from lerobot.common.datasets.xarm import SimxarmDataset -# from lerobot.common.datasets.aloha import AlohaDataset -# from lerobot.common.datasets.pusht import PushtDataset - -# from lerobot.common.policies.act.policy import ActionChunkingTransformerPolicy -# from lerobot.common.policies.diffusion.policy import DiffusionPolicy -# from lerobot.common.policies.tdmpc.policy import TDMPCPolicy +from lerobot.common.policies.act.policy import ActionChunkingTransformerPolicy +from lerobot.common.policies.diffusion.policy import DiffusionPolicy +from lerobot.common.policies.tdmpc.policy import TDMPCPolicy -# def test_available(): -# pol_classes = [ -# ActionChunkingTransformerPolicy, -# DiffusionPolicy, -# TDMPCPolicy, -# ] +def test_available(): + policy_classes = [ + ActionChunkingTransformerPolicy, + DiffusionPolicy, + TDMPCPolicy, + ] -# env_classes = [ -# AlohaEnv, -# PushtEnv, -# SimxarmEnv, -# ] - -# dat_classes = [ -# AlohaDataset, -# PushtDataset, -# SimxarmDataset, -# ] + dataset_class_per_env = { + "aloha": AlohaDataset, + "pusht": PushtDataset, + "xarm": XarmDataset, + } -# policies = [pol_cls.name for pol_cls in pol_classes] -# assert set(policies) == set(lerobot.available_policies) + policies = [pol_cls.name for pol_cls in policy_classes] + assert set(policies) == set(lerobot.available_policies), policies -# envs = [env_cls.name for env_cls in env_classes] -# assert set(envs) == set(lerobot.available_envs) + for env_name in lerobot.available_envs: + for task_name in lerobot.available_tasks_per_env[env_name]: + package_name = f"gym_{env_name}" + importlib.import_module(package_name) + gym_handle = f"{package_name}/{task_name}" + assert gym_handle in gym.envs.registry.keys(), gym_handle -# tasks_per_env = {env_cls.name: env_cls.available_tasks for env_cls in env_classes} -# for env in envs: -# assert set(tasks_per_env[env]) == set(lerobot.available_tasks_per_env[env]) - -# datasets_per_env = {env_cls.name: dat_cls.available_datasets for env_cls, dat_cls in zip(env_classes, dat_classes)} -# for env in envs: -# assert set(datasets_per_env[env]) == set(lerobot.available_datasets_per_env[env]) + dataset_class = dataset_class_per_env[env_name] + available_datasets = lerobot.available_datasets_per_env[env_name] + assert set(available_datasets) == set(dataset_class.available_datasets), f"{env_name=} {available_datasets=}" diff --git a/tests/test_datasets.py b/tests/test_datasets.py index e24d7b4d..71eefa9c 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -1,6 +1,12 @@ +import os +from pathlib import Path +import einops import pytest import torch +from lerobot.common.datasets.utils import compute_stats, get_stats_einops_patterns, load_data_with_delta_timestamps +from lerobot.common.datasets.xarm import XarmDataset +from lerobot.common.transforms import Prod from lerobot.common.utils import init_hydra_config import logging from lerobot.common.datasets.factory import make_dataset @@ -45,6 +51,7 @@ def test_factory(env_name, dataset_id, policy_name): keys_ndim_required.append( (key, 3, True), ) + assert dataset.data_dict[key].dtype == torch.uint8, f"{key}" # test number of dimensions for key, ndim, required in keys_ndim_required: @@ -81,28 +88,104 @@ def test_factory(env_name, dataset_id, policy_name): assert key in item, f"{key}" -# def test_compute_stats(): -# """Check that the statistics are computed correctly according to the stats_patterns property. +def test_compute_stats(): + """Check that the statistics are computed correctly according to the stats_patterns property. + + We compare with taking a straight min, mean, max, std of all the data in one pass (which we can do + because we are working with a small dataset). + """ + DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None + + # get transform to convert images from uint8 [0,255] to float32 [0,1] + transform = Prod(in_keys=XarmDataset.image_keys, prod=1 / 255.0) + + dataset = XarmDataset( + dataset_id="xarm_lift_medium", + root=DATA_DIR, + transform=transform, + ) + + # Note: we set the batch size to be smaller than the whole dataset to make sure we are testing batched + # computation of the statistics. While doing this, we also make sure it works when we don't divide the + # dataset into even batches. + computed_stats = compute_stats(dataset, batch_size=int(len(dataset) * 0.25)) + + # get einops patterns to aggregate batches and compute statistics + stats_patterns = get_stats_einops_patterns(dataset) + + # get all frames from the dataset in the same dtype and range as during compute_stats + data_dict = transform(dataset.data_dict) + + # compute stats based on all frames from the dataset without any batching + expected_stats = {} + for k, pattern in stats_patterns.items(): + expected_stats[k] = {} + expected_stats[k]["mean"] = einops.reduce(data_dict[k], pattern, "mean") + expected_stats[k]["std"] = torch.sqrt(einops.reduce((data_dict[k] - expected_stats[k]["mean"]) ** 2, pattern, "mean")) + expected_stats[k]["min"] = einops.reduce(data_dict[k], pattern, "min") + expected_stats[k]["max"] = einops.reduce(data_dict[k], pattern, "max") + + # test computed stats match expected stats + for k in stats_patterns: + assert torch.allclose(computed_stats[k]["mean"], expected_stats[k]["mean"]) + assert torch.allclose(computed_stats[k]["std"], expected_stats[k]["std"]) + assert torch.allclose(computed_stats[k]["min"], expected_stats[k]["min"]) + assert torch.allclose(computed_stats[k]["max"], expected_stats[k]["max"]) + + # TODO(rcadene): check that the stats used for training are correct too + # # load stats that are expected to match the ones returned by computed_stats + # assert (dataset.data_dir / "stats.pth").exists() + # loaded_stats = torch.load(dataset.data_dir / "stats.pth") + + # # test loaded stats match expected stats + # for k in stats_patterns: + # assert torch.allclose(loaded_stats[k]["mean"], expected_stats[k]["mean"]) + # assert torch.allclose(loaded_stats[k]["std"], expected_stats[k]["std"]) + # assert torch.allclose(loaded_stats[k]["min"], expected_stats[k]["min"]) + # assert torch.allclose(loaded_stats[k]["max"], expected_stats[k]["max"]) + + +def test_load_data_with_delta_timestamps_within_tolerance(): + data_dict = { + "timestamp": torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5]), + "index": torch.tensor([0, 1, 2, 3, 4]), + } + data_ids_per_episode = {0: torch.tensor([0, 1, 2, 3, 4])} + delta_timestamps = {"index": [-0.2, 0, 0.139]} + key = "index" + current_ts = 0.3 + episode = 0 + tol = 0.04 + data, is_pad = load_data_with_delta_timestamps(data_dict, data_ids_per_episode, delta_timestamps, key, current_ts, episode, tol) + assert not is_pad.any(), "Unexpected padding detected" + assert torch.equal(data, torch.tensor([0, 2, 3])), "Data does not match expected values" + +def test_load_data_with_delta_timestamps_outside_tolerance_inside_episode_range(): + data_dict = { + "timestamp": torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5]), + "index": torch.tensor([0, 1, 2, 3, 4]), + } + data_ids_per_episode = {0: torch.tensor([0, 1, 2, 3, 4])} + delta_timestamps = {"index": [-0.2, 0, 0.141]} + key = "index" + current_ts = 0.3 + episode = 0 + tol = 0.04 + with pytest.raises(AssertionError): + load_data_with_delta_timestamps(data_dict, data_ids_per_episode, delta_timestamps, key, current_ts, episode, tol) + +def test_load_data_with_delta_timestamps_outside_tolerance_outside_episode_range(): + data_dict = { + "timestamp": torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5]), + "index": torch.tensor([0, 1, 2, 3, 4]), + } + data_ids_per_episode = {0: torch.tensor([0, 1, 2, 3, 4])} + delta_timestamps = {"index": [-0.3, -0.24, 0, 0.26, 0.3]} + key = "index" + current_ts = 0.3 + episode = 0 + tol = 0.04 + data, is_pad = load_data_with_delta_timestamps(data_dict, data_ids_per_episode, delta_timestamps, key, current_ts, episode, tol) + assert torch.equal(is_pad, torch.tensor([True, False, False, True, True])), "Padding does not match expected values" + assert torch.equal(data, torch.tensor([0, 0, 2, 4, 4])), "Data does not match expected values" -# We compare with taking a straight min, mean, max, std of all the data in one pass (which we can do -# because we are working with a small dataset). -# """ -# cfg = init_hydra_config( -# DEFAULT_CONFIG_PATH, overrides=["env=aloha", "env.task=sim_transfer_cube_human"] -# ) -# dataset = make_dataset(cfg) -# # Get all of the data. -# all_data = dataset.data_dict -# # Note: we set the batch size to be smaller than the whole dataset to make sure we are testing batched -# # computation of the statistics. While doing this, we also make sure it works when we don't divide the -# # dataset into even batches. -# computed_stats = buffer._compute_stats(batch_size=int(len(all_data) * 0.75)) -# for k, pattern in buffer.stats_patterns.items(): -# expected_mean = einops.reduce(all_data[k], pattern, "mean") -# assert torch.allclose(computed_stats[k]["mean"], expected_mean) -# assert torch.allclose( -# computed_stats[k]["std"], -# torch.sqrt(einops.reduce((all_data[k] - expected_mean) ** 2, pattern, "mean")) -# ) -# assert torch.allclose(computed_stats[k]["min"], einops.reduce(all_data[k], pattern, "min")) -# assert torch.allclose(computed_stats[k]["max"], einops.reduce(all_data[k], pattern, "max"))