[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot] 2025-03-24 13:16:38 +00:00
parent 313812df16
commit ce0008850d
97 changed files with 1596 additions and 492 deletions

View File

@ -22,7 +22,10 @@ from pathlib import Path
import torch
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
from lerobot.common.datasets.lerobot_dataset import (
LeRobotDataset,
LeRobotDatasetMetadata,
)
from lerobot.common.datasets.utils import dataset_to_policy_features
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
@ -48,12 +51,18 @@ def main():
# - dataset stats: for normalization and denormalization of input/outputs
dataset_metadata = LeRobotDatasetMetadata("lerobot/pusht")
features = dataset_to_policy_features(dataset_metadata.features)
output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
input_features = {key: ft for key, ft in features.items() if key not in output_features}
output_features = {
key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION
}
input_features = {
key: ft for key, ft in features.items() if key not in output_features
}
# Policies are initialized with a configuration class, in this case `DiffusionConfig`. For this example,
# we'll just use the defaults and so no arguments other than input/output features need to be passed.
cfg = DiffusionConfig(input_features=input_features, output_features=output_features)
cfg = DiffusionConfig(
input_features=input_features, output_features=output_features
)
# We can now instantiate our policy with this config and the dataset stats.
policy = DiffusionPolicy(cfg, dataset_stats=dataset_metadata.stats)
@ -63,8 +72,12 @@ def main():
# Another policy-dataset interaction is with the delta_timestamps. Each policy expects a given number frames
# which can differ for inputs, outputs and rewards (if there are some).
delta_timestamps = {
"observation.image": [i / dataset_metadata.fps for i in cfg.observation_delta_indices],
"observation.state": [i / dataset_metadata.fps for i in cfg.observation_delta_indices],
"observation.image": [
i / dataset_metadata.fps for i in cfg.observation_delta_indices
],
"observation.state": [
i / dataset_metadata.fps for i in cfg.observation_delta_indices
],
"action": [i / dataset_metadata.fps for i in cfg.action_delta_indices],
}
@ -77,7 +90,24 @@ def main():
# Load the previous action (-0.1), the next action to be executed (0.0),
# and 14 future actions with a 0.1 seconds spacing. All these actions will be
# used to supervise the policy.
"action": [-0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4],
"action": [
-0.1,
0.0,
0.1,
0.2,
0.3,
0.4,
0.5,
0.6,
0.7,
0.8,
0.9,
1.0,
1.1,
1.2,
1.3,
1.4,
],
}
# We can then instantiate the dataset with these delta_timestamps configuration.
@ -99,7 +129,10 @@ def main():
done = False
while not done:
for batch in dataloader:
batch = {k: (v.to(device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()}
batch = {
k: (v.to(device) if isinstance(v, torch.Tensor) else v)
for k, v in batch.items()
}
loss, _ = policy.forward(batch)
loss.backward()
optimizer.step()

View File

@ -54,7 +54,24 @@ def main():
# Load the previous action (-0.1), the next action to be executed (0.0),
# and 14 future actions with a 0.1 seconds spacing. All these actions will be
# used to calculate the loss.
"action": [-0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4],
"action": [
-0.1,
0.0,
0.1,
0.2,
0.3,
0.4,
0.5,
0.6,
0.7,
0.8,
0.9,
1.0,
1.1,
1.2,
1.3,
1.4,
],
}
# Load the last 10% of episodes of the dataset as a validation set.
@ -73,7 +90,9 @@ def main():
train_dataset = LeRobotDataset(
"lerobot/pusht", episodes=train_episodes, delta_timestamps=delta_timestamps
)
val_dataset = LeRobotDataset("lerobot/pusht", episodes=val_episodes, delta_timestamps=delta_timestamps)
val_dataset = LeRobotDataset(
"lerobot/pusht", episodes=val_episodes, delta_timestamps=delta_timestamps
)
print(f"Number of frames in training dataset (90% subset): {len(train_dataset)}")
print(f"Number of frames in validation dataset (10% subset): {len(val_dataset)}")

View File

@ -19,7 +19,10 @@ from lerobot.common.datasets.utils import load_image_as_numpy
def estimate_num_samples(
dataset_len: int, min_num_samples: int = 100, max_num_samples: int = 10_000, power: float = 0.75
dataset_len: int,
min_num_samples: int = 100,
max_num_samples: int = 10_000,
power: float = 0.75,
) -> int:
"""Heuristic to estimate the number of samples based on dataset size.
The power controls the sample growth relative to dataset size.
@ -43,14 +46,18 @@ def sample_indices(data_len: int) -> list[int]:
return np.round(np.linspace(0, data_len - 1, num_samples)).astype(int).tolist()
def auto_downsample_height_width(img: np.ndarray, target_size: int = 150, max_size_threshold: int = 300):
def auto_downsample_height_width(
img: np.ndarray, target_size: int = 150, max_size_threshold: int = 300
):
_, height, width = img.shape
if max(width, height) < max_size_threshold:
# no downsampling needed
return img
downsample_factor = int(width / target_size) if width > height else int(height / target_size)
downsample_factor = (
int(width / target_size) if width > height else int(height / target_size)
)
return img[:, ::downsample_factor, ::downsample_factor]
@ -72,7 +79,9 @@ def sample_images(image_paths: list[str]) -> np.ndarray:
return images
def get_feature_stats(array: np.ndarray, axis: tuple, keepdims: bool) -> dict[str, np.ndarray]:
def get_feature_stats(
array: np.ndarray, axis: tuple, keepdims: bool
) -> dict[str, np.ndarray]:
return {
"min": np.min(array, axis=axis, keepdims=keepdims),
"max": np.max(array, axis=axis, keepdims=keepdims),
@ -82,7 +91,9 @@ def get_feature_stats(array: np.ndarray, axis: tuple, keepdims: bool) -> dict[st
}
def compute_episode_stats(episode_data: dict[str, list[str] | np.ndarray], features: dict) -> dict:
def compute_episode_stats(
episode_data: dict[str, list[str] | np.ndarray], features: dict
) -> dict:
ep_stats = {}
for key, data in episode_data.items():
if features[key]["dtype"] == "string":
@ -96,12 +107,15 @@ def compute_episode_stats(episode_data: dict[str, list[str] | np.ndarray], featu
axes_to_reduce = 0 # compute stats over the first axis
keepdims = data.ndim == 1 # keep as np.array
ep_stats[key] = get_feature_stats(ep_ft_array, axis=axes_to_reduce, keepdims=keepdims)
ep_stats[key] = get_feature_stats(
ep_ft_array, axis=axes_to_reduce, keepdims=keepdims
)
# finally, we normalize and remove batch dim for images
if features[key]["dtype"] in ["image", "video"]:
ep_stats[key] = {
k: v if k == "count" else np.squeeze(v / 255.0, axis=0) for k, v in ep_stats[key].items()
k: v if k == "count" else np.squeeze(v / 255.0, axis=0)
for k, v in ep_stats[key].items()
}
return ep_stats
@ -116,14 +130,22 @@ def _assert_type_and_shape(stats_list: list[dict[str, dict]]):
f"Stats must be composed of numpy array, but key '{k}' of feature '{fkey}' is of type '{type(v)}' instead."
)
if v.ndim == 0:
raise ValueError("Number of dimensions must be at least 1, and is 0 instead.")
raise ValueError(
"Number of dimensions must be at least 1, and is 0 instead."
)
if k == "count" and v.shape != (1,):
raise ValueError(f"Shape of 'count' must be (1), but is {v.shape} instead.")
raise ValueError(
f"Shape of 'count' must be (1), but is {v.shape} instead."
)
if "image" in fkey and k != "count" and v.shape != (3, 1, 1):
raise ValueError(f"Shape of '{k}' must be (3,1,1), but is {v.shape} instead.")
raise ValueError(
f"Shape of '{k}' must be (3,1,1), but is {v.shape} instead."
)
def aggregate_feature_stats(stats_ft_list: list[dict[str, dict]]) -> dict[str, dict[str, np.ndarray]]:
def aggregate_feature_stats(
stats_ft_list: list[dict[str, dict]],
) -> dict[str, dict[str, np.ndarray]]:
"""Aggregates stats for a single feature."""
means = np.stack([s["mean"] for s in stats_ft_list])
variances = np.stack([s["std"] ** 2 for s in stats_ft_list])
@ -152,7 +174,9 @@ def aggregate_feature_stats(stats_ft_list: list[dict[str, dict]]) -> dict[str, d
}
def aggregate_stats(stats_list: list[dict[str, dict]]) -> dict[str, dict[str, np.ndarray]]:
def aggregate_stats(
stats_list: list[dict[str, dict]],
) -> dict[str, dict[str, np.ndarray]]:
"""Aggregate stats from multiple compute_stats outputs into a single set of stats.
The final stats will have the union of all data keys from each of the stats dicts.

View File

@ -58,7 +58,9 @@ def resolve_delta_timestamps(
if key == "action" and cfg.action_delta_indices is not None:
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.action_delta_indices]
if key.startswith("observation.") and cfg.observation_delta_indices is not None:
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.observation_delta_indices]
delta_timestamps[key] = [
i / ds_meta.fps for i in cfg.observation_delta_indices
]
if len(delta_timestamps) == 0:
delta_timestamps = None
@ -79,7 +81,9 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
LeRobotDataset | MultiLeRobotDataset
"""
image_transforms = (
ImageTransforms(cfg.dataset.image_transforms) if cfg.dataset.image_transforms.enable else None
ImageTransforms(cfg.dataset.image_transforms)
if cfg.dataset.image_transforms.enable
else None
)
if isinstance(cfg.dataset.repo_id, str):
@ -113,6 +117,8 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
if cfg.dataset.use_imagenet_stats:
for key in dataset.meta.camera_keys:
for stats_type, stats in IMAGENET_STATS.items():
dataset.meta.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32)
dataset.meta.stats[key][stats_type] = torch.tensor(
stats, dtype=torch.float32
)
return dataset

View File

@ -38,10 +38,14 @@ def safe_stop_image_writer(func):
return wrapper
def image_array_to_pil_image(image_array: np.ndarray, range_check: bool = True) -> PIL.Image.Image:
def image_array_to_pil_image(
image_array: np.ndarray, range_check: bool = True
) -> PIL.Image.Image:
# TODO(aliberts): handle 1 channel and 4 for depth images
if image_array.ndim != 3:
raise ValueError(f"The array has {image_array.ndim} dimensions, but 3 is expected for an image.")
raise ValueError(
f"The array has {image_array.ndim} dimensions, but 3 is expected for an image."
)
if image_array.shape[0] == 3:
# Transpose from pytorch convention (C, H, W) to (H, W, C)

View File

@ -108,7 +108,9 @@ class LeRobotDatasetMetadata:
self.episodes = load_episodes(self.root)
if self._version < packaging.version.parse("v2.1"):
self.stats = load_stats(self.root)
self.episodes_stats = backward_compatible_episodes_stats(self.stats, self.episodes)
self.episodes_stats = backward_compatible_episodes_stats(
self.stats, self.episodes
)
else:
self.episodes_stats = load_episodes_stats(self.root)
self.stats = aggregate_stats(list(self.episodes_stats.values()))
@ -238,7 +240,9 @@ class LeRobotDatasetMetadata:
Given a task in natural language, add it to the dictionary of tasks.
"""
if task in self.task_to_task_index:
raise ValueError(f"The task '{task}' already exists and can't be added twice.")
raise ValueError(
f"The task '{task}' already exists and can't be added twice."
)
task_index = self.info["total_tasks"]
self.task_to_task_index[task] = task_index
@ -281,7 +285,11 @@ class LeRobotDatasetMetadata:
write_episode(episode_dict, self.root)
self.episodes_stats[episode_index] = episode_stats
self.stats = aggregate_stats([self.stats, episode_stats]) if self.stats else episode_stats
self.stats = (
aggregate_stats([self.stats, episode_stats])
if self.stats
else episode_stats
)
write_episode_stats(episode_index, episode_stats, self.root)
def update_video_info(self) -> None:
@ -345,13 +353,17 @@ class LeRobotDatasetMetadata:
# as this would break the dict flattening in the stats computation, which uses '/' as separator
for key in features:
if "/" in key:
raise ValueError(f"Feature names should not contain '/'. Found '/' in feature '{key}'.")
raise ValueError(
f"Feature names should not contain '/'. Found '/' in feature '{key}'."
)
features = {**features, **DEFAULT_FEATURES}
obj.tasks, obj.task_to_task_index = {}, {}
obj.episodes_stats, obj.stats, obj.episodes = {}, {}, {}
obj.info = create_empty_dataset_info(CODEBASE_VERSION, fps, robot_type, features, use_videos)
obj.info = create_empty_dataset_info(
CODEBASE_VERSION, fps, robot_type, features, use_videos
)
if len(obj.video_keys) > 0 and not use_videos:
raise ValueError()
write_json(obj.info, obj.root / INFO_PATH)
@ -482,7 +494,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.episodes = episodes
self.tolerance_s = tolerance_s
self.revision = revision if revision else CODEBASE_VERSION
self.video_backend = video_backend if video_backend else get_safe_default_codec()
self.video_backend = (
video_backend if video_backend else get_safe_default_codec()
)
self.delta_indices = None
# Unused attributes
@ -495,28 +509,39 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.meta = LeRobotDatasetMetadata(
self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync
)
if self.episodes is not None and self.meta._version >= packaging.version.parse("v2.1"):
episodes_stats = [self.meta.episodes_stats[ep_idx] for ep_idx in self.episodes]
if self.episodes is not None and self.meta._version >= packaging.version.parse(
"v2.1"
):
episodes_stats = [
self.meta.episodes_stats[ep_idx] for ep_idx in self.episodes
]
self.stats = aggregate_stats(episodes_stats)
# Load actual data
try:
if force_cache_sync:
raise FileNotFoundError
assert all((self.root / fpath).is_file() for fpath in self.get_episodes_file_paths())
assert all(
(self.root / fpath).is_file()
for fpath in self.get_episodes_file_paths()
)
self.hf_dataset = self.load_hf_dataset()
except (AssertionError, FileNotFoundError, NotADirectoryError):
self.revision = get_safe_version(self.repo_id, self.revision)
self.download_episodes(download_videos)
self.hf_dataset = self.load_hf_dataset()
self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
self.episode_data_index = get_episode_data_index(
self.meta.episodes, self.episodes
)
# Check timestamps
timestamps = torch.stack(self.hf_dataset["timestamp"]).numpy()
episode_indices = torch.stack(self.hf_dataset["episode_index"]).numpy()
ep_data_index_np = {k: t.numpy() for k, t in self.episode_data_index.items()}
check_timestamps_sync(timestamps, episode_indices, ep_data_index_np, self.fps, self.tolerance_s)
check_timestamps_sync(
timestamps, episode_indices, ep_data_index_np, self.fps, self.tolerance_s
)
# Setup delta_indices
if self.delta_timestamps is not None:
@ -568,7 +593,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
else:
hub_api.upload_folder(**upload_kwargs)
if not hub_api.file_exists(self.repo_id, REPOCARD_NAME, repo_type="dataset", revision=branch):
if not hub_api.file_exists(
self.repo_id, REPOCARD_NAME, repo_type="dataset", revision=branch
):
card = create_lerobot_dataset_card(
tags=tags, dataset_info=self.meta.info, license=license, **card_kwargs
)
@ -576,8 +603,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
if tag_version:
with contextlib.suppress(RevisionNotFoundError):
hub_api.delete_tag(self.repo_id, tag=CODEBASE_VERSION, repo_type="dataset")
hub_api.create_tag(self.repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
hub_api.delete_tag(
self.repo_id, tag=CODEBASE_VERSION, repo_type="dataset"
)
hub_api.create_tag(
self.repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset"
)
def pull_from_repo(
self,
@ -609,7 +640,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.pull_from_repo(allow_patterns=files, ignore_patterns=ignore_patterns)
def get_episodes_file_paths(self) -> list[Path]:
episodes = self.episodes if self.episodes is not None else list(range(self.meta.total_episodes))
episodes = (
self.episodes
if self.episodes is not None
else list(range(self.meta.total_episodes))
)
fpaths = [str(self.meta.get_data_file_path(ep_idx)) for ep_idx in episodes]
if len(self.meta.video_keys) > 0:
video_files = [
@ -640,7 +675,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
def create_hf_dataset(self) -> datasets.Dataset:
features = get_hf_features_from_features(self.features)
ft_dict = {col: [] for col in features}
hf_dataset = datasets.Dataset.from_dict(ft_dict, features=features, split="train")
hf_dataset = datasets.Dataset.from_dict(
ft_dict, features=features, split="train"
)
# TODO(aliberts): hf_dataset.set_format("torch")
hf_dataset.set_transform(hf_transform_to_torch)
@ -726,7 +763,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
if key not in self.meta.video_keys
}
def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict[str, torch.Tensor]:
def _query_videos(
self, query_timestamps: dict[str, list[float]], ep_idx: int
) -> dict[str, torch.Tensor]:
"""Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function
in the main process (e.g. by using a second Dataloader with num_workers=0). It will result in a
Segmentation Fault. This probably happens because a memory reference to the video loader is created in
@ -735,7 +774,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
item = {}
for vid_key, query_ts in query_timestamps.items():
video_path = self.root / self.meta.get_video_file_path(ep_idx, vid_key)
frames = decode_video_frames(video_path, query_ts, self.tolerance_s, self.video_backend)
frames = decode_video_frames(
video_path, query_ts, self.tolerance_s, self.video_backend
)
item[vid_key] = frames.squeeze(0)
return item
@ -789,7 +830,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
)
def create_episode_buffer(self, episode_index: int | None = None) -> dict:
current_ep_idx = self.meta.total_episodes if episode_index is None else episode_index
current_ep_idx = (
self.meta.total_episodes if episode_index is None else episode_index
)
ep_buffer = {}
# size and task are special cases that are not in self.features
ep_buffer["size"] = 0
@ -887,7 +930,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
episode_tasks = list(set(tasks))
episode_index = episode_buffer["episode_index"]
episode_buffer["index"] = np.arange(self.meta.total_frames, self.meta.total_frames + episode_length)
episode_buffer["index"] = np.arange(
self.meta.total_frames, self.meta.total_frames + episode_length
)
episode_buffer["episode_index"] = np.full((episode_length,), episode_index)
# Add new tasks to the tasks dictionary
@ -897,12 +942,17 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.meta.add_task(task)
# Given tasks in natural language, find their corresponding task indices
episode_buffer["task_index"] = np.array([self.meta.get_task_index(task) for task in tasks])
episode_buffer["task_index"] = np.array(
[self.meta.get_task_index(task) for task in tasks]
)
for key, ft in self.features.items():
# index, episode_index, task_index are already processed above, and image and video
# are processed separately by storing image path and frame info as meta data
if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video"]:
if key in ["index", "episode_index", "task_index"] or ft["dtype"] in [
"image",
"video",
]:
continue
episode_buffer[key] = np.stack(episode_buffer[key])
@ -944,7 +994,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
def _save_episode_table(self, episode_buffer: dict, episode_index: int) -> None:
episode_dict = {key: episode_buffer[key] for key in self.hf_features}
ep_dataset = datasets.Dataset.from_dict(episode_dict, features=self.hf_features, split="train")
ep_dataset = datasets.Dataset.from_dict(
episode_dict, features=self.hf_features, split="train"
)
ep_dataset = embed_images(ep_dataset)
self.hf_dataset = concatenate_datasets([self.hf_dataset, ep_dataset])
self.hf_dataset.set_transform(hf_transform_to_torch)
@ -1063,7 +1115,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
obj.delta_timestamps = None
obj.delta_indices = None
obj.episode_data_index = None
obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec()
obj.video_backend = (
video_backend if video_backend is not None else get_safe_default_codec()
)
return obj
@ -1088,7 +1142,9 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
super().__init__()
self.repo_ids = repo_ids
self.root = Path(root) if root else HF_LEROBOT_HOME
self.tolerances_s = tolerances_s if tolerances_s else {repo_id: 1e-4 for repo_id in repo_ids}
self.tolerances_s = (
tolerances_s if tolerances_s else {repo_id: 1e-4 for repo_id in repo_ids}
)
# Construct the underlying datasets passing everything but `transform` and `delta_timestamps` which
# are handled by this class.
self._datasets = [

View File

@ -141,12 +141,16 @@ class SharpnessJitter(Transform):
return float(sharpness[0]), float(sharpness[1])
def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]:
sharpness_factor = torch.empty(1).uniform_(self.sharpness[0], self.sharpness[1]).item()
sharpness_factor = (
torch.empty(1).uniform_(self.sharpness[0], self.sharpness[1]).item()
)
return {"sharpness_factor": sharpness_factor}
def transform(self, inpt: Any, params: dict[str, Any]) -> Any:
sharpness_factor = params["sharpness_factor"]
return self._call_kernel(F.adjust_sharpness, inpt, sharpness_factor=sharpness_factor)
return self._call_kernel(
F.adjust_sharpness, inpt, sharpness_factor=sharpness_factor
)
@dataclass

View File

@ -135,7 +135,9 @@ def serialize_dict(stats: dict[str, torch.Tensor | np.ndarray | dict]) -> dict:
elif isinstance(value, (int, float)):
serialized_dict[key] = value
else:
raise NotImplementedError(f"The value '{value}' of type '{type(value)}' is not supported.")
raise NotImplementedError(
f"The value '{value}' of type '{type(value)}' is not supported."
)
return unflatten_dict(serialized_dict)
@ -214,7 +216,10 @@ def write_task(task_index: int, task: dict, local_dir: Path):
def load_tasks(local_dir: Path) -> tuple[dict, dict]:
tasks = load_jsonlines(local_dir / TASKS_PATH)
tasks = {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}
tasks = {
item["task_index"]: item["task"]
for item in sorted(tasks, key=lambda x: x["task_index"])
}
task_to_task_index = {task: task_index for task_index, task in tasks.items()}
return tasks, task_to_task_index
@ -225,13 +230,19 @@ def write_episode(episode: dict, local_dir: Path):
def load_episodes(local_dir: Path) -> dict:
episodes = load_jsonlines(local_dir / EPISODES_PATH)
return {item["episode_index"]: item for item in sorted(episodes, key=lambda x: x["episode_index"])}
return {
item["episode_index"]: item
for item in sorted(episodes, key=lambda x: x["episode_index"])
}
def write_episode_stats(episode_index: int, episode_stats: dict, local_dir: Path):
# We wrap episode_stats in a dictionary since `episode_stats["episode_index"]`
# is a dictionary of stats and not an integer.
episode_stats = {"episode_index": episode_index, "stats": serialize_dict(episode_stats)}
episode_stats = {
"episode_index": episode_index,
"stats": serialize_dict(episode_stats),
}
append_jsonlines(episode_stats, local_dir / EPISODES_STATS_PATH)
@ -275,7 +286,9 @@ def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
elif first_item is None:
pass
else:
items_dict[key] = [x if isinstance(x, str) else torch.tensor(x) for x in items_dict[key]]
items_dict[key] = [
x if isinstance(x, str) else torch.tensor(x) for x in items_dict[key]
]
return items_dict
@ -328,7 +341,9 @@ def get_safe_version(repo_id: str, version: str | packaging.version.Version) ->
Otherwise, will throw a `CompatibilityError`.
"""
target_version = (
packaging.version.parse(version) if not isinstance(version, packaging.version.Version) else version
packaging.version.parse(version)
if not isinstance(version, packaging.version.Version)
else version
)
hub_versions = get_repo_versions(repo_id)
@ -349,12 +364,16 @@ def get_safe_version(repo_id: str, version: str | packaging.version.Version) ->
return f"v{target_version}"
compatibles = [
v for v in hub_versions if v.major == target_version.major and v.minor <= target_version.minor
v
for v in hub_versions
if v.major == target_version.major and v.minor <= target_version.minor
]
if compatibles:
return_version = max(compatibles)
if return_version < target_version:
logging.warning(f"Revision {version} for {repo_id} not found, using version v{return_version}")
logging.warning(
f"Revision {version} for {repo_id} not found, using version v{return_version}"
)
return f"v{return_version}"
lower_major = [v for v in hub_versions if v.major < target_version.major]
@ -461,7 +480,9 @@ def create_empty_dataset_info(
def get_episode_data_index(
episode_dicts: dict[dict], episodes: list[int] | None = None
) -> dict[str, torch.Tensor]:
episode_lengths = {ep_idx: ep_dict["length"] for ep_idx, ep_dict in episode_dicts.items()}
episode_lengths = {
ep_idx: ep_dict["length"] for ep_idx, ep_dict in episode_dicts.items()
}
if episodes is not None:
episode_lengths = {ep_idx: episode_lengths[ep_idx] for ep_idx in episodes}
@ -511,7 +532,9 @@ def check_timestamps_sync(
# Mask to ignore differences at the boundaries between episodes
mask = np.ones(len(diffs), dtype=bool)
ignored_diffs = episode_data_index["to"][:-1] - 1 # indices at the end of each episode
ignored_diffs = (
episode_data_index["to"][:-1] - 1
) # indices at the end of each episode
mask[ignored_diffs] = False
filtered_within_tolerance = within_tolerance[mask]
@ -720,14 +743,18 @@ def validate_frame(frame: dict, features: dict):
expected_features = (set(features) - set(DEFAULT_FEATURES.keys())) | {"task"}
actual_features = set(frame.keys())
error_message = validate_features_presence(actual_features, expected_features, optional_features)
error_message = validate_features_presence(
actual_features, expected_features, optional_features
)
if "task" in frame:
error_message += validate_feature_string("task", frame["task"])
common_features = actual_features & (expected_features | optional_features)
for name in common_features - {"task"}:
error_message += validate_feature_dtype_and_shape(name, features[name], frame[name])
error_message += validate_feature_dtype_and_shape(
name, features[name], frame[name]
)
if error_message:
raise ValueError(error_message)
@ -750,7 +777,9 @@ def validate_features_presence(
return error_message
def validate_feature_dtype_and_shape(name: str, feature: dict, value: np.ndarray | PILImage.Image | str):
def validate_feature_dtype_and_shape(
name: str, feature: dict, value: np.ndarray | PILImage.Image | str
):
expected_dtype = feature["dtype"]
expected_shape = feature["shape"]
if is_valid_numpy_dtype_string(expected_dtype):
@ -760,7 +789,9 @@ def validate_feature_dtype_and_shape(name: str, feature: dict, value: np.ndarray
elif expected_dtype == "string":
return validate_feature_string(name, value)
else:
raise NotImplementedError(f"The feature dtype '{expected_dtype}' is not implemented yet.")
raise NotImplementedError(
f"The feature dtype '{expected_dtype}' is not implemented yet."
)
def validate_feature_numpy_array(
@ -782,13 +813,17 @@ def validate_feature_numpy_array(
return error_message
def validate_feature_image_or_video(name: str, expected_shape: list[str], value: np.ndarray | PILImage.Image):
def validate_feature_image_or_video(
name: str, expected_shape: list[str], value: np.ndarray | PILImage.Image
):
# Note: The check of pixels range ([0,1] for float and [0,255] for uint8) is done by the image writer threads.
error_message = ""
if isinstance(value, np.ndarray):
actual_shape = value.shape
c, h, w = expected_shape
if len(actual_shape) != 3 or (actual_shape != (c, h, w) and actual_shape != (h, w, c)):
if len(actual_shape) != 3 or (
actual_shape != (c, h, w) and actual_shape != (h, w, c)
):
error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(c, h, w)}' or '{(h, w, c)}'.\n"
elif isinstance(value, PILImage.Image):
pass
@ -819,7 +854,9 @@ def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features:
)
if episode_buffer["size"] == 0:
raise ValueError("You must add one or several frames with `add_frame` before calling `add_episode`.")
raise ValueError(
"You must add one or several frames with `add_frame` before calling `add_episode`."
)
buffer_keys = set(episode_buffer.keys()) - {"task", "size"}
if not buffer_keys == set(features):

View File

@ -35,22 +35,30 @@ def fix_dataset(repo_id: str) -> str:
dataset_info = get_dataset_config_info(repo_id, "default")
with SuppressWarnings():
lerobot_metadata = LeRobotDatasetMetadata(repo_id, revision=V20, force_cache_sync=True)
lerobot_metadata = LeRobotDatasetMetadata(
repo_id, revision=V20, force_cache_sync=True
)
meta_features = {key for key, ft in lerobot_metadata.features.items() if ft["dtype"] != "video"}
meta_features = {
key for key, ft in lerobot_metadata.features.items() if ft["dtype"] != "video"
}
parquet_features = set(dataset_info.features)
diff_parquet_meta = parquet_features - meta_features
diff_meta_parquet = meta_features - parquet_features
if diff_parquet_meta:
raise ValueError(f"In parquet not in info.json: {parquet_features - meta_features}")
raise ValueError(
f"In parquet not in info.json: {parquet_features - meta_features}"
)
if not diff_meta_parquet:
return f"{repo_id}: skipped (no diff)"
if diff_meta_parquet:
logging.warning(f"In info.json not in parquet: {meta_features - parquet_features}")
logging.warning(
f"In info.json not in parquet: {meta_features - parquet_features}"
)
assert diff_meta_parquet == {"language_instruction"}
lerobot_metadata.features.pop("language_instruction")
write_info(lerobot_metadata.info, lerobot_metadata.root)

View File

@ -37,8 +37,16 @@ import logging
from huggingface_hub import HfApi
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
from lerobot.common.datasets.utils import EPISODES_STATS_PATH, STATS_PATH, load_stats, write_info
from lerobot.common.datasets.v21.convert_stats import check_aggregate_stats, convert_stats
from lerobot.common.datasets.utils import (
EPISODES_STATS_PATH,
STATS_PATH,
load_stats,
write_info,
)
from lerobot.common.datasets.v21.convert_stats import (
check_aggregate_stats,
convert_stats,
)
V20 = "v2.0"
V21 = "v2.1"
@ -79,13 +87,21 @@ def convert_dataset(
hub_api = HfApi()
if hub_api.file_exists(
repo_id=dataset.repo_id, filename=STATS_PATH, revision=branch, repo_type="dataset"
repo_id=dataset.repo_id,
filename=STATS_PATH,
revision=branch,
repo_type="dataset",
):
hub_api.delete_file(
path_in_repo=STATS_PATH, repo_id=dataset.repo_id, revision=branch, repo_type="dataset"
path_in_repo=STATS_PATH,
repo_id=dataset.repo_id,
revision=branch,
repo_type="dataset",
)
hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
hub_api.create_tag(
repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset"
)
if __name__ == "__main__":

View File

@ -17,12 +17,18 @@ from concurrent.futures import ThreadPoolExecutor, as_completed
import numpy as np
from tqdm import tqdm
from lerobot.common.datasets.compute_stats import aggregate_stats, get_feature_stats, sample_indices
from lerobot.common.datasets.compute_stats import (
aggregate_stats,
get_feature_stats,
sample_indices,
)
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.utils import write_episode_stats
def sample_episode_video_frames(dataset: LeRobotDataset, episode_index: int, ft_key: str) -> np.ndarray:
def sample_episode_video_frames(
dataset: LeRobotDataset, episode_index: int, ft_key: str
) -> np.ndarray:
ep_len = dataset.meta.episodes[episode_index]["length"]
sampled_indices = sample_indices(ep_len)
query_timestamps = dataset._get_query_timestamps(0.0, {ft_key: sampled_indices})
@ -45,11 +51,14 @@ def convert_episode_stats(dataset: LeRobotDataset, ep_idx: int):
axes_to_reduce = (0, 2, 3) if ft["dtype"] in ["image", "video"] else 0
keepdims = True if ft["dtype"] in ["image", "video"] else ep_ft_data.ndim == 1
ep_stats[key] = get_feature_stats(ep_ft_data, axis=axes_to_reduce, keepdims=keepdims)
ep_stats[key] = get_feature_stats(
ep_ft_data, axis=axes_to_reduce, keepdims=keepdims
)
if ft["dtype"] in ["image", "video"]: # remove batch dim
ep_stats[key] = {
k: v if k == "count" else np.squeeze(v, axis=0) for k, v in ep_stats[key].items()
k: v if k == "count" else np.squeeze(v, axis=0)
for k, v in ep_stats[key].items()
}
dataset.meta.episodes_stats[ep_idx] = ep_stats
@ -95,5 +104,9 @@ def check_aggregate_stats(
if key in reference_stats and stat in reference_stats[key]:
err_msg = f"feature='{key}' stats='{stat}'"
np.testing.assert_allclose(
val, reference_stats[key][stat], rtol=rtol, atol=atol, err_msg=err_msg
val,
reference_stats[key][stat],
rtol=rtol,
atol=atol,
err_msg=err_msg,
)

View File

@ -65,7 +65,9 @@ def decode_video_frames(
if backend == "torchcodec":
return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s)
elif backend in ["pyav", "video_reader"]:
return decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend)
return decode_video_frames_torchvision(
video_path, timestamps, tolerance_s, backend
)
else:
raise ValueError(f"Unsupported video backend: {backend}")

View File

@ -61,10 +61,16 @@ class AlohaEnv(EnvConfig):
def __post_init__(self):
if self.obs_type == "pixels":
self.features["top"] = PolicyFeature(type=FeatureType.VISUAL, shape=(480, 640, 3))
self.features["top"] = PolicyFeature(
type=FeatureType.VISUAL, shape=(480, 640, 3)
)
elif self.obs_type == "pixels_agent_pos":
self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(14,))
self.features["pixels/top"] = PolicyFeature(type=FeatureType.VISUAL, shape=(480, 640, 3))
self.features["agent_pos"] = PolicyFeature(
type=FeatureType.STATE, shape=(14,)
)
self.features["pixels/top"] = PolicyFeature(
type=FeatureType.VISUAL, shape=(480, 640, 3)
)
@property
def gym_kwargs(self) -> dict:
@ -102,9 +108,13 @@ class PushtEnv(EnvConfig):
def __post_init__(self):
if self.obs_type == "pixels_agent_pos":
self.features["pixels"] = PolicyFeature(type=FeatureType.VISUAL, shape=(384, 384, 3))
self.features["pixels"] = PolicyFeature(
type=FeatureType.VISUAL, shape=(384, 384, 3)
)
elif self.obs_type == "environment_state_agent_pos":
self.features["environment_state"] = PolicyFeature(type=FeatureType.ENV, shape=(16,))
self.features["environment_state"] = PolicyFeature(
type=FeatureType.ENV, shape=(16,)
)
@property
def gym_kwargs(self) -> dict:
@ -143,7 +153,9 @@ class XarmEnv(EnvConfig):
def __post_init__(self):
if self.obs_type == "pixels_agent_pos":
self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(4,))
self.features["agent_pos"] = PolicyFeature(
type=FeatureType.STATE, shape=(4,)
)
@property
def gym_kwargs(self) -> dict:

View File

@ -32,7 +32,9 @@ def make_env_config(env_type: str, **kwargs) -> EnvConfig:
raise ValueError(f"Policy type '{env_type}' is not available.")
def make_env(cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False) -> gym.vector.VectorEnv | None:
def make_env(
cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False
) -> gym.vector.VectorEnv | None:
"""Makes a gym vector environment according to the config.
Args:
@ -56,7 +58,9 @@ def make_env(cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False) -> g
try:
importlib.import_module(package_name)
except ModuleNotFoundError as e:
print(f"{package_name} is not installed. Please install it with `pip install 'lerobot[{cfg.type}]'`")
print(
f"{package_name} is not installed. Please install it with `pip install 'lerobot[{cfg.type}]'`"
)
raise e
gym_handle = f"{package_name}/{cfg.task}"
@ -64,7 +68,10 @@ def make_env(cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False) -> g
# batched version of the env that returns an observation of shape (b, c)
env_cls = gym.vector.AsyncVectorEnv if use_async_envs else gym.vector.SyncVectorEnv
env = env_cls(
[lambda: gym.make(gym_handle, disable_env_checker=True, **cfg.gym_kwargs) for _ in range(n_envs)]
[
lambda: gym.make(gym_handle, disable_env_checker=True, **cfg.gym_kwargs)
for _ in range(n_envs)
]
)
return env

View File

@ -46,7 +46,9 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
# sanity check that images are channel last
_, h, w, c = img.shape
assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}"
assert c < h and c < w, (
f"expect channel last images, but instead got {img.shape=}"
)
# sanity check that images are uint8
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
@ -79,7 +81,9 @@ def env_to_policy_features(env_cfg: EnvConfig) -> dict[str, PolicyFeature]:
for key, ft in env_cfg.features.items():
if ft.type is FeatureType.VISUAL:
if len(ft.shape) != 3:
raise ValueError(f"Number of dimensions of {key} != 3 (shape={ft.shape})")
raise ValueError(
f"Number of dimensions of {key} != 3 (shape={ft.shape})"
)
shape = get_channel_first_image_shape(ft.shape)
feature = PolicyFeature(type=ft.type, shape=shape)
@ -92,7 +96,9 @@ def env_to_policy_features(env_cfg: EnvConfig) -> dict[str, PolicyFeature]:
return policy_features
def preprocess_maniskill_observation(observations: dict[str, np.ndarray]) -> dict[str, Tensor]:
def preprocess_maniskill_observation(
observations: dict[str, np.ndarray],
) -> dict[str, Tensor]:
"""Convert environment observation to LeRobot format observation.
Args:
observation: Dictionary of observation batches from a Gym vector environment.

View File

@ -250,9 +250,9 @@ class Logger:
)
# For the case where the optimizer is a dictionary of optimizers (e.g., sac)
if type(training_state["optimizer"]) is dict:
assert set(training_state["optimizer"].keys()) == set(
optimizer.keys()
), "Optimizer dictionaries do not have the same keys during resume!"
assert set(training_state["optimizer"].keys()) == set(optimizer.keys()), (
"Optimizer dictionaries do not have the same keys during resume!"
)
for k, v in training_state["optimizer"].items():
optimizer[k].load_state_dict(v)
else:

View File

@ -34,7 +34,13 @@ def make_optimizer_and_scheduler(
Returns:
tuple[Optimizer, LRScheduler | None]: The couple (Optimizer, Scheduler). Scheduler can be `None`.
"""
params = policy.get_optim_params() if cfg.use_policy_training_preset else policy.parameters()
params = (
policy.get_optim_params()
if cfg.use_policy_training_preset
else policy.parameters()
)
optimizer = cfg.optimizer.build(params)
lr_scheduler = cfg.scheduler.build(optimizer, cfg.steps) if cfg.scheduler is not None else None
lr_scheduler = (
cfg.scheduler.build(optimizer, cfg.steps) if cfg.scheduler is not None else None
)
return optimizer, lr_scheduler

View File

@ -102,7 +102,9 @@ def save_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> No
write_json(param_groups, save_dir / OPTIMIZER_PARAM_GROUPS)
def load_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> torch.optim.Optimizer:
def load_optimizer_state(
optimizer: torch.optim.Optimizer, save_dir: Path
) -> torch.optim.Optimizer:
current_state_dict = optimizer.state_dict()
flat_state = load_file(save_dir / OPTIMIZER_STATE)
state = unflatten_dict(flat_state)

View File

@ -36,7 +36,9 @@ class LRSchedulerConfig(draccus.ChoiceRegistry, abc.ABC):
return self.get_choice_name(self.__class__)
@abc.abstractmethod
def build(self, optimizer: Optimizer, num_training_steps: int) -> LRScheduler | None:
def build(
self, optimizer: Optimizer, num_training_steps: int
) -> LRScheduler | None:
raise NotImplementedError
@ -49,7 +51,11 @@ class DiffuserSchedulerConfig(LRSchedulerConfig):
def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR:
from diffusers.optimization import get_scheduler
kwargs = {**asdict(self), "num_training_steps": num_training_steps, "optimizer": optimizer}
kwargs = {
**asdict(self),
"num_training_steps": num_training_steps,
"optimizer": optimizer,
}
return get_scheduler(**kwargs)
@ -71,7 +77,14 @@ class VQBeTSchedulerConfig(LRSchedulerConfig):
progress = float(adjusted_step - self.num_warmup_steps) / float(
max(1, num_training_steps - self.num_warmup_steps)
)
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(self.num_cycles) * 2.0 * progress)))
return max(
0.0,
0.5
* (
1.0
+ math.cos(math.pi * float(self.num_cycles) * 2.0 * progress)
),
)
return LambdaLR(optimizer, lr_lambda, -1)
@ -98,7 +111,9 @@ class CosineDecayWithWarmupSchedulerConfig(LRSchedulerConfig):
def cosine_decay_schedule(current_step):
step = min(current_step, self.num_decay_steps)
cosine_decay = 0.5 * (1 + math.cos(math.pi * step / self.num_decay_steps))
cosine_decay = 0.5 * (
1 + math.cos(math.pi * step / self.num_decay_steps)
)
alpha = self.decay_lr / self.peak_lr
decayed = (1 - alpha) * cosine_decay + alpha
return decayed
@ -117,6 +132,8 @@ def save_scheduler_state(scheduler: LRScheduler, save_dir: Path) -> None:
def load_scheduler_state(scheduler: LRScheduler, save_dir: Path) -> LRScheduler:
state_dict = deserialize_json_into_object(save_dir / SCHEDULER_STATE, scheduler.state_dict())
state_dict = deserialize_json_into_object(
save_dir / SCHEDULER_STATE, scheduler.state_dict()
)
scheduler.load_state_dict(state_dict)
return scheduler

View File

@ -171,7 +171,9 @@ class ACTConfig(PreTrainedConfig):
def validate_features(self) -> None:
if not self.image_features and not self.env_state_feature:
raise ValueError("You must provide at least one image or the environment state among the inputs.")
raise ValueError(
"You must provide at least one image or the environment state among the inputs."
)
@property
def observation_delta_indices(self) -> None:

View File

@ -63,7 +63,9 @@ class ACTPolicy(PreTrainedPolicy):
config.validate_features()
self.config = config
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
self.normalize_inputs = Normalize(
config.input_features, config.normalization_mapping, dataset_stats
)
self.normalize_targets = Normalize(
config.output_features, config.normalization_mapping, dataset_stats
)
@ -120,8 +122,12 @@ class ACTPolicy(PreTrainedPolicy):
batch = self.normalize_inputs(batch)
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = [batch[key] for key in self.config.image_features]
batch = dict(
batch
) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = [
batch[key] for key in self.config.image_features
]
# If we are doing temporal ensembling, do online updates where we keep track of the number of actions
# we are ensembling over.
@ -148,8 +154,12 @@ class ACTPolicy(PreTrainedPolicy):
"""Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch)
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = [batch[key] for key in self.config.image_features]
batch = dict(
batch
) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = [
batch[key] for key in self.config.image_features
]
batch = self.normalize_targets(batch)
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
@ -406,14 +416,18 @@ class ACT(nn.Module):
n_1d_tokens += 1
self.encoder_1d_feature_pos_embed = nn.Embedding(n_1d_tokens, config.dim_model)
if self.config.image_features:
self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2)
self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(
config.dim_model // 2
)
# Transformer decoder.
# Learnable positional embedding for the transformer's decoder (in the style of DETR object queries).
self.decoder_pos_embed = nn.Embedding(config.chunk_size, config.dim_model)
# Final action regression head on the output of the transformer's decoder.
self.action_head = nn.Linear(config.dim_model, self.config.action_feature.shape[0])
self.action_head = nn.Linear(
config.dim_model, self.config.action_feature.shape[0]
)
self._reset_parameters()
@ -461,14 +475,20 @@ class ACT(nn.Module):
self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size
) # (B, 1, D)
if self.config.robot_state_feature:
robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"])
robot_state_embed = self.vae_encoder_robot_state_input_proj(
batch["observation.state"]
)
robot_state_embed = robot_state_embed.unsqueeze(1) # (B, 1, D)
action_embed = self.vae_encoder_action_input_proj(
batch["action"]
) # (B, S, D)
if self.config.robot_state_feature:
vae_encoder_input = [cls_embed, robot_state_embed, action_embed] # (B, S+2, D)
vae_encoder_input = [
cls_embed,
robot_state_embed,
action_embed,
] # (B, S+2, D)
else:
vae_encoder_input = [cls_embed, action_embed]
vae_encoder_input = torch.cat(vae_encoder_input, axis=1)
@ -517,7 +537,9 @@ class ACT(nn.Module):
)
# Robot state token.
if self.config.robot_state_feature:
encoder_in_tokens.append(self.encoder_robot_state_input_proj(batch["observation.state"]))
encoder_in_tokens.append(
self.encoder_robot_state_input_proj(batch["observation.state"])
)
# Environment state token.
if self.config.env_state_feature:
encoder_in_tokens.append(
@ -534,7 +556,9 @@ class ACT(nn.Module):
# For a list of images, the H and W may vary but H*W is constant.
for img in batch["observation.images"]:
cam_features = self.backbone(img)["feature_map"]
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(
dtype=cam_features.dtype
)
cam_features = self.encoder_img_feat_input_proj(cam_features)
# Rearrange features to (sequence, batch, dim).

View File

@ -205,11 +205,16 @@ class DiffusionConfig(PreTrainedConfig):
def validate_features(self) -> None:
if len(self.image_features) == 0 and self.env_state_feature is None:
raise ValueError("You must provide at least one image or the environment state among the inputs.")
raise ValueError(
"You must provide at least one image or the environment state among the inputs."
)
if self.crop_shape is not None:
for key, image_ft in self.image_features.items():
if self.crop_shape[0] > image_ft.shape[1] or self.crop_shape[1] > image_ft.shape[2]:
if (
self.crop_shape[0] > image_ft.shape[1]
or self.crop_shape[1] > image_ft.shape[2]
):
raise ValueError(
f"`crop_shape` should fit within the images shapes. Got {self.crop_shape} "
f"for `crop_shape` and {image_ft.shape} for "

View File

@ -70,7 +70,9 @@ class DiffusionPolicy(PreTrainedPolicy):
config.validate_features()
self.config = config
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
self.normalize_inputs = Normalize(
config.input_features, config.normalization_mapping, dataset_stats
)
self.normalize_targets = Normalize(
config.output_features, config.normalization_mapping, dataset_stats
)
@ -97,7 +99,9 @@ class DiffusionPolicy(PreTrainedPolicy):
if self.config.image_features:
self._queues["observation.images"] = deque(maxlen=self.config.n_obs_steps)
if self.config.env_state_feature:
self._queues["observation.environment_state"] = deque(maxlen=self.config.n_obs_steps)
self._queues["observation.environment_state"] = deque(
maxlen=self.config.n_obs_steps
)
@torch.no_grad
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
@ -123,7 +127,9 @@ class DiffusionPolicy(PreTrainedPolicy):
"""
batch = self.normalize_inputs(batch)
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch = dict(
batch
) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = torch.stack(
[batch[key] for key in self.config.image_features], dim=-4
)
@ -151,7 +157,9 @@ class DiffusionPolicy(PreTrainedPolicy):
"""Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch)
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch = dict(
batch
) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = torch.stack(
[batch[key] for key in self.config.image_features], dim=-4
)
@ -515,11 +523,15 @@ class DiffusionRgbEncoder(nn.Module):
# Note: we have a check in the config class to make sure all images have the same shape.
images_shape = next(iter(config.image_features.values())).shape
dummy_shape_h_w = config.crop_shape if config.crop_shape is not None else images_shape[1:]
dummy_shape_h_w = (
config.crop_shape if config.crop_shape is not None else images_shape[1:]
)
dummy_shape = (1, images_shape[0], *dummy_shape_h_w)
feature_map_shape = get_output_shape(self.backbone, dummy_shape)[1:]
self.pool = SpatialSoftmax(feature_map_shape, num_kp=config.spatial_softmax_num_keypoints)
self.pool = SpatialSoftmax(
feature_map_shape, num_kp=config.spatial_softmax_num_keypoints
)
self.feature_dim = config.spatial_softmax_num_keypoints * 2
self.out = nn.Linear(config.spatial_softmax_num_keypoints * 2, self.feature_dim)
self.relu = nn.ReLU()
@ -719,7 +731,9 @@ class DiffusionConditionalUnet1d(nn.Module):
)
self.final_conv = nn.Sequential(
DiffusionConv1dBlock(config.down_dims[0], config.down_dims[0], kernel_size=config.kernel_size),
DiffusionConv1dBlock(
config.down_dims[0], config.down_dims[0], kernel_size=config.kernel_size
),
nn.Conv1d(config.down_dims[0], config.action_feature.shape[0], 1),
)

View File

@ -104,7 +104,9 @@ def make_policy(
PreTrainedPolicy: _description_
"""
if bool(ds_meta) == bool(env_cfg):
raise ValueError("Either one of a dataset metadata or a sim env must be provided.")
raise ValueError(
"Either one of a dataset metadata or a sim env must be provided."
)
# NOTE: Currently, if you try to run vqbet with mps backend, you'll get this error.
# TODO(aliberts, rcadene): Implement a check_backend_compatibility in policies?
@ -134,8 +136,12 @@ def make_policy(
)
features = env_to_policy_features(env_cfg)
cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
cfg.input_features = {key: ft for key, ft in features.items() if key not in cfg.output_features}
cfg.output_features = {
key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION
}
cfg.input_features = {
key: ft for key, ft in features.items() if key not in cfg.output_features
}
kwargs["config"] = cfg
if cfg.pretrained_path:

View File

@ -82,25 +82,43 @@ def create_stats_buffers(
if stats:
if isinstance(stats[key]["mean"], np.ndarray):
if norm_mode is NormalizationMode.MEAN_STD:
buffer["mean"].data = torch.from_numpy(stats[key]["mean"]).to(dtype=torch.float32)
buffer["std"].data = torch.from_numpy(stats[key]["std"]).to(dtype=torch.float32)
buffer["mean"].data = torch.from_numpy(stats[key]["mean"]).to(
dtype=torch.float32
)
buffer["std"].data = torch.from_numpy(stats[key]["std"]).to(
dtype=torch.float32
)
elif norm_mode is NormalizationMode.MIN_MAX:
buffer["min"].data = torch.from_numpy(stats[key]["min"]).to(dtype=torch.float32)
buffer["max"].data = torch.from_numpy(stats[key]["max"]).to(dtype=torch.float32)
buffer["min"].data = torch.from_numpy(stats[key]["min"]).to(
dtype=torch.float32
)
buffer["max"].data = torch.from_numpy(stats[key]["max"]).to(
dtype=torch.float32
)
elif isinstance(stats[key]["mean"], torch.Tensor):
# Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
# tensors anywhere (for example, when we use the same stats for normalization and
# unnormalization). See the logic here
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
if norm_mode is NormalizationMode.MEAN_STD:
buffer["mean"].data = stats[key]["mean"].clone().to(dtype=torch.float32)
buffer["std"].data = stats[key]["std"].clone().to(dtype=torch.float32)
buffer["mean"].data = (
stats[key]["mean"].clone().to(dtype=torch.float32)
)
buffer["std"].data = (
stats[key]["std"].clone().to(dtype=torch.float32)
)
elif norm_mode is NormalizationMode.MIN_MAX:
buffer["min"].data = stats[key]["min"].clone().to(dtype=torch.float32)
buffer["max"].data = stats[key]["max"].clone().to(dtype=torch.float32)
buffer["min"].data = (
stats[key]["min"].clone().to(dtype=torch.float32)
)
buffer["max"].data = (
stats[key]["max"].clone().to(dtype=torch.float32)
)
else:
type_ = type(stats[key]["mean"])
raise ValueError(f"np.ndarray or torch.Tensor expected, but type is '{type_}' instead.")
raise ValueError(
f"np.ndarray or torch.Tensor expected, but type is '{type_}' instead."
)
stats_buffers[key] = buffer
return stats_buffers

View File

@ -44,7 +44,9 @@ def main():
else:
dataset_repo_id = "lerobot/aloha_sim_transfer_cube_human"
ckpt_torch_dir = Path.home() / f".cache/openpi/openpi-assets/checkpoints/{model_name}_pytorch"
ckpt_torch_dir = (
Path.home() / f".cache/openpi/openpi-assets/checkpoints/{model_name}_pytorch"
)
ckpt_jax_dir = Path.home() / f".cache/openpi/openpi-assets/checkpoints/{model_name}"
save_dir = Path(f"../openpi/data/{model_name}/save")
@ -70,7 +72,9 @@ def main():
# Create LeRobot batch from Jax
batch = {}
for cam_key, uint_chw_array in example["images"].items():
batch[f"observation.images.{cam_key}"] = torch.from_numpy(uint_chw_array) / 255.0
batch[f"observation.images.{cam_key}"] = (
torch.from_numpy(uint_chw_array) / 255.0
)
batch["observation.state"] = torch.from_numpy(example["state"])
batch["action"] = torch.from_numpy(outputs["actions"])
batch["task"] = example["prompt"]

View File

@ -54,7 +54,9 @@ def get_paligemma_config(precision: str):
"projector_hidden_act": "gelu_fast",
"vision_use_head": False,
}
final_config = PaliGemmaConfig(text_config=text_config, vision_config=vision_config, **config)
final_config = PaliGemmaConfig(
text_config=text_config, vision_config=vision_config, **config
)
return final_config

View File

@ -61,7 +61,11 @@ from lerobot.common.policies.pi0.conversion_scripts.conversion_utils import (
)
from lerobot.common.policies.pi0.modeling_pi0 import PI0Policy
PRECISIONS = {"bfloat16": torch.bfloat16, "float32": torch.float32, "float16": torch.float16}
PRECISIONS = {
"bfloat16": torch.bfloat16,
"float32": torch.float32,
"float16": torch.float16,
}
def slice_paligemma_state_dict(state_dict, config):
@ -318,7 +322,9 @@ def update_keys_with_prefix(d: dict, prefix: str) -> dict:
return {f"{prefix}{key}": value for key, value in d.items()}
def convert_pi0_checkpoint(checkpoint_dir: str, precision: str, tokenizer_id: str, output_path: str):
def convert_pi0_checkpoint(
checkpoint_dir: str, precision: str, tokenizer_id: str, output_path: str
):
# Break down orbax ckpts - they are in OCDBT
initial_params = slice_initial_orbax_checkpoint(checkpoint_dir=checkpoint_dir)
# process projection params
@ -378,7 +384,9 @@ def convert_pi0_checkpoint(checkpoint_dir: str, precision: str, tokenizer_id: st
# gemma_config=gemma_config, paligemma_config=paligemma_config)
pi0_model = PI0Policy(pi0_config)
paligemma_params = update_keys_with_prefix(paligemma_params, "model.paligemma_with_expert.")
paligemma_params = update_keys_with_prefix(
paligemma_params, "model.paligemma_with_expert."
)
gemma_params = update_keys_with_prefix(gemma_params, "model.paligemma_with_expert.")
projection_params = update_keys_with_prefix(projection_params, "model.")

View File

@ -48,18 +48,32 @@ def flex_attention_forward(
key_states = key_states[:, :, :, None, :]
key_states = key_states.expand(
batch_size, key_states.shape[1], num_key_value_heads, num_key_value_groups, head_dim
batch_size,
key_states.shape[1],
num_key_value_heads,
num_key_value_groups,
head_dim,
)
key_states = key_states.reshape(
batch_size, key_states.shape[1], num_key_value_heads * num_key_value_groups, head_dim
batch_size,
key_states.shape[1],
num_key_value_heads * num_key_value_groups,
head_dim,
)
value_states = value_states[:, :, :, None, :]
value_states = value_states.expand(
batch_size, value_states.shape[1], num_key_value_heads, num_key_value_groups, head_dim
batch_size,
value_states.shape[1],
num_key_value_heads,
num_key_value_groups,
head_dim,
)
value_states = value_states.reshape(
batch_size, value_states.shape[1], num_key_value_heads * num_key_value_groups, head_dim
batch_size,
value_states.shape[1],
num_key_value_heads * num_key_value_groups,
head_dim,
)
query_states = query_states.transpose(1, 2)

View File

@ -69,7 +69,11 @@ from lerobot.common.utils.utils import get_safe_dtype
def create_sinusoidal_pos_embedding(
time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu"
time: torch.tensor,
dimension: int,
min_period: float,
max_period: float,
device="cpu",
) -> Tensor:
"""Computes sine-cosine positional embedding vectors for scalar positions."""
if dimension % 2 != 0:
@ -189,7 +193,9 @@ def aloha_gripper_to_angular(value):
# This is the inverse of the angular to linear transformation inside the Interbotix code.
def linear_to_radian(linear_position, arm_length, horn_radius):
value = (horn_radius**2 + linear_position**2 - arm_length**2) / (2 * horn_radius * linear_position)
value = (horn_radius**2 + linear_position**2 - arm_length**2) / (
2 * horn_radius * linear_position
)
return safe_arcsin(value)
# The constants are taken from the Interbotix code.
@ -240,7 +246,9 @@ class PI0Policy(PreTrainedPolicy):
super().__init__(config)
config.validate_features()
self.config = config
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
self.normalize_inputs = Normalize(
config.input_features, config.normalization_mapping, dataset_stats
)
self.normalize_targets = Normalize(
config.output_features, config.normalization_mapping, dataset_stats
)
@ -248,7 +256,9 @@ class PI0Policy(PreTrainedPolicy):
config.output_features, config.normalization_mapping, dataset_stats
)
self.language_tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224")
self.language_tokenizer = AutoTokenizer.from_pretrained(
"google/paligemma-3b-pt-224"
)
self.model = PI0FlowMatching(config)
self.reset()
@ -261,7 +271,9 @@ class PI0Policy(PreTrainedPolicy):
return self.parameters()
@torch.no_grad
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
def select_action(
self, batch: dict[str, Tensor], noise: Tensor | None = None
) -> Tensor:
"""Select a single action given environment observations.
This method wraps `select_actions` in order to return one action at a time for execution in the
@ -300,7 +312,9 @@ class PI0Policy(PreTrainedPolicy):
self._action_queue.extend(actions.transpose(0, 1))
return self._action_queue.popleft()
def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> tuple[Tensor, dict[str, Tensor]]:
def forward(
self, batch: dict[str, Tensor], noise=None, time=None
) -> tuple[Tensor, dict[str, Tensor]]:
"""Do a full training forward pass to compute the loss"""
if self.config.adapt_to_pi_aloha:
batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT])
@ -316,7 +330,9 @@ class PI0Policy(PreTrainedPolicy):
actions_is_pad = batch.get("actions_is_pad")
loss_dict = {}
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions, noise, time)
losses = self.model.forward(
images, img_masks, lang_tokens, lang_masks, state, actions, noise, time
)
loss_dict["losses_after_forward"] = losses.clone()
if actions_is_pad is not None:
@ -343,7 +359,9 @@ class PI0Policy(PreTrainedPolicy):
img_masks = []
present_img_keys = [key for key in self.config.image_features if key in batch]
missing_img_keys = [key for key in self.config.image_features if key not in batch]
missing_img_keys = [
key for key in self.config.image_features if key not in batch
]
if len(present_img_keys) == 0:
raise ValueError(
@ -355,7 +373,9 @@ class PI0Policy(PreTrainedPolicy):
img = batch[key]
if self.config.resize_imgs_with_padding is not None:
img = resize_with_pad(img, *self.config.resize_imgs_with_padding, pad_value=0)
img = resize_with_pad(
img, *self.config.resize_imgs_with_padding, pad_value=0
)
# Normalize from range [0,1] to [-1,1] as expacted by siglip
img = img * 2.0 - 1.0
@ -394,7 +414,9 @@ class PI0Policy(PreTrainedPolicy):
return_tensors="pt",
)
lang_tokens = tokenized_prompt["input_ids"].to(device=device)
lang_masks = tokenized_prompt["attention_mask"].to(device=device, dtype=torch.bool)
lang_masks = tokenized_prompt["attention_mask"].to(
device=device, dtype=torch.bool
)
return lang_tokens, lang_masks
@ -413,7 +435,9 @@ class PI0Policy(PreTrainedPolicy):
actions[:, :, motor_idx] *= -1
# Reverse the gripper transformation that is being applied by the Aloha runtime.
for motor_idx in [6, 13]:
actions[:, :, motor_idx] = aloha_gripper_from_angular(actions[:, :, motor_idx])
actions[:, :, motor_idx] = aloha_gripper_from_angular(
actions[:, :, motor_idx]
)
return actions
def _pi_aloha_encode_actions_inv(self, actions):
@ -422,7 +446,9 @@ class PI0Policy(PreTrainedPolicy):
actions[:, :, motor_idx] *= -1
# Reverse the gripper transformation that is being applied by the Aloha runtime.
for motor_idx in [6, 13]:
actions[:, :, motor_idx] = aloha_gripper_from_angular_inv(actions[:, :, motor_idx])
actions[:, :, motor_idx] = aloha_gripper_from_angular_inv(
actions[:, :, motor_idx]
)
return actions
def prepare_state(self, batch):
@ -472,15 +498,25 @@ class PI0FlowMatching(nn.Module):
train_expert_only=self.config.train_expert_only,
attention_implementation=self.config.attention_implementation,
)
self.paligemma_with_expert = PaliGemmaWithExpertModel(paligemma_with_export_config)
self.paligemma_with_expert = PaliGemmaWithExpertModel(
paligemma_with_export_config
)
# Projections are float32
self.state_proj = nn.Linear(self.config.max_state_dim, self.config.proj_width)
self.action_in_proj = nn.Linear(self.config.max_action_dim, self.config.proj_width)
self.action_out_proj = nn.Linear(self.config.proj_width, self.config.max_action_dim)
self.action_in_proj = nn.Linear(
self.config.max_action_dim, self.config.proj_width
)
self.action_out_proj = nn.Linear(
self.config.proj_width, self.config.max_action_dim
)
self.action_time_mlp_in = nn.Linear(self.config.proj_width * 2, self.config.proj_width)
self.action_time_mlp_out = nn.Linear(self.config.proj_width, self.config.proj_width)
self.action_time_mlp_in = nn.Linear(
self.config.proj_width * 2, self.config.proj_width
)
self.action_time_mlp_out = nn.Linear(
self.config.proj_width, self.config.proj_width
)
self.set_requires_grad()
@ -524,7 +560,9 @@ class PI0FlowMatching(nn.Module):
# Normalize image embeddings
img_emb_dim = img_emb.shape[-1]
img_emb = img_emb * torch.tensor(img_emb_dim**0.5, dtype=img_emb.dtype, device=img_emb.device)
img_emb = img_emb * torch.tensor(
img_emb_dim**0.5, dtype=img_emb.dtype, device=img_emb.device
)
bsize, num_img_embs = img_emb.shape[:2]
img_mask = img_mask[:, None].expand(bsize, num_img_embs)
@ -577,7 +615,11 @@ class PI0FlowMatching(nn.Module):
# Embed timestep using sine-cosine positional encoding with sensitivity in the range [0, 1]
time_emb = create_sinusoidal_pos_embedding(
timestep, self.config.proj_width, min_period=4e-3, max_period=4.0, device=device
timestep,
self.config.proj_width,
min_period=4e-3,
max_period=4.0,
device=device,
)
time_emb = time_emb.type(dtype=dtype)
@ -595,7 +637,9 @@ class PI0FlowMatching(nn.Module):
embs.append(action_time_emb)
bsize, action_time_dim = action_time_emb.shape[:2]
action_time_mask = torch.ones(bsize, action_time_dim, dtype=torch.bool, device=device)
action_time_mask = torch.ones(
bsize, action_time_dim, dtype=torch.bool, device=device
)
pad_masks.append(action_time_mask)
# Set attention masks so that image, language and state inputs do not attend to action tokens
@ -609,7 +653,15 @@ class PI0FlowMatching(nn.Module):
return embs, pad_masks, att_masks
def forward(
self, images, img_masks, lang_tokens, lang_masks, state, actions, noise=None, time=None
self,
images,
img_masks,
lang_tokens,
lang_masks,
state,
actions,
noise=None,
time=None,
) -> Tensor:
"""Do a full training forward pass and compute the loss (batch_size x num_steps x num_motors)"""
if noise is None:
@ -625,7 +677,9 @@ class PI0FlowMatching(nn.Module):
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
images, img_masks, lang_tokens, lang_masks
)
suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(state, x_t, time)
suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(
state, x_t, time
)
pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1)
att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1)
@ -649,13 +703,19 @@ class PI0FlowMatching(nn.Module):
losses = F.mse_loss(u_t, v_t, reduction="none")
return losses
def sample_actions(self, images, img_masks, lang_tokens, lang_masks, state, noise=None) -> Tensor:
def sample_actions(
self, images, img_masks, lang_tokens, lang_masks, state, noise=None
) -> Tensor:
"""Do a full inference forward and compute the action (batch_size x num_steps x num_motors)"""
bsize = state.shape[0]
device = state.device
if noise is None:
actions_shape = (bsize, self.config.n_action_steps, self.config.max_action_dim)
actions_shape = (
bsize,
self.config.n_action_steps,
self.config.max_action_dim,
)
noise = self.sample_noise(actions_shape, device)
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
@ -703,12 +763,16 @@ class PI0FlowMatching(nn.Module):
timestep,
):
"""Apply one denoising step of the noise `x_t` at a given timestep."""
suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(state, x_t, timestep)
suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(
state, x_t, timestep
)
suffix_len = suffix_pad_masks.shape[1]
batch_size = prefix_pad_masks.shape[0]
prefix_len = prefix_pad_masks.shape[1]
prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand(batch_size, suffix_len, prefix_len)
prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand(
batch_size, suffix_len, prefix_len
)
suffix_att_2d_masks = make_att_2d_masks(suffix_pad_masks, suffix_att_masks)

