Merge branch 'main' into aloha_hd5_to_dataset_v2

This commit is contained in:
Claudio Coppola 2025-02-12 09:15:33 +00:00 committed by GitHub
commit f7b84fa2cd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
409 changed files with 1515 additions and 5153 deletions

View File

@ -4,8 +4,6 @@ on:
workflow_dispatch:
workflow_call:
pull_request:
branches:
- main
push:
branches:
- main

View File

@ -4,8 +4,6 @@ name: Test Dockerfiles
on:
pull_request:
branches:
- main
paths:
# Run only when DockerFile files are modified
- "docker/**"

View File

@ -2,8 +2,6 @@ name: Tests
on:
pull_request:
branches:
- main
paths:
- "lerobot/**"
- "tests/**"

View File

@ -39,8 +39,8 @@ test-act-ete-train:
--dataset.image_transforms.enable=true \
--dataset.episodes="[0]" \
--batch_size=2 \
--offline.steps=4 \
--online.steps=0 \
--steps=4 \
--eval_freq=2 \
--eval.n_episodes=1 \
--eval.batch_size=1 \
--save_freq=2 \
@ -76,8 +76,8 @@ test-diffusion-ete-train:
--dataset.image_transforms.enable=true \
--dataset.episodes="[0]" \
--batch_size=2 \
--offline.steps=2 \
--online.steps=0 \
--steps=2 \
--eval_freq=2 \
--eval.n_episodes=1 \
--eval.batch_size=1 \
--save_checkpoint=true \
@ -106,8 +106,8 @@ test-tdmpc-ete-train:
--dataset.image_transforms.enable=true \
--dataset.episodes="[0]" \
--batch_size=2 \
--offline.steps=2 \
--online.steps=0 \
--steps=2 \
--eval_freq=2 \
--eval.n_episodes=1 \
--eval.batch_size=1 \
--save_checkpoint=true \
@ -126,30 +126,3 @@ test-tdmpc-ete-eval:
--eval.n_episodes=1 \
--eval.batch_size=1 \
--device=$(DEVICE)
# TODO(rcadene): fix online buffer to storing "task"
# test-tdmpc-ete-train-with-online:
# python lerobot/scripts/train.py \
# --policy.type=tdmpc \
# --env.type=pusht \
# --env.obs_type=environment_state_agent_pos \
# --env.episode_length=5 \
# --dataset.repo_id=lerobot/pusht_keypoints \
# --dataset.image_transforms.enable=true \
# --dataset.episodes="[0]" \
# --batch_size=2 \
# --offline.steps=2 \
# --online.steps=20 \
# --online.rollout_n_episodes=2 \
# --online.rollout_batch_size=2 \
# --online.steps_between_rollouts=10 \
# --online.buffer_capacity=1000 \
# --online.env_seed=10000 \
# --save_checkpoint=false \
# --save_freq=10 \
# --log_freq=1 \
# --eval.use_async_envs=true \
# --eval.n_episodes=1 \
# --eval.batch_size=1 \
# --device=$(DEVICE) \
# --output_dir=tests/outputs/tdmpc_online/

View File

@ -86,8 +86,7 @@ def main():
while not done:
for batch in dataloader:
batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
output_dict = policy.forward(batch)
loss = output_dict["loss"]
loss, _ = policy.forward(batch)
loss.backward()
optimizer.step()
optimizer.zero_grad()

View File

@ -161,13 +161,13 @@ python lerobot/scripts/train.py \
```
You should see from the logging that your training picks up from where it left off.
Another reason for which you might want to resume a run is simply to extend training and add more training steps. The number of training steps is set by the option `--offline.steps`, which is 100 000 by default.
Another reason for which you might want to resume a run is simply to extend training and add more training steps. The number of training steps is set by the option `--steps`, which is 100 000 by default.
You could double the number of steps of the previous run with:
```bash
python lerobot/scripts/train.py \
--config_path=outputs/train/run_resumption/checkpoints/last/pretrained_model/ \
--resume=true \
--offline.steps=200000
--steps=200000
```
## Outputs of a run
@ -175,12 +175,16 @@ In the output directory, there will be a folder called `checkpoints` with the fo
```bash
outputs/train/run_resumption/checkpoints
├── 000100 # checkpoint_dir for training step 100
│   ├── pretrained_model
│   │   ├── config.json # pretrained policy config
│   │   ├── model.safetensors # model weights
│   │   ├── train_config.json # train config
│ │ └── README.md # model card
│   └── training_state.pth # optimizer/scheduler/rng state and training step
│ ├── pretrained_model/
│ │ ├── config.json # policy config
│ │ ├── model.safetensors # policy weights
│ │ └── train_config.json # train config
│ └── training_state/
│ ├── optimizer_param_groups.json # optimizer param groups
│ ├── optimizer_state.safetensors # optimizer state
│ ├── rng_state.safetensors # rng states
│ ├── scheduler_state.json # scheduler state
│ └── training_step.json # training step
├── 000200
└── last -> 000200 # symlink to the last available checkpoint
```
@ -250,7 +254,7 @@ python lerobot/scripts/train.py \
python lerobot/scripts/train.py \
--config_path=checkpoint/pretrained_model/ \
--resume=true \
--offline.steps=200000 # <- you can change some training parameters
--steps=200000 # <- you can change some training parameters
```
#### Fine-tuning

View File

@ -75,9 +75,9 @@ def main():
n_examples_evaluated = 0
for batch in val_dataloader:
batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
output_dict = policy.forward(batch)
loss, _ = policy.forward(batch)
loss_cumsum += output_dict["loss"].item()
loss_cumsum += loss.item()
n_examples_evaluated += batch["index"].shape[0]
# Calculate the average loss over the validation set.

View File

@ -4,3 +4,14 @@ OBS_ROBOT = "observation.state"
OBS_IMAGE = "observation.image"
OBS_IMAGES = "observation.images"
ACTION = "action"
# files & directories
CHECKPOINTS_DIR = "checkpoints"
LAST_CHECKPOINT_LINK = "last"
PRETRAINED_MODEL_DIR = "pretrained_model"
TRAINING_STATE_DIR = "training_state"
RNG_STATE = "rng_state.safetensors"
TRAINING_STEP = "training_step.json"
OPTIMIZER_STATE = "optimizer_state.safetensors"
OPTIMIZER_PARAM_GROUPS = "optimizer_param_groups.json"
SCHEDULER_STATE = "scheduler_state.json"

View File

@ -1,240 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Borrowed from https://github.com/fyhMer/fowm/blob/main/src/logger.py
# TODO(rcadene, alexander-soare): clean this file
"""
import logging
import os
import re
from dataclasses import asdict
from glob import glob
from pathlib import Path
import draccus
import torch
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
from termcolor import colored
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.utils.utils import get_global_random_state
from lerobot.configs.train import TrainPipelineConfig
from lerobot.configs.types import FeatureType, NormalizationMode
PRETRAINED_MODEL = "pretrained_model"
TRAINING_STATE = "training_state.pth"
def log_output_dir(out_dir):
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {out_dir}")
def cfg_to_group(cfg: TrainPipelineConfig, return_list: bool = False) -> list[str] | str:
"""Return a group name for logging. Optionally returns group name as list."""
lst = [
f"policy:{cfg.policy.type}",
f"dataset:{cfg.dataset.repo_id}",
f"seed:{cfg.seed}",
]
if cfg.env is not None:
lst.append(f"env:{cfg.env.type}")
return lst if return_list else "-".join(lst)
def get_wandb_run_id_from_filesystem(checkpoint_dir: Path) -> str:
# Get the WandB run ID.
paths = glob(str(checkpoint_dir / "../wandb/latest-run/run-*"))
if len(paths) != 1:
raise RuntimeError("Couldn't get the previous WandB run ID for run resumption.")
match = re.search(r"run-([^\.]+).wandb", paths[0].split("/")[-1])
if match is None:
raise RuntimeError("Couldn't get the previous WandB run ID for run resumption.")
wandb_run_id = match.groups(0)[0]
return wandb_run_id
class Logger:
"""Primary logger object. Logs either locally or using wandb.
The logger creates the following directory structure:
provided_log_dir
checkpoints
specific_checkpoint_name
pretrained_model # Hugging Face pretrained model directory
...
training_state.pth # optimizer, scheduler, and random states + training step
| another_specific_checkpoint_name
...
| ...
last # a softlink to the last logged checkpoint
"""
pretrained_model_dir_name = PRETRAINED_MODEL
training_state_file_name = TRAINING_STATE
def __init__(self, cfg: TrainPipelineConfig):
self._cfg = cfg
self.log_dir = cfg.output_dir
self.log_dir.mkdir(parents=True, exist_ok=True)
self.job_name = cfg.job_name
self.checkpoints_dir = self.get_checkpoints_dir(self.log_dir)
self.last_checkpoint_dir = self.get_last_checkpoint_dir(self.log_dir)
self.last_pretrained_model_dir = self.get_last_pretrained_model_dir(self.log_dir)
# Set up WandB.
self._group = cfg_to_group(cfg)
run_offline = not cfg.wandb.enable or not cfg.wandb.project
if run_offline:
logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"]))
self._wandb = None
else:
os.environ["WANDB_SILENT"] = "true"
import wandb
wandb_run_id = None
if cfg.resume:
wandb_run_id = get_wandb_run_id_from_filesystem(self.checkpoints_dir)
wandb.init(
id=wandb_run_id,
project=cfg.wandb.project,
entity=cfg.wandb.entity,
name=self.job_name,
notes=cfg.wandb.notes,
tags=cfg_to_group(cfg, return_list=True),
dir=self.log_dir,
config=asdict(self._cfg),
# TODO(rcadene): try set to True
save_code=False,
# TODO(rcadene): split train and eval, and run async eval with job_type="eval"
job_type="train_eval",
resume="must" if cfg.resume else None,
)
print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"]))
logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}")
self._wandb = wandb
@classmethod
def get_checkpoints_dir(cls, log_dir: str | Path) -> Path:
"""Given the log directory, get the sub-directory in which checkpoints will be saved."""
return Path(log_dir) / "checkpoints"
@classmethod
def get_last_checkpoint_dir(cls, log_dir: str | Path) -> Path:
"""Given the log directory, get the sub-directory in which the last checkpoint will be saved."""
return cls.get_checkpoints_dir(log_dir) / "last"
@classmethod
def get_last_pretrained_model_dir(cls, log_dir: str | Path) -> Path:
"""
Given the log directory, get the sub-directory in which the last checkpoint's pretrained weights will
be saved.
"""
return cls.get_last_checkpoint_dir(log_dir) / cls.pretrained_model_dir_name
def save_model(self, save_dir: Path, policy: PreTrainedPolicy, wandb_artifact_name: str | None = None):
"""Save the weights of the Policy model using PyTorchModelHubMixin.
The weights are saved in a folder called "pretrained_model" under the checkpoint directory.
Optionally also upload the model to WandB.
"""
self.checkpoints_dir.mkdir(parents=True, exist_ok=True)
register_features_types()
policy.save_pretrained(save_dir)
# Also save the full config for the env configuration.
self._cfg.save_pretrained(save_dir)
if self._wandb and not self._cfg.wandb.disable_artifact:
# note wandb artifact does not accept ":" or "/" in its name
artifact = self._wandb.Artifact(wandb_artifact_name, type="model")
artifact.add_file(save_dir / SAFETENSORS_SINGLE_FILE)
self._wandb.log_artifact(artifact)
if self.last_checkpoint_dir.exists():
os.remove(self.last_checkpoint_dir)
def save_training_state(
self,
save_dir: Path,
train_step: int,
optimizer: Optimizer | None = None,
scheduler: LRScheduler | None = None,
):
"""Checkpoint the global training_step, optimizer state, scheduler state, and random state.
All of these are saved as "training_state.pth" under the checkpoint directory.
"""
training_state = {}
training_state["step"] = train_step
training_state.update(get_global_random_state())
if optimizer is not None:
training_state["optimizer"] = optimizer.state_dict()
if scheduler is not None:
training_state["scheduler"] = scheduler.state_dict()
torch.save(training_state, save_dir / self.training_state_file_name)
def save_checkpoint(
self,
train_step: int,
identifier: str,
policy: PreTrainedPolicy,
optimizer: Optimizer | None = None,
scheduler: LRScheduler | None = None,
):
"""Checkpoint the model weights and the training state."""
checkpoint_dir = self.checkpoints_dir / str(identifier)
wandb_artifact_name = (
None
if self._wandb is None
else f"{self._group.replace(':', '_').replace('/', '_')}-{self._cfg.seed}-{identifier}"
)
self.save_model(
checkpoint_dir / self.pretrained_model_dir_name, policy, wandb_artifact_name=wandb_artifact_name
)
self.save_training_state(checkpoint_dir, train_step, optimizer, scheduler)
relative_target = checkpoint_dir.relative_to(self.last_checkpoint_dir.parent)
self.last_checkpoint_dir.symlink_to(relative_target)
def log_dict(self, d: dict, step: int, mode: str = "train"):
assert mode in {"train", "eval"}
# TODO(alexander-soare): Add local text log.
if self._wandb is not None:
for k, v in d.items():
if not isinstance(v, (int, float, str)):
logging.warning(
f'WandB logging of key "{k}" was ignored as its type is not handled by this wrapper.'
)
continue
self._wandb.log({f"{mode}/{k}": v}, step=step)
def log_video(self, video_path: str, step: int, mode: str = "train"):
assert mode in {"train", "eval"}
assert self._wandb is not None
wandb_video = self._wandb.Video(video_path, fps=self._cfg.env.fps, format="mp4")
self._wandb.log({f"{mode}/video": wandb_video}, step=step)
def register_features_types():
draccus.decode.register(FeatureType, lambda x: FeatureType[x])
draccus.encode.register(FeatureType, lambda x: x.name)
draccus.decode.register(NormalizationMode, lambda x: NormalizationMode[x])
draccus.encode.register(NormalizationMode, lambda x: x.name)

View File

@ -14,15 +14,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from pathlib import Path
import torch
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
from lerobot.common.logger import TRAINING_STATE
from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.utils.utils import get_global_random_state, set_global_random_state
from lerobot.configs.train import TrainPipelineConfig
@ -40,22 +36,5 @@ def make_optimizer_and_scheduler(
"""
params = policy.get_optim_params() if cfg.use_policy_training_preset else policy.parameters()
optimizer = cfg.optimizer.build(params)
lr_scheduler = cfg.scheduler.build(optimizer, cfg.offline.steps) if cfg.scheduler is not None else None
lr_scheduler = cfg.scheduler.build(optimizer, cfg.steps) if cfg.scheduler is not None else None
return optimizer, lr_scheduler
def load_training_state(checkpoint_dir: Path, optimizer: Optimizer, scheduler: LRScheduler | None) -> int:
"""
Given the checkpoint directory, load the optimizer state, scheduler state, and random state, and
return the global training step.
"""
# TODO(aliberts): use safetensors instead as weights_only=False is unsafe
training_state = torch.load(checkpoint_dir / TRAINING_STATE, weights_only=False)
optimizer.load_state_dict(training_state["optimizer"])
if scheduler is not None:
scheduler.load_state_dict(training_state["scheduler"])
elif "scheduler" in training_state:
raise ValueError("The checkpoint contains a scheduler state_dict, but no LRScheduler was provided.")
# Small HACK to get the expected keys: use `get_global_random_state`.
set_global_random_state({k: training_state[k] for k in get_global_random_state()})
return training_state["step"], optimizer, scheduler

