Tidy up yaml configs (#121)

This commit is contained in:
Alexander Soare 2024-04-30 16:08:59 +01:00 committed by GitHub
parent e4e739f4f8
commit 9d60dce6f3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 142 additions and 207 deletions

View File

@ -30,21 +30,21 @@ test-act-ete-train:
policy=act \
env=aloha \
wandb.enable=False \
offline_steps=2 \
online_steps=0 \
eval_episodes=1 \
training.offline_steps=2 \
training.online_steps=0 \
eval.n_episodes=1 \
device=cpu \
save_model=true \
save_freq=2 \
training.save_model=true \
training.save_freq=2 \
policy.n_action_steps=20 \
policy.chunk_size=20 \
policy.batch_size=2 \
training.batch_size=2 \
hydra.run.dir=tests/outputs/act/
test-act-ete-eval:
python lerobot/scripts/eval.py \
--config tests/outputs/act/.hydra/config.yaml \
eval_episodes=1 \
eval.n_episodes=1 \
env.episode_length=8 \
device=cpu \
policy.pretrained_model_path=tests/outputs/act/models/2.pt
@ -54,19 +54,19 @@ test-diffusion-ete-train:
policy=diffusion \
env=pusht \
wandb.enable=False \
offline_steps=2 \
online_steps=0 \
eval_episodes=1 \
training.offline_steps=2 \
training.online_steps=0 \
eval.n_episodes=1 \
device=cpu \
save_model=true \
save_freq=2 \
policy.batch_size=2 \
training.save_model=true \
training.save_freq=2 \
training.batch_size=2 \
hydra.run.dir=tests/outputs/diffusion/
test-diffusion-ete-eval:
python lerobot/scripts/eval.py \
--config tests/outputs/diffusion/.hydra/config.yaml \
eval_episodes=1 \
eval.n_episodes=1 \
env.episode_length=8 \
device=cpu \
policy.pretrained_model_path=tests/outputs/diffusion/models/2.pt
@ -76,20 +76,20 @@ test-tdmpc-ete-train:
policy=tdmpc \
env=xarm \
wandb.enable=False \
offline_steps=1 \
online_steps=2 \
eval_episodes=1 \
training.offline_steps=1 \
training.online_steps=2 \
eval.n_episodes=1 \
env.episode_length=2 \
device=cpu \
save_model=true \
save_freq=2 \
policy.batch_size=2 \
training.save_model=true \
training.save_freq=2 \
training.batch_size=2 \
hydra.run.dir=tests/outputs/tdmpc/
test-tdmpc-ete-eval:
python lerobot/scripts/eval.py \
--config tests/outputs/tdmpc/.hydra/config.yaml \
eval_episodes=1 \
eval.n_episodes=1 \
env.episode_length=8 \
device=cpu \
policy.pretrained_model_path=tests/outputs/tdmpc/models/2.pt

View File

@ -23,8 +23,8 @@ weights_path = folder / "model.pt"
# Override some config parameters to do with evaluation.
overrides = [
f"policy.pretrained_model_path={weights_path}",
"eval_episodes=10",
"rollout_batch_size=10",
"eval.n_episodes=10",
"eval.batch_size=10",
"device=cuda",
]

View File

@ -38,15 +38,13 @@ policy = DiffusionPolicy(cfg, lr_scheduler_num_training_steps=training_steps, da
policy.train()
policy.to(device)
optimizer = torch.optim.Adam(
policy.diffusion.parameters(), cfg.lr, cfg.adam_betas, cfg.adam_eps, cfg.adam_weight_decay
)
optimizer = torch.optim.Adam(policy.parameters(), lr=1e-4)
# Create dataloader for offline training.
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=4,
batch_size=cfg.batch_size,
batch_size=64,
shuffle=True,
pin_memory=device != torch.device("cpu"),
drop_last=True,

View File

@ -14,12 +14,13 @@ def make_dataset(
cfg,
split="train",
):
if cfg.env.name not in cfg.dataset.repo_id:
if cfg.env.name not in cfg.dataset_repo_id:
logging.warning(
f"There might be a mismatch between your training dataset ({cfg.dataset.repo_id=}) and your environment ({cfg.env.name=})."
f"There might be a mismatch between your training dataset ({cfg.dataset_repo_id=}) and your "
f"environment ({cfg.env.name=})."
)
delta_timestamps = cfg.policy.get("delta_timestamps")
delta_timestamps = cfg.training.get("delta_timestamps")
if delta_timestamps is not None:
for key in delta_timestamps:
if isinstance(delta_timestamps[key], str):
@ -28,7 +29,7 @@ def make_dataset(
# TODO(rcadene): add data augmentations
dataset = LeRobotDataset(
cfg.dataset.repo_id,
cfg.dataset_repo_id,
split=split,
root=DATA_DIR,
delta_timestamps=delta_timestamps,

View File

@ -29,9 +29,9 @@ class Logger:
self._job_name = job_name
self._model_dir = self._log_dir / "models"
self._buffer_dir = self._log_dir / "buffers"
self._save_model = cfg.save_model
self._save_model = cfg.training.save_model
self._disable_wandb_artifact = cfg.wandb.disable_artifact
self._save_buffer = cfg.save_buffer
self._save_buffer = cfg.training.get("save_buffer", False)
self._group = cfg_to_group(cfg)
self._seed = cfg.seed
self._cfg = cfg

View File

@ -112,15 +112,6 @@ class ActionChunkingTransformerConfig:
dropout: float = 0.1
kl_weight: float = 10.0
# ---
# TODO(alexander-soare): Remove these from the policy config.
batch_size: int = 8
lr: float = 1e-5
lr_backbone: float = 1e-5
weight_decay: float = 1e-4
grad_clip_norm: float = 10
utd: int = 1
def __post_init__(self):
"""Input validation (not exhaustive)."""
if not self.vision_backbone.startswith("resnet"):

View File

@ -119,15 +119,6 @@ class DiffusionConfig:
# ---
# TODO(alexander-soare): Remove these from the policy config.
batch_size: int = 64
grad_clip_norm: int = 10
lr: float = 1.0e-4
lr_scheduler: str = "cosine"
lr_warmup_steps: int = 500
adam_betas: tuple[float, float] = (0.95, 0.999)
adam_eps: float = 1.0e-8
adam_weight_decay: float = 1.0e-6
utd: int = 1
use_ema: bool = True
ema_update_after_step: int = 0
ema_min_alpha: float = 0.0

View File

@ -35,7 +35,7 @@ def make_policy(hydra_cfg: DictConfig, dataset_stats=None):
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
policy_cfg = _policy_cfg_from_hydra_cfg(DiffusionConfig, hydra_cfg)
policy = DiffusionPolicy(policy_cfg, hydra_cfg.offline_steps, dataset_stats)
policy = DiffusionPolicy(policy_cfg, hydra_cfg.training.offline_steps, dataset_stats)
policy.to(get_safe_torch_device(hydra_cfg.device))
elif hydra_cfg.policy.name == "act":
from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig

View File

@ -9,31 +9,23 @@ hydra:
job:
name: default
seed: 1337
# batch size for TorchRL SerialEnv. Each underlying env will get the seed = seed + env_index
# NOTE: only diffusion policy supports rollout_batch_size > 1
rollout_batch_size: 1
device: cuda # cpu
prefetch: 4
eval_freq: ???
save_freq: ???
eval_episodes: ???
save_video: false
save_model: false
save_buffer: false
train_steps: ???
fps: ???
seed: ???
dataset_repo_id: lerobot/pusht
offline_prioritized_sampler: true
training:
offline_steps: ???
online_steps: ???
online_steps_between_rollouts: ???
eval_freq: ???
save_freq: ???
log_freq: 250
save_model: false
dataset:
repo_id: ???
n_action_steps: ???
n_obs_steps: ???
env: ???
policy: ???
eval:
n_episodes: 1
# TODO(alexander-soare): Right now this does not work. Reinstate this.
batch_size: 1
wandb:
enable: true

View File

@ -1,18 +1,7 @@
# @package _global_
eval_episodes: 50
eval_freq: 7500
save_freq: 75000
log_freq: 250
# TODO: same as xarm, need to adjust
offline_steps: 25000
online_steps: 25000
fps: 50
dataset:
repo_id: lerobot/aloha_sim_insertion_human
env:
name: aloha
task: AlohaInsertion-v0

View File

@ -1,18 +1,7 @@
# @package _global_
eval_episodes: 50
eval_freq: 7500
save_freq: 75000
log_freq: 250
# TODO: same as xarm, need to adjust
offline_steps: 25000
online_steps: 25000
fps: 10
dataset:
repo_id: lerobot/pusht
env:
name: pusht
task: PushT-v0

View File

@ -1,17 +1,7 @@
# @package _global_
eval_episodes: 20
eval_freq: 1000
save_freq: 10000
log_freq: 50
offline_steps: 25000
online_steps: 25000
fps: 15
dataset:
repo_id: lerobot/xarm_lift_medium
env:
name: xarm
task: XarmLift-v0

View File

@ -1,21 +1,34 @@
# @package _global_
offline_steps: 80000
online_steps: 0
seed: 1000
dataset_repo_id: lerobot/aloha_sim_insertion_human
eval_episodes: 1
eval_freq: 10000
save_freq: 100000
log_freq: 250
training:
offline_steps: 80000
online_steps: 0
eval_freq: 10000
save_freq: 100000
log_freq: 250
save_model: true
n_obs_steps: 1
# when temporal_agg=False, n_action_steps=horizon
batch_size: 8
lr: 1e-5
lr_backbone: 1e-5
weight_decay: 1e-4
grad_clip_norm: 10
online_steps_between_rollouts: 1
override_dataset_stats:
observation.images.top:
# stats from imagenet, since we use a pretrained vision model
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
override_dataset_stats:
observation.images.top:
# stats from imagenet, since we use a pretrained vision model
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
delta_timestamps:
action: "[i / ${fps} for i in range(${policy.chunk_size})]"
eval:
n_episodes:: 50
# See `configuration_act.py` for more details.
policy:
@ -24,7 +37,7 @@ policy:
pretrained_model_path:
# Input / output structure.
n_obs_steps: ${n_obs_steps}
n_obs_steps: 1
chunk_size: 100 # chunk_size
n_action_steps: 100
@ -66,15 +79,3 @@ policy:
# Training and loss computation.
dropout: 0.1
kl_weight: 10.0
# ---
# TODO(alexander-soare): Remove these from the policy config.
batch_size: 8
lr: 1e-5
lr_backbone: 1e-5
weight_decay: 1e-4
grad_clip_norm: 10
utd: 1
delta_timestamps:
action: "[i / ${fps} for i in range(${policy.chunk_size})]"

View File

@ -1,22 +1,33 @@
# @package _global_
seed: 100000
horizon: 16
n_obs_steps: 2
n_action_steps: 8
dataset_obs_steps: ${n_obs_steps}
past_action_visible: False
keypoint_visible_rate: 1.0
dataset_repo_id: lerobot/pusht
eval_episodes: 50
eval_freq: 5000
save_freq: 5000
log_freq: 250
training:
offline_steps: 200000
online_steps: 0
eval_freq: 5000
save_freq: 5000
log_freq: 250
save_model: true
offline_steps: 200000
online_steps: 0
batch_size: 64
grad_clip_norm: 10
lr: 1.0e-4
lr_scheduler: cosine
lr_warmup_steps: 500
adam_betas: [0.95, 0.999]
adam_eps: 1.0e-8
adam_weight_decay: 1.0e-6
online_steps_between_rollouts: 1
offline_prioritized_sampler: true
delta_timestamps:
observation.image: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1)]"
observation.state: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1)]"
action: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1 - ${policy.n_obs_steps} + ${policy.horizon})]"
eval:
n_episodes: 50
override_dataset_stats:
# TODO(rcadene, alexander-soare): should we remove image stats as well? do we use a pretrained vision model?
@ -38,9 +49,9 @@ policy:
pretrained_model_path:
# Input / output structure.
n_obs_steps: ${n_obs_steps}
horizon: ${horizon}
n_action_steps: ${n_action_steps}
n_obs_steps: 2
horizon: 16
n_action_steps: 8
input_shapes:
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
@ -84,23 +95,9 @@ policy:
# ---
# TODO(alexander-soare): Remove these from the policy config.
batch_size: 64
grad_clip_norm: 10
lr: 1.0e-4
lr_scheduler: cosine
lr_warmup_steps: 500
adam_betas: [0.95, 0.999]
adam_eps: 1.0e-8
adam_weight_decay: 1.0e-6
utd: 1
use_ema: true
ema_update_after_step: 0
ema_min_alpha: 0.0
ema_max_alpha: 0.9999
ema_inv_gamma: 1.0
ema_power: 0.75
delta_timestamps:
observation.image: "[i / ${fps} for i in range(1 - ${n_obs_steps}, 1)]"
observation.state: "[i / ${fps} for i in range(1 - ${n_obs_steps}, 1)]"
action: "[i / ${fps} for i in range(1 - ${n_obs_steps}, 1 - ${n_obs_steps} + ${policy.horizon})]"

View File

@ -54,7 +54,7 @@ policy:
seed_steps: 0
update_freq: 2
tau: 0.01
utd: 1
online_steps_between_rollouts: 1
# offline rl
# dataset_dir: ???

View File

@ -16,14 +16,14 @@ You have a specific config file to go with trained model weights, and want to ru
python lerobot/scripts/eval.py \
--config PATH/TO/FOLDER/config.yaml \
policy.pretrained_model_path=PATH/TO/FOLDER/weights.pth \
eval_episodes=10
eval.n_episodes=10
```
You have a HuggingFace Hub ID, you know which revision you want, and want to run 10 episodes (note that in this case,
you don't need to specify which weights to use):
```
python lerobot/scripts/eval.py --hub-id HUB/ID --revision v1.0 eval_episodes=10
python lerobot/scripts/eval.py --hub-id HUB/ID --revision v1.0 eval.n_episodes=10
```
"""
@ -365,7 +365,7 @@ def eval(cfg: dict, out_dir=None):
log_output_dir(out_dir)
logging.info("Making environment.")
env = make_env(cfg, num_parallel_envs=cfg.eval_episodes)
env = make_env(cfg, num_parallel_envs=cfg.eval.n_episodes)
logging.info("Making policy.")
policy = make_policy(cfg)

View File

@ -81,7 +81,7 @@ def log_train_info(logger, info, step, cfg, dataset, is_offline):
# A sample is an (observation,action) pair, where observation and action
# can be on multiple timestamps. In a batch, we have `batch_size`` number of samples.
num_samples = (step + 1) * cfg.policy.batch_size
num_samples = (step + 1) * cfg.training.batch_size
avg_samples_per_ep = dataset.num_samples / dataset.num_episodes
num_episodes = num_samples / avg_samples_per_ep
num_epochs = num_samples / dataset.num_samples
@ -117,7 +117,7 @@ def log_eval_info(logger, info, step, cfg, dataset, is_offline):
# A sample is an (observation,action) pair, where observation and action
# can be on multiple timestamps. In a batch, we have `batch_size`` number of samples.
num_samples = (step + 1) * cfg.policy.batch_size
num_samples = (step + 1) * cfg.training.batch_size
avg_samples_per_ep = dataset.num_samples / dataset.num_episodes
num_episodes = num_samples / avg_samples_per_ep
num_epochs = num_samples / dataset.num_samples
@ -246,8 +246,8 @@ def train(cfg: dict, out_dir=None, job_name=None):
raise NotImplementedError()
if job_name is None:
raise NotImplementedError()
if cfg.online_steps > 0:
assert cfg.rollout_batch_size == 1, "rollout_batch_size > 1 not supported for online training steps"
if cfg.training.online_steps > 0:
assert cfg.eval.batch_size == 1, "eval.batch_size > 1 not supported for online training steps"
init_logging()
@ -262,7 +262,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
offline_dataset = make_dataset(cfg)
logging.info("make_env")
env = make_env(cfg, num_parallel_envs=cfg.eval_episodes)
env = make_env(cfg, num_parallel_envs=cfg.eval.n_episodes)
logging.info("make_policy")
policy = make_policy(cfg, dataset_stats=offline_dataset.stats)
@ -282,31 +282,27 @@ def train(cfg: dict, out_dir=None, job_name=None):
"params": [
p for n, p in policy.named_parameters() if n.startswith("backbone") and p.requires_grad
],
"lr": cfg.policy.lr_backbone,
"lr": cfg.training.lr_backbone,
},
]
optimizer = torch.optim.AdamW(
optimizer_params_dicts, lr=cfg.policy.lr, weight_decay=cfg.policy.weight_decay
optimizer_params_dicts, lr=cfg.training.lr, weight_decay=cfg.training.weight_decay
)
lr_scheduler = None
elif cfg.policy.name == "diffusion":
optimizer = torch.optim.Adam(
policy.diffusion.parameters(),
cfg.policy.lr,
cfg.policy.adam_betas,
cfg.policy.adam_eps,
cfg.policy.adam_weight_decay,
cfg.training.lr,
cfg.training.adam_betas,
cfg.training.adam_eps,
cfg.training.adam_weight_decay,
)
# TODO(rcadene): modify lr scheduler so that it doesn't depend on epochs but steps
# configure lr scheduler
assert cfg.training.online_steps == 0, "Diffusion Policy does not handle online training."
lr_scheduler = get_scheduler(
cfg.policy.lr_scheduler,
cfg.training.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=cfg.policy.lr_warmup_steps,
num_training_steps=cfg.offline_steps,
# pytorch assumes stepping LRScheduler every epoch
# however huggingface diffusers steps it every batch
last_epoch=-1,
num_warmup_steps=cfg.training.lr_warmup_steps,
num_training_steps=cfg.training.offline_steps,
)
elif policy.name == "tdmpc":
raise NotImplementedError("TD-MPC not implemented yet.")
@ -319,8 +315,8 @@ def train(cfg: dict, out_dir=None, job_name=None):
log_output_dir(out_dir)
logging.info(f"{cfg.env.task=}")
logging.info(f"{cfg.offline_steps=} ({format_big_number(cfg.offline_steps)})")
logging.info(f"{cfg.online_steps=}")
logging.info(f"{cfg.training.offline_steps=} ({format_big_number(cfg.training.offline_steps)})")
logging.info(f"{cfg.training.online_steps=}")
logging.info(f"{offline_dataset.num_samples=} ({format_big_number(offline_dataset.num_samples)})")
logging.info(f"{offline_dataset.num_episodes=}")
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
@ -328,7 +324,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
# Note: this helper will be used in offline and online training loops.
def _maybe_eval_and_maybe_save(step):
if step % cfg.eval_freq == 0:
if step % cfg.training.eval_freq == 0:
logging.info(f"Eval policy at step {step}")
eval_info = eval_policy(
env,
@ -342,7 +338,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
logger.log_video(eval_info["videos"][0], step, mode="eval")
logging.info("Resume training")
if cfg.save_model and step % cfg.save_freq == 0:
if cfg.training.save_model and step % cfg.training.save_freq == 0:
logging.info(f"Checkpoint policy after step {step}")
logger.save_model(policy, identifier=step)
logging.info("Resume training")
@ -351,7 +347,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
dataloader = torch.utils.data.DataLoader(
offline_dataset,
num_workers=4,
batch_size=cfg.policy.batch_size,
batch_size=cfg.training.batch_size,
shuffle=True,
pin_memory=cfg.device != "cpu",
drop_last=False,
@ -360,7 +356,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
step = 0 # number of policy update (forward + backward + optim)
is_offline = True
for offline_step in range(cfg.offline_steps):
for offline_step in range(cfg.training.offline_steps):
if offline_step == 0:
logging.info("Start offline training on a fixed dataset")
policy.train()
@ -369,10 +365,10 @@ def train(cfg: dict, out_dir=None, job_name=None):
for key in batch:
batch[key] = batch[key].to(cfg.device, non_blocking=True)
train_info = update_policy(policy, batch, optimizer, cfg.policy.grad_clip_norm, lr_scheduler)
train_info = update_policy(policy, batch, optimizer, cfg.training.grad_clip_norm, lr_scheduler)
# TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done?
if step % cfg.log_freq == 0:
if step % cfg.training.log_freq == 0:
log_train_info(logger, train_info, step, cfg, offline_dataset, is_offline)
# Note: _maybe_eval_and_maybe_save happens **after** the `step`th training update has completed, so we pass in
@ -398,7 +394,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
dataloader = torch.utils.data.DataLoader(
concat_dataset,
num_workers=4,
batch_size=cfg.policy.batch_size,
batch_size=cfg.training.batch_size,
sampler=sampler,
pin_memory=cfg.device != "cpu",
drop_last=False,
@ -407,7 +403,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
online_step = 0
is_offline = False
for env_step in range(cfg.online_steps):
for env_step in range(cfg.training.online_steps):
if env_step == 0:
logging.info("Start online training by interacting with environment")
@ -428,16 +424,16 @@ def train(cfg: dict, out_dir=None, job_name=None):
pc_online_samples=cfg.get("demo_schedule", 0.5),
)
for _ in range(cfg.policy.utd):
for _ in range(cfg.training.online_steps_between_rollouts):
policy.train()
batch = next(dl_iter)
for key in batch:
batch[key] = batch[key].to(cfg.device, non_blocking=True)
train_info = update_policy(policy, batch, optimizer, cfg.policy.grad_clip_norm, lr_scheduler)
train_info = update_policy(policy, batch, optimizer, cfg.training.grad_clip_norm, lr_scheduler)
if step % cfg.log_freq == 0:
if step % cfg.training.log_freq == 0:
log_train_info(logger, train_info, step, cfg, online_dataset, is_offline)
# Note: _maybe_eval_and_maybe_save happens **after** the `step`th training update has completed, so we pass

View File

@ -33,7 +33,7 @@ def test_factory(env_name, repo_id, policy_name):
DEFAULT_CONFIG_PATH,
overrides=[
f"env={env_name}",
f"dataset.repo_id={repo_id}",
f"dataset_repo_id={repo_id}",
f"policy={policy_name}",
f"device={DEVICE}",
],

View File

@ -39,7 +39,7 @@ def test_examples_3_and_2():
("training_steps = 5000", "training_steps = 1"),
("num_workers=4", "num_workers=0"),
('device = torch.device("cuda")', 'device = torch.device("cpu")'),
("batch_size=cfg.batch_size", "batch_size=1"),
("batch_size=64", "batch_size=1"),
],
)
@ -58,8 +58,8 @@ def test_examples_3_and_2():
file_contents = _find_and_replace(
file_contents,
[
('"eval_episodes=10"', '"eval_episodes=1"'),
('"rollout_batch_size=10"', '"rollout_batch_size=1"'),
('"eval.n_episodes=10"', '"eval.n_episodes=1"'),
('"eval.batch_size=10"', '"eval.batch_size=1"'),
('"device=cuda"', '"device=cpu"'),
(
'# folder = Path("outputs/train/example_pusht_diffusion")',

View File

@ -21,21 +21,21 @@ from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_env
# ("xarm", "tdmpc", ["policy.mpc=true"]),
# ("pusht", "tdmpc", ["policy.mpc=false"]),
("pusht", "diffusion", []),
("aloha", "act", ["env.task=AlohaInsertion-v0", "dataset.repo_id=lerobot/aloha_sim_insertion_human"]),
("aloha", "act", ["env.task=AlohaInsertion-v0", "dataset_repo_id=lerobot/aloha_sim_insertion_human"]),
(
"aloha",
"act",
["env.task=AlohaInsertion-v0", "dataset.repo_id=lerobot/aloha_sim_insertion_scripted"],
["env.task=AlohaInsertion-v0", "dataset_repo_id=lerobot/aloha_sim_insertion_scripted"],
),
(
"aloha",
"act",
["env.task=AlohaTransferCube-v0", "dataset.repo_id=lerobot/aloha_sim_transfer_cube_human"],
["env.task=AlohaTransferCube-v0", "dataset_repo_id=lerobot/aloha_sim_transfer_cube_human"],
),
(
"aloha",
"act",
["env.task=AlohaTransferCube-v0", "dataset.repo_id=lerobot/aloha_sim_transfer_cube_scripted"],
["env.task=AlohaTransferCube-v0", "dataset_repo_id=lerobot/aloha_sim_transfer_cube_scripted"],
),
],
)

View File

@ -20,7 +20,7 @@ def test_visualize_dataset(tmpdir, repo_id):
overrides=[
"policy=act",
"env=aloha",
f"dataset.repo_id={repo_id}",
f"dataset_repo_id={repo_id}",
],
)
video_paths = visualize_dataset(cfg, out_dir=tmpdir)