View File

@ -39,9 +39,13 @@ def apply_rope(x, positions, max_wavelength=10_000):
dtype = x.dtype
x = x.to(torch.float32)
freq_exponents = (2.0 / x.shape[-1]) * torch.arange(d_half, dtype=torch.float32, device=device)
freq_exponents = (2.0 / x.shape[-1]) * torch.arange(
d_half, dtype=torch.float32, device=device
)
timescale = max_wavelength**freq_exponents
radians = positions[..., None].to(torch.float32) / timescale[None, None, :].to(torch.float32)
radians = positions[..., None].to(torch.float32) / timescale[None, None, :].to(
torch.float32
)
radians = radians[..., None, :]
@ -174,7 +178,9 @@ class PaliGemmaWithExpertModel(PreTrainedModel):
def __init__(self, config: PaliGemmaWithExpertConfig):
super().__init__(config=config)
self.config = config
self.paligemma = PaliGemmaForConditionalGeneration(config=config.paligemma_config)
self.paligemma = PaliGemmaForConditionalGeneration(
config=config.paligemma_config
)
self.gemma_expert = GemmaForCausalLM(config=config.gemma_expert_config)
# Remove unused embed_tokens
self.gemma_expert.model.embed_tokens = None
@ -291,14 +297,22 @@ class PaliGemmaWithExpertModel(PreTrainedModel):
# so we create an empty cache, with just one cuda malloc, and if (in autoregressive case) we reach
# the max len, then we (for instance) double the cache size. This implementation already exists
# in `transformers`. (molbap)
key_states = torch.cat([past_key_values[layer_idx]["key_states"], key_states], dim=1)
key_states = torch.cat(
[past_key_values[layer_idx]["key_states"], key_states], dim=1
)
value_states = torch.cat(
[past_key_values[layer_idx]["value_states"], value_states], dim=1
[past_key_values[layer_idx]["value_states"], value_states],
dim=1,
)
attention_interface = self.get_attention_interface()
att_output = attention_interface(
attention_mask, batch_size, head_dim, query_states, key_states, value_states
attention_mask,
batch_size,
head_dim,
query_states,
key_states,
value_states,
)
att_output = att_output.to(dtype=torch.bfloat16)
@ -358,15 +372,29 @@ class PaliGemmaWithExpertModel(PreTrainedModel):
return attention_interface
def flash_attention_forward(
self, attention_mask, batch_size, head_dim, query_states, key_states, value_states
self,
attention_mask,
batch_size,
head_dim,
query_states,
key_states,
value_states,
):
raise NotImplementedError("FA2 is not implemented (yet)")
def eager_attention_forward(
self, attention_mask, batch_size, head_dim, query_states, key_states, value_states
self,
attention_mask,
batch_size,
head_dim,
query_states,
key_states,
value_states,
):
num_att_heads = self.config.paligemma_config.text_config.num_attention_heads
num_key_value_heads = self.config.paligemma_config.text_config.num_key_value_heads
num_key_value_heads = (
self.config.paligemma_config.text_config.num_key_value_heads
)
num_key_value_groups = num_att_heads // num_key_value_heads
# query_states: batch_size, sequence_length, num_att_head, head_dim
@ -375,17 +403,31 @@ class PaliGemmaWithExpertModel(PreTrainedModel):
sequence_length = key_states.shape[1]
key_states = key_states[:, :, :, None, :].expand(
batch_size, sequence_length, num_key_value_heads, num_key_value_groups, head_dim
batch_size,
sequence_length,
num_key_value_heads,
num_key_value_groups,
head_dim,
)
key_states = key_states.reshape(
batch_size, sequence_length, num_key_value_heads * num_key_value_groups, head_dim
batch_size,
sequence_length,
num_key_value_heads * num_key_value_groups,
head_dim,
)
value_states = value_states[:, :, :, None, :].expand(
batch_size, sequence_length, num_key_value_heads, num_key_value_groups, head_dim
batch_size,
sequence_length,
num_key_value_heads,
num_key_value_groups,
head_dim,
)
value_states = value_states.reshape(
batch_size, sequence_length, num_key_value_heads * num_key_value_groups, head_dim
batch_size,
sequence_length,
num_key_value_heads * num_key_value_groups,
head_dim,
)
# Attention here is upcasted to float32 to match the original eager implementation.
@ -400,7 +442,9 @@ class PaliGemmaWithExpertModel(PreTrainedModel):
att_weights *= head_dim**-0.5
big_neg = -2.3819763e38 # See gemma/modules.py
masked_att_weights = torch.where(attention_mask[:, None, :, :], att_weights, big_neg)
masked_att_weights = torch.where(
attention_mask[:, None, :, :], att_weights, big_neg
)
probs = nn.functional.softmax(masked_att_weights, dim=-1)
probs = probs.to(dtype=value_states.dtype)
@ -412,6 +456,8 @@ class PaliGemmaWithExpertModel(PreTrainedModel):
att_output = att_output.permute(0, 2, 1, 3)
# we use -1 because sequence length can change
att_output = att_output.reshape(batch_size, -1, num_key_value_heads * num_key_value_groups * head_dim)
att_output = att_output.reshape(
batch_size, -1, num_key_value_heads * num_key_value_groups * head_dim
)
return att_output