View File

@ -1,8 +1,32 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
from dataclasses import asdict, dataclass
from pathlib import Path
import draccus
import torch
from safetensors.torch import load_file, save_file
from lerobot.common.constants import (
OPTIMIZER_PARAM_GROUPS,
OPTIMIZER_STATE,
)
from lerobot.common.datasets.utils import flatten_dict, unflatten_dict, write_json
from lerobot.common.utils.io_utils import deserialize_json_into_object
@dataclass
@ -68,3 +92,27 @@ class SGDConfig(OptimizerConfig):
kwargs = asdict(self)
kwargs.pop("grad_clip_norm")
return torch.optim.SGD(params, **kwargs)
def save_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> None:
state = optimizer.state_dict()
param_groups = state.pop("param_groups")
flat_state = flatten_dict(state)
save_file(flat_state, save_dir / OPTIMIZER_STATE)
write_json(param_groups, save_dir / OPTIMIZER_PARAM_GROUPS)
def load_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> torch.optim.Optimizer:
current_state_dict = optimizer.state_dict()
flat_state = load_file(save_dir / OPTIMIZER_STATE)
state = unflatten_dict(flat_state)
loaded_state_dict = {"state": {int(k): v for k, v in state["state"].items()}}
if "param_groups" in current_state_dict:
param_groups = deserialize_json_into_object(
save_dir / OPTIMIZER_PARAM_GROUPS, current_state_dict["param_groups"]
)
loaded_state_dict["param_groups"] = param_groups
optimizer.load_state_dict(loaded_state_dict)
return optimizer

View File

@ -1,11 +1,31 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
import math
from dataclasses import asdict, dataclass
from pathlib import Path
import draccus
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
from lerobot.common.constants import SCHEDULER_STATE
from lerobot.common.datasets.utils import write_json
from lerobot.common.utils.io_utils import deserialize_json_into_object
@dataclass
class LRSchedulerConfig(draccus.ChoiceRegistry, abc.ABC):
@ -89,3 +109,14 @@ class CosineDecayWithWarmupSchedulerConfig(LRSchedulerConfig):
return cosine_decay_schedule(current_step)
return LambdaLR(optimizer, lr_lambda, -1)
def save_scheduler_state(scheduler: LRScheduler, save_dir: Path) -> None:
state_dict = scheduler.state_dict()
write_json(state_dict, save_dir / SCHEDULER_STATE)
def load_scheduler_state(scheduler: LRScheduler, save_dir: Path) -> LRScheduler:
state_dict = deserialize_json_into_object(save_dir / SCHEDULER_STATE, scheduler.state_dict())
scheduler.load_state_dict(state_dict)
return scheduler

View File

@ -144,7 +144,7 @@ class ACTPolicy(PreTrainedPolicy):
self._action_queue.extend(actions.transpose(0, 1))
return self._action_queue.popleft()
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
"""Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch)
if self.config.image_features:
@ -169,11 +169,11 @@ class ACTPolicy(PreTrainedPolicy):
(-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
)
loss_dict["kld_loss"] = mean_kld.item()
loss_dict["loss"] = l1_loss + mean_kld * self.config.kl_weight
loss = l1_loss + mean_kld * self.config.kl_weight
else:
loss_dict["loss"] = l1_loss
loss = l1_loss
return loss_dict
return loss, loss_dict
class ACTTemporalEnsembler:

View File

@ -143,7 +143,7 @@ class DiffusionPolicy(PreTrainedPolicy):
action = self._queues["action"].popleft()
return action
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, None]:
"""Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch)
if self.config.image_features:
@ -153,7 +153,8 @@ class DiffusionPolicy(PreTrainedPolicy):
)
batch = self.normalize_targets(batch)
loss = self.diffusion.compute_loss(batch)
return {"loss": loss}
# no output_dict so returning None
return loss, None
def _make_noise_scheduler(name: str, **kwargs: dict) -> DDPMScheduler | DDIMScheduler:

View File

@ -163,12 +163,17 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
"""
raise NotImplementedError
# TODO(aliberts, rcadene): split into 'forward' and 'compute_loss'?
@abc.abstractmethod
def forward(self, batch: dict[str, Tensor]) -> dict:
"""Run the batch through the model and compute the loss for training or validation.
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict | None]:
"""_summary_
Returns a dictionary with "loss" and potentially other information. Apart from "loss" which is a Tensor, all
other items should be logging-friendly, native Python types.
Args:
batch (dict[str, Tensor]): _description_
Returns:
tuple[Tensor, dict | None]: The loss and potentially other information. Apart from the loss which
is a Tensor, all other items should be logging-friendly, native Python types.
"""
raise NotImplementedError

View File

@ -302,7 +302,7 @@ class TDMPCPolicy(PreTrainedPolicy):
G -= running_discount * self.config.uncertainty_regularizer_coeff * terminal_values.std(0)
return G
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]:
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
"""Run the batch through the model and compute the loss.
Returns a dictionary with loss as a tensor, and other information as native floats.
@ -495,7 +495,6 @@ class TDMPCPolicy(PreTrainedPolicy):
"Q_value_loss": q_value_loss.item(),
"V_value_loss": v_value_loss.item(),
"pi_loss": pi_loss.item(),
"loss": loss,
"sum_loss": loss.item() * self.config.horizon,
}
)
@ -505,7 +504,7 @@ class TDMPCPolicy(PreTrainedPolicy):
if isinstance(batch[key], torch.Tensor) and batch[key].ndim > 1:
batch[key] = batch[key].transpose(1, 0)
return info
return loss, info
def update(self):
"""Update the target model's parameters with an EMA step."""

View File

@ -156,7 +156,7 @@ class VQBeTPolicy(PreTrainedPolicy):
action = self._queues["action"].popleft()
return action
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
"""Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch)
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
@ -170,16 +170,16 @@ class VQBeTPolicy(PreTrainedPolicy):
loss, n_different_codes, n_different_combinations, recon_l1_error = (
self.vqbet.action_head.discretize(self.config.n_vqvae_training_steps, batch["action"])
)
return {
"loss": loss,
return loss, {
"n_different_codes": n_different_codes,
"n_different_combinations": n_different_combinations,
"recon_l1_error": recon_l1_error,
}
# if Residual VQ is already trained, VQ-BeT trains its GPT and bin prediction head / offset prediction head parts.
_, loss_dict = self.vqbet(batch, rollout=False)
loss = loss_dict.pop("loss")
return loss_dict
return loss, loss_dict
class SpatialSoftmax(nn.Module):
@ -342,7 +342,7 @@ class VQBeTModel(nn.Module):
torch.row_stack([torch.arange(i, i + self.config.action_chunk_size) for i in range(num_tokens)]),
)
def forward(self, batch: dict[str, Tensor], rollout: bool) -> Tensor:
def forward(self, batch: dict[str, Tensor], rollout: bool) -> tuple[dict, dict]:
# Input validation.
assert set(batch).issuperset({"observation.state", "observation.images"})
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
@ -482,7 +482,7 @@ class VQBeTHead(nn.Module):
param.requires_grad = False
return loss, n_different_codes, n_different_combinations, recon_l1_error
def forward(self, x, **kwargs):
def forward(self, x, **kwargs) -> dict:
# N is the batch size, and T is number of action query tokens, which are process through same GPT
N, T, _ = x.shape
# we calculate N and T side parallely. Thus, the dimensions would be

View File

@ -13,10 +13,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import warnings
from pathlib import Path
from typing import TypeVar
import imageio
JsonLike = str | int | float | bool | None | list["JsonLike"] | dict[str, "JsonLike"] | tuple["JsonLike", ...]
T = TypeVar("T", bound=JsonLike)
def write_video(video_path, stacked_frames, fps):
# Filter out DeprecationWarnings raised from pkg_resources
@ -25,3 +31,81 @@ def write_video(video_path, stacked_frames, fps):
"ignore", "pkg_resources is deprecated as an API", category=DeprecationWarning
)
imageio.mimsave(video_path, stacked_frames, fps=fps)
def deserialize_json_into_object(fpath: Path, obj: T) -> T:
"""
Loads the JSON data from `fpath` and recursively fills `obj` with the
corresponding values (strictly matching structure and types).
Tuples in `obj` are expected to be lists in the JSON data, which will be
converted back into tuples.
"""
with open(fpath, encoding="utf-8") as f:
data = json.load(f)
def _deserialize(target, source):
"""
Recursively overwrite the structure in `target` with data from `source`,
performing strict checks on structure and type.
Returns the updated version of `target` (especially important for tuples).
"""
# If the target is a dictionary, source must be a dictionary as well.
if isinstance(target, dict):
if not isinstance(source, dict):
raise TypeError(f"Type mismatch: expected dict, got {type(source)}")
# Check that they have exactly the same set of keys.
if target.keys() != source.keys():
raise ValueError(
f"Dictionary keys do not match.\n" f"Expected: {target.keys()}, got: {source.keys()}"
)
# Recursively update each key.
for k in target:
target[k] = _deserialize(target[k], source[k])
return target
# If the target is a list, source must be a list as well.
elif isinstance(target, list):
if not isinstance(source, list):
raise TypeError(f"Type mismatch: expected list, got {type(source)}")
# Check length
if len(target) != len(source):
raise ValueError(f"List length mismatch: expected {len(target)}, got {len(source)}")
# Recursively update each element.
for i in range(len(target)):
target[i] = _deserialize(target[i], source[i])
return target
# If the target is a tuple, the source must be a list in JSON,
# which we'll convert back to a tuple.
elif isinstance(target, tuple):
if not isinstance(source, list):
raise TypeError(f"Type mismatch: expected list (for tuple), got {type(source)}")
if len(target) != len(source):
raise ValueError(f"Tuple length mismatch: expected {len(target)}, got {len(source)}")
# Convert each element, forming a new tuple.
converted_items = []
for t_item, s_item in zip(target, source, strict=False):
converted_items.append(_deserialize(t_item, s_item))
# Return a brand new tuple (tuples are immutable in Python).
return tuple(converted_items)
# Otherwise, we're dealing with a "primitive" (int, float, str, bool, None).
else:
# Check the exact type. If these must match 1:1, do:
if type(target) is not type(source):
raise TypeError(f"Type mismatch: expected {type(target)}, got {type(source)}")
return source
# Perform the in-place/recursive deserialization
updated_obj = _deserialize(obj, data)
return updated_obj

View File

@ -0,0 +1,163 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from lerobot.common.utils.utils import format_big_number
class AverageMeter:
"""
Computes and stores the average and current value
Adapted from https://github.com/pytorch/examples/blob/main/imagenet/main.py
"""
def __init__(self, name: str, fmt: str = ":f"):
self.name = name
self.fmt = fmt
self.reset()
def reset(self) -> None:
self.val = 0.0
self.avg = 0.0
self.sum = 0.0
self.count = 0.0
def update(self, val: float, n: int = 1) -> None:
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def __str__(self):
fmtstr = "{name}:{avg" + self.fmt + "}"
return fmtstr.format(**self.__dict__)
class MetricsTracker:
"""
A helper class to track and log metrics over time.
Usage pattern:
```python
# initialize, potentially with non-zero initial step (e.g. if resuming run)
metrics = {"loss": AverageMeter("loss", ":.3f")}
train_metrics = MetricsTracker(cfg, dataset, metrics, initial_step=step)
# update metrics derived from step (samples, episodes, epochs) at each training step
train_metrics.step()
# update various metrics
loss = policy.forward(batch)
train_metrics.loss = loss
# display current metrics
logging.info(train_metrics)
# export for wandb
wandb.log(train_metrics.to_dict())
# reset averages after logging
train_metrics.reset_averages()
```
"""
__keys__ = [
"_batch_size",
"_num_frames",
"_avg_samples_per_ep",
"metrics",
"steps",
"samples",
"episodes",
"epochs",
]
def __init__(
self,
batch_size: int,
num_frames: int,
num_episodes: int,
metrics: dict[str, AverageMeter],
initial_step: int = 0,
):
self.__dict__.update({k: None for k in self.__keys__})
self._batch_size = batch_size
self._num_frames = num_frames
self._avg_samples_per_ep = num_frames / num_episodes
self.metrics = metrics
self.steps = initial_step
# A sample is an (observation,action) pair, where observation and action
# can be on multiple timestamps. In a batch, we have `batch_size` number of samples.
self.samples = self.steps * self._batch_size
self.episodes = self.samples / self._avg_samples_per_ep
self.epochs = self.samples / self._num_frames
def __getattr__(self, name: str) -> int | dict[str, AverageMeter] | AverageMeter | Any:
if name in self.__dict__:
return self.__dict__[name]
elif name in self.metrics:
return self.metrics[name]
else:
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
def __setattr__(self, name: str, value: Any) -> None:
if name in self.__dict__:
super().__setattr__(name, value)
elif name in self.metrics:
self.metrics[name].update(value)
else:
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
def step(self) -> None:
"""
Updates metrics that depend on 'step' for one step.
"""
self.steps += 1
self.samples += self._batch_size
self.episodes = self.samples / self._avg_samples_per_ep
self.epochs = self.samples / self._num_frames
def __str__(self) -> str:
display_list = [
f"step:{format_big_number(self.steps)}",
# number of samples seen during training
f"smpl:{format_big_number(self.samples)}",
# number of episodes seen during training
f"ep:{format_big_number(self.episodes)}",
# number of time all unique samples are seen
f"epch:{self.epochs:.2f}",
*[str(m) for m in self.metrics.values()],
]
return " ".join(display_list)
def to_dict(self, use_avg: bool = True) -> dict[str, int | float]:
"""
Returns the current metric values (or averages if `use_avg=True`) as a dict.
"""
return {
"steps": self.steps,
"samples": self.samples,
"episodes": self.episodes,
"epochs": self.epochs,
**{k: m.avg if use_avg else m.val for k, m in self.metrics.items()},
}
def reset_averages(self) -> None:
"""Resets average meters."""
for m in self.metrics.values():
m.reset()

View File

@ -0,0 +1,191 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
from contextlib import contextmanager
from pathlib import Path
from typing import Any, Generator
import numpy as np
import torch
from safetensors.torch import load_file, save_file
from lerobot.common.constants import RNG_STATE
from lerobot.common.datasets.utils import flatten_dict, unflatten_dict
def serialize_python_rng_state() -> dict[str, torch.Tensor]:
"""
Returns the rng state for `random` in the form of a flat dict[str, torch.Tensor] to be saved using
`safetensors.save_file()` or `torch.save()`.
"""
py_state = random.getstate()
return {
"py_rng_version": torch.tensor([py_state[0]], dtype=torch.int64),
"py_rng_state": torch.tensor(py_state[1], dtype=torch.int64),
}
def deserialize_python_rng_state(rng_state_dict: dict[str, torch.Tensor]) -> None:
"""
Restores the rng state for `random` from a dictionary produced by `serialize_python_rng_state()`.
"""
py_state = (rng_state_dict["py_rng_version"].item(), tuple(rng_state_dict["py_rng_state"].tolist()), None)
random.setstate(py_state)
def serialize_numpy_rng_state() -> dict[str, torch.Tensor]:
"""
Returns the rng state for `numpy` in the form of a flat dict[str, torch.Tensor] to be saved using
`safetensors.save_file()` or `torch.save()`.
"""
np_state = np.random.get_state()
# Ensure no breaking changes from numpy
assert np_state[0] == "MT19937"
return {
"np_rng_state_values": torch.tensor(np_state[1], dtype=torch.int64),
"np_rng_state_index": torch.tensor([np_state[2]], dtype=torch.int64),
"np_rng_has_gauss": torch.tensor([np_state[3]], dtype=torch.int64),
"np_rng_cached_gaussian": torch.tensor([np_state[4]], dtype=torch.float32),
}
def deserialize_numpy_rng_state(rng_state_dict: dict[str, torch.Tensor]) -> None:
"""
Restores the rng state for `numpy` from a dictionary produced by `serialize_numpy_rng_state()`.
"""
np_state = (
"MT19937",
rng_state_dict["np_rng_state_values"].numpy(),
rng_state_dict["np_rng_state_index"].item(),
rng_state_dict["np_rng_has_gauss"].item(),
rng_state_dict["np_rng_cached_gaussian"].item(),
)
np.random.set_state(np_state)
def serialize_torch_rng_state() -> dict[str, torch.Tensor]:
"""
Returns the rng state for `torch` in the form of a flat dict[str, torch.Tensor] to be saved using
`safetensors.save_file()` or `torch.save()`.
"""
torch_rng_state_dict = {"torch_rng_state": torch.get_rng_state()}
if torch.cuda.is_available():
torch_rng_state_dict["torch_cuda_rng_state"] = torch.cuda.get_rng_state()
return torch_rng_state_dict
def deserialize_torch_rng_state(rng_state_dict: dict[str, torch.Tensor]) -> None:
"""
Restores the rng state for `torch` from a dictionary produced by `serialize_torch_rng_state()`.
"""
torch.set_rng_state(rng_state_dict["torch_rng_state"])
if torch.cuda.is_available() and "torch_cuda_rng_state" in rng_state_dict:
torch.cuda.set_rng_state(rng_state_dict["torch_cuda_rng_state"])
def serialize_rng_state() -> dict[str, torch.Tensor]:
"""
Returns the rng state for `random`, `numpy`, and `torch`, in the form of a flat
dict[str, torch.Tensor] to be saved using `safetensors.save_file()` `torch.save()`.
"""
py_rng_state_dict = serialize_python_rng_state()
np_rng_state_dict = serialize_numpy_rng_state()
torch_rng_state_dict = serialize_torch_rng_state()
return {
**py_rng_state_dict,
**np_rng_state_dict,
**torch_rng_state_dict,
}
def deserialize_rng_state(rng_state_dict: dict[str, torch.Tensor]) -> None:
"""
Restores the rng state for `random`, `numpy`, and `torch` from a dictionary produced by
`serialize_rng_state()`.
"""
py_rng_state_dict = {k: v for k, v in rng_state_dict.items() if k.startswith("py")}
np_rng_state_dict = {k: v for k, v in rng_state_dict.items() if k.startswith("np")}
torch_rng_state_dict = {k: v for k, v in rng_state_dict.items() if k.startswith("torch")}
deserialize_python_rng_state(py_rng_state_dict)
deserialize_numpy_rng_state(np_rng_state_dict)
deserialize_torch_rng_state(torch_rng_state_dict)
def save_rng_state(save_dir: Path) -> None:
rng_state_dict = serialize_rng_state()
flat_rng_state_dict = flatten_dict(rng_state_dict)
save_file(flat_rng_state_dict, save_dir / RNG_STATE)
def load_rng_state(save_dir: Path) -> None:
flat_rng_state_dict = load_file(save_dir / RNG_STATE)
rng_state_dict = unflatten_dict(flat_rng_state_dict)
deserialize_rng_state(rng_state_dict)
def get_rng_state() -> dict[str, Any]:
"""Get the random state for `random`, `numpy`, and `torch`."""
random_state_dict = {
"random_state": random.getstate(),
"numpy_random_state": np.random.get_state(),
"torch_random_state": torch.random.get_rng_state(),
}
if torch.cuda.is_available():
random_state_dict["torch_cuda_random_state"] = torch.cuda.random.get_rng_state()
return random_state_dict
def set_rng_state(random_state_dict: dict[str, Any]):
"""Set the random state for `random`, `numpy`, and `torch`.
Args:
random_state_dict: A dictionary of the form returned by `get_rng_state`.
"""
random.setstate(random_state_dict["random_state"])
np.random.set_state(random_state_dict["numpy_random_state"])
torch.random.set_rng_state(random_state_dict["torch_random_state"])
if torch.cuda.is_available():
torch.cuda.random.set_rng_state(random_state_dict["torch_cuda_random_state"])
def set_seed(seed) -> None:
"""Set seed for reproducibility."""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
@contextmanager
def seeded_context(seed: int) -> Generator[None, None, None]:
"""Set the seed when entering a context, and restore the prior random state at exit.
Example usage:
```
a = random.random() # produces some random number
with seeded_context(1337):
b = random.random() # produces some other random number
c = random.random() # produces yet another random number, but the same it would have if we never made `b`
```
"""
random_state_dict = get_rng_state()
set_seed(seed)
yield None
set_rng_state(random_state_dict)

View File

@ -0,0 +1,161 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from pathlib import Path
from termcolor import colored
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
from lerobot.common.constants import (
CHECKPOINTS_DIR,
LAST_CHECKPOINT_LINK,
PRETRAINED_MODEL_DIR,
TRAINING_STATE_DIR,
TRAINING_STEP,
)
from lerobot.common.datasets.utils import load_json, write_json
from lerobot.common.optim.optimizers import load_optimizer_state, save_optimizer_state
from lerobot.common.optim.schedulers import load_scheduler_state, save_scheduler_state
from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.utils.random_utils import load_rng_state, save_rng_state
from lerobot.configs.train import TrainPipelineConfig
def log_output_dir(out_dir):
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {out_dir}")
def get_step_identifier(step: int, total_steps: int) -> str:
num_digits = max(6, len(str(total_steps)))
return f"{step:0{num_digits}d}"
def get_step_checkpoint_dir(output_dir: Path, total_steps: int, step: int) -> Path:
"""Returns the checkpoint sub-directory corresponding to the step number."""
step_identifier = get_step_identifier(step, total_steps)
return output_dir / CHECKPOINTS_DIR / step_identifier
def save_training_step(step: int, save_dir: Path) -> None:
write_json({"step": step}, save_dir / TRAINING_STEP)
def load_training_step(save_dir: Path) -> int:
training_step = load_json(save_dir / TRAINING_STEP)
return training_step["step"]
def update_last_checkpoint(checkpoint_dir: Path) -> Path:
last_checkpoint_dir = checkpoint_dir.parent / LAST_CHECKPOINT_LINK
if last_checkpoint_dir.is_symlink():
last_checkpoint_dir.unlink()
relative_target = checkpoint_dir.relative_to(checkpoint_dir.parent)
last_checkpoint_dir.symlink_to(relative_target)
def save_checkpoint(
checkpoint_dir: Path,
step: int,
cfg: TrainPipelineConfig,
policy: PreTrainedPolicy,
optimizer: Optimizer,
scheduler: LRScheduler | None = None,
) -> None:
"""This function creates the following directory structure:
005000/ # training step at checkpoint
pretrained_model/
config.json # policy config
model.safetensors # policy weights
train_config.json # train config
training_state/
optimizer_param_groups.json # optimizer param groups
optimizer_state.safetensors # optimizer state
rng_state.safetensors # rng states
scheduler_state.json # scheduler state
training_step.json # training step
Args:
cfg (TrainPipelineConfig): The training config used for this run.
step (int): The training step at that checkpoint.
policy (PreTrainedPolicy): The policy to save.
optimizer (Optimizer | None, optional): The optimizer to save the state from. Defaults to None.
scheduler (LRScheduler | None, optional): The scheduler to save the state from. Defaults to None.
"""
pretrained_dir = checkpoint_dir / PRETRAINED_MODEL_DIR
policy.save_pretrained(pretrained_dir)
cfg.save_pretrained(pretrained_dir)
save_training_state(checkpoint_dir, step, optimizer, scheduler)
def save_training_state(
checkpoint_dir: Path,
train_step: int,
optimizer: Optimizer | None = None,
scheduler: LRScheduler | None = None,
) -> None:
"""
Saves the training step, optimizer state, scheduler state, and rng state.
Args:
save_dir (Path): The directory to save artifacts to.
train_step (int): Current training step.
optimizer (Optimizer | None, optional): The optimizer from which to save the state_dict.
Defaults to None.
scheduler (LRScheduler | None, optional): The scheduler from which to save the state_dict.
Defaults to None.
"""
save_dir = checkpoint_dir / TRAINING_STATE_DIR
save_dir.mkdir(parents=True, exist_ok=True)
save_training_step(train_step, save_dir)
save_rng_state(save_dir)
if optimizer is not None:
save_optimizer_state(optimizer, save_dir)
if scheduler is not None:
save_scheduler_state(scheduler, save_dir)
def load_training_state(
checkpoint_dir: Path, optimizer: Optimizer, scheduler: LRScheduler | None
) -> tuple[int, Optimizer, LRScheduler | None]:
"""
Loads the training step, optimizer state, scheduler state, and rng state.
This is used to resume a training run.
Args:
checkpoint_dir (Path): The checkpoint directory. Should contain a 'training_state' dir.
optimizer (Optimizer): The optimizer to load the state_dict to.
scheduler (LRScheduler | None): The scheduler to load the state_dict to (can be None).
Raises:
NotADirectoryError: If 'checkpoint_dir' doesn't contain a 'training_state' dir
Returns:
tuple[int, Optimizer, LRScheduler | None]: training step, optimizer and scheduler with their
state_dict loaded.
"""
training_state_dir = checkpoint_dir / TRAINING_STATE_DIR
if not training_state_dir.is_dir():
raise NotADirectoryError(training_state_dir)
load_rng_state(training_state_dir)
step = load_training_step(training_state_dir)
optimizer = load_optimizer_state(optimizer, training_state_dir)
if scheduler is not None:
scheduler = load_scheduler_state(scheduler, training_state_dir)
return step, optimizer, scheduler

View File

@ -17,14 +17,10 @@ import logging
import os
import os.path as osp
import platform
import random
from contextlib import contextmanager
from copy import copy
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Generator
import numpy as np
import torch
@ -106,59 +102,6 @@ def is_amp_available(device: str):
raise ValueError(f"Unknown device '{device}.")
def get_global_random_state() -> dict[str, Any]:
"""Get the random state for `random`, `numpy`, and `torch`."""
random_state_dict = {
"random_state": random.getstate(),
"numpy_random_state": np.random.get_state(),
"torch_random_state": torch.random.get_rng_state(),
}
if torch.cuda.is_available():
random_state_dict["torch_cuda_random_state"] = torch.cuda.random.get_rng_state()
return random_state_dict
def set_global_random_state(random_state_dict: dict[str, Any]):
"""Set the random state for `random`, `numpy`, and `torch`.
Args:
random_state_dict: A dictionary of the form returned by `get_global_random_state`.
"""
random.setstate(random_state_dict["random_state"])
np.random.set_state(random_state_dict["numpy_random_state"])
torch.random.set_rng_state(random_state_dict["torch_random_state"])
if torch.cuda.is_available():
torch.cuda.random.set_rng_state(random_state_dict["torch_cuda_random_state"])
def set_global_seed(seed):
"""Set seed for reproducibility."""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
@contextmanager
def seeded_context(seed: int) -> Generator[None, None, None]:
"""Set the seed when entering a context, and restore the prior random state at exit.
Example usage:
```
a = random.random() # produces some random number
with seeded_context(1337):
b = random.random() # produces some other random number
c = random.random() # produces yet another random number, but the same it would have if we never made `b`
```
"""
random_state_dict = get_global_random_state()
set_global_seed(seed)
yield None
set_global_random_state(random_state_dict)
def init_logging():
def custom_format(record):
dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S")

View File

@ -0,0 +1,121 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import re
from glob import glob
from pathlib import Path
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
from termcolor import colored
from lerobot.common.constants import PRETRAINED_MODEL_DIR
from lerobot.configs.train import TrainPipelineConfig
def cfg_to_group(cfg: TrainPipelineConfig, return_list: bool = False) -> list[str] | str:
"""Return a group name for logging. Optionally returns group name as list."""
lst = [
f"policy:{cfg.policy.type}",
f"dataset:{cfg.dataset.repo_id}",
f"seed:{cfg.seed}",
]
if cfg.env is not None:
lst.append(f"env:{cfg.env.type}")
return lst if return_list else "-".join(lst)
def get_wandb_run_id_from_filesystem(log_dir: Path) -> str:
# Get the WandB run ID.
paths = glob(str(log_dir / "wandb/latest-run/run-*"))
if len(paths) != 1:
raise RuntimeError("Couldn't get the previous WandB run ID for run resumption.")
match = re.search(r"run-([^\.]+).wandb", paths[0].split("/")[-1])
if match is None:
raise RuntimeError("Couldn't get the previous WandB run ID for run resumption.")
wandb_run_id = match.groups(0)[0]
return wandb_run_id
def get_safe_wandb_artifact_name(name: str):
"""WandB artifacts don't accept ":" or "/" in their name."""
return name.replace(":", "_").replace("/", "_")
class WandBLogger:
"""A helper class to log object using wandb."""
def __init__(self, cfg: TrainPipelineConfig):
self.cfg = cfg.wandb
self.log_dir = cfg.output_dir
self.job_name = cfg.job_name
self.env_fps = cfg.env.fps if cfg.env else None
self._group = cfg_to_group(cfg)
# Set up WandB.
os.environ["WANDB_SILENT"] = "True"
import wandb
wandb_run_id = get_wandb_run_id_from_filesystem(self.log_dir) if cfg.resume else None
wandb.init(
id=wandb_run_id,
project=self.cfg.project,
entity=self.cfg.entity,
name=self.job_name,
notes=self.cfg.notes,
tags=cfg_to_group(cfg, return_list=True),
dir=self.log_dir,
config=cfg.to_dict(),
# TODO(rcadene): try set to True
save_code=False,
# TODO(rcadene): split train and eval, and run async eval with job_type="eval"
job_type="train_eval",
resume="must" if cfg.resume else None,
)
print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"]))
logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}")
self._wandb = wandb
def log_policy(self, checkpoint_dir: Path):
"""Checkpoints the policy to wandb."""
if self.cfg.disable_artifact:
return
step_id = checkpoint_dir.name
artifact_name = f"{self._group}-{step_id}"
artifact_name = get_safe_wandb_artifact_name(artifact_name)
artifact = self._wandb.Artifact(artifact_name, type="model")
artifact.add_file(checkpoint_dir / PRETRAINED_MODEL_DIR / SAFETENSORS_SINGLE_FILE)
self._wandb.log_artifact(artifact)
def log_dict(self, d: dict, step: int, mode: str = "train"):
if mode in {"train", "eval"}:
raise ValueError(mode)
for k, v in d.items():
if not isinstance(v, (int, float, str)):
logging.warning(
f'WandB logging of key "{k}" was ignored as its type is not handled by this wrapper.'
)
continue
self._wandb.log({f"{mode}/{k}": v}, step=step)
def log_video(self, video_path: str, step: int, mode: str = "train"):
if mode in {"train", "eval"}:
raise ValueError(mode)
wandb_video = self._wandb.Video(video_path, fps=self.env_fps, format="mp4")
self._wandb.log({f"{mode}/video": wandb_video}, step=step)

View File

@ -21,68 +21,6 @@ from lerobot.configs.policies import PreTrainedConfig
TRAIN_CONFIG_NAME = "train_config.json"
@dataclass
class OfflineConfig:
steps: int = 100_000
@dataclass
class OnlineConfig:
"""
The online training loop looks something like:
```python
for i in range(steps):
do_online_rollout_and_update_online_buffer()
for j in range(steps_between_rollouts):
batch = next(dataloader_with_offline_and_online_data)
loss = policy(batch)
loss.backward()
optimizer.step()
```
Note that the online training loop adopts most of the options from the offline loop unless specified
otherwise.
"""
steps: int = 0
# How many episodes to collect at once when we reach the online rollout part of the training loop.
rollout_n_episodes: int = 1
# The number of environments to use in the gym.vector.VectorEnv. This ends up also being the batch size for
# the policy. Ideally you should set this to by an even divisor of rollout_n_episodes.
rollout_batch_size: int = 1
# How many optimization steps (forward, backward, optimizer step) to do between running rollouts.
steps_between_rollouts: int | None = None
# The proportion of online samples (vs offline samples) to include in the online training batches.
sampling_ratio: float = 0.5
# First seed to use for the online rollout environment. Seeds for subsequent rollouts are incremented by 1.
env_seed: int | None = None
# Sets the maximum number of frames that are stored in the online buffer for online training. The buffer is
# FIFO.
buffer_capacity: int | None = None
# The minimum number of frames to have in the online buffer before commencing online training.
# If buffer_seed_size > rollout_n_episodes, the rollout will be run multiple times until the
# seed size condition is satisfied.
buffer_seed_size: int = 0
# Whether to run the online rollouts asynchronously. This means we can run the online training steps in
# parallel with the rollouts. This might be advised if your GPU has the bandwidth to handle training
# + eval + environment rendering simultaneously.
do_rollout_async: bool = False
def __post_init__(self):
if self.steps == 0:
return
if self.steps_between_rollouts is None:
raise ValueError(
"'steps_between_rollouts' must be set to a positive integer, but it is currently None."
)
if self.env_seed is None:
raise ValueError("'env_seed' must be set to a positive integer, but it is currently None.")
if self.buffer_capacity is None:
raise ValueError("'buffer_capacity' must be set to a positive integer, but it is currently None.")
@dataclass
class TrainPipelineConfig(HubMixin):
dataset: DatasetConfig
@ -107,13 +45,12 @@ class TrainPipelineConfig(HubMixin):
# Number of workers for the dataloader.
num_workers: int = 4
batch_size: int = 8
steps: int = 100_000
eval_freq: int = 20_000
log_freq: int = 200
save_checkpoint: bool = True
# Checkpoint is saved every `save_freq` training iterations and after the last training step.
save_freq: int = 20_000
offline: OfflineConfig = field(default_factory=OfflineConfig)
online: OnlineConfig = field(default_factory=OnlineConfig)
use_policy_training_preset: bool = True
optimizer: OptimizerConfig | None = None
scheduler: LRSchedulerConfig | None = None
@ -168,11 +105,8 @@ class TrainPipelineConfig(HubMixin):
train_dir = f"{now:%Y-%m-%d}/{now:%H-%M-%S}_{self.job_name}"
self.output_dir = Path("outputs/train") / train_dir
if self.online.steps > 0:
if isinstance(self.dataset.repo_id, list):
raise NotImplementedError("Online training with LeRobotMultiDataset is not implemented.")
if self.env is None:
raise ValueError("An environment is required for online training")
raise NotImplementedError("LeRobotMultiDataset is not currently implemented.")
if not self.use_policy_training_preset and (self.optimizer is None or self.scheduler is None):
raise ValueError("Optimizer and Scheduler must be set when the policy presets are not used.")
@ -185,6 +119,9 @@ class TrainPipelineConfig(HubMixin):
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
return ["policy"]
def to_dict(self) -> dict:
return draccus.encode(self)
def _save_pretrained(self, save_directory: Path) -> None:
with open(save_directory / TRAIN_CONFIG_NAME, "w") as f, draccus.config_type("json"):
draccus.dump(self, f, indent=4)

