Provide more information to the user (#358)
Co-authored-by: Alexander Soare <alexander.soare159@gmail.com> Co-authored-by: Remi <re.cadene@gmail.com>
This commit is contained in:
parent
b5ad79a7d3
commit
a2592a5563
|
@ -267,13 +267,20 @@ checkpoints
|
|||
│ └── training_state.pth # optimizer/scheduler/rng state and training step
|
||||
```
|
||||
|
||||
To resume training from a checkpoint, you can add these to the `train.py` python command:
|
||||
```bash
|
||||
hydra.run.dir=your/original/experiment/dir resume=true
|
||||
```
|
||||
|
||||
It will load the pretrained model, optimizer and scheduler states for training. For more information please see our tutorial on training resumption [here](https://github.com/huggingface/lerobot/blob/main/examples/5_resume_training.md).
|
||||
|
||||
To use wandb for logging training and evaluation curves, make sure you've run `wandb login` as a one-time setup step. Then, when running the training command above, enable WandB in the configuration by adding:
|
||||
|
||||
```bash
|
||||
wandb.enable=true
|
||||
```
|
||||
|
||||
A link to the wandb logs for the run will also show up in yellow in your terminal. Here is an example of what they look like in your browser:
|
||||
A link to the wandb logs for the run will also show up in yellow in your terminal. Here is an example of what they look like in your browser. Please also check [here](https://github.com/huggingface/lerobot/blob/main/examples/4_train_policy_with_script.md#typical-logs-and-metrics) for the explaination of some commonly used metrics in logs.
|
||||
|
||||

|
||||
|
||||
|
|
|
@ -170,6 +170,36 @@ python lerobot/scripts/train.py --config-dir outputs/train/my_experiment/checkpo
|
|||
|
||||
Note that you may still use the regular syntax for config parameter overrides (eg: by adding `training.offline_steps=200000`).
|
||||
|
||||
## Typical logs and metrics
|
||||
|
||||
When you start the training process, you will first see your full configuration being printed in the terminal. You can check it to make sure that you config it correctly and your config is not overrided by other files. The final configuration will also be saved with the checkpoint.
|
||||
|
||||
After that, you will see training log like this one:
|
||||
|
||||
```
|
||||
INFO 2024-08-14 13:35:12 ts/train.py:192 step:0 smpl:64 ep:1 epch:0.00 loss:1.112 grdn:15.387 lr:2.0e-07 updt_s:1.738 data_s:4.774
|
||||
```
|
||||
|
||||
or evaluation log like:
|
||||
|
||||
```
|
||||
INFO 2024-08-14 13:38:45 ts/train.py:226 step:100 smpl:6K ep:52 epch:0.25 ∑rwrd:20.693 success:0.0% eval_s:120.266
|
||||
```
|
||||
|
||||
These logs will also be saved in wandb if `wandb.enable` is set to `true`. Here are the meaning of some abbreviations:
|
||||
|
||||
- `smpl`: number of samples seen during training.
|
||||
- `ep`: number of episodes seen during training. An episode contains multiple samples in a complete manipulation task.
|
||||
- `epch`: number of time all unique samples are seen (epoch).
|
||||
- `grdn`: gradient norm.
|
||||
- `∑rwrd`: compute the sum of rewards in every evaluation episode and then take an average of them.
|
||||
- `success`: average success rate of eval episodes. Reward and success are usually different except for the sparsing reward setting, where reward=1 only when the task is completed successfully.
|
||||
- `eval_s`: time to evaluate the policy in the environment, in second.
|
||||
- `updt_s`: time to update the network parameters, in second.
|
||||
- `data_s`: time to load a batch of data, in second.
|
||||
|
||||
Some metrics are useful for initial performance profiling. For example, if you find the current GPU utilization is low via the `nvidia-smi` command and `data_s` sometimes is too high, you may need to modify batch size or number of dataloading workers to accelerate dataloading. We also recommend [pytorch profiler](https://github.com/huggingface/lerobot?tab=readme-ov-file#improve-your-code-with-profiling) for detailed performance probing.
|
||||
|
||||
---
|
||||
|
||||
So far we've seen how to train Diffusion Policy for PushT and ACT for ALOHA. Now, what if we want to train ACT for PushT? Well, there are aspects of the ACT configuration that are specific to the ALOHA environments, and these happen to be incompatible with PushT. Therefore, trying to run the following will almost certainly raise an exception of sorts (eg: feature dimension mismatch):
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import os
|
||||
import os.path as osp
|
||||
import random
|
||||
from contextlib import contextmanager
|
||||
|
@ -27,6 +28,12 @@ import torch
|
|||
from omegaconf import DictConfig
|
||||
|
||||
|
||||
def inside_slurm():
|
||||
"""Check whether the python process was launched through slurm"""
|
||||
# TODO(rcadene): return False for interactive mode `--pty bash`
|
||||
return "SLURM_JOB_ID" in os.environ
|
||||
|
||||
|
||||
def get_safe_torch_device(cfg_device: str, log: bool = False) -> torch.device:
|
||||
"""Given a string, return a torch.device with checks on whether the device is available."""
|
||||
match cfg_device:
|
||||
|
@ -158,7 +165,15 @@ def init_hydra_config(config_path: str, overrides: list[str] | None = None) -> D
|
|||
version_base="1.2",
|
||||
)
|
||||
cfg = hydra.compose(Path(config_path).stem, overrides)
|
||||
|
||||
if cfg.eval.batch_size > cfg.eval.n_episodes:
|
||||
raise ValueError(
|
||||
"The eval batch size is greater than the number of eval episodes "
|
||||
f"({cfg.eval.batch_size} > {cfg.eval.n_episodes}). As a result, {cfg.eval.batch_size} "
|
||||
f"eval environments will be instantiated, but only {cfg.eval.n_episodes} will be used. "
|
||||
"This might significantly slow down evaluation. To fix this, you should update your command "
|
||||
f"to increase the number of episodes to match the batch size (e.g. `eval.n_episodes={cfg.eval.batch_size}`), "
|
||||
f"or lower the batch size (e.g. `eval.batch_size={cfg.eval.n_episodes}`)."
|
||||
)
|
||||
return cfg
|
||||
|
||||
|
||||
|
|
|
@ -120,7 +120,7 @@ eval:
|
|||
# `batch_size` specifies the number of environments to use in a gym.vector.VectorEnv.
|
||||
batch_size: 1
|
||||
# `use_async_envs` specifies whether to use asynchronous environments (multiprocessing).
|
||||
use_async_envs: false
|
||||
use_async_envs: true
|
||||
|
||||
wandb:
|
||||
enable: false
|
||||
|
|
|
@ -2,6 +2,11 @@
|
|||
|
||||
fps: 50
|
||||
|
||||
eval:
|
||||
# `use_async_envs` specifies whether to use asynchronous environments (multiprocessing).
|
||||
# set it to false to avoid some problems of the aloha env
|
||||
use_async_envs: false
|
||||
|
||||
env:
|
||||
name: aloha
|
||||
task: AlohaInsertion-v0
|
||||
|
|
|
@ -2,6 +2,11 @@
|
|||
|
||||
fps: 15
|
||||
|
||||
eval:
|
||||
# `use_async_envs` specifies whether to use asynchronous environments (multiprocessing).
|
||||
# set it to false to avoid some problems of the aloha env
|
||||
use_async_envs: false
|
||||
|
||||
env:
|
||||
name: xarm
|
||||
task: XarmLift-v0
|
||||
|
|
|
@ -70,7 +70,13 @@ from lerobot.common.policies.factory import make_policy
|
|||
from lerobot.common.policies.policy_protocol import Policy
|
||||
from lerobot.common.policies.utils import get_device_from_parameters
|
||||
from lerobot.common.utils.io_utils import write_video
|
||||
from lerobot.common.utils.utils import get_safe_torch_device, init_hydra_config, init_logging, set_global_seed
|
||||
from lerobot.common.utils.utils import (
|
||||
get_safe_torch_device,
|
||||
init_hydra_config,
|
||||
init_logging,
|
||||
inside_slurm,
|
||||
set_global_seed,
|
||||
)
|
||||
|
||||
|
||||
def rollout(
|
||||
|
@ -79,7 +85,6 @@ def rollout(
|
|||
seeds: list[int] | None = None,
|
||||
return_observations: bool = False,
|
||||
render_callback: Callable[[gym.vector.VectorEnv], None] | None = None,
|
||||
enable_progbar: bool = False,
|
||||
) -> dict:
|
||||
"""Run a batched policy rollout once through a batch of environments.
|
||||
|
||||
|
@ -109,7 +114,6 @@ def rollout(
|
|||
are returned optionally because they typically take more memory to cache. Defaults to False.
|
||||
render_callback: Optional rendering callback to be used after the environments are reset, and after
|
||||
every step.
|
||||
enable_progbar: Enable a progress bar over rollout steps.
|
||||
Returns:
|
||||
The dictionary described above.
|
||||
"""
|
||||
|
@ -136,7 +140,7 @@ def rollout(
|
|||
progbar = trange(
|
||||
max_steps,
|
||||
desc=f"Running rollout with at most {max_steps} steps",
|
||||
disable=not enable_progbar,
|
||||
disable=inside_slurm(), # we dont want progress bar when we use slurm, since it clutters the logs
|
||||
leave=False,
|
||||
)
|
||||
while not np.all(done):
|
||||
|
@ -210,8 +214,6 @@ def eval_policy(
|
|||
videos_dir: Path | None = None,
|
||||
return_episode_data: bool = False,
|
||||
start_seed: int | None = None,
|
||||
enable_progbar: bool = False,
|
||||
enable_inner_progbar: bool = False,
|
||||
) -> dict:
|
||||
"""
|
||||
Args:
|
||||
|
@ -224,8 +226,6 @@ def eval_policy(
|
|||
the "episodes" key of the returned dictionary.
|
||||
start_seed: The first seed to use for the first individual rollout. For all subsequent rollouts the
|
||||
seed is incremented by 1. If not provided, the environments are not manually seeded.
|
||||
enable_progbar: Enable progress bar over batches.
|
||||
enable_inner_progbar: Enable progress bar over steps in each batch.
|
||||
Returns:
|
||||
Dictionary with metrics and data regarding the rollouts.
|
||||
"""
|
||||
|
@ -266,7 +266,8 @@ def eval_policy(
|
|||
if return_episode_data:
|
||||
episode_data: dict | None = None
|
||||
|
||||
progbar = trange(n_batches, desc="Stepping through eval batches", disable=not enable_progbar)
|
||||
# we dont want progress bar when we use slurm, since it clutters the logs
|
||||
progbar = trange(n_batches, desc="Stepping through eval batches", disable=inside_slurm())
|
||||
for batch_ix in progbar:
|
||||
# Cache frames for rendering videos. Each item will be (b, h, w, c), and the list indexes the rollout
|
||||
# step.
|
||||
|
@ -285,7 +286,6 @@ def eval_policy(
|
|||
seeds=list(seeds) if seeds else None,
|
||||
return_observations=return_episode_data,
|
||||
render_callback=render_frame if max_episodes_rendered > 0 else None,
|
||||
enable_progbar=enable_inner_progbar,
|
||||
)
|
||||
|
||||
# Figure out where in each rollout sequence the first done condition was encountered (results after
|
||||
|
@ -487,8 +487,6 @@ def main(
|
|||
max_episodes_rendered=10,
|
||||
videos_dir=Path(out_dir) / "videos",
|
||||
start_seed=hydra_cfg.seed,
|
||||
enable_progbar=True,
|
||||
enable_inner_progbar=True,
|
||||
)
|
||||
print(info["aggregated"])
|
||||
|
||||
|
|
|
@ -241,6 +241,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||
raise NotImplementedError()
|
||||
|
||||
init_logging()
|
||||
logging.info(pformat(OmegaConf.to_container(cfg)))
|
||||
|
||||
if cfg.training.online_steps > 0 and isinstance(cfg.dataset_repo_id, ListConfig):
|
||||
raise NotImplementedError("Online training with LeRobotMultiDataset is not implemented.")
|
||||
|
|
Loading…
Reference in New Issue