View File

@ -71,7 +71,9 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
def _save_pretrained(self, save_directory: Path) -> None:
self.config._save_pretrained(save_directory)
model_to_save = self.module if hasattr(self, "module") else self
save_model_as_safetensor(model_to_save, str(save_directory / SAFETENSORS_SINGLE_FILE))
save_model_as_safetensor(
model_to_save, str(save_directory / SAFETENSORS_SINGLE_FILE)
)
@classmethod
def from_pretrained(
@ -110,7 +112,9 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
if os.path.isdir(model_id):
print("Loading weights from local directory")
model_file = os.path.join(model_id, SAFETENSORS_SINGLE_FILE)
policy = cls._load_as_safetensor(instance, model_file, config.device, strict)
policy = cls._load_as_safetensor(
instance, model_file, config.device, strict
)
else:
try:
model_file = hf_hub_download(
@ -124,7 +128,9 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
token=token,
local_files_only=local_files_only,
)
policy = cls._load_as_safetensor(instance, model_file, config.device, strict)
policy = cls._load_as_safetensor(
instance, model_file, config.device, strict
)
except HfHubHTTPError as e:
raise FileNotFoundError(
f"{SAFETENSORS_SINGLE_FILE} not found on the HuggingFace Hub in {model_id}"
@ -135,8 +141,12 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
return policy
@classmethod
def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, strict: bool) -> T:
if packaging.version.parse(safetensors.__version__) < packaging.version.parse("0.4.3"):
def _load_as_safetensor(
cls, model: T, model_file: str, map_location: str, strict: bool
) -> T:
if packaging.version.parse(safetensors.__version__) < packaging.version.parse(
"0.4.3"
):
load_model_as_safetensor(model, model_file, strict=strict)
if map_location != "cpu":
logging.warning(
@ -147,7 +157,9 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
)
model.to(map_location)
else:
safetensors.torch.load_model(model, model_file, strict=strict, device=map_location)
safetensors.torch.load_model(
model, model_file, strict=strict, device=map_location
)
return model
# def generate_model_card(self, *args, **kwargs) -> ModelCard:

View File

@ -639,9 +639,9 @@ class Policy(nn.Module):
# Compute standard deviations
if self.fixed_std is None:
log_std = self.std_layer(outputs)
assert not torch.isnan(
log_std
).any(), "[ERROR] log_std became NaN after std_layer!"
assert not torch.isnan(log_std).any(), (
"[ERROR] log_std became NaN after std_layer!"
)
if self.use_tanh_squash:
log_std = torch.tanh(log_std)

View File

@ -187,7 +187,9 @@ class TDMPCConfig(PreTrainedConfig):
"If `n_action_steps > 1`, `use_mpc` must be set to `True`."
)
if self.n_action_steps > self.horizon:
raise ValueError("`n_action_steps` must be less than or equal to `horizon`.")
raise ValueError(
"`n_action_steps` must be less than or equal to `horizon`."
)
def get_optimizer_preset(self) -> AdamConfig:
return AdamConfig(lr=self.optimizer_lr)
@ -207,7 +209,9 @@ class TDMPCConfig(PreTrainedConfig):
if image_ft.shape[-2] != image_ft.shape[-1]:
# TODO(alexander-soare): This limitation is solely because of code in the random shift
# augmentation. It should be able to be removed.
raise ValueError(f"Only square images are handled now. Got image shape {image_ft.shape}.")
raise ValueError(
f"Only square images are handled now. Got image shape {image_ft.shape}."
)
@property
def observation_delta_indices(self) -> list:

View File

@ -39,7 +39,11 @@ from lerobot.common.constants import OBS_ENV, OBS_ROBOT
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
from lerobot.common.policies.utils import get_device_from_parameters, get_output_shape, populate_queues
from lerobot.common.policies.utils import (
get_device_from_parameters,
get_output_shape,
populate_queues,
)
class TDMPCPolicy(PreTrainedPolicy):
@ -63,7 +67,11 @@ class TDMPCPolicy(PreTrainedPolicy):
config_class = TDMPCConfig
name = "tdmpc"
def __init__(self, config: TDMPCConfig, dataset_stats: dict[str, dict[str, Tensor]] | None = None):
def __init__(
self,
config: TDMPCConfig,
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
):
"""
Args:
config: Policy configuration class instance or None, in which case the default instantiation of
@ -75,7 +83,9 @@ class TDMPCPolicy(PreTrainedPolicy):
config.validate_features()
self.config = config
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
self.normalize_inputs = Normalize(
config.input_features, config.normalization_mapping, dataset_stats
)
self.normalize_targets = Normalize(
config.output_features, config.normalization_mapping, dataset_stats
)
@ -117,7 +127,9 @@ class TDMPCPolicy(PreTrainedPolicy):
"""Select a single action given environment observations."""
batch = self.normalize_inputs(batch)
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch = dict(
batch
) # shallow copy so that adding a key doesn't modify the original
batch["observation.image"] = batch[next(iter(self.config.image_features))]
self._queues = populate_queues(self._queues, batch)
@ -201,7 +213,10 @@ class TDMPCPolicy(PreTrainedPolicy):
# algorithm.
# The initial mean and standard deviation for the cross-entropy method (CEM).
mean = torch.zeros(
self.config.horizon, batch_size, self.config.action_feature.shape[0], device=device
self.config.horizon,
batch_size,
self.config.action_feature.shape[0],
device=device,
)
# Maybe warm start CEM with the mean from the previous step.
if self._prev_mean is not None:
@ -339,7 +354,9 @@ class TDMPCPolicy(PreTrainedPolicy):
batch = self.normalize_inputs(batch)
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch = dict(
batch
) # shallow copy so that adding a key doesn't modify the original
batch["observation.image"] = batch[next(iter(self.config.image_features))]
batch = self.normalize_targets(batch)
@ -371,7 +388,9 @@ class TDMPCPolicy(PreTrainedPolicy):
current_observation[k] = observations[k][0]
next_observations[k] = observations[k][1:]
horizon, batch_size = next_observations[
"observation.image" if self.config.image_features else "observation.environment_state"
"observation.image"
if self.config.image_features
else "observation.environment_state"
].shape[:2]
# Run latent rollout using the latent dynamics model and policy model.
@ -569,7 +588,9 @@ class TDMPCTOLD(nn.Module):
self.config = config
self._encoder = TDMPCObservationEncoder(config)
self._dynamics = nn.Sequential(
nn.Linear(config.latent_dim + config.action_feature.shape[0], config.mlp_dim),
nn.Linear(
config.latent_dim + config.action_feature.shape[0], config.mlp_dim
),
nn.LayerNorm(config.mlp_dim),
nn.Mish(),
nn.Linear(config.mlp_dim, config.mlp_dim),
@ -580,7 +601,9 @@ class TDMPCTOLD(nn.Module):
nn.Sigmoid(),
)
self._reward = nn.Sequential(
nn.Linear(config.latent_dim + config.action_feature.shape[0], config.mlp_dim),
nn.Linear(
config.latent_dim + config.action_feature.shape[0], config.mlp_dim
),
nn.LayerNorm(config.mlp_dim),
nn.Mish(),
nn.Linear(config.mlp_dim, config.mlp_dim),
@ -600,7 +623,10 @@ class TDMPCTOLD(nn.Module):
self._Qs = nn.ModuleList(
[
nn.Sequential(
nn.Linear(config.latent_dim + config.action_feature.shape[0], config.mlp_dim),
nn.Linear(
config.latent_dim + config.action_feature.shape[0],
config.mlp_dim,
),
nn.LayerNorm(config.mlp_dim),
nn.Tanh(),
nn.Linear(config.mlp_dim, config.mlp_dim),
@ -786,7 +812,9 @@ class TDMPCObservationEncoder(nn.Module):
if config.robot_state_feature:
self.state_enc_layers = nn.Sequential(
nn.Linear(config.robot_state_feature.shape[0], config.state_encoder_hidden_dim),
nn.Linear(
config.robot_state_feature.shape[0], config.state_encoder_hidden_dim
),
nn.ELU(),
nn.Linear(config.state_encoder_hidden_dim, config.latent_dim),
nn.LayerNorm(config.latent_dim),
@ -795,7 +823,9 @@ class TDMPCObservationEncoder(nn.Module):
if config.env_state_feature:
self.env_state_enc_layers = nn.Sequential(
nn.Linear(config.env_state_feature.shape[0], config.state_encoder_hidden_dim),
nn.Linear(
config.env_state_feature.shape[0], config.state_encoder_hidden_dim
),
nn.ELU(),
nn.Linear(config.state_encoder_hidden_dim, config.latent_dim),
nn.LayerNorm(config.latent_dim),
@ -813,7 +843,8 @@ class TDMPCObservationEncoder(nn.Module):
if self.config.image_features:
feat.append(
flatten_forward_unflatten(
self.image_enc_layers, obs_dict[next(iter(self.config.image_features))]
self.image_enc_layers,
obs_dict[next(iter(self.config.image_features))],
)
)
if self.config.env_state_feature:

View File

@ -172,7 +172,10 @@ class VQBeTConfig(PreTrainedConfig):
if self.crop_shape is not None:
for key, image_ft in self.image_features.items():
if self.crop_shape[0] > image_ft.shape[1] or self.crop_shape[1] > image_ft.shape[2]:
if (
self.crop_shape[0] > image_ft.shape[1]
or self.crop_shape[1] > image_ft.shape[2]
):
raise ValueError(
f"`crop_shape` should fit within the images shapes. Got {self.crop_shape} "
f"for `crop_shape` and {image_ft.shape} for "
@ -193,7 +196,12 @@ class VQBeTConfig(PreTrainedConfig):
@property
def action_delta_indices(self) -> list:
return list(range(1 - self.n_obs_steps, self.n_action_pred_token + self.action_chunk_size - 1))
return list(
range(
1 - self.n_obs_steps,
self.n_action_pred_token + self.action_chunk_size - 1,
)
)
@property
def reward_delta_indices(self) -> None:

View File

@ -29,7 +29,11 @@ from torch import Tensor, nn
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.policies.utils import get_device_from_parameters, get_output_shape, populate_queues
from lerobot.common.policies.utils import (
get_device_from_parameters,
get_output_shape,
populate_queues,
)
from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
from lerobot.common.policies.vqbet.vqbet_utils import GPT, ResidualVQ
@ -60,7 +64,9 @@ class VQBeTPolicy(PreTrainedPolicy):
config.validate_features()
self.config = config
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
self.normalize_inputs = Normalize(
config.input_features, config.normalization_mapping, dataset_stats
)
self.normalize_targets = Normalize(
config.output_features, config.normalization_mapping, dataset_stats
)
@ -91,11 +97,17 @@ class VQBeTPolicy(PreTrainedPolicy):
if self.config.sequentially_select:
decay_params = (
decay_params
+ list(self.vqbet.action_head.map_to_cbet_preds_primary_bin.parameters())
+ list(self.vqbet.action_head.map_to_cbet_preds_secondary_bin.parameters())
+ list(
self.vqbet.action_head.map_to_cbet_preds_primary_bin.parameters()
)
+ list(
self.vqbet.action_head.map_to_cbet_preds_secondary_bin.parameters()
)
)
else:
decay_params = decay_params + list(self.vqbet.action_head.map_to_cbet_preds_bin.parameters())
decay_params = decay_params + list(
self.vqbet.action_head.map_to_cbet_preds_bin.parameters()
)
return [
{
@ -133,8 +145,12 @@ class VQBeTPolicy(PreTrainedPolicy):
"""
batch = self.normalize_inputs(batch)
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
batch = dict(
batch
) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = torch.stack(
[batch[key] for key in self.config.image_features], dim=-4
)
# Note: It's important that this happens after stacking the images into a single key.
self._queues = populate_queues(self._queues, batch)
@ -165,8 +181,12 @@ class VQBeTPolicy(PreTrainedPolicy):
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
"""Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch)
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
batch = dict(
batch
) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = torch.stack(
[batch[key] for key in self.config.image_features], dim=-4
)
batch = self.normalize_targets(batch)
# VQ-BeT discretizes action using VQ-VAE before training BeT (please refer to section 3.2 in the VQ-BeT paper https://arxiv.org/pdf/2403.03181)
if not self.vqbet.action_head.vqvae_model.discretized.item():
@ -334,7 +354,8 @@ class VQBeTModel(nn.Module):
# To input state and observation features into GPT layers, we first project the features to fit the shape of input size of GPT.
self.state_projector = MLP(
config.robot_state_feature.shape[0], hidden_channels=[self.config.gpt_input_dim]
config.robot_state_feature.shape[0],
hidden_channels=[self.config.gpt_input_dim],
)
self.rgb_feature_projector = MLP(
self.rgb_encoder.feature_dim, hidden_channels=[self.config.gpt_input_dim]
@ -406,9 +427,9 @@ class VQBeTModel(nn.Module):
features = self.policy(input_tokens)
# len(self.config.input_features) is the number of different observation modes.
# this line gets the index of action prompt tokens.
historical_act_pred_index = np.arange(0, n_obs_steps) * (len(self.config.input_features) + 1) + len(
self.config.input_features
)
historical_act_pred_index = np.arange(0, n_obs_steps) * (
len(self.config.input_features) + 1
) + len(self.config.input_features)
# only extract the output tokens at the position of action query:
# Behavior Transformer (BeT), and VQ-BeT are both sequence-to-sequence prediction models,
@ -771,11 +792,15 @@ class VQBeTRgbEncoder(nn.Module):
# height and width from `config.image_features`.
images_shape = next(iter(config.image_features.values())).shape
dummy_shape_h_w = config.crop_shape if config.crop_shape is not None else images_shape[1:]
dummy_shape_h_w = (
config.crop_shape if config.crop_shape is not None else images_shape[1:]
)
dummy_shape = (1, images_shape[0], *dummy_shape_h_w)
feature_map_shape = get_output_shape(self.backbone, dummy_shape)[1:]
self.pool = SpatialSoftmax(feature_map_shape, num_kp=config.spatial_softmax_num_keypoints)
self.pool = SpatialSoftmax(
feature_map_shape, num_kp=config.spatial_softmax_num_keypoints
)
self.feature_dim = config.spatial_softmax_num_keypoints * 2
self.out = nn.Linear(config.spatial_softmax_num_keypoints * 2, self.feature_dim)
self.relu = nn.ReLU()
@ -871,7 +896,8 @@ class VqVae(nn.Module):
)
self.encoder = MLP(
in_channels=self.config.action_feature.shape[0] * self.config.action_chunk_size,
in_channels=self.config.action_feature.shape[0]
* self.config.action_chunk_size,
hidden_channels=[
config.vqvae_enc_hidden_dim,
config.vqvae_enc_hidden_dim,
@ -899,9 +925,13 @@ class VqVae(nn.Module):
# given latent vector, this function outputs the decoded action.
output = self.decoder(latent)
if self.config.action_chunk_size == 1:
return einops.rearrange(output, "N (T A) -> N T A", A=self.config.action_feature.shape[0])
return einops.rearrange(
output, "N (T A) -> N T A", A=self.config.action_feature.shape[0]
)
else:
return einops.rearrange(output, "N (T A) -> N T A", A=self.config.action_feature.shape[0])
return einops.rearrange(
output, "N (T A) -> N T A", A=self.config.action_feature.shape[0]
)
def get_code(self, state):
# in phase 2 of VQ-BeT training, we need a `ground truth labels of action data` to calculate the Focal loss for code prediction head. (please refer to section 3.3 in the paper https://arxiv.org/pdf/2403.03181)

View File

@ -290,10 +290,10 @@ class GPT(nn.Module):
param_dict = dict(self.named_parameters())
inter_params = decay & no_decay
union_params = decay | no_decay
assert (
len(inter_params) == 0
), "parameters {} made it into both decay/no_decay sets!".format(
str(inter_params)
assert len(inter_params) == 0, (
"parameters {} made it into both decay/no_decay sets!".format(
str(inter_params)
)
)
assert len(param_dict.keys() - union_params) == 0, (
"parameters {} were not separated into either decay/no_decay set!".format(
@ -664,14 +664,14 @@ class VectorQuantize(nn.Module):
self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only
self.orthogonal_reg_max_codes = orthogonal_reg_max_codes
assert not (
ema_update and learnable_codebook
), "learnable codebook not compatible with EMA update"
assert not (ema_update and learnable_codebook), (
"learnable codebook not compatible with EMA update"
)
assert 0 <= sync_update_v <= 1.0
assert not (
sync_update_v > 0.0 and not learnable_codebook
), "learnable codebook must be turned on"
assert not (sync_update_v > 0.0 and not learnable_codebook), (
"learnable codebook must be turned on"
)
self.sync_update_v = sync_update_v

View File

@ -57,7 +57,9 @@ class OpenCVCameraConfig(CameraConfig):
self.channels = 3
if self.rotation not in [-90, None, 90, 180]:
raise ValueError(f"`rotation` must be in [-90, None, 90, 180] (got {self.rotation})")
raise ValueError(
f"`rotation` must be in [-90, None, 90, 180] (got {self.rotation})"
)
@CameraConfig.register_subclass("intelrealsense")
@ -102,8 +104,12 @@ class IntelRealSenseCameraConfig(CameraConfig):
self.channels = 3
at_least_one_is_not_none = self.fps is not None or self.width is not None or self.height is not None
at_least_one_is_none = self.fps is None or self.width is None or self.height is None
at_least_one_is_not_none = (
self.fps is not None or self.width is not None or self.height is not None
)
at_least_one_is_none = (
self.fps is None or self.width is None or self.height is None
)
if at_least_one_is_not_none and at_least_one_is_none:
raise ValueError(
"For `fps`, `width` and `height`, either all of them need to be set, or none of them, "
@ -111,4 +117,6 @@ class IntelRealSenseCameraConfig(CameraConfig):
)
if self.rotation not in [-90, None, 90, 180]:
raise ValueError(f"`rotation` must be in [-90, None, 90, 180] (got {self.rotation})")
raise ValueError(
f"`rotation` must be in [-90, None, 90, 180] (got {self.rotation})"
)

View File

@ -303,7 +303,11 @@ class IntelRealSenseCamera:
if self.fps and self.capture_width and self.capture_height:
# TODO(rcadene): can we set rgb8 directly?
config.enable_stream(
rs.stream.color, self.capture_width, self.capture_height, rs.format.rgb8, self.fps
rs.stream.color,
self.capture_width,
self.capture_height,
rs.format.rgb8,
self.fps,
)
else:
config.enable_stream(rs.stream.color)
@ -311,7 +315,11 @@ class IntelRealSenseCamera:
if self.use_depth:
if self.fps and self.capture_width and self.capture_height:
config.enable_stream(
rs.stream.depth, self.capture_width, self.capture_height, rs.format.z16, self.fps
rs.stream.depth,
self.capture_width,
self.capture_height,
rs.format.z16,
self.fps,
)
else:
config.enable_stream(rs.stream.depth)

View File

@ -144,7 +144,9 @@ def save_images_from_cameras(
print("Connecting cameras")
cameras = []
for cam_idx in camera_ids:
config = OpenCVCameraConfig(camera_index=cam_idx, fps=fps, width=width, height=height, mock=mock)
config = OpenCVCameraConfig(
camera_index=cam_idx, fps=fps, width=width, height=height, mock=mock
)
camera = OpenCVCamera(config)
camera.connect()
print(
@ -250,7 +252,9 @@ class OpenCVCamera:
# Retrieve the camera index from a potentially symlinked path
self.camera_index = get_camera_index_from_unix_port(self.port)
else:
raise ValueError(f"Please check the provided camera_index: {self.camera_index}")
raise ValueError(
f"Please check the provided camera_index: {self.camera_index}"
)
# Store the raw (capture) resolution from the config.
self.capture_width = config.width
@ -314,7 +318,11 @@ class OpenCVCamera:
else cv2.CAP_ANY
)
camera_idx = f"/dev/video{self.camera_index}" if platform.system() == "Linux" else self.camera_index
camera_idx = (
f"/dev/video{self.camera_index}"
if platform.system() == "Linux"
else self.camera_index
)
# First create a temporary camera trying to access `camera_index`,
# and verify it is a valid camera by calling `isOpened`.
tmp_camera = cv2.VideoCapture(camera_idx, backend)

View File

@ -41,7 +41,9 @@ def make_cameras_from_configs(camera_configs: dict[str, CameraConfig]) -> list[C
cameras[key] = OpenCVCamera(cfg)
elif cfg.type == "intelrealsense":
from lerobot.common.robot_devices.cameras.intelrealsense import IntelRealSenseCamera
from lerobot.common.robot_devices.cameras.intelrealsense import (
IntelRealSenseCamera,
)
cameras[key] = IntelRealSenseCamera(cfg)
else:
@ -58,7 +60,9 @@ def make_camera(camera_type, **kwargs) -> Camera:
return OpenCVCamera(config)
elif camera_type == "intelrealsense":
from lerobot.common.robot_devices.cameras.intelrealsense import IntelRealSenseCamera
from lerobot.common.robot_devices.cameras.intelrealsense import (
IntelRealSenseCamera,
)
config = IntelRealSenseCameraConfig(**kwargs)
return IntelRealSenseCamera(config)

View File

@ -93,7 +93,9 @@ class RecordControlConfig(ControlConfig):
policy_path = parser.get_path_arg("control.policy")
if policy_path:
cli_overrides = parser.get_cli_overrides("control.policy")
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
self.policy = PreTrainedConfig.from_pretrained(
policy_path, cli_overrides=cli_overrides
)
self.policy.pretrained_path = policy_path

View File

@ -282,7 +282,10 @@ def control_loop(
if policy is not None:
pred_action = predict_action(
observation, policy, get_safe_torch_device(policy.config.device), policy.config.use_amp
observation,
policy,
get_safe_torch_device(policy.config.device),
policy.config.use_amp,
)
# Action can eventually be clipped using `max_relative_target`,
# so action actually sent is saved in the dataset.

View File

@ -23,7 +23,10 @@ import numpy as np
import tqdm
from lerobot.common.robot_devices.motors.configs import DynamixelMotorsBusConfig
from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError
from lerobot.common.robot_devices.utils import (
RobotDeviceAlreadyConnectedError,
RobotDeviceNotConnectedError,
)
from lerobot.common.utils.utils import capture_timestamp_utc
PROTOCOL_VERSION = 2.0

View File

@ -23,7 +23,10 @@ import numpy as np
import tqdm
from lerobot.common.robot_devices.motors.configs import FeetechMotorsBusConfig
from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError
from lerobot.common.robot_devices.utils import (
RobotDeviceAlreadyConnectedError,
RobotDeviceNotConnectedError,
)
from lerobot.common.utils.utils import capture_timestamp_utc
PROTOCOL_VERSION = 0

View File

@ -30,7 +30,9 @@ class MotorsBus(Protocol):
def write(self): ...
def make_motors_buses_from_configs(motors_bus_configs: dict[str, MotorsBusConfig]) -> list[MotorsBus]:
def make_motors_buses_from_configs(
motors_bus_configs: dict[str, MotorsBusConfig],
) -> list[MotorsBus]:
motors_buses = {}
for key, cfg in motors_bus_configs.items():

View File

@ -69,9 +69,13 @@ class ManipulatorRobotConfig(RobotConfig):
if not cam.mock:
cam.mock = True
if self.max_relative_target is not None and isinstance(self.max_relative_target, Sequence):
if self.max_relative_target is not None and isinstance(
self.max_relative_target, Sequence
):
for name in self.follower_arms:
if len(self.follower_arms[name].motors) != len(self.max_relative_target):
if len(self.follower_arms[name].motors) != len(
self.max_relative_target
):
raise ValueError(
f"len(max_relative_target)={len(self.max_relative_target)} but the follower arm with name {name} has "
f"{len(self.follower_arms[name].motors)} motors. Please make sure that the "

View File

@ -42,7 +42,9 @@ def run_camera_capture(cameras, images_lock, latest_images_dict, stop_event):
local_dict = {}
for name, cam in cameras.items():
frame = cam.async_read()
ret, buffer = cv2.imencode(".jpg", frame, [int(cv2.IMWRITE_JPEG_QUALITY), 90])
ret, buffer = cv2.imencode(
".jpg", frame, [int(cv2.IMWRITE_JPEG_QUALITY), 90]
)
if ret:
local_dict[name] = base64.b64encode(buffer).decode("utf-8")
else:
@ -61,7 +63,9 @@ def calibrate_follower_arm(motors_bus, calib_dir_str):
calib_dir.mkdir(parents=True, exist_ok=True)
calib_file = calib_dir / "main_follower.json"
try:
from lerobot.common.robot_devices.robots.feetech_calibration import run_arm_manual_calibration
from lerobot.common.robot_devices.robots.feetech_calibration import (
run_arm_manual_calibration,
)
except ImportError:
print("[WARNING] Calibration function not available. Skipping calibration.")
return
@ -72,7 +76,9 @@ def calibrate_follower_arm(motors_bus, calib_dir_str):
print(f"[INFO] Loaded calibration from {calib_file}")
else:
print("[INFO] Calibration file not found. Running manual calibration...")
calibration = run_arm_manual_calibration(motors_bus, "lekiwi", "follower_arm", "follower")
calibration = run_arm_manual_calibration(
motors_bus, "lekiwi", "follower_arm", "follower"
)
print(f"[INFO] Calibration complete. Saving to {calib_file}")
with open(calib_file, "w") as f:
json.dump(calibration, f)
@ -116,7 +122,14 @@ def run_lekiwi(robot_config):
robot = LeKiwi(motors_bus)
# Define the expected arm motor IDs.
arm_motor_ids = ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"]
arm_motor_ids = [
"shoulder_pan",
"shoulder_lift",
"elbow_flex",
"wrist_flex",
"wrist_roll",
"gripper",
]
# Disable torque for each arm motor.
for motor in arm_motor_ids:
@ -130,7 +143,9 @@ def run_lekiwi(robot_config):
images_lock = threading.Lock()
stop_event = threading.Event()
cam_thread = threading.Thread(
target=run_camera_capture, args=(cameras, images_lock, latest_images_dict, stop_event), daemon=True
target=run_camera_capture,
args=(cameras, images_lock, latest_images_dict, stop_event),
daemon=True,
)
cam_thread.start()
@ -159,7 +174,9 @@ def run_lekiwi(robot_config):
f"[WARNING] Received {len(arm_positions)} arm positions, expected {len(arm_motor_ids)}"
)
else:
for motor, pos in zip(arm_motor_ids, arm_positions, strict=False):
for motor, pos in zip(
arm_motor_ids, arm_positions, strict=False
):
motors_bus.write("Goal_Position", pos, motor)
# Process wheel (base) commands.
if "raw_velocity" in data:
@ -190,7 +207,9 @@ def run_lekiwi(robot_config):
try:
pos = motors_bus.read("Present_Position", motor)
# Convert the position to a float (or use as is if already numeric).
follower_arm_state.append(float(pos) if not isinstance(pos, (int, float)) else pos)
follower_arm_state.append(
float(pos) if not isinstance(pos, (int, float)) else pos
)
except Exception as e:
print(f"[ERROR] Reading motor {motor} failed: {e}")

View File

@ -28,7 +28,10 @@ import numpy as np
import torch
from lerobot.common.robot_devices.cameras.utils import make_cameras_from_configs
from lerobot.common.robot_devices.motors.utils import MotorsBus, make_motors_buses_from_configs
from lerobot.common.robot_devices.motors.utils import (
MotorsBus,
make_motors_buses_from_configs,
)
from lerobot.common.robot_devices.robots.configs import ManipulatorRobotConfig
from lerobot.common.robot_devices.robots.utils import get_arm_id
from lerobot.common.robot_devices.utils import (

View File

@ -25,9 +25,14 @@ import zmq
from lerobot.common.robot_devices.cameras.utils import make_cameras_from_configs
from lerobot.common.robot_devices.motors.feetech import TorqueMode
from lerobot.common.robot_devices.motors.utils import MotorsBus, make_motors_buses_from_configs
from lerobot.common.robot_devices.motors.utils import (
MotorsBus,
make_motors_buses_from_configs,
)
from lerobot.common.robot_devices.robots.configs import LeKiwiRobotConfig
from lerobot.common.robot_devices.robots.feetech_calibration import run_arm_manual_calibration
from lerobot.common.robot_devices.robots.feetech_calibration import (
run_arm_manual_calibration,
)
from lerobot.common.robot_devices.robots.utils import get_arm_id
from lerobot.common.robot_devices.utils import RobotDeviceNotConnectedError
@ -266,7 +271,9 @@ class MobileManipulator:
calibration = json.load(f)
else:
print(f"Missing calibration file '{arm_calib_path}'")
calibration = run_arm_manual_calibration(arm, self.robot_type, name, arm_type)
calibration = run_arm_manual_calibration(
arm, self.robot_type, name, arm_type
)
print(f"Calibration is done! Saving calibration file '{arm_calib_path}'")
arm_calib_path.parent.mkdir(parents=True, exist_ok=True)
with open(arm_calib_path, "w") as f:
@ -296,7 +303,9 @@ class MobileManipulator:
bus.write("Torque_Enable", 0, motor_id)
# Then filter out wheels
arm_only_dict = {k: v for k, v in bus.motors.items() if not k.startswith("wheel_")}
arm_only_dict = {
k: v for k, v in bus.motors.items() if not k.startswith("wheel_")
}
if not arm_only_dict:
continue
@ -324,7 +333,11 @@ class MobileManipulator:
socks = dict(poller.poll(15))
if self.video_socket not in socks or socks[self.video_socket] != zmq.POLLIN:
# No new data arrived → reuse ALL old data
return (self.last_frames, self.last_present_speed, self.last_remote_arm_state)
return (
self.last_frames,
self.last_present_speed,
self.last_remote_arm_state,
)
# Drain all messages, keep only the last
last_msg = None
@ -337,7 +350,11 @@ class MobileManipulator:
if not last_msg:
# No new message → also reuse old
return (self.last_frames, self.last_present_speed, self.last_remote_arm_state)
return (
self.last_frames,
self.last_present_speed,
self.last_remote_arm_state,
)
# Decode only the final message
try:
@ -360,7 +377,9 @@ class MobileManipulator:
if new_arm_state is not None and frames is not None:
self.last_frames = frames
remote_arm_state_tensor = torch.tensor(new_arm_state, dtype=torch.float32)
remote_arm_state_tensor = torch.tensor(
new_arm_state, dtype=torch.float32
)
self.last_remote_arm_state = remote_arm_state_tensor
present_speed = new_speed
@ -375,14 +394,21 @@ class MobileManipulator:
except Exception as e:
print(f"[DEBUG] Error decoding video message: {e}")
# If decode fails, fall back to old data
return (self.last_frames, self.last_present_speed, self.last_remote_arm_state)
return (
self.last_frames,
self.last_present_speed,
self.last_remote_arm_state,
)
return frames, present_speed, remote_arm_state_tensor
def _process_present_speed(self, present_speed: dict) -> torch.Tensor:
state_tensor = torch.zeros(3, dtype=torch.int32)
if present_speed:
decoded = {key: MobileManipulator.raw_to_degps(value) for key, value in present_speed.items()}
decoded = {
key: MobileManipulator.raw_to_degps(value)
for key, value in present_speed.items()
}
if "1" in decoded:
state_tensor[0] = decoded["1"]
if "2" in decoded:
@ -395,7 +421,9 @@ class MobileManipulator:
self, record_data: bool = False
) -> None | tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
if not self.is_connected:
raise RobotDeviceNotConnectedError("MobileManipulator is not connected. Run `connect()` first.")
raise RobotDeviceNotConnectedError(
"MobileManipulator is not connected. Run `connect()` first."
)
speed_setting = self.speed_levels[self.speed_index]
xy_speed = speed_setting["xy"] # e.g. 0.1, 0.25, or 0.4
@ -461,9 +489,15 @@ class MobileManipulator:
body_state = self.wheel_raw_to_body(present_speed)
body_state_mm = (body_state[0] * 1000.0, body_state[1] * 1000.0, body_state[2]) # Convert x,y to mm/s
body_state_mm = (
body_state[0] * 1000.0,
body_state[1] * 1000.0,
body_state[2],
) # Convert x,y to mm/s
wheel_state_tensor = torch.tensor(body_state_mm, dtype=torch.float32)
combined_state_tensor = torch.cat((remote_arm_state_tensor, wheel_state_tensor), dim=0)
combined_state_tensor = torch.cat(
(remote_arm_state_tensor, wheel_state_tensor), dim=0
)
obs_dict = {"observation.state": combined_state_tensor}
@ -620,7 +654,11 @@ class MobileManipulator:
# Convert each wheels angular speed (deg/s) to a raw integer.
wheel_raw = [MobileManipulator.degps_to_raw(deg) for deg in wheel_degps]
return {"left_wheel": wheel_raw[0], "back_wheel": wheel_raw[1], "right_wheel": wheel_raw[2]}
return {
"left_wheel": wheel_raw[0],
"back_wheel": wheel_raw[1],
"right_wheel": wheel_raw[2],
}
def wheel_raw_to_body(
self, wheel_raw: dict, wheel_radius: float = 0.05, base_radius: float = 0.125

View File

@ -72,7 +72,9 @@ def make_robot_from_config(config: RobotConfig):
return ManipulatorRobot(config)
elif isinstance(config, LeKiwiRobotConfig):
from lerobot.common.robot_devices.robots.mobile_manipulator import MobileManipulator
from lerobot.common.robot_devices.robots.mobile_manipulator import (
MobileManipulator,
)
return MobileManipulator(config)
else:

View File

@ -69,7 +69,9 @@ class HubMixin:
if push_to_hub:
if repo_id is None:
repo_id = save_directory.name # Defaults to `save_directory` name
return self.push_to_hub(repo_id=repo_id, card_kwargs=card_kwargs, **push_to_hub_kwargs)
return self.push_to_hub(
repo_id=repo_id, card_kwargs=card_kwargs, **push_to_hub_kwargs
)
return None
def _save_pretrained(self, save_directory: Path) -> None:
@ -175,7 +177,9 @@ class HubMixin:
The url of the commit of your object in the given repository.
"""
api = HfApi(token=token)
repo_id = api.create_repo(repo_id=repo_id, private=private, exist_ok=True).repo_id
repo_id = api.create_repo(
repo_id=repo_id, private=private, exist_ok=True
).repo_id
if commit_message is None:
if "Policy" in self.__class__.__name__:

View File

@ -20,7 +20,16 @@ from typing import TypeVar
import imageio
JsonLike = str | int | float | bool | None | list["JsonLike"] | dict[str, "JsonLike"] | tuple["JsonLike", ...]
JsonLike = (
str
| int
| float
| bool
| None
| list["JsonLike"]
| dict[str, "JsonLike"]
| tuple["JsonLike", ...]
)
T = TypeVar("T", bound=JsonLike)
@ -76,7 +85,9 @@ def deserialize_json_into_object(fpath: Path, obj: T) -> T:
# Check length
if len(target) != len(source):
raise ValueError(f"List length mismatch: expected {len(target)}, got {len(source)}")
raise ValueError(
f"List length mismatch: expected {len(target)}, got {len(source)}"
)
# Recursively update each element.
for i in range(len(target)):
@ -88,10 +99,14 @@ def deserialize_json_into_object(fpath: Path, obj: T) -> T:
# which we'll convert back to a tuple.
elif isinstance(target, tuple):
if not isinstance(source, list):
raise TypeError(f"Type mismatch: expected list (for tuple), got {type(source)}")
raise TypeError(
f"Type mismatch: expected list (for tuple), got {type(source)}"
)
if len(target) != len(source):
raise ValueError(f"Tuple length mismatch: expected {len(target)}, got {len(source)}")
raise ValueError(
f"Tuple length mismatch: expected {len(target)}, got {len(source)}"
)
# Convert each element, forming a new tuple.
converted_items = []
@ -105,7 +120,9 @@ def deserialize_json_into_object(fpath: Path, obj: T) -> T:
else:
# Check the exact type. If these must match 1:1, do:
if type(target) is not type(source):
raise TypeError(f"Type mismatch: expected {type(target)}, got {type(source)}")
raise TypeError(
f"Type mismatch: expected {type(target)}, got {type(source)}"
)
return source
# Perform the in-place/recursive deserialization

View File

@ -107,13 +107,17 @@ class MetricsTracker:
self.episodes = self.samples / self._avg_samples_per_ep
self.epochs = self.samples / self._num_frames
def __getattr__(self, name: str) -> int | dict[str, AverageMeter] | AverageMeter | Any:
def __getattr__(
self, name: str
) -> int | dict[str, AverageMeter] | AverageMeter | Any:
if name in self.__dict__:
return self.__dict__[name]
elif name in self.metrics:
return self.metrics[name]
else:
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
raise AttributeError(
f"'{self.__class__.__name__}' object has no attribute '{name}'"
)
def __setattr__(self, name: str, value: Any) -> None:
if name in self.__dict__:
@ -121,7 +125,9 @@ class MetricsTracker:
elif name in self.metrics:
self.metrics[name].update(value)
else:
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
raise AttributeError(
f"'{self.__class__.__name__}' object has no attribute '{name}'"
)
def step(self) -> None:
"""

View File

@ -42,7 +42,11 @@ def deserialize_python_rng_state(rng_state_dict: dict[str, torch.Tensor]) -> Non
"""
Restores the rng state for `random` from a dictionary produced by `serialize_python_rng_state()`.
"""
py_state = (rng_state_dict["py_rng_version"].item(), tuple(rng_state_dict["py_rng_state"].tolist()), None)
py_state = (
rng_state_dict["py_rng_version"].item(),
tuple(rng_state_dict["py_rng_state"].tolist()),
None,
)
random.setstate(py_state)
@ -119,7 +123,9 @@ def deserialize_rng_state(rng_state_dict: dict[str, torch.Tensor]) -> None:
"""
py_rng_state_dict = {k: v for k, v in rng_state_dict.items() if k.startswith("py")}
np_rng_state_dict = {k: v for k, v in rng_state_dict.items() if k.startswith("np")}
torch_rng_state_dict = {k: v for k, v in rng_state_dict.items() if k.startswith("torch")}
torch_rng_state_dict = {
k: v for k, v in rng_state_dict.items() if k.startswith("torch")
}
deserialize_python_rng_state(py_rng_state_dict)
deserialize_numpy_rng_state(np_rng_state_dict)

View File

@ -48,7 +48,9 @@ def auto_select_torch_device() -> torch.device:
logging.info("Metal backend detected, using cuda.")
return torch.device("mps")
else:
logging.warning("No accelerated backend detected. Using default cpu, this will be slow.")
logging.warning(
"No accelerated backend detected. Using default cpu, this will be slow."
)
return torch.device("cpu")
@ -96,7 +98,9 @@ def is_torch_device_available(try_device: str) -> bool:
elif try_device == "cpu":
return True
else:
raise ValueError(f"Unknown device {try_device}. Supported devices are: cuda, mps or cpu.")
raise ValueError(
f"Unknown device {try_device}. Supported devices are: cuda, mps or cpu."
)
def is_amp_available(device: str):
@ -219,7 +223,9 @@ def say(text, blocking=False):
if blocking:
subprocess.run(cmd, check=True)
else:
subprocess.Popen(cmd, creationflags=subprocess.CREATE_NO_WINDOW if system == "Windows" else 0)
subprocess.Popen(
cmd, creationflags=subprocess.CREATE_NO_WINDOW if system == "Windows" else 0
)
def log_say(text, play_sounds, blocking=False):

View File

@ -26,7 +26,9 @@ from lerobot.common.constants import PRETRAINED_MODEL_DIR
from lerobot.configs.train import TrainPipelineConfig
def cfg_to_group(cfg: TrainPipelineConfig, return_list: bool = False) -> list[str] | str:
def cfg_to_group(
cfg: TrainPipelineConfig, return_list: bool = False
) -> list[str] | str:
"""Return a group name for logging. Optionally returns group name as list."""
lst = [
f"policy:{cfg.policy.type}",
@ -92,7 +94,9 @@ class WandBLogger:
resume="must" if cfg.resume else None,
)
print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"]))
logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}")
logging.info(
f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}"
)
self._wandb = wandb
def log_policy(self, checkpoint_dir: Path):
@ -104,7 +108,9 @@ class WandBLogger:
artifact_name = f"{self._group}-{step_id}"
artifact_name = get_safe_wandb_artifact_name(artifact_name)
artifact = self._wandb.Artifact(artifact_name, type="model")
artifact.add_file(checkpoint_dir / PRETRAINED_MODEL_DIR / SAFETENSORS_SINGLE_FILE)
artifact.add_file(
checkpoint_dir / PRETRAINED_MODEL_DIR / SAFETENSORS_SINGLE_FILE
)
self._wandb.log_artifact(artifact)
def log_dict(self, d: dict, step: int, mode: str = "train"):

View File

@ -33,7 +33,9 @@ class DatasetConfig:
# Root directory where the dataset will be stored (e.g. 'dataset/path').
root: str | None = None
episodes: list[int] | None = None
image_transforms: ImageTransformsConfig = field(default_factory=ImageTransformsConfig)
image_transforms: ImageTransformsConfig = field(
default_factory=ImageTransformsConfig
)
revision: str | None = None
use_imagenet_stats: bool = True
video_backend: str = field(default_factory=get_safe_default_codec)

View File

@ -40,7 +40,9 @@ class EvalPipelineConfig:
policy_path = parser.get_path_arg("policy")
if policy_path:
cli_overrides = parser.get_cli_overrides("policy")
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
self.policy = PreTrainedConfig.from_pretrained(
policy_path, cli_overrides=cli_overrides
)
self.policy.pretrained_path = policy_path
else:

View File

@ -29,7 +29,9 @@ PLUGIN_DISCOVERY_SUFFIX = "discover_packages_path"
draccus.set_config_type("json")
def get_cli_overrides(field_name: str, args: Sequence[str] | None = None) -> list[str] | None:
def get_cli_overrides(
field_name: str, args: Sequence[str] | None = None
) -> list[str] | None:
"""Parses arguments from cli at a given nested attribute level.
For example, supposing the main script was called with:
@ -42,7 +44,10 @@ def get_cli_overrides(field_name: str, args: Sequence[str] | None = None) -> lis
args = sys.argv[1:]
attr_level_args = []
detect_string = f"--{field_name}."
exclude_strings = (f"--{field_name}.{draccus.CHOICE_TYPE_KEY}=", f"--{field_name}.{PATH_KEY}=")
exclude_strings = (
f"--{field_name}.{draccus.CHOICE_TYPE_KEY}=",
f"--{field_name}.{PATH_KEY}=",
)
for arg in args:
if arg.startswith(detect_string) and not arg.startswith(exclude_strings):
denested_arg = f"--{arg.removeprefix(detect_string)}"
@ -153,7 +158,9 @@ def filter_arg(field_to_filter: str, args: Sequence[str] | None = None) -> list[
return [arg for arg in args if not arg.startswith(f"--{field_to_filter}=")]
def filter_path_args(fields_to_filter: str | list[str], args: Sequence[str] | None = None) -> list[str]:
def filter_path_args(
fields_to_filter: str | list[str], args: Sequence[str] | None = None
) -> list[str]:
"""
Filters command-line arguments related to fields with specific path arguments.
@ -181,7 +188,9 @@ def filter_path_args(fields_to_filter: str | list[str], args: Sequence[str] | No
argument=None,
message=f"Cannot specify both --{field}.{PATH_KEY} and --{field}.{draccus.CHOICE_TYPE_KEY}",
)
filtered_args = [arg for arg in filtered_args if not arg.startswith(f"--{field}.")]
filtered_args = [
arg for arg in filtered_args if not arg.startswith(f"--{field}.")
]
return filtered_args
@ -213,7 +222,9 @@ def wrap(config_path: Path | None = None):
load_plugin(plugin_path)
except PluginLoadError as e:
# add the relevant CLI arg to the error message
raise PluginLoadError(f"{e}\nFailed plugin CLI Arg: {plugin_cli_arg}") from e
raise PluginLoadError(
f"{e}\nFailed plugin CLI Arg: {plugin_cli_arg}"
) from e
cli_args = filter_arg(plugin_cli_arg, cli_args)
config_path_cli = parse_arg("config_path", cli_args)
if has_method(argtype, "__get_path_fields__"):
@ -223,7 +234,9 @@ def wrap(config_path: Path | None = None):
cli_args = filter_arg("config_path", cli_args)
cfg = argtype.from_pretrained(config_path_cli, cli_args=cli_args)
else:
cfg = draccus.parse(config_class=argtype, config_path=config_path, args=cli_args)
cfg = draccus.parse(
config_class=argtype, config_path=config_path, args=cli_args
)
response = fn(cfg, *args, **kwargs)
return response

View File

@ -26,7 +26,11 @@ from huggingface_hub.errors import HfHubHTTPError
from lerobot.common.optim.optimizers import OptimizerConfig
from lerobot.common.optim.schedulers import LRSchedulerConfig
from lerobot.common.utils.hub import HubMixin
from lerobot.common.utils.utils import auto_select_torch_device, is_amp_available, is_torch_device_available
from lerobot.common.utils.utils import (
auto_select_torch_device,
is_amp_available,
is_torch_device_available,
)
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
# Generic variable that is either PreTrainedConfig or a subclass thereof
@ -64,7 +68,9 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
self.pretrained_path = None
if not self.device or not is_torch_device_available(self.device):
auto_device = auto_select_torch_device()
logging.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.")
logging.warning(
f"Device '{self.device}' is not available. Switching to '{auto_device}'."
)
self.device = auto_device.type
# Automatically deactivate AMP if necessary
@ -118,7 +124,11 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
@property
def image_features(self) -> dict[str, PolicyFeature]:
return {key: ft for key, ft in self.input_features.items() if ft.type is FeatureType.VISUAL}
return {
key: ft
for key, ft in self.input_features.items()
if ft.type is FeatureType.VISUAL
}
@property
def action_feature(self) -> PolicyFeature | None:

View File

@ -73,7 +73,9 @@ class TrainPipelineConfig(HubMixin):
if policy_path:
# Only load the policy config
cli_overrides = parser.get_cli_overrides("policy")
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
self.policy = PreTrainedConfig.from_pretrained(
policy_path, cli_overrides=cli_overrides
)
self.policy.pretrained_path = policy_path
elif self.resume:
# The entire train config is already loaded, we just need to get the checkpoint dir
@ -97,7 +99,11 @@ class TrainPipelineConfig(HubMixin):
else:
self.job_name = f"{self.env.type}_{self.policy.type}"
if not self.resume and isinstance(self.output_dir, Path) and self.output_dir.is_dir():
if (
not self.resume
and isinstance(self.output_dir, Path)
and self.output_dir.is_dir()
):
raise FileExistsError(
f"Output directory {self.output_dir} already exists and resume is {self.resume}. "
f"Please change your output directory so that {self.output_dir} is not overwritten."
@ -108,10 +114,16 @@ class TrainPipelineConfig(HubMixin):
self.output_dir = Path("outputs/train") / train_dir
if isinstance(self.dataset.repo_id, list):
raise NotImplementedError("LeRobotMultiDataset is not currently implemented.")
raise NotImplementedError(
"LeRobotMultiDataset is not currently implemented."
)
if not self.use_policy_training_preset and (self.optimizer is None or self.scheduler is None):
raise ValueError("Optimizer and Scheduler must be set when the policy presets are not used.")
if not self.use_policy_training_preset and (
self.optimizer is None or self.scheduler is None
):
raise ValueError(
"Optimizer and Scheduler must be set when the policy presets are not used."
)
elif self.use_policy_training_preset and not self.resume:
self.optimizer = self.policy.get_optimizer_preset()
self.scheduler = self.policy.get_scheduler_preset()
@ -125,7 +137,10 @@ class TrainPipelineConfig(HubMixin):
return draccus.encode(self)
def _save_pretrained(self, save_directory: Path) -> None:
with open(save_directory / TRAIN_CONFIG_NAME, "w") as f, draccus.config_type("json"):
with (
open(save_directory / TRAIN_CONFIG_NAME, "w") as f,
draccus.config_type("json"),
):
draccus.dump(self, f, indent=4)
@classmethod

View File

@ -38,7 +38,12 @@ def get_motor_bus_cls(brand: str) -> tuple:
FeetechMotorsBus,
)
return FeetechMotorsBusConfig, FeetechMotorsBus, MODEL_BAUDRATE_TABLE, SCS_SERIES_BAUDRATE_TABLE
return (
FeetechMotorsBusConfig,
FeetechMotorsBus,
MODEL_BAUDRATE_TABLE,
SCS_SERIES_BAUDRATE_TABLE,
)
elif brand == "dynamixel":
from lerobot.common.robot_devices.motors.configs import DynamixelMotorsBusConfig
@ -48,7 +53,12 @@ def get_motor_bus_cls(brand: str) -> tuple:
DynamixelMotorsBus,
)
return DynamixelMotorsBusConfig, DynamixelMotorsBus, MODEL_BAUDRATE_TABLE, X_SERIES_BAUDRATE_TABLE
return (
DynamixelMotorsBusConfig,
DynamixelMotorsBus,
MODEL_BAUDRATE_TABLE,
X_SERIES_BAUDRATE_TABLE,
)
else:
raise ValueError(
@ -57,8 +67,8 @@ def get_motor_bus_cls(brand: str) -> tuple:
def configure_motor(port, brand, model, motor_idx_des, baudrate_des):
motor_bus_config_cls, motor_bus_cls, model_baudrate_table, series_baudrate_table = get_motor_bus_cls(
brand
motor_bus_config_cls, motor_bus_cls, model_baudrate_table, series_baudrate_table = (
get_motor_bus_cls(brand)
)
# Check if the provided model exists in the model_baud_rate_table
@ -72,7 +82,9 @@ def configure_motor(port, brand, model, motor_idx_des, baudrate_des):
motor_index_arbitrary = motor_idx_des # Use the motor ID passed via argument
motor_model = model # Use the motor model passed via argument
config = motor_bus_config_cls(port=port, motors={motor_name: (motor_index_arbitrary, motor_model)})
config = motor_bus_config_cls(
port=port, motors={motor_name: (motor_index_arbitrary, motor_model)}
)
# Initialize the MotorBus with the correct port and motor configurations
motor_bus = motor_bus_cls(config=config)
@ -139,8 +151,12 @@ def configure_motor(port, brand, model, motor_idx_des, baudrate_des):
print(f"Setting its index to desired index {motor_idx_des}")
if brand == "feetech":
motor_bus.write_with_motor_ids(motor_bus.motor_models, motor_index, "Lock", 0)
motor_bus.write_with_motor_ids(motor_bus.motor_models, motor_index, "ID", motor_idx_des)
motor_bus.write_with_motor_ids(
motor_bus.motor_models, motor_index, "Lock", 0
)
motor_bus.write_with_motor_ids(
motor_bus.motor_models, motor_index, "ID", motor_idx_des
)
present_idx = motor_bus.read_with_motor_ids(
motor_bus.motor_models, motor_idx_des, "ID", num_retry=2

View File

@ -156,7 +156,6 @@ from lerobot.common.robot_devices.control_utils import (
log_control_info,
record_episode,
reset_environment,
reset_follower_position,
sanity_check_dataset_name,
sanity_check_dataset_robot_compatibility,
stop_recording,
@ -251,7 +250,8 @@ def record(
if len(robot.cameras) > 0:
dataset.start_image_writer(
num_processes=cfg.num_image_writer_processes,
num_threads=cfg.num_image_writer_threads_per_camera * len(robot.cameras),
num_threads=cfg.num_image_writer_threads_per_camera
* len(robot.cameras),
)
sanity_check_dataset_robot_compatibility(dataset, robot, cfg.fps, cfg.video)
else:
@ -264,14 +264,19 @@ def record(
robot=robot,
use_videos=cfg.video,
image_writer_processes=cfg.num_image_writer_processes,
image_writer_threads=cfg.num_image_writer_threads_per_camera * len(robot.cameras),
image_writer_threads=cfg.num_image_writer_threads_per_camera
* len(robot.cameras),
)
# Load pretrained policy
policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta)
policy = (
None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta)
)
# Load pretrained policy
policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta)
policy = (
None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta)
)
if not robot.is_connected:
robot.connect()
@ -286,7 +291,14 @@ def record(
# 3. place the cameras windows on screen
enable_teleoperation = policy is None
log_say("Warmup record", cfg.play_sounds)
warmup_record(robot, events, enable_teleoperation, cfg.warmup_time_s, cfg.display_cameras, cfg.fps)
warmup_record(
robot,
events,
enable_teleoperation,
cfg.warmup_time_s,
cfg.display_cameras,
cfg.fps,
)
if has_method(robot, "teleop_safety_stop"):
robot.teleop_safety_stop()

View File

@ -262,7 +262,11 @@ def record(
shape = env.observation_space[key].shape
if not key.startswith("observation.image."):
key = "observation.image." + key
features[key] = {"dtype": "video", "names": ["channels", "height", "width"], "shape": shape}
features[key] = {
"dtype": "video",
"names": ["channels", "height", "width"],
"shape": shape,
}
for key, obs_key in state_keys_dict.items():
features[key] = {

View File

@ -152,7 +152,8 @@ def rollout(
all_observations.append(deepcopy(observation))
observation = {
key: observation[key].to(device, non_blocking=device.type == "cuda") for key in observation
key: observation[key].to(device, non_blocking=device.type == "cuda")
for key in observation
}
with torch.inference_mode():
@ -511,10 +512,14 @@ def eval_main(cfg: EvalPipelineConfig):
torch.backends.cuda.matmul.allow_tf32 = True
set_seed(cfg.seed)
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}")
logging.info(
colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}"
)
logging.info("Making environment.")
env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs)
env = make_env(
cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs
)
logging.info("Making policy.")
@ -524,7 +529,12 @@ def eval_main(cfg: EvalPipelineConfig):
)
policy.eval()
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext():
with (
torch.no_grad(),
torch.autocast(device_type=device.type)
if cfg.policy.use_amp
else nullcontext(),
):
info = eval_policy(
env,
policy,

View File

@ -480,7 +480,7 @@ def test_forward_kinematics(robot, fps=10):
obs = robot.capture_observation()
joint_positions = obs["observation.state"].cpu().numpy()
ee_pos = RobotKinematics.fk_gripper_tip(joint_positions)
logging.info(f"EE Position: {ee_pos[:3,3]}")
logging.info(f"EE Position: {ee_pos[:3, 3]}")
busy_wait(1 / fps - (time.perf_counter() - loop_start_time))
@ -519,7 +519,7 @@ def teleoperate_inverse_kinematics_with_leader(robot, fps=10):
joint_positions, desired_ee_pos, position_only=True, fk_func=fk_func
)
robot.send_action(torch.from_numpy(target_joint_state))
logging.info(f"Leader EE: {leader_ee[:3,3]}, Follower EE: {ee_pos[:3,3]}")
logging.info(f"Leader EE: {leader_ee[:3, 3]}, Follower EE: {ee_pos[:3, 3]}")
busy_wait(1 / fps - (time.perf_counter() - loop_start_time))
@ -574,9 +574,9 @@ def teleoperate_delta_inverse_kinematics_with_leader(robot, fps=10):
# Logging
logging.info(
f"Current EE: {current_ee_pos[:3,3]}, Desired EE: {desired_ee_pos[:3,3]}"
f"Current EE: {current_ee_pos[:3, 3]}, Desired EE: {desired_ee_pos[:3, 3]}"
)
logging.info(f"Delta EE: {ee_delta[:3,3]}")
logging.info(f"Delta EE: {ee_delta[:3, 3]}")
busy_wait(1 / fps - (time.perf_counter() - loop_start_time))

View File

@ -1087,9 +1087,9 @@ class GamepadControlWrapper(gym.Wrapper):
class ActionScaleWrapper(gym.ActionWrapper):
def __init__(self, env, ee_action_space_params=None):
super().__init__(env)
assert (
ee_action_space_params is not None
), "TODO: method implemented for ee action space only so far"
assert ee_action_space_params is not None, (
"TODO: method implemented for ee action space only so far"
)
self.scale_vector = np.array(
[
[
@ -1546,4 +1546,4 @@ if __name__ == "__main__":
busy_wait(1 / args.fps - dt_s)
logging.info(f"Success after 20 steps {sucesses}")
logging.info(f"success rate {sum(sucesses)/ len(sucesses)}")
logging.info(f"success rate {sum(sucesses) / len(sucesses)}")

View File

@ -41,7 +41,7 @@ def send_bytes_in_chunks(
logging_method = logging.info if not silent else logging.debug
logging_method(f"{log_prefix} Buffer size {size_in_bytes/1024/1024} MB with")
logging_method(f"{log_prefix} Buffer size {size_in_bytes / 1024 / 1024} MB with")
while sent_bytes < size_in_bytes:
transfer_state = hilserl_pb2.TransferState.TRANSFER_MIDDLE
@ -60,7 +60,7 @@ def send_bytes_in_chunks(
f"{log_prefix} Sent {sent_bytes}/{size_in_bytes} bytes with state {transfer_state}"
)
logging_method(f"{log_prefix} Published {sent_bytes/1024/1024} MB")
logging_method(f"{log_prefix} Published {sent_bytes / 1024 / 1024} MB")
def receive_bytes_in_chunks(

View File

@ -223,14 +223,18 @@ def train(cfg: TrainPipelineConfig):
step = 0 # number of policy updates (forward + backward + optim)
if cfg.resume:
step, optimizer, lr_scheduler = load_training_state(cfg.checkpoint_path, optimizer, lr_scheduler)
step, optimizer, lr_scheduler = load_training_state(
cfg.checkpoint_path, optimizer, lr_scheduler
)
num_learnable_params = sum(
p.numel() for p in policy.parameters() if p.requires_grad
)
num_total_params = sum(p.numel() for p in policy.parameters())
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}")
logging.info(
colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}"
)
if cfg.env is not None:
logging.info(f"{cfg.env.task=}")
logging.info(f"{cfg.steps=} ({format_big_number(cfg.steps)})")
@ -273,7 +277,11 @@ def train(cfg: TrainPipelineConfig):
}
train_tracker = MetricsTracker(
cfg.batch_size, dataset.num_frames, dataset.num_episodes, train_metrics, initial_step=step
cfg.batch_size,
dataset.num_frames,
dataset.num_episodes,
train_metrics,
initial_step=step,
)
logging.info("Start offline training on a fixed dataset")
@ -327,7 +335,9 @@ def train(cfg: TrainPipelineConfig):
logging.info(f"Eval policy at step {step}")
with (
torch.no_grad(),
torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(),
torch.autocast(device_type=device.type)
if cfg.policy.use_amp
else nullcontext(),
):
eval_info = eval_policy(
eval_env,
@ -344,7 +354,11 @@ def train(cfg: TrainPipelineConfig):
"eval_s": AverageMeter("eval_s", ":.3f"),
}
eval_tracker = MetricsTracker(
cfg.batch_size, dataset.num_frames, dataset.num_episodes, eval_metrics, initial_step=step
cfg.batch_size,
dataset.num_frames,
dataset.num_episodes,
eval_metrics,
initial_step=step,
)
eval_tracker.eval_s = eval_info["aggregated"].pop("eval_s")
eval_tracker.avg_sum_reward = eval_info["aggregated"].pop("avg_sum_reward")

View File

@ -426,7 +426,7 @@ def train(
# Training loop with validation and checkpointing
for epoch in range(cfg.training.num_epochs):
logging.info(f"\nEpoch {epoch+1}/{cfg.training.num_epochs}")
logging.info(f"\nEpoch {epoch + 1}/{cfg.training.num_epochs}")
train_epoch(
model,
@ -470,7 +470,7 @@ def train(
policy=model,
optimizer=optimizer,
scheduler=None,
identifier=f"{epoch+1:06d}",
identifier=f"{epoch + 1:06d}",
)
step += len(train_loader)

View File

@ -94,9 +94,9 @@ def to_hwc_uint8_numpy(chw_float32_torch: torch.Tensor) -> np.ndarray:
assert chw_float32_torch.dtype == torch.float32
assert chw_float32_torch.ndim == 3
c, h, w = chw_float32_torch.shape
assert (
c < h and c < w
), f"expect channel first images, but instead {chw_float32_torch.shape}"
assert c < h and c < w, (
f"expect channel first images, but instead {chw_float32_torch.shape}"
)
hwc_uint8_numpy = (
(chw_float32_torch * 255).type(torch.uint8).permute(1, 2, 0).numpy()
)

View File

@ -158,7 +158,9 @@ def run_server(
400,
)
dataset_version = (
str(dataset.meta._version) if isinstance(dataset, LeRobotDataset) else dataset.codebase_version
str(dataset.meta._version)
if isinstance(dataset, LeRobotDataset)
else dataset.codebase_version
)
match = re.search(r"v(\d+)\.", dataset_version)
if match:
@ -166,7 +168,9 @@ def run_server(
if major_version < 2:
return "Make sure to convert your LeRobotDataset to v2 & above."
episode_data_csv_str, columns, ignored_columns = get_episode_data(dataset, episode_id)
episode_data_csv_str, columns, ignored_columns = get_episode_data(
dataset, episode_id
)
dataset_info = {
"repo_id": f"{dataset_namespace}/{dataset_name}",
"num_samples": dataset.num_frames
@ -208,7 +212,8 @@ def run_server(
]
response = requests.get(
f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/episodes.jsonl", timeout=5
f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/episodes.jsonl",
timeout=5,
)
response.raise_for_status()
# Split into lines and parse each line as JSON
@ -256,7 +261,11 @@ def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index)
This file will be loaded by Dygraph javascript to plot data in real time."""
columns = []
selected_columns = [col for col, ft in dataset.features.items() if ft["dtype"] in ["float32", "int32"]]
selected_columns = [
col
for col, ft in dataset.features.items()
if ft["dtype"] in ["float32", "int32"]
]
selected_columns.remove("timestamp")
ignored_columns = []
@ -361,7 +370,8 @@ def get_episode_language_instruction(
def get_dataset_info(repo_id: str) -> IterableNamespace:
response = requests.get(
f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/info.json", timeout=5
f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/info.json",
timeout=5,
)
response.raise_for_status() # Raises an HTTPError for bad responses
dataset_info = response.json()

View File

@ -47,7 +47,9 @@ OUTPUT_DIR = Path("outputs/image_transforms")
to_pil = ToPILImage()
def save_all_transforms(cfg: ImageTransformsConfig, original_frame, output_dir, n_examples):
def save_all_transforms(
cfg: ImageTransformsConfig, original_frame, output_dir, n_examples
):
output_dir_all = output_dir / "all"
output_dir_all.mkdir(parents=True, exist_ok=True)
@ -60,7 +62,9 @@ def save_all_transforms(cfg: ImageTransformsConfig, original_frame, output_dir,
print(f" {output_dir_all}")
def save_each_transform(cfg: ImageTransformsConfig, original_frame, output_dir, n_examples):
def save_each_transform(
cfg: ImageTransformsConfig, original_frame, output_dir, n_examples
):
if not cfg.enable:
logging.warning(
"No single transforms will be saved, because `image_transforms.enable=False`. To enable, set `enable` to True in `ImageTransformsConfig` or in the command line with `--image_transforms.enable=True`."
@ -89,9 +93,15 @@ def save_each_transform(cfg: ImageTransformsConfig, original_frame, output_dir,
tf_cfg_kwgs_max[key] = [max_, max_]
tf_cfg_kwgs_avg[key] = [avg, avg]
tf_min = make_transform_from_config(replace(tf_cfg, **{"kwargs": tf_cfg_kwgs_min}))
tf_max = make_transform_from_config(replace(tf_cfg, **{"kwargs": tf_cfg_kwgs_max}))
tf_avg = make_transform_from_config(replace(tf_cfg, **{"kwargs": tf_cfg_kwgs_avg}))
tf_min = make_transform_from_config(
replace(tf_cfg, **{"kwargs": tf_cfg_kwgs_min})
)
tf_max = make_transform_from_config(
replace(tf_cfg, **{"kwargs": tf_cfg_kwgs_max})
)
tf_avg = make_transform_from_config(
replace(tf_cfg, **{"kwargs": tf_cfg_kwgs_avg})
)
tf_frame_min = tf_min(original_frame)
tf_frame_max = tf_max(original_frame)
@ -105,7 +115,9 @@ def save_each_transform(cfg: ImageTransformsConfig, original_frame, output_dir,
@draccus.wrap()
def visualize_image_transforms(cfg: DatasetConfig, output_dir: Path = OUTPUT_DIR, n_examples: int = 5):
def visualize_image_transforms(
cfg: DatasetConfig, output_dir: Path = OUTPUT_DIR, n_examples: int = 5
):
dataset = LeRobotDataset(
repo_id=cfg.repo_id,
episodes=cfg.episodes,

View File

@ -51,7 +51,9 @@ def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict):
batch = next(iter(dataloader))
loss, output_dict = policy.forward(batch)
if output_dict is not None:
output_dict = {k: v for k, v in output_dict.items() if isinstance(v, torch.Tensor)}
output_dict = {
k: v for k, v in output_dict.items() if isinstance(v, torch.Tensor)
}
output_dict["loss"] = loss
else:
output_dict = {"loss": loss}
@ -69,7 +71,9 @@ def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict):
param_stats = {}
for key, param in policy.named_parameters():
param_stats[f"{key}_mean"] = param.mean()
param_stats[f"{key}_std"] = param.std() if param.numel() > 1 else torch.tensor(float(0.0))
param_stats[f"{key}_std"] = (
param.std() if param.numel() > 1 else torch.tensor(float(0.0))
)
optimizer.zero_grad()
policy.reset()
@ -96,11 +100,15 @@ def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict):
else:
actions_queue = train_cfg.policy.n_action_repeats
actions = {str(i): policy.select_action(obs).contiguous() for i in range(actions_queue)}
actions = {
str(i): policy.select_action(obs).contiguous() for i in range(actions_queue)
}
return output_dict, grad_stats, param_stats, actions
def save_policy_to_safetensors(output_dir: Path, ds_repo_id: str, policy_name: str, policy_kwargs: dict):
def save_policy_to_safetensors(
output_dir: Path, ds_repo_id: str, policy_name: str, policy_kwargs: dict
):
if output_dir.exists():
print(f"Overwrite existing safetensors in '{output_dir}':")
print(f" - Validate with: `git add {output_dir}`")
@ -108,7 +116,9 @@ def save_policy_to_safetensors(output_dir: Path, ds_repo_id: str, policy_name: s
shutil.rmtree(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
output_dict, grad_stats, param_stats, actions = get_policy_stats(ds_repo_id, policy_name, policy_kwargs)
output_dict, grad_stats, param_stats, actions = get_policy_stats(
ds_repo_id, policy_name, policy_kwargs
)
save_file(output_dict, output_dir / "output_dict.safetensors")
save_file(grad_stats, output_dir / "grad_stats.safetensors")
save_file(param_stats, output_dir / "param_stats.safetensors")
@ -141,5 +151,7 @@ if __name__ == "__main__":
raise RuntimeError("No policies were provided!")
for ds_repo_id, policy, policy_kwargs, file_name_extra in artifacts_cfg:
ds_name = ds_repo_id.split("/")[-1]
output_dir = Path("tests/artifacts/policies") / f"{ds_name}_{policy}_{file_name_extra}"
output_dir = (
Path("tests/artifacts/policies") / f"{ds_name}_{policy}_{file_name_extra}"
)
save_policy_to_safetensors(output_dir, ds_repo_id, policy, policy_kwargs)

View File

@ -226,7 +226,13 @@ def test_save_images_from_cameras(tmp_path, request, camera_type, mock):
@pytest.mark.parametrize("camera_type, mock", TEST_CAMERA_TYPES)
@require_camera
def test_camera_rotation(request, camera_type, mock):
config_kwargs = {"camera_type": camera_type, "mock": mock, "width": 640, "height": 480, "fps": 30}
config_kwargs = {
"camera_type": camera_type,
"mock": mock,
"width": 640,
"height": 480,
"fps": 30,
}
# No rotation.
camera = make_camera(**config_kwargs, rotation=None)

View File

@ -9,7 +9,9 @@ from lerobot.common.envs.configs import EnvConfig
from lerobot.configs.parser import PluginLoadError, load_plugin, parse_plugin_args, wrap
def create_plugin_code(*, base_class: str = "EnvConfig", plugin_name: str = "test_env") -> str:
def create_plugin_code(
*, base_class: str = "EnvConfig", plugin_name: str = "test_env"
) -> str:
"""Creates a dummy plugin module that implements its own EnvConfig subclass."""
return f"""
from dataclasses import dataclass

View File

@ -31,7 +31,11 @@ from lerobot.common.datasets.compute_stats import (
def mock_load_image_as_numpy(path, dtype, channel_first):
return np.ones((3, 32, 32), dtype=dtype) if channel_first else np.ones((32, 32, 3), dtype=dtype)
return (
np.ones((3, 32, 32), dtype=dtype)
if channel_first
else np.ones((32, 32, 3), dtype=dtype)
)
@pytest.fixture
@ -61,7 +65,10 @@ def test_sample_indices():
assert len(indices) == estimate_num_samples(10)
@patch("lerobot.common.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy)
@patch(
"lerobot.common.datasets.compute_stats.load_image_as_numpy",
side_effect=mock_load_image_as_numpy,
)
def test_sample_images(mock_load):
image_paths = [f"image_{i}.jpg" for i in range(100)]
images = sample_images(image_paths)
@ -74,9 +81,20 @@ def test_sample_images(mock_load):
def test_get_feature_stats_images():
data = np.random.rand(100, 3, 32, 32)
stats = get_feature_stats(data, axis=(0, 2, 3), keepdims=True)
assert "min" in stats and "max" in stats and "mean" in stats and "std" in stats and "count" in stats
assert (
"min" in stats
and "max" in stats
and "mean" in stats
and "std" in stats
and "count" in stats
)
np.testing.assert_equal(stats["count"], np.array([100]))
assert stats["min"].shape == stats["max"].shape == stats["mean"].shape == stats["std"].shape
assert (
stats["min"].shape
== stats["max"].shape
== stats["mean"].shape
== stats["std"].shape
)
def test_get_feature_stats_axis_0_keepdims(sample_array):
@ -145,7 +163,8 @@ def test_compute_episode_stats():
}
with patch(
"lerobot.common.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy
"lerobot.common.datasets.compute_stats.load_image_as_numpy",
side_effect=mock_load_image_as_numpy,
):
stats = compute_episode_stats(episode_data, features)
@ -233,7 +252,13 @@ def test_aggregate_stats():
"std": [2.87, 5.87, 8.87],
"count": 10,
},
"observation.state": {"min": 1, "max": 10, "mean": 5.5, "std": 2.87, "count": 10},
"observation.state": {
"min": 1,
"max": 10,
"mean": 5.5,
"std": 2.87,
"count": 10,
},
"extra_key_0": {"min": 5, "max": 25, "mean": 15, "std": 6, "count": 6},
},
{
@ -244,7 +269,13 @@ def test_aggregate_stats():
"std": [3.42, 2.42, 1.42],
"count": 15,
},
"observation.state": {"min": 2, "max": 15, "mean": 8.5, "std": 3.42, "count": 15},
"observation.state": {
"min": 2,
"max": 15,
"mean": 8.5,
"std": 3.42,
"count": 15,
},
"extra_key_1": {"min": 0, "max": 20, "mean": 10, "std": 5, "count": 5},
},
]
@ -284,28 +315,47 @@ def test_aggregate_stats():
for ep_stats in all_stats:
for fkey, stats in ep_stats.items():
for k in stats:
stats[k] = np.array(stats[k], dtype=np.int64 if k == "count" else np.float32)
stats[k] = np.array(
stats[k], dtype=np.int64 if k == "count" else np.float32
)
if fkey == "observation.image" and k != "count":
stats[k] = stats[k].reshape(3, 1, 1) # for normalization on image channels
stats[k] = stats[k].reshape(
3, 1, 1
) # for normalization on image channels
else:
stats[k] = stats[k].reshape(1)
# cast to numpy
for fkey, stats in expected_agg_stats.items():
for k in stats:
stats[k] = np.array(stats[k], dtype=np.int64 if k == "count" else np.float32)
stats[k] = np.array(
stats[k], dtype=np.int64 if k == "count" else np.float32
)
if fkey == "observation.image" and k != "count":
stats[k] = stats[k].reshape(3, 1, 1) # for normalization on image channels
stats[k] = stats[k].reshape(
3, 1, 1
) # for normalization on image channels
else:
stats[k] = stats[k].reshape(1)
results = aggregate_stats(all_stats)
for fkey in expected_agg_stats:
np.testing.assert_allclose(results[fkey]["min"], expected_agg_stats[fkey]["min"])
np.testing.assert_allclose(results[fkey]["max"], expected_agg_stats[fkey]["max"])
np.testing.assert_allclose(results[fkey]["mean"], expected_agg_stats[fkey]["mean"])
np.testing.assert_allclose(
results[fkey]["std"], expected_agg_stats[fkey]["std"], atol=1e-04, rtol=1e-04
results[fkey]["min"], expected_agg_stats[fkey]["min"]
)
np.testing.assert_allclose(
results[fkey]["max"], expected_agg_stats[fkey]["max"]
)
np.testing.assert_allclose(
results[fkey]["mean"], expected_agg_stats[fkey]["mean"]
)
np.testing.assert_allclose(
results[fkey]["std"],
expected_agg_stats[fkey]["std"],
atol=1e-04,
rtol=1e-04,
)
np.testing.assert_allclose(
results[fkey]["count"], expected_agg_stats[fkey]["count"]
)
np.testing.assert_allclose(results[fkey]["count"], expected_agg_stats[fkey]["count"])

View File

@ -72,7 +72,9 @@ def test_same_attributes_defined(tmp_path, lerobot_dataset_factory):
# Instantiate both ways
robot = make_robot("koch", mock=True)
root_create = tmp_path / "create"
dataset_create = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, robot=robot, root=root_create)
dataset_create = LeRobotDataset.create(
repo_id=DUMMY_REPO_ID, fps=30, robot=robot, root=root_create
)
root_init = tmp_path / "init"
dataset_init = lerobot_dataset_factory(root=root_init)
@ -104,7 +106,8 @@ def test_add_frame_missing_task(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
with pytest.raises(
ValueError, match="Feature mismatch in `frame` dictionary:\nMissing features: {'task'}\n"
ValueError,
match="Feature mismatch in `frame` dictionary:\nMissing features: {'task'}\n",
):
dataset.add_frame({"state": torch.randn(1)})
@ -113,7 +116,8 @@ def test_add_frame_missing_feature(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
with pytest.raises(
ValueError, match="Feature mismatch in `frame` dictionary:\nMissing features: {'state'}\n"
ValueError,
match="Feature mismatch in `frame` dictionary:\nMissing features: {'state'}\n",
):
dataset.add_frame({"task": "Dummy task"})
@ -122,18 +126,24 @@ def test_add_frame_extra_feature(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
with pytest.raises(
ValueError, match="Feature mismatch in `frame` dictionary:\nExtra features: {'extra'}\n"
ValueError,
match="Feature mismatch in `frame` dictionary:\nExtra features: {'extra'}\n",
):
dataset.add_frame({"state": torch.randn(1), "task": "Dummy task", "extra": "dummy_extra"})
dataset.add_frame(
{"state": torch.randn(1), "task": "Dummy task", "extra": "dummy_extra"}
)
def test_add_frame_wrong_type(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
with pytest.raises(
ValueError, match="The feature 'state' of dtype 'float16' is not of the expected dtype 'float32'.\n"
ValueError,
match="The feature 'state' of dtype 'float16' is not of the expected dtype 'float32'.\n",
):
dataset.add_frame({"state": torch.randn(1, dtype=torch.float16), "task": "Dummy task"})
dataset.add_frame(
{"state": torch.randn(1, dtype=torch.float16), "task": "Dummy task"}
)
def test_add_frame_wrong_shape(tmp_path, empty_lerobot_dataset_factory):
@ -141,7 +151,9 @@ def test_add_frame_wrong_shape(tmp_path, empty_lerobot_dataset_factory):
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
with pytest.raises(
ValueError,
match=re.escape("The feature 'state' of shape '(1,)' does not have the expected shape '(2,)'.\n"),
match=re.escape(
"The feature 'state' of shape '(1,)' does not have the expected shape '(2,)'.\n"
),
):
dataset.add_frame({"state": torch.randn(1), "task": "Dummy task"})
@ -163,7 +175,9 @@ def test_add_frame_wrong_shape_torch_ndim_0(tmp_path, empty_lerobot_dataset_fact
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
with pytest.raises(
ValueError,
match=re.escape("The feature 'state' of shape '()' does not have the expected shape '(1,)'.\n"),
match=re.escape(
"The feature 'state' of shape '()' does not have the expected shape '(1,)'.\n"
),
):
dataset.add_frame({"state": torch.tensor(1.0), "task": "Dummy task"})
@ -457,7 +471,9 @@ def test_flatten_unflatten_dict():
d = unflatten_dict(flatten_dict(d))
# test equality between nested dicts
assert json.dumps(original_d, sort_keys=True) == json.dumps(d, sort_keys=True), f"{original_d} != {d}"
assert json.dumps(original_d, sort_keys=True) == json.dumps(d, sort_keys=True), (
f"{original_d} != {d}"
)
@pytest.mark.parametrize(
@ -511,7 +527,13 @@ def test_backward_compatibility(repo_id):
load_and_compare(i + 1)
# test 2 frames at the middle of first episode
i = int((dataset.episode_data_index["to"][0].item() - dataset.episode_data_index["from"][0].item()) / 2)
i = int(
(
dataset.episode_data_index["to"][0].item()
- dataset.episode_data_index["from"][0].item()
)
/ 2
)
load_and_compare(i)
load_and_compare(i + 1)

View File

@ -54,7 +54,9 @@ def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> dict[str, np.n
@pytest.fixture(scope="module")
def synced_timestamps_factory(hf_dataset_factory):
def _create_synced_timestamps(fps: int = 30) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
def _create_synced_timestamps(
fps: int = 30,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
hf_dataset = hf_dataset_factory(fps=fps)
timestamps = torch.stack(hf_dataset["timestamp"]).numpy()
episode_indices = torch.stack(hf_dataset["episode_index"]).numpy()
@ -69,8 +71,12 @@ def unsynced_timestamps_factory(synced_timestamps_factory):
def _create_unsynced_timestamps(
fps: int = 30, tolerance_s: float = 1e-4
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
timestamps, episode_indices, episode_data_index = synced_timestamps_factory(fps=fps)
timestamps[30] += tolerance_s * 1.1 # Modify a single timestamp just outside tolerance
timestamps, episode_indices, episode_data_index = synced_timestamps_factory(
fps=fps
)
timestamps[30] += (
tolerance_s * 1.1
) # Modify a single timestamp just outside tolerance
return timestamps, episode_indices, episode_data_index
return _create_unsynced_timestamps
@ -81,8 +87,12 @@ def slightly_off_timestamps_factory(synced_timestamps_factory):
def _create_slightly_off_timestamps(
fps: int = 30, tolerance_s: float = 1e-4
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
timestamps, episode_indices, episode_data_index = synced_timestamps_factory(fps=fps)
timestamps[30] += tolerance_s * 0.9 # Modify a single timestamp just inside tolerance
timestamps, episode_indices, episode_data_index = synced_timestamps_factory(
fps=fps
)
timestamps[30] += (
tolerance_s * 0.9
) # Modify a single timestamp just inside tolerance
return timestamps, episode_indices, episode_data_index
return _create_slightly_off_timestamps
@ -91,9 +101,13 @@ def slightly_off_timestamps_factory(synced_timestamps_factory):
@pytest.fixture(scope="module")
def valid_delta_timestamps_factory():
def _create_valid_delta_timestamps(
fps: int = 30, keys: list = DUMMY_MOTOR_FEATURES, min_max_range: tuple[int, int] = (-10, 10)
fps: int = 30,
keys: list = DUMMY_MOTOR_FEATURES,
min_max_range: tuple[int, int] = (-10, 10),
) -> dict:
delta_timestamps = {key: [i * (1 / fps) for i in range(*min_max_range)] for key in keys}
delta_timestamps = {
key: [i * (1 / fps) for i in range(*min_max_range)] for key in keys
}
return delta_timestamps
return _create_valid_delta_timestamps
@ -130,7 +144,9 @@ def slightly_off_delta_timestamps_factory(valid_delta_timestamps_factory):
@pytest.fixture(scope="module")
def delta_indices_factory():
def _delta_indices(keys: list = DUMMY_MOTOR_FEATURES, min_max_range: tuple[int, int] = (-10, 10)) -> dict:
def _delta_indices(
keys: list = DUMMY_MOTOR_FEATURES, min_max_range: tuple[int, int] = (-10, 10)
) -> dict:
return {key: list(range(*min_max_range)) for key in keys}
return _delta_indices
@ -182,7 +198,9 @@ def test_check_timestamps_sync_unsynced_no_exception(unsynced_timestamps_factory
def test_check_timestamps_sync_slightly_off(slightly_off_timestamps_factory):
fps = 30
tolerance_s = 1e-4
timestamps, ep_idx, ep_data_index = slightly_off_timestamps_factory(fps, tolerance_s)
timestamps, ep_idx, ep_data_index = slightly_off_timestamps_factory(
fps, tolerance_s
)
result = check_timestamps_sync(
timestamps=timestamps,
episode_indices=ep_idx,
@ -223,7 +241,9 @@ def test_check_delta_timestamps_valid(valid_delta_timestamps_factory):
def test_check_delta_timestamps_slightly_off(slightly_off_delta_timestamps_factory):
fps = 30
tolerance_s = 1e-4
slightly_off_delta_timestamps = slightly_off_delta_timestamps_factory(fps, tolerance_s)
slightly_off_delta_timestamps = slightly_off_delta_timestamps_factory(
fps, tolerance_s
)
result = check_delta_timestamps(
delta_timestamps=slightly_off_delta_timestamps,
fps=fps,

View File

@ -33,7 +33,9 @@ from lerobot.scripts.visualize_image_transforms import (
save_all_transforms,
save_each_transform,
)
from tests.artifacts.image_transforms.save_image_transforms_to_safetensors import ARTIFACT_DIR
from tests.artifacts.image_transforms.save_image_transforms_to_safetensors import (
ARTIFACT_DIR,
)
from tests.utils import require_x86_64_kernel
@ -80,7 +82,11 @@ def test_get_image_transforms_brightness(img_tensor_factory, min_max):
img_tensor = img_tensor_factory()
tf_cfg = ImageTransformsConfig(
enable=True,
tfs={"brightness": ImageTransformConfig(type="ColorJitter", kwargs={"brightness": min_max})},
tfs={
"brightness": ImageTransformConfig(
type="ColorJitter", kwargs={"brightness": min_max}
)
},
)
tf_actual = ImageTransforms(tf_cfg)
tf_expected = v2.ColorJitter(brightness=min_max)
@ -91,7 +97,12 @@ def test_get_image_transforms_brightness(img_tensor_factory, min_max):
def test_get_image_transforms_contrast(img_tensor_factory, min_max):
img_tensor = img_tensor_factory()
tf_cfg = ImageTransformsConfig(
enable=True, tfs={"contrast": ImageTransformConfig(type="ColorJitter", kwargs={"contrast": min_max})}
enable=True,
tfs={
"contrast": ImageTransformConfig(
type="ColorJitter", kwargs={"contrast": min_max}
)
},
)
tf_actual = ImageTransforms(tf_cfg)
tf_expected = v2.ColorJitter(contrast=min_max)
@ -103,7 +114,11 @@ def test_get_image_transforms_saturation(img_tensor_factory, min_max):
img_tensor = img_tensor_factory()
tf_cfg = ImageTransformsConfig(
enable=True,
tfs={"saturation": ImageTransformConfig(type="ColorJitter", kwargs={"saturation": min_max})},
tfs={
"saturation": ImageTransformConfig(
type="ColorJitter", kwargs={"saturation": min_max}
)
},
)
tf_actual = ImageTransforms(tf_cfg)
tf_expected = v2.ColorJitter(saturation=min_max)
@ -114,7 +129,8 @@ def test_get_image_transforms_saturation(img_tensor_factory, min_max):
def test_get_image_transforms_hue(img_tensor_factory, min_max):
img_tensor = img_tensor_factory()
tf_cfg = ImageTransformsConfig(
enable=True, tfs={"hue": ImageTransformConfig(type="ColorJitter", kwargs={"hue": min_max})}
enable=True,
tfs={"hue": ImageTransformConfig(type="ColorJitter", kwargs={"hue": min_max})},
)
tf_actual = ImageTransforms(tf_cfg)
tf_expected = v2.ColorJitter(hue=min_max)
@ -126,7 +142,11 @@ def test_get_image_transforms_sharpness(img_tensor_factory, min_max):
img_tensor = img_tensor_factory()
tf_cfg = ImageTransformsConfig(
enable=True,
tfs={"sharpness": ImageTransformConfig(type="SharpnessJitter", kwargs={"sharpness": min_max})},
tfs={
"sharpness": ImageTransformConfig(
type="SharpnessJitter", kwargs={"sharpness": min_max}
)
},
)
tf_actual = ImageTransforms(tf_cfg)
tf_expected = SharpnessJitter(sharpness=min_max)
@ -342,7 +362,9 @@ def test_save_all_transforms(img_tensor_factory, tmp_path):
# Check if the combined transforms directory exists and contains the right files
combined_transforms_dir = tmp_path / "all"
assert combined_transforms_dir.exists(), "Combined transforms directory was not created."
assert combined_transforms_dir.exists(), (
"Combined transforms directory was not created."
)
assert any(combined_transforms_dir.iterdir()), (
"No transformed images found in combined transforms directory."
)
@ -364,9 +386,9 @@ def test_save_each_transform(img_tensor_factory, tmp_path):
for transform in transforms:
transform_dir = tmp_path / transform
assert transform_dir.exists(), f"{transform} directory was not created."
assert any(
transform_dir.iterdir()
), f"No transformed images found in {transform} directory."
assert any(transform_dir.iterdir()), (
f"No transformed images found in {transform} directory."
)
# Check for specific files within each transform directory
expected_files = [f"{i}.png" for i in range(1, n_examples + 1)] + [

View File

@ -176,7 +176,9 @@ def test_delta_timestamps_within_tolerance():
buffer.tolerance_s = 0.04
item = buffer[2]
data, is_pad = item["index"], item[f"index{OnlineBuffer.IS_PAD_POSTFIX}"]
torch.testing.assert_close(data, torch.tensor([0, 2, 3]), msg="Data does not match expected values")
torch.testing.assert_close(
data, torch.tensor([0, 2, 3]), msg="Data does not match expected values"
)
assert not is_pad.any(), "Unexpected padding detected"
@ -212,7 +214,9 @@ def test_delta_timestamps_outside_tolerance_outside_episode_range():
buffer.tolerance_s = 0.04
item = buffer[2]
data, is_pad = item["index"], item["index_is_pad"]
assert torch.equal(data, torch.tensor([0, 0, 2, 4, 4])), "Data does not match expected values"
assert torch.equal(data, torch.tensor([0, 0, 2, 4, 4])), (
"Data does not match expected values"
)
assert torch.equal(is_pad, torch.tensor([True, False, False, True, True])), (
"Padding does not match expected values"
)
@ -275,7 +279,8 @@ def test_compute_sampler_weights_nontrivial_ratio(lerobot_dataset_factory, tmp_p
online_sampling_ratio=online_sampling_ratio,
)
torch.testing.assert_close(
weights, torch.tensor([0.05, 0.05, 0.05, 0.05, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1])
weights,
torch.tensor([0.05, 0.05, 0.05, 0.05, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]),
)
@ -297,7 +302,8 @@ def test_compute_sampler_weights_nontrivial_ratio_and_drop_last_n(
online_drop_n_last_frames=1,
)
torch.testing.assert_close(
weights, torch.tensor([0.05, 0.05, 0.05, 0.05, 0.2, 0.0, 0.2, 0.0, 0.2, 0.0, 0.2, 0.0])
weights,
torch.tensor([0.05, 0.05, 0.05, 0.05, 0.2, 0.0, 0.2, 0.0, 0.2, 0.0, 0.2, 0.0]),
)
@ -318,4 +324,6 @@ def test_compute_sampler_weights_drop_n_last_frames(lerobot_dataset_factory, tmp
online_sampling_ratio=0.5,
online_drop_n_last_frames=1,
)
torch.testing.assert_close(weights, torch.tensor([0.5, 0, 0.125, 0, 0.125, 0, 0.125, 0, 0.125, 0]))
torch.testing.assert_close(
weights, torch.tensor([0.5, 0, 0.125, 0, 0.125, 0, 0.125, 0, 0.125, 0])
)

View File

@ -18,8 +18,13 @@ import torch
from datasets import Dataset
from huggingface_hub import DatasetCard
from lerobot.common.datasets.push_dataset_to_hub.utils import calculate_episode_data_index
from lerobot.common.datasets.utils import create_lerobot_dataset_card, hf_transform_to_torch
from lerobot.common.datasets.push_dataset_to_hub.utils import (
calculate_episode_data_index,
)
from lerobot.common.datasets.utils import (
create_lerobot_dataset_card,
hf_transform_to_torch,
)
def test_default_parameters():

View File

@ -210,7 +210,10 @@ def tasks_factory():
def _create_tasks(total_tasks: int = 3) -> int:
tasks = {}
for task_index in range(total_tasks):
task_dict = {"task_index": task_index, "task": f"Perform action {task_index}."}
task_dict = {
"task_index": task_index,
"task": f"Perform action {task_index}.",
}
tasks[task_index] = task_dict
return tasks
@ -297,8 +300,12 @@ def hf_dataset_factory(
episode_index_col = np.array([], dtype=np.int64)
task_index = np.array([], dtype=np.int64)
for ep_dict in episodes.values():
timestamp_col = np.concatenate((timestamp_col, np.arange(ep_dict["length"]) / fps))
frame_index_col = np.concatenate((frame_index_col, np.arange(ep_dict["length"], dtype=int)))
timestamp_col = np.concatenate(
(timestamp_col, np.arange(ep_dict["length"]) / fps)
)
frame_index_col = np.concatenate(
(frame_index_col, np.arange(ep_dict["length"], dtype=int))
)
episode_index_col = np.concatenate(
(
episode_index_col,
@ -385,7 +392,9 @@ def lerobot_dataset_metadata_factory(
episodes=episodes,
)
with (
patch("lerobot.common.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version_patch,
patch(
"lerobot.common.datasets.lerobot_dataset.get_safe_version"
) as mock_get_safe_version_patch,
patch(
"lerobot.common.datasets.lerobot_dataset.snapshot_download"
) as mock_snapshot_download_patch,
@ -433,7 +442,9 @@ def lerobot_dataset_factory(
if not stats:
stats = stats_factory(features=info["features"])
if not episodes_stats:
episodes_stats = episodes_stats_factory(features=info["features"], total_episodes=total_episodes)
episodes_stats = episodes_stats_factory(
features=info["features"], total_episodes=total_episodes
)
if not tasks:
tasks = tasks_factory(total_tasks=info["total_tasks"])
if not episode_dicts:
@ -466,8 +477,12 @@ def lerobot_dataset_factory(
episodes=episode_dicts,
)
with (
patch("lerobot.common.datasets.lerobot_dataset.LeRobotDatasetMetadata") as mock_metadata_patch,
patch("lerobot.common.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version_patch,
patch(
"lerobot.common.datasets.lerobot_dataset.LeRobotDatasetMetadata"
) as mock_metadata_patch,
patch(
"lerobot.common.datasets.lerobot_dataset.get_safe_version"
) as mock_get_safe_version_patch,
patch(
"lerobot.common.datasets.lerobot_dataset.snapshot_download"
) as mock_snapshot_download_patch,

View File

@ -59,7 +59,9 @@ def stats_path(stats_factory):
@pytest.fixture(scope="session")
def episodes_stats_path(episodes_stats_factory):
def _create_episodes_stats_jsonl_file(dir: Path, episodes_stats: list[dict] | None = None) -> Path:
def _create_episodes_stats_jsonl_file(
dir: Path, episodes_stats: list[dict] | None = None
) -> Path:
if not episodes_stats:
episodes_stats = episodes_stats_factory()
fpath = dir / EPISODES_STATS_PATH

View File

@ -99,7 +99,13 @@ def mock_snapshot_download_factory(
# List all possible files
all_files = []
meta_files = [INFO_PATH, STATS_PATH, EPISODES_STATS_PATH, TASKS_PATH, EPISODES_PATH]
meta_files = [
INFO_PATH,
STATS_PATH,
EPISODES_STATS_PATH,
TASKS_PATH,
EPISODES_PATH,
]
all_files.extend(meta_files)
data_files = []

View File

@ -35,5 +35,7 @@ def optimizer(model_params):
@pytest.fixture
def scheduler(optimizer):
config = VQBeTSchedulerConfig(num_warmup_steps=10, num_vqvae_training_steps=20, num_cycles=0.5)
config = VQBeTSchedulerConfig(
num_warmup_steps=10, num_vqvae_training_steps=20, num_cycles=0.5
)
return config.build(optimizer, num_training_steps=100)

View File

@ -43,7 +43,9 @@ def test_diffuser_scheduler(optimizer):
def test_vqbet_scheduler(optimizer):
config = VQBeTSchedulerConfig(num_warmup_steps=10, num_vqvae_training_steps=20, num_cycles=0.5)
config = VQBeTSchedulerConfig(
num_warmup_steps=10, num_vqvae_training_steps=20, num_cycles=0.5
)
scheduler = config.build(optimizer, num_training_steps=100)
assert isinstance(scheduler, LambdaLR)

View File

@ -59,16 +59,33 @@ def dummy_dataset_metadata(lerobot_dataset_metadata_factory, info_factory, tmp_p
"action": {
"dtype": "float32",
"shape": (6,),
"names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"],
"names": [
"shoulder_pan",
"shoulder_lift",
"elbow_flex",
"wrist_flex",
"wrist_roll",
"gripper",
],
},
"observation.state": {
"dtype": "float32",
"shape": (6,),
"names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"],
"names": [
"shoulder_pan",
"shoulder_lift",
"elbow_flex",
"wrist_flex",
"wrist_roll",
"gripper",
],
},
}
info = info_factory(
total_episodes=1, total_frames=1, camera_features=camera_features, motor_features=motor_features
total_episodes=1,
total_frames=1,
camera_features=camera_features,
motor_features=motor_features,
)
ds_meta = lerobot_dataset_metadata_factory(root=tmp_path / "init", info=info)
return ds_meta
@ -81,7 +98,8 @@ def test_get_policy_and_config_classes(policy_name: str):
policy_cfg = make_policy_config(policy_name)
assert policy_cls.name == policy_name
assert issubclass(
policy_cfg.__class__, inspect.signature(policy_cls.__init__).parameters["config"].annotation
policy_cfg.__class__,
inspect.signature(policy_cls.__init__).parameters["config"].annotation,
)
@ -92,7 +110,13 @@ def test_get_policy_and_config_classes(policy_name: str):
("lerobot/pusht", "pusht", {}, "diffusion", {}),
("lerobot/pusht", "pusht", {}, "vqbet", {}),
("lerobot/pusht", "pusht", {}, "act", {}),
("lerobot/aloha_sim_insertion_human", "aloha", {"task": "AlohaInsertion-v0"}, "act", {}),
(
"lerobot/aloha_sim_insertion_human",
"aloha",
{"task": "AlohaInsertion-v0"},
"act",
{},
),
(
"lerobot/aloha_sim_insertion_scripted",
"aloha",
@ -172,11 +196,13 @@ def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs):
# Test updating the policy (and test that it does not mutate the batch)
batch_ = deepcopy(batch)
policy.forward(batch)
assert set(batch) == set(
batch_
), "Batch keys are not the same after a forward pass."
assert set(batch) == set(batch_), (
"Batch keys are not the same after a forward pass."
)
assert all(
torch.equal(batch[k], batch_[k]) if isinstance(batch[k], torch.Tensor) else batch[k] == batch_[k]
torch.equal(batch[k], batch_[k])
if isinstance(batch[k], torch.Tensor)
else batch[k] == batch_[k]
for k in batch
), "Batch values are not the same after a forward pass."
@ -215,8 +241,12 @@ def test_act_backbone_lr():
cfg = TrainPipelineConfig(
# TODO(rcadene, aliberts): remove dataset download
dataset=DatasetConfig(repo_id="lerobot/aloha_sim_insertion_scripted", episodes=[0]),
policy=make_policy_config("act", optimizer_lr=0.01, optimizer_lr_backbone=0.001),
dataset=DatasetConfig(
repo_id="lerobot/aloha_sim_insertion_scripted", episodes=[0]
),
policy=make_policy_config(
"act", optimizer_lr=0.01, optimizer_lr_backbone=0.001
),
)
cfg.validate() # Needed for auto-setting some parameters
@ -239,7 +269,9 @@ def test_policy_defaults(dummy_dataset_metadata, policy_name: str):
policy_cls = get_policy_class(policy_name)
policy_cfg = make_policy_config(policy_name)
features = dataset_to_policy_features(dummy_dataset_metadata.features)
policy_cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
policy_cfg.output_features = {
key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION
}
policy_cfg.input_features = {
key: ft for key, ft in features.items() if key not in policy_cfg.output_features
}
@ -251,7 +283,9 @@ def test_save_and_load_pretrained(dummy_dataset_metadata, tmp_path, policy_name:
policy_cls = get_policy_class(policy_name)
policy_cfg = make_policy_config(policy_name)
features = dataset_to_policy_features(dummy_dataset_metadata.features)
policy_cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
policy_cfg.output_features = {
key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION
}
policy_cfg.input_features = {
key: ft for key, ft in features.items() if key not in policy_cfg.output_features
}
@ -260,7 +294,9 @@ def test_save_and_load_pretrained(dummy_dataset_metadata, tmp_path, policy_name:
save_dir = tmp_path / f"test_save_and_load_pretrained_{policy_cls.__name__}"
policy.save_pretrained(save_dir)
loaded_policy = policy_cls.from_pretrained(save_dir, config=policy_cfg)
torch.testing.assert_close(list(policy.parameters()), list(loaded_policy.parameters()), rtol=0, atol=0)
torch.testing.assert_close(
list(policy.parameters()), list(loaded_policy.parameters()), rtol=0, atol=0
)
@pytest.mark.parametrize("insert_temporal_dim", [False, True])
@ -400,7 +436,9 @@ def test_normalize(insert_temporal_dim):
# pass if it's run on another platform due to floating point errors
@require_x86_64_kernel
@require_cpu
def test_backward_compatibility(ds_repo_id: str, policy_name: str, policy_kwargs: dict, file_name_extra: str):
def test_backward_compatibility(
ds_repo_id: str, policy_name: str, policy_kwargs: dict, file_name_extra: str
):
"""
NOTE: If this test does not pass, and you have intentionally changed something in the policy:
1. Inspect the differences in policy outputs and make sure you can account for them. Your PR should
@ -414,13 +452,17 @@ def test_backward_compatibility(ds_repo_id: str, policy_name: str, policy_kwargs
6. Remember to stage and commit the resulting changes to `tests/artifacts`.
"""
ds_name = ds_repo_id.split("/")[-1]
artifact_dir = Path("tests/artifacts/policies") / f"{ds_name}_{policy_name}_{file_name_extra}"
artifact_dir = (
Path("tests/artifacts/policies") / f"{ds_name}_{policy_name}_{file_name_extra}"
)
saved_output_dict = load_file(artifact_dir / "output_dict.safetensors")
saved_grad_stats = load_file(artifact_dir / "grad_stats.safetensors")
saved_param_stats = load_file(artifact_dir / "param_stats.safetensors")
saved_actions = load_file(artifact_dir / "actions.safetensors")
output_dict, grad_stats, param_stats, actions = get_policy_stats(ds_repo_id, policy_name, policy_kwargs)
output_dict, grad_stats, param_stats, actions = get_policy_stats(
ds_repo_id, policy_name, policy_kwargs
)
for key in saved_output_dict:
torch.testing.assert_close(output_dict[key], saved_output_dict[key])
@ -429,8 +471,12 @@ def test_backward_compatibility(ds_repo_id: str, policy_name: str, policy_kwargs
for key in saved_param_stats:
torch.testing.assert_close(param_stats[key], saved_param_stats[key])
for key in saved_actions:
rtol, atol = (2e-3, 5e-6) if policy_name == "diffusion" else (None, None) # HACK
torch.testing.assert_close(actions[key], saved_actions[key], rtol=rtol, atol=atol)
rtol, atol = (
(2e-3, 5e-6) if policy_name == "diffusion" else (None, None)
) # HACK
torch.testing.assert_close(
actions[key], saved_actions[key], rtol=rtol, atol=atol
)
def test_act_temporal_ensembler():

View File

@ -180,7 +180,9 @@ def test_record_and_replay_and_policy(tmp_path, request, robot_type, mock):
assert dataset.meta.total_episodes == 2
assert len(dataset) == 2
replay_cfg = ReplayControlConfig(episode=0, fps=1, root=root, repo_id=repo_id, play_sounds=False)
replay_cfg = ReplayControlConfig(
episode=0, fps=1, root=root, repo_id=repo_id, play_sounds=False
)
replay(robot, replay_cfg)
policy_cfg = ACTConfig()
@ -335,12 +337,12 @@ def test_record_with_event_rerecord_episode(tmp_path, request, robot_type, mock)
)
dataset = record(robot, rec_cfg)
assert not mock_events[
"rerecord_episode"
], "`rerecord_episode` wasn't properly reset to False"
assert not mock_events[
"exit_early"
], "`exit_early` wasn't properly reset to False"
assert not mock_events["rerecord_episode"], (
"`rerecord_episode` wasn't properly reset to False"
)
assert not mock_events["exit_early"], (
"`exit_early` wasn't properly reset to False"
)
assert len(dataset) == 1, "`dataset` should contain only 1 frame"
@ -390,7 +392,9 @@ def test_record_with_event_exit_early(tmp_path, request, robot_type, mock):
dataset = record(robot, rec_cfg)
assert not mock_events["exit_early"], "`exit_early` wasn't properly reset to False"
assert not mock_events["exit_early"], (
"`exit_early` wasn't properly reset to False"
)
assert len(dataset) == 1, "`dataset` should contain only 1 frame"
@ -399,7 +403,9 @@ def test_record_with_event_exit_early(tmp_path, request, robot_type, mock):
[("koch", True, 0), ("koch", True, 1)],
)
@require_robot
def test_record_with_event_stop_recording(tmp_path, request, robot_type, mock, num_image_writer_processes):
def test_record_with_event_stop_recording(
tmp_path, request, robot_type, mock, num_image_writer_processes
):
robot_kwargs = {"robot_type": robot_type, "mock": mock}
if mock:
@ -445,5 +451,7 @@ def test_record_with_event_stop_recording(tmp_path, request, robot_type, mock, n
dataset = record(robot, rec_cfg)
assert not mock_events["exit_early"], "`exit_early` wasn't properly reset to False"
assert not mock_events["exit_early"], (
"`exit_early` wasn't properly reset to False"
)
assert len(dataset) == 1, "`dataset` should contain only 1 frame"

View File

@ -40,7 +40,10 @@ import pytest
import torch
from lerobot.common.robot_devices.robots.utils import make_robot
from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError
from lerobot.common.robot_devices.utils import (
RobotDeviceAlreadyConnectedError,
RobotDeviceNotConnectedError,
)
from tests.utils import TEST_ROBOT_TYPES, mock_calibration_dir, require_robot
@ -131,7 +134,9 @@ def test_robot(tmp_path, request, robot_type, mock):
if "image" in name:
# TODO(rcadene): skipping image for now as it's challenging to assess equality between two consecutive frames
continue
torch.testing.assert_close(captured_observation[name], observation[name], rtol=1e-4, atol=1)
torch.testing.assert_close(
captured_observation[name], observation[name], rtol=1e-4, atol=1
)
assert captured_observation[name].shape == observation[name].shape
# Test send_action can run

View File

@ -227,9 +227,9 @@ def test_resume_function(
config_dir = os.path.abspath(
os.path.join(test_file_dir, "..", "lerobot", "configs", "policy")
)
assert os.path.exists(
config_dir
), f"Config directory does not exist at {config_dir}"
assert os.path.exists(config_dir), (
f"Config directory does not exist at {config_dir}"
)
with initialize_config_dir(
config_dir=config_dir, job_name="test_app", version_base="1.2"

View File

@ -26,10 +26,16 @@ from lerobot import available_cameras, available_motors, available_robots
from lerobot.common.robot_devices.cameras.utils import Camera
from lerobot.common.robot_devices.cameras.utils import make_camera as make_camera_device
from lerobot.common.robot_devices.motors.utils import MotorsBus
from lerobot.common.robot_devices.motors.utils import make_motors_bus as make_motors_bus_device
from lerobot.common.robot_devices.motors.utils import (
make_motors_bus as make_motors_bus_device,
)
from lerobot.common.utils.import_utils import is_package_available
DEVICE = os.environ.get("LEROBOT_TEST_DEVICE", "cuda") if torch.cuda.is_available() else "cpu"
DEVICE = (
os.environ.get("LEROBOT_TEST_DEVICE", "cuda")
if torch.cuda.is_available()
else "cpu"
)
TEST_ROBOT_TYPES = []
for robot_type in available_robots:
@ -45,7 +51,9 @@ for motor_type in available_motors:
# Camera indices used for connecting physical cameras
OPENCV_CAMERA_INDEX = int(os.environ.get("LEROBOT_TEST_OPENCV_CAMERA_INDEX", 0))
INTELREALSENSE_SERIAL_NUMBER = int(os.environ.get("LEROBOT_TEST_INTELREALSENSE_SERIAL_NUMBER", 128422271614))
INTELREALSENSE_SERIAL_NUMBER = int(
os.environ.get("LEROBOT_TEST_INTELREALSENSE_SERIAL_NUMBER", 128422271614)
)
DYNAMIXEL_PORT = os.environ.get(
"LEROBOT_TEST_DYNAMIXEL_PORT", "/dev/tty.usbmodem575E0032081"

View File

@ -18,7 +18,10 @@ from lerobot.common.utils.logging_utils import AverageMeter, MetricsTracker
@pytest.fixture
def mock_metrics():
return {"loss": AverageMeter("loss", ":.3f"), "accuracy": AverageMeter("accuracy", ":.2f")}
return {
"loss": AverageMeter("loss", ":.3f"),
"accuracy": AverageMeter("accuracy", ":.2f"),
}
def test_average_meter_initialization():
@ -58,7 +61,11 @@ def test_average_meter_str():
def test_metrics_tracker_initialization(mock_metrics):
tracker = MetricsTracker(
batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics, initial_step=10
batch_size=32,
num_frames=1000,
num_episodes=50,
metrics=mock_metrics,
initial_step=10,
)
assert tracker.steps == 10
assert tracker.samples == 10 * 32
@ -70,7 +77,11 @@ def test_metrics_tracker_initialization(mock_metrics):
def test_metrics_tracker_step(mock_metrics):
tracker = MetricsTracker(
batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics, initial_step=5
batch_size=32,
num_frames=1000,
num_episodes=50,
metrics=mock_metrics,
initial_step=5,
)
tracker.step()
assert tracker.steps == 6
@ -80,7 +91,9 @@ def test_metrics_tracker_step(mock_metrics):
def test_metrics_tracker_getattr(mock_metrics):
tracker = MetricsTracker(batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics)
tracker = MetricsTracker(
batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics
)
assert tracker.loss == mock_metrics["loss"]
assert tracker.accuracy == mock_metrics["accuracy"]
with pytest.raises(AttributeError):
@ -88,13 +101,17 @@ def test_metrics_tracker_getattr(mock_metrics):
def test_metrics_tracker_setattr(mock_metrics):
tracker = MetricsTracker(batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics)
tracker = MetricsTracker(
batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics
)
tracker.loss = 2.0
assert tracker.loss.val == 2.0
def test_metrics_tracker_str(mock_metrics):
tracker = MetricsTracker(batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics)
tracker = MetricsTracker(
batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics
)
tracker.loss.update(3.456, 1)
tracker.accuracy.update(0.876, 1)
output = str(tracker)
@ -103,7 +120,9 @@ def test_metrics_tracker_str(mock_metrics):
def test_metrics_tracker_to_dict(mock_metrics):
tracker = MetricsTracker(batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics)
tracker = MetricsTracker(
batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics
)
tracker.loss.update(5, 2)
metrics_dict = tracker.to_dict()
assert isinstance(metrics_dict, dict)
@ -112,7 +131,9 @@ def test_metrics_tracker_to_dict(mock_metrics):
def test_metrics_tracker_reset_averages(mock_metrics):
tracker = MetricsTracker(batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics)
tracker = MetricsTracker(
batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics
)
tracker.loss.update(10, 3)
tracker.accuracy.update(0.95, 5)
tracker.reset_averages()

View File

@ -118,5 +118,9 @@ def test_seeded_context(fixed_seed):
seeded_val2 = (random.random(), np.random.rand(), torch.rand(1).item())
assert seeded_val1 == seeded_val2
assert all(a != b for a, b in zip(val1, seeded_val1, strict=True)) # changed inside the context
assert all(a != b for a, b in zip(val2, seeded_val2, strict=True)) # changed again after exiting
assert all(
a != b for a, b in zip(val1, seeded_val1, strict=True)
) # changed inside the context
assert all(
a != b for a, b in zip(val2, seeded_val2, strict=True)
) # changed again after exiting

View File

@ -91,7 +91,9 @@ def test_save_training_state(tmp_path, optimizer, scheduler):
def test_save_load_training_state(tmp_path, optimizer, scheduler):
save_training_state(tmp_path, 10, optimizer, scheduler)
loaded_step, loaded_optimizer, loaded_scheduler = load_training_state(tmp_path, optimizer, scheduler)
loaded_step, loaded_optimizer, loaded_scheduler = load_training_state(
tmp_path, optimizer, scheduler
)
assert loaded_step == 10
assert loaded_optimizer is optimizer
assert loaded_scheduler is scheduler