This commit is contained in:
Alexander Soare 2024-05-20 08:40:57 +01:00
parent dbed2ee1aa
commit 43ebb3033e
2 changed files with 15 additions and 26 deletions

View File

@ -14,7 +14,7 @@ LeRobot offers a training script at [`lerobot/scripts/train.py`](../../lerobot/s
Explaining the ins and outs of [Hydra](https://hydra.cc/docs/intro/) is beyond the scope of this document, but here we'll share the main points you need to know.
First, consider that `lerobot/configs` might have a directory structure like this (this is the case at the time of writing):
First, `lerobot/configs` might have a directory structure like this:
```
.
@ -31,9 +31,15 @@ First, consider that `lerobot/configs` might have a directory structure like thi
**_For brevity, in the rest of this document we'll drop the leading `lerobot/configs` path. So `default.yaml` really refers to `lerobot/configs/default.yaml`._**
When you run the training script, Hydra takes over via the `@hydra.main` decorator. If you take a look at the `@hydra.main`'s arguments you will see `config_path="../configs", config_name="default"`. This means Hydra looks for `default.yaml` in `../configs` (which resolves to `lerobot/configs`).
When you run the training script with
Among regular configuration hyperparameters like `device: cuda`, `default.yaml` has a `defaults` section. It might look like this.
```python
python lerobot/scripts/train.py
```
Hydra takes over via the `@hydra.main` decorator. If you take a look at the `@hydra.main`'s arguments you will see `config_path="../configs", config_name="default"`. This means Hydra looks for `default.yaml` in `../configs` (which resolves to `lerobot/configs`).
Therefore, `default.yaml` is the first configuration file that Hydra considers. At the top of the file, is a `defaults` section which looks likes this:
```yaml
defaults:
@ -42,9 +48,9 @@ defaults:
- policy: diffusion
```
So, Hydra will grab `env/pusht.yaml` and `policy/diffusion.yaml` and incorporate their configuration parameters (any configuration parameters already present in `default.yaml` are overriden).
So, Hydra then grabs `env/pusht.yaml` and `policy/diffusion.yaml` and incorporates their configuration parameters as well (any configuration parameters already present in `default.yaml` are overriden).
## Running the training script with our provided configurations
Below the `defaults` section, `default.yaml` also contains regular configuration parameters.
If you want to train Diffusion Policy with PushT, you really only need to run:
@ -66,6 +72,8 @@ python lerobot/scripts/train.py policy=act env=aloha
**Notice, how the config overrides are passed** as `param_name=param_value`. This is the format the Hydra excepts for parsing the overrides.
_As an aside: we've set up our configurations so that they reproduce state-of-the-art results from papers in the literature._
## Overriding configuration parameters in the CLI
If you look in `env/aloha.yaml` you might see:
@ -144,14 +152,12 @@ There's one new thing here: `hydra.run.dir=outputs/train/act_aloha_sim_transfer_
---
Now, why don't you try running:
So far we've seen how to train Diffusion Policy for PushT and ACT for ALOHA. Now, what if we want to train ACT for PushT? Well, there are aspects of the ACT configuration that are specific to the ALOHA environments, and these happen to be incompatible with PushT. Therefore, trying to run the following will almost certainly raise an exception of sorts (eg: feature dimension mismatch):
```bash
python lerobot/scripts/train.py policy=act env=pusht dataset_repo_id=lerobot/pusht
```
That was a little mean of us, because if you did try running that code, you almost certainly got an exception of sorts. That's because there are aspects of the ACT configuration that are specific to the ALOHA environments, and here we have tried to use PushT.
Please, head on over to our advanced [tutorial on adapting policy configuration to various environments](./advanced/train_act_pusht/train_act_pusht.md).
Please, head on over to our [advanced tutorial on adapting policy configuration to various environments](./advanced/train_act_pusht/train_act_pusht.md) to learn more.
Or in the meantime, happy coding! 🤗

View File

@ -1,17 +0,0 @@
python lerobot/scripts/train.py \
hydra.job.name=act_pusht \
hydra.run.dir=outputs/train/act_pusht \
env=aloha \
env.task=AlohaInsertion-v0 \
dataset_repo_id=lerobot/pusht \
policy=act \
policy.use_vae=true \
training.eval_freq=10000 \
training.log_freq=250 \
training.offline_steps=100000 \
training.save_model=true \
training.save_freq=25000 \
eval.n_episodes=50 \
eval.batch_size=50 \
wandb.enable=false \
device=cuda \