WIP
This commit is contained in:
parent
1da5caaf4b
commit
9ddbbd8e80
|
@ -44,7 +44,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
def __init__(
|
||||
self,
|
||||
repo_id: str,
|
||||
root: Path | None = DATA_DIR,
|
||||
root: Path | None = None,
|
||||
split: str = "train",
|
||||
image_transforms: Callable | None = None,
|
||||
delta_timestamps: dict[list[float]] | None = None,
|
||||
|
@ -53,22 +53,24 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
super().__init__()
|
||||
self.repo_id = repo_id
|
||||
self.root = root
|
||||
if self.root is None and DATA_DIR is not None:
|
||||
self.root = DATA_DIR
|
||||
self.split = split
|
||||
self.image_transforms = image_transforms
|
||||
self.delta_timestamps = delta_timestamps
|
||||
# load data from hub or locally when root is provided
|
||||
# TODO(rcadene, aliberts): implement faster transfer
|
||||
# https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads
|
||||
self.hf_dataset = load_hf_dataset(repo_id, CODEBASE_VERSION, root, split)
|
||||
self.hf_dataset = load_hf_dataset(repo_id, CODEBASE_VERSION, self.root, split)
|
||||
if split == "train":
|
||||
self.episode_data_index = load_episode_data_index(repo_id, CODEBASE_VERSION, root)
|
||||
else:
|
||||
self.episode_data_index = calculate_episode_data_index(self.hf_dataset)
|
||||
self.hf_dataset = reset_episode_index(self.hf_dataset)
|
||||
self.stats = load_stats(repo_id, CODEBASE_VERSION, root)
|
||||
self.info = load_info(repo_id, CODEBASE_VERSION, root)
|
||||
self.stats = load_stats(repo_id, CODEBASE_VERSION, self.root)
|
||||
self.info = load_info(repo_id, CODEBASE_VERSION, self.root)
|
||||
if self.video:
|
||||
self.videos_dir = load_videos(repo_id, CODEBASE_VERSION, root)
|
||||
self.videos_dir = load_videos(repo_id, CODEBASE_VERSION, self.root)
|
||||
self.video_backend = video_backend if video_backend is not None else "pyav"
|
||||
|
||||
@property
|
||||
|
|
|
@ -233,9 +233,6 @@ class Logger:
|
|||
if self._wandb is not None:
|
||||
for k, v in d.items():
|
||||
if not isinstance(v, (int, float, str)):
|
||||
logging.warning(
|
||||
f'WandB logging of key "{k}" was ignored as its type is not handled by this wrapper.'
|
||||
)
|
||||
continue
|
||||
self._wandb.log({f"{mode}/{k}": v}, step=step)
|
||||
|
||||
|
|
|
@ -134,25 +134,26 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
batch = self.normalize_targets(batch)
|
||||
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
|
||||
|
||||
l1_loss = (
|
||||
F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1)
|
||||
).mean()
|
||||
bsize = actions_hat.shape[0]
|
||||
l1_loss = F.l1_loss(batch["action"], actions_hat, reduction="none")
|
||||
l1_loss = l1_loss * ~batch["action_is_pad"].unsqueeze(-1)
|
||||
l1_loss = l1_loss.view(bsize, -1).mean(dim=1)
|
||||
|
||||
out_dict = {}
|
||||
out_dict["l1_loss"] = l1_loss
|
||||
|
||||
loss_dict = {"l1_loss": l1_loss.item()}
|
||||
if self.config.use_vae:
|
||||
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
|
||||
# each dimension independently, we sum over the latent dimension to get the total
|
||||
# KL-divergence per batch element, then take the mean over the batch.
|
||||
# (See App. B of https://arxiv.org/abs/1312.6114 for more details).
|
||||
mean_kld = (
|
||||
(-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
|
||||
)
|
||||
loss_dict["kld_loss"] = mean_kld.item()
|
||||
loss_dict["loss"] = l1_loss + mean_kld * self.config.kl_weight
|
||||
kld_loss = (-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1)
|
||||
out_dict["loss"] = l1_loss + kld_loss * self.config.kl_weight
|
||||
else:
|
||||
loss_dict["loss"] = l1_loss
|
||||
out_dict["loss"] = l1_loss
|
||||
|
||||
return loss_dict
|
||||
out_dict["action"] = self.unnormalize_outputs({"action": actions_hat})["action"]
|
||||
return out_dict
|
||||
|
||||
|
||||
class ACTTemporalEnsembler:
|
||||
|
|
|
@ -341,7 +341,11 @@ class DiffusionModel(nn.Module):
|
|||
in_episode_bound = ~batch["action_is_pad"]
|
||||
loss = loss * in_episode_bound.unsqueeze(-1)
|
||||
|
||||
return loss.mean()
|
||||
# Compute average per item in the batch
|
||||
bsize = loss.shape[0]
|
||||
loss = loss.reshape(bsize, -1).mean(1)
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
class SpatialSoftmax(nn.Module):
|
||||
|
|
|
@ -396,51 +396,39 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
# Compute consistency loss as MSE loss between latents predicted from the rollout and latents
|
||||
# predicted from the (target model's) observation encoder.
|
||||
consistency_loss = (
|
||||
(
|
||||
temporal_loss_coeffs
|
||||
* F.mse_loss(z_preds[1:], z_targets, reduction="none").mean(dim=-1)
|
||||
# `z_preds` depends on the current observation and the actions.
|
||||
* ~batch["observation.state_is_pad"][0]
|
||||
* ~batch["action_is_pad"]
|
||||
# `z_targets` depends on the next observation.
|
||||
* ~batch["observation.state_is_pad"][1:]
|
||||
)
|
||||
.sum(0)
|
||||
.mean()
|
||||
)
|
||||
temporal_loss_coeffs
|
||||
* F.mse_loss(z_preds[1:], z_targets, reduction="none").mean(dim=-1)
|
||||
# `z_preds` depends on the current observation and the actions.
|
||||
* ~batch["observation.state_is_pad"][0]
|
||||
* ~batch["action_is_pad"]
|
||||
# `z_targets` depends on the next observation.
|
||||
* ~batch["observation.state_is_pad"][1:]
|
||||
).sum(0)
|
||||
# Compute the reward loss as MSE loss between rewards predicted from the rollout and the dataset
|
||||
# rewards.
|
||||
reward_loss = (
|
||||
(
|
||||
temporal_loss_coeffs
|
||||
* F.mse_loss(reward_preds, reward, reduction="none")
|
||||
* ~batch["next.reward_is_pad"]
|
||||
# `reward_preds` depends on the current observation and the actions.
|
||||
* ~batch["observation.state_is_pad"][0]
|
||||
* ~batch["action_is_pad"]
|
||||
)
|
||||
.sum(0)
|
||||
.mean()
|
||||
)
|
||||
temporal_loss_coeffs
|
||||
* F.mse_loss(reward_preds, reward, reduction="none")
|
||||
* ~batch["next.reward_is_pad"]
|
||||
# `reward_preds` depends on the current observation and the actions.
|
||||
* ~batch["observation.state_is_pad"][0]
|
||||
* ~batch["action_is_pad"]
|
||||
).sum(0)
|
||||
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
|
||||
q_value_loss = (
|
||||
(
|
||||
temporal_loss_coeffs
|
||||
* F.mse_loss(
|
||||
q_preds_ensemble,
|
||||
einops.repeat(q_targets, "t b -> e t b", e=q_preds_ensemble.shape[0]),
|
||||
reduction="none",
|
||||
).sum(0) # sum over ensemble
|
||||
# `q_preds_ensemble` depends on the first observation and the actions.
|
||||
* ~batch["observation.state_is_pad"][0]
|
||||
* ~batch["action_is_pad"]
|
||||
# q_targets depends on the reward and the next observations.
|
||||
* ~batch["next.reward_is_pad"]
|
||||
* ~batch["observation.state_is_pad"][1:]
|
||||
)
|
||||
.sum(0)
|
||||
.mean()
|
||||
)
|
||||
temporal_loss_coeffs
|
||||
* F.mse_loss(
|
||||
q_preds_ensemble,
|
||||
einops.repeat(q_targets, "t b -> e t b", e=q_preds_ensemble.shape[0]),
|
||||
reduction="none",
|
||||
).sum(0) # sum over ensemble
|
||||
# `q_preds_ensemble` depends on the first observation and the actions.
|
||||
* ~batch["observation.state_is_pad"][0]
|
||||
* ~batch["action_is_pad"]
|
||||
# q_targets depends on the reward and the next observations.
|
||||
* ~batch["next.reward_is_pad"]
|
||||
* ~batch["observation.state_is_pad"][1:]
|
||||
).sum(0)
|
||||
# Compute state value loss as in eqn 3 of FOWM.
|
||||
diff = v_targets - v_preds
|
||||
# Expectile loss penalizes:
|
||||
|
@ -450,16 +438,12 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
diff > 0, self.config.expectile_weight, (1 - self.config.expectile_weight)
|
||||
) * (diff**2)
|
||||
v_value_loss = (
|
||||
(
|
||||
temporal_loss_coeffs
|
||||
* raw_v_value_loss
|
||||
# `v_targets` depends on the first observation and the actions, as does `v_preds`.
|
||||
* ~batch["observation.state_is_pad"][0]
|
||||
* ~batch["action_is_pad"]
|
||||
)
|
||||
.sum(0)
|
||||
.mean()
|
||||
)
|
||||
temporal_loss_coeffs
|
||||
* raw_v_value_loss
|
||||
# `v_targets` depends on the first observation and the actions, as does `v_preds`.
|
||||
* ~batch["observation.state_is_pad"][0]
|
||||
* ~batch["action_is_pad"]
|
||||
).sum(0)
|
||||
|
||||
# Calculate the advantage weighted regression loss for π as detailed in FOWM 3.1.
|
||||
# We won't need these gradients again so detach.
|
||||
|
@ -492,7 +476,7 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
# `action_preds` depends on the first observation and the actions.
|
||||
* ~batch["observation.state_is_pad"][0]
|
||||
* ~batch["action_is_pad"]
|
||||
).mean()
|
||||
).sum(0)
|
||||
|
||||
loss = (
|
||||
self.config.consistency_coeff * consistency_loss
|
||||
|
@ -504,13 +488,13 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
|
||||
info.update(
|
||||
{
|
||||
"consistency_loss": consistency_loss.item(),
|
||||
"reward_loss": reward_loss.item(),
|
||||
"Q_value_loss": q_value_loss.item(),
|
||||
"V_value_loss": v_value_loss.item(),
|
||||
"pi_loss": pi_loss.item(),
|
||||
"consistency_loss": consistency_loss,
|
||||
"reward_loss": reward_loss,
|
||||
"Q_value_loss": q_value_loss,
|
||||
"V_value_loss": v_value_loss,
|
||||
"pi_loss": pi_loss,
|
||||
"loss": loss,
|
||||
"sum_loss": loss.item() * self.config.horizon,
|
||||
"sum_loss": loss * self.config.horizon,
|
||||
}
|
||||
)
|
||||
|
||||
|
|
|
@ -13,7 +13,13 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
from huggingface_hub.utils._errors import RepositoryNotFoundError
|
||||
from huggingface_hub.utils._validators import HFValidationError
|
||||
from torch import nn
|
||||
|
||||
|
||||
|
@ -47,3 +53,26 @@ def get_dtype_from_parameters(module: nn.Module) -> torch.dtype:
|
|||
Note: assumes that all parameters have the same dtype.
|
||||
"""
|
||||
return next(iter(module.parameters())).dtype
|
||||
|
||||
|
||||
def get_pretrained_policy_path(pretrained_policy_name_or_path, revision=None):
|
||||
try:
|
||||
pretrained_policy_path = Path(snapshot_download(pretrained_policy_name_or_path, revision=revision))
|
||||
except (HFValidationError, RepositoryNotFoundError) as e:
|
||||
if isinstance(e, HFValidationError):
|
||||
error_message = (
|
||||
"The provided pretrained_policy_name_or_path is not a valid Hugging Face Hub repo ID."
|
||||
)
|
||||
else:
|
||||
error_message = (
|
||||
"The provided pretrained_policy_name_or_path was not found on the Hugging Face Hub."
|
||||
)
|
||||
|
||||
logging.warning(f"{error_message} Treating it as a local directory.")
|
||||
pretrained_policy_path = Path(pretrained_policy_name_or_path)
|
||||
if not pretrained_policy_path.is_dir() or not pretrained_policy_path.exists():
|
||||
raise ValueError(
|
||||
"The provided pretrained_policy_name_or_path is not a valid/existing Hugging Face Hub "
|
||||
"repo ID, nor is it an existing local directory."
|
||||
)
|
||||
return pretrained_policy_path
|
||||
|
|
|
@ -24,7 +24,7 @@ training:
|
|||
online_steps_between_rollouts: 1
|
||||
|
||||
delta_timestamps:
|
||||
action: "[i / ${fps} for i in range(${policy.chunk_size})]"
|
||||
action: "[i / ${fps} for i in range(1, ${policy.chunk_size} + 1)]"
|
||||
|
||||
eval:
|
||||
n_episodes: 50
|
||||
|
|
|
@ -50,7 +50,7 @@ training:
|
|||
online_steps_between_rollouts: 1
|
||||
|
||||
delta_timestamps:
|
||||
action: "[i / ${fps} for i in range(${policy.chunk_size})]"
|
||||
action: "[i / ${fps} for i in range(1, ${policy.chunk_size} + 1)]"
|
||||
|
||||
eval:
|
||||
n_episodes: 50
|
||||
|
|
|
@ -48,7 +48,7 @@ training:
|
|||
online_steps_between_rollouts: 1
|
||||
|
||||
delta_timestamps:
|
||||
action: "[i / ${fps} for i in range(${policy.chunk_size})]"
|
||||
action: "[i / ${fps} for i in range(1, ${policy.chunk_size} + 1)]"
|
||||
|
||||
eval:
|
||||
n_episodes: 50
|
||||
|
|
|
@ -56,9 +56,6 @@ import einops
|
|||
import gymnasium as gym
|
||||
import numpy as np
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
from huggingface_hub.utils._errors import RepositoryNotFoundError
|
||||
from huggingface_hub.utils._validators import HFValidationError
|
||||
from torch import Tensor, nn
|
||||
from tqdm import trange
|
||||
|
||||
|
@ -68,7 +65,7 @@ from lerobot.common.envs.utils import preprocess_observation
|
|||
from lerobot.common.logger import log_output_dir
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.policies.policy_protocol import Policy
|
||||
from lerobot.common.policies.utils import get_device_from_parameters
|
||||
from lerobot.common.policies.utils import get_device_from_parameters, get_pretrained_policy_path
|
||||
from lerobot.common.utils.io_utils import write_video
|
||||
from lerobot.common.utils.utils import get_safe_torch_device, init_hydra_config, init_logging, set_global_seed
|
||||
|
||||
|
@ -501,29 +498,6 @@ def main(
|
|||
logging.info("End of eval")
|
||||
|
||||
|
||||
def get_pretrained_policy_path(pretrained_policy_name_or_path, revision=None):
|
||||
try:
|
||||
pretrained_policy_path = Path(snapshot_download(pretrained_policy_name_or_path, revision=revision))
|
||||
except (HFValidationError, RepositoryNotFoundError) as e:
|
||||
if isinstance(e, HFValidationError):
|
||||
error_message = (
|
||||
"The provided pretrained_policy_name_or_path is not a valid Hugging Face Hub repo ID."
|
||||
)
|
||||
else:
|
||||
error_message = (
|
||||
"The provided pretrained_policy_name_or_path was not found on the Hugging Face Hub."
|
||||
)
|
||||
|
||||
logging.warning(f"{error_message} Treating it as a local directory.")
|
||||
pretrained_policy_path = Path(pretrained_policy_name_or_path)
|
||||
if not pretrained_policy_path.is_dir() or not pretrained_policy_path.exists():
|
||||
raise ValueError(
|
||||
"The provided pretrained_policy_name_or_path is not a valid/existing Hugging Face Hub "
|
||||
"repo ID, nor is it an existing local directory."
|
||||
)
|
||||
return pretrained_policy_path
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
init_logging()
|
||||
|
||||
|
|
|
@ -120,8 +120,7 @@ def update_policy(
|
|||
policy.train()
|
||||
with torch.autocast(device_type=device.type) if use_amp else nullcontext():
|
||||
output_dict = policy.forward(batch)
|
||||
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
|
||||
loss = output_dict["loss"]
|
||||
loss = output_dict["loss"].mean()
|
||||
grad_scaler.scale(loss).backward()
|
||||
|
||||
# Unscale the graident of the optimzer's assigned params in-place **prior to gradient clipping**.
|
||||
|
@ -150,14 +149,12 @@ def update_policy(
|
|||
policy.update()
|
||||
|
||||
info = {
|
||||
"loss": loss.item(),
|
||||
"grad_norm": float(grad_norm),
|
||||
"lr": optimizer.param_groups[0]["lr"],
|
||||
"update_s": time.perf_counter() - start_time,
|
||||
**{k: v for k, v in output_dict.items() if k != "loss"},
|
||||
**{k: v.detach().mean().item() for k, v in output_dict.items() if "loss" in k},
|
||||
**{k: v for k, v in output_dict.items() if "loss" not in k},
|
||||
}
|
||||
info.update({k: v for k, v in output_dict.items() if k not in info})
|
||||
|
||||
return info
|
||||
|
||||
|
||||
|
|
|
@ -13,6 +13,18 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Save the policy tests artifacts.
|
||||
|
||||
Note: Run on the cluster
|
||||
|
||||
Example of usage:
|
||||
```bash
|
||||
DATA_DIR=tests/data python tests/scripts/save_policy_to_safetensors.py
|
||||
```
|
||||
"""
|
||||
|
||||
import platform
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
|
@ -54,7 +66,7 @@ def get_policy_stats(env_name, policy_name, extra_overrides):
|
|||
output_dict = {k: v for k, v in output_dict.items() if isinstance(v, torch.Tensor)}
|
||||
loss = output_dict["loss"]
|
||||
|
||||
loss.backward()
|
||||
loss.mean().backward()
|
||||
grad_stats = {}
|
||||
for key, param in policy.named_parameters():
|
||||
if param.requires_grad:
|
||||
|
@ -96,10 +108,21 @@ def save_policy_to_safetensors(output_dir, env_name, policy_name, extra_override
|
|||
print(f"Overwrite existing safetensors in '{env_policy_dir}':")
|
||||
print(f" - Validate with: `git add {env_policy_dir}`")
|
||||
print(f" - Revert with: `git checkout -- {env_policy_dir}`")
|
||||
|
||||
output_dict, grad_stats, param_stats, actions = get_policy_stats(env_name, policy_name, extra_overrides)
|
||||
|
||||
from safetensors.torch import load_file
|
||||
|
||||
if (env_policy_dir / "output_dict.safetensors").exists():
|
||||
prev_loss = load_file(env_policy_dir / "output_dict.safetensors")["loss"]
|
||||
print(f"Previous loss={prev_loss}")
|
||||
print(f"New loss={output_dict['loss'].mean()}")
|
||||
print()
|
||||
|
||||
if env_policy_dir.exists():
|
||||
shutil.rmtree(env_policy_dir)
|
||||
|
||||
env_policy_dir.mkdir(parents=True, exist_ok=True)
|
||||
output_dict, grad_stats, param_stats, actions = get_policy_stats(env_name, policy_name, extra_overrides)
|
||||
save_file(output_dict, env_policy_dir / "output_dict.safetensors")
|
||||
save_file(grad_stats, env_policy_dir / "grad_stats.safetensors")
|
||||
save_file(param_stats, env_policy_dir / "param_stats.safetensors")
|
||||
|
@ -107,27 +130,32 @@ def save_policy_to_safetensors(output_dir, env_name, policy_name, extra_override
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if platform.machine() != "x86_64":
|
||||
raise OSError("Generate policy artifacts on x86_64 machine since it is used for the unit tests. ")
|
||||
|
||||
env_policies = [
|
||||
# ("xarm", "tdmpc", ["policy.use_mpc=false"], "use_policy"),
|
||||
# ("xarm", "tdmpc", ["policy.use_mpc=true"], "use_mpc"),
|
||||
# (
|
||||
# "pusht",
|
||||
# "diffusion",
|
||||
# [
|
||||
# "policy.n_action_steps=8",
|
||||
# "policy.num_inference_steps=10",
|
||||
# "policy.down_dims=[128, 256, 512]",
|
||||
# ],
|
||||
# "",
|
||||
# ),
|
||||
# ("aloha", "act", ["policy.n_action_steps=10"], ""),
|
||||
# ("aloha", "act", ["policy.n_action_steps=1000", "policy.chunk_size=1000"], "_1000_steps"),
|
||||
# ("dora_aloha_real", "act_real", ["policy.n_action_steps=10"], ""),
|
||||
# ("dora_aloha_real", "act_real_no_state", ["policy.n_action_steps=10"], ""),
|
||||
("xarm", "tdmpc", ["policy.use_mpc=false"], "use_policy"),
|
||||
("xarm", "tdmpc", ["policy.use_mpc=true"], "use_mpc"),
|
||||
(
|
||||
"pusht",
|
||||
"diffusion",
|
||||
[
|
||||
"policy.n_action_steps=8",
|
||||
"policy.num_inference_steps=10",
|
||||
"policy.down_dims=[128, 256, 512]",
|
||||
],
|
||||
"",
|
||||
),
|
||||
("aloha", "act", ["policy.n_action_steps=10"], ""),
|
||||
("aloha", "act", ["policy.n_action_steps=1000", "policy.chunk_size=1000"], "_1000_steps"),
|
||||
("dora_aloha_real", "act_real", ["policy.n_action_steps=10"], ""),
|
||||
("dora_aloha_real", "act_real_no_state", ["policy.n_action_steps=10"], ""),
|
||||
]
|
||||
if len(env_policies) == 0:
|
||||
raise RuntimeError("No policies were provided!")
|
||||
for env, policy, extra_overrides, file_name_extra in env_policies:
|
||||
print(f"env={env} policy={policy} extra_overrides={extra_overrides}")
|
||||
save_policy_to_safetensors(
|
||||
"tests/data/save_policy_to_safetensors", env, policy, extra_overrides, file_name_extra
|
||||
)
|
||||
print()
|
||||
|
|
|
@ -147,10 +147,11 @@ def test_policy(env_name, policy_name, extra_overrides):
|
|||
# Check that we run select_actions and get the appropriate output.
|
||||
env = make_env(cfg, n_envs=2)
|
||||
|
||||
batch_size = 2
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=0,
|
||||
batch_size=2,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
pin_memory=DEVICE != "cpu",
|
||||
drop_last=True,
|
||||
|
@ -164,12 +165,19 @@ def test_policy(env_name, policy_name, extra_overrides):
|
|||
|
||||
# Test updating the policy (and test that it does not mutate the batch)
|
||||
batch_ = deepcopy(batch)
|
||||
policy.forward(batch)
|
||||
out = policy.forward(batch)
|
||||
assert set(batch) == set(batch_), "Batch keys are not the same after a forward pass."
|
||||
assert all(
|
||||
torch.equal(batch[k], batch_[k]) for k in batch
|
||||
), "Batch values are not the same after a forward pass."
|
||||
|
||||
# Test loss can be visualized using visualize_dataset_html.py
|
||||
for key in out:
|
||||
if "loss" in key:
|
||||
assert (
|
||||
out[key].ndim == 1 and out[key].shape[0] == batch_size
|
||||
), f"1 loss value per item in the batch is expected, but {out[key].shape} provided instead."
|
||||
|
||||
# reset the policy and environment
|
||||
policy.reset()
|
||||
observation, _ = env.reset(seed=cfg.seed)
|
||||
|
@ -234,6 +242,7 @@ def test_policy_defaults(policy_name: str):
|
|||
[
|
||||
("xarm", "tdmpc"),
|
||||
("pusht", "diffusion"),
|
||||
("pusht", "vqbet"),
|
||||
("aloha", "act"),
|
||||
],
|
||||
)
|
||||
|
@ -250,7 +259,7 @@ def test_yaml_matches_dataclass(env_name: str, policy_name: str):
|
|||
def test_save_and_load_pretrained(policy_name: str):
|
||||
policy_cls, _ = get_policy_and_config_classes(policy_name)
|
||||
policy: Policy = policy_cls()
|
||||
save_dir = "/tmp/test_save_and_load_pretrained_{policy_cls.__name__}"
|
||||
save_dir = f"/tmp/test_save_and_load_pretrained_{policy_cls.__name__}"
|
||||
policy.save_pretrained(save_dir)
|
||||
policy_ = policy_cls.from_pretrained(save_dir)
|
||||
assert all(torch.equal(p, p_) for p, p_ in zip(policy.parameters(), policy_.parameters(), strict=True))
|
||||
|
@ -365,6 +374,7 @@ def test_normalize(insert_temporal_dim):
|
|||
["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"],
|
||||
"",
|
||||
),
|
||||
("pusht", "vqbet", "[]", ""),
|
||||
("aloha", "act", ["policy.n_action_steps=10"], ""),
|
||||
("aloha", "act", ["policy.n_action_steps=1000", "policy.chunk_size=1000"], "_1000_steps"),
|
||||
("dora_aloha_real", "act_real", ["policy.n_action_steps=10"], ""),
|
||||
|
@ -461,7 +471,3 @@ def test_act_temporal_ensembler():
|
|||
assert torch.all(offline_avg <= einops.reduce(seq_slice, "b s 1 -> b 1", "max"))
|
||||
# Selected atol=1e-4 keeping in mind actions in [-1, 1] and excepting 0.01% error.
|
||||
assert torch.allclose(online_avg, offline_avg, atol=1e-4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_act_temporal_ensembler()
|
||||
|
|
|
@ -25,13 +25,13 @@ from lerobot.scripts.visualize_dataset import visualize_dataset
|
|||
["lerobot/pusht"],
|
||||
)
|
||||
@pytest.mark.parametrize("root", [Path(__file__).parent / "data"])
|
||||
def test_visualize_local_dataset(tmpdir, repo_id, root):
|
||||
def test_visualize_dataset_root(tmpdir, repo_id, root):
|
||||
rrd_path = visualize_dataset(
|
||||
repo_id,
|
||||
root=root,
|
||||
episode_index=0,
|
||||
batch_size=32,
|
||||
save=True,
|
||||
output_dir=tmpdir,
|
||||
root=root,
|
||||
)
|
||||
assert rrd_path.exists()
|
||||
|
|
Loading…
Reference in New Issue