parent
7ec76ee235
commit
e3b9f1c19b
19
Makefile
19
Makefile
|
@ -40,7 +40,7 @@ test-act-ete-train:
|
|||
eval.n_episodes=1 \
|
||||
eval.batch_size=1 \
|
||||
device=cpu \
|
||||
training.save_model=true \
|
||||
training.save_checkpoint=true \
|
||||
training.save_freq=2 \
|
||||
policy.n_action_steps=20 \
|
||||
policy.chunk_size=20 \
|
||||
|
@ -49,7 +49,7 @@ test-act-ete-train:
|
|||
|
||||
test-act-ete-eval:
|
||||
python lerobot/scripts/eval.py \
|
||||
-p tests/outputs/act/checkpoints/000002 \
|
||||
-p tests/outputs/act/checkpoints/000002/pretrained_model \
|
||||
eval.n_episodes=1 \
|
||||
eval.batch_size=1 \
|
||||
env.episode_length=8 \
|
||||
|
@ -66,17 +66,17 @@ test-act-ete-train-amp:
|
|||
eval.n_episodes=1 \
|
||||
eval.batch_size=1 \
|
||||
device=cpu \
|
||||
training.save_model=true \
|
||||
training.save_checkpoint=true \
|
||||
training.save_freq=2 \
|
||||
policy.n_action_steps=20 \
|
||||
policy.chunk_size=20 \
|
||||
training.batch_size=2 \
|
||||
hydra.run.dir=tests/outputs/act/ \
|
||||
hydra.run.dir=tests/outputs/act_amp/ \
|
||||
use_amp=true
|
||||
|
||||
test-act-ete-eval-amp:
|
||||
python lerobot/scripts/eval.py \
|
||||
-p tests/outputs/act/checkpoints/000002 \
|
||||
-p tests/outputs/act_amp/checkpoints/000002/pretrained_model \
|
||||
eval.n_episodes=1 \
|
||||
eval.batch_size=1 \
|
||||
env.episode_length=8 \
|
||||
|
@ -96,14 +96,14 @@ test-diffusion-ete-train:
|
|||
eval.n_episodes=1 \
|
||||
eval.batch_size=1 \
|
||||
device=cpu \
|
||||
training.save_model=true \
|
||||
training.save_checkpoint=true \
|
||||
training.save_freq=2 \
|
||||
training.batch_size=2 \
|
||||
hydra.run.dir=tests/outputs/diffusion/
|
||||
|
||||
test-diffusion-ete-eval:
|
||||
python lerobot/scripts/eval.py \
|
||||
-p tests/outputs/diffusion/checkpoints/000002 \
|
||||
-p tests/outputs/diffusion/checkpoints/000002/pretrained_model \
|
||||
eval.n_episodes=1 \
|
||||
eval.batch_size=1 \
|
||||
env.episode_length=8 \
|
||||
|
@ -123,14 +123,14 @@ test-tdmpc-ete-train:
|
|||
eval.batch_size=1 \
|
||||
env.episode_length=2 \
|
||||
device=cpu \
|
||||
training.save_model=true \
|
||||
training.save_checkpoint=true \
|
||||
training.save_freq=2 \
|
||||
training.batch_size=2 \
|
||||
hydra.run.dir=tests/outputs/tdmpc/
|
||||
|
||||
test-tdmpc-ete-eval:
|
||||
python lerobot/scripts/eval.py \
|
||||
-p tests/outputs/tdmpc/checkpoints/000002 \
|
||||
-p tests/outputs/tdmpc/checkpoints/000002/pretrained_model \
|
||||
eval.n_episodes=1 \
|
||||
eval.batch_size=1 \
|
||||
env.episode_length=8 \
|
||||
|
@ -144,7 +144,6 @@ test-default-ete-eval:
|
|||
env.episode_length=8 \
|
||||
device=cpu \
|
||||
|
||||
|
||||
test-act-pusht-tutorial:
|
||||
cp examples/advanced/1_train_act_pusht/act_pusht.yaml lerobot/configs/policy/created_by_Makefile.yaml
|
||||
python lerobot/scripts/train.py \
|
||||
|
|
21
README.md
21
README.md
|
@ -154,9 +154,9 @@ python lerobot/scripts/eval.py \
|
|||
```
|
||||
|
||||
Note: After training your own policy, you can re-evaluate the checkpoints with:
|
||||
|
||||
```bash
|
||||
python lerobot/scripts/eval.py \
|
||||
-p PATH/TO/TRAIN/OUTPUT/FOLDER
|
||||
python lerobot/scripts/eval.py -p {OUTPUT_DIR}/checkpoints/last/pretrained_model
|
||||
```
|
||||
|
||||
See `python lerobot/scripts/eval.py --help` for more instructions.
|
||||
|
@ -180,6 +180,19 @@ The experiment directory is automatically generated and will show up in yellow i
|
|||
hydra.run.dir=your/new/experiment/dir
|
||||
```
|
||||
|
||||
In the experiment directory there will be a folder called `checkpoints` which will have the following structure:
|
||||
|
||||
```bash
|
||||
checkpoints
|
||||
├── 000250 # checkpoint_dir for training step 250
|
||||
│ ├── pretrained_model # Hugging Face pretrained model dir
|
||||
│ │ ├── config.json # Hugging Face pretrained model config
|
||||
│ │ ├── config.yaml # consolidated Hydra config
|
||||
│ │ ├── model.safetensors # model weights
|
||||
│ │ └── README.md # Hugging Face model card
|
||||
│ └── training_state.pth # optimizer/scheduler/rng state and training step
|
||||
```
|
||||
|
||||
To use wandb for logging training and evaluation curves, make sure you've run `wandb login` as a one-time setup step. Then, when running the training command above, enable WandB in the configuration by adding:
|
||||
|
||||
```bash
|
||||
|
@ -233,14 +246,14 @@ If your dataset format is not supported, implement your own in `lerobot/common/d
|
|||
|
||||
Once you have trained a policy you may upload it to the Hugging Face hub using a hub id that looks like `${hf_user}/${repo_name}` (e.g. [lerobot/diffusion_pusht](https://huggingface.co/lerobot/diffusion_pusht)).
|
||||
|
||||
You first need to find the checkpoint located inside your experiment directory (e.g. `outputs/train/2024-05-05/20-21-12_aloha_act_default/checkpoints/002500`). It should contain:
|
||||
You first need to find the checkpoint folder located inside your experiment directory (e.g. `outputs/train/2024-05-05/20-21-12_aloha_act_default/checkpoints/002500`). Within that there is a `pretrained_model` directory which should contain:
|
||||
- `config.json`: A serialized version of the policy configuration (following the policy's dataclass config).
|
||||
- `model.safetensors`: A set of `torch.nn.Module` parameters, saved in [Hugging Face Safetensors](https://huggingface.co/docs/safetensors/index) format.
|
||||
- `config.yaml`: A consolidated Hydra training configuration containing the policy, environment, and dataset configs. The policy configuration should match `config.json` exactly. The environment config is useful for anyone who wants to evaluate your policy. The dataset config just serves as a paper trail for reproducibility.
|
||||
|
||||
To upload these to the hub, run the following:
|
||||
```bash
|
||||
huggingface-cli upload ${hf_user}/${repo_name} path/to/checkpoint/dir
|
||||
huggingface-cli upload ${hf_user}/${repo_name} path/to/pretrained_model
|
||||
```
|
||||
|
||||
See [eval.py](https://github.com/huggingface/lerobot/blob/main/lerobot/scripts/eval.py) for an example of how other people may use your policy.
|
||||
|
|
|
@ -165,7 +165,7 @@ Note: here we use regular syntax for providing CLI arguments to a Python script,
|
|||
As a concrete example, this becomes particularly handy when you have a folder with training outputs, and would like to re-run the training. For example, say you previously ran the training script with one of the earlier commands and have `outputs/train/my_experiment/checkpoints/pretrained_model/config.yaml`. This `config.yaml` file will have the full set of configuration parameters within it. To run the training with the same configuration again, do:
|
||||
|
||||
```bash
|
||||
python lerobot/scripts/train.py --config-dir outputs/train/my_experiment/checkpoints/pretrained_model --config-name config
|
||||
python lerobot/scripts/train.py --config-dir outputs/train/my_experiment/checkpoints/last/pretrained_model --config-name config
|
||||
```
|
||||
|
||||
Note that you may still use the regular syntax for config parameter overrides (eg: by adding `training.offline_steps=200000`).
|
||||
|
|
|
@ -0,0 +1,37 @@
|
|||
This tutorial explains how to resume a training run that you've started with the training script. If you don't know how our training script and configuration system works, please read [4_train_policy_with_script.md](./4_train_policy_with_script.md) first.
|
||||
|
||||
## Basic training resumption
|
||||
|
||||
Let's consider the example of training ACT for one of the ALOHA tasks. Here's a command that can achieve that:
|
||||
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
hydra.run.dir=outputs/train/run_resumption \
|
||||
policy=act \
|
||||
dataset_repo_id=lerobot/aloha_sim_transfer_cube_human \
|
||||
env=aloha \
|
||||
env.task=AlohaTransferCube-v0 \
|
||||
training.log_freq=25 \
|
||||
training.save_checkpoint=true \
|
||||
training.save_freq=100
|
||||
```
|
||||
|
||||
Here we're using the default dataset and environment for ACT, and we've taken care to set up the log frequency and checkpointing frequency to low numbers so we can test resumption. You should be able to see some logging and have a first checkpoint within 1 minute. Please interrupt the training after the first checkpoint.
|
||||
|
||||
To resume, all that we have to do is run the training script, providing the run directory, and the resume option:
|
||||
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
hydra.run.dir=outputs/train/run_resumption \
|
||||
resume=true
|
||||
```
|
||||
|
||||
You should see from the logging that your training picks up from where it left off.
|
||||
|
||||
Note that with `resume=true`, the configuration file from the last checkpoint in the training output directory is loaded. So it doesn't matter that we haven't provided all the other configuration parameters from our previous command (although there may be warnings to notify you that your command has a different configuration than than the checkpoint).
|
||||
|
||||
---
|
||||
|
||||
Now you should know how to resume your training run in case it gets interrupted or you want to extend a finished training run.
|
||||
|
||||
Happy coding! 🤗
|
|
@ -21,6 +21,20 @@ from omegaconf import OmegaConf
|
|||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
|
||||
def resolve_delta_timestamps(cfg):
|
||||
"""Resolves delta_timestamps config key (in-place) by using `eval`.
|
||||
|
||||
Doesn't do anything if delta_timestamps is not specified or has already been resolve (as evidenced by
|
||||
the data type of its values).
|
||||
"""
|
||||
delta_timestamps = cfg.training.get("delta_timestamps")
|
||||
if delta_timestamps is not None:
|
||||
for key in delta_timestamps:
|
||||
if isinstance(delta_timestamps[key], str):
|
||||
# TODO(rcadene, alexander-soare): remove `eval` to avoid exploit
|
||||
cfg.training.delta_timestamps[key] = eval(delta_timestamps[key])
|
||||
|
||||
|
||||
def make_dataset(
|
||||
cfg,
|
||||
split="train",
|
||||
|
@ -31,18 +45,14 @@ def make_dataset(
|
|||
f"environment ({cfg.env.name=})."
|
||||
)
|
||||
|
||||
delta_timestamps = cfg.training.get("delta_timestamps")
|
||||
if delta_timestamps is not None:
|
||||
for key in delta_timestamps:
|
||||
if isinstance(delta_timestamps[key], str):
|
||||
delta_timestamps[key] = eval(delta_timestamps[key])
|
||||
resolve_delta_timestamps(cfg)
|
||||
|
||||
# TODO(rcadene): add data augmentations
|
||||
|
||||
dataset = LeRobotDataset(
|
||||
cfg.dataset_repo_id,
|
||||
split=split,
|
||||
delta_timestamps=delta_timestamps,
|
||||
delta_timestamps=cfg.training.get("delta_timestamps"),
|
||||
)
|
||||
|
||||
if cfg.get("override_dataset_stats"):
|
||||
|
|
|
@ -13,25 +13,33 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Borrowed from https://github.com/fyhMer/fowm/blob/main/src/logger.py
|
||||
|
||||
# TODO(rcadene, alexander-soare): clean this file
|
||||
"""Borrowed from https://github.com/fyhMer/fowm/blob/main/src/logger.py"""
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from glob import glob
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
|
||||
from omegaconf import OmegaConf
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from termcolor import colored
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import LRScheduler
|
||||
|
||||
from lerobot.common.policies.policy_protocol import Policy
|
||||
from lerobot.common.utils.utils import get_global_random_state, set_global_random_state
|
||||
|
||||
|
||||
def log_output_dir(out_dir):
|
||||
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {out_dir}")
|
||||
|
||||
|
||||
def cfg_to_group(cfg, return_list=False):
|
||||
def cfg_to_group(cfg: DictConfig, return_list: bool = False) -> list[str] | str:
|
||||
"""Return a group name for logging. Optionally returns group name as list."""
|
||||
lst = [
|
||||
f"policy:{cfg.policy.name}",
|
||||
|
@ -42,22 +50,54 @@ def cfg_to_group(cfg, return_list=False):
|
|||
return lst if return_list else "-".join(lst)
|
||||
|
||||
|
||||
class Logger:
|
||||
"""Primary logger object. Logs either locally or using wandb."""
|
||||
def get_wandb_run_id_from_filesystem(checkpoint_dir: Path) -> str:
|
||||
# Get the WandB run ID.
|
||||
paths = glob(str(checkpoint_dir / "../wandb/latest-run/run-*"))
|
||||
if len(paths) != 1:
|
||||
raise RuntimeError("Couldn't get the previous WandB run ID for run resumption.")
|
||||
match = re.search(r"run-([^\.]+).wandb", paths[0].split("/")[-1])
|
||||
if match is None:
|
||||
raise RuntimeError("Couldn't get the previous WandB run ID for run resumption.")
|
||||
wandb_run_id = match.groups(0)[0]
|
||||
return wandb_run_id
|
||||
|
||||
def __init__(self, log_dir, job_name, cfg):
|
||||
self._log_dir = Path(log_dir)
|
||||
self._log_dir.mkdir(parents=True, exist_ok=True)
|
||||
self._job_name = job_name
|
||||
self._model_dir = self._log_dir / "checkpoints"
|
||||
self._buffer_dir = self._log_dir / "buffers"
|
||||
self._save_model = cfg.training.save_model
|
||||
self._disable_wandb_artifact = cfg.wandb.disable_artifact
|
||||
self._save_buffer = cfg.training.get("save_buffer", False)
|
||||
self._group = cfg_to_group(cfg)
|
||||
self._seed = cfg.seed
|
||||
|
||||
class Logger:
|
||||
"""Primary logger object. Logs either locally or using wandb.
|
||||
|
||||
The logger creates the following directory structure:
|
||||
|
||||
provided_log_dir
|
||||
├── .hydra # hydra's configuration cache
|
||||
├── checkpoints
|
||||
│ ├── specific_checkpoint_name
|
||||
│ │ ├── pretrained_model # Hugging Face pretrained model directory
|
||||
│ │ │ ├── ...
|
||||
│ │ └── training_state.pth # optimizer, scheduler, and random states + training step
|
||||
| ├── another_specific_checkpoint_name
|
||||
│ │ ├── ...
|
||||
| ├── ...
|
||||
│ └── last # a softlink to the last logged checkpoint
|
||||
"""
|
||||
|
||||
pretrained_model_dir_name = "pretrained_model"
|
||||
training_state_file_name = "training_state.pth"
|
||||
|
||||
def __init__(self, cfg: DictConfig, log_dir: str, wandb_job_name: str | None = None):
|
||||
"""
|
||||
Args:
|
||||
log_dir: The directory to save all logs and training outputs to.
|
||||
job_name: The WandB job name.
|
||||
"""
|
||||
self._cfg = cfg
|
||||
self._eval = []
|
||||
self.log_dir = Path(log_dir)
|
||||
self.log_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.checkpoints_dir = self.get_checkpoints_dir(log_dir)
|
||||
self.last_checkpoint_dir = self.get_last_checkpoint_dir(log_dir)
|
||||
self.last_pretrained_model_dir = self.get_last_pretrained_model_dir(log_dir)
|
||||
|
||||
# Set up WandB.
|
||||
self._group = cfg_to_group(cfg)
|
||||
project = cfg.get("wandb", {}).get("project")
|
||||
entity = cfg.get("wandb", {}).get("entity")
|
||||
enable_wandb = cfg.get("wandb", {}).get("enable", False)
|
||||
|
@ -69,65 +109,127 @@ class Logger:
|
|||
os.environ["WANDB_SILENT"] = "true"
|
||||
import wandb
|
||||
|
||||
wandb_run_id = None
|
||||
if cfg.resume:
|
||||
wandb_run_id = get_wandb_run_id_from_filesystem(self.checkpoints_dir)
|
||||
|
||||
wandb.init(
|
||||
id=wandb_run_id,
|
||||
project=project,
|
||||
entity=entity,
|
||||
name=job_name,
|
||||
name=wandb_job_name,
|
||||
notes=cfg.get("wandb", {}).get("notes"),
|
||||
# group=self._group,
|
||||
tags=cfg_to_group(cfg, return_list=True),
|
||||
dir=self._log_dir,
|
||||
dir=log_dir,
|
||||
config=OmegaConf.to_container(cfg, resolve=True),
|
||||
# TODO(rcadene): try set to True
|
||||
save_code=False,
|
||||
# TODO(rcadene): split train and eval, and run async eval with job_type="eval"
|
||||
job_type="train_eval",
|
||||
# TODO(rcadene): add resume option
|
||||
resume=None,
|
||||
resume="must" if cfg.resume else None,
|
||||
)
|
||||
print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"]))
|
||||
logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}")
|
||||
self._wandb = wandb
|
||||
|
||||
def save_model(self, policy: Policy, identifier):
|
||||
if self._save_model:
|
||||
self._model_dir.mkdir(parents=True, exist_ok=True)
|
||||
save_dir = self._model_dir / str(identifier)
|
||||
policy.save_pretrained(save_dir)
|
||||
# Also save the full Hydra config for the env configuration.
|
||||
OmegaConf.save(self._cfg, save_dir / "config.yaml")
|
||||
if self._wandb and not self._disable_wandb_artifact:
|
||||
# note wandb artifact does not accept ":" or "/" in its name
|
||||
artifact = self._wandb.Artifact(
|
||||
f"{self._group.replace(':', '_').replace('/', '_')}-{self._seed}-{identifier}",
|
||||
type="model",
|
||||
)
|
||||
artifact.add_file(save_dir / SAFETENSORS_SINGLE_FILE)
|
||||
self._wandb.log_artifact(artifact)
|
||||
@classmethod
|
||||
def get_checkpoints_dir(cls, log_dir: str | Path) -> Path:
|
||||
"""Given the log directory, get the sub-directory in which checkpoints will be saved."""
|
||||
return Path(log_dir) / "checkpoints"
|
||||
|
||||
def save_buffer(self, buffer, identifier):
|
||||
self._buffer_dir.mkdir(parents=True, exist_ok=True)
|
||||
fp = self._buffer_dir / f"{str(identifier)}.pkl"
|
||||
buffer.save(fp)
|
||||
if self._wandb and not self._disable_wandb_artifact:
|
||||
@classmethod
|
||||
def get_last_checkpoint_dir(cls, log_dir: str | Path) -> Path:
|
||||
"""Given the log directory, get the sub-directory in which the last checkpoint will be saved."""
|
||||
return cls.get_checkpoints_dir(log_dir) / "last"
|
||||
|
||||
@classmethod
|
||||
def get_last_pretrained_model_dir(cls, log_dir: str | Path) -> Path:
|
||||
"""
|
||||
Given the log directory, get the sub-directory in which the last checkpoint's pretrained weights will
|
||||
be saved.
|
||||
"""
|
||||
return cls.get_last_checkpoint_dir(log_dir) / cls.pretrained_model_dir_name
|
||||
|
||||
def save_model(self, save_dir: Path, policy: Policy, wandb_artifact_name: str | None = None):
|
||||
"""Save the weights of the Policy model using PyTorchModelHubMixin.
|
||||
|
||||
The weights are saved in a folder called "pretrained_model" under the checkpoint directory.
|
||||
|
||||
Optionally also upload the model to WandB.
|
||||
"""
|
||||
self.checkpoints_dir.mkdir(parents=True, exist_ok=True)
|
||||
policy.save_pretrained(save_dir)
|
||||
# Also save the full Hydra config for the env configuration.
|
||||
OmegaConf.save(self._cfg, save_dir / "config.yaml")
|
||||
if self._wandb and not self._cfg.wandb.disable_artifact:
|
||||
# note wandb artifact does not accept ":" or "/" in its name
|
||||
artifact = self._wandb.Artifact(
|
||||
f"{self._group.replace(':', '_').replace('/', '_')}-{self._seed}-{identifier}",
|
||||
type="buffer",
|
||||
)
|
||||
artifact.add_file(fp)
|
||||
artifact = self._wandb.Artifact(wandb_artifact_name, type="model")
|
||||
artifact.add_file(save_dir / SAFETENSORS_SINGLE_FILE)
|
||||
self._wandb.log_artifact(artifact)
|
||||
if self.last_checkpoint_dir.exists():
|
||||
os.remove(self.last_checkpoint_dir)
|
||||
|
||||
def finish(self, agent, buffer):
|
||||
if self._save_model:
|
||||
self.save_model(agent, identifier="final")
|
||||
if self._save_buffer:
|
||||
self.save_buffer(buffer, identifier="buffer")
|
||||
if self._wandb:
|
||||
self._wandb.finish()
|
||||
def save_training_state(
|
||||
self,
|
||||
save_dir: Path,
|
||||
train_step: int,
|
||||
optimizer: Optimizer,
|
||||
scheduler: LRScheduler | None,
|
||||
):
|
||||
"""Checkpoint the global training_step, optimizer state, scheduler state, and random state.
|
||||
|
||||
All of these are saved as "training_state.pth" under the checkpoint directory.
|
||||
"""
|
||||
training_state = {
|
||||
"step": train_step,
|
||||
"optimizer": optimizer.state_dict(),
|
||||
**get_global_random_state(),
|
||||
}
|
||||
if scheduler is not None:
|
||||
training_state["scheduler"] = scheduler.state_dict()
|
||||
torch.save(training_state, save_dir / self.training_state_file_name)
|
||||
|
||||
def save_checkpont(
|
||||
self,
|
||||
train_step: int,
|
||||
policy: Policy,
|
||||
optimizer: Optimizer,
|
||||
scheduler: LRScheduler | None,
|
||||
identifier: str,
|
||||
):
|
||||
"""Checkpoint the model weights and the training state."""
|
||||
checkpoint_dir = self.checkpoints_dir / str(identifier)
|
||||
wandb_artifact_name = (
|
||||
None
|
||||
if self._wandb is None
|
||||
else f"{self._group.replace(':', '_').replace('/', '_')}-{self._cfg.seed}-{identifier}"
|
||||
)
|
||||
self.save_model(
|
||||
checkpoint_dir / self.pretrained_model_dir_name, policy, wandb_artifact_name=wandb_artifact_name
|
||||
)
|
||||
self.save_training_state(checkpoint_dir, train_step, optimizer, scheduler)
|
||||
os.symlink(checkpoint_dir.absolute(), self.last_checkpoint_dir)
|
||||
|
||||
def load_last_training_state(self, optimizer: Optimizer, scheduler: LRScheduler | None) -> int:
|
||||
"""
|
||||
Given the last checkpoint in the logging directory, load the optimizer state, scheduler state, and
|
||||
random state, and return the global training step.
|
||||
"""
|
||||
training_state = torch.load(self.last_checkpoint_dir / self.training_state_file_name)
|
||||
optimizer.load_state_dict(training_state["optimizer"])
|
||||
if scheduler is not None:
|
||||
scheduler.load_state_dict(training_state["scheduler"])
|
||||
elif "scheduler" in training_state:
|
||||
raise ValueError(
|
||||
"The checkpoint contains a scheduler state_dict, but no LRScheduler was provided."
|
||||
)
|
||||
# Small hack to get the expected keys: use `get_global_random_state`.
|
||||
set_global_random_state({k: training_state[k] for k in get_global_random_state()})
|
||||
return training_state["step"]
|
||||
|
||||
def log_dict(self, d, step, mode="train"):
|
||||
assert mode in {"train", "eval"}
|
||||
# TODO(alexander-soare): Add local text log.
|
||||
if self._wandb is not None:
|
||||
for k, v in d.items():
|
||||
if not isinstance(v, (int, float, str)):
|
||||
|
|
|
@ -19,7 +19,7 @@ import random
|
|||
from contextlib import contextmanager
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Generator
|
||||
from typing import Any, Generator
|
||||
|
||||
import hydra
|
||||
import numpy as np
|
||||
|
@ -48,12 +48,38 @@ def get_safe_torch_device(cfg_device: str, log: bool = False) -> torch.device:
|
|||
return device
|
||||
|
||||
|
||||
def get_global_random_state() -> dict[str, Any]:
|
||||
"""Get the random state for `random`, `numpy`, and `torch`."""
|
||||
random_state_dict = {
|
||||
"random_state": random.getstate(),
|
||||
"numpy_random_state": np.random.get_state(),
|
||||
"torch_random_state": torch.random.get_rng_state(),
|
||||
}
|
||||
if torch.cuda.is_available():
|
||||
random_state_dict["torch_cuda_random_state"] = torch.cuda.random.get_rng_state()
|
||||
return random_state_dict
|
||||
|
||||
|
||||
def set_global_random_state(random_state_dict: dict[str, Any]):
|
||||
"""Set the random state for `random`, `numpy`, and `torch`.
|
||||
|
||||
Args:
|
||||
random_state_dict: A dictionary of the form returned by `get_global_random_state`.
|
||||
"""
|
||||
random.setstate(random_state_dict["random_state"])
|
||||
np.random.set_state(random_state_dict["numpy_random_state"])
|
||||
torch.random.set_rng_state(random_state_dict["torch_random_state"])
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.random.set_rng_state(random_state_dict["torch_cuda_random_state"])
|
||||
|
||||
|
||||
def set_global_seed(seed):
|
||||
"""Set seed for reproducibility."""
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
|
||||
@contextmanager
|
||||
|
@ -69,16 +95,10 @@ def seeded_context(seed: int) -> Generator[None, None, None]:
|
|||
c = random.random() # produces yet another random number, but the same it would have if we never made `b`
|
||||
```
|
||||
"""
|
||||
random_state = random.getstate()
|
||||
np_random_state = np.random.get_state()
|
||||
torch_random_state = torch.random.get_rng_state()
|
||||
torch_cuda_random_state = torch.cuda.random.get_rng_state()
|
||||
random_state_dict = get_global_random_state()
|
||||
set_global_seed(seed)
|
||||
yield None
|
||||
random.setstate(random_state)
|
||||
np.random.set_state(np_random_state)
|
||||
torch.random.set_rng_state(torch_random_state)
|
||||
torch.cuda.random.set_rng_state(torch_cuda_random_state)
|
||||
set_global_random_state(random_state_dict)
|
||||
|
||||
|
||||
def init_logging():
|
||||
|
|
|
@ -5,10 +5,17 @@ defaults:
|
|||
|
||||
hydra:
|
||||
run:
|
||||
# Set `dir` to where you would like to save all of the run outputs. If you run another training session
|
||||
# with the same value for `dir` its contents will be overwritten unless you set `resume` to true.
|
||||
dir: outputs/train/${now:%Y-%m-%d}/${now:%H-%M-%S}_${env.name}_${policy.name}_${hydra.job.name}
|
||||
job:
|
||||
name: default
|
||||
|
||||
# Set `resume` to true to resume a previous run. In order for this to work, you will need to make sure
|
||||
# `hydra.run.dir` is the directory of an existing run with at least one checkpoint in it.
|
||||
# Note that when resuming a run, the default behavior is to use the configuration from the checkpoint,
|
||||
# regardless of what's provided with the training command at the time of resumption.
|
||||
resume: false
|
||||
device: cuda # cpu
|
||||
# `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP,
|
||||
# automatic gradient scaling is used.
|
||||
|
@ -29,7 +36,7 @@ training:
|
|||
eval_freq: ???
|
||||
save_freq: ???
|
||||
log_freq: 250
|
||||
save_model: true
|
||||
save_checkpoint: true
|
||||
|
||||
eval:
|
||||
n_episodes: 1
|
||||
|
@ -40,7 +47,7 @@ eval:
|
|||
|
||||
wandb:
|
||||
enable: false
|
||||
# Set to true to disable saving an artifact despite save_model == True
|
||||
# Set to true to disable saving an artifact despite save_checkpoint == True
|
||||
disable_artifact: false
|
||||
project: lerobot
|
||||
notes: ""
|
||||
|
|
|
@ -15,7 +15,7 @@ training:
|
|||
eval_freq: 10000
|
||||
save_freq: 100000
|
||||
log_freq: 250
|
||||
save_model: true
|
||||
save_checkpoint: true
|
||||
|
||||
batch_size: 8
|
||||
lr: 1e-5
|
||||
|
|
|
@ -27,7 +27,7 @@ training:
|
|||
eval_freq: 5000
|
||||
save_freq: 5000
|
||||
log_freq: 250
|
||||
save_model: true
|
||||
save_checkpoint: true
|
||||
|
||||
batch_size: 64
|
||||
grad_clip_norm: 10
|
||||
|
|
|
@ -28,7 +28,7 @@ OR, you want to evaluate a model checkpoint from the LeRobot training script for
|
|||
|
||||
```
|
||||
python lerobot/scripts/eval.py \
|
||||
-p outputs/train/diffusion_pusht/checkpoints/005000 \
|
||||
-p outputs/train/diffusion_pusht/checkpoints/005000/pretrained_model \
|
||||
eval.n_episodes=10
|
||||
```
|
||||
|
||||
|
|
|
@ -18,13 +18,16 @@ import time
|
|||
from contextlib import nullcontext
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
from pprint import pformat
|
||||
|
||||
import hydra
|
||||
import torch
|
||||
from omegaconf import DictConfig
|
||||
from deepdiff import DeepDiff
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from termcolor import colored
|
||||
from torch.cuda.amp import GradScaler
|
||||
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.datasets.factory import make_dataset, resolve_delta_timestamps
|
||||
from lerobot.common.datasets.utils import cycle
|
||||
from lerobot.common.envs.factory import make_env
|
||||
from lerobot.common.logger import Logger, log_output_dir
|
||||
|
@ -34,6 +37,7 @@ from lerobot.common.policies.utils import get_device_from_parameters
|
|||
from lerobot.common.utils.utils import (
|
||||
format_big_number,
|
||||
get_safe_torch_device,
|
||||
init_hydra_config,
|
||||
init_logging,
|
||||
set_global_seed,
|
||||
)
|
||||
|
@ -140,24 +144,6 @@ def update_policy(
|
|||
return info
|
||||
|
||||
|
||||
@hydra.main(version_base="1.2", config_name="default", config_path="../configs")
|
||||
def train_cli(cfg: dict):
|
||||
train(
|
||||
cfg,
|
||||
out_dir=hydra.core.hydra_config.HydraConfig.get().run.dir,
|
||||
job_name=hydra.core.hydra_config.HydraConfig.get().job.name,
|
||||
)
|
||||
|
||||
|
||||
def train_notebook(out_dir=None, job_name=None, config_name="default", config_path="../configs"):
|
||||
from hydra import compose, initialize
|
||||
|
||||
hydra.core.global_hydra.GlobalHydra.instance().clear()
|
||||
initialize(config_path=config_path)
|
||||
cfg = compose(config_name=config_name)
|
||||
train(cfg, out_dir=out_dir, job_name=job_name)
|
||||
|
||||
|
||||
def log_train_info(logger: Logger, info, step, cfg, dataset, is_offline):
|
||||
loss = info["loss"]
|
||||
grad_norm = info["grad_norm"]
|
||||
|
@ -237,15 +223,60 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||
|
||||
init_logging()
|
||||
|
||||
# If we are resuming a run, we need to check that a checkpoint exists in the log directory, and we need
|
||||
# to check for any differences between the provided config and the checkpoint's config.
|
||||
if cfg.resume:
|
||||
if not Logger.get_last_checkpoint_dir(out_dir).exists():
|
||||
raise RuntimeError(
|
||||
"You have set resume=True, but there is no model checkpoint in "
|
||||
f"{Logger.get_last_checkpoint_dir(out_dir)}"
|
||||
)
|
||||
checkpoint_cfg_path = str(Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml")
|
||||
logging.info(
|
||||
colored(
|
||||
"You have set resume=True, indicating that you wish to resume a run",
|
||||
color="yellow",
|
||||
attrs=["bold"],
|
||||
)
|
||||
)
|
||||
# Get the configuration file from the last checkpoint.
|
||||
checkpoint_cfg = init_hydra_config(checkpoint_cfg_path)
|
||||
# Check for differences between the checkpoint configuration and provided configuration.
|
||||
# Hack to resolve the delta_timestamps ahead of time in order to properly diff.
|
||||
resolve_delta_timestamps(cfg)
|
||||
diff = DeepDiff(OmegaConf.to_container(checkpoint_cfg), OmegaConf.to_container(cfg))
|
||||
# Ignore the `resume` and parameters.
|
||||
if "values_changed" in diff and "root['resume']" in diff["values_changed"]:
|
||||
del diff["values_changed"]["root['resume']"]
|
||||
# Log a warning about differences between the checkpoint configuration and the provided
|
||||
# configuration.
|
||||
if len(diff) > 0:
|
||||
logging.warning(
|
||||
"At least one difference was detected between the checkpoint configuration and "
|
||||
f"the provided configuration: \n{pformat(diff)}\nNote that the checkpoint configuration "
|
||||
"takes precedence.",
|
||||
)
|
||||
# Use the checkpoint config instead of the provided config (but keep `resume` parameter).
|
||||
cfg = checkpoint_cfg
|
||||
cfg.resume = True
|
||||
elif Logger.get_last_checkpoint_dir(out_dir).exists():
|
||||
raise RuntimeError(
|
||||
f"The configured output directory {Logger.get_last_checkpoint_dir(out_dir)} already exists."
|
||||
)
|
||||
|
||||
# log metrics to terminal and wandb
|
||||
logger = Logger(cfg, out_dir, wandb_job_name=job_name)
|
||||
|
||||
if cfg.training.online_steps > 0:
|
||||
raise NotImplementedError("Online training is not implemented yet.")
|
||||
|
||||
set_global_seed(cfg.seed)
|
||||
|
||||
# Check device is available
|
||||
device = get_safe_torch_device(cfg.device, log=True)
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
set_global_seed(cfg.seed)
|
||||
|
||||
logging.info("make_dataset")
|
||||
offline_dataset = make_dataset(cfg)
|
||||
|
@ -254,19 +285,25 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||
eval_env = make_env(cfg)
|
||||
|
||||
logging.info("make_policy")
|
||||
policy = make_policy(hydra_cfg=cfg, dataset_stats=offline_dataset.stats)
|
||||
policy = make_policy(
|
||||
hydra_cfg=cfg,
|
||||
dataset_stats=offline_dataset.stats if not cfg.resume else None,
|
||||
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None,
|
||||
)
|
||||
|
||||
# Create optimizer and scheduler
|
||||
# Temporary hack to move optimizer out of policy
|
||||
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
|
||||
grad_scaler = GradScaler(enabled=cfg.use_amp)
|
||||
|
||||
step = 0 # number of policy updates (forward + backward + optim)
|
||||
|
||||
if cfg.resume:
|
||||
step = logger.load_last_training_state(optimizer, lr_scheduler)
|
||||
|
||||
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
||||
num_total_params = sum(p.numel() for p in policy.parameters())
|
||||
|
||||
# log metrics to terminal and wandb
|
||||
logger = Logger(out_dir, job_name, cfg)
|
||||
|
||||
log_output_dir(out_dir)
|
||||
logging.info(f"{cfg.env.task=}")
|
||||
logging.info(f"{cfg.training.offline_steps=} ({format_big_number(cfg.training.offline_steps)})")
|
||||
|
@ -294,12 +331,15 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||
logger.log_video(eval_info["video_paths"][0], step, mode="eval")
|
||||
logging.info("Resume training")
|
||||
|
||||
if cfg.training.save_model and step % cfg.training.save_freq == 0:
|
||||
if cfg.training.save_checkpoint and step % cfg.training.save_freq == 0:
|
||||
logging.info(f"Checkpoint policy after step {step}")
|
||||
# Note: Save with step as the identifier, and format it to have at least 6 digits but more if
|
||||
# needed (choose 6 as a minimum for consistency without being overkill).
|
||||
logger.save_model(
|
||||
logger.save_checkpont(
|
||||
step,
|
||||
policy,
|
||||
optimizer,
|
||||
lr_scheduler,
|
||||
identifier=str(step).zfill(
|
||||
max(6, len(str(cfg.training.offline_steps + cfg.training.online_steps)))
|
||||
),
|
||||
|
@ -319,7 +359,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||
|
||||
policy.train()
|
||||
is_offline = True
|
||||
for step in range(cfg.training.offline_steps):
|
||||
for _ in range(step, cfg.training.offline_steps):
|
||||
if step == 0:
|
||||
logging.info("Start offline training on a fixed dataset")
|
||||
batch = next(dl_iter)
|
||||
|
@ -337,7 +377,6 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||
use_amp=cfg.use_amp,
|
||||
)
|
||||
|
||||
# TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done?
|
||||
if step % cfg.training.log_freq == 0:
|
||||
log_train_info(logger, train_info, step, cfg, offline_dataset, is_offline)
|
||||
|
||||
|
@ -345,6 +384,8 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||
# so we pass in step + 1.
|
||||
evaluate_and_checkpoint_if_needed(step + 1)
|
||||
|
||||
step += 1
|
||||
|
||||
# create an empty online dataset similar to offline dataset
|
||||
online_dataset = deepcopy(offline_dataset)
|
||||
online_dataset.hf_dataset = {}
|
||||
|
@ -369,5 +410,23 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||
logging.info("End of training")
|
||||
|
||||
|
||||
@hydra.main(version_base="1.2", config_name="default", config_path="../configs")
|
||||
def train_cli(cfg: dict):
|
||||
train(
|
||||
cfg,
|
||||
out_dir=hydra.core.hydra_config.HydraConfig.get().run.dir,
|
||||
job_name=hydra.core.hydra_config.HydraConfig.get().job.name,
|
||||
)
|
||||
|
||||
|
||||
def train_notebook(out_dir=None, job_name=None, config_name="default", config_path="../configs"):
|
||||
from hydra import compose, initialize
|
||||
|
||||
hydra.core.global_hydra.GlobalHydra.instance().clear()
|
||||
initialize(config_path=config_path)
|
||||
cfg = compose(config_name=config_name)
|
||||
train(cfg, out_dir=out_dir, job_name=job_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train_cli()
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
|
||||
# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "absl-py"
|
||||
|
@ -595,6 +595,24 @@ files = [
|
|||
{file = "decorator-4.4.2.tar.gz", hash = "sha256:e3a62f0520172440ca0dcc823749319382e377f37f140a0b99ef45fecb84bfe7"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "deepdiff"
|
||||
version = "7.0.1"
|
||||
description = "Deep Difference and Search of any Python object/data. Recreate objects by adding adding deltas to each other."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "deepdiff-7.0.1-py3-none-any.whl", hash = "sha256:447760081918216aa4fd4ca78a4b6a848b81307b2ea94c810255334b759e1dc3"},
|
||||
{file = "deepdiff-7.0.1.tar.gz", hash = "sha256:260c16f052d4badbf60351b4f77e8390bee03a0b516246f6839bc813fb429ddf"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
ordered-set = ">=4.1.0,<4.2.0"
|
||||
|
||||
[package.extras]
|
||||
cli = ["click (==8.1.7)", "pyyaml (==6.0.1)"]
|
||||
optimize = ["orjson"]
|
||||
|
||||
[[package]]
|
||||
name = "diffusers"
|
||||
version = "0.27.2"
|
||||
|
@ -1251,13 +1269,13 @@ files = [
|
|||
|
||||
[[package]]
|
||||
name = "huggingface-hub"
|
||||
version = "0.23.0"
|
||||
version = "0.23.1"
|
||||
description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub"
|
||||
optional = false
|
||||
python-versions = ">=3.8.0"
|
||||
files = [
|
||||
{file = "huggingface_hub-0.23.0-py3-none-any.whl", hash = "sha256:075c30d48ee7db2bba779190dc526d2c11d422aed6f9044c5e2fdc2c432fdb91"},
|
||||
{file = "huggingface_hub-0.23.0.tar.gz", hash = "sha256:7126dedd10a4c6fac796ced4d87a8cf004efc722a5125c2c09299017fa366fa9"},
|
||||
{file = "huggingface_hub-0.23.1-py3-none-any.whl", hash = "sha256:720a5bffd2b1b449deb793da8b0df7a9390a7e238534d5a08c9fbcdecb1dd3cb"},
|
||||
{file = "huggingface_hub-0.23.1.tar.gz", hash = "sha256:4f62dbf6ae94f400c6d3419485e52bce510591432a5248a65d0cb72e4d479eb4"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
|
@ -1655,9 +1673,13 @@ files = [
|
|||
{file = "lxml-5.2.2-cp36-cp36m-win_amd64.whl", hash = "sha256:edcfa83e03370032a489430215c1e7783128808fd3e2e0a3225deee278585196"},
|
||||
{file = "lxml-5.2.2-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:28bf95177400066596cdbcfc933312493799382879da504633d16cf60bba735b"},
|
||||
{file = "lxml-5.2.2-cp37-cp37m-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3a745cc98d504d5bd2c19b10c79c61c7c3df9222629f1b6210c0368177589fb8"},
|
||||
{file = "lxml-5.2.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1b590b39ef90c6b22ec0be925b211298e810b4856909c8ca60d27ffbca6c12e6"},
|
||||
{file = "lxml-5.2.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b336b0416828022bfd5a2e3083e7f5ba54b96242159f83c7e3eebaec752f1716"},
|
||||
{file = "lxml-5.2.2-cp37-cp37m-manylinux_2_28_aarch64.whl", hash = "sha256:c2faf60c583af0d135e853c86ac2735ce178f0e338a3c7f9ae8f622fd2eb788c"},
|
||||
{file = "lxml-5.2.2-cp37-cp37m-manylinux_2_28_x86_64.whl", hash = "sha256:4bc6cb140a7a0ad1f7bc37e018d0ed690b7b6520ade518285dc3171f7a117905"},
|
||||
{file = "lxml-5.2.2-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:7ff762670cada8e05b32bf1e4dc50b140790909caa8303cfddc4d702b71ea184"},
|
||||
{file = "lxml-5.2.2-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:57f0a0bbc9868e10ebe874e9f129d2917750adf008fe7b9c1598c0fbbfdde6a6"},
|
||||
{file = "lxml-5.2.2-cp37-cp37m-musllinux_1_2_aarch64.whl", hash = "sha256:a6d2092797b388342c1bc932077ad232f914351932353e2e8706851c870bca1f"},
|
||||
{file = "lxml-5.2.2-cp37-cp37m-musllinux_1_2_x86_64.whl", hash = "sha256:60499fe961b21264e17a471ec296dcbf4365fbea611bf9e303ab69db7159ce61"},
|
||||
{file = "lxml-5.2.2-cp37-cp37m-win32.whl", hash = "sha256:d9b342c76003c6b9336a80efcc766748a333573abf9350f4094ee46b006ec18f"},
|
||||
{file = "lxml-5.2.2-cp37-cp37m-win_amd64.whl", hash = "sha256:b16db2770517b8799c79aa80f4053cd6f8b716f21f8aca962725a9565ce3ee40"},
|
||||
|
@ -2292,13 +2314,13 @@ files = [
|
|||
|
||||
[[package]]
|
||||
name = "nvidia-nvjitlink-cu12"
|
||||
version = "12.4.127"
|
||||
version = "12.5.40"
|
||||
description = "Nvidia JIT LTO Library"
|
||||
optional = false
|
||||
python-versions = ">=3"
|
||||
files = [
|
||||
{file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:06b3b9b25bf3f8af351d664978ca26a16d2c5127dbd53c0497e28d1fb9611d57"},
|
||||
{file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:fd9020c501d27d135f983c6d3e244b197a7ccad769e34df53a42e276b0e25fa1"},
|
||||
{file = "nvidia_nvjitlink_cu12-12.5.40-py3-none-manylinux2014_x86_64.whl", hash = "sha256:d9714f27c1d0f0895cd8915c07a87a1d0029a0aa36acaf9156952ec2a8a12189"},
|
||||
{file = "nvidia_nvjitlink_cu12-12.5.40-py3-none-win_amd64.whl", hash = "sha256:c3401dc8543b52d3a8158007a0c1ab4e9c768fcbd24153a48c86972102197ddd"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -2351,6 +2373,20 @@ numpy = [
|
|||
{version = ">=1.21.2", markers = "platform_system != \"Darwin\" and python_version >= \"3.10\" and python_version < \"3.11\""},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ordered-set"
|
||||
version = "4.1.0"
|
||||
description = "An OrderedSet is a custom MutableSet that remembers its order, so that every"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "ordered-set-4.1.0.tar.gz", hash = "sha256:694a8e44c87657c59292ede72891eb91d34131f6531463aab3009191c77364a8"},
|
||||
{file = "ordered_set-4.1.0-py3-none-any.whl", hash = "sha256:046e1132c71fcf3330438a539928932caf51ddbc582496833e23de611de14562"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
dev = ["black", "mypy", "pytest"]
|
||||
|
||||
[[package]]
|
||||
name = "packaging"
|
||||
version = "24.0"
|
||||
|
@ -2941,13 +2977,13 @@ files = [
|
|||
|
||||
[[package]]
|
||||
name = "pytest"
|
||||
version = "8.2.0"
|
||||
version = "8.2.1"
|
||||
description = "pytest: simple powerful testing with Python"
|
||||
optional = true
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "pytest-8.2.0-py3-none-any.whl", hash = "sha256:1733f0620f6cda4095bbf0d9ff8022486e91892245bb9e7d5542c018f612f233"},
|
||||
{file = "pytest-8.2.0.tar.gz", hash = "sha256:d507d4482197eac0ba2bae2e9babf0672eb333017bcedaa5fb1a3d42c1174b3f"},
|
||||
{file = "pytest-8.2.1-py3-none-any.whl", hash = "sha256:faccc5d332b8c3719f40283d0d44aa5cf101cec36f88cde9ed8f2bc0538612b1"},
|
||||
{file = "pytest-8.2.1.tar.gz", hash = "sha256:5046e5b46d8e4cac199c373041f26be56fdb81eb4e67dc11d4e10811fc3408fd"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
|
@ -3029,6 +3065,7 @@ files = [
|
|||
{file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
|
||||
|
@ -3153,13 +3190,13 @@ files = [
|
|||
|
||||
[[package]]
|
||||
name = "requests"
|
||||
version = "2.31.0"
|
||||
version = "2.32.2"
|
||||
description = "Python HTTP for Humans."
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "requests-2.31.0-py3-none-any.whl", hash = "sha256:58cd2187c01e70e6e26505bca751777aa9f2ee0b7f4300988b709f44e013003f"},
|
||||
{file = "requests-2.31.0.tar.gz", hash = "sha256:942c5a758f98d790eaed1a29cb6eefc7ffb0d1cf7af05c3d2791656dbd6ad1e1"},
|
||||
{file = "requests-2.32.2-py3-none-any.whl", hash = "sha256:fc06670dd0ed212426dfeb94fc1b983d917c4f9847c863f313c9dfaaffb7c23c"},
|
||||
{file = "requests-2.32.2.tar.gz", hash = "sha256:dd951ff5ecf3e3b3aa26b40703ba77495dab41da839ae72ef3c8e5d8e2433289"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
|
@ -3364,36 +3401,36 @@ test = ["asv", "numpydoc (>=1.7)", "pooch (>=1.6.0)", "pytest (>=7.0)", "pytest-
|
|||
|
||||
[[package]]
|
||||
name = "scipy"
|
||||
version = "1.13.0"
|
||||
version = "1.13.1"
|
||||
description = "Fundamental algorithms for scientific computing in Python"
|
||||
optional = true
|
||||
python-versions = ">=3.9"
|
||||
files = [
|
||||
{file = "scipy-1.13.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:ba419578ab343a4e0a77c0ef82f088238a93eef141b2b8017e46149776dfad4d"},
|
||||
{file = "scipy-1.13.0-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:22789b56a999265431c417d462e5b7f2b487e831ca7bef5edeb56efe4c93f86e"},
|
||||
{file = "scipy-1.13.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:05f1432ba070e90d42d7fd836462c50bf98bd08bed0aa616c359eed8a04e3922"},
|
||||
{file = "scipy-1.13.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b8434f6f3fa49f631fae84afee424e2483289dfc30a47755b4b4e6b07b2633a4"},
|
||||
{file = "scipy-1.13.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:dcbb9ea49b0167de4167c40eeee6e167caeef11effb0670b554d10b1e693a8b9"},
|
||||
{file = "scipy-1.13.0-cp310-cp310-win_amd64.whl", hash = "sha256:1d2f7bb14c178f8b13ebae93f67e42b0a6b0fc50eba1cd8021c9b6e08e8fb1cd"},
|
||||
{file = "scipy-1.13.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:0fbcf8abaf5aa2dc8d6400566c1a727aed338b5fe880cde64907596a89d576fa"},
|
||||
{file = "scipy-1.13.0-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:5e4a756355522eb60fcd61f8372ac2549073c8788f6114449b37e9e8104f15a5"},
|
||||
{file = "scipy-1.13.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b5acd8e1dbd8dbe38d0004b1497019b2dbbc3d70691e65d69615f8a7292865d7"},
|
||||
{file = "scipy-1.13.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9ff7dad5d24a8045d836671e082a490848e8639cabb3dbdacb29f943a678683d"},
|
||||
{file = "scipy-1.13.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:4dca18c3ffee287ddd3bc8f1dabaf45f5305c5afc9f8ab9cbfab855e70b2df5c"},
|
||||
{file = "scipy-1.13.0-cp311-cp311-win_amd64.whl", hash = "sha256:a2f471de4d01200718b2b8927f7d76b5d9bde18047ea0fa8bd15c5ba3f26a1d6"},
|
||||
{file = "scipy-1.13.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:d0de696f589681c2802f9090fff730c218f7c51ff49bf252b6a97ec4a5d19e8b"},
|
||||
{file = "scipy-1.13.0-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:b2a3ff461ec4756b7e8e42e1c681077349a038f0686132d623fa404c0bee2551"},
|
||||
{file = "scipy-1.13.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6bf9fe63e7a4bf01d3645b13ff2aa6dea023d38993f42aaac81a18b1bda7a82a"},
|
||||
{file = "scipy-1.13.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1e7626dfd91cdea5714f343ce1176b6c4745155d234f1033584154f60ef1ff42"},
|
||||
{file = "scipy-1.13.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:109d391d720fcebf2fbe008621952b08e52907cf4c8c7efc7376822151820820"},
|
||||
{file = "scipy-1.13.0-cp312-cp312-win_amd64.whl", hash = "sha256:8930ae3ea371d6b91c203b1032b9600d69c568e537b7988a3073dfe4d4774f21"},
|
||||
{file = "scipy-1.13.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:5407708195cb38d70fd2d6bb04b1b9dd5c92297d86e9f9daae1576bd9e06f602"},
|
||||
{file = "scipy-1.13.0-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:ac38c4c92951ac0f729c4c48c9e13eb3675d9986cc0c83943784d7390d540c78"},
|
||||
{file = "scipy-1.13.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:09c74543c4fbeb67af6ce457f6a6a28e5d3739a87f62412e4a16e46f164f0ae5"},
|
||||
{file = "scipy-1.13.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:28e286bf9ac422d6beb559bc61312c348ca9b0f0dae0d7c5afde7f722d6ea13d"},
|
||||
{file = "scipy-1.13.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:33fde20efc380bd23a78a4d26d59fc8704e9b5fd9b08841693eb46716ba13d86"},
|
||||
{file = "scipy-1.13.0-cp39-cp39-win_amd64.whl", hash = "sha256:45c08bec71d3546d606989ba6e7daa6f0992918171e2a6f7fbedfa7361c2de1e"},
|
||||
{file = "scipy-1.13.0.tar.gz", hash = "sha256:58569af537ea29d3f78e5abd18398459f195546bb3be23d16677fb26616cc11e"},
|
||||
{file = "scipy-1.13.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:20335853b85e9a49ff7572ab453794298bcf0354d8068c5f6775a0eabf350aca"},
|
||||
{file = "scipy-1.13.1-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:d605e9c23906d1994f55ace80e0125c587f96c020037ea6aa98d01b4bd2e222f"},
|
||||
{file = "scipy-1.13.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cfa31f1def5c819b19ecc3a8b52d28ffdcc7ed52bb20c9a7589669dd3c250989"},
|
||||
{file = "scipy-1.13.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f26264b282b9da0952a024ae34710c2aff7d27480ee91a2e82b7b7073c24722f"},
|
||||
{file = "scipy-1.13.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:eccfa1906eacc02de42d70ef4aecea45415f5be17e72b61bafcfd329bdc52e94"},
|
||||
{file = "scipy-1.13.1-cp310-cp310-win_amd64.whl", hash = "sha256:2831f0dc9c5ea9edd6e51e6e769b655f08ec6db6e2e10f86ef39bd32eb11da54"},
|
||||
{file = "scipy-1.13.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:27e52b09c0d3a1d5b63e1105f24177e544a222b43611aaf5bc44d4a0979e32f9"},
|
||||
{file = "scipy-1.13.1-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:54f430b00f0133e2224c3ba42b805bfd0086fe488835effa33fa291561932326"},
|
||||
{file = "scipy-1.13.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e89369d27f9e7b0884ae559a3a956e77c02114cc60a6058b4e5011572eea9299"},
|
||||
{file = "scipy-1.13.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a78b4b3345f1b6f68a763c6e25c0c9a23a9fd0f39f5f3d200efe8feda560a5fa"},
|
||||
{file = "scipy-1.13.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:45484bee6d65633752c490404513b9ef02475b4284c4cfab0ef946def50b3f59"},
|
||||
{file = "scipy-1.13.1-cp311-cp311-win_amd64.whl", hash = "sha256:5713f62f781eebd8d597eb3f88b8bf9274e79eeabf63afb4a737abc6c84ad37b"},
|
||||
{file = "scipy-1.13.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:5d72782f39716b2b3509cd7c33cdc08c96f2f4d2b06d51e52fb45a19ca0c86a1"},
|
||||
{file = "scipy-1.13.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:017367484ce5498445aade74b1d5ab377acdc65e27095155e448c88497755a5d"},
|
||||
{file = "scipy-1.13.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:949ae67db5fa78a86e8fa644b9a6b07252f449dcf74247108c50e1d20d2b4627"},
|
||||
{file = "scipy-1.13.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:de3ade0e53bc1f21358aa74ff4830235d716211d7d077e340c7349bc3542e884"},
|
||||
{file = "scipy-1.13.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:2ac65fb503dad64218c228e2dc2d0a0193f7904747db43014645ae139c8fad16"},
|
||||
{file = "scipy-1.13.1-cp312-cp312-win_amd64.whl", hash = "sha256:cdd7dacfb95fea358916410ec61bbc20440f7860333aee6d882bb8046264e949"},
|
||||
{file = "scipy-1.13.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:436bbb42a94a8aeef855d755ce5a465479c721e9d684de76bf61a62e7c2b81d5"},
|
||||
{file = "scipy-1.13.1-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:8335549ebbca860c52bf3d02f80784e91a004b71b059e3eea9678ba994796a24"},
|
||||
{file = "scipy-1.13.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d533654b7d221a6a97304ab63c41c96473ff04459e404b83275b60aa8f4b7004"},
|
||||
{file = "scipy-1.13.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:637e98dcf185ba7f8e663e122ebf908c4702420477ae52a04f9908707456ba4d"},
|
||||
{file = "scipy-1.13.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:a014c2b3697bde71724244f63de2476925596c24285c7a637364761f8710891c"},
|
||||
{file = "scipy-1.13.1-cp39-cp39-win_amd64.whl", hash = "sha256:392e4ec766654852c25ebad4f64e4e584cf19820b980bc04960bca0b0cd6eaa2"},
|
||||
{file = "scipy-1.13.1.tar.gz", hash = "sha256:095a87a0312b08dfd6a6155cbbd310a8c51800fc931b8c0b84003014b874ed3c"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
|
@ -3406,13 +3443,13 @@ test = ["array-api-strict", "asv", "gmpy2", "hypothesis (>=6.30)", "mpmath", "po
|
|||
|
||||
[[package]]
|
||||
name = "sentry-sdk"
|
||||
version = "2.2.0"
|
||||
version = "2.3.1"
|
||||
description = "Python client for Sentry (https://sentry.io)"
|
||||
optional = false
|
||||
python-versions = ">=3.6"
|
||||
files = [
|
||||
{file = "sentry_sdk-2.2.0-py2.py3-none-any.whl", hash = "sha256:674f58da37835ea7447fe0e34c57b4a4277fad558b0a7cb4a6c83bcb263086be"},
|
||||
{file = "sentry_sdk-2.2.0.tar.gz", hash = "sha256:70eca103cf4c6302365a9d7cf522e7ed7720828910eb23d43ada8e50d1ecda9d"},
|
||||
{file = "sentry_sdk-2.3.1-py2.py3-none-any.whl", hash = "sha256:c5aeb095ba226391d337dd42a6f9470d86c9fc236ecc71cfc7cd1942b45010c6"},
|
||||
{file = "sentry_sdk-2.3.1.tar.gz", hash = "sha256:139a71a19f5e9eb5d3623942491ce03cf8ebc14ea2e39ba3e6fe79560d8a5b1f"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
|
@ -3434,7 +3471,7 @@ django = ["django (>=1.8)"]
|
|||
falcon = ["falcon (>=1.4)"]
|
||||
fastapi = ["fastapi (>=0.79.0)"]
|
||||
flask = ["blinker (>=1.1)", "flask (>=0.11)", "markupsafe"]
|
||||
grpcio = ["grpcio (>=1.21.1)"]
|
||||
grpcio = ["grpcio (>=1.21.1)", "protobuf (>=3.8.0)"]
|
||||
httpx = ["httpx (>=0.16.0)"]
|
||||
huey = ["huey (>=2)"]
|
||||
huggingface-hub = ["huggingface-hub (>=0.22)"]
|
||||
|
@ -3556,19 +3593,18 @@ test = ["pytest"]
|
|||
|
||||
[[package]]
|
||||
name = "setuptools"
|
||||
version = "69.5.1"
|
||||
version = "70.0.0"
|
||||
description = "Easily download, build, install, upgrade, and uninstall Python packages"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "setuptools-69.5.1-py3-none-any.whl", hash = "sha256:c636ac361bc47580504644275c9ad802c50415c7522212252c033bd15f301f32"},
|
||||
{file = "setuptools-69.5.1.tar.gz", hash = "sha256:6c1fccdac05a97e598fb0ae3bbed5904ccb317337a51139dcd51453611bbb987"},
|
||||
{file = "setuptools-70.0.0-py3-none-any.whl", hash = "sha256:54faa7f2e8d2d11bcd2c07bed282eef1046b5c080d1c32add737d7b5817b1ad4"},
|
||||
{file = "setuptools-70.0.0.tar.gz", hash = "sha256:f211a66637b8fa059bb28183da127d4e86396c991a942b028c6650d4319c3fd0"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier"]
|
||||
testing = ["build[virtualenv]", "filelock (>=3.4.0)", "importlib-metadata", "ini2toml[lite] (>=0.9)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "mypy (==1.9)", "packaging (>=23.2)", "pip (>=19.1)", "pytest (>=6,!=8.1.1)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-home (>=0.5)", "pytest-mypy", "pytest-perf", "pytest-ruff (>=0.2.1)", "pytest-timeout", "pytest-xdist (>=3)", "tomli", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"]
|
||||
testing-integration = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "packaging (>=23.2)", "pytest", "pytest-enabler", "pytest-xdist", "tomli", "virtualenv (>=13.0.0)", "wheel"]
|
||||
docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "pyproject-hooks (!=1.1)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier"]
|
||||
testing = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "importlib-metadata", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "mypy (==1.9)", "packaging (>=23.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.1)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-home (>=0.5)", "pytest-mypy", "pytest-perf", "pytest-ruff (>=0.2.1)", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"]
|
||||
|
||||
[[package]]
|
||||
name = "shapely"
|
||||
|
@ -3703,13 +3739,13 @@ tests = ["pytest", "pytest-cov"]
|
|||
|
||||
[[package]]
|
||||
name = "tifffile"
|
||||
version = "2024.5.10"
|
||||
version = "2024.5.22"
|
||||
description = "Read and write TIFF files"
|
||||
optional = true
|
||||
python-versions = ">=3.9"
|
||||
files = [
|
||||
{file = "tifffile-2024.5.10-py3-none-any.whl", hash = "sha256:4154f091aa24d4e75bfad9ab2d5424a68c70e67b8220188066dc61946d4551bd"},
|
||||
{file = "tifffile-2024.5.10.tar.gz", hash = "sha256:aa1e1b12be952ab20717d6848bd6d4a5ee88d2aa319f1152bff4354ad728ec86"},
|
||||
{file = "tifffile-2024.5.22-py3-none-any.whl", hash = "sha256:e281781c15d7d197d7e12749849c965651413aa905f97a48b0f84bd90a3b4c6f"},
|
||||
{file = "tifffile-2024.5.22.tar.gz", hash = "sha256:3a105801d1b86d55692a98812a170c39d3f0447aeacb1d94635d38077cb328c4"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
|
@ -3865,13 +3901,13 @@ tutorials = ["matplotlib", "pandas", "tabulate", "torch"]
|
|||
|
||||
[[package]]
|
||||
name = "typing-extensions"
|
||||
version = "4.11.0"
|
||||
version = "4.12.0"
|
||||
description = "Backported and Experimental Type Hints for Python 3.8+"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "typing_extensions-4.11.0-py3-none-any.whl", hash = "sha256:c1f94d72897edaf4ce775bb7558d5b79d8126906a14ea5ed1635921406c0387a"},
|
||||
{file = "typing_extensions-4.11.0.tar.gz", hash = "sha256:83f085bd5ca59c80295fc2a82ab5dac679cbe02b9f33f7d83af68e241bea51b0"},
|
||||
{file = "typing_extensions-4.12.0-py3-none-any.whl", hash = "sha256:b349c66bea9016ac22978d800cfff206d5f9816951f12a7d0ec5578b0a819594"},
|
||||
{file = "typing_extensions-4.12.0.tar.gz", hash = "sha256:8cbcdc8606ebcb0d95453ad7dc5065e6237b6aa230a31e81d0f440c30fed5fd8"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -4231,4 +4267,4 @@ xarm = ["gym-xarm"]
|
|||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.10,<3.13"
|
||||
content-hash = "c3044329cfad91ffd91b411e85f16d8dfdcdfd7b9186d38fff5e18f4ee647e7b"
|
||||
content-hash = "1ad6ef0f88f0056ab639e60e033e586f7460a9c5fc3676a477bbd47923f41cb6"
|
||||
|
|
|
@ -58,6 +58,7 @@ imagecodecs = { version = ">=2024.1.1", optional = true }
|
|||
pyav = ">=12.0.5"
|
||||
moviepy = ">=1.0.3"
|
||||
rerun-sdk = ">=0.15.1"
|
||||
deepdiff = ">=7.0.1"
|
||||
|
||||
|
||||
[tool.poetry.extras]
|
||||
|
|
|
@ -11,22 +11,24 @@ from lerobot.common.datasets.utils import (
|
|||
hf_transform_to_torch,
|
||||
reset_episode_index,
|
||||
)
|
||||
from lerobot.common.utils.utils import seeded_context, set_global_seed
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"rand_fn",
|
||||
(
|
||||
[
|
||||
random.random,
|
||||
np.random.random,
|
||||
lambda: torch.rand(1).item(),
|
||||
]
|
||||
+ [lambda: torch.rand(1, device="cuda")]
|
||||
if torch.cuda.is_available()
|
||||
else []
|
||||
),
|
||||
from lerobot.common.utils.utils import (
|
||||
get_global_random_state,
|
||||
seeded_context,
|
||||
set_global_random_state,
|
||||
set_global_seed,
|
||||
)
|
||||
|
||||
# Random generation functions for testing the seeding and random state get/set.
|
||||
rand_fns = [
|
||||
random.random,
|
||||
np.random.random,
|
||||
lambda: torch.rand(1).item(),
|
||||
]
|
||||
if torch.cuda.is_available():
|
||||
rand_fns.append(lambda: torch.rand(1, device="cuda"))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("rand_fn", rand_fns)
|
||||
def test_seeding(rand_fn: Callable[[], int]):
|
||||
set_global_seed(0)
|
||||
a = rand_fn()
|
||||
|
@ -46,6 +48,15 @@ def test_seeding(rand_fn: Callable[[], int]):
|
|||
assert c_ == c
|
||||
|
||||
|
||||
def test_get_set_random_state():
|
||||
"""Check that getting the random state, then setting it results in the same random number generation."""
|
||||
random_state_dict = get_global_random_state()
|
||||
rand_numbers = [rand_fn() for rand_fn in rand_fns]
|
||||
set_global_random_state(random_state_dict)
|
||||
rand_numbers_ = [rand_fn() for rand_fn in rand_fns]
|
||||
assert rand_numbers_ == rand_numbers
|
||||
|
||||
|
||||
def test_calculate_episode_data_index():
|
||||
dataset = Dataset.from_dict(
|
||||
{
|
||||
|
|
Loading…
Reference in New Issue