Add real-world support for ACT on Aloha/Aloha2 (#228)
Co-authored-by: Alexander Soare <alexander.soare159@gmail.com>
This commit is contained in:
parent
504d2aaf48
commit
d585c73f9f
|
@ -121,7 +121,6 @@ celerybeat.pid
|
||||||
# Environments
|
# Environments
|
||||||
.env
|
.env
|
||||||
.venv
|
.venv
|
||||||
env/
|
|
||||||
venv/
|
venv/
|
||||||
ENV/
|
ENV/
|
||||||
env.bak/
|
env.bak/
|
||||||
|
|
|
@ -45,6 +45,9 @@ import itertools
|
||||||
|
|
||||||
from lerobot.__version__ import __version__ # noqa: F401
|
from lerobot.__version__ import __version__ # noqa: F401
|
||||||
|
|
||||||
|
# TODO(rcadene): Improve policies and envs. As of now, an item in `available_policies`
|
||||||
|
# refers to a yaml file AND a modeling name. Same for `available_envs` which refers to
|
||||||
|
# a yaml file AND a environment name. The difference should be more obvious.
|
||||||
available_tasks_per_env = {
|
available_tasks_per_env = {
|
||||||
"aloha": [
|
"aloha": [
|
||||||
"AlohaInsertion-v0",
|
"AlohaInsertion-v0",
|
||||||
|
@ -52,6 +55,7 @@ available_tasks_per_env = {
|
||||||
],
|
],
|
||||||
"pusht": ["PushT-v0"],
|
"pusht": ["PushT-v0"],
|
||||||
"xarm": ["XarmLift-v0"],
|
"xarm": ["XarmLift-v0"],
|
||||||
|
"dora_aloha_real": ["DoraAloha-v0", "DoraKoch-v0", "DoraReachy2-v0"],
|
||||||
}
|
}
|
||||||
available_envs = list(available_tasks_per_env.keys())
|
available_envs = list(available_tasks_per_env.keys())
|
||||||
|
|
||||||
|
@ -77,6 +81,23 @@ available_datasets_per_env = {
|
||||||
"lerobot/xarm_push_medium_image",
|
"lerobot/xarm_push_medium_image",
|
||||||
"lerobot/xarm_push_medium_replay_image",
|
"lerobot/xarm_push_medium_replay_image",
|
||||||
],
|
],
|
||||||
|
"dora_aloha_real": [
|
||||||
|
"lerobot/aloha_static_battery",
|
||||||
|
"lerobot/aloha_static_candy",
|
||||||
|
"lerobot/aloha_static_coffee",
|
||||||
|
"lerobot/aloha_static_coffee_new",
|
||||||
|
"lerobot/aloha_static_cups_open",
|
||||||
|
"lerobot/aloha_static_fork_pick_up",
|
||||||
|
"lerobot/aloha_static_pingpong_test",
|
||||||
|
"lerobot/aloha_static_pro_pencil",
|
||||||
|
"lerobot/aloha_static_screw_driver",
|
||||||
|
"lerobot/aloha_static_tape",
|
||||||
|
"lerobot/aloha_static_thread_velcro",
|
||||||
|
"lerobot/aloha_static_towel",
|
||||||
|
"lerobot/aloha_static_vinh_cup",
|
||||||
|
"lerobot/aloha_static_vinh_cup_left",
|
||||||
|
"lerobot/aloha_static_ziploc_slide",
|
||||||
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
available_real_world_datasets = [
|
available_real_world_datasets = [
|
||||||
|
@ -108,16 +129,19 @@ available_datasets = list(
|
||||||
itertools.chain(*available_datasets_per_env.values(), available_real_world_datasets)
|
itertools.chain(*available_datasets_per_env.values(), available_real_world_datasets)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# lists all available policies from `lerobot/common/policies` by their class attribute: `name`.
|
||||||
available_policies = [
|
available_policies = [
|
||||||
"act",
|
"act",
|
||||||
"diffusion",
|
"diffusion",
|
||||||
"tdmpc",
|
"tdmpc",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# keys and values refer to yaml files
|
||||||
available_policies_per_env = {
|
available_policies_per_env = {
|
||||||
"aloha": ["act"],
|
"aloha": ["act"],
|
||||||
"pusht": ["diffusion"],
|
"pusht": ["diffusion"],
|
||||||
"xarm": ["tdmpc"],
|
"xarm": ["tdmpc"],
|
||||||
|
"dora_aloha_real": ["act_real"],
|
||||||
}
|
}
|
||||||
|
|
||||||
env_task_pairs = [(env, task) for env, tasks in available_tasks_per_env.items() for task in tasks]
|
env_task_pairs = [(env, task) for env, tasks in available_tasks_per_env.items() for task in tasks]
|
||||||
|
|
|
@ -55,9 +55,17 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData
|
||||||
"strings to load multiple datasets."
|
"strings to load multiple datasets."
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(cfg.dataset_repo_id, str) and cfg.env.name not in cfg.dataset_repo_id:
|
# A soft check to warn if the environment matches the dataset. Don't check if we are using a real world env (dora).
|
||||||
|
if cfg.env.name != "dora":
|
||||||
|
if isinstance(cfg.dataset_repo_id, str):
|
||||||
|
dataset_repo_ids = [cfg.dataset_repo_id] # single dataset
|
||||||
|
else:
|
||||||
|
dataset_repo_ids = cfg.dataset_repo_id # multiple datasets
|
||||||
|
|
||||||
|
for dataset_repo_id in dataset_repo_ids:
|
||||||
|
if cfg.env.name not in dataset_repo_id:
|
||||||
logging.warning(
|
logging.warning(
|
||||||
f"There might be a mismatch between your training dataset ({cfg.dataset_repo_id=}) and your "
|
f"There might be a mismatch between your training dataset ({dataset_repo_id=}) and your "
|
||||||
f"environment ({cfg.env.name=})."
|
f"environment ({cfg.env.name=})."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -25,6 +25,13 @@ class ACTConfig:
|
||||||
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
|
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
|
||||||
Those are: `input_shapes` and 'output_shapes`.
|
Those are: `input_shapes` and 'output_shapes`.
|
||||||
|
|
||||||
|
Notes on the inputs and outputs:
|
||||||
|
- At least one key starting with "observation.image is required as an input.
|
||||||
|
- If there are multiple keys beginning with "observation.images." they are treated as multiple camera
|
||||||
|
views. Right now we only support all images having the same shape.
|
||||||
|
- May optionally work without an "observation.state" key for the proprioceptive robot state.
|
||||||
|
- "action" is required as an output key.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
|
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
|
||||||
current step and additional steps going back).
|
current step and additional steps going back).
|
||||||
|
@ -33,15 +40,15 @@ class ACTConfig:
|
||||||
This should be no greater than the chunk size. For example, if the chunk size size 100, you may
|
This should be no greater than the chunk size. For example, if the chunk size size 100, you may
|
||||||
set this to 50. This would mean that the model predicts 100 steps worth of actions, runs 50 in the
|
set this to 50. This would mean that the model predicts 100 steps worth of actions, runs 50 in the
|
||||||
environment, and throws the other 50 out.
|
environment, and throws the other 50 out.
|
||||||
input_shapes: A dictionary defining the shapes of the input data for the policy.
|
input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
|
||||||
The key represents the input data name, and the value is a list indicating the dimensions
|
the input data name, and the value is a list indicating the dimensions of the corresponding data.
|
||||||
of the corresponding data. For example, "observation.images.top" refers to an input from the
|
For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96],
|
||||||
"top" camera with dimensions [3, 96, 96], indicating it has three color channels and 96x96 resolution.
|
indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't
|
||||||
Importantly, shapes doesn't include batch dimension or temporal dimension.
|
include batch dimension or temporal dimension.
|
||||||
output_shapes: A dictionary defining the shapes of the output data for the policy.
|
output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents
|
||||||
The key represents the output data name, and the value is a list indicating the dimensions
|
the output data name, and the value is a list indicating the dimensions of the corresponding data.
|
||||||
of the corresponding data. For example, "action" refers to an output shape of [14], indicating
|
For example, "action" refers to an output shape of [14], indicating 14-dimensional actions.
|
||||||
14-dimensional actions. Importantly, shapes doesn't include batch dimension or temporal dimension.
|
Importantly, `output_shapes` doesn't include batch dimension or temporal dimension.
|
||||||
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
|
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
|
||||||
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
|
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
|
||||||
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
|
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
|
||||||
|
|
|
@ -200,25 +200,29 @@ class ACT(nn.Module):
|
||||||
self.config = config
|
self.config = config
|
||||||
# BERT style VAE encoder with input tokens [cls, robot_state, *action_sequence].
|
# BERT style VAE encoder with input tokens [cls, robot_state, *action_sequence].
|
||||||
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
|
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
|
||||||
|
self.use_input_state = "observation.state" in config.input_shapes
|
||||||
if self.config.use_vae:
|
if self.config.use_vae:
|
||||||
self.vae_encoder = ACTEncoder(config)
|
self.vae_encoder = ACTEncoder(config)
|
||||||
self.vae_encoder_cls_embed = nn.Embedding(1, config.dim_model)
|
self.vae_encoder_cls_embed = nn.Embedding(1, config.dim_model)
|
||||||
# Projection layer for joint-space configuration to hidden dimension.
|
# Projection layer for joint-space configuration to hidden dimension.
|
||||||
|
if self.use_input_state:
|
||||||
self.vae_encoder_robot_state_input_proj = nn.Linear(
|
self.vae_encoder_robot_state_input_proj = nn.Linear(
|
||||||
config.input_shapes["observation.state"][0], config.dim_model
|
config.input_shapes["observation.state"][0], config.dim_model
|
||||||
)
|
)
|
||||||
# Projection layer for action (joint-space target) to hidden dimension.
|
# Projection layer for action (joint-space target) to hidden dimension.
|
||||||
self.vae_encoder_action_input_proj = nn.Linear(
|
self.vae_encoder_action_input_proj = nn.Linear(
|
||||||
config.input_shapes["observation.state"][0], config.dim_model
|
config.output_shapes["action"][0], config.dim_model
|
||||||
)
|
)
|
||||||
self.latent_dim = config.latent_dim
|
|
||||||
# Projection layer from the VAE encoder's output to the latent distribution's parameter space.
|
# Projection layer from the VAE encoder's output to the latent distribution's parameter space.
|
||||||
self.vae_encoder_latent_output_proj = nn.Linear(config.dim_model, self.latent_dim * 2)
|
self.vae_encoder_latent_output_proj = nn.Linear(config.dim_model, config.latent_dim * 2)
|
||||||
# Fixed sinusoidal positional embedding for the input to the VAE encoder. Unsqueeze for batch
|
# Fixed sinusoidal positional embedding for the input to the VAE encoder. Unsqueeze for batch
|
||||||
# dimension.
|
# dimension.
|
||||||
|
num_input_token_encoder = 1 + config.chunk_size
|
||||||
|
if self.use_input_state:
|
||||||
|
num_input_token_encoder += 1
|
||||||
self.register_buffer(
|
self.register_buffer(
|
||||||
"vae_encoder_pos_enc",
|
"vae_encoder_pos_enc",
|
||||||
create_sinusoidal_pos_embedding(1 + 1 + config.chunk_size, config.dim_model).unsqueeze(0),
|
create_sinusoidal_pos_embedding(num_input_token_encoder, config.dim_model).unsqueeze(0),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Backbone for image feature extraction.
|
# Backbone for image feature extraction.
|
||||||
|
@ -238,15 +242,17 @@ class ACT(nn.Module):
|
||||||
|
|
||||||
# Transformer encoder input projections. The tokens will be structured like
|
# Transformer encoder input projections. The tokens will be structured like
|
||||||
# [latent, robot_state, image_feature_map_pixels].
|
# [latent, robot_state, image_feature_map_pixels].
|
||||||
|
if self.use_input_state:
|
||||||
self.encoder_robot_state_input_proj = nn.Linear(
|
self.encoder_robot_state_input_proj = nn.Linear(
|
||||||
config.input_shapes["observation.state"][0], config.dim_model
|
config.input_shapes["observation.state"][0], config.dim_model
|
||||||
)
|
)
|
||||||
self.encoder_latent_input_proj = nn.Linear(self.latent_dim, config.dim_model)
|
self.encoder_latent_input_proj = nn.Linear(config.latent_dim, config.dim_model)
|
||||||
self.encoder_img_feat_input_proj = nn.Conv2d(
|
self.encoder_img_feat_input_proj = nn.Conv2d(
|
||||||
backbone_model.fc.in_features, config.dim_model, kernel_size=1
|
backbone_model.fc.in_features, config.dim_model, kernel_size=1
|
||||||
)
|
)
|
||||||
# Transformer encoder positional embeddings.
|
# Transformer encoder positional embeddings.
|
||||||
self.encoder_robot_and_latent_pos_embed = nn.Embedding(2, config.dim_model)
|
num_input_token_decoder = 2 if self.use_input_state else 1
|
||||||
|
self.encoder_robot_and_latent_pos_embed = nn.Embedding(num_input_token_decoder, config.dim_model)
|
||||||
self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2)
|
self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2)
|
||||||
|
|
||||||
# Transformer decoder.
|
# Transformer decoder.
|
||||||
|
@ -285,7 +291,7 @@ class ACT(nn.Module):
|
||||||
"action" in batch
|
"action" in batch
|
||||||
), "actions must be provided when using the variational objective in training mode."
|
), "actions must be provided when using the variational objective in training mode."
|
||||||
|
|
||||||
batch_size = batch["observation.state"].shape[0]
|
batch_size = batch["observation.images"].shape[0]
|
||||||
|
|
||||||
# Prepare the latent for input to the transformer encoder.
|
# Prepare the latent for input to the transformer encoder.
|
||||||
if self.config.use_vae and "action" in batch:
|
if self.config.use_vae and "action" in batch:
|
||||||
|
@ -293,11 +299,16 @@ class ACT(nn.Module):
|
||||||
cls_embed = einops.repeat(
|
cls_embed = einops.repeat(
|
||||||
self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size
|
self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size
|
||||||
) # (B, 1, D)
|
) # (B, 1, D)
|
||||||
robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"]).unsqueeze(
|
if self.use_input_state:
|
||||||
1
|
robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"])
|
||||||
) # (B, 1, D)
|
robot_state_embed = robot_state_embed.unsqueeze(1) # (B, 1, D)
|
||||||
action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D)
|
action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D)
|
||||||
vae_encoder_input = torch.cat([cls_embed, robot_state_embed, action_embed], axis=1) # (B, S+2, D)
|
|
||||||
|
if self.use_input_state:
|
||||||
|
vae_encoder_input = [cls_embed, robot_state_embed, action_embed] # (B, S+2, D)
|
||||||
|
else:
|
||||||
|
vae_encoder_input = [cls_embed, action_embed]
|
||||||
|
vae_encoder_input = torch.cat(vae_encoder_input, axis=1)
|
||||||
|
|
||||||
# Prepare fixed positional embedding.
|
# Prepare fixed positional embedding.
|
||||||
# Note: detach() shouldn't be necessary but leaving it the same as the original code just in case.
|
# Note: detach() shouldn't be necessary but leaving it the same as the original code just in case.
|
||||||
|
@ -308,16 +319,17 @@ class ACT(nn.Module):
|
||||||
vae_encoder_input.permute(1, 0, 2), pos_embed=pos_embed.permute(1, 0, 2)
|
vae_encoder_input.permute(1, 0, 2), pos_embed=pos_embed.permute(1, 0, 2)
|
||||||
)[0] # select the class token, with shape (B, D)
|
)[0] # select the class token, with shape (B, D)
|
||||||
latent_pdf_params = self.vae_encoder_latent_output_proj(cls_token_out)
|
latent_pdf_params = self.vae_encoder_latent_output_proj(cls_token_out)
|
||||||
mu = latent_pdf_params[:, : self.latent_dim]
|
mu = latent_pdf_params[:, : self.config.latent_dim]
|
||||||
# This is 2log(sigma). Done this way to match the original implementation.
|
# This is 2log(sigma). Done this way to match the original implementation.
|
||||||
log_sigma_x2 = latent_pdf_params[:, self.latent_dim :]
|
log_sigma_x2 = latent_pdf_params[:, self.config.latent_dim :]
|
||||||
|
|
||||||
# Sample the latent with the reparameterization trick.
|
# Sample the latent with the reparameterization trick.
|
||||||
latent_sample = mu + log_sigma_x2.div(2).exp() * torch.randn_like(mu)
|
latent_sample = mu + log_sigma_x2.div(2).exp() * torch.randn_like(mu)
|
||||||
else:
|
else:
|
||||||
# When not using the VAE encoder, we set the latent to be all zeros.
|
# When not using the VAE encoder, we set the latent to be all zeros.
|
||||||
mu = log_sigma_x2 = None
|
mu = log_sigma_x2 = None
|
||||||
latent_sample = torch.zeros([batch_size, self.latent_dim], dtype=torch.float32).to(
|
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer
|
||||||
|
latent_sample = torch.zeros([batch_size, self.config.latent_dim], dtype=torch.float32).to(
|
||||||
batch["observation.state"].device
|
batch["observation.state"].device
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -326,8 +338,10 @@ class ACT(nn.Module):
|
||||||
all_cam_features = []
|
all_cam_features = []
|
||||||
all_cam_pos_embeds = []
|
all_cam_pos_embeds = []
|
||||||
images = batch["observation.images"]
|
images = batch["observation.images"]
|
||||||
|
|
||||||
for cam_index in range(images.shape[-4]):
|
for cam_index in range(images.shape[-4]):
|
||||||
cam_features = self.backbone(images[:, cam_index])["feature_map"]
|
cam_features = self.backbone(images[:, cam_index])["feature_map"]
|
||||||
|
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer
|
||||||
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
|
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
|
||||||
cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w)
|
cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w)
|
||||||
all_cam_features.append(cam_features)
|
all_cam_features.append(cam_features)
|
||||||
|
@ -337,13 +351,15 @@ class ACT(nn.Module):
|
||||||
cam_pos_embed = torch.cat(all_cam_pos_embeds, axis=-1)
|
cam_pos_embed = torch.cat(all_cam_pos_embeds, axis=-1)
|
||||||
|
|
||||||
# Get positional embeddings for robot state and latent.
|
# Get positional embeddings for robot state and latent.
|
||||||
|
if self.use_input_state:
|
||||||
robot_state_embed = self.encoder_robot_state_input_proj(batch["observation.state"]) # (B, C)
|
robot_state_embed = self.encoder_robot_state_input_proj(batch["observation.state"]) # (B, C)
|
||||||
latent_embed = self.encoder_latent_input_proj(latent_sample) # (B, C)
|
latent_embed = self.encoder_latent_input_proj(latent_sample) # (B, C)
|
||||||
|
|
||||||
# Stack encoder input and positional embeddings moving to (S, B, C).
|
# Stack encoder input and positional embeddings moving to (S, B, C).
|
||||||
|
encoder_in_feats = [latent_embed, robot_state_embed] if self.use_input_state else [latent_embed]
|
||||||
encoder_in = torch.cat(
|
encoder_in = torch.cat(
|
||||||
[
|
[
|
||||||
torch.stack([latent_embed, robot_state_embed], axis=0),
|
torch.stack(encoder_in_feats, axis=0),
|
||||||
einops.rearrange(encoder_in, "b c h w -> (h w) b c"),
|
einops.rearrange(encoder_in, "b c h w -> (h w) b c"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
@ -357,6 +373,7 @@ class ACT(nn.Module):
|
||||||
|
|
||||||
# Forward pass through the transformer modules.
|
# Forward pass through the transformer modules.
|
||||||
encoder_out = self.encoder(encoder_in, pos_embed=pos_embed)
|
encoder_out = self.encoder(encoder_in, pos_embed=pos_embed)
|
||||||
|
# TODO(rcadene, alexander-soare): remove call to `device` ; precompute and use buffer
|
||||||
decoder_in = torch.zeros(
|
decoder_in = torch.zeros(
|
||||||
(self.config.chunk_size, batch_size, self.config.dim_model),
|
(self.config.chunk_size, batch_size, self.config.dim_model),
|
||||||
dtype=pos_embed.dtype,
|
dtype=pos_embed.dtype,
|
||||||
|
|
|
@ -26,21 +26,26 @@ class DiffusionConfig:
|
||||||
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
|
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
|
||||||
Those are: `input_shapes` and `output_shapes`.
|
Those are: `input_shapes` and `output_shapes`.
|
||||||
|
|
||||||
|
Notes on the inputs and outputs:
|
||||||
|
- "observation.state" is required as an input key.
|
||||||
|
- A key starting with "observation.image is required as an input.
|
||||||
|
- "action" is required as an output key.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
|
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
|
||||||
current step and additional steps going back).
|
current step and additional steps going back).
|
||||||
horizon: Diffusion model action prediction size as detailed in `DiffusionPolicy.select_action`.
|
horizon: Diffusion model action prediction size as detailed in `DiffusionPolicy.select_action`.
|
||||||
n_action_steps: The number of action steps to run in the environment for one invocation of the policy.
|
n_action_steps: The number of action steps to run in the environment for one invocation of the policy.
|
||||||
See `DiffusionPolicy.select_action` for more details.
|
See `DiffusionPolicy.select_action` for more details.
|
||||||
input_shapes: A dictionary defining the shapes of the input data for the policy.
|
input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
|
||||||
The key represents the input data name, and the value is a list indicating the dimensions
|
the input data name, and the value is a list indicating the dimensions of the corresponding data.
|
||||||
of the corresponding data. For example, "observation.image" refers to an input from
|
For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96],
|
||||||
a camera with dimensions [3, 96, 96], indicating it has three color channels and 96x96 resolution.
|
indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't
|
||||||
Importantly, shapes doesnt include batch dimension or temporal dimension.
|
include batch dimension or temporal dimension.
|
||||||
output_shapes: A dictionary defining the shapes of the output data for the policy.
|
output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents
|
||||||
The key represents the output data name, and the value is a list indicating the dimensions
|
the output data name, and the value is a list indicating the dimensions of the corresponding data.
|
||||||
of the corresponding data. For example, "action" refers to an output shape of [14], indicating
|
For example, "action" refers to an output shape of [14], indicating 14-dimensional actions.
|
||||||
14-dimensional actions. Importantly, shapes doesnt include batch dimension or temporal dimension.
|
Importantly, `output_shapes` doesn't include batch dimension or temporal dimension.
|
||||||
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
|
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
|
||||||
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
|
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
|
||||||
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
|
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
|
||||||
|
|
|
@ -31,6 +31,15 @@ class TDMPCConfig:
|
||||||
n_action_repeats: The number of times to repeat the action returned by the planning. (hint: Google
|
n_action_repeats: The number of times to repeat the action returned by the planning. (hint: Google
|
||||||
action repeats in Q-learning or ask your favorite chatbot)
|
action repeats in Q-learning or ask your favorite chatbot)
|
||||||
horizon: Horizon for model predictive control.
|
horizon: Horizon for model predictive control.
|
||||||
|
input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
|
||||||
|
the input data name, and the value is a list indicating the dimensions of the corresponding data.
|
||||||
|
For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96],
|
||||||
|
indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't
|
||||||
|
include batch dimension or temporal dimension.
|
||||||
|
output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents
|
||||||
|
the output data name, and the value is a list indicating the dimensions of the corresponding data.
|
||||||
|
For example, "action" refers to an output shape of [14], indicating 14-dimensional actions.
|
||||||
|
Importantly, `output_shapes` doesn't include batch dimension or temporal dimension.
|
||||||
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
|
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
|
||||||
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
|
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
|
||||||
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
|
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
|
||||||
|
|
|
@ -0,0 +1,13 @@
|
||||||
|
# @package _global_
|
||||||
|
|
||||||
|
fps: 30
|
||||||
|
|
||||||
|
env:
|
||||||
|
name: dora
|
||||||
|
task: DoraAloha-v0
|
||||||
|
state_dim: 14
|
||||||
|
action_dim: 14
|
||||||
|
fps: ${fps}
|
||||||
|
episode_length: 400
|
||||||
|
gym:
|
||||||
|
fps: ${fps}
|
|
@ -0,0 +1,115 @@
|
||||||
|
# @package _global_
|
||||||
|
|
||||||
|
# Use `act_real.yaml` to train on real-world Aloha/Aloha2 datasets.
|
||||||
|
# Compared to `act.yaml`, it contains 4 cameras (i.e. cam_right_wrist, cam_left_wrist, images,
|
||||||
|
# cam_low) instead of 1 camera (i.e. top). Also, `training.eval_freq` is set to -1. This config is used
|
||||||
|
# to evaluate checkpoints at a certain frequency of training steps. When it is set to -1, it deactivates evaluation.
|
||||||
|
# This is because real-world evaluation is done through [dora-lerobot](https://github.com/dora-rs/dora-lerobot).
|
||||||
|
# Look at its README for more information on how to evaluate a checkpoint in the real-world.
|
||||||
|
#
|
||||||
|
# Example of usage for training:
|
||||||
|
# ```bash
|
||||||
|
# python lerobot/scripts/train.py \
|
||||||
|
# policy=act_real \
|
||||||
|
# env=dora_aloha_real
|
||||||
|
# ```
|
||||||
|
|
||||||
|
seed: 1000
|
||||||
|
dataset_repo_id: lerobot/aloha_static_vinh_cup
|
||||||
|
|
||||||
|
override_dataset_stats:
|
||||||
|
observation.images.cam_right_wrist:
|
||||||
|
# stats from imagenet, since we use a pretrained vision model
|
||||||
|
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||||
|
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||||
|
observation.images.cam_left_wrist:
|
||||||
|
# stats from imagenet, since we use a pretrained vision model
|
||||||
|
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||||
|
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||||
|
observation.images.cam_high:
|
||||||
|
# stats from imagenet, since we use a pretrained vision model
|
||||||
|
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||||
|
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||||
|
observation.images.cam_low:
|
||||||
|
# stats from imagenet, since we use a pretrained vision model
|
||||||
|
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||||
|
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||||
|
|
||||||
|
training:
|
||||||
|
offline_steps: 80000
|
||||||
|
online_steps: 0
|
||||||
|
eval_freq: -1
|
||||||
|
save_freq: 10000
|
||||||
|
log_freq: 100
|
||||||
|
save_checkpoint: true
|
||||||
|
|
||||||
|
batch_size: 8
|
||||||
|
lr: 1e-5
|
||||||
|
lr_backbone: 1e-5
|
||||||
|
weight_decay: 1e-4
|
||||||
|
grad_clip_norm: 10
|
||||||
|
online_steps_between_rollouts: 1
|
||||||
|
|
||||||
|
delta_timestamps:
|
||||||
|
action: "[i / ${fps} for i in range(${policy.chunk_size})]"
|
||||||
|
|
||||||
|
eval:
|
||||||
|
n_episodes: 50
|
||||||
|
batch_size: 50
|
||||||
|
|
||||||
|
# See `configuration_act.py` for more details.
|
||||||
|
policy:
|
||||||
|
name: act
|
||||||
|
|
||||||
|
# Input / output structure.
|
||||||
|
n_obs_steps: 1
|
||||||
|
chunk_size: 100 # chunk_size
|
||||||
|
n_action_steps: 100
|
||||||
|
|
||||||
|
input_shapes:
|
||||||
|
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
||||||
|
observation.images.cam_right_wrist: [3, 480, 640]
|
||||||
|
observation.images.cam_left_wrist: [3, 480, 640]
|
||||||
|
observation.images.cam_high: [3, 480, 640]
|
||||||
|
observation.images.cam_low: [3, 480, 640]
|
||||||
|
observation.state: ["${env.state_dim}"]
|
||||||
|
output_shapes:
|
||||||
|
action: ["${env.action_dim}"]
|
||||||
|
|
||||||
|
# Normalization / Unnormalization
|
||||||
|
input_normalization_modes:
|
||||||
|
observation.images.cam_right_wrist: mean_std
|
||||||
|
observation.images.cam_left_wrist: mean_std
|
||||||
|
observation.images.cam_high: mean_std
|
||||||
|
observation.images.cam_low: mean_std
|
||||||
|
observation.state: mean_std
|
||||||
|
output_normalization_modes:
|
||||||
|
action: mean_std
|
||||||
|
|
||||||
|
# Architecture.
|
||||||
|
# Vision backbone.
|
||||||
|
vision_backbone: resnet18
|
||||||
|
pretrained_backbone_weights: ResNet18_Weights.IMAGENET1K_V1
|
||||||
|
replace_final_stride_with_dilation: false
|
||||||
|
# Transformer layers.
|
||||||
|
pre_norm: false
|
||||||
|
dim_model: 512
|
||||||
|
n_heads: 8
|
||||||
|
dim_feedforward: 3200
|
||||||
|
feedforward_activation: relu
|
||||||
|
n_encoder_layers: 4
|
||||||
|
# Note: Although the original ACT implementation has 7 for `n_decoder_layers`, there is a bug in the code
|
||||||
|
# that means only the first layer is used. Here we match the original implementation by setting this to 1.
|
||||||
|
# See this issue https://github.com/tonyzhaozh/act/issues/25#issue-2258740521.
|
||||||
|
n_decoder_layers: 1
|
||||||
|
# VAE.
|
||||||
|
use_vae: true
|
||||||
|
latent_dim: 32
|
||||||
|
n_vae_encoder_layers: 4
|
||||||
|
|
||||||
|
# Inference.
|
||||||
|
temporal_ensemble_momentum: null
|
||||||
|
|
||||||
|
# Training and loss computation.
|
||||||
|
dropout: 0.1
|
||||||
|
kl_weight: 10.0
|
|
@ -0,0 +1,111 @@
|
||||||
|
# @package _global_
|
||||||
|
|
||||||
|
# Use `act_real_no_state.yaml` to train on real-world Aloha/Aloha2 datasets when cameras are moving (e.g. wrist cameras)
|
||||||
|
# Compared to `act_real.yaml`, it is camera only and does not use the state as input which is vector of robot joint positions.
|
||||||
|
# We validated experimentaly that not using state reaches better success rate. Our hypothesis is that `act_real.yaml` might
|
||||||
|
# overfits to the state, because the images are more complex to learn from since they are moving.
|
||||||
|
#
|
||||||
|
# Example of usage for training:
|
||||||
|
# ```bash
|
||||||
|
# python lerobot/scripts/train.py \
|
||||||
|
# policy=act_real_no_state \
|
||||||
|
# env=dora_aloha_real
|
||||||
|
# ```
|
||||||
|
|
||||||
|
seed: 1000
|
||||||
|
dataset_repo_id: lerobot/aloha_static_vinh_cup
|
||||||
|
|
||||||
|
override_dataset_stats:
|
||||||
|
observation.images.cam_right_wrist:
|
||||||
|
# stats from imagenet, since we use a pretrained vision model
|
||||||
|
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||||
|
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||||
|
observation.images.cam_left_wrist:
|
||||||
|
# stats from imagenet, since we use a pretrained vision model
|
||||||
|
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||||
|
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||||
|
observation.images.cam_high:
|
||||||
|
# stats from imagenet, since we use a pretrained vision model
|
||||||
|
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||||
|
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||||
|
observation.images.cam_low:
|
||||||
|
# stats from imagenet, since we use a pretrained vision model
|
||||||
|
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||||
|
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||||
|
|
||||||
|
training:
|
||||||
|
offline_steps: 80000
|
||||||
|
online_steps: 0
|
||||||
|
eval_freq: -1
|
||||||
|
save_freq: 10000
|
||||||
|
log_freq: 100
|
||||||
|
save_checkpoint: true
|
||||||
|
|
||||||
|
batch_size: 8
|
||||||
|
lr: 1e-5
|
||||||
|
lr_backbone: 1e-5
|
||||||
|
weight_decay: 1e-4
|
||||||
|
grad_clip_norm: 10
|
||||||
|
online_steps_between_rollouts: 1
|
||||||
|
|
||||||
|
delta_timestamps:
|
||||||
|
action: "[i / ${fps} for i in range(${policy.chunk_size})]"
|
||||||
|
|
||||||
|
eval:
|
||||||
|
n_episodes: 50
|
||||||
|
batch_size: 50
|
||||||
|
|
||||||
|
# See `configuration_act.py` for more details.
|
||||||
|
policy:
|
||||||
|
name: act
|
||||||
|
|
||||||
|
# Input / output structure.
|
||||||
|
n_obs_steps: 1
|
||||||
|
chunk_size: 100 # chunk_size
|
||||||
|
n_action_steps: 100
|
||||||
|
|
||||||
|
input_shapes:
|
||||||
|
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
||||||
|
observation.images.cam_right_wrist: [3, 480, 640]
|
||||||
|
observation.images.cam_left_wrist: [3, 480, 640]
|
||||||
|
observation.images.cam_high: [3, 480, 640]
|
||||||
|
observation.images.cam_low: [3, 480, 640]
|
||||||
|
output_shapes:
|
||||||
|
action: ["${env.action_dim}"]
|
||||||
|
|
||||||
|
# Normalization / Unnormalization
|
||||||
|
input_normalization_modes:
|
||||||
|
observation.images.cam_right_wrist: mean_std
|
||||||
|
observation.images.cam_left_wrist: mean_std
|
||||||
|
observation.images.cam_high: mean_std
|
||||||
|
observation.images.cam_low: mean_std
|
||||||
|
output_normalization_modes:
|
||||||
|
action: mean_std
|
||||||
|
|
||||||
|
# Architecture.
|
||||||
|
# Vision backbone.
|
||||||
|
vision_backbone: resnet18
|
||||||
|
pretrained_backbone_weights: ResNet18_Weights.IMAGENET1K_V1
|
||||||
|
replace_final_stride_with_dilation: false
|
||||||
|
# Transformer layers.
|
||||||
|
pre_norm: false
|
||||||
|
dim_model: 512
|
||||||
|
n_heads: 8
|
||||||
|
dim_feedforward: 3200
|
||||||
|
feedforward_activation: relu
|
||||||
|
n_encoder_layers: 4
|
||||||
|
# Note: Although the original ACT implementation has 7 for `n_decoder_layers`, there is a bug in the code
|
||||||
|
# that means only the first layer is used. Here we match the original implementation by setting this to 1.
|
||||||
|
# See this issue https://github.com/tonyzhaozh/act/issues/25#issue-2258740521.
|
||||||
|
n_decoder_layers: 1
|
||||||
|
# VAE.
|
||||||
|
use_vae: true
|
||||||
|
latent_dim: 32
|
||||||
|
n_vae_encoder_layers: 4
|
||||||
|
|
||||||
|
# Inference.
|
||||||
|
temporal_ensemble_momentum: null
|
||||||
|
|
||||||
|
# Training and loss computation.
|
||||||
|
dropout: 0.1
|
||||||
|
kl_weight: 10.0
|
|
@ -1,4 +1,4 @@
|
||||||
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
|
# This file is automatically @generated by Poetry 1.8.1 and should not be changed by hand.
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "absl-py"
|
name = "absl-py"
|
||||||
|
@ -444,63 +444,63 @@ files = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "coverage"
|
name = "coverage"
|
||||||
version = "7.5.1"
|
version = "7.5.3"
|
||||||
description = "Code coverage measurement for Python"
|
description = "Code coverage measurement for Python"
|
||||||
optional = true
|
optional = true
|
||||||
python-versions = ">=3.8"
|
python-versions = ">=3.8"
|
||||||
files = [
|
files = [
|
||||||
{file = "coverage-7.5.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c0884920835a033b78d1c73b6d3bbcda8161a900f38a488829a83982925f6c2e"},
|
{file = "coverage-7.5.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:a6519d917abb15e12380406d721e37613e2a67d166f9fb7e5a8ce0375744cd45"},
|
||||||
{file = "coverage-7.5.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:39afcd3d4339329c5f58de48a52f6e4e50f6578dd6099961cf22228feb25f38f"},
|
{file = "coverage-7.5.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:aea7da970f1feccf48be7335f8b2ca64baf9b589d79e05b9397a06696ce1a1ec"},
|
||||||
{file = "coverage-7.5.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4a7b0ceee8147444347da6a66be737c9d78f3353b0681715b668b72e79203e4a"},
|
{file = "coverage-7.5.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:923b7b1c717bd0f0f92d862d1ff51d9b2b55dbbd133e05680204465f454bb286"},
|
||||||
{file = "coverage-7.5.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4a9ca3f2fae0088c3c71d743d85404cec8df9be818a005ea065495bedc33da35"},
|
{file = "coverage-7.5.3-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:62bda40da1e68898186f274f832ef3e759ce929da9a9fd9fcf265956de269dbc"},
|
||||||
{file = "coverage-7.5.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5fd215c0c7d7aab005221608a3c2b46f58c0285a819565887ee0b718c052aa4e"},
|
{file = "coverage-7.5.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d8b7339180d00de83e930358223c617cc343dd08e1aa5ec7b06c3a121aec4e1d"},
|
||||||
{file = "coverage-7.5.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:4bf0655ab60d754491004a5efd7f9cccefcc1081a74c9ef2da4735d6ee4a6223"},
|
{file = "coverage-7.5.3-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:25a5caf742c6195e08002d3b6c2dd6947e50efc5fc2c2205f61ecb47592d2d83"},
|
||||||
{file = "coverage-7.5.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:61c4bf1ba021817de12b813338c9be9f0ad5b1e781b9b340a6d29fc13e7c1b5e"},
|
{file = "coverage-7.5.3-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:05ac5f60faa0c704c0f7e6a5cbfd6f02101ed05e0aee4d2822637a9e672c998d"},
|
||||||
{file = "coverage-7.5.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:db66fc317a046556a96b453a58eced5024af4582a8dbdc0c23ca4dbc0d5b3146"},
|
{file = "coverage-7.5.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:239a4e75e09c2b12ea478d28815acf83334d32e722e7433471fbf641c606344c"},
|
||||||
{file = "coverage-7.5.1-cp310-cp310-win32.whl", hash = "sha256:b016ea6b959d3b9556cb401c55a37547135a587db0115635a443b2ce8f1c7228"},
|
{file = "coverage-7.5.3-cp310-cp310-win32.whl", hash = "sha256:a5812840d1d00eafae6585aba38021f90a705a25b8216ec7f66aebe5b619fb84"},
|
||||||
{file = "coverage-7.5.1-cp310-cp310-win_amd64.whl", hash = "sha256:df4e745a81c110e7446b1cc8131bf986157770fa405fe90e15e850aaf7619bc8"},
|
{file = "coverage-7.5.3-cp310-cp310-win_amd64.whl", hash = "sha256:33ca90a0eb29225f195e30684ba4a6db05dbef03c2ccd50b9077714c48153cac"},
|
||||||
{file = "coverage-7.5.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:796a79f63eca8814ca3317a1ea443645c9ff0d18b188de470ed7ccd45ae79428"},
|
{file = "coverage-7.5.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:f81bc26d609bf0fbc622c7122ba6307993c83c795d2d6f6f6fd8c000a770d974"},
|
||||||
{file = "coverage-7.5.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:4fc84a37bfd98db31beae3c2748811a3fa72bf2007ff7902f68746d9757f3746"},
|
{file = "coverage-7.5.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:7cec2af81f9e7569280822be68bd57e51b86d42e59ea30d10ebdbb22d2cb7232"},
|
||||||
{file = "coverage-7.5.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6175d1a0559986c6ee3f7fccfc4a90ecd12ba0a383dcc2da30c2b9918d67d8a3"},
|
{file = "coverage-7.5.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:55f689f846661e3f26efa535071775d0483388a1ccfab899df72924805e9e7cd"},
|
||||||
{file = "coverage-7.5.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1fc81d5878cd6274ce971e0a3a18a8803c3fe25457165314271cf78e3aae3aa2"},
|
{file = "coverage-7.5.3-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:50084d3516aa263791198913a17354bd1dc627d3c1639209640b9cac3fef5807"},
|
||||||
{file = "coverage-7.5.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:556cf1a7cbc8028cb60e1ff0be806be2eded2daf8129b8811c63e2b9a6c43bca"},
|
{file = "coverage-7.5.3-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:341dd8f61c26337c37988345ca5c8ccabeff33093a26953a1ac72e7d0103c4fb"},
|
||||||
{file = "coverage-7.5.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:9981706d300c18d8b220995ad22627647be11a4276721c10911e0e9fa44c83e8"},
|
{file = "coverage-7.5.3-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:ab0b028165eea880af12f66086694768f2c3139b2c31ad5e032c8edbafca6ffc"},
|
||||||
{file = "coverage-7.5.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:d7fed867ee50edf1a0b4a11e8e5d0895150e572af1cd6d315d557758bfa9c057"},
|
{file = "coverage-7.5.3-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:5bc5a8c87714b0c67cfeb4c7caa82b2d71e8864d1a46aa990b5588fa953673b8"},
|
||||||
{file = "coverage-7.5.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:ef48e2707fb320c8f139424a596f5b69955a85b178f15af261bab871873bb987"},
|
{file = "coverage-7.5.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:38a3b98dae8a7c9057bd91fbf3415c05e700a5114c5f1b5b0ea5f8f429ba6614"},
|
||||||
{file = "coverage-7.5.1-cp311-cp311-win32.whl", hash = "sha256:9314d5678dcc665330df5b69c1e726a0e49b27df0461c08ca12674bcc19ef136"},
|
{file = "coverage-7.5.3-cp311-cp311-win32.whl", hash = "sha256:fcf7d1d6f5da887ca04302db8e0e0cf56ce9a5e05f202720e49b3e8157ddb9a9"},
|
||||||
{file = "coverage-7.5.1-cp311-cp311-win_amd64.whl", hash = "sha256:5fa567e99765fe98f4e7d7394ce623e794d7cabb170f2ca2ac5a4174437e90dd"},
|
{file = "coverage-7.5.3-cp311-cp311-win_amd64.whl", hash = "sha256:8c836309931839cca658a78a888dab9676b5c988d0dd34ca247f5f3e679f4e7a"},
|
||||||
{file = "coverage-7.5.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:b6cf3764c030e5338e7f61f95bd21147963cf6aa16e09d2f74f1fa52013c1206"},
|
{file = "coverage-7.5.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:296a7d9bbc598e8744c00f7a6cecf1da9b30ae9ad51c566291ff1314e6cbbed8"},
|
||||||
{file = "coverage-7.5.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2ec92012fefebee89a6b9c79bc39051a6cb3891d562b9270ab10ecfdadbc0c34"},
|
{file = "coverage-7.5.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:34d6d21d8795a97b14d503dcaf74226ae51eb1f2bd41015d3ef332a24d0a17b3"},
|
||||||
{file = "coverage-7.5.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:16db7f26000a07efcf6aea00316f6ac57e7d9a96501e990a36f40c965ec7a95d"},
|
{file = "coverage-7.5.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8e317953bb4c074c06c798a11dbdd2cf9979dbcaa8ccc0fa4701d80042d4ebf1"},
|
||||||
{file = "coverage-7.5.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:beccf7b8a10b09c4ae543582c1319c6df47d78fd732f854ac68d518ee1fb97fa"},
|
{file = "coverage-7.5.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:705f3d7c2b098c40f5b81790a5fedb274113373d4d1a69e65f8b68b0cc26f6db"},
|
||||||
{file = "coverage-7.5.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8748731ad392d736cc9ccac03c9845b13bb07d020a33423fa5b3a36521ac6e4e"},
|
{file = "coverage-7.5.3-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b1196e13c45e327d6cd0b6e471530a1882f1017eb83c6229fc613cd1a11b53cd"},
|
||||||
{file = "coverage-7.5.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:7352b9161b33fd0b643ccd1f21f3a3908daaddf414f1c6cb9d3a2fd618bf2572"},
|
{file = "coverage-7.5.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:015eddc5ccd5364dcb902eaecf9515636806fa1e0d5bef5769d06d0f31b54523"},
|
||||||
{file = "coverage-7.5.1-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:7a588d39e0925f6a2bff87154752481273cdb1736270642aeb3635cb9b4cad07"},
|
{file = "coverage-7.5.3-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:fd27d8b49e574e50caa65196d908f80e4dff64d7e592d0c59788b45aad7e8b35"},
|
||||||
{file = "coverage-7.5.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:68f962d9b72ce69ea8621f57551b2fa9c70509af757ee3b8105d4f51b92b41a7"},
|
{file = "coverage-7.5.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:33fc65740267222fc02975c061eb7167185fef4cc8f2770267ee8bf7d6a42f84"},
|
||||||
{file = "coverage-7.5.1-cp312-cp312-win32.whl", hash = "sha256:f152cbf5b88aaeb836127d920dd0f5e7edff5a66f10c079157306c4343d86c19"},
|
{file = "coverage-7.5.3-cp312-cp312-win32.whl", hash = "sha256:7b2a19e13dfb5c8e145c7a6ea959485ee8e2204699903c88c7d25283584bfc08"},
|
||||||
{file = "coverage-7.5.1-cp312-cp312-win_amd64.whl", hash = "sha256:5a5740d1fb60ddf268a3811bcd353de34eb56dc24e8f52a7f05ee513b2d4f596"},
|
{file = "coverage-7.5.3-cp312-cp312-win_amd64.whl", hash = "sha256:0bbddc54bbacfc09b3edaec644d4ac90c08ee8ed4844b0f86227dcda2d428fcb"},
|
||||||
{file = "coverage-7.5.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:e2213def81a50519d7cc56ed643c9e93e0247f5bbe0d1247d15fa520814a7cd7"},
|
{file = "coverage-7.5.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:f78300789a708ac1f17e134593f577407d52d0417305435b134805c4fb135adb"},
|
||||||
{file = "coverage-7.5.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:5037f8fcc2a95b1f0e80585bd9d1ec31068a9bcb157d9750a172836e98bc7a90"},
|
{file = "coverage-7.5.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:b368e1aee1b9b75757942d44d7598dcd22a9dbb126affcbba82d15917f0cc155"},
|
||||||
{file = "coverage-7.5.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5c3721c2c9e4c4953a41a26c14f4cef64330392a6d2d675c8b1db3b645e31f0e"},
|
{file = "coverage-7.5.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f836c174c3a7f639bded48ec913f348c4761cbf49de4a20a956d3431a7c9cb24"},
|
||||||
{file = "coverage-7.5.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ca498687ca46a62ae590253fba634a1fe9836bc56f626852fb2720f334c9e4e5"},
|
{file = "coverage-7.5.3-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:244f509f126dc71369393ce5fea17c0592c40ee44e607b6d855e9c4ac57aac98"},
|
||||||
{file = "coverage-7.5.1-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0cdcbc320b14c3e5877ee79e649677cb7d89ef588852e9583e6b24c2e5072661"},
|
{file = "coverage-7.5.3-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c4c2872b3c91f9baa836147ca33650dc5c172e9273c808c3c3199c75490e709d"},
|
||||||
{file = "coverage-7.5.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:57e0204b5b745594e5bc14b9b50006da722827f0b8c776949f1135677e88d0b8"},
|
{file = "coverage-7.5.3-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:dd4b3355b01273a56b20c219e74e7549e14370b31a4ffe42706a8cda91f19f6d"},
|
||||||
{file = "coverage-7.5.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:8fe7502616b67b234482c3ce276ff26f39ffe88adca2acf0261df4b8454668b4"},
|
{file = "coverage-7.5.3-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:f542287b1489c7a860d43a7d8883e27ca62ab84ca53c965d11dac1d3a1fab7ce"},
|
||||||
{file = "coverage-7.5.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:9e78295f4144f9dacfed4f92935fbe1780021247c2fabf73a819b17f0ccfff8d"},
|
{file = "coverage-7.5.3-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:75e3f4e86804023e991096b29e147e635f5e2568f77883a1e6eed74512659ab0"},
|
||||||
{file = "coverage-7.5.1-cp38-cp38-win32.whl", hash = "sha256:1434e088b41594baa71188a17533083eabf5609e8e72f16ce8c186001e6b8c41"},
|
{file = "coverage-7.5.3-cp38-cp38-win32.whl", hash = "sha256:c59d2ad092dc0551d9f79d9d44d005c945ba95832a6798f98f9216ede3d5f485"},
|
||||||
{file = "coverage-7.5.1-cp38-cp38-win_amd64.whl", hash = "sha256:0646599e9b139988b63704d704af8e8df7fa4cbc4a1f33df69d97f36cb0a38de"},
|
{file = "coverage-7.5.3-cp38-cp38-win_amd64.whl", hash = "sha256:fa21a04112c59ad54f69d80e376f7f9d0f5f9123ab87ecd18fbb9ec3a2beed56"},
|
||||||
{file = "coverage-7.5.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:4cc37def103a2725bc672f84bd939a6fe4522310503207aae4d56351644682f1"},
|
{file = "coverage-7.5.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:f5102a92855d518b0996eb197772f5ac2a527c0ec617124ad5242a3af5e25f85"},
|
||||||
{file = "coverage-7.5.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:fc0b4d8bfeabd25ea75e94632f5b6e047eef8adaed0c2161ada1e922e7f7cece"},
|
{file = "coverage-7.5.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:d1da0a2e3b37b745a2b2a678a4c796462cf753aebf94edcc87dcc6b8641eae31"},
|
||||||
{file = "coverage-7.5.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0d0a0f5e06881ecedfe6f3dd2f56dcb057b6dbeb3327fd32d4b12854df36bf26"},
|
{file = "coverage-7.5.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8383a6c8cefba1b7cecc0149415046b6fc38836295bc4c84e820872eb5478b3d"},
|
||||||
{file = "coverage-7.5.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9735317685ba6ec7e3754798c8871c2f49aa5e687cc794a0b1d284b2389d1bd5"},
|
{file = "coverage-7.5.3-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9aad68c3f2566dfae84bf46295a79e79d904e1c21ccfc66de88cd446f8686341"},
|
||||||
{file = "coverage-7.5.1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d21918e9ef11edf36764b93101e2ae8cc82aa5efdc7c5a4e9c6c35a48496d601"},
|
{file = "coverage-7.5.3-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2e079c9ec772fedbade9d7ebc36202a1d9ef7291bc9b3a024ca395c4d52853d7"},
|
||||||
{file = "coverage-7.5.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:c3e757949f268364b96ca894b4c342b41dc6f8f8b66c37878aacef5930db61be"},
|
{file = "coverage-7.5.3-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:bde997cac85fcac227b27d4fb2c7608a2c5f6558469b0eb704c5726ae49e1c52"},
|
||||||
{file = "coverage-7.5.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:79afb6197e2f7f60c4824dd4b2d4c2ec5801ceb6ba9ce5d2c3080e5660d51a4f"},
|
{file = "coverage-7.5.3-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:990fb20b32990b2ce2c5f974c3e738c9358b2735bc05075d50a6f36721b8f303"},
|
||||||
{file = "coverage-7.5.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:d1d0d98d95dd18fe29dc66808e1accf59f037d5716f86a501fc0256455219668"},
|
{file = "coverage-7.5.3-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:3d5a67f0da401e105753d474369ab034c7bae51a4c31c77d94030d59e41df5bd"},
|
||||||
{file = "coverage-7.5.1-cp39-cp39-win32.whl", hash = "sha256:1cc0fe9b0b3a8364093c53b0b4c0c2dd4bb23acbec4c9240b5f284095ccf7981"},
|
{file = "coverage-7.5.3-cp39-cp39-win32.whl", hash = "sha256:e08c470c2eb01977d221fd87495b44867a56d4d594f43739a8028f8646a51e0d"},
|
||||||
{file = "coverage-7.5.1-cp39-cp39-win_amd64.whl", hash = "sha256:dde0070c40ea8bb3641e811c1cfbf18e265d024deff6de52c5950677a8fb1e0f"},
|
{file = "coverage-7.5.3-cp39-cp39-win_amd64.whl", hash = "sha256:1d2a830ade66d3563bb61d1e3c77c8def97b30ed91e166c67d0632c018f380f0"},
|
||||||
{file = "coverage-7.5.1-pp38.pp39.pp310-none-any.whl", hash = "sha256:6537e7c10cc47c595828b8a8be04c72144725c383c4702703ff4e42e44577312"},
|
{file = "coverage-7.5.3-pp38.pp39.pp310-none-any.whl", hash = "sha256:3538d8fb1ee9bdd2e2692b3b18c22bb1c19ffbefd06880f5ac496e42d7bb3884"},
|
||||||
{file = "coverage-7.5.1.tar.gz", hash = "sha256:54de9ef3a9da981f7af93eafde4ede199e0846cd819eb27c88e2b712aae9708c"},
|
{file = "coverage-7.5.3.tar.gz", hash = "sha256:04aefca5190d1dc7a53a4c1a5a7f8568811306d7a8ee231c42fb69215571944f"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
|
@ -785,6 +785,26 @@ files = [
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
six = ">=1.4.0"
|
six = ">=1.4.0"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "dora-rs"
|
||||||
|
version = "0.3.4"
|
||||||
|
description = "`dora` goal is to be a low latency, composable, and distributed data flow."
|
||||||
|
optional = true
|
||||||
|
python-versions = "*"
|
||||||
|
files = [
|
||||||
|
{file = "dora_rs-0.3.4-cp37-abi3-macosx_10_12_x86_64.whl", hash = "sha256:d1b738eea5a4966d731c26c6b6a0a50a491a24f7e9e335475f983cfc6f0da19e"},
|
||||||
|
{file = "dora_rs-0.3.4-cp37-abi3-macosx_11_0_arm64.whl", hash = "sha256:80b724871618c78a4e5863938fa66724176cc40352771087aebe1e62a8141157"},
|
||||||
|
{file = "dora_rs-0.3.4-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3a3919e157b47dc1dbc74c040a73087a4485f0d1bee99b6adcdbc36559400fe2"},
|
||||||
|
{file = "dora_rs-0.3.4-cp37-abi3-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f7c95f6e5858fd651d6cd220e4f052e99db2944b9c37fb0b5402d60ac4b41a63"},
|
||||||
|
{file = "dora_rs-0.3.4-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:37d915fbbca282446235c98a9ca08389aa3ef3155d4e88c6c136326e9a830042"},
|
||||||
|
{file = "dora_rs-0.3.4-cp37-abi3-win32.whl", hash = "sha256:c9f7f22f65c884ec9bee0245ce98d0c7fad25dec0f982e566f844b5e8e58818f"},
|
||||||
|
{file = "dora_rs-0.3.4-cp37-abi3-win_amd64.whl", hash = "sha256:0a6a37f96a9f6e13b58b02a6ea75af192af5fbe4f456f6a67b1f239c3cee3276"},
|
||||||
|
{file = "dora_rs-0.3.4.tar.gz", hash = "sha256:05c5d0db0d23d7c4669995ae34db11cd636dbf91f5705d832669bd04e7452903"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
pyarrow = "*"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "einops"
|
name = "einops"
|
||||||
version = "0.8.0"
|
version = "0.8.0"
|
||||||
|
@ -1066,6 +1086,27 @@ mujoco = ">=2.3.7,<3.0.0"
|
||||||
dev = ["debugpy (>=1.8.1)", "pre-commit (>=3.7.0)"]
|
dev = ["debugpy (>=1.8.1)", "pre-commit (>=3.7.0)"]
|
||||||
test = ["pytest (>=8.1.0)", "pytest-cov (>=5.0.0)"]
|
test = ["pytest (>=8.1.0)", "pytest-cov (>=5.0.0)"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "gym-dora"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = ""
|
||||||
|
optional = true
|
||||||
|
python-versions = "^3.10"
|
||||||
|
files = []
|
||||||
|
develop = false
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
dora-rs = ">=0.3.4"
|
||||||
|
gymnasium = ">=0.29.1"
|
||||||
|
pyarrow = ">=12.0.0"
|
||||||
|
|
||||||
|
[package.source]
|
||||||
|
type = "git"
|
||||||
|
url = "https://github.com/dora-rs/dora-lerobot.git"
|
||||||
|
reference = "HEAD"
|
||||||
|
resolved_reference = "ed0c00a4fdc6ec856c9842551acd7dc7ee776f79"
|
||||||
|
subdirectory = "gym_dora"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "gym-pusht"
|
name = "gym-pusht"
|
||||||
version = "0.1.4"
|
version = "0.1.4"
|
||||||
|
@ -1269,13 +1310,13 @@ files = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "huggingface-hub"
|
name = "huggingface-hub"
|
||||||
version = "0.23.1"
|
version = "0.23.2"
|
||||||
description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub"
|
description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.8.0"
|
python-versions = ">=3.8.0"
|
||||||
files = [
|
files = [
|
||||||
{file = "huggingface_hub-0.23.1-py3-none-any.whl", hash = "sha256:720a5bffd2b1b449deb793da8b0df7a9390a7e238534d5a08c9fbcdecb1dd3cb"},
|
{file = "huggingface_hub-0.23.2-py3-none-any.whl", hash = "sha256:48727a16e704d409c4bb5913613308499664f22a99743435dc3a13b23c485827"},
|
||||||
{file = "huggingface_hub-0.23.1.tar.gz", hash = "sha256:4f62dbf6ae94f400c6d3419485e52bce510591432a5248a65d0cb72e4d479eb4"},
|
{file = "huggingface_hub-0.23.2.tar.gz", hash = "sha256:f6829b62d5fdecb452a76fdbec620cba4c1573655a8d710c1df71735fd9edbd2"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
|
@ -2061,18 +2102,15 @@ test = ["pytest (>=7.2)", "pytest-cov (>=4.0)"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "nodeenv"
|
name = "nodeenv"
|
||||||
version = "1.8.0"
|
version = "1.9.0"
|
||||||
description = "Node.js virtual environment builder"
|
description = "Node.js virtual environment builder"
|
||||||
optional = true
|
optional = true
|
||||||
python-versions = ">=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*"
|
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7"
|
||||||
files = [
|
files = [
|
||||||
{file = "nodeenv-1.8.0-py2.py3-none-any.whl", hash = "sha256:df865724bb3c3adc86b3876fa209771517b0cfe596beff01a92700e0e8be4cec"},
|
{file = "nodeenv-1.9.0-py2.py3-none-any.whl", hash = "sha256:508ecec98f9f3330b636d4448c0f1a56fc68017c68f1e7857ebc52acf0eb879a"},
|
||||||
{file = "nodeenv-1.8.0.tar.gz", hash = "sha256:d51e0c37e64fbf47d017feac3145cdbb58836d7eee8c6f6d3b6880c5456227d2"},
|
{file = "nodeenv-1.9.0.tar.gz", hash = "sha256:07f144e90dae547bf0d4ee8da0ee42664a42a04e02ed68e06324348dafe4bdb1"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
|
||||||
setuptools = "*"
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "numba"
|
name = "numba"
|
||||||
version = "0.59.1"
|
version = "0.59.1"
|
||||||
|
@ -2406,6 +2444,7 @@ optional = false
|
||||||
python-versions = ">=3.9"
|
python-versions = ">=3.9"
|
||||||
files = [
|
files = [
|
||||||
{file = "pandas-2.2.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:90c6fca2acf139569e74e8781709dccb6fe25940488755716d1d354d6bc58bce"},
|
{file = "pandas-2.2.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:90c6fca2acf139569e74e8781709dccb6fe25940488755716d1d354d6bc58bce"},
|
||||||
|
{file = "pandas-2.2.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c7adfc142dac335d8c1e0dcbd37eb8617eac386596eb9e1a1b77791cf2498238"},
|
||||||
{file = "pandas-2.2.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4abfe0be0d7221be4f12552995e58723c7422c80a659da13ca382697de830c08"},
|
{file = "pandas-2.2.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4abfe0be0d7221be4f12552995e58723c7422c80a659da13ca382697de830c08"},
|
||||||
{file = "pandas-2.2.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8635c16bf3d99040fdf3ca3db669a7250ddf49c55dc4aa8fe0ae0fa8d6dcc1f0"},
|
{file = "pandas-2.2.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8635c16bf3d99040fdf3ca3db669a7250ddf49c55dc4aa8fe0ae0fa8d6dcc1f0"},
|
||||||
{file = "pandas-2.2.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:40ae1dffb3967a52203105a077415a86044a2bea011b5f321c6aa64b379a3f51"},
|
{file = "pandas-2.2.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:40ae1dffb3967a52203105a077415a86044a2bea011b5f321c6aa64b379a3f51"},
|
||||||
|
@ -2426,6 +2465,7 @@ files = [
|
||||||
{file = "pandas-2.2.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:43498c0bdb43d55cb162cdc8c06fac328ccb5d2eabe3cadeb3529ae6f0517c32"},
|
{file = "pandas-2.2.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:43498c0bdb43d55cb162cdc8c06fac328ccb5d2eabe3cadeb3529ae6f0517c32"},
|
||||||
{file = "pandas-2.2.2-cp312-cp312-win_amd64.whl", hash = "sha256:d187d355ecec3629624fccb01d104da7d7f391db0311145817525281e2804d23"},
|
{file = "pandas-2.2.2-cp312-cp312-win_amd64.whl", hash = "sha256:d187d355ecec3629624fccb01d104da7d7f391db0311145817525281e2804d23"},
|
||||||
{file = "pandas-2.2.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:0ca6377b8fca51815f382bd0b697a0814c8bda55115678cbc94c30aacbb6eff2"},
|
{file = "pandas-2.2.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:0ca6377b8fca51815f382bd0b697a0814c8bda55115678cbc94c30aacbb6eff2"},
|
||||||
|
{file = "pandas-2.2.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9057e6aa78a584bc93a13f0a9bf7e753a5e9770a30b4d758b8d5f2a62a9433cd"},
|
||||||
{file = "pandas-2.2.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:001910ad31abc7bf06f49dcc903755d2f7f3a9186c0c040b827e522e9cef0863"},
|
{file = "pandas-2.2.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:001910ad31abc7bf06f49dcc903755d2f7f3a9186c0c040b827e522e9cef0863"},
|
||||||
{file = "pandas-2.2.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:66b479b0bd07204e37583c191535505410daa8df638fd8e75ae1b383851fe921"},
|
{file = "pandas-2.2.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:66b479b0bd07204e37583c191535505410daa8df638fd8e75ae1b383851fe921"},
|
||||||
{file = "pandas-2.2.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:a77e9d1c386196879aa5eb712e77461aaee433e54c68cf253053a73b7e49c33a"},
|
{file = "pandas-2.2.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:a77e9d1c386196879aa5eb712e77461aaee433e54c68cf253053a73b7e49c33a"},
|
||||||
|
@ -3188,13 +3228,13 @@ files = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "requests"
|
name = "requests"
|
||||||
version = "2.32.2"
|
version = "2.32.3"
|
||||||
description = "Python HTTP for Humans."
|
description = "Python HTTP for Humans."
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.8"
|
python-versions = ">=3.8"
|
||||||
files = [
|
files = [
|
||||||
{file = "requests-2.32.2-py3-none-any.whl", hash = "sha256:fc06670dd0ed212426dfeb94fc1b983d917c4f9847c863f313c9dfaaffb7c23c"},
|
{file = "requests-2.32.3-py3-none-any.whl", hash = "sha256:70761cfe03c773ceb22aa2f671b4757976145175cdfca038c02654d061d6dcc6"},
|
||||||
{file = "requests-2.32.2.tar.gz", hash = "sha256:dd951ff5ecf3e3b3aa26b40703ba77495dab41da839ae72ef3c8e5d8e2433289"},
|
{file = "requests-2.32.3.tar.gz", hash = "sha256:55365417734eb18255590a9ff9eb97e9e1da868d4ccd6402399eaf68af20a760"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
|
@ -3210,16 +3250,16 @@ use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "rerun-sdk"
|
name = "rerun-sdk"
|
||||||
version = "0.16.0"
|
version = "0.16.1"
|
||||||
description = "The Rerun Logging SDK"
|
description = "The Rerun Logging SDK"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = "<3.13,>=3.8"
|
python-versions = "<3.13,>=3.8"
|
||||||
files = [
|
files = [
|
||||||
{file = "rerun_sdk-0.16.0-cp38-abi3-macosx_10_12_x86_64.whl", hash = "sha256:1cc6dc66d089e296f945dc238301889efb61dd6d338b5d00f76981cf7aed0a74"},
|
{file = "rerun_sdk-0.16.1-cp38-abi3-macosx_10_12_x86_64.whl", hash = "sha256:170c6976634008611753e10dfef8cdc395ce8180e634c169e7c61cef2f89a277"},
|
||||||
{file = "rerun_sdk-0.16.0-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:faf231897655e46eb975695df2b0ace07db362d697e697f9a3dff52f81c0dc5d"},
|
{file = "rerun_sdk-0.16.1-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:c9a76eab7eb5559276737dad655200e9350df0837158dbc5a896970ab4201454"},
|
||||||
{file = "rerun_sdk-0.16.0-cp38-abi3-manylinux_2_31_aarch64.whl", hash = "sha256:860a6394380d3e9b9e48bf34423bd56dda54d5b0158d2ae0e433698659b34198"},
|
{file = "rerun_sdk-0.16.1-cp38-abi3-manylinux_2_31_aarch64.whl", hash = "sha256:4d6436752d57e8b8038489a0e7e37f0c760b088e96db5fb81667d3a376d63fea"},
|
||||||
{file = "rerun_sdk-0.16.0-cp38-abi3-manylinux_2_31_x86_64.whl", hash = "sha256:5b8d1476f73a3ad1a5d3f21b61c633f3ab62aa80fa0b049f5ad10bf1227681ab"},
|
{file = "rerun_sdk-0.16.1-cp38-abi3-manylinux_2_31_x86_64.whl", hash = "sha256:37b7b47948471873e84f224b16f417a94a91c7cbd6c72c68281eeff1ba414b8f"},
|
||||||
{file = "rerun_sdk-0.16.0-cp38-abi3-win_amd64.whl", hash = "sha256:aff0051a263b8c3067243c0126d319845baf4fe640899f04aeef7daf151f35e4"},
|
{file = "rerun_sdk-0.16.1-cp38-abi3-win_amd64.whl", hash = "sha256:be88799c8afdf68eafa99e64e2e4f0a484e187e017a180219abbe6bb988acd4e"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
|
@ -3696,17 +3736,17 @@ files = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "sympy"
|
name = "sympy"
|
||||||
version = "1.12"
|
version = "1.12.1"
|
||||||
description = "Computer algebra system (CAS) in Python"
|
description = "Computer algebra system (CAS) in Python"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.8"
|
python-versions = ">=3.8"
|
||||||
files = [
|
files = [
|
||||||
{file = "sympy-1.12-py3-none-any.whl", hash = "sha256:c3588cd4295d0c0f603d0f2ae780587e64e2efeedb3521e46b9bb1d08d184fa5"},
|
{file = "sympy-1.12.1-py3-none-any.whl", hash = "sha256:9b2cbc7f1a640289430e13d2a56f02f867a1da0190f2f99d8968c2f74da0e515"},
|
||||||
{file = "sympy-1.12.tar.gz", hash = "sha256:ebf595c8dac3e0fdc4152c51878b498396ec7f30e7a914d6071e674d49420fb8"},
|
{file = "sympy-1.12.1.tar.gz", hash = "sha256:2877b03f998cd8c08f07cd0de5b767119cd3ef40d09f41c30d722f6686b0fb88"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
mpmath = ">=0.19"
|
mpmath = ">=1.1.0,<1.4.0"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tbb"
|
name = "tbb"
|
||||||
|
@ -4220,13 +4260,13 @@ multidict = ">=4.0"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "zarr"
|
name = "zarr"
|
||||||
version = "2.18.1"
|
version = "2.18.2"
|
||||||
description = "An implementation of chunked, compressed, N-dimensional arrays for Python"
|
description = "An implementation of chunked, compressed, N-dimensional arrays for Python"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.9"
|
python-versions = ">=3.9"
|
||||||
files = [
|
files = [
|
||||||
{file = "zarr-2.18.1-py3-none-any.whl", hash = "sha256:a1770d194eec4ec0a41a01295a6f724e1c3471d704d3aca906d3b3a7f8830245"},
|
{file = "zarr-2.18.2-py3-none-any.whl", hash = "sha256:a638754902f97efa99b406083fdc807a0e2ccf12a949117389d2a4ba9b05df38"},
|
||||||
{file = "zarr-2.18.1.tar.gz", hash = "sha256:28c360ed123e606c425a694a83300227a907cb86a995fc9eef620ecafbe5f92d"},
|
{file = "zarr-2.18.2.tar.gz", hash = "sha256:9bb393b8a0a38fb121dbb913b047d75db28de9890f6d644a217a73cf4ae74f47"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
|
@ -4241,13 +4281,13 @@ jupyter = ["ipytree (>=0.2.2)", "ipywidgets (>=8.0.0)", "notebook"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "zipp"
|
name = "zipp"
|
||||||
version = "3.18.2"
|
version = "3.19.0"
|
||||||
description = "Backport of pathlib-compatible object wrapper for zip files"
|
description = "Backport of pathlib-compatible object wrapper for zip files"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.8"
|
python-versions = ">=3.8"
|
||||||
files = [
|
files = [
|
||||||
{file = "zipp-3.18.2-py3-none-any.whl", hash = "sha256:dce197b859eb796242b0622af1b8beb0a722d52aa2f57133ead08edd5bf5374e"},
|
{file = "zipp-3.19.0-py3-none-any.whl", hash = "sha256:96dc6ad62f1441bcaccef23b274ec471518daf4fbbc580341204936a5a3dddec"},
|
||||||
{file = "zipp-3.18.2.tar.gz", hash = "sha256:6278d9ddbcfb1f1089a88fde84481528b07b0e10474e09dcfe53dad4069fa059"},
|
{file = "zipp-3.19.0.tar.gz", hash = "sha256:952df858fb3164426c976d9338d3961e8e8b3758e2e059e0f754b8c4262625ee"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.extras]
|
[package.extras]
|
||||||
|
@ -4257,6 +4297,7 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "jaraco.test", "more
|
||||||
[extras]
|
[extras]
|
||||||
aloha = ["gym-aloha"]
|
aloha = ["gym-aloha"]
|
||||||
dev = ["debugpy", "pre-commit"]
|
dev = ["debugpy", "pre-commit"]
|
||||||
|
dora = ["gym-dora"]
|
||||||
pusht = ["gym-pusht"]
|
pusht = ["gym-pusht"]
|
||||||
test = ["pytest", "pytest-cov"]
|
test = ["pytest", "pytest-cov"]
|
||||||
umi = ["imagecodecs"]
|
umi = ["imagecodecs"]
|
||||||
|
@ -4265,4 +4306,4 @@ xarm = ["gym-xarm"]
|
||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = ">=3.10,<3.13"
|
python-versions = ">=3.10,<3.13"
|
||||||
content-hash = "1ad6ef0f88f0056ab639e60e033e586f7460a9c5fc3676a477bbd47923f41cb6"
|
content-hash = "23ddb8dd774a4faf85d08a07dfdf19badb7c370120834b71df4afca254520771"
|
||||||
|
|
|
@ -46,6 +46,7 @@ h5py = ">=3.10.0"
|
||||||
huggingface-hub = {extras = ["hf-transfer"], version = "^0.23.0"}
|
huggingface-hub = {extras = ["hf-transfer"], version = "^0.23.0"}
|
||||||
gymnasium = ">=0.29.1"
|
gymnasium = ">=0.29.1"
|
||||||
cmake = ">=3.29.0.1"
|
cmake = ">=3.29.0.1"
|
||||||
|
gym-dora = { git = "https://github.com/dora-rs/dora-lerobot.git", subdirectory = "gym_dora", optional = true }
|
||||||
gym-pusht = { version = ">=0.1.3", optional = true}
|
gym-pusht = { version = ">=0.1.3", optional = true}
|
||||||
gym-xarm = { version = ">=0.1.1", optional = true}
|
gym-xarm = { version = ">=0.1.1", optional = true}
|
||||||
gym-aloha = { version = ">=0.1.1", optional = true}
|
gym-aloha = { version = ">=0.1.1", optional = true}
|
||||||
|
@ -62,6 +63,7 @@ deepdiff = ">=7.0.1"
|
||||||
|
|
||||||
|
|
||||||
[tool.poetry.extras]
|
[tool.poetry.extras]
|
||||||
|
dora = ["gym-dora"]
|
||||||
pusht = ["gym-pusht"]
|
pusht = ["gym-pusht"]
|
||||||
xarm = ["gym-xarm"]
|
xarm = ["gym-xarm"]
|
||||||
aloha = ["gym-aloha"]
|
aloha = ["gym-aloha"]
|
||||||
|
|
|
@ -0,0 +1,3 @@
|
||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:2fff6294b94cf42d4dd1249dcc5c3b0269d6d9c697f894e61b867d7ab81a94e4
|
||||||
|
size 5104
|
|
@ -0,0 +1,3 @@
|
||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:4aa23e51607604a18b70fa42edbbe1af34f119d985628fc27cc1bbb0efbc8901
|
||||||
|
size 31688
|
|
@ -0,0 +1,3 @@
|
||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:6fd368406c93cb562a69ff11cf7adf34a4b223507dcb2b9e9b8f44ee1036988a
|
||||||
|
size 68
|
|
@ -0,0 +1,3 @@
|
||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:5663ee79a13bb70a1604b887dd21bf89d18482287442419c6cc6c5bf0e753e99
|
||||||
|
size 34928
|
|
@ -0,0 +1,3 @@
|
||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:fb1a45463efd860af2ca22c16c77d55a18bd96fef080ae77978845a2f22ef716
|
||||||
|
size 5104
|
|
@ -0,0 +1,3 @@
|
||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:aa5a43e22f01d8e2f8d19f31753608794f1edbd74aaf71660091ab80ea58dc9b
|
||||||
|
size 30808
|
|
@ -0,0 +1,3 @@
|
||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:97455b4360748c99905cd103473c1a52da6901d0a73ffbc51b5ea3eb250d1386
|
||||||
|
size 68
|
|
@ -0,0 +1,3 @@
|
||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:54d1f75cf67a7b1d7a7c6865ecb9b1cc86a2f032d1890245f8996789ab6e0df6
|
||||||
|
size 33608
|
|
@ -75,15 +75,16 @@ def get_policy_stats(env_name, policy_name, extra_overrides):
|
||||||
# HACK: We reload a batch with no delta_timestamps as `select_action` won't expect a timestamps dimension
|
# HACK: We reload a batch with no delta_timestamps as `select_action` won't expect a timestamps dimension
|
||||||
dataset.delta_timestamps = None
|
dataset.delta_timestamps = None
|
||||||
batch = next(iter(dataloader))
|
batch = next(iter(dataloader))
|
||||||
obs = {
|
obs = {}
|
||||||
k: batch[k]
|
for k in batch:
|
||||||
for k in batch
|
if k.startswith("observation"):
|
||||||
if k in ["observation.image", "observation.images.top", "observation.state"]
|
obs[k] = batch[k]
|
||||||
}
|
|
||||||
|
if "n_action_steps" in cfg.policy:
|
||||||
|
actions_queue = cfg.policy.n_action_steps
|
||||||
|
else:
|
||||||
|
actions_queue = cfg.policy.n_action_repeats
|
||||||
|
|
||||||
actions_queue = (
|
|
||||||
cfg.policy.n_action_steps if "n_action_steps" in cfg.policy else cfg.policy.n_action_repeats
|
|
||||||
)
|
|
||||||
actions = {str(i): policy.select_action(obs).contiguous() for i in range(actions_queue)}
|
actions = {str(i): policy.select_action(obs).contiguous() for i in range(actions_queue)}
|
||||||
return output_dict, grad_stats, param_stats, actions
|
return output_dict, grad_stats, param_stats, actions
|
||||||
|
|
||||||
|
@ -114,6 +115,8 @@ if __name__ == "__main__":
|
||||||
["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"],
|
["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"],
|
||||||
),
|
),
|
||||||
("aloha", "act", ["policy.n_action_steps=10"]),
|
("aloha", "act", ["policy.n_action_steps=10"]),
|
||||||
|
("dora_aloha_real", "act_real", ["policy.n_action_steps=10"]),
|
||||||
|
("dora_aloha_real", "act_real_no_state", ["policy.n_action_steps=10"]),
|
||||||
]
|
]
|
||||||
for env, policy, extra_overrides in env_policies:
|
for env, policy, extra_overrides in env_policies:
|
||||||
save_policy_to_safetensors("tests/data/save_policy_to_safetensors", env, policy, extra_overrides)
|
save_policy_to_safetensors("tests/data/save_policy_to_safetensors", env, policy, extra_overrides)
|
|
@ -30,7 +30,7 @@ from lerobot.common.policies.factory import get_policy_and_config_classes, make_
|
||||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||||
from lerobot.common.policies.policy_protocol import Policy
|
from lerobot.common.policies.policy_protocol import Policy
|
||||||
from lerobot.common.utils.utils import init_hydra_config
|
from lerobot.common.utils.utils import init_hydra_config
|
||||||
from tests.scripts.save_policy_to_safetensor import get_policy_stats
|
from tests.scripts.save_policy_to_safetensors import get_policy_stats
|
||||||
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_cpu, require_env, require_x86_64_kernel
|
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_cpu, require_env, require_x86_64_kernel
|
||||||
|
|
||||||
|
|
||||||
|
@ -72,6 +72,8 @@ def test_get_policy_and_config_classes(policy_name: str):
|
||||||
),
|
),
|
||||||
# Note: these parameters also need custom logic in the test function for overriding the Hydra config.
|
# Note: these parameters also need custom logic in the test function for overriding the Hydra config.
|
||||||
("pusht", "act", ["env.task=PushT-v0", "dataset_repo_id=lerobot/pusht"]),
|
("pusht", "act", ["env.task=PushT-v0", "dataset_repo_id=lerobot/pusht"]),
|
||||||
|
("dora_aloha_real", "act_real", []),
|
||||||
|
("dora_aloha_real", "act_real_no_state", []),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@require_env
|
@require_env
|
||||||
|
@ -84,6 +86,9 @@ def test_policy(env_name, policy_name, extra_overrides):
|
||||||
- Updating the policy.
|
- Updating the policy.
|
||||||
- Using the policy to select actions at inference time.
|
- Using the policy to select actions at inference time.
|
||||||
- Test the action can be applied to the policy
|
- Test the action can be applied to the policy
|
||||||
|
|
||||||
|
Note: We test various combinations of policy and dataset. The combinations are by no means exhaustive,
|
||||||
|
and for now we add tests as we see fit.
|
||||||
"""
|
"""
|
||||||
cfg = init_hydra_config(
|
cfg = init_hydra_config(
|
||||||
DEFAULT_CONFIG_PATH,
|
DEFAULT_CONFIG_PATH,
|
||||||
|
@ -135,7 +140,7 @@ def test_policy(env_name, policy_name, extra_overrides):
|
||||||
|
|
||||||
dataloader = torch.utils.data.DataLoader(
|
dataloader = torch.utils.data.DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
num_workers=4,
|
num_workers=0,
|
||||||
batch_size=2,
|
batch_size=2,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
pin_memory=DEVICE != "cpu",
|
pin_memory=DEVICE != "cpu",
|
||||||
|
@ -291,6 +296,8 @@ def test_normalize(insert_temporal_dim):
|
||||||
["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"],
|
["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"],
|
||||||
),
|
),
|
||||||
("aloha", "act", ["policy.n_action_steps=10"]),
|
("aloha", "act", ["policy.n_action_steps=10"]),
|
||||||
|
("dora_aloha_real", "act_real", ["policy.n_action_steps=10"]),
|
||||||
|
("dora_aloha_real", "act_real_no_state", ["policy.n_action_steps=10"]),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
# As artifacts have been generated on an x86_64 kernel, this test won't
|
# As artifacts have been generated on an x86_64 kernel, this test won't
|
||||||
|
|
Loading…
Reference in New Issue