Merge branch 'main' of github.com:huggingface/lerobot
This commit is contained in:
commit
c0166949ad
|
@ -14,21 +14,9 @@ env:
|
||||||
jobs:
|
jobs:
|
||||||
latest-cpu:
|
latest-cpu:
|
||||||
name: CPU
|
name: CPU
|
||||||
runs-on: ubuntu-latest
|
runs-on:
|
||||||
|
group: aws-general-8-plus
|
||||||
steps:
|
steps:
|
||||||
- name: Cleanup disk
|
|
||||||
run: |
|
|
||||||
sudo df -h
|
|
||||||
# sudo ls -l /usr/local/lib/
|
|
||||||
# sudo ls -l /usr/share/
|
|
||||||
sudo du -sh /usr/local/lib/
|
|
||||||
sudo du -sh /usr/share/
|
|
||||||
sudo rm -rf /usr/local/lib/android
|
|
||||||
sudo rm -rf /usr/share/dotnet
|
|
||||||
sudo du -sh /usr/local/lib/
|
|
||||||
sudo du -sh /usr/share/
|
|
||||||
sudo df -h
|
|
||||||
|
|
||||||
- name: Set up Docker Buildx
|
- name: Set up Docker Buildx
|
||||||
uses: docker/setup-buildx-action@v3
|
uses: docker/setup-buildx-action@v3
|
||||||
|
|
||||||
|
@ -55,20 +43,9 @@ jobs:
|
||||||
|
|
||||||
latest-cuda:
|
latest-cuda:
|
||||||
name: GPU
|
name: GPU
|
||||||
runs-on: ubuntu-latest
|
runs-on:
|
||||||
|
group: aws-general-8-plus
|
||||||
steps:
|
steps:
|
||||||
- name: Cleanup disk
|
|
||||||
run: |
|
|
||||||
sudo df -h
|
|
||||||
# sudo ls -l /usr/local/lib/
|
|
||||||
# sudo ls -l /usr/share/
|
|
||||||
sudo du -sh /usr/local/lib/
|
|
||||||
sudo du -sh /usr/share/
|
|
||||||
sudo rm -rf /usr/local/lib/android
|
|
||||||
sudo rm -rf /usr/share/dotnet
|
|
||||||
sudo du -sh /usr/local/lib/
|
|
||||||
sudo du -sh /usr/share/
|
|
||||||
sudo df -h
|
|
||||||
- name: Set up Docker Buildx
|
- name: Set up Docker Buildx
|
||||||
uses: docker/setup-buildx-action@v3
|
uses: docker/setup-buildx-action@v3
|
||||||
|
|
||||||
|
@ -95,20 +72,9 @@ jobs:
|
||||||
|
|
||||||
latest-cuda-dev:
|
latest-cuda-dev:
|
||||||
name: GPU Dev
|
name: GPU Dev
|
||||||
runs-on: ubuntu-latest
|
runs-on:
|
||||||
|
group: aws-general-8-plus
|
||||||
steps:
|
steps:
|
||||||
- name: Cleanup disk
|
|
||||||
run: |
|
|
||||||
sudo df -h
|
|
||||||
# sudo ls -l /usr/local/lib/
|
|
||||||
# sudo ls -l /usr/share/
|
|
||||||
sudo du -sh /usr/local/lib/
|
|
||||||
sudo du -sh /usr/share/
|
|
||||||
sudo rm -rf /usr/local/lib/android
|
|
||||||
sudo rm -rf /usr/share/dotnet
|
|
||||||
sudo du -sh /usr/local/lib/
|
|
||||||
sudo du -sh /usr/share/
|
|
||||||
sudo df -h
|
|
||||||
- name: Set up Docker Buildx
|
- name: Set up Docker Buildx
|
||||||
uses: docker/setup-buildx-action@v3
|
uses: docker/setup-buildx-action@v3
|
||||||
|
|
||||||
|
|
|
@ -16,7 +16,8 @@ jobs:
|
||||||
name: CPU
|
name: CPU
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
runs-on: ubuntu-latest
|
runs-on:
|
||||||
|
group: aws-general-8-plus
|
||||||
container:
|
container:
|
||||||
image: huggingface/lerobot-cpu:latest
|
image: huggingface/lerobot-cpu:latest
|
||||||
options: --shm-size "16gb"
|
options: --shm-size "16gb"
|
||||||
|
@ -43,7 +44,8 @@ jobs:
|
||||||
name: GPU
|
name: GPU
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
runs-on: [single-gpu, nvidia-gpu, t4, ci]
|
runs-on:
|
||||||
|
group: aws-g6-4xlarge-plus
|
||||||
env:
|
env:
|
||||||
CUDA_VISIBLE_DEVICES: "0"
|
CUDA_VISIBLE_DEVICES: "0"
|
||||||
TEST_TYPE: "single_gpu"
|
TEST_TYPE: "single_gpu"
|
||||||
|
|
|
@ -42,26 +42,14 @@ jobs:
|
||||||
build_modified_dockerfiles:
|
build_modified_dockerfiles:
|
||||||
name: Build modified Docker images
|
name: Build modified Docker images
|
||||||
needs: get_changed_files
|
needs: get_changed_files
|
||||||
runs-on: ubuntu-latest
|
runs-on:
|
||||||
|
group: aws-general-8-plus
|
||||||
if: ${{ needs.get_changed_files.outputs.matrix }} != ''
|
if: ${{ needs.get_changed_files.outputs.matrix }} != ''
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
docker-file: ${{ fromJson(needs.get_changed_files.outputs.matrix) }}
|
docker-file: ${{ fromJson(needs.get_changed_files.outputs.matrix) }}
|
||||||
steps:
|
steps:
|
||||||
- name: Cleanup disk
|
|
||||||
run: |
|
|
||||||
sudo df -h
|
|
||||||
# sudo ls -l /usr/local/lib/
|
|
||||||
# sudo ls -l /usr/share/
|
|
||||||
sudo du -sh /usr/local/lib/
|
|
||||||
sudo du -sh /usr/share/
|
|
||||||
sudo rm -rf /usr/local/lib/android
|
|
||||||
sudo rm -rf /usr/share/dotnet
|
|
||||||
sudo du -sh /usr/local/lib/
|
|
||||||
sudo du -sh /usr/share/
|
|
||||||
sudo df -h
|
|
||||||
|
|
||||||
- name: Set up Docker Buildx
|
- name: Set up Docker Buildx
|
||||||
uses: docker/setup-buildx-action@v3
|
uses: docker/setup-buildx-action@v3
|
||||||
|
|
||||||
|
|
24
Makefile
24
Makefile
|
@ -26,6 +26,7 @@ test-end-to-end:
|
||||||
${MAKE} DEVICE=$(DEVICE) test-diffusion-ete-train
|
${MAKE} DEVICE=$(DEVICE) test-diffusion-ete-train
|
||||||
${MAKE} DEVICE=$(DEVICE) test-diffusion-ete-eval
|
${MAKE} DEVICE=$(DEVICE) test-diffusion-ete-eval
|
||||||
${MAKE} DEVICE=$(DEVICE) test-tdmpc-ete-train
|
${MAKE} DEVICE=$(DEVICE) test-tdmpc-ete-train
|
||||||
|
${MAKE} DEVICE=$(DEVICE) test-tdmpc-ete-train-with-online
|
||||||
${MAKE} DEVICE=$(DEVICE) test-tdmpc-ete-eval
|
${MAKE} DEVICE=$(DEVICE) test-tdmpc-ete-eval
|
||||||
${MAKE} DEVICE=$(DEVICE) test-default-ete-eval
|
${MAKE} DEVICE=$(DEVICE) test-default-ete-eval
|
||||||
${MAKE} DEVICE=$(DEVICE) test-act-pusht-tutorial
|
${MAKE} DEVICE=$(DEVICE) test-act-pusht-tutorial
|
||||||
|
@ -113,7 +114,6 @@ test-diffusion-ete-eval:
|
||||||
env.episode_length=8 \
|
env.episode_length=8 \
|
||||||
device=$(DEVICE) \
|
device=$(DEVICE) \
|
||||||
|
|
||||||
# TODO(alexander-soare): Restore online_steps to 2 when it is reinstated.
|
|
||||||
test-tdmpc-ete-train:
|
test-tdmpc-ete-train:
|
||||||
python lerobot/scripts/train.py \
|
python lerobot/scripts/train.py \
|
||||||
policy=tdmpc \
|
policy=tdmpc \
|
||||||
|
@ -133,6 +133,28 @@ test-tdmpc-ete-train:
|
||||||
training.image_transforms.enable=true \
|
training.image_transforms.enable=true \
|
||||||
hydra.run.dir=tests/outputs/tdmpc/
|
hydra.run.dir=tests/outputs/tdmpc/
|
||||||
|
|
||||||
|
test-tdmpc-ete-train-with-online:
|
||||||
|
python lerobot/scripts/train.py \
|
||||||
|
env=pusht \
|
||||||
|
env.gym.obs_type=environment_state_agent_pos \
|
||||||
|
policy=tdmpc_pusht_keypoints \
|
||||||
|
eval.n_episodes=1 \
|
||||||
|
eval.batch_size=1 \
|
||||||
|
env.episode_length=10 \
|
||||||
|
device=$(DEVICE) \
|
||||||
|
training.offline_steps=2 \
|
||||||
|
training.online_steps=20 \
|
||||||
|
training.save_checkpoint=false \
|
||||||
|
training.save_freq=10 \
|
||||||
|
training.batch_size=2 \
|
||||||
|
training.online_rollout_n_episodes=2 \
|
||||||
|
training.online_rollout_batch_size=2 \
|
||||||
|
training.online_steps_between_rollouts=10 \
|
||||||
|
training.online_buffer_capacity=15 \
|
||||||
|
eval.use_async_envs=true \
|
||||||
|
hydra.run.dir=tests/outputs/tdmpc_online/
|
||||||
|
|
||||||
|
|
||||||
test-tdmpc-ete-eval:
|
test-tdmpc-ete-eval:
|
||||||
python lerobot/scripts/eval.py \
|
python lerobot/scripts/eval.py \
|
||||||
-p tests/outputs/tdmpc/checkpoints/000002/pretrained_model \
|
-p tests/outputs/tdmpc/checkpoints/000002/pretrained_model \
|
||||||
|
|
10
README.md
10
README.md
|
@ -65,12 +65,14 @@
|
||||||
|
|
||||||
Download our source code:
|
Download our source code:
|
||||||
```bash
|
```bash
|
||||||
git clone https://github.com/huggingface/lerobot.git && cd lerobot
|
git clone https://github.com/huggingface/lerobot.git
|
||||||
|
cd lerobot
|
||||||
```
|
```
|
||||||
|
|
||||||
Create a virtual environment with Python 3.10 and activate it, e.g. with [`miniconda`](https://docs.anaconda.com/free/miniconda/index.html):
|
Create a virtual environment with Python 3.10 and activate it, e.g. with [`miniconda`](https://docs.anaconda.com/free/miniconda/index.html):
|
||||||
```bash
|
```bash
|
||||||
conda create -y -n lerobot python=3.10 && conda activate lerobot
|
conda create -y -n lerobot python=3.10
|
||||||
|
conda activate lerobot
|
||||||
```
|
```
|
||||||
|
|
||||||
Install 🤗 LeRobot:
|
Install 🤗 LeRobot:
|
||||||
|
@ -180,8 +182,10 @@ dataset attributes:
|
||||||
│ ├ observation.images.cam_high: {'max': tensor with same number of dimensions (e.g. `(c, 1, 1)` for images, `(c,)` for states), etc.}
|
│ ├ observation.images.cam_high: {'max': tensor with same number of dimensions (e.g. `(c, 1, 1)` for images, `(c,)` for states), etc.}
|
||||||
│ ...
|
│ ...
|
||||||
├ info: a dictionary of metadata on the dataset
|
├ info: a dictionary of metadata on the dataset
|
||||||
|
│ ├ codebase_version (str): this is to keep track of the codebase version the dataset was created with
|
||||||
│ ├ fps (float): frame per second the dataset is recorded/synchronized to
|
│ ├ fps (float): frame per second the dataset is recorded/synchronized to
|
||||||
│ └ video (bool): indicates if frames are encoded in mp4 video files to save space or stored as png files
|
│ ├ video (bool): indicates if frames are encoded in mp4 video files to save space or stored as png files
|
||||||
|
│ └ encoding (dict): if video, this documents the main options that were used with ffmpeg to encode the videos
|
||||||
├ videos_dir (Path): where the mp4 videos or png images are stored/accessed
|
├ videos_dir (Path): where the mp4 videos or png images are stored/accessed
|
||||||
└ camera_keys (list of string): the keys to access camera features in the item returned by the dataset (e.g. `["observation.images.cam_high", ...]`)
|
└ camera_keys (list of string): the keys to access camera features in the item returned by the dataset (e.g. `["observation.images.cam_high", ...]`)
|
||||||
```
|
```
|
||||||
|
|
|
@ -257,10 +257,10 @@ def benchmark_encoding_decoding(
|
||||||
imgs_dir=imgs_dir,
|
imgs_dir=imgs_dir,
|
||||||
video_path=video_path,
|
video_path=video_path,
|
||||||
fps=fps,
|
fps=fps,
|
||||||
video_codec=encoding_cfg["vcodec"],
|
vcodec=encoding_cfg["vcodec"],
|
||||||
pixel_format=encoding_cfg["pix_fmt"],
|
pix_fmt=encoding_cfg["pix_fmt"],
|
||||||
group_of_pictures_size=encoding_cfg.get("g"),
|
g=encoding_cfg.get("g"),
|
||||||
constant_rate_factor=encoding_cfg.get("crf"),
|
crf=encoding_cfg.get("crf"),
|
||||||
# fast_decode=encoding_cfg.get("fastdecode"),
|
# fast_decode=encoding_cfg.get("fastdecode"),
|
||||||
overwrite=True,
|
overwrite=True,
|
||||||
)
|
)
|
||||||
|
|
|
@ -125,6 +125,10 @@ available_real_world_datasets = [
|
||||||
"lerobot/aloha_static_vinh_cup_left",
|
"lerobot/aloha_static_vinh_cup_left",
|
||||||
"lerobot/aloha_static_ziploc_slide",
|
"lerobot/aloha_static_ziploc_slide",
|
||||||
"lerobot/umi_cup_in_the_wild",
|
"lerobot/umi_cup_in_the_wild",
|
||||||
|
"lerobot/unitreeh1_fold_clothes",
|
||||||
|
"lerobot/unitreeh1_rearrange_objects",
|
||||||
|
"lerobot/unitreeh1_two_robot_greeting",
|
||||||
|
"lerobot/unitreeh1_warehouse",
|
||||||
]
|
]
|
||||||
|
|
||||||
available_datasets = list(
|
available_datasets = list(
|
||||||
|
|
|
@ -35,9 +35,8 @@ from lerobot.common.datasets.utils import (
|
||||||
)
|
)
|
||||||
from lerobot.common.datasets.video_utils import VideoFrame, load_from_videos
|
from lerobot.common.datasets.video_utils import VideoFrame, load_from_videos
|
||||||
|
|
||||||
# For maintainers, see lerobot/common/datasets/push_dataset_to_hub/codebase_version.md
|
# For maintainers, see lerobot/common/datasets/push_dataset_to_hub/CODEBASE_VERSION.md
|
||||||
CODEBASE_VERSION = "v1.5"
|
CODEBASE_VERSION = "v1.6"
|
||||||
|
|
||||||
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
|
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,384 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""An online buffer for the online training loop in train.py
|
||||||
|
|
||||||
|
Note to maintainers: This duplicates some logic from LeRobotDataset and EpisodeAwareSampler. We should
|
||||||
|
consider converging to one approach. Here we have opted to use numpy.memmap to back the data buffer. It's much
|
||||||
|
faster than using HuggingFace Datasets as there's no conversion to an intermediate non-python object. Also it
|
||||||
|
supports in-place slicing and mutation which is very handy for a dynamic buffer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
|
|
||||||
|
|
||||||
|
def _make_memmap_safe(**kwargs) -> np.memmap:
|
||||||
|
"""Make a numpy memmap with checks on available disk space first.
|
||||||
|
|
||||||
|
Expected kwargs are: "filename", "dtype" (must by np.dtype), "mode" and "shape"
|
||||||
|
|
||||||
|
For information on dtypes:
|
||||||
|
https://numpy.org/doc/stable/reference/arrays.dtypes.html#arrays-dtypes-constructing
|
||||||
|
"""
|
||||||
|
if kwargs["mode"].startswith("w"):
|
||||||
|
required_space = kwargs["dtype"].itemsize * np.prod(kwargs["shape"]) # bytes
|
||||||
|
stats = os.statvfs(Path(kwargs["filename"]).parent)
|
||||||
|
available_space = stats.f_bavail * stats.f_frsize # bytes
|
||||||
|
if required_space >= available_space * 0.8:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"You're about to take up {required_space} of {available_space} bytes available."
|
||||||
|
)
|
||||||
|
return np.memmap(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class OnlineBuffer(torch.utils.data.Dataset):
|
||||||
|
"""FIFO data buffer for the online training loop in train.py.
|
||||||
|
|
||||||
|
Follows the protocol of LeRobotDataset as much as is required to have it be used by the online training
|
||||||
|
loop in the same way that a LeRobotDataset would be used.
|
||||||
|
|
||||||
|
The underlying data structure will have data inserted in a circular fashion. Always insert after the
|
||||||
|
last index, and when you reach the end, wrap around to the start.
|
||||||
|
|
||||||
|
The data is stored in a numpy memmap.
|
||||||
|
"""
|
||||||
|
|
||||||
|
NEXT_INDEX_KEY = "_next_index"
|
||||||
|
OCCUPANCY_MASK_KEY = "_occupancy_mask"
|
||||||
|
INDEX_KEY = "index"
|
||||||
|
FRAME_INDEX_KEY = "frame_index"
|
||||||
|
EPISODE_INDEX_KEY = "episode_index"
|
||||||
|
TIMESTAMP_KEY = "timestamp"
|
||||||
|
IS_PAD_POSTFIX = "_is_pad"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
write_dir: str | Path,
|
||||||
|
data_spec: dict[str, Any] | None,
|
||||||
|
buffer_capacity: int | None,
|
||||||
|
fps: float | None = None,
|
||||||
|
delta_timestamps: dict[str, list[float]] | dict[str, np.ndarray] | None = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
The online buffer can be provided from scratch or you can load an existing online buffer by passing
|
||||||
|
a `write_dir` associated with an existing buffer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
write_dir: Where to keep the numpy memmap files. One memmap file will be stored for each data key.
|
||||||
|
Note that if the files already exist, they are opened in read-write mode (used for training
|
||||||
|
resumption.)
|
||||||
|
data_spec: A mapping from data key to data specification, like {data_key: {"shape": tuple[int],
|
||||||
|
"dtype": np.dtype}}. This should include all the data that you wish to record into the buffer,
|
||||||
|
but note that "index", "frame_index" and "episode_index" are already accounted for by this
|
||||||
|
class, so you don't need to include them.
|
||||||
|
buffer_capacity: How many frames should be stored in the buffer as a maximum. Be aware of your
|
||||||
|
system's available disk space when choosing this.
|
||||||
|
fps: Same as the fps concept in LeRobot dataset. Here it needs to be provided for the
|
||||||
|
delta_timestamps logic. You can pass None if you are not using delta_timestamps.
|
||||||
|
delta_timestamps: Same as the delta_timestamps concept in LeRobotDataset. This is internally
|
||||||
|
converted to dict[str, np.ndarray] for optimization purposes.
|
||||||
|
|
||||||
|
"""
|
||||||
|
self.set_delta_timestamps(delta_timestamps)
|
||||||
|
self._fps = fps
|
||||||
|
# Tolerance in seconds used to discard loaded frames when their timestamps are not close enough from
|
||||||
|
# the requested frames. It is only used when `delta_timestamps` is provided.
|
||||||
|
# minus 1e-4 to account for possible numerical error
|
||||||
|
self.tolerance_s = 1 / self.fps - 1e-4 if fps is not None else None
|
||||||
|
self._buffer_capacity = buffer_capacity
|
||||||
|
data_spec = self._make_data_spec(data_spec, buffer_capacity)
|
||||||
|
Path(write_dir).mkdir(parents=True, exist_ok=True)
|
||||||
|
self._data = {}
|
||||||
|
for k, v in data_spec.items():
|
||||||
|
self._data[k] = _make_memmap_safe(
|
||||||
|
filename=Path(write_dir) / k,
|
||||||
|
dtype=v["dtype"] if v is not None else None,
|
||||||
|
mode="r+" if (Path(write_dir) / k).exists() else "w+",
|
||||||
|
shape=tuple(v["shape"]) if v is not None else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def delta_timestamps(self) -> dict[str, np.ndarray] | None:
|
||||||
|
return self._delta_timestamps
|
||||||
|
|
||||||
|
def set_delta_timestamps(self, value: dict[str, list[float]] | None):
|
||||||
|
"""Set delta_timestamps converting the values to numpy arrays.
|
||||||
|
|
||||||
|
The conversion is for an optimization in the __getitem__. The loop is much slower if the arrays
|
||||||
|
need to be converted into numpy arrays.
|
||||||
|
"""
|
||||||
|
if value is not None:
|
||||||
|
self._delta_timestamps = {k: np.array(v) for k, v in value.items()}
|
||||||
|
else:
|
||||||
|
self._delta_timestamps = None
|
||||||
|
|
||||||
|
def _make_data_spec(self, data_spec: dict[str, Any], buffer_capacity: int) -> dict[str, dict[str, Any]]:
|
||||||
|
"""Makes the data spec for np.memmap."""
|
||||||
|
if any(k.startswith("_") for k in data_spec):
|
||||||
|
raise ValueError(
|
||||||
|
"data_spec keys should not start with '_'. This prefix is reserved for internal logic."
|
||||||
|
)
|
||||||
|
preset_keys = {
|
||||||
|
OnlineBuffer.INDEX_KEY,
|
||||||
|
OnlineBuffer.FRAME_INDEX_KEY,
|
||||||
|
OnlineBuffer.EPISODE_INDEX_KEY,
|
||||||
|
OnlineBuffer.TIMESTAMP_KEY,
|
||||||
|
}
|
||||||
|
if len(intersection := set(data_spec).intersection(preset_keys)) > 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"data_spec should not contain any of {preset_keys} as these are handled internally. "
|
||||||
|
f"The provided data_spec has {intersection}."
|
||||||
|
)
|
||||||
|
complete_data_spec = {
|
||||||
|
# _next_index will be a pointer to the next index that we should start filling from when we add
|
||||||
|
# more data.
|
||||||
|
OnlineBuffer.NEXT_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": ()},
|
||||||
|
# Since the memmap is initialized with all-zeros, this keeps track of which indices are occupied
|
||||||
|
# with real data rather than the dummy initialization.
|
||||||
|
OnlineBuffer.OCCUPANCY_MASK_KEY: {"dtype": np.dtype("?"), "shape": (buffer_capacity,)},
|
||||||
|
OnlineBuffer.INDEX_KEY: {"dtype": np.dtype("int64"), "shape": (buffer_capacity,)},
|
||||||
|
OnlineBuffer.FRAME_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": (buffer_capacity,)},
|
||||||
|
OnlineBuffer.EPISODE_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": (buffer_capacity,)},
|
||||||
|
OnlineBuffer.TIMESTAMP_KEY: {"dtype": np.dtype("float64"), "shape": (buffer_capacity,)},
|
||||||
|
}
|
||||||
|
for k, v in data_spec.items():
|
||||||
|
complete_data_spec[k] = {"dtype": v["dtype"], "shape": (buffer_capacity, *v["shape"])}
|
||||||
|
return complete_data_spec
|
||||||
|
|
||||||
|
def add_data(self, data: dict[str, np.ndarray]):
|
||||||
|
"""Add new data to the buffer, which could potentially mean shifting old data out.
|
||||||
|
|
||||||
|
The new data should contain all the frames (in order) of any number of episodes. The indices should
|
||||||
|
start from 0 (note to the developer: this can easily be generalized). See the `rollout` and
|
||||||
|
`eval_policy` functions in `eval.py` for more information on how the data is constructed.
|
||||||
|
|
||||||
|
Shift the incoming data index and episode_index to continue on from the last frame. Note that this
|
||||||
|
will be done in place!
|
||||||
|
"""
|
||||||
|
if len(missing_keys := (set(self.data_keys).difference(set(data)))) > 0:
|
||||||
|
raise ValueError(f"Missing data keys: {missing_keys}")
|
||||||
|
new_data_length = len(data[self.data_keys[0]])
|
||||||
|
if not all(len(data[k]) == new_data_length for k in self.data_keys):
|
||||||
|
raise ValueError("All data items should have the same length")
|
||||||
|
|
||||||
|
next_index = self._data[OnlineBuffer.NEXT_INDEX_KEY]
|
||||||
|
|
||||||
|
# Sanity check to make sure that the new data indices start from 0.
|
||||||
|
assert data[OnlineBuffer.EPISODE_INDEX_KEY][0].item() == 0
|
||||||
|
assert data[OnlineBuffer.INDEX_KEY][0].item() == 0
|
||||||
|
|
||||||
|
# Shift the incoming indices if necessary.
|
||||||
|
if self.num_samples > 0:
|
||||||
|
last_episode_index = self._data[OnlineBuffer.EPISODE_INDEX_KEY][next_index - 1]
|
||||||
|
last_data_index = self._data[OnlineBuffer.INDEX_KEY][next_index - 1]
|
||||||
|
data[OnlineBuffer.EPISODE_INDEX_KEY] += last_episode_index + 1
|
||||||
|
data[OnlineBuffer.INDEX_KEY] += last_data_index + 1
|
||||||
|
|
||||||
|
# Insert the new data starting from next_index. It may be necessary to wrap around to the start.
|
||||||
|
n_surplus = max(0, new_data_length - (self._buffer_capacity - next_index))
|
||||||
|
for k in self.data_keys:
|
||||||
|
if n_surplus == 0:
|
||||||
|
slc = slice(next_index, next_index + new_data_length)
|
||||||
|
self._data[k][slc] = data[k]
|
||||||
|
self._data[OnlineBuffer.OCCUPANCY_MASK_KEY][slc] = True
|
||||||
|
else:
|
||||||
|
self._data[k][next_index:] = data[k][:-n_surplus]
|
||||||
|
self._data[OnlineBuffer.OCCUPANCY_MASK_KEY][next_index:] = True
|
||||||
|
self._data[k][:n_surplus] = data[k][-n_surplus:]
|
||||||
|
if n_surplus == 0:
|
||||||
|
self._data[OnlineBuffer.NEXT_INDEX_KEY] = next_index + new_data_length
|
||||||
|
else:
|
||||||
|
self._data[OnlineBuffer.NEXT_INDEX_KEY] = n_surplus
|
||||||
|
|
||||||
|
@property
|
||||||
|
def data_keys(self) -> list[str]:
|
||||||
|
keys = set(self._data)
|
||||||
|
keys.remove(OnlineBuffer.OCCUPANCY_MASK_KEY)
|
||||||
|
keys.remove(OnlineBuffer.NEXT_INDEX_KEY)
|
||||||
|
return sorted(keys)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def fps(self) -> float | None:
|
||||||
|
return self._fps
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_episodes(self) -> int:
|
||||||
|
return len(
|
||||||
|
np.unique(self._data[OnlineBuffer.EPISODE_INDEX_KEY][self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]])
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_samples(self) -> int:
|
||||||
|
return np.count_nonzero(self._data[OnlineBuffer.OCCUPANCY_MASK_KEY])
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.num_samples
|
||||||
|
|
||||||
|
def _item_to_tensors(self, item: dict) -> dict:
|
||||||
|
item_ = {}
|
||||||
|
for k, v in item.items():
|
||||||
|
if isinstance(v, torch.Tensor):
|
||||||
|
item_[k] = v
|
||||||
|
elif isinstance(v, np.ndarray):
|
||||||
|
item_[k] = torch.from_numpy(v)
|
||||||
|
else:
|
||||||
|
item_[k] = torch.tensor(v)
|
||||||
|
return item_
|
||||||
|
|
||||||
|
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
|
||||||
|
if idx >= len(self) or idx < -len(self):
|
||||||
|
raise IndexError
|
||||||
|
|
||||||
|
item = {k: v[idx] for k, v in self._data.items() if not k.startswith("_")}
|
||||||
|
|
||||||
|
if self.delta_timestamps is None:
|
||||||
|
return self._item_to_tensors(item)
|
||||||
|
|
||||||
|
episode_index = item[OnlineBuffer.EPISODE_INDEX_KEY]
|
||||||
|
current_ts = item[OnlineBuffer.TIMESTAMP_KEY]
|
||||||
|
episode_data_indices = np.where(
|
||||||
|
np.bitwise_and(
|
||||||
|
self._data[OnlineBuffer.EPISODE_INDEX_KEY] == episode_index,
|
||||||
|
self._data[OnlineBuffer.OCCUPANCY_MASK_KEY],
|
||||||
|
)
|
||||||
|
)[0]
|
||||||
|
episode_timestamps = self._data[OnlineBuffer.TIMESTAMP_KEY][episode_data_indices]
|
||||||
|
|
||||||
|
for data_key in self.delta_timestamps:
|
||||||
|
# Note: The logic in this loop is copied from `load_previous_and_future_frames`.
|
||||||
|
# Get timestamps used as query to retrieve data of previous/future frames.
|
||||||
|
query_ts = current_ts + self.delta_timestamps[data_key]
|
||||||
|
|
||||||
|
# Compute distances between each query timestamp and all timestamps of all the frames belonging to
|
||||||
|
# the episode.
|
||||||
|
dist = np.abs(query_ts[:, None] - episode_timestamps[None, :])
|
||||||
|
argmin_ = np.argmin(dist, axis=1)
|
||||||
|
min_ = dist[np.arange(dist.shape[0]), argmin_]
|
||||||
|
|
||||||
|
is_pad = min_ > self.tolerance_s
|
||||||
|
|
||||||
|
# Check violated query timestamps are all outside the episode range.
|
||||||
|
assert (
|
||||||
|
(query_ts[is_pad] < episode_timestamps[0]) | (episode_timestamps[-1] < query_ts[is_pad])
|
||||||
|
).all(), (
|
||||||
|
f"One or several timestamps unexpectedly violate the tolerance ({min_} > {self.tolerance_s=}"
|
||||||
|
") inside the episode range."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load frames for this data key.
|
||||||
|
item[data_key] = self._data[data_key][episode_data_indices[argmin_]]
|
||||||
|
|
||||||
|
item[f"{data_key}{OnlineBuffer.IS_PAD_POSTFIX}"] = is_pad
|
||||||
|
|
||||||
|
return self._item_to_tensors(item)
|
||||||
|
|
||||||
|
def get_data_by_key(self, key: str) -> torch.Tensor:
|
||||||
|
"""Returns all data for a given data key as a Tensor."""
|
||||||
|
return torch.from_numpy(self._data[key][self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]])
|
||||||
|
|
||||||
|
|
||||||
|
def compute_sampler_weights(
|
||||||
|
offline_dataset: LeRobotDataset,
|
||||||
|
offline_drop_n_last_frames: int = 0,
|
||||||
|
online_dataset: OnlineBuffer | None = None,
|
||||||
|
online_sampling_ratio: float | None = None,
|
||||||
|
online_drop_n_last_frames: int = 0,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Compute the sampling weights for the online training dataloader in train.py.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
offline_dataset: The LeRobotDataset used for offline pre-training.
|
||||||
|
online_drop_n_last_frames: Number of frames to drop from the end of each offline dataset episode.
|
||||||
|
online_dataset: The OnlineBuffer used in online training.
|
||||||
|
online_sampling_ratio: The proportion of data that should be sampled from the online dataset. If an
|
||||||
|
online dataset is provided, this value must also be provided.
|
||||||
|
online_drop_n_first_frames: See `offline_drop_n_last_frames`. This is the same, but for the online
|
||||||
|
dataset.
|
||||||
|
Returns:
|
||||||
|
Tensor of weights for [offline_dataset; online_dataset], normalized to 1.
|
||||||
|
|
||||||
|
Notes to maintainers:
|
||||||
|
- This duplicates some logic from EpisodeAwareSampler. We should consider converging to one approach.
|
||||||
|
- When used with `torch.utils.data.WeightedRandomSampler`, it could completely replace
|
||||||
|
`EpisodeAwareSampler` as the online dataset related arguments are optional. The only missing feature
|
||||||
|
is the ability to turn shuffling off.
|
||||||
|
- Options `drop_first_n_frames` and `episode_indices_to_use` can be added easily. They were not
|
||||||
|
included here to avoid adding complexity.
|
||||||
|
"""
|
||||||
|
if len(offline_dataset) == 0 and (online_dataset is None or len(online_dataset) == 0):
|
||||||
|
raise ValueError("At least one of `offline_dataset` or `online_dataset` should be contain data.")
|
||||||
|
if (online_dataset is None) ^ (online_sampling_ratio is None):
|
||||||
|
raise ValueError(
|
||||||
|
"`online_dataset` and `online_sampling_ratio` must be provided together or not at all."
|
||||||
|
)
|
||||||
|
offline_sampling_ratio = 0 if online_sampling_ratio is None else 1 - online_sampling_ratio
|
||||||
|
|
||||||
|
weights = []
|
||||||
|
|
||||||
|
if len(offline_dataset) > 0:
|
||||||
|
offline_data_mask_indices = []
|
||||||
|
for start_index, end_index in zip(
|
||||||
|
offline_dataset.episode_data_index["from"],
|
||||||
|
offline_dataset.episode_data_index["to"],
|
||||||
|
strict=True,
|
||||||
|
):
|
||||||
|
offline_data_mask_indices.extend(
|
||||||
|
range(start_index.item(), end_index.item() - offline_drop_n_last_frames)
|
||||||
|
)
|
||||||
|
offline_data_mask = torch.zeros(len(offline_dataset), dtype=torch.bool)
|
||||||
|
offline_data_mask[torch.tensor(offline_data_mask_indices)] = True
|
||||||
|
weights.append(
|
||||||
|
torch.full(
|
||||||
|
size=(len(offline_dataset),),
|
||||||
|
fill_value=offline_sampling_ratio / offline_data_mask.sum(),
|
||||||
|
)
|
||||||
|
* offline_data_mask
|
||||||
|
)
|
||||||
|
|
||||||
|
if online_dataset is not None and len(online_dataset) > 0:
|
||||||
|
online_data_mask_indices = []
|
||||||
|
episode_indices = online_dataset.get_data_by_key("episode_index")
|
||||||
|
for episode_idx in torch.unique(episode_indices):
|
||||||
|
where_episode = torch.where(episode_indices == episode_idx)
|
||||||
|
start_index = where_episode[0][0]
|
||||||
|
end_index = where_episode[0][-1] + 1
|
||||||
|
online_data_mask_indices.extend(
|
||||||
|
range(start_index.item(), end_index.item() - online_drop_n_last_frames)
|
||||||
|
)
|
||||||
|
online_data_mask = torch.zeros(len(online_dataset), dtype=torch.bool)
|
||||||
|
online_data_mask[torch.tensor(online_data_mask_indices)] = True
|
||||||
|
weights.append(
|
||||||
|
torch.full(
|
||||||
|
size=(len(online_dataset),),
|
||||||
|
fill_value=online_sampling_ratio / online_data_mask.sum(),
|
||||||
|
)
|
||||||
|
* online_data_mask
|
||||||
|
)
|
||||||
|
|
||||||
|
weights = torch.cat(weights)
|
||||||
|
|
||||||
|
if weights.sum() == 0:
|
||||||
|
weights += 1 / len(weights)
|
||||||
|
else:
|
||||||
|
weights /= weights.sum()
|
||||||
|
|
||||||
|
return weights
|
|
@ -10,7 +10,8 @@ For instance, [`lerobot/pusht`](https://huggingface.co/datasets/lerobot/pusht) h
|
||||||
- [v1.2](https://huggingface.co/datasets/lerobot/pusht/tree/v1.2)
|
- [v1.2](https://huggingface.co/datasets/lerobot/pusht/tree/v1.2)
|
||||||
- [v1.3](https://huggingface.co/datasets/lerobot/pusht/tree/v1.3)
|
- [v1.3](https://huggingface.co/datasets/lerobot/pusht/tree/v1.3)
|
||||||
- [v1.4](https://huggingface.co/datasets/lerobot/pusht/tree/v1.4)
|
- [v1.4](https://huggingface.co/datasets/lerobot/pusht/tree/v1.4)
|
||||||
- [v1.5](https://huggingface.co/datasets/lerobot/pusht/tree/v1.5) <-- last version
|
- [v1.5](https://huggingface.co/datasets/lerobot/pusht/tree/v1.5)
|
||||||
|
- [v1.6](https://huggingface.co/datasets/lerobot/pusht/tree/v1.6) <-- last version
|
||||||
- [main](https://huggingface.co/datasets/lerobot/pusht/tree/main) <-- points to the last version
|
- [main](https://huggingface.co/datasets/lerobot/pusht/tree/main) <-- points to the last version
|
||||||
|
|
||||||
Starting with v1.6, every dataset pushed to the hub or saved locally also have this version number in their
|
Starting with v1.6, every dataset pushed to the hub or saved locally also have this version number in their
|
||||||
|
@ -45,13 +46,11 @@ for repo_id in available_datasets:
|
||||||
dataset_info = api.list_repo_refs(repo_id, repo_type="dataset")
|
dataset_info = api.list_repo_refs(repo_id, repo_type="dataset")
|
||||||
branches = [b.name for b in dataset_info.branches]
|
branches = [b.name for b in dataset_info.branches]
|
||||||
if CODEBASE_VERSION in branches:
|
if CODEBASE_VERSION in branches:
|
||||||
# First check if the newer version already exists.
|
print(f"{repo_id} already @{CODEBASE_VERSION}, skipping.")
|
||||||
print(f"Found existing branch for {repo_id}. Please contact a member of the core LeRobot team.")
|
continue
|
||||||
print("Exiting early")
|
|
||||||
break
|
|
||||||
else:
|
else:
|
||||||
# Now create a branch named after the new version by branching out from "main"
|
# Now create a branch named after the new version by branching out from "main"
|
||||||
# which is expected to be the preceding version
|
# which is expected to be the preceding version
|
||||||
api.create_branch(repo_id, repo_type="dataset", branch=CODEBASE_VERSION, revision="main")
|
api.create_branch(repo_id, repo_type="dataset", branch=CODEBASE_VERSION, revision="main")
|
||||||
print(f"{repo_id} successfully updated")
|
print(f"{repo_id} successfully updated @{CODEBASE_VERSION}")
|
||||||
```
|
```
|
||||||
|
|
|
@ -19,8 +19,8 @@ This file contains download scripts for raw datasets.
|
||||||
Example of usage:
|
Example of usage:
|
||||||
```
|
```
|
||||||
python lerobot/common/datasets/push_dataset_to_hub/_download_raw.py \
|
python lerobot/common/datasets/push_dataset_to_hub/_download_raw.py \
|
||||||
--raw-dir data/cadene/pusht_raw \
|
--raw-dir data/lerobot-raw/pusht_raw \
|
||||||
--repo-id cadene/pusht_raw
|
--repo-id lerobot-raw/pusht_raw
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -31,63 +31,65 @@ from pathlib import Path
|
||||||
|
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
AVAILABLE_RAW_REPO_IDS = [
|
from lerobot.common.datasets.push_dataset_to_hub.utils import check_repo_id
|
||||||
"lerobot-raw/aloha_mobile_cabinet_raw",
|
|
||||||
"lerobot-raw/aloha_mobile_chair_raw",
|
# {raw_repo_id: raw_format}
|
||||||
"lerobot-raw/aloha_mobile_elevator_raw",
|
AVAILABLE_RAW_REPO_IDS = {
|
||||||
"lerobot-raw/aloha_mobile_shrimp_raw",
|
"lerobot-raw/aloha_mobile_cabinet_raw": "aloha_hdf5",
|
||||||
"lerobot-raw/aloha_mobile_wash_pan_raw",
|
"lerobot-raw/aloha_mobile_chair_raw": "aloha_hdf5",
|
||||||
"lerobot-raw/aloha_mobile_wipe_wine_raw",
|
"lerobot-raw/aloha_mobile_elevator_raw": "aloha_hdf5",
|
||||||
"lerobot-raw/aloha_sim_insertion_human_raw",
|
"lerobot-raw/aloha_mobile_shrimp_raw": "aloha_hdf5",
|
||||||
"lerobot-raw/aloha_sim_insertion_scripted_raw",
|
"lerobot-raw/aloha_mobile_wash_pan_raw": "aloha_hdf5",
|
||||||
"lerobot-raw/aloha_sim_transfer_cube_human_raw",
|
"lerobot-raw/aloha_mobile_wipe_wine_raw": "aloha_hdf5",
|
||||||
"lerobot-raw/aloha_sim_transfer_cube_scripted_raw",
|
"lerobot-raw/aloha_sim_insertion_human_raw": "aloha_hdf5",
|
||||||
"lerobot-raw/aloha_static_battery_raw",
|
"lerobot-raw/aloha_sim_insertion_scripted_raw": "aloha_hdf5",
|
||||||
"lerobot-raw/aloha_static_candy_raw",
|
"lerobot-raw/aloha_sim_transfer_cube_human_raw": "aloha_hdf5",
|
||||||
"lerobot-raw/aloha_static_coffee_new_raw",
|
"lerobot-raw/aloha_sim_transfer_cube_scripted_raw": "aloha_hdf5",
|
||||||
"lerobot-raw/aloha_static_coffee_raw",
|
"lerobot-raw/aloha_static_battery_raw": "aloha_hdf5",
|
||||||
"lerobot-raw/aloha_static_cups_open_raw",
|
"lerobot-raw/aloha_static_candy_raw": "aloha_hdf5",
|
||||||
"lerobot-raw/aloha_static_fork_pick_up_raw",
|
"lerobot-raw/aloha_static_coffee_new_raw": "aloha_hdf5",
|
||||||
"lerobot-raw/aloha_static_pingpong_test_raw",
|
"lerobot-raw/aloha_static_coffee_raw": "aloha_hdf5",
|
||||||
"lerobot-raw/aloha_static_pro_pencil_raw",
|
"lerobot-raw/aloha_static_cups_open_raw": "aloha_hdf5",
|
||||||
"lerobot-raw/aloha_static_screw_driver_raw",
|
"lerobot-raw/aloha_static_fork_pick_up_raw": "aloha_hdf5",
|
||||||
"lerobot-raw/aloha_static_tape_raw",
|
"lerobot-raw/aloha_static_pingpong_test_raw": "aloha_hdf5",
|
||||||
"lerobot-raw/aloha_static_thread_velcro_raw",
|
"lerobot-raw/aloha_static_pro_pencil_raw": "aloha_hdf5",
|
||||||
"lerobot-raw/aloha_static_towel_raw",
|
"lerobot-raw/aloha_static_screw_driver_raw": "aloha_hdf5",
|
||||||
"lerobot-raw/aloha_static_vinh_cup_left_raw",
|
"lerobot-raw/aloha_static_tape_raw": "aloha_hdf5",
|
||||||
"lerobot-raw/aloha_static_vinh_cup_raw",
|
"lerobot-raw/aloha_static_thread_velcro_raw": "aloha_hdf5",
|
||||||
"lerobot-raw/aloha_static_ziploc_slide_raw",
|
"lerobot-raw/aloha_static_towel_raw": "aloha_hdf5",
|
||||||
"lerobot-raw/pusht_raw",
|
"lerobot-raw/aloha_static_vinh_cup_left_raw": "aloha_hdf5",
|
||||||
"lerobot-raw/umi_cup_in_the_wild_raw",
|
"lerobot-raw/aloha_static_vinh_cup_raw": "aloha_hdf5",
|
||||||
"lerobot-raw/unitreeh1_fold_clothes_raw",
|
"lerobot-raw/aloha_static_ziploc_slide_raw": "aloha_hdf5",
|
||||||
"lerobot-raw/unitreeh1_rearrange_objects_raw",
|
"lerobot-raw/pusht_raw": "pusht_zarr",
|
||||||
"lerobot-raw/unitreeh1_two_robot_greeting_raw",
|
"lerobot-raw/umi_cup_in_the_wild_raw": "umi_zarr",
|
||||||
"lerobot-raw/unitreeh1_warehouse_raw",
|
"lerobot-raw/unitreeh1_fold_clothes_raw": "aloha_hdf5",
|
||||||
"lerobot-raw/xarm_lift_medium_raw",
|
"lerobot-raw/unitreeh1_rearrange_objects_raw": "aloha_hdf5",
|
||||||
"lerobot-raw/xarm_lift_medium_replay_raw",
|
"lerobot-raw/unitreeh1_two_robot_greeting_raw": "aloha_hdf5",
|
||||||
"lerobot-raw/xarm_push_medium_raw",
|
"lerobot-raw/unitreeh1_warehouse_raw": "aloha_hdf5",
|
||||||
"lerobot-raw/xarm_push_medium_replay_raw",
|
"lerobot-raw/xarm_lift_medium_raw": "xarm_pkl",
|
||||||
]
|
"lerobot-raw/xarm_lift_medium_replay_raw": "xarm_pkl",
|
||||||
|
"lerobot-raw/xarm_push_medium_raw": "xarm_pkl",
|
||||||
|
"lerobot-raw/xarm_push_medium_replay_raw": "xarm_pkl",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def download_raw(raw_dir: Path, repo_id: str):
|
def download_raw(raw_dir: Path, repo_id: str):
|
||||||
# Check repo_id is well formated
|
check_repo_id(repo_id)
|
||||||
if len(repo_id.split("/")) != 2:
|
|
||||||
raise ValueError(
|
|
||||||
f"`repo_id` is expected to contain a community or user id `/` the name of the dataset (e.g. 'lerobot/pusht'), but contains '{repo_id}'."
|
|
||||||
)
|
|
||||||
user_id, dataset_id = repo_id.split("/")
|
user_id, dataset_id = repo_id.split("/")
|
||||||
|
|
||||||
if not dataset_id.endswith("_raw"):
|
if not dataset_id.endswith("_raw"):
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
f"`dataset_id` ({dataset_id}) doesn't end with '_raw' (e.g. 'lerobot/pusht_raw'). Following this naming convention by renaming your repository is advised, but not mandatory.",
|
f"""`dataset_id` ({dataset_id}) doesn't end with '_raw' (e.g. 'lerobot/pusht_raw'). Following this
|
||||||
|
naming convention by renaming your repository is advised, but not mandatory.""",
|
||||||
stacklevel=1,
|
stacklevel=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Send warning if raw_dir isn't well formated
|
# Send warning if raw_dir isn't well formated
|
||||||
if raw_dir.parts[-2] != user_id or raw_dir.parts[-1] != dataset_id:
|
if raw_dir.parts[-2] != user_id or raw_dir.parts[-1] != dataset_id:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
f"`raw_dir` ({raw_dir}) doesn't contain a community or user id `/` the name of the dataset that match the `repo_id` (e.g. 'data/lerobot/pusht_raw'). Following this naming convention is advised, but not mandatory.",
|
f"""`raw_dir` ({raw_dir}) doesn't contain a community or user id `/` the name of the dataset that
|
||||||
|
match the `repo_id` (e.g. 'data/lerobot/pusht_raw'). Following this naming convention is advised,
|
||||||
|
but not mandatory.""",
|
||||||
stacklevel=1,
|
stacklevel=1,
|
||||||
)
|
)
|
||||||
raw_dir.mkdir(parents=True, exist_ok=True)
|
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
@ -97,7 +99,8 @@ def download_raw(raw_dir: Path, repo_id: str):
|
||||||
logging.info(f"Finish downloading from huggingface.co/{user_id} for {dataset_id}")
|
logging.info(f"Finish downloading from huggingface.co/{user_id} for {dataset_id}")
|
||||||
|
|
||||||
|
|
||||||
def download_all_raw_datasets():
|
def download_all_raw_datasets(data_dir: Path | None = None):
|
||||||
|
if data_dir is None:
|
||||||
data_dir = Path("data")
|
data_dir = Path("data")
|
||||||
for repo_id in AVAILABLE_RAW_REPO_IDS:
|
for repo_id in AVAILABLE_RAW_REPO_IDS:
|
||||||
raw_dir = data_dir / repo_id
|
raw_dir = data_dir / repo_id
|
||||||
|
@ -106,7 +109,8 @@ def download_all_raw_datasets():
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description=f"A script to download raw datasets from Hugging Face hub to a local directory. Here is a non exhaustive list of available repositories to use in `--repo-id`: {AVAILABLE_RAW_REPO_IDS}",
|
description=f"""A script to download raw datasets from Hugging Face hub to a local directory. Here is a
|
||||||
|
non exhaustive list of available repositories to use in `--repo-id`: {AVAILABLE_RAW_REPO_IDS}""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
@ -119,7 +123,8 @@ def main():
|
||||||
"--repo-id",
|
"--repo-id",
|
||||||
type=str,
|
type=str,
|
||||||
required=True,
|
required=True,
|
||||||
help="Repositery identifier on Hugging Face: a community or a user name `/` the name of the dataset (e.g. `lerobot/pusht_raw`, `cadene/aloha_sim_insertion_human_raw`).",
|
help="""Repositery identifier on Hugging Face: a community or a user name `/` the name of
|
||||||
|
the dataset (e.g. `lerobot/pusht_raw`, `cadene/aloha_sim_insertion_human_raw`).""",
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
download_raw(**vars(args))
|
download_raw(**vars(args))
|
||||||
|
|
|
@ -0,0 +1,184 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
Use this script to batch encode lerobot dataset from their raw format to LeRobotDataset and push their updated
|
||||||
|
version to the hub. Under the hood, this script reuses 'push_dataset_to_hub.py'. It assumes that you already
|
||||||
|
downloaded raw datasets, which you can do with the related '_download_raw.py' script.
|
||||||
|
|
||||||
|
For instance, for codebase_version = 'v1.6', the following command was run, assuming raw datasets from
|
||||||
|
lerobot-raw were downloaded in 'raw/datasets/directory':
|
||||||
|
```bash
|
||||||
|
python lerobot/common/datasets/push_dataset_to_hub/_encode_datasets.py \
|
||||||
|
--raw-dir raw/datasets/directory \
|
||||||
|
--raw-repo-ids lerobot-raw \
|
||||||
|
--local-dir push/datasets/directory \
|
||||||
|
--tests-data-dir tests/data \
|
||||||
|
--push-repo lerobot \
|
||||||
|
--vcodec libsvtav1 \
|
||||||
|
--pix-fmt yuv420p \
|
||||||
|
--g 2 \
|
||||||
|
--crf 30
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
||||||
|
from lerobot.common.datasets.push_dataset_to_hub._download_raw import AVAILABLE_RAW_REPO_IDS
|
||||||
|
from lerobot.common.datasets.push_dataset_to_hub.utils import check_repo_id
|
||||||
|
from lerobot.scripts.push_dataset_to_hub import push_dataset_to_hub
|
||||||
|
|
||||||
|
|
||||||
|
def get_push_repo_id_from_raw(raw_repo_id: str, push_repo: str) -> str:
|
||||||
|
dataset_id_raw = raw_repo_id.split("/")[1]
|
||||||
|
dataset_id = dataset_id_raw.removesuffix("_raw")
|
||||||
|
return f"{push_repo}/{dataset_id}"
|
||||||
|
|
||||||
|
|
||||||
|
def encode_datasets(
|
||||||
|
raw_dir: Path,
|
||||||
|
raw_repo_ids: list[str],
|
||||||
|
push_repo: str,
|
||||||
|
vcodec: str,
|
||||||
|
pix_fmt: str,
|
||||||
|
g: int,
|
||||||
|
crf: int,
|
||||||
|
local_dir: Path | None = None,
|
||||||
|
tests_data_dir: Path | None = None,
|
||||||
|
raw_format: str | None = None,
|
||||||
|
dry_run: bool = False,
|
||||||
|
) -> None:
|
||||||
|
if len(raw_repo_ids) == 1 and raw_repo_ids[0].lower() == "lerobot-raw":
|
||||||
|
raw_repo_ids_format = AVAILABLE_RAW_REPO_IDS
|
||||||
|
else:
|
||||||
|
if raw_format is None:
|
||||||
|
raise ValueError(raw_format)
|
||||||
|
raw_repo_ids_format = {id_: raw_format for id_ in raw_repo_ids}
|
||||||
|
|
||||||
|
for raw_repo_id, repo_raw_format in raw_repo_ids_format.items():
|
||||||
|
check_repo_id(raw_repo_id)
|
||||||
|
dataset_repo_id_push = get_push_repo_id_from_raw(raw_repo_id, push_repo)
|
||||||
|
dataset_raw_dir = raw_dir / raw_repo_id
|
||||||
|
dataset_dir = local_dir / dataset_repo_id_push if local_dir is not None else None
|
||||||
|
encoding = {
|
||||||
|
"vcodec": vcodec,
|
||||||
|
"pix_fmt": pix_fmt,
|
||||||
|
"g": g,
|
||||||
|
"crf": crf,
|
||||||
|
}
|
||||||
|
|
||||||
|
if not (dataset_raw_dir).is_dir():
|
||||||
|
raise NotADirectoryError(dataset_raw_dir)
|
||||||
|
|
||||||
|
if not dry_run:
|
||||||
|
push_dataset_to_hub(
|
||||||
|
dataset_raw_dir,
|
||||||
|
raw_format=repo_raw_format,
|
||||||
|
repo_id=dataset_repo_id_push,
|
||||||
|
local_dir=dataset_dir,
|
||||||
|
resume=True,
|
||||||
|
encoding=encoding,
|
||||||
|
tests_data_dir=tests_data_dir,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
f"DRY RUN: {dataset_raw_dir} --> {dataset_dir} --> {dataset_repo_id_push}@{CODEBASE_VERSION}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--raw-dir",
|
||||||
|
type=Path,
|
||||||
|
default=Path("data"),
|
||||||
|
help="Directory where raw datasets are located.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--raw-repo-ids",
|
||||||
|
type=str,
|
||||||
|
nargs="*",
|
||||||
|
default=["lerobot-raw"],
|
||||||
|
help="""Raw dataset repo ids. if 'lerobot-raw', the keys from `AVAILABLE_RAW_REPO_IDS` will be
|
||||||
|
used and raw datasets will be fetched from the 'lerobot-raw/' repo and pushed with their
|
||||||
|
associated format. It is assumed that each dataset is located at `raw_dir / raw_repo_id` """,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--raw-format",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="""Raw format to use for the raw repo-ids. Must be specified if --raw-repo-ids is not
|
||||||
|
'lerobot-raw'""",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--local-dir",
|
||||||
|
type=Path,
|
||||||
|
default=None,
|
||||||
|
help="""When provided, writes the dataset converted to LeRobotDataset format in this directory
|
||||||
|
(e.g. `data/lerobot/aloha_mobile_chair`).""",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--push-repo",
|
||||||
|
type=str,
|
||||||
|
default="lerobot",
|
||||||
|
help="Repo to upload datasets to",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--vcodec",
|
||||||
|
type=str,
|
||||||
|
default="libsvtav1",
|
||||||
|
help="Codec to use for encoding videos",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--pix-fmt",
|
||||||
|
type=str,
|
||||||
|
default="yuv420p",
|
||||||
|
help="Pixel formats (chroma subsampling) to be used for encoding",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--g",
|
||||||
|
type=int,
|
||||||
|
default=2,
|
||||||
|
help="Group of pictures sizes to be used for encoding.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--crf",
|
||||||
|
type=int,
|
||||||
|
default=30,
|
||||||
|
help="Constant rate factors to be used for encoding.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--tests-data-dir",
|
||||||
|
type=Path,
|
||||||
|
default=None,
|
||||||
|
help=(
|
||||||
|
"When provided, save tests artifacts into the given directory "
|
||||||
|
"(e.g. `--tests-data-dir tests/data` will save to tests/data/{--repo-id})."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dry-run",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="If not set to 0, this script won't download or upload anything.",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
encode_datasets(**vars(args))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
|
@ -29,7 +29,11 @@ from datasets import Dataset, Features, Image, Sequence, Value
|
||||||
from PIL import Image as PILImage
|
from PIL import Image as PILImage
|
||||||
|
|
||||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
||||||
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently
|
from lerobot.common.datasets.push_dataset_to_hub.utils import (
|
||||||
|
concatenate_episodes,
|
||||||
|
get_default_encoding,
|
||||||
|
save_images_concurrently,
|
||||||
|
)
|
||||||
from lerobot.common.datasets.utils import (
|
from lerobot.common.datasets.utils import (
|
||||||
calculate_episode_data_index,
|
calculate_episode_data_index,
|
||||||
hf_transform_to_torch,
|
hf_transform_to_torch,
|
||||||
|
@ -72,7 +76,14 @@ def check_format(raw_dir) -> bool:
|
||||||
assert c < h and c < w, f"Expect (h,w,c) image format but ({h=},{w=},{c=}) provided."
|
assert c < h and c < w, f"Expect (h,w,c) image format but ({h=},{w=},{c=}) provided."
|
||||||
|
|
||||||
|
|
||||||
def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episodes: list[int] | None = None):
|
def load_from_raw(
|
||||||
|
raw_dir: Path,
|
||||||
|
videos_dir: Path,
|
||||||
|
fps: int,
|
||||||
|
video: bool,
|
||||||
|
episodes: list[int] | None = None,
|
||||||
|
encoding: dict | None = None,
|
||||||
|
):
|
||||||
# only frames from simulation are uncompressed
|
# only frames from simulation are uncompressed
|
||||||
compressed_images = "sim" not in raw_dir.name
|
compressed_images = "sim" not in raw_dir.name
|
||||||
|
|
||||||
|
@ -123,7 +134,7 @@ def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episod
|
||||||
# encode images to a mp4 video
|
# encode images to a mp4 video
|
||||||
fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
|
fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
|
||||||
video_path = videos_dir / fname
|
video_path = videos_dir / fname
|
||||||
encode_video_frames(tmp_imgs_dir, video_path, fps)
|
encode_video_frames(tmp_imgs_dir, video_path, fps, **(encoding or {}))
|
||||||
|
|
||||||
# clean temporary images directory
|
# clean temporary images directory
|
||||||
shutil.rmtree(tmp_imgs_dir)
|
shutil.rmtree(tmp_imgs_dir)
|
||||||
|
@ -200,6 +211,7 @@ def from_raw_to_lerobot_format(
|
||||||
fps: int | None = None,
|
fps: int | None = None,
|
||||||
video: bool = True,
|
video: bool = True,
|
||||||
episodes: list[int] | None = None,
|
episodes: list[int] | None = None,
|
||||||
|
encoding: dict | None = None,
|
||||||
):
|
):
|
||||||
# sanity check
|
# sanity check
|
||||||
check_format(raw_dir)
|
check_format(raw_dir)
|
||||||
|
@ -207,7 +219,7 @@ def from_raw_to_lerobot_format(
|
||||||
if fps is None:
|
if fps is None:
|
||||||
fps = 50
|
fps = 50
|
||||||
|
|
||||||
data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes)
|
data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes, encoding)
|
||||||
hf_dataset = to_hf_dataset(data_dict, video)
|
hf_dataset = to_hf_dataset(data_dict, video)
|
||||||
episode_data_index = calculate_episode_data_index(hf_dataset)
|
episode_data_index = calculate_episode_data_index(hf_dataset)
|
||||||
info = {
|
info = {
|
||||||
|
@ -215,4 +227,7 @@ def from_raw_to_lerobot_format(
|
||||||
"fps": fps,
|
"fps": fps,
|
||||||
"video": video,
|
"video": video,
|
||||||
}
|
}
|
||||||
|
if video:
|
||||||
|
info["encoding"] = get_default_encoding()
|
||||||
|
|
||||||
return hf_dataset, episode_data_index, info
|
return hf_dataset, episode_data_index, info
|
||||||
|
|
|
@ -81,8 +81,9 @@ def from_raw_to_lerobot_format(
|
||||||
fps: int | None = None,
|
fps: int | None = None,
|
||||||
video: bool = True,
|
video: bool = True,
|
||||||
episodes: list[int] | None = None,
|
episodes: list[int] | None = None,
|
||||||
|
encoding: dict | None = None,
|
||||||
):
|
):
|
||||||
if video or episodes is not None:
|
if video or episodes or encoding is not None:
|
||||||
# TODO(aliberts): support this
|
# TODO(aliberts): support this
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
|
@ -18,6 +18,7 @@ Contains utilities to process raw data format from dora-record
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import re
|
import re
|
||||||
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
@ -199,6 +200,7 @@ def from_raw_to_lerobot_format(
|
||||||
fps: int | None = None,
|
fps: int | None = None,
|
||||||
video: bool = True,
|
video: bool = True,
|
||||||
episodes: list[int] | None = None,
|
episodes: list[int] | None = None,
|
||||||
|
encoding: dict | None = None,
|
||||||
):
|
):
|
||||||
# sanity check
|
# sanity check
|
||||||
check_format(raw_dir)
|
check_format(raw_dir)
|
||||||
|
@ -211,6 +213,12 @@ def from_raw_to_lerobot_format(
|
||||||
if not video:
|
if not video:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
if encoding is not None:
|
||||||
|
warnings.warn(
|
||||||
|
"Video encoding is currently done outside of LeRobot for the dora_parquet format.",
|
||||||
|
stacklevel=1,
|
||||||
|
)
|
||||||
|
|
||||||
data_df = load_from_raw(raw_dir, videos_dir, fps, episodes)
|
data_df = load_from_raw(raw_dir, videos_dir, fps, episodes)
|
||||||
hf_dataset = to_hf_dataset(data_df, video)
|
hf_dataset = to_hf_dataset(data_df, video)
|
||||||
episode_data_index = calculate_episode_data_index(hf_dataset)
|
episode_data_index = calculate_episode_data_index(hf_dataset)
|
||||||
|
@ -219,4 +227,7 @@ def from_raw_to_lerobot_format(
|
||||||
"fps": fps,
|
"fps": fps,
|
||||||
"video": video,
|
"video": video,
|
||||||
}
|
}
|
||||||
|
if video:
|
||||||
|
info["encoding"] = "unknown"
|
||||||
|
|
||||||
return hf_dataset, episode_data_index, info
|
return hf_dataset, episode_data_index, info
|
||||||
|
|
|
@ -26,7 +26,11 @@ from datasets import Dataset, Features, Image, Sequence, Value
|
||||||
from PIL import Image as PILImage
|
from PIL import Image as PILImage
|
||||||
|
|
||||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
||||||
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently
|
from lerobot.common.datasets.push_dataset_to_hub.utils import (
|
||||||
|
concatenate_episodes,
|
||||||
|
get_default_encoding,
|
||||||
|
save_images_concurrently,
|
||||||
|
)
|
||||||
from lerobot.common.datasets.utils import (
|
from lerobot.common.datasets.utils import (
|
||||||
calculate_episode_data_index,
|
calculate_episode_data_index,
|
||||||
hf_transform_to_torch,
|
hf_transform_to_torch,
|
||||||
|
@ -62,6 +66,7 @@ def load_from_raw(
|
||||||
video: bool,
|
video: bool,
|
||||||
episodes: list[int] | None = None,
|
episodes: list[int] | None = None,
|
||||||
keypoints_instead_of_image: bool = False,
|
keypoints_instead_of_image: bool = False,
|
||||||
|
encoding: dict | None = None,
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
import pymunk
|
import pymunk
|
||||||
|
@ -172,7 +177,7 @@ def load_from_raw(
|
||||||
# encode images to a mp4 video
|
# encode images to a mp4 video
|
||||||
fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
|
fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
|
||||||
video_path = videos_dir / fname
|
video_path = videos_dir / fname
|
||||||
encode_video_frames(tmp_imgs_dir, video_path, fps)
|
encode_video_frames(tmp_imgs_dir, video_path, fps, **(encoding or {}))
|
||||||
|
|
||||||
# clean temporary images directory
|
# clean temporary images directory
|
||||||
shutil.rmtree(tmp_imgs_dir)
|
shutil.rmtree(tmp_imgs_dir)
|
||||||
|
@ -244,6 +249,7 @@ def from_raw_to_lerobot_format(
|
||||||
fps: int | None = None,
|
fps: int | None = None,
|
||||||
video: bool = True,
|
video: bool = True,
|
||||||
episodes: list[int] | None = None,
|
episodes: list[int] | None = None,
|
||||||
|
encoding: dict | None = None,
|
||||||
):
|
):
|
||||||
# Manually change this to True to use keypoints of the T instead of an image observation (but don't merge
|
# Manually change this to True to use keypoints of the T instead of an image observation (but don't merge
|
||||||
# with True). Also make sure to use video = 0 in the `push_dataset_to_hub.py` script.
|
# with True). Also make sure to use video = 0 in the `push_dataset_to_hub.py` script.
|
||||||
|
@ -255,7 +261,7 @@ def from_raw_to_lerobot_format(
|
||||||
if fps is None:
|
if fps is None:
|
||||||
fps = 10
|
fps = 10
|
||||||
|
|
||||||
data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes, keypoints_instead_of_image)
|
data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes, keypoints_instead_of_image, encoding)
|
||||||
hf_dataset = to_hf_dataset(data_dict, video, keypoints_instead_of_image)
|
hf_dataset = to_hf_dataset(data_dict, video, keypoints_instead_of_image)
|
||||||
episode_data_index = calculate_episode_data_index(hf_dataset)
|
episode_data_index = calculate_episode_data_index(hf_dataset)
|
||||||
info = {
|
info = {
|
||||||
|
@ -263,4 +269,7 @@ def from_raw_to_lerobot_format(
|
||||||
"fps": fps,
|
"fps": fps,
|
||||||
"video": video if not keypoints_instead_of_image else 0,
|
"video": video if not keypoints_instead_of_image else 0,
|
||||||
}
|
}
|
||||||
|
if video:
|
||||||
|
info["encoding"] = get_default_encoding()
|
||||||
|
|
||||||
return hf_dataset, episode_data_index, info
|
return hf_dataset, episode_data_index, info
|
||||||
|
|
|
@ -27,7 +27,11 @@ from PIL import Image as PILImage
|
||||||
|
|
||||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
||||||
from lerobot.common.datasets.push_dataset_to_hub._umi_imagecodecs_numcodecs import register_codecs
|
from lerobot.common.datasets.push_dataset_to_hub._umi_imagecodecs_numcodecs import register_codecs
|
||||||
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently
|
from lerobot.common.datasets.push_dataset_to_hub.utils import (
|
||||||
|
concatenate_episodes,
|
||||||
|
get_default_encoding,
|
||||||
|
save_images_concurrently,
|
||||||
|
)
|
||||||
from lerobot.common.datasets.utils import (
|
from lerobot.common.datasets.utils import (
|
||||||
calculate_episode_data_index,
|
calculate_episode_data_index,
|
||||||
hf_transform_to_torch,
|
hf_transform_to_torch,
|
||||||
|
@ -60,7 +64,14 @@ def check_format(raw_dir) -> bool:
|
||||||
assert all(nb_frames == zarr_data[dataset].shape[0] for dataset in required_datasets)
|
assert all(nb_frames == zarr_data[dataset].shape[0] for dataset in required_datasets)
|
||||||
|
|
||||||
|
|
||||||
def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episodes: list[int] | None = None):
|
def load_from_raw(
|
||||||
|
raw_dir: Path,
|
||||||
|
videos_dir: Path,
|
||||||
|
fps: int,
|
||||||
|
video: bool,
|
||||||
|
episodes: list[int] | None = None,
|
||||||
|
encoding: dict | None = None,
|
||||||
|
):
|
||||||
zarr_path = raw_dir / "cup_in_the_wild.zarr"
|
zarr_path = raw_dir / "cup_in_the_wild.zarr"
|
||||||
zarr_data = zarr.open(zarr_path, mode="r")
|
zarr_data = zarr.open(zarr_path, mode="r")
|
||||||
|
|
||||||
|
@ -88,9 +99,14 @@ def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episod
|
||||||
to_ids.append(to_idx)
|
to_ids.append(to_idx)
|
||||||
from_idx = to_idx
|
from_idx = to_idx
|
||||||
|
|
||||||
|
ep_dicts_dir = videos_dir / "ep_dicts"
|
||||||
|
ep_dicts_dir.mkdir(exist_ok=True, parents=True)
|
||||||
ep_dicts = []
|
ep_dicts = []
|
||||||
|
|
||||||
ep_ids = episodes if episodes else range(num_episodes)
|
ep_ids = episodes if episodes else range(num_episodes)
|
||||||
for ep_idx, selected_ep_idx in tqdm.tqdm(enumerate(ep_ids)):
|
for ep_idx, selected_ep_idx in tqdm.tqdm(enumerate(ep_ids)):
|
||||||
|
ep_dict_path = ep_dicts_dir / f"{ep_idx}"
|
||||||
|
if not ep_dict_path.is_file():
|
||||||
from_idx = from_ids[selected_ep_idx]
|
from_idx = from_ids[selected_ep_idx]
|
||||||
to_idx = to_ids[selected_ep_idx]
|
to_idx = to_ids[selected_ep_idx]
|
||||||
num_frames = to_idx - from_idx
|
num_frames = to_idx - from_idx
|
||||||
|
@ -105,20 +121,23 @@ def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episod
|
||||||
imgs_array = zarr_data["data/camera0_rgb"][from_idx:to_idx]
|
imgs_array = zarr_data["data/camera0_rgb"][from_idx:to_idx]
|
||||||
img_key = "observation.image"
|
img_key = "observation.image"
|
||||||
if video:
|
if video:
|
||||||
|
fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
|
||||||
|
video_path = videos_dir / fname
|
||||||
|
if not video_path.is_file():
|
||||||
# save png images in temporary directory
|
# save png images in temporary directory
|
||||||
tmp_imgs_dir = videos_dir / "tmp_images"
|
tmp_imgs_dir = videos_dir / "tmp_images"
|
||||||
save_images_concurrently(imgs_array, tmp_imgs_dir)
|
save_images_concurrently(imgs_array, tmp_imgs_dir)
|
||||||
|
|
||||||
# encode images to a mp4 video
|
# encode images to a mp4 video
|
||||||
fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
|
encode_video_frames(tmp_imgs_dir, video_path, fps, **(encoding or {}))
|
||||||
video_path = videos_dir / fname
|
|
||||||
encode_video_frames(tmp_imgs_dir, video_path, fps)
|
|
||||||
|
|
||||||
# clean temporary images directory
|
# clean temporary images directory
|
||||||
shutil.rmtree(tmp_imgs_dir)
|
shutil.rmtree(tmp_imgs_dir)
|
||||||
|
|
||||||
# store the reference to the video frame
|
# store the reference to the video frame
|
||||||
ep_dict[img_key] = [{"path": f"videos/{fname}", "timestamp": i / fps} for i in range(num_frames)]
|
ep_dict[img_key] = [
|
||||||
|
{"path": f"videos/{fname}", "timestamp": i / fps} for i in range(num_frames)
|
||||||
|
]
|
||||||
else:
|
else:
|
||||||
ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array]
|
ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array]
|
||||||
|
|
||||||
|
@ -131,6 +150,10 @@ def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episod
|
||||||
ep_dict["end_pose"] = end_pose[from_idx:to_idx]
|
ep_dict["end_pose"] = end_pose[from_idx:to_idx]
|
||||||
ep_dict["start_pos"] = start_pos[from_idx:to_idx]
|
ep_dict["start_pos"] = start_pos[from_idx:to_idx]
|
||||||
ep_dict["gripper_width"] = gripper_width[from_idx:to_idx]
|
ep_dict["gripper_width"] = gripper_width[from_idx:to_idx]
|
||||||
|
torch.save(ep_dict, ep_dict_path)
|
||||||
|
else:
|
||||||
|
ep_dict = torch.load(ep_dict_path)
|
||||||
|
|
||||||
ep_dicts.append(ep_dict)
|
ep_dicts.append(ep_dict)
|
||||||
|
|
||||||
data_dict = concatenate_episodes(ep_dicts)
|
data_dict = concatenate_episodes(ep_dicts)
|
||||||
|
@ -183,6 +206,7 @@ def from_raw_to_lerobot_format(
|
||||||
fps: int | None = None,
|
fps: int | None = None,
|
||||||
video: bool = True,
|
video: bool = True,
|
||||||
episodes: list[int] | None = None,
|
episodes: list[int] | None = None,
|
||||||
|
encoding: dict | None = None,
|
||||||
):
|
):
|
||||||
# sanity check
|
# sanity check
|
||||||
check_format(raw_dir)
|
check_format(raw_dir)
|
||||||
|
@ -196,7 +220,7 @@ def from_raw_to_lerobot_format(
|
||||||
"Generating UMI dataset without `video=True` creates ~150GB on disk and requires ~80GB in RAM."
|
"Generating UMI dataset without `video=True` creates ~150GB on disk and requires ~80GB in RAM."
|
||||||
)
|
)
|
||||||
|
|
||||||
data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes)
|
data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes, encoding)
|
||||||
hf_dataset = to_hf_dataset(data_dict, video)
|
hf_dataset = to_hf_dataset(data_dict, video)
|
||||||
episode_data_index = calculate_episode_data_index(hf_dataset)
|
episode_data_index = calculate_episode_data_index(hf_dataset)
|
||||||
info = {
|
info = {
|
||||||
|
@ -204,4 +228,7 @@ def from_raw_to_lerobot_format(
|
||||||
"fps": fps,
|
"fps": fps,
|
||||||
"video": video,
|
"video": video,
|
||||||
}
|
}
|
||||||
|
if video:
|
||||||
|
info["encoding"] = get_default_encoding()
|
||||||
|
|
||||||
return hf_dataset, episode_data_index, info
|
return hf_dataset, episode_data_index, info
|
||||||
|
|
|
@ -13,6 +13,7 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
import inspect
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
@ -20,6 +21,8 @@ import numpy
|
||||||
import PIL
|
import PIL
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from lerobot.common.datasets.video_utils import encode_video_frames
|
||||||
|
|
||||||
|
|
||||||
def concatenate_episodes(ep_dicts):
|
def concatenate_episodes(ep_dicts):
|
||||||
data_dict = {}
|
data_dict = {}
|
||||||
|
@ -51,3 +54,21 @@ def save_images_concurrently(imgs_array: numpy.array, out_dir: Path, max_workers
|
||||||
num_images = len(imgs_array)
|
num_images = len(imgs_array)
|
||||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||||
[executor.submit(save_image, imgs_array[i], i, out_dir) for i in range(num_images)]
|
[executor.submit(save_image, imgs_array[i], i, out_dir) for i in range(num_images)]
|
||||||
|
|
||||||
|
|
||||||
|
def get_default_encoding() -> dict:
|
||||||
|
"""Returns the default ffmpeg encoding parameters used by `encode_video_frames`."""
|
||||||
|
signature = inspect.signature(encode_video_frames)
|
||||||
|
return {
|
||||||
|
k: v.default
|
||||||
|
for k, v in signature.parameters.items()
|
||||||
|
if v.default is not inspect.Parameter.empty and k in ["vcodec", "pix_fmt", "g", "crf"]
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def check_repo_id(repo_id: str) -> None:
|
||||||
|
if len(repo_id.split("/")) != 2:
|
||||||
|
raise ValueError(
|
||||||
|
f"""`repo_id` is expected to contain a community or user id `/` the name of the dataset
|
||||||
|
(e.g. 'lerobot/pusht'), but contains '{repo_id}'."""
|
||||||
|
)
|
||||||
|
|
|
@ -26,7 +26,11 @@ from datasets import Dataset, Features, Image, Sequence, Value
|
||||||
from PIL import Image as PILImage
|
from PIL import Image as PILImage
|
||||||
|
|
||||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
||||||
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently
|
from lerobot.common.datasets.push_dataset_to_hub.utils import (
|
||||||
|
concatenate_episodes,
|
||||||
|
get_default_encoding,
|
||||||
|
save_images_concurrently,
|
||||||
|
)
|
||||||
from lerobot.common.datasets.utils import (
|
from lerobot.common.datasets.utils import (
|
||||||
calculate_episode_data_index,
|
calculate_episode_data_index,
|
||||||
hf_transform_to_torch,
|
hf_transform_to_torch,
|
||||||
|
@ -56,7 +60,14 @@ def check_format(raw_dir):
|
||||||
assert all(len(nested_dict[subkey]) == expected_len for subkey in subkeys if subkey in nested_dict)
|
assert all(len(nested_dict[subkey]) == expected_len for subkey in subkeys if subkey in nested_dict)
|
||||||
|
|
||||||
|
|
||||||
def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episodes: list[int] | None = None):
|
def load_from_raw(
|
||||||
|
raw_dir: Path,
|
||||||
|
videos_dir: Path,
|
||||||
|
fps: int,
|
||||||
|
video: bool,
|
||||||
|
episodes: list[int] | None = None,
|
||||||
|
encoding: dict | None = None,
|
||||||
|
):
|
||||||
pkl_path = raw_dir / "buffer.pkl"
|
pkl_path = raw_dir / "buffer.pkl"
|
||||||
|
|
||||||
with open(pkl_path, "rb") as f:
|
with open(pkl_path, "rb") as f:
|
||||||
|
@ -105,7 +116,7 @@ def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episod
|
||||||
# encode images to a mp4 video
|
# encode images to a mp4 video
|
||||||
fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
|
fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
|
||||||
video_path = videos_dir / fname
|
video_path = videos_dir / fname
|
||||||
encode_video_frames(tmp_imgs_dir, video_path, fps)
|
encode_video_frames(tmp_imgs_dir, video_path, fps, **(encoding or {}))
|
||||||
|
|
||||||
# clean temporary images directory
|
# clean temporary images directory
|
||||||
shutil.rmtree(tmp_imgs_dir)
|
shutil.rmtree(tmp_imgs_dir)
|
||||||
|
@ -167,6 +178,7 @@ def from_raw_to_lerobot_format(
|
||||||
fps: int | None = None,
|
fps: int | None = None,
|
||||||
video: bool = True,
|
video: bool = True,
|
||||||
episodes: list[int] | None = None,
|
episodes: list[int] | None = None,
|
||||||
|
encoding: dict | None = None,
|
||||||
):
|
):
|
||||||
# sanity check
|
# sanity check
|
||||||
check_format(raw_dir)
|
check_format(raw_dir)
|
||||||
|
@ -174,7 +186,7 @@ def from_raw_to_lerobot_format(
|
||||||
if fps is None:
|
if fps is None:
|
||||||
fps = 15
|
fps = 15
|
||||||
|
|
||||||
data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes)
|
data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes, encoding)
|
||||||
hf_dataset = to_hf_dataset(data_dict, video)
|
hf_dataset = to_hf_dataset(data_dict, video)
|
||||||
episode_data_index = calculate_episode_data_index(hf_dataset)
|
episode_data_index = calculate_episode_data_index(hf_dataset)
|
||||||
info = {
|
info = {
|
||||||
|
@ -182,4 +194,7 @@ def from_raw_to_lerobot_format(
|
||||||
"fps": fps,
|
"fps": fps,
|
||||||
"video": video,
|
"video": video,
|
||||||
}
|
}
|
||||||
|
if video:
|
||||||
|
info["encoding"] = get_default_encoding()
|
||||||
|
|
||||||
return hf_dataset, episode_data_index, info
|
return hf_dataset, episode_data_index, info
|
||||||
|
|
|
@ -166,10 +166,10 @@ def encode_video_frames(
|
||||||
imgs_dir: Path,
|
imgs_dir: Path,
|
||||||
video_path: Path,
|
video_path: Path,
|
||||||
fps: int,
|
fps: int,
|
||||||
video_codec: str = "libsvtav1",
|
vcodec: str = "libsvtav1",
|
||||||
pixel_format: str = "yuv420p",
|
pix_fmt: str = "yuv420p",
|
||||||
group_of_pictures_size: int | None = 2,
|
g: int | None = 2,
|
||||||
constant_rate_factor: int | None = 30,
|
crf: int | None = 30,
|
||||||
fast_decode: int = 0,
|
fast_decode: int = 0,
|
||||||
log_level: str | None = "error",
|
log_level: str | None = "error",
|
||||||
overwrite: bool = False,
|
overwrite: bool = False,
|
||||||
|
@ -183,20 +183,20 @@ def encode_video_frames(
|
||||||
("-f", "image2"),
|
("-f", "image2"),
|
||||||
("-r", str(fps)),
|
("-r", str(fps)),
|
||||||
("-i", str(imgs_dir / "frame_%06d.png")),
|
("-i", str(imgs_dir / "frame_%06d.png")),
|
||||||
("-vcodec", video_codec),
|
("-vcodec", vcodec),
|
||||||
("-pix_fmt", pixel_format),
|
("-pix_fmt", pix_fmt),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
if group_of_pictures_size is not None:
|
if g is not None:
|
||||||
ffmpeg_args["-g"] = str(group_of_pictures_size)
|
ffmpeg_args["-g"] = str(g)
|
||||||
|
|
||||||
if constant_rate_factor is not None:
|
if crf is not None:
|
||||||
ffmpeg_args["-crf"] = str(constant_rate_factor)
|
ffmpeg_args["-crf"] = str(crf)
|
||||||
|
|
||||||
if fast_decode:
|
if fast_decode:
|
||||||
key = "-svtav1-params" if video_codec == "libsvtav1" else "-tune"
|
key = "-svtav1-params" if vcodec == "libsvtav1" else "-tune"
|
||||||
value = f"fast-decode={fast_decode}" if video_codec == "libsvtav1" else "fastdecode"
|
value = f"fast-decode={fast_decode}" if vcodec == "libsvtav1" else "fastdecode"
|
||||||
ffmpeg_args[key] = value
|
ffmpeg_args[key] = value
|
||||||
|
|
||||||
if log_level is not None:
|
if log_level is not None:
|
||||||
|
|
|
@ -101,6 +101,7 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
|
|
||||||
batch = self.normalize_inputs(batch)
|
batch = self.normalize_inputs(batch)
|
||||||
if len(self.expected_image_keys) > 0:
|
if len(self.expected_image_keys) > 0:
|
||||||
|
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||||
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
|
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
|
||||||
|
|
||||||
# If we are doing temporal ensembling, do online updates where we keep track of the number of actions
|
# If we are doing temporal ensembling, do online updates where we keep track of the number of actions
|
||||||
|
@ -128,6 +129,7 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
"""Run the batch through the model and compute the loss for training or validation."""
|
"""Run the batch through the model and compute the loss for training or validation."""
|
||||||
batch = self.normalize_inputs(batch)
|
batch = self.normalize_inputs(batch)
|
||||||
if len(self.expected_image_keys) > 0:
|
if len(self.expected_image_keys) > 0:
|
||||||
|
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||||
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
|
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
|
||||||
batch = self.normalize_targets(batch)
|
batch = self.normalize_targets(batch)
|
||||||
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
|
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
|
||||||
|
@ -467,10 +469,9 @@ class ACT(nn.Module):
|
||||||
if self.use_images:
|
if self.use_images:
|
||||||
all_cam_features = []
|
all_cam_features = []
|
||||||
all_cam_pos_embeds = []
|
all_cam_pos_embeds = []
|
||||||
images = batch["observation.images"]
|
|
||||||
|
|
||||||
for cam_index in range(images.shape[-4]):
|
for cam_index in range(batch["observation.images"].shape[-4]):
|
||||||
cam_features = self.backbone(images[:, cam_index])["feature_map"]
|
cam_features = self.backbone(batch["observation.images"][:, cam_index])["feature_map"]
|
||||||
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use
|
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use
|
||||||
# buffer
|
# buffer
|
||||||
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
|
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
|
||||||
|
|
|
@ -111,17 +111,18 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
Schematically this looks like:
|
Schematically this looks like:
|
||||||
----------------------------------------------------------------------------------------------
|
----------------------------------------------------------------------------------------------
|
||||||
(legend: o = n_obs_steps, h = horizon, a = n_action_steps)
|
(legend: o = n_obs_steps, h = horizon, a = n_action_steps)
|
||||||
|timestep | n-o+1 | n-o+2 | ..... | n | ..... | n+a-1 | n+a | ..... |n-o+1+h|
|
|timestep | n-o+1 | n-o+2 | ..... | n | ..... | n+a-1 | n+a | ..... | n-o+h |
|
||||||
|observation is used | YES | YES | YES | NO | NO | NO | NO | NO | NO |
|
|observation is used | YES | YES | YES | YES | NO | NO | NO | NO | NO |
|
||||||
|action is generated | YES | YES | YES | YES | YES | YES | YES | YES | YES |
|
|action is generated | YES | YES | YES | YES | YES | YES | YES | YES | YES |
|
||||||
|action is used | NO | NO | NO | YES | YES | YES | NO | NO | NO |
|
|action is used | NO | NO | NO | YES | YES | YES | NO | NO | NO |
|
||||||
----------------------------------------------------------------------------------------------
|
----------------------------------------------------------------------------------------------
|
||||||
Note that this means we require: `n_action_steps < horizon - n_obs_steps + 1`. Also, note that
|
Note that this means we require: `n_action_steps <= horizon - n_obs_steps + 1`. Also, note that
|
||||||
"horizon" may not the best name to describe what the variable actually means, because this period is
|
"horizon" may not the best name to describe what the variable actually means, because this period is
|
||||||
actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past.
|
actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past.
|
||||||
"""
|
"""
|
||||||
batch = self.normalize_inputs(batch)
|
batch = self.normalize_inputs(batch)
|
||||||
if len(self.expected_image_keys) > 0:
|
if len(self.expected_image_keys) > 0:
|
||||||
|
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||||
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
|
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
|
||||||
# Note: It's important that this happens after stacking the images into a single key.
|
# Note: It's important that this happens after stacking the images into a single key.
|
||||||
self._queues = populate_queues(self._queues, batch)
|
self._queues = populate_queues(self._queues, batch)
|
||||||
|
@ -143,6 +144,7 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
"""Run the batch through the model and compute the loss for training or validation."""
|
"""Run the batch through the model and compute the loss for training or validation."""
|
||||||
batch = self.normalize_inputs(batch)
|
batch = self.normalize_inputs(batch)
|
||||||
if len(self.expected_image_keys) > 0:
|
if len(self.expected_image_keys) > 0:
|
||||||
|
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||||
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
|
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
|
||||||
batch = self.normalize_targets(batch)
|
batch = self.normalize_targets(batch)
|
||||||
loss = self.diffusion.compute_loss(batch)
|
loss = self.diffusion.compute_loss(batch)
|
||||||
|
|
|
@ -132,6 +132,7 @@ class Normalize(nn.Module):
|
||||||
# TODO(rcadene): should we remove torch.no_grad?
|
# TODO(rcadene): should we remove torch.no_grad?
|
||||||
@torch.no_grad
|
@torch.no_grad
|
||||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||||
|
batch = dict(batch) # shallow copy avoids mutating the input batch
|
||||||
for key, mode in self.modes.items():
|
for key, mode in self.modes.items():
|
||||||
buffer = getattr(self, "buffer_" + key.replace(".", "_"))
|
buffer = getattr(self, "buffer_" + key.replace(".", "_"))
|
||||||
|
|
||||||
|
@ -197,6 +198,7 @@ class Unnormalize(nn.Module):
|
||||||
# TODO(rcadene): should we remove torch.no_grad?
|
# TODO(rcadene): should we remove torch.no_grad?
|
||||||
@torch.no_grad
|
@torch.no_grad
|
||||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||||
|
batch = dict(batch) # shallow copy avoids mutating the input batch
|
||||||
for key, mode in self.modes.items():
|
for key, mode in self.modes.items():
|
||||||
buffer = getattr(self, "buffer_" + key.replace(".", "_"))
|
buffer = getattr(self, "buffer_" + key.replace(".", "_"))
|
||||||
|
|
||||||
|
|
|
@ -25,12 +25,16 @@ class TDMPCConfig:
|
||||||
camera observations.
|
camera observations.
|
||||||
|
|
||||||
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
|
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
|
||||||
Those are: `input_shapes`, `output_shapes`, and perhaps `max_random_shift`.
|
Those are: `input_shapes`, `output_shapes`, and perhaps `max_random_shift_ratio`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
n_action_repeats: The number of times to repeat the action returned by the planning. (hint: Google
|
n_action_repeats: The number of times to repeat the action returned by the planning. (hint: Google
|
||||||
action repeats in Q-learning or ask your favorite chatbot)
|
action repeats in Q-learning or ask your favorite chatbot)
|
||||||
horizon: Horizon for model predictive control.
|
horizon: Horizon for model predictive control.
|
||||||
|
n_action_steps: Number of action steps to take from the plan given by model predictive control. This
|
||||||
|
is an alternative to using action repeats. If this is set to more than 1, then we require
|
||||||
|
`n_action_repeats == 1`, `use_mpc == True` and `n_action_steps <= horizon`. Note that this
|
||||||
|
approach of using multiple steps from the plan is not in the original implementation.
|
||||||
input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
|
input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
|
||||||
the input data name, and the value is a list indicating the dimensions of the corresponding data.
|
the input data name, and the value is a list indicating the dimensions of the corresponding data.
|
||||||
For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96],
|
For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96],
|
||||||
|
@ -100,6 +104,7 @@ class TDMPCConfig:
|
||||||
# Input / output structure.
|
# Input / output structure.
|
||||||
n_action_repeats: int = 2
|
n_action_repeats: int = 2
|
||||||
horizon: int = 5
|
horizon: int = 5
|
||||||
|
n_action_steps: int = 1
|
||||||
|
|
||||||
input_shapes: dict[str, list[int]] = field(
|
input_shapes: dict[str, list[int]] = field(
|
||||||
default_factory=lambda: {
|
default_factory=lambda: {
|
||||||
|
@ -158,10 +163,11 @@ class TDMPCConfig:
|
||||||
"""Input validation (not exhaustive)."""
|
"""Input validation (not exhaustive)."""
|
||||||
# There should only be one image key.
|
# There should only be one image key.
|
||||||
image_keys = {k for k in self.input_shapes if k.startswith("observation.image")}
|
image_keys = {k for k in self.input_shapes if k.startswith("observation.image")}
|
||||||
if len(image_keys) != 1:
|
if len(image_keys) > 1:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"{self.__class__.__name__} only handles one image for now. Got image keys {image_keys}."
|
f"{self.__class__.__name__} handles at most one image for now. Got image keys {image_keys}."
|
||||||
)
|
)
|
||||||
|
if len(image_keys) > 0:
|
||||||
image_key = next(iter(image_keys))
|
image_key = next(iter(image_keys))
|
||||||
if self.input_shapes[image_key][-2] != self.input_shapes[image_key][-1]:
|
if self.input_shapes[image_key][-2] != self.input_shapes[image_key][-1]:
|
||||||
# TODO(alexander-soare): This limitation is solely because of code in the random shift
|
# TODO(alexander-soare): This limitation is solely because of code in the random shift
|
||||||
|
@ -179,3 +185,12 @@ class TDMPCConfig:
|
||||||
f"advised that you stick with the default. See {self.__class__.__name__} docstring for more "
|
f"advised that you stick with the default. See {self.__class__.__name__} docstring for more "
|
||||||
"information."
|
"information."
|
||||||
)
|
)
|
||||||
|
if self.n_action_steps > 1:
|
||||||
|
if self.n_action_repeats != 1:
|
||||||
|
raise ValueError(
|
||||||
|
"If `n_action_steps > 1`, `n_action_repeats` must be left to its default value of 1."
|
||||||
|
)
|
||||||
|
if not self.use_mpc:
|
||||||
|
raise ValueError("If `n_action_steps > 1`, `use_mpc` must be set to `True`.")
|
||||||
|
if self.n_action_steps > self.horizon:
|
||||||
|
raise ValueError("`n_action_steps` must be less than or equal to `horizon`.")
|
||||||
|
|
|
@ -19,14 +19,10 @@
|
||||||
The comments in this code may sometimes refer to these references:
|
The comments in this code may sometimes refer to these references:
|
||||||
TD-MPC paper: Temporal Difference Learning for Model Predictive Control (https://arxiv.org/abs/2203.04955)
|
TD-MPC paper: Temporal Difference Learning for Model Predictive Control (https://arxiv.org/abs/2203.04955)
|
||||||
FOWM paper: Finetuning Offline World Models in the Real World (https://arxiv.org/abs/2310.16029)
|
FOWM paper: Finetuning Offline World Models in the Real World (https://arxiv.org/abs/2310.16029)
|
||||||
|
|
||||||
TODO(alexander-soare): Make rollout work for batch sizes larger than 1.
|
|
||||||
TODO(alexander-soare): Use batch-first throughout.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# ruff: noqa: N806
|
# ruff: noqa: N806
|
||||||
|
|
||||||
import logging
|
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
@ -56,7 +52,9 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
process communication to use the xarm environment from FOWM. This is because our xarm
|
process communication to use the xarm environment from FOWM. This is because our xarm
|
||||||
environment uses newer dependencies and does not match the environment in FOWM. See
|
environment uses newer dependencies and does not match the environment in FOWM. See
|
||||||
https://github.com/huggingface/lerobot/pull/103 for implementation details.
|
https://github.com/huggingface/lerobot/pull/103 for implementation details.
|
||||||
- We have NOT checked that training on LeRobot reproduces SOTA results. This is a TODO.
|
- We have NOT checked that training on LeRobot reproduces the results from FOWM.
|
||||||
|
- Nevertheless, we have verified that we can train TD-MPC for PushT. See
|
||||||
|
`lerobot/configs/policy/tdmpc_pusht_keypoints.yaml`.
|
||||||
- Our current xarm datasets were generated using the environment from FOWM. Therefore they do not
|
- Our current xarm datasets were generated using the environment from FOWM. Therefore they do not
|
||||||
match our xarm environment.
|
match our xarm environment.
|
||||||
"""
|
"""
|
||||||
|
@ -74,22 +72,6 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
that they will be passed with a call to `load_state_dict` before the policy is used.
|
that they will be passed with a call to `load_state_dict` before the policy is used.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
logging.warning(
|
|
||||||
"""
|
|
||||||
Please note several warnings for this policy.
|
|
||||||
|
|
||||||
- Evaluation of pretrained weights created with the original FOWM code
|
|
||||||
(https://github.com/fyhMer/fowm) works as expected. To be precise: we trained and evaluated a
|
|
||||||
model with the FOWM code for the xarm_lift_medium_replay dataset. We ported the weights across
|
|
||||||
to LeRobot, and were able to evaluate with the same success metric. BUT, we had to use inter-
|
|
||||||
process communication to use the xarm environment from FOWM. This is because our xarm
|
|
||||||
environment uses newer dependencies and does not match the environment in FOWM. See
|
|
||||||
https://github.com/huggingface/lerobot/pull/103 for implementation details.
|
|
||||||
- We have NOT checked that training on LeRobot reproduces SOTA results. This is a TODO.
|
|
||||||
- Our current xarm datasets were generated using the environment from FOWM. Therefore they do not
|
|
||||||
match our xarm environment.
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
if config is None:
|
if config is None:
|
||||||
config = TDMPCConfig()
|
config = TDMPCConfig()
|
||||||
|
@ -114,8 +96,14 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
|
|
||||||
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
|
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
|
||||||
# Note: This check is covered in the post-init of the config but have a sanity check just in case.
|
# Note: This check is covered in the post-init of the config but have a sanity check just in case.
|
||||||
|
self._use_image = False
|
||||||
|
self._use_env_state = False
|
||||||
|
if len(image_keys) > 0:
|
||||||
assert len(image_keys) == 1
|
assert len(image_keys) == 1
|
||||||
|
self._use_image = True
|
||||||
self.input_image_key = image_keys[0]
|
self.input_image_key = image_keys[0]
|
||||||
|
if "observation.environment_state" in config.input_shapes:
|
||||||
|
self._use_env_state = True
|
||||||
|
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
|
@ -125,10 +113,13 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
called on `env.reset()`
|
called on `env.reset()`
|
||||||
"""
|
"""
|
||||||
self._queues = {
|
self._queues = {
|
||||||
"observation.image": deque(maxlen=1),
|
|
||||||
"observation.state": deque(maxlen=1),
|
"observation.state": deque(maxlen=1),
|
||||||
"action": deque(maxlen=self.config.n_action_repeats),
|
"action": deque(maxlen=max(self.config.n_action_steps, self.config.n_action_repeats)),
|
||||||
}
|
}
|
||||||
|
if self._use_image:
|
||||||
|
self._queues["observation.image"] = deque(maxlen=1)
|
||||||
|
if self._use_env_state:
|
||||||
|
self._queues["observation.environment_state"] = deque(maxlen=1)
|
||||||
# Previous mean obtained from the cross-entropy method (CEM) used during MPC. It is used to warm start
|
# Previous mean obtained from the cross-entropy method (CEM) used during MPC. It is used to warm start
|
||||||
# CEM for the next step.
|
# CEM for the next step.
|
||||||
self._prev_mean: torch.Tensor | None = None
|
self._prev_mean: torch.Tensor | None = None
|
||||||
|
@ -137,6 +128,8 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||||
"""Select a single action given environment observations."""
|
"""Select a single action given environment observations."""
|
||||||
batch = self.normalize_inputs(batch)
|
batch = self.normalize_inputs(batch)
|
||||||
|
if self._use_image:
|
||||||
|
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||||
batch["observation.image"] = batch[self.input_image_key]
|
batch["observation.image"] = batch[self.input_image_key]
|
||||||
|
|
||||||
self._queues = populate_queues(self._queues, batch)
|
self._queues = populate_queues(self._queues, batch)
|
||||||
|
@ -151,49 +144,57 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
batch[key] = batch[key][:, 0]
|
batch[key] = batch[key][:, 0]
|
||||||
|
|
||||||
# NOTE: Order of observations matters here.
|
# NOTE: Order of observations matters here.
|
||||||
z = self.model.encode({k: batch[k] for k in ["observation.image", "observation.state"]})
|
encode_keys = []
|
||||||
if self.config.use_mpc:
|
if self._use_image:
|
||||||
batch_size = batch["observation.image"].shape[0]
|
encode_keys.append("observation.image")
|
||||||
# Batch processing is not handled in MPC mode, so process the batch in a loop.
|
if self._use_env_state:
|
||||||
action = [] # will be a batch of actions for one step
|
encode_keys.append("observation.environment_state")
|
||||||
for i in range(batch_size):
|
encode_keys.append("observation.state")
|
||||||
# Note: self.plan does not handle batches, hence the squeeze.
|
z = self.model.encode({k: batch[k] for k in encode_keys})
|
||||||
action.append(self.plan(z[i]))
|
if self.config.use_mpc: # noqa: SIM108
|
||||||
action = torch.stack(action)
|
actions = self.plan(z) # (horizon, batch, action_dim)
|
||||||
else:
|
else:
|
||||||
# Plan with the policy (π) alone.
|
# Plan with the policy (π) alone. This always returns one action so unsqueeze to get a
|
||||||
action = self.model.pi(z)
|
# sequence dimension like in the MPC branch.
|
||||||
|
actions = self.model.pi(z).unsqueeze(0)
|
||||||
|
|
||||||
self.unnormalize_outputs({"action": action})["action"]
|
actions = torch.clamp(actions, -1, +1)
|
||||||
|
|
||||||
|
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||||
|
|
||||||
|
if self.config.n_action_repeats > 1:
|
||||||
for _ in range(self.config.n_action_repeats):
|
for _ in range(self.config.n_action_repeats):
|
||||||
self._queues["action"].append(action)
|
self._queues["action"].append(actions[0])
|
||||||
|
else:
|
||||||
|
# Action queue is (n_action_steps, batch_size, action_dim), so we transpose the action.
|
||||||
|
self._queues["action"].extend(actions[: self.config.n_action_steps])
|
||||||
|
|
||||||
action = self._queues["action"].popleft()
|
action = self._queues["action"].popleft()
|
||||||
return torch.clamp(action, -1, 1)
|
return action
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def plan(self, z: Tensor) -> Tensor:
|
def plan(self, z: Tensor) -> Tensor:
|
||||||
"""Plan next action using TD-MPC inference.
|
"""Plan sequence of actions using TD-MPC inference.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
z: (latent_dim,) tensor for the initial state.
|
z: (batch, latent_dim,) tensor for the initial state.
|
||||||
Returns:
|
Returns:
|
||||||
(action_dim,) tensor for the next action.
|
(horizon, batch, action_dim,) tensor for the planned trajectory of actions.
|
||||||
|
|
||||||
TODO(alexander-soare) Extend this to be able to work with batches.
|
|
||||||
"""
|
"""
|
||||||
device = get_device_from_parameters(self)
|
device = get_device_from_parameters(self)
|
||||||
|
|
||||||
|
batch_size = z.shape[0]
|
||||||
|
|
||||||
# Sample Nπ trajectories from the policy.
|
# Sample Nπ trajectories from the policy.
|
||||||
pi_actions = torch.empty(
|
pi_actions = torch.empty(
|
||||||
self.config.horizon,
|
self.config.horizon,
|
||||||
self.config.n_pi_samples,
|
self.config.n_pi_samples,
|
||||||
|
batch_size,
|
||||||
self.config.output_shapes["action"][0],
|
self.config.output_shapes["action"][0],
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
if self.config.n_pi_samples > 0:
|
if self.config.n_pi_samples > 0:
|
||||||
_z = einops.repeat(z, "d -> n d", n=self.config.n_pi_samples)
|
_z = einops.repeat(z, "b d -> n b d", n=self.config.n_pi_samples)
|
||||||
for t in range(self.config.horizon):
|
for t in range(self.config.horizon):
|
||||||
# Note: Adding a small amount of noise here doesn't hurt during inference and may even be
|
# Note: Adding a small amount of noise here doesn't hurt during inference and may even be
|
||||||
# helpful for CEM.
|
# helpful for CEM.
|
||||||
|
@ -202,12 +203,14 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
|
|
||||||
# In the CEM loop we will need this for a call to estimate_value with the gaussian sampled
|
# In the CEM loop we will need this for a call to estimate_value with the gaussian sampled
|
||||||
# trajectories.
|
# trajectories.
|
||||||
z = einops.repeat(z, "d -> n d", n=self.config.n_gaussian_samples + self.config.n_pi_samples)
|
z = einops.repeat(z, "b d -> n b d", n=self.config.n_gaussian_samples + self.config.n_pi_samples)
|
||||||
|
|
||||||
# Model Predictive Path Integral (MPPI) with the cross-entropy method (CEM) as the optimization
|
# Model Predictive Path Integral (MPPI) with the cross-entropy method (CEM) as the optimization
|
||||||
# algorithm.
|
# algorithm.
|
||||||
# The initial mean and standard deviation for the cross-entropy method (CEM).
|
# The initial mean and standard deviation for the cross-entropy method (CEM).
|
||||||
mean = torch.zeros(self.config.horizon, self.config.output_shapes["action"][0], device=device)
|
mean = torch.zeros(
|
||||||
|
self.config.horizon, batch_size, self.config.output_shapes["action"][0], device=device
|
||||||
|
)
|
||||||
# Maybe warm start CEM with the mean from the previous step.
|
# Maybe warm start CEM with the mean from the previous step.
|
||||||
if self._prev_mean is not None:
|
if self._prev_mean is not None:
|
||||||
mean[:-1] = self._prev_mean[1:]
|
mean[:-1] = self._prev_mean[1:]
|
||||||
|
@ -218,6 +221,7 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
std_normal_noise = torch.randn(
|
std_normal_noise = torch.randn(
|
||||||
self.config.horizon,
|
self.config.horizon,
|
||||||
self.config.n_gaussian_samples,
|
self.config.n_gaussian_samples,
|
||||||
|
batch_size,
|
||||||
self.config.output_shapes["action"][0],
|
self.config.output_shapes["action"][0],
|
||||||
device=std.device,
|
device=std.device,
|
||||||
)
|
)
|
||||||
|
@ -226,21 +230,24 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
# Compute elite actions.
|
# Compute elite actions.
|
||||||
actions = torch.cat([gaussian_actions, pi_actions], dim=1)
|
actions = torch.cat([gaussian_actions, pi_actions], dim=1)
|
||||||
value = self.estimate_value(z, actions).nan_to_num_(0)
|
value = self.estimate_value(z, actions).nan_to_num_(0)
|
||||||
elite_idxs = torch.topk(value, self.config.n_elites, dim=0).indices
|
elite_idxs = torch.topk(value, self.config.n_elites, dim=0).indices # (n_elites, batch)
|
||||||
elite_value, elite_actions = value[elite_idxs], actions[:, elite_idxs]
|
elite_value = value.take_along_dim(elite_idxs, dim=0) # (n_elites, batch)
|
||||||
|
# (horizon, n_elites, batch, action_dim)
|
||||||
|
elite_actions = actions.take_along_dim(einops.rearrange(elite_idxs, "n b -> 1 n b 1"), dim=1)
|
||||||
|
|
||||||
# Update guassian PDF parameters to be the (weighted) mean and standard deviation of the elites.
|
# Update gaussian PDF parameters to be the (weighted) mean and standard deviation of the elites.
|
||||||
max_value = elite_value.max(0)[0]
|
max_value = elite_value.max(0, keepdim=True)[0] # (1, batch)
|
||||||
# The weighting is a softmax over trajectory values. Note that this is not the same as the usage
|
# The weighting is a softmax over trajectory values. Note that this is not the same as the usage
|
||||||
# of Ω in eqn 4 of the TD-MPC paper. Instead it is the normalized version of it: s = Ω/ΣΩ. This
|
# of Ω in eqn 4 of the TD-MPC paper. Instead it is the normalized version of it: s = Ω/ΣΩ. This
|
||||||
# makes the equations: μ = Σ(s⋅Γ), σ = Σ(s⋅(Γ-μ)²).
|
# makes the equations: μ = Σ(s⋅Γ), σ = Σ(s⋅(Γ-μ)²).
|
||||||
score = torch.exp(self.config.elite_weighting_temperature * (elite_value - max_value))
|
score = torch.exp(self.config.elite_weighting_temperature * (elite_value - max_value))
|
||||||
score /= score.sum()
|
score /= score.sum(axis=0, keepdim=True)
|
||||||
_mean = torch.sum(einops.rearrange(score, "n -> n 1") * elite_actions, dim=1)
|
# (horizon, batch, action_dim)
|
||||||
|
_mean = torch.sum(einops.rearrange(score, "n b -> n b 1") * elite_actions, dim=1)
|
||||||
_std = torch.sqrt(
|
_std = torch.sqrt(
|
||||||
torch.sum(
|
torch.sum(
|
||||||
einops.rearrange(score, "n -> n 1")
|
einops.rearrange(score, "n b -> n b 1")
|
||||||
* (elite_actions - einops.rearrange(_mean, "h d -> h 1 d")) ** 2,
|
* (elite_actions - einops.rearrange(_mean, "h b d -> h 1 b d")) ** 2,
|
||||||
dim=1,
|
dim=1,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -255,11 +262,9 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
|
|
||||||
# Randomly select one of the elite actions from the last iteration of MPPI/CEM using the softmax
|
# Randomly select one of the elite actions from the last iteration of MPPI/CEM using the softmax
|
||||||
# scores from the last iteration.
|
# scores from the last iteration.
|
||||||
actions = elite_actions[:, torch.multinomial(score, 1).item()]
|
actions = elite_actions[:, torch.multinomial(score.T, 1).squeeze(), torch.arange(batch_size)]
|
||||||
|
|
||||||
# Select only the first action
|
return actions
|
||||||
action = actions[0]
|
|
||||||
return action
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def estimate_value(self, z: Tensor, actions: Tensor):
|
def estimate_value(self, z: Tensor, actions: Tensor):
|
||||||
|
@ -311,11 +316,16 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
G -= running_discount * self.config.uncertainty_regularizer_coeff * terminal_values.std(0)
|
G -= running_discount * self.config.uncertainty_regularizer_coeff * terminal_values.std(0)
|
||||||
return G
|
return G
|
||||||
|
|
||||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]:
|
||||||
"""Run the batch through the model and compute the loss."""
|
"""Run the batch through the model and compute the loss.
|
||||||
|
|
||||||
|
Returns a dictionary with loss as a tensor, and other information as native floats.
|
||||||
|
"""
|
||||||
device = get_device_from_parameters(self)
|
device = get_device_from_parameters(self)
|
||||||
|
|
||||||
batch = self.normalize_inputs(batch)
|
batch = self.normalize_inputs(batch)
|
||||||
|
if self._use_image:
|
||||||
|
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||||
batch["observation.image"] = batch[self.input_image_key]
|
batch["observation.image"] = batch[self.input_image_key]
|
||||||
batch = self.normalize_targets(batch)
|
batch = self.normalize_targets(batch)
|
||||||
|
|
||||||
|
@ -326,12 +336,12 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
if batch[key].ndim > 1:
|
if batch[key].ndim > 1:
|
||||||
batch[key] = batch[key].transpose(1, 0)
|
batch[key] = batch[key].transpose(1, 0)
|
||||||
|
|
||||||
action = batch["action"] # (t, b)
|
action = batch["action"] # (t, b, action_dim)
|
||||||
reward = batch["next.reward"] # (t,)
|
reward = batch["next.reward"] # (t, b)
|
||||||
observations = {k: v for k, v in batch.items() if k.startswith("observation.")}
|
observations = {k: v for k, v in batch.items() if k.startswith("observation.")}
|
||||||
|
|
||||||
# Apply random image augmentations.
|
# Apply random image augmentations.
|
||||||
if self.config.max_random_shift_ratio > 0:
|
if self._use_image and self.config.max_random_shift_ratio > 0:
|
||||||
observations["observation.image"] = flatten_forward_unflatten(
|
observations["observation.image"] = flatten_forward_unflatten(
|
||||||
partial(random_shifts_aug, max_random_shift_ratio=self.config.max_random_shift_ratio),
|
partial(random_shifts_aug, max_random_shift_ratio=self.config.max_random_shift_ratio),
|
||||||
observations["observation.image"],
|
observations["observation.image"],
|
||||||
|
@ -343,7 +353,9 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
for k in observations:
|
for k in observations:
|
||||||
current_observation[k] = observations[k][0]
|
current_observation[k] = observations[k][0]
|
||||||
next_observations[k] = observations[k][1:]
|
next_observations[k] = observations[k][1:]
|
||||||
horizon = next_observations["observation.image"].shape[0]
|
horizon, batch_size = next_observations[
|
||||||
|
"observation.image" if self._use_image else "observation.environment_state"
|
||||||
|
].shape[:2]
|
||||||
|
|
||||||
# Run latent rollout using the latent dynamics model and policy model.
|
# Run latent rollout using the latent dynamics model and policy model.
|
||||||
# Note this has shape `horizon+1` because there are `horizon` actions and a current `z`. Each action
|
# Note this has shape `horizon+1` because there are `horizon` actions and a current `z`. Each action
|
||||||
|
@ -413,7 +425,8 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
|
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
|
||||||
q_value_loss = (
|
q_value_loss = (
|
||||||
(
|
(
|
||||||
F.mse_loss(
|
temporal_loss_coeffs
|
||||||
|
* F.mse_loss(
|
||||||
q_preds_ensemble,
|
q_preds_ensemble,
|
||||||
einops.repeat(q_targets, "t b -> e t b", e=q_preds_ensemble.shape[0]),
|
einops.repeat(q_targets, "t b -> e t b", e=q_preds_ensemble.shape[0]),
|
||||||
reduction="none",
|
reduction="none",
|
||||||
|
@ -462,10 +475,11 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
action_preds = self.model.pi(z_preds[:-1]) # (t, b, a)
|
action_preds = self.model.pi(z_preds[:-1]) # (t, b, a)
|
||||||
# Calculate the MSE between the actions and the action predictions.
|
# Calculate the MSE between the actions and the action predictions.
|
||||||
# Note: FOWM's original code calculates the log probability (wrt to a unit standard deviation
|
# Note: FOWM's original code calculates the log probability (wrt to a unit standard deviation
|
||||||
# gaussian) and sums over the action dimension. Computing the log probability amounts to multiplying
|
# gaussian) and sums over the action dimension. Computing the (negative) log probability amounts to
|
||||||
# the MSE by 0.5 and adding a constant offset (the log(2*pi) term) . Here we drop the constant offset
|
# multiplying the MSE by 0.5 and adding a constant offset (the log(2*pi)/2 term, times the action
|
||||||
# as it doesn't change the optimization step, and we drop the 0.5 as we instead make a configuration
|
# dimension). Here we drop the constant offset as it doesn't change the optimization step, and we drop
|
||||||
# parameter for it (see below where we compute the total loss).
|
# the 0.5 as we instead make a configuration parameter for it (see below where we compute the total
|
||||||
|
# loss).
|
||||||
mse = F.mse_loss(action_preds, action, reduction="none").sum(-1) # (t, b)
|
mse = F.mse_loss(action_preds, action, reduction="none").sum(-1) # (t, b)
|
||||||
# NOTE: The original implementation does not take the sum over the temporal dimension like with the
|
# NOTE: The original implementation does not take the sum over the temporal dimension like with the
|
||||||
# other losses.
|
# other losses.
|
||||||
|
@ -726,6 +740,16 @@ class TDMPCObservationEncoder(nn.Module):
|
||||||
nn.LayerNorm(config.latent_dim),
|
nn.LayerNorm(config.latent_dim),
|
||||||
nn.Sigmoid(),
|
nn.Sigmoid(),
|
||||||
)
|
)
|
||||||
|
if "observation.environment_state" in config.input_shapes:
|
||||||
|
self.env_state_enc_layers = nn.Sequential(
|
||||||
|
nn.Linear(
|
||||||
|
config.input_shapes["observation.environment_state"][0], config.state_encoder_hidden_dim
|
||||||
|
),
|
||||||
|
nn.ELU(),
|
||||||
|
nn.Linear(config.state_encoder_hidden_dim, config.latent_dim),
|
||||||
|
nn.LayerNorm(config.latent_dim),
|
||||||
|
nn.Sigmoid(),
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, obs_dict: dict[str, Tensor]) -> Tensor:
|
def forward(self, obs_dict: dict[str, Tensor]) -> Tensor:
|
||||||
"""Encode the image and/or state vector.
|
"""Encode the image and/or state vector.
|
||||||
|
@ -734,8 +758,11 @@ class TDMPCObservationEncoder(nn.Module):
|
||||||
over all features.
|
over all features.
|
||||||
"""
|
"""
|
||||||
feat = []
|
feat = []
|
||||||
|
# NOTE: Order of observations matters here.
|
||||||
if "observation.image" in self.config.input_shapes:
|
if "observation.image" in self.config.input_shapes:
|
||||||
feat.append(flatten_forward_unflatten(self.image_enc_layers, obs_dict["observation.image"]))
|
feat.append(flatten_forward_unflatten(self.image_enc_layers, obs_dict["observation.image"]))
|
||||||
|
if "observation.environment_state" in self.config.input_shapes:
|
||||||
|
feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"]))
|
||||||
if "observation.state" in self.config.input_shapes:
|
if "observation.state" in self.config.input_shapes:
|
||||||
feat.append(self.state_enc_layers(obs_dict["observation.state"]))
|
feat.append(self.state_enc_layers(obs_dict["observation.state"]))
|
||||||
return torch.stack(feat, dim=0).mean(0)
|
return torch.stack(feat, dim=0).mean(0)
|
||||||
|
|
|
@ -98,6 +98,7 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
batch = self.normalize_inputs(batch)
|
batch = self.normalize_inputs(batch)
|
||||||
|
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||||
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
|
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
|
||||||
# Note: It's important that this happens after stacking the images into a single key.
|
# Note: It's important that this happens after stacking the images into a single key.
|
||||||
self._queues = populate_queues(self._queues, batch)
|
self._queues = populate_queues(self._queues, batch)
|
||||||
|
@ -123,6 +124,7 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||||
"""Run the batch through the model and compute the loss for training or validation."""
|
"""Run the batch through the model and compute the loss for training or validation."""
|
||||||
batch = self.normalize_inputs(batch)
|
batch = self.normalize_inputs(batch)
|
||||||
|
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||||
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
|
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
|
||||||
batch = self.normalize_targets(batch)
|
batch = self.normalize_targets(batch)
|
||||||
# VQ-BeT discretizes action using VQ-VAE before training BeT (please refer to section 3.2 in the VQ-BeT paper https://arxiv.org/pdf/2403.03181)
|
# VQ-BeT discretizes action using VQ-VAE before training BeT (please refer to section 3.2 in the VQ-BeT paper https://arxiv.org/pdf/2403.03181)
|
||||||
|
|
|
@ -32,19 +32,54 @@ video_backend: pyav
|
||||||
|
|
||||||
training:
|
training:
|
||||||
offline_steps: ???
|
offline_steps: ???
|
||||||
# NOTE: `online_steps` is not implemented yet. It's here as a placeholder.
|
|
||||||
online_steps: ???
|
# Number of workers for the offline training dataloader.
|
||||||
online_steps_between_rollouts: ???
|
num_workers: 4
|
||||||
online_sampling_ratio: 0.5
|
|
||||||
# `online_env_seed` is used for environments for online training data rollouts.
|
batch_size: ???
|
||||||
online_env_seed: ???
|
|
||||||
eval_freq: ???
|
eval_freq: ???
|
||||||
log_freq: 200
|
log_freq: 200
|
||||||
save_checkpoint: true
|
save_checkpoint: true
|
||||||
# Checkpoint is saved every `save_freq` training iterations and after the last training step.
|
# Checkpoint is saved every `save_freq` training iterations and after the last training step.
|
||||||
save_freq: ???
|
save_freq: ???
|
||||||
num_workers: 4
|
|
||||||
batch_size: ???
|
# Online training. Note that the online training loop adopts most of the options above apart from the
|
||||||
|
# dataloader options. Unless otherwise specified.
|
||||||
|
# The online training look looks something like:
|
||||||
|
#
|
||||||
|
# for i in range(online_steps):
|
||||||
|
# do_online_rollout_and_update_online_buffer()
|
||||||
|
# for j in range(online_steps_between_rollouts):
|
||||||
|
# batch = next(dataloader_with_offline_and_online_data)
|
||||||
|
# loss = policy(batch)
|
||||||
|
# loss.backward()
|
||||||
|
# optimizer.step()
|
||||||
|
#
|
||||||
|
online_steps: ???
|
||||||
|
# How many episodes to collect at once when we reach the online rollout part of the training loop.
|
||||||
|
online_rollout_n_episodes: 1
|
||||||
|
# The number of environments to use in the gym.vector.VectorEnv. This ends up also being the batch size for
|
||||||
|
# the policy. Ideally you should set this to by an even divisor or online_rollout_n_episodes.
|
||||||
|
online_rollout_batch_size: 1
|
||||||
|
# How many optimization steps (forward, backward, optimizer step) to do between running rollouts.
|
||||||
|
online_steps_between_rollouts: null
|
||||||
|
# The proportion of online samples (vs offline samples) to include in the online training batches.
|
||||||
|
online_sampling_ratio: 0.5
|
||||||
|
# First seed to use for the online rollout environment. Seeds for subsequent rollouts are incremented by 1.
|
||||||
|
online_env_seed: null
|
||||||
|
# Sets the maximum number of frames that are stored in the online buffer for online training. The buffer is
|
||||||
|
# FIFO.
|
||||||
|
online_buffer_capacity: null
|
||||||
|
# The minimum number of frames to have in the online buffer before commencing online training.
|
||||||
|
# If online_buffer_seed_size > online_rollout_n_episodes, the rollout will be run multiple times until the
|
||||||
|
# seed size condition is satisfied.
|
||||||
|
online_buffer_seed_size: 0
|
||||||
|
# Whether to run the online rollouts asynchronously. This means we can run the online training steps in
|
||||||
|
# parallel with the rollouts. This might be advised if your GPU has the bandwidth to handle training
|
||||||
|
# + eval + environment rendering simultaneously.
|
||||||
|
do_online_rollout_async: false
|
||||||
|
|
||||||
image_transforms:
|
image_transforms:
|
||||||
# These transforms are all using standard torchvision.transforms.v2
|
# These transforms are all using standard torchvision.transforms.v2
|
||||||
# You can find out how these transformations affect images here:
|
# You can find out how these transformations affect images here:
|
||||||
|
|
|
@ -9,7 +9,7 @@ env:
|
||||||
state_dim: 4
|
state_dim: 4
|
||||||
action_dim: 4
|
action_dim: 4
|
||||||
fps: ${fps}
|
fps: ${fps}
|
||||||
episode_length: 25
|
episode_length: 200
|
||||||
gym:
|
gym:
|
||||||
obs_type: pixels_agent_pos
|
obs_type: pixels_agent_pos
|
||||||
render_mode: rgb_array
|
render_mode: rgb_array
|
||||||
|
|
|
@ -4,19 +4,30 @@ seed: 1
|
||||||
dataset_repo_id: lerobot/xarm_lift_medium
|
dataset_repo_id: lerobot/xarm_lift_medium
|
||||||
|
|
||||||
training:
|
training:
|
||||||
offline_steps: 25000
|
offline_steps: 50000
|
||||||
# TODO(alexander-soare): uncomment when online training gets reinstated
|
|
||||||
online_steps: 0 # 25000 not implemented yet
|
num_workers: 4
|
||||||
eval_freq: 5000
|
|
||||||
online_steps_between_rollouts: 1
|
|
||||||
online_sampling_ratio: 0.5
|
|
||||||
online_env_seed: 10000
|
|
||||||
log_freq: 100
|
|
||||||
|
|
||||||
batch_size: 256
|
batch_size: 256
|
||||||
grad_clip_norm: 10.0
|
grad_clip_norm: 10.0
|
||||||
lr: 3e-4
|
lr: 3e-4
|
||||||
|
|
||||||
|
eval_freq: 5000
|
||||||
|
log_freq: 100
|
||||||
|
|
||||||
|
online_steps: 50000
|
||||||
|
online_rollout_n_episodes: 1
|
||||||
|
online_rollout_batch_size: 1
|
||||||
|
# Note: in FOWM `online_steps_between_rollouts` is actually dynamically set to match exactly the length of
|
||||||
|
# the last sampled episode.
|
||||||
|
online_steps_between_rollouts: 50
|
||||||
|
online_sampling_ratio: 0.5
|
||||||
|
online_env_seed: 10000
|
||||||
|
# FOWM Push uses 10000 for `online_buffer_capacity`. Given that their maximum episode length for this task
|
||||||
|
# is 25, 10000 is approx 400 of their episodes worth. Since our episodes are about 8 times longer, we'll use
|
||||||
|
# 80000.
|
||||||
|
online_buffer_capacity: 80000
|
||||||
|
|
||||||
delta_timestamps:
|
delta_timestamps:
|
||||||
observation.image: "[i / ${fps} for i in range(${policy.horizon} + 1)]"
|
observation.image: "[i / ${fps} for i in range(${policy.horizon} + 1)]"
|
||||||
observation.state: "[i / ${fps} for i in range(${policy.horizon} + 1)]"
|
observation.state: "[i / ${fps} for i in range(${policy.horizon} + 1)]"
|
||||||
|
@ -31,6 +42,7 @@ policy:
|
||||||
# Input / output structure.
|
# Input / output structure.
|
||||||
n_action_repeats: 2
|
n_action_repeats: 2
|
||||||
horizon: 5
|
horizon: 5
|
||||||
|
n_action_steps: 1
|
||||||
|
|
||||||
input_shapes:
|
input_shapes:
|
||||||
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
||||||
|
|
|
@ -0,0 +1,105 @@
|
||||||
|
# @package _global_
|
||||||
|
|
||||||
|
# Train with:
|
||||||
|
#
|
||||||
|
# python lerobot/scripts/train.py \
|
||||||
|
# env=pusht \
|
||||||
|
# env.gym.obs_type=environment_state_agent_pos \
|
||||||
|
# policy=tdmpc_pusht_keypoints \
|
||||||
|
# eval.batch_size=50 \
|
||||||
|
# eval.n_episodes=50 \
|
||||||
|
# eval.use_async_envs=true \
|
||||||
|
# device=cuda \
|
||||||
|
# use_amp=true
|
||||||
|
|
||||||
|
seed: 1
|
||||||
|
dataset_repo_id: lerobot/pusht_keypoints
|
||||||
|
|
||||||
|
training:
|
||||||
|
offline_steps: 0
|
||||||
|
|
||||||
|
# Offline training dataloader
|
||||||
|
num_workers: 4
|
||||||
|
|
||||||
|
batch_size: 256
|
||||||
|
grad_clip_norm: 10.0
|
||||||
|
lr: 3e-4
|
||||||
|
|
||||||
|
eval_freq: 10000
|
||||||
|
log_freq: 500
|
||||||
|
save_freq: 50000
|
||||||
|
|
||||||
|
online_steps: 1000000
|
||||||
|
online_rollout_n_episodes: 10
|
||||||
|
online_rollout_batch_size: 10
|
||||||
|
online_steps_between_rollouts: 1000
|
||||||
|
online_sampling_ratio: 1.0
|
||||||
|
online_env_seed: 10000
|
||||||
|
online_buffer_capacity: 40000
|
||||||
|
online_buffer_seed_size: 0
|
||||||
|
do_online_rollout_async: false
|
||||||
|
|
||||||
|
delta_timestamps:
|
||||||
|
observation.environment_state: "[i / ${fps} for i in range(${policy.horizon} + 1)]"
|
||||||
|
observation.state: "[i / ${fps} for i in range(${policy.horizon} + 1)]"
|
||||||
|
action: "[i / ${fps} for i in range(${policy.horizon})]"
|
||||||
|
next.reward: "[i / ${fps} for i in range(${policy.horizon})]"
|
||||||
|
|
||||||
|
policy:
|
||||||
|
name: tdmpc
|
||||||
|
|
||||||
|
pretrained_model_path:
|
||||||
|
|
||||||
|
# Input / output structure.
|
||||||
|
n_action_repeats: 1
|
||||||
|
horizon: 5
|
||||||
|
n_action_steps: 5
|
||||||
|
|
||||||
|
input_shapes:
|
||||||
|
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
||||||
|
observation.environment_state: [16]
|
||||||
|
observation.state: ["${env.state_dim}"]
|
||||||
|
output_shapes:
|
||||||
|
action: ["${env.action_dim}"]
|
||||||
|
|
||||||
|
# Normalization / Unnormalization
|
||||||
|
input_normalization_modes:
|
||||||
|
observation.environment_state: min_max
|
||||||
|
observation.state: min_max
|
||||||
|
output_normalization_modes:
|
||||||
|
action: min_max
|
||||||
|
|
||||||
|
# Architecture / modeling.
|
||||||
|
# Neural networks.
|
||||||
|
image_encoder_hidden_dim: 32
|
||||||
|
state_encoder_hidden_dim: 256
|
||||||
|
latent_dim: 50
|
||||||
|
q_ensemble_size: 5
|
||||||
|
mlp_dim: 512
|
||||||
|
# Reinforcement learning.
|
||||||
|
discount: 0.98
|
||||||
|
|
||||||
|
# Inference.
|
||||||
|
use_mpc: true
|
||||||
|
cem_iterations: 6
|
||||||
|
max_std: 2.0
|
||||||
|
min_std: 0.05
|
||||||
|
n_gaussian_samples: 512
|
||||||
|
n_pi_samples: 51
|
||||||
|
uncertainty_regularizer_coeff: 1.0
|
||||||
|
n_elites: 50
|
||||||
|
elite_weighting_temperature: 0.5
|
||||||
|
gaussian_mean_momentum: 0.1
|
||||||
|
|
||||||
|
# Training and loss computation.
|
||||||
|
max_random_shift_ratio: 0.0476
|
||||||
|
# Loss coefficients.
|
||||||
|
reward_coeff: 0.5
|
||||||
|
expectile_weight: 0.9
|
||||||
|
value_coeff: 0.1
|
||||||
|
consistency_coeff: 20.0
|
||||||
|
advantage_scaling: 3.0
|
||||||
|
pi_coeff: 0.5
|
||||||
|
temporal_decay_coeff: 0.5
|
||||||
|
# Target model.
|
||||||
|
target_model_momentum: 0.995
|
|
@ -101,7 +101,7 @@ from termcolor import colored
|
||||||
from lerobot.common.datasets.compute_stats import compute_stats
|
from lerobot.common.datasets.compute_stats import compute_stats
|
||||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
|
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
|
||||||
from lerobot.common.datasets.push_dataset_to_hub.aloha_hdf5_format import to_hf_dataset
|
from lerobot.common.datasets.push_dataset_to_hub.aloha_hdf5_format import to_hf_dataset
|
||||||
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes
|
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, get_default_encoding
|
||||||
from lerobot.common.datasets.utils import calculate_episode_data_index
|
from lerobot.common.datasets.utils import calculate_episode_data_index
|
||||||
from lerobot.common.datasets.video_utils import encode_video_frames
|
from lerobot.common.datasets.video_utils import encode_video_frames
|
||||||
from lerobot.common.policies.factory import make_policy
|
from lerobot.common.policies.factory import make_policy
|
||||||
|
@ -479,6 +479,8 @@ def record_dataset(
|
||||||
"fps": fps,
|
"fps": fps,
|
||||||
"video": video,
|
"video": video,
|
||||||
}
|
}
|
||||||
|
if video:
|
||||||
|
info["encoding"] = get_default_encoding()
|
||||||
|
|
||||||
lerobot_dataset = LeRobotDataset.from_preloaded(
|
lerobot_dataset = LeRobotDataset.from_preloaded(
|
||||||
repo_id=repo_id,
|
repo_id=repo_id,
|
||||||
|
|
|
@ -56,16 +56,13 @@ import einops
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from datasets import Dataset, Features, Image, Sequence, Value, concatenate_datasets
|
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
from huggingface_hub.utils._errors import RepositoryNotFoundError
|
from huggingface_hub.utils._errors import RepositoryNotFoundError
|
||||||
from huggingface_hub.utils._validators import HFValidationError
|
from huggingface_hub.utils._validators import HFValidationError
|
||||||
from PIL import Image as PILImage
|
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
from tqdm import trange
|
from tqdm import trange
|
||||||
|
|
||||||
from lerobot.common.datasets.factory import make_dataset
|
from lerobot.common.datasets.factory import make_dataset
|
||||||
from lerobot.common.datasets.utils import hf_transform_to_torch
|
|
||||||
from lerobot.common.envs.factory import make_env
|
from lerobot.common.envs.factory import make_env
|
||||||
from lerobot.common.envs.utils import preprocess_observation
|
from lerobot.common.envs.utils import preprocess_observation
|
||||||
from lerobot.common.logger import log_output_dir
|
from lerobot.common.logger import log_output_dir
|
||||||
|
@ -318,41 +315,17 @@ def eval_policy(
|
||||||
rollout_data,
|
rollout_data,
|
||||||
done_indices,
|
done_indices,
|
||||||
start_episode_index=batch_ix * env.num_envs,
|
start_episode_index=batch_ix * env.num_envs,
|
||||||
start_data_index=(
|
start_data_index=(0 if episode_data is None else (episode_data["index"][-1].item() + 1)),
|
||||||
0 if episode_data is None else (episode_data["episode_data_index"]["to"][-1].item())
|
|
||||||
),
|
|
||||||
fps=env.unwrapped.metadata["render_fps"],
|
fps=env.unwrapped.metadata["render_fps"],
|
||||||
)
|
)
|
||||||
if episode_data is None:
|
if episode_data is None:
|
||||||
episode_data = this_episode_data
|
episode_data = this_episode_data
|
||||||
else:
|
else:
|
||||||
# Some sanity checks to make sure we are not correctly compiling the data.
|
# Some sanity checks to make sure we are correctly compiling the data.
|
||||||
assert (
|
assert episode_data["episode_index"][-1] + 1 == this_episode_data["episode_index"][0]
|
||||||
episode_data["hf_dataset"]["episode_index"][-1] + 1
|
assert episode_data["index"][-1] + 1 == this_episode_data["index"][0]
|
||||||
== this_episode_data["hf_dataset"]["episode_index"][0]
|
|
||||||
)
|
|
||||||
assert (
|
|
||||||
episode_data["hf_dataset"]["index"][-1] + 1 == this_episode_data["hf_dataset"]["index"][0]
|
|
||||||
)
|
|
||||||
assert torch.equal(
|
|
||||||
episode_data["episode_data_index"]["to"][-1],
|
|
||||||
this_episode_data["episode_data_index"]["from"][0],
|
|
||||||
)
|
|
||||||
# Concatenate the episode data.
|
# Concatenate the episode data.
|
||||||
episode_data = {
|
episode_data = {k: torch.cat([episode_data[k], this_episode_data[k]]) for k in episode_data}
|
||||||
"hf_dataset": concatenate_datasets(
|
|
||||||
[episode_data["hf_dataset"], this_episode_data["hf_dataset"]]
|
|
||||||
),
|
|
||||||
"episode_data_index": {
|
|
||||||
k: torch.cat(
|
|
||||||
[
|
|
||||||
episode_data["episode_data_index"][k],
|
|
||||||
this_episode_data["episode_data_index"][k],
|
|
||||||
]
|
|
||||||
)
|
|
||||||
for k in ["from", "to"]
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
# Maybe render video for visualization.
|
# Maybe render video for visualization.
|
||||||
if max_episodes_rendered > 0 and len(ep_frames) > 0:
|
if max_episodes_rendered > 0 and len(ep_frames) > 0:
|
||||||
|
@ -434,89 +407,39 @@ def _compile_episode_data(
|
||||||
Similar logic is implemented when datasets are pushed to hub (see: `push_to_hub`).
|
Similar logic is implemented when datasets are pushed to hub (see: `push_to_hub`).
|
||||||
"""
|
"""
|
||||||
ep_dicts = []
|
ep_dicts = []
|
||||||
episode_data_index = {"from": [], "to": []}
|
|
||||||
total_frames = 0
|
total_frames = 0
|
||||||
data_index_from = start_data_index
|
|
||||||
for ep_ix in range(rollout_data["action"].shape[0]):
|
for ep_ix in range(rollout_data["action"].shape[0]):
|
||||||
num_frames = done_indices[ep_ix].item() + 1 # + 1 to include the first done frame
|
# + 2 to include the first done frame and the last observation frame.
|
||||||
|
num_frames = done_indices[ep_ix].item() + 2
|
||||||
total_frames += num_frames
|
total_frames += num_frames
|
||||||
|
|
||||||
# TODO(rcadene): We need to add a missing last frame which is the observation
|
# Here we do `num_frames - 1` as we don't want to include the last observation frame just yet.
|
||||||
# of a done state. it is critical to have this frame for tdmpc to predict a "done observation/state"
|
|
||||||
ep_dict = {
|
ep_dict = {
|
||||||
"action": rollout_data["action"][ep_ix, :num_frames],
|
"action": rollout_data["action"][ep_ix, : num_frames - 1],
|
||||||
"episode_index": torch.tensor([start_episode_index + ep_ix] * num_frames),
|
"episode_index": torch.tensor([start_episode_index + ep_ix] * (num_frames - 1)),
|
||||||
"frame_index": torch.arange(0, num_frames, 1),
|
"frame_index": torch.arange(0, num_frames - 1, 1),
|
||||||
"timestamp": torch.arange(0, num_frames, 1) / fps,
|
"timestamp": torch.arange(0, num_frames - 1, 1) / fps,
|
||||||
"next.done": rollout_data["done"][ep_ix, :num_frames],
|
"next.done": rollout_data["done"][ep_ix, : num_frames - 1],
|
||||||
"next.reward": rollout_data["reward"][ep_ix, :num_frames].type(torch.float32),
|
"next.success": rollout_data["success"][ep_ix, : num_frames - 1],
|
||||||
|
"next.reward": rollout_data["reward"][ep_ix, : num_frames - 1].type(torch.float32),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# For the last observation frame, all other keys will just be copy padded.
|
||||||
|
for k in ep_dict:
|
||||||
|
ep_dict[k] = torch.cat([ep_dict[k], ep_dict[k][-1:]])
|
||||||
|
|
||||||
for key in rollout_data["observation"]:
|
for key in rollout_data["observation"]:
|
||||||
ep_dict[key] = rollout_data["observation"][key][ep_ix][:num_frames]
|
ep_dict[key] = rollout_data["observation"][key][ep_ix, :num_frames]
|
||||||
|
|
||||||
ep_dicts.append(ep_dict)
|
ep_dicts.append(ep_dict)
|
||||||
|
|
||||||
episode_data_index["from"].append(data_index_from)
|
|
||||||
episode_data_index["to"].append(data_index_from + num_frames)
|
|
||||||
|
|
||||||
data_index_from += num_frames
|
|
||||||
|
|
||||||
data_dict = {}
|
data_dict = {}
|
||||||
for key in ep_dicts[0]:
|
for key in ep_dicts[0]:
|
||||||
if "image" not in key:
|
|
||||||
data_dict[key] = torch.cat([x[key] for x in ep_dicts])
|
data_dict[key] = torch.cat([x[key] for x in ep_dicts])
|
||||||
else:
|
|
||||||
if key not in data_dict:
|
|
||||||
data_dict[key] = []
|
|
||||||
for ep_dict in ep_dicts:
|
|
||||||
for img in ep_dict[key]:
|
|
||||||
# sanity check that images are channel first
|
|
||||||
c, h, w = img.shape
|
|
||||||
assert c < h and c < w, f"expect channel first images, but instead {img.shape}"
|
|
||||||
|
|
||||||
# sanity check that images are float32 in range [0,1]
|
|
||||||
assert img.dtype == torch.float32, f"expect torch.float32, but instead {img.dtype=}"
|
|
||||||
assert img.max() <= 1, f"expect pixels lower than 1, but instead {img.max()=}"
|
|
||||||
assert img.min() >= 0, f"expect pixels greater than 1, but instead {img.min()=}"
|
|
||||||
|
|
||||||
# from float32 in range [0,1] to uint8 in range [0,255]
|
|
||||||
img *= 255
|
|
||||||
img = img.type(torch.uint8)
|
|
||||||
|
|
||||||
# convert to channel last and numpy as expected by PIL
|
|
||||||
img = PILImage.fromarray(img.permute(1, 2, 0).numpy())
|
|
||||||
|
|
||||||
data_dict[key].append(img)
|
|
||||||
|
|
||||||
data_dict["index"] = torch.arange(start_data_index, start_data_index + total_frames, 1)
|
data_dict["index"] = torch.arange(start_data_index, start_data_index + total_frames, 1)
|
||||||
episode_data_index["from"] = torch.tensor(episode_data_index["from"])
|
|
||||||
episode_data_index["to"] = torch.tensor(episode_data_index["to"])
|
|
||||||
|
|
||||||
# TODO(rcadene): clean this
|
return data_dict
|
||||||
features = {}
|
|
||||||
for key in rollout_data["observation"]:
|
|
||||||
if "image" in key:
|
|
||||||
features[key] = Image()
|
|
||||||
else:
|
|
||||||
features[key] = Sequence(length=data_dict[key].shape[1], feature=Value(dtype="float32", id=None))
|
|
||||||
features.update(
|
|
||||||
{
|
|
||||||
"action": Sequence(length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)),
|
|
||||||
"episode_index": Value(dtype="int64", id=None),
|
|
||||||
"frame_index": Value(dtype="int64", id=None),
|
|
||||||
"timestamp": Value(dtype="float32", id=None),
|
|
||||||
"next.reward": Value(dtype="float32", id=None),
|
|
||||||
"next.done": Value(dtype="bool", id=None),
|
|
||||||
#'next.success': Value(dtype='bool', id=None),
|
|
||||||
"index": Value(dtype="int64", id=None),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
features = Features(features)
|
|
||||||
hf_dataset = Dataset.from_dict(data_dict, features=features)
|
|
||||||
hf_dataset.set_transform(hf_transform_to_torch)
|
|
||||||
return {
|
|
||||||
"hf_dataset": hf_dataset,
|
|
||||||
"episode_data_index": episode_data_index,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def main(
|
def main(
|
||||||
|
|
|
@ -55,6 +55,7 @@ from safetensors.torch import save_file
|
||||||
|
|
||||||
from lerobot.common.datasets.compute_stats import compute_stats
|
from lerobot.common.datasets.compute_stats import compute_stats
|
||||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
|
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
|
||||||
|
from lerobot.common.datasets.push_dataset_to_hub.utils import check_repo_id
|
||||||
from lerobot.common.datasets.utils import flatten_dict
|
from lerobot.common.datasets.utils import flatten_dict
|
||||||
|
|
||||||
|
|
||||||
|
@ -140,14 +141,12 @@ def push_dataset_to_hub(
|
||||||
num_workers: int = 8,
|
num_workers: int = 8,
|
||||||
episodes: list[int] | None = None,
|
episodes: list[int] | None = None,
|
||||||
force_override: bool = False,
|
force_override: bool = False,
|
||||||
|
resume: bool = False,
|
||||||
cache_dir: Path = Path("/tmp"),
|
cache_dir: Path = Path("/tmp"),
|
||||||
tests_data_dir: Path | None = None,
|
tests_data_dir: Path | None = None,
|
||||||
|
encoding: dict | None = None,
|
||||||
):
|
):
|
||||||
# Check repo_id is well formated
|
check_repo_id(repo_id)
|
||||||
if len(repo_id.split("/")) != 2:
|
|
||||||
raise ValueError(
|
|
||||||
f"`repo_id` is expected to contain a community or user id `/` the name of the dataset (e.g. 'lerobot/pusht'), but instead contains '{repo_id}'."
|
|
||||||
)
|
|
||||||
user_id, dataset_id = repo_id.split("/")
|
user_id, dataset_id = repo_id.split("/")
|
||||||
|
|
||||||
# Robustify when `raw_dir` is str instead of Path
|
# Robustify when `raw_dir` is str instead of Path
|
||||||
|
@ -173,7 +172,7 @@ def push_dataset_to_hub(
|
||||||
if local_dir.exists():
|
if local_dir.exists():
|
||||||
if force_override:
|
if force_override:
|
||||||
shutil.rmtree(local_dir)
|
shutil.rmtree(local_dir)
|
||||||
else:
|
elif not resume:
|
||||||
raise ValueError(f"`local_dir` already exists ({local_dir}). Use `--force-override 1`.")
|
raise ValueError(f"`local_dir` already exists ({local_dir}). Use `--force-override 1`.")
|
||||||
|
|
||||||
meta_data_dir = local_dir / "meta_data"
|
meta_data_dir = local_dir / "meta_data"
|
||||||
|
@ -191,7 +190,7 @@ def push_dataset_to_hub(
|
||||||
# convert dataset from original raw format to LeRobot format
|
# convert dataset from original raw format to LeRobot format
|
||||||
from_raw_to_lerobot_format = get_from_raw_to_lerobot_format_fn(raw_format)
|
from_raw_to_lerobot_format = get_from_raw_to_lerobot_format_fn(raw_format)
|
||||||
hf_dataset, episode_data_index, info = from_raw_to_lerobot_format(
|
hf_dataset, episode_data_index, info = from_raw_to_lerobot_format(
|
||||||
raw_dir, videos_dir, fps, video, episodes
|
raw_dir, videos_dir, fps, video, episodes, encoding
|
||||||
)
|
)
|
||||||
|
|
||||||
lerobot_dataset = LeRobotDataset.from_preloaded(
|
lerobot_dataset = LeRobotDataset.from_preloaded(
|
||||||
|
@ -315,6 +314,12 @@ def main():
|
||||||
default=0,
|
default=0,
|
||||||
help="When set to 1, removes provided output directory if it already exists. By default, raises a ValueError exception.",
|
help="When set to 1, removes provided output directory if it already exists. By default, raises a ValueError exception.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--resume",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="When set to 1, resumes a previous run.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--tests-data-dir",
|
"--tests-data-dir",
|
||||||
type=Path,
|
type=Path,
|
||||||
|
|
|
@ -15,20 +15,25 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
|
from copy import deepcopy
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from pprint import pformat
|
from pprint import pformat
|
||||||
|
from threading import Lock
|
||||||
|
|
||||||
import hydra
|
import hydra
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from deepdiff import DeepDiff
|
from deepdiff import DeepDiff
|
||||||
from omegaconf import DictConfig, OmegaConf
|
from omegaconf import DictConfig, ListConfig, OmegaConf
|
||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.cuda.amp import GradScaler
|
from torch.cuda.amp import GradScaler
|
||||||
|
|
||||||
from lerobot.common.datasets.factory import make_dataset, resolve_delta_timestamps
|
from lerobot.common.datasets.factory import make_dataset, resolve_delta_timestamps
|
||||||
from lerobot.common.datasets.lerobot_dataset import MultiLeRobotDataset
|
from lerobot.common.datasets.lerobot_dataset import MultiLeRobotDataset
|
||||||
|
from lerobot.common.datasets.online_buffer import OnlineBuffer, compute_sampler_weights
|
||||||
from lerobot.common.datasets.sampler import EpisodeAwareSampler
|
from lerobot.common.datasets.sampler import EpisodeAwareSampler
|
||||||
from lerobot.common.datasets.utils import cycle
|
from lerobot.common.datasets.utils import cycle
|
||||||
from lerobot.common.envs.factory import make_env
|
from lerobot.common.envs.factory import make_env
|
||||||
|
@ -107,6 +112,7 @@ def update_policy(
|
||||||
grad_scaler: GradScaler,
|
grad_scaler: GradScaler,
|
||||||
lr_scheduler=None,
|
lr_scheduler=None,
|
||||||
use_amp: bool = False,
|
use_amp: bool = False,
|
||||||
|
lock=None,
|
||||||
):
|
):
|
||||||
"""Returns a dictionary of items for logging."""
|
"""Returns a dictionary of items for logging."""
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
|
@ -129,6 +135,7 @@ def update_policy(
|
||||||
|
|
||||||
# Optimizer's gradients are already unscaled, so scaler.step does not unscale them,
|
# Optimizer's gradients are already unscaled, so scaler.step does not unscale them,
|
||||||
# although it still skips optimizer.step() if the gradients contain infs or NaNs.
|
# although it still skips optimizer.step() if the gradients contain infs or NaNs.
|
||||||
|
with lock if lock is not None else nullcontext():
|
||||||
grad_scaler.step(optimizer)
|
grad_scaler.step(optimizer)
|
||||||
# Updates the scale for next iteration.
|
# Updates the scale for next iteration.
|
||||||
grad_scaler.update()
|
grad_scaler.update()
|
||||||
|
@ -149,11 +156,12 @@ def update_policy(
|
||||||
"update_s": time.perf_counter() - start_time,
|
"update_s": time.perf_counter() - start_time,
|
||||||
**{k: v for k, v in output_dict.items() if k != "loss"},
|
**{k: v for k, v in output_dict.items() if k != "loss"},
|
||||||
}
|
}
|
||||||
|
info.update({k: v for k, v in output_dict.items() if k not in info})
|
||||||
|
|
||||||
return info
|
return info
|
||||||
|
|
||||||
|
|
||||||
def log_train_info(logger: Logger, info, step, cfg, dataset, is_offline):
|
def log_train_info(logger: Logger, info, step, cfg, dataset, is_online):
|
||||||
loss = info["loss"]
|
loss = info["loss"]
|
||||||
grad_norm = info["grad_norm"]
|
grad_norm = info["grad_norm"]
|
||||||
lr = info["lr"]
|
lr = info["lr"]
|
||||||
|
@ -187,12 +195,12 @@ def log_train_info(logger: Logger, info, step, cfg, dataset, is_offline):
|
||||||
info["num_samples"] = num_samples
|
info["num_samples"] = num_samples
|
||||||
info["num_episodes"] = num_episodes
|
info["num_episodes"] = num_episodes
|
||||||
info["num_epochs"] = num_epochs
|
info["num_epochs"] = num_epochs
|
||||||
info["is_offline"] = is_offline
|
info["is_online"] = is_online
|
||||||
|
|
||||||
logger.log_dict(info, step, mode="train")
|
logger.log_dict(info, step, mode="train")
|
||||||
|
|
||||||
|
|
||||||
def log_eval_info(logger, info, step, cfg, dataset, is_offline):
|
def log_eval_info(logger, info, step, cfg, dataset, is_online):
|
||||||
eval_s = info["eval_s"]
|
eval_s = info["eval_s"]
|
||||||
avg_sum_reward = info["avg_sum_reward"]
|
avg_sum_reward = info["avg_sum_reward"]
|
||||||
pc_success = info["pc_success"]
|
pc_success = info["pc_success"]
|
||||||
|
@ -221,7 +229,7 @@ def log_eval_info(logger, info, step, cfg, dataset, is_offline):
|
||||||
info["num_samples"] = num_samples
|
info["num_samples"] = num_samples
|
||||||
info["num_episodes"] = num_episodes
|
info["num_episodes"] = num_episodes
|
||||||
info["num_epochs"] = num_epochs
|
info["num_epochs"] = num_epochs
|
||||||
info["is_offline"] = is_offline
|
info["is_online"] = is_online
|
||||||
|
|
||||||
logger.log_dict(info, step, mode="eval")
|
logger.log_dict(info, step, mode="eval")
|
||||||
|
|
||||||
|
@ -234,6 +242,9 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||||
|
|
||||||
init_logging()
|
init_logging()
|
||||||
|
|
||||||
|
if cfg.training.online_steps > 0 and isinstance(cfg.dataset_repo_id, ListConfig):
|
||||||
|
raise NotImplementedError("Online training with LeRobotMultiDataset is not implemented.")
|
||||||
|
|
||||||
# If we are resuming a run, we need to check that a checkpoint exists in the log directory, and we need
|
# If we are resuming a run, we need to check that a checkpoint exists in the log directory, and we need
|
||||||
# to check for any differences between the provided config and the checkpoint's config.
|
# to check for any differences between the provided config and the checkpoint's config.
|
||||||
if cfg.resume:
|
if cfg.resume:
|
||||||
|
@ -279,9 +290,6 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||||
# log metrics to terminal and wandb
|
# log metrics to terminal and wandb
|
||||||
logger = Logger(cfg, out_dir, wandb_job_name=job_name)
|
logger = Logger(cfg, out_dir, wandb_job_name=job_name)
|
||||||
|
|
||||||
if cfg.training.online_steps > 0:
|
|
||||||
raise NotImplementedError("Online training is not implemented yet.")
|
|
||||||
|
|
||||||
set_global_seed(cfg.seed)
|
set_global_seed(cfg.seed)
|
||||||
|
|
||||||
# Check device is available
|
# Check device is available
|
||||||
|
@ -336,7 +344,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||||
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
|
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
|
||||||
|
|
||||||
# Note: this helper will be used in offline and online training loops.
|
# Note: this helper will be used in offline and online training loops.
|
||||||
def evaluate_and_checkpoint_if_needed(step):
|
def evaluate_and_checkpoint_if_needed(step, is_online):
|
||||||
_num_digits = max(6, len(str(cfg.training.offline_steps + cfg.training.online_steps)))
|
_num_digits = max(6, len(str(cfg.training.offline_steps + cfg.training.online_steps)))
|
||||||
step_identifier = f"{step:0{_num_digits}d}"
|
step_identifier = f"{step:0{_num_digits}d}"
|
||||||
|
|
||||||
|
@ -352,7 +360,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||||
max_episodes_rendered=4,
|
max_episodes_rendered=4,
|
||||||
start_seed=cfg.seed,
|
start_seed=cfg.seed,
|
||||||
)
|
)
|
||||||
log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_dataset, is_offline=True)
|
log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_dataset, is_online=is_online)
|
||||||
if cfg.wandb.enable:
|
if cfg.wandb.enable:
|
||||||
logger.log_video(eval_info["video_paths"][0], step, mode="eval")
|
logger.log_video(eval_info["video_paths"][0], step, mode="eval")
|
||||||
logging.info("Resume training")
|
logging.info("Resume training")
|
||||||
|
@ -396,8 +404,9 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||||
dl_iter = cycle(dataloader)
|
dl_iter = cycle(dataloader)
|
||||||
|
|
||||||
policy.train()
|
policy.train()
|
||||||
|
offline_step = 0
|
||||||
for _ in range(step, cfg.training.offline_steps):
|
for _ in range(step, cfg.training.offline_steps):
|
||||||
if step == 0:
|
if offline_step == 0:
|
||||||
logging.info("Start offline training on a fixed dataset")
|
logging.info("Start offline training on a fixed dataset")
|
||||||
|
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
|
@ -420,13 +429,207 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||||
train_info["dataloading_s"] = dataloading_s
|
train_info["dataloading_s"] = dataloading_s
|
||||||
|
|
||||||
if step % cfg.training.log_freq == 0:
|
if step % cfg.training.log_freq == 0:
|
||||||
log_train_info(logger, train_info, step, cfg, offline_dataset, is_offline=True)
|
log_train_info(logger, train_info, step, cfg, offline_dataset, is_online=False)
|
||||||
|
|
||||||
# Note: evaluate_and_checkpoint_if_needed happens **after** the `step`th training update has completed,
|
# Note: evaluate_and_checkpoint_if_needed happens **after** the `step`th training update has completed,
|
||||||
# so we pass in step + 1.
|
# so we pass in step + 1.
|
||||||
evaluate_and_checkpoint_if_needed(step + 1)
|
evaluate_and_checkpoint_if_needed(step + 1, is_online=False)
|
||||||
|
|
||||||
step += 1
|
step += 1
|
||||||
|
offline_step += 1 # noqa: SIM113
|
||||||
|
|
||||||
|
if cfg.training.online_steps == 0:
|
||||||
|
if eval_env:
|
||||||
|
eval_env.close()
|
||||||
|
logging.info("End of training")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Online training.
|
||||||
|
|
||||||
|
# Create an env dedicated to online episodes collection from policy rollout.
|
||||||
|
online_env = make_env(cfg, n_envs=cfg.training.online_rollout_batch_size)
|
||||||
|
resolve_delta_timestamps(cfg)
|
||||||
|
online_buffer_path = logger.log_dir / "online_buffer"
|
||||||
|
if cfg.resume and not online_buffer_path.exists():
|
||||||
|
# If we are resuming a run, we default to the data shapes and buffer capacity from the saved online
|
||||||
|
# buffer.
|
||||||
|
logging.warning(
|
||||||
|
"When online training is resumed, we load the latest online buffer from the prior run, "
|
||||||
|
"and this might not coincide with the state of the buffer as it was at the moment the checkpoint "
|
||||||
|
"was made. This is because the online buffer is updated on disk during training, independently "
|
||||||
|
"of our explicit checkpointing mechanisms."
|
||||||
|
)
|
||||||
|
online_dataset = OnlineBuffer(
|
||||||
|
online_buffer_path,
|
||||||
|
data_spec={
|
||||||
|
**{k: {"shape": v, "dtype": np.dtype("float32")} for k, v in policy.config.input_shapes.items()},
|
||||||
|
**{k: {"shape": v, "dtype": np.dtype("float32")} for k, v in policy.config.output_shapes.items()},
|
||||||
|
"next.reward": {"shape": (), "dtype": np.dtype("float32")},
|
||||||
|
"next.done": {"shape": (), "dtype": np.dtype("?")},
|
||||||
|
"next.success": {"shape": (), "dtype": np.dtype("?")},
|
||||||
|
},
|
||||||
|
buffer_capacity=cfg.training.online_buffer_capacity,
|
||||||
|
fps=online_env.unwrapped.metadata["render_fps"],
|
||||||
|
delta_timestamps=cfg.training.delta_timestamps,
|
||||||
|
)
|
||||||
|
|
||||||
|
# If we are doing online rollouts asynchronously, deepcopy the policy to use for online rollouts (this
|
||||||
|
# makes it possible to do online rollouts in parallel with training updates).
|
||||||
|
online_rollout_policy = deepcopy(policy) if cfg.training.do_online_rollout_async else policy
|
||||||
|
|
||||||
|
# Create dataloader for online training.
|
||||||
|
concat_dataset = torch.utils.data.ConcatDataset([offline_dataset, online_dataset])
|
||||||
|
sampler_weights = compute_sampler_weights(
|
||||||
|
offline_dataset,
|
||||||
|
offline_drop_n_last_frames=cfg.training.get("drop_n_last_frames", 0),
|
||||||
|
online_dataset=online_dataset,
|
||||||
|
# +1 because online rollouts return an extra frame for the "final observation". Note: we don't have
|
||||||
|
# this final observation in the offline datasets, but we might add them in future.
|
||||||
|
online_drop_n_last_frames=cfg.training.get("drop_n_last_frames", 0) + 1,
|
||||||
|
online_sampling_ratio=cfg.training.online_sampling_ratio,
|
||||||
|
)
|
||||||
|
sampler = torch.utils.data.WeightedRandomSampler(
|
||||||
|
sampler_weights,
|
||||||
|
num_samples=len(concat_dataset),
|
||||||
|
replacement=True,
|
||||||
|
)
|
||||||
|
dataloader = torch.utils.data.DataLoader(
|
||||||
|
concat_dataset,
|
||||||
|
batch_size=cfg.training.batch_size,
|
||||||
|
num_workers=cfg.training.num_workers,
|
||||||
|
sampler=sampler,
|
||||||
|
pin_memory=device.type != "cpu",
|
||||||
|
drop_last=True,
|
||||||
|
)
|
||||||
|
dl_iter = cycle(dataloader)
|
||||||
|
|
||||||
|
# Lock and thread pool executor for asynchronous online rollouts. When asynchronous mode is disabled,
|
||||||
|
# these are still used but effectively do nothing.
|
||||||
|
lock = Lock()
|
||||||
|
# Note: 1 worker because we only ever want to run one set of online rollouts at a time. Batch
|
||||||
|
# parallelization of rollouts is handled within the job.
|
||||||
|
executor = ThreadPoolExecutor(max_workers=1)
|
||||||
|
|
||||||
|
online_step = 0
|
||||||
|
online_rollout_s = 0 # time take to do online rollout
|
||||||
|
update_online_buffer_s = 0 # time taken to update the online buffer with the online rollout data
|
||||||
|
# Time taken waiting for the online buffer to finish being updated. This is relevant when using the async
|
||||||
|
# online rollout option.
|
||||||
|
await_update_online_buffer_s = 0
|
||||||
|
rollout_start_seed = cfg.training.online_env_seed
|
||||||
|
|
||||||
|
while True:
|
||||||
|
if online_step == cfg.training.online_steps:
|
||||||
|
break
|
||||||
|
|
||||||
|
if online_step == 0:
|
||||||
|
logging.info("Start online training by interacting with environment")
|
||||||
|
|
||||||
|
def sample_trajectory_and_update_buffer():
|
||||||
|
nonlocal rollout_start_seed
|
||||||
|
with lock:
|
||||||
|
online_rollout_policy.load_state_dict(policy.state_dict())
|
||||||
|
online_rollout_policy.eval()
|
||||||
|
start_rollout_time = time.perf_counter()
|
||||||
|
with torch.no_grad():
|
||||||
|
eval_info = eval_policy(
|
||||||
|
online_env,
|
||||||
|
online_rollout_policy,
|
||||||
|
n_episodes=cfg.training.online_rollout_n_episodes,
|
||||||
|
max_episodes_rendered=min(10, cfg.training.online_rollout_n_episodes),
|
||||||
|
videos_dir=logger.log_dir / "online_rollout_videos",
|
||||||
|
return_episode_data=True,
|
||||||
|
start_seed=(
|
||||||
|
rollout_start_seed := (rollout_start_seed + cfg.training.batch_size) % 1000000
|
||||||
|
),
|
||||||
|
)
|
||||||
|
online_rollout_s = time.perf_counter() - start_rollout_time
|
||||||
|
|
||||||
|
with lock:
|
||||||
|
start_update_buffer_time = time.perf_counter()
|
||||||
|
online_dataset.add_data(eval_info["episodes"])
|
||||||
|
|
||||||
|
# Update the concatenated dataset length used during sampling.
|
||||||
|
concat_dataset.cumulative_sizes = concat_dataset.cumsum(concat_dataset.datasets)
|
||||||
|
|
||||||
|
# Update the sampling weights.
|
||||||
|
sampler.weights = compute_sampler_weights(
|
||||||
|
offline_dataset,
|
||||||
|
offline_drop_n_last_frames=cfg.training.get("drop_n_last_frames", 0),
|
||||||
|
online_dataset=online_dataset,
|
||||||
|
# +1 because online rollouts return an extra frame for the "final observation". Note: we don't have
|
||||||
|
# this final observation in the offline datasets, but we might add them in future.
|
||||||
|
online_drop_n_last_frames=cfg.training.get("drop_n_last_frames", 0) + 1,
|
||||||
|
online_sampling_ratio=cfg.training.online_sampling_ratio,
|
||||||
|
)
|
||||||
|
sampler.num_samples = len(concat_dataset)
|
||||||
|
|
||||||
|
update_online_buffer_s = time.perf_counter() - start_update_buffer_time
|
||||||
|
|
||||||
|
return online_rollout_s, update_online_buffer_s
|
||||||
|
|
||||||
|
future = executor.submit(sample_trajectory_and_update_buffer)
|
||||||
|
# If we aren't doing async rollouts, or if we haven't yet gotten enough examples in our buffer, wait
|
||||||
|
# here until the rollout and buffer update is done, before proceeding to the policy update steps.
|
||||||
|
if (
|
||||||
|
not cfg.training.do_online_rollout_async
|
||||||
|
or len(online_dataset) <= cfg.training.online_buffer_seed_size
|
||||||
|
):
|
||||||
|
online_rollout_s, update_online_buffer_s = future.result()
|
||||||
|
|
||||||
|
if len(online_dataset) <= cfg.training.online_buffer_seed_size:
|
||||||
|
logging.info(
|
||||||
|
f"Seeding online buffer: {len(online_dataset)}/{cfg.training.online_buffer_seed_size}"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
policy.train()
|
||||||
|
for _ in range(cfg.training.online_steps_between_rollouts):
|
||||||
|
with lock:
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
batch = next(dl_iter)
|
||||||
|
dataloading_s = time.perf_counter() - start_time
|
||||||
|
|
||||||
|
for key in batch:
|
||||||
|
batch[key] = batch[key].to(cfg.device, non_blocking=True)
|
||||||
|
|
||||||
|
train_info = update_policy(
|
||||||
|
policy,
|
||||||
|
batch,
|
||||||
|
optimizer,
|
||||||
|
cfg.training.grad_clip_norm,
|
||||||
|
grad_scaler=grad_scaler,
|
||||||
|
lr_scheduler=lr_scheduler,
|
||||||
|
use_amp=cfg.use_amp,
|
||||||
|
lock=lock,
|
||||||
|
)
|
||||||
|
|
||||||
|
train_info["dataloading_s"] = dataloading_s
|
||||||
|
train_info["online_rollout_s"] = online_rollout_s
|
||||||
|
train_info["update_online_buffer_s"] = update_online_buffer_s
|
||||||
|
train_info["await_update_online_buffer_s"] = await_update_online_buffer_s
|
||||||
|
with lock:
|
||||||
|
train_info["online_buffer_size"] = len(online_dataset)
|
||||||
|
|
||||||
|
if step % cfg.training.log_freq == 0:
|
||||||
|
log_train_info(logger, train_info, step, cfg, online_dataset, is_online=True)
|
||||||
|
|
||||||
|
# Note: evaluate_and_checkpoint_if_needed happens **after** the `step`th training update has completed,
|
||||||
|
# so we pass in step + 1.
|
||||||
|
evaluate_and_checkpoint_if_needed(step + 1, is_online=True)
|
||||||
|
|
||||||
|
step += 1
|
||||||
|
online_step += 1
|
||||||
|
|
||||||
|
# If we're doing async rollouts, we should now wait until we've completed them before proceeding
|
||||||
|
# to do the next batch of rollouts.
|
||||||
|
if future.running():
|
||||||
|
start = time.perf_counter()
|
||||||
|
online_rollout_s, update_online_buffer_s = future.result()
|
||||||
|
await_update_online_buffer_s = time.perf_counter() - start
|
||||||
|
|
||||||
|
if online_step >= cfg.training.online_steps:
|
||||||
|
break
|
||||||
|
|
||||||
if eval_env:
|
if eval_env:
|
||||||
eval_env.close()
|
eval_env.close()
|
||||||
|
|
|
@ -3212,23 +3212,6 @@ pytest = ">=4.6"
|
||||||
[package.extras]
|
[package.extras]
|
||||||
testing = ["fields", "hunter", "process-tests", "pytest-xdist", "virtualenv"]
|
testing = ["fields", "hunter", "process-tests", "pytest-xdist", "virtualenv"]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "pytest-mock"
|
|
||||||
version = "3.14.0"
|
|
||||||
description = "Thin-wrapper around the mock package for easier use with pytest"
|
|
||||||
optional = true
|
|
||||||
python-versions = ">=3.8"
|
|
||||||
files = [
|
|
||||||
{file = "pytest-mock-3.14.0.tar.gz", hash = "sha256:2719255a1efeceadbc056d6bf3df3d1c5015530fb40cf347c0f9afac88410bd0"},
|
|
||||||
{file = "pytest_mock-3.14.0-py3-none-any.whl", hash = "sha256:0b72c38033392a5f4621342fe11e9219ac11ec9d375f8e2a0c164539e0d70f6f"},
|
|
||||||
]
|
|
||||||
|
|
||||||
[package.dependencies]
|
|
||||||
pytest = ">=6.2.5"
|
|
||||||
|
|
||||||
[package.extras]
|
|
||||||
dev = ["pre-commit", "pytest-asyncio", "tox"]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "python-dateutil"
|
name = "python-dateutil"
|
||||||
version = "2.9.0.post0"
|
version = "2.9.0.post0"
|
||||||
|
@ -4494,7 +4477,7 @@ dev = ["debugpy", "pre-commit"]
|
||||||
dora = ["gym-dora"]
|
dora = ["gym-dora"]
|
||||||
koch = ["dynamixel-sdk", "pynput"]
|
koch = ["dynamixel-sdk", "pynput"]
|
||||||
pusht = ["gym-pusht"]
|
pusht = ["gym-pusht"]
|
||||||
test = ["pytest", "pytest-cov", "pytest-mock"]
|
test = ["pytest", "pytest-cov"]
|
||||||
umi = ["imagecodecs"]
|
umi = ["imagecodecs"]
|
||||||
video-benchmark = ["pandas", "scikit-image"]
|
video-benchmark = ["pandas", "scikit-image"]
|
||||||
xarm = ["gym-xarm"]
|
xarm = ["gym-xarm"]
|
||||||
|
@ -4502,4 +4485,4 @@ xarm = ["gym-xarm"]
|
||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = ">=3.10,<3.13"
|
python-versions = ">=3.10,<3.13"
|
||||||
content-hash = "882b44dada0890dd4e1c727d3363d95cbe1a4adf1d80aa5263080597d80be42c"
|
content-hash = "dfe9c6a54e0382156e62e7bd2c7aab1be6372da76d30c61b06d27232276638cb"
|
||||||
|
|
|
@ -62,7 +62,6 @@ rerun-sdk = ">=0.15.1"
|
||||||
deepdiff = ">=7.0.1"
|
deepdiff = ">=7.0.1"
|
||||||
scikit-image = {version = ">=0.23.2", optional = true}
|
scikit-image = {version = ">=0.23.2", optional = true}
|
||||||
pandas = {version = ">=2.2.2", optional = true}
|
pandas = {version = ">=2.2.2", optional = true}
|
||||||
pytest-mock = {version = ">=3.14.0", optional = true}
|
|
||||||
dynamixel-sdk = {version = ">=3.7.31", optional = true}
|
dynamixel-sdk = {version = ">=3.7.31", optional = true}
|
||||||
pynput = {version = ">=1.7.7", optional = true}
|
pynput = {version = ">=1.7.7", optional = true}
|
||||||
|
|
||||||
|
@ -74,7 +73,7 @@ pusht = ["gym-pusht"]
|
||||||
xarm = ["gym-xarm"]
|
xarm = ["gym-xarm"]
|
||||||
aloha = ["gym-aloha"]
|
aloha = ["gym-aloha"]
|
||||||
dev = ["pre-commit", "debugpy"]
|
dev = ["pre-commit", "debugpy"]
|
||||||
test = ["pytest", "pytest-cov", "pytest-mock"]
|
test = ["pytest", "pytest-cov"]
|
||||||
umi = ["imagecodecs"]
|
umi = ["imagecodecs"]
|
||||||
video_benchmark = ["scikit-image", "pandas"]
|
video_benchmark = ["scikit-image", "pandas"]
|
||||||
koch = ["dynamixel-sdk", "pynput"]
|
koch = ["dynamixel-sdk", "pynput"]
|
||||||
|
@ -110,7 +109,6 @@ exclude = [
|
||||||
|
|
||||||
[tool.ruff.lint]
|
[tool.ruff.lint]
|
||||||
select = ["E4", "E7", "E9", "F", "I", "N", "B", "C4", "SIM"]
|
select = ["E4", "E7", "E9", "F", "I", "N", "B", "C4", "SIM"]
|
||||||
ignore-init-module-imports = true
|
|
||||||
|
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
version https://git-lfs.github.com/spec/v1
|
version https://git-lfs.github.com/spec/v1
|
||||||
oid sha256:9f9347c8d9ac90ee44e6dd86f65043438168df6bbe4bab2d2b875e55ef7376ef
|
oid sha256:7841afb9ef99c0601448c43a20c25eb029440c73816319c67c5d7e1c5cde2445
|
||||||
size 1488
|
size 136
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
version https://git-lfs.github.com/spec/v1
|
version https://git-lfs.github.com/spec/v1
|
||||||
oid sha256:cf148247bf191c7f7e8af738a7b9e147f9ffffeec0e4b9d1c4783c4e384da7eb
|
oid sha256:50e40e4c2bb523fca0b54e9a9635281312e9c6f9d757db03c06a0865c5508f29
|
||||||
size 33
|
size 188
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
version https://git-lfs.github.com/spec/v1
|
version https://git-lfs.github.com/spec/v1
|
||||||
oid sha256:02fc4ea25766269f65752a60b0594c43d799b0ae528cd773bf024b064b5aa329
|
oid sha256:03508d82db846a804aef1a28aec3cb9572e3105b55a02b6ddbb09b2522d57b84
|
||||||
size 4344
|
size 4344
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
version https://git-lfs.github.com/spec/v1
|
version https://git-lfs.github.com/spec/v1
|
||||||
oid sha256:55d7b1a06fe3e3051482752740074348bdb5fc98fb2e305b06d6203994117b27
|
oid sha256:7009b3d2f14d6af497eeb32a52332e79cb9c07db24a6c2bbfbeffbaa8151dd69
|
||||||
size 592448
|
size 592448
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
version https://git-lfs.github.com/spec/v1
|
version https://git-lfs.github.com/spec/v1
|
||||||
oid sha256:98329e4b40e9be0d63f7d36da9d86c44bbe7eeeb1b10d3ba973c923f3be70867
|
oid sha256:34ece24fb6b302db0b68987858509f31713fb299faa9a9d34b8fd68f10bc3100
|
||||||
size 247
|
size 247
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
version https://git-lfs.github.com/spec/v1
|
version https://git-lfs.github.com/spec/v1
|
||||||
oid sha256:54e42cdfd016a0ced2ab1fe2966a8c15a2384e0dbe1a2fe87433a2d1b8209ac0
|
oid sha256:a70cc17019407cf6bee44fa2c78b4f29e48eb1696aa1a4ff4c048ba256574523
|
||||||
size 5220057
|
size 6356921
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
version https://git-lfs.github.com/spec/v1
|
version https://git-lfs.github.com/spec/v1
|
||||||
oid sha256:af1ded2a244cb47a96255b75f584a643edf6967e13bb5464b330ffdd9d7ad859
|
oid sha256:2b35992036e6dcee7d4df6d1675d55d1dd2d658b2d65442737e709895699a2f0
|
||||||
size 5284692
|
size 5084448
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
version https://git-lfs.github.com/spec/v1
|
version https://git-lfs.github.com/spec/v1
|
||||||
oid sha256:13d1bebabd79984fd6715971be758ef9a354495adea5e8d33f4e7904365e112b
|
oid sha256:3aa92e6b6bd0e39f6de530ea6a270671db7350cdc101c9d9030c775539c708c1
|
||||||
size 5258380
|
size 5441406
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
version https://git-lfs.github.com/spec/v1
|
version https://git-lfs.github.com/spec/v1
|
||||||
oid sha256:f33bc6810f0b91817a42610364cb49ed1b99660f058f0f9407e6f5920d0aee02
|
oid sha256:4ee862b1a6dc1d11df77c36c47ea00db88ad35a48e4d71c2940ad26b55fe2167
|
||||||
size 1008
|
size 136
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
version https://git-lfs.github.com/spec/v1
|
version https://git-lfs.github.com/spec/v1
|
||||||
oid sha256:cf148247bf191c7f7e8af738a7b9e147f9ffffeec0e4b9d1c4783c4e384da7eb
|
oid sha256:50e40e4c2bb523fca0b54e9a9635281312e9c6f9d757db03c06a0865c5508f29
|
||||||
size 33
|
size 188
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
version https://git-lfs.github.com/spec/v1
|
version https://git-lfs.github.com/spec/v1
|
||||||
oid sha256:7b58d6c89e936a781a307805ebecf0dd473fbc02d52a7094da62e54bffb9454a
|
oid sha256:095c30bfe3c5da168c85aceef905e74e2142866332282965aa6812f6e6e48448
|
||||||
size 4344
|
size 4344
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
version https://git-lfs.github.com/spec/v1
|
version https://git-lfs.github.com/spec/v1
|
||||||
oid sha256:a08be578285cbe2d35b78f150d464ff3e10604a9865398c976983e0d711774f9
|
oid sha256:98859f2d87e1a0abb9a930a82af623504b3efb26f70fe576f05bab7f19024427
|
||||||
size 788528
|
size 788528
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
version https://git-lfs.github.com/spec/v1
|
version https://git-lfs.github.com/spec/v1
|
||||||
oid sha256:34e36233477c8aa0b0840314ddace072062d4f486d06546bbd6550832c370065
|
oid sha256:38cf4116a65cb92a5c43f9b9da7a7b81cfa9168b17605c8c456f7d3a3a23b77a
|
||||||
size 247
|
size 247
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
version https://git-lfs.github.com/spec/v1
|
version https://git-lfs.github.com/spec/v1
|
||||||
oid sha256:66e7349a4a82ca6042a7189608d01eb1cfa38d100d039b5445ae1a9e65d824ab
|
oid sha256:596dda720d378a44b6b61a6a72b44bec3e55e85198bca37f9dace6fe84af7ff0
|
||||||
size 14470946
|
size 16062396
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
version https://git-lfs.github.com/spec/v1
|
version https://git-lfs.github.com/spec/v1
|
||||||
oid sha256:a2146f0c10c9f2611e57e617983aa4f91ad681b4fc50d91b992b97abd684f926
|
oid sha256:c614bbaf93d65354a82001b357682a0bd36f9603685f6c735c5e377b763d0bdb
|
||||||
size 11662185
|
size 10317415
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
version https://git-lfs.github.com/spec/v1
|
version https://git-lfs.github.com/spec/v1
|
||||||
oid sha256:5affbaf1c48895ba3c626e0d8cf1309e5f4ec6bbaa135313096f52a22de66c05
|
oid sha256:868788028a38334b6b566cb17ffcc2ace2ec2b2b68ff2a58b6d29eb3c3e2ec1f
|
||||||
size 11410342
|
size 9516445
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
version https://git-lfs.github.com/spec/v1
|
version https://git-lfs.github.com/spec/v1
|
||||||
oid sha256:6c2b195ca91b88fd16422128d386d2cabd808a1862c6d127e6bf2e83e1fe819a
|
oid sha256:f365a02b052a2697b1558f4ab9b813f0d4ba46a5bc6ae3da30bbc4b135426aa6
|
||||||
size 448
|
size 136
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
version https://git-lfs.github.com/spec/v1
|
version https://git-lfs.github.com/spec/v1
|
||||||
oid sha256:cf148247bf191c7f7e8af738a7b9e147f9ffffeec0e4b9d1c4783c4e384da7eb
|
oid sha256:50e40e4c2bb523fca0b54e9a9635281312e9c6f9d757db03c06a0865c5508f29
|
||||||
size 33
|
size 188
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
version https://git-lfs.github.com/spec/v1
|
version https://git-lfs.github.com/spec/v1
|
||||||
oid sha256:b360b6b956d2adcb20589947c553348ef1eb6b70743c989dcbe95243d8592ce5
|
oid sha256:5c96f47b569b7af82e05200213d733626664150aa7c5ae3298fd04a2138a2023
|
||||||
size 4344
|
size 4344
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
version https://git-lfs.github.com/spec/v1
|
version https://git-lfs.github.com/spec/v1
|
||||||
oid sha256:3f5c3926b4d4da9271abefcdf6a8952bb1f13258a9c39fe0fd223f548dc89dcb
|
oid sha256:75f53d221827f17cc2ded3908452e24331b39b79dc3a26f2b9d89a6e6894baab
|
||||||
size 887728
|
size 887728
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
version https://git-lfs.github.com/spec/v1
|
version https://git-lfs.github.com/spec/v1
|
||||||
oid sha256:4993b05fb026619eec5eb70db8cadaa041ba4ab92d38b4a387167ace03b1018b
|
oid sha256:d394d451929b805f2d94f9fc5b12d15c31cfc494df76d7d642b63378b8ba0131
|
||||||
size 247
|
size 247
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
version https://git-lfs.github.com/spec/v1
|
version https://git-lfs.github.com/spec/v1
|
||||||
oid sha256:bd25d17ef5b7500386761b5e32920879bbdcafe0e17a8a8845628525d861e644
|
oid sha256:73ddb898f83589b4bcabe978e46e75f20be215492f115bf6ebc98f1d01e1eff8
|
||||||
size 10231081
|
size 9696507
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
version https://git-lfs.github.com/spec/v1
|
version https://git-lfs.github.com/spec/v1
|
||||||
oid sha256:5b557acbfeb0681c0a38e47263d945f6cd3a03461298d8b17209c81e3fd0aae8
|
oid sha256:d3d993977bee96882732d4a9c9d082c356fc9fcd8199c027b016207d60494c2f
|
||||||
size 9701371
|
size 8957007
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
version https://git-lfs.github.com/spec/v1
|
version https://git-lfs.github.com/spec/v1
|
||||||
oid sha256:da8f3b4f9f965da63819652b2c042d4cf7e07d14631113ea072087d56370310e
|
oid sha256:c9321627184c14af4a6ba64d02e86f7253bc1f563a3adef17036d68480d2bb3e
|
||||||
size 10473741
|
size 9938178
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
version https://git-lfs.github.com/spec/v1
|
version https://git-lfs.github.com/spec/v1
|
||||||
oid sha256:a053506017d8a78cfd307b2912eeafa1ac1485a280cf90913985fcc40120b5ec
|
oid sha256:88346956fdf58f17dba7b08cc858364ed8278a7baa20febd9c68ae959d2c9c82
|
||||||
size 416
|
size 136
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
version https://git-lfs.github.com/spec/v1
|
version https://git-lfs.github.com/spec/v1
|
||||||
oid sha256:cf148247bf191c7f7e8af738a7b9e147f9ffffeec0e4b9d1c4783c4e384da7eb
|
oid sha256:50e40e4c2bb523fca0b54e9a9635281312e9c6f9d757db03c06a0865c5508f29
|
||||||
size 33
|
size 188
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
version https://git-lfs.github.com/spec/v1
|
version https://git-lfs.github.com/spec/v1
|
||||||
oid sha256:d6d172d1bca02face22ceb4c21ea2b054cf3463025485dce64711b6f36b31f8a
|
oid sha256:de80d5afc044be903a89ee08f30cfef5fb4c1e928d8ba8f4d81ea9d0bb4fb011
|
||||||
size 4344
|
size 4344
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
version https://git-lfs.github.com/spec/v1
|
version https://git-lfs.github.com/spec/v1
|
||||||
oid sha256:7e5ce817a2c188041f57f8d4c465dab3b9c3e4e1aeb7a9fb270230d1b36df530
|
oid sha256:79c2a3da1024fa140d23e8438b2756d27cf5db65ac70d7ac4215260b55ca55f8
|
||||||
size 1477064
|
size 1477064
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
version https://git-lfs.github.com/spec/v1
|
version https://git-lfs.github.com/spec/v1
|
||||||
oid sha256:4eb2dc373e4ea7d474742590f9073d66a773f6ab94b9e73a8673df19f93fae6d
|
oid sha256:69435f30146a309c8d7d0eb01216555bf0547095db1fc9c20218d481d6fe62c8
|
||||||
size 247
|
size 247
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
version https://git-lfs.github.com/spec/v1
|
version https://git-lfs.github.com/spec/v1
|
||||||
oid sha256:d2c55b146fabe78b18c8a28a7746ab56e1ee7a6918e9e3dad9bd196f97975895
|
oid sha256:3fc89b720dfb7511d5dd9eba31494cf720e6a89519067b7b5a4d65f0a539c811
|
||||||
size 26158915
|
size 35137505
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
version https://git-lfs.github.com/spec/v1
|
version https://git-lfs.github.com/spec/v1
|
||||||
oid sha256:71e1958d77f56843acf1ec48da4f04311a5836c87a0e77dbe26aa47c27c6347e
|
oid sha256:26b8d97a096aa8a1d686d86fc93bde1dcdd50a9dc273f49f3b6a700fe6610e88
|
||||||
size 18786848
|
size 20387806
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
version https://git-lfs.github.com/spec/v1
|
version https://git-lfs.github.com/spec/v1
|
||||||
oid sha256:20780718399b5759ff9a3a79824986310524793066198e3b9a307222f11a93df
|
oid sha256:72000be2803259f40da6d093279d17ed194ead3ebc508bf2d77cb463bcb67c4d
|
||||||
size 17769988
|
size 17594265
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
version https://git-lfs.github.com/spec/v1
|
version https://git-lfs.github.com/spec/v1
|
||||||
oid sha256:279916f7689ae46af90e92a46eba9486a71fc762e3e2679ab5441eb37126827b
|
oid sha256:fb6de86fee6ff3cc5d61d591fe480a50feb289c05770e3f4b76e24138b571c65
|
||||||
size 928
|
size 136
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
version https://git-lfs.github.com/spec/v1
|
version https://git-lfs.github.com/spec/v1
|
||||||
oid sha256:cf148247bf191c7f7e8af738a7b9e147f9ffffeec0e4b9d1c4783c4e384da7eb
|
oid sha256:50e40e4c2bb523fca0b54e9a9635281312e9c6f9d757db03c06a0865c5508f29
|
||||||
size 33
|
size 188
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
version https://git-lfs.github.com/spec/v1
|
version https://git-lfs.github.com/spec/v1
|
||||||
oid sha256:7a7731051b521694b52b5631470720a7f05331915f4ac4e7f8cd83f9ff459bce
|
oid sha256:d79027c2513c01a7e360f3177e62ab955e5d3f704f1e7127a6e1e852158ec42c
|
||||||
size 4344
|
size 4344
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
version https://git-lfs.github.com/spec/v1
|
version https://git-lfs.github.com/spec/v1
|
||||||
oid sha256:99608258e8c9fe5191f1a12edc29b47d307790104149dffb6d3046ddad6aeb1b
|
oid sha256:0a2c1f98c816728136291fcb7530cd0ebcf4ea47b0f6750836da56b8324d64c1
|
||||||
size 435600
|
size 435600
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
version https://git-lfs.github.com/spec/v1
|
version https://git-lfs.github.com/spec/v1
|
||||||
oid sha256:ae6735b7b394914824e974a7461019373a10f9e2d84ddf834bec8ea268d9ec1e
|
oid sha256:921505133c62906bd53034a613a827996994875d84c8b26d69d188df9a7ffeba
|
||||||
size 247
|
size 247
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
version https://git-lfs.github.com/spec/v1
|
version https://git-lfs.github.com/spec/v1
|
||||||
oid sha256:833e288c5fdacbbe10a5d048cb6f49fe1a396d91b2117b827e130ec11069256a
|
oid sha256:7e298db7d820e2ff9f0b9c250e800e8ada3521fdeae3c4127452dd62700e9ac8
|
||||||
size 8397615
|
size 10980189
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
version https://git-lfs.github.com/spec/v1
|
version https://git-lfs.github.com/spec/v1
|
||||||
oid sha256:2cb870acb4855fef70f19c5f632d94e4c25eef59eeea92f4b1167a44b1b36b33
|
oid sha256:29b46c2e823d62b1329b98a3d7efffd24fc6c904e9cea115e2f0adb1bb45db44
|
||||||
size 5912007
|
size 7229025
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
version https://git-lfs.github.com/spec/v1
|
version https://git-lfs.github.com/spec/v1
|
||||||
oid sha256:8be36298141b455ea51d17a78e4bbc6619639302139fe2db605bdfa3ff5e91bd
|
oid sha256:f34ddbd109b212260c758d54a0930f75a38666a178a0d26eeefa846cfeac86c0
|
||||||
size 4794018
|
size 5944469
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
version https://git-lfs.github.com/spec/v1
|
version https://git-lfs.github.com/spec/v1
|
||||||
oid sha256:348d0ee38a71929b2017d540de870b9dff6d79efdd0cbc5352fa9697e350134a
|
oid sha256:1386f9030607facefe56f429c93e50df0e22017914ce3f21ab67edc87b936d9d
|
||||||
size 928
|
size 136
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
version https://git-lfs.github.com/spec/v1
|
version https://git-lfs.github.com/spec/v1
|
||||||
oid sha256:cf148247bf191c7f7e8af738a7b9e147f9ffffeec0e4b9d1c4783c4e384da7eb
|
oid sha256:50e40e4c2bb523fca0b54e9a9635281312e9c6f9d757db03c06a0865c5508f29
|
||||||
size 33
|
size 188
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
version https://git-lfs.github.com/spec/v1
|
version https://git-lfs.github.com/spec/v1
|
||||||
oid sha256:c5c2996f58d5277fa19cf56ec143334fbee940d1de37530452496a6f0aa11f88
|
oid sha256:7ffb173891cebb47a4d24d051f5fdd2ec44493d0a1a48d11f4d1410049aadd5b
|
||||||
size 4344
|
size 4344
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
version https://git-lfs.github.com/spec/v1
|
version https://git-lfs.github.com/spec/v1
|
||||||
oid sha256:da3a8efea9ba60d1fdd209d45a3387df22a09f7c156904ecb03f10456736fb74
|
oid sha256:ae1760af2d3bf13c6e868643f203e76e1faf81a237715f72f2b81c3199e95e96
|
||||||
size 514056
|
size 514056
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
version https://git-lfs.github.com/spec/v1
|
version https://git-lfs.github.com/spec/v1
|
||||||
oid sha256:6b7111ff1ef5c4d6a2990f5f39f42398f061da8c4e81adf46b9d9150ec2feeaf
|
oid sha256:505a42c408d56c8a7d3e2367280b41e27667b58334f32e84c937c44c38217bd6
|
||||||
size 247
|
size 247
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
version https://git-lfs.github.com/spec/v1
|
version https://git-lfs.github.com/spec/v1
|
||||||
oid sha256:2ac8c2755d940534042595ecad33ebea358974ec67bc041c8675e53b7d2272ff
|
oid sha256:1489dac711fb99b192f064f9dbe56ed0e9e80fedc34da469e85acc7d5b4d75bf
|
||||||
size 9182551
|
size 12316772
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
version https://git-lfs.github.com/spec/v1
|
version https://git-lfs.github.com/spec/v1
|
||||||
oid sha256:b35aaa37e66dd5563d93e6059d5b645e112e020e03bd398f7098a5289970953a
|
oid sha256:20edc20184b5e4eb45194016fe7a0a5673665e3105286e0c6563767b5ff461f3
|
||||||
size 6378566
|
size 6365474
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
version https://git-lfs.github.com/spec/v1
|
version https://git-lfs.github.com/spec/v1
|
||||||
oid sha256:6138247ba7160a3de6c50111e6fcc5ae075044086d8527ae5d435b1f8a7c7a93
|
oid sha256:4ccdc96d9fe560a841e45e9fa636b69ef35b518271982339516517a4ae47d04f
|
||||||
size 6439183
|
size 7449799
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
version https://git-lfs.github.com/spec/v1
|
version https://git-lfs.github.com/spec/v1
|
||||||
oid sha256:e11c127b1ab12f0761bce6651fa5a4882093924df230294f2f34309bc74b0707
|
oid sha256:9ee4f3c571ce6822e157e60133bee02245febee93eba5d35458d3c83345f7b87
|
||||||
size 672
|
size 136
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
version https://git-lfs.github.com/spec/v1
|
version https://git-lfs.github.com/spec/v1
|
||||||
oid sha256:cf148247bf191c7f7e8af738a7b9e147f9ffffeec0e4b9d1c4783c4e384da7eb
|
oid sha256:50e40e4c2bb523fca0b54e9a9635281312e9c6f9d757db03c06a0865c5508f29
|
||||||
size 33
|
size 188
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
version https://git-lfs.github.com/spec/v1
|
version https://git-lfs.github.com/spec/v1
|
||||||
oid sha256:5efab3606f50ee7ac8bb0c88cfeefd86bfd060dbb75d063e01d09456da020026
|
oid sha256:b05f933aa67d559e44f062c8428b2f85ee7b49d3bf0e0302b9b83fb7d48ed0a3
|
||||||
size 2904
|
size 2904
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
version https://git-lfs.github.com/spec/v1
|
version https://git-lfs.github.com/spec/v1
|
||||||
oid sha256:724a26cc4a3fc54ea5deb835816afa4a1c9712958ba402cd3067c22f4556a532
|
oid sha256:8698f98e3fe36e321ba99a9b60facaab4abffb26916042b021adc1b41e8fb877
|
||||||
size 100040
|
size 100040
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
version https://git-lfs.github.com/spec/v1
|
version https://git-lfs.github.com/spec/v1
|
||||||
oid sha256:71d6ba89bee5a4ee2761220452999e415bc838a44bebf1b5a2e4ba8622369798
|
oid sha256:c0b18566cbf59e399ea40f1630df12ffbbb9f73bbc733d1d4eba62d675b1fda5
|
||||||
size 247
|
size 247
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
version https://git-lfs.github.com/spec/v1
|
version https://git-lfs.github.com/spec/v1
|
||||||
oid sha256:12cd101db746125d40cf2e27c79340a3786c2906feca11a34e380c5d88280d36
|
oid sha256:5a57aade7d8510ef1cc8778f90cfa86749c95fa0c5a5e80cb166b2edd0f7189a
|
||||||
size 1329662
|
size 1788513
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
version https://git-lfs.github.com/spec/v1
|
version https://git-lfs.github.com/spec/v1
|
||||||
oid sha256:083db9efc5c9e3396c5e1159d020c2a3786f1f1a4b069719d327ed7fbc65c34d
|
oid sha256:e7ab5c2bd7d176d4d7902a600240318c2828b7d75f4a888d0887327e4eff089d
|
||||||
size 33
|
size 65
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
version https://git-lfs.github.com/spec/v1
|
version https://git-lfs.github.com/spec/v1
|
||||||
oid sha256:4500f31e62f0928a837fa71783acacda0db516c7b00d0586a41ea5fd8fc5e772
|
oid sha256:4e910eac6a1c94f4c194b05e908dcc973dd4227b18eb80c374d7a1150f166c34
|
||||||
size 928
|
size 136
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
version https://git-lfs.github.com/spec/v1
|
version https://git-lfs.github.com/spec/v1
|
||||||
oid sha256:cf148247bf191c7f7e8af738a7b9e147f9ffffeec0e4b9d1c4783c4e384da7eb
|
oid sha256:50e40e4c2bb523fca0b54e9a9635281312e9c6f9d757db03c06a0865c5508f29
|
||||||
size 33
|
size 188
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
version https://git-lfs.github.com/spec/v1
|
version https://git-lfs.github.com/spec/v1
|
||||||
oid sha256:ae67bab70f3b313427fdcb447ed0a1a3d09581ff7ae8cc64ddd2243ef9ccb6c0
|
oid sha256:a85e57264325cc0927450e30a85dd0eacb0a70ebdb00c4e2ac043a57f9c200e2
|
||||||
size 2904
|
size 2904
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
version https://git-lfs.github.com/spec/v1
|
version https://git-lfs.github.com/spec/v1
|
||||||
oid sha256:277340fe2c2ca9f40a2cf27caf66dbb47089d690917a076e341d3be586b874d1
|
oid sha256:171a9efc9c45601688821936ec9a1dcf91f16b1bbab4e8246f18b4d4cc6ac6ee
|
||||||
size 80432
|
size 80432
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
version https://git-lfs.github.com/spec/v1
|
version https://git-lfs.github.com/spec/v1
|
||||||
oid sha256:e48156ce4f71ac15d78732312fbc7e199f0ecdaac3604231e6be2e3e5b31a0ad
|
oid sha256:5fd5fe80657788d044cdc8a1baf1456c7695cc951049347a469165002a83c6c7
|
||||||
size 247
|
size 247
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
version https://git-lfs.github.com/spec/v1
|
version https://git-lfs.github.com/spec/v1
|
||||||
oid sha256:292e6815ae7431d07ee75a5a770fdc8fd6fe8479eb104c33774ef0049f0dd768
|
oid sha256:cb4810728c3d642326bf5fa2cd1632a60e68880faace4ec7368c6ee7992dabfb
|
||||||
size 963206
|
size 1297818
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
version https://git-lfs.github.com/spec/v1
|
version https://git-lfs.github.com/spec/v1
|
||||||
oid sha256:083db9efc5c9e3396c5e1159d020c2a3786f1f1a4b069719d327ed7fbc65c34d
|
oid sha256:e7ab5c2bd7d176d4d7902a600240318c2828b7d75f4a888d0887327e4eff089d
|
||||||
size 33
|
size 65
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
version https://git-lfs.github.com/spec/v1
|
version https://git-lfs.github.com/spec/v1
|
||||||
oid sha256:4500f31e62f0928a837fa71783acacda0db516c7b00d0586a41ea5fd8fc5e772
|
oid sha256:4e910eac6a1c94f4c194b05e908dcc973dd4227b18eb80c374d7a1150f166c34
|
||||||
size 928
|
size 136
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
version https://git-lfs.github.com/spec/v1
|
version https://git-lfs.github.com/spec/v1
|
||||||
oid sha256:cf148247bf191c7f7e8af738a7b9e147f9ffffeec0e4b9d1c4783c4e384da7eb
|
oid sha256:50e40e4c2bb523fca0b54e9a9635281312e9c6f9d757db03c06a0865c5508f29
|
||||||
size 33
|
size 188
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue