fix environment seeding
add fixes for reproducibility only try to start env if it is closed revision fix normalization and data type Improve README Improve README Tests are passing, Eval pretrained model works, Add gif Update gif Update gif Update gif Update gif Update README Update README update minor Update README.md Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com> Update README.md Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com> Address suggestions Update thumbnail + stats Update thumbnail + stats Update README.md Co-authored-by: Alexander Soare <alexander.soare159@gmail.com> Add more comments Add test_examples.py
This commit is contained in:
parent
203bcd7ca5
commit
1a1308d62f
495
README.md
495
README.md
|
@ -1,83 +1,356 @@
|
|||
# Le Robot
|
||||
<p align="center">
|
||||
<picture>
|
||||
<source media="(prefers-color-scheme: dark)" srcset="media/lerobot-logo-thumbnail.png">
|
||||
<source media="(prefers-color-scheme: light)" srcset="media/lerobot-logo-thumbnail.png">
|
||||
<img alt="LeRobot, Hugging Face Robotics Library" src="media/lerobot-logo-thumbnail.png" style="max-width: 100%;">
|
||||
</picture>
|
||||
<br/>
|
||||
<br/>
|
||||
</p>
|
||||
|
||||
#### State-of-the-art machine learning for real-world robotics
|
||||
# LeRobot
|
||||
|
||||
Le Robot aims to provide models, datasets, and tools for real-world robotics in PyTorch. The goal is to lower the barrier for entry to robotics so that everyone can contribute and benefit from sharing datasets and pretrained models.
|
||||
**State-of-the-art machine learning for real-world robotics**
|
||||
|
||||
Le Robot contains state-of-the-art approaches that have been shown to transfer to the real-world with a focus on imitation learning and reinforcement learning.
|
||||
🤗 LeRobot aims to provide models, datasets, and tools for real-world robotics in PyTorch. The goal is to lower the barrier for entry to robotics so that everyone can contribute and benefit from sharing datasets and pretrained models.
|
||||
|
||||
Le Robot already provides a set of pretrained models, datasets with human collected demonstrations, and simulated environments so that everyone can get started. In the coming weeks, the plan is to add more and more supports for real-world robotics on the most affordable and capable robots out there.
|
||||
🤗 LeRobot contains state-of-the-art approaches that have been shown to transfer to the real-world with a focus on imitation learning and reinforcement learning.
|
||||
|
||||
Le Robot is built upon [TorchRL](https://github.com/pytorch/rl) which provides abstractions and utilities for Reinforcement Learning.
|
||||
🤗 LeRobot already provides a set of pretrained models, datasets with human collected demonstrations, and simulated environments so that everyone can get started. In the coming weeks, the plan is to add more and more support for real-world robotics on the most affordable and capable robots out there.
|
||||
|
||||
## Acknowledgment
|
||||
🤗 LeRobot hosts pretrained models and datasets on this HuggingFace community page: [huggingface.co/lerobot](https://huggingface.co/lerobot)
|
||||
|
||||
- Our ACT policy and ALOHA environment are adapted from [ALOHA](https://tonyzhaozh.github.io/aloha/)
|
||||
- Our Diffusion policy and Pusht environment are adapted from [Diffusion Policy](https://diffusion-policy.cs.columbia.edu/)
|
||||
- Our TDMPC policy and Simxarm environment are adapted from [FOWM](https://www.yunhaifeng.com/FOWM/)
|
||||
#### Examples of pretrained models and environments
|
||||
|
||||
<table>
|
||||
<tr>
|
||||
<td><img src="http://remicadene.com/assets/gif/aloha_act.gif" width="100%" alt="ACT policy on ALOHA env"/></td>
|
||||
<td><img src="http://remicadene.com/assets/gif/simxarm_tdmpc.gif" width="100%" alt="TDMPC policy on SimXArm env"/></td>
|
||||
<td><img src="http://remicadene.com/assets/gif/pusht_diffusion.gif" width="100%" alt="Diffusion policy on PushT env"/></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">ACT policy on ALOHA env</td>
|
||||
<td align="center">TDMPC policy on SimXArm env</td>
|
||||
<td align="center">Diffusion policy on PushT env</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
### Acknowledgment
|
||||
|
||||
- ACT policy and ALOHA environment are adapted from [ALOHA](https://tonyzhaozh.github.io/aloha/)
|
||||
- Diffusion policy and Pusht environment are adapted from [Diffusion Policy](https://diffusion-policy.cs.columbia.edu/)
|
||||
- TDMPC policy and Simxarm environment are adapted from [FOWM](https://www.yunhaifeng.com/FOWM/)
|
||||
- Abstractions and utilities for Reinforcement Learning come from [TorchRL](https://github.com/pytorch/rl)
|
||||
|
||||
## Installation
|
||||
|
||||
Create a virtual environment with Python 3.10, e.g. using `conda`:
|
||||
```
|
||||
```bash
|
||||
conda create -y -n lerobot python=3.10
|
||||
conda activate lerobot
|
||||
```
|
||||
|
||||
[Install `poetry`](https://python-poetry.org/docs/#installation) (if you don't have it already)
|
||||
```
|
||||
```bash
|
||||
curl -sSL https://install.python-poetry.org | python -
|
||||
```
|
||||
|
||||
Install dependencies
|
||||
```
|
||||
```bash
|
||||
poetry install
|
||||
```
|
||||
|
||||
If you encounter a disk space error, try to change your tmp dir to a location where you have enough disk space, e.g.
|
||||
```
|
||||
```bash
|
||||
mkdir ~/tmp
|
||||
export TMPDIR='~/tmp'
|
||||
```
|
||||
|
||||
To use [Weights and Biases](https://docs.wandb.ai/quickstart) for experiments tracking, log in with
|
||||
```
|
||||
```bash
|
||||
wandb login
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
### Train
|
||||
## Walkthrough
|
||||
|
||||
```
|
||||
python lerobot/scripts/train.py \
|
||||
hydra.job.name=pusht \
|
||||
env=pusht
|
||||
```
|
||||
|
||||
### Visualize offline buffer
|
||||
.
|
||||
├── lerobot
|
||||
| ├── configs # contains hydra yaml files with all options that you can override in the command line
|
||||
| | ├── default.yaml # selected by default, it loads pusht environment and diffusion policy
|
||||
| | ├── env # various sim environments and their datasets: aloha.yaml, pusht.yaml, simxarm.yaml
|
||||
| | └── policy # various policies: act.yaml, diffusion.yaml, tdmpc.yaml
|
||||
| ├── common # contains classes and utilities
|
||||
| | ├── datasets # various datasets of human demonstrations: aloha, pusht, simxarm
|
||||
| | ├── envs # various sim environments: aloha, pusht, simxarm
|
||||
| | └── policies # various policies: act, diffusion, tdmpc
|
||||
| └── scripts # contains functions to execute via command line
|
||||
| ├── visualize_dataset.py # load a dataset and render its demonstrations
|
||||
| ├── eval.py # load policy and evaluate it on an environment
|
||||
| └── train.py # train a policy via imitation learning and/or reinforcement learning
|
||||
├── outputs # contains results of scripts execution: logs, videos, model checkpoints
|
||||
├── .github
|
||||
| └── workflows
|
||||
| └── test.yml # defines install settings for continuous integration and specifies end-to-end tests
|
||||
└── tests # contains pytest utilities for continuous integration
|
||||
|
||||
```
|
||||
|
||||
### Visualize datasets
|
||||
|
||||
You can import our dataset class, download the data from the HuggingFace hub and use our rendering utilities:
|
||||
```python
|
||||
""" Copy pasted from `examples/1_visualize_dataset.py` """
|
||||
import lerobot
|
||||
from lerobot.common.datasets.aloha import AlohaDataset
|
||||
from torchrl.data.replay_buffers import SamplerWithoutReplacement
|
||||
from lerobot.scripts.visualize_dataset import render_dataset
|
||||
|
||||
print(lerobot.available_datasets)
|
||||
# >>> ['aloha_sim_insertion_human', 'aloha_sim_insertion_scripted', 'aloha_sim_transfer_cube_human', 'aloha_sim_transfer_cube_scripted', 'pusht', 'xarm_lift_medium']
|
||||
|
||||
# we use this sampler to sample 1 frame after the other
|
||||
sampler = SamplerWithoutReplacement(shuffle=False)
|
||||
|
||||
dataset = AlohaDataset("aloha_sim_transfer_cube_human", sampler=sampler)
|
||||
|
||||
video_paths = render_dataset(
|
||||
dataset,
|
||||
out_dir="outputs/visualize_dataset/example",
|
||||
max_num_samples=300,
|
||||
fps=50,
|
||||
)
|
||||
print(video_paths)
|
||||
# >>> ['outputs/visualize_dataset/example/episode_0.mp4']
|
||||
```
|
||||
|
||||
Or you can achieve the same result by executing our script from the command line:
|
||||
```bash
|
||||
python lerobot/scripts/visualize_dataset.py \
|
||||
hydra.run.dir=tmp/$(date +"%Y_%m_%d") \
|
||||
env=pusht
|
||||
env=aloha \
|
||||
task=sim_sim_transfer_cube_human \
|
||||
hydra.run.dir=outputs/visualize_dataset/example
|
||||
# >>> ['outputs/visualize_dataset/example/episode_0.mp4']
|
||||
```
|
||||
|
||||
### Eval
|
||||
### Evaluate a pretrained policy
|
||||
|
||||
Run `python lerobot/scripts/eval.py --help` for instructions.
|
||||
You can import our environment class, download pretrained policies from the HuggingFace hub, and use our rollout utilities with rendering:
|
||||
```python
|
||||
""" Copy pasted from `examples/2_evaluate_pretrained_policy.py`
|
||||
# TODO
|
||||
```
|
||||
|
||||
## TODO
|
||||
Or you can achieve the same result by executing our script from the command line:
|
||||
```bash
|
||||
python lerobot/scripts/eval.py \
|
||||
--hub-id lerobot/diffusion_policy_pusht_image \
|
||||
--revision v1.0 \
|
||||
eval_episodes=10 \
|
||||
hydra.run.dir=outputs/eval/example_hub
|
||||
```
|
||||
|
||||
If you are not sure how to contribute or want to know the next features we working on, look on this project page: [LeRobot TODO](https://github.com/users/Cadene/projects/1)
|
||||
After launching training of your own policy, you can also re-evaluate the checkpoints with:
|
||||
```bash
|
||||
python lerobot/scripts/eval.py \
|
||||
--config PATH/TO/FOLDER/config.yaml \
|
||||
policy.pretrained_model_path=PATH/TO/FOLDER/weights.pth \
|
||||
eval_episodes=10 \
|
||||
hydra.run.dir=outputs/eval/example_dir
|
||||
```
|
||||
|
||||
Ask [Remi Cadene](re.cadene@gmail.com) for access if needed.
|
||||
See `python lerobot/scripts/eval.py --help` for more instructions.
|
||||
|
||||
### Train your own policy
|
||||
|
||||
You can import our dataset, environment, policy classes, and use our training utilities (if some data is missing, it will be automatically downloaded from HuggingFace hub):
|
||||
```python
|
||||
""" Copy pasted from `examples/3_train_policy.py`
|
||||
# TODO
|
||||
```
|
||||
|
||||
Or you can achieve the same result by executing our script from the command line:
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
hydra.run.dir=outputs/train/example
|
||||
```
|
||||
|
||||
You can easily train any policy on any environment:
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
env=aloha \
|
||||
task=sim_insertion \
|
||||
dataset_id=aloha_sim_insertion_scripted \
|
||||
policy=act \
|
||||
hydra.run.dir=outputs/train/aloha_act
|
||||
```
|
||||
|
||||
## Contribute
|
||||
|
||||
Feel free to open issues and PRs, and to coordinate your efforts with the community on our [Discord Channel](https://discord.gg/VjFz58wn3R). For specific inquiries, reach out to [Remi Cadene](remi.cadene@huggingface.co).
|
||||
|
||||
**TODO**
|
||||
|
||||
If you are not sure how to contribute or want to know the next features we working on, look on this project page: [LeRobot TODO](https://github.com/orgs/huggingface/projects/46)
|
||||
|
||||
**Follow our style**
|
||||
|
||||
```bash
|
||||
# install if needed
|
||||
pre-commit install
|
||||
# apply style and linter checks before git commit
|
||||
pre-commit
|
||||
```
|
||||
|
||||
**Add dependencies**
|
||||
|
||||
Instead of `pip install some-package`, we use `poetry` to track the versions of our dependencies:
|
||||
```bash
|
||||
poetry add some-package
|
||||
```
|
||||
|
||||
**NOTE:** Currently, to ensure the CI works properly, any new package must also be added in the CPU-only environment dedicated CI. To do this, you should create a separate environment and add the new package there as well. For example:
|
||||
```bash
|
||||
# add the new package to your main poetry env
|
||||
poetry add some-package
|
||||
# add the same package to the CPU-only env dedicated to CI
|
||||
conda create -y -n lerobot-ci python=3.10
|
||||
conda activate lerobot-ci
|
||||
cd .github/poetry/cpu
|
||||
poetry add some-package
|
||||
```
|
||||
|
||||
**Run tests locally**
|
||||
|
||||
Install [git lfs](https://git-lfs.com/) to retrieve test artifacts (if you don't have it already).
|
||||
|
||||
On Mac:
|
||||
```bash
|
||||
brew install git-lfs
|
||||
git lfs install
|
||||
```
|
||||
|
||||
On Ubuntu:
|
||||
```bash
|
||||
sudo apt-get install git-lfs
|
||||
git lfs install
|
||||
```
|
||||
|
||||
Pull artifacts if they're not in [tests/data](tests/data)
|
||||
```bash
|
||||
git lfs pull
|
||||
```
|
||||
|
||||
When adding a new dataset, mock it with
|
||||
```bash
|
||||
python tests/scripts/mock_dataset.py --in-data-dir data/$DATASET --out-data-dir tests/data/$DATASET
|
||||
```
|
||||
|
||||
Run tests
|
||||
```bash
|
||||
DATA_DIR="tests/data" pytest -sx tests
|
||||
```
|
||||
|
||||
**Add a new dataset**
|
||||
|
||||
To add a dataset to the hub, first login and use a token generated from [huggingface settings](https://huggingface.co/settings/tokens) with write access:
|
||||
```bash
|
||||
huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
|
||||
```
|
||||
|
||||
Then you can upload it to the hub with:
|
||||
```bash
|
||||
HF_HUB_ENABLE_HF_TRANSFER=1 huggingface-cli upload $HF_USER/$DATASET data/$DATASET \
|
||||
--repo-type dataset \
|
||||
--revision v1.0
|
||||
```
|
||||
|
||||
You will need to set the corresponding version as a default argument in your dataset class:
|
||||
```python
|
||||
version: str | None = "v1.0",
|
||||
```
|
||||
See: [`lerobot/common/datasets/pusht.py`](https://github.com/Cadene/lerobot/blob/main/lerobot/common/datasets/pusht.py)
|
||||
|
||||
For instance, for [lerobot/pusht](https://huggingface.co/datasets/lerobot/pusht), we used:
|
||||
```bash
|
||||
HF_USER=lerobot
|
||||
DATASET=pusht
|
||||
```
|
||||
|
||||
If you want to improve an existing dataset, you can download it locally with:
|
||||
```bash
|
||||
mkdir -p data/$DATASET
|
||||
HF_HUB_ENABLE_HF_TRANSFER=1 huggingface-cli download ${HF_USER}/$DATASET \
|
||||
--repo-type dataset \
|
||||
--local-dir data/$DATASET \
|
||||
--local-dir-use-symlinks=False \
|
||||
--revision v1.0
|
||||
```
|
||||
|
||||
Iterate on your code and dataset with:
|
||||
```bash
|
||||
DATA_DIR=data python train.py
|
||||
```
|
||||
|
||||
Upload a new version (v2.0 or v1.1 if the changes are respectively more or less significant):
|
||||
```bash
|
||||
HF_HUB_ENABLE_HF_TRANSFER=1 huggingface-cli upload $HF_USER/$DATASET data/$DATASET \
|
||||
--repo-type dataset \
|
||||
--revision v1.1 \
|
||||
--delete "*"
|
||||
```
|
||||
|
||||
Then you will need to set the corresponding version as a default argument in your dataset class:
|
||||
```python
|
||||
version: str | None = "v1.1",
|
||||
```
|
||||
See: [`lerobot/common/datasets/pusht.py`](https://github.com/Cadene/lerobot/blob/main/lerobot/common/datasets/pusht.py)
|
||||
|
||||
|
||||
## Profile
|
||||
Finally, you might want to mock the dataset if you need to update the unit tests as well:
|
||||
```bash
|
||||
python tests/scripts/mock_dataset.py --in-data-dir data/$DATASET --out-data-dir tests/data/$DATASET
|
||||
```
|
||||
|
||||
**Example**
|
||||
**Add a pretrained policy**
|
||||
|
||||
Once you have trained a policy you may upload it to the HuggingFace hub.
|
||||
|
||||
Firstly, make sure you have a model repository set up on the hub. The hub ID looks like HF_USER/REPO_NAME.
|
||||
|
||||
Secondly, assuming you have trained a policy, you need:
|
||||
|
||||
- `config.yaml` which you can get from the `.hydra` directory of your training output folder.
|
||||
- `model.pt` which should be one of the saved models in the `models` directory of your training output folder (they won't be named `model.pt` but you will need to choose one).
|
||||
- `stats.pth` which should point to the same file in the dataset directory (found in `data/{dataset_name}`).
|
||||
|
||||
To upload these to the hub, prepare a folder with the following structure (you can use symlinks rather than copying):
|
||||
|
||||
```
|
||||
to_upload
|
||||
├── config.yaml
|
||||
├── model.pt
|
||||
└── stats.pth
|
||||
```
|
||||
|
||||
With the folder prepared, run the following with a desired revision ID.
|
||||
|
||||
```bash
|
||||
huggingface-cli upload $HUB_ID to_upload --revision $REVISION_ID
|
||||
```
|
||||
|
||||
If you want this to be the default revision also run the following (don't worry, it won't upload the files again; it will just adjust the file pointers):
|
||||
|
||||
```bash
|
||||
huggingface-cli upload $HUB_ID to_upload
|
||||
```
|
||||
|
||||
See `eval.py` for an example of how a user may use your policy.
|
||||
|
||||
|
||||
**Improve your code with profiling**
|
||||
|
||||
An example of a code snippet to profile the evaluation of a policy:
|
||||
```python
|
||||
from torch.profiler import profile, record_function, ProfilerActivity
|
||||
|
||||
|
@ -96,160 +369,12 @@ with profile(
|
|||
with record_function("eval_policy"):
|
||||
for i in range(num_episodes):
|
||||
prof.step()
|
||||
# insert code to profile, potentially whole body of eval_policy function
|
||||
```
|
||||
|
||||
```bash
|
||||
python lerobot/scripts/eval.py \
|
||||
--config /home/rcadene/code/fowm/logs/xarm_lift/all/default/2/.hydra/config.yaml \
|
||||
pretrained_model_path=/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/final.pt \
|
||||
eval_episodes=7
|
||||
--config outputs/pusht/.hydra/config.yaml \
|
||||
pretrained_model_path=outputs/pusht/model.pt \
|
||||
eval_episodes=7
|
||||
```
|
||||
|
||||
## Contribute
|
||||
|
||||
**Style**
|
||||
```
|
||||
# install if needed
|
||||
pre-commit install
|
||||
# apply style and linter checks before git commit
|
||||
pre-commit run -a
|
||||
```
|
||||
|
||||
**Adding dependencies (temporary)**
|
||||
|
||||
Right now, for the CI to work, whenever a new dependency is added it needs to be also added to the cpu env, eg:
|
||||
|
||||
```
|
||||
# Run in this directory, adds the package to the main env with cuda
|
||||
poetry add some-package
|
||||
|
||||
# Adds the same package to the cpu env
|
||||
cd .github/poetry/cpu && poetry add some-package
|
||||
```
|
||||
|
||||
**Tests**
|
||||
|
||||
Install [git lfs](https://git-lfs.com/) to retrieve test artifacts (if you don't have it already).
|
||||
|
||||
On Mac:
|
||||
```
|
||||
brew install git-lfs
|
||||
git lfs install
|
||||
```
|
||||
|
||||
On Ubuntu:
|
||||
```
|
||||
sudo apt-get install git-lfs
|
||||
git lfs install
|
||||
```
|
||||
|
||||
Pull artifacts if they're not in [tests/data](tests/data)
|
||||
```
|
||||
git lfs pull
|
||||
```
|
||||
|
||||
When adding a new dataset, mock it with
|
||||
```
|
||||
python tests/scripts/mock_dataset.py --in-data-dir data/$DATASET --out-data-dir tests/data/$DATASET
|
||||
```
|
||||
|
||||
Run tests
|
||||
```
|
||||
DATA_DIR="tests/data" pytest -sx tests
|
||||
```
|
||||
|
||||
**Datasets**
|
||||
|
||||
To add a dataset to the hub, first login and use a token generated from [huggingface settings](https://huggingface.co/settings/tokens) with write access:
|
||||
```
|
||||
huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
|
||||
```
|
||||
|
||||
Then you can upload it to the hub with:
|
||||
```
|
||||
HF_HUB_ENABLE_HF_TRANSFER=1 huggingface-cli upload $HF_USER/$DATASET data/$DATASET \
|
||||
--repo-type dataset \
|
||||
--revision v1.0
|
||||
```
|
||||
|
||||
You will need to set the corresponding version as a default argument in your dataset class:
|
||||
```python
|
||||
version: str | None = "v1.0",
|
||||
```
|
||||
See: [`lerobot/common/datasets/pusht.py`](https://github.com/Cadene/lerobot/blob/main/lerobot/common/datasets/pusht.py)
|
||||
|
||||
For instance, for [cadene/pusht](https://huggingface.co/datasets/cadene/pusht), we used:
|
||||
```
|
||||
HF_USER=cadene
|
||||
DATASET=pusht
|
||||
```
|
||||
|
||||
If you want to improve an existing dataset, you can download it locally with:
|
||||
```
|
||||
mkdir -p data/$DATASET
|
||||
HF_HUB_ENABLE_HF_TRANSFER=1 huggingface-cli download ${HF_USER}/$DATASET \
|
||||
--repo-type dataset \
|
||||
--local-dir data/$DATASET \
|
||||
--local-dir-use-symlinks=False \
|
||||
--revision v1.0
|
||||
```
|
||||
|
||||
Iterate on your code and dataset with:
|
||||
```
|
||||
DATA_DIR=data python train.py
|
||||
```
|
||||
|
||||
Upload a new version (v2.0 or v1.1 if the changes are respectively more or less significant):
|
||||
```
|
||||
HF_HUB_ENABLE_HF_TRANSFER=1 huggingface-cli upload $HF_USER/$DATASET data/$DATASET \
|
||||
--repo-type dataset \
|
||||
--revision v1.1 \
|
||||
--delete "*"
|
||||
```
|
||||
|
||||
Then you will need to set the corresponding version as a default argument in your dataset class:
|
||||
```python
|
||||
version: str | None = "v1.1",
|
||||
```
|
||||
See: [`lerobot/common/datasets/pusht.py`](https://github.com/Cadene/lerobot/blob/main/lerobot/common/datasets/pusht.py)
|
||||
|
||||
|
||||
Finally, you might want to mock the dataset if you need to update the unit tests as well:
|
||||
```
|
||||
python tests/scripts/mock_dataset.py --in-data-dir data/$DATASET --out-data-dir tests/data/$DATASET
|
||||
```
|
||||
|
||||
**Models**
|
||||
|
||||
Once you have trained a model you may upload it to the HuggingFace hub.
|
||||
|
||||
Firstly, make sure you have a model repository set up on the hub. The hub ID looks like HF_USER/REPO_NAME.
|
||||
|
||||
Secondly, assuming you have trained a model, you need:
|
||||
|
||||
- `config.yaml` which you can get from the `.hydra` directory of your training output folder.
|
||||
- `model.pt` which should be one of the saved models in the `models` directory of your training output folder (they won't be named `model.pt` but you will need to choose one).
|
||||
- `staths.pth` which should point to the same file in the dataset directory (found in `data/{dataset_name}`).
|
||||
|
||||
To upload these to the hub, prepare a folder with the following structure (you can use symlinks rather than copying):
|
||||
|
||||
```
|
||||
to_upload
|
||||
├── config.yaml
|
||||
├── model.pt
|
||||
└── stats.pth
|
||||
```
|
||||
|
||||
With the folder prepared, run the following with a desired revision ID.
|
||||
|
||||
```
|
||||
huggingface-cli upload $HUB_ID to_upload --revision $REVISION_ID
|
||||
```
|
||||
|
||||
If you want this to be the default revision also run the following (don't worry, it won't upload the files again; it will just adjust the file pointers):
|
||||
|
||||
```
|
||||
huggingface-cli upload $HUB_ID to_upload
|
||||
```
|
||||
|
||||
See `eval.py` for an example of how a user may use your model.
|
||||
|
|
|
@ -0,0 +1,22 @@
|
|||
from torchrl.data.replay_buffers import SamplerWithoutReplacement
|
||||
|
||||
import lerobot
|
||||
from lerobot.common.datasets.aloha import AlohaDataset
|
||||
from lerobot.scripts.visualize_dataset import render_dataset
|
||||
|
||||
print(lerobot.available_datasets)
|
||||
# >>> ['aloha_sim_insertion_human', 'aloha_sim_insertion_scripted', 'aloha_sim_transfer_cube_human', 'aloha_sim_transfer_cube_scripted', 'pusht', 'xarm_lift_medium']
|
||||
|
||||
# we use this sampler to sample 1 frame after the other
|
||||
sampler = SamplerWithoutReplacement(shuffle=False)
|
||||
|
||||
dataset = AlohaDataset("aloha_sim_transfer_cube_human", sampler=sampler)
|
||||
|
||||
video_paths = render_dataset(
|
||||
dataset,
|
||||
out_dir="outputs/visualize_dataset/example",
|
||||
max_num_samples=300,
|
||||
fps=50,
|
||||
)
|
||||
print(video_paths)
|
||||
# ['outputs/visualize_dataset/example/episode_0.mp4']
|
|
@ -0,0 +1 @@
|
|||
# TODO
|
|
@ -0,0 +1 @@
|
|||
# TODO
|
|
@ -1 +1,59 @@
|
|||
"""
|
||||
This file contains lists of available environments, dataset and policies to reflect the current state of LeRobot library.
|
||||
We do not want to import all the dependencies, but instead we keep it lightweight to ensure fast access to these variables.
|
||||
|
||||
Example:
|
||||
```python
|
||||
import lerobot
|
||||
print(lerobot.available_envs)
|
||||
print(lerobot.available_tasks_per_env)
|
||||
print(lerobot.available_datasets_per_env)
|
||||
print(lerobot.available_datasets)
|
||||
print(lerobot.available_policies)
|
||||
```
|
||||
|
||||
Note:
|
||||
When implementing a concrete class (e.g. `AlohaDataset`, `PushtEnv`, `DiffusionPolicy`), you need to:
|
||||
1. set the required class attributes:
|
||||
- for classes inheriting from `AbstractDataset`: `available_datasets`
|
||||
- for classes inheriting from `AbstractEnv`: `name`, `available_tasks`
|
||||
- for classes inheriting from `AbstractPolicy`: `name`
|
||||
2. update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`)
|
||||
3. update variables in `tests/test_available.py` by importing your new class
|
||||
"""
|
||||
|
||||
from lerobot.__version__ import __version__ # noqa: F401
|
||||
|
||||
available_envs = [
|
||||
"aloha",
|
||||
"pusht",
|
||||
"simxarm",
|
||||
]
|
||||
|
||||
available_tasks_per_env = {
|
||||
"aloha": [
|
||||
"sim_insertion",
|
||||
"sim_transfer_cube",
|
||||
],
|
||||
"pusht": ["pusht"],
|
||||
"simxarm": ["lift"],
|
||||
}
|
||||
|
||||
available_datasets_per_env = {
|
||||
"aloha": [
|
||||
"aloha_sim_insertion_human",
|
||||
"aloha_sim_insertion_scripted",
|
||||
"aloha_sim_transfer_cube_human",
|
||||
"aloha_sim_transfer_cube_scripted",
|
||||
],
|
||||
"pusht": ["pusht"],
|
||||
"simxarm": ["xarm_lift_medium"],
|
||||
}
|
||||
|
||||
available_datasets = [dataset for env in available_envs for dataset in available_datasets_per_env[env]]
|
||||
|
||||
available_policies = [
|
||||
"act",
|
||||
"diffusion",
|
||||
"tdmpc",
|
||||
]
|
||||
|
|
|
@ -9,7 +9,7 @@ import tqdm
|
|||
from huggingface_hub import snapshot_download
|
||||
from tensordict import TensorDict
|
||||
from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer
|
||||
from torchrl.data.replay_buffers.samplers import SliceSampler
|
||||
from torchrl.data.replay_buffers.samplers import Sampler
|
||||
from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id
|
||||
from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
|
||||
from torchrl.envs.transforms.transforms import Compose
|
||||
|
@ -17,22 +17,56 @@ from torchrl.envs.transforms.transforms import Compose
|
|||
HF_USER = "lerobot"
|
||||
|
||||
|
||||
class AbstractExperienceReplay(TensorDictReplayBuffer):
|
||||
class AbstractDataset(TensorDictReplayBuffer):
|
||||
"""
|
||||
AbstractDataset represents a dataset in the context of imitation learning or reinforcement learning.
|
||||
This class is designed to be subclassed by concrete implementations that specify particular types of datasets.
|
||||
These implementations can vary based on the source of the data, the environment the data pertains to,
|
||||
or the specific kind of data manipulation applied.
|
||||
|
||||
Note:
|
||||
- `TensorDictReplayBuffer` is the base class from which `AbstractDataset` inherits. It provides the foundational
|
||||
functionality for storing and retrieving `TensorDict`-like data.
|
||||
- `available_datasets` should be overridden by concrete subclasses to list the specific dataset variants supported.
|
||||
It is expected that these variants correspond to a HuggingFace dataset on the hub.
|
||||
For instance, the `AlohaDataset` which inherites from `AbstractDataset` has 4 available dataset variants:
|
||||
- [aloha_sim_transfer_cube_scripted](https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_scripted)
|
||||
- [aloha_sim_insertion_scripted](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_scripted)
|
||||
- [aloha_sim_transfer_cube_human](https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_human)
|
||||
- [aloha_sim_insertion_human](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_human)
|
||||
- When implementing a concrete class (e.g. `AlohaDataset`, `PushtEnv`, `DiffusionPolicy`), you need to:
|
||||
1. set the required class attributes:
|
||||
- for classes inheriting from `AbstractDataset`: `available_datasets`
|
||||
- for classes inheriting from `AbstractEnv`: `name`, `available_tasks`
|
||||
- for classes inheriting from `AbstractPolicy`: `name`
|
||||
2. update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`)
|
||||
3. update variables in `tests/test_available.py` by importing your new class
|
||||
"""
|
||||
|
||||
available_datasets: list[str] | None = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset_id: str,
|
||||
version: str | None = None,
|
||||
batch_size: int = None,
|
||||
batch_size: int | None = None,
|
||||
*,
|
||||
shuffle: bool = True,
|
||||
root: Path | None = None,
|
||||
pin_memory: bool = False,
|
||||
prefetch: int = None,
|
||||
sampler: SliceSampler = None,
|
||||
collate_fn: Callable = None,
|
||||
writer: Writer = None,
|
||||
sampler: Sampler | None = None,
|
||||
collate_fn: Callable | None = None,
|
||||
writer: Writer | None = None,
|
||||
transform: "torchrl.envs.Transform" = None,
|
||||
):
|
||||
assert (
|
||||
self.available_datasets is not None
|
||||
), "Subclasses of `AbstractDataset` should set the `available_datasets` class attribute."
|
||||
assert (
|
||||
dataset_id in self.available_datasets
|
||||
), f"The provided dataset ({dataset_id}) is not on the list of available datasets {self.available_datasets}."
|
||||
|
||||
self.dataset_id = dataset_id
|
||||
self.version = version
|
||||
self.shuffle = shuffle
|
||||
|
|
|
@ -9,11 +9,11 @@ import torch
|
|||
import torchrl
|
||||
import tqdm
|
||||
from tensordict import TensorDict
|
||||
from torchrl.data.replay_buffers.samplers import SliceSampler
|
||||
from torchrl.data.replay_buffers.samplers import Sampler
|
||||
from torchrl.data.replay_buffers.storages import TensorStorage
|
||||
from torchrl.data.replay_buffers.writers import Writer
|
||||
|
||||
from lerobot.common.datasets.abstract import AbstractExperienceReplay
|
||||
from lerobot.common.datasets.abstract import AbstractDataset
|
||||
|
||||
DATASET_IDS = [
|
||||
"aloha_sim_insertion_human",
|
||||
|
@ -80,24 +80,24 @@ def download(data_dir, dataset_id):
|
|||
gdown.download(EP49_URLS[dataset_id], output=str(data_dir / "episode_49.hdf5"), fuzzy=True)
|
||||
|
||||
|
||||
class AlohaExperienceReplay(AbstractExperienceReplay):
|
||||
class AlohaDataset(AbstractDataset):
|
||||
available_datasets = DATASET_IDS
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset_id: str,
|
||||
version: str | None = "v1.2",
|
||||
batch_size: int = None,
|
||||
batch_size: int | None = None,
|
||||
*,
|
||||
shuffle: bool = True,
|
||||
root: Path | None = None,
|
||||
pin_memory: bool = False,
|
||||
prefetch: int = None,
|
||||
sampler: SliceSampler = None,
|
||||
collate_fn: Callable = None,
|
||||
writer: Writer = None,
|
||||
sampler: Sampler | None = None,
|
||||
collate_fn: Callable | None = None,
|
||||
writer: Writer | None = None,
|
||||
transform: "torchrl.envs.Transform" = None,
|
||||
):
|
||||
assert dataset_id in DATASET_IDS
|
||||
|
||||
super().__init__(
|
||||
dataset_id,
|
||||
version,
|
||||
|
|
|
@ -5,7 +5,7 @@ from pathlib import Path
|
|||
import torch
|
||||
from torchrl.data.replay_buffers import PrioritizedSliceSampler, SliceSampler
|
||||
|
||||
from lerobot.common.envs.transforms import NormalizeTransform, Prod
|
||||
from lerobot.common.transforms import NormalizeTransform, Prod
|
||||
|
||||
# DATA_DIR specifies to location where datasets are loaded. By default, DATA_DIR is None and
|
||||
# we load from `$HOME/.cache/huggingface/hub/datasets`. For our unit tests, we set `DATA_DIR=tests/data`
|
||||
|
@ -16,6 +16,7 @@ DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
|
|||
def make_offline_buffer(
|
||||
cfg,
|
||||
overwrite_sampler=None,
|
||||
# set normalize=False to remove all transformations and keep images unnormalized in [0,255]
|
||||
normalize=True,
|
||||
overwrite_batch_size=None,
|
||||
overwrite_prefetch=None,
|
||||
|
@ -64,25 +65,27 @@ def make_offline_buffer(
|
|||
sampler = overwrite_sampler
|
||||
|
||||
if cfg.env.name == "simxarm":
|
||||
from lerobot.common.datasets.simxarm import SimxarmExperienceReplay
|
||||
from lerobot.common.datasets.simxarm import SimxarmDataset
|
||||
|
||||
clsfunc = SimxarmExperienceReplay
|
||||
dataset_id = f"xarm_{cfg.env.task}_medium"
|
||||
clsfunc = SimxarmDataset
|
||||
|
||||
elif cfg.env.name == "pusht":
|
||||
from lerobot.common.datasets.pusht import PushtExperienceReplay
|
||||
from lerobot.common.datasets.pusht import PushtDataset
|
||||
|
||||
clsfunc = PushtExperienceReplay
|
||||
dataset_id = "pusht"
|
||||
clsfunc = PushtDataset
|
||||
|
||||
elif cfg.env.name == "aloha":
|
||||
from lerobot.common.datasets.aloha import AlohaExperienceReplay
|
||||
from lerobot.common.datasets.aloha import AlohaDataset
|
||||
|
||||
clsfunc = AlohaExperienceReplay
|
||||
dataset_id = f"aloha_{cfg.env.task}"
|
||||
clsfunc = AlohaDataset
|
||||
else:
|
||||
raise ValueError(cfg.env.name)
|
||||
|
||||
# TODO(rcadene): backward compatiblity to load pretrained pusht policy
|
||||
dataset_id = cfg.get("dataset_id")
|
||||
if dataset_id is None and cfg.env.name == "pusht":
|
||||
dataset_id = "pusht"
|
||||
|
||||
offline_buffer = clsfunc(
|
||||
dataset_id=dataset_id,
|
||||
sampler=sampler,
|
||||
|
@ -100,36 +103,40 @@ def make_offline_buffer(
|
|||
else:
|
||||
img_keys = offline_buffer.image_keys
|
||||
|
||||
transforms = [Prod(in_keys=img_keys, prod=1 / 255)]
|
||||
if normalize:
|
||||
transforms = [Prod(in_keys=img_keys, prod=1 / 255)]
|
||||
|
||||
if normalize:
|
||||
# TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max,
|
||||
# min_max_from_spec
|
||||
stats = offline_buffer.compute_or_load_stats() if stats_path is None else torch.load(stats_path)
|
||||
# TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max,
|
||||
# min_max_from_spec
|
||||
stats = offline_buffer.compute_or_load_stats() if stats_path is None else torch.load(stats_path)
|
||||
|
||||
# we only normalize the state and action, since the images are usually normalized inside the model for
|
||||
# now (except for tdmpc: see the following)
|
||||
in_keys = [("observation", "state"), ("action")]
|
||||
# we only normalize the state and action, since the images are usually normalized inside the model for
|
||||
# now (except for tdmpc: see the following)
|
||||
in_keys = [("observation", "state"), ("action")]
|
||||
|
||||
if cfg.policy.name == "tdmpc":
|
||||
# TODO(rcadene): we add img_keys to the keys to normalize for tdmpc only, since diffusion and act policies normalize the image inside the model for now
|
||||
in_keys += img_keys
|
||||
# TODO(racdene): since we use next observations in tdmpc, we also add them to the normalization. We are wasting a bit of compute on this for now.
|
||||
in_keys += [("next", *key) for key in img_keys]
|
||||
in_keys.append(("next", "observation", "state"))
|
||||
if cfg.policy.name == "tdmpc":
|
||||
# TODO(rcadene): we add img_keys to the keys to normalize for tdmpc only, since diffusion and act policies normalize the image inside the model for now
|
||||
in_keys += img_keys
|
||||
# TODO(racdene): since we use next observations in tdmpc, we also add them to the normalization. We are wasting a bit of compute on this for now.
|
||||
in_keys += [("next", *key) for key in img_keys]
|
||||
in_keys.append(("next", "observation", "state"))
|
||||
|
||||
if cfg.policy.name == "diffusion" and cfg.env.name == "pusht":
|
||||
# TODO(rcadene): we overwrite stats to have the same as pretrained model, but we should remove this
|
||||
stats["observation", "state", "min"] = torch.tensor([13.456424, 32.938293], dtype=torch.float32)
|
||||
stats["observation", "state", "max"] = torch.tensor([496.14618, 510.9579], dtype=torch.float32)
|
||||
stats["action", "min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
|
||||
stats["action", "max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
|
||||
if cfg.policy.name == "diffusion" and cfg.env.name == "pusht":
|
||||
# TODO(rcadene): we overwrite stats to have the same as pretrained model, but we should remove this
|
||||
stats["observation", "state", "min"] = torch.tensor(
|
||||
[13.456424, 32.938293], dtype=torch.float32
|
||||
)
|
||||
stats["observation", "state", "max"] = torch.tensor(
|
||||
[496.14618, 510.9579], dtype=torch.float32
|
||||
)
|
||||
stats["action", "min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
|
||||
stats["action", "max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
|
||||
|
||||
# TODO(rcadene): remove this and put it in config. Ideally we want to reproduce SOTA results just with mean_std
|
||||
normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max"
|
||||
transforms.append(NormalizeTransform(stats, in_keys, mode=normalization_mode))
|
||||
# TODO(rcadene): remove this and put it in config. Ideally we want to reproduce SOTA results just with mean_std
|
||||
normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max"
|
||||
transforms.append(NormalizeTransform(stats, in_keys, mode=normalization_mode))
|
||||
|
||||
offline_buffer.set_transform(transforms)
|
||||
offline_buffer.set_transform(transforms)
|
||||
|
||||
if not overwrite_sampler:
|
||||
index = torch.arange(0, offline_buffer.num_samples, 1)
|
||||
|
|
|
@ -9,11 +9,11 @@ import torch
|
|||
import torchrl
|
||||
import tqdm
|
||||
from tensordict import TensorDict
|
||||
from torchrl.data.replay_buffers.samplers import SliceSampler
|
||||
from torchrl.data.replay_buffers.samplers import Sampler
|
||||
from torchrl.data.replay_buffers.storages import TensorStorage
|
||||
from torchrl.data.replay_buffers.writers import Writer
|
||||
|
||||
from lerobot.common.datasets.abstract import AbstractExperienceReplay
|
||||
from lerobot.common.datasets.abstract import AbstractDataset
|
||||
from lerobot.common.datasets.utils import download_and_extract_zip
|
||||
from lerobot.common.envs.pusht.pusht_env import pymunk_to_shapely
|
||||
from lerobot.common.policies.diffusion.replay_buffer import ReplayBuffer as DiffusionPolicyReplayBuffer
|
||||
|
@ -83,20 +83,22 @@ def add_tee(
|
|||
return body
|
||||
|
||||
|
||||
class PushtExperienceReplay(AbstractExperienceReplay):
|
||||
class PushtDataset(AbstractDataset):
|
||||
available_datasets = ["pusht"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset_id: str,
|
||||
version: str | None = "v1.2",
|
||||
batch_size: int = None,
|
||||
batch_size: int | None = None,
|
||||
*,
|
||||
shuffle: bool = True,
|
||||
root: Path | None = None,
|
||||
pin_memory: bool = False,
|
||||
prefetch: int = None,
|
||||
sampler: SliceSampler = None,
|
||||
collate_fn: Callable = None,
|
||||
writer: Writer = None,
|
||||
sampler: Sampler | None = None,
|
||||
collate_fn: Callable | None = None,
|
||||
writer: Writer | None = None,
|
||||
transform: "torchrl.envs.Transform" = None,
|
||||
):
|
||||
super().__init__(
|
||||
|
|
|
@ -8,12 +8,12 @@ import torchrl
|
|||
import tqdm
|
||||
from tensordict import TensorDict
|
||||
from torchrl.data.replay_buffers.samplers import (
|
||||
SliceSampler,
|
||||
Sampler,
|
||||
)
|
||||
from torchrl.data.replay_buffers.storages import TensorStorage
|
||||
from torchrl.data.replay_buffers.writers import Writer
|
||||
|
||||
from lerobot.common.datasets.abstract import AbstractExperienceReplay
|
||||
from lerobot.common.datasets.abstract import AbstractDataset
|
||||
|
||||
|
||||
def download():
|
||||
|
@ -32,7 +32,7 @@ def download():
|
|||
Path(download_path).unlink()
|
||||
|
||||
|
||||
class SimxarmExperienceReplay(AbstractExperienceReplay):
|
||||
class SimxarmDataset(AbstractDataset):
|
||||
available_datasets = [
|
||||
"xarm_lift_medium",
|
||||
]
|
||||
|
@ -41,15 +41,15 @@ class SimxarmExperienceReplay(AbstractExperienceReplay):
|
|||
self,
|
||||
dataset_id: str,
|
||||
version: str | None = "v1.1",
|
||||
batch_size: int = None,
|
||||
batch_size: int | None = None,
|
||||
*,
|
||||
shuffle: bool = True,
|
||||
root: Path | None = None,
|
||||
pin_memory: bool = False,
|
||||
prefetch: int = None,
|
||||
sampler: SliceSampler = None,
|
||||
collate_fn: Callable = None,
|
||||
writer: Writer = None,
|
||||
sampler: Sampler | None = None,
|
||||
collate_fn: Callable | None = None,
|
||||
writer: Writer | None = None,
|
||||
transform: "torchrl.envs.Transform" = None,
|
||||
):
|
||||
super().__init__(
|
||||
|
|
|
@ -8,6 +8,20 @@ from lerobot.common.utils import set_global_seed
|
|||
|
||||
|
||||
class AbstractEnv(EnvBase):
|
||||
"""
|
||||
Note:
|
||||
When implementing a concrete class (e.g. `AlohaDataset`, `PushtEnv`, `DiffusionPolicy`), you need to:
|
||||
1. set the required class attributes:
|
||||
- for classes inheriting from `AbstractDataset`: `available_datasets`
|
||||
- for classes inheriting from `AbstractEnv`: `name`, `available_tasks`
|
||||
- for classes inheriting from `AbstractPolicy`: `name`
|
||||
2. update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`)
|
||||
3. update variables in `tests/test_available.py` by importing your new class
|
||||
"""
|
||||
|
||||
name: str | None = None # same name should be used to instantiate the environment in factory.py
|
||||
available_tasks: list[str] | None = None # for instance: sim_insertion, sim_transfer_cube, pusht, lift
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
task,
|
||||
|
@ -21,6 +35,14 @@ class AbstractEnv(EnvBase):
|
|||
num_prev_action=0,
|
||||
):
|
||||
super().__init__(device=device, batch_size=[])
|
||||
assert self.name is not None, "Subclasses of `AbstractEnv` should set the `name` class attribute."
|
||||
assert (
|
||||
self.available_tasks is not None
|
||||
), "Subclasses of `AbstractEnv` should set the `available_tasks` class attribute."
|
||||
assert (
|
||||
task in self.available_tasks
|
||||
), f"The provided task ({task}) is not on the list of available tasks {self.available_tasks}."
|
||||
|
||||
self.task = task
|
||||
self.frame_skip = frame_skip
|
||||
self.from_pixels = from_pixels
|
||||
|
|
|
@ -35,6 +35,8 @@ _has_gym = importlib.util.find_spec("gymnasium") is not None
|
|||
|
||||
|
||||
class AlohaEnv(AbstractEnv):
|
||||
name = "aloha"
|
||||
available_tasks = ["sim_insertion", "sim_transfer_cube"]
|
||||
_reset_warning_issued = False
|
||||
|
||||
def __init__(
|
||||
|
|
|
@ -22,6 +22,8 @@ _has_gym = importlib.util.find_spec("gymnasium") is not None
|
|||
|
||||
|
||||
class PushtEnv(AbstractEnv):
|
||||
name = "pusht"
|
||||
available_tasks = ["pusht"]
|
||||
_reset_warning_issued = False
|
||||
|
||||
def __init__(
|
||||
|
|
|
@ -24,6 +24,9 @@ _has_gym = importlib.util.find_spec("gymnasium") is not None
|
|||
|
||||
|
||||
class SimxarmEnv(AbstractEnv):
|
||||
name = "simxarm"
|
||||
available_tasks = ["lift"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
task,
|
||||
|
|
|
@ -9,8 +9,19 @@ class AbstractPolicy(nn.Module):
|
|||
|
||||
The forward method should generally not be overriden as it plays the role of handling multi-step policies. See its
|
||||
documentation for more information.
|
||||
|
||||
Note:
|
||||
When implementing a concrete class (e.g. `AlohaDataset`, `PushtEnv`, `DiffusionPolicy`), you need to:
|
||||
1. set the required class attributes:
|
||||
- for classes inheriting from `AbstractDataset`: `available_datasets`
|
||||
- for classes inheriting from `AbstractEnv`: `name`, `available_tasks`
|
||||
- for classes inheriting from `AbstractPolicy`: `name`
|
||||
2. update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`)
|
||||
3. update variables in `tests/test_available.py` by importing your new class
|
||||
"""
|
||||
|
||||
name: str | None = None # same name should be used to instantiate the policy in factory.py
|
||||
|
||||
def __init__(self, n_action_steps: int | None):
|
||||
"""
|
||||
n_action_steps: Sets the cache size for storing action trajectories. If None, it is assumed that a single
|
||||
|
@ -18,6 +29,7 @@ class AbstractPolicy(nn.Module):
|
|||
adds that dimension.
|
||||
"""
|
||||
super().__init__()
|
||||
assert self.name is not None, "Subclasses of `AbstractPolicy` should set the `name` class attribute."
|
||||
self.n_action_steps = n_action_steps
|
||||
self.clear_action_queue()
|
||||
|
||||
|
|
|
@ -42,6 +42,8 @@ def kl_divergence(mu, logvar):
|
|||
|
||||
|
||||
class ActionChunkingTransformerPolicy(AbstractPolicy):
|
||||
name = "act"
|
||||
|
||||
def __init__(self, cfg, device, n_action_steps=1):
|
||||
super().__init__(n_action_steps)
|
||||
self.cfg = cfg
|
||||
|
|
|
@ -13,6 +13,8 @@ from lerobot.common.utils import get_safe_torch_device
|
|||
|
||||
|
||||
class DiffusionPolicy(AbstractPolicy):
|
||||
name = "diffusion"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cfg,
|
||||
|
|
|
@ -3,9 +3,9 @@ def make_policy(cfg):
|
|||
raise NotImplementedError("Only diffusion policy supports rollout_batch_size > 1 for the time being.")
|
||||
|
||||
if cfg.policy.name == "tdmpc":
|
||||
from lerobot.common.policies.tdmpc.policy import TDMPC
|
||||
from lerobot.common.policies.tdmpc.policy import TDMPCPolicy
|
||||
|
||||
policy = TDMPC(cfg.policy, cfg.device)
|
||||
policy = TDMPCPolicy(cfg.policy, cfg.device)
|
||||
elif cfg.policy.name == "diffusion":
|
||||
from lerobot.common.policies.diffusion.policy import DiffusionPolicy
|
||||
|
||||
|
|
|
@ -87,9 +87,11 @@ class TOLD(nn.Module):
|
|||
return torch.min(Q1, Q2) if return_type == "min" else (Q1 + Q2) / 2
|
||||
|
||||
|
||||
class TDMPC(AbstractPolicy):
|
||||
class TDMPCPolicy(AbstractPolicy):
|
||||
"""Implementation of TD-MPC learning + inference."""
|
||||
|
||||
name = "tdmpc"
|
||||
|
||||
def __init__(self, cfg, device):
|
||||
super().__init__(None)
|
||||
self.action_dim = cfg.action_dim
|
||||
|
|
|
@ -26,6 +26,8 @@ fps: ???
|
|||
|
||||
offline_prioritized_sampler: true
|
||||
|
||||
dataset_id: ???
|
||||
|
||||
n_action_steps: ???
|
||||
n_obs_steps: ???
|
||||
env: ???
|
||||
|
|
|
@ -10,9 +10,11 @@ online_steps: 25000
|
|||
|
||||
fps: 50
|
||||
|
||||
dataset_id: aloha_sim_insertion_human
|
||||
|
||||
env:
|
||||
name: aloha
|
||||
task: sim_insertion_human
|
||||
task: sim_insertion
|
||||
from_pixels: True
|
||||
pixels_only: False
|
||||
image_size: [3, 480, 640]
|
||||
|
|
|
@ -10,6 +10,8 @@ online_steps: 25000
|
|||
|
||||
fps: 10
|
||||
|
||||
dataset_id: pusht
|
||||
|
||||
env:
|
||||
name: pusht
|
||||
task: pusht
|
||||
|
|
|
@ -9,6 +9,8 @@ online_steps: 25000
|
|||
|
||||
fps: 15
|
||||
|
||||
dataset_id: xarm_lift_medium
|
||||
|
||||
env:
|
||||
name: simxarm
|
||||
task: lift
|
||||
|
|
|
@ -13,8 +13,10 @@ Examples:
|
|||
You have a specific config file to go with trained model weights, and want to run 10 episodes.
|
||||
|
||||
```
|
||||
python lerobot/scripts/eval.py --config PATH/TO/FOLDER/config.yaml \
|
||||
policy.pretrained_model_path=PATH/TO/FOLDER/weights.pth` eval_episodes=10
|
||||
python lerobot/scripts/eval.py \
|
||||
--config PATH/TO/FOLDER/config.yaml \
|
||||
policy.pretrained_model_path=PATH/TO/FOLDER/weights.pth \
|
||||
eval_episodes=10
|
||||
```
|
||||
|
||||
You have a HuggingFace Hub ID, you know which revision you want, and want to run 10 episodes (note that in this case,
|
||||
|
|
|
@ -25,7 +25,7 @@ def visualize_dataset_cli(cfg: dict):
|
|||
|
||||
|
||||
def cat_and_write_video(video_path, frames, fps):
|
||||
# Expects images in [0, 1].
|
||||
# Expects images in [0, 255].
|
||||
frames = torch.cat(frames)
|
||||
assert frames.max() <= 1 and frames.min() >= 0
|
||||
frames = (255 * frames).to(dtype=torch.uint8)
|
||||
|
@ -47,44 +47,63 @@ def visualize_dataset(cfg: dict, out_dir=None):
|
|||
|
||||
logging.info("make_offline_buffer")
|
||||
offline_buffer = make_offline_buffer(
|
||||
cfg, overwrite_sampler=sampler, normalize=False, overwrite_batch_size=1, overwrite_prefetch=12
|
||||
cfg,
|
||||
overwrite_sampler=sampler,
|
||||
# remove all transformations such as rescale images from [0,255] to [0,1] or normalization
|
||||
normalize=False,
|
||||
overwrite_batch_size=1,
|
||||
overwrite_prefetch=12,
|
||||
)
|
||||
|
||||
logging.info("Start rendering episodes from offline buffer")
|
||||
video_paths = render_dataset(offline_buffer, out_dir, MAX_NUM_STEPS * NUM_EPISODES_TO_RENDER, cfg.fps)
|
||||
for video_path in video_paths:
|
||||
logging.info(video_path)
|
||||
|
||||
|
||||
def render_dataset(offline_buffer, out_dir, max_num_samples, fps):
|
||||
out_dir = Path(out_dir)
|
||||
video_paths = []
|
||||
threads = []
|
||||
frames = {}
|
||||
current_ep_idx = 0
|
||||
logging.info(f"Visualizing episode {current_ep_idx}")
|
||||
for _ in range(MAX_NUM_STEPS * NUM_EPISODES_TO_RENDER):
|
||||
for i in range(max_num_samples):
|
||||
# TODO(rcadene): make it work with bsize > 1
|
||||
ep_td = offline_buffer.sample(1)
|
||||
ep_idx = ep_td["episode"][FIRST_FRAME].item()
|
||||
|
||||
# TODO(rcadene): modify offline_buffer._sampler._sample_list or sampler to randomly sample an episode, but sequentially sample frames
|
||||
no_more_frames = offline_buffer._sampler._sample_list.numel() == 0
|
||||
new_episode = ep_idx != current_ep_idx
|
||||
num_frames_left = offline_buffer._sampler._sample_list.numel()
|
||||
episode_is_done = ep_idx != current_ep_idx
|
||||
|
||||
if new_episode:
|
||||
logging.info(f"Visualizing episode {current_ep_idx}")
|
||||
if episode_is_done:
|
||||
logging.info(f"Rendering episode {current_ep_idx}")
|
||||
|
||||
for im_key in offline_buffer.image_keys:
|
||||
if new_episode or no_more_frames:
|
||||
# append last observed frames (the ones after last action taken)
|
||||
frames[im_key].append(offline_buffer.transform(ep_td["next"])[im_key])
|
||||
|
||||
video_dir = Path(out_dir) / "visualize_dataset"
|
||||
video_dir.mkdir(parents=True, exist_ok=True)
|
||||
if not episode_is_done and num_frames_left > 0 and i < (max_num_samples - 1):
|
||||
# when first frame of episode, initialize frames dict
|
||||
if im_key not in frames:
|
||||
frames[im_key] = []
|
||||
# add current frame to list of frames to render
|
||||
frames[im_key].append(ep_td[im_key])
|
||||
else:
|
||||
# When episode has no more frame in its list of observation,
|
||||
# one frame still remains. It is the result of the last action taken.
|
||||
# It is stored in `"next"`, so we add it to the list of frames to render.
|
||||
frames[im_key].append(ep_td["next"][im_key])
|
||||
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
if len(offline_buffer.image_keys) > 1:
|
||||
camera = im_key[-1]
|
||||
video_path = video_dir / f"episode_{current_ep_idx}_{camera}.mp4"
|
||||
video_path = out_dir / f"episode_{current_ep_idx}_{camera}.mp4"
|
||||
else:
|
||||
video_path = video_dir / f"episode_{current_ep_idx}.mp4"
|
||||
video_path = out_dir / f"episode_{current_ep_idx}.mp4"
|
||||
video_paths.append(str(video_path))
|
||||
|
||||
thread = threading.Thread(
|
||||
target=cat_and_write_video,
|
||||
args=(str(video_path), frames[im_key], cfg.fps),
|
||||
args=(str(video_path), frames[im_key], fps),
|
||||
)
|
||||
thread.start()
|
||||
threads.append(thread)
|
||||
|
@ -94,12 +113,7 @@ def visualize_dataset(cfg: dict, out_dir=None):
|
|||
# reset list of frames
|
||||
del frames[im_key]
|
||||
|
||||
# append current cameras images to list of frames
|
||||
if im_key not in frames:
|
||||
frames[im_key] = []
|
||||
frames[im_key].append(ep_td[im_key])
|
||||
|
||||
if no_more_frames:
|
||||
if num_frames_left == 0:
|
||||
logging.info("Ran out of frames")
|
||||
break
|
||||
|
||||
|
@ -110,6 +124,7 @@ def visualize_dataset(cfg: dict, out_dir=None):
|
|||
thread.join()
|
||||
|
||||
logging.info("End of visualize_dataset")
|
||||
return video_paths
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
Binary file not shown.
After Width: | Height: | Size: 199 KiB |
Binary file not shown.
After Width: | Height: | Size: 160 KiB |
|
@ -0,0 +1,64 @@
|
|||
"""
|
||||
This test verifies that all environments, datasets, policies listed in `lerobot/__init__.py` can be sucessfully
|
||||
imported and that their class attributes (eg. `available_datasets`, `name`, `available_tasks`) corresponds.
|
||||
|
||||
Note:
|
||||
When implementing a concrete class (e.g. `AlohaDataset`, `PushtEnv`, `DiffusionPolicy`), you need to:
|
||||
1. set the required class attributes:
|
||||
- for classes inheriting from `AbstractDataset`: `available_datasets`
|
||||
- for classes inheriting from `AbstractEnv`: `name`, `available_tasks`
|
||||
- for classes inheriting from `AbstractPolicy`: `name`
|
||||
2. update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`)
|
||||
3. update variables in `tests/test_available.py` by importing your new class
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import lerobot
|
||||
|
||||
from lerobot.common.envs.aloha.env import AlohaEnv
|
||||
from lerobot.common.envs.pusht.env import PushtEnv
|
||||
from lerobot.common.envs.simxarm import SimxarmEnv
|
||||
|
||||
from lerobot.common.datasets.simxarm import SimxarmDataset
|
||||
from lerobot.common.datasets.aloha import AlohaDataset
|
||||
from lerobot.common.datasets.pusht import PushtDataset
|
||||
|
||||
from lerobot.common.policies.act.policy import ActionChunkingTransformerPolicy
|
||||
from lerobot.common.policies.diffusion.policy import DiffusionPolicy
|
||||
from lerobot.common.policies.tdmpc.policy import TDMPCPolicy
|
||||
|
||||
|
||||
def test_available():
|
||||
pol_classes = [
|
||||
ActionChunkingTransformerPolicy,
|
||||
DiffusionPolicy,
|
||||
TDMPCPolicy,
|
||||
]
|
||||
|
||||
env_classes = [
|
||||
AlohaEnv,
|
||||
PushtEnv,
|
||||
SimxarmEnv,
|
||||
]
|
||||
|
||||
dat_classes = [
|
||||
AlohaDataset,
|
||||
PushtDataset,
|
||||
SimxarmDataset,
|
||||
]
|
||||
|
||||
policies = [pol_cls.name for pol_cls in pol_classes]
|
||||
assert set(policies) == set(lerobot.available_policies)
|
||||
|
||||
envs = [env_cls.name for env_cls in env_classes]
|
||||
assert set(envs) == set(lerobot.available_envs)
|
||||
|
||||
tasks_per_env = {env_cls.name: env_cls.available_tasks for env_cls in env_classes}
|
||||
for env in envs:
|
||||
assert set(tasks_per_env[env]) == set(lerobot.available_tasks_per_env[env])
|
||||
|
||||
datasets_per_env = {env_cls.name: dat_cls.available_datasets for env_cls, dat_cls in zip(env_classes, dat_classes)}
|
||||
for env in envs:
|
||||
assert set(datasets_per_env[env]) == set(lerobot.available_datasets_per_env[env])
|
||||
|
||||
|
|
@ -1,4 +1,3 @@
|
|||
import os
|
||||
import pytest
|
||||
from tensordict import TensorDict
|
||||
import torch
|
||||
|
|
|
@ -0,0 +1,18 @@
|
|||
import pytest
|
||||
from pathlib import Path
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
"examples/1_visualize_dataset.py",
|
||||
"examples/2_evaluate_pretrained_policy.py",
|
||||
"examples/3_train_policy.py",
|
||||
],
|
||||
)
|
||||
def test_example(path):
|
||||
|
||||
with open(path, 'r') as file:
|
||||
file_contents = file.read()
|
||||
exec(file_contents)
|
||||
|
||||
assert Path("outputs/visualize_dataset/example/episode_0.mp4").exists()
|
|
@ -13,17 +13,16 @@ from lerobot.common.policies.abstract import AbstractPolicy
|
|||
|
||||
from .utils import DEVICE, init_config
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"env_name,policy_name,extra_overrides",
|
||||
[
|
||||
("simxarm", "tdmpc", ["policy.mpc=true"]),
|
||||
("pusht", "tdmpc", ["policy.mpc=false"]),
|
||||
("pusht", "diffusion", []),
|
||||
("aloha", "act", ["env.task=sim_insertion_scripted"]),
|
||||
("aloha", "act", ["env.task=sim_insertion_human"]),
|
||||
("aloha", "act", ["env.task=sim_transfer_cube_scripted"]),
|
||||
("aloha", "act", ["env.task=sim_transfer_cube_human"]),
|
||||
("aloha", "act", ["env.task=sim_insertion", "dataset_id=aloha_sim_insertion_human"]),
|
||||
("aloha", "act", ["env.task=sim_insertion", "dataset_id=aloha_sim_insertion_scripted"]),
|
||||
("aloha", "act", ["env.task=sim_transfer_cube", "dataset_id=aloha_sim_transfer_cube_human"]),
|
||||
("aloha", "act", ["env.task=sim_transfer_cube", "dataset_id=aloha_sim_transfer_cube_scripted"]),
|
||||
# TODO(aliberts): simxarm not working with diffusion
|
||||
# ("simxarm", "diffusion", []),
|
||||
],
|
||||
|
@ -106,6 +105,8 @@ def test_abstract_policy_forward():
|
|||
return
|
||||
|
||||
class StubPolicy(AbstractPolicy):
|
||||
name = "stub"
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(n_action_steps)
|
||||
self.n_policy_invocations = 0
|
||||
|
|
Loading…
Reference in New Issue