View File

@ -61,21 +61,21 @@ import einops
import gymnasium as gym
import numpy as np
import torch
from termcolor import colored
from torch import Tensor, nn
from tqdm import trange
from lerobot.common.envs.factory import make_env
from lerobot.common.envs.utils import preprocess_observation
from lerobot.common.logger import log_output_dir
from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.policies.utils import get_device_from_parameters
from lerobot.common.utils.io_utils import write_video
from lerobot.common.utils.random_utils import set_seed
from lerobot.common.utils.utils import (
get_safe_torch_device,
init_logging,
inside_slurm,
set_global_seed,
)
from lerobot.configs import parser
from lerobot.configs.eval import EvalPipelineConfig
@ -125,9 +125,6 @@ def rollout(
# Reset the policy and environments.
policy.reset()
if hasattr(policy, "use_ema_modules"):
policy.use_ema_modules()
observation, info = env.reset(seed=seeds)
if render_callback is not None:
render_callback(env)
@ -463,9 +460,9 @@ def eval(cfg: EvalPipelineConfig):
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
set_global_seed(cfg.seed)
set_seed(cfg.seed)
log_output_dir(cfg.output_dir)
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}")
logging.info("Making environment.")
env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs)

View File

@ -15,58 +15,61 @@
# limitations under the License.
import logging
import time
from concurrent.futures import ThreadPoolExecutor
from contextlib import nullcontext
from copy import deepcopy
from dataclasses import asdict
from pprint import pformat
from threading import Lock
from typing import Any
import numpy as np
import torch
from termcolor import colored
from torch.amp import GradScaler
from torch.optim import Optimizer
from lerobot.common.datasets.factory import make_dataset, resolve_delta_timestamps
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.online_buffer import OnlineBuffer, compute_sampler_weights
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.sampler import EpisodeAwareSampler
from lerobot.common.datasets.utils import cycle
from lerobot.common.envs.factory import make_env
from lerobot.common.logger import Logger, log_output_dir
from lerobot.common.optim.factory import load_training_state, make_optimizer_and_scheduler
from lerobot.common.optim.factory import make_optimizer_and_scheduler
from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.policies.utils import get_device_from_parameters
from lerobot.common.utils.logging_utils import AverageMeter, MetricsTracker
from lerobot.common.utils.random_utils import set_seed
from lerobot.common.utils.train_utils import (
get_step_checkpoint_dir,
get_step_identifier,
load_training_state,
save_checkpoint,
update_last_checkpoint,
)
from lerobot.common.utils.utils import (
format_big_number,
get_safe_dtype,
get_safe_torch_device,
has_method,
init_logging,
set_global_seed,
)
from lerobot.common.utils.wandb_utils import WandBLogger
from lerobot.configs import parser
from lerobot.configs.train import TrainPipelineConfig
from lerobot.scripts.eval import eval_policy
def update_policy(
policy,
batch,
optimizer,
grad_clip_norm,
train_metrics: MetricsTracker,
policy: PreTrainedPolicy,
batch: Any,
optimizer: Optimizer,
grad_clip_norm: float,
grad_scaler: GradScaler,
lr_scheduler=None,
use_amp: bool = False,
lock=None,
):
"""Returns a dictionary of items for logging."""
) -> tuple[MetricsTracker, dict]:
start_time = time.perf_counter()
device = get_device_from_parameters(policy)
policy.train()
with torch.autocast(device_type=device.type) if use_amp else nullcontext():
output_dict = policy.forward(batch)
loss, output_dict = policy.forward(batch)
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
loss = output_dict["loss"]
grad_scaler.scale(loss).backward()
# Unscale the graident of the optimzer's assigned params in-place **prior to gradient clipping**.
@ -87,9 +90,6 @@ def update_policy(
optimizer.zero_grad()
if hasattr(policy, "update_ema_modules"):
policy.update_ema_modules()
# Step through pytorch scheduler at every batch instead of epoch
if lr_scheduler is not None:
lr_scheduler.step()
@ -98,113 +98,34 @@ def update_policy(
# To possibly update an internal buffer (for instance an Exponential Moving Average like in TDMPC).
policy.update()
info = {
"loss": loss.item(),
"grad_norm": float(grad_norm),
"lr": optimizer.param_groups[0]["lr"],
"update_s": time.perf_counter() - start_time,
**{k: v for k, v in output_dict.items() if k != "loss"},
}
info.update({k: v for k, v in output_dict.items() if k not in info})
return info
def log_train_info(
logger: Logger, info: dict, step: int, cfg: TrainPipelineConfig, dataset: LeRobotDataset, is_online: bool
):
loss = info["loss"]
grad_norm = info["grad_norm"]
lr = info["lr"]
update_s = info["update_s"]
dataloading_s = info["dataloading_s"]
# A sample is an (observation,action) pair, where observation and action
# can be on multiple timestamps. In a batch, we have `batch_size`` number of samples.
num_samples = (step + 1) * cfg.batch_size
avg_samples_per_ep = dataset.num_frames / dataset.num_episodes
num_episodes = num_samples / avg_samples_per_ep
num_epochs = num_samples / dataset.num_frames
log_items = [
f"step:{format_big_number(step)}",
# number of samples seen during training
f"smpl:{format_big_number(num_samples)}",
# number of episodes seen during training
f"ep:{format_big_number(num_episodes)}",
# number of time all unique samples are seen
f"epch:{num_epochs:.2f}",
f"loss:{loss:.3f}",
f"grdn:{grad_norm:.3f}",
f"lr:{lr:0.1e}",
# in seconds
f"updt_s:{update_s:.3f}",
f"data_s:{dataloading_s:.3f}", # if not ~0, you are bottlenecked by cpu or io
]
logging.info(" ".join(log_items))
info["step"] = step
info["num_samples"] = num_samples
info["num_episodes"] = num_episodes
info["num_epochs"] = num_epochs
info["is_online"] = is_online
logger.log_dict(info, step, mode="train")
def log_eval_info(logger, info, step, cfg, dataset, is_online):
eval_s = info["eval_s"]
avg_sum_reward = info["avg_sum_reward"]
pc_success = info["pc_success"]
# A sample is an (observation,action) pair, where observation and action
# can be on multiple timestamps. In a batch, we have `batch_size`` number of samples.
num_samples = (step + 1) * cfg.batch_size
avg_samples_per_ep = dataset.num_frames / dataset.num_episodes
num_episodes = num_samples / avg_samples_per_ep
num_epochs = num_samples / dataset.num_frames
log_items = [
f"step:{format_big_number(step)}",
# number of samples seen during training
f"smpl:{format_big_number(num_samples)}",
# number of episodes seen during training
f"ep:{format_big_number(num_episodes)}",
# number of time all unique samples are seen
f"epch:{num_epochs:.2f}",
f"∑rwrd:{avg_sum_reward:.3f}",
f"success:{pc_success:.1f}%",
f"eval_s:{eval_s:.3f}",
]
logging.info(" ".join(log_items))
info["step"] = step
info["num_samples"] = num_samples
info["num_episodes"] = num_episodes
info["num_epochs"] = num_epochs
info["is_online"] = is_online
logger.log_dict(info, step, mode="eval")
train_metrics.loss = loss.item()
train_metrics.grad_norm = grad_norm.item()
train_metrics.lr = optimizer.param_groups[0]["lr"]
train_metrics.update_s = time.perf_counter() - start_time
return train_metrics, output_dict
@parser.wrap()
def train(cfg: TrainPipelineConfig):
cfg.validate()
logging.info(pformat(cfg.to_dict()))
logging.info(pformat(asdict(cfg)))
# log metrics to terminal and wandb
logger = Logger(cfg)
if cfg.wandb.enable and cfg.wandb.project:
wandb_logger = WandBLogger(cfg)
else:
wandb_logger = None
logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"]))
if cfg.seed is not None:
set_global_seed(cfg.seed)
set_seed(cfg.seed)
# Check device is available
device = get_safe_torch_device(cfg.device, log=True)
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
logging.info("Creating dataset")
offline_dataset = make_dataset(cfg)
dataset = make_dataset(cfg)
# Create environment used for evaluating checkpoints during training on simulation data.
# On real-world data, no need to create an environment as evaluations are done outside train.py,
@ -218,7 +139,7 @@ def train(cfg: TrainPipelineConfig):
policy = make_policy(
cfg=cfg.policy,
device=device,
ds_meta=offline_dataset.meta,
ds_meta=dataset.meta,
)
logging.info("Creating optimizer and scheduler")
@ -233,65 +154,29 @@ def train(cfg: TrainPipelineConfig):
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
num_total_params = sum(p.numel() for p in policy.parameters())
log_output_dir(cfg.output_dir)
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}")
if cfg.env is not None:
logging.info(f"{cfg.env.task=}")
logging.info(f"{cfg.offline.steps=} ({format_big_number(cfg.offline.steps)})")
logging.info(f"{cfg.online.steps=}")
logging.info(f"{offline_dataset.num_frames=} ({format_big_number(offline_dataset.num_frames)})")
logging.info(f"{offline_dataset.num_episodes=}")
logging.info(f"{cfg.steps=} ({format_big_number(cfg.steps)})")
logging.info(f"{dataset.num_frames=} ({format_big_number(dataset.num_frames)})")
logging.info(f"{dataset.num_episodes=}")
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
# Note: this helper will be used in offline and online training loops.
def evaluate_and_checkpoint_if_needed(step: int, is_online: bool):
_num_digits = max(6, len(str(cfg.offline.steps + cfg.online.steps)))
step_identifier = f"{step:0{_num_digits}d}"
if cfg.env is not None and cfg.eval_freq > 0 and step % cfg.eval_freq == 0:
logging.info(f"Eval policy at step {step}")
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.use_amp else nullcontext():
eval_info = eval_policy(
eval_env,
policy,
cfg.eval.n_episodes,
videos_dir=cfg.output_dir / "eval" / f"videos_step_{step_identifier}",
max_episodes_rendered=4,
start_seed=cfg.seed,
)
log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_dataset, is_online=is_online)
if cfg.wandb.enable:
logger.log_video(eval_info["video_paths"][0], step, mode="eval")
logging.info("Resume training")
if cfg.save_checkpoint and (
step % cfg.save_freq == 0 or step == cfg.offline.steps + cfg.online.steps
):
logging.info(f"Checkpoint policy after step {step}")
# Note: Save with step as the identifier, and format it to have at least 6 digits but more if
# needed (choose 6 as a minimum for consistency without being overkill).
logger.save_checkpoint(
step,
step_identifier,
policy,
optimizer,
lr_scheduler,
)
logging.info("Resume training")
# create dataloader for offline training
if getattr(cfg.policy, "drop_n_last_frames", None):
if hasattr(cfg.policy, "drop_n_last_frames"):
shuffle = False
sampler = EpisodeAwareSampler(
offline_dataset.episode_data_index,
dataset.episode_data_index,
drop_n_last_frames=cfg.policy.drop_n_last_frames,
shuffle=True,
)
else:
shuffle = True
sampler = None
dataloader = torch.utils.data.DataLoader(
offline_dataset,
dataset,
num_workers=cfg.num_workers,
batch_size=cfg.batch_size,
shuffle=shuffle,
@ -303,23 +188,30 @@ def train(cfg: TrainPipelineConfig):
policy.train()
if hasattr(policy, "init_ema_modules"):
policy.init_ema_modules()
train_metrics = {
"loss": AverageMeter("loss", ":.3f"),
"grad_norm": AverageMeter("grdn", ":.3f"),
"lr": AverageMeter("lr", ":0.1e"),
"update_s": AverageMeter("updt_s", ":.3f"),
"dataloading_s": AverageMeter("data_s", ":.3f"),
}
train_tracker = MetricsTracker(
cfg.batch_size, dataset.num_frames, dataset.num_episodes, train_metrics, initial_step=step
)
offline_step = 0
for _ in range(step, cfg.offline.steps):
if offline_step == 0:
logging.info("Start offline training on a fixed dataset")
for _ in range(step, cfg.steps):
start_time = time.perf_counter()
batch = next(dl_iter)
dataloading_s = time.perf_counter() - start_time
train_tracker.dataloading_s = time.perf_counter() - start_time
for key in batch:
if isinstance(batch[key], torch.Tensor):
batch[key] = batch[key].to(device, non_blocking=True)
train_info = update_policy(
train_tracker, output_dict = update_policy(
train_tracker,
policy,
batch,
optimizer,
@ -329,231 +221,57 @@ def train(cfg: TrainPipelineConfig):
use_amp=cfg.use_amp,
)
train_info["dataloading_s"] = dataloading_s
if step % cfg.log_freq == 0:
log_train_info(logger, train_info, step, cfg, offline_dataset, is_online=False)
# Note: evaluate_and_checkpoint_if_needed happens **after** the `step`th training update has completed,
# so we pass in step + 1.
evaluate_and_checkpoint_if_needed(step + 1, is_online=False)
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
# increment `step` here.
step += 1
offline_step += 1 # noqa: SIM113
train_tracker.step()
is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0
is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps
is_eval_step = cfg.eval_freq > 0 and step % cfg.eval_freq == 0
if cfg.online.steps == 0:
if eval_env:
eval_env.close()
logging.info("End of training")
return
if is_log_step:
logging.info(train_tracker)
if wandb_logger:
wandb_log_dict = {**train_tracker.to_dict(), **output_dict}
wandb_logger.log_dict(wandb_log_dict)
train_tracker.reset_averages()
# Online training.
if cfg.save_checkpoint and is_saving_step:
logging.info(f"Checkpoint policy after step {step}")
checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step)
save_checkpoint(checkpoint_dir, step, cfg, policy, optimizer, lr_scheduler)
update_last_checkpoint(checkpoint_dir)
if wandb_logger:
wandb_logger.log_policy(checkpoint_dir)
# Create an env dedicated to online episodes collection from policy rollout.
online_env = make_env(cfg.env, n_envs=cfg.online.rollout_batch_size)
delta_timestamps = resolve_delta_timestamps(cfg.policy, offline_dataset.meta)
online_buffer_path = logger.log_dir / "online_buffer"
if cfg.resume and not online_buffer_path.exists():
# If we are resuming a run, we default to the data shapes and buffer capacity from the saved online
# buffer.
logging.warning(
"When online training is resumed, we load the latest online buffer from the prior run, "
"and this might not coincide with the state of the buffer as it was at the moment the checkpoint "
"was made. This is because the online buffer is updated on disk during training, independently "
"of our explicit checkpointing mechanisms."
)
online_dataset = OnlineBuffer(
online_buffer_path,
data_spec={
**{
key: {"shape": ft.shape, "dtype": np.dtype("float32")}
for key, ft in policy.config.input_features.items()
},
**{
key: {"shape": ft.shape, "dtype": np.dtype("float32")}
for key, ft in policy.config.output_features.items()
},
"next.reward": {"shape": (), "dtype": np.dtype("float32")},
"next.done": {"shape": (), "dtype": np.dtype("?")},
"task_index": {"shape": (), "dtype": np.dtype("int64")},
# FIXME: 'task' is a string
# "task": {"shape": (), "dtype": np.dtype("?")},
# FIXME: 'next.success' is expected by pusht env but not xarm
"next.success": {"shape": (), "dtype": np.dtype("?")},
},
buffer_capacity=cfg.online.buffer_capacity,
fps=online_env.unwrapped.metadata["render_fps"],
delta_timestamps=delta_timestamps,
)
# If we are doing online rollouts asynchronously, deepcopy the policy to use for online rollouts (this
# makes it possible to do online rollouts in parallel with training updates).
online_rollout_policy = deepcopy(policy) if cfg.online.do_rollout_async else policy
# Create dataloader for online training.
concat_dataset = torch.utils.data.ConcatDataset([offline_dataset, online_dataset])
sampler_weights = compute_sampler_weights(
offline_dataset,
offline_drop_n_last_frames=getattr(cfg.policy, "drop_n_last_frames", 0),
online_dataset=online_dataset,
# +1 because online rollouts return an extra frame for the "final observation". Note: we don't have
# this final observation in the offline datasets, but we might add them in future.
online_drop_n_last_frames=getattr(cfg.policy, "drop_n_last_frames", 0) + 1,
online_sampling_ratio=cfg.online.sampling_ratio,
)
sampler = torch.utils.data.WeightedRandomSampler(
sampler_weights,
num_samples=len(concat_dataset),
replacement=True,
)
dataloader = torch.utils.data.DataLoader(
concat_dataset,
batch_size=cfg.batch_size,
num_workers=cfg.num_workers,
sampler=sampler,
pin_memory=device.type != "cpu",
drop_last=True,
)
dl_iter = cycle(dataloader)
if cfg.online.do_rollout_async:
# Lock and thread pool executor for asynchronous online rollouts.
lock = Lock()
# Note: 1 worker because we only ever want to run one set of online rollouts at a time. Batch
# parallelization of rollouts is handled within the job.
executor = ThreadPoolExecutor(max_workers=1)
else:
lock = None
online_step = 0
online_rollout_s = 0 # time take to do online rollout
update_online_buffer_s = 0 # time taken to update the online buffer with the online rollout data
# Time taken waiting for the online buffer to finish being updated. This is relevant when using the async
# online rollout option.
await_update_online_buffer_s = 0
rollout_start_seed = cfg.online.env_seed
while True:
if online_step == cfg.online.steps:
break
if online_step == 0:
logging.info("Start online training by interacting with environment")
def sample_trajectory_and_update_buffer():
nonlocal rollout_start_seed
with lock if lock is not None else nullcontext():
online_rollout_policy.load_state_dict(policy.state_dict())
online_rollout_policy.eval()
start_rollout_time = time.perf_counter()
with torch.no_grad():
if cfg.env and is_eval_step:
step_id = get_step_identifier(step, cfg.steps)
logging.info(f"Eval policy at step {step}")
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.use_amp else nullcontext():
eval_info = eval_policy(
online_env,
online_rollout_policy,
n_episodes=cfg.online.rollout_n_episodes,
max_episodes_rendered=min(10, cfg.online.rollout_n_episodes),
videos_dir=logger.log_dir / "online_rollout_videos",
return_episode_data=True,
start_seed=(rollout_start_seed := (rollout_start_seed + cfg.batch_size) % 1000000),
)
online_rollout_s = time.perf_counter() - start_rollout_time
if len(offline_dataset.meta.tasks) > 1:
raise NotImplementedError("Add support for multi task.")
# TODO(rcadene, aliberts): Hack to add a task to the online_dataset (0 is the first task of the offline_dataset)
total_num_frames = eval_info["episodes"]["index"].shape[0]
eval_info["episodes"]["task_index"] = torch.tensor([0] * total_num_frames, dtype=torch.int64)
eval_info["episodes"]["task"] = ["do the thing"] * total_num_frames
with lock if lock is not None else nullcontext():
start_update_buffer_time = time.perf_counter()
online_dataset.add_data(eval_info["episodes"])
# Update the concatenated dataset length used during sampling.
concat_dataset.cumulative_sizes = concat_dataset.cumsum(concat_dataset.datasets)
# Update the sampling weights.
sampler.weights = compute_sampler_weights(
offline_dataset,
offline_drop_n_last_frames=getattr(cfg.policy, "drop_n_last_frames", 0),
online_dataset=online_dataset,
# +1 because online rollouts return an extra frame for the "final observation". Note: we don't have
# this final observation in the offline datasets, but we might add them in future.
online_drop_n_last_frames=getattr(cfg.policy, "drop_n_last_frames", 0) + 1,
online_sampling_ratio=cfg.online.sampling_ratio,
)
sampler.num_frames = len(concat_dataset)
update_online_buffer_s = time.perf_counter() - start_update_buffer_time
return online_rollout_s, update_online_buffer_s
if lock is None:
online_rollout_s, update_online_buffer_s = sample_trajectory_and_update_buffer()
else:
future = executor.submit(sample_trajectory_and_update_buffer)
# If we aren't doing async rollouts, or if we haven't yet gotten enough examples in our buffer, wait
# here until the rollout and buffer update is done, before proceeding to the policy update steps.
if len(online_dataset) <= cfg.online.buffer_seed_size:
online_rollout_s, update_online_buffer_s = future.result()
if len(online_dataset) <= cfg.online.buffer_seed_size:
logging.info(f"Seeding online buffer: {len(online_dataset)}/{cfg.online.buffer_seed_size}")
continue
policy.train()
for _ in range(cfg.online.steps_between_rollouts):
with lock if lock is not None else nullcontext():
start_time = time.perf_counter()
batch = next(dl_iter)
dataloading_s = time.perf_counter() - start_time
for key in batch:
if isinstance(batch[key], torch.Tensor):
dtype = get_safe_dtype(batch[key].dtype, device)
batch[key] = batch[key].to(device=device, dtype=dtype, non_blocking=True)
train_info = update_policy(
eval_env,
policy,
batch,
optimizer,
cfg.optimizer.grad_clip_norm,
grad_scaler=grad_scaler,
lr_scheduler=lr_scheduler,
use_amp=cfg.use_amp,
lock=lock,
cfg.eval.n_episodes,
videos_dir=cfg.output_dir / "eval" / f"videos_step_{step_id}",
max_episodes_rendered=4,
start_seed=cfg.seed,
)
train_info["dataloading_s"] = dataloading_s
train_info["online_rollout_s"] = online_rollout_s
train_info["update_online_buffer_s"] = update_online_buffer_s
train_info["await_update_online_buffer_s"] = await_update_online_buffer_s
with lock if lock is not None else nullcontext():
train_info["online_buffer_size"] = len(online_dataset)
if step % cfg.log_freq == 0:
log_train_info(logger, train_info, step, cfg, online_dataset, is_online=True)
# Note: evaluate_and_checkpoint_if_needed happens **after** the `step`th training update has completed,
# so we pass in step + 1.
evaluate_and_checkpoint_if_needed(step + 1, is_online=True)
step += 1
online_step += 1
# If we're doing async rollouts, we should now wait until we've completed them before proceeding
# to do the next batch of rollouts.
if cfg.online.do_rollout_async and future.running():
start = time.perf_counter()
online_rollout_s, update_online_buffer_s = future.result()
await_update_online_buffer_s = time.perf_counter() - start
if online_step >= cfg.online.steps:
break
eval_metrics = {
"avg_sum_reward": AverageMeter("∑rwrd", ":.3f"),
"pc_success": AverageMeter("success", ":.1f"),
"eval_s": AverageMeter("eval_s", ":.3f"),
}
eval_tracker = MetricsTracker(
cfg.batch_size, dataset.num_frames, dataset.num_episodes, eval_metrics, initial_step=step
)
eval_tracker.eval_s = eval_info["aggregated"].pop("eval_s")
eval_tracker.avg_sum_reward = eval_info["aggregated"].pop("avg_sum_reward")
eval_tracker.pc_success = eval_info["aggregated"].pop("pc_success")
logging.info(eval_tracker)
if wandb_logger:
wandb_log_dict = {**eval_tracker.to_dict(), **eval_info}
wandb_logger.log_video(eval_info["video_paths"][0], step, mode="eval")
if eval_env:
eval_env.close()

24
poetry.lock generated
View File

@ -1257,12 +1257,14 @@ pyarrow = "*"
[[package]]
name = "draccus"
version = "0.9.3"
version = "0.10.0"
description = "A slightly opinionated framework for simple dataclass-based configurations based on Pyrallis."
optional = false
python-versions = ">=3.8"
files = []
develop = false
python-versions = ">=3.9"
files = [
{file = "draccus-0.10.0-py3-none-any.whl", hash = "sha256:90243418ae0e9271c390a59cafb6acfd37001193696ed36fcc8525f791a83282"},
{file = "draccus-0.10.0.tar.gz", hash = "sha256:8dd08304219becdcd66cd16058ba98e9c3e6b7bfe48ccb9579dae39f8d37ae19"},
]
[package.dependencies]
mergedeep = ">=1.3,<2.0"
@ -1274,12 +1276,6 @@ typing-inspect = ">=0.9.0,<0.10.0"
[package.extras]
dev = ["black", "mypy", "pre-commit", "pytest", "ruff"]
[package.source]
type = "git"
url = "https://github.com/dlwh/draccus.git"
reference = "HEAD"
resolved_reference = "9b690730ca108930519f48cc5dead72a72fd27cb"
[[package]]
name = "drawnow"
version = "0.72.5"
@ -4911,8 +4907,6 @@ files = [
{file = "PyAudio-0.2.14-cp311-cp311-win_amd64.whl", hash = "sha256:bbeb01d36a2f472ae5ee5e1451cacc42112986abe622f735bb870a5db77cf903"},
{file = "PyAudio-0.2.14-cp312-cp312-win32.whl", hash = "sha256:5fce4bcdd2e0e8c063d835dbe2860dac46437506af509353c7f8114d4bacbd5b"},
{file = "PyAudio-0.2.14-cp312-cp312-win_amd64.whl", hash = "sha256:12f2f1ba04e06ff95d80700a78967897a489c05e093e3bffa05a84ed9c0a7fa3"},
{file = "PyAudio-0.2.14-cp313-cp313-win32.whl", hash = "sha256:95328285b4dab57ea8c52a4a996cb52be6d629353315be5bfda403d15932a497"},
{file = "PyAudio-0.2.14-cp313-cp313-win_amd64.whl", hash = "sha256:692d8c1446f52ed2662120bcd9ddcb5aa2b71f38bda31e58b19fb4672fffba69"},
{file = "PyAudio-0.2.14-cp38-cp38-win32.whl", hash = "sha256:858caf35b05c26d8fc62f1efa2e8f53d5fa1a01164842bd622f70ddc41f55000"},
{file = "PyAudio-0.2.14-cp38-cp38-win_amd64.whl", hash = "sha256:2dac0d6d675fe7e181ba88f2de88d321059b69abd52e3f4934a8878e03a7a074"},
{file = "PyAudio-0.2.14-cp39-cp39-win32.whl", hash = "sha256:f745109634a7c19fa4d6b8b7d6967c3123d988c9ade0cd35d4295ee1acdb53e9"},
@ -6934,7 +6928,7 @@ test = ["pytest", "ruff"]
name = "tokenizers"
version = "0.21.0"
description = ""
optional = false
optional = true
python-versions = ">=3.7"
files = [
{file = "tokenizers-0.21.0-cp39-abi3-macosx_10_12_x86_64.whl", hash = "sha256:3c4c93eae637e7d2aaae3d376f06085164e1660f89304c0ab2b1d08a406636b2"},
@ -7168,7 +7162,7 @@ test = ["argcomplete (>=3.0.3)", "mypy (>=1.7.0)", "pre-commit", "pytest (>=7.0,
name = "transformers"
version = "4.48.0"
description = "State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow"
optional = false
optional = true
python-versions = ">=3.9.0"
files = [
{file = "transformers-4.48.0-py3-none-any.whl", hash = "sha256:6d3de6d71cb5f2a10f9775ccc17abce9620195caaf32ec96542bd2a6937f25b0"},
@ -7936,4 +7930,4 @@ xarm = ["gym-xarm"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.10,<3.13"
content-hash = "089df92a0455bb58d2f600ad645b99362c9ff2b800a0c6108edc09f51509c716"
content-hash = "3df78b6f373b1e4ee79530932f386fb4cfd302c1ceffe26617d26fb1a7e751ce"

View File

@ -69,8 +69,8 @@ pyrender = {git = "https://github.com/mmatl/pyrender.git", markers = "sys_platfo
hello-robot-stretch-body = {version = ">=0.7.27", markers = "sys_platform == 'linux'", optional = true}
pyserial = {version = ">=3.5", optional = true}
jsonlines = ">=4.0.0"
transformers = ">=4.48.0"
draccus = {git = "https://github.com/dlwh/draccus.git"} # replace with draccus = ">=0.9.4" when https://github.com/dlwh/draccus/pull/29 is in release
transformers = {version = ">=4.48.0", optional = true}
draccus = ">=0.10.0"
[tool.poetry.extras]

View File

@ -28,6 +28,7 @@ pytest_plugins = [
"tests.fixtures.dataset_factories",
"tests.fixtures.files",
"tests.fixtures.hub",
"tests.fixtures.optimizers",
]

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:7841afb9ef99c0601448c43a20c25eb029440c73816319c67c5d7e1c5cde2445
size 136

View File

@ -1,11 +0,0 @@
{
"codebase_version": "v1.6",
"fps": 50,
"video": true,
"encoding": {
"vcodec": "libsvtav1",
"pix_fmt": "yuv420p",
"g": 2,
"crf": 30
}
}

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:03508d82db846a804aef1a28aec3cb9572e3105b55a02b6ddbb09b2522d57b84
size 4344

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:7009b3d2f14d6af497eeb32a52332e79cb9c07db24a6c2bbfbeffbaa8151dd69
size 592448

View File

@ -1,61 +0,0 @@
{
"citation": "",
"description": "",
"features": {
"observation.images.cam_high": {
"_type": "VideoFrame"
},
"observation.images.cam_left_wrist": {
"_type": "VideoFrame"
},
"observation.images.cam_right_wrist": {
"_type": "VideoFrame"
},
"observation.state": {
"feature": {
"dtype": "float32",
"_type": "Value"
},
"length": 14,
"_type": "Sequence"
},
"observation.effort": {
"feature": {
"dtype": "float32",
"_type": "Value"
},
"length": 14,
"_type": "Sequence"
},
"action": {
"feature": {
"dtype": "float32",
"_type": "Value"
},
"length": 14,
"_type": "Sequence"
},
"episode_index": {
"dtype": "int64",
"_type": "Value"
},
"frame_index": {
"dtype": "int64",
"_type": "Value"
},
"timestamp": {
"dtype": "float32",
"_type": "Value"
},
"next.done": {
"dtype": "bool",
"_type": "Value"
},
"index": {
"dtype": "int64",
"_type": "Value"
}
},
"homepage": "",
"license": ""
}

View File

@ -1,13 +0,0 @@
{
"_data_files": [
{
"filename": "data-00000-of-00001.arrow"
}
],
"_fingerprint": "3e76021c95d21c4d",
"_format_columns": null,
"_format_kwargs": {},
"_format_type": null,
"_output_all_columns": false,
"_split": null
}

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:a70cc17019407cf6bee44fa2c78b4f29e48eb1696aa1a4ff4c048ba256574523
size 6356921

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:2b35992036e6dcee7d4df6d1675d55d1dd2d658b2d65442737e709895699a2f0
size 5084448

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:3aa92e6b6bd0e39f6de530ea6a270671db7350cdc101c9d9030c775539c708c1
size 5441406

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:4ee862b1a6dc1d11df77c36c47ea00db88ad35a48e4d71c2940ad26b55fe2167
size 136

View File

@ -1,11 +0,0 @@
{
"codebase_version": "v1.6",
"fps": 50,
"video": true,
"encoding": {
"vcodec": "libsvtav1",
"pix_fmt": "yuv420p",
"g": 2,
"crf": 30
}
}

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:095c30bfe3c5da168c85aceef905e74e2142866332282965aa6812f6e6e48448
size 4344

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:98859f2d87e1a0abb9a930a82af623504b3efb26f70fe576f05bab7f19024427
size 788528

View File

@ -1,61 +0,0 @@
{
"citation": "",
"description": "",
"features": {
"observation.images.cam_high": {
"_type": "VideoFrame"
},
"observation.images.cam_left_wrist": {
"_type": "VideoFrame"
},
"observation.images.cam_right_wrist": {
"_type": "VideoFrame"
},
"observation.state": {
"feature": {
"dtype": "float32",
"_type": "Value"
},
"length": 14,
"_type": "Sequence"
},
"observation.effort": {
"feature": {
"dtype": "float32",
"_type": "Value"
},
"length": 14,
"_type": "Sequence"
},
"action": {
"feature": {
"dtype": "float32",
"_type": "Value"
},
"length": 14,
"_type": "Sequence"
},
"episode_index": {
"dtype": "int64",
"_type": "Value"
},
"frame_index": {
"dtype": "int64",
"_type": "Value"
},
"timestamp": {
"dtype": "float32",
"_type": "Value"
},
"next.done": {
"dtype": "bool",
"_type": "Value"
},
"index": {
"dtype": "int64",
"_type": "Value"
}
},
"homepage": "",
"license": ""
}

View File

@ -1,13 +0,0 @@
{
"_data_files": [
{
"filename": "data-00000-of-00001.arrow"
}
],
"_fingerprint": "872117944c4ecdff",
"_format_columns": null,
"_format_kwargs": {},
"_format_type": null,
"_output_all_columns": false,
"_split": null
}

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:596dda720d378a44b6b61a6a72b44bec3e55e85198bca37f9dace6fe84af7ff0
size 16062396

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:c614bbaf93d65354a82001b357682a0bd36f9603685f6c735c5e377b763d0bdb
size 10317415

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:868788028a38334b6b566cb17ffcc2ace2ec2b2b68ff2a58b6d29eb3c3e2ec1f
size 9516445

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:f365a02b052a2697b1558f4ab9b813f0d4ba46a5bc6ae3da30bbc4b135426aa6
size 136

View File

@ -1,11 +0,0 @@
{
"codebase_version": "v1.6",
"fps": 50,
"video": true,
"encoding": {
"vcodec": "libsvtav1",
"pix_fmt": "yuv420p",
"g": 2,
"crf": 30
}
}

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:5c96f47b569b7af82e05200213d733626664150aa7c5ae3298fd04a2138a2023
size 4344

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:75f53d221827f17cc2ded3908452e24331b39b79dc3a26f2b9d89a6e6894baab
size 887728

View File

@ -1,61 +0,0 @@
{
"citation": "",
"description": "",
"features": {
"observation.images.cam_high": {
"_type": "VideoFrame"
},
"observation.images.cam_left_wrist": {
"_type": "VideoFrame"
},
"observation.images.cam_right_wrist": {
"_type": "VideoFrame"
},
"observation.state": {
"feature": {
"dtype": "float32",
"_type": "Value"
},
"length": 14,
"_type": "Sequence"
},
"observation.effort": {
"feature": {
"dtype": "float32",
"_type": "Value"
},
"length": 14,
"_type": "Sequence"
},
"action": {
"feature": {
"dtype": "float32",
"_type": "Value"
},
"length": 14,
"_type": "Sequence"
},
"episode_index": {
"dtype": "int64",
"_type": "Value"
},
"frame_index": {
"dtype": "int64",
"_type": "Value"
},
"timestamp": {
"dtype": "float32",
"_type": "Value"
},
"next.done": {
"dtype": "bool",
"_type": "Value"
},
"index": {
"dtype": "int64",
"_type": "Value"
}
},
"homepage": "",
"license": ""
}

View File

@ -1,13 +0,0 @@
{
"_data_files": [
{
"filename": "data-00000-of-00001.arrow"
}
],
"_fingerprint": "ba1d9dc6ea5a9717",
"_format_columns": null,
"_format_kwargs": {},
"_format_type": null,
"_output_all_columns": false,
"_split": null
}

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:73ddb898f83589b4bcabe978e46e75f20be215492f115bf6ebc98f1d01e1eff8
size 9696507

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:d3d993977bee96882732d4a9c9d082c356fc9fcd8199c027b016207d60494c2f
size 8957007

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:c9321627184c14af4a6ba64d02e86f7253bc1f563a3adef17036d68480d2bb3e
size 9938178

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:88346956fdf58f17dba7b08cc858364ed8278a7baa20febd9c68ae959d2c9c82
size 136

View File

@ -1,11 +0,0 @@
{
"codebase_version": "v1.6",
"fps": 50,
"video": true,
"encoding": {
"vcodec": "libsvtav1",
"pix_fmt": "yuv420p",
"g": 2,
"crf": 30
}
}

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:de80d5afc044be903a89ee08f30cfef5fb4c1e928d8ba8f4d81ea9d0bb4fb011
size 4344

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:79c2a3da1024fa140d23e8438b2756d27cf5db65ac70d7ac4215260b55ca55f8
size 1477064

View File

@ -1,61 +0,0 @@
{
"citation": "",
"description": "",
"features": {
"observation.images.cam_high": {
"_type": "VideoFrame"
},
"observation.images.cam_left_wrist": {
"_type": "VideoFrame"
},
"observation.images.cam_right_wrist": {
"_type": "VideoFrame"
},
"observation.state": {
"feature": {
"dtype": "float32",
"_type": "Value"
},
"length": 14,
"_type": "Sequence"
},
"observation.effort": {
"feature": {
"dtype": "float32",
"_type": "Value"
},
"length": 14,
"_type": "Sequence"
},
"action": {
"feature": {
"dtype": "float32",
"_type": "Value"
},
"length": 14,
"_type": "Sequence"
},
"episode_index": {
"dtype": "int64",
"_type": "Value"
},
"frame_index": {
"dtype": "int64",
"_type": "Value"
},
"timestamp": {
"dtype": "float32",
"_type": "Value"
},
"next.done": {
"dtype": "bool",
"_type": "Value"
},
"index": {
"dtype": "int64",
"_type": "Value"
}
},
"homepage": "",
"license": ""
}

View File

@ -1,13 +0,0 @@
{
"_data_files": [
{
"filename": "data-00000-of-00001.arrow"
}
],
"_fingerprint": "d95a3a7eae59566d",
"_format_columns": null,
"_format_kwargs": {},
"_format_type": null,
"_output_all_columns": false,
"_split": null
}

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:3fc89b720dfb7511d5dd9eba31494cf720e6a89519067b7b5a4d65f0a539c811
size 35137505

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:26b8d97a096aa8a1d686d86fc93bde1dcdd50a9dc273f49f3b6a700fe6610e88
size 20387806

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:72000be2803259f40da6d093279d17ed194ead3ebc508bf2d77cb463bcb67c4d
size 17594265

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:fb6de86fee6ff3cc5d61d591fe480a50feb289c05770e3f4b76e24138b571c65
size 136

View File

@ -1,11 +0,0 @@
{
"codebase_version": "v1.6",
"fps": 50,
"video": true,
"encoding": {
"vcodec": "libsvtav1",
"pix_fmt": "yuv420p",
"g": 2,
"crf": 30
}
}

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:d79027c2513c01a7e360f3177e62ab955e5d3f704f1e7127a6e1e852158ec42c
size 4344

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:0a2c1f98c816728136291fcb7530cd0ebcf4ea47b0f6750836da56b8324d64c1
size 435600

View File

@ -1,61 +0,0 @@
{
"citation": "",
"description": "",
"features": {
"observation.images.cam_high": {
"_type": "VideoFrame"
},
"observation.images.cam_left_wrist": {
"_type": "VideoFrame"
},
"observation.images.cam_right_wrist": {
"_type": "VideoFrame"
},
"observation.state": {
"feature": {
"dtype": "float32",
"_type": "Value"
},
"length": 14,
"_type": "Sequence"
},
"observation.effort": {
"feature": {
"dtype": "float32",
"_type": "Value"
},
"length": 14,
"_type": "Sequence"
},
"action": {
"feature": {
"dtype": "float32",
"_type": "Value"
},
"length": 14,
"_type": "Sequence"
},
"episode_index": {
"dtype": "int64",
"_type": "Value"
},
"frame_index": {
"dtype": "int64",
"_type": "Value"
},
"timestamp": {
"dtype": "float32",
"_type": "Value"
},
"next.done": {
"dtype": "bool",
"_type": "Value"
},
"index": {
"dtype": "int64",
"_type": "Value"
}
},
"homepage": "",
"license": ""
}

View File

@ -1,13 +0,0 @@
{
"_data_files": [
{
"filename": "data-00000-of-00001.arrow"
}
],
"_fingerprint": "09ba9b66b7f468bc",
"_format_columns": null,
"_format_kwargs": {},
"_format_type": null,
"_output_all_columns": false,
"_split": null
}

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:7e298db7d820e2ff9f0b9c250e800e8ada3521fdeae3c4127452dd62700e9ac8
size 10980189

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:29b46c2e823d62b1329b98a3d7efffd24fc6c904e9cea115e2f0adb1bb45db44
size 7229025

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:f34ddbd109b212260c758d54a0930f75a38666a178a0d26eeefa846cfeac86c0
size 5944469

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:1386f9030607facefe56f429c93e50df0e22017914ce3f21ab67edc87b936d9d
size 136

View File

@ -1,11 +0,0 @@
{
"codebase_version": "v1.6",
"fps": 50,
"video": true,
"encoding": {
"vcodec": "libsvtav1",
"pix_fmt": "yuv420p",
"g": 2,
"crf": 30
}
}

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:7ffb173891cebb47a4d24d051f5fdd2ec44493d0a1a48d11f4d1410049aadd5b
size 4344

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:ae1760af2d3bf13c6e868643f203e76e1faf81a237715f72f2b81c3199e95e96
size 514056

View File

@ -1,61 +0,0 @@
{
"citation": "",
"description": "",
"features": {
"observation.images.cam_high": {
"_type": "VideoFrame"
},
"observation.images.cam_left_wrist": {
"_type": "VideoFrame"
},
"observation.images.cam_right_wrist": {
"_type": "VideoFrame"
},
"observation.state": {
"feature": {
"dtype": "float32",
"_type": "Value"
},
"length": 14,
"_type": "Sequence"
},
"observation.effort": {
"feature": {
"dtype": "float32",
"_type": "Value"
},
"length": 14,
"_type": "Sequence"
},
"action": {
"feature": {
"dtype": "float32",
"_type": "Value"
},
"length": 14,
"_type": "Sequence"
},
"episode_index": {
"dtype": "int64",
"_type": "Value"
},
"frame_index": {
"dtype": "int64",
"_type": "Value"
},
"timestamp": {
"dtype": "float32",
"_type": "Value"
},
"next.done": {
"dtype": "bool",
"_type": "Value"
},
"index": {
"dtype": "int64",
"_type": "Value"
}
},
"homepage": "",
"license": ""
}

View File

@ -1,13 +0,0 @@
{
"_data_files": [
{
"filename": "data-00000-of-00001.arrow"
}
],
"_fingerprint": "5a33376149b4e966",
"_format_columns": null,
"_format_kwargs": {},
"_format_type": null,
"_output_all_columns": false,
"_split": null
}

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:1489dac711fb99b192f064f9dbe56ed0e9e80fedc34da469e85acc7d5b4d75bf
size 12316772

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:20edc20184b5e4eb45194016fe7a0a5673665e3105286e0c6563767b5ff461f3
size 6365474

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:4ccdc96d9fe560a841e45e9fa636b69ef35b518271982339516517a4ae47d04f
size 7449799

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:9ee4f3c571ce6822e157e60133bee02245febee93eba5d35458d3c83345f7b87
size 136

View File

@ -1,11 +0,0 @@
{
"codebase_version": "v1.6",
"fps": 50,
"video": true,
"encoding": {
"vcodec": "libsvtav1",
"pix_fmt": "yuv420p",
"g": 2,
"crf": 30
}
}

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:b05f933aa67d559e44f062c8428b2f85ee7b49d3bf0e0302b9b83fb7d48ed0a3
size 2904

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:8698f98e3fe36e321ba99a9b60facaab4abffb26916042b021adc1b41e8fb877
size 100040

View File

@ -1,47 +0,0 @@
{
"citation": "",
"description": "",
"features": {
"observation.images.top": {
"_type": "VideoFrame"
},
"observation.state": {
"feature": {
"dtype": "float32",
"_type": "Value"
},
"length": 14,
"_type": "Sequence"
},
"action": {
"feature": {
"dtype": "float32",
"_type": "Value"
},
"length": 14,
"_type": "Sequence"
},
"episode_index": {
"dtype": "int64",
"_type": "Value"
},
"frame_index": {
"dtype": "int64",
"_type": "Value"
},
"timestamp": {
"dtype": "float32",
"_type": "Value"
},
"next.done": {
"dtype": "bool",
"_type": "Value"
},
"index": {
"dtype": "int64",
"_type": "Value"
}
},
"homepage": "",
"license": ""
}

View File

@ -1,13 +0,0 @@
{
"_data_files": [
{
"filename": "data-00000-of-00001.arrow"
}
],
"_fingerprint": "3aa08798f073758b",
"_format_columns": null,
"_format_kwargs": {},
"_format_type": null,
"_output_all_columns": false,
"_split": null
}

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:5a57aade7d8510ef1cc8778f90cfa86749c95fa0c5a5e80cb166b2edd0f7189a
size 1788513

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:e7dbc214a415689ca7fb83b6f8e12ec7824dfe34a66024b0b24bfeb3aeefd0e4
size 928

View File

@ -1,5 +0,0 @@
{
"codebase_version": "v1.6",
"fps": 50,
"video": 0
}

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:f98bd8f6347590aecdddaceed95d921f2d9f7bf35fbe742c37bdf12cba11dca6
size 2904

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:c0013aea549ec290af94bddde1b559fb8d0967d4c43ef14319177c4e62ed1e91
size 14545712

View File

@ -1,47 +0,0 @@
{
"citation": "",
"description": "",
"features": {
"observation.images.top": {
"_type": "Image"
},
"observation.state": {
"feature": {
"dtype": "float32",
"_type": "Value"
},
"length": 14,
"_type": "Sequence"
},
"action": {
"feature": {
"dtype": "float32",
"_type": "Value"
},
"length": 14,
"_type": "Sequence"
},
"episode_index": {
"dtype": "int64",
"_type": "Value"
},
"frame_index": {
"dtype": "int64",
"_type": "Value"
},
"timestamp": {
"dtype": "float32",
"_type": "Value"
},
"next.done": {
"dtype": "bool",
"_type": "Value"
},
"index": {
"dtype": "int64",
"_type": "Value"
}
},
"homepage": "",
"license": ""
}

View File

@ -1,13 +0,0 @@
{
"_data_files": [
{
"filename": "data-00000-of-00001.arrow"
}
],
"_fingerprint": "9fe3a4bf575a8a67",
"_format_columns": null,
"_format_kwargs": {},
"_format_type": null,
"_output_all_columns": false,
"_split": null
}

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:4e910eac6a1c94f4c194b05e908dcc973dd4227b18eb80c374d7a1150f166c34
size 136

View File

@ -1,11 +0,0 @@
{
"codebase_version": "v1.6",
"fps": 50,
"video": true,
"encoding": {
"vcodec": "libsvtav1",
"pix_fmt": "yuv420p",
"g": 2,
"crf": 30
}
}

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:a85e57264325cc0927450e30a85dd0eacb0a70ebdb00c4e2ac043a57f9c200e2
size 2904

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:171a9efc9c45601688821936ec9a1dcf91f16b1bbab4e8246f18b4d4cc6ac6ee
size 80432

Some files were not shown because too many files have changed in this diff Show More