This commit is contained in:
Remi Cadene 2024-08-06 17:17:07 +03:00
parent 1da5caaf4b
commit 9ddbbd8e80
14 changed files with 162 additions and 140 deletions

View File

@ -44,7 +44,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
def __init__(
self,
repo_id: str,
root: Path | None = DATA_DIR,
root: Path | None = None,
split: str = "train",
image_transforms: Callable | None = None,
delta_timestamps: dict[list[float]] | None = None,
@ -53,22 +53,24 @@ class LeRobotDataset(torch.utils.data.Dataset):
super().__init__()
self.repo_id = repo_id
self.root = root
if self.root is None and DATA_DIR is not None:
self.root = DATA_DIR
self.split = split
self.image_transforms = image_transforms
self.delta_timestamps = delta_timestamps
# load data from hub or locally when root is provided
# TODO(rcadene, aliberts): implement faster transfer
# https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads
self.hf_dataset = load_hf_dataset(repo_id, CODEBASE_VERSION, root, split)
self.hf_dataset = load_hf_dataset(repo_id, CODEBASE_VERSION, self.root, split)
if split == "train":
self.episode_data_index = load_episode_data_index(repo_id, CODEBASE_VERSION, root)
else:
self.episode_data_index = calculate_episode_data_index(self.hf_dataset)
self.hf_dataset = reset_episode_index(self.hf_dataset)
self.stats = load_stats(repo_id, CODEBASE_VERSION, root)
self.info = load_info(repo_id, CODEBASE_VERSION, root)
self.stats = load_stats(repo_id, CODEBASE_VERSION, self.root)
self.info = load_info(repo_id, CODEBASE_VERSION, self.root)
if self.video:
self.videos_dir = load_videos(repo_id, CODEBASE_VERSION, root)
self.videos_dir = load_videos(repo_id, CODEBASE_VERSION, self.root)
self.video_backend = video_backend if video_backend is not None else "pyav"
@property

View File

@ -233,9 +233,6 @@ class Logger:
if self._wandb is not None:
for k, v in d.items():
if not isinstance(v, (int, float, str)):
logging.warning(
f'WandB logging of key "{k}" was ignored as its type is not handled by this wrapper.'
)
continue
self._wandb.log({f"{mode}/{k}": v}, step=step)

View File

@ -134,25 +134,26 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
batch = self.normalize_targets(batch)
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
l1_loss = (
F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1)
).mean()
bsize = actions_hat.shape[0]
l1_loss = F.l1_loss(batch["action"], actions_hat, reduction="none")
l1_loss = l1_loss * ~batch["action_is_pad"].unsqueeze(-1)
l1_loss = l1_loss.view(bsize, -1).mean(dim=1)
out_dict = {}
out_dict["l1_loss"] = l1_loss
loss_dict = {"l1_loss": l1_loss.item()}
if self.config.use_vae:
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
# each dimension independently, we sum over the latent dimension to get the total
# KL-divergence per batch element, then take the mean over the batch.
# (See App. B of https://arxiv.org/abs/1312.6114 for more details).
mean_kld = (
(-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
)
loss_dict["kld_loss"] = mean_kld.item()
loss_dict["loss"] = l1_loss + mean_kld * self.config.kl_weight
kld_loss = (-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1)
out_dict["loss"] = l1_loss + kld_loss * self.config.kl_weight
else:
loss_dict["loss"] = l1_loss
out_dict["loss"] = l1_loss
return loss_dict
out_dict["action"] = self.unnormalize_outputs({"action": actions_hat})["action"]
return out_dict
class ACTTemporalEnsembler:

View File

@ -341,7 +341,11 @@ class DiffusionModel(nn.Module):
in_episode_bound = ~batch["action_is_pad"]
loss = loss * in_episode_bound.unsqueeze(-1)
return loss.mean()
# Compute average per item in the batch
bsize = loss.shape[0]
loss = loss.reshape(bsize, -1).mean(1)
return loss
class SpatialSoftmax(nn.Module):

View File

@ -396,51 +396,39 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
# Compute consistency loss as MSE loss between latents predicted from the rollout and latents
# predicted from the (target model's) observation encoder.
consistency_loss = (
(
temporal_loss_coeffs
* F.mse_loss(z_preds[1:], z_targets, reduction="none").mean(dim=-1)
# `z_preds` depends on the current observation and the actions.
* ~batch["observation.state_is_pad"][0]
* ~batch["action_is_pad"]
# `z_targets` depends on the next observation.
* ~batch["observation.state_is_pad"][1:]
)
.sum(0)
.mean()
)
temporal_loss_coeffs
* F.mse_loss(z_preds[1:], z_targets, reduction="none").mean(dim=-1)
# `z_preds` depends on the current observation and the actions.
* ~batch["observation.state_is_pad"][0]
* ~batch["action_is_pad"]
# `z_targets` depends on the next observation.
* ~batch["observation.state_is_pad"][1:]
).sum(0)
# Compute the reward loss as MSE loss between rewards predicted from the rollout and the dataset
# rewards.
reward_loss = (
(
temporal_loss_coeffs
* F.mse_loss(reward_preds, reward, reduction="none")
* ~batch["next.reward_is_pad"]
# `reward_preds` depends on the current observation and the actions.
* ~batch["observation.state_is_pad"][0]
* ~batch["action_is_pad"]
)
.sum(0)
.mean()
)
temporal_loss_coeffs
* F.mse_loss(reward_preds, reward, reduction="none")
* ~batch["next.reward_is_pad"]
# `reward_preds` depends on the current observation and the actions.
* ~batch["observation.state_is_pad"][0]
* ~batch["action_is_pad"]
).sum(0)
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
q_value_loss = (
(
temporal_loss_coeffs
* F.mse_loss(
q_preds_ensemble,
einops.repeat(q_targets, "t b -> e t b", e=q_preds_ensemble.shape[0]),
reduction="none",
).sum(0) # sum over ensemble
# `q_preds_ensemble` depends on the first observation and the actions.
* ~batch["observation.state_is_pad"][0]
* ~batch["action_is_pad"]
# q_targets depends on the reward and the next observations.
* ~batch["next.reward_is_pad"]
* ~batch["observation.state_is_pad"][1:]
)
.sum(0)
.mean()
)
temporal_loss_coeffs
* F.mse_loss(
q_preds_ensemble,
einops.repeat(q_targets, "t b -> e t b", e=q_preds_ensemble.shape[0]),
reduction="none",
).sum(0) # sum over ensemble
# `q_preds_ensemble` depends on the first observation and the actions.
* ~batch["observation.state_is_pad"][0]
* ~batch["action_is_pad"]
# q_targets depends on the reward and the next observations.
* ~batch["next.reward_is_pad"]
* ~batch["observation.state_is_pad"][1:]
).sum(0)
# Compute state value loss as in eqn 3 of FOWM.
diff = v_targets - v_preds
# Expectile loss penalizes:
@ -450,16 +438,12 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
diff > 0, self.config.expectile_weight, (1 - self.config.expectile_weight)
) * (diff**2)
v_value_loss = (
(
temporal_loss_coeffs
* raw_v_value_loss
# `v_targets` depends on the first observation and the actions, as does `v_preds`.
* ~batch["observation.state_is_pad"][0]
* ~batch["action_is_pad"]
)
.sum(0)
.mean()
)
temporal_loss_coeffs
* raw_v_value_loss
# `v_targets` depends on the first observation and the actions, as does `v_preds`.
* ~batch["observation.state_is_pad"][0]
* ~batch["action_is_pad"]
).sum(0)
# Calculate the advantage weighted regression loss for π as detailed in FOWM 3.1.
# We won't need these gradients again so detach.
@ -492,7 +476,7 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
# `action_preds` depends on the first observation and the actions.
* ~batch["observation.state_is_pad"][0]
* ~batch["action_is_pad"]
).mean()
).sum(0)
loss = (
self.config.consistency_coeff * consistency_loss
@ -504,13 +488,13 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
info.update(
{
"consistency_loss": consistency_loss.item(),
"reward_loss": reward_loss.item(),
"Q_value_loss": q_value_loss.item(),
"V_value_loss": v_value_loss.item(),
"pi_loss": pi_loss.item(),
"consistency_loss": consistency_loss,
"reward_loss": reward_loss,
"Q_value_loss": q_value_loss,
"V_value_loss": v_value_loss,
"pi_loss": pi_loss,
"loss": loss,
"sum_loss": loss.item() * self.config.horizon,
"sum_loss": loss * self.config.horizon,
}
)

View File

@ -13,7 +13,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from pathlib import Path
import torch
from huggingface_hub import snapshot_download
from huggingface_hub.utils._errors import RepositoryNotFoundError
from huggingface_hub.utils._validators import HFValidationError
from torch import nn
@ -47,3 +53,26 @@ def get_dtype_from_parameters(module: nn.Module) -> torch.dtype:
Note: assumes that all parameters have the same dtype.
"""
return next(iter(module.parameters())).dtype
def get_pretrained_policy_path(pretrained_policy_name_or_path, revision=None):
try:
pretrained_policy_path = Path(snapshot_download(pretrained_policy_name_or_path, revision=revision))
except (HFValidationError, RepositoryNotFoundError) as e:
if isinstance(e, HFValidationError):
error_message = (
"The provided pretrained_policy_name_or_path is not a valid Hugging Face Hub repo ID."
)
else:
error_message = (
"The provided pretrained_policy_name_or_path was not found on the Hugging Face Hub."
)
logging.warning(f"{error_message} Treating it as a local directory.")
pretrained_policy_path = Path(pretrained_policy_name_or_path)
if not pretrained_policy_path.is_dir() or not pretrained_policy_path.exists():
raise ValueError(
"The provided pretrained_policy_name_or_path is not a valid/existing Hugging Face Hub "
"repo ID, nor is it an existing local directory."
)
return pretrained_policy_path

View File

@ -24,7 +24,7 @@ training:
online_steps_between_rollouts: 1
delta_timestamps:
action: "[i / ${fps} for i in range(${policy.chunk_size})]"
action: "[i / ${fps} for i in range(1, ${policy.chunk_size} + 1)]"
eval:
n_episodes: 50

View File

@ -50,7 +50,7 @@ training:
online_steps_between_rollouts: 1
delta_timestamps:
action: "[i / ${fps} for i in range(${policy.chunk_size})]"
action: "[i / ${fps} for i in range(1, ${policy.chunk_size} + 1)]"
eval:
n_episodes: 50

View File

@ -48,7 +48,7 @@ training:
online_steps_between_rollouts: 1
delta_timestamps:
action: "[i / ${fps} for i in range(${policy.chunk_size})]"
action: "[i / ${fps} for i in range(1, ${policy.chunk_size} + 1)]"
eval:
n_episodes: 50

View File

@ -56,9 +56,6 @@ import einops
import gymnasium as gym
import numpy as np
import torch
from huggingface_hub import snapshot_download
from huggingface_hub.utils._errors import RepositoryNotFoundError
from huggingface_hub.utils._validators import HFValidationError
from torch import Tensor, nn
from tqdm import trange
@ -68,7 +65,7 @@ from lerobot.common.envs.utils import preprocess_observation
from lerobot.common.logger import log_output_dir
from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.policy_protocol import Policy
from lerobot.common.policies.utils import get_device_from_parameters
from lerobot.common.policies.utils import get_device_from_parameters, get_pretrained_policy_path
from lerobot.common.utils.io_utils import write_video
from lerobot.common.utils.utils import get_safe_torch_device, init_hydra_config, init_logging, set_global_seed
@ -501,29 +498,6 @@ def main(
logging.info("End of eval")
def get_pretrained_policy_path(pretrained_policy_name_or_path, revision=None):
try:
pretrained_policy_path = Path(snapshot_download(pretrained_policy_name_or_path, revision=revision))
except (HFValidationError, RepositoryNotFoundError) as e:
if isinstance(e, HFValidationError):
error_message = (
"The provided pretrained_policy_name_or_path is not a valid Hugging Face Hub repo ID."
)
else:
error_message = (
"The provided pretrained_policy_name_or_path was not found on the Hugging Face Hub."
)
logging.warning(f"{error_message} Treating it as a local directory.")
pretrained_policy_path = Path(pretrained_policy_name_or_path)
if not pretrained_policy_path.is_dir() or not pretrained_policy_path.exists():
raise ValueError(
"The provided pretrained_policy_name_or_path is not a valid/existing Hugging Face Hub "
"repo ID, nor is it an existing local directory."
)
return pretrained_policy_path
if __name__ == "__main__":
init_logging()

View File

@ -120,8 +120,7 @@ def update_policy(
policy.train()
with torch.autocast(device_type=device.type) if use_amp else nullcontext():
output_dict = policy.forward(batch)
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
loss = output_dict["loss"]
loss = output_dict["loss"].mean()
grad_scaler.scale(loss).backward()
# Unscale the graident of the optimzer's assigned params in-place **prior to gradient clipping**.
@ -150,14 +149,12 @@ def update_policy(
policy.update()
info = {
"loss": loss.item(),
"grad_norm": float(grad_norm),
"lr": optimizer.param_groups[0]["lr"],
"update_s": time.perf_counter() - start_time,
**{k: v for k, v in output_dict.items() if k != "loss"},
**{k: v.detach().mean().item() for k, v in output_dict.items() if "loss" in k},
**{k: v for k, v in output_dict.items() if "loss" not in k},
}
info.update({k: v for k, v in output_dict.items() if k not in info})
return info

View File

@ -13,6 +13,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Save the policy tests artifacts.
Note: Run on the cluster
Example of usage:
```bash
DATA_DIR=tests/data python tests/scripts/save_policy_to_safetensors.py
```
"""
import platform
import shutil
from pathlib import Path
@ -54,7 +66,7 @@ def get_policy_stats(env_name, policy_name, extra_overrides):
output_dict = {k: v for k, v in output_dict.items() if isinstance(v, torch.Tensor)}
loss = output_dict["loss"]
loss.backward()
loss.mean().backward()
grad_stats = {}
for key, param in policy.named_parameters():
if param.requires_grad:
@ -96,10 +108,21 @@ def save_policy_to_safetensors(output_dir, env_name, policy_name, extra_override
print(f"Overwrite existing safetensors in '{env_policy_dir}':")
print(f" - Validate with: `git add {env_policy_dir}`")
print(f" - Revert with: `git checkout -- {env_policy_dir}`")
output_dict, grad_stats, param_stats, actions = get_policy_stats(env_name, policy_name, extra_overrides)
from safetensors.torch import load_file
if (env_policy_dir / "output_dict.safetensors").exists():
prev_loss = load_file(env_policy_dir / "output_dict.safetensors")["loss"]
print(f"Previous loss={prev_loss}")
print(f"New loss={output_dict['loss'].mean()}")
print()
if env_policy_dir.exists():
shutil.rmtree(env_policy_dir)
env_policy_dir.mkdir(parents=True, exist_ok=True)
output_dict, grad_stats, param_stats, actions = get_policy_stats(env_name, policy_name, extra_overrides)
save_file(output_dict, env_policy_dir / "output_dict.safetensors")
save_file(grad_stats, env_policy_dir / "grad_stats.safetensors")
save_file(param_stats, env_policy_dir / "param_stats.safetensors")
@ -107,27 +130,32 @@ def save_policy_to_safetensors(output_dir, env_name, policy_name, extra_override
if __name__ == "__main__":
if platform.machine() != "x86_64":
raise OSError("Generate policy artifacts on x86_64 machine since it is used for the unit tests. ")
env_policies = [
# ("xarm", "tdmpc", ["policy.use_mpc=false"], "use_policy"),
# ("xarm", "tdmpc", ["policy.use_mpc=true"], "use_mpc"),
# (
# "pusht",
# "diffusion",
# [
# "policy.n_action_steps=8",
# "policy.num_inference_steps=10",
# "policy.down_dims=[128, 256, 512]",
# ],
# "",
# ),
# ("aloha", "act", ["policy.n_action_steps=10"], ""),
# ("aloha", "act", ["policy.n_action_steps=1000", "policy.chunk_size=1000"], "_1000_steps"),
# ("dora_aloha_real", "act_real", ["policy.n_action_steps=10"], ""),
# ("dora_aloha_real", "act_real_no_state", ["policy.n_action_steps=10"], ""),
("xarm", "tdmpc", ["policy.use_mpc=false"], "use_policy"),
("xarm", "tdmpc", ["policy.use_mpc=true"], "use_mpc"),
(
"pusht",
"diffusion",
[
"policy.n_action_steps=8",
"policy.num_inference_steps=10",
"policy.down_dims=[128, 256, 512]",
],
"",
),
("aloha", "act", ["policy.n_action_steps=10"], ""),
("aloha", "act", ["policy.n_action_steps=1000", "policy.chunk_size=1000"], "_1000_steps"),
("dora_aloha_real", "act_real", ["policy.n_action_steps=10"], ""),
("dora_aloha_real", "act_real_no_state", ["policy.n_action_steps=10"], ""),
]
if len(env_policies) == 0:
raise RuntimeError("No policies were provided!")
for env, policy, extra_overrides, file_name_extra in env_policies:
print(f"env={env} policy={policy} extra_overrides={extra_overrides}")
save_policy_to_safetensors(
"tests/data/save_policy_to_safetensors", env, policy, extra_overrides, file_name_extra
)
print()

View File

@ -147,10 +147,11 @@ def test_policy(env_name, policy_name, extra_overrides):
# Check that we run select_actions and get the appropriate output.
env = make_env(cfg, n_envs=2)
batch_size = 2
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=0,
batch_size=2,
batch_size=batch_size,
shuffle=True,
pin_memory=DEVICE != "cpu",
drop_last=True,
@ -164,12 +165,19 @@ def test_policy(env_name, policy_name, extra_overrides):
# Test updating the policy (and test that it does not mutate the batch)
batch_ = deepcopy(batch)
policy.forward(batch)
out = policy.forward(batch)
assert set(batch) == set(batch_), "Batch keys are not the same after a forward pass."
assert all(
torch.equal(batch[k], batch_[k]) for k in batch
), "Batch values are not the same after a forward pass."
# Test loss can be visualized using visualize_dataset_html.py
for key in out:
if "loss" in key:
assert (
out[key].ndim == 1 and out[key].shape[0] == batch_size
), f"1 loss value per item in the batch is expected, but {out[key].shape} provided instead."
# reset the policy and environment
policy.reset()
observation, _ = env.reset(seed=cfg.seed)
@ -234,6 +242,7 @@ def test_policy_defaults(policy_name: str):
[
("xarm", "tdmpc"),
("pusht", "diffusion"),
("pusht", "vqbet"),
("aloha", "act"),
],
)
@ -250,7 +259,7 @@ def test_yaml_matches_dataclass(env_name: str, policy_name: str):
def test_save_and_load_pretrained(policy_name: str):
policy_cls, _ = get_policy_and_config_classes(policy_name)
policy: Policy = policy_cls()
save_dir = "/tmp/test_save_and_load_pretrained_{policy_cls.__name__}"
save_dir = f"/tmp/test_save_and_load_pretrained_{policy_cls.__name__}"
policy.save_pretrained(save_dir)
policy_ = policy_cls.from_pretrained(save_dir)
assert all(torch.equal(p, p_) for p, p_ in zip(policy.parameters(), policy_.parameters(), strict=True))
@ -365,6 +374,7 @@ def test_normalize(insert_temporal_dim):
["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"],
"",
),
("pusht", "vqbet", "[]", ""),
("aloha", "act", ["policy.n_action_steps=10"], ""),
("aloha", "act", ["policy.n_action_steps=1000", "policy.chunk_size=1000"], "_1000_steps"),
("dora_aloha_real", "act_real", ["policy.n_action_steps=10"], ""),
@ -461,7 +471,3 @@ def test_act_temporal_ensembler():
assert torch.all(offline_avg <= einops.reduce(seq_slice, "b s 1 -> b 1", "max"))
# Selected atol=1e-4 keeping in mind actions in [-1, 1] and excepting 0.01% error.
assert torch.allclose(online_avg, offline_avg, atol=1e-4)
if __name__ == "__main__":
test_act_temporal_ensembler()

View File

@ -25,13 +25,13 @@ from lerobot.scripts.visualize_dataset import visualize_dataset
["lerobot/pusht"],
)
@pytest.mark.parametrize("root", [Path(__file__).parent / "data"])
def test_visualize_local_dataset(tmpdir, repo_id, root):
def test_visualize_dataset_root(tmpdir, repo_id, root):
rrd_path = visualize_dataset(
repo_id,
root=root,
episode_index=0,
batch_size=32,
save=True,
output_dir=tmpdir,
root=root,
)
assert rrd_path.exists()