Merge remote-tracking branch 'upstream/main' into refactor_dp

This commit is contained in:
Alexander Soare 2024-04-11 17:52:10 +01:00
commit 94cc22da9e
29 changed files with 545 additions and 603 deletions

2
.github/poetry/cpu/poetry.lock generated vendored
View File

@ -940,7 +940,7 @@ mujoco = "^2.3.7"
type = "git"
url = "git@github.com:huggingface/gym-xarm.git"
reference = "HEAD"
resolved_reference = "08ddd5a9400783a6898bbf3c3014fc5da3961b9d"
resolved_reference = "6a88f7d63833705dfbec4b997bf36cac6b4a448c"
[[package]]
name = "gymnasium"

View File

@ -142,6 +142,7 @@ jobs:
wandb.enable=False \
offline_steps=2 \
online_steps=0 \
eval_episodes=1 \
device=cpu \
save_model=true \
save_freq=2 \
@ -159,17 +160,6 @@ jobs:
device=cpu \
policy.pretrained_model_path=tests/outputs/act/models/2.pt
# TODO(aliberts): This takes ~2mn to run, needs to be improved
# - name: Test eval ACT on ALOHA end-to-end (policy is None)
# run: |
# source .venv/bin/activate
# python lerobot/scripts/eval.py \
# --config lerobot/configs/default.yaml \
# policy=act \
# env=aloha \
# eval_episodes=1 \
# device=cpu
- name: Test train Diffusion on PushT end-to-end
run: |
source .venv/bin/activate
@ -179,9 +169,11 @@ jobs:
wandb.enable=False \
offline_steps=2 \
online_steps=0 \
eval_episodes=1 \
device=cpu \
save_model=true \
save_freq=2 \
policy.batch_size=2 \
hydra.run.dir=tests/outputs/diffusion/
- name: Test eval Diffusion on PushT end-to-end
@ -194,16 +186,6 @@ jobs:
device=cpu \
policy.pretrained_model_path=tests/outputs/diffusion/models/2.pt
- name: Test eval Diffusion on PushT end-to-end (policy is None)
run: |
source .venv/bin/activate
python lerobot/scripts/eval.py \
--config lerobot/configs/default.yaml \
policy=diffusion \
env=pusht \
eval_episodes=1 \
device=cpu
- name: Test train TDMPC on Simxarm end-to-end
run: |
source .venv/bin/activate
@ -213,9 +195,11 @@ jobs:
wandb.enable=False \
offline_steps=1 \
online_steps=1 \
eval_episodes=1 \
device=cpu \
save_model=true \
save_freq=2 \
policy.batch_size=2 \
hydra.run.dir=tests/outputs/tdmpc/
- name: Test eval TDMPC on Simxarm end-to-end
@ -227,13 +211,3 @@ jobs:
env.episode_length=8 \
device=cpu \
policy.pretrained_model_path=tests/outputs/tdmpc/models/2.pt
- name: Test eval TDPMC on Simxarm end-to-end (policy is None)
run: |
source .venv/bin/activate
python lerobot/scripts/eval.py \
--config lerobot/configs/default.yaml \
policy=tdmpc \
env=xarm \
eval_episodes=1 \
device=cpu

3
.gitignore vendored
View File

@ -11,6 +11,9 @@ rl
nautilus/*.yaml
*.key
# Slurm
sbatch*.sh
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]

View File

@ -120,34 +120,32 @@ wandb login
You can import our dataset class, download the data from the HuggingFace hub and use our rendering utilities:
```python
""" Copy pasted from `examples/1_visualize_dataset.py` """
import os
from pathlib import Path
import lerobot
from lerobot.common.datasets.aloha import AlohaDataset
from torchrl.data.replay_buffers import SamplerWithoutReplacement
from lerobot.scripts.visualize_dataset import render_dataset
print(lerobot.available_datasets)
# >>> ['aloha_sim_insertion_human', 'aloha_sim_insertion_scripted', 'aloha_sim_transfer_cube_human', 'aloha_sim_transfer_cube_scripted', 'pusht', 'xarm_lift_medium']
# we use this sampler to sample 1 frame after the other
sampler = SamplerWithoutReplacement(shuffle=False)
dataset = AlohaDataset("aloha_sim_transfer_cube_human", sampler=sampler)
# TODO(rcadene): remove DATA_DIR
dataset = AlohaDataset("pusht", root=Path(os.environ.get("DATA_DIR")))
video_paths = render_dataset(
dataset,
out_dir="outputs/visualize_dataset/example",
max_num_samples=300,
fps=50,
max_num_episodes=1,
)
print(video_paths)
# >>> ['outputs/visualize_dataset/example/episode_0.mp4']
# ['outputs/visualize_dataset/example/episode_0.mp4']
```
Or you can achieve the same result by executing our script from the command line:
```bash
python lerobot/scripts/visualize_dataset.py \
env=aloha \
task=sim_sim_transfer_cube_human \
env=pusht \
hydra.run.dir=outputs/visualize_dataset/example
# >>> ['outputs/visualize_dataset/example/episode_0.mp4']
```

View File

@ -1,24 +1,20 @@
import os
from torchrl.data.replay_buffers import SamplerWithoutReplacement
from pathlib import Path
import lerobot
from lerobot.common.datasets.aloha import AlohaDataset
from lerobot.common.datasets.pusht import PushtDataset
from lerobot.scripts.visualize_dataset import render_dataset
print(lerobot.available_datasets)
# >>> ['aloha_sim_insertion_human', 'aloha_sim_insertion_scripted', 'aloha_sim_transfer_cube_human', 'aloha_sim_transfer_cube_scripted', 'pusht', 'xarm_lift_medium']
# we use this sampler to sample 1 frame after the other
sampler = SamplerWithoutReplacement(shuffle=False)
dataset = AlohaDataset("aloha_sim_transfer_cube_human", sampler=sampler, root=os.environ.get("DATA_DIR"))
# TODO(rcadene): remove DATA_DIR
dataset = PushtDataset("pusht", root=Path(os.environ.get("DATA_DIR")))
video_paths = render_dataset(
dataset,
out_dir="outputs/visualize_dataset/example",
max_num_samples=300,
fps=50,
max_num_episodes=1,
)
print(video_paths)
# ['outputs/visualize_dataset/example/episode_0.mp4']

View File

@ -9,9 +9,8 @@ from pathlib import Path
import torch
from omegaconf import OmegaConf
from tqdm import trange
from lerobot.common.datasets.factory import make_offline_buffer
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.policies.diffusion.policy import DiffusionPolicy
from lerobot.common.utils import init_hydra_config
@ -37,19 +36,33 @@ policy = DiffusionPolicy(
cfg_obs_encoder=cfg.obs_encoder,
cfg_optimizer=cfg.optimizer,
cfg_ema=cfg.ema,
n_action_steps=cfg.n_action_steps,
**cfg.policy,
)
policy.train()
offline_buffer = make_offline_buffer(cfg)
dataset = make_dataset(cfg)
# create dataloader for offline training
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=4,
batch_size=cfg.policy.batch_size,
shuffle=True,
pin_memory=cfg.device != "cpu",
drop_last=True,
)
for step, batch in enumerate(dataloader):
info = policy(batch, step)
if step % cfg.log_freq == 0:
num_samples = (step + 1) * cfg.policy.batch_size
loss = info["loss"]
update_s = info["update_s"]
print(f"step:{step} samples:{num_samples} loss:{loss:.3f} update_time:{update_s:.3f}(seconds)")
for offline_step in trange(cfg.offline_steps):
train_info = policy.update(offline_buffer, offline_step)
if offline_step % cfg.log_freq == 0:
print(train_info)
# Save the policy, configuration, and normalization stats for later use.
policy.save(output_directory / "model.pt")
OmegaConf.save(cfg, output_directory / "config.yaml")
torch.save(offline_buffer.transform[-1].stats, output_directory / "stats.pth")
torch.save(dataset.transform.transforms[-1].stats, output_directory / "stats.pth")

View File

@ -12,14 +12,11 @@ Example:
print(lerobot.available_policies)
```
Note:
When implementing a concrete class (e.g. `AlohaDataset`, `PushtEnv`, `DiffusionPolicy`), you need to:
1. set the required class attributes:
- for classes inheriting from `AbstractDataset`: `available_datasets`
- for classes inheriting from `AbstractEnv`: `name`, `available_tasks`
- for classes inheriting from `AbstractPolicy`: `name`
2. update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`)
3. update variables in `tests/test_available.py` by importing your new class
When implementing a new dataset (e.g. `AlohaDataset`), policy (e.g. `DiffusionPolicy`), or environment, follow these steps:
- Set the required class attributes: `available_datasets`.
- Set the required class attributes: `name`.
- Update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`)
- Update variables in `tests/test_available.py` by importing your new class
"""
from lerobot.__version__ import __version__ # noqa: F401
@ -32,11 +29,11 @@ available_envs = [
available_tasks_per_env = {
"aloha": [
"sim_insertion",
"sim_transfer_cube",
"AlohaInsertion-v0",
"AlohaTransferCube-v0",
],
"pusht": ["pusht"],
"xarm": ["lift"],
"pusht": ["PushT-v0"],
"xarm": ["XarmLift-v0"],
}
available_datasets_per_env = {

View File

@ -105,7 +105,7 @@ class AlohaDataset(torch.utils.data.Dataset):
@property
def num_samples(self) -> int:
return len(self.data_dict["index"])
return len(self.data_dict["index"]) if "index" in self.data_dict else 0
@property
def num_episodes(self) -> int:

View File

@ -1,10 +1,11 @@
import logging
import os
from pathlib import Path
import torch
from torchvision.transforms import v2
from lerobot.common.datasets.utils import compute_or_load_stats
from lerobot.common.datasets.utils import compute_stats
from lerobot.common.transforms import NormalizeTransform, Prod
# DATA_DIR specifies to location where datasets are loaded. By default, DATA_DIR is None and
@ -40,7 +41,8 @@ def make_dataset(
if normalize:
# TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max,
# min_max_from_spec
# stats = dataset.compute_or_load_stats() if stats_path is None else torch.load(stats_path)
# TODO(rcadene): remove this and put it in config. Ideally we want to reproduce SOTA results just with mean_std
normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max"
if cfg.policy.name == "diffusion" and cfg.env.name == "pusht":
stats = {}
@ -51,21 +53,27 @@ def make_dataset(
stats["action"] = {}
stats["action"]["min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
stats["action"]["max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
else:
elif stats_path is None:
# instantiate a one frame dataset with light transform
stats_dataset = clsfunc(
dataset_id=cfg.dataset_id,
root=DATA_DIR,
transform=Prod(in_keys=clsfunc.image_keys, prod=1 / 255.0),
)
stats = compute_or_load_stats(stats_dataset)
# TODO(rcadene): remove this and put it in config. Ideally we want to reproduce SOTA results just with mean_std
normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max"
# load stats if the file exists already or compute stats and save it
precomputed_stats_path = stats_dataset.data_dir / "stats.pth"
if precomputed_stats_path.exists():
stats = torch.load(precomputed_stats_path)
else:
logging.info(f"compute_stats and save to {precomputed_stats_path}")
stats = compute_stats(stats_dataset)
torch.save(stats, stats_path)
else:
stats = torch.load(stats_path)
transforms = v2.Compose(
[
# TODO(rcadene): we need to do something about image_keys
Prod(in_keys=clsfunc.image_keys, prod=1 / 255.0),
NormalizeTransform(
stats,

View File

@ -2,11 +2,8 @@ from pathlib import Path
import einops
import numpy as np
import pygame
import pymunk
import torch
import tqdm
from gym_pusht.envs.pusht import pymunk_to_shapely
from lerobot.common.datasets._diffusion_policy_replay_buffer import (
ReplayBuffer as DiffusionPolicyReplayBuffer,
@ -20,64 +17,6 @@ PUSHT_URL = "https://diffusion-policy.cs.columbia.edu/data/training/pusht.zip"
PUSHT_ZARR = Path("pusht/pusht_cchi_v7_replay.zarr")
def get_goal_pose_body(pose):
mass = 1
inertia = pymunk.moment_for_box(mass, (50, 100))
body = pymunk.Body(mass, inertia)
# preserving the legacy assignment order for compatibility
# the order here doesn't matter somehow, maybe because CoM is aligned with body origin
body.position = pose[:2].tolist()
body.angle = pose[2]
return body
def add_segment(space, a, b, radius):
shape = pymunk.Segment(space.static_body, a, b, radius)
shape.color = pygame.Color("LightGray") # https://htmlcolorcodes.com/color-names
return shape
def add_tee(
space,
position,
angle,
scale=30,
color="LightSlateGray",
mask=None,
):
if mask is None:
mask = pymunk.ShapeFilter.ALL_MASKS()
mass = 1
length = 4
vertices1 = [
(-length * scale / 2, scale),
(length * scale / 2, scale),
(length * scale / 2, 0),
(-length * scale / 2, 0),
]
inertia1 = pymunk.moment_for_poly(mass, vertices=vertices1)
vertices2 = [
(-scale / 2, scale),
(-scale / 2, length * scale),
(scale / 2, length * scale),
(scale / 2, scale),
]
inertia2 = pymunk.moment_for_poly(mass, vertices=vertices1)
body = pymunk.Body(mass, inertia1 + inertia2)
shape1 = pymunk.Poly(body, vertices1)
shape2 = pymunk.Poly(body, vertices2)
shape1.color = pygame.Color(color)
shape2.color = pygame.Color(color)
shape1.filter = pymunk.ShapeFilter(mask=mask)
shape2.filter = pymunk.ShapeFilter(mask=mask)
body.center_of_gravity = (shape1.center_of_gravity + shape2.center_of_gravity) / 2
body.position = position
body.angle = angle
body.friction = 1
space.add(body, shape1, shape2)
return body
class PushtDataset(torch.utils.data.Dataset):
"""
@ -121,7 +60,7 @@ class PushtDataset(torch.utils.data.Dataset):
@property
def num_samples(self) -> int:
return len(self.data_dict["index"])
return len(self.data_dict["index"]) if "index" in self.data_dict else 0
@property
def num_episodes(self) -> int:
@ -158,6 +97,13 @@ class PushtDataset(torch.utils.data.Dataset):
return item
def _download_and_preproc_obsolete(self):
try:
import pymunk
from gym_pusht.envs.pusht import PushTEnv, pymunk_to_shapely
except ModuleNotFoundError as e:
print("`gym_pusht` is not installed. Please install it with `pip install 'lerobot[gym_pusht]'`")
raise e
assert self.root is not None
raw_dir = self.root / f"{self.dataset_id}_raw"
zarr_path = (raw_dir / PUSHT_ZARR).resolve()
@ -182,7 +128,7 @@ class PushtDataset(torch.utils.data.Dataset):
# TODO: verify that goal pose is expected to be fixed
goal_pos_angle = np.array([256, 256, np.pi / 4]) # x, y, theta (in radians)
goal_body = get_goal_pose_body(goal_pos_angle)
goal_body = PushTEnv.get_goal_pose_body(goal_pos_angle)
imgs = torch.from_numpy(dataset_dict["img"])
imgs = einops.rearrange(imgs, "b h w c -> b c h w")
@ -201,6 +147,9 @@ class PushtDataset(torch.utils.data.Dataset):
assert (episode_ids[idx0:idx1] == episode_id).all()
image = imgs[idx0:idx1]
assert image.min() >= 0.0
assert image.max() <= 255.0
image = image.type(torch.uint8)
state = states[idx0:idx1]
agent_pos = state[:, :2]
@ -217,14 +166,14 @@ class PushtDataset(torch.utils.data.Dataset):
# Add walls.
walls = [
add_segment(space, (5, 506), (5, 5), 2),
add_segment(space, (5, 5), (506, 5), 2),
add_segment(space, (506, 5), (506, 506), 2),
add_segment(space, (5, 506), (506, 506), 2),
PushTEnv.add_segment(space, (5, 506), (5, 5), 2),
PushTEnv.add_segment(space, (5, 5), (506, 5), 2),
PushTEnv.add_segment(space, (506, 5), (506, 506), 2),
PushTEnv.add_segment(space, (5, 506), (506, 506), 2),
]
space.add(*walls)
block_body = add_tee(space, block_pos[i].tolist(), block_angle[i].item())
block_body = PushTEnv.add_tee(space, block_pos[i].tolist(), block_angle[i].item())
goal_geom = pymunk_to_shapely(goal_body, block_body.shapes)
block_geom = pymunk_to_shapely(block_body, block_body.shapes)
intersection_area = goal_geom.intersection(block_geom).area
@ -265,16 +214,3 @@ class PushtDataset(torch.utils.data.Dataset):
self.data_dict[key] = torch.cat([x[key] for x in ep_dicts])
self.data_dict["index"] = torch.arange(0, total_frames, 1)
if __name__ == "__main__":
dataset = PushtDataset(
"pusht",
root=Path("data"),
delta_timestamps={
"observation.image": [0, -1, -0.2, -0.1],
"observation.state": [0, -1, -0.2, -0.1],
"action": [-0.1, 0, 1, 2, 3],
},
)
dataset[10]

View File

@ -1,5 +1,4 @@
import io
import logging
import zipfile
from copy import deepcopy
from math import ceil
@ -35,52 +34,56 @@ def download_and_extract_zip(url: str, destination_folder: Path) -> bool:
return False
def euclidean_distance_matrix(mat0, mat1):
# Compute the square of the distance matrix
sq0 = torch.sum(mat0**2, dim=1, keepdim=True)
sq1 = torch.sum(mat1**2, dim=1, keepdim=True)
distance_sq = sq0 + sq1.transpose(0, 1) - 2 * mat0 @ mat1.transpose(0, 1)
# Taking the square root to get the euclidean distance
distance = torch.sqrt(torch.clamp(distance_sq, min=0))
return distance
def is_contiguously_true_or_false(bool_vector):
assert bool_vector.ndim == 1
assert bool_vector.dtype == torch.bool
# Compare each element with its neighbor to find changes
changes = bool_vector[1:] != bool_vector[:-1]
# Count the number of changes
num_changes = changes.sum().item()
# If there's more than one change, the list is not contiguous
return num_changes <= 1
# examples = [
# ([True, False, True, False, False, False], False),
# ([True, True, True, False, False, False], True),
# ([False, False, False, False, False, False], True)
# ]
# for bool_list, expected in examples:
# result = is_contiguously_true_or_false(bool_list)
def load_data_with_delta_timestamps(
data_dict, data_ids_per_episode, delta_timestamps, key, current_ts, episode
data_dict: dict[torch.Tensor],
data_ids_per_episode: dict[torch.Tensor],
delta_timestamps: list[float],
key: str,
current_ts: float,
episode: int,
tol: float = 0.04,
):
"""
Given a current timestamp (e.g. current_ts=0.6) and a list of timestamps differences (e.g. delta_timestamps=[-0.8, -0.2, 0, 0.2]),
this function compute the query timestamps (e.g. [-0.2, 0.4, 0.6, 0.8]) and loads the closest frames of the specified modality (e.g. key="observation.image").
Importantly, when no frame can be found around a query timestamp within a specified tolerance window (e.g. tol=0.04), this function raises an AssertionError.
When a timestamp is queried before the first available timestamp of the episode or after the last available timestamp,
the violation of the tolerance doesnt raise an AssertionError, and the function populates a boolean array indicating which frames are outside of the episode range.
For instance, this boolean array is useful during batched training to not supervise actions associated to timestamps coming after the end of the episode,
or to pad the observations in a specific way. Note that by default the observation frames before the start of the episode are the same as the first frame of the episode.
Parameters:
- data_dict (dict): A dictionary containing the data, where each key corresponds to a different modality (e.g., "timestamp", "observation.image", "action").
- data_ids_per_episode (dict): A dictionary where keys are episode identifiers and values are lists of indices corresponding to frames associated with each episode.
- delta_timestamps (dict): A dictionary containing lists of delta timestamps for each possible key to be retrieved. These deltas are added to the current_ts to form the query timestamps.
- key (str): The key specifying which data modality is to be retrieved from the data_dict.
- current_ts (float): The current timestamp to which the delta timestamps are added to form the query timestamps.
- episode (int): The identifier of the episode from which frames are to be retrieved.
- tol (float, optional): The tolerance level used to determine if a data point is close enough to the query timestamp. Defaults to 0.04.
Returns:
- tuple: A tuple containing two elements:
- The first element is the data retrieved from the specified modality based on the closest match to the query timestamps.
- The second element is a boolean array indicating which frames were considered as padding (True if the distance to the closest timestamp was greater than the tolerance level).
Raises:
- AssertionError: If any of the frames unexpectedly violate the tolerance level. This could indicate synchronization issues with timestamps during data collection.
"""
# get indices of the frames associated to the episode, and their timestamps
ep_data_ids = data_ids_per_episode[episode]
ep_timestamps = data_dict["timestamp"][ep_data_ids]
# we make the assumption that the timestamps are sorted
ep_first_ts = ep_timestamps[0]
ep_last_ts = ep_timestamps[-1]
# get timestamps used as query to retrieve data of previous/future frames
delta_ts = delta_timestamps[key]
query_ts = current_ts + torch.tensor(delta_ts)
# compute distances between each query timestamp and all timestamps of all the frames belonging to the episode
dist = euclidean_distance_matrix(query_ts[:, None], ep_timestamps[:, None])
dist = torch.cdist(query_ts[:, None], ep_timestamps[:, None], p=1)
min_, argmin_ = dist.min(1)
# get the indices of the data that are closest to the query timestamps
@ -92,24 +95,29 @@ def load_data_with_delta_timestamps(
# TODO(rcadene): synchronize timestamps + interpolation if needed
tol = 0.04
is_pad = min_ > tol
assert is_contiguously_true_or_false(is_pad), (
f"One or several timestamps unexpectedly violate the tolerance ({min_} > {tol=})."
# check violated query timestamps are all outside the episode range
assert ((query_ts[is_pad] < ep_first_ts) | (ep_last_ts < query_ts[is_pad])).all(), (
f"One or several timestamps unexpectedly violate the tolerance ({min_} > {tol=}) inside episode range."
"This might be due to synchronization issues with timestamps during data collection."
)
return data, is_pad
def compute_or_load_stats(dataset, batch_size=32, max_num_samples=None):
stats_path = dataset.data_dir / "stats.pth"
if stats_path.exists():
return torch.load(stats_path)
def get_stats_einops_patterns(dataset):
"""These einops patterns will be used to aggregate batches and compute statistics."""
stats_patterns = {
"action": "b c -> c",
"observation.state": "b c -> c",
}
for key in dataset.image_keys:
stats_patterns[key] = "b c h w -> c 1 1"
return stats_patterns
logging.info(f"compute_stats and save to {stats_path}")
def compute_stats(dataset, batch_size=32, max_num_samples=None):
if max_num_samples is None:
max_num_samples = len(dataset)
else:
@ -124,13 +132,8 @@ def compute_or_load_stats(dataset, batch_size=32, max_num_samples=None):
drop_last=False,
)
# these einops patterns will be used to aggregate batches and compute statistics
stats_patterns = {
"action": "b c -> c",
"observation.state": "b c -> c",
}
for key in dataset.image_keys:
stats_patterns[key] = "b c h w -> c 1 1"
# get einops patterns to aggregate batches and compute statistics
stats_patterns = get_stats_einops_patterns(dataset)
# mean and std will be computed incrementally while max and min will track the running value.
mean, std, max, min = {}, {}, {}, {}
@ -201,7 +204,6 @@ def compute_or_load_stats(dataset, batch_size=32, max_num_samples=None):
"min": min[key],
}
torch.save(stats, stats_path)
return stats

View File

@ -60,7 +60,7 @@ class XarmDataset(torch.utils.data.Dataset):
@property
def num_samples(self) -> int:
return len(self.data_dict["index"])
return len(self.data_dict["index"]) if "index" in self.data_dict else 0
@property
def num_episodes(self) -> int:
@ -126,7 +126,8 @@ class XarmDataset(torch.utils.data.Dataset):
image = torch.tensor(dataset_dict["observations"]["rgb"][idx0:idx1])
state = torch.tensor(dataset_dict["observations"]["state"][idx0:idx1])
action = torch.tensor(dataset_dict["actions"][idx0:idx1])
# TODO(rcadene): concat the last "next_observations" to "observations"
# TODO(rcadene): we have a missing last frame which is the observation when the env is done
# it is critical to have this frame for tdmpc to predict a "done observation/state"
# next_image = torch.tensor(dataset_dict["next_observations"]["rgb"][idx0:idx1])
# next_state = torch.tensor(dataset_dict["next_observations"]["state"][idx0:idx1])
next_reward = torch.tensor(dataset_dict["rewards"][idx0:idx1])

View File

@ -19,6 +19,7 @@ def preprocess_observation(observation, transform=None):
img = einops.rearrange(img, "b h w c -> b c h w")
obs[imgkey] = img
# TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing requirement for "agent_pos"
obs["observation.state"] = torch.from_numpy(observation["agent_pos"]).float()
# apply same transforms as in training

View File

@ -29,9 +29,9 @@ def make_policy(cfg):
if cfg.policy.pretrained_model_path:
# TODO(rcadene): hack for old pretrained models from fowm
if cfg.policy.name == "tdmpc" and "fowm" in cfg.policy.pretrained_model_path:
if "offline" in cfg.pretrained_model_path:
if "offline" in cfg.policy.pretrained_model_path:
policy.step[0] = 25000
elif "final" in cfg.pretrained_model_path:
elif "final" in cfg.policy.pretrained_model_path:
policy.step[0] = 100000
else:
raise NotImplementedError()

View File

@ -333,94 +333,6 @@ class TDMPCPolicy(nn.Module):
"""Main update function. Corresponds to one iteration of the model learning."""
start_time = time.time()
# num_slices = self.cfg.batch_size
# batch_size = self.cfg.horizon * num_slices
# if demo_buffer is None:
# demo_batch_size = 0
# else:
# # Update oversampling ratio
# demo_pc_batch = h.linear_schedule(self.cfg.demo_schedule, step)
# demo_num_slices = int(demo_pc_batch * self.batch_size)
# demo_batch_size = self.cfg.horizon * demo_num_slices
# batch_size -= demo_batch_size
# num_slices -= demo_num_slices
# replay_buffer._sampler.num_slices = num_slices
# demo_buffer._sampler.num_slices = demo_num_slices
# assert demo_batch_size % self.cfg.horizon == 0
# assert demo_batch_size % demo_num_slices == 0
# assert batch_size % self.cfg.horizon == 0
# assert batch_size % num_slices == 0
# # Sample from interaction dataset
# def process_batch(batch, horizon, num_slices):
# # trajectory t = 256, horizon h = 5
# # (t h) ... -> h t ...
# batch = batch.reshape(num_slices, horizon).transpose(1, 0).contiguous()
# obs = {
# "rgb": batch["observation", "image"][FIRST_FRAME].to(self.device, non_blocking=True),
# "state": batch["observation", "state"][FIRST_FRAME].to(self.device, non_blocking=True),
# }
# action = batch["action"].to(self.device, non_blocking=True)
# next_obses = {
# "rgb": batch["next", "observation", "image"].to(self.device, non_blocking=True),
# "state": batch["next", "observation", "state"].to(self.device, non_blocking=True),
# }
# reward = batch["next", "reward"].to(self.device, non_blocking=True)
# idxs = batch["index"][FIRST_FRAME].to(self.device, non_blocking=True)
# weights = batch["_weight"][FIRST_FRAME, :, None].to(self.device, non_blocking=True)
# # TODO(rcadene): rearrange directly in offline dataset
# if reward.ndim == 2:
# reward = einops.rearrange(reward, "h t -> h t 1")
# assert reward.ndim == 3
# assert reward.shape == (horizon, num_slices, 1)
# # We dont use `batch["next", "done"]` since it only indicates the end of an
# # episode, but not the end of the trajectory of an episode.
# # Neither does `batch["next", "terminated"]`
# done = torch.zeros_like(reward, dtype=torch.bool, device=reward.device)
# mask = torch.ones_like(reward, dtype=torch.bool, device=reward.device)
# return obs, action, next_obses, reward, mask, done, idxs, weights
# batch = replay_buffer.sample(batch_size) if self.cfg.balanced_sampling else replay_buffer.sample()
# obs, action, next_obses, reward, mask, done, idxs, weights = process_batch(
# batch, self.cfg.horizon, num_slices
# )
# Sample from demonstration dataset
# if demo_batch_size > 0:
# demo_batch = demo_buffer.sample(demo_batch_size)
# (
# demo_obs,
# demo_action,
# demo_next_obses,
# demo_reward,
# demo_mask,
# demo_done,
# demo_idxs,
# demo_weights,
# ) = process_batch(demo_batch, self.cfg.horizon, demo_num_slices)
# if isinstance(obs, dict):
# obs = {k: torch.cat([obs[k], demo_obs[k]]) for k in obs}
# next_obses = {k: torch.cat([next_obses[k], demo_next_obses[k]], dim=1) for k in next_obses}
# else:
# obs = torch.cat([obs, demo_obs])
# next_obses = torch.cat([next_obses, demo_next_obses], dim=1)
# action = torch.cat([action, demo_action], dim=1)
# reward = torch.cat([reward, demo_reward], dim=1)
# mask = torch.cat([mask, demo_mask], dim=1)
# done = torch.cat([done, demo_done], dim=1)
# idxs = torch.cat([idxs, demo_idxs])
# weights = torch.cat([weights, demo_weights])
batch_size = batch["index"].shape[0]
# TODO(rcadene): convert tdmpc with (batch size, time/horizon, channels)
@ -534,6 +446,7 @@ class TDMPCPolicy(nn.Module):
)
self.optim.step()
# TODO(rcadene): implement PrioritizedSampling by modifying sampler.weights with priorities computed by a criterion
# if self.cfg.per:
# # Update priorities
# priorities = priority_loss.clamp(max=1e4).detach()

View File

@ -99,6 +99,7 @@ def init_hydra_config(config_path: str, overrides: list[str] | None = None) -> D
def print_cuda_memory_usage():
"""Use this function to locate and debug memory leak."""
import gc
gc.collect()

View File

@ -18,7 +18,6 @@ env:
from_pixels: True
pixels_only: False
image_size: [3, 480, 640]
action_repeat: 1
episode_length: 400
fps: ${fps}

View File

@ -18,7 +18,6 @@ env:
from_pixels: True
pixels_only: False
image_size: 96
action_repeat: 1
episode_length: 300
fps: ${fps}

View File

@ -17,7 +17,6 @@ env:
from_pixels: True
pixels_only: False
image_size: 84
# action_repeat: 2 # we can remove if policy has n_action_steps=2
episode_length: 25
fps: ${fps}

View File

@ -36,6 +36,7 @@ policy:
log_std_max: 2
# learning
batch_size: 256
max_buffer_size: 10000
horizon: 5
reward_coef: 0.5

View File

@ -32,6 +32,7 @@ import json
import logging
import threading
import time
from copy import deepcopy
from datetime import datetime as dt
from pathlib import Path
@ -56,15 +57,15 @@ def write_video(video_path, stacked_frames, fps):
def eval_policy(
env: gym.vector.VectorEnv,
policy,
save_video: bool = False,
policy: torch.nn.Module,
max_episodes_rendered: int = 0,
video_dir: Path = None,
# TODO(rcadene): make it possible to overwrite fps? we should use env.fps
fps: int = 15,
return_first_video: bool = False,
transform: callable = None,
seed=None,
):
fps = env.unwrapped.metadata["render_fps"]
if policy is not None:
policy.eval()
device = "cpu" if policy is None else next(policy.parameters()).device
@ -83,14 +84,11 @@ def eval_policy(
# needed as I'm currently taking a ceil.
ep_frames = []
def maybe_render_frame(env):
if save_video: # noqa: B023
if return_first_video:
visu = env.envs[0].render()
visu = visu[None, ...] # add batch dim
else:
visu = np.stack([env.render() for env in env.envs])
ep_frames.append(visu) # noqa: B023
def render_frame(env):
# noqa: B023
eps_rendered = min(max_episodes_rendered, len(env.envs))
visu = np.stack([env.envs[i].render() for i in range(eps_rendered)])
ep_frames.append(visu) # noqa: B023
for _ in range(num_episodes):
seeds.append("TODO")
@ -104,8 +102,14 @@ def eval_policy(
# reset the environment
observation, info = env.reset(seed=seed)
maybe_render_frame(env)
if max_episodes_rendered > 0:
render_frame(env)
observations = []
actions = []
# episode
# frame_id
# timestamp
rewards = []
successes = []
dones = []
@ -113,8 +117,13 @@ def eval_policy(
done = torch.tensor([False for _ in env.envs])
step = 0
while not done.all():
# format from env keys to lerobot keys
observation = preprocess_observation(observation)
observations.append(deepcopy(observation))
# apply transform to normalize the observations
observation = preprocess_observation(observation, transform)
for key in observation:
observation[key] = torch.stack([transform({key: item})[key] for item in observation[key]])
# send observation to device/gpu
observation = {key: observation[key].to(device, non_blocking=True) for key in observation}
@ -126,11 +135,13 @@ def eval_policy(
# apply inverse transform to unnormalize the action
action = postprocess_action(action, transform)
# apply the next
# apply the next action
observation, reward, terminated, truncated, info = env.step(action)
maybe_render_frame(env)
if max_episodes_rendered > 0:
render_frame(env)
# TODO(rcadene): implement a wrapper over env to return torch tensors in float32 (and cuda?)
action = torch.from_numpy(action)
reward = torch.from_numpy(reward)
terminated = torch.from_numpy(terminated)
truncated = torch.from_numpy(truncated)
@ -147,12 +158,24 @@ def eval_policy(
success = [False for _ in env.envs]
success = torch.tensor(success)
actions.append(action)
rewards.append(reward)
dones.append(done)
successes.append(success)
step += 1
env.close()
# add the last observation when the env is done
observation = preprocess_observation(observation)
observations.append(deepcopy(observation))
new_obses = {}
for key in observations[0].keys(): # noqa: SIM118
new_obses[key] = torch.stack([obs[key] for obs in observations], dim=1)
observations = new_obses
actions = torch.stack(actions, dim=1)
rewards = torch.stack(rewards, dim=1)
successes = torch.stack(successes, dim=1)
dones = torch.stack(dones, dim=1)
@ -172,29 +195,61 @@ def eval_policy(
max_rewards.extend(batch_max_reward.tolist())
all_successes.extend(batch_success.tolist())
env.close()
# similar logic is implemented in dataset preprocessing
ep_dicts = []
num_episodes = dones.shape[0]
total_frames = 0
idx0 = idx1 = 0
data_ids_per_episode = {}
for ep_id in range(num_episodes):
num_frames = done_indices[ep_id].item() + 1
# TODO(rcadene): We need to add a missing last frame which is the observation
# of a done state. it is critical to have this frame for tdmpc to predict a "done observation/state"
ep_dict = {
"action": actions[ep_id, :num_frames],
"episode": torch.tensor([ep_id] * num_frames),
"frame_id": torch.arange(0, num_frames, 1),
"timestamp": torch.arange(0, num_frames, 1) / fps,
"next.done": dones[ep_id, :num_frames],
"next.reward": rewards[ep_id, :num_frames].type(torch.float32),
}
for key in observations:
ep_dict[key] = observations[key][ep_id, :num_frames]
ep_dicts.append(ep_dict)
if save_video or return_first_video:
total_frames += num_frames
idx1 += num_frames
data_ids_per_episode[ep_id] = torch.arange(idx0, idx1, 1)
idx0 = idx1
# similar logic is implemented in dataset preprocessing
data_dict = {}
keys = ep_dicts[0].keys()
for key in keys:
data_dict[key] = torch.cat([x[key] for x in ep_dicts])
data_dict["index"] = torch.arange(0, total_frames, 1)
if max_episodes_rendered > 0:
batch_stacked_frames = np.stack(ep_frames, 1) # (b, t, *)
if save_video:
for stacked_frames, done_index in zip(
batch_stacked_frames, done_indices.flatten().tolist(), strict=False
):
if episode_counter >= num_episodes:
continue
video_dir.mkdir(parents=True, exist_ok=True)
video_path = video_dir / f"eval_episode_{episode_counter}.mp4"
thread = threading.Thread(
target=write_video,
args=(str(video_path), stacked_frames[:done_index], fps),
)
thread.start()
threads.append(thread)
episode_counter += 1
for stacked_frames, done_index in zip(
batch_stacked_frames, done_indices.flatten().tolist(), strict=False
):
if episode_counter >= num_episodes:
continue
video_dir.mkdir(parents=True, exist_ok=True)
video_path = video_dir / f"eval_episode_{episode_counter}.mp4"
thread = threading.Thread(
target=write_video,
args=(str(video_path), stacked_frames[:done_index], fps),
)
thread.start()
threads.append(thread)
episode_counter += 1
if return_first_video:
first_video = batch_stacked_frames[0].transpose(0, 3, 1, 2)
videos = einops.rearrange(batch_stacked_frames, "b t h w c -> b t c h w")
for thread in threads:
thread.join()
@ -225,9 +280,13 @@ def eval_policy(
"eval_s": time.time() - start,
"eval_ep_s": (time.time() - start) / num_episodes,
},
"episodes": {
"data_dict": data_dict,
"data_ids_per_episode": data_ids_per_episode,
},
}
if return_first_video:
return info, first_video
if max_episodes_rendered > 0:
info["videos"] = videos
return info
@ -253,16 +312,14 @@ def eval(cfg: dict, out_dir=None, stats_path=None):
logging.info("Making environment.")
env = make_env(cfg, num_parallel_envs=cfg.eval_episodes)
# when policy is None, rollout a random policy
policy = make_policy(cfg) if cfg.policy.pretrained_model_path else None
logging.info("Making policy.")
policy = make_policy(cfg)
info = eval_policy(
env,
policy=policy,
save_video=True,
policy,
max_episodes_rendered=10,
video_dir=Path(out_dir) / "eval",
fps=cfg.env.fps,
# TODO(rcadene): what should we do with the transform?
transform=transform,
seed=cfg.seed,
)
@ -270,6 +327,9 @@ def eval(cfg: dict, out_dir=None, stats_path=None):
# Save info
with open(Path(out_dir) / "eval_info.json", "w") as f:
# remove pytorch tensors which are not serializable to save the evaluation results only
del info["episodes"]
del info["videos"]
json.dump(info, f, indent=2)
logging.info("End of eval")

View File

@ -1,8 +1,8 @@
import logging
from copy import deepcopy
from pathlib import Path
import hydra
import numpy as np
import torch
from lerobot.common.datasets.factory import make_dataset
@ -108,6 +108,64 @@ def log_eval_info(logger, info, step, cfg, dataset, is_offline):
logger.log_dict(info, step, mode="eval")
def calculate_online_sample_weight(n_off: int, n_on: int, pc_on: float):
"""
Calculate the sampling weight to be assigned to samples so that a specified percentage of the batch comes from online dataset (on average).
Parameters:
- n_off (int): Number of offline samples, each with a sampling weight of 1.
- n_on (int): Number of online samples.
- pc_on (float): Desired percentage of online samples in decimal form (e.g., 50% as 0.5).
The total weight of offline samples is n_off * 1.0.
The total weight of offline samples is n_on * w.
The total combined weight of all samples is n_off + n_on * w.
The fraction of the weight that is online is n_on * w / (n_off + n_on * w).
We want this fraction to equal pc_on, so we set up the equation n_on * w / (n_off + n_on * w) = pc_on.
The solution is w = - (n_off * pc_on) / (n_on * (pc_on - 1))
"""
assert 0.0 <= pc_on <= 1.0
return -(n_off * pc_on) / (n_on * (pc_on - 1))
def add_episodes_inplace(episodes, online_dataset, concat_dataset, sampler, pc_online_samples):
data_dict = episodes["data_dict"]
data_ids_per_episode = episodes["data_ids_per_episode"]
if len(online_dataset) == 0:
# initialize online dataset
online_dataset.data_dict = data_dict
online_dataset.data_ids_per_episode = data_ids_per_episode
else:
# find episode index and data frame indices according to previous episode in online_dataset
start_episode = max(online_dataset.data_ids_per_episode.keys()) + 1
start_index = online_dataset.data_dict["index"][-1].item() + 1
data_dict["episode"] += start_episode
data_dict["index"] += start_index
# extend online dataset
for key in data_dict:
# TODO(rcadene): avoid reallocating memory at every step by preallocating memory or changing our data structure
online_dataset.data_dict[key] = torch.cat([online_dataset.data_dict[key], data_dict[key]])
for ep_id in data_ids_per_episode:
online_dataset.data_ids_per_episode[ep_id + start_episode] = (
data_ids_per_episode[ep_id] + start_index
)
# update the concatenated dataset length used during sampling
concat_dataset.cumulative_sizes = concat_dataset.cumsum(concat_dataset.datasets)
# update the sampling weights for each frame so that online frames get sampled a certain percentage of times
len_online = len(online_dataset)
len_offline = len(concat_dataset) - len_online
weight_offline = 1.0
weight_online = calculate_online_sample_weight(len_offline, len_online, pc_online_samples)
sampler.weights = torch.tensor([weight_offline] * len_offline + [weight_online] * len(online_dataset))
# update the total number of samples used during sampling
sampler.num_samples = len(concat_dataset)
def train(cfg: dict, out_dir=None, job_name=None):
if out_dir is None:
raise NotImplementedError()
@ -126,26 +184,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
set_global_seed(cfg.seed)
logging.info("make_dataset")
dataset = make_dataset(cfg)
# TODO(rcadene): move balanced_sampling, per_alpha, per_beta outside policy
# if cfg.policy.balanced_sampling:
# logging.info("make online_buffer")
# num_traj_per_batch = cfg.policy.batch_size
# online_sampler = PrioritizedSliceSampler(
# max_capacity=100_000,
# alpha=cfg.policy.per_alpha,
# beta=cfg.policy.per_beta,
# num_slices=num_traj_per_batch,
# strict_length=True,
# )
# online_buffer = TensorDictReplayBuffer(
# storage=LazyMemmapStorage(100_000),
# sampler=online_sampler,
# transform=dataset.transform,
# )
offline_dataset = make_dataset(cfg)
logging.info("make_env")
env = make_env(cfg, num_parallel_envs=cfg.eval_episodes)
@ -163,9 +202,8 @@ def train(cfg: dict, out_dir=None, job_name=None):
logging.info(f"{cfg.env.task=}")
logging.info(f"{cfg.offline_steps=} ({format_big_number(cfg.offline_steps)})")
logging.info(f"{cfg.online_steps=}")
logging.info(f"{cfg.env.action_repeat=}")
logging.info(f"{dataset.num_samples=} ({format_big_number(dataset.num_samples)})")
logging.info(f"{dataset.num_episodes=}")
logging.info(f"{offline_dataset.num_samples=} ({format_big_number(offline_dataset.num_samples)})")
logging.info(f"{offline_dataset.num_episodes=}")
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
@ -173,18 +211,17 @@ def train(cfg: dict, out_dir=None, job_name=None):
def _maybe_eval_and_maybe_save(step):
if step % cfg.eval_freq == 0:
logging.info(f"Eval policy at step {step}")
eval_info, first_video = eval_policy(
eval_info = eval_policy(
env,
policy,
return_first_video=True,
video_dir=Path(out_dir) / "eval",
save_video=True,
transform=dataset.transform,
max_episodes_rendered=4,
transform=offline_dataset.transform,
seed=cfg.seed,
)
log_eval_info(logger, eval_info["aggregated"], step, cfg, dataset, is_offline)
log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_dataset, is_offline)
if cfg.wandb.enable:
logger.log_video(first_video, step, mode="eval")
logger.log_video(eval_info["videos"][0], step, mode="eval")
logging.info("Resume training")
if cfg.save_model and step % cfg.save_freq == 0:
@ -192,18 +229,19 @@ def train(cfg: dict, out_dir=None, job_name=None):
logger.save_model(policy, identifier=step)
logging.info("Resume training")
step = 0 # number of policy update (forward + backward + optim)
is_offline = True
# create dataloader for offline training
dataloader = torch.utils.data.DataLoader(
dataset,
offline_dataset,
num_workers=4,
batch_size=cfg.policy.batch_size,
shuffle=True,
pin_memory=cfg.device != "cpu",
drop_last=True,
drop_last=False,
)
dl_iter = cycle(dataloader)
step = 0 # number of policy update (forward + backward + optim)
is_offline = True
for offline_step in range(cfg.offline_steps):
if offline_step == 0:
logging.info("Start offline training on a fixed dataset")
@ -217,7 +255,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
# TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done?
if step % cfg.log_freq == 0:
log_train_info(logger, train_info, step, cfg, dataset, is_offline)
log_train_info(logger, train_info, step, cfg, offline_dataset, is_offline)
# Note: _maybe_eval_and_maybe_save happens **after** the `step`th training update has completed, so we pass in
# step + 1.
@ -225,61 +263,60 @@ def train(cfg: dict, out_dir=None, job_name=None):
step += 1
raise NotImplementedError()
# create an env dedicated to online episodes collection from policy rollout
rollout_env = make_env(cfg, num_parallel_envs=1)
# create an empty online dataset similar to offline dataset
online_dataset = deepcopy(offline_dataset)
online_dataset.data_dict = {}
online_dataset.data_ids_per_episode = {}
# create dataloader for online training
concat_dataset = torch.utils.data.ConcatDataset([offline_dataset, online_dataset])
weights = [1.0] * len(concat_dataset)
sampler = torch.utils.data.WeightedRandomSampler(
weights, num_samples=len(concat_dataset), replacement=True
)
dataloader = torch.utils.data.DataLoader(
concat_dataset,
num_workers=4,
batch_size=cfg.policy.batch_size,
sampler=sampler,
pin_memory=cfg.device != "cpu",
drop_last=False,
)
dl_iter = cycle(dataloader)
demo_buffer = dataset if cfg.policy.balanced_sampling else None
online_step = 0
is_offline = False
for env_step in range(cfg.online_steps):
if env_step == 0:
logging.info("Start online training by interacting with environment")
# TODO: add configurable number of rollout? (default=1)
with torch.no_grad():
rollout = env.rollout(
max_steps=cfg.env.episode_length,
policy=policy,
auto_cast_to_device=True,
eval_info = eval_policy(
rollout_env,
policy,
transform=offline_dataset.transform,
seed=cfg.seed,
)
assert (
len(rollout.batch_size) == 2
), "2 dimensions expected: number of env in parallel x max number of steps during rollout"
num_parallel_env = rollout.batch_size[0]
if num_parallel_env != 1:
# TODO(rcadene): when num_parallel_env > 1, rollout["episode"] needs to be properly set and we need to add tests
raise NotImplementedError()
num_max_steps = rollout.batch_size[1]
assert num_max_steps <= cfg.env.episode_length
# reshape to have a list of steps to insert into online_buffer
rollout = rollout.reshape(num_parallel_env * num_max_steps)
# set same episode index for all time steps contained in this rollout
rollout["episode"] = torch.tensor([env_step] * len(rollout), dtype=torch.int)
# online_buffer.extend(rollout)
ep_sum_reward = rollout["next", "reward"].sum()
ep_max_reward = rollout["next", "reward"].max()
ep_success = rollout["next", "success"].any()
rollout_info = {
"avg_sum_reward": np.nanmean(ep_sum_reward),
"avg_max_reward": np.nanmean(ep_max_reward),
"pc_success": np.nanmean(ep_success) * 100,
"env_step": env_step,
"ep_length": len(rollout),
}
online_pc_sampling = cfg.get("demo_schedule", 0.5)
add_episodes_inplace(
eval_info["episodes"], online_dataset, concat_dataset, sampler, online_pc_sampling
)
for _ in range(cfg.policy.utd):
train_info = policy.update(
# online_buffer,
step,
demo_buffer=demo_buffer,
)
policy.train()
batch = next(dl_iter)
for key in batch:
batch[key] = batch[key].to(cfg.device, non_blocking=True)
train_info = policy(batch, step)
if step % cfg.log_freq == 0:
train_info.update(rollout_info)
log_train_info(logger, train_info, step, cfg, dataset, is_offline)
log_train_info(logger, train_info, step, cfg, online_dataset, is_offline)
# Note: _maybe_eval_and_maybe_save happens **after** the `step`th training update has completed, so we pass
# in step + 1.

View File

@ -6,9 +6,6 @@ import einops
import hydra
import imageio
import torch
from torchrl.data.replay_buffers import (
SamplerWithoutReplacement,
)
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.logger import log_output_dir
@ -39,19 +36,11 @@ def visualize_dataset(cfg: dict, out_dir=None):
init_logging()
log_output_dir(out_dir)
# we expect frames of each episode to be stored next to each others sequentially
sampler = SamplerWithoutReplacement(
shuffle=False,
)
logging.info("make_dataset")
dataset = make_dataset(
cfg,
overwrite_sampler=sampler,
# remove all transformations such as rescale images from [0,255] to [0,1] or normalization
normalize=False,
overwrite_batch_size=1,
overwrite_prefetch=12,
)
logging.info("Start rendering episodes from offline buffer")
@ -60,64 +49,49 @@ def visualize_dataset(cfg: dict, out_dir=None):
logging.info(video_path)
def render_dataset(dataset, out_dir, max_num_samples, fps):
def render_dataset(dataset, out_dir, max_num_episodes):
out_dir = Path(out_dir)
video_paths = []
threads = []
frames = {}
current_ep_idx = 0
logging.info(f"Visualizing episode {current_ep_idx}")
for i in range(max_num_samples):
# TODO(rcadene): make it work with bsize > 1
ep_td = dataset.sample(1)
ep_idx = ep_td["episode"][FIRST_FRAME].item()
# TODO(rcadene): modify dataset._sampler._sample_list or sampler to randomly sample an episode, but sequentially sample frames
num_frames_left = dataset._sampler._sample_list.numel()
episode_is_done = ep_idx != current_ep_idx
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=4,
batch_size=1,
shuffle=False,
)
dl_iter = iter(dataloader)
if episode_is_done:
logging.info(f"Rendering episode {current_ep_idx}")
num_episodes = len(dataset.data_ids_per_episode)
for ep_id in range(min(max_num_episodes, num_episodes)):
logging.info(f"Rendering episode {ep_id}")
for im_key in dataset.image_keys:
if not episode_is_done and num_frames_left > 0 and i < (max_num_samples - 1):
frames = {}
for _ in dataset.data_ids_per_episode[ep_id]:
item = next(dl_iter)
for im_key in dataset.image_keys:
# when first frame of episode, initialize frames dict
if im_key not in frames:
frames[im_key] = []
# add current frame to list of frames to render
frames[im_key].append(ep_td[im_key])
frames[im_key].append(item[im_key])
out_dir.mkdir(parents=True, exist_ok=True)
for im_key in dataset.image_keys:
if len(dataset.image_keys) > 1:
im_name = im_key.replace("observation.images.", "")
video_path = out_dir / f"episode_{ep_id}_{im_name}.mp4"
else:
# When episode has no more frame in its list of observation,
# one frame still remains. It is the result of the last action taken.
# It is stored in `"next"`, so we add it to the list of frames to render.
frames[im_key].append(ep_td["next"][im_key])
video_path = out_dir / f"episode_{ep_id}.mp4"
video_paths.append(video_path)
out_dir.mkdir(parents=True, exist_ok=True)
if len(dataset.image_keys) > 1:
camera = im_key[-1]
video_path = out_dir / f"episode_{current_ep_idx}_{camera}.mp4"
else:
video_path = out_dir / f"episode_{current_ep_idx}.mp4"
video_paths.append(str(video_path))
thread = threading.Thread(
target=cat_and_write_video,
args=(str(video_path), frames[im_key], fps),
)
thread.start()
threads.append(thread)
current_ep_idx = ep_idx
# reset list of frames
del frames[im_key]
if num_frames_left == 0:
logging.info("Ran out of frames")
break
if current_ep_idx == NUM_EPISODES_TO_RENDER:
break
thread = threading.Thread(
target=cat_and_write_video,
args=(str(video_path), frames[im_key], dataset.fps),
)
thread.start()
threads.append(thread)
for thread in threads:
thread.join()

6
poetry.lock generated
View File

@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
# This file is automatically @generated by Poetry 1.8.1 and should not be changed by hand.
[[package]]
name = "absl-py"
@ -921,7 +921,7 @@ shapely = "^2.0.3"
type = "git"
url = "git@github.com:huggingface/gym-pusht.git"
reference = "HEAD"
resolved_reference = "824b22832cc8d71a4b4e96a57563510cf47e30c1"
resolved_reference = "080d4ce4d8d3140b2fd204ed628bda14dc58ff06"
[[package]]
name = "gym-xarm"
@ -941,7 +941,7 @@ mujoco = "^2.3.7"
type = "git"
url = "git@github.com:huggingface/gym-xarm.git"
reference = "HEAD"
resolved_reference = "ce294c0d30def08414d9237e2bf9f373d448ca07"
resolved_reference = "6a88f7d63833705dfbec4b997bf36cac6b4a448c"
[[package]]
name = "gymnasium"

View File

@ -1,25 +0,0 @@
#!/bin/bash
#SBATCH --nodes=1 # total number of nodes (N to be defined)
#SBATCH --ntasks-per-node=1 # number of tasks per node (here 8 tasks, or 1 task per GPU)
#SBATCH --gres=gpu:1 # number of GPUs reserved per node (here 8, or all the GPUs)
#SBATCH --cpus-per-task=8 # number of cores per task (8x8 = 64 cores, or all the cores)
#SBATCH --time=2-00:00:00
#SBATCH --output=/home/rcadene/slurm/%j.out
#SBATCH --error=/home/rcadene/slurm/%j.err
#SBATCH --qos=low
#SBATCH --mail-user=re.cadene@gmail.com
#SBATCH --mail-type=ALL
CMD=$@
echo "command: $CMD"
apptainer exec --nv \
~/apptainer/nvidia_cuda:12.2.2-devel-ubuntu22.04.sif $SHELL
source ~/.bashrc
#conda activate fowm
conda activate lerobot
export DATA_DIR="data"
srun $CMD

View File

@ -1,17 +0,0 @@
#!/bin/bash
#SBATCH --nodes=1 # total number of nodes (N to be defined)
#SBATCH --ntasks-per-node=1 # number of tasks per node (here 8 tasks, or 1 task per GPU)
#SBATCH --qos=normal # number of GPUs reserved per node (here 8, or all the GPUs)
#SBATCH --partition=hopper-prod
#SBATCH --gres=gpu:1 # number of GPUs reserved per node (here 8, or all the GPUs)
#SBATCH --cpus-per-task=12 # number of cores per task
#SBATCH --mem-per-cpu=11G
#SBATCH --time=12:00:00
#SBATCH --output=/admin/home/remi_cadene/slurm/%j.out
#SBATCH --error=/admin/home/remi_cadene/slurm/%j.err
#SBATCH --mail-user=remi_cadene@huggingface.co
#SBATCH --mail-type=ALL
CMD=$@
echo "command: $CMD"
srun $CMD

Binary file not shown.

View File

@ -1,64 +1,53 @@
"""
This test verifies that all environments, datasets, policies listed in `lerobot/__init__.py` can be sucessfully
imported and that their class attributes (eg. `available_datasets`, `name`, `available_tasks`) corresponds.
imported and that their class attributes (eg. `available_datasets`, `name`, `available_tasks`) are valid.
Note:
When implementing a concrete class (e.g. `AlohaDataset`, `PushtEnv`, `DiffusionPolicy`), you need to:
1. set the required class attributes:
- for classes inheriting from `AbstractDataset`: `available_datasets`
- for classes inheriting from `AbstractEnv`: `name`, `available_tasks`
- for classes inheriting from `AbstractPolicy`: `name`
2. update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`)
3. update variables in `tests/test_available.py` by importing your new class
When implementing a new dataset (e.g. `AlohaDataset`), policy (e.g. `DiffusionPolicy`), or environment, follow these steps:
- Set the required class attributes: `available_datasets`.
- Set the required class attributes: `name`.
- Update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`)
- Update variables in `tests/test_available.py` by importing your new class
"""
import importlib
import pytest
import lerobot
import gymnasium as gym
# from lerobot.common.envs.aloha.env import AlohaEnv
# from gym_pusht.envs import PushtEnv
# from gym_xarm.envs import SimxarmEnv
from lerobot.common.datasets.xarm import XarmDataset
from lerobot.common.datasets.aloha import AlohaDataset
from lerobot.common.datasets.pusht import PushtDataset
# from lerobot.common.datasets.xarm import SimxarmDataset
# from lerobot.common.datasets.aloha import AlohaDataset
# from lerobot.common.datasets.pusht import PushtDataset
# from lerobot.common.policies.act.policy import ActionChunkingTransformerPolicy
# from lerobot.common.policies.diffusion.policy import DiffusionPolicy
# from lerobot.common.policies.tdmpc.policy import TDMPCPolicy
from lerobot.common.policies.act.policy import ActionChunkingTransformerPolicy
from lerobot.common.policies.diffusion.policy import DiffusionPolicy
from lerobot.common.policies.tdmpc.policy import TDMPCPolicy
# def test_available():
# pol_classes = [
# ActionChunkingTransformerPolicy,
# DiffusionPolicy,
# TDMPCPolicy,
# ]
def test_available():
policy_classes = [
ActionChunkingTransformerPolicy,
DiffusionPolicy,
TDMPCPolicy,
]
# env_classes = [
# AlohaEnv,
# PushtEnv,
# SimxarmEnv,
# ]
dataset_class_per_env = {
"aloha": AlohaDataset,
"pusht": PushtDataset,
"xarm": XarmDataset,
}
# dat_classes = [
# AlohaDataset,
# PushtDataset,
# SimxarmDataset,
# ]
policies = [pol_cls.name for pol_cls in policy_classes]
assert set(policies) == set(lerobot.available_policies), policies
# policies = [pol_cls.name for pol_cls in pol_classes]
# assert set(policies) == set(lerobot.available_policies)
for env_name in lerobot.available_envs:
for task_name in lerobot.available_tasks_per_env[env_name]:
package_name = f"gym_{env_name}"
importlib.import_module(package_name)
gym_handle = f"{package_name}/{task_name}"
assert gym_handle in gym.envs.registry.keys(), gym_handle
# envs = [env_cls.name for env_cls in env_classes]
# assert set(envs) == set(lerobot.available_envs)
# tasks_per_env = {env_cls.name: env_cls.available_tasks for env_cls in env_classes}
# for env in envs:
# assert set(tasks_per_env[env]) == set(lerobot.available_tasks_per_env[env])
# datasets_per_env = {env_cls.name: dat_cls.available_datasets for env_cls, dat_cls in zip(env_classes, dat_classes)}
# for env in envs:
# assert set(datasets_per_env[env]) == set(lerobot.available_datasets_per_env[env])
dataset_class = dataset_class_per_env[env_name]
available_datasets = lerobot.available_datasets_per_env[env_name]
assert set(available_datasets) == set(dataset_class.available_datasets), f"{env_name=} {available_datasets=}"

View File

@ -1,6 +1,12 @@
import os
from pathlib import Path
import einops
import pytest
import torch
from lerobot.common.datasets.utils import compute_stats, get_stats_einops_patterns, load_data_with_delta_timestamps
from lerobot.common.datasets.xarm import XarmDataset
from lerobot.common.transforms import Prod
from lerobot.common.utils import init_hydra_config
import logging
from lerobot.common.datasets.factory import make_dataset
@ -45,6 +51,7 @@ def test_factory(env_name, dataset_id, policy_name):
keys_ndim_required.append(
(key, 3, True),
)
assert dataset.data_dict[key].dtype == torch.uint8, f"{key}"
# test number of dimensions
for key, ndim, required in keys_ndim_required:
@ -81,28 +88,104 @@ def test_factory(env_name, dataset_id, policy_name):
assert key in item, f"{key}"
# def test_compute_stats():
# """Check that the statistics are computed correctly according to the stats_patterns property.
def test_compute_stats():
"""Check that the statistics are computed correctly according to the stats_patterns property.
We compare with taking a straight min, mean, max, std of all the data in one pass (which we can do
because we are working with a small dataset).
"""
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
# get transform to convert images from uint8 [0,255] to float32 [0,1]
transform = Prod(in_keys=XarmDataset.image_keys, prod=1 / 255.0)
dataset = XarmDataset(
dataset_id="xarm_lift_medium",
root=DATA_DIR,
transform=transform,
)
# Note: we set the batch size to be smaller than the whole dataset to make sure we are testing batched
# computation of the statistics. While doing this, we also make sure it works when we don't divide the
# dataset into even batches.
computed_stats = compute_stats(dataset, batch_size=int(len(dataset) * 0.25))
# get einops patterns to aggregate batches and compute statistics
stats_patterns = get_stats_einops_patterns(dataset)
# get all frames from the dataset in the same dtype and range as during compute_stats
data_dict = transform(dataset.data_dict)
# compute stats based on all frames from the dataset without any batching
expected_stats = {}
for k, pattern in stats_patterns.items():
expected_stats[k] = {}
expected_stats[k]["mean"] = einops.reduce(data_dict[k], pattern, "mean")
expected_stats[k]["std"] = torch.sqrt(einops.reduce((data_dict[k] - expected_stats[k]["mean"]) ** 2, pattern, "mean"))
expected_stats[k]["min"] = einops.reduce(data_dict[k], pattern, "min")
expected_stats[k]["max"] = einops.reduce(data_dict[k], pattern, "max")
# test computed stats match expected stats
for k in stats_patterns:
assert torch.allclose(computed_stats[k]["mean"], expected_stats[k]["mean"])
assert torch.allclose(computed_stats[k]["std"], expected_stats[k]["std"])
assert torch.allclose(computed_stats[k]["min"], expected_stats[k]["min"])
assert torch.allclose(computed_stats[k]["max"], expected_stats[k]["max"])
# TODO(rcadene): check that the stats used for training are correct too
# # load stats that are expected to match the ones returned by computed_stats
# assert (dataset.data_dir / "stats.pth").exists()
# loaded_stats = torch.load(dataset.data_dir / "stats.pth")
# # test loaded stats match expected stats
# for k in stats_patterns:
# assert torch.allclose(loaded_stats[k]["mean"], expected_stats[k]["mean"])
# assert torch.allclose(loaded_stats[k]["std"], expected_stats[k]["std"])
# assert torch.allclose(loaded_stats[k]["min"], expected_stats[k]["min"])
# assert torch.allclose(loaded_stats[k]["max"], expected_stats[k]["max"])
def test_load_data_with_delta_timestamps_within_tolerance():
data_dict = {
"timestamp": torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5]),
"index": torch.tensor([0, 1, 2, 3, 4]),
}
data_ids_per_episode = {0: torch.tensor([0, 1, 2, 3, 4])}
delta_timestamps = {"index": [-0.2, 0, 0.139]}
key = "index"
current_ts = 0.3
episode = 0
tol = 0.04
data, is_pad = load_data_with_delta_timestamps(data_dict, data_ids_per_episode, delta_timestamps, key, current_ts, episode, tol)
assert not is_pad.any(), "Unexpected padding detected"
assert torch.equal(data, torch.tensor([0, 2, 3])), "Data does not match expected values"
def test_load_data_with_delta_timestamps_outside_tolerance_inside_episode_range():
data_dict = {
"timestamp": torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5]),
"index": torch.tensor([0, 1, 2, 3, 4]),
}
data_ids_per_episode = {0: torch.tensor([0, 1, 2, 3, 4])}
delta_timestamps = {"index": [-0.2, 0, 0.141]}
key = "index"
current_ts = 0.3
episode = 0
tol = 0.04
with pytest.raises(AssertionError):
load_data_with_delta_timestamps(data_dict, data_ids_per_episode, delta_timestamps, key, current_ts, episode, tol)
def test_load_data_with_delta_timestamps_outside_tolerance_outside_episode_range():
data_dict = {
"timestamp": torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5]),
"index": torch.tensor([0, 1, 2, 3, 4]),
}
data_ids_per_episode = {0: torch.tensor([0, 1, 2, 3, 4])}
delta_timestamps = {"index": [-0.3, -0.24, 0, 0.26, 0.3]}
key = "index"
current_ts = 0.3
episode = 0
tol = 0.04
data, is_pad = load_data_with_delta_timestamps(data_dict, data_ids_per_episode, delta_timestamps, key, current_ts, episode, tol)
assert torch.equal(is_pad, torch.tensor([True, False, False, True, True])), "Padding does not match expected values"
assert torch.equal(data, torch.tensor([0, 0, 2, 4, 4])), "Data does not match expected values"
# We compare with taking a straight min, mean, max, std of all the data in one pass (which we can do
# because we are working with a small dataset).
# """
# cfg = init_hydra_config(
# DEFAULT_CONFIG_PATH, overrides=["env=aloha", "env.task=sim_transfer_cube_human"]
# )
# dataset = make_dataset(cfg)
# # Get all of the data.
# all_data = dataset.data_dict
# # Note: we set the batch size to be smaller than the whole dataset to make sure we are testing batched
# # computation of the statistics. While doing this, we also make sure it works when we don't divide the
# # dataset into even batches.
# computed_stats = buffer._compute_stats(batch_size=int(len(all_data) * 0.75))
# for k, pattern in buffer.stats_patterns.items():
# expected_mean = einops.reduce(all_data[k], pattern, "mean")
# assert torch.allclose(computed_stats[k]["mean"], expected_mean)
# assert torch.allclose(
# computed_stats[k]["std"],
# torch.sqrt(einops.reduce((all_data[k] - expected_mean) ** 2, pattern, "mean"))
# )
# assert torch.allclose(computed_stats[k]["min"], einops.reduce(all_data[k], pattern, "min"))
# assert torch.allclose(computed_stats[k]["max"], einops.reduce(all_data[k], pattern, "max"))