Merge branch 'main' into aloha_hd5_to_dataset_v2

This commit is contained in:
Claudio Coppola 2025-02-24 12:29:28 +00:00 committed by GitHub
commit dca5c22f9c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
34 changed files with 280 additions and 8191 deletions

View File

@ -8,6 +8,8 @@ on:
schedule: schedule:
- cron: "0 1 * * *" - cron: "0 1 * * *"
permissions: {}
env: env:
PYTHON_VERSION: "3.10" PYTHON_VERSION: "3.10"
@ -25,11 +27,14 @@ jobs:
- name: Set up Docker Buildx - name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3 uses: docker/setup-buildx-action@v3
with:
cache-binary: false
- name: Check out code - name: Check out code
uses: actions/checkout@v4 uses: actions/checkout@v4
with: with:
lfs: true lfs: true
persist-credentials: false
- name: Login to DockerHub - name: Login to DockerHub
uses: docker/login-action@v3 uses: docker/login-action@v3
@ -60,11 +65,14 @@ jobs:
- name: Set up Docker Buildx - name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3 uses: docker/setup-buildx-action@v3
with:
cache-binary: false
- name: Check out code - name: Check out code
uses: actions/checkout@v4 uses: actions/checkout@v4
with: with:
lfs: true lfs: true
persist-credentials: false
- name: Login to DockerHub - name: Login to DockerHub
uses: docker/login-action@v3 uses: docker/login-action@v3
@ -89,9 +97,13 @@ jobs:
steps: steps:
- name: Set up Docker Buildx - name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3 uses: docker/setup-buildx-action@v3
with:
cache-binary: false
- name: Check out code - name: Check out code
uses: actions/checkout@v4 uses: actions/checkout@v4
with:
persist-credentials: false
- name: Login to DockerHub - name: Login to DockerHub
uses: docker/login-action@v3 uses: docker/login-action@v3

View File

@ -7,6 +7,8 @@ on:
schedule: schedule:
- cron: "0 2 * * *" - cron: "0 2 * * *"
permissions: {}
# env: # env:
# SLACK_API_TOKEN: ${{ secrets.SLACK_API_TOKEN }} # SLACK_API_TOKEN: ${{ secrets.SLACK_API_TOKEN }}
jobs: jobs:

View File

@ -8,6 +8,8 @@ on:
branches: branches:
- main - main
permissions: {}
env: env:
PYTHON_VERSION: "3.10" PYTHON_VERSION: "3.10"
@ -17,7 +19,9 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Checkout Repository - name: Checkout Repository
uses: actions/checkout@v3 uses: actions/checkout@v4
with:
persist-credentials: false
- name: Set up Python - name: Set up Python
uses: actions/setup-python@v4 uses: actions/setup-python@v4
@ -34,49 +38,7 @@ jobs:
run: python -m pip install "ruff==${{ env.RUFF_VERSION }}" run: python -m pip install "ruff==${{ env.RUFF_VERSION }}"
- name: Ruff check - name: Ruff check
run: ruff check run: ruff check --output-format=github
- name: Ruff format - name: Ruff format
run: ruff format --diff run: ruff format --diff
poetry_check:
name: Poetry check
runs-on: ubuntu-latest
steps:
- name: Checkout Repository
uses: actions/checkout@v3
- name: Install poetry
run: pipx install "poetry<2.0.0"
- name: Poetry check
run: poetry check
poetry_relax:
name: Poetry relax
runs-on: ubuntu-latest
steps:
- name: Checkout Repository
uses: actions/checkout@v3
- name: Install poetry
run: pipx install "poetry<2.0.0"
- name: Install poetry-relax
run: poetry self add poetry-relax
- name: Poetry relax
id: poetry_relax
run: |
output=$(poetry relax --check 2>&1)
if echo "$output" | grep -q "Proposing updates"; then
echo "$output"
echo ""
echo "Some dependencies have caret '^' version requirement added by poetry by default."
echo "Please replace them with '>='. You can do this by hand or use poetry-relax to do this."
exit 1
else
echo "$output"
fi

View File

