Merge remote-tracking branch 'upstream/main' into refactor_dp
This commit is contained in:
commit
94cc22da9e
|
@ -940,7 +940,7 @@ mujoco = "^2.3.7"
|
||||||
type = "git"
|
type = "git"
|
||||||
url = "git@github.com:huggingface/gym-xarm.git"
|
url = "git@github.com:huggingface/gym-xarm.git"
|
||||||
reference = "HEAD"
|
reference = "HEAD"
|
||||||
resolved_reference = "08ddd5a9400783a6898bbf3c3014fc5da3961b9d"
|
resolved_reference = "6a88f7d63833705dfbec4b997bf36cac6b4a448c"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "gymnasium"
|
name = "gymnasium"
|
||||||
|
|
|
@ -142,6 +142,7 @@ jobs:
|
||||||
wandb.enable=False \
|
wandb.enable=False \
|
||||||
offline_steps=2 \
|
offline_steps=2 \
|
||||||
online_steps=0 \
|
online_steps=0 \
|
||||||
|
eval_episodes=1 \
|
||||||
device=cpu \
|
device=cpu \
|
||||||
save_model=true \
|
save_model=true \
|
||||||
save_freq=2 \
|
save_freq=2 \
|
||||||
|
@ -159,17 +160,6 @@ jobs:
|
||||||
device=cpu \
|
device=cpu \
|
||||||
policy.pretrained_model_path=tests/outputs/act/models/2.pt
|
policy.pretrained_model_path=tests/outputs/act/models/2.pt
|
||||||
|
|
||||||
# TODO(aliberts): This takes ~2mn to run, needs to be improved
|
|
||||||
# - name: Test eval ACT on ALOHA end-to-end (policy is None)
|
|
||||||
# run: |
|
|
||||||
# source .venv/bin/activate
|
|
||||||
# python lerobot/scripts/eval.py \
|
|
||||||
# --config lerobot/configs/default.yaml \
|
|
||||||
# policy=act \
|
|
||||||
# env=aloha \
|
|
||||||
# eval_episodes=1 \
|
|
||||||
# device=cpu
|
|
||||||
|
|
||||||
- name: Test train Diffusion on PushT end-to-end
|
- name: Test train Diffusion on PushT end-to-end
|
||||||
run: |
|
run: |
|
||||||
source .venv/bin/activate
|
source .venv/bin/activate
|
||||||
|
@ -179,9 +169,11 @@ jobs:
|
||||||
wandb.enable=False \
|
wandb.enable=False \
|
||||||
offline_steps=2 \
|
offline_steps=2 \
|
||||||
online_steps=0 \
|
online_steps=0 \
|
||||||
|
eval_episodes=1 \
|
||||||
device=cpu \
|
device=cpu \
|
||||||
save_model=true \
|
save_model=true \
|
||||||
save_freq=2 \
|
save_freq=2 \
|
||||||
|
policy.batch_size=2 \
|
||||||
hydra.run.dir=tests/outputs/diffusion/
|
hydra.run.dir=tests/outputs/diffusion/
|
||||||
|
|
||||||
- name: Test eval Diffusion on PushT end-to-end
|
- name: Test eval Diffusion on PushT end-to-end
|
||||||
|
@ -194,16 +186,6 @@ jobs:
|
||||||
device=cpu \
|
device=cpu \
|
||||||
policy.pretrained_model_path=tests/outputs/diffusion/models/2.pt
|
policy.pretrained_model_path=tests/outputs/diffusion/models/2.pt
|
||||||
|
|
||||||
- name: Test eval Diffusion on PushT end-to-end (policy is None)
|
|
||||||
run: |
|
|
||||||
source .venv/bin/activate
|
|
||||||
python lerobot/scripts/eval.py \
|
|
||||||
--config lerobot/configs/default.yaml \
|
|
||||||
policy=diffusion \
|
|
||||||
env=pusht \
|
|
||||||
eval_episodes=1 \
|
|
||||||
device=cpu
|
|
||||||
|
|
||||||
- name: Test train TDMPC on Simxarm end-to-end
|
- name: Test train TDMPC on Simxarm end-to-end
|
||||||
run: |
|
run: |
|
||||||
source .venv/bin/activate
|
source .venv/bin/activate
|
||||||
|
@ -213,9 +195,11 @@ jobs:
|
||||||
wandb.enable=False \
|
wandb.enable=False \
|
||||||
offline_steps=1 \
|
offline_steps=1 \
|
||||||
online_steps=1 \
|
online_steps=1 \
|
||||||
|
eval_episodes=1 \
|
||||||
device=cpu \
|
device=cpu \
|
||||||
save_model=true \
|
save_model=true \
|
||||||
save_freq=2 \
|
save_freq=2 \
|
||||||
|
policy.batch_size=2 \
|
||||||
hydra.run.dir=tests/outputs/tdmpc/
|
hydra.run.dir=tests/outputs/tdmpc/
|
||||||
|
|
||||||
- name: Test eval TDMPC on Simxarm end-to-end
|
- name: Test eval TDMPC on Simxarm end-to-end
|
||||||
|
@ -227,13 +211,3 @@ jobs:
|
||||||
env.episode_length=8 \
|
env.episode_length=8 \
|
||||||
device=cpu \
|
device=cpu \
|
||||||
policy.pretrained_model_path=tests/outputs/tdmpc/models/2.pt
|
policy.pretrained_model_path=tests/outputs/tdmpc/models/2.pt
|
||||||
|
|
||||||
- name: Test eval TDPMC on Simxarm end-to-end (policy is None)
|
|
||||||
run: |
|
|
||||||
source .venv/bin/activate
|
|
||||||
python lerobot/scripts/eval.py \
|
|
||||||
--config lerobot/configs/default.yaml \
|
|
||||||
policy=tdmpc \
|
|
||||||
env=xarm \
|
|
||||||
eval_episodes=1 \
|
|
||||||
device=cpu
|
|
||||||
|
|
|
@ -11,6 +11,9 @@ rl
|
||||||
nautilus/*.yaml
|
nautilus/*.yaml
|
||||||
*.key
|
*.key
|
||||||
|
|
||||||
|
# Slurm
|
||||||
|
sbatch*.sh
|
||||||
|
|
||||||
# Byte-compiled / optimized / DLL files
|
# Byte-compiled / optimized / DLL files
|
||||||
__pycache__/
|
__pycache__/
|
||||||
*.py[cod]
|
*.py[cod]
|
||||||
|
|
18
README.md
18
README.md
|
@ -120,34 +120,32 @@ wandb login
|
||||||
You can import our dataset class, download the data from the HuggingFace hub and use our rendering utilities:
|
You can import our dataset class, download the data from the HuggingFace hub and use our rendering utilities:
|
||||||
```python
|
```python
|
||||||
""" Copy pasted from `examples/1_visualize_dataset.py` """
|
""" Copy pasted from `examples/1_visualize_dataset.py` """
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import lerobot
|
import lerobot
|
||||||
from lerobot.common.datasets.aloha import AlohaDataset
|
from lerobot.common.datasets.aloha import AlohaDataset
|
||||||
from torchrl.data.replay_buffers import SamplerWithoutReplacement
|
|
||||||
from lerobot.scripts.visualize_dataset import render_dataset
|
from lerobot.scripts.visualize_dataset import render_dataset
|
||||||
|
|
||||||
print(lerobot.available_datasets)
|
print(lerobot.available_datasets)
|
||||||
# >>> ['aloha_sim_insertion_human', 'aloha_sim_insertion_scripted', 'aloha_sim_transfer_cube_human', 'aloha_sim_transfer_cube_scripted', 'pusht', 'xarm_lift_medium']
|
# >>> ['aloha_sim_insertion_human', 'aloha_sim_insertion_scripted', 'aloha_sim_transfer_cube_human', 'aloha_sim_transfer_cube_scripted', 'pusht', 'xarm_lift_medium']
|
||||||
|
|
||||||
# we use this sampler to sample 1 frame after the other
|
# TODO(rcadene): remove DATA_DIR
|
||||||
sampler = SamplerWithoutReplacement(shuffle=False)
|
dataset = AlohaDataset("pusht", root=Path(os.environ.get("DATA_DIR")))
|
||||||
|
|
||||||
dataset = AlohaDataset("aloha_sim_transfer_cube_human", sampler=sampler)
|
|
||||||
|
|
||||||
video_paths = render_dataset(
|
video_paths = render_dataset(
|
||||||
dataset,
|
dataset,
|
||||||
out_dir="outputs/visualize_dataset/example",
|
out_dir="outputs/visualize_dataset/example",
|
||||||
max_num_samples=300,
|
max_num_episodes=1,
|
||||||
fps=50,
|
|
||||||
)
|
)
|
||||||
print(video_paths)
|
print(video_paths)
|
||||||
# >>> ['outputs/visualize_dataset/example/episode_0.mp4']
|
# ['outputs/visualize_dataset/example/episode_0.mp4']
|
||||||
```
|
```
|
||||||
|
|
||||||
Or you can achieve the same result by executing our script from the command line:
|
Or you can achieve the same result by executing our script from the command line:
|
||||||
```bash
|
```bash
|
||||||
python lerobot/scripts/visualize_dataset.py \
|
python lerobot/scripts/visualize_dataset.py \
|
||||||
env=aloha \
|
env=pusht \
|
||||||
task=sim_sim_transfer_cube_human \
|
|
||||||
hydra.run.dir=outputs/visualize_dataset/example
|
hydra.run.dir=outputs/visualize_dataset/example
|
||||||
# >>> ['outputs/visualize_dataset/example/episode_0.mp4']
|
# >>> ['outputs/visualize_dataset/example/episode_0.mp4']
|
||||||
```
|
```
|
||||||
|
|
|
@ -1,24 +1,20 @@
|
||||||
import os
|
import os
|
||||||
|
from pathlib import Path
|
||||||
from torchrl.data.replay_buffers import SamplerWithoutReplacement
|
|
||||||
|
|
||||||
import lerobot
|
import lerobot
|
||||||
from lerobot.common.datasets.aloha import AlohaDataset
|
from lerobot.common.datasets.pusht import PushtDataset
|
||||||
from lerobot.scripts.visualize_dataset import render_dataset
|
from lerobot.scripts.visualize_dataset import render_dataset
|
||||||
|
|
||||||
print(lerobot.available_datasets)
|
print(lerobot.available_datasets)
|
||||||
# >>> ['aloha_sim_insertion_human', 'aloha_sim_insertion_scripted', 'aloha_sim_transfer_cube_human', 'aloha_sim_transfer_cube_scripted', 'pusht', 'xarm_lift_medium']
|
# >>> ['aloha_sim_insertion_human', 'aloha_sim_insertion_scripted', 'aloha_sim_transfer_cube_human', 'aloha_sim_transfer_cube_scripted', 'pusht', 'xarm_lift_medium']
|
||||||
|
|
||||||
# we use this sampler to sample 1 frame after the other
|
# TODO(rcadene): remove DATA_DIR
|
||||||
sampler = SamplerWithoutReplacement(shuffle=False)
|
dataset = PushtDataset("pusht", root=Path(os.environ.get("DATA_DIR")))
|
||||||
|
|
||||||
dataset = AlohaDataset("aloha_sim_transfer_cube_human", sampler=sampler, root=os.environ.get("DATA_DIR"))
|
|
||||||
|
|
||||||
video_paths = render_dataset(
|
video_paths = render_dataset(
|
||||||
dataset,
|
dataset,
|
||||||
out_dir="outputs/visualize_dataset/example",
|
out_dir="outputs/visualize_dataset/example",
|
||||||
max_num_samples=300,
|
max_num_episodes=1,
|
||||||
fps=50,
|
|
||||||
)
|
)
|
||||||
print(video_paths)
|
print(video_paths)
|
||||||
# ['outputs/visualize_dataset/example/episode_0.mp4']
|
# ['outputs/visualize_dataset/example/episode_0.mp4']
|
||||||
|
|
|
@ -9,9 +9,8 @@ from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from tqdm import trange
|
|
||||||
|
|
||||||
from lerobot.common.datasets.factory import make_offline_buffer
|
from lerobot.common.datasets.factory import make_dataset
|
||||||
from lerobot.common.policies.diffusion.policy import DiffusionPolicy
|
from lerobot.common.policies.diffusion.policy import DiffusionPolicy
|
||||||
from lerobot.common.utils import init_hydra_config
|
from lerobot.common.utils import init_hydra_config
|
||||||
|
|
||||||
|
@ -37,19 +36,33 @@ policy = DiffusionPolicy(
|
||||||
cfg_obs_encoder=cfg.obs_encoder,
|
cfg_obs_encoder=cfg.obs_encoder,
|
||||||
cfg_optimizer=cfg.optimizer,
|
cfg_optimizer=cfg.optimizer,
|
||||||
cfg_ema=cfg.ema,
|
cfg_ema=cfg.ema,
|
||||||
n_action_steps=cfg.n_action_steps,
|
|
||||||
**cfg.policy,
|
**cfg.policy,
|
||||||
)
|
)
|
||||||
policy.train()
|
policy.train()
|
||||||
|
|
||||||
offline_buffer = make_offline_buffer(cfg)
|
dataset = make_dataset(cfg)
|
||||||
|
|
||||||
|
# create dataloader for offline training
|
||||||
|
dataloader = torch.utils.data.DataLoader(
|
||||||
|
dataset,
|
||||||
|
num_workers=4,
|
||||||
|
batch_size=cfg.policy.batch_size,
|
||||||
|
shuffle=True,
|
||||||
|
pin_memory=cfg.device != "cpu",
|
||||||
|
drop_last=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
for step, batch in enumerate(dataloader):
|
||||||
|
info = policy(batch, step)
|
||||||
|
|
||||||
|
if step % cfg.log_freq == 0:
|
||||||
|
num_samples = (step + 1) * cfg.policy.batch_size
|
||||||
|
loss = info["loss"]
|
||||||
|
update_s = info["update_s"]
|
||||||
|
print(f"step:{step} samples:{num_samples} loss:{loss:.3f} update_time:{update_s:.3f}(seconds)")
|
||||||
|
|
||||||
for offline_step in trange(cfg.offline_steps):
|
|
||||||
train_info = policy.update(offline_buffer, offline_step)
|
|
||||||
if offline_step % cfg.log_freq == 0:
|
|
||||||
print(train_info)
|
|
||||||
|
|
||||||
# Save the policy, configuration, and normalization stats for later use.
|
# Save the policy, configuration, and normalization stats for later use.
|
||||||
policy.save(output_directory / "model.pt")
|
policy.save(output_directory / "model.pt")
|
||||||
OmegaConf.save(cfg, output_directory / "config.yaml")
|
OmegaConf.save(cfg, output_directory / "config.yaml")
|
||||||
torch.save(offline_buffer.transform[-1].stats, output_directory / "stats.pth")
|
torch.save(dataset.transform.transforms[-1].stats, output_directory / "stats.pth")
|
||||||
|
|
|
@ -12,14 +12,11 @@ Example:
|
||||||
print(lerobot.available_policies)
|
print(lerobot.available_policies)
|
||||||
```
|
```
|
||||||
|
|
||||||
Note:
|
When implementing a new dataset (e.g. `AlohaDataset`), policy (e.g. `DiffusionPolicy`), or environment, follow these steps:
|
||||||
When implementing a concrete class (e.g. `AlohaDataset`, `PushtEnv`, `DiffusionPolicy`), you need to:
|
- Set the required class attributes: `available_datasets`.
|
||||||
1. set the required class attributes:
|
- Set the required class attributes: `name`.
|
||||||
- for classes inheriting from `AbstractDataset`: `available_datasets`
|
- Update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`)
|
||||||
- for classes inheriting from `AbstractEnv`: `name`, `available_tasks`
|
- Update variables in `tests/test_available.py` by importing your new class
|
||||||
- for classes inheriting from `AbstractPolicy`: `name`
|
|
||||||
2. update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`)
|
|
||||||
3. update variables in `tests/test_available.py` by importing your new class
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from lerobot.__version__ import __version__ # noqa: F401
|
from lerobot.__version__ import __version__ # noqa: F401
|
||||||
|
@ -32,11 +29,11 @@ available_envs = [
|
||||||
|
|
||||||
available_tasks_per_env = {
|
available_tasks_per_env = {
|
||||||
"aloha": [
|
"aloha": [
|
||||||
"sim_insertion",
|
"AlohaInsertion-v0",
|
||||||
"sim_transfer_cube",
|
"AlohaTransferCube-v0",
|
||||||
],
|
],
|
||||||
"pusht": ["pusht"],
|
"pusht": ["PushT-v0"],
|
||||||
"xarm": ["lift"],
|
"xarm": ["XarmLift-v0"],
|
||||||
}
|
}
|
||||||
|
|
||||||
available_datasets_per_env = {
|
available_datasets_per_env = {
|
||||||
|
|
|
@ -105,7 +105,7 @@ class AlohaDataset(torch.utils.data.Dataset):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_samples(self) -> int:
|
def num_samples(self) -> int:
|
||||||
return len(self.data_dict["index"])
|
return len(self.data_dict["index"]) if "index" in self.data_dict else 0
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_episodes(self) -> int:
|
def num_episodes(self) -> int:
|
||||||
|
|
|
@ -1,10 +1,11 @@
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torchvision.transforms import v2
|
from torchvision.transforms import v2
|
||||||
|
|
||||||
from lerobot.common.datasets.utils import compute_or_load_stats
|
from lerobot.common.datasets.utils import compute_stats
|
||||||
from lerobot.common.transforms import NormalizeTransform, Prod
|
from lerobot.common.transforms import NormalizeTransform, Prod
|
||||||
|
|
||||||
# DATA_DIR specifies to location where datasets are loaded. By default, DATA_DIR is None and
|
# DATA_DIR specifies to location where datasets are loaded. By default, DATA_DIR is None and
|
||||||
|
@ -40,7 +41,8 @@ def make_dataset(
|
||||||
if normalize:
|
if normalize:
|
||||||
# TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max,
|
# TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max,
|
||||||
# min_max_from_spec
|
# min_max_from_spec
|
||||||
# stats = dataset.compute_or_load_stats() if stats_path is None else torch.load(stats_path)
|
# TODO(rcadene): remove this and put it in config. Ideally we want to reproduce SOTA results just with mean_std
|
||||||
|
normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max"
|
||||||
|
|
||||||
if cfg.policy.name == "diffusion" and cfg.env.name == "pusht":
|
if cfg.policy.name == "diffusion" and cfg.env.name == "pusht":
|
||||||
stats = {}
|
stats = {}
|
||||||
|
@ -51,21 +53,27 @@ def make_dataset(
|
||||||
stats["action"] = {}
|
stats["action"] = {}
|
||||||
stats["action"]["min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
|
stats["action"]["min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
|
||||||
stats["action"]["max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
|
stats["action"]["max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
|
||||||
else:
|
elif stats_path is None:
|
||||||
# instantiate a one frame dataset with light transform
|
# instantiate a one frame dataset with light transform
|
||||||
stats_dataset = clsfunc(
|
stats_dataset = clsfunc(
|
||||||
dataset_id=cfg.dataset_id,
|
dataset_id=cfg.dataset_id,
|
||||||
root=DATA_DIR,
|
root=DATA_DIR,
|
||||||
transform=Prod(in_keys=clsfunc.image_keys, prod=1 / 255.0),
|
transform=Prod(in_keys=clsfunc.image_keys, prod=1 / 255.0),
|
||||||
)
|
)
|
||||||
stats = compute_or_load_stats(stats_dataset)
|
|
||||||
|
|
||||||
# TODO(rcadene): remove this and put it in config. Ideally we want to reproduce SOTA results just with mean_std
|
# load stats if the file exists already or compute stats and save it
|
||||||
normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max"
|
precomputed_stats_path = stats_dataset.data_dir / "stats.pth"
|
||||||
|
if precomputed_stats_path.exists():
|
||||||
|
stats = torch.load(precomputed_stats_path)
|
||||||
|
else:
|
||||||
|
logging.info(f"compute_stats and save to {precomputed_stats_path}")
|
||||||
|
stats = compute_stats(stats_dataset)
|
||||||
|
torch.save(stats, stats_path)
|
||||||
|
else:
|
||||||
|
stats = torch.load(stats_path)
|
||||||
|
|
||||||
transforms = v2.Compose(
|
transforms = v2.Compose(
|
||||||
[
|
[
|
||||||
# TODO(rcadene): we need to do something about image_keys
|
|
||||||
Prod(in_keys=clsfunc.image_keys, prod=1 / 255.0),
|
Prod(in_keys=clsfunc.image_keys, prod=1 / 255.0),
|
||||||
NormalizeTransform(
|
NormalizeTransform(
|
||||||
stats,
|
stats,
|
||||||
|
|
|
@ -2,11 +2,8 @@ from pathlib import Path
|
||||||
|
|
||||||
import einops
|
import einops
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pygame
|
|
||||||
import pymunk
|
|
||||||
import torch
|
import torch
|
||||||
import tqdm
|
import tqdm
|
||||||
from gym_pusht.envs.pusht import pymunk_to_shapely
|
|
||||||
|
|
||||||
from lerobot.common.datasets._diffusion_policy_replay_buffer import (
|
from lerobot.common.datasets._diffusion_policy_replay_buffer import (
|
||||||
ReplayBuffer as DiffusionPolicyReplayBuffer,
|
ReplayBuffer as DiffusionPolicyReplayBuffer,
|
||||||
|
@ -20,64 +17,6 @@ PUSHT_URL = "https://diffusion-policy.cs.columbia.edu/data/training/pusht.zip"
|
||||||
PUSHT_ZARR = Path("pusht/pusht_cchi_v7_replay.zarr")
|
PUSHT_ZARR = Path("pusht/pusht_cchi_v7_replay.zarr")
|
||||||
|
|
||||||
|
|
||||||
def get_goal_pose_body(pose):
|
|
||||||
mass = 1
|
|
||||||
inertia = pymunk.moment_for_box(mass, (50, 100))
|
|
||||||
body = pymunk.Body(mass, inertia)
|
|
||||||
# preserving the legacy assignment order for compatibility
|
|
||||||
# the order here doesn't matter somehow, maybe because CoM is aligned with body origin
|
|
||||||
body.position = pose[:2].tolist()
|
|
||||||
body.angle = pose[2]
|
|
||||||
return body
|
|
||||||
|
|
||||||
|
|
||||||
def add_segment(space, a, b, radius):
|
|
||||||
shape = pymunk.Segment(space.static_body, a, b, radius)
|
|
||||||
shape.color = pygame.Color("LightGray") # https://htmlcolorcodes.com/color-names
|
|
||||||
return shape
|
|
||||||
|
|
||||||
|
|
||||||
def add_tee(
|
|
||||||
space,
|
|
||||||
position,
|
|
||||||
angle,
|
|
||||||
scale=30,
|
|
||||||
color="LightSlateGray",
|
|
||||||
mask=None,
|
|
||||||
):
|
|
||||||
if mask is None:
|
|
||||||
mask = pymunk.ShapeFilter.ALL_MASKS()
|
|
||||||
mass = 1
|
|
||||||
length = 4
|
|
||||||
vertices1 = [
|
|
||||||
(-length * scale / 2, scale),
|
|
||||||
(length * scale / 2, scale),
|
|
||||||
(length * scale / 2, 0),
|
|
||||||
(-length * scale / 2, 0),
|
|
||||||
]
|
|
||||||
inertia1 = pymunk.moment_for_poly(mass, vertices=vertices1)
|
|
||||||
vertices2 = [
|
|
||||||
(-scale / 2, scale),
|
|
||||||
(-scale / 2, length * scale),
|
|
||||||
(scale / 2, length * scale),
|
|
||||||
(scale / 2, scale),
|
|
||||||
]
|
|
||||||
inertia2 = pymunk.moment_for_poly(mass, vertices=vertices1)
|
|
||||||
body = pymunk.Body(mass, inertia1 + inertia2)
|
|
||||||
shape1 = pymunk.Poly(body, vertices1)
|
|
||||||
shape2 = pymunk.Poly(body, vertices2)
|
|
||||||
shape1.color = pygame.Color(color)
|
|
||||||
shape2.color = pygame.Color(color)
|
|
||||||
shape1.filter = pymunk.ShapeFilter(mask=mask)
|
|
||||||
shape2.filter = pymunk.ShapeFilter(mask=mask)
|
|
||||||
body.center_of_gravity = (shape1.center_of_gravity + shape2.center_of_gravity) / 2
|
|
||||||
body.position = position
|
|
||||||
body.angle = angle
|
|
||||||
body.friction = 1
|
|
||||||
space.add(body, shape1, shape2)
|
|
||||||
return body
|
|
||||||
|
|
||||||
|
|
||||||
class PushtDataset(torch.utils.data.Dataset):
|
class PushtDataset(torch.utils.data.Dataset):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -121,7 +60,7 @@ class PushtDataset(torch.utils.data.Dataset):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_samples(self) -> int:
|
def num_samples(self) -> int:
|
||||||
return len(self.data_dict["index"])
|
return len(self.data_dict["index"]) if "index" in self.data_dict else 0
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_episodes(self) -> int:
|
def num_episodes(self) -> int:
|
||||||
|
@ -158,6 +97,13 @@ class PushtDataset(torch.utils.data.Dataset):
|
||||||
return item
|
return item
|
||||||
|
|
||||||
def _download_and_preproc_obsolete(self):
|
def _download_and_preproc_obsolete(self):
|
||||||
|
try:
|
||||||
|
import pymunk
|
||||||
|
from gym_pusht.envs.pusht import PushTEnv, pymunk_to_shapely
|
||||||
|
except ModuleNotFoundError as e:
|
||||||
|
print("`gym_pusht` is not installed. Please install it with `pip install 'lerobot[gym_pusht]'`")
|
||||||
|
raise e
|
||||||
|
|
||||||
assert self.root is not None
|
assert self.root is not None
|
||||||
raw_dir = self.root / f"{self.dataset_id}_raw"
|
raw_dir = self.root / f"{self.dataset_id}_raw"
|
||||||
zarr_path = (raw_dir / PUSHT_ZARR).resolve()
|
zarr_path = (raw_dir / PUSHT_ZARR).resolve()
|
||||||
|
@ -182,7 +128,7 @@ class PushtDataset(torch.utils.data.Dataset):
|
||||||
|
|
||||||
# TODO: verify that goal pose is expected to be fixed
|
# TODO: verify that goal pose is expected to be fixed
|
||||||
goal_pos_angle = np.array([256, 256, np.pi / 4]) # x, y, theta (in radians)
|
goal_pos_angle = np.array([256, 256, np.pi / 4]) # x, y, theta (in radians)
|
||||||
goal_body = get_goal_pose_body(goal_pos_angle)
|
goal_body = PushTEnv.get_goal_pose_body(goal_pos_angle)
|
||||||
|
|
||||||
imgs = torch.from_numpy(dataset_dict["img"])
|
imgs = torch.from_numpy(dataset_dict["img"])
|
||||||
imgs = einops.rearrange(imgs, "b h w c -> b c h w")
|
imgs = einops.rearrange(imgs, "b h w c -> b c h w")
|
||||||
|
@ -201,6 +147,9 @@ class PushtDataset(torch.utils.data.Dataset):
|
||||||
assert (episode_ids[idx0:idx1] == episode_id).all()
|
assert (episode_ids[idx0:idx1] == episode_id).all()
|
||||||
|
|
||||||
image = imgs[idx0:idx1]
|
image = imgs[idx0:idx1]
|
||||||
|
assert image.min() >= 0.0
|
||||||
|
assert image.max() <= 255.0
|
||||||
|
image = image.type(torch.uint8)
|
||||||
|
|
||||||
state = states[idx0:idx1]
|
state = states[idx0:idx1]
|
||||||
agent_pos = state[:, :2]
|
agent_pos = state[:, :2]
|
||||||
|
@ -217,14 +166,14 @@ class PushtDataset(torch.utils.data.Dataset):
|
||||||
|
|
||||||
# Add walls.
|
# Add walls.
|
||||||
walls = [
|
walls = [
|
||||||
add_segment(space, (5, 506), (5, 5), 2),
|
PushTEnv.add_segment(space, (5, 506), (5, 5), 2),
|
||||||
add_segment(space, (5, 5), (506, 5), 2),
|
PushTEnv.add_segment(space, (5, 5), (506, 5), 2),
|
||||||
add_segment(space, (506, 5), (506, 506), 2),
|
PushTEnv.add_segment(space, (506, 5), (506, 506), 2),
|
||||||
add_segment(space, (5, 506), (506, 506), 2),
|
PushTEnv.add_segment(space, (5, 506), (506, 506), 2),
|
||||||
]
|
]
|
||||||
space.add(*walls)
|
space.add(*walls)
|
||||||
|
|
||||||
block_body = add_tee(space, block_pos[i].tolist(), block_angle[i].item())
|
block_body = PushTEnv.add_tee(space, block_pos[i].tolist(), block_angle[i].item())
|
||||||
goal_geom = pymunk_to_shapely(goal_body, block_body.shapes)
|
goal_geom = pymunk_to_shapely(goal_body, block_body.shapes)
|
||||||
block_geom = pymunk_to_shapely(block_body, block_body.shapes)
|
block_geom = pymunk_to_shapely(block_body, block_body.shapes)
|
||||||
intersection_area = goal_geom.intersection(block_geom).area
|
intersection_area = goal_geom.intersection(block_geom).area
|
||||||
|
@ -265,16 +214,3 @@ class PushtDataset(torch.utils.data.Dataset):
|
||||||
self.data_dict[key] = torch.cat([x[key] for x in ep_dicts])
|
self.data_dict[key] = torch.cat([x[key] for x in ep_dicts])
|
||||||
|
|
||||||
self.data_dict["index"] = torch.arange(0, total_frames, 1)
|
self.data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
dataset = PushtDataset(
|
|
||||||
"pusht",
|
|
||||||
root=Path("data"),
|
|
||||||
delta_timestamps={
|
|
||||||
"observation.image": [0, -1, -0.2, -0.1],
|
|
||||||
"observation.state": [0, -1, -0.2, -0.1],
|
|
||||||
"action": [-0.1, 0, 1, 2, 3],
|
|
||||||
},
|
|
||||||
)
|
|
||||||
dataset[10]
|
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
import io
|
import io
|
||||||
import logging
|
|
||||||
import zipfile
|
import zipfile
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from math import ceil
|
from math import ceil
|
||||||
|
@ -35,52 +34,56 @@ def download_and_extract_zip(url: str, destination_folder: Path) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def euclidean_distance_matrix(mat0, mat1):
|
|
||||||
# Compute the square of the distance matrix
|
|
||||||
sq0 = torch.sum(mat0**2, dim=1, keepdim=True)
|
|
||||||
sq1 = torch.sum(mat1**2, dim=1, keepdim=True)
|
|
||||||
distance_sq = sq0 + sq1.transpose(0, 1) - 2 * mat0 @ mat1.transpose(0, 1)
|
|
||||||
|
|
||||||
# Taking the square root to get the euclidean distance
|
|
||||||
distance = torch.sqrt(torch.clamp(distance_sq, min=0))
|
|
||||||
return distance
|
|
||||||
|
|
||||||
|
|
||||||
def is_contiguously_true_or_false(bool_vector):
|
|
||||||
assert bool_vector.ndim == 1
|
|
||||||
assert bool_vector.dtype == torch.bool
|
|
||||||
|
|
||||||
# Compare each element with its neighbor to find changes
|
|
||||||
changes = bool_vector[1:] != bool_vector[:-1]
|
|
||||||
|
|
||||||
# Count the number of changes
|
|
||||||
num_changes = changes.sum().item()
|
|
||||||
|
|
||||||
# If there's more than one change, the list is not contiguous
|
|
||||||
return num_changes <= 1
|
|
||||||
|
|
||||||
# examples = [
|
|
||||||
# ([True, False, True, False, False, False], False),
|
|
||||||
# ([True, True, True, False, False, False], True),
|
|
||||||
# ([False, False, False, False, False, False], True)
|
|
||||||
# ]
|
|
||||||
# for bool_list, expected in examples:
|
|
||||||
# result = is_contiguously_true_or_false(bool_list)
|
|
||||||
|
|
||||||
|
|
||||||
def load_data_with_delta_timestamps(
|
def load_data_with_delta_timestamps(
|
||||||
data_dict, data_ids_per_episode, delta_timestamps, key, current_ts, episode
|
data_dict: dict[torch.Tensor],
|
||||||
|
data_ids_per_episode: dict[torch.Tensor],
|
||||||
|
delta_timestamps: list[float],
|
||||||
|
key: str,
|
||||||
|
current_ts: float,
|
||||||
|
episode: int,
|
||||||
|
tol: float = 0.04,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Given a current timestamp (e.g. current_ts=0.6) and a list of timestamps differences (e.g. delta_timestamps=[-0.8, -0.2, 0, 0.2]),
|
||||||
|
this function compute the query timestamps (e.g. [-0.2, 0.4, 0.6, 0.8]) and loads the closest frames of the specified modality (e.g. key="observation.image").
|
||||||
|
|
||||||
|
Importantly, when no frame can be found around a query timestamp within a specified tolerance window (e.g. tol=0.04), this function raises an AssertionError.
|
||||||
|
When a timestamp is queried before the first available timestamp of the episode or after the last available timestamp,
|
||||||
|
the violation of the tolerance doesnt raise an AssertionError, and the function populates a boolean array indicating which frames are outside of the episode range.
|
||||||
|
For instance, this boolean array is useful during batched training to not supervise actions associated to timestamps coming after the end of the episode,
|
||||||
|
or to pad the observations in a specific way. Note that by default the observation frames before the start of the episode are the same as the first frame of the episode.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- data_dict (dict): A dictionary containing the data, where each key corresponds to a different modality (e.g., "timestamp", "observation.image", "action").
|
||||||
|
- data_ids_per_episode (dict): A dictionary where keys are episode identifiers and values are lists of indices corresponding to frames associated with each episode.
|
||||||
|
- delta_timestamps (dict): A dictionary containing lists of delta timestamps for each possible key to be retrieved. These deltas are added to the current_ts to form the query timestamps.
|
||||||
|
- key (str): The key specifying which data modality is to be retrieved from the data_dict.
|
||||||
|
- current_ts (float): The current timestamp to which the delta timestamps are added to form the query timestamps.
|
||||||
|
- episode (int): The identifier of the episode from which frames are to be retrieved.
|
||||||
|
- tol (float, optional): The tolerance level used to determine if a data point is close enough to the query timestamp. Defaults to 0.04.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- tuple: A tuple containing two elements:
|
||||||
|
- The first element is the data retrieved from the specified modality based on the closest match to the query timestamps.
|
||||||
|
- The second element is a boolean array indicating which frames were considered as padding (True if the distance to the closest timestamp was greater than the tolerance level).
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
- AssertionError: If any of the frames unexpectedly violate the tolerance level. This could indicate synchronization issues with timestamps during data collection.
|
||||||
|
"""
|
||||||
# get indices of the frames associated to the episode, and their timestamps
|
# get indices of the frames associated to the episode, and their timestamps
|
||||||
ep_data_ids = data_ids_per_episode[episode]
|
ep_data_ids = data_ids_per_episode[episode]
|
||||||
ep_timestamps = data_dict["timestamp"][ep_data_ids]
|
ep_timestamps = data_dict["timestamp"][ep_data_ids]
|
||||||
|
|
||||||
|
# we make the assumption that the timestamps are sorted
|
||||||
|
ep_first_ts = ep_timestamps[0]
|
||||||
|
ep_last_ts = ep_timestamps[-1]
|
||||||
|
|
||||||
# get timestamps used as query to retrieve data of previous/future frames
|
# get timestamps used as query to retrieve data of previous/future frames
|
||||||
delta_ts = delta_timestamps[key]
|
delta_ts = delta_timestamps[key]
|
||||||
query_ts = current_ts + torch.tensor(delta_ts)
|
query_ts = current_ts + torch.tensor(delta_ts)
|
||||||
|
|
||||||
# compute distances between each query timestamp and all timestamps of all the frames belonging to the episode
|
# compute distances between each query timestamp and all timestamps of all the frames belonging to the episode
|
||||||
dist = euclidean_distance_matrix(query_ts[:, None], ep_timestamps[:, None])
|
dist = torch.cdist(query_ts[:, None], ep_timestamps[:, None], p=1)
|
||||||
min_, argmin_ = dist.min(1)
|
min_, argmin_ = dist.min(1)
|
||||||
|
|
||||||
# get the indices of the data that are closest to the query timestamps
|
# get the indices of the data that are closest to the query timestamps
|
||||||
|
@ -92,24 +95,29 @@ def load_data_with_delta_timestamps(
|
||||||
|
|
||||||
# TODO(rcadene): synchronize timestamps + interpolation if needed
|
# TODO(rcadene): synchronize timestamps + interpolation if needed
|
||||||
|
|
||||||
tol = 0.04
|
|
||||||
is_pad = min_ > tol
|
is_pad = min_ > tol
|
||||||
|
|
||||||
assert is_contiguously_true_or_false(is_pad), (
|
# check violated query timestamps are all outside the episode range
|
||||||
f"One or several timestamps unexpectedly violate the tolerance ({min_} > {tol=})."
|
assert ((query_ts[is_pad] < ep_first_ts) | (ep_last_ts < query_ts[is_pad])).all(), (
|
||||||
|
f"One or several timestamps unexpectedly violate the tolerance ({min_} > {tol=}) inside episode range."
|
||||||
"This might be due to synchronization issues with timestamps during data collection."
|
"This might be due to synchronization issues with timestamps during data collection."
|
||||||
)
|
)
|
||||||
|
|
||||||
return data, is_pad
|
return data, is_pad
|
||||||
|
|
||||||
|
|
||||||
def compute_or_load_stats(dataset, batch_size=32, max_num_samples=None):
|
def get_stats_einops_patterns(dataset):
|
||||||
stats_path = dataset.data_dir / "stats.pth"
|
"""These einops patterns will be used to aggregate batches and compute statistics."""
|
||||||
if stats_path.exists():
|
stats_patterns = {
|
||||||
return torch.load(stats_path)
|
"action": "b c -> c",
|
||||||
|
"observation.state": "b c -> c",
|
||||||
|
}
|
||||||
|
for key in dataset.image_keys:
|
||||||
|
stats_patterns[key] = "b c h w -> c 1 1"
|
||||||
|
return stats_patterns
|
||||||
|
|
||||||
logging.info(f"compute_stats and save to {stats_path}")
|
|
||||||
|
|
||||||
|
def compute_stats(dataset, batch_size=32, max_num_samples=None):
|
||||||
if max_num_samples is None:
|
if max_num_samples is None:
|
||||||
max_num_samples = len(dataset)
|
max_num_samples = len(dataset)
|
||||||
else:
|
else:
|
||||||
|
@ -124,13 +132,8 @@ def compute_or_load_stats(dataset, batch_size=32, max_num_samples=None):
|
||||||
drop_last=False,
|
drop_last=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# these einops patterns will be used to aggregate batches and compute statistics
|
# get einops patterns to aggregate batches and compute statistics
|
||||||
stats_patterns = {
|
stats_patterns = get_stats_einops_patterns(dataset)
|
||||||
"action": "b c -> c",
|
|
||||||
"observation.state": "b c -> c",
|
|
||||||
}
|
|
||||||
for key in dataset.image_keys:
|
|
||||||
stats_patterns[key] = "b c h w -> c 1 1"
|
|
||||||
|
|
||||||
# mean and std will be computed incrementally while max and min will track the running value.
|
# mean and std will be computed incrementally while max and min will track the running value.
|
||||||
mean, std, max, min = {}, {}, {}, {}
|
mean, std, max, min = {}, {}, {}, {}
|
||||||
|
@ -201,7 +204,6 @@ def compute_or_load_stats(dataset, batch_size=32, max_num_samples=None):
|
||||||
"min": min[key],
|
"min": min[key],
|
||||||
}
|
}
|
||||||
|
|
||||||
torch.save(stats, stats_path)
|
|
||||||
return stats
|
return stats
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -60,7 +60,7 @@ class XarmDataset(torch.utils.data.Dataset):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_samples(self) -> int:
|
def num_samples(self) -> int:
|
||||||
return len(self.data_dict["index"])
|
return len(self.data_dict["index"]) if "index" in self.data_dict else 0
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_episodes(self) -> int:
|
def num_episodes(self) -> int:
|
||||||
|
@ -126,7 +126,8 @@ class XarmDataset(torch.utils.data.Dataset):
|
||||||
image = torch.tensor(dataset_dict["observations"]["rgb"][idx0:idx1])
|
image = torch.tensor(dataset_dict["observations"]["rgb"][idx0:idx1])
|
||||||
state = torch.tensor(dataset_dict["observations"]["state"][idx0:idx1])
|
state = torch.tensor(dataset_dict["observations"]["state"][idx0:idx1])
|
||||||
action = torch.tensor(dataset_dict["actions"][idx0:idx1])
|
action = torch.tensor(dataset_dict["actions"][idx0:idx1])
|
||||||
# TODO(rcadene): concat the last "next_observations" to "observations"
|
# TODO(rcadene): we have a missing last frame which is the observation when the env is done
|
||||||
|
# it is critical to have this frame for tdmpc to predict a "done observation/state"
|
||||||
# next_image = torch.tensor(dataset_dict["next_observations"]["rgb"][idx0:idx1])
|
# next_image = torch.tensor(dataset_dict["next_observations"]["rgb"][idx0:idx1])
|
||||||
# next_state = torch.tensor(dataset_dict["next_observations"]["state"][idx0:idx1])
|
# next_state = torch.tensor(dataset_dict["next_observations"]["state"][idx0:idx1])
|
||||||
next_reward = torch.tensor(dataset_dict["rewards"][idx0:idx1])
|
next_reward = torch.tensor(dataset_dict["rewards"][idx0:idx1])
|
||||||
|
|
|
@ -19,6 +19,7 @@ def preprocess_observation(observation, transform=None):
|
||||||
img = einops.rearrange(img, "b h w c -> b c h w")
|
img = einops.rearrange(img, "b h w c -> b c h w")
|
||||||
obs[imgkey] = img
|
obs[imgkey] = img
|
||||||
|
|
||||||
|
# TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing requirement for "agent_pos"
|
||||||
obs["observation.state"] = torch.from_numpy(observation["agent_pos"]).float()
|
obs["observation.state"] = torch.from_numpy(observation["agent_pos"]).float()
|
||||||
|
|
||||||
# apply same transforms as in training
|
# apply same transforms as in training
|
||||||
|
|
|
@ -29,9 +29,9 @@ def make_policy(cfg):
|
||||||
if cfg.policy.pretrained_model_path:
|
if cfg.policy.pretrained_model_path:
|
||||||
# TODO(rcadene): hack for old pretrained models from fowm
|
# TODO(rcadene): hack for old pretrained models from fowm
|
||||||
if cfg.policy.name == "tdmpc" and "fowm" in cfg.policy.pretrained_model_path:
|
if cfg.policy.name == "tdmpc" and "fowm" in cfg.policy.pretrained_model_path:
|
||||||
if "offline" in cfg.pretrained_model_path:
|
if "offline" in cfg.policy.pretrained_model_path:
|
||||||
policy.step[0] = 25000
|
policy.step[0] = 25000
|
||||||
elif "final" in cfg.pretrained_model_path:
|
elif "final" in cfg.policy.pretrained_model_path:
|
||||||
policy.step[0] = 100000
|
policy.step[0] = 100000
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
|
@ -333,94 +333,6 @@ class TDMPCPolicy(nn.Module):
|
||||||
"""Main update function. Corresponds to one iteration of the model learning."""
|
"""Main update function. Corresponds to one iteration of the model learning."""
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
# num_slices = self.cfg.batch_size
|
|
||||||
# batch_size = self.cfg.horizon * num_slices
|
|
||||||
|
|
||||||
# if demo_buffer is None:
|
|
||||||
# demo_batch_size = 0
|
|
||||||
# else:
|
|
||||||
# # Update oversampling ratio
|
|
||||||
# demo_pc_batch = h.linear_schedule(self.cfg.demo_schedule, step)
|
|
||||||
# demo_num_slices = int(demo_pc_batch * self.batch_size)
|
|
||||||
# demo_batch_size = self.cfg.horizon * demo_num_slices
|
|
||||||
# batch_size -= demo_batch_size
|
|
||||||
# num_slices -= demo_num_slices
|
|
||||||
# replay_buffer._sampler.num_slices = num_slices
|
|
||||||
# demo_buffer._sampler.num_slices = demo_num_slices
|
|
||||||
|
|
||||||
# assert demo_batch_size % self.cfg.horizon == 0
|
|
||||||
# assert demo_batch_size % demo_num_slices == 0
|
|
||||||
|
|
||||||
# assert batch_size % self.cfg.horizon == 0
|
|
||||||
# assert batch_size % num_slices == 0
|
|
||||||
|
|
||||||
# # Sample from interaction dataset
|
|
||||||
|
|
||||||
# def process_batch(batch, horizon, num_slices):
|
|
||||||
# # trajectory t = 256, horizon h = 5
|
|
||||||
# # (t h) ... -> h t ...
|
|
||||||
# batch = batch.reshape(num_slices, horizon).transpose(1, 0).contiguous()
|
|
||||||
|
|
||||||
# obs = {
|
|
||||||
# "rgb": batch["observation", "image"][FIRST_FRAME].to(self.device, non_blocking=True),
|
|
||||||
# "state": batch["observation", "state"][FIRST_FRAME].to(self.device, non_blocking=True),
|
|
||||||
# }
|
|
||||||
# action = batch["action"].to(self.device, non_blocking=True)
|
|
||||||
# next_obses = {
|
|
||||||
# "rgb": batch["next", "observation", "image"].to(self.device, non_blocking=True),
|
|
||||||
# "state": batch["next", "observation", "state"].to(self.device, non_blocking=True),
|
|
||||||
# }
|
|
||||||
# reward = batch["next", "reward"].to(self.device, non_blocking=True)
|
|
||||||
|
|
||||||
# idxs = batch["index"][FIRST_FRAME].to(self.device, non_blocking=True)
|
|
||||||
# weights = batch["_weight"][FIRST_FRAME, :, None].to(self.device, non_blocking=True)
|
|
||||||
|
|
||||||
# # TODO(rcadene): rearrange directly in offline dataset
|
|
||||||
# if reward.ndim == 2:
|
|
||||||
# reward = einops.rearrange(reward, "h t -> h t 1")
|
|
||||||
|
|
||||||
# assert reward.ndim == 3
|
|
||||||
# assert reward.shape == (horizon, num_slices, 1)
|
|
||||||
# # We dont use `batch["next", "done"]` since it only indicates the end of an
|
|
||||||
# # episode, but not the end of the trajectory of an episode.
|
|
||||||
# # Neither does `batch["next", "terminated"]`
|
|
||||||
# done = torch.zeros_like(reward, dtype=torch.bool, device=reward.device)
|
|
||||||
# mask = torch.ones_like(reward, dtype=torch.bool, device=reward.device)
|
|
||||||
# return obs, action, next_obses, reward, mask, done, idxs, weights
|
|
||||||
|
|
||||||
# batch = replay_buffer.sample(batch_size) if self.cfg.balanced_sampling else replay_buffer.sample()
|
|
||||||
|
|
||||||
# obs, action, next_obses, reward, mask, done, idxs, weights = process_batch(
|
|
||||||
# batch, self.cfg.horizon, num_slices
|
|
||||||
# )
|
|
||||||
|
|
||||||
# Sample from demonstration dataset
|
|
||||||
# if demo_batch_size > 0:
|
|
||||||
# demo_batch = demo_buffer.sample(demo_batch_size)
|
|
||||||
# (
|
|
||||||
# demo_obs,
|
|
||||||
# demo_action,
|
|
||||||
# demo_next_obses,
|
|
||||||
# demo_reward,
|
|
||||||
# demo_mask,
|
|
||||||
# demo_done,
|
|
||||||
# demo_idxs,
|
|
||||||
# demo_weights,
|
|
||||||
# ) = process_batch(demo_batch, self.cfg.horizon, demo_num_slices)
|
|
||||||
|
|
||||||
# if isinstance(obs, dict):
|
|
||||||
# obs = {k: torch.cat([obs[k], demo_obs[k]]) for k in obs}
|
|
||||||
# next_obses = {k: torch.cat([next_obses[k], demo_next_obses[k]], dim=1) for k in next_obses}
|
|
||||||
# else:
|
|
||||||
# obs = torch.cat([obs, demo_obs])
|
|
||||||
# next_obses = torch.cat([next_obses, demo_next_obses], dim=1)
|
|
||||||
# action = torch.cat([action, demo_action], dim=1)
|
|
||||||
# reward = torch.cat([reward, demo_reward], dim=1)
|
|
||||||
# mask = torch.cat([mask, demo_mask], dim=1)
|
|
||||||
# done = torch.cat([done, demo_done], dim=1)
|
|
||||||
# idxs = torch.cat([idxs, demo_idxs])
|
|
||||||
# weights = torch.cat([weights, demo_weights])
|
|
||||||
|
|
||||||
batch_size = batch["index"].shape[0]
|
batch_size = batch["index"].shape[0]
|
||||||
|
|
||||||
# TODO(rcadene): convert tdmpc with (batch size, time/horizon, channels)
|
# TODO(rcadene): convert tdmpc with (batch size, time/horizon, channels)
|
||||||
|
@ -534,6 +446,7 @@ class TDMPCPolicy(nn.Module):
|
||||||
)
|
)
|
||||||
self.optim.step()
|
self.optim.step()
|
||||||
|
|
||||||
|
# TODO(rcadene): implement PrioritizedSampling by modifying sampler.weights with priorities computed by a criterion
|
||||||
# if self.cfg.per:
|
# if self.cfg.per:
|
||||||
# # Update priorities
|
# # Update priorities
|
||||||
# priorities = priority_loss.clamp(max=1e4).detach()
|
# priorities = priority_loss.clamp(max=1e4).detach()
|
||||||
|
|
|
@ -99,6 +99,7 @@ def init_hydra_config(config_path: str, overrides: list[str] | None = None) -> D
|
||||||
|
|
||||||
|
|
||||||
def print_cuda_memory_usage():
|
def print_cuda_memory_usage():
|
||||||
|
"""Use this function to locate and debug memory leak."""
|
||||||
import gc
|
import gc
|
||||||
|
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
|
@ -18,7 +18,6 @@ env:
|
||||||
from_pixels: True
|
from_pixels: True
|
||||||
pixels_only: False
|
pixels_only: False
|
||||||
image_size: [3, 480, 640]
|
image_size: [3, 480, 640]
|
||||||
action_repeat: 1
|
|
||||||
episode_length: 400
|
episode_length: 400
|
||||||
fps: ${fps}
|
fps: ${fps}
|
||||||
|
|
||||||
|
|
|
@ -18,7 +18,6 @@ env:
|
||||||
from_pixels: True
|
from_pixels: True
|
||||||
pixels_only: False
|
pixels_only: False
|
||||||
image_size: 96
|
image_size: 96
|
||||||
action_repeat: 1
|
|
||||||
episode_length: 300
|
episode_length: 300
|
||||||
fps: ${fps}
|
fps: ${fps}
|
||||||
|
|
||||||
|
|
|
@ -17,7 +17,6 @@ env:
|
||||||
from_pixels: True
|
from_pixels: True
|
||||||
pixels_only: False
|
pixels_only: False
|
||||||
image_size: 84
|
image_size: 84
|
||||||
# action_repeat: 2 # we can remove if policy has n_action_steps=2
|
|
||||||
episode_length: 25
|
episode_length: 25
|
||||||
fps: ${fps}
|
fps: ${fps}
|
||||||
|
|
||||||
|
|
|
@ -36,6 +36,7 @@ policy:
|
||||||
log_std_max: 2
|
log_std_max: 2
|
||||||
|
|
||||||
# learning
|
# learning
|
||||||
|
batch_size: 256
|
||||||
max_buffer_size: 10000
|
max_buffer_size: 10000
|
||||||
horizon: 5
|
horizon: 5
|
||||||
reward_coef: 0.5
|
reward_coef: 0.5
|
||||||
|
|
|
@ -32,6 +32,7 @@ import json
|
||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
|
from copy import deepcopy
|
||||||
from datetime import datetime as dt
|
from datetime import datetime as dt
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
@ -56,15 +57,15 @@ def write_video(video_path, stacked_frames, fps):
|
||||||
|
|
||||||
def eval_policy(
|
def eval_policy(
|
||||||
env: gym.vector.VectorEnv,
|
env: gym.vector.VectorEnv,
|
||||||
policy,
|
policy: torch.nn.Module,
|
||||||
save_video: bool = False,
|
max_episodes_rendered: int = 0,
|
||||||
video_dir: Path = None,
|
video_dir: Path = None,
|
||||||
# TODO(rcadene): make it possible to overwrite fps? we should use env.fps
|
# TODO(rcadene): make it possible to overwrite fps? we should use env.fps
|
||||||
fps: int = 15,
|
|
||||||
return_first_video: bool = False,
|
|
||||||
transform: callable = None,
|
transform: callable = None,
|
||||||
seed=None,
|
seed=None,
|
||||||
):
|
):
|
||||||
|
fps = env.unwrapped.metadata["render_fps"]
|
||||||
|
|
||||||
if policy is not None:
|
if policy is not None:
|
||||||
policy.eval()
|
policy.eval()
|
||||||
device = "cpu" if policy is None else next(policy.parameters()).device
|
device = "cpu" if policy is None else next(policy.parameters()).device
|
||||||
|
@ -83,14 +84,11 @@ def eval_policy(
|
||||||
# needed as I'm currently taking a ceil.
|
# needed as I'm currently taking a ceil.
|
||||||
ep_frames = []
|
ep_frames = []
|
||||||
|
|
||||||
def maybe_render_frame(env):
|
def render_frame(env):
|
||||||
if save_video: # noqa: B023
|
# noqa: B023
|
||||||
if return_first_video:
|
eps_rendered = min(max_episodes_rendered, len(env.envs))
|
||||||
visu = env.envs[0].render()
|
visu = np.stack([env.envs[i].render() for i in range(eps_rendered)])
|
||||||
visu = visu[None, ...] # add batch dim
|
ep_frames.append(visu) # noqa: B023
|
||||||
else:
|
|
||||||
visu = np.stack([env.render() for env in env.envs])
|
|
||||||
ep_frames.append(visu) # noqa: B023
|
|
||||||
|
|
||||||
for _ in range(num_episodes):
|
for _ in range(num_episodes):
|
||||||
seeds.append("TODO")
|
seeds.append("TODO")
|
||||||
|
@ -104,8 +102,14 @@ def eval_policy(
|
||||||
|
|
||||||
# reset the environment
|
# reset the environment
|
||||||
observation, info = env.reset(seed=seed)
|
observation, info = env.reset(seed=seed)
|
||||||
maybe_render_frame(env)
|
if max_episodes_rendered > 0:
|
||||||
|
render_frame(env)
|
||||||
|
|
||||||
|
observations = []
|
||||||
|
actions = []
|
||||||
|
# episode
|
||||||
|
# frame_id
|
||||||
|
# timestamp
|
||||||
rewards = []
|
rewards = []
|
||||||
successes = []
|
successes = []
|
||||||
dones = []
|
dones = []
|
||||||
|
@ -113,8 +117,13 @@ def eval_policy(
|
||||||
done = torch.tensor([False for _ in env.envs])
|
done = torch.tensor([False for _ in env.envs])
|
||||||
step = 0
|
step = 0
|
||||||
while not done.all():
|
while not done.all():
|
||||||
|
# format from env keys to lerobot keys
|
||||||
|
observation = preprocess_observation(observation)
|
||||||
|
observations.append(deepcopy(observation))
|
||||||
|
|
||||||
# apply transform to normalize the observations
|
# apply transform to normalize the observations
|
||||||
observation = preprocess_observation(observation, transform)
|
for key in observation:
|
||||||
|
observation[key] = torch.stack([transform({key: item})[key] for item in observation[key]])
|
||||||
|
|
||||||
# send observation to device/gpu
|
# send observation to device/gpu
|
||||||
observation = {key: observation[key].to(device, non_blocking=True) for key in observation}
|
observation = {key: observation[key].to(device, non_blocking=True) for key in observation}
|
||||||
|
@ -126,11 +135,13 @@ def eval_policy(
|
||||||
# apply inverse transform to unnormalize the action
|
# apply inverse transform to unnormalize the action
|
||||||
action = postprocess_action(action, transform)
|
action = postprocess_action(action, transform)
|
||||||
|
|
||||||
# apply the next
|
# apply the next action
|
||||||
observation, reward, terminated, truncated, info = env.step(action)
|
observation, reward, terminated, truncated, info = env.step(action)
|
||||||
maybe_render_frame(env)
|
if max_episodes_rendered > 0:
|
||||||
|
render_frame(env)
|
||||||
|
|
||||||
# TODO(rcadene): implement a wrapper over env to return torch tensors in float32 (and cuda?)
|
# TODO(rcadene): implement a wrapper over env to return torch tensors in float32 (and cuda?)
|
||||||
|
action = torch.from_numpy(action)
|
||||||
reward = torch.from_numpy(reward)
|
reward = torch.from_numpy(reward)
|
||||||
terminated = torch.from_numpy(terminated)
|
terminated = torch.from_numpy(terminated)
|
||||||
truncated = torch.from_numpy(truncated)
|
truncated = torch.from_numpy(truncated)
|
||||||
|
@ -147,12 +158,24 @@ def eval_policy(
|
||||||
success = [False for _ in env.envs]
|
success = [False for _ in env.envs]
|
||||||
success = torch.tensor(success)
|
success = torch.tensor(success)
|
||||||
|
|
||||||
|
actions.append(action)
|
||||||
rewards.append(reward)
|
rewards.append(reward)
|
||||||
dones.append(done)
|
dones.append(done)
|
||||||
successes.append(success)
|
successes.append(success)
|
||||||
|
|
||||||
step += 1
|
step += 1
|
||||||
|
|
||||||
|
env.close()
|
||||||
|
|
||||||
|
# add the last observation when the env is done
|
||||||
|
observation = preprocess_observation(observation)
|
||||||
|
observations.append(deepcopy(observation))
|
||||||
|
|
||||||
|
new_obses = {}
|
||||||
|
for key in observations[0].keys(): # noqa: SIM118
|
||||||
|
new_obses[key] = torch.stack([obs[key] for obs in observations], dim=1)
|
||||||
|
observations = new_obses
|
||||||
|
actions = torch.stack(actions, dim=1)
|
||||||
rewards = torch.stack(rewards, dim=1)
|
rewards = torch.stack(rewards, dim=1)
|
||||||
successes = torch.stack(successes, dim=1)
|
successes = torch.stack(successes, dim=1)
|
||||||
dones = torch.stack(dones, dim=1)
|
dones = torch.stack(dones, dim=1)
|
||||||
|
@ -172,29 +195,61 @@ def eval_policy(
|
||||||
max_rewards.extend(batch_max_reward.tolist())
|
max_rewards.extend(batch_max_reward.tolist())
|
||||||
all_successes.extend(batch_success.tolist())
|
all_successes.extend(batch_success.tolist())
|
||||||
|
|
||||||
env.close()
|
# similar logic is implemented in dataset preprocessing
|
||||||
|
ep_dicts = []
|
||||||
|
num_episodes = dones.shape[0]
|
||||||
|
total_frames = 0
|
||||||
|
idx0 = idx1 = 0
|
||||||
|
data_ids_per_episode = {}
|
||||||
|
for ep_id in range(num_episodes):
|
||||||
|
num_frames = done_indices[ep_id].item() + 1
|
||||||
|
# TODO(rcadene): We need to add a missing last frame which is the observation
|
||||||
|
# of a done state. it is critical to have this frame for tdmpc to predict a "done observation/state"
|
||||||
|
ep_dict = {
|
||||||
|
"action": actions[ep_id, :num_frames],
|
||||||
|
"episode": torch.tensor([ep_id] * num_frames),
|
||||||
|
"frame_id": torch.arange(0, num_frames, 1),
|
||||||
|
"timestamp": torch.arange(0, num_frames, 1) / fps,
|
||||||
|
"next.done": dones[ep_id, :num_frames],
|
||||||
|
"next.reward": rewards[ep_id, :num_frames].type(torch.float32),
|
||||||
|
}
|
||||||
|
for key in observations:
|
||||||
|
ep_dict[key] = observations[key][ep_id, :num_frames]
|
||||||
|
ep_dicts.append(ep_dict)
|
||||||
|
|
||||||
if save_video or return_first_video:
|
total_frames += num_frames
|
||||||
|
idx1 += num_frames
|
||||||
|
|
||||||
|
data_ids_per_episode[ep_id] = torch.arange(idx0, idx1, 1)
|
||||||
|
|
||||||
|
idx0 = idx1
|
||||||
|
|
||||||
|
# similar logic is implemented in dataset preprocessing
|
||||||
|
data_dict = {}
|
||||||
|
keys = ep_dicts[0].keys()
|
||||||
|
for key in keys:
|
||||||
|
data_dict[key] = torch.cat([x[key] for x in ep_dicts])
|
||||||
|
data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||||
|
|
||||||
|
if max_episodes_rendered > 0:
|
||||||
batch_stacked_frames = np.stack(ep_frames, 1) # (b, t, *)
|
batch_stacked_frames = np.stack(ep_frames, 1) # (b, t, *)
|
||||||
|
|
||||||
if save_video:
|
for stacked_frames, done_index in zip(
|
||||||
for stacked_frames, done_index in zip(
|
batch_stacked_frames, done_indices.flatten().tolist(), strict=False
|
||||||
batch_stacked_frames, done_indices.flatten().tolist(), strict=False
|
):
|
||||||
):
|
if episode_counter >= num_episodes:
|
||||||
if episode_counter >= num_episodes:
|
continue
|
||||||
continue
|
video_dir.mkdir(parents=True, exist_ok=True)
|
||||||
video_dir.mkdir(parents=True, exist_ok=True)
|
video_path = video_dir / f"eval_episode_{episode_counter}.mp4"
|
||||||
video_path = video_dir / f"eval_episode_{episode_counter}.mp4"
|
thread = threading.Thread(
|
||||||
thread = threading.Thread(
|
target=write_video,
|
||||||
target=write_video,
|
args=(str(video_path), stacked_frames[:done_index], fps),
|
||||||
args=(str(video_path), stacked_frames[:done_index], fps),
|
)
|
||||||
)
|
thread.start()
|
||||||
thread.start()
|
threads.append(thread)
|
||||||
threads.append(thread)
|
episode_counter += 1
|
||||||
episode_counter += 1
|
|
||||||
|
|
||||||
if return_first_video:
|
videos = einops.rearrange(batch_stacked_frames, "b t h w c -> b t c h w")
|
||||||
first_video = batch_stacked_frames[0].transpose(0, 3, 1, 2)
|
|
||||||
|
|
||||||
for thread in threads:
|
for thread in threads:
|
||||||
thread.join()
|
thread.join()
|
||||||
|
@ -225,9 +280,13 @@ def eval_policy(
|
||||||
"eval_s": time.time() - start,
|
"eval_s": time.time() - start,
|
||||||
"eval_ep_s": (time.time() - start) / num_episodes,
|
"eval_ep_s": (time.time() - start) / num_episodes,
|
||||||
},
|
},
|
||||||
|
"episodes": {
|
||||||
|
"data_dict": data_dict,
|
||||||
|
"data_ids_per_episode": data_ids_per_episode,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
if return_first_video:
|
if max_episodes_rendered > 0:
|
||||||
return info, first_video
|
info["videos"] = videos
|
||||||
return info
|
return info
|
||||||
|
|
||||||
|
|
||||||
|
@ -253,16 +312,14 @@ def eval(cfg: dict, out_dir=None, stats_path=None):
|
||||||
logging.info("Making environment.")
|
logging.info("Making environment.")
|
||||||
env = make_env(cfg, num_parallel_envs=cfg.eval_episodes)
|
env = make_env(cfg, num_parallel_envs=cfg.eval_episodes)
|
||||||
|
|
||||||
# when policy is None, rollout a random policy
|
logging.info("Making policy.")
|
||||||
policy = make_policy(cfg) if cfg.policy.pretrained_model_path else None
|
policy = make_policy(cfg)
|
||||||
|
|
||||||
info = eval_policy(
|
info = eval_policy(
|
||||||
env,
|
env,
|
||||||
policy=policy,
|
policy,
|
||||||
save_video=True,
|
max_episodes_rendered=10,
|
||||||
video_dir=Path(out_dir) / "eval",
|
video_dir=Path(out_dir) / "eval",
|
||||||
fps=cfg.env.fps,
|
|
||||||
# TODO(rcadene): what should we do with the transform?
|
|
||||||
transform=transform,
|
transform=transform,
|
||||||
seed=cfg.seed,
|
seed=cfg.seed,
|
||||||
)
|
)
|
||||||
|
@ -270,6 +327,9 @@ def eval(cfg: dict, out_dir=None, stats_path=None):
|
||||||
|
|
||||||
# Save info
|
# Save info
|
||||||
with open(Path(out_dir) / "eval_info.json", "w") as f:
|
with open(Path(out_dir) / "eval_info.json", "w") as f:
|
||||||
|
# remove pytorch tensors which are not serializable to save the evaluation results only
|
||||||
|
del info["episodes"]
|
||||||
|
del info["videos"]
|
||||||
json.dump(info, f, indent=2)
|
json.dump(info, f, indent=2)
|
||||||
|
|
||||||
logging.info("End of eval")
|
logging.info("End of eval")
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
import logging
|
import logging
|
||||||
|
from copy import deepcopy
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import hydra
|
import hydra
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from lerobot.common.datasets.factory import make_dataset
|
from lerobot.common.datasets.factory import make_dataset
|
||||||
|
@ -108,6 +108,64 @@ def log_eval_info(logger, info, step, cfg, dataset, is_offline):
|
||||||
logger.log_dict(info, step, mode="eval")
|
logger.log_dict(info, step, mode="eval")
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_online_sample_weight(n_off: int, n_on: int, pc_on: float):
|
||||||
|
"""
|
||||||
|
Calculate the sampling weight to be assigned to samples so that a specified percentage of the batch comes from online dataset (on average).
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- n_off (int): Number of offline samples, each with a sampling weight of 1.
|
||||||
|
- n_on (int): Number of online samples.
|
||||||
|
- pc_on (float): Desired percentage of online samples in decimal form (e.g., 50% as 0.5).
|
||||||
|
|
||||||
|
The total weight of offline samples is n_off * 1.0.
|
||||||
|
The total weight of offline samples is n_on * w.
|
||||||
|
The total combined weight of all samples is n_off + n_on * w.
|
||||||
|
The fraction of the weight that is online is n_on * w / (n_off + n_on * w).
|
||||||
|
We want this fraction to equal pc_on, so we set up the equation n_on * w / (n_off + n_on * w) = pc_on.
|
||||||
|
The solution is w = - (n_off * pc_on) / (n_on * (pc_on - 1))
|
||||||
|
"""
|
||||||
|
assert 0.0 <= pc_on <= 1.0
|
||||||
|
return -(n_off * pc_on) / (n_on * (pc_on - 1))
|
||||||
|
|
||||||
|
|
||||||
|
def add_episodes_inplace(episodes, online_dataset, concat_dataset, sampler, pc_online_samples):
|
||||||
|
data_dict = episodes["data_dict"]
|
||||||
|
data_ids_per_episode = episodes["data_ids_per_episode"]
|
||||||
|
|
||||||
|
if len(online_dataset) == 0:
|
||||||
|
# initialize online dataset
|
||||||
|
online_dataset.data_dict = data_dict
|
||||||
|
online_dataset.data_ids_per_episode = data_ids_per_episode
|
||||||
|
else:
|
||||||
|
# find episode index and data frame indices according to previous episode in online_dataset
|
||||||
|
start_episode = max(online_dataset.data_ids_per_episode.keys()) + 1
|
||||||
|
start_index = online_dataset.data_dict["index"][-1].item() + 1
|
||||||
|
data_dict["episode"] += start_episode
|
||||||
|
data_dict["index"] += start_index
|
||||||
|
|
||||||
|
# extend online dataset
|
||||||
|
for key in data_dict:
|
||||||
|
# TODO(rcadene): avoid reallocating memory at every step by preallocating memory or changing our data structure
|
||||||
|
online_dataset.data_dict[key] = torch.cat([online_dataset.data_dict[key], data_dict[key]])
|
||||||
|
for ep_id in data_ids_per_episode:
|
||||||
|
online_dataset.data_ids_per_episode[ep_id + start_episode] = (
|
||||||
|
data_ids_per_episode[ep_id] + start_index
|
||||||
|
)
|
||||||
|
|
||||||
|
# update the concatenated dataset length used during sampling
|
||||||
|
concat_dataset.cumulative_sizes = concat_dataset.cumsum(concat_dataset.datasets)
|
||||||
|
|
||||||
|
# update the sampling weights for each frame so that online frames get sampled a certain percentage of times
|
||||||
|
len_online = len(online_dataset)
|
||||||
|
len_offline = len(concat_dataset) - len_online
|
||||||
|
weight_offline = 1.0
|
||||||
|
weight_online = calculate_online_sample_weight(len_offline, len_online, pc_online_samples)
|
||||||
|
sampler.weights = torch.tensor([weight_offline] * len_offline + [weight_online] * len(online_dataset))
|
||||||
|
|
||||||
|
# update the total number of samples used during sampling
|
||||||
|
sampler.num_samples = len(concat_dataset)
|
||||||
|
|
||||||
|
|
||||||
def train(cfg: dict, out_dir=None, job_name=None):
|
def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
if out_dir is None:
|
if out_dir is None:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
@ -126,26 +184,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
set_global_seed(cfg.seed)
|
set_global_seed(cfg.seed)
|
||||||
|
|
||||||
logging.info("make_dataset")
|
logging.info("make_dataset")
|
||||||
dataset = make_dataset(cfg)
|
offline_dataset = make_dataset(cfg)
|
||||||
|
|
||||||
# TODO(rcadene): move balanced_sampling, per_alpha, per_beta outside policy
|
|
||||||
# if cfg.policy.balanced_sampling:
|
|
||||||
# logging.info("make online_buffer")
|
|
||||||
# num_traj_per_batch = cfg.policy.batch_size
|
|
||||||
|
|
||||||
# online_sampler = PrioritizedSliceSampler(
|
|
||||||
# max_capacity=100_000,
|
|
||||||
# alpha=cfg.policy.per_alpha,
|
|
||||||
# beta=cfg.policy.per_beta,
|
|
||||||
# num_slices=num_traj_per_batch,
|
|
||||||
# strict_length=True,
|
|
||||||
# )
|
|
||||||
|
|
||||||
# online_buffer = TensorDictReplayBuffer(
|
|
||||||
# storage=LazyMemmapStorage(100_000),
|
|
||||||
# sampler=online_sampler,
|
|
||||||
# transform=dataset.transform,
|
|
||||||
# )
|
|
||||||
|
|
||||||
logging.info("make_env")
|
logging.info("make_env")
|
||||||
env = make_env(cfg, num_parallel_envs=cfg.eval_episodes)
|
env = make_env(cfg, num_parallel_envs=cfg.eval_episodes)
|
||||||
|
@ -163,9 +202,8 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
logging.info(f"{cfg.env.task=}")
|
logging.info(f"{cfg.env.task=}")
|
||||||
logging.info(f"{cfg.offline_steps=} ({format_big_number(cfg.offline_steps)})")
|
logging.info(f"{cfg.offline_steps=} ({format_big_number(cfg.offline_steps)})")
|
||||||
logging.info(f"{cfg.online_steps=}")
|
logging.info(f"{cfg.online_steps=}")
|
||||||
logging.info(f"{cfg.env.action_repeat=}")
|
logging.info(f"{offline_dataset.num_samples=} ({format_big_number(offline_dataset.num_samples)})")
|
||||||
logging.info(f"{dataset.num_samples=} ({format_big_number(dataset.num_samples)})")
|
logging.info(f"{offline_dataset.num_episodes=}")
|
||||||
logging.info(f"{dataset.num_episodes=}")
|
|
||||||
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
|
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
|
||||||
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
|
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
|
||||||
|
|
||||||
|
@ -173,18 +211,17 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
def _maybe_eval_and_maybe_save(step):
|
def _maybe_eval_and_maybe_save(step):
|
||||||
if step % cfg.eval_freq == 0:
|
if step % cfg.eval_freq == 0:
|
||||||
logging.info(f"Eval policy at step {step}")
|
logging.info(f"Eval policy at step {step}")
|
||||||
eval_info, first_video = eval_policy(
|
eval_info = eval_policy(
|
||||||
env,
|
env,
|
||||||
policy,
|
policy,
|
||||||
return_first_video=True,
|
|
||||||
video_dir=Path(out_dir) / "eval",
|
video_dir=Path(out_dir) / "eval",
|
||||||
save_video=True,
|
max_episodes_rendered=4,
|
||||||
transform=dataset.transform,
|
transform=offline_dataset.transform,
|
||||||
seed=cfg.seed,
|
seed=cfg.seed,
|
||||||
)
|
)
|
||||||
log_eval_info(logger, eval_info["aggregated"], step, cfg, dataset, is_offline)
|
log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_dataset, is_offline)
|
||||||
if cfg.wandb.enable:
|
if cfg.wandb.enable:
|
||||||
logger.log_video(first_video, step, mode="eval")
|
logger.log_video(eval_info["videos"][0], step, mode="eval")
|
||||||
logging.info("Resume training")
|
logging.info("Resume training")
|
||||||
|
|
||||||
if cfg.save_model and step % cfg.save_freq == 0:
|
if cfg.save_model and step % cfg.save_freq == 0:
|
||||||
|
@ -192,18 +229,19 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
logger.save_model(policy, identifier=step)
|
logger.save_model(policy, identifier=step)
|
||||||
logging.info("Resume training")
|
logging.info("Resume training")
|
||||||
|
|
||||||
step = 0 # number of policy update (forward + backward + optim)
|
# create dataloader for offline training
|
||||||
|
|
||||||
is_offline = True
|
|
||||||
dataloader = torch.utils.data.DataLoader(
|
dataloader = torch.utils.data.DataLoader(
|
||||||
dataset,
|
offline_dataset,
|
||||||
num_workers=4,
|
num_workers=4,
|
||||||
batch_size=cfg.policy.batch_size,
|
batch_size=cfg.policy.batch_size,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
pin_memory=cfg.device != "cpu",
|
pin_memory=cfg.device != "cpu",
|
||||||
drop_last=True,
|
drop_last=False,
|
||||||
)
|
)
|
||||||
dl_iter = cycle(dataloader)
|
dl_iter = cycle(dataloader)
|
||||||
|
|
||||||
|
step = 0 # number of policy update (forward + backward + optim)
|
||||||
|
is_offline = True
|
||||||
for offline_step in range(cfg.offline_steps):
|
for offline_step in range(cfg.offline_steps):
|
||||||
if offline_step == 0:
|
if offline_step == 0:
|
||||||
logging.info("Start offline training on a fixed dataset")
|
logging.info("Start offline training on a fixed dataset")
|
||||||
|
@ -217,7 +255,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
|
|
||||||
# TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done?
|
# TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done?
|
||||||
if step % cfg.log_freq == 0:
|
if step % cfg.log_freq == 0:
|
||||||
log_train_info(logger, train_info, step, cfg, dataset, is_offline)
|
log_train_info(logger, train_info, step, cfg, offline_dataset, is_offline)
|
||||||
|
|
||||||
# Note: _maybe_eval_and_maybe_save happens **after** the `step`th training update has completed, so we pass in
|
# Note: _maybe_eval_and_maybe_save happens **after** the `step`th training update has completed, so we pass in
|
||||||
# step + 1.
|
# step + 1.
|
||||||
|
@ -225,61 +263,60 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
|
|
||||||
step += 1
|
step += 1
|
||||||
|
|
||||||
raise NotImplementedError()
|
# create an env dedicated to online episodes collection from policy rollout
|
||||||
|
rollout_env = make_env(cfg, num_parallel_envs=1)
|
||||||
|
|
||||||
|
# create an empty online dataset similar to offline dataset
|
||||||
|
online_dataset = deepcopy(offline_dataset)
|
||||||
|
online_dataset.data_dict = {}
|
||||||
|
online_dataset.data_ids_per_episode = {}
|
||||||
|
|
||||||
|
# create dataloader for online training
|
||||||
|
concat_dataset = torch.utils.data.ConcatDataset([offline_dataset, online_dataset])
|
||||||
|
weights = [1.0] * len(concat_dataset)
|
||||||
|
sampler = torch.utils.data.WeightedRandomSampler(
|
||||||
|
weights, num_samples=len(concat_dataset), replacement=True
|
||||||
|
)
|
||||||
|
dataloader = torch.utils.data.DataLoader(
|
||||||
|
concat_dataset,
|
||||||
|
num_workers=4,
|
||||||
|
batch_size=cfg.policy.batch_size,
|
||||||
|
sampler=sampler,
|
||||||
|
pin_memory=cfg.device != "cpu",
|
||||||
|
drop_last=False,
|
||||||
|
)
|
||||||
|
dl_iter = cycle(dataloader)
|
||||||
|
|
||||||
demo_buffer = dataset if cfg.policy.balanced_sampling else None
|
|
||||||
online_step = 0
|
online_step = 0
|
||||||
is_offline = False
|
is_offline = False
|
||||||
for env_step in range(cfg.online_steps):
|
for env_step in range(cfg.online_steps):
|
||||||
if env_step == 0:
|
if env_step == 0:
|
||||||
logging.info("Start online training by interacting with environment")
|
logging.info("Start online training by interacting with environment")
|
||||||
# TODO: add configurable number of rollout? (default=1)
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
rollout = env.rollout(
|
eval_info = eval_policy(
|
||||||
max_steps=cfg.env.episode_length,
|
rollout_env,
|
||||||
policy=policy,
|
policy,
|
||||||
auto_cast_to_device=True,
|
transform=offline_dataset.transform,
|
||||||
|
seed=cfg.seed,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert (
|
online_pc_sampling = cfg.get("demo_schedule", 0.5)
|
||||||
len(rollout.batch_size) == 2
|
add_episodes_inplace(
|
||||||
), "2 dimensions expected: number of env in parallel x max number of steps during rollout"
|
eval_info["episodes"], online_dataset, concat_dataset, sampler, online_pc_sampling
|
||||||
|
)
|
||||||
num_parallel_env = rollout.batch_size[0]
|
|
||||||
if num_parallel_env != 1:
|
|
||||||
# TODO(rcadene): when num_parallel_env > 1, rollout["episode"] needs to be properly set and we need to add tests
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
num_max_steps = rollout.batch_size[1]
|
|
||||||
assert num_max_steps <= cfg.env.episode_length
|
|
||||||
|
|
||||||
# reshape to have a list of steps to insert into online_buffer
|
|
||||||
rollout = rollout.reshape(num_parallel_env * num_max_steps)
|
|
||||||
|
|
||||||
# set same episode index for all time steps contained in this rollout
|
|
||||||
rollout["episode"] = torch.tensor([env_step] * len(rollout), dtype=torch.int)
|
|
||||||
# online_buffer.extend(rollout)
|
|
||||||
|
|
||||||
ep_sum_reward = rollout["next", "reward"].sum()
|
|
||||||
ep_max_reward = rollout["next", "reward"].max()
|
|
||||||
ep_success = rollout["next", "success"].any()
|
|
||||||
rollout_info = {
|
|
||||||
"avg_sum_reward": np.nanmean(ep_sum_reward),
|
|
||||||
"avg_max_reward": np.nanmean(ep_max_reward),
|
|
||||||
"pc_success": np.nanmean(ep_success) * 100,
|
|
||||||
"env_step": env_step,
|
|
||||||
"ep_length": len(rollout),
|
|
||||||
}
|
|
||||||
|
|
||||||
for _ in range(cfg.policy.utd):
|
for _ in range(cfg.policy.utd):
|
||||||
train_info = policy.update(
|
policy.train()
|
||||||
# online_buffer,
|
batch = next(dl_iter)
|
||||||
step,
|
|
||||||
demo_buffer=demo_buffer,
|
for key in batch:
|
||||||
)
|
batch[key] = batch[key].to(cfg.device, non_blocking=True)
|
||||||
|
|
||||||
|
train_info = policy(batch, step)
|
||||||
|
|
||||||
if step % cfg.log_freq == 0:
|
if step % cfg.log_freq == 0:
|
||||||
train_info.update(rollout_info)
|
log_train_info(logger, train_info, step, cfg, online_dataset, is_offline)
|
||||||
log_train_info(logger, train_info, step, cfg, dataset, is_offline)
|
|
||||||
|
|
||||||
# Note: _maybe_eval_and_maybe_save happens **after** the `step`th training update has completed, so we pass
|
# Note: _maybe_eval_and_maybe_save happens **after** the `step`th training update has completed, so we pass
|
||||||
# in step + 1.
|
# in step + 1.
|
||||||
|
|
|
@ -6,9 +6,6 @@ import einops
|
||||||
import hydra
|
import hydra
|
||||||
import imageio
|
import imageio
|
||||||
import torch
|
import torch
|
||||||
from torchrl.data.replay_buffers import (
|
|
||||||
SamplerWithoutReplacement,
|
|
||||||
)
|
|
||||||
|
|
||||||
from lerobot.common.datasets.factory import make_dataset
|
from lerobot.common.datasets.factory import make_dataset
|
||||||
from lerobot.common.logger import log_output_dir
|
from lerobot.common.logger import log_output_dir
|
||||||
|
@ -39,19 +36,11 @@ def visualize_dataset(cfg: dict, out_dir=None):
|
||||||
init_logging()
|
init_logging()
|
||||||
log_output_dir(out_dir)
|
log_output_dir(out_dir)
|
||||||
|
|
||||||
# we expect frames of each episode to be stored next to each others sequentially
|
|
||||||
sampler = SamplerWithoutReplacement(
|
|
||||||
shuffle=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
logging.info("make_dataset")
|
logging.info("make_dataset")
|
||||||
dataset = make_dataset(
|
dataset = make_dataset(
|
||||||
cfg,
|
cfg,
|
||||||
overwrite_sampler=sampler,
|
|
||||||
# remove all transformations such as rescale images from [0,255] to [0,1] or normalization
|
# remove all transformations such as rescale images from [0,255] to [0,1] or normalization
|
||||||
normalize=False,
|
normalize=False,
|
||||||
overwrite_batch_size=1,
|
|
||||||
overwrite_prefetch=12,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info("Start rendering episodes from offline buffer")
|
logging.info("Start rendering episodes from offline buffer")
|
||||||
|
@ -60,64 +49,49 @@ def visualize_dataset(cfg: dict, out_dir=None):
|
||||||
logging.info(video_path)
|
logging.info(video_path)
|
||||||
|
|
||||||
|
|
||||||
def render_dataset(dataset, out_dir, max_num_samples, fps):
|
def render_dataset(dataset, out_dir, max_num_episodes):
|
||||||
out_dir = Path(out_dir)
|
out_dir = Path(out_dir)
|
||||||
video_paths = []
|
video_paths = []
|
||||||
threads = []
|
threads = []
|
||||||
frames = {}
|
|
||||||
current_ep_idx = 0
|
|
||||||
logging.info(f"Visualizing episode {current_ep_idx}")
|
|
||||||
for i in range(max_num_samples):
|
|
||||||
# TODO(rcadene): make it work with bsize > 1
|
|
||||||
ep_td = dataset.sample(1)
|
|
||||||
ep_idx = ep_td["episode"][FIRST_FRAME].item()
|
|
||||||
|
|
||||||
# TODO(rcadene): modify dataset._sampler._sample_list or sampler to randomly sample an episode, but sequentially sample frames
|
dataloader = torch.utils.data.DataLoader(
|
||||||
num_frames_left = dataset._sampler._sample_list.numel()
|
dataset,
|
||||||
episode_is_done = ep_idx != current_ep_idx
|
num_workers=4,
|
||||||
|
batch_size=1,
|
||||||
|
shuffle=False,
|
||||||
|
)
|
||||||
|
dl_iter = iter(dataloader)
|
||||||
|
|
||||||
if episode_is_done:
|
num_episodes = len(dataset.data_ids_per_episode)
|
||||||
logging.info(f"Rendering episode {current_ep_idx}")
|
for ep_id in range(min(max_num_episodes, num_episodes)):
|
||||||
|
logging.info(f"Rendering episode {ep_id}")
|
||||||
|
|
||||||
for im_key in dataset.image_keys:
|
frames = {}
|
||||||
if not episode_is_done and num_frames_left > 0 and i < (max_num_samples - 1):
|
for _ in dataset.data_ids_per_episode[ep_id]:
|
||||||
|
item = next(dl_iter)
|
||||||
|
|
||||||
|
for im_key in dataset.image_keys:
|
||||||
# when first frame of episode, initialize frames dict
|
# when first frame of episode, initialize frames dict
|
||||||
if im_key not in frames:
|
if im_key not in frames:
|
||||||
frames[im_key] = []
|
frames[im_key] = []
|
||||||
# add current frame to list of frames to render
|
# add current frame to list of frames to render
|
||||||
frames[im_key].append(ep_td[im_key])
|
frames[im_key].append(item[im_key])
|
||||||
|
|
||||||
|
out_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
for im_key in dataset.image_keys:
|
||||||
|
if len(dataset.image_keys) > 1:
|
||||||
|
im_name = im_key.replace("observation.images.", "")
|
||||||
|
video_path = out_dir / f"episode_{ep_id}_{im_name}.mp4"
|
||||||
else:
|
else:
|
||||||
# When episode has no more frame in its list of observation,
|
video_path = out_dir / f"episode_{ep_id}.mp4"
|
||||||
# one frame still remains. It is the result of the last action taken.
|
video_paths.append(video_path)
|
||||||
# It is stored in `"next"`, so we add it to the list of frames to render.
|
|
||||||
frames[im_key].append(ep_td["next"][im_key])
|
|
||||||
|
|
||||||
out_dir.mkdir(parents=True, exist_ok=True)
|
thread = threading.Thread(
|
||||||
if len(dataset.image_keys) > 1:
|
target=cat_and_write_video,
|
||||||
camera = im_key[-1]
|
args=(str(video_path), frames[im_key], dataset.fps),
|
||||||
video_path = out_dir / f"episode_{current_ep_idx}_{camera}.mp4"
|
)
|
||||||
else:
|
thread.start()
|
||||||
video_path = out_dir / f"episode_{current_ep_idx}.mp4"
|
threads.append(thread)
|
||||||
video_paths.append(str(video_path))
|
|
||||||
|
|
||||||
thread = threading.Thread(
|
|
||||||
target=cat_and_write_video,
|
|
||||||
args=(str(video_path), frames[im_key], fps),
|
|
||||||
)
|
|
||||||
thread.start()
|
|
||||||
threads.append(thread)
|
|
||||||
|
|
||||||
current_ep_idx = ep_idx
|
|
||||||
|
|
||||||
# reset list of frames
|
|
||||||
del frames[im_key]
|
|
||||||
|
|
||||||
if num_frames_left == 0:
|
|
||||||
logging.info("Ran out of frames")
|
|
||||||
break
|
|
||||||
|
|
||||||
if current_ep_idx == NUM_EPISODES_TO_RENDER:
|
|
||||||
break
|
|
||||||
|
|
||||||
for thread in threads:
|
for thread in threads:
|
||||||
thread.join()
|
thread.join()
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
|
# This file is automatically @generated by Poetry 1.8.1 and should not be changed by hand.
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "absl-py"
|
name = "absl-py"
|
||||||
|
@ -921,7 +921,7 @@ shapely = "^2.0.3"
|
||||||
type = "git"
|
type = "git"
|
||||||
url = "git@github.com:huggingface/gym-pusht.git"
|
url = "git@github.com:huggingface/gym-pusht.git"
|
||||||
reference = "HEAD"
|
reference = "HEAD"
|
||||||
resolved_reference = "824b22832cc8d71a4b4e96a57563510cf47e30c1"
|
resolved_reference = "080d4ce4d8d3140b2fd204ed628bda14dc58ff06"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "gym-xarm"
|
name = "gym-xarm"
|
||||||
|
@ -941,7 +941,7 @@ mujoco = "^2.3.7"
|
||||||
type = "git"
|
type = "git"
|
||||||
url = "git@github.com:huggingface/gym-xarm.git"
|
url = "git@github.com:huggingface/gym-xarm.git"
|
||||||
reference = "HEAD"
|
reference = "HEAD"
|
||||||
resolved_reference = "ce294c0d30def08414d9237e2bf9f373d448ca07"
|
resolved_reference = "6a88f7d63833705dfbec4b997bf36cac6b4a448c"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "gymnasium"
|
name = "gymnasium"
|
||||||
|
|
25
sbatch.sh
25
sbatch.sh
|
@ -1,25 +0,0 @@
|
||||||
#!/bin/bash
|
|
||||||
#SBATCH --nodes=1 # total number of nodes (N to be defined)
|
|
||||||
#SBATCH --ntasks-per-node=1 # number of tasks per node (here 8 tasks, or 1 task per GPU)
|
|
||||||
#SBATCH --gres=gpu:1 # number of GPUs reserved per node (here 8, or all the GPUs)
|
|
||||||
#SBATCH --cpus-per-task=8 # number of cores per task (8x8 = 64 cores, or all the cores)
|
|
||||||
#SBATCH --time=2-00:00:00
|
|
||||||
#SBATCH --output=/home/rcadene/slurm/%j.out
|
|
||||||
#SBATCH --error=/home/rcadene/slurm/%j.err
|
|
||||||
#SBATCH --qos=low
|
|
||||||
#SBATCH --mail-user=re.cadene@gmail.com
|
|
||||||
#SBATCH --mail-type=ALL
|
|
||||||
|
|
||||||
CMD=$@
|
|
||||||
echo "command: $CMD"
|
|
||||||
|
|
||||||
apptainer exec --nv \
|
|
||||||
~/apptainer/nvidia_cuda:12.2.2-devel-ubuntu22.04.sif $SHELL
|
|
||||||
|
|
||||||
source ~/.bashrc
|
|
||||||
#conda activate fowm
|
|
||||||
conda activate lerobot
|
|
||||||
|
|
||||||
export DATA_DIR="data"
|
|
||||||
|
|
||||||
srun $CMD
|
|
|
@ -1,17 +0,0 @@
|
||||||
#!/bin/bash
|
|
||||||
#SBATCH --nodes=1 # total number of nodes (N to be defined)
|
|
||||||
#SBATCH --ntasks-per-node=1 # number of tasks per node (here 8 tasks, or 1 task per GPU)
|
|
||||||
#SBATCH --qos=normal # number of GPUs reserved per node (here 8, or all the GPUs)
|
|
||||||
#SBATCH --partition=hopper-prod
|
|
||||||
#SBATCH --gres=gpu:1 # number of GPUs reserved per node (here 8, or all the GPUs)
|
|
||||||
#SBATCH --cpus-per-task=12 # number of cores per task
|
|
||||||
#SBATCH --mem-per-cpu=11G
|
|
||||||
#SBATCH --time=12:00:00
|
|
||||||
#SBATCH --output=/admin/home/remi_cadene/slurm/%j.out
|
|
||||||
#SBATCH --error=/admin/home/remi_cadene/slurm/%j.err
|
|
||||||
#SBATCH --mail-user=remi_cadene@huggingface.co
|
|
||||||
#SBATCH --mail-type=ALL
|
|
||||||
|
|
||||||
CMD=$@
|
|
||||||
echo "command: $CMD"
|
|
||||||
srun $CMD
|
|
Binary file not shown.
|
@ -1,64 +1,53 @@
|
||||||
"""
|
"""
|
||||||
This test verifies that all environments, datasets, policies listed in `lerobot/__init__.py` can be sucessfully
|
This test verifies that all environments, datasets, policies listed in `lerobot/__init__.py` can be sucessfully
|
||||||
imported and that their class attributes (eg. `available_datasets`, `name`, `available_tasks`) corresponds.
|
imported and that their class attributes (eg. `available_datasets`, `name`, `available_tasks`) are valid.
|
||||||
|
|
||||||
Note:
|
When implementing a new dataset (e.g. `AlohaDataset`), policy (e.g. `DiffusionPolicy`), or environment, follow these steps:
|
||||||
When implementing a concrete class (e.g. `AlohaDataset`, `PushtEnv`, `DiffusionPolicy`), you need to:
|
- Set the required class attributes: `available_datasets`.
|
||||||
1. set the required class attributes:
|
- Set the required class attributes: `name`.
|
||||||
- for classes inheriting from `AbstractDataset`: `available_datasets`
|
- Update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`)
|
||||||
- for classes inheriting from `AbstractEnv`: `name`, `available_tasks`
|
- Update variables in `tests/test_available.py` by importing your new class
|
||||||
- for classes inheriting from `AbstractPolicy`: `name`
|
|
||||||
2. update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`)
|
|
||||||
3. update variables in `tests/test_available.py` by importing your new class
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import importlib
|
||||||
import pytest
|
import pytest
|
||||||
import lerobot
|
import lerobot
|
||||||
|
import gymnasium as gym
|
||||||
|
|
||||||
# from lerobot.common.envs.aloha.env import AlohaEnv
|
from lerobot.common.datasets.xarm import XarmDataset
|
||||||
# from gym_pusht.envs import PushtEnv
|
from lerobot.common.datasets.aloha import AlohaDataset
|
||||||
# from gym_xarm.envs import SimxarmEnv
|
from lerobot.common.datasets.pusht import PushtDataset
|
||||||
|
|
||||||
# from lerobot.common.datasets.xarm import SimxarmDataset
|
from lerobot.common.policies.act.policy import ActionChunkingTransformerPolicy
|
||||||
# from lerobot.common.datasets.aloha import AlohaDataset
|
from lerobot.common.policies.diffusion.policy import DiffusionPolicy
|
||||||
# from lerobot.common.datasets.pusht import PushtDataset
|
from lerobot.common.policies.tdmpc.policy import TDMPCPolicy
|
||||||
|
|
||||||
# from lerobot.common.policies.act.policy import ActionChunkingTransformerPolicy
|
|
||||||
# from lerobot.common.policies.diffusion.policy import DiffusionPolicy
|
|
||||||
# from lerobot.common.policies.tdmpc.policy import TDMPCPolicy
|
|
||||||
|
|
||||||
|
|
||||||
# def test_available():
|
def test_available():
|
||||||
# pol_classes = [
|
policy_classes = [
|
||||||
# ActionChunkingTransformerPolicy,
|
ActionChunkingTransformerPolicy,
|
||||||
# DiffusionPolicy,
|
DiffusionPolicy,
|
||||||
# TDMPCPolicy,
|
TDMPCPolicy,
|
||||||
# ]
|
]
|
||||||
|
|
||||||
# env_classes = [
|
dataset_class_per_env = {
|
||||||
# AlohaEnv,
|
"aloha": AlohaDataset,
|
||||||
# PushtEnv,
|
"pusht": PushtDataset,
|
||||||
# SimxarmEnv,
|
"xarm": XarmDataset,
|
||||||
# ]
|
}
|
||||||
|
|
||||||
# dat_classes = [
|
|
||||||
# AlohaDataset,
|
|
||||||
# PushtDataset,
|
|
||||||
# SimxarmDataset,
|
|
||||||
# ]
|
|
||||||
|
|
||||||
# policies = [pol_cls.name for pol_cls in pol_classes]
|
policies = [pol_cls.name for pol_cls in policy_classes]
|
||||||
# assert set(policies) == set(lerobot.available_policies)
|
assert set(policies) == set(lerobot.available_policies), policies
|
||||||
|
|
||||||
# envs = [env_cls.name for env_cls in env_classes]
|
for env_name in lerobot.available_envs:
|
||||||
# assert set(envs) == set(lerobot.available_envs)
|
for task_name in lerobot.available_tasks_per_env[env_name]:
|
||||||
|
package_name = f"gym_{env_name}"
|
||||||
|
importlib.import_module(package_name)
|
||||||
|
gym_handle = f"{package_name}/{task_name}"
|
||||||
|
assert gym_handle in gym.envs.registry.keys(), gym_handle
|
||||||
|
|
||||||
# tasks_per_env = {env_cls.name: env_cls.available_tasks for env_cls in env_classes}
|
dataset_class = dataset_class_per_env[env_name]
|
||||||
# for env in envs:
|
available_datasets = lerobot.available_datasets_per_env[env_name]
|
||||||
# assert set(tasks_per_env[env]) == set(lerobot.available_tasks_per_env[env])
|
assert set(available_datasets) == set(dataset_class.available_datasets), f"{env_name=} {available_datasets=}"
|
||||||
|
|
||||||
# datasets_per_env = {env_cls.name: dat_cls.available_datasets for env_cls, dat_cls in zip(env_classes, dat_classes)}
|
|
||||||
# for env in envs:
|
|
||||||
# assert set(datasets_per_env[env]) == set(lerobot.available_datasets_per_env[env])
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,12 @@
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
import einops
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from lerobot.common.datasets.utils import compute_stats, get_stats_einops_patterns, load_data_with_delta_timestamps
|
||||||
|
from lerobot.common.datasets.xarm import XarmDataset
|
||||||
|
from lerobot.common.transforms import Prod
|
||||||
from lerobot.common.utils import init_hydra_config
|
from lerobot.common.utils import init_hydra_config
|
||||||
import logging
|
import logging
|
||||||
from lerobot.common.datasets.factory import make_dataset
|
from lerobot.common.datasets.factory import make_dataset
|
||||||
|
@ -45,6 +51,7 @@ def test_factory(env_name, dataset_id, policy_name):
|
||||||
keys_ndim_required.append(
|
keys_ndim_required.append(
|
||||||
(key, 3, True),
|
(key, 3, True),
|
||||||
)
|
)
|
||||||
|
assert dataset.data_dict[key].dtype == torch.uint8, f"{key}"
|
||||||
|
|
||||||
# test number of dimensions
|
# test number of dimensions
|
||||||
for key, ndim, required in keys_ndim_required:
|
for key, ndim, required in keys_ndim_required:
|
||||||
|
@ -81,28 +88,104 @@ def test_factory(env_name, dataset_id, policy_name):
|
||||||
assert key in item, f"{key}"
|
assert key in item, f"{key}"
|
||||||
|
|
||||||
|
|
||||||
# def test_compute_stats():
|
def test_compute_stats():
|
||||||
# """Check that the statistics are computed correctly according to the stats_patterns property.
|
"""Check that the statistics are computed correctly according to the stats_patterns property.
|
||||||
|
|
||||||
|
We compare with taking a straight min, mean, max, std of all the data in one pass (which we can do
|
||||||
|
because we are working with a small dataset).
|
||||||
|
"""
|
||||||
|
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
|
||||||
|
|
||||||
|
# get transform to convert images from uint8 [0,255] to float32 [0,1]
|
||||||
|
transform = Prod(in_keys=XarmDataset.image_keys, prod=1 / 255.0)
|
||||||
|
|
||||||
|
dataset = XarmDataset(
|
||||||
|
dataset_id="xarm_lift_medium",
|
||||||
|
root=DATA_DIR,
|
||||||
|
transform=transform,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Note: we set the batch size to be smaller than the whole dataset to make sure we are testing batched
|
||||||
|
# computation of the statistics. While doing this, we also make sure it works when we don't divide the
|
||||||
|
# dataset into even batches.
|
||||||
|
computed_stats = compute_stats(dataset, batch_size=int(len(dataset) * 0.25))
|
||||||
|
|
||||||
|
# get einops patterns to aggregate batches and compute statistics
|
||||||
|
stats_patterns = get_stats_einops_patterns(dataset)
|
||||||
|
|
||||||
|
# get all frames from the dataset in the same dtype and range as during compute_stats
|
||||||
|
data_dict = transform(dataset.data_dict)
|
||||||
|
|
||||||
|
# compute stats based on all frames from the dataset without any batching
|
||||||
|
expected_stats = {}
|
||||||
|
for k, pattern in stats_patterns.items():
|
||||||
|
expected_stats[k] = {}
|
||||||
|
expected_stats[k]["mean"] = einops.reduce(data_dict[k], pattern, "mean")
|
||||||
|
expected_stats[k]["std"] = torch.sqrt(einops.reduce((data_dict[k] - expected_stats[k]["mean"]) ** 2, pattern, "mean"))
|
||||||
|
expected_stats[k]["min"] = einops.reduce(data_dict[k], pattern, "min")
|
||||||
|
expected_stats[k]["max"] = einops.reduce(data_dict[k], pattern, "max")
|
||||||
|
|
||||||
|
# test computed stats match expected stats
|
||||||
|
for k in stats_patterns:
|
||||||
|
assert torch.allclose(computed_stats[k]["mean"], expected_stats[k]["mean"])
|
||||||
|
assert torch.allclose(computed_stats[k]["std"], expected_stats[k]["std"])
|
||||||
|
assert torch.allclose(computed_stats[k]["min"], expected_stats[k]["min"])
|
||||||
|
assert torch.allclose(computed_stats[k]["max"], expected_stats[k]["max"])
|
||||||
|
|
||||||
|
# TODO(rcadene): check that the stats used for training are correct too
|
||||||
|
# # load stats that are expected to match the ones returned by computed_stats
|
||||||
|
# assert (dataset.data_dir / "stats.pth").exists()
|
||||||
|
# loaded_stats = torch.load(dataset.data_dir / "stats.pth")
|
||||||
|
|
||||||
|
# # test loaded stats match expected stats
|
||||||
|
# for k in stats_patterns:
|
||||||
|
# assert torch.allclose(loaded_stats[k]["mean"], expected_stats[k]["mean"])
|
||||||
|
# assert torch.allclose(loaded_stats[k]["std"], expected_stats[k]["std"])
|
||||||
|
# assert torch.allclose(loaded_stats[k]["min"], expected_stats[k]["min"])
|
||||||
|
# assert torch.allclose(loaded_stats[k]["max"], expected_stats[k]["max"])
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_data_with_delta_timestamps_within_tolerance():
|
||||||
|
data_dict = {
|
||||||
|
"timestamp": torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5]),
|
||||||
|
"index": torch.tensor([0, 1, 2, 3, 4]),
|
||||||
|
}
|
||||||
|
data_ids_per_episode = {0: torch.tensor([0, 1, 2, 3, 4])}
|
||||||
|
delta_timestamps = {"index": [-0.2, 0, 0.139]}
|
||||||
|
key = "index"
|
||||||
|
current_ts = 0.3
|
||||||
|
episode = 0
|
||||||
|
tol = 0.04
|
||||||
|
data, is_pad = load_data_with_delta_timestamps(data_dict, data_ids_per_episode, delta_timestamps, key, current_ts, episode, tol)
|
||||||
|
assert not is_pad.any(), "Unexpected padding detected"
|
||||||
|
assert torch.equal(data, torch.tensor([0, 2, 3])), "Data does not match expected values"
|
||||||
|
|
||||||
|
def test_load_data_with_delta_timestamps_outside_tolerance_inside_episode_range():
|
||||||
|
data_dict = {
|
||||||
|
"timestamp": torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5]),
|
||||||
|
"index": torch.tensor([0, 1, 2, 3, 4]),
|
||||||
|
}
|
||||||
|
data_ids_per_episode = {0: torch.tensor([0, 1, 2, 3, 4])}
|
||||||
|
delta_timestamps = {"index": [-0.2, 0, 0.141]}
|
||||||
|
key = "index"
|
||||||
|
current_ts = 0.3
|
||||||
|
episode = 0
|
||||||
|
tol = 0.04
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
load_data_with_delta_timestamps(data_dict, data_ids_per_episode, delta_timestamps, key, current_ts, episode, tol)
|
||||||
|
|
||||||
|
def test_load_data_with_delta_timestamps_outside_tolerance_outside_episode_range():
|
||||||
|
data_dict = {
|
||||||
|
"timestamp": torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5]),
|
||||||
|
"index": torch.tensor([0, 1, 2, 3, 4]),
|
||||||
|
}
|
||||||
|
data_ids_per_episode = {0: torch.tensor([0, 1, 2, 3, 4])}
|
||||||
|
delta_timestamps = {"index": [-0.3, -0.24, 0, 0.26, 0.3]}
|
||||||
|
key = "index"
|
||||||
|
current_ts = 0.3
|
||||||
|
episode = 0
|
||||||
|
tol = 0.04
|
||||||
|
data, is_pad = load_data_with_delta_timestamps(data_dict, data_ids_per_episode, delta_timestamps, key, current_ts, episode, tol)
|
||||||
|
assert torch.equal(is_pad, torch.tensor([True, False, False, True, True])), "Padding does not match expected values"
|
||||||
|
assert torch.equal(data, torch.tensor([0, 0, 2, 4, 4])), "Data does not match expected values"
|
||||||
|
|
||||||
# We compare with taking a straight min, mean, max, std of all the data in one pass (which we can do
|
|
||||||
# because we are working with a small dataset).
|
|
||||||
# """
|
|
||||||
# cfg = init_hydra_config(
|
|
||||||
# DEFAULT_CONFIG_PATH, overrides=["env=aloha", "env.task=sim_transfer_cube_human"]
|
|
||||||
# )
|
|
||||||
# dataset = make_dataset(cfg)
|
|
||||||
# # Get all of the data.
|
|
||||||
# all_data = dataset.data_dict
|
|
||||||
# # Note: we set the batch size to be smaller than the whole dataset to make sure we are testing batched
|
|
||||||
# # computation of the statistics. While doing this, we also make sure it works when we don't divide the
|
|
||||||
# # dataset into even batches.
|
|
||||||
# computed_stats = buffer._compute_stats(batch_size=int(len(all_data) * 0.75))
|
|
||||||
# for k, pattern in buffer.stats_patterns.items():
|
|
||||||
# expected_mean = einops.reduce(all_data[k], pattern, "mean")
|
|
||||||
# assert torch.allclose(computed_stats[k]["mean"], expected_mean)
|
|
||||||
# assert torch.allclose(
|
|
||||||
# computed_stats[k]["std"],
|
|
||||||
# torch.sqrt(einops.reduce((all_data[k] - expected_mean) ** 2, pattern, "mean"))
|
|
||||||
# )
|
|
||||||
# assert torch.allclose(computed_stats[k]["min"], einops.reduce(all_data[k], pattern, "min"))
|
|
||||||
# assert torch.allclose(computed_stats[k]["max"], einops.reduce(all_data[k], pattern, "max"))
|
|
||||||
|
|
Loading…
Reference in New Issue