[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
parent
313812df16
commit
ce0008850d
|
@ -22,7 +22,10 @@ from pathlib import Path
|
|||
|
||||
import torch
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||
from lerobot.common.datasets.lerobot_dataset import (
|
||||
LeRobotDataset,
|
||||
LeRobotDatasetMetadata,
|
||||
)
|
||||
from lerobot.common.datasets.utils import dataset_to_policy_features
|
||||
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
||||
|
@ -48,12 +51,18 @@ def main():
|
|||
# - dataset stats: for normalization and denormalization of input/outputs
|
||||
dataset_metadata = LeRobotDatasetMetadata("lerobot/pusht")
|
||||
features = dataset_to_policy_features(dataset_metadata.features)
|
||||
output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
|
||||
input_features = {key: ft for key, ft in features.items() if key not in output_features}
|
||||
output_features = {
|
||||
key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION
|
||||
}
|
||||
input_features = {
|
||||
key: ft for key, ft in features.items() if key not in output_features
|
||||
}
|
||||
|
||||
# Policies are initialized with a configuration class, in this case `DiffusionConfig`. For this example,
|
||||
# we'll just use the defaults and so no arguments other than input/output features need to be passed.
|
||||
cfg = DiffusionConfig(input_features=input_features, output_features=output_features)
|
||||
cfg = DiffusionConfig(
|
||||
input_features=input_features, output_features=output_features
|
||||
)
|
||||
|
||||
# We can now instantiate our policy with this config and the dataset stats.
|
||||
policy = DiffusionPolicy(cfg, dataset_stats=dataset_metadata.stats)
|
||||
|
@ -63,8 +72,12 @@ def main():
|
|||
# Another policy-dataset interaction is with the delta_timestamps. Each policy expects a given number frames
|
||||
# which can differ for inputs, outputs and rewards (if there are some).
|
||||
delta_timestamps = {
|
||||
"observation.image": [i / dataset_metadata.fps for i in cfg.observation_delta_indices],
|
||||
"observation.state": [i / dataset_metadata.fps for i in cfg.observation_delta_indices],
|
||||
"observation.image": [
|
||||
i / dataset_metadata.fps for i in cfg.observation_delta_indices
|
||||
],
|
||||
"observation.state": [
|
||||
i / dataset_metadata.fps for i in cfg.observation_delta_indices
|
||||
],
|
||||
"action": [i / dataset_metadata.fps for i in cfg.action_delta_indices],
|
||||
}
|
||||
|
||||
|
@ -77,7 +90,24 @@ def main():
|
|||
# Load the previous action (-0.1), the next action to be executed (0.0),
|
||||
# and 14 future actions with a 0.1 seconds spacing. All these actions will be
|
||||
# used to supervise the policy.
|
||||
"action": [-0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4],
|
||||
"action": [
|
||||
-0.1,
|
||||
0.0,
|
||||
0.1,
|
||||
0.2,
|
||||
0.3,
|
||||
0.4,
|
||||
0.5,
|
||||
0.6,
|
||||
0.7,
|
||||
0.8,
|
||||
0.9,
|
||||
1.0,
|
||||
1.1,
|
||||
1.2,
|
||||
1.3,
|
||||
1.4,
|
||||
],
|
||||
}
|
||||
|
||||
# We can then instantiate the dataset with these delta_timestamps configuration.
|
||||
|
@ -99,7 +129,10 @@ def main():
|
|||
done = False
|
||||
while not done:
|
||||
for batch in dataloader:
|
||||
batch = {k: (v.to(device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()}
|
||||
batch = {
|
||||
k: (v.to(device) if isinstance(v, torch.Tensor) else v)
|
||||
for k, v in batch.items()
|
||||
}
|
||||
loss, _ = policy.forward(batch)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
|
|
@ -54,7 +54,24 @@ def main():
|
|||
# Load the previous action (-0.1), the next action to be executed (0.0),
|
||||
# and 14 future actions with a 0.1 seconds spacing. All these actions will be
|
||||
# used to calculate the loss.
|
||||
"action": [-0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4],
|
||||
"action": [
|
||||
-0.1,
|
||||
0.0,
|
||||
0.1,
|
||||
0.2,
|
||||
0.3,
|
||||
0.4,
|
||||
0.5,
|
||||
0.6,
|
||||
0.7,
|
||||
0.8,
|
||||
0.9,
|
||||
1.0,
|
||||
1.1,
|
||||
1.2,
|
||||
1.3,
|
||||
1.4,
|
||||
],
|
||||
}
|
||||
|
||||
# Load the last 10% of episodes of the dataset as a validation set.
|
||||
|
@ -73,7 +90,9 @@ def main():
|
|||
train_dataset = LeRobotDataset(
|
||||
"lerobot/pusht", episodes=train_episodes, delta_timestamps=delta_timestamps
|
||||
)
|
||||
val_dataset = LeRobotDataset("lerobot/pusht", episodes=val_episodes, delta_timestamps=delta_timestamps)
|
||||
val_dataset = LeRobotDataset(
|
||||
"lerobot/pusht", episodes=val_episodes, delta_timestamps=delta_timestamps
|
||||
)
|
||||
print(f"Number of frames in training dataset (90% subset): {len(train_dataset)}")
|
||||
print(f"Number of frames in validation dataset (10% subset): {len(val_dataset)}")
|
||||
|
||||
|
|
|
@ -19,7 +19,10 @@ from lerobot.common.datasets.utils import load_image_as_numpy
|
|||
|
||||
|
||||
def estimate_num_samples(
|
||||
dataset_len: int, min_num_samples: int = 100, max_num_samples: int = 10_000, power: float = 0.75
|
||||
dataset_len: int,
|
||||
min_num_samples: int = 100,
|
||||
max_num_samples: int = 10_000,
|
||||
power: float = 0.75,
|
||||
) -> int:
|
||||
"""Heuristic to estimate the number of samples based on dataset size.
|
||||
The power controls the sample growth relative to dataset size.
|
||||
|
@ -43,14 +46,18 @@ def sample_indices(data_len: int) -> list[int]:
|
|||
return np.round(np.linspace(0, data_len - 1, num_samples)).astype(int).tolist()
|
||||
|
||||
|
||||
def auto_downsample_height_width(img: np.ndarray, target_size: int = 150, max_size_threshold: int = 300):
|
||||
def auto_downsample_height_width(
|
||||
img: np.ndarray, target_size: int = 150, max_size_threshold: int = 300
|
||||
):
|
||||
_, height, width = img.shape
|
||||
|
||||
if max(width, height) < max_size_threshold:
|
||||
# no downsampling needed
|
||||
return img
|
||||
|
||||
downsample_factor = int(width / target_size) if width > height else int(height / target_size)
|
||||
downsample_factor = (
|
||||
int(width / target_size) if width > height else int(height / target_size)
|
||||
)
|
||||
return img[:, ::downsample_factor, ::downsample_factor]
|
||||
|
||||
|
||||
|
@ -72,7 +79,9 @@ def sample_images(image_paths: list[str]) -> np.ndarray:
|
|||
return images
|
||||
|
||||
|
||||
def get_feature_stats(array: np.ndarray, axis: tuple, keepdims: bool) -> dict[str, np.ndarray]:
|
||||
def get_feature_stats(
|
||||
array: np.ndarray, axis: tuple, keepdims: bool
|
||||
) -> dict[str, np.ndarray]:
|
||||
return {
|
||||
"min": np.min(array, axis=axis, keepdims=keepdims),
|
||||
"max": np.max(array, axis=axis, keepdims=keepdims),
|
||||
|
@ -82,7 +91,9 @@ def get_feature_stats(array: np.ndarray, axis: tuple, keepdims: bool) -> dict[st
|
|||
}
|
||||
|
||||
|
||||
def compute_episode_stats(episode_data: dict[str, list[str] | np.ndarray], features: dict) -> dict:
|
||||
def compute_episode_stats(
|
||||
episode_data: dict[str, list[str] | np.ndarray], features: dict
|
||||
) -> dict:
|
||||
ep_stats = {}
|
||||
for key, data in episode_data.items():
|
||||
if features[key]["dtype"] == "string":
|
||||
|
@ -96,12 +107,15 @@ def compute_episode_stats(episode_data: dict[str, list[str] | np.ndarray], featu
|
|||
axes_to_reduce = 0 # compute stats over the first axis
|
||||
keepdims = data.ndim == 1 # keep as np.array
|
||||
|
||||
ep_stats[key] = get_feature_stats(ep_ft_array, axis=axes_to_reduce, keepdims=keepdims)
|
||||
ep_stats[key] = get_feature_stats(
|
||||
ep_ft_array, axis=axes_to_reduce, keepdims=keepdims
|
||||
)
|
||||
|
||||
# finally, we normalize and remove batch dim for images
|
||||
if features[key]["dtype"] in ["image", "video"]:
|
||||
ep_stats[key] = {
|
||||
k: v if k == "count" else np.squeeze(v / 255.0, axis=0) for k, v in ep_stats[key].items()
|
||||
k: v if k == "count" else np.squeeze(v / 255.0, axis=0)
|
||||
for k, v in ep_stats[key].items()
|
||||
}
|
||||
|
||||
return ep_stats
|
||||
|
@ -116,14 +130,22 @@ def _assert_type_and_shape(stats_list: list[dict[str, dict]]):
|
|||
f"Stats must be composed of numpy array, but key '{k}' of feature '{fkey}' is of type '{type(v)}' instead."
|
||||
)
|
||||
if v.ndim == 0:
|
||||
raise ValueError("Number of dimensions must be at least 1, and is 0 instead.")
|
||||
raise ValueError(
|
||||
"Number of dimensions must be at least 1, and is 0 instead."
|
||||
)
|
||||
if k == "count" and v.shape != (1,):
|
||||
raise ValueError(f"Shape of 'count' must be (1), but is {v.shape} instead.")
|
||||
raise ValueError(
|
||||
f"Shape of 'count' must be (1), but is {v.shape} instead."
|
||||
)
|
||||
if "image" in fkey and k != "count" and v.shape != (3, 1, 1):
|
||||
raise ValueError(f"Shape of '{k}' must be (3,1,1), but is {v.shape} instead.")
|
||||
raise ValueError(
|
||||
f"Shape of '{k}' must be (3,1,1), but is {v.shape} instead."
|
||||
)
|
||||
|
||||
|
||||
def aggregate_feature_stats(stats_ft_list: list[dict[str, dict]]) -> dict[str, dict[str, np.ndarray]]:
|
||||
def aggregate_feature_stats(
|
||||
stats_ft_list: list[dict[str, dict]],
|
||||
) -> dict[str, dict[str, np.ndarray]]:
|
||||
"""Aggregates stats for a single feature."""
|
||||
means = np.stack([s["mean"] for s in stats_ft_list])
|
||||
variances = np.stack([s["std"] ** 2 for s in stats_ft_list])
|
||||
|
@ -152,7 +174,9 @@ def aggregate_feature_stats(stats_ft_list: list[dict[str, dict]]) -> dict[str, d
|
|||
}
|
||||
|
||||
|
||||
def aggregate_stats(stats_list: list[dict[str, dict]]) -> dict[str, dict[str, np.ndarray]]:
|
||||
def aggregate_stats(
|
||||
stats_list: list[dict[str, dict]],
|
||||
) -> dict[str, dict[str, np.ndarray]]:
|
||||
"""Aggregate stats from multiple compute_stats outputs into a single set of stats.
|
||||
|
||||
The final stats will have the union of all data keys from each of the stats dicts.
|
||||
|
|
|
@ -58,7 +58,9 @@ def resolve_delta_timestamps(
|
|||
if key == "action" and cfg.action_delta_indices is not None:
|
||||
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.action_delta_indices]
|
||||
if key.startswith("observation.") and cfg.observation_delta_indices is not None:
|
||||
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.observation_delta_indices]
|
||||
delta_timestamps[key] = [
|
||||
i / ds_meta.fps for i in cfg.observation_delta_indices
|
||||
]
|
||||
|
||||
if len(delta_timestamps) == 0:
|
||||
delta_timestamps = None
|
||||
|
@ -79,7 +81,9 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
|
|||
LeRobotDataset | MultiLeRobotDataset
|
||||
"""
|
||||
image_transforms = (
|
||||
ImageTransforms(cfg.dataset.image_transforms) if cfg.dataset.image_transforms.enable else None
|
||||
ImageTransforms(cfg.dataset.image_transforms)
|
||||
if cfg.dataset.image_transforms.enable
|
||||
else None
|
||||
)
|
||||
|
||||
if isinstance(cfg.dataset.repo_id, str):
|
||||
|
@ -113,6 +117,8 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
|
|||
if cfg.dataset.use_imagenet_stats:
|
||||
for key in dataset.meta.camera_keys:
|
||||
for stats_type, stats in IMAGENET_STATS.items():
|
||||
dataset.meta.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32)
|
||||
dataset.meta.stats[key][stats_type] = torch.tensor(
|
||||
stats, dtype=torch.float32
|
||||
)
|
||||
|
||||
return dataset
|
||||
|
|
|
@ -38,10 +38,14 @@ def safe_stop_image_writer(func):
|
|||
return wrapper
|
||||
|
||||
|
||||
def image_array_to_pil_image(image_array: np.ndarray, range_check: bool = True) -> PIL.Image.Image:
|
||||
def image_array_to_pil_image(
|
||||
image_array: np.ndarray, range_check: bool = True
|
||||
) -> PIL.Image.Image:
|
||||
# TODO(aliberts): handle 1 channel and 4 for depth images
|
||||
if image_array.ndim != 3:
|
||||
raise ValueError(f"The array has {image_array.ndim} dimensions, but 3 is expected for an image.")
|
||||
raise ValueError(
|
||||
f"The array has {image_array.ndim} dimensions, but 3 is expected for an image."
|
||||
)
|
||||
|
||||
if image_array.shape[0] == 3:
|
||||
# Transpose from pytorch convention (C, H, W) to (H, W, C)
|
||||
|
|
|
@ -108,7 +108,9 @@ class LeRobotDatasetMetadata:
|
|||
self.episodes = load_episodes(self.root)
|
||||
if self._version < packaging.version.parse("v2.1"):
|
||||
self.stats = load_stats(self.root)
|
||||
self.episodes_stats = backward_compatible_episodes_stats(self.stats, self.episodes)
|
||||
self.episodes_stats = backward_compatible_episodes_stats(
|
||||
self.stats, self.episodes
|
||||
)
|
||||
else:
|
||||
self.episodes_stats = load_episodes_stats(self.root)
|
||||
self.stats = aggregate_stats(list(self.episodes_stats.values()))
|
||||
|
@ -238,7 +240,9 @@ class LeRobotDatasetMetadata:
|
|||
Given a task in natural language, add it to the dictionary of tasks.
|
||||
"""
|
||||
if task in self.task_to_task_index:
|
||||
raise ValueError(f"The task '{task}' already exists and can't be added twice.")
|
||||
raise ValueError(
|
||||
f"The task '{task}' already exists and can't be added twice."
|
||||
)
|
||||
|
||||
task_index = self.info["total_tasks"]
|
||||
self.task_to_task_index[task] = task_index
|
||||
|
@ -281,7 +285,11 @@ class LeRobotDatasetMetadata:
|
|||
write_episode(episode_dict, self.root)
|
||||
|
||||
self.episodes_stats[episode_index] = episode_stats
|
||||
self.stats = aggregate_stats([self.stats, episode_stats]) if self.stats else episode_stats
|
||||
self.stats = (
|
||||
aggregate_stats([self.stats, episode_stats])
|
||||
if self.stats
|
||||
else episode_stats
|
||||
)
|
||||
write_episode_stats(episode_index, episode_stats, self.root)
|
||||
|
||||
def update_video_info(self) -> None:
|
||||
|
@ -345,13 +353,17 @@ class LeRobotDatasetMetadata:
|
|||
# as this would break the dict flattening in the stats computation, which uses '/' as separator
|
||||
for key in features:
|
||||
if "/" in key:
|
||||
raise ValueError(f"Feature names should not contain '/'. Found '/' in feature '{key}'.")
|
||||
raise ValueError(
|
||||
f"Feature names should not contain '/'. Found '/' in feature '{key}'."
|
||||
)
|
||||
|
||||
features = {**features, **DEFAULT_FEATURES}
|
||||
|
||||
obj.tasks, obj.task_to_task_index = {}, {}
|
||||
obj.episodes_stats, obj.stats, obj.episodes = {}, {}, {}
|
||||
obj.info = create_empty_dataset_info(CODEBASE_VERSION, fps, robot_type, features, use_videos)
|
||||
obj.info = create_empty_dataset_info(
|
||||
CODEBASE_VERSION, fps, robot_type, features, use_videos
|
||||
)
|
||||
if len(obj.video_keys) > 0 and not use_videos:
|
||||
raise ValueError()
|
||||
write_json(obj.info, obj.root / INFO_PATH)
|
||||
|
@ -482,7 +494,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
self.episodes = episodes
|
||||
self.tolerance_s = tolerance_s
|
||||
self.revision = revision if revision else CODEBASE_VERSION
|
||||
self.video_backend = video_backend if video_backend else get_safe_default_codec()
|
||||
self.video_backend = (
|
||||
video_backend if video_backend else get_safe_default_codec()
|
||||
)
|
||||
self.delta_indices = None
|
||||
|
||||
# Unused attributes
|
||||
|
@ -495,28 +509,39 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
self.meta = LeRobotDatasetMetadata(
|
||||
self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync
|
||||
)
|
||||
if self.episodes is not None and self.meta._version >= packaging.version.parse("v2.1"):
|
||||
episodes_stats = [self.meta.episodes_stats[ep_idx] for ep_idx in self.episodes]
|
||||
if self.episodes is not None and self.meta._version >= packaging.version.parse(
|
||||
"v2.1"
|
||||
):
|
||||
episodes_stats = [
|
||||
self.meta.episodes_stats[ep_idx] for ep_idx in self.episodes
|
||||
]
|
||||
self.stats = aggregate_stats(episodes_stats)
|
||||
|
||||
# Load actual data
|
||||
try:
|
||||
if force_cache_sync:
|
||||
raise FileNotFoundError
|
||||
assert all((self.root / fpath).is_file() for fpath in self.get_episodes_file_paths())
|
||||
assert all(
|
||||
(self.root / fpath).is_file()
|
||||
for fpath in self.get_episodes_file_paths()
|
||||
)
|
||||
self.hf_dataset = self.load_hf_dataset()
|
||||
except (AssertionError, FileNotFoundError, NotADirectoryError):
|
||||
self.revision = get_safe_version(self.repo_id, self.revision)
|
||||
self.download_episodes(download_videos)
|
||||
self.hf_dataset = self.load_hf_dataset()
|
||||
|
||||
self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
|
||||
self.episode_data_index = get_episode_data_index(
|
||||
self.meta.episodes, self.episodes
|
||||
)
|
||||
|
||||
# Check timestamps
|
||||
timestamps = torch.stack(self.hf_dataset["timestamp"]).numpy()
|
||||
episode_indices = torch.stack(self.hf_dataset["episode_index"]).numpy()
|
||||
ep_data_index_np = {k: t.numpy() for k, t in self.episode_data_index.items()}
|
||||
check_timestamps_sync(timestamps, episode_indices, ep_data_index_np, self.fps, self.tolerance_s)
|
||||
check_timestamps_sync(
|
||||
timestamps, episode_indices, ep_data_index_np, self.fps, self.tolerance_s
|
||||
)
|
||||
|
||||
# Setup delta_indices
|
||||
if self.delta_timestamps is not None:
|
||||
|
@ -568,7 +593,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
else:
|
||||
hub_api.upload_folder(**upload_kwargs)
|
||||
|
||||
if not hub_api.file_exists(self.repo_id, REPOCARD_NAME, repo_type="dataset", revision=branch):
|
||||
if not hub_api.file_exists(
|
||||
self.repo_id, REPOCARD_NAME, repo_type="dataset", revision=branch
|
||||
):
|
||||
card = create_lerobot_dataset_card(
|
||||
tags=tags, dataset_info=self.meta.info, license=license, **card_kwargs
|
||||
)
|
||||
|
@ -576,8 +603,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
|
||||
if tag_version:
|
||||
with contextlib.suppress(RevisionNotFoundError):
|
||||
hub_api.delete_tag(self.repo_id, tag=CODEBASE_VERSION, repo_type="dataset")
|
||||
hub_api.create_tag(self.repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
|
||||
hub_api.delete_tag(
|
||||
self.repo_id, tag=CODEBASE_VERSION, repo_type="dataset"
|
||||
)
|
||||
hub_api.create_tag(
|
||||
self.repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset"
|
||||
)
|
||||
|
||||
def pull_from_repo(
|
||||
self,
|
||||
|
@ -609,7 +640,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
self.pull_from_repo(allow_patterns=files, ignore_patterns=ignore_patterns)
|
||||
|
||||
def get_episodes_file_paths(self) -> list[Path]:
|
||||
episodes = self.episodes if self.episodes is not None else list(range(self.meta.total_episodes))
|
||||
episodes = (
|
||||
self.episodes
|
||||
if self.episodes is not None
|
||||
else list(range(self.meta.total_episodes))
|
||||
)
|
||||
fpaths = [str(self.meta.get_data_file_path(ep_idx)) for ep_idx in episodes]
|
||||
if len(self.meta.video_keys) > 0:
|
||||
video_files = [
|
||||
|
@ -640,7 +675,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
def create_hf_dataset(self) -> datasets.Dataset:
|
||||
features = get_hf_features_from_features(self.features)
|
||||
ft_dict = {col: [] for col in features}
|
||||
hf_dataset = datasets.Dataset.from_dict(ft_dict, features=features, split="train")
|
||||
hf_dataset = datasets.Dataset.from_dict(
|
||||
ft_dict, features=features, split="train"
|
||||
)
|
||||
|
||||
# TODO(aliberts): hf_dataset.set_format("torch")
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
|
@ -726,7 +763,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
if key not in self.meta.video_keys
|
||||
}
|
||||
|
||||
def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict[str, torch.Tensor]:
|
||||
def _query_videos(
|
||||
self, query_timestamps: dict[str, list[float]], ep_idx: int
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function
|
||||
in the main process (e.g. by using a second Dataloader with num_workers=0). It will result in a
|
||||
Segmentation Fault. This probably happens because a memory reference to the video loader is created in
|
||||
|
@ -735,7 +774,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
item = {}
|
||||
for vid_key, query_ts in query_timestamps.items():
|
||||
video_path = self.root / self.meta.get_video_file_path(ep_idx, vid_key)
|
||||
frames = decode_video_frames(video_path, query_ts, self.tolerance_s, self.video_backend)
|
||||
frames = decode_video_frames(
|
||||
video_path, query_ts, self.tolerance_s, self.video_backend
|
||||
)
|
||||
item[vid_key] = frames.squeeze(0)
|
||||
|
||||
return item
|
||||
|
@ -789,7 +830,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
)
|
||||
|
||||
def create_episode_buffer(self, episode_index: int | None = None) -> dict:
|
||||
current_ep_idx = self.meta.total_episodes if episode_index is None else episode_index
|
||||
current_ep_idx = (
|
||||
self.meta.total_episodes if episode_index is None else episode_index
|
||||
)
|
||||
ep_buffer = {}
|
||||
# size and task are special cases that are not in self.features
|
||||
ep_buffer["size"] = 0
|
||||
|
@ -887,7 +930,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
episode_tasks = list(set(tasks))
|
||||
episode_index = episode_buffer["episode_index"]
|
||||
|
||||
episode_buffer["index"] = np.arange(self.meta.total_frames, self.meta.total_frames + episode_length)
|
||||
episode_buffer["index"] = np.arange(
|
||||
self.meta.total_frames, self.meta.total_frames + episode_length
|
||||
)
|
||||
episode_buffer["episode_index"] = np.full((episode_length,), episode_index)
|
||||
|
||||
# Add new tasks to the tasks dictionary
|
||||
|
@ -897,12 +942,17 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
self.meta.add_task(task)
|
||||
|
||||
# Given tasks in natural language, find their corresponding task indices
|
||||
episode_buffer["task_index"] = np.array([self.meta.get_task_index(task) for task in tasks])
|
||||
episode_buffer["task_index"] = np.array(
|
||||
[self.meta.get_task_index(task) for task in tasks]
|
||||
)
|
||||
|
||||
for key, ft in self.features.items():
|
||||
# index, episode_index, task_index are already processed above, and image and video
|
||||
# are processed separately by storing image path and frame info as meta data
|
||||
if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video"]:
|
||||
if key in ["index", "episode_index", "task_index"] or ft["dtype"] in [
|
||||
"image",
|
||||
"video",
|
||||
]:
|
||||
continue
|
||||
episode_buffer[key] = np.stack(episode_buffer[key])
|
||||
|
||||
|
@ -944,7 +994,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
|
||||
def _save_episode_table(self, episode_buffer: dict, episode_index: int) -> None:
|
||||
episode_dict = {key: episode_buffer[key] for key in self.hf_features}
|
||||
ep_dataset = datasets.Dataset.from_dict(episode_dict, features=self.hf_features, split="train")
|
||||
ep_dataset = datasets.Dataset.from_dict(
|
||||
episode_dict, features=self.hf_features, split="train"
|
||||
)
|
||||
ep_dataset = embed_images(ep_dataset)
|
||||
self.hf_dataset = concatenate_datasets([self.hf_dataset, ep_dataset])
|
||||
self.hf_dataset.set_transform(hf_transform_to_torch)
|
||||
|
@ -1063,7 +1115,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
obj.delta_timestamps = None
|
||||
obj.delta_indices = None
|
||||
obj.episode_data_index = None
|
||||
obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec()
|
||||
obj.video_backend = (
|
||||
video_backend if video_backend is not None else get_safe_default_codec()
|
||||
)
|
||||
return obj
|
||||
|
||||
|
||||
|
@ -1088,7 +1142,9 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
|||
super().__init__()
|
||||
self.repo_ids = repo_ids
|
||||
self.root = Path(root) if root else HF_LEROBOT_HOME
|
||||
self.tolerances_s = tolerances_s if tolerances_s else {repo_id: 1e-4 for repo_id in repo_ids}
|
||||
self.tolerances_s = (
|
||||
tolerances_s if tolerances_s else {repo_id: 1e-4 for repo_id in repo_ids}
|
||||
)
|
||||
# Construct the underlying datasets passing everything but `transform` and `delta_timestamps` which
|
||||
# are handled by this class.
|
||||
self._datasets = [
|
||||
|
|
|
@ -141,12 +141,16 @@ class SharpnessJitter(Transform):
|
|||
return float(sharpness[0]), float(sharpness[1])
|
||||
|
||||
def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]:
|
||||
sharpness_factor = torch.empty(1).uniform_(self.sharpness[0], self.sharpness[1]).item()
|
||||
sharpness_factor = (
|
||||
torch.empty(1).uniform_(self.sharpness[0], self.sharpness[1]).item()
|
||||
)
|
||||
return {"sharpness_factor": sharpness_factor}
|
||||
|
||||
def transform(self, inpt: Any, params: dict[str, Any]) -> Any:
|
||||
sharpness_factor = params["sharpness_factor"]
|
||||
return self._call_kernel(F.adjust_sharpness, inpt, sharpness_factor=sharpness_factor)
|
||||
return self._call_kernel(
|
||||
F.adjust_sharpness, inpt, sharpness_factor=sharpness_factor
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
|
@ -135,7 +135,9 @@ def serialize_dict(stats: dict[str, torch.Tensor | np.ndarray | dict]) -> dict:
|
|||
elif isinstance(value, (int, float)):
|
||||
serialized_dict[key] = value
|
||||
else:
|
||||
raise NotImplementedError(f"The value '{value}' of type '{type(value)}' is not supported.")
|
||||
raise NotImplementedError(
|
||||
f"The value '{value}' of type '{type(value)}' is not supported."
|
||||
)
|
||||
return unflatten_dict(serialized_dict)
|
||||
|
||||
|
||||
|
@ -214,7 +216,10 @@ def write_task(task_index: int, task: dict, local_dir: Path):
|
|||
|
||||
def load_tasks(local_dir: Path) -> tuple[dict, dict]:
|
||||
tasks = load_jsonlines(local_dir / TASKS_PATH)
|
||||
tasks = {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}
|
||||
tasks = {
|
||||
item["task_index"]: item["task"]
|
||||
for item in sorted(tasks, key=lambda x: x["task_index"])
|
||||
}
|
||||
task_to_task_index = {task: task_index for task_index, task in tasks.items()}
|
||||
return tasks, task_to_task_index
|
||||
|
||||
|
@ -225,13 +230,19 @@ def write_episode(episode: dict, local_dir: Path):
|
|||
|
||||
def load_episodes(local_dir: Path) -> dict:
|
||||
episodes = load_jsonlines(local_dir / EPISODES_PATH)
|
||||
return {item["episode_index"]: item for item in sorted(episodes, key=lambda x: x["episode_index"])}
|
||||
return {
|
||||
item["episode_index"]: item
|
||||
for item in sorted(episodes, key=lambda x: x["episode_index"])
|
||||
}
|
||||
|
||||
|
||||
def write_episode_stats(episode_index: int, episode_stats: dict, local_dir: Path):
|
||||
# We wrap episode_stats in a dictionary since `episode_stats["episode_index"]`
|
||||
# is a dictionary of stats and not an integer.
|
||||
episode_stats = {"episode_index": episode_index, "stats": serialize_dict(episode_stats)}
|
||||
episode_stats = {
|
||||
"episode_index": episode_index,
|
||||
"stats": serialize_dict(episode_stats),
|
||||
}
|
||||
append_jsonlines(episode_stats, local_dir / EPISODES_STATS_PATH)
|
||||
|
||||
|
||||
|
@ -275,7 +286,9 @@ def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
|
|||
elif first_item is None:
|
||||
pass
|
||||
else:
|
||||
items_dict[key] = [x if isinstance(x, str) else torch.tensor(x) for x in items_dict[key]]
|
||||
items_dict[key] = [
|
||||
x if isinstance(x, str) else torch.tensor(x) for x in items_dict[key]
|
||||
]
|
||||
return items_dict
|
||||
|
||||
|
||||
|
@ -328,7 +341,9 @@ def get_safe_version(repo_id: str, version: str | packaging.version.Version) ->
|
|||
Otherwise, will throw a `CompatibilityError`.
|
||||
"""
|
||||
target_version = (
|
||||
packaging.version.parse(version) if not isinstance(version, packaging.version.Version) else version
|
||||
packaging.version.parse(version)
|
||||
if not isinstance(version, packaging.version.Version)
|
||||
else version
|
||||
)
|
||||
hub_versions = get_repo_versions(repo_id)
|
||||
|
||||
|
@ -349,12 +364,16 @@ def get_safe_version(repo_id: str, version: str | packaging.version.Version) ->
|
|||
return f"v{target_version}"
|
||||
|
||||
compatibles = [
|
||||
v for v in hub_versions if v.major == target_version.major and v.minor <= target_version.minor
|
||||
v
|
||||
for v in hub_versions
|
||||
if v.major == target_version.major and v.minor <= target_version.minor
|
||||
]
|
||||
if compatibles:
|
||||
return_version = max(compatibles)
|
||||
if return_version < target_version:
|
||||
logging.warning(f"Revision {version} for {repo_id} not found, using version v{return_version}")
|
||||
logging.warning(
|
||||
f"Revision {version} for {repo_id} not found, using version v{return_version}"
|
||||
)
|
||||
return f"v{return_version}"
|
||||
|
||||
lower_major = [v for v in hub_versions if v.major < target_version.major]
|
||||
|
@ -461,7 +480,9 @@ def create_empty_dataset_info(
|
|||
def get_episode_data_index(
|
||||
episode_dicts: dict[dict], episodes: list[int] | None = None
|
||||
) -> dict[str, torch.Tensor]:
|
||||
episode_lengths = {ep_idx: ep_dict["length"] for ep_idx, ep_dict in episode_dicts.items()}
|
||||
episode_lengths = {
|
||||
ep_idx: ep_dict["length"] for ep_idx, ep_dict in episode_dicts.items()
|
||||
}
|
||||
if episodes is not None:
|
||||
episode_lengths = {ep_idx: episode_lengths[ep_idx] for ep_idx in episodes}
|
||||
|
||||
|
@ -511,7 +532,9 @@ def check_timestamps_sync(
|
|||
|
||||
# Mask to ignore differences at the boundaries between episodes
|
||||
mask = np.ones(len(diffs), dtype=bool)
|
||||
ignored_diffs = episode_data_index["to"][:-1] - 1 # indices at the end of each episode
|
||||
ignored_diffs = (
|
||||
episode_data_index["to"][:-1] - 1
|
||||
) # indices at the end of each episode
|
||||
mask[ignored_diffs] = False
|
||||
filtered_within_tolerance = within_tolerance[mask]
|
||||
|
||||
|
@ -720,14 +743,18 @@ def validate_frame(frame: dict, features: dict):
|
|||
expected_features = (set(features) - set(DEFAULT_FEATURES.keys())) | {"task"}
|
||||
actual_features = set(frame.keys())
|
||||
|
||||
error_message = validate_features_presence(actual_features, expected_features, optional_features)
|
||||
error_message = validate_features_presence(
|
||||
actual_features, expected_features, optional_features
|
||||
)
|
||||
|
||||
if "task" in frame:
|
||||
error_message += validate_feature_string("task", frame["task"])
|
||||
|
||||
common_features = actual_features & (expected_features | optional_features)
|
||||
for name in common_features - {"task"}:
|
||||
error_message += validate_feature_dtype_and_shape(name, features[name], frame[name])
|
||||
error_message += validate_feature_dtype_and_shape(
|
||||
name, features[name], frame[name]
|
||||
)
|
||||
|
||||
if error_message:
|
||||
raise ValueError(error_message)
|
||||
|
@ -750,7 +777,9 @@ def validate_features_presence(
|
|||
return error_message
|
||||
|
||||
|
||||
def validate_feature_dtype_and_shape(name: str, feature: dict, value: np.ndarray | PILImage.Image | str):
|
||||
def validate_feature_dtype_and_shape(
|
||||
name: str, feature: dict, value: np.ndarray | PILImage.Image | str
|
||||
):
|
||||
expected_dtype = feature["dtype"]
|
||||
expected_shape = feature["shape"]
|
||||
if is_valid_numpy_dtype_string(expected_dtype):
|
||||
|
@ -760,7 +789,9 @@ def validate_feature_dtype_and_shape(name: str, feature: dict, value: np.ndarray
|
|||
elif expected_dtype == "string":
|
||||
return validate_feature_string(name, value)
|
||||
else:
|
||||
raise NotImplementedError(f"The feature dtype '{expected_dtype}' is not implemented yet.")
|
||||
raise NotImplementedError(
|
||||
f"The feature dtype '{expected_dtype}' is not implemented yet."
|
||||
)
|
||||
|
||||
|
||||
def validate_feature_numpy_array(
|
||||
|
@ -782,13 +813,17 @@ def validate_feature_numpy_array(
|
|||
return error_message
|
||||
|
||||
|
||||
def validate_feature_image_or_video(name: str, expected_shape: list[str], value: np.ndarray | PILImage.Image):
|
||||
def validate_feature_image_or_video(
|
||||
name: str, expected_shape: list[str], value: np.ndarray | PILImage.Image
|
||||
):
|
||||
# Note: The check of pixels range ([0,1] for float and [0,255] for uint8) is done by the image writer threads.
|
||||
error_message = ""
|
||||
if isinstance(value, np.ndarray):
|
||||
actual_shape = value.shape
|
||||
c, h, w = expected_shape
|
||||
if len(actual_shape) != 3 or (actual_shape != (c, h, w) and actual_shape != (h, w, c)):
|
||||
if len(actual_shape) != 3 or (
|
||||
actual_shape != (c, h, w) and actual_shape != (h, w, c)
|
||||
):
|
||||
error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(c, h, w)}' or '{(h, w, c)}'.\n"
|
||||
elif isinstance(value, PILImage.Image):
|
||||
pass
|
||||
|
@ -819,7 +854,9 @@ def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features:
|
|||
)
|
||||
|
||||
if episode_buffer["size"] == 0:
|
||||
raise ValueError("You must add one or several frames with `add_frame` before calling `add_episode`.")
|
||||
raise ValueError(
|
||||
"You must add one or several frames with `add_frame` before calling `add_episode`."
|
||||
)
|
||||
|
||||
buffer_keys = set(episode_buffer.keys()) - {"task", "size"}
|
||||
if not buffer_keys == set(features):
|
||||
|
|
|
@ -35,22 +35,30 @@ def fix_dataset(repo_id: str) -> str:
|
|||
|
||||
dataset_info = get_dataset_config_info(repo_id, "default")
|
||||
with SuppressWarnings():
|
||||
lerobot_metadata = LeRobotDatasetMetadata(repo_id, revision=V20, force_cache_sync=True)
|
||||
lerobot_metadata = LeRobotDatasetMetadata(
|
||||
repo_id, revision=V20, force_cache_sync=True
|
||||
)
|
||||
|
||||
meta_features = {key for key, ft in lerobot_metadata.features.items() if ft["dtype"] != "video"}
|
||||
meta_features = {
|
||||
key for key, ft in lerobot_metadata.features.items() if ft["dtype"] != "video"
|
||||
}
|
||||
parquet_features = set(dataset_info.features)
|
||||
|
||||
diff_parquet_meta = parquet_features - meta_features
|
||||
diff_meta_parquet = meta_features - parquet_features
|
||||
|
||||
if diff_parquet_meta:
|
||||
raise ValueError(f"In parquet not in info.json: {parquet_features - meta_features}")
|
||||
raise ValueError(
|
||||
f"In parquet not in info.json: {parquet_features - meta_features}"
|
||||
)
|
||||
|
||||
if not diff_meta_parquet:
|
||||
return f"{repo_id}: skipped (no diff)"
|
||||
|
||||
if diff_meta_parquet:
|
||||
logging.warning(f"In info.json not in parquet: {meta_features - parquet_features}")
|
||||
logging.warning(
|
||||
f"In info.json not in parquet: {meta_features - parquet_features}"
|
||||
)
|
||||
assert diff_meta_parquet == {"language_instruction"}
|
||||
lerobot_metadata.features.pop("language_instruction")
|
||||
write_info(lerobot_metadata.info, lerobot_metadata.root)
|
||||
|
|
|
@ -37,8 +37,16 @@ import logging
|
|||
from huggingface_hub import HfApi
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
|
||||
from lerobot.common.datasets.utils import EPISODES_STATS_PATH, STATS_PATH, load_stats, write_info
|
||||
from lerobot.common.datasets.v21.convert_stats import check_aggregate_stats, convert_stats
|
||||
from lerobot.common.datasets.utils import (
|
||||
EPISODES_STATS_PATH,
|
||||
STATS_PATH,
|
||||
load_stats,
|
||||
write_info,
|
||||
)
|
||||
from lerobot.common.datasets.v21.convert_stats import (
|
||||
check_aggregate_stats,
|
||||
convert_stats,
|
||||
)
|
||||
|
||||
V20 = "v2.0"
|
||||
V21 = "v2.1"
|
||||
|
@ -79,13 +87,21 @@ def convert_dataset(
|
|||
|
||||
hub_api = HfApi()
|
||||
if hub_api.file_exists(
|
||||
repo_id=dataset.repo_id, filename=STATS_PATH, revision=branch, repo_type="dataset"
|
||||
repo_id=dataset.repo_id,
|
||||
filename=STATS_PATH,
|
||||
revision=branch,
|
||||
repo_type="dataset",
|
||||
):
|
||||
hub_api.delete_file(
|
||||
path_in_repo=STATS_PATH, repo_id=dataset.repo_id, revision=branch, repo_type="dataset"
|
||||
path_in_repo=STATS_PATH,
|
||||
repo_id=dataset.repo_id,
|
||||
revision=branch,
|
||||
repo_type="dataset",
|
||||
)
|
||||
|
||||
hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
|
||||
hub_api.create_tag(
|
||||
repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -17,12 +17,18 @@ from concurrent.futures import ThreadPoolExecutor, as_completed
|
|||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.common.datasets.compute_stats import aggregate_stats, get_feature_stats, sample_indices
|
||||
from lerobot.common.datasets.compute_stats import (
|
||||
aggregate_stats,
|
||||
get_feature_stats,
|
||||
sample_indices,
|
||||
)
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.datasets.utils import write_episode_stats
|
||||
|
||||
|
||||
def sample_episode_video_frames(dataset: LeRobotDataset, episode_index: int, ft_key: str) -> np.ndarray:
|
||||
def sample_episode_video_frames(
|
||||
dataset: LeRobotDataset, episode_index: int, ft_key: str
|
||||
) -> np.ndarray:
|
||||
ep_len = dataset.meta.episodes[episode_index]["length"]
|
||||
sampled_indices = sample_indices(ep_len)
|
||||
query_timestamps = dataset._get_query_timestamps(0.0, {ft_key: sampled_indices})
|
||||
|
@ -45,11 +51,14 @@ def convert_episode_stats(dataset: LeRobotDataset, ep_idx: int):
|
|||
|
||||
axes_to_reduce = (0, 2, 3) if ft["dtype"] in ["image", "video"] else 0
|
||||
keepdims = True if ft["dtype"] in ["image", "video"] else ep_ft_data.ndim == 1
|
||||
ep_stats[key] = get_feature_stats(ep_ft_data, axis=axes_to_reduce, keepdims=keepdims)
|
||||
ep_stats[key] = get_feature_stats(
|
||||
ep_ft_data, axis=axes_to_reduce, keepdims=keepdims
|
||||
)
|
||||
|
||||
if ft["dtype"] in ["image", "video"]: # remove batch dim
|
||||
ep_stats[key] = {
|
||||
k: v if k == "count" else np.squeeze(v, axis=0) for k, v in ep_stats[key].items()
|
||||
k: v if k == "count" else np.squeeze(v, axis=0)
|
||||
for k, v in ep_stats[key].items()
|
||||
}
|
||||
|
||||
dataset.meta.episodes_stats[ep_idx] = ep_stats
|
||||
|
@ -95,5 +104,9 @@ def check_aggregate_stats(
|
|||
if key in reference_stats and stat in reference_stats[key]:
|
||||
err_msg = f"feature='{key}' stats='{stat}'"
|
||||
np.testing.assert_allclose(
|
||||
val, reference_stats[key][stat], rtol=rtol, atol=atol, err_msg=err_msg
|
||||
val,
|
||||
reference_stats[key][stat],
|
||||
rtol=rtol,
|
||||
atol=atol,
|
||||
err_msg=err_msg,
|
||||
)
|
||||
|
|
|
@ -65,7 +65,9 @@ def decode_video_frames(
|
|||
if backend == "torchcodec":
|
||||
return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s)
|
||||
elif backend in ["pyav", "video_reader"]:
|
||||
return decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend)
|
||||
return decode_video_frames_torchvision(
|
||||
video_path, timestamps, tolerance_s, backend
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported video backend: {backend}")
|
||||
|
||||
|
|
|
@ -61,10 +61,16 @@ class AlohaEnv(EnvConfig):
|
|||
|
||||
def __post_init__(self):
|
||||
if self.obs_type == "pixels":
|
||||
self.features["top"] = PolicyFeature(type=FeatureType.VISUAL, shape=(480, 640, 3))
|
||||
self.features["top"] = PolicyFeature(
|
||||
type=FeatureType.VISUAL, shape=(480, 640, 3)
|
||||
)
|
||||
elif self.obs_type == "pixels_agent_pos":
|
||||
self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(14,))
|
||||
self.features["pixels/top"] = PolicyFeature(type=FeatureType.VISUAL, shape=(480, 640, 3))
|
||||
self.features["agent_pos"] = PolicyFeature(
|
||||
type=FeatureType.STATE, shape=(14,)
|
||||
)
|
||||
self.features["pixels/top"] = PolicyFeature(
|
||||
type=FeatureType.VISUAL, shape=(480, 640, 3)
|
||||
)
|
||||
|
||||
@property
|
||||
def gym_kwargs(self) -> dict:
|
||||
|
@ -102,9 +108,13 @@ class PushtEnv(EnvConfig):
|
|||
|
||||
def __post_init__(self):
|
||||
if self.obs_type == "pixels_agent_pos":
|
||||
self.features["pixels"] = PolicyFeature(type=FeatureType.VISUAL, shape=(384, 384, 3))
|
||||
self.features["pixels"] = PolicyFeature(
|
||||
type=FeatureType.VISUAL, shape=(384, 384, 3)
|
||||
)
|
||||
elif self.obs_type == "environment_state_agent_pos":
|
||||
self.features["environment_state"] = PolicyFeature(type=FeatureType.ENV, shape=(16,))
|
||||
self.features["environment_state"] = PolicyFeature(
|
||||
type=FeatureType.ENV, shape=(16,)
|
||||
)
|
||||
|
||||
@property
|
||||
def gym_kwargs(self) -> dict:
|
||||
|
@ -143,7 +153,9 @@ class XarmEnv(EnvConfig):
|
|||
|
||||
def __post_init__(self):
|
||||
if self.obs_type == "pixels_agent_pos":
|
||||
self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(4,))
|
||||
self.features["agent_pos"] = PolicyFeature(
|
||||
type=FeatureType.STATE, shape=(4,)
|
||||
)
|
||||
|
||||
@property
|
||||
def gym_kwargs(self) -> dict:
|
||||
|
|
|
@ -32,7 +32,9 @@ def make_env_config(env_type: str, **kwargs) -> EnvConfig:
|
|||
raise ValueError(f"Policy type '{env_type}' is not available.")
|
||||
|
||||
|
||||
def make_env(cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False) -> gym.vector.VectorEnv | None:
|
||||
def make_env(
|
||||
cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False
|
||||
) -> gym.vector.VectorEnv | None:
|
||||
"""Makes a gym vector environment according to the config.
|
||||
|
||||
Args:
|
||||
|
@ -56,7 +58,9 @@ def make_env(cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False) -> g
|
|||
try:
|
||||
importlib.import_module(package_name)
|
||||
except ModuleNotFoundError as e:
|
||||
print(f"{package_name} is not installed. Please install it with `pip install 'lerobot[{cfg.type}]'`")
|
||||
print(
|
||||
f"{package_name} is not installed. Please install it with `pip install 'lerobot[{cfg.type}]'`"
|
||||
)
|
||||
raise e
|
||||
|
||||
gym_handle = f"{package_name}/{cfg.task}"
|
||||
|
@ -64,7 +68,10 @@ def make_env(cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False) -> g
|
|||
# batched version of the env that returns an observation of shape (b, c)
|
||||
env_cls = gym.vector.AsyncVectorEnv if use_async_envs else gym.vector.SyncVectorEnv
|
||||
env = env_cls(
|
||||
[lambda: gym.make(gym_handle, disable_env_checker=True, **cfg.gym_kwargs) for _ in range(n_envs)]
|
||||
[
|
||||
lambda: gym.make(gym_handle, disable_env_checker=True, **cfg.gym_kwargs)
|
||||
for _ in range(n_envs)
|
||||
]
|
||||
)
|
||||
|
||||
return env
|
||||
|
|
|
@ -46,7 +46,9 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
|
|||
|
||||
# sanity check that images are channel last
|
||||
_, h, w, c = img.shape
|
||||
assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}"
|
||||
assert c < h and c < w, (
|
||||
f"expect channel last images, but instead got {img.shape=}"
|
||||
)
|
||||
|
||||
# sanity check that images are uint8
|
||||
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
|
||||
|
@ -79,7 +81,9 @@ def env_to_policy_features(env_cfg: EnvConfig) -> dict[str, PolicyFeature]:
|
|||
for key, ft in env_cfg.features.items():
|
||||
if ft.type is FeatureType.VISUAL:
|
||||
if len(ft.shape) != 3:
|
||||
raise ValueError(f"Number of dimensions of {key} != 3 (shape={ft.shape})")
|
||||
raise ValueError(
|
||||
f"Number of dimensions of {key} != 3 (shape={ft.shape})"
|
||||
)
|
||||
|
||||
shape = get_channel_first_image_shape(ft.shape)
|
||||
feature = PolicyFeature(type=ft.type, shape=shape)
|
||||
|
@ -92,7 +96,9 @@ def env_to_policy_features(env_cfg: EnvConfig) -> dict[str, PolicyFeature]:
|
|||
return policy_features
|
||||
|
||||
|
||||
def preprocess_maniskill_observation(observations: dict[str, np.ndarray]) -> dict[str, Tensor]:
|
||||
def preprocess_maniskill_observation(
|
||||
observations: dict[str, np.ndarray],
|
||||
) -> dict[str, Tensor]:
|
||||
"""Convert environment observation to LeRobot format observation.
|
||||
Args:
|
||||
observation: Dictionary of observation batches from a Gym vector environment.
|
||||
|
|
|
@ -250,9 +250,9 @@ class Logger:
|
|||
)
|
||||
# For the case where the optimizer is a dictionary of optimizers (e.g., sac)
|
||||
if type(training_state["optimizer"]) is dict:
|
||||
assert set(training_state["optimizer"].keys()) == set(
|
||||
optimizer.keys()
|
||||
), "Optimizer dictionaries do not have the same keys during resume!"
|
||||
assert set(training_state["optimizer"].keys()) == set(optimizer.keys()), (
|
||||
"Optimizer dictionaries do not have the same keys during resume!"
|
||||
)
|
||||
for k, v in training_state["optimizer"].items():
|
||||
optimizer[k].load_state_dict(v)
|
||||
else:
|
||||
|
|
|
@ -34,7 +34,13 @@ def make_optimizer_and_scheduler(
|
|||
Returns:
|
||||
tuple[Optimizer, LRScheduler | None]: The couple (Optimizer, Scheduler). Scheduler can be `None`.
|
||||
"""
|
||||
params = policy.get_optim_params() if cfg.use_policy_training_preset else policy.parameters()
|
||||
params = (
|
||||
policy.get_optim_params()
|
||||
if cfg.use_policy_training_preset
|
||||
else policy.parameters()
|
||||
)
|
||||
optimizer = cfg.optimizer.build(params)
|
||||
lr_scheduler = cfg.scheduler.build(optimizer, cfg.steps) if cfg.scheduler is not None else None
|
||||
lr_scheduler = (
|
||||
cfg.scheduler.build(optimizer, cfg.steps) if cfg.scheduler is not None else None
|
||||
)
|
||||
return optimizer, lr_scheduler
|
||||
|
|
|
@ -102,7 +102,9 @@ def save_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> No
|
|||
write_json(param_groups, save_dir / OPTIMIZER_PARAM_GROUPS)
|
||||
|
||||
|
||||
def load_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> torch.optim.Optimizer:
|
||||
def load_optimizer_state(
|
||||
optimizer: torch.optim.Optimizer, save_dir: Path
|
||||
) -> torch.optim.Optimizer:
|
||||
current_state_dict = optimizer.state_dict()
|
||||
flat_state = load_file(save_dir / OPTIMIZER_STATE)
|
||||
state = unflatten_dict(flat_state)
|
||||
|
|
|
@ -36,7 +36,9 @@ class LRSchedulerConfig(draccus.ChoiceRegistry, abc.ABC):
|
|||
return self.get_choice_name(self.__class__)
|
||||
|
||||
@abc.abstractmethod
|
||||
def build(self, optimizer: Optimizer, num_training_steps: int) -> LRScheduler | None:
|
||||
def build(
|
||||
self, optimizer: Optimizer, num_training_steps: int
|
||||
) -> LRScheduler | None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
|
@ -49,7 +51,11 @@ class DiffuserSchedulerConfig(LRSchedulerConfig):
|
|||
def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR:
|
||||
from diffusers.optimization import get_scheduler
|
||||
|
||||
kwargs = {**asdict(self), "num_training_steps": num_training_steps, "optimizer": optimizer}
|
||||
kwargs = {
|
||||
**asdict(self),
|
||||
"num_training_steps": num_training_steps,
|
||||
"optimizer": optimizer,
|
||||
}
|
||||
return get_scheduler(**kwargs)
|
||||
|
||||
|
||||
|
@ -71,7 +77,14 @@ class VQBeTSchedulerConfig(LRSchedulerConfig):
|
|||
progress = float(adjusted_step - self.num_warmup_steps) / float(
|
||||
max(1, num_training_steps - self.num_warmup_steps)
|
||||
)
|
||||
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(self.num_cycles) * 2.0 * progress)))
|
||||
return max(
|
||||
0.0,
|
||||
0.5
|
||||
* (
|
||||
1.0
|
||||
+ math.cos(math.pi * float(self.num_cycles) * 2.0 * progress)
|
||||
),
|
||||
)
|
||||
|
||||
return LambdaLR(optimizer, lr_lambda, -1)
|
||||
|
||||
|
@ -98,7 +111,9 @@ class CosineDecayWithWarmupSchedulerConfig(LRSchedulerConfig):
|
|||
|
||||
def cosine_decay_schedule(current_step):
|
||||
step = min(current_step, self.num_decay_steps)
|
||||
cosine_decay = 0.5 * (1 + math.cos(math.pi * step / self.num_decay_steps))
|
||||
cosine_decay = 0.5 * (
|
||||
1 + math.cos(math.pi * step / self.num_decay_steps)
|
||||
)
|
||||
alpha = self.decay_lr / self.peak_lr
|
||||
decayed = (1 - alpha) * cosine_decay + alpha
|
||||
return decayed
|
||||
|
@ -117,6 +132,8 @@ def save_scheduler_state(scheduler: LRScheduler, save_dir: Path) -> None:
|
|||
|
||||
|
||||
def load_scheduler_state(scheduler: LRScheduler, save_dir: Path) -> LRScheduler:
|
||||
state_dict = deserialize_json_into_object(save_dir / SCHEDULER_STATE, scheduler.state_dict())
|
||||
state_dict = deserialize_json_into_object(
|
||||
save_dir / SCHEDULER_STATE, scheduler.state_dict()
|
||||
)
|
||||
scheduler.load_state_dict(state_dict)
|
||||
return scheduler
|
||||
|
|
|
@ -171,7 +171,9 @@ class ACTConfig(PreTrainedConfig):
|
|||
|
||||
def validate_features(self) -> None:
|
||||
if not self.image_features and not self.env_state_feature:
|
||||
raise ValueError("You must provide at least one image or the environment state among the inputs.")
|
||||
raise ValueError(
|
||||
"You must provide at least one image or the environment state among the inputs."
|
||||
)
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> None:
|
||||
|
|
|
@ -63,7 +63,9 @@ class ACTPolicy(PreTrainedPolicy):
|
|||
config.validate_features()
|
||||
self.config = config
|
||||
|
||||
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||||
self.normalize_inputs = Normalize(
|
||||
config.input_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
|
@ -120,8 +122,12 @@ class ACTPolicy(PreTrainedPolicy):
|
|||
|
||||
batch = self.normalize_inputs(batch)
|
||||
if self.config.image_features:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = [batch[key] for key in self.config.image_features]
|
||||
batch = dict(
|
||||
batch
|
||||
) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = [
|
||||
batch[key] for key in self.config.image_features
|
||||
]
|
||||
|
||||
# If we are doing temporal ensembling, do online updates where we keep track of the number of actions
|
||||
# we are ensembling over.
|
||||
|
@ -148,8 +154,12 @@ class ACTPolicy(PreTrainedPolicy):
|
|||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
if self.config.image_features:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = [batch[key] for key in self.config.image_features]
|
||||
batch = dict(
|
||||
batch
|
||||
) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = [
|
||||
batch[key] for key in self.config.image_features
|
||||
]
|
||||
|
||||
batch = self.normalize_targets(batch)
|
||||
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
|
||||
|
@ -406,14 +416,18 @@ class ACT(nn.Module):
|
|||
n_1d_tokens += 1
|
||||
self.encoder_1d_feature_pos_embed = nn.Embedding(n_1d_tokens, config.dim_model)
|
||||
if self.config.image_features:
|
||||
self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2)
|
||||
self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(
|
||||
config.dim_model // 2
|
||||
)
|
||||
|
||||
# Transformer decoder.
|
||||
# Learnable positional embedding for the transformer's decoder (in the style of DETR object queries).
|
||||
self.decoder_pos_embed = nn.Embedding(config.chunk_size, config.dim_model)
|
||||
|
||||
# Final action regression head on the output of the transformer's decoder.
|
||||
self.action_head = nn.Linear(config.dim_model, self.config.action_feature.shape[0])
|
||||
self.action_head = nn.Linear(
|
||||
config.dim_model, self.config.action_feature.shape[0]
|
||||
)
|
||||
|
||||
self._reset_parameters()
|
||||
|
||||
|
@ -461,14 +475,20 @@ class ACT(nn.Module):
|
|||
self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size
|
||||
) # (B, 1, D)
|
||||
if self.config.robot_state_feature:
|
||||
robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"])
|
||||
robot_state_embed = self.vae_encoder_robot_state_input_proj(
|
||||
batch["observation.state"]
|
||||
)
|
||||
robot_state_embed = robot_state_embed.unsqueeze(1) # (B, 1, D)
|
||||
action_embed = self.vae_encoder_action_input_proj(
|
||||
batch["action"]
|
||||
) # (B, S, D)
|
||||
|
||||
if self.config.robot_state_feature:
|
||||
vae_encoder_input = [cls_embed, robot_state_embed, action_embed] # (B, S+2, D)
|
||||
vae_encoder_input = [
|
||||
cls_embed,
|
||||
robot_state_embed,
|
||||
action_embed,
|
||||
] # (B, S+2, D)
|
||||
else:
|
||||
vae_encoder_input = [cls_embed, action_embed]
|
||||
vae_encoder_input = torch.cat(vae_encoder_input, axis=1)
|
||||
|
@ -517,7 +537,9 @@ class ACT(nn.Module):
|
|||
)
|
||||
# Robot state token.
|
||||
if self.config.robot_state_feature:
|
||||
encoder_in_tokens.append(self.encoder_robot_state_input_proj(batch["observation.state"]))
|
||||
encoder_in_tokens.append(
|
||||
self.encoder_robot_state_input_proj(batch["observation.state"])
|
||||
)
|
||||
# Environment state token.
|
||||
if self.config.env_state_feature:
|
||||
encoder_in_tokens.append(
|
||||
|
@ -534,7 +556,9 @@ class ACT(nn.Module):
|
|||
# For a list of images, the H and W may vary but H*W is constant.
|
||||
for img in batch["observation.images"]:
|
||||
cam_features = self.backbone(img)["feature_map"]
|
||||
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
|
||||
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(
|
||||
dtype=cam_features.dtype
|
||||
)
|
||||
cam_features = self.encoder_img_feat_input_proj(cam_features)
|
||||
|
||||
# Rearrange features to (sequence, batch, dim).
|
||||
|
|
|
@ -205,11 +205,16 @@ class DiffusionConfig(PreTrainedConfig):
|
|||
|
||||
def validate_features(self) -> None:
|
||||
if len(self.image_features) == 0 and self.env_state_feature is None:
|
||||
raise ValueError("You must provide at least one image or the environment state among the inputs.")
|
||||
raise ValueError(
|
||||
"You must provide at least one image or the environment state among the inputs."
|
||||
)
|
||||
|
||||
if self.crop_shape is not None:
|
||||
for key, image_ft in self.image_features.items():
|
||||
if self.crop_shape[0] > image_ft.shape[1] or self.crop_shape[1] > image_ft.shape[2]:
|
||||
if (
|
||||
self.crop_shape[0] > image_ft.shape[1]
|
||||
or self.crop_shape[1] > image_ft.shape[2]
|
||||
):
|
||||
raise ValueError(
|
||||
f"`crop_shape` should fit within the images shapes. Got {self.crop_shape} "
|
||||
f"for `crop_shape` and {image_ft.shape} for "
|
||||
|
|
|
@ -70,7 +70,9 @@ class DiffusionPolicy(PreTrainedPolicy):
|
|||
config.validate_features()
|
||||
self.config = config
|
||||
|
||||
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||||
self.normalize_inputs = Normalize(
|
||||
config.input_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
|
@ -97,7 +99,9 @@ class DiffusionPolicy(PreTrainedPolicy):
|
|||
if self.config.image_features:
|
||||
self._queues["observation.images"] = deque(maxlen=self.config.n_obs_steps)
|
||||
if self.config.env_state_feature:
|
||||
self._queues["observation.environment_state"] = deque(maxlen=self.config.n_obs_steps)
|
||||
self._queues["observation.environment_state"] = deque(
|
||||
maxlen=self.config.n_obs_steps
|
||||
)
|
||||
|
||||
@torch.no_grad
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
|
@ -123,7 +127,9 @@ class DiffusionPolicy(PreTrainedPolicy):
|
|||
"""
|
||||
batch = self.normalize_inputs(batch)
|
||||
if self.config.image_features:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch = dict(
|
||||
batch
|
||||
) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = torch.stack(
|
||||
[batch[key] for key in self.config.image_features], dim=-4
|
||||
)
|
||||
|
@ -151,7 +157,9 @@ class DiffusionPolicy(PreTrainedPolicy):
|
|||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
if self.config.image_features:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch = dict(
|
||||
batch
|
||||
) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = torch.stack(
|
||||
[batch[key] for key in self.config.image_features], dim=-4
|
||||
)
|
||||
|
@ -515,11 +523,15 @@ class DiffusionRgbEncoder(nn.Module):
|
|||
|
||||
# Note: we have a check in the config class to make sure all images have the same shape.
|
||||
images_shape = next(iter(config.image_features.values())).shape
|
||||
dummy_shape_h_w = config.crop_shape if config.crop_shape is not None else images_shape[1:]
|
||||
dummy_shape_h_w = (
|
||||
config.crop_shape if config.crop_shape is not None else images_shape[1:]
|
||||
)
|
||||
dummy_shape = (1, images_shape[0], *dummy_shape_h_w)
|
||||
feature_map_shape = get_output_shape(self.backbone, dummy_shape)[1:]
|
||||
|
||||
self.pool = SpatialSoftmax(feature_map_shape, num_kp=config.spatial_softmax_num_keypoints)
|
||||
self.pool = SpatialSoftmax(
|
||||
feature_map_shape, num_kp=config.spatial_softmax_num_keypoints
|
||||
)
|
||||
self.feature_dim = config.spatial_softmax_num_keypoints * 2
|
||||
self.out = nn.Linear(config.spatial_softmax_num_keypoints * 2, self.feature_dim)
|
||||
self.relu = nn.ReLU()
|
||||
|
@ -719,7 +731,9 @@ class DiffusionConditionalUnet1d(nn.Module):
|
|||
)
|
||||
|
||||
self.final_conv = nn.Sequential(
|
||||
DiffusionConv1dBlock(config.down_dims[0], config.down_dims[0], kernel_size=config.kernel_size),
|
||||
DiffusionConv1dBlock(
|
||||
config.down_dims[0], config.down_dims[0], kernel_size=config.kernel_size
|
||||
),
|
||||
nn.Conv1d(config.down_dims[0], config.action_feature.shape[0], 1),
|
||||
)
|
||||
|
||||
|
|
|
@ -104,7 +104,9 @@ def make_policy(
|
|||
PreTrainedPolicy: _description_
|
||||
"""
|
||||
if bool(ds_meta) == bool(env_cfg):
|
||||
raise ValueError("Either one of a dataset metadata or a sim env must be provided.")
|
||||
raise ValueError(
|
||||
"Either one of a dataset metadata or a sim env must be provided."
|
||||
)
|
||||
|
||||
# NOTE: Currently, if you try to run vqbet with mps backend, you'll get this error.
|
||||
# TODO(aliberts, rcadene): Implement a check_backend_compatibility in policies?
|
||||
|
@ -134,8 +136,12 @@ def make_policy(
|
|||
)
|
||||
features = env_to_policy_features(env_cfg)
|
||||
|
||||
cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
|
||||
cfg.input_features = {key: ft for key, ft in features.items() if key not in cfg.output_features}
|
||||
cfg.output_features = {
|
||||
key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION
|
||||
}
|
||||
cfg.input_features = {
|
||||
key: ft for key, ft in features.items() if key not in cfg.output_features
|
||||
}
|
||||
kwargs["config"] = cfg
|
||||
|
||||
if cfg.pretrained_path:
|
||||
|
|
|
@ -82,25 +82,43 @@ def create_stats_buffers(
|
|||
if stats:
|
||||
if isinstance(stats[key]["mean"], np.ndarray):
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
buffer["mean"].data = torch.from_numpy(stats[key]["mean"]).to(dtype=torch.float32)
|
||||
buffer["std"].data = torch.from_numpy(stats[key]["std"]).to(dtype=torch.float32)
|
||||
buffer["mean"].data = torch.from_numpy(stats[key]["mean"]).to(
|
||||
dtype=torch.float32
|
||||
)
|
||||
buffer["std"].data = torch.from_numpy(stats[key]["std"]).to(
|
||||
dtype=torch.float32
|
||||
)
|
||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||
buffer["min"].data = torch.from_numpy(stats[key]["min"]).to(dtype=torch.float32)
|
||||
buffer["max"].data = torch.from_numpy(stats[key]["max"]).to(dtype=torch.float32)
|
||||
buffer["min"].data = torch.from_numpy(stats[key]["min"]).to(
|
||||
dtype=torch.float32
|
||||
)
|
||||
buffer["max"].data = torch.from_numpy(stats[key]["max"]).to(
|
||||
dtype=torch.float32
|
||||
)
|
||||
elif isinstance(stats[key]["mean"], torch.Tensor):
|
||||
# Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
|
||||
# tensors anywhere (for example, when we use the same stats for normalization and
|
||||
# unnormalization). See the logic here
|
||||
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
buffer["mean"].data = stats[key]["mean"].clone().to(dtype=torch.float32)
|
||||
buffer["std"].data = stats[key]["std"].clone().to(dtype=torch.float32)
|
||||
buffer["mean"].data = (
|
||||
stats[key]["mean"].clone().to(dtype=torch.float32)
|
||||
)
|
||||
buffer["std"].data = (
|
||||
stats[key]["std"].clone().to(dtype=torch.float32)
|
||||
)
|
||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||
buffer["min"].data = stats[key]["min"].clone().to(dtype=torch.float32)
|
||||
buffer["max"].data = stats[key]["max"].clone().to(dtype=torch.float32)
|
||||
buffer["min"].data = (
|
||||
stats[key]["min"].clone().to(dtype=torch.float32)
|
||||
)
|
||||
buffer["max"].data = (
|
||||
stats[key]["max"].clone().to(dtype=torch.float32)
|
||||
)
|
||||
else:
|
||||
type_ = type(stats[key]["mean"])
|
||||
raise ValueError(f"np.ndarray or torch.Tensor expected, but type is '{type_}' instead.")
|
||||
raise ValueError(
|
||||
f"np.ndarray or torch.Tensor expected, but type is '{type_}' instead."
|
||||
)
|
||||
|
||||
stats_buffers[key] = buffer
|
||||
return stats_buffers
|
||||
|
|
|
@ -44,7 +44,9 @@ def main():
|
|||
else:
|
||||
dataset_repo_id = "lerobot/aloha_sim_transfer_cube_human"
|
||||
|
||||
ckpt_torch_dir = Path.home() / f".cache/openpi/openpi-assets/checkpoints/{model_name}_pytorch"
|
||||
ckpt_torch_dir = (
|
||||
Path.home() / f".cache/openpi/openpi-assets/checkpoints/{model_name}_pytorch"
|
||||
)
|
||||
ckpt_jax_dir = Path.home() / f".cache/openpi/openpi-assets/checkpoints/{model_name}"
|
||||
save_dir = Path(f"../openpi/data/{model_name}/save")
|
||||
|
||||
|
@ -70,7 +72,9 @@ def main():
|
|||
# Create LeRobot batch from Jax
|
||||
batch = {}
|
||||
for cam_key, uint_chw_array in example["images"].items():
|
||||
batch[f"observation.images.{cam_key}"] = torch.from_numpy(uint_chw_array) / 255.0
|
||||
batch[f"observation.images.{cam_key}"] = (
|
||||
torch.from_numpy(uint_chw_array) / 255.0
|
||||
)
|
||||
batch["observation.state"] = torch.from_numpy(example["state"])
|
||||
batch["action"] = torch.from_numpy(outputs["actions"])
|
||||
batch["task"] = example["prompt"]
|
||||
|
|
|
@ -54,7 +54,9 @@ def get_paligemma_config(precision: str):
|
|||
"projector_hidden_act": "gelu_fast",
|
||||
"vision_use_head": False,
|
||||
}
|
||||
final_config = PaliGemmaConfig(text_config=text_config, vision_config=vision_config, **config)
|
||||
final_config = PaliGemmaConfig(
|
||||
text_config=text_config, vision_config=vision_config, **config
|
||||
)
|
||||
return final_config
|
||||
|
||||
|
||||
|
|
|
@ -61,7 +61,11 @@ from lerobot.common.policies.pi0.conversion_scripts.conversion_utils import (
|
|||
)
|
||||
from lerobot.common.policies.pi0.modeling_pi0 import PI0Policy
|
||||
|
||||
PRECISIONS = {"bfloat16": torch.bfloat16, "float32": torch.float32, "float16": torch.float16}
|
||||
PRECISIONS = {
|
||||
"bfloat16": torch.bfloat16,
|
||||
"float32": torch.float32,
|
||||
"float16": torch.float16,
|
||||
}
|
||||
|
||||
|
||||
def slice_paligemma_state_dict(state_dict, config):
|
||||
|
@ -318,7 +322,9 @@ def update_keys_with_prefix(d: dict, prefix: str) -> dict:
|
|||
return {f"{prefix}{key}": value for key, value in d.items()}
|
||||
|
||||
|
||||
def convert_pi0_checkpoint(checkpoint_dir: str, precision: str, tokenizer_id: str, output_path: str):
|
||||
def convert_pi0_checkpoint(
|
||||
checkpoint_dir: str, precision: str, tokenizer_id: str, output_path: str
|
||||
):
|
||||
# Break down orbax ckpts - they are in OCDBT
|
||||
initial_params = slice_initial_orbax_checkpoint(checkpoint_dir=checkpoint_dir)
|
||||
# process projection params
|
||||
|
@ -378,7 +384,9 @@ def convert_pi0_checkpoint(checkpoint_dir: str, precision: str, tokenizer_id: st
|
|||
# gemma_config=gemma_config, paligemma_config=paligemma_config)
|
||||
pi0_model = PI0Policy(pi0_config)
|
||||
|
||||
paligemma_params = update_keys_with_prefix(paligemma_params, "model.paligemma_with_expert.")
|
||||
paligemma_params = update_keys_with_prefix(
|
||||
paligemma_params, "model.paligemma_with_expert."
|
||||
)
|
||||
gemma_params = update_keys_with_prefix(gemma_params, "model.paligemma_with_expert.")
|
||||
projection_params = update_keys_with_prefix(projection_params, "model.")
|
||||
|
||||
|
|
|
@ -48,18 +48,32 @@ def flex_attention_forward(
|
|||
|
||||
key_states = key_states[:, :, :, None, :]
|
||||
key_states = key_states.expand(
|
||||
batch_size, key_states.shape[1], num_key_value_heads, num_key_value_groups, head_dim
|
||||
batch_size,
|
||||
key_states.shape[1],
|
||||
num_key_value_heads,
|
||||
num_key_value_groups,
|
||||
head_dim,
|
||||
)
|
||||
key_states = key_states.reshape(
|
||||
batch_size, key_states.shape[1], num_key_value_heads * num_key_value_groups, head_dim
|
||||
batch_size,
|
||||
key_states.shape[1],
|
||||
num_key_value_heads * num_key_value_groups,
|
||||
head_dim,
|
||||
)
|
||||
|
||||
value_states = value_states[:, :, :, None, :]
|
||||
value_states = value_states.expand(
|
||||
batch_size, value_states.shape[1], num_key_value_heads, num_key_value_groups, head_dim
|
||||
batch_size,
|
||||
value_states.shape[1],
|
||||
num_key_value_heads,
|
||||
num_key_value_groups,
|
||||
head_dim,
|
||||
)
|
||||
value_states = value_states.reshape(
|
||||
batch_size, value_states.shape[1], num_key_value_heads * num_key_value_groups, head_dim
|
||||
batch_size,
|
||||
value_states.shape[1],
|
||||
num_key_value_heads * num_key_value_groups,
|
||||
head_dim,
|
||||
)
|
||||
|
||||
query_states = query_states.transpose(1, 2)
|
||||
|
|
|
@ -69,7 +69,11 @@ from lerobot.common.utils.utils import get_safe_dtype
|
|||
|
||||
|
||||
def create_sinusoidal_pos_embedding(
|
||||
time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu"
|
||||
time: torch.tensor,
|
||||
dimension: int,
|
||||
min_period: float,
|
||||
max_period: float,
|
||||
device="cpu",
|
||||
) -> Tensor:
|
||||
"""Computes sine-cosine positional embedding vectors for scalar positions."""
|
||||
if dimension % 2 != 0:
|
||||
|
@ -189,7 +193,9 @@ def aloha_gripper_to_angular(value):
|
|||
|
||||
# This is the inverse of the angular to linear transformation inside the Interbotix code.
|
||||
def linear_to_radian(linear_position, arm_length, horn_radius):
|
||||
value = (horn_radius**2 + linear_position**2 - arm_length**2) / (2 * horn_radius * linear_position)
|
||||
value = (horn_radius**2 + linear_position**2 - arm_length**2) / (
|
||||
2 * horn_radius * linear_position
|
||||
)
|
||||
return safe_arcsin(value)
|
||||
|
||||
# The constants are taken from the Interbotix code.
|
||||
|
@ -240,7 +246,9 @@ class PI0Policy(PreTrainedPolicy):
|
|||
super().__init__(config)
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||||
self.normalize_inputs = Normalize(
|
||||
config.input_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
|
@ -248,7 +256,9 @@ class PI0Policy(PreTrainedPolicy):
|
|||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
|
||||
self.language_tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224")
|
||||
self.language_tokenizer = AutoTokenizer.from_pretrained(
|
||||
"google/paligemma-3b-pt-224"
|
||||
)
|
||||
self.model = PI0FlowMatching(config)
|
||||
|
||||
self.reset()
|
||||
|
@ -261,7 +271,9 @@ class PI0Policy(PreTrainedPolicy):
|
|||
return self.parameters()
|
||||
|
||||
@torch.no_grad
|
||||
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
||||
def select_action(
|
||||
self, batch: dict[str, Tensor], noise: Tensor | None = None
|
||||
) -> Tensor:
|
||||
"""Select a single action given environment observations.
|
||||
|
||||
This method wraps `select_actions` in order to return one action at a time for execution in the
|
||||
|
@ -300,7 +312,9 @@ class PI0Policy(PreTrainedPolicy):
|
|||
self._action_queue.extend(actions.transpose(0, 1))
|
||||
return self._action_queue.popleft()
|
||||
|
||||
def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> tuple[Tensor, dict[str, Tensor]]:
|
||||
def forward(
|
||||
self, batch: dict[str, Tensor], noise=None, time=None
|
||||
) -> tuple[Tensor, dict[str, Tensor]]:
|
||||
"""Do a full training forward pass to compute the loss"""
|
||||
if self.config.adapt_to_pi_aloha:
|
||||
batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT])
|
||||
|
@ -316,7 +330,9 @@ class PI0Policy(PreTrainedPolicy):
|
|||
actions_is_pad = batch.get("actions_is_pad")
|
||||
|
||||
loss_dict = {}
|
||||
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions, noise, time)
|
||||
losses = self.model.forward(
|
||||
images, img_masks, lang_tokens, lang_masks, state, actions, noise, time
|
||||
)
|
||||
loss_dict["losses_after_forward"] = losses.clone()
|
||||
|
||||
if actions_is_pad is not None:
|
||||
|
@ -343,7 +359,9 @@ class PI0Policy(PreTrainedPolicy):
|
|||
img_masks = []
|
||||
|
||||
present_img_keys = [key for key in self.config.image_features if key in batch]
|
||||
missing_img_keys = [key for key in self.config.image_features if key not in batch]
|
||||
missing_img_keys = [
|
||||
key for key in self.config.image_features if key not in batch
|
||||
]
|
||||
|
||||
if len(present_img_keys) == 0:
|
||||
raise ValueError(
|
||||
|
@ -355,7 +373,9 @@ class PI0Policy(PreTrainedPolicy):
|
|||
img = batch[key]
|
||||
|
||||
if self.config.resize_imgs_with_padding is not None:
|
||||
img = resize_with_pad(img, *self.config.resize_imgs_with_padding, pad_value=0)
|
||||
img = resize_with_pad(
|
||||
img, *self.config.resize_imgs_with_padding, pad_value=0
|
||||
)
|
||||
|
||||
# Normalize from range [0,1] to [-1,1] as expacted by siglip
|
||||
img = img * 2.0 - 1.0
|
||||
|
@ -394,7 +414,9 @@ class PI0Policy(PreTrainedPolicy):
|
|||
return_tensors="pt",
|
||||
)
|
||||
lang_tokens = tokenized_prompt["input_ids"].to(device=device)
|
||||
lang_masks = tokenized_prompt["attention_mask"].to(device=device, dtype=torch.bool)
|
||||
lang_masks = tokenized_prompt["attention_mask"].to(
|
||||
device=device, dtype=torch.bool
|
||||
)
|
||||
|
||||
return lang_tokens, lang_masks
|
||||
|
||||
|
@ -413,7 +435,9 @@ class PI0Policy(PreTrainedPolicy):
|
|||
actions[:, :, motor_idx] *= -1
|
||||
# Reverse the gripper transformation that is being applied by the Aloha runtime.
|
||||
for motor_idx in [6, 13]:
|
||||
actions[:, :, motor_idx] = aloha_gripper_from_angular(actions[:, :, motor_idx])
|
||||
actions[:, :, motor_idx] = aloha_gripper_from_angular(
|
||||
actions[:, :, motor_idx]
|
||||
)
|
||||
return actions
|
||||
|
||||
def _pi_aloha_encode_actions_inv(self, actions):
|
||||
|
@ -422,7 +446,9 @@ class PI0Policy(PreTrainedPolicy):
|
|||
actions[:, :, motor_idx] *= -1
|
||||
# Reverse the gripper transformation that is being applied by the Aloha runtime.
|
||||
for motor_idx in [6, 13]:
|
||||
actions[:, :, motor_idx] = aloha_gripper_from_angular_inv(actions[:, :, motor_idx])
|
||||
actions[:, :, motor_idx] = aloha_gripper_from_angular_inv(
|
||||
actions[:, :, motor_idx]
|
||||
)
|
||||
return actions
|
||||
|
||||
def prepare_state(self, batch):
|
||||
|
@ -472,15 +498,25 @@ class PI0FlowMatching(nn.Module):
|
|||
train_expert_only=self.config.train_expert_only,
|
||||
attention_implementation=self.config.attention_implementation,
|
||||
)
|
||||
self.paligemma_with_expert = PaliGemmaWithExpertModel(paligemma_with_export_config)
|
||||
self.paligemma_with_expert = PaliGemmaWithExpertModel(
|
||||
paligemma_with_export_config
|
||||
)
|
||||
|
||||
# Projections are float32
|
||||
self.state_proj = nn.Linear(self.config.max_state_dim, self.config.proj_width)
|
||||
self.action_in_proj = nn.Linear(self.config.max_action_dim, self.config.proj_width)
|
||||
self.action_out_proj = nn.Linear(self.config.proj_width, self.config.max_action_dim)
|
||||
self.action_in_proj = nn.Linear(
|
||||
self.config.max_action_dim, self.config.proj_width
|
||||
)
|
||||
self.action_out_proj = nn.Linear(
|
||||
self.config.proj_width, self.config.max_action_dim
|
||||
)
|
||||
|
||||
self.action_time_mlp_in = nn.Linear(self.config.proj_width * 2, self.config.proj_width)
|
||||
self.action_time_mlp_out = nn.Linear(self.config.proj_width, self.config.proj_width)
|
||||
self.action_time_mlp_in = nn.Linear(
|
||||
self.config.proj_width * 2, self.config.proj_width
|
||||
)
|
||||
self.action_time_mlp_out = nn.Linear(
|
||||
self.config.proj_width, self.config.proj_width
|
||||
)
|
||||
|
||||
self.set_requires_grad()
|
||||
|
||||
|
@ -524,7 +560,9 @@ class PI0FlowMatching(nn.Module):
|
|||
|
||||
# Normalize image embeddings
|
||||
img_emb_dim = img_emb.shape[-1]
|
||||
img_emb = img_emb * torch.tensor(img_emb_dim**0.5, dtype=img_emb.dtype, device=img_emb.device)
|
||||
img_emb = img_emb * torch.tensor(
|
||||
img_emb_dim**0.5, dtype=img_emb.dtype, device=img_emb.device
|
||||
)
|
||||
|
||||
bsize, num_img_embs = img_emb.shape[:2]
|
||||
img_mask = img_mask[:, None].expand(bsize, num_img_embs)
|
||||
|
@ -577,7 +615,11 @@ class PI0FlowMatching(nn.Module):
|
|||
|
||||
# Embed timestep using sine-cosine positional encoding with sensitivity in the range [0, 1]
|
||||
time_emb = create_sinusoidal_pos_embedding(
|
||||
timestep, self.config.proj_width, min_period=4e-3, max_period=4.0, device=device
|
||||
timestep,
|
||||
self.config.proj_width,
|
||||
min_period=4e-3,
|
||||
max_period=4.0,
|
||||
device=device,
|
||||
)
|
||||
time_emb = time_emb.type(dtype=dtype)
|
||||
|
||||
|
@ -595,7 +637,9 @@ class PI0FlowMatching(nn.Module):
|
|||
embs.append(action_time_emb)
|
||||
|
||||
bsize, action_time_dim = action_time_emb.shape[:2]
|
||||
action_time_mask = torch.ones(bsize, action_time_dim, dtype=torch.bool, device=device)
|
||||
action_time_mask = torch.ones(
|
||||
bsize, action_time_dim, dtype=torch.bool, device=device
|
||||
)
|
||||
pad_masks.append(action_time_mask)
|
||||
|
||||
# Set attention masks so that image, language and state inputs do not attend to action tokens
|
||||
|
@ -609,7 +653,15 @@ class PI0FlowMatching(nn.Module):
|
|||
return embs, pad_masks, att_masks
|
||||
|
||||
def forward(
|
||||
self, images, img_masks, lang_tokens, lang_masks, state, actions, noise=None, time=None
|
||||
self,
|
||||
images,
|
||||
img_masks,
|
||||
lang_tokens,
|
||||
lang_masks,
|
||||
state,
|
||||
actions,
|
||||
noise=None,
|
||||
time=None,
|
||||
) -> Tensor:
|
||||
"""Do a full training forward pass and compute the loss (batch_size x num_steps x num_motors)"""
|
||||
if noise is None:
|
||||
|
@ -625,7 +677,9 @@ class PI0FlowMatching(nn.Module):
|
|||
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
|
||||
images, img_masks, lang_tokens, lang_masks
|
||||
)
|
||||
suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(state, x_t, time)
|
||||
suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(
|
||||
state, x_t, time
|
||||
)
|
||||
|
||||
pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1)
|
||||
att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1)
|
||||
|
@ -649,13 +703,19 @@ class PI0FlowMatching(nn.Module):
|
|||
losses = F.mse_loss(u_t, v_t, reduction="none")
|
||||
return losses
|
||||
|
||||
def sample_actions(self, images, img_masks, lang_tokens, lang_masks, state, noise=None) -> Tensor:
|
||||
def sample_actions(
|
||||
self, images, img_masks, lang_tokens, lang_masks, state, noise=None
|
||||
) -> Tensor:
|
||||
"""Do a full inference forward and compute the action (batch_size x num_steps x num_motors)"""
|
||||
bsize = state.shape[0]
|
||||
device = state.device
|
||||
|
||||
if noise is None:
|
||||
actions_shape = (bsize, self.config.n_action_steps, self.config.max_action_dim)
|
||||
actions_shape = (
|
||||
bsize,
|
||||
self.config.n_action_steps,
|
||||
self.config.max_action_dim,
|
||||
)
|
||||
noise = self.sample_noise(actions_shape, device)
|
||||
|
||||
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
|
||||
|
@ -703,12 +763,16 @@ class PI0FlowMatching(nn.Module):
|
|||
timestep,
|
||||
):
|
||||
"""Apply one denoising step of the noise `x_t` at a given timestep."""
|
||||
suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(state, x_t, timestep)
|
||||
suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(
|
||||
state, x_t, timestep
|
||||
)
|
||||
|
||||
suffix_len = suffix_pad_masks.shape[1]
|
||||
batch_size = prefix_pad_masks.shape[0]
|
||||
prefix_len = prefix_pad_masks.shape[1]
|
||||
prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand(batch_size, suffix_len, prefix_len)
|
||||
prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand(
|
||||
batch_size, suffix_len, prefix_len
|
||||
)
|
||||
|
||||
suffix_att_2d_masks = make_att_2d_masks(suffix_pad_masks, suffix_att_masks)
|
||||
|
||||
|
|
|
@ -39,9 +39,13 @@ def apply_rope(x, positions, max_wavelength=10_000):
|
|||
dtype = x.dtype
|
||||
x = x.to(torch.float32)
|
||||
|
||||
freq_exponents = (2.0 / x.shape[-1]) * torch.arange(d_half, dtype=torch.float32, device=device)
|
||||
freq_exponents = (2.0 / x.shape[-1]) * torch.arange(
|
||||
d_half, dtype=torch.float32, device=device
|
||||
)
|
||||
timescale = max_wavelength**freq_exponents
|
||||
radians = positions[..., None].to(torch.float32) / timescale[None, None, :].to(torch.float32)
|
||||
radians = positions[..., None].to(torch.float32) / timescale[None, None, :].to(
|
||||
torch.float32
|
||||
)
|
||||
|
||||
radians = radians[..., None, :]
|
||||
|
||||
|
@ -174,7 +178,9 @@ class PaliGemmaWithExpertModel(PreTrainedModel):
|
|||
def __init__(self, config: PaliGemmaWithExpertConfig):
|
||||
super().__init__(config=config)
|
||||
self.config = config
|
||||
self.paligemma = PaliGemmaForConditionalGeneration(config=config.paligemma_config)
|
||||
self.paligemma = PaliGemmaForConditionalGeneration(
|
||||
config=config.paligemma_config
|
||||
)
|
||||
self.gemma_expert = GemmaForCausalLM(config=config.gemma_expert_config)
|
||||
# Remove unused embed_tokens
|
||||
self.gemma_expert.model.embed_tokens = None
|
||||
|
@ -291,14 +297,22 @@ class PaliGemmaWithExpertModel(PreTrainedModel):
|
|||
# so we create an empty cache, with just one cuda malloc, and if (in autoregressive case) we reach
|
||||
# the max len, then we (for instance) double the cache size. This implementation already exists
|
||||
# in `transformers`. (molbap)
|
||||
key_states = torch.cat([past_key_values[layer_idx]["key_states"], key_states], dim=1)
|
||||
key_states = torch.cat(
|
||||
[past_key_values[layer_idx]["key_states"], key_states], dim=1
|
||||
)
|
||||
value_states = torch.cat(
|
||||
[past_key_values[layer_idx]["value_states"], value_states], dim=1
|
||||
[past_key_values[layer_idx]["value_states"], value_states],
|
||||
dim=1,
|
||||
)
|
||||
|
||||
attention_interface = self.get_attention_interface()
|
||||
att_output = attention_interface(
|
||||
attention_mask, batch_size, head_dim, query_states, key_states, value_states
|
||||
attention_mask,
|
||||
batch_size,
|
||||
head_dim,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
)
|
||||
att_output = att_output.to(dtype=torch.bfloat16)
|
||||
|
||||
|
@ -358,15 +372,29 @@ class PaliGemmaWithExpertModel(PreTrainedModel):
|
|||
return attention_interface
|
||||
|
||||
def flash_attention_forward(
|
||||
self, attention_mask, batch_size, head_dim, query_states, key_states, value_states
|
||||
self,
|
||||
attention_mask,
|
||||
batch_size,
|
||||
head_dim,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
):
|
||||
raise NotImplementedError("FA2 is not implemented (yet)")
|
||||
|
||||
def eager_attention_forward(
|
||||
self, attention_mask, batch_size, head_dim, query_states, key_states, value_states
|
||||
self,
|
||||
attention_mask,
|
||||
batch_size,
|
||||
head_dim,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
):
|
||||
num_att_heads = self.config.paligemma_config.text_config.num_attention_heads
|
||||
num_key_value_heads = self.config.paligemma_config.text_config.num_key_value_heads
|
||||
num_key_value_heads = (
|
||||
self.config.paligemma_config.text_config.num_key_value_heads
|
||||
)
|
||||
num_key_value_groups = num_att_heads // num_key_value_heads
|
||||
|
||||
# query_states: batch_size, sequence_length, num_att_head, head_dim
|
||||
|
@ -375,17 +403,31 @@ class PaliGemmaWithExpertModel(PreTrainedModel):
|
|||
sequence_length = key_states.shape[1]
|
||||
|
||||
key_states = key_states[:, :, :, None, :].expand(
|
||||
batch_size, sequence_length, num_key_value_heads, num_key_value_groups, head_dim
|
||||
batch_size,
|
||||
sequence_length,
|
||||
num_key_value_heads,
|
||||
num_key_value_groups,
|
||||
head_dim,
|
||||
)
|
||||
key_states = key_states.reshape(
|
||||
batch_size, sequence_length, num_key_value_heads * num_key_value_groups, head_dim
|
||||
batch_size,
|
||||
sequence_length,
|
||||
num_key_value_heads * num_key_value_groups,
|
||||
head_dim,
|
||||
)
|
||||
|
||||
value_states = value_states[:, :, :, None, :].expand(
|
||||
batch_size, sequence_length, num_key_value_heads, num_key_value_groups, head_dim
|
||||
batch_size,
|
||||
sequence_length,
|
||||
num_key_value_heads,
|
||||
num_key_value_groups,
|
||||
head_dim,
|
||||
)
|
||||
value_states = value_states.reshape(
|
||||
batch_size, sequence_length, num_key_value_heads * num_key_value_groups, head_dim
|
||||
batch_size,
|
||||
sequence_length,
|
||||
num_key_value_heads * num_key_value_groups,
|
||||
head_dim,
|
||||
)
|
||||
|
||||
# Attention here is upcasted to float32 to match the original eager implementation.
|
||||
|
@ -400,7 +442,9 @@ class PaliGemmaWithExpertModel(PreTrainedModel):
|
|||
att_weights *= head_dim**-0.5
|
||||
big_neg = -2.3819763e38 # See gemma/modules.py
|
||||
|
||||
masked_att_weights = torch.where(attention_mask[:, None, :, :], att_weights, big_neg)
|
||||
masked_att_weights = torch.where(
|
||||
attention_mask[:, None, :, :], att_weights, big_neg
|
||||
)
|
||||
|
||||
probs = nn.functional.softmax(masked_att_weights, dim=-1)
|
||||
probs = probs.to(dtype=value_states.dtype)
|
||||
|
@ -412,6 +456,8 @@ class PaliGemmaWithExpertModel(PreTrainedModel):
|
|||
|
||||
att_output = att_output.permute(0, 2, 1, 3)
|
||||
# we use -1 because sequence length can change
|
||||
att_output = att_output.reshape(batch_size, -1, num_key_value_heads * num_key_value_groups * head_dim)
|
||||
att_output = att_output.reshape(
|
||||
batch_size, -1, num_key_value_heads * num_key_value_groups * head_dim
|
||||
)
|
||||
|
||||
return att_output
|
||||
|
|
|
@ -71,7 +71,9 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
|
|||
def _save_pretrained(self, save_directory: Path) -> None:
|
||||
self.config._save_pretrained(save_directory)
|
||||
model_to_save = self.module if hasattr(self, "module") else self
|
||||
save_model_as_safetensor(model_to_save, str(save_directory / SAFETENSORS_SINGLE_FILE))
|
||||
save_model_as_safetensor(
|
||||
model_to_save, str(save_directory / SAFETENSORS_SINGLE_FILE)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
|
@ -110,7 +112,9 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
|
|||
if os.path.isdir(model_id):
|
||||
print("Loading weights from local directory")
|
||||
model_file = os.path.join(model_id, SAFETENSORS_SINGLE_FILE)
|
||||
policy = cls._load_as_safetensor(instance, model_file, config.device, strict)
|
||||
policy = cls._load_as_safetensor(
|
||||
instance, model_file, config.device, strict
|
||||
)
|
||||
else:
|
||||
try:
|
||||
model_file = hf_hub_download(
|
||||
|
@ -124,7 +128,9 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
|
|||
token=token,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
policy = cls._load_as_safetensor(instance, model_file, config.device, strict)
|
||||
policy = cls._load_as_safetensor(
|
||||
instance, model_file, config.device, strict
|
||||
)
|
||||
except HfHubHTTPError as e:
|
||||
raise FileNotFoundError(
|
||||
f"{SAFETENSORS_SINGLE_FILE} not found on the HuggingFace Hub in {model_id}"
|
||||
|
@ -135,8 +141,12 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
|
|||
return policy
|
||||
|
||||
@classmethod
|
||||
def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, strict: bool) -> T:
|
||||
if packaging.version.parse(safetensors.__version__) < packaging.version.parse("0.4.3"):
|
||||
def _load_as_safetensor(
|
||||
cls, model: T, model_file: str, map_location: str, strict: bool
|
||||
) -> T:
|
||||
if packaging.version.parse(safetensors.__version__) < packaging.version.parse(
|
||||
"0.4.3"
|
||||
):
|
||||
load_model_as_safetensor(model, model_file, strict=strict)
|
||||
if map_location != "cpu":
|
||||
logging.warning(
|
||||
|
@ -147,7 +157,9 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
|
|||
)
|
||||
model.to(map_location)
|
||||
else:
|
||||
safetensors.torch.load_model(model, model_file, strict=strict, device=map_location)
|
||||
safetensors.torch.load_model(
|
||||
model, model_file, strict=strict, device=map_location
|
||||
)
|
||||
return model
|
||||
|
||||
# def generate_model_card(self, *args, **kwargs) -> ModelCard:
|
||||
|
|
|
@ -639,9 +639,9 @@ class Policy(nn.Module):
|
|||
# Compute standard deviations
|
||||
if self.fixed_std is None:
|
||||
log_std = self.std_layer(outputs)
|
||||
assert not torch.isnan(
|
||||
log_std
|
||||
).any(), "[ERROR] log_std became NaN after std_layer!"
|
||||
assert not torch.isnan(log_std).any(), (
|
||||
"[ERROR] log_std became NaN after std_layer!"
|
||||
)
|
||||
|
||||
if self.use_tanh_squash:
|
||||
log_std = torch.tanh(log_std)
|
||||
|
|
|
@ -187,7 +187,9 @@ class TDMPCConfig(PreTrainedConfig):
|
|||
"If `n_action_steps > 1`, `use_mpc` must be set to `True`."
|
||||
)
|
||||
if self.n_action_steps > self.horizon:
|
||||
raise ValueError("`n_action_steps` must be less than or equal to `horizon`.")
|
||||
raise ValueError(
|
||||
"`n_action_steps` must be less than or equal to `horizon`."
|
||||
)
|
||||
|
||||
def get_optimizer_preset(self) -> AdamConfig:
|
||||
return AdamConfig(lr=self.optimizer_lr)
|
||||
|
@ -207,7 +209,9 @@ class TDMPCConfig(PreTrainedConfig):
|
|||
if image_ft.shape[-2] != image_ft.shape[-1]:
|
||||
# TODO(alexander-soare): This limitation is solely because of code in the random shift
|
||||
# augmentation. It should be able to be removed.
|
||||
raise ValueError(f"Only square images are handled now. Got image shape {image_ft.shape}.")
|
||||
raise ValueError(
|
||||
f"Only square images are handled now. Got image shape {image_ft.shape}."
|
||||
)
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> list:
|
||||
|
|
|
@ -39,7 +39,11 @@ from lerobot.common.constants import OBS_ENV, OBS_ROBOT
|
|||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from lerobot.common.policies.utils import get_device_from_parameters, get_output_shape, populate_queues
|
||||
from lerobot.common.policies.utils import (
|
||||
get_device_from_parameters,
|
||||
get_output_shape,
|
||||
populate_queues,
|
||||
)
|
||||
|
||||
|
||||
class TDMPCPolicy(PreTrainedPolicy):
|
||||
|
@ -63,7 +67,11 @@ class TDMPCPolicy(PreTrainedPolicy):
|
|||
config_class = TDMPCConfig
|
||||
name = "tdmpc"
|
||||
|
||||
def __init__(self, config: TDMPCConfig, dataset_stats: dict[str, dict[str, Tensor]] | None = None):
|
||||
def __init__(
|
||||
self,
|
||||
config: TDMPCConfig,
|
||||
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
config: Policy configuration class instance or None, in which case the default instantiation of
|
||||
|
@ -75,7 +83,9 @@ class TDMPCPolicy(PreTrainedPolicy):
|
|||
config.validate_features()
|
||||
self.config = config
|
||||
|
||||
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||||
self.normalize_inputs = Normalize(
|
||||
config.input_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
|
@ -117,7 +127,9 @@ class TDMPCPolicy(PreTrainedPolicy):
|
|||
"""Select a single action given environment observations."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
if self.config.image_features:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch = dict(
|
||||
batch
|
||||
) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.image"] = batch[next(iter(self.config.image_features))]
|
||||
|
||||
self._queues = populate_queues(self._queues, batch)
|
||||
|
@ -201,7 +213,10 @@ class TDMPCPolicy(PreTrainedPolicy):
|
|||
# algorithm.
|
||||
# The initial mean and standard deviation for the cross-entropy method (CEM).
|
||||
mean = torch.zeros(
|
||||
self.config.horizon, batch_size, self.config.action_feature.shape[0], device=device
|
||||
self.config.horizon,
|
||||
batch_size,
|
||||
self.config.action_feature.shape[0],
|
||||
device=device,
|
||||
)
|
||||
# Maybe warm start CEM with the mean from the previous step.
|
||||
if self._prev_mean is not None:
|
||||
|
@ -339,7 +354,9 @@ class TDMPCPolicy(PreTrainedPolicy):
|
|||
|
||||
batch = self.normalize_inputs(batch)
|
||||
if self.config.image_features:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch = dict(
|
||||
batch
|
||||
) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.image"] = batch[next(iter(self.config.image_features))]
|
||||
batch = self.normalize_targets(batch)
|
||||
|
||||
|
@ -371,7 +388,9 @@ class TDMPCPolicy(PreTrainedPolicy):
|
|||
current_observation[k] = observations[k][0]
|
||||
next_observations[k] = observations[k][1:]
|
||||
horizon, batch_size = next_observations[
|
||||
"observation.image" if self.config.image_features else "observation.environment_state"
|
||||
"observation.image"
|
||||
if self.config.image_features
|
||||
else "observation.environment_state"
|
||||
].shape[:2]
|
||||
|
||||
# Run latent rollout using the latent dynamics model and policy model.
|
||||
|
@ -569,7 +588,9 @@ class TDMPCTOLD(nn.Module):
|
|||
self.config = config
|
||||
self._encoder = TDMPCObservationEncoder(config)
|
||||
self._dynamics = nn.Sequential(
|
||||
nn.Linear(config.latent_dim + config.action_feature.shape[0], config.mlp_dim),
|
||||
nn.Linear(
|
||||
config.latent_dim + config.action_feature.shape[0], config.mlp_dim
|
||||
),
|
||||
nn.LayerNorm(config.mlp_dim),
|
||||
nn.Mish(),
|
||||
nn.Linear(config.mlp_dim, config.mlp_dim),
|
||||
|
@ -580,7 +601,9 @@ class TDMPCTOLD(nn.Module):
|
|||
nn.Sigmoid(),
|
||||
)
|
||||
self._reward = nn.Sequential(
|
||||
nn.Linear(config.latent_dim + config.action_feature.shape[0], config.mlp_dim),
|
||||
nn.Linear(
|
||||
config.latent_dim + config.action_feature.shape[0], config.mlp_dim
|
||||
),
|
||||
nn.LayerNorm(config.mlp_dim),
|
||||
nn.Mish(),
|
||||
nn.Linear(config.mlp_dim, config.mlp_dim),
|
||||
|
@ -600,7 +623,10 @@ class TDMPCTOLD(nn.Module):
|
|||
self._Qs = nn.ModuleList(
|
||||
[
|
||||
nn.Sequential(
|
||||
nn.Linear(config.latent_dim + config.action_feature.shape[0], config.mlp_dim),
|
||||
nn.Linear(
|
||||
config.latent_dim + config.action_feature.shape[0],
|
||||
config.mlp_dim,
|
||||
),
|
||||
nn.LayerNorm(config.mlp_dim),
|
||||
nn.Tanh(),
|
||||
nn.Linear(config.mlp_dim, config.mlp_dim),
|
||||
|
@ -786,7 +812,9 @@ class TDMPCObservationEncoder(nn.Module):
|
|||
|
||||
if config.robot_state_feature:
|
||||
self.state_enc_layers = nn.Sequential(
|
||||
nn.Linear(config.robot_state_feature.shape[0], config.state_encoder_hidden_dim),
|
||||
nn.Linear(
|
||||
config.robot_state_feature.shape[0], config.state_encoder_hidden_dim
|
||||
),
|
||||
nn.ELU(),
|
||||
nn.Linear(config.state_encoder_hidden_dim, config.latent_dim),
|
||||
nn.LayerNorm(config.latent_dim),
|
||||
|
@ -795,7 +823,9 @@ class TDMPCObservationEncoder(nn.Module):
|
|||
|
||||
if config.env_state_feature:
|
||||
self.env_state_enc_layers = nn.Sequential(
|
||||
nn.Linear(config.env_state_feature.shape[0], config.state_encoder_hidden_dim),
|
||||
nn.Linear(
|
||||
config.env_state_feature.shape[0], config.state_encoder_hidden_dim
|
||||
),
|
||||
nn.ELU(),
|
||||
nn.Linear(config.state_encoder_hidden_dim, config.latent_dim),
|
||||
nn.LayerNorm(config.latent_dim),
|
||||
|
@ -813,7 +843,8 @@ class TDMPCObservationEncoder(nn.Module):
|
|||
if self.config.image_features:
|
||||
feat.append(
|
||||
flatten_forward_unflatten(
|
||||
self.image_enc_layers, obs_dict[next(iter(self.config.image_features))]
|
||||
self.image_enc_layers,
|
||||
obs_dict[next(iter(self.config.image_features))],
|
||||
)
|
||||
)
|
||||
if self.config.env_state_feature:
|
||||
|
|
|
@ -172,7 +172,10 @@ class VQBeTConfig(PreTrainedConfig):
|
|||
|
||||
if self.crop_shape is not None:
|
||||
for key, image_ft in self.image_features.items():
|
||||
if self.crop_shape[0] > image_ft.shape[1] or self.crop_shape[1] > image_ft.shape[2]:
|
||||
if (
|
||||
self.crop_shape[0] > image_ft.shape[1]
|
||||
or self.crop_shape[1] > image_ft.shape[2]
|
||||
):
|
||||
raise ValueError(
|
||||
f"`crop_shape` should fit within the images shapes. Got {self.crop_shape} "
|
||||
f"for `crop_shape` and {image_ft.shape} for "
|
||||
|
@ -193,7 +196,12 @@ class VQBeTConfig(PreTrainedConfig):
|
|||
|
||||
@property
|
||||
def action_delta_indices(self) -> list:
|
||||
return list(range(1 - self.n_obs_steps, self.n_action_pred_token + self.action_chunk_size - 1))
|
||||
return list(
|
||||
range(
|
||||
1 - self.n_obs_steps,
|
||||
self.n_action_pred_token + self.action_chunk_size - 1,
|
||||
)
|
||||
)
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
|
|
|
@ -29,7 +29,11 @@ from torch import Tensor, nn
|
|||
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.common.policies.utils import get_device_from_parameters, get_output_shape, populate_queues
|
||||
from lerobot.common.policies.utils import (
|
||||
get_device_from_parameters,
|
||||
get_output_shape,
|
||||
populate_queues,
|
||||
)
|
||||
from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||
from lerobot.common.policies.vqbet.vqbet_utils import GPT, ResidualVQ
|
||||
|
||||
|
@ -60,7 +64,9 @@ class VQBeTPolicy(PreTrainedPolicy):
|
|||
config.validate_features()
|
||||
self.config = config
|
||||
|
||||
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||||
self.normalize_inputs = Normalize(
|
||||
config.input_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
|
@ -91,11 +97,17 @@ class VQBeTPolicy(PreTrainedPolicy):
|
|||
if self.config.sequentially_select:
|
||||
decay_params = (
|
||||
decay_params
|
||||
+ list(self.vqbet.action_head.map_to_cbet_preds_primary_bin.parameters())
|
||||
+ list(self.vqbet.action_head.map_to_cbet_preds_secondary_bin.parameters())
|
||||
+ list(
|
||||
self.vqbet.action_head.map_to_cbet_preds_primary_bin.parameters()
|
||||
)
|
||||
+ list(
|
||||
self.vqbet.action_head.map_to_cbet_preds_secondary_bin.parameters()
|
||||
)
|
||||
)
|
||||
else:
|
||||
decay_params = decay_params + list(self.vqbet.action_head.map_to_cbet_preds_bin.parameters())
|
||||
decay_params = decay_params + list(
|
||||
self.vqbet.action_head.map_to_cbet_preds_bin.parameters()
|
||||
)
|
||||
|
||||
return [
|
||||
{
|
||||
|
@ -133,8 +145,12 @@ class VQBeTPolicy(PreTrainedPolicy):
|
|||
"""
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
|
||||
batch = dict(
|
||||
batch
|
||||
) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = torch.stack(
|
||||
[batch[key] for key in self.config.image_features], dim=-4
|
||||
)
|
||||
# Note: It's important that this happens after stacking the images into a single key.
|
||||
self._queues = populate_queues(self._queues, batch)
|
||||
|
||||
|
@ -165,8 +181,12 @@ class VQBeTPolicy(PreTrainedPolicy):
|
|||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
|
||||
batch = dict(
|
||||
batch
|
||||
) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = torch.stack(
|
||||
[batch[key] for key in self.config.image_features], dim=-4
|
||||
)
|
||||
batch = self.normalize_targets(batch)
|
||||
# VQ-BeT discretizes action using VQ-VAE before training BeT (please refer to section 3.2 in the VQ-BeT paper https://arxiv.org/pdf/2403.03181)
|
||||
if not self.vqbet.action_head.vqvae_model.discretized.item():
|
||||
|
@ -334,7 +354,8 @@ class VQBeTModel(nn.Module):
|
|||
|
||||
# To input state and observation features into GPT layers, we first project the features to fit the shape of input size of GPT.
|
||||
self.state_projector = MLP(
|
||||
config.robot_state_feature.shape[0], hidden_channels=[self.config.gpt_input_dim]
|
||||
config.robot_state_feature.shape[0],
|
||||
hidden_channels=[self.config.gpt_input_dim],
|
||||
)
|
||||
self.rgb_feature_projector = MLP(
|
||||
self.rgb_encoder.feature_dim, hidden_channels=[self.config.gpt_input_dim]
|
||||
|
@ -406,9 +427,9 @@ class VQBeTModel(nn.Module):
|
|||
features = self.policy(input_tokens)
|
||||
# len(self.config.input_features) is the number of different observation modes.
|
||||
# this line gets the index of action prompt tokens.
|
||||
historical_act_pred_index = np.arange(0, n_obs_steps) * (len(self.config.input_features) + 1) + len(
|
||||
self.config.input_features
|
||||
)
|
||||
historical_act_pred_index = np.arange(0, n_obs_steps) * (
|
||||
len(self.config.input_features) + 1
|
||||
) + len(self.config.input_features)
|
||||
|
||||
# only extract the output tokens at the position of action query:
|
||||
# Behavior Transformer (BeT), and VQ-BeT are both sequence-to-sequence prediction models,
|
||||
|
@ -771,11 +792,15 @@ class VQBeTRgbEncoder(nn.Module):
|
|||
# height and width from `config.image_features`.
|
||||
|
||||
images_shape = next(iter(config.image_features.values())).shape
|
||||
dummy_shape_h_w = config.crop_shape if config.crop_shape is not None else images_shape[1:]
|
||||
dummy_shape_h_w = (
|
||||
config.crop_shape if config.crop_shape is not None else images_shape[1:]
|
||||
)
|
||||
dummy_shape = (1, images_shape[0], *dummy_shape_h_w)
|
||||
feature_map_shape = get_output_shape(self.backbone, dummy_shape)[1:]
|
||||
|
||||
self.pool = SpatialSoftmax(feature_map_shape, num_kp=config.spatial_softmax_num_keypoints)
|
||||
self.pool = SpatialSoftmax(
|
||||
feature_map_shape, num_kp=config.spatial_softmax_num_keypoints
|
||||
)
|
||||
self.feature_dim = config.spatial_softmax_num_keypoints * 2
|
||||
self.out = nn.Linear(config.spatial_softmax_num_keypoints * 2, self.feature_dim)
|
||||
self.relu = nn.ReLU()
|
||||
|
@ -871,7 +896,8 @@ class VqVae(nn.Module):
|
|||
)
|
||||
|
||||
self.encoder = MLP(
|
||||
in_channels=self.config.action_feature.shape[0] * self.config.action_chunk_size,
|
||||
in_channels=self.config.action_feature.shape[0]
|
||||
* self.config.action_chunk_size,
|
||||
hidden_channels=[
|
||||
config.vqvae_enc_hidden_dim,
|
||||
config.vqvae_enc_hidden_dim,
|
||||
|
@ -899,9 +925,13 @@ class VqVae(nn.Module):
|
|||
# given latent vector, this function outputs the decoded action.
|
||||
output = self.decoder(latent)
|
||||
if self.config.action_chunk_size == 1:
|
||||
return einops.rearrange(output, "N (T A) -> N T A", A=self.config.action_feature.shape[0])
|
||||
return einops.rearrange(
|
||||
output, "N (T A) -> N T A", A=self.config.action_feature.shape[0]
|
||||
)
|
||||
else:
|
||||
return einops.rearrange(output, "N (T A) -> N T A", A=self.config.action_feature.shape[0])
|
||||
return einops.rearrange(
|
||||
output, "N (T A) -> N T A", A=self.config.action_feature.shape[0]
|
||||
)
|
||||
|
||||
def get_code(self, state):
|
||||
# in phase 2 of VQ-BeT training, we need a `ground truth labels of action data` to calculate the Focal loss for code prediction head. (please refer to section 3.3 in the paper https://arxiv.org/pdf/2403.03181)
|
||||
|
|
|
@ -290,11 +290,11 @@ class GPT(nn.Module):
|
|||
param_dict = dict(self.named_parameters())
|
||||
inter_params = decay & no_decay
|
||||
union_params = decay | no_decay
|
||||
assert (
|
||||
len(inter_params) == 0
|
||||
), "parameters {} made it into both decay/no_decay sets!".format(
|
||||
assert len(inter_params) == 0, (
|
||||
"parameters {} made it into both decay/no_decay sets!".format(
|
||||
str(inter_params)
|
||||
)
|
||||
)
|
||||
assert len(param_dict.keys() - union_params) == 0, (
|
||||
"parameters {} were not separated into either decay/no_decay set!".format(
|
||||
str(param_dict.keys() - union_params),
|
||||
|
@ -664,14 +664,14 @@ class VectorQuantize(nn.Module):
|
|||
self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only
|
||||
self.orthogonal_reg_max_codes = orthogonal_reg_max_codes
|
||||
|
||||
assert not (
|
||||
ema_update and learnable_codebook
|
||||
), "learnable codebook not compatible with EMA update"
|
||||
assert not (ema_update and learnable_codebook), (
|
||||
"learnable codebook not compatible with EMA update"
|
||||
)
|
||||
|
||||
assert 0 <= sync_update_v <= 1.0
|
||||
assert not (
|
||||
sync_update_v > 0.0 and not learnable_codebook
|
||||
), "learnable codebook must be turned on"
|
||||
assert not (sync_update_v > 0.0 and not learnable_codebook), (
|
||||
"learnable codebook must be turned on"
|
||||
)
|
||||
|
||||
self.sync_update_v = sync_update_v
|
||||
|
||||
|
|
|
@ -57,7 +57,9 @@ class OpenCVCameraConfig(CameraConfig):
|
|||
self.channels = 3
|
||||
|
||||
if self.rotation not in [-90, None, 90, 180]:
|
||||
raise ValueError(f"`rotation` must be in [-90, None, 90, 180] (got {self.rotation})")
|
||||
raise ValueError(
|
||||
f"`rotation` must be in [-90, None, 90, 180] (got {self.rotation})"
|
||||
)
|
||||
|
||||
|
||||
@CameraConfig.register_subclass("intelrealsense")
|
||||
|
@ -102,8 +104,12 @@ class IntelRealSenseCameraConfig(CameraConfig):
|
|||
|
||||
self.channels = 3
|
||||
|
||||
at_least_one_is_not_none = self.fps is not None or self.width is not None or self.height is not None
|
||||
at_least_one_is_none = self.fps is None or self.width is None or self.height is None
|
||||
at_least_one_is_not_none = (
|
||||
self.fps is not None or self.width is not None or self.height is not None
|
||||
)
|
||||
at_least_one_is_none = (
|
||||
self.fps is None or self.width is None or self.height is None
|
||||
)
|
||||
if at_least_one_is_not_none and at_least_one_is_none:
|
||||
raise ValueError(
|
||||
"For `fps`, `width` and `height`, either all of them need to be set, or none of them, "
|
||||
|
@ -111,4 +117,6 @@ class IntelRealSenseCameraConfig(CameraConfig):
|
|||
)
|
||||
|
||||
if self.rotation not in [-90, None, 90, 180]:
|
||||
raise ValueError(f"`rotation` must be in [-90, None, 90, 180] (got {self.rotation})")
|
||||
raise ValueError(
|
||||
f"`rotation` must be in [-90, None, 90, 180] (got {self.rotation})"
|
||||
)
|
||||
|
|
|
@ -303,7 +303,11 @@ class IntelRealSenseCamera:
|
|||
if self.fps and self.capture_width and self.capture_height:
|
||||
# TODO(rcadene): can we set rgb8 directly?
|
||||
config.enable_stream(
|
||||
rs.stream.color, self.capture_width, self.capture_height, rs.format.rgb8, self.fps
|
||||
rs.stream.color,
|
||||
self.capture_width,
|
||||
self.capture_height,
|
||||
rs.format.rgb8,
|
||||
self.fps,
|
||||
)
|
||||
else:
|
||||
config.enable_stream(rs.stream.color)
|
||||
|
@ -311,7 +315,11 @@ class IntelRealSenseCamera:
|
|||
if self.use_depth:
|
||||
if self.fps and self.capture_width and self.capture_height:
|
||||
config.enable_stream(
|
||||
rs.stream.depth, self.capture_width, self.capture_height, rs.format.z16, self.fps
|
||||
rs.stream.depth,
|
||||
self.capture_width,
|
||||
self.capture_height,
|
||||
rs.format.z16,
|
||||
self.fps,
|
||||
)
|
||||
else:
|
||||
config.enable_stream(rs.stream.depth)
|
||||
|
|
|
@ -144,7 +144,9 @@ def save_images_from_cameras(
|
|||
print("Connecting cameras")
|
||||
cameras = []
|
||||
for cam_idx in camera_ids:
|
||||
config = OpenCVCameraConfig(camera_index=cam_idx, fps=fps, width=width, height=height, mock=mock)
|
||||
config = OpenCVCameraConfig(
|
||||
camera_index=cam_idx, fps=fps, width=width, height=height, mock=mock
|
||||
)
|
||||
camera = OpenCVCamera(config)
|
||||
camera.connect()
|
||||
print(
|
||||
|
@ -250,7 +252,9 @@ class OpenCVCamera:
|
|||
# Retrieve the camera index from a potentially symlinked path
|
||||
self.camera_index = get_camera_index_from_unix_port(self.port)
|
||||
else:
|
||||
raise ValueError(f"Please check the provided camera_index: {self.camera_index}")
|
||||
raise ValueError(
|
||||
f"Please check the provided camera_index: {self.camera_index}"
|
||||
)
|
||||
|
||||
# Store the raw (capture) resolution from the config.
|
||||
self.capture_width = config.width
|
||||
|
@ -314,7 +318,11 @@ class OpenCVCamera:
|
|||
else cv2.CAP_ANY
|
||||
)
|
||||
|
||||
camera_idx = f"/dev/video{self.camera_index}" if platform.system() == "Linux" else self.camera_index
|
||||
camera_idx = (
|
||||
f"/dev/video{self.camera_index}"
|
||||
if platform.system() == "Linux"
|
||||
else self.camera_index
|
||||
)
|
||||
# First create a temporary camera trying to access `camera_index`,
|
||||
# and verify it is a valid camera by calling `isOpened`.
|
||||
tmp_camera = cv2.VideoCapture(camera_idx, backend)
|
||||
|
|
|
@ -41,7 +41,9 @@ def make_cameras_from_configs(camera_configs: dict[str, CameraConfig]) -> list[C
|
|||
cameras[key] = OpenCVCamera(cfg)
|
||||
|
||||
elif cfg.type == "intelrealsense":
|
||||
from lerobot.common.robot_devices.cameras.intelrealsense import IntelRealSenseCamera
|
||||
from lerobot.common.robot_devices.cameras.intelrealsense import (
|
||||
IntelRealSenseCamera,
|
||||
)
|
||||
|
||||
cameras[key] = IntelRealSenseCamera(cfg)
|
||||
else:
|
||||
|
@ -58,7 +60,9 @@ def make_camera(camera_type, **kwargs) -> Camera:
|
|||
return OpenCVCamera(config)
|
||||
|
||||
elif camera_type == "intelrealsense":
|
||||
from lerobot.common.robot_devices.cameras.intelrealsense import IntelRealSenseCamera
|
||||
from lerobot.common.robot_devices.cameras.intelrealsense import (
|
||||
IntelRealSenseCamera,
|
||||
)
|
||||
|
||||
config = IntelRealSenseCameraConfig(**kwargs)
|
||||
return IntelRealSenseCamera(config)
|
||||
|
|
|
@ -93,7 +93,9 @@ class RecordControlConfig(ControlConfig):
|
|||
policy_path = parser.get_path_arg("control.policy")
|
||||
if policy_path:
|
||||
cli_overrides = parser.get_cli_overrides("control.policy")
|
||||
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
||||
self.policy = PreTrainedConfig.from_pretrained(
|
||||
policy_path, cli_overrides=cli_overrides
|
||||
)
|
||||
self.policy.pretrained_path = policy_path
|
||||
|
||||
|
||||
|
|
|
@ -282,7 +282,10 @@ def control_loop(
|
|||
|
||||
if policy is not None:
|
||||
pred_action = predict_action(
|
||||
observation, policy, get_safe_torch_device(policy.config.device), policy.config.use_amp
|
||||
observation,
|
||||
policy,
|
||||
get_safe_torch_device(policy.config.device),
|
||||
policy.config.use_amp,
|
||||
)
|
||||
# Action can eventually be clipped using `max_relative_target`,
|
||||
# so action actually sent is saved in the dataset.
|
||||
|
|
|
@ -23,7 +23,10 @@ import numpy as np
|
|||
import tqdm
|
||||
|
||||
from lerobot.common.robot_devices.motors.configs import DynamixelMotorsBusConfig
|
||||
from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError
|
||||
from lerobot.common.robot_devices.utils import (
|
||||
RobotDeviceAlreadyConnectedError,
|
||||
RobotDeviceNotConnectedError,
|
||||
)
|
||||
from lerobot.common.utils.utils import capture_timestamp_utc
|
||||
|
||||
PROTOCOL_VERSION = 2.0
|
||||
|
|
|
@ -23,7 +23,10 @@ import numpy as np
|
|||
import tqdm
|
||||
|
||||
from lerobot.common.robot_devices.motors.configs import FeetechMotorsBusConfig
|
||||
from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError
|
||||
from lerobot.common.robot_devices.utils import (
|
||||
RobotDeviceAlreadyConnectedError,
|
||||
RobotDeviceNotConnectedError,
|
||||
)
|
||||
from lerobot.common.utils.utils import capture_timestamp_utc
|
||||
|
||||
PROTOCOL_VERSION = 0
|
||||
|
|
|
@ -30,7 +30,9 @@ class MotorsBus(Protocol):
|
|||
def write(self): ...
|
||||
|
||||
|
||||
def make_motors_buses_from_configs(motors_bus_configs: dict[str, MotorsBusConfig]) -> list[MotorsBus]:
|
||||
def make_motors_buses_from_configs(
|
||||
motors_bus_configs: dict[str, MotorsBusConfig],
|
||||
) -> list[MotorsBus]:
|
||||
motors_buses = {}
|
||||
|
||||
for key, cfg in motors_bus_configs.items():
|
||||
|
|
|
@ -69,9 +69,13 @@ class ManipulatorRobotConfig(RobotConfig):
|
|||
if not cam.mock:
|
||||
cam.mock = True
|
||||
|
||||
if self.max_relative_target is not None and isinstance(self.max_relative_target, Sequence):
|
||||
if self.max_relative_target is not None and isinstance(
|
||||
self.max_relative_target, Sequence
|
||||
):
|
||||
for name in self.follower_arms:
|
||||
if len(self.follower_arms[name].motors) != len(self.max_relative_target):
|
||||
if len(self.follower_arms[name].motors) != len(
|
||||
self.max_relative_target
|
||||
):
|
||||
raise ValueError(
|
||||
f"len(max_relative_target)={len(self.max_relative_target)} but the follower arm with name {name} has "
|
||||
f"{len(self.follower_arms[name].motors)} motors. Please make sure that the "
|
||||
|
|
|
@ -42,7 +42,9 @@ def run_camera_capture(cameras, images_lock, latest_images_dict, stop_event):
|
|||
local_dict = {}
|
||||
for name, cam in cameras.items():
|
||||
frame = cam.async_read()
|
||||
ret, buffer = cv2.imencode(".jpg", frame, [int(cv2.IMWRITE_JPEG_QUALITY), 90])
|
||||
ret, buffer = cv2.imencode(
|
||||
".jpg", frame, [int(cv2.IMWRITE_JPEG_QUALITY), 90]
|
||||
)
|
||||
if ret:
|
||||
local_dict[name] = base64.b64encode(buffer).decode("utf-8")
|
||||
else:
|
||||
|
@ -61,7 +63,9 @@ def calibrate_follower_arm(motors_bus, calib_dir_str):
|
|||
calib_dir.mkdir(parents=True, exist_ok=True)
|
||||
calib_file = calib_dir / "main_follower.json"
|
||||
try:
|
||||
from lerobot.common.robot_devices.robots.feetech_calibration import run_arm_manual_calibration
|
||||
from lerobot.common.robot_devices.robots.feetech_calibration import (
|
||||
run_arm_manual_calibration,
|
||||
)
|
||||
except ImportError:
|
||||
print("[WARNING] Calibration function not available. Skipping calibration.")
|
||||
return
|
||||
|
@ -72,7 +76,9 @@ def calibrate_follower_arm(motors_bus, calib_dir_str):
|
|||
print(f"[INFO] Loaded calibration from {calib_file}")
|
||||
else:
|
||||
print("[INFO] Calibration file not found. Running manual calibration...")
|
||||
calibration = run_arm_manual_calibration(motors_bus, "lekiwi", "follower_arm", "follower")
|
||||
calibration = run_arm_manual_calibration(
|
||||
motors_bus, "lekiwi", "follower_arm", "follower"
|
||||
)
|
||||
print(f"[INFO] Calibration complete. Saving to {calib_file}")
|
||||
with open(calib_file, "w") as f:
|
||||
json.dump(calibration, f)
|
||||
|
@ -116,7 +122,14 @@ def run_lekiwi(robot_config):
|
|||
robot = LeKiwi(motors_bus)
|
||||
|
||||
# Define the expected arm motor IDs.
|
||||
arm_motor_ids = ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"]
|
||||
arm_motor_ids = [
|
||||
"shoulder_pan",
|
||||
"shoulder_lift",
|
||||
"elbow_flex",
|
||||
"wrist_flex",
|
||||
"wrist_roll",
|
||||
"gripper",
|
||||
]
|
||||
|
||||
# Disable torque for each arm motor.
|
||||
for motor in arm_motor_ids:
|
||||
|
@ -130,7 +143,9 @@ def run_lekiwi(robot_config):
|
|||
images_lock = threading.Lock()
|
||||
stop_event = threading.Event()
|
||||
cam_thread = threading.Thread(
|
||||
target=run_camera_capture, args=(cameras, images_lock, latest_images_dict, stop_event), daemon=True
|
||||
target=run_camera_capture,
|
||||
args=(cameras, images_lock, latest_images_dict, stop_event),
|
||||
daemon=True,
|
||||
)
|
||||
cam_thread.start()
|
||||
|
||||
|
@ -159,7 +174,9 @@ def run_lekiwi(robot_config):
|
|||
f"[WARNING] Received {len(arm_positions)} arm positions, expected {len(arm_motor_ids)}"
|
||||
)
|
||||
else:
|
||||
for motor, pos in zip(arm_motor_ids, arm_positions, strict=False):
|
||||
for motor, pos in zip(
|
||||
arm_motor_ids, arm_positions, strict=False
|
||||
):
|
||||
motors_bus.write("Goal_Position", pos, motor)
|
||||
# Process wheel (base) commands.
|
||||
if "raw_velocity" in data:
|
||||
|
@ -190,7 +207,9 @@ def run_lekiwi(robot_config):
|
|||
try:
|
||||
pos = motors_bus.read("Present_Position", motor)
|
||||
# Convert the position to a float (or use as is if already numeric).
|
||||
follower_arm_state.append(float(pos) if not isinstance(pos, (int, float)) else pos)
|
||||
follower_arm_state.append(
|
||||
float(pos) if not isinstance(pos, (int, float)) else pos
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"[ERROR] Reading motor {motor} failed: {e}")
|
||||
|
||||
|
|
|
@ -28,7 +28,10 @@ import numpy as np
|
|||
import torch
|
||||
|
||||
from lerobot.common.robot_devices.cameras.utils import make_cameras_from_configs
|
||||
from lerobot.common.robot_devices.motors.utils import MotorsBus, make_motors_buses_from_configs
|
||||
from lerobot.common.robot_devices.motors.utils import (
|
||||
MotorsBus,
|
||||
make_motors_buses_from_configs,
|
||||
)
|
||||
from lerobot.common.robot_devices.robots.configs import ManipulatorRobotConfig
|
||||
from lerobot.common.robot_devices.robots.utils import get_arm_id
|
||||
from lerobot.common.robot_devices.utils import (
|
||||
|
|
|
@ -25,9 +25,14 @@ import zmq
|
|||
|
||||
from lerobot.common.robot_devices.cameras.utils import make_cameras_from_configs
|
||||
from lerobot.common.robot_devices.motors.feetech import TorqueMode
|
||||
from lerobot.common.robot_devices.motors.utils import MotorsBus, make_motors_buses_from_configs
|
||||
from lerobot.common.robot_devices.motors.utils import (
|
||||
MotorsBus,
|
||||
make_motors_buses_from_configs,
|
||||
)
|
||||
from lerobot.common.robot_devices.robots.configs import LeKiwiRobotConfig
|
||||
from lerobot.common.robot_devices.robots.feetech_calibration import run_arm_manual_calibration
|
||||
from lerobot.common.robot_devices.robots.feetech_calibration import (
|
||||
run_arm_manual_calibration,
|
||||
)
|
||||
from lerobot.common.robot_devices.robots.utils import get_arm_id
|
||||
from lerobot.common.robot_devices.utils import RobotDeviceNotConnectedError
|
||||
|
||||
|
@ -266,7 +271,9 @@ class MobileManipulator:
|
|||
calibration = json.load(f)
|
||||
else:
|
||||
print(f"Missing calibration file '{arm_calib_path}'")
|
||||
calibration = run_arm_manual_calibration(arm, self.robot_type, name, arm_type)
|
||||
calibration = run_arm_manual_calibration(
|
||||
arm, self.robot_type, name, arm_type
|
||||
)
|
||||
print(f"Calibration is done! Saving calibration file '{arm_calib_path}'")
|
||||
arm_calib_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(arm_calib_path, "w") as f:
|
||||
|
@ -296,7 +303,9 @@ class MobileManipulator:
|
|||
bus.write("Torque_Enable", 0, motor_id)
|
||||
|
||||
# Then filter out wheels
|
||||
arm_only_dict = {k: v for k, v in bus.motors.items() if not k.startswith("wheel_")}
|
||||
arm_only_dict = {
|
||||
k: v for k, v in bus.motors.items() if not k.startswith("wheel_")
|
||||
}
|
||||
if not arm_only_dict:
|
||||
continue
|
||||
|
||||
|
@ -324,7 +333,11 @@ class MobileManipulator:
|
|||
socks = dict(poller.poll(15))
|
||||
if self.video_socket not in socks or socks[self.video_socket] != zmq.POLLIN:
|
||||
# No new data arrived → reuse ALL old data
|
||||
return (self.last_frames, self.last_present_speed, self.last_remote_arm_state)
|
||||
return (
|
||||
self.last_frames,
|
||||
self.last_present_speed,
|
||||
self.last_remote_arm_state,
|
||||
)
|
||||
|
||||
# Drain all messages, keep only the last
|
||||
last_msg = None
|
||||
|
@ -337,7 +350,11 @@ class MobileManipulator:
|
|||
|
||||
if not last_msg:
|
||||
# No new message → also reuse old
|
||||
return (self.last_frames, self.last_present_speed, self.last_remote_arm_state)
|
||||
return (
|
||||
self.last_frames,
|
||||
self.last_present_speed,
|
||||
self.last_remote_arm_state,
|
||||
)
|
||||
|
||||
# Decode only the final message
|
||||
try:
|
||||
|
@ -360,7 +377,9 @@ class MobileManipulator:
|
|||
if new_arm_state is not None and frames is not None:
|
||||
self.last_frames = frames
|
||||
|
||||
remote_arm_state_tensor = torch.tensor(new_arm_state, dtype=torch.float32)
|
||||
remote_arm_state_tensor = torch.tensor(
|
||||
new_arm_state, dtype=torch.float32
|
||||
)
|
||||
self.last_remote_arm_state = remote_arm_state_tensor
|
||||
|
||||
present_speed = new_speed
|
||||
|
@ -375,14 +394,21 @@ class MobileManipulator:
|
|||
except Exception as e:
|
||||
print(f"[DEBUG] Error decoding video message: {e}")
|
||||
# If decode fails, fall back to old data
|
||||
return (self.last_frames, self.last_present_speed, self.last_remote_arm_state)
|
||||
return (
|
||||
self.last_frames,
|
||||
self.last_present_speed,
|
||||
self.last_remote_arm_state,
|
||||
)
|
||||
|
||||
return frames, present_speed, remote_arm_state_tensor
|
||||
|
||||
def _process_present_speed(self, present_speed: dict) -> torch.Tensor:
|
||||
state_tensor = torch.zeros(3, dtype=torch.int32)
|
||||
if present_speed:
|
||||
decoded = {key: MobileManipulator.raw_to_degps(value) for key, value in present_speed.items()}
|
||||
decoded = {
|
||||
key: MobileManipulator.raw_to_degps(value)
|
||||
for key, value in present_speed.items()
|
||||
}
|
||||
if "1" in decoded:
|
||||
state_tensor[0] = decoded["1"]
|
||||
if "2" in decoded:
|
||||
|
@ -395,7 +421,9 @@ class MobileManipulator:
|
|||
self, record_data: bool = False
|
||||
) -> None | tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
|
||||
if not self.is_connected:
|
||||
raise RobotDeviceNotConnectedError("MobileManipulator is not connected. Run `connect()` first.")
|
||||
raise RobotDeviceNotConnectedError(
|
||||
"MobileManipulator is not connected. Run `connect()` first."
|
||||
)
|
||||
|
||||
speed_setting = self.speed_levels[self.speed_index]
|
||||
xy_speed = speed_setting["xy"] # e.g. 0.1, 0.25, or 0.4
|
||||
|
@ -461,9 +489,15 @@ class MobileManipulator:
|
|||
|
||||
body_state = self.wheel_raw_to_body(present_speed)
|
||||
|
||||
body_state_mm = (body_state[0] * 1000.0, body_state[1] * 1000.0, body_state[2]) # Convert x,y to mm/s
|
||||
body_state_mm = (
|
||||
body_state[0] * 1000.0,
|
||||
body_state[1] * 1000.0,
|
||||
body_state[2],
|
||||
) # Convert x,y to mm/s
|
||||
wheel_state_tensor = torch.tensor(body_state_mm, dtype=torch.float32)
|
||||
combined_state_tensor = torch.cat((remote_arm_state_tensor, wheel_state_tensor), dim=0)
|
||||
combined_state_tensor = torch.cat(
|
||||
(remote_arm_state_tensor, wheel_state_tensor), dim=0
|
||||
)
|
||||
|
||||
obs_dict = {"observation.state": combined_state_tensor}
|
||||
|
||||
|
@ -620,7 +654,11 @@ class MobileManipulator:
|
|||
# Convert each wheel’s angular speed (deg/s) to a raw integer.
|
||||
wheel_raw = [MobileManipulator.degps_to_raw(deg) for deg in wheel_degps]
|
||||
|
||||
return {"left_wheel": wheel_raw[0], "back_wheel": wheel_raw[1], "right_wheel": wheel_raw[2]}
|
||||
return {
|
||||
"left_wheel": wheel_raw[0],
|
||||
"back_wheel": wheel_raw[1],
|
||||
"right_wheel": wheel_raw[2],
|
||||
}
|
||||
|
||||
def wheel_raw_to_body(
|
||||
self, wheel_raw: dict, wheel_radius: float = 0.05, base_radius: float = 0.125
|
||||
|
|
|
@ -72,7 +72,9 @@ def make_robot_from_config(config: RobotConfig):
|
|||
|
||||
return ManipulatorRobot(config)
|
||||
elif isinstance(config, LeKiwiRobotConfig):
|
||||
from lerobot.common.robot_devices.robots.mobile_manipulator import MobileManipulator
|
||||
from lerobot.common.robot_devices.robots.mobile_manipulator import (
|
||||
MobileManipulator,
|
||||
)
|
||||
|
||||
return MobileManipulator(config)
|
||||
else:
|
||||
|
|
|
@ -69,7 +69,9 @@ class HubMixin:
|
|||
if push_to_hub:
|
||||
if repo_id is None:
|
||||
repo_id = save_directory.name # Defaults to `save_directory` name
|
||||
return self.push_to_hub(repo_id=repo_id, card_kwargs=card_kwargs, **push_to_hub_kwargs)
|
||||
return self.push_to_hub(
|
||||
repo_id=repo_id, card_kwargs=card_kwargs, **push_to_hub_kwargs
|
||||
)
|
||||
return None
|
||||
|
||||
def _save_pretrained(self, save_directory: Path) -> None:
|
||||
|
@ -175,7 +177,9 @@ class HubMixin:
|
|||
The url of the commit of your object in the given repository.
|
||||
"""
|
||||
api = HfApi(token=token)
|
||||
repo_id = api.create_repo(repo_id=repo_id, private=private, exist_ok=True).repo_id
|
||||
repo_id = api.create_repo(
|
||||
repo_id=repo_id, private=private, exist_ok=True
|
||||
).repo_id
|
||||
|
||||
if commit_message is None:
|
||||
if "Policy" in self.__class__.__name__:
|
||||
|
|
|
@ -20,7 +20,16 @@ from typing import TypeVar
|
|||
|
||||
import imageio
|
||||
|
||||
JsonLike = str | int | float | bool | None | list["JsonLike"] | dict[str, "JsonLike"] | tuple["JsonLike", ...]
|
||||
JsonLike = (
|
||||
str
|
||||
| int
|
||||
| float
|
||||
| bool
|
||||
| None
|
||||
| list["JsonLike"]
|
||||
| dict[str, "JsonLike"]
|
||||
| tuple["JsonLike", ...]
|
||||
)
|
||||
T = TypeVar("T", bound=JsonLike)
|
||||
|
||||
|
||||
|
@ -76,7 +85,9 @@ def deserialize_json_into_object(fpath: Path, obj: T) -> T:
|
|||
|
||||
# Check length
|
||||
if len(target) != len(source):
|
||||
raise ValueError(f"List length mismatch: expected {len(target)}, got {len(source)}")
|
||||
raise ValueError(
|
||||
f"List length mismatch: expected {len(target)}, got {len(source)}"
|
||||
)
|
||||
|
||||
# Recursively update each element.
|
||||
for i in range(len(target)):
|
||||
|
@ -88,10 +99,14 @@ def deserialize_json_into_object(fpath: Path, obj: T) -> T:
|
|||
# which we'll convert back to a tuple.
|
||||
elif isinstance(target, tuple):
|
||||
if not isinstance(source, list):
|
||||
raise TypeError(f"Type mismatch: expected list (for tuple), got {type(source)}")
|
||||
raise TypeError(
|
||||
f"Type mismatch: expected list (for tuple), got {type(source)}"
|
||||
)
|
||||
|
||||
if len(target) != len(source):
|
||||
raise ValueError(f"Tuple length mismatch: expected {len(target)}, got {len(source)}")
|
||||
raise ValueError(
|
||||
f"Tuple length mismatch: expected {len(target)}, got {len(source)}"
|
||||
)
|
||||
|
||||
# Convert each element, forming a new tuple.
|
||||
converted_items = []
|
||||
|
@ -105,7 +120,9 @@ def deserialize_json_into_object(fpath: Path, obj: T) -> T:
|
|||
else:
|
||||
# Check the exact type. If these must match 1:1, do:
|
||||
if type(target) is not type(source):
|
||||
raise TypeError(f"Type mismatch: expected {type(target)}, got {type(source)}")
|
||||
raise TypeError(
|
||||
f"Type mismatch: expected {type(target)}, got {type(source)}"
|
||||
)
|
||||
return source
|
||||
|
||||
# Perform the in-place/recursive deserialization
|
||||
|
|
|
@ -107,13 +107,17 @@ class MetricsTracker:
|
|||
self.episodes = self.samples / self._avg_samples_per_ep
|
||||
self.epochs = self.samples / self._num_frames
|
||||
|
||||
def __getattr__(self, name: str) -> int | dict[str, AverageMeter] | AverageMeter | Any:
|
||||
def __getattr__(
|
||||
self, name: str
|
||||
) -> int | dict[str, AverageMeter] | AverageMeter | Any:
|
||||
if name in self.__dict__:
|
||||
return self.__dict__[name]
|
||||
elif name in self.metrics:
|
||||
return self.metrics[name]
|
||||
else:
|
||||
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
|
||||
raise AttributeError(
|
||||
f"'{self.__class__.__name__}' object has no attribute '{name}'"
|
||||
)
|
||||
|
||||
def __setattr__(self, name: str, value: Any) -> None:
|
||||
if name in self.__dict__:
|
||||
|
@ -121,7 +125,9 @@ class MetricsTracker:
|
|||
elif name in self.metrics:
|
||||
self.metrics[name].update(value)
|
||||
else:
|
||||
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
|
||||
raise AttributeError(
|
||||
f"'{self.__class__.__name__}' object has no attribute '{name}'"
|
||||
)
|
||||
|
||||
def step(self) -> None:
|
||||
"""
|
||||
|
|
|
@ -42,7 +42,11 @@ def deserialize_python_rng_state(rng_state_dict: dict[str, torch.Tensor]) -> Non
|
|||
"""
|
||||
Restores the rng state for `random` from a dictionary produced by `serialize_python_rng_state()`.
|
||||
"""
|
||||
py_state = (rng_state_dict["py_rng_version"].item(), tuple(rng_state_dict["py_rng_state"].tolist()), None)
|
||||
py_state = (
|
||||
rng_state_dict["py_rng_version"].item(),
|
||||
tuple(rng_state_dict["py_rng_state"].tolist()),
|
||||
None,
|
||||
)
|
||||
random.setstate(py_state)
|
||||
|
||||
|
||||
|
@ -119,7 +123,9 @@ def deserialize_rng_state(rng_state_dict: dict[str, torch.Tensor]) -> None:
|
|||
"""
|
||||
py_rng_state_dict = {k: v for k, v in rng_state_dict.items() if k.startswith("py")}
|
||||
np_rng_state_dict = {k: v for k, v in rng_state_dict.items() if k.startswith("np")}
|
||||
torch_rng_state_dict = {k: v for k, v in rng_state_dict.items() if k.startswith("torch")}
|
||||
torch_rng_state_dict = {
|
||||
k: v for k, v in rng_state_dict.items() if k.startswith("torch")
|
||||
}
|
||||
|
||||
deserialize_python_rng_state(py_rng_state_dict)
|
||||
deserialize_numpy_rng_state(np_rng_state_dict)
|
||||
|
|
|
@ -48,7 +48,9 @@ def auto_select_torch_device() -> torch.device:
|
|||
logging.info("Metal backend detected, using cuda.")
|
||||
return torch.device("mps")
|
||||
else:
|
||||
logging.warning("No accelerated backend detected. Using default cpu, this will be slow.")
|
||||
logging.warning(
|
||||
"No accelerated backend detected. Using default cpu, this will be slow."
|
||||
)
|
||||
return torch.device("cpu")
|
||||
|
||||
|
||||
|
@ -96,7 +98,9 @@ def is_torch_device_available(try_device: str) -> bool:
|
|||
elif try_device == "cpu":
|
||||
return True
|
||||
else:
|
||||
raise ValueError(f"Unknown device {try_device}. Supported devices are: cuda, mps or cpu.")
|
||||
raise ValueError(
|
||||
f"Unknown device {try_device}. Supported devices are: cuda, mps or cpu."
|
||||
)
|
||||
|
||||
|
||||
def is_amp_available(device: str):
|
||||
|
@ -219,7 +223,9 @@ def say(text, blocking=False):
|
|||
if blocking:
|
||||
subprocess.run(cmd, check=True)
|
||||
else:
|
||||
subprocess.Popen(cmd, creationflags=subprocess.CREATE_NO_WINDOW if system == "Windows" else 0)
|
||||
subprocess.Popen(
|
||||
cmd, creationflags=subprocess.CREATE_NO_WINDOW if system == "Windows" else 0
|
||||
)
|
||||
|
||||
|
||||
def log_say(text, play_sounds, blocking=False):
|
||||
|
|
|
@ -26,7 +26,9 @@ from lerobot.common.constants import PRETRAINED_MODEL_DIR
|
|||
from lerobot.configs.train import TrainPipelineConfig
|
||||
|
||||
|
||||
def cfg_to_group(cfg: TrainPipelineConfig, return_list: bool = False) -> list[str] | str:
|
||||
def cfg_to_group(
|
||||
cfg: TrainPipelineConfig, return_list: bool = False
|
||||
) -> list[str] | str:
|
||||
"""Return a group name for logging. Optionally returns group name as list."""
|
||||
lst = [
|
||||
f"policy:{cfg.policy.type}",
|
||||
|
@ -92,7 +94,9 @@ class WandBLogger:
|
|||
resume="must" if cfg.resume else None,
|
||||
)
|
||||
print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"]))
|
||||
logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}")
|
||||
logging.info(
|
||||
f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}"
|
||||
)
|
||||
self._wandb = wandb
|
||||
|
||||
def log_policy(self, checkpoint_dir: Path):
|
||||
|
@ -104,7 +108,9 @@ class WandBLogger:
|
|||
artifact_name = f"{self._group}-{step_id}"
|
||||
artifact_name = get_safe_wandb_artifact_name(artifact_name)
|
||||
artifact = self._wandb.Artifact(artifact_name, type="model")
|
||||
artifact.add_file(checkpoint_dir / PRETRAINED_MODEL_DIR / SAFETENSORS_SINGLE_FILE)
|
||||
artifact.add_file(
|
||||
checkpoint_dir / PRETRAINED_MODEL_DIR / SAFETENSORS_SINGLE_FILE
|
||||
)
|
||||
self._wandb.log_artifact(artifact)
|
||||
|
||||
def log_dict(self, d: dict, step: int, mode: str = "train"):
|
||||
|
|
|
@ -33,7 +33,9 @@ class DatasetConfig:
|
|||
# Root directory where the dataset will be stored (e.g. 'dataset/path').
|
||||
root: str | None = None
|
||||
episodes: list[int] | None = None
|
||||
image_transforms: ImageTransformsConfig = field(default_factory=ImageTransformsConfig)
|
||||
image_transforms: ImageTransformsConfig = field(
|
||||
default_factory=ImageTransformsConfig
|
||||
)
|
||||
revision: str | None = None
|
||||
use_imagenet_stats: bool = True
|
||||
video_backend: str = field(default_factory=get_safe_default_codec)
|
||||
|
|
|
@ -40,7 +40,9 @@ class EvalPipelineConfig:
|
|||
policy_path = parser.get_path_arg("policy")
|
||||
if policy_path:
|
||||
cli_overrides = parser.get_cli_overrides("policy")
|
||||
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
||||
self.policy = PreTrainedConfig.from_pretrained(
|
||||
policy_path, cli_overrides=cli_overrides
|
||||
)
|
||||
self.policy.pretrained_path = policy_path
|
||||
|
||||
else:
|
||||
|
|
|
@ -29,7 +29,9 @@ PLUGIN_DISCOVERY_SUFFIX = "discover_packages_path"
|
|||
draccus.set_config_type("json")
|
||||
|
||||
|
||||
def get_cli_overrides(field_name: str, args: Sequence[str] | None = None) -> list[str] | None:
|
||||
def get_cli_overrides(
|
||||
field_name: str, args: Sequence[str] | None = None
|
||||
) -> list[str] | None:
|
||||
"""Parses arguments from cli at a given nested attribute level.
|
||||
|
||||
For example, supposing the main script was called with:
|
||||
|
@ -42,7 +44,10 @@ def get_cli_overrides(field_name: str, args: Sequence[str] | None = None) -> lis
|
|||
args = sys.argv[1:]
|
||||
attr_level_args = []
|
||||
detect_string = f"--{field_name}."
|
||||
exclude_strings = (f"--{field_name}.{draccus.CHOICE_TYPE_KEY}=", f"--{field_name}.{PATH_KEY}=")
|
||||
exclude_strings = (
|
||||
f"--{field_name}.{draccus.CHOICE_TYPE_KEY}=",
|
||||
f"--{field_name}.{PATH_KEY}=",
|
||||
)
|
||||
for arg in args:
|
||||
if arg.startswith(detect_string) and not arg.startswith(exclude_strings):
|
||||
denested_arg = f"--{arg.removeprefix(detect_string)}"
|
||||
|
@ -153,7 +158,9 @@ def filter_arg(field_to_filter: str, args: Sequence[str] | None = None) -> list[
|
|||
return [arg for arg in args if not arg.startswith(f"--{field_to_filter}=")]
|
||||
|
||||
|
||||
def filter_path_args(fields_to_filter: str | list[str], args: Sequence[str] | None = None) -> list[str]:
|
||||
def filter_path_args(
|
||||
fields_to_filter: str | list[str], args: Sequence[str] | None = None
|
||||
) -> list[str]:
|
||||
"""
|
||||
Filters command-line arguments related to fields with specific path arguments.
|
||||
|
||||
|
@ -181,7 +188,9 @@ def filter_path_args(fields_to_filter: str | list[str], args: Sequence[str] | No
|
|||
argument=None,
|
||||
message=f"Cannot specify both --{field}.{PATH_KEY} and --{field}.{draccus.CHOICE_TYPE_KEY}",
|
||||
)
|
||||
filtered_args = [arg for arg in filtered_args if not arg.startswith(f"--{field}.")]
|
||||
filtered_args = [
|
||||
arg for arg in filtered_args if not arg.startswith(f"--{field}.")
|
||||
]
|
||||
|
||||
return filtered_args
|
||||
|
||||
|
@ -213,7 +222,9 @@ def wrap(config_path: Path | None = None):
|
|||
load_plugin(plugin_path)
|
||||
except PluginLoadError as e:
|
||||
# add the relevant CLI arg to the error message
|
||||
raise PluginLoadError(f"{e}\nFailed plugin CLI Arg: {plugin_cli_arg}") from e
|
||||
raise PluginLoadError(
|
||||
f"{e}\nFailed plugin CLI Arg: {plugin_cli_arg}"
|
||||
) from e
|
||||
cli_args = filter_arg(plugin_cli_arg, cli_args)
|
||||
config_path_cli = parse_arg("config_path", cli_args)
|
||||
if has_method(argtype, "__get_path_fields__"):
|
||||
|
@ -223,7 +234,9 @@ def wrap(config_path: Path | None = None):
|
|||
cli_args = filter_arg("config_path", cli_args)
|
||||
cfg = argtype.from_pretrained(config_path_cli, cli_args=cli_args)
|
||||
else:
|
||||
cfg = draccus.parse(config_class=argtype, config_path=config_path, args=cli_args)
|
||||
cfg = draccus.parse(
|
||||
config_class=argtype, config_path=config_path, args=cli_args
|
||||
)
|
||||
response = fn(cfg, *args, **kwargs)
|
||||
return response
|
||||
|
||||
|
|
|
@ -26,7 +26,11 @@ from huggingface_hub.errors import HfHubHTTPError
|
|||
from lerobot.common.optim.optimizers import OptimizerConfig
|
||||
from lerobot.common.optim.schedulers import LRSchedulerConfig
|
||||
from lerobot.common.utils.hub import HubMixin
|
||||
from lerobot.common.utils.utils import auto_select_torch_device, is_amp_available, is_torch_device_available
|
||||
from lerobot.common.utils.utils import (
|
||||
auto_select_torch_device,
|
||||
is_amp_available,
|
||||
is_torch_device_available,
|
||||
)
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
|
||||
# Generic variable that is either PreTrainedConfig or a subclass thereof
|
||||
|
@ -64,7 +68,9 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
|
|||
self.pretrained_path = None
|
||||
if not self.device or not is_torch_device_available(self.device):
|
||||
auto_device = auto_select_torch_device()
|
||||
logging.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.")
|
||||
logging.warning(
|
||||
f"Device '{self.device}' is not available. Switching to '{auto_device}'."
|
||||
)
|
||||
self.device = auto_device.type
|
||||
|
||||
# Automatically deactivate AMP if necessary
|
||||
|
@ -118,7 +124,11 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
|
|||
|
||||
@property
|
||||
def image_features(self) -> dict[str, PolicyFeature]:
|
||||
return {key: ft for key, ft in self.input_features.items() if ft.type is FeatureType.VISUAL}
|
||||
return {
|
||||
key: ft
|
||||
for key, ft in self.input_features.items()
|
||||
if ft.type is FeatureType.VISUAL
|
||||
}
|
||||
|
||||
@property
|
||||
def action_feature(self) -> PolicyFeature | None:
|
||||
|
|
|
@ -73,7 +73,9 @@ class TrainPipelineConfig(HubMixin):
|
|||
if policy_path:
|
||||
# Only load the policy config
|
||||
cli_overrides = parser.get_cli_overrides("policy")
|
||||
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
||||
self.policy = PreTrainedConfig.from_pretrained(
|
||||
policy_path, cli_overrides=cli_overrides
|
||||
)
|
||||
self.policy.pretrained_path = policy_path
|
||||
elif self.resume:
|
||||
# The entire train config is already loaded, we just need to get the checkpoint dir
|
||||
|
@ -97,7 +99,11 @@ class TrainPipelineConfig(HubMixin):
|
|||
else:
|
||||
self.job_name = f"{self.env.type}_{self.policy.type}"
|
||||
|
||||
if not self.resume and isinstance(self.output_dir, Path) and self.output_dir.is_dir():
|
||||
if (
|
||||
not self.resume
|
||||
and isinstance(self.output_dir, Path)
|
||||
and self.output_dir.is_dir()
|
||||
):
|
||||
raise FileExistsError(
|
||||
f"Output directory {self.output_dir} already exists and resume is {self.resume}. "
|
||||
f"Please change your output directory so that {self.output_dir} is not overwritten."
|
||||
|
@ -108,10 +114,16 @@ class TrainPipelineConfig(HubMixin):
|
|||
self.output_dir = Path("outputs/train") / train_dir
|
||||
|
||||
if isinstance(self.dataset.repo_id, list):
|
||||
raise NotImplementedError("LeRobotMultiDataset is not currently implemented.")
|
||||
raise NotImplementedError(
|
||||
"LeRobotMultiDataset is not currently implemented."
|
||||
)
|
||||
|
||||
if not self.use_policy_training_preset and (self.optimizer is None or self.scheduler is None):
|
||||
raise ValueError("Optimizer and Scheduler must be set when the policy presets are not used.")
|
||||
if not self.use_policy_training_preset and (
|
||||
self.optimizer is None or self.scheduler is None
|
||||
):
|
||||
raise ValueError(
|
||||
"Optimizer and Scheduler must be set when the policy presets are not used."
|
||||
)
|
||||
elif self.use_policy_training_preset and not self.resume:
|
||||
self.optimizer = self.policy.get_optimizer_preset()
|
||||
self.scheduler = self.policy.get_scheduler_preset()
|
||||
|
@ -125,7 +137,10 @@ class TrainPipelineConfig(HubMixin):
|
|||
return draccus.encode(self)
|
||||
|
||||
def _save_pretrained(self, save_directory: Path) -> None:
|
||||
with open(save_directory / TRAIN_CONFIG_NAME, "w") as f, draccus.config_type("json"):
|
||||
with (
|
||||
open(save_directory / TRAIN_CONFIG_NAME, "w") as f,
|
||||
draccus.config_type("json"),
|
||||
):
|
||||
draccus.dump(self, f, indent=4)
|
||||
|
||||
@classmethod
|
||||
|
|
|
@ -38,7 +38,12 @@ def get_motor_bus_cls(brand: str) -> tuple:
|
|||
FeetechMotorsBus,
|
||||
)
|
||||
|
||||
return FeetechMotorsBusConfig, FeetechMotorsBus, MODEL_BAUDRATE_TABLE, SCS_SERIES_BAUDRATE_TABLE
|
||||
return (
|
||||
FeetechMotorsBusConfig,
|
||||
FeetechMotorsBus,
|
||||
MODEL_BAUDRATE_TABLE,
|
||||
SCS_SERIES_BAUDRATE_TABLE,
|
||||
)
|
||||
|
||||
elif brand == "dynamixel":
|
||||
from lerobot.common.robot_devices.motors.configs import DynamixelMotorsBusConfig
|
||||
|
@ -48,7 +53,12 @@ def get_motor_bus_cls(brand: str) -> tuple:
|
|||
DynamixelMotorsBus,
|
||||
)
|
||||
|
||||
return DynamixelMotorsBusConfig, DynamixelMotorsBus, MODEL_BAUDRATE_TABLE, X_SERIES_BAUDRATE_TABLE
|
||||
return (
|
||||
DynamixelMotorsBusConfig,
|
||||
DynamixelMotorsBus,
|
||||
MODEL_BAUDRATE_TABLE,
|
||||
X_SERIES_BAUDRATE_TABLE,
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
|
@ -57,8 +67,8 @@ def get_motor_bus_cls(brand: str) -> tuple:
|
|||
|
||||
|
||||
def configure_motor(port, brand, model, motor_idx_des, baudrate_des):
|
||||
motor_bus_config_cls, motor_bus_cls, model_baudrate_table, series_baudrate_table = get_motor_bus_cls(
|
||||
brand
|
||||
motor_bus_config_cls, motor_bus_cls, model_baudrate_table, series_baudrate_table = (
|
||||
get_motor_bus_cls(brand)
|
||||
)
|
||||
|
||||
# Check if the provided model exists in the model_baud_rate_table
|
||||
|
@ -72,7 +82,9 @@ def configure_motor(port, brand, model, motor_idx_des, baudrate_des):
|
|||
motor_index_arbitrary = motor_idx_des # Use the motor ID passed via argument
|
||||
motor_model = model # Use the motor model passed via argument
|
||||
|
||||
config = motor_bus_config_cls(port=port, motors={motor_name: (motor_index_arbitrary, motor_model)})
|
||||
config = motor_bus_config_cls(
|
||||
port=port, motors={motor_name: (motor_index_arbitrary, motor_model)}
|
||||
)
|
||||
|
||||
# Initialize the MotorBus with the correct port and motor configurations
|
||||
motor_bus = motor_bus_cls(config=config)
|
||||
|
@ -139,8 +151,12 @@ def configure_motor(port, brand, model, motor_idx_des, baudrate_des):
|
|||
|
||||
print(f"Setting its index to desired index {motor_idx_des}")
|
||||
if brand == "feetech":
|
||||
motor_bus.write_with_motor_ids(motor_bus.motor_models, motor_index, "Lock", 0)
|
||||
motor_bus.write_with_motor_ids(motor_bus.motor_models, motor_index, "ID", motor_idx_des)
|
||||
motor_bus.write_with_motor_ids(
|
||||
motor_bus.motor_models, motor_index, "Lock", 0
|
||||
)
|
||||
motor_bus.write_with_motor_ids(
|
||||
motor_bus.motor_models, motor_index, "ID", motor_idx_des
|
||||
)
|
||||
|
||||
present_idx = motor_bus.read_with_motor_ids(
|
||||
motor_bus.motor_models, motor_idx_des, "ID", num_retry=2
|
||||
|
|
|
@ -156,7 +156,6 @@ from lerobot.common.robot_devices.control_utils import (
|
|||
log_control_info,
|
||||
record_episode,
|
||||
reset_environment,
|
||||
reset_follower_position,
|
||||
sanity_check_dataset_name,
|
||||
sanity_check_dataset_robot_compatibility,
|
||||
stop_recording,
|
||||
|
@ -251,7 +250,8 @@ def record(
|
|||
if len(robot.cameras) > 0:
|
||||
dataset.start_image_writer(
|
||||
num_processes=cfg.num_image_writer_processes,
|
||||
num_threads=cfg.num_image_writer_threads_per_camera * len(robot.cameras),
|
||||
num_threads=cfg.num_image_writer_threads_per_camera
|
||||
* len(robot.cameras),
|
||||
)
|
||||
sanity_check_dataset_robot_compatibility(dataset, robot, cfg.fps, cfg.video)
|
||||
else:
|
||||
|
@ -264,14 +264,19 @@ def record(
|
|||
robot=robot,
|
||||
use_videos=cfg.video,
|
||||
image_writer_processes=cfg.num_image_writer_processes,
|
||||
image_writer_threads=cfg.num_image_writer_threads_per_camera * len(robot.cameras),
|
||||
image_writer_threads=cfg.num_image_writer_threads_per_camera
|
||||
* len(robot.cameras),
|
||||
)
|
||||
|
||||
# Load pretrained policy
|
||||
policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta)
|
||||
policy = (
|
||||
None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta)
|
||||
)
|
||||
|
||||
# Load pretrained policy
|
||||
policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta)
|
||||
policy = (
|
||||
None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta)
|
||||
)
|
||||
|
||||
if not robot.is_connected:
|
||||
robot.connect()
|
||||
|
@ -286,7 +291,14 @@ def record(
|
|||
# 3. place the cameras windows on screen
|
||||
enable_teleoperation = policy is None
|
||||
log_say("Warmup record", cfg.play_sounds)
|
||||
warmup_record(robot, events, enable_teleoperation, cfg.warmup_time_s, cfg.display_cameras, cfg.fps)
|
||||
warmup_record(
|
||||
robot,
|
||||
events,
|
||||
enable_teleoperation,
|
||||
cfg.warmup_time_s,
|
||||
cfg.display_cameras,
|
||||
cfg.fps,
|
||||
)
|
||||
|
||||
if has_method(robot, "teleop_safety_stop"):
|
||||
robot.teleop_safety_stop()
|
||||
|
|
|
@ -262,7 +262,11 @@ def record(
|
|||
shape = env.observation_space[key].shape
|
||||
if not key.startswith("observation.image."):
|
||||
key = "observation.image." + key
|
||||
features[key] = {"dtype": "video", "names": ["channels", "height", "width"], "shape": shape}
|
||||
features[key] = {
|
||||
"dtype": "video",
|
||||
"names": ["channels", "height", "width"],
|
||||
"shape": shape,
|
||||
}
|
||||
|
||||
for key, obs_key in state_keys_dict.items():
|
||||
features[key] = {
|
||||
|
|
|
@ -152,7 +152,8 @@ def rollout(
|
|||
all_observations.append(deepcopy(observation))
|
||||
|
||||
observation = {
|
||||
key: observation[key].to(device, non_blocking=device.type == "cuda") for key in observation
|
||||
key: observation[key].to(device, non_blocking=device.type == "cuda")
|
||||
for key in observation
|
||||
}
|
||||
|
||||
with torch.inference_mode():
|
||||
|
@ -511,10 +512,14 @@ def eval_main(cfg: EvalPipelineConfig):
|
|||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
set_seed(cfg.seed)
|
||||
|
||||
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}")
|
||||
logging.info(
|
||||
colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}"
|
||||
)
|
||||
|
||||
logging.info("Making environment.")
|
||||
env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs)
|
||||
env = make_env(
|
||||
cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs
|
||||
)
|
||||
|
||||
logging.info("Making policy.")
|
||||
|
||||
|
@ -524,7 +529,12 @@ def eval_main(cfg: EvalPipelineConfig):
|
|||
)
|
||||
policy.eval()
|
||||
|
||||
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext():
|
||||
with (
|
||||
torch.no_grad(),
|
||||
torch.autocast(device_type=device.type)
|
||||
if cfg.policy.use_amp
|
||||
else nullcontext(),
|
||||
):
|
||||
info = eval_policy(
|
||||
env,
|
||||
policy,
|
||||
|
|
|
@ -1087,9 +1087,9 @@ class GamepadControlWrapper(gym.Wrapper):
|
|||
class ActionScaleWrapper(gym.ActionWrapper):
|
||||
def __init__(self, env, ee_action_space_params=None):
|
||||
super().__init__(env)
|
||||
assert (
|
||||
ee_action_space_params is not None
|
||||
), "TODO: method implemented for ee action space only so far"
|
||||
assert ee_action_space_params is not None, (
|
||||
"TODO: method implemented for ee action space only so far"
|
||||
)
|
||||
self.scale_vector = np.array(
|
||||
[
|
||||
[
|
||||
|
|
|
@ -223,14 +223,18 @@ def train(cfg: TrainPipelineConfig):
|
|||
step = 0 # number of policy updates (forward + backward + optim)
|
||||
|
||||
if cfg.resume:
|
||||
step, optimizer, lr_scheduler = load_training_state(cfg.checkpoint_path, optimizer, lr_scheduler)
|
||||
step, optimizer, lr_scheduler = load_training_state(
|
||||
cfg.checkpoint_path, optimizer, lr_scheduler
|
||||
)
|
||||
|
||||
num_learnable_params = sum(
|
||||
p.numel() for p in policy.parameters() if p.requires_grad
|
||||
)
|
||||
num_total_params = sum(p.numel() for p in policy.parameters())
|
||||
|
||||
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}")
|
||||
logging.info(
|
||||
colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}"
|
||||
)
|
||||
if cfg.env is not None:
|
||||
logging.info(f"{cfg.env.task=}")
|
||||
logging.info(f"{cfg.steps=} ({format_big_number(cfg.steps)})")
|
||||
|
@ -273,7 +277,11 @@ def train(cfg: TrainPipelineConfig):
|
|||
}
|
||||
|
||||
train_tracker = MetricsTracker(
|
||||
cfg.batch_size, dataset.num_frames, dataset.num_episodes, train_metrics, initial_step=step
|
||||
cfg.batch_size,
|
||||
dataset.num_frames,
|
||||
dataset.num_episodes,
|
||||
train_metrics,
|
||||
initial_step=step,
|
||||
)
|
||||
|
||||
logging.info("Start offline training on a fixed dataset")
|
||||
|
@ -327,7 +335,9 @@ def train(cfg: TrainPipelineConfig):
|
|||
logging.info(f"Eval policy at step {step}")
|
||||
with (
|
||||
torch.no_grad(),
|
||||
torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(),
|
||||
torch.autocast(device_type=device.type)
|
||||
if cfg.policy.use_amp
|
||||
else nullcontext(),
|
||||
):
|
||||
eval_info = eval_policy(
|
||||
eval_env,
|
||||
|
@ -344,7 +354,11 @@ def train(cfg: TrainPipelineConfig):
|
|||
"eval_s": AverageMeter("eval_s", ":.3f"),
|
||||
}
|
||||
eval_tracker = MetricsTracker(
|
||||
cfg.batch_size, dataset.num_frames, dataset.num_episodes, eval_metrics, initial_step=step
|
||||
cfg.batch_size,
|
||||
dataset.num_frames,
|
||||
dataset.num_episodes,
|
||||
eval_metrics,
|
||||
initial_step=step,
|
||||
)
|
||||
eval_tracker.eval_s = eval_info["aggregated"].pop("eval_s")
|
||||
eval_tracker.avg_sum_reward = eval_info["aggregated"].pop("avg_sum_reward")
|
||||
|
|
|
@ -94,9 +94,9 @@ def to_hwc_uint8_numpy(chw_float32_torch: torch.Tensor) -> np.ndarray:
|
|||
assert chw_float32_torch.dtype == torch.float32
|
||||
assert chw_float32_torch.ndim == 3
|
||||
c, h, w = chw_float32_torch.shape
|
||||
assert (
|
||||
c < h and c < w
|
||||
), f"expect channel first images, but instead {chw_float32_torch.shape}"
|
||||
assert c < h and c < w, (
|
||||
f"expect channel first images, but instead {chw_float32_torch.shape}"
|
||||
)
|
||||
hwc_uint8_numpy = (
|
||||
(chw_float32_torch * 255).type(torch.uint8).permute(1, 2, 0).numpy()
|
||||
)
|
||||
|
|
|
@ -158,7 +158,9 @@ def run_server(
|
|||
400,
|
||||
)
|
||||
dataset_version = (
|
||||
str(dataset.meta._version) if isinstance(dataset, LeRobotDataset) else dataset.codebase_version
|
||||
str(dataset.meta._version)
|
||||
if isinstance(dataset, LeRobotDataset)
|
||||
else dataset.codebase_version
|
||||
)
|
||||
match = re.search(r"v(\d+)\.", dataset_version)
|
||||
if match:
|
||||
|
@ -166,7 +168,9 @@ def run_server(
|
|||
if major_version < 2:
|
||||
return "Make sure to convert your LeRobotDataset to v2 & above."
|
||||
|
||||
episode_data_csv_str, columns, ignored_columns = get_episode_data(dataset, episode_id)
|
||||
episode_data_csv_str, columns, ignored_columns = get_episode_data(
|
||||
dataset, episode_id
|
||||
)
|
||||
dataset_info = {
|
||||
"repo_id": f"{dataset_namespace}/{dataset_name}",
|
||||
"num_samples": dataset.num_frames
|
||||
|
@ -208,7 +212,8 @@ def run_server(
|
|||
]
|
||||
|
||||
response = requests.get(
|
||||
f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/episodes.jsonl", timeout=5
|
||||
f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/episodes.jsonl",
|
||||
timeout=5,
|
||||
)
|
||||
response.raise_for_status()
|
||||
# Split into lines and parse each line as JSON
|
||||
|
@ -256,7 +261,11 @@ def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index)
|
|||
This file will be loaded by Dygraph javascript to plot data in real time."""
|
||||
columns = []
|
||||
|
||||
selected_columns = [col for col, ft in dataset.features.items() if ft["dtype"] in ["float32", "int32"]]
|
||||
selected_columns = [
|
||||
col
|
||||
for col, ft in dataset.features.items()
|
||||
if ft["dtype"] in ["float32", "int32"]
|
||||
]
|
||||
selected_columns.remove("timestamp")
|
||||
|
||||
ignored_columns = []
|
||||
|
@ -361,7 +370,8 @@ def get_episode_language_instruction(
|
|||
|
||||
def get_dataset_info(repo_id: str) -> IterableNamespace:
|
||||
response = requests.get(
|
||||
f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/info.json", timeout=5
|
||||
f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/info.json",
|
||||
timeout=5,
|
||||
)
|
||||
response.raise_for_status() # Raises an HTTPError for bad responses
|
||||
dataset_info = response.json()
|
||||
|
|
|
@ -47,7 +47,9 @@ OUTPUT_DIR = Path("outputs/image_transforms")
|
|||
to_pil = ToPILImage()
|
||||
|
||||
|
||||
def save_all_transforms(cfg: ImageTransformsConfig, original_frame, output_dir, n_examples):
|
||||
def save_all_transforms(
|
||||
cfg: ImageTransformsConfig, original_frame, output_dir, n_examples
|
||||
):
|
||||
output_dir_all = output_dir / "all"
|
||||
output_dir_all.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
@ -60,7 +62,9 @@ def save_all_transforms(cfg: ImageTransformsConfig, original_frame, output_dir,
|
|||
print(f" {output_dir_all}")
|
||||
|
||||
|
||||
def save_each_transform(cfg: ImageTransformsConfig, original_frame, output_dir, n_examples):
|
||||
def save_each_transform(
|
||||
cfg: ImageTransformsConfig, original_frame, output_dir, n_examples
|
||||
):
|
||||
if not cfg.enable:
|
||||
logging.warning(
|
||||
"No single transforms will be saved, because `image_transforms.enable=False`. To enable, set `enable` to True in `ImageTransformsConfig` or in the command line with `--image_transforms.enable=True`."
|
||||
|
@ -89,9 +93,15 @@ def save_each_transform(cfg: ImageTransformsConfig, original_frame, output_dir,
|
|||
tf_cfg_kwgs_max[key] = [max_, max_]
|
||||
tf_cfg_kwgs_avg[key] = [avg, avg]
|
||||
|
||||
tf_min = make_transform_from_config(replace(tf_cfg, **{"kwargs": tf_cfg_kwgs_min}))
|
||||
tf_max = make_transform_from_config(replace(tf_cfg, **{"kwargs": tf_cfg_kwgs_max}))
|
||||
tf_avg = make_transform_from_config(replace(tf_cfg, **{"kwargs": tf_cfg_kwgs_avg}))
|
||||
tf_min = make_transform_from_config(
|
||||
replace(tf_cfg, **{"kwargs": tf_cfg_kwgs_min})
|
||||
)
|
||||
tf_max = make_transform_from_config(
|
||||
replace(tf_cfg, **{"kwargs": tf_cfg_kwgs_max})
|
||||
)
|
||||
tf_avg = make_transform_from_config(
|
||||
replace(tf_cfg, **{"kwargs": tf_cfg_kwgs_avg})
|
||||
)
|
||||
|
||||
tf_frame_min = tf_min(original_frame)
|
||||
tf_frame_max = tf_max(original_frame)
|
||||
|
@ -105,7 +115,9 @@ def save_each_transform(cfg: ImageTransformsConfig, original_frame, output_dir,
|
|||
|
||||
|
||||
@draccus.wrap()
|
||||
def visualize_image_transforms(cfg: DatasetConfig, output_dir: Path = OUTPUT_DIR, n_examples: int = 5):
|
||||
def visualize_image_transforms(
|
||||
cfg: DatasetConfig, output_dir: Path = OUTPUT_DIR, n_examples: int = 5
|
||||
):
|
||||
dataset = LeRobotDataset(
|
||||
repo_id=cfg.repo_id,
|
||||
episodes=cfg.episodes,
|
||||
|
|
|
@ -51,7 +51,9 @@ def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict):
|
|||
batch = next(iter(dataloader))
|
||||
loss, output_dict = policy.forward(batch)
|
||||
if output_dict is not None:
|
||||
output_dict = {k: v for k, v in output_dict.items() if isinstance(v, torch.Tensor)}
|
||||
output_dict = {
|
||||
k: v for k, v in output_dict.items() if isinstance(v, torch.Tensor)
|
||||
}
|
||||
output_dict["loss"] = loss
|
||||
else:
|
||||
output_dict = {"loss": loss}
|
||||
|
@ -69,7 +71,9 @@ def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict):
|
|||
param_stats = {}
|
||||
for key, param in policy.named_parameters():
|
||||
param_stats[f"{key}_mean"] = param.mean()
|
||||
param_stats[f"{key}_std"] = param.std() if param.numel() > 1 else torch.tensor(float(0.0))
|
||||
param_stats[f"{key}_std"] = (
|
||||
param.std() if param.numel() > 1 else torch.tensor(float(0.0))
|
||||
)
|
||||
|
||||
optimizer.zero_grad()
|
||||
policy.reset()
|
||||
|
@ -96,11 +100,15 @@ def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict):
|
|||
else:
|
||||
actions_queue = train_cfg.policy.n_action_repeats
|
||||
|
||||
actions = {str(i): policy.select_action(obs).contiguous() for i in range(actions_queue)}
|
||||
actions = {
|
||||
str(i): policy.select_action(obs).contiguous() for i in range(actions_queue)
|
||||
}
|
||||
return output_dict, grad_stats, param_stats, actions
|
||||
|
||||
|
||||
def save_policy_to_safetensors(output_dir: Path, ds_repo_id: str, policy_name: str, policy_kwargs: dict):
|
||||
def save_policy_to_safetensors(
|
||||
output_dir: Path, ds_repo_id: str, policy_name: str, policy_kwargs: dict
|
||||
):
|
||||
if output_dir.exists():
|
||||
print(f"Overwrite existing safetensors in '{output_dir}':")
|
||||
print(f" - Validate with: `git add {output_dir}`")
|
||||
|
@ -108,7 +116,9 @@ def save_policy_to_safetensors(output_dir: Path, ds_repo_id: str, policy_name: s
|
|||
shutil.rmtree(output_dir)
|
||||
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
output_dict, grad_stats, param_stats, actions = get_policy_stats(ds_repo_id, policy_name, policy_kwargs)
|
||||
output_dict, grad_stats, param_stats, actions = get_policy_stats(
|
||||
ds_repo_id, policy_name, policy_kwargs
|
||||
)
|
||||
save_file(output_dict, output_dir / "output_dict.safetensors")
|
||||
save_file(grad_stats, output_dir / "grad_stats.safetensors")
|
||||
save_file(param_stats, output_dir / "param_stats.safetensors")
|
||||
|
@ -141,5 +151,7 @@ if __name__ == "__main__":
|
|||
raise RuntimeError("No policies were provided!")
|
||||
for ds_repo_id, policy, policy_kwargs, file_name_extra in artifacts_cfg:
|
||||
ds_name = ds_repo_id.split("/")[-1]
|
||||
output_dir = Path("tests/artifacts/policies") / f"{ds_name}_{policy}_{file_name_extra}"
|
||||
output_dir = (
|
||||
Path("tests/artifacts/policies") / f"{ds_name}_{policy}_{file_name_extra}"
|
||||
)
|
||||
save_policy_to_safetensors(output_dir, ds_repo_id, policy, policy_kwargs)
|
||||
|
|
|
@ -226,7 +226,13 @@ def test_save_images_from_cameras(tmp_path, request, camera_type, mock):
|
|||
@pytest.mark.parametrize("camera_type, mock", TEST_CAMERA_TYPES)
|
||||
@require_camera
|
||||
def test_camera_rotation(request, camera_type, mock):
|
||||
config_kwargs = {"camera_type": camera_type, "mock": mock, "width": 640, "height": 480, "fps": 30}
|
||||
config_kwargs = {
|
||||
"camera_type": camera_type,
|
||||
"mock": mock,
|
||||
"width": 640,
|
||||
"height": 480,
|
||||
"fps": 30,
|
||||
}
|
||||
|
||||
# No rotation.
|
||||
camera = make_camera(**config_kwargs, rotation=None)
|
||||
|
|
|
@ -9,7 +9,9 @@ from lerobot.common.envs.configs import EnvConfig
|
|||
from lerobot.configs.parser import PluginLoadError, load_plugin, parse_plugin_args, wrap
|
||||
|
||||
|
||||
def create_plugin_code(*, base_class: str = "EnvConfig", plugin_name: str = "test_env") -> str:
|
||||
def create_plugin_code(
|
||||
*, base_class: str = "EnvConfig", plugin_name: str = "test_env"
|
||||
) -> str:
|
||||
"""Creates a dummy plugin module that implements its own EnvConfig subclass."""
|
||||
return f"""
|
||||
from dataclasses import dataclass
|
||||
|
|
|
@ -31,7 +31,11 @@ from lerobot.common.datasets.compute_stats import (
|
|||
|
||||
|
||||
def mock_load_image_as_numpy(path, dtype, channel_first):
|
||||
return np.ones((3, 32, 32), dtype=dtype) if channel_first else np.ones((32, 32, 3), dtype=dtype)
|
||||
return (
|
||||
np.ones((3, 32, 32), dtype=dtype)
|
||||
if channel_first
|
||||
else np.ones((32, 32, 3), dtype=dtype)
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
@ -61,7 +65,10 @@ def test_sample_indices():
|
|||
assert len(indices) == estimate_num_samples(10)
|
||||
|
||||
|
||||
@patch("lerobot.common.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy)
|
||||
@patch(
|
||||
"lerobot.common.datasets.compute_stats.load_image_as_numpy",
|
||||
side_effect=mock_load_image_as_numpy,
|
||||
)
|
||||
def test_sample_images(mock_load):
|
||||
image_paths = [f"image_{i}.jpg" for i in range(100)]
|
||||
images = sample_images(image_paths)
|
||||
|
@ -74,9 +81,20 @@ def test_sample_images(mock_load):
|
|||
def test_get_feature_stats_images():
|
||||
data = np.random.rand(100, 3, 32, 32)
|
||||
stats = get_feature_stats(data, axis=(0, 2, 3), keepdims=True)
|
||||
assert "min" in stats and "max" in stats and "mean" in stats and "std" in stats and "count" in stats
|
||||
assert (
|
||||
"min" in stats
|
||||
and "max" in stats
|
||||
and "mean" in stats
|
||||
and "std" in stats
|
||||
and "count" in stats
|
||||
)
|
||||
np.testing.assert_equal(stats["count"], np.array([100]))
|
||||
assert stats["min"].shape == stats["max"].shape == stats["mean"].shape == stats["std"].shape
|
||||
assert (
|
||||
stats["min"].shape
|
||||
== stats["max"].shape
|
||||
== stats["mean"].shape
|
||||
== stats["std"].shape
|
||||
)
|
||||
|
||||
|
||||
def test_get_feature_stats_axis_0_keepdims(sample_array):
|
||||
|
@ -145,7 +163,8 @@ def test_compute_episode_stats():
|
|||
}
|
||||
|
||||
with patch(
|
||||
"lerobot.common.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy
|
||||
"lerobot.common.datasets.compute_stats.load_image_as_numpy",
|
||||
side_effect=mock_load_image_as_numpy,
|
||||
):
|
||||
stats = compute_episode_stats(episode_data, features)
|
||||
|
||||
|
@ -233,7 +252,13 @@ def test_aggregate_stats():
|
|||
"std": [2.87, 5.87, 8.87],
|
||||
"count": 10,
|
||||
},
|
||||
"observation.state": {"min": 1, "max": 10, "mean": 5.5, "std": 2.87, "count": 10},
|
||||
"observation.state": {
|
||||
"min": 1,
|
||||
"max": 10,
|
||||
"mean": 5.5,
|
||||
"std": 2.87,
|
||||
"count": 10,
|
||||
},
|
||||
"extra_key_0": {"min": 5, "max": 25, "mean": 15, "std": 6, "count": 6},
|
||||
},
|
||||
{
|
||||
|
@ -244,7 +269,13 @@ def test_aggregate_stats():
|
|||
"std": [3.42, 2.42, 1.42],
|
||||
"count": 15,
|
||||
},
|
||||
"observation.state": {"min": 2, "max": 15, "mean": 8.5, "std": 3.42, "count": 15},
|
||||
"observation.state": {
|
||||
"min": 2,
|
||||
"max": 15,
|
||||
"mean": 8.5,
|
||||
"std": 3.42,
|
||||
"count": 15,
|
||||
},
|
||||
"extra_key_1": {"min": 0, "max": 20, "mean": 10, "std": 5, "count": 5},
|
||||
},
|
||||
]
|
||||
|
@ -284,28 +315,47 @@ def test_aggregate_stats():
|
|||
for ep_stats in all_stats:
|
||||
for fkey, stats in ep_stats.items():
|
||||
for k in stats:
|
||||
stats[k] = np.array(stats[k], dtype=np.int64 if k == "count" else np.float32)
|
||||
stats[k] = np.array(
|
||||
stats[k], dtype=np.int64 if k == "count" else np.float32
|
||||
)
|
||||
if fkey == "observation.image" and k != "count":
|
||||
stats[k] = stats[k].reshape(3, 1, 1) # for normalization on image channels
|
||||
stats[k] = stats[k].reshape(
|
||||
3, 1, 1
|
||||
) # for normalization on image channels
|
||||
else:
|
||||
stats[k] = stats[k].reshape(1)
|
||||
|
||||
# cast to numpy
|
||||
for fkey, stats in expected_agg_stats.items():
|
||||
for k in stats:
|
||||
stats[k] = np.array(stats[k], dtype=np.int64 if k == "count" else np.float32)
|
||||
stats[k] = np.array(
|
||||
stats[k], dtype=np.int64 if k == "count" else np.float32
|
||||
)
|
||||
if fkey == "observation.image" and k != "count":
|
||||
stats[k] = stats[k].reshape(3, 1, 1) # for normalization on image channels
|
||||
stats[k] = stats[k].reshape(
|
||||
3, 1, 1
|
||||
) # for normalization on image channels
|
||||
else:
|
||||
stats[k] = stats[k].reshape(1)
|
||||
|
||||
results = aggregate_stats(all_stats)
|
||||
|
||||
for fkey in expected_agg_stats:
|
||||
np.testing.assert_allclose(results[fkey]["min"], expected_agg_stats[fkey]["min"])
|
||||
np.testing.assert_allclose(results[fkey]["max"], expected_agg_stats[fkey]["max"])
|
||||
np.testing.assert_allclose(results[fkey]["mean"], expected_agg_stats[fkey]["mean"])
|
||||
np.testing.assert_allclose(
|
||||
results[fkey]["std"], expected_agg_stats[fkey]["std"], atol=1e-04, rtol=1e-04
|
||||
results[fkey]["min"], expected_agg_stats[fkey]["min"]
|
||||
)
|
||||
np.testing.assert_allclose(
|
||||
results[fkey]["max"], expected_agg_stats[fkey]["max"]
|
||||
)
|
||||
np.testing.assert_allclose(
|
||||
results[fkey]["mean"], expected_agg_stats[fkey]["mean"]
|
||||
)
|
||||
np.testing.assert_allclose(
|
||||
results[fkey]["std"],
|
||||
expected_agg_stats[fkey]["std"],
|
||||
atol=1e-04,
|
||||
rtol=1e-04,
|
||||
)
|
||||
np.testing.assert_allclose(
|
||||
results[fkey]["count"], expected_agg_stats[fkey]["count"]
|
||||
)
|
||||
np.testing.assert_allclose(results[fkey]["count"], expected_agg_stats[fkey]["count"])
|
||||
|
|
|
@ -72,7 +72,9 @@ def test_same_attributes_defined(tmp_path, lerobot_dataset_factory):
|
|||
# Instantiate both ways
|
||||
robot = make_robot("koch", mock=True)
|
||||
root_create = tmp_path / "create"
|
||||
dataset_create = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, robot=robot, root=root_create)
|
||||
dataset_create = LeRobotDataset.create(
|
||||
repo_id=DUMMY_REPO_ID, fps=30, robot=robot, root=root_create
|
||||
)
|
||||
|
||||
root_init = tmp_path / "init"
|
||||
dataset_init = lerobot_dataset_factory(root=root_init)
|
||||
|
@ -104,7 +106,8 @@ def test_add_frame_missing_task(tmp_path, empty_lerobot_dataset_factory):
|
|||
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
with pytest.raises(
|
||||
ValueError, match="Feature mismatch in `frame` dictionary:\nMissing features: {'task'}\n"
|
||||
ValueError,
|
||||
match="Feature mismatch in `frame` dictionary:\nMissing features: {'task'}\n",
|
||||
):
|
||||
dataset.add_frame({"state": torch.randn(1)})
|
||||
|
||||
|
@ -113,7 +116,8 @@ def test_add_frame_missing_feature(tmp_path, empty_lerobot_dataset_factory):
|
|||
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
with pytest.raises(
|
||||
ValueError, match="Feature mismatch in `frame` dictionary:\nMissing features: {'state'}\n"
|
||||
ValueError,
|
||||
match="Feature mismatch in `frame` dictionary:\nMissing features: {'state'}\n",
|
||||
):
|
||||
dataset.add_frame({"task": "Dummy task"})
|
||||
|
||||
|
@ -122,18 +126,24 @@ def test_add_frame_extra_feature(tmp_path, empty_lerobot_dataset_factory):
|
|||
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
with pytest.raises(
|
||||
ValueError, match="Feature mismatch in `frame` dictionary:\nExtra features: {'extra'}\n"
|
||||
ValueError,
|
||||
match="Feature mismatch in `frame` dictionary:\nExtra features: {'extra'}\n",
|
||||
):
|
||||
dataset.add_frame({"state": torch.randn(1), "task": "Dummy task", "extra": "dummy_extra"})
|
||||
dataset.add_frame(
|
||||
{"state": torch.randn(1), "task": "Dummy task", "extra": "dummy_extra"}
|
||||
)
|
||||
|
||||
|
||||
def test_add_frame_wrong_type(tmp_path, empty_lerobot_dataset_factory):
|
||||
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
with pytest.raises(
|
||||
ValueError, match="The feature 'state' of dtype 'float16' is not of the expected dtype 'float32'.\n"
|
||||
ValueError,
|
||||
match="The feature 'state' of dtype 'float16' is not of the expected dtype 'float32'.\n",
|
||||
):
|
||||
dataset.add_frame({"state": torch.randn(1, dtype=torch.float16), "task": "Dummy task"})
|
||||
dataset.add_frame(
|
||||
{"state": torch.randn(1, dtype=torch.float16), "task": "Dummy task"}
|
||||
)
|
||||
|
||||
|
||||
def test_add_frame_wrong_shape(tmp_path, empty_lerobot_dataset_factory):
|
||||
|
@ -141,7 +151,9 @@ def test_add_frame_wrong_shape(tmp_path, empty_lerobot_dataset_factory):
|
|||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=re.escape("The feature 'state' of shape '(1,)' does not have the expected shape '(2,)'.\n"),
|
||||
match=re.escape(
|
||||
"The feature 'state' of shape '(1,)' does not have the expected shape '(2,)'.\n"
|
||||
),
|
||||
):
|
||||
dataset.add_frame({"state": torch.randn(1), "task": "Dummy task"})
|
||||
|
||||
|
@ -163,7 +175,9 @@ def test_add_frame_wrong_shape_torch_ndim_0(tmp_path, empty_lerobot_dataset_fact
|
|||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=re.escape("The feature 'state' of shape '()' does not have the expected shape '(1,)'.\n"),
|
||||
match=re.escape(
|
||||
"The feature 'state' of shape '()' does not have the expected shape '(1,)'.\n"
|
||||
),
|
||||
):
|
||||
dataset.add_frame({"state": torch.tensor(1.0), "task": "Dummy task"})
|
||||
|
||||
|
@ -457,7 +471,9 @@ def test_flatten_unflatten_dict():
|
|||
d = unflatten_dict(flatten_dict(d))
|
||||
|
||||
# test equality between nested dicts
|
||||
assert json.dumps(original_d, sort_keys=True) == json.dumps(d, sort_keys=True), f"{original_d} != {d}"
|
||||
assert json.dumps(original_d, sort_keys=True) == json.dumps(d, sort_keys=True), (
|
||||
f"{original_d} != {d}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
@ -511,7 +527,13 @@ def test_backward_compatibility(repo_id):
|
|||
load_and_compare(i + 1)
|
||||
|
||||
# test 2 frames at the middle of first episode
|
||||
i = int((dataset.episode_data_index["to"][0].item() - dataset.episode_data_index["from"][0].item()) / 2)
|
||||
i = int(
|
||||
(
|
||||
dataset.episode_data_index["to"][0].item()
|
||||
- dataset.episode_data_index["from"][0].item()
|
||||
)
|
||||
/ 2
|
||||
)
|
||||
load_and_compare(i)
|
||||
load_and_compare(i + 1)
|
||||
|
||||
|
|
|
@ -54,7 +54,9 @@ def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> dict[str, np.n
|
|||
|
||||
@pytest.fixture(scope="module")
|
||||
def synced_timestamps_factory(hf_dataset_factory):
|
||||
def _create_synced_timestamps(fps: int = 30) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||
def _create_synced_timestamps(
|
||||
fps: int = 30,
|
||||
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||
hf_dataset = hf_dataset_factory(fps=fps)
|
||||
timestamps = torch.stack(hf_dataset["timestamp"]).numpy()
|
||||
episode_indices = torch.stack(hf_dataset["episode_index"]).numpy()
|
||||
|
@ -69,8 +71,12 @@ def unsynced_timestamps_factory(synced_timestamps_factory):
|
|||
def _create_unsynced_timestamps(
|
||||
fps: int = 30, tolerance_s: float = 1e-4
|
||||
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||
timestamps, episode_indices, episode_data_index = synced_timestamps_factory(fps=fps)
|
||||
timestamps[30] += tolerance_s * 1.1 # Modify a single timestamp just outside tolerance
|
||||
timestamps, episode_indices, episode_data_index = synced_timestamps_factory(
|
||||
fps=fps
|
||||
)
|
||||
timestamps[30] += (
|
||||
tolerance_s * 1.1
|
||||
) # Modify a single timestamp just outside tolerance
|
||||
return timestamps, episode_indices, episode_data_index
|
||||
|
||||
return _create_unsynced_timestamps
|
||||
|
@ -81,8 +87,12 @@ def slightly_off_timestamps_factory(synced_timestamps_factory):
|
|||
def _create_slightly_off_timestamps(
|
||||
fps: int = 30, tolerance_s: float = 1e-4
|
||||
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||
timestamps, episode_indices, episode_data_index = synced_timestamps_factory(fps=fps)
|
||||
timestamps[30] += tolerance_s * 0.9 # Modify a single timestamp just inside tolerance
|
||||
timestamps, episode_indices, episode_data_index = synced_timestamps_factory(
|
||||
fps=fps
|
||||
)
|
||||
timestamps[30] += (
|
||||
tolerance_s * 0.9
|
||||
) # Modify a single timestamp just inside tolerance
|
||||
return timestamps, episode_indices, episode_data_index
|
||||
|
||||
return _create_slightly_off_timestamps
|
||||
|
@ -91,9 +101,13 @@ def slightly_off_timestamps_factory(synced_timestamps_factory):
|
|||
@pytest.fixture(scope="module")
|
||||
def valid_delta_timestamps_factory():
|
||||
def _create_valid_delta_timestamps(
|
||||
fps: int = 30, keys: list = DUMMY_MOTOR_FEATURES, min_max_range: tuple[int, int] = (-10, 10)
|
||||
fps: int = 30,
|
||||
keys: list = DUMMY_MOTOR_FEATURES,
|
||||
min_max_range: tuple[int, int] = (-10, 10),
|
||||
) -> dict:
|
||||
delta_timestamps = {key: [i * (1 / fps) for i in range(*min_max_range)] for key in keys}
|
||||
delta_timestamps = {
|
||||
key: [i * (1 / fps) for i in range(*min_max_range)] for key in keys
|
||||
}
|
||||
return delta_timestamps
|
||||
|
||||
return _create_valid_delta_timestamps
|
||||
|
@ -130,7 +144,9 @@ def slightly_off_delta_timestamps_factory(valid_delta_timestamps_factory):
|
|||
|
||||
@pytest.fixture(scope="module")
|
||||
def delta_indices_factory():
|
||||
def _delta_indices(keys: list = DUMMY_MOTOR_FEATURES, min_max_range: tuple[int, int] = (-10, 10)) -> dict:
|
||||
def _delta_indices(
|
||||
keys: list = DUMMY_MOTOR_FEATURES, min_max_range: tuple[int, int] = (-10, 10)
|
||||
) -> dict:
|
||||
return {key: list(range(*min_max_range)) for key in keys}
|
||||
|
||||
return _delta_indices
|
||||
|
@ -182,7 +198,9 @@ def test_check_timestamps_sync_unsynced_no_exception(unsynced_timestamps_factory
|
|||
def test_check_timestamps_sync_slightly_off(slightly_off_timestamps_factory):
|
||||
fps = 30
|
||||
tolerance_s = 1e-4
|
||||
timestamps, ep_idx, ep_data_index = slightly_off_timestamps_factory(fps, tolerance_s)
|
||||
timestamps, ep_idx, ep_data_index = slightly_off_timestamps_factory(
|
||||
fps, tolerance_s
|
||||
)
|
||||
result = check_timestamps_sync(
|
||||
timestamps=timestamps,
|
||||
episode_indices=ep_idx,
|
||||
|
@ -223,7 +241,9 @@ def test_check_delta_timestamps_valid(valid_delta_timestamps_factory):
|
|||
def test_check_delta_timestamps_slightly_off(slightly_off_delta_timestamps_factory):
|
||||
fps = 30
|
||||
tolerance_s = 1e-4
|
||||
slightly_off_delta_timestamps = slightly_off_delta_timestamps_factory(fps, tolerance_s)
|
||||
slightly_off_delta_timestamps = slightly_off_delta_timestamps_factory(
|
||||
fps, tolerance_s
|
||||
)
|
||||
result = check_delta_timestamps(
|
||||
delta_timestamps=slightly_off_delta_timestamps,
|
||||
fps=fps,
|
||||
|
|
|
@ -33,7 +33,9 @@ from lerobot.scripts.visualize_image_transforms import (
|
|||
save_all_transforms,
|
||||
save_each_transform,
|
||||
)
|
||||
from tests.artifacts.image_transforms.save_image_transforms_to_safetensors import ARTIFACT_DIR
|
||||
from tests.artifacts.image_transforms.save_image_transforms_to_safetensors import (
|
||||
ARTIFACT_DIR,
|
||||
)
|
||||
from tests.utils import require_x86_64_kernel
|
||||
|
||||
|
||||
|
@ -80,7 +82,11 @@ def test_get_image_transforms_brightness(img_tensor_factory, min_max):
|
|||
img_tensor = img_tensor_factory()
|
||||
tf_cfg = ImageTransformsConfig(
|
||||
enable=True,
|
||||
tfs={"brightness": ImageTransformConfig(type="ColorJitter", kwargs={"brightness": min_max})},
|
||||
tfs={
|
||||
"brightness": ImageTransformConfig(
|
||||
type="ColorJitter", kwargs={"brightness": min_max}
|
||||
)
|
||||
},
|
||||
)
|
||||
tf_actual = ImageTransforms(tf_cfg)
|
||||
tf_expected = v2.ColorJitter(brightness=min_max)
|
||||
|
@ -91,7 +97,12 @@ def test_get_image_transforms_brightness(img_tensor_factory, min_max):
|
|||
def test_get_image_transforms_contrast(img_tensor_factory, min_max):
|
||||
img_tensor = img_tensor_factory()
|
||||
tf_cfg = ImageTransformsConfig(
|
||||
enable=True, tfs={"contrast": ImageTransformConfig(type="ColorJitter", kwargs={"contrast": min_max})}
|
||||
enable=True,
|
||||
tfs={
|
||||
"contrast": ImageTransformConfig(
|
||||
type="ColorJitter", kwargs={"contrast": min_max}
|
||||
)
|
||||
},
|
||||
)
|
||||
tf_actual = ImageTransforms(tf_cfg)
|
||||
tf_expected = v2.ColorJitter(contrast=min_max)
|
||||
|
@ -103,7 +114,11 @@ def test_get_image_transforms_saturation(img_tensor_factory, min_max):
|
|||
img_tensor = img_tensor_factory()
|
||||
tf_cfg = ImageTransformsConfig(
|
||||
enable=True,
|
||||
tfs={"saturation": ImageTransformConfig(type="ColorJitter", kwargs={"saturation": min_max})},
|
||||
tfs={
|
||||
"saturation": ImageTransformConfig(
|
||||
type="ColorJitter", kwargs={"saturation": min_max}
|
||||
)
|
||||
},
|
||||
)
|
||||
tf_actual = ImageTransforms(tf_cfg)
|
||||
tf_expected = v2.ColorJitter(saturation=min_max)
|
||||
|
@ -114,7 +129,8 @@ def test_get_image_transforms_saturation(img_tensor_factory, min_max):
|
|||
def test_get_image_transforms_hue(img_tensor_factory, min_max):
|
||||
img_tensor = img_tensor_factory()
|
||||
tf_cfg = ImageTransformsConfig(
|
||||
enable=True, tfs={"hue": ImageTransformConfig(type="ColorJitter", kwargs={"hue": min_max})}
|
||||
enable=True,
|
||||
tfs={"hue": ImageTransformConfig(type="ColorJitter", kwargs={"hue": min_max})},
|
||||
)
|
||||
tf_actual = ImageTransforms(tf_cfg)
|
||||
tf_expected = v2.ColorJitter(hue=min_max)
|
||||
|
@ -126,7 +142,11 @@ def test_get_image_transforms_sharpness(img_tensor_factory, min_max):
|
|||
img_tensor = img_tensor_factory()
|
||||
tf_cfg = ImageTransformsConfig(
|
||||
enable=True,
|
||||
tfs={"sharpness": ImageTransformConfig(type="SharpnessJitter", kwargs={"sharpness": min_max})},
|
||||
tfs={
|
||||
"sharpness": ImageTransformConfig(
|
||||
type="SharpnessJitter", kwargs={"sharpness": min_max}
|
||||
)
|
||||
},
|
||||
)
|
||||
tf_actual = ImageTransforms(tf_cfg)
|
||||
tf_expected = SharpnessJitter(sharpness=min_max)
|
||||
|
@ -342,7 +362,9 @@ def test_save_all_transforms(img_tensor_factory, tmp_path):
|
|||
|
||||
# Check if the combined transforms directory exists and contains the right files
|
||||
combined_transforms_dir = tmp_path / "all"
|
||||
assert combined_transforms_dir.exists(), "Combined transforms directory was not created."
|
||||
assert combined_transforms_dir.exists(), (
|
||||
"Combined transforms directory was not created."
|
||||
)
|
||||
assert any(combined_transforms_dir.iterdir()), (
|
||||
"No transformed images found in combined transforms directory."
|
||||
)
|
||||
|
@ -364,9 +386,9 @@ def test_save_each_transform(img_tensor_factory, tmp_path):
|
|||
for transform in transforms:
|
||||
transform_dir = tmp_path / transform
|
||||
assert transform_dir.exists(), f"{transform} directory was not created."
|
||||
assert any(
|
||||
transform_dir.iterdir()
|
||||
), f"No transformed images found in {transform} directory."
|
||||
assert any(transform_dir.iterdir()), (
|
||||
f"No transformed images found in {transform} directory."
|
||||
)
|
||||
|
||||
# Check for specific files within each transform directory
|
||||
expected_files = [f"{i}.png" for i in range(1, n_examples + 1)] + [
|
||||
|
|
|
@ -176,7 +176,9 @@ def test_delta_timestamps_within_tolerance():
|
|||
buffer.tolerance_s = 0.04
|
||||
item = buffer[2]
|
||||
data, is_pad = item["index"], item[f"index{OnlineBuffer.IS_PAD_POSTFIX}"]
|
||||
torch.testing.assert_close(data, torch.tensor([0, 2, 3]), msg="Data does not match expected values")
|
||||
torch.testing.assert_close(
|
||||
data, torch.tensor([0, 2, 3]), msg="Data does not match expected values"
|
||||
)
|
||||
assert not is_pad.any(), "Unexpected padding detected"
|
||||
|
||||
|
||||
|
@ -212,7 +214,9 @@ def test_delta_timestamps_outside_tolerance_outside_episode_range():
|
|||
buffer.tolerance_s = 0.04
|
||||
item = buffer[2]
|
||||
data, is_pad = item["index"], item["index_is_pad"]
|
||||
assert torch.equal(data, torch.tensor([0, 0, 2, 4, 4])), "Data does not match expected values"
|
||||
assert torch.equal(data, torch.tensor([0, 0, 2, 4, 4])), (
|
||||
"Data does not match expected values"
|
||||
)
|
||||
assert torch.equal(is_pad, torch.tensor([True, False, False, True, True])), (
|
||||
"Padding does not match expected values"
|
||||
)
|
||||
|
@ -275,7 +279,8 @@ def test_compute_sampler_weights_nontrivial_ratio(lerobot_dataset_factory, tmp_p
|
|||
online_sampling_ratio=online_sampling_ratio,
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
weights, torch.tensor([0.05, 0.05, 0.05, 0.05, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1])
|
||||
weights,
|
||||
torch.tensor([0.05, 0.05, 0.05, 0.05, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]),
|
||||
)
|
||||
|
||||
|
||||
|
@ -297,7 +302,8 @@ def test_compute_sampler_weights_nontrivial_ratio_and_drop_last_n(
|
|||
online_drop_n_last_frames=1,
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
weights, torch.tensor([0.05, 0.05, 0.05, 0.05, 0.2, 0.0, 0.2, 0.0, 0.2, 0.0, 0.2, 0.0])
|
||||
weights,
|
||||
torch.tensor([0.05, 0.05, 0.05, 0.05, 0.2, 0.0, 0.2, 0.0, 0.2, 0.0, 0.2, 0.0]),
|
||||
)
|
||||
|
||||
|
||||
|
@ -318,4 +324,6 @@ def test_compute_sampler_weights_drop_n_last_frames(lerobot_dataset_factory, tmp
|
|||
online_sampling_ratio=0.5,
|
||||
online_drop_n_last_frames=1,
|
||||
)
|
||||
torch.testing.assert_close(weights, torch.tensor([0.5, 0, 0.125, 0, 0.125, 0, 0.125, 0, 0.125, 0]))
|
||||
torch.testing.assert_close(
|
||||
weights, torch.tensor([0.5, 0, 0.125, 0, 0.125, 0, 0.125, 0, 0.125, 0])
|
||||
)
|
||||
|
|
|
@ -18,8 +18,13 @@ import torch
|
|||
from datasets import Dataset
|
||||
from huggingface_hub import DatasetCard
|
||||
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import calculate_episode_data_index
|
||||
from lerobot.common.datasets.utils import create_lerobot_dataset_card, hf_transform_to_torch
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import (
|
||||
calculate_episode_data_index,
|
||||
)
|
||||
from lerobot.common.datasets.utils import (
|
||||
create_lerobot_dataset_card,
|
||||
hf_transform_to_torch,
|
||||
)
|
||||
|
||||
|
||||
def test_default_parameters():
|
||||
|
|
|
@ -210,7 +210,10 @@ def tasks_factory():
|
|||
def _create_tasks(total_tasks: int = 3) -> int:
|
||||
tasks = {}
|
||||
for task_index in range(total_tasks):
|
||||
task_dict = {"task_index": task_index, "task": f"Perform action {task_index}."}
|
||||
task_dict = {
|
||||
"task_index": task_index,
|
||||
"task": f"Perform action {task_index}.",
|
||||
}
|
||||
tasks[task_index] = task_dict
|
||||
return tasks
|
||||
|
||||
|
@ -297,8 +300,12 @@ def hf_dataset_factory(
|
|||
episode_index_col = np.array([], dtype=np.int64)
|
||||
task_index = np.array([], dtype=np.int64)
|
||||
for ep_dict in episodes.values():
|
||||
timestamp_col = np.concatenate((timestamp_col, np.arange(ep_dict["length"]) / fps))
|
||||
frame_index_col = np.concatenate((frame_index_col, np.arange(ep_dict["length"], dtype=int)))
|
||||
timestamp_col = np.concatenate(
|
||||
(timestamp_col, np.arange(ep_dict["length"]) / fps)
|
||||
)
|
||||
frame_index_col = np.concatenate(
|
||||
(frame_index_col, np.arange(ep_dict["length"], dtype=int))
|
||||
)
|
||||
episode_index_col = np.concatenate(
|
||||
(
|
||||
episode_index_col,
|
||||
|
@ -385,7 +392,9 @@ def lerobot_dataset_metadata_factory(
|
|||
episodes=episodes,
|
||||
)
|
||||
with (
|
||||
patch("lerobot.common.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version_patch,
|
||||
patch(
|
||||
"lerobot.common.datasets.lerobot_dataset.get_safe_version"
|
||||
) as mock_get_safe_version_patch,
|
||||
patch(
|
||||
"lerobot.common.datasets.lerobot_dataset.snapshot_download"
|
||||
) as mock_snapshot_download_patch,
|
||||
|
@ -433,7 +442,9 @@ def lerobot_dataset_factory(
|
|||
if not stats:
|
||||
stats = stats_factory(features=info["features"])
|
||||
if not episodes_stats:
|
||||
episodes_stats = episodes_stats_factory(features=info["features"], total_episodes=total_episodes)
|
||||
episodes_stats = episodes_stats_factory(
|
||||
features=info["features"], total_episodes=total_episodes
|
||||
)
|
||||
if not tasks:
|
||||
tasks = tasks_factory(total_tasks=info["total_tasks"])
|
||||
if not episode_dicts:
|
||||
|
@ -466,8 +477,12 @@ def lerobot_dataset_factory(
|
|||
episodes=episode_dicts,
|
||||
)
|
||||
with (
|
||||
patch("lerobot.common.datasets.lerobot_dataset.LeRobotDatasetMetadata") as mock_metadata_patch,
|
||||
patch("lerobot.common.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version_patch,
|
||||
patch(
|
||||
"lerobot.common.datasets.lerobot_dataset.LeRobotDatasetMetadata"
|
||||
) as mock_metadata_patch,
|
||||
patch(
|
||||
"lerobot.common.datasets.lerobot_dataset.get_safe_version"
|
||||
) as mock_get_safe_version_patch,
|
||||
patch(
|
||||
"lerobot.common.datasets.lerobot_dataset.snapshot_download"
|
||||
) as mock_snapshot_download_patch,
|
||||
|
|
|
@ -59,7 +59,9 @@ def stats_path(stats_factory):
|
|||
|
||||
@pytest.fixture(scope="session")
|
||||
def episodes_stats_path(episodes_stats_factory):
|
||||
def _create_episodes_stats_jsonl_file(dir: Path, episodes_stats: list[dict] | None = None) -> Path:
|
||||
def _create_episodes_stats_jsonl_file(
|
||||
dir: Path, episodes_stats: list[dict] | None = None
|
||||
) -> Path:
|
||||
if not episodes_stats:
|
||||
episodes_stats = episodes_stats_factory()
|
||||
fpath = dir / EPISODES_STATS_PATH
|
||||
|
|
|
@ -99,7 +99,13 @@ def mock_snapshot_download_factory(
|
|||
|
||||
# List all possible files
|
||||
all_files = []
|
||||
meta_files = [INFO_PATH, STATS_PATH, EPISODES_STATS_PATH, TASKS_PATH, EPISODES_PATH]
|
||||
meta_files = [
|
||||
INFO_PATH,
|
||||
STATS_PATH,
|
||||
EPISODES_STATS_PATH,
|
||||
TASKS_PATH,
|
||||
EPISODES_PATH,
|
||||
]
|
||||
all_files.extend(meta_files)
|
||||
|
||||
data_files = []
|
||||
|
|
|
@ -35,5 +35,7 @@ def optimizer(model_params):
|
|||
|
||||
@pytest.fixture
|
||||
def scheduler(optimizer):
|
||||
config = VQBeTSchedulerConfig(num_warmup_steps=10, num_vqvae_training_steps=20, num_cycles=0.5)
|
||||
config = VQBeTSchedulerConfig(
|
||||
num_warmup_steps=10, num_vqvae_training_steps=20, num_cycles=0.5
|
||||
)
|
||||
return config.build(optimizer, num_training_steps=100)
|
||||
|
|
|
@ -43,7 +43,9 @@ def test_diffuser_scheduler(optimizer):
|
|||
|
||||
|
||||
def test_vqbet_scheduler(optimizer):
|
||||
config = VQBeTSchedulerConfig(num_warmup_steps=10, num_vqvae_training_steps=20, num_cycles=0.5)
|
||||
config = VQBeTSchedulerConfig(
|
||||
num_warmup_steps=10, num_vqvae_training_steps=20, num_cycles=0.5
|
||||
)
|
||||
scheduler = config.build(optimizer, num_training_steps=100)
|
||||
assert isinstance(scheduler, LambdaLR)
|
||||
|
||||
|
|
|
@ -59,16 +59,33 @@ def dummy_dataset_metadata(lerobot_dataset_metadata_factory, info_factory, tmp_p
|
|||
"action": {
|
||||
"dtype": "float32",
|
||||
"shape": (6,),
|
||||
"names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"],
|
||||
"names": [
|
||||
"shoulder_pan",
|
||||
"shoulder_lift",
|
||||
"elbow_flex",
|
||||
"wrist_flex",
|
||||
"wrist_roll",
|
||||
"gripper",
|
||||
],
|
||||
},
|
||||
"observation.state": {
|
||||
"dtype": "float32",
|
||||
"shape": (6,),
|
||||
"names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"],
|
||||
"names": [
|
||||
"shoulder_pan",
|
||||
"shoulder_lift",
|
||||
"elbow_flex",
|
||||
"wrist_flex",
|
||||
"wrist_roll",
|
||||
"gripper",
|
||||
],
|
||||
},
|
||||
}
|
||||
info = info_factory(
|
||||
total_episodes=1, total_frames=1, camera_features=camera_features, motor_features=motor_features
|
||||
total_episodes=1,
|
||||
total_frames=1,
|
||||
camera_features=camera_features,
|
||||
motor_features=motor_features,
|
||||
)
|
||||
ds_meta = lerobot_dataset_metadata_factory(root=tmp_path / "init", info=info)
|
||||
return ds_meta
|
||||
|
@ -81,7 +98,8 @@ def test_get_policy_and_config_classes(policy_name: str):
|
|||
policy_cfg = make_policy_config(policy_name)
|
||||
assert policy_cls.name == policy_name
|
||||
assert issubclass(
|
||||
policy_cfg.__class__, inspect.signature(policy_cls.__init__).parameters["config"].annotation
|
||||
policy_cfg.__class__,
|
||||
inspect.signature(policy_cls.__init__).parameters["config"].annotation,
|
||||
)
|
||||
|
||||
|
||||
|
@ -92,7 +110,13 @@ def test_get_policy_and_config_classes(policy_name: str):
|
|||
("lerobot/pusht", "pusht", {}, "diffusion", {}),
|
||||
("lerobot/pusht", "pusht", {}, "vqbet", {}),
|
||||
("lerobot/pusht", "pusht", {}, "act", {}),
|
||||
("lerobot/aloha_sim_insertion_human", "aloha", {"task": "AlohaInsertion-v0"}, "act", {}),
|
||||
(
|
||||
"lerobot/aloha_sim_insertion_human",
|
||||
"aloha",
|
||||
{"task": "AlohaInsertion-v0"},
|
||||
"act",
|
||||
{},
|
||||
),
|
||||
(
|
||||
"lerobot/aloha_sim_insertion_scripted",
|
||||
"aloha",
|
||||
|
@ -172,11 +196,13 @@ def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs):
|
|||
# Test updating the policy (and test that it does not mutate the batch)
|
||||
batch_ = deepcopy(batch)
|
||||
policy.forward(batch)
|
||||
assert set(batch) == set(
|
||||
batch_
|
||||
), "Batch keys are not the same after a forward pass."
|
||||
assert set(batch) == set(batch_), (
|
||||
"Batch keys are not the same after a forward pass."
|
||||
)
|
||||
assert all(
|
||||
torch.equal(batch[k], batch_[k]) if isinstance(batch[k], torch.Tensor) else batch[k] == batch_[k]
|
||||
torch.equal(batch[k], batch_[k])
|
||||
if isinstance(batch[k], torch.Tensor)
|
||||
else batch[k] == batch_[k]
|
||||
for k in batch
|
||||
), "Batch values are not the same after a forward pass."
|
||||
|
||||
|
@ -215,8 +241,12 @@ def test_act_backbone_lr():
|
|||
|
||||
cfg = TrainPipelineConfig(
|
||||
# TODO(rcadene, aliberts): remove dataset download
|
||||
dataset=DatasetConfig(repo_id="lerobot/aloha_sim_insertion_scripted", episodes=[0]),
|
||||
policy=make_policy_config("act", optimizer_lr=0.01, optimizer_lr_backbone=0.001),
|
||||
dataset=DatasetConfig(
|
||||
repo_id="lerobot/aloha_sim_insertion_scripted", episodes=[0]
|
||||
),
|
||||
policy=make_policy_config(
|
||||
"act", optimizer_lr=0.01, optimizer_lr_backbone=0.001
|
||||
),
|
||||
)
|
||||
cfg.validate() # Needed for auto-setting some parameters
|
||||
|
||||
|
@ -239,7 +269,9 @@ def test_policy_defaults(dummy_dataset_metadata, policy_name: str):
|
|||
policy_cls = get_policy_class(policy_name)
|
||||
policy_cfg = make_policy_config(policy_name)
|
||||
features = dataset_to_policy_features(dummy_dataset_metadata.features)
|
||||
policy_cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
|
||||
policy_cfg.output_features = {
|
||||
key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION
|
||||
}
|
||||
policy_cfg.input_features = {
|
||||
key: ft for key, ft in features.items() if key not in policy_cfg.output_features
|
||||
}
|
||||
|
@ -251,7 +283,9 @@ def test_save_and_load_pretrained(dummy_dataset_metadata, tmp_path, policy_name:
|
|||
policy_cls = get_policy_class(policy_name)
|
||||
policy_cfg = make_policy_config(policy_name)
|
||||
features = dataset_to_policy_features(dummy_dataset_metadata.features)
|
||||
policy_cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
|
||||
policy_cfg.output_features = {
|
||||
key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION
|
||||
}
|
||||
policy_cfg.input_features = {
|
||||
key: ft for key, ft in features.items() if key not in policy_cfg.output_features
|
||||
}
|
||||
|
@ -260,7 +294,9 @@ def test_save_and_load_pretrained(dummy_dataset_metadata, tmp_path, policy_name:
|
|||
save_dir = tmp_path / f"test_save_and_load_pretrained_{policy_cls.__name__}"
|
||||
policy.save_pretrained(save_dir)
|
||||
loaded_policy = policy_cls.from_pretrained(save_dir, config=policy_cfg)
|
||||
torch.testing.assert_close(list(policy.parameters()), list(loaded_policy.parameters()), rtol=0, atol=0)
|
||||
torch.testing.assert_close(
|
||||
list(policy.parameters()), list(loaded_policy.parameters()), rtol=0, atol=0
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("insert_temporal_dim", [False, True])
|
||||
|
@ -400,7 +436,9 @@ def test_normalize(insert_temporal_dim):
|
|||
# pass if it's run on another platform due to floating point errors
|
||||
@require_x86_64_kernel
|
||||
@require_cpu
|
||||
def test_backward_compatibility(ds_repo_id: str, policy_name: str, policy_kwargs: dict, file_name_extra: str):
|
||||
def test_backward_compatibility(
|
||||
ds_repo_id: str, policy_name: str, policy_kwargs: dict, file_name_extra: str
|
||||
):
|
||||
"""
|
||||
NOTE: If this test does not pass, and you have intentionally changed something in the policy:
|
||||
1. Inspect the differences in policy outputs and make sure you can account for them. Your PR should
|
||||
|
@ -414,13 +452,17 @@ def test_backward_compatibility(ds_repo_id: str, policy_name: str, policy_kwargs
|
|||
6. Remember to stage and commit the resulting changes to `tests/artifacts`.
|
||||
"""
|
||||
ds_name = ds_repo_id.split("/")[-1]
|
||||
artifact_dir = Path("tests/artifacts/policies") / f"{ds_name}_{policy_name}_{file_name_extra}"
|
||||
artifact_dir = (
|
||||
Path("tests/artifacts/policies") / f"{ds_name}_{policy_name}_{file_name_extra}"
|
||||
)
|
||||
saved_output_dict = load_file(artifact_dir / "output_dict.safetensors")
|
||||
saved_grad_stats = load_file(artifact_dir / "grad_stats.safetensors")
|
||||
saved_param_stats = load_file(artifact_dir / "param_stats.safetensors")
|
||||
saved_actions = load_file(artifact_dir / "actions.safetensors")
|
||||
|
||||
output_dict, grad_stats, param_stats, actions = get_policy_stats(ds_repo_id, policy_name, policy_kwargs)
|
||||
output_dict, grad_stats, param_stats, actions = get_policy_stats(
|
||||
ds_repo_id, policy_name, policy_kwargs
|
||||
)
|
||||
|
||||
for key in saved_output_dict:
|
||||
torch.testing.assert_close(output_dict[key], saved_output_dict[key])
|
||||
|
@ -429,8 +471,12 @@ def test_backward_compatibility(ds_repo_id: str, policy_name: str, policy_kwargs
|
|||
for key in saved_param_stats:
|
||||
torch.testing.assert_close(param_stats[key], saved_param_stats[key])
|
||||
for key in saved_actions:
|
||||
rtol, atol = (2e-3, 5e-6) if policy_name == "diffusion" else (None, None) # HACK
|
||||
torch.testing.assert_close(actions[key], saved_actions[key], rtol=rtol, atol=atol)
|
||||
rtol, atol = (
|
||||
(2e-3, 5e-6) if policy_name == "diffusion" else (None, None)
|
||||
) # HACK
|
||||
torch.testing.assert_close(
|
||||
actions[key], saved_actions[key], rtol=rtol, atol=atol
|
||||
)
|
||||
|
||||
|
||||
def test_act_temporal_ensembler():
|
||||
|
|
|
@ -180,7 +180,9 @@ def test_record_and_replay_and_policy(tmp_path, request, robot_type, mock):
|
|||
assert dataset.meta.total_episodes == 2
|
||||
assert len(dataset) == 2
|
||||
|
||||
replay_cfg = ReplayControlConfig(episode=0, fps=1, root=root, repo_id=repo_id, play_sounds=False)
|
||||
replay_cfg = ReplayControlConfig(
|
||||
episode=0, fps=1, root=root, repo_id=repo_id, play_sounds=False
|
||||
)
|
||||
replay(robot, replay_cfg)
|
||||
|
||||
policy_cfg = ACTConfig()
|
||||
|
@ -335,12 +337,12 @@ def test_record_with_event_rerecord_episode(tmp_path, request, robot_type, mock)
|
|||
)
|
||||
dataset = record(robot, rec_cfg)
|
||||
|
||||
assert not mock_events[
|
||||
"rerecord_episode"
|
||||
], "`rerecord_episode` wasn't properly reset to False"
|
||||
assert not mock_events[
|
||||
"exit_early"
|
||||
], "`exit_early` wasn't properly reset to False"
|
||||
assert not mock_events["rerecord_episode"], (
|
||||
"`rerecord_episode` wasn't properly reset to False"
|
||||
)
|
||||
assert not mock_events["exit_early"], (
|
||||
"`exit_early` wasn't properly reset to False"
|
||||
)
|
||||
assert len(dataset) == 1, "`dataset` should contain only 1 frame"
|
||||
|
||||
|
||||
|
@ -390,7 +392,9 @@ def test_record_with_event_exit_early(tmp_path, request, robot_type, mock):
|
|||
|
||||
dataset = record(robot, rec_cfg)
|
||||
|
||||
assert not mock_events["exit_early"], "`exit_early` wasn't properly reset to False"
|
||||
assert not mock_events["exit_early"], (
|
||||
"`exit_early` wasn't properly reset to False"
|
||||
)
|
||||
assert len(dataset) == 1, "`dataset` should contain only 1 frame"
|
||||
|
||||
|
||||
|
@ -399,7 +403,9 @@ def test_record_with_event_exit_early(tmp_path, request, robot_type, mock):
|
|||
[("koch", True, 0), ("koch", True, 1)],
|
||||
)
|
||||
@require_robot
|
||||
def test_record_with_event_stop_recording(tmp_path, request, robot_type, mock, num_image_writer_processes):
|
||||
def test_record_with_event_stop_recording(
|
||||
tmp_path, request, robot_type, mock, num_image_writer_processes
|
||||
):
|
||||
robot_kwargs = {"robot_type": robot_type, "mock": mock}
|
||||
|
||||
if mock:
|
||||
|
@ -445,5 +451,7 @@ def test_record_with_event_stop_recording(tmp_path, request, robot_type, mock, n
|
|||
|
||||
dataset = record(robot, rec_cfg)
|
||||
|
||||
assert not mock_events["exit_early"], "`exit_early` wasn't properly reset to False"
|
||||
assert not mock_events["exit_early"], (
|
||||
"`exit_early` wasn't properly reset to False"
|
||||
)
|
||||
assert len(dataset) == 1, "`dataset` should contain only 1 frame"
|
||||
|
|
|
@ -40,7 +40,10 @@ import pytest
|
|||
import torch
|
||||
|
||||
from lerobot.common.robot_devices.robots.utils import make_robot
|
||||
from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError
|
||||
from lerobot.common.robot_devices.utils import (
|
||||
RobotDeviceAlreadyConnectedError,
|
||||
RobotDeviceNotConnectedError,
|
||||
)
|
||||
from tests.utils import TEST_ROBOT_TYPES, mock_calibration_dir, require_robot
|
||||
|
||||
|
||||
|
@ -131,7 +134,9 @@ def test_robot(tmp_path, request, robot_type, mock):
|
|||
if "image" in name:
|
||||
# TODO(rcadene): skipping image for now as it's challenging to assess equality between two consecutive frames
|
||||
continue
|
||||
torch.testing.assert_close(captured_observation[name], observation[name], rtol=1e-4, atol=1)
|
||||
torch.testing.assert_close(
|
||||
captured_observation[name], observation[name], rtol=1e-4, atol=1
|
||||
)
|
||||
assert captured_observation[name].shape == observation[name].shape
|
||||
|
||||
# Test send_action can run
|
||||
|
|
|
@ -227,9 +227,9 @@ def test_resume_function(
|
|||
config_dir = os.path.abspath(
|
||||
os.path.join(test_file_dir, "..", "lerobot", "configs", "policy")
|
||||
)
|
||||
assert os.path.exists(
|
||||
config_dir
|
||||
), f"Config directory does not exist at {config_dir}"
|
||||
assert os.path.exists(config_dir), (
|
||||
f"Config directory does not exist at {config_dir}"
|
||||
)
|
||||
|
||||
with initialize_config_dir(
|
||||
config_dir=config_dir, job_name="test_app", version_base="1.2"
|
||||
|
|
|
@ -26,10 +26,16 @@ from lerobot import available_cameras, available_motors, available_robots
|
|||
from lerobot.common.robot_devices.cameras.utils import Camera
|
||||
from lerobot.common.robot_devices.cameras.utils import make_camera as make_camera_device
|
||||
from lerobot.common.robot_devices.motors.utils import MotorsBus
|
||||
from lerobot.common.robot_devices.motors.utils import make_motors_bus as make_motors_bus_device
|
||||
from lerobot.common.robot_devices.motors.utils import (
|
||||
make_motors_bus as make_motors_bus_device,
|
||||
)
|
||||
from lerobot.common.utils.import_utils import is_package_available
|
||||
|
||||
DEVICE = os.environ.get("LEROBOT_TEST_DEVICE", "cuda") if torch.cuda.is_available() else "cpu"
|
||||
DEVICE = (
|
||||
os.environ.get("LEROBOT_TEST_DEVICE", "cuda")
|
||||
if torch.cuda.is_available()
|
||||
else "cpu"
|
||||
)
|
||||
|
||||
TEST_ROBOT_TYPES = []
|
||||
for robot_type in available_robots:
|
||||
|
@ -45,7 +51,9 @@ for motor_type in available_motors:
|
|||
|
||||
# Camera indices used for connecting physical cameras
|
||||
OPENCV_CAMERA_INDEX = int(os.environ.get("LEROBOT_TEST_OPENCV_CAMERA_INDEX", 0))
|
||||
INTELREALSENSE_SERIAL_NUMBER = int(os.environ.get("LEROBOT_TEST_INTELREALSENSE_SERIAL_NUMBER", 128422271614))
|
||||
INTELREALSENSE_SERIAL_NUMBER = int(
|
||||
os.environ.get("LEROBOT_TEST_INTELREALSENSE_SERIAL_NUMBER", 128422271614)
|
||||
)
|
||||
|
||||
DYNAMIXEL_PORT = os.environ.get(
|
||||
"LEROBOT_TEST_DYNAMIXEL_PORT", "/dev/tty.usbmodem575E0032081"
|
||||
|
|
|
@ -18,7 +18,10 @@ from lerobot.common.utils.logging_utils import AverageMeter, MetricsTracker
|
|||
|
||||
@pytest.fixture
|
||||
def mock_metrics():
|
||||
return {"loss": AverageMeter("loss", ":.3f"), "accuracy": AverageMeter("accuracy", ":.2f")}
|
||||
return {
|
||||
"loss": AverageMeter("loss", ":.3f"),
|
||||
"accuracy": AverageMeter("accuracy", ":.2f"),
|
||||
}
|
||||
|
||||
|
||||
def test_average_meter_initialization():
|
||||
|
@ -58,7 +61,11 @@ def test_average_meter_str():
|
|||
|
||||
def test_metrics_tracker_initialization(mock_metrics):
|
||||
tracker = MetricsTracker(
|
||||
batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics, initial_step=10
|
||||
batch_size=32,
|
||||
num_frames=1000,
|
||||
num_episodes=50,
|
||||
metrics=mock_metrics,
|
||||
initial_step=10,
|
||||
)
|
||||
assert tracker.steps == 10
|
||||
assert tracker.samples == 10 * 32
|
||||
|
@ -70,7 +77,11 @@ def test_metrics_tracker_initialization(mock_metrics):
|
|||
|
||||
def test_metrics_tracker_step(mock_metrics):
|
||||
tracker = MetricsTracker(
|
||||
batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics, initial_step=5
|
||||
batch_size=32,
|
||||
num_frames=1000,
|
||||
num_episodes=50,
|
||||
metrics=mock_metrics,
|
||||
initial_step=5,
|
||||
)
|
||||
tracker.step()
|
||||
assert tracker.steps == 6
|
||||
|
@ -80,7 +91,9 @@ def test_metrics_tracker_step(mock_metrics):
|
|||
|
||||
|
||||
def test_metrics_tracker_getattr(mock_metrics):
|
||||
tracker = MetricsTracker(batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics)
|
||||
tracker = MetricsTracker(
|
||||
batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics
|
||||
)
|
||||
assert tracker.loss == mock_metrics["loss"]
|
||||
assert tracker.accuracy == mock_metrics["accuracy"]
|
||||
with pytest.raises(AttributeError):
|
||||
|
@ -88,13 +101,17 @@ def test_metrics_tracker_getattr(mock_metrics):
|
|||
|
||||
|
||||
def test_metrics_tracker_setattr(mock_metrics):
|
||||
tracker = MetricsTracker(batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics)
|
||||
tracker = MetricsTracker(
|
||||
batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics
|
||||
)
|
||||
tracker.loss = 2.0
|
||||
assert tracker.loss.val == 2.0
|
||||
|
||||
|
||||
def test_metrics_tracker_str(mock_metrics):
|
||||
tracker = MetricsTracker(batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics)
|
||||
tracker = MetricsTracker(
|
||||
batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics
|
||||
)
|
||||
tracker.loss.update(3.456, 1)
|
||||
tracker.accuracy.update(0.876, 1)
|
||||
output = str(tracker)
|
||||
|
@ -103,7 +120,9 @@ def test_metrics_tracker_str(mock_metrics):
|
|||
|
||||
|
||||
def test_metrics_tracker_to_dict(mock_metrics):
|
||||
tracker = MetricsTracker(batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics)
|
||||
tracker = MetricsTracker(
|
||||
batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics
|
||||
)
|
||||
tracker.loss.update(5, 2)
|
||||
metrics_dict = tracker.to_dict()
|
||||
assert isinstance(metrics_dict, dict)
|
||||
|
@ -112,7 +131,9 @@ def test_metrics_tracker_to_dict(mock_metrics):
|
|||
|
||||
|
||||
def test_metrics_tracker_reset_averages(mock_metrics):
|
||||
tracker = MetricsTracker(batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics)
|
||||
tracker = MetricsTracker(
|
||||
batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics
|
||||
)
|
||||
tracker.loss.update(10, 3)
|
||||
tracker.accuracy.update(0.95, 5)
|
||||
tracker.reset_averages()
|
||||
|
|
|
@ -118,5 +118,9 @@ def test_seeded_context(fixed_seed):
|
|||
seeded_val2 = (random.random(), np.random.rand(), torch.rand(1).item())
|
||||
|
||||
assert seeded_val1 == seeded_val2
|
||||
assert all(a != b for a, b in zip(val1, seeded_val1, strict=True)) # changed inside the context
|
||||
assert all(a != b for a, b in zip(val2, seeded_val2, strict=True)) # changed again after exiting
|
||||
assert all(
|
||||
a != b for a, b in zip(val1, seeded_val1, strict=True)
|
||||
) # changed inside the context
|
||||
assert all(
|
||||
a != b for a, b in zip(val2, seeded_val2, strict=True)
|
||||
) # changed again after exiting
|
||||
|
|
|
@ -91,7 +91,9 @@ def test_save_training_state(tmp_path, optimizer, scheduler):
|
|||
|
||||
def test_save_load_training_state(tmp_path, optimizer, scheduler):
|
||||
save_training_state(tmp_path, 10, optimizer, scheduler)
|
||||
loaded_step, loaded_optimizer, loaded_scheduler = load_training_state(tmp_path, optimizer, scheduler)
|
||||
loaded_step, loaded_optimizer, loaded_scheduler = load_training_state(
|
||||
tmp_path, optimizer, scheduler
|
||||
)
|
||||
assert loaded_step == 10
|
||||
assert loaded_optimizer is optimizer
|
||||
assert loaded_scheduler is scheduler
|
||||
|
|
Loading…
Reference in New Issue