@ -8,6 +8,8 @@ on:
# Run only when DockerFile files are modified # Run only when DockerFile files are modified
- "docker/**" - "docker/**"
permissions: {}
env: env:
PYTHON_VERSION: "3.10" PYTHON_VERSION: "3.10"
@ -20,6 +22,8 @@ jobs:
steps: steps:
- name: Check out code - name: Check out code
uses: actions/checkout@v4 uses: actions/checkout@v4
with:
persist-credentials: false
- name: Get changed files - name: Get changed files
id: changed-files id: changed-files
@ -28,15 +32,12 @@ jobs:
files: docker/** files: docker/**
json: "true" json: "true"
- name: Run step if only the files listed above change - name: Run step if only the files listed above change # zizmor: ignore[template-injection]
if: steps.changed-files.outputs.any_changed == 'true' if: steps.changed-files.outputs.any_changed == 'true'
id: set-matrix id: set-matrix
env:
ALL_CHANGED_FILES: ${{ steps.changed-files.outputs.all_changed_files }}
run: | run: |
echo "matrix=${{ steps.changed-files.outputs.all_changed_files}}" >> $GITHUB_OUTPUT echo "matrix=${{ steps.changed-files.outputs.all_changed_files}}" >> $GITHUB_OUTPUT
build_modified_dockerfiles: build_modified_dockerfiles:
name: Build modified Docker images name: Build modified Docker images
needs: get_changed_files needs: get_changed_files
@ -50,9 +51,13 @@ jobs:
steps: steps:
- name: Set up Docker Buildx - name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3 uses: docker/setup-buildx-action@v3
with:
cache-binary: false
- name: Check out code - name: Check out code
uses: actions/checkout@v4 uses: actions/checkout@v4
with:
persist-credentials: false
- name: Build Docker image - name: Build Docker image
uses: docker/build-push-action@v5 uses: docker/build-push-action@v5

View File

@ -7,7 +7,8 @@ on:
- "tests/**" - "tests/**"
- "examples/**" - "examples/**"
- ".github/**" - ".github/**"
- "poetry.lock" - "pyproject.toml"
- ".pre-commit-config.yaml"
- "Makefile" - "Makefile"
- ".cache/**" - ".cache/**"
push: push:
@ -18,10 +19,16 @@ on:
- "tests/**" - "tests/**"
- "examples/**" - "examples/**"
- ".github/**" - ".github/**"
- "poetry.lock" - "pyproject.toml"
- ".pre-commit-config.yaml"
- "Makefile" - "Makefile"
- ".cache/**" - ".cache/**"
permissions: {}
env:
UV_VERSION: "0.6.0"
jobs: jobs:
pytest: pytest:
name: Pytest name: Pytest
@ -32,6 +39,7 @@ jobs:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
with: with:
lfs: true # Ensure LFS files are pulled lfs: true # Ensure LFS files are pulled
persist-credentials: false
- name: Install apt dependencies - name: Install apt dependencies
# portaudio19-dev is needed to install pyaudio # portaudio19-dev is needed to install pyaudio
@ -39,25 +47,19 @@ jobs:
sudo apt-get update && \ sudo apt-get update && \
sudo apt-get install -y libegl1-mesa-dev ffmpeg portaudio19-dev sudo apt-get install -y libegl1-mesa-dev ffmpeg portaudio19-dev
- name: Install poetry - name: Install uv and python
run: | uses: astral-sh/setup-uv@v5
pipx install poetry && poetry config virtualenvs.in-project true
echo "${{ github.workspace }}/.venv/bin" >> $GITHUB_PATH
# TODO(rcadene, aliberts): python 3.12 seems to be used in the tests, not python 3.10
- name: Set up Python 3.10
uses: actions/setup-python@v5
with: with:
enable-cache: true
version: ${{ env.UV_VERSION }}
python-version: "3.10" python-version: "3.10"
cache: "poetry"
- name: Install poetry dependencies - name: Install lerobot (all extras)
run: | run: uv sync --all-extras
poetry install --all-extras
- name: Test with pytest - name: Test with pytest
run: | run: |
pytest tests -v --cov=./lerobot --durations=0 \ uv run pytest tests -v --cov=./lerobot --durations=0 \
-W ignore::DeprecationWarning:imageio_ffmpeg._utils:7 \ -W ignore::DeprecationWarning:imageio_ffmpeg._utils:7 \
-W ignore::UserWarning:torch.utils.data.dataloader:558 \ -W ignore::UserWarning:torch.utils.data.dataloader:558 \
-W ignore::UserWarning:gymnasium.utils.env_checker:247 \ -W ignore::UserWarning:gymnasium.utils.env_checker:247 \
@ -72,28 +74,24 @@ jobs:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
with: with:
lfs: true # Ensure LFS files are pulled lfs: true # Ensure LFS files are pulled
persist-credentials: false
- name: Install apt dependencies - name: Install apt dependencies
run: sudo apt-get update && sudo apt-get install -y ffmpeg run: sudo apt-get update && sudo apt-get install -y ffmpeg
- name: Install poetry - name: Install uv and python
run: | uses: astral-sh/setup-uv@v5
pipx install poetry && poetry config virtualenvs.in-project true
echo "${{ github.workspace }}/.venv/bin" >> $GITHUB_PATH
# TODO(rcadene, aliberts): python 3.12 seems to be used in the tests, not python 3.10
- name: Set up Python 3.10
uses: actions/setup-python@v5
with: with:
enable-cache: true
version: ${{ env.UV_VERSION }}
python-version: "3.10" python-version: "3.10"
- name: Install poetry dependencies - name: Install lerobot
run: | run: uv sync --extra "test"
poetry install --extras "test"
- name: Test with pytest - name: Test with pytest
run: | run: |
pytest tests -v --cov=./lerobot --durations=0 \ uv run pytest tests -v --cov=./lerobot --durations=0 \
-W ignore::DeprecationWarning:imageio_ffmpeg._utils:7 \ -W ignore::DeprecationWarning:imageio_ffmpeg._utils:7 \
-W ignore::UserWarning:torch.utils.data.dataloader:558 \ -W ignore::UserWarning:torch.utils.data.dataloader:558 \
-W ignore::UserWarning:gymnasium.utils.env_checker:247 \ -W ignore::UserWarning:gymnasium.utils.env_checker:247 \
@ -108,6 +106,7 @@ jobs:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
with: with:
lfs: true # Ensure LFS files are pulled lfs: true # Ensure LFS files are pulled
persist-credentials: false
- name: Install apt dependencies - name: Install apt dependencies
# portaudio19-dev is needed to install pyaudio # portaudio19-dev is needed to install pyaudio
@ -115,20 +114,21 @@ jobs:
sudo apt-get update && \ sudo apt-get update && \
sudo apt-get install -y libegl1-mesa-dev portaudio19-dev sudo apt-get install -y libegl1-mesa-dev portaudio19-dev
- name: Install poetry - name: Install uv and python
run: | uses: astral-sh/setup-uv@v5
pipx install poetry && poetry config virtualenvs.in-project true
echo "${{ github.workspace }}/.venv/bin" >> $GITHUB_PATH
- name: Set up Python 3.10
uses: actions/setup-python@v5
with: with:
enable-cache: true
version: ${{ env.UV_VERSION }}
python-version: "3.10" python-version: "3.10"
cache: "poetry"
- name: Install poetry dependencies - name: Install lerobot (all extras)
run: | run: |
poetry install --all-extras uv venv
uv sync --all-extras
- name: venv
run: |
echo "PYTHON_PATH=${{ github.workspace }}/.venv/bin/python" >> $GITHUB_ENV
- name: Test end-to-end - name: Test end-to-end
run: | run: |

View File

@ -3,8 +3,7 @@ on:
name: Secret Leaks name: Secret Leaks
permissions: permissions: {}
contents: read
jobs: jobs:
trufflehog: trufflehog:
@ -14,6 +13,8 @@ jobs:
uses: actions/checkout@v4 uses: actions/checkout@v4
with: with:
fetch-depth: 0 fetch-depth: 0
persist-credentials: false
- name: Secret Scanning - name: Secret Scanning
uses: trufflesecurity/trufflehog@main uses: trufflesecurity/trufflehog@main
with: with:

4
.gitignore vendored
View File

@ -49,6 +49,10 @@ share/python-wheels/
*.egg *.egg
MANIFEST MANIFEST
# uv/poetry lock files
poetry.lock
uv.lock
# PyInstaller # PyInstaller
# Usually these files are written by a python script from a template # Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it. # before PyInstaller builds the exe, so as to inject date/other infos into it.

View File

@ -14,24 +14,20 @@ repos:
- id: end-of-file-fixer - id: end-of-file-fixer
- id: trailing-whitespace - id: trailing-whitespace
- repo: https://github.com/asottile/pyupgrade - repo: https://github.com/asottile/pyupgrade
rev: v3.19.0 rev: v3.19.1
hooks: hooks:
- id: pyupgrade - id: pyupgrade
- repo: https://github.com/astral-sh/ruff-pre-commit - repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.8.2 rev: v0.9.6
hooks: hooks:
- id: ruff - id: ruff
args: [--fix] args: [--fix]
- id: ruff-format - id: ruff-format
- repo: https://github.com/python-poetry/poetry
rev: 1.8.0
hooks:
- id: poetry-check
- id: poetry-lock
args:
- "--check"
- "--no-update"
- repo: https://github.com/gitleaks/gitleaks - repo: https://github.com/gitleaks/gitleaks
rev: v8.21.2 rev: v8.23.3
hooks: hooks:
- id: gitleaks - id: gitleaks
- repo: https://github.com/woodruffw/zizmor-pre-commit
rev: v1.3.1
hooks:
- id: zizmor

View File

@ -129,38 +129,71 @@ Follow these steps to start contributing:
🚨 **Do not** work on the `main` branch. 🚨 **Do not** work on the `main` branch.
4. for development, we use `poetry` instead of just `pip` to easily track our dependencies. 4. for development, we advise to use a tool like `poetry` or `uv` instead of just `pip` to easily track our dependencies.
If you don't have it already, follow the [instructions](https://python-poetry.org/docs/#installation) to install it. Follow the instructions to [install poetry](https://python-poetry.org/docs/#installation) (use a version >=2.1.0) or to [install uv](https://docs.astral.sh/uv/getting-started/installation/#installation-methods) if you don't have one of them already.
Set up a development environment with conda or miniconda: Set up a development environment with conda or miniconda:
```bash ```bash
conda create -y -n lerobot-dev python=3.10 && conda activate lerobot-dev conda create -y -n lerobot-dev python=3.10 && conda activate lerobot-dev
``` ```
To develop on 🤗 LeRobot, you will at least need to install the `dev` and `test` extras dependencies along with the core library: If you're using `uv`, it can manage python versions so you can instead do:
```bash ```bash
poetry install --sync --extras "dev test" uv venv --python 3.10 && source .venv/bin/activate
```
To develop on 🤗 LeRobot, you will at least need to install the `dev` and `test` extras dependencies along with the core library:
using `poetry`
```bash
poetry sync --extras "dev test"
```
using `uv`
```bash
uv sync --extra dev --extra test
``` ```
You can also install the project with all its dependencies (including environments): You can also install the project with all its dependencies (including environments):
using `poetry`
```bash ```bash
poetry install --sync --all-extras poetry sync --all-extras
```
using `uv`
```bash
uv sync --all-extras
``` ```
> **Note:** If you don't install simulation environments with `--all-extras`, the tests that require them will be skipped when running the pytest suite locally. However, they *will* be tested in the CI. In general, we advise you to install everything and test locally before pushing. > **Note:** If you don't install simulation environments with `--all-extras`, the tests that require them will be skipped when running the pytest suite locally. However, they *will* be tested in the CI. In general, we advise you to install everything and test locally before pushing.
Whichever command you chose to install the project (e.g. `poetry install --sync --all-extras`), you should run it again when pulling code with an updated version of `pyproject.toml` and `poetry.lock` in order to synchronize your virtual environment with the new dependencies. Whichever command you chose to install the project (e.g. `poetry sync --all-extras`), you should run it again when pulling code with an updated version of `pyproject.toml` and `poetry.lock` in order to synchronize your virtual environment with the new dependencies.
The equivalent of `pip install some-package`, would just be: The equivalent of `pip install some-package`, would just be:
using `poetry`
```bash ```bash
poetry add some-package poetry add some-package
``` ```
When making changes to the poetry sections of the `pyproject.toml`, you should run the following command to lock dependencies. using `uv`
```bash ```bash
poetry lock --no-update uv add some-package
``` ```
When making changes to the poetry sections of the `pyproject.toml`, you should run the following command to lock dependencies.
using `poetry`
```bash
poetry lock
```
using `uv`
```bash
uv lock
```
5. Develop the features on your branch. 5. Develop the features on your branch.
As you work on the features, you should make sure that the test suite As you work on the features, you should make sure that the test suite

View File

@ -2,10 +2,10 @@
PYTHON_PATH := $(shell which python) PYTHON_PATH := $(shell which python)
# If Poetry is installed, redefine PYTHON_PATH to use the Poetry-managed Python # If uv is installed and a virtual environment exists, use it
POETRY_CHECK := $(shell command -v poetry) UV_CHECK := $(shell command -v uv)
ifneq ($(POETRY_CHECK),) ifneq ($(UV_CHECK),)
PYTHON_PATH := $(shell poetry run which python) PYTHON_PATH := $(shell .venv/bin/python)
endif endif
export PATH := $(dir $(PYTHON_PATH)):$(PATH) export PATH := $(dir $(PYTHON_PATH)):$(PATH)

View File

@ -58,7 +58,7 @@ RUN (type -p wget >/dev/null || (apt update && apt-get install wget -y)) \
RUN ln -s /usr/bin/python3 /usr/bin/python RUN ln -s /usr/bin/python3 /usr/bin/python
# Install poetry # Install poetry
RUN curl -sSL https://install.python-poetry.org | python - --version 1.8.5 RUN curl -sSL https://install.python-poetry.org | python -
ENV PATH="/root/.local/bin:$PATH" ENV PATH="/root/.local/bin:$PATH"
RUN echo 'if [ "$HOME" != "/root" ]; then ln -sf /root/.local/bin/poetry $HOME/.local/bin/poetry; fi' >> /root/.bashrc RUN echo 'if [ "$HOME" != "/root" ]; then ln -sf /root/.local/bin/poetry $HOME/.local/bin/poetry; fi' >> /root/.bashrc
RUN poetry config virtualenvs.create false RUN poetry config virtualenvs.create false

View File

@ -36,9 +36,14 @@ Using `pip`:
pip install -e ".[dynamixel]" pip install -e ".[dynamixel]"
``` ```
Or using `poetry`: Using `poetry`:
```bash ```bash
poetry install --sync --extras "dynamixel" poetry sync --extras "dynamixel"
```
Using `uv`:
```bash
uv sync --extra "dynamixel"
``` ```
/!\ For Linux only, ffmpeg and opencv requires conda install for now. Run this exact sequence of commands: /!\ For Linux only, ffmpeg and opencv requires conda install for now. Run this exact sequence of commands:

View File

@ -104,7 +104,7 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
) )
logging.info( logging.info(
"Multiple datasets were provided. Applied the following index mapping to the provided datasets: " "Multiple datasets were provided. Applied the following index mapping to the provided datasets: "
f"{pformat(dataset.repo_id_to_index , indent=2)}" f"{pformat(dataset.repo_id_to_index, indent=2)}"
) )
if cfg.dataset.use_imagenet_stats: if cfg.dataset.use_imagenet_stats:

View File

@ -72,7 +72,7 @@ def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episod
# However, note that "nearest" might synchronize the reference camera with other cameras on slightly future timestamps. # However, note that "nearest" might synchronize the reference camera with other cameras on slightly future timestamps.
# are too far appart. # are too far appart.
direction="nearest", direction="nearest",
tolerance=pd.Timedelta(f"{1/fps} seconds"), tolerance=pd.Timedelta(f"{1 / fps} seconds"),
) )
# Remove rows with episode_index -1 which indicates data that correspond to in-between episodes # Remove rows with episode_index -1 which indicates data that correspond to in-between episodes
df = df[df["episode_index"] != -1] df = df[df["episode_index"] != -1]

View File

@ -409,9 +409,9 @@ class ACT(nn.Module):
latent dimension. latent dimension.
""" """
if self.config.use_vae and self.training: if self.config.use_vae and self.training:
assert ( assert "action" in batch, (
"action" in batch "actions must be provided when using the variational objective in training mode."
), "actions must be provided when using the variational objective in training mode." )
batch_size = ( batch_size = (
batch["observation.images"] batch["observation.images"]

View File

@ -221,7 +221,7 @@ class DiffusionConfig(PreTrainedConfig):
for key, image_ft in self.image_features.items(): for key, image_ft in self.image_features.items():
if image_ft.shape != first_image_ft.shape: if image_ft.shape != first_image_ft.shape:
raise ValueError( raise ValueError(
f"`{key}` does not match `{first_image_key}`, but we " "expect all image shapes to match." f"`{key}` does not match `{first_image_key}`, but we expect all image shapes to match."
) )
@property @property

View File

@ -300,7 +300,7 @@ class PI0Policy(PreTrainedPolicy):
self._action_queue.extend(actions.transpose(0, 1)) self._action_queue.extend(actions.transpose(0, 1))
return self._action_queue.popleft() return self._action_queue.popleft()
def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> dict[str, Tensor]: def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> tuple[Tensor, dict[str, Tensor]]:
"""Do a full training forward pass to compute the loss""" """Do a full training forward pass to compute the loss"""
if self.config.adapt_to_pi_aloha: if self.config.adapt_to_pi_aloha:
batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT]) batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT])
@ -328,12 +328,12 @@ class PI0Policy(PreTrainedPolicy):
losses = losses[:, :, : self.config.max_action_dim] losses = losses[:, :, : self.config.max_action_dim]
loss_dict["losses_after_rm_padding"] = losses.clone() loss_dict["losses_after_rm_padding"] = losses.clone()
loss = losses.mean()
# For backward pass # For backward pass
loss_dict["loss"] = loss loss = losses.mean()
# For logging # For logging
loss_dict["l2_loss"] = loss.item() loss_dict["l2_loss"] = loss.item()
return loss_dict
return loss, loss_dict
def prepare_images(self, batch): def prepare_images(self, batch):
"""Apply Pi0 preprocessing to the images, like resizing to 224x224 and padding to keep aspect ratio, and """Apply Pi0 preprocessing to the images, like resizing to 224x224 and padding to keep aspect ratio, and

View File

@ -594,9 +594,9 @@ class TDMPCTOLD(nn.Module):
self.apply(_apply_fn) self.apply(_apply_fn)
for m in [self._reward, *self._Qs]: for m in [self._reward, *self._Qs]:
assert isinstance( assert isinstance(m[-1], nn.Linear), (
m[-1], nn.Linear "Sanity check. The last linear layer needs 0 initialization on weights."
), "Sanity check. The last linear layer needs 0 initialization on weights." )
nn.init.zeros_(m[-1].weight) nn.init.zeros_(m[-1].weight)
nn.init.zeros_(m[-1].bias) # this has already been done, but keep this line here for good measure nn.init.zeros_(m[-1].bias) # this has already been done, but keep this line here for good measure

View File

@ -184,7 +184,7 @@ class VQBeTConfig(PreTrainedConfig):
for key, image_ft in self.image_features.items(): for key, image_ft in self.image_features.items():
if image_ft.shape != first_image_ft.shape: if image_ft.shape != first_image_ft.shape:
raise ValueError( raise ValueError(
f"`{key}` does not match `{first_image_key}`, but we " "expect all image shapes to match." f"`{key}` does not match `{first_image_key}`, but we expect all image shapes to match."
) )
@property @property

View File

@ -203,9 +203,9 @@ class GPT(nn.Module):
def forward(self, input, targets=None): def forward(self, input, targets=None):
device = input.device device = input.device
b, t, d = input.size() b, t, d = input.size()
assert ( assert t <= self.config.gpt_block_size, (
t <= self.config.gpt_block_size f"Cannot forward sequence of length {t}, block size is only {self.config.gpt_block_size}"
), f"Cannot forward sequence of length {t}, block size is only {self.config.gpt_block_size}" )
# positional encodings that are added to the input embeddings # positional encodings that are added to the input embeddings
pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t) pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t)
@ -273,10 +273,10 @@ class GPT(nn.Module):
assert len(inter_params) == 0, "parameters {} made it into both decay/no_decay sets!".format( assert len(inter_params) == 0, "parameters {} made it into both decay/no_decay sets!".format(
str(inter_params) str(inter_params)
) )
assert ( assert len(param_dict.keys() - union_params) == 0, (
len(param_dict.keys() - union_params) == 0 "parameters {} were not separated into either decay/no_decay set!".format(
), "parameters {} were not separated into either decay/no_decay set!".format( str(param_dict.keys() - union_params),
str(param_dict.keys() - union_params), )
) )
decay = [param_dict[pn] for pn in sorted(decay)] decay = [param_dict[pn] for pn in sorted(decay)]
@ -419,9 +419,9 @@ class ResidualVQ(nn.Module):
# and the network should be able to reconstruct # and the network should be able to reconstruct
if quantize_dim < self.num_quantizers: if quantize_dim < self.num_quantizers:
assert ( assert self.quantize_dropout > 0.0, (
self.quantize_dropout > 0.0 "quantize dropout must be greater than 0 if you wish to reconstruct from a signal with less fine quantizations"
), "quantize dropout must be greater than 0 if you wish to reconstruct from a signal with less fine quantizations" )
indices = F.pad(indices, (0, self.num_quantizers - quantize_dim), value=-1) indices = F.pad(indices, (0, self.num_quantizers - quantize_dim), value=-1)
# get ready for gathering # get ready for gathering
@ -472,9 +472,9 @@ class ResidualVQ(nn.Module):
all_indices = [] all_indices = []
if return_loss: if return_loss:
assert not torch.any( assert not torch.any(indices == -1), (
indices == -1 "some of the residual vq indices were dropped out. please use indices derived when the module is in eval mode to derive cross entropy loss"
), "some of the residual vq indices were dropped out. please use indices derived when the module is in eval mode to derive cross entropy loss" )
ce_losses = [] ce_losses = []
should_quantize_dropout = self.training and self.quantize_dropout and not return_loss should_quantize_dropout = self.training and self.quantize_dropout and not return_loss
@ -887,9 +887,9 @@ class VectorQuantize(nn.Module):
# only calculate orthogonal loss for the activated codes for this batch # only calculate orthogonal loss for the activated codes for this batch
if self.orthogonal_reg_active_codes_only: if self.orthogonal_reg_active_codes_only:
assert not ( assert not (is_multiheaded and self.separate_codebook_per_head), (
is_multiheaded and self.separate_codebook_per_head "orthogonal regularization for only active codes not compatible with multi-headed with separate codebooks yet"
), "orthogonal regularization for only active codes not compatible with multi-headed with separate codebooks yet" )
unique_code_ids = torch.unique(embed_ind) unique_code_ids = torch.unique(embed_ind)
codebook = codebook[:, unique_code_ids] codebook = codebook[:, unique_code_ids]
@ -999,9 +999,9 @@ def gumbel_sample(
ind = sampling_logits.argmax(dim=dim) ind = sampling_logits.argmax(dim=dim)
one_hot = F.one_hot(ind, size).type(dtype) one_hot = F.one_hot(ind, size).type(dtype)
assert not ( assert not (reinmax and not straight_through), (
reinmax and not straight_through "reinmax can only be turned on if using straight through gumbel softmax"
), "reinmax can only be turned on if using straight through gumbel softmax" )
if not straight_through or temperature <= 0.0 or not training: if not straight_through or temperature <= 0.0 or not training:
return ind, one_hot return ind, one_hot
@ -1209,9 +1209,9 @@ class EuclideanCodebook(nn.Module):
self.gumbel_sample = gumbel_sample self.gumbel_sample = gumbel_sample
self.sample_codebook_temp = sample_codebook_temp self.sample_codebook_temp = sample_codebook_temp
assert not ( assert not (use_ddp and num_codebooks > 1 and kmeans_init), (
use_ddp and num_codebooks > 1 and kmeans_init "kmeans init is not compatible with multiple codebooks in distributed environment for now"
), "kmeans init is not compatible with multiple codebooks in distributed environment for now" )
self.sample_fn = sample_vectors_distributed if use_ddp and sync_kmeans else batched_sample_vectors self.sample_fn = sample_vectors_distributed if use_ddp and sync_kmeans else batched_sample_vectors
self.kmeans_all_reduce_fn = distributed.all_reduce if use_ddp and sync_kmeans else noop self.kmeans_all_reduce_fn = distributed.all_reduce if use_ddp and sync_kmeans else noop

View File

@ -33,7 +33,7 @@ def log_control_info(robot: Robot, dt_s, episode_index=None, frame_index=None, f
def log_dt(shortname, dt_val_s): def log_dt(shortname, dt_val_s):
nonlocal log_items, fps nonlocal log_items, fps
info_str = f"{shortname}:{dt_val_s * 1000:5.2f} ({1/ dt_val_s:3.1f}hz)" info_str = f"{shortname}:{dt_val_s * 1000:5.2f} ({1 / dt_val_s:3.1f}hz)"
if fps is not None: if fps is not None:
actual_fps = 1 / dt_val_s actual_fps = 1 / dt_val_s
if actual_fps < fps - 1: if actual_fps < fps - 1:

View File

@ -58,7 +58,7 @@ def deserialize_json_into_object(fpath: Path, obj: T) -> T:
# Check that they have exactly the same set of keys. # Check that they have exactly the same set of keys.
if target.keys() != source.keys(): if target.keys() != source.keys():
raise ValueError( raise ValueError(
f"Dictionary keys do not match.\n" f"Expected: {target.keys()}, got: {source.keys()}" f"Dictionary keys do not match.\nExpected: {target.keys()}, got: {source.keys()}"
) )
# Recursively update each key. # Recursively update each key.

View File

@ -102,7 +102,7 @@ class WandBLogger:
self._wandb.log_artifact(artifact) self._wandb.log_artifact(artifact)
def log_dict(self, d: dict, step: int, mode: str = "train"): def log_dict(self, d: dict, step: int, mode: str = "train"):
if mode in {"train", "eval"}: if mode not in {"train", "eval"}:
raise ValueError(mode) raise ValueError(mode)
for k, v in d.items(): for k, v in d.items():
@ -114,7 +114,7 @@ class WandBLogger:
self._wandb.log({f"{mode}/{k}": v}, step=step) self._wandb.log({f"{mode}/{k}": v}, step=step)
def log_video(self, video_path: str, step: int, mode: str = "train"): def log_video(self, video_path: str, step: int, mode: str = "train"):
if mode in {"train", "eval"}: if mode not in {"train", "eval"}:
raise ValueError(mode) raise ValueError(mode)
wandb_video = self._wandb.Video(video_path, fps=self.env_fps, format="mp4") wandb_video = self._wandb.Video(video_path, fps=self.env_fps, format="mp4")

View File

@ -85,6 +85,11 @@ class TrainPipelineConfig(HubMixin):
config_path = parser.parse_arg("config_path") config_path = parser.parse_arg("config_path")
if not config_path: if not config_path:
raise ValueError("A config_path is expected when resuming a run.") raise ValueError("A config_path is expected when resuming a run.")
if not Path(config_path).resolve().exists():
raise NotADirectoryError(
f"{config_path=} is expected to be a local path. "
"Resuming from the hub is not supported for now."
)
policy_path = Path(config_path).parent policy_path = Path(config_path).parent
self.policy.pretrained_path = policy_path self.policy.pretrained_path = policy_path
self.checkpoint_path = policy_path.parent self.checkpoint_path = policy_path.parent

View File

@ -151,7 +151,9 @@ def rollout(
if return_observations: if return_observations:
all_observations.append(deepcopy(observation)) all_observations.append(deepcopy(observation))
observation = {key: observation[key].to(device, non_blocking=True) for key in observation} observation = {
key: observation[key].to(device, non_blocking=device.type == "cuda") for key in observation
}
with torch.inference_mode(): with torch.inference_mode():
action = policy.select_action(observation) action = policy.select_action(observation)

View File

@ -232,8 +232,10 @@ def train(cfg: TrainPipelineConfig):
if is_log_step: if is_log_step:
logging.info(train_tracker) logging.info(train_tracker)
if wandb_logger: if wandb_logger:
wandb_log_dict = {**train_tracker.to_dict(), **output_dict} wandb_log_dict = train_tracker.to_dict()
wandb_logger.log_dict(wandb_log_dict) if output_dict:
wandb_log_dict.update(output_dict)
wandb_logger.log_dict(wandb_log_dict, step)
train_tracker.reset_averages() train_tracker.reset_averages()
if cfg.save_checkpoint and is_saving_step: if cfg.save_checkpoint and is_saving_step:
@ -271,6 +273,7 @@ def train(cfg: TrainPipelineConfig):
logging.info(eval_tracker) logging.info(eval_tracker)
if wandb_logger: if wandb_logger:
wandb_log_dict = {**eval_tracker.to_dict(), **eval_info} wandb_log_dict = {**eval_tracker.to_dict(), **eval_info}
wandb_logger.log_dict(wandb_log_dict, step, mode="eval")
wandb_logger.log_video(eval_info["video_paths"][0], step, mode="eval") wandb_logger.log_video(eval_info["video_paths"][0], step, mode="eval")
if eval_env: if eval_env:

View File

@ -111,9 +111,9 @@ def visualize_dataset(
output_dir: Path | None = None, output_dir: Path | None = None,
) -> Path | None: ) -> Path | None:
if save: if save:
assert ( assert output_dir is not None, (
output_dir is not None "Set an output directory where to write .rrd files with `--output-dir path/to/directory`."
), "Set an output directory where to write .rrd files with `--output-dir path/to/directory`." )
repo_id = dataset.repo_id repo_id = dataset.repo_id

7933
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -1,18 +1,24 @@
[tool.poetry] [project.urls]
homepage = "https://github.com/huggingface/lerobot"
issues = "https://github.com/huggingface/lerobot/issues"
discord = "https://discord.gg/s3KuuzsPFb"
[project]
name = "lerobot" name = "lerobot"
version = "0.1.0" version = "0.1.0"
description = "🤗 LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch" description = "🤗 LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch"
authors = [ authors = [
"Rémi Cadène <re.cadene@gmail.com>", {name = "Rémi Cadène", email = "re.cadene@gmail.com"},
"Simon Alibert <alibert.sim@gmail.com>", {name = "Simon Alibert", email = "alibert.sim@gmail.com"},
"Alexander Soare <alexander.soare159@gmail.com>", {name = "Alexander Soare", email = "alexander.soare159@gmail.com"},
"Quentin Gallouédec <quentin.gallouedec@ec-lyon.fr>", {name = "Quentin Gallouédec", email = "quentin.gallouedec@ec-lyon.fr"},
"Adil Zouitine <adilzouitinegm@gmail.com>", {name = "Adil Zouitine", email = "adilzouitinegm@gmail.com"},
"Thomas Wolf <thomaswolfcontact@gmail.com>", {name = "Thomas Wolf", email = "thomaswolfcontact@gmail.com"},
] ]
repository = "https://github.com/huggingface/lerobot"
readme = "README.md" readme = "README.md"
license = "Apache-2.0" license = {text = "Apache-2.0"}
requires-python = ">=3.10"
keywords = ["robotics", "deep learning", "pytorch"]
classifiers=[ classifiers=[
"Development Status :: 3 - Alpha", "Development Status :: 3 - Alpha",
"Intended Audience :: Developers", "Intended Audience :: Developers",
@ -23,70 +29,56 @@ classifiers=[
"License :: OSI Approved :: Apache Software License", "License :: OSI Approved :: Apache Software License",
"Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.10",
] ]
packages = [{include = "lerobot"}] dependencies = [
"cmake>=3.29.0.1",
"datasets>=2.19.0",
"deepdiff>=7.0.1",
"diffusers>=0.27.2",
"draccus>=0.10.0",
"einops>=0.8.0",
"flask>=3.0.3",
"gdown>=5.1.0",
"gymnasium==0.29.1", # TODO(rcadene, aliberts): Make gym 1.0.0 work
"h5py>=3.10.0",
"huggingface-hub[hf-transfer,cli]>=0.27.1 ; python_version < '4.0'",
"hydra-core>=1.3.2",
"imageio[ffmpeg]>=2.34.0",
"jsonlines>=4.0.0",
"numba>=0.59.0",
"omegaconf>=2.3.0",
"opencv-python>=4.9.0",
"pyav>=12.0.5",
"pymunk>=6.6.0",
"rerun-sdk>=0.21.0",
"termcolor>=2.4.0",
"torch>=2.2.1",
"torchvision>=0.21.0",
"wandb>=0.16.3",
"zarr>=2.17.0"
]
[project.optional-dependencies]
aloha = ["gym-aloha>=0.1.1 ; python_version < '4.0'"]
dev = ["pre-commit>=3.7.0", "debugpy>=1.8.1"]
dora = ["gym-dora @ git+https://github.com/dora-rs/dora-lerobot.git#subdirectory=gym_dora ; python_version < '4.0'"]
dynamixel = ["dynamixel-sdk>=3.7.31", "pynput>=1.7.7"]
feetech = ["feetech-servo-sdk>=1.0.0", "pynput>=1.7.7"]
intelrealsense = ["pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'"]
pi0 = ["transformers>=4.48.0"]
pusht = ["gym-pusht>=0.1.5 ; python_version < '4.0'"]
stretch = [
"hello-robot-stretch-body>=0.7.27 ; python_version < '4.0' and sys_platform == 'linux'",
"pyrender @ git+https://github.com/mmatl/pyrender.git ; sys_platform == 'linux'",
"pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'",
"pynput>=1.7.7"
]
test = ["pytest>=8.1.0", "pytest-cov>=5.0.0", "pyserial>=3.5"]
umi = ["imagecodecs>=2024.1.1"]
video_benchmark = ["scikit-image>=0.23.2", "pandas>=2.2.2"]
xarm = ["gym-xarm>=0.1.1 ; python_version < '4.0'"]
[tool.poetry.dependencies] [tool.poetry]
python = ">=3.10,<3.13" requires-poetry = ">=2.1"
termcolor = ">=2.4.0"
wandb = ">=0.16.3"
imageio = {extras = ["ffmpeg"], version = ">=2.34.0"}
gdown = ">=5.1.0"
einops = ">=0.8.0"
pymunk = ">=6.6.0"
zarr = ">=2.17.0"
numba = ">=0.59.0"
torch = ">=2.2.1"
opencv-python = ">=4.9.0"
diffusers = ">=0.27.2"
torchvision = ">=0.21.0"
h5py = ">=3.10.0"
huggingface-hub = {extras = ["hf-transfer", "cli"], version = ">=0.27.1"}
gymnasium = "==0.29.1" # TODO(rcadene, aliberts): Make gym 1.0.0 work
cmake = ">=3.29.0.1"
gym-dora = { git = "https://github.com/dora-rs/dora-lerobot.git", subdirectory = "gym_dora", optional = true }
gym-pusht = { version = ">=0.1.5", optional = true}
gym-xarm = { version = ">=0.1.1", optional = true}
gym-aloha = { version = ">=0.1.1", optional = true}
pre-commit = {version = ">=3.7.0", optional = true}
debugpy = {version = ">=1.8.1", optional = true}
pytest = {version = ">=8.1.0", optional = true}
pytest-cov = {version = ">=5.0.0", optional = true}
datasets = ">=2.19.0"
imagecodecs = { version = ">=2024.1.1", optional = true }
pyav = ">=12.0.5"
rerun-sdk = ">=0.21.0"
deepdiff = ">=7.0.1"
flask = ">=3.0.3"
pandas = {version = ">=2.2.2", optional = true}
scikit-image = {version = ">=0.23.2", optional = true}
dynamixel-sdk = {version = ">=3.7.31", optional = true}
pynput = {version = ">=1.7.7", optional = true}
feetech-servo-sdk = {version = ">=1.0.0", optional = true}
setuptools = {version = "!=71.0.1", optional = true} # TODO(rcadene, aliberts): 71.0.1 has a bug
pyrealsense2 = {version = ">=2.55.1.6486", markers = "sys_platform != 'darwin'", optional = true} # TODO(rcadene, aliberts): Fix on Mac
pyrender = {git = "https://github.com/mmatl/pyrender.git", markers = "sys_platform == 'linux'", optional = true}
hello-robot-stretch-body = {version = ">=0.7.27", markers = "sys_platform == 'linux'", optional = true}
pyserial = {version = ">=3.5", optional = true}
jsonlines = ">=4.0.0"
transformers = {version = ">=4.48.0", optional = true}
draccus = ">=0.10.0"
[tool.poetry.extras]
dora = ["gym-dora"]
pusht = ["gym-pusht"]
xarm = ["gym-xarm"]
aloha = ["gym-aloha"]
dev = ["pre-commit", "debugpy"]
test = ["pytest", "pytest-cov", "pyserial"]
umi = ["imagecodecs"]
video_benchmark = ["scikit-image", "pandas"]
dynamixel = ["dynamixel-sdk", "pynput"]
feetech = ["feetech-servo-sdk", "pynput"]
intelrealsense = ["pyrealsense2"]
stretch = ["hello-robot-stretch-body", "pyrender", "pyrealsense2", "pynput"]
pi0 = ["transformers"]
[tool.ruff] [tool.ruff]
line-length = 110 line-length = 110

View File

@ -49,17 +49,17 @@ def save_dataset_to_safetensors(output_dir, repo_id="lerobot/pusht"):
# save 2 first frames of first episode # save 2 first frames of first episode
i = dataset.episode_data_index["from"][0].item() i = dataset.episode_data_index["from"][0].item()
save_file(dataset[i], repo_dir / f"frame_{i}.safetensors") save_file(dataset[i], repo_dir / f"frame_{i}.safetensors")
save_file(dataset[i + 1], repo_dir / f"frame_{i+1}.safetensors") save_file(dataset[i + 1], repo_dir / f"frame_{i + 1}.safetensors")
# save 2 frames at the middle of first episode # save 2 frames at the middle of first episode
i = int((dataset.episode_data_index["to"][0].item() - dataset.episode_data_index["from"][0].item()) / 2) i = int((dataset.episode_data_index["to"][0].item() - dataset.episode_data_index["from"][0].item()) / 2)
save_file(dataset[i], repo_dir / f"frame_{i}.safetensors") save_file(dataset[i], repo_dir / f"frame_{i}.safetensors")
save_file(dataset[i + 1], repo_dir / f"frame_{i+1}.safetensors") save_file(dataset[i + 1], repo_dir / f"frame_{i + 1}.safetensors")
# save 2 last frames of first episode # save 2 last frames of first episode
i = dataset.episode_data_index["to"][0].item() i = dataset.episode_data_index["to"][0].item()
save_file(dataset[i - 2], repo_dir / f"frame_{i-2}.safetensors") save_file(dataset[i - 2], repo_dir / f"frame_{i - 2}.safetensors")
save_file(dataset[i - 1], repo_dir / f"frame_{i-1}.safetensors") save_file(dataset[i - 1], repo_dir / f"frame_{i - 1}.safetensors")
# TODO(rcadene): Enable testing on second and last episode # TODO(rcadene): Enable testing on second and last episode
# We currently cant because our test dataset only contains the first episode # We currently cant because our test dataset only contains the first episode

View File

@ -336,9 +336,9 @@ def test_backward_compatibility(repo_id):
assert new_keys == old_keys, f"{new_keys=} and {old_keys=} are not the same" assert new_keys == old_keys, f"{new_keys=} and {old_keys=} are not the same"
for key in new_frame: for key in new_frame:
assert torch.isclose( assert torch.isclose(new_frame[key], old_frame[key]).all(), (
new_frame[key], old_frame[key] f"{key=} for index={i} does not contain the same value"
).all(), f"{key=} for index={i} does not contain the same value" )
# test2 first frames of first episode # test2 first frames of first episode
i = dataset.episode_data_index["from"][0].item() i = dataset.episode_data_index["from"][0].item()

View File

@ -343,13 +343,13 @@ def test_save_all_transforms(img_tensor_factory, tmp_path):
# Check if the combined transforms directory exists and contains the right files # Check if the combined transforms directory exists and contains the right files
combined_transforms_dir = tmp_path / "all" combined_transforms_dir = tmp_path / "all"
assert combined_transforms_dir.exists(), "Combined transforms directory was not created." assert combined_transforms_dir.exists(), "Combined transforms directory was not created."
assert any( assert any(combined_transforms_dir.iterdir()), (
combined_transforms_dir.iterdir() "No transformed images found in combined transforms directory."
), "No transformed images found in combined transforms directory." )
for i in range(1, n_examples + 1): for i in range(1, n_examples + 1):
assert ( assert (combined_transforms_dir / f"{i}.png").exists(), (
combined_transforms_dir / f"{i}.png" f"Combined transform image {i}.png was not found."
).exists(), f"Combined transform image {i}.png was not found." )
def test_save_each_transform(img_tensor_factory, tmp_path): def test_save_each_transform(img_tensor_factory, tmp_path):
@ -369,6 +369,6 @@ def test_save_each_transform(img_tensor_factory, tmp_path):
# Check for specific files within each transform directory # Check for specific files within each transform directory
expected_files = [f"{i}.png" for i in range(1, n_examples + 1)] + ["min.png", "max.png", "mean.png"] expected_files = [f"{i}.png" for i in range(1, n_examples + 1)] + ["min.png", "max.png", "mean.png"]
for file_name in expected_files: for file_name in expected_files:
assert ( assert (transform_dir / file_name).exists(), (
transform_dir / file_name f"{file_name} was not found in {transform} directory."
).exists(), f"{file_name} was not found in {transform} directory." )

View File

@ -132,9 +132,9 @@ def test_fifo():
buffer.add_data(new_data) buffer.add_data(new_data)
n_more_episodes = 2 n_more_episodes = 2
# Developer sanity check (in case someone changes the global `buffer_capacity`). # Developer sanity check (in case someone changes the global `buffer_capacity`).
assert ( assert (n_episodes + n_more_episodes) * n_frames_per_episode > buffer_capacity, (
n_episodes + n_more_episodes "Something went wrong with the test code."
) * n_frames_per_episode > buffer_capacity, "Something went wrong with the test code." )
more_new_data = make_spoof_data_frames(n_more_episodes, n_frames_per_episode) more_new_data = make_spoof_data_frames(n_more_episodes, n_frames_per_episode)
buffer.add_data(more_new_data) buffer.add_data(more_new_data)
assert len(buffer) == buffer_capacity, "The buffer should be full." assert len(buffer) == buffer_capacity, "The buffer should be full."
@ -203,9 +203,9 @@ def test_delta_timestamps_outside_tolerance_outside_episode_range():
item = buffer[2] item = buffer[2]
data, is_pad = item["index"], item["index_is_pad"] data, is_pad = item["index"], item["index_is_pad"]
assert torch.equal(data, torch.tensor([0, 0, 2, 4, 4])), "Data does not match expected values" assert torch.equal(data, torch.tensor([0, 0, 2, 4, 4])), "Data does not match expected values"
assert torch.equal( assert torch.equal(is_pad, torch.tensor([True, False, False, True, True])), (
is_pad, torch.tensor([True, False, False, True, True]) "Padding does not match expected values"
), "Padding does not match expected values" )
# Arbitrarily set small dataset sizes, making sure to have uneven sizes. # Arbitrarily set small dataset sizes, making sure to have uneven sizes.

View File

@ -193,12 +193,12 @@ def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs):
observation_ = deepcopy(observation) observation_ = deepcopy(observation)
with torch.inference_mode(): with torch.inference_mode():
action = policy.select_action(observation).cpu().numpy() action = policy.select_action(observation).cpu().numpy()
assert set(observation) == set( assert set(observation) == set(observation_), (
observation_ "Observation batch keys are not the same after a forward pass."
), "Observation batch keys are not the same after a forward pass." )
assert all( assert all(torch.equal(observation[k], observation_[k]) for k in observation), (
torch.equal(observation[k], observation_[k]) for k in observation "Observation batch values are not the same after a forward pass."
), "Observation batch values are not the same after a forward pass." )
# Test step through policy # Test step through policy
env.step(action) env.